• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include <vector>
17 #include <unordered_map>
18 #include <string>
19 #include <memory>
20 #include "src/common/log_util.h"
21 #include "src/train/optimizer/common/fusion_utils.h"
22 
23 namespace mindspore {
24 namespace opt {
GetMatchNodeIndex(schema::MetaGraphT * graph,const std::unordered_map<std::string,std::shared_ptr<lite::Path>> & matched_path,const std::string & node_name,size_t * node_index)25 STATUS GetMatchNodeIndex(schema::MetaGraphT *graph,
26                          const std::unordered_map<std::string, std::shared_ptr<lite::Path>> &matched_path,
27                          const std::string &node_name, size_t *node_index) {
28   auto node_path_iter = matched_path.find(node_name);
29   MS_CHECK_TRUE_MSG(node_path_iter != matched_path.end(), RET_ERROR, "cannot find node_path");
30   const auto &node_path = node_path_iter->second;
31   MS_CHECK_TRUE_MSG(node_path != nullptr, RET_NULL_PTR, "node_path is empty");
32   *node_index = node_path->nodeIdx;
33   MS_CHECK_TRUE_MSG(*node_index < graph->nodes.size(), RET_ERROR, "node_index is out of range");
34   return RET_OK;
35 }
36 
IsMultiOutputNode(schema::MetaGraphT * graph,size_t out_node_index)37 bool IsMultiOutputNode(schema::MetaGraphT *graph, size_t out_node_index) {
38   uint32_t count = 0;
39   for (auto &node : graph->nodes) {
40     if (std::find(node->inputIndex.begin(), node->inputIndex.end(), out_node_index) != node->inputIndex.end()) {
41       count++;
42     }
43     if (count > 1) {
44       return true;
45     }
46   }
47   return false;
48 }
49 }  // namespace opt
50 }  // namespace mindspore
51