1 /**
2 * Copyright 2019-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "utils/check_convert_utils.h"
18
19 #include <algorithm>
20 #include <functional>
21 #include <iterator>
22 #include <map>
23 #include <set>
24 #include <typeinfo>
25 #include <utility>
26 #include <vector>
27
28 #include "abstract/abstract_value.h"
29 #include "ir/dtype.h"
30 #include "ir/dtype/tensor_type.h"
31 #include "ir/dtype/type.h"
32 #include "ir/scalar.h"
33 #include "ir/tensor.h"
34 #include "ir/value.h"
35 #include "mindapi/base/format.h"
36 #include "mindapi/base/type_id.h"
37 #include "mindapi/base/types.h"
38 #include "ops/op_name.h"
39 #include "utils/convert_utils_base.h"
40 #include "utils/ms_context.h"
41 #include "ops/op_utils.h"
42 #include "ir/kernel_tensor_value.h"
43
44 namespace mindspore {
45 static std::map<std::string, int64_t> DataFormatToEnumMap = {
46 {"NCHW", Format::NCHW}, {"NHWC", Format::NHWC}, {"NHWC4", Format::NHWC4},
47 {"HWKC", Format::HWKC}, {"HWCK", Format::HWCK}, {"KCHW", Format::KCHW},
48 {"CKHW", Format::CKHW}, {"KHWC", Format::KHWC}, {"CHWK", Format::CHWK},
49 {"HW", Format::HW}, {"HW4", Format::HW4}, {"NC", Format::NC},
50 {"NC4", Format::NC4}, {"NC4HW4", Format::NC4HW4}, {"NUM_OF_FORMAT", Format::NUM_OF_FORMAT},
51 {"NCDHW", Format::NCDHW}, {"NWC", Format::NWC}, {"NCW", Format::NCW},
52 };
53
54 static std::map<int64_t, std::string> DataFormatToStrMap = {
55 {Format::NCHW, "NCHW"}, {Format::NHWC, "NHWC"}, {Format::NHWC4, "NHWC4"},
56 {Format::HWKC, "HWKC"}, {Format::HWCK, "HWCK"}, {Format::KCHW, "KCHW"},
57 {Format::CKHW, "CKHW"}, {Format::KHWC, "KHWC"}, {Format::CHWK, "CHWK"},
58 {Format::HW, "HW"}, {Format::HW4, "HW4"}, {Format::NC, "NC"},
59 {Format::NC4, "NC4"}, {Format::NC4HW4, "NC4HW4"}, {Format::NUM_OF_FORMAT, "NUM_OF_FORMAT"},
60 {Format::NCDHW, "NCDHW"}, {Format::NWC, "NWC"}, {Format::NCW, "NCW"},
61 };
62
63 static std::map<std::string, int64_t> ReductionToEnumMap = {
64 {"sum", Reduction::REDUCTION_SUM},
65 {"mean", Reduction::MEAN},
66 {"none", Reduction::NONE},
67 };
68
69 static std::map<int64_t, std::string> ReductionToStrMap = {
70 {Reduction::REDUCTION_SUM, "sum"},
71 {Reduction::MEAN, "mean"},
72 {Reduction::NONE, "none"},
73 };
74
75 static std::map<std::string, int64_t> PadModToEnumMap = {
76 {"pad", PadMode::PAD},
77 {"same", PadMode::SAME},
78 {"valid", PadMode::VALID},
79 };
80
81 static std::map<int64_t, std::string> PadModToStrMap = {
82 {PadMode::PAD, "pad"},
83 {PadMode::SAME, "same"},
84 {PadMode::VALID, "valid"},
85 };
86
87 static std::map<std::string, int64_t> PadModToEnumUpperMap = {
88 {"PAD", PadMode::PAD},
89 {"SAME", PadMode::SAME},
90 {"VALID", PadMode::VALID},
91 // this should be removed, cause some op change "PAD" to "CALCULATED" in python.
92 {"CALCULATED", PadMode::PAD},
93 };
94
95 static std::map<int64_t, std::string> PadModToStrUpperMap = {
96 {PadMode::PAD, "PAD"},
97 {PadMode::SAME, "SAME"},
98 {PadMode::VALID, "VALID"},
99 };
100
101 AttrConverterPair DataFormatConverter(DataFormatToEnumMap, DataFormatToStrMap);
102 AttrConverterPair PadModeConverter(PadModToEnumMap, PadModToStrMap);
103 AttrConverterPair PadModeUpperConverter(PadModToEnumUpperMap, PadModToStrUpperMap);
104 AttrConverterPair ReductionConverter(ReductionToEnumMap, ReductionToStrMap);
105
106 static std::map<std::string, AttrConverterPair> FormatAndPadAttrMap = {
107 {ops::kFormat, DataFormatConverter},
108 {ops::kPadMode, PadModeConverter},
109 };
110
111 static std::map<std::string, AttrConverterPair> FormatAndPadUpperAttrMap = {
112 {ops::kFormat, DataFormatConverter},
113 {ops::kPadMode, PadModeUpperConverter},
114 };
115
116 static std::map<std::string, AttrConverterPair> DataFormatMap = {
117 {ops::kFormat, DataFormatConverter},
118 };
119
120 static std::map<std::string, AttrConverterPair> FormatAndDataFormatMap = {
121 {ops::kFormat, DataFormatConverter},
122 {ops::kDataFormat, DataFormatConverter},
123 };
124
125 static std::map<std::string, AttrConverterPair> ReductionMap = {
126 {ops::kReduction, ReductionConverter},
127 };
128
129 static std::map<std::string, std::map<std::string, AttrConverterPair>> PrimAttrConvertMap = {
130 {"Conv2D", FormatAndPadAttrMap},
131 {"Conv2DTranspose", FormatAndPadUpperAttrMap},
132 {"Conv2DBackpropInput", FormatAndPadUpperAttrMap},
133 {"Conv2DBackpropFilter", FormatAndPadUpperAttrMap},
134 {"Conv3D", FormatAndPadAttrMap},
135 {"Conv3DBackpropInput", FormatAndPadAttrMap},
136 {"Conv3DBackpropFilter", FormatAndPadAttrMap},
137 {"Conv3DTranspose", DataFormatMap},
138 {"DepthwiseConv2dNative", FormatAndPadAttrMap},
139 {"DepthwiseConv2dNativeBackpropInput", FormatAndPadAttrMap},
140 {"DepthwiseConv2dNativeBackpropFilter", FormatAndPadAttrMap},
141 {"AvgPool", FormatAndPadUpperAttrMap},
142 {"MaxPoolV1", FormatAndPadUpperAttrMap},
143 {"MaxPool", FormatAndPadUpperAttrMap},
144 {"MaxPoolWithArgmax", FormatAndPadUpperAttrMap},
145 {"AvgPoolGrad", FormatAndPadUpperAttrMap},
146 {"AvgPoolGradVm", FormatAndPadUpperAttrMap},
147 {"AvgPoolGradGpu", FormatAndPadUpperAttrMap},
148 {"AvgPoolGradCpu", FormatAndPadUpperAttrMap},
149 {"AvgPoolV1", FormatAndPadUpperAttrMap},
150 {"AvgPoolGradV1", FormatAndPadUpperAttrMap},
151 {"MaxPoolGrad", FormatAndPadUpperAttrMap},
152 {"MaxPoolGradV1", FormatAndPadUpperAttrMap},
153 {"MaxPoolGradGrad", FormatAndPadUpperAttrMap},
154 {"MaxPoolGradWithArgmax", FormatAndPadUpperAttrMap},
155 {"MaxPoolGradGradWithArgmax", FormatAndPadUpperAttrMap},
156 {"BatchNorm", DataFormatMap},
157 {"BatchNormGrad", DataFormatMap},
158 {"BiasAdd", DataFormatMap},
159 {"BiasAddGrad", DataFormatMap},
160 {"BinaryCrossEntropy", ReductionMap},
161 {"BinaryCrossEntropyGrad", ReductionMap},
162 {"NLLLoss", ReductionMap},
163 {"NLLLossGrad", ReductionMap},
164 {"DepthToSpace", FormatAndDataFormatMap},
165 {"SpaceToDepth", FormatAndDataFormatMap},
166 {"Pooling", DataFormatMap},
167 {"Deconvolution", DataFormatMap},
168 {"AvgPoolV2", DataFormatMap},
169 {"MaxPoolV3", DataFormatMap},
170 {"FusedBatchNorm", DataFormatMap},
171 {"DeformableConv2d", DataFormatMap}};
172
CheckPrimAttrConverted(const std::string & op_name)173 bool CheckAndConvertUtils::CheckPrimAttrConverted(const std::string &op_name) {
174 return PrimAttrConvertMap.find(op_name) != PrimAttrConvertMap.end();
175 }
176
GetDataFormatEnumValue(const ValuePtr & value,int64_t * enum_value)177 bool CheckAndConvertUtils::GetDataFormatEnumValue(const ValuePtr &value, int64_t *enum_value) {
178 MS_EXCEPTION_IF_NULL(value);
179 if (!value->isa<StringImm>()) {
180 *enum_value = GetValue<int64_t>(value);
181 return true;
182 }
183 auto attr_value_str = GetValue<std::string>(value);
184 auto iter = DataFormatToEnumMap.find(attr_value_str);
185 if (iter == DataFormatToEnumMap.end()) {
186 MS_LOG(DEBUG) << "The data format " << attr_value_str << " not be converted to enum.";
187 return false;
188 }
189 *enum_value = iter->second;
190 return true;
191 }
192
GetPadModEnumValue(const ValuePtr & value,int64_t * enum_value,bool is_upper)193 void CheckAndConvertUtils::GetPadModEnumValue(const ValuePtr &value, int64_t *enum_value, bool is_upper) {
194 MS_EXCEPTION_IF_NULL(value);
195 if (!value->isa<StringImm>()) {
196 *enum_value = GetValue<int64_t>(value);
197 return;
198 }
199 auto attr_value_str = GetValue<std::string>(value);
200
201 if (is_upper) {
202 auto iter = PadModToEnumUpperMap.find(attr_value_str);
203 if (iter == PadModToEnumUpperMap.end()) {
204 MS_LOG(EXCEPTION) << "Invalid pad mode " << attr_value_str << " use pad, valid or same";
205 }
206 *enum_value = iter->second;
207 return;
208 }
209 auto iter = PadModToEnumMap.find(attr_value_str);
210 if (iter == PadModToEnumMap.end()) {
211 MS_LOG(EXCEPTION) << "Invalid pad mode " << attr_value_str << " use pad, valid or same";
212 }
213 *enum_value = iter->second;
214 }
215
GetReductionEnumValue(const ValuePtr & value,int64_t * enum_value)216 void CheckAndConvertUtils::GetReductionEnumValue(const ValuePtr &value, int64_t *enum_value) {
217 MS_EXCEPTION_IF_NULL(value);
218 if (!value->isa<StringImm>()) {
219 *enum_value = GetValue<int64_t>(value);
220 return;
221 }
222 auto attr_value_str = GetValue<std::string>(value);
223 auto iter = ReductionToEnumMap.find(attr_value_str);
224 if (iter == ReductionToEnumMap.end()) {
225 MS_LOG(EXCEPTION) << "Invalid pad mode " << attr_value_str << " use pad, valid or same";
226 }
227 *enum_value = iter->second;
228 }
229
GetAttrConvertPair(const std::string & op_type,const std::string & attr_name)230 AttrConverterPair CheckAndConvertUtils::GetAttrConvertPair(const std::string &op_type, const std::string &attr_name) {
231 AttrConverterPair attr_pair;
232 if (op_type.empty() || attr_name.empty()) {
233 return attr_pair;
234 }
235 auto op_attr_map_it = PrimAttrConvertMap.find(op_type);
236 if (op_attr_map_it == PrimAttrConvertMap.end()) {
237 return attr_pair;
238 }
239 auto attr_pair_it = op_attr_map_it->second.find(attr_name);
240 if (attr_pair_it == op_attr_map_it->second.end()) {
241 return attr_pair;
242 }
243
244 return attr_pair_it->second;
245 }
246
ConvertAttrValueToInt(const std::string & op_type,const std::string & attr_name,ValuePtr * const value)247 bool CheckAndConvertUtils::ConvertAttrValueToInt(const std::string &op_type, const std::string &attr_name,
248 ValuePtr *const value) {
249 if (value == nullptr || *value == nullptr) {
250 MS_LOG(DEBUG) << "value of attr " << op_type << attr_name << " is nullptr.";
251 return false;
252 }
253 if (!(*value)->isa<StringImm>()) {
254 return false;
255 }
256 auto attr_map_pair = GetAttrConvertPair(op_type, attr_name);
257 if (attr_map_pair.first.empty()) {
258 return false;
259 }
260
261 std::string real_value = std::dynamic_pointer_cast<StringImm>(*value)->value();
262 bool do_convert = false;
263 if (attr_map_pair.first.find(real_value) != attr_map_pair.first.end()) {
264 do_convert = true;
265 }
266 if (!do_convert) {
267 transform(real_value.begin(), real_value.end(), real_value.begin(), ::toupper);
268 if (attr_map_pair.first.find(real_value) != attr_map_pair.first.end()) {
269 do_convert = true;
270 }
271 }
272 if (!do_convert) {
273 transform(real_value.begin(), real_value.end(), real_value.begin(), ::tolower);
274 if (attr_map_pair.first.find(real_value) == attr_map_pair.first.end()) {
275 MS_LOG(DEBUG) << "Can not convert " << op_type << " attr " << attr_name << ": " << real_value << " to int";
276 return false;
277 }
278 }
279 *value = MakeValue<int64_t>(attr_map_pair.first[real_value]);
280 MS_LOG(DEBUG) << "convert str to int, name: " << op_type << ", attr: " << attr_name;
281 return true;
282 }
283
ConvertAttrValueToString(const std::string & op_type,const std::string & attr_name,ValuePtr * const value)284 bool CheckAndConvertUtils::ConvertAttrValueToString(const std::string &op_type, const std::string &attr_name,
285 ValuePtr *const value) {
286 if (value == nullptr || *value == nullptr) {
287 MS_LOG(DEBUG) << "value is nullptr! op_type = " << op_type << ", attr_name = " << attr_name;
288 return false;
289 }
290 if (!(*value)->isa<Int64Imm>()) {
291 return false;
292 }
293 auto attr_map_pair = GetAttrConvertPair(op_type, attr_name);
294 if (attr_map_pair.second.empty()) {
295 return false;
296 }
297
298 int64_t real_value = std::dynamic_pointer_cast<Int64Imm>(*value)->value();
299 if (attr_map_pair.second.find(real_value) == attr_map_pair.second.end()) {
300 MS_LOG(DEBUG) << "Can not convert " << op_type << " attr " << attr_name << ": " << real_value << " to string";
301 return false;
302 }
303 *value = MakeValue<std::string>(attr_map_pair.second[real_value]);
304 MS_LOG(DEBUG) << "convert int to str, name: " << op_type << ", attr: " << attr_name;
305 return true;
306 }
307
GetFormatStringVal(const PrimitivePtr & prim,std::string * format)308 void CheckAndConvertUtils::GetFormatStringVal(const PrimitivePtr &prim, std::string *format) {
309 if (prim == nullptr || format == nullptr) {
310 MS_LOG(DEBUG) << "Prim or format is nullptr.";
311 return;
312 }
313 auto value_ptr = prim->GetAttr(ops::kFormat);
314 if (value_ptr == nullptr) {
315 MS_LOG(DEBUG) << "Val is nullptr! op type = " << prim->name();
316 return;
317 }
318 int64_t data_format;
319 bool result = CheckAndConvertUtils::GetDataFormatEnumValue(value_ptr, &data_format);
320 if (result) {
321 if (DataFormatToStrMap.find(data_format) != DataFormatToStrMap.end()) {
322 *format = DataFormatToStrMap.at(data_format);
323 }
324 }
325 }
326
CheckAbstractShapeSame(const std::vector<AbstractBasePtr> & abs_list)327 size_t CheckAndConvertUtils::CheckAbstractShapeSame(const std::vector<AbstractBasePtr> &abs_list) {
328 if (abs_list.size() <= 1) {
329 return 0;
330 }
331 const auto &first_elem_abs = abs_list[0];
332 MS_EXCEPTION_IF_NULL(first_elem_abs);
333 auto abs1_shape = first_elem_abs->GetShape();
334 MS_EXCEPTION_IF_NULL(abs1_shape);
335 for (size_t i = 0; i < abs_list.size(); ++i) {
336 MS_EXCEPTION_IF_NULL(abs_list[i]);
337 auto abs2_shape = abs_list[i]->GetShape();
338 MS_EXCEPTION_IF_NULL(abs2_shape);
339 if (*abs1_shape != *abs2_shape) {
340 MS_LOG(ERROR) << "Abstract shapes are not same, shape1:" << abs1_shape->ToString()
341 << ", shape2:" << abs2_shape->ToString();
342 return i;
343 }
344 }
345 return 0;
346 }
347
348 // For example,
349 // TensorType(element type is float16) and TensorType(element type is int32) are not same,
350 // TupleType(elements num is 3) and TupleType(elements num is 4) are same.
CheckAbstractTypeSame(const std::vector<AbstractBasePtr> & abs_list)351 size_t CheckAndConvertUtils::CheckAbstractTypeSame(const std::vector<AbstractBasePtr> &abs_list) {
352 if (abs_list.size() <= 1) {
353 return 0;
354 }
355 const auto &first_elem_abs = abs_list[0];
356 MS_EXCEPTION_IF_NULL(first_elem_abs);
357 auto abs1_type = first_elem_abs->BuildType();
358 MS_EXCEPTION_IF_NULL(abs1_type);
359 for (size_t i = 1; i < abs_list.size(); ++i) {
360 MS_EXCEPTION_IF_NULL(abs_list[i]);
361 auto abs2_type = abs_list[i]->BuildType();
362 MS_EXCEPTION_IF_NULL(abs2_type);
363 if (!(*abs1_type == *abs2_type)) {
364 MS_LOG(ERROR) << "Abstract types are not same, type1:" << abs1_type->ToString()
365 << ", type2:" << abs2_type->ToString();
366 return i;
367 }
368 }
369 return 0;
370 }
371
CheckAttrInt64Positive(const std::string & op,const ValuePtr & attr,const std::string & attr_name)372 int64_t CheckAndConvertUtils::CheckAttrInt64Positive(const std::string &op, const ValuePtr &attr,
373 const std::string &attr_name) {
374 MS_EXCEPTION_IF_NULL(attr);
375 int64_t attr_val = attr->cast<Int64ImmPtr>()->value();
376 if (attr_val <= 0) {
377 MS_EXCEPTION(ValueError) << "For '" << op << "', the '" << attr_name
378 << "' should be greater than 0, but got: " << attr_val << ".";
379 }
380 return attr_val;
381 }
382
CheckAbstractTypeAndShapeSame(const std::vector<AbstractBasePtr> & abs_list,const std::string & precondition_log,const std::string & standard_abs_description,const std::string & differ_abs_description)383 void CheckAndConvertUtils::CheckAbstractTypeAndShapeSame(const std::vector<AbstractBasePtr> &abs_list,
384 const std::string &precondition_log,
385 const std::string &standard_abs_description,
386 const std::string &differ_abs_description) {
387 if (abs_list.size() <= 1) {
388 return;
389 }
390 auto differ_index = CheckAndConvertUtils::CheckAbstractTypeSame(abs_list);
391 if (differ_index == 0) {
392 differ_index = CheckAndConvertUtils::CheckAbstractShapeSame(abs_list);
393 }
394 if (differ_index != 0) {
395 auto log_info1 = standard_abs_description.empty() ? "sequence[0] item" : standard_abs_description;
396 auto log_info2 =
397 differ_abs_description.empty() ? "sequence[" + std::to_string(differ_index) + "] item" : differ_abs_description;
398 MS_EXCEPTION(TypeError) << precondition_log << ", the " << log_info1 << " abstract '" << abs_list[0]->ToString()
399 << "' is not same with the " << log_info2 << " abstract '"
400 << abs_list[differ_index]->ToString() << "'.";
401 }
402 }
403
CheckElementAbstractUnSupport(const AbstractBasePtr abs)404 bool CheckElementAbstractUnSupport(const AbstractBasePtr abs) {
405 if (abs == nullptr) {
406 return false;
407 }
408 if (abs->isa<abstract::AbstractSequence>() && !abs->isa<abstract::AbstractSparseTensor>()) {
409 abstract::AbstractSequencePtr seq = abs->cast<abstract::AbstractSequencePtr>();
410 if (seq->dynamic_len()) {
411 auto elem = seq->dynamic_len_element_abs();
412 if (elem->BuildType()->type_id() == kNumberTypeInt64) {
413 return false;
414 }
415 } else {
416 auto elements = seq->elements();
417 if (elements.empty() || std::all_of(elements.cbegin(), elements.cend(), [](const AbstractBasePtr &elem) {
418 return elem->BuildType()->type_id() == kNumberTypeInt64;
419 })) {
420 return false;
421 }
422 }
423 return true;
424 }
425 if (abs->isa<abstract::AbstractDictionary>()) {
426 return true;
427 }
428 if (abs->isa<abstract::AbstractAny>()) {
429 return true;
430 }
431 auto abs_type = abs->BuildType();
432 if (abs_type != nullptr && abs_type->isa<External>()) {
433 return true;
434 }
435 return false;
436 }
437
CheckContainNestedOrIrregularSequence(const std::vector<AbstractBasePtr> & abs_list)438 bool CheckAndConvertUtils::CheckContainNestedOrIrregularSequence(const std::vector<AbstractBasePtr> &abs_list) {
439 // Check input abs has nested sequence, or irregular sequence,
440 // such as sequence contains elements with different shape or type.
441 for (auto abs : abs_list) {
442 if (abs == nullptr) {
443 continue;
444 }
445 if (abs->isa<abstract::AbstractDictionary>()) {
446 return true;
447 }
448 if (!abs->isa<abstract::AbstractSequence>() || abs->isa<abstract::AbstractSparseTensor>()) {
449 continue;
450 }
451 auto abs_seq = abs->cast<abstract::AbstractSequencePtr>();
452 if (abs_seq->dynamic_len()) {
453 if (CheckElementAbstractUnSupport(abs_seq->dynamic_len_element_abs())) {
454 return true;
455 }
456 continue;
457 }
458 const auto &elements = abs_seq->elements();
459 if (elements.size() == 0) {
460 continue;
461 }
462 auto first_element = elements[0];
463 MS_EXCEPTION_IF_NULL(first_element);
464 if (CheckElementAbstractUnSupport(first_element)) {
465 return true;
466 }
467 auto first_element_shape = first_element->GetShape();
468 MS_EXCEPTION_IF_NULL(first_element_shape);
469 auto first_element_type = first_element->BuildType();
470 MS_EXCEPTION_IF_NULL(first_element_type);
471 auto first_element_type_id = first_element_type->generic_type_id();
472 for (size_t i = 1; i < elements.size(); ++i) {
473 auto cur_element = elements[i];
474 MS_EXCEPTION_IF_NULL(cur_element);
475 auto cur_element_type = cur_element->BuildType();
476 MS_EXCEPTION_IF_NULL(cur_element_type);
477 auto cur_element_type_id = cur_element_type->generic_type_id();
478 if (first_element_type_id != cur_element_type_id) {
479 return true;
480 }
481 auto cur_element_shape = cur_element->GetShape();
482 MS_EXCEPTION_IF_NULL(cur_element_shape);
483 if (*first_element_shape != *cur_element_shape) {
484 return true;
485 }
486 try {
487 // cppcheck-suppress unreadVariable
488 MS_LOG_TRY_CATCH_SCOPE;
489 (void)first_element->Join(cur_element);
490 } catch (std::exception &) {
491 return true;
492 }
493 }
494 }
495 return false;
496 }
497
BroadenAllSequenceElements(const abstract::AbstractSequencePtr & sequence)498 abstract::AbstractSequencePtr CheckAndConvertUtils::BroadenAllSequenceElements(
499 const abstract::AbstractSequencePtr &sequence) {
500 MS_EXCEPTION_IF_NULL(sequence);
501 const auto &elements = sequence->elements();
502 AbstractBasePtrList new_elements;
503 for (auto element : elements) {
504 AbstractBasePtr new_element = nullptr;
505 if (element->isa<abstract::AbstractSequence>()) {
506 new_element = BroadenAllSequenceElements(element->cast<abstract::AbstractSequencePtr>());
507 } else {
508 auto tmp_element = element->Clone();
509 if (element->isa<abstract::AbstractScalar>()) {
510 tmp_element->cast<abstract::AbstractScalarPtr>()->set_is_variable(true);
511 }
512 new_element = tmp_element->Broaden();
513 }
514 new_elements.push_back(new_element);
515 }
516 if (sequence->isa<abstract::AbstractList>()) {
517 return std::make_shared<abstract::AbstractList>(new_elements, sequence->sequence_nodes());
518 }
519 return std::make_shared<abstract::AbstractTuple>(new_elements, sequence->sequence_nodes());
520 }
521
CheckValueSame(const ValuePtr & value_1,const ValuePtr & value_2)522 bool CheckAndConvertUtils::CheckValueSame(const ValuePtr &value_1, const ValuePtr &value_2) {
523 MS_EXCEPTION_IF_NULL(value_1);
524 MS_EXCEPTION_IF_NULL(value_2);
525 if (!value_1->IsSameTypeId(value_2->tid())) {
526 return false;
527 }
528 if (value_1->isa<tensor::BaseTensor>()) {
529 auto list_tensor_value = value_2->cast_ptr<tensor::BaseTensor>();
530 return value_1->cast_ptr<tensor::BaseTensor>()->ValueEqual(*list_tensor_value);
531 }
532 return *value_1 == *value_2;
533 }
534
ConvertAttrValueInExport(const std::string & op_type,const std::string & attr_name,ValuePtr * const value)535 void CheckAndConvertUtils::ConvertAttrValueInExport(const std::string &op_type, const std::string &attr_name,
536 ValuePtr *const value) {
537 if (value == nullptr || *value == nullptr) {
538 MS_LOG(DEBUG) << "value is nullptr! op_type = " << op_type << ", attr_name = " << attr_name;
539 return;
540 }
541 // convert enum to string
542 ConvertAttrValueToString(op_type, attr_name, value);
543 }
544
ConvertAttrValueInLoad(const std::string & op_type,const std::string & attr_name,ValuePtr * const value)545 void CheckAndConvertUtils::ConvertAttrValueInLoad(const std::string &op_type, const std::string &attr_name,
546 ValuePtr *const value) {
547 if (value == nullptr || *value == nullptr) {
548 MS_LOG(DEBUG) << "value is nullptr! op_type = " << op_type << ", attr_name = " << attr_name;
549 return;
550 }
551 // convert string to enum
552 ConvertAttrValueToInt(op_type, attr_name, value);
553 }
554
555 namespace {
556 typedef std::map<std::string, std::function<ValuePtr(ValuePtr)>> AttrFunction;
557
L2NormalizeAttrConversion(ValuePtr attr)558 ValuePtr L2NormalizeAttrConversion(ValuePtr attr) {
559 if (attr->isa<Int64Imm>()) {
560 return attr;
561 }
562 auto attr_value = GetValue<std::vector<int64_t>>(attr);
563 return MakeValue(attr_value[0]);
564 }
565
566 std::map<std::string, AttrFunction> kIrAttrToOpAttr = {{"L2Normalize", {{"axis", L2NormalizeAttrConversion}}},
567 {"L2NormalizeGrad", {{"axis", L2NormalizeAttrConversion}}}};
CheckType(const TypePtr & check_type,const std::set<TypePtr> & template_types)568 inline bool CheckType(const TypePtr &check_type, const std::set<TypePtr> &template_types) {
569 return std::any_of(template_types.begin(), template_types.end(), [&check_type](const TypePtr &accept) -> bool {
570 return IsIdentidityOrSubclass(check_type, accept);
571 });
572 }
573 } // namespace
574
CheckString(const std::string & arg_name,const std::string & arg_value,const std::set<std::string> & check_list,const std::string & prim_name)575 std::string CheckAndConvertUtils::CheckString(const std::string &arg_name, const std::string &arg_value,
576 const std::set<std::string> &check_list, const std::string &prim_name) {
577 if (check_list.find(arg_value) != check_list.end()) {
578 return arg_value;
579 }
580 std::ostringstream buffer;
581 buffer << "For primitive[" << prim_name << "], the attribute[" << arg_name << "]";
582 if (check_list.size() == 1) {
583 buffer << " must be \"" << (*check_list.begin()) << "\", but got \"" << arg_value << "\".";
584 MS_EXCEPTION(ValueError) << buffer.str();
585 }
586 buffer << " should be a element of {";
587 for (const auto &item : check_list) {
588 buffer << "\"" << item << "\", ";
589 }
590 buffer << "}"
591 << ",but got \"" << arg_value << "\""
592 << ".";
593 MS_EXCEPTION(ValueError) << buffer.str();
594 }
595
CheckInteger(const std::string & arg_name,int64_t arg_value,CompareEnum compare_operator,int64_t match_value,const std::string & prim_name)596 int64_t CheckAndConvertUtils::CheckInteger(const std::string &arg_name, int64_t arg_value, CompareEnum compare_operator,
597 int64_t match_value, const std::string &prim_name) {
598 auto iter = kCompareMap<float>.find(compare_operator);
599 if (iter == kCompareMap<float>.end()) {
600 MS_EXCEPTION(NotExistsError) << "compare_operator " << compare_operator << " cannot find in the compare map";
601 }
602 if (iter->second(arg_value, match_value)) {
603 return arg_value;
604 }
605 std::ostringstream buffer;
606 if (prim_name.empty()) {
607 buffer << "The argument[" << arg_name << "] must ";
608 } else {
609 buffer << "For primitive[" << prim_name << "], the " << arg_name << " must ";
610 }
611 auto iter_to_string = kCompareToString.find(compare_operator);
612 if (iter_to_string == kCompareToString.end()) {
613 MS_EXCEPTION(NotExistsError) << "compare_operator " << compare_operator << " cannot find in the compare string map";
614 }
615 buffer << iter_to_string->second << match_value << ", but got " << arg_value << ".";
616 MS_EXCEPTION(ValueError) << buffer.str();
617 }
618
FormatCheckMsg(const std::string & arg_name,const std::vector<int64_t> & arg_value,CompareEnum compare_type,const std::vector<int64_t> & value,const PrimitivePtr & prim)619 std::string CheckAndConvertUtils::FormatCheckMsg(const std::string &arg_name, const std::vector<int64_t> &arg_value,
620 CompareEnum compare_type, const std::vector<int64_t> &value,
621 const PrimitivePtr &prim) {
622 std::ostringstream buffer;
623 if (prim == nullptr) {
624 buffer << "The attribute[" << arg_name << "]:";
625 } else {
626 buffer << "For primitive[" << prim->name() << "], the " << arg_name << ":";
627 }
628 auto iter_to_string = kCompareToString.find(compare_type);
629 if (iter_to_string == kCompareToString.end()) {
630 MS_EXCEPTION(NotExistsError) << "compare_operator " << compare_type << " cannot find in the compare string map";
631 }
632
633 buffer << " [";
634 for (auto item : arg_value) {
635 buffer << item << ",";
636 }
637 buffer << "]";
638 buffer << " must " << iter_to_string->second << "[";
639 for (auto item : value) {
640 buffer << item << ",";
641 }
642 buffer << "]";
643 return buffer.str();
644 }
645
CheckInputArgs(const std::vector<AbstractBasePtr> & input_args,const CompareEnum compare_operator,const int64_t match_value,const std::string & prim_name)646 void CheckAndConvertUtils::CheckInputArgs(const std::vector<AbstractBasePtr> &input_args,
647 const CompareEnum compare_operator, const int64_t match_value,
648 const std::string &prim_name) {
649 (void)CheckInteger("input number", SizeToLong(input_args.size()), compare_operator, match_value, prim_name);
650 for (size_t index = 0; index < input_args.size(); index++) {
651 if (input_args[index] == nullptr) {
652 MS_EXCEPTION(ValueError) << "The " << index << "'s input of " << prim_name << " is nullptr.";
653 }
654 }
655 }
656
ConvertShapePtrToShapeMap(const BaseShapePtr & shape)657 ShapeMap CheckAndConvertUtils::ConvertShapePtrToShapeMap(const BaseShapePtr &shape) {
658 MS_EXCEPTION_IF_NULL(shape);
659 if (!shape->isa<abstract::Shape>()) {
660 return std::map<std::string, std::vector<int64_t>>();
661 }
662 auto shape_element = shape->cast<abstract::ShapePtr>();
663 MS_EXCEPTION_IF_NULL(shape_element);
664 ShapeMap shape_map;
665 shape_map[kShape] = shape_element->shape();
666 shape_map[kMaxShape] = shape_element->max_shape();
667 return shape_map;
668 }
669
GetTensorInputShape(const std::string & prim_name,const std::vector<AbstractBasePtr> & input_args,size_t index)670 abstract::ShapePtr CheckAndConvertUtils::GetTensorInputShape(const std::string &prim_name,
671 const std::vector<AbstractBasePtr> &input_args,
672 size_t index) {
673 auto abstract = CheckAndConvertUtils::CheckArgsType(prim_name, input_args, index, kObjectTypeTensorType);
674 MS_EXCEPTION_IF_NULL(abstract);
675 auto base_shape = abstract->GetShape();
676 MS_EXCEPTION_IF_NULL(base_shape);
677 if (!base_shape->isa<abstract::TensorShape>()) {
678 MS_LOG(EXCEPTION) << prim_name << " can not get shape for input " << index;
679 }
680 auto shape = base_shape->cast<abstract::TensorShapePtr>();
681 MS_EXCEPTION_IF_NULL(shape);
682 return shape;
683 }
684
GetTensorInputType(const std::string & prim_name,const std::vector<AbstractBasePtr> & input_args,size_t index)685 TypePtr CheckAndConvertUtils::GetTensorInputType(const std::string &prim_name,
686 const std::vector<AbstractBasePtr> &input_args, size_t index) {
687 if (input_args.size() <= index) {
688 MS_EXCEPTION(ValueError) << "For " << prim_name << ", the index " << index << " is out of the input number "
689 << input_args.size();
690 }
691 auto input_arg = input_args[index];
692 if (input_arg == nullptr) {
693 MS_EXCEPTION(ValueError) << "The " << index << "'s input of " << prim_name << " is nullptr.";
694 }
695 auto base_type = input_arg->GetType();
696 MS_EXCEPTION_IF_NULL(base_type);
697 if (!base_type->isa<TensorType>()) {
698 MS_EXCEPTION(TypeError) << "The " << index << "'s input type of " << prim_name << " is not Tensor.";
699 }
700 auto tensor_type = base_type->cast<TensorTypePtr>();
701 MS_EXCEPTION_IF_NULL(tensor_type);
702 auto type = tensor_type->element();
703 MS_EXCEPTION_IF_NULL(type);
704 return type;
705 }
706
Check(const string & arg_name,int64_t arg_value,CompareEnum compare_type,int64_t value,const string & prim_name,ExceptionType)707 void CheckAndConvertUtils::Check(const string &arg_name, int64_t arg_value, CompareEnum compare_type, int64_t value,
708 const string &prim_name, ExceptionType) {
709 auto iter = kCompareMap<float>.find(compare_type);
710 if (iter == kCompareMap<float>.end()) {
711 MS_EXCEPTION(NotExistsError) << "the compare type :" << compare_type << " is not in the compare map";
712 }
713 if (iter->second(arg_value, value)) {
714 return;
715 }
716 std::ostringstream buffer;
717 if (prim_name.empty()) {
718 buffer << "The attribute[" << arg_name << "] must ";
719 } else {
720 buffer << "For primitive[" << prim_name << "], the attribute[" << arg_name << "] must ";
721 }
722 auto iter_to_string = kCompareToString.find(compare_type);
723 if (iter_to_string == kCompareToString.end()) {
724 MS_EXCEPTION(NotExistsError) << "compare_operator " << compare_type << " cannot find in the compare string map";
725 }
726 buffer << iter_to_string->second << value << ", but got " << arg_value << ".";
727 MS_EXCEPTION(ValueError) << buffer.str();
728 }
729
CheckTensorTypeSame(const std::map<std::string,TypePtr> & types,const std::set<TypePtr> & check_list,const std::string & prim_name)730 TypePtr CheckAndConvertUtils::CheckTensorTypeSame(const std::map<std::string, TypePtr> &types,
731 const std::set<TypePtr> &check_list, const std::string &prim_name) {
732 if (types.empty()) {
733 MS_EXCEPTION(ArgumentError) << "Trying to use the function to check a empty types map!";
734 }
735 // Check Input type is tensor type
736 for (const auto &item : types) {
737 auto type = item.second;
738 MS_EXCEPTION_IF_NULL(type);
739 if (!type->isa<TensorType>()) {
740 size_t i = 1;
741 std::ostringstream buffer;
742 buffer << "The primitive[" << prim_name << "]'s input arguments[";
743 for (const auto &item_type : types) {
744 buffer << item_type.first;
745 if (i < types.size()) {
746 buffer << ", ";
747 ++i;
748 }
749 }
750 i = 1;
751 buffer << "] must be all tensor and those type must be same.";
752 for (const auto &type_info : types) {
753 if (!type_info.second->isa<TensorType>()) {
754 buffer << " But got input argument[" << type_info.first << "]"
755 << ":" << type_info.second->ToString() << "\n";
756 }
757 }
758 if (!check_list.empty()) {
759 buffer << "Valid type list: {";
760 std::set<string> order_set;
761 for (auto const &valid_type : check_list) {
762 if (valid_type->isa<TensorType>()) {
763 (void)order_set.emplace(valid_type->ToString());
764 break;
765 } else {
766 (void)order_set.emplace("Tensor[" + valid_type->ToString() + "]");
767 }
768 }
769 for (auto const &error_item : order_set) {
770 buffer << error_item;
771 if (error_item != *(--order_set.end())) {
772 buffer << ", ";
773 }
774 }
775 buffer << "}.";
776 }
777 MS_EXCEPTION(TypeError) << buffer.str();
778 }
779 }
780 (void)CheckTypeSame(types, prim_name, false);
781 return CheckTensorSubClass(types.begin()->first, types.begin()->second, check_list, prim_name);
782 }
783
CheckMathBinaryOpTensorType(const std::map<std::string,TypePtr> & types,const std::set<TypePtr> & check_list,const std::string & prim_name)784 TypePtr CheckAndConvertUtils::CheckMathBinaryOpTensorType(const std::map<std::string, TypePtr> &types,
785 const std::set<TypePtr> &check_list,
786 const std::string &prim_name) {
787 constexpr size_t n = 2;
788 if (types.size() != n) {
789 MS_EXCEPTION(ArgumentError) << "For primitive[" << prim_name << "], the size of types to check must be " << n
790 << ", but got " << types.size();
791 }
792 // Check Input type is tensor type
793 std::vector<TypeId> type_ids;
794 std::vector<TypePtr> type_ptr;
795 bool has_complex = false;
796 for (const auto &item : types) {
797 MS_EXCEPTION_IF_NULL(item.second);
798 if (!item.second->isa<TensorType>()) {
799 MS_EXCEPTION(TypeError) << "The primitive[" << prim_name << "]'s input arguments[" << item.first
800 << "] must be Tensor, but got " << item.second->ToString();
801 }
802 auto tensor_type = item.second->cast<TensorTypePtr>();
803 MS_EXCEPTION_IF_NULL(tensor_type);
804 auto element = tensor_type->element();
805 MS_EXCEPTION_IF_NULL(element);
806 auto type_id = element->type_id();
807 if (!has_complex && (type_id == kNumberTypeComplex64 || type_id == kNumberTypeComplex128)) {
808 has_complex = true;
809 }
810 type_ids.push_back(type_id);
811 type_ptr.push_back(item.second);
812 }
813 // Deal with complex data type
814 if (has_complex) {
815 static std::map<std::pair<TypeId, TypeId>, TypeId> type_infer_dict = {
816 {{kNumberTypeComplex64, kNumberTypeComplex64}, kNumberTypeComplex64},
817 {{kNumberTypeComplex64, kNumberTypeFloat32}, kNumberTypeComplex64},
818 {{kNumberTypeFloat32, kNumberTypeComplex64}, kNumberTypeComplex64},
819 {{kNumberTypeComplex128, kNumberTypeComplex128}, kNumberTypeComplex128},
820 {{kNumberTypeComplex128, kNumberTypeFloat64}, kNumberTypeComplex128},
821 {{kNumberTypeFloat64, kNumberTypeComplex128}, kNumberTypeComplex128}};
822 std::pair<TypeId, TypeId> type_info(type_ids[0], type_ids[1]);
823 auto iter = type_infer_dict.find(type_info);
824 if (iter != type_infer_dict.end()) {
825 return type_ids[0] == iter->second ? type_ptr[0] : type_ptr[1];
826 }
827 std::ostringstream buffer;
828 buffer << "For primitive[" << prim_name << "], complex math binary op expecting Tensor";
829 for (const auto &items : type_infer_dict) {
830 buffer << "[" << TypeIdToString(items.first.first) << ", " << TypeIdToString(items.first.second) << "], ";
831 }
832 buffer << "but got Tensor[" << TypeIdToString(type_ids[0]) << ", " << TypeIdToString(type_ids[1]) << "]";
833 MS_EXCEPTION(TypeError) << buffer.str();
834 }
835 // Deal with non-complex data type
836 if (type_ids[0] != type_ids[1]) {
837 MS_EXCEPTION(TypeError) << "For primitive[" << prim_name
838 << "], the input arguments must have same data type, but got Tensor["
839 << TypeIdToString(type_ids[0]) << "] and Tensor[" << TypeIdToString(type_ids[1]) << "]";
840 }
841 (void)CheckTensorSubClass(types.begin()->first, types.begin()->second, check_list, prim_name);
842 return types.begin()->second;
843 }
844
CheckTensorShapeSame(const std::map<std::string,BaseShapePtr> & shapes,const std::vector<int64_t> & check_shape,const std::string & prim_name)845 ShapeVector CheckAndConvertUtils::CheckTensorShapeSame(const std::map<std::string, BaseShapePtr> &shapes,
846 const std::vector<int64_t> &check_shape,
847 const std::string &prim_name) {
848 if (shapes.empty()) {
849 MS_EXCEPTION(ArgumentError) << "Trying to use the function to check a empty shapes map!";
850 }
851 for (const auto &shape : shapes) {
852 auto _shape_ptr_ = shape.second;
853 MS_EXCEPTION_IF_NULL(_shape_ptr_);
854 auto _shape_ = ConvertShapePtrToShapeMap(_shape_ptr_)[kShape];
855 (void)CheckPositiveVectorExcludeZero(shape.first, _shape_, prim_name);
856 if (!ShapeVectorIsSame(_shape_, check_shape)) {
857 std::ostringstream buffer;
858 buffer << "The primitive[" << prim_name << "]'s input arguments " << shape.first << " shape should equal to "
859 << ShapeVectorToStr(check_shape) << ", but get the real shape " << ShapeVectorToStr(_shape_) << ".";
860 MS_EXCEPTION(ValueError) << buffer.str();
861 }
862 }
863 return check_shape;
864 }
865
CheckTensorTypeValid(const std::string & type_name,const TypePtr & type,const std::set<TypePtr> & check_list,const std::string & prim_name)866 TypePtr CheckAndConvertUtils::CheckTensorTypeValid(const std::string &type_name, const TypePtr &type,
867 const std::set<TypePtr> &check_list, const std::string &prim_name) {
868 // note that the return type might be different from input type
869 MS_EXCEPTION_IF_NULL(type);
870 if (!type->isa<TensorType>()) {
871 MS_EXCEPTION(TypeError) << "For Primitive[" << prim_name << "], the type of input argument[" << type_name
872 << "] must be Tensor but got " << type->ToString() << ".";
873 }
874 auto tensor_type = type->cast<TensorTypePtr>();
875 auto element = tensor_type->element();
876 MS_EXCEPTION_IF_NULL(element);
877 for (const TypePtr &item : check_list) {
878 if (item->isa<TensorType>()) {
879 auto item_tensor_type = item->cast<TensorTypePtr>();
880 if (item_tensor_type->element() == nullptr) {
881 return element;
882 }
883 }
884 }
885 return CheckTensorSubClass(type_name, type, check_list, prim_name);
886 }
887
CheckSparseTensorTypeValid(const std::string & type_name,const TypePtr & type,const std::set<TypePtr> &,const std::string & prim_name)888 TypePtr CheckAndConvertUtils::CheckSparseTensorTypeValid(const std::string &type_name, const TypePtr &type,
889 const std::set<TypePtr> &, const std::string &prim_name) {
890 MS_EXCEPTION_IF_NULL(type);
891 if (!type->isa<SparseTensorType>()) {
892 MS_EXCEPTION(TypeError) << "For Primitive[" << prim_name << "], the input argument[" << type_name
893 << "] must be a CSRTensor or COOTensor, but got " << type->ToString() << ".";
894 } else {
895 auto sparse_type = type->cast<SparseTensorTypePtr>();
896 if (sparse_type != nullptr) {
897 return sparse_type->element_type();
898 }
899 MS_EXCEPTION(TypeError) << "For Primitive[" << prim_name << "], the input argument[" << type_name
900 << "] cast to SparseTensorTypePtr failed! Get type : " << type->ToString() << ".";
901 }
902 }
903
CheckTensorIntValue(const std::string & tensor_name,const ValuePtr & value,const std::string & prim_name)904 ShapeVector CheckAndConvertUtils::CheckTensorIntValue(const std::string &tensor_name, const ValuePtr &value,
905 const std::string &prim_name) {
906 if (value == nullptr) {
907 MS_EXCEPTION(ValueError) << "For primitive[" << prim_name << "], the input argument[" << tensor_name
908 << "] value is nullptr.";
909 }
910 ShapeVector tensor_value;
911 if (!value->isa<tensor::BaseTensor>()) {
912 MS_EXCEPTION(ValueError) << "For primitive[" << prim_name << "], the input argument[" << tensor_name
913 << "] must be a tensor, but got " << value->ToString();
914 }
915 auto input_tensor = value->cast<tensor::BaseTensorPtr>();
916 MS_EXCEPTION_IF_NULL(input_tensor);
917 size_t data_size = input_tensor->DataSize();
918 auto tensor_type = input_tensor->Dtype();
919 if (tensor_type->type_id() == kNumberTypeInt32) {
920 auto data_c = reinterpret_cast<int *>(input_tensor->data_c());
921 MS_EXCEPTION_IF_NULL(data_c);
922 for (size_t i = 0; i < data_size; i++) {
923 tensor_value.push_back(static_cast<int64_t>(*data_c));
924 ++data_c;
925 }
926 } else if (tensor_type->type_id() == kNumberTypeInt64) {
927 auto tensor_data = reinterpret_cast<int64_t *>(input_tensor->data_c());
928 MS_EXCEPTION_IF_NULL(tensor_data);
929 tensor_value = {tensor_data, tensor_data + data_size};
930 } else if (tensor_type->type_id() == kNumberTypeUInt32) {
931 auto tensor_data = reinterpret_cast<uint32_t *>(input_tensor->data_c());
932 MS_EXCEPTION_IF_NULL(tensor_data);
933 tensor_value = {tensor_data, tensor_data + data_size};
934 } else if (tensor_type->type_id() == kNumberTypeUInt64) {
935 auto tensor_data = reinterpret_cast<uint64_t *>(input_tensor->data_c());
936 MS_EXCEPTION_IF_NULL(tensor_data);
937 tensor_value = {tensor_data, tensor_data + data_size};
938 } else {
939 MS_EXCEPTION(TypeError) << "For primitive[" << prim_name << "], the input argument[" << tensor_name
940 << "] must be a Tensor[Int64] or Tensor[Int32]"
941 << " or Tensor[UInt64] or Tensor[UInt32] type, but got " << value->ToString();
942 }
943 return tensor_value;
944 }
945
CheckTensorIntValue(const std::string & tensor_name,const ValuePtr & value,const std::string & prim_name,const TypePtr & type)946 ShapeVector CheckAndConvertUtils::CheckTensorIntValue(const std::string &tensor_name, const ValuePtr &value,
947 const std::string &prim_name, const TypePtr &type) {
948 if (value == nullptr) {
949 MS_EXCEPTION(ValueError) << "For primitive[" << prim_name << "], the input argument[" << tensor_name
950 << "] value is nullptr.";
951 }
952 if (value->isa<ValueAny>() || value->isa<None>()) {
953 MS_EXCEPTION(ValueError) << "For primitive[" << prim_name << "], the input argument[" << tensor_name
954 << "] value is unknown.";
955 }
956 if (type->object_type() != kObjectTypeTensorType) {
957 MS_EXCEPTION(ValueError) << "For primitive[" << prim_name << "], the input argument[" << tensor_name
958 << "] must be a tensor, but got " << type->ToString();
959 }
960 ShapeVector tensor_value;
961 auto tensor_type_ptr = type->cast<TensorTypePtr>();
962 MS_EXCEPTION_IF_NULL(tensor_type_ptr);
963 auto tensor_type = tensor_type_ptr->element()->type_id();
964 if (tensor_type == kNumberTypeInt32) {
965 auto data_opt = ops::GetArrayValue<int>(value);
966 const auto &data = data_opt.value();
967 for (size_t i = 0; i < data.size(); i++) {
968 tensor_value.push_back(static_cast<int64_t>(data[i]));
969 }
970 } else if (tensor_type == kNumberTypeInt64) {
971 auto data_opt = ops::GetArrayValue<int64_t>(value);
972 tensor_value = data_opt.value().ToVector();
973 } else if (tensor_type == kNumberTypeUInt32) {
974 auto data_opt = ops::GetArrayValue<uint32_t>(value);
975 const auto &data = data_opt.value();
976 for (size_t i = 0; i < data.size(); i++) {
977 tensor_value.push_back(static_cast<int64_t>(data[i]));
978 }
979 } else if (tensor_type == kNumberTypeUInt64) {
980 auto data_opt = ops::GetArrayValue<uint64_t>(value);
981 const auto &data = data_opt.value();
982 for (size_t i = 0; i < data.size(); i++) {
983 tensor_value.push_back(static_cast<int64_t>(data[i]));
984 }
985 } else {
986 MS_EXCEPTION(TypeError) << "For primitive[" << prim_name << "], the input argument[" << tensor_name
987 << "] must be a Tensor[Int64] or Tensor[Int32]"
988 << " or Tensor[UInt64] or Tensor[UInt32] type, but got " << value->ToString();
989 }
990 return tensor_value;
991 }
992
CheckTensorSubClass(const string & type_name,const TypePtr & type,const std::set<TypePtr> & template_types,const string & prim_name,bool is_mix)993 TypePtr CheckAndConvertUtils::CheckTensorSubClass(const string &type_name, const TypePtr &type,
994 const std::set<TypePtr> &template_types, const string &prim_name,
995 bool is_mix) {
996 MS_EXCEPTION_IF_NULL(type);
997 auto real_type = type;
998 if (type->isa<TensorType>()) {
999 auto tensor_type = type->cast<TensorTypePtr>();
1000 real_type = tensor_type->element();
1001 }
1002 if (CheckType(real_type, template_types)) {
1003 return real_type;
1004 }
1005 std::ostringstream buffer;
1006 buffer << "For primitive[" << prim_name << "], the input argument[" << type_name << "] must be a type of {";
1007 std::set<string> order_set;
1008
1009 if (is_mix) {
1010 for (const auto &item : template_types) {
1011 (void)order_set.emplace(item->ToString());
1012 }
1013 }
1014
1015 for (const auto &item : template_types) {
1016 if (item->isa<TensorType>()) {
1017 (void)order_set.emplace(item->ToString());
1018 continue;
1019 }
1020 (void)order_set.emplace("Tensor[" + item->ToString() + "]");
1021 }
1022
1023 for (const auto &item : order_set) {
1024 buffer << item;
1025 if (item != *(--order_set.end())) {
1026 buffer << ", ";
1027 }
1028 }
1029 buffer << "}, but got " << type->ToString();
1030 buffer << ".";
1031 MS_EXCEPTION(TypeError) << buffer.str();
1032 }
1033
CheckSubClass(const std::string & type_name,const TypePtr & type,const std::set<TypePtr> & template_types,const std::string & prim_name)1034 TypePtr CheckAndConvertUtils::CheckSubClass(const std::string &type_name, const TypePtr &type,
1035 const std::set<TypePtr> &template_types, const std::string &prim_name) {
1036 if (CheckType(type, template_types)) {
1037 return type;
1038 }
1039 std::ostringstream buffer;
1040 buffer << "For primitive[" << prim_name << "], the input argument[" << type_name << "] must be a type of {";
1041 std::set<string> order_set;
1042 for (const auto &item : template_types) {
1043 (void)order_set.emplace(item->ToString());
1044 }
1045 for (const auto &item : order_set) {
1046 buffer << item;
1047 if (item != *(--order_set.end())) {
1048 buffer << ", ";
1049 }
1050 }
1051 buffer << "}, but got " << type->ToString();
1052 buffer << ".";
1053 MS_EXCEPTION(TypeError) << buffer.str();
1054 }
1055
CheckSubClassWithMoreInfo(const std::string & type_name,const TypePtr & type,const std::string & more_info,const std::set<TypePtr> & template_types,const std::string & prim_name)1056 TypePtr CheckAndConvertUtils::CheckSubClassWithMoreInfo(const std::string &type_name, const TypePtr &type,
1057 const std::string &more_info,
1058 const std::set<TypePtr> &template_types,
1059 const std::string &prim_name) {
1060 if (CheckType(type, template_types)) {
1061 return type;
1062 }
1063 std::ostringstream buffer;
1064 buffer << "For primitive[" << prim_name << "], the input argument[" << type_name << "] " << more_info
1065 << " must be a type of {";
1066 std::set<string> order_set;
1067 for (const auto &item : template_types) {
1068 (void)order_set.emplace(item->ToString());
1069 }
1070 for (const auto &item : order_set) {
1071 buffer << item;
1072 if (item != *(--order_set.end())) {
1073 buffer << ", ";
1074 }
1075 }
1076 buffer << "}, but got " << type->ToString();
1077 buffer << ".";
1078 MS_EXCEPTION(TypeError) << buffer.str();
1079 }
1080
CheckScalarOrTensorTypesSame(const std::map<std::string,TypePtr> & args,const std::set<TypePtr> & valid_values,const std::string & prim_name,bool allow_mix)1081 TypePtr CheckAndConvertUtils::CheckScalarOrTensorTypesSame(const std::map<std::string, TypePtr> &args,
1082 const std::set<TypePtr> &valid_values,
1083 const std::string &prim_name, bool allow_mix) {
1084 (void)CheckTypeSame(args, prim_name, allow_mix);
1085 return CheckTensorSubClass(args.begin()->first, args.begin()->second, valid_values, prim_name, true);
1086 }
1087
CheckTypeSame(const std::map<std::string,TypePtr> & args,const std::string & prim_name,const bool allow_mix)1088 TypePtr CheckAndConvertUtils::CheckTypeSame(const std::map<std::string, TypePtr> &args, const std::string &prim_name,
1089 const bool allow_mix) {
1090 if (args.empty()) {
1091 MS_EXCEPTION(ArgumentError) << "Trying to use the function to check a empty types map!";
1092 }
1093 std::ostringstream buffer;
1094 TypePtr return_type = args.begin()->second;
1095 buffer << "For primitive[" << prim_name << "], the ";
1096 bool tensor_flag = return_type->isa<TensorType>();
1097 std::set<TypeId> types_id;
1098 for (const auto &elem : args) {
1099 auto type = elem.second;
1100 MS_EXCEPTION_IF_NULL(type);
1101 if (!allow_mix) {
1102 // input must be all tensor or all other type
1103 if ((tensor_flag && !type->isa<TensorType>()) || (!tensor_flag && type->isa<TensorType>())) {
1104 buffer << "input type must be same.\n";
1105 for (const auto &error_elem : args) {
1106 buffer << "input argument[" << error_elem.first << "]:" << error_elem.second->ToString() << "\n";
1107 }
1108 MS_EXCEPTION(TypeError) << buffer.str();
1109 }
1110 }
1111 if (type->isa<TensorType>()) {
1112 auto tensor_type = type->cast<TensorTypePtr>();
1113 auto element = tensor_type->element();
1114 MS_EXCEPTION_IF_NULL(element);
1115 return_type = element;
1116 (void)types_id.emplace(element->type_id());
1117 } else {
1118 if (return_type->isa<TensorType>()) {
1119 return_type = type;
1120 }
1121 (void)types_id.emplace(type->type_id());
1122 }
1123 if (types_id.size() > 1) {
1124 buffer << "input type must be same.\n";
1125 for (const auto &item : args) {
1126 buffer << "name:[" << item.first << "]:" << item.second->ToString() << ".\n";
1127 }
1128 MS_EXCEPTION(TypeError) << buffer.str();
1129 }
1130 }
1131 return return_type->DeepCopy();
1132 }
1133
CheckTypeValid(const std::string & arg_name,const TypePtr & arg_type,const std::set<TypePtr> & valid_type,const std::string & prim_name)1134 TypePtr CheckAndConvertUtils::CheckTypeValid(const std::string &arg_name, const TypePtr &arg_type,
1135 const std::set<TypePtr> &valid_type, const std::string &prim_name) {
1136 if (valid_type.empty()) {
1137 MS_EXCEPTION(ArgumentError) << "Trying to use the function to check a empty valid_type!";
1138 }
1139 MS_EXCEPTION_IF_NULL(arg_type);
1140 if (arg_type->isa<TensorType>()) {
1141 return CheckTensorTypeValid(arg_name, arg_type, valid_type, prim_name);
1142 }
1143 return CheckSubClass(arg_name, arg_type, valid_type, prim_name);
1144 }
1145
CheckTypeValidWithMoreInfo(const std::string & arg_name,const TypePtr & arg_type,const std::string & more_info,const std::set<TypePtr> & valid_type,const std::string & prim_name)1146 TypePtr CheckAndConvertUtils::CheckTypeValidWithMoreInfo(const std::string &arg_name, const TypePtr &arg_type,
1147 const std::string &more_info,
1148 const std::set<TypePtr> &valid_type,
1149 const std::string &prim_name) {
1150 if (valid_type.empty()) {
1151 MS_EXCEPTION(ArgumentError) << "Trying to use the function to check a empty valid_type!";
1152 }
1153 MS_EXCEPTION_IF_NULL(arg_type);
1154 if (arg_type->isa<TensorType>()) {
1155 return CheckTensorTypeValid(arg_name, arg_type, valid_type, prim_name);
1156 }
1157 return CheckSubClassWithMoreInfo(arg_name, arg_type, more_info, valid_type, prim_name);
1158 }
1159
CheckIrAttrtoOpAttr(const std::string & op_type,const std::string & attr_name,ValuePtr * const value)1160 bool CheckAndConvertUtils::CheckIrAttrtoOpAttr(const std::string &op_type, const std::string &attr_name,
1161 ValuePtr *const value) {
1162 if (*value == nullptr) {
1163 MS_LOG(DEBUG) << "value is nullptr! op_type = " << op_type << ", attr_name = " << attr_name;
1164 return false;
1165 }
1166 if (op_type.empty() || attr_name.empty()) {
1167 return false;
1168 }
1169 auto op_map = kIrAttrToOpAttr.find(op_type);
1170 if (op_map == kIrAttrToOpAttr.end()) {
1171 return false;
1172 }
1173 auto attr_func = op_map->second.find(attr_name);
1174 if (attr_func == op_map->second.end()) {
1175 return false;
1176 }
1177 *value = attr_func->second(*value);
1178 MS_LOG(DEBUG) << "convert ir attr to op attr, name: " << op_type << ", attr: " << attr_name;
1179 return true;
1180 }
1181
CheckSummaryParam(const AbstractBasePtr & name,const AbstractBasePtr & value,const std::string & class_name)1182 void CheckAndConvertUtils::CheckSummaryParam(const AbstractBasePtr &name, const AbstractBasePtr &value,
1183 const std::string &class_name) {
1184 MS_EXCEPTION_IF_NULL(name);
1185 MS_EXCEPTION_IF_NULL(value);
1186 (void)CheckTypeValid("name", name->BuildType(), {kString}, class_name);
1187 auto s = GetValue<std::string>(name->BuildValue());
1188 if (s.empty()) {
1189 MS_EXCEPTION(ValueError) << "For primitive[" << class_name << "], the input argument[name]"
1190 << " cannot be an empty string.";
1191 }
1192 (void)CheckTypeValid("value", value->BuildType(), {kTensorType}, class_name);
1193 }
1194
CheckTensorFloatValue(const std::string & type_name,const ValuePtr & value,const std::string & prim_name)1195 std::vector<double> CheckAndConvertUtils::CheckTensorFloatValue(const std::string &type_name, const ValuePtr &value,
1196 const std::string &prim_name) {
1197 if (value == nullptr) {
1198 MS_EXCEPTION(ValueError) << "For primitive[" << prim_name << "], the input argument[" << type_name
1199 << "] value is nullptr.";
1200 }
1201 std::vector<double> tensor_value;
1202 if (!value->isa<tensor::BaseTensor>()) {
1203 MS_EXCEPTION(ValueError) << "For primitive[" << prim_name << "], the input argument[" << type_name
1204 << "] must be a tensor, but got " << value->ToString();
1205 }
1206 auto input_tensor = value->cast<tensor::BaseTensorPtr>();
1207 MS_EXCEPTION_IF_NULL(input_tensor);
1208 size_t data_size = input_tensor->DataSize();
1209 auto tensor_type = input_tensor->Dtype();
1210 if (tensor_type->type_id() == kNumberTypeFloat32) {
1211 auto data_c = static_cast<float *>(input_tensor->data_c());
1212 MS_EXCEPTION_IF_NULL(data_c);
1213 for (size_t i = 0; i < data_size; i++) {
1214 tensor_value.push_back(static_cast<double>(*data_c));
1215 ++data_c;
1216 }
1217 } else if (tensor_type->type_id() == kNumberTypeFloat64) {
1218 auto tensor_data = static_cast<double *>(input_tensor->data_c());
1219 MS_EXCEPTION_IF_NULL(tensor_data);
1220 tensor_value = {tensor_data, tensor_data + data_size};
1221 } else {
1222 MS_EXCEPTION(TypeError) << "For primitive[" << prim_name << "], the input argument[" << type_name
1223 << "] must be a Tensor[Float32] or Tensor[Float64], but got " << value->ToString();
1224 }
1225 return tensor_value;
1226 }
1227
CheckListOrTupleFloat(const std::string & arg_name,const ValuePtr & attr,const std::string & prim_name)1228 std::vector<double> CheckAndConvertUtils::CheckListOrTupleFloat(const std::string &arg_name, const ValuePtr &attr,
1229 const std::string &prim_name) {
1230 std::vector<double> result;
1231 bool is_correct = false;
1232 MS_EXCEPTION_IF_NULL(attr);
1233 if (attr->isa<ValueTuple>() || attr->isa<ValueList>()) {
1234 auto attr_vec =
1235 attr->isa<ValueTuple>() ? attr->cast<ValueTuplePtr>()->value() : attr->cast<ValueListPtr>()->value();
1236 if (attr_vec.empty()) {
1237 return result;
1238 }
1239 is_correct = std::all_of(attr_vec.begin(), attr_vec.end(), [&result](const ValuePtr &e) -> bool {
1240 MS_EXCEPTION_IF_NULL(e);
1241 if (e->isa<FP32Imm>()) {
1242 (void)result.emplace_back(static_cast<double>(GetValue<float>(e)));
1243 return true;
1244 } else if (e->isa<FP64Imm>()) {
1245 (void)result.emplace_back(GetValue<double>(e));
1246 return true;
1247 }
1248 return false;
1249 });
1250 }
1251 if (!is_correct) {
1252 MS_EXCEPTION(TypeError) << "For primitive[" << prim_name << "], the " << arg_name
1253 << " must be one of ['tuple', 'list'] with all Float elements, but got "
1254 << attr->ToString();
1255 }
1256 return result;
1257 }
1258
CheckListOrTupleFloat(const std::string & arg_name,const AbstractBasePtr & abs,const std::string & prim_name)1259 std::vector<pyfloat> CheckAndConvertUtils::CheckListOrTupleFloat(const std::string &arg_name,
1260 const AbstractBasePtr &abs,
1261 const std::string &prim_name) {
1262 std::vector<pyfloat> result{};
1263 if (IsSequence(abs)) {
1264 const auto &type_list = GetSequenceElementTypes(abs);
1265 if (type_list.empty()) {
1266 return result;
1267 }
1268 auto is_correct = std::all_of(type_list.begin(), type_list.end(), [](const TypePtr &e) -> bool {
1269 MS_EXCEPTION_IF_NULL(e);
1270 return e->type_id() == kNumberTypeFloat64 || e->type_id() == kNumberTypeFloat32;
1271 });
1272 if (is_correct) {
1273 const auto &arr_value = ops::GetArrayValue<pyfloat>(abs);
1274 if (arr_value->HasUnknownValue()) {
1275 MS_EXCEPTION(ValueError) << "For primitive[" << prim_name << "], there are unknown values in the " << arg_name
1276 << ", please handle this case before calling this function.";
1277 }
1278 result = arr_value->ToVector();
1279 } else {
1280 MS_EXCEPTION(TypeError) << "For primitive[" << prim_name << "], the " << arg_name
1281 << " must be one of ['tuple', 'list'] with all Float elements, but got "
1282 << abs->ToString();
1283 }
1284 return result;
1285 }
1286 MS_EXCEPTION(TypeError) << "For primitive[" << prim_name << "], the " << arg_name
1287 << " must be one of ['tuple', 'list'] with all Float elements, but got " << abs->ToString();
1288 }
1289
CheckIntOrTupleInt(const std::string & arg_name,const ValuePtr & attr,const std::string & prim_name)1290 std::vector<int64_t> CheckAndConvertUtils::CheckIntOrTupleInt(const std::string &arg_name, const ValuePtr &attr,
1291 const std::string &prim_name) {
1292 std::vector<int64_t> result;
1293 bool is_correct = false;
1294 MS_EXCEPTION_IF_NULL(attr);
1295 if (attr->isa<ValueTuple>() || attr->isa<ValueList>()) {
1296 auto attr_vec =
1297 attr->isa<ValueTuple>() ? attr->cast<ValueTuplePtr>()->value() : attr->cast<ValueListPtr>()->value();
1298 if (attr_vec.empty()) {
1299 return result;
1300 }
1301 is_correct = std::all_of(attr_vec.begin(), attr_vec.end(), [&result](const ValuePtr &e) -> bool {
1302 MS_EXCEPTION_IF_NULL(e);
1303 if (e->isa<Int64Imm>()) {
1304 (void)result.emplace_back(GetValue<int64_t>(e));
1305 return true;
1306 } else if (e->isa<Int32Imm>()) {
1307 (void)result.emplace_back(GetValue<int32_t>(e));
1308 return true;
1309 }
1310 return false;
1311 });
1312 } else {
1313 if (attr->isa<Int64Imm>()) {
1314 is_correct = true;
1315 int64_t attr_val = attr->cast<Int64ImmPtr>()->value();
1316 result.push_back(attr_val);
1317 } else if (attr->isa<Int32Imm>()) {
1318 is_correct = true;
1319 int64_t attr_val = attr->cast<Int32ImmPtr>()->value();
1320 result.push_back(attr_val);
1321 }
1322 }
1323 if (!is_correct) {
1324 MS_EXCEPTION(TypeError) << "For primitive[" << prim_name << "], the " << arg_name
1325 << " must be one of ['int', 'tuple', 'list'] with all Int elements, but got "
1326 << attr->ToString();
1327 }
1328 return result;
1329 }
1330
CheckIntOrTupleInt(const std::string & arg_name,const AbstractBasePtr & abs,const std::string & prim_name)1331 std::vector<int64_t> CheckAndConvertUtils::CheckIntOrTupleInt(const std::string &arg_name, const AbstractBasePtr &abs,
1332 const std::string &prim_name) {
1333 std::vector<int64_t> result{};
1334 if (IsSequence(abs)) {
1335 const auto &type_list = GetSequenceElementTypes(abs);
1336 if (type_list.empty()) {
1337 return result;
1338 }
1339 auto is_correct = std::all_of(type_list.begin(), type_list.end(), [](const TypePtr &e) -> bool {
1340 MS_EXCEPTION_IF_NULL(e);
1341 return e->type_id() == kNumberTypeInt64 || e->type_id() == kNumberTypeInt32;
1342 });
1343 if (!is_correct) {
1344 MS_EXCEPTION(TypeError) << "For primitive[" << prim_name << "], when the " << arg_name
1345 << "'s type is one of ['tuple', 'list'], its element data type must be int32 or int64, "
1346 "but got "
1347 << abs->ToString();
1348 } else if (type_list.front()->type_id() == kNumberTypeInt64) {
1349 const auto &arr_value = ops::GetArrayValue<int64_t>(abs);
1350 if (arr_value->HasUnknownValue()) {
1351 MS_EXCEPTION(ValueError) << "For primitive[" << prim_name << "], there are unknown values in the " << arg_name
1352 << ", please handle this case before calling this function.";
1353 }
1354 result = arr_value->ToVector();
1355 } else if (type_list.front()->type_id() == kNumberTypeInt32) {
1356 const auto &arr_value = ops::GetArrayValue<int>(abs);
1357 if (arr_value->HasUnknownValue()) {
1358 MS_EXCEPTION(ValueError) << "For primitive[" << prim_name << "], there are unknown values in the " << arg_name
1359 << ", please handle this case before calling this function.";
1360 }
1361 const auto &vec_value = arr_value->ToVector();
1362 (void)std::transform(vec_value.begin(), vec_value.end(), std::back_inserter(result),
1363 [](int ele) -> int64_t { return static_cast<int64_t>(ele); });
1364 }
1365 } else {
1366 if (!ops::IsValueKnown(abs)) {
1367 MS_EXCEPTION(ValueError) << "For primitive[" << prim_name << "], the value of [" << arg_name
1368 << "] is unknown, please handle this case before calling this function.";
1369 }
1370 auto data_type = abs->GetType();
1371 MS_EXCEPTION_IF_NULL(data_type);
1372 if (data_type->type_id() == kNumberTypeInt64) {
1373 const auto &val = ops::GetScalarValue<int64_t>(abs->GetValue());
1374 result.push_back(val.value());
1375 } else if (data_type->type_id() == kNumberTypeInt32) {
1376 const auto &val = ops::GetScalarValue<int>(abs->GetValue());
1377 result.push_back(val.value());
1378 } else {
1379 MS_EXCEPTION(TypeError) << "For primitive[" << prim_name << "], when the " << arg_name
1380 << "'s type is 'int', its data type must be int32 or int64, but got "
1381 << data_type->ToString();
1382 }
1383 }
1384 return result;
1385 }
1386
CheckAttrTuple(const PrimitivePtr & prim,const std::string & attr_name,size_t num_element)1387 std::vector<int64_t> CheckAndConvertUtils::CheckAttrTuple(const PrimitivePtr &prim, const std::string &attr_name,
1388 size_t num_element) {
1389 MS_EXCEPTION_IF_NULL(prim);
1390 auto attr = prim->GetAttr(attr_name);
1391 MS_EXCEPTION_IF_NULL(attr);
1392 std::vector<int64_t> result;
1393 if (!attr->isa<ValueTuple>()) {
1394 MS_EXCEPTION(ValueError) << "For '" << prim->name() << "', the '" << attr_name
1395 << "' should be a tuple[int64], but got: " << attr->ToString() << ".";
1396 }
1397 std::vector<ValuePtr> attr_vec = attr->cast<ValueTuplePtr>()->value();
1398 if (attr_vec.size() != num_element) {
1399 MS_EXCEPTION(ValueError) << "For '" << prim->name() << "', the '" << attr_name
1400 << "' should be a tuple[int64] with size " << num_element << ", but its size is "
1401 << attr_vec.size() << ".";
1402 }
1403 (void)std::transform(attr_vec.begin(), attr_vec.end(), std::back_inserter(result),
1404 [&prim, &attr_name](const ValuePtr &e) -> int64_t {
1405 auto value = GetValue<int64_t>(e);
1406 if (value < 0) {
1407 MS_EXCEPTION(ValueError) << "For '" << prim->name() << "', the element of '" << attr_name
1408 << "' should not be negative number, but got " << value << ".";
1409 }
1410 return value;
1411 });
1412 return result;
1413 }
1414
CheckTupleInt(const std::string & arg_name,const ValuePtr & attr,const std::string & prim_name)1415 std::vector<int64_t> CheckAndConvertUtils::CheckTupleInt(const std::string &arg_name, const ValuePtr &attr,
1416 const std::string &prim_name) {
1417 std::vector<int64_t> result;
1418 MS_EXCEPTION_IF_NULL(attr);
1419 if (attr->isa<ValueTuple>()) {
1420 std::vector<ValuePtr> attr_vec = attr->cast<ValueTuplePtr>()->value();
1421 (void)std::transform(
1422 attr_vec.begin(), attr_vec.end(), std::back_inserter(result), [=](const ValuePtr &e) -> int64_t {
1423 if (!e->isa<Int64Imm>()) {
1424 MS_EXCEPTION(TypeError) << "For primitive[" << prim_name << "], the " << arg_name
1425 << " must be a tuple with all Int elements, but got " << attr->type_name();
1426 }
1427 return GetValue<int64_t>(e);
1428 });
1429 } else if (attr->isa<KernelTensorValue>()) {
1430 // to_do: check type of the KernelTensorValue is int64
1431 auto data_opt = ops::GetArrayValue<int64_t>(attr);
1432 const auto &data_array = data_opt.value();
1433 result = data_array.ToVector();
1434 } else {
1435 MS_EXCEPTION(TypeError) << "For primitive[" << prim_name << "], the " << arg_name
1436 << " must be a tuple with all Int elements, but got " << attr->type_name() << ".";
1437 }
1438 return result;
1439 }
1440
CheckListInt(const std::string & arg_name,const ValuePtr & attr,const std::string & prim_name)1441 std::vector<int64_t> CheckAndConvertUtils::CheckListInt(const std::string &arg_name, const ValuePtr &attr,
1442 const std::string &prim_name) {
1443 std::vector<int64_t> result;
1444 MS_EXCEPTION_IF_NULL(attr);
1445 if (attr->isa<ValueList>()) {
1446 std::vector<ValuePtr> attr_vec = attr->cast<ValueListPtr>()->value();
1447 (void)std::transform(
1448 attr_vec.begin(), attr_vec.end(), std::back_inserter(result), [=](const ValuePtr &e) -> int64_t {
1449 if (!e->isa<Int64Imm>()) {
1450 MS_EXCEPTION(TypeError) << "For primitive[" << prim_name << "], the " << arg_name
1451 << " must be a list with all Int elements, but got " << attr->ToString();
1452 }
1453 return GetValue<int64_t>(e);
1454 });
1455 } else if (attr->isa<KernelTensorValue>()) {
1456 // to_do: check type of the KernelTensorValue is int64
1457 auto data_opt = ops::GetArrayValue<int64_t>(attr);
1458 const auto &data_array = data_opt.value();
1459 result = data_array.ToVector();
1460 } else {
1461 MS_EXCEPTION(TypeError) << "For primitive[" << prim_name << "], the " << arg_name
1462 << " must be a list with all Int elements, but got " << attr->ToString() << ".";
1463 }
1464 return result;
1465 }
1466
GetAndCheckFormat(const ValuePtr & value)1467 int64_t CheckAndConvertUtils::GetAndCheckFormat(const ValuePtr &value) {
1468 int64_t data_format;
1469 bool result = CheckAndConvertUtils::GetDataFormatEnumValue(value, &data_format);
1470 if (!result ||
1471 (data_format != static_cast<int64_t>(Format::NHWC) && data_format != static_cast<int64_t>(Format::NCHW) &&
1472 data_format != static_cast<int64_t>(Format::NCDHW))) {
1473 MS_LOG(EXCEPTION) << "data format value " << data_format << " is invalid, only support NCHW, NHWC and NCDHW";
1474 }
1475 return data_format;
1476 }
GetRemoveMonadAbsNum(const AbstractBasePtrList & abs_list)1477 size_t CheckAndConvertUtils::GetRemoveMonadAbsNum(const AbstractBasePtrList &abs_list) {
1478 size_t remove_monad_count = abs_list.size();
1479 for (const auto &item : abs_list) {
1480 if (item->isa<abstract::AbstractMonad>()) {
1481 --remove_monad_count;
1482 }
1483 }
1484
1485 for (size_t i = 0; i < remove_monad_count; ++i) {
1486 if (abs_list[i]->isa<abstract::AbstractMonad>()) {
1487 MS_EXCEPTION(UnknownError) << "The monad inputs of the node must at last of the node inputs.";
1488 }
1489 }
1490 return remove_monad_count;
1491 }
GetRemoveUMonadAbsNum(const AbstractBasePtrList & abs_list)1492 size_t CheckAndConvertUtils::GetRemoveUMonadAbsNum(const AbstractBasePtrList &abs_list) {
1493 size_t remove_umonad_count = abs_list.size();
1494 for (const auto &item : abs_list) {
1495 if (item->isa<abstract::AbstractUMonad>()) {
1496 --remove_umonad_count;
1497 }
1498 }
1499
1500 for (size_t i = 0; i < remove_umonad_count; ++i) {
1501 if (abs_list[i]->isa<abstract::AbstractUMonad>()) {
1502 MS_EXCEPTION(UnknownError) << "The umonad inputs of the node must at last of the node inputs.";
1503 }
1504 }
1505 return remove_umonad_count;
1506 }
HasDynamicShapeInput(const AbstractBasePtrList & abs_list)1507 bool CheckAndConvertUtils::HasDynamicShapeInput(const AbstractBasePtrList &abs_list) {
1508 for (const auto &item : abs_list) {
1509 MS_EXCEPTION_IF_NULL(item);
1510 auto shape = item->GetShape();
1511 if (shape->IsDynamic()) {
1512 return true;
1513 }
1514 }
1515 return false;
1516 }
1517
CheckArgsType(const std::string & op,const AbstractBasePtrList & args_spec_list,size_t index,TypeId type_id)1518 AbstractBasePtr CheckAndConvertUtils::CheckArgsType(const std::string &op, const AbstractBasePtrList &args_spec_list,
1519 size_t index, TypeId type_id) {
1520 if (index >= args_spec_list.size()) {
1521 MS_EXCEPTION(ValueError) << op << " evaluator arguments list index out of bound, size " << args_spec_list.size()
1522 << ", index " << index;
1523 }
1524 auto args_abs = args_spec_list[index];
1525 MS_EXCEPTION_IF_NULL(args_abs);
1526 if (args_abs->GetType()->object_type() != type_id) {
1527 MS_EXCEPTION(TypeError) << "For primitive[" << op << "], the input[" << index << "] should be a "
1528 << TypeIdToType(type_id)->ToString() << ", but got " << args_abs->GetType()->ToString()
1529 << ".";
1530 }
1531 return args_abs;
1532 }
1533
CheckArgsSequenceType(const std::string & op,const AbstractBasePtrList & args_spec_list,size_t index)1534 AbstractBasePtr CheckAndConvertUtils::CheckArgsSequenceType(const std::string &op,
1535 const AbstractBasePtrList &args_spec_list, size_t index) {
1536 if (index >= args_spec_list.size()) {
1537 MS_EXCEPTION(ValueError) << op << " evaluator arguments list index out of bound, size " << args_spec_list.size()
1538 << ", index " << index;
1539 }
1540 auto args_abs = args_spec_list[index];
1541 MS_EXCEPTION_IF_NULL(args_abs);
1542 if (!IsSequence(args_abs)) {
1543 MS_EXCEPTION(TypeError) << "For primitive[" << op << "], the input[" << index << "] should be a "
1544 << "tuple or list, but got " << args_abs->GetType()->ToString() << ".";
1545 }
1546 return args_abs;
1547 }
1548 } // namespace mindspore
1549