1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "tools/converter/parser/onnx/onnx_model_parser.h"
18 #include <algorithm>
19 #include <memory>
20 #include <set>
21 #include <vector>
22 #include <unordered_map>
23 #include <utility>
24 #include "include/registry/node_parser_registry.h"
25 #include "tools/optimizer/common/gllo_utils.h"
26 #include "tools/common/graph_util.h"
27 #include "tools/common/protobuf_utils.h"
28 #include "tools/common/tensor_util.h"
29 #include "tools/converter/ops/ops_def.h"
30 #include "ops/tensor_list_stack.h"
31 #include "ir/func_graph.h"
32 #include "tools/converter/converter_flags.h"
33 #include "tools/converter/quant_param_holder.h"
34 #include "tools/converter/converter_context.h"
35 #include "tools/converter/parser/onnx/onnx_inputs_adjust.h"
36 #include "tools/converter/parser/onnx/onnx_pad_adjust.h"
37 #include "tools/converter/parser/parser_utils.h"
38 #include "tools/converter/parser/unify_format.h"
39 #include "nnacl/op_base.h"
40 #include "src/common/log_util.h"
41
42 using mindspore::converter::kFmkTypeOnnx;
43 namespace mindspore {
44 namespace lite {
45 namespace {
46 constexpr int kTensorListDatasize = 3;
47 constexpr int kTypeIndex = 0;
48 constexpr int kElementShapeIndex = 1;
49 constexpr int kTensorsNumIndex = 2;
50
Onnx2AnfAdjust(const std::set<FuncGraphPtr> & all_func_graphs)51 int Onnx2AnfAdjust(const std::set<FuncGraphPtr> &all_func_graphs) {
52 for (const auto &func_graph : all_func_graphs) {
53 MS_ASSERT(func_graph != nullptr);
54 if (!OnnxInputAdjust::Adjust(func_graph)) {
55 MS_LOG(ERROR) << "onnx adjust failed.";
56 ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
57 return RET_ERROR;
58 }
59 if (!OnnxPadAdjust::Adjust(func_graph)) {
60 MS_LOG(ERROR) << "onnx pad adjust failed.";
61 ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
62 return RET_ERROR;
63 }
64 }
65 return RET_OK;
66 }
67
CreateConstParamter(const FuncGraphPtr & anf_graph,int val)68 ParameterPtr CreateConstParamter(const FuncGraphPtr &anf_graph, int val) {
69 MS_ASSERT(anf_graph != nullptr);
70 auto const_node = anf_graph->add_parameter();
71 MS_CHECK_TRUE_RET(const_node != nullptr, nullptr);
72 auto const_abstract = CreateTensorAbstract({}, kNumberTypeInt32);
73 if (const_abstract == nullptr) {
74 MS_LOG(ERROR) << "Create tensor abstarct failed";
75 return nullptr;
76 }
77 const_node->set_abstract(const_abstract);
78 int *tensor_data = new (std::nothrow) int[1];
79 if (tensor_data == nullptr) {
80 MS_LOG(ERROR) << "new int[] failed";
81 return nullptr;
82 }
83 tensor_data[0] = val;
84 auto tensor_info = CreateTensorInfo(tensor_data, sizeof(int), {1}, kNumberTypeInt32);
85 if (tensor_info == nullptr) {
86 MS_LOG(ERROR) << "create tensor info failed.";
87 delete[] tensor_data;
88 tensor_data = nullptr;
89 return nullptr;
90 }
91 tensor_data = nullptr;
92 delete[] tensor_data;
93 const_node->set_default_param(tensor_info);
94 return const_node;
95 }
96
CreateValueNode(const schema::PrimitiveType & op_type)97 ValueNodePtr CreateValueNode(const schema::PrimitiveType &op_type) {
98 auto node_type = schema::EnumNamePrimitiveType(op_type);
99 auto op_primc_fns = ops::OpPrimCRegister::GetInstance().GetPrimCMap();
100 if (op_primc_fns.find(node_type) == op_primc_fns.end()) {
101 MS_LOG(ERROR) << "have no func to create primitive.";
102 return nullptr;
103 }
104 auto prim = op_primc_fns[node_type]();
105 if (prim == nullptr) {
106 MS_LOG(ERROR) << "cannot create primitive.";
107 return nullptr;
108 }
109 return NewValueNode(prim);
110 }
111
AddIterNumsUpdateEdge(const FuncGraphPtr & anf_graph,std::vector<AnfNodePtr> * return_new_inputs,const std::unordered_map<std::string,AnfNodePtr> & anf_nodes_map,const std::string & trip_cout_name,const std::string & loop_node_name)112 STATUS AddIterNumsUpdateEdge(const FuncGraphPtr &anf_graph, std::vector<AnfNodePtr> *return_new_inputs,
113 const std::unordered_map<std::string, AnfNodePtr> &anf_nodes_map,
114 const std::string &trip_cout_name, const std::string &loop_node_name) {
115 MS_ASSERT(anf_graph != nullptr && return_new_inputs != nullptr);
116 // trip_cout need -1 after every iteration
117 auto sub_value_node = CreateValueNode(schema::PrimitiveType_SubFusion);
118 if (sub_value_node == nullptr) {
119 MS_LOG(ERROR) << "create sub failed.";
120 return RET_NULL_PTR;
121 }
122 auto trip_cout_paramter_iter = anf_nodes_map.find(trip_cout_name);
123 if (trip_cout_paramter_iter == anf_nodes_map.end()) {
124 MS_LOG(ERROR) << "can not find " << trip_cout_name;
125 return RET_ERROR;
126 }
127 auto &trip_cout_paramter = trip_cout_paramter_iter->second;
128 if (trip_cout_paramter == nullptr) {
129 MS_LOG(ERROR) << "trip_cout_paramter found failed";
130 return RET_ERROR;
131 }
132 auto const_one_parameter = CreateConstParamter(anf_graph, 1);
133 MS_CHECK_TRUE_MSG(const_one_parameter != nullptr, RET_ERROR, "create const parameter return nullptr");
134 const_one_parameter->set_name(loop_node_name + "_index_update_parameter");
135
136 std::vector<AnfNodePtr> sub_inputs = {sub_value_node, trip_cout_paramter, const_one_parameter};
137 auto sub_cnode = anf_graph->NewCNode(sub_inputs);
138 if (sub_cnode == nullptr) {
139 MS_LOG(ERROR) << "new cnode error";
140 return RET_ERROR;
141 }
142 sub_cnode->set_fullname_with_scope(loop_node_name + "_sub");
143 sub_cnode->set_abstract(trip_cout_paramter->abstract());
144 return_new_inputs->insert(return_new_inputs->begin() + 1, sub_cnode);
145 return RET_OK;
146 }
147
GetCNodeFromControlFlowNodesMap(const std::string & loop_node_name,const std::unordered_map<std::string,std::unordered_map<std::string,AnfNodePtr> * > & control_nodes_map)148 CNodePtr GetCNodeFromControlFlowNodesMap(
149 const std::string &loop_node_name,
150 const std::unordered_map<std::string, std::unordered_map<std::string, AnfNodePtr> *> &control_nodes_map) {
151 auto iter1 = control_nodes_map.find(loop_node_name);
152 if (iter1 == control_nodes_map.end()) {
153 return nullptr;
154 }
155 auto iter2 = iter1->second->find(loop_node_name);
156 if (iter2 == iter1->second->end()) {
157 return nullptr;
158 }
159 return iter2->second->cast<CNodePtr>();
160 }
161
BuildReturnNode(const FuncGraphPtr & anf_graph,const std::vector<AnfNodePtr> & return_inputs)162 STATUS BuildReturnNode(const FuncGraphPtr &anf_graph, const std::vector<AnfNodePtr> &return_inputs) {
163 MS_ASSERT(anf_graph != nullptr);
164 auto return_prim = std::make_shared<lite::Return>();
165 if (return_prim == nullptr) {
166 MS_LOG(ERROR) << "new Return failed";
167 return RET_NULL_PTR;
168 }
169 auto return_cnode = anf_graph->NewCNode(return_prim, return_inputs);
170 if (return_cnode == nullptr) {
171 MS_LOG(ERROR) << "new cnode error";
172 return RET_ERROR;
173 }
174 return_cnode->set_fullname_with_scope("Return");
175 anf_graph->set_return(return_cnode);
176 return RET_OK;
177 }
178
BuildParameterNode(const ParameterPtr & parameter_node,const onnx::TensorProto & tensor)179 STATUS BuildParameterNode(const ParameterPtr ¶meter_node, const onnx::TensorProto &tensor) {
180 MS_ASSERT(parameter_node != nullptr);
181 auto data_type = OnnxNodeParser::GetDataTypeFromOnnx(static_cast<onnx::TensorProto_DataType>(tensor.data_type()));
182 if (data_type == kTypeUnknown) {
183 MS_LOG(ERROR) << "not support onnx data type " << static_cast<onnx::TensorProto_DataType>(tensor.data_type());
184 return RET_ERROR;
185 }
186 std::vector<int64_t> shape_vector(tensor.dims().begin(), tensor.dims().end());
187 auto abstract_tensor = CreateTensorAbstract(shape_vector, data_type);
188 if (abstract_tensor == nullptr) {
189 MS_LOG(ERROR) << "Create tensor abstarct failed";
190 return RET_ERROR;
191 }
192 parameter_node->set_abstract(abstract_tensor);
193 parameter_node->set_name(tensor.name());
194
195 auto tensor_info = std::make_shared<tensor::Tensor>(data_type, shape_vector);
196 MS_CHECK_TRUE_MSG(tensor_info != nullptr, RET_NULL_PTR, "create tensor_info return nullptr");
197 std::vector<int> shape;
198 std::transform(shape_vector.begin(), shape_vector.end(), std::back_inserter(shape),
199 [](const int64_t &value) { return static_cast<int>(value); });
200 auto status = OnnxNodeParser::CopyOnnxTensorData(tensor, tensor_info);
201 if (status != RET_OK) {
202 MS_LOG(ERROR) << "copy data failed.";
203 return status;
204 }
205 parameter_node->set_default_param(tensor_info);
206 return RET_OK;
207 }
208
BuildOpOutputs(const onnx::NodeProto & onnx_node,const FuncGraphPtr & anf_graph,std::unordered_map<std::string,AnfNodePtr> * anf_nodes_map,const CNodePtr & cnode)209 STATUS BuildOpOutputs(const onnx::NodeProto &onnx_node, const FuncGraphPtr &anf_graph,
210 std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map, const CNodePtr &cnode) {
211 MS_ASSERT(anf_graph != nullptr && cnode != nullptr && anf_nodes_map != nullptr);
212 if (onnx_node.output_size() == 1) {
213 auto abstract_tensor = CreateTensorAbstract({}, kNumberTypeFloat32);
214 if (abstract_tensor == nullptr) {
215 MS_LOG(ERROR) << "Create tensor abstarct failed";
216 return RET_ERROR;
217 }
218 cnode->set_abstract(abstract_tensor);
219 anf_nodes_map->emplace(onnx_node.output(0), cnode);
220 } else {
221 AbstractBasePtrList abstract_list;
222 int op_idx = 0;
223 for (const auto &output_name : onnx_node.output()) {
224 auto abstract_tensor = CreateTensorAbstract({}, kNumberTypeFloat32);
225 if (abstract_tensor == nullptr) {
226 MS_LOG(ERROR) << "Create tensor abstarct failed";
227 return RET_ERROR;
228 }
229 abstract_list.emplace_back(abstract_tensor);
230 auto tuple_get_item_prim_ptr = std::make_shared<lite::TupleGetItem>();
231 if (tuple_get_item_prim_ptr == nullptr) {
232 MS_LOG(ERROR) << "new TupleGetItem failed";
233 return RET_NULL_PTR;
234 }
235 auto tuple_get_item_prim = NewValueNode(tuple_get_item_prim_ptr);
236 MS_CHECK_TRUE_MSG(tuple_get_item_prim != nullptr, RET_NULL_PTR, "create ValueNode return nullptr");
237 auto get_item_value = NewValueNode(MakeValue<int>(op_idx));
238 MS_CHECK_TRUE_MSG(get_item_value != nullptr, RET_NULL_PTR, "create ValueNode return nullptr");
239 std::vector<AnfNodePtr> inputs{tuple_get_item_prim, cnode, get_item_value};
240 CNodePtr get_item_cnode = anf_graph->NewCNode(inputs);
241 if (get_item_cnode == nullptr) {
242 MS_LOG(ERROR) << "new cnode error";
243 return RET_ERROR;
244 }
245 get_item_cnode->set_fullname_with_scope(cnode->fullname_with_scope() + "_getitem_" + std::to_string(op_idx));
246 auto get_item_abstract = CreateTensorAbstract({}, kNumberTypeFloat32);
247 if (get_item_abstract == nullptr) {
248 MS_LOG(ERROR) << "Create tensor abstarct failed";
249 return RET_ERROR;
250 }
251 get_item_cnode->set_abstract(get_item_abstract);
252 anf_nodes_map->emplace(output_name, get_item_cnode);
253 op_idx++;
254 }
255 auto new_abstract_list = std::make_shared<abstract::AbstractTuple>(abstract_list);
256 CHECK_NULL_RETURN(new_abstract_list);
257 cnode->set_abstract(new_abstract_list);
258 }
259 anf_nodes_map->emplace(onnx_node.name(), cnode);
260 return RET_OK;
261 }
262
ConvertConstTensors(const onnx::GraphProto & onnx_graph,const FuncGraphPtr & func_graph_ptr,std::unordered_map<std::string,AnfNodePtr> * anf_nodes_map)263 STATUS ConvertConstTensors(const onnx::GraphProto &onnx_graph, const FuncGraphPtr &func_graph_ptr,
264 std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map) {
265 MS_ASSERT(func_graph_ptr != nullptr && anf_nodes_map != nullptr);
266 for (const auto &onnx_const_value : onnx_graph.initializer()) {
267 auto parameter = func_graph_ptr->add_parameter();
268 MS_CHECK_TRUE_MSG(parameter != nullptr, RET_NULL_PTR, "create parameter return nullptr");
269 auto status = BuildParameterNode(parameter, onnx_const_value);
270 if (status != RET_OK) {
271 MS_LOG(ERROR) << "parameter node build failed.";
272 return status;
273 }
274 anf_nodes_map->emplace(onnx_const_value.name(), parameter);
275 }
276 return RET_OK;
277 }
278
ConvertGraphInputs(const onnx::GraphProto & onnx_graph,const FuncGraphPtr & func_graph_ptr,std::unordered_map<std::string,AnfNodePtr> * anf_nodes_map)279 STATUS ConvertGraphInputs(const onnx::GraphProto &onnx_graph, const FuncGraphPtr &func_graph_ptr,
280 std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map) {
281 MS_ASSERT(func_graph_ptr != nullptr && anf_nodes_map != nullptr);
282 for (int i = 0; i < onnx_graph.input().size(); ++i) {
283 const auto &input_value = onnx_graph.input(i);
284 if (anf_nodes_map->find(input_value.name()) != anf_nodes_map->end()) {
285 continue;
286 }
287 auto parameter = func_graph_ptr->add_parameter();
288 MS_CHECK_TRUE_MSG(parameter != nullptr, RET_NULL_PTR, "create parameter return nullptr");
289 auto data_type = OnnxNodeParser::GetDataTypeFromOnnx(
290 static_cast<onnx::TensorProto_DataType>(input_value.type().tensor_type().elem_type()));
291 if (data_type == kTypeUnknown) {
292 MS_LOG(ERROR) << "not support onnx data type "
293 << static_cast<onnx::TensorProto_DataType>(input_value.type().tensor_type().elem_type());
294 return RET_ERROR;
295 }
296 std::vector<int64_t> shape_vector = ConverterContext::GetInstance()->GetGraphInputTensorShape(input_value.name());
297 if (ConverterContext::GetInstance()->GetGraphInputTensorShapeMapSize() > 0 && shape_vector.empty()) {
298 MS_LOG(WARNING) << "Can not find name in map. name is " << input_value.name();
299 }
300 if (shape_vector.empty()) {
301 auto onnx_shape = input_value.type().tensor_type().shape().dim();
302 std::transform(onnx_shape.begin(), onnx_shape.end(), std::back_inserter(shape_vector),
303 [](const onnx::TensorShapeProto_Dimension &val) { return static_cast<int64_t>(val.dim_value()); });
304 std::replace(shape_vector.begin(), shape_vector.end(), 0, -1);
305 }
306 auto abstract_tensor = CreateTensorAbstract(shape_vector, data_type);
307 if (abstract_tensor == nullptr) {
308 MS_LOG(ERROR) << "Create tensor abstarct failed";
309 return RET_ERROR;
310 }
311 parameter->set_abstract(abstract_tensor);
312 parameter->set_name(input_value.name());
313 anf_nodes_map->emplace(input_value.name(), parameter);
314 }
315 return RET_OK;
316 }
317
ConvertGraphOutputs(const onnx::GraphProto & onnx_graph,const FuncGraphPtr & anf_graph,const std::unordered_map<std::string,AnfNodePtr> & anf_nodes_map)318 STATUS ConvertGraphOutputs(const onnx::GraphProto &onnx_graph, const FuncGraphPtr &anf_graph,
319 const std::unordered_map<std::string, AnfNodePtr> &anf_nodes_map) {
320 MS_ASSERT(anf_graph != nullptr);
321 std::vector<AnfNodePtr> return_inputs;
322 if (onnx_graph.output_size() == 0) {
323 MS_LOG(ERROR) << "onnx graph has no output";
324 return RET_ERROR;
325 }
326 if (onnx_graph.output_size() > 1) {
327 std::vector<AnfNodePtr> make_tuple_inputs;
328 auto make_tuple_prim_ptr = std::make_shared<lite::MakeTuple>();
329 if (make_tuple_prim_ptr == nullptr) {
330 MS_LOG(ERROR) << "new MakeTuple failed";
331 return RET_NULL_PTR;
332 }
333 for (const auto &graph_out : onnx_graph.output()) {
334 if (anf_nodes_map.find(graph_out.name()) == anf_nodes_map.end()) {
335 MS_LOG(ERROR) << "graph output get failed.";
336 return RET_ERROR;
337 }
338 auto cnode = anf_nodes_map.at(graph_out.name());
339 if (cnode == nullptr) {
340 MS_LOG(ERROR) << "Can't find input node.";
341 return RET_NOT_FIND_OP;
342 }
343 make_tuple_inputs.emplace_back(cnode);
344 }
345 auto make_tuple_cnode = anf_graph->NewCNode(make_tuple_prim_ptr, make_tuple_inputs);
346 if (make_tuple_cnode == nullptr) {
347 MS_LOG(ERROR) << "new cnode error";
348 return RET_ERROR;
349 }
350 make_tuple_cnode->set_fullname_with_scope("return tuple");
351 return_inputs.emplace_back(make_tuple_cnode);
352 } else {
353 const auto &graph_out = onnx_graph.output(0);
354 if (anf_nodes_map.find(graph_out.name()) == anf_nodes_map.end()) {
355 MS_LOG(ERROR) << "graph output get failed.";
356 return RET_ERROR;
357 }
358 auto cnode = anf_nodes_map.at(graph_out.name());
359 if (cnode == nullptr) {
360 MS_LOG(ERROR) << "Can't find input node.";
361 return RET_NOT_FIND_OP;
362 }
363 return_inputs.emplace_back(cnode);
364 }
365 if (BuildReturnNode(anf_graph, return_inputs) != RET_OK) {
366 MS_LOG(ERROR) << "build return node failed.";
367 return RET_ERROR;
368 }
369 return RET_OK;
370 }
371
BuildCondGraph(const AnfNodePtr & root_while_node,int inputs_num,const std::string & cond_graph_name)372 FuncGraphPtr BuildCondGraph(const AnfNodePtr &root_while_node, int inputs_num, const std::string &cond_graph_name) {
373 MS_ASSERT(root_while_node != nullptr);
374 auto cond_graph = std::make_shared<FuncGraph>();
375 MS_CHECK_TRUE_MSG(cond_graph != nullptr, nullptr, "create cond_graph return nullptr");
376 CNodePtr less_cnode = nullptr;
377 for (int i = 0; i < inputs_num; i++) {
378 auto input_parameter = cond_graph->add_parameter();
379 MS_CHECK_TRUE_MSG(input_parameter != nullptr, nullptr, "create input_parameter return nullptr");
380 input_parameter->set_name(cond_graph_name + "_input_" + std::to_string(i) + "_parameter");
381 auto input_abstract = CreateTensorAbstract({}, kNumberTypeInt32);
382 if (input_abstract == nullptr) {
383 MS_LOG(ERROR) << "Create tensor abstarct failed";
384 return nullptr;
385 }
386 input_parameter->set_abstract(input_abstract);
387 if (i == 0) {
388 auto zero_parameter = CreateConstParamter(cond_graph, 0);
389 MS_CHECK_TRUE_MSG(zero_parameter != nullptr, nullptr, "create zero_parameter return nullptr");
390 zero_parameter->set_name(root_while_node->fullname_with_scope() + "_const_0");
391 auto less_value_node = CreateValueNode(schema::PrimitiveType_Less);
392 MS_CHECK_TRUE_MSG(less_value_node != nullptr, nullptr, "create less_value_node return nullptr");
393 std::vector<AnfNodePtr> less_inputs = {less_value_node, zero_parameter, input_parameter};
394 less_cnode = cond_graph->NewCNode(less_inputs);
395 if (less_cnode == nullptr) {
396 MS_LOG(ERROR) << "new cnode error";
397 return nullptr;
398 }
399 auto less_abstract = CreateTensorAbstract({}, kNumberTypeBool);
400 if (less_abstract == nullptr) {
401 MS_LOG(ERROR) << "Create tensor abstarct failed";
402 return nullptr;
403 }
404 less_cnode->set_abstract(less_abstract);
405 less_cnode->set_fullname_with_scope(cond_graph_name + "_less_cnode");
406 }
407 if (i == 1) {
408 auto and_value_node = CreateValueNode(schema::PrimitiveType_LogicalAnd);
409 MS_CHECK_TRUE_MSG(and_value_node != nullptr, nullptr, "CreateValueNode failed");
410 std::vector<AnfNodePtr> and_inputs = {and_value_node, less_cnode, input_parameter};
411 auto and_cnode = cond_graph->NewCNode(and_inputs);
412 if (and_cnode == nullptr) {
413 MS_LOG(ERROR) << "new cnode error";
414 return nullptr;
415 }
416 and_cnode->set_abstract(less_cnode->abstract());
417 and_cnode->set_fullname_with_scope(cond_graph_name + "_output_" + std::to_string(0) + "_cnode");
418 auto status = BuildReturnNode(cond_graph, {and_cnode});
419 if (status != RET_OK) {
420 MS_LOG(ERROR) << "build return node failed: " << status;
421 return nullptr;
422 }
423 }
424 }
425 cond_graph->set_attr("graph_name", MakeValue(cond_graph_name));
426 return cond_graph;
427 }
428 } // namespace
429
BuildBodyGraph(const onnx::NodeProto & loop_node,const onnx::GraphProto & subgraph_proto,int * cond_graph_input_num)430 FuncGraphPtr OnnxModelParser::BuildBodyGraph(const onnx::NodeProto &loop_node, const onnx::GraphProto &subgraph_proto,
431 int *cond_graph_input_num) {
432 MS_ASSERT(cond_graph_input_num != nullptr);
433 auto &loop_node_name = loop_node.name();
434 auto node_inputs_num = loop_node.input_size();
435 auto node_outputs_num = loop_node.output_size();
436 // skip trip_cout and cond input,scan_output nums
437 auto act_outputs_num = node_outputs_num - (node_inputs_num - 2);
438 auto loop_body_graph = std::make_shared<FuncGraph>();
439 MS_CHECK_TRUE_MSG(loop_body_graph != nullptr, nullptr, "create loop_body_graph return nullptr");
440 std::unordered_map<std::string, AnfNodePtr> anf_nodes_map;
441 std::vector<AnfNodePtr> gen_subgraph_inputs;
442 auto status = ConvertOnnxGraph(subgraph_proto, loop_body_graph, &anf_nodes_map, &gen_subgraph_inputs, loop_node_name);
443 if (status != RET_OK) {
444 MS_LOG(ERROR) << "convert loop OnnxGraph: " << status;
445 return nullptr;
446 }
447 auto return_node = loop_body_graph->get_return();
448 MS_CHECK_TRUE_MSG(return_node != nullptr, nullptr, "return node of subgraph is nullptr");
449 MS_ASSERT(return_node->inputs().size() == DIMENSION_2D);
450 auto return_tuple_cnode = return_node->input(1)->cast<CNodePtr>();
451 MS_ASSERT(return_tuple_cnode != nullptr);
452 auto return_new_inputs = return_tuple_cnode->inputs();
453 return_new_inputs.insert(return_new_inputs.end() - act_outputs_num, gen_subgraph_inputs.begin(),
454 gen_subgraph_inputs.end());
455
456 std::string max_trip_count_name = subgraph_proto.input(0).name();
457 status =
458 AddIterNumsUpdateEdge(loop_body_graph, &return_new_inputs, anf_nodes_map, max_trip_count_name, loop_node_name);
459 if (status != RET_OK) {
460 MS_LOG(ERROR) << "add iter nums update edge failed: " << status;
461 return nullptr;
462 }
463 auto root_while_node = GetCNodeFromControlFlowNodesMap(loop_node_name, control_nodes_map_);
464 MS_CHECK_TRUE_MSG(root_while_node != nullptr, nullptr, "can not find root_while_node is control_nodes_map");
465 std::vector<AnfNodePtr> body_graph_inputs;
466 body_graph_inputs.reserve(subgraph_proto.input_size());
467 for (int j = 0; j < subgraph_proto.input_size(); j++) {
468 body_graph_inputs.emplace_back(anf_nodes_map[subgraph_proto.input(j).name()]);
469 }
470 body_graph_inputs.insert(body_graph_inputs.end(), gen_subgraph_inputs.begin(), gen_subgraph_inputs.end());
471 if (act_outputs_num != 0) {
472 status =
473 AddTensorArrayEdge(loop_body_graph, &return_new_inputs, loop_node_name, &body_graph_inputs, act_outputs_num);
474 if (status != RET_OK) {
475 MS_LOG(ERROR) << "add tensorarray update edge failed: " << status;
476 return nullptr;
477 }
478 // insert tensorliststack after while output
479 status = AddTensorListStackNode(root_while_node, loop_node, act_outputs_num, body_graph_inputs.size());
480 if (status != RET_OK) {
481 MS_LOG(ERROR) << "add tensorliststack node failed: " << status;
482 return nullptr;
483 }
484 }
485 return_tuple_cnode->set_inputs(return_new_inputs);
486 auto body_graph_name = loop_node_name + "_body_graph";
487 for (size_t j = 0; j < body_graph_inputs.size(); j++) {
488 MS_CHECK_TRUE_RET(body_graph_inputs[j] != nullptr, nullptr);
489 auto body_input = body_graph_inputs[j]->cast<ParameterPtr>();
490 MS_ASSERT(body_input != nullptr);
491 body_input->set_name(body_graph_name + "_input_" + std::to_string(j) + "_parameter");
492 }
493 for (size_t j = 1; j < return_new_inputs.size(); j++) {
494 if (utils::isa<CNodePtr>(return_new_inputs[j])) {
495 return_new_inputs[j]->cast<CNodePtr>()->set_fullname_with_scope(body_graph_name + "_output_" +
496 std::to_string(j - 1) + "_cnode");
497 } else if (utils::isa<ParameterPtr>(return_new_inputs[j])) {
498 return_new_inputs[j]->cast<ParameterPtr>()->set_name(body_graph_name + "_output_" + std::to_string(j - 1) +
499 "_parameter");
500 }
501 }
502 *cond_graph_input_num = return_new_inputs.size() - 1;
503 loop_body_graph->set_attr("graph_name", MakeValue(body_graph_name));
504 return loop_body_graph;
505 }
506
507 namespace {
CheckOnnxModel(const onnx::GraphProto & onnx_graph)508 STATUS CheckOnnxModel(const onnx::GraphProto &onnx_graph) {
509 // all input should in initialize
510 std::set<std::string> providers;
511 for (const auto &const_tensor : onnx_graph.initializer()) {
512 const auto &name = const_tensor.name();
513 if (providers.count(name) != 0) {
514 MS_LOG(ERROR) << "const tensor repeated";
515 return RET_ERROR;
516 }
517 providers.insert(name);
518 }
519 for (int i = 0; i < onnx_graph.input().size(); ++i) {
520 providers.insert(onnx_graph.input(i).name());
521 }
522 for (const auto &onnx_node : onnx_graph.node()) {
523 for (int i = 0; i < onnx_node.output_size(); i++) {
524 auto &output = onnx_node.output(i);
525 if (providers.count(output) != 0) {
526 MS_LOG(ERROR) << "Output tensor repeated";
527 return RET_ERROR;
528 }
529 providers.insert(output);
530 }
531 }
532 // all output should find
533 for (const auto &onnx_node : onnx_graph.node()) {
534 for (int i = 0; i < onnx_node.input_size(); i++) {
535 auto &input = onnx_node.input(i);
536 if (providers.count(input) == 0) {
537 MS_LOG(WARNING) << "Can not find node input: " << input;
538 }
539 }
540 }
541 return RET_OK;
542 }
543 } // namespace
544
Parse(const converter::ConverterParameters & flag)545 api::FuncGraphPtr OnnxModelParser::Parse(const converter::ConverterParameters &flag) {
546 auto model_file = flag.model_file;
547 NotSupportOp::GetInstance()->set_fmk_type("ONNX");
548 res_graph_ = std::make_shared<FuncGraph>();
549 MS_CHECK_TRUE_MSG(res_graph_ != nullptr, nullptr, "create FuncGraph failed");
550 auto status = InitOriginModel(model_file);
551 if (RET_OK != status) {
552 ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
553 MS_LOG(ERROR) << "init origin model failed.";
554 return nullptr;
555 }
556 MS_ASSERT(onnx_root_graph_ != nullptr);
557
558 auto func_graph = std::dynamic_pointer_cast<FuncGraph>(res_graph_);
559 MS_CHECK_TRUE_RET(func_graph != nullptr, nullptr);
560 status = ConvertOnnxGraph(onnx_root_graph_, func_graph, &anf_nodes_map_, {}, "root_node");
561 if (RET_OK != status) {
562 ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
563 MS_LOG(ERROR) << "convert onnx graph failed.";
564 return nullptr;
565 }
566 static auto root_func_manager = Manage(func_graph);
567 MS_ASSERT(root_func_manager != nullptr);
568 for (auto &subgraph : all_subgraphs_) {
569 MS_ASSERT(subgraph != nullptr);
570 subgraph->set_manager(root_func_manager);
571 subgraph->set_attr("fmk", MakeValue(static_cast<int>(converter::kFmkTypeOnnx)));
572 }
573 res_graph_->set_attr("graph_name", MakeValue("main_graph"));
574 res_graph_->set_attr("fmk", MakeValue(static_cast<int>(converter::kFmkTypeOnnx)));
575 std::set<FuncGraphPtr> all_func_graphs = {};
576 GetAllFuncGraph(func_graph, &all_func_graphs);
577 if ((status = CommonAnfAdjust(all_func_graphs)) != RET_OK) {
578 MS_LOG(ERROR) << "AdjustForAnf failed.";
579 ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
580 return nullptr;
581 }
582 if ((status = Onnx2AnfAdjust(all_func_graphs)) != RET_OK) {
583 MS_LOG(ERROR) << "Onnx2AnfAdjust failed.";
584 ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
585 return nullptr;
586 }
587 auto unify_format = std::make_shared<UnifyFormatToNHWC>(kFmkTypeOnnx, false);
588 MS_CHECK_TRUE_MSG(unify_format != nullptr, nullptr, "create unify_format return nullptr");
589 if (!unify_format->Run(func_graph)) {
590 MS_LOG(ERROR) << "Run insert transpose failed.";
591 return nullptr;
592 }
593 return res_graph_;
594 }
595
InitOriginModel(const std::string & model_file)596 STATUS OnnxModelParser::InitOriginModel(const std::string &model_file) {
597 MS_ASSERT(res_graph_ != nullptr);
598 auto status = ValidateFileStr(model_file, ".onnx");
599 if (status != RET_OK) {
600 MS_LOG(ERROR) << "INPUT ILLEGAL: modelFile must be *.onnx";
601 return status;
602 }
603
604 status = ReadProtoFromBinaryFile(model_file, &onnx_model_);
605 if (status != RET_OK) {
606 MS_LOG(ERROR) << "Read onnx model file failed, model path: " << model_file;
607 ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
608 return status;
609 }
610 MS_ASSERT(onnx_model_ != nullptr);
611 OnnxNodeParser::set_opset_version(onnx_model_.opset_import().Get(0).version());
612 onnx_root_graph_ = onnx_model_.graph();
613 auto fmk_value_node = MakeValue(static_cast<int>(converter::kFmkTypeOnnx));
614 CHECK_NULL_RETURN(fmk_value_node);
615 res_graph_->set_attr("fmk", fmk_value_node);
616 return RET_OK;
617 }
618
ConvertOnnxGraph(const onnx::GraphProto & onnx_graph,const FuncGraphPtr & anf_graph,std::unordered_map<std::string,AnfNodePtr> * anf_nodes_map,std::vector<AnfNodePtr> * extra_subgraph_inputs,const std::string & root_node_name)619 STATUS OnnxModelParser::ConvertOnnxGraph(const onnx::GraphProto &onnx_graph, const FuncGraphPtr &anf_graph,
620 std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map,
621 std::vector<AnfNodePtr> *extra_subgraph_inputs,
622 const std::string &root_node_name) {
623 MS_ASSERT(anf_graph != nullptr && anf_nodes_map != nullptr && extra_subgraph_inputs != nullptr);
624 STATUS status = RET_OK;
625 status = CheckOnnxModel(onnx_graph);
626 if (status != RET_OK) {
627 ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
628 MS_LOG(ERROR) << "input onnx model error: " << status;
629 return status;
630 }
631 status = ConvertConstTensors(onnx_graph, anf_graph, anf_nodes_map);
632 if (RET_OK != status) {
633 ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
634 MS_LOG(ERROR) << "convert const nodes failed.";
635 return RET_ERROR;
636 }
637
638 status = ConvertGraphInputs(onnx_graph, anf_graph, anf_nodes_map);
639 if (RET_OK != status) {
640 ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
641 MS_LOG(ERROR) << "convert graph inputs failed.";
642 return RET_OK;
643 }
644
645 status = ConvertNodes(onnx_graph, anf_graph, anf_nodes_map, extra_subgraph_inputs, root_node_name);
646 if (RET_OK != status) {
647 ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
648 MS_LOG(ERROR) << "convert nodes failed.";
649 return RET_ERROR;
650 }
651
652 status = ConvertGraphOutputs(onnx_graph, anf_graph, *anf_nodes_map);
653 if (RET_OK != status) {
654 ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
655 MS_LOG(ERROR) << "convert graph outputs failed.";
656 return RET_ERROR;
657 }
658 // save original output tensor names.
659 if (root_node_name == "root_node") {
660 std::vector<std::string> output_names;
661 std::transform(onnx_graph.output().begin(), onnx_graph.output().end(), std::back_inserter(output_names),
662 [](auto &graph_output) { return graph_output.name(); });
663 ConverterContext::GetInstance()->SetGraphOutputTensorNames(output_names);
664 }
665 return status;
666 }
667
ConvertNodes(const onnx::GraphProto & onnx_graph,const FuncGraphPtr & anf_graph,std::unordered_map<std::string,AnfNodePtr> * anf_nodes_map,std::vector<AnfNodePtr> * graph_inputs,const std::string & root_node_name)668 STATUS OnnxModelParser::ConvertNodes(const onnx::GraphProto &onnx_graph, const FuncGraphPtr &anf_graph,
669 std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map,
670 std::vector<AnfNodePtr> *graph_inputs, const std::string &root_node_name) {
671 MS_ASSERT(anf_graph != nullptr && anf_nodes_map != nullptr && extra_subgraph_inputs != nullptr);
672 STATUS status = RET_OK;
673 for (const auto &onnx_node : onnx_graph.node()) {
674 ops::PrimitiveC *primitive_c = nullptr;
675 auto node_parser = registry::NodeParserRegistry::GetNodeParser(kFmkTypeOnnx, onnx_node.op_type());
676 if (node_parser != nullptr) {
677 primitive_c = node_parser->Parse(onnx_graph, onnx_node);
678 } else {
679 auto node_parser_builtin = OnnxNodeParserRegistry::GetInstance().GetNodeParser(onnx_node.op_type());
680 if (node_parser_builtin == nullptr) {
681 NotSupportOp::GetInstance()->InsertOp(onnx_node.op_type());
682 status = status == RET_OK ? RET_NOT_FIND_OP : status;
683 MS_LOG(ERROR) << "not support onnx data type " << onnx_node.op_type();
684 }
685 if (status != RET_OK) {
686 continue;
687 }
688 MS_LOG(INFO) << "parse op:" << onnx_node.op_type();
689 primitive_c = node_parser_builtin->Parse(onnx_graph, onnx_node);
690 }
691
692 if (primitive_c == nullptr) {
693 MS_LOG(ERROR) << "parse node " << onnx_node.op_type() << " failed.";
694 status = RET_ERROR;
695 continue;
696 }
697 if (primitive_c->GetAttr(ops::kOriginalFormat) == nullptr) {
698 primitive_c->AddAttr(mindspore::ops::kOriginalFormat, MakeValue<int64_t>(NCHW));
699 }
700 status = ConvertOpQuantParams(onnx_node, primitive_c);
701 if (status != RET_OK) {
702 MS_LOG(ERROR) << "convert " << onnx_node.op_type() << " quant param failed.";
703 continue;
704 }
705 // build CNode
706 status = BuildCNode(onnx_node, anf_graph, anf_nodes_map, graph_inputs, primitive_c, root_node_name);
707 if (status != RET_OK) {
708 MS_LOG(ERROR) << "build cnode " << onnx_node.op_type() << " failed.";
709 }
710
711 if (onnx_node.op_type() == "Loop") {
712 child_root_map_[onnx_node.name()] = root_node_name;
713 control_nodes_map_[onnx_node.name()] = anf_nodes_map;
714
715 status = ConvertLoopOnnxNode(onnx_node, anf_nodes_map, root_node_name);
716 if (status != RET_OK) {
717 MS_LOG(ERROR) << "build loop node failed.";
718 }
719 }
720 if (onnx_node.op_type() == "If") {
721 child_root_map_[onnx_node.name()] = root_node_name;
722 control_nodes_map_[onnx_node.name()] = anf_nodes_map;
723
724 status = ConvertIfOnnxNode(onnx_node, anf_nodes_map, root_node_name);
725 if (status != RET_OK) {
726 MS_LOG(ERROR) << "build if node failed.";
727 }
728 }
729 }
730 return status;
731 }
732
ConvertIfSubgraph(const onnx::GraphProto & subgraph_proto,const FuncGraphPtr & subgraph,const std::string & subgraph_name,const std::string & if_node_name,const std::string & root_node_name)733 STATUS OnnxModelParser::ConvertIfSubgraph(const onnx::GraphProto &subgraph_proto, const FuncGraphPtr &subgraph,
734 const std::string &subgraph_name, const std::string &if_node_name,
735 const std::string &root_node_name) {
736 MS_ASSERT(subgraph != nullptr);
737 std::unordered_map<std::string, AnfNodePtr> anf_nodes_map;
738 std::vector<AnfNodePtr> subgraph_extra_inputs;
739 auto status = ConvertOnnxGraph(subgraph_proto, subgraph, &anf_nodes_map, &subgraph_extra_inputs, if_node_name);
740 if (status != RET_OK) {
741 MS_LOG(ERROR) << "convert loop OnnxGraph failed";
742 return status;
743 }
744 subgraph->set_attr("graph_name", MakeValue(subgraph_name));
745 // update subgraph in out name
746 for (int j = 0; j < subgraph_proto.input_size(); j++) {
747 auto input_anode_iter = anf_nodes_map.find(subgraph_proto.input(j).name());
748 if (input_anode_iter == anf_nodes_map.end()) {
749 MS_LOG(ERROR) << "can not find input anode";
750 return RET_ERROR;
751 }
752 auto input_parameter = input_anode_iter->second->cast<ParameterPtr>();
753 MS_CHECK_TRUE_MSG(input_parameter != nullptr, RET_ERROR, "subgraph input should be a parameter");
754 input_parameter->set_name(subgraph_name + "_input_" + std::to_string(j) + "_parameter");
755 }
756 for (size_t j = 0; j < subgraph_extra_inputs.size(); j++) {
757 auto input_parameter = subgraph_extra_inputs[j]->cast<ParameterPtr>();
758 MS_CHECK_TRUE_MSG(input_parameter != nullptr, RET_ERROR, "subgraph input should be a parameter");
759 input_parameter->set_name(subgraph_name + "_input_" + std::to_string(j + subgraph_proto.input_size()) +
760 "_parameter");
761 }
762 auto return_node = subgraph->get_return();
763 MS_CHECK_TRUE_MSG(return_node != nullptr, RET_ERROR, "subgraph has no return");
764 std::vector<AnfNodePtr> return_act_inputs;
765 int start_index = 0;
766 if (subgraph_proto.output_size() > 1) {
767 auto return_cnode = return_node->input(1)->cast<CNodePtr>();
768 MS_ASSERT(return_cnode != nullptr);
769 return_act_inputs = return_cnode->inputs();
770 start_index = 1;
771 } else {
772 return_act_inputs = {return_node->input(1)};
773 }
774 for (size_t j = start_index; j < return_act_inputs.size(); j++) {
775 if (utils::isa<CNodePtr>(return_act_inputs[j])) {
776 return_act_inputs[j]->cast<CNodePtr>()->set_fullname_with_scope(subgraph_name + "_output_" +
777 std::to_string(j - start_index) + "_cnode");
778 } else if (utils::isa<ParameterPtr>(return_act_inputs[j])) {
779 return_act_inputs[j]->cast<ParameterPtr>()->set_name(subgraph_name + "_output_" +
780 std::to_string(j - start_index) + "_parameter");
781 }
782 }
783 return RET_OK;
784 }
785
ConvertIfOnnxNode(const onnx::NodeProto & onnx_node,std::unordered_map<std::string,AnfNodePtr> * anf_root_nodes_map,const std::string & root_node_name)786 STATUS OnnxModelParser::ConvertIfOnnxNode(const onnx::NodeProto &onnx_node,
787 std::unordered_map<std::string, AnfNodePtr> *anf_root_nodes_map,
788 const std::string &root_node_name) {
789 MS_ASSERT(anf_root_nodes_map != nullptr);
790 FuncGraphPtr then_branch_graph = nullptr;
791 FuncGraphPtr else_branch_graph = nullptr;
792 FuncGraphPtr subgraph = nullptr;
793 std::string subgraph_name;
794 auto &if_node_name = onnx_node.name();
795
796 for (int i = 0; i < onnx_node.attribute_size(); i++) {
797 auto &attr = onnx_node.attribute(i);
798 auto &subgraph_proto = attr.g();
799 if (attr.name().find("then_branch") != std::string::npos) {
800 subgraph_name = if_node_name + "_then_branch";
801 then_branch_graph = std::make_shared<FuncGraph>();
802 MS_CHECK_TRUE_MSG(then_branch_graph != nullptr, RET_NULL_PTR, "create then_branch_graph return nullptr");
803 auto status = ConvertIfSubgraph(subgraph_proto, then_branch_graph, subgraph_name, if_node_name, root_node_name);
804 if (status != RET_OK) {
805 MS_LOG(ERROR) << "build if node else branch failed.";
806 }
807 } else if (attr.name().find("else_branch") != std::string::npos) {
808 subgraph_name = if_node_name + "_else_branch";
809 else_branch_graph = std::make_shared<FuncGraph>();
810 MS_CHECK_TRUE_MSG(else_branch_graph != nullptr, RET_NULL_PTR, "create else_branch_graph return nullptr");
811 auto status = ConvertIfSubgraph(subgraph_proto, else_branch_graph, subgraph_name, if_node_name, root_node_name);
812 if (status != RET_OK) {
813 MS_LOG(ERROR) << "build if node else branch failed.";
814 }
815 } else {
816 continue;
817 }
818 }
819 all_subgraphs_.emplace_back(then_branch_graph);
820 all_subgraphs_.emplace_back(else_branch_graph);
821 auto then_value_node = NewValueNode(then_branch_graph);
822 MS_CHECK_TRUE_MSG(then_value_node != nullptr, RET_NULL_PTR, "create then_value_node return nullptr");
823 auto else_value_node = NewValueNode(else_branch_graph);
824 MS_CHECK_TRUE_MSG(else_value_node != nullptr, RET_NULL_PTR, "create else_value_node return nullptr");
825 auto root_if_node = GetCNodeFromControlFlowNodesMap(if_node_name, control_nodes_map_);
826 MS_CHECK_TRUE_MSG(root_if_node != nullptr, RET_ERROR, "can not find root_if_node is control_nodes_map");
827 auto if_new_inputs = root_if_node->inputs();
828 if_new_inputs.insert(if_new_inputs.begin() + 1, {then_value_node, else_value_node});
829
830 std::vector<AnfNodePtr> if_new_input_not_same{};
831 std::set<AnfNodePtr> if_set{};
832 for (auto &input : if_new_inputs) {
833 if (if_set.find(input) != if_set.end()) {
834 continue;
835 }
836 if_new_input_not_same.push_back(input);
837 if_set.insert(input);
838 }
839
840 root_if_node->set_inputs(if_new_input_not_same);
841 return RET_OK;
842 }
843
BuildCNode(const onnx::NodeProto & onnx_node,const FuncGraphPtr & anf_graph,std::unordered_map<std::string,AnfNodePtr> * anf_nodes_map,std::vector<AnfNodePtr> * graph_inputs,ops::PrimitiveC * primitive_c,std::string loop_name)844 STATUS OnnxModelParser::BuildCNode(const onnx::NodeProto &onnx_node, const FuncGraphPtr &anf_graph,
845 std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map,
846 std::vector<AnfNodePtr> *graph_inputs, ops::PrimitiveC *primitive_c,
847 std::string loop_name) {
848 MS_ASSERT(anf_graph != nullptr && anf_nodes_map != nullptr && graph_inputs != nullptr && primitive_c != nullptr);
849 std::vector<AnfNodePtr> op_inputs;
850 for (const auto &input_name : onnx_node.input()) {
851 if (input_name.empty()) {
852 continue;
853 }
854
855 if (anf_nodes_map->find(input_name) != anf_nodes_map->end()) {
856 op_inputs.push_back(anf_nodes_map->at(input_name));
857 } else {
858 // subgraph may refer root graph nodes
859 std::vector<CNodePtr> need_add_input_nodes;
860 auto ext_subgraph_input = anf_graph->add_parameter();
861 MS_CHECK_TRUE_MSG(ext_subgraph_input != nullptr, RET_NULL_PTR, "create parameter return nullptr");
862 ParameterPtr inner_extra_paramter = nullptr;
863 while (!loop_name.empty() && child_root_map_.find(loop_name) != child_root_map_.end()) {
864 auto cur_node_map = control_nodes_map_[loop_name];
865 CHECK_NULL_RETURN(cur_node_map);
866 if (cur_node_map->find(input_name) != cur_node_map->end()) {
867 auto outside_input_node = cur_node_map->at(input_name);
868 CHECK_NULL_RETURN(outside_input_node);
869 // copy outside input parameter value to inside subgraph
870 ext_subgraph_input->set_abstract(outside_input_node->abstract());
871 ext_subgraph_input->set_name(input_name);
872 if (outside_input_node->isa<Parameter>()) {
873 auto parameter = outside_input_node->cast<ParameterPtr>();
874 if (!parameter->has_default()) {
875 MS_LOG(ERROR) << "outside_input_node should has data.";
876 return RET_ERROR;
877 }
878 auto tensor_info = parameter->default_param()->cast<tensor::TensorPtr>();
879 auto copy_tensor_info = CreateTensorInfo(tensor_info->data_c(), tensor_info->Size(), tensor_info->shape(),
880 tensor_info->data_type());
881 if (copy_tensor_info == nullptr) {
882 MS_LOG(ERROR) << "memcpy failed.";
883 return RET_ERROR;
884 }
885 ext_subgraph_input->set_default_param(copy_tensor_info);
886 } else {
887 // output inside cnode need make extra input
888 graph_inputs->emplace_back(ext_subgraph_input);
889 if (cur_node_map->find(loop_name) != cur_node_map->end()) {
890 CHECK_NULL_RETURN(cur_node_map->at(loop_name));
891 auto control_node = cur_node_map->at(loop_name)->cast<CNodePtr>();
892 MS_ASSERT(control_node != nullptr);
893 control_node->add_input(outside_input_node);
894 } else {
895 MS_LOG(ERROR) << "loop node: " << loop_name << " not found in cur node map.";
896 return RET_ERROR;
897 }
898 for (auto &control_node : need_add_input_nodes) {
899 CHECK_NULL_RETURN(control_node);
900 auto func_graph = control_node->func_graph();
901 auto extra_input_parameter = func_graph->add_parameter();
902 MS_CHECK_TRUE_MSG(extra_input_parameter != nullptr, RET_NULL_PTR, "create parameter return nullptr");
903 extra_input_parameter->set_name(input_name);
904 extra_input_parameter->set_abstract(outside_input_node->abstract());
905 control_node->add_input(extra_input_parameter);
906 }
907 }
908 op_inputs.push_back(ext_subgraph_input);
909 anf_nodes_map->emplace(input_name, ext_subgraph_input);
910 break;
911 } else {
912 if (cur_node_map->find(loop_name) != cur_node_map->end()) {
913 CHECK_NULL_RETURN(cur_node_map->at(loop_name));
914 need_add_input_nodes.emplace_back(cur_node_map->at(loop_name)->cast<CNodePtr>());
915 } else {
916 MS_LOG(ERROR) << "loop node: " << loop_name << " not found in cur node map.";
917 return RET_ERROR;
918 }
919 loop_name = child_root_map_[loop_name];
920 }
921 }
922 }
923 }
924 auto new_cnode = anf_graph->NewCNode(std::shared_ptr<ops::PrimitiveC>(primitive_c), op_inputs);
925 if (new_cnode == nullptr) {
926 MS_LOG(ERROR) << "new cnode error";
927 return RET_ERROR;
928 }
929 new_cnode->set_fullname_with_scope(onnx_node.op_type() + "_" + onnx_node.output(0));
930 auto status = BuildOpOutputs(onnx_node, anf_graph, anf_nodes_map, new_cnode);
931 return status;
932 }
933
ConvertOpQuantParams(const onnx::NodeProto & onnx_node,ops::PrimitiveC * primitive_c)934 STATUS OnnxModelParser::ConvertOpQuantParams(const onnx::NodeProto &onnx_node, ops::PrimitiveC *primitive_c) {
935 MS_ASSERT(primitive_c != nullptr);
936 auto status = ParseQuantParam(onnx_node);
937 if (status != RET_OK) {
938 MS_LOG(ERROR) << "parse quant param failed.";
939 return RET_ERROR;
940 }
941 // set input tensors
942 auto quant_params_holder = std::make_shared<QuantParamHolder>(onnx_node.input_size(), onnx_node.output_size());
943 MS_CHECK_TRUE_MSG(quant_params_holder != nullptr, RET_NULL_PTR, "create QuantParamHolder return nullptr");
944 for (int i = 0; i < onnx_node.input_size(); ++i) {
945 const auto &input_name = onnx_node.input(i);
946 std::vector<schema::QuantParamT> quant_params;
947 status = SetTensorQuantParam(input_name, &quant_params);
948 if (status != RET_OK) {
949 MS_LOG(ERROR) << "set input tensor quant param failed.";
950 return status;
951 }
952 quant_params_holder->set_input_quant_param(i, quant_params);
953 }
954 // set out tensors
955 for (int i = 0; i < onnx_node.output_size(); ++i) {
956 const auto &output_name = onnx_node.output(i);
957 std::vector<schema::QuantParamT> quant_params;
958 status = SetTensorQuantParam(output_name, &quant_params);
959 if (status != RET_OK) {
960 MS_LOG(ERROR) << "set output tensor quant param failed.";
961 return status;
962 }
963 quant_params_holder->set_output_quant_param(i, quant_params);
964 }
965 primitive_c->AddAttr("quant_params", quant_params_holder);
966 return RET_OK;
967 }
968
ParseQuantParam(const onnx::NodeProto & onnx_node)969 STATUS OnnxModelParser::ParseQuantParam(const onnx::NodeProto &onnx_node) {
970 for (const auto &onnx_node_attr : onnx_node.attribute()) {
971 if (onnx_node_attr.name() == "Y_scale") {
972 float scale = onnx_node_attr.f();
973 if (BuildParameterNodeForQuantParam(&scale, "scale_" + onnx_node.output(0), kNumberTypeFloat32) != RET_OK) {
974 MS_LOG(ERROR) << "parse quant param failed.";
975 return RET_ERROR;
976 }
977 } else if (onnx_node_attr.name() == "Y_zero_point") {
978 int64_t zero_point = onnx_node_attr.i();
979 if (BuildParameterNodeForQuantParam(&zero_point, "zero_point_" + onnx_node.output(0), kNumberTypeInt64) !=
980 RET_OK) {
981 MS_LOG(ERROR) << "parse quant param failed.";
982 return RET_ERROR;
983 }
984 }
985 }
986 return RET_OK;
987 }
988
SetTensorQuantParam(const std::string & tensor_name,std::vector<QuantParamT> * quant_params)989 STATUS OnnxModelParser::SetTensorQuantParam(const std::string &tensor_name, std::vector<QuantParamT> *quant_params) {
990 MS_ASSERT(quant_params != nullptr);
991 quant_params->clear();
992 auto quant_param = std::make_unique<QuantParamT>();
993 MS_CHECK_TRUE_MSG(quant_param != nullptr, RET_NULL_PTR, "create QuantParamT return nullptr");
994 for (int i = 0; i < onnx_root_graph_.quantization_annotation_size(); ++i) {
995 auto tensor_annotation = onnx_root_graph_.quantization_annotation(i);
996 if (!tensor_annotation.has_tensor_name() || tensor_annotation.tensor_name() != tensor_name) {
997 continue;
998 }
999 for (const auto &item : tensor_annotation.quant_parameter_tensor_names()) {
1000 if (!item.has_key() || !item.has_value()) {
1001 continue;
1002 }
1003
1004 const auto &quant_tensor_name = item.value();
1005 if (item.key() == "SCALE_TENSOR") {
1006 auto status = CopyTensorQuantParam(quant_tensor_name, quant_param.get(), true);
1007 if (status != RET_OK) {
1008 MS_LOG(ERROR) << "quant param scale get failed";
1009 return status;
1010 }
1011 } else if (item.key() == "ZERO_POINT_TENSOR") {
1012 auto status = CopyTensorQuantParam(quant_tensor_name, quant_param.get(), false);
1013 if (status != RET_OK) {
1014 MS_LOG(ERROR) << "quant param zero_point get failed";
1015 return status;
1016 }
1017 }
1018 }
1019 break;
1020 }
1021 if (quant_param->inited) {
1022 quant_params->push_back(*std::move(quant_param));
1023 return RET_OK;
1024 }
1025 return SetTensorQuantParamFromNode(tensor_name, quant_params);
1026 }
1027
SetTensorQuantParamFromNode(const std::string & tensor_name,std::vector<QuantParamT> * quant_params)1028 STATUS OnnxModelParser::SetTensorQuantParamFromNode(const std::string &tensor_name,
1029 std::vector<QuantParamT> *quant_params) {
1030 MS_ASSERT(quant_params != nullptr);
1031 quant_params->clear();
1032 auto quant_param = std::make_unique<QuantParamT>();
1033 MS_CHECK_TRUE_MSG(quant_param != nullptr, RET_NULL_PTR, "create QuantParamT return nullptr");
1034 if (OnnxNodeParser::opset_version() <= 15) {
1035 quant_param->multiplier = 0;
1036 }
1037 std::string quant_tensor_name = "scale_" + tensor_name;
1038 auto status = CopyTensorQuantParam(quant_tensor_name, quant_param.get(), true);
1039 if (status != RET_OK) {
1040 MS_LOG(ERROR) << "quant param scale get failed";
1041 return status;
1042 }
1043 quant_tensor_name = "zero_point_" + tensor_name;
1044 status = CopyTensorQuantParam(quant_tensor_name, quant_param.get(), false);
1045 if (status != RET_OK) {
1046 MS_LOG(ERROR) << "quant param zero_point get failed";
1047 return status;
1048 }
1049 if (quant_param->inited) {
1050 quant_params->push_back(*std::move(quant_param));
1051 } else {
1052 std::vector<schema::QuantParamT> notinited_quant_params(1);
1053 *quant_params = notinited_quant_params;
1054 }
1055 return RET_OK;
1056 }
1057
CopyTensorQuantParam(const std::string & tensor_name,QuantParamT * quant_param,bool scale_or_not)1058 STATUS OnnxModelParser::CopyTensorQuantParam(const std::string &tensor_name, QuantParamT *quant_param,
1059 bool scale_or_not) {
1060 MS_ASSERT(quant_param != nullptr);
1061 auto iter = anf_nodes_map_.find(tensor_name);
1062 if (iter == anf_nodes_map_.end()) {
1063 MS_LOG(DEBUG) << "has no quant param";
1064 return RET_OK;
1065 }
1066 if (!utils::isa<ParameterPtr>(iter->second)) {
1067 MS_LOG(ERROR) << "quant param get failed";
1068 return RET_ERROR;
1069 }
1070 auto quant_parameter_node = iter->second->cast<ParameterPtr>();
1071 MS_ASSERT(quant_parameter_node != nullptr);
1072 if (!quant_parameter_node->has_default()) {
1073 MS_LOG(ERROR) << "quant param get failed";
1074 return RET_ERROR;
1075 }
1076 auto tensor_info = quant_parameter_node->default_param()->cast<tensor::TensorPtr>();
1077 if (tensor_info == nullptr) {
1078 MS_LOG(ERROR) << "parameterNode's default param is not tensor::TensorPtr";
1079 return RET_ERROR;
1080 }
1081 if (tensor_info->data_c() == nullptr) {
1082 MS_LOG(ERROR) << "parameterNode's default param has no data";
1083 return RET_ERROR;
1084 }
1085 if (scale_or_not) {
1086 quant_param->scale = *reinterpret_cast<float *>(tensor_info->data_c());
1087 quant_param->inited = true;
1088 } else {
1089 quant_param->zeroPoint = *reinterpret_cast<int64_t *>(tensor_info->data_c());
1090 quant_param->inited = true;
1091 }
1092 return RET_OK;
1093 }
1094
AddTensorListStackNode(const AnfNodePtr & root_while_node,const onnx::NodeProto & onnx_node,int act_outputs_num,int body_output_size)1095 STATUS OnnxModelParser::AddTensorListStackNode(const AnfNodePtr &root_while_node, const onnx::NodeProto &onnx_node,
1096 int act_outputs_num, int body_output_size) {
1097 MS_ASSERT(root_while_node != nullptr);
1098 auto &loop_node_name = onnx_node.name();
1099 auto root_anf_graph = root_while_node->func_graph();
1100 auto stack_elem_node = CreateConstParamter(root_anf_graph, -1);
1101 MS_CHECK_TRUE_MSG(stack_elem_node != nullptr, RET_NULL_PTR, "create const parameter return nullptr");
1102 stack_elem_node->set_name(loop_node_name + "_element_shape");
1103 for (int j = 0; j < act_outputs_num; j++) {
1104 auto output_size = onnx_node.output_size();
1105 auto &loop_output_name = onnx_node.output(output_size - act_outputs_num + j);
1106 MS_ASSERT(control_nodes_map_.find(loop_node_name) != control_nodes_map_.end());
1107 MS_ASSERT(control_nodes_map_[loop_node_name]->find(loop_output_name) != control_nodes_map_[loop_node_name]->end());
1108 auto &while_output_node = control_nodes_map_[loop_node_name]->at(loop_output_name);
1109 MS_CHECK_TRUE_MSG(while_output_node != nullptr, RET_ERROR, "can not find while_output_node is control_nodes_map");
1110 auto tensor_list_stack_prim = std::make_shared<ops::TensorListStack>();
1111 if (tensor_list_stack_prim == nullptr) {
1112 MS_LOG(ERROR) << "create stack failed";
1113 return RET_ERROR;
1114 }
1115 tensor_list_stack_prim->set_num_elements(-1);
1116 auto stack_value_node = NewValueNode(tensor_list_stack_prim);
1117 MS_CHECK_TRUE_MSG(stack_value_node != nullptr, RET_NULL_PTR, "create stack_value_node return nullptr");
1118 std::vector<AnfNodePtr> stack_inputs = {stack_value_node, while_output_node, stack_elem_node};
1119 auto tensorlist_stack_cnode = root_anf_graph->NewCNode(stack_inputs);
1120 if (tensorlist_stack_cnode == nullptr) {
1121 MS_LOG(ERROR) << "new cnode error";
1122 return RET_ERROR;
1123 }
1124 tensorlist_stack_cnode->set_fullname_with_scope(loop_node_name + "_tensorlist_stack_node_" + std::to_string(j));
1125 tensorlist_stack_cnode->set_abstract(stack_elem_node->abstract());
1126
1127 // update getitem value output index
1128 auto new_get_item_value = NewValueNode(MakeValue<int>(body_output_size - act_outputs_num + j));
1129 MS_CHECK_TRUE_MSG(new_get_item_value != nullptr, RET_NULL_PTR, "create new_get_item_value return nullptr");
1130 MS_ASSERT(while_output_node->cast<CNodePtr>() != nullptr);
1131 while_output_node->cast<CNodePtr>()->set_input(2, new_get_item_value);
1132 // insert tensorliststack after while_output
1133 (*control_nodes_map_[loop_node_name])[loop_output_name] = tensorlist_stack_cnode;
1134 }
1135 return RET_OK;
1136 }
1137
1138 // onnx loop scan_output need through tensorlist op,while node need add new inputs
AddTensorArrayEdge(const FuncGraphPtr & anf_graph,std::vector<AnfNodePtr> * return_new_inputs,const std::string & loop_node_name,std::vector<AnfNodePtr> * body_graph_inputs,int act_output_num)1139 STATUS OnnxModelParser::AddTensorArrayEdge(const FuncGraphPtr &anf_graph, std::vector<AnfNodePtr> *return_new_inputs,
1140 const std::string &loop_node_name,
1141 std::vector<AnfNodePtr> *body_graph_inputs, int act_output_num) {
1142 MS_ASSERT(anf_graph != nullptr && return_new_inputs != nullptr && body_graph_inputs != nullptr);
1143 // body graph output is trip_count,cond_count,loop_var,placeholder,scan_outputs
1144 auto root_while_node = GetCNodeFromControlFlowNodesMap(loop_node_name, control_nodes_map_);
1145 MS_CHECK_TRUE_MSG(root_while_node != nullptr, RET_ERROR, "can not find root_while_node is control_nodes_map");
1146 if (root_while_node == nullptr) {
1147 MS_LOG(ERROR) << "anf root node map cannot find loop node" << loop_node_name;
1148 return RET_ERROR;
1149 }
1150 auto anf_root_graph = root_while_node->func_graph();
1151 auto root_item_index_parameter = CreateConstParamter(anf_root_graph, 0);
1152 MS_CHECK_TRUE_MSG(root_item_index_parameter != nullptr, RET_NULL_PTR,
1153 "create root_item_index_parameter return nullptr");
1154 root_item_index_parameter->set_name(loop_node_name + "_item_index");
1155 root_while_node->add_input(root_item_index_parameter);
1156 // fake parameter need pass by root while node input
1157 auto item_index_parameter = anf_graph->add_parameter();
1158 MS_CHECK_TRUE_MSG(item_index_parameter != nullptr, RET_NULL_PTR, "create item_index_parameter return nullptr");
1159 item_index_parameter->set_name(loop_node_name + "_item_index_2");
1160 item_index_parameter->set_abstract(root_item_index_parameter->abstract());
1161 body_graph_inputs->emplace_back(item_index_parameter);
1162 // item index++ edge
1163 auto add_value_node = CreateValueNode(schema::PrimitiveType_AddFusion);
1164 if (add_value_node == nullptr) {
1165 MS_LOG(ERROR) << "create add failed.";
1166 return RET_NULL_PTR;
1167 }
1168 auto add_one_input = CreateConstParamter(anf_graph, 1);
1169 MS_CHECK_TRUE_MSG(root_item_index_parameter != nullptr, RET_NULL_PTR, "create add_one_input return nullptr");
1170 add_one_input->set_name(loop_node_name + "_const_placeholder_1");
1171 std::vector<AnfNodePtr> add_inputs = {add_value_node, item_index_parameter, add_one_input};
1172 auto add_cnode = anf_graph->NewCNode(add_inputs);
1173 if (add_cnode == nullptr) {
1174 MS_LOG(ERROR) << "new cnode error";
1175 return RET_ERROR;
1176 }
1177 add_cnode->set_fullname_with_scope(loop_node_name + "item_index_add_node");
1178 add_cnode->set_abstract(root_item_index_parameter->abstract());
1179 // return node inputs will be trip_count,cond_out,loop_var,placeholder,tensorarray...
1180 if (static_cast<int>(return_new_inputs->size()) < act_output_num || act_output_num < 0) {
1181 MS_LOG(ERROR) << "act_output_num out of range of return_new_inputs";
1182 return RET_ERROR;
1183 }
1184 return_new_inputs->insert(return_new_inputs->end() - act_output_num, add_cnode);
1185
1186 for (int i = 0; i < act_output_num; i++) {
1187 // tensor_array need as root while input
1188 auto while_tensor_array_input = anf_root_graph->add_parameter();
1189 MS_CHECK_TRUE_MSG(while_tensor_array_input != nullptr, RET_NULL_PTR,
1190 "create while_tensor_array_input return nullptr");
1191 std::vector<int> tensor_list_data(kTensorListDatasize);
1192 tensor_list_data[kTypeIndex] = kTypeUnknown;
1193 tensor_list_data[kElementShapeIndex] = 0;
1194 tensor_list_data[kTensorsNumIndex] = -1;
1195 if (INT_MUL_OVERFLOW_THRESHOLD(tensor_list_data.size(), sizeof(int), SIZE_MAX)) {
1196 MS_LOG(ERROR) << "data_size overflow";
1197 return RET_ERROR;
1198 }
1199 auto tensor_info = CreateTensorInfo(tensor_list_data.data(), tensor_list_data.size() * sizeof(int),
1200 {static_cast<int64_t>(tensor_list_data.size())}, kObjectTypeTensorType);
1201 if (tensor_info == nullptr) {
1202 MS_LOG(ERROR) << "Create tensor info failed";
1203 return RET_ERROR;
1204 }
1205 auto abstract_tensor = tensor_info->ToAbstract();
1206 if (abstract_tensor == nullptr) {
1207 MS_LOG(ERROR) << "Create tensor abstarct failed";
1208 return RET_ERROR;
1209 }
1210 while_tensor_array_input->set_abstract(abstract_tensor);
1211 while_tensor_array_input->set_default_param(tensor_info);
1212 while_tensor_array_input->set_name(loop_node_name + "_scan_outputs_tensorarray_while_input");
1213 root_while_node->add_input(while_tensor_array_input);
1214
1215 auto subgraph_tensor_array_input = anf_graph->add_parameter();
1216 MS_CHECK_TRUE_MSG(subgraph_tensor_array_input != nullptr, RET_NULL_PTR,
1217 "create subgraph_tensor_array_input return nullptr");
1218 subgraph_tensor_array_input->set_name(loop_node_name + "_scan_outputs_tensorarray_body_fg_input");
1219 subgraph_tensor_array_input->set_abstract(abstract_tensor);
1220 body_graph_inputs->emplace_back(subgraph_tensor_array_input);
1221 // skip trip_count ,cond_out,loop_var,no_loop_var,place_holder, output
1222 auto loop_output_idx = return_new_inputs->size() - act_output_num + i;
1223 auto loop_output_node = (*return_new_inputs)[loop_output_idx];
1224 auto set_item_value_node = CreateValueNode(schema::PrimitiveType_TensorListSetItem);
1225 if (set_item_value_node == nullptr) {
1226 MS_LOG(ERROR) << "create tensor list set item failed.";
1227 return RET_NULL_PTR;
1228 }
1229 std::vector<AnfNodePtr> set_item_inputs = {set_item_value_node, subgraph_tensor_array_input, item_index_parameter,
1230 loop_output_node};
1231 auto tensorlist_setitem_cnode = anf_graph->NewCNode(set_item_inputs);
1232 if (tensorlist_setitem_cnode == nullptr) {
1233 MS_LOG(ERROR) << "new cnode error";
1234 return RET_ERROR;
1235 }
1236 tensorlist_setitem_cnode->set_fullname_with_scope(loop_node_name + "_tensorlist_setitem_node");
1237 tensorlist_setitem_cnode->set_abstract(abstract_tensor);
1238 // loop output need replace by tensorliststack_output
1239 (*return_new_inputs)[loop_output_idx] = tensorlist_setitem_cnode;
1240 }
1241
1242 return RET_OK;
1243 }
1244
ConvertLoopOnnxNode(const onnx::NodeProto & onnx_node,std::unordered_map<std::string,AnfNodePtr> * anf_root_nodes_map,const std::string & root_node_name)1245 STATUS OnnxModelParser::ConvertLoopOnnxNode(const onnx::NodeProto &onnx_node,
1246 std::unordered_map<std::string, AnfNodePtr> *anf_root_nodes_map,
1247 const std::string &root_node_name) {
1248 MS_ASSERT(anf_root_nodes_map != nullptr);
1249 for (int i = 0; i < onnx_node.attribute_size(); i++) {
1250 auto &attr = onnx_node.attribute(i);
1251 if (attr.name() != "body" || attr.type() != onnx::AttributeProto_AttributeType_GRAPH) {
1252 continue;
1253 }
1254 auto &subgraph_proto = attr.g();
1255 int cond_graph_input_num = -1;
1256 auto loop_body_graph = BuildBodyGraph(onnx_node, subgraph_proto, &cond_graph_input_num);
1257 MS_CHECK_TRUE_MSG(loop_body_graph != nullptr, RET_NULL_PTR, "create loop_body_graph return nullptr");
1258 auto root_while_node = GetCNodeFromControlFlowNodesMap(onnx_node.name(), control_nodes_map_);
1259 MS_CHECK_TRUE_MSG(root_while_node != nullptr, RET_ERROR, "can not find root_while_node");
1260 auto loop_cond_graph = BuildCondGraph(root_while_node, cond_graph_input_num, onnx_node.name() + "_cond_graph");
1261 MS_CHECK_TRUE_MSG(loop_cond_graph != nullptr, RET_NULL_PTR, "create loop_cond_graph return nullptr");
1262 all_subgraphs_.emplace_back(loop_body_graph);
1263 all_subgraphs_.emplace_back(loop_cond_graph);
1264 auto body_value_node = NewValueNode(loop_body_graph);
1265 MS_CHECK_TRUE_MSG(body_value_node != nullptr, RET_NULL_PTR, "create body_value_node return nullptr");
1266 auto inputs = root_while_node->inputs();
1267 auto cond_value_node = NewValueNode(loop_cond_graph);
1268 MS_CHECK_TRUE_MSG(cond_value_node != nullptr, RET_NULL_PTR, "create cond_value_node return nullptr");
1269 inputs.insert(inputs.begin() + 1, {cond_value_node, body_value_node});
1270 root_while_node->set_inputs(inputs);
1271 }
1272 return RET_OK;
1273 }
1274
BuildParameterNodeForQuantParam(const void * data,const std::string & name,TypeId type)1275 STATUS OnnxModelParser::BuildParameterNodeForQuantParam(const void *data, const std::string &name, TypeId type) {
1276 MS_ASSERT(data != nullptr);
1277 if (type != kNumberTypeInt64 && type != kNumberTypeFloat32) {
1278 MS_LOG(ERROR) << "quant param type don't support.";
1279 return RET_NOT_SUPPORT;
1280 }
1281 auto parameter_node = res_graph_->add_parameter();
1282 MS_CHECK_TRUE_MSG(parameter_node != nullptr, RET_NULL_PTR, "create parameter return nullptr");
1283 auto abstract_tensor = CreateTensorAbstract({}, type);
1284 if (abstract_tensor == nullptr) {
1285 MS_LOG(ERROR) << "Create tensor abstarct failed";
1286 return RET_ERROR;
1287 }
1288 parameter_node->set_abstract(abstract_tensor);
1289 parameter_node->set_name(name);
1290 int data_size = 0;
1291 if (type == kNumberTypeFloat32) {
1292 data_size = sizeof(float);
1293 } else {
1294 data_size = sizeof(int64_t);
1295 }
1296 auto tensor_info = CreateTensorInfo(data, data_size, {1}, type);
1297 if (tensor_info == nullptr) {
1298 MS_LOG(ERROR) << "create tensor info failed.";
1299 return RET_ERROR;
1300 }
1301 parameter_node->set_default_param(tensor_info);
1302 anf_nodes_map_.emplace(name, parameter_node);
1303 return RET_OK;
1304 }
1305
1306 REG_MODEL_PARSER(kFmkTypeOnnx, LiteModelParserCreator<OnnxModelParser>)
1307 } // namespace lite
1308 } // namespace mindspore
1309