1 /**
2 * Copyright 2020-2024 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <algorithm>
18 #include <functional>
19 #include <map>
20 #include <memory>
21 #include <utility>
22 #include <vector>
23
24 #include "ir/func_graph.h"
25 #include "ir/param_info.h"
26 #include "ir/tensor.h"
27 #include "mindspore/core/ops/array_ops.h"
28 #include "mindspore/core/ops/comparison_ops.h"
29 #include "mindspore/core/ops/conv_pool_ops.h"
30 #include "mindspore/core/ops/framework_ops.h"
31 #include "mindspore/core/ops/image_ops.h"
32 #include "mindspore/core/ops/math_ops.h"
33 #include "mindspore/core/ops/nn_ops.h"
34 #include "mindspore/core/ops/nn_optimizer_ops.h"
35 #include "mindspore/core/ops/sequence_ops.h"
36 #include "mindspore/core/ops/auto_generate/gen_ops_primitive.h"
37 #include "proto/onnx.pb.h"
38 #include "utils/check_convert_utils.h"
39 #include "utils/hash_map.h"
40 #include "utils/ms_context.h"
41 #include "ops/op_utils.h"
42
43 namespace mindspore {
44 const int ONNX_VERSION = 11;
45 const int kZeroNum = 0;
46 const int kOneNum = 1;
47 const int kTwoNum = 2;
48 const int kThreeNum = 3;
49 const int kFourNum = 4;
50 const int kFiveNum = 5;
51 const int kSixNum = 6;
52 const int kSevenNum = 7;
53 const int kEightNum = 8;
54 const int kNineNum = 9;
55 const int64_t kOneNumLong = 1;
56 const float weight_for_mul = 0.5;
57 enum OpMergeMode {
58 OP_MERGE_UNDEFINED = 0, // undefined behavior
59 OP_MERGE_IGNORE = 1, // indicate an input op merged into other op in compute node list
60 OP_MERGE_CONV = 2, // indicate `MindSpore Conv + BiasAdd` --> `ONNX Conv`
61 OP_MERGE_GEMM = 3, // indicate `MindSpore MatMul + BiasAdd` --> `ONNX Gemm`
62 OP_MERGE_BATCH_NORM = 4, // indicate `MindSpore BatchNorm(x)[0]` --> `ONNX Batch Normalization`
63 OP_MERGE_MAXPOOL_WITH_ARGMAX = 5, // indicate `MindSpore MaxPoolWithArgmax(x)[0]` --> `ONNX MaxPool`
64 OP_MERGE_LAYER_NORM = 6, // indicate `MindSpore LayerNorm(x)[0]` --> `ONNX MeanVarianceNormalization`
65 OP_MERGE_CONV2D_TRANSPOSE = 7, // indicate `MindSpore ConvTranspose + BiasAdd` --> `ONNX ConvTranspose`
66 OP_MERGE_DYNAMIC_GRU_V2 = 8, // indicate `MindSpore DynamicGRUV2(...)[0]` --> `ONNX GRU`
67 };
68
69 struct OpMergedInfo {
70 OpMergeMode mode = OP_MERGE_UNDEFINED;
71 int referred_count = 0;
72 };
73
74 using GenAttrFuncType =
75 std::function<void(ValuePtr, onnx::AttributeProto_AttributeType, onnx::AttributeProto *, const PrimitivePtr &)>;
76
IsIgnoredIdentityNode(const AnfNodePtr & node)77 bool IsIgnoredIdentityNode(const AnfNodePtr &node) {
78 return IsPrimitiveCNode(node, prim::kPrimDepend) || IsPrimitiveCNode(node, prim::kPrimLoad);
79 }
80
81 /*
82 If true, the node should not be referenced by anything and should not be contributing to any
83 ref counts itself
84 */
IsZeroRefcountNode(const AnfNodePtr & node)85 bool IsZeroRefcountNode(const AnfNodePtr &node) { return HasAbstractMonad(node) || IsIgnoredIdentityNode(node); }
86
87 // Ideally this should be applied to every node->input() call, not only inside GetNodeInputName
GetRealInput(const AnfNodePtr & origin_input)88 static AnfNodePtr GetRealInput(const AnfNodePtr &origin_input) {
89 AnfNodePtr input = origin_input;
90 while (IsIgnoredIdentityNode(input)) {
91 input = input->cast<CNodePtr>()->inputs().at(1);
92 }
93 return input;
94 }
95
96 template <typename T, size_t rep_cnt = 0>
SetAttrValueToProto(const ValuePtr & value,onnx::AttributeProto_AttributeType attr_type,onnx::AttributeProto * const attr_proto,const PrimitivePtr &)97 void SetAttrValueToProto(const ValuePtr &value, onnx::AttributeProto_AttributeType attr_type,
98 onnx::AttributeProto *const attr_proto, const PrimitivePtr &) {
99 auto casted_value = dyn_cast<T>(value);
100 if (casted_value == nullptr) {
101 MS_LOG(EXCEPTION) << "Cast value " << value->ToString() << " to type T failed.";
102 }
103 auto attr_value = casted_value->value();
104 switch (attr_type) {
105 case onnx::AttributeProto_AttributeType_INT:
106 attr_proto->set_i(static_cast<::google::protobuf::int64>(attr_value));
107 break;
108 case onnx::AttributeProto_AttributeType_FLOAT:
109 attr_proto->set_f(static_cast<float>(attr_value));
110 break;
111 case onnx::AttributeProto_AttributeType_INTS:
112 for (size_t i = 0; i < rep_cnt; ++i) {
113 attr_proto->add_ints(static_cast<::google::protobuf::int64>(attr_value));
114 }
115 break;
116 case onnx::AttributeProto_AttributeType_FLOATS:
117 for (size_t i = 0; i < rep_cnt; ++i) {
118 attr_proto->add_floats(static_cast<float>(attr_value));
119 }
120 break;
121 default:
122 MS_LOG(EXCEPTION) << "Convert attribute fail, unexpected ONNX type " << attr_type;
123 }
124 attr_proto->set_type(attr_type);
125 }
126
127 template <size_t beg_idx = 0>
SetAttrTupleValueToProto(const ValuePtr & value,onnx::AttributeProto_AttributeType attr_type,onnx::AttributeProto * const attr_proto,const PrimitivePtr &)128 void SetAttrTupleValueToProto(const ValuePtr &value, onnx::AttributeProto_AttributeType attr_type,
129 onnx::AttributeProto *const attr_proto, const PrimitivePtr &) {
130 auto tuple_ptr = dyn_cast<ValueTuple>(value);
131 if (tuple_ptr == nullptr) {
132 MS_LOG(EXCEPTION) << "Cast value from type " << value->type_name() << " to ValueTuple failed.";
133 }
134 switch (attr_type) {
135 case onnx::AttributeProto_AttributeType_INTS:
136 for (size_t i = beg_idx; i < tuple_ptr->size(); ++i) {
137 attr_proto->add_ints(GetValue<int64_t>((*tuple_ptr)[i]));
138 }
139 break;
140 case onnx::AttributeProto_AttributeType_INT:
141 attr_proto->set_i(GetValue<int64_t>((*tuple_ptr)[beg_idx]));
142 break;
143 case onnx::AttributeProto_AttributeType_FLOATS:
144 for (size_t i = beg_idx; i < tuple_ptr->size(); ++i) {
145 attr_proto->add_floats(GetValue<float>((*tuple_ptr)[i]));
146 }
147 break;
148 default:
149 MS_LOG(EXCEPTION) << "Convert attribute fail, unexpected ONNX type " << attr_type;
150 }
151 attr_proto->set_type(attr_type);
152 }
153
SetPoolingPadMode(const ValuePtr & value,onnx::AttributeProto_AttributeType,onnx::AttributeProto * const attr_proto,const PrimitivePtr &)154 void SetPoolingPadMode(const ValuePtr &value, onnx::AttributeProto_AttributeType,
155 onnx::AttributeProto *const attr_proto, const PrimitivePtr &) {
156 attr_proto->set_type(onnx::AttributeProto_AttributeType_STRING);
157 int64_t attr_value;
158 CheckAndConvertUtils::GetPadModEnumValue(value, &attr_value, true);
159 if (attr_value == PadMode::VALID) {
160 attr_proto->set_s("VALID");
161 } else {
162 attr_proto->set_s("SAME_UPPER");
163 }
164 }
165
SetConvPadding(const ValuePtr & value,onnx::AttributeProto_AttributeType,onnx::AttributeProto * const attr_proto,const PrimitivePtr & prim)166 void SetConvPadding(const ValuePtr &value, onnx::AttributeProto_AttributeType, onnx::AttributeProto *const attr_proto,
167 const PrimitivePtr &prim) {
168 attr_proto->set_type(onnx::AttributeProto_AttributeType_STRING);
169 int64_t attr_value;
170 CheckAndConvertUtils::GetPadModEnumValue(value, &attr_value);
171 if (attr_value == PadMode::VALID) {
172 attr_proto->set_s("VALID");
173 } else if (attr_value == PadMode::SAME) {
174 attr_proto->set_s("SAME_UPPER");
175 } else { // pad_mode is 'pad', use attribute 'pad_list' to fill ONNX attribute 'pads'
176 attr_proto->set_name("pads");
177 SetAttrTupleValueToProto(prim->GetAttr("pad_list"), onnx::AttributeProto_AttributeType_INTS, attr_proto, prim);
178 }
179 }
180
SetConvTransposePadding(const ValuePtr & value,onnx::AttributeProto_AttributeType,onnx::AttributeProto * const attr_proto,const PrimitivePtr & prim)181 void SetConvTransposePadding(const ValuePtr &value, onnx::AttributeProto_AttributeType,
182 onnx::AttributeProto *const attr_proto, const PrimitivePtr &prim) {
183 attr_proto->set_type(onnx::AttributeProto_AttributeType_STRING);
184 int64_t attr_value;
185 CheckAndConvertUtils::GetPadModEnumValue(value, &attr_value);
186 if (attr_value == PadMode::VALID) {
187 attr_proto->set_s("VALID");
188 } else if (attr_value == PadMode::SAME) {
189 attr_proto->set_s("SAME_LOWER");
190 } else { // pad_mode is 'pad', use attribute 'pad_list' to fill ONNX attribute 'pads'
191 attr_proto->set_name("pads");
192 SetAttrTupleValueToProto(prim->GetAttr("pad_list"), onnx::AttributeProto_AttributeType_INTS, attr_proto, prim);
193 }
194 }
195
GetPrimitive(const CNodePtr & node)196 PrimitivePtr GetPrimitive(const CNodePtr &node) {
197 AnfNodePtr op = node->input(kZeroNum);
198 auto op_value = dyn_cast<ValueNode>(op);
199 MS_EXCEPTION_IF_NULL(op_value);
200 auto prim = dyn_cast<Primitive>(op_value->value());
201 MS_EXCEPTION_IF_NULL(prim);
202 return prim;
203 }
204
205 template <typename T>
GetOpAttribute(const CNodePtr & node,const std::string & name)206 T GetOpAttribute(const CNodePtr &node, const std::string &name) {
207 ValuePtr attr = GetPrimitive(node)->GetAttr(name);
208 return GetValue<T>(attr);
209 }
210
211 template <typename T>
GetOpAttributePtr(const CNodePtr & node,const std::string & name)212 std::shared_ptr<T> GetOpAttributePtr(const CNodePtr &node, const std::string &name) {
213 ValuePtr attr = GetPrimitive(node)->GetAttr(name);
214 auto result = dyn_cast<T>(attr);
215 MS_EXCEPTION_IF_NULL(result);
216 return result;
217 }
218
MakeOutputName(const std::string & node_name,int output_index)219 std::string MakeOutputName(const std::string &node_name, int output_index) {
220 return node_name + "_" + std::to_string(output_index);
221 }
222
RavelIndex(const std::vector<int64_t> & index,const std::vector<int64_t> & shape)223 int64_t RavelIndex(const std::vector<int64_t> &index, const std::vector<int64_t> &shape) {
224 MS_EXCEPTION_IF_CHECK_FAIL(index.size() <= shape.size(), "Index ndims must be <= shape ndims");
225 int64_t result = 0;
226 int64_t stride = 1;
227 for (size_t i = 0; i < shape.size() - index.size(); ++i) {
228 stride *= shape[shape.size() - 1 - i];
229 }
230 for (size_t i = 0; i < index.size(); ++i) {
231 size_t rev_i = index.size() - 1 - i;
232 result += index[rev_i] * stride;
233 stride *= shape[rev_i];
234 }
235 return result;
236 }
237
238 namespace fp16 {
FieldMask(unsigned int field_size)239 uint32_t FieldMask(unsigned int field_size) {
240 const unsigned int BYTE_SIZE = 8;
241 uint32_t mask = std::numeric_limits<uint32_t>::max();
242 return mask >> (BYTE_SIZE * sizeof(mask) - field_size);
243 }
244
ExponentBias(unsigned int exponent_size)245 uint32_t ExponentBias(unsigned int exponent_size) { return (1U << (exponent_size - 1U)) - 1U; }
246
Fp32ToFp16(float value)247 uint32_t Fp32ToFp16(float value) {
248 const unsigned int FP32_M = 23;
249 const unsigned int FP32_E = 32 - 1 - FP32_M;
250 const unsigned int FP16_M = 10;
251 const unsigned int FP16_E = 16 - 1 - FP16_M;
252
253 uint32_t fp32_bits;
254 auto ret = memcpy_s(reinterpret_cast<std::byte *>(&fp32_bits), sizeof(fp32_bits),
255 reinterpret_cast<std::byte *>(&value), sizeof(value));
256 if (ret != EOK) {
257 MS_LOG(ERROR) << "Set data memcpy_s failed, ret = " << ret;
258 }
259
260 uint32_t mantissa = fp32_bits & FieldMask(FP32_M);
261 uint32_t fp32_exp_mask = FieldMask(FP32_E);
262 uint32_t fp32_exponent = (fp32_bits >> FP32_M) & fp32_exp_mask;
263 if (fp32_exponent == fp32_exp_mask) {
264 MS_LOG(EXCEPTION) << "Tried to convert inf or nan to float16: " << value;
265 }
266 uint32_t sign = fp32_bits >> (FP32_E + FP32_M);
267
268 uint32_t fp16_bits = 0;
269 fp16_bits |= sign << (FP16_E + FP16_M);
270 uint32_t fp16_exponent = 0;
271 if (fp32_exponent != 0) {
272 fp16_exponent = fp32_exponent - ExponentBias(FP32_E) + ExponentBias(FP16_E);
273 }
274 if (fp16_exponent >= FieldMask(FP16_E)) { // inf, nan (==), underflow, or overflow (>)
275 MS_LOG(EXCEPTION) << "Conversion of " << value << " to float16 resulted in exponent overflow or underflow";
276 }
277 fp16_bits |= fp16_exponent << FP16_M;
278 fp16_bits |= mantissa >> (FP32_M - FP16_M);
279
280 return fp16_bits;
281 }
282 } // namespace fp16
283
AddFloatScalarInitializer(const std::string & name,float value,onnx::TensorProto_DataType type,onnx::GraphProto * graph_proto)284 void AddFloatScalarInitializer(const std::string &name, float value, onnx::TensorProto_DataType type,
285 onnx::GraphProto *graph_proto) {
286 onnx::TensorProto *initializer = graph_proto->add_initializer();
287 initializer->set_name(name);
288 if (type == onnx::TensorProto_DataType_FLOAT16) {
289 uint32_t fp16 = fp16::Fp32ToFp16(value);
290 initializer->add_int32_data(static_cast<int32_t>(fp16));
291 } else if (type == onnx::TensorProto_DataType_FLOAT) {
292 initializer->add_float_data(value);
293 } else {
294 MS_LOG(EXCEPTION) << "Unsupported type: " << type;
295 }
296 initializer->set_data_type(type);
297 }
298
AddInt64Tensor1DInitializer(const std::string & name,const std::vector<int64_t> & values,onnx::GraphProto * graph_proto)299 void AddInt64Tensor1DInitializer(const std::string &name, const std::vector<int64_t> &values,
300 onnx::GraphProto *graph_proto) {
301 onnx::TensorProto *initializer = graph_proto->add_initializer();
302 initializer->set_name(name);
303 initializer->set_data_type(onnx::TensorProto_DataType_INT64);
304 initializer->add_dims(static_cast<int64_t>(values.size()));
305 for (auto value : values) {
306 initializer->add_int64_data(value);
307 }
308 }
309
AddFloatTensor1DInitializer(const std::string & name,const std::vector<float> & values,onnx::TensorProto_DataType type,onnx::GraphProto * graph_proto)310 void AddFloatTensor1DInitializer(const std::string &name, const std::vector<float> &values,
311 onnx::TensorProto_DataType type, onnx::GraphProto *graph_proto) {
312 onnx::TensorProto *initializer = graph_proto->add_initializer();
313 initializer->set_name(name);
314 initializer->add_dims(static_cast<int64_t>(values.size()));
315 if (type == onnx::TensorProto_DataType_FLOAT16) {
316 for (auto value : values) {
317 uint32_t fp16 = fp16::Fp32ToFp16(value);
318 initializer->add_int32_data(static_cast<int32_t>(fp16));
319 }
320 } else if (type == onnx::TensorProto_DataType_FLOAT) {
321 for (auto value : values) {
322 initializer->add_float_data(value);
323 }
324 } else {
325 MS_LOG(EXCEPTION) << "Unsupported type: " << type;
326 }
327 initializer->set_data_type(type);
328 }
329
AddOp(const std::string & type,const std::vector<std::string> & inputs,const std::vector<std::string> & outputs,onnx::GraphProto * graph_proto)330 void AddOp(const std::string &type, const std::vector<std::string> &inputs, const std::vector<std::string> &outputs,
331 onnx::GraphProto *graph_proto) {
332 onnx::NodeProto *op = graph_proto->add_node();
333 op->set_op_type(type);
334 op->set_name(outputs.at(0) + type);
335 for (const auto &input : inputs) {
336 op->add_input(input);
337 }
338 for (const auto &output : outputs) {
339 op->add_output(output);
340 }
341 }
342
AddClipOp(const std::string & input,const std::string & output,float min,float max,onnx::TensorProto_DataType type,onnx::GraphProto * graph_proto)343 void AddClipOp(const std::string &input, const std::string &output, float min, float max,
344 onnx::TensorProto_DataType type, onnx::GraphProto *graph_proto) {
345 auto min_input_name = output + "__min_initializer";
346 AddFloatScalarInitializer(min_input_name, min, type, graph_proto);
347
348 auto max_input_name = output + "__max_initializer";
349 AddFloatScalarInitializer(max_input_name, max, type, graph_proto);
350
351 AddOp("Clip", {input, min_input_name, max_input_name}, {output}, graph_proto);
352 }
353
AddSliceOp(const std::string & input,const std::string & output,const std::vector<int64_t> & start,const std::vector<int64_t> & end,const std::vector<int64_t> & axis,const std::vector<int64_t> & step,onnx::GraphProto * graph_proto)354 void AddSliceOp(const std::string &input, const std::string &output, const std::vector<int64_t> &start,
355 const std::vector<int64_t> &end, const std::vector<int64_t> &axis, const std::vector<int64_t> &step,
356 onnx::GraphProto *graph_proto) {
357 auto starts_name = output + "__starts_initializer";
358 AddInt64Tensor1DInitializer(starts_name, start, graph_proto);
359
360 auto ends_name = output + "__ends_initializer";
361 AddInt64Tensor1DInitializer(ends_name, end, graph_proto);
362
363 auto axes_name = output + "__axes_initializer";
364 AddInt64Tensor1DInitializer(axes_name, axis, graph_proto);
365
366 auto steps_name = output + "__steps_initializer";
367 AddInt64Tensor1DInitializer(steps_name, step, graph_proto);
368
369 AddOp("Slice", {input, starts_name, ends_name, axes_name, steps_name}, {output}, graph_proto);
370 }
371
AddSplitOp(const std::string & input,const std::vector<std::string> & outputs,const std::vector<int64_t> & split,int64_t axis,onnx::GraphProto * graph_proto)372 void AddSplitOp(const std::string &input, const std::vector<std::string> &outputs, const std::vector<int64_t> &split,
373 int64_t axis, onnx::GraphProto *graph_proto) {
374 if (outputs.size() != split.size()) {
375 MS_LOG(EXCEPTION) << "Number of splits and number of outputs do not match";
376 }
377
378 onnx::NodeProto *split_proto = graph_proto->add_node();
379 std::string op_type = "Split";
380 split_proto->set_op_type(op_type);
381 split_proto->set_name(outputs.at(0) + op_type);
382 split_proto->add_input(input);
383 for (const auto &output : outputs) {
384 split_proto->add_output(output);
385 }
386 onnx::AttributeProto *axis_attr_proto = split_proto->add_attribute();
387 axis_attr_proto->set_name("axis");
388 axis_attr_proto->set_type(onnx::AttributeProto_AttributeType_INT);
389 axis_attr_proto->set_i(axis);
390 onnx::AttributeProto *split_attr_proto = split_proto->add_attribute();
391 split_attr_proto->set_name("split");
392 split_attr_proto->set_type(onnx::AttributeProto_AttributeType_INTS);
393 for (int64_t n : split) {
394 split_attr_proto->add_ints(n);
395 }
396 }
397
AddExpandOp(const std::string & input,const std::string & output,const std::vector<int64_t> & shape,onnx::GraphProto * graph_proto)398 void AddExpandOp(const std::string &input, const std::string &output, const std::vector<int64_t> &shape,
399 onnx::GraphProto *graph_proto) {
400 onnx::NodeProto *expand_node_proto = graph_proto->add_node();
401 expand_node_proto->set_op_type("Expand");
402 expand_node_proto->set_name(output + "_Expand");
403 expand_node_proto->add_input(input);
404 auto shape_name = output + "_expand_shape_initializer";
405 AddInt64Tensor1DInitializer(shape_name, shape, graph_proto);
406 expand_node_proto->add_input(shape_name);
407 expand_node_proto->add_output(output);
408 }
409
AddReshapeOp(const std::string & input,const std::string & output,const std::vector<int64_t> & shape,onnx::GraphProto * graph_proto)410 void AddReshapeOp(const std::string &input, const std::string &output, const std::vector<int64_t> &shape,
411 onnx::GraphProto *graph_proto) {
412 auto shape_name = output + "__shape_initializer";
413 AddInt64Tensor1DInitializer(shape_name, shape, graph_proto);
414 AddOp("Reshape", {input, shape_name}, {output}, graph_proto);
415 }
416
AddConstantOfShapeOp(const std::string & shape,const std::string & output,onnx::GraphProto * graph_proto)417 onnx::TensorProto *AddConstantOfShapeOp(const std::string &shape, const std::string &output,
418 onnx::GraphProto *graph_proto) {
419 onnx::NodeProto *op = graph_proto->add_node();
420 std::string op_type = "ConstantOfShape";
421 op->set_op_type(op_type);
422 op->set_name(output + op_type);
423 op->add_input(shape);
424 op->add_output(output);
425 onnx::AttributeProto *value_attr = op->add_attribute();
426 value_attr->set_name("value");
427 value_attr->set_type(onnx::AttributeProto_AttributeType_TENSOR);
428 onnx::TensorProto *value_proto = value_attr->mutable_t();
429 value_proto->add_dims(1);
430 return value_proto;
431 }
432
AddCastOp(const std::string & input,const std::string & output,onnx::TensorProto_DataType target_type,onnx::GraphProto * graph_proto)433 void AddCastOp(const std::string &input, const std::string &output, onnx::TensorProto_DataType target_type,
434 onnx::GraphProto *graph_proto) {
435 onnx::NodeProto *node_proto = graph_proto->add_node();
436 std::string op_type = "Cast";
437 node_proto->set_op_type(op_type);
438 node_proto->set_name(output + op_type);
439 node_proto->add_input(input);
440 node_proto->add_output(output);
441
442 onnx::AttributeProto *target_type_attr = node_proto->add_attribute();
443 target_type_attr->set_name("to");
444 target_type_attr->set_type(onnx::AttributeProto_AttributeType_INT);
445 target_type_attr->set_i(target_type);
446 }
447
AddReduceOp(const std::string & op_type,const std::string & input,const std::string & output,const std::vector<int64_t> & axes,bool keepdims,onnx::GraphProto * graph_proto)448 void AddReduceOp(const std::string &op_type, const std::string &input, const std::string &output,
449 const std::vector<int64_t> &axes, bool keepdims, onnx::GraphProto *graph_proto) {
450 onnx::NodeProto *node_proto = graph_proto->add_node();
451 node_proto->set_name(output + op_type);
452 node_proto->set_op_type(op_type);
453 node_proto->add_input(input);
454 node_proto->add_output(output);
455
456 onnx::AttributeProto *keep_dims_proto = node_proto->add_attribute();
457 keep_dims_proto->set_name("keepdims");
458 keep_dims_proto->set_type(onnx::AttributeProto_AttributeType_INT);
459 keep_dims_proto->set_i(static_cast<int64_t>(keepdims));
460
461 onnx::AttributeProto *axes_proto = node_proto->add_attribute();
462 axes_proto->set_name("axes");
463 axes_proto->set_type(onnx::AttributeProto_AttributeType_INTS);
464
465 for (auto axis : axes) {
466 axes_proto->add_ints(axis);
467 }
468 }
469
AddMeanVarianceNormalizationOp(const std::string & input,const std::string & gamma,const std::string & beta,const std::string & output,const std::vector<int64_t> & axes,float epsilon,const std::vector<int64_t> & input_shape,onnx::TensorProto_DataType input_type,onnx::GraphProto * graph_proto)470 void AddMeanVarianceNormalizationOp(const std::string &input, const std::string &gamma, const std::string &beta,
471 const std::string &output, const std::vector<int64_t> &axes, float epsilon,
472 const std::vector<int64_t> &input_shape, onnx::TensorProto_DataType input_type,
473 onnx::GraphProto *graph_proto) {
474 auto input_name = output + "_input";
475 AddCastOp(input, input_name, onnx::TensorProto_DataType_FLOAT, graph_proto);
476 auto gamma_name = output + "_gamma";
477 AddCastOp(gamma, gamma_name, onnx::TensorProto_DataType_FLOAT, graph_proto);
478 auto beta_name = output + "_beta";
479 AddCastOp(beta, beta_name, onnx::TensorProto_DataType_FLOAT, graph_proto);
480
481 // MeanVarianceNormalization is replaced with equivalent ops because it is not supported by CUDAExecutionProvider
482 auto meanvariancenormal_node_name = output + "_normalized";
483
484 auto mean_name = output + "_mean";
485 AddReduceOp("ReduceMean", input_name, mean_name, axes, true, graph_proto);
486 auto centered_name = output + "_centered";
487 AddOp("Sub", {input_name, mean_name}, {centered_name}, graph_proto);
488
489 auto sqsum_name = output + "_sqsum";
490 AddReduceOp("ReduceSumSquare", centered_name, sqsum_name, axes, true, graph_proto);
491 float reduce_size = std::accumulate(axes.begin(), axes.end(), 1.0f,
492 [&input_shape](auto acc, auto axis) { return acc * input_shape[axis]; });
493 auto reduce_size_name = output + "_reduce_size";
494 AddFloatScalarInitializer(reduce_size_name, reduce_size, onnx::TensorProto_DataType_FLOAT, graph_proto);
495 auto variance_name = output + "_variance";
496 AddOp("Div", {sqsum_name, reduce_size_name}, {variance_name}, graph_proto);
497
498 auto epsilon_name = output + "_epsilon";
499 AddFloatScalarInitializer(epsilon_name, epsilon, onnx::TensorProto_DataType_FLOAT, graph_proto);
500 auto variance_with_epsilon_name = output + "_variance_with_epsilon";
501 AddOp("Add", {variance_name, epsilon_name}, {variance_with_epsilon_name}, graph_proto);
502 auto std_name = output + "_std";
503 AddOp("Sqrt", {variance_with_epsilon_name}, {std_name}, graph_proto);
504
505 AddOp("Div", {centered_name, std_name}, {meanvariancenormal_node_name}, graph_proto);
506
507 // Add mul and add node
508 auto mul_node_name = output + "_rescaled";
509 AddOp("Mul", {meanvariancenormal_node_name, gamma_name}, {mul_node_name}, graph_proto);
510
511 // add beta
512 auto add_node_name = output;
513 if (input_type == onnx::TensorProto_DataType_FLOAT16) {
514 add_node_name += "_shifted";
515 }
516 AddOp("Add", {mul_node_name, beta_name}, {add_node_name}, graph_proto);
517
518 if (input_type == onnx::TensorProto_DataType_FLOAT16) {
519 AddCastOp(add_node_name, output, onnx::TensorProto_DataType_FLOAT16, graph_proto);
520 }
521 }
522
AddConcatOp(const std::vector<std::string> & inputs,const std::string & output,int axis,onnx::GraphProto * graph_proto)523 void AddConcatOp(const std::vector<std::string> &inputs, const std::string &output, int axis,
524 onnx::GraphProto *graph_proto) {
525 onnx::NodeProto *concat_proto = graph_proto->add_node();
526 auto op_type = "Concat";
527 concat_proto->set_op_type(op_type);
528 concat_proto->set_name(output + op_type);
529 for (const auto &input : inputs) {
530 concat_proto->add_input(input);
531 }
532 concat_proto->add_output(output);
533 onnx::AttributeProto *axis_proto = concat_proto->add_attribute();
534 axis_proto->set_name("axis");
535 axis_proto->set_type(onnx::AttributeProto_AttributeType_INT);
536 axis_proto->set_i(axis);
537 }
538
ConvertBoxesToXywh(const std::string & startpoints,const std::string & endpoints,const std::string & centerpoints,const std::string & dimensions,onnx::TensorProto_DataType type,onnx::GraphProto * graph_proto)539 void ConvertBoxesToXywh(const std::string &startpoints, const std::string &endpoints, const std::string ¢erpoints,
540 const std::string &dimensions, onnx::TensorProto_DataType type, onnx::GraphProto *graph_proto) {
541 auto coord_sums_name = centerpoints + "__to_div";
542 AddOp("Add", {startpoints, endpoints}, {coord_sums_name}, graph_proto);
543 auto two_name = centerpoints + "__two_initializer";
544 AddFloatScalarInitializer(two_name, 2.0f, type, graph_proto);
545 AddOp("Div", {coord_sums_name, two_name}, {centerpoints}, graph_proto);
546
547 auto coord_diffs_name = dimensions + "__to_add";
548 AddOp("Sub", {endpoints, startpoints}, {coord_diffs_name}, graph_proto);
549 auto one_name = dimensions + "__one_initializer";
550 AddFloatScalarInitializer(one_name, 1.0f, type, graph_proto);
551 AddOp("Add", {coord_diffs_name, one_name}, {dimensions}, graph_proto);
552 }
553
ConvertBoxesToXyxy(const std::string & centerpoints,const std::string & dimensions,const std::string & startpoints,const std::string & endpoints,onnx::TensorProto_DataType type,onnx::GraphProto * graph_proto)554 void ConvertBoxesToXyxy(const std::string ¢erpoints, const std::string &dimensions, const std::string &startpoints,
555 const std::string &endpoints, onnx::TensorProto_DataType type, onnx::GraphProto *graph_proto) {
556 auto half_name = startpoints + "__half_initializer";
557 AddFloatScalarInitializer(half_name, 0.5f, type, graph_proto);
558
559 auto half_dim_name = startpoints + "__half_dim";
560 auto half_dim_to_sub_name = startpoints + "__to_sub";
561 AddOp("Mul", {dimensions, half_name}, {half_dim_to_sub_name}, graph_proto);
562 AddOp("Sub", {half_dim_to_sub_name, half_name}, {half_dim_name}, graph_proto);
563
564 AddOp("Sub", {centerpoints, half_dim_name}, {startpoints}, graph_proto);
565 AddOp("Add", {centerpoints, half_dim_name}, {endpoints}, graph_proto);
566 }
567
ClipPointsComponent(const std::string & points,const std::string & clipped,float max,int64_t component_idx,onnx::TensorProto_DataType type,onnx::GraphProto * graph_proto)568 void ClipPointsComponent(const std::string &points, const std::string &clipped, float max, int64_t component_idx,
569 onnx::TensorProto_DataType type, onnx::GraphProto *graph_proto) {
570 auto res_to_clip_name = clipped + "__clip";
571 AddSliceOp(points, res_to_clip_name, {component_idx}, {component_idx + 1}, {1}, {1}, graph_proto);
572 AddClipOp(res_to_clip_name, clipped, 0.0f, max, type, graph_proto);
573 }
574
575 // check AnfNode data type is float or not.
IsFloatDataType(const AnfNodePtr & node)576 bool IsFloatDataType(const AnfNodePtr &node) {
577 auto dtype = node->Type();
578 auto elem_type = dyn_cast<TensorType>(dtype)->element()->type_id();
579 switch (elem_type) {
580 case (kNumberTypeFloat):
581 case (kNumberTypeFloat16):
582 case (kNumberTypeFloat32):
583 case (kNumberTypeFloat64):
584 return True;
585 default:
586 return False;
587 }
588 }
589
590 namespace while_loop_export {
591 namespace {
592 const char CONTROL_PATTERN[] = "\u21B5";
593 const char LOOP_BODY_PATTERN[] = "\u21BB";
594 const char AFTER_LOOP_PATTERN[] = "\u2193";
595
596 const size_t LOOP_BODY_INPUT = 2;
597 const size_t AFTER_LOOP_INPUT = 3;
598
IsSubgraphNameCorrect(const FuncGraphPtr & func_graph,const std::string & part_pattern)599 bool IsSubgraphNameCorrect(const FuncGraphPtr &func_graph, const std::string &part_pattern) {
600 auto name = func_graph->ToString();
601 return name.find("construct") != std::string::npos && name.find(part_pattern) != std::string::npos;
602 }
603
604 template <typename T>
GetNodeInput(const CNodePtr & node,size_t i)605 const std::shared_ptr<T> GetNodeInput(const CNodePtr &node, size_t i) {
606 auto input = GetRealInput(node->input(i));
607 auto result = dyn_cast<T>(input);
608 if (result == nullptr) {
609 MS_LOG(EXCEPTION) << "Failed to get input " << i << " of node " << node->DebugString();
610 }
611 return result;
612 }
613
614 template <typename T>
GetNodeInputValue(const CNodePtr & node,size_t i)615 const std::shared_ptr<T> GetNodeInputValue(const CNodePtr &node, size_t i) {
616 auto input = GetNodeInput<ValueNode>(node, i);
617 auto result = dyn_cast<T>(input->value());
618 if (result == nullptr) {
619 MS_LOG(EXCEPTION) << "Failed to get a value from input " << i << " of node " << node->DebugString();
620 }
621 return result;
622 }
623
FindLoopSwitchNode(const FuncGraphPtr & control_subgraph)624 CNodePtr FindLoopSwitchNode(const FuncGraphPtr &control_subgraph) {
625 if (!IsSubgraphNameCorrect(control_subgraph, CONTROL_PATTERN)) {
626 MS_LOG(EXCEPTION) << "Expected a loop control structure";
627 }
628 auto lazy_call_node = GetNodeInput<CNode>(control_subgraph->get_return(), kOneNum);
629 if (lazy_call_node->size() != kOneNum || !lazy_call_node->input(kZeroNum)->isa<CNode>()) {
630 MS_LOG(EXCEPTION) << "Expected a lazy call node";
631 }
632 auto switch_node = GetNodeInput<CNode>(lazy_call_node, kZeroNum);
633 if (!switch_node->IsApply(prim::kPrimSwitch)) {
634 MS_LOG(EXCEPTION) << "Expected a switch node";
635 }
636 return switch_node;
637 }
638
GetSubgraph(const CNodePtr & switch_node,size_t input_index,const std::string & name_pattern)639 FuncGraphPtr GetSubgraph(const CNodePtr &switch_node, size_t input_index, const std::string &name_pattern) {
640 auto input_node = GetNodeInput<CNode>(switch_node, input_index);
641 if (!input_node->IsApply(prim::kPrimPartial)) {
642 MS_LOG(EXCEPTION) << "Expected a partial node";
643 }
644
645 auto subgraph = GetNodeInputValue<FuncGraph>(input_node, kOneNum);
646 if (!IsSubgraphNameCorrect(subgraph, name_pattern)) {
647 MS_LOG(EXCEPTION) << "Expected a loop part: " << name_pattern;
648 }
649
650 return subgraph;
651 }
652
653 // The inputs of this node are the outputs of ONNX Loop
FindLoopRepeatNode(const FuncGraphPtr & loop_subgraph,const FuncGraphPtr & control_subgraph)654 CNodePtr FindLoopRepeatNode(const FuncGraphPtr &loop_subgraph, const FuncGraphPtr &control_subgraph) {
655 auto repeat_node = GetNodeInput<CNode>(loop_subgraph->return_node(), kOneNum);
656 auto maybe_control_graph = GetNodeInputValue<FuncGraph>(repeat_node, kZeroNum);
657 MS_EXCEPTION_IF_CHECK_FAIL(maybe_control_graph == control_subgraph, "Loop matching failed");
658 return repeat_node;
659 }
660
661 struct LoopConditionInfo {
662 int64_t begin;
663 int64_t end;
664 int64_t step;
665 };
666
667 /*
668 NOTE: loop support is currently very limited, because proper condition export requires more graph surgery (copying
669 condition expression before and inside Loop subgraph)
670 The only while loop form supported currently is the one used in GNMT v2's Beam Search. Python example:
671 i = begin
672 while i < end
673 ...
674 i += step
675 To enable proper support for arbitrary while loop conditions, condition calculation should be duplicated inside the
676 Loop supgraph. But exporting the same ops twice with different names is not currently supported.
677 */
TraceLoopConditionInfo(const CNodePtr & start_node,const CNodePtr & cond_node,const FuncGraphPtr & control_subgraph,const CNodePtr & loop_repeat_node)678 LoopConditionInfo TraceLoopConditionInfo(const CNodePtr &start_node, const CNodePtr &cond_node,
679 const FuncGraphPtr &control_subgraph, const CNodePtr &loop_repeat_node) {
680 MS_EXCEPTION_IF_CHECK_FAIL(cond_node->IsApply(prim::kPrimLess), "Expected Less node");
681
682 auto counter = GetNodeInput<Parameter>(cond_node, kOneNum);
683 auto end_tensor = GetNodeInputValue<tensor::Tensor>(cond_node, kTwoNum);
684 MS_EXCEPTION_IF_CHECK_FAIL(end_tensor->shape_c().empty(), "Expected a scalar tensor");
685 auto end = *reinterpret_cast<const int32_t *>(end_tensor->data_c());
686
687 const auto &subgraph_args = control_subgraph->parameters();
688 auto counter_input_pos = std::find(subgraph_args.begin(), subgraph_args.end(), counter) - subgraph_args.begin();
689
690 auto begin_tensor = GetNodeInputValue<tensor::Tensor>(start_node, 1UL + static_cast<size_t>(counter_input_pos));
691 MS_EXCEPTION_IF_CHECK_FAIL(begin_tensor->shape_c().empty(), "Expected a scalar tensor");
692 auto begin = *reinterpret_cast<const int32_t *>(begin_tensor->data_c());
693
694 auto increment_node = GetNodeInput<CNode>(loop_repeat_node, 1UL + static_cast<size_t>(counter_input_pos));
695 MS_EXCEPTION_IF_CHECK_FAIL(increment_node->IsApply(prim::kPrimAdd), "Expected Add node");
696 auto step_tensor = GetNodeInputValue<tensor::Tensor>(increment_node, kTwoNum);
697 MS_EXCEPTION_IF_CHECK_FAIL(step_tensor->shape_c().empty(), "Expected a scalar tensor");
698 auto step = *reinterpret_cast<const int32_t *>(step_tensor->data_c());
699
700 return LoopConditionInfo{begin, end, step};
701 }
702
703 // result[i] is which control subgraph input should be taken for pos i to match the order of loop subgraph inputs
TraceLoopToControlMap(const FuncGraphPtr & control_subgraph)704 std::vector<size_t> TraceLoopToControlMap(const FuncGraphPtr &control_subgraph) {
705 std::vector<size_t> result;
706
707 auto switch_node = FindLoopSwitchNode(control_subgraph);
708 auto loop_partial_node = GetNodeInput<CNode>(switch_node, kTwoNum);
709 const auto &control_params = control_subgraph->parameters();
710 int64_t auxiliary_inputs_num = 2;
711 for (size_t i = static_cast<size_t>(auxiliary_inputs_num); i < loop_partial_node->size(); ++i) {
712 auto loop_param = GetNodeInput<Parameter>(loop_partial_node, i);
713 auto control_param_pos =
714 std::find(control_params.begin(), control_params.end(), loop_param) - control_params.begin();
715 result.push_back(control_param_pos);
716 }
717
718 return result;
719 }
720
TraceAfterToLoopMap(const FuncGraphPtr & control_subgraph)721 std::vector<size_t> TraceAfterToLoopMap(const FuncGraphPtr &control_subgraph) {
722 std::vector<size_t> result;
723
724 auto switch_node = FindLoopSwitchNode(control_subgraph);
725 auto loop_partial_node = GetNodeInput<CNode>(switch_node, kTwoNum);
726 auto after_partial_node = GetNodeInput<CNode>(switch_node, kThreeNum);
727 const auto &loop_params = loop_partial_node->inputs();
728 int64_t auxiliary_inputs_num = 2;
729 for (size_t i = static_cast<size_t>(auxiliary_inputs_num); i < after_partial_node->size(); ++i) {
730 auto after_param = GetNodeInput<Parameter>(after_partial_node, i);
731 auto after_param_pos = std::find(loop_params.begin(), loop_params.end(), after_param) - loop_params.begin();
732 result.push_back(after_param_pos - auxiliary_inputs_num);
733 }
734
735 return result;
736 }
737
TraceIgnoredLoopParams(const CNodePtr & start_node,const std::vector<size_t> & loop_to_control_map)738 std::vector<bool> TraceIgnoredLoopParams(const CNodePtr &start_node, const std::vector<size_t> &loop_to_control_map) {
739 auto inputs_num = start_node->size() - 1;
740 std::vector<bool> result(inputs_num);
741 for (size_t loop_i = 0; loop_i < inputs_num; ++loop_i) {
742 auto control_i = loop_to_control_map.at(loop_i);
743 const auto &input = start_node->input(control_i + 1);
744 if ((input->isa<Parameter>() && input->cast<ParameterPtr>()->has_default()) || HasAbstractMonad(input)) {
745 result.at(loop_i) = true;
746 }
747 }
748 return result;
749 }
750 } // namespace
751
IsControlSubgraph(const ValuePtr & func_graph_node)752 bool IsControlSubgraph(const ValuePtr &func_graph_node) {
753 auto func_graph = dyn_cast<FuncGraph>(func_graph_node);
754 return func_graph != nullptr && IsSubgraphNameCorrect(func_graph, CONTROL_PATTERN);
755 }
756
IsLoopBodyReturnNode(const CNodePtr & node,const FuncGraphPtr & func_graph)757 bool IsLoopBodyReturnNode(const CNodePtr &node, const FuncGraphPtr &func_graph) {
758 return IsSubgraphNameCorrect(func_graph, LOOP_BODY_PATTERN) && node == func_graph->get_return();
759 }
760
IsAfterLoopReturnNode(const CNodePtr & node,const FuncGraphPtr & func_graph)761 bool IsAfterLoopReturnNode(const CNodePtr &node, const FuncGraphPtr &func_graph) {
762 return IsSubgraphNameCorrect(func_graph, AFTER_LOOP_PATTERN) && node == func_graph->get_return();
763 }
764
765 struct LoopParts {
766 LoopConditionInfo loop_condition_info;
767 std::vector<std::pair<size_t, size_t>> after_param_to_output_indices;
768 std::vector<size_t> ignored_loop_param_indices;
769 std::vector<std::pair<size_t, size_t>> used_loop_to_control_param_indices;
770 CNodePtr repeat_node;
771 FuncGraphPtr loop_subgraph;
772 FuncGraphPtr after_loop_subgraph;
773 };
774
MatchGraph(const CNodePtr & start_node)775 LoopParts MatchGraph(const CNodePtr &start_node) {
776 LoopParts result;
777
778 auto control_subgraph_value = dyn_cast<ValueNode>(start_node->input(0));
779 MS_EXCEPTION_IF_NULL(control_subgraph_value);
780 auto control_subgraph = dyn_cast<FuncGraph>(control_subgraph_value->value());
781 MS_EXCEPTION_IF_NULL(control_subgraph);
782
783 auto switch_node = FindLoopSwitchNode(control_subgraph);
784 auto cond_node = GetNodeInput<CNode>(switch_node, kOneNum);
785
786 result.loop_subgraph = GetSubgraph(switch_node, LOOP_BODY_INPUT, LOOP_BODY_PATTERN);
787
788 result.repeat_node = FindLoopRepeatNode(result.loop_subgraph, control_subgraph);
789 result.loop_condition_info = TraceLoopConditionInfo(start_node, cond_node, control_subgraph, result.repeat_node);
790
791 result.after_loop_subgraph = GetSubgraph(switch_node, AFTER_LOOP_INPUT, AFTER_LOOP_PATTERN);
792
793 auto loop_to_control_order_map = TraceLoopToControlMap(control_subgraph);
794 auto ignored_loop_params_mask = TraceIgnoredLoopParams(start_node, loop_to_control_order_map);
795 auto loop_inputs_num = start_node->size() - 1;
796 for (size_t i = 0; i < loop_inputs_num; ++i) {
797 if (ignored_loop_params_mask.at(i)) {
798 result.ignored_loop_param_indices.push_back(i);
799 } else {
800 result.used_loop_to_control_param_indices.push_back(std::make_pair(i, loop_to_control_order_map.at(i)));
801 }
802 }
803
804 auto after_to_loop_order_map = TraceAfterToLoopMap(control_subgraph);
805 for (size_t after_i = 0; after_i < result.after_loop_subgraph->parameters().size(); ++after_i) {
806 auto loop_i = after_to_loop_order_map.at(after_i);
807 if (!ignored_loop_params_mask.at(loop_i)) {
808 auto output_i = loop_i;
809 for (size_t i = 0; i < loop_i; ++i) {
810 output_i -= static_cast<size_t>(ignored_loop_params_mask.at(i));
811 }
812 result.after_param_to_output_indices.push_back(std::make_pair(after_i, output_i));
813 }
814 }
815
816 return result;
817 }
818 } // namespace while_loop_export
819
820 class OpAttrInfo {
821 public:
OpAttrInfo(const std::string & attr_name,const string & onnx_attr_name,onnx::AttributeProto_AttributeType onnx_attr_type,const GenAttrFuncType & fn_gen_attr)822 OpAttrInfo(const std::string &attr_name, const string &onnx_attr_name,
823 onnx::AttributeProto_AttributeType onnx_attr_type, const GenAttrFuncType &fn_gen_attr)
824 : attr_name_(attr_name),
825 onnx_attr_name_(onnx_attr_name),
826 onnx_attr_type_(onnx_attr_type),
827 fn_gen_attr_(fn_gen_attr) {}
~OpAttrInfo()828 ~OpAttrInfo() {}
829
attr_name() const830 const std::string &attr_name() const { return attr_name_; }
onnx_attr_name() const831 const std::string &onnx_attr_name() const { return onnx_attr_name_; }
onnx_attr_type() const832 onnx::AttributeProto_AttributeType onnx_attr_type() const { return onnx_attr_type_; }
fn_gen_attr() const833 GenAttrFuncType fn_gen_attr() const { return fn_gen_attr_; }
834
835 private:
836 std::string attr_name_; // attribute name of MindSpore
837 std::string onnx_attr_name_; // corresponding attribute name of ONNX
838 onnx::AttributeProto_AttributeType onnx_attr_type_; // corresponding attribute type of ONNX
839 GenAttrFuncType fn_gen_attr_; // function used convert
840 };
841
842 struct InputConversion {
843 int input_index;
844 onnx::TensorProto_DataType input_type;
845 onnx::TensorProto_DataType target_type;
846 };
847
848 struct OutputConversion {
849 int output_index;
850 enum class Mode { FIXED, INPUT } mode;
851 union {
852 onnx::TensorProto_DataType target_type;
853 int input_with_matching_type;
854 };
855 };
856
857 class OpNameInfo {
858 public:
set_op_type(const std::string & op_type)859 OpNameInfo &set_op_type(const std::string &op_type) {
860 op_type_ = op_type;
861 return *this;
862 }
863
op_type() const864 const std::string &op_type() const { return op_type_; }
865
set_onnx_type(const std::string & onnx_type)866 OpNameInfo &set_onnx_type(const std::string &onnx_type) {
867 onnx_type_ = onnx_type;
868 return *this;
869 }
870
onnx_type() const871 const std::string &onnx_type() const { return onnx_type_; }
872
Attr(const std::string & attr_name,const std::string & onnx_attr_name,onnx::AttributeProto_AttributeType onnx_attr_type,const GenAttrFuncType & fn_gen_attr)873 OpNameInfo &Attr(const std::string &attr_name, const std::string &onnx_attr_name,
874 onnx::AttributeProto_AttributeType onnx_attr_type, const GenAttrFuncType &fn_gen_attr) {
875 (void)op_attrs_.emplace_back(OpAttrInfo(attr_name, onnx_attr_name, onnx_attr_type, fn_gen_attr));
876 return *this;
877 }
878
op_attrs() const879 const std::vector<OpAttrInfo> &op_attrs() const { return op_attrs_; }
880
input_casts() const881 const std::vector<InputConversion> &input_casts() const { return input_casts_; }
882
CastInput(int input_index,onnx::TensorProto_DataType input_type,onnx::TensorProto_DataType target_type)883 OpNameInfo &CastInput(int input_index, onnx::TensorProto_DataType input_type,
884 onnx::TensorProto_DataType target_type) {
885 input_casts_.push_back({input_index, input_type, target_type});
886 return *this;
887 }
888
output_casts() const889 const std::vector<OutputConversion> &output_casts() const { return output_casts_; }
890
CastOutputToFixedType(onnx::TensorProto_DataType type,int output_index=0)891 OpNameInfo &CastOutputToFixedType(onnx::TensorProto_DataType type, int output_index = 0) {
892 output_casts_.push_back({output_index, OutputConversion::Mode::FIXED, {type}});
893 return *this;
894 }
895
CastOutputToInputType(int input_index,int output_index=0)896 OpNameInfo &CastOutputToInputType(int input_index, int output_index = 0) {
897 auto rule = OutputConversion{output_index, OutputConversion::Mode::INPUT};
898 rule.input_with_matching_type = input_index;
899 output_casts_.push_back(rule);
900 return *this;
901 }
902
num_outputs() const903 int num_outputs() const { return num_outputs_; }
904
set_num_outputs(int n)905 OpNameInfo &set_num_outputs(int n) {
906 num_outputs_ = n;
907 return *this;
908 }
909
910 private:
911 std::string op_type_; // operator type of MindSpore
912 std::string onnx_type_; // corresponding ONNX operator type
913 std::vector<OpAttrInfo> op_attrs_; // operator attributes map info
914 std::vector<InputConversion> input_casts_; // if input input_index has type input_type, cast it to target_type
915 std::vector<OutputConversion> output_casts_; // cast output output_index to fixed type or input type
916 int num_outputs_ = 1;
917 };
918
919 #define OPERATOR_ONNX_CONVERT_DEFINE(name, onnx_name, impl) \
920 OpNameInfo GetOpOnnxConvertInfo_##name() { return impl.set_op_type(#name).set_onnx_type(#onnx_name); }
921
OPERATOR_ONNX_CONVERT_DEFINE(Mod,Mod,OpNameInfo ())922 OPERATOR_ONNX_CONVERT_DEFINE(Mod, Mod, OpNameInfo())
923 OPERATOR_ONNX_CONVERT_DEFINE(Add, Add, OpNameInfo())
924 OPERATOR_ONNX_CONVERT_DEFINE(Mul, Mul, OpNameInfo())
925 OPERATOR_ONNX_CONVERT_DEFINE(Pow, Pow, OpNameInfo())
926
927 OPERATOR_ONNX_CONVERT_DEFINE(ReLU, Relu, OpNameInfo())
928 OPERATOR_ONNX_CONVERT_DEFINE(Sigmoid, Sigmoid, OpNameInfo())
929 OPERATOR_ONNX_CONVERT_DEFINE(Sin, Sin, OpNameInfo())
930 OPERATOR_ONNX_CONVERT_DEFINE(Round, Round, OpNameInfo())
931 OPERATOR_ONNX_CONVERT_DEFINE(Div, Div, OpNameInfo())
932
933 OPERATOR_ONNX_CONVERT_DEFINE(Flatten, Flatten, OpNameInfo())
934
935 OPERATOR_ONNX_CONVERT_DEFINE(
936 Conv2D, Conv,
937 OpNameInfo()
938 .Attr("dilation", "dilations", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<2>)
939 .Attr("group", "group", onnx::AttributeProto_AttributeType_INT, SetAttrValueToProto<Int64Imm>)
940 .Attr("kernel_size", "kernel_shape", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<0>)
941 .Attr("pad_mode", "auto_pad", onnx::AttributeProto_AttributeType_STRING, SetConvPadding)
942 .Attr("stride", "strides", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<2>))
943 OPERATOR_ONNX_CONVERT_DEFINE(
944 Conv3D, Conv,
945 OpNameInfo()
946 .Attr("dilations", "dilations", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<kTwoNum>)
947 .Attr("group", "group", onnx::AttributeProto_AttributeType_INT, SetAttrValueToProto<Int64Imm>)
948 .Attr("kernel_size", "kernel_shape", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<0>)
949 .Attr("pad_mode", "auto_pad", onnx::AttributeProto_AttributeType_STRING, SetConvPadding)
950 .Attr("strides", "strides", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<kTwoNum>))
951 OPERATOR_ONNX_CONVERT_DEFINE(
952 Conv3DTranspose, ConvTranspose,
953 OpNameInfo()
954 .Attr("dilations", "dilations", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<kTwoNum>)
955 .Attr("group", "group", onnx::AttributeProto_AttributeType_INT, SetAttrValueToProto<Int64Imm>)
956 .Attr("kernel_size", "kernel_shape", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<0>)
957 .Attr("pad_mode", "auto_pad", onnx::AttributeProto_AttributeType_STRING, SetConvTransposePadding)
958 .Attr("strides", "strides", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<kTwoNum>)
959 .Attr("output_padding", "output_padding", onnx::AttributeProto_AttributeType_INTS,
960 SetAttrTupleValueToProto<kTwoNum>))
961
962 OPERATOR_ONNX_CONVERT_DEFINE(DepthToSpace, DepthToSpace,
963 OpNameInfo().Attr("block_size", "blocksize", onnx::AttributeProto_AttributeType_INT,
964 SetAttrValueToProto<Int64Imm>))
965
966 OPERATOR_ONNX_CONVERT_DEFINE(BiasAdd, Add, OpNameInfo())
967 OPERATOR_ONNX_CONVERT_DEFINE(MatMul, Gemm,
968 OpNameInfo()
969 .Attr("transpose_a", "transA", onnx::AttributeProto_AttributeType_INT,
970 SetAttrValueToProto<BoolImm>)
971 .Attr("transpose_b", "transB", onnx::AttributeProto_AttributeType_INT,
972 SetAttrValueToProto<BoolImm>))
973
974 OPERATOR_ONNX_CONVERT_DEFINE(BatchNorm, BatchNormalization,
975 OpNameInfo()
976 .Attr("epsilon", "epsilon", onnx::AttributeProto_AttributeType_FLOAT,
977 SetAttrValueToProto<FP32Imm>)
978 .CastInput(0, onnx::TensorProto_DataType_FLOAT16, onnx::TensorProto_DataType_FLOAT)
979 .CastOutputToInputType(0))
980
981 OPERATOR_ONNX_CONVERT_DEFINE(Reshape, Reshape, OpNameInfo())
982 OPERATOR_ONNX_CONVERT_DEFINE(Cast, Cast, OpNameInfo())
983 OPERATOR_ONNX_CONVERT_DEFINE(PReLU, PRelu, OpNameInfo())
984 OPERATOR_ONNX_CONVERT_DEFINE(Argmax, ArgMax,
985 OpNameInfo()
986 .Attr("axis", "axis", onnx::AttributeProto_AttributeType_INT,
987 SetAttrValueToProto<Int64Imm>)
988 .Attr("", "keepdims", onnx::AttributeProto_AttributeType_INT,
989 [](ValuePtr, onnx::AttributeProto_AttributeType,
990 onnx::AttributeProto *const attr_proto, const PrimitivePtr &) {
991 attr_proto->set_type(onnx::AttributeProto_AttributeType_INT);
992 attr_proto->set_i(0);
993 })
994 .CastOutputToFixedType(onnx::TensorProto_DataType_INT32))
995
996 OPERATOR_ONNX_CONVERT_DEFINE(SimpleMean, AveragePool, OpNameInfo())
997 OPERATOR_ONNX_CONVERT_DEFINE(
998 MaxPool, MaxPool,
999 OpNameInfo()
1000 .Attr("kernel_size", "kernel_shape", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<2>)
1001 .Attr("pad_mode", "auto_pad", onnx::AttributeProto_AttributeType_STRING, SetPoolingPadMode)
1002 .Attr("strides", "strides", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<2>))
1003
1004 OPERATOR_ONNX_CONVERT_DEFINE(
1005 MaxPool3D, MaxPool,
1006 OpNameInfo()
1007 .Attr("kernel_size", "kernel_shape", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<2>)
1008 .Attr("pad_mode", "auto_pad", onnx::AttributeProto_AttributeType_STRING, SetPoolingPadMode)
1009 .Attr("strides", "strides", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<2>))
1010
1011 OPERATOR_ONNX_CONVERT_DEFINE(
1012 MaxPoolWithArgmax, MaxPool,
1013 OpNameInfo()
1014 .Attr("kernel_size", "kernel_shape", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<2>)
1015 .Attr("pad_mode", "auto_pad", onnx::AttributeProto_AttributeType_STRING, SetPoolingPadMode)
1016 .Attr("strides", "strides", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<2>))
1017
1018 OPERATOR_ONNX_CONVERT_DEFINE(
1019 AvgPool, AveragePool,
1020 OpNameInfo()
1021 .Attr("kernel_size", "kernel_shape", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<2>)
1022 .Attr("pad_mode", "auto_pad", onnx::AttributeProto_AttributeType_STRING, SetPoolingPadMode)
1023 .Attr("strides", "strides", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<2>))
1024
1025 OPERATOR_ONNX_CONVERT_DEFINE(Gather, Gather, OpNameInfo())
1026 OPERATOR_ONNX_CONVERT_DEFINE(MakeTuple, SequenceConstruct, OpNameInfo())
1027 OPERATOR_ONNX_CONVERT_DEFINE(RealDiv, Div, OpNameInfo())
1028 OPERATOR_ONNX_CONVERT_DEFINE(Sub, Sub, OpNameInfo())
1029 OPERATOR_ONNX_CONVERT_DEFINE(Neg, Neg, OpNameInfo())
1030 OPERATOR_ONNX_CONVERT_DEFINE(Maximum, Max,
1031 OpNameInfo()
1032 .CastInput(0, onnx::TensorProto_DataType_INT32, onnx::TensorProto_DataType_FLOAT)
1033 .CastInput(1, onnx::TensorProto_DataType_INT32, onnx::TensorProto_DataType_FLOAT)
1034 .CastOutputToInputType(0))
1035 OPERATOR_ONNX_CONVERT_DEFINE(Minimum, Min,
1036 OpNameInfo()
1037 .CastInput(0, onnx::TensorProto_DataType_INT32, onnx::TensorProto_DataType_FLOAT)
1038 .CastInput(1, onnx::TensorProto_DataType_INT32, onnx::TensorProto_DataType_FLOAT)
1039 .CastOutputToInputType(0))
1040 OPERATOR_ONNX_CONVERT_DEFINE(Transpose, Transpose, OpNameInfo())
1041 OPERATOR_ONNX_CONVERT_DEFINE(Exp, Exp, OpNameInfo())
1042 OPERATOR_ONNX_CONVERT_DEFINE(Softplus, Softplus, OpNameInfo())
1043 OPERATOR_ONNX_CONVERT_DEFINE(Tanh, Tanh, OpNameInfo())
1044 OPERATOR_ONNX_CONVERT_DEFINE(Abs, Abs, OpNameInfo())
1045
1046 // MindSpore Softmax axis(int, Tuple)
1047 OPERATOR_ONNX_CONVERT_DEFINE(Softmax, Softmax,
1048 OpNameInfo().Attr("axis", "axis", onnx::AttributeProto_AttributeType_INT,
1049 SetAttrTupleValueToProto<0>))
1050
1051 // MindSpore LogSoftmax axis(int)
1052 OPERATOR_ONNX_CONVERT_DEFINE(LogSoftmax, LogSoftmax,
1053 OpNameInfo().Attr("axis", "axis", onnx::AttributeProto_AttributeType_INT,
1054 SetAttrValueToProto<Int64Imm>))
1055
1056 OPERATOR_ONNX_CONVERT_DEFINE(Softsign, Softsign, OpNameInfo())
1057 OPERATOR_ONNX_CONVERT_DEFINE(Sqrt, Sqrt, OpNameInfo())
1058 OPERATOR_ONNX_CONVERT_DEFINE(Equal, Equal, OpNameInfo())
1059 OPERATOR_ONNX_CONVERT_DEFINE(Floor, Floor, OpNameInfo())
1060 OPERATOR_ONNX_CONVERT_DEFINE(ACos, Acos, OpNameInfo())
1061
1062 OPERATOR_ONNX_CONVERT_DEFINE(GatherNd, GatherND,
1063 OpNameInfo().CastInput(1, onnx::TensorProto_DataType_INT32,
1064 onnx::TensorProto_DataType_INT64))
1065 OPERATOR_ONNX_CONVERT_DEFINE(Select, Where, OpNameInfo())
1066 OPERATOR_ONNX_CONVERT_DEFINE(Log, Log, OpNameInfo())
1067 OPERATOR_ONNX_CONVERT_DEFINE(Greater, Greater, OpNameInfo())
1068 OPERATOR_ONNX_CONVERT_DEFINE(LogicalAnd, And, OpNameInfo())
1069 OPERATOR_ONNX_CONVERT_DEFINE(LogicalOr, Or, OpNameInfo())
1070 OPERATOR_ONNX_CONVERT_DEFINE(ReverseSequence, ReverseSequence,
1071 OpNameInfo()
1072 .Attr("seq_dim", "time_axis", onnx::AttributeProto_AttributeType_INT,
1073 SetAttrValueToProto<Int64Imm>)
1074 .Attr("batch_dim", "batch_axis", onnx::AttributeProto_AttributeType_INT,
1075 SetAttrValueToProto<Int64Imm>)
1076 .CastInput(1, onnx::TensorProto_DataType_INT32, onnx::TensorProto_DataType_INT64))
1077 OPERATOR_ONNX_CONVERT_DEFINE(Less, Less, OpNameInfo())
1078 OPERATOR_ONNX_CONVERT_DEFINE(TensorScatterUpdate, ScatterND,
1079 OpNameInfo().CastInput(1, onnx::TensorProto_DataType_INT32,
1080 onnx::TensorProto_DataType_INT64))
1081 OPERATOR_ONNX_CONVERT_DEFINE(Cos, Cos, OpNameInfo())
1082 OPERATOR_ONNX_CONVERT_DEFINE(Atan2, Atan2, OpNameInfo())
1083
1084 #define OP_CONVERT_FUNCTION_NAME(name) GetOpOnnxConvertInfo_##name
1085
1086 void RegisterOpConverters(const std::function<void(OpNameInfo &&)> &fn) {
1087 fn(OP_CONVERT_FUNCTION_NAME(Mod)());
1088 fn(OP_CONVERT_FUNCTION_NAME(DepthToSpace)());
1089 fn(OP_CONVERT_FUNCTION_NAME(Add)());
1090 fn(OP_CONVERT_FUNCTION_NAME(Mul)());
1091 fn(OP_CONVERT_FUNCTION_NAME(Pow)());
1092 fn(OP_CONVERT_FUNCTION_NAME(ReLU)());
1093 fn(OP_CONVERT_FUNCTION_NAME(Sigmoid)());
1094 fn(OP_CONVERT_FUNCTION_NAME(Conv2D)());
1095 fn(OP_CONVERT_FUNCTION_NAME(Conv3D)());
1096 fn(OP_CONVERT_FUNCTION_NAME(Conv3DTranspose)());
1097 fn(OP_CONVERT_FUNCTION_NAME(Argmax)());
1098 fn(OP_CONVERT_FUNCTION_NAME(Flatten)());
1099 fn(OP_CONVERT_FUNCTION_NAME(MaxPool)());
1100 fn(OP_CONVERT_FUNCTION_NAME(MaxPool3D)());
1101 fn(OP_CONVERT_FUNCTION_NAME(MaxPoolWithArgmax)());
1102 fn(OP_CONVERT_FUNCTION_NAME(AvgPool)());
1103
1104 fn(OP_CONVERT_FUNCTION_NAME(BatchNorm)());
1105 fn(OP_CONVERT_FUNCTION_NAME(MatMul)());
1106 fn(OP_CONVERT_FUNCTION_NAME(MakeTuple)());
1107 fn(OP_CONVERT_FUNCTION_NAME(RealDiv)());
1108 fn(OP_CONVERT_FUNCTION_NAME(BiasAdd)());
1109 fn(OP_CONVERT_FUNCTION_NAME(Sub)());
1110 fn(OP_CONVERT_FUNCTION_NAME(Neg)());
1111 fn(OP_CONVERT_FUNCTION_NAME(Maximum)());
1112 fn(OP_CONVERT_FUNCTION_NAME(Minimum)());
1113 fn(OP_CONVERT_FUNCTION_NAME(Exp)());
1114
1115 fn(OP_CONVERT_FUNCTION_NAME(Softplus)());
1116 fn(OP_CONVERT_FUNCTION_NAME(Tanh)());
1117 fn(OP_CONVERT_FUNCTION_NAME(Softmax)());
1118 fn(OP_CONVERT_FUNCTION_NAME(LogSoftmax)());
1119 fn(OP_CONVERT_FUNCTION_NAME(Abs)());
1120 fn(OP_CONVERT_FUNCTION_NAME(Softsign)());
1121 fn(OP_CONVERT_FUNCTION_NAME(Sqrt)());
1122 fn(OP_CONVERT_FUNCTION_NAME(Equal)());
1123 fn(OP_CONVERT_FUNCTION_NAME(Floor)());
1124 fn(OP_CONVERT_FUNCTION_NAME(ACos)());
1125
1126 fn(OP_CONVERT_FUNCTION_NAME(GatherNd)());
1127 fn(OP_CONVERT_FUNCTION_NAME(Select)());
1128 fn(OP_CONVERT_FUNCTION_NAME(Log)());
1129 fn(OP_CONVERT_FUNCTION_NAME(Less)());
1130 fn(OP_CONVERT_FUNCTION_NAME(Greater)());
1131 fn(OP_CONVERT_FUNCTION_NAME(LogicalAnd)());
1132 fn(OP_CONVERT_FUNCTION_NAME(LogicalOr)());
1133 fn(OP_CONVERT_FUNCTION_NAME(ReverseSequence)());
1134 fn(OP_CONVERT_FUNCTION_NAME(TensorScatterUpdate)());
1135
1136 fn(OP_CONVERT_FUNCTION_NAME(Sin)());
1137 fn(OP_CONVERT_FUNCTION_NAME(Cos)());
1138 fn(OP_CONVERT_FUNCTION_NAME(Atan2)());
1139 fn(OP_CONVERT_FUNCTION_NAME(Round)());
1140 fn(OP_CONVERT_FUNCTION_NAME(Div)());
1141 }
1142
1143 class OpConvertRegistry {
1144 public:
~OpConvertRegistry()1145 ~OpConvertRegistry() { Clear(); }
1146
RegisterOneOpConverter(OpNameInfo && op_info)1147 static void RegisterOneOpConverter(OpNameInfo &&op_info) { GetSingleton().op_map_[op_info.op_type()] = op_info; }
1148
RegisterAllOpConverters()1149 static void RegisterAllOpConverters() { RegisterOpConverters(RegisterOneOpConverter); }
1150
GetSingleton()1151 static OpConvertRegistry &GetSingleton() {
1152 static OpConvertRegistry registry = OpConvertRegistry();
1153 return registry;
1154 }
1155
GetOpConvertMap()1156 static const mindspore::HashMap<std::string, OpNameInfo> &GetOpConvertMap() { return GetSingleton().op_map_; }
1157
Clear()1158 void Clear() noexcept { op_map_.clear(); }
1159
1160 private:
OpConvertRegistry()1161 OpConvertRegistry() {}
1162
1163 mindspore::HashMap<std::string, OpNameInfo> op_map_;
1164 };
1165
1166 class OnnxExporter {
1167 public:
OnnxExporter()1168 OnnxExporter() {}
~OnnxExporter()1169 ~OnnxExporter() {}
1170
1171 std::string GetOnnxProtoString(const FuncGraphPtr &func_graph);
1172
1173 private:
1174 void InitModelInfo();
1175
1176 void ExportFuncGraph(const FuncGraphPtr &func_graph, std::map<AnfNodePtr, std::string> *node_map_ptr,
1177 onnx::GraphProto *graph_proto, bool export_inputs = true);
1178 void ExportInputs(const FuncGraphPtr &func_graph, std::map<AnfNodePtr, std::string> *node_map_ptr,
1179 onnx::GraphProto *graph_proto);
1180
1181 std::string ExportPrimitive(const FuncGraphPtr &func_graph, std::map<AnfNodePtr, std::string> *node_map_ptr,
1182 const PrimitivePtr &prim, const std::vector<AnfNodePtr> &inputs,
1183 onnx::GraphProto *graph_proto);
1184
1185 static onnx::TensorProto_DataType GetOnnxDataType(TypeId type_id);
1186 static onnx::TensorProto_DataType GetOutputType(const AnfNodePtr &node, int64_t output_index = -1);
1187 void SetValueInfoType(const AnfNodePtr &node, onnx::ValueInfoProto *value_proto, int64_t output_index = -1) const;
1188
1189 void MatchAndMark(const FuncGraphPtr &func_graph, const std::vector<AnfNodePtr> &nodes,
1190 mindspore::HashMap<AnfNodePtr, OpMergedInfo> *op_merged_infos_ptr) const;
1191 void MatchAndMarkCNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
1192 mindspore::HashMap<AnfNodePtr, OpMergedInfo> *op_merged_infos_ptr) const;
1193 void IgnoreMakeTuple(const AnfNodePtr &node, mindspore::HashMap<AnfNodePtr, OpMergedInfo> *op_merged_infos_ptr) const;
1194
1195 void ExportNodes(const FuncGraphPtr &func_graph, std::map<AnfNodePtr, std::string> *node_map_ptr,
1196 onnx::GraphProto *graph_proto);
1197
1198 void ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &node,
1199 std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
1200 void ExportWhileLoop(const CNodePtr &start_node, std::map<AnfNodePtr, std::string> *node_map_ptr,
1201 onnx::GraphProto *graph_proto);
1202
1203 void ExportPrimReshape(const FuncGraphPtr &func_graph, const CNodePtr &node,
1204 std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
1205 void ExportPrimReduce(const FuncGraphPtr &func_graph, const CNodePtr &node,
1206 std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
1207 void ExportPrimReduceAnyOrAll(const FuncGraphPtr &func_graph, const CNodePtr &node,
1208 std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
1209 void ExportPrimTranspose(const FuncGraphPtr &func_graph, const CNodePtr &node,
1210 std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
1211 void ExportPrimStridedSlice(const FuncGraphPtr &func_graph, const CNodePtr &node,
1212 std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
1213 onnx::NodeProto *PrimResizeExportHelper(const FuncGraphPtr &, const CNodePtr &node,
1214 std::map<AnfNodePtr, std::string> *node_map_ptr,
1215 onnx::GraphProto *const graph_proto);
1216 void ExportPrimResizeNearestNeighbor(const FuncGraphPtr &func_graph, const CNodePtr &node,
1217 std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
1218 void ExportPrimResizeBilinear(const FuncGraphPtr &func_graph, const CNodePtr &node,
1219 std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
1220 void ExportPrimExpandDims(const FuncGraphPtr &func_graph, const CNodePtr &node,
1221 std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
1222 void ExportPrimGatherD(const FuncGraphPtr &func_graph, const CNodePtr &node,
1223 std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
1224 void ExportPrimPad(const FuncGraphPtr &func_graph, const CNodePtr &node,
1225 std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
1226 void ExportPrimBatchMatMul(const FuncGraphPtr &func_graph, const CNodePtr &node,
1227 std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
1228 void ExportPrimBroadcastTo(const FuncGraphPtr &func_graph, const CNodePtr &node,
1229 std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
1230 void ExportPrimAddN(const FuncGraphPtr &func_graph, const CNodePtr &node,
1231 std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
1232 void ExportPrimGeLU(const FuncGraphPtr &func_graph, const CNodePtr &node,
1233 std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
1234 void ExportPrimConcat(const FuncGraphPtr &func_graph, const CNodePtr &node,
1235 std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
1236 void ExportPrimCast(const FuncGraphPtr &func_graph, const CNodePtr &node,
1237 std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
1238 void ExportPrimPReLU(const FuncGraphPtr &func_graph, const CNodePtr &node,
1239 std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
1240 void ExportPrimReLU6(const FuncGraphPtr &func_graph, const CNodePtr &node,
1241 std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
1242 void ExportPrimDepthwiseConv2d(const FuncGraphPtr &func_graph, const CNodePtr &node,
1243 std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
1244 void ExportPrimTile(const FuncGraphPtr &func_graph, const CNodePtr &node,
1245 std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
1246 void ExportPrimSquare(const FuncGraphPtr &func_graph, const CNodePtr &node,
1247 std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
1248 void ExportPrimGatherV2(const FuncGraphPtr &func_graph, const CNodePtr &node,
1249 std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
1250 void ExportPrimTupleGetItem(const FuncGraphPtr &func_graph, const CNodePtr &node,
1251 std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
1252 void ExportPrimTopK(const FuncGraphPtr &func_graph, const CNodePtr &node,
1253 std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
1254 void ExportPrimBoundingBoxDecode(const FuncGraphPtr &func_graph, const CNodePtr &node,
1255 std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
1256 void ExportPrimNMSWithMask(const FuncGraphPtr &func_graph, const CNodePtr &node,
1257 std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
1258 void ExportPrimSplit(const FuncGraphPtr &func_graph, const CNodePtr &node,
1259 std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
1260 void ExportPrimROIAlign(const FuncGraphPtr &func_graph, const CNodePtr &node,
1261 std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
1262 void ExportPrimSlice(const FuncGraphPtr &func_graph, const CNodePtr &node,
1263 std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
1264 void ExportPrimOnesLike(const FuncGraphPtr &func_graph, const CNodePtr &node,
1265 std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
1266 void ExportPrimScatterNd(const FuncGraphPtr &func_graph, const CNodePtr &node,
1267 std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
1268 void ExportPrimArgMaxWithValue(const FuncGraphPtr &func_graph, const CNodePtr &node,
1269 std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
1270 void ExportPrimArgMinWithValue(const FuncGraphPtr &func_graph, const CNodePtr &node,
1271 std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
1272 void ExportPrimOneHot(const FuncGraphPtr &func_graph, const CNodePtr &node,
1273 std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
1274 void PrimConv2DTransposeExportHelper(const CNodePtr &conv_node, const CNodePtr &bias_add_node,
1275 std::map<AnfNodePtr, std::string> *node_map_ptr,
1276 onnx::GraphProto *const graph_proto);
1277 void ExportPrimConv2DTranspose(const FuncGraphPtr &func_graph, const CNodePtr &node,
1278 std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
1279 void ExportPrimGreaterEqual(const FuncGraphPtr &func_graph, const CNodePtr &node,
1280 std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
1281 void ExportPrimLessEqual(const FuncGraphPtr &func_graph, const CNodePtr &node,
1282 std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
1283 void ExportPrimNotEqual(const FuncGraphPtr &func_graph, const CNodePtr &node,
1284 std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
1285 void ExportPrimDense(const FuncGraphPtr &func_graph, const CNodePtr &node,
1286 std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
1287 void ExportPrimSqueeze(const FuncGraphPtr &func_graph, const CNodePtr &node,
1288 std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
1289 void ExportPrimDynamicRNN(const FuncGraphPtr &func_graph, const CNodePtr &node,
1290 std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *const graph_proto);
1291 void ExportPrimLSTM(const FuncGraphPtr &, const CNodePtr &node, std::map<AnfNodePtr, std::string> *node_map_ptr,
1292 onnx::GraphProto *graph_proto);
1293 void ExportPrimReverseV2(const FuncGraphPtr &func_graph, const CNodePtr &node,
1294 std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
1295 void ExportPrimTensorCopySlices(const FuncGraphPtr &, const CNodePtr &node,
1296 std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
1297 void ExportPrimStack(const FuncGraphPtr &, const CNodePtr &node, std::map<AnfNodePtr, std::string> *node_map_ptr,
1298 onnx::GraphProto *graph_proto);
1299 void ExportPrimAtan2(const FuncGraphPtr &, const CNodePtr &node, std::map<AnfNodePtr, std::string> *node_map_ptr,
1300 onnx::GraphProto *graph_proto);
1301 void ExportPrimFloorDiv(const FuncGraphPtr &, const CNodePtr &node, std::map<AnfNodePtr, std::string> *node_map_ptr,
1302 onnx::GraphProto *graph_proto);
1303 void ExportPrimFloorMod(const FuncGraphPtr &, const CNodePtr &node, std::map<AnfNodePtr, std::string> *node_map_ptr,
1304 onnx::GraphProto *graph_proto);
1305 void ExportPrimSort(const FuncGraphPtr &, const CNodePtr &node, std::map<AnfNodePtr, std::string> *node_map_ptr,
1306 onnx::GraphProto *graph_proto);
1307 void ExportPrimCustom(const FuncGraphPtr &, const CNodePtr &node, std::map<AnfNodePtr, std::string> *node_map_ptr,
1308 onnx::GraphProto *graph_proto);
1309 void ExportMergeConv(const FuncGraphPtr &func_graph, const CNodePtr &node,
1310 std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
1311 void ExportMergeGemm(const FuncGraphPtr &func_graph, const CNodePtr &node,
1312 std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
1313 void ExportMergeBatchNorm(const FuncGraphPtr &func_graph, const CNodePtr &node,
1314 std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
1315 void ExportMergeMaxPoolWithArgmax(const FuncGraphPtr &func_graph, const CNodePtr &node,
1316 std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
1317 void ExportMergeLayerNorm(const FuncGraphPtr &func_graph, const CNodePtr &node,
1318 std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
1319 void ExportMergeConv2DTranspose(const FuncGraphPtr &, const CNodePtr &node,
1320 std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
1321 void ExportMergeDynamicGRUV2(const FuncGraphPtr &, const CNodePtr &node,
1322 std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *const graph_proto);
1323 void ExportOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &return_arg,
1324 std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto);
1325 std::string GetNodeInputName(const AnfNodePtr &node, std::map<AnfNodePtr, std::string> *node_map_ptr,
1326 onnx::GraphProto *const);
1327
1328 void ConvertTupleToTensor(const ValuePtr &value, onnx::TensorProto *tensor_proto) const;
1329 void SetTensorData(const ValuePtr &value, onnx::TensorProto *tensor_proto);
1330
1331 void AddOutputWithCast(onnx::NodeProto *node_proto, const std::string &output_name,
1332 onnx::TensorProto_DataType target_type, onnx::GraphProto *graph_proto) const;
1333
GenerateUniqueName()1334 std::string GenerateUniqueName() { return std::to_string(++onnx_node_index_); }
RegisterNodeWithUniqueName(const AnfNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr)1335 std::string RegisterNodeWithUniqueName(const AnfNodePtr &node, std::map<AnfNodePtr, std::string> *node_map_ptr) {
1336 auto name = GenerateUniqueName();
1337 (*node_map_ptr)[node] = name;
1338 return name;
1339 }
GenerateUniqueParameterName(const ParameterPtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr)1340 std::string GenerateUniqueParameterName(const ParameterPtr &node, std::map<AnfNodePtr, std::string> *node_map_ptr) {
1341 auto node_name = node->ToString();
1342 MS_EXCEPTION_IF_CHECK_FAIL(node_name != "", "Cannot get the name of an ignored parameter");
1343 auto dup_iter = std::find_if(node_map_ptr->begin(), node_map_ptr->end(),
1344 [&node_name](const auto &pair) { return pair.second == node_name; });
1345 if (dup_iter != node_map_ptr->end()) {
1346 node_name = GenerateUniqueName() + node_name;
1347 }
1348 return node_name;
1349 }
1350
ResetNodeIndex()1351 void ResetNodeIndex() { onnx_node_index_ = 0; }
1352
GetInt64Value(const AnfNodePtr & node)1353 static int64_t GetInt64Value(const AnfNodePtr &node) {
1354 auto value_node_ptr = dyn_cast<ValueNode>(node);
1355 MS_EXCEPTION_IF_NULL(value_node_ptr);
1356 return GetValue<int64_t>(value_node_ptr->value());
1357 }
1358
1359 onnx::ModelProto model_;
1360
1361 size_t onnx_node_index_ = 0;
1362
1363 std::map<AnfNodePtr, std::string> renamed_node_map_;
1364 };
1365
GetOnnxProtoString(const FuncGraphPtr & func_graph)1366 std::string OnnxExporter::GetOnnxProtoString(const FuncGraphPtr &func_graph) {
1367 if (func_graph == nullptr) {
1368 return "";
1369 }
1370 ResetNodeIndex();
1371 OpConvertRegistry::GetSingleton().Clear();
1372 OpConvertRegistry::RegisterAllOpConverters();
1373 InitModelInfo();
1374 onnx::GraphProto *graph_proto = model_.mutable_graph();
1375 std::map<AnfNodePtr, std::string> node_map;
1376 ExportFuncGraph(func_graph, &node_map, graph_proto);
1377 return model_.SerializeAsString();
1378 }
1379
InitModelInfo()1380 void OnnxExporter::InitModelInfo() {
1381 model_.set_ir_version(onnx::IR_VERSION_2019_1_22);
1382 model_.set_producer_name("MindSpore");
1383 model_.set_producer_version("1.0");
1384 onnx::OperatorSetIdProto *opset_proto = model_.add_opset_import();
1385 opset_proto->set_version(ONNX_VERSION);
1386 }
1387
ExportFuncGraph(const FuncGraphPtr & func_graph,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto,bool export_inputs)1388 void OnnxExporter::ExportFuncGraph(const FuncGraphPtr &func_graph, std::map<AnfNodePtr, std::string> *node_map_ptr,
1389 onnx::GraphProto *const graph_proto, bool export_inputs) {
1390 MS_LOG(INFO) << "Begin exporting onnx model for graph " << func_graph->ToString();
1391
1392 // Convert yaml defined primitive to old primitive.
1393 std::vector<AnfNodePtr> nodes = TopoSort(func_graph->get_return(), SuccIncoming, AlwaysInclude);
1394 auto manager = func_graph->manager();
1395 MS_EXCEPTION_IF_NULL(manager);
1396 for (const auto &node : nodes) {
1397 if (node->isa<CNode>()) {
1398 auto converted_node = ops::ConvertArgsToAttr(node->cast<CNodePtr>());
1399 // If node is old primitive node, nullptr will be returned.
1400 if (converted_node == nullptr) {
1401 continue;
1402 }
1403 (void)manager->Replace(node, converted_node);
1404 }
1405 }
1406 // set graph name
1407 graph_proto->set_name(func_graph->ToString());
1408
1409 // export inputs if graph is not inlined
1410 if (export_inputs) {
1411 ExportInputs(func_graph, node_map_ptr, graph_proto);
1412 }
1413
1414 // export computational nodes and output nodes
1415 ExportNodes(func_graph, node_map_ptr, graph_proto);
1416
1417 // add names for easier debugging
1418 for (auto &node : *graph_proto->mutable_node()) {
1419 if (!node.has_name()) {
1420 node.set_name(node.output(0) + node.op_type());
1421 }
1422 }
1423
1424 MS_LOG(INFO) << "End exporting onnx model for graph " << func_graph->ToString();
1425 }
1426
ExportInputs(const FuncGraphPtr & func_graph,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)1427 void OnnxExporter::ExportInputs(const FuncGraphPtr &func_graph, std::map<AnfNodePtr, std::string> *node_map_ptr,
1428 onnx::GraphProto *const graph_proto) {
1429 for (auto ¶m : func_graph->parameters()) {
1430 const ParameterPtr param_ptr = dyn_cast<Parameter>(param);
1431 if (param_ptr == nullptr) {
1432 MS_LOG(EXCEPTION) << "Parameter '" << param->ToString() << "' could not cast to parameter.";
1433 }
1434
1435 if (param_ptr->has_default()) {
1436 continue;
1437 }
1438
1439 // set onnx input.
1440 std::string name;
1441 auto renamed_iter = renamed_node_map_.find(param_ptr);
1442 if (renamed_iter != renamed_node_map_.end()) {
1443 name = renamed_iter->second;
1444 if (name == "") {
1445 continue;
1446 }
1447 } else {
1448 name = GenerateUniqueParameterName(param_ptr, node_map_ptr);
1449 (*node_map_ptr)[param_ptr] = name;
1450 }
1451
1452 onnx::ValueInfoProto *input_proto = graph_proto->add_input();
1453 input_proto->set_name(name);
1454 SetValueInfoType(param_ptr, input_proto);
1455 }
1456 }
1457
GetOnnxDataType(TypeId type_id)1458 onnx::TensorProto_DataType OnnxExporter::GetOnnxDataType(TypeId type_id) {
1459 // clang-format off
1460 static mindspore::HashMap<int, onnx::TensorProto_DataType> type_map = {
1461 {kNumberTypeBool, onnx::TensorProto_DataType_BOOL},
1462 {kNumberTypeInt8, onnx::TensorProto_DataType_INT8},
1463 {kNumberTypeInt16, onnx::TensorProto_DataType_INT16},
1464 {kNumberTypeInt32, onnx::TensorProto_DataType_INT32},
1465 {kNumberTypeInt64, onnx::TensorProto_DataType_INT64},
1466 {kNumberTypeUInt8, onnx::TensorProto_DataType_UINT8},
1467 {kNumberTypeUInt16, onnx::TensorProto_DataType_UINT16},
1468 {kNumberTypeUInt32, onnx::TensorProto_DataType_UINT32},
1469 {kNumberTypeUInt64, onnx::TensorProto_DataType_UINT64},
1470 {kNumberTypeFloat16, onnx::TensorProto_DataType_FLOAT16},
1471 {kNumberTypeFloat32, onnx::TensorProto_DataType_FLOAT},
1472 {kNumberTypeFloat64, onnx::TensorProto_DataType_DOUBLE},
1473 };
1474 // clang-format on
1475
1476 auto iter = type_map.find(type_id);
1477 if (iter == type_map.end()) {
1478 MS_LOG(EXCEPTION) << "Convert type error, unsupported type " << type_id;
1479 }
1480
1481 return iter->second;
1482 }
1483
SetValueInfoType(const AnfNodePtr & node,onnx::ValueInfoProto * const value_proto,int64_t output_index) const1484 void OnnxExporter::SetValueInfoType(const AnfNodePtr &node, onnx::ValueInfoProto *const value_proto,
1485 int64_t output_index) const {
1486 auto dtype = GetOutputType(node, output_index);
1487 auto shape = node->Shape();
1488
1489 abstract::ShapePtr output_shape;
1490 if (shape->isa<abstract::TupleShape>()) {
1491 auto tuple_shape = dyn_cast<abstract::TupleShape>(shape);
1492 auto base_shape = tuple_shape->shape().at(static_cast<size_t>(output_index));
1493 output_shape = dyn_cast<abstract::Shape>(base_shape);
1494 if (output_shape == nullptr) {
1495 MS_LOG(EXCEPTION) << "Expected " << node->ToString() << " to output a tuple of tensors. Instead got "
1496 << base_shape->ToString() << " from output " << output_index;
1497 }
1498 } else if (shape->isa<abstract::Shape>()) {
1499 output_shape = dyn_cast<abstract::Shape>(shape);
1500 } else {
1501 MS_LOG(EXCEPTION) << "Unsupported shape: " << shape->ToString();
1502 }
1503
1504 auto *type_proto = value_proto->mutable_type();
1505 type_proto->mutable_tensor_type()->set_elem_type(dtype);
1506 auto *shape_proto = type_proto->mutable_tensor_type()->mutable_shape();
1507
1508 for (const auto dim : output_shape->shape()) {
1509 shape_proto->add_dim()->set_dim_value(dim);
1510 }
1511 }
1512
MatchAndMark(const FuncGraphPtr & func_graph,const std::vector<AnfNodePtr> & nodes,mindspore::HashMap<AnfNodePtr,OpMergedInfo> * op_merged_infos_ptr) const1513 void OnnxExporter::MatchAndMark(const FuncGraphPtr &func_graph, const std::vector<AnfNodePtr> &nodes,
1514 mindspore::HashMap<AnfNodePtr, OpMergedInfo> *op_merged_infos_ptr) const {
1515 auto &op_merged_infos = *op_merged_infos_ptr;
1516
1517 for (auto &node : nodes) {
1518 if (!node->isa<CNode>() || IsZeroRefcountNode(node)) {
1519 continue;
1520 }
1521 auto cnode = node->cast<CNodePtr>();
1522 if (cnode == func_graph->get_return()) {
1523 // if the key `input` does not exist, just create a new one
1524 op_merged_infos[cnode].referred_count += 1;
1525 }
1526 for (auto &weak_input : cnode->weak_inputs()) {
1527 auto orig_input = weak_input.lock();
1528 MS_EXCEPTION_IF_NULL(orig_input);
1529 auto input = GetRealInput(orig_input);
1530 if (!input->isa<CNode>() || IsZeroRefcountNode(input)) {
1531 continue;
1532 }
1533 // if the key `input` does not exist, just create a new one
1534 op_merged_infos[input].referred_count += 1;
1535 }
1536 MatchAndMarkCNode(func_graph, cnode, op_merged_infos_ptr);
1537 }
1538 }
1539
1540 struct MergeRule {
1541 PrimitivePtr node_type;
1542 PrimitivePtr prev_type;
1543 OpMergeMode merge_mode;
1544 };
1545
MatchAndMarkCNode(const FuncGraphPtr & func_graph,const CNodePtr & cnode,mindspore::HashMap<AnfNodePtr,OpMergedInfo> * op_merged_infos_ptr) const1546 void OnnxExporter::MatchAndMarkCNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
1547 mindspore::HashMap<AnfNodePtr, OpMergedInfo> *op_merged_infos_ptr) const {
1548 MS_EXCEPTION_IF_NULL(op_merged_infos_ptr);
1549 auto &op_merged_infos = *op_merged_infos_ptr;
1550 const auto ignore = [&op_merged_infos](const AnfNodePtr &node) {
1551 op_merged_infos[node].mode = OP_MERGE_IGNORE;
1552 op_merged_infos[node].referred_count -= 1;
1553 };
1554
1555 const std::vector<MergeRule> first_input_merge_rules = {
1556 {prim::kPrimBiasAdd, prim::kPrimConv2D, OP_MERGE_CONV},
1557 {prim::kPrimBiasAdd, prim::kPrimConv2DTranspose, OP_MERGE_CONV2D_TRANSPOSE},
1558 {prim::kPrimBiasAdd, prim::kPrimConv3D, OP_MERGE_CONV},
1559 {prim::kPrimBiasAdd, prim::kPrimConv3DTranspose, OP_MERGE_CONV},
1560 {prim::kPrimBiasAdd, prim::kPrimMatMul, OP_MERGE_GEMM},
1561 {prim::kPrimTupleGetItem, prim::kPrimBatchNorm, OP_MERGE_BATCH_NORM},
1562 {prim::kPrimTupleGetItem, prim::kPrimMaxPoolWithArgmax, OP_MERGE_MAXPOOL_WITH_ARGMAX},
1563 {prim::kPrimTupleGetItem, prim::kPrimLayerNorm, OP_MERGE_LAYER_NORM},
1564 {prim::kPrimTupleGetItem, prim::kPrimDynamicGRUV2, OP_MERGE_DYNAMIC_GRU_V2},
1565 };
1566
1567 auto rule = std::find_if(first_input_merge_rules.begin(), first_input_merge_rules.end(), [&cnode](const auto &rule) {
1568 return cnode->IsApply(rule.node_type) && IsPrimitiveCNode(cnode->input(1), rule.prev_type);
1569 });
1570 if (rule != first_input_merge_rules.end()) {
1571 if (cnode->IsApply(prim::kPrimTupleGetItem) && GetInt64Value(cnode->input(kTwoNum)) != 0) {
1572 MS_LOG(EXCEPTION) << "Multiple outputs for node \"" << cnode->input(1)->ToString() << "\" are not supported";
1573 }
1574 op_merged_infos[cnode].mode = rule->merge_mode;
1575 ignore(cnode->input(1));
1576 } else if (while_loop_export::IsLoopBodyReturnNode(cnode, func_graph)) {
1577 // Ignore to replace with other outputs
1578 ignore(cnode);
1579 auto repeat_node = dyn_cast<CNode>(GetRealInput(cnode->input(1)));
1580 MS_EXCEPTION_IF_NULL(repeat_node);
1581 ignore(repeat_node);
1582 } else if (while_loop_export::IsAfterLoopReturnNode(cnode, func_graph)) {
1583 // Ignore to inline after-loop subgraph in main graph
1584 ignore(cnode);
1585 auto first_input = GetRealInput(cnode->input(1));
1586 if (IsPrimitiveCNode(first_input, prim::kPrimMakeTuple)) {
1587 ignore(first_input);
1588 }
1589 } else if (cnode == func_graph->get_return()) {
1590 auto first_input = GetRealInput(cnode->input(1)); // Unpack Depend
1591 // Ignore MakeTuple output node to avoid exporting it to SequenceConstruct
1592 // and handle multiple outputs in ExportOutput
1593 IgnoreMakeTuple(first_input, op_merged_infos_ptr);
1594 } else if (cnode->IsApply(prim::kPrimConcat) && IsPrimitiveCNode(cnode->input(1), prim::kPrimMakeTuple)) {
1595 // Ignore MakeTuple to handle it in ExportPrimConcat
1596 ignore(cnode->input(1));
1597 }
1598 }
1599
IgnoreMakeTuple(const AnfNodePtr & node,mindspore::HashMap<AnfNodePtr,OpMergedInfo> * op_merged_infos_ptr) const1600 void OnnxExporter::IgnoreMakeTuple(const AnfNodePtr &node,
1601 mindspore::HashMap<AnfNodePtr, OpMergedInfo> *op_merged_infos_ptr) const {
1602 MS_EXCEPTION_IF_NULL(op_merged_infos_ptr);
1603 auto &op_merged_infos = *op_merged_infos_ptr;
1604 const auto ignore = [&op_merged_infos](const AnfNodePtr &node) {
1605 op_merged_infos[node].mode = OP_MERGE_IGNORE;
1606 op_merged_infos[node].referred_count -= 1;
1607 };
1608 if (node == nullptr) {
1609 return;
1610 }
1611
1612 if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) {
1613 ignore(node);
1614 auto cnode = dyn_cast<CNode>(node);
1615 if (cnode != nullptr) {
1616 for (size_t i = 1; i < cnode->size(); ++i) {
1617 auto real_input = GetRealInput(cnode->input(i));
1618 IgnoreMakeTuple(real_input, op_merged_infos_ptr);
1619 }
1620 }
1621 }
1622 }
1623
1624 /**
1625 * AnfNode
1626 * +-- CNode
1627 * +-- ANode
1628 * | +-- Parameter
1629 * | `-- ValueNode
1630 */
ExportNodes(const FuncGraphPtr & func_graph,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)1631 void OnnxExporter::ExportNodes(const FuncGraphPtr &func_graph, std::map<AnfNodePtr, std::string> *node_map_ptr,
1632 onnx::GraphProto *const graph_proto) {
1633 std::vector<AnfNodePtr> nodes = TopoSort(func_graph->get_return(), SuccIncoming, AlwaysInclude);
1634
1635 mindspore::HashMap<AnfNodePtr, OpMergedInfo> op_merged_infos;
1636 MatchAndMark(func_graph, nodes, &op_merged_infos);
1637 for (const AnfNodePtr &node : nodes) {
1638 if (!node->isa<CNode>()) {
1639 continue;
1640 }
1641
1642 auto cnode = node->cast<CNodePtr>();
1643 auto iter = op_merged_infos.find(cnode);
1644 // the node is not referenced by any other nodes, skip it
1645 if (iter == op_merged_infos.end()) {
1646 continue;
1647 }
1648 auto merged_info = iter->second;
1649 // the op node is merged with other node and not used any more, skip it
1650 if (merged_info.mode == OP_MERGE_IGNORE && merged_info.referred_count == 0) {
1651 continue;
1652 }
1653 if (cnode == func_graph->get_return()) {
1654 ExportOutput(func_graph, cnode->input(kOneNum), node_map_ptr, graph_proto);
1655 continue;
1656 }
1657 switch (merged_info.mode) {
1658 case OP_MERGE_CONV:
1659 ExportMergeConv(func_graph, cnode, node_map_ptr, graph_proto);
1660 break;
1661 case OP_MERGE_GEMM:
1662 ExportMergeGemm(func_graph, cnode, node_map_ptr, graph_proto);
1663 break;
1664 case OP_MERGE_BATCH_NORM:
1665 ExportMergeBatchNorm(func_graph, cnode, node_map_ptr, graph_proto);
1666 break;
1667 case OP_MERGE_MAXPOOL_WITH_ARGMAX:
1668 ExportMergeMaxPoolWithArgmax(func_graph, cnode, node_map_ptr, graph_proto);
1669 break;
1670 case OP_MERGE_LAYER_NORM:
1671 ExportMergeLayerNorm(func_graph, cnode, node_map_ptr, graph_proto);
1672 break;
1673 case OP_MERGE_CONV2D_TRANSPOSE:
1674 ExportMergeConv2DTranspose(func_graph, cnode, node_map_ptr, graph_proto);
1675 break;
1676 case OP_MERGE_DYNAMIC_GRU_V2:
1677 ExportMergeDynamicGRUV2(func_graph, cnode, node_map_ptr, graph_proto);
1678 break;
1679 default:
1680 ExportCNode(func_graph, cnode, node_map_ptr, graph_proto);
1681 break;
1682 }
1683 }
1684 }
1685
ExportPrimReshape(const FuncGraphPtr &,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)1686 void OnnxExporter::ExportPrimReshape(const FuncGraphPtr &, const CNodePtr &node,
1687 std::map<AnfNodePtr, std::string> *node_map_ptr,
1688 onnx::GraphProto *const graph_proto) {
1689 auto name_x = GetNodeInputName(node->input(kOneNum), node_map_ptr, graph_proto);
1690 auto input_shape = node->input(kTwoNum);
1691 std::string name_shape;
1692 if (input_shape->isa<ValueNode>()) {
1693 name_shape = RegisterNodeWithUniqueName(input_shape, node_map_ptr);
1694 onnx::NodeProto *node_proto = graph_proto->add_node();
1695 auto name = prim::kPrimReshape->name();
1696
1697 node_proto->set_name(name_shape + name);
1698 node_proto->add_output(name_shape);
1699 node_proto->set_op_type("Constant");
1700 onnx::AttributeProto *attr_proto = node_proto->add_attribute();
1701 attr_proto->set_name("value");
1702 attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR);
1703 ConvertTupleToTensor(dyn_cast<ValueNode>(input_shape)->value(), attr_proto->mutable_t());
1704 } else {
1705 name_shape = GetNodeInputName(input_shape, node_map_ptr, graph_proto);
1706 MS_LOG(EXCEPTION) << "Need to insert op convert variable from tuple to tensor for Reshape.";
1707 }
1708
1709 auto node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
1710 onnx::NodeProto *node_proto = graph_proto->add_node();
1711 node_proto->set_op_type(prim::kPrimReshape->name());
1712 node_proto->add_output(node_name);
1713 node_proto->add_input(name_x);
1714 node_proto->add_input(name_shape);
1715 }
1716
ExportPrimReduce(const FuncGraphPtr &,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)1717 void OnnxExporter::ExportPrimReduce(const FuncGraphPtr &, const CNodePtr &node,
1718 std::map<AnfNodePtr, std::string> *node_map_ptr,
1719 onnx::GraphProto *const graph_proto) {
1720 auto input_data = GetNodeInputName(node->input(kOneNum), node_map_ptr, graph_proto);
1721 auto input_axis = node->input(kTwoNum);
1722 auto keep_dims = GetOpAttribute<bool>(node, "keep_dims");
1723
1724 auto node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
1725
1726 std::string name;
1727 if (node->IsApply(prim::kPrimReduceSum)) {
1728 name = "ReduceSum";
1729 } else if (node->IsApply(prim::kPrimReduceMean)) {
1730 name = "ReduceMean";
1731 } else if (node->IsApply(prim::kPrimReduceMax)) {
1732 name = "ReduceMax";
1733 } else {
1734 MS_LOG(EXCEPTION) << "Unsupported reduce op: " << node->ToString();
1735 }
1736
1737 std::vector<int64_t> axes;
1738 if (input_axis->isa<ValueNode>()) {
1739 auto axis_value = dyn_cast<ValueNode>(input_axis)->value();
1740 if (axis_value->isa<Int32Imm>()) {
1741 auto int_ptr = dyn_cast<Int32Imm>(axis_value);
1742 axes.push_back(int_ptr->value());
1743 } else if (axis_value->isa<Int64Imm>()) {
1744 auto int_ptr = dyn_cast<Int64Imm>(axis_value);
1745 axes.push_back(int_ptr->value());
1746 } else if (axis_value->isa<ValueTuple>()) {
1747 auto tuple_ptr = dyn_cast<ValueTuple>(axis_value);
1748 axes = GetValue<std::vector<int64_t>>(tuple_ptr);
1749 } else {
1750 MS_LOG(EXCEPTION) << "Cannot convert value " << axis_value->ToString() << " of type "
1751 << axis_value->type()->ToString() << " for \"axes\" attribute of " << name;
1752 }
1753 } else {
1754 MS_LOG(EXCEPTION) << "Need to insert op convert variable from tuple to attributes for " << name;
1755 }
1756
1757 AddReduceOp(name, input_data, node_name, axes, keep_dims, graph_proto);
1758 }
1759
ExportPrimReduceAnyOrAll(const FuncGraphPtr &,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)1760 void OnnxExporter::ExportPrimReduceAnyOrAll(const FuncGraphPtr &, const CNodePtr &node,
1761 std::map<AnfNodePtr, std::string> *node_map_ptr,
1762 onnx::GraphProto *const graph_proto) {
1763 auto input_data_name = GetNodeInputName(node->input(kOneNum), node_map_ptr, graph_proto);
1764 auto input_axis = node->input(kTwoNum);
1765 auto keep_dims = GetOpAttribute<bool>(node, "keep_dims");
1766 auto reduce_name = RegisterNodeWithUniqueName(node, node_map_ptr);
1767
1768 std::string target_node_name = "";
1769 if (node->IsApply(prim::kPrimReduceAny)) {
1770 target_node_name = "ReduceSum";
1771 } else if (node->IsApply(prim::kPrimReduceAll)) {
1772 target_node_name = "ReduceMin";
1773 } else {
1774 MS_LOG(EXCEPTION) << "Unsupported reduce op: " << node->ToString();
1775 }
1776
1777 std::string cast_name = GenerateUniqueName(); // Insert cast op
1778 onnx::NodeProto *cast_proto = graph_proto->add_node();
1779 cast_proto->add_input(input_data_name);
1780 cast_proto->add_output(cast_name);
1781 cast_proto->set_op_type(prim::kPrimCast->name());
1782 onnx::AttributeProto *attr_proto = cast_proto->add_attribute();
1783 attr_proto->set_name("to");
1784 attr_proto->set_type(onnx::AttributeProto_AttributeType_INT);
1785 attr_proto->set_i(GetOnnxDataType(TypeId::kNumberTypeFloat32));
1786
1787 std::vector<int64_t> axes;
1788 if (input_axis->isa<ValueNode>()) {
1789 auto axis_value = dyn_cast<ValueNode>(input_axis)->value();
1790 if (axis_value->isa<Int32Imm>()) {
1791 auto int_ptr = dyn_cast<Int32Imm>(axis_value);
1792 axes.push_back(int_ptr->value());
1793 } else if (axis_value->isa<Int64Imm>()) {
1794 auto int_ptr = dyn_cast<Int64Imm>(axis_value);
1795 axes.push_back(int_ptr->value());
1796 } else if (axis_value->isa<ValueTuple>()) {
1797 auto tuple_ptr = dyn_cast<ValueTuple>(axis_value);
1798 axes = GetValue<std::vector<int64_t>>(tuple_ptr);
1799 if (axes.empty()) {
1800 const auto &x_shape = dyn_cast<abstract::Shape>(node->input(kOneNum)->Shape())->shape();
1801 for (size_t i = 0; i < x_shape.size(); ++i) {
1802 axes.push_back(static_cast<int64_t>(i));
1803 }
1804 }
1805 } else {
1806 MS_LOG(EXCEPTION) << "Cannot convert value " << axis_value->ToString() << " of type "
1807 << axis_value->type()->ToString() << " for \"axes\" attribute of " << target_node_name;
1808 }
1809 } else {
1810 MS_LOG(EXCEPTION) << "Need to insert op convert variable from tuple to attributes for " << target_node_name;
1811 }
1812
1813 std::string greater_name = GenerateUniqueName();
1814 onnx::TensorProto *zero_initializer_proto = graph_proto->add_initializer();
1815 auto zero_input_name = greater_name + "_zero";
1816 zero_initializer_proto->set_name(zero_input_name);
1817 zero_initializer_proto->set_data_type(GetOnnxDataType(kNumberTypeFloat32));
1818 zero_initializer_proto->add_float_data(0);
1819
1820 AddReduceOp(target_node_name, cast_name, greater_name, axes, keep_dims, graph_proto);
1821
1822 onnx::NodeProto *greater_node_proto = graph_proto->add_node(); // Insert greater op
1823 greater_node_proto->add_input(greater_name);
1824 greater_node_proto->add_input(zero_input_name);
1825 greater_node_proto->add_output(reduce_name);
1826 greater_node_proto->set_op_type(prim::kPrimGreater->name());
1827 }
1828
ExportPrimTranspose(const FuncGraphPtr &,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)1829 void OnnxExporter::ExportPrimTranspose(const FuncGraphPtr &, const CNodePtr &node,
1830 std::map<AnfNodePtr, std::string> *node_map_ptr,
1831 onnx::GraphProto *const graph_proto) {
1832 auto input_data = GetNodeInputName(node->input(kOneNum), node_map_ptr, graph_proto);
1833 auto input_perm = node->input(kTwoNum);
1834 auto node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
1835 onnx::NodeProto *node_proto = graph_proto->add_node();
1836 auto name = prim::kPrimTranspose->name();
1837
1838 node_proto->set_name(node_name + name);
1839 node_proto->set_op_type(name);
1840 node_proto->add_output(node_name);
1841 node_proto->add_input(input_data);
1842
1843 if (input_perm->isa<ValueNode>()) {
1844 onnx::AttributeProto *attr_proto = node_proto->add_attribute();
1845 attr_proto->set_name("perm");
1846 attr_proto->set_type(onnx::AttributeProto_AttributeType_INTS);
1847 auto perm_value = dyn_cast<ValueNode>(input_perm)->value();
1848 auto int_ptr = dyn_cast<Int32Imm>(perm_value);
1849 if (int_ptr == nullptr) {
1850 auto tuple_ptr = dyn_cast<ValueTuple>(perm_value);
1851 MS_EXCEPTION_IF_NULL(tuple_ptr);
1852 for (size_t i = 0; i < tuple_ptr->size(); ++i) {
1853 attr_proto->add_ints(GetValue<int64_t>((*tuple_ptr)[i]));
1854 }
1855 } else {
1856 attr_proto->add_ints(int_ptr->value());
1857 }
1858 } else {
1859 MS_LOG(EXCEPTION) << "The input input_perm of Transpose is not a ValueNode! "
1860 << "Need to insert op convert variable from tuple to attributes for " << name;
1861 }
1862 }
1863
1864 /*
1865 See:
1866 - mindspore/ccsrc/backend/kernel_compiler/cpu/stridedslice_cpu_kernel.cc
1867 - mindspore/ccsrc/backend/kernel_compiler/common_utils.cc
1868 */
ExportPrimStridedSlice(const FuncGraphPtr &,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)1869 void OnnxExporter::ExportPrimStridedSlice(const FuncGraphPtr &, const CNodePtr &node,
1870 std::map<AnfNodePtr, std::string> *node_map_ptr,
1871 onnx::GraphProto *const graph_proto) {
1872 auto input_data = GetNodeInputName(node->input(kOneNum), node_map_ptr, graph_proto);
1873 auto node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
1874 auto name = node_name + prim::kPrimStridedSlice->name();
1875
1876 auto begin = node->input(kTwoNum);
1877 auto begin_mask = node->input(kFiveNum);
1878 if (!begin->isa<ValueNode>() && !begin_mask->isa<ValueNode>()) {
1879 MS_LOG(EXCEPTION) << "The input begin of StridedSlice is not a ValueNode! "
1880 << "Need to insert op convert variable from tuple to tensor for " << name;
1881 }
1882 auto begin_value_node = dyn_cast<ValueNode>(begin);
1883 auto begin_value = GetValue<std::vector<int64_t>>(begin_value_node->value());
1884 auto begin_ignore_mask = GetValue<int64_t>(dyn_cast<ValueNode>(begin_mask)->value());
1885 for (size_t i = 0; i < begin_value.size(); ++i) {
1886 if ((static_cast<uint64_t>(begin_ignore_mask) & (1UL << i)) != 0) {
1887 begin_value[i] = 0;
1888 }
1889 }
1890
1891 auto end = node->input(kThreeNum);
1892 auto end_mask = node->input(kSixNum);
1893 if (!end->isa<ValueNode>() && !end_mask->isa<ValueNode>()) {
1894 MS_LOG(EXCEPTION) << "The input end of StridedSlice is not a ValueNode! "
1895 << "Need to insert op convert variable from tuple to tensor for " << name;
1896 }
1897 auto end_value_node = dyn_cast<ValueNode>(end);
1898 auto end_value = GetValue<std::vector<int64_t>>(end_value_node->value());
1899 const auto &x_shape = dyn_cast<abstract::Shape>(node->input(kOneNum)->Shape())->shape();
1900 auto end_ignore_mask = GetValue<int64_t>(dyn_cast<ValueNode>(end_mask)->value());
1901 for (size_t i = 0; i < end_value.size(); ++i) {
1902 if ((static_cast<uint64_t>(end_ignore_mask) & (1UL << i)) != 0) {
1903 end_value[i] = x_shape[i];
1904 }
1905 }
1906
1907 size_t axes_size = end_value.size();
1908 std::vector<int64_t> axes_value;
1909 for (size_t i = 0; i < axes_size; ++i) {
1910 axes_value.push_back(static_cast<int64_t>(i));
1911 }
1912
1913 auto strides = node->input(kFourNum);
1914 auto shrink_mask = node->input(kNineNum);
1915 if (!strides->isa<ValueNode>() && !shrink_mask->isa<ValueNode>()) {
1916 MS_LOG(EXCEPTION) << "The input strides of StridedSlice is not a ValueNode! "
1917 << "Need to insert op convert variable from tuple to tensor for " << name;
1918 }
1919 auto strides_value_node = dyn_cast<ValueNode>(strides);
1920 auto strides_value = GetValue<std::vector<int64_t>>(strides_value_node->value());
1921
1922 auto shrink_axis_mask = GetValue<int64_t>(dyn_cast<ValueNode>(shrink_mask)->value());
1923 for (size_t i = 0; i < end_value.size(); ++i) {
1924 if ((static_cast<uint64_t>(shrink_axis_mask) & (1UL << i)) != 0) {
1925 strides_value[i] = end_value[i] > begin_value[i] ? 1 : -1;
1926 end_value[i] = begin_value[i] + strides_value[i];
1927 }
1928 }
1929
1930 auto slice_name = node_name;
1931 if (shrink_axis_mask != 0) {
1932 slice_name = node_name + "__reshape";
1933 }
1934
1935 AddSliceOp(input_data, slice_name, begin_value, end_value, axes_value, strides_value, graph_proto);
1936
1937 if (shrink_axis_mask != 0) {
1938 onnx::NodeProto *squeeze_op = graph_proto->add_node();
1939 squeeze_op->set_op_type("Squeeze");
1940 squeeze_op->add_input(slice_name);
1941 squeeze_op->add_output(node_name);
1942 onnx::AttributeProto *axes_attr = squeeze_op->add_attribute();
1943 axes_attr->set_name("axes");
1944 axes_attr->set_type(onnx::AttributeProto_AttributeType_INTS);
1945 for (size_t i = 0; i < x_shape.size(); ++i) {
1946 if ((static_cast<uint64_t>(shrink_axis_mask) & (1UL << i)) != 0) {
1947 axes_attr->add_ints(static_cast<int64_t>(i));
1948 }
1949 }
1950 }
1951 }
1952
PrimResizeExportHelper(const FuncGraphPtr &,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)1953 onnx::NodeProto *OnnxExporter::PrimResizeExportHelper(const FuncGraphPtr &, const CNodePtr &node,
1954 std::map<AnfNodePtr, std::string> *node_map_ptr,
1955 onnx::GraphProto *const graph_proto) {
1956 auto input_data = GetNodeInputName(node->input(kOneNum), node_map_ptr, graph_proto);
1957 auto x_shape = dyn_cast<abstract::Shape>(node->input(kOneNum)->Shape());
1958
1959 AnfNodePtr op = node->input(kZeroNum);
1960 auto op_value = dyn_cast<ValueNode>(op);
1961 auto prim = dyn_cast<Primitive>(op_value->value());
1962 std::vector<int64_t> resize_size;
1963
1964 auto tuple_ptr = dyn_cast<ValueSequence>(prim->GetAttr("size")); // size may be Tuple or List
1965 if (tuple_ptr == nullptr) {
1966 MS_LOG(EXCEPTION) << "Got null pointer, currently the " << prim->name()
1967 << " operator in your model is not support for exporting onnx.";
1968 }
1969
1970 for (size_t i = 0; i < x_shape->shape().size() - kTwoNum; i++) {
1971 resize_size.push_back(x_shape->shape()[i]);
1972 }
1973 for (size_t i = 0; i < tuple_ptr->size(); i++) {
1974 ValuePtr elem = (*tuple_ptr)[i];
1975 resize_size.push_back(dyn_cast<Int64Imm>(elem)->value());
1976 }
1977 auto resize_size_ptr = MakeValue<std::vector<int64_t>>(resize_size);
1978 auto size = NewValueNode(resize_size_ptr)->cast<AnfNodePtr>();
1979
1980 auto name_size = RegisterNodeWithUniqueName(size, node_map_ptr);
1981 onnx::NodeProto *node_proto_size = graph_proto->add_node();
1982 node_proto_size->add_output(name_size);
1983 node_proto_size->set_op_type("Constant");
1984 onnx::AttributeProto *attr_proto = node_proto_size->add_attribute();
1985 attr_proto->set_name("value");
1986 attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR);
1987 ConvertTupleToTensor(resize_size_ptr, attr_proto->mutable_t());
1988
1989 auto node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
1990
1991 onnx::TensorProto *roi_initializer_proto = graph_proto->add_initializer();
1992 auto roi_name = node_name + "roi_initializer";
1993 roi_initializer_proto->set_name(roi_name);
1994 roi_initializer_proto->set_data_type(GetOnnxDataType(kNumberTypeFloat32));
1995 roi_initializer_proto->add_dims(0);
1996
1997 onnx::TensorProto *scales_initializer_proto = graph_proto->add_initializer();
1998 auto scales_name = node_name + "scales_initializer";
1999 scales_initializer_proto->set_name(scales_name);
2000 scales_initializer_proto->set_data_type(GetOnnxDataType(kNumberTypeFloat32));
2001 scales_initializer_proto->add_dims(0);
2002
2003 onnx::NodeProto *node_proto = graph_proto->add_node();
2004
2005 node_proto->set_op_type("Resize");
2006 node_proto->add_output(node_name);
2007 node_proto->add_input(input_data);
2008 node_proto->add_input(roi_name);
2009 node_proto->add_input(scales_name);
2010 node_proto->add_input(name_size);
2011
2012 return node_proto;
2013 }
2014
ExportPrimResizeNearestNeighbor(const FuncGraphPtr & graph,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)2015 void OnnxExporter::ExportPrimResizeNearestNeighbor(const FuncGraphPtr &graph, const CNodePtr &node,
2016 std::map<AnfNodePtr, std::string> *node_map_ptr,
2017 onnx::GraphProto *const graph_proto) {
2018 onnx::NodeProto *node_proto = PrimResizeExportHelper(graph, node, node_map_ptr, graph_proto);
2019
2020 auto align_corners = GetOpAttribute<bool>(node, "align_corners");
2021 std::string coordinate_transformation_mode = align_corners ? "align_corners" : "asymmetric";
2022 // `nearest_mode` is based on ResizeNearestNeighborCPUKernel::LaunchKernel in
2023 // mindspore/ccsrc/backend/kernel_compiler/cpu/resize_nearest_neighbor_cpu_kernel.cc
2024 std::string nearest_mode = align_corners ? "round_prefer_ceil" : "floor";
2025
2026 onnx::AttributeProto *coordinate_mode_proto = node_proto->add_attribute();
2027 coordinate_mode_proto->set_name("coordinate_transformation_mode");
2028 coordinate_mode_proto->set_type(onnx::AttributeProto_AttributeType_STRING);
2029 coordinate_mode_proto->set_s(coordinate_transformation_mode);
2030
2031 onnx::AttributeProto *nearest_mode_proto = node_proto->add_attribute();
2032 nearest_mode_proto->set_name("nearest_mode");
2033 nearest_mode_proto->set_type(onnx::AttributeProto_AttributeType_STRING);
2034 nearest_mode_proto->set_s(nearest_mode);
2035 }
2036
ExportPrimResizeBilinear(const FuncGraphPtr & graph,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)2037 void OnnxExporter::ExportPrimResizeBilinear(const FuncGraphPtr &graph, const CNodePtr &node,
2038 std::map<AnfNodePtr, std::string> *node_map_ptr,
2039 onnx::GraphProto *const graph_proto) {
2040 onnx::NodeProto *node_proto = PrimResizeExportHelper(graph, node, node_map_ptr, graph_proto);
2041
2042 auto align_corners = GetOpAttribute<bool>(node, "align_corners");
2043 std::string coordinate_transformation_mode = align_corners ? "align_corners" : "asymmetric";
2044
2045 onnx::AttributeProto *coordinate_mode_proto = node_proto->add_attribute();
2046 coordinate_mode_proto->set_name("coordinate_transformation_mode");
2047 coordinate_mode_proto->set_type(onnx::AttributeProto_AttributeType_STRING);
2048 coordinate_mode_proto->set_s(coordinate_transformation_mode);
2049
2050 onnx::AttributeProto *mode_proto = node_proto->add_attribute();
2051 mode_proto->set_name("mode");
2052 mode_proto->set_type(onnx::AttributeProto_AttributeType_STRING);
2053 mode_proto->set_s("linear");
2054 }
2055
2056 // MindSpore ExpandDims -> ONNX Reshape
ExportPrimExpandDims(const FuncGraphPtr &,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)2057 void OnnxExporter::ExportPrimExpandDims(const FuncGraphPtr &, const CNodePtr &node,
2058 std::map<AnfNodePtr, std::string> *node_map_ptr,
2059 onnx::GraphProto *const graph_proto) {
2060 auto input_x = GetNodeInputName(node->input(kOneNum), node_map_ptr, graph_proto);
2061 auto axis = GetInt64Value(node->input(kTwoNum));
2062 auto x_shape = dyn_cast<abstract::Shape>(node->input(kOneNum)->Shape());
2063 auto name = prim::kPrimExpandDims->name();
2064
2065 std::vector<int64_t> new_shape;
2066 for (size_t i = 0; i < x_shape->shape().size(); i++) {
2067 new_shape.push_back(x_shape->shape()[i]);
2068 }
2069 if (axis < 0) {
2070 axis = axis + kOneNumLong + SizeToLong(x_shape->shape().size());
2071 }
2072 (void)new_shape.insert(new_shape.begin() + axis, kOneNum);
2073 auto new_shape_value = MakeValue<std::vector<int64_t>>(new_shape);
2074 auto shape = NewValueNode(new_shape_value)->cast<AnfNodePtr>();
2075 std::string name_shape;
2076
2077 if (shape->isa<ValueNode>()) {
2078 name_shape = RegisterNodeWithUniqueName(shape, node_map_ptr);
2079 onnx::NodeProto *node_proto = graph_proto->add_node();
2080 node_proto->add_output(name_shape);
2081 node_proto->set_op_type("Constant");
2082 onnx::AttributeProto *attr_proto = node_proto->add_attribute();
2083 attr_proto->set_name("value");
2084 attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR);
2085 ConvertTupleToTensor(dyn_cast<ValueNode>(shape)->value(), attr_proto->mutable_t());
2086 } else {
2087 name_shape = GetNodeInputName(shape, node_map_ptr, graph_proto);
2088 MS_LOG(EXCEPTION) << "Need to insert op convert variable from tuple to tensor for " << name;
2089 }
2090
2091 auto node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
2092 onnx::NodeProto *node_proto = graph_proto->add_node();
2093 node_proto->set_op_type("Reshape");
2094 node_proto->add_output(node_name);
2095 node_proto->add_input(input_x);
2096 node_proto->add_input(name_shape);
2097 }
2098
2099 // MindSpore GatherD -> ONNX GatherElements
ExportPrimGatherD(const FuncGraphPtr &,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)2100 void OnnxExporter::ExportPrimGatherD(const FuncGraphPtr &, const CNodePtr &node,
2101 std::map<AnfNodePtr, std::string> *node_map_ptr,
2102 onnx::GraphProto *const graph_proto) {
2103 auto input_x = GetNodeInputName(node->input(kOneNum), node_map_ptr, graph_proto);
2104 auto axis = GetInt64Value(node->input(kTwoNum));
2105 auto input_indices = GetNodeInputName(node->input(kThreeNum), node_map_ptr, graph_proto);
2106 auto node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
2107 onnx::NodeProto *node_proto = graph_proto->add_node();
2108 node_proto->set_op_type("GatherElements");
2109 node_proto->add_output(node_name);
2110 node_proto->add_input(input_x);
2111 node_proto->add_input(input_indices);
2112 onnx::AttributeProto *attr_proto = node_proto->add_attribute();
2113 attr_proto->set_name("axis");
2114 attr_proto->set_type(onnx::AttributeProto_AttributeType_INT);
2115 attr_proto->set_i(static_cast<::google::protobuf::int64>(axis));
2116 }
2117
2118 // MindSpore Pad -> ONNX Pad
ExportPrimPad(const FuncGraphPtr &,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)2119 void OnnxExporter::ExportPrimPad(const FuncGraphPtr &, const CNodePtr &node,
2120 std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *const graph_proto) {
2121 auto x_name = GetNodeInputName(node->input(kOneNum), node_map_ptr, graph_proto);
2122
2123 auto paddings = GetOpAttributePtr<ValueTuple>(node, "paddings");
2124 std::vector<std::vector<int64_t>> paddings_values = GetValue<std::vector<std::vector<int64_t>>>(paddings);
2125 std::vector<int64_t> pads_sequence;
2126 for (size_t i = 0; i < paddings_values.size(); ++i) {
2127 pads_sequence.push_back(paddings_values[i][0]);
2128 }
2129 for (size_t j = 0; j < paddings_values.size(); ++j) {
2130 pads_sequence.push_back(paddings_values[j][1]);
2131 }
2132 auto pads_ptr = MakeValue<std::vector<int64_t>>(pads_sequence);
2133 auto pads = NewValueNode(pads_ptr)->cast<AnfNodePtr>();
2134
2135 auto pads_name = RegisterNodeWithUniqueName(pads, node_map_ptr);
2136 onnx::NodeProto *pads_node = graph_proto->add_node();
2137 pads_node->add_output(pads_name);
2138 pads_node->set_op_type("Constant");
2139 onnx::AttributeProto *pads_attr_proto = pads_node->add_attribute();
2140 pads_attr_proto->set_name("value");
2141 pads_attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR);
2142 ConvertTupleToTensor(pads_ptr, pads_attr_proto->mutable_t());
2143
2144 auto ms_pad_node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
2145 onnx::NodeProto *onnx_pad_node = graph_proto->add_node();
2146 onnx_pad_node->set_op_type("Pad");
2147 onnx_pad_node->add_output(ms_pad_node_name);
2148 onnx_pad_node->add_input(x_name);
2149 onnx_pad_node->add_input(pads_name);
2150 }
2151
2152 // MindSpore BatchMatMul -> ONNX Transpose + MatMul
ExportPrimBatchMatMul(const FuncGraphPtr &,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)2153 void OnnxExporter::ExportPrimBatchMatMul(const FuncGraphPtr &, const CNodePtr &node,
2154 std::map<AnfNodePtr, std::string> *node_map_ptr,
2155 onnx::GraphProto *const graph_proto) {
2156 auto input_x = GetNodeInputName(node->input(kOneNum), node_map_ptr, graph_proto);
2157 auto input_y = GetNodeInputName(node->input(kTwoNum), node_map_ptr, graph_proto);
2158
2159 AnfNodePtr batchmatmul_op = node->input(kZeroNum);
2160 auto op_value = dyn_cast<ValueNode>(batchmatmul_op);
2161 auto prim = dyn_cast<Primitive>(op_value->value());
2162 auto transpose_a = GetValue<bool>(prim->GetAttr("transpose_a"));
2163 auto transpose_b = GetValue<bool>(prim->GetAttr("transpose_b"));
2164 std::string transpose_input_x_name = "";
2165 std::string transpose_input_y_name = "";
2166
2167 if (transpose_a) {
2168 auto input_x_shape = dyn_cast<abstract::Shape>(node->input(kOneNum)->Shape());
2169 // Add Transpose node after input_x of BatchMatMul
2170 transpose_input_x_name = GenerateUniqueName();
2171 onnx::NodeProto *transpose_inputx_node_proto = graph_proto->add_node();
2172 transpose_inputx_node_proto->add_input(input_x);
2173 transpose_inputx_node_proto->add_output(transpose_input_x_name);
2174 transpose_inputx_node_proto->set_op_type(prim::kPrimTranspose->name());
2175 onnx::AttributeProto *attr_proto = transpose_inputx_node_proto->add_attribute();
2176 attr_proto->set_name("perm");
2177 attr_proto->set_type(onnx::AttributeProto_AttributeType_INTS);
2178 for (size_t i = 0; i < input_x_shape->shape().size() - kTwoNum; i++) {
2179 attr_proto->add_ints(SizeToLong(i));
2180 }
2181 attr_proto->add_ints(SizeToLong(input_x_shape->shape().size()) - IntToLong(kOneNum));
2182 attr_proto->add_ints(SizeToLong(input_x_shape->shape().size()) - IntToLong(kTwoNum));
2183 }
2184 if (transpose_b) {
2185 auto input_y_shape = dyn_cast<abstract::Shape>(node->input(kTwoNum)->Shape());
2186 // Add Transpose node after input_y of BatchMatMul
2187 transpose_input_y_name = GenerateUniqueName();
2188 onnx::NodeProto *transpose_inputy_node_proto = graph_proto->add_node();
2189 transpose_inputy_node_proto->add_input(input_y);
2190 transpose_inputy_node_proto->add_output(transpose_input_y_name);
2191 transpose_inputy_node_proto->set_op_type(prim::kPrimTranspose->name());
2192 onnx::AttributeProto *attr_proto = transpose_inputy_node_proto->add_attribute();
2193 attr_proto->set_name("perm");
2194 attr_proto->set_type(onnx::AttributeProto_AttributeType_INTS);
2195 for (size_t i = 0; i < input_y_shape->shape().size() - kTwoNum; i++) {
2196 attr_proto->add_ints(SizeToLong(i));
2197 }
2198 attr_proto->add_ints(SizeToLong(input_y_shape->shape().size()) - IntToLong(kOneNum));
2199 attr_proto->add_ints(SizeToLong(input_y_shape->shape().size()) - IntToLong(kTwoNum));
2200 }
2201
2202 auto node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
2203 onnx::NodeProto *node_proto = graph_proto->add_node();
2204 node_proto->set_op_type("MatMul");
2205 node_proto->add_output(node_name);
2206 node_proto->set_name(node_name + "MatMul");
2207 if (transpose_a) {
2208 node_proto->add_input(transpose_input_x_name);
2209 } else {
2210 node_proto->add_input(input_x);
2211 }
2212 if (transpose_b) {
2213 node_proto->add_input(transpose_input_y_name);
2214 } else {
2215 node_proto->add_input(input_y);
2216 }
2217 }
2218
2219 // MindSpore BroadcastTo -> ONNX Expand
ExportPrimBroadcastTo(const FuncGraphPtr &,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)2220 void OnnxExporter::ExportPrimBroadcastTo(const FuncGraphPtr &, const CNodePtr &node,
2221 std::map<AnfNodePtr, std::string> *node_map_ptr,
2222 onnx::GraphProto *const graph_proto) {
2223 auto input_x = GetNodeInputName(node->input(kOneNum), node_map_ptr, graph_proto);
2224 auto x_shape = dyn_cast<abstract::Shape>(node->input(kOneNum)->Shape());
2225 auto name = prim::kPrimBroadcastTo->name();
2226
2227 auto shape_ptr = GetOpAttributePtr<ValueSequeue>(node, "shape");
2228 auto shape_vec = GetValue<std::vector<int64_t>>(shape_ptr);
2229 size_t n_shape = shape_vec.size();
2230
2231 std::vector<int64_t> new_shape;
2232 for (size_t i = 0; i < n_shape; i++) {
2233 if (shape_vec[i] == -kOneNum) {
2234 size_t ids = i + x_shape->shape().size() - n_shape;
2235 new_shape.push_back(x_shape->shape()[ids]);
2236 } else {
2237 new_shape.push_back(shape_vec[i]);
2238 }
2239 }
2240
2241 auto new_shape_value = MakeValue<std::vector<int64_t>>(new_shape);
2242 auto shape = NewValueNode(new_shape_value)->cast<AnfNodePtr>();
2243 std::string name_shape;
2244
2245 if (shape->isa<ValueNode>()) {
2246 name_shape = RegisterNodeWithUniqueName(shape, node_map_ptr);
2247 onnx::NodeProto *node_proto = graph_proto->add_node();
2248 node_proto->add_output(name_shape);
2249 node_proto->set_op_type("Constant");
2250 onnx::AttributeProto *attr_proto = node_proto->add_attribute();
2251 attr_proto->set_name("value");
2252 attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR);
2253 ConvertTupleToTensor(dyn_cast<ValueNode>(shape)->value(), attr_proto->mutable_t());
2254 } else {
2255 name_shape = GetNodeInputName(shape, node_map_ptr, graph_proto);
2256 MS_LOG(EXCEPTION) << "Need to insert op convert variable from tuple to tensor for " << name;
2257 }
2258
2259 auto node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
2260 onnx::NodeProto *node_proto = graph_proto->add_node();
2261 node_proto->set_op_type("Expand");
2262 node_proto->add_output(node_name);
2263 node_proto->add_input(input_x);
2264 node_proto->add_input(name_shape);
2265 }
2266
2267 // MindSpore AddN -> ONNX Add
ExportPrimAddN(const FuncGraphPtr &,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)2268 void OnnxExporter::ExportPrimAddN(const FuncGraphPtr &, const CNodePtr &node,
2269 std::map<AnfNodePtr, std::string> *node_map_ptr,
2270 onnx::GraphProto *const graph_proto) {
2271 auto node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
2272
2273 auto input_node = node->input(kOneNum)->cast<CNodePtr>();
2274 auto last_input_name = GetNodeInputName(input_node->input(kOneNum), node_map_ptr, graph_proto);
2275 for (size_t i = kTwoNum; i < input_node->size() - 1; ++i) {
2276 auto input_name = GetNodeInputName(input_node->input(i), node_map_ptr, graph_proto);
2277 auto tmp_end_name = node_name + "ADD_" + std::to_string(i);
2278 AddOp("Add", {last_input_name, input_name}, {tmp_end_name}, graph_proto);
2279 last_input_name = tmp_end_name;
2280 }
2281 auto input_end_name = GetNodeInputName(input_node->input(input_node->size() - 1), node_map_ptr, graph_proto);
2282 AddOp("Add", {last_input_name, input_end_name}, {node_name}, graph_proto);
2283 }
2284
2285 // MindSpore GeLU -> ONNX 0.5 * X * (1.0 + tanh((sqrt(2/pi) * (x + 0.044715 * pow(x, 3)))))
ExportPrimGeLU(const FuncGraphPtr &,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)2286 void OnnxExporter::ExportPrimGeLU(const FuncGraphPtr &, const CNodePtr &node,
2287 std::map<AnfNodePtr, std::string> *node_map_ptr,
2288 onnx::GraphProto *const graph_proto) {
2289 auto input_x = GetNodeInputName(node->input(kOneNum), node_map_ptr, graph_proto);
2290 auto onnx_type = GetOutputType(node->input(kOneNum));
2291
2292 // Add pow node
2293 auto pow_name = GenerateUniqueName();
2294 auto exp_node_name = pow_name + "exponent_initializer";
2295 AddFloatTensor1DInitializer(exp_node_name, {3.0}, onnx_type, graph_proto);
2296 AddOp("Pow", {input_x, exp_node_name}, {pow_name}, graph_proto);
2297
2298 // Add first Mul Node
2299 auto fmul_name = GenerateUniqueName();
2300 auto fmul_input_node_name = fmul_name + "input_y_for_mul_initializer";
2301 AddFloatTensor1DInitializer(fmul_input_node_name, {0.044715}, onnx_type, graph_proto);
2302 AddOp("Mul", {pow_name, fmul_input_node_name}, {fmul_name}, graph_proto);
2303
2304 // Add first Add node
2305 auto fadd_name = GenerateUniqueName();
2306 AddOp("Add", {input_x, fmul_name}, {fadd_name}, graph_proto);
2307
2308 // Add second Mul Node
2309 auto smul_name = GenerateUniqueName();
2310 auto smul_input_node_name = smul_name + "input_y_for_smul_initializer";
2311 AddFloatTensor1DInitializer(smul_input_node_name, {0.7978845608}, onnx_type, graph_proto);
2312 AddOp("Mul", {fadd_name, smul_input_node_name}, {smul_name}, graph_proto);
2313
2314 // Add tanh node
2315 auto tanh_name = GenerateUniqueName();
2316 AddOp("Tanh", {smul_name}, {tanh_name}, graph_proto);
2317
2318 // Add second Add node
2319 auto sadd_name = GenerateUniqueName();
2320 auto sadd_input_node_name = sadd_name + "input_y_for_sadd_initializer";
2321 AddFloatTensor1DInitializer(sadd_input_node_name, {1.0}, onnx_type, graph_proto);
2322 AddOp("Add", {tanh_name, sadd_input_node_name}, {sadd_name}, graph_proto);
2323
2324 // Add third Mul Node
2325 auto tmul_name = GenerateUniqueName();
2326 auto tmul_input_node_name = tmul_name + "input_y_for_tmul_initializer";
2327 AddFloatTensor1DInitializer(tmul_input_node_name, {0.5}, onnx_type, graph_proto);
2328 AddOp("Mul", {sadd_name, tmul_input_node_name}, {tmul_name}, graph_proto);
2329
2330 // Add fourth Mul Node
2331 auto fomul_node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
2332 AddOp("Mul", {input_x, tmul_name}, {fomul_node_name}, graph_proto);
2333 }
2334
ExportPrimConcat(const FuncGraphPtr &,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)2335 void OnnxExporter::ExportPrimConcat(const FuncGraphPtr &, const CNodePtr &node,
2336 std::map<AnfNodePtr, std::string> *node_map_ptr,
2337 onnx::GraphProto *const graph_proto) {
2338 auto node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
2339
2340 // Get inputs first: otherwise if an input is a constant, topological order will break
2341 auto input_node = node->input(kOneNum)->cast<CNodePtr>();
2342 std::vector<std::string> input_names;
2343 if (input_node->IsApply(prim::kPrimMakeTuple)) {
2344 for (size_t i = 1; i < input_node->size(); ++i) {
2345 auto input_name = GetNodeInputName(input_node->input(i), node_map_ptr, graph_proto);
2346 input_names.push_back(input_name);
2347 }
2348 } else {
2349 auto input_data = GetNodeInputName(node->input(kOneNum), node_map_ptr, graph_proto);
2350 input_names.push_back(input_data);
2351 }
2352
2353 AddConcatOp(input_names, node_name, GetOpAttribute<int64_t>(node, "axis"), graph_proto);
2354 }
2355
ExportPrimCast(const FuncGraphPtr &,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)2356 void OnnxExporter::ExportPrimCast(const FuncGraphPtr &, const CNodePtr &node,
2357 std::map<AnfNodePtr, std::string> *node_map_ptr,
2358 onnx::GraphProto *const graph_proto) {
2359 auto input_data = GetNodeInputName(node->input(kOneNum), node_map_ptr, graph_proto);
2360 auto input_type = node->input(kTwoNum);
2361
2362 auto node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
2363 onnx::NodeProto *node_proto = graph_proto->add_node();
2364 node_proto->set_op_type(prim::kPrimCast->name());
2365 node_proto->add_output(node_name);
2366 node_proto->add_input(input_data);
2367
2368 if (input_type->isa<ValueNode>()) {
2369 onnx::AttributeProto *attr_proto = node_proto->add_attribute();
2370 attr_proto->set_name("to");
2371 attr_proto->set_type(onnx::AttributeProto_AttributeType_INT);
2372 auto type_value = dyn_cast<ValueNode>(input_type)->value();
2373 auto type_ptr = dyn_cast<Int64Imm>(type_value);
2374 MS_EXCEPTION_IF_NULL(type_ptr);
2375 auto type_id = static_cast<TypeId>(type_ptr->value());
2376 attr_proto->set_i(GetOnnxDataType(type_id));
2377 } else {
2378 MS_LOG(EXCEPTION) << "Need to convert MindSpore Cast input(1) to ONNX Cast to attribute.";
2379 }
2380 }
2381
ExportPrimPReLU(const FuncGraphPtr &,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)2382 void OnnxExporter::ExportPrimPReLU(const FuncGraphPtr &, const CNodePtr &node,
2383 std::map<AnfNodePtr, std::string> *node_map_ptr,
2384 onnx::GraphProto *const graph_proto) {
2385 auto input_x = GetNodeInputName(node->input(kOneNum), node_map_ptr, graph_proto);
2386 auto input_slope = GetNodeInputName(node->input(kTwoNum), node_map_ptr, graph_proto);
2387
2388 auto x_shape = dyn_cast<abstract::Shape>(node->input(kOneNum)->Shape());
2389 auto slope_shape = dyn_cast<abstract::Shape>(node->input(kTwoNum)->Shape());
2390 MS_EXCEPTION_IF_NULL(x_shape);
2391 MS_EXCEPTION_IF_NULL(slope_shape);
2392
2393 // format of x is NCHW, input format is NCHW, if length of input_slope is 1, insert Unsqueeze [1,2]
2394 if (x_shape->shape().size() == kFourNum && slope_shape->shape().size() == kOneNum) {
2395 auto node_name = GenerateUniqueName();
2396 onnx::NodeProto *node_proto = graph_proto->add_node();
2397 node_proto->set_op_type("Unsqueeze");
2398 node_proto->add_output(node_name);
2399
2400 onnx::AttributeProto *attr_proto = node_proto->add_attribute();
2401 attr_proto->set_type(onnx::AttributeProto_AttributeType_INTS);
2402 attr_proto->set_name("axes");
2403 attr_proto->add_ints(kOneNum);
2404 attr_proto->add_ints(kTwoNum);
2405
2406 node_proto->add_input(input_slope);
2407 input_slope = node_name;
2408 }
2409
2410 auto node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
2411 onnx::NodeProto *node_proto = graph_proto->add_node();
2412 node_proto->set_op_type("PRelu");
2413 node_proto->add_output(node_name);
2414 node_proto->add_input(input_x);
2415 node_proto->add_input(input_slope);
2416 }
2417
ExportPrimReLU6(const FuncGraphPtr &,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)2418 void OnnxExporter::ExportPrimReLU6(const FuncGraphPtr &, const CNodePtr &node,
2419 std::map<AnfNodePtr, std::string> *node_map_ptr,
2420 onnx::GraphProto *const graph_proto) {
2421 auto node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
2422
2423 auto input_x_name = GetNodeInputName(node->input(kOneNum), node_map_ptr, graph_proto);
2424 auto onnx_input_type = GetOutputType(node->input(kOneNum));
2425 AddClipOp(input_x_name, node_name, 0.0f, 6.0f, onnx_input_type, graph_proto);
2426 }
2427
ExportPrimDepthwiseConv2d(const FuncGraphPtr &,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)2428 void OnnxExporter::ExportPrimDepthwiseConv2d(const FuncGraphPtr &, const CNodePtr &node,
2429 std::map<AnfNodePtr, std::string> *node_map_ptr,
2430 onnx::GraphProto *const graph_proto) {
2431 auto input_x = GetNodeInputName(node->input(kOneNum), node_map_ptr, graph_proto);
2432 auto input_w = GetNodeInputName(node->input(kTwoNum), node_map_ptr, graph_proto);
2433 auto x_shape = dyn_cast<abstract::Shape>(node->input(kOneNum)->Shape());
2434 auto w_shape = dyn_cast<abstract::Shape>(node->input(kTwoNum)->Shape());
2435 MS_EXCEPTION_IF_NULL(x_shape);
2436 MS_EXCEPTION_IF_NULL(w_shape);
2437 if (x_shape->shape().size() != kFourNum || w_shape->shape().size() != kFourNum) {
2438 MS_LOG(EXCEPTION) << "DepthwiseConv2d input shape should be 4d.";
2439 }
2440 if (w_shape->shape()[kZeroNum] != kOneNum && w_shape->shape()[kOneNum] != kOneNum) {
2441 MS_LOG(EXCEPTION) << "DepthwiseConv2d weight shape[0] != 1 and shape[1] != 1, cannot reshape";
2442 }
2443 // create w_shape constant node
2444 auto node_name = GenerateUniqueName();
2445 onnx::NodeProto *node_proto = graph_proto->add_node();
2446 auto name_w_shape = node_name;
2447 node_proto->add_output(name_w_shape);
2448 node_proto->set_op_type("Constant");
2449 // create Value Tensor
2450 onnx::AttributeProto *attr_proto = node_proto->add_attribute();
2451 attr_proto->set_name("value");
2452 attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR);
2453 onnx::TensorProto *tensor_proto = attr_proto->mutable_t();
2454 tensor_proto->add_dims(static_cast<::google::protobuf::int64>(w_shape->shape().size()));
2455 tensor_proto->set_data_type(onnx::TensorProto_DataType_INT64);
2456 // reshape
2457 tensor_proto->add_int64_data(w_shape->shape()[kOneNum]);
2458 tensor_proto->add_int64_data(w_shape->shape()[kZeroNum]);
2459 tensor_proto->add_int64_data(w_shape->shape()[kTwoNum]);
2460 tensor_proto->add_int64_data(w_shape->shape()[kThreeNum]);
2461
2462 // add reshape node
2463 node_name = GenerateUniqueName();
2464 node_proto = graph_proto->add_node();
2465 node_proto->set_op_type(prim::kPrimReshape->name());
2466 node_proto->add_input(input_w);
2467 node_proto->add_input(name_w_shape);
2468 input_w = node_name;
2469 node_proto->add_output(input_w);
2470
2471 // add conv node
2472 node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
2473 node_proto = graph_proto->add_node();
2474 node_proto->set_op_type("Conv");
2475 node_proto->add_input(input_x);
2476 node_proto->add_input(input_w);
2477 node_proto->add_output(node_name);
2478 // set attributes
2479 AnfNodePtr op = node->input(0);
2480 auto op_value = dyn_cast<ValueNode>(op);
2481 auto prim = dyn_cast<Primitive>(op_value->value());
2482 // set dilations
2483 onnx::AttributeProto *onnx_attr_proto = node_proto->add_attribute();
2484 onnx_attr_proto->set_name("dilations");
2485 SetAttrTupleValueToProto<2>(prim->GetAttr("dilation"), onnx::AttributeProto_AttributeType_INTS, onnx_attr_proto,
2486 prim);
2487 // set group
2488 onnx_attr_proto = node_proto->add_attribute();
2489 onnx_attr_proto->set_name("group");
2490 onnx_attr_proto->set_type(onnx::AttributeProto_AttributeType_INT);
2491 onnx_attr_proto->set_i(x_shape->shape()[1]);
2492 // set kernel_shape
2493 onnx_attr_proto = node_proto->add_attribute();
2494 onnx_attr_proto->set_name("kernel_shape");
2495 SetAttrTupleValueToProto<0>(prim->GetAttr("kernel_size"), onnx::AttributeProto_AttributeType_INTS, onnx_attr_proto,
2496 prim);
2497
2498 // set pad
2499 onnx_attr_proto = node_proto->add_attribute();
2500 int64_t attr_value;
2501 CheckAndConvertUtils::GetPadModEnumValue(prim->GetAttr("pad_mode"), &attr_value);
2502 onnx_attr_proto->set_name("auto_pad");
2503 onnx_attr_proto->set_type(onnx::AttributeProto_AttributeType_STRING);
2504 if (attr_value == PadMode::VALID) {
2505 onnx_attr_proto->set_s("VALID");
2506 } else if (attr_value == PadMode::SAME) {
2507 onnx_attr_proto->set_s("SAME_UPPER");
2508 } else {
2509 onnx_attr_proto->set_name("pads");
2510 SetAttrTupleValueToProto(prim->GetAttr("pad_list"), onnx::AttributeProto_AttributeType_INTS, onnx_attr_proto, prim);
2511 }
2512 // set strides
2513 onnx_attr_proto = node_proto->add_attribute();
2514 onnx_attr_proto->set_name("strides");
2515 SetAttrTupleValueToProto<2>(prim->GetAttr("stride"), onnx::AttributeProto_AttributeType_INTS, onnx_attr_proto, prim);
2516 }
2517
ExportPrimTile(const FuncGraphPtr &,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)2518 void OnnxExporter::ExportPrimTile(const FuncGraphPtr &, const CNodePtr &node,
2519 std::map<AnfNodePtr, std::string> *node_map_ptr,
2520 onnx::GraphProto *const graph_proto) {
2521 auto name_x = GetNodeInputName(node->input(kOneNum), node_map_ptr, graph_proto);
2522 auto multiples = node->input(kTwoNum);
2523 std::string name_multiples;
2524 if (multiples->isa<ValueNode>()) {
2525 onnx::NodeProto *node_proto = graph_proto->add_node();
2526 name_multiples = RegisterNodeWithUniqueName(multiples, node_map_ptr);
2527 node_proto->add_output(name_multiples);
2528 node_proto->set_op_type("Constant");
2529 onnx::AttributeProto *attr_proto = node_proto->add_attribute();
2530 attr_proto->set_name("value");
2531 attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR);
2532 ConvertTupleToTensor(dyn_cast<ValueNode>(multiples)->value(), attr_proto->mutable_t());
2533 } else {
2534 name_multiples = GetNodeInputName(multiples, node_map_ptr, graph_proto);
2535 MS_LOG(EXCEPTION) << "Need to insert op convert variable from tuple to tensor for Tile.";
2536 }
2537
2538 auto node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
2539 onnx::NodeProto *node_proto = graph_proto->add_node();
2540 node_proto->set_op_type("Tile");
2541 node_proto->add_output(node_name);
2542 node_proto->add_input(name_x);
2543 node_proto->add_input(name_multiples);
2544 }
2545
ExportPrimSquare(const FuncGraphPtr &,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)2546 void OnnxExporter::ExportPrimSquare(const FuncGraphPtr &, const CNodePtr &node,
2547 std::map<AnfNodePtr, std::string> *node_map_ptr,
2548 onnx::GraphProto *const graph_proto) {
2549 auto name_x = GetNodeInputName(node->input(kOneNum), node_map_ptr, graph_proto);
2550 auto name_exponent = GenerateUniqueName();
2551 onnx::NodeProto *node_proto_exp = graph_proto->add_node();
2552 node_proto_exp->add_output(name_exponent);
2553
2554 node_proto_exp->set_op_type("Constant");
2555 onnx::AttributeProto *attr_proto = node_proto_exp->add_attribute();
2556 attr_proto->set_name("value");
2557 attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR);
2558 onnx::TensorProto *tensor_proto = attr_proto->mutable_t();
2559 const float exponent_value = 2.0;
2560 tensor_proto->set_name("exponent");
2561 tensor_proto->add_dims(static_cast<::google::protobuf::int64>(1));
2562 tensor_proto->set_data_type(GetOnnxDataType(kNumberTypeFloat32));
2563 tensor_proto->add_float_data(exponent_value);
2564
2565 auto node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
2566 onnx::NodeProto *node_proto = graph_proto->add_node();
2567 node_proto->set_op_type("Pow");
2568 node_proto->add_output(node_name);
2569 node_proto->add_input(name_x);
2570 node_proto->add_input(name_exponent);
2571 }
2572
ExportPrimGatherV2(const FuncGraphPtr &,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)2573 void OnnxExporter::ExportPrimGatherV2(const FuncGraphPtr &, const CNodePtr &node,
2574 std::map<AnfNodePtr, std::string> *node_map_ptr,
2575 onnx::GraphProto *const graph_proto) {
2576 auto name_x = GetNodeInputName(node->input(kOneNum), node_map_ptr, graph_proto);
2577 auto name_indices = GetNodeInputName(node->input(kTwoNum), node_map_ptr, graph_proto);
2578 auto axis = node->input(kThreeNum)->cast<ValueNodePtr>()->value();
2579 auto node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
2580 onnx::NodeProto *node_proto = graph_proto->add_node();
2581 node_proto->set_op_type("Gather");
2582 node_proto->add_output(node_name);
2583 node_proto->add_input(name_x);
2584 node_proto->add_input(name_indices);
2585 onnx::AttributeProto *attr_proto = node_proto->add_attribute();
2586 attr_proto->set_name("axis");
2587 attr_proto->set_type(onnx::AttributeProto_AttributeType_INT);
2588 attr_proto->set_i(static_cast<::google::protobuf::int64>(dyn_cast<Int64Imm>(axis)->value()));
2589 }
2590
2591 /*
2592 This is a workaround for nodes with several outputs used at once
2593 MatchAndMark cannot help here, because it only supports a single output
2594 Proposed convention:
2595 * Nodes with several outputs are registered as
2596 `(*node_map_ptr)[node] = node_idx;`, just like nodes with a single output
2597 * Their outputs are named "{node_idx}_{output_idx}"
2598 * TupleGetItem automatically passes the outputs to the next nodes
2599 See OnnxExporter::ExportPrimTopK for a usage example
2600 */
ExportPrimTupleGetItem(const FuncGraphPtr &,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)2601 void OnnxExporter::ExportPrimTupleGetItem(const FuncGraphPtr &, const CNodePtr &node,
2602 std::map<AnfNodePtr, std::string> *node_map_ptr,
2603 onnx::GraphProto *const graph_proto) {
2604 auto index = GetInt64Value(node->input(kTwoNum));
2605 auto input_node_name = GetNodeInputName(node->input(kOneNum), node_map_ptr, graph_proto);
2606 auto input_name = MakeOutputName(input_node_name, index);
2607
2608 auto node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
2609
2610 onnx::NodeProto *node_proto = graph_proto->add_node();
2611 node_proto->set_op_type("Identity");
2612 node_proto->add_input(input_name);
2613 node_proto->add_output(node_name);
2614 }
2615
ExportPrimTopK(const FuncGraphPtr &,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)2616 void OnnxExporter::ExportPrimTopK(const FuncGraphPtr &, const CNodePtr &node,
2617 std::map<AnfNodePtr, std::string> *node_map_ptr,
2618 onnx::GraphProto *const graph_proto) {
2619 auto x_input_name = GetNodeInputName(node->input(kOneNum), node_map_ptr, graph_proto);
2620
2621 auto node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
2622
2623 auto k_input_name = node_name + "k_initializer";
2624 auto k = GetInt64Value(node->input(kTwoNum));
2625 AddInt64Tensor1DInitializer(k_input_name, {k}, graph_proto);
2626
2627 onnx::NodeProto *node_proto = graph_proto->add_node();
2628 node_proto->set_op_type("TopK");
2629 node_proto->add_input(x_input_name);
2630 node_proto->add_input(k_input_name);
2631 node_proto->add_output(MakeOutputName(node_name, kZeroNum)); // Values
2632 auto indices_name = MakeOutputName(node_name, kOneNum);
2633 auto indices_cast_name = indices_name + "_cast";
2634 node_proto->add_output(indices_cast_name);
2635
2636 onnx::AttributeProto *sorted_attr_proto = node_proto->add_attribute();
2637 sorted_attr_proto->set_name("sorted");
2638 sorted_attr_proto->set_type(onnx::AttributeProto_AttributeType_INT);
2639 auto sorted = GetOpAttribute<bool>(node, "sorted");
2640 sorted_attr_proto->set_i(sorted);
2641 AddCastOp(indices_cast_name, indices_name, onnx::TensorProto_DataType_INT32, graph_proto);
2642 }
2643
2644 // Based on mindspore/ccsrc/backend/kernel_compiler/cpu/boundingbox_decode_cpu_kernel.cc
ExportPrimBoundingBoxDecode(const FuncGraphPtr &,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)2645 void OnnxExporter::ExportPrimBoundingBoxDecode(const FuncGraphPtr &, const CNodePtr &node,
2646 std::map<AnfNodePtr, std::string> *node_map_ptr,
2647 onnx::GraphProto *const graph_proto) {
2648 auto node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
2649
2650 auto anchor_bbox_input_name = GetNodeInputName(node->input(kOneNum), node_map_ptr, graph_proto);
2651 auto deltas_input_name = GetNodeInputName(node->input(kTwoNum), node_map_ptr, graph_proto);
2652 auto onnx_input_type = GetOutputType(node->input(kOneNum));
2653
2654 auto means = GetOpAttributePtr<ValueTuple>(node, "means");
2655 std::vector<float> mean_values = GetValue<std::vector<float>>(means);
2656 auto means_name = node_name + "means_initializer";
2657 AddFloatTensor1DInitializer(means_name, mean_values, onnx_input_type, graph_proto);
2658
2659 auto stds = GetOpAttributePtr<ValueTuple>(node, "stds");
2660 std::vector<float> std_values = GetValue<std::vector<float>>(stds);
2661 auto stds_name = node_name + "stds_initializer";
2662 AddFloatTensor1DInitializer(stds_name, std_values, onnx_input_type, graph_proto);
2663
2664 auto wh_ratio_clip = GetOpAttribute<float>(node, "wh_ratio_clip");
2665 auto max_ratio = static_cast<float>(std::abs(std::log(wh_ratio_clip)));
2666
2667 auto unstd_deltas_name = node_name + "unstd_deltas";
2668 auto sd_to_add_name = unstd_deltas_name + "__add";
2669 AddOp("Mul", {deltas_input_name, stds_name}, {sd_to_add_name}, graph_proto);
2670 AddOp("Add", {sd_to_add_name, means_name}, {unstd_deltas_name}, graph_proto);
2671
2672 auto center_deltas_name = node_name + "center_deltas";
2673 auto log_scale_deltas_name = node_name + "log_scale_deltas";
2674 auto lsd_to_clip_name = log_scale_deltas_name + "__clip";
2675 AddSplitOp(unstd_deltas_name, {center_deltas_name, lsd_to_clip_name}, {kTwoNum, kTwoNum}, 1, graph_proto);
2676 AddClipOp(lsd_to_clip_name, log_scale_deltas_name, -max_ratio, max_ratio, onnx_input_type, graph_proto);
2677
2678 auto anchor_starts_name = node_name + "anchor_starts";
2679 auto anchor_ends_name = node_name + "anchor_ends";
2680 AddSplitOp(anchor_bbox_input_name, {anchor_starts_name, anchor_ends_name}, {kTwoNum, kTwoNum}, 1, graph_proto);
2681
2682 auto anchor_centers_name = node_name + "anchor_centers";
2683 auto anchor_dimensions_name = node_name + "anchor_dimensions";
2684 ConvertBoxesToXywh(anchor_starts_name, anchor_ends_name, anchor_centers_name, anchor_dimensions_name, onnx_input_type,
2685 graph_proto);
2686
2687 auto anchor_shifts_name = node_name + "anchor_shifts";
2688 AddOp("Mul", {anchor_dimensions_name, center_deltas_name}, {anchor_shifts_name}, graph_proto);
2689 auto result_centers_name = node_name + "result_centers";
2690 AddOp("Add", {anchor_centers_name, anchor_shifts_name}, {result_centers_name}, graph_proto);
2691
2692 auto anchor_scales_name = node_name + "anchor_scales";
2693 AddOp("Exp", {log_scale_deltas_name}, {anchor_scales_name}, graph_proto);
2694 auto result_dimensions_name = node_name + "result_dimensions";
2695 AddOp("Mul", {anchor_dimensions_name, anchor_scales_name}, {result_dimensions_name}, graph_proto);
2696
2697 auto result_starts_to_clip_name = node_name + "result_starts_to_clip";
2698 auto result_ends_to_clip_name = node_name + "result_ends_to_clip";
2699 ConvertBoxesToXyxy(result_centers_name, result_dimensions_name, result_starts_to_clip_name, result_ends_to_clip_name,
2700 onnx_input_type, graph_proto);
2701
2702 auto max_shape = GetOpAttributePtr<ValueTuple>(node, "max_shape");
2703 auto max_y = GetValue<int64_t>((*max_shape)[0]);
2704 auto max_x = GetValue<int64_t>((*max_shape)[1]);
2705 auto result_start_xs_name = node_name + "result_start_x";
2706 auto result_start_ys_name = node_name + "result_start_y";
2707 auto result_end_xs_name = node_name + "result_end_x";
2708 auto result_end_ys_name = node_name + "result_end_y";
2709 ClipPointsComponent(result_starts_to_clip_name, result_start_xs_name, static_cast<float>(max_x), 0, onnx_input_type,
2710 graph_proto);
2711 ClipPointsComponent(result_starts_to_clip_name, result_start_ys_name, static_cast<float>(max_y), 1, onnx_input_type,
2712 graph_proto);
2713 ClipPointsComponent(result_ends_to_clip_name, result_end_xs_name, static_cast<float>(max_x), 0, onnx_input_type,
2714 graph_proto);
2715 ClipPointsComponent(result_ends_to_clip_name, result_end_ys_name, static_cast<float>(max_y), 1, onnx_input_type,
2716 graph_proto);
2717
2718 AddConcatOp({result_start_xs_name, result_start_ys_name, result_end_xs_name, result_end_ys_name}, node_name, kOneNum,
2719 graph_proto);
2720 }
2721
ExportPrimNMSWithMask(const FuncGraphPtr &,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)2722 void OnnxExporter::ExportPrimNMSWithMask(const FuncGraphPtr &, const CNodePtr &node,
2723 std::map<AnfNodePtr, std::string> *node_map_ptr,
2724 onnx::GraphProto *const graph_proto) {
2725 auto node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
2726
2727 auto bboxes_input_name = GetNodeInputName(node->input(kOneNum), node_map_ptr, graph_proto);
2728 auto iou_threshold = GetOpAttribute<float>(node, "iou_threshold");
2729 auto selected_boxes_output_name = MakeOutputName(node_name, kZeroNum);
2730 auto selected_idx_output_name = MakeOutputName(node_name, kOneNum);
2731 auto selected_mask_output_name = MakeOutputName(node_name, kTwoNum);
2732 auto onnx_input_type = GetOutputType(node->input(kOneNum));
2733
2734 // Preprocessing
2735
2736 auto boxes_count_name = node_name + "max_output_boxes";
2737 auto max_output_boxes_to_squeeze_name = boxes_count_name + "_to_reshape";
2738 auto input_shape_name = node_name + "input_shape";
2739 AddOp("Shape", {bboxes_input_name}, {input_shape_name}, graph_proto);
2740 AddSliceOp(input_shape_name, max_output_boxes_to_squeeze_name, {0}, {1}, {0}, {1}, graph_proto);
2741 AddReshapeOp(max_output_boxes_to_squeeze_name, boxes_count_name, {}, graph_proto);
2742
2743 auto scores_name = node_name + "scores";
2744 auto flat_scores_name = scores_name + "_flat";
2745 auto sorted_scores_name = flat_scores_name + "_sorted";
2746 auto scores_to_flatten_name = scores_name + "_to_reshape";
2747 auto descending_order_name = node_name + "descending_indices";
2748 const int BBOX_NUM_EL = 4;
2749 AddSliceOp(bboxes_input_name, scores_to_flatten_name, {BBOX_NUM_EL}, {BBOX_NUM_EL + 1}, {1}, {1}, graph_proto);
2750 AddReshapeOp(scores_to_flatten_name, flat_scores_name, {-1}, graph_proto);
2751 AddOp("TopK", {flat_scores_name, max_output_boxes_to_squeeze_name}, {sorted_scores_name, descending_order_name},
2752 graph_proto);
2753 AddReshapeOp(sorted_scores_name, scores_name, {1, 1, -1}, graph_proto);
2754 auto iou_threshold_name = node_name + "iou_threshold_initializer";
2755 AddFloatScalarInitializer(iou_threshold_name, iou_threshold, onnx::TensorProto_DataType_FLOAT, graph_proto);
2756
2757 AddOp("Gather", {bboxes_input_name, descending_order_name}, {selected_boxes_output_name},
2758 graph_proto); // Output 0: boxes
2759 auto boxes_name = node_name + "boxes";
2760 auto boxes_to_reshape_name = boxes_name + "_to_reshape";
2761 AddSliceOp(selected_boxes_output_name, boxes_to_reshape_name, {0}, {BBOX_NUM_EL}, {1}, {1}, graph_proto);
2762 AddReshapeOp(boxes_to_reshape_name, boxes_name, {1, -1, BBOX_NUM_EL}, graph_proto);
2763
2764 if (onnx_input_type == onnx::TensorProto_DataType_FLOAT16) {
2765 auto fp32_boxes_name = boxes_name + "_fp32";
2766 AddCastOp(boxes_name, fp32_boxes_name, onnx::TensorProto_DataType_FLOAT, graph_proto);
2767 boxes_name = fp32_boxes_name;
2768
2769 auto fp32_scores_name = scores_name + "_fp32";
2770 AddCastOp(scores_name, fp32_scores_name, onnx::TensorProto_DataType_FLOAT, graph_proto);
2771 scores_name = fp32_scores_name;
2772 }
2773
2774 // NMS op
2775
2776 auto selected_indices_name = node_name + "selected_indices";
2777 AddOp("NonMaxSuppression", {boxes_name, scores_name, boxes_count_name, iou_threshold_name}, {selected_indices_name},
2778 graph_proto);
2779
2780 // Output 1: indices
2781
2782 auto flat_indices_name = node_name + "flat_indices";
2783 auto flat_indices_to_squeeze_name = flat_indices_name + "__reshape";
2784 const int BOX_INDEX_POS = 2;
2785 AddSliceOp(selected_indices_name, flat_indices_to_squeeze_name, {BOX_INDEX_POS}, {BOX_INDEX_POS + 1}, {1}, {1},
2786 graph_proto);
2787 AddReshapeOp(flat_indices_to_squeeze_name, flat_indices_name, {-1}, graph_proto);
2788
2789 auto zero_name = node_name + "zero_initializer";
2790 onnx::TensorProto *zero_initializer = graph_proto->add_initializer();
2791 zero_initializer->set_name(zero_name);
2792 zero_initializer->set_data_type(onnx::TensorProto_DataType_INT32);
2793 zero_initializer->add_int32_data(0);
2794 auto one_name = node_name + "one_initializer";
2795 onnx::TensorProto *one_initializer = graph_proto->add_initializer();
2796 one_initializer->set_name(one_name);
2797 one_initializer->set_data_type(onnx::TensorProto_DataType_INT32);
2798 one_initializer->add_int32_data(1);
2799 auto int32_boxes_count_name = boxes_count_name + "_int32";
2800 AddCastOp(boxes_count_name, int32_boxes_count_name, onnx::TensorProto_DataType_INT32, graph_proto);
2801 AddOp("Range", {zero_name, int32_boxes_count_name, one_name}, {selected_idx_output_name}, graph_proto);
2802
2803 // Output 2: mask
2804
2805 auto empty_mask_name = selected_mask_output_name + "__scatter";
2806 onnx::TensorProto *empty_mask_value_proto =
2807 AddConstantOfShapeOp(max_output_boxes_to_squeeze_name, empty_mask_name, graph_proto);
2808 empty_mask_value_proto->set_data_type(onnx::TensorProto_DataType_BOOL);
2809 empty_mask_value_proto->add_int32_data(0);
2810
2811 auto true_elements_name = node_name + "true";
2812 auto true_elements_shape_name = true_elements_name + "_shape";
2813 AddOp("Shape", {flat_indices_name}, {true_elements_shape_name}, graph_proto);
2814 onnx::TensorProto *true_elements_value_proto =
2815 AddConstantOfShapeOp(true_elements_shape_name, true_elements_name, graph_proto);
2816 true_elements_value_proto->set_data_type(onnx::TensorProto_DataType_BOOL);
2817 true_elements_value_proto->add_int32_data(1);
2818
2819 AddOp("ScatterElements", {empty_mask_name, flat_indices_name, true_elements_name}, {selected_mask_output_name},
2820 graph_proto);
2821 }
2822
ExportPrimSplit(const FuncGraphPtr &,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)2823 void OnnxExporter::ExportPrimSplit(const FuncGraphPtr &, const CNodePtr &node,
2824 std::map<AnfNodePtr, std::string> *node_map_ptr,
2825 onnx::GraphProto *const graph_proto) {
2826 auto node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
2827 auto input_name = GetNodeInputName(node->input(kOneNum), node_map_ptr, graph_proto);
2828
2829 auto axis = GetOpAttribute<int64_t>(node, "axis");
2830 auto output_num = GetOpAttribute<int64_t>(node, "output_num");
2831 if (output_num == 0) {
2832 MS_LOG(EXCEPTION) << "output_num must be > 0";
2833 }
2834 const auto &input_shape = dyn_cast<abstract::Shape>(node->input(kOneNum)->Shape())->shape();
2835
2836 if (axis < 0 || static_cast<size_t>(axis) >= input_shape.size()) {
2837 MS_LOG(EXCEPTION) << "`axis` is out of range";
2838 }
2839 if (input_shape[static_cast<size_t>(axis)] % output_num != 0) {
2840 MS_LOG(EXCEPTION) << "Input dim is not divisible by `output_num`";
2841 }
2842
2843 onnx::NodeProto *split_proto = graph_proto->add_node();
2844 split_proto->set_op_type("Split");
2845 split_proto->add_input(input_name);
2846 for (int64_t i = 0; i < output_num; ++i) {
2847 split_proto->add_output(MakeOutputName(node_name, i));
2848 }
2849
2850 onnx::AttributeProto *axis_attr_proto = split_proto->add_attribute();
2851 axis_attr_proto->set_name("axis");
2852 axis_attr_proto->set_type(onnx::AttributeProto_AttributeType_INT);
2853 axis_attr_proto->set_i(axis);
2854
2855 onnx::AttributeProto *split_attr_proto = split_proto->add_attribute();
2856 split_attr_proto->set_name("split");
2857 split_attr_proto->set_type(onnx::AttributeProto_AttributeType_INTS);
2858 for (int64_t i = 0; i < output_num; ++i) {
2859 split_attr_proto->add_ints(input_shape[static_cast<size_t>(axis)] / output_num);
2860 }
2861 }
2862
2863 /*
2864 Based on mindspore-project/mindspore/ccsrc/backend/kernel_compiler/cpu/roi_align_cpu_kernel.cc
2865 Notes:
2866 * MS version uses avg pool, leaving corresponding ONNX attr as is
2867 * MS has two ROI end modes, implemented with pre-processing
2868 */
ExportPrimROIAlign(const FuncGraphPtr &,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)2869 void OnnxExporter::ExportPrimROIAlign(const FuncGraphPtr &, const CNodePtr &node,
2870 std::map<AnfNodePtr, std::string> *node_map_ptr,
2871 onnx::GraphProto *const graph_proto) {
2872 auto node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
2873 auto features_input_name = GetNodeInputName(node->input(kOneNum), node_map_ptr, graph_proto);
2874 auto rois_input_name = GetNodeInputName(node->input(kTwoNum), node_map_ptr, graph_proto);
2875 auto onnx_input_type = GetOutputType(node->input(kOneNum));
2876
2877 auto roi_indices_name = node_name + "roi_indices";
2878 auto roi_indices_column_name = roi_indices_name + "_column";
2879 auto roi_starts_name = node_name + "roi_starts";
2880 auto roi_ends_name = node_name + "roi_ends";
2881 AddSplitOp(rois_input_name, {roi_indices_column_name, roi_starts_name, roi_ends_name}, {1, kTwoNum, kTwoNum}, 1,
2882 graph_proto);
2883
2884 // Indices transformation
2885
2886 auto flat_roi_indices_name = roi_indices_name + "_flat";
2887 AddReshapeOp(roi_indices_column_name, flat_roi_indices_name, {-1}, graph_proto);
2888 auto int_roi_indices_name = roi_indices_name + "_int";
2889 // This should be fine if indices are whole numbers less than 2^23
2890 AddCastOp(flat_roi_indices_name, int_roi_indices_name, onnx::TensorProto_DataType_INT64, graph_proto);
2891
2892 // ROI end mode
2893
2894 auto roi_end_mode = GetOpAttribute<int64_t>(node, "roi_end_mode");
2895 auto roi_end_mode_name = node_name + "roi_end_mode_initializer";
2896 AddFloatScalarInitializer(roi_end_mode_name, static_cast<float>(roi_end_mode), onnx_input_type, graph_proto);
2897
2898 auto corrected_roi_ends_name = roi_ends_name + "_corrected";
2899 AddOp("Add", {roi_ends_name, roi_end_mode_name}, {corrected_roi_ends_name}, graph_proto);
2900
2901 // Contatenate ROIs
2902
2903 auto corrected_rois_name = node_name + "corrected_rois";
2904 AddConcatOp({roi_starts_name, corrected_roi_ends_name}, corrected_rois_name, kOneNum, graph_proto);
2905
2906 // RoiAlign op
2907
2908 onnx::NodeProto *roi_align_proto = graph_proto->add_node();
2909 roi_align_proto->set_op_type("RoiAlign");
2910 roi_align_proto->add_input(features_input_name);
2911 roi_align_proto->add_input(corrected_rois_name);
2912 roi_align_proto->add_input(int_roi_indices_name);
2913 roi_align_proto->add_output(node_name);
2914 onnx::AttributeProto *height_attr_proto = roi_align_proto->add_attribute();
2915 height_attr_proto->set_name("output_height");
2916 height_attr_proto->set_type(onnx::AttributeProto_AttributeType_INT);
2917 height_attr_proto->set_i(GetOpAttribute<int64_t>(node, "pooled_height"));
2918 onnx::AttributeProto *width_attr_proto = roi_align_proto->add_attribute();
2919 width_attr_proto->set_name("output_width");
2920 width_attr_proto->set_type(onnx::AttributeProto_AttributeType_INT);
2921 width_attr_proto->set_i(GetOpAttribute<int64_t>(node, "pooled_width"));
2922 onnx::AttributeProto *scale_attr_proto = roi_align_proto->add_attribute();
2923 scale_attr_proto->set_name("spatial_scale");
2924 scale_attr_proto->set_type(onnx::AttributeProto_AttributeType_FLOAT);
2925 scale_attr_proto->set_f(GetOpAttribute<float>(node, "spatial_scale"));
2926 onnx::AttributeProto *sampling_ratio_attr_proto = roi_align_proto->add_attribute();
2927 sampling_ratio_attr_proto->set_name("sampling_ratio");
2928 sampling_ratio_attr_proto->set_type(onnx::AttributeProto_AttributeType_INT);
2929 sampling_ratio_attr_proto->set_i(GetOpAttribute<int64_t>(node, "sample_num"));
2930 }
2931
ExportPrimSlice(const FuncGraphPtr &,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)2932 void OnnxExporter::ExportPrimSlice(const FuncGraphPtr &, const CNodePtr &node,
2933 std::map<AnfNodePtr, std::string> *node_map_ptr,
2934 onnx::GraphProto *const graph_proto) {
2935 auto node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
2936 auto input_x_name = GetNodeInputName(node->input(kOneNum), node_map_ptr, graph_proto);
2937 auto begin_input_name = GetNodeInputName(node->input(kTwoNum), node_map_ptr, graph_proto);
2938 auto size_input_name = GetNodeInputName(node->input(kThreeNum), node_map_ptr, graph_proto);
2939
2940 auto end_name = node_name + "end";
2941 AddOp("Add", {begin_input_name, size_input_name}, {end_name}, graph_proto);
2942 AddOp("Slice", {input_x_name, begin_input_name, end_name}, {node_name}, graph_proto);
2943 }
2944
ExportPrimOnesLike(const FuncGraphPtr &,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)2945 void OnnxExporter::ExportPrimOnesLike(const FuncGraphPtr &, const CNodePtr &node,
2946 std::map<AnfNodePtr, std::string> *node_map_ptr,
2947 onnx::GraphProto *const graph_proto) {
2948 auto node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
2949 auto input_x_name = GetNodeInputName(node->input(kOneNum), node_map_ptr, graph_proto);
2950
2951 auto shape_name = node_name + "shape";
2952 AddOp("Shape", {input_x_name}, {shape_name}, graph_proto);
2953
2954 auto dtype = node->input(kOneNum)->Type();
2955 auto elem_type = dyn_cast<TensorType>(dtype)->element()->type_id();
2956
2957 onnx::TensorProto *one_proto = AddConstantOfShapeOp(shape_name, node_name, graph_proto);
2958 switch (elem_type) {
2959 case kNumberTypeInt32:
2960 one_proto->set_data_type(onnx::TensorProto_DataType_INT32);
2961 one_proto->add_int32_data(1);
2962 break;
2963 case kNumberTypeInt64:
2964 one_proto->set_data_type(onnx::TensorProto_DataType_INT64);
2965 one_proto->add_int64_data(1);
2966 break;
2967 case kNumberTypeFloat32:
2968 one_proto->set_data_type(onnx::TensorProto_DataType_FLOAT);
2969 one_proto->add_float_data(1.0f);
2970 break;
2971 case kNumberTypeFloat64:
2972 one_proto->set_data_type(onnx::TensorProto_DataType_DOUBLE);
2973 one_proto->add_double_data(1.0);
2974 break;
2975 default:
2976 MS_LOG(EXCEPTION) << "Unsupported dtype: " << elem_type;
2977 }
2978 }
2979
ExportPrimScatterNd(const FuncGraphPtr &,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)2980 void OnnxExporter::ExportPrimScatterNd(const FuncGraphPtr &, const CNodePtr &node,
2981 std::map<AnfNodePtr, std::string> *node_map_ptr,
2982 onnx::GraphProto *const graph_proto) {
2983 auto node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
2984 auto input_indices_name = GetNodeInputName(node->input(kOneNum), node_map_ptr, graph_proto);
2985 auto input_update_name = GetNodeInputName(node->input(kTwoNum), node_map_ptr, graph_proto);
2986 auto input_shape_name = GetNodeInputName(node->input(kThreeNum), node_map_ptr, graph_proto);
2987 auto node_zero_tensor_name = node_name + "_zero";
2988 auto dtype = node->input(kTwoNum)->Type();
2989 auto elem_type = dyn_cast<TensorType>(dtype)->element()->type_id();
2990
2991 onnx::TensorProto *zero_proto = AddConstantOfShapeOp(input_shape_name, node_zero_tensor_name, graph_proto);
2992 switch (elem_type) {
2993 case kNumberTypeInt32:
2994 zero_proto->set_data_type(onnx::TensorProto_DataType_INT32);
2995 zero_proto->add_int32_data(0);
2996 break;
2997 case kNumberTypeInt64:
2998 zero_proto->set_data_type(onnx::TensorProto_DataType_INT64);
2999 zero_proto->add_int64_data(0);
3000 break;
3001 case kNumberTypeFloat32:
3002 zero_proto->set_data_type(onnx::TensorProto_DataType_FLOAT);
3003 zero_proto->add_float_data(0.0f);
3004 break;
3005 case kNumberTypeFloat64:
3006 zero_proto->set_data_type(onnx::TensorProto_DataType_DOUBLE);
3007 zero_proto->add_double_data(0.0);
3008 break;
3009 default:
3010 MS_LOG(EXCEPTION) << "Unsupported dtype: " << elem_type;
3011 }
3012 auto int64_indices_name = input_indices_name + "_int64";
3013 AddCastOp(input_indices_name, int64_indices_name, onnx::TensorProto_DataType_INT64, graph_proto);
3014
3015 // Create ScatterND node
3016 onnx::NodeProto *scatternd_proto = graph_proto->add_node();
3017 scatternd_proto->set_op_type("ScatterND");
3018 scatternd_proto->add_input(node_zero_tensor_name);
3019 scatternd_proto->add_input(int64_indices_name);
3020 scatternd_proto->add_input(input_update_name);
3021 scatternd_proto->add_output(node_name);
3022 }
3023
ExportPrimArgMaxWithValue(const FuncGraphPtr &,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)3024 void OnnxExporter::ExportPrimArgMaxWithValue(const FuncGraphPtr &, const CNodePtr &node,
3025 std::map<AnfNodePtr, std::string> *node_map_ptr,
3026 onnx::GraphProto *const graph_proto) {
3027 auto node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
3028 auto input_x_name = GetNodeInputName(node->input(kOneNum), node_map_ptr, graph_proto);
3029 auto axis = GetOpAttribute<int64_t>(node, "axis");
3030 auto keep_dims = GetOpAttribute<bool>(node, "keep_dims");
3031
3032 auto indices_output_name = MakeOutputName(node_name, kZeroNum);
3033 auto indices_cast_name = indices_output_name + "_cast";
3034
3035 onnx::NodeProto *argmax_proto = graph_proto->add_node();
3036 argmax_proto->set_op_type("ArgMax");
3037 argmax_proto->add_input(input_x_name);
3038 argmax_proto->add_output(indices_cast_name);
3039 onnx::AttributeProto *argmax_axis_attr_proto = argmax_proto->add_attribute();
3040 argmax_axis_attr_proto->set_name("axis");
3041 argmax_axis_attr_proto->set_type(onnx::AttributeProto_AttributeType_INT);
3042 argmax_axis_attr_proto->set_i(axis);
3043 onnx::AttributeProto *argmax_keepdims_attr_proto = argmax_proto->add_attribute();
3044 argmax_keepdims_attr_proto->set_name("keepdims");
3045 argmax_keepdims_attr_proto->set_type(onnx::AttributeProto_AttributeType_INT);
3046 argmax_keepdims_attr_proto->set_i(keep_dims);
3047
3048 AddCastOp(indices_cast_name, indices_output_name, onnx::TensorProto_DataType_INT32, graph_proto);
3049
3050 auto max_output_name = MakeOutputName(node_name, kOneNum);
3051 AddReduceOp("ReduceMax", input_x_name, max_output_name, {axis}, keep_dims, graph_proto);
3052 }
3053
ExportPrimArgMinWithValue(const FuncGraphPtr &,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)3054 void OnnxExporter::ExportPrimArgMinWithValue(const FuncGraphPtr &, const CNodePtr &node,
3055 std::map<AnfNodePtr, std::string> *node_map_ptr,
3056 onnx::GraphProto *const graph_proto) {
3057 auto node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
3058 auto input_x_name = GetNodeInputName(node->input(kOneNum), node_map_ptr, graph_proto);
3059 auto axis = GetOpAttribute<int64_t>(node, "axis");
3060 auto keep_dims = GetOpAttribute<bool>(node, "keep_dims");
3061
3062 auto indices_output_name = MakeOutputName(node_name, kZeroNum);
3063 auto indices_cast_name = indices_output_name + "_cast";
3064
3065 onnx::NodeProto *argmax_proto = graph_proto->add_node();
3066 argmax_proto->set_op_type("ArgMin");
3067 argmax_proto->add_input(input_x_name);
3068 argmax_proto->add_output(indices_cast_name);
3069 onnx::AttributeProto *argmax_axis_attr_proto = argmax_proto->add_attribute();
3070 argmax_axis_attr_proto->set_name("axis");
3071 argmax_axis_attr_proto->set_type(onnx::AttributeProto_AttributeType_INT);
3072 argmax_axis_attr_proto->set_i(axis);
3073 onnx::AttributeProto *argmax_keepdims_attr_proto = argmax_proto->add_attribute();
3074 argmax_keepdims_attr_proto->set_name("keepdims");
3075 argmax_keepdims_attr_proto->set_type(onnx::AttributeProto_AttributeType_INT);
3076 argmax_keepdims_attr_proto->set_i(static_cast<int64_t>(keep_dims));
3077
3078 AddCastOp(indices_cast_name, indices_output_name, onnx::TensorProto_DataType_INT32, graph_proto);
3079
3080 auto max_output_name = MakeOutputName(node_name, kOneNum);
3081 AddReduceOp("ReduceMin", input_x_name, max_output_name, {axis}, keep_dims, graph_proto);
3082 }
3083
ExportPrimOneHot(const FuncGraphPtr &,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)3084 void OnnxExporter::ExportPrimOneHot(const FuncGraphPtr &, const CNodePtr &node,
3085 std::map<AnfNodePtr, std::string> *node_map_ptr,
3086 onnx::GraphProto *const graph_proto) {
3087 auto node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
3088 auto indices_input_name = GetNodeInputName(node->input(kOneNum), node_map_ptr, graph_proto);
3089 auto depth_input_name = GetNodeInputName(node->input(kTwoNum), node_map_ptr, graph_proto);
3090 auto on_input_name = GetNodeInputName(node->input(kThreeNum), node_map_ptr, graph_proto);
3091 auto off_input_name = GetNodeInputName(node->input(kFourNum), node_map_ptr, graph_proto);
3092 auto axis = GetOpAttribute<int64_t>(node, "axis");
3093
3094 if (GetOutputType(node->input(kOneNum)) == onnx::TensorProto_DataType_INT32) {
3095 auto indices_cast_name = node_name + "_indices_as_int32";
3096 AddCastOp(indices_input_name, indices_cast_name, onnx::TensorProto_DataType_INT64, graph_proto);
3097 indices_input_name = indices_cast_name;
3098 }
3099
3100 auto on_1d_name = node_name + "on_1d";
3101 AddReshapeOp(on_input_name, on_1d_name, {-1}, graph_proto);
3102 auto off_1d_name = node_name + "off_1d";
3103 AddReshapeOp(off_input_name, off_1d_name, {-1}, graph_proto);
3104
3105 auto on_off_name = node_name + "on_off";
3106 AddConcatOp({off_1d_name, on_1d_name}, on_off_name, kZeroNum, graph_proto);
3107
3108 onnx::NodeProto *one_hot_proto = graph_proto->add_node();
3109 one_hot_proto->set_op_type("OneHot");
3110 one_hot_proto->add_input(indices_input_name);
3111 one_hot_proto->add_input(depth_input_name);
3112 one_hot_proto->add_input(on_off_name);
3113 one_hot_proto->add_output(node_name);
3114 onnx::AttributeProto *one_hot_axis_attr_proto = one_hot_proto->add_attribute();
3115 one_hot_axis_attr_proto->set_name("axis");
3116 one_hot_axis_attr_proto->set_type(onnx::AttributeProto_AttributeType_INT);
3117 one_hot_axis_attr_proto->set_i(axis);
3118 }
3119
3120 /*
3121 Based on nn.Conv2dTranspose
3122 Warning: `output_shape` is an input in MS and an attribute in ONNX. Hence
3123 it is not possible to change the output shape in runtime
3124 */
PrimConv2DTransposeExportHelper(const CNodePtr & conv_node,const CNodePtr & bias_add_node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)3125 void OnnxExporter::PrimConv2DTransposeExportHelper(const CNodePtr &conv_node, const CNodePtr &bias_add_node,
3126 std::map<AnfNodePtr, std::string> *node_map_ptr,
3127 onnx::GraphProto *const graph_proto) {
3128 std::string node_name;
3129
3130 std::vector<AnfNodePtr> inputs{conv_node->input(kOneNum), conv_node->input(kTwoNum)};
3131 if (bias_add_node != nullptr) {
3132 inputs.push_back(bias_add_node->input(kTwoNum));
3133 node_name = RegisterNodeWithUniqueName(bias_add_node, node_map_ptr);
3134 } else {
3135 node_name = RegisterNodeWithUniqueName(conv_node, node_map_ptr);
3136 }
3137
3138 onnx::NodeProto *node_proto = graph_proto->add_node();
3139 node_proto->set_op_type("ConvTranspose");
3140 for (const auto &input : inputs) {
3141 node_proto->add_input(GetNodeInputName(input, node_map_ptr, graph_proto));
3142 }
3143 node_proto->add_output(node_name);
3144
3145 auto prim = GetPrimitive(conv_node);
3146 auto attrs_convert_info =
3147 OpNameInfo()
3148 .Attr("dilation", "dilations", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<kTwoNum>)
3149 .Attr("group", "group", onnx::AttributeProto_AttributeType_INT, SetAttrValueToProto<Int64Imm>)
3150 .Attr("kernel_size", "kernel_shape", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<0>)
3151 .Attr("pad_mode", "auto_pad", onnx::AttributeProto_AttributeType_STRING, SetConvTransposePadding)
3152 .Attr("stride", "strides", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<kTwoNum>);
3153 for (const auto &attr_info : attrs_convert_info.op_attrs()) {
3154 onnx::AttributeProto *attr_proto = node_proto->add_attribute();
3155 attr_proto->set_name(attr_info.onnx_attr_name());
3156 auto ms_attr = GetOpAttributePtr<Value>(conv_node, attr_info.attr_name());
3157 MS_EXCEPTION_IF_NULL(ms_attr);
3158 attr_info.fn_gen_attr()(ms_attr, attr_info.onnx_attr_type(), attr_proto, prim);
3159 }
3160
3161 // Set output shape
3162
3163 auto input_shape_node = GetRealInput(conv_node->input(kThreeNum));
3164 if (!input_shape_node->isa<ValueNode>()) {
3165 MS_LOG(EXCEPTION) << "For ONNX export third argument must be constant "
3166 "(Python tuple). Instead got "
3167 << input_shape_node->ToString();
3168 }
3169 auto input_shape_value_ptr = input_shape_node->cast<ValueNodePtr>()->value();
3170 if (!input_shape_value_ptr->isa<ValueTuple>()) {
3171 MS_LOG(EXCEPTION) << "Expected ValueTuple, got " << input_shape_value_ptr->ToString() << " of type "
3172 << input_shape_value_ptr->type()->ToString();
3173 }
3174
3175 onnx::AttributeProto *output_shape_attr_proto = node_proto->add_attribute();
3176 output_shape_attr_proto->set_name("output_shape");
3177 SetAttrTupleValueToProto<0>(input_shape_value_ptr, onnx::AttributeProto_AttributeType_INTS, output_shape_attr_proto,
3178 prim);
3179 }
3180
ExportPrimConv2DTranspose(const FuncGraphPtr &,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * graph_proto)3181 void OnnxExporter::ExportPrimConv2DTranspose(const FuncGraphPtr &, const CNodePtr &node,
3182 std::map<AnfNodePtr, std::string> *node_map_ptr,
3183 onnx::GraphProto *graph_proto) {
3184 PrimConv2DTransposeExportHelper(node, nullptr, node_map_ptr, graph_proto);
3185 }
3186
ExportPrimGreaterEqual(const FuncGraphPtr &,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)3187 void OnnxExporter::ExportPrimGreaterEqual(const FuncGraphPtr &, const CNodePtr &node,
3188 std::map<AnfNodePtr, std::string> *node_map_ptr,
3189 onnx::GraphProto *const graph_proto) {
3190 auto node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
3191
3192 auto input_x_name = GetNodeInputName(node->input(kOneNum), node_map_ptr, graph_proto);
3193 auto input_y_name = GetNodeInputName(node->input(kTwoNum), node_map_ptr, graph_proto);
3194 auto less_name = node_name + "less";
3195
3196 AddOp("Less", {input_x_name, input_y_name}, {less_name}, graph_proto);
3197 AddOp("Not", {less_name}, {node_name}, graph_proto);
3198 }
3199
ExportPrimLessEqual(const FuncGraphPtr &,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)3200 void OnnxExporter::ExportPrimLessEqual(const FuncGraphPtr &, const CNodePtr &node,
3201 std::map<AnfNodePtr, std::string> *node_map_ptr,
3202 onnx::GraphProto *const graph_proto) {
3203 auto node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
3204
3205 auto input_x_name = GetNodeInputName(node->input(kOneNum), node_map_ptr, graph_proto);
3206 auto input_y_name = GetNodeInputName(node->input(kTwoNum), node_map_ptr, graph_proto);
3207 auto greater_name = node_name + "greater";
3208
3209 AddOp("Greater", {input_x_name, input_y_name}, {greater_name}, graph_proto);
3210 AddOp("Not", {greater_name}, {node_name}, graph_proto);
3211 }
3212
ExportPrimNotEqual(const FuncGraphPtr &,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)3213 void OnnxExporter::ExportPrimNotEqual(const FuncGraphPtr &, const CNodePtr &node,
3214 std::map<AnfNodePtr, std::string> *node_map_ptr,
3215 onnx::GraphProto *const graph_proto) {
3216 auto node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
3217 auto input_x_name = GetNodeInputName(node->input(kOneNum), node_map_ptr, graph_proto);
3218 auto input_y_name = GetNodeInputName(node->input(kTwoNum), node_map_ptr, graph_proto);
3219 auto equal_name = node_name + "equal";
3220
3221 AddOp("Equal", {input_x_name, input_y_name}, {equal_name}, graph_proto);
3222 AddOp("Not", {equal_name}, {node_name}, graph_proto);
3223 }
3224
ExportPrimDense(const FuncGraphPtr & func_graph,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)3225 void OnnxExporter::ExportPrimDense(const FuncGraphPtr &func_graph, const CNodePtr &node,
3226 std::map<AnfNodePtr, std::string> *node_map_ptr,
3227 onnx::GraphProto *const graph_proto) {
3228 auto matmul_node = dyn_cast<CNode>(node->input(kOneNum));
3229 auto input_x = matmul_node->input(kOneNum); // matmul input x
3230 auto input_y = matmul_node->input(kTwoNum); // matmul input y
3231 auto input_b = node->input(kTwoNum); // matmul bias
3232
3233 PrimitivePtr prim_matmul = dyn_cast<Primitive>((dyn_cast<ValueNode>(matmul_node->input(kZeroNum)))->value());
3234 std::vector<AnfNodePtr> inputs{input_x, input_y, input_b};
3235 (*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim_matmul, inputs, graph_proto);
3236 }
3237
ExportPrimSqueeze(const FuncGraphPtr &,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)3238 void OnnxExporter::ExportPrimSqueeze(const FuncGraphPtr &, const CNodePtr &node,
3239 std::map<AnfNodePtr, std::string> *node_map_ptr,
3240 onnx::GraphProto *const graph_proto) {
3241 auto node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
3242
3243 auto input_name = GetNodeInputName(node->input(kOneNum), node_map_ptr, graph_proto);
3244
3245 onnx::NodeProto *node_proto = graph_proto->add_node();
3246 node_proto->set_op_type("Squeeze");
3247 node_proto->add_input(input_name);
3248 node_proto->add_output(node_name);
3249
3250 auto axes = GetOpAttributePtr<ValueSequence>(node, "axis");
3251 auto axes_value = GetValue<std::vector<int64_t>>(axes);
3252 if (!axes_value.empty()) {
3253 onnx::AttributeProto *axes_proto = node_proto->add_attribute();
3254 axes_proto->set_name("axes");
3255 axes_proto->set_type(onnx::AttributeProto_AttributeType_INTS);
3256 for (auto axis : axes_value) {
3257 axes_proto->add_ints(axis);
3258 }
3259 }
3260 }
3261
MakeLSTMWeight(const std::string & input,const std::string & output,const std::vector<int64_t> & output_shape,onnx::GraphProto * graph_proto)3262 void MakeLSTMWeight(const std::string &input, const std::string &output, const std::vector<int64_t> &output_shape,
3263 onnx::GraphProto *graph_proto) {
3264 auto reshaped_name = output + "__split";
3265 AddReshapeOp(input, reshaped_name, output_shape, graph_proto);
3266
3267 auto split_i_name = output + "__concat_i";
3268 auto split_o_name = output + "__concat_o";
3269 auto split_f_name = output + "__concat_f";
3270 auto split_c_name = output + "__concat_c";
3271 int64_t hidden_size = output_shape[kOneNum] / kFourNum;
3272 AddSplitOp(reshaped_name, {split_i_name, split_f_name, split_c_name, split_o_name},
3273 {hidden_size, hidden_size, hidden_size, hidden_size}, 1, graph_proto);
3274
3275 AddConcatOp({split_i_name, split_o_name, split_f_name, split_c_name}, output, 1, graph_proto);
3276 }
3277
MakeLSTMWeight2(const std::string & input,const std::string & output,const std::vector<int64_t> & output_shape,onnx::GraphProto * graph_proto)3278 void MakeLSTMWeight2(const std::string &input, const std::string &output, const std::vector<int64_t> &output_shape,
3279 onnx::GraphProto *graph_proto) {
3280 auto reshaped_name = output + "__split";
3281 AddReshapeOp(input, reshaped_name, output_shape, graph_proto);
3282
3283 auto split_i_name = output + "__concat_i";
3284 auto split_o_name = output + "__concat_o";
3285 auto split_f_name = output + "__concat_f";
3286 auto split_c_name = output + "__concat_c";
3287 int64_t hidden_size = output_shape[kOneNum] / kFourNum;
3288 AddSplitOp(reshaped_name, {split_i_name, split_c_name, split_f_name, split_o_name},
3289 {hidden_size, hidden_size, hidden_size, hidden_size}, 1, graph_proto);
3290
3291 AddConcatOp({split_i_name, split_o_name, split_f_name, split_c_name}, output, 1, graph_proto);
3292 }
3293
ExportPrimDynamicRNN(const FuncGraphPtr &,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)3294 void OnnxExporter::ExportPrimDynamicRNN(const FuncGraphPtr &, const CNodePtr &node,
3295 std::map<AnfNodePtr, std::string> *node_map_ptr,
3296 onnx::GraphProto *const graph_proto) {
3297 auto node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
3298
3299 auto x_input_name = GetNodeInputName(node->input(kOneNum), node_map_ptr, graph_proto);
3300 auto weight_input_name = GetNodeInputName(node->input(kTwoNum), node_map_ptr, graph_proto);
3301 auto bias_input_name = GetNodeInputName(node->input(kThreeNum), node_map_ptr, graph_proto);
3302 auto init_h_input_name = GetNodeInputName(node->input(kFiveNum), node_map_ptr, graph_proto);
3303 auto init_c_input_name = GetNodeInputName(node->input(kSixNum), node_map_ptr, graph_proto);
3304
3305 auto hidden_size = GetOpAttribute<int64_t>(node, "hidden_size");
3306 auto direction_input = GetOpAttribute<std::string>(node, "direction");
3307 auto x_input_shape = dyn_cast<abstract::Shape>(node->input(kOneNum)->Shape())->shape();
3308 auto seq_len = x_input_shape[0];
3309 auto batch_size = x_input_shape[1];
3310 auto num_dir = direction_input == "UNIDIRECTIONAL" ? 1 : 2;
3311 auto input_size = x_input_shape[kTwoNum];
3312
3313 auto onnx_input_weights_name = node_name + "_onnx_input_weights";
3314 auto onnx_hidden_weights_name = node_name + "_onnx_hidden_weights";
3315 auto onnx_bias_name = node_name + "_onnx_bias";
3316
3317 const int num_gates = 4;
3318 auto gate_size = num_gates * hidden_size;
3319
3320 auto weight_input_name_reshape = weight_input_name + "_reshape";
3321 AddReshapeOp(weight_input_name, weight_input_name_reshape, {(input_size * gate_size) + (hidden_size * gate_size)},
3322 graph_proto);
3323
3324 auto input_weights_name = node_name + "_input_weights";
3325 auto hidden_weights_name = node_name + "_hidden_weights";
3326 std::vector<int64_t> split_sizes = {input_size * gate_size, hidden_size * gate_size};
3327 std::vector<std::string> split_outputs = {input_weights_name, hidden_weights_name};
3328
3329 AddSplitOp(weight_input_name_reshape, split_outputs, split_sizes, 0, graph_proto);
3330
3331 auto input_weights_name_reshape = input_weights_name + "_reshape";
3332 auto hidden_weights_name_reshape = hidden_weights_name + "_reshape";
3333 AddReshapeOp(input_weights_name, input_weights_name_reshape, {num_dir, input_size, gate_size}, graph_proto);
3334 AddReshapeOp(hidden_weights_name, hidden_weights_name_reshape, {num_dir, hidden_size, gate_size}, graph_proto);
3335
3336 // Transpose input_weights_name
3337 onnx::NodeProto *transpose_node_proto_1 = graph_proto->add_node();
3338 auto input_weights_name_reshape_transposed = input_weights_name_reshape + "_transposed";
3339 transpose_node_proto_1->set_name(input_weights_name_reshape_transposed);
3340 transpose_node_proto_1->set_op_type("Transpose");
3341 transpose_node_proto_1->add_input(input_weights_name_reshape);
3342 transpose_node_proto_1->add_output(input_weights_name_reshape_transposed);
3343
3344 onnx::AttributeProto *perm_proto_1 = transpose_node_proto_1->add_attribute();
3345 perm_proto_1->set_name("perm");
3346 perm_proto_1->set_type(onnx::AttributeProto_AttributeType_INTS);
3347 perm_proto_1->add_ints(kZeroNum);
3348 perm_proto_1->add_ints(kTwoNum);
3349 perm_proto_1->add_ints(kOneNum);
3350
3351 // Transpose hidden_weights_name
3352 onnx::NodeProto *transpose_node_proto_2 = graph_proto->add_node();
3353 auto hidden_weights_name_reshape_transposed = hidden_weights_name_reshape + "_transposed";
3354 transpose_node_proto_2->set_name(hidden_weights_name_reshape_transposed);
3355 transpose_node_proto_2->set_op_type("Transpose");
3356 transpose_node_proto_2->add_input(hidden_weights_name_reshape);
3357 transpose_node_proto_2->add_output(hidden_weights_name_reshape_transposed);
3358
3359 onnx::AttributeProto *perm_proto_2 = transpose_node_proto_2->add_attribute();
3360 perm_proto_2->set_name("perm");
3361 perm_proto_2->set_type(onnx::AttributeProto_AttributeType_INTS);
3362 perm_proto_2->add_ints(kZeroNum);
3363 perm_proto_2->add_ints(kTwoNum);
3364 perm_proto_2->add_ints(kOneNum);
3365
3366 MakeLSTMWeight2(input_weights_name_reshape_transposed, onnx_input_weights_name, {num_dir, gate_size, input_size},
3367 graph_proto);
3368 MakeLSTMWeight2(hidden_weights_name_reshape_transposed, onnx_hidden_weights_name, {num_dir, gate_size, hidden_size},
3369 graph_proto);
3370
3371 auto bias_input_name_reshape = bias_input_name + "_reshape";
3372 AddReshapeOp(bias_input_name, bias_input_name_reshape, {num_dir, gate_size}, graph_proto);
3373
3374 auto bias_output_name = node_name + "_bias_output_name";
3375 MakeLSTMWeight2(bias_input_name_reshape, bias_output_name, {num_dir, gate_size}, graph_proto);
3376
3377 auto bias_concat = bias_output_name + "_concat";
3378 std::vector<std::string> concat_inputs = {bias_output_name, bias_output_name};
3379 AddConcatOp(concat_inputs, bias_concat, 1, graph_proto);
3380
3381 auto div_second_operand_name = node_name + "_div_second_operand";
3382 const float div_second_operand = 2.0;
3383 AddFloatScalarInitializer(div_second_operand_name, div_second_operand, onnx::TensorProto_DataType_FLOAT16,
3384 graph_proto);
3385
3386 AddOp("Div", {bias_concat, div_second_operand_name}, {onnx_bias_name}, graph_proto);
3387
3388 // Create LSTM node
3389 onnx::NodeProto *lstm_node_proto = graph_proto->add_node();
3390 lstm_node_proto->set_op_type("LSTM");
3391 lstm_node_proto->add_input(x_input_name);
3392 lstm_node_proto->add_input(onnx_input_weights_name);
3393 lstm_node_proto->add_input(onnx_hidden_weights_name);
3394 lstm_node_proto->add_input(onnx_bias_name);
3395 lstm_node_proto->add_input("");
3396 lstm_node_proto->add_input(init_h_input_name);
3397 lstm_node_proto->add_input(init_c_input_name);
3398
3399 auto Y_output_name = node_name + "_Y";
3400 auto Y_h_output_name = node_name + "_Y_h";
3401 auto Y_c_output_name = node_name + "_Y_c";
3402 lstm_node_proto->add_output(Y_output_name);
3403 lstm_node_proto->add_output(Y_h_output_name);
3404 lstm_node_proto->add_output(Y_c_output_name);
3405
3406 onnx::AttributeProto *hidden_size_proto = lstm_node_proto->add_attribute();
3407 hidden_size_proto->set_name("hidden_size");
3408 hidden_size_proto->set_type(onnx::AttributeProto_AttributeType_INT);
3409 hidden_size_proto->set_i(hidden_size);
3410
3411 auto output_name_Y = MakeOutputName(node_name, kZeroNum);
3412 auto output_name_Y_h = MakeOutputName(node_name, kOneNum);
3413 auto output_name_Y_c = MakeOutputName(node_name, kTwoNum);
3414 AddReshapeOp(Y_output_name, output_name_Y, {seq_len, batch_size, num_dir * hidden_size}, graph_proto);
3415 AddExpandOp(Y_h_output_name, output_name_Y_h, {seq_len, batch_size, hidden_size}, graph_proto);
3416 AddExpandOp(Y_c_output_name, output_name_Y_c, {seq_len, batch_size, hidden_size}, graph_proto);
3417 }
3418
ExportLSTMWeights(const CNodePtr & node,const std::string & node_name,const std::string & weights_name,onnx::TensorProto_DataType dtype,const std::string & onnx_input_weights_name,const std::string & onnx_hidden_weights_name,const std::string & onnx_bias_name,onnx::GraphProto * graph_proto)3419 void ExportLSTMWeights(const CNodePtr &node, const std::string &node_name, const std::string &weights_name,
3420 onnx::TensorProto_DataType dtype, const std::string &onnx_input_weights_name,
3421 const std::string &onnx_hidden_weights_name, const std::string &onnx_bias_name,
3422 onnx::GraphProto *graph_proto) {
3423 auto input_size = GetOpAttribute<int64_t>(node, "input_size");
3424 auto hidden_size = GetOpAttribute<int64_t>(node, "hidden_size");
3425 auto num_layers = GetOpAttribute<int64_t>(node, "num_layers");
3426 auto has_bias = GetOpAttribute<bool>(node, "has_bias");
3427 auto bidirectional = GetOpAttribute<bool>(node, "bidirectional");
3428 auto num_dir = 1 + static_cast<int>(bidirectional);
3429 auto num_gates = 4;
3430 auto gate_size = num_gates * hidden_size;
3431
3432 if (num_layers != 1) {
3433 MS_LOG(EXCEPTION) << "Converter for multilayer LSTM is not implemented";
3434 }
3435 if (bidirectional) {
3436 MS_LOG(EXCEPTION) << "Bidirectional mode for P.LSTM is not implemented";
3437 }
3438 auto ms_context = MsContext::GetInstance();
3439 MS_EXCEPTION_IF_NULL(ms_context);
3440 auto target_device = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
3441 if (target_device != "CPU" && target_device != "GPU") {
3442 MS_LOG(EXCEPTION) << "Unsupported target device: " << target_device;
3443 }
3444
3445 auto input_weights_name = node_name + "_input_weights";
3446 auto hidden_weights_name = node_name + "_hidden_weights";
3447 auto input_bias_name = node_name + "_input_bias";
3448 auto hidden_bias_name = node_name + "_hidden_bias";
3449
3450 std::vector<int64_t> split_sizes = {input_size * gate_size, hidden_size * gate_size};
3451 std::vector<std::string> split_outputs = {input_weights_name, hidden_weights_name};
3452 if (has_bias) {
3453 if (target_device == "GPU") {
3454 (void)split_sizes.insert(split_sizes.end(), {gate_size, gate_size});
3455 (void)split_outputs.insert(split_outputs.end(), {input_bias_name, hidden_bias_name});
3456 } else if (target_device == "CPU") {
3457 split_sizes.push_back(gate_size);
3458 split_outputs.push_back(input_bias_name);
3459 } else {
3460 MS_LOG(EXCEPTION) << "Impossible branch";
3461 }
3462 }
3463 AddSplitOp(weights_name, split_outputs, split_sizes, 0, graph_proto);
3464
3465 MakeLSTMWeight(input_weights_name, onnx_input_weights_name, {num_dir, gate_size, input_size}, graph_proto);
3466 MakeLSTMWeight(hidden_weights_name, onnx_hidden_weights_name, {num_dir, gate_size, hidden_size}, graph_proto);
3467 if (has_bias) {
3468 auto onnx_input_bias_name = node_name + "_onnx_input_bias";
3469 auto onnx_hidden_bias_name = node_name + "_onnx_hidden_bias";
3470 if (target_device == "GPU") {
3471 MakeLSTMWeight(input_bias_name, onnx_input_bias_name, {num_dir, gate_size}, graph_proto);
3472 MakeLSTMWeight(hidden_bias_name, onnx_hidden_bias_name, {num_dir, gate_size}, graph_proto);
3473 } else if (target_device == "CPU") {
3474 MakeLSTMWeight(input_bias_name, onnx_input_bias_name, {num_dir, gate_size}, graph_proto);
3475 auto bias_shape_name = node_name + "_bias_shape";
3476 AddOp("Shape", {onnx_input_bias_name}, {bias_shape_name}, graph_proto);
3477 onnx::TensorProto *zero_padding = AddConstantOfShapeOp(bias_shape_name, onnx_hidden_bias_name, graph_proto);
3478 zero_padding->set_data_type(dtype);
3479 if (dtype == onnx::TensorProto_DataType_FLOAT16) {
3480 zero_padding->add_int32_data(0); // float 0 and int 0 have identical representations
3481 } else if (dtype == onnx::TensorProto_DataType_FLOAT) {
3482 zero_padding->add_float_data(0.0f);
3483 } else {
3484 MS_LOG(EXCEPTION) << "Unsupported type: " << dtype;
3485 }
3486 } else {
3487 MS_LOG(EXCEPTION) << "Impossible branch";
3488 }
3489 AddConcatOp({onnx_input_bias_name, onnx_hidden_bias_name}, onnx_bias_name, 1, graph_proto);
3490 }
3491 }
3492
ExportPrimLSTM(const FuncGraphPtr &,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)3493 void OnnxExporter::ExportPrimLSTM(const FuncGraphPtr &, const CNodePtr &node,
3494 std::map<AnfNodePtr, std::string> *node_map_ptr,
3495 onnx::GraphProto *const graph_proto) {
3496 auto node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
3497
3498 // MS inputs
3499 auto x_input_name = GetNodeInputName(node->input(kOneNum), node_map_ptr, graph_proto);
3500 auto init_h_input_name = GetNodeInputName(node->input(kTwoNum), node_map_ptr, graph_proto);
3501 auto init_c_input_name = GetNodeInputName(node->input(kThreeNum), node_map_ptr, graph_proto);
3502
3503 auto hidden_size = GetOpAttribute<int64_t>(node, "hidden_size");
3504 auto has_bias = GetOpAttribute<bool>(node, "has_bias");
3505 auto bidirectional = GetOpAttribute<bool>(node, "bidirectional");
3506 std::string direction = bidirectional ? "bidirectional" : "forward";
3507 auto x_input_shape = dyn_cast<abstract::Shape>(node->input(kOneNum)->Shape())->shape();
3508 auto seq_len = x_input_shape[0];
3509 auto batch_size = x_input_shape[1];
3510 auto num_dir = 1 + static_cast<int>(bidirectional);
3511
3512 auto weights_name = GetNodeInputName(node->input(kFourNum), node_map_ptr, graph_proto);
3513 auto dtype = GetOutputType(node->input(kOneNum));
3514 auto onnx_input_weights_name = node_name + "_onnx_input_weights";
3515 auto onnx_hidden_weights_name = node_name + "_onnx_hidden_weights";
3516 auto onnx_bias_name = node_name + "_onnx_bias";
3517
3518 ExportLSTMWeights(node, node_name, weights_name, dtype, onnx_input_weights_name, onnx_hidden_weights_name,
3519 onnx_bias_name, graph_proto);
3520
3521 // Create LSTM node
3522 onnx::NodeProto *lstm_node_proto = graph_proto->add_node();
3523 lstm_node_proto->set_op_type("LSTM");
3524 lstm_node_proto->add_input(x_input_name);
3525 lstm_node_proto->add_input(onnx_input_weights_name);
3526 lstm_node_proto->add_input(onnx_hidden_weights_name);
3527 lstm_node_proto->add_input(has_bias ? onnx_bias_name : "");
3528 lstm_node_proto->add_input(""); // seqlens
3529 lstm_node_proto->add_input(init_h_input_name);
3530 lstm_node_proto->add_input(init_c_input_name);
3531
3532 auto Y_output_name = node_name + "_Y";
3533 lstm_node_proto->add_output(Y_output_name);
3534 lstm_node_proto->add_output(MakeOutputName(node_name, kOneNum));
3535 lstm_node_proto->add_output(MakeOutputName(node_name, kTwoNum));
3536
3537 onnx::AttributeProto *hidden_size_proto = lstm_node_proto->add_attribute();
3538 hidden_size_proto->set_name("hidden_size");
3539 hidden_size_proto->set_type(onnx::AttributeProto_AttributeType_INT);
3540 hidden_size_proto->set_i(hidden_size);
3541
3542 onnx::AttributeProto *direction_proto = lstm_node_proto->add_attribute();
3543 direction_proto->set_name("direction");
3544 direction_proto->set_type(onnx::AttributeProto_AttributeType_STRING);
3545 direction_proto->set_s(direction);
3546
3547 // Transpose 1st output of the LSTM node
3548 onnx::NodeProto *transpose_node_proto = graph_proto->add_node();
3549 auto transpose_node_name = node_name + "_Y_transposed";
3550 transpose_node_proto->set_name(transpose_node_name);
3551 transpose_node_proto->set_op_type("Transpose");
3552 transpose_node_proto->add_input(Y_output_name);
3553 transpose_node_proto->add_output(transpose_node_name);
3554
3555 onnx::AttributeProto *perm_proto = transpose_node_proto->add_attribute();
3556 perm_proto->set_name("perm");
3557 perm_proto->set_type(onnx::AttributeProto_AttributeType_INTS);
3558 perm_proto->add_ints(kZeroNum);
3559 perm_proto->add_ints(kTwoNum);
3560 perm_proto->add_ints(kOneNum);
3561 perm_proto->add_ints(kThreeNum);
3562
3563 // Reshape
3564 auto output_name = MakeOutputName(node_name, kZeroNum);
3565 AddReshapeOp(transpose_node_name, output_name, {seq_len, batch_size, num_dir * hidden_size}, graph_proto);
3566 }
3567
ExportPrimReverseV2(const FuncGraphPtr &,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)3568 void OnnxExporter::ExportPrimReverseV2(const FuncGraphPtr &, const CNodePtr &node,
3569 std::map<AnfNodePtr, std::string> *node_map_ptr,
3570 onnx::GraphProto *const graph_proto) {
3571 auto output = RegisterNodeWithUniqueName(node, node_map_ptr);
3572 auto input = GetNodeInputName(node->input(kOneNum), node_map_ptr, graph_proto);
3573
3574 auto axes_ptr = GetOpAttributePtr<ValueSequeue>(node, "axis");
3575 auto axes_vec = GetValue<std::vector<int64_t>>(axes_ptr);
3576 size_t n_axes = axes_vec.size();
3577 auto shape = dyn_cast<abstract::Shape>(node->input(kOneNum)->Shape())->shape();
3578
3579 std::vector<int64_t> starts_vec(n_axes, -1);
3580 std::vector<int64_t> ends_vec(n_axes);
3581 (void)std::transform(axes_vec.begin(), axes_vec.end(), ends_vec.begin(),
3582 [&shape](size_t ax) { return -shape.at(ax) - 1; });
3583 std::vector<int64_t> steps_vec(n_axes, -1);
3584
3585 AddSliceOp(input, output, starts_vec, ends_vec, axes_vec, steps_vec, graph_proto);
3586 }
3587
ExportPrimTensorCopySlices(const FuncGraphPtr &,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * graph_proto)3588 void OnnxExporter::ExportPrimTensorCopySlices(const FuncGraphPtr &, const CNodePtr &node,
3589 std::map<AnfNodePtr, std::string> *node_map_ptr,
3590 onnx::GraphProto *graph_proto) {
3591 auto node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
3592
3593 auto x_input = node->input(kOneNum);
3594 auto value_input = node->input(kTwoNum);
3595
3596 auto x_input_name = GetNodeInputName(x_input, node_map_ptr, graph_proto);
3597 auto value_input_name = GetNodeInputName(value_input, node_map_ptr, graph_proto);
3598
3599 const auto &x_shape = dyn_cast<abstract::Shape>(x_input->Shape())->shape();
3600 const auto &value_shape = dyn_cast<abstract::Shape>(value_input->Shape())->shape();
3601
3602 auto begin_node = dyn_cast<ValueNode>(node->input(kThreeNum));
3603 MS_EXCEPTION_IF_NULL(begin_node);
3604 auto begin = GetValue<std::vector<int64_t>>(begin_node->value());
3605
3606 auto end_node = dyn_cast<ValueNode>(node->input(kFourNum));
3607 MS_EXCEPTION_IF_NULL(end_node);
3608 auto end = GetValue<std::vector<int64_t>>(end_node->value());
3609
3610 auto strides_node = dyn_cast<ValueNode>(node->input(kFiveNum));
3611 MS_EXCEPTION_IF_NULL(strides_node);
3612 auto strides = GetValue<std::vector<int64_t>>(strides_node->value());
3613
3614 MS_EXCEPTION_IF_CHECK_FAIL(
3615 begin.size() == end.size() && end.size() == strides.size() && strides.size() <= x_shape.size(),
3616 "Sizes of begin, end, and strides must be equal");
3617 // MindSpore only allows contuguous slices of memory
3618 // Contiguous slice size follows the pattern: [1, ..., 1, n, :, ..., :]
3619 bool found_slice = false;
3620 for (size_t i = 0; i < begin.size(); ++i) {
3621 int64_t dim = end[i] - begin[i];
3622 if (!found_slice && dim != 1) {
3623 found_slice = true;
3624 } else if (found_slice && dim != x_shape[i]) {
3625 MS_LOG(EXCEPTION) << "Slice must be contiguous";
3626 }
3627 }
3628 for (auto stride : strides) {
3629 MS_EXCEPTION_IF_CHECK_FAIL(stride == 1, "Slice must be contiguous");
3630 }
3631
3632 int64_t flat_begin_index = RavelIndex(begin, x_shape);
3633
3634 std::vector<int64_t> end_inclusive;
3635 (void)std::transform(end.begin(), end.end(), std::back_inserter(end_inclusive), [](auto x) { return x - 1; });
3636 (void)std::transform(x_shape.begin() + static_cast<int64_t>(end.size()), x_shape.end(),
3637 std::back_inserter(end_inclusive), [](auto x) { return x - 1; });
3638 int64_t flat_end_index = RavelIndex(end_inclusive, x_shape) + 1;
3639
3640 int64_t x_size = std::accumulate(x_shape.begin(), x_shape.end(), 1, std::multiplies<int64_t>());
3641 int64_t value_size = std::accumulate(value_shape.begin(), value_shape.end(), 1, std::multiplies<int64_t>());
3642 MS_EXCEPTION_IF_CHECK_FAIL(value_size == flat_end_index - flat_begin_index, "Cannot copy 'value' to target slice");
3643
3644 auto flat_x_name = node_name + "_flat_x";
3645 AddReshapeOp(x_input_name, flat_x_name, {-1}, graph_proto);
3646 auto begin_slice_name = node_name + "_begin_slice";
3647 AddSliceOp(flat_x_name, begin_slice_name, {0}, {static_cast<int64_t>(flat_begin_index)}, {0}, {1}, graph_proto);
3648 auto end_slice_name = node_name + "_end_slice";
3649 AddSliceOp(flat_x_name, end_slice_name, {static_cast<int64_t>(flat_end_index)}, {x_size}, {0}, {1}, graph_proto);
3650
3651 auto flat_value_name = node_name + "_flat_value";
3652 AddReshapeOp(value_input_name, flat_value_name, {-1}, graph_proto);
3653
3654 auto flat_result_name = node_name + "_flat_result";
3655 AddConcatOp({begin_slice_name, flat_value_name, end_slice_name}, flat_result_name, 0, graph_proto);
3656 AddReshapeOp(flat_result_name, node_name, x_shape, graph_proto);
3657 }
3658
ExportPrimStack(const FuncGraphPtr &,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * graph_proto)3659 void OnnxExporter::ExportPrimStack(const FuncGraphPtr &, const CNodePtr &node,
3660 std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto) {
3661 auto node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
3662
3663 auto input_name = GetNodeInputName(node->input(kOneNum), node_map_ptr, graph_proto);
3664
3665 onnx::NodeProto *node_proto = graph_proto->add_node();
3666 node_proto->set_name(node_name + "Stack");
3667 node_proto->set_op_type("ConcatFromSequence");
3668 node_proto->add_input(input_name);
3669 node_proto->add_output(node_name);
3670
3671 onnx::AttributeProto *axis_proto = node_proto->add_attribute();
3672 axis_proto->set_name("axis");
3673 axis_proto->set_type(onnx::AttributeProto_AttributeType_INT);
3674 axis_proto->set_i(GetOpAttribute<int64_t>(node, "axis"));
3675
3676 onnx::AttributeProto *new_axis_proto = node_proto->add_attribute();
3677 new_axis_proto->set_name("new_axis");
3678 new_axis_proto->set_type(onnx::AttributeProto_AttributeType_INT);
3679 new_axis_proto->set_i(true);
3680 }
3681
ExportPrimAtan2(const FuncGraphPtr &,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * graph_proto)3682 void OnnxExporter::ExportPrimAtan2(const FuncGraphPtr &, const CNodePtr &node,
3683 std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto) {
3684 auto node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
3685 auto input_node1_anf = node->input(kOneNum);
3686 auto input_node2_anf = node->input(kTwoNum);
3687 auto input_node1 = GetNodeInputName(input_node1_anf, node_map_ptr, graph_proto);
3688 auto input_node2 = GetNodeInputName(input_node2_anf, node_map_ptr, graph_proto);
3689 auto atan_node = "Atan2_" + node_name + "_atan";
3690 auto div_node = "Atan2_" + node_name + "_div";
3691 auto less_node = "Atan2_" + node_name + "_less";
3692 auto zero_value = "Atan2_" + node_name + "_zero";
3693 auto neg_pi_value = "Atan2_" + node_name + "_pi";
3694 auto minimal_value = "Atan2_" + node_name + "_minimal_val";
3695 auto sign_node = "Atan2_" + node_name + "_sign";
3696 auto mul_node = "Atan2_" + node_name + "_mul";
3697 auto less_where_node1 = "Atan2_" + node_name + "_less_then_else1";
3698 auto add_node = "Atan2_" + node_name + "_add1";
3699 if (!(IsFloatDataType(input_node1_anf) && IsFloatDataType(input_node2_anf))) {
3700 auto input_node1_cast = node_name + "_div_cast_fp32_1";
3701 auto input_node2_cast = node_name + "_div_cast_fp32_2";
3702 AddCastOp(input_node1, input_node1_cast, onnx::TensorProto_DataType_FLOAT, graph_proto);
3703 AddCastOp(input_node2, input_node2_cast, onnx::TensorProto_DataType_FLOAT, graph_proto);
3704 input_node1 = input_node1_cast;
3705 input_node2 = input_node2_cast;
3706 }
3707 AddFloatScalarInitializer(minimal_value, 1e-10, onnx::TensorProto_DataType_FLOAT,
3708 graph_proto); // minimal_value, avoid division by zero
3709 AddOp("Add", {input_node2, minimal_value}, {add_node}, graph_proto);
3710 AddOp("Div", {input_node1, add_node}, {div_node}, graph_proto);
3711 AddOp("Atan", {div_node}, {atan_node}, graph_proto);
3712 AddFloatScalarInitializer(zero_value, 0, onnx::TensorProto_DataType_FLOAT, graph_proto);
3713 AddOp("Less", {input_node2, zero_value}, {less_node}, graph_proto);
3714 AddFloatScalarInitializer(neg_pi_value, -acos(-1), onnx::TensorProto_DataType_FLOAT, graph_proto); // -PI
3715 AddOp("Sign", {atan_node}, {sign_node}, graph_proto);
3716 AddOp("Mul", {neg_pi_value, sign_node}, {mul_node}, graph_proto);
3717 AddOp("Where", {less_node, mul_node, zero_value}, {less_where_node1}, graph_proto);
3718 AddOp("Add", {less_where_node1, atan_node}, {node_name}, graph_proto);
3719 }
3720
ExportPrimFloorDiv(const FuncGraphPtr &,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * graph_proto)3721 void OnnxExporter::ExportPrimFloorDiv(const FuncGraphPtr &, const CNodePtr &node,
3722 std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto) {
3723 auto node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
3724 auto out_name = node_name;
3725 auto input_x_name = GetNodeInputName(node->input(kOneNum), node_map_ptr, graph_proto);
3726 auto input_y_name = GetNodeInputName(node->input(kTwoNum), node_map_ptr, graph_proto);
3727 auto onnx_type = GetOutputType(node->input(kOneNum));
3728 bool is_float = onnx_type == onnx::TensorProto_DataType_FLOAT;
3729
3730 if (!is_float) {
3731 auto input_x_name_cast = input_x_name + "_cast";
3732 auto input_y_name_cast = input_y_name + "_cast";
3733 AddCastOp(input_x_name, input_x_name_cast, onnx::TensorProto_DataType_FLOAT, graph_proto);
3734 AddCastOp(input_y_name, input_y_name_cast, onnx::TensorProto_DataType_FLOAT, graph_proto);
3735 input_x_name = input_x_name_cast;
3736 input_y_name = input_y_name_cast;
3737 node_name = node_name + "_floor";
3738 }
3739
3740 auto div_name = node_name + "_div";
3741 AddOp("Div", {input_x_name, input_y_name}, {div_name}, graph_proto);
3742 AddOp("Floor", {div_name}, {node_name}, graph_proto);
3743
3744 if (!is_float) {
3745 AddCastOp(node_name, out_name, onnx_type, graph_proto);
3746 }
3747 }
3748
ExportPrimFloorMod(const FuncGraphPtr &,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * graph_proto)3749 void OnnxExporter::ExportPrimFloorMod(const FuncGraphPtr &, const CNodePtr &node,
3750 std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto) {
3751 auto node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
3752 auto out_name = node_name;
3753 auto input_x_name = GetNodeInputName(node->input(kOneNum), node_map_ptr, graph_proto);
3754 auto input_y_name = GetNodeInputName(node->input(kTwoNum), node_map_ptr, graph_proto);
3755 auto onnx_type = GetOutputType(node->input(kOneNum));
3756 bool is_float = onnx_type == onnx::TensorProto_DataType_FLOAT;
3757
3758 if (!is_float) {
3759 auto input_x_name_cast = input_x_name + "_cast";
3760 auto input_y_name_cast = input_y_name + "_cast";
3761 AddCastOp(input_x_name, input_x_name_cast, onnx::TensorProto_DataType_FLOAT, graph_proto);
3762 AddCastOp(input_y_name, input_y_name_cast, onnx::TensorProto_DataType_FLOAT, graph_proto);
3763 input_x_name = input_x_name_cast;
3764 input_y_name = input_y_name_cast;
3765 node_name = node_name + "_sub";
3766 }
3767
3768 auto div_name = node_name + "_div";
3769 auto mul_name = node_name + "_mul";
3770 auto floor_name = node_name + "_floor";
3771 AddOp("Div", {input_x_name, input_y_name}, {div_name}, graph_proto);
3772 AddOp("Floor", {div_name}, {floor_name}, graph_proto);
3773 AddOp("Mul", {floor_name, input_y_name}, {mul_name}, graph_proto);
3774 AddOp("Sub", {input_x_name, mul_name}, {node_name}, graph_proto);
3775
3776 if (!is_float) {
3777 AddCastOp(node_name, out_name, onnx_type, graph_proto);
3778 }
3779 }
3780
ExportPrimSort(const FuncGraphPtr &,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * graph_proto)3781 void OnnxExporter::ExportPrimSort(const FuncGraphPtr &, const CNodePtr &node,
3782 std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto) {
3783 auto node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
3784
3785 auto x_input = node->input(kOneNum);
3786 auto x_input_name = GetNodeInputName(x_input, node_map_ptr, graph_proto);
3787 auto x_input_shape = dyn_cast<abstract::Shape>(x_input->Shape())->shape();
3788
3789 auto axis_attr = GetOpAttribute<int64_t>(node, "axis");
3790 auto descending_attr = GetOpAttribute<bool>(node, "descending");
3791
3792 onnx::NodeProto *node_proto = graph_proto->add_node();
3793 node_proto->set_name(node_name + "TopK");
3794 node_proto->set_op_type("TopK");
3795 node_proto->add_input(x_input_name);
3796
3797 onnx::TensorProto *k_initializer_proto = graph_proto->add_initializer();
3798 auto k_input_name = "k";
3799 k_initializer_proto->set_name(k_input_name);
3800 k_initializer_proto->add_dims(static_cast<int64_t>(1));
3801 k_initializer_proto->set_data_type(GetOnnxDataType(kNumberTypeInt64));
3802 int64_t k_index = axis_attr;
3803 if (axis_attr < 0) {
3804 k_index += SizeToLong(x_input_shape.size());
3805 }
3806 if (k_index > SizeToLong(x_input_shape.size()) - 1 || k_index < 0) {
3807 MS_LOG(EXCEPTION) << "Invalid axis value: " << axis_attr;
3808 }
3809 int64_t k_value = x_input_shape[k_index];
3810 k_initializer_proto->add_int64_data(k_value);
3811 node_proto->add_input(k_input_name);
3812
3813 node_proto->add_output(MakeOutputName(node_name, kZeroNum));
3814 auto indices_output_name = MakeOutputName(node_name, kOneNum);
3815 auto indices_cast_name = indices_output_name + "_cast";
3816 node_proto->add_output(indices_cast_name);
3817 AddCastOp(indices_cast_name, indices_output_name, onnx::TensorProto_DataType_INT32, graph_proto);
3818
3819 onnx::AttributeProto *axis_attr_proto = node_proto->add_attribute();
3820 axis_attr_proto->set_name("axis");
3821 axis_attr_proto->set_type(onnx::AttributeProto_AttributeType_INT);
3822 axis_attr_proto->set_i(axis_attr);
3823
3824 onnx::AttributeProto *largest_attr_proto = node_proto->add_attribute();
3825 largest_attr_proto->set_name("largest");
3826 largest_attr_proto->set_type(onnx::AttributeProto_AttributeType_INT);
3827 if (descending_attr) {
3828 largest_attr_proto->set_i(kOneNum);
3829 } else {
3830 largest_attr_proto->set_i(kZeroNum);
3831 }
3832
3833 onnx::AttributeProto *sorted_attr_proto = node_proto->add_attribute();
3834 sorted_attr_proto->set_name("sorted");
3835 sorted_attr_proto->set_type(onnx::AttributeProto_AttributeType_INT);
3836 sorted_attr_proto->set_i(1);
3837 }
3838
ExportPrimCustom(const FuncGraphPtr &,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * graph_proto)3839 void OnnxExporter::ExportPrimCustom(const FuncGraphPtr &, const CNodePtr &node,
3840 std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *graph_proto) {
3841 auto node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
3842 onnx::NodeProto *node_proto = graph_proto->add_node();
3843 node_proto->set_name("Custom_" + node_name);
3844 mindspore::HashSet<size_t> input_attrs;
3845
3846 constexpr auto kAttrCusInputNames = "input_names";
3847 constexpr auto kAttrCusAttrNames = "attr_names";
3848 auto input_names_vec = GetOpAttribute<std::vector<std::string>>(node, kAttrCusInputNames);
3849 auto primitive = GetPrimitive(node);
3850 auto attr_names = primitive->GetAttr(kAttrCusAttrNames);
3851 if (attr_names != nullptr) {
3852 auto attr_names_vec = GetValue<std::vector<std::string>>(attr_names);
3853 for (size_t i = 0; i < input_names_vec.size(); ++i) {
3854 if (std::find(attr_names_vec.begin(), attr_names_vec.end(), input_names_vec[i]) != attr_names_vec.end()) {
3855 (void)input_attrs.insert(i);
3856 }
3857 }
3858 }
3859
3860 auto inputs = node->inputs();
3861 std::vector<AnfNodePtr> real_inputs;
3862
3863 for (size_t i = 0; i < inputs.size() - 1; ++i) {
3864 auto input_node = inputs[i + 1];
3865 MS_EXCEPTION_IF_NULL(input_node);
3866 if (input_attrs.find(i) != input_attrs.end() && input_node->isa<ValueNode>() && !HasAbstractMonad(input_node)) {
3867 auto value_node = input_node->cast<ValueNodePtr>();
3868 MS_EXCEPTION_IF_NULL(value_node);
3869 auto attr_value = value_node->value();
3870 if (attr_value->isa<StringImm>()) {
3871 auto str_attr = GetValue<std::string>(attr_value);
3872 onnx::AttributeProto *str_proto = node_proto->add_attribute();
3873 str_proto->set_name(input_names_vec[i]);
3874 str_proto->set_type(onnx::AttributeProto_AttributeType_STRING);
3875 str_proto->set_s(str_attr);
3876 } else if (attr_value->isa<IntegerImm>()) {
3877 int64_t int64_attr = attr_value->cast<Int64ImmPtr>()->value();
3878 onnx::AttributeProto *int64_proto = node_proto->add_attribute();
3879 int64_proto->set_name(input_names_vec[i]);
3880 int64_proto->set_type(onnx::AttributeProto_AttributeType_INT);
3881 int64_proto->set_i(int64_attr);
3882 } else if (attr_value->isa<FloatImm>()) {
3883 float fp32_attr = attr_value->cast<FP32ImmPtr>()->value();
3884 onnx::AttributeProto *fp32_proto = node_proto->add_attribute();
3885 fp32_proto->set_name(input_names_vec[i]);
3886 fp32_proto->set_type(onnx::AttributeProto_AttributeType_FLOAT);
3887 fp32_proto->set_f(fp32_attr);
3888 } else {
3889 MS_LOG(EXCEPTION) << "Unsupported attr input type: " << attr_value->ToString();
3890 }
3891 } else {
3892 real_inputs.push_back(inputs[i + 1]);
3893 }
3894 }
3895
3896 for (size_t idx = 0; idx < real_inputs.size(); idx++) {
3897 auto input_name = GetNodeInputName(real_inputs[idx], node_map_ptr, graph_proto);
3898 node_proto->add_input(input_name);
3899 }
3900
3901 node_proto->add_output(node_name);
3902 node_proto->set_op_type(GetOpAttribute<std::string>(node, "reg_op_name"));
3903 }
3904
ExportCNode(const FuncGraphPtr & func_graph,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)3905 void OnnxExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &node,
3906 std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *const graph_proto) {
3907 using ExportFunc = std::function<void(OnnxExporter *, const FuncGraphPtr &, const CNodePtr &,
3908 std::map<AnfNodePtr, std::string> *, onnx::GraphProto *const)>;
3909 static std::vector<std::pair<PrimitivePtr, ExportFunc>> export_table = {
3910 {prim::kPrimReshape, &OnnxExporter::ExportPrimReshape},
3911 {prim::kPrimReduceMean, &OnnxExporter::ExportPrimReduce},
3912 {prim::kPrimReduceSum, &OnnxExporter::ExportPrimReduce},
3913 {prim::kPrimReduceMax, &OnnxExporter::ExportPrimReduce},
3914 {prim::kPrimReduceAny, &OnnxExporter::ExportPrimReduceAnyOrAll},
3915 {prim::kPrimReduceAll, &OnnxExporter::ExportPrimReduceAnyOrAll},
3916 {prim::kPrimTranspose, &OnnxExporter::ExportPrimTranspose},
3917 {prim::kPrimStridedSlice, &OnnxExporter::ExportPrimStridedSlice},
3918 {prim::kPrimResizeNearestNeighbor, &OnnxExporter::ExportPrimResizeNearestNeighbor},
3919 {prim::kPrimResizeBilinearV2, &OnnxExporter::ExportPrimResizeBilinear},
3920 {prim::kPrimConcat, &OnnxExporter::ExportPrimConcat},
3921 {prim::kPrimCast, &OnnxExporter::ExportPrimCast},
3922 {prim::kPrimPReLU, &OnnxExporter::ExportPrimPReLU},
3923 {prim::kPrimReLU6, &OnnxExporter::ExportPrimReLU6},
3924 {prim::kPrimDepthwiseConv2dNative, &OnnxExporter::ExportPrimDepthwiseConv2d},
3925 {prim::kPrimTile, &OnnxExporter::ExportPrimTile},
3926 {prim::kPrimSquare, &OnnxExporter::ExportPrimSquare},
3927 {prim::kPrimGather, &OnnxExporter::ExportPrimGatherV2},
3928 {prim::kPrimTupleGetItem, &OnnxExporter::ExportPrimTupleGetItem},
3929 {prim::kPrimTopK, &OnnxExporter::ExportPrimTopK},
3930 {prim::kPrimBoundingBoxDecode, &OnnxExporter::ExportPrimBoundingBoxDecode},
3931 {prim::kPrimNMSWithMask, &OnnxExporter::ExportPrimNMSWithMask},
3932 {prim::kPrimSplit, &OnnxExporter::ExportPrimSplit},
3933 {prim::kPrimROIAlign, &OnnxExporter::ExportPrimROIAlign},
3934 {prim::kPrimSlice, &OnnxExporter::ExportPrimSlice},
3935 {prim::kPrimOnesLike, &OnnxExporter::ExportPrimOnesLike},
3936 {prim::kPrimScatterNd, &OnnxExporter::ExportPrimScatterNd},
3937 {prim::kPrimArgMaxWithValue, &OnnxExporter::ExportPrimArgMaxWithValue},
3938 {prim::kPrimArgMinWithValue, &OnnxExporter::ExportPrimArgMinWithValue},
3939 {prim::kPrimOneHot, &OnnxExporter::ExportPrimOneHot},
3940 {prim::kPrimConv2DTranspose, &OnnxExporter::ExportPrimConv2DTranspose},
3941 {prim::kPrimGreaterEqual, &OnnxExporter::ExportPrimGreaterEqual},
3942 {prim::kPrimLessEqual, &OnnxExporter::ExportPrimLessEqual},
3943 {prim::kPrimNotEqual, &OnnxExporter::ExportPrimNotEqual},
3944 {prim::kPrimDense, &OnnxExporter::ExportPrimDense},
3945 {prim::kPrimSqueeze, &OnnxExporter::ExportPrimSqueeze},
3946 {prim::kPrimExpandDims, &OnnxExporter::ExportPrimExpandDims},
3947 {prim::kPrimGatherD, &OnnxExporter::ExportPrimGatherD},
3948 {prim::kPrimPad, &OnnxExporter::ExportPrimPad},
3949 {prim::kPrimBatchMatMul, &OnnxExporter::ExportPrimBatchMatMul},
3950 {prim::kPrimBroadcastTo, &OnnxExporter::ExportPrimBroadcastTo},
3951 {prim::kPrimAddN, &OnnxExporter::ExportPrimAddN},
3952 {prim::kPrimGeLU, &OnnxExporter::ExportPrimGeLU},
3953 {prim::kPrimLstm, &OnnxExporter::ExportPrimLSTM},
3954 {prim::kPrimReverseV2, &OnnxExporter::ExportPrimReverseV2},
3955 {prim::kPrimTensorCopySlices, &OnnxExporter::ExportPrimTensorCopySlices},
3956 {prim::kPrimDynamicRNN, &OnnxExporter::ExportPrimDynamicRNN},
3957 {prim::kPrimStack, &OnnxExporter::ExportPrimStack},
3958 {prim::kPrimAtan2, &OnnxExporter::ExportPrimAtan2},
3959 {prim::kPrimFloorDiv, &OnnxExporter::ExportPrimFloorDiv},
3960 {prim::kPrimFloorMod, &OnnxExporter::ExportPrimFloorMod},
3961 {prim::kPrimSort, &OnnxExporter::ExportPrimSort},
3962 {prim::kPrimCustom, &OnnxExporter::ExportPrimCustom},
3963 };
3964
3965 auto iter = std::find_if(export_table.begin(), export_table.end(),
3966 [&node](const auto &item) { return node->IsApply(item.first); });
3967 if (iter != export_table.end()) {
3968 iter->second(this, func_graph, node, node_map_ptr, graph_proto);
3969 return;
3970 }
3971
3972 auto inputs = node->inputs();
3973 if (inputs.size() < 1) {
3974 MS_LOG(EXCEPTION) << "Inputs of apply node is empty";
3975 }
3976
3977 AnfNodePtr op = inputs[kZeroNum];
3978 std::vector<AnfNodePtr> op_inputs;
3979 // first process node input 1,2,..., since when node input is a ValueNode, here need to create a Constant Operator
3980 for (size_t i = 1; i < inputs.size(); i++) {
3981 if (!HasAbstractMonad(inputs[i])) {
3982 op_inputs.push_back(inputs[i]);
3983 }
3984 }
3985
3986 if (!op->isa<ValueNode>()) {
3987 MS_LOG(EXCEPTION) << "Need to support node op type " << op->type_name();
3988 }
3989
3990 auto op_value = dyn_cast<ValueNode>(op)->value();
3991 if (op_value->isa<Primitive>()) {
3992 auto prim = dyn_cast<Primitive>(op_value);
3993 (*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim, op_inputs, graph_proto);
3994 } else if (while_loop_export::IsControlSubgraph(op_value)) {
3995 ExportWhileLoop(node, node_map_ptr, graph_proto);
3996 } else {
3997 MS_LOG(EXCEPTION) << "Need to support node op value type " << op_value->type_name();
3998 }
3999 }
4000
ExportWhileLoop(const CNodePtr & start_node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * graph_proto)4001 void OnnxExporter::ExportWhileLoop(const CNodePtr &start_node, std::map<AnfNodePtr, std::string> *node_map_ptr,
4002 onnx::GraphProto *graph_proto) {
4003 auto node_name = RegisterNodeWithUniqueName(start_node, node_map_ptr);
4004 auto loop_parts = while_loop_export::MatchGraph(start_node);
4005
4006 // 1. Make Loop op
4007
4008 onnx::NodeProto *loop_proto = graph_proto->add_node();
4009 loop_proto->set_op_type("Loop");
4010
4011 auto loop_count_name = node_name + "_M";
4012 const auto &loop_counter_params = loop_parts.loop_condition_info;
4013 int64_t loop_count = (loop_counter_params.end - loop_counter_params.begin) / loop_counter_params.step;
4014 onnx::TensorProto *loop_count_proto = graph_proto->add_initializer();
4015 loop_count_proto->set_name(loop_count_name);
4016 loop_count_proto->set_data_type(onnx::TensorProto_DataType_INT64);
4017 loop_count_proto->add_int64_data(loop_count);
4018
4019 auto loop_cond_name = node_name + "_cond";
4020 auto *cond_value = graph_proto->add_initializer();
4021 cond_value->set_name(loop_cond_name);
4022 cond_value->set_data_type(onnx::TensorProto_DataType_BOOL);
4023 cond_value->add_int32_data(true);
4024
4025 loop_proto->add_input(loop_count_name);
4026 loop_proto->add_input(loop_cond_name);
4027 for (const auto &[loop_i, control_i] : loop_parts.used_loop_to_control_param_indices) {
4028 auto name = GetNodeInputName(start_node->input(control_i + 1), node_map_ptr, graph_proto);
4029 loop_proto->add_input(name);
4030 loop_proto->add_output(MakeOutputName(node_name + "_loop", loop_i));
4031 }
4032
4033 onnx::AttributeProto *subgraph_attr = loop_proto->add_attribute();
4034 subgraph_attr->set_type(onnx::AttributeProto_AttributeType_GRAPH);
4035 subgraph_attr->set_name("body");
4036 onnx::GraphProto *loop_subgraph_proto = subgraph_attr->mutable_g();
4037
4038 // 2. Create subgraph for loop body
4039
4040 auto subgraph_name = loop_parts.loop_subgraph->ToString();
4041 auto subgraph_input_cond_name = subgraph_name + "_input_cond";
4042
4043 auto *iter_num_input = loop_subgraph_proto->add_input();
4044 iter_num_input->set_name(subgraph_name + "_input_M");
4045 (void)iter_num_input->mutable_type()->mutable_tensor_type()->mutable_shape(); // side-effect: shape created
4046 iter_num_input->mutable_type()->mutable_tensor_type()->set_elem_type(onnx::TensorProto_DataType_INT64);
4047
4048 auto *cond_input = loop_subgraph_proto->add_input();
4049 cond_input->set_name(subgraph_input_cond_name);
4050 cond_input->mutable_type()->mutable_tensor_type()->set_elem_type(cond_value->data_type());
4051
4052 auto *cond_output = loop_subgraph_proto->add_output();
4053 cond_output->set_name(cond_input->name());
4054 cond_output->mutable_type()->mutable_tensor_type()->set_elem_type(cond_value->data_type());
4055
4056 MS_EXCEPTION_IF_CHECK_FAIL(renamed_node_map_.empty(), "renamed_nodes must be cleared after subgraph export");
4057 for (size_t i : loop_parts.ignored_loop_param_indices) {
4058 const auto ¶m = loop_parts.loop_subgraph->parameters().at(i);
4059 renamed_node_map_[param] = "";
4060 }
4061
4062 // Export everything except the control call and the output (see MatchAndMark)
4063 ExportFuncGraph(loop_parts.loop_subgraph, node_map_ptr, loop_subgraph_proto);
4064
4065 // Export outputs manually
4066 for (const auto &loop_to_control_i : loop_parts.used_loop_to_control_param_indices) {
4067 const auto &input = loop_parts.repeat_node->input(loop_to_control_i.second + 1);
4068 ExportOutput(loop_parts.loop_subgraph, input, node_map_ptr, loop_subgraph_proto);
4069 }
4070 renamed_node_map_.clear();
4071
4072 // 3. Export part after loop
4073
4074 MS_EXCEPTION_IF_CHECK_FAIL(renamed_node_map_.empty(), "renamed_nodes must be cleared after subgraph export");
4075 const auto &after_loop_params = loop_parts.after_loop_subgraph->parameters();
4076 for (const auto &[after_i, output_i] : loop_parts.after_param_to_output_indices) {
4077 MS_EXCEPTION_IF_CHECK_FAIL(static_cast<int>(output_i) < loop_proto->output_size(), "Output index out of bounds");
4078 renamed_node_map_[after_loop_params.at(after_i)] = loop_proto->output(output_i);
4079 }
4080 ExportFuncGraph(loop_parts.after_loop_subgraph, node_map_ptr, graph_proto, false);
4081
4082 auto after_loop_retval = GetRealInput(loop_parts.after_loop_subgraph->get_return()->input(1));
4083 if (after_loop_retval->isa<CNode>() && after_loop_retval->cast<CNodePtr>()->IsApply(prim::kPrimMakeTuple)) {
4084 auto tuple_retval = dyn_cast<CNode>(after_loop_retval);
4085 for (size_t i = 1; i < tuple_retval->size(); ++i) {
4086 auto output_name = GetNodeInputName(tuple_retval->input(i), node_map_ptr, graph_proto);
4087 AddOp("Identity", {output_name}, {MakeOutputName(node_name, i - 1)}, graph_proto);
4088 }
4089 } else {
4090 auto output_name = GetNodeInputName(after_loop_retval, node_map_ptr, graph_proto);
4091 AddOp("Identity", {output_name}, {node_name}, graph_proto);
4092 }
4093 renamed_node_map_.clear();
4094 }
4095
GetOutputType(const AnfNodePtr & node,int64_t output_index)4096 onnx::TensorProto_DataType OnnxExporter::GetOutputType(const AnfNodePtr &node, int64_t output_index) {
4097 auto unpacked = GetRealInput(node);
4098 if (IsPrimitiveCNode(unpacked, prim::kPrimTupleGetItem)) {
4099 if (output_index != -1) {
4100 MS_LOG(EXCEPTION) << "Unexpected output index for TupleGetItem: " << output_index;
4101 }
4102 auto cnode = dyn_cast<CNode>(unpacked);
4103 unpacked = cnode->input(kOneNum);
4104 output_index = GetInt64Value(cnode->input(kTwoNum));
4105 }
4106
4107 /*
4108 Special cases (MS and ONNX type differences) go here
4109 Example:
4110 if (IsPrimitiveCNode(unpacked, prim::kPrim<Something>) && output_index == <i>) {
4111 return onnx::TensorProto_DataType_<TYPE>;
4112 }
4113 */
4114
4115 if (output_index == -1) {
4116 auto tensor = dyn_cast<TensorType>(unpacked->Type());
4117 if (tensor == nullptr) {
4118 MS_LOG(EXCEPTION) << "Expected output of node " << unpacked->ToString()
4119 << " to be a single tensor. Instead got: " << unpacked->Type()->ToString();
4120 }
4121 return GetOnnxDataType(tensor->element()->type_id());
4122 } else {
4123 auto tuple_type = dyn_cast<Tuple>(unpacked->Type());
4124 if (tuple_type == nullptr) {
4125 MS_LOG(EXCEPTION) << "Expected output of node " << unpacked->ToString()
4126 << " to be a tuple. Instead got: " << unpacked->Type()->ToString();
4127 }
4128 auto element_type = tuple_type->elements()[static_cast<size_t>(output_index)];
4129 MS_EXCEPTION_IF_NULL(element_type);
4130 auto tensor_type = dyn_cast<TensorType>(element_type);
4131 if (tensor_type == nullptr) {
4132 MS_LOG(EXCEPTION) << "Expected output " << output_index << " of node " << unpacked->ToString()
4133 << " to be a tensor. Instead got: " << element_type->ToString();
4134 }
4135 return GetOnnxDataType(tensor_type->element()->type_id());
4136 }
4137 }
4138
AddOutputWithCast(onnx::NodeProto * node_proto,const std::string & output_name,onnx::TensorProto_DataType target_type,onnx::GraphProto * graph_proto) const4139 void OnnxExporter::AddOutputWithCast(onnx::NodeProto *node_proto, const std::string &output_name,
4140 onnx::TensorProto_DataType target_type, onnx::GraphProto *graph_proto) const {
4141 if (target_type == onnx::TensorProto_DataType_UNDEFINED) {
4142 node_proto->add_output(output_name);
4143 } else {
4144 auto output_to_cast_name = output_name + "_output_to_cast";
4145 node_proto->add_output(output_to_cast_name);
4146 AddCastOp(output_to_cast_name, output_name, target_type, graph_proto);
4147 }
4148 }
4149
ExportPrimitive(const FuncGraphPtr &,std::map<AnfNodePtr,std::string> * node_map_ptr,const PrimitivePtr & prim,const std::vector<AnfNodePtr> & inputs,onnx::GraphProto * const graph_proto)4150 std::string OnnxExporter::ExportPrimitive(const FuncGraphPtr &, std::map<AnfNodePtr, std::string> *node_map_ptr,
4151 const PrimitivePtr &prim, const std::vector<AnfNodePtr> &inputs,
4152 onnx::GraphProto *const graph_proto) {
4153 auto op_map = OpConvertRegistry::GetOpConvertMap();
4154 MS_EXCEPTION_IF_NULL(prim);
4155 auto op_iter = op_map.find(prim->name());
4156 if (op_iter == op_map.end()) {
4157 MS_LOG(EXCEPTION) << "Can not find key " << prim->name() << " in convert map. "
4158 << "Exporting " << prim->name() << " operator is not yet supported.";
4159 }
4160 // Get input first, because input maybe valuenode which need create constant node
4161 std::vector<std::string> input_list;
4162 for (const auto &input : inputs) {
4163 auto input_name = GetNodeInputName(input, node_map_ptr, graph_proto);
4164 input_list.push_back(input_name);
4165 }
4166
4167 const OpNameInfo &op_convert_info = op_iter->second;
4168 auto node_name = GenerateUniqueName();
4169
4170 std::vector<onnx::TensorProto_DataType> output_cast_types(op_convert_info.num_outputs(),
4171 onnx::TensorProto_DataType_UNDEFINED);
4172 // Cast inputs if needed
4173 for (const auto &rule : op_convert_info.input_casts()) {
4174 auto original_type = GetOutputType(inputs[static_cast<size_t>(rule.input_index)]);
4175 if (original_type != rule.input_type) {
4176 continue;
4177 }
4178
4179 auto cast_input_name = node_name + "cast_input_" + std::to_string(rule.input_index);
4180 AddCastOp(input_list[static_cast<size_t>(rule.input_index)], cast_input_name, rule.target_type, graph_proto);
4181 input_list[static_cast<size_t>(rule.input_index)] = cast_input_name;
4182
4183 auto output_cast = std::find_if(
4184 op_convert_info.output_casts().begin(), op_convert_info.output_casts().end(), [&rule](const OutputConversion &x) {
4185 return x.mode == OutputConversion::Mode::INPUT && x.input_with_matching_type == rule.input_index;
4186 });
4187 if (output_cast != op_convert_info.output_casts().end()) {
4188 output_cast_types[static_cast<size_t>(output_cast->output_index)] = original_type;
4189 }
4190 }
4191
4192 for (const auto &output_cast : op_convert_info.output_casts()) {
4193 if (output_cast.mode == OutputConversion::Mode::FIXED) {
4194 output_cast_types[static_cast<size_t>(output_cast.output_index)] = output_cast.target_type;
4195 }
4196 }
4197
4198 onnx::NodeProto *node_proto = graph_proto->add_node();
4199 node_proto->set_name(node_name + op_convert_info.onnx_type());
4200 node_proto->set_op_type(op_convert_info.onnx_type());
4201
4202 // Set outputs
4203 if (op_convert_info.num_outputs() == 1) {
4204 AddOutputWithCast(node_proto, node_name, output_cast_types[0], graph_proto);
4205 } else {
4206 for (int i = 0; i < op_convert_info.num_outputs(); ++i) {
4207 auto output_name = MakeOutputName(node_name, i);
4208 AddOutputWithCast(node_proto, output_name, output_cast_types[static_cast<size_t>(i)], graph_proto);
4209 }
4210 }
4211
4212 // Set inputs
4213 for (const auto &input_name : input_list) {
4214 node_proto->add_input(input_name);
4215 }
4216
4217 // Set node attribute
4218 for (const OpAttrInfo &attr : op_convert_info.op_attrs()) {
4219 const std::string &attr_name = attr.attr_name();
4220 ValuePtr attr_value = nullptr;
4221 if (!attr_name.empty()) {
4222 attr_value = prim->GetAttr(attr_name);
4223 if (attr_value == nullptr) {
4224 MS_LOG(EXCEPTION) << "Primitive " << prim->name() << " does not have attribute " << attr_name;
4225 }
4226 }
4227 onnx::AttributeProto *onnx_attr_proto = node_proto->add_attribute();
4228 onnx_attr_proto->set_name(attr.onnx_attr_name());
4229 attr.fn_gen_attr()(attr_value, attr.onnx_attr_type(), onnx_attr_proto, prim);
4230 }
4231 return node_name;
4232 }
4233
ExportMergeConv(const FuncGraphPtr & func_graph,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)4234 void OnnxExporter::ExportMergeConv(const FuncGraphPtr &func_graph, const CNodePtr &node,
4235 std::map<AnfNodePtr, std::string> *node_map_ptr,
4236 onnx::GraphProto *const graph_proto) {
4237 auto conv_node = dyn_cast<CNode>(node->input(kOneNum));
4238 auto input_x = conv_node->input(kOneNum); // conv input x
4239 auto input_w = conv_node->input(kTwoNum); // conv weight(filter)
4240 auto input_b = node->input(kTwoNum); // conv bias
4241
4242 PrimitivePtr prim_conv = dyn_cast<Primitive>((dyn_cast<ValueNode>(conv_node->input(kZeroNum)))->value());
4243 std::vector<AnfNodePtr> inputs{input_x, input_w, input_b};
4244 (*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim_conv, inputs, graph_proto);
4245 }
4246
ExportMergeGemm(const FuncGraphPtr & func_graph,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)4247 void OnnxExporter::ExportMergeGemm(const FuncGraphPtr &func_graph, const CNodePtr &node,
4248 std::map<AnfNodePtr, std::string> *node_map_ptr,
4249 onnx::GraphProto *const graph_proto) {
4250 auto matmul_node = dyn_cast<CNode>(node->input(kOneNum));
4251 auto input_x = matmul_node->input(kOneNum); // matmul input x
4252 auto input_y = matmul_node->input(kTwoNum); // matmul input y
4253 auto input_b = node->input(kTwoNum); // matmul bias
4254
4255 PrimitivePtr prim_matmul = dyn_cast<Primitive>((dyn_cast<ValueNode>(matmul_node->input(kZeroNum)))->value());
4256 std::vector<AnfNodePtr> inputs{input_x, input_y, input_b};
4257 (*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim_matmul, inputs, graph_proto);
4258 }
4259
ExportMergeBatchNorm(const FuncGraphPtr & func_graph,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)4260 void OnnxExporter::ExportMergeBatchNorm(const FuncGraphPtr &func_graph, const CNodePtr &node,
4261 std::map<AnfNodePtr, std::string> *node_map_ptr,
4262 onnx::GraphProto *const graph_proto) {
4263 auto batch_norm_node = dyn_cast<CNode>(node->input(kOneNum));
4264
4265 auto is_training = GetOpAttribute<bool>(batch_norm_node, "is_training");
4266 if (is_training) {
4267 auto input_x_name = GetNodeInputName(batch_norm_node->input(kOneNum), node_map_ptr, graph_proto);
4268 auto scale_input_name = GetNodeInputName(batch_norm_node->input(kTwoNum), node_map_ptr, graph_proto);
4269 auto bias_input_name = GetNodeInputName(batch_norm_node->input(kThreeNum), node_map_ptr, graph_proto);
4270
4271 auto onnx_type = GetOutputType(batch_norm_node->input(kOneNum));
4272
4273 auto output_name = RegisterNodeWithUniqueName(node, node_map_ptr);
4274
4275 auto input_shape_ptr = batch_norm_node->input(kOneNum)->Shape();
4276 auto input_shape = input_shape_ptr->cast<abstract::ShapePtr>()->shape();
4277
4278 std::vector<int64_t> normalize_axes = {0};
4279 for (size_t i = kTwoNum; i < input_shape.size(); ++i) {
4280 normalize_axes.push_back(static_cast<int64_t>(i));
4281 }
4282
4283 std::vector<int64_t> scale_bias_shape(input_shape.size(), 1);
4284 scale_bias_shape[1] = -1;
4285 auto reshaped_scale_name = output_name + "_reshaped_scale";
4286 AddReshapeOp(scale_input_name, reshaped_scale_name, scale_bias_shape, graph_proto);
4287 auto reshaped_bias_name = output_name + "_reshaped_bias";
4288 AddReshapeOp(bias_input_name, reshaped_bias_name, scale_bias_shape, graph_proto);
4289 auto epsilon = GetOpAttribute<float>(batch_norm_node, "epsilon");
4290
4291 AddMeanVarianceNormalizationOp(input_x_name, reshaped_scale_name, reshaped_bias_name, output_name, normalize_axes,
4292 epsilon, input_shape, onnx_type, graph_proto);
4293 } else {
4294 PrimitivePtr prim_batch_norm = GetPrimitive(batch_norm_node);
4295 std::vector<AnfNodePtr> inputs;
4296 for (size_t i = 1; i < batch_norm_node->size(); i++) {
4297 inputs.push_back(batch_norm_node->input(i));
4298 }
4299 (*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim_batch_norm, inputs, graph_proto);
4300 }
4301 }
4302
ExportMergeMaxPoolWithArgmax(const FuncGraphPtr & func_graph,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)4303 void OnnxExporter::ExportMergeMaxPoolWithArgmax(const FuncGraphPtr &func_graph, const CNodePtr &node,
4304 std::map<AnfNodePtr, std::string> *node_map_ptr,
4305 onnx::GraphProto *const graph_proto) {
4306 auto maxpool_with_argmax_node = dyn_cast<CNode>(node->input(kOneNum));
4307
4308 PrimitivePtr prim_maxpool_with_argmax =
4309 dyn_cast<Primitive>((dyn_cast<ValueNode>(maxpool_with_argmax_node->input(kZeroNum)))->value());
4310 std::vector<AnfNodePtr> inputs;
4311 for (size_t i = 1; i < maxpool_with_argmax_node->size(); i++) {
4312 inputs.push_back(maxpool_with_argmax_node->input(i));
4313 }
4314 (*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim_maxpool_with_argmax, inputs, graph_proto);
4315 }
4316
4317 // LayerNorm(N, C1, H, W) --> reshape(1, C2, 1, W) + MeanVarianceNormalization + reshape(N, C1, H, W)
ExportMergeLayerNorm(const FuncGraphPtr &,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)4318 void OnnxExporter::ExportMergeLayerNorm(const FuncGraphPtr &, const CNodePtr &node,
4319 std::map<AnfNodePtr, std::string> *node_map_ptr,
4320 onnx::GraphProto *const graph_proto) {
4321 auto LayerNormNode = dyn_cast<CNode>(node->input(kOneNum));
4322 auto layernorm_input_x = GetNodeInputName(LayerNormNode->input(kOneNum), node_map_ptr, graph_proto);
4323 auto layernorm_input_gamma = GetNodeInputName(LayerNormNode->input(kTwoNum), node_map_ptr, graph_proto);
4324 auto layernorm_input_beta = GetNodeInputName(LayerNormNode->input(kThreeNum), node_map_ptr, graph_proto);
4325
4326 auto begin_norm_axis = GetOpAttribute<int64_t>(LayerNormNode, "begin_norm_axis");
4327 auto begin_params_axis = GetOpAttribute<int64_t>(LayerNormNode, "begin_params_axis");
4328 if (begin_norm_axis != -1 || begin_params_axis != -1) {
4329 MS_LOG(EXCEPTION) << "begin_norm_axis != -1 and begin_params_axis != -1 are not implemented";
4330 }
4331
4332 auto onnx_type = GetOutputType(LayerNormNode->input(kOneNum));
4333 auto input_shape = dyn_cast<abstract::Shape>(LayerNormNode->input(kOneNum)->Shape())->shape();
4334 auto node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
4335 auto epsilon = GetOpAttribute<float>(LayerNormNode, "epsilon");
4336 std::vector<int64_t> reduce_axes = {static_cast<int64_t>(input_shape.size()) - 1};
4337
4338 AddMeanVarianceNormalizationOp(layernorm_input_x, layernorm_input_gamma, layernorm_input_beta, node_name, reduce_axes,
4339 epsilon, input_shape, onnx_type, graph_proto);
4340 }
4341
ExportMergeConv2DTranspose(const FuncGraphPtr &,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)4342 void OnnxExporter::ExportMergeConv2DTranspose(const FuncGraphPtr &, const CNodePtr &node,
4343 std::map<AnfNodePtr, std::string> *node_map_ptr,
4344 onnx::GraphProto *const graph_proto) {
4345 auto conv_node = dyn_cast<CNode>(node->input(kOneNum));
4346 PrimConv2DTransposeExportHelper(conv_node, node, node_map_ptr, graph_proto);
4347 }
4348
AddTransposeOp(const std::string & input,const std::string & output,onnx::GraphProto * graph_proto)4349 void AddTransposeOp(const std::string &input, const std::string &output, onnx::GraphProto *graph_proto) {
4350 onnx::NodeProto *node_proto = graph_proto->add_node();
4351 std::string op_type = "Transpose";
4352 node_proto->set_op_type(op_type);
4353 node_proto->set_name(output + op_type);
4354 node_proto->add_input(input);
4355 node_proto->add_output(output);
4356 }
4357
AddUnsqueezeOp(const std::string & input,const std::string & output,int64_t axis,onnx::GraphProto * graph_proto)4358 void AddUnsqueezeOp(const std::string &input, const std::string &output, int64_t axis, onnx::GraphProto *graph_proto) {
4359 onnx::NodeProto *node_proto = graph_proto->add_node();
4360 std::string op_type = "Unsqueeze";
4361 node_proto->set_op_type(op_type);
4362 node_proto->set_name(output + op_type);
4363 node_proto->add_input(input);
4364 node_proto->add_output(output);
4365
4366 onnx::AttributeProto *attr_proto = node_proto->add_attribute();
4367 attr_proto->set_type(onnx::AttributeProto_AttributeType_INTS);
4368 attr_proto->set_name("axes");
4369 attr_proto->add_ints(axis);
4370 }
4371
AddSqueezeOp(const std::string & input,const std::string & output,int64_t axis,onnx::GraphProto * graph_proto)4372 void AddSqueezeOp(const std::string &input, const std::string &output, int64_t axis, onnx::GraphProto *graph_proto) {
4373 onnx::NodeProto *node_proto = graph_proto->add_node();
4374 std::string op_type = "Squeeze";
4375 node_proto->set_op_type(op_type);
4376 node_proto->set_name(output + op_type);
4377 node_proto->add_input(input);
4378 node_proto->add_output(output);
4379
4380 onnx::AttributeProto *attr_proto = node_proto->add_attribute();
4381 attr_proto->set_type(onnx::AttributeProto_AttributeType_INTS);
4382 attr_proto->set_name("axes");
4383 attr_proto->add_ints(axis);
4384 }
4385
AddGRUOp(const std::vector<std::string> & inputs,const std::vector<std::string> & outputs,int64_t hidden_size,int64_t linear_before_reset,onnx::GraphProto * graph_proto)4386 void AddGRUOp(const std::vector<std::string> &inputs, const std::vector<std::string> &outputs, int64_t hidden_size,
4387 int64_t linear_before_reset, onnx::GraphProto *graph_proto) {
4388 onnx::NodeProto *node_proto = graph_proto->add_node();
4389 std::string op_type = "GRU";
4390 node_proto->set_op_type(op_type);
4391 node_proto->set_name(outputs[0] + op_type);
4392
4393 for (const auto &in : inputs) {
4394 node_proto->add_input(in);
4395 }
4396
4397 for (const auto &out : outputs) {
4398 node_proto->add_output(out);
4399 }
4400
4401 onnx::AttributeProto *attr_proto = node_proto->add_attribute();
4402 attr_proto->set_type(onnx::AttributeProto_AttributeType_INT);
4403 attr_proto->set_name("linear_before_reset");
4404 attr_proto->set_i(linear_before_reset);
4405
4406 onnx::AttributeProto *attr2_proto = node_proto->add_attribute();
4407 attr2_proto->set_type(onnx::AttributeProto_AttributeType_INT);
4408 attr2_proto->set_name("hidden_size");
4409 attr2_proto->set_i(hidden_size);
4410 }
4411
UnsqueezeInputOfGRU(std::string * in_name,const std::string & node_name,const std::string & suffix,int64_t axis,onnx::GraphProto * graph_proto)4412 void UnsqueezeInputOfGRU(std::string *in_name, const std::string &node_name, const std::string &suffix, int64_t axis,
4413 onnx::GraphProto *graph_proto) {
4414 auto out_name = node_name + suffix;
4415 AddUnsqueezeOp(*in_name, out_name, axis, graph_proto);
4416 *in_name = out_name;
4417 }
4418
GruRzh2Zrh(std::string * in_name,const std::string & node_name,const std::string & mid_name,std::vector<std::string> tmp_out_names,const std::vector<int64_t> & hidden_sizes,int64_t axis,onnx::GraphProto * graph_proto)4419 void GruRzh2Zrh(std::string *in_name, const std::string &node_name, const std::string &mid_name,
4420 std::vector<std::string> tmp_out_names, const std::vector<int64_t> &hidden_sizes, int64_t axis,
4421 onnx::GraphProto *graph_proto) {
4422 const int kConcatNum = 6;
4423 const int kIndexBiasHiddenR = 3;
4424 const int kIndexBiasHiddenZ = 4;
4425 auto out_name = node_name + mid_name + "_zrh";
4426
4427 AddSplitOp(*in_name, tmp_out_names, hidden_sizes, 0, graph_proto);
4428 swap(tmp_out_names[0], tmp_out_names[1]);
4429 if (tmp_out_names.size() == kConcatNum) {
4430 swap(tmp_out_names[kIndexBiasHiddenR], tmp_out_names[kIndexBiasHiddenZ]);
4431 }
4432 AddConcatOp(tmp_out_names, out_name, 0, graph_proto);
4433 *in_name = out_name;
4434 }
4435
4436 /*
4437 Mapping between the inputs of MindSpore DynamicGRUV2 and ONNX GRU operator.
4438 +----------------------------------------------------------+----------------------------------------------+
4439 | ONNX | MindSpore |
4440 +==========================================================+==============================================+
4441 | X: [seq_length, batch_size, input_size] | x: (num_step, batch_size, input_size) |
4442 +----------------------------------------------------------+----------------------------------------------+
4443 | W: [num_directions, 3*hidden_size, input_size] | weight_input: (input_size, 3*hidden_size) |
4444 +----------------------------------------------------------+----------------------------------------------+
4445 | R: [num_directions, 3*hidden_size, hidden_size] | weight_hidden: (hidden_size, 3*hidden_size) |
4446 +----------------------------------------------------------+----------------------------------------------+
4447 | | bias_input: (3*hidden_size) |
4448 + B: [num_directiBBons, 6*hidden_size] +----------------------------------------------+
4449 | | bias_hidden: (3*hidden_size) |
4450 +----------------------------------------------------------+----------------------------------------------+
4451 | sequence_lens: [batch_size] | seq_length: (hidden_size) |
4452 +----------------------------------------------------------+----------------------------------------------+
4453 | initial_h: [num_directions, batch_size, hidden_size] | init_h: (batch_size, hidden_size) |
4454 +----------------------------------------------------------+----------------------------------------------+
4455 | Y:[seq_length, num_directions, batch_size, hidden_size] | y: (num_step, batch_size, hidden_size) |
4456 +----------------------------------------------------------+----------------------------------------------+
4457 */
ExportMergeDynamicGRUV2(const FuncGraphPtr &,const CNodePtr & node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)4458 void OnnxExporter::ExportMergeDynamicGRUV2(const FuncGraphPtr &, const CNodePtr &node,
4459 std::map<AnfNodePtr, std::string> *node_map_ptr,
4460 onnx::GraphProto *const graph_proto) {
4461 const int kInX = 1;
4462 const int kInWeightInput = 2;
4463 const int kInWeightHidden = 3;
4464 const int kInBiasInput = 4;
4465 const int kInBiasHidden = 5;
4466 // The 6th input 'seq_length' now only support None, so it's not used.
4467 const int kInInitH = 7;
4468
4469 const int kWeightHiddenDim = 2;
4470 const int kNumberOfGates = 3;
4471 const std::string kDefaultDir = "UNIDIRECTIONAL";
4472 const std::string kDefaultAct = "tanh";
4473 const std::vector<std::string> kGateOrderSupported{"rzh", "zrh"};
4474
4475 auto node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
4476 auto gru_node = dyn_cast<CNode>(node->input(1));
4477 MS_EXCEPTION_IF_NULL(gru_node);
4478
4479 /* Get Attributes */
4480 auto direction = GetOpAttribute<std::string>(gru_node, "direction");
4481 auto activation = GetOpAttribute<std::string>(gru_node, "activation");
4482 auto gate_order = GetOpAttribute<std::string>(gru_node, "gate_order");
4483 auto reset_after = GetOpAttribute<bool>(gru_node, "reset_after");
4484
4485 int64_t linear_before_reset = reset_after ? 1 : 0;
4486
4487 if (direction != kDefaultDir) {
4488 MS_LOG(EXCEPTION) << "'direction': " << direction << " is not in supported values[" << kDefaultDir << "]";
4489 }
4490 if (activation != kDefaultAct) {
4491 MS_LOG(EXCEPTION) << "'activation': " << activation << " is not in supported values[" << kDefaultAct << "]";
4492 }
4493 if (gate_order != kGateOrderSupported[0] && gate_order != kGateOrderSupported[1]) {
4494 std::string supported;
4495 for (const auto &order : gate_order) {
4496 supported += order;
4497 supported += ", ";
4498 }
4499 MS_LOG(EXCEPTION) << "'gate_order': " << gate_order << " is not in supported values[" << supported << "]";
4500 }
4501
4502 auto x = GetNodeInputName(gru_node->input(kInX), node_map_ptr, graph_proto);
4503 auto weight_input = GetNodeInputName(gru_node->input(kInWeightInput), node_map_ptr, graph_proto);
4504 auto weight_hidden = GetNodeInputName(gru_node->input(kInWeightHidden), node_map_ptr, graph_proto);
4505 auto bias_input = GetNodeInputName(gru_node->input(kInBiasInput), node_map_ptr, graph_proto);
4506 auto bias_hidden = GetNodeInputName(gru_node->input(kInBiasHidden), node_map_ptr, graph_proto);
4507 auto init_h = GetNodeInputName(gru_node->input(kInInitH), node_map_ptr, graph_proto);
4508
4509 if (GetOutputType(gru_node->input(kInBiasInput)) == onnx::TensorProto_DataType_FLOAT) {
4510 auto x_cast = x + "_cast";
4511 AddCastOp(x, x_cast, onnx::TensorProto_DataType_FLOAT, graph_proto);
4512 x = x_cast;
4513
4514 auto wi_cast = weight_input + "_cast";
4515 AddCastOp(weight_input, wi_cast, onnx::TensorProto_DataType_FLOAT, graph_proto);
4516 weight_input = wi_cast;
4517
4518 auto wh_cast = weight_hidden + "_cast";
4519 AddCastOp(weight_hidden, wh_cast, onnx::TensorProto_DataType_FLOAT, graph_proto);
4520 weight_hidden = wh_cast;
4521 }
4522
4523 auto weight_hidden_shape = dyn_cast<abstract::Shape>(gru_node->input(kInWeightHidden)->Shape())->shape();
4524 if (weight_hidden_shape.size() != kWeightHiddenDim) {
4525 MS_LOG(EXCEPTION) << "The dim of input weight_hidden must be " << kWeightHiddenDim << ".";
4526 }
4527 int64_t hidden_size = weight_hidden_shape[1] / kNumberOfGates;
4528
4529 auto trans_w_i = node_name + "_trans_w_i";
4530 AddTransposeOp(weight_input, trans_w_i, graph_proto);
4531 weight_input = trans_w_i;
4532
4533 auto trans_w_h = node_name + "_trans_w_h";
4534 AddTransposeOp(weight_hidden, trans_w_h, graph_proto);
4535 weight_hidden = trans_w_h;
4536
4537 auto bias_i_h = node_name + "_bias_i_h";
4538 AddConcatOp({bias_input, bias_hidden}, bias_i_h, 0, graph_proto);
4539
4540 // ONNX GRU only support "zrh"
4541 if (gate_order == "rzh") {
4542 MS_LOG(INFO) << "change gate order 'rzh' to 'zrh'.";
4543 std::vector<int64_t> hidden_sizes(kNumberOfGates, hidden_size);
4544 GruRzh2Zrh(&weight_input, node_name, "w_i", {node_name + "_w_i_r", node_name + "_w_i_z", node_name + "_w_i_h"},
4545 hidden_sizes, 0, graph_proto);
4546 GruRzh2Zrh(&weight_hidden, node_name, "w_h", {node_name + "_w_h_r", node_name + "_w_h_z", node_name + "_w_h_h"},
4547 hidden_sizes, 0, graph_proto);
4548
4549 std::vector<int64_t> bias_hidden_sizes(kNumberOfGates + kNumberOfGates, hidden_size);
4550 GruRzh2Zrh(&bias_i_h, node_name, "bias",
4551 {node_name + "_b_i_r", node_name + "_b_i_z", node_name + "_b_i_h", node_name + "_b_h_r",
4552 node_name + "_b_h_z", node_name + "_b_h_h"},
4553 bias_hidden_sizes, 0, graph_proto);
4554 }
4555
4556 std::vector<std::string *> input_names = {&weight_input, &weight_hidden, &bias_i_h, &init_h};
4557 std::vector<std::string> suffixes = {"_unsqueeze_w_i", "_unsqueeze_w_h", "_unsqueeze_bias", "_unsqueeze_init_h"};
4558 for (size_t i = 0; i < input_names.size(); i++) {
4559 UnsqueezeInputOfGRU(input_names[i], node_name, suffixes[i], 0, graph_proto);
4560 }
4561
4562 auto y = node_name + "_Y";
4563 // 'seq_length' input of DynamicGRUV2 is None, so pass "" to ONNX GRU.
4564 std::string sequence_lens = "";
4565 AddGRUOp({x, weight_input, weight_hidden, bias_i_h, sequence_lens, init_h}, {y}, hidden_size, linear_before_reset,
4566 graph_proto);
4567
4568 AddSqueezeOp(y, node_name, 1, graph_proto);
4569 }
4570
4571 /*
4572 Kinds of return values:
4573 1) A single Tensor
4574 2) A Tuple returned by an op with multiple outputs like TopK
4575 3) A Tuple returned by MakeTuple. This corresponds to `return x, y`
4576 or equivalent in Python, where x and y are Tensors
4577 In this case MakeTuple itself is not exported, so this case must be handled
4578 separately from the previous one
4579 4) A constant tuple (ValueNode). Example:
4580 class MyCell(nn.Cell):
4581 def __init__(self):
4582 super().__init__()
4583 self.x = ms.Tensor(np.zeros((1, 2, 3)))
4584
4585 def construct(self):
4586 return self.x, self.x
4587
4588 */
ExportOutput(const FuncGraphPtr & func_graph,const AnfNodePtr & return_arg,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const graph_proto)4589 void OnnxExporter::ExportOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &return_arg,
4590 std::map<AnfNodePtr, std::string> *node_map_ptr, onnx::GraphProto *const graph_proto) {
4591 AnfNodePtr arg = GetRealInput(return_arg);
4592 if (IsPrimitiveCNode(arg, prim::kPrimMakeTuple)) {
4593 auto arg_cnode = dyn_cast<CNode>(arg);
4594 for (size_t i = 1; i < arg_cnode->size(); ++i) {
4595 const auto &output = arg_cnode->input(i);
4596 ExportOutput(func_graph, output, node_map_ptr, graph_proto);
4597 }
4598 } else if (arg->isa<ValueNode>() && arg->cast<ValueNodePtr>()->value()->isa<ValueTuple>()) {
4599 // Several outputs, all constants
4600 auto tuple = arg->cast<ValueNodePtr>()->value()->cast<ValueTuplePtr>();
4601 for (size_t i = 0; i < tuple->value().size(); ++i) {
4602 const auto &element = tuple->value().at(i);
4603 std::string output_name = GenerateUniqueName();
4604
4605 onnx::TensorProto *initializer = graph_proto->add_initializer();
4606 initializer->set_name(output_name);
4607 SetTensorData(element, initializer);
4608
4609 onnx::ValueInfoProto *output_proto = graph_proto->add_output();
4610 output_proto->set_name(output_name);
4611 SetValueInfoType(arg, output_proto, static_cast<int64_t>(i));
4612 }
4613 } else if (arg->Type()->isa<Tuple>()) {
4614 auto arg_name = GetNodeInputName(arg, node_map_ptr, graph_proto);
4615 auto tuple = dyn_cast<Tuple>(arg->Type());
4616
4617 for (size_t i = 0; i < tuple->size(); ++i) {
4618 auto output_name = MakeOutputName(arg_name, i);
4619 onnx::ValueInfoProto *output_proto = graph_proto->add_output();
4620 output_proto->set_name(output_name);
4621 SetValueInfoType(arg, output_proto, static_cast<int64_t>(i));
4622 }
4623 } else if (arg->Type()->isa<TensorType>()) {
4624 auto arg_name = GetNodeInputName(arg, node_map_ptr, graph_proto);
4625 onnx::ValueInfoProto *output_proto = graph_proto->add_output();
4626 output_proto->set_name(arg_name);
4627 SetValueInfoType(arg, output_proto);
4628 } else {
4629 MS_LOG(EXCEPTION) << "Unsupported network output type " << arg->Type()->ToString() << " in node "
4630 << arg->ToString();
4631 }
4632 }
4633
GetNodeInputName(const AnfNodePtr & orig_node,std::map<AnfNodePtr,std::string> * node_map_ptr,onnx::GraphProto * const)4634 std::string OnnxExporter::GetNodeInputName(const AnfNodePtr &orig_node, std::map<AnfNodePtr, std::string> *node_map_ptr,
4635 onnx::GraphProto *const) {
4636 auto node = GetRealInput(orig_node);
4637
4638 // if node is renamed and not ignored, use alternative name
4639 // if it is ignored, try to find the actual name in global map
4640 auto renamed_iter = renamed_node_map_.find(node);
4641 if (renamed_iter != renamed_node_map_.end() && renamed_iter->second != "") {
4642 return renamed_iter->second;
4643 }
4644
4645 auto iter = node_map_ptr->find(node);
4646 if (iter != node_map_ptr->end()) {
4647 return iter->second;
4648 }
4649
4650 if (node->isa<CNode>() || (node->isa<Parameter>() && !node->cast<ParameterPtr>()->has_default())) {
4651 MS_LOG(EXCEPTION) << "Can not find node '" << node->DebugString() << "' in node_map";
4652 }
4653
4654 // for ValueNode or Parameter with default input, create an initializer
4655 // same value can be used in several subgraphs, so create initializers in root graph
4656 if (node->isa<ValueNode>()) {
4657 auto node_name = RegisterNodeWithUniqueName(node, node_map_ptr);
4658 auto value = node->cast<ValueNodePtr>()->value();
4659
4660 onnx::TensorProto *initializer_proto = model_.mutable_graph()->add_initializer();
4661 initializer_proto->set_name(node_name);
4662 SetTensorData(value, initializer_proto);
4663
4664 (*node_map_ptr)[node] = node_name;
4665 return node_name;
4666 }
4667
4668 if (node->isa<Parameter>()) {
4669 auto param = dyn_cast<Parameter>(node);
4670 auto node_name = GenerateUniqueParameterName(param, node_map_ptr);
4671
4672 onnx::TensorProto *initializer_proto = model_.mutable_graph()->add_initializer();
4673 initializer_proto->set_name(node_name);
4674 SetTensorData(param->default_param(), initializer_proto);
4675
4676 (*node_map_ptr)[node] = node_name;
4677 return node_name;
4678 }
4679
4680 MS_LOG(EXCEPTION) << "Unexpected node type " << node->type_name();
4681 }
4682
ConvertTupleToTensor(const ValuePtr & value,onnx::TensorProto * const tensor_proto) const4683 void OnnxExporter::ConvertTupleToTensor(const ValuePtr &value, onnx::TensorProto *const tensor_proto) const {
4684 auto tuple_ptr = dyn_cast<ValueTuple>(value);
4685 MS_EXCEPTION_IF_NULL(tuple_ptr);
4686 if (tuple_ptr->size() == 0) {
4687 MS_LOG(EXCEPTION) << "Convert tuple to tensor fail, the size of converted tuple is 0.";
4688 }
4689
4690 ValuePtr first_element = (*tuple_ptr)[0];
4691 if (!first_element->isa<Scalar>()) { // For non-scalars x->type() contains nullptr
4692 MS_LOG(EXCEPTION) << "Expected tuple elements to be scalars. Got: " << value->ToString();
4693 }
4694 auto type_id = first_element->type()->type_id();
4695 for (size_t i = 1; i < tuple_ptr->size(); ++i) {
4696 const auto element_type = (*tuple_ptr)[i]->type();
4697 if (element_type == nullptr || element_type->type_id() != type_id) {
4698 MS_LOG(EXCEPTION) << "Convert tuple to tensor fail, type of tuple elements is not same.";
4699 }
4700 }
4701
4702 onnx::TensorProto_DataType result_type = onnx::TensorProto_DataType_UNDEFINED;
4703 if (first_element->isa<IntegerImm>()) {
4704 result_type = onnx::TensorProto_DataType_INT64;
4705 } else if (first_element->isa<FloatImm>()) {
4706 result_type = onnx::TensorProto_DataType_FLOAT;
4707 } else {
4708 MS_LOG(EXCEPTION) << "Convert tuple to tensor fail, unexpected tuple element type "
4709 << first_element->type()->type_name() << ".";
4710 }
4711
4712 tensor_proto->add_dims(static_cast<::google::protobuf::int64>(tuple_ptr->size()));
4713 tensor_proto->set_data_type(result_type);
4714 for (size_t i = 0; i < tuple_ptr->size(); ++i) {
4715 ValuePtr elem = (*tuple_ptr)[i];
4716 if (elem->isa<Int8Imm>()) {
4717 tensor_proto->add_int64_data(dyn_cast<Int8Imm>(elem)->value());
4718 } else if (elem->isa<Int16Imm>()) {
4719 tensor_proto->add_int64_data(dyn_cast<Int16Imm>(elem)->value());
4720 } else if (elem->isa<Int32Imm>()) {
4721 tensor_proto->add_int64_data(dyn_cast<Int32Imm>(elem)->value());
4722 } else if (elem->isa<Int64Imm>()) {
4723 tensor_proto->add_int64_data(dyn_cast<Int64Imm>(elem)->value());
4724 } else if (elem->isa<FP32Imm>()) {
4725 tensor_proto->add_float_data(dyn_cast<FP32Imm>(elem)->value());
4726 } else {
4727 MS_LOG(EXCEPTION) << "Convert tuple to tensor fail, unexpected tuple element type " << elem->type()->type_name()
4728 << ".";
4729 }
4730 }
4731 }
4732
SetTensorData(const ValuePtr & value,onnx::TensorProto * tensor_proto)4733 void OnnxExporter::SetTensorData(const ValuePtr &value, onnx::TensorProto *tensor_proto) {
4734 if (value->isa<Int32Imm>()) {
4735 auto attr_value = dyn_cast<Int32Imm>(value)->value();
4736 tensor_proto->set_data_type(onnx::TensorProto_DataType_INT32);
4737 tensor_proto->add_int32_data(attr_value);
4738 } else if (value->isa<Int64Imm>()) {
4739 auto attr_value = dyn_cast<Int64Imm>(value)->value();
4740 tensor_proto->set_data_type(onnx::TensorProto_DataType_INT64);
4741 tensor_proto->add_int64_data(attr_value);
4742 } else if (value->isa<tensor::Tensor>()) {
4743 auto data = dyn_cast<tensor::Tensor>(value);
4744 tensor_proto->set_raw_data(data->data_c(), static_cast<size_t>(data->data().nbytes()));
4745 auto dtype = data->data_type();
4746 auto shape = data->shape_c();
4747
4748 tensor_proto->set_data_type(GetOnnxDataType(dtype));
4749 for (const auto dim : shape) {
4750 tensor_proto->add_dims(dim);
4751 }
4752 } else if (value->isa<ValueTuple>()) { // Note: this is a tuple of primitives, not Tensors
4753 ConvertTupleToTensor(value, tensor_proto);
4754 } else {
4755 MS_LOG(EXCEPTION) << "Need to set value " << value->ToString() << " attribute for Constant node";
4756 }
4757 }
4758
GetOnnxProtoString(const FuncGraphPtr & func_graph)4759 std::string GetOnnxProtoString(const FuncGraphPtr &func_graph) {
4760 OnnxExporter exporter;
4761 return exporter.GetOnnxProtoString(func_graph);
4762 }
4763 } // namespace mindspore
4764