1 /**
2 * Copyright 2020-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <vector>
18 #include <set>
19 #include <algorithm>
20 #include <memory>
21 #include "tools/converter/legacy_optimizer/graph/subgraph_node_pass.h"
22 #include "src/common/log_adapter.h"
23 #include "src/common/utils.h"
24 #include "tools/common/graph_util.h"
25 #include "include/errorcode.h"
26 #include "schema/inner/model_generated.h"
27 #include "src/common/log_util.h"
28
29 namespace mindspore {
30 namespace lite {
GetSubgraphAllTensorIndices(const std::unique_ptr<SubGraphT> & subgraph,schema::MetaGraphT * graph,std::set<uint32_t> * tensors_indices)31 STATUS SubgraphNodePass::GetSubgraphAllTensorIndices(const std::unique_ptr<SubGraphT> &subgraph,
32 schema::MetaGraphT *graph, std::set<uint32_t> *tensors_indices) {
33 for (auto &node_idx : subgraph->nodeIndices) {
34 if (node_idx >= graph->nodes.size()) {
35 MS_LOG(ERROR) << "node_idx: " << node_idx << " bigger than graph->nodes.size(): " << graph->nodes.size();
36 for (auto &cur_subgraph : graph->subGraph) {
37 MS_LOG(ERROR) << cur_subgraph->name << " : " << cur_subgraph->nodeIndices;
38 }
39 return RET_ERROR;
40 }
41 auto &node = graph->nodes.at(node_idx);
42 for (auto &input_idx : node->inputIndex) {
43 tensors_indices->insert(input_idx);
44 }
45 for (auto &output_idx : node->outputIndex) {
46 tensors_indices->insert(output_idx);
47 }
48 }
49 return RET_OK;
50 }
51
IsNodeInputInSubgraph(const std::set<uint32_t> & tensors_indices,const std::unique_ptr<CNodeT> & node,const std::unique_ptr<SubGraphT> & subgraph)52 bool SubgraphNodePass::IsNodeInputInSubgraph(const std::set<uint32_t> &tensors_indices,
53 const std::unique_ptr<CNodeT> &node,
54 const std::unique_ptr<SubGraphT> &subgraph) {
55 return std::any_of(node->inputIndex.begin(), node->inputIndex.end(), [&tensors_indices, &subgraph](uint32_t idx) {
56 return (tensors_indices.count(idx) > 0) || IsContain(subgraph->inputIndices, idx);
57 });
58 }
59
IsNodeOutputInSubgraph(const std::set<uint32_t> & tensors_indices,const std::unique_ptr<CNodeT> & node,const std::unique_ptr<SubGraphT> & subgraph)60 bool SubgraphNodePass::IsNodeOutputInSubgraph(const std::set<uint32_t> &tensors_indices,
61 const std::unique_ptr<CNodeT> &node,
62 const std::unique_ptr<SubGraphT> &subgraph) {
63 return std::any_of(node->outputIndex.begin(), node->outputIndex.end(), [&tensors_indices, &subgraph](uint32_t idx) {
64 return (tensors_indices.count(idx) > 0) || IsContain(subgraph->outputIndices, idx);
65 });
66 }
67
DecreaseSubgraphNodeIndices(const size_t & node_idx,schema::MetaGraphT * graph)68 void SubgraphNodePass::DecreaseSubgraphNodeIndices(const size_t &node_idx, schema::MetaGraphT *graph) {
69 for (auto &subgraph : graph->subGraph) {
70 std::transform(subgraph->nodeIndices.begin(), subgraph->nodeIndices.end(), subgraph->nodeIndices.begin(),
71 [&node_idx](uint32_t idx) {
72 if (idx > node_idx) {
73 return --idx;
74 }
75 return idx;
76 });
77 }
78 }
79
IncreaseSubgraphNodeIndices(const size_t & node_idx,schema::MetaGraphT * graph)80 void SubgraphNodePass::IncreaseSubgraphNodeIndices(const size_t &node_idx, schema::MetaGraphT *graph) {
81 for (auto &subgraph : graph->subGraph) {
82 std::transform(subgraph->nodeIndices.begin(), subgraph->nodeIndices.end(), subgraph->nodeIndices.begin(),
83 [&node_idx](uint32_t idx) {
84 if (idx >= node_idx) {
85 return ++idx;
86 }
87 return idx;
88 });
89 }
90 }
91
Run(schema::MetaGraphT * graph)92 STATUS SubgraphNodePass::Run(schema::MetaGraphT *graph) {
93 CHECK_NULL_RETURN(graph);
94 std::vector<schema::CNodeT *> new_nodes{};
95 std::transform(graph->nodes.begin(), graph->nodes.end(), std::back_inserter(new_nodes),
96 [](std::unique_ptr<CNodeT> &node) { return node.get(); });
97
98 for (auto it = old_nodes_.begin(); it != old_nodes_.end();) {
99 if (!IsContain(new_nodes, *it)) {
100 size_t node_idx = it - old_nodes_.begin();
101 for (auto &subgraph : graph->subGraph) {
102 auto node_idx_pos = std::find(subgraph->nodeIndices.begin(), subgraph->nodeIndices.end(), node_idx);
103 if (node_idx_pos != subgraph->nodeIndices.end()) {
104 subgraph->nodeIndices.erase(node_idx_pos);
105 DecreaseSubgraphNodeIndices(node_idx, graph);
106 break;
107 }
108 }
109 it = old_nodes_.erase(it);
110 } else {
111 it++;
112 }
113 }
114
115 for (uint32_t i = 0; i < new_nodes.size(); i++) {
116 if (!IsContain(old_nodes_, new_nodes[i])) {
117 auto &node = graph->nodes.at(i);
118 std::vector<SubGraphT *> contain_node_input_subgraphs{};
119 std::vector<SubGraphT *> contain_node_output_subgraphs{};
120 std::vector<SubGraphT *> contain_subgraphs{};
121 for (auto &subgraph : graph->subGraph) {
122 std::set<uint32_t> tensors_indices{};
123 int ret = GetSubgraphAllTensorIndices(subgraph, graph, &tensors_indices);
124 if (ret != RET_OK) {
125 MS_LOG(ERROR) << "GetSubgraphAllTensorIndices failed.";
126 return ret;
127 }
128 if (IsNodeInputInSubgraph(tensors_indices, node, subgraph)) {
129 contain_node_input_subgraphs.push_back(subgraph.get());
130 }
131 if (IsNodeOutputInSubgraph(tensors_indices, node, subgraph)) {
132 contain_node_output_subgraphs.push_back(subgraph.get());
133 }
134 }
135 for (auto subgraph : contain_node_input_subgraphs) {
136 if (IsContain(contain_node_output_subgraphs, subgraph)) {
137 contain_subgraphs.emplace_back(subgraph);
138 }
139 }
140 if (contain_subgraphs.size() == 1) {
141 IncreaseSubgraphNodeIndices(i, graph);
142 contain_subgraphs[0]->nodeIndices.push_back(i);
143 continue;
144 }
145 if (contain_node_input_subgraphs.size() == 1 && contain_node_output_subgraphs.empty()) {
146 IncreaseSubgraphNodeIndices(i, graph);
147 contain_node_input_subgraphs[0]->nodeIndices.push_back(i);
148 continue;
149 }
150 if (contain_node_output_subgraphs.size() == 1 && contain_node_input_subgraphs.empty()) {
151 IncreaseSubgraphNodeIndices(i, graph);
152 contain_node_output_subgraphs[0]->nodeIndices.push_back(i);
153 continue;
154 } else {
155 MS_LOG(ERROR) << "Not able to find which subgraph to insert node: " << node->name;
156 return RET_ERROR;
157 }
158 }
159 }
160 return RET_OK;
161 }
162 } // namespace lite
163 } // namespace mindspore
164