• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "tools/converter/parser/onnx/onnx_model_parser.h"
18 #include <algorithm>
19 #include <memory>
20 #include <set>
21 #include <vector>
22 #include <unordered_map>
23 #include <utility>
24 #include "include/registry/node_parser_registry.h"
25 #include "tools/optimizer/common/gllo_utils.h"
26 #include "tools/common/graph_util.h"
27 #include "tools/common/protobuf_utils.h"
28 #include "tools/common/tensor_util.h"
29 #include "tools/converter/ops/ops_def.h"
30 #include "ops/tensor_list_stack.h"
31 #include "ir/func_graph.h"
32 #include "tools/converter/converter_flags.h"
33 #include "tools/converter/quant_param_holder.h"
34 #include "tools/converter/converter_context.h"
35 #include "tools/converter/parser/onnx/onnx_inputs_adjust.h"
36 #include "tools/converter/parser/onnx/onnx_pad_adjust.h"
37 #include "tools/converter/parser/parser_utils.h"
38 #include "tools/converter/parser/unify_format.h"
39 #include "nnacl/op_base.h"
40 #include "src/common/log_util.h"
41 
42 using mindspore::converter::kFmkTypeOnnx;
43 namespace mindspore {
44 namespace lite {
45 namespace {
46 constexpr int kTensorListDatasize = 3;
47 constexpr int kTypeIndex = 0;
48 constexpr int kElementShapeIndex = 1;
49 constexpr int kTensorsNumIndex = 2;
50 
Onnx2AnfAdjust(const std::set<FuncGraphPtr> & all_func_graphs)51 int Onnx2AnfAdjust(const std::set<FuncGraphPtr> &all_func_graphs) {
52   for (const auto &func_graph : all_func_graphs) {
53     MS_ASSERT(func_graph != nullptr);
54     if (!OnnxInputAdjust::Adjust(func_graph)) {
55       MS_LOG(ERROR) << "onnx adjust failed.";
56       ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
57       return RET_ERROR;
58     }
59     if (!OnnxPadAdjust::Adjust(func_graph)) {
60       MS_LOG(ERROR) << "onnx pad adjust failed.";
61       ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
62       return RET_ERROR;
63     }
64   }
65   return RET_OK;
66 }
67 
CreateConstParamter(const FuncGraphPtr & anf_graph,int val)68 ParameterPtr CreateConstParamter(const FuncGraphPtr &anf_graph, int val) {
69   MS_ASSERT(anf_graph != nullptr);
70   auto const_node = anf_graph->add_parameter();
71   MS_CHECK_TRUE_RET(const_node != nullptr, nullptr);
72   auto const_abstract = CreateTensorAbstract({}, kNumberTypeInt32);
73   if (const_abstract == nullptr) {
74     MS_LOG(ERROR) << "Create tensor abstarct failed";
75     return nullptr;
76   }
77   const_node->set_abstract(const_abstract);
78   int *tensor_data = new (std::nothrow) int[1];
79   if (tensor_data == nullptr) {
80     MS_LOG(ERROR) << "new int[] failed";
81     return nullptr;
82   }
83   tensor_data[0] = val;
84   auto tensor_info = CreateTensorInfo(tensor_data, sizeof(int), {1}, kNumberTypeInt32);
85   if (tensor_info == nullptr) {
86     MS_LOG(ERROR) << "create tensor info failed.";
87     delete[] tensor_data;
88     tensor_data = nullptr;
89     return nullptr;
90   }
91   tensor_data = nullptr;
92   delete[] tensor_data;
93   const_node->set_default_param(tensor_info);
94   return const_node;
95 }
96 
CreateValueNode(const schema::PrimitiveType & op_type)97 ValueNodePtr CreateValueNode(const schema::PrimitiveType &op_type) {
98   auto node_type = schema::EnumNamePrimitiveType(op_type);
99   auto op_primc_fns = ops::OpPrimCRegister::GetInstance().GetPrimCMap();
100   if (op_primc_fns.find(node_type) == op_primc_fns.end()) {
101     MS_LOG(ERROR) << "have no func to create primitive.";
102     return nullptr;
103   }
104   auto prim = op_primc_fns[node_type]();
105   if (prim == nullptr) {
106     MS_LOG(ERROR) << "cannot create primitive.";
107     return nullptr;
108   }
109   return NewValueNode(prim);
110 }
111 
AddIterNumsUpdateEdge(const FuncGraphPtr & anf_graph,std::vector<AnfNodePtr> * return_new_inputs,const std::unordered_map<std::string,AnfNodePtr> & anf_nodes_map,const std::string & trip_cout_name,const std::string & loop_node_name)112 STATUS AddIterNumsUpdateEdge(const FuncGraphPtr &anf_graph, std::vector<AnfNodePtr> *return_new_inputs,
113                              const std::unordered_map<std::string, AnfNodePtr> &anf_nodes_map,
114                              const std::string &trip_cout_name, const std::string &loop_node_name) {
115   MS_ASSERT(anf_graph != nullptr && return_new_inputs != nullptr);
116   // trip_cout need -1 after every iteration
117   auto sub_value_node = CreateValueNode(schema::PrimitiveType_SubFusion);
118   if (sub_value_node == nullptr) {
119     MS_LOG(ERROR) << "create sub failed.";
120     return RET_NULL_PTR;
121   }
122   auto trip_cout_paramter_iter = anf_nodes_map.find(trip_cout_name);
123   if (trip_cout_paramter_iter == anf_nodes_map.end()) {
124     MS_LOG(ERROR) << "can not find " << trip_cout_name;
125     return RET_ERROR;
126   }
127   auto &trip_cout_paramter = trip_cout_paramter_iter->second;
128   if (trip_cout_paramter == nullptr) {
129     MS_LOG(ERROR) << "trip_cout_paramter found failed";
130     return RET_ERROR;
131   }
132   auto const_one_parameter = CreateConstParamter(anf_graph, 1);
133   MS_CHECK_TRUE_MSG(const_one_parameter != nullptr, RET_ERROR, "create const parameter return nullptr");
134   const_one_parameter->set_name(loop_node_name + "_index_update_parameter");
135 
136   std::vector<AnfNodePtr> sub_inputs = {sub_value_node, trip_cout_paramter, const_one_parameter};
137   auto sub_cnode = anf_graph->NewCNode(sub_inputs);
138   if (sub_cnode == nullptr) {
139     MS_LOG(ERROR) << "new cnode error";
140     return RET_ERROR;
141   }
142   sub_cnode->set_fullname_with_scope(loop_node_name + "_sub");
143   sub_cnode->set_abstract(trip_cout_paramter->abstract());
144   return_new_inputs->insert(return_new_inputs->begin() + 1, sub_cnode);
145   return RET_OK;
146 }
147 
GetCNodeFromControlFlowNodesMap(const std::string & loop_node_name,const std::unordered_map<std::string,std::unordered_map<std::string,AnfNodePtr> * > & control_nodes_map)148 CNodePtr GetCNodeFromControlFlowNodesMap(
149   const std::string &loop_node_name,
150   const std::unordered_map<std::string, std::unordered_map<std::string, AnfNodePtr> *> &control_nodes_map) {
151   auto iter1 = control_nodes_map.find(loop_node_name);
152   if (iter1 == control_nodes_map.end()) {
153     return nullptr;
154   }
155   auto iter2 = iter1->second->find(loop_node_name);
156   if (iter2 == iter1->second->end()) {
157     return nullptr;
158   }
159   return iter2->second->cast<CNodePtr>();
160 }
161 
BuildReturnNode(const FuncGraphPtr & anf_graph,const std::vector<AnfNodePtr> & return_inputs)162 STATUS BuildReturnNode(const FuncGraphPtr &anf_graph, const std::vector<AnfNodePtr> &return_inputs) {
163   MS_ASSERT(anf_graph != nullptr);
164   auto return_prim = std::make_shared<lite::Return>();
165   if (return_prim == nullptr) {
166     MS_LOG(ERROR) << "new Return failed";
167     return RET_NULL_PTR;
168   }
169   auto return_cnode = anf_graph->NewCNode(return_prim, return_inputs);
170   if (return_cnode == nullptr) {
171     MS_LOG(ERROR) << "new cnode error";
172     return RET_ERROR;
173   }
174   return_cnode->set_fullname_with_scope("Return");
175   anf_graph->set_return(return_cnode);
176   return RET_OK;
177 }
178 
BuildParameterNode(const ParameterPtr & parameter_node,const onnx::TensorProto & tensor)179 STATUS BuildParameterNode(const ParameterPtr &parameter_node, const onnx::TensorProto &tensor) {
180   MS_ASSERT(parameter_node != nullptr);
181   auto data_type = OnnxNodeParser::GetDataTypeFromOnnx(static_cast<onnx::TensorProto_DataType>(tensor.data_type()));
182   if (data_type == kTypeUnknown) {
183     MS_LOG(ERROR) << "not support onnx data type " << static_cast<onnx::TensorProto_DataType>(tensor.data_type());
184     return RET_ERROR;
185   }
186   std::vector<int64_t> shape_vector(tensor.dims().begin(), tensor.dims().end());
187   auto abstract_tensor = CreateTensorAbstract(shape_vector, data_type);
188   if (abstract_tensor == nullptr) {
189     MS_LOG(ERROR) << "Create tensor abstarct failed";
190     return RET_ERROR;
191   }
192   parameter_node->set_abstract(abstract_tensor);
193   parameter_node->set_name(tensor.name());
194 
195   auto tensor_info = std::make_shared<tensor::Tensor>(data_type, shape_vector);
196   MS_CHECK_TRUE_MSG(tensor_info != nullptr, RET_NULL_PTR, "create tensor_info return nullptr");
197   std::vector<int> shape;
198   std::transform(shape_vector.begin(), shape_vector.end(), std::back_inserter(shape),
199                  [](const int64_t &value) { return static_cast<int>(value); });
200   auto status = OnnxNodeParser::CopyOnnxTensorData(tensor, tensor_info);
201   if (status != RET_OK) {
202     MS_LOG(ERROR) << "copy data failed.";
203     return status;
204   }
205   parameter_node->set_default_param(tensor_info);
206   return RET_OK;
207 }
208 
BuildOpOutputs(const onnx::NodeProto & onnx_node,const FuncGraphPtr & anf_graph,std::unordered_map<std::string,AnfNodePtr> * anf_nodes_map,const CNodePtr & cnode)209 STATUS BuildOpOutputs(const onnx::NodeProto &onnx_node, const FuncGraphPtr &anf_graph,
210                       std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map, const CNodePtr &cnode) {
211   MS_ASSERT(anf_graph != nullptr && cnode != nullptr && anf_nodes_map != nullptr);
212   if (onnx_node.output_size() == 1) {
213     auto abstract_tensor = CreateTensorAbstract({}, kNumberTypeFloat32);
214     if (abstract_tensor == nullptr) {
215       MS_LOG(ERROR) << "Create tensor abstarct failed";
216       return RET_ERROR;
217     }
218     cnode->set_abstract(abstract_tensor);
219     anf_nodes_map->emplace(onnx_node.output(0), cnode);
220   } else {
221     AbstractBasePtrList abstract_list;
222     int op_idx = 0;
223     for (const auto &output_name : onnx_node.output()) {
224       auto abstract_tensor = CreateTensorAbstract({}, kNumberTypeFloat32);
225       if (abstract_tensor == nullptr) {
226         MS_LOG(ERROR) << "Create tensor abstarct failed";
227         return RET_ERROR;
228       }
229       abstract_list.emplace_back(abstract_tensor);
230       auto tuple_get_item_prim_ptr = std::make_shared<lite::TupleGetItem>();
231       if (tuple_get_item_prim_ptr == nullptr) {
232         MS_LOG(ERROR) << "new TupleGetItem failed";
233         return RET_NULL_PTR;
234       }
235       auto tuple_get_item_prim = NewValueNode(tuple_get_item_prim_ptr);
236       MS_CHECK_TRUE_MSG(tuple_get_item_prim != nullptr, RET_NULL_PTR, "create ValueNode return nullptr");
237       auto get_item_value = NewValueNode(MakeValue<int>(op_idx));
238       MS_CHECK_TRUE_MSG(get_item_value != nullptr, RET_NULL_PTR, "create ValueNode return nullptr");
239       std::vector<AnfNodePtr> inputs{tuple_get_item_prim, cnode, get_item_value};
240       CNodePtr get_item_cnode = anf_graph->NewCNode(inputs);
241       if (get_item_cnode == nullptr) {
242         MS_LOG(ERROR) << "new cnode error";
243         return RET_ERROR;
244       }
245       get_item_cnode->set_fullname_with_scope(cnode->fullname_with_scope() + "_getitem_" + std::to_string(op_idx));
246       auto get_item_abstract = CreateTensorAbstract({}, kNumberTypeFloat32);
247       if (get_item_abstract == nullptr) {
248         MS_LOG(ERROR) << "Create tensor abstarct failed";
249         return RET_ERROR;
250       }
251       get_item_cnode->set_abstract(get_item_abstract);
252       anf_nodes_map->emplace(output_name, get_item_cnode);
253       op_idx++;
254     }
255     auto new_abstract_list = std::make_shared<abstract::AbstractTuple>(abstract_list);
256     CHECK_NULL_RETURN(new_abstract_list);
257     cnode->set_abstract(new_abstract_list);
258   }
259   anf_nodes_map->emplace(onnx_node.name(), cnode);
260   return RET_OK;
261 }
262 
ConvertConstTensors(const onnx::GraphProto & onnx_graph,const FuncGraphPtr & func_graph_ptr,std::unordered_map<std::string,AnfNodePtr> * anf_nodes_map)263 STATUS ConvertConstTensors(const onnx::GraphProto &onnx_graph, const FuncGraphPtr &func_graph_ptr,
264                            std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map) {
265   MS_ASSERT(func_graph_ptr != nullptr && anf_nodes_map != nullptr);
266   for (const auto &onnx_const_value : onnx_graph.initializer()) {
267     auto parameter = func_graph_ptr->add_parameter();
268     MS_CHECK_TRUE_MSG(parameter != nullptr, RET_NULL_PTR, "create parameter return nullptr");
269     auto status = BuildParameterNode(parameter, onnx_const_value);
270     if (status != RET_OK) {
271       MS_LOG(ERROR) << "parameter node build failed.";
272       return status;
273     }
274     anf_nodes_map->emplace(onnx_const_value.name(), parameter);
275   }
276   return RET_OK;
277 }
278 
ConvertGraphInputs(const onnx::GraphProto & onnx_graph,const FuncGraphPtr & func_graph_ptr,std::unordered_map<std::string,AnfNodePtr> * anf_nodes_map)279 STATUS ConvertGraphInputs(const onnx::GraphProto &onnx_graph, const FuncGraphPtr &func_graph_ptr,
280                           std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map) {
281   MS_ASSERT(func_graph_ptr != nullptr && anf_nodes_map != nullptr);
282   for (int i = 0; i < onnx_graph.input().size(); ++i) {
283     const auto &input_value = onnx_graph.input(i);
284     if (anf_nodes_map->find(input_value.name()) != anf_nodes_map->end()) {
285       continue;
286     }
287     auto parameter = func_graph_ptr->add_parameter();
288     MS_CHECK_TRUE_MSG(parameter != nullptr, RET_NULL_PTR, "create parameter return nullptr");
289     auto data_type = OnnxNodeParser::GetDataTypeFromOnnx(
290       static_cast<onnx::TensorProto_DataType>(input_value.type().tensor_type().elem_type()));
291     if (data_type == kTypeUnknown) {
292       MS_LOG(ERROR) << "not support onnx data type "
293                     << static_cast<onnx::TensorProto_DataType>(input_value.type().tensor_type().elem_type());
294       return RET_ERROR;
295     }
296     std::vector<int64_t> shape_vector = ConverterContext::GetInstance()->GetGraphInputTensorShape(input_value.name());
297     if (ConverterContext::GetInstance()->GetGraphInputTensorShapeMapSize() > 0 && shape_vector.empty()) {
298       MS_LOG(WARNING) << "Can not find name in map. name is " << input_value.name();
299     }
300     if (shape_vector.empty()) {
301       auto onnx_shape = input_value.type().tensor_type().shape().dim();
302       std::transform(onnx_shape.begin(), onnx_shape.end(), std::back_inserter(shape_vector),
303                      [](const onnx::TensorShapeProto_Dimension &val) { return static_cast<int64_t>(val.dim_value()); });
304       std::replace(shape_vector.begin(), shape_vector.end(), 0, -1);
305     }
306     auto abstract_tensor = CreateTensorAbstract(shape_vector, data_type);
307     if (abstract_tensor == nullptr) {
308       MS_LOG(ERROR) << "Create tensor abstarct failed";
309       return RET_ERROR;
310     }
311     parameter->set_abstract(abstract_tensor);
312     parameter->set_name(input_value.name());
313     anf_nodes_map->emplace(input_value.name(), parameter);
314   }
315   return RET_OK;
316 }
317 
ConvertGraphOutputs(const onnx::GraphProto & onnx_graph,const FuncGraphPtr & anf_graph,const std::unordered_map<std::string,AnfNodePtr> & anf_nodes_map)318 STATUS ConvertGraphOutputs(const onnx::GraphProto &onnx_graph, const FuncGraphPtr &anf_graph,
319                            const std::unordered_map<std::string, AnfNodePtr> &anf_nodes_map) {
320   MS_ASSERT(anf_graph != nullptr);
321   std::vector<AnfNodePtr> return_inputs;
322   if (onnx_graph.output_size() == 0) {
323     MS_LOG(ERROR) << "onnx graph has no output";
324     return RET_ERROR;
325   }
326   if (onnx_graph.output_size() > 1) {
327     std::vector<AnfNodePtr> make_tuple_inputs;
328     auto make_tuple_prim_ptr = std::make_shared<lite::MakeTuple>();
329     if (make_tuple_prim_ptr == nullptr) {
330       MS_LOG(ERROR) << "new MakeTuple failed";
331       return RET_NULL_PTR;
332     }
333     for (const auto &graph_out : onnx_graph.output()) {
334       if (anf_nodes_map.find(graph_out.name()) == anf_nodes_map.end()) {
335         MS_LOG(ERROR) << "graph output get failed.";
336         return RET_ERROR;
337       }
338       auto cnode = anf_nodes_map.at(graph_out.name());
339       if (cnode == nullptr) {
340         MS_LOG(ERROR) << "Can't find input node.";
341         return RET_NOT_FIND_OP;
342       }
343       make_tuple_inputs.emplace_back(cnode);
344     }
345     auto make_tuple_cnode = anf_graph->NewCNode(make_tuple_prim_ptr, make_tuple_inputs);
346     if (make_tuple_cnode == nullptr) {
347       MS_LOG(ERROR) << "new cnode error";
348       return RET_ERROR;
349     }
350     make_tuple_cnode->set_fullname_with_scope("return tuple");
351     return_inputs.emplace_back(make_tuple_cnode);
352   } else {
353     const auto &graph_out = onnx_graph.output(0);
354     if (anf_nodes_map.find(graph_out.name()) == anf_nodes_map.end()) {
355       MS_LOG(ERROR) << "graph output get failed.";
356       return RET_ERROR;
357     }
358     auto cnode = anf_nodes_map.at(graph_out.name());
359     if (cnode == nullptr) {
360       MS_LOG(ERROR) << "Can't find input node.";
361       return RET_NOT_FIND_OP;
362     }
363     return_inputs.emplace_back(cnode);
364   }
365   if (BuildReturnNode(anf_graph, return_inputs) != RET_OK) {
366     MS_LOG(ERROR) << "build return node failed.";
367     return RET_ERROR;
368   }
369   return RET_OK;
370 }
371 
BuildCondGraph(const AnfNodePtr & root_while_node,int inputs_num,const std::string & cond_graph_name)372 FuncGraphPtr BuildCondGraph(const AnfNodePtr &root_while_node, int inputs_num, const std::string &cond_graph_name) {
373   MS_ASSERT(root_while_node != nullptr);
374   auto cond_graph = std::make_shared<FuncGraph>();
375   MS_CHECK_TRUE_MSG(cond_graph != nullptr, nullptr, "create cond_graph return nullptr");
376   CNodePtr less_cnode = nullptr;
377   for (int i = 0; i < inputs_num; i++) {
378     auto input_parameter = cond_graph->add_parameter();
379     MS_CHECK_TRUE_MSG(input_parameter != nullptr, nullptr, "create input_parameter return nullptr");
380     input_parameter->set_name(cond_graph_name + "_input_" + std::to_string(i) + "_parameter");
381     auto input_abstract = CreateTensorAbstract({}, kNumberTypeInt32);
382     if (input_abstract == nullptr) {
383       MS_LOG(ERROR) << "Create tensor abstarct failed";
384       return nullptr;
385     }
386     input_parameter->set_abstract(input_abstract);
387     if (i == 0) {
388       auto zero_parameter = CreateConstParamter(cond_graph, 0);
389       MS_CHECK_TRUE_MSG(zero_parameter != nullptr, nullptr, "create zero_parameter return nullptr");
390       zero_parameter->set_name(root_while_node->fullname_with_scope() + "_const_0");
391       auto less_value_node = CreateValueNode(schema::PrimitiveType_Less);
392       MS_CHECK_TRUE_MSG(less_value_node != nullptr, nullptr, "create less_value_node return nullptr");
393       std::vector<AnfNodePtr> less_inputs = {less_value_node, zero_parameter, input_parameter};
394       less_cnode = cond_graph->NewCNode(less_inputs);
395       if (less_cnode == nullptr) {
396         MS_LOG(ERROR) << "new cnode error";
397         return nullptr;
398       }
399       auto less_abstract = CreateTensorAbstract({}, kNumberTypeBool);
400       if (less_abstract == nullptr) {
401         MS_LOG(ERROR) << "Create tensor abstarct failed";
402         return nullptr;
403       }
404       less_cnode->set_abstract(less_abstract);
405       less_cnode->set_fullname_with_scope(cond_graph_name + "_less_cnode");
406     }
407     if (i == 1) {
408       auto and_value_node = CreateValueNode(schema::PrimitiveType_LogicalAnd);
409       MS_CHECK_TRUE_MSG(and_value_node != nullptr, nullptr, "CreateValueNode failed");
410       std::vector<AnfNodePtr> and_inputs = {and_value_node, less_cnode, input_parameter};
411       auto and_cnode = cond_graph->NewCNode(and_inputs);
412       if (and_cnode == nullptr) {
413         MS_LOG(ERROR) << "new cnode error";
414         return nullptr;
415       }
416       and_cnode->set_abstract(less_cnode->abstract());
417       and_cnode->set_fullname_with_scope(cond_graph_name + "_output_" + std::to_string(0) + "_cnode");
418       auto status = BuildReturnNode(cond_graph, {and_cnode});
419       if (status != RET_OK) {
420         MS_LOG(ERROR) << "build return node failed: " << status;
421         return nullptr;
422       }
423     }
424   }
425   cond_graph->set_attr("graph_name", MakeValue(cond_graph_name));
426   return cond_graph;
427 }
428 }  // namespace
429 
BuildBodyGraph(const onnx::NodeProto & loop_node,const onnx::GraphProto & subgraph_proto,int * cond_graph_input_num)430 FuncGraphPtr OnnxModelParser::BuildBodyGraph(const onnx::NodeProto &loop_node, const onnx::GraphProto &subgraph_proto,
431                                              int *cond_graph_input_num) {
432   MS_ASSERT(cond_graph_input_num != nullptr);
433   auto &loop_node_name = loop_node.name();
434   auto node_inputs_num = loop_node.input_size();
435   auto node_outputs_num = loop_node.output_size();
436   // skip trip_cout and cond input,scan_output nums
437   auto act_outputs_num = node_outputs_num - (node_inputs_num - 2);
438   auto loop_body_graph = std::make_shared<FuncGraph>();
439   MS_CHECK_TRUE_MSG(loop_body_graph != nullptr, nullptr, "create loop_body_graph return nullptr");
440   std::unordered_map<std::string, AnfNodePtr> anf_nodes_map;
441   std::vector<AnfNodePtr> gen_subgraph_inputs;
442   auto status = ConvertOnnxGraph(subgraph_proto, loop_body_graph, &anf_nodes_map, &gen_subgraph_inputs, loop_node_name);
443   if (status != RET_OK) {
444     MS_LOG(ERROR) << "convert loop OnnxGraph: " << status;
445     return nullptr;
446   }
447   auto return_node = loop_body_graph->get_return();
448   MS_CHECK_TRUE_MSG(return_node != nullptr, nullptr, "return node of subgraph is nullptr");
449   MS_ASSERT(return_node->inputs().size() == DIMENSION_2D);
450   auto return_tuple_cnode = return_node->input(1)->cast<CNodePtr>();
451   MS_ASSERT(return_tuple_cnode != nullptr);
452   auto return_new_inputs = return_tuple_cnode->inputs();
453   return_new_inputs.insert(return_new_inputs.end() - act_outputs_num, gen_subgraph_inputs.begin(),
454                            gen_subgraph_inputs.end());
455 
456   std::string max_trip_count_name = subgraph_proto.input(0).name();
457   status =
458     AddIterNumsUpdateEdge(loop_body_graph, &return_new_inputs, anf_nodes_map, max_trip_count_name, loop_node_name);
459   if (status != RET_OK) {
460     MS_LOG(ERROR) << "add iter nums update edge failed: " << status;
461     return nullptr;
462   }
463   auto root_while_node = GetCNodeFromControlFlowNodesMap(loop_node_name, control_nodes_map_);
464   MS_CHECK_TRUE_MSG(root_while_node != nullptr, nullptr, "can not find root_while_node is control_nodes_map");
465   std::vector<AnfNodePtr> body_graph_inputs;
466   body_graph_inputs.reserve(subgraph_proto.input_size());
467   for (int j = 0; j < subgraph_proto.input_size(); j++) {
468     body_graph_inputs.emplace_back(anf_nodes_map[subgraph_proto.input(j).name()]);
469   }
470   body_graph_inputs.insert(body_graph_inputs.end(), gen_subgraph_inputs.begin(), gen_subgraph_inputs.end());
471   if (act_outputs_num != 0) {
472     status =
473       AddTensorArrayEdge(loop_body_graph, &return_new_inputs, loop_node_name, &body_graph_inputs, act_outputs_num);
474     if (status != RET_OK) {
475       MS_LOG(ERROR) << "add tensorarray update edge failed: " << status;
476       return nullptr;
477     }
478     // insert tensorliststack after while output
479     status = AddTensorListStackNode(root_while_node, loop_node, act_outputs_num, body_graph_inputs.size());
480     if (status != RET_OK) {
481       MS_LOG(ERROR) << "add tensorliststack node failed: " << status;
482       return nullptr;
483     }
484   }
485   return_tuple_cnode->set_inputs(return_new_inputs);
486   auto body_graph_name = loop_node_name + "_body_graph";
487   for (size_t j = 0; j < body_graph_inputs.size(); j++) {
488     MS_CHECK_TRUE_RET(body_graph_inputs[j] != nullptr, nullptr);
489     auto body_input = body_graph_inputs[j]->cast<ParameterPtr>();
490     MS_ASSERT(body_input != nullptr);
491     body_input->set_name(body_graph_name + "_input_" + std::to_string(j) + "_parameter");
492   }
493   for (size_t j = 1; j < return_new_inputs.size(); j++) {
494     if (utils::isa<CNodePtr>(return_new_inputs[j])) {
495       return_new_inputs[j]->cast<CNodePtr>()->set_fullname_with_scope(body_graph_name + "_output_" +
496                                                                       std::to_string(j - 1) + "_cnode");
497     } else if (utils::isa<ParameterPtr>(return_new_inputs[j])) {
498       return_new_inputs[j]->cast<ParameterPtr>()->set_name(body_graph_name + "_output_" + std::to_string(j - 1) +
499                                                            "_parameter");
500     }
501   }
502   *cond_graph_input_num = return_new_inputs.size() - 1;
503   loop_body_graph->set_attr("graph_name", MakeValue(body_graph_name));
504   return loop_body_graph;
505 }
506 
507 namespace {
CheckOnnxModel(const onnx::GraphProto & onnx_graph)508 STATUS CheckOnnxModel(const onnx::GraphProto &onnx_graph) {
509   // all input should in initialize
510   std::set<std::string> providers;
511   for (const auto &const_tensor : onnx_graph.initializer()) {
512     const auto &name = const_tensor.name();
513     if (providers.count(name) != 0) {
514       MS_LOG(ERROR) << "const tensor repeated";
515       return RET_ERROR;
516     }
517     providers.insert(name);
518   }
519   for (int i = 0; i < onnx_graph.input().size(); ++i) {
520     providers.insert(onnx_graph.input(i).name());
521   }
522   for (const auto &onnx_node : onnx_graph.node()) {
523     for (int i = 0; i < onnx_node.output_size(); i++) {
524       auto &output = onnx_node.output(i);
525       if (providers.count(output) != 0) {
526         MS_LOG(ERROR) << "Output tensor repeated";
527         return RET_ERROR;
528       }
529       providers.insert(output);
530     }
531   }
532   // all output should find
533   for (const auto &onnx_node : onnx_graph.node()) {
534     for (int i = 0; i < onnx_node.input_size(); i++) {
535       auto &input = onnx_node.input(i);
536       if (providers.count(input) == 0) {
537         MS_LOG(WARNING) << "Can not find node input: " << input;
538       }
539     }
540   }
541   return RET_OK;
542 }
543 }  // namespace
544 
Parse(const converter::ConverterParameters & flag)545 api::FuncGraphPtr OnnxModelParser::Parse(const converter::ConverterParameters &flag) {
546   auto model_file = flag.model_file;
547   NotSupportOp::GetInstance()->set_fmk_type("ONNX");
548   res_graph_ = std::make_shared<FuncGraph>();
549   MS_CHECK_TRUE_MSG(res_graph_ != nullptr, nullptr, "create FuncGraph failed");
550   auto status = InitOriginModel(model_file);
551   if (RET_OK != status) {
552     ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
553     MS_LOG(ERROR) << "init origin model failed.";
554     return nullptr;
555   }
556   MS_ASSERT(onnx_root_graph_ != nullptr);
557 
558   auto func_graph = std::dynamic_pointer_cast<FuncGraph>(res_graph_);
559   MS_CHECK_TRUE_RET(func_graph != nullptr, nullptr);
560   status = ConvertOnnxGraph(onnx_root_graph_, func_graph, &anf_nodes_map_, {}, "root_node");
561   if (RET_OK != status) {
562     ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
563     MS_LOG(ERROR) << "convert onnx graph failed.";
564     return nullptr;
565   }
566   static auto root_func_manager = Manage(func_graph);
567   MS_ASSERT(root_func_manager != nullptr);
568   for (auto &subgraph : all_subgraphs_) {
569     MS_ASSERT(subgraph != nullptr);
570     subgraph->set_manager(root_func_manager);
571     subgraph->set_attr("fmk", MakeValue(static_cast<int>(converter::kFmkTypeOnnx)));
572   }
573   res_graph_->set_attr("graph_name", MakeValue("main_graph"));
574   res_graph_->set_attr("fmk", MakeValue(static_cast<int>(converter::kFmkTypeOnnx)));
575   std::set<FuncGraphPtr> all_func_graphs = {};
576   GetAllFuncGraph(func_graph, &all_func_graphs);
577   if ((status = CommonAnfAdjust(all_func_graphs)) != RET_OK) {
578     MS_LOG(ERROR) << "AdjustForAnf failed.";
579     ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
580     return nullptr;
581   }
582   if ((status = Onnx2AnfAdjust(all_func_graphs)) != RET_OK) {
583     MS_LOG(ERROR) << "Onnx2AnfAdjust failed.";
584     ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
585     return nullptr;
586   }
587   auto unify_format = std::make_shared<UnifyFormatToNHWC>(kFmkTypeOnnx, false);
588   MS_CHECK_TRUE_MSG(unify_format != nullptr, nullptr, "create unify_format return nullptr");
589   if (!unify_format->Run(func_graph)) {
590     MS_LOG(ERROR) << "Run insert transpose failed.";
591     return nullptr;
592   }
593   return res_graph_;
594 }
595 
InitOriginModel(const std::string & model_file)596 STATUS OnnxModelParser::InitOriginModel(const std::string &model_file) {
597   MS_ASSERT(res_graph_ != nullptr);
598   auto status = ValidateFileStr(model_file, ".onnx");
599   if (status != RET_OK) {
600     MS_LOG(ERROR) << "INPUT ILLEGAL: modelFile must be *.onnx";
601     return status;
602   }
603 
604   status = ReadProtoFromBinaryFile(model_file, &onnx_model_);
605   if (status != RET_OK) {
606     MS_LOG(ERROR) << "Read onnx model file failed, model path: " << model_file;
607     ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
608     return status;
609   }
610   MS_ASSERT(onnx_model_ != nullptr);
611   OnnxNodeParser::set_opset_version(onnx_model_.opset_import().Get(0).version());
612   onnx_root_graph_ = onnx_model_.graph();
613   auto fmk_value_node = MakeValue(static_cast<int>(converter::kFmkTypeOnnx));
614   CHECK_NULL_RETURN(fmk_value_node);
615   res_graph_->set_attr("fmk", fmk_value_node);
616   return RET_OK;
617 }
618 
ConvertOnnxGraph(const onnx::GraphProto & onnx_graph,const FuncGraphPtr & anf_graph,std::unordered_map<std::string,AnfNodePtr> * anf_nodes_map,std::vector<AnfNodePtr> * extra_subgraph_inputs,const std::string & root_node_name)619 STATUS OnnxModelParser::ConvertOnnxGraph(const onnx::GraphProto &onnx_graph, const FuncGraphPtr &anf_graph,
620                                          std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map,
621                                          std::vector<AnfNodePtr> *extra_subgraph_inputs,
622                                          const std::string &root_node_name) {
623   MS_ASSERT(anf_graph != nullptr && anf_nodes_map != nullptr && extra_subgraph_inputs != nullptr);
624   STATUS status = RET_OK;
625   status = CheckOnnxModel(onnx_graph);
626   if (status != RET_OK) {
627     ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
628     MS_LOG(ERROR) << "input onnx model error: " << status;
629     return status;
630   }
631   status = ConvertConstTensors(onnx_graph, anf_graph, anf_nodes_map);
632   if (RET_OK != status) {
633     ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
634     MS_LOG(ERROR) << "convert const nodes failed.";
635     return RET_ERROR;
636   }
637 
638   status = ConvertGraphInputs(onnx_graph, anf_graph, anf_nodes_map);
639   if (RET_OK != status) {
640     ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
641     MS_LOG(ERROR) << "convert graph inputs failed.";
642     return RET_OK;
643   }
644 
645   status = ConvertNodes(onnx_graph, anf_graph, anf_nodes_map, extra_subgraph_inputs, root_node_name);
646   if (RET_OK != status) {
647     ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
648     MS_LOG(ERROR) << "convert nodes failed.";
649     return RET_ERROR;
650   }
651 
652   status = ConvertGraphOutputs(onnx_graph, anf_graph, *anf_nodes_map);
653   if (RET_OK != status) {
654     ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
655     MS_LOG(ERROR) << "convert graph outputs failed.";
656     return RET_ERROR;
657   }
658   // save original output tensor names.
659   if (root_node_name == "root_node") {
660     std::vector<std::string> output_names;
661     std::transform(onnx_graph.output().begin(), onnx_graph.output().end(), std::back_inserter(output_names),
662                    [](auto &graph_output) { return graph_output.name(); });
663     ConverterContext::GetInstance()->SetGraphOutputTensorNames(output_names);
664   }
665   return status;
666 }
667 
ConvertNodes(const onnx::GraphProto & onnx_graph,const FuncGraphPtr & anf_graph,std::unordered_map<std::string,AnfNodePtr> * anf_nodes_map,std::vector<AnfNodePtr> * graph_inputs,const std::string & root_node_name)668 STATUS OnnxModelParser::ConvertNodes(const onnx::GraphProto &onnx_graph, const FuncGraphPtr &anf_graph,
669                                      std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map,
670                                      std::vector<AnfNodePtr> *graph_inputs, const std::string &root_node_name) {
671   MS_ASSERT(anf_graph != nullptr && anf_nodes_map != nullptr && extra_subgraph_inputs != nullptr);
672   STATUS status = RET_OK;
673   for (const auto &onnx_node : onnx_graph.node()) {
674     ops::PrimitiveC *primitive_c = nullptr;
675     auto node_parser = registry::NodeParserRegistry::GetNodeParser(kFmkTypeOnnx, onnx_node.op_type());
676     if (node_parser != nullptr) {
677       primitive_c = node_parser->Parse(onnx_graph, onnx_node);
678     } else {
679       auto node_parser_builtin = OnnxNodeParserRegistry::GetInstance().GetNodeParser(onnx_node.op_type());
680       if (node_parser_builtin == nullptr) {
681         NotSupportOp::GetInstance()->InsertOp(onnx_node.op_type());
682         status = status == RET_OK ? RET_NOT_FIND_OP : status;
683         MS_LOG(ERROR) << "not support onnx data type " << onnx_node.op_type();
684       }
685       if (status != RET_OK) {
686         continue;
687       }
688       MS_LOG(INFO) << "parse op:" << onnx_node.op_type();
689       primitive_c = node_parser_builtin->Parse(onnx_graph, onnx_node);
690     }
691 
692     if (primitive_c == nullptr) {
693       MS_LOG(ERROR) << "parse node " << onnx_node.op_type() << " failed.";
694       status = RET_ERROR;
695       continue;
696     }
697     if (primitive_c->GetAttr(ops::kOriginalFormat) == nullptr) {
698       primitive_c->AddAttr(mindspore::ops::kOriginalFormat, MakeValue<int64_t>(NCHW));
699     }
700     status = ConvertOpQuantParams(onnx_node, primitive_c);
701     if (status != RET_OK) {
702       MS_LOG(ERROR) << "convert " << onnx_node.op_type() << " quant param failed.";
703       continue;
704     }
705     // build CNode
706     status = BuildCNode(onnx_node, anf_graph, anf_nodes_map, graph_inputs, primitive_c, root_node_name);
707     if (status != RET_OK) {
708       MS_LOG(ERROR) << "build cnode " << onnx_node.op_type() << " failed.";
709     }
710 
711     if (onnx_node.op_type() == "Loop") {
712       child_root_map_[onnx_node.name()] = root_node_name;
713       control_nodes_map_[onnx_node.name()] = anf_nodes_map;
714 
715       status = ConvertLoopOnnxNode(onnx_node, anf_nodes_map, root_node_name);
716       if (status != RET_OK) {
717         MS_LOG(ERROR) << "build loop node  failed.";
718       }
719     }
720     if (onnx_node.op_type() == "If") {
721       child_root_map_[onnx_node.name()] = root_node_name;
722       control_nodes_map_[onnx_node.name()] = anf_nodes_map;
723 
724       status = ConvertIfOnnxNode(onnx_node, anf_nodes_map, root_node_name);
725       if (status != RET_OK) {
726         MS_LOG(ERROR) << "build if node  failed.";
727       }
728     }
729   }
730   return status;
731 }
732 
ConvertIfSubgraph(const onnx::GraphProto & subgraph_proto,const FuncGraphPtr & subgraph,const std::string & subgraph_name,const std::string & if_node_name,const std::string & root_node_name)733 STATUS OnnxModelParser::ConvertIfSubgraph(const onnx::GraphProto &subgraph_proto, const FuncGraphPtr &subgraph,
734                                           const std::string &subgraph_name, const std::string &if_node_name,
735                                           const std::string &root_node_name) {
736   MS_ASSERT(subgraph != nullptr);
737   std::unordered_map<std::string, AnfNodePtr> anf_nodes_map;
738   std::vector<AnfNodePtr> subgraph_extra_inputs;
739   auto status = ConvertOnnxGraph(subgraph_proto, subgraph, &anf_nodes_map, &subgraph_extra_inputs, if_node_name);
740   if (status != RET_OK) {
741     MS_LOG(ERROR) << "convert loop OnnxGraph failed";
742     return status;
743   }
744   subgraph->set_attr("graph_name", MakeValue(subgraph_name));
745   // update subgraph in out name
746   for (int j = 0; j < subgraph_proto.input_size(); j++) {
747     auto input_anode_iter = anf_nodes_map.find(subgraph_proto.input(j).name());
748     if (input_anode_iter == anf_nodes_map.end()) {
749       MS_LOG(ERROR) << "can not find input anode";
750       return RET_ERROR;
751     }
752     auto input_parameter = input_anode_iter->second->cast<ParameterPtr>();
753     MS_CHECK_TRUE_MSG(input_parameter != nullptr, RET_ERROR, "subgraph input should be a parameter");
754     input_parameter->set_name(subgraph_name + "_input_" + std::to_string(j) + "_parameter");
755   }
756   for (size_t j = 0; j < subgraph_extra_inputs.size(); j++) {
757     auto input_parameter = subgraph_extra_inputs[j]->cast<ParameterPtr>();
758     MS_CHECK_TRUE_MSG(input_parameter != nullptr, RET_ERROR, "subgraph input should be a parameter");
759     input_parameter->set_name(subgraph_name + "_input_" + std::to_string(j + subgraph_proto.input_size()) +
760                               "_parameter");
761   }
762   auto return_node = subgraph->get_return();
763   MS_CHECK_TRUE_MSG(return_node != nullptr, RET_ERROR, "subgraph has no return");
764   std::vector<AnfNodePtr> return_act_inputs;
765   int start_index = 0;
766   if (subgraph_proto.output_size() > 1) {
767     auto return_cnode = return_node->input(1)->cast<CNodePtr>();
768     MS_ASSERT(return_cnode != nullptr);
769     return_act_inputs = return_cnode->inputs();
770     start_index = 1;
771   } else {
772     return_act_inputs = {return_node->input(1)};
773   }
774   for (size_t j = start_index; j < return_act_inputs.size(); j++) {
775     if (utils::isa<CNodePtr>(return_act_inputs[j])) {
776       return_act_inputs[j]->cast<CNodePtr>()->set_fullname_with_scope(subgraph_name + "_output_" +
777                                                                       std::to_string(j - start_index) + "_cnode");
778     } else if (utils::isa<ParameterPtr>(return_act_inputs[j])) {
779       return_act_inputs[j]->cast<ParameterPtr>()->set_name(subgraph_name + "_output_" +
780                                                            std::to_string(j - start_index) + "_parameter");
781     }
782   }
783   return RET_OK;
784 }
785 
ConvertIfOnnxNode(const onnx::NodeProto & onnx_node,std::unordered_map<std::string,AnfNodePtr> * anf_root_nodes_map,const std::string & root_node_name)786 STATUS OnnxModelParser::ConvertIfOnnxNode(const onnx::NodeProto &onnx_node,
787                                           std::unordered_map<std::string, AnfNodePtr> *anf_root_nodes_map,
788                                           const std::string &root_node_name) {
789   MS_ASSERT(anf_root_nodes_map != nullptr);
790   FuncGraphPtr then_branch_graph = nullptr;
791   FuncGraphPtr else_branch_graph = nullptr;
792   FuncGraphPtr subgraph = nullptr;
793   std::string subgraph_name;
794   auto &if_node_name = onnx_node.name();
795 
796   for (int i = 0; i < onnx_node.attribute_size(); i++) {
797     auto &attr = onnx_node.attribute(i);
798     auto &subgraph_proto = attr.g();
799     if (attr.name().find("then_branch") != std::string::npos) {
800       subgraph_name = if_node_name + "_then_branch";
801       then_branch_graph = std::make_shared<FuncGraph>();
802       MS_CHECK_TRUE_MSG(then_branch_graph != nullptr, RET_NULL_PTR, "create then_branch_graph return nullptr");
803       auto status = ConvertIfSubgraph(subgraph_proto, then_branch_graph, subgraph_name, if_node_name, root_node_name);
804       if (status != RET_OK) {
805         MS_LOG(ERROR) << "build if node else branch failed.";
806       }
807     } else if (attr.name().find("else_branch") != std::string::npos) {
808       subgraph_name = if_node_name + "_else_branch";
809       else_branch_graph = std::make_shared<FuncGraph>();
810       MS_CHECK_TRUE_MSG(else_branch_graph != nullptr, RET_NULL_PTR, "create else_branch_graph return nullptr");
811       auto status = ConvertIfSubgraph(subgraph_proto, else_branch_graph, subgraph_name, if_node_name, root_node_name);
812       if (status != RET_OK) {
813         MS_LOG(ERROR) << "build if node else branch failed.";
814       }
815     } else {
816       continue;
817     }
818   }
819   all_subgraphs_.emplace_back(then_branch_graph);
820   all_subgraphs_.emplace_back(else_branch_graph);
821   auto then_value_node = NewValueNode(then_branch_graph);
822   MS_CHECK_TRUE_MSG(then_value_node != nullptr, RET_NULL_PTR, "create then_value_node return nullptr");
823   auto else_value_node = NewValueNode(else_branch_graph);
824   MS_CHECK_TRUE_MSG(else_value_node != nullptr, RET_NULL_PTR, "create else_value_node return nullptr");
825   auto root_if_node = GetCNodeFromControlFlowNodesMap(if_node_name, control_nodes_map_);
826   MS_CHECK_TRUE_MSG(root_if_node != nullptr, RET_ERROR, "can not find root_if_node is control_nodes_map");
827   auto if_new_inputs = root_if_node->inputs();
828   if_new_inputs.insert(if_new_inputs.begin() + 1, {then_value_node, else_value_node});
829 
830   std::vector<AnfNodePtr> if_new_input_not_same{};
831   std::set<AnfNodePtr> if_set{};
832   for (auto &input : if_new_inputs) {
833     if (if_set.find(input) != if_set.end()) {
834       continue;
835     }
836     if_new_input_not_same.push_back(input);
837     if_set.insert(input);
838   }
839 
840   root_if_node->set_inputs(if_new_input_not_same);
841   return RET_OK;
842 }
843 
BuildCNode(const onnx::NodeProto & onnx_node,const FuncGraphPtr & anf_graph,std::unordered_map<std::string,AnfNodePtr> * anf_nodes_map,std::vector<AnfNodePtr> * graph_inputs,ops::PrimitiveC * primitive_c,std::string loop_name)844 STATUS OnnxModelParser::BuildCNode(const onnx::NodeProto &onnx_node, const FuncGraphPtr &anf_graph,
845                                    std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map,
846                                    std::vector<AnfNodePtr> *graph_inputs, ops::PrimitiveC *primitive_c,
847                                    std::string loop_name) {
848   MS_ASSERT(anf_graph != nullptr && anf_nodes_map != nullptr && graph_inputs != nullptr && primitive_c != nullptr);
849   std::vector<AnfNodePtr> op_inputs;
850   for (const auto &input_name : onnx_node.input()) {
851     if (input_name.empty()) {
852       continue;
853     }
854 
855     if (anf_nodes_map->find(input_name) != anf_nodes_map->end()) {
856       op_inputs.push_back(anf_nodes_map->at(input_name));
857     } else {
858       // subgraph may refer root graph nodes
859       std::vector<CNodePtr> need_add_input_nodes;
860       auto ext_subgraph_input = anf_graph->add_parameter();
861       MS_CHECK_TRUE_MSG(ext_subgraph_input != nullptr, RET_NULL_PTR, "create parameter return nullptr");
862       ParameterPtr inner_extra_paramter = nullptr;
863       while (!loop_name.empty() && child_root_map_.find(loop_name) != child_root_map_.end()) {
864         auto cur_node_map = control_nodes_map_[loop_name];
865         CHECK_NULL_RETURN(cur_node_map);
866         if (cur_node_map->find(input_name) != cur_node_map->end()) {
867           auto outside_input_node = cur_node_map->at(input_name);
868           CHECK_NULL_RETURN(outside_input_node);
869           // copy outside input parameter value to inside subgraph
870           ext_subgraph_input->set_abstract(outside_input_node->abstract());
871           ext_subgraph_input->set_name(input_name);
872           if (outside_input_node->isa<Parameter>()) {
873             auto parameter = outside_input_node->cast<ParameterPtr>();
874             if (!parameter->has_default()) {
875               MS_LOG(ERROR) << "outside_input_node should has data.";
876               return RET_ERROR;
877             }
878             auto tensor_info = parameter->default_param()->cast<tensor::TensorPtr>();
879             auto copy_tensor_info = CreateTensorInfo(tensor_info->data_c(), tensor_info->Size(), tensor_info->shape(),
880                                                      tensor_info->data_type());
881             if (copy_tensor_info == nullptr) {
882               MS_LOG(ERROR) << "memcpy failed.";
883               return RET_ERROR;
884             }
885             ext_subgraph_input->set_default_param(copy_tensor_info);
886           } else {
887             // output inside cnode need make extra input
888             graph_inputs->emplace_back(ext_subgraph_input);
889             if (cur_node_map->find(loop_name) != cur_node_map->end()) {
890               CHECK_NULL_RETURN(cur_node_map->at(loop_name));
891               auto control_node = cur_node_map->at(loop_name)->cast<CNodePtr>();
892               MS_ASSERT(control_node != nullptr);
893               control_node->add_input(outside_input_node);
894             } else {
895               MS_LOG(ERROR) << "loop node: " << loop_name << " not found in cur node map.";
896               return RET_ERROR;
897             }
898             for (auto &control_node : need_add_input_nodes) {
899               CHECK_NULL_RETURN(control_node);
900               auto func_graph = control_node->func_graph();
901               auto extra_input_parameter = func_graph->add_parameter();
902               MS_CHECK_TRUE_MSG(extra_input_parameter != nullptr, RET_NULL_PTR, "create parameter return nullptr");
903               extra_input_parameter->set_name(input_name);
904               extra_input_parameter->set_abstract(outside_input_node->abstract());
905               control_node->add_input(extra_input_parameter);
906             }
907           }
908           op_inputs.push_back(ext_subgraph_input);
909           anf_nodes_map->emplace(input_name, ext_subgraph_input);
910           break;
911         } else {
912           if (cur_node_map->find(loop_name) != cur_node_map->end()) {
913             CHECK_NULL_RETURN(cur_node_map->at(loop_name));
914             need_add_input_nodes.emplace_back(cur_node_map->at(loop_name)->cast<CNodePtr>());
915           } else {
916             MS_LOG(ERROR) << "loop node: " << loop_name << " not found in cur node map.";
917             return RET_ERROR;
918           }
919           loop_name = child_root_map_[loop_name];
920         }
921       }
922     }
923   }
924   auto new_cnode = anf_graph->NewCNode(std::shared_ptr<ops::PrimitiveC>(primitive_c), op_inputs);
925   if (new_cnode == nullptr) {
926     MS_LOG(ERROR) << "new cnode error";
927     return RET_ERROR;
928   }
929   new_cnode->set_fullname_with_scope(onnx_node.op_type() + "_" + onnx_node.output(0));
930   auto status = BuildOpOutputs(onnx_node, anf_graph, anf_nodes_map, new_cnode);
931   return status;
932 }
933 
ConvertOpQuantParams(const onnx::NodeProto & onnx_node,ops::PrimitiveC * primitive_c)934 STATUS OnnxModelParser::ConvertOpQuantParams(const onnx::NodeProto &onnx_node, ops::PrimitiveC *primitive_c) {
935   MS_ASSERT(primitive_c != nullptr);
936   auto status = ParseQuantParam(onnx_node);
937   if (status != RET_OK) {
938     MS_LOG(ERROR) << "parse quant param failed.";
939     return RET_ERROR;
940   }
941   // set input tensors
942   auto quant_params_holder = std::make_shared<QuantParamHolder>(onnx_node.input_size(), onnx_node.output_size());
943   MS_CHECK_TRUE_MSG(quant_params_holder != nullptr, RET_NULL_PTR, "create QuantParamHolder return nullptr");
944   for (int i = 0; i < onnx_node.input_size(); ++i) {
945     const auto &input_name = onnx_node.input(i);
946     std::vector<schema::QuantParamT> quant_params;
947     status = SetTensorQuantParam(input_name, &quant_params);
948     if (status != RET_OK) {
949       MS_LOG(ERROR) << "set input tensor quant param failed.";
950       return status;
951     }
952     quant_params_holder->set_input_quant_param(i, quant_params);
953   }
954   // set out tensors
955   for (int i = 0; i < onnx_node.output_size(); ++i) {
956     const auto &output_name = onnx_node.output(i);
957     std::vector<schema::QuantParamT> quant_params;
958     status = SetTensorQuantParam(output_name, &quant_params);
959     if (status != RET_OK) {
960       MS_LOG(ERROR) << "set output tensor quant param failed.";
961       return status;
962     }
963     quant_params_holder->set_output_quant_param(i, quant_params);
964   }
965   primitive_c->AddAttr("quant_params", quant_params_holder);
966   return RET_OK;
967 }
968 
ParseQuantParam(const onnx::NodeProto & onnx_node)969 STATUS OnnxModelParser::ParseQuantParam(const onnx::NodeProto &onnx_node) {
970   for (const auto &onnx_node_attr : onnx_node.attribute()) {
971     if (onnx_node_attr.name() == "Y_scale") {
972       float scale = onnx_node_attr.f();
973       if (BuildParameterNodeForQuantParam(&scale, "scale_" + onnx_node.output(0), kNumberTypeFloat32) != RET_OK) {
974         MS_LOG(ERROR) << "parse quant param failed.";
975         return RET_ERROR;
976       }
977     } else if (onnx_node_attr.name() == "Y_zero_point") {
978       int64_t zero_point = onnx_node_attr.i();
979       if (BuildParameterNodeForQuantParam(&zero_point, "zero_point_" + onnx_node.output(0), kNumberTypeInt64) !=
980           RET_OK) {
981         MS_LOG(ERROR) << "parse quant param failed.";
982         return RET_ERROR;
983       }
984     }
985   }
986   return RET_OK;
987 }
988 
SetTensorQuantParam(const std::string & tensor_name,std::vector<QuantParamT> * quant_params)989 STATUS OnnxModelParser::SetTensorQuantParam(const std::string &tensor_name, std::vector<QuantParamT> *quant_params) {
990   MS_ASSERT(quant_params != nullptr);
991   quant_params->clear();
992   auto quant_param = std::make_unique<QuantParamT>();
993   MS_CHECK_TRUE_MSG(quant_param != nullptr, RET_NULL_PTR, "create QuantParamT return nullptr");
994   for (int i = 0; i < onnx_root_graph_.quantization_annotation_size(); ++i) {
995     auto tensor_annotation = onnx_root_graph_.quantization_annotation(i);
996     if (!tensor_annotation.has_tensor_name() || tensor_annotation.tensor_name() != tensor_name) {
997       continue;
998     }
999     for (const auto &item : tensor_annotation.quant_parameter_tensor_names()) {
1000       if (!item.has_key() || !item.has_value()) {
1001         continue;
1002       }
1003 
1004       const auto &quant_tensor_name = item.value();
1005       if (item.key() == "SCALE_TENSOR") {
1006         auto status = CopyTensorQuantParam(quant_tensor_name, quant_param.get(), true);
1007         if (status != RET_OK) {
1008           MS_LOG(ERROR) << "quant param scale get failed";
1009           return status;
1010         }
1011       } else if (item.key() == "ZERO_POINT_TENSOR") {
1012         auto status = CopyTensorQuantParam(quant_tensor_name, quant_param.get(), false);
1013         if (status != RET_OK) {
1014           MS_LOG(ERROR) << "quant param zero_point get failed";
1015           return status;
1016         }
1017       }
1018     }
1019     break;
1020   }
1021   if (quant_param->inited) {
1022     quant_params->push_back(*std::move(quant_param));
1023     return RET_OK;
1024   }
1025   return SetTensorQuantParamFromNode(tensor_name, quant_params);
1026 }
1027 
SetTensorQuantParamFromNode(const std::string & tensor_name,std::vector<QuantParamT> * quant_params)1028 STATUS OnnxModelParser::SetTensorQuantParamFromNode(const std::string &tensor_name,
1029                                                     std::vector<QuantParamT> *quant_params) {
1030   MS_ASSERT(quant_params != nullptr);
1031   quant_params->clear();
1032   auto quant_param = std::make_unique<QuantParamT>();
1033   MS_CHECK_TRUE_MSG(quant_param != nullptr, RET_NULL_PTR, "create QuantParamT return nullptr");
1034   if (OnnxNodeParser::opset_version() <= 15) {
1035     quant_param->multiplier = 0;
1036   }
1037   std::string quant_tensor_name = "scale_" + tensor_name;
1038   auto status = CopyTensorQuantParam(quant_tensor_name, quant_param.get(), true);
1039   if (status != RET_OK) {
1040     MS_LOG(ERROR) << "quant param scale get failed";
1041     return status;
1042   }
1043   quant_tensor_name = "zero_point_" + tensor_name;
1044   status = CopyTensorQuantParam(quant_tensor_name, quant_param.get(), false);
1045   if (status != RET_OK) {
1046     MS_LOG(ERROR) << "quant param zero_point get failed";
1047     return status;
1048   }
1049   if (quant_param->inited) {
1050     quant_params->push_back(*std::move(quant_param));
1051   } else {
1052     std::vector<schema::QuantParamT> notinited_quant_params(1);
1053     *quant_params = notinited_quant_params;
1054   }
1055   return RET_OK;
1056 }
1057 
CopyTensorQuantParam(const std::string & tensor_name,QuantParamT * quant_param,bool scale_or_not)1058 STATUS OnnxModelParser::CopyTensorQuantParam(const std::string &tensor_name, QuantParamT *quant_param,
1059                                              bool scale_or_not) {
1060   MS_ASSERT(quant_param != nullptr);
1061   auto iter = anf_nodes_map_.find(tensor_name);
1062   if (iter == anf_nodes_map_.end()) {
1063     MS_LOG(DEBUG) << "has no quant param";
1064     return RET_OK;
1065   }
1066   if (!utils::isa<ParameterPtr>(iter->second)) {
1067     MS_LOG(ERROR) << "quant param get failed";
1068     return RET_ERROR;
1069   }
1070   auto quant_parameter_node = iter->second->cast<ParameterPtr>();
1071   MS_ASSERT(quant_parameter_node != nullptr);
1072   if (!quant_parameter_node->has_default()) {
1073     MS_LOG(ERROR) << "quant param get failed";
1074     return RET_ERROR;
1075   }
1076   auto tensor_info = quant_parameter_node->default_param()->cast<tensor::TensorPtr>();
1077   if (tensor_info == nullptr) {
1078     MS_LOG(ERROR) << "parameterNode's default param is not tensor::TensorPtr";
1079     return RET_ERROR;
1080   }
1081   if (tensor_info->data_c() == nullptr) {
1082     MS_LOG(ERROR) << "parameterNode's default param has no data";
1083     return RET_ERROR;
1084   }
1085   if (scale_or_not) {
1086     quant_param->scale = *reinterpret_cast<float *>(tensor_info->data_c());
1087     quant_param->inited = true;
1088   } else {
1089     quant_param->zeroPoint = *reinterpret_cast<int64_t *>(tensor_info->data_c());
1090     quant_param->inited = true;
1091   }
1092   return RET_OK;
1093 }
1094 
AddTensorListStackNode(const AnfNodePtr & root_while_node,const onnx::NodeProto & onnx_node,int act_outputs_num,int body_output_size)1095 STATUS OnnxModelParser::AddTensorListStackNode(const AnfNodePtr &root_while_node, const onnx::NodeProto &onnx_node,
1096                                                int act_outputs_num, int body_output_size) {
1097   MS_ASSERT(root_while_node != nullptr);
1098   auto &loop_node_name = onnx_node.name();
1099   auto root_anf_graph = root_while_node->func_graph();
1100   auto stack_elem_node = CreateConstParamter(root_anf_graph, -1);
1101   MS_CHECK_TRUE_MSG(stack_elem_node != nullptr, RET_NULL_PTR, "create const parameter return nullptr");
1102   stack_elem_node->set_name(loop_node_name + "_element_shape");
1103   for (int j = 0; j < act_outputs_num; j++) {
1104     auto output_size = onnx_node.output_size();
1105     auto &loop_output_name = onnx_node.output(output_size - act_outputs_num + j);
1106     MS_ASSERT(control_nodes_map_.find(loop_node_name) != control_nodes_map_.end());
1107     MS_ASSERT(control_nodes_map_[loop_node_name]->find(loop_output_name) != control_nodes_map_[loop_node_name]->end());
1108     auto &while_output_node = control_nodes_map_[loop_node_name]->at(loop_output_name);
1109     MS_CHECK_TRUE_MSG(while_output_node != nullptr, RET_ERROR, "can not find while_output_node is control_nodes_map");
1110     auto tensor_list_stack_prim = std::make_shared<ops::TensorListStack>();
1111     if (tensor_list_stack_prim == nullptr) {
1112       MS_LOG(ERROR) << "create stack failed";
1113       return RET_ERROR;
1114     }
1115     tensor_list_stack_prim->set_num_elements(-1);
1116     auto stack_value_node = NewValueNode(tensor_list_stack_prim);
1117     MS_CHECK_TRUE_MSG(stack_value_node != nullptr, RET_NULL_PTR, "create stack_value_node return nullptr");
1118     std::vector<AnfNodePtr> stack_inputs = {stack_value_node, while_output_node, stack_elem_node};
1119     auto tensorlist_stack_cnode = root_anf_graph->NewCNode(stack_inputs);
1120     if (tensorlist_stack_cnode == nullptr) {
1121       MS_LOG(ERROR) << "new cnode error";
1122       return RET_ERROR;
1123     }
1124     tensorlist_stack_cnode->set_fullname_with_scope(loop_node_name + "_tensorlist_stack_node_" + std::to_string(j));
1125     tensorlist_stack_cnode->set_abstract(stack_elem_node->abstract());
1126 
1127     // update getitem value output index
1128     auto new_get_item_value = NewValueNode(MakeValue<int>(body_output_size - act_outputs_num + j));
1129     MS_CHECK_TRUE_MSG(new_get_item_value != nullptr, RET_NULL_PTR, "create new_get_item_value return nullptr");
1130     MS_ASSERT(while_output_node->cast<CNodePtr>() != nullptr);
1131     while_output_node->cast<CNodePtr>()->set_input(2, new_get_item_value);
1132     // insert tensorliststack after while_output
1133     (*control_nodes_map_[loop_node_name])[loop_output_name] = tensorlist_stack_cnode;
1134   }
1135   return RET_OK;
1136 }
1137 
1138 // onnx loop scan_output need through tensorlist op,while node need add new inputs
AddTensorArrayEdge(const FuncGraphPtr & anf_graph,std::vector<AnfNodePtr> * return_new_inputs,const std::string & loop_node_name,std::vector<AnfNodePtr> * body_graph_inputs,int act_output_num)1139 STATUS OnnxModelParser::AddTensorArrayEdge(const FuncGraphPtr &anf_graph, std::vector<AnfNodePtr> *return_new_inputs,
1140                                            const std::string &loop_node_name,
1141                                            std::vector<AnfNodePtr> *body_graph_inputs, int act_output_num) {
1142   MS_ASSERT(anf_graph != nullptr && return_new_inputs != nullptr && body_graph_inputs != nullptr);
1143   // body graph output is  trip_count,cond_count,loop_var,placeholder,scan_outputs
1144   auto root_while_node = GetCNodeFromControlFlowNodesMap(loop_node_name, control_nodes_map_);
1145   MS_CHECK_TRUE_MSG(root_while_node != nullptr, RET_ERROR, "can not find root_while_node is control_nodes_map");
1146   if (root_while_node == nullptr) {
1147     MS_LOG(ERROR) << "anf root node map cannot find loop node" << loop_node_name;
1148     return RET_ERROR;
1149   }
1150   auto anf_root_graph = root_while_node->func_graph();
1151   auto root_item_index_parameter = CreateConstParamter(anf_root_graph, 0);
1152   MS_CHECK_TRUE_MSG(root_item_index_parameter != nullptr, RET_NULL_PTR,
1153                     "create root_item_index_parameter return nullptr");
1154   root_item_index_parameter->set_name(loop_node_name + "_item_index");
1155   root_while_node->add_input(root_item_index_parameter);
1156   // fake parameter need pass by root while node input
1157   auto item_index_parameter = anf_graph->add_parameter();
1158   MS_CHECK_TRUE_MSG(item_index_parameter != nullptr, RET_NULL_PTR, "create item_index_parameter return nullptr");
1159   item_index_parameter->set_name(loop_node_name + "_item_index_2");
1160   item_index_parameter->set_abstract(root_item_index_parameter->abstract());
1161   body_graph_inputs->emplace_back(item_index_parameter);
1162   // item index++ edge
1163   auto add_value_node = CreateValueNode(schema::PrimitiveType_AddFusion);
1164   if (add_value_node == nullptr) {
1165     MS_LOG(ERROR) << "create add failed.";
1166     return RET_NULL_PTR;
1167   }
1168   auto add_one_input = CreateConstParamter(anf_graph, 1);
1169   MS_CHECK_TRUE_MSG(root_item_index_parameter != nullptr, RET_NULL_PTR, "create add_one_input return nullptr");
1170   add_one_input->set_name(loop_node_name + "_const_placeholder_1");
1171   std::vector<AnfNodePtr> add_inputs = {add_value_node, item_index_parameter, add_one_input};
1172   auto add_cnode = anf_graph->NewCNode(add_inputs);
1173   if (add_cnode == nullptr) {
1174     MS_LOG(ERROR) << "new cnode error";
1175     return RET_ERROR;
1176   }
1177   add_cnode->set_fullname_with_scope(loop_node_name + "item_index_add_node");
1178   add_cnode->set_abstract(root_item_index_parameter->abstract());
1179   // return node inputs will be trip_count,cond_out,loop_var,placeholder,tensorarray...
1180   if (static_cast<int>(return_new_inputs->size()) < act_output_num || act_output_num < 0) {
1181     MS_LOG(ERROR) << "act_output_num out of range of return_new_inputs";
1182     return RET_ERROR;
1183   }
1184   return_new_inputs->insert(return_new_inputs->end() - act_output_num, add_cnode);
1185 
1186   for (int i = 0; i < act_output_num; i++) {
1187     // tensor_array need as root while input
1188     auto while_tensor_array_input = anf_root_graph->add_parameter();
1189     MS_CHECK_TRUE_MSG(while_tensor_array_input != nullptr, RET_NULL_PTR,
1190                       "create while_tensor_array_input return nullptr");
1191     std::vector<int> tensor_list_data(kTensorListDatasize);
1192     tensor_list_data[kTypeIndex] = kTypeUnknown;
1193     tensor_list_data[kElementShapeIndex] = 0;
1194     tensor_list_data[kTensorsNumIndex] = -1;
1195     if (INT_MUL_OVERFLOW_THRESHOLD(tensor_list_data.size(), sizeof(int), SIZE_MAX)) {
1196       MS_LOG(ERROR) << "data_size overflow";
1197       return RET_ERROR;
1198     }
1199     auto tensor_info = CreateTensorInfo(tensor_list_data.data(), tensor_list_data.size() * sizeof(int),
1200                                         {static_cast<int64_t>(tensor_list_data.size())}, kObjectTypeTensorType);
1201     if (tensor_info == nullptr) {
1202       MS_LOG(ERROR) << "Create tensor info failed";
1203       return RET_ERROR;
1204     }
1205     auto abstract_tensor = tensor_info->ToAbstract();
1206     if (abstract_tensor == nullptr) {
1207       MS_LOG(ERROR) << "Create tensor abstarct failed";
1208       return RET_ERROR;
1209     }
1210     while_tensor_array_input->set_abstract(abstract_tensor);
1211     while_tensor_array_input->set_default_param(tensor_info);
1212     while_tensor_array_input->set_name(loop_node_name + "_scan_outputs_tensorarray_while_input");
1213     root_while_node->add_input(while_tensor_array_input);
1214 
1215     auto subgraph_tensor_array_input = anf_graph->add_parameter();
1216     MS_CHECK_TRUE_MSG(subgraph_tensor_array_input != nullptr, RET_NULL_PTR,
1217                       "create subgraph_tensor_array_input return nullptr");
1218     subgraph_tensor_array_input->set_name(loop_node_name + "_scan_outputs_tensorarray_body_fg_input");
1219     subgraph_tensor_array_input->set_abstract(abstract_tensor);
1220     body_graph_inputs->emplace_back(subgraph_tensor_array_input);
1221     // skip trip_count ,cond_out,loop_var,no_loop_var,place_holder, output
1222     auto loop_output_idx = return_new_inputs->size() - act_output_num + i;
1223     auto loop_output_node = (*return_new_inputs)[loop_output_idx];
1224     auto set_item_value_node = CreateValueNode(schema::PrimitiveType_TensorListSetItem);
1225     if (set_item_value_node == nullptr) {
1226       MS_LOG(ERROR) << "create tensor list set item failed.";
1227       return RET_NULL_PTR;
1228     }
1229     std::vector<AnfNodePtr> set_item_inputs = {set_item_value_node, subgraph_tensor_array_input, item_index_parameter,
1230                                                loop_output_node};
1231     auto tensorlist_setitem_cnode = anf_graph->NewCNode(set_item_inputs);
1232     if (tensorlist_setitem_cnode == nullptr) {
1233       MS_LOG(ERROR) << "new cnode error";
1234       return RET_ERROR;
1235     }
1236     tensorlist_setitem_cnode->set_fullname_with_scope(loop_node_name + "_tensorlist_setitem_node");
1237     tensorlist_setitem_cnode->set_abstract(abstract_tensor);
1238     // loop output need replace by tensorliststack_output
1239     (*return_new_inputs)[loop_output_idx] = tensorlist_setitem_cnode;
1240   }
1241 
1242   return RET_OK;
1243 }
1244 
ConvertLoopOnnxNode(const onnx::NodeProto & onnx_node,std::unordered_map<std::string,AnfNodePtr> * anf_root_nodes_map,const std::string & root_node_name)1245 STATUS OnnxModelParser::ConvertLoopOnnxNode(const onnx::NodeProto &onnx_node,
1246                                             std::unordered_map<std::string, AnfNodePtr> *anf_root_nodes_map,
1247                                             const std::string &root_node_name) {
1248   MS_ASSERT(anf_root_nodes_map != nullptr);
1249   for (int i = 0; i < onnx_node.attribute_size(); i++) {
1250     auto &attr = onnx_node.attribute(i);
1251     if (attr.name() != "body" || attr.type() != onnx::AttributeProto_AttributeType_GRAPH) {
1252       continue;
1253     }
1254     auto &subgraph_proto = attr.g();
1255     int cond_graph_input_num = -1;
1256     auto loop_body_graph = BuildBodyGraph(onnx_node, subgraph_proto, &cond_graph_input_num);
1257     MS_CHECK_TRUE_MSG(loop_body_graph != nullptr, RET_NULL_PTR, "create loop_body_graph return nullptr");
1258     auto root_while_node = GetCNodeFromControlFlowNodesMap(onnx_node.name(), control_nodes_map_);
1259     MS_CHECK_TRUE_MSG(root_while_node != nullptr, RET_ERROR, "can not find root_while_node");
1260     auto loop_cond_graph = BuildCondGraph(root_while_node, cond_graph_input_num, onnx_node.name() + "_cond_graph");
1261     MS_CHECK_TRUE_MSG(loop_cond_graph != nullptr, RET_NULL_PTR, "create loop_cond_graph return nullptr");
1262     all_subgraphs_.emplace_back(loop_body_graph);
1263     all_subgraphs_.emplace_back(loop_cond_graph);
1264     auto body_value_node = NewValueNode(loop_body_graph);
1265     MS_CHECK_TRUE_MSG(body_value_node != nullptr, RET_NULL_PTR, "create body_value_node return nullptr");
1266     auto inputs = root_while_node->inputs();
1267     auto cond_value_node = NewValueNode(loop_cond_graph);
1268     MS_CHECK_TRUE_MSG(cond_value_node != nullptr, RET_NULL_PTR, "create cond_value_node return nullptr");
1269     inputs.insert(inputs.begin() + 1, {cond_value_node, body_value_node});
1270     root_while_node->set_inputs(inputs);
1271   }
1272   return RET_OK;
1273 }
1274 
BuildParameterNodeForQuantParam(const void * data,const std::string & name,TypeId type)1275 STATUS OnnxModelParser::BuildParameterNodeForQuantParam(const void *data, const std::string &name, TypeId type) {
1276   MS_ASSERT(data != nullptr);
1277   if (type != kNumberTypeInt64 && type != kNumberTypeFloat32) {
1278     MS_LOG(ERROR) << "quant param type don't support.";
1279     return RET_NOT_SUPPORT;
1280   }
1281   auto parameter_node = res_graph_->add_parameter();
1282   MS_CHECK_TRUE_MSG(parameter_node != nullptr, RET_NULL_PTR, "create parameter return nullptr");
1283   auto abstract_tensor = CreateTensorAbstract({}, type);
1284   if (abstract_tensor == nullptr) {
1285     MS_LOG(ERROR) << "Create tensor abstarct failed";
1286     return RET_ERROR;
1287   }
1288   parameter_node->set_abstract(abstract_tensor);
1289   parameter_node->set_name(name);
1290   int data_size = 0;
1291   if (type == kNumberTypeFloat32) {
1292     data_size = sizeof(float);
1293   } else {
1294     data_size = sizeof(int64_t);
1295   }
1296   auto tensor_info = CreateTensorInfo(data, data_size, {1}, type);
1297   if (tensor_info == nullptr) {
1298     MS_LOG(ERROR) << "create tensor info failed.";
1299     return RET_ERROR;
1300   }
1301   parameter_node->set_default_param(tensor_info);
1302   anf_nodes_map_.emplace(name, parameter_node);
1303   return RET_OK;
1304 }
1305 
1306 REG_MODEL_PARSER(kFmkTypeOnnx, LiteModelParserCreator<OnnxModelParser>)
1307 }  // namespace lite
1308 }  // namespace mindspore
1309