• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "utils/check_convert_utils.h"
18 
19 #include <utility>
20 #include <vector>
21 #include <algorithm>
22 #include <typeinfo>
23 #include <functional>
24 
25 #include "abstract/abstract_value.h"
26 #include "ops/op_utils.h"
27 #include "ir/dtype/type.h"
28 #include "ir/dtype/tensor_type.h"
29 #include "ir/dtype.h"
30 #include "utils/ms_context.h"
31 
32 namespace mindspore {
33 static std::map<std::string, int64_t> DataFormatToEnumMap = {
34   {"NCHW", Format::NCHW},   {"NHWC", Format::NHWC},     {"NHWC4", Format::NHWC4},
35   {"HWKC", Format::HWKC},   {"HWCK", Format::HWCK},     {"KCHW", Format::KCHW},
36   {"CKHW", Format::CKHW},   {"KHWC", Format::KHWC},     {"CHWK", Format::CHWK},
37   {"HW", Format::HW},       {"HW4", Format::HW4},       {"NC", Format::NC},
38   {"NC4", Format::NC4},     {"NC4HW4", Format::NC4HW4}, {"NUM_OF_FORMAT", Format::NUM_OF_FORMAT},
39   {"NCDHW", Format::NCDHW}, {"NWC", Format::NWC},       {"NCW", Format::NCW},
40 };
41 
42 static std::map<int64_t, std::string> DataFormatToStrMap = {
43   {Format::NCHW, "NCHW"},   {Format::NHWC, "NHWC"},     {Format::NHWC4, "NHWC4"},
44   {Format::HWKC, "HWKC"},   {Format::HWCK, "HWCK"},     {Format::KCHW, "KCHW"},
45   {Format::CKHW, "CKHW"},   {Format::KHWC, "KHWC"},     {Format::CHWK, "CHWK"},
46   {Format::HW, "HW"},       {Format::HW4, "HW4"},       {Format::NC, "NC"},
47   {Format::NC4, "NC4"},     {Format::NC4HW4, "NC4HW4"}, {Format::NUM_OF_FORMAT, "NUM_OF_FORMAT"},
48   {Format::NCDHW, "NCDHW"}, {Format::NWC, "NWC"},       {Format::NCW, "NCW"},
49 };
50 
51 static std::map<std::string, int64_t> ReductionToEnumMap = {
52   {"sum", Reduction::REDUCTION_SUM},
53   {"mean", Reduction::MEAN},
54   {"none", Reduction::NONE},
55 };
56 
57 static std::map<int64_t, std::string> ReductionToStrMap = {
58   {Reduction::REDUCTION_SUM, "sum"},
59   {Reduction::MEAN, "mean"},
60   {Reduction::NONE, "none"},
61 };
62 
63 static std::map<std::string, int64_t> PadModToEnumMap = {
64   {"pad", PadMode::PAD},
65   {"same", PadMode::SAME},
66   {"valid", PadMode::VALID},
67 };
68 
69 static std::map<int64_t, std::string> PadModToStrMap = {
70   {PadMode::PAD, "pad"},
71   {PadMode::SAME, "same"},
72   {PadMode::VALID, "valid"},
73 };
74 
75 static std::map<std::string, int64_t> PadModToEnumUpperMap = {
76   {"PAD", PadMode::PAD},
77   {"SAME", PadMode::SAME},
78   {"VALID", PadMode::VALID},
79 };
80 
81 static std::map<int64_t, std::string> PadModToStrUpperMap = {
82   {PadMode::PAD, "PAD"},
83   {PadMode::SAME, "SAME"},
84   {PadMode::VALID, "VALID"},
85 };
86 
87 AttrConverterPair DataFormatConverter(DataFormatToEnumMap, DataFormatToStrMap);
88 AttrConverterPair PadModeConverter(PadModToEnumMap, PadModToStrMap);
89 AttrConverterPair PadModeUpperConverter(PadModToEnumUpperMap, PadModToStrUpperMap);
90 AttrConverterPair ReductionConverter(ReductionToEnumMap, ReductionToStrMap);
91 
92 static std::map<std::string, AttrConverterPair> FormatAndPadAttrMap = {
93   {ops::kFormat, DataFormatConverter},
94   {ops::kPadMode, PadModeConverter},
95 };
96 
97 static std::map<std::string, AttrConverterPair> FormatAndPadUpperAttrMap = {
98   {ops::kFormat, DataFormatConverter},
99   {ops::kPadMode, PadModeUpperConverter},
100 };
101 
102 static std::map<std::string, AttrConverterPair> DataFormatMap = {
103   {ops::kFormat, DataFormatConverter},
104 };
105 
106 static std::map<std::string, AttrConverterPair> ReductionMap = {
107   {ops::kReduction, ReductionConverter},
108 };
109 
110 static std::map<std::string, std::map<std::string, AttrConverterPair>> PrimAttrConvertMap = {
111   {"Conv2D", FormatAndPadAttrMap},
112   {"Conv2DTranspose", FormatAndPadUpperAttrMap},
113   {"Conv2DBackpropInput", FormatAndPadUpperAttrMap},
114   {"Conv2DBackpropFilter", FormatAndPadUpperAttrMap},
115   {"Conv3D", FormatAndPadAttrMap},
116   {"Conv3DBackpropInput", FormatAndPadAttrMap},
117   {"Conv3DBackpropFilter", FormatAndPadAttrMap},
118   {"Conv3DTranspose", DataFormatMap},
119   {"DepthwiseConv2dNative", FormatAndPadAttrMap},
120   {"DepthwiseConv2dNativeBackpropInput", FormatAndPadAttrMap},
121   {"DepthwiseConv2dNativeBackpropFilter", FormatAndPadAttrMap},
122   {"AvgPool", FormatAndPadUpperAttrMap},
123   {"MaxPool", FormatAndPadUpperAttrMap},
124   {"MaxPoolWithArgmax", FormatAndPadUpperAttrMap},
125   {"AvgPoolGrad", FormatAndPadUpperAttrMap},
126   {"AvgPoolGradVm", FormatAndPadUpperAttrMap},
127   {"AvgPoolGradGpu", FormatAndPadUpperAttrMap},
128   {"AvgPoolGradCpu", FormatAndPadUpperAttrMap},
129   {"MaxPoolGrad", FormatAndPadUpperAttrMap},
130   {"MaxPoolGradGrad", FormatAndPadUpperAttrMap},
131   {"MaxPoolGradWithArgmax", FormatAndPadUpperAttrMap},
132   {"MaxPoolGradGradWithArgmax", FormatAndPadUpperAttrMap},
133   {"BatchNorm", DataFormatMap},
134   {"BatchNormGrad", DataFormatMap},
135   {"BiasAdd", DataFormatMap},
136   {"BiasAddGrad", DataFormatMap},
137   {"BinaryCrossEntropy", ReductionMap},
138   {"BinaryCrossEntropyGrad", ReductionMap},
139   {"NLLLoss", ReductionMap},
140   {"DepthToSpace", DataFormatMap},
141   {"Pooling", DataFormatMap},
142   {"Deconvolution", DataFormatMap},
143   {"AvgPoolV2", DataFormatMap},
144   {"MaxPoolV3", DataFormatMap},
145   {"FusedBatchNorm", DataFormatMap}};
146 
GetDataFormatEnumValue(const ValuePtr & value,int64_t * enum_value)147 bool CheckAndConvertUtils::GetDataFormatEnumValue(const ValuePtr &value, int64_t *enum_value) {
148   MS_EXCEPTION_IF_NULL(value);
149   if (value->isa<StringImm>()) {
150     auto attr_value_str = GetValue<std::string>(value);
151     if (DataFormatToEnumMap.find(attr_value_str) == DataFormatToEnumMap.end()) {
152       MS_LOG(DEBUG) << "The data format " << attr_value_str << " not be converted to enum.";
153       return false;
154     }
155     *enum_value = DataFormatToEnumMap[attr_value_str];
156     return true;
157   } else {
158     *enum_value = GetValue<int64_t>(value);
159     return true;
160   }
161 }
162 
GetPadModEnumValue(const ValuePtr & value,int64_t * enum_value,bool is_upper)163 void CheckAndConvertUtils::GetPadModEnumValue(const ValuePtr &value, int64_t *enum_value, bool is_upper) {
164   MS_EXCEPTION_IF_NULL(value);
165   if (value->isa<StringImm>()) {
166     auto attr_value_str = GetValue<std::string>(value);
167 
168     std::map<std::string, int64_t> pad_map = PadModToEnumMap;
169     if (is_upper) {
170       pad_map = PadModToEnumUpperMap;
171     }
172     if (pad_map.find(attr_value_str) == pad_map.end()) {
173       MS_LOG(EXCEPTION) << "Invalid pad mode " << attr_value_str << " use pad, valid or same";
174     }
175     *enum_value = pad_map[attr_value_str];
176   } else {
177     *enum_value = GetValue<int64_t>(value);
178   }
179 }
180 
GetReductionEnumValue(const ValuePtr & value,int64_t * enum_value)181 void CheckAndConvertUtils::GetReductionEnumValue(const ValuePtr &value, int64_t *enum_value) {
182   MS_EXCEPTION_IF_NULL(value);
183   if (value->isa<StringImm>()) {
184     auto attr_value_str = GetValue<std::string>(value);
185 
186     std::map<std::string, int64_t> pad_map = ReductionToEnumMap;
187     if (pad_map.find(attr_value_str) == pad_map.end()) {
188       MS_LOG(EXCEPTION) << "Invalid pad mode " << attr_value_str << " use pad, valid or same";
189     }
190     *enum_value = pad_map[attr_value_str];
191   } else {
192     *enum_value = GetValue<int64_t>(value);
193   }
194 }
195 
GetAttrConvertPair(const std::string & op_type,const std::string & attr_name)196 AttrConverterPair CheckAndConvertUtils::GetAttrConvertPair(const std::string &op_type, const std::string &attr_name) {
197   AttrConverterPair attr_pair;
198   if (op_type.empty() || attr_name.empty()) {
199     return attr_pair;
200   }
201   auto op_attr_map_it = PrimAttrConvertMap.find(op_type);
202   if (op_attr_map_it == PrimAttrConvertMap.end()) {
203     return attr_pair;
204   }
205   auto attr_pair_it = op_attr_map_it->second.find(attr_name);
206   if (attr_pair_it == op_attr_map_it->second.end()) {
207     return attr_pair;
208   }
209 
210   return attr_pair_it->second;
211 }
212 
ConvertAttrValueToInt(const std::string & op_type,const std::string & attr_name,ValuePtr * const value)213 bool CheckAndConvertUtils::ConvertAttrValueToInt(const std::string &op_type, const std::string &attr_name,
214                                                  ValuePtr *const value) {
215   if (value == nullptr || *value == nullptr) {
216     MS_LOG(DEBUG) << "value of attr " << op_type << attr_name << " is nullptr.";
217     return false;
218   }
219   if (!(*value)->isa<StringImm>()) {
220     return false;
221   }
222   auto attr_map_pair = GetAttrConvertPair(op_type, attr_name);
223   if (attr_map_pair.first.size() == 0) {
224     return false;
225   }
226 
227   std::string real_value = std::dynamic_pointer_cast<StringImm>(*value)->value();
228   bool do_convert = false;
229   if (attr_map_pair.first.find(real_value) != attr_map_pair.first.end()) {
230     do_convert = true;
231   }
232   if (!do_convert) {
233     transform(real_value.begin(), real_value.end(), real_value.begin(), ::toupper);
234     if (attr_map_pair.first.find(real_value) != attr_map_pair.first.end()) {
235       do_convert = true;
236     }
237   }
238   if (!do_convert) {
239     transform(real_value.begin(), real_value.end(), real_value.begin(), ::tolower);
240     if (attr_map_pair.first.find(real_value) == attr_map_pair.first.end()) {
241       MS_LOG(DEBUG) << "Can not convert " << op_type << " attr " << attr_name << ": " << real_value << " to int";
242       return false;
243     }
244   }
245   *value = MakeValue<int64_t>(attr_map_pair.first[real_value]);
246   MS_LOG(DEBUG) << "convert str to int, name: " << op_type << ", attr: " << attr_name;
247   return true;
248 }
249 
ConvertAttrValueToString(const std::string & op_type,const std::string & attr_name,ValuePtr * const value)250 bool CheckAndConvertUtils::ConvertAttrValueToString(const std::string &op_type, const std::string &attr_name,
251                                                     ValuePtr *const value) {
252   if (value == nullptr || *value == nullptr) {
253     MS_LOG(DEBUG) << "value is nullptr! op_type = " << op_type << ", attr_name = " << attr_name;
254     return false;
255   }
256   if (!(*value)->isa<Int64Imm>()) {
257     return false;
258   }
259   auto attr_map_pair = GetAttrConvertPair(op_type, attr_name);
260   if (attr_map_pair.second.size() == 0) {
261     return false;
262   }
263 
264   int64_t real_value = std::dynamic_pointer_cast<Int64Imm>(*value)->value();
265   if (attr_map_pair.second.find(real_value) == attr_map_pair.second.end()) {
266     MS_LOG(DEBUG) << "Can not convert " << op_type << " attr " << attr_name << ": " << real_value << " to string";
267     return false;
268   }
269   *value = MakeValue<std::string>(attr_map_pair.second[real_value]);
270   MS_LOG(DEBUG) << "convert int to str, name: " << op_type << ", attr: " << attr_name;
271   return true;
272 }
273 
ConvertAttrValueInExport(const std::string & op_type,const std::string & attr_name,ValuePtr * const value)274 void CheckAndConvertUtils::ConvertAttrValueInExport(const std::string &op_type, const std::string &attr_name,
275                                                     ValuePtr *const value) {
276   if (value == nullptr || *value == nullptr) {
277     MS_LOG(DEBUG) << "value is nullptr! op_type = " << op_type << ", attr_name = " << attr_name;
278     return;
279   }
280   // convert enum to string
281   ConvertAttrValueToString(op_type, attr_name, value);
282 }
283 
ConvertAttrValueInLoad(const std::string & op_type,const std::string & attr_name,ValuePtr * const value)284 void CheckAndConvertUtils::ConvertAttrValueInLoad(const std::string &op_type, const std::string &attr_name,
285                                                   ValuePtr *const value) {
286   if (value == nullptr || *value == nullptr) {
287     MS_LOG(DEBUG) << "value is nullptr! op_type = " << op_type << ", attr_name = " << attr_name;
288     return;
289   }
290   // convert string to enum
291   ConvertAttrValueToInt(op_type, attr_name, value);
292 }
293 
294 namespace {
295 typedef std::map<std::string, std::function<ValuePtr(ValuePtr)>> AttrFunction;
296 
L2NormalizeAttrConversion(ValuePtr attr)297 ValuePtr L2NormalizeAttrConversion(ValuePtr attr) {
298   if (attr->isa<Int64Imm>()) {
299     return attr;
300   }
301   auto attr_value = GetValue<std::vector<int64_t>>(attr);
302   return MakeValue(attr_value[0]);
303 }
304 
305 std::map<std::string, AttrFunction> kIrAttrToOpAttr = {{"L2Normalize", {{"axis", L2NormalizeAttrConversion}}},
306                                                        {"L2NormalizeGrad", {{"axis", L2NormalizeAttrConversion}}}};
307 }  // namespace
308 
CheckPositiveVector(const std::string & arg_name,const std::vector<int64_t> & arg_value,const std::string & prim_name)309 std::vector<int64_t> CheckAndConvertUtils::CheckPositiveVector(const std::string &arg_name,
310                                                                const std::vector<int64_t> &arg_value,
311                                                                const std::string &prim_name) {
312   std::ostringstream buffer;
313   buffer << "The primitive[" << prim_name << "]'s attribute[" << arg_name
314          << "] should be a vector with all positive item. but got [";
315   if (std::any_of(arg_value.begin(), arg_value.end(), [](int64_t item) { return item < 0; })) {
316     for (auto item : arg_value) {
317       buffer << item << ", ";
318     }
319     buffer << "].";
320     MS_EXCEPTION(ValueError) << buffer.str();
321   }
322 
323   return arg_value;
324 }
325 
CheckString(const std::string & arg_name,const std::string & arg_value,const std::set<std::string> & check_list,const std::string & prim_name)326 std::string CheckAndConvertUtils::CheckString(const std::string &arg_name, const std::string &arg_value,
327                                               const std::set<std::string> &check_list, const std::string &prim_name) {
328   if (check_list.find(arg_value) != check_list.end()) {
329     return arg_value;
330   }
331   std::ostringstream buffer;
332   buffer << "The primitive[" << prim_name << "]'s attribute[" << arg_name << "]";
333   if (check_list.size() == 1) {
334     buffer << " must be \"" << (*check_list.begin()) << "\",but got \"" << arg_value << "\".";
335     MS_EXCEPTION(ValueError) << buffer.str();
336   }
337   buffer << " should be a element of {";
338   for (const auto &item : check_list) {
339     buffer << "\"" << item << "\", ";
340   }
341   buffer << "}"
342          << ",but got \"" << arg_value << "\""
343          << ".";
344   MS_EXCEPTION(ValueError) << buffer.str();
345 }
346 
CheckInteger(const std::string & arg_name,int64_t arg_value,CompareEnum compare_operator,int64_t match_value,const std::string & prim_name)347 int64_t CheckAndConvertUtils::CheckInteger(const std::string &arg_name, int64_t arg_value, CompareEnum compare_operator,
348                                            int64_t match_value, const std::string &prim_name) {
349   auto iter = kCompareMap<float>.find(compare_operator);
350   if (iter == kCompareMap<float>.end()) {
351     MS_EXCEPTION(NotExistsError) << "compare_operator " << compare_operator << " cannot find in the compare map";
352   }
353   if (iter->second(arg_value, match_value)) {
354     return arg_value;
355   }
356   std::ostringstream buffer;
357   if (prim_name.empty()) {
358     buffer << "The argument[" << arg_name << "] must ";
359   } else {
360     buffer << "The primitive[" << prim_name << "]'s " << arg_name << " must ";
361   }
362   auto iter_to_string = kCompareToString.find(compare_operator);
363   if (iter_to_string == kCompareToString.end()) {
364     MS_EXCEPTION(NotExistsError) << "compare_operator " << compare_operator << " cannot find in the compare string map";
365   }
366   buffer << iter_to_string->second << match_value << ", but got " << arg_value << ".";
367   MS_EXCEPTION(ValueError) << buffer.str();
368 }
369 
CheckInputArgs(const std::vector<AbstractBasePtr> & input_args,const CompareEnum compare_operator,const int64_t match_value,const std::string & prim_name)370 void CheckAndConvertUtils::CheckInputArgs(const std::vector<AbstractBasePtr> &input_args,
371                                           const CompareEnum compare_operator, const int64_t match_value,
372                                           const std::string &prim_name) {
373   (void)CheckInteger("input number", SizeToLong(input_args.size()), compare_operator, match_value, prim_name);
374   for (size_t index = 0; index < input_args.size(); index++) {
375     if (input_args[index] == nullptr) {
376       MS_EXCEPTION(ValueError) << "The " << index << "'s input of " << prim_name << " is nullptr.";
377     }
378   }
379 }
380 
GetInputTensorType(const std::vector<AbstractBasePtr> & input_args,const size_t index,const std::string & prim_name)381 TypePtr CheckAndConvertUtils::GetInputTensorType(const std::vector<AbstractBasePtr> &input_args, const size_t index,
382                                                  const std::string &prim_name) {
383   if (input_args.size() <= index) {
384     MS_EXCEPTION(ValueError) << "The primitive[" << prim_name << "]'s input index[" << index
385                              << "] is out of the input number " << input_args.size();
386   }
387   auto input_arg = input_args[index];
388   if (input_arg == nullptr) {
389     MS_EXCEPTION(ValueError) << "The primitive[" << prim_name << "]'s input index[" << index << "] is nullptr.";
390   }
391   auto base_type = input_arg->BuildType();
392   MS_EXCEPTION_IF_NULL(base_type);
393   if (!base_type->isa<TensorType>()) {
394     MS_EXCEPTION(ValueError) << "The primitive[" << prim_name << "]'s input index[" << index << "] is not a tensor.";
395   }
396   auto tensor_type = base_type->cast<TensorTypePtr>();
397   MS_EXCEPTION_IF_NULL(tensor_type);
398   auto type = tensor_type->element();
399   MS_EXCEPTION_IF_NULL(type);
400   return type;
401 }
402 
ConvertShapePtrToShapeMap(const BaseShapePtr & shape)403 ShapeMap CheckAndConvertUtils::ConvertShapePtrToShapeMap(const BaseShapePtr &shape) {
404   MS_EXCEPTION_IF_NULL(shape);
405   if (!shape->isa<abstract::Shape>()) {
406     return std::map<std::string, std::vector<int64_t>>();
407   }
408   auto shape_element = shape->cast<abstract::ShapePtr>();
409   MS_EXCEPTION_IF_NULL(shape_element);
410   ShapeMap shape_map;
411   shape_map[kShape] = shape_element->shape();
412   shape_map[kMinShape] = shape_element->min_shape();
413   shape_map[kMaxShape] = shape_element->max_shape();
414   return shape_map;
415 }
416 
GetTensorInputShape(const std::string & prim_name,const std::vector<AbstractBasePtr> & input_args,int64_t index)417 abstract::ShapePtr CheckAndConvertUtils::GetTensorInputShape(const std::string &prim_name,
418                                                              const std::vector<AbstractBasePtr> &input_args,
419                                                              int64_t index) {
420   auto abstract = CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, LongToSize(index));
421   MS_EXCEPTION_IF_NULL(abstract);
422   auto base_shape = abstract->BuildShape();
423   MS_EXCEPTION_IF_NULL(base_shape);
424   if (!base_shape->isa<abstract::Shape>()) {
425     MS_LOG(EXCEPTION) << prim_name << " can not get shape for input " << index;
426   }
427   auto shape = base_shape->cast<abstract::ShapePtr>();
428   MS_EXCEPTION_IF_NULL(shape);
429   return shape;
430 }
431 
Check(const string & arg_name,int64_t arg_value,CompareEnum compare_type,const string &,int64_t value,const string & prim_name,ExceptionType)432 void CheckAndConvertUtils::Check(const string &arg_name, int64_t arg_value, CompareEnum compare_type, const string &,
433                                  int64_t value, const string &prim_name, ExceptionType) {
434   auto iter = kCompareMap<float>.find(compare_type);
435   if (iter == kCompareMap<float>.end()) {
436     MS_EXCEPTION(NotExistsError) << "the compare type :" << compare_type << " is not in the compare map";
437   }
438   if (iter->second(arg_value, value)) {
439     return;
440   }
441   std::ostringstream buffer;
442   if (prim_name.empty()) {
443     buffer << "The attribute[" << arg_name << "] must ";
444   } else {
445     buffer << "The primitive[" << prim_name << "]'s attribute[" << arg_name << "] must ";
446   }
447   auto iter_to_string = kCompareToString.find(compare_type);
448   if (iter_to_string == kCompareToString.end()) {
449     MS_EXCEPTION(NotExistsError) << "compare_operator " << compare_type << " cannot find in the compare string map";
450   }
451   buffer << iter_to_string->second << value << ", but got " << arg_value << ".";
452   MS_EXCEPTION(ValueError) << buffer.str();
453 }
454 
CheckTensorTypeSame(const std::map<std::string,TypePtr> & types,const std::set<TypePtr> & check_list,const std::string & prim_name)455 TypePtr CheckAndConvertUtils::CheckTensorTypeSame(const std::map<std::string, TypePtr> &types,
456                                                   const std::set<TypePtr> &check_list, const std::string &prim_name) {
457   if (types.empty()) {
458     MS_EXCEPTION(ArgumentError) << "Trying to use the function to check a empty types map!";
459   }
460   for (const auto &item : types) {
461     auto type = item.second;
462     MS_EXCEPTION_IF_NULL(type);
463     if (!type->isa<TensorType>()) {
464       std::ostringstream buffer;
465       buffer << "The primitive[" << prim_name << "]'s input arguments must be all tensor.\n";
466       if (!check_list.empty()) {
467         buffer << "Valid type list: {";
468         for (auto const &valid_type : check_list) {
469           if (valid_type->isa<TensorType>()) {
470             buffer << valid_type->ToString() << ", ";
471             break;
472           }
473           buffer << "Tensor[" << valid_type << "]"
474                  << ", ";
475         }
476         buffer << "}.\n";
477       }
478       for (const auto &type_info : types) {
479         buffer << "input argument[" << type_info.first << "]"
480                << ":" << type_info.second->ToString() << "\n";
481       }
482       MS_EXCEPTION(TypeError) << buffer.str();
483     }
484   }
485   auto check_type = _CheckTypeSame(types, prim_name, false);
486   std::string input_names = "";
487   for (const auto &item : types) {
488     (void)input_names.append(item.first);
489     (void)input_names.append(", ");
490   }
491   return CheckSubClass(input_names, check_type, check_list, prim_name);
492 }
493 
CheckTensorTypeValid(const std::string & type_name,const TypePtr & type,const std::set<TypePtr> & check_list,const std::string & prim_name)494 TypePtr CheckAndConvertUtils::CheckTensorTypeValid(const std::string &type_name, const TypePtr &type,
495                                                    const std::set<TypePtr> &check_list, const std::string &prim_name) {
496   MS_EXCEPTION_IF_NULL(type);
497   if (!type->isa<TensorType>()) {
498     MS_EXCEPTION(TypeError) << "The Primitive[" << prim_name << "] input argument[" << type_name
499                             << "] must be a Tensor but got " << type->ToString() << ".";
500   }
501   auto tensor_type = type->cast<TensorTypePtr>();
502   auto element = tensor_type->element();
503   MS_EXCEPTION_IF_NULL(element);
504   for (const TypePtr &item : check_list) {
505     if (item->isa<TensorType>()) {
506       auto item_tensor_type = item->cast<TensorTypePtr>();
507       if (item_tensor_type->element() == nullptr) {
508         return element;
509       }
510     }
511   }
512   return CheckSubClass(type_name, type, check_list, prim_name);
513 }
514 
CheckTensorIntValue(const std::string & type_name,const ValuePtr & value,const std::string & prim_name)515 ShapeVector CheckAndConvertUtils::CheckTensorIntValue(const std::string &type_name, const ValuePtr &value,
516                                                       const std::string &prim_name) {
517   if (value == nullptr) {
518     MS_EXCEPTION(ValueError) << "The primitive[" << prim_name << "] input argument[" << type_name
519                              << "] value is nullptr.";
520   }
521   ShapeVector tensor_value;
522   if (!value->isa<tensor::Tensor>()) {
523     MS_EXCEPTION(ValueError) << "The primitive[" << prim_name << "] input argument[" << type_name
524                              << "] must be a tensor,but got " << value->ToString();
525   }
526   auto input_tensor = value->cast<tensor::TensorPtr>();
527   MS_EXCEPTION_IF_NULL(input_tensor);
528   size_t data_size = LongToSize(input_tensor->DataSize());
529   auto tensor_type = input_tensor->Dtype();
530   if (tensor_type->type_id() == kNumberTypeInt32) {
531     auto data_c = reinterpret_cast<int *>(input_tensor->data_c());
532     MS_EXCEPTION_IF_NULL(data_c);
533     for (size_t i = 0; i < data_size; i++) {
534       tensor_value.push_back(static_cast<int64_t>(*data_c));
535       ++data_c;
536     }
537   } else if (tensor_type->type_id() == kNumberTypeInt64) {
538     auto tensor_data = reinterpret_cast<int64_t *>(input_tensor->data_c());
539     MS_EXCEPTION_IF_NULL(tensor_data);
540     tensor_value = {tensor_data, tensor_data + data_size};
541   } else {
542     MS_EXCEPTION(TypeError) << "The primitive[" << prim_name << "] input argument[" << type_name
543                             << "] must be a Tensor[Int64] or Tensor[Int32] type,but got " << value->ToString();
544   }
545   return tensor_value;
546 }
547 
CheckSubClass(const std::string & type_name,const TypePtr & type,const std::set<TypePtr> & template_types,const std::string & prim_name)548 TypePtr CheckAndConvertUtils::CheckSubClass(const std::string &type_name, const TypePtr &type,
549                                             const std::set<TypePtr> &template_types, const std::string &prim_name) {
550   auto check_type = type;
551   bool ok = std::any_of(template_types.begin(), template_types.end(), [check_type](const TypePtr &accept) -> bool {
552     return IsIdentidityOrSubclass(check_type, accept);
553   });
554   if (ok) {
555     return check_type;
556   }
557   if (type->isa<TensorType>()) {
558     auto tensor_type = type->cast<TensorTypePtr>();
559     check_type = tensor_type->element();
560   }
561   ok = std::any_of(template_types.begin(), template_types.end(),
562                    [check_type](const TypePtr &accept) -> bool { return IsIdentidityOrSubclass(check_type, accept); });
563   if (ok) {
564     return check_type;
565   } else {
566     std::ostringstream buffer;
567     buffer << "Primitive[" << prim_name << "]'s input argument[" << type_name << "] must be a type of ";
568     buffer << GetErrorTypeString(template_types, type) << ", but got " << type->ToString();
569     buffer << ".";
570     MS_EXCEPTION(TypeError) << buffer.str();
571   }
572 }
573 
CheckScalarOrTensorTypesSame(const std::map<std::string,TypePtr> & args,const std::set<TypePtr> & valid_values,const std::string & prim_name,const bool allow_mix)574 TypePtr CheckAndConvertUtils::CheckScalarOrTensorTypesSame(const std::map<std::string, TypePtr> &args,
575                                                            const std::set<TypePtr> &valid_values,
576                                                            const std::string &prim_name, const bool allow_mix) {
577   auto arg_ = _CheckTypeSame(args, prim_name, allow_mix);
578   return CheckTypeValid(args.begin()->first, arg_, valid_values, prim_name);
579 }
580 
_CheckTypeSame(const std::map<std::string,TypePtr> & args,const std::string & prim_name,const bool allow_mix)581 TypePtr CheckAndConvertUtils::_CheckTypeSame(const std::map<std::string, TypePtr> &args, const std::string &prim_name,
582                                              const bool allow_mix) {
583   if (args.empty()) {
584     MS_EXCEPTION(ArgumentError) << "Trying to use the function to check a empty types map!";
585   }
586   std::ostringstream buffer;
587   TypePtr return_type = args.begin()->second;
588   buffer << "The primitive[" << prim_name << "]";
589   bool tensor_flag = return_type->isa<TensorType>();
590   std::set<TypeId> types_id;
591   for (const auto &elem : args) {
592     auto type = elem.second;
593     MS_EXCEPTION_IF_NULL(type);
594     if (!allow_mix) {
595       // input must be all tensor or all other type
596       if ((tensor_flag && !type->isa<TensorType>()) || (!tensor_flag && type->isa<TensorType>())) {
597         buffer << "'s "
598                << "input type must be same.\n";
599         for (const auto &error_elem : args) {
600           buffer << "input argument[" << error_elem.first << "]:" << error_elem.second->ToString() << "\n";
601         }
602         MS_EXCEPTION(TypeError) << buffer.str();
603       }
604     }
605     if (type->isa<TensorType>()) {
606       auto tensor_type = type->cast<TensorTypePtr>();
607       auto element = tensor_type->element();
608       MS_EXCEPTION_IF_NULL(element);
609       if (!allow_mix) {
610         return_type = element;
611       } else {
612         return_type = tensor_type;
613       }
614       (void)types_id.emplace(element->type_id());
615     } else {
616       (void)types_id.emplace(type->type_id());
617     }
618     if (types_id.size() > 1) {
619       buffer << "'s input type must be same.\n";
620       for (const auto &item : args) {
621         buffer << "name:[" << item.first << "]:" << item.second->ToString() << ".\n";
622       }
623       MS_EXCEPTION(TypeError) << buffer.str();
624     }
625   }
626   return return_type->DeepCopy();
627 }
628 
CheckTypeValid(const std::string & arg_name,const TypePtr & arg_type,const std::set<TypePtr> & valid_type,const std::string & prim_name)629 TypePtr CheckAndConvertUtils::CheckTypeValid(const std::string &arg_name, const TypePtr &arg_type,
630                                              const std::set<TypePtr> &valid_type, const std::string &prim_name) {
631   if (valid_type.empty()) {
632     MS_EXCEPTION(ArgumentError) << "Trying to use the function to check a empty valid_type!";
633   }
634   MS_EXCEPTION_IF_NULL(arg_type);
635   if (arg_type->isa<TensorType>()) {
636     return CheckTensorTypeValid(arg_name, arg_type, valid_type, prim_name);
637   }
638   return CheckSubClass(arg_name, arg_type, valid_type, prim_name);
639 }
640 
CheckIrAttrtoOpAttr(const std::string & op_type,const std::string & attr_name,ValuePtr * const value)641 bool CheckAndConvertUtils::CheckIrAttrtoOpAttr(const std::string &op_type, const std::string &attr_name,
642                                                ValuePtr *const value) {
643   if (*value == nullptr) {
644     MS_LOG(DEBUG) << "value is nullptr! op_type = " << op_type << ", attr_name = " << attr_name;
645     return false;
646   }
647   if (op_type.empty() || attr_name.empty()) {
648     return false;
649   }
650   auto op_map = kIrAttrToOpAttr.find(op_type);
651   if (op_map == kIrAttrToOpAttr.end()) {
652     return false;
653   }
654   auto attr_func = op_map->second.find(attr_name);
655   if (attr_func == op_map->second.end()) {
656     return false;
657   }
658   *value = attr_func->second(*value);
659   MS_LOG(DEBUG) << "convert ir attr to op attr, name: " << op_type << ", attr: " << attr_name;
660   return true;
661 }
662 
CheckSummaryParam(const AbstractBasePtr & name,const AbstractBasePtr & value,const std::string & class_name)663 void CheckAndConvertUtils::CheckSummaryParam(const AbstractBasePtr &name, const AbstractBasePtr &value,
664                                              const std::string &class_name) {
665   MS_EXCEPTION_IF_NULL(name);
666   MS_EXCEPTION_IF_NULL(value);
667   CheckMode(class_name);
668   (void)CheckTypeValid("name", name->BuildType(), {kString}, class_name);
669   auto s = GetValue<std::string>(name->BuildValue());
670   if (s.empty()) {
671     MS_EXCEPTION(ValueError) << "The primitive[" << class_name << "]'s input argument[name] "
672                              << " cannot be an empty string.";
673   }
674   (void)CheckTypeValid("value", value->BuildType(), {kTensorType}, class_name);
675 }
676 
CheckMode(const std::string & class_name)677 void CheckAndConvertUtils::CheckMode(const std::string &class_name) {
678   auto ms_context = MsContext::GetInstance();
679   MS_EXCEPTION_IF_NULL(ms_context);
680   if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
681     MS_EXCEPTION(NotSupportError) << "The primitive[" << class_name << "] does not support PyNativeMode.\n"
682                                   << "Please convert the mode to GraphMode";
683   }
684 }
685 
CheckAttrIntOrTupleInt(const std::string & arg_name,const ValuePtr & attr,const std::string & prim_name)686 std::vector<int64_t> CheckAndConvertUtils::CheckAttrIntOrTupleInt(const std::string &arg_name, const ValuePtr &attr,
687                                                                   const std::string &prim_name) {
688   std::vector<int64_t> result;
689   bool is_correct = false;
690   MS_EXCEPTION_IF_NULL(attr);
691   if (attr->isa<ValueTuple>()) {
692     std::vector<ValuePtr> attr_vec = attr->cast<ValueTuplePtr>()->value();
693     is_correct = std::all_of(attr_vec.begin(), attr_vec.end(), [&result](const ValuePtr &e) -> bool {
694       MS_EXCEPTION_IF_NULL(e);
695       if (e->isa<Int64Imm>()) {
696         (void)result.emplace_back(GetValue<int64_t>(e));
697         return true;
698       }
699       return false;
700     });
701   } else {
702     if (attr->isa<Int64Imm>()) {
703       is_correct = true;
704       int64_t attr_val = attr->cast<Int64ImmPtr>()->value();
705       result.push_back(attr_val);
706     }
707   }
708   if (!is_correct) {
709     MS_EXCEPTION(TypeError) << "The primitive[" << prim_name << "]'s attribute[" << arg_name
710                             << "] must be a Int or a tuple with all Int elements, but got " << attr->ToString();
711   }
712   return result;
713 }
714 
CheckAttrTupleInt(const std::string & arg_name,const ValuePtr & attr,const std::string & prim_name)715 std::vector<int64_t> CheckAndConvertUtils::CheckAttrTupleInt(const std::string &arg_name, const ValuePtr &attr,
716                                                              const std::string &prim_name) {
717   std::vector<int64_t> result;
718   MS_EXCEPTION_IF_NULL(attr);
719   if (attr->isa<ValueTuple>()) {
720     std::vector<ValuePtr> attr_vec = attr->cast<ValueTuplePtr>()->value();
721     (void)std::transform(
722       attr_vec.begin(), attr_vec.end(), std::back_inserter(result), [=](const ValuePtr &e) -> int64_t {
723         if (!e->isa<Int64Imm>()) {
724           MS_EXCEPTION(TypeError) << "The primitive[" << prim_name << "]'s attribute[" << arg_name
725                                   << "] must be a tuple with all Int elements, but got " << attr->ToString();
726         }
727         return GetValue<int64_t>(e);
728       });
729   } else {
730     MS_EXCEPTION(TypeError) << "The primitive[" << prim_name << "]'s attribute[" << arg_name
731                             << "] must be a tuple with all Int elements, but got " << attr->ToString() << ".";
732   }
733   return result;
734 }
735 
CheckMinMaxShape(const ShapeVector & shape,ShapeVector * min_shape,ShapeVector * max_shape)736 void CheckAndConvertUtils::CheckMinMaxShape(const ShapeVector &shape, ShapeVector *min_shape, ShapeVector *max_shape) {
737   *min_shape = (*min_shape).empty() ? shape : *min_shape;
738   *max_shape = (*max_shape).empty() ? shape : *max_shape;
739 }
740 
GetAndCheckFormat(const ValuePtr & value)741 int64_t CheckAndConvertUtils::GetAndCheckFormat(const ValuePtr &value) {
742   int64_t data_format;
743   bool result = CheckAndConvertUtils::GetDataFormatEnumValue(value, &data_format);
744   if (!result || (data_format != Format::NHWC && data_format != Format::NCHW && data_format != Format::NCDHW)) {
745     MS_LOG(EXCEPTION) << "data format is invalid, only support NCHW, NHWC and NCDHW";
746   }
747   return data_format;
748 }
GetRemoveMonadAbsNum(const AbstractBasePtrList & abs_list)749 size_t CheckAndConvertUtils::GetRemoveMonadAbsNum(const AbstractBasePtrList &abs_list) {
750   size_t remove_monad_count = abs_list.size();
751   for (const auto &item : abs_list) {
752     if (item->isa<abstract::AbstractMonad>()) {
753       --remove_monad_count;
754     }
755   }
756 
757   for (size_t i = 0; i < remove_monad_count; ++i) {
758     if (abs_list[i]->isa<abstract::AbstractMonad>()) {
759       MS_EXCEPTION(UnknownError) << "The monad inputs of the node must at last of the node inputs.";
760     }
761   }
762   return remove_monad_count;
763 }
764 
HasDynamicShapeInput(const AbstractBasePtrList & abs_list)765 bool CheckAndConvertUtils::HasDynamicShapeInput(const AbstractBasePtrList &abs_list) {
766   for (const auto &item : abs_list) {
767     MS_EXCEPTION_IF_NULL(item);
768     auto shape = item->BuildShape();
769     if (shape->IsDynamic()) {
770       return true;
771     }
772   }
773   return false;
774 }
775 
GetErrorTypeString(const std::set<TypePtr> & check_list,const TypePtr & check_type)776 std::string CheckAndConvertUtils::GetErrorTypeString(const std::set<TypePtr> &check_list, const TypePtr &check_type) {
777   std::ostringstream buffer;
778   buffer << "{";
779   // got tensor type list
780   for (const auto &item : check_list) {
781     if (item->isa<TensorType>()) {
782       buffer << item->ToString();
783       buffer << ", ";
784       continue;
785     }
786     buffer << "Tensor[" << item->ToString() << "], ";
787   }
788   if (check_type->isa<TensorType>()) {
789     buffer << "}";
790     return buffer.str();
791   }
792   // got python type
793   std::set<std::string> type_string;
794   for (const auto &item : check_list) {
795     if (item->isa<Float>()) {
796       (void)type_string.emplace("Float");
797     }
798     if (item->isa<Int>()) {
799       (void)type_string.emplace("Int");
800     }
801     if (item->isa<Bool>()) {
802       (void)type_string.emplace("Bool");
803     }
804     if (item->isa<UInt>()) {
805       (void)type_string.emplace("UInt");
806     }
807   }
808   for (const auto &item : type_string) {
809     buffer << item << ",";
810   }
811   buffer << "}";
812   return buffer.str();
813 }
814 }  // namespace mindspore
815