• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "tools/common/graph_util.h"
18 #include <algorithm>
19 #include <functional>
20 #include <ctime>
21 #include <utility>
22 #include <set>
23 #include "schema/inner/model_generated.h"
24 #include "tools/common/tensor_util.h"
25 #include "tools/converter/quantizer/bitpacking.h"
26 #include "tools/common/node_util.h"
27 #include "src/common/log_adapter.h"
28 #include "src/common/utils.h"
29 #include "tools/converter/ops/ops_def.h"
30 #include "nnacl/op_base.h"
31 
32 namespace mindspore {
33 namespace lite {
34 namespace {
35 enum QuantBitNum { QuantBitNum_INT8 = 8, QuantBitNum_INT16 = 16 };
36 const int kZeroPointGap = 128;
37 }  // namespace
SetFuncGraphOutput(const FuncGraphPtr & graph,const std::vector<AnfNodePtr> & outputs)38 int SetFuncGraphOutput(const FuncGraphPtr &graph, const std::vector<AnfNodePtr> &outputs) {
39   if (graph == nullptr || outputs.empty()) {
40     MS_LOG(DEBUG) << "Input graph is nullptr or outputs is empty";
41     return RET_INPUT_PARAM_INVALID;
42   }
43   if (outputs.size() == 1) {
44     graph->set_output(outputs.front(), false);
45     return RET_OK;
46   }
47   auto make_tuple_prim_ptr = std::make_shared<lite::MakeTuple>();
48   if (make_tuple_prim_ptr == nullptr) {
49     MS_LOG(DEBUG) << "new MakeTuple failed";
50     return lite::RET_NULL_PTR;
51   }
52   auto make_tuple_cnode = graph->NewCNode(make_tuple_prim_ptr, outputs);
53   if (make_tuple_prim_ptr == nullptr) {
54     MS_LOG(DEBUG) << "new cnode failed";
55     return lite::RET_NULL_PTR;
56   }
57   make_tuple_cnode->set_fullname_with_scope("return tuple");
58   graph->set_output(make_tuple_cnode, false);
59   return RET_OK;
60 }
61 
GetSimpleOpCopyer()62 OpDefCopyer GetSimpleOpCopyer() {
63   return [](CNodeT *inCNode) -> std::unique_ptr<CNodeT> {
64     std::unique_ptr<CNodeT> newCNode = std::make_unique<CNodeT>();
65     if (newCNode == nullptr) {
66       return nullptr;
67     }
68 
69     newCNode->name = inCNode->name;
70     newCNode->quantType = inCNode->quantType;
71     newCNode->primitive = std::make_unique<schema::PrimitiveT>();
72     newCNode->primitive->value.type = inCNode->primitive->value.type;
73     return newCNode;
74   };
75 }
76 
GetInputNodeIdx(const schema::MetaGraphT & graphT,const size_t & nodeIdx,const int inputIndexIdx)77 std::vector<size_t> GetInputNodeIdx(const schema::MetaGraphT &graphT, const size_t &nodeIdx, const int inputIndexIdx) {
78   return GetInputNodeIdx(graphT, *(graphT.nodes.at(nodeIdx).get()), inputIndexIdx);
79 }
80 
GetInputNodeIdx(const schema::MetaGraphT & graphT,const CNodeT & node,const int inputIndexIdx)81 std::vector<size_t> GetInputNodeIdx(const schema::MetaGraphT &graphT, const CNodeT &node, const int inputIndexIdx) {
82   std::vector<uint32_t> inputIndexes;
83   if (inputIndexIdx == -1) {
84     inputIndexes = node.inputIndex;
85   } else {
86     MS_ASSERT(node.inputIndex.size() > inputIndexIdx);
87     inputIndexes.emplace_back(node.inputIndex.at(inputIndexIdx));
88   }
89   std::set<size_t> inputNodeIdx;
90   for (uint32_t inputIdx : inputIndexes) {
91     auto linkedPreIdx = GetLinkedPreIdx(graphT, inputIdx);
92     inputNodeIdx.insert(linkedPreIdx.begin(), linkedPreIdx.end());
93   }
94   std::vector<size_t> ret;
95   ret.insert(ret.end(), inputNodeIdx.begin(), inputNodeIdx.end());
96   return ret;
97 }
98 
GetOutputNodeIdx(const schema::MetaGraphT & graphT,const size_t & nodeIdx,const int outputIndexIdx)99 std::vector<size_t> GetOutputNodeIdx(const schema::MetaGraphT &graphT, const size_t &nodeIdx,
100                                      const int outputIndexIdx) {
101   return GetOutputNodeIdx(graphT, *(graphT.nodes.at(nodeIdx).get()), outputIndexIdx);
102 }
103 
ReplaceOutput(const uint32_t & old_index,const uint32_t & new_index,schema::MetaGraphT * graphT)104 void ReplaceOutput(const uint32_t &old_index, const uint32_t &new_index, schema::MetaGraphT *graphT) {
105   std::replace_if(
106     std::begin(graphT->outputIndex), std::end(graphT->outputIndex),
107     [&old_index](uint32_t outputIndex) { return outputIndex == old_index; }, new_index);
108 
109   for (auto &subGraph : graphT->subGraph) {
110     std::replace_if(
111       std::begin(subGraph->outputIndices), std::end(subGraph->outputIndices),
112       [&old_index](uint32_t outputIndex) { return outputIndex == old_index; }, new_index);
113   }
114 }
115 
GetOutputNodeIdx(const schema::MetaGraphT & graphT,const CNodeT & node,const int outputIndexIdx)116 std::vector<size_t> GetOutputNodeIdx(const schema::MetaGraphT &graphT, const CNodeT &node, const int outputIndexIdx) {
117   std::vector<uint32_t> outputIndexes;
118   if (outputIndexIdx == -1) {
119     outputIndexes = node.outputIndex;
120   } else {
121     MS_ASSERT(node.outputIndex.size() > outputIndexIdx);
122     outputIndexes.emplace_back(node.outputIndex.at(outputIndexIdx));
123   }
124   std::set<size_t> outputNodeIdx;
125   for (uint32_t outputIdx : outputIndexes) {
126     auto linkedPostIdx = GetLinkedPostIdx(graphT, outputIdx);
127     outputNodeIdx.insert(linkedPostIdx.begin(), linkedPostIdx.end());
128   }
129   std::vector<size_t> ret;
130   ret.insert(ret.end(), outputNodeIdx.begin(), outputNodeIdx.end());
131   return ret;
132 }
133 
GetLinkedPreIdx(const schema::MetaGraphT & graphT,const size_t & tensorIdx)134 std::vector<size_t> GetLinkedPreIdx(const schema::MetaGraphT &graphT, const size_t &tensorIdx) {
135   std::vector<size_t> preNodeIdx;
136   for (size_t i = 0; i < graphT.nodes.size(); i++) {
137     auto &oldNode = graphT.nodes.at(i);
138     if (oldNode == nullptr) {
139       continue;
140     }
141     auto outputIndexes = oldNode->outputIndex;
142     if (IsContain<uint32_t>(outputIndexes, tensorIdx)) {
143       preNodeIdx.emplace_back(i);
144     }
145   }
146   return preNodeIdx;
147 }
148 
GetLinkedPostIdx(const schema::MetaGraphT & graphT,const size_t & tensorIdx)149 std::vector<size_t> GetLinkedPostIdx(const schema::MetaGraphT &graphT, const size_t &tensorIdx) {
150   std::vector<size_t> postNodeIdx;
151   for (size_t i = 0; i < graphT.nodes.size(); i++) {
152     auto &oldNode = graphT.nodes.at(i);
153     if (oldNode == nullptr) {
154       continue;
155     }
156     auto inputIndexes = oldNode->inputIndex;
157     if (IsContain<uint32_t>(inputIndexes, tensorIdx)) {
158       postNodeIdx.emplace_back(i);
159     }
160   }
161   return postNodeIdx;
162 }
163 
IsolateNode(schema::MetaGraphT * graphT,CNodeT * node)164 STATUS IsolateNode(schema::MetaGraphT *graphT, CNodeT *node) {
165   MS_ASSERT(graphT != nullptr);
166   MS_ASSERT(node != nullptr);
167   size_t nodeIdx = 0;
168   for (size_t i = 0; i < graphT->nodes.size(); i++) {
169     auto &inNode = graphT->nodes.at(i);
170     MS_CHECK_TRUE_MSG(inNode != nullptr, RET_NULL_PTR, "inNode is nullptr");
171     if (inNode->name == node->name) {
172       nodeIdx = i;
173       break;
174     }
175   }
176   auto inputTensorIdxes = node->inputIndex;
177   auto outputTensorIdxes = node->outputIndex;
178   if (inputTensorIdxes.empty()) {
179     MS_LOG(ERROR) << "Node " << node->name.c_str() << "should has no inputs";
180     return RET_ERROR;
181   }
182   if (outputTensorIdxes.size() != 1) {
183     MS_LOG(ERROR) << "FakeQuantNode " << node->name.c_str()
184                   << "should has 1 output, in fact: " << outputTensorIdxes.size();
185     return RET_ERROR;
186   }
187   auto inDataTensorIdx = inputTensorIdxes.front();
188   auto outDataTensorIdx = outputTensorIdxes.front();
189 
190   MS_ASSERT(graphT->allTensors.size() > inDataTensorIdx);
191   ReplaceOutput(outDataTensorIdx, inDataTensorIdx, graphT);
192 
193   // find poseNode
194   auto postNodeIdxes = GetOutputNodeIdx(*graphT, nodeIdx, 0);
195   for (auto postNodeIdx : postNodeIdxes) {
196     MS_ASSERT(graphT->nodes.size() > postNodeIdx);
197     auto &postNode = graphT->nodes.at(postNodeIdx);
198     MS_CHECK_TRUE_MSG(postNode != nullptr, RET_NULL_PTR, "postNode is nullptr");
199     for (auto iter = postNode->inputIndex.begin(); iter != postNode->inputIndex.end(); iter++) {
200       if (*iter == outDataTensorIdx) {
201         *iter = inDataTensorIdx;
202         break;
203       }
204     }
205   }
206 
207   RemoveTensor(graphT, outputTensorIdxes);
208   node->inputIndex.clear();
209   node->outputIndex.clear();
210 
211   return RET_OK;
212 }
213 
IsolateOneWayNode(schema::MetaGraphT * graph,size_t subGraphIdx,size_t nodeIdx,bool removeTensor)214 STATUS IsolateOneWayNode(schema::MetaGraphT *graph, size_t subGraphIdx, size_t nodeIdx, bool removeTensor) {
215   MS_ASSERT(graph != nullptr);
216   return IsolateOneWayNode(graph, nodeIdx, removeTensor);
217 }
218 
IsolateOneWayNode(schema::MetaGraphT * graphT,size_t nodeIdx,bool removeTensor)219 STATUS IsolateOneWayNode(schema::MetaGraphT *graphT, size_t nodeIdx, bool removeTensor) {
220   MS_ASSERT(graphT != nullptr);
221   if (graphT->nodes.size() <= nodeIdx) {
222     MS_LOG(ERROR) << "nodeIdx out of range: " << nodeIdx;
223     return RET_PARAM_INVALID;
224   }
225   CNodeT *node = graphT->nodes.at(nodeIdx).get();
226   if (node == nullptr) {
227     MS_LOG(ERROR) << "node is null";
228     return RET_NULL_PTR;
229   }
230   auto inputTensorIdxes = node->inputIndex;
231   auto outputTensorIdxes = node->outputIndex;
232   auto preNodeIdxes = GetInputNodeIdx(*graphT, nodeIdx);
233   if (preNodeIdxes.size() > 1 || outputTensorIdxes.size() > 1) {
234     MS_LOG(ERROR) << "Only support node who has no more than one input and one output";
235     return RET_ERROR;
236   }
237   if (inputTensorIdxes.empty()) {
238     MS_LOG(ERROR) << "Error, " << nodeIdx << "th node has no input tensor";
239     return RET_ERROR;
240   }
241   auto inDataTensorIdx = inputTensorIdxes.front();
242   if (!outputTensorIdxes.empty()) {
243     auto outDataTensorIdx = outputTensorIdxes.front();
244     MS_ASSERT(graphT->allTensors.size() > inDataTensorIdx);
245     MS_ASSERT(graphT->allTensors.at(inDataTensorIdx) != nullptr);
246     ReplaceOutput(outDataTensorIdx, inDataTensorIdx, graphT);
247 
248     // find poseNode
249     auto postNodeIdxes = GetOutputNodeIdx(*graphT, nodeIdx, 0);
250     for (auto postNodeIdx : postNodeIdxes) {
251       MS_ASSERT(graphT->nodes.size() > postNodeIdx);
252       auto &postNode = graphT->nodes.at(postNodeIdx);
253       MS_CHECK_TRUE_MSG(postNode != nullptr, RET_NULL_PTR, "postNode is nullptr");
254       for (auto iter = postNode->inputIndex.begin(); iter != postNode->inputIndex.end(); iter++) {
255         if (*iter == outDataTensorIdx) {
256           *iter = inDataTensorIdx;
257           break;
258         }
259       }
260     }
261   }
262 
263   if (removeTensor) {
264     // now all node's outputTensors are useless
265     // remove all node's outputTensors
266     auto status = RemoveTensor(graphT, outputTensorIdxes);
267     if (status != RET_OK) {
268       MS_LOG(ERROR) << "RemoveOutputTensors of node " << node->name.c_str() << "failed";
269       return RET_ERROR;
270     }
271   }
272   node->inputIndex.clear();
273   node->outputIndex.clear();
274   return RET_OK;
275 }
276 
IsolateOneWayNode(schema::MetaGraphT * graphT,CNodeT * node,bool removeTensor)277 STATUS IsolateOneWayNode(schema::MetaGraphT *graphT, CNodeT *node, bool removeTensor) {
278   MS_ASSERT(graphT != nullptr);
279   MS_ASSERT(node != nullptr);
280   bool isSubNode = false;
281   size_t nodeIdx = 0;
282   for (size_t i = 0; i < graphT->nodes.size(); i++) {
283     auto &inNode = graphT->nodes.at(i);
284     MS_CHECK_TRUE_MSG(inNode != nullptr, RET_NULL_PTR, "inNode is nullptr");
285     if (inNode->name == node->name) {
286       isSubNode = true;
287       nodeIdx = i;
288       break;
289     }
290   }
291   if (!isSubNode) {
292     MS_LOG(ERROR) << "Node " << node->name.c_str() << "is not in graphT " << graphT->name.c_str();
293     return RET_PARAM_INVALID;
294   } else {
295     return IsolateOneWayNode(graphT, nodeIdx, removeTensor);
296   }
297 }
298 
RemoveTensor(schema::MetaGraphT * graphT,std::vector<uint32_t> toDeleteTensorIdxes,bool forceDelete)299 STATUS RemoveTensor(schema::MetaGraphT *graphT, std::vector<uint32_t> toDeleteTensorIdxes, bool forceDelete) {
300   MS_ASSERT(graphT != nullptr);
301   for (auto iter = toDeleteTensorIdxes.begin(); iter != toDeleteTensorIdxes.end();) {
302     uint32_t deleteIdx = *iter;
303     if (!forceDelete) {
304       if (GetRefCount(graphT, deleteIdx) > 1) {
305         iter++;
306         continue;
307       }
308     }
309     // update graph input indices
310     for (auto gInIdx = graphT->inputIndex.begin(); gInIdx != graphT->inputIndex.end(); gInIdx++) {
311       if (*gInIdx > deleteIdx) {
312         (*gInIdx)--;
313       }
314     }
315     // update graph output indices
316     for (auto gOutIdx = graphT->outputIndex.begin(); gOutIdx != graphT->outputIndex.end(); gOutIdx++) {
317       if (*gOutIdx > deleteIdx) {
318         (*gOutIdx)--;
319       }
320     }
321 
322     for (auto &subgraph : graphT->subGraph) {
323       // update subgraph input indices
324       for (auto gInIdx = subgraph->inputIndices.begin(); gInIdx != subgraph->inputIndices.end(); gInIdx++) {
325         if (*gInIdx > deleteIdx) {
326           (*gInIdx)--;
327         }
328       }
329       // update subgraph output indices
330       for (auto gOutIdx = subgraph->outputIndices.begin(); gOutIdx != subgraph->outputIndices.end(); gOutIdx++) {
331         if (*gOutIdx > deleteIdx) {
332           (*gOutIdx)--;
333         }
334       }
335       // update subgraph output indices
336       for (auto idx = subgraph->tensorIndices.begin(); idx != subgraph->tensorIndices.end(); idx++) {
337         if (*idx > deleteIdx) {
338           (*idx)--;
339         }
340       }
341     }
342 
343     // update nodes indexes
344     for (auto node_iter = graphT->nodes.begin(); node_iter != graphT->nodes.end(); node_iter++) {
345       // update nodes input indexes
346       UpdateNodeIndex((*node_iter).get(), deleteIdx);
347     }
348     // update deleteTensorIdx
349     for (auto selfIt = toDeleteTensorIdxes.begin(); selfIt != toDeleteTensorIdxes.end(); selfIt++) {
350       if (*selfIt > deleteIdx) {
351         (*selfIt)--;
352       }
353     }
354     graphT->allTensors.erase(graphT->allTensors.begin() + deleteIdx);
355     iter = toDeleteTensorIdxes.erase(iter);
356   }
357   return RET_OK;
358 }
359 
UpdateNodeIndex(CNodeT * node,uint32_t deleteIdx)360 STATUS UpdateNodeIndex(CNodeT *node, uint32_t deleteIdx) {
361   MS_ASSERT(node != nullptr);
362   for (auto inIdxIt = node->inputIndex.begin(); inIdxIt != node->inputIndex.end();) {
363     if (*inIdxIt == deleteIdx) {
364       inIdxIt = node->inputIndex.erase(inIdxIt);
365     } else {
366       if (*inIdxIt > deleteIdx) {
367         (*inIdxIt)--;
368       }
369       inIdxIt++;
370     }
371   }
372   // update nodes output indexes
373   for (auto outIdxIt = node->outputIndex.begin(); outIdxIt != node->outputIndex.end();) {
374     if (*outIdxIt == deleteIdx) {
375       outIdxIt = node->outputIndex.erase(outIdxIt);
376     } else {
377       if (*outIdxIt > deleteIdx) {
378         (*outIdxIt)--;
379       }
380       outIdxIt++;
381     }
382   }
383   return RET_OK;
384 }
385 
AddTensor2Node(schema::MetaGraphT * graphT,uint32_t nodeIdx,std::unique_ptr<TensorT> tensor,InsertPlace place)386 STATUS AddTensor2Node(schema::MetaGraphT *graphT, uint32_t nodeIdx, std::unique_ptr<TensorT> tensor,
387                       InsertPlace place) {
388   if (nodeIdx >= graphT->nodes.size()) {
389     MS_LOG(ERROR) << "nodeIdx out of range: " << nodeIdx;
390     return RET_PARAM_INVALID;
391   }
392   graphT->allTensors.emplace_back(std::move(tensor));
393   uint32_t newTensorIdx = graphT->allTensors.size() - 1;
394   auto node = graphT->nodes.at(nodeIdx).get();
395   MS_CHECK_TRUE_MSG(node != nullptr, RET_NULL_PTR, "node is nullptr");
396   if (place == kBefore) {
397     node->inputIndex.emplace_back(newTensorIdx);
398   } else {
399     node->outputIndex.emplace_back(newTensorIdx);
400   }
401   return RET_OK;
402 }
403 
ReplaceTensorOfNode(schema::MetaGraphT * graphT,uint32_t nodeIdx,uint32_t inTensorIdx,std::unique_ptr<TensorT> tensor)404 STATUS ReplaceTensorOfNode(schema::MetaGraphT *graphT, uint32_t nodeIdx, uint32_t inTensorIdx,
405                            std::unique_ptr<TensorT> tensor) {
406   MS_ASSERT(graphT != nullptr);
407   if (nodeIdx >= graphT->nodes.size()) {
408     MS_LOG(ERROR) << "nodeIdx out of range: " << nodeIdx;
409     return RET_PARAM_INVALID;
410   }
411   auto node = graphT->nodes.at(nodeIdx).get();
412   MS_CHECK_TRUE_MSG(node != nullptr, RET_NULL_PTR, "node is nullptr");
413   if (inTensorIdx >= graphT->allTensors.size()) {
414     MS_LOG(ERROR) << "inTensorIdx out of range: " << nodeIdx;
415     return RET_PARAM_INVALID;
416   }
417   if (!IsContain(node->inputIndex, inTensorIdx)) {
418     MS_LOG(ERROR) << "inTensorIdx(" << inTensorIdx << ") is not a inputIdx of node(" << nodeIdx << ")";
419     return RET_PARAM_INVALID;
420   }
421   graphT->allTensors.at(inTensorIdx).swap(tensor);
422   return RET_OK;
423 }
424 
DoBitPack(const int & bit_num,schema::TensorT * tensor_input)425 int DoBitPack(const int &bit_num, schema::TensorT *tensor_input) {
426   if (bit_num > 0 && bit_num < 8) {
427     std::vector<int8_t> origin_data(tensor_input->data.size());
428     auto status = memcpy_s(origin_data.data(), origin_data.size() * sizeof(int8_t), tensor_input->data.data(),
429                            tensor_input->data.size() * sizeof(uint8_t));
430     if (status != EOK) {
431       MS_LOG(ERROR) << "memcpy failed. " << status;
432       return RET_ERROR;
433     }
434     std::vector<uint8_t> pack_data{};
435     BitPack::BitPacking<int8_t, uint8_t>(bit_num, origin_data, &pack_data);
436     tensor_input->data.resize(pack_data.size() * sizeof(uint8_t));
437     status = memcpy_s(tensor_input->data.data(), tensor_input->data.size() * sizeof(uint8_t), pack_data.data(),
438                       pack_data.size() * sizeof(uint8_t));
439     if (status != EOK) {
440       MS_LOG(ERROR) << "memcpy_s failed. " << status;
441       return RET_ERROR;
442     }
443   } else if (bit_num > QuantBitNum_INT8 && bit_num < QuantBitNum_INT16) {
444     auto shape_size =
445       std::accumulate(tensor_input->dims.begin(), tensor_input->dims.end(), size_t(1), std::multiplies<size_t>());
446     std::vector<int16_t> origin_data(shape_size);
447     auto status = memcpy_s(origin_data.data(), origin_data.size() * sizeof(int16_t), tensor_input->data.data(),
448                            tensor_input->data.size() * sizeof(uint8_t));
449     if (status != EOK) {
450       MS_LOG(ERROR) << "memcpy failed. " << status;
451       return RET_ERROR;
452     }
453     std::vector<uint16_t> pack_data{};
454     BitPack::BitPacking<int16_t, uint16_t>(bit_num, origin_data, &pack_data);
455     tensor_input->data.resize(pack_data.size() * sizeof(uint16_t));
456     status = memcpy_s(tensor_input->data.data(), tensor_input->data.size() * sizeof(uint8_t), pack_data.data(),
457                       pack_data.size() * sizeof(uint16_t));
458     if (status != EOK) {
459       MS_LOG(ERROR) << "memcpy_s failed. " << status;
460       return RET_ERROR;
461     }
462   }
463   return RET_OK;
464 }
465 
InsertNode(schema::MetaGraphT * graphT,uint32_t existNodeIdx,InsertPlace place,size_t inoutIndex,std::unique_ptr<CNodeT> toAddNode,STATUS * errorCode,int * insert_num,const OpDefCopyer & opDefCopyer)466 NodeIter InsertNode(schema::MetaGraphT *graphT, uint32_t existNodeIdx, InsertPlace place, size_t inoutIndex,
467                     std::unique_ptr<CNodeT> toAddNode, STATUS *errorCode, int *insert_num,
468                     const OpDefCopyer &opDefCopyer) {
469   MS_ASSERT(graphT != nullptr);
470   MS_ASSERT(errorCode != nullptr);
471   if (existNodeIdx >= graphT->nodes.size()) {
472     MS_LOG(ERROR) << "nodeIdx out of range: " << existNodeIdx;
473     return graphT->nodes.end();
474   }
475   auto node_iter = graphT->nodes.begin() + existNodeIdx;
476   MS_ASSERT(node_iter != graphT->nodes.begin());
477   MS_ASSERT((*node_iter) != nullptr);
478   return InsertNode(graphT, node_iter, place, inoutIndex, std::move(toAddNode), errorCode, insert_num);
479 }
480 
InsertNode(schema::MetaGraphT * graphT,NodeIter existNodeIter,InsertPlace place,size_t inoutIndexIdx,std::unique_ptr<CNodeT> toAddNode,STATUS * errorCode,int * insert_num,const OpDefCopyer & opDefCopyer)481 NodeIter InsertNode(schema::MetaGraphT *graphT, NodeIter existNodeIter, InsertPlace place, size_t inoutIndexIdx,
482                     std::unique_ptr<CNodeT> toAddNode, STATUS *errorCode, int *insert_num,
483                     const OpDefCopyer &opDefCopyer) {
484   MS_ASSERT(graphT != nullptr);
485   MS_ASSERT(errorCode != nullptr);
486   if (place == kBefore) {
487     return InsertNodeBefore(graphT, existNodeIter, inoutIndexIdx, std::move(toAddNode), errorCode, insert_num,
488                             opDefCopyer);
489   } else if (place == kAfter) {
490     return InsertNodeAfter(graphT, existNodeIter, inoutIndexIdx, std::move(toAddNode), errorCode, insert_num,
491                            opDefCopyer);
492   } else {
493     MS_LOG(ERROR) << "Invalid InsertPlace : " << place;
494     return graphT->nodes.end();
495   }
496 }
497 
InsertNodeBefore(schema::MetaGraphT * graphT,NodeIter existNodeIter,size_t inputIndexIdx,std::unique_ptr<CNodeT> toAddNodeIn,STATUS * errorCode,int * insert_num,const OpDefCopyer & opDefCopyer)498 NodeIter InsertNodeBefore(schema::MetaGraphT *graphT, NodeIter existNodeIter, size_t inputIndexIdx,
499                           std::unique_ptr<CNodeT> toAddNodeIn, STATUS *errorCode, int *insert_num,
500                           const OpDefCopyer &opDefCopyer) {
501   MS_ASSERT(graphT != nullptr);
502   MS_ASSERT(errorCode != nullptr);
503   auto &existNode = *existNodeIter;
504   MS_ASSERT(existNode != nullptr);
505   MS_ASSERT(existNode->inputIndex.size() > inputIndexIdx);
506   MS_ASSERT(toAddNodeIn != nullptr);
507   auto preTensorIdx = existNode->inputIndex.at(inputIndexIdx);
508   MS_ASSERT(graphT->allTensors.size() > preTensorIdx);
509 
510   auto preNodeIdxes = GetInputNodeIdx(*graphT, *(existNode), inputIndexIdx);
511   size_t insert_node_num = preNodeIdxes.empty() ? 1 : preNodeIdxes.size();
512   std::vector<std::unique_ptr<CNodeT>> toAddNodes;
513   for (size_t i = 0; i < insert_node_num; ++i) {
514     auto &preTensor = graphT->allTensors.at(preTensorIdx);
515     MS_ASSERT(preTensor != nullptr);
516     auto toAddTensor = CopyTensorDefT(preTensor);
517     if (toAddTensor == nullptr) {
518       *errorCode = RET_NULL_PTR;
519       MS_LOG(ERROR) << "Copy Tensor failed";
520       return graphT->nodes.end();
521     }
522     toAddTensor->nodeType = NodeType_CNode;
523     toAddTensor->refCount = 0;
524     toAddTensor->data.clear();
525     MS_ASSERT(toAddNodeIn->primitive != nullptr);
526     if (toAddNodeIn->primitive->value.type == schema::PrimitiveType_QuantDTypeCast) {
527       auto prim = toAddNodeIn->primitive->value.AsQuantDTypeCast();
528       MS_ASSERT(prim != nullptr);
529       if (prim->src_t == TypeId::kNumberTypeUInt8) {
530         if (preTensor->dataType == TypeId::kNumberTypeUInt8) {
531           toAddTensor->quantParams.front()->zeroPoint -= kZeroPointGap;
532         } else {
533           preTensor->quantParams.front()->zeroPoint += kZeroPointGap;
534         }
535       } else if (prim->dst_t == TypeId::kNumberTypeUInt8) {
536         if (preTensor->dataType == TypeId::kNumberTypeInt8) {
537           toAddTensor->quantParams.front()->zeroPoint += kZeroPointGap;
538         } else {
539           preTensor->quantParams.front()->zeroPoint -= kZeroPointGap;
540         }
541       }
542       preTensor->dataType = prim->src_t;
543       toAddTensor->dataType = prim->dst_t;
544     }
545     graphT->allTensors.emplace_back(std::move(toAddTensor));
546     size_t toAddTensorIdx = graphT->allTensors.size() - 1;
547     auto toAddNode = opDefCopyer(toAddNodeIn.get());
548     if (toAddNode == nullptr) {
549       MS_LOG(ERROR) << "copy toAddNodeIn failed";
550       *errorCode = RET_NULL_PTR;
551       return graphT->nodes.end();
552     }
553     if (!preNodeIdxes.empty()) {
554       toAddNode->name = toAddNodeIn->name + "_" + std::to_string(i);
555     }
556     toAddNode->inputIndex.clear();
557     toAddNode->inputIndex.push_back(preTensorIdx);
558     toAddNode->outputIndex.clear();
559     toAddNode->outputIndex.push_back(toAddTensorIdx);
560     for (auto iter = existNode->inputIndex.begin(); iter != existNode->inputIndex.end(); iter++) {
561       if (*iter == preTensorIdx) {
562         *iter = toAddTensorIdx;
563         break;
564       }
565     }
566     toAddNodes.emplace_back(std::move(toAddNode));
567   }
568   for (auto &toAddNode : toAddNodes) {
569     existNodeIter = graphT->nodes.insert(existNodeIter, std::move(toAddNode));
570     existNodeIter++;
571     *insert_num += 1;
572   }
573   *errorCode = RET_OK;
574   return existNodeIter;
575 }
576 
InsertNodeAfter(schema::MetaGraphT * graphT,NodeIter existNodeIter,size_t outputIndexIdx,std::unique_ptr<schema::CNodeT> toAddNodeIn,STATUS * errorCode,int * insert_num,const OpDefCopyer & opDefCopyer)577 NodeIter InsertNodeAfter(schema::MetaGraphT *graphT, NodeIter existNodeIter, size_t outputIndexIdx,
578                          std::unique_ptr<schema::CNodeT> toAddNodeIn, STATUS *errorCode, int *insert_num,
579                          const OpDefCopyer &opDefCopyer) {
580   MS_ASSERT(graphT != nullptr);
581   MS_ASSERT(errorCode != nullptr);
582   auto &existNode = *existNodeIter;
583   MS_ASSERT(existNode != nullptr);
584   MS_ASSERT(existNode->outputIndex.size() > outputIndexIdx);
585   MS_ASSERT(toAddNodeIn != nullptr);
586   auto postTensorIdx = existNode->outputIndex.at(outputIndexIdx);
587   MS_ASSERT(graphT->allTensors.size() > postTensorIdx);
588   auto postNodeIdxes = GetOutputNodeIdx(*graphT, *(existNode), outputIndexIdx);
589   bool is_output_index = IsContain(graphT->outputIndex, postTensorIdx);
590   size_t insert_node_num = (postNodeIdxes.empty() || is_output_index) ? postNodeIdxes.size() + 1 : postNodeIdxes.size();
591   bool has_insert_for_graph_out = postNodeIdxes.empty() || is_output_index;
592   std::vector<std::unique_ptr<schema::CNodeT>> toAddNodes;
593   for (size_t i = 0; i < insert_node_num; ++i) {
594     auto &postTensor = graphT->allTensors.at(postTensorIdx);
595     MS_ASSERT(postTensor != nullptr);
596     auto toAddTensor = CopyTensorDefT(postTensor);
597     if (toAddTensor == nullptr) {
598       MS_LOG(ERROR) << "Copy TensorT failed";
599       *errorCode = RET_NULL_PTR;
600       return graphT->nodes.end();
601     }
602     toAddTensor->nodeType = NodeType_CNode;
603     MS_ASSERT(toAddNodeIn->primitive != nullptr);
604     if (toAddNodeIn->primitive->value.type == schema::PrimitiveType_QuantDTypeCast) {
605       auto prim = toAddNodeIn->primitive->value.AsQuantDTypeCast();
606       MS_ASSERT(prim != nullptr);
607       if (prim->dst_t == TypeId::kNumberTypeUInt8) {
608         if (postTensor->dataType == TypeId::kNumberTypeUInt8) {
609           postTensor->quantParams.front()->zeroPoint -= kZeroPointGap;
610         } else {
611           toAddTensor->quantParams.front()->zeroPoint += kZeroPointGap;
612         }
613       } else if (prim->src_t == TypeId::kNumberTypeUInt8) {
614         if (postTensor->dataType == TypeId::kNumberTypeUInt8) {
615           toAddTensor->quantParams.front()->zeroPoint -= kZeroPointGap;
616         } else {
617           postTensor->quantParams.front()->zeroPoint += kZeroPointGap;
618         }
619       }
620       postTensor->dataType = prim->src_t;
621       toAddTensor->dataType = prim->dst_t;
622     }
623     graphT->allTensors.emplace_back(std::move(toAddTensor));
624     size_t toAddTensorIdx = graphT->allTensors.size() - 1;
625     auto toAddNode = opDefCopyer(toAddNodeIn.get());
626     if (toAddNode == nullptr) {
627       MS_LOG(ERROR) << "copy toAddNodeIn failed";
628       *errorCode = RET_NULL_PTR;
629       return graphT->nodes.end();
630     }
631     toAddNode->inputIndex.clear();
632     toAddNode->inputIndex.push_back(postTensorIdx);
633     toAddNode->outputIndex.clear();
634     toAddNode->outputIndex.push_back(toAddTensorIdx);
635     if (!postNodeIdxes.empty()) {
636       toAddNode->name = toAddNodeIn->name + "_" + std::to_string(i);
637     }
638     if (has_insert_for_graph_out) {
639       ReplaceOutput(postTensorIdx, toAddTensorIdx, graphT);
640       has_insert_for_graph_out = false;
641     } else {
642       auto &postNode = graphT->nodes.at(postNodeIdxes[is_output_index ? i - 1 : i]);
643       for (auto iter = postNode->inputIndex.begin(); iter != postNode->inputIndex.end(); iter++) {
644         if (*iter == postTensorIdx) {
645           *iter = toAddTensorIdx;
646         }
647       }
648     }
649     toAddNodes.emplace_back(std::move(toAddNode));
650   }
651   for (auto &toAddNode : toAddNodes) {
652     existNodeIter = graphT->nodes.insert(existNodeIter, std::move(toAddNode));
653     existNodeIter++;
654     *insert_num += 1;
655   }
656   *errorCode = RET_OK;
657   return existNodeIter;
658 }
659 
ValidateFileStr(const std::string & modelFile,const std::string & fileType)660 STATUS ValidateFileStr(const std::string &modelFile, const std::string &fileType) {
661   if (modelFile.size() > fileType.size() && modelFile.substr(modelFile.size() - fileType.size()) == fileType) {
662     return RET_OK;
663   } else {
664     return RET_ERROR;
665   }
666 }
667 
GetModelName(const std::string & modelFile)668 std::string GetModelName(const std::string &modelFile) {
669   std::string modelName = modelFile;
670   modelName = modelName.substr(modelName.find_last_of('/') + 1);
671   modelName = modelName.substr(0, modelName.find_last_of('.'));
672   return modelName;
673 }
674 
SetSubgraphTensorIndices(schema::MetaGraphT * meta_graphT)675 int SetSubgraphTensorIndices(schema::MetaGraphT *meta_graphT) {
676   for (auto &subgraph : meta_graphT->subGraph) {
677     std::vector<uint32_t> subgraph_indices{};
678     subgraph_indices.assign(subgraph->inputIndices.begin(), subgraph->inputIndices.end());
679     subgraph_indices.assign(subgraph->outputIndices.begin(), subgraph->outputIndices.end());
680     for (auto &node_idx : subgraph->nodeIndices) {
681       auto &node = meta_graphT->nodes.at(node_idx);
682       for (auto &input_idx : node->inputIndex) {
683         if (IsContain(subgraph_indices, input_idx)) {
684           continue;
685         } else {
686           subgraph_indices.push_back(input_idx);
687         }
688       }
689       for (auto &output_idx : node->outputIndex) {
690         if (IsContain(subgraph_indices, output_idx)) {
691           continue;
692         } else {
693           subgraph_indices.push_back(output_idx);
694         }
695       }
696     }
697     subgraph->tensorIndices.assign(subgraph_indices.begin(), subgraph_indices.end());
698   }
699   return RET_OK;
700 }
701 
GetTransposePerm(MetaGraphT * graph,const std::unique_ptr<CNodeT> & cnode)702 std::vector<int> GetTransposePerm(MetaGraphT *graph, const std::unique_ptr<CNodeT> &cnode) {
703   MS_ASSERT(graph != nullptr && cnode != nullptr);
704   std::vector<int> perm;
705   if (cnode->primitive->value.type != schema::PrimitiveType_Transpose) {
706     return perm;
707   }
708   if (cnode->inputIndex.size() < 2) {
709     MS_LOG(ERROR) << "transpose node input size is less than 2.";
710     return perm;
711   }
712   MS_ASSERT(cnode->outputIndex.at(1) < graph->allTensors.size());
713   auto &perm_tensor = graph->allTensors.at(cnode->inputIndex.at(1));
714   if (perm_tensor->data.empty()) {
715     return perm;
716   }
717   MS_ASSERT(perm_tensor->dims.size() != 0);
718   perm.resize(perm_tensor->dims[0]);
719   if (memcpy_s(perm.data(), perm_tensor->dims[0] * sizeof(int), perm_tensor->data.data(),
720                perm_tensor->dims[0] * sizeof(int)) != EOK) {
721     MS_LOG(ERROR) << "memcpy data failed.";
722     return {};
723   }
724   return perm;
725 }
726 
727 namespace {
728 constexpr size_t kBitNumPerByte = 8;
729 }
730 
BoolVectorToString(const std::vector<bool> & bool_vec)731 std::string BoolVectorToString(const std::vector<bool> &bool_vec) {
732   size_t size_in_byte = ceil(bool_vec.size() / kBitNumPerByte);
733   std::string str(size_in_byte, '\0');
734   auto iter = str.begin();
735   size_t shift = kBitNumPerByte;
736   for (bool bit : bool_vec) {
737     *iter |= bit << (shift - 1);
738     if (--shift == 0) {
739       iter++;
740       shift = kBitNumPerByte;
741     }
742   }
743   return str;
744 }
745 
GetAbstractTensorDtype(const abstract::AbstractTensorPtr & tensor)746 TypeId GetAbstractTensorDtype(const abstract::AbstractTensorPtr &tensor) {
747   if (tensor == nullptr || tensor->element() == nullptr) {
748     MS_LOG(ERROR) << "abstract_tensor or abstract_tensor->element() is nullptr";
749     return kTypeUnknown;
750   }
751   auto type_ptr = tensor->element()->GetTypeTrack();
752   return type_ptr->type_id();
753 }
754 
GetParameterDtype(const ParameterPtr & param_node)755 TypeId GetParameterDtype(const ParameterPtr &param_node) {
756   auto abstract_base = param_node->abstract();
757   auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract_base);
758   auto type_ptr = abstract_tensor->element()->GetTypeTrack();
759   return type_ptr->type_id();
760 }
761 
UpdateFuncGraphInputsAndOutputsDtype(const FuncGraphPtr & func_graph)762 STATUS UpdateFuncGraphInputsAndOutputsDtype(const FuncGraphPtr &func_graph) {
763   MS_ASSERT(func_graph != nullptr);
764   // update graph inputs dtype
765   size_t idx = 0;
766   for (auto &input : func_graph->get_inputs()) {
767     TypeId type = GetParameterDtype(input->cast<ParameterPtr>());
768     ConverterContext::GetInstance()->UpdateGraphInputDType(idx, type);
769     idx++;
770   }
771   // update graph outputs dtype
772   auto graph_return = func_graph->get_return();
773   idx = 0;
774   for (auto &input : graph_return->inputs()) {
775     if (input->isa<CNode>()) {
776       if (utils::isa<abstract::AbstractTuple>(input->abstract())) {
777         auto tuple = std::reinterpret_pointer_cast<abstract::AbstractTuple>(input->abstract());
778         if (tuple == nullptr) {
779           MS_LOG(ERROR) << "tuple is nullptr";
780           return RET_ERROR;
781         }
782         for (const auto &tuple_item : tuple->elements()) {
783           TypeId type = GetAbstractTensorDtype(tuple_item->cast<abstract::AbstractTensorPtr>());
784           ConverterContext::GetInstance()->UpdateGraphOutputDType(idx, type);
785           idx++;
786         }
787       } else if (utils::isa<abstract::AbstractTensor>(input->abstract())) {
788         TypeId type = GetAbstractTensorDtype(input->abstract()->cast<abstract::AbstractTensorPtr>());
789         ConverterContext::GetInstance()->UpdateGraphOutputDType(idx, type);
790         idx++;
791       } else {
792         ConverterContext::GetInstance()->UpdateGraphOutputDType(idx, kTypeUnknown);
793         idx++;
794       }
795     }
796   }
797   return RET_OK;
798 }
799 }  // namespace lite
800 }  // namespace mindspore
801