• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "utils/check_convert_utils.h"
18 
19 #include <algorithm>
20 #include <functional>
21 #include <iterator>
22 #include <map>
23 #include <set>
24 #include <typeinfo>
25 #include <utility>
26 #include <vector>
27 
28 #include "abstract/abstract_value.h"
29 #include "ir/dtype.h"
30 #include "ir/dtype/tensor_type.h"
31 #include "ir/dtype/type.h"
32 #include "ir/scalar.h"
33 #include "ir/tensor.h"
34 #include "ir/value.h"
35 #include "mindapi/base/format.h"
36 #include "mindapi/base/type_id.h"
37 #include "mindapi/base/types.h"
38 #include "ops/op_name.h"
39 #include "utils/convert_utils_base.h"
40 #include "utils/ms_context.h"
41 #include "ops/op_utils.h"
42 #include "ir/kernel_tensor_value.h"
43 
44 namespace mindspore {
45 static std::map<std::string, int64_t> DataFormatToEnumMap = {
46   {"NCHW", Format::NCHW},   {"NHWC", Format::NHWC},     {"NHWC4", Format::NHWC4},
47   {"HWKC", Format::HWKC},   {"HWCK", Format::HWCK},     {"KCHW", Format::KCHW},
48   {"CKHW", Format::CKHW},   {"KHWC", Format::KHWC},     {"CHWK", Format::CHWK},
49   {"HW", Format::HW},       {"HW4", Format::HW4},       {"NC", Format::NC},
50   {"NC4", Format::NC4},     {"NC4HW4", Format::NC4HW4}, {"NUM_OF_FORMAT", Format::NUM_OF_FORMAT},
51   {"NCDHW", Format::NCDHW}, {"NWC", Format::NWC},       {"NCW", Format::NCW},
52 };
53 
54 static std::map<int64_t, std::string> DataFormatToStrMap = {
55   {Format::NCHW, "NCHW"},   {Format::NHWC, "NHWC"},     {Format::NHWC4, "NHWC4"},
56   {Format::HWKC, "HWKC"},   {Format::HWCK, "HWCK"},     {Format::KCHW, "KCHW"},
57   {Format::CKHW, "CKHW"},   {Format::KHWC, "KHWC"},     {Format::CHWK, "CHWK"},
58   {Format::HW, "HW"},       {Format::HW4, "HW4"},       {Format::NC, "NC"},
59   {Format::NC4, "NC4"},     {Format::NC4HW4, "NC4HW4"}, {Format::NUM_OF_FORMAT, "NUM_OF_FORMAT"},
60   {Format::NCDHW, "NCDHW"}, {Format::NWC, "NWC"},       {Format::NCW, "NCW"},
61 };
62 
63 static std::map<std::string, int64_t> ReductionToEnumMap = {
64   {"sum", Reduction::REDUCTION_SUM},
65   {"mean", Reduction::MEAN},
66   {"none", Reduction::NONE},
67 };
68 
69 static std::map<int64_t, std::string> ReductionToStrMap = {
70   {Reduction::REDUCTION_SUM, "sum"},
71   {Reduction::MEAN, "mean"},
72   {Reduction::NONE, "none"},
73 };
74 
75 static std::map<std::string, int64_t> PadModToEnumMap = {
76   {"pad", PadMode::PAD},
77   {"same", PadMode::SAME},
78   {"valid", PadMode::VALID},
79 };
80 
81 static std::map<int64_t, std::string> PadModToStrMap = {
82   {PadMode::PAD, "pad"},
83   {PadMode::SAME, "same"},
84   {PadMode::VALID, "valid"},
85 };
86 
87 static std::map<std::string, int64_t> PadModToEnumUpperMap = {
88   {"PAD", PadMode::PAD},
89   {"SAME", PadMode::SAME},
90   {"VALID", PadMode::VALID},
91   // this should be removed, cause some op change "PAD" to "CALCULATED" in python.
92   {"CALCULATED", PadMode::PAD},
93 };
94 
95 static std::map<int64_t, std::string> PadModToStrUpperMap = {
96   {PadMode::PAD, "PAD"},
97   {PadMode::SAME, "SAME"},
98   {PadMode::VALID, "VALID"},
99 };
100 
101 AttrConverterPair DataFormatConverter(DataFormatToEnumMap, DataFormatToStrMap);
102 AttrConverterPair PadModeConverter(PadModToEnumMap, PadModToStrMap);
103 AttrConverterPair PadModeUpperConverter(PadModToEnumUpperMap, PadModToStrUpperMap);
104 AttrConverterPair ReductionConverter(ReductionToEnumMap, ReductionToStrMap);
105 
106 static std::map<std::string, AttrConverterPair> FormatAndPadAttrMap = {
107   {ops::kFormat, DataFormatConverter},
108   {ops::kPadMode, PadModeConverter},
109 };
110 
111 static std::map<std::string, AttrConverterPair> FormatAndPadUpperAttrMap = {
112   {ops::kFormat, DataFormatConverter},
113   {ops::kPadMode, PadModeUpperConverter},
114 };
115 
116 static std::map<std::string, AttrConverterPair> DataFormatMap = {
117   {ops::kFormat, DataFormatConverter},
118 };
119 
120 static std::map<std::string, AttrConverterPair> FormatAndDataFormatMap = {
121   {ops::kFormat, DataFormatConverter},
122   {ops::kDataFormat, DataFormatConverter},
123 };
124 
125 static std::map<std::string, AttrConverterPair> ReductionMap = {
126   {ops::kReduction, ReductionConverter},
127 };
128 
129 static std::map<std::string, std::map<std::string, AttrConverterPair>> PrimAttrConvertMap = {
130   {"Conv2D", FormatAndPadAttrMap},
131   {"Conv2DTranspose", FormatAndPadUpperAttrMap},
132   {"Conv2DBackpropInput", FormatAndPadUpperAttrMap},
133   {"Conv2DBackpropFilter", FormatAndPadUpperAttrMap},
134   {"Conv3D", FormatAndPadAttrMap},
135   {"Conv3DBackpropInput", FormatAndPadAttrMap},
136   {"Conv3DBackpropFilter", FormatAndPadAttrMap},
137   {"Conv3DTranspose", DataFormatMap},
138   {"DepthwiseConv2dNative", FormatAndPadAttrMap},
139   {"DepthwiseConv2dNativeBackpropInput", FormatAndPadAttrMap},
140   {"DepthwiseConv2dNativeBackpropFilter", FormatAndPadAttrMap},
141   {"AvgPool", FormatAndPadUpperAttrMap},
142   {"MaxPoolV1", FormatAndPadUpperAttrMap},
143   {"MaxPool", FormatAndPadUpperAttrMap},
144   {"MaxPoolWithArgmax", FormatAndPadUpperAttrMap},
145   {"AvgPoolGrad", FormatAndPadUpperAttrMap},
146   {"AvgPoolGradVm", FormatAndPadUpperAttrMap},
147   {"AvgPoolGradGpu", FormatAndPadUpperAttrMap},
148   {"AvgPoolGradCpu", FormatAndPadUpperAttrMap},
149   {"AvgPoolV1", FormatAndPadUpperAttrMap},
150   {"AvgPoolGradV1", FormatAndPadUpperAttrMap},
151   {"MaxPoolGrad", FormatAndPadUpperAttrMap},
152   {"MaxPoolGradV1", FormatAndPadUpperAttrMap},
153   {"MaxPoolGradGrad", FormatAndPadUpperAttrMap},
154   {"MaxPoolGradWithArgmax", FormatAndPadUpperAttrMap},
155   {"MaxPoolGradGradWithArgmax", FormatAndPadUpperAttrMap},
156   {"BatchNorm", DataFormatMap},
157   {"BatchNormGrad", DataFormatMap},
158   {"BiasAdd", DataFormatMap},
159   {"BiasAddGrad", DataFormatMap},
160   {"BinaryCrossEntropy", ReductionMap},
161   {"BinaryCrossEntropyGrad", ReductionMap},
162   {"NLLLoss", ReductionMap},
163   {"NLLLossGrad", ReductionMap},
164   {"DepthToSpace", FormatAndDataFormatMap},
165   {"SpaceToDepth", FormatAndDataFormatMap},
166   {"Pooling", DataFormatMap},
167   {"Deconvolution", DataFormatMap},
168   {"AvgPoolV2", DataFormatMap},
169   {"MaxPoolV3", DataFormatMap},
170   {"FusedBatchNorm", DataFormatMap},
171   {"DeformableConv2d", DataFormatMap}};
172 
CheckPrimAttrConverted(const std::string & op_name)173 bool CheckAndConvertUtils::CheckPrimAttrConverted(const std::string &op_name) {
174   return PrimAttrConvertMap.find(op_name) != PrimAttrConvertMap.end();
175 }
176 
GetDataFormatEnumValue(const ValuePtr & value,int64_t * enum_value)177 bool CheckAndConvertUtils::GetDataFormatEnumValue(const ValuePtr &value, int64_t *enum_value) {
178   MS_EXCEPTION_IF_NULL(value);
179   if (!value->isa<StringImm>()) {
180     *enum_value = GetValue<int64_t>(value);
181     return true;
182   }
183   auto attr_value_str = GetValue<std::string>(value);
184   auto iter = DataFormatToEnumMap.find(attr_value_str);
185   if (iter == DataFormatToEnumMap.end()) {
186     MS_LOG(DEBUG) << "The data format " << attr_value_str << " not be converted to enum.";
187     return false;
188   }
189   *enum_value = iter->second;
190   return true;
191 }
192 
GetPadModEnumValue(const ValuePtr & value,int64_t * enum_value,bool is_upper)193 void CheckAndConvertUtils::GetPadModEnumValue(const ValuePtr &value, int64_t *enum_value, bool is_upper) {
194   MS_EXCEPTION_IF_NULL(value);
195   if (!value->isa<StringImm>()) {
196     *enum_value = GetValue<int64_t>(value);
197     return;
198   }
199   auto attr_value_str = GetValue<std::string>(value);
200 
201   if (is_upper) {
202     auto iter = PadModToEnumUpperMap.find(attr_value_str);
203     if (iter == PadModToEnumUpperMap.end()) {
204       MS_LOG(EXCEPTION) << "Invalid pad mode " << attr_value_str << " use pad, valid or same";
205     }
206     *enum_value = iter->second;
207     return;
208   }
209   auto iter = PadModToEnumMap.find(attr_value_str);
210   if (iter == PadModToEnumMap.end()) {
211     MS_LOG(EXCEPTION) << "Invalid pad mode " << attr_value_str << " use pad, valid or same";
212   }
213   *enum_value = iter->second;
214 }
215 
GetReductionEnumValue(const ValuePtr & value,int64_t * enum_value)216 void CheckAndConvertUtils::GetReductionEnumValue(const ValuePtr &value, int64_t *enum_value) {
217   MS_EXCEPTION_IF_NULL(value);
218   if (!value->isa<StringImm>()) {
219     *enum_value = GetValue<int64_t>(value);
220     return;
221   }
222   auto attr_value_str = GetValue<std::string>(value);
223   auto iter = ReductionToEnumMap.find(attr_value_str);
224   if (iter == ReductionToEnumMap.end()) {
225     MS_LOG(EXCEPTION) << "Invalid pad mode " << attr_value_str << " use pad, valid or same";
226   }
227   *enum_value = iter->second;
228 }
229 
GetAttrConvertPair(const std::string & op_type,const std::string & attr_name)230 AttrConverterPair CheckAndConvertUtils::GetAttrConvertPair(const std::string &op_type, const std::string &attr_name) {
231   AttrConverterPair attr_pair;
232   if (op_type.empty() || attr_name.empty()) {
233     return attr_pair;
234   }
235   auto op_attr_map_it = PrimAttrConvertMap.find(op_type);
236   if (op_attr_map_it == PrimAttrConvertMap.end()) {
237     return attr_pair;
238   }
239   auto attr_pair_it = op_attr_map_it->second.find(attr_name);
240   if (attr_pair_it == op_attr_map_it->second.end()) {
241     return attr_pair;
242   }
243 
244   return attr_pair_it->second;
245 }
246 
ConvertAttrValueToInt(const std::string & op_type,const std::string & attr_name,ValuePtr * const value)247 bool CheckAndConvertUtils::ConvertAttrValueToInt(const std::string &op_type, const std::string &attr_name,
248                                                  ValuePtr *const value) {
249   if (value == nullptr || *value == nullptr) {
250     MS_LOG(DEBUG) << "value of attr " << op_type << attr_name << " is nullptr.";
251     return false;
252   }
253   if (!(*value)->isa<StringImm>()) {
254     return false;
255   }
256   auto attr_map_pair = GetAttrConvertPair(op_type, attr_name);
257   if (attr_map_pair.first.empty()) {
258     return false;
259   }
260 
261   std::string real_value = std::dynamic_pointer_cast<StringImm>(*value)->value();
262   bool do_convert = false;
263   if (attr_map_pair.first.find(real_value) != attr_map_pair.first.end()) {
264     do_convert = true;
265   }
266   if (!do_convert) {
267     transform(real_value.begin(), real_value.end(), real_value.begin(), ::toupper);
268     if (attr_map_pair.first.find(real_value) != attr_map_pair.first.end()) {
269       do_convert = true;
270     }
271   }
272   if (!do_convert) {
273     transform(real_value.begin(), real_value.end(), real_value.begin(), ::tolower);
274     if (attr_map_pair.first.find(real_value) == attr_map_pair.first.end()) {
275       MS_LOG(DEBUG) << "Can not convert " << op_type << " attr " << attr_name << ": " << real_value << " to int";
276       return false;
277     }
278   }
279   *value = MakeValue<int64_t>(attr_map_pair.first[real_value]);
280   MS_LOG(DEBUG) << "convert str to int, name: " << op_type << ", attr: " << attr_name;
281   return true;
282 }
283 
ConvertAttrValueToString(const std::string & op_type,const std::string & attr_name,ValuePtr * const value)284 bool CheckAndConvertUtils::ConvertAttrValueToString(const std::string &op_type, const std::string &attr_name,
285                                                     ValuePtr *const value) {
286   if (value == nullptr || *value == nullptr) {
287     MS_LOG(DEBUG) << "value is nullptr! op_type = " << op_type << ", attr_name = " << attr_name;
288     return false;
289   }
290   if (!(*value)->isa<Int64Imm>()) {
291     return false;
292   }
293   auto attr_map_pair = GetAttrConvertPair(op_type, attr_name);
294   if (attr_map_pair.second.empty()) {
295     return false;
296   }
297 
298   int64_t real_value = std::dynamic_pointer_cast<Int64Imm>(*value)->value();
299   if (attr_map_pair.second.find(real_value) == attr_map_pair.second.end()) {
300     MS_LOG(DEBUG) << "Can not convert " << op_type << " attr " << attr_name << ": " << real_value << " to string";
301     return false;
302   }
303   *value = MakeValue<std::string>(attr_map_pair.second[real_value]);
304   MS_LOG(DEBUG) << "convert int to str, name: " << op_type << ", attr: " << attr_name;
305   return true;
306 }
307 
GetFormatStringVal(const PrimitivePtr & prim,std::string * format)308 void CheckAndConvertUtils::GetFormatStringVal(const PrimitivePtr &prim, std::string *format) {
309   if (prim == nullptr || format == nullptr) {
310     MS_LOG(DEBUG) << "Prim or format is nullptr.";
311     return;
312   }
313   auto value_ptr = prim->GetAttr(ops::kFormat);
314   if (value_ptr == nullptr) {
315     MS_LOG(DEBUG) << "Val is nullptr! op type = " << prim->name();
316     return;
317   }
318   int64_t data_format;
319   bool result = CheckAndConvertUtils::GetDataFormatEnumValue(value_ptr, &data_format);
320   if (result) {
321     if (DataFormatToStrMap.find(data_format) != DataFormatToStrMap.end()) {
322       *format = DataFormatToStrMap.at(data_format);
323     }
324   }
325 }
326 
CheckAbstractShapeSame(const std::vector<AbstractBasePtr> & abs_list)327 size_t CheckAndConvertUtils::CheckAbstractShapeSame(const std::vector<AbstractBasePtr> &abs_list) {
328   if (abs_list.size() <= 1) {
329     return 0;
330   }
331   const auto &first_elem_abs = abs_list[0];
332   MS_EXCEPTION_IF_NULL(first_elem_abs);
333   auto abs1_shape = first_elem_abs->GetShape();
334   MS_EXCEPTION_IF_NULL(abs1_shape);
335   for (size_t i = 0; i < abs_list.size(); ++i) {
336     MS_EXCEPTION_IF_NULL(abs_list[i]);
337     auto abs2_shape = abs_list[i]->GetShape();
338     MS_EXCEPTION_IF_NULL(abs2_shape);
339     if (*abs1_shape != *abs2_shape) {
340       MS_LOG(ERROR) << "Abstract shapes are not same, shape1:" << abs1_shape->ToString()
341                     << ", shape2:" << abs2_shape->ToString();
342       return i;
343     }
344   }
345   return 0;
346 }
347 
348 // For example,
349 // TensorType(element type is float16) and TensorType(element type is int32) are not same,
350 // TupleType(elements num is 3) and TupleType(elements num is 4) are same.
CheckAbstractTypeSame(const std::vector<AbstractBasePtr> & abs_list)351 size_t CheckAndConvertUtils::CheckAbstractTypeSame(const std::vector<AbstractBasePtr> &abs_list) {
352   if (abs_list.size() <= 1) {
353     return 0;
354   }
355   const auto &first_elem_abs = abs_list[0];
356   MS_EXCEPTION_IF_NULL(first_elem_abs);
357   auto abs1_type = first_elem_abs->BuildType();
358   MS_EXCEPTION_IF_NULL(abs1_type);
359   for (size_t i = 1; i < abs_list.size(); ++i) {
360     MS_EXCEPTION_IF_NULL(abs_list[i]);
361     auto abs2_type = abs_list[i]->BuildType();
362     MS_EXCEPTION_IF_NULL(abs2_type);
363     if (!(*abs1_type == *abs2_type)) {
364       MS_LOG(ERROR) << "Abstract types are not same, type1:" << abs1_type->ToString()
365                     << ", type2:" << abs2_type->ToString();
366       return i;
367     }
368   }
369   return 0;
370 }
371 
CheckAttrInt64Positive(const std::string & op,const ValuePtr & attr,const std::string & attr_name)372 int64_t CheckAndConvertUtils::CheckAttrInt64Positive(const std::string &op, const ValuePtr &attr,
373                                                      const std::string &attr_name) {
374   MS_EXCEPTION_IF_NULL(attr);
375   int64_t attr_val = attr->cast<Int64ImmPtr>()->value();
376   if (attr_val <= 0) {
377     MS_EXCEPTION(ValueError) << "For '" << op << "', the '" << attr_name
378                              << "' should be greater than 0, but got: " << attr_val << ".";
379   }
380   return attr_val;
381 }
382 
CheckAbstractTypeAndShapeSame(const std::vector<AbstractBasePtr> & abs_list,const std::string & precondition_log,const std::string & standard_abs_description,const std::string & differ_abs_description)383 void CheckAndConvertUtils::CheckAbstractTypeAndShapeSame(const std::vector<AbstractBasePtr> &abs_list,
384                                                          const std::string &precondition_log,
385                                                          const std::string &standard_abs_description,
386                                                          const std::string &differ_abs_description) {
387   if (abs_list.size() <= 1) {
388     return;
389   }
390   auto differ_index = CheckAndConvertUtils::CheckAbstractTypeSame(abs_list);
391   if (differ_index == 0) {
392     differ_index = CheckAndConvertUtils::CheckAbstractShapeSame(abs_list);
393   }
394   if (differ_index != 0) {
395     auto log_info1 = standard_abs_description.empty() ? "sequence[0] item" : standard_abs_description;
396     auto log_info2 =
397       differ_abs_description.empty() ? "sequence[" + std::to_string(differ_index) + "] item" : differ_abs_description;
398     MS_EXCEPTION(TypeError) << precondition_log << ", the " << log_info1 << " abstract '" << abs_list[0]->ToString()
399                             << "' is not same with the " << log_info2 << " abstract '"
400                             << abs_list[differ_index]->ToString() << "'.";
401   }
402 }
403 
CheckElementAbstractUnSupport(const AbstractBasePtr abs)404 bool CheckElementAbstractUnSupport(const AbstractBasePtr abs) {
405   if (abs == nullptr) {
406     return false;
407   }
408   if (abs->isa<abstract::AbstractSequence>() && !abs->isa<abstract::AbstractSparseTensor>()) {
409     abstract::AbstractSequencePtr seq = abs->cast<abstract::AbstractSequencePtr>();
410     if (seq->dynamic_len()) {
411       auto elem = seq->dynamic_len_element_abs();
412       if (elem->BuildType()->type_id() == kNumberTypeInt64) {
413         return false;
414       }
415     } else {
416       auto elements = seq->elements();
417       if (elements.empty() || std::all_of(elements.cbegin(), elements.cend(), [](const AbstractBasePtr &elem) {
418             return elem->BuildType()->type_id() == kNumberTypeInt64;
419           })) {
420         return false;
421       }
422     }
423     return true;
424   }
425   if (abs->isa<abstract::AbstractDictionary>()) {
426     return true;
427   }
428   if (abs->isa<abstract::AbstractAny>()) {
429     return true;
430   }
431   auto abs_type = abs->BuildType();
432   if (abs_type != nullptr && abs_type->isa<External>()) {
433     return true;
434   }
435   return false;
436 }
437 
CheckContainNestedOrIrregularSequence(const std::vector<AbstractBasePtr> & abs_list)438 bool CheckAndConvertUtils::CheckContainNestedOrIrregularSequence(const std::vector<AbstractBasePtr> &abs_list) {
439   // Check input abs has nested sequence, or irregular sequence,
440   // such as sequence contains elements with different shape or type.
441   for (auto abs : abs_list) {
442     if (abs == nullptr) {
443       continue;
444     }
445     if (abs->isa<abstract::AbstractDictionary>()) {
446       return true;
447     }
448     if (!abs->isa<abstract::AbstractSequence>() || abs->isa<abstract::AbstractSparseTensor>()) {
449       continue;
450     }
451     auto abs_seq = abs->cast<abstract::AbstractSequencePtr>();
452     if (abs_seq->dynamic_len()) {
453       if (CheckElementAbstractUnSupport(abs_seq->dynamic_len_element_abs())) {
454         return true;
455       }
456       continue;
457     }
458     const auto &elements = abs_seq->elements();
459     if (elements.size() == 0) {
460       continue;
461     }
462     auto first_element = elements[0];
463     MS_EXCEPTION_IF_NULL(first_element);
464     if (CheckElementAbstractUnSupport(first_element)) {
465       return true;
466     }
467     auto first_element_shape = first_element->GetShape();
468     MS_EXCEPTION_IF_NULL(first_element_shape);
469     auto first_element_type = first_element->BuildType();
470     MS_EXCEPTION_IF_NULL(first_element_type);
471     auto first_element_type_id = first_element_type->generic_type_id();
472     for (size_t i = 1; i < elements.size(); ++i) {
473       auto cur_element = elements[i];
474       MS_EXCEPTION_IF_NULL(cur_element);
475       auto cur_element_type = cur_element->BuildType();
476       MS_EXCEPTION_IF_NULL(cur_element_type);
477       auto cur_element_type_id = cur_element_type->generic_type_id();
478       if (first_element_type_id != cur_element_type_id) {
479         return true;
480       }
481       auto cur_element_shape = cur_element->GetShape();
482       MS_EXCEPTION_IF_NULL(cur_element_shape);
483       if (*first_element_shape != *cur_element_shape) {
484         return true;
485       }
486       try {
487         // cppcheck-suppress unreadVariable
488         MS_LOG_TRY_CATCH_SCOPE;
489         (void)first_element->Join(cur_element);
490       } catch (std::exception &) {
491         return true;
492       }
493     }
494   }
495   return false;
496 }
497 
BroadenAllSequenceElements(const abstract::AbstractSequencePtr & sequence)498 abstract::AbstractSequencePtr CheckAndConvertUtils::BroadenAllSequenceElements(
499   const abstract::AbstractSequencePtr &sequence) {
500   MS_EXCEPTION_IF_NULL(sequence);
501   const auto &elements = sequence->elements();
502   AbstractBasePtrList new_elements;
503   for (auto element : elements) {
504     AbstractBasePtr new_element = nullptr;
505     if (element->isa<abstract::AbstractSequence>()) {
506       new_element = BroadenAllSequenceElements(element->cast<abstract::AbstractSequencePtr>());
507     } else {
508       auto tmp_element = element->Clone();
509       if (element->isa<abstract::AbstractScalar>()) {
510         tmp_element->cast<abstract::AbstractScalarPtr>()->set_is_variable(true);
511       }
512       new_element = tmp_element->Broaden();
513     }
514     new_elements.push_back(new_element);
515   }
516   if (sequence->isa<abstract::AbstractList>()) {
517     return std::make_shared<abstract::AbstractList>(new_elements, sequence->sequence_nodes());
518   }
519   return std::make_shared<abstract::AbstractTuple>(new_elements, sequence->sequence_nodes());
520 }
521 
CheckValueSame(const ValuePtr & value_1,const ValuePtr & value_2)522 bool CheckAndConvertUtils::CheckValueSame(const ValuePtr &value_1, const ValuePtr &value_2) {
523   MS_EXCEPTION_IF_NULL(value_1);
524   MS_EXCEPTION_IF_NULL(value_2);
525   if (!value_1->IsSameTypeId(value_2->tid())) {
526     return false;
527   }
528   if (value_1->isa<tensor::BaseTensor>()) {
529     auto list_tensor_value = value_2->cast_ptr<tensor::BaseTensor>();
530     return value_1->cast_ptr<tensor::BaseTensor>()->ValueEqual(*list_tensor_value);
531   }
532   return *value_1 == *value_2;
533 }
534 
ConvertAttrValueInExport(const std::string & op_type,const std::string & attr_name,ValuePtr * const value)535 void CheckAndConvertUtils::ConvertAttrValueInExport(const std::string &op_type, const std::string &attr_name,
536                                                     ValuePtr *const value) {
537   if (value == nullptr || *value == nullptr) {
538     MS_LOG(DEBUG) << "value is nullptr! op_type = " << op_type << ", attr_name = " << attr_name;
539     return;
540   }
541   // convert enum to string
542   ConvertAttrValueToString(op_type, attr_name, value);
543 }
544 
ConvertAttrValueInLoad(const std::string & op_type,const std::string & attr_name,ValuePtr * const value)545 void CheckAndConvertUtils::ConvertAttrValueInLoad(const std::string &op_type, const std::string &attr_name,
546                                                   ValuePtr *const value) {
547   if (value == nullptr || *value == nullptr) {
548     MS_LOG(DEBUG) << "value is nullptr! op_type = " << op_type << ", attr_name = " << attr_name;
549     return;
550   }
551   // convert string to enum
552   ConvertAttrValueToInt(op_type, attr_name, value);
553 }
554 
555 namespace {
556 typedef std::map<std::string, std::function<ValuePtr(ValuePtr)>> AttrFunction;
557 
L2NormalizeAttrConversion(ValuePtr attr)558 ValuePtr L2NormalizeAttrConversion(ValuePtr attr) {
559   if (attr->isa<Int64Imm>()) {
560     return attr;
561   }
562   auto attr_value = GetValue<std::vector<int64_t>>(attr);
563   return MakeValue(attr_value[0]);
564 }
565 
566 std::map<std::string, AttrFunction> kIrAttrToOpAttr = {{"L2Normalize", {{"axis", L2NormalizeAttrConversion}}},
567                                                        {"L2NormalizeGrad", {{"axis", L2NormalizeAttrConversion}}}};
CheckType(const TypePtr & check_type,const std::set<TypePtr> & template_types)568 inline bool CheckType(const TypePtr &check_type, const std::set<TypePtr> &template_types) {
569   return std::any_of(template_types.begin(), template_types.end(), [&check_type](const TypePtr &accept) -> bool {
570     return IsIdentidityOrSubclass(check_type, accept);
571   });
572 }
573 }  // namespace
574 
CheckString(const std::string & arg_name,const std::string & arg_value,const std::set<std::string> & check_list,const std::string & prim_name)575 std::string CheckAndConvertUtils::CheckString(const std::string &arg_name, const std::string &arg_value,
576                                               const std::set<std::string> &check_list, const std::string &prim_name) {
577   if (check_list.find(arg_value) != check_list.end()) {
578     return arg_value;
579   }
580   std::ostringstream buffer;
581   buffer << "For primitive[" << prim_name << "], the attribute[" << arg_name << "]";
582   if (check_list.size() == 1) {
583     buffer << " must be \"" << (*check_list.begin()) << "\", but got \"" << arg_value << "\".";
584     MS_EXCEPTION(ValueError) << buffer.str();
585   }
586   buffer << " should be a element of {";
587   for (const auto &item : check_list) {
588     buffer << "\"" << item << "\", ";
589   }
590   buffer << "}"
591          << ",but got \"" << arg_value << "\""
592          << ".";
593   MS_EXCEPTION(ValueError) << buffer.str();
594 }
595 
CheckInteger(const std::string & arg_name,int64_t arg_value,CompareEnum compare_operator,int64_t match_value,const std::string & prim_name)596 int64_t CheckAndConvertUtils::CheckInteger(const std::string &arg_name, int64_t arg_value, CompareEnum compare_operator,
597                                            int64_t match_value, const std::string &prim_name) {
598   auto iter = kCompareMap<float>.find(compare_operator);
599   if (iter == kCompareMap<float>.end()) {
600     MS_EXCEPTION(NotExistsError) << "compare_operator " << compare_operator << " cannot find in the compare map";
601   }
602   if (iter->second(arg_value, match_value)) {
603     return arg_value;
604   }
605   std::ostringstream buffer;
606   if (prim_name.empty()) {
607     buffer << "The argument[" << arg_name << "] must ";
608   } else {
609     buffer << "For primitive[" << prim_name << "], the " << arg_name << " must ";
610   }
611   auto iter_to_string = kCompareToString.find(compare_operator);
612   if (iter_to_string == kCompareToString.end()) {
613     MS_EXCEPTION(NotExistsError) << "compare_operator " << compare_operator << " cannot find in the compare string map";
614   }
615   buffer << iter_to_string->second << match_value << ", but got " << arg_value << ".";
616   MS_EXCEPTION(ValueError) << buffer.str();
617 }
618 
FormatCheckMsg(const std::string & arg_name,const std::vector<int64_t> & arg_value,CompareEnum compare_type,const std::vector<int64_t> & value,const PrimitivePtr & prim)619 std::string CheckAndConvertUtils::FormatCheckMsg(const std::string &arg_name, const std::vector<int64_t> &arg_value,
620                                                  CompareEnum compare_type, const std::vector<int64_t> &value,
621                                                  const PrimitivePtr &prim) {
622   std::ostringstream buffer;
623   if (prim == nullptr) {
624     buffer << "The attribute[" << arg_name << "]:";
625   } else {
626     buffer << "For primitive[" << prim->name() << "], the " << arg_name << ":";
627   }
628   auto iter_to_string = kCompareToString.find(compare_type);
629   if (iter_to_string == kCompareToString.end()) {
630     MS_EXCEPTION(NotExistsError) << "compare_operator " << compare_type << " cannot find in the compare string map";
631   }
632 
633   buffer << " [";
634   for (auto item : arg_value) {
635     buffer << item << ",";
636   }
637   buffer << "]";
638   buffer << " must " << iter_to_string->second << "[";
639   for (auto item : value) {
640     buffer << item << ",";
641   }
642   buffer << "]";
643   return buffer.str();
644 }
645 
CheckInputArgs(const std::vector<AbstractBasePtr> & input_args,const CompareEnum compare_operator,const int64_t match_value,const std::string & prim_name)646 void CheckAndConvertUtils::CheckInputArgs(const std::vector<AbstractBasePtr> &input_args,
647                                           const CompareEnum compare_operator, const int64_t match_value,
648                                           const std::string &prim_name) {
649   (void)CheckInteger("input number", SizeToLong(input_args.size()), compare_operator, match_value, prim_name);
650   for (size_t index = 0; index < input_args.size(); index++) {
651     if (input_args[index] == nullptr) {
652       MS_EXCEPTION(ValueError) << "The " << index << "'s input of " << prim_name << " is nullptr.";
653     }
654   }
655 }
656 
ConvertShapePtrToShapeMap(const BaseShapePtr & shape)657 ShapeMap CheckAndConvertUtils::ConvertShapePtrToShapeMap(const BaseShapePtr &shape) {
658   MS_EXCEPTION_IF_NULL(shape);
659   if (!shape->isa<abstract::Shape>()) {
660     return std::map<std::string, std::vector<int64_t>>();
661   }
662   auto shape_element = shape->cast<abstract::ShapePtr>();
663   MS_EXCEPTION_IF_NULL(shape_element);
664   ShapeMap shape_map;
665   shape_map[kShape] = shape_element->shape();
666   shape_map[kMaxShape] = shape_element->max_shape();
667   return shape_map;
668 }
669 
GetTensorInputShape(const std::string & prim_name,const std::vector<AbstractBasePtr> & input_args,size_t index)670 abstract::ShapePtr CheckAndConvertUtils::GetTensorInputShape(const std::string &prim_name,
671                                                              const std::vector<AbstractBasePtr> &input_args,
672                                                              size_t index) {
673   auto abstract = CheckAndConvertUtils::CheckArgsType(prim_name, input_args, index, kObjectTypeTensorType);
674   MS_EXCEPTION_IF_NULL(abstract);
675   auto base_shape = abstract->GetShape();
676   MS_EXCEPTION_IF_NULL(base_shape);
677   if (!base_shape->isa<abstract::TensorShape>()) {
678     MS_LOG(EXCEPTION) << prim_name << " can not get shape for input " << index;
679   }
680   auto shape = base_shape->cast<abstract::TensorShapePtr>();
681   MS_EXCEPTION_IF_NULL(shape);
682   return shape;
683 }
684 
GetTensorInputType(const std::string & prim_name,const std::vector<AbstractBasePtr> & input_args,size_t index)685 TypePtr CheckAndConvertUtils::GetTensorInputType(const std::string &prim_name,
686                                                  const std::vector<AbstractBasePtr> &input_args, size_t index) {
687   if (input_args.size() <= index) {
688     MS_EXCEPTION(ValueError) << "For " << prim_name << ", the index " << index << " is out of the input number "
689                              << input_args.size();
690   }
691   auto input_arg = input_args[index];
692   if (input_arg == nullptr) {
693     MS_EXCEPTION(ValueError) << "The " << index << "'s input of " << prim_name << " is nullptr.";
694   }
695   auto base_type = input_arg->GetType();
696   MS_EXCEPTION_IF_NULL(base_type);
697   if (!base_type->isa<TensorType>()) {
698     MS_EXCEPTION(TypeError) << "The " << index << "'s input type of " << prim_name << " is not Tensor.";
699   }
700   auto tensor_type = base_type->cast<TensorTypePtr>();
701   MS_EXCEPTION_IF_NULL(tensor_type);
702   auto type = tensor_type->element();
703   MS_EXCEPTION_IF_NULL(type);
704   return type;
705 }
706 
Check(const string & arg_name,int64_t arg_value,CompareEnum compare_type,int64_t value,const string & prim_name,ExceptionType)707 void CheckAndConvertUtils::Check(const string &arg_name, int64_t arg_value, CompareEnum compare_type, int64_t value,
708                                  const string &prim_name, ExceptionType) {
709   auto iter = kCompareMap<float>.find(compare_type);
710   if (iter == kCompareMap<float>.end()) {
711     MS_EXCEPTION(NotExistsError) << "the compare type :" << compare_type << " is not in the compare map";
712   }
713   if (iter->second(arg_value, value)) {
714     return;
715   }
716   std::ostringstream buffer;
717   if (prim_name.empty()) {
718     buffer << "The attribute[" << arg_name << "] must ";
719   } else {
720     buffer << "For primitive[" << prim_name << "], the attribute[" << arg_name << "] must ";
721   }
722   auto iter_to_string = kCompareToString.find(compare_type);
723   if (iter_to_string == kCompareToString.end()) {
724     MS_EXCEPTION(NotExistsError) << "compare_operator " << compare_type << " cannot find in the compare string map";
725   }
726   buffer << iter_to_string->second << value << ", but got " << arg_value << ".";
727   MS_EXCEPTION(ValueError) << buffer.str();
728 }
729 
CheckTensorTypeSame(const std::map<std::string,TypePtr> & types,const std::set<TypePtr> & check_list,const std::string & prim_name)730 TypePtr CheckAndConvertUtils::CheckTensorTypeSame(const std::map<std::string, TypePtr> &types,
731                                                   const std::set<TypePtr> &check_list, const std::string &prim_name) {
732   if (types.empty()) {
733     MS_EXCEPTION(ArgumentError) << "Trying to use the function to check a empty types map!";
734   }
735   // Check Input type is tensor type
736   for (const auto &item : types) {
737     auto type = item.second;
738     MS_EXCEPTION_IF_NULL(type);
739     if (!type->isa<TensorType>()) {
740       size_t i = 1;
741       std::ostringstream buffer;
742       buffer << "The primitive[" << prim_name << "]'s input arguments[";
743       for (const auto &item_type : types) {
744         buffer << item_type.first;
745         if (i < types.size()) {
746           buffer << ", ";
747           ++i;
748         }
749       }
750       i = 1;
751       buffer << "] must be all tensor and those type must be same.";
752       for (const auto &type_info : types) {
753         if (!type_info.second->isa<TensorType>()) {
754           buffer << " But got input argument[" << type_info.first << "]"
755                  << ":" << type_info.second->ToString() << "\n";
756         }
757       }
758       if (!check_list.empty()) {
759         buffer << "Valid type list: {";
760         std::set<string> order_set;
761         for (auto const &valid_type : check_list) {
762           if (valid_type->isa<TensorType>()) {
763             (void)order_set.emplace(valid_type->ToString());
764             break;
765           } else {
766             (void)order_set.emplace("Tensor[" + valid_type->ToString() + "]");
767           }
768         }
769         for (auto const &error_item : order_set) {
770           buffer << error_item;
771           if (error_item != *(--order_set.end())) {
772             buffer << ", ";
773           }
774         }
775         buffer << "}.";
776       }
777       MS_EXCEPTION(TypeError) << buffer.str();
778     }
779   }
780   (void)CheckTypeSame(types, prim_name, false);
781   return CheckTensorSubClass(types.begin()->first, types.begin()->second, check_list, prim_name);
782 }
783 
CheckMathBinaryOpTensorType(const std::map<std::string,TypePtr> & types,const std::set<TypePtr> & check_list,const std::string & prim_name)784 TypePtr CheckAndConvertUtils::CheckMathBinaryOpTensorType(const std::map<std::string, TypePtr> &types,
785                                                           const std::set<TypePtr> &check_list,
786                                                           const std::string &prim_name) {
787   constexpr size_t n = 2;
788   if (types.size() != n) {
789     MS_EXCEPTION(ArgumentError) << "For primitive[" << prim_name << "], the size of types to check must be " << n
790                                 << ", but got " << types.size();
791   }
792   // Check Input type is tensor type
793   std::vector<TypeId> type_ids;
794   std::vector<TypePtr> type_ptr;
795   bool has_complex = false;
796   for (const auto &item : types) {
797     MS_EXCEPTION_IF_NULL(item.second);
798     if (!item.second->isa<TensorType>()) {
799       MS_EXCEPTION(TypeError) << "The primitive[" << prim_name << "]'s input arguments[" << item.first
800                               << "] must be Tensor, but got " << item.second->ToString();
801     }
802     auto tensor_type = item.second->cast<TensorTypePtr>();
803     MS_EXCEPTION_IF_NULL(tensor_type);
804     auto element = tensor_type->element();
805     MS_EXCEPTION_IF_NULL(element);
806     auto type_id = element->type_id();
807     if (!has_complex && (type_id == kNumberTypeComplex64 || type_id == kNumberTypeComplex128)) {
808       has_complex = true;
809     }
810     type_ids.push_back(type_id);
811     type_ptr.push_back(item.second);
812   }
813   // Deal with complex data type
814   if (has_complex) {
815     static std::map<std::pair<TypeId, TypeId>, TypeId> type_infer_dict = {
816       {{kNumberTypeComplex64, kNumberTypeComplex64}, kNumberTypeComplex64},
817       {{kNumberTypeComplex64, kNumberTypeFloat32}, kNumberTypeComplex64},
818       {{kNumberTypeFloat32, kNumberTypeComplex64}, kNumberTypeComplex64},
819       {{kNumberTypeComplex128, kNumberTypeComplex128}, kNumberTypeComplex128},
820       {{kNumberTypeComplex128, kNumberTypeFloat64}, kNumberTypeComplex128},
821       {{kNumberTypeFloat64, kNumberTypeComplex128}, kNumberTypeComplex128}};
822     std::pair<TypeId, TypeId> type_info(type_ids[0], type_ids[1]);
823     auto iter = type_infer_dict.find(type_info);
824     if (iter != type_infer_dict.end()) {
825       return type_ids[0] == iter->second ? type_ptr[0] : type_ptr[1];
826     }
827     std::ostringstream buffer;
828     buffer << "For primitive[" << prim_name << "], complex math binary op expecting Tensor";
829     for (const auto &items : type_infer_dict) {
830       buffer << "[" << TypeIdToString(items.first.first) << ", " << TypeIdToString(items.first.second) << "], ";
831     }
832     buffer << "but got Tensor[" << TypeIdToString(type_ids[0]) << ", " << TypeIdToString(type_ids[1]) << "]";
833     MS_EXCEPTION(TypeError) << buffer.str();
834   }
835   // Deal with non-complex data type
836   if (type_ids[0] != type_ids[1]) {
837     MS_EXCEPTION(TypeError) << "For primitive[" << prim_name
838                             << "], the input arguments must have same data type, but got Tensor["
839                             << TypeIdToString(type_ids[0]) << "] and Tensor[" << TypeIdToString(type_ids[1]) << "]";
840   }
841   (void)CheckTensorSubClass(types.begin()->first, types.begin()->second, check_list, prim_name);
842   return types.begin()->second;
843 }
844 
CheckTensorShapeSame(const std::map<std::string,BaseShapePtr> & shapes,const std::vector<int64_t> & check_shape,const std::string & prim_name)845 ShapeVector CheckAndConvertUtils::CheckTensorShapeSame(const std::map<std::string, BaseShapePtr> &shapes,
846                                                        const std::vector<int64_t> &check_shape,
847                                                        const std::string &prim_name) {
848   if (shapes.empty()) {
849     MS_EXCEPTION(ArgumentError) << "Trying to use the function to check a empty shapes map!";
850   }
851   for (const auto &shape : shapes) {
852     auto _shape_ptr_ = shape.second;
853     MS_EXCEPTION_IF_NULL(_shape_ptr_);
854     auto _shape_ = ConvertShapePtrToShapeMap(_shape_ptr_)[kShape];
855     (void)CheckPositiveVectorExcludeZero(shape.first, _shape_, prim_name);
856     if (!ShapeVectorIsSame(_shape_, check_shape)) {
857       std::ostringstream buffer;
858       buffer << "The primitive[" << prim_name << "]'s input arguments " << shape.first << " shape should equal to "
859              << ShapeVectorToStr(check_shape) << ", but get the real shape " << ShapeVectorToStr(_shape_) << ".";
860       MS_EXCEPTION(ValueError) << buffer.str();
861     }
862   }
863   return check_shape;
864 }
865 
CheckTensorTypeValid(const std::string & type_name,const TypePtr & type,const std::set<TypePtr> & check_list,const std::string & prim_name)866 TypePtr CheckAndConvertUtils::CheckTensorTypeValid(const std::string &type_name, const TypePtr &type,
867                                                    const std::set<TypePtr> &check_list, const std::string &prim_name) {
868   // note that the return type might be different from input type
869   MS_EXCEPTION_IF_NULL(type);
870   if (!type->isa<TensorType>()) {
871     MS_EXCEPTION(TypeError) << "For Primitive[" << prim_name << "], the type of input argument[" << type_name
872                             << "] must be Tensor but got " << type->ToString() << ".";
873   }
874   auto tensor_type = type->cast<TensorTypePtr>();
875   auto element = tensor_type->element();
876   MS_EXCEPTION_IF_NULL(element);
877   for (const TypePtr &item : check_list) {
878     if (item->isa<TensorType>()) {
879       auto item_tensor_type = item->cast<TensorTypePtr>();
880       if (item_tensor_type->element() == nullptr) {
881         return element;
882       }
883     }
884   }
885   return CheckTensorSubClass(type_name, type, check_list, prim_name);
886 }
887 
CheckSparseTensorTypeValid(const std::string & type_name,const TypePtr & type,const std::set<TypePtr> &,const std::string & prim_name)888 TypePtr CheckAndConvertUtils::CheckSparseTensorTypeValid(const std::string &type_name, const TypePtr &type,
889                                                          const std::set<TypePtr> &, const std::string &prim_name) {
890   MS_EXCEPTION_IF_NULL(type);
891   if (!type->isa<SparseTensorType>()) {
892     MS_EXCEPTION(TypeError) << "For Primitive[" << prim_name << "], the input argument[" << type_name
893                             << "] must be a CSRTensor or COOTensor, but got " << type->ToString() << ".";
894   } else {
895     auto sparse_type = type->cast<SparseTensorTypePtr>();
896     if (sparse_type != nullptr) {
897       return sparse_type->element_type();
898     }
899     MS_EXCEPTION(TypeError) << "For Primitive[" << prim_name << "], the input argument[" << type_name
900                             << "] cast to SparseTensorTypePtr failed! Get type : " << type->ToString() << ".";
901   }
902 }
903 
CheckTensorIntValue(const std::string & tensor_name,const ValuePtr & value,const std::string & prim_name)904 ShapeVector CheckAndConvertUtils::CheckTensorIntValue(const std::string &tensor_name, const ValuePtr &value,
905                                                       const std::string &prim_name) {
906   if (value == nullptr) {
907     MS_EXCEPTION(ValueError) << "For primitive[" << prim_name << "], the input argument[" << tensor_name
908                              << "] value is nullptr.";
909   }
910   ShapeVector tensor_value;
911   if (!value->isa<tensor::BaseTensor>()) {
912     MS_EXCEPTION(ValueError) << "For primitive[" << prim_name << "], the input argument[" << tensor_name
913                              << "] must be a tensor, but got " << value->ToString();
914   }
915   auto input_tensor = value->cast<tensor::BaseTensorPtr>();
916   MS_EXCEPTION_IF_NULL(input_tensor);
917   size_t data_size = input_tensor->DataSize();
918   auto tensor_type = input_tensor->Dtype();
919   if (tensor_type->type_id() == kNumberTypeInt32) {
920     auto data_c = reinterpret_cast<int *>(input_tensor->data_c());
921     MS_EXCEPTION_IF_NULL(data_c);
922     for (size_t i = 0; i < data_size; i++) {
923       tensor_value.push_back(static_cast<int64_t>(*data_c));
924       ++data_c;
925     }
926   } else if (tensor_type->type_id() == kNumberTypeInt64) {
927     auto tensor_data = reinterpret_cast<int64_t *>(input_tensor->data_c());
928     MS_EXCEPTION_IF_NULL(tensor_data);
929     tensor_value = {tensor_data, tensor_data + data_size};
930   } else if (tensor_type->type_id() == kNumberTypeUInt32) {
931     auto tensor_data = reinterpret_cast<uint32_t *>(input_tensor->data_c());
932     MS_EXCEPTION_IF_NULL(tensor_data);
933     tensor_value = {tensor_data, tensor_data + data_size};
934   } else if (tensor_type->type_id() == kNumberTypeUInt64) {
935     auto tensor_data = reinterpret_cast<uint64_t *>(input_tensor->data_c());
936     MS_EXCEPTION_IF_NULL(tensor_data);
937     tensor_value = {tensor_data, tensor_data + data_size};
938   } else {
939     MS_EXCEPTION(TypeError) << "For primitive[" << prim_name << "], the input argument[" << tensor_name
940                             << "] must be a Tensor[Int64] or Tensor[Int32]"
941                             << " or Tensor[UInt64] or Tensor[UInt32] type, but got " << value->ToString();
942   }
943   return tensor_value;
944 }
945 
CheckTensorIntValue(const std::string & tensor_name,const ValuePtr & value,const std::string & prim_name,const TypePtr & type)946 ShapeVector CheckAndConvertUtils::CheckTensorIntValue(const std::string &tensor_name, const ValuePtr &value,
947                                                       const std::string &prim_name, const TypePtr &type) {
948   if (value == nullptr) {
949     MS_EXCEPTION(ValueError) << "For primitive[" << prim_name << "], the input argument[" << tensor_name
950                              << "] value is nullptr.";
951   }
952   if (value->isa<ValueAny>() || value->isa<None>()) {
953     MS_EXCEPTION(ValueError) << "For primitive[" << prim_name << "], the input argument[" << tensor_name
954                              << "] value is unknown.";
955   }
956   if (type->object_type() != kObjectTypeTensorType) {
957     MS_EXCEPTION(ValueError) << "For primitive[" << prim_name << "], the input argument[" << tensor_name
958                              << "] must be a tensor, but got " << type->ToString();
959   }
960   ShapeVector tensor_value;
961   auto tensor_type_ptr = type->cast<TensorTypePtr>();
962   MS_EXCEPTION_IF_NULL(tensor_type_ptr);
963   auto tensor_type = tensor_type_ptr->element()->type_id();
964   if (tensor_type == kNumberTypeInt32) {
965     auto data_opt = ops::GetArrayValue<int>(value);
966     const auto &data = data_opt.value();
967     for (size_t i = 0; i < data.size(); i++) {
968       tensor_value.push_back(static_cast<int64_t>(data[i]));
969     }
970   } else if (tensor_type == kNumberTypeInt64) {
971     auto data_opt = ops::GetArrayValue<int64_t>(value);
972     tensor_value = data_opt.value().ToVector();
973   } else if (tensor_type == kNumberTypeUInt32) {
974     auto data_opt = ops::GetArrayValue<uint32_t>(value);
975     const auto &data = data_opt.value();
976     for (size_t i = 0; i < data.size(); i++) {
977       tensor_value.push_back(static_cast<int64_t>(data[i]));
978     }
979   } else if (tensor_type == kNumberTypeUInt64) {
980     auto data_opt = ops::GetArrayValue<uint64_t>(value);
981     const auto &data = data_opt.value();
982     for (size_t i = 0; i < data.size(); i++) {
983       tensor_value.push_back(static_cast<int64_t>(data[i]));
984     }
985   } else {
986     MS_EXCEPTION(TypeError) << "For primitive[" << prim_name << "], the input argument[" << tensor_name
987                             << "] must be a Tensor[Int64] or Tensor[Int32]"
988                             << " or Tensor[UInt64] or Tensor[UInt32] type, but got " << value->ToString();
989   }
990   return tensor_value;
991 }
992 
CheckTensorSubClass(const string & type_name,const TypePtr & type,const std::set<TypePtr> & template_types,const string & prim_name,bool is_mix)993 TypePtr CheckAndConvertUtils::CheckTensorSubClass(const string &type_name, const TypePtr &type,
994                                                   const std::set<TypePtr> &template_types, const string &prim_name,
995                                                   bool is_mix) {
996   MS_EXCEPTION_IF_NULL(type);
997   auto real_type = type;
998   if (type->isa<TensorType>()) {
999     auto tensor_type = type->cast<TensorTypePtr>();
1000     real_type = tensor_type->element();
1001   }
1002   if (CheckType(real_type, template_types)) {
1003     return real_type;
1004   }
1005   std::ostringstream buffer;
1006   buffer << "For primitive[" << prim_name << "], the input argument[" << type_name << "] must be a type of {";
1007   std::set<string> order_set;
1008 
1009   if (is_mix) {
1010     for (const auto &item : template_types) {
1011       (void)order_set.emplace(item->ToString());
1012     }
1013   }
1014 
1015   for (const auto &item : template_types) {
1016     if (item->isa<TensorType>()) {
1017       (void)order_set.emplace(item->ToString());
1018       continue;
1019     }
1020     (void)order_set.emplace("Tensor[" + item->ToString() + "]");
1021   }
1022 
1023   for (const auto &item : order_set) {
1024     buffer << item;
1025     if (item != *(--order_set.end())) {
1026       buffer << ", ";
1027     }
1028   }
1029   buffer << "}, but got " << type->ToString();
1030   buffer << ".";
1031   MS_EXCEPTION(TypeError) << buffer.str();
1032 }
1033 
CheckSubClass(const std::string & type_name,const TypePtr & type,const std::set<TypePtr> & template_types,const std::string & prim_name)1034 TypePtr CheckAndConvertUtils::CheckSubClass(const std::string &type_name, const TypePtr &type,
1035                                             const std::set<TypePtr> &template_types, const std::string &prim_name) {
1036   if (CheckType(type, template_types)) {
1037     return type;
1038   }
1039   std::ostringstream buffer;
1040   buffer << "For primitive[" << prim_name << "], the input argument[" << type_name << "] must be a type of {";
1041   std::set<string> order_set;
1042   for (const auto &item : template_types) {
1043     (void)order_set.emplace(item->ToString());
1044   }
1045   for (const auto &item : order_set) {
1046     buffer << item;
1047     if (item != *(--order_set.end())) {
1048       buffer << ", ";
1049     }
1050   }
1051   buffer << "}, but got " << type->ToString();
1052   buffer << ".";
1053   MS_EXCEPTION(TypeError) << buffer.str();
1054 }
1055 
CheckSubClassWithMoreInfo(const std::string & type_name,const TypePtr & type,const std::string & more_info,const std::set<TypePtr> & template_types,const std::string & prim_name)1056 TypePtr CheckAndConvertUtils::CheckSubClassWithMoreInfo(const std::string &type_name, const TypePtr &type,
1057                                                         const std::string &more_info,
1058                                                         const std::set<TypePtr> &template_types,
1059                                                         const std::string &prim_name) {
1060   if (CheckType(type, template_types)) {
1061     return type;
1062   }
1063   std::ostringstream buffer;
1064   buffer << "For primitive[" << prim_name << "], the input argument[" << type_name << "] " << more_info
1065          << " must be a type of {";
1066   std::set<string> order_set;
1067   for (const auto &item : template_types) {
1068     (void)order_set.emplace(item->ToString());
1069   }
1070   for (const auto &item : order_set) {
1071     buffer << item;
1072     if (item != *(--order_set.end())) {
1073       buffer << ", ";
1074     }
1075   }
1076   buffer << "}, but got " << type->ToString();
1077   buffer << ".";
1078   MS_EXCEPTION(TypeError) << buffer.str();
1079 }
1080 
CheckScalarOrTensorTypesSame(const std::map<std::string,TypePtr> & args,const std::set<TypePtr> & valid_values,const std::string & prim_name,bool allow_mix)1081 TypePtr CheckAndConvertUtils::CheckScalarOrTensorTypesSame(const std::map<std::string, TypePtr> &args,
1082                                                            const std::set<TypePtr> &valid_values,
1083                                                            const std::string &prim_name, bool allow_mix) {
1084   (void)CheckTypeSame(args, prim_name, allow_mix);
1085   return CheckTensorSubClass(args.begin()->first, args.begin()->second, valid_values, prim_name, true);
1086 }
1087 
CheckTypeSame(const std::map<std::string,TypePtr> & args,const std::string & prim_name,const bool allow_mix)1088 TypePtr CheckAndConvertUtils::CheckTypeSame(const std::map<std::string, TypePtr> &args, const std::string &prim_name,
1089                                             const bool allow_mix) {
1090   if (args.empty()) {
1091     MS_EXCEPTION(ArgumentError) << "Trying to use the function to check a empty types map!";
1092   }
1093   std::ostringstream buffer;
1094   TypePtr return_type = args.begin()->second;
1095   buffer << "For primitive[" << prim_name << "], the ";
1096   bool tensor_flag = return_type->isa<TensorType>();
1097   std::set<TypeId> types_id;
1098   for (const auto &elem : args) {
1099     auto type = elem.second;
1100     MS_EXCEPTION_IF_NULL(type);
1101     if (!allow_mix) {
1102       // input must be all tensor or all other type
1103       if ((tensor_flag && !type->isa<TensorType>()) || (!tensor_flag && type->isa<TensorType>())) {
1104         buffer << "input type must be same.\n";
1105         for (const auto &error_elem : args) {
1106           buffer << "input argument[" << error_elem.first << "]:" << error_elem.second->ToString() << "\n";
1107         }
1108         MS_EXCEPTION(TypeError) << buffer.str();
1109       }
1110     }
1111     if (type->isa<TensorType>()) {
1112       auto tensor_type = type->cast<TensorTypePtr>();
1113       auto element = tensor_type->element();
1114       MS_EXCEPTION_IF_NULL(element);
1115       return_type = element;
1116       (void)types_id.emplace(element->type_id());
1117     } else {
1118       if (return_type->isa<TensorType>()) {
1119         return_type = type;
1120       }
1121       (void)types_id.emplace(type->type_id());
1122     }
1123     if (types_id.size() > 1) {
1124       buffer << "input type must be same.\n";
1125       for (const auto &item : args) {
1126         buffer << "name:[" << item.first << "]:" << item.second->ToString() << ".\n";
1127       }
1128       MS_EXCEPTION(TypeError) << buffer.str();
1129     }
1130   }
1131   return return_type->DeepCopy();
1132 }
1133 
CheckTypeValid(const std::string & arg_name,const TypePtr & arg_type,const std::set<TypePtr> & valid_type,const std::string & prim_name)1134 TypePtr CheckAndConvertUtils::CheckTypeValid(const std::string &arg_name, const TypePtr &arg_type,
1135                                              const std::set<TypePtr> &valid_type, const std::string &prim_name) {
1136   if (valid_type.empty()) {
1137     MS_EXCEPTION(ArgumentError) << "Trying to use the function to check a empty valid_type!";
1138   }
1139   MS_EXCEPTION_IF_NULL(arg_type);
1140   if (arg_type->isa<TensorType>()) {
1141     return CheckTensorTypeValid(arg_name, arg_type, valid_type, prim_name);
1142   }
1143   return CheckSubClass(arg_name, arg_type, valid_type, prim_name);
1144 }
1145 
CheckTypeValidWithMoreInfo(const std::string & arg_name,const TypePtr & arg_type,const std::string & more_info,const std::set<TypePtr> & valid_type,const std::string & prim_name)1146 TypePtr CheckAndConvertUtils::CheckTypeValidWithMoreInfo(const std::string &arg_name, const TypePtr &arg_type,
1147                                                          const std::string &more_info,
1148                                                          const std::set<TypePtr> &valid_type,
1149                                                          const std::string &prim_name) {
1150   if (valid_type.empty()) {
1151     MS_EXCEPTION(ArgumentError) << "Trying to use the function to check a empty valid_type!";
1152   }
1153   MS_EXCEPTION_IF_NULL(arg_type);
1154   if (arg_type->isa<TensorType>()) {
1155     return CheckTensorTypeValid(arg_name, arg_type, valid_type, prim_name);
1156   }
1157   return CheckSubClassWithMoreInfo(arg_name, arg_type, more_info, valid_type, prim_name);
1158 }
1159 
CheckIrAttrtoOpAttr(const std::string & op_type,const std::string & attr_name,ValuePtr * const value)1160 bool CheckAndConvertUtils::CheckIrAttrtoOpAttr(const std::string &op_type, const std::string &attr_name,
1161                                                ValuePtr *const value) {
1162   if (*value == nullptr) {
1163     MS_LOG(DEBUG) << "value is nullptr! op_type = " << op_type << ", attr_name = " << attr_name;
1164     return false;
1165   }
1166   if (op_type.empty() || attr_name.empty()) {
1167     return false;
1168   }
1169   auto op_map = kIrAttrToOpAttr.find(op_type);
1170   if (op_map == kIrAttrToOpAttr.end()) {
1171     return false;
1172   }
1173   auto attr_func = op_map->second.find(attr_name);
1174   if (attr_func == op_map->second.end()) {
1175     return false;
1176   }
1177   *value = attr_func->second(*value);
1178   MS_LOG(DEBUG) << "convert ir attr to op attr, name: " << op_type << ", attr: " << attr_name;
1179   return true;
1180 }
1181 
CheckSummaryParam(const AbstractBasePtr & name,const AbstractBasePtr & value,const std::string & class_name)1182 void CheckAndConvertUtils::CheckSummaryParam(const AbstractBasePtr &name, const AbstractBasePtr &value,
1183                                              const std::string &class_name) {
1184   MS_EXCEPTION_IF_NULL(name);
1185   MS_EXCEPTION_IF_NULL(value);
1186   (void)CheckTypeValid("name", name->BuildType(), {kString}, class_name);
1187   auto s = GetValue<std::string>(name->BuildValue());
1188   if (s.empty()) {
1189     MS_EXCEPTION(ValueError) << "For primitive[" << class_name << "], the input argument[name]"
1190                              << " cannot be an empty string.";
1191   }
1192   (void)CheckTypeValid("value", value->BuildType(), {kTensorType}, class_name);
1193 }
1194 
CheckTensorFloatValue(const std::string & type_name,const ValuePtr & value,const std::string & prim_name)1195 std::vector<double> CheckAndConvertUtils::CheckTensorFloatValue(const std::string &type_name, const ValuePtr &value,
1196                                                                 const std::string &prim_name) {
1197   if (value == nullptr) {
1198     MS_EXCEPTION(ValueError) << "For primitive[" << prim_name << "], the input argument[" << type_name
1199                              << "] value is nullptr.";
1200   }
1201   std::vector<double> tensor_value;
1202   if (!value->isa<tensor::BaseTensor>()) {
1203     MS_EXCEPTION(ValueError) << "For primitive[" << prim_name << "], the input argument[" << type_name
1204                              << "] must be a tensor, but got " << value->ToString();
1205   }
1206   auto input_tensor = value->cast<tensor::BaseTensorPtr>();
1207   MS_EXCEPTION_IF_NULL(input_tensor);
1208   size_t data_size = input_tensor->DataSize();
1209   auto tensor_type = input_tensor->Dtype();
1210   if (tensor_type->type_id() == kNumberTypeFloat32) {
1211     auto data_c = static_cast<float *>(input_tensor->data_c());
1212     MS_EXCEPTION_IF_NULL(data_c);
1213     for (size_t i = 0; i < data_size; i++) {
1214       tensor_value.push_back(static_cast<double>(*data_c));
1215       ++data_c;
1216     }
1217   } else if (tensor_type->type_id() == kNumberTypeFloat64) {
1218     auto tensor_data = static_cast<double *>(input_tensor->data_c());
1219     MS_EXCEPTION_IF_NULL(tensor_data);
1220     tensor_value = {tensor_data, tensor_data + data_size};
1221   } else {
1222     MS_EXCEPTION(TypeError) << "For primitive[" << prim_name << "], the input argument[" << type_name
1223                             << "] must be a Tensor[Float32] or Tensor[Float64], but got " << value->ToString();
1224   }
1225   return tensor_value;
1226 }
1227 
CheckListOrTupleFloat(const std::string & arg_name,const ValuePtr & attr,const std::string & prim_name)1228 std::vector<double> CheckAndConvertUtils::CheckListOrTupleFloat(const std::string &arg_name, const ValuePtr &attr,
1229                                                                 const std::string &prim_name) {
1230   std::vector<double> result;
1231   bool is_correct = false;
1232   MS_EXCEPTION_IF_NULL(attr);
1233   if (attr->isa<ValueTuple>() || attr->isa<ValueList>()) {
1234     auto attr_vec =
1235       attr->isa<ValueTuple>() ? attr->cast<ValueTuplePtr>()->value() : attr->cast<ValueListPtr>()->value();
1236     if (attr_vec.empty()) {
1237       return result;
1238     }
1239     is_correct = std::all_of(attr_vec.begin(), attr_vec.end(), [&result](const ValuePtr &e) -> bool {
1240       MS_EXCEPTION_IF_NULL(e);
1241       if (e->isa<FP32Imm>()) {
1242         (void)result.emplace_back(static_cast<double>(GetValue<float>(e)));
1243         return true;
1244       } else if (e->isa<FP64Imm>()) {
1245         (void)result.emplace_back(GetValue<double>(e));
1246         return true;
1247       }
1248       return false;
1249     });
1250   }
1251   if (!is_correct) {
1252     MS_EXCEPTION(TypeError) << "For primitive[" << prim_name << "], the " << arg_name
1253                             << " must be one of ['tuple', 'list'] with all Float elements, but got "
1254                             << attr->ToString();
1255   }
1256   return result;
1257 }
1258 
CheckListOrTupleFloat(const std::string & arg_name,const AbstractBasePtr & abs,const std::string & prim_name)1259 std::vector<pyfloat> CheckAndConvertUtils::CheckListOrTupleFloat(const std::string &arg_name,
1260                                                                  const AbstractBasePtr &abs,
1261                                                                  const std::string &prim_name) {
1262   std::vector<pyfloat> result{};
1263   if (IsSequence(abs)) {
1264     const auto &type_list = GetSequenceElementTypes(abs);
1265     if (type_list.empty()) {
1266       return result;
1267     }
1268     auto is_correct = std::all_of(type_list.begin(), type_list.end(), [](const TypePtr &e) -> bool {
1269       MS_EXCEPTION_IF_NULL(e);
1270       return e->type_id() == kNumberTypeFloat64 || e->type_id() == kNumberTypeFloat32;
1271     });
1272     if (is_correct) {
1273       const auto &arr_value = ops::GetArrayValue<pyfloat>(abs);
1274       if (arr_value->HasUnknownValue()) {
1275         MS_EXCEPTION(ValueError) << "For primitive[" << prim_name << "], there are unknown values in the " << arg_name
1276                                  << ", please handle this case before calling this function.";
1277       }
1278       result = arr_value->ToVector();
1279     } else {
1280       MS_EXCEPTION(TypeError) << "For primitive[" << prim_name << "], the " << arg_name
1281                               << " must be one of ['tuple', 'list'] with all Float elements, but got "
1282                               << abs->ToString();
1283     }
1284     return result;
1285   }
1286   MS_EXCEPTION(TypeError) << "For primitive[" << prim_name << "], the " << arg_name
1287                           << " must be one of ['tuple', 'list'] with all Float elements, but got " << abs->ToString();
1288 }
1289 
CheckIntOrTupleInt(const std::string & arg_name,const ValuePtr & attr,const std::string & prim_name)1290 std::vector<int64_t> CheckAndConvertUtils::CheckIntOrTupleInt(const std::string &arg_name, const ValuePtr &attr,
1291                                                               const std::string &prim_name) {
1292   std::vector<int64_t> result;
1293   bool is_correct = false;
1294   MS_EXCEPTION_IF_NULL(attr);
1295   if (attr->isa<ValueTuple>() || attr->isa<ValueList>()) {
1296     auto attr_vec =
1297       attr->isa<ValueTuple>() ? attr->cast<ValueTuplePtr>()->value() : attr->cast<ValueListPtr>()->value();
1298     if (attr_vec.empty()) {
1299       return result;
1300     }
1301     is_correct = std::all_of(attr_vec.begin(), attr_vec.end(), [&result](const ValuePtr &e) -> bool {
1302       MS_EXCEPTION_IF_NULL(e);
1303       if (e->isa<Int64Imm>()) {
1304         (void)result.emplace_back(GetValue<int64_t>(e));
1305         return true;
1306       } else if (e->isa<Int32Imm>()) {
1307         (void)result.emplace_back(GetValue<int32_t>(e));
1308         return true;
1309       }
1310       return false;
1311     });
1312   } else {
1313     if (attr->isa<Int64Imm>()) {
1314       is_correct = true;
1315       int64_t attr_val = attr->cast<Int64ImmPtr>()->value();
1316       result.push_back(attr_val);
1317     } else if (attr->isa<Int32Imm>()) {
1318       is_correct = true;
1319       int64_t attr_val = attr->cast<Int32ImmPtr>()->value();
1320       result.push_back(attr_val);
1321     }
1322   }
1323   if (!is_correct) {
1324     MS_EXCEPTION(TypeError) << "For primitive[" << prim_name << "], the " << arg_name
1325                             << " must be one of ['int', 'tuple', 'list'] with all Int elements, but got "
1326                             << attr->ToString();
1327   }
1328   return result;
1329 }
1330 
CheckIntOrTupleInt(const std::string & arg_name,const AbstractBasePtr & abs,const std::string & prim_name)1331 std::vector<int64_t> CheckAndConvertUtils::CheckIntOrTupleInt(const std::string &arg_name, const AbstractBasePtr &abs,
1332                                                               const std::string &prim_name) {
1333   std::vector<int64_t> result{};
1334   if (IsSequence(abs)) {
1335     const auto &type_list = GetSequenceElementTypes(abs);
1336     if (type_list.empty()) {
1337       return result;
1338     }
1339     auto is_correct = std::all_of(type_list.begin(), type_list.end(), [](const TypePtr &e) -> bool {
1340       MS_EXCEPTION_IF_NULL(e);
1341       return e->type_id() == kNumberTypeInt64 || e->type_id() == kNumberTypeInt32;
1342     });
1343     if (!is_correct) {
1344       MS_EXCEPTION(TypeError) << "For primitive[" << prim_name << "], when the " << arg_name
1345                               << "'s type is one of ['tuple', 'list'], its element data type must be int32 or int64, "
1346                                  "but got "
1347                               << abs->ToString();
1348     } else if (type_list.front()->type_id() == kNumberTypeInt64) {
1349       const auto &arr_value = ops::GetArrayValue<int64_t>(abs);
1350       if (arr_value->HasUnknownValue()) {
1351         MS_EXCEPTION(ValueError) << "For primitive[" << prim_name << "], there are unknown values in the " << arg_name
1352                                  << ", please handle this case before calling this function.";
1353       }
1354       result = arr_value->ToVector();
1355     } else if (type_list.front()->type_id() == kNumberTypeInt32) {
1356       const auto &arr_value = ops::GetArrayValue<int>(abs);
1357       if (arr_value->HasUnknownValue()) {
1358         MS_EXCEPTION(ValueError) << "For primitive[" << prim_name << "], there are unknown values in the " << arg_name
1359                                  << ", please handle this case before calling this function.";
1360       }
1361       const auto &vec_value = arr_value->ToVector();
1362       (void)std::transform(vec_value.begin(), vec_value.end(), std::back_inserter(result),
1363                            [](int ele) -> int64_t { return static_cast<int64_t>(ele); });
1364     }
1365   } else {
1366     if (!ops::IsValueKnown(abs)) {
1367       MS_EXCEPTION(ValueError) << "For primitive[" << prim_name << "], the value of  [" << arg_name
1368                                << "] is unknown, please handle this case before calling this function.";
1369     }
1370     auto data_type = abs->GetType();
1371     MS_EXCEPTION_IF_NULL(data_type);
1372     if (data_type->type_id() == kNumberTypeInt64) {
1373       const auto &val = ops::GetScalarValue<int64_t>(abs->GetValue());
1374       result.push_back(val.value());
1375     } else if (data_type->type_id() == kNumberTypeInt32) {
1376       const auto &val = ops::GetScalarValue<int>(abs->GetValue());
1377       result.push_back(val.value());
1378     } else {
1379       MS_EXCEPTION(TypeError) << "For primitive[" << prim_name << "], when the " << arg_name
1380                               << "'s type is 'int', its data type must be int32 or int64, but got "
1381                               << data_type->ToString();
1382     }
1383   }
1384   return result;
1385 }
1386 
CheckAttrTuple(const PrimitivePtr & prim,const std::string & attr_name,size_t num_element)1387 std::vector<int64_t> CheckAndConvertUtils::CheckAttrTuple(const PrimitivePtr &prim, const std::string &attr_name,
1388                                                           size_t num_element) {
1389   MS_EXCEPTION_IF_NULL(prim);
1390   auto attr = prim->GetAttr(attr_name);
1391   MS_EXCEPTION_IF_NULL(attr);
1392   std::vector<int64_t> result;
1393   if (!attr->isa<ValueTuple>()) {
1394     MS_EXCEPTION(ValueError) << "For '" << prim->name() << "', the '" << attr_name
1395                              << "' should be a tuple[int64], but got: " << attr->ToString() << ".";
1396   }
1397   std::vector<ValuePtr> attr_vec = attr->cast<ValueTuplePtr>()->value();
1398   if (attr_vec.size() != num_element) {
1399     MS_EXCEPTION(ValueError) << "For '" << prim->name() << "', the '" << attr_name
1400                              << "' should be a tuple[int64] with size " << num_element << ", but its size is "
1401                              << attr_vec.size() << ".";
1402   }
1403   (void)std::transform(attr_vec.begin(), attr_vec.end(), std::back_inserter(result),
1404                        [&prim, &attr_name](const ValuePtr &e) -> int64_t {
1405                          auto value = GetValue<int64_t>(e);
1406                          if (value < 0) {
1407                            MS_EXCEPTION(ValueError) << "For '" << prim->name() << "', the element of '" << attr_name
1408                                                     << "' should not be negative number, but got " << value << ".";
1409                          }
1410                          return value;
1411                        });
1412   return result;
1413 }
1414 
CheckTupleInt(const std::string & arg_name,const ValuePtr & attr,const std::string & prim_name)1415 std::vector<int64_t> CheckAndConvertUtils::CheckTupleInt(const std::string &arg_name, const ValuePtr &attr,
1416                                                          const std::string &prim_name) {
1417   std::vector<int64_t> result;
1418   MS_EXCEPTION_IF_NULL(attr);
1419   if (attr->isa<ValueTuple>()) {
1420     std::vector<ValuePtr> attr_vec = attr->cast<ValueTuplePtr>()->value();
1421     (void)std::transform(
1422       attr_vec.begin(), attr_vec.end(), std::back_inserter(result), [=](const ValuePtr &e) -> int64_t {
1423         if (!e->isa<Int64Imm>()) {
1424           MS_EXCEPTION(TypeError) << "For primitive[" << prim_name << "], the " << arg_name
1425                                   << " must be a tuple with all Int elements, but got " << attr->type_name();
1426         }
1427         return GetValue<int64_t>(e);
1428       });
1429   } else if (attr->isa<KernelTensorValue>()) {
1430     // to_do: check type of the KernelTensorValue is int64
1431     auto data_opt = ops::GetArrayValue<int64_t>(attr);
1432     const auto &data_array = data_opt.value();
1433     result = data_array.ToVector();
1434   } else {
1435     MS_EXCEPTION(TypeError) << "For primitive[" << prim_name << "], the " << arg_name
1436                             << " must be a tuple with all Int elements, but got " << attr->type_name() << ".";
1437   }
1438   return result;
1439 }
1440 
CheckListInt(const std::string & arg_name,const ValuePtr & attr,const std::string & prim_name)1441 std::vector<int64_t> CheckAndConvertUtils::CheckListInt(const std::string &arg_name, const ValuePtr &attr,
1442                                                         const std::string &prim_name) {
1443   std::vector<int64_t> result;
1444   MS_EXCEPTION_IF_NULL(attr);
1445   if (attr->isa<ValueList>()) {
1446     std::vector<ValuePtr> attr_vec = attr->cast<ValueListPtr>()->value();
1447     (void)std::transform(
1448       attr_vec.begin(), attr_vec.end(), std::back_inserter(result), [=](const ValuePtr &e) -> int64_t {
1449         if (!e->isa<Int64Imm>()) {
1450           MS_EXCEPTION(TypeError) << "For primitive[" << prim_name << "], the " << arg_name
1451                                   << " must be a list with all Int elements, but got " << attr->ToString();
1452         }
1453         return GetValue<int64_t>(e);
1454       });
1455   } else if (attr->isa<KernelTensorValue>()) {
1456     // to_do: check type of the KernelTensorValue is int64
1457     auto data_opt = ops::GetArrayValue<int64_t>(attr);
1458     const auto &data_array = data_opt.value();
1459     result = data_array.ToVector();
1460   } else {
1461     MS_EXCEPTION(TypeError) << "For primitive[" << prim_name << "], the " << arg_name
1462                             << " must be a list with all Int elements, but got " << attr->ToString() << ".";
1463   }
1464   return result;
1465 }
1466 
GetAndCheckFormat(const ValuePtr & value)1467 int64_t CheckAndConvertUtils::GetAndCheckFormat(const ValuePtr &value) {
1468   int64_t data_format;
1469   bool result = CheckAndConvertUtils::GetDataFormatEnumValue(value, &data_format);
1470   if (!result ||
1471       (data_format != static_cast<int64_t>(Format::NHWC) && data_format != static_cast<int64_t>(Format::NCHW) &&
1472        data_format != static_cast<int64_t>(Format::NCDHW))) {
1473     MS_LOG(EXCEPTION) << "data format value " << data_format << " is invalid, only support NCHW, NHWC and NCDHW";
1474   }
1475   return data_format;
1476 }
GetRemoveMonadAbsNum(const AbstractBasePtrList & abs_list)1477 size_t CheckAndConvertUtils::GetRemoveMonadAbsNum(const AbstractBasePtrList &abs_list) {
1478   size_t remove_monad_count = abs_list.size();
1479   for (const auto &item : abs_list) {
1480     if (item->isa<abstract::AbstractMonad>()) {
1481       --remove_monad_count;
1482     }
1483   }
1484 
1485   for (size_t i = 0; i < remove_monad_count; ++i) {
1486     if (abs_list[i]->isa<abstract::AbstractMonad>()) {
1487       MS_EXCEPTION(UnknownError) << "The monad inputs of the node must at last of the node inputs.";
1488     }
1489   }
1490   return remove_monad_count;
1491 }
GetRemoveUMonadAbsNum(const AbstractBasePtrList & abs_list)1492 size_t CheckAndConvertUtils::GetRemoveUMonadAbsNum(const AbstractBasePtrList &abs_list) {
1493   size_t remove_umonad_count = abs_list.size();
1494   for (const auto &item : abs_list) {
1495     if (item->isa<abstract::AbstractUMonad>()) {
1496       --remove_umonad_count;
1497     }
1498   }
1499 
1500   for (size_t i = 0; i < remove_umonad_count; ++i) {
1501     if (abs_list[i]->isa<abstract::AbstractUMonad>()) {
1502       MS_EXCEPTION(UnknownError) << "The umonad inputs of the node must at last of the node inputs.";
1503     }
1504   }
1505   return remove_umonad_count;
1506 }
HasDynamicShapeInput(const AbstractBasePtrList & abs_list)1507 bool CheckAndConvertUtils::HasDynamicShapeInput(const AbstractBasePtrList &abs_list) {
1508   for (const auto &item : abs_list) {
1509     MS_EXCEPTION_IF_NULL(item);
1510     auto shape = item->GetShape();
1511     if (shape->IsDynamic()) {
1512       return true;
1513     }
1514   }
1515   return false;
1516 }
1517 
CheckArgsType(const std::string & op,const AbstractBasePtrList & args_spec_list,size_t index,TypeId type_id)1518 AbstractBasePtr CheckAndConvertUtils::CheckArgsType(const std::string &op, const AbstractBasePtrList &args_spec_list,
1519                                                     size_t index, TypeId type_id) {
1520   if (index >= args_spec_list.size()) {
1521     MS_EXCEPTION(ValueError) << op << " evaluator arguments list index out of bound, size " << args_spec_list.size()
1522                              << ", index " << index;
1523   }
1524   auto args_abs = args_spec_list[index];
1525   MS_EXCEPTION_IF_NULL(args_abs);
1526   if (args_abs->GetType()->object_type() != type_id) {
1527     MS_EXCEPTION(TypeError) << "For primitive[" << op << "], the input[" << index << "] should be a "
1528                             << TypeIdToType(type_id)->ToString() << ", but got " << args_abs->GetType()->ToString()
1529                             << ".";
1530   }
1531   return args_abs;
1532 }
1533 
CheckArgsSequenceType(const std::string & op,const AbstractBasePtrList & args_spec_list,size_t index)1534 AbstractBasePtr CheckAndConvertUtils::CheckArgsSequenceType(const std::string &op,
1535                                                             const AbstractBasePtrList &args_spec_list, size_t index) {
1536   if (index >= args_spec_list.size()) {
1537     MS_EXCEPTION(ValueError) << op << " evaluator arguments list index out of bound, size " << args_spec_list.size()
1538                              << ", index " << index;
1539   }
1540   auto args_abs = args_spec_list[index];
1541   MS_EXCEPTION_IF_NULL(args_abs);
1542   if (!IsSequence(args_abs)) {
1543     MS_EXCEPTION(TypeError) << "For primitive[" << op << "], the input[" << index << "] should be a "
1544                             << "tuple or list, but got " << args_abs->GetType()->ToString() << ".";
1545   }
1546   return args_abs;
1547 }
1548 }  // namespace mindspore
1549