• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <algorithm>
18 #include <functional>
19 #include <map>
20 #include <memory>
21 #include <utility>
22 #include <vector>
23 
24 #include "ir/func_graph.h"
25 #include "ir/param_info.h"
26 #include "ir/tensor.h"
27 #include "mindspore/core/ops/array_ops.h"
28 #include "mindspore/core/ops/comparison_ops.h"
29 #include "mindspore/core/ops/conv_pool_ops.h"
30 #include "mindspore/core/ops/framework_ops.h"
31 #include "mindspore/core/ops/image_ops.h"
32 #include "mindspore/core/ops/math_ops.h"
33 #include "mindspore/core/ops/nn_ops.h"
34 #include "mindspore/core/ops/nn_optimizer_ops.h"
35 #include "mindspore/core/ops/sequence_ops.h"
36 #include "mindspore/core/ops/auto_generate/gen_ops_primitive.h"
37 #include "proto/onnx.pb.h"
38 #include "utils/check_convert_utils.h"
39 #include "utils/hash_map.h"
40 #include "utils/ms_context.h"
41 #include "ops/op_utils.h"
42 
43 namespace mindspore {
44 const int ONNX_VERSION = 11;
45 const int kZeroNum = 0;
46 const int kOneNum = 1;
47 const int kTwoNum = 2;
48 const int kThreeNum = 3;
49 const int kFourNum = 4;
50 const int kFiveNum = 5;
51 const int kSixNum = 6;
52 const int kSevenNum = 7;
53 const int kEightNum = 8;
54 const int kNineNum = 9;
55 const int64_t kOneNumLong = 1;
56 const float weight_for_mul = 0.5;
57 enum OpMergeMode {
58   OP_MERGE_UNDEFINED = 0,            // undefined behavior
59   OP_MERGE_IGNORE = 1,               // indicate an input op merged into other op in compute node list
60   OP_MERGE_CONV = 2,                 // indicate `MindSpore Conv + BiasAdd` --> `ONNX Conv`
61   OP_MERGE_GEMM = 3,                 // indicate `MindSpore MatMul + BiasAdd` --> `ONNX Gemm`
62   OP_MERGE_BATCH_NORM = 4,           // indicate `MindSpore BatchNorm(x)[0]` --> `ONNX Batch Normalization`
63   OP_MERGE_MAXPOOL_WITH_ARGMAX = 5,  // indicate `MindSpore MaxPoolWithArgmax(x)[0]` --> `ONNX MaxPool`
64   OP_MERGE_LAYER_NORM = 6,           // indicate `MindSpore LayerNorm(x)[0]` --> `ONNX MeanVarianceNormalization`
65   OP_MERGE_CONV2D_TRANSPOSE = 7,     // indicate `MindSpore ConvTranspose + BiasAdd` --> `ONNX ConvTranspose`
66   OP_MERGE_DYNAMIC_GRU_V2 = 8,       // indicate `MindSpore DynamicGRUV2(...)[0]` --> `ONNX GRU`
67 };
68 
69 struct OpMergedInfo {
70   OpMergeMode mode = OP_MERGE_UNDEFINED;
71   int referred_count = 0;
72 };
73 
74 using GenAttrFuncType =
75   std::function<void(ValuePtr, onnx::AttributeProto_AttributeType, onnx::AttributeProto *, const PrimitivePtr &)>;
76 
IsIgnoredIdentityNode(const AnfNodePtr & node)77 bool IsIgnoredIdentityNode(const AnfNodePtr &node) {
78   return IsPrimitiveCNode(node, prim::kPrimDepend) || IsPrimitiveCNode(node, prim::kPrimLoad);
79 }
80 
81 /*
82  If true, the node should not be referenced by anything and should not be contributing to any
83  ref counts itself
84  */
IsZeroRefcountNode(const AnfNodePtr & node)85 bool IsZeroRefcountNode(const AnfNodePtr &node) { return HasAbstractMonad(node) || IsIgnoredIdentityNode(node); }
86 
87 // Ideally this should be applied to every node->input() call, not only inside GetNodeInputName
GetRealInput(const AnfNodePtr & origin_input)88 static AnfNodePtr GetRealInput(const AnfNodePtr &origin_input) {
89   AnfNodePtr input = origin_input;
90   while (IsIgnoredIdentityNode(input)) {
91     input = input->cast<CNodePtr>()->inputs().at(1);
92   }
93   return input;
94 }
95 
96 template <typename T, size_t rep_cnt = 0>
SetAttrValueToProto(const ValuePtr & value,onnx::AttributeProto_AttributeType attr_type,onnx::AttributeProto * const attr_proto,const PrimitivePtr &)97 void SetAttrValueToProto(const ValuePtr &value, onnx::AttributeProto_AttributeType attr_type,
98                          onnx::AttributeProto *const attr_proto, const PrimitivePtr &) {
99   auto casted_value = dyn_cast<T>(value);
100   if (casted_value == nullptr) {
101     MS_LOG(EXCEPTION) << "Cast value " << value->ToString() << " to type T failed.";
102   }
103   auto attr_value = casted_value->value();
104   switch (attr_type) {
105     case onnx::AttributeProto_AttributeType_INT:
106       attr_proto->set_i(static_cast<::google::protobuf::int64>(attr_value));
107       break;
108     case onnx::AttributeProto_AttributeType_FLOAT:
109       attr_proto->set_f(static_cast<float>(attr_value));
110       break;
111     case onnx::AttributeProto_AttributeType_INTS:
112       for (size_t i = 0; i < rep_cnt; ++i) {
113         attr_proto->add_ints(static_cast<::google::protobuf::int64>(attr_value));
114       }
115       break;
116     case onnx::AttributeProto_AttributeType_FLOATS:
117       for (size_t i = 0; i < rep_cnt; ++i) {
118         attr_proto->add_floats(static_cast<float>(attr_value));
119       }
120       break;
121     default:
122       MS_LOG(EXCEPTION) << "Convert attribute fail, unexpected ONNX type " << attr_type;
123   }
124   attr_proto->set_type(attr_type);
125 }
126 
127 template <size_t beg_idx = 0>
SetAttrTupleValueToProto(const ValuePtr & value,onnx::AttributeProto_AttributeType attr_type,onnx::AttributeProto * const attr_proto,const PrimitivePtr &)128 void SetAttrTupleValueToProto(const ValuePtr &value, onnx::AttributeProto_AttributeType attr_type,
129                               onnx::AttributeProto *const attr_proto, const PrimitivePtr &) {
130   auto tuple_ptr = dyn_cast<ValueTuple>(value);
131   if (tuple_ptr == nullptr) {
132     MS_LOG(EXCEPTION) << "Cast value from type " << value->type_name() << " to ValueTuple failed.";
133   }
134   switch (attr_type) {
135     case onnx::AttributeProto_AttributeType_INTS:
136       for (size_t i = beg_idx; i < tuple_ptr->size(); ++i) {
137         attr_proto->add_ints(GetValue<int64_t>((*tuple_ptr)[i]));
138       }
139       break;
140     case onnx::AttributeProto_AttributeType_INT:
141       attr_proto->set_i(GetValue<int64_t>((*tuple_ptr)[beg_idx]));
142       break;
143     case onnx::AttributeProto_AttributeType_FLOATS:
144       for (size_t i = beg_idx; i < tuple_ptr->size(); ++i) {
145         attr_proto->add_floats(GetValue<float>((*tuple_ptr)[i]));
146       }
147       break;
148     default:
149       MS_LOG(EXCEPTION) << "Convert attribute fail, unexpected ONNX type " << attr_type;
150   }
151   attr_proto->set_type(attr_type);
152 }
153 
SetPoolingPadMode(const ValuePtr & value,onnx::AttributeProto_AttributeType,onnx::AttributeProto * const attr_proto,const PrimitivePtr &)154 void SetPoolingPadMode(const ValuePtr &value, onnx::AttributeProto_AttributeType,
155                        onnx::AttributeProto *const attr_proto, const PrimitivePtr &) {
156   attr_proto->set_type(onnx::AttributeProto_AttributeType_STRING);
157   int64_t attr_value;
158   CheckAndConvertUtils::GetPadModEnumValue(value, &attr_value, true);
159   if (attr_value == PadMode::VALID) {
160     attr_proto->set_s("VALID");
161   } else {
162     attr_proto->set_s("SAME_UPPER");
163   }
164 }
165 
SetConvPadding(const ValuePtr & value,onnx::AttributeProto_AttributeType,onnx::AttributeProto * const attr_proto,const PrimitivePtr & prim)166 void SetConvPadding(const ValuePtr &value, onnx::AttributeProto_AttributeType, onnx::AttributeProto *const attr_proto,
167                     const PrimitivePtr &prim) {
168   attr_proto->set_type(onnx::AttributeProto_AttributeType_STRING);
169   int64_t attr_value;
170   CheckAndConvertUtils::GetPadModEnumValue(value, &attr_value);
171   if (attr_value == PadMode::VALID) {
172     attr_proto->set_s("VALID");
173   } else if (attr_value == PadMode::SAME) {
174     attr_proto->set_s("SAME_UPPER");
175   } else {  // pad_mode is 'pad', use attribute 'pad_list' to fill ONNX attribute 'pads'
176     attr_proto->set_name("pads");
177     SetAttrTupleValueToProto(prim->GetAttr("pad_list"), onnx::AttributeProto_AttributeType_INTS, attr_proto, prim);
178   }
179 }
180 
SetConvTransposePadding(const ValuePtr & value,onnx::AttributeProto_AttributeType,onnx::AttributeProto * const attr_proto,const PrimitivePtr & prim)181 void SetConvTransposePadding(const ValuePtr &value, onnx::AttributeProto_AttributeType,
182                              onnx::AttributeProto *const attr_proto, const PrimitivePtr &prim) {
183   attr_proto->set_type(onnx::AttributeProto_AttributeType_STRING);
184   int64_t attr_value;
185   CheckAndConvertUtils::GetPadModEnumValue(value, &attr_value);
186   if (attr_value == PadMode::VALID) {
187     attr_proto->set_s("VALID");
188   } else if (attr_value == PadMode::SAME) {
189     attr_proto->set_s("SAME_LOWER");
190   } else {  // pad_mode is 'pad', use attribute 'pad_list' to fill ONNX attribute 'pads'
191     attr_proto->set_name("pads");
192     SetAttrTupleValueToProto(prim->GetAttr("pad_list"), onnx::AttributeProto_AttributeType_INTS, attr_proto, prim);
193   }
194 }
195 
GetPrimitive(const CNodePtr & node)196 PrimitivePtr GetPrimitive(const CNodePtr &node) {
197   AnfNodePtr op = node->input(kZeroNum);
198   auto op_value = dyn_cast<ValueNode>(op);
199   MS_EXCEPTION_IF_NULL(op_value);
200   auto prim = dyn_cast<Primitive>(op_value->value());
201   MS_EXCEPTION_IF_NULL(prim);
202   return prim;
203 }
204 
205 template <typename T>
GetOpAttribute(const CNodePtr & node,const std::string & name)206 T GetOpAttribute(const CNodePtr &node, const std::string &name) {
207   ValuePtr attr = GetPrimitive(node)->GetAttr(name);
208   return GetValue<T>(attr);
209 }
210 
211 template <typename T>
GetOpAttributePtr(const CNodePtr & node,const std::string & name)212 std::shared_ptr<T> GetOpAttributePtr(const CNodePtr &node, const std::string &name) {
213   ValuePtr attr = GetPrimitive(node)->GetAttr(name);
214   auto result = dyn_cast<T>(attr);
215   MS_EXCEPTION_IF_NULL(result);
216   return result;
217 }
218 
MakeOutputName(const std::string & node_name,int output_index)219 std::string MakeOutputName(const std::string &node_name, int output_index) {
220   return node_name + "_" + std::to_string(output_index);
221 }
222 
RavelIndex(const std::vector<int64_t> & index,const std::vector<int64_t> & shape)223 int64_t RavelIndex(const std::vector<int64_t> &index, const std::vector<int64_t> &shape) {
224   MS_EXCEPTION_IF_CHECK_FAIL(index.size() <= shape.size(), "Index ndims must be <= shape ndims");
225   int64_t result = 0;
226   int64_t stride = 1;
227   for (size_t i = 0; i < shape.size() - index.size(); ++i) {
228     stride *= shape[shape.size() - 1 - i];
229   }
230   for (size_t i = 0; i < index.size(); ++i) {
231     size_t rev_i = index.size() - 1 - i;
232     result += index[rev_i] * stride;
233     stride *= shape[rev_i];
234   }
235   return result;
236 }
237 
238 namespace fp16 {
FieldMask(unsigned int field_size)239 uint32_t FieldMask(unsigned int field_size) {
240   const unsigned int BYTE_SIZE = 8;
241   uint32_t mask = std::numeric_limits<uint32_t>::max();
242   return mask >> (BYTE_SIZE * sizeof(mask) - field_size);
243 }
244 
ExponentBias(unsigned int exponent_size)245 uint32_t ExponentBias(unsigned int exponent_size) { return (1U << (exponent_size - 1U)) - 1U; }
246 
Fp32ToFp16(float value)247 uint32_t Fp32ToFp16(float value) {
248   const unsigned int FP32_M = 23;
249   const unsigned int FP32_E = 32 - 1 - FP32_M;
250   const unsigned int FP16_M = 10;
251   const unsigned int FP16_E = 16 - 1 - FP16_M;
252 
253   uint32_t fp32_bits;
254   auto ret = memcpy_s(reinterpret_cast<std::byte *>(&fp32_bits), sizeof(fp32_bits),
255                       reinterpret_cast<std::byte *>(&value), sizeof(value));
256   if (ret != EOK) {
257     MS_LOG(ERROR) << "Set data memcpy_s failed, ret = " << ret;
258   }
259 
260   uint32_t mantissa = fp32_bits & FieldMask(FP32_M);
261   uint32_t fp32_exp_mask = FieldMask(FP32_E);
262   uint32_t fp32_exponent = (fp32_bits >> FP32_M) & fp32_exp_mask;
263   if (fp32_exponent == fp32_exp_mask) {
264     MS_LOG(EXCEPTION) << "Tried to convert inf or nan to float16: " << value;
265   }
266   uint32_t sign = fp32_bits >> (FP32_E + FP32_M);
267 
268   uint32_t fp16_bits = 0;
269   fp16_bits |= sign << (FP16_E + FP16_M);
270   uint32_t fp16_exponent = 0;
271   if (fp32_exponent != 0) {
272     fp16_exponent = fp32_exponent - ExponentBias(FP32_E) + ExponentBias(FP16_E);
273   }
274   if (fp16_exponent >= FieldMask(FP16_E)) {  // inf, nan (==), underflow, or overflow (>)
275     MS_LOG(EXCEPTION) << "Conversion of " << value << " to float16 resulted in exponent overflow or underflow";
276   }
277   fp16_bits |= fp16_exponent << FP16_M;
278   fp16_bits |= mantissa >> (FP32_M - FP16_M);
279 
280   return fp16_bits;
281 }
282 }  // namespace fp16
283 
AddFloatScalarInitializer(const std::string & name,float value,onnx::TensorProto_DataType type,onnx::GraphProto * graph_proto)284 void AddFloatScalarInitializer(const std::string &name, float value, onnx::TensorProto_DataType type,
285                                onnx::GraphProto *graph_proto) {
286   onnx::TensorProto *initializer = graph_proto->add_initializer();
287   initializer->set_name(name);
288   if (type == onnx::TensorProto_DataType_FLOAT16) {
289     uint32_t fp16 = fp16::Fp32ToFp16(value);
290     initializer->add_int32_data(static_cast<int32_t>(fp16));
291   } else if (type == onnx::TensorProto_DataType_FLOAT) {
292     initializer->add_float_data(value);
293   } else {
294     MS_LOG(EXCEPTION) << "Unsupported type: " << type;
295   }
296   initializer->set_data_type(type);
297 }
298 
AddInt64Tensor1DInitializer(const std::string & name,const std::vector<int64_t> & values,onnx::GraphProto * graph_proto)299 void AddInt64Tensor1DInitializer(const std::string &name, const std::vector<int64_t> &values,
300                                  onnx::GraphProto *graph_proto) {
301   onnx::TensorProto *initializer = graph_proto->add_initializer();
302   initializer->set_name(name);
303   initializer->set_data_type(onnx::TensorProto_DataType_INT64);
304   initializer->add_dims(static_cast<int64_t>(values.size()));
305   for (auto value : values) {
306     initializer->add_int64_data(value);
307   }
308 }
309 
AddFloatTensor1DInitializer(const std::string & name,const std::vector<float> & values,onnx::TensorProto_DataType type,onnx::GraphProto * graph_proto)310 void AddFloatTensor1DInitializer(const std::string &name, const std::vector<float> &values,
311                                  onnx::TensorProto_DataType type, onnx::GraphProto *graph_proto) {
312   onnx::TensorProto *initializer = graph_proto->add_initializer();
313   initializer->set_name(name);
314   initializer->add_dims(static_cast<int64_t>(values.size()));
315   if (type == onnx::TensorProto_DataType_FLOAT16) {
316     for (auto value : values) {
317       uint32_t fp16 = fp16::Fp32ToFp16(value);
318       initializer->add_int32_data(static_cast<int32_t>(fp16));
319     }
320   } else if (type == onnx::TensorProto_DataType_FLOAT) {
321     for (auto value : values) {
322       initializer->add_float_data(value);
323     }
324   } else {
325     MS_LOG(EXCEPTION) << "Unsupported type: " << type;
326   }
327   initializer->set_data_type(type);
328 }
329 
AddOp(const std::string & type,const std::vector<std::string> & inputs,const std::vector<std::string> & outputs,onnx::GraphProto * graph_proto)330 void AddOp(const std::string &type, const std::vector<std::string> &inputs, const std::vector<std::string> &outputs,
331            onnx::GraphProto *graph_proto) {
332   onnx::NodeProto *op = graph_proto->add_node();
333   op->set_op_type(type);
334   op->set_name(outputs.at(0) + type);
335   for (const auto &input : inputs) {
336     op->add_input(input);
337   }
338   for (const auto &output : outputs) {
339     op->add_output(output);
340   }
341 }
342 
AddClipOp(const std::string & input,const std::string & output,float min,float max,onnx::TensorProto_DataType type,onnx::GraphProto * graph_proto)343 void AddClipOp(const std::string &input, const std::string &output, float min, float max,
344                onnx::TensorProto_DataType type, onnx::GraphProto *graph_proto) {
345   auto min_input_name = output + "__min_initializer";
346   AddFloatScalarInitializer(min_input_name, min, type, graph_proto);
347 
348   auto max_input_name = output + "__max_initializer";
349   AddFloatScalarInitializer(max_input_name, max, type, graph_proto);
350 
351   AddOp("Clip", {input, min_input_name, max_input_name}, {output}, graph_proto);
352 }
353 
AddSliceOp(const std::string & input,const std::string & output,const std::vector<int64_t> & start,const std::vector<int64_t> & end,const std::vector<int64_t> & axis,const std::vector<int64_t> & step,onnx::GraphProto * graph_proto)354 void AddSliceOp(const std::string &input, const std::string &output, const std::vector<int64_t> &start,
355                 const std::vector<int64_t> &end, const std::vector<int64_t> &axis, const std::vector<int64_t> &step,
356                 onnx::GraphProto *graph_proto) {
357   auto starts_name = output + "__starts_initializer";
358   AddInt64Tensor1DInitializer(starts_name, start, graph_proto);
359 
360   auto ends_name = output + "__ends_initializer";
361   AddInt64Tensor1DInitializer(ends_name, end, graph_proto);
362 
363   auto axes_name = output + "__axes_initializer";
364   AddInt64Tensor1DInitializer(axes_name, axis, graph_proto);
365 
366   auto steps_name = output + "__steps_initializer";
367   AddInt64Tensor1DInitializer(steps_name, step, graph_proto);
368 
369   AddOp("Slice", {input, starts_name, ends_name, axes_name, steps_name}, {output}, graph_proto);
370 }
371 
AddSplitOp(const std::string & input,const std::vector<std::string> & outputs,const std::vector<int64_t> & split,int64_t axis,onnx::GraphProto * graph_proto)372 void AddSplitOp(const std::string &input, const std::vector<std::string> &outputs, const std::vector<int64_t> &split,
373                 int64_t axis, onnx::GraphProto *graph_proto) {
374   if (outputs.size() != split.size()) {
375     MS_LOG(EXCEPTION) << "Number of splits and number of outputs do not match";
376   }
377 
378   onnx::NodeProto *split_proto = graph_proto->add_node();
379   std::string op_type = "Split";
380   split_proto->set_op_type(op_type);
381   split_proto->set_name(outputs.at(0) + op_type);
382   split_proto->add_input(input);
383   for (const auto &output : outputs) {
384     split_proto->add_output(output);
385   }
386   onnx::AttributeProto *axis_attr_proto = split_proto->add_attribute();
387   axis_attr_proto->set_name("axis");
388   axis_attr_proto->set_type(onnx::AttributeProto_AttributeType_INT);
389   axis_attr_proto->set_i(axis);
390   onnx::AttributeProto *split_attr_proto = split_proto->add_attribute();
391   split_attr_proto->set_name("split");
392   split_attr_proto->set_type(onnx::AttributeProto_AttributeType_INTS);
393   for (int64_t n : split) {
394     split_attr_proto->add_ints(n);
395   }
396 }
397 
AddExpandOp(const std::string & input,const std::string & output,const std::vector<int64_t> & shape,onnx::GraphProto * graph_proto)398 void AddExpandOp(const std::string &input, const std::string &output, const std::vector<int64_t> &shape,
399                  onnx::GraphProto *graph_proto) {
400   onnx::NodeProto *expand_node_proto = graph_proto->add_node();
401   expand_node_proto->set_op_type("Expand");
402   expand_node_proto->set_name(output + "_Expand");
403   expand_node_proto->add_input(input);
404   auto shape_name = output + "_expand_shape_initializer";
405   AddInt64Tensor1DInitializer(shape_name, shape, graph_proto);
406   expand_node_proto->add_input(shape_name);
407   expand_node_proto->add_output(output);
408 }
409 
AddReshapeOp(const std::string & input,const std::string & output,const std::vector<int64_t> & shape,onnx::GraphProto * graph_proto)410 void AddReshapeOp(const std::string &input, const std::string &output, const std::vector<int64_t> &shape,
411                   onnx::GraphProto *graph_proto) {
412   auto shape_name = output + "__shape_initializer";
413   AddInt64Tensor1DInitializer(shape_name, shape, graph_proto);
414   AddOp("Reshape", {input, shape_name}, {output}, graph_proto);
415 }
416 
AddConstantOfShapeOp(const std::string & shape,const std::string & output,onnx::GraphProto * graph_proto)417 onnx::TensorProto *AddConstantOfShapeOp(const std::string &shape, const std::string &output,
418                                         onnx::GraphProto *graph_proto) {
419   onnx::NodeProto *op = graph_proto->add_node();
420   std::string op_type = "ConstantOfShape";
421   op->set_op_type(op_type);
422   op->set_name(output + op_type);
423   op->add_input(shape);
424   op->add_output(output);
425   onnx::AttributeProto *value_attr = op->add_attribute();
426   value_attr->set_name("value");
427   value_attr->set_type(onnx::AttributeProto_AttributeType_TENSOR);
428   onnx::TensorProto *value_proto = value_attr->mutable_t();
429   value_proto->add_dims(1);
430   return value_proto;
431 }
432 
AddCastOp(const std::string & input,const std::string & output,onnx::TensorProto_DataType target_type,onnx::GraphProto * graph_proto)433 void AddCastOp(const std::string &input, const std::string &output, onnx::TensorProto_DataType target_type,
434                onnx::GraphProto *graph_proto) {
435   onnx::NodeProto *node_proto = graph_proto->add_node();
436   std::string op_type = "Cast";
437   node_proto->set_op_type(op_type);
438   node_proto->set_name(output + op_type);
439   node_proto->add_input(input);
440   node_proto->add_output(output);
441 
442   onnx::AttributeProto *target_type_attr = node_proto->add_attribute();
443   target_type_attr->set_name("to");
444   target_type_attr->set_type(onnx::AttributeProto_AttributeType_INT);
445   target_type_attr->set_i(target_type);
446 }
447 
AddReduceOp(const std::string & op_type,const std::string & input,const std::string & output,const std::vector<int64_t> & axes,bool keepdims,onnx::GraphProto * graph_proto)448 void AddReduceOp(const std::string &op_type, const std::string &input, const std::string &output,
449                  const std::vector<int64_t> &axes, bool keepdims, onnx::GraphProto *graph_proto) {
450   onnx::NodeProto *node_proto = graph_proto->add_node();
451   node_proto->set_name(output + op_type);
452   node_proto->set_op_type(op_type);
453   node_proto->add_input(input);
454   node_proto->add_output(output);
455 
456   onnx::AttributeProto *keep_dims_proto = node_proto->add_attribute();
457   keep_dims_proto->set_name("keepdims");
458   keep_dims_proto->set_type(onnx::AttributeProto_AttributeType_INT);
459   keep_dims_proto->set_i(static_cast<int64_t>(keepdims));
460 
461   onnx::AttributeProto *axes_proto = node_proto->add_attribute();
462   axes_proto->set_name("axes");
463   axes_proto->set_type(onnx::AttributeProto_AttributeType_INTS);
464 
465   for (auto axis : axes) {
466     axes_proto->add_ints(axis);
467   }
468 }
469 
AddMeanVarianceNormalizationOp(const std::string & input,const std::string & gamma,const std::string & beta,const std::string & output,const std::vector<int64_t> & axes,float epsilon,const std::vector<int64_t> & input_shape,onnx::TensorProto_DataType input_type,onnx::GraphProto * graph_proto)470 void AddMeanVarianceNormalizationOp(const std::string &input, const std::string &gamma, const std::string &beta,
471                                     const std::string &output, const std::vector<int64_t> &axes, float epsilon,
472                                     const std::vector<int64_t> &input_shape, onnx::TensorProto_DataType input_type,
473                                     onnx::GraphProto *graph_proto) {
474   auto input_name = output + "_input";
475   AddCastOp(input, input_name, onnx::TensorProto_DataType_FLOAT, graph_proto);
476   auto gamma_name = output + "_gamma";
477   AddCastOp(gamma, gamma_name, onnx::TensorProto_DataType_FLOAT, graph_proto);
478   auto beta_name = output + "_beta";
479   AddCastOp(beta, beta_name, onnx::TensorProto_DataType_FLOAT, graph_proto);
480 
481   // MeanVarianceNormalization is replaced with equivalent ops because it is not supported by CUDAExecutionProvider
482   auto meanvariancenormal_node_name = output + "_normalized";
483 
484   auto mean_name = output + "_mean";
485   AddReduceOp("ReduceMean", input_name, mean_name, axes, true, graph_proto);
486   auto centered_name = output + "_centered";
487   AddOp("Sub", {input_name, mean_name}, {centered_name}, graph_proto);
488 
489   auto sqsum_name = output + "_sqsum";
490   AddReduceOp("ReduceSumSquare", centered_name, sqsum_name, axes, true, graph_proto);
491   float reduce_size = std::accumulate(axes.begin(), axes.end(), 1.0f,
492                                       [&input_shape](auto acc, auto axis) { return acc * input_shape[axis]; });
493   auto reduce_size_name = output + "_reduce_size";
494   AddFloatScalarInitializer(reduce_size_name, reduce_size, onnx::TensorProto_DataType_FLOAT, graph_proto);
495   auto variance_name = output + "_variance";
496   AddOp("Div", {sqsum_name, reduce_size_name}, {variance_name}, graph_proto);
497 
498   auto epsilon_name = output + "_epsilon";
499   AddFloatScalarInitializer(epsilon_name, epsilon, onnx::TensorProto_DataType_FLOAT, graph_proto);
500   auto variance_with_epsilon_name = output + "_variance_with_epsilon";
501   AddOp("Add", {variance_name, epsilon_name}, {variance_with_epsilon_name}, graph_proto);
502   auto std_name = output + "_std";
503   AddOp("Sqrt", {variance_with_epsilon_name}, {std_name}, graph_proto);
504 
505   AddOp("Div", {centered_name, std_name}, {meanvariancenormal_node_name}, graph_proto);
506 
507   // Add mul and add node
508   auto mul_node_name = output + "_rescaled";
509   AddOp("Mul", {meanvariancenormal_node_name, gamma_name}, {mul_node_name}, graph_proto);
510 
511   // add beta
512   auto add_node_name = output;
513   if (input_type == onnx::TensorProto_DataType_FLOAT16) {
514     add_node_name += "_shifted";
515   }
516   AddOp("Add", {mul_node_name, beta_name}, {add_node_name}, graph_proto);
517 
518   if (input_type == onnx::TensorProto_DataType_FLOAT16) {
519     AddCastOp(add_node_name, output, onnx::TensorProto_DataType_FLOAT16, graph_proto);
520   }
521 }
522 
AddConcatOp(const std::vector<std::string> & inputs,const std::string & output,int axis,onnx::GraphProto * graph_proto)523 void AddConcatOp(const std::vector<std::string> &inputs, const std::string &output, int axis,
524                  onnx::GraphProto *graph_proto) {
525   onnx::NodeProto *concat_proto = graph_proto->add_node();
526   auto op_type = "Concat";
527   concat_proto->set_op_type(op_type);
528   concat_proto->set_name(output + op_type);
529   for (const auto &input : inputs) {
530     concat_proto->add_input(input);
531   }
532   concat_proto->add_output(output);
533   onnx::AttributeProto *axis_proto = concat_proto->add_attribute();
534   axis_proto->set_name("axis");
535   axis_proto->set_type(onnx::AttributeProto_AttributeType_INT);
536   axis_proto->set_i(axis);
537 }
538 
ConvertBoxesToXywh(const std::string & startpoints,const std::string & endpoints,const std::string & centerpoints,const std::string & dimensions,onnx::TensorProto_DataType type,onnx::GraphProto * graph_proto)539 void ConvertBoxesToXywh(const std::string &startpoints, const std::string &endpoints, const std::string &centerpoints,
540                         const std::string &dimensions, onnx::TensorProto_DataType type, onnx::GraphProto *graph_proto) {
541   auto coord_sums_name = centerpoints + "__to_div";
542   AddOp("Add", {startpoints, endpoints}, {coord_sums_name}, graph_proto);
543   auto two_name = centerpoints + "__two_initializer";
544   AddFloatScalarInitializer(two_name, 2.0f, type, graph_proto);
545   AddOp("Div", {coord_sums_name, two_name}, {centerpoints}, graph_proto);
546 
547   auto coord_diffs_name = dimensions + "__to_add";
548   AddOp("Sub", {endpoints, startpoints}, {coord_diffs_name}, graph_proto);
549   auto one_name = dimensions + "__one_initializer";
550   AddFloatScalarInitializer(one_name, 1.0f, type, graph_proto);
551   AddOp("Add", {coord_diffs_name, one_name}, {dimensions}, graph_proto);
552 }
553 
ConvertBoxesToXyxy(const std::string & centerpoints,const std::string & dimensions,const std::string & startpoints,const std::string & endpoints,onnx::TensorProto_DataType type,onnx::GraphProto * graph_proto)554 void ConvertBoxesToXyxy(const std::string &centerpoints, const std::string &dimensions, const std::string &startpoints,
555                         const std::string &endpoints, onnx::TensorProto_DataType type, onnx::GraphProto *graph_proto) {
556   auto half_name = startpoints + "__half_initializer";
557   AddFloatScalarInitializer(half_name, 0.5f, type, graph_proto);
558 
559   auto half_dim_name = startpoints + "__half_dim";
560   auto half_dim_to_sub_name = startpoints + "__to_sub";
561   AddOp("Mul", {dimensions, half_name}, {half_dim_to_sub_name}, graph_proto);
562   AddOp("Sub", {half_dim_to_sub_name, half_name}, {half_dim_name}, graph_proto);
563 
564   AddOp("Sub", {centerpoints, half_dim_name}, {startpoints}, graph_proto);
565   AddOp("Add", {centerpoints, half_dim_name}, {endpoints}, graph_proto);
566 }
567 
ClipPointsComponent(const std::string & points,const std::string & clipped,float max,int64_t component_idx,onnx::TensorProto_DataType type,onnx::GraphProto * graph_proto)568 void ClipPointsComponent(const std::string &points, const std::string &clipped, float max, int64_t component_idx,
569                          onnx::TensorProto_DataType type, onnx::GraphProto *graph_proto) {
570   auto res_to_clip_name = clipped + "__clip";
571   AddSliceOp(points, res_to_clip_name, {component_idx}, {component_idx + 1}, {1}, {1}, graph_proto);
572   AddClipOp(res_to_clip_name, clipped, 0.0f, max, type, graph_proto);
573 }
574 
575 // check AnfNode data type is float or not.
IsFloatDataType(const AnfNodePtr & node)576 bool IsFloatDataType(const AnfNodePtr &node) {
577   auto dtype = node->Type();
578   auto elem_type = dyn_cast<TensorType>(dtype)->element()->type_id();
579   switch (elem_type) {
580     case (kNumberTypeFloat):
581     case (kNumberTypeFloat16):
582     case (kNumberTypeFloat32):
583     case (kNumberTypeFloat64):
584       return True;
585     default:
586       return False;
587   }
588 }
589 
590 namespace while_loop_export {
591 namespace {
592 const char CONTROL_PATTERN[] = "\u21B5";
593 const char LOOP_BODY_PATTERN[] = "\u21BB";
594 const char AFTER_LOOP_PATTERN[] = "\u2193";
595 
596 const size_t LOOP_BODY_INPUT = 2;
597 const size_t AFTER_LOOP_INPUT = 3;
598 
IsSubgraphNameCorrect(const FuncGraphPtr & func_graph,const std::string & part_pattern)599 bool IsSubgraphNameCorrect(const FuncGraphPtr &func_graph, const std::string &part_pattern) {
600   auto name = func_graph->ToString();
601   return name.find("construct") != std::string::npos && name.find(part_pattern) != std::string::npos;
602 }
603 
604 template <typename T>
GetNodeInput(const CNodePtr & node,size_t i)605 const std::shared_ptr<T> GetNodeInput(const CNodePtr &node, size_t i) {
606   auto input = GetRealInput(node->input(i));
607   auto result = dyn_cast<T>(input);
608   if (result == nullptr) {
609     MS_LOG(EXCEPTION) << "Failed to get input " << i << " of node " << node->DebugString();
610   }
611   return result;
612 }
613 
614 template <typename T>
GetNodeInputValue(const CNodePtr & node,size_t i)615 const std::shared_ptr<T> GetNodeInputValue(const CNodePtr &node, size_t i) {
616   auto input = GetNodeInput<ValueNode>(node, i);
617   auto result = dyn_cast<T>(input->value());
618   if (result == nullptr) {
619     MS_LOG(EXCEPTION) << "Failed to get a value from input " << i << " of node " << node->DebugString();
620   }
621   return result;
622 }
623 
FindLoopSwitchNode(const FuncGraphPtr & control_subgraph)624 CNodePtr FindLoopSwitchNode(const FuncGraphPtr &control_subgraph) {
625   if (!IsSubgraphNameCorrect(control_subgraph, CONTROL_PATTERN)) {
626     MS_LOG(EXCEPTION) << "Expected a loop control structure";
627   }
628   auto lazy_call_node = GetNodeInput<CNode>(control_subgraph->get_return(), kOneNum);
629   if (lazy_call_node->size() != kOneNum || !lazy_call_node->input(kZeroNum)->isa<CNode>()) {
630     MS_LOG(EXCEPTION) << "Expected a lazy call node";
631   }
632   auto switch_node = GetNodeInput<CNode>(lazy_call_node, kZeroNum);
633   if (!switch_node->IsApply(prim::kPrimSwitch)) {
634     MS_LOG(EXCEPTION) << "Expected a switch node";
635   }
636   return switch_node;
637 }
638 
GetSubgraph(const CNodePtr & switch_node,size_t input_index,const std::string & name_pattern)639 FuncGraphPtr GetSubgraph(const CNodePtr &switch_node, size_t input_index, const std::string &name_pattern) {
640   auto input_node = GetNodeInput<CNode>(switch_node, input_index);
641   if (!input_node->IsApply(prim::kPrimPartial)) {
642     MS_LOG(EXCEPTION) << "Expected a partial node";
643   }
644 
645   auto subgraph = GetNodeInputValue<FuncGraph>(input_node, kOneNum);
646   if (!IsSubgraphNameCorrect(subgraph, name_pattern)) {
647     MS_LOG(EXCEPTION) << "Expected a loop part: " << name_pattern;
648   }
649 
650   return subgraph;
651 }
652 
653 // The inputs of this node are the outputs of ONNX Loop
FindLoopRepeatNode(const FuncGraphPtr & loop_subgraph,const FuncGraphPtr & control_subgraph)654 CNodePtr FindLoopRepeatNode(const FuncGraphPtr &loop_subgraph, const FuncGraphPtr &control_subgraph) {
655   auto repeat_node = GetNodeInput<CNode>(loop_subgraph->return_node(), kOneNum);
656   auto maybe_control_graph = GetNodeInputValue<FuncGraph>(repeat_node, kZeroNum);
657   MS_EXCEPTION_IF_CHECK_FAIL(maybe_control_graph == control_subgraph, "Loop matching failed");
658   return repeat_node;
659 }
660 
661 struct LoopConditionInfo {
662   int64_t begin;
663   int64_t end;
664   int64_t step;
665 };
666 
667 /*
668   NOTE: loop support is currently very limited, because proper condition export requires more graph surgery (copying
669   condition expression before and inside Loop subgraph)
670   The only while loop form supported currently is the one used in GNMT v2's Beam Search. Python example:
671     i = begin
672     while i < end
673         ...
674         i += step
675   To enable proper support for arbitrary while loop conditions, condition calculation should be duplicated inside the
676   Loop supgraph. But exporting the same ops twice with different names is not currently supported.
677  */
TraceLoopConditionInfo(const CNodePtr & start_node,const CNodePtr & cond_node,const FuncGraphPtr & control_subgraph,const CNodePtr & loop_repeat_node)678 LoopConditionInfo TraceLoopConditionInfo(const CNodePtr &start_node, const CNodePtr &cond_node,
679                                          const FuncGraphPtr &control_subgraph, const CNodePtr &loop_repeat_node) {
680   MS_EXCEPTION_IF_CHECK_FAIL(cond_node->IsApply(prim::kPrimLess), "Expected Less node");
681 
682   auto counter = GetNodeInput<Parameter>(cond_node, kOneNum);
683   auto end_tensor = GetNodeInputValue<tensor::Tensor>(cond_node, kTwoNum);
684   MS_EXCEPTION_IF_CHECK_FAIL(end_tensor->shape_c().empty(), "Expected a scalar tensor");
685   auto end = *reinterpret_cast<const int32_t *>(end_tensor->data_c());
686 
687   const auto &subgraph_args = control_subgraph->parameters();
688   auto counter_input_pos = std::find(subgraph_args.begin(), subgraph_args.end(), counter) - subgraph_args.begin();
689 
690   auto begin_tensor = GetNodeInputValue<tensor::Tensor>(start_node, 1UL + static_cast<size_t>(counter_input_pos));
691   MS_EXCEPTION_IF_CHECK_FAIL(begin_tensor->shape_c().empty(), "Expected a scalar tensor");
692   auto begin = *reinterpret_cast<const int32_t *>(begin_tensor->data_c());
693 
694   auto increment_node = GetNodeInput<CNode>(loop_repeat_node, 1UL + static_cast<size_t>(counter_input_pos));
695   MS_EXCEPTION_IF_CHECK_FAIL(increment_node->IsApply(prim::kPrimAdd), "Expected Add node");
696   auto step_tensor = GetNodeInputValue<tensor::Tensor>(increment_node, kTwoNum);
697   MS_EXCEPTION_IF_CHECK_FAIL(step_tensor->shape_c().empty(), "Expected a scalar tensor");
698   auto step = *reinterpret_cast<const int32_t *>(step_tensor->data_c());
699 
700   return LoopConditionInfo{begin, end, step};
701 }
702 
703 // result[i] is which control subgraph input should be taken for pos i to match the order of loop subgraph inputs
TraceLoopToControlMap(const FuncGraphPtr & control_subgraph)704 std::vector<size_t> TraceLoopToControlMap(const FuncGraphPtr &control_subgraph) {
705   std::vector<size_t> result;
706 
707   auto switch_node = FindLoopSwitchNode(control_subgraph);
708   auto loop_partial_node = GetNodeInput<CNode>(switch_node, kTwoNum);
709   const auto &control_params = control_subgraph->parameters();
710   int64_t auxiliary_inputs_num = 2;
711   for (size_t i = static_cast<size_t>(auxiliary_inputs_num); i < loop_partial_node->size(); ++i) {
712     auto loop_param = GetNodeInput<Parameter>(loop_partial_node, i);
713     auto control_param_pos =
714       std::find(control_params.begin(), control_params.end(), loop_param) - control_params.begin();
715     result.push_back(control_param_pos);
716   }
717 
718   return result;
719 }
720 
TraceAfterToLoopMap(const FuncGraphPtr & control_subgraph)721 std::vector<size_t> TraceAfterToLoopMap(const FuncGraphPtr &control_subgraph) {
722   std::vector<size_t> result;
723 
724   auto switch_node = FindLoopSwitchNode(control_subgraph);
725   auto loop_partial_node = GetNodeInput<CNode>(switch_node, kTwoNum);
726   auto after_partial_node = GetNodeInput<CNode>(switch_node, kThreeNum);
727   const auto &loop_params = loop_partial_node->inputs();
728   int64_t auxiliary_inputs_num = 2;
729   for (size_t i = static_cast<size_t>(auxiliary_inputs_num); i < after_partial_node->size(); ++i) {
730     auto after_param = GetNodeInput<Parameter>(after_partial_node, i);
731     auto after_param_pos = std::find(loop_params.begin(), loop_params.end(), after_param) - loop_params.begin();
732     result.push_back(after_param_pos - auxiliary_inputs_num);
733   }
734 
735   return result;
736 }
737 
TraceIgnoredLoopParams(const CNodePtr & start_node,const std::vector<size_t> & loop_to_control_map)738 std::vector<bool> TraceIgnoredLoopParams(const CNodePtr &start_node, const std::vector<size_t> &loop_to_control_map) {
739   auto inputs_num = start_node->size() - 1;
740   std::vector<bool> result(inputs_num);
741   for (size_t loop_i = 0; loop_i < inputs_num; ++loop_i) {
742     auto control_i = loop_to_control_map.at(loop_i);
743     const auto &input = start_node->input(control_i + 1);
744     if ((input->isa<Parameter>() && input->cast<ParameterPtr>()->has_default()) || HasAbstractMonad(input)) {
745       result.at(loop_i) = true;
746     }
747   }
748   return result;
749 }
750 }  // namespace
751 
IsControlSubgraph(const ValuePtr & func_graph_node)752 bool IsControlSubgraph(const ValuePtr &func_graph_node) {
753   auto func_graph = dyn_cast<FuncGraph>(func_graph_node);
754   return func_graph != nullptr && IsSubgraphNameCorrect(func_graph, CONTROL_PATTERN);
755 }
756 
IsLoopBodyReturnNode(const CNodePtr & node,const FuncGraphPtr & func_graph)757 bool IsLoopBodyReturnNode(const CNodePtr &node, const FuncGraphPtr &func_graph) {
758   return IsSubgraphNameCorrect(func_graph, LOOP_BODY_PATTERN) && node == func_graph->get_return();
759 }
760 
IsAfterLoopReturnNode(const CNodePtr & node,const FuncGraphPtr & func_graph)761 bool IsAfterLoopReturnNode(const CNodePtr &node, const FuncGraphPtr &func_graph) {
762   return IsSubgraphNameCorrect(func_graph, AFTER_LOOP_PATTERN) && node == func_graph->get_return();
763 }
764 
765 struct LoopParts {
766   LoopConditionInfo loop_condition_info;
767   std::vector<std::pair<size_t, size_t>> after_param_to_output_indices;
768   std::vector<size_t> ignored_loop_param_indices;
769   std::vector<std::pair<size_t, size_t>> used_loop_to_control_param_indices;
770   CNodePtr repeat_node;
771   FuncGraphPtr loop_subgraph;
772   FuncGraphPtr after_loop_subgraph;
773 };
774 
MatchGraph(const CNodePtr & start_node)775 LoopParts MatchGraph(const CNodePtr &start_node) {
776   LoopParts result;
777 
778   auto control_subgraph_value = dyn_cast<ValueNode>(start_node->input(0));
779   MS_EXCEPTION_IF_NULL(control_subgraph_value);
780   auto control_subgraph = dyn_cast<FuncGraph>(control_subgraph_value->value());
781   MS_EXCEPTION_IF_NULL(control_subgraph);
782 
783   auto switch_node = FindLoopSwitchNode(control_subgraph);
784   auto cond_node = GetNodeInput<CNode>(switch_node, kOneNum);
785 
786   result.loop_subgraph = GetSubgraph(switch_node, LOOP_BODY_INPUT, LOOP_BODY_PATTERN);
787 
788   result.repeat_node = FindLoopRepeatNode(result.loop_subgraph, control_subgraph);
789   result.loop_condition_info = TraceLoopConditionInfo(start_node, cond_node, control_subgraph, result.repeat_node);
790 
791   result.after_loop_subgraph = GetSubgraph(switch_node, AFTER_LOOP_INPUT, AFTER_LOOP_PATTERN);
792 
793   auto loop_to_control_order_map = TraceLoopToControlMap(control_subgraph);
794   auto ignored_loop_params_mask = TraceIgnoredLoopParams(start_node, loop_to_control_order_map);
795   auto loop_inputs_num = start_node->size() - 1;
796   for (size_t i = 0; i < loop_inputs_num; ++i) {
797     if (ignored_loop_params_mask.at(i)) {
798       result.ignored_loop_param_indices.push_back(i);
799     } else {
800       result.used_loop_to_control_param_indices.push_back(std::make_pair(i, loop_to_control_order_map.at(i)));
801     }
802   }
803 
804   auto after_to_loop_order_map = TraceAfterToLoopMap(control_subgraph);
805   for (size_t after_i = 0; after_i < result.after_loop_subgraph->parameters().size(); ++after_i) {
806     auto loop_i = after_to_loop_order_map.at(after_i);
807     if (!ignored_loop_params_mask.at(loop_i)) {
808       auto output_i = loop_i;
809       for (size_t i = 0; i < loop_i; ++i) {
810         output_i -= static_cast<size_t>(ignored_loop_params_mask.at(i));
811       }
812       result.after_param_to_output_indices.push_back(std::make_pair(after_i, output_i));
813     }
814   }
815 
816   return result;
817 }
818 }  // namespace while_loop_export
819 
820 class OpAttrInfo {
821  public:
OpAttrInfo(const std::string & attr_name,const string & onnx_attr_name,onnx::AttributeProto_AttributeType onnx_attr_type,const GenAttrFuncType & fn_gen_attr)822   OpAttrInfo(const std::string &attr_name, const string &onnx_attr_name,
823              onnx::AttributeProto_AttributeType onnx_attr_type, const GenAttrFuncType &fn_gen_attr)
824       : attr_name_(attr_name),
825         onnx_attr_name_(onnx_attr_name),
826         onnx_attr_type_(onnx_attr_type),
827         fn_gen_attr_(fn_gen_attr) {}
~OpAttrInfo()828   ~OpAttrInfo() {}
829 
attr_name() const830   const std::string &attr_name() const { return attr_name_; }
onnx_attr_name() const831   const std::string &onnx_attr_name() const { return onnx_attr_name_; }
onnx_attr_type() const832   onnx::AttributeProto_AttributeType onnx_attr_type() const { return onnx_attr_type_; }
fn_gen_attr() const833   GenAttrFuncType fn_gen_attr() const { return fn_gen_attr_; }
834 
835  private:
836   std::string attr_name_;                              // attribute name of MindSpore
837   std::string onnx_attr_name_;                         // corresponding attribute name of ONNX
838   onnx::AttributeProto_AttributeType onnx_attr_type_;  // corresponding attribute type of ONNX
839   GenAttrFuncType fn_gen_attr_;                        // function used convert
840 };
841 
842 struct InputConversion {
843   int input_index;
844   onnx::TensorProto_DataType input_type;
845   onnx::TensorProto_DataType target_type;
846 };
847 
848 struct OutputConversion {
849   int output_index;
850   enum class Mode { FIXED, INPUT } mode;
851   union {
852     onnx::TensorProto_DataType target_type;
853     int input_with_matching_type;
854   };
855 };
856 
857 class OpNameInfo {
858  public:
set_op_type(const std::string & op_type)859   OpNameInfo &set_op_type(const std::string &op_type) {
860     op_type_ = op_type;
861     return *this;
862   }
863 
op_type() const864   const std::string &op_type() const { return op_type_; }
865 
set_onnx_type(const std::string & onnx_type)866   OpNameInfo &set_onnx_type(const std::string &onnx_type) {
867     onnx_type_ = onnx_type;
868     return *this;
869   }
870 
onnx_type() const871   const std::string &onnx_type() const { return onnx_type_; }
872 
Attr(const std::string & attr_name,const std::string & onnx_attr_name,onnx::AttributeProto_AttributeType onnx_attr_type,const GenAttrFuncType & fn_gen_attr)873   OpNameInfo &Attr(const std::string &attr_name, const std::string &onnx_attr_name,
874                    onnx::AttributeProto_AttributeType onnx_attr_type, const GenAttrFuncType &fn_gen_attr) {
875     (void)op_attrs_.emplace_back(OpAttrInfo(attr_name, onnx_attr_name, onnx_attr_type, fn_gen_attr));
876     return *this;
877   }
878 
op_attrs() const879   const std::vector<OpAttrInfo> &op_attrs() const { return op_attrs_; }
880 
input_casts() const881   const std::vector<InputConversion> &input_casts() const { return input_casts_; }
882 
CastInput(int input_index,onnx::TensorProto_DataType input_type,onnx::TensorProto_DataType target_type)883   OpNameInfo &CastInput(int input_index, onnx::TensorProto_DataType input_type,
884                         onnx::TensorProto_DataType target_type) {
885     input_casts_.push_back({input_index, input_type, target_type});
886     return *this;
887   }
888 
output_casts() const889   const std::vector<OutputConversion> &output_casts() const { return output_casts_; }
890 
CastOutputToFixedType(onnx::TensorProto_DataType type,int output_index=0)891   OpNameInfo &CastOutputToFixedType(onnx::TensorProto_DataType type, int output_index = 0) {
892     output_casts_.push_back({output_index, OutputConversion::Mode::FIXED, {type}});
893     return *this;
894   }
895 
CastOutputToInputType(int input_index,int output_index=0)896   OpNameInfo &CastOutputToInputType(int input_index, int output_index = 0) {
897     auto rule = OutputConversion{output_index, OutputConversion::Mode::INPUT};
898     rule.input_with_matching_type = input_index;
899     output_casts_.push_back(rule);
900     return *this;
901   }
902 
num_outputs() const903   int num_outputs() const { return num_outputs_; }
904 
set_num_outputs(int n)905   OpNameInfo &set_num_outputs(int n) {
906     num_outputs_ = n;
907     return *this;
908   }
909 
910  private:
911   std::string op_type_;                         // operator type of MindSpore
912   std::string onnx_type_;                       // corresponding ONNX operator type
913   std::vector<OpAttrInfo> op_attrs_;            // operator attributes map info
914   std::vector<InputConversion> input_casts_;    // if input input_index has type input_type, cast it to target_type
915   std::vector<OutputConversion> output_casts_;  // cast output output_index to fixed type or input type
916   int num_outputs_ = 1;
917 };
918 
919 #define OPERATOR_ONNX_CONVERT_DEFINE(name, onnx_name, impl) \
920   OpNameInfo GetOpOnnxConvertInfo_##name() { return impl.set_op_type(#name).set_onnx_type(#onnx_name); }
921 
OPERATOR_ONNX_CONVERT_DEFINE(Mod,Mod,OpNameInfo ())922 OPERATOR_ONNX_CONVERT_DEFINE(Mod, Mod, OpNameInfo())
923 OPERATOR_ONNX_CONVERT_DEFINE(Add, Add, OpNameInfo())
924 OPERATOR_ONNX_CONVERT_DEFINE(Mul, Mul, OpNameInfo())
925 OPERATOR_ONNX_CONVERT_DEFINE(Pow, Pow, OpNameInfo())
926 
927 OPERATOR_ONNX_CONVERT_DEFINE(ReLU, Relu, OpNameInfo())
928 OPERATOR_ONNX_CONVERT_DEFINE(Sigmoid, Sigmoid, OpNameInfo())
929 OPERATOR_ONNX_CONVERT_DEFINE(Sin, Sin, OpNameInfo())
930 OPERATOR_ONNX_CONVERT_DEFINE(Round, Round, OpNameInfo())
931 OPERATOR_ONNX_CONVERT_DEFINE(Div, Div, OpNameInfo())
932 
933 OPERATOR_ONNX_CONVERT_DEFINE(Flatten, Flatten, OpNameInfo())
934 
935 OPERATOR_ONNX_CONVERT_DEFINE(
936   Conv2D, Conv,
937   OpNameInfo()
938     .Attr("dilation", "dilations", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<2>)
939     .Attr("group", "group", onnx::AttributeProto_AttributeType_INT, SetAttrValueToProto<Int64Imm>)
940     .Attr("kernel_size", "kernel_shape", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<0>)
941     .Attr("pad_mode", "auto_pad", onnx::AttributeProto_AttributeType_STRING, SetConvPadding)
942     .Attr("stride", "strides", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<2>))
943 OPERATOR_ONNX_CONVERT_DEFINE(
944   Conv3D, Conv,
945   OpNameInfo()
946     .Attr("dilations", "dilations", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<kTwoNum>)
947     .Attr("group", "group", onnx::AttributeProto_AttributeType_INT, SetAttrValueToProto<Int64Imm>)
948     .Attr("kernel_size", "kernel_shape", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<0>)
949     .Attr("pad_mode", "auto_pad", onnx::AttributeProto_AttributeType_STRING, SetConvPadding)
950     .Attr("strides", "strides", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<kTwoNum>))
951 OPERATOR_ONNX_CONVERT_DEFINE(
952   Conv3DTranspose, ConvTranspose,
953   OpNameInfo()
954     .Attr("dilations", "dilations", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<kTwoNum>)
955     .Attr("group", "group", onnx::AttributeProto_AttributeType_INT, SetAttrValueToProto<Int64Imm>)
956     .Attr("kernel_size", "kernel_shape", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<0>)
957     .Attr("pad_mode", "auto_pad", onnx::AttributeProto_AttributeType_STRING, SetConvTransposePadding)
958     .Attr("strides", "strides", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<kTwoNum>)
959     .Attr("output_padding", "output_padding", onnx::AttributeProto_AttributeType_INTS,
960           SetAttrTupleValueToProto<kTwoNum>))
961 
962 OPERATOR_ONNX_CONVERT_DEFINE(DepthToSpace, DepthToSpace,
963                              OpNameInfo().Attr("block_size", "blocksize", onnx::AttributeProto_AttributeType_INT,
964                                                SetAttrValueToProto<Int64Imm>))
965 
966 OPERATOR_ONNX_CONVERT_DEFINE(BiasAdd, Add, OpNameInfo())
967 OPERATOR_ONNX_CONVERT_DEFINE(MatMul, Gemm,
968                              OpNameInfo()
969                                .Attr("transpose_a", "transA", onnx::AttributeProto_AttributeType_INT,
970                                      SetAttrValueToProto<BoolImm>)
971                                .Attr("transpose_b", "transB", onnx::AttributeProto_AttributeType_INT,
972                                      SetAttrValueToProto<BoolImm>))
973 
974 OPERATOR_ONNX_CONVERT_DEFINE(BatchNorm, BatchNormalization,
975                              OpNameInfo()
976                                .Attr("epsilon", "epsilon", onnx::AttributeProto_AttributeType_FLOAT,
977                                      SetAttrValueToProto<FP32Imm>)
978                                .CastInput(0, onnx::TensorProto_DataType_FLOAT16, onnx::TensorProto_DataType_FLOAT)
979                                .CastOutputToInputType(0))
980 
981 OPERATOR_ONNX_CONVERT_DEFINE(Reshape, Reshape, OpNameInfo())
982 OPERATOR_ONNX_CONVERT_DEFINE(Cast, Cast, OpNameInfo())
983 OPERATOR_ONNX_CONVERT_DEFINE(PReLU, PRelu, OpNameInfo())
984 OPERATOR_ONNX_CONVERT_DEFINE(Argmax, ArgMax,
985                              OpNameInfo()
986                                .Attr("axis", "axis", onnx::AttributeProto_AttributeType_INT,
987                                      SetAttrValueToProto<Int64Imm>)
988                                .Attr("", "keepdims", onnx::AttributeProto_AttributeType_INT,
989                                      [](ValuePtr, onnx::AttributeProto_AttributeType,
990                                         onnx::AttributeProto *const attr_proto, const PrimitivePtr &) {
991                                        attr_proto->set_type(onnx::AttributeProto_AttributeType_INT);
992                                        attr_proto->set_i(0);
993                                      })
994                                .CastOutputToFixedType(onnx::TensorProto_DataType_INT32))
995 
996 OPERATOR_ONNX_CONVERT_DEFINE(SimpleMean, AveragePool, OpNameInfo())
997 OPERATOR_ONNX_CONVERT_DEFINE(
998   MaxPool, MaxPool,
999   OpNameInfo()
1000     .Attr("kernel_size", "kernel_shape", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<2>)
1001     .Attr("pad_mode", "auto_pad", onnx::AttributeProto_AttributeType_STRING, SetPoolingPadMode)
1002     .Attr("strides", "strides", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<2>))
1003 
1004 OPERATOR_ONNX_CONVERT_DEFINE(
1005   MaxPool3D, MaxPool,
1006   OpNameInfo()
1007     .Attr("kernel_size", "kernel_shape", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<2>)
1008     .Attr("pad_mode", "auto_pad", onnx::AttributeProto_AttributeType_STRING, SetPoolingPadMode)
1009     .Attr("strides", "strides", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<2>))
1010 
1011 OPERATOR_ONNX_CONVERT_DEFINE(
1012   MaxPoolWithArgmax, MaxPool,
1013   OpNameInfo()
1014     .Attr("kernel_size", "kernel_shape", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<2>)
1015     .Attr("pad_mode", "auto_pad", onnx::AttributeProto_AttributeType_STRING, SetPoolingPadMode)
1016     .Attr("strides", "strides", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<2>))
1017 
1018 OPERATOR_ONNX_CONVERT_DEFINE(
1019   AvgPool, AveragePool,
1020   OpNameInfo()
1021     .Attr("kernel_size", "kernel_shape", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<2>)
1022     .Attr("pad_mode", "auto_pad", onnx::AttributeProto_AttributeType_STRING, SetPoolingPadMode)
1023     .Attr("strides", "strides", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<2>))
1024 
1025 OPERATOR_ONNX_CONVERT_DEFINE(Gather, Gather, OpNameInfo())
1026 OPERATOR_ONNX_CONVERT_DEFINE(MakeTuple, SequenceConstruct, OpNameInfo())
1027 OPERATOR_ONNX_CONVERT_DEFINE(RealDiv, Div, OpNameInfo())
1028 OPERATOR_ONNX_CONVERT_DEFINE(Sub, Sub, OpNameInfo())
1029 OPERATOR_ONNX_CONVERT_DEFINE(Neg, Neg, OpNameInfo())
1030 OPERATOR_ONNX_CONVERT_DEFINE(Maximum, Max,
1031                              OpNameInfo()
1032                                .CastInput(0, onnx::TensorProto_DataType_INT32, onnx::TensorProto_DataType_FLOAT)
1033                                .CastInput(1, onnx::TensorProto_DataType_INT32, onnx::TensorProto_DataType_FLOAT)
1034                                .CastOutputToInputType(0))
1035 OPERATOR_ONNX_CONVERT_DEFINE(Minimum, Min,
1036                              OpNameInfo()
1037                                .CastInput(0, onnx::TensorProto_DataType_INT32, onnx::TensorProto_DataType_FLOAT)
1038                                .CastInput(1, onnx::TensorProto_DataType_INT32, onnx::TensorProto_DataType_FLOAT)
1039                                .CastOutputToInputType(0))
1040 OPERATOR_ONNX_CONVERT_DEFINE(Transpose, Transpose, OpNameInfo())
1041 OPERATOR_ONNX_CONVERT_DEFINE(Exp, Exp, OpNameInfo())
1042 OPERATOR_ONNX_CONVERT_DEFINE(Softplus, Softplus, OpNameInfo())
1043 OPERATOR_ONNX_CONVERT_DEFINE(Tanh, Tanh, OpNameInfo())
1044 OPERATOR_ONNX_CONVERT_DEFINE(Abs, Abs, OpNameInfo())
1045 
1046 // MindSpore Softmax axis(int, Tuple)
1047 OPERATOR_ONNX_CONVERT_DEFINE(Softmax, Softmax,
1048                              OpNameInfo().Attr("axis", "axis", onnx::AttributeProto_AttributeType_INT,
1049                                                SetAttrTupleValueToProto<0>))
1050 
1051 // MindSpore LogSoftmax axis(int)
1052 OPERATOR_ONNX_CONVERT_DEFINE(LogSoftmax, LogSoftmax,
1053                              OpNameInfo().Attr("axis", "axis", onnx::AttributeProto_AttributeType_INT,
1054                                                SetAttrValueToProto<Int64Imm>))
1055 
1056 OPERATOR_ONNX_CONVERT_DEFINE(Softsign, Softsign, OpNameInfo())
1057 OPERATOR_ONNX_CONVERT_DEFINE(Sqrt, Sqrt, OpNameInfo())
1058 OPERATOR_ONNX_CONVERT_DEFINE(Equal, Equal, OpNameInfo())
1059 OPERATOR_ONNX_CONVERT_DEFINE(Floor, Floor, OpNameInfo())
1060 OPERATOR_ONNX_CONVERT_DEFINE(ACos, Acos, OpNameInfo())
1061 
1062 OPERATOR_ONNX_CONVERT_DEFINE(GatherNd, GatherND,
1063                              OpNameInfo().CastInput(1, onnx::TensorProto_DataType_INT32,
1064                                                     onnx::TensorProto_DataType_INT64))
1065 OPERATOR_ONNX_CONVERT_DEFINE(Select, Where, OpNameInfo())
1066 OPERATOR_ONNX_CONVERT_DEFINE(Log, Log, OpNameInfo())
1067 OPERATOR_ONNX_CONVERT_DEFINE(Greater, Greater, OpNameInfo())
1068 OPERATOR_ONNX_CONVERT_DEFINE(LogicalAnd, And, OpNameInfo())
1069 OPERATOR_ONNX_CONVERT_DEFINE(LogicalOr, Or, OpNameInfo())
1070 OPERATOR_ONNX_CONVERT_DEFINE(ReverseSequence, ReverseSequence,
1071                              OpNameInfo()
1072                                .Attr("seq_dim", "time_axis", onnx::AttributeProto_AttributeType_INT,
1073                                      SetAttrValueToProto<Int64Imm>)
1074                                .Attr("batch_dim", "batch_axis", onnx::AttributeProto_AttributeType_INT,
1075                                      SetAttrValueToProto<Int64Imm>)
1076                                .CastInput(1, onnx::TensorProto_DataType_INT32, onnx::TensorProto_DataType_INT64))
1077 OPERATOR_ONNX_CONVERT_DEFINE(Less, Less, OpNameInfo())
1078 OPERATOR_ONNX_CONVERT_DEFINE(TensorScatterUpdate, ScatterND,
1079                              OpNameInfo().CastInput(1, onnx::TensorProto_DataType_INT32,
1080                                                     onnx::TensorProto_DataType_INT64))
1081 OPERATOR_ONNX_CONVERT_DEFINE(Cos, Cos, OpNameInfo())
1082 OPERATOR_ONNX_CONVERT_DEFINE(Atan2, Atan2, OpNameInfo())
1083 
1084 #define OP_CONVERT_FUNCTION_NAME(name) GetOpOnnxConvertInfo_##name
1085 
1086 void RegisterOpConverters(const std::function<void(OpNameInfo &&)> &fn) {
1087   fn(OP_CONVERT_FUNCTION_NAME(Mod)());
1088   fn(OP_CONVERT_FUNCTION_NAME(DepthToSpace)());
1089   fn(OP_CONVERT_FUNCTION_NAME(Add)());
1090   fn(OP_CONVERT_FUNCTION_NAME(Mul)());
1091   fn(OP_CONVERT_FUNCTION_NAME(Pow)());
1092   fn(OP_CONVERT_FUNCTION_NAME(ReLU)());
1093   fn(OP_CONVERT_FUNCTION_NAME(Sigmoid)());
1094   fn(OP_CONVERT_FUNCTION_NAME(Conv2D)());
1095   fn(OP_CONVERT_FUNCTION_NAME(Conv3D)());
1096   fn(OP_CONVERT_FUNCTION_NAME(Conv3DTranspose)());
1097   fn(OP_CONVERT_FUNCTION_NAME(Argmax)());
1098   fn(OP_CONVERT_FUNCTION_NAME(Flatten)());
1099   fn(OP_CONVERT_FUNCTION_NAME(MaxPool)());
1100   fn(OP_CONVERT_FUNCTION_NAME(MaxPool3D)());
1101   fn(OP_CONVERT_FUNCTION_NAME(MaxPoolWithArgmax)());
1102   fn(OP_CONVERT_FUNCTION_NAME(AvgPool)());
1103 
1104   fn(OP_CONVERT_FUNCTION_NAME(BatchNorm)());
1105   fn(OP_CONVERT_FUNCTION_NAME(MatMul)());
1106   fn(OP_CONVERT_FUNCTION_NAME(MakeTuple)());
1107   fn(OP_CONVERT_FUNCTION_NAME(RealDiv)());
1108   fn(OP_CONVERT_FUNCTION_NAME(BiasAdd)());
1109   fn(OP_CONVERT_FUNCTION_NAME(Sub)());
1110   fn(OP_CONVERT_FUNCTION_NAME(Neg)());
1111   fn(OP_CONVERT_FUNCTION_NAME(Maximum)());
1112   fn(OP_CONVERT_FUNCTION_NAME(Minimum)());
1113   fn(OP_CONVERT_FUNCTION_NAME(Exp)());
1114 
1115   fn(OP_CONVERT_FUNCTION_NAME(Softplus)());
1116   fn(OP_CONVERT_FUNCTION_NAME(Tanh)());
1117   fn(OP_CONVERT_FUNCTION_NAME(Softmax)());
1118   fn(OP_CONVERT_FUNCTION_NAME(LogSoftmax)());
1119   fn(OP_CONVERT_FUNCTION_NAME(Abs)());
1120   fn(OP_CONVERT_FUNCTION_NAME(Softsign)());
1121   fn(OP_CONVERT_FUNCTION_NAME(Sqrt)());
1122   fn(OP_CONVERT_FUNCTION_NAME(Equal)());
1123   fn(OP_CONVERT_FUNCTION_NAME(Floor)());
1124   fn(OP_CONVERT_FUNCTION_NAME(ACos)());
1125 
1126   fn(OP_CONVERT_FUNCTION_NAME(GatherNd)());
1127   fn(OP_CONVERT_FUNCTION_NAME(Select)());
1128   fn(OP_CONVERT_FUNCTION_NAME(Log)());
1129   fn(OP_CONVERT_FUNCTION_NAME(Less)());
1130   fn(OP_CONVERT_FUNCTION_NAME(Greater)());
1131   fn(OP_CONVERT_FUNCTION_NAME(LogicalAnd)());
1132   fn(OP_CONVERT_FUNCTION_NAME(LogicalOr)());
1133   fn(OP_CONVERT_FUNCTION_NAME(ReverseSequence)());
1134   fn(OP_CONVERT_FUNCTION_NAME(TensorScatterUpdate)());
1135 
1136   fn(OP_CONVERT_FUNCTION_NAME(Sin)());
1137   fn(OP_CONVERT_FUNCTION_NAME(Cos)());
1138   fn(OP_CONVERT_FUNCTION_NAME(Atan2)());
1139   fn(OP_CONVERT_FUNCTION_NAME(Round)());
1140   fn(OP_CONVERT_FUNCTION_NAME(Div)());
1141 }
1142 
1143 class OpConvertRegistry {
1144  public:
~OpConvertRegistry()1145   ~OpConvertRegistry() { Clear(); }
1146 
RegisterOneOpConverter(OpNameInfo && op_info)1147   static void RegisterOneOpConverter(OpNameInfo &&op_info) { GetSingleton().op_map_[op_info.op_type()] = op_info; }
1148 
RegisterAllOpConverters()1149   static void RegisterAllOpConverters() { RegisterOpConverters(RegisterOneOpConverter); }
1150 
GetSingleton()1151   static OpConvertRegistry &GetSingleton() {
1152     static OpConvertRegistry registry = OpConvertRegistry();
1153     return registry;
1154   }
1155 
GetOpConvertMap()1156   static const mindspore::HashMap<std::string, OpNameInfo> &GetOpConvertMap() { return GetSingleton().op_map_; }
1157 
Clear()1158   void Clear() noexcept { op_map_.clear(); }
1159 
1160  private:
OpConvertRegistry()1161   OpConvertRegistry() {}
1162 
1163   mindspore::HashMap<std::string, OpNameInfo> op_map_;
1164 };
1165 
1166 class OnnxExporter {
1167  public:
OnnxExporter()1168   OnnxExporter() {}
~OnnxExporter()1169   ~OnnxExporter() {}
1170 
1171   std::string GetOnnxProtoString(const FuncGraphPtr &func_graph);
1172 
1173  private:
1174   void InitModelInfo();
1175 
1176   void ExportFuncGraph(const FuncGraphPtr &func_graph, std::map<AnfNodePtr, std::string> *node_map_ptr,
1177                        onnx::GraphProto *graph_proto, bool export_inputs = true);
1178   void ExportInputs(const FuncGraphPtr &func_graph, std::map<AnfNodePtr, std::string> *node_map_ptr,
1179                     onnx::GraphProto *graph_proto);
1180 
1181   std::string ExportPrimitive(const FuncGraphPtr &func_graph, std::map<AnfNodePtr, std::string> *node_map_ptr,
1182                               const PrimitivePtr &prim, const std::vector<AnfNodePtr> &inputs,
1183                               onnx::GraphProto *graph_proto);
1184 
1185   static onnx::TensorProto_DataType GetOnnxDataType(TypeId type_id);
1186   static onnx::TensorProto_DataType GetOutputType(const AnfNodePtr &node, int64_t output_index = -1);
1187   void SetValueInfoType(const AnfNodePtr &node, onnx::ValueInfoProto *value_proto, int64_t output_index = -1) const;
1188 
1189   void MatchAndMark(const FuncGraphPtr &func_graph, const std::vector<AnfNodePtr> &nodes,
1190                     mindspore::HashMap<AnfNodePtr, OpMergedInfo> *op_merged_infos_ptr) const;
1191   void MatchAndMarkCNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
1192                          mindspore::HashMap<AnfNodePtr, OpMergedInfo> *op_merged_infos_ptr) const;
1193   void IgnoreMakeTuple(const AnfNodePtr &node, mindspore::HashMap<AnfNodePtr, OpMergedInfo> *op_merged_infos_ptr) const;
1194 
1195   void ExportNodes(const FuncGraphPtr &func_graph, std::map<AnfNodePtr, std::string> *node_map_ptr,
1196                    onnx::GraphProto *graph_proto);
1197 
1198   void ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &node,
1199                    std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
1200   void ExportWhileLoop(const CNodePtr &start_node, std::map<AnfNodePtr, std::string> *node_map_ptr,
1201                        onnx::GraphProto *graph_proto);
1202 
1203   void ExportPrimReshape(const FuncGraphPtr &func_graph, const CNodePtr &node,
1204                          std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
1205   void ExportPrimReduce(const FuncGraphPtr &func_graph, const CNodePtr &node,
1206                         std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
1207   void ExportPrimReduceAnyOrAll(const FuncGraphPtr &func_graph, const CNodePtr &node,
1208                                 std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
1209   void ExportPrimTranspose(const FuncGraphPtr &func_graph, const CNodePtr &node,
1210                            std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
1211   void ExportPrimStridedSlice(const FuncGraphPtr &func_graph, const CNodePtr &node,
1212                               std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
1213   onnx::NodeProto *PrimResizeExportHelper(const FuncGraphPtr &, const CNodePtr &node,
1214                                           std::map<AnfNodePtr, std::string> *node_map_ptr,
1215                                           onnx::GraphProto *const graph_proto);
1216   void ExportPrimResizeNearestNeighbor(const FuncGraphPtr &func_graph, const CNodePtr &node,
1217                                        std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
1218   void ExportPrimResizeBilinear(const FuncGraphPtr &func_graph, const CNodePtr &node,
1219                                 std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
1220   void ExportPrimExpandDims(const FuncGraphPtr &func_graph, const CNodePtr &node,
1221                             std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
1222   void ExportPrimGatherD(const FuncGraphPtr &func_graph, const CNodePtr &node,
1223                          std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
1224   void ExportPrimPad(const FuncGraphPtr &func_graph, const CNodePtr &node,
1225                      std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
1226   void ExportPrimBatchMatMul(const FuncGraphPtr &func_graph, const CNodePtr &node,
1227                              std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
1228   void ExportPrimBroadcastTo(const FuncGraphPtr &func_graph, const CNodePtr &node,
1229                              std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
1230   void ExportPrimAddN(const FuncGraphPtr &func_graph, const CNodePtr &node,
1231                       std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
1232   void ExportPrimGeLU(const FuncGraphPtr &func_graph, const CNodePtr &node,
1233                       std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
1234   void ExportPrimConcat(const FuncGraphPtr &func_graph, const CNodePtr &node,
1235                         std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
1236   void ExportPrimCast(const FuncGraphPtr &func_graph, const CNodePtr &node,
1237                       std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
1238   void ExportPrimPReLU(const FuncGraphPtr &func_graph, const CNodePtr &node,
1239                        std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
1240   void ExportPrimReLU6(const FuncGraphPtr &func_graph, const CNodePtr &node,
1241                        std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
1242   void ExportPrimDepthwiseConv2d(const FuncGraphPtr &func_graph, const CNodePtr &node,
1243                                  std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
1244   void ExportPrimTile(const FuncGraphPtr &func_graph, const CNodePtr &node,
1245                       std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
1246   void ExportPrimSquare(const FuncGraphPtr &func_graph, const CNodePtr &node,
1247                         std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
1248   void ExportPrimGatherV2(const FuncGraphPtr &func_graph, const CNodePtr &node,
1249                           std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
1250   void ExportPrimTupleGetItem(const FuncGraphPtr &func_graph, const CNodePtr &node,
1251                               std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
1252   void ExportPrimTopK(const FuncGraphPtr &func_graph, const CNodePtr &node,
1253                       std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
1254   void ExportPrimBoundingBoxDecode(const FuncGraphPtr &func_graph, const CNodePtr &node,
1255                                    std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
1256   void ExportPrimNMSWithMask(const FuncGraphPtr &func_graph, const CNodePtr &node,
1257                              std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
1258   void ExportPrimSplit(const FuncGraphPtr &func_graph, const CNodePtr &node,
1259                        std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
1260   void ExportPrimROIAlign(const FuncGraphPtr &func_graph, const CNodePtr &node,
1261                           std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
1262   void ExportPrimSlice(const FuncGraphPtr &func_graph, const CNodePtr &node,
1263                        std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
1264   void ExportPrimOnesLike(const FuncGraphPtr &func_graph, const CNodePtr &node,
1265                           std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
1266   void ExportPrimScatterNd(const FuncGraphPtr &func_graph, const CNodePtr &node,
1267                            std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
1268   void ExportPrimArgMaxWithValue(const FuncGraphPtr &func_graph, const CNodePtr &node,
1269                                  std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
1270   void ExportPrimArgMinWithValue(const FuncGraphPtr &func_graph, const CNodePtr &node,
1271                                  std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
1272   void ExportPrimOneHot(const FuncGraphPtr &func_graph, const CNodePtr &node,
1273                         std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
1274   void PrimConv2DTransposeExportHelper(const CNodePtr &conv_node, const CNodePtr &bias_add_node,
1275                                        std::map<AnfNodePtr, std::string> *node_map_ptr,
1276                                        onnx::GraphProto *const graph_proto);
1277   void ExportPrimConv2DTranspose(const FuncGraphPtr &func_graph, const CNodePtr &node,
1278                                  std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
1279   void ExportPrimGreaterEqual(const FuncGraphPtr &func_graph, const CNodePtr &node,
1280                               std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
1281   void ExportPrimLessEqual(const FuncGraphPtr &func_graph, const CNodePtr &node,
1282                            std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
1283   void ExportPrimNotEqual(const FuncGraphPtr &func_graph, const CNodePtr &node,
1284                           std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
1285   void ExportPrimDense(const FuncGraphPtr &func_graph, const CNodePtr &node,
1286                        std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
1287   void ExportPrimSqueeze(const FuncGraphPtr &func_graph, const CNodePtr &node,
1288                          std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
1289   void ExportPrimDynamicRNN(const FuncGraphPtr &func_graph, const CNodePtr &node,
1290                             std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *const graph_proto);
1291   void ExportPrimLSTM(const FuncGraphPtr &, const CNodePtr &node, std::map<AnfNodePtr, std::string> *node_map_ptr,
1292                       onnx::GraphProto *graph_proto);
1293   void ExportPrimReverseV2(const FuncGraphPtr &func_graph, const CNodePtr &node,
1294                            std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
1295   void ExportPrimTensorCopySlices(const FuncGraphPtr &, const CNodePtr &node,
1296                                   std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
1297   void ExportPrimStack(const FuncGraphPtr &, const CNodePtr &node, std::map<AnfNodePtr, std::string> *node_map_ptr,
1298                        onnx::GraphProto *graph_proto);
1299   void ExportPrimAtan2(const FuncGraphPtr &, const CNodePtr &node, std::map<AnfNodePtr, std::string> *node_map_ptr,
1300                        onnx::GraphProto *graph_proto);
1301   void ExportPrimFloorDiv(const FuncGraphPtr &, const CNodePtr &node, std::map<AnfNodePtr, std::string> *node_map_ptr,
1302                           onnx::GraphProto *graph_proto);
1303   void ExportPrimFloorMod(const FuncGraphPtr &, const CNodePtr &node, std::map<AnfNodePtr, std::string> *node_map_ptr,
1304                           onnx::GraphProto *graph_proto);
1305   void ExportPrimSort(const FuncGraphPtr &, const CNodePtr &node, std::map<AnfNodePtr, std::string> *node_map_ptr,
1306                       onnx::GraphProto *graph_proto);
1307   void ExportPrimCustom(const FuncGraphPtr &, const CNodePtr &node, std::map<AnfNodePtr, std::string> *node_map_ptr,
1308                         onnx::GraphProto *graph_proto);
1309   void ExportMergeConv(const FuncGraphPtr &func_graph, const CNodePtr &node,
1310                        std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
1311   void ExportMergeGemm(const FuncGraphPtr &func_graph, const CNodePtr &node,
1312                        std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
1313   void ExportMergeBatchNorm(const FuncGraphPtr &func_graph, const CNodePtr &node,
1314                             std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
1315   void ExportMergeMaxPoolWithArgmax(const FuncGraphPtr &func_graph, const CNodePtr &node,
1316                                     std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
1317   void ExportMergeLayerNorm(const FuncGraphPtr &func_graph, const CNodePtr &node,
1318                             std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
1319   void ExportMergeConv2DTranspose(const FuncGraphPtr &, const CNodePtr &node,
1320                                   std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
1321   void ExportMergeDynamicGRUV2(const FuncGraphPtr &, const CNodePtr &node,
1322                                std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *const graph_proto);
1323   void ExportOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &return_arg,
1324                     std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
1325   std::string GetNodeInputName(const AnfNodePtr &node, std::map<AnfNodePtr, std::string> *node_map_ptr,
1326                                onnx::GraphProto *const);
1327 
1328   void ConvertTupleToTensor(const ValuePtr &value, onnx::TensorProto *tensor_proto) const;
1329   void SetTensorData(const ValuePtr &value, onnx::TensorProto *tensor_proto);
1330 
1331   void AddOutputWithCast(onnx::NodeProto *node_proto, const std::string &output_name,
1332                          onnx::TensorProto_DataType target_type, onnx::GraphProto *graph_proto) const;
1333 
GenerateUniqueName()1334   std::string GenerateUniqueName() { return std::to_string(++onnx_node_index_); }
RegisterNodeWithUniqueName(const AnfNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr)1335   std::string RegisterNodeWithUniqueName(const AnfNodePtr &node, std::map<AnfNodePtr, std::string> *node_map_ptr) {
1336     auto name = GenerateUniqueName();
1337     (*node_map_ptr)[node] = name;
1338     return name;
1339   }
GenerateUniqueParameterName(const ParameterPtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr)1340   std::string GenerateUniqueParameterName(const ParameterPtr &node, std::map<AnfNodePtr, std::string> *node_map_ptr) {
1341     auto node_name = node->ToString();
1342     MS_EXCEPTION_IF_CHECK_FAIL(node_name != "", "Cannot get the name of an ignored parameter");
1343     auto dup_iter = std::find_if(node_map_ptr->begin(), node_map_ptr->end(),
1344                                  [&node_name](const auto &pair) { return pair.second == node_name; });
1345     if (dup_iter != node_map_ptr->end()) {
1346       node_name = GenerateUniqueName() + node_name;
1347     }
1348     return node_name;
1349   }
1350 
ResetNodeIndex()1351   void ResetNodeIndex() { onnx_node_index_ = 0; }
1352 
GetInt64Value(const AnfNodePtr & node)1353   static int64_t GetInt64Value(const AnfNodePtr &node) {
1354     auto value_node_ptr = dyn_cast<ValueNode>(node);
1355     MS_EXCEPTION_IF_NULL(value_node_ptr);
1356     return GetValue<int64_t>(value_node_ptr->value());
1357   }
1358 
1359   onnx::ModelProto model_;
1360 
1361   size_t onnx_node_index_ = 0;
1362 
1363   std::map<AnfNodePtr, std::string> renamed_node_map_;
1364 };
1365 
GetOnnxProtoString(const FuncGraphPtr & func_graph)1366 std::string OnnxExporter::GetOnnxProtoString(const FuncGraphPtr &func_graph) {
1367   if (func_graph == nullptr) {
1368     return "";
1369   }
1370   ResetNodeIndex();
1371   OpConvertRegistry::GetSingleton().Clear();
1372   OpConvertRegistry::RegisterAllOpConverters();
1373   InitModelInfo();
1374   onnx::GraphProto *graph_proto = model_.mutable_graph();
1375   std::map<AnfNodePtr, std::string> node_map;
1376   ExportFuncGraph(func_graph, &node_map, graph_proto);
1377   return model_.SerializeAsString();
1378 }
1379 
InitModelInfo()1380 void OnnxExporter::InitModelInfo() {
1381   model_.set_ir_version(onnx::IR_VERSION_2019_1_22);
1382   model_.set_producer_name("MindSpore");
1383   model_.set_producer_version("1.0");
1384   onnx::OperatorSetIdProto *opset_proto = model_.add_opset_import();
1385   opset_proto->set_version(ONNX_VERSION);
1386 }
1387 
ExportFuncGraph(const FuncGraphPtr & func_graph,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto,bool export_inputs)1388 void OnnxExporter::ExportFuncGraph(const FuncGraphPtr &func_graph, std::map<AnfNodePtr, std::string> *node_map_ptr,
1389                                    onnx::GraphProto *const graph_proto, bool export_inputs) {
1390   MS_LOG(INFO) << "Begin exporting onnx model for graph " << func_graph->ToString();
1391 
1392   // Convert yaml defined primitive to old primitive.
1393   std::vector<AnfNodePtr> nodes = TopoSort(func_graph->get_return(), SuccIncoming, AlwaysInclude);
1394   auto manager = func_graph->manager();
1395   MS_EXCEPTION_IF_NULL(manager);
1396   for (const auto &node : nodes) {
1397     if (node->isa<CNode>()) {
1398       auto converted_node = ops::ConvertArgsToAttr(node->cast<CNodePtr>());
1399       // If node is old primitive node, nullptr will be returned.
1400       if (converted_node == nullptr) {
1401         continue;
1402       }
1403       (void)manager->Replace(node, converted_node);
1404     }
1405   }
1406   // set graph name
1407   graph_proto->set_name(func_graph->ToString());
1408 
1409   // export inputs if graph is not inlined
1410   if (export_inputs) {
1411     ExportInputs(func_graph, node_map_ptr, graph_proto);
1412   }
1413 
1414   // export computational nodes and output nodes
1415   ExportNodes(func_graph, node_map_ptr, graph_proto);
1416 
1417   // add names for easier debugging
1418   for (auto &node : *graph_proto->mutable_node()) {
1419     if (!node.has_name()) {
1420       node.set_name(node.output(0) + node.op_type());
1421     }
1422   }
1423 
1424   MS_LOG(INFO) << "End exporting onnx model for graph " << func_graph->ToString();
1425 }
1426 
ExportInputs(const FuncGraphPtr & func_graph,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)1427 void OnnxExporter::ExportInputs(const FuncGraphPtr &func_graph, std::map<AnfNodePtr, std::string> *node_map_ptr,
1428                                 onnx::GraphProto *const graph_proto) {
1429   for (auto &param : func_graph->parameters()) {
1430     const ParameterPtr param_ptr = dyn_cast<Parameter>(param);
1431     if (param_ptr == nullptr) {
1432       MS_LOG(EXCEPTION) << "Parameter '" << param->ToString() << "' could not cast to parameter.";
1433     }
1434 
1435     if (param_ptr->has_default()) {
1436       continue;
1437     }
1438 
1439     // set onnx input.
1440     std::string name;
1441     auto renamed_iter = renamed_node_map_.find(param_ptr);
1442     if (renamed_iter != renamed_node_map_.end()) {
1443       name = renamed_iter->second;
1444       if (name == "") {
1445         continue;
1446       }
1447     } else {
1448       name = GenerateUniqueParameterName(param_ptr, node_map_ptr);
1449       (*node_map_ptr)[param_ptr] = name;
1450     }
1451 
1452     onnx::ValueInfoProto *input_proto = graph_proto->add_input();
1453     input_proto->set_name(name);
1454     SetValueInfoType(param_ptr, input_proto);
1455   }
1456 }
1457 
GetOnnxDataType(TypeId type_id)1458 onnx::TensorProto_DataType OnnxExporter::GetOnnxDataType(TypeId type_id) {
1459   // clang-format off
1460   static mindspore::HashMap<int, onnx::TensorProto_DataType> type_map = {
1461     {kNumberTypeBool, onnx::TensorProto_DataType_BOOL},
1462     {kNumberTypeInt8, onnx::TensorProto_DataType_INT8},
1463     {kNumberTypeInt16, onnx::TensorProto_DataType_INT16},
1464     {kNumberTypeInt32, onnx::TensorProto_DataType_INT32},
1465     {kNumberTypeInt64, onnx::TensorProto_DataType_INT64},
1466     {kNumberTypeUInt8, onnx::TensorProto_DataType_UINT8},
1467     {kNumberTypeUInt16, onnx::TensorProto_DataType_UINT16},
1468     {kNumberTypeUInt32, onnx::TensorProto_DataType_UINT32},
1469     {kNumberTypeUInt64, onnx::TensorProto_DataType_UINT64},
1470     {kNumberTypeFloat16, onnx::TensorProto_DataType_FLOAT16},
1471     {kNumberTypeFloat32, onnx::TensorProto_DataType_FLOAT},
1472     {kNumberTypeFloat64, onnx::TensorProto_DataType_DOUBLE},
1473   };
1474   // clang-format on
1475 
1476   auto iter = type_map.find(type_id);
1477   if (iter == type_map.end()) {
1478     MS_LOG(EXCEPTION) << "Convert type error, unsupported type " << type_id;
1479   }
1480 
1481   return iter->second;
1482 }
1483 
SetValueInfoType(const AnfNodePtr & node,onnx::ValueInfoProto * const value_proto,int64_t output_index) const1484 void OnnxExporter::SetValueInfoType(const AnfNodePtr &node, onnx::ValueInfoProto *const value_proto,
1485                                     int64_t output_index) const {
1486   auto dtype = GetOutputType(node, output_index);
1487   auto shape = node->Shape();
1488 
1489   abstract::ShapePtr output_shape;
1490   if (shape->isa<abstract::TupleShape>()) {
1491     auto tuple_shape = dyn_cast<abstract::TupleShape>(shape);
1492     auto base_shape = tuple_shape->shape().at(static_cast<size_t>(output_index));
1493     output_shape = dyn_cast<abstract::Shape>(base_shape);
1494     if (output_shape == nullptr) {
1495       MS_LOG(EXCEPTION) << "Expected " << node->ToString() << " to output a tuple of tensors. Instead got "
1496                         << base_shape->ToString() << " from output " << output_index;
1497     }
1498   } else if (shape->isa<abstract::Shape>()) {
1499     output_shape = dyn_cast<abstract::Shape>(shape);
1500   } else {
1501     MS_LOG(EXCEPTION) << "Unsupported shape: " << shape->ToString();
1502   }
1503 
1504   auto *type_proto = value_proto->mutable_type();
1505   type_proto->mutable_tensor_type()->set_elem_type(dtype);
1506   auto *shape_proto = type_proto->mutable_tensor_type()->mutable_shape();
1507 
1508   for (const auto dim : output_shape->shape()) {
1509     shape_proto->add_dim()->set_dim_value(dim);
1510   }
1511 }
1512 
MatchAndMark(const FuncGraphPtr & func_graph,const std::vector<AnfNodePtr> & nodes,mindspore::HashMap<AnfNodePtr,OpMergedInfo> * op_merged_infos_ptr) const1513 void OnnxExporter::MatchAndMark(const FuncGraphPtr &func_graph, const std::vector<AnfNodePtr> &nodes,
1514                                 mindspore::HashMap<AnfNodePtr, OpMergedInfo> *op_merged_infos_ptr) const {
1515   auto &op_merged_infos = *op_merged_infos_ptr;
1516 
1517   for (auto &node : nodes) {
1518     if (!node->isa<CNode>() || IsZeroRefcountNode(node)) {
1519       continue;
1520     }
1521     auto cnode = node->cast<CNodePtr>();
1522     if (cnode == func_graph->get_return()) {
1523       // if the key `input` does not exist, just create a new one
1524       op_merged_infos[cnode].referred_count += 1;
1525     }
1526     for (auto &weak_input : cnode->weak_inputs()) {
1527       auto orig_input = weak_input.lock();
1528       MS_EXCEPTION_IF_NULL(orig_input);
1529       auto input = GetRealInput(orig_input);
1530       if (!input->isa<CNode>() || IsZeroRefcountNode(input)) {
1531         continue;
1532       }
1533       // if the key `input` does not exist, just create a new one
1534       op_merged_infos[input].referred_count += 1;
1535     }
1536     MatchAndMarkCNode(func_graph, cnode, op_merged_infos_ptr);
1537   }
1538 }
1539 
1540 struct MergeRule {
1541   PrimitivePtr node_type;
1542   PrimitivePtr prev_type;
1543   OpMergeMode merge_mode;
1544 };
1545 
MatchAndMarkCNode(const FuncGraphPtr & func_graph,const CNodePtr & cnode,mindspore::HashMap<AnfNodePtr,OpMergedInfo> * op_merged_infos_ptr) const1546 void OnnxExporter::MatchAndMarkCNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
1547                                      mindspore::HashMap<AnfNodePtr, OpMergedInfo> *op_merged_infos_ptr) const {
1548   MS_EXCEPTION_IF_NULL(op_merged_infos_ptr);
1549   auto &op_merged_infos = *op_merged_infos_ptr;
1550   const auto ignore = [&op_merged_infos](const AnfNodePtr &node) {
1551     op_merged_infos[node].mode = OP_MERGE_IGNORE;
1552     op_merged_infos[node].referred_count -= 1;
1553   };
1554 
1555   const std::vector<MergeRule> first_input_merge_rules = {
1556     {prim::kPrimBiasAdd, prim::kPrimConv2D, OP_MERGE_CONV},
1557     {prim::kPrimBiasAdd, prim::kPrimConv2DTranspose, OP_MERGE_CONV2D_TRANSPOSE},
1558     {prim::kPrimBiasAdd, prim::kPrimConv3D, OP_MERGE_CONV},
1559     {prim::kPrimBiasAdd, prim::kPrimConv3DTranspose, OP_MERGE_CONV},
1560     {prim::kPrimBiasAdd, prim::kPrimMatMul, OP_MERGE_GEMM},
1561     {prim::kPrimTupleGetItem, prim::kPrimBatchNorm, OP_MERGE_BATCH_NORM},
1562     {prim::kPrimTupleGetItem, prim::kPrimMaxPoolWithArgmax, OP_MERGE_MAXPOOL_WITH_ARGMAX},
1563     {prim::kPrimTupleGetItem, prim::kPrimLayerNorm, OP_MERGE_LAYER_NORM},
1564     {prim::kPrimTupleGetItem, prim::kPrimDynamicGRUV2, OP_MERGE_DYNAMIC_GRU_V2},
1565   };
1566 
1567   auto rule = std::find_if(first_input_merge_rules.begin(), first_input_merge_rules.end(), [&cnode](const auto &rule) {
1568     return cnode->IsApply(rule.node_type) && IsPrimitiveCNode(cnode->input(1), rule.prev_type);
1569   });
1570   if (rule != first_input_merge_rules.end()) {
1571     if (cnode->IsApply(prim::kPrimTupleGetItem) && GetInt64Value(cnode->input(kTwoNum)) != 0) {
1572       MS_LOG(EXCEPTION) << "Multiple outputs for node \"" << cnode->input(1)->ToString() << "\" are not supported";
1573     }
1574     op_merged_infos[cnode].mode = rule->merge_mode;
1575     ignore(cnode->input(1));
1576   } else if (while_loop_export::IsLoopBodyReturnNode(cnode, func_graph)) {
1577     // Ignore to replace with other outputs
1578     ignore(cnode);
1579     auto repeat_node = dyn_cast<CNode>(GetRealInput(cnode->input(1)));
1580     MS_EXCEPTION_IF_NULL(repeat_node);
1581     ignore(repeat_node);
1582   } else if (while_loop_export::IsAfterLoopReturnNode(cnode, func_graph)) {
1583     // Ignore to inline after-loop subgraph in main graph
1584     ignore(cnode);
1585     auto first_input = GetRealInput(cnode->input(1));
1586     if (IsPrimitiveCNode(first_input, prim::kPrimMakeTuple)) {
1587       ignore(first_input);
1588     }
1589   } else if (cnode == func_graph->get_return()) {
1590     auto first_input = GetRealInput(cnode->input(1));  // Unpack Depend
1591     // Ignore MakeTuple output node to avoid exporting it to SequenceConstruct
1592     // and handle multiple outputs in ExportOutput
1593     IgnoreMakeTuple(first_input, op_merged_infos_ptr);
1594   } else if (cnode->IsApply(prim::kPrimConcat) && IsPrimitiveCNode(cnode->input(1), prim::kPrimMakeTuple)) {
1595     // Ignore MakeTuple to handle it in ExportPrimConcat
1596     ignore(cnode->input(1));
1597   }
1598 }
1599 
IgnoreMakeTuple(const AnfNodePtr & node,mindspore::HashMap<AnfNodePtr,OpMergedInfo> * op_merged_infos_ptr) const1600 void OnnxExporter::IgnoreMakeTuple(const AnfNodePtr &node,
1601                                    mindspore::HashMap<AnfNodePtr, OpMergedInfo> *op_merged_infos_ptr) const {
1602   MS_EXCEPTION_IF_NULL(op_merged_infos_ptr);
1603   auto &op_merged_infos = *op_merged_infos_ptr;
1604   const auto ignore = [&op_merged_infos](const AnfNodePtr &node) {
1605     op_merged_infos[node].mode = OP_MERGE_IGNORE;
1606     op_merged_infos[node].referred_count -= 1;
1607   };
1608   if (node == nullptr) {
1609     return;
1610   }
1611 
1612   if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) {
1613     ignore(node);
1614     auto cnode = dyn_cast<CNode>(node);
1615     if (cnode != nullptr) {
1616       for (size_t i = 1; i < cnode->size(); ++i) {
1617         auto real_input = GetRealInput(cnode->input(i));
1618         IgnoreMakeTuple(real_input, op_merged_infos_ptr);
1619       }
1620     }
1621   }
1622 }
1623 
1624 /**
1625  * AnfNode
1626  * +-- CNode
1627  * +-- ANode
1628  * |   +-- Parameter
1629  * |   `-- ValueNode
1630  */
ExportNodes(const FuncGraphPtr & func_graph,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)1631 void OnnxExporter::ExportNodes(const FuncGraphPtr &func_graph, std::map<AnfNodePtr, std::string> *node_map_ptr,
1632                                onnx::GraphProto *const graph_proto) {
1633   std::vector<AnfNodePtr> nodes = TopoSort(func_graph->get_return(), SuccIncoming, AlwaysInclude);
1634 
1635   mindspore::HashMap<AnfNodePtr, OpMergedInfo> op_merged_infos;
1636   MatchAndMark(func_graph, nodes, &op_merged_infos);
1637   for (const AnfNodePtr &node : nodes) {
1638     if (!node->isa<CNode>()) {
1639       continue;
1640     }
1641 
1642     auto cnode = node->cast<CNodePtr>();
1643     auto iter = op_merged_infos.find(cnode);
1644     // the node is not referenced by any other nodes, skip it
1645     if (iter == op_merged_infos.end()) {
1646       continue;
1647     }
1648     auto merged_info = iter->second;
1649     // the op node is merged with other node and not used any more, skip it
1650     if (merged_info.mode == OP_MERGE_IGNORE && merged_info.referred_count == 0) {
1651       continue;
1652     }
1653     if (cnode == func_graph->get_return()) {
1654       ExportOutput(func_graph, cnode->input(kOneNum), node_map_ptr, graph_proto);
1655       continue;
1656     }
1657     switch (merged_info.mode) {
1658       case OP_MERGE_CONV:
1659         ExportMergeConv(func_graph, cnode, node_map_ptr, graph_proto);
1660         break;
1661       case OP_MERGE_GEMM:
1662         ExportMergeGemm(func_graph, cnode, node_map_ptr, graph_proto);
1663         break;
1664       case OP_MERGE_BATCH_NORM:
1665         ExportMergeBatchNorm(func_graph, cnode, node_map_ptr, graph_proto);
1666         break;
1667       case OP_MERGE_MAXPOOL_WITH_ARGMAX:
1668         ExportMergeMaxPoolWithArgmax(func_graph, cnode, node_map_ptr, graph_proto);
1669         break;
1670       case OP_MERGE_LAYER_NORM:
1671         ExportMergeLayerNorm(func_graph, cnode, node_map_ptr, graph_proto);
1672         break;
1673       case OP_MERGE_CONV2D_TRANSPOSE:
1674         ExportMergeConv2DTranspose(func_graph, cnode, node_map_ptr, graph_proto);
1675         break;
1676       case OP_MERGE_DYNAMIC_GRU_V2:
1677         ExportMergeDynamicGRUV2(func_graph, cnode, node_map_ptr, graph_proto);
1678         break;
1679       default:
1680         ExportCNode(func_graph, cnode, node_map_ptr, graph_proto);
1681         break;
1682     }
1683   }
1684 }
1685 
ExportPrimReshape(const FuncGraphPtr &,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)1686 void OnnxExporter::ExportPrimReshape(const FuncGraphPtr &, const CNodePtr &node,
1687                                      std::map<AnfNodePtr, std::string> *node_map_ptr,
1688                                      onnx::GraphProto *const graph_proto) {
1689   auto name_x = GetNodeInputName(node->input(kOneNum), node_map_ptr, graph_proto);
1690   auto input_shape = node->input(kTwoNum);
1691   std::string name_shape;
1692   if (input_shape->isa<ValueNode>()) {
1693     name_shape = RegisterNodeWithUniqueName(input_shape, node_map_ptr);
1694     onnx::NodeProto *node_proto = graph_proto->add_node();
1695     auto name = prim::kPrimReshape->name();
1696 
1697     node_proto->set_name(name_shape + name);
1698     node_proto->add_output(name_shape);
1699     node_proto->set_op_type("Constant");
1700     onnx::AttributeProto *attr_proto = node_proto->add_attribute();
1701     attr_proto->set_name("value");
1702     attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR);
1703     ConvertTupleToTensor(dyn_cast<ValueNode>(input_shape)->value(), attr_proto->mutable_t());
1704   } else {
1705     name_shape = GetNodeInputName(input_shape, node_map_ptr, graph_proto);
1706     MS_LOG(EXCEPTION) << "Need to insert op convert variable from tuple to tensor for Reshape.";
1707   }
1708 
1709   auto node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
1710   onnx::NodeProto *node_proto = graph_proto->add_node();
1711   node_proto->set_op_type(prim::kPrimReshape->name());
1712   node_proto->add_output(node_name);
1713   node_proto->add_input(name_x);
1714   node_proto->add_input(name_shape);
1715 }
1716 
ExportPrimReduce(const FuncGraphPtr &,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)1717 void OnnxExporter::ExportPrimReduce(const FuncGraphPtr &, const CNodePtr &node,
1718                                     std::map<AnfNodePtr, std::string> *node_map_ptr,
1719                                     onnx::GraphProto *const graph_proto) {
1720   auto input_data = GetNodeInputName(node->input(kOneNum), node_map_ptr, graph_proto);
1721   auto input_axis = node->input(kTwoNum);
1722   auto keep_dims = GetOpAttribute<bool>(node, "keep_dims");
1723 
1724   auto node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
1725 
1726   std::string name;
1727   if (node->IsApply(prim::kPrimReduceSum)) {
1728     name = "ReduceSum";
1729   } else if (node->IsApply(prim::kPrimReduceMean)) {
1730     name = "ReduceMean";
1731   } else if (node->IsApply(prim::kPrimReduceMax)) {
1732     name = "ReduceMax";
1733   } else {
1734     MS_LOG(EXCEPTION) << "Unsupported reduce op: " << node->ToString();
1735   }
1736 
1737   std::vector<int64_t> axes;
1738   if (input_axis->isa<ValueNode>()) {
1739     auto axis_value = dyn_cast<ValueNode>(input_axis)->value();
1740     if (axis_value->isa<Int32Imm>()) {
1741       auto int_ptr = dyn_cast<Int32Imm>(axis_value);
1742       axes.push_back(int_ptr->value());
1743     } else if (axis_value->isa<Int64Imm>()) {
1744       auto int_ptr = dyn_cast<Int64Imm>(axis_value);
1745       axes.push_back(int_ptr->value());
1746     } else if (axis_value->isa<ValueTuple>()) {
1747       auto tuple_ptr = dyn_cast<ValueTuple>(axis_value);
1748       axes = GetValue<std::vector<int64_t>>(tuple_ptr);
1749     } else {
1750       MS_LOG(EXCEPTION) << "Cannot convert value " << axis_value->ToString() << " of type "
1751                         << axis_value->type()->ToString() << " for \"axes\" attribute of " << name;
1752     }
1753   } else {
1754     MS_LOG(EXCEPTION) << "Need to insert op convert variable from tuple to attributes for " << name;
1755   }
1756 
1757   AddReduceOp(name, input_data, node_name, axes, keep_dims, graph_proto);
1758 }
1759 
ExportPrimReduceAnyOrAll(const FuncGraphPtr &,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)1760 void OnnxExporter::ExportPrimReduceAnyOrAll(const FuncGraphPtr &, const CNodePtr &node,
1761                                             std::map<AnfNodePtr, std::string> *node_map_ptr,
1762                                             onnx::GraphProto *const graph_proto) {
1763   auto input_data_name = GetNodeInputName(node->input(kOneNum), node_map_ptr, graph_proto);
1764   auto input_axis = node->input(kTwoNum);
1765   auto keep_dims = GetOpAttribute<bool>(node, "keep_dims");
1766   auto reduce_name = RegisterNodeWithUniqueName(node, node_map_ptr);
1767 
1768   std::string target_node_name = "";
1769   if (node->IsApply(prim::kPrimReduceAny)) {
1770     target_node_name = "ReduceSum";
1771   } else if (node->IsApply(prim::kPrimReduceAll)) {
1772     target_node_name = "ReduceMin";
1773   } else {
1774     MS_LOG(EXCEPTION) << "Unsupported reduce op: " << node->ToString();
1775   }
1776 
1777   std::string cast_name = GenerateUniqueName();  // Insert cast op
1778   onnx::NodeProto *cast_proto = graph_proto->add_node();
1779   cast_proto->add_input(input_data_name);
1780   cast_proto->add_output(cast_name);
1781   cast_proto->set_op_type(prim::kPrimCast->name());
1782   onnx::AttributeProto *attr_proto = cast_proto->add_attribute();
1783   attr_proto->set_name("to");
1784   attr_proto->set_type(onnx::AttributeProto_AttributeType_INT);
1785   attr_proto->set_i(GetOnnxDataType(TypeId::kNumberTypeFloat32));
1786 
1787   std::vector<int64_t> axes;
1788   if (input_axis->isa<ValueNode>()) {
1789     auto axis_value = dyn_cast<ValueNode>(input_axis)->value();
1790     if (axis_value->isa<Int32Imm>()) {
1791       auto int_ptr = dyn_cast<Int32Imm>(axis_value);
1792       axes.push_back(int_ptr->value());
1793     } else if (axis_value->isa<Int64Imm>()) {
1794       auto int_ptr = dyn_cast<Int64Imm>(axis_value);
1795       axes.push_back(int_ptr->value());
1796     } else if (axis_value->isa<ValueTuple>()) {
1797       auto tuple_ptr = dyn_cast<ValueTuple>(axis_value);
1798       axes = GetValue<std::vector<int64_t>>(tuple_ptr);
1799       if (axes.empty()) {
1800         const auto &x_shape = dyn_cast<abstract::Shape>(node->input(kOneNum)->Shape())->shape();
1801         for (size_t i = 0; i < x_shape.size(); ++i) {
1802           axes.push_back(static_cast<int64_t>(i));
1803         }
1804       }
1805     } else {
1806       MS_LOG(EXCEPTION) << "Cannot convert value " << axis_value->ToString() << " of type "
1807                         << axis_value->type()->ToString() << " for \"axes\" attribute of " << target_node_name;
1808     }
1809   } else {
1810     MS_LOG(EXCEPTION) << "Need to insert op convert variable from tuple to attributes for " << target_node_name;
1811   }
1812 
1813   std::string greater_name = GenerateUniqueName();
1814   onnx::TensorProto *zero_initializer_proto = graph_proto->add_initializer();
1815   auto zero_input_name = greater_name + "_zero";
1816   zero_initializer_proto->set_name(zero_input_name);
1817   zero_initializer_proto->set_data_type(GetOnnxDataType(kNumberTypeFloat32));
1818   zero_initializer_proto->add_float_data(0);
1819 
1820   AddReduceOp(target_node_name, cast_name, greater_name, axes, keep_dims, graph_proto);
1821 
1822   onnx::NodeProto *greater_node_proto = graph_proto->add_node();  // Insert greater op
1823   greater_node_proto->add_input(greater_name);
1824   greater_node_proto->add_input(zero_input_name);
1825   greater_node_proto->add_output(reduce_name);
1826   greater_node_proto->set_op_type(prim::kPrimGreater->name());
1827 }
1828 
ExportPrimTranspose(const FuncGraphPtr &,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)1829 void OnnxExporter::ExportPrimTranspose(const FuncGraphPtr &, const CNodePtr &node,
1830                                        std::map<AnfNodePtr, std::string> *node_map_ptr,
1831                                        onnx::GraphProto *const graph_proto) {
1832   auto input_data = GetNodeInputName(node->input(kOneNum), node_map_ptr, graph_proto);
1833   auto input_perm = node->input(kTwoNum);
1834   auto node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
1835   onnx::NodeProto *node_proto = graph_proto->add_node();
1836   auto name = prim::kPrimTranspose->name();
1837 
1838   node_proto->set_name(node_name + name);
1839   node_proto->set_op_type(name);
1840   node_proto->add_output(node_name);
1841   node_proto->add_input(input_data);
1842 
1843   if (input_perm->isa<ValueNode>()) {
1844     onnx::AttributeProto *attr_proto = node_proto->add_attribute();
1845     attr_proto->set_name("perm");
1846     attr_proto->set_type(onnx::AttributeProto_AttributeType_INTS);
1847     auto perm_value = dyn_cast<ValueNode>(input_perm)->value();
1848     auto int_ptr = dyn_cast<Int32Imm>(perm_value);
1849     if (int_ptr == nullptr) {
1850       auto tuple_ptr = dyn_cast<ValueTuple>(perm_value);
1851       MS_EXCEPTION_IF_NULL(tuple_ptr);
1852       for (size_t i = 0; i < tuple_ptr->size(); ++i) {
1853         attr_proto->add_ints(GetValue<int64_t>((*tuple_ptr)[i]));
1854       }
1855     } else {
1856       attr_proto->add_ints(int_ptr->value());
1857     }
1858   } else {
1859     MS_LOG(EXCEPTION) << "The input input_perm of Transpose is not a ValueNode! "
1860                       << "Need to insert op convert variable from tuple to attributes for " << name;
1861   }
1862 }
1863 
1864 /*
1865   See:
1866     - mindspore/ccsrc/backend/kernel_compiler/cpu/stridedslice_cpu_kernel.cc
1867     - mindspore/ccsrc/backend/kernel_compiler/common_utils.cc
1868  */
ExportPrimStridedSlice(const FuncGraphPtr &,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)1869 void OnnxExporter::ExportPrimStridedSlice(const FuncGraphPtr &, const CNodePtr &node,
1870                                           std::map<AnfNodePtr, std::string> *node_map_ptr,
1871                                           onnx::GraphProto *const graph_proto) {
1872   auto input_data = GetNodeInputName(node->input(kOneNum), node_map_ptr, graph_proto);
1873   auto node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
1874   auto name = node_name + prim::kPrimStridedSlice->name();
1875 
1876   auto begin = node->input(kTwoNum);
1877   auto begin_mask = node->input(kFiveNum);
1878   if (!begin->isa<ValueNode>() && !begin_mask->isa<ValueNode>()) {
1879     MS_LOG(EXCEPTION) << "The input begin of StridedSlice is not a ValueNode! "
1880                       << "Need to insert op convert variable from tuple to tensor for " << name;
1881   }
1882   auto begin_value_node = dyn_cast<ValueNode>(begin);
1883   auto begin_value = GetValue<std::vector<int64_t>>(begin_value_node->value());
1884   auto begin_ignore_mask = GetValue<int64_t>(dyn_cast<ValueNode>(begin_mask)->value());
1885   for (size_t i = 0; i < begin_value.size(); ++i) {
1886     if ((static_cast<uint64_t>(begin_ignore_mask) & (1UL << i)) != 0) {
1887       begin_value[i] = 0;
1888     }
1889   }
1890 
1891   auto end = node->input(kThreeNum);
1892   auto end_mask = node->input(kSixNum);
1893   if (!end->isa<ValueNode>() && !end_mask->isa<ValueNode>()) {
1894     MS_LOG(EXCEPTION) << "The input end of StridedSlice is not a ValueNode! "
1895                       << "Need to insert op convert variable from tuple to tensor for " << name;
1896   }
1897   auto end_value_node = dyn_cast<ValueNode>(end);
1898   auto end_value = GetValue<std::vector<int64_t>>(end_value_node->value());
1899   const auto &x_shape = dyn_cast<abstract::Shape>(node->input(kOneNum)->Shape())->shape();
1900   auto end_ignore_mask = GetValue<int64_t>(dyn_cast<ValueNode>(end_mask)->value());
1901   for (size_t i = 0; i < end_value.size(); ++i) {
1902     if ((static_cast<uint64_t>(end_ignore_mask) & (1UL << i)) != 0) {
1903       end_value[i] = x_shape[i];
1904     }
1905   }
1906 
1907   size_t axes_size = end_value.size();
1908   std::vector<int64_t> axes_value;
1909   for (size_t i = 0; i < axes_size; ++i) {
1910     axes_value.push_back(static_cast<int64_t>(i));
1911   }
1912 
1913   auto strides = node->input(kFourNum);
1914   auto shrink_mask = node->input(kNineNum);
1915   if (!strides->isa<ValueNode>() && !shrink_mask->isa<ValueNode>()) {
1916     MS_LOG(EXCEPTION) << "The input strides of StridedSlice is not a ValueNode! "
1917                       << "Need to insert op convert variable from tuple to tensor for " << name;
1918   }
1919   auto strides_value_node = dyn_cast<ValueNode>(strides);
1920   auto strides_value = GetValue<std::vector<int64_t>>(strides_value_node->value());
1921 
1922   auto shrink_axis_mask = GetValue<int64_t>(dyn_cast<ValueNode>(shrink_mask)->value());
1923   for (size_t i = 0; i < end_value.size(); ++i) {
1924     if ((static_cast<uint64_t>(shrink_axis_mask) & (1UL << i)) != 0) {
1925       strides_value[i] = end_value[i] > begin_value[i] ? 1 : -1;
1926       end_value[i] = begin_value[i] + strides_value[i];
1927     }
1928   }
1929 
1930   auto slice_name = node_name;
1931   if (shrink_axis_mask != 0) {
1932     slice_name = node_name + "__reshape";
1933   }
1934 
1935   AddSliceOp(input_data, slice_name, begin_value, end_value, axes_value, strides_value, graph_proto);
1936 
1937   if (shrink_axis_mask != 0) {
1938     onnx::NodeProto *squeeze_op = graph_proto->add_node();
1939     squeeze_op->set_op_type("Squeeze");
1940     squeeze_op->add_input(slice_name);
1941     squeeze_op->add_output(node_name);
1942     onnx::AttributeProto *axes_attr = squeeze_op->add_attribute();
1943     axes_attr->set_name("axes");
1944     axes_attr->set_type(onnx::AttributeProto_AttributeType_INTS);
1945     for (size_t i = 0; i < x_shape.size(); ++i) {
1946       if ((static_cast<uint64_t>(shrink_axis_mask) & (1UL << i)) != 0) {
1947         axes_attr->add_ints(static_cast<int64_t>(i));
1948       }
1949     }
1950   }
1951 }
1952 
PrimResizeExportHelper(const FuncGraphPtr &,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)1953 onnx::NodeProto *OnnxExporter::PrimResizeExportHelper(const FuncGraphPtr &, const CNodePtr &node,
1954                                                       std::map<AnfNodePtr, std::string> *node_map_ptr,
1955                                                       onnx::GraphProto *const graph_proto) {
1956   auto input_data = GetNodeInputName(node->input(kOneNum), node_map_ptr, graph_proto);
1957   auto x_shape = dyn_cast<abstract::Shape>(node->input(kOneNum)->Shape());
1958 
1959   AnfNodePtr op = node->input(kZeroNum);
1960   auto op_value = dyn_cast<ValueNode>(op);
1961   auto prim = dyn_cast<Primitive>(op_value->value());
1962   std::vector<int64_t> resize_size;
1963 
1964   auto tuple_ptr = dyn_cast<ValueSequence>(prim->GetAttr("size"));  // size may be Tuple or List
1965   if (tuple_ptr == nullptr) {
1966     MS_LOG(EXCEPTION) << "Got null pointer, currently the " << prim->name()
1967                       << " operator in your model is not support for exporting onnx.";
1968   }
1969 
1970   for (size_t i = 0; i < x_shape->shape().size() - kTwoNum; i++) {
1971     resize_size.push_back(x_shape->shape()[i]);
1972   }
1973   for (size_t i = 0; i < tuple_ptr->size(); i++) {
1974     ValuePtr elem = (*tuple_ptr)[i];
1975     resize_size.push_back(dyn_cast<Int64Imm>(elem)->value());
1976   }
1977   auto resize_size_ptr = MakeValue<std::vector<int64_t>>(resize_size);
1978   auto size = NewValueNode(resize_size_ptr)->cast<AnfNodePtr>();
1979 
1980   auto name_size = RegisterNodeWithUniqueName(size, node_map_ptr);
1981   onnx::NodeProto *node_proto_size = graph_proto->add_node();
1982   node_proto_size->add_output(name_size);
1983   node_proto_size->set_op_type("Constant");
1984   onnx::AttributeProto *attr_proto = node_proto_size->add_attribute();
1985   attr_proto->set_name("value");
1986   attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR);
1987   ConvertTupleToTensor(resize_size_ptr, attr_proto->mutable_t());
1988 
1989   auto node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
1990 
1991   onnx::TensorProto *roi_initializer_proto = graph_proto->add_initializer();
1992   auto roi_name = node_name + "roi_initializer";
1993   roi_initializer_proto->set_name(roi_name);
1994   roi_initializer_proto->set_data_type(GetOnnxDataType(kNumberTypeFloat32));
1995   roi_initializer_proto->add_dims(0);
1996 
1997   onnx::TensorProto *scales_initializer_proto = graph_proto->add_initializer();
1998   auto scales_name = node_name + "scales_initializer";
1999   scales_initializer_proto->set_name(scales_name);
2000   scales_initializer_proto->set_data_type(GetOnnxDataType(kNumberTypeFloat32));
2001   scales_initializer_proto->add_dims(0);
2002 
2003   onnx::NodeProto *node_proto = graph_proto->add_node();
2004 
2005   node_proto->set_op_type("Resize");
2006   node_proto->add_output(node_name);
2007   node_proto->add_input(input_data);
2008   node_proto->add_input(roi_name);
2009   node_proto->add_input(scales_name);
2010   node_proto->add_input(name_size);
2011 
2012   return node_proto;
2013 }
2014 
ExportPrimResizeNearestNeighbor(const FuncGraphPtr & graph,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)2015 void OnnxExporter::ExportPrimResizeNearestNeighbor(const FuncGraphPtr &graph, const CNodePtr &node,
2016                                                    std::map<AnfNodePtr, std::string> *node_map_ptr,
2017                                                    onnx::GraphProto *const graph_proto) {
2018   onnx::NodeProto *node_proto = PrimResizeExportHelper(graph, node, node_map_ptr, graph_proto);
2019 
2020   auto align_corners = GetOpAttribute<bool>(node, "align_corners");
2021   std::string coordinate_transformation_mode = align_corners ? "align_corners" : "asymmetric";
2022   // `nearest_mode` is based on ResizeNearestNeighborCPUKernel::LaunchKernel in
2023   // mindspore/ccsrc/backend/kernel_compiler/cpu/resize_nearest_neighbor_cpu_kernel.cc
2024   std::string nearest_mode = align_corners ? "round_prefer_ceil" : "floor";
2025 
2026   onnx::AttributeProto *coordinate_mode_proto = node_proto->add_attribute();
2027   coordinate_mode_proto->set_name("coordinate_transformation_mode");
2028   coordinate_mode_proto->set_type(onnx::AttributeProto_AttributeType_STRING);
2029   coordinate_mode_proto->set_s(coordinate_transformation_mode);
2030 
2031   onnx::AttributeProto *nearest_mode_proto = node_proto->add_attribute();
2032   nearest_mode_proto->set_name("nearest_mode");
2033   nearest_mode_proto->set_type(onnx::AttributeProto_AttributeType_STRING);
2034   nearest_mode_proto->set_s(nearest_mode);
2035 }
2036 
ExportPrimResizeBilinear(const FuncGraphPtr & graph,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)2037 void OnnxExporter::ExportPrimResizeBilinear(const FuncGraphPtr &graph, const CNodePtr &node,
2038                                             std::map<AnfNodePtr, std::string> *node_map_ptr,
2039                                             onnx::GraphProto *const graph_proto) {
2040   onnx::NodeProto *node_proto = PrimResizeExportHelper(graph, node, node_map_ptr, graph_proto);
2041 
2042   auto align_corners = GetOpAttribute<bool>(node, "align_corners");
2043   std::string coordinate_transformation_mode = align_corners ? "align_corners" : "asymmetric";
2044 
2045   onnx::AttributeProto *coordinate_mode_proto = node_proto->add_attribute();
2046   coordinate_mode_proto->set_name("coordinate_transformation_mode");
2047   coordinate_mode_proto->set_type(onnx::AttributeProto_AttributeType_STRING);
2048   coordinate_mode_proto->set_s(coordinate_transformation_mode);
2049 
2050   onnx::AttributeProto *mode_proto = node_proto->add_attribute();
2051   mode_proto->set_name("mode");
2052   mode_proto->set_type(onnx::AttributeProto_AttributeType_STRING);
2053   mode_proto->set_s("linear");
2054 }
2055 
2056 // MindSpore ExpandDims -> ONNX Reshape
ExportPrimExpandDims(const FuncGraphPtr &,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)2057 void OnnxExporter::ExportPrimExpandDims(const FuncGraphPtr &, const CNodePtr &node,
2058                                         std::map<AnfNodePtr, std::string> *node_map_ptr,
2059                                         onnx::GraphProto *const graph_proto) {
2060   auto input_x = GetNodeInputName(node->input(kOneNum), node_map_ptr, graph_proto);
2061   auto axis = GetInt64Value(node->input(kTwoNum));
2062   auto x_shape = dyn_cast<abstract::Shape>(node->input(kOneNum)->Shape());
2063   auto name = prim::kPrimExpandDims->name();
2064 
2065   std::vector<int64_t> new_shape;
2066   for (size_t i = 0; i < x_shape->shape().size(); i++) {
2067     new_shape.push_back(x_shape->shape()[i]);
2068   }
2069   if (axis < 0) {
2070     axis = axis + kOneNumLong + SizeToLong(x_shape->shape().size());
2071   }
2072   (void)new_shape.insert(new_shape.begin() + axis, kOneNum);
2073   auto new_shape_value = MakeValue<std::vector<int64_t>>(new_shape);
2074   auto shape = NewValueNode(new_shape_value)->cast<AnfNodePtr>();
2075   std::string name_shape;
2076 
2077   if (shape->isa<ValueNode>()) {
2078     name_shape = RegisterNodeWithUniqueName(shape, node_map_ptr);
2079     onnx::NodeProto *node_proto = graph_proto->add_node();
2080     node_proto->add_output(name_shape);
2081     node_proto->set_op_type("Constant");
2082     onnx::AttributeProto *attr_proto = node_proto->add_attribute();
2083     attr_proto->set_name("value");
2084     attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR);
2085     ConvertTupleToTensor(dyn_cast<ValueNode>(shape)->value(), attr_proto->mutable_t());
2086   } else {
2087     name_shape = GetNodeInputName(shape, node_map_ptr, graph_proto);
2088     MS_LOG(EXCEPTION) << "Need to insert op convert variable from tuple to tensor for " << name;
2089   }
2090 
2091   auto node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
2092   onnx::NodeProto *node_proto = graph_proto->add_node();
2093   node_proto->set_op_type("Reshape");
2094   node_proto->add_output(node_name);
2095   node_proto->add_input(input_x);
2096   node_proto->add_input(name_shape);
2097 }
2098 
2099 // MindSpore GatherD -> ONNX GatherElements
ExportPrimGatherD(const FuncGraphPtr &,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)2100 void OnnxExporter::ExportPrimGatherD(const FuncGraphPtr &, const CNodePtr &node,
2101                                      std::map<AnfNodePtr, std::string> *node_map_ptr,
2102                                      onnx::GraphProto *const graph_proto) {
2103   auto input_x = GetNodeInputName(node->input(kOneNum), node_map_ptr, graph_proto);
2104   auto axis = GetInt64Value(node->input(kTwoNum));
2105   auto input_indices = GetNodeInputName(node->input(kThreeNum), node_map_ptr, graph_proto);
2106   auto node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
2107   onnx::NodeProto *node_proto = graph_proto->add_node();
2108   node_proto->set_op_type("GatherElements");
2109   node_proto->add_output(node_name);
2110   node_proto->add_input(input_x);
2111   node_proto->add_input(input_indices);
2112   onnx::AttributeProto *attr_proto = node_proto->add_attribute();
2113   attr_proto->set_name("axis");
2114   attr_proto->set_type(onnx::AttributeProto_AttributeType_INT);
2115   attr_proto->set_i(static_cast<::google::protobuf::int64>(axis));
2116 }
2117 
2118 // MindSpore Pad -> ONNX Pad
ExportPrimPad(const FuncGraphPtr &,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)2119 void OnnxExporter::ExportPrimPad(const FuncGraphPtr &, const CNodePtr &node,
2120                                  std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *const graph_proto) {
2121   auto x_name = GetNodeInputName(node->input(kOneNum), node_map_ptr, graph_proto);
2122 
2123   auto paddings = GetOpAttributePtr<ValueTuple>(node, "paddings");
2124   std::vector<std::vector<int64_t>> paddings_values = GetValue<std::vector<std::vector<int64_t>>>(paddings);
2125   std::vector<int64_t> pads_sequence;
2126   for (size_t i = 0; i < paddings_values.size(); ++i) {
2127     pads_sequence.push_back(paddings_values[i][0]);
2128   }
2129   for (size_t j = 0; j < paddings_values.size(); ++j) {
2130     pads_sequence.push_back(paddings_values[j][1]);
2131   }
2132   auto pads_ptr = MakeValue<std::vector<int64_t>>(pads_sequence);
2133   auto pads = NewValueNode(pads_ptr)->cast<AnfNodePtr>();
2134 
2135   auto pads_name = RegisterNodeWithUniqueName(pads, node_map_ptr);
2136   onnx::NodeProto *pads_node = graph_proto->add_node();
2137   pads_node->add_output(pads_name);
2138   pads_node->set_op_type("Constant");
2139   onnx::AttributeProto *pads_attr_proto = pads_node->add_attribute();
2140   pads_attr_proto->set_name("value");
2141   pads_attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR);
2142   ConvertTupleToTensor(pads_ptr, pads_attr_proto->mutable_t());
2143 
2144   auto ms_pad_node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
2145   onnx::NodeProto *onnx_pad_node = graph_proto->add_node();
2146   onnx_pad_node->set_op_type("Pad");
2147   onnx_pad_node->add_output(ms_pad_node_name);
2148   onnx_pad_node->add_input(x_name);
2149   onnx_pad_node->add_input(pads_name);
2150 }
2151 
2152 // MindSpore BatchMatMul -> ONNX Transpose + MatMul
ExportPrimBatchMatMul(const FuncGraphPtr &,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)2153 void OnnxExporter::ExportPrimBatchMatMul(const FuncGraphPtr &, const CNodePtr &node,
2154                                          std::map<AnfNodePtr, std::string> *node_map_ptr,
2155                                          onnx::GraphProto *const graph_proto) {
2156   auto input_x = GetNodeInputName(node->input(kOneNum), node_map_ptr, graph_proto);
2157   auto input_y = GetNodeInputName(node->input(kTwoNum), node_map_ptr, graph_proto);
2158 
2159   AnfNodePtr batchmatmul_op = node->input(kZeroNum);
2160   auto op_value = dyn_cast<ValueNode>(batchmatmul_op);
2161   auto prim = dyn_cast<Primitive>(op_value->value());
2162   auto transpose_a = GetValue<bool>(prim->GetAttr("transpose_a"));
2163   auto transpose_b = GetValue<bool>(prim->GetAttr("transpose_b"));
2164   std::string transpose_input_x_name = "";
2165   std::string transpose_input_y_name = "";
2166 
2167   if (transpose_a) {
2168     auto input_x_shape = dyn_cast<abstract::Shape>(node->input(kOneNum)->Shape());
2169     // Add Transpose node after input_x of BatchMatMul
2170     transpose_input_x_name = GenerateUniqueName();
2171     onnx::NodeProto *transpose_inputx_node_proto = graph_proto->add_node();
2172     transpose_inputx_node_proto->add_input(input_x);
2173     transpose_inputx_node_proto->add_output(transpose_input_x_name);
2174     transpose_inputx_node_proto->set_op_type(prim::kPrimTranspose->name());
2175     onnx::AttributeProto *attr_proto = transpose_inputx_node_proto->add_attribute();
2176     attr_proto->set_name("perm");
2177     attr_proto->set_type(onnx::AttributeProto_AttributeType_INTS);
2178     for (size_t i = 0; i < input_x_shape->shape().size() - kTwoNum; i++) {
2179       attr_proto->add_ints(SizeToLong(i));
2180     }
2181     attr_proto->add_ints(SizeToLong(input_x_shape->shape().size()) - IntToLong(kOneNum));
2182     attr_proto->add_ints(SizeToLong(input_x_shape->shape().size()) - IntToLong(kTwoNum));
2183   }
2184   if (transpose_b) {
2185     auto input_y_shape = dyn_cast<abstract::Shape>(node->input(kTwoNum)->Shape());
2186     // Add Transpose node after input_y of BatchMatMul
2187     transpose_input_y_name = GenerateUniqueName();
2188     onnx::NodeProto *transpose_inputy_node_proto = graph_proto->add_node();
2189     transpose_inputy_node_proto->add_input(input_y);
2190     transpose_inputy_node_proto->add_output(transpose_input_y_name);
2191     transpose_inputy_node_proto->set_op_type(prim::kPrimTranspose->name());
2192     onnx::AttributeProto *attr_proto = transpose_inputy_node_proto->add_attribute();
2193     attr_proto->set_name("perm");
2194     attr_proto->set_type(onnx::AttributeProto_AttributeType_INTS);
2195     for (size_t i = 0; i < input_y_shape->shape().size() - kTwoNum; i++) {
2196       attr_proto->add_ints(SizeToLong(i));
2197     }
2198     attr_proto->add_ints(SizeToLong(input_y_shape->shape().size()) - IntToLong(kOneNum));
2199     attr_proto->add_ints(SizeToLong(input_y_shape->shape().size()) - IntToLong(kTwoNum));
2200   }
2201 
2202   auto node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
2203   onnx::NodeProto *node_proto = graph_proto->add_node();
2204   node_proto->set_op_type("MatMul");
2205   node_proto->add_output(node_name);
2206   node_proto->set_name(node_name + "MatMul");
2207   if (transpose_a) {
2208     node_proto->add_input(transpose_input_x_name);
2209   } else {
2210     node_proto->add_input(input_x);
2211   }
2212   if (transpose_b) {
2213     node_proto->add_input(transpose_input_y_name);
2214   } else {
2215     node_proto->add_input(input_y);
2216   }
2217 }
2218 
2219 // MindSpore BroadcastTo -> ONNX Expand
ExportPrimBroadcastTo(const FuncGraphPtr &,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)2220 void OnnxExporter::ExportPrimBroadcastTo(const FuncGraphPtr &, const CNodePtr &node,
2221                                          std::map<AnfNodePtr, std::string> *node_map_ptr,
2222                                          onnx::GraphProto *const graph_proto) {
2223   auto input_x = GetNodeInputName(node->input(kOneNum), node_map_ptr, graph_proto);
2224   auto x_shape = dyn_cast<abstract::Shape>(node->input(kOneNum)->Shape());
2225   auto name = prim::kPrimBroadcastTo->name();
2226 
2227   auto shape_ptr = GetOpAttributePtr<ValueSequeue>(node, "shape");
2228   auto shape_vec = GetValue<std::vector<int64_t>>(shape_ptr);
2229   size_t n_shape = shape_vec.size();
2230 
2231   std::vector<int64_t> new_shape;
2232   for (size_t i = 0; i < n_shape; i++) {
2233     if (shape_vec[i] == -kOneNum) {
2234       size_t ids = i + x_shape->shape().size() - n_shape;
2235       new_shape.push_back(x_shape->shape()[ids]);
2236     } else {
2237       new_shape.push_back(shape_vec[i]);
2238     }
2239   }
2240 
2241   auto new_shape_value = MakeValue<std::vector<int64_t>>(new_shape);
2242   auto shape = NewValueNode(new_shape_value)->cast<AnfNodePtr>();
2243   std::string name_shape;
2244 
2245   if (shape->isa<ValueNode>()) {
2246     name_shape = RegisterNodeWithUniqueName(shape, node_map_ptr);
2247     onnx::NodeProto *node_proto = graph_proto->add_node();
2248     node_proto->add_output(name_shape);
2249     node_proto->set_op_type("Constant");
2250     onnx::AttributeProto *attr_proto = node_proto->add_attribute();
2251     attr_proto->set_name("value");
2252     attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR);
2253     ConvertTupleToTensor(dyn_cast<ValueNode>(shape)->value(), attr_proto->mutable_t());
2254   } else {
2255     name_shape = GetNodeInputName(shape, node_map_ptr, graph_proto);
2256     MS_LOG(EXCEPTION) << "Need to insert op convert variable from tuple to tensor for " << name;
2257   }
2258 
2259   auto node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
2260   onnx::NodeProto *node_proto = graph_proto->add_node();
2261   node_proto->set_op_type("Expand");
2262   node_proto->add_output(node_name);
2263   node_proto->add_input(input_x);
2264   node_proto->add_input(name_shape);
2265 }
2266 
2267 // MindSpore AddN -> ONNX Add
ExportPrimAddN(const FuncGraphPtr &,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)2268 void OnnxExporter::ExportPrimAddN(const FuncGraphPtr &, const CNodePtr &node,
2269                                   std::map<AnfNodePtr, std::string> *node_map_ptr,
2270                                   onnx::GraphProto *const graph_proto) {
2271   auto node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
2272 
2273   auto input_node = node->input(kOneNum)->cast<CNodePtr>();
2274   auto last_input_name = GetNodeInputName(input_node->input(kOneNum), node_map_ptr, graph_proto);
2275   for (size_t i = kTwoNum; i < input_node->size() - 1; ++i) {
2276     auto input_name = GetNodeInputName(input_node->input(i), node_map_ptr, graph_proto);
2277     auto tmp_end_name = node_name + "ADD_" + std::to_string(i);
2278     AddOp("Add", {last_input_name, input_name}, {tmp_end_name}, graph_proto);
2279     last_input_name = tmp_end_name;
2280   }
2281   auto input_end_name = GetNodeInputName(input_node->input(input_node->size() - 1), node_map_ptr, graph_proto);
2282   AddOp("Add", {last_input_name, input_end_name}, {node_name}, graph_proto);
2283 }
2284 
2285 // MindSpore GeLU -> ONNX 0.5 * X * (1.0 + tanh((sqrt(2/pi) * (x + 0.044715 * pow(x, 3)))))
ExportPrimGeLU(const FuncGraphPtr &,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)2286 void OnnxExporter::ExportPrimGeLU(const FuncGraphPtr &, const CNodePtr &node,
2287                                   std::map<AnfNodePtr, std::string> *node_map_ptr,
2288                                   onnx::GraphProto *const graph_proto) {
2289   auto input_x = GetNodeInputName(node->input(kOneNum), node_map_ptr, graph_proto);
2290   auto onnx_type = GetOutputType(node->input(kOneNum));
2291 
2292   // Add pow node
2293   auto pow_name = GenerateUniqueName();
2294   auto exp_node_name = pow_name + "exponent_initializer";
2295   AddFloatTensor1DInitializer(exp_node_name, {3.0}, onnx_type, graph_proto);
2296   AddOp("Pow", {input_x, exp_node_name}, {pow_name}, graph_proto);
2297 
2298   // Add first Mul Node
2299   auto fmul_name = GenerateUniqueName();
2300   auto fmul_input_node_name = fmul_name + "input_y_for_mul_initializer";
2301   AddFloatTensor1DInitializer(fmul_input_node_name, {0.044715}, onnx_type, graph_proto);
2302   AddOp("Mul", {pow_name, fmul_input_node_name}, {fmul_name}, graph_proto);
2303 
2304   // Add first Add node
2305   auto fadd_name = GenerateUniqueName();
2306   AddOp("Add", {input_x, fmul_name}, {fadd_name}, graph_proto);
2307 
2308   // Add second Mul Node
2309   auto smul_name = GenerateUniqueName();
2310   auto smul_input_node_name = smul_name + "input_y_for_smul_initializer";
2311   AddFloatTensor1DInitializer(smul_input_node_name, {0.7978845608}, onnx_type, graph_proto);
2312   AddOp("Mul", {fadd_name, smul_input_node_name}, {smul_name}, graph_proto);
2313 
2314   // Add tanh node
2315   auto tanh_name = GenerateUniqueName();
2316   AddOp("Tanh", {smul_name}, {tanh_name}, graph_proto);
2317 
2318   // Add second Add node
2319   auto sadd_name = GenerateUniqueName();
2320   auto sadd_input_node_name = sadd_name + "input_y_for_sadd_initializer";
2321   AddFloatTensor1DInitializer(sadd_input_node_name, {1.0}, onnx_type, graph_proto);
2322   AddOp("Add", {tanh_name, sadd_input_node_name}, {sadd_name}, graph_proto);
2323 
2324   // Add third Mul Node
2325   auto tmul_name = GenerateUniqueName();
2326   auto tmul_input_node_name = tmul_name + "input_y_for_tmul_initializer";
2327   AddFloatTensor1DInitializer(tmul_input_node_name, {0.5}, onnx_type, graph_proto);
2328   AddOp("Mul", {sadd_name, tmul_input_node_name}, {tmul_name}, graph_proto);
2329 
2330   // Add fourth Mul Node
2331   auto fomul_node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
2332   AddOp("Mul", {input_x, tmul_name}, {fomul_node_name}, graph_proto);
2333 }
2334 
ExportPrimConcat(const FuncGraphPtr &,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)2335 void OnnxExporter::ExportPrimConcat(const FuncGraphPtr &, const CNodePtr &node,
2336                                     std::map<AnfNodePtr, std::string> *node_map_ptr,
2337                                     onnx::GraphProto *const graph_proto) {
2338   auto node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
2339 
2340   // Get inputs first: otherwise if an input is a constant, topological order will break
2341   auto input_node = node->input(kOneNum)->cast<CNodePtr>();
2342   std::vector<std::string> input_names;
2343   if (input_node->IsApply(prim::kPrimMakeTuple)) {
2344     for (size_t i = 1; i < input_node->size(); ++i) {
2345       auto input_name = GetNodeInputName(input_node->input(i), node_map_ptr, graph_proto);
2346       input_names.push_back(input_name);
2347     }
2348   } else {
2349     auto input_data = GetNodeInputName(node->input(kOneNum), node_map_ptr, graph_proto);
2350     input_names.push_back(input_data);
2351   }
2352 
2353   AddConcatOp(input_names, node_name, GetOpAttribute<int64_t>(node, "axis"), graph_proto);
2354 }
2355 
ExportPrimCast(const FuncGraphPtr &,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)2356 void OnnxExporter::ExportPrimCast(const FuncGraphPtr &, const CNodePtr &node,
2357                                   std::map<AnfNodePtr, std::string> *node_map_ptr,
2358                                   onnx::GraphProto *const graph_proto) {
2359   auto input_data = GetNodeInputName(node->input(kOneNum), node_map_ptr, graph_proto);
2360   auto input_type = node->input(kTwoNum);
2361 
2362   auto node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
2363   onnx::NodeProto *node_proto = graph_proto->add_node();
2364   node_proto->set_op_type(prim::kPrimCast->name());
2365   node_proto->add_output(node_name);
2366   node_proto->add_input(input_data);
2367 
2368   if (input_type->isa<ValueNode>()) {
2369     onnx::AttributeProto *attr_proto = node_proto->add_attribute();
2370     attr_proto->set_name("to");
2371     attr_proto->set_type(onnx::AttributeProto_AttributeType_INT);
2372     auto type_value = dyn_cast<ValueNode>(input_type)->value();
2373     auto type_ptr = dyn_cast<Int64Imm>(type_value);
2374     MS_EXCEPTION_IF_NULL(type_ptr);
2375     auto type_id = static_cast<TypeId>(type_ptr->value());
2376     attr_proto->set_i(GetOnnxDataType(type_id));
2377   } else {
2378     MS_LOG(EXCEPTION) << "Need to convert MindSpore Cast input(1) to ONNX Cast to attribute.";
2379   }
2380 }
2381 
ExportPrimPReLU(const FuncGraphPtr &,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)2382 void OnnxExporter::ExportPrimPReLU(const FuncGraphPtr &, const CNodePtr &node,
2383                                    std::map<AnfNodePtr, std::string> *node_map_ptr,
2384                                    onnx::GraphProto *const graph_proto) {
2385   auto input_x = GetNodeInputName(node->input(kOneNum), node_map_ptr, graph_proto);
2386   auto input_slope = GetNodeInputName(node->input(kTwoNum), node_map_ptr, graph_proto);
2387 
2388   auto x_shape = dyn_cast<abstract::Shape>(node->input(kOneNum)->Shape());
2389   auto slope_shape = dyn_cast<abstract::Shape>(node->input(kTwoNum)->Shape());
2390   MS_EXCEPTION_IF_NULL(x_shape);
2391   MS_EXCEPTION_IF_NULL(slope_shape);
2392 
2393   // format of x is NCHW, input format is NCHW, if length of input_slope is 1, insert Unsqueeze [1,2]
2394   if (x_shape->shape().size() == kFourNum && slope_shape->shape().size() == kOneNum) {
2395     auto node_name = GenerateUniqueName();
2396     onnx::NodeProto *node_proto = graph_proto->add_node();
2397     node_proto->set_op_type("Unsqueeze");
2398     node_proto->add_output(node_name);
2399 
2400     onnx::AttributeProto *attr_proto = node_proto->add_attribute();
2401     attr_proto->set_type(onnx::AttributeProto_AttributeType_INTS);
2402     attr_proto->set_name("axes");
2403     attr_proto->add_ints(kOneNum);
2404     attr_proto->add_ints(kTwoNum);
2405 
2406     node_proto->add_input(input_slope);
2407     input_slope = node_name;
2408   }
2409 
2410   auto node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
2411   onnx::NodeProto *node_proto = graph_proto->add_node();
2412   node_proto->set_op_type("PRelu");
2413   node_proto->add_output(node_name);
2414   node_proto->add_input(input_x);
2415   node_proto->add_input(input_slope);
2416 }
2417 
ExportPrimReLU6(const FuncGraphPtr &,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)2418 void OnnxExporter::ExportPrimReLU6(const FuncGraphPtr &, const CNodePtr &node,
2419                                    std::map<AnfNodePtr, std::string> *node_map_ptr,
2420                                    onnx::GraphProto *const graph_proto) {
2421   auto node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
2422 
2423   auto input_x_name = GetNodeInputName(node->input(kOneNum), node_map_ptr, graph_proto);
2424   auto onnx_input_type = GetOutputType(node->input(kOneNum));
2425   AddClipOp(input_x_name, node_name, 0.0f, 6.0f, onnx_input_type, graph_proto);
2426 }
2427 
ExportPrimDepthwiseConv2d(const FuncGraphPtr &,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)2428 void OnnxExporter::ExportPrimDepthwiseConv2d(const FuncGraphPtr &, const CNodePtr &node,
2429                                              std::map<AnfNodePtr, std::string> *node_map_ptr,
2430                                              onnx::GraphProto *const graph_proto) {
2431   auto input_x = GetNodeInputName(node->input(kOneNum), node_map_ptr, graph_proto);
2432   auto input_w = GetNodeInputName(node->input(kTwoNum), node_map_ptr, graph_proto);
2433   auto x_shape = dyn_cast<abstract::Shape>(node->input(kOneNum)->Shape());
2434   auto w_shape = dyn_cast<abstract::Shape>(node->input(kTwoNum)->Shape());
2435   MS_EXCEPTION_IF_NULL(x_shape);
2436   MS_EXCEPTION_IF_NULL(w_shape);
2437   if (x_shape->shape().size() != kFourNum || w_shape->shape().size() != kFourNum) {
2438     MS_LOG(EXCEPTION) << "DepthwiseConv2d input shape should be 4d.";
2439   }
2440   if (w_shape->shape()[kZeroNum] != kOneNum && w_shape->shape()[kOneNum] != kOneNum) {
2441     MS_LOG(EXCEPTION) << "DepthwiseConv2d weight shape[0] != 1 and shape[1] != 1, cannot reshape";
2442   }
2443   // create w_shape constant node
2444   auto node_name = GenerateUniqueName();
2445   onnx::NodeProto *node_proto = graph_proto->add_node();
2446   auto name_w_shape = node_name;
2447   node_proto->add_output(name_w_shape);
2448   node_proto->set_op_type("Constant");
2449   // create Value Tensor
2450   onnx::AttributeProto *attr_proto = node_proto->add_attribute();
2451   attr_proto->set_name("value");
2452   attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR);
2453   onnx::TensorProto *tensor_proto = attr_proto->mutable_t();
2454   tensor_proto->add_dims(static_cast<::google::protobuf::int64>(w_shape->shape().size()));
2455   tensor_proto->set_data_type(onnx::TensorProto_DataType_INT64);
2456   // reshape
2457   tensor_proto->add_int64_data(w_shape->shape()[kOneNum]);
2458   tensor_proto->add_int64_data(w_shape->shape()[kZeroNum]);
2459   tensor_proto->add_int64_data(w_shape->shape()[kTwoNum]);
2460   tensor_proto->add_int64_data(w_shape->shape()[kThreeNum]);
2461 
2462   // add reshape node
2463   node_name = GenerateUniqueName();
2464   node_proto = graph_proto->add_node();
2465   node_proto->set_op_type(prim::kPrimReshape->name());
2466   node_proto->add_input(input_w);
2467   node_proto->add_input(name_w_shape);
2468   input_w = node_name;
2469   node_proto->add_output(input_w);
2470 
2471   // add conv node
2472   node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
2473   node_proto = graph_proto->add_node();
2474   node_proto->set_op_type("Conv");
2475   node_proto->add_input(input_x);
2476   node_proto->add_input(input_w);
2477   node_proto->add_output(node_name);
2478   // set attributes
2479   AnfNodePtr op = node->input(0);
2480   auto op_value = dyn_cast<ValueNode>(op);
2481   auto prim = dyn_cast<Primitive>(op_value->value());
2482   // set dilations
2483   onnx::AttributeProto *onnx_attr_proto = node_proto->add_attribute();
2484   onnx_attr_proto->set_name("dilations");
2485   SetAttrTupleValueToProto<2>(prim->GetAttr("dilation"), onnx::AttributeProto_AttributeType_INTS, onnx_attr_proto,
2486                               prim);
2487   // set group
2488   onnx_attr_proto = node_proto->add_attribute();
2489   onnx_attr_proto->set_name("group");
2490   onnx_attr_proto->set_type(onnx::AttributeProto_AttributeType_INT);
2491   onnx_attr_proto->set_i(x_shape->shape()[1]);
2492   // set kernel_shape
2493   onnx_attr_proto = node_proto->add_attribute();
2494   onnx_attr_proto->set_name("kernel_shape");
2495   SetAttrTupleValueToProto<0>(prim->GetAttr("kernel_size"), onnx::AttributeProto_AttributeType_INTS, onnx_attr_proto,
2496                               prim);
2497 
2498   // set pad
2499   onnx_attr_proto = node_proto->add_attribute();
2500   int64_t attr_value;
2501   CheckAndConvertUtils::GetPadModEnumValue(prim->GetAttr("pad_mode"), &attr_value);
2502   onnx_attr_proto->set_name("auto_pad");
2503   onnx_attr_proto->set_type(onnx::AttributeProto_AttributeType_STRING);
2504   if (attr_value == PadMode::VALID) {
2505     onnx_attr_proto->set_s("VALID");
2506   } else if (attr_value == PadMode::SAME) {
2507     onnx_attr_proto->set_s("SAME_UPPER");
2508   } else {
2509     onnx_attr_proto->set_name("pads");
2510     SetAttrTupleValueToProto(prim->GetAttr("pad_list"), onnx::AttributeProto_AttributeType_INTS, onnx_attr_proto, prim);
2511   }
2512   // set strides
2513   onnx_attr_proto = node_proto->add_attribute();
2514   onnx_attr_proto->set_name("strides");
2515   SetAttrTupleValueToProto<2>(prim->GetAttr("stride"), onnx::AttributeProto_AttributeType_INTS, onnx_attr_proto, prim);
2516 }
2517 
ExportPrimTile(const FuncGraphPtr &,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)2518 void OnnxExporter::ExportPrimTile(const FuncGraphPtr &, const CNodePtr &node,
2519                                   std::map<AnfNodePtr, std::string> *node_map_ptr,
2520                                   onnx::GraphProto *const graph_proto) {
2521   auto name_x = GetNodeInputName(node->input(kOneNum), node_map_ptr, graph_proto);
2522   auto multiples = node->input(kTwoNum);
2523   std::string name_multiples;
2524   if (multiples->isa<ValueNode>()) {
2525     onnx::NodeProto *node_proto = graph_proto->add_node();
2526     name_multiples = RegisterNodeWithUniqueName(multiples, node_map_ptr);
2527     node_proto->add_output(name_multiples);
2528     node_proto->set_op_type("Constant");
2529     onnx::AttributeProto *attr_proto = node_proto->add_attribute();
2530     attr_proto->set_name("value");
2531     attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR);
2532     ConvertTupleToTensor(dyn_cast<ValueNode>(multiples)->value(), attr_proto->mutable_t());
2533   } else {
2534     name_multiples = GetNodeInputName(multiples, node_map_ptr, graph_proto);
2535     MS_LOG(EXCEPTION) << "Need to insert op convert variable from tuple to tensor for Tile.";
2536   }
2537 
2538   auto node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
2539   onnx::NodeProto *node_proto = graph_proto->add_node();
2540   node_proto->set_op_type("Tile");
2541   node_proto->add_output(node_name);
2542   node_proto->add_input(name_x);
2543   node_proto->add_input(name_multiples);
2544 }
2545 
ExportPrimSquare(const FuncGraphPtr &,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)2546 void OnnxExporter::ExportPrimSquare(const FuncGraphPtr &, const CNodePtr &node,
2547                                     std::map<AnfNodePtr, std::string> *node_map_ptr,
2548                                     onnx::GraphProto *const graph_proto) {
2549   auto name_x = GetNodeInputName(node->input(kOneNum), node_map_ptr, graph_proto);
2550   auto name_exponent = GenerateUniqueName();
2551   onnx::NodeProto *node_proto_exp = graph_proto->add_node();
2552   node_proto_exp->add_output(name_exponent);
2553 
2554   node_proto_exp->set_op_type("Constant");
2555   onnx::AttributeProto *attr_proto = node_proto_exp->add_attribute();
2556   attr_proto->set_name("value");
2557   attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR);
2558   onnx::TensorProto *tensor_proto = attr_proto->mutable_t();
2559   const float exponent_value = 2.0;
2560   tensor_proto->set_name("exponent");
2561   tensor_proto->add_dims(static_cast<::google::protobuf::int64>(1));
2562   tensor_proto->set_data_type(GetOnnxDataType(kNumberTypeFloat32));
2563   tensor_proto->add_float_data(exponent_value);
2564 
2565   auto node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
2566   onnx::NodeProto *node_proto = graph_proto->add_node();
2567   node_proto->set_op_type("Pow");
2568   node_proto->add_output(node_name);
2569   node_proto->add_input(name_x);
2570   node_proto->add_input(name_exponent);
2571 }
2572 
ExportPrimGatherV2(const FuncGraphPtr &,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)2573 void OnnxExporter::ExportPrimGatherV2(const FuncGraphPtr &, const CNodePtr &node,
2574                                       std::map<AnfNodePtr, std::string> *node_map_ptr,
2575                                       onnx::GraphProto *const graph_proto) {
2576   auto name_x = GetNodeInputName(node->input(kOneNum), node_map_ptr, graph_proto);
2577   auto name_indices = GetNodeInputName(node->input(kTwoNum), node_map_ptr, graph_proto);
2578   auto axis = node->input(kThreeNum)->cast<ValueNodePtr>()->value();
2579   auto node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
2580   onnx::NodeProto *node_proto = graph_proto->add_node();
2581   node_proto->set_op_type("Gather");
2582   node_proto->add_output(node_name);
2583   node_proto->add_input(name_x);
2584   node_proto->add_input(name_indices);
2585   onnx::AttributeProto *attr_proto = node_proto->add_attribute();
2586   attr_proto->set_name("axis");
2587   attr_proto->set_type(onnx::AttributeProto_AttributeType_INT);
2588   attr_proto->set_i(static_cast<::google::protobuf::int64>(dyn_cast<Int64Imm>(axis)->value()));
2589 }
2590 
2591 /*
2592   This is a workaround for nodes with several outputs used at once
2593   MatchAndMark cannot help here, because it only supports a single output
2594   Proposed convention:
2595     * Nodes with several outputs are registered as
2596       `(*node_map_ptr)[node] = node_idx;`, just like nodes with a single output
2597     * Their outputs are named "{node_idx}_{output_idx}"
2598     * TupleGetItem automatically passes the outputs to the next nodes
2599   See OnnxExporter::ExportPrimTopK for a usage example
2600 */
ExportPrimTupleGetItem(const FuncGraphPtr &,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)2601 void OnnxExporter::ExportPrimTupleGetItem(const FuncGraphPtr &, const CNodePtr &node,
2602                                           std::map<AnfNodePtr, std::string> *node_map_ptr,
2603                                           onnx::GraphProto *const graph_proto) {
2604   auto index = GetInt64Value(node->input(kTwoNum));
2605   auto input_node_name = GetNodeInputName(node->input(kOneNum), node_map_ptr, graph_proto);
2606   auto input_name = MakeOutputName(input_node_name, index);
2607 
2608   auto node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
2609 
2610   onnx::NodeProto *node_proto = graph_proto->add_node();
2611   node_proto->set_op_type("Identity");
2612   node_proto->add_input(input_name);
2613   node_proto->add_output(node_name);
2614 }
2615 
ExportPrimTopK(const FuncGraphPtr &,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)2616 void OnnxExporter::ExportPrimTopK(const FuncGraphPtr &, const CNodePtr &node,
2617                                   std::map<AnfNodePtr, std::string> *node_map_ptr,
2618                                   onnx::GraphProto *const graph_proto) {
2619   auto x_input_name = GetNodeInputName(node->input(kOneNum), node_map_ptr, graph_proto);
2620 
2621   auto node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
2622 
2623   auto k_input_name = node_name + "k_initializer";
2624   auto k = GetInt64Value(node->input(kTwoNum));
2625   AddInt64Tensor1DInitializer(k_input_name, {k}, graph_proto);
2626 
2627   onnx::NodeProto *node_proto = graph_proto->add_node();
2628   node_proto->set_op_type("TopK");
2629   node_proto->add_input(x_input_name);
2630   node_proto->add_input(k_input_name);
2631   node_proto->add_output(MakeOutputName(node_name, kZeroNum));  // Values
2632   auto indices_name = MakeOutputName(node_name, kOneNum);
2633   auto indices_cast_name = indices_name + "_cast";
2634   node_proto->add_output(indices_cast_name);
2635 
2636   onnx::AttributeProto *sorted_attr_proto = node_proto->add_attribute();
2637   sorted_attr_proto->set_name("sorted");
2638   sorted_attr_proto->set_type(onnx::AttributeProto_AttributeType_INT);
2639   auto sorted = GetOpAttribute<bool>(node, "sorted");
2640   sorted_attr_proto->set_i(sorted);
2641   AddCastOp(indices_cast_name, indices_name, onnx::TensorProto_DataType_INT32, graph_proto);
2642 }
2643 
2644 // Based on mindspore/ccsrc/backend/kernel_compiler/cpu/boundingbox_decode_cpu_kernel.cc
ExportPrimBoundingBoxDecode(const FuncGraphPtr &,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)2645 void OnnxExporter::ExportPrimBoundingBoxDecode(const FuncGraphPtr &, const CNodePtr &node,
2646                                                std::map<AnfNodePtr, std::string> *node_map_ptr,
2647                                                onnx::GraphProto *const graph_proto) {
2648   auto node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
2649 
2650   auto anchor_bbox_input_name = GetNodeInputName(node->input(kOneNum), node_map_ptr, graph_proto);
2651   auto deltas_input_name = GetNodeInputName(node->input(kTwoNum), node_map_ptr, graph_proto);
2652   auto onnx_input_type = GetOutputType(node->input(kOneNum));
2653 
2654   auto means = GetOpAttributePtr<ValueTuple>(node, "means");
2655   std::vector<float> mean_values = GetValue<std::vector<float>>(means);
2656   auto means_name = node_name + "means_initializer";
2657   AddFloatTensor1DInitializer(means_name, mean_values, onnx_input_type, graph_proto);
2658 
2659   auto stds = GetOpAttributePtr<ValueTuple>(node, "stds");
2660   std::vector<float> std_values = GetValue<std::vector<float>>(stds);
2661   auto stds_name = node_name + "stds_initializer";
2662   AddFloatTensor1DInitializer(stds_name, std_values, onnx_input_type, graph_proto);
2663 
2664   auto wh_ratio_clip = GetOpAttribute<float>(node, "wh_ratio_clip");
2665   auto max_ratio = static_cast<float>(std::abs(std::log(wh_ratio_clip)));
2666 
2667   auto unstd_deltas_name = node_name + "unstd_deltas";
2668   auto sd_to_add_name = unstd_deltas_name + "__add";
2669   AddOp("Mul", {deltas_input_name, stds_name}, {sd_to_add_name}, graph_proto);
2670   AddOp("Add", {sd_to_add_name, means_name}, {unstd_deltas_name}, graph_proto);
2671 
2672   auto center_deltas_name = node_name + "center_deltas";
2673   auto log_scale_deltas_name = node_name + "log_scale_deltas";
2674   auto lsd_to_clip_name = log_scale_deltas_name + "__clip";
2675   AddSplitOp(unstd_deltas_name, {center_deltas_name, lsd_to_clip_name}, {kTwoNum, kTwoNum}, 1, graph_proto);
2676   AddClipOp(lsd_to_clip_name, log_scale_deltas_name, -max_ratio, max_ratio, onnx_input_type, graph_proto);
2677 
2678   auto anchor_starts_name = node_name + "anchor_starts";
2679   auto anchor_ends_name = node_name + "anchor_ends";
2680   AddSplitOp(anchor_bbox_input_name, {anchor_starts_name, anchor_ends_name}, {kTwoNum, kTwoNum}, 1, graph_proto);
2681 
2682   auto anchor_centers_name = node_name + "anchor_centers";
2683   auto anchor_dimensions_name = node_name + "anchor_dimensions";
2684   ConvertBoxesToXywh(anchor_starts_name, anchor_ends_name, anchor_centers_name, anchor_dimensions_name, onnx_input_type,
2685                      graph_proto);
2686 
2687   auto anchor_shifts_name = node_name + "anchor_shifts";
2688   AddOp("Mul", {anchor_dimensions_name, center_deltas_name}, {anchor_shifts_name}, graph_proto);
2689   auto result_centers_name = node_name + "result_centers";
2690   AddOp("Add", {anchor_centers_name, anchor_shifts_name}, {result_centers_name}, graph_proto);
2691 
2692   auto anchor_scales_name = node_name + "anchor_scales";
2693   AddOp("Exp", {log_scale_deltas_name}, {anchor_scales_name}, graph_proto);
2694   auto result_dimensions_name = node_name + "result_dimensions";
2695   AddOp("Mul", {anchor_dimensions_name, anchor_scales_name}, {result_dimensions_name}, graph_proto);
2696 
2697   auto result_starts_to_clip_name = node_name + "result_starts_to_clip";
2698   auto result_ends_to_clip_name = node_name + "result_ends_to_clip";
2699   ConvertBoxesToXyxy(result_centers_name, result_dimensions_name, result_starts_to_clip_name, result_ends_to_clip_name,
2700                      onnx_input_type, graph_proto);
2701 
2702   auto max_shape = GetOpAttributePtr<ValueTuple>(node, "max_shape");
2703   auto max_y = GetValue<int64_t>((*max_shape)[0]);
2704   auto max_x = GetValue<int64_t>((*max_shape)[1]);
2705   auto result_start_xs_name = node_name + "result_start_x";
2706   auto result_start_ys_name = node_name + "result_start_y";
2707   auto result_end_xs_name = node_name + "result_end_x";
2708   auto result_end_ys_name = node_name + "result_end_y";
2709   ClipPointsComponent(result_starts_to_clip_name, result_start_xs_name, static_cast<float>(max_x), 0, onnx_input_type,
2710                       graph_proto);
2711   ClipPointsComponent(result_starts_to_clip_name, result_start_ys_name, static_cast<float>(max_y), 1, onnx_input_type,
2712                       graph_proto);
2713   ClipPointsComponent(result_ends_to_clip_name, result_end_xs_name, static_cast<float>(max_x), 0, onnx_input_type,
2714                       graph_proto);
2715   ClipPointsComponent(result_ends_to_clip_name, result_end_ys_name, static_cast<float>(max_y), 1, onnx_input_type,
2716                       graph_proto);
2717 
2718   AddConcatOp({result_start_xs_name, result_start_ys_name, result_end_xs_name, result_end_ys_name}, node_name, kOneNum,
2719               graph_proto);
2720 }
2721 
ExportPrimNMSWithMask(const FuncGraphPtr &,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)2722 void OnnxExporter::ExportPrimNMSWithMask(const FuncGraphPtr &, const CNodePtr &node,
2723                                          std::map<AnfNodePtr, std::string> *node_map_ptr,
2724                                          onnx::GraphProto *const graph_proto) {
2725   auto node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
2726 
2727   auto bboxes_input_name = GetNodeInputName(node->input(kOneNum), node_map_ptr, graph_proto);
2728   auto iou_threshold = GetOpAttribute<float>(node, "iou_threshold");
2729   auto selected_boxes_output_name = MakeOutputName(node_name, kZeroNum);
2730   auto selected_idx_output_name = MakeOutputName(node_name, kOneNum);
2731   auto selected_mask_output_name = MakeOutputName(node_name, kTwoNum);
2732   auto onnx_input_type = GetOutputType(node->input(kOneNum));
2733 
2734   // Preprocessing
2735 
2736   auto boxes_count_name = node_name + "max_output_boxes";
2737   auto max_output_boxes_to_squeeze_name = boxes_count_name + "_to_reshape";
2738   auto input_shape_name = node_name + "input_shape";
2739   AddOp("Shape", {bboxes_input_name}, {input_shape_name}, graph_proto);
2740   AddSliceOp(input_shape_name, max_output_boxes_to_squeeze_name, {0}, {1}, {0}, {1}, graph_proto);
2741   AddReshapeOp(max_output_boxes_to_squeeze_name, boxes_count_name, {}, graph_proto);
2742 
2743   auto scores_name = node_name + "scores";
2744   auto flat_scores_name = scores_name + "_flat";
2745   auto sorted_scores_name = flat_scores_name + "_sorted";
2746   auto scores_to_flatten_name = scores_name + "_to_reshape";
2747   auto descending_order_name = node_name + "descending_indices";
2748   const int BBOX_NUM_EL = 4;
2749   AddSliceOp(bboxes_input_name, scores_to_flatten_name, {BBOX_NUM_EL}, {BBOX_NUM_EL + 1}, {1}, {1}, graph_proto);
2750   AddReshapeOp(scores_to_flatten_name, flat_scores_name, {-1}, graph_proto);
2751   AddOp("TopK", {flat_scores_name, max_output_boxes_to_squeeze_name}, {sorted_scores_name, descending_order_name},
2752         graph_proto);
2753   AddReshapeOp(sorted_scores_name, scores_name, {1, 1, -1}, graph_proto);
2754   auto iou_threshold_name = node_name + "iou_threshold_initializer";
2755   AddFloatScalarInitializer(iou_threshold_name, iou_threshold, onnx::TensorProto_DataType_FLOAT, graph_proto);
2756 
2757   AddOp("Gather", {bboxes_input_name, descending_order_name}, {selected_boxes_output_name},
2758         graph_proto);  // Output 0: boxes
2759   auto boxes_name = node_name + "boxes";
2760   auto boxes_to_reshape_name = boxes_name + "_to_reshape";
2761   AddSliceOp(selected_boxes_output_name, boxes_to_reshape_name, {0}, {BBOX_NUM_EL}, {1}, {1}, graph_proto);
2762   AddReshapeOp(boxes_to_reshape_name, boxes_name, {1, -1, BBOX_NUM_EL}, graph_proto);
2763 
2764   if (onnx_input_type == onnx::TensorProto_DataType_FLOAT16) {
2765     auto fp32_boxes_name = boxes_name + "_fp32";
2766     AddCastOp(boxes_name, fp32_boxes_name, onnx::TensorProto_DataType_FLOAT, graph_proto);
2767     boxes_name = fp32_boxes_name;
2768 
2769     auto fp32_scores_name = scores_name + "_fp32";
2770     AddCastOp(scores_name, fp32_scores_name, onnx::TensorProto_DataType_FLOAT, graph_proto);
2771     scores_name = fp32_scores_name;
2772   }
2773 
2774   // NMS op
2775 
2776   auto selected_indices_name = node_name + "selected_indices";
2777   AddOp("NonMaxSuppression", {boxes_name, scores_name, boxes_count_name, iou_threshold_name}, {selected_indices_name},
2778         graph_proto);
2779 
2780   // Output 1: indices
2781 
2782   auto flat_indices_name = node_name + "flat_indices";
2783   auto flat_indices_to_squeeze_name = flat_indices_name + "__reshape";
2784   const int BOX_INDEX_POS = 2;
2785   AddSliceOp(selected_indices_name, flat_indices_to_squeeze_name, {BOX_INDEX_POS}, {BOX_INDEX_POS + 1}, {1}, {1},
2786              graph_proto);
2787   AddReshapeOp(flat_indices_to_squeeze_name, flat_indices_name, {-1}, graph_proto);
2788 
2789   auto zero_name = node_name + "zero_initializer";
2790   onnx::TensorProto *zero_initializer = graph_proto->add_initializer();
2791   zero_initializer->set_name(zero_name);
2792   zero_initializer->set_data_type(onnx::TensorProto_DataType_INT32);
2793   zero_initializer->add_int32_data(0);
2794   auto one_name = node_name + "one_initializer";
2795   onnx::TensorProto *one_initializer = graph_proto->add_initializer();
2796   one_initializer->set_name(one_name);
2797   one_initializer->set_data_type(onnx::TensorProto_DataType_INT32);
2798   one_initializer->add_int32_data(1);
2799   auto int32_boxes_count_name = boxes_count_name + "_int32";
2800   AddCastOp(boxes_count_name, int32_boxes_count_name, onnx::TensorProto_DataType_INT32, graph_proto);
2801   AddOp("Range", {zero_name, int32_boxes_count_name, one_name}, {selected_idx_output_name}, graph_proto);
2802 
2803   // Output 2: mask
2804 
2805   auto empty_mask_name = selected_mask_output_name + "__scatter";
2806   onnx::TensorProto *empty_mask_value_proto =
2807     AddConstantOfShapeOp(max_output_boxes_to_squeeze_name, empty_mask_name, graph_proto);
2808   empty_mask_value_proto->set_data_type(onnx::TensorProto_DataType_BOOL);
2809   empty_mask_value_proto->add_int32_data(0);
2810 
2811   auto true_elements_name = node_name + "true";
2812   auto true_elements_shape_name = true_elements_name + "_shape";
2813   AddOp("Shape", {flat_indices_name}, {true_elements_shape_name}, graph_proto);
2814   onnx::TensorProto *true_elements_value_proto =
2815     AddConstantOfShapeOp(true_elements_shape_name, true_elements_name, graph_proto);
2816   true_elements_value_proto->set_data_type(onnx::TensorProto_DataType_BOOL);
2817   true_elements_value_proto->add_int32_data(1);
2818 
2819   AddOp("ScatterElements", {empty_mask_name, flat_indices_name, true_elements_name}, {selected_mask_output_name},
2820         graph_proto);
2821 }
2822 
ExportPrimSplit(const FuncGraphPtr &,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)2823 void OnnxExporter::ExportPrimSplit(const FuncGraphPtr &, const CNodePtr &node,
2824                                    std::map<AnfNodePtr, std::string> *node_map_ptr,
2825                                    onnx::GraphProto *const graph_proto) {
2826   auto node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
2827   auto input_name = GetNodeInputName(node->input(kOneNum), node_map_ptr, graph_proto);
2828 
2829   auto axis = GetOpAttribute<int64_t>(node, "axis");
2830   auto output_num = GetOpAttribute<int64_t>(node, "output_num");
2831   if (output_num == 0) {
2832     MS_LOG(EXCEPTION) << "output_num must be > 0";
2833   }
2834   const auto &input_shape = dyn_cast<abstract::Shape>(node->input(kOneNum)->Shape())->shape();
2835 
2836   if (axis < 0 || static_cast<size_t>(axis) >= input_shape.size()) {
2837     MS_LOG(EXCEPTION) << "`axis` is out of range";
2838   }
2839   if (input_shape[static_cast<size_t>(axis)] % output_num != 0) {
2840     MS_LOG(EXCEPTION) << "Input dim is not divisible by `output_num`";
2841   }
2842 
2843   onnx::NodeProto *split_proto = graph_proto->add_node();
2844   split_proto->set_op_type("Split");
2845   split_proto->add_input(input_name);
2846   for (int64_t i = 0; i < output_num; ++i) {
2847     split_proto->add_output(MakeOutputName(node_name, i));
2848   }
2849 
2850   onnx::AttributeProto *axis_attr_proto = split_proto->add_attribute();
2851   axis_attr_proto->set_name("axis");
2852   axis_attr_proto->set_type(onnx::AttributeProto_AttributeType_INT);
2853   axis_attr_proto->set_i(axis);
2854 
2855   onnx::AttributeProto *split_attr_proto = split_proto->add_attribute();
2856   split_attr_proto->set_name("split");
2857   split_attr_proto->set_type(onnx::AttributeProto_AttributeType_INTS);
2858   for (int64_t i = 0; i < output_num; ++i) {
2859     split_attr_proto->add_ints(input_shape[static_cast<size_t>(axis)] / output_num);
2860   }
2861 }
2862 
2863 /*
2864   Based on mindspore-project/mindspore/ccsrc/backend/kernel_compiler/cpu/roi_align_cpu_kernel.cc
2865   Notes:
2866     * MS version uses avg pool, leaving corresponding ONNX attr as is
2867     * MS has two ROI end modes, implemented with pre-processing
2868  */
ExportPrimROIAlign(const FuncGraphPtr &,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)2869 void OnnxExporter::ExportPrimROIAlign(const FuncGraphPtr &, const CNodePtr &node,
2870                                       std::map<AnfNodePtr, std::string> *node_map_ptr,
2871                                       onnx::GraphProto *const graph_proto) {
2872   auto node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
2873   auto features_input_name = GetNodeInputName(node->input(kOneNum), node_map_ptr, graph_proto);
2874   auto rois_input_name = GetNodeInputName(node->input(kTwoNum), node_map_ptr, graph_proto);
2875   auto onnx_input_type = GetOutputType(node->input(kOneNum));
2876 
2877   auto roi_indices_name = node_name + "roi_indices";
2878   auto roi_indices_column_name = roi_indices_name + "_column";
2879   auto roi_starts_name = node_name + "roi_starts";
2880   auto roi_ends_name = node_name + "roi_ends";
2881   AddSplitOp(rois_input_name, {roi_indices_column_name, roi_starts_name, roi_ends_name}, {1, kTwoNum, kTwoNum}, 1,
2882              graph_proto);
2883 
2884   // Indices transformation
2885 
2886   auto flat_roi_indices_name = roi_indices_name + "_flat";
2887   AddReshapeOp(roi_indices_column_name, flat_roi_indices_name, {-1}, graph_proto);
2888   auto int_roi_indices_name = roi_indices_name + "_int";
2889   // This should be fine if indices are whole numbers less than 2^23
2890   AddCastOp(flat_roi_indices_name, int_roi_indices_name, onnx::TensorProto_DataType_INT64, graph_proto);
2891 
2892   // ROI end mode
2893 
2894   auto roi_end_mode = GetOpAttribute<int64_t>(node, "roi_end_mode");
2895   auto roi_end_mode_name = node_name + "roi_end_mode_initializer";
2896   AddFloatScalarInitializer(roi_end_mode_name, static_cast<float>(roi_end_mode), onnx_input_type, graph_proto);
2897 
2898   auto corrected_roi_ends_name = roi_ends_name + "_corrected";
2899   AddOp("Add", {roi_ends_name, roi_end_mode_name}, {corrected_roi_ends_name}, graph_proto);
2900 
2901   // Contatenate ROIs
2902 
2903   auto corrected_rois_name = node_name + "corrected_rois";
2904   AddConcatOp({roi_starts_name, corrected_roi_ends_name}, corrected_rois_name, kOneNum, graph_proto);
2905 
2906   // RoiAlign op
2907 
2908   onnx::NodeProto *roi_align_proto = graph_proto->add_node();
2909   roi_align_proto->set_op_type("RoiAlign");
2910   roi_align_proto->add_input(features_input_name);
2911   roi_align_proto->add_input(corrected_rois_name);
2912   roi_align_proto->add_input(int_roi_indices_name);
2913   roi_align_proto->add_output(node_name);
2914   onnx::AttributeProto *height_attr_proto = roi_align_proto->add_attribute();
2915   height_attr_proto->set_name("output_height");
2916   height_attr_proto->set_type(onnx::AttributeProto_AttributeType_INT);
2917   height_attr_proto->set_i(GetOpAttribute<int64_t>(node, "pooled_height"));
2918   onnx::AttributeProto *width_attr_proto = roi_align_proto->add_attribute();
2919   width_attr_proto->set_name("output_width");
2920   width_attr_proto->set_type(onnx::AttributeProto_AttributeType_INT);
2921   width_attr_proto->set_i(GetOpAttribute<int64_t>(node, "pooled_width"));
2922   onnx::AttributeProto *scale_attr_proto = roi_align_proto->add_attribute();
2923   scale_attr_proto->set_name("spatial_scale");
2924   scale_attr_proto->set_type(onnx::AttributeProto_AttributeType_FLOAT);
2925   scale_attr_proto->set_f(GetOpAttribute<float>(node, "spatial_scale"));
2926   onnx::AttributeProto *sampling_ratio_attr_proto = roi_align_proto->add_attribute();
2927   sampling_ratio_attr_proto->set_name("sampling_ratio");
2928   sampling_ratio_attr_proto->set_type(onnx::AttributeProto_AttributeType_INT);
2929   sampling_ratio_attr_proto->set_i(GetOpAttribute<int64_t>(node, "sample_num"));
2930 }
2931 
ExportPrimSlice(const FuncGraphPtr &,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)2932 void OnnxExporter::ExportPrimSlice(const FuncGraphPtr &, const CNodePtr &node,
2933                                    std::map<AnfNodePtr, std::string> *node_map_ptr,
2934                                    onnx::GraphProto *const graph_proto) {
2935   auto node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
2936   auto input_x_name = GetNodeInputName(node->input(kOneNum), node_map_ptr, graph_proto);
2937   auto begin_input_name = GetNodeInputName(node->input(kTwoNum), node_map_ptr, graph_proto);
2938   auto size_input_name = GetNodeInputName(node->input(kThreeNum), node_map_ptr, graph_proto);
2939 
2940   auto end_name = node_name + "end";
2941   AddOp("Add", {begin_input_name, size_input_name}, {end_name}, graph_proto);
2942   AddOp("Slice", {input_x_name, begin_input_name, end_name}, {node_name}, graph_proto);
2943 }
2944 
ExportPrimOnesLike(const FuncGraphPtr &,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)2945 void OnnxExporter::ExportPrimOnesLike(const FuncGraphPtr &, const CNodePtr &node,
2946                                       std::map<AnfNodePtr, std::string> *node_map_ptr,
2947                                       onnx::GraphProto *const graph_proto) {
2948   auto node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
2949   auto input_x_name = GetNodeInputName(node->input(kOneNum), node_map_ptr, graph_proto);
2950 
2951   auto shape_name = node_name + "shape";
2952   AddOp("Shape", {input_x_name}, {shape_name}, graph_proto);
2953 
2954   auto dtype = node->input(kOneNum)->Type();
2955   auto elem_type = dyn_cast<TensorType>(dtype)->element()->type_id();
2956 
2957   onnx::TensorProto *one_proto = AddConstantOfShapeOp(shape_name, node_name, graph_proto);
2958   switch (elem_type) {
2959     case kNumberTypeInt32:
2960       one_proto->set_data_type(onnx::TensorProto_DataType_INT32);
2961       one_proto->add_int32_data(1);
2962       break;
2963     case kNumberTypeInt64:
2964       one_proto->set_data_type(onnx::TensorProto_DataType_INT64);
2965       one_proto->add_int64_data(1);
2966       break;
2967     case kNumberTypeFloat32:
2968       one_proto->set_data_type(onnx::TensorProto_DataType_FLOAT);
2969       one_proto->add_float_data(1.0f);
2970       break;
2971     case kNumberTypeFloat64:
2972       one_proto->set_data_type(onnx::TensorProto_DataType_DOUBLE);
2973       one_proto->add_double_data(1.0);
2974       break;
2975     default:
2976       MS_LOG(EXCEPTION) << "Unsupported dtype: " << elem_type;
2977   }
2978 }
2979 
ExportPrimScatterNd(const FuncGraphPtr &,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)2980 void OnnxExporter::ExportPrimScatterNd(const FuncGraphPtr &, const CNodePtr &node,
2981                                        std::map<AnfNodePtr, std::string> *node_map_ptr,
2982                                        onnx::GraphProto *const graph_proto) {
2983   auto node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
2984   auto input_indices_name = GetNodeInputName(node->input(kOneNum), node_map_ptr, graph_proto);
2985   auto input_update_name = GetNodeInputName(node->input(kTwoNum), node_map_ptr, graph_proto);
2986   auto input_shape_name = GetNodeInputName(node->input(kThreeNum), node_map_ptr, graph_proto);
2987   auto node_zero_tensor_name = node_name + "_zero";
2988   auto dtype = node->input(kTwoNum)->Type();
2989   auto elem_type = dyn_cast<TensorType>(dtype)->element()->type_id();
2990 
2991   onnx::TensorProto *zero_proto = AddConstantOfShapeOp(input_shape_name, node_zero_tensor_name, graph_proto);
2992   switch (elem_type) {
2993     case kNumberTypeInt32:
2994       zero_proto->set_data_type(onnx::TensorProto_DataType_INT32);
2995       zero_proto->add_int32_data(0);
2996       break;
2997     case kNumberTypeInt64:
2998       zero_proto->set_data_type(onnx::TensorProto_DataType_INT64);
2999       zero_proto->add_int64_data(0);
3000       break;
3001     case kNumberTypeFloat32:
3002       zero_proto->set_data_type(onnx::TensorProto_DataType_FLOAT);
3003       zero_proto->add_float_data(0.0f);
3004       break;
3005     case kNumberTypeFloat64:
3006       zero_proto->set_data_type(onnx::TensorProto_DataType_DOUBLE);
3007       zero_proto->add_double_data(0.0);
3008       break;
3009     default:
3010       MS_LOG(EXCEPTION) << "Unsupported dtype: " << elem_type;
3011   }
3012   auto int64_indices_name = input_indices_name + "_int64";
3013   AddCastOp(input_indices_name, int64_indices_name, onnx::TensorProto_DataType_INT64, graph_proto);
3014 
3015   // Create ScatterND node
3016   onnx::NodeProto *scatternd_proto = graph_proto->add_node();
3017   scatternd_proto->set_op_type("ScatterND");
3018   scatternd_proto->add_input(node_zero_tensor_name);
3019   scatternd_proto->add_input(int64_indices_name);
3020   scatternd_proto->add_input(input_update_name);
3021   scatternd_proto->add_output(node_name);
3022 }
3023 
ExportPrimArgMaxWithValue(const FuncGraphPtr &,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)3024 void OnnxExporter::ExportPrimArgMaxWithValue(const FuncGraphPtr &, const CNodePtr &node,
3025                                              std::map<AnfNodePtr, std::string> *node_map_ptr,
3026                                              onnx::GraphProto *const graph_proto) {
3027   auto node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
3028   auto input_x_name = GetNodeInputName(node->input(kOneNum), node_map_ptr, graph_proto);
3029   auto axis = GetOpAttribute<int64_t>(node, "axis");
3030   auto keep_dims = GetOpAttribute<bool>(node, "keep_dims");
3031 
3032   auto indices_output_name = MakeOutputName(node_name, kZeroNum);
3033   auto indices_cast_name = indices_output_name + "_cast";
3034 
3035   onnx::NodeProto *argmax_proto = graph_proto->add_node();
3036   argmax_proto->set_op_type("ArgMax");
3037   argmax_proto->add_input(input_x_name);
3038   argmax_proto->add_output(indices_cast_name);
3039   onnx::AttributeProto *argmax_axis_attr_proto = argmax_proto->add_attribute();
3040   argmax_axis_attr_proto->set_name("axis");
3041   argmax_axis_attr_proto->set_type(onnx::AttributeProto_AttributeType_INT);
3042   argmax_axis_attr_proto->set_i(axis);
3043   onnx::AttributeProto *argmax_keepdims_attr_proto = argmax_proto->add_attribute();
3044   argmax_keepdims_attr_proto->set_name("keepdims");
3045   argmax_keepdims_attr_proto->set_type(onnx::AttributeProto_AttributeType_INT);
3046   argmax_keepdims_attr_proto->set_i(keep_dims);
3047 
3048   AddCastOp(indices_cast_name, indices_output_name, onnx::TensorProto_DataType_INT32, graph_proto);
3049 
3050   auto max_output_name = MakeOutputName(node_name, kOneNum);
3051   AddReduceOp("ReduceMax", input_x_name, max_output_name, {axis}, keep_dims, graph_proto);
3052 }
3053 
ExportPrimArgMinWithValue(const FuncGraphPtr &,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)3054 void OnnxExporter::ExportPrimArgMinWithValue(const FuncGraphPtr &, const CNodePtr &node,
3055                                              std::map<AnfNodePtr, std::string> *node_map_ptr,
3056                                              onnx::GraphProto *const graph_proto) {
3057   auto node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
3058   auto input_x_name = GetNodeInputName(node->input(kOneNum), node_map_ptr, graph_proto);
3059   auto axis = GetOpAttribute<int64_t>(node, "axis");
3060   auto keep_dims = GetOpAttribute<bool>(node, "keep_dims");
3061 
3062   auto indices_output_name = MakeOutputName(node_name, kZeroNum);
3063   auto indices_cast_name = indices_output_name + "_cast";
3064 
3065   onnx::NodeProto *argmax_proto = graph_proto->add_node();
3066   argmax_proto->set_op_type("ArgMin");
3067   argmax_proto->add_input(input_x_name);
3068   argmax_proto->add_output(indices_cast_name);
3069   onnx::AttributeProto *argmax_axis_attr_proto = argmax_proto->add_attribute();
3070   argmax_axis_attr_proto->set_name("axis");
3071   argmax_axis_attr_proto->set_type(onnx::AttributeProto_AttributeType_INT);
3072   argmax_axis_attr_proto->set_i(axis);
3073   onnx::AttributeProto *argmax_keepdims_attr_proto = argmax_proto->add_attribute();
3074   argmax_keepdims_attr_proto->set_name("keepdims");
3075   argmax_keepdims_attr_proto->set_type(onnx::AttributeProto_AttributeType_INT);
3076   argmax_keepdims_attr_proto->set_i(static_cast<int64_t>(keep_dims));
3077 
3078   AddCastOp(indices_cast_name, indices_output_name, onnx::TensorProto_DataType_INT32, graph_proto);
3079 
3080   auto max_output_name = MakeOutputName(node_name, kOneNum);
3081   AddReduceOp("ReduceMin", input_x_name, max_output_name, {axis}, keep_dims, graph_proto);
3082 }
3083 
ExportPrimOneHot(const FuncGraphPtr &,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)3084 void OnnxExporter::ExportPrimOneHot(const FuncGraphPtr &, const CNodePtr &node,
3085                                     std::map<AnfNodePtr, std::string> *node_map_ptr,
3086                                     onnx::GraphProto *const graph_proto) {
3087   auto node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
3088   auto indices_input_name = GetNodeInputName(node->input(kOneNum), node_map_ptr, graph_proto);
3089   auto depth_input_name = GetNodeInputName(node->input(kTwoNum), node_map_ptr, graph_proto);
3090   auto on_input_name = GetNodeInputName(node->input(kThreeNum), node_map_ptr, graph_proto);
3091   auto off_input_name = GetNodeInputName(node->input(kFourNum), node_map_ptr, graph_proto);
3092   auto axis = GetOpAttribute<int64_t>(node, "axis");
3093 
3094   if (GetOutputType(node->input(kOneNum)) == onnx::TensorProto_DataType_INT32) {
3095     auto indices_cast_name = node_name + "_indices_as_int32";
3096     AddCastOp(indices_input_name, indices_cast_name, onnx::TensorProto_DataType_INT64, graph_proto);
3097     indices_input_name = indices_cast_name;
3098   }
3099 
3100   auto on_1d_name = node_name + "on_1d";
3101   AddReshapeOp(on_input_name, on_1d_name, {-1}, graph_proto);
3102   auto off_1d_name = node_name + "off_1d";
3103   AddReshapeOp(off_input_name, off_1d_name, {-1}, graph_proto);
3104 
3105   auto on_off_name = node_name + "on_off";
3106   AddConcatOp({off_1d_name, on_1d_name}, on_off_name, kZeroNum, graph_proto);
3107 
3108   onnx::NodeProto *one_hot_proto = graph_proto->add_node();
3109   one_hot_proto->set_op_type("OneHot");
3110   one_hot_proto->add_input(indices_input_name);
3111   one_hot_proto->add_input(depth_input_name);
3112   one_hot_proto->add_input(on_off_name);
3113   one_hot_proto->add_output(node_name);
3114   onnx::AttributeProto *one_hot_axis_attr_proto = one_hot_proto->add_attribute();
3115   one_hot_axis_attr_proto->set_name("axis");
3116   one_hot_axis_attr_proto->set_type(onnx::AttributeProto_AttributeType_INT);
3117   one_hot_axis_attr_proto->set_i(axis);
3118 }
3119 
3120 /*
3121   Based on nn.Conv2dTranspose
3122   Warning: `output_shape` is an input in MS and an attribute in ONNX. Hence
3123            it is not possible to change the output shape in runtime
3124  */
PrimConv2DTransposeExportHelper(const CNodePtr & conv_node,const CNodePtr & bias_add_node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)3125 void OnnxExporter::PrimConv2DTransposeExportHelper(const CNodePtr &conv_node, const CNodePtr &bias_add_node,
3126                                                    std::map<AnfNodePtr, std::string> *node_map_ptr,
3127                                                    onnx::GraphProto *const graph_proto) {
3128   std::string node_name;
3129 
3130   std::vector<AnfNodePtr> inputs{conv_node->input(kOneNum), conv_node->input(kTwoNum)};
3131   if (bias_add_node != nullptr) {
3132     inputs.push_back(bias_add_node->input(kTwoNum));
3133     node_name = RegisterNodeWithUniqueName(bias_add_node, node_map_ptr);
3134   } else {
3135     node_name = RegisterNodeWithUniqueName(conv_node, node_map_ptr);
3136   }
3137 
3138   onnx::NodeProto *node_proto = graph_proto->add_node();
3139   node_proto->set_op_type("ConvTranspose");
3140   for (const auto &input : inputs) {
3141     node_proto->add_input(GetNodeInputName(input, node_map_ptr, graph_proto));
3142   }
3143   node_proto->add_output(node_name);
3144 
3145   auto prim = GetPrimitive(conv_node);
3146   auto attrs_convert_info =
3147     OpNameInfo()
3148       .Attr("dilation", "dilations", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<kTwoNum>)
3149       .Attr("group", "group", onnx::AttributeProto_AttributeType_INT, SetAttrValueToProto<Int64Imm>)
3150       .Attr("kernel_size", "kernel_shape", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<0>)
3151       .Attr("pad_mode", "auto_pad", onnx::AttributeProto_AttributeType_STRING, SetConvTransposePadding)
3152       .Attr("stride", "strides", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<kTwoNum>);
3153   for (const auto &attr_info : attrs_convert_info.op_attrs()) {
3154     onnx::AttributeProto *attr_proto = node_proto->add_attribute();
3155     attr_proto->set_name(attr_info.onnx_attr_name());
3156     auto ms_attr = GetOpAttributePtr<Value>(conv_node, attr_info.attr_name());
3157     MS_EXCEPTION_IF_NULL(ms_attr);
3158     attr_info.fn_gen_attr()(ms_attr, attr_info.onnx_attr_type(), attr_proto, prim);
3159   }
3160 
3161   // Set output shape
3162 
3163   auto input_shape_node = GetRealInput(conv_node->input(kThreeNum));
3164   if (!input_shape_node->isa<ValueNode>()) {
3165     MS_LOG(EXCEPTION) << "For ONNX export third argument must be constant "
3166                          "(Python tuple). Instead got "
3167                       << input_shape_node->ToString();
3168   }
3169   auto input_shape_value_ptr = input_shape_node->cast<ValueNodePtr>()->value();
3170   if (!input_shape_value_ptr->isa<ValueTuple>()) {
3171     MS_LOG(EXCEPTION) << "Expected ValueTuple, got " << input_shape_value_ptr->ToString() << " of type "
3172                       << input_shape_value_ptr->type()->ToString();
3173   }
3174 
3175   onnx::AttributeProto *output_shape_attr_proto = node_proto->add_attribute();
3176   output_shape_attr_proto->set_name("output_shape");
3177   SetAttrTupleValueToProto<0>(input_shape_value_ptr, onnx::AttributeProto_AttributeType_INTS, output_shape_attr_proto,
3178                               prim);
3179 }
3180 
ExportPrimConv2DTranspose(const FuncGraphPtr &,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * graph_proto)3181 void OnnxExporter::ExportPrimConv2DTranspose(const FuncGraphPtr &, const CNodePtr &node,
3182                                              std::map<AnfNodePtr, std::string> *node_map_ptr,
3183                                              onnx::GraphProto *graph_proto) {
3184   PrimConv2DTransposeExportHelper(node, nullptr, node_map_ptr, graph_proto);
3185 }
3186 
ExportPrimGreaterEqual(const FuncGraphPtr &,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)3187 void OnnxExporter::ExportPrimGreaterEqual(const FuncGraphPtr &, const CNodePtr &node,
3188                                           std::map<AnfNodePtr, std::string> *node_map_ptr,
3189                                           onnx::GraphProto *const graph_proto) {
3190   auto node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
3191 
3192   auto input_x_name = GetNodeInputName(node->input(kOneNum), node_map_ptr, graph_proto);
3193   auto input_y_name = GetNodeInputName(node->input(kTwoNum), node_map_ptr, graph_proto);
3194   auto less_name = node_name + "less";
3195 
3196   AddOp("Less", {input_x_name, input_y_name}, {less_name}, graph_proto);
3197   AddOp("Not", {less_name}, {node_name}, graph_proto);
3198 }
3199 
ExportPrimLessEqual(const FuncGraphPtr &,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)3200 void OnnxExporter::ExportPrimLessEqual(const FuncGraphPtr &, const CNodePtr &node,
3201                                        std::map<AnfNodePtr, std::string> *node_map_ptr,
3202                                        onnx::GraphProto *const graph_proto) {
3203   auto node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
3204 
3205   auto input_x_name = GetNodeInputName(node->input(kOneNum), node_map_ptr, graph_proto);
3206   auto input_y_name = GetNodeInputName(node->input(kTwoNum), node_map_ptr, graph_proto);
3207   auto greater_name = node_name + "greater";
3208 
3209   AddOp("Greater", {input_x_name, input_y_name}, {greater_name}, graph_proto);
3210   AddOp("Not", {greater_name}, {node_name}, graph_proto);
3211 }
3212 
ExportPrimNotEqual(const FuncGraphPtr &,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)3213 void OnnxExporter::ExportPrimNotEqual(const FuncGraphPtr &, const CNodePtr &node,
3214                                       std::map<AnfNodePtr, std::string> *node_map_ptr,
3215                                       onnx::GraphProto *const graph_proto) {
3216   auto node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
3217   auto input_x_name = GetNodeInputName(node->input(kOneNum), node_map_ptr, graph_proto);
3218   auto input_y_name = GetNodeInputName(node->input(kTwoNum), node_map_ptr, graph_proto);
3219   auto equal_name = node_name + "equal";
3220 
3221   AddOp("Equal", {input_x_name, input_y_name}, {equal_name}, graph_proto);
3222   AddOp("Not", {equal_name}, {node_name}, graph_proto);
3223 }
3224 
ExportPrimDense(const FuncGraphPtr & func_graph,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)3225 void OnnxExporter::ExportPrimDense(const FuncGraphPtr &func_graph, const CNodePtr &node,
3226                                    std::map<AnfNodePtr, std::string> *node_map_ptr,
3227                                    onnx::GraphProto *const graph_proto) {
3228   auto matmul_node = dyn_cast<CNode>(node->input(kOneNum));
3229   auto input_x = matmul_node->input(kOneNum);  // matmul input x
3230   auto input_y = matmul_node->input(kTwoNum);  // matmul input y
3231   auto input_b = node->input(kTwoNum);         // matmul bias
3232 
3233   PrimitivePtr prim_matmul = dyn_cast<Primitive>((dyn_cast<ValueNode>(matmul_node->input(kZeroNum)))->value());
3234   std::vector<AnfNodePtr> inputs{input_x, input_y, input_b};
3235   (*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim_matmul, inputs, graph_proto);
3236 }
3237 
ExportPrimSqueeze(const FuncGraphPtr &,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)3238 void OnnxExporter::ExportPrimSqueeze(const FuncGraphPtr &, const CNodePtr &node,
3239                                      std::map<AnfNodePtr, std::string> *node_map_ptr,
3240                                      onnx::GraphProto *const graph_proto) {
3241   auto node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
3242 
3243   auto input_name = GetNodeInputName(node->input(kOneNum), node_map_ptr, graph_proto);
3244 
3245   onnx::NodeProto *node_proto = graph_proto->add_node();
3246   node_proto->set_op_type("Squeeze");
3247   node_proto->add_input(input_name);
3248   node_proto->add_output(node_name);
3249 
3250   auto axes = GetOpAttributePtr<ValueSequence>(node, "axis");
3251   auto axes_value = GetValue<std::vector<int64_t>>(axes);
3252   if (!axes_value.empty()) {
3253     onnx::AttributeProto *axes_proto = node_proto->add_attribute();
3254     axes_proto->set_name("axes");
3255     axes_proto->set_type(onnx::AttributeProto_AttributeType_INTS);
3256     for (auto axis : axes_value) {
3257       axes_proto->add_ints(axis);
3258     }
3259   }
3260 }
3261 
MakeLSTMWeight(const std::string & input,const std::string & output,const std::vector<int64_t> & output_shape,onnx::GraphProto * graph_proto)3262 void MakeLSTMWeight(const std::string &input, const std::string &output, const std::vector<int64_t> &output_shape,
3263                     onnx::GraphProto *graph_proto) {
3264   auto reshaped_name = output + "__split";
3265   AddReshapeOp(input, reshaped_name, output_shape, graph_proto);
3266 
3267   auto split_i_name = output + "__concat_i";
3268   auto split_o_name = output + "__concat_o";
3269   auto split_f_name = output + "__concat_f";
3270   auto split_c_name = output + "__concat_c";
3271   int64_t hidden_size = output_shape[kOneNum] / kFourNum;
3272   AddSplitOp(reshaped_name, {split_i_name, split_f_name, split_c_name, split_o_name},
3273              {hidden_size, hidden_size, hidden_size, hidden_size}, 1, graph_proto);
3274 
3275   AddConcatOp({split_i_name, split_o_name, split_f_name, split_c_name}, output, 1, graph_proto);
3276 }
3277 
MakeLSTMWeight2(const std::string & input,const std::string & output,const std::vector<int64_t> & output_shape,onnx::GraphProto * graph_proto)3278 void MakeLSTMWeight2(const std::string &input, const std::string &output, const std::vector<int64_t> &output_shape,
3279                      onnx::GraphProto *graph_proto) {
3280   auto reshaped_name = output + "__split";
3281   AddReshapeOp(input, reshaped_name, output_shape, graph_proto);
3282 
3283   auto split_i_name = output + "__concat_i";
3284   auto split_o_name = output + "__concat_o";
3285   auto split_f_name = output + "__concat_f";
3286   auto split_c_name = output + "__concat_c";
3287   int64_t hidden_size = output_shape[kOneNum] / kFourNum;
3288   AddSplitOp(reshaped_name, {split_i_name, split_c_name, split_f_name, split_o_name},
3289              {hidden_size, hidden_size, hidden_size, hidden_size}, 1, graph_proto);
3290 
3291   AddConcatOp({split_i_name, split_o_name, split_f_name, split_c_name}, output, 1, graph_proto);
3292 }
3293 
ExportPrimDynamicRNN(const FuncGraphPtr &,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)3294 void OnnxExporter::ExportPrimDynamicRNN(const FuncGraphPtr &, const CNodePtr &node,
3295                                         std::map<AnfNodePtr, std::string> *node_map_ptr,
3296                                         onnx::GraphProto *const graph_proto) {
3297   auto node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
3298 
3299   auto x_input_name = GetNodeInputName(node->input(kOneNum), node_map_ptr, graph_proto);
3300   auto weight_input_name = GetNodeInputName(node->input(kTwoNum), node_map_ptr, graph_proto);
3301   auto bias_input_name = GetNodeInputName(node->input(kThreeNum), node_map_ptr, graph_proto);
3302   auto init_h_input_name = GetNodeInputName(node->input(kFiveNum), node_map_ptr, graph_proto);
3303   auto init_c_input_name = GetNodeInputName(node->input(kSixNum), node_map_ptr, graph_proto);
3304 
3305   auto hidden_size = GetOpAttribute<int64_t>(node, "hidden_size");
3306   auto direction_input = GetOpAttribute<std::string>(node, "direction");
3307   auto x_input_shape = dyn_cast<abstract::Shape>(node->input(kOneNum)->Shape())->shape();
3308   auto seq_len = x_input_shape[0];
3309   auto batch_size = x_input_shape[1];
3310   auto num_dir = direction_input == "UNIDIRECTIONAL" ? 1 : 2;
3311   auto input_size = x_input_shape[kTwoNum];
3312 
3313   auto onnx_input_weights_name = node_name + "_onnx_input_weights";
3314   auto onnx_hidden_weights_name = node_name + "_onnx_hidden_weights";
3315   auto onnx_bias_name = node_name + "_onnx_bias";
3316 
3317   const int num_gates = 4;
3318   auto gate_size = num_gates * hidden_size;
3319 
3320   auto weight_input_name_reshape = weight_input_name + "_reshape";
3321   AddReshapeOp(weight_input_name, weight_input_name_reshape, {(input_size * gate_size) + (hidden_size * gate_size)},
3322                graph_proto);
3323 
3324   auto input_weights_name = node_name + "_input_weights";
3325   auto hidden_weights_name = node_name + "_hidden_weights";
3326   std::vector<int64_t> split_sizes = {input_size * gate_size, hidden_size * gate_size};
3327   std::vector<std::string> split_outputs = {input_weights_name, hidden_weights_name};
3328 
3329   AddSplitOp(weight_input_name_reshape, split_outputs, split_sizes, 0, graph_proto);
3330 
3331   auto input_weights_name_reshape = input_weights_name + "_reshape";
3332   auto hidden_weights_name_reshape = hidden_weights_name + "_reshape";
3333   AddReshapeOp(input_weights_name, input_weights_name_reshape, {num_dir, input_size, gate_size}, graph_proto);
3334   AddReshapeOp(hidden_weights_name, hidden_weights_name_reshape, {num_dir, hidden_size, gate_size}, graph_proto);
3335 
3336   // Transpose input_weights_name
3337   onnx::NodeProto *transpose_node_proto_1 = graph_proto->add_node();
3338   auto input_weights_name_reshape_transposed = input_weights_name_reshape + "_transposed";
3339   transpose_node_proto_1->set_name(input_weights_name_reshape_transposed);
3340   transpose_node_proto_1->set_op_type("Transpose");
3341   transpose_node_proto_1->add_input(input_weights_name_reshape);
3342   transpose_node_proto_1->add_output(input_weights_name_reshape_transposed);
3343 
3344   onnx::AttributeProto *perm_proto_1 = transpose_node_proto_1->add_attribute();
3345   perm_proto_1->set_name("perm");
3346   perm_proto_1->set_type(onnx::AttributeProto_AttributeType_INTS);
3347   perm_proto_1->add_ints(kZeroNum);
3348   perm_proto_1->add_ints(kTwoNum);
3349   perm_proto_1->add_ints(kOneNum);
3350 
3351   // Transpose  hidden_weights_name
3352   onnx::NodeProto *transpose_node_proto_2 = graph_proto->add_node();
3353   auto hidden_weights_name_reshape_transposed = hidden_weights_name_reshape + "_transposed";
3354   transpose_node_proto_2->set_name(hidden_weights_name_reshape_transposed);
3355   transpose_node_proto_2->set_op_type("Transpose");
3356   transpose_node_proto_2->add_input(hidden_weights_name_reshape);
3357   transpose_node_proto_2->add_output(hidden_weights_name_reshape_transposed);
3358 
3359   onnx::AttributeProto *perm_proto_2 = transpose_node_proto_2->add_attribute();
3360   perm_proto_2->set_name("perm");
3361   perm_proto_2->set_type(onnx::AttributeProto_AttributeType_INTS);
3362   perm_proto_2->add_ints(kZeroNum);
3363   perm_proto_2->add_ints(kTwoNum);
3364   perm_proto_2->add_ints(kOneNum);
3365 
3366   MakeLSTMWeight2(input_weights_name_reshape_transposed, onnx_input_weights_name, {num_dir, gate_size, input_size},
3367                   graph_proto);
3368   MakeLSTMWeight2(hidden_weights_name_reshape_transposed, onnx_hidden_weights_name, {num_dir, gate_size, hidden_size},
3369                   graph_proto);
3370 
3371   auto bias_input_name_reshape = bias_input_name + "_reshape";
3372   AddReshapeOp(bias_input_name, bias_input_name_reshape, {num_dir, gate_size}, graph_proto);
3373 
3374   auto bias_output_name = node_name + "_bias_output_name";
3375   MakeLSTMWeight2(bias_input_name_reshape, bias_output_name, {num_dir, gate_size}, graph_proto);
3376 
3377   auto bias_concat = bias_output_name + "_concat";
3378   std::vector<std::string> concat_inputs = {bias_output_name, bias_output_name};
3379   AddConcatOp(concat_inputs, bias_concat, 1, graph_proto);
3380 
3381   auto div_second_operand_name = node_name + "_div_second_operand";
3382   const float div_second_operand = 2.0;
3383   AddFloatScalarInitializer(div_second_operand_name, div_second_operand, onnx::TensorProto_DataType_FLOAT16,
3384                             graph_proto);
3385 
3386   AddOp("Div", {bias_concat, div_second_operand_name}, {onnx_bias_name}, graph_proto);
3387 
3388   // Create LSTM node
3389   onnx::NodeProto *lstm_node_proto = graph_proto->add_node();
3390   lstm_node_proto->set_op_type("LSTM");
3391   lstm_node_proto->add_input(x_input_name);
3392   lstm_node_proto->add_input(onnx_input_weights_name);
3393   lstm_node_proto->add_input(onnx_hidden_weights_name);
3394   lstm_node_proto->add_input(onnx_bias_name);
3395   lstm_node_proto->add_input("");
3396   lstm_node_proto->add_input(init_h_input_name);
3397   lstm_node_proto->add_input(init_c_input_name);
3398 
3399   auto Y_output_name = node_name + "_Y";
3400   auto Y_h_output_name = node_name + "_Y_h";
3401   auto Y_c_output_name = node_name + "_Y_c";
3402   lstm_node_proto->add_output(Y_output_name);
3403   lstm_node_proto->add_output(Y_h_output_name);
3404   lstm_node_proto->add_output(Y_c_output_name);
3405 
3406   onnx::AttributeProto *hidden_size_proto = lstm_node_proto->add_attribute();
3407   hidden_size_proto->set_name("hidden_size");
3408   hidden_size_proto->set_type(onnx::AttributeProto_AttributeType_INT);
3409   hidden_size_proto->set_i(hidden_size);
3410 
3411   auto output_name_Y = MakeOutputName(node_name, kZeroNum);
3412   auto output_name_Y_h = MakeOutputName(node_name, kOneNum);
3413   auto output_name_Y_c = MakeOutputName(node_name, kTwoNum);
3414   AddReshapeOp(Y_output_name, output_name_Y, {seq_len, batch_size, num_dir * hidden_size}, graph_proto);
3415   AddExpandOp(Y_h_output_name, output_name_Y_h, {seq_len, batch_size, hidden_size}, graph_proto);
3416   AddExpandOp(Y_c_output_name, output_name_Y_c, {seq_len, batch_size, hidden_size}, graph_proto);
3417 }
3418 
ExportLSTMWeights(const CNodePtr & node,const std::string & node_name,const std::string & weights_name,onnx::TensorProto_DataType dtype,const std::string & onnx_input_weights_name,const std::string & onnx_hidden_weights_name,const std::string & onnx_bias_name,onnx::GraphProto * graph_proto)3419 void ExportLSTMWeights(const CNodePtr &node, const std::string &node_name, const std::string &weights_name,
3420                        onnx::TensorProto_DataType dtype, const std::string &onnx_input_weights_name,
3421                        const std::string &onnx_hidden_weights_name, const std::string &onnx_bias_name,
3422                        onnx::GraphProto *graph_proto) {
3423   auto input_size = GetOpAttribute<int64_t>(node, "input_size");
3424   auto hidden_size = GetOpAttribute<int64_t>(node, "hidden_size");
3425   auto num_layers = GetOpAttribute<int64_t>(node, "num_layers");
3426   auto has_bias = GetOpAttribute<bool>(node, "has_bias");
3427   auto bidirectional = GetOpAttribute<bool>(node, "bidirectional");
3428   auto num_dir = 1 + static_cast<int>(bidirectional);
3429   auto num_gates = 4;
3430   auto gate_size = num_gates * hidden_size;
3431 
3432   if (num_layers != 1) {
3433     MS_LOG(EXCEPTION) << "Converter for multilayer LSTM is not implemented";
3434   }
3435   if (bidirectional) {
3436     MS_LOG(EXCEPTION) << "Bidirectional mode for P.LSTM is not implemented";
3437   }
3438   auto ms_context = MsContext::GetInstance();
3439   MS_EXCEPTION_IF_NULL(ms_context);
3440   auto target_device = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
3441   if (target_device != "CPU" && target_device != "GPU") {
3442     MS_LOG(EXCEPTION) << "Unsupported target device: " << target_device;
3443   }
3444 
3445   auto input_weights_name = node_name + "_input_weights";
3446   auto hidden_weights_name = node_name + "_hidden_weights";
3447   auto input_bias_name = node_name + "_input_bias";
3448   auto hidden_bias_name = node_name + "_hidden_bias";
3449 
3450   std::vector<int64_t> split_sizes = {input_size * gate_size, hidden_size * gate_size};
3451   std::vector<std::string> split_outputs = {input_weights_name, hidden_weights_name};
3452   if (has_bias) {
3453     if (target_device == "GPU") {
3454       (void)split_sizes.insert(split_sizes.end(), {gate_size, gate_size});
3455       (void)split_outputs.insert(split_outputs.end(), {input_bias_name, hidden_bias_name});
3456     } else if (target_device == "CPU") {
3457       split_sizes.push_back(gate_size);
3458       split_outputs.push_back(input_bias_name);
3459     } else {
3460       MS_LOG(EXCEPTION) << "Impossible branch";
3461     }
3462   }
3463   AddSplitOp(weights_name, split_outputs, split_sizes, 0, graph_proto);
3464 
3465   MakeLSTMWeight(input_weights_name, onnx_input_weights_name, {num_dir, gate_size, input_size}, graph_proto);
3466   MakeLSTMWeight(hidden_weights_name, onnx_hidden_weights_name, {num_dir, gate_size, hidden_size}, graph_proto);
3467   if (has_bias) {
3468     auto onnx_input_bias_name = node_name + "_onnx_input_bias";
3469     auto onnx_hidden_bias_name = node_name + "_onnx_hidden_bias";
3470     if (target_device == "GPU") {
3471       MakeLSTMWeight(input_bias_name, onnx_input_bias_name, {num_dir, gate_size}, graph_proto);
3472       MakeLSTMWeight(hidden_bias_name, onnx_hidden_bias_name, {num_dir, gate_size}, graph_proto);
3473     } else if (target_device == "CPU") {
3474       MakeLSTMWeight(input_bias_name, onnx_input_bias_name, {num_dir, gate_size}, graph_proto);
3475       auto bias_shape_name = node_name + "_bias_shape";
3476       AddOp("Shape", {onnx_input_bias_name}, {bias_shape_name}, graph_proto);
3477       onnx::TensorProto *zero_padding = AddConstantOfShapeOp(bias_shape_name, onnx_hidden_bias_name, graph_proto);
3478       zero_padding->set_data_type(dtype);
3479       if (dtype == onnx::TensorProto_DataType_FLOAT16) {
3480         zero_padding->add_int32_data(0);  // float 0 and int 0 have identical representations
3481       } else if (dtype == onnx::TensorProto_DataType_FLOAT) {
3482         zero_padding->add_float_data(0.0f);
3483       } else {
3484         MS_LOG(EXCEPTION) << "Unsupported type: " << dtype;
3485       }
3486     } else {
3487       MS_LOG(EXCEPTION) << "Impossible branch";
3488     }
3489     AddConcatOp({onnx_input_bias_name, onnx_hidden_bias_name}, onnx_bias_name, 1, graph_proto);
3490   }
3491 }
3492 
ExportPrimLSTM(const FuncGraphPtr &,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)3493 void OnnxExporter::ExportPrimLSTM(const FuncGraphPtr &, const CNodePtr &node,
3494                                   std::map<AnfNodePtr, std::string> *node_map_ptr,
3495                                   onnx::GraphProto *const graph_proto) {
3496   auto node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
3497 
3498   // MS inputs
3499   auto x_input_name = GetNodeInputName(node->input(kOneNum), node_map_ptr, graph_proto);
3500   auto init_h_input_name = GetNodeInputName(node->input(kTwoNum), node_map_ptr, graph_proto);
3501   auto init_c_input_name = GetNodeInputName(node->input(kThreeNum), node_map_ptr, graph_proto);
3502 
3503   auto hidden_size = GetOpAttribute<int64_t>(node, "hidden_size");
3504   auto has_bias = GetOpAttribute<bool>(node, "has_bias");
3505   auto bidirectional = GetOpAttribute<bool>(node, "bidirectional");
3506   std::string direction = bidirectional ? "bidirectional" : "forward";
3507   auto x_input_shape = dyn_cast<abstract::Shape>(node->input(kOneNum)->Shape())->shape();
3508   auto seq_len = x_input_shape[0];
3509   auto batch_size = x_input_shape[1];
3510   auto num_dir = 1 + static_cast<int>(bidirectional);
3511 
3512   auto weights_name = GetNodeInputName(node->input(kFourNum), node_map_ptr, graph_proto);
3513   auto dtype = GetOutputType(node->input(kOneNum));
3514   auto onnx_input_weights_name = node_name + "_onnx_input_weights";
3515   auto onnx_hidden_weights_name = node_name + "_onnx_hidden_weights";
3516   auto onnx_bias_name = node_name + "_onnx_bias";
3517 
3518   ExportLSTMWeights(node, node_name, weights_name, dtype, onnx_input_weights_name, onnx_hidden_weights_name,
3519                     onnx_bias_name, graph_proto);
3520 
3521   // Create LSTM node
3522   onnx::NodeProto *lstm_node_proto = graph_proto->add_node();
3523   lstm_node_proto->set_op_type("LSTM");
3524   lstm_node_proto->add_input(x_input_name);
3525   lstm_node_proto->add_input(onnx_input_weights_name);
3526   lstm_node_proto->add_input(onnx_hidden_weights_name);
3527   lstm_node_proto->add_input(has_bias ? onnx_bias_name : "");
3528   lstm_node_proto->add_input("");  // seqlens
3529   lstm_node_proto->add_input(init_h_input_name);
3530   lstm_node_proto->add_input(init_c_input_name);
3531 
3532   auto Y_output_name = node_name + "_Y";
3533   lstm_node_proto->add_output(Y_output_name);
3534   lstm_node_proto->add_output(MakeOutputName(node_name, kOneNum));
3535   lstm_node_proto->add_output(MakeOutputName(node_name, kTwoNum));
3536 
3537   onnx::AttributeProto *hidden_size_proto = lstm_node_proto->add_attribute();
3538   hidden_size_proto->set_name("hidden_size");
3539   hidden_size_proto->set_type(onnx::AttributeProto_AttributeType_INT);
3540   hidden_size_proto->set_i(hidden_size);
3541 
3542   onnx::AttributeProto *direction_proto = lstm_node_proto->add_attribute();
3543   direction_proto->set_name("direction");
3544   direction_proto->set_type(onnx::AttributeProto_AttributeType_STRING);
3545   direction_proto->set_s(direction);
3546 
3547   // Transpose 1st output of the LSTM node
3548   onnx::NodeProto *transpose_node_proto = graph_proto->add_node();
3549   auto transpose_node_name = node_name + "_Y_transposed";
3550   transpose_node_proto->set_name(transpose_node_name);
3551   transpose_node_proto->set_op_type("Transpose");
3552   transpose_node_proto->add_input(Y_output_name);
3553   transpose_node_proto->add_output(transpose_node_name);
3554 
3555   onnx::AttributeProto *perm_proto = transpose_node_proto->add_attribute();
3556   perm_proto->set_name("perm");
3557   perm_proto->set_type(onnx::AttributeProto_AttributeType_INTS);
3558   perm_proto->add_ints(kZeroNum);
3559   perm_proto->add_ints(kTwoNum);
3560   perm_proto->add_ints(kOneNum);
3561   perm_proto->add_ints(kThreeNum);
3562 
3563   // Reshape
3564   auto output_name = MakeOutputName(node_name, kZeroNum);
3565   AddReshapeOp(transpose_node_name, output_name, {seq_len, batch_size, num_dir * hidden_size}, graph_proto);
3566 }
3567 
ExportPrimReverseV2(const FuncGraphPtr &,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)3568 void OnnxExporter::ExportPrimReverseV2(const FuncGraphPtr &, const CNodePtr &node,
3569                                        std::map<AnfNodePtr, std::string> *node_map_ptr,
3570                                        onnx::GraphProto *const graph_proto) {
3571   auto output = RegisterNodeWithUniqueName(node, node_map_ptr);
3572   auto input = GetNodeInputName(node->input(kOneNum), node_map_ptr, graph_proto);
3573 
3574   auto axes_ptr = GetOpAttributePtr<ValueSequeue>(node, "axis");
3575   auto axes_vec = GetValue<std::vector<int64_t>>(axes_ptr);
3576   size_t n_axes = axes_vec.size();
3577   auto shape = dyn_cast<abstract::Shape>(node->input(kOneNum)->Shape())->shape();
3578 
3579   std::vector<int64_t> starts_vec(n_axes, -1);
3580   std::vector<int64_t> ends_vec(n_axes);
3581   (void)std::transform(axes_vec.begin(), axes_vec.end(), ends_vec.begin(),
3582                        [&shape](size_t ax) { return -shape.at(ax) - 1; });
3583   std::vector<int64_t> steps_vec(n_axes, -1);
3584 
3585   AddSliceOp(input, output, starts_vec, ends_vec, axes_vec, steps_vec, graph_proto);
3586 }
3587 
ExportPrimTensorCopySlices(const FuncGraphPtr &,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * graph_proto)3588 void OnnxExporter::ExportPrimTensorCopySlices(const FuncGraphPtr &, const CNodePtr &node,
3589                                               std::map<AnfNodePtr, std::string> *node_map_ptr,
3590                                               onnx::GraphProto *graph_proto) {
3591   auto node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
3592 
3593   auto x_input = node->input(kOneNum);
3594   auto value_input = node->input(kTwoNum);
3595 
3596   auto x_input_name = GetNodeInputName(x_input, node_map_ptr, graph_proto);
3597   auto value_input_name = GetNodeInputName(value_input, node_map_ptr, graph_proto);
3598 
3599   const auto &x_shape = dyn_cast<abstract::Shape>(x_input->Shape())->shape();
3600   const auto &value_shape = dyn_cast<abstract::Shape>(value_input->Shape())->shape();
3601 
3602   auto begin_node = dyn_cast<ValueNode>(node->input(kThreeNum));
3603   MS_EXCEPTION_IF_NULL(begin_node);
3604   auto begin = GetValue<std::vector<int64_t>>(begin_node->value());
3605 
3606   auto end_node = dyn_cast<ValueNode>(node->input(kFourNum));
3607   MS_EXCEPTION_IF_NULL(end_node);
3608   auto end = GetValue<std::vector<int64_t>>(end_node->value());
3609 
3610   auto strides_node = dyn_cast<ValueNode>(node->input(kFiveNum));
3611   MS_EXCEPTION_IF_NULL(strides_node);
3612   auto strides = GetValue<std::vector<int64_t>>(strides_node->value());
3613 
3614   MS_EXCEPTION_IF_CHECK_FAIL(
3615     begin.size() == end.size() && end.size() == strides.size() && strides.size() <= x_shape.size(),
3616     "Sizes of begin, end, and strides must be equal");
3617   // MindSpore only allows contuguous slices of memory
3618   // Contiguous slice size follows the pattern: [1, ..., 1, n, :, ..., :]
3619   bool found_slice = false;
3620   for (size_t i = 0; i < begin.size(); ++i) {
3621     int64_t dim = end[i] - begin[i];
3622     if (!found_slice && dim != 1) {
3623       found_slice = true;
3624     } else if (found_slice && dim != x_shape[i]) {
3625       MS_LOG(EXCEPTION) << "Slice must be contiguous";
3626     }
3627   }
3628   for (auto stride : strides) {
3629     MS_EXCEPTION_IF_CHECK_FAIL(stride == 1, "Slice must be contiguous");
3630   }
3631 
3632   int64_t flat_begin_index = RavelIndex(begin, x_shape);
3633 
3634   std::vector<int64_t> end_inclusive;
3635   (void)std::transform(end.begin(), end.end(), std::back_inserter(end_inclusive), [](auto x) { return x - 1; });
3636   (void)std::transform(x_shape.begin() + static_cast<int64_t>(end.size()), x_shape.end(),
3637                        std::back_inserter(end_inclusive), [](auto x) { return x - 1; });
3638   int64_t flat_end_index = RavelIndex(end_inclusive, x_shape) + 1;
3639 
3640   int64_t x_size = std::accumulate(x_shape.begin(), x_shape.end(), 1, std::multiplies<int64_t>());
3641   int64_t value_size = std::accumulate(value_shape.begin(), value_shape.end(), 1, std::multiplies<int64_t>());
3642   MS_EXCEPTION_IF_CHECK_FAIL(value_size == flat_end_index - flat_begin_index, "Cannot copy 'value' to target slice");
3643 
3644   auto flat_x_name = node_name + "_flat_x";
3645   AddReshapeOp(x_input_name, flat_x_name, {-1}, graph_proto);
3646   auto begin_slice_name = node_name + "_begin_slice";
3647   AddSliceOp(flat_x_name, begin_slice_name, {0}, {static_cast<int64_t>(flat_begin_index)}, {0}, {1}, graph_proto);
3648   auto end_slice_name = node_name + "_end_slice";
3649   AddSliceOp(flat_x_name, end_slice_name, {static_cast<int64_t>(flat_end_index)}, {x_size}, {0}, {1}, graph_proto);
3650 
3651   auto flat_value_name = node_name + "_flat_value";
3652   AddReshapeOp(value_input_name, flat_value_name, {-1}, graph_proto);
3653 
3654   auto flat_result_name = node_name + "_flat_result";
3655   AddConcatOp({begin_slice_name, flat_value_name, end_slice_name}, flat_result_name, 0, graph_proto);
3656   AddReshapeOp(flat_result_name, node_name, x_shape, graph_proto);
3657 }
3658 
ExportPrimStack(const FuncGraphPtr &,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * graph_proto)3659 void OnnxExporter::ExportPrimStack(const FuncGraphPtr &, const CNodePtr &node,
3660                                    std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto) {
3661   auto node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
3662 
3663   auto input_name = GetNodeInputName(node->input(kOneNum), node_map_ptr, graph_proto);
3664 
3665   onnx::NodeProto *node_proto = graph_proto->add_node();
3666   node_proto->set_name(node_name + "Stack");
3667   node_proto->set_op_type("ConcatFromSequence");
3668   node_proto->add_input(input_name);
3669   node_proto->add_output(node_name);
3670 
3671   onnx::AttributeProto *axis_proto = node_proto->add_attribute();
3672   axis_proto->set_name("axis");
3673   axis_proto->set_type(onnx::AttributeProto_AttributeType_INT);
3674   axis_proto->set_i(GetOpAttribute<int64_t>(node, "axis"));
3675 
3676   onnx::AttributeProto *new_axis_proto = node_proto->add_attribute();
3677   new_axis_proto->set_name("new_axis");
3678   new_axis_proto->set_type(onnx::AttributeProto_AttributeType_INT);
3679   new_axis_proto->set_i(true);
3680 }
3681 
ExportPrimAtan2(const FuncGraphPtr &,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * graph_proto)3682 void OnnxExporter::ExportPrimAtan2(const FuncGraphPtr &, const CNodePtr &node,
3683                                    std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto) {
3684   auto node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
3685   auto input_node1_anf = node->input(kOneNum);
3686   auto input_node2_anf = node->input(kTwoNum);
3687   auto input_node1 = GetNodeInputName(input_node1_anf, node_map_ptr, graph_proto);
3688   auto input_node2 = GetNodeInputName(input_node2_anf, node_map_ptr, graph_proto);
3689   auto atan_node = "Atan2_" + node_name + "_atan";
3690   auto div_node = "Atan2_" + node_name + "_div";
3691   auto less_node = "Atan2_" + node_name + "_less";
3692   auto zero_value = "Atan2_" + node_name + "_zero";
3693   auto neg_pi_value = "Atan2_" + node_name + "_pi";
3694   auto minimal_value = "Atan2_" + node_name + "_minimal_val";
3695   auto sign_node = "Atan2_" + node_name + "_sign";
3696   auto mul_node = "Atan2_" + node_name + "_mul";
3697   auto less_where_node1 = "Atan2_" + node_name + "_less_then_else1";
3698   auto add_node = "Atan2_" + node_name + "_add1";
3699   if (!(IsFloatDataType(input_node1_anf) && IsFloatDataType(input_node2_anf))) {
3700     auto input_node1_cast = node_name + "_div_cast_fp32_1";
3701     auto input_node2_cast = node_name + "_div_cast_fp32_2";
3702     AddCastOp(input_node1, input_node1_cast, onnx::TensorProto_DataType_FLOAT, graph_proto);
3703     AddCastOp(input_node2, input_node2_cast, onnx::TensorProto_DataType_FLOAT, graph_proto);
3704     input_node1 = input_node1_cast;
3705     input_node2 = input_node2_cast;
3706   }
3707   AddFloatScalarInitializer(minimal_value, 1e-10, onnx::TensorProto_DataType_FLOAT,
3708                             graph_proto);  // minimal_value, avoid division by zero
3709   AddOp("Add", {input_node2, minimal_value}, {add_node}, graph_proto);
3710   AddOp("Div", {input_node1, add_node}, {div_node}, graph_proto);
3711   AddOp("Atan", {div_node}, {atan_node}, graph_proto);
3712   AddFloatScalarInitializer(zero_value, 0, onnx::TensorProto_DataType_FLOAT, graph_proto);
3713   AddOp("Less", {input_node2, zero_value}, {less_node}, graph_proto);
3714   AddFloatScalarInitializer(neg_pi_value, -acos(-1), onnx::TensorProto_DataType_FLOAT, graph_proto);  // -PI
3715   AddOp("Sign", {atan_node}, {sign_node}, graph_proto);
3716   AddOp("Mul", {neg_pi_value, sign_node}, {mul_node}, graph_proto);
3717   AddOp("Where", {less_node, mul_node, zero_value}, {less_where_node1}, graph_proto);
3718   AddOp("Add", {less_where_node1, atan_node}, {node_name}, graph_proto);
3719 }
3720 
ExportPrimFloorDiv(const FuncGraphPtr &,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * graph_proto)3721 void OnnxExporter::ExportPrimFloorDiv(const FuncGraphPtr &, const CNodePtr &node,
3722                                       std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto) {
3723   auto node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
3724   auto out_name = node_name;
3725   auto input_x_name = GetNodeInputName(node->input(kOneNum), node_map_ptr, graph_proto);
3726   auto input_y_name = GetNodeInputName(node->input(kTwoNum), node_map_ptr, graph_proto);
3727   auto onnx_type = GetOutputType(node->input(kOneNum));
3728   bool is_float = onnx_type == onnx::TensorProto_DataType_FLOAT;
3729 
3730   if (!is_float) {
3731     auto input_x_name_cast = input_x_name + "_cast";
3732     auto input_y_name_cast = input_y_name + "_cast";
3733     AddCastOp(input_x_name, input_x_name_cast, onnx::TensorProto_DataType_FLOAT, graph_proto);
3734     AddCastOp(input_y_name, input_y_name_cast, onnx::TensorProto_DataType_FLOAT, graph_proto);
3735     input_x_name = input_x_name_cast;
3736     input_y_name = input_y_name_cast;
3737     node_name = node_name + "_floor";
3738   }
3739 
3740   auto div_name = node_name + "_div";
3741   AddOp("Div", {input_x_name, input_y_name}, {div_name}, graph_proto);
3742   AddOp("Floor", {div_name}, {node_name}, graph_proto);
3743 
3744   if (!is_float) {
3745     AddCastOp(node_name, out_name, onnx_type, graph_proto);
3746   }
3747 }
3748 
ExportPrimFloorMod(const FuncGraphPtr &,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * graph_proto)3749 void OnnxExporter::ExportPrimFloorMod(const FuncGraphPtr &, const CNodePtr &node,
3750                                       std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto) {
3751   auto node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
3752   auto out_name = node_name;
3753   auto input_x_name = GetNodeInputName(node->input(kOneNum), node_map_ptr, graph_proto);
3754   auto input_y_name = GetNodeInputName(node->input(kTwoNum), node_map_ptr, graph_proto);
3755   auto onnx_type = GetOutputType(node->input(kOneNum));
3756   bool is_float = onnx_type == onnx::TensorProto_DataType_FLOAT;
3757 
3758   if (!is_float) {
3759     auto input_x_name_cast = input_x_name + "_cast";
3760     auto input_y_name_cast = input_y_name + "_cast";
3761     AddCastOp(input_x_name, input_x_name_cast, onnx::TensorProto_DataType_FLOAT, graph_proto);
3762     AddCastOp(input_y_name, input_y_name_cast, onnx::TensorProto_DataType_FLOAT, graph_proto);
3763     input_x_name = input_x_name_cast;
3764     input_y_name = input_y_name_cast;
3765     node_name = node_name + "_sub";
3766   }
3767 
3768   auto div_name = node_name + "_div";
3769   auto mul_name = node_name + "_mul";
3770   auto floor_name = node_name + "_floor";
3771   AddOp("Div", {input_x_name, input_y_name}, {div_name}, graph_proto);
3772   AddOp("Floor", {div_name}, {floor_name}, graph_proto);
3773   AddOp("Mul", {floor_name, input_y_name}, {mul_name}, graph_proto);
3774   AddOp("Sub", {input_x_name, mul_name}, {node_name}, graph_proto);
3775 
3776   if (!is_float) {
3777     AddCastOp(node_name, out_name, onnx_type, graph_proto);
3778   }
3779 }
3780 
ExportPrimSort(const FuncGraphPtr &,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * graph_proto)3781 void OnnxExporter::ExportPrimSort(const FuncGraphPtr &, const CNodePtr &node,
3782                                   std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto) {
3783   auto node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
3784 
3785   auto x_input = node->input(kOneNum);
3786   auto x_input_name = GetNodeInputName(x_input, node_map_ptr, graph_proto);
3787   auto x_input_shape = dyn_cast<abstract::Shape>(x_input->Shape())->shape();
3788 
3789   auto axis_attr = GetOpAttribute<int64_t>(node, "axis");
3790   auto descending_attr = GetOpAttribute<bool>(node, "descending");
3791 
3792   onnx::NodeProto *node_proto = graph_proto->add_node();
3793   node_proto->set_name(node_name + "TopK");
3794   node_proto->set_op_type("TopK");
3795   node_proto->add_input(x_input_name);
3796 
3797   onnx::TensorProto *k_initializer_proto = graph_proto->add_initializer();
3798   auto k_input_name = "k";
3799   k_initializer_proto->set_name(k_input_name);
3800   k_initializer_proto->add_dims(static_cast<int64_t>(1));
3801   k_initializer_proto->set_data_type(GetOnnxDataType(kNumberTypeInt64));
3802   int64_t k_index = axis_attr;
3803   if (axis_attr < 0) {
3804     k_index += SizeToLong(x_input_shape.size());
3805   }
3806   if (k_index > SizeToLong(x_input_shape.size()) - 1 || k_index < 0) {
3807     MS_LOG(EXCEPTION) << "Invalid axis value: " << axis_attr;
3808   }
3809   int64_t k_value = x_input_shape[k_index];
3810   k_initializer_proto->add_int64_data(k_value);
3811   node_proto->add_input(k_input_name);
3812 
3813   node_proto->add_output(MakeOutputName(node_name, kZeroNum));
3814   auto indices_output_name = MakeOutputName(node_name, kOneNum);
3815   auto indices_cast_name = indices_output_name + "_cast";
3816   node_proto->add_output(indices_cast_name);
3817   AddCastOp(indices_cast_name, indices_output_name, onnx::TensorProto_DataType_INT32, graph_proto);
3818 
3819   onnx::AttributeProto *axis_attr_proto = node_proto->add_attribute();
3820   axis_attr_proto->set_name("axis");
3821   axis_attr_proto->set_type(onnx::AttributeProto_AttributeType_INT);
3822   axis_attr_proto->set_i(axis_attr);
3823 
3824   onnx::AttributeProto *largest_attr_proto = node_proto->add_attribute();
3825   largest_attr_proto->set_name("largest");
3826   largest_attr_proto->set_type(onnx::AttributeProto_AttributeType_INT);
3827   if (descending_attr) {
3828     largest_attr_proto->set_i(kOneNum);
3829   } else {
3830     largest_attr_proto->set_i(kZeroNum);
3831   }
3832 
3833   onnx::AttributeProto *sorted_attr_proto = node_proto->add_attribute();
3834   sorted_attr_proto->set_name("sorted");
3835   sorted_attr_proto->set_type(onnx::AttributeProto_AttributeType_INT);
3836   sorted_attr_proto->set_i(1);
3837 }
3838 
ExportPrimCustom(const FuncGraphPtr &,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * graph_proto)3839 void OnnxExporter::ExportPrimCustom(const FuncGraphPtr &, const CNodePtr &node,
3840                                     std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto) {
3841   auto node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
3842   onnx::NodeProto *node_proto = graph_proto->add_node();
3843   node_proto->set_name("Custom_" + node_name);
3844   mindspore::HashSet<size_t> input_attrs;
3845 
3846   constexpr auto kAttrCusInputNames = "input_names";
3847   constexpr auto kAttrCusAttrNames = "attr_names";
3848   auto input_names_vec = GetOpAttribute<std::vector<std::string>>(node, kAttrCusInputNames);
3849   auto primitive = GetPrimitive(node);
3850   auto attr_names = primitive->GetAttr(kAttrCusAttrNames);
3851   if (attr_names != nullptr) {
3852     auto attr_names_vec = GetValue<std::vector<std::string>>(attr_names);
3853     for (size_t i = 0; i < input_names_vec.size(); ++i) {
3854       if (std::find(attr_names_vec.begin(), attr_names_vec.end(), input_names_vec[i]) != attr_names_vec.end()) {
3855         (void)input_attrs.insert(i);
3856       }
3857     }
3858   }
3859 
3860   auto inputs = node->inputs();
3861   std::vector<AnfNodePtr> real_inputs;
3862 
3863   for (size_t i = 0; i < inputs.size() - 1; ++i) {
3864     auto input_node = inputs[i + 1];
3865     MS_EXCEPTION_IF_NULL(input_node);
3866     if (input_attrs.find(i) != input_attrs.end() && input_node->isa<ValueNode>() && !HasAbstractMonad(input_node)) {
3867       auto value_node = input_node->cast<ValueNodePtr>();
3868       MS_EXCEPTION_IF_NULL(value_node);
3869       auto attr_value = value_node->value();
3870       if (attr_value->isa<StringImm>()) {
3871         auto str_attr = GetValue<std::string>(attr_value);
3872         onnx::AttributeProto *str_proto = node_proto->add_attribute();
3873         str_proto->set_name(input_names_vec[i]);
3874         str_proto->set_type(onnx::AttributeProto_AttributeType_STRING);
3875         str_proto->set_s(str_attr);
3876       } else if (attr_value->isa<IntegerImm>()) {
3877         int64_t int64_attr = attr_value->cast<Int64ImmPtr>()->value();
3878         onnx::AttributeProto *int64_proto = node_proto->add_attribute();
3879         int64_proto->set_name(input_names_vec[i]);
3880         int64_proto->set_type(onnx::AttributeProto_AttributeType_INT);
3881         int64_proto->set_i(int64_attr);
3882       } else if (attr_value->isa<FloatImm>()) {
3883         float fp32_attr = attr_value->cast<FP32ImmPtr>()->value();
3884         onnx::AttributeProto *fp32_proto = node_proto->add_attribute();
3885         fp32_proto->set_name(input_names_vec[i]);
3886         fp32_proto->set_type(onnx::AttributeProto_AttributeType_FLOAT);
3887         fp32_proto->set_f(fp32_attr);
3888       } else {
3889         MS_LOG(EXCEPTION) << "Unsupported attr input type: " << attr_value->ToString();
3890       }
3891     } else {
3892       real_inputs.push_back(inputs[i + 1]);
3893     }
3894   }
3895 
3896   for (size_t idx = 0; idx < real_inputs.size(); idx++) {
3897     auto input_name = GetNodeInputName(real_inputs[idx], node_map_ptr, graph_proto);
3898     node_proto->add_input(input_name);
3899   }
3900 
3901   node_proto->add_output(node_name);
3902   node_proto->set_op_type(GetOpAttribute<std::string>(node, "reg_op_name"));
3903 }
3904 
ExportCNode(const FuncGraphPtr & func_graph,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)3905 void OnnxExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &node,
3906                                std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *const graph_proto) {
3907   using ExportFunc = std::function<void(OnnxExporter *, const FuncGraphPtr &, const CNodePtr &,
3908                                         std::map<AnfNodePtr, std::string> *, onnx::GraphProto *const)>;
3909   static std::vector<std::pair<PrimitivePtr, ExportFunc>> export_table = {
3910     {prim::kPrimReshape, &OnnxExporter::ExportPrimReshape},
3911     {prim::kPrimReduceMean, &OnnxExporter::ExportPrimReduce},
3912     {prim::kPrimReduceSum, &OnnxExporter::ExportPrimReduce},
3913     {prim::kPrimReduceMax, &OnnxExporter::ExportPrimReduce},
3914     {prim::kPrimReduceAny, &OnnxExporter::ExportPrimReduceAnyOrAll},
3915     {prim::kPrimReduceAll, &OnnxExporter::ExportPrimReduceAnyOrAll},
3916     {prim::kPrimTranspose, &OnnxExporter::ExportPrimTranspose},
3917     {prim::kPrimStridedSlice, &OnnxExporter::ExportPrimStridedSlice},
3918     {prim::kPrimResizeNearestNeighbor, &OnnxExporter::ExportPrimResizeNearestNeighbor},
3919     {prim::kPrimResizeBilinearV2, &OnnxExporter::ExportPrimResizeBilinear},
3920     {prim::kPrimConcat, &OnnxExporter::ExportPrimConcat},
3921     {prim::kPrimCast, &OnnxExporter::ExportPrimCast},
3922     {prim::kPrimPReLU, &OnnxExporter::ExportPrimPReLU},
3923     {prim::kPrimReLU6, &OnnxExporter::ExportPrimReLU6},
3924     {prim::kPrimDepthwiseConv2dNative, &OnnxExporter::ExportPrimDepthwiseConv2d},
3925     {prim::kPrimTile, &OnnxExporter::ExportPrimTile},
3926     {prim::kPrimSquare, &OnnxExporter::ExportPrimSquare},
3927     {prim::kPrimGather, &OnnxExporter::ExportPrimGatherV2},
3928     {prim::kPrimTupleGetItem, &OnnxExporter::ExportPrimTupleGetItem},
3929     {prim::kPrimTopK, &OnnxExporter::ExportPrimTopK},
3930     {prim::kPrimBoundingBoxDecode, &OnnxExporter::ExportPrimBoundingBoxDecode},
3931     {prim::kPrimNMSWithMask, &OnnxExporter::ExportPrimNMSWithMask},
3932     {prim::kPrimSplit, &OnnxExporter::ExportPrimSplit},
3933     {prim::kPrimROIAlign, &OnnxExporter::ExportPrimROIAlign},
3934     {prim::kPrimSlice, &OnnxExporter::ExportPrimSlice},
3935     {prim::kPrimOnesLike, &OnnxExporter::ExportPrimOnesLike},
3936     {prim::kPrimScatterNd, &OnnxExporter::ExportPrimScatterNd},
3937     {prim::kPrimArgMaxWithValue, &OnnxExporter::ExportPrimArgMaxWithValue},
3938     {prim::kPrimArgMinWithValue, &OnnxExporter::ExportPrimArgMinWithValue},
3939     {prim::kPrimOneHot, &OnnxExporter::ExportPrimOneHot},
3940     {prim::kPrimConv2DTranspose, &OnnxExporter::ExportPrimConv2DTranspose},
3941     {prim::kPrimGreaterEqual, &OnnxExporter::ExportPrimGreaterEqual},
3942     {prim::kPrimLessEqual, &OnnxExporter::ExportPrimLessEqual},
3943     {prim::kPrimNotEqual, &OnnxExporter::ExportPrimNotEqual},
3944     {prim::kPrimDense, &OnnxExporter::ExportPrimDense},
3945     {prim::kPrimSqueeze, &OnnxExporter::ExportPrimSqueeze},
3946     {prim::kPrimExpandDims, &OnnxExporter::ExportPrimExpandDims},
3947     {prim::kPrimGatherD, &OnnxExporter::ExportPrimGatherD},
3948     {prim::kPrimPad, &OnnxExporter::ExportPrimPad},
3949     {prim::kPrimBatchMatMul, &OnnxExporter::ExportPrimBatchMatMul},
3950     {prim::kPrimBroadcastTo, &OnnxExporter::ExportPrimBroadcastTo},
3951     {prim::kPrimAddN, &OnnxExporter::ExportPrimAddN},
3952     {prim::kPrimGeLU, &OnnxExporter::ExportPrimGeLU},
3953     {prim::kPrimLstm, &OnnxExporter::ExportPrimLSTM},
3954     {prim::kPrimReverseV2, &OnnxExporter::ExportPrimReverseV2},
3955     {prim::kPrimTensorCopySlices, &OnnxExporter::ExportPrimTensorCopySlices},
3956     {prim::kPrimDynamicRNN, &OnnxExporter::ExportPrimDynamicRNN},
3957     {prim::kPrimStack, &OnnxExporter::ExportPrimStack},
3958     {prim::kPrimAtan2, &OnnxExporter::ExportPrimAtan2},
3959     {prim::kPrimFloorDiv, &OnnxExporter::ExportPrimFloorDiv},
3960     {prim::kPrimFloorMod, &OnnxExporter::ExportPrimFloorMod},
3961     {prim::kPrimSort, &OnnxExporter::ExportPrimSort},
3962     {prim::kPrimCustom, &OnnxExporter::ExportPrimCustom},
3963   };
3964 
3965   auto iter = std::find_if(export_table.begin(), export_table.end(),
3966                            [&node](const auto &item) { return node->IsApply(item.first); });
3967   if (iter != export_table.end()) {
3968     iter->second(this, func_graph, node, node_map_ptr, graph_proto);
3969     return;
3970   }
3971 
3972   auto inputs = node->inputs();
3973   if (inputs.size() < 1) {
3974     MS_LOG(EXCEPTION) << "Inputs of apply node is empty";
3975   }
3976 
3977   AnfNodePtr op = inputs[kZeroNum];
3978   std::vector<AnfNodePtr> op_inputs;
3979   // first process node input 1,2,..., since when node input is a ValueNode, here need to create a Constant Operator
3980   for (size_t i = 1; i < inputs.size(); i++) {
3981     if (!HasAbstractMonad(inputs[i])) {
3982       op_inputs.push_back(inputs[i]);
3983     }
3984   }
3985 
3986   if (!op->isa<ValueNode>()) {
3987     MS_LOG(EXCEPTION) << "Need to support node op type " << op->type_name();
3988   }
3989 
3990   auto op_value = dyn_cast<ValueNode>(op)->value();
3991   if (op_value->isa<Primitive>()) {
3992     auto prim = dyn_cast<Primitive>(op_value);
3993     (*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim, op_inputs, graph_proto);
3994   } else if (while_loop_export::IsControlSubgraph(op_value)) {
3995     ExportWhileLoop(node, node_map_ptr, graph_proto);
3996   } else {
3997     MS_LOG(EXCEPTION) << "Need to support node op value type " << op_value->type_name();
3998   }
3999 }
4000 
ExportWhileLoop(const CNodePtr & start_node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * graph_proto)4001 void OnnxExporter::ExportWhileLoop(const CNodePtr &start_node, std::map<AnfNodePtr, std::string> *node_map_ptr,
4002                                    onnx::GraphProto *graph_proto) {
4003   auto node_name = RegisterNodeWithUniqueName(start_node, node_map_ptr);
4004   auto loop_parts = while_loop_export::MatchGraph(start_node);
4005 
4006   // 1. Make Loop op
4007 
4008   onnx::NodeProto *loop_proto = graph_proto->add_node();
4009   loop_proto->set_op_type("Loop");
4010 
4011   auto loop_count_name = node_name + "_M";
4012   const auto &loop_counter_params = loop_parts.loop_condition_info;
4013   int64_t loop_count = (loop_counter_params.end - loop_counter_params.begin) / loop_counter_params.step;
4014   onnx::TensorProto *loop_count_proto = graph_proto->add_initializer();
4015   loop_count_proto->set_name(loop_count_name);
4016   loop_count_proto->set_data_type(onnx::TensorProto_DataType_INT64);
4017   loop_count_proto->add_int64_data(loop_count);
4018 
4019   auto loop_cond_name = node_name + "_cond";
4020   auto *cond_value = graph_proto->add_initializer();
4021   cond_value->set_name(loop_cond_name);
4022   cond_value->set_data_type(onnx::TensorProto_DataType_BOOL);
4023   cond_value->add_int32_data(true);
4024 
4025   loop_proto->add_input(loop_count_name);
4026   loop_proto->add_input(loop_cond_name);
4027   for (const auto &[loop_i, control_i] : loop_parts.used_loop_to_control_param_indices) {
4028     auto name = GetNodeInputName(start_node->input(control_i + 1), node_map_ptr, graph_proto);
4029     loop_proto->add_input(name);
4030     loop_proto->add_output(MakeOutputName(node_name + "_loop", loop_i));
4031   }
4032 
4033   onnx::AttributeProto *subgraph_attr = loop_proto->add_attribute();
4034   subgraph_attr->set_type(onnx::AttributeProto_AttributeType_GRAPH);
4035   subgraph_attr->set_name("body");
4036   onnx::GraphProto *loop_subgraph_proto = subgraph_attr->mutable_g();
4037 
4038   // 2. Create subgraph for loop body
4039 
4040   auto subgraph_name = loop_parts.loop_subgraph->ToString();
4041   auto subgraph_input_cond_name = subgraph_name + "_input_cond";
4042 
4043   auto *iter_num_input = loop_subgraph_proto->add_input();
4044   iter_num_input->set_name(subgraph_name + "_input_M");
4045   (void)iter_num_input->mutable_type()->mutable_tensor_type()->mutable_shape();  // side-effect: shape created
4046   iter_num_input->mutable_type()->mutable_tensor_type()->set_elem_type(onnx::TensorProto_DataType_INT64);
4047 
4048   auto *cond_input = loop_subgraph_proto->add_input();
4049   cond_input->set_name(subgraph_input_cond_name);
4050   cond_input->mutable_type()->mutable_tensor_type()->set_elem_type(cond_value->data_type());
4051 
4052   auto *cond_output = loop_subgraph_proto->add_output();
4053   cond_output->set_name(cond_input->name());
4054   cond_output->mutable_type()->mutable_tensor_type()->set_elem_type(cond_value->data_type());
4055 
4056   MS_EXCEPTION_IF_CHECK_FAIL(renamed_node_map_.empty(), "renamed_nodes must be cleared after subgraph export");
4057   for (size_t i : loop_parts.ignored_loop_param_indices) {
4058     const auto &param = loop_parts.loop_subgraph->parameters().at(i);
4059     renamed_node_map_[param] = "";
4060   }
4061 
4062   // Export everything except the control call and the output (see MatchAndMark)
4063   ExportFuncGraph(loop_parts.loop_subgraph, node_map_ptr, loop_subgraph_proto);
4064 
4065   // Export outputs manually
4066   for (const auto &loop_to_control_i : loop_parts.used_loop_to_control_param_indices) {
4067     const auto &input = loop_parts.repeat_node->input(loop_to_control_i.second + 1);
4068     ExportOutput(loop_parts.loop_subgraph, input, node_map_ptr, loop_subgraph_proto);
4069   }
4070   renamed_node_map_.clear();
4071 
4072   // 3. Export part after loop
4073 
4074   MS_EXCEPTION_IF_CHECK_FAIL(renamed_node_map_.empty(), "renamed_nodes must be cleared after subgraph export");
4075   const auto &after_loop_params = loop_parts.after_loop_subgraph->parameters();
4076   for (const auto &[after_i, output_i] : loop_parts.after_param_to_output_indices) {
4077     MS_EXCEPTION_IF_CHECK_FAIL(static_cast<int>(output_i) < loop_proto->output_size(), "Output index out of bounds");
4078     renamed_node_map_[after_loop_params.at(after_i)] = loop_proto->output(output_i);
4079   }
4080   ExportFuncGraph(loop_parts.after_loop_subgraph, node_map_ptr, graph_proto, false);
4081 
4082   auto after_loop_retval = GetRealInput(loop_parts.after_loop_subgraph->get_return()->input(1));
4083   if (after_loop_retval->isa<CNode>() && after_loop_retval->cast<CNodePtr>()->IsApply(prim::kPrimMakeTuple)) {
4084     auto tuple_retval = dyn_cast<CNode>(after_loop_retval);
4085     for (size_t i = 1; i < tuple_retval->size(); ++i) {
4086       auto output_name = GetNodeInputName(tuple_retval->input(i), node_map_ptr, graph_proto);
4087       AddOp("Identity", {output_name}, {MakeOutputName(node_name, i - 1)}, graph_proto);
4088     }
4089   } else {
4090     auto output_name = GetNodeInputName(after_loop_retval, node_map_ptr, graph_proto);
4091     AddOp("Identity", {output_name}, {node_name}, graph_proto);
4092   }
4093   renamed_node_map_.clear();
4094 }
4095 
GetOutputType(const AnfNodePtr & node,int64_t output_index)4096 onnx::TensorProto_DataType OnnxExporter::GetOutputType(const AnfNodePtr &node, int64_t output_index) {
4097   auto unpacked = GetRealInput(node);
4098   if (IsPrimitiveCNode(unpacked, prim::kPrimTupleGetItem)) {
4099     if (output_index != -1) {
4100       MS_LOG(EXCEPTION) << "Unexpected output index for TupleGetItem: " << output_index;
4101     }
4102     auto cnode = dyn_cast<CNode>(unpacked);
4103     unpacked = cnode->input(kOneNum);
4104     output_index = GetInt64Value(cnode->input(kTwoNum));
4105   }
4106 
4107   /*
4108     Special cases (MS and ONNX type differences) go here
4109     Example:
4110       if (IsPrimitiveCNode(unpacked, prim::kPrim<Something>) && output_index == <i>) {
4111         return onnx::TensorProto_DataType_<TYPE>;
4112       }
4113   */
4114 
4115   if (output_index == -1) {
4116     auto tensor = dyn_cast<TensorType>(unpacked->Type());
4117     if (tensor == nullptr) {
4118       MS_LOG(EXCEPTION) << "Expected output of node " << unpacked->ToString()
4119                         << " to be a single tensor. Instead got: " << unpacked->Type()->ToString();
4120     }
4121     return GetOnnxDataType(tensor->element()->type_id());
4122   } else {
4123     auto tuple_type = dyn_cast<Tuple>(unpacked->Type());
4124     if (tuple_type == nullptr) {
4125       MS_LOG(EXCEPTION) << "Expected output of node " << unpacked->ToString()
4126                         << " to be a tuple. Instead got: " << unpacked->Type()->ToString();
4127     }
4128     auto element_type = tuple_type->elements()[static_cast<size_t>(output_index)];
4129     MS_EXCEPTION_IF_NULL(element_type);
4130     auto tensor_type = dyn_cast<TensorType>(element_type);
4131     if (tensor_type == nullptr) {
4132       MS_LOG(EXCEPTION) << "Expected output " << output_index << " of node " << unpacked->ToString()
4133                         << " to be a tensor. Instead got: " << element_type->ToString();
4134     }
4135     return GetOnnxDataType(tensor_type->element()->type_id());
4136   }
4137 }
4138 
AddOutputWithCast(onnx::NodeProto * node_proto,const std::string & output_name,onnx::TensorProto_DataType target_type,onnx::GraphProto * graph_proto) const4139 void OnnxExporter::AddOutputWithCast(onnx::NodeProto *node_proto, const std::string &output_name,
4140                                      onnx::TensorProto_DataType target_type, onnx::GraphProto *graph_proto) const {
4141   if (target_type == onnx::TensorProto_DataType_UNDEFINED) {
4142     node_proto->add_output(output_name);
4143   } else {
4144     auto output_to_cast_name = output_name + "_output_to_cast";
4145     node_proto->add_output(output_to_cast_name);
4146     AddCastOp(output_to_cast_name, output_name, target_type, graph_proto);
4147   }
4148 }
4149 
ExportPrimitive(const FuncGraphPtr &,std::map<AnfNodePtr,std::string> * node_map_ptr,const PrimitivePtr & prim,const std::vector<AnfNodePtr> & inputs,onnx::GraphProto * const graph_proto)4150 std::string OnnxExporter::ExportPrimitive(const FuncGraphPtr &, std::map<AnfNodePtr, std::string> *node_map_ptr,
4151                                           const PrimitivePtr &prim, const std::vector<AnfNodePtr> &inputs,
4152                                           onnx::GraphProto *const graph_proto) {
4153   auto op_map = OpConvertRegistry::GetOpConvertMap();
4154   MS_EXCEPTION_IF_NULL(prim);
4155   auto op_iter = op_map.find(prim->name());
4156   if (op_iter == op_map.end()) {
4157     MS_LOG(EXCEPTION) << "Can not find key " << prim->name() << " in convert map. "
4158                       << "Exporting " << prim->name() << " operator is not yet supported.";
4159   }
4160   // Get input first, because input maybe valuenode which need create constant node
4161   std::vector<std::string> input_list;
4162   for (const auto &input : inputs) {
4163     auto input_name = GetNodeInputName(input, node_map_ptr, graph_proto);
4164     input_list.push_back(input_name);
4165   }
4166 
4167   const OpNameInfo &op_convert_info = op_iter->second;
4168   auto node_name = GenerateUniqueName();
4169 
4170   std::vector<onnx::TensorProto_DataType> output_cast_types(op_convert_info.num_outputs(),
4171                                                             onnx::TensorProto_DataType_UNDEFINED);
4172   // Cast inputs if needed
4173   for (const auto &rule : op_convert_info.input_casts()) {
4174     auto original_type = GetOutputType(inputs[static_cast<size_t>(rule.input_index)]);
4175     if (original_type != rule.input_type) {
4176       continue;
4177     }
4178 
4179     auto cast_input_name = node_name + "cast_input_" + std::to_string(rule.input_index);
4180     AddCastOp(input_list[static_cast<size_t>(rule.input_index)], cast_input_name, rule.target_type, graph_proto);
4181     input_list[static_cast<size_t>(rule.input_index)] = cast_input_name;
4182 
4183     auto output_cast = std::find_if(
4184       op_convert_info.output_casts().begin(), op_convert_info.output_casts().end(), [&rule](const OutputConversion &x) {
4185         return x.mode == OutputConversion::Mode::INPUT && x.input_with_matching_type == rule.input_index;
4186       });
4187     if (output_cast != op_convert_info.output_casts().end()) {
4188       output_cast_types[static_cast<size_t>(output_cast->output_index)] = original_type;
4189     }
4190   }
4191 
4192   for (const auto &output_cast : op_convert_info.output_casts()) {
4193     if (output_cast.mode == OutputConversion::Mode::FIXED) {
4194       output_cast_types[static_cast<size_t>(output_cast.output_index)] = output_cast.target_type;
4195     }
4196   }
4197 
4198   onnx::NodeProto *node_proto = graph_proto->add_node();
4199   node_proto->set_name(node_name + op_convert_info.onnx_type());
4200   node_proto->set_op_type(op_convert_info.onnx_type());
4201 
4202   // Set outputs
4203   if (op_convert_info.num_outputs() == 1) {
4204     AddOutputWithCast(node_proto, node_name, output_cast_types[0], graph_proto);
4205   } else {
4206     for (int i = 0; i < op_convert_info.num_outputs(); ++i) {
4207       auto output_name = MakeOutputName(node_name, i);
4208       AddOutputWithCast(node_proto, output_name, output_cast_types[static_cast<size_t>(i)], graph_proto);
4209     }
4210   }
4211 
4212   // Set inputs
4213   for (const auto &input_name : input_list) {
4214     node_proto->add_input(input_name);
4215   }
4216 
4217   // Set node attribute
4218   for (const OpAttrInfo &attr : op_convert_info.op_attrs()) {
4219     const std::string &attr_name = attr.attr_name();
4220     ValuePtr attr_value = nullptr;
4221     if (!attr_name.empty()) {
4222       attr_value = prim->GetAttr(attr_name);
4223       if (attr_value == nullptr) {
4224         MS_LOG(EXCEPTION) << "Primitive " << prim->name() << " does not have attribute " << attr_name;
4225       }
4226     }
4227     onnx::AttributeProto *onnx_attr_proto = node_proto->add_attribute();
4228     onnx_attr_proto->set_name(attr.onnx_attr_name());
4229     attr.fn_gen_attr()(attr_value, attr.onnx_attr_type(), onnx_attr_proto, prim);
4230   }
4231   return node_name;
4232 }
4233 
ExportMergeConv(const FuncGraphPtr & func_graph,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)4234 void OnnxExporter::ExportMergeConv(const FuncGraphPtr &func_graph, const CNodePtr &node,
4235                                    std::map<AnfNodePtr, std::string> *node_map_ptr,
4236                                    onnx::GraphProto *const graph_proto) {
4237   auto conv_node = dyn_cast<CNode>(node->input(kOneNum));
4238   auto input_x = conv_node->input(kOneNum);  // conv input x
4239   auto input_w = conv_node->input(kTwoNum);  // conv weight(filter)
4240   auto input_b = node->input(kTwoNum);       // conv bias
4241 
4242   PrimitivePtr prim_conv = dyn_cast<Primitive>((dyn_cast<ValueNode>(conv_node->input(kZeroNum)))->value());
4243   std::vector<AnfNodePtr> inputs{input_x, input_w, input_b};
4244   (*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim_conv, inputs, graph_proto);
4245 }
4246 
ExportMergeGemm(const FuncGraphPtr & func_graph,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)4247 void OnnxExporter::ExportMergeGemm(const FuncGraphPtr &func_graph, const CNodePtr &node,
4248                                    std::map<AnfNodePtr, std::string> *node_map_ptr,
4249                                    onnx::GraphProto *const graph_proto) {
4250   auto matmul_node = dyn_cast<CNode>(node->input(kOneNum));
4251   auto input_x = matmul_node->input(kOneNum);  // matmul input x
4252   auto input_y = matmul_node->input(kTwoNum);  // matmul input y
4253   auto input_b = node->input(kTwoNum);         // matmul bias
4254 
4255   PrimitivePtr prim_matmul = dyn_cast<Primitive>((dyn_cast<ValueNode>(matmul_node->input(kZeroNum)))->value());
4256   std::vector<AnfNodePtr> inputs{input_x, input_y, input_b};
4257   (*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim_matmul, inputs, graph_proto);
4258 }
4259 
ExportMergeBatchNorm(const FuncGraphPtr & func_graph,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)4260 void OnnxExporter::ExportMergeBatchNorm(const FuncGraphPtr &func_graph, const CNodePtr &node,
4261                                         std::map<AnfNodePtr, std::string> *node_map_ptr,
4262                                         onnx::GraphProto *const graph_proto) {
4263   auto batch_norm_node = dyn_cast<CNode>(node->input(kOneNum));
4264 
4265   auto is_training = GetOpAttribute<bool>(batch_norm_node, "is_training");
4266   if (is_training) {
4267     auto input_x_name = GetNodeInputName(batch_norm_node->input(kOneNum), node_map_ptr, graph_proto);
4268     auto scale_input_name = GetNodeInputName(batch_norm_node->input(kTwoNum), node_map_ptr, graph_proto);
4269     auto bias_input_name = GetNodeInputName(batch_norm_node->input(kThreeNum), node_map_ptr, graph_proto);
4270 
4271     auto onnx_type = GetOutputType(batch_norm_node->input(kOneNum));
4272 
4273     auto output_name = RegisterNodeWithUniqueName(node, node_map_ptr);
4274 
4275     auto input_shape_ptr = batch_norm_node->input(kOneNum)->Shape();
4276     auto input_shape = input_shape_ptr->cast<abstract::ShapePtr>()->shape();
4277 
4278     std::vector<int64_t> normalize_axes = {0};
4279     for (size_t i = kTwoNum; i < input_shape.size(); ++i) {
4280       normalize_axes.push_back(static_cast<int64_t>(i));
4281     }
4282 
4283     std::vector<int64_t> scale_bias_shape(input_shape.size(), 1);
4284     scale_bias_shape[1] = -1;
4285     auto reshaped_scale_name = output_name + "_reshaped_scale";
4286     AddReshapeOp(scale_input_name, reshaped_scale_name, scale_bias_shape, graph_proto);
4287     auto reshaped_bias_name = output_name + "_reshaped_bias";
4288     AddReshapeOp(bias_input_name, reshaped_bias_name, scale_bias_shape, graph_proto);
4289     auto epsilon = GetOpAttribute<float>(batch_norm_node, "epsilon");
4290 
4291     AddMeanVarianceNormalizationOp(input_x_name, reshaped_scale_name, reshaped_bias_name, output_name, normalize_axes,
4292                                    epsilon, input_shape, onnx_type, graph_proto);
4293   } else {
4294     PrimitivePtr prim_batch_norm = GetPrimitive(batch_norm_node);
4295     std::vector<AnfNodePtr> inputs;
4296     for (size_t i = 1; i < batch_norm_node->size(); i++) {
4297       inputs.push_back(batch_norm_node->input(i));
4298     }
4299     (*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim_batch_norm, inputs, graph_proto);
4300   }
4301 }
4302 
ExportMergeMaxPoolWithArgmax(const FuncGraphPtr & func_graph,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)4303 void OnnxExporter::ExportMergeMaxPoolWithArgmax(const FuncGraphPtr &func_graph, const CNodePtr &node,
4304                                                 std::map<AnfNodePtr, std::string> *node_map_ptr,
4305                                                 onnx::GraphProto *const graph_proto) {
4306   auto maxpool_with_argmax_node = dyn_cast<CNode>(node->input(kOneNum));
4307 
4308   PrimitivePtr prim_maxpool_with_argmax =
4309     dyn_cast<Primitive>((dyn_cast<ValueNode>(maxpool_with_argmax_node->input(kZeroNum)))->value());
4310   std::vector<AnfNodePtr> inputs;
4311   for (size_t i = 1; i < maxpool_with_argmax_node->size(); i++) {
4312     inputs.push_back(maxpool_with_argmax_node->input(i));
4313   }
4314   (*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim_maxpool_with_argmax, inputs, graph_proto);
4315 }
4316 
4317 // LayerNorm(N, C1, H, W) --> reshape(1, C2, 1, W) + MeanVarianceNormalization + reshape(N, C1, H, W)
ExportMergeLayerNorm(const FuncGraphPtr &,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)4318 void OnnxExporter::ExportMergeLayerNorm(const FuncGraphPtr &, const CNodePtr &node,
4319                                         std::map<AnfNodePtr, std::string> *node_map_ptr,
4320                                         onnx::GraphProto *const graph_proto) {
4321   auto LayerNormNode = dyn_cast<CNode>(node->input(kOneNum));
4322   auto layernorm_input_x = GetNodeInputName(LayerNormNode->input(kOneNum), node_map_ptr, graph_proto);
4323   auto layernorm_input_gamma = GetNodeInputName(LayerNormNode->input(kTwoNum), node_map_ptr, graph_proto);
4324   auto layernorm_input_beta = GetNodeInputName(LayerNormNode->input(kThreeNum), node_map_ptr, graph_proto);
4325 
4326   auto begin_norm_axis = GetOpAttribute<int64_t>(LayerNormNode, "begin_norm_axis");
4327   auto begin_params_axis = GetOpAttribute<int64_t>(LayerNormNode, "begin_params_axis");
4328   if (begin_norm_axis != -1 || begin_params_axis != -1) {
4329     MS_LOG(EXCEPTION) << "begin_norm_axis != -1 and begin_params_axis != -1 are not implemented";
4330   }
4331 
4332   auto onnx_type = GetOutputType(LayerNormNode->input(kOneNum));
4333   auto input_shape = dyn_cast<abstract::Shape>(LayerNormNode->input(kOneNum)->Shape())->shape();
4334   auto node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
4335   auto epsilon = GetOpAttribute<float>(LayerNormNode, "epsilon");
4336   std::vector<int64_t> reduce_axes = {static_cast<int64_t>(input_shape.size()) - 1};
4337 
4338   AddMeanVarianceNormalizationOp(layernorm_input_x, layernorm_input_gamma, layernorm_input_beta, node_name, reduce_axes,
4339                                  epsilon, input_shape, onnx_type, graph_proto);
4340 }
4341 
ExportMergeConv2DTranspose(const FuncGraphPtr &,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)4342 void OnnxExporter::ExportMergeConv2DTranspose(const FuncGraphPtr &, const CNodePtr &node,
4343                                               std::map<AnfNodePtr, std::string> *node_map_ptr,
4344                                               onnx::GraphProto *const graph_proto) {
4345   auto conv_node = dyn_cast<CNode>(node->input(kOneNum));
4346   PrimConv2DTransposeExportHelper(conv_node, node, node_map_ptr, graph_proto);
4347 }
4348 
AddTransposeOp(const std::string & input,const std::string & output,onnx::GraphProto * graph_proto)4349 void AddTransposeOp(const std::string &input, const std::string &output, onnx::GraphProto *graph_proto) {
4350   onnx::NodeProto *node_proto = graph_proto->add_node();
4351   std::string op_type = "Transpose";
4352   node_proto->set_op_type(op_type);
4353   node_proto->set_name(output + op_type);
4354   node_proto->add_input(input);
4355   node_proto->add_output(output);
4356 }
4357 
AddUnsqueezeOp(const std::string & input,const std::string & output,int64_t axis,onnx::GraphProto * graph_proto)4358 void AddUnsqueezeOp(const std::string &input, const std::string &output, int64_t axis, onnx::GraphProto *graph_proto) {
4359   onnx::NodeProto *node_proto = graph_proto->add_node();
4360   std::string op_type = "Unsqueeze";
4361   node_proto->set_op_type(op_type);
4362   node_proto->set_name(output + op_type);
4363   node_proto->add_input(input);
4364   node_proto->add_output(output);
4365 
4366   onnx::AttributeProto *attr_proto = node_proto->add_attribute();
4367   attr_proto->set_type(onnx::AttributeProto_AttributeType_INTS);
4368   attr_proto->set_name("axes");
4369   attr_proto->add_ints(axis);
4370 }
4371 
AddSqueezeOp(const std::string & input,const std::string & output,int64_t axis,onnx::GraphProto * graph_proto)4372 void AddSqueezeOp(const std::string &input, const std::string &output, int64_t axis, onnx::GraphProto *graph_proto) {
4373   onnx::NodeProto *node_proto = graph_proto->add_node();
4374   std::string op_type = "Squeeze";
4375   node_proto->set_op_type(op_type);
4376   node_proto->set_name(output + op_type);
4377   node_proto->add_input(input);
4378   node_proto->add_output(output);
4379 
4380   onnx::AttributeProto *attr_proto = node_proto->add_attribute();
4381   attr_proto->set_type(onnx::AttributeProto_AttributeType_INTS);
4382   attr_proto->set_name("axes");
4383   attr_proto->add_ints(axis);
4384 }
4385 
AddGRUOp(const std::vector<std::string> & inputs,const std::vector<std::string> & outputs,int64_t hidden_size,int64_t linear_before_reset,onnx::GraphProto * graph_proto)4386 void AddGRUOp(const std::vector<std::string> &inputs, const std::vector<std::string> &outputs, int64_t hidden_size,
4387               int64_t linear_before_reset, onnx::GraphProto *graph_proto) {
4388   onnx::NodeProto *node_proto = graph_proto->add_node();
4389   std::string op_type = "GRU";
4390   node_proto->set_op_type(op_type);
4391   node_proto->set_name(outputs[0] + op_type);
4392 
4393   for (const auto &in : inputs) {
4394     node_proto->add_input(in);
4395   }
4396 
4397   for (const auto &out : outputs) {
4398     node_proto->add_output(out);
4399   }
4400 
4401   onnx::AttributeProto *attr_proto = node_proto->add_attribute();
4402   attr_proto->set_type(onnx::AttributeProto_AttributeType_INT);
4403   attr_proto->set_name("linear_before_reset");
4404   attr_proto->set_i(linear_before_reset);
4405 
4406   onnx::AttributeProto *attr2_proto = node_proto->add_attribute();
4407   attr2_proto->set_type(onnx::AttributeProto_AttributeType_INT);
4408   attr2_proto->set_name("hidden_size");
4409   attr2_proto->set_i(hidden_size);
4410 }
4411 
UnsqueezeInputOfGRU(std::string * in_name,const std::string & node_name,const std::string & suffix,int64_t axis,onnx::GraphProto * graph_proto)4412 void UnsqueezeInputOfGRU(std::string *in_name, const std::string &node_name, const std::string &suffix, int64_t axis,
4413                          onnx::GraphProto *graph_proto) {
4414   auto out_name = node_name + suffix;
4415   AddUnsqueezeOp(*in_name, out_name, axis, graph_proto);
4416   *in_name = out_name;
4417 }
4418 
GruRzh2Zrh(std::string * in_name,const std::string & node_name,const std::string & mid_name,std::vector<std::string> tmp_out_names,const std::vector<int64_t> & hidden_sizes,int64_t axis,onnx::GraphProto * graph_proto)4419 void GruRzh2Zrh(std::string *in_name, const std::string &node_name, const std::string &mid_name,
4420                 std::vector<std::string> tmp_out_names, const std::vector<int64_t> &hidden_sizes, int64_t axis,
4421                 onnx::GraphProto *graph_proto) {
4422   const int kConcatNum = 6;
4423   const int kIndexBiasHiddenR = 3;
4424   const int kIndexBiasHiddenZ = 4;
4425   auto out_name = node_name + mid_name + "_zrh";
4426 
4427   AddSplitOp(*in_name, tmp_out_names, hidden_sizes, 0, graph_proto);
4428   swap(tmp_out_names[0], tmp_out_names[1]);
4429   if (tmp_out_names.size() == kConcatNum) {
4430     swap(tmp_out_names[kIndexBiasHiddenR], tmp_out_names[kIndexBiasHiddenZ]);
4431   }
4432   AddConcatOp(tmp_out_names, out_name, 0, graph_proto);
4433   *in_name = out_name;
4434 }
4435 
4436 /*
4437   Mapping between the inputs of MindSpore DynamicGRUV2 and ONNX GRU operator.
4438   +----------------------------------------------------------+----------------------------------------------+
4439   |                          ONNX                            |                  MindSpore                   |
4440   +==========================================================+==============================================+
4441   | X: [seq_length, batch_size, input_size]                  | x: (num_step, batch_size, input_size)        |
4442   +----------------------------------------------------------+----------------------------------------------+
4443   | W: [num_directions, 3*hidden_size, input_size]           | weight_input: (input_size, 3*hidden_size)    |
4444   +----------------------------------------------------------+----------------------------------------------+
4445   | R: [num_directions, 3*hidden_size, hidden_size]          | weight_hidden: (hidden_size, 3*hidden_size)  |
4446   +----------------------------------------------------------+----------------------------------------------+
4447   |                                                          | bias_input:  (3*hidden_size)                 |
4448   + B: [num_directiBBons, 6*hidden_size]                     +----------------------------------------------+
4449   |                                                          | bias_hidden: (3*hidden_size)                 |
4450   +----------------------------------------------------------+----------------------------------------------+
4451   | sequence_lens: [batch_size]                              | seq_length: (hidden_size)                    |
4452   +----------------------------------------------------------+----------------------------------------------+
4453   | initial_h: [num_directions, batch_size, hidden_size]     | init_h: (batch_size, hidden_size)            |
4454   +----------------------------------------------------------+----------------------------------------------+
4455   | Y:[seq_length, num_directions, batch_size, hidden_size]  | y: (num_step, batch_size, hidden_size)       |
4456   +----------------------------------------------------------+----------------------------------------------+
4457 */
ExportMergeDynamicGRUV2(const FuncGraphPtr &,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)4458 void OnnxExporter::ExportMergeDynamicGRUV2(const FuncGraphPtr &, const CNodePtr &node,
4459                                            std::map<AnfNodePtr, std::string> *node_map_ptr,
4460                                            onnx::GraphProto *const graph_proto) {
4461   const int kInX = 1;
4462   const int kInWeightInput = 2;
4463   const int kInWeightHidden = 3;
4464   const int kInBiasInput = 4;
4465   const int kInBiasHidden = 5;
4466   // The 6th input 'seq_length' now only support None, so it's not used.
4467   const int kInInitH = 7;
4468 
4469   const int kWeightHiddenDim = 2;
4470   const int kNumberOfGates = 3;
4471   const std::string kDefaultDir = "UNIDIRECTIONAL";
4472   const std::string kDefaultAct = "tanh";
4473   const std::vector<std::string> kGateOrderSupported{"rzh", "zrh"};
4474 
4475   auto node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
4476   auto gru_node = dyn_cast<CNode>(node->input(1));
4477   MS_EXCEPTION_IF_NULL(gru_node);
4478 
4479   /* Get Attributes */
4480   auto direction = GetOpAttribute<std::string>(gru_node, "direction");
4481   auto activation = GetOpAttribute<std::string>(gru_node, "activation");
4482   auto gate_order = GetOpAttribute<std::string>(gru_node, "gate_order");
4483   auto reset_after = GetOpAttribute<bool>(gru_node, "reset_after");
4484 
4485   int64_t linear_before_reset = reset_after ? 1 : 0;
4486 
4487   if (direction != kDefaultDir) {
4488     MS_LOG(EXCEPTION) << "'direction': " << direction << " is not in supported values[" << kDefaultDir << "]";
4489   }
4490   if (activation != kDefaultAct) {
4491     MS_LOG(EXCEPTION) << "'activation': " << activation << " is not in supported values[" << kDefaultAct << "]";
4492   }
4493   if (gate_order != kGateOrderSupported[0] && gate_order != kGateOrderSupported[1]) {
4494     std::string supported;
4495     for (const auto &order : gate_order) {
4496       supported += order;
4497       supported += ", ";
4498     }
4499     MS_LOG(EXCEPTION) << "'gate_order': " << gate_order << " is not in supported values[" << supported << "]";
4500   }
4501 
4502   auto x = GetNodeInputName(gru_node->input(kInX), node_map_ptr, graph_proto);
4503   auto weight_input = GetNodeInputName(gru_node->input(kInWeightInput), node_map_ptr, graph_proto);
4504   auto weight_hidden = GetNodeInputName(gru_node->input(kInWeightHidden), node_map_ptr, graph_proto);
4505   auto bias_input = GetNodeInputName(gru_node->input(kInBiasInput), node_map_ptr, graph_proto);
4506   auto bias_hidden = GetNodeInputName(gru_node->input(kInBiasHidden), node_map_ptr, graph_proto);
4507   auto init_h = GetNodeInputName(gru_node->input(kInInitH), node_map_ptr, graph_proto);
4508 
4509   if (GetOutputType(gru_node->input(kInBiasInput)) == onnx::TensorProto_DataType_FLOAT) {
4510     auto x_cast = x + "_cast";
4511     AddCastOp(x, x_cast, onnx::TensorProto_DataType_FLOAT, graph_proto);
4512     x = x_cast;
4513 
4514     auto wi_cast = weight_input + "_cast";
4515     AddCastOp(weight_input, wi_cast, onnx::TensorProto_DataType_FLOAT, graph_proto);
4516     weight_input = wi_cast;
4517 
4518     auto wh_cast = weight_hidden + "_cast";
4519     AddCastOp(weight_hidden, wh_cast, onnx::TensorProto_DataType_FLOAT, graph_proto);
4520     weight_hidden = wh_cast;
4521   }
4522 
4523   auto weight_hidden_shape = dyn_cast<abstract::Shape>(gru_node->input(kInWeightHidden)->Shape())->shape();
4524   if (weight_hidden_shape.size() != kWeightHiddenDim) {
4525     MS_LOG(EXCEPTION) << "The dim of input weight_hidden must be " << kWeightHiddenDim << ".";
4526   }
4527   int64_t hidden_size = weight_hidden_shape[1] / kNumberOfGates;
4528 
4529   auto trans_w_i = node_name + "_trans_w_i";
4530   AddTransposeOp(weight_input, trans_w_i, graph_proto);
4531   weight_input = trans_w_i;
4532 
4533   auto trans_w_h = node_name + "_trans_w_h";
4534   AddTransposeOp(weight_hidden, trans_w_h, graph_proto);
4535   weight_hidden = trans_w_h;
4536 
4537   auto bias_i_h = node_name + "_bias_i_h";
4538   AddConcatOp({bias_input, bias_hidden}, bias_i_h, 0, graph_proto);
4539 
4540   // ONNX GRU only support "zrh"
4541   if (gate_order == "rzh") {
4542     MS_LOG(INFO) << "change gate order 'rzh' to 'zrh'.";
4543     std::vector<int64_t> hidden_sizes(kNumberOfGates, hidden_size);
4544     GruRzh2Zrh(&weight_input, node_name, "w_i", {node_name + "_w_i_r", node_name + "_w_i_z", node_name + "_w_i_h"},
4545                hidden_sizes, 0, graph_proto);
4546     GruRzh2Zrh(&weight_hidden, node_name, "w_h", {node_name + "_w_h_r", node_name + "_w_h_z", node_name + "_w_h_h"},
4547                hidden_sizes, 0, graph_proto);
4548 
4549     std::vector<int64_t> bias_hidden_sizes(kNumberOfGates + kNumberOfGates, hidden_size);
4550     GruRzh2Zrh(&bias_i_h, node_name, "bias",
4551                {node_name + "_b_i_r", node_name + "_b_i_z", node_name + "_b_i_h", node_name + "_b_h_r",
4552                 node_name + "_b_h_z", node_name + "_b_h_h"},
4553                bias_hidden_sizes, 0, graph_proto);
4554   }
4555 
4556   std::vector<std::string *> input_names = {&weight_input, &weight_hidden, &bias_i_h, &init_h};
4557   std::vector<std::string> suffixes = {"_unsqueeze_w_i", "_unsqueeze_w_h", "_unsqueeze_bias", "_unsqueeze_init_h"};
4558   for (size_t i = 0; i < input_names.size(); i++) {
4559     UnsqueezeInputOfGRU(input_names[i], node_name, suffixes[i], 0, graph_proto);
4560   }
4561 
4562   auto y = node_name + "_Y";
4563   // 'seq_length' input of DynamicGRUV2 is None, so pass "" to ONNX GRU.
4564   std::string sequence_lens = "";
4565   AddGRUOp({x, weight_input, weight_hidden, bias_i_h, sequence_lens, init_h}, {y}, hidden_size, linear_before_reset,
4566            graph_proto);
4567 
4568   AddSqueezeOp(y, node_name, 1, graph_proto);
4569 }
4570 
4571 /*
4572   Kinds of return values:
4573   1) A single Tensor
4574   2) A Tuple returned by an op with multiple outputs like TopK
4575   3) A Tuple returned by MakeTuple. This corresponds to `return x, y`
4576      or equivalent in Python, where x and y are Tensors
4577      In this case MakeTuple itself is not exported, so this case must be handled
4578      separately from the previous one
4579   4) A constant tuple (ValueNode). Example:
4580         class MyCell(nn.Cell):
4581             def __init__(self):
4582                 super().__init__()
4583                 self.x = ms.Tensor(np.zeros((1, 2, 3)))
4584 
4585             def construct(self):
4586                 return self.x, self.x
4587 
4588  */
ExportOutput(const FuncGraphPtr & func_graph,const AnfNodePtr & return_arg,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)4589 void OnnxExporter::ExportOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &return_arg,
4590                                 std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *const graph_proto) {
4591   AnfNodePtr arg = GetRealInput(return_arg);
4592   if (IsPrimitiveCNode(arg, prim::kPrimMakeTuple)) {
4593     auto arg_cnode = dyn_cast<CNode>(arg);
4594     for (size_t i = 1; i < arg_cnode->size(); ++i) {
4595       const auto &output = arg_cnode->input(i);
4596       ExportOutput(func_graph, output, node_map_ptr, graph_proto);
4597     }
4598   } else if (arg->isa<ValueNode>() && arg->cast<ValueNodePtr>()->value()->isa<ValueTuple>()) {
4599     // Several outputs, all constants
4600     auto tuple = arg->cast<ValueNodePtr>()->value()->cast<ValueTuplePtr>();
4601     for (size_t i = 0; i < tuple->value().size(); ++i) {
4602       const auto &element = tuple->value().at(i);
4603       std::string output_name = GenerateUniqueName();
4604 
4605       onnx::TensorProto *initializer = graph_proto->add_initializer();
4606       initializer->set_name(output_name);
4607       SetTensorData(element, initializer);
4608 
4609       onnx::ValueInfoProto *output_proto = graph_proto->add_output();
4610       output_proto->set_name(output_name);
4611       SetValueInfoType(arg, output_proto, static_cast<int64_t>(i));
4612     }
4613   } else if (arg->Type()->isa<Tuple>()) {
4614     auto arg_name = GetNodeInputName(arg, node_map_ptr, graph_proto);
4615     auto tuple = dyn_cast<Tuple>(arg->Type());
4616 
4617     for (size_t i = 0; i < tuple->size(); ++i) {
4618       auto output_name = MakeOutputName(arg_name, i);
4619       onnx::ValueInfoProto *output_proto = graph_proto->add_output();
4620       output_proto->set_name(output_name);
4621       SetValueInfoType(arg, output_proto, static_cast<int64_t>(i));
4622     }
4623   } else if (arg->Type()->isa<TensorType>()) {
4624     auto arg_name = GetNodeInputName(arg, node_map_ptr, graph_proto);
4625     onnx::ValueInfoProto *output_proto = graph_proto->add_output();
4626     output_proto->set_name(arg_name);
4627     SetValueInfoType(arg, output_proto);
4628   } else {
4629     MS_LOG(EXCEPTION) << "Unsupported network output type " << arg->Type()->ToString() << " in node "
4630                       << arg->ToString();
4631   }
4632 }
4633 
GetNodeInputName(const AnfNodePtr & orig_node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const)4634 std::string OnnxExporter::GetNodeInputName(const AnfNodePtr &orig_node, std::map<AnfNodePtr, std::string> *node_map_ptr,
4635                                            onnx::GraphProto *const) {
4636   auto node = GetRealInput(orig_node);
4637 
4638   // if node is renamed and not ignored, use alternative name
4639   // if it is ignored, try to find the actual name in global map
4640   auto renamed_iter = renamed_node_map_.find(node);
4641   if (renamed_iter != renamed_node_map_.end() && renamed_iter->second != "") {
4642     return renamed_iter->second;
4643   }
4644 
4645   auto iter = node_map_ptr->find(node);
4646   if (iter != node_map_ptr->end()) {
4647     return iter->second;
4648   }
4649 
4650   if (node->isa<CNode>() || (node->isa<Parameter>() && !node->cast<ParameterPtr>()->has_default())) {
4651     MS_LOG(EXCEPTION) << "Can not find node '" << node->DebugString() << "' in node_map";
4652   }
4653 
4654   // for ValueNode or Parameter with default input, create an initializer
4655   // same value can be used in several subgraphs, so create initializers in root graph
4656   if (node->isa<ValueNode>()) {
4657     auto node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
4658     auto value = node->cast<ValueNodePtr>()->value();
4659 
4660     onnx::TensorProto *initializer_proto = model_.mutable_graph()->add_initializer();
4661     initializer_proto->set_name(node_name);
4662     SetTensorData(value, initializer_proto);
4663 
4664     (*node_map_ptr)[node] = node_name;
4665     return node_name;
4666   }
4667 
4668   if (node->isa<Parameter>()) {
4669     auto param = dyn_cast<Parameter>(node);
4670     auto node_name = GenerateUniqueParameterName(param, node_map_ptr);
4671 
4672     onnx::TensorProto *initializer_proto = model_.mutable_graph()->add_initializer();
4673     initializer_proto->set_name(node_name);
4674     SetTensorData(param->default_param(), initializer_proto);
4675 
4676     (*node_map_ptr)[node] = node_name;
4677     return node_name;
4678   }
4679 
4680   MS_LOG(EXCEPTION) << "Unexpected node type " << node->type_name();
4681 }
4682 
ConvertTupleToTensor(const ValuePtr & value,onnx::TensorProto * const tensor_proto) const4683 void OnnxExporter::ConvertTupleToTensor(const ValuePtr &value, onnx::TensorProto *const tensor_proto) const {
4684   auto tuple_ptr = dyn_cast<ValueTuple>(value);
4685   MS_EXCEPTION_IF_NULL(tuple_ptr);
4686   if (tuple_ptr->size() == 0) {
4687     MS_LOG(EXCEPTION) << "Convert tuple to tensor fail, the size of converted tuple is 0.";
4688   }
4689 
4690   ValuePtr first_element = (*tuple_ptr)[0];
4691   if (!first_element->isa<Scalar>()) {  // For non-scalars x->type() contains nullptr
4692     MS_LOG(EXCEPTION) << "Expected tuple elements to be scalars. Got: " << value->ToString();
4693   }
4694   auto type_id = first_element->type()->type_id();
4695   for (size_t i = 1; i < tuple_ptr->size(); ++i) {
4696     const auto element_type = (*tuple_ptr)[i]->type();
4697     if (element_type == nullptr || element_type->type_id() != type_id) {
4698       MS_LOG(EXCEPTION) << "Convert tuple to tensor fail, type of tuple elements is not same.";
4699     }
4700   }
4701 
4702   onnx::TensorProto_DataType result_type = onnx::TensorProto_DataType_UNDEFINED;
4703   if (first_element->isa<IntegerImm>()) {
4704     result_type = onnx::TensorProto_DataType_INT64;
4705   } else if (first_element->isa<FloatImm>()) {
4706     result_type = onnx::TensorProto_DataType_FLOAT;
4707   } else {
4708     MS_LOG(EXCEPTION) << "Convert tuple to tensor fail, unexpected tuple element type "
4709                       << first_element->type()->type_name() << ".";
4710   }
4711 
4712   tensor_proto->add_dims(static_cast<::google::protobuf::int64>(tuple_ptr->size()));
4713   tensor_proto->set_data_type(result_type);
4714   for (size_t i = 0; i < tuple_ptr->size(); ++i) {
4715     ValuePtr elem = (*tuple_ptr)[i];
4716     if (elem->isa<Int8Imm>()) {
4717       tensor_proto->add_int64_data(dyn_cast<Int8Imm>(elem)->value());
4718     } else if (elem->isa<Int16Imm>()) {
4719       tensor_proto->add_int64_data(dyn_cast<Int16Imm>(elem)->value());
4720     } else if (elem->isa<Int32Imm>()) {
4721       tensor_proto->add_int64_data(dyn_cast<Int32Imm>(elem)->value());
4722     } else if (elem->isa<Int64Imm>()) {
4723       tensor_proto->add_int64_data(dyn_cast<Int64Imm>(elem)->value());
4724     } else if (elem->isa<FP32Imm>()) {
4725       tensor_proto->add_float_data(dyn_cast<FP32Imm>(elem)->value());
4726     } else {
4727       MS_LOG(EXCEPTION) << "Convert tuple to tensor fail, unexpected tuple element type " << elem->type()->type_name()
4728                         << ".";
4729     }
4730   }
4731 }
4732 
SetTensorData(const ValuePtr & value,onnx::TensorProto * tensor_proto)4733 void OnnxExporter::SetTensorData(const ValuePtr &value, onnx::TensorProto *tensor_proto) {
4734   if (value->isa<Int32Imm>()) {
4735     auto attr_value = dyn_cast<Int32Imm>(value)->value();
4736     tensor_proto->set_data_type(onnx::TensorProto_DataType_INT32);
4737     tensor_proto->add_int32_data(attr_value);
4738   } else if (value->isa<Int64Imm>()) {
4739     auto attr_value = dyn_cast<Int64Imm>(value)->value();
4740     tensor_proto->set_data_type(onnx::TensorProto_DataType_INT64);
4741     tensor_proto->add_int64_data(attr_value);
4742   } else if (value->isa<tensor::Tensor>()) {
4743     auto data = dyn_cast<tensor::Tensor>(value);
4744     tensor_proto->set_raw_data(data->data_c(), static_cast<size_t>(data->data().nbytes()));
4745     auto dtype = data->data_type();
4746     auto shape = data->shape_c();
4747 
4748     tensor_proto->set_data_type(GetOnnxDataType(dtype));
4749     for (const auto dim : shape) {
4750       tensor_proto->add_dims(dim);
4751     }
4752   } else if (value->isa<ValueTuple>()) {  // Note: this is a tuple of primitives, not Tensors
4753     ConvertTupleToTensor(value, tensor_proto);
4754   } else {
4755     MS_LOG(EXCEPTION) << "Need to set value " << value->ToString() << " attribute for Constant node";
4756   }
4757 }
4758 
GetOnnxProtoString(const FuncGraphPtr & func_graph)4759 std::string GetOnnxProtoString(const FuncGraphPtr &func_graph) {
4760   OnnxExporter exporter;
4761   return exporter.GetOnnxProtoString(func_graph);
4762 }
4763 }  // namespace mindspore
4764