1 /**
2 * Copyright 2020-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <vector>
18 #include <set>
19 #include <algorithm>
20 #include <memory>
21 #include "tools/converter/legacy_optimizer/graph/subgraph_node_pass.h"
22 #include "src/common/log_adapter.h"
23 #include "src/common/utils.h"
24 #include "include/errorcode.h"
25 #include "schema/inner/model_generated.h"
26 #include "src/common/log_util.h"
27
28 namespace mindspore {
29 namespace lite {
GetSubgraphAllTensorIndices(const std::unique_ptr<SubGraphT> & subgraph,schema::MetaGraphT * graph,std::set<uint32_t> * tensors_indices)30 STATUS SubgraphNodePass::GetSubgraphAllTensorIndices(const std::unique_ptr<SubGraphT> &subgraph,
31 schema::MetaGraphT *graph, std::set<uint32_t> *tensors_indices) {
32 for (auto &node_idx : subgraph->nodeIndices) {
33 if (node_idx >= graph->nodes.size()) {
34 MS_LOG(ERROR) << "node_idx: " << node_idx << " bigger than graph->nodes.size(): " << graph->nodes.size();
35 for (auto &cur_subgraph : graph->subGraph) {
36 MS_LOG(ERROR) << cur_subgraph->name << " : " << cur_subgraph->nodeIndices;
37 }
38 return RET_ERROR;
39 }
40 auto &node = graph->nodes.at(node_idx);
41 for (auto &input_idx : node->inputIndex) {
42 tensors_indices->insert(input_idx);
43 }
44 for (auto &output_idx : node->outputIndex) {
45 tensors_indices->insert(output_idx);
46 }
47 }
48 return RET_OK;
49 }
50
IsNodeInputInSubgraph(const std::set<uint32_t> & tensors_indices,const std::unique_ptr<CNodeT> & node,const std::unique_ptr<SubGraphT> & subgraph)51 bool SubgraphNodePass::IsNodeInputInSubgraph(const std::set<uint32_t> &tensors_indices,
52 const std::unique_ptr<CNodeT> &node,
53 const std::unique_ptr<SubGraphT> &subgraph) {
54 return std::any_of(node->inputIndex.begin(), node->inputIndex.end(), [&tensors_indices, &subgraph](uint32_t idx) {
55 return (tensors_indices.count(idx) > 0) || IsContain(subgraph->inputIndices, idx);
56 });
57 }
58
IsNodeOutputInSubgraph(const std::set<uint32_t> & tensors_indices,const std::unique_ptr<CNodeT> & node,const std::unique_ptr<SubGraphT> & subgraph)59 bool SubgraphNodePass::IsNodeOutputInSubgraph(const std::set<uint32_t> &tensors_indices,
60 const std::unique_ptr<CNodeT> &node,
61 const std::unique_ptr<SubGraphT> &subgraph) {
62 return std::any_of(node->outputIndex.begin(), node->outputIndex.end(), [&tensors_indices, &subgraph](uint32_t idx) {
63 return (tensors_indices.count(idx) > 0) || IsContain(subgraph->outputIndices, idx);
64 });
65 }
66
DecreaseSubgraphNodeIndices(const size_t & node_idx,const schema::MetaGraphT & graph)67 void SubgraphNodePass::DecreaseSubgraphNodeIndices(const size_t &node_idx, const schema::MetaGraphT &graph) {
68 for (auto &subgraph : graph.subGraph) {
69 std::transform(subgraph->nodeIndices.begin(), subgraph->nodeIndices.end(), subgraph->nodeIndices.begin(),
70 [&node_idx](uint32_t idx) {
71 if (idx > node_idx) {
72 return --idx;
73 }
74 return idx;
75 });
76 }
77 }
78
IncreaseSubgraphNodeIndices(const size_t & node_idx,const schema::MetaGraphT & graph)79 void SubgraphNodePass::IncreaseSubgraphNodeIndices(const size_t &node_idx, const schema::MetaGraphT &graph) {
80 for (auto &subgraph : graph.subGraph) {
81 std::transform(subgraph->nodeIndices.begin(), subgraph->nodeIndices.end(), subgraph->nodeIndices.begin(),
82 [&node_idx](uint32_t idx) {
83 if (idx >= node_idx) {
84 return ++idx;
85 }
86 return idx;
87 });
88 }
89 }
90
Run(schema::MetaGraphT * graph)91 STATUS SubgraphNodePass::Run(schema::MetaGraphT *graph) {
92 CHECK_NULL_RETURN(graph);
93 std::vector<schema::CNodeT *> new_nodes{};
94 std::transform(graph->nodes.begin(), graph->nodes.end(), std::back_inserter(new_nodes),
95 [](const std::unique_ptr<CNodeT> &node) { return node.get(); });
96
97 for (auto it = old_nodes_.begin(); it != old_nodes_.end();) {
98 if (!IsContain(new_nodes, *it)) {
99 size_t node_idx = static_cast<size_t>(it - old_nodes_.begin());
100 for (auto &subgraph : graph->subGraph) {
101 auto node_idx_pos = std::find(subgraph->nodeIndices.begin(), subgraph->nodeIndices.end(), node_idx);
102 if (node_idx_pos != subgraph->nodeIndices.end()) {
103 subgraph->nodeIndices.erase(node_idx_pos);
104 DecreaseSubgraphNodeIndices(node_idx, *graph);
105 break;
106 }
107 }
108 it = old_nodes_.erase(it);
109 } else {
110 it++;
111 }
112 }
113
114 for (uint32_t i = 0; i < new_nodes.size(); i++) {
115 if (!IsContain(old_nodes_, new_nodes[i])) {
116 auto &node = graph->nodes.at(i);
117 std::vector<SubGraphT *> contain_node_input_subgraphs{};
118 std::vector<SubGraphT *> contain_node_output_subgraphs{};
119 std::vector<SubGraphT *> contain_subgraphs{};
120 for (auto &subgraph : graph->subGraph) {
121 std::set<uint32_t> tensors_indices{};
122 int ret = GetSubgraphAllTensorIndices(subgraph, graph, &tensors_indices);
123 if (ret != RET_OK) {
124 MS_LOG(ERROR) << "GetSubgraphAllTensorIndices failed.";
125 return ret;
126 }
127 if (IsNodeInputInSubgraph(tensors_indices, node, subgraph)) {
128 contain_node_input_subgraphs.push_back(subgraph.get());
129 }
130 if (IsNodeOutputInSubgraph(tensors_indices, node, subgraph)) {
131 contain_node_output_subgraphs.push_back(subgraph.get());
132 }
133 }
134 for (auto subgraph : contain_node_input_subgraphs) {
135 if (IsContain(contain_node_output_subgraphs, subgraph)) {
136 contain_subgraphs.emplace_back(subgraph);
137 }
138 }
139 if (contain_subgraphs.size() == 1) {
140 IncreaseSubgraphNodeIndices(i, *graph);
141 contain_subgraphs[0]->nodeIndices.push_back(i);
142 continue;
143 }
144 if (contain_node_input_subgraphs.size() == 1 && contain_node_output_subgraphs.empty()) {
145 IncreaseSubgraphNodeIndices(i, *graph);
146 contain_node_input_subgraphs[0]->nodeIndices.push_back(i);
147 continue;
148 }
149 if (contain_node_output_subgraphs.size() == 1 && contain_node_input_subgraphs.empty()) {
150 IncreaseSubgraphNodeIndices(i, *graph);
151 contain_node_output_subgraphs[0]->nodeIndices.push_back(i);
152 continue;
153 } else {
154 MS_LOG(ERROR) << "Not able to find which subgraph to insert node: " << node->name;
155 return RET_ERROR;
156 }
157 }
158 }
159 return RET_OK;
160 }
161 } // namespace lite
162 } // namespace mindspore
163