1 /**
2 * Copyright 2019-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "utils/check_convert_utils.h"
18
19 #include <utility>
20 #include <vector>
21 #include <algorithm>
22 #include <typeinfo>
23 #include <functional>
24
25 #include "abstract/abstract_value.h"
26 #include "ops/op_utils.h"
27 #include "ir/dtype/type.h"
28 #include "ir/dtype/tensor_type.h"
29 #include "ir/dtype.h"
30 #include "utils/ms_context.h"
31
32 namespace mindspore {
33 static std::map<std::string, int64_t> DataFormatToEnumMap = {
34 {"NCHW", Format::NCHW}, {"NHWC", Format::NHWC}, {"NHWC4", Format::NHWC4},
35 {"HWKC", Format::HWKC}, {"HWCK", Format::HWCK}, {"KCHW", Format::KCHW},
36 {"CKHW", Format::CKHW}, {"KHWC", Format::KHWC}, {"CHWK", Format::CHWK},
37 {"HW", Format::HW}, {"HW4", Format::HW4}, {"NC", Format::NC},
38 {"NC4", Format::NC4}, {"NC4HW4", Format::NC4HW4}, {"NUM_OF_FORMAT", Format::NUM_OF_FORMAT},
39 {"NCDHW", Format::NCDHW}, {"NWC", Format::NWC}, {"NCW", Format::NCW},
40 };
41
42 static std::map<int64_t, std::string> DataFormatToStrMap = {
43 {Format::NCHW, "NCHW"}, {Format::NHWC, "NHWC"}, {Format::NHWC4, "NHWC4"},
44 {Format::HWKC, "HWKC"}, {Format::HWCK, "HWCK"}, {Format::KCHW, "KCHW"},
45 {Format::CKHW, "CKHW"}, {Format::KHWC, "KHWC"}, {Format::CHWK, "CHWK"},
46 {Format::HW, "HW"}, {Format::HW4, "HW4"}, {Format::NC, "NC"},
47 {Format::NC4, "NC4"}, {Format::NC4HW4, "NC4HW4"}, {Format::NUM_OF_FORMAT, "NUM_OF_FORMAT"},
48 {Format::NCDHW, "NCDHW"}, {Format::NWC, "NWC"}, {Format::NCW, "NCW"},
49 };
50
51 static std::map<std::string, int64_t> ReductionToEnumMap = {
52 {"sum", Reduction::REDUCTION_SUM},
53 {"mean", Reduction::MEAN},
54 {"none", Reduction::NONE},
55 };
56
57 static std::map<int64_t, std::string> ReductionToStrMap = {
58 {Reduction::REDUCTION_SUM, "sum"},
59 {Reduction::MEAN, "mean"},
60 {Reduction::NONE, "none"},
61 };
62
63 static std::map<std::string, int64_t> PadModToEnumMap = {
64 {"pad", PadMode::PAD},
65 {"same", PadMode::SAME},
66 {"valid", PadMode::VALID},
67 };
68
69 static std::map<int64_t, std::string> PadModToStrMap = {
70 {PadMode::PAD, "pad"},
71 {PadMode::SAME, "same"},
72 {PadMode::VALID, "valid"},
73 };
74
75 static std::map<std::string, int64_t> PadModToEnumUpperMap = {
76 {"PAD", PadMode::PAD},
77 {"SAME", PadMode::SAME},
78 {"VALID", PadMode::VALID},
79 };
80
81 static std::map<int64_t, std::string> PadModToStrUpperMap = {
82 {PadMode::PAD, "PAD"},
83 {PadMode::SAME, "SAME"},
84 {PadMode::VALID, "VALID"},
85 };
86
87 AttrConverterPair DataFormatConverter(DataFormatToEnumMap, DataFormatToStrMap);
88 AttrConverterPair PadModeConverter(PadModToEnumMap, PadModToStrMap);
89 AttrConverterPair PadModeUpperConverter(PadModToEnumUpperMap, PadModToStrUpperMap);
90 AttrConverterPair ReductionConverter(ReductionToEnumMap, ReductionToStrMap);
91
92 static std::map<std::string, AttrConverterPair> FormatAndPadAttrMap = {
93 {ops::kFormat, DataFormatConverter},
94 {ops::kPadMode, PadModeConverter},
95 };
96
97 static std::map<std::string, AttrConverterPair> FormatAndPadUpperAttrMap = {
98 {ops::kFormat, DataFormatConverter},
99 {ops::kPadMode, PadModeUpperConverter},
100 };
101
102 static std::map<std::string, AttrConverterPair> DataFormatMap = {
103 {ops::kFormat, DataFormatConverter},
104 };
105
106 static std::map<std::string, AttrConverterPair> ReductionMap = {
107 {ops::kReduction, ReductionConverter},
108 };
109
110 static std::map<std::string, std::map<std::string, AttrConverterPair>> PrimAttrConvertMap = {
111 {"Conv2D", FormatAndPadAttrMap},
112 {"Conv2DTranspose", FormatAndPadUpperAttrMap},
113 {"Conv2DBackpropInput", FormatAndPadUpperAttrMap},
114 {"Conv2DBackpropFilter", FormatAndPadUpperAttrMap},
115 {"Conv3D", FormatAndPadAttrMap},
116 {"Conv3DBackpropInput", FormatAndPadAttrMap},
117 {"Conv3DBackpropFilter", FormatAndPadAttrMap},
118 {"Conv3DTranspose", DataFormatMap},
119 {"DepthwiseConv2dNative", FormatAndPadAttrMap},
120 {"DepthwiseConv2dNativeBackpropInput", FormatAndPadAttrMap},
121 {"DepthwiseConv2dNativeBackpropFilter", FormatAndPadAttrMap},
122 {"AvgPool", FormatAndPadUpperAttrMap},
123 {"MaxPool", FormatAndPadUpperAttrMap},
124 {"MaxPoolWithArgmax", FormatAndPadUpperAttrMap},
125 {"AvgPoolGrad", FormatAndPadUpperAttrMap},
126 {"AvgPoolGradVm", FormatAndPadUpperAttrMap},
127 {"AvgPoolGradGpu", FormatAndPadUpperAttrMap},
128 {"AvgPoolGradCpu", FormatAndPadUpperAttrMap},
129 {"MaxPoolGrad", FormatAndPadUpperAttrMap},
130 {"MaxPoolGradGrad", FormatAndPadUpperAttrMap},
131 {"MaxPoolGradWithArgmax", FormatAndPadUpperAttrMap},
132 {"MaxPoolGradGradWithArgmax", FormatAndPadUpperAttrMap},
133 {"BatchNorm", DataFormatMap},
134 {"BatchNormGrad", DataFormatMap},
135 {"BiasAdd", DataFormatMap},
136 {"BiasAddGrad", DataFormatMap},
137 {"BinaryCrossEntropy", ReductionMap},
138 {"BinaryCrossEntropyGrad", ReductionMap},
139 {"NLLLoss", ReductionMap},
140 {"DepthToSpace", DataFormatMap},
141 {"Pooling", DataFormatMap},
142 {"Deconvolution", DataFormatMap},
143 {"AvgPoolV2", DataFormatMap},
144 {"MaxPoolV3", DataFormatMap},
145 {"FusedBatchNorm", DataFormatMap}};
146
GetDataFormatEnumValue(const ValuePtr & value,int64_t * enum_value)147 bool CheckAndConvertUtils::GetDataFormatEnumValue(const ValuePtr &value, int64_t *enum_value) {
148 MS_EXCEPTION_IF_NULL(value);
149 if (value->isa<StringImm>()) {
150 auto attr_value_str = GetValue<std::string>(value);
151 if (DataFormatToEnumMap.find(attr_value_str) == DataFormatToEnumMap.end()) {
152 MS_LOG(DEBUG) << "The data format " << attr_value_str << " not be converted to enum.";
153 return false;
154 }
155 *enum_value = DataFormatToEnumMap[attr_value_str];
156 return true;
157 } else {
158 *enum_value = GetValue<int64_t>(value);
159 return true;
160 }
161 }
162
GetPadModEnumValue(const ValuePtr & value,int64_t * enum_value,bool is_upper)163 void CheckAndConvertUtils::GetPadModEnumValue(const ValuePtr &value, int64_t *enum_value, bool is_upper) {
164 MS_EXCEPTION_IF_NULL(value);
165 if (value->isa<StringImm>()) {
166 auto attr_value_str = GetValue<std::string>(value);
167
168 std::map<std::string, int64_t> pad_map = PadModToEnumMap;
169 if (is_upper) {
170 pad_map = PadModToEnumUpperMap;
171 }
172 if (pad_map.find(attr_value_str) == pad_map.end()) {
173 MS_LOG(EXCEPTION) << "Invalid pad mode " << attr_value_str << " use pad, valid or same";
174 }
175 *enum_value = pad_map[attr_value_str];
176 } else {
177 *enum_value = GetValue<int64_t>(value);
178 }
179 }
180
GetReductionEnumValue(const ValuePtr & value,int64_t * enum_value)181 void CheckAndConvertUtils::GetReductionEnumValue(const ValuePtr &value, int64_t *enum_value) {
182 MS_EXCEPTION_IF_NULL(value);
183 if (value->isa<StringImm>()) {
184 auto attr_value_str = GetValue<std::string>(value);
185
186 std::map<std::string, int64_t> pad_map = ReductionToEnumMap;
187 if (pad_map.find(attr_value_str) == pad_map.end()) {
188 MS_LOG(EXCEPTION) << "Invalid pad mode " << attr_value_str << " use pad, valid or same";
189 }
190 *enum_value = pad_map[attr_value_str];
191 } else {
192 *enum_value = GetValue<int64_t>(value);
193 }
194 }
195
GetAttrConvertPair(const std::string & op_type,const std::string & attr_name)196 AttrConverterPair CheckAndConvertUtils::GetAttrConvertPair(const std::string &op_type, const std::string &attr_name) {
197 AttrConverterPair attr_pair;
198 if (op_type.empty() || attr_name.empty()) {
199 return attr_pair;
200 }
201 auto op_attr_map_it = PrimAttrConvertMap.find(op_type);
202 if (op_attr_map_it == PrimAttrConvertMap.end()) {
203 return attr_pair;
204 }
205 auto attr_pair_it = op_attr_map_it->second.find(attr_name);
206 if (attr_pair_it == op_attr_map_it->second.end()) {
207 return attr_pair;
208 }
209
210 return attr_pair_it->second;
211 }
212
ConvertAttrValueToInt(const std::string & op_type,const std::string & attr_name,ValuePtr * const value)213 bool CheckAndConvertUtils::ConvertAttrValueToInt(const std::string &op_type, const std::string &attr_name,
214 ValuePtr *const value) {
215 if (value == nullptr || *value == nullptr) {
216 MS_LOG(DEBUG) << "value of attr " << op_type << attr_name << " is nullptr.";
217 return false;
218 }
219 if (!(*value)->isa<StringImm>()) {
220 return false;
221 }
222 auto attr_map_pair = GetAttrConvertPair(op_type, attr_name);
223 if (attr_map_pair.first.size() == 0) {
224 return false;
225 }
226
227 std::string real_value = std::dynamic_pointer_cast<StringImm>(*value)->value();
228 bool do_convert = false;
229 if (attr_map_pair.first.find(real_value) != attr_map_pair.first.end()) {
230 do_convert = true;
231 }
232 if (!do_convert) {
233 transform(real_value.begin(), real_value.end(), real_value.begin(), ::toupper);
234 if (attr_map_pair.first.find(real_value) != attr_map_pair.first.end()) {
235 do_convert = true;
236 }
237 }
238 if (!do_convert) {
239 transform(real_value.begin(), real_value.end(), real_value.begin(), ::tolower);
240 if (attr_map_pair.first.find(real_value) == attr_map_pair.first.end()) {
241 MS_LOG(DEBUG) << "Can not convert " << op_type << " attr " << attr_name << ": " << real_value << " to int";
242 return false;
243 }
244 }
245 *value = MakeValue<int64_t>(attr_map_pair.first[real_value]);
246 MS_LOG(DEBUG) << "convert str to int, name: " << op_type << ", attr: " << attr_name;
247 return true;
248 }
249
ConvertAttrValueToString(const std::string & op_type,const std::string & attr_name,ValuePtr * const value)250 bool CheckAndConvertUtils::ConvertAttrValueToString(const std::string &op_type, const std::string &attr_name,
251 ValuePtr *const value) {
252 if (value == nullptr || *value == nullptr) {
253 MS_LOG(DEBUG) << "value is nullptr! op_type = " << op_type << ", attr_name = " << attr_name;
254 return false;
255 }
256 if (!(*value)->isa<Int64Imm>()) {
257 return false;
258 }
259 auto attr_map_pair = GetAttrConvertPair(op_type, attr_name);
260 if (attr_map_pair.second.size() == 0) {
261 return false;
262 }
263
264 int64_t real_value = std::dynamic_pointer_cast<Int64Imm>(*value)->value();
265 if (attr_map_pair.second.find(real_value) == attr_map_pair.second.end()) {
266 MS_LOG(DEBUG) << "Can not convert " << op_type << " attr " << attr_name << ": " << real_value << " to string";
267 return false;
268 }
269 *value = MakeValue<std::string>(attr_map_pair.second[real_value]);
270 MS_LOG(DEBUG) << "convert int to str, name: " << op_type << ", attr: " << attr_name;
271 return true;
272 }
273
ConvertAttrValueInExport(const std::string & op_type,const std::string & attr_name,ValuePtr * const value)274 void CheckAndConvertUtils::ConvertAttrValueInExport(const std::string &op_type, const std::string &attr_name,
275 ValuePtr *const value) {
276 if (value == nullptr || *value == nullptr) {
277 MS_LOG(DEBUG) << "value is nullptr! op_type = " << op_type << ", attr_name = " << attr_name;
278 return;
279 }
280 // convert enum to string
281 ConvertAttrValueToString(op_type, attr_name, value);
282 }
283
ConvertAttrValueInLoad(const std::string & op_type,const std::string & attr_name,ValuePtr * const value)284 void CheckAndConvertUtils::ConvertAttrValueInLoad(const std::string &op_type, const std::string &attr_name,
285 ValuePtr *const value) {
286 if (value == nullptr || *value == nullptr) {
287 MS_LOG(DEBUG) << "value is nullptr! op_type = " << op_type << ", attr_name = " << attr_name;
288 return;
289 }
290 // convert string to enum
291 ConvertAttrValueToInt(op_type, attr_name, value);
292 }
293
294 namespace {
295 typedef std::map<std::string, std::function<ValuePtr(ValuePtr)>> AttrFunction;
296
L2NormalizeAttrConversion(ValuePtr attr)297 ValuePtr L2NormalizeAttrConversion(ValuePtr attr) {
298 if (attr->isa<Int64Imm>()) {
299 return attr;
300 }
301 auto attr_value = GetValue<std::vector<int64_t>>(attr);
302 return MakeValue(attr_value[0]);
303 }
304
305 std::map<std::string, AttrFunction> kIrAttrToOpAttr = {{"L2Normalize", {{"axis", L2NormalizeAttrConversion}}},
306 {"L2NormalizeGrad", {{"axis", L2NormalizeAttrConversion}}}};
307 } // namespace
308
CheckPositiveVector(const std::string & arg_name,const std::vector<int64_t> & arg_value,const std::string & prim_name)309 std::vector<int64_t> CheckAndConvertUtils::CheckPositiveVector(const std::string &arg_name,
310 const std::vector<int64_t> &arg_value,
311 const std::string &prim_name) {
312 std::ostringstream buffer;
313 buffer << "The primitive[" << prim_name << "]'s attribute[" << arg_name
314 << "] should be a vector with all positive item. but got [";
315 if (std::any_of(arg_value.begin(), arg_value.end(), [](int64_t item) { return item < 0; })) {
316 for (auto item : arg_value) {
317 buffer << item << ", ";
318 }
319 buffer << "].";
320 MS_EXCEPTION(ValueError) << buffer.str();
321 }
322
323 return arg_value;
324 }
325
CheckString(const std::string & arg_name,const std::string & arg_value,const std::set<std::string> & check_list,const std::string & prim_name)326 std::string CheckAndConvertUtils::CheckString(const std::string &arg_name, const std::string &arg_value,
327 const std::set<std::string> &check_list, const std::string &prim_name) {
328 if (check_list.find(arg_value) != check_list.end()) {
329 return arg_value;
330 }
331 std::ostringstream buffer;
332 buffer << "The primitive[" << prim_name << "]'s attribute[" << arg_name << "]";
333 if (check_list.size() == 1) {
334 buffer << " must be \"" << (*check_list.begin()) << "\",but got \"" << arg_value << "\".";
335 MS_EXCEPTION(ValueError) << buffer.str();
336 }
337 buffer << " should be a element of {";
338 for (const auto &item : check_list) {
339 buffer << "\"" << item << "\", ";
340 }
341 buffer << "}"
342 << ",but got \"" << arg_value << "\""
343 << ".";
344 MS_EXCEPTION(ValueError) << buffer.str();
345 }
346
CheckInteger(const std::string & arg_name,int64_t arg_value,CompareEnum compare_operator,int64_t match_value,const std::string & prim_name)347 int64_t CheckAndConvertUtils::CheckInteger(const std::string &arg_name, int64_t arg_value, CompareEnum compare_operator,
348 int64_t match_value, const std::string &prim_name) {
349 auto iter = kCompareMap<float>.find(compare_operator);
350 if (iter == kCompareMap<float>.end()) {
351 MS_EXCEPTION(NotExistsError) << "compare_operator " << compare_operator << " cannot find in the compare map";
352 }
353 if (iter->second(arg_value, match_value)) {
354 return arg_value;
355 }
356 std::ostringstream buffer;
357 if (prim_name.empty()) {
358 buffer << "The argument[" << arg_name << "] must ";
359 } else {
360 buffer << "The primitive[" << prim_name << "]'s " << arg_name << " must ";
361 }
362 auto iter_to_string = kCompareToString.find(compare_operator);
363 if (iter_to_string == kCompareToString.end()) {
364 MS_EXCEPTION(NotExistsError) << "compare_operator " << compare_operator << " cannot find in the compare string map";
365 }
366 buffer << iter_to_string->second << match_value << ", but got " << arg_value << ".";
367 MS_EXCEPTION(ValueError) << buffer.str();
368 }
369
CheckInputArgs(const std::vector<AbstractBasePtr> & input_args,const CompareEnum compare_operator,const int64_t match_value,const std::string & prim_name)370 void CheckAndConvertUtils::CheckInputArgs(const std::vector<AbstractBasePtr> &input_args,
371 const CompareEnum compare_operator, const int64_t match_value,
372 const std::string &prim_name) {
373 (void)CheckInteger("input number", SizeToLong(input_args.size()), compare_operator, match_value, prim_name);
374 for (size_t index = 0; index < input_args.size(); index++) {
375 if (input_args[index] == nullptr) {
376 MS_EXCEPTION(ValueError) << "The " << index << "'s input of " << prim_name << " is nullptr.";
377 }
378 }
379 }
380
GetInputTensorType(const std::vector<AbstractBasePtr> & input_args,const size_t index,const std::string & prim_name)381 TypePtr CheckAndConvertUtils::GetInputTensorType(const std::vector<AbstractBasePtr> &input_args, const size_t index,
382 const std::string &prim_name) {
383 if (input_args.size() <= index) {
384 MS_EXCEPTION(ValueError) << "The primitive[" << prim_name << "]'s input index[" << index
385 << "] is out of the input number " << input_args.size();
386 }
387 auto input_arg = input_args[index];
388 if (input_arg == nullptr) {
389 MS_EXCEPTION(ValueError) << "The primitive[" << prim_name << "]'s input index[" << index << "] is nullptr.";
390 }
391 auto base_type = input_arg->BuildType();
392 MS_EXCEPTION_IF_NULL(base_type);
393 if (!base_type->isa<TensorType>()) {
394 MS_EXCEPTION(ValueError) << "The primitive[" << prim_name << "]'s input index[" << index << "] is not a tensor.";
395 }
396 auto tensor_type = base_type->cast<TensorTypePtr>();
397 MS_EXCEPTION_IF_NULL(tensor_type);
398 auto type = tensor_type->element();
399 MS_EXCEPTION_IF_NULL(type);
400 return type;
401 }
402
ConvertShapePtrToShapeMap(const BaseShapePtr & shape)403 ShapeMap CheckAndConvertUtils::ConvertShapePtrToShapeMap(const BaseShapePtr &shape) {
404 MS_EXCEPTION_IF_NULL(shape);
405 if (!shape->isa<abstract::Shape>()) {
406 return std::map<std::string, std::vector<int64_t>>();
407 }
408 auto shape_element = shape->cast<abstract::ShapePtr>();
409 MS_EXCEPTION_IF_NULL(shape_element);
410 ShapeMap shape_map;
411 shape_map[kShape] = shape_element->shape();
412 shape_map[kMinShape] = shape_element->min_shape();
413 shape_map[kMaxShape] = shape_element->max_shape();
414 return shape_map;
415 }
416
GetTensorInputShape(const std::string & prim_name,const std::vector<AbstractBasePtr> & input_args,int64_t index)417 abstract::ShapePtr CheckAndConvertUtils::GetTensorInputShape(const std::string &prim_name,
418 const std::vector<AbstractBasePtr> &input_args,
419 int64_t index) {
420 auto abstract = CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, LongToSize(index));
421 MS_EXCEPTION_IF_NULL(abstract);
422 auto base_shape = abstract->BuildShape();
423 MS_EXCEPTION_IF_NULL(base_shape);
424 if (!base_shape->isa<abstract::Shape>()) {
425 MS_LOG(EXCEPTION) << prim_name << " can not get shape for input " << index;
426 }
427 auto shape = base_shape->cast<abstract::ShapePtr>();
428 MS_EXCEPTION_IF_NULL(shape);
429 return shape;
430 }
431
Check(const string & arg_name,int64_t arg_value,CompareEnum compare_type,const string &,int64_t value,const string & prim_name,ExceptionType)432 void CheckAndConvertUtils::Check(const string &arg_name, int64_t arg_value, CompareEnum compare_type, const string &,
433 int64_t value, const string &prim_name, ExceptionType) {
434 auto iter = kCompareMap<float>.find(compare_type);
435 if (iter == kCompareMap<float>.end()) {
436 MS_EXCEPTION(NotExistsError) << "the compare type :" << compare_type << " is not in the compare map";
437 }
438 if (iter->second(arg_value, value)) {
439 return;
440 }
441 std::ostringstream buffer;
442 if (prim_name.empty()) {
443 buffer << "The attribute[" << arg_name << "] must ";
444 } else {
445 buffer << "The primitive[" << prim_name << "]'s attribute[" << arg_name << "] must ";
446 }
447 auto iter_to_string = kCompareToString.find(compare_type);
448 if (iter_to_string == kCompareToString.end()) {
449 MS_EXCEPTION(NotExistsError) << "compare_operator " << compare_type << " cannot find in the compare string map";
450 }
451 buffer << iter_to_string->second << value << ", but got " << arg_value << ".";
452 MS_EXCEPTION(ValueError) << buffer.str();
453 }
454
CheckTensorTypeSame(const std::map<std::string,TypePtr> & types,const std::set<TypePtr> & check_list,const std::string & prim_name)455 TypePtr CheckAndConvertUtils::CheckTensorTypeSame(const std::map<std::string, TypePtr> &types,
456 const std::set<TypePtr> &check_list, const std::string &prim_name) {
457 if (types.empty()) {
458 MS_EXCEPTION(ArgumentError) << "Trying to use the function to check a empty types map!";
459 }
460 for (const auto &item : types) {
461 auto type = item.second;
462 MS_EXCEPTION_IF_NULL(type);
463 if (!type->isa<TensorType>()) {
464 std::ostringstream buffer;
465 buffer << "The primitive[" << prim_name << "]'s input arguments must be all tensor.\n";
466 if (!check_list.empty()) {
467 buffer << "Valid type list: {";
468 for (auto const &valid_type : check_list) {
469 if (valid_type->isa<TensorType>()) {
470 buffer << valid_type->ToString() << ", ";
471 break;
472 }
473 buffer << "Tensor[" << valid_type << "]"
474 << ", ";
475 }
476 buffer << "}.\n";
477 }
478 for (const auto &type_info : types) {
479 buffer << "input argument[" << type_info.first << "]"
480 << ":" << type_info.second->ToString() << "\n";
481 }
482 MS_EXCEPTION(TypeError) << buffer.str();
483 }
484 }
485 auto check_type = _CheckTypeSame(types, prim_name, false);
486 std::string input_names = "";
487 for (const auto &item : types) {
488 (void)input_names.append(item.first);
489 (void)input_names.append(", ");
490 }
491 return CheckSubClass(input_names, check_type, check_list, prim_name);
492 }
493
CheckTensorTypeValid(const std::string & type_name,const TypePtr & type,const std::set<TypePtr> & check_list,const std::string & prim_name)494 TypePtr CheckAndConvertUtils::CheckTensorTypeValid(const std::string &type_name, const TypePtr &type,
495 const std::set<TypePtr> &check_list, const std::string &prim_name) {
496 MS_EXCEPTION_IF_NULL(type);
497 if (!type->isa<TensorType>()) {
498 MS_EXCEPTION(TypeError) << "The Primitive[" << prim_name << "] input argument[" << type_name
499 << "] must be a Tensor but got " << type->ToString() << ".";
500 }
501 auto tensor_type = type->cast<TensorTypePtr>();
502 auto element = tensor_type->element();
503 MS_EXCEPTION_IF_NULL(element);
504 for (const TypePtr &item : check_list) {
505 if (item->isa<TensorType>()) {
506 auto item_tensor_type = item->cast<TensorTypePtr>();
507 if (item_tensor_type->element() == nullptr) {
508 return element;
509 }
510 }
511 }
512 return CheckSubClass(type_name, type, check_list, prim_name);
513 }
514
CheckTensorIntValue(const std::string & type_name,const ValuePtr & value,const std::string & prim_name)515 ShapeVector CheckAndConvertUtils::CheckTensorIntValue(const std::string &type_name, const ValuePtr &value,
516 const std::string &prim_name) {
517 if (value == nullptr) {
518 MS_EXCEPTION(ValueError) << "The primitive[" << prim_name << "] input argument[" << type_name
519 << "] value is nullptr.";
520 }
521 ShapeVector tensor_value;
522 if (!value->isa<tensor::Tensor>()) {
523 MS_EXCEPTION(ValueError) << "The primitive[" << prim_name << "] input argument[" << type_name
524 << "] must be a tensor,but got " << value->ToString();
525 }
526 auto input_tensor = value->cast<tensor::TensorPtr>();
527 MS_EXCEPTION_IF_NULL(input_tensor);
528 size_t data_size = LongToSize(input_tensor->DataSize());
529 auto tensor_type = input_tensor->Dtype();
530 if (tensor_type->type_id() == kNumberTypeInt32) {
531 auto data_c = reinterpret_cast<int *>(input_tensor->data_c());
532 MS_EXCEPTION_IF_NULL(data_c);
533 for (size_t i = 0; i < data_size; i++) {
534 tensor_value.push_back(static_cast<int64_t>(*data_c));
535 ++data_c;
536 }
537 } else if (tensor_type->type_id() == kNumberTypeInt64) {
538 auto tensor_data = reinterpret_cast<int64_t *>(input_tensor->data_c());
539 MS_EXCEPTION_IF_NULL(tensor_data);
540 tensor_value = {tensor_data, tensor_data + data_size};
541 } else {
542 MS_EXCEPTION(TypeError) << "The primitive[" << prim_name << "] input argument[" << type_name
543 << "] must be a Tensor[Int64] or Tensor[Int32] type,but got " << value->ToString();
544 }
545 return tensor_value;
546 }
547
CheckSubClass(const std::string & type_name,const TypePtr & type,const std::set<TypePtr> & template_types,const std::string & prim_name)548 TypePtr CheckAndConvertUtils::CheckSubClass(const std::string &type_name, const TypePtr &type,
549 const std::set<TypePtr> &template_types, const std::string &prim_name) {
550 auto check_type = type;
551 bool ok = std::any_of(template_types.begin(), template_types.end(), [check_type](const TypePtr &accept) -> bool {
552 return IsIdentidityOrSubclass(check_type, accept);
553 });
554 if (ok) {
555 return check_type;
556 }
557 if (type->isa<TensorType>()) {
558 auto tensor_type = type->cast<TensorTypePtr>();
559 check_type = tensor_type->element();
560 }
561 ok = std::any_of(template_types.begin(), template_types.end(),
562 [check_type](const TypePtr &accept) -> bool { return IsIdentidityOrSubclass(check_type, accept); });
563 if (ok) {
564 return check_type;
565 } else {
566 std::ostringstream buffer;
567 buffer << "Primitive[" << prim_name << "]'s input argument[" << type_name << "] must be a type of ";
568 buffer << GetErrorTypeString(template_types, type) << ", but got " << type->ToString();
569 buffer << ".";
570 MS_EXCEPTION(TypeError) << buffer.str();
571 }
572 }
573
CheckScalarOrTensorTypesSame(const std::map<std::string,TypePtr> & args,const std::set<TypePtr> & valid_values,const std::string & prim_name,const bool allow_mix)574 TypePtr CheckAndConvertUtils::CheckScalarOrTensorTypesSame(const std::map<std::string, TypePtr> &args,
575 const std::set<TypePtr> &valid_values,
576 const std::string &prim_name, const bool allow_mix) {
577 auto arg_ = _CheckTypeSame(args, prim_name, allow_mix);
578 return CheckTypeValid(args.begin()->first, arg_, valid_values, prim_name);
579 }
580
_CheckTypeSame(const std::map<std::string,TypePtr> & args,const std::string & prim_name,const bool allow_mix)581 TypePtr CheckAndConvertUtils::_CheckTypeSame(const std::map<std::string, TypePtr> &args, const std::string &prim_name,
582 const bool allow_mix) {
583 if (args.empty()) {
584 MS_EXCEPTION(ArgumentError) << "Trying to use the function to check a empty types map!";
585 }
586 std::ostringstream buffer;
587 TypePtr return_type = args.begin()->second;
588 buffer << "The primitive[" << prim_name << "]";
589 bool tensor_flag = return_type->isa<TensorType>();
590 std::set<TypeId> types_id;
591 for (const auto &elem : args) {
592 auto type = elem.second;
593 MS_EXCEPTION_IF_NULL(type);
594 if (!allow_mix) {
595 // input must be all tensor or all other type
596 if ((tensor_flag && !type->isa<TensorType>()) || (!tensor_flag && type->isa<TensorType>())) {
597 buffer << "'s "
598 << "input type must be same.\n";
599 for (const auto &error_elem : args) {
600 buffer << "input argument[" << error_elem.first << "]:" << error_elem.second->ToString() << "\n";
601 }
602 MS_EXCEPTION(TypeError) << buffer.str();
603 }
604 }
605 if (type->isa<TensorType>()) {
606 auto tensor_type = type->cast<TensorTypePtr>();
607 auto element = tensor_type->element();
608 MS_EXCEPTION_IF_NULL(element);
609 if (!allow_mix) {
610 return_type = element;
611 } else {
612 return_type = tensor_type;
613 }
614 (void)types_id.emplace(element->type_id());
615 } else {
616 (void)types_id.emplace(type->type_id());
617 }
618 if (types_id.size() > 1) {
619 buffer << "'s input type must be same.\n";
620 for (const auto &item : args) {
621 buffer << "name:[" << item.first << "]:" << item.second->ToString() << ".\n";
622 }
623 MS_EXCEPTION(TypeError) << buffer.str();
624 }
625 }
626 return return_type->DeepCopy();
627 }
628
CheckTypeValid(const std::string & arg_name,const TypePtr & arg_type,const std::set<TypePtr> & valid_type,const std::string & prim_name)629 TypePtr CheckAndConvertUtils::CheckTypeValid(const std::string &arg_name, const TypePtr &arg_type,
630 const std::set<TypePtr> &valid_type, const std::string &prim_name) {
631 if (valid_type.empty()) {
632 MS_EXCEPTION(ArgumentError) << "Trying to use the function to check a empty valid_type!";
633 }
634 MS_EXCEPTION_IF_NULL(arg_type);
635 if (arg_type->isa<TensorType>()) {
636 return CheckTensorTypeValid(arg_name, arg_type, valid_type, prim_name);
637 }
638 return CheckSubClass(arg_name, arg_type, valid_type, prim_name);
639 }
640
CheckIrAttrtoOpAttr(const std::string & op_type,const std::string & attr_name,ValuePtr * const value)641 bool CheckAndConvertUtils::CheckIrAttrtoOpAttr(const std::string &op_type, const std::string &attr_name,
642 ValuePtr *const value) {
643 if (*value == nullptr) {
644 MS_LOG(DEBUG) << "value is nullptr! op_type = " << op_type << ", attr_name = " << attr_name;
645 return false;
646 }
647 if (op_type.empty() || attr_name.empty()) {
648 return false;
649 }
650 auto op_map = kIrAttrToOpAttr.find(op_type);
651 if (op_map == kIrAttrToOpAttr.end()) {
652 return false;
653 }
654 auto attr_func = op_map->second.find(attr_name);
655 if (attr_func == op_map->second.end()) {
656 return false;
657 }
658 *value = attr_func->second(*value);
659 MS_LOG(DEBUG) << "convert ir attr to op attr, name: " << op_type << ", attr: " << attr_name;
660 return true;
661 }
662
CheckSummaryParam(const AbstractBasePtr & name,const AbstractBasePtr & value,const std::string & class_name)663 void CheckAndConvertUtils::CheckSummaryParam(const AbstractBasePtr &name, const AbstractBasePtr &value,
664 const std::string &class_name) {
665 MS_EXCEPTION_IF_NULL(name);
666 MS_EXCEPTION_IF_NULL(value);
667 CheckMode(class_name);
668 (void)CheckTypeValid("name", name->BuildType(), {kString}, class_name);
669 auto s = GetValue<std::string>(name->BuildValue());
670 if (s.empty()) {
671 MS_EXCEPTION(ValueError) << "The primitive[" << class_name << "]'s input argument[name] "
672 << " cannot be an empty string.";
673 }
674 (void)CheckTypeValid("value", value->BuildType(), {kTensorType}, class_name);
675 }
676
CheckMode(const std::string & class_name)677 void CheckAndConvertUtils::CheckMode(const std::string &class_name) {
678 auto ms_context = MsContext::GetInstance();
679 MS_EXCEPTION_IF_NULL(ms_context);
680 if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
681 MS_EXCEPTION(NotSupportError) << "The primitive[" << class_name << "] does not support PyNativeMode.\n"
682 << "Please convert the mode to GraphMode";
683 }
684 }
685
CheckAttrIntOrTupleInt(const std::string & arg_name,const ValuePtr & attr,const std::string & prim_name)686 std::vector<int64_t> CheckAndConvertUtils::CheckAttrIntOrTupleInt(const std::string &arg_name, const ValuePtr &attr,
687 const std::string &prim_name) {
688 std::vector<int64_t> result;
689 bool is_correct = false;
690 MS_EXCEPTION_IF_NULL(attr);
691 if (attr->isa<ValueTuple>()) {
692 std::vector<ValuePtr> attr_vec = attr->cast<ValueTuplePtr>()->value();
693 is_correct = std::all_of(attr_vec.begin(), attr_vec.end(), [&result](const ValuePtr &e) -> bool {
694 MS_EXCEPTION_IF_NULL(e);
695 if (e->isa<Int64Imm>()) {
696 (void)result.emplace_back(GetValue<int64_t>(e));
697 return true;
698 }
699 return false;
700 });
701 } else {
702 if (attr->isa<Int64Imm>()) {
703 is_correct = true;
704 int64_t attr_val = attr->cast<Int64ImmPtr>()->value();
705 result.push_back(attr_val);
706 }
707 }
708 if (!is_correct) {
709 MS_EXCEPTION(TypeError) << "The primitive[" << prim_name << "]'s attribute[" << arg_name
710 << "] must be a Int or a tuple with all Int elements, but got " << attr->ToString();
711 }
712 return result;
713 }
714
CheckAttrTupleInt(const std::string & arg_name,const ValuePtr & attr,const std::string & prim_name)715 std::vector<int64_t> CheckAndConvertUtils::CheckAttrTupleInt(const std::string &arg_name, const ValuePtr &attr,
716 const std::string &prim_name) {
717 std::vector<int64_t> result;
718 MS_EXCEPTION_IF_NULL(attr);
719 if (attr->isa<ValueTuple>()) {
720 std::vector<ValuePtr> attr_vec = attr->cast<ValueTuplePtr>()->value();
721 (void)std::transform(
722 attr_vec.begin(), attr_vec.end(), std::back_inserter(result), [=](const ValuePtr &e) -> int64_t {
723 if (!e->isa<Int64Imm>()) {
724 MS_EXCEPTION(TypeError) << "The primitive[" << prim_name << "]'s attribute[" << arg_name
725 << "] must be a tuple with all Int elements, but got " << attr->ToString();
726 }
727 return GetValue<int64_t>(e);
728 });
729 } else {
730 MS_EXCEPTION(TypeError) << "The primitive[" << prim_name << "]'s attribute[" << arg_name
731 << "] must be a tuple with all Int elements, but got " << attr->ToString() << ".";
732 }
733 return result;
734 }
735
CheckMinMaxShape(const ShapeVector & shape,ShapeVector * min_shape,ShapeVector * max_shape)736 void CheckAndConvertUtils::CheckMinMaxShape(const ShapeVector &shape, ShapeVector *min_shape, ShapeVector *max_shape) {
737 *min_shape = (*min_shape).empty() ? shape : *min_shape;
738 *max_shape = (*max_shape).empty() ? shape : *max_shape;
739 }
740
GetAndCheckFormat(const ValuePtr & value)741 int64_t CheckAndConvertUtils::GetAndCheckFormat(const ValuePtr &value) {
742 int64_t data_format;
743 bool result = CheckAndConvertUtils::GetDataFormatEnumValue(value, &data_format);
744 if (!result || (data_format != Format::NHWC && data_format != Format::NCHW && data_format != Format::NCDHW)) {
745 MS_LOG(EXCEPTION) << "data format is invalid, only support NCHW, NHWC and NCDHW";
746 }
747 return data_format;
748 }
GetRemoveMonadAbsNum(const AbstractBasePtrList & abs_list)749 size_t CheckAndConvertUtils::GetRemoveMonadAbsNum(const AbstractBasePtrList &abs_list) {
750 size_t remove_monad_count = abs_list.size();
751 for (const auto &item : abs_list) {
752 if (item->isa<abstract::AbstractMonad>()) {
753 --remove_monad_count;
754 }
755 }
756
757 for (size_t i = 0; i < remove_monad_count; ++i) {
758 if (abs_list[i]->isa<abstract::AbstractMonad>()) {
759 MS_EXCEPTION(UnknownError) << "The monad inputs of the node must at last of the node inputs.";
760 }
761 }
762 return remove_monad_count;
763 }
764
HasDynamicShapeInput(const AbstractBasePtrList & abs_list)765 bool CheckAndConvertUtils::HasDynamicShapeInput(const AbstractBasePtrList &abs_list) {
766 for (const auto &item : abs_list) {
767 MS_EXCEPTION_IF_NULL(item);
768 auto shape = item->BuildShape();
769 if (shape->IsDynamic()) {
770 return true;
771 }
772 }
773 return false;
774 }
775
GetErrorTypeString(const std::set<TypePtr> & check_list,const TypePtr & check_type)776 std::string CheckAndConvertUtils::GetErrorTypeString(const std::set<TypePtr> &check_list, const TypePtr &check_type) {
777 std::ostringstream buffer;
778 buffer << "{";
779 // got tensor type list
780 for (const auto &item : check_list) {
781 if (item->isa<TensorType>()) {
782 buffer << item->ToString();
783 buffer << ", ";
784 continue;
785 }
786 buffer << "Tensor[" << item->ToString() << "], ";
787 }
788 if (check_type->isa<TensorType>()) {
789 buffer << "}";
790 return buffer.str();
791 }
792 // got python type
793 std::set<std::string> type_string;
794 for (const auto &item : check_list) {
795 if (item->isa<Float>()) {
796 (void)type_string.emplace("Float");
797 }
798 if (item->isa<Int>()) {
799 (void)type_string.emplace("Int");
800 }
801 if (item->isa<Bool>()) {
802 (void)type_string.emplace("Bool");
803 }
804 if (item->isa<UInt>()) {
805 (void)type_string.emplace("UInt");
806 }
807 }
808 for (const auto &item : type_string) {
809 buffer << item << ",";
810 }
811 buffer << "}";
812 return buffer.str();
813 }
814 } // namespace mindspore
815