• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define USE_DEPRECATED_API
18 #include "tools/common/graph_util.h"
19 #include <algorithm>
20 #include <functional>
21 #include <ctime>
22 #include <utility>
23 #include <set>
24 #include "mindspore/core/ops/sequence_ops.h"
25 #include "mindspore/core/ops/framework_ops.h"
26 #include "tools/common/meta_graph_utils.h"
27 #include "schema/inner/model_generated.h"
28 #include "tools/common/tensor_util.h"
29 #include "src/common/log_adapter.h"
30 #include "src/common/utils.h"
31 #include "nnacl/op_base.h"
32 #include "ops/make_tuple.h"
33 #include "tools/converter/converter_context.h"
34 #include "tools/optimizer/common/gllo_utils.h"
35 #include "tools/common/string_util.h"
36 
37 namespace mindspore {
38 namespace lite {
39 namespace {
40 const int kZeroPointGap = 128;
41 constexpr size_t kTupleGetItemFirstInputIdx = 1;
42 constexpr size_t kDependInputNum = 3;
43 constexpr size_t kDependFirstInputIdx = 1;
44 constexpr size_t kSequenceCodeGetItemInputSize = 3;
45 constexpr size_t kSecondIndex = 1;
46 constexpr size_t kInvalidSize = SIZE_MAX;
47 constexpr auto kMakeTuple = "MakeTuple";
48 constexpr auto kMakeList = "make_list";
49 constexpr size_t kEncMaxLen = 16;
50 }  // namespace
51 
GetAbstractfromSequenceCodeGetItem(const CNodePtr & cnode,AbstractBasePtr * abstract,size_t * idx)52 static STATUS GetAbstractfromSequenceCodeGetItem(const CNodePtr &cnode, AbstractBasePtr *abstract, size_t *idx) {
53   MS_CHECK_TRUE_MSG(abstract != nullptr, lite::RET_ERROR, "Abstract is nullptr.");
54   MS_CHECK_TRUE_MSG(idx != nullptr, lite::RET_ERROR, "idx is nullptr.");
55   auto SequenceCode_inputs = cnode->inputs();
56   MS_CHECK_TRUE_MSG(SequenceCode_inputs.size() == kSequenceCodeGetItemInputSize, lite::RET_ERROR,
57                     "The node must have 3 inputs!");
58   auto get_item_input_cnode = SequenceCode_inputs.at(kSecondIndex);
59   MS_CHECK_TRUE_MSG(get_item_input_cnode != nullptr, lite::RET_ERROR, "input node is nullptr.");
60 
61   AbstractBasePtrList abstract_list;
62   if (opt::CheckPrimitiveType(cnode, prim::kPrimTupleGetItem)) {
63     *idx = opt::GetTupleGetItemOutIndex(cnode);
64     if (!mindspore::utils::isa<mindspore::abstract::AbstractTuplePtr>(get_item_input_cnode->abstract())) {
65       MS_LOG(ERROR) << "TupleGetItem's abstract is not AbstractTuple, cnode name: "
66                     << get_item_input_cnode->fullname_with_scope();
67       return lite::RET_ERROR;
68     }
69     auto input_node_abstract = utils::cast<abstract::AbstractTuplePtr>(get_item_input_cnode->abstract());
70     abstract_list = input_node_abstract->elements();
71   } else {
72     *idx = opt::GetListGetItemOutIndex(cnode);
73     if (!mindspore::utils::isa<mindspore::abstract::AbstractListPtr>(get_item_input_cnode->abstract())) {
74       MS_LOG(ERROR) << "ListGetItem's abstract is not AbstractTuple, cnode name: "
75                     << get_item_input_cnode->fullname_with_scope();
76       return lite::RET_ERROR;
77     }
78     auto input_node_abstract = utils::cast<abstract::AbstractListPtr>(get_item_input_cnode->abstract());
79     abstract_list = input_node_abstract->elements();
80   }
81 
82   if (abstract_list.size() <= *idx) {
83     MS_LOG(ERROR) << "Abstract's size is smaller than expect";
84     return lite::RET_ERROR;
85   }
86   *abstract = abstract_list[*idx];
87   return lite::RET_OK;
88 }
89 
GetShapeVectorFromParameter(const mindspore::ParameterPtr & param_node,std::vector<int64_t> * shape_vector)90 STATUS GetShapeVectorFromParameter(const mindspore::ParameterPtr &param_node, std::vector<int64_t> *shape_vector) {
91   MS_CHECK_TRUE_MSG(shape_vector != nullptr, RET_ERROR, "shape vector is nullptr.");
92   auto abstract_base = param_node->abstract();
93   if (abstract_base == nullptr) {
94     MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << param_node->name();
95     return RET_ERROR;
96   }
97 
98   if (!abstract_base->isa<abstract::AbstractTensor>()) {
99     MS_LOG(ERROR) << "Abstract of parameter should be abstract tensor, " << param_node->name();
100     return lite::RET_ERROR;
101   }
102   auto abstract_tensor = abstract_base->cast<abstract::AbstractTensorPtr>();
103   MS_CHECK_TRUE_MSG(abstract_tensor != nullptr, RET_ERROR, "Cast to abstract tensor failed!");
104   *shape_vector = abstract_tensor->shape()->shape();
105   return lite::RET_OK;
106 }
107 
GetShapeVectorAndIdxFromCNode(const CNodePtr & cnode,std::vector<int64_t> * shape_vector,size_t * idx)108 STATUS GetShapeVectorAndIdxFromCNode(const CNodePtr &cnode, std::vector<int64_t> *shape_vector, size_t *idx) {
109   MS_CHECK_TRUE_MSG(shape_vector != nullptr, lite::RET_ERROR, "shape is nullptr");
110 
111   AbstractBasePtr cnode_abstract = nullptr;
112   if ((opt::CheckPrimitiveType(cnode, prim::kPrimTupleGetItem)) ||
113       (opt::CheckPrimitiveType(cnode, prim::kPrimListGetItem))) {
114     // idx is only used when cnode is type of kPrimTupleGetItem or kPrimListGetItem.
115     MS_CHECK_TRUE_MSG(idx != nullptr, lite::RET_ERROR, "idx is nullptr");
116     if (GetAbstractfromSequenceCodeGetItem(cnode, &cnode_abstract, idx) != lite::RET_OK) {
117       MS_LOG(ERROR) << "Get abstract from tuple get item failed.";
118       return lite::RET_ERROR;
119     }
120   } else {
121     cnode_abstract = cnode->abstract();
122   }
123   // the control flow model may be nullptr
124   if (cnode_abstract == nullptr) {
125     *shape_vector = std::vector<int64_t>();
126     return lite::RET_OK;
127   }
128   if (cnode_abstract->BuildShape() == mindspore::abstract::kNoShape) {
129     *shape_vector = std::vector<int64_t>();
130     return lite::RET_OK;
131   }
132   if (!utils::isa<mindspore::abstract::AbstractTensorPtr>(cnode_abstract)) {
133     MS_LOG(ERROR) << "Abstract is not abstract tensor. " << cnode->fullname_with_scope();
134     return lite::RET_ERROR;
135   }
136   auto cnode_abstract_tensor = cnode_abstract->cast<mindspore::abstract::AbstractTensorPtr>();
137   CHECK_NULL_RETURN(cnode_abstract_tensor);
138   if (!utils::isa<mindspore::abstract::ShapePtr>(cnode_abstract_tensor->BuildShape())) {
139     MS_LOG(ERROR) << "Shape of abstract tensor should be ShapePtr. " << cnode->fullname_with_scope();
140     return lite::RET_ERROR;
141   }
142   auto shape_ptr = utils::cast<mindspore::abstract::ShapePtr>(cnode_abstract_tensor->BuildShape());
143   CHECK_NULL_RETURN(shape_ptr);
144   if (shape_ptr->shape().empty()) {
145     MS_LOG(WARNING) << "Shape is empty " << cnode->fullname_with_scope();
146   }
147   *shape_vector = shape_ptr->shape();
148   return lite::RET_OK;
149 }
150 
GetCNodeOrParameterShapeVec(const AnfNodePtr & anf_node,std::vector<int> * shape)151 STATUS GetCNodeOrParameterShapeVec(const AnfNodePtr &anf_node, std::vector<int> *shape) {
152   auto int64_t_to_int_func = [](int64_t x) -> int { return static_cast<int>(x); };
153   std::vector<int64_t> in_shape;
154   if (anf_node->isa<CNode>()) {
155     auto status = GetShapeVectorAndIdxFromCNode(anf_node->cast<CNodePtr>(), &in_shape);
156     if (status != RET_OK) {
157       MS_LOG(ERROR) << "Get shape from CNode failed.";
158       return status;
159     }
160   } else if (anf_node->isa<Parameter>()) {
161     auto param_node = anf_node->cast<ParameterPtr>();
162     auto status = GetShapeVectorFromParameter(param_node, &in_shape);
163     if (status != RET_OK) {
164       MS_LOG(ERROR) << "Get shape from Parameter failed.";
165       return status;
166     }
167   } else {
168     MS_LOG(ERROR) << "Node type is not recognized.";
169     return RET_ERROR;
170   }
171   shape->resize(in_shape.size());
172   (void)std::transform(in_shape.begin(), in_shape.end(), shape->begin(), int64_t_to_int_func);
173   return RET_OK;
174 }
175 
TraceOutput(const AnfNodePtr & node,std::vector<std::pair<AnfNodePtr,int64_t>> * outputs,std::vector<std::string> * output_names,std::vector<std::vector<int64_t>> * output_dims)176 static STATUS TraceOutput(const AnfNodePtr &node, std::vector<std::pair<AnfNodePtr, int64_t>> *outputs,
177                           std::vector<std::string> *output_names, std::vector<std::vector<int64_t>> *output_dims) {
178   static size_t iter = 0;
179   CHECK_NULL_RETURN(node);
180   if (utils::isa<ParameterPtr>(node) || utils::isa<ValueNode>(node)) {
181     MS_LOG(INFO) << "Name of graph output value node is : " << node->fullname_with_scope();
182     outputs->emplace_back(std::pair<AnfNodePtr, int64_t>(node, 0));
183     output_names->push_back(node->fullname_with_scope());
184     output_dims->emplace_back(std::vector<int64_t>());
185     return lite::RET_OK;
186   }
187   AnfNodePtr cur_node = node;
188   CNodePtr pre_node = nullptr;
189   while (cur_node->isa<CNode>() && IsPrimitiveCNode(cur_node, prim::kPrimTupleGetItem)) {
190     auto tmp = cur_node->cast<CNodePtr>();
191     CHECK_NULL_RETURN(tmp);
192     pre_node = tmp;
193     cur_node = tmp->input(kTupleGetItemFirstInputIdx);
194     CHECK_NULL_RETURN(cur_node);
195   }
196   auto cnode = cur_node->cast<CNodePtr>();
197   CHECK_NULL_RETURN(cnode);
198   std::string name = GetCNodeFuncName(cnode);
199   iter++;
200   MS_LOG(INFO) << "Func name of cnode " << name << " ,trace iter: " << iter;
201   if ((name == kMakeTuple) || (name == kMakeList)) {
202     for (size_t i = 1; i < cnode->size(); ++i) {
203       auto make_tuple_input = cnode->input(i);
204       if (opt::CheckPrimitiveType(make_tuple_input, prim::kPrimUpdateState) ||
205           opt::CheckPrimitiveType(make_tuple_input, prim::kPrimLoad)) {
206         continue;
207       }
208       if (TraceOutput(make_tuple_input, outputs, output_names, output_dims) != lite::RET_OK) {
209         MS_LOG(ERROR) << "The input[ " << i << "]"
210                       << " trace output failed, name: " << name;
211         return lite::RET_ERROR;
212       }
213     }
214   } else if (name == prim::kPrimDepend->name()) {
215     if (cnode->size() < kDependInputNum) {
216       MS_LOG(ERROR) << "Length of inputs is " << cnode->size() << ", which is less than three.";
217       return lite::RET_ERROR;
218     }
219     if (TraceOutput(cnode->input(kDependFirstInputIdx), outputs, output_names, output_dims) != lite::RET_OK) {
220       MS_LOG(ERROR) << "Depend node trace output failed.";
221       return lite::RET_ERROR;
222     }
223   } else {
224     MS_LOG(INFO) << "Name of graph output node is " << cnode->fullname_with_scope();
225     std::string node_name = cnode->fullname_with_scope();
226     std::vector<int64_t> dims;
227     size_t idx = -1;
228     STATUS ret;
229     if (pre_node != nullptr && IsPrimitiveCNode(pre_node, prim::kPrimTupleGetItem)) {
230       ret = GetShapeVectorAndIdxFromCNode(pre_node, &dims, &idx);
231       node_name = node_name + "_" + std::to_string(idx);
232     } else {
233       ret = GetShapeVectorAndIdxFromCNode(cnode, &dims, &idx);
234     }
235     if (ret != lite::RET_OK) {
236       MS_LOG(ERROR) << "Get node shape failed.";
237       return lite::RET_ERROR;
238     }
239     outputs->emplace_back(std::pair<AnfNodePtr, int64_t>(cnode, idx));
240     output_names->emplace_back(node_name);
241     output_dims->emplace_back(dims);
242   }
243   return lite::RET_OK;
244 }
245 
SetFuncGraphOutput(const FuncGraphPtr & graph,const std::vector<AnfNodePtr> & outputs)246 int SetFuncGraphOutput(const FuncGraphPtr &graph, const std::vector<AnfNodePtr> &outputs) {
247   if (graph == nullptr || outputs.empty()) {
248     MS_LOG(DEBUG) << "Input graph is nullptr or outputs is empty";
249     return RET_INPUT_PARAM_INVALID;
250   }
251   if (outputs.size() == 1) {
252     graph->set_output(outputs.front(), false);
253     return RET_OK;
254   }
255   auto make_tuple_prim_ptr = std::make_shared<ops::MakeTuple>();
256   if (make_tuple_prim_ptr == nullptr) {
257     MS_LOG(DEBUG) << "new MakeTuple failed";
258     return lite::RET_NULL_PTR;
259   }
260   auto make_tuple_prim_c = make_tuple_prim_ptr->GetPrim();
261   MS_CHECK_TRUE_MSG(make_tuple_prim_c != nullptr, lite::RET_NULL_PTR, "make_tuple_prim_c is nullptr");
262   auto make_tuple_cnode = graph->NewCNode(make_tuple_prim_c, outputs);
263   if (make_tuple_cnode == nullptr) {
264     MS_LOG(DEBUG) << "new cnode failed";
265     return lite::RET_NULL_PTR;
266   }
267   make_tuple_cnode->set_fullname_with_scope("return tuple");
268   graph->set_output(make_tuple_cnode, false);
269   return RET_OK;
270 }
271 
GetSimpleOpCopyer()272 OpDefCopyer GetSimpleOpCopyer() {
273   return [](const CNodeT &inCNode) -> std::unique_ptr<CNodeT> {
274     std::unique_ptr<CNodeT> newCNode = std::make_unique<CNodeT>();
275     if (newCNode == nullptr) {
276       return nullptr;
277     }
278 
279     newCNode->name = inCNode.name;
280     newCNode->quantType = inCNode.quantType;
281     newCNode->primitive = std::make_unique<schema::PrimitiveT>();
282     newCNode->primitive->value.type = inCNode.primitive->value.type;
283     return newCNode;
284   };
285 }
286 
AddTensor2Node(schema::MetaGraphT * graphT,uint32_t nodeIdx,std::unique_ptr<TensorT> tensor,InsertPlace place)287 STATUS AddTensor2Node(schema::MetaGraphT *graphT, uint32_t nodeIdx, std::unique_ptr<TensorT> tensor,
288                       InsertPlace place) {
289   MS_CHECK_TRUE_MSG(graphT != nullptr, RET_NULL_PTR, "graphT is nullptr");
290   if (nodeIdx >= graphT->nodes.size()) {
291     MS_LOG(ERROR) << "nodeIdx out of range: " << nodeIdx;
292     return RET_PARAM_INVALID;
293   }
294   graphT->allTensors.emplace_back(std::move(tensor));
295   uint32_t newTensorIdx = graphT->allTensors.size() - 1;
296   auto node = graphT->nodes.at(nodeIdx).get();
297   MS_CHECK_TRUE_MSG(node != nullptr, RET_NULL_PTR, "node is nullptr");
298   if (place == kBefore) {
299     node->inputIndex.emplace_back(newTensorIdx);
300   } else {
301     node->outputIndex.emplace_back(newTensorIdx);
302   }
303   return RET_OK;
304 }
305 
ReplaceTensorOfNode(schema::MetaGraphT * graphT,uint32_t nodeIdx,uint32_t inTensorIdx,std::unique_ptr<TensorT> tensor)306 STATUS ReplaceTensorOfNode(schema::MetaGraphT *graphT, uint32_t nodeIdx, uint32_t inTensorIdx,
307                            std::unique_ptr<TensorT> tensor) {
308   MS_ASSERT(graphT != nullptr);
309   if (nodeIdx >= graphT->nodes.size()) {
310     MS_LOG(ERROR) << "nodeIdx out of range: " << nodeIdx;
311     return RET_PARAM_INVALID;
312   }
313   auto node = graphT->nodes.at(nodeIdx).get();
314   MS_CHECK_TRUE_MSG(node != nullptr, RET_NULL_PTR, "node is nullptr");
315   if (inTensorIdx >= graphT->allTensors.size()) {
316     MS_LOG(ERROR) << "inTensorIdx out of range: " << nodeIdx;
317     return RET_PARAM_INVALID;
318   }
319   if (!IsContain(node->inputIndex, inTensorIdx)) {
320     MS_LOG(ERROR) << "inTensorIdx(" << inTensorIdx << ") is not a inputIdx of node(" << nodeIdx << ")";
321     return RET_PARAM_INVALID;
322   }
323   graphT->allTensors.at(inTensorIdx).swap(tensor);
324   return RET_OK;
325 }
326 
InsertNode(schema::MetaGraphT * graphT,uint32_t existNodeIdx,InsertPlace place,size_t inoutIndex,std::unique_ptr<CNodeT> toAddNode,STATUS * errorCode,int * insert_num,const OpDefCopyer & opDefCopyer)327 NodeIter InsertNode(schema::MetaGraphT *graphT, uint32_t existNodeIdx, InsertPlace place, size_t inoutIndex,
328                     std::unique_ptr<CNodeT> toAddNode, STATUS *errorCode, int *insert_num,
329                     const OpDefCopyer &opDefCopyer) {
330   MS_ASSERT(graphT != nullptr);
331   MS_ASSERT(errorCode != nullptr);
332   if (existNodeIdx >= graphT->nodes.size()) {
333     MS_LOG(ERROR) << "nodeIdx out of range: " << existNodeIdx;
334     return graphT->nodes.end();
335   }
336   auto node_iter = graphT->nodes.begin() + existNodeIdx;
337   MS_ASSERT(node_iter != graphT->nodes.begin());
338   MS_ASSERT((*node_iter) != nullptr);
339   return InsertNode(graphT, node_iter, place, inoutIndex, std::move(toAddNode), errorCode, insert_num);
340 }
341 
InsertNode(schema::MetaGraphT * graphT,NodeIter existNodeIter,InsertPlace place,size_t inoutIndexIdx,std::unique_ptr<CNodeT> toAddNode,STATUS * errorCode,int * insert_num,const OpDefCopyer & opDefCopyer)342 NodeIter InsertNode(schema::MetaGraphT *graphT, NodeIter existNodeIter, InsertPlace place, size_t inoutIndexIdx,
343                     std::unique_ptr<CNodeT> toAddNode, STATUS *errorCode, int *insert_num,
344                     const OpDefCopyer &opDefCopyer) {
345   MS_ASSERT(graphT != nullptr);
346   MS_ASSERT(errorCode != nullptr);
347   if (place == kBefore) {
348     return InsertNodeBefore(graphT, existNodeIter, inoutIndexIdx, std::move(toAddNode), errorCode, insert_num,
349                             opDefCopyer);
350   } else if (place == kAfter) {
351     return InsertNodeAfter(graphT, existNodeIter, inoutIndexIdx, std::move(toAddNode), errorCode, insert_num,
352                            opDefCopyer);
353   } else {
354     MS_LOG(ERROR) << "Invalid InsertPlace : " << place;
355     return graphT->nodes.end();
356   }
357 }
358 
InsertNodeBefore(schema::MetaGraphT * graphT,NodeIter existNodeIter,size_t inputIndexIdx,std::unique_ptr<CNodeT> toAddNodeIn,STATUS * errorCode,int * insert_num,const OpDefCopyer & opDefCopyer)359 NodeIter InsertNodeBefore(schema::MetaGraphT *graphT, NodeIter existNodeIter, size_t inputIndexIdx,
360                           std::unique_ptr<CNodeT> toAddNodeIn, STATUS *errorCode, int *insert_num,
361                           const OpDefCopyer &opDefCopyer) {
362   MS_ASSERT(graphT != nullptr);
363   MS_ASSERT(errorCode != nullptr);
364   auto &existNode = *existNodeIter;
365   MS_ASSERT(existNode != nullptr);
366   MS_ASSERT(existNode->inputIndex.size() > inputIndexIdx);
367   MS_ASSERT(toAddNodeIn != nullptr);
368   auto preTensorIdx = existNode->inputIndex.at(inputIndexIdx);
369   MS_ASSERT(graphT->allTensors.size() > preTensorIdx);
370 
371   auto preNodeIdxes = GetInputNodeIdx(*graphT, *(existNode), inputIndexIdx);
372   size_t insert_node_num = preNodeIdxes.empty() ? 1 : preNodeIdxes.size();
373   std::vector<std::unique_ptr<CNodeT>> toAddNodes;
374   for (size_t i = 0; i < insert_node_num; ++i) {
375     auto &preTensor = graphT->allTensors.at(preTensorIdx);
376     MS_ASSERT(preTensor != nullptr);
377     auto toAddTensor = CopyTensorDefT(preTensor);
378     if (toAddTensor == nullptr) {
379       *errorCode = RET_NULL_PTR;
380       MS_LOG(ERROR) << "Copy Tensor failed";
381       return graphT->nodes.end();
382     }
383     toAddTensor->nodeType = NodeType_CNode;
384     toAddTensor->refCount = 0;
385     toAddTensor->data.clear();
386     MS_ASSERT(toAddNodeIn->primitive != nullptr);
387     if (toAddNodeIn->primitive->value.type == schema::PrimitiveType_QuantDTypeCast) {
388       auto prim = toAddNodeIn->primitive->value.AsQuantDTypeCast();
389       MS_ASSERT(prim != nullptr);
390       if (prim->src_t == TypeId::kNumberTypeUInt8) {
391         if (preTensor->dataType == TypeId::kNumberTypeUInt8) {
392           toAddTensor->quantParams.front()->zeroPoint -= kZeroPointGap;
393         } else {
394           preTensor->quantParams.front()->zeroPoint += kZeroPointGap;
395         }
396       } else if (prim->dst_t == TypeId::kNumberTypeUInt8) {
397         if (preTensor->dataType == TypeId::kNumberTypeInt8) {
398           toAddTensor->quantParams.front()->zeroPoint += kZeroPointGap;
399         } else {
400           preTensor->quantParams.front()->zeroPoint -= kZeroPointGap;
401         }
402       }
403       preTensor->dataType = prim->src_t;
404       toAddTensor->dataType = prim->dst_t;
405     }
406     graphT->allTensors.emplace_back(std::move(toAddTensor));
407     size_t toAddTensorIdx = graphT->allTensors.size() - 1;
408     auto toAddNode = opDefCopyer(*toAddNodeIn);
409     if (toAddNode == nullptr) {
410       MS_LOG(ERROR) << "copy toAddNodeIn failed";
411       *errorCode = RET_NULL_PTR;
412       return graphT->nodes.end();
413     }
414     if (!preNodeIdxes.empty()) {
415       toAddNode->name = toAddNodeIn->name + "_" + std::to_string(i);
416     }
417     toAddNode->inputIndex.clear();
418     toAddNode->inputIndex.push_back(preTensorIdx);
419     toAddNode->outputIndex.clear();
420     toAddNode->outputIndex.push_back(toAddTensorIdx);
421     for (auto iter = existNode->inputIndex.begin(); iter != existNode->inputIndex.end(); iter++) {
422       if (*iter == preTensorIdx) {
423         *iter = toAddTensorIdx;
424         break;
425       }
426     }
427     toAddNodes.emplace_back(std::move(toAddNode));
428   }
429   for (auto &toAddNode : toAddNodes) {
430     existNodeIter = graphT->nodes.insert(existNodeIter, std::move(toAddNode));
431     existNodeIter++;
432     *insert_num += 1;
433   }
434   *errorCode = RET_OK;
435   return existNodeIter;
436 }
437 
InsertNodeAfter(schema::MetaGraphT * graphT,NodeIter existNodeIter,size_t outputIndexIdx,std::unique_ptr<schema::CNodeT> toAddNodeIn,STATUS * errorCode,int * insert_num,const OpDefCopyer & opDefCopyer)438 NodeIter InsertNodeAfter(schema::MetaGraphT *graphT, NodeIter existNodeIter, size_t outputIndexIdx,
439                          std::unique_ptr<schema::CNodeT> toAddNodeIn, STATUS *errorCode, int *insert_num,
440                          const OpDefCopyer &opDefCopyer) {
441   MS_ASSERT(graphT != nullptr);
442   MS_ASSERT(errorCode != nullptr);
443   auto &existNode = *existNodeIter;
444   MS_ASSERT(existNode != nullptr);
445   MS_ASSERT(existNode->outputIndex.size() > outputIndexIdx);
446   MS_ASSERT(toAddNodeIn != nullptr);
447   auto postTensorIdx = existNode->outputIndex.at(outputIndexIdx);
448   MS_ASSERT(graphT->allTensors.size() > postTensorIdx);
449   auto postNodeIdxes = GetOutputNodeIdx(*graphT, *(existNode), outputIndexIdx);
450   bool is_output_index = IsContain(graphT->outputIndex, postTensorIdx);
451   size_t insert_node_num = (postNodeIdxes.empty() || is_output_index) ? postNodeIdxes.size() + 1 : postNodeIdxes.size();
452   bool has_insert_for_graph_out = postNodeIdxes.empty() || is_output_index;
453   std::vector<std::unique_ptr<schema::CNodeT>> toAddNodes;
454   for (size_t i = 0; i < insert_node_num; ++i) {
455     auto &postTensor = graphT->allTensors.at(postTensorIdx);
456     MS_ASSERT(postTensor != nullptr);
457     auto toAddTensor = CopyTensorDefT(postTensor);
458     if (toAddTensor == nullptr) {
459       MS_LOG(ERROR) << "Copy TensorT failed";
460       *errorCode = RET_NULL_PTR;
461       return graphT->nodes.end();
462     }
463     toAddTensor->nodeType = NodeType_CNode;
464     MS_ASSERT(toAddNodeIn->primitive != nullptr);
465     if (toAddNodeIn->primitive->value.type == schema::PrimitiveType_QuantDTypeCast) {
466       auto prim = toAddNodeIn->primitive->value.AsQuantDTypeCast();
467       MS_ASSERT(prim != nullptr);
468       if (prim->dst_t == TypeId::kNumberTypeUInt8) {
469         if (postTensor->dataType == TypeId::kNumberTypeUInt8) {
470           postTensor->quantParams.front()->zeroPoint -= kZeroPointGap;
471         } else {
472           toAddTensor->quantParams.front()->zeroPoint += kZeroPointGap;
473         }
474       } else if (prim->src_t == TypeId::kNumberTypeUInt8) {
475         if (postTensor->dataType == TypeId::kNumberTypeUInt8) {
476           toAddTensor->quantParams.front()->zeroPoint -= kZeroPointGap;
477         } else {
478           postTensor->quantParams.front()->zeroPoint += kZeroPointGap;
479         }
480       }
481       postTensor->dataType = prim->src_t;
482       toAddTensor->dataType = prim->dst_t;
483     }
484     graphT->allTensors.emplace_back(std::move(toAddTensor));
485     size_t toAddTensorIdx = graphT->allTensors.size() - 1;
486     auto toAddNode = opDefCopyer(*toAddNodeIn);
487     if (toAddNode == nullptr) {
488       MS_LOG(ERROR) << "copy toAddNodeIn failed";
489       *errorCode = RET_NULL_PTR;
490       return graphT->nodes.end();
491     }
492     toAddNode->inputIndex.clear();
493     toAddNode->inputIndex.push_back(postTensorIdx);
494     toAddNode->outputIndex.clear();
495     toAddNode->outputIndex.push_back(toAddTensorIdx);
496     if (!postNodeIdxes.empty()) {
497       toAddNode->name = toAddNodeIn->name + "_" + std::to_string(i);
498     }
499     if (has_insert_for_graph_out) {
500       ReplaceOutput(postTensorIdx, toAddTensorIdx, graphT);
501       has_insert_for_graph_out = false;
502     } else {
503       auto &postNode = graphT->nodes.at(postNodeIdxes[is_output_index ? i - 1 : i]);
504       for (auto iter = postNode->inputIndex.begin(); iter != postNode->inputIndex.end(); iter++) {
505         if (*iter == postTensorIdx) {
506           *iter = toAddTensorIdx;
507         }
508       }
509     }
510     toAddNodes.emplace_back(std::move(toAddNode));
511   }
512   for (auto &toAddNode : toAddNodes) {
513     existNodeIter = graphT->nodes.insert(existNodeIter, std::move(toAddNode));
514     existNodeIter++;
515     *insert_num += 1;
516   }
517   *errorCode = RET_OK;
518   return existNodeIter;
519 }
520 
ValidateFileStr(const std::string & modelFile,const std::string & fileType)521 STATUS ValidateFileStr(const std::string &modelFile, const std::string &fileType) {
522   if (modelFile.size() > fileType.size() && modelFile.substr(modelFile.size() - fileType.size()) == fileType) {
523     return RET_OK;
524   } else {
525     return RET_ERROR;
526   }
527 }
528 
SetSubgraphTensorIndices(schema::MetaGraphT * meta_graphT)529 void SetSubgraphTensorIndices(schema::MetaGraphT *meta_graphT) {
530   if (meta_graphT == nullptr) {
531     MS_LOG(ERROR) << "meta_graphT is nullptr.";
532     return;
533   }
534   for (auto &subgraph : meta_graphT->subGraph) {
535     std::vector<uint32_t> subgraph_indices{};
536     subgraph_indices.assign(subgraph->inputIndices.begin(), subgraph->inputIndices.end());
537     subgraph_indices.assign(subgraph->outputIndices.begin(), subgraph->outputIndices.end());
538     for (auto &node_idx : subgraph->nodeIndices) {
539       auto &node = meta_graphT->nodes.at(node_idx);
540       for (auto &input_idx : node->inputIndex) {
541         if (IsContain(subgraph_indices, input_idx)) {
542           continue;
543         } else {
544           subgraph_indices.push_back(input_idx);
545         }
546       }
547       for (auto &output_idx : node->outputIndex) {
548         if (IsContain(subgraph_indices, output_idx)) {
549           continue;
550         } else {
551           subgraph_indices.push_back(output_idx);
552         }
553       }
554     }
555     subgraph->tensorIndices.assign(subgraph_indices.begin(), subgraph_indices.end());
556   }
557 }
558 
GetModelName(const std::string & modelFile)559 std::string GetModelName(const std::string &modelFile) {
560   std::string modelName = modelFile;
561   modelName = modelName.substr(modelName.find_last_of('/') + 1);
562   modelName = modelName.substr(0, modelName.find_last_of('.'));
563   return modelName;
564 }
565 
GetTransposePerm(MetaGraphT * graph,const std::unique_ptr<CNodeT> & cnode)566 std::vector<int> GetTransposePerm(MetaGraphT *graph, const std::unique_ptr<CNodeT> &cnode) {
567   MS_ASSERT(graph != nullptr && cnode != nullptr);
568   std::vector<int> perm;
569   if (cnode->primitive->value.type != schema::PrimitiveType_Transpose) {
570     return perm;
571   }
572   if (cnode->inputIndex.size() < 2) {
573     MS_LOG(ERROR) << "transpose node input size is less than 2.";
574     return perm;
575   }
576   MS_ASSERT(cnode->outputIndex.at(1) < graph->allTensors.size());
577   auto &perm_tensor = graph->allTensors.at(cnode->inputIndex.at(1));
578   if (perm_tensor->data.empty()) {
579     return perm;
580   }
581   MS_ASSERT(perm_tensor->dims.size() != 0);
582   perm.resize(perm_tensor->dims[0]);
583   if (memcpy_s(perm.data(), perm_tensor->dims[0] * sizeof(int), perm_tensor->data.data(),
584                perm_tensor->dims[0] * sizeof(int)) != EOK) {
585     MS_LOG(ERROR) << "memcpy data failed.";
586     return {};
587   }
588   return perm;
589 }
590 
GetAbstractTensorDtype(const abstract::AbstractTensorPtr & tensor)591 TypeId GetAbstractTensorDtype(const abstract::AbstractTensorPtr &tensor) {
592   if (tensor == nullptr || tensor->element() == nullptr) {
593     MS_LOG(ERROR) << "abstract_tensor or abstract_tensor->element() is nullptr";
594     return kTypeUnknown;
595   }
596   auto type_ptr = tensor->element()->GetTypeTrack();
597   MS_CHECK_TRUE_MSG(type_ptr != nullptr, kTypeUnknown, "type_ptr is nullptr");
598   return type_ptr->type_id();
599 }
600 
GetParameterDtype(const ParameterPtr & param_node)601 TypeId GetParameterDtype(const ParameterPtr &param_node) {
602   MS_CHECK_TRUE_MSG(param_node != nullptr, kTypeUnknown, "param_node is nullptr");
603   auto abstract_base = param_node->abstract();
604   MS_CHECK_TRUE_MSG(abstract_base != nullptr, kTypeUnknown, "abstract_base is nullptr");
605   auto abstract_tensor = abstract_base->cast<abstract::AbstractTensorPtr>();
606   MS_CHECK_TRUE_MSG(abstract_tensor != nullptr, kTypeUnknown, "Cast to abstract tensor failed!");
607   auto type_ptr = abstract_tensor->element()->GetTypeTrack();
608   MS_CHECK_TRUE_MSG(type_ptr != nullptr, kTypeUnknown, "type_ptr is nullptr");
609   return type_ptr->type_id();
610 }
611 
UpdateFuncGraphInputsAndOutputsDtype(const FuncGraphPtr & func_graph)612 STATUS UpdateFuncGraphInputsAndOutputsDtype(const FuncGraphPtr &func_graph) {
613   MS_ASSERT(func_graph != nullptr);
614   // update graph inputs dtype
615   size_t idx = 0;
616   for (auto &input : func_graph->get_inputs()) {
617     TypeId type = GetParameterDtype(input->cast<ParameterPtr>());
618     ConverterInnerContext::GetInstance()->UpdateGraphInputDType(idx, type);
619     idx++;
620   }
621   // update graph outputs dtype
622   auto graph_return = func_graph->get_return();
623   idx = 0;
624   for (auto &input : graph_return->inputs()) {
625     if (input->isa<CNode>()) {
626       if (utils::isa<abstract::AbstractTuple>(input->abstract())) {
627         auto tuple = std::reinterpret_pointer_cast<abstract::AbstractTuple>(input->abstract());
628         if (tuple == nullptr) {
629           MS_LOG(ERROR) << "tuple is nullptr";
630           return RET_ERROR;
631         }
632         for (const auto &tuple_item : tuple->elements()) {
633           if (utils::isa<abstract::AbstractTuple>(tuple_item)) {
634             continue;
635           }
636           TypeId type = GetAbstractTensorDtype(tuple_item->cast<abstract::AbstractTensorPtr>());
637           ConverterInnerContext::GetInstance()->UpdateGraphOutputDType(idx, type);
638           idx++;
639         }
640       } else if (utils::isa<abstract::AbstractTensor>(input->abstract())) {
641         TypeId type = GetAbstractTensorDtype(input->abstract()->cast<abstract::AbstractTensorPtr>());
642         ConverterInnerContext::GetInstance()->UpdateGraphOutputDType(idx, type);
643         idx++;
644       } else {
645         ConverterInnerContext::GetInstance()->UpdateGraphOutputDType(idx, kTypeUnknown);
646         idx++;
647       }
648     }
649   }
650   return RET_OK;
651 }
652 
GetFuncGraphOutputsInfo(const FuncGraphPtr & func_graph,std::vector<std::pair<AnfNodePtr,int64_t>> * outputs,std::vector<std::string> * output_names,std::vector<std::vector<int64_t>> * output_dims)653 STATUS GetFuncGraphOutputsInfo(const FuncGraphPtr &func_graph, std::vector<std::pair<AnfNodePtr, int64_t>> *outputs,
654                                std::vector<std::string> *output_names, std::vector<std::vector<int64_t>> *output_dims) {
655   MS_CHECK_TRUE_MSG(outputs != nullptr, lite::RET_ERROR, "Output is nullptr.");
656   MS_CHECK_TRUE_MSG(output_names != nullptr, lite::RET_ERROR, "Output names is nullptr.");
657   MS_CHECK_TRUE_MSG(output_dims != nullptr, lite::RET_ERROR, "Output dims is nullptr.");
658   AnfNodePtr return_input = func_graph->output();
659   CHECK_NULL_RETURN(return_input);
660   if (TraceOutput(return_input, outputs, output_names, output_dims) != lite::RET_OK) {
661     MS_LOG(ERROR) << "Trace output failed.";
662     return lite::RET_ERROR;
663   }
664   return lite::RET_OK;
665 }
666 
UpdateGraphOutputName(schema::MetaGraphT * meta_graph)667 STATUS UpdateGraphOutputName(schema::MetaGraphT *meta_graph) {
668   MS_CHECK_TRUE_MSG(meta_graph != nullptr, RET_NULL_PTR, "meta_graph is nullptr");
669   auto output_names = ConverterInnerContext::GetInstance()->GetGraphOutputTensorNames();
670   if (output_names.size() > meta_graph->outputIndex.size()) {
671     MS_LOG(ERROR) << "the num of setting output_names is greater than actual, " << output_names.size() << " > "
672                   << meta_graph->outputIndex.size() << ".";
673     ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
674     return RET_ERROR;
675   }
676   for (size_t idx = 0; idx < output_names.size(); idx++) {
677     auto &tensor = meta_graph->allTensors.at(meta_graph->outputIndex.at(idx));
678     tensor->name = output_names.at(idx);
679   }
680   return RET_OK;
681 }
682 
TransferMetaGraph(const schema::MetaGraphT & graph,void ** model_buf,size_t * size)683 int TransferMetaGraph(const schema::MetaGraphT &graph, void **model_buf, size_t *size) {
684   if (model_buf == nullptr) {
685     MS_LOG(ERROR) << "input model_buf invalid";
686     return RET_ERROR;
687   }
688   if (size == nullptr) {
689     MS_LOG(ERROR) << "input size invalid";
690     return RET_ERROR;
691   }
692 
693   /* model_buf malloc here, free outside */
694   if (*model_buf != nullptr) {
695     MS_LOG(ERROR) << "input model_buf must be nullptr";
696     return RET_ERROR;
697   }
698   flatbuffers::FlatBufferBuilder builder(MAX_GRAPH_SIZE);
699   auto offset = schema::MetaGraph::Pack(builder, &graph);
700   builder.Finish(offset);
701   schema::FinishMetaGraphBuffer(builder, offset);
702   *size = builder.GetSize();
703   auto content = builder.GetBufferPointer();
704   if (content == nullptr) {
705     MS_LOG(ERROR) << "GetBufferPointer nullptr";
706     return RET_ERROR;
707   }
708   *model_buf = new (std::nothrow) char[*size];
709   if (*model_buf == nullptr) {
710     MS_LOG(ERROR) << "malloc model_buf failed";
711     return RET_ERROR;
712   }
713   return memcpy_s(*model_buf, *size, content, *size);
714 }
715 
InitEncryptKey(const std::shared_ptr<ConverterPara> & param,unsigned char * encKey,size_t * keyLen)716 int InitEncryptKey(const std::shared_ptr<ConverterPara> &param, unsigned char *encKey, size_t *keyLen) {
717   if (!param->enable_encryption) {
718     return RET_OK;
719   }
720   if (param->encrypt_key.empty()) {
721     MS_LOG(ERROR) << "param->encrypt_key is empty.";
722     return RET_INPUT_PARAM_INVALID;
723   }
724   *keyLen = lite::Hex2ByteArray(param->encrypt_key, encKey, kEncMaxLen);
725   if (*keyLen != kEncMaxLen) {
726     MS_LOG(ERROR) << "enc_key must expressed in hexadecimal characters "
727                   << " and only support AES-GCM method and the key length is " << kEncMaxLen;
728     return RET_INPUT_PARAM_INVALID;
729   }
730 
731   return RET_OK;
732 }
733 }  // namespace lite
734 }  // namespace mindspore
735