• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "tools/common/meta_graph_utils.h"
18 #include <vector>
19 #include <set>
20 #include "inner/model_generated.h"
21 #include "src/common/utils.h"
22 #include "nnacl/op_base.h"
23 namespace mindspore::lite {
24 namespace {
GetRefCount(schema::MetaGraphT * graphT,uint32_t tensorIdx)25 size_t GetRefCount(schema::MetaGraphT *graphT, uint32_t tensorIdx) {
26   MS_ASSERT(graphT != nullptr);
27   MS_ASSERT(graphT->allTensors.size() > tensorIdx);
28   size_t refCount = 0;
29   for (auto &node : graphT->nodes) {
30     MS_ASSERT(node != nullptr);
31     if (IsContain(node->inputIndex, tensorIdx)) {
32       refCount++;
33     }
34   }
35   return refCount;
36 }
37 }  // namespace
38 
GetLinkedPostIdx(const schema::MetaGraphT & graphT,const size_t & tensorIdx)39 std::vector<size_t> GetLinkedPostIdx(const schema::MetaGraphT &graphT, const size_t &tensorIdx) {
40   std::vector<size_t> postNodeIdx;
41   for (size_t i = 0; i < graphT.nodes.size(); i++) {
42     auto &oldNode = graphT.nodes.at(i);
43     if (oldNode == nullptr) {
44       continue;
45     }
46     auto inputIndexes = oldNode->inputIndex;
47     if (IsContain<uint32_t>(inputIndexes, tensorIdx)) {
48       postNodeIdx.emplace_back(i);
49     }
50   }
51   return postNodeIdx;
52 }
53 
GetLinkedPreIdx(const schema::MetaGraphT & graphT,const size_t & tensorIdx)54 std::vector<size_t> GetLinkedPreIdx(const schema::MetaGraphT &graphT, const size_t &tensorIdx) {
55   std::vector<size_t> preNodeIdx;
56   for (size_t i = 0; i < graphT.nodes.size(); i++) {
57     auto &oldNode = graphT.nodes.at(i);
58     if (oldNode == nullptr) {
59       continue;
60     }
61     auto outputIndexes = oldNode->outputIndex;
62     if (IsContain<uint32_t>(outputIndexes, tensorIdx)) {
63       preNodeIdx.emplace_back(i);
64     }
65   }
66   return preNodeIdx;
67 }
68 
GetInputNodeIdx(const schema::MetaGraphT & graphT,const schema::CNodeT & node,const int inputIndexIdx)69 std::vector<size_t> GetInputNodeIdx(const schema::MetaGraphT &graphT, const schema::CNodeT &node,
70                                     const int inputIndexIdx) {
71   std::vector<uint32_t> inputIndexes;
72   if (inputIndexIdx == -1) {
73     inputIndexes = node.inputIndex;
74   } else {
75     MS_ASSERT(node.inputIndex.size() > static_cast<uint32_t>(inputIndexIdx));
76     inputIndexes.emplace_back(node.inputIndex.at(inputIndexIdx));
77   }
78   std::set<size_t> inputNodeIdx;
79   for (uint32_t inputIdx : inputIndexes) {
80     auto linkedPreIdx = GetLinkedPreIdx(graphT, inputIdx);
81     inputNodeIdx.insert(linkedPreIdx.begin(), linkedPreIdx.end());
82   }
83   std::vector<size_t> ret;
84   ret.insert(ret.end(), inputNodeIdx.begin(), inputNodeIdx.end());
85   return ret;
86 }
87 
GetInputNodeIdx(const schema::MetaGraphT & graphT,const size_t & nodeIdx,const int inputIndexIdx)88 std::vector<size_t> GetInputNodeIdx(const schema::MetaGraphT &graphT, const size_t &nodeIdx, const int inputIndexIdx) {
89   return GetInputNodeIdx(graphT, *(graphT.nodes.at(nodeIdx).get()), inputIndexIdx);
90 }
91 
GetOutputNodeIdx(const schema::MetaGraphT & graphT,const schema::CNodeT & node,const int outputIndexIdx)92 std::vector<size_t> GetOutputNodeIdx(const schema::MetaGraphT &graphT, const schema::CNodeT &node,
93                                      const int outputIndexIdx) {
94   std::vector<uint32_t> outputIndexes;
95   if (outputIndexIdx == -1) {
96     outputIndexes = node.outputIndex;
97   } else {
98     MS_ASSERT(node.outputIndex.size() > static_cast<uint32_t>(outputIndexIdx));
99     outputIndexes.emplace_back(node.outputIndex.at(outputIndexIdx));
100   }
101   std::set<size_t> outputNodeIdx;
102   for (uint32_t outputIdx : outputIndexes) {
103     auto linkedPostIdx = GetLinkedPostIdx(graphT, outputIdx);
104     outputNodeIdx.insert(linkedPostIdx.begin(), linkedPostIdx.end());
105   }
106   std::vector<size_t> ret;
107   ret.insert(ret.end(), outputNodeIdx.begin(), outputNodeIdx.end());
108   return ret;
109 }
110 
GetOutputNodeIdx(const schema::MetaGraphT & graphT,const size_t & nodeIdx,const int outputIndexIdx)111 std::vector<size_t> GetOutputNodeIdx(const schema::MetaGraphT &graphT, const size_t &nodeIdx,
112                                      const int outputIndexIdx) {
113   return GetOutputNodeIdx(graphT, *(graphT.nodes.at(nodeIdx).get()), outputIndexIdx);
114 }
115 
ReplaceOutput(const uint32_t & old_index,const uint32_t & new_index,schema::MetaGraphT * graphT)116 void ReplaceOutput(const uint32_t &old_index, const uint32_t &new_index, schema::MetaGraphT *graphT) {
117   std::replace_if(
118     std::begin(graphT->outputIndex), std::end(graphT->outputIndex),
119     [&old_index](uint32_t outputIndex) { return outputIndex == old_index; }, new_index);
120 
121   for (auto &subGraph : graphT->subGraph) {
122     std::replace_if(
123       std::begin(subGraph->outputIndices), std::end(subGraph->outputIndices),
124       [&old_index](uint32_t outputIndex) { return outputIndex == old_index; }, new_index);
125   }
126 }
127 
UpdateNodeIndex(schema::CNodeT * node,uint32_t deleteIdx)128 STATUS UpdateNodeIndex(schema::CNodeT *node, uint32_t deleteIdx) {
129   MS_ASSERT(node != nullptr);
130   for (auto inIdxIt = node->inputIndex.begin(); inIdxIt != node->inputIndex.end();) {
131     if (*inIdxIt == deleteIdx) {
132       inIdxIt = node->inputIndex.erase(inIdxIt);
133     } else {
134       if (*inIdxIt > deleteIdx) {
135         (*inIdxIt)--;
136       }
137       inIdxIt++;
138     }
139   }
140   // update nodes output indexes
141   for (auto outIdxIt = node->outputIndex.begin(); outIdxIt != node->outputIndex.end();) {
142     if (*outIdxIt == deleteIdx) {
143       outIdxIt = node->outputIndex.erase(outIdxIt);
144     } else {
145       if (*outIdxIt > deleteIdx) {
146         (*outIdxIt)--;
147       }
148       outIdxIt++;
149     }
150   }
151   return RET_OK;
152 }
153 
RemoveTensor(schema::MetaGraphT * graphT,std::vector<uint32_t> toDeleteTensorIdxes,bool forceDelete)154 STATUS RemoveTensor(schema::MetaGraphT *graphT, std::vector<uint32_t> toDeleteTensorIdxes, bool forceDelete) {
155   MS_CHECK_TRUE_MSG(graphT != nullptr, RET_NULL_PTR, "graphT is nullptr");
156   for (auto iter = toDeleteTensorIdxes.begin(); iter != toDeleteTensorIdxes.end();) {
157     uint32_t deleteIdx = *iter;
158     if (!forceDelete) {
159       if (GetRefCount(graphT, deleteIdx) > 1) {
160         iter++;
161         continue;
162       }
163     }
164     // update graph input indices
165     for (auto gInIdx = graphT->inputIndex.begin(); gInIdx != graphT->inputIndex.end(); gInIdx++) {
166       if (*gInIdx > deleteIdx) {
167         (*gInIdx)--;
168       }
169     }
170     // update graph output indices
171     for (auto gOutIdx = graphT->outputIndex.begin(); gOutIdx != graphT->outputIndex.end(); gOutIdx++) {
172       if (*gOutIdx > deleteIdx) {
173         (*gOutIdx)--;
174       }
175     }
176 
177     for (auto &subgraph : graphT->subGraph) {
178       // update subgraph input indices
179       for (auto gInIdx = subgraph->inputIndices.begin(); gInIdx != subgraph->inputIndices.end(); gInIdx++) {
180         if (*gInIdx > deleteIdx) {
181           (*gInIdx)--;
182         }
183       }
184       // update subgraph output indices
185       for (auto gOutIdx = subgraph->outputIndices.begin(); gOutIdx != subgraph->outputIndices.end(); gOutIdx++) {
186         if (*gOutIdx > deleteIdx) {
187           (*gOutIdx)--;
188         }
189       }
190       // update subgraph output indices
191       for (auto idx = subgraph->tensorIndices.begin(); idx != subgraph->tensorIndices.end(); idx++) {
192         if (*idx > deleteIdx) {
193           (*idx)--;
194         }
195       }
196     }
197 
198     // update nodes indexes
199     for (auto node_iter = graphT->nodes.begin(); node_iter != graphT->nodes.end(); node_iter++) {
200       // update nodes input indexes
201       UpdateNodeIndex((*node_iter).get(), deleteIdx);
202     }
203     // update deleteTensorIdx
204     for (auto selfIt = toDeleteTensorIdxes.begin(); selfIt != toDeleteTensorIdxes.end(); selfIt++) {
205       if (*selfIt > deleteIdx) {
206         (*selfIt)--;
207       }
208     }
209     graphT->allTensors.erase(graphT->allTensors.begin() + deleteIdx);
210     iter = toDeleteTensorIdxes.erase(iter);
211   }
212   return RET_OK;
213 }
214 
IsolateNode(schema::MetaGraphT * graphT,schema::CNodeT * node)215 STATUS IsolateNode(schema::MetaGraphT *graphT, schema::CNodeT *node) {
216   MS_CHECK_TRUE_MSG(graphT != nullptr, RET_NULL_PTR, "graphT is nullptr");
217   MS_CHECK_TRUE_MSG(node != nullptr, RET_NULL_PTR, "node is nullptr");
218   size_t nodeIdx = 0;
219   for (size_t i = 0; i < graphT->nodes.size(); i++) {
220     auto &inNode = graphT->nodes.at(i);
221     MS_CHECK_TRUE_MSG(inNode != nullptr, RET_NULL_PTR, "inNode is nullptr");
222     if (inNode->name == node->name) {
223       nodeIdx = i;
224       break;
225     }
226   }
227   auto inputTensorIdxes = node->inputIndex;
228   auto outputTensorIdxes = node->outputIndex;
229   if (inputTensorIdxes.empty()) {
230     MS_LOG(ERROR) << "Node " << node->name.c_str() << "should has no inputs";
231     return RET_ERROR;
232   }
233   if (outputTensorIdxes.size() != 1) {
234     MS_LOG(ERROR) << "FakeQuantNode " << node->name.c_str()
235                   << "should has 1 output, in fact: " << outputTensorIdxes.size();
236     return RET_ERROR;
237   }
238   auto inDataTensorIdx = inputTensorIdxes.front();
239   auto outDataTensorIdx = outputTensorIdxes.front();
240 
241   MS_ASSERT(graphT->allTensors.size() > inDataTensorIdx);
242   ReplaceOutput(outDataTensorIdx, inDataTensorIdx, graphT);
243 
244   // find poseNode
245   auto postNodeIdxes = GetOutputNodeIdx(*graphT, nodeIdx, 0);
246   for (auto postNodeIdx : postNodeIdxes) {
247     MS_ASSERT(graphT->nodes.size() > postNodeIdx);
248     auto &postNode = graphT->nodes.at(postNodeIdx);
249     MS_CHECK_TRUE_MSG(postNode != nullptr, RET_NULL_PTR, "postNode is nullptr");
250     for (auto iter = postNode->inputIndex.begin(); iter != postNode->inputIndex.end(); iter++) {
251       if (*iter == outDataTensorIdx) {
252         *iter = inDataTensorIdx;
253         break;
254       }
255     }
256   }
257   RemoveTensor(graphT, outputTensorIdxes);
258   node->inputIndex.clear();
259   node->outputIndex.clear();
260   return RET_OK;
261 }
262 
IsolateOneWayNode(schema::MetaGraphT * graphT,size_t nodeIdx,bool removeTensor)263 STATUS IsolateOneWayNode(schema::MetaGraphT *graphT, size_t nodeIdx, bool removeTensor) {
264   MS_CHECK_TRUE_MSG(graphT != nullptr, RET_NULL_PTR, "graphT is nullptr");
265   if (graphT->nodes.size() <= nodeIdx) {
266     MS_LOG(ERROR) << "nodeIdx out of range: " << nodeIdx;
267     return RET_PARAM_INVALID;
268   }
269   schema::CNodeT *node = graphT->nodes.at(nodeIdx).get();
270   if (node == nullptr) {
271     MS_LOG(ERROR) << "node is null";
272     return RET_NULL_PTR;
273   }
274   auto inputTensorIdxes = node->inputIndex;
275   auto outputTensorIdxes = node->outputIndex;
276   auto preNodeIdxes = GetInputNodeIdx(*graphT, nodeIdx);
277   if (preNodeIdxes.size() > 1 || outputTensorIdxes.size() > 1) {
278     MS_LOG(ERROR) << "Only support node who has no more than one input and one output";
279     return RET_ERROR;
280   }
281   if (inputTensorIdxes.empty()) {
282     MS_LOG(ERROR) << "Error, " << nodeIdx << "th node has no input tensor";
283     return RET_ERROR;
284   }
285   auto inDataTensorIdx = inputTensorIdxes.front();
286   if (!outputTensorIdxes.empty()) {
287     auto outDataTensorIdx = outputTensorIdxes.front();
288     MS_ASSERT(graphT->allTensors.size() > inDataTensorIdx);
289     MS_ASSERT(graphT->allTensors.at(inDataTensorIdx) != nullptr);
290     ReplaceOutput(outDataTensorIdx, inDataTensorIdx, graphT);
291 
292     // find poseNode
293     auto postNodeIdxes = GetOutputNodeIdx(*graphT, nodeIdx, 0);
294     for (auto postNodeIdx : postNodeIdxes) {
295       MS_ASSERT(graphT->nodes.size() > postNodeIdx);
296       auto &postNode = graphT->nodes.at(postNodeIdx);
297       MS_CHECK_TRUE_MSG(postNode != nullptr, RET_NULL_PTR, "postNode is nullptr");
298       for (auto iter = postNode->inputIndex.begin(); iter != postNode->inputIndex.end(); iter++) {
299         if (*iter == outDataTensorIdx) {
300           *iter = inDataTensorIdx;
301           break;
302         }
303       }
304     }
305   }
306   if (removeTensor) {
307     // now all node's outputTensors are useless
308     // remove all node's outputTensors
309     auto status = RemoveTensor(graphT, outputTensorIdxes);
310     if (status != RET_OK) {
311       MS_LOG(ERROR) << "RemoveOutputTensors of node " << node->name.c_str() << "failed";
312       return RET_ERROR;
313     }
314   }
315   node->inputIndex.clear();
316   node->outputIndex.clear();
317   return RET_OK;
318 }
319 
IsolateOneWayNode(schema::MetaGraphT * graph,size_t subGraphIdx,size_t nodeIdx,bool removeTensor)320 STATUS IsolateOneWayNode(schema::MetaGraphT *graph, size_t subGraphIdx, size_t nodeIdx, bool removeTensor) {
321   MS_CHECK_TRUE_MSG(graph != nullptr, RET_NULL_PTR, "graph is nullptr");
322   return IsolateOneWayNode(graph, nodeIdx, removeTensor);
323 }
324 
IsolateOneWayNode(schema::MetaGraphT * graphT,schema::CNodeT * node,bool removeTensor)325 STATUS IsolateOneWayNode(schema::MetaGraphT *graphT, schema::CNodeT *node, bool removeTensor) {
326   MS_CHECK_TRUE_MSG(graphT != nullptr, RET_NULL_PTR, "graphT is nullptr");
327   MS_CHECK_TRUE_MSG(node != nullptr, RET_NULL_PTR, "node is nullptr");
328   bool isSubNode = false;
329   size_t nodeIdx = 0;
330   for (size_t i = 0; i < graphT->nodes.size(); i++) {
331     auto &inNode = graphT->nodes.at(i);
332     MS_CHECK_TRUE_MSG(inNode != nullptr, RET_NULL_PTR, "inNode is nullptr");
333     if (inNode->name == node->name) {
334       isSubNode = true;
335       nodeIdx = i;
336       break;
337     }
338   }
339   if (!isSubNode) {
340     MS_LOG(ERROR) << "Node " << node->name.c_str() << "is not in graphT " << graphT->name.c_str();
341     return RET_PARAM_INVALID;
342   } else {
343     return IsolateOneWayNode(graphT, nodeIdx, removeTensor);
344   }
345 }
346 }  // namespace mindspore::lite
347