• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <vector>
18 #include <set>
19 #include <algorithm>
20 #include <memory>
21 #include "tools/converter/legacy_optimizer/graph/subgraph_node_pass.h"
22 #include "src/common/log_adapter.h"
23 #include "src/common/utils.h"
24 #include "tools/common/graph_util.h"
25 #include "include/errorcode.h"
26 #include "schema/inner/model_generated.h"
27 #include "src/common/log_util.h"
28 
29 namespace mindspore {
30 namespace lite {
GetSubgraphAllTensorIndices(const std::unique_ptr<SubGraphT> & subgraph,schema::MetaGraphT * graph,std::set<uint32_t> * tensors_indices)31 STATUS SubgraphNodePass::GetSubgraphAllTensorIndices(const std::unique_ptr<SubGraphT> &subgraph,
32                                                      schema::MetaGraphT *graph, std::set<uint32_t> *tensors_indices) {
33   for (auto &node_idx : subgraph->nodeIndices) {
34     if (node_idx >= graph->nodes.size()) {
35       MS_LOG(ERROR) << "node_idx: " << node_idx << " bigger than graph->nodes.size(): " << graph->nodes.size();
36       for (auto &cur_subgraph : graph->subGraph) {
37         MS_LOG(ERROR) << cur_subgraph->name << " : " << cur_subgraph->nodeIndices;
38       }
39       return RET_ERROR;
40     }
41     auto &node = graph->nodes.at(node_idx);
42     for (auto &input_idx : node->inputIndex) {
43       tensors_indices->insert(input_idx);
44     }
45     for (auto &output_idx : node->outputIndex) {
46       tensors_indices->insert(output_idx);
47     }
48   }
49   return RET_OK;
50 }
51 
IsNodeInputInSubgraph(const std::set<uint32_t> & tensors_indices,const std::unique_ptr<CNodeT> & node,const std::unique_ptr<SubGraphT> & subgraph)52 bool SubgraphNodePass::IsNodeInputInSubgraph(const std::set<uint32_t> &tensors_indices,
53                                              const std::unique_ptr<CNodeT> &node,
54                                              const std::unique_ptr<SubGraphT> &subgraph) {
55   return std::any_of(node->inputIndex.begin(), node->inputIndex.end(), [&tensors_indices, &subgraph](uint32_t idx) {
56     return (tensors_indices.count(idx) > 0) || IsContain(subgraph->inputIndices, idx);
57   });
58 }
59 
IsNodeOutputInSubgraph(const std::set<uint32_t> & tensors_indices,const std::unique_ptr<CNodeT> & node,const std::unique_ptr<SubGraphT> & subgraph)60 bool SubgraphNodePass::IsNodeOutputInSubgraph(const std::set<uint32_t> &tensors_indices,
61                                               const std::unique_ptr<CNodeT> &node,
62                                               const std::unique_ptr<SubGraphT> &subgraph) {
63   return std::any_of(node->outputIndex.begin(), node->outputIndex.end(), [&tensors_indices, &subgraph](uint32_t idx) {
64     return (tensors_indices.count(idx) > 0) || IsContain(subgraph->outputIndices, idx);
65   });
66 }
67 
DecreaseSubgraphNodeIndices(const size_t & node_idx,schema::MetaGraphT * graph)68 void SubgraphNodePass::DecreaseSubgraphNodeIndices(const size_t &node_idx, schema::MetaGraphT *graph) {
69   for (auto &subgraph : graph->subGraph) {
70     std::transform(subgraph->nodeIndices.begin(), subgraph->nodeIndices.end(), subgraph->nodeIndices.begin(),
71                    [&node_idx](uint32_t idx) {
72                      if (idx > node_idx) {
73                        return --idx;
74                      }
75                      return idx;
76                    });
77   }
78 }
79 
IncreaseSubgraphNodeIndices(const size_t & node_idx,schema::MetaGraphT * graph)80 void SubgraphNodePass::IncreaseSubgraphNodeIndices(const size_t &node_idx, schema::MetaGraphT *graph) {
81   for (auto &subgraph : graph->subGraph) {
82     std::transform(subgraph->nodeIndices.begin(), subgraph->nodeIndices.end(), subgraph->nodeIndices.begin(),
83                    [&node_idx](uint32_t idx) {
84                      if (idx >= node_idx) {
85                        return ++idx;
86                      }
87                      return idx;
88                    });
89   }
90 }
91 
Run(schema::MetaGraphT * graph)92 STATUS SubgraphNodePass::Run(schema::MetaGraphT *graph) {
93   CHECK_NULL_RETURN(graph);
94   std::vector<schema::CNodeT *> new_nodes{};
95   std::transform(graph->nodes.begin(), graph->nodes.end(), std::back_inserter(new_nodes),
96                  [](std::unique_ptr<CNodeT> &node) { return node.get(); });
97 
98   for (auto it = old_nodes_.begin(); it != old_nodes_.end();) {
99     if (!IsContain(new_nodes, *it)) {
100       size_t node_idx = it - old_nodes_.begin();
101       for (auto &subgraph : graph->subGraph) {
102         auto node_idx_pos = std::find(subgraph->nodeIndices.begin(), subgraph->nodeIndices.end(), node_idx);
103         if (node_idx_pos != subgraph->nodeIndices.end()) {
104           subgraph->nodeIndices.erase(node_idx_pos);
105           DecreaseSubgraphNodeIndices(node_idx, graph);
106           break;
107         }
108       }
109       it = old_nodes_.erase(it);
110     } else {
111       it++;
112     }
113   }
114 
115   for (uint32_t i = 0; i < new_nodes.size(); i++) {
116     if (!IsContain(old_nodes_, new_nodes[i])) {
117       auto &node = graph->nodes.at(i);
118       std::vector<SubGraphT *> contain_node_input_subgraphs{};
119       std::vector<SubGraphT *> contain_node_output_subgraphs{};
120       std::vector<SubGraphT *> contain_subgraphs{};
121       for (auto &subgraph : graph->subGraph) {
122         std::set<uint32_t> tensors_indices{};
123         int ret = GetSubgraphAllTensorIndices(subgraph, graph, &tensors_indices);
124         if (ret != RET_OK) {
125           MS_LOG(ERROR) << "GetSubgraphAllTensorIndices failed.";
126           return ret;
127         }
128         if (IsNodeInputInSubgraph(tensors_indices, node, subgraph)) {
129           contain_node_input_subgraphs.push_back(subgraph.get());
130         }
131         if (IsNodeOutputInSubgraph(tensors_indices, node, subgraph)) {
132           contain_node_output_subgraphs.push_back(subgraph.get());
133         }
134       }
135       for (auto subgraph : contain_node_input_subgraphs) {
136         if (IsContain(contain_node_output_subgraphs, subgraph)) {
137           contain_subgraphs.emplace_back(subgraph);
138         }
139       }
140       if (contain_subgraphs.size() == 1) {
141         IncreaseSubgraphNodeIndices(i, graph);
142         contain_subgraphs[0]->nodeIndices.push_back(i);
143         continue;
144       }
145       if (contain_node_input_subgraphs.size() == 1 && contain_node_output_subgraphs.empty()) {
146         IncreaseSubgraphNodeIndices(i, graph);
147         contain_node_input_subgraphs[0]->nodeIndices.push_back(i);
148         continue;
149       }
150       if (contain_node_output_subgraphs.size() == 1 && contain_node_input_subgraphs.empty()) {
151         IncreaseSubgraphNodeIndices(i, graph);
152         contain_node_output_subgraphs[0]->nodeIndices.push_back(i);
153         continue;
154       } else {
155         MS_LOG(ERROR) << "Not able to find which subgraph to insert node: " << node->name;
156         return RET_ERROR;
157       }
158     }
159   }
160   return RET_OK;
161 }
162 }  // namespace lite
163 }  // namespace mindspore
164