1 /**
2 * Copyright 2020-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "tools/converter/parser/onnx/onnx_model_parser.h"
17 #include <algorithm>
18 #include <map>
19 #include <memory>
20 #include <queue>
21 #include <set>
22 #include <unordered_map>
23 #include <utility>
24 #include <vector>
25 #include "include/registry/node_parser_registry.h"
26 #include "ir/func_graph.h"
27 #include "mindspore/core/ops/nn_ops.h"
28 #include "nnacl/op_base.h"
29 #include "ops/auto_generate/gen_lite_ops.h"
30 #include "ops/make_tuple.h"
31 #include "ops/return.h"
32 #include "ops/tensor_list_stack.h"
33 #include "ops/tuple_get_item.h"
34 #include "src/common/log_util.h"
35 #include "tools/common/graph_util.h"
36 #include "tools/common/protobuf_utils.h"
37 #include "tools/common/tensor_util.h"
38 #include "tools/converter/converter_context.h"
39 #include "tools/converter/parser/lite_model_parser_creator.h"
40 #include "tools/converter/parser/onnx/onnx_einsum_adjust.h"
41 #include "tools/converter/parser/onnx/onnx_inputs_adjust.h"
42 #include "tools/converter/parser/onnx/onnx_megatron_op_adjust.h"
43 #include "tools/converter/parser/onnx/onnx_nonzero_adjust.h"
44 #include "tools/converter/parser/onnx/onnx_pad_adjust.h"
45 #include "tools/converter/parser/onnx/onnx_concat_adjust.h"
46 #include "tools/converter/parser/onnx/onnx_quantize_linear_adjust.h"
47 #include "tools/converter/parser/onnx/onnx_deform_conv2d_adjust.h"
48 #include "tools/converter/parser/onnx/onnx_custom_op_adjust.h"
49 #include "tools/converter/parser/parser_utils.h"
50 #include "tools/converter/parser/unify_format.h"
51 #include "tools/converter/quantizer/quant_param_holder.h"
52 #include "tools/optimizer/common/gllo_utils.h"
53 #include "tools/converter/parser/onnx/onnx_dtype_adjust.h"
54
55 using mindspore::converter::kFmkTypeOnnx;
56 namespace mindspore {
57 namespace lite {
58 namespace {
59 constexpr int kTensorListDatasize = 3;
60 constexpr int kTypeIndex = 0;
61 constexpr int kElementShapeIndex = 1;
62 constexpr int kTensorsNumIndex = 2;
63
Onnx2AnfAdjust(const std::set<FuncGraphPtr> & all_func_graphs,const converter::ConverterParameters & flag)64 int Onnx2AnfAdjust(const std::set<FuncGraphPtr> &all_func_graphs, const converter::ConverterParameters &flag) {
65 for (const auto &func_graph : all_func_graphs) {
66 CHECK_NULL_RETURN(func_graph);
67 if (!OnnxMegatronOpAdjust::Adjust(func_graph, flag)) {
68 MS_LOG(ERROR) << "onnx magatron adjust failed.";
69 ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
70 return RET_ERROR;
71 }
72 if (!OnnxInputAdjust::Adjust(func_graph, flag)) {
73 MS_LOG(ERROR) << "onnx adjust failed.";
74 ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
75 return RET_ERROR;
76 }
77 if (!OnnxDtypeAdjust::Adjust(func_graph, flag)) {
78 MS_LOG(ERROR) << "onnx dtype adjust failed!";
79 ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
80 return RET_ERROR;
81 }
82 if (!OnnxPadAdjust::Adjust(func_graph)) {
83 MS_LOG(ERROR) << "onnx pad adjust failed.";
84 ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
85 return RET_ERROR;
86 }
87 if (!OnnxConcatAdjust::Adjust(func_graph)) {
88 MS_LOG(ERROR) << "onnx OnnxConcatOp adjust failed.";
89 ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
90 return RET_ERROR;
91 }
92 if (!OnnxNonZeroAdjust::Adjust(func_graph)) {
93 MS_LOG(ERROR) << "onnx nonzero adjust failed.";
94 ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
95 return RET_ERROR;
96 }
97 if (!OnnxEinsumAdjust::Adjust(func_graph)) {
98 MS_LOG(ERROR) << "onnx einsum adjust failed.";
99 ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
100 return RET_ERROR;
101 }
102 if (!OnnxQuantizeLinearAdjust::Adjust(func_graph)) {
103 MS_LOG(ERROR) << "onnx quantize linear adjust failed.";
104 ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
105 return RET_ERROR;
106 }
107 if (!OnnxDeformConv2dAdjust::Adjust(func_graph)) {
108 MS_LOG(ERROR) << "onnx MMCVModulatedDeformConv2d adjust failed.";
109 ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
110 return RET_ERROR;
111 }
112 if (!OnnxCustomOpAdjust::Adjust(func_graph)) {
113 MS_LOG(ERROR) << "onnx OnnxCustomOp adjust failed.";
114 ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
115 return RET_ERROR;
116 }
117 }
118 return RET_OK;
119 }
120
CreateConstParamter(const FuncGraphPtr & anf_graph,int val)121 ParameterPtr CreateConstParamter(const FuncGraphPtr &anf_graph, int val) {
122 MS_CHECK_TRUE_RET(anf_graph != nullptr, nullptr);
123 auto const_node = anf_graph->add_parameter();
124 MS_CHECK_TRUE_RET(const_node != nullptr, nullptr);
125 auto const_abstract = CreateTensorAbstract({}, kNumberTypeInt32);
126 if (const_abstract == nullptr) {
127 MS_LOG(ERROR) << "Create tensor abstarct failed";
128 return nullptr;
129 }
130 const_node->set_abstract(const_abstract);
131 int *tensor_data = new (std::nothrow) int[1];
132 if (tensor_data == nullptr) {
133 MS_LOG(ERROR) << "new int[] failed";
134 return nullptr;
135 }
136 tensor_data[0] = val;
137 auto tensor_info = CreateTensorInfo(tensor_data, sizeof(int), {1}, kNumberTypeInt32);
138 if (tensor_info == nullptr) {
139 MS_LOG(ERROR) << "create tensor info failed.";
140 delete[] tensor_data;
141 tensor_data = nullptr;
142 return nullptr;
143 }
144 delete[] tensor_data;
145 tensor_data = nullptr;
146 const_node->set_default_param(tensor_info);
147 return const_node;
148 }
149
CreateValueNode(const schema::PrimitiveType & op_type)150 ValueNodePtr CreateValueNode(const schema::PrimitiveType &op_type) {
151 auto node_type = schema::EnumNamePrimitiveType(op_type);
152 auto op_primc_fns = ops::OpPrimCRegister::GetInstance().GetPrimCMap();
153 if (op_primc_fns.find(node_type) == op_primc_fns.end()) {
154 MS_LOG(ERROR) << "have no func to create primitive.";
155 return nullptr;
156 }
157 auto prim = op_primc_fns[node_type]();
158 if (prim == nullptr) {
159 MS_LOG(ERROR) << "cannot create primitive.";
160 return nullptr;
161 }
162 return NewValueNode(prim);
163 }
164
AddIterNumsUpdateEdge(const FuncGraphPtr & anf_graph,std::vector<AnfNodePtr> * return_new_inputs,const std::unordered_map<std::string,AnfNodePtr> & anf_nodes_map,const std::string & trip_cout_name,const std::string & loop_node_name)165 STATUS AddIterNumsUpdateEdge(const FuncGraphPtr &anf_graph, std::vector<AnfNodePtr> *return_new_inputs,
166 const std::unordered_map<std::string, AnfNodePtr> &anf_nodes_map,
167 const std::string &trip_cout_name, const std::string &loop_node_name) {
168 CHECK_NULL_RETURN(anf_graph);
169 CHECK_NULL_RETURN(return_new_inputs);
170 // trip_cout need -1 after every iteration
171 auto sub_value_node = CreateValueNode(schema::PrimitiveType_SubFusion);
172 if (sub_value_node == nullptr) {
173 MS_LOG(ERROR) << "create sub failed.";
174 return RET_NULL_PTR;
175 }
176 auto trip_cout_paramter_iter = anf_nodes_map.find(trip_cout_name);
177 if (trip_cout_paramter_iter == anf_nodes_map.end()) {
178 MS_LOG(ERROR) << "cannot find " << trip_cout_name;
179 return RET_ERROR;
180 }
181 auto &trip_cout_paramter = trip_cout_paramter_iter->second;
182 if (trip_cout_paramter == nullptr) {
183 MS_LOG(ERROR) << "trip_cout_paramter found failed";
184 return RET_ERROR;
185 }
186 auto const_one_parameter = CreateConstParamter(anf_graph, 1);
187 MS_CHECK_TRUE_MSG(const_one_parameter != nullptr, RET_ERROR, "create const parameter return nullptr");
188 const_one_parameter->set_name(loop_node_name + "_index_update_parameter");
189
190 std::vector<AnfNodePtr> sub_inputs = {sub_value_node, trip_cout_paramter, const_one_parameter};
191 auto sub_cnode = anf_graph->NewCNode(sub_inputs);
192 if (sub_cnode == nullptr) {
193 MS_LOG(ERROR) << "new cnode error";
194 return RET_ERROR;
195 }
196 sub_cnode->set_fullname_with_scope(loop_node_name + "_sub");
197 sub_cnode->set_abstract(trip_cout_paramter->abstract());
198 return_new_inputs->insert(return_new_inputs->begin() + 1, sub_cnode);
199 return RET_OK;
200 }
201
GetCNodeFromControlFlowNodesMap(const std::string & loop_node_name,const std::unordered_map<std::string,std::unordered_map<std::string,AnfNodePtr> * > & control_nodes_map)202 CNodePtr GetCNodeFromControlFlowNodesMap(
203 const std::string &loop_node_name,
204 const std::unordered_map<std::string, std::unordered_map<std::string, AnfNodePtr> *> &control_nodes_map) {
205 auto iter1 = control_nodes_map.find(loop_node_name);
206 if (iter1 == control_nodes_map.end()) {
207 return nullptr;
208 } // namespace
209 auto iter2 = iter1->second->find(loop_node_name);
210 if (iter2 == iter1->second->end()) {
211 return nullptr;
212 }
213 return iter2->second->cast<CNodePtr>();
214 }
215
BuildReturnNode(const FuncGraphPtr & anf_graph,const std::vector<AnfNodePtr> & return_inputs)216 STATUS BuildReturnNode(const FuncGraphPtr &anf_graph, const std::vector<AnfNodePtr> &return_inputs) {
217 MS_CHECK_TRUE_RET(anf_graph != nullptr, RET_NULL_PTR);
218 auto return_prim = std::make_shared<ops::Return>();
219 if (return_prim == nullptr) {
220 MS_LOG(ERROR) << "new Return failed";
221 return RET_NULL_PTR;
222 }
223 if (return_inputs.empty()) {
224 MS_LOG(ERROR) << "return input is empty";
225 return RET_ERROR;
226 }
227 auto input = return_inputs[0];
228 MS_EXCEPTION_IF_NULL(input);
229 auto abstract = input->abstract();
230 if (abstract == nullptr) {
231 MS_LOG(ERROR) << "Input node abstract is null, node: " << input->fullname_with_scope();
232 return RET_ERROR;
233 }
234
235 auto return_prim_c = return_prim->GetPrim();
236 CHECK_NULL_RETURN(return_prim_c);
237 auto return_cnode = anf_graph->NewCNode(return_prim_c, return_inputs);
238 if (return_cnode == nullptr) {
239 MS_LOG(ERROR) << "new cnode error";
240 return RET_ERROR;
241 }
242 return_cnode->set_fullname_with_scope("Return");
243 return_cnode->set_abstract(abstract);
244 anf_graph->set_return(return_cnode);
245 return RET_OK;
246 }
247
BuildParameterNode(const ParameterPtr & parameter_node,const onnx::TensorProto & tensor,const std::string & model_file,std::map<std::string,std::pair<size_t,uint8_t * >> * external_datas)248 STATUS BuildParameterNode(const ParameterPtr ¶meter_node, const onnx::TensorProto &tensor,
249 const std::string &model_file,
250 std::map<std::string, std::pair<size_t, uint8_t *>> *external_datas) {
251 MS_CHECK_TRUE_RET(parameter_node != nullptr, RET_NULL_PTR);
252 auto data_type = OnnxNodeParser::GetDataTypeFromOnnx(static_cast<onnx::TensorProto_DataType>(tensor.data_type()));
253 if (data_type == kTypeUnknown) {
254 MS_LOG(ERROR) << "not support onnx data type " << static_cast<onnx::TensorProto_DataType>(tensor.data_type());
255 return RET_ERROR;
256 }
257 std::vector<int64_t> shape_vector(tensor.dims().begin(), tensor.dims().end());
258 auto abstract_tensor = CreateTensorAbstract(shape_vector, data_type);
259 if (abstract_tensor == nullptr) {
260 MS_LOG(ERROR) << "Create tensor abstract failed";
261 return RET_ERROR;
262 }
263 parameter_node->set_abstract(abstract_tensor);
264 parameter_node->set_name(tensor.name());
265
266 tensor::TensorPtr tensor_info;
267 if (tensor.data_location() != onnx::TensorProto::EXTERNAL) {
268 tensor_info = OnnxNodeParser::CopyOnnxTensorData(tensor);
269 if (tensor_info == nullptr) {
270 MS_LOG(ERROR) << "copy data failed.";
271 return RET_ERROR;
272 }
273 } else {
274 tensor_info = std::make_shared<tensor::Tensor>(data_type, shape_vector);
275 MS_CHECK_TRUE_MSG(tensor_info != nullptr, RET_NULL_PTR, "create tensor_info return nullptr");
276 std::vector<int> shape;
277 std::transform(shape_vector.begin(), shape_vector.end(), std::back_inserter(shape),
278 [](const int64_t &value) { return static_cast<int>(value); });
279 auto status = OnnxNodeParser::LoadOnnxExternalTensorData(tensor, tensor_info, model_file, external_datas);
280 if (status != RET_OK) {
281 MS_LOG(ERROR) << "load external data failed.";
282 return status;
283 }
284 }
285 parameter_node->set_default_param(tensor_info);
286 return RET_OK;
287 }
288
BuildOpOutputs(const onnx::NodeProto & onnx_node,const FuncGraphPtr & anf_graph,std::unordered_map<std::string,AnfNodePtr> * anf_nodes_map,const CNodePtr & cnode)289 STATUS BuildOpOutputs(const onnx::NodeProto &onnx_node, const FuncGraphPtr &anf_graph,
290 std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map, const CNodePtr &cnode) {
291 CHECK_NULL_RETURN(anf_graph);
292 CHECK_NULL_RETURN(anf_nodes_map);
293 if (onnx_node.output_size() == 1) {
294 auto abstract_tensor = CreateTensorAbstract({}, kNumberTypeFloat32);
295 if (abstract_tensor == nullptr) {
296 MS_LOG(ERROR) << "Create tensor abstarct failed";
297 return RET_ERROR;
298 }
299 cnode->set_abstract(abstract_tensor);
300 anf_nodes_map->emplace(onnx_node.output(0), cnode);
301 } else {
302 AbstractBasePtrList abstract_list;
303 int op_idx = 0;
304 for (const auto &output_name : onnx_node.output()) {
305 auto abstract_tensor = CreateTensorAbstract({}, kNumberTypeFloat32);
306 if (abstract_tensor == nullptr) {
307 MS_LOG(ERROR) << "Create tensor abstarct failed";
308 return RET_ERROR;
309 }
310 abstract_list.emplace_back(abstract_tensor);
311 auto tuple_get_item_prim_ptr = std::make_shared<ops::TupleGetItem>();
312 if (tuple_get_item_prim_ptr == nullptr) {
313 MS_LOG(ERROR) << "new TupleGetItem failed";
314 return RET_NULL_PTR;
315 }
316 auto tuple_get_item_prim_c = tuple_get_item_prim_ptr->GetPrim();
317 MS_CHECK_TRUE_MSG(tuple_get_item_prim_c != nullptr, RET_NULL_PTR, "create tuple_get_item_prim_c return nullptr");
318 auto tuple_get_item_prim = NewValueNode(tuple_get_item_prim_c);
319 MS_CHECK_TRUE_MSG(tuple_get_item_prim != nullptr, RET_NULL_PTR, "create ValueNode return nullptr");
320 auto get_item_value = NewValueNode(MakeValue<int64_t>(op_idx));
321 MS_CHECK_TRUE_MSG(get_item_value != nullptr, RET_NULL_PTR, "create ValueNode return nullptr");
322 std::vector<AnfNodePtr> inputs{tuple_get_item_prim, cnode, get_item_value};
323 CNodePtr get_item_cnode = anf_graph->NewCNode(inputs);
324 if (get_item_cnode == nullptr) {
325 MS_LOG(ERROR) << "new cnode error";
326 return RET_ERROR;
327 }
328 get_item_cnode->set_fullname_with_scope(cnode->fullname_with_scope() + "_getitem_" + std::to_string(op_idx));
329 auto get_item_abstract = CreateTensorAbstract({}, kNumberTypeFloat32);
330 if (get_item_abstract == nullptr) {
331 MS_LOG(ERROR) << "Create tensor abstarct failed";
332 return RET_ERROR;
333 }
334 get_item_cnode->set_abstract(get_item_abstract);
335 anf_nodes_map->emplace(output_name, get_item_cnode);
336 op_idx++;
337 }
338 auto new_abstract_list = std::make_shared<abstract::AbstractTuple>(abstract_list);
339 CHECK_NULL_RETURN(new_abstract_list);
340 cnode->set_abstract(new_abstract_list);
341 }
342 if (onnx_node.op_type() == "Loop" || onnx_node.op_type() == "If") {
343 anf_nodes_map->emplace(onnx_node.name(), cnode);
344 }
345 return RET_OK;
346 }
347
ConvertConstTensors(const onnx::GraphProto & onnx_graph,const FuncGraphPtr & func_graph_ptr,std::unordered_map<std::string,AnfNodePtr> * anf_nodes_map,const std::string & model_file)348 STATUS ConvertConstTensors(const onnx::GraphProto &onnx_graph, const FuncGraphPtr &func_graph_ptr,
349 std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map, const std::string &model_file) {
350 CHECK_NULL_RETURN(func_graph_ptr);
351 CHECK_NULL_RETURN(anf_nodes_map);
352 std::map<std::string, std::pair<size_t, uint8_t *>> external_datas;
353 auto free_external_data = [&external_datas]() {
354 for (auto &&item : external_datas) {
355 if (item.second.second) {
356 delete[] item.second.second;
357 }
358 }
359 external_datas.clear();
360 };
361 for (const auto &onnx_const_value : onnx_graph.initializer()) {
362 auto parameter = func_graph_ptr->add_parameter();
363 MS_CHECK_TRUE_MSG(parameter != nullptr, RET_NULL_PTR, "create parameter return nullptr");
364 auto status = BuildParameterNode(parameter, onnx_const_value, model_file, &external_datas);
365 if (status != RET_OK) {
366 MS_LOG(ERROR) << "parameter node build failed.";
367 free_external_data();
368 return status;
369 }
370 anf_nodes_map->emplace(onnx_const_value.name(), parameter);
371 }
372 free_external_data();
373 return RET_OK;
374 }
375
ConvertGraphInputs(const onnx::GraphProto & onnx_graph,const FuncGraphPtr & func_graph_ptr,std::unordered_map<std::string,AnfNodePtr> * anf_nodes_map)376 STATUS ConvertGraphInputs(const onnx::GraphProto &onnx_graph, const FuncGraphPtr &func_graph_ptr,
377 std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map) {
378 CHECK_NULL_RETURN(func_graph_ptr);
379 CHECK_NULL_RETURN(anf_nodes_map);
380 for (int i = 0; i < onnx_graph.input().size(); ++i) {
381 const auto &input_value = onnx_graph.input(i);
382 if (anf_nodes_map->find(input_value.name()) != anf_nodes_map->end()) {
383 continue;
384 }
385 auto parameter = func_graph_ptr->add_parameter();
386 MS_CHECK_TRUE_MSG(parameter != nullptr, RET_NULL_PTR, "create parameter return nullptr");
387 auto data_type = OnnxNodeParser::GetDataTypeFromOnnx(
388 static_cast<onnx::TensorProto_DataType>(input_value.type().tensor_type().elem_type()));
389 if (data_type == kTypeUnknown) {
390 MS_LOG(ERROR) << "not support onnx data type "
391 << static_cast<onnx::TensorProto_DataType>(input_value.type().tensor_type().elem_type());
392 return RET_ERROR;
393 }
394 std::vector<int64_t> shape_vector =
395 ConverterInnerContext::GetInstance()->GetGraphInputTensorShape(input_value.name());
396 if (ConverterInnerContext::GetInstance()->GetGraphInputTensorShapeMapSize() > 0 && shape_vector.empty()) {
397 MS_LOG(WARNING) << "Cannot find name in map. name is " << input_value.name();
398 }
399 if (shape_vector.empty()) {
400 auto onnx_shape = input_value.type().tensor_type().shape().dim();
401 std::transform(onnx_shape.begin(), onnx_shape.end(), std::back_inserter(shape_vector),
402 [](const onnx::TensorShapeProto_Dimension &val) { return static_cast<int64_t>(val.dim_value()); });
403 std::replace(shape_vector.begin(), shape_vector.end(), 0, -1);
404 }
405 auto abstract_tensor = CreateTensorAbstract(shape_vector, data_type);
406 if (abstract_tensor == nullptr) {
407 MS_LOG(ERROR) << "Create tensor abstarct failed";
408 return RET_ERROR;
409 }
410 parameter->set_abstract(abstract_tensor);
411 parameter->set_name(input_value.name());
412 anf_nodes_map->emplace(input_value.name(), parameter);
413 }
414 return RET_OK;
415 }
416
ConvertGraphOutputs(const onnx::GraphProto & onnx_graph,const FuncGraphPtr & anf_graph,const std::unordered_map<std::string,AnfNodePtr> & anf_nodes_map)417 STATUS ConvertGraphOutputs(const onnx::GraphProto &onnx_graph, const FuncGraphPtr &anf_graph,
418 const std::unordered_map<std::string, AnfNodePtr> &anf_nodes_map) {
419 MS_CHECK_TRUE_RET(anf_graph != nullptr, RET_NULL_PTR);
420 std::vector<AnfNodePtr> return_inputs;
421 if (onnx_graph.output_size() == 0) {
422 MS_LOG(ERROR) << "onnx graph has no output";
423 return RET_ERROR;
424 }
425 if (onnx_graph.output_size() > 1) {
426 std::vector<AnfNodePtr> make_tuple_inputs;
427 auto make_tuple_prim_ptr = std::make_shared<ops::MakeTuple>();
428 AbstractBasePtrList elem;
429 if (make_tuple_prim_ptr == nullptr) {
430 MS_LOG(ERROR) << "new MakeTuple failed";
431 return RET_NULL_PTR;
432 }
433 for (const auto &graph_out : onnx_graph.output()) {
434 if (anf_nodes_map.find(graph_out.name()) == anf_nodes_map.end()) {
435 MS_LOG(ERROR) << "graph output get failed.";
436 return RET_ERROR;
437 }
438 auto cnode = anf_nodes_map.at(graph_out.name());
439 if (cnode == nullptr) {
440 MS_LOG(ERROR) << "Can't find input node.";
441 return RET_NOT_FIND_OP;
442 }
443 elem.emplace_back(cnode->abstract());
444 make_tuple_inputs.emplace_back(cnode);
445 }
446 auto make_tuple_prim_c = make_tuple_prim_ptr->GetPrim();
447 CHECK_NULL_RETURN(make_tuple_prim_c);
448 auto make_tuple_cnode = anf_graph->NewCNode(make_tuple_prim_c, make_tuple_inputs);
449 if (make_tuple_cnode == nullptr) {
450 MS_LOG(ERROR) << "new cnode error";
451 return RET_ERROR;
452 }
453
454 make_tuple_cnode->set_fullname_with_scope("return tuple");
455 make_tuple_cnode->set_abstract(std::make_shared<abstract::AbstractTuple>(elem));
456 return_inputs.emplace_back(make_tuple_cnode);
457 } else {
458 const auto &graph_out = onnx_graph.output(0);
459 if (anf_nodes_map.find(graph_out.name()) == anf_nodes_map.end()) {
460 MS_LOG(ERROR) << "graph output get failed.";
461 return RET_ERROR;
462 }
463 auto cnode = anf_nodes_map.at(graph_out.name());
464 if (cnode == nullptr) {
465 MS_LOG(ERROR) << "Can't find input node.";
466 return RET_NOT_FIND_OP;
467 }
468 return_inputs.emplace_back(cnode);
469 }
470 if (BuildReturnNode(anf_graph, return_inputs) != RET_OK) {
471 MS_LOG(ERROR) << "build return node failed.";
472 return RET_ERROR;
473 }
474 return RET_OK;
475 }
476
BuildCondGraph(const AnfNodePtr & root_while_node,int inputs_num,const std::string & cond_graph_name)477 FuncGraphPtr BuildCondGraph(const AnfNodePtr &root_while_node, int inputs_num, const std::string &cond_graph_name) {
478 MS_CHECK_TRUE_RET(root_while_node != nullptr, nullptr);
479 auto cond_graph = std::make_shared<FuncGraph>();
480 MS_CHECK_TRUE_MSG(cond_graph != nullptr, nullptr, "create cond_graph return nullptr");
481 CNodePtr less_cnode = nullptr;
482 for (int i = 0; i < inputs_num; i++) {
483 auto input_parameter = cond_graph->add_parameter();
484 MS_CHECK_TRUE_MSG(input_parameter != nullptr, nullptr, "create input_parameter return nullptr");
485 input_parameter->set_name(cond_graph_name + "_input_" + std::to_string(i) + "_parameter");
486 auto input_abstract = CreateTensorAbstract({}, kNumberTypeInt32);
487 if (input_abstract == nullptr) {
488 MS_LOG(ERROR) << "Create tensor abstarct failed";
489 return nullptr;
490 }
491 input_parameter->set_abstract(input_abstract);
492 if (i == 0) {
493 auto zero_parameter = CreateConstParamter(cond_graph, 0);
494 MS_CHECK_TRUE_MSG(zero_parameter != nullptr, nullptr, "create zero_parameter return nullptr");
495 zero_parameter->set_name(root_while_node->fullname_with_scope() + "_const_0");
496 auto less_value_node = CreateValueNode(schema::PrimitiveType_Less);
497 MS_CHECK_TRUE_MSG(less_value_node != nullptr, nullptr, "create less_value_node return nullptr");
498 std::vector<AnfNodePtr> less_inputs = {less_value_node, zero_parameter, input_parameter};
499 less_cnode = cond_graph->NewCNode(less_inputs);
500 if (less_cnode == nullptr) {
501 MS_LOG(ERROR) << "new cnode error";
502 return nullptr;
503 }
504 auto less_abstract = CreateTensorAbstract({}, kNumberTypeBool);
505 if (less_abstract == nullptr) {
506 MS_LOG(ERROR) << "Create tensor abstarct failed";
507 return nullptr;
508 }
509 less_cnode->set_abstract(less_abstract);
510 less_cnode->set_fullname_with_scope(cond_graph_name + "_less_cnode");
511 }
512 if (i == 1) {
513 auto and_value_node = CreateValueNode(schema::PrimitiveType_LogicalAnd);
514 MS_CHECK_TRUE_MSG(and_value_node != nullptr, nullptr, "CreateValueNode failed");
515 std::vector<AnfNodePtr> and_inputs = {and_value_node, less_cnode, input_parameter};
516 auto and_cnode = cond_graph->NewCNode(and_inputs);
517 if (and_cnode == nullptr) {
518 MS_LOG(ERROR) << "new cnode error";
519 return nullptr;
520 }
521 and_cnode->set_abstract(less_cnode->abstract());
522 and_cnode->set_fullname_with_scope(cond_graph_name + "_output_" + std::to_string(0) + "_cnode");
523 auto status = BuildReturnNode(cond_graph, {and_cnode});
524 if (status != RET_OK) {
525 MS_LOG(ERROR) << "build return node failed: " << status;
526 return nullptr;
527 }
528 }
529 }
530 cond_graph->set_attr("graph_name", MakeValue(cond_graph_name));
531 return cond_graph;
532 }
533
ConvertGraph(api::FuncGraphPtr func_graph)534 FuncGraphPtr ConvertGraph(api::FuncGraphPtr func_graph) {
535 auto impl = func_graph->impl();
536 return std::dynamic_pointer_cast<FuncGraph>(impl);
537 }
538 } // namespace
539
BuildBodyGraph(const onnx::NodeProto & loop_node,const onnx::GraphProto & subgraph_proto,int * cond_graph_input_num)540 FuncGraphPtr OnnxModelParser::BuildBodyGraph(const onnx::NodeProto &loop_node, const onnx::GraphProto &subgraph_proto,
541 int *cond_graph_input_num) {
542 MS_CHECK_TRUE_RET(cond_graph_input_num != nullptr, nullptr);
543 auto &loop_node_name = loop_node.name();
544 auto node_inputs_num = loop_node.input_size();
545 auto node_outputs_num = loop_node.output_size();
546 // skip trip_cout and cond input,scan_output nums
547 auto act_outputs_num = node_outputs_num - (node_inputs_num - 2);
548 auto loop_body_graph = std::make_shared<FuncGraph>();
549 MS_CHECK_TRUE_MSG(loop_body_graph != nullptr, nullptr, "create loop_body_graph return nullptr");
550 std::unordered_map<std::string, AnfNodePtr> anf_nodes_map;
551 std::vector<AnfNodePtr> gen_subgraph_inputs;
552 auto status = ConvertOnnxGraph(subgraph_proto, loop_body_graph, &anf_nodes_map, &gen_subgraph_inputs, loop_node_name);
553 if (status != RET_OK) {
554 MS_LOG(ERROR) << "convert loop OnnxGraph: " << status;
555 return nullptr;
556 }
557 auto return_node = loop_body_graph->get_return();
558 MS_CHECK_TRUE_MSG(return_node != nullptr, nullptr, "return node of subgraph is nullptr");
559 MS_CHECK_TRUE_RET(return_node->size() == DIMENSION_2D, nullptr);
560 auto return_tuple_cnode = return_node->input(1)->cast<CNodePtr>();
561 MS_CHECK_TRUE_RET(return_tuple_cnode != nullptr, nullptr);
562 auto return_new_inputs = return_tuple_cnode->inputs();
563 return_new_inputs.insert(return_new_inputs.end() - act_outputs_num, gen_subgraph_inputs.begin(),
564 gen_subgraph_inputs.end());
565
566 std::string max_trip_count_name = subgraph_proto.input(0).name();
567 status =
568 AddIterNumsUpdateEdge(loop_body_graph, &return_new_inputs, anf_nodes_map, max_trip_count_name, loop_node_name);
569 if (status != RET_OK) {
570 MS_LOG(ERROR) << "add iter nums update edge failed: " << status;
571 return nullptr;
572 }
573 auto root_while_node = GetCNodeFromControlFlowNodesMap(loop_node_name, control_nodes_map_);
574 MS_CHECK_TRUE_MSG(root_while_node != nullptr, nullptr, "cannot find root_while_node is control_nodes_map");
575 std::vector<AnfNodePtr> body_graph_inputs;
576 body_graph_inputs.reserve(subgraph_proto.input_size());
577 for (int j = 0; j < subgraph_proto.input_size(); j++) {
578 body_graph_inputs.emplace_back(anf_nodes_map[subgraph_proto.input(j).name()]);
579 }
580 body_graph_inputs.insert(body_graph_inputs.end(), gen_subgraph_inputs.begin(), gen_subgraph_inputs.end());
581 if (act_outputs_num != 0) {
582 status =
583 AddTensorArrayEdge(loop_body_graph, &return_new_inputs, loop_node_name, &body_graph_inputs, act_outputs_num);
584 if (status != RET_OK) {
585 MS_LOG(ERROR) << "add tensorarray update edge failed: " << status;
586 return nullptr;
587 }
588 // insert tensorliststack after while output
589 status = AddTensorListStackNode(root_while_node, loop_node, act_outputs_num, body_graph_inputs.size());
590 if (status != RET_OK) {
591 MS_LOG(ERROR) << "add tensorliststack node failed: " << status;
592 return nullptr;
593 }
594 }
595 return_tuple_cnode->set_inputs(return_new_inputs);
596 auto body_graph_name = loop_node_name + "_body_graph";
597 for (size_t j = 0; j < body_graph_inputs.size(); j++) {
598 MS_CHECK_TRUE_RET(body_graph_inputs[j] != nullptr, nullptr);
599 auto body_input = body_graph_inputs[j]->cast<ParameterPtr>();
600 MS_CHECK_TRUE_RET(body_input != nullptr, nullptr);
601 body_input->set_name(body_graph_name + "_input_" + std::to_string(j) + "_parameter");
602 }
603 for (size_t j = 1; j < return_new_inputs.size(); j++) {
604 if (utils::isa<CNodePtr>(return_new_inputs[j])) {
605 return_new_inputs[j]->cast<CNodePtr>()->set_fullname_with_scope(body_graph_name + "_output_" +
606 std::to_string(j - 1) + "_cnode");
607 } else if (utils::isa<ParameterPtr>(return_new_inputs[j])) {
608 return_new_inputs[j]->cast<ParameterPtr>()->set_name(body_graph_name + "_output_" + std::to_string(j - 1) +
609 "_parameter");
610 }
611 }
612 *cond_graph_input_num = return_new_inputs.size() - 1;
613 loop_body_graph->set_attr("graph_name", MakeValue(body_graph_name));
614 return loop_body_graph;
615 }
616
617 namespace {
CheckOnnxModel(const onnx::GraphProto & onnx_graph)618 STATUS CheckOnnxModel(const onnx::GraphProto &onnx_graph) {
619 // all input should in initialize
620 std::set<std::string> providers;
621 for (const auto &const_tensor : onnx_graph.initializer()) {
622 const auto &name = const_tensor.name();
623 if (providers.count(name) != 0) {
624 MS_LOG(ERROR) << "const tensor repeated";
625 return RET_ERROR;
626 }
627 providers.insert(name);
628 }
629 for (int i = 0; i < onnx_graph.input().size(); ++i) {
630 providers.insert(onnx_graph.input(i).name());
631 }
632 for (const auto &onnx_node : onnx_graph.node()) {
633 for (int i = 0; i < onnx_node.output_size(); i++) {
634 auto &output = onnx_node.output(i);
635 if (providers.count(output) != 0) {
636 MS_LOG(ERROR) << "Output tensor repeated";
637 return RET_ERROR;
638 }
639 providers.insert(output);
640 }
641 }
642 // all output should find
643 for (const auto &onnx_node : onnx_graph.node()) {
644 for (int i = 0; i < onnx_node.input_size(); i++) {
645 auto &input = onnx_node.input(i);
646 if (providers.count(input) == 0) {
647 MS_LOG(WARNING) << "Cannot find input: " << input << " of node: " << onnx_node.name();
648 }
649 }
650 }
651 return RET_OK;
652 }
653 } // namespace
654
Parse(const converter::ConverterParameters & flag)655 api::FuncGraphPtr OnnxModelParser::Parse(const converter::ConverterParameters &flag) {
656 auto model_file = flag.model_file;
657 NotSupportOp::GetInstance()->set_fmk_type("ONNX");
658 auto graph = std::make_shared<FuncGraph>();
659 MS_CHECK_TRUE_MSG(graph != nullptr, nullptr, "create FuncGraph failed");
660 res_graph_ = api::MakeShared<api::FuncGraph>(graph);
661 auto status = InitOriginModel(model_file);
662 if (RET_OK != status) {
663 ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
664 MS_LOG(ERROR) << "init origin model failed.";
665 return nullptr;
666 }
667 MS_ASSERT(onnx_root_graph_ != nullptr);
668
669 status = ConvertOnnxGraph(onnx_root_graph_, graph, &anf_nodes_map_, {}, "root_node");
670 if (RET_OK != status) {
671 ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
672 MS_LOG(ERROR) << "convert onnx graph failed.";
673 return nullptr;
674 }
675 static auto root_func_manager = Manage(graph);
676
677 for (auto &subgraph : all_subgraphs_) {
678 MS_CHECK_TRUE_RET(subgraph != nullptr, nullptr);
679 subgraph->set_manager(root_func_manager);
680 subgraph->set_attr("fmk", MakeValue(static_cast<int>(converter::kFmkTypeOnnx)));
681 }
682 graph->set_attr("graph_name", MakeValue("main_graph"));
683 graph->set_attr("fmk", MakeValue(static_cast<int>(converter::kFmkTypeOnnx)));
684 if ((status = CommonAnfAdjust(graph)) != RET_OK) {
685 MS_LOG(ERROR) << "AdjustForAnf failed.";
686 ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
687 return nullptr;
688 }
689 std::set<FuncGraphPtr> all_func_graphs = {};
690 GetAllFuncGraph(graph, &all_func_graphs);
691 if ((status = Onnx2AnfAdjust(all_func_graphs, flag)) != RET_OK) {
692 MS_LOG(ERROR) << "Onnx2AnfAdjust failed.";
693 ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
694 return nullptr;
695 }
696 auto unify_format = std::make_shared<UnifyFormatToNHWC>(kFmkTypeOnnx, false, flag.save_type);
697 MS_CHECK_TRUE_MSG(unify_format != nullptr, nullptr, "create unify_format return nullptr");
698 if (!unify_format->Run(graph)) {
699 MS_LOG(ERROR) << "Run insert transpose failed.";
700 return nullptr;
701 }
702 return res_graph_;
703 }
704
InitOriginModel(const std::string & model_file)705 STATUS OnnxModelParser::InitOriginModel(const std::string &model_file) {
706 MS_CHECK_TRUE_RET(res_graph_ != nullptr, RET_NULL_PTR);
707 auto res_graph = ConvertGraph(res_graph_);
708 auto status = ValidateFileStr(model_file, ".onnx");
709 if (status != RET_OK) {
710 MS_LOG(ERROR) << "INPUT ILLEGAL: modelFile must be *.onnx";
711 return status;
712 }
713 model_file_ = model_file;
714 status = ReadProtoFromBinaryFile(model_file, &onnx_model_);
715 if (status != RET_OK) {
716 MS_LOG(ERROR) << "Read onnx model file failed, model path: " << model_file;
717 ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
718 return status;
719 }
720 OnnxNodeParser::set_opset_version(onnx_model_.opset_import().Get(0).version());
721 OnnxNodeParser::SetOnnxModelFile(model_file);
722 onnx_root_graph_ = onnx_model_.graph();
723 auto fmk_value_node = MakeValue(static_cast<int>(converter::kFmkTypeOnnx));
724 CHECK_NULL_RETURN(fmk_value_node);
725 res_graph->set_attr("fmk", fmk_value_node);
726 return RET_OK;
727 }
728
ConvertOnnxGraph(const onnx::GraphProto & onnx_graph,const FuncGraphPtr & anf_graph,std::unordered_map<std::string,AnfNodePtr> * anf_nodes_map,std::vector<AnfNodePtr> * extra_subgraph_inputs,const std::string & root_node_name)729 STATUS OnnxModelParser::ConvertOnnxGraph(const onnx::GraphProto &onnx_graph, const FuncGraphPtr &anf_graph,
730 std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map,
731 std::vector<AnfNodePtr> *extra_subgraph_inputs,
732 const std::string &root_node_name) {
733 MS_ASSERT(anf_graph != nullptr && anf_nodes_map != nullptr && extra_subgraph_inputs != nullptr);
734 STATUS status = RET_OK;
735 status = CheckOnnxModel(onnx_graph);
736 if (status != RET_OK) {
737 ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
738 MS_LOG(ERROR) << "input onnx model error: " << status;
739 return status;
740 }
741 status = ConvertConstTensors(onnx_graph, anf_graph, anf_nodes_map, model_file_);
742 if (RET_OK != status) {
743 ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
744 MS_LOG(ERROR) << "convert const nodes failed.";
745 return RET_ERROR;
746 }
747
748 status = ConvertGraphInputs(onnx_graph, anf_graph, anf_nodes_map);
749 if (RET_OK != status) {
750 ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
751 MS_LOG(ERROR) << "convert graph inputs failed.";
752 return RET_OK;
753 }
754
755 status = ConvertNodes(onnx_graph, anf_graph, anf_nodes_map, extra_subgraph_inputs, root_node_name);
756 if (RET_OK != status) {
757 ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
758 MS_LOG(ERROR) << "convert nodes failed.";
759 return RET_ERROR;
760 }
761
762 status = ConvertGraphOutputs(onnx_graph, anf_graph, *anf_nodes_map);
763 if (RET_OK != status) {
764 ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
765 MS_LOG(ERROR) << "convert graph outputs failed.";
766 return RET_ERROR;
767 }
768 // save original output tensor names.
769 if (root_node_name == "root_node") {
770 std::vector<std::string> output_names;
771 std::transform(onnx_graph.output().begin(), onnx_graph.output().end(), std::back_inserter(output_names),
772 [](auto &graph_output) { return graph_output.name(); });
773 ConverterInnerContext::GetInstance()->SetGraphOutputTensorNames(output_names);
774 }
775 return status;
776 }
777
SortOnnxNodeIndex(const onnx::GraphProto & onnx_graph)778 std::vector<int> OnnxModelParser::SortOnnxNodeIndex(const onnx::GraphProto &onnx_graph) {
779 std::vector<int> sorted_node_index;
780 std::queue<int> onnx_nodes_queue;
781 std::set<std::string> node_names;
782 // for const tensor
783 for (const auto &const_tensor : onnx_graph.initializer()) {
784 const auto &name = const_tensor.name();
785 node_names.insert(name);
786 }
787 // for graph input
788 for (int i = 0; i < onnx_graph.input().size(); i++) {
789 node_names.insert(onnx_graph.input(i).name());
790 }
791 for (int i = 0; i < onnx_graph.node().size(); i++) {
792 auto onnx_node = onnx_graph.node(i);
793 if (onnx_node.op_type() == "If" || onnx_node.op_type() == "Loop" || has_subgraph_) {
794 sorted_node_index.clear();
795 has_subgraph_ = true;
796 for (int index = 0; index < onnx_graph.node().size(); index++) {
797 sorted_node_index.push_back(index);
798 }
799 return sorted_node_index;
800 }
801 if (onnx_node.op_type() == "Constant") {
802 sorted_node_index.push_back(i);
803 for (auto output_name : onnx_node.output()) {
804 node_names.insert(output_name);
805 }
806 } else {
807 onnx_nodes_queue.push(i);
808 }
809 }
810 bool find = false;
811 int pre_node_index = -1;
812 while (!onnx_nodes_queue.empty()) {
813 auto node_index = onnx_nodes_queue.front();
814 auto onnx_node = onnx_graph.node(node_index);
815 if (std::any_of(onnx_node.input().begin(), onnx_node.input().end(),
816 [&](const string &name) { return node_names.count(name) == 0 && !name.empty(); })) {
817 onnx_nodes_queue.pop();
818 onnx_nodes_queue.push(node_index);
819 if (!find && pre_node_index == node_index) {
820 MS_LOG(ERROR) << "sort onnx node failed.";
821 return {};
822 }
823 find = false;
824 pre_node_index = pre_node_index == -1 ? node_index : pre_node_index;
825 } else {
826 find = true;
827 pre_node_index = pre_node_index == node_index ? -1 : pre_node_index;
828 sorted_node_index.push_back(node_index);
829 onnx_nodes_queue.pop();
830 for (int i = 0; i < onnx_node.output_size(); i++) {
831 node_names.insert(onnx_node.output(i));
832 }
833 }
834 }
835 return sorted_node_index;
836 }
837
ConvertNodes(const onnx::GraphProto & onnx_graph,const FuncGraphPtr & anf_graph,std::unordered_map<std::string,AnfNodePtr> * anf_nodes_map,std::vector<AnfNodePtr> * graph_inputs,const std::string & root_node_name)838 STATUS OnnxModelParser::ConvertNodes(const onnx::GraphProto &onnx_graph, const FuncGraphPtr &anf_graph,
839 std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map,
840 std::vector<AnfNodePtr> *graph_inputs, const std::string &root_node_name) {
841 CHECK_NULL_RETURN(anf_graph);
842 CHECK_NULL_RETURN(anf_nodes_map);
843 auto sorted_node_index = SortOnnxNodeIndex(onnx_graph);
844 if (sorted_node_index.empty()) {
845 MS_LOG(ERROR) << "SortOnnxNodeIndex failed.";
846 return RET_ERROR;
847 }
848 STATUS status = RET_OK;
849 for (auto node_index : sorted_node_index) {
850 const auto &onnx_node = onnx_graph.node(node_index);
851 ops::PrimitiveCPtr primitive_c;
852 auto node_parser = registry::NodeParserRegistry::GetNodeParser(kFmkTypeOnnx, onnx_node.op_type());
853 if (node_parser != nullptr) {
854 primitive_c = node_parser->Parse(onnx_graph, onnx_node)->GetPrim();
855 } else {
856 auto node_parser_builtin = OnnxNodeParserRegistry::GetInstance().GetNodeParser(onnx_node.op_type());
857 if (node_parser_builtin == nullptr) {
858 NotSupportOp::GetInstance()->InsertOp(onnx_node.op_type());
859 status = status == RET_OK ? RET_NOT_FIND_OP : status;
860 MS_LOG(ERROR) << "not support onnx data type " << onnx_node.op_type();
861 }
862 if (status != RET_OK) {
863 continue;
864 }
865 MS_LOG(INFO) << "parse op:" << onnx_node.op_type();
866 primitive_c = node_parser_builtin->Parse(onnx_graph, onnx_node);
867 }
868
869 if (primitive_c == nullptr) {
870 MS_LOG(ERROR) << "parse node " << onnx_node.op_type() << " failed.";
871 status = RET_ERROR;
872 continue;
873 }
874 if (primitive_c->GetAttr(ops::kOriginalFormat) == nullptr) {
875 primitive_c->AddAttr(mindspore::ops::kOriginalFormat, MakeValue<int64_t>(NCHW));
876 }
877 status = ConvertOpQuantParams(onnx_node, primitive_c);
878 if (status != RET_OK) {
879 MS_LOG(ERROR) << "convert " << onnx_node.op_type() << " quant param failed.";
880 continue;
881 }
882 // build CNode
883 status = BuildCNode(onnx_node, anf_graph, anf_nodes_map, graph_inputs, primitive_c, root_node_name);
884 if (status != RET_OK) {
885 MS_LOG(ERROR) << "build cnode " << onnx_node.op_type() << " failed.";
886 }
887
888 if (onnx_node.op_type() == "Loop") {
889 child_root_map_[onnx_node.name()] = root_node_name;
890 control_nodes_map_[onnx_node.name()] = anf_nodes_map;
891
892 status = ConvertLoopOnnxNode(onnx_node, anf_nodes_map, root_node_name);
893 if (status != RET_OK) {
894 MS_LOG(ERROR) << "build loop node failed.";
895 }
896 }
897 if (onnx_node.op_type() == "If") {
898 child_root_map_[onnx_node.name()] = root_node_name;
899 control_nodes_map_[onnx_node.name()] = anf_nodes_map;
900
901 status = ConvertIfOnnxNode(onnx_node, anf_nodes_map, root_node_name);
902 if (status != RET_OK) {
903 MS_LOG(ERROR) << "build if node failed.";
904 }
905 }
906 }
907 return status;
908 }
909
ConvertIfSubgraph(const onnx::GraphProto & subgraph_proto,const FuncGraphPtr & subgraph,const std::string & subgraph_name,const std::string & if_node_name,const std::string & root_node_name)910 STATUS OnnxModelParser::ConvertIfSubgraph(const onnx::GraphProto &subgraph_proto, const FuncGraphPtr &subgraph,
911 const std::string &subgraph_name, const std::string &if_node_name,
912 const std::string &root_node_name) {
913 MS_CHECK_TRUE_RET(subgraph != nullptr, RET_NULL_PTR);
914 std::unordered_map<std::string, AnfNodePtr> anf_nodes_map;
915 std::vector<AnfNodePtr> subgraph_extra_inputs;
916 auto status = ConvertOnnxGraph(subgraph_proto, subgraph, &anf_nodes_map, &subgraph_extra_inputs, if_node_name);
917 if (status != RET_OK) {
918 MS_LOG(ERROR) << "convert loop OnnxGraph failed";
919 return status;
920 }
921 subgraph->set_attr("graph_name", MakeValue(subgraph_name));
922 // update subgraph in out name
923 for (int j = 0; j < subgraph_proto.input_size(); j++) {
924 auto input_anode_iter = anf_nodes_map.find(subgraph_proto.input(j).name());
925 if (input_anode_iter == anf_nodes_map.end()) {
926 MS_LOG(ERROR) << "cannot find input anode";
927 return RET_ERROR;
928 }
929 auto input_parameter = input_anode_iter->second->cast<ParameterPtr>();
930 MS_CHECK_TRUE_MSG(input_parameter != nullptr, RET_ERROR, "subgraph input should be a parameter");
931 input_parameter->set_name(subgraph_name + "_input_" + std::to_string(j) + "_parameter");
932 }
933 for (size_t j = 0; j < subgraph_extra_inputs.size(); j++) {
934 auto input_parameter = subgraph_extra_inputs[j]->cast<ParameterPtr>();
935 MS_CHECK_TRUE_MSG(input_parameter != nullptr, RET_ERROR, "subgraph input should be a parameter");
936 input_parameter->set_name(subgraph_name + "_input_" + std::to_string(j + subgraph_proto.input_size()) +
937 "_parameter");
938 }
939 auto return_node = subgraph->get_return();
940 MS_CHECK_TRUE_MSG(return_node != nullptr, RET_ERROR, "subgraph has no return");
941 MS_CHECK_GE(return_node->size(), kInputSize1, RET_ERROR);
942 std::vector<AnfNodePtr> return_act_inputs;
943 int start_index = 0;
944 if (subgraph_proto.output_size() > 1) {
945 auto return_cnode = return_node->input(1)->cast<CNodePtr>();
946 MS_CHECK_TRUE_RET(return_cnode != nullptr, RET_NULL_PTR);
947 return_act_inputs = return_cnode->inputs();
948 start_index = 1;
949 } else {
950 return_act_inputs = {return_node->input(1)};
951 }
952 for (size_t j = start_index; j < return_act_inputs.size(); j++) {
953 if (utils::isa<CNodePtr>(return_act_inputs[j])) {
954 return_act_inputs[j]->cast<CNodePtr>()->set_fullname_with_scope(subgraph_name + "_output_" +
955 std::to_string(j - start_index) + "_cnode");
956 } else if (utils::isa<ParameterPtr>(return_act_inputs[j])) {
957 return_act_inputs[j]->cast<ParameterPtr>()->set_name(subgraph_name + "_output_" +
958 std::to_string(j - start_index) + "_parameter");
959 }
960 }
961 return RET_OK;
962 }
963
ConvertIfOnnxNode(const onnx::NodeProto & onnx_node,std::unordered_map<std::string,AnfNodePtr> * anf_root_nodes_map,const std::string & root_node_name)964 STATUS OnnxModelParser::ConvertIfOnnxNode(const onnx::NodeProto &onnx_node,
965 std::unordered_map<std::string, AnfNodePtr> *anf_root_nodes_map,
966 const std::string &root_node_name) {
967 CHECK_NULL_RETURN(anf_root_nodes_map);
968 FuncGraphPtr then_branch_graph = nullptr;
969 FuncGraphPtr else_branch_graph = nullptr;
970 FuncGraphPtr subgraph = nullptr;
971 std::string subgraph_name;
972 auto &if_node_name = onnx_node.name();
973
974 for (int i = 0; i < onnx_node.attribute_size(); i++) {
975 auto &attr = onnx_node.attribute(i);
976 auto &subgraph_proto = attr.g();
977 if (attr.name().find("then_branch") != std::string::npos) {
978 subgraph_name = if_node_name + "_then_branch";
979 then_branch_graph = std::make_shared<FuncGraph>();
980 MS_CHECK_TRUE_MSG(then_branch_graph != nullptr, RET_NULL_PTR, "create then_branch_graph return nullptr");
981 auto status = ConvertIfSubgraph(subgraph_proto, then_branch_graph, subgraph_name, if_node_name, root_node_name);
982 if (status != RET_OK) {
983 MS_LOG(ERROR) << "build if node else branch failed.";
984 }
985 } else if (attr.name().find("else_branch") != std::string::npos) {
986 subgraph_name = if_node_name + "_else_branch";
987 else_branch_graph = std::make_shared<FuncGraph>();
988 MS_CHECK_TRUE_MSG(else_branch_graph != nullptr, RET_NULL_PTR, "create else_branch_graph return nullptr");
989 auto status = ConvertIfSubgraph(subgraph_proto, else_branch_graph, subgraph_name, if_node_name, root_node_name);
990 if (status != RET_OK) {
991 MS_LOG(ERROR) << "build if node else branch failed.";
992 }
993 } else {
994 continue;
995 }
996 }
997 all_subgraphs_.emplace_back(then_branch_graph);
998 all_subgraphs_.emplace_back(else_branch_graph);
999 auto then_value_node = NewValueNode(then_branch_graph);
1000 MS_CHECK_TRUE_MSG(then_value_node != nullptr, RET_NULL_PTR, "create then_value_node return nullptr");
1001 auto else_value_node = NewValueNode(else_branch_graph);
1002 MS_CHECK_TRUE_MSG(else_value_node != nullptr, RET_NULL_PTR, "create else_value_node return nullptr");
1003 auto root_if_node = GetCNodeFromControlFlowNodesMap(if_node_name, control_nodes_map_);
1004 MS_CHECK_TRUE_MSG(root_if_node != nullptr, RET_ERROR, "cannot find root_if_node is control_nodes_map");
1005 auto if_new_inputs = root_if_node->inputs();
1006 if_new_inputs.insert(if_new_inputs.begin() + 1, {then_value_node, else_value_node});
1007
1008 std::vector<AnfNodePtr> if_new_input_not_same{};
1009 std::set<AnfNodePtr> if_set{};
1010 for (auto &input : if_new_inputs) {
1011 if (if_set.find(input) != if_set.end()) {
1012 continue;
1013 }
1014 if_new_input_not_same.push_back(input);
1015 if_set.insert(input);
1016 }
1017
1018 root_if_node->set_inputs(if_new_input_not_same);
1019 return RET_OK;
1020 }
1021
BuildCNode(const onnx::NodeProto & onnx_node,const FuncGraphPtr & anf_graph,std::unordered_map<std::string,AnfNodePtr> * anf_nodes_map,std::vector<AnfNodePtr> * graph_inputs,PrimitiveCPtr primitive_c,std::string loop_name)1022 STATUS OnnxModelParser::BuildCNode(const onnx::NodeProto &onnx_node, const FuncGraphPtr &anf_graph,
1023 std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map,
1024 std::vector<AnfNodePtr> *graph_inputs, PrimitiveCPtr primitive_c,
1025 std::string loop_name) {
1026 CHECK_NULL_RETURN(anf_graph);
1027 CHECK_NULL_RETURN(anf_nodes_map);
1028 CHECK_NULL_RETURN(primitive_c);
1029 std::vector<AnfNodePtr> op_inputs;
1030 for (int i = 0; i < onnx_node.input_size(); i++) {
1031 auto input_name = onnx_node.input(i);
1032 if (input_name.empty()) {
1033 std::string empty_input_index = "empty_input_index";
1034 primitive_c->AddAttr(empty_input_index, MakeValue<int>(i));
1035 continue;
1036 }
1037
1038 if (anf_nodes_map->find(input_name) != anf_nodes_map->end()) {
1039 op_inputs.push_back(anf_nodes_map->at(input_name));
1040 } else {
1041 // subgraph may refer root graph nodes
1042 std::vector<CNodePtr> need_add_input_nodes;
1043 auto ext_subgraph_input = anf_graph->add_parameter();
1044 MS_CHECK_TRUE_MSG(ext_subgraph_input != nullptr, RET_NULL_PTR, "create parameter return nullptr");
1045 ParameterPtr inner_extra_paramter = nullptr;
1046 while (!loop_name.empty() && child_root_map_.find(loop_name) != child_root_map_.end()) {
1047 auto cur_node_map = control_nodes_map_[loop_name];
1048 CHECK_NULL_RETURN(cur_node_map);
1049 if (cur_node_map->find(input_name) != cur_node_map->end()) {
1050 auto outside_input_node = cur_node_map->at(input_name);
1051 CHECK_NULL_RETURN(outside_input_node);
1052 // copy outside input parameter value to inside subgraph
1053 ext_subgraph_input->set_abstract(outside_input_node->abstract());
1054 ext_subgraph_input->set_name(input_name);
1055 if (outside_input_node->isa<Parameter>()) {
1056 auto parameter = outside_input_node->cast<ParameterPtr>();
1057 if (!parameter->has_default()) {
1058 MS_LOG(ERROR) << "outside_input_node should has data.";
1059 return RET_ERROR;
1060 }
1061 auto tensor_info = parameter->default_param()->cast<tensor::TensorPtr>();
1062 auto copy_tensor_info = CreateTensorInfo(tensor_info->data_c(), tensor_info->Size(), tensor_info->shape(),
1063 tensor_info->data_type());
1064 if (copy_tensor_info == nullptr) {
1065 MS_LOG(ERROR) << "memcpy failed.";
1066 return RET_ERROR;
1067 }
1068 ext_subgraph_input->set_default_param(copy_tensor_info);
1069 } else {
1070 // output inside cnode need make extra input
1071 CHECK_NULL_RETURN(graph_inputs);
1072 graph_inputs->emplace_back(ext_subgraph_input);
1073 if (cur_node_map->find(loop_name) != cur_node_map->end()) {
1074 CHECK_NULL_RETURN(cur_node_map->at(loop_name));
1075 auto control_node = cur_node_map->at(loop_name)->cast<CNodePtr>();
1076 MS_CHECK_TRUE_RET(control_node != nullptr, RET_NULL_PTR);
1077 control_node->add_input(outside_input_node);
1078 } else {
1079 MS_LOG(ERROR) << "loop node: " << loop_name << " not found in cur node map.";
1080 return RET_ERROR;
1081 }
1082 for (auto &control_node : need_add_input_nodes) {
1083 CHECK_NULL_RETURN(control_node);
1084 auto func_graph = control_node->func_graph();
1085 auto extra_input_parameter = func_graph->add_parameter();
1086 MS_CHECK_TRUE_MSG(extra_input_parameter != nullptr, RET_NULL_PTR, "create parameter return nullptr");
1087 extra_input_parameter->set_name(input_name);
1088 extra_input_parameter->set_abstract(outside_input_node->abstract());
1089 control_node->add_input(extra_input_parameter);
1090 }
1091 }
1092 op_inputs.push_back(ext_subgraph_input);
1093 anf_nodes_map->emplace(input_name, ext_subgraph_input);
1094 break;
1095 } else {
1096 if (cur_node_map->find(loop_name) != cur_node_map->end()) {
1097 CHECK_NULL_RETURN(cur_node_map->at(loop_name));
1098 need_add_input_nodes.emplace_back(cur_node_map->at(loop_name)->cast<CNodePtr>());
1099 } else {
1100 MS_LOG(ERROR) << "loop node: " << loop_name << " not found in cur node map.";
1101 return RET_ERROR;
1102 }
1103 loop_name = child_root_map_[loop_name];
1104 }
1105 }
1106 }
1107 }
1108 auto new_cnode = anf_graph->NewCNode(primitive_c, op_inputs);
1109 if (new_cnode == nullptr) {
1110 MS_LOG(ERROR) << "new cnode error";
1111 return RET_ERROR;
1112 }
1113 new_cnode->set_fullname_with_scope(onnx_node.name());
1114 auto status = BuildOpOutputs(onnx_node, anf_graph, anf_nodes_map, new_cnode);
1115 return status;
1116 }
1117
ConvertOpQuantParams(const onnx::NodeProto & onnx_node,ops::PrimitiveCPtr primitive_c)1118 STATUS OnnxModelParser::ConvertOpQuantParams(const onnx::NodeProto &onnx_node, ops::PrimitiveCPtr primitive_c) {
1119 CHECK_NULL_RETURN(primitive_c);
1120 auto status = ParseQuantParam(onnx_node);
1121 if (status != RET_OK) {
1122 MS_LOG(ERROR) << "parse quant param failed.";
1123 return RET_ERROR;
1124 }
1125 // set input tensors
1126 std::map<int, std::vector<schema::QuantParamT>> input_quant_params;
1127 size_t idx = 0;
1128 for (int i = 0; i < onnx_node.input_size(); ++i) {
1129 const auto &input_name = onnx_node.input(i);
1130 std::vector<schema::QuantParamT> quant_params;
1131 status = SetTensorQuantParam(input_name, &quant_params);
1132 if (status != RET_OK) {
1133 MS_LOG(ERROR) << "set input tensor quant param failed.";
1134 return status;
1135 }
1136 if (!quant_params.empty()) {
1137 input_quant_params.insert({idx, quant_params});
1138 idx++;
1139 }
1140 }
1141 // set out tensors
1142 idx = 0;
1143 std::map<int, std::vector<schema::QuantParamT>> output_quant_params;
1144 for (int i = 0; i < onnx_node.output_size(); ++i) {
1145 const auto &output_name = onnx_node.output(i);
1146 std::vector<schema::QuantParamT> quant_params;
1147 status = SetTensorQuantParam(output_name, &quant_params);
1148 if (status != RET_OK) {
1149 MS_LOG(ERROR) << "set output tensor quant param failed.";
1150 return status;
1151 }
1152 if (!quant_params.empty()) {
1153 output_quant_params.insert({idx, quant_params});
1154 idx++;
1155 }
1156 }
1157 schema::QuantParamT quant_param;
1158 bool has_quant = false;
1159 for (const auto &onnx_node_attr : onnx_node.attribute()) {
1160 if (onnx_node_attr.name() == "scale") {
1161 float scale = onnx_node_attr.f();
1162 quant_param.scale = scale;
1163 MS_LOG(INFO) << onnx_node.name() << " scale is " << quant_param.scale;
1164 has_quant = true;
1165 } else if (onnx_node_attr.name() == "offset") {
1166 float offset = onnx_node_attr.f();
1167 quant_param.zeroPoint = static_cast<int>(offset);
1168 MS_LOG(INFO) << onnx_node.name() << " offset is " << quant_param.zeroPoint;
1169 has_quant = true;
1170 } else if (onnx_node_attr.name() == "quant_bit") {
1171 int64_t quant_bit = onnx_node_attr.i();
1172 quant_param.numBits = static_cast<int>(quant_bit);
1173 MS_LOG(INFO) << onnx_node.name() << " quant_bit is " << quant_param.numBits;
1174 has_quant = true;
1175 }
1176 }
1177 if (has_quant) {
1178 std::vector<schema::QuantParamT> quant_params;
1179 quant_param.inited = true;
1180 quant_params.push_back(quant_param);
1181 input_quant_params.insert({0, quant_params});
1182 output_quant_params.insert({0, quant_params});
1183 }
1184 if (!input_quant_params.empty() || !output_quant_params.empty()) {
1185 auto quant_params_holder = std::make_shared<QuantParamHolder>(0, 0);
1186 CHECK_NULL_RETURN(quant_params_holder);
1187 for (auto &iter : input_quant_params) {
1188 quant_params_holder->set_input_quant_param(iter.first, iter.second);
1189 }
1190 for (auto &iter : output_quant_params) {
1191 quant_params_holder->set_output_quant_param(iter.first, iter.second);
1192 }
1193 primitive_c->AddAttr("quant_params", quant_params_holder);
1194 }
1195 return RET_OK;
1196 }
1197
ParseQuantParam(const onnx::NodeProto & onnx_node)1198 STATUS OnnxModelParser::ParseQuantParam(const onnx::NodeProto &onnx_node) {
1199 for (const auto &onnx_node_attr : onnx_node.attribute()) {
1200 if (onnx_node_attr.name() == "Y_scale") {
1201 float scale = onnx_node_attr.f();
1202 if (BuildParameterNodeForQuantParam(&scale, "scale_" + onnx_node.output(0), kNumberTypeFloat32) != RET_OK) {
1203 MS_LOG(ERROR) << "parse quant param failed.";
1204 return RET_ERROR;
1205 }
1206 } else if (onnx_node_attr.name() == "Y_zero_point") {
1207 int64_t zero_point = onnx_node_attr.i();
1208 if (BuildParameterNodeForQuantParam(&zero_point, "zero_point_" + onnx_node.output(0), kNumberTypeInt64) !=
1209 RET_OK) {
1210 MS_LOG(ERROR) << "parse quant param failed.";
1211 return RET_ERROR;
1212 }
1213 }
1214 }
1215 return RET_OK;
1216 }
1217
SetTensorQuantParam(const std::string & tensor_name,std::vector<QuantParamT> * quant_params)1218 STATUS OnnxModelParser::SetTensorQuantParam(const std::string &tensor_name, std::vector<QuantParamT> *quant_params) {
1219 MS_CHECK_TRUE_RET(quant_params != nullptr, RET_NULL_PTR);
1220 quant_params->clear();
1221 auto quant_param = std::make_unique<QuantParamT>();
1222 MS_CHECK_TRUE_MSG(quant_param != nullptr, RET_NULL_PTR, "create QuantParamT return nullptr");
1223 for (int i = 0; i < onnx_root_graph_.quantization_annotation_size(); ++i) {
1224 auto tensor_annotation = onnx_root_graph_.quantization_annotation(i);
1225 if (!tensor_annotation.has_tensor_name() || tensor_annotation.tensor_name() != tensor_name) {
1226 continue;
1227 }
1228 for (const auto &item : tensor_annotation.quant_parameter_tensor_names()) {
1229 if (!item.has_key() || !item.has_value()) {
1230 continue;
1231 }
1232
1233 const auto &quant_tensor_name = item.value();
1234 if (item.key() == "SCALE_TENSOR") {
1235 auto status = CopyTensorQuantParam(quant_tensor_name, quant_param.get(), true);
1236 if (status != RET_OK) {
1237 MS_LOG(ERROR) << "quant param scale get failed";
1238 return status;
1239 }
1240 } else if (item.key() == "ZERO_POINT_TENSOR") {
1241 auto status = CopyTensorQuantParam(quant_tensor_name, quant_param.get(), false);
1242 if (status != RET_OK) {
1243 MS_LOG(ERROR) << "quant param zero_point get failed";
1244 return status;
1245 }
1246 }
1247 }
1248 break;
1249 }
1250 if (quant_param->inited) {
1251 quant_params->push_back(*std::move(quant_param));
1252 return RET_OK;
1253 }
1254 return SetTensorQuantParamFromNode(tensor_name, quant_params);
1255 }
1256
SetTensorQuantParamFromNode(const std::string & tensor_name,std::vector<QuantParamT> * quant_params)1257 STATUS OnnxModelParser::SetTensorQuantParamFromNode(const std::string &tensor_name,
1258 std::vector<QuantParamT> *quant_params) {
1259 MS_CHECK_TRUE_RET(quant_params != nullptr, RET_NULL_PTR);
1260 quant_params->clear();
1261 auto quant_param = std::make_unique<QuantParamT>();
1262 MS_CHECK_TRUE_MSG(quant_param != nullptr, RET_NULL_PTR, "create QuantParamT return nullptr");
1263 if (OnnxNodeParser::opset_version() <= 15) {
1264 quant_param->multiplier = 0;
1265 }
1266 std::string quant_tensor_name = "scale_" + tensor_name;
1267 auto status = CopyTensorQuantParam(quant_tensor_name, quant_param.get(), true);
1268 if (status != RET_OK) {
1269 MS_LOG(ERROR) << "quant param scale get failed";
1270 return status;
1271 }
1272 quant_tensor_name = "zero_point_" + tensor_name;
1273 status = CopyTensorQuantParam(quant_tensor_name, quant_param.get(), false);
1274 if (status != RET_OK) {
1275 MS_LOG(ERROR) << "quant param zero_point get failed";
1276 return status;
1277 }
1278 if (quant_param->inited) {
1279 quant_params->push_back(*std::move(quant_param));
1280 }
1281 return RET_OK;
1282 }
1283
CopyTensorQuantParam(const std::string & tensor_name,QuantParamT * quant_param,bool scale_or_not)1284 STATUS OnnxModelParser::CopyTensorQuantParam(const std::string &tensor_name, QuantParamT *quant_param,
1285 bool scale_or_not) {
1286 CHECK_NULL_RETURN(quant_param);
1287 auto iter = anf_nodes_map_.find(tensor_name);
1288 if (iter == anf_nodes_map_.end()) {
1289 MS_LOG(DEBUG) << "has no quant param";
1290 return RET_OK;
1291 }
1292 if (!utils::isa<ParameterPtr>(iter->second)) {
1293 MS_LOG(ERROR) << "quant param get failed";
1294 return RET_ERROR;
1295 }
1296 auto quant_parameter_node = iter->second->cast<ParameterPtr>();
1297 MS_CHECK_TRUE_RET(quant_parameter_node != nullptr, RET_NULL_PTR);
1298 if (!quant_parameter_node->has_default()) {
1299 MS_LOG(ERROR) << "quant param get failed";
1300 return RET_ERROR;
1301 }
1302 auto tensor_info = quant_parameter_node->default_param()->cast<tensor::TensorPtr>();
1303 if (tensor_info == nullptr) {
1304 MS_LOG(ERROR) << "parameterNode's default param is not tensor::TensorPtr";
1305 return RET_ERROR;
1306 }
1307 if (tensor_info->data_c() == nullptr) {
1308 MS_LOG(ERROR) << "parameterNode's default param has no data";
1309 return RET_ERROR;
1310 }
1311 if (scale_or_not) {
1312 quant_param->scale = *reinterpret_cast<float *>(tensor_info->data_c());
1313 quant_param->inited = true;
1314 } else {
1315 quant_param->zeroPoint = *reinterpret_cast<int64_t *>(tensor_info->data_c());
1316 quant_param->inited = true;
1317 }
1318 return RET_OK;
1319 }
1320
AddTensorListStackNode(const AnfNodePtr & root_while_node,const onnx::NodeProto & onnx_node,int act_outputs_num,int body_output_size)1321 STATUS OnnxModelParser::AddTensorListStackNode(const AnfNodePtr &root_while_node, const onnx::NodeProto &onnx_node,
1322 int act_outputs_num, int body_output_size) {
1323 MS_CHECK_TRUE_RET(root_while_node != nullptr, RET_NULL_PTR);
1324 auto &loop_node_name = onnx_node.name();
1325 auto root_anf_graph = root_while_node->func_graph();
1326 auto stack_elem_node = CreateConstParamter(root_anf_graph, -1);
1327 MS_CHECK_TRUE_MSG(stack_elem_node != nullptr, RET_NULL_PTR, "create const parameter return nullptr");
1328 stack_elem_node->set_name(loop_node_name + "_element_shape");
1329 for (int j = 0; j < act_outputs_num; j++) {
1330 auto output_size = onnx_node.output_size();
1331 auto &loop_output_name = onnx_node.output(output_size - act_outputs_num + j);
1332 MS_CHECK_TRUE_RET(control_nodes_map_.find(loop_node_name) != control_nodes_map_.end(), RET_NULL_PTR);
1333 MS_CHECK_TRUE_RET(
1334 control_nodes_map_[loop_node_name]->find(loop_output_name) != control_nodes_map_[loop_node_name]->end(),
1335 RET_NULL_PTR);
1336 auto &while_output_node = control_nodes_map_[loop_node_name]->at(loop_output_name);
1337 MS_CHECK_TRUE_MSG(while_output_node != nullptr, RET_ERROR, "cannot find while_output_node is control_nodes_map");
1338 auto tensor_list_stack_prim = std::make_shared<ops::TensorListStack>();
1339 if (tensor_list_stack_prim == nullptr) {
1340 MS_LOG(ERROR) << "create stack failed";
1341 return RET_ERROR;
1342 }
1343 tensor_list_stack_prim->set_num_elements(-1);
1344 auto prim_c = tensor_list_stack_prim->GetPrim();
1345 MS_CHECK_TRUE_RET(prim_c != nullptr, RET_ERROR);
1346 auto stack_value_node = NewValueNode(prim_c);
1347 MS_CHECK_TRUE_MSG(stack_value_node != nullptr, RET_NULL_PTR, "create stack_value_node return nullptr");
1348 std::vector<AnfNodePtr> stack_inputs = {stack_value_node, while_output_node, stack_elem_node};
1349 auto tensorlist_stack_cnode = root_anf_graph->NewCNode(stack_inputs);
1350 if (tensorlist_stack_cnode == nullptr) {
1351 MS_LOG(ERROR) << "new cnode error";
1352 return RET_ERROR;
1353 }
1354 tensorlist_stack_cnode->set_fullname_with_scope(loop_node_name + "_tensorlist_stack_node_" + std::to_string(j));
1355 tensorlist_stack_cnode->set_abstract(stack_elem_node->abstract());
1356
1357 // update getitem value output index
1358 auto new_get_item_value = NewValueNode(MakeValue<int64_t>(body_output_size - act_outputs_num + j));
1359 MS_CHECK_TRUE_MSG(new_get_item_value != nullptr, RET_NULL_PTR, "create new_get_item_value return nullptr");
1360 CHECK_NULL_RETURN(while_output_node->cast<CNodePtr>());
1361 while_output_node->cast<CNodePtr>()->set_input(2, new_get_item_value);
1362 // insert tensorliststack after while_output
1363 (*control_nodes_map_[loop_node_name])[loop_output_name] = tensorlist_stack_cnode;
1364 }
1365 return RET_OK;
1366 }
1367
1368 // onnx loop scan_output need through tensorlist op,while node need add new inputs
AddTensorArrayEdge(const FuncGraphPtr & anf_graph,std::vector<AnfNodePtr> * return_new_inputs,const std::string & loop_node_name,std::vector<AnfNodePtr> * body_graph_inputs,int act_output_num)1369 STATUS OnnxModelParser::AddTensorArrayEdge(const FuncGraphPtr &anf_graph, std::vector<AnfNodePtr> *return_new_inputs,
1370 const std::string &loop_node_name,
1371 std::vector<AnfNodePtr> *body_graph_inputs, int act_output_num) {
1372 MS_CHECK_TRUE_RET(anf_graph != nullptr && return_new_inputs != nullptr && body_graph_inputs != nullptr, RET_NULL_PTR);
1373 // body graph output is trip_count,cond_count,loop_var,placeholder,scan_outputs
1374 auto root_while_node = GetCNodeFromControlFlowNodesMap(loop_node_name, control_nodes_map_);
1375 MS_CHECK_TRUE_MSG(root_while_node != nullptr, RET_ERROR, "cannot find root_while_node is control_nodes_map");
1376 if (root_while_node == nullptr) {
1377 MS_LOG(ERROR) << "anf root node map cannot find loop node" << loop_node_name;
1378 return RET_ERROR;
1379 }
1380 auto anf_root_graph = root_while_node->func_graph();
1381 auto root_item_index_parameter = CreateConstParamter(anf_root_graph, 0);
1382 MS_CHECK_TRUE_MSG(root_item_index_parameter != nullptr, RET_NULL_PTR,
1383 "create root_item_index_parameter return nullptr");
1384 root_item_index_parameter->set_name(loop_node_name + "_item_index");
1385 root_while_node->add_input(root_item_index_parameter);
1386 // fake parameter need pass by root while node input
1387 auto item_index_parameter = anf_graph->add_parameter();
1388 MS_CHECK_TRUE_MSG(item_index_parameter != nullptr, RET_NULL_PTR, "create item_index_parameter return nullptr");
1389 item_index_parameter->set_name(loop_node_name + "_item_index_2");
1390 item_index_parameter->set_abstract(root_item_index_parameter->abstract());
1391 body_graph_inputs->emplace_back(item_index_parameter);
1392 // item index++ edge
1393 auto add_value_node = CreateValueNode(schema::PrimitiveType_AddFusion);
1394 if (add_value_node == nullptr) {
1395 MS_LOG(ERROR) << "create add failed.";
1396 return RET_NULL_PTR;
1397 }
1398 auto add_one_input = CreateConstParamter(anf_graph, 1);
1399 MS_CHECK_TRUE_MSG(root_item_index_parameter != nullptr, RET_NULL_PTR, "create add_one_input return nullptr");
1400 add_one_input->set_name(loop_node_name + "_const_placeholder_1");
1401 std::vector<AnfNodePtr> add_inputs = {add_value_node, item_index_parameter, add_one_input};
1402 auto add_cnode = anf_graph->NewCNode(add_inputs);
1403 if (add_cnode == nullptr) {
1404 MS_LOG(ERROR) << "new cnode error";
1405 return RET_ERROR;
1406 }
1407 add_cnode->set_fullname_with_scope(loop_node_name + "item_index_add_node");
1408 add_cnode->set_abstract(root_item_index_parameter->abstract());
1409 // return node inputs will be trip_count,cond_out,loop_var,placeholder,tensorarray...
1410 if (static_cast<int>(return_new_inputs->size()) < act_output_num || act_output_num < 0) {
1411 MS_LOG(ERROR) << "act_output_num out of range of return_new_inputs";
1412 return RET_ERROR;
1413 }
1414 return_new_inputs->insert(return_new_inputs->end() - act_output_num, add_cnode);
1415
1416 for (int i = 0; i < act_output_num; i++) {
1417 // tensor_array need as root while input
1418 auto while_tensor_array_input = anf_root_graph->add_parameter();
1419 MS_CHECK_TRUE_MSG(while_tensor_array_input != nullptr, RET_NULL_PTR,
1420 "create while_tensor_array_input return nullptr");
1421 std::vector<int> tensor_list_data(kTensorListDatasize);
1422 tensor_list_data[kTypeIndex] = kTypeUnknown;
1423 tensor_list_data[kElementShapeIndex] = 0;
1424 tensor_list_data[kTensorsNumIndex] = -1;
1425 if (INT_MUL_OVERFLOW_THRESHOLD(tensor_list_data.size(), sizeof(int), SIZE_MAX)) {
1426 MS_LOG(ERROR) << "data_size overflow";
1427 return RET_ERROR;
1428 }
1429 auto tensor_info = CreateTensorInfo(tensor_list_data.data(), tensor_list_data.size() * sizeof(int),
1430 {static_cast<int64_t>(tensor_list_data.size())}, kObjectTypeTensorType);
1431 if (tensor_info == nullptr) {
1432 MS_LOG(ERROR) << "Create tensor info failed";
1433 return RET_ERROR;
1434 }
1435 auto abstract_tensor = tensor_info->ToAbstract();
1436 if (abstract_tensor == nullptr) {
1437 MS_LOG(ERROR) << "Create tensor abstarct failed";
1438 return RET_ERROR;
1439 }
1440 while_tensor_array_input->set_abstract(abstract_tensor);
1441 while_tensor_array_input->set_default_param(tensor_info);
1442 while_tensor_array_input->set_name(loop_node_name + "_scan_outputs_tensorarray_while_input");
1443 root_while_node->add_input(while_tensor_array_input);
1444
1445 auto subgraph_tensor_array_input = anf_graph->add_parameter();
1446 MS_CHECK_TRUE_MSG(subgraph_tensor_array_input != nullptr, RET_NULL_PTR,
1447 "create subgraph_tensor_array_input return nullptr");
1448 subgraph_tensor_array_input->set_name(loop_node_name + "_scan_outputs_tensorarray_body_fg_input");
1449 subgraph_tensor_array_input->set_abstract(abstract_tensor);
1450 body_graph_inputs->emplace_back(subgraph_tensor_array_input);
1451 // skip trip_count ,cond_out,loop_var,no_loop_var,place_holder, output
1452 auto loop_output_idx = return_new_inputs->size() - act_output_num + i;
1453 auto loop_output_node = (*return_new_inputs)[loop_output_idx];
1454 auto set_item_value_node = CreateValueNode(schema::PrimitiveType_TensorListSetItem);
1455 if (set_item_value_node == nullptr) {
1456 MS_LOG(ERROR) << "create tensor list set item failed.";
1457 return RET_NULL_PTR;
1458 }
1459 std::vector<AnfNodePtr> set_item_inputs = {set_item_value_node, subgraph_tensor_array_input, item_index_parameter,
1460 loop_output_node};
1461 auto tensorlist_setitem_cnode = anf_graph->NewCNode(set_item_inputs);
1462 if (tensorlist_setitem_cnode == nullptr) {
1463 MS_LOG(ERROR) << "new cnode error";
1464 return RET_ERROR;
1465 }
1466 tensorlist_setitem_cnode->set_fullname_with_scope(loop_node_name + "_tensorlist_setitem_node");
1467 tensorlist_setitem_cnode->set_abstract(abstract_tensor);
1468 // loop output need replace by tensorliststack_output
1469 (*return_new_inputs)[loop_output_idx] = tensorlist_setitem_cnode;
1470 }
1471
1472 return RET_OK;
1473 }
1474
ConvertLoopOnnxNode(const onnx::NodeProto & onnx_node,std::unordered_map<std::string,AnfNodePtr> * anf_root_nodes_map,const std::string & root_node_name)1475 STATUS OnnxModelParser::ConvertLoopOnnxNode(const onnx::NodeProto &onnx_node,
1476 std::unordered_map<std::string, AnfNodePtr> *anf_root_nodes_map,
1477 const std::string &root_node_name) {
1478 MS_CHECK_TRUE_RET(anf_root_nodes_map != nullptr, RET_NULL_PTR);
1479 for (int i = 0; i < onnx_node.attribute_size(); i++) {
1480 auto &attr = onnx_node.attribute(i);
1481 if (attr.name() != "body" || attr.type() != onnx::AttributeProto_AttributeType_GRAPH) {
1482 continue;
1483 }
1484 auto &subgraph_proto = attr.g();
1485 int cond_graph_input_num = -1;
1486 auto loop_body_graph = BuildBodyGraph(onnx_node, subgraph_proto, &cond_graph_input_num);
1487 MS_CHECK_TRUE_MSG(loop_body_graph != nullptr, RET_NULL_PTR, "create loop_body_graph return nullptr");
1488 auto root_while_node = GetCNodeFromControlFlowNodesMap(onnx_node.name(), control_nodes_map_);
1489 MS_CHECK_TRUE_MSG(root_while_node != nullptr, RET_ERROR, "cannot find root_while_node");
1490 auto loop_cond_graph = BuildCondGraph(root_while_node, cond_graph_input_num, onnx_node.name() + "_cond_graph");
1491 MS_CHECK_TRUE_MSG(loop_cond_graph != nullptr, RET_NULL_PTR, "create loop_cond_graph return nullptr");
1492 all_subgraphs_.emplace_back(loop_body_graph);
1493 all_subgraphs_.emplace_back(loop_cond_graph);
1494 auto body_value_node = NewValueNode(loop_body_graph);
1495 MS_CHECK_TRUE_MSG(body_value_node != nullptr, RET_NULL_PTR, "create body_value_node return nullptr");
1496 auto inputs = root_while_node->inputs();
1497 auto cond_value_node = NewValueNode(loop_cond_graph);
1498 MS_CHECK_TRUE_MSG(cond_value_node != nullptr, RET_NULL_PTR, "create cond_value_node return nullptr");
1499 inputs.insert(inputs.begin() + 1, {cond_value_node, body_value_node});
1500 root_while_node->set_inputs(inputs);
1501 }
1502 return RET_OK;
1503 }
1504
BuildParameterNodeForQuantParam(const void * data,const std::string & name,TypeId type)1505 STATUS OnnxModelParser::BuildParameterNodeForQuantParam(const void *data, const std::string &name, TypeId type) {
1506 CHECK_NULL_RETURN(data);
1507 if (type != kNumberTypeInt64 && type != kNumberTypeFloat32) {
1508 MS_LOG(ERROR) << "quant param type don't support.";
1509 return RET_NOT_SUPPORT;
1510 }
1511 auto res_graph = ConvertGraph(res_graph_);
1512 auto parameter_node = res_graph->add_parameter();
1513 MS_CHECK_TRUE_MSG(parameter_node != nullptr, RET_NULL_PTR, "create parameter return nullptr");
1514 auto abstract_tensor = CreateTensorAbstract({}, type);
1515 if (abstract_tensor == nullptr) {
1516 MS_LOG(ERROR) << "Create tensor abstarct failed";
1517 return RET_ERROR;
1518 }
1519 parameter_node->set_abstract(abstract_tensor);
1520 parameter_node->set_name(name);
1521 int data_size = 0;
1522 if (type == kNumberTypeFloat32) {
1523 data_size = sizeof(float);
1524 } else {
1525 data_size = sizeof(int64_t);
1526 }
1527 auto tensor_info = CreateTensorInfo(data, data_size, {1}, type);
1528 if (tensor_info == nullptr) {
1529 MS_LOG(ERROR) << "create tensor info failed.";
1530 return RET_ERROR;
1531 }
1532 parameter_node->set_default_param(tensor_info);
1533 anf_nodes_map_.emplace(name, parameter_node);
1534 return RET_OK;
1535 }
1536
1537 REG_MODEL_PARSER(kFmkTypeOnnx, LiteModelParserCreator<OnnxModelParser>)
1538 } // namespace lite
1539 } // namespace mindspore
1540