1 /**
2 * Copyright 2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include <vector>
17 #include <unordered_map>
18 #include <string>
19 #include <memory>
20 #include "src/common/log_util.h"
21 #include "src/train/optimizer/common/fusion_utils.h"
22
23 namespace mindspore {
24 namespace opt {
GetMatchNodeIndex(schema::MetaGraphT * graph,const std::unordered_map<std::string,std::shared_ptr<lite::Path>> & matched_path,const std::string & node_name,size_t * node_index)25 STATUS GetMatchNodeIndex(schema::MetaGraphT *graph,
26 const std::unordered_map<std::string, std::shared_ptr<lite::Path>> &matched_path,
27 const std::string &node_name, size_t *node_index) {
28 auto node_path_iter = matched_path.find(node_name);
29 MS_CHECK_TRUE_MSG(node_path_iter != matched_path.end(), RET_ERROR, "cannot find node_path");
30 const auto &node_path = node_path_iter->second;
31 MS_CHECK_TRUE_MSG(node_path != nullptr, RET_NULL_PTR, "node_path is empty");
32 *node_index = node_path->nodeIdx;
33 MS_CHECK_TRUE_MSG(*node_index < graph->nodes.size(), RET_ERROR, "node_index is out of range");
34 return RET_OK;
35 }
36
IsMultiOutputNode(schema::MetaGraphT * graph,size_t out_node_index)37 bool IsMultiOutputNode(schema::MetaGraphT *graph, size_t out_node_index) {
38 uint32_t count = 0;
39 for (auto &node : graph->nodes) {
40 if (std::find(node->inputIndex.begin(), node->inputIndex.end(), out_node_index) != node->inputIndex.end()) {
41 count++;
42 }
43 if (count > 1) {
44 return true;
45 }
46 }
47 return false;
48 }
49 } // namespace opt
50 } // namespace mindspore
51