• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "tools/converter/parser/onnx/onnx_model_parser.h"
17 #include <algorithm>
18 #include <map>
19 #include <memory>
20 #include <queue>
21 #include <set>
22 #include <unordered_map>
23 #include <utility>
24 #include <vector>
25 #include "include/registry/node_parser_registry.h"
26 #include "ir/func_graph.h"
27 #include "mindspore/core/ops/nn_ops.h"
28 #include "nnacl/op_base.h"
29 #include "ops/auto_generate/gen_lite_ops.h"
30 #include "ops/make_tuple.h"
31 #include "ops/return.h"
32 #include "ops/tensor_list_stack.h"
33 #include "ops/tuple_get_item.h"
34 #include "src/common/log_util.h"
35 #include "tools/common/graph_util.h"
36 #include "tools/common/protobuf_utils.h"
37 #include "tools/common/tensor_util.h"
38 #include "tools/converter/converter_context.h"
39 #include "tools/converter/parser/lite_model_parser_creator.h"
40 #include "tools/converter/parser/onnx/onnx_einsum_adjust.h"
41 #include "tools/converter/parser/onnx/onnx_inputs_adjust.h"
42 #include "tools/converter/parser/onnx/onnx_megatron_op_adjust.h"
43 #include "tools/converter/parser/onnx/onnx_nonzero_adjust.h"
44 #include "tools/converter/parser/onnx/onnx_pad_adjust.h"
45 #include "tools/converter/parser/onnx/onnx_concat_adjust.h"
46 #include "tools/converter/parser/onnx/onnx_quantize_linear_adjust.h"
47 #include "tools/converter/parser/onnx/onnx_deform_conv2d_adjust.h"
48 #include "tools/converter/parser/onnx/onnx_custom_op_adjust.h"
49 #include "tools/converter/parser/parser_utils.h"
50 #include "tools/converter/parser/unify_format.h"
51 #include "tools/converter/quantizer/quant_param_holder.h"
52 #include "tools/optimizer/common/gllo_utils.h"
53 #include "tools/converter/parser/onnx/onnx_dtype_adjust.h"
54 
55 using mindspore::converter::kFmkTypeOnnx;
56 namespace mindspore {
57 namespace lite {
58 namespace {
59 constexpr int kTensorListDatasize = 3;
60 constexpr int kTypeIndex = 0;
61 constexpr int kElementShapeIndex = 1;
62 constexpr int kTensorsNumIndex = 2;
63 
Onnx2AnfAdjust(const std::set<FuncGraphPtr> & all_func_graphs,const converter::ConverterParameters & flag)64 int Onnx2AnfAdjust(const std::set<FuncGraphPtr> &all_func_graphs, const converter::ConverterParameters &flag) {
65   for (const auto &func_graph : all_func_graphs) {
66     CHECK_NULL_RETURN(func_graph);
67     if (!OnnxMegatronOpAdjust::Adjust(func_graph, flag)) {
68       MS_LOG(ERROR) << "onnx magatron adjust failed.";
69       ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
70       return RET_ERROR;
71     }
72     if (!OnnxInputAdjust::Adjust(func_graph, flag)) {
73       MS_LOG(ERROR) << "onnx adjust failed.";
74       ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
75       return RET_ERROR;
76     }
77     if (!OnnxDtypeAdjust::Adjust(func_graph, flag)) {
78       MS_LOG(ERROR) << "onnx dtype adjust failed!";
79       ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
80       return RET_ERROR;
81     }
82     if (!OnnxPadAdjust::Adjust(func_graph)) {
83       MS_LOG(ERROR) << "onnx pad adjust failed.";
84       ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
85       return RET_ERROR;
86     }
87     if (!OnnxConcatAdjust::Adjust(func_graph)) {
88       MS_LOG(ERROR) << "onnx OnnxConcatOp adjust failed.";
89       ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
90       return RET_ERROR;
91     }
92     if (!OnnxNonZeroAdjust::Adjust(func_graph)) {
93       MS_LOG(ERROR) << "onnx nonzero adjust failed.";
94       ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
95       return RET_ERROR;
96     }
97     if (!OnnxEinsumAdjust::Adjust(func_graph)) {
98       MS_LOG(ERROR) << "onnx einsum adjust failed.";
99       ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
100       return RET_ERROR;
101     }
102     if (!OnnxQuantizeLinearAdjust::Adjust(func_graph)) {
103       MS_LOG(ERROR) << "onnx quantize linear adjust failed.";
104       ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
105       return RET_ERROR;
106     }
107     if (!OnnxDeformConv2dAdjust::Adjust(func_graph)) {
108       MS_LOG(ERROR) << "onnx MMCVModulatedDeformConv2d adjust failed.";
109       ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
110       return RET_ERROR;
111     }
112     if (!OnnxCustomOpAdjust::Adjust(func_graph)) {
113       MS_LOG(ERROR) << "onnx OnnxCustomOp adjust failed.";
114       ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
115       return RET_ERROR;
116     }
117   }
118   return RET_OK;
119 }
120 
CreateConstParamter(const FuncGraphPtr & anf_graph,int val)121 ParameterPtr CreateConstParamter(const FuncGraphPtr &anf_graph, int val) {
122   MS_CHECK_TRUE_RET(anf_graph != nullptr, nullptr);
123   auto const_node = anf_graph->add_parameter();
124   MS_CHECK_TRUE_RET(const_node != nullptr, nullptr);
125   auto const_abstract = CreateTensorAbstract({}, kNumberTypeInt32);
126   if (const_abstract == nullptr) {
127     MS_LOG(ERROR) << "Create tensor abstarct failed";
128     return nullptr;
129   }
130   const_node->set_abstract(const_abstract);
131   int *tensor_data = new (std::nothrow) int[1];
132   if (tensor_data == nullptr) {
133     MS_LOG(ERROR) << "new int[] failed";
134     return nullptr;
135   }
136   tensor_data[0] = val;
137   auto tensor_info = CreateTensorInfo(tensor_data, sizeof(int), {1}, kNumberTypeInt32);
138   if (tensor_info == nullptr) {
139     MS_LOG(ERROR) << "create tensor info failed.";
140     delete[] tensor_data;
141     tensor_data = nullptr;
142     return nullptr;
143   }
144   delete[] tensor_data;
145   tensor_data = nullptr;
146   const_node->set_default_param(tensor_info);
147   return const_node;
148 }
149 
CreateValueNode(const schema::PrimitiveType & op_type)150 ValueNodePtr CreateValueNode(const schema::PrimitiveType &op_type) {
151   auto node_type = schema::EnumNamePrimitiveType(op_type);
152   auto op_primc_fns = ops::OpPrimCRegister::GetInstance().GetPrimCMap();
153   if (op_primc_fns.find(node_type) == op_primc_fns.end()) {
154     MS_LOG(ERROR) << "have no func to create primitive.";
155     return nullptr;
156   }
157   auto prim = op_primc_fns[node_type]();
158   if (prim == nullptr) {
159     MS_LOG(ERROR) << "cannot create primitive.";
160     return nullptr;
161   }
162   return NewValueNode(prim);
163 }
164 
AddIterNumsUpdateEdge(const FuncGraphPtr & anf_graph,std::vector<AnfNodePtr> * return_new_inputs,const std::unordered_map<std::string,AnfNodePtr> & anf_nodes_map,const std::string & trip_cout_name,const std::string & loop_node_name)165 STATUS AddIterNumsUpdateEdge(const FuncGraphPtr &anf_graph, std::vector<AnfNodePtr> *return_new_inputs,
166                              const std::unordered_map<std::string, AnfNodePtr> &anf_nodes_map,
167                              const std::string &trip_cout_name, const std::string &loop_node_name) {
168   CHECK_NULL_RETURN(anf_graph);
169   CHECK_NULL_RETURN(return_new_inputs);
170   // trip_cout need -1 after every iteration
171   auto sub_value_node = CreateValueNode(schema::PrimitiveType_SubFusion);
172   if (sub_value_node == nullptr) {
173     MS_LOG(ERROR) << "create sub failed.";
174     return RET_NULL_PTR;
175   }
176   auto trip_cout_paramter_iter = anf_nodes_map.find(trip_cout_name);
177   if (trip_cout_paramter_iter == anf_nodes_map.end()) {
178     MS_LOG(ERROR) << "cannot find " << trip_cout_name;
179     return RET_ERROR;
180   }
181   auto &trip_cout_paramter = trip_cout_paramter_iter->second;
182   if (trip_cout_paramter == nullptr) {
183     MS_LOG(ERROR) << "trip_cout_paramter found failed";
184     return RET_ERROR;
185   }
186   auto const_one_parameter = CreateConstParamter(anf_graph, 1);
187   MS_CHECK_TRUE_MSG(const_one_parameter != nullptr, RET_ERROR, "create const parameter return nullptr");
188   const_one_parameter->set_name(loop_node_name + "_index_update_parameter");
189 
190   std::vector<AnfNodePtr> sub_inputs = {sub_value_node, trip_cout_paramter, const_one_parameter};
191   auto sub_cnode = anf_graph->NewCNode(sub_inputs);
192   if (sub_cnode == nullptr) {
193     MS_LOG(ERROR) << "new cnode error";
194     return RET_ERROR;
195   }
196   sub_cnode->set_fullname_with_scope(loop_node_name + "_sub");
197   sub_cnode->set_abstract(trip_cout_paramter->abstract());
198   return_new_inputs->insert(return_new_inputs->begin() + 1, sub_cnode);
199   return RET_OK;
200 }
201 
GetCNodeFromControlFlowNodesMap(const std::string & loop_node_name,const std::unordered_map<std::string,std::unordered_map<std::string,AnfNodePtr> * > & control_nodes_map)202 CNodePtr GetCNodeFromControlFlowNodesMap(
203   const std::string &loop_node_name,
204   const std::unordered_map<std::string, std::unordered_map<std::string, AnfNodePtr> *> &control_nodes_map) {
205   auto iter1 = control_nodes_map.find(loop_node_name);
206   if (iter1 == control_nodes_map.end()) {
207     return nullptr;
208   }  // namespace
209   auto iter2 = iter1->second->find(loop_node_name);
210   if (iter2 == iter1->second->end()) {
211     return nullptr;
212   }
213   return iter2->second->cast<CNodePtr>();
214 }
215 
BuildReturnNode(const FuncGraphPtr & anf_graph,const std::vector<AnfNodePtr> & return_inputs)216 STATUS BuildReturnNode(const FuncGraphPtr &anf_graph, const std::vector<AnfNodePtr> &return_inputs) {
217   MS_CHECK_TRUE_RET(anf_graph != nullptr, RET_NULL_PTR);
218   auto return_prim = std::make_shared<ops::Return>();
219   if (return_prim == nullptr) {
220     MS_LOG(ERROR) << "new Return failed";
221     return RET_NULL_PTR;
222   }
223   if (return_inputs.empty()) {
224     MS_LOG(ERROR) << "return input is empty";
225     return RET_ERROR;
226   }
227   auto input = return_inputs[0];
228   MS_EXCEPTION_IF_NULL(input);
229   auto abstract = input->abstract();
230   if (abstract == nullptr) {
231     MS_LOG(ERROR) << "Input node abstract is null, node: " << input->fullname_with_scope();
232     return RET_ERROR;
233   }
234 
235   auto return_prim_c = return_prim->GetPrim();
236   CHECK_NULL_RETURN(return_prim_c);
237   auto return_cnode = anf_graph->NewCNode(return_prim_c, return_inputs);
238   if (return_cnode == nullptr) {
239     MS_LOG(ERROR) << "new cnode error";
240     return RET_ERROR;
241   }
242   return_cnode->set_fullname_with_scope("Return");
243   return_cnode->set_abstract(abstract);
244   anf_graph->set_return(return_cnode);
245   return RET_OK;
246 }
247 
BuildParameterNode(const ParameterPtr & parameter_node,const onnx::TensorProto & tensor,const std::string & model_file,std::map<std::string,std::pair<size_t,uint8_t * >> * external_datas)248 STATUS BuildParameterNode(const ParameterPtr &parameter_node, const onnx::TensorProto &tensor,
249                           const std::string &model_file,
250                           std::map<std::string, std::pair<size_t, uint8_t *>> *external_datas) {
251   MS_CHECK_TRUE_RET(parameter_node != nullptr, RET_NULL_PTR);
252   auto data_type = OnnxNodeParser::GetDataTypeFromOnnx(static_cast<onnx::TensorProto_DataType>(tensor.data_type()));
253   if (data_type == kTypeUnknown) {
254     MS_LOG(ERROR) << "not support onnx data type " << static_cast<onnx::TensorProto_DataType>(tensor.data_type());
255     return RET_ERROR;
256   }
257   std::vector<int64_t> shape_vector(tensor.dims().begin(), tensor.dims().end());
258   auto abstract_tensor = CreateTensorAbstract(shape_vector, data_type);
259   if (abstract_tensor == nullptr) {
260     MS_LOG(ERROR) << "Create tensor abstract failed";
261     return RET_ERROR;
262   }
263   parameter_node->set_abstract(abstract_tensor);
264   parameter_node->set_name(tensor.name());
265 
266   tensor::TensorPtr tensor_info;
267   if (tensor.data_location() != onnx::TensorProto::EXTERNAL) {
268     tensor_info = OnnxNodeParser::CopyOnnxTensorData(tensor);
269     if (tensor_info == nullptr) {
270       MS_LOG(ERROR) << "copy data failed.";
271       return RET_ERROR;
272     }
273   } else {
274     tensor_info = std::make_shared<tensor::Tensor>(data_type, shape_vector);
275     MS_CHECK_TRUE_MSG(tensor_info != nullptr, RET_NULL_PTR, "create tensor_info return nullptr");
276     std::vector<int> shape;
277     std::transform(shape_vector.begin(), shape_vector.end(), std::back_inserter(shape),
278                    [](const int64_t &value) { return static_cast<int>(value); });
279     auto status = OnnxNodeParser::LoadOnnxExternalTensorData(tensor, tensor_info, model_file, external_datas);
280     if (status != RET_OK) {
281       MS_LOG(ERROR) << "load external data failed.";
282       return status;
283     }
284   }
285   parameter_node->set_default_param(tensor_info);
286   return RET_OK;
287 }
288 
BuildOpOutputs(const onnx::NodeProto & onnx_node,const FuncGraphPtr & anf_graph,std::unordered_map<std::string,AnfNodePtr> * anf_nodes_map,const CNodePtr & cnode)289 STATUS BuildOpOutputs(const onnx::NodeProto &onnx_node, const FuncGraphPtr &anf_graph,
290                       std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map, const CNodePtr &cnode) {
291   CHECK_NULL_RETURN(anf_graph);
292   CHECK_NULL_RETURN(anf_nodes_map);
293   if (onnx_node.output_size() == 1) {
294     auto abstract_tensor = CreateTensorAbstract({}, kNumberTypeFloat32);
295     if (abstract_tensor == nullptr) {
296       MS_LOG(ERROR) << "Create tensor abstarct failed";
297       return RET_ERROR;
298     }
299     cnode->set_abstract(abstract_tensor);
300     anf_nodes_map->emplace(onnx_node.output(0), cnode);
301   } else {
302     AbstractBasePtrList abstract_list;
303     int op_idx = 0;
304     for (const auto &output_name : onnx_node.output()) {
305       auto abstract_tensor = CreateTensorAbstract({}, kNumberTypeFloat32);
306       if (abstract_tensor == nullptr) {
307         MS_LOG(ERROR) << "Create tensor abstarct failed";
308         return RET_ERROR;
309       }
310       abstract_list.emplace_back(abstract_tensor);
311       auto tuple_get_item_prim_ptr = std::make_shared<ops::TupleGetItem>();
312       if (tuple_get_item_prim_ptr == nullptr) {
313         MS_LOG(ERROR) << "new TupleGetItem failed";
314         return RET_NULL_PTR;
315       }
316       auto tuple_get_item_prim_c = tuple_get_item_prim_ptr->GetPrim();
317       MS_CHECK_TRUE_MSG(tuple_get_item_prim_c != nullptr, RET_NULL_PTR, "create tuple_get_item_prim_c return nullptr");
318       auto tuple_get_item_prim = NewValueNode(tuple_get_item_prim_c);
319       MS_CHECK_TRUE_MSG(tuple_get_item_prim != nullptr, RET_NULL_PTR, "create ValueNode return nullptr");
320       auto get_item_value = NewValueNode(MakeValue<int64_t>(op_idx));
321       MS_CHECK_TRUE_MSG(get_item_value != nullptr, RET_NULL_PTR, "create ValueNode return nullptr");
322       std::vector<AnfNodePtr> inputs{tuple_get_item_prim, cnode, get_item_value};
323       CNodePtr get_item_cnode = anf_graph->NewCNode(inputs);
324       if (get_item_cnode == nullptr) {
325         MS_LOG(ERROR) << "new cnode error";
326         return RET_ERROR;
327       }
328       get_item_cnode->set_fullname_with_scope(cnode->fullname_with_scope() + "_getitem_" + std::to_string(op_idx));
329       auto get_item_abstract = CreateTensorAbstract({}, kNumberTypeFloat32);
330       if (get_item_abstract == nullptr) {
331         MS_LOG(ERROR) << "Create tensor abstarct failed";
332         return RET_ERROR;
333       }
334       get_item_cnode->set_abstract(get_item_abstract);
335       anf_nodes_map->emplace(output_name, get_item_cnode);
336       op_idx++;
337     }
338     auto new_abstract_list = std::make_shared<abstract::AbstractTuple>(abstract_list);
339     CHECK_NULL_RETURN(new_abstract_list);
340     cnode->set_abstract(new_abstract_list);
341   }
342   if (onnx_node.op_type() == "Loop" || onnx_node.op_type() == "If") {
343     anf_nodes_map->emplace(onnx_node.name(), cnode);
344   }
345   return RET_OK;
346 }
347 
ConvertConstTensors(const onnx::GraphProto & onnx_graph,const FuncGraphPtr & func_graph_ptr,std::unordered_map<std::string,AnfNodePtr> * anf_nodes_map,const std::string & model_file)348 STATUS ConvertConstTensors(const onnx::GraphProto &onnx_graph, const FuncGraphPtr &func_graph_ptr,
349                            std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map, const std::string &model_file) {
350   CHECK_NULL_RETURN(func_graph_ptr);
351   CHECK_NULL_RETURN(anf_nodes_map);
352   std::map<std::string, std::pair<size_t, uint8_t *>> external_datas;
353   auto free_external_data = [&external_datas]() {
354     for (auto &&item : external_datas) {
355       if (item.second.second) {
356         delete[] item.second.second;
357       }
358     }
359     external_datas.clear();
360   };
361   for (const auto &onnx_const_value : onnx_graph.initializer()) {
362     auto parameter = func_graph_ptr->add_parameter();
363     MS_CHECK_TRUE_MSG(parameter != nullptr, RET_NULL_PTR, "create parameter return nullptr");
364     auto status = BuildParameterNode(parameter, onnx_const_value, model_file, &external_datas);
365     if (status != RET_OK) {
366       MS_LOG(ERROR) << "parameter node build failed.";
367       free_external_data();
368       return status;
369     }
370     anf_nodes_map->emplace(onnx_const_value.name(), parameter);
371   }
372   free_external_data();
373   return RET_OK;
374 }
375 
ConvertGraphInputs(const onnx::GraphProto & onnx_graph,const FuncGraphPtr & func_graph_ptr,std::unordered_map<std::string,AnfNodePtr> * anf_nodes_map)376 STATUS ConvertGraphInputs(const onnx::GraphProto &onnx_graph, const FuncGraphPtr &func_graph_ptr,
377                           std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map) {
378   CHECK_NULL_RETURN(func_graph_ptr);
379   CHECK_NULL_RETURN(anf_nodes_map);
380   for (int i = 0; i < onnx_graph.input().size(); ++i) {
381     const auto &input_value = onnx_graph.input(i);
382     if (anf_nodes_map->find(input_value.name()) != anf_nodes_map->end()) {
383       continue;
384     }
385     auto parameter = func_graph_ptr->add_parameter();
386     MS_CHECK_TRUE_MSG(parameter != nullptr, RET_NULL_PTR, "create parameter return nullptr");
387     auto data_type = OnnxNodeParser::GetDataTypeFromOnnx(
388       static_cast<onnx::TensorProto_DataType>(input_value.type().tensor_type().elem_type()));
389     if (data_type == kTypeUnknown) {
390       MS_LOG(ERROR) << "not support onnx data type "
391                     << static_cast<onnx::TensorProto_DataType>(input_value.type().tensor_type().elem_type());
392       return RET_ERROR;
393     }
394     std::vector<int64_t> shape_vector =
395       ConverterInnerContext::GetInstance()->GetGraphInputTensorShape(input_value.name());
396     if (ConverterInnerContext::GetInstance()->GetGraphInputTensorShapeMapSize() > 0 && shape_vector.empty()) {
397       MS_LOG(WARNING) << "Cannot find name in map. name is " << input_value.name();
398     }
399     if (shape_vector.empty()) {
400       auto onnx_shape = input_value.type().tensor_type().shape().dim();
401       std::transform(onnx_shape.begin(), onnx_shape.end(), std::back_inserter(shape_vector),
402                      [](const onnx::TensorShapeProto_Dimension &val) { return static_cast<int64_t>(val.dim_value()); });
403       std::replace(shape_vector.begin(), shape_vector.end(), 0, -1);
404     }
405     auto abstract_tensor = CreateTensorAbstract(shape_vector, data_type);
406     if (abstract_tensor == nullptr) {
407       MS_LOG(ERROR) << "Create tensor abstarct failed";
408       return RET_ERROR;
409     }
410     parameter->set_abstract(abstract_tensor);
411     parameter->set_name(input_value.name());
412     anf_nodes_map->emplace(input_value.name(), parameter);
413   }
414   return RET_OK;
415 }
416 
ConvertGraphOutputs(const onnx::GraphProto & onnx_graph,const FuncGraphPtr & anf_graph,const std::unordered_map<std::string,AnfNodePtr> & anf_nodes_map)417 STATUS ConvertGraphOutputs(const onnx::GraphProto &onnx_graph, const FuncGraphPtr &anf_graph,
418                            const std::unordered_map<std::string, AnfNodePtr> &anf_nodes_map) {
419   MS_CHECK_TRUE_RET(anf_graph != nullptr, RET_NULL_PTR);
420   std::vector<AnfNodePtr> return_inputs;
421   if (onnx_graph.output_size() == 0) {
422     MS_LOG(ERROR) << "onnx graph has no output";
423     return RET_ERROR;
424   }
425   if (onnx_graph.output_size() > 1) {
426     std::vector<AnfNodePtr> make_tuple_inputs;
427     auto make_tuple_prim_ptr = std::make_shared<ops::MakeTuple>();
428     AbstractBasePtrList elem;
429     if (make_tuple_prim_ptr == nullptr) {
430       MS_LOG(ERROR) << "new MakeTuple failed";
431       return RET_NULL_PTR;
432     }
433     for (const auto &graph_out : onnx_graph.output()) {
434       if (anf_nodes_map.find(graph_out.name()) == anf_nodes_map.end()) {
435         MS_LOG(ERROR) << "graph output get failed.";
436         return RET_ERROR;
437       }
438       auto cnode = anf_nodes_map.at(graph_out.name());
439       if (cnode == nullptr) {
440         MS_LOG(ERROR) << "Can't find input node.";
441         return RET_NOT_FIND_OP;
442       }
443       elem.emplace_back(cnode->abstract());
444       make_tuple_inputs.emplace_back(cnode);
445     }
446     auto make_tuple_prim_c = make_tuple_prim_ptr->GetPrim();
447     CHECK_NULL_RETURN(make_tuple_prim_c);
448     auto make_tuple_cnode = anf_graph->NewCNode(make_tuple_prim_c, make_tuple_inputs);
449     if (make_tuple_cnode == nullptr) {
450       MS_LOG(ERROR) << "new cnode error";
451       return RET_ERROR;
452     }
453 
454     make_tuple_cnode->set_fullname_with_scope("return tuple");
455     make_tuple_cnode->set_abstract(std::make_shared<abstract::AbstractTuple>(elem));
456     return_inputs.emplace_back(make_tuple_cnode);
457   } else {
458     const auto &graph_out = onnx_graph.output(0);
459     if (anf_nodes_map.find(graph_out.name()) == anf_nodes_map.end()) {
460       MS_LOG(ERROR) << "graph output get failed.";
461       return RET_ERROR;
462     }
463     auto cnode = anf_nodes_map.at(graph_out.name());
464     if (cnode == nullptr) {
465       MS_LOG(ERROR) << "Can't find input node.";
466       return RET_NOT_FIND_OP;
467     }
468     return_inputs.emplace_back(cnode);
469   }
470   if (BuildReturnNode(anf_graph, return_inputs) != RET_OK) {
471     MS_LOG(ERROR) << "build return node failed.";
472     return RET_ERROR;
473   }
474   return RET_OK;
475 }
476 
BuildCondGraph(const AnfNodePtr & root_while_node,int inputs_num,const std::string & cond_graph_name)477 FuncGraphPtr BuildCondGraph(const AnfNodePtr &root_while_node, int inputs_num, const std::string &cond_graph_name) {
478   MS_CHECK_TRUE_RET(root_while_node != nullptr, nullptr);
479   auto cond_graph = std::make_shared<FuncGraph>();
480   MS_CHECK_TRUE_MSG(cond_graph != nullptr, nullptr, "create cond_graph return nullptr");
481   CNodePtr less_cnode = nullptr;
482   for (int i = 0; i < inputs_num; i++) {
483     auto input_parameter = cond_graph->add_parameter();
484     MS_CHECK_TRUE_MSG(input_parameter != nullptr, nullptr, "create input_parameter return nullptr");
485     input_parameter->set_name(cond_graph_name + "_input_" + std::to_string(i) + "_parameter");
486     auto input_abstract = CreateTensorAbstract({}, kNumberTypeInt32);
487     if (input_abstract == nullptr) {
488       MS_LOG(ERROR) << "Create tensor abstarct failed";
489       return nullptr;
490     }
491     input_parameter->set_abstract(input_abstract);
492     if (i == 0) {
493       auto zero_parameter = CreateConstParamter(cond_graph, 0);
494       MS_CHECK_TRUE_MSG(zero_parameter != nullptr, nullptr, "create zero_parameter return nullptr");
495       zero_parameter->set_name(root_while_node->fullname_with_scope() + "_const_0");
496       auto less_value_node = CreateValueNode(schema::PrimitiveType_Less);
497       MS_CHECK_TRUE_MSG(less_value_node != nullptr, nullptr, "create less_value_node return nullptr");
498       std::vector<AnfNodePtr> less_inputs = {less_value_node, zero_parameter, input_parameter};
499       less_cnode = cond_graph->NewCNode(less_inputs);
500       if (less_cnode == nullptr) {
501         MS_LOG(ERROR) << "new cnode error";
502         return nullptr;
503       }
504       auto less_abstract = CreateTensorAbstract({}, kNumberTypeBool);
505       if (less_abstract == nullptr) {
506         MS_LOG(ERROR) << "Create tensor abstarct failed";
507         return nullptr;
508       }
509       less_cnode->set_abstract(less_abstract);
510       less_cnode->set_fullname_with_scope(cond_graph_name + "_less_cnode");
511     }
512     if (i == 1) {
513       auto and_value_node = CreateValueNode(schema::PrimitiveType_LogicalAnd);
514       MS_CHECK_TRUE_MSG(and_value_node != nullptr, nullptr, "CreateValueNode failed");
515       std::vector<AnfNodePtr> and_inputs = {and_value_node, less_cnode, input_parameter};
516       auto and_cnode = cond_graph->NewCNode(and_inputs);
517       if (and_cnode == nullptr) {
518         MS_LOG(ERROR) << "new cnode error";
519         return nullptr;
520       }
521       and_cnode->set_abstract(less_cnode->abstract());
522       and_cnode->set_fullname_with_scope(cond_graph_name + "_output_" + std::to_string(0) + "_cnode");
523       auto status = BuildReturnNode(cond_graph, {and_cnode});
524       if (status != RET_OK) {
525         MS_LOG(ERROR) << "build return node failed: " << status;
526         return nullptr;
527       }
528     }
529   }
530   cond_graph->set_attr("graph_name", MakeValue(cond_graph_name));
531   return cond_graph;
532 }
533 
ConvertGraph(api::FuncGraphPtr func_graph)534 FuncGraphPtr ConvertGraph(api::FuncGraphPtr func_graph) {
535   auto impl = func_graph->impl();
536   return std::dynamic_pointer_cast<FuncGraph>(impl);
537 }
538 }  // namespace
539 
BuildBodyGraph(const onnx::NodeProto & loop_node,const onnx::GraphProto & subgraph_proto,int * cond_graph_input_num)540 FuncGraphPtr OnnxModelParser::BuildBodyGraph(const onnx::NodeProto &loop_node, const onnx::GraphProto &subgraph_proto,
541                                              int *cond_graph_input_num) {
542   MS_CHECK_TRUE_RET(cond_graph_input_num != nullptr, nullptr);
543   auto &loop_node_name = loop_node.name();
544   auto node_inputs_num = loop_node.input_size();
545   auto node_outputs_num = loop_node.output_size();
546   // skip trip_cout and cond input,scan_output nums
547   auto act_outputs_num = node_outputs_num - (node_inputs_num - 2);
548   auto loop_body_graph = std::make_shared<FuncGraph>();
549   MS_CHECK_TRUE_MSG(loop_body_graph != nullptr, nullptr, "create loop_body_graph return nullptr");
550   std::unordered_map<std::string, AnfNodePtr> anf_nodes_map;
551   std::vector<AnfNodePtr> gen_subgraph_inputs;
552   auto status = ConvertOnnxGraph(subgraph_proto, loop_body_graph, &anf_nodes_map, &gen_subgraph_inputs, loop_node_name);
553   if (status != RET_OK) {
554     MS_LOG(ERROR) << "convert loop OnnxGraph: " << status;
555     return nullptr;
556   }
557   auto return_node = loop_body_graph->get_return();
558   MS_CHECK_TRUE_MSG(return_node != nullptr, nullptr, "return node of subgraph is nullptr");
559   MS_CHECK_TRUE_RET(return_node->size() == DIMENSION_2D, nullptr);
560   auto return_tuple_cnode = return_node->input(1)->cast<CNodePtr>();
561   MS_CHECK_TRUE_RET(return_tuple_cnode != nullptr, nullptr);
562   auto return_new_inputs = return_tuple_cnode->inputs();
563   return_new_inputs.insert(return_new_inputs.end() - act_outputs_num, gen_subgraph_inputs.begin(),
564                            gen_subgraph_inputs.end());
565 
566   std::string max_trip_count_name = subgraph_proto.input(0).name();
567   status =
568     AddIterNumsUpdateEdge(loop_body_graph, &return_new_inputs, anf_nodes_map, max_trip_count_name, loop_node_name);
569   if (status != RET_OK) {
570     MS_LOG(ERROR) << "add iter nums update edge failed: " << status;
571     return nullptr;
572   }
573   auto root_while_node = GetCNodeFromControlFlowNodesMap(loop_node_name, control_nodes_map_);
574   MS_CHECK_TRUE_MSG(root_while_node != nullptr, nullptr, "cannot find root_while_node is control_nodes_map");
575   std::vector<AnfNodePtr> body_graph_inputs;
576   body_graph_inputs.reserve(subgraph_proto.input_size());
577   for (int j = 0; j < subgraph_proto.input_size(); j++) {
578     body_graph_inputs.emplace_back(anf_nodes_map[subgraph_proto.input(j).name()]);
579   }
580   body_graph_inputs.insert(body_graph_inputs.end(), gen_subgraph_inputs.begin(), gen_subgraph_inputs.end());
581   if (act_outputs_num != 0) {
582     status =
583       AddTensorArrayEdge(loop_body_graph, &return_new_inputs, loop_node_name, &body_graph_inputs, act_outputs_num);
584     if (status != RET_OK) {
585       MS_LOG(ERROR) << "add tensorarray update edge failed: " << status;
586       return nullptr;
587     }
588     // insert tensorliststack after while output
589     status = AddTensorListStackNode(root_while_node, loop_node, act_outputs_num, body_graph_inputs.size());
590     if (status != RET_OK) {
591       MS_LOG(ERROR) << "add tensorliststack node failed: " << status;
592       return nullptr;
593     }
594   }
595   return_tuple_cnode->set_inputs(return_new_inputs);
596   auto body_graph_name = loop_node_name + "_body_graph";
597   for (size_t j = 0; j < body_graph_inputs.size(); j++) {
598     MS_CHECK_TRUE_RET(body_graph_inputs[j] != nullptr, nullptr);
599     auto body_input = body_graph_inputs[j]->cast<ParameterPtr>();
600     MS_CHECK_TRUE_RET(body_input != nullptr, nullptr);
601     body_input->set_name(body_graph_name + "_input_" + std::to_string(j) + "_parameter");
602   }
603   for (size_t j = 1; j < return_new_inputs.size(); j++) {
604     if (utils::isa<CNodePtr>(return_new_inputs[j])) {
605       return_new_inputs[j]->cast<CNodePtr>()->set_fullname_with_scope(body_graph_name + "_output_" +
606                                                                       std::to_string(j - 1) + "_cnode");
607     } else if (utils::isa<ParameterPtr>(return_new_inputs[j])) {
608       return_new_inputs[j]->cast<ParameterPtr>()->set_name(body_graph_name + "_output_" + std::to_string(j - 1) +
609                                                            "_parameter");
610     }
611   }
612   *cond_graph_input_num = return_new_inputs.size() - 1;
613   loop_body_graph->set_attr("graph_name", MakeValue(body_graph_name));
614   return loop_body_graph;
615 }
616 
617 namespace {
CheckOnnxModel(const onnx::GraphProto & onnx_graph)618 STATUS CheckOnnxModel(const onnx::GraphProto &onnx_graph) {
619   // all input should in initialize
620   std::set<std::string> providers;
621   for (const auto &const_tensor : onnx_graph.initializer()) {
622     const auto &name = const_tensor.name();
623     if (providers.count(name) != 0) {
624       MS_LOG(ERROR) << "const tensor repeated";
625       return RET_ERROR;
626     }
627     providers.insert(name);
628   }
629   for (int i = 0; i < onnx_graph.input().size(); ++i) {
630     providers.insert(onnx_graph.input(i).name());
631   }
632   for (const auto &onnx_node : onnx_graph.node()) {
633     for (int i = 0; i < onnx_node.output_size(); i++) {
634       auto &output = onnx_node.output(i);
635       if (providers.count(output) != 0) {
636         MS_LOG(ERROR) << "Output tensor repeated";
637         return RET_ERROR;
638       }
639       providers.insert(output);
640     }
641   }
642   // all output should find
643   for (const auto &onnx_node : onnx_graph.node()) {
644     for (int i = 0; i < onnx_node.input_size(); i++) {
645       auto &input = onnx_node.input(i);
646       if (providers.count(input) == 0) {
647         MS_LOG(WARNING) << "Cannot find input: " << input << " of node: " << onnx_node.name();
648       }
649     }
650   }
651   return RET_OK;
652 }
653 }  // namespace
654 
Parse(const converter::ConverterParameters & flag)655 api::FuncGraphPtr OnnxModelParser::Parse(const converter::ConverterParameters &flag) {
656   auto model_file = flag.model_file;
657   NotSupportOp::GetInstance()->set_fmk_type("ONNX");
658   auto graph = std::make_shared<FuncGraph>();
659   MS_CHECK_TRUE_MSG(graph != nullptr, nullptr, "create FuncGraph failed");
660   res_graph_ = api::MakeShared<api::FuncGraph>(graph);
661   auto status = InitOriginModel(model_file);
662   if (RET_OK != status) {
663     ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
664     MS_LOG(ERROR) << "init origin model failed.";
665     return nullptr;
666   }
667   MS_ASSERT(onnx_root_graph_ != nullptr);
668 
669   status = ConvertOnnxGraph(onnx_root_graph_, graph, &anf_nodes_map_, {}, "root_node");
670   if (RET_OK != status) {
671     ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
672     MS_LOG(ERROR) << "convert onnx graph failed.";
673     return nullptr;
674   }
675   static auto root_func_manager = Manage(graph);
676 
677   for (auto &subgraph : all_subgraphs_) {
678     MS_CHECK_TRUE_RET(subgraph != nullptr, nullptr);
679     subgraph->set_manager(root_func_manager);
680     subgraph->set_attr("fmk", MakeValue(static_cast<int>(converter::kFmkTypeOnnx)));
681   }
682   graph->set_attr("graph_name", MakeValue("main_graph"));
683   graph->set_attr("fmk", MakeValue(static_cast<int>(converter::kFmkTypeOnnx)));
684   if ((status = CommonAnfAdjust(graph)) != RET_OK) {
685     MS_LOG(ERROR) << "AdjustForAnf failed.";
686     ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
687     return nullptr;
688   }
689   std::set<FuncGraphPtr> all_func_graphs = {};
690   GetAllFuncGraph(graph, &all_func_graphs);
691   if ((status = Onnx2AnfAdjust(all_func_graphs, flag)) != RET_OK) {
692     MS_LOG(ERROR) << "Onnx2AnfAdjust failed.";
693     ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
694     return nullptr;
695   }
696   auto unify_format = std::make_shared<UnifyFormatToNHWC>(kFmkTypeOnnx, false, flag.save_type);
697   MS_CHECK_TRUE_MSG(unify_format != nullptr, nullptr, "create unify_format return nullptr");
698   if (!unify_format->Run(graph)) {
699     MS_LOG(ERROR) << "Run insert transpose failed.";
700     return nullptr;
701   }
702   return res_graph_;
703 }
704 
InitOriginModel(const std::string & model_file)705 STATUS OnnxModelParser::InitOriginModel(const std::string &model_file) {
706   MS_CHECK_TRUE_RET(res_graph_ != nullptr, RET_NULL_PTR);
707   auto res_graph = ConvertGraph(res_graph_);
708   auto status = ValidateFileStr(model_file, ".onnx");
709   if (status != RET_OK) {
710     MS_LOG(ERROR) << "INPUT ILLEGAL: modelFile must be *.onnx";
711     return status;
712   }
713   model_file_ = model_file;
714   status = ReadProtoFromBinaryFile(model_file, &onnx_model_);
715   if (status != RET_OK) {
716     MS_LOG(ERROR) << "Read onnx model file failed, model path: " << model_file;
717     ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
718     return status;
719   }
720   OnnxNodeParser::set_opset_version(onnx_model_.opset_import().Get(0).version());
721   OnnxNodeParser::SetOnnxModelFile(model_file);
722   onnx_root_graph_ = onnx_model_.graph();
723   auto fmk_value_node = MakeValue(static_cast<int>(converter::kFmkTypeOnnx));
724   CHECK_NULL_RETURN(fmk_value_node);
725   res_graph->set_attr("fmk", fmk_value_node);
726   return RET_OK;
727 }
728 
ConvertOnnxGraph(const onnx::GraphProto & onnx_graph,const FuncGraphPtr & anf_graph,std::unordered_map<std::string,AnfNodePtr> * anf_nodes_map,std::vector<AnfNodePtr> * extra_subgraph_inputs,const std::string & root_node_name)729 STATUS OnnxModelParser::ConvertOnnxGraph(const onnx::GraphProto &onnx_graph, const FuncGraphPtr &anf_graph,
730                                          std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map,
731                                          std::vector<AnfNodePtr> *extra_subgraph_inputs,
732                                          const std::string &root_node_name) {
733   MS_ASSERT(anf_graph != nullptr && anf_nodes_map != nullptr && extra_subgraph_inputs != nullptr);
734   STATUS status = RET_OK;
735   status = CheckOnnxModel(onnx_graph);
736   if (status != RET_OK) {
737     ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
738     MS_LOG(ERROR) << "input onnx model error: " << status;
739     return status;
740   }
741   status = ConvertConstTensors(onnx_graph, anf_graph, anf_nodes_map, model_file_);
742   if (RET_OK != status) {
743     ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
744     MS_LOG(ERROR) << "convert const nodes failed.";
745     return RET_ERROR;
746   }
747 
748   status = ConvertGraphInputs(onnx_graph, anf_graph, anf_nodes_map);
749   if (RET_OK != status) {
750     ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
751     MS_LOG(ERROR) << "convert graph inputs failed.";
752     return RET_OK;
753   }
754 
755   status = ConvertNodes(onnx_graph, anf_graph, anf_nodes_map, extra_subgraph_inputs, root_node_name);
756   if (RET_OK != status) {
757     ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
758     MS_LOG(ERROR) << "convert nodes failed.";
759     return RET_ERROR;
760   }
761 
762   status = ConvertGraphOutputs(onnx_graph, anf_graph, *anf_nodes_map);
763   if (RET_OK != status) {
764     ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
765     MS_LOG(ERROR) << "convert graph outputs failed.";
766     return RET_ERROR;
767   }
768   // save original output tensor names.
769   if (root_node_name == "root_node") {
770     std::vector<std::string> output_names;
771     std::transform(onnx_graph.output().begin(), onnx_graph.output().end(), std::back_inserter(output_names),
772                    [](auto &graph_output) { return graph_output.name(); });
773     ConverterInnerContext::GetInstance()->SetGraphOutputTensorNames(output_names);
774   }
775   return status;
776 }
777 
SortOnnxNodeIndex(const onnx::GraphProto & onnx_graph)778 std::vector<int> OnnxModelParser::SortOnnxNodeIndex(const onnx::GraphProto &onnx_graph) {
779   std::vector<int> sorted_node_index;
780   std::queue<int> onnx_nodes_queue;
781   std::set<std::string> node_names;
782   // for const tensor
783   for (const auto &const_tensor : onnx_graph.initializer()) {
784     const auto &name = const_tensor.name();
785     node_names.insert(name);
786   }
787   // for graph input
788   for (int i = 0; i < onnx_graph.input().size(); i++) {
789     node_names.insert(onnx_graph.input(i).name());
790   }
791   for (int i = 0; i < onnx_graph.node().size(); i++) {
792     auto onnx_node = onnx_graph.node(i);
793     if (onnx_node.op_type() == "If" || onnx_node.op_type() == "Loop" || has_subgraph_) {
794       sorted_node_index.clear();
795       has_subgraph_ = true;
796       for (int index = 0; index < onnx_graph.node().size(); index++) {
797         sorted_node_index.push_back(index);
798       }
799       return sorted_node_index;
800     }
801     if (onnx_node.op_type() == "Constant") {
802       sorted_node_index.push_back(i);
803       for (auto output_name : onnx_node.output()) {
804         node_names.insert(output_name);
805       }
806     } else {
807       onnx_nodes_queue.push(i);
808     }
809   }
810   bool find = false;
811   int pre_node_index = -1;
812   while (!onnx_nodes_queue.empty()) {
813     auto node_index = onnx_nodes_queue.front();
814     auto onnx_node = onnx_graph.node(node_index);
815     if (std::any_of(onnx_node.input().begin(), onnx_node.input().end(),
816                     [&](const string &name) { return node_names.count(name) == 0 && !name.empty(); })) {
817       onnx_nodes_queue.pop();
818       onnx_nodes_queue.push(node_index);
819       if (!find && pre_node_index == node_index) {
820         MS_LOG(ERROR) << "sort onnx node failed.";
821         return {};
822       }
823       find = false;
824       pre_node_index = pre_node_index == -1 ? node_index : pre_node_index;
825     } else {
826       find = true;
827       pre_node_index = pre_node_index == node_index ? -1 : pre_node_index;
828       sorted_node_index.push_back(node_index);
829       onnx_nodes_queue.pop();
830       for (int i = 0; i < onnx_node.output_size(); i++) {
831         node_names.insert(onnx_node.output(i));
832       }
833     }
834   }
835   return sorted_node_index;
836 }
837 
ConvertNodes(const onnx::GraphProto & onnx_graph,const FuncGraphPtr & anf_graph,std::unordered_map<std::string,AnfNodePtr> * anf_nodes_map,std::vector<AnfNodePtr> * graph_inputs,const std::string & root_node_name)838 STATUS OnnxModelParser::ConvertNodes(const onnx::GraphProto &onnx_graph, const FuncGraphPtr &anf_graph,
839                                      std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map,
840                                      std::vector<AnfNodePtr> *graph_inputs, const std::string &root_node_name) {
841   CHECK_NULL_RETURN(anf_graph);
842   CHECK_NULL_RETURN(anf_nodes_map);
843   auto sorted_node_index = SortOnnxNodeIndex(onnx_graph);
844   if (sorted_node_index.empty()) {
845     MS_LOG(ERROR) << "SortOnnxNodeIndex failed.";
846     return RET_ERROR;
847   }
848   STATUS status = RET_OK;
849   for (auto node_index : sorted_node_index) {
850     const auto &onnx_node = onnx_graph.node(node_index);
851     ops::PrimitiveCPtr primitive_c;
852     auto node_parser = registry::NodeParserRegistry::GetNodeParser(kFmkTypeOnnx, onnx_node.op_type());
853     if (node_parser != nullptr) {
854       primitive_c = node_parser->Parse(onnx_graph, onnx_node)->GetPrim();
855     } else {
856       auto node_parser_builtin = OnnxNodeParserRegistry::GetInstance().GetNodeParser(onnx_node.op_type());
857       if (node_parser_builtin == nullptr) {
858         NotSupportOp::GetInstance()->InsertOp(onnx_node.op_type());
859         status = status == RET_OK ? RET_NOT_FIND_OP : status;
860         MS_LOG(ERROR) << "not support onnx data type " << onnx_node.op_type();
861       }
862       if (status != RET_OK) {
863         continue;
864       }
865       MS_LOG(INFO) << "parse op:" << onnx_node.op_type();
866       primitive_c = node_parser_builtin->Parse(onnx_graph, onnx_node);
867     }
868 
869     if (primitive_c == nullptr) {
870       MS_LOG(ERROR) << "parse node " << onnx_node.op_type() << " failed.";
871       status = RET_ERROR;
872       continue;
873     }
874     if (primitive_c->GetAttr(ops::kOriginalFormat) == nullptr) {
875       primitive_c->AddAttr(mindspore::ops::kOriginalFormat, MakeValue<int64_t>(NCHW));
876     }
877     status = ConvertOpQuantParams(onnx_node, primitive_c);
878     if (status != RET_OK) {
879       MS_LOG(ERROR) << "convert " << onnx_node.op_type() << " quant param failed.";
880       continue;
881     }
882     // build CNode
883     status = BuildCNode(onnx_node, anf_graph, anf_nodes_map, graph_inputs, primitive_c, root_node_name);
884     if (status != RET_OK) {
885       MS_LOG(ERROR) << "build cnode " << onnx_node.op_type() << " failed.";
886     }
887 
888     if (onnx_node.op_type() == "Loop") {
889       child_root_map_[onnx_node.name()] = root_node_name;
890       control_nodes_map_[onnx_node.name()] = anf_nodes_map;
891 
892       status = ConvertLoopOnnxNode(onnx_node, anf_nodes_map, root_node_name);
893       if (status != RET_OK) {
894         MS_LOG(ERROR) << "build loop node  failed.";
895       }
896     }
897     if (onnx_node.op_type() == "If") {
898       child_root_map_[onnx_node.name()] = root_node_name;
899       control_nodes_map_[onnx_node.name()] = anf_nodes_map;
900 
901       status = ConvertIfOnnxNode(onnx_node, anf_nodes_map, root_node_name);
902       if (status != RET_OK) {
903         MS_LOG(ERROR) << "build if node  failed.";
904       }
905     }
906   }
907   return status;
908 }
909 
ConvertIfSubgraph(const onnx::GraphProto & subgraph_proto,const FuncGraphPtr & subgraph,const std::string & subgraph_name,const std::string & if_node_name,const std::string & root_node_name)910 STATUS OnnxModelParser::ConvertIfSubgraph(const onnx::GraphProto &subgraph_proto, const FuncGraphPtr &subgraph,
911                                           const std::string &subgraph_name, const std::string &if_node_name,
912                                           const std::string &root_node_name) {
913   MS_CHECK_TRUE_RET(subgraph != nullptr, RET_NULL_PTR);
914   std::unordered_map<std::string, AnfNodePtr> anf_nodes_map;
915   std::vector<AnfNodePtr> subgraph_extra_inputs;
916   auto status = ConvertOnnxGraph(subgraph_proto, subgraph, &anf_nodes_map, &subgraph_extra_inputs, if_node_name);
917   if (status != RET_OK) {
918     MS_LOG(ERROR) << "convert loop OnnxGraph failed";
919     return status;
920   }
921   subgraph->set_attr("graph_name", MakeValue(subgraph_name));
922   // update subgraph in out name
923   for (int j = 0; j < subgraph_proto.input_size(); j++) {
924     auto input_anode_iter = anf_nodes_map.find(subgraph_proto.input(j).name());
925     if (input_anode_iter == anf_nodes_map.end()) {
926       MS_LOG(ERROR) << "cannot find input anode";
927       return RET_ERROR;
928     }
929     auto input_parameter = input_anode_iter->second->cast<ParameterPtr>();
930     MS_CHECK_TRUE_MSG(input_parameter != nullptr, RET_ERROR, "subgraph input should be a parameter");
931     input_parameter->set_name(subgraph_name + "_input_" + std::to_string(j) + "_parameter");
932   }
933   for (size_t j = 0; j < subgraph_extra_inputs.size(); j++) {
934     auto input_parameter = subgraph_extra_inputs[j]->cast<ParameterPtr>();
935     MS_CHECK_TRUE_MSG(input_parameter != nullptr, RET_ERROR, "subgraph input should be a parameter");
936     input_parameter->set_name(subgraph_name + "_input_" + std::to_string(j + subgraph_proto.input_size()) +
937                               "_parameter");
938   }
939   auto return_node = subgraph->get_return();
940   MS_CHECK_TRUE_MSG(return_node != nullptr, RET_ERROR, "subgraph has no return");
941   MS_CHECK_GE(return_node->size(), kInputSize1, RET_ERROR);
942   std::vector<AnfNodePtr> return_act_inputs;
943   int start_index = 0;
944   if (subgraph_proto.output_size() > 1) {
945     auto return_cnode = return_node->input(1)->cast<CNodePtr>();
946     MS_CHECK_TRUE_RET(return_cnode != nullptr, RET_NULL_PTR);
947     return_act_inputs = return_cnode->inputs();
948     start_index = 1;
949   } else {
950     return_act_inputs = {return_node->input(1)};
951   }
952   for (size_t j = start_index; j < return_act_inputs.size(); j++) {
953     if (utils::isa<CNodePtr>(return_act_inputs[j])) {
954       return_act_inputs[j]->cast<CNodePtr>()->set_fullname_with_scope(subgraph_name + "_output_" +
955                                                                       std::to_string(j - start_index) + "_cnode");
956     } else if (utils::isa<ParameterPtr>(return_act_inputs[j])) {
957       return_act_inputs[j]->cast<ParameterPtr>()->set_name(subgraph_name + "_output_" +
958                                                            std::to_string(j - start_index) + "_parameter");
959     }
960   }
961   return RET_OK;
962 }
963 
ConvertIfOnnxNode(const onnx::NodeProto & onnx_node,std::unordered_map<std::string,AnfNodePtr> * anf_root_nodes_map,const std::string & root_node_name)964 STATUS OnnxModelParser::ConvertIfOnnxNode(const onnx::NodeProto &onnx_node,
965                                           std::unordered_map<std::string, AnfNodePtr> *anf_root_nodes_map,
966                                           const std::string &root_node_name) {
967   CHECK_NULL_RETURN(anf_root_nodes_map);
968   FuncGraphPtr then_branch_graph = nullptr;
969   FuncGraphPtr else_branch_graph = nullptr;
970   FuncGraphPtr subgraph = nullptr;
971   std::string subgraph_name;
972   auto &if_node_name = onnx_node.name();
973 
974   for (int i = 0; i < onnx_node.attribute_size(); i++) {
975     auto &attr = onnx_node.attribute(i);
976     auto &subgraph_proto = attr.g();
977     if (attr.name().find("then_branch") != std::string::npos) {
978       subgraph_name = if_node_name + "_then_branch";
979       then_branch_graph = std::make_shared<FuncGraph>();
980       MS_CHECK_TRUE_MSG(then_branch_graph != nullptr, RET_NULL_PTR, "create then_branch_graph return nullptr");
981       auto status = ConvertIfSubgraph(subgraph_proto, then_branch_graph, subgraph_name, if_node_name, root_node_name);
982       if (status != RET_OK) {
983         MS_LOG(ERROR) << "build if node else branch failed.";
984       }
985     } else if (attr.name().find("else_branch") != std::string::npos) {
986       subgraph_name = if_node_name + "_else_branch";
987       else_branch_graph = std::make_shared<FuncGraph>();
988       MS_CHECK_TRUE_MSG(else_branch_graph != nullptr, RET_NULL_PTR, "create else_branch_graph return nullptr");
989       auto status = ConvertIfSubgraph(subgraph_proto, else_branch_graph, subgraph_name, if_node_name, root_node_name);
990       if (status != RET_OK) {
991         MS_LOG(ERROR) << "build if node else branch failed.";
992       }
993     } else {
994       continue;
995     }
996   }
997   all_subgraphs_.emplace_back(then_branch_graph);
998   all_subgraphs_.emplace_back(else_branch_graph);
999   auto then_value_node = NewValueNode(then_branch_graph);
1000   MS_CHECK_TRUE_MSG(then_value_node != nullptr, RET_NULL_PTR, "create then_value_node return nullptr");
1001   auto else_value_node = NewValueNode(else_branch_graph);
1002   MS_CHECK_TRUE_MSG(else_value_node != nullptr, RET_NULL_PTR, "create else_value_node return nullptr");
1003   auto root_if_node = GetCNodeFromControlFlowNodesMap(if_node_name, control_nodes_map_);
1004   MS_CHECK_TRUE_MSG(root_if_node != nullptr, RET_ERROR, "cannot find root_if_node is control_nodes_map");
1005   auto if_new_inputs = root_if_node->inputs();
1006   if_new_inputs.insert(if_new_inputs.begin() + 1, {then_value_node, else_value_node});
1007 
1008   std::vector<AnfNodePtr> if_new_input_not_same{};
1009   std::set<AnfNodePtr> if_set{};
1010   for (auto &input : if_new_inputs) {
1011     if (if_set.find(input) != if_set.end()) {
1012       continue;
1013     }
1014     if_new_input_not_same.push_back(input);
1015     if_set.insert(input);
1016   }
1017 
1018   root_if_node->set_inputs(if_new_input_not_same);
1019   return RET_OK;
1020 }
1021 
BuildCNode(const onnx::NodeProto & onnx_node,const FuncGraphPtr & anf_graph,std::unordered_map<std::string,AnfNodePtr> * anf_nodes_map,std::vector<AnfNodePtr> * graph_inputs,PrimitiveCPtr primitive_c,std::string loop_name)1022 STATUS OnnxModelParser::BuildCNode(const onnx::NodeProto &onnx_node, const FuncGraphPtr &anf_graph,
1023                                    std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map,
1024                                    std::vector<AnfNodePtr> *graph_inputs, PrimitiveCPtr primitive_c,
1025                                    std::string loop_name) {
1026   CHECK_NULL_RETURN(anf_graph);
1027   CHECK_NULL_RETURN(anf_nodes_map);
1028   CHECK_NULL_RETURN(primitive_c);
1029   std::vector<AnfNodePtr> op_inputs;
1030   for (int i = 0; i < onnx_node.input_size(); i++) {
1031     auto input_name = onnx_node.input(i);
1032     if (input_name.empty()) {
1033       std::string empty_input_index = "empty_input_index";
1034       primitive_c->AddAttr(empty_input_index, MakeValue<int>(i));
1035       continue;
1036     }
1037 
1038     if (anf_nodes_map->find(input_name) != anf_nodes_map->end()) {
1039       op_inputs.push_back(anf_nodes_map->at(input_name));
1040     } else {
1041       // subgraph may refer root graph nodes
1042       std::vector<CNodePtr> need_add_input_nodes;
1043       auto ext_subgraph_input = anf_graph->add_parameter();
1044       MS_CHECK_TRUE_MSG(ext_subgraph_input != nullptr, RET_NULL_PTR, "create parameter return nullptr");
1045       ParameterPtr inner_extra_paramter = nullptr;
1046       while (!loop_name.empty() && child_root_map_.find(loop_name) != child_root_map_.end()) {
1047         auto cur_node_map = control_nodes_map_[loop_name];
1048         CHECK_NULL_RETURN(cur_node_map);
1049         if (cur_node_map->find(input_name) != cur_node_map->end()) {
1050           auto outside_input_node = cur_node_map->at(input_name);
1051           CHECK_NULL_RETURN(outside_input_node);
1052           // copy outside input parameter value to inside subgraph
1053           ext_subgraph_input->set_abstract(outside_input_node->abstract());
1054           ext_subgraph_input->set_name(input_name);
1055           if (outside_input_node->isa<Parameter>()) {
1056             auto parameter = outside_input_node->cast<ParameterPtr>();
1057             if (!parameter->has_default()) {
1058               MS_LOG(ERROR) << "outside_input_node should has data.";
1059               return RET_ERROR;
1060             }
1061             auto tensor_info = parameter->default_param()->cast<tensor::TensorPtr>();
1062             auto copy_tensor_info = CreateTensorInfo(tensor_info->data_c(), tensor_info->Size(), tensor_info->shape(),
1063                                                      tensor_info->data_type());
1064             if (copy_tensor_info == nullptr) {
1065               MS_LOG(ERROR) << "memcpy failed.";
1066               return RET_ERROR;
1067             }
1068             ext_subgraph_input->set_default_param(copy_tensor_info);
1069           } else {
1070             // output inside cnode need make extra input
1071             CHECK_NULL_RETURN(graph_inputs);
1072             graph_inputs->emplace_back(ext_subgraph_input);
1073             if (cur_node_map->find(loop_name) != cur_node_map->end()) {
1074               CHECK_NULL_RETURN(cur_node_map->at(loop_name));
1075               auto control_node = cur_node_map->at(loop_name)->cast<CNodePtr>();
1076               MS_CHECK_TRUE_RET(control_node != nullptr, RET_NULL_PTR);
1077               control_node->add_input(outside_input_node);
1078             } else {
1079               MS_LOG(ERROR) << "loop node: " << loop_name << " not found in cur node map.";
1080               return RET_ERROR;
1081             }
1082             for (auto &control_node : need_add_input_nodes) {
1083               CHECK_NULL_RETURN(control_node);
1084               auto func_graph = control_node->func_graph();
1085               auto extra_input_parameter = func_graph->add_parameter();
1086               MS_CHECK_TRUE_MSG(extra_input_parameter != nullptr, RET_NULL_PTR, "create parameter return nullptr");
1087               extra_input_parameter->set_name(input_name);
1088               extra_input_parameter->set_abstract(outside_input_node->abstract());
1089               control_node->add_input(extra_input_parameter);
1090             }
1091           }
1092           op_inputs.push_back(ext_subgraph_input);
1093           anf_nodes_map->emplace(input_name, ext_subgraph_input);
1094           break;
1095         } else {
1096           if (cur_node_map->find(loop_name) != cur_node_map->end()) {
1097             CHECK_NULL_RETURN(cur_node_map->at(loop_name));
1098             need_add_input_nodes.emplace_back(cur_node_map->at(loop_name)->cast<CNodePtr>());
1099           } else {
1100             MS_LOG(ERROR) << "loop node: " << loop_name << " not found in cur node map.";
1101             return RET_ERROR;
1102           }
1103           loop_name = child_root_map_[loop_name];
1104         }
1105       }
1106     }
1107   }
1108   auto new_cnode = anf_graph->NewCNode(primitive_c, op_inputs);
1109   if (new_cnode == nullptr) {
1110     MS_LOG(ERROR) << "new cnode error";
1111     return RET_ERROR;
1112   }
1113   new_cnode->set_fullname_with_scope(onnx_node.name());
1114   auto status = BuildOpOutputs(onnx_node, anf_graph, anf_nodes_map, new_cnode);
1115   return status;
1116 }
1117 
ConvertOpQuantParams(const onnx::NodeProto & onnx_node,ops::PrimitiveCPtr primitive_c)1118 STATUS OnnxModelParser::ConvertOpQuantParams(const onnx::NodeProto &onnx_node, ops::PrimitiveCPtr primitive_c) {
1119   CHECK_NULL_RETURN(primitive_c);
1120   auto status = ParseQuantParam(onnx_node);
1121   if (status != RET_OK) {
1122     MS_LOG(ERROR) << "parse quant param failed.";
1123     return RET_ERROR;
1124   }
1125   // set input tensors
1126   std::map<int, std::vector<schema::QuantParamT>> input_quant_params;
1127   size_t idx = 0;
1128   for (int i = 0; i < onnx_node.input_size(); ++i) {
1129     const auto &input_name = onnx_node.input(i);
1130     std::vector<schema::QuantParamT> quant_params;
1131     status = SetTensorQuantParam(input_name, &quant_params);
1132     if (status != RET_OK) {
1133       MS_LOG(ERROR) << "set input tensor quant param failed.";
1134       return status;
1135     }
1136     if (!quant_params.empty()) {
1137       input_quant_params.insert({idx, quant_params});
1138       idx++;
1139     }
1140   }
1141   // set out tensors
1142   idx = 0;
1143   std::map<int, std::vector<schema::QuantParamT>> output_quant_params;
1144   for (int i = 0; i < onnx_node.output_size(); ++i) {
1145     const auto &output_name = onnx_node.output(i);
1146     std::vector<schema::QuantParamT> quant_params;
1147     status = SetTensorQuantParam(output_name, &quant_params);
1148     if (status != RET_OK) {
1149       MS_LOG(ERROR) << "set output tensor quant param failed.";
1150       return status;
1151     }
1152     if (!quant_params.empty()) {
1153       output_quant_params.insert({idx, quant_params});
1154       idx++;
1155     }
1156   }
1157   schema::QuantParamT quant_param;
1158   bool has_quant = false;
1159   for (const auto &onnx_node_attr : onnx_node.attribute()) {
1160     if (onnx_node_attr.name() == "scale") {
1161       float scale = onnx_node_attr.f();
1162       quant_param.scale = scale;
1163       MS_LOG(INFO) << onnx_node.name() << " scale is " << quant_param.scale;
1164       has_quant = true;
1165     } else if (onnx_node_attr.name() == "offset") {
1166       float offset = onnx_node_attr.f();
1167       quant_param.zeroPoint = static_cast<int>(offset);
1168       MS_LOG(INFO) << onnx_node.name() << " offset is " << quant_param.zeroPoint;
1169       has_quant = true;
1170     } else if (onnx_node_attr.name() == "quant_bit") {
1171       int64_t quant_bit = onnx_node_attr.i();
1172       quant_param.numBits = static_cast<int>(quant_bit);
1173       MS_LOG(INFO) << onnx_node.name() << " quant_bit is " << quant_param.numBits;
1174       has_quant = true;
1175     }
1176   }
1177   if (has_quant) {
1178     std::vector<schema::QuantParamT> quant_params;
1179     quant_param.inited = true;
1180     quant_params.push_back(quant_param);
1181     input_quant_params.insert({0, quant_params});
1182     output_quant_params.insert({0, quant_params});
1183   }
1184   if (!input_quant_params.empty() || !output_quant_params.empty()) {
1185     auto quant_params_holder = std::make_shared<QuantParamHolder>(0, 0);
1186     CHECK_NULL_RETURN(quant_params_holder);
1187     for (auto &iter : input_quant_params) {
1188       quant_params_holder->set_input_quant_param(iter.first, iter.second);
1189     }
1190     for (auto &iter : output_quant_params) {
1191       quant_params_holder->set_output_quant_param(iter.first, iter.second);
1192     }
1193     primitive_c->AddAttr("quant_params", quant_params_holder);
1194   }
1195   return RET_OK;
1196 }
1197 
ParseQuantParam(const onnx::NodeProto & onnx_node)1198 STATUS OnnxModelParser::ParseQuantParam(const onnx::NodeProto &onnx_node) {
1199   for (const auto &onnx_node_attr : onnx_node.attribute()) {
1200     if (onnx_node_attr.name() == "Y_scale") {
1201       float scale = onnx_node_attr.f();
1202       if (BuildParameterNodeForQuantParam(&scale, "scale_" + onnx_node.output(0), kNumberTypeFloat32) != RET_OK) {
1203         MS_LOG(ERROR) << "parse quant param failed.";
1204         return RET_ERROR;
1205       }
1206     } else if (onnx_node_attr.name() == "Y_zero_point") {
1207       int64_t zero_point = onnx_node_attr.i();
1208       if (BuildParameterNodeForQuantParam(&zero_point, "zero_point_" + onnx_node.output(0), kNumberTypeInt64) !=
1209           RET_OK) {
1210         MS_LOG(ERROR) << "parse quant param failed.";
1211         return RET_ERROR;
1212       }
1213     }
1214   }
1215   return RET_OK;
1216 }
1217 
SetTensorQuantParam(const std::string & tensor_name,std::vector<QuantParamT> * quant_params)1218 STATUS OnnxModelParser::SetTensorQuantParam(const std::string &tensor_name, std::vector<QuantParamT> *quant_params) {
1219   MS_CHECK_TRUE_RET(quant_params != nullptr, RET_NULL_PTR);
1220   quant_params->clear();
1221   auto quant_param = std::make_unique<QuantParamT>();
1222   MS_CHECK_TRUE_MSG(quant_param != nullptr, RET_NULL_PTR, "create QuantParamT return nullptr");
1223   for (int i = 0; i < onnx_root_graph_.quantization_annotation_size(); ++i) {
1224     auto tensor_annotation = onnx_root_graph_.quantization_annotation(i);
1225     if (!tensor_annotation.has_tensor_name() || tensor_annotation.tensor_name() != tensor_name) {
1226       continue;
1227     }
1228     for (const auto &item : tensor_annotation.quant_parameter_tensor_names()) {
1229       if (!item.has_key() || !item.has_value()) {
1230         continue;
1231       }
1232 
1233       const auto &quant_tensor_name = item.value();
1234       if (item.key() == "SCALE_TENSOR") {
1235         auto status = CopyTensorQuantParam(quant_tensor_name, quant_param.get(), true);
1236         if (status != RET_OK) {
1237           MS_LOG(ERROR) << "quant param scale get failed";
1238           return status;
1239         }
1240       } else if (item.key() == "ZERO_POINT_TENSOR") {
1241         auto status = CopyTensorQuantParam(quant_tensor_name, quant_param.get(), false);
1242         if (status != RET_OK) {
1243           MS_LOG(ERROR) << "quant param zero_point get failed";
1244           return status;
1245         }
1246       }
1247     }
1248     break;
1249   }
1250   if (quant_param->inited) {
1251     quant_params->push_back(*std::move(quant_param));
1252     return RET_OK;
1253   }
1254   return SetTensorQuantParamFromNode(tensor_name, quant_params);
1255 }
1256 
SetTensorQuantParamFromNode(const std::string & tensor_name,std::vector<QuantParamT> * quant_params)1257 STATUS OnnxModelParser::SetTensorQuantParamFromNode(const std::string &tensor_name,
1258                                                     std::vector<QuantParamT> *quant_params) {
1259   MS_CHECK_TRUE_RET(quant_params != nullptr, RET_NULL_PTR);
1260   quant_params->clear();
1261   auto quant_param = std::make_unique<QuantParamT>();
1262   MS_CHECK_TRUE_MSG(quant_param != nullptr, RET_NULL_PTR, "create QuantParamT return nullptr");
1263   if (OnnxNodeParser::opset_version() <= 15) {
1264     quant_param->multiplier = 0;
1265   }
1266   std::string quant_tensor_name = "scale_" + tensor_name;
1267   auto status = CopyTensorQuantParam(quant_tensor_name, quant_param.get(), true);
1268   if (status != RET_OK) {
1269     MS_LOG(ERROR) << "quant param scale get failed";
1270     return status;
1271   }
1272   quant_tensor_name = "zero_point_" + tensor_name;
1273   status = CopyTensorQuantParam(quant_tensor_name, quant_param.get(), false);
1274   if (status != RET_OK) {
1275     MS_LOG(ERROR) << "quant param zero_point get failed";
1276     return status;
1277   }
1278   if (quant_param->inited) {
1279     quant_params->push_back(*std::move(quant_param));
1280   }
1281   return RET_OK;
1282 }
1283 
CopyTensorQuantParam(const std::string & tensor_name,QuantParamT * quant_param,bool scale_or_not)1284 STATUS OnnxModelParser::CopyTensorQuantParam(const std::string &tensor_name, QuantParamT *quant_param,
1285                                              bool scale_or_not) {
1286   CHECK_NULL_RETURN(quant_param);
1287   auto iter = anf_nodes_map_.find(tensor_name);
1288   if (iter == anf_nodes_map_.end()) {
1289     MS_LOG(DEBUG) << "has no quant param";
1290     return RET_OK;
1291   }
1292   if (!utils::isa<ParameterPtr>(iter->second)) {
1293     MS_LOG(ERROR) << "quant param get failed";
1294     return RET_ERROR;
1295   }
1296   auto quant_parameter_node = iter->second->cast<ParameterPtr>();
1297   MS_CHECK_TRUE_RET(quant_parameter_node != nullptr, RET_NULL_PTR);
1298   if (!quant_parameter_node->has_default()) {
1299     MS_LOG(ERROR) << "quant param get failed";
1300     return RET_ERROR;
1301   }
1302   auto tensor_info = quant_parameter_node->default_param()->cast<tensor::TensorPtr>();
1303   if (tensor_info == nullptr) {
1304     MS_LOG(ERROR) << "parameterNode's default param is not tensor::TensorPtr";
1305     return RET_ERROR;
1306   }
1307   if (tensor_info->data_c() == nullptr) {
1308     MS_LOG(ERROR) << "parameterNode's default param has no data";
1309     return RET_ERROR;
1310   }
1311   if (scale_or_not) {
1312     quant_param->scale = *reinterpret_cast<float *>(tensor_info->data_c());
1313     quant_param->inited = true;
1314   } else {
1315     quant_param->zeroPoint = *reinterpret_cast<int64_t *>(tensor_info->data_c());
1316     quant_param->inited = true;
1317   }
1318   return RET_OK;
1319 }
1320 
AddTensorListStackNode(const AnfNodePtr & root_while_node,const onnx::NodeProto & onnx_node,int act_outputs_num,int body_output_size)1321 STATUS OnnxModelParser::AddTensorListStackNode(const AnfNodePtr &root_while_node, const onnx::NodeProto &onnx_node,
1322                                                int act_outputs_num, int body_output_size) {
1323   MS_CHECK_TRUE_RET(root_while_node != nullptr, RET_NULL_PTR);
1324   auto &loop_node_name = onnx_node.name();
1325   auto root_anf_graph = root_while_node->func_graph();
1326   auto stack_elem_node = CreateConstParamter(root_anf_graph, -1);
1327   MS_CHECK_TRUE_MSG(stack_elem_node != nullptr, RET_NULL_PTR, "create const parameter return nullptr");
1328   stack_elem_node->set_name(loop_node_name + "_element_shape");
1329   for (int j = 0; j < act_outputs_num; j++) {
1330     auto output_size = onnx_node.output_size();
1331     auto &loop_output_name = onnx_node.output(output_size - act_outputs_num + j);
1332     MS_CHECK_TRUE_RET(control_nodes_map_.find(loop_node_name) != control_nodes_map_.end(), RET_NULL_PTR);
1333     MS_CHECK_TRUE_RET(
1334       control_nodes_map_[loop_node_name]->find(loop_output_name) != control_nodes_map_[loop_node_name]->end(),
1335       RET_NULL_PTR);
1336     auto &while_output_node = control_nodes_map_[loop_node_name]->at(loop_output_name);
1337     MS_CHECK_TRUE_MSG(while_output_node != nullptr, RET_ERROR, "cannot find while_output_node is control_nodes_map");
1338     auto tensor_list_stack_prim = std::make_shared<ops::TensorListStack>();
1339     if (tensor_list_stack_prim == nullptr) {
1340       MS_LOG(ERROR) << "create stack failed";
1341       return RET_ERROR;
1342     }
1343     tensor_list_stack_prim->set_num_elements(-1);
1344     auto prim_c = tensor_list_stack_prim->GetPrim();
1345     MS_CHECK_TRUE_RET(prim_c != nullptr, RET_ERROR);
1346     auto stack_value_node = NewValueNode(prim_c);
1347     MS_CHECK_TRUE_MSG(stack_value_node != nullptr, RET_NULL_PTR, "create stack_value_node return nullptr");
1348     std::vector<AnfNodePtr> stack_inputs = {stack_value_node, while_output_node, stack_elem_node};
1349     auto tensorlist_stack_cnode = root_anf_graph->NewCNode(stack_inputs);
1350     if (tensorlist_stack_cnode == nullptr) {
1351       MS_LOG(ERROR) << "new cnode error";
1352       return RET_ERROR;
1353     }
1354     tensorlist_stack_cnode->set_fullname_with_scope(loop_node_name + "_tensorlist_stack_node_" + std::to_string(j));
1355     tensorlist_stack_cnode->set_abstract(stack_elem_node->abstract());
1356 
1357     // update getitem value output index
1358     auto new_get_item_value = NewValueNode(MakeValue<int64_t>(body_output_size - act_outputs_num + j));
1359     MS_CHECK_TRUE_MSG(new_get_item_value != nullptr, RET_NULL_PTR, "create new_get_item_value return nullptr");
1360     CHECK_NULL_RETURN(while_output_node->cast<CNodePtr>());
1361     while_output_node->cast<CNodePtr>()->set_input(2, new_get_item_value);
1362     // insert tensorliststack after while_output
1363     (*control_nodes_map_[loop_node_name])[loop_output_name] = tensorlist_stack_cnode;
1364   }
1365   return RET_OK;
1366 }
1367 
1368 // onnx loop scan_output need through tensorlist op,while node need add new inputs
AddTensorArrayEdge(const FuncGraphPtr & anf_graph,std::vector<AnfNodePtr> * return_new_inputs,const std::string & loop_node_name,std::vector<AnfNodePtr> * body_graph_inputs,int act_output_num)1369 STATUS OnnxModelParser::AddTensorArrayEdge(const FuncGraphPtr &anf_graph, std::vector<AnfNodePtr> *return_new_inputs,
1370                                            const std::string &loop_node_name,
1371                                            std::vector<AnfNodePtr> *body_graph_inputs, int act_output_num) {
1372   MS_CHECK_TRUE_RET(anf_graph != nullptr && return_new_inputs != nullptr && body_graph_inputs != nullptr, RET_NULL_PTR);
1373   // body graph output is  trip_count,cond_count,loop_var,placeholder,scan_outputs
1374   auto root_while_node = GetCNodeFromControlFlowNodesMap(loop_node_name, control_nodes_map_);
1375   MS_CHECK_TRUE_MSG(root_while_node != nullptr, RET_ERROR, "cannot find root_while_node is control_nodes_map");
1376   if (root_while_node == nullptr) {
1377     MS_LOG(ERROR) << "anf root node map cannot find loop node" << loop_node_name;
1378     return RET_ERROR;
1379   }
1380   auto anf_root_graph = root_while_node->func_graph();
1381   auto root_item_index_parameter = CreateConstParamter(anf_root_graph, 0);
1382   MS_CHECK_TRUE_MSG(root_item_index_parameter != nullptr, RET_NULL_PTR,
1383                     "create root_item_index_parameter return nullptr");
1384   root_item_index_parameter->set_name(loop_node_name + "_item_index");
1385   root_while_node->add_input(root_item_index_parameter);
1386   // fake parameter need pass by root while node input
1387   auto item_index_parameter = anf_graph->add_parameter();
1388   MS_CHECK_TRUE_MSG(item_index_parameter != nullptr, RET_NULL_PTR, "create item_index_parameter return nullptr");
1389   item_index_parameter->set_name(loop_node_name + "_item_index_2");
1390   item_index_parameter->set_abstract(root_item_index_parameter->abstract());
1391   body_graph_inputs->emplace_back(item_index_parameter);
1392   // item index++ edge
1393   auto add_value_node = CreateValueNode(schema::PrimitiveType_AddFusion);
1394   if (add_value_node == nullptr) {
1395     MS_LOG(ERROR) << "create add failed.";
1396     return RET_NULL_PTR;
1397   }
1398   auto add_one_input = CreateConstParamter(anf_graph, 1);
1399   MS_CHECK_TRUE_MSG(root_item_index_parameter != nullptr, RET_NULL_PTR, "create add_one_input return nullptr");
1400   add_one_input->set_name(loop_node_name + "_const_placeholder_1");
1401   std::vector<AnfNodePtr> add_inputs = {add_value_node, item_index_parameter, add_one_input};
1402   auto add_cnode = anf_graph->NewCNode(add_inputs);
1403   if (add_cnode == nullptr) {
1404     MS_LOG(ERROR) << "new cnode error";
1405     return RET_ERROR;
1406   }
1407   add_cnode->set_fullname_with_scope(loop_node_name + "item_index_add_node");
1408   add_cnode->set_abstract(root_item_index_parameter->abstract());
1409   // return node inputs will be trip_count,cond_out,loop_var,placeholder,tensorarray...
1410   if (static_cast<int>(return_new_inputs->size()) < act_output_num || act_output_num < 0) {
1411     MS_LOG(ERROR) << "act_output_num out of range of return_new_inputs";
1412     return RET_ERROR;
1413   }
1414   return_new_inputs->insert(return_new_inputs->end() - act_output_num, add_cnode);
1415 
1416   for (int i = 0; i < act_output_num; i++) {
1417     // tensor_array need as root while input
1418     auto while_tensor_array_input = anf_root_graph->add_parameter();
1419     MS_CHECK_TRUE_MSG(while_tensor_array_input != nullptr, RET_NULL_PTR,
1420                       "create while_tensor_array_input return nullptr");
1421     std::vector<int> tensor_list_data(kTensorListDatasize);
1422     tensor_list_data[kTypeIndex] = kTypeUnknown;
1423     tensor_list_data[kElementShapeIndex] = 0;
1424     tensor_list_data[kTensorsNumIndex] = -1;
1425     if (INT_MUL_OVERFLOW_THRESHOLD(tensor_list_data.size(), sizeof(int), SIZE_MAX)) {
1426       MS_LOG(ERROR) << "data_size overflow";
1427       return RET_ERROR;
1428     }
1429     auto tensor_info = CreateTensorInfo(tensor_list_data.data(), tensor_list_data.size() * sizeof(int),
1430                                         {static_cast<int64_t>(tensor_list_data.size())}, kObjectTypeTensorType);
1431     if (tensor_info == nullptr) {
1432       MS_LOG(ERROR) << "Create tensor info failed";
1433       return RET_ERROR;
1434     }
1435     auto abstract_tensor = tensor_info->ToAbstract();
1436     if (abstract_tensor == nullptr) {
1437       MS_LOG(ERROR) << "Create tensor abstarct failed";
1438       return RET_ERROR;
1439     }
1440     while_tensor_array_input->set_abstract(abstract_tensor);
1441     while_tensor_array_input->set_default_param(tensor_info);
1442     while_tensor_array_input->set_name(loop_node_name + "_scan_outputs_tensorarray_while_input");
1443     root_while_node->add_input(while_tensor_array_input);
1444 
1445     auto subgraph_tensor_array_input = anf_graph->add_parameter();
1446     MS_CHECK_TRUE_MSG(subgraph_tensor_array_input != nullptr, RET_NULL_PTR,
1447                       "create subgraph_tensor_array_input return nullptr");
1448     subgraph_tensor_array_input->set_name(loop_node_name + "_scan_outputs_tensorarray_body_fg_input");
1449     subgraph_tensor_array_input->set_abstract(abstract_tensor);
1450     body_graph_inputs->emplace_back(subgraph_tensor_array_input);
1451     // skip trip_count ,cond_out,loop_var,no_loop_var,place_holder, output
1452     auto loop_output_idx = return_new_inputs->size() - act_output_num + i;
1453     auto loop_output_node = (*return_new_inputs)[loop_output_idx];
1454     auto set_item_value_node = CreateValueNode(schema::PrimitiveType_TensorListSetItem);
1455     if (set_item_value_node == nullptr) {
1456       MS_LOG(ERROR) << "create tensor list set item failed.";
1457       return RET_NULL_PTR;
1458     }
1459     std::vector<AnfNodePtr> set_item_inputs = {set_item_value_node, subgraph_tensor_array_input, item_index_parameter,
1460                                                loop_output_node};
1461     auto tensorlist_setitem_cnode = anf_graph->NewCNode(set_item_inputs);
1462     if (tensorlist_setitem_cnode == nullptr) {
1463       MS_LOG(ERROR) << "new cnode error";
1464       return RET_ERROR;
1465     }
1466     tensorlist_setitem_cnode->set_fullname_with_scope(loop_node_name + "_tensorlist_setitem_node");
1467     tensorlist_setitem_cnode->set_abstract(abstract_tensor);
1468     // loop output need replace by tensorliststack_output
1469     (*return_new_inputs)[loop_output_idx] = tensorlist_setitem_cnode;
1470   }
1471 
1472   return RET_OK;
1473 }
1474 
ConvertLoopOnnxNode(const onnx::NodeProto & onnx_node,std::unordered_map<std::string,AnfNodePtr> * anf_root_nodes_map,const std::string & root_node_name)1475 STATUS OnnxModelParser::ConvertLoopOnnxNode(const onnx::NodeProto &onnx_node,
1476                                             std::unordered_map<std::string, AnfNodePtr> *anf_root_nodes_map,
1477                                             const std::string &root_node_name) {
1478   MS_CHECK_TRUE_RET(anf_root_nodes_map != nullptr, RET_NULL_PTR);
1479   for (int i = 0; i < onnx_node.attribute_size(); i++) {
1480     auto &attr = onnx_node.attribute(i);
1481     if (attr.name() != "body" || attr.type() != onnx::AttributeProto_AttributeType_GRAPH) {
1482       continue;
1483     }
1484     auto &subgraph_proto = attr.g();
1485     int cond_graph_input_num = -1;
1486     auto loop_body_graph = BuildBodyGraph(onnx_node, subgraph_proto, &cond_graph_input_num);
1487     MS_CHECK_TRUE_MSG(loop_body_graph != nullptr, RET_NULL_PTR, "create loop_body_graph return nullptr");
1488     auto root_while_node = GetCNodeFromControlFlowNodesMap(onnx_node.name(), control_nodes_map_);
1489     MS_CHECK_TRUE_MSG(root_while_node != nullptr, RET_ERROR, "cannot find root_while_node");
1490     auto loop_cond_graph = BuildCondGraph(root_while_node, cond_graph_input_num, onnx_node.name() + "_cond_graph");
1491     MS_CHECK_TRUE_MSG(loop_cond_graph != nullptr, RET_NULL_PTR, "create loop_cond_graph return nullptr");
1492     all_subgraphs_.emplace_back(loop_body_graph);
1493     all_subgraphs_.emplace_back(loop_cond_graph);
1494     auto body_value_node = NewValueNode(loop_body_graph);
1495     MS_CHECK_TRUE_MSG(body_value_node != nullptr, RET_NULL_PTR, "create body_value_node return nullptr");
1496     auto inputs = root_while_node->inputs();
1497     auto cond_value_node = NewValueNode(loop_cond_graph);
1498     MS_CHECK_TRUE_MSG(cond_value_node != nullptr, RET_NULL_PTR, "create cond_value_node return nullptr");
1499     inputs.insert(inputs.begin() + 1, {cond_value_node, body_value_node});
1500     root_while_node->set_inputs(inputs);
1501   }
1502   return RET_OK;
1503 }
1504 
BuildParameterNodeForQuantParam(const void * data,const std::string & name,TypeId type)1505 STATUS OnnxModelParser::BuildParameterNodeForQuantParam(const void *data, const std::string &name, TypeId type) {
1506   CHECK_NULL_RETURN(data);
1507   if (type != kNumberTypeInt64 && type != kNumberTypeFloat32) {
1508     MS_LOG(ERROR) << "quant param type don't support.";
1509     return RET_NOT_SUPPORT;
1510   }
1511   auto res_graph = ConvertGraph(res_graph_);
1512   auto parameter_node = res_graph->add_parameter();
1513   MS_CHECK_TRUE_MSG(parameter_node != nullptr, RET_NULL_PTR, "create parameter return nullptr");
1514   auto abstract_tensor = CreateTensorAbstract({}, type);
1515   if (abstract_tensor == nullptr) {
1516     MS_LOG(ERROR) << "Create tensor abstarct failed";
1517     return RET_ERROR;
1518   }
1519   parameter_node->set_abstract(abstract_tensor);
1520   parameter_node->set_name(name);
1521   int data_size = 0;
1522   if (type == kNumberTypeFloat32) {
1523     data_size = sizeof(float);
1524   } else {
1525     data_size = sizeof(int64_t);
1526   }
1527   auto tensor_info = CreateTensorInfo(data, data_size, {1}, type);
1528   if (tensor_info == nullptr) {
1529     MS_LOG(ERROR) << "create tensor info failed.";
1530     return RET_ERROR;
1531   }
1532   parameter_node->set_default_param(tensor_info);
1533   anf_nodes_map_.emplace(name, parameter_node);
1534   return RET_OK;
1535 }
1536 
1537 REG_MODEL_PARSER(kFmkTypeOnnx, LiteModelParserCreator<OnnxModelParser>)
1538 }  // namespace lite
1539 }  // namespace mindspore
1540