1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "tools/common/meta_graph_utils.h"
18 #include <vector>
19 #include <set>
20 #include "inner/model_generated.h"
21 #include "src/common/utils.h"
22 #include "nnacl/op_base.h"
23 namespace mindspore::lite {
24 namespace {
GetRefCount(schema::MetaGraphT * graphT,uint32_t tensorIdx)25 size_t GetRefCount(schema::MetaGraphT *graphT, uint32_t tensorIdx) {
26 MS_ASSERT(graphT != nullptr);
27 MS_ASSERT(graphT->allTensors.size() > tensorIdx);
28 size_t refCount = 0;
29 for (auto &node : graphT->nodes) {
30 MS_ASSERT(node != nullptr);
31 if (IsContain(node->inputIndex, tensorIdx)) {
32 refCount++;
33 }
34 }
35 return refCount;
36 }
37 } // namespace
38
GetLinkedPostIdx(const schema::MetaGraphT & graphT,const size_t & tensorIdx)39 std::vector<size_t> GetLinkedPostIdx(const schema::MetaGraphT &graphT, const size_t &tensorIdx) {
40 std::vector<size_t> postNodeIdx;
41 for (size_t i = 0; i < graphT.nodes.size(); i++) {
42 auto &oldNode = graphT.nodes.at(i);
43 if (oldNode == nullptr) {
44 continue;
45 }
46 auto inputIndexes = oldNode->inputIndex;
47 if (IsContain<uint32_t>(inputIndexes, tensorIdx)) {
48 postNodeIdx.emplace_back(i);
49 }
50 }
51 return postNodeIdx;
52 }
53
GetLinkedPreIdx(const schema::MetaGraphT & graphT,const size_t & tensorIdx)54 std::vector<size_t> GetLinkedPreIdx(const schema::MetaGraphT &graphT, const size_t &tensorIdx) {
55 std::vector<size_t> preNodeIdx;
56 for (size_t i = 0; i < graphT.nodes.size(); i++) {
57 auto &oldNode = graphT.nodes.at(i);
58 if (oldNode == nullptr) {
59 continue;
60 }
61 auto outputIndexes = oldNode->outputIndex;
62 if (IsContain<uint32_t>(outputIndexes, tensorIdx)) {
63 preNodeIdx.emplace_back(i);
64 }
65 }
66 return preNodeIdx;
67 }
68
GetInputNodeIdx(const schema::MetaGraphT & graphT,const schema::CNodeT & node,const int inputIndexIdx)69 std::vector<size_t> GetInputNodeIdx(const schema::MetaGraphT &graphT, const schema::CNodeT &node,
70 const int inputIndexIdx) {
71 std::vector<uint32_t> inputIndexes;
72 if (inputIndexIdx == -1) {
73 inputIndexes = node.inputIndex;
74 } else {
75 MS_ASSERT(node.inputIndex.size() > static_cast<uint32_t>(inputIndexIdx));
76 inputIndexes.emplace_back(node.inputIndex.at(inputIndexIdx));
77 }
78 std::set<size_t> inputNodeIdx;
79 for (uint32_t inputIdx : inputIndexes) {
80 auto linkedPreIdx = GetLinkedPreIdx(graphT, inputIdx);
81 inputNodeIdx.insert(linkedPreIdx.begin(), linkedPreIdx.end());
82 }
83 std::vector<size_t> ret;
84 ret.insert(ret.end(), inputNodeIdx.begin(), inputNodeIdx.end());
85 return ret;
86 }
87
GetInputNodeIdx(const schema::MetaGraphT & graphT,const size_t & nodeIdx,const int inputIndexIdx)88 std::vector<size_t> GetInputNodeIdx(const schema::MetaGraphT &graphT, const size_t &nodeIdx, const int inputIndexIdx) {
89 return GetInputNodeIdx(graphT, *(graphT.nodes.at(nodeIdx).get()), inputIndexIdx);
90 }
91
GetOutputNodeIdx(const schema::MetaGraphT & graphT,const schema::CNodeT & node,const int outputIndexIdx)92 std::vector<size_t> GetOutputNodeIdx(const schema::MetaGraphT &graphT, const schema::CNodeT &node,
93 const int outputIndexIdx) {
94 std::vector<uint32_t> outputIndexes;
95 if (outputIndexIdx == -1) {
96 outputIndexes = node.outputIndex;
97 } else {
98 MS_ASSERT(node.outputIndex.size() > static_cast<uint32_t>(outputIndexIdx));
99 outputIndexes.emplace_back(node.outputIndex.at(outputIndexIdx));
100 }
101 std::set<size_t> outputNodeIdx;
102 for (uint32_t outputIdx : outputIndexes) {
103 auto linkedPostIdx = GetLinkedPostIdx(graphT, outputIdx);
104 outputNodeIdx.insert(linkedPostIdx.begin(), linkedPostIdx.end());
105 }
106 std::vector<size_t> ret;
107 ret.insert(ret.end(), outputNodeIdx.begin(), outputNodeIdx.end());
108 return ret;
109 }
110
GetOutputNodeIdx(const schema::MetaGraphT & graphT,const size_t & nodeIdx,const int outputIndexIdx)111 std::vector<size_t> GetOutputNodeIdx(const schema::MetaGraphT &graphT, const size_t &nodeIdx,
112 const int outputIndexIdx) {
113 return GetOutputNodeIdx(graphT, *(graphT.nodes.at(nodeIdx).get()), outputIndexIdx);
114 }
115
ReplaceOutput(const uint32_t & old_index,const uint32_t & new_index,schema::MetaGraphT * graphT)116 void ReplaceOutput(const uint32_t &old_index, const uint32_t &new_index, schema::MetaGraphT *graphT) {
117 std::replace_if(
118 std::begin(graphT->outputIndex), std::end(graphT->outputIndex),
119 [&old_index](uint32_t outputIndex) { return outputIndex == old_index; }, new_index);
120
121 for (auto &subGraph : graphT->subGraph) {
122 std::replace_if(
123 std::begin(subGraph->outputIndices), std::end(subGraph->outputIndices),
124 [&old_index](uint32_t outputIndex) { return outputIndex == old_index; }, new_index);
125 }
126 }
127
UpdateNodeIndex(schema::CNodeT * node,uint32_t deleteIdx)128 STATUS UpdateNodeIndex(schema::CNodeT *node, uint32_t deleteIdx) {
129 MS_ASSERT(node != nullptr);
130 for (auto inIdxIt = node->inputIndex.begin(); inIdxIt != node->inputIndex.end();) {
131 if (*inIdxIt == deleteIdx) {
132 inIdxIt = node->inputIndex.erase(inIdxIt);
133 } else {
134 if (*inIdxIt > deleteIdx) {
135 (*inIdxIt)--;
136 }
137 inIdxIt++;
138 }
139 }
140 // update nodes output indexes
141 for (auto outIdxIt = node->outputIndex.begin(); outIdxIt != node->outputIndex.end();) {
142 if (*outIdxIt == deleteIdx) {
143 outIdxIt = node->outputIndex.erase(outIdxIt);
144 } else {
145 if (*outIdxIt > deleteIdx) {
146 (*outIdxIt)--;
147 }
148 outIdxIt++;
149 }
150 }
151 return RET_OK;
152 }
153
RemoveTensor(schema::MetaGraphT * graphT,std::vector<uint32_t> toDeleteTensorIdxes,bool forceDelete)154 STATUS RemoveTensor(schema::MetaGraphT *graphT, std::vector<uint32_t> toDeleteTensorIdxes, bool forceDelete) {
155 MS_CHECK_TRUE_MSG(graphT != nullptr, RET_NULL_PTR, "graphT is nullptr");
156 for (auto iter = toDeleteTensorIdxes.begin(); iter != toDeleteTensorIdxes.end();) {
157 uint32_t deleteIdx = *iter;
158 if (!forceDelete) {
159 if (GetRefCount(graphT, deleteIdx) > 1) {
160 iter++;
161 continue;
162 }
163 }
164 // update graph input indices
165 for (auto gInIdx = graphT->inputIndex.begin(); gInIdx != graphT->inputIndex.end(); gInIdx++) {
166 if (*gInIdx > deleteIdx) {
167 (*gInIdx)--;
168 }
169 }
170 // update graph output indices
171 for (auto gOutIdx = graphT->outputIndex.begin(); gOutIdx != graphT->outputIndex.end(); gOutIdx++) {
172 if (*gOutIdx > deleteIdx) {
173 (*gOutIdx)--;
174 }
175 }
176
177 for (auto &subgraph : graphT->subGraph) {
178 // update subgraph input indices
179 for (auto gInIdx = subgraph->inputIndices.begin(); gInIdx != subgraph->inputIndices.end(); gInIdx++) {
180 if (*gInIdx > deleteIdx) {
181 (*gInIdx)--;
182 }
183 }
184 // update subgraph output indices
185 for (auto gOutIdx = subgraph->outputIndices.begin(); gOutIdx != subgraph->outputIndices.end(); gOutIdx++) {
186 if (*gOutIdx > deleteIdx) {
187 (*gOutIdx)--;
188 }
189 }
190 // update subgraph output indices
191 for (auto idx = subgraph->tensorIndices.begin(); idx != subgraph->tensorIndices.end(); idx++) {
192 if (*idx > deleteIdx) {
193 (*idx)--;
194 }
195 }
196 }
197
198 // update nodes indexes
199 for (auto node_iter = graphT->nodes.begin(); node_iter != graphT->nodes.end(); node_iter++) {
200 // update nodes input indexes
201 UpdateNodeIndex((*node_iter).get(), deleteIdx);
202 }
203 // update deleteTensorIdx
204 for (auto selfIt = toDeleteTensorIdxes.begin(); selfIt != toDeleteTensorIdxes.end(); selfIt++) {
205 if (*selfIt > deleteIdx) {
206 (*selfIt)--;
207 }
208 }
209 graphT->allTensors.erase(graphT->allTensors.begin() + deleteIdx);
210 iter = toDeleteTensorIdxes.erase(iter);
211 }
212 return RET_OK;
213 }
214
IsolateNode(schema::MetaGraphT * graphT,schema::CNodeT * node)215 STATUS IsolateNode(schema::MetaGraphT *graphT, schema::CNodeT *node) {
216 MS_CHECK_TRUE_MSG(graphT != nullptr, RET_NULL_PTR, "graphT is nullptr");
217 MS_CHECK_TRUE_MSG(node != nullptr, RET_NULL_PTR, "node is nullptr");
218 size_t nodeIdx = 0;
219 for (size_t i = 0; i < graphT->nodes.size(); i++) {
220 auto &inNode = graphT->nodes.at(i);
221 MS_CHECK_TRUE_MSG(inNode != nullptr, RET_NULL_PTR, "inNode is nullptr");
222 if (inNode->name == node->name) {
223 nodeIdx = i;
224 break;
225 }
226 }
227 auto inputTensorIdxes = node->inputIndex;
228 auto outputTensorIdxes = node->outputIndex;
229 if (inputTensorIdxes.empty()) {
230 MS_LOG(ERROR) << "Node " << node->name.c_str() << "should has no inputs";
231 return RET_ERROR;
232 }
233 if (outputTensorIdxes.size() != 1) {
234 MS_LOG(ERROR) << "FakeQuantNode " << node->name.c_str()
235 << "should has 1 output, in fact: " << outputTensorIdxes.size();
236 return RET_ERROR;
237 }
238 auto inDataTensorIdx = inputTensorIdxes.front();
239 auto outDataTensorIdx = outputTensorIdxes.front();
240
241 MS_ASSERT(graphT->allTensors.size() > inDataTensorIdx);
242 ReplaceOutput(outDataTensorIdx, inDataTensorIdx, graphT);
243
244 // find poseNode
245 auto postNodeIdxes = GetOutputNodeIdx(*graphT, nodeIdx, 0);
246 for (auto postNodeIdx : postNodeIdxes) {
247 MS_ASSERT(graphT->nodes.size() > postNodeIdx);
248 auto &postNode = graphT->nodes.at(postNodeIdx);
249 MS_CHECK_TRUE_MSG(postNode != nullptr, RET_NULL_PTR, "postNode is nullptr");
250 for (auto iter = postNode->inputIndex.begin(); iter != postNode->inputIndex.end(); iter++) {
251 if (*iter == outDataTensorIdx) {
252 *iter = inDataTensorIdx;
253 break;
254 }
255 }
256 }
257 RemoveTensor(graphT, outputTensorIdxes);
258 node->inputIndex.clear();
259 node->outputIndex.clear();
260 return RET_OK;
261 }
262
IsolateOneWayNode(schema::MetaGraphT * graphT,size_t nodeIdx,bool removeTensor)263 STATUS IsolateOneWayNode(schema::MetaGraphT *graphT, size_t nodeIdx, bool removeTensor) {
264 MS_CHECK_TRUE_MSG(graphT != nullptr, RET_NULL_PTR, "graphT is nullptr");
265 if (graphT->nodes.size() <= nodeIdx) {
266 MS_LOG(ERROR) << "nodeIdx out of range: " << nodeIdx;
267 return RET_PARAM_INVALID;
268 }
269 schema::CNodeT *node = graphT->nodes.at(nodeIdx).get();
270 if (node == nullptr) {
271 MS_LOG(ERROR) << "node is null";
272 return RET_NULL_PTR;
273 }
274 auto inputTensorIdxes = node->inputIndex;
275 auto outputTensorIdxes = node->outputIndex;
276 auto preNodeIdxes = GetInputNodeIdx(*graphT, nodeIdx);
277 if (preNodeIdxes.size() > 1 || outputTensorIdxes.size() > 1) {
278 MS_LOG(ERROR) << "Only support node who has no more than one input and one output";
279 return RET_ERROR;
280 }
281 if (inputTensorIdxes.empty()) {
282 MS_LOG(ERROR) << "Error, " << nodeIdx << "th node has no input tensor";
283 return RET_ERROR;
284 }
285 auto inDataTensorIdx = inputTensorIdxes.front();
286 if (!outputTensorIdxes.empty()) {
287 auto outDataTensorIdx = outputTensorIdxes.front();
288 MS_ASSERT(graphT->allTensors.size() > inDataTensorIdx);
289 MS_ASSERT(graphT->allTensors.at(inDataTensorIdx) != nullptr);
290 ReplaceOutput(outDataTensorIdx, inDataTensorIdx, graphT);
291
292 // find poseNode
293 auto postNodeIdxes = GetOutputNodeIdx(*graphT, nodeIdx, 0);
294 for (auto postNodeIdx : postNodeIdxes) {
295 MS_ASSERT(graphT->nodes.size() > postNodeIdx);
296 auto &postNode = graphT->nodes.at(postNodeIdx);
297 MS_CHECK_TRUE_MSG(postNode != nullptr, RET_NULL_PTR, "postNode is nullptr");
298 for (auto iter = postNode->inputIndex.begin(); iter != postNode->inputIndex.end(); iter++) {
299 if (*iter == outDataTensorIdx) {
300 *iter = inDataTensorIdx;
301 break;
302 }
303 }
304 }
305 }
306 if (removeTensor) {
307 // now all node's outputTensors are useless
308 // remove all node's outputTensors
309 auto status = RemoveTensor(graphT, outputTensorIdxes);
310 if (status != RET_OK) {
311 MS_LOG(ERROR) << "RemoveOutputTensors of node " << node->name.c_str() << "failed";
312 return RET_ERROR;
313 }
314 }
315 node->inputIndex.clear();
316 node->outputIndex.clear();
317 return RET_OK;
318 }
319
IsolateOneWayNode(schema::MetaGraphT * graph,size_t subGraphIdx,size_t nodeIdx,bool removeTensor)320 STATUS IsolateOneWayNode(schema::MetaGraphT *graph, size_t subGraphIdx, size_t nodeIdx, bool removeTensor) {
321 MS_CHECK_TRUE_MSG(graph != nullptr, RET_NULL_PTR, "graph is nullptr");
322 return IsolateOneWayNode(graph, nodeIdx, removeTensor);
323 }
324
IsolateOneWayNode(schema::MetaGraphT * graphT,schema::CNodeT * node,bool removeTensor)325 STATUS IsolateOneWayNode(schema::MetaGraphT *graphT, schema::CNodeT *node, bool removeTensor) {
326 MS_CHECK_TRUE_MSG(graphT != nullptr, RET_NULL_PTR, "graphT is nullptr");
327 MS_CHECK_TRUE_MSG(node != nullptr, RET_NULL_PTR, "node is nullptr");
328 bool isSubNode = false;
329 size_t nodeIdx = 0;
330 for (size_t i = 0; i < graphT->nodes.size(); i++) {
331 auto &inNode = graphT->nodes.at(i);
332 MS_CHECK_TRUE_MSG(inNode != nullptr, RET_NULL_PTR, "inNode is nullptr");
333 if (inNode->name == node->name) {
334 isSubNode = true;
335 nodeIdx = i;
336 break;
337 }
338 }
339 if (!isSubNode) {
340 MS_LOG(ERROR) << "Node " << node->name.c_str() << "is not in graphT " << graphT->name.c_str();
341 return RET_PARAM_INVALID;
342 } else {
343 return IsolateOneWayNode(graphT, nodeIdx, removeTensor);
344 }
345 }
346 } // namespace mindspore::lite
347