• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <vector>
18 #include <set>
19 #include <algorithm>
20 #include <memory>
21 #include "tools/converter/legacy_optimizer/graph/subgraph_node_pass.h"
22 #include "src/common/log_adapter.h"
23 #include "src/common/utils.h"
24 #include "include/errorcode.h"
25 #include "schema/inner/model_generated.h"
26 #include "src/common/log_util.h"
27 
28 namespace mindspore {
29 namespace lite {
GetSubgraphAllTensorIndices(const std::unique_ptr<SubGraphT> & subgraph,schema::MetaGraphT * graph,std::set<uint32_t> * tensors_indices)30 STATUS SubgraphNodePass::GetSubgraphAllTensorIndices(const std::unique_ptr<SubGraphT> &subgraph,
31                                                      schema::MetaGraphT *graph, std::set<uint32_t> *tensors_indices) {
32   for (auto &node_idx : subgraph->nodeIndices) {
33     if (node_idx >= graph->nodes.size()) {
34       MS_LOG(ERROR) << "node_idx: " << node_idx << " bigger than graph->nodes.size(): " << graph->nodes.size();
35       for (auto &cur_subgraph : graph->subGraph) {
36         MS_LOG(ERROR) << cur_subgraph->name << " : " << cur_subgraph->nodeIndices;
37       }
38       return RET_ERROR;
39     }
40     auto &node = graph->nodes.at(node_idx);
41     for (auto &input_idx : node->inputIndex) {
42       tensors_indices->insert(input_idx);
43     }
44     for (auto &output_idx : node->outputIndex) {
45       tensors_indices->insert(output_idx);
46     }
47   }
48   return RET_OK;
49 }
50 
IsNodeInputInSubgraph(const std::set<uint32_t> & tensors_indices,const std::unique_ptr<CNodeT> & node,const std::unique_ptr<SubGraphT> & subgraph)51 bool SubgraphNodePass::IsNodeInputInSubgraph(const std::set<uint32_t> &tensors_indices,
52                                              const std::unique_ptr<CNodeT> &node,
53                                              const std::unique_ptr<SubGraphT> &subgraph) {
54   return std::any_of(node->inputIndex.begin(), node->inputIndex.end(), [&tensors_indices, &subgraph](uint32_t idx) {
55     return (tensors_indices.count(idx) > 0) || IsContain(subgraph->inputIndices, idx);
56   });
57 }
58 
IsNodeOutputInSubgraph(const std::set<uint32_t> & tensors_indices,const std::unique_ptr<CNodeT> & node,const std::unique_ptr<SubGraphT> & subgraph)59 bool SubgraphNodePass::IsNodeOutputInSubgraph(const std::set<uint32_t> &tensors_indices,
60                                               const std::unique_ptr<CNodeT> &node,
61                                               const std::unique_ptr<SubGraphT> &subgraph) {
62   return std::any_of(node->outputIndex.begin(), node->outputIndex.end(), [&tensors_indices, &subgraph](uint32_t idx) {
63     return (tensors_indices.count(idx) > 0) || IsContain(subgraph->outputIndices, idx);
64   });
65 }
66 
DecreaseSubgraphNodeIndices(const size_t & node_idx,const schema::MetaGraphT & graph)67 void SubgraphNodePass::DecreaseSubgraphNodeIndices(const size_t &node_idx, const schema::MetaGraphT &graph) {
68   for (auto &subgraph : graph.subGraph) {
69     std::transform(subgraph->nodeIndices.begin(), subgraph->nodeIndices.end(), subgraph->nodeIndices.begin(),
70                    [&node_idx](uint32_t idx) {
71                      if (idx > node_idx) {
72                        return --idx;
73                      }
74                      return idx;
75                    });
76   }
77 }
78 
IncreaseSubgraphNodeIndices(const size_t & node_idx,const schema::MetaGraphT & graph)79 void SubgraphNodePass::IncreaseSubgraphNodeIndices(const size_t &node_idx, const schema::MetaGraphT &graph) {
80   for (auto &subgraph : graph.subGraph) {
81     std::transform(subgraph->nodeIndices.begin(), subgraph->nodeIndices.end(), subgraph->nodeIndices.begin(),
82                    [&node_idx](uint32_t idx) {
83                      if (idx >= node_idx) {
84                        return ++idx;
85                      }
86                      return idx;
87                    });
88   }
89 }
90 
Run(schema::MetaGraphT * graph)91 STATUS SubgraphNodePass::Run(schema::MetaGraphT *graph) {
92   CHECK_NULL_RETURN(graph);
93   std::vector<schema::CNodeT *> new_nodes{};
94   std::transform(graph->nodes.begin(), graph->nodes.end(), std::back_inserter(new_nodes),
95                  [](const std::unique_ptr<CNodeT> &node) { return node.get(); });
96 
97   for (auto it = old_nodes_.begin(); it != old_nodes_.end();) {
98     if (!IsContain(new_nodes, *it)) {
99       size_t node_idx = static_cast<size_t>(it - old_nodes_.begin());
100       for (auto &subgraph : graph->subGraph) {
101         auto node_idx_pos = std::find(subgraph->nodeIndices.begin(), subgraph->nodeIndices.end(), node_idx);
102         if (node_idx_pos != subgraph->nodeIndices.end()) {
103           subgraph->nodeIndices.erase(node_idx_pos);
104           DecreaseSubgraphNodeIndices(node_idx, *graph);
105           break;
106         }
107       }
108       it = old_nodes_.erase(it);
109     } else {
110       it++;
111     }
112   }
113 
114   for (uint32_t i = 0; i < new_nodes.size(); i++) {
115     if (!IsContain(old_nodes_, new_nodes[i])) {
116       auto &node = graph->nodes.at(i);
117       std::vector<SubGraphT *> contain_node_input_subgraphs{};
118       std::vector<SubGraphT *> contain_node_output_subgraphs{};
119       std::vector<SubGraphT *> contain_subgraphs{};
120       for (auto &subgraph : graph->subGraph) {
121         std::set<uint32_t> tensors_indices{};
122         int ret = GetSubgraphAllTensorIndices(subgraph, graph, &tensors_indices);
123         if (ret != RET_OK) {
124           MS_LOG(ERROR) << "GetSubgraphAllTensorIndices failed.";
125           return ret;
126         }
127         if (IsNodeInputInSubgraph(tensors_indices, node, subgraph)) {
128           contain_node_input_subgraphs.push_back(subgraph.get());
129         }
130         if (IsNodeOutputInSubgraph(tensors_indices, node, subgraph)) {
131           contain_node_output_subgraphs.push_back(subgraph.get());
132         }
133       }
134       for (auto subgraph : contain_node_input_subgraphs) {
135         if (IsContain(contain_node_output_subgraphs, subgraph)) {
136           contain_subgraphs.emplace_back(subgraph);
137         }
138       }
139       if (contain_subgraphs.size() == 1) {
140         IncreaseSubgraphNodeIndices(i, *graph);
141         contain_subgraphs[0]->nodeIndices.push_back(i);
142         continue;
143       }
144       if (contain_node_input_subgraphs.size() == 1 && contain_node_output_subgraphs.empty()) {
145         IncreaseSubgraphNodeIndices(i, *graph);
146         contain_node_input_subgraphs[0]->nodeIndices.push_back(i);
147         continue;
148       }
149       if (contain_node_output_subgraphs.size() == 1 && contain_node_input_subgraphs.empty()) {
150         IncreaseSubgraphNodeIndices(i, *graph);
151         contain_node_output_subgraphs[0]->nodeIndices.push_back(i);
152         continue;
153       } else {
154         MS_LOG(ERROR) << "Not able to find which subgraph to insert node: " << node->name;
155         return RET_ERROR;
156       }
157     }
158   }
159   return RET_OK;
160 }
161 }  // namespace lite
162 }  // namespace mindspore
163