1 /** 2 * Copyright 2019-2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CORE_UTILS_CHECK_CONVERT_UTILS_H_ 18 #define MINDSPORE_CORE_UTILS_CHECK_CONVERT_UTILS_H_ 19 #include <vector> 20 #include <string> 21 #include <map> 22 #include <set> 23 #include <utility> 24 #include <typeinfo> 25 #include <memory> 26 #include "abstract/param_validator.h" 27 #include "base/base.h" 28 #include "ir/anf.h" 29 #include "include/api/format.h" 30 #include "utils/log_adapter.h" 31 #if __has_include("include/mindapi/base/types.h") 32 #include "include/mindapi/base/types.h" 33 #else 34 #include "mindapi/base/types.h" 35 #endif 36 37 namespace mindspore { 38 typedef std::pair<std::map<std::string, int64_t>, std::map<int64_t, std::string>> AttrConverterPair; 39 typedef std::map<std::string, std::vector<int64_t>> ShapeMap; 40 constexpr auto kShape = "shape"; 41 constexpr auto kMaxShape = "max_shape"; 42 43 enum CompareEnum : int64_t { 44 kEqual = 1, // == 45 kNotEqual = 2, // != 46 kLessThan = 3, // < 47 kLessEqual = 4, // <= 48 kGreaterThan = 5, // > 49 kGreaterEqual = 6, // >= 50 }; 51 52 enum CompareRange { 53 kIncludeNeither = 1, // (a,b) 54 kIncludeLeft = 2, // [a,b) 55 kIncludeRight = 3, // (a,b] 56 kIncludeBoth = 4, // [a,b] 57 }; 58 59 enum ReduceType : int64_t { 60 REDUCE_MAX = 0, 61 REDUCE_MEAN = 1, 62 REDUCE_ALL = 2, 63 REDUCE_ANY = 3, 64 REDUCE_LOG_SUM_EXP = 4, 65 REDUCE_PROD = 5, 66 REDUCE_SUM = 6, 67 REDUCE_UNKNOW = 7, 68 }; 69 70 enum GateOrderMode : int64_t { RZH = 0, ZRH = 1 }; 71 72 template <typename T> 73 const std::map<CompareEnum, std::function<bool(T, T)>> kCompareMap = { 74 {kEqual, [](T num1, T num2) -> bool { return num1 == num2; }}, 75 {kNotEqual, [](T num1, T num2) -> bool { return num1 != num2; }}, 76 {kLessThan, [](T num1, T num2) -> bool { return num1 < num2; }}, 77 {kLessEqual, [](T num1, T num2) -> bool { return num1 <= num2; }}, 78 {kGreaterThan, [](T num1, T num2) -> bool { return num1 > num2; }}, 79 {kGreaterEqual, [](T num1, T num2) -> bool { return num1 >= num2; }}}; 80 81 template <typename T> 82 const std::map<CompareRange, std::function<bool(T, std::pair<T, T>)>> kCompareRangeMap = { 83 {kIncludeNeither, [](T num1, std::pair<T, T> range) -> bool { return num1 > range.first && num1 < range.second; }}, 84 {kIncludeLeft, [](T num1, std::pair<T, T> range) -> bool { return num1 >= range.first && num1 < range.second; }}, 85 {kIncludeBoth, [](T num1, std::pair<T, T> range) -> bool { return num1 >= range.first && num1 <= range.second; }}, 86 {kIncludeRight, [](T num1, std::pair<T, T> range) -> bool { return num1 > range.first && num1 <= range.second; }}}; 87 88 const std::map<CompareEnum, std::string> kCompareToString = { 89 {kEqual, "be equal to "}, {kNotEqual, "be not equal to "}, 90 {kLessThan, "be less than "}, {kLessEqual, "be less than or equal to "}, 91 {kGreaterThan, "be greater than "}, {kGreaterEqual, "be greater than or equal to "}}; 92 93 const std::map<CompareRange, std::pair<std::string, std::string>> kCompareRangeToString = { 94 {kIncludeNeither, {"in (", ")"}}, 95 {kIncludeLeft, {"in [", ")"}}, 96 {kIncludeRight, {"in (", "]"}}, 97 {kIncludeBoth, {"in [", "]"}}}; 98 99 class MS_CORE_API CheckAndConvertUtils { 100 public: 101 template <typename T> CheckPositiveVector(const std::string & arg_name,const std::vector<T> & arg_value,const std::string & prim_name)102 static std::vector<T> CheckPositiveVector(const std::string &arg_name, const std::vector<T> &arg_value, 103 const std::string &prim_name) { 104 std::ostringstream buffer; 105 buffer << "For primitive[" << prim_name << "], the attribute[" << arg_name 106 << "] should be a vector with all positive item. but got ["; 107 if (std::any_of(arg_value.begin(), arg_value.end(), [](T item) { return item < T(0); })) { 108 for (auto item : arg_value) { 109 buffer << item << ", "; 110 } 111 buffer << "]."; 112 MS_EXCEPTION(ValueError) << buffer.str(); 113 } 114 115 return arg_value; 116 } 117 static int64_t CheckAttrInt64Positive(const std::string &op, const ValuePtr &attr, const std::string &attr_name); 118 static std::vector<int64_t> CheckAttrTuple(const PrimitivePtr &prim, const std::string &attr_name, 119 size_t num_element); 120 121 static std::string CheckString(const std::string &arg_name, const std::string &arg_value, 122 const std::set<std::string> &check_list, const std::string &prim_name); 123 124 // CheckValue should replace CheckInteger 125 static int64_t CheckInteger(const std::string &arg_name, int64_t arg_value, CompareEnum compare_operator, 126 int64_t match_value, const std::string &prim_name = ""); 127 128 template <class T, class U> FormatCheckIntegerMsg(const std::string & arg_name,T arg_value,CompareEnum compare_operator,U match_value,const PrimitivePtr & prim)129 static std::string FormatCheckIntegerMsg(const std::string &arg_name, T arg_value, CompareEnum compare_operator, 130 U match_value, const PrimitivePtr &prim) { 131 std::ostringstream buffer; 132 if (prim == nullptr) { 133 buffer << "The argument[" << arg_name << "] must "; 134 } else { 135 auto prim_name = prim->name(); 136 buffer << "For primitive[" << prim_name << "], the " << arg_name << " must "; 137 } 138 auto iter_to_string = kCompareToString.find(compare_operator); 139 if (iter_to_string == kCompareToString.end()) { 140 MS_EXCEPTION(NotExistsError) << "compare_operator " << compare_operator 141 << " cannot find in the compare string map"; 142 } 143 buffer << iter_to_string->second << match_value << ", but got " << arg_value << "."; 144 return buffer.str(); 145 } 146 147 static std::string FormatCheckMsg(const std::string &arg_name, const std::vector<int64_t> &arg_value, 148 CompareEnum compare_type, const std::vector<int64_t> &value, 149 const PrimitivePtr &prim); 150 151 template <typename T> FormatCommMsg(T arg0)152 static std::string FormatCommMsg(T arg0) { 153 std::ostringstream buffer; 154 buffer << arg0; 155 return buffer.str(); 156 } 157 158 template <typename T, typename... Args> FormatCommMsg(T arg0,Args...args)159 static std::string FormatCommMsg(T arg0, Args... args) { 160 std::ostringstream buffer; 161 buffer << arg0 << FormatCommMsg(args...); 162 return buffer.str(); 163 } 164 165 template <typename T> FormatCheckInRangeMsg(const std::string & arg_name,T arg_value,CompareRange compare_operator,const std::pair<T,T> & range,const PrimitivePtr & prim)166 static std::string FormatCheckInRangeMsg(const std::string &arg_name, T arg_value, CompareRange compare_operator, 167 const std::pair<T, T> &range, const PrimitivePtr &prim) { 168 std::ostringstream buffer; 169 if (prim == nullptr) { 170 buffer << "The attribute[" << arg_name << "] must be "; 171 } else { 172 auto prim_name = prim->name(); 173 buffer << "For primitive[" << prim_name << "], the " << arg_name << " must be "; 174 } 175 auto iter_to_string = kCompareRangeToString.find(compare_operator); 176 if (iter_to_string == kCompareRangeToString.end()) { 177 MS_EXCEPTION(NotExistsError) << "compare_operator " << compare_operator 178 << " cannot find in the compare string map"; 179 } 180 auto range_strng = iter_to_string->second; 181 buffer << range_strng.first << range.first << "," << range.second << range_strng.second << ", but got " << arg_value 182 << "."; 183 return buffer.str(); 184 } 185 186 template <typename T> CheckPositiveVectorExcludeZero(const std::string & arg_name,const std::vector<T> & arg_value,const std::string & prim_name)187 static std::vector<T> CheckPositiveVectorExcludeZero(const std::string &arg_name, const std::vector<T> &arg_value, 188 const std::string &prim_name) { 189 std::ostringstream buffer; 190 buffer << "For primitive[" << prim_name << "], the attribute[" << arg_name 191 << "] should be a vector with all positive item. but got ["; 192 if (std::any_of(arg_value.begin(), arg_value.end(), [](T item) { return item <= T(0); })) { 193 for (auto item : arg_value) { 194 buffer << item << ", "; 195 } 196 buffer << "]."; 197 MS_EXCEPTION(ValueError) << buffer.str(); 198 } 199 200 return arg_value; 201 } 202 203 template <typename T> CheckValue(const std::string & arg_name,T arg_value,CompareEnum compare_operator,T match_value,const std::string & prim_name)204 static T CheckValue(const std::string &arg_name, T arg_value, CompareEnum compare_operator, T match_value, 205 const std::string &prim_name) { 206 auto iter = kCompareMap<T>.find(compare_operator); 207 if (iter == kCompareMap<T>.end()) { 208 MS_EXCEPTION(NotExistsError) << "compare_operator " << compare_operator << " cannot find in the compare map"; 209 } 210 if (iter->second(arg_value, match_value)) { 211 return arg_value; 212 } 213 std::ostringstream buffer; 214 if (prim_name.empty()) { 215 buffer << "The attribute[" << arg_name << "] must "; 216 } else { 217 buffer << "For primitive[" << prim_name << "], the " << arg_name << " must "; 218 } 219 auto iter_to_string = kCompareToString.find(compare_operator); 220 if (iter_to_string == kCompareToString.end()) { 221 MS_EXCEPTION(NotExistsError) << "compare_operator " << compare_operator 222 << " cannot find in the compare string map"; 223 } 224 buffer << iter_to_string->second << match_value << " , but got " << arg_value << "."; 225 MS_EXCEPTION(ValueError) << buffer.str(); 226 } 227 228 template <typename T> CheckValue(const std::string & arg_name,T arg_value,CompareEnum compare_operator,const std::string & match_name,T match_value,const std::string & prim_name)229 static T CheckValue(const std::string &arg_name, T arg_value, CompareEnum compare_operator, 230 const std::string &match_name, T match_value, const std::string &prim_name) { 231 auto iter = kCompareMap<T>.find(compare_operator); 232 if (iter == kCompareMap<T>.end()) { 233 MS_EXCEPTION(NotExistsError) << "compare_operator " << compare_operator << " cannot find in the compare map"; 234 } 235 if (iter->second(arg_value, match_value)) { 236 return arg_value; 237 } 238 std::ostringstream buffer; 239 if (prim_name.empty()) { 240 buffer << "The attribute[" << arg_name << "] must "; 241 } else { 242 buffer << "For primitive[" << prim_name << "], the " << arg_name << " must "; 243 } 244 auto iter_to_string = kCompareToString.find(compare_operator); 245 if (iter_to_string == kCompareToString.end()) { 246 MS_EXCEPTION(NotExistsError) << "compare_operator " << compare_operator 247 << " cannot find in the compare string map"; 248 } 249 buffer << iter_to_string->second << match_name << " which is " << match_value << " , but got " << arg_value << "."; 250 MS_EXCEPTION(ValueError) << buffer.str(); 251 } 252 253 template <typename T> CheckInRange(const std::string & arg_name,T arg_value,CompareRange compare_operator,const std::pair<T,T> & range,const std::string & prim_name)254 static void CheckInRange(const std::string &arg_name, T arg_value, CompareRange compare_operator, 255 const std::pair<T, T> &range, const std::string &prim_name) { 256 auto iter = kCompareRangeMap<T>.find(compare_operator); 257 if (iter == kCompareRangeMap<T>.end()) { 258 MS_EXCEPTION(NotExistsError) << "For " << prim_name << ", compare_operator " << compare_operator 259 << " cannot find in the compare map"; 260 } 261 if (range.first >= range.second) { 262 MS_EXCEPTION(ValueError) << "For " << prim_name 263 << ", the check range left must be smaller than right number but got left: " 264 << range.first << " and right: " << range.second << "."; 265 } 266 if (iter->second(arg_value, range)) { 267 return; 268 } 269 std::ostringstream buffer; 270 if (prim_name.empty()) { 271 buffer << "The attribute[" << arg_name << "] must be "; 272 } else { 273 buffer << "For primitive[" << prim_name << "], the " << arg_name << " must be "; 274 } 275 auto iter_to_string = kCompareRangeToString.find(compare_operator); 276 if (iter_to_string == kCompareRangeToString.end()) { 277 MS_EXCEPTION(NotExistsError) << "For " << prim_name << ", compare_operator " << compare_operator 278 << " cannot find in the compare string map"; 279 } 280 auto range_strng = iter_to_string->second; 281 buffer << range_strng.first << range.first << "," << range.second << range_strng.second << ", but got " << arg_value 282 << "."; 283 MS_EXCEPTION(ValueError) << buffer.str(); 284 } 285 286 static ShapeMap ConvertShapePtrToShapeMap(const BaseShapePtr &shape); 287 static abstract::ShapePtr GetTensorInputShape(const std::string &prim_name, 288 const std::vector<AbstractBasePtr> &input_args, size_t index); 289 static TypePtr GetTensorInputType(const std::string &prim_name, const std::vector<AbstractBasePtr> &input_args, 290 size_t index); 291 static void Check(const std::string &arg_name, int64_t arg_value, CompareEnum compare_type, int64_t value, 292 const std::string &prim_name = "", ExceptionType exception_type = ValueError); 293 294 template <typename T> 295 static void Check(const std::string &arg_name, const std::vector<T> &arg_value, CompareEnum compare_type, 296 const std::vector<T> &value, const std::string &prim_name = "", 297 ExceptionType exception_type = ValueError) { 298 if (compare_type != kEqual) { 299 auto iter = kCompareToString.find(compare_type); 300 if (iter != kCompareToString.end()) { 301 MS_EXCEPTION(NotSupportError) << "Only supported equal to compare two vectors but got " << iter->second; 302 } 303 MS_EXCEPTION(UnknownError) << "Cannot find the operator " << compare_type << "in the compare map!"; 304 } 305 if (arg_value == value) { 306 return; 307 } 308 std::ostringstream buffer; 309 if (prim_name.empty()) { 310 buffer << "The attribute[" << arg_name << "]:"; 311 } else { 312 buffer << "For primitive[" << prim_name << "], the " << arg_name << ":"; 313 } 314 auto iter_to_string = kCompareToString.find(compare_type); 315 if (iter_to_string == kCompareToString.end()) { 316 MS_EXCEPTION(NotExistsError) << "compare_operator " << compare_type << " cannot find in the compare string map"; 317 } 318 319 buffer << " ["; 320 for (auto item : arg_value) { 321 buffer << item << ","; 322 } 323 buffer << "]"; 324 buffer << " must " << iter_to_string->second << "["; 325 for (auto item : value) { 326 buffer << item << ","; 327 } 328 buffer << "]"; 329 MS_EXCEPTION(exception_type) << buffer.str(); 330 } 331 332 template <typename T> CheckArgs(const std::string & op,const AbstractBasePtrList & args_spec_list,size_t index)333 static std::shared_ptr<T> CheckArgs(const std::string &op, const AbstractBasePtrList &args_spec_list, size_t index) { 334 if (index >= args_spec_list.size()) { 335 MS_EXCEPTION(ValueError) << op << " evaluator arguments list index out of bound, size " << args_spec_list.size() 336 << ", index " << index; 337 } 338 auto args_abs = args_spec_list[index]; 339 MS_EXCEPTION_IF_NULL(args_abs); 340 auto arg = dyn_cast<T>(args_abs); 341 if (arg == nullptr) { 342 MS_EXCEPTION(TypeError) << "For primitive[" << op << "], the input[" << index << "] should be a " 343 << abstract::ReportNameTraits<T>::name << ", but got " 344 << args_spec_list[index]->BuildType()->ToString() << "."; 345 } 346 return arg; 347 } 348 IsScalar(const AbstractBasePtr & abs)349 static inline bool IsScalar(const AbstractBasePtr &abs) { return abs->GetType()->object_type() == kObjectTypeNumber; } IsTuple(const AbstractBasePtr & abs)350 static inline bool IsTuple(const AbstractBasePtr &abs) { return abs->GetType()->object_type() == kObjectTypeTuple; } IsList(const AbstractBasePtr & abs)351 static inline bool IsList(const AbstractBasePtr &abs) { return abs->GetType()->object_type() == kObjectTypeList; } IsTensor(const AbstractBasePtr & abs)352 static inline bool IsTensor(const AbstractBasePtr &abs) { 353 return abs->GetType()->object_type() == kObjectTypeTensorType; 354 } IsSequence(const AbstractBasePtr & abs)355 static inline bool IsSequence(const AbstractBasePtr &abs) { 356 return abs->GetType()->object_type() == kObjectTypeTuple || abs->GetType()->object_type() == kObjectTypeList; 357 } IsDynamicSequence(const AbstractBasePtr & abs)358 static inline bool IsDynamicSequence(const AbstractBasePtr &abs) { 359 return abs->GetShape()->isa<abstract::DynamicSequenceShape>(); 360 } GetSequenceElementTypes(const AbstractBasePtr & abs)361 static inline TypePtrList GetSequenceElementTypes(const AbstractBasePtr &abs) { 362 if (IsDynamicSequence(abs)) { 363 MS_EXCEPTION(TypeError) << "The input must not be a dynamic sequence."; 364 } 365 auto const &input_type = abs->GetType(); 366 MS_EXCEPTION_IF_NULL(input_type); 367 TypePtrList types_list{}; 368 if (input_type->object_type() == kObjectTypeTuple) { 369 types_list = input_type->cast<TuplePtr>()->elements(); 370 } else if (input_type->object_type() == kObjectTypeList) { 371 types_list = input_type->cast<ListPtr>()->elements(); 372 } else { 373 MS_EXCEPTION(TypeError) << "The input must be a tuple or a list, but got " << input_type->ToString() << "."; 374 } 375 return types_list; 376 } GetSequenceElementShapes(const AbstractBasePtr & abs)377 static inline abstract::BaseShapePtrList GetSequenceElementShapes(const AbstractBasePtr &abs) { 378 if (IsDynamicSequence(abs)) { 379 MS_EXCEPTION(TypeError) << "The input must not be a dynamic sequence."; 380 } 381 auto const &input_shape = abs->GetShape(); 382 MS_EXCEPTION_IF_NULL(input_shape); 383 auto const &input_type = abs->GetType(); 384 MS_EXCEPTION_IF_NULL(input_type); 385 abstract::BaseShapePtrList shapes_list{}; 386 if (input_type->object_type() == kObjectTypeTuple) { 387 shapes_list = input_shape->cast<abstract::TupleShapePtr>()->shape(); 388 } else if (input_type->object_type() == kObjectTypeList) { 389 shapes_list = input_shape->cast<abstract::ListShapePtr>()->shape(); 390 } else { 391 MS_EXCEPTION(TypeError) << "The input must be a tuple or a list, but got " << input_type->ToString() << "."; 392 } 393 return shapes_list; 394 } 395 static AbstractBasePtr CheckArgsType(const std::string &op, const AbstractBasePtrList &args_spec_list, size_t index, 396 TypeId type_id); 397 static AbstractBasePtr CheckArgsSequenceType(const std::string &op, const AbstractBasePtrList &args_spec_list, 398 size_t index); 399 400 static ShapeVector CheckTensorShapeSame(const std::map<std::string, BaseShapePtr> &shapes, 401 const std::vector<int64_t> &check_shape, const std::string &prim_name); 402 static TypePtr CheckTensorTypeSame(const std::map<std::string, TypePtr> &types, const std::set<TypePtr> &check_list, 403 const std::string &prim_name); 404 // Return Tensor type 405 static TypePtr CheckMathBinaryOpTensorType(const std::map<std::string, TypePtr> &types, 406 const std::set<TypePtr> &check_list, const std::string &prim_name); 407 // ==========================old========================= 408 static ShapeVector CheckTensorIntValue(const std::string &tensor_name, const ValuePtr &value, 409 const std::string &prim_name); 410 // ==========================new========================= 411 static ShapeVector CheckTensorIntValue(const std::string &tensor_name, const ValuePtr &value, 412 const std::string &prim_name, const TypePtr &type); 413 static TypePtr CheckTensorTypeValid(const std::string &type_name, const TypePtr &type, 414 const std::set<TypePtr> &check_list, const std::string &prim_name); 415 static TypePtr CheckSparseTensorTypeValid(const std::string &type_name, const TypePtr &type, 416 const std::set<TypePtr> &check_list, const std::string &prim_name); 417 static TypePtr CheckSubClass(const std::string &type_name, const TypePtr &type, 418 const std::set<TypePtr> &template_types, const std::string &prim_name); 419 static TypePtr CheckSubClassWithMoreInfo(const std::string &type_name, const TypePtr &type, 420 const std::string &more_info, const std::set<TypePtr> &template_types, 421 const std::string &prim_name); 422 static TypePtr CheckScalarOrTensorTypesSame(const std::map<std::string, TypePtr> &args, 423 const std::set<TypePtr> &valid_values, const std::string &prim_name, 424 bool allow_mix = false); 425 static TypePtr CheckTypeValid(const std::string &arg_name, const TypePtr &arg_type, 426 const std::set<TypePtr> &valid_type, const std::string &prim_name); 427 static TypePtr CheckTypeValidWithMoreInfo(const std::string &arg_name, const TypePtr &arg_type, 428 const std::string &more_info, const std::set<TypePtr> &valid_type, 429 const std::string &prim_name); 430 static bool ConvertAttrValueToInt(const std::string &op_type, const std::string &attr_name, ValuePtr *const value); 431 static bool ConvertAttrValueToString(const std::string &op_type, const std::string &attr_name, ValuePtr *const value); 432 static void ConvertAttrValueInExport(const std::string &op_type, const std::string &attr_name, ValuePtr *const value); 433 static void ConvertAttrValueInLoad(const std::string &op_type, const std::string &attr_name, ValuePtr *const value); 434 static AttrConverterPair GetAttrConvertPair(const std::string &op_type, const std::string &attr_name); 435 static bool GetDataFormatEnumValue(const ValuePtr &value, int64_t *enum_value); 436 static void GetPadModEnumValue(const ValuePtr &value, int64_t *enum_value, bool is_upper = false); 437 static void GetReductionEnumValue(const ValuePtr &value, int64_t *enum_value); 438 static bool CheckIrAttrtoOpAttr(const std::string &op_type, const std::string &attr_name, ValuePtr *const value); 439 static void CheckSummaryParam(const AbstractBasePtr &name, const AbstractBasePtr &value, 440 const std::string &class_name); 441 static void CheckMode(const std::string &class_name); 442 static std::vector<double> CheckTensorFloatValue(const std::string &type_name, const ValuePtr &value, 443 const std::string &prim_name); 444 static std::vector<double> CheckListOrTupleFloat(const std::string &arg_name, const ValuePtr &attr, 445 const std::string &prim_name); 446 // ==========================old========================= 447 static std::vector<int64_t> CheckIntOrTupleInt(const std::string &arg_name, const ValuePtr &attr, 448 const std::string &prim_name); 449 // ==========================new========================= 450 static std::vector<pyfloat> CheckListOrTupleFloat(const std::string &arg_name, const AbstractBasePtr &abs, 451 const std::string &prim_name); 452 static std::vector<int64_t> CheckIntOrTupleInt(const std::string &arg_name, const AbstractBasePtr &abs, 453 const std::string &prim_name); 454 static std::vector<int64_t> CheckTupleInt(const std::string &arg_name, const ValuePtr &attr, 455 const std::string &prim_name); 456 static std::vector<int64_t> CheckListInt(const std::string &arg_name, const ValuePtr &attr, 457 const std::string &prim_name); 458 static int64_t GetAndCheckFormat(const ValuePtr &value); 459 static size_t GetRemoveMonadAbsNum(const AbstractBasePtrList &abs_list); 460 static size_t GetRemoveUMonadAbsNum(const AbstractBasePtrList &abs_list); 461 static void CheckInputArgs(const std::vector<AbstractBasePtr> &input_args, const CompareEnum compare_operator, 462 const int64_t match_value, const std::string &prim_name); 463 static bool HasDynamicShapeInput(const AbstractBasePtrList &abs_list); 464 static void GetFormatStringVal(const PrimitivePtr &prim, std::string *format); 465 static size_t CheckAbstractShapeSame(const std::vector<AbstractBasePtr> &abs_list); 466 static size_t CheckAbstractTypeSame(const std::vector<AbstractBasePtr> &abs_list); 467 static void CheckAbstractTypeAndShapeSame(const std::vector<AbstractBasePtr> &abs_list, 468 const std::string &precondition_log, 469 const std::string &standard_abs_description = "", 470 const std::string &differ_abs_description = ""); 471 static bool CheckContainNestedOrIrregularSequence(const std::vector<AbstractBasePtr> &abs_list); 472 static bool CheckValueSame(const ValuePtr &value_1, const ValuePtr &value_2); 473 static abstract::AbstractSequencePtr BroadenAllSequenceElements(const abstract::AbstractSequencePtr &sequence); 474 static TypePtr CheckTypeSame(const std::map<std::string, TypePtr> &args, const std::string &prim_name, 475 const bool allow_mix = false); 476 static bool CheckPrimAttrConverted(const std::string &op_name); 477 478 private: 479 static TypePtr CheckTensorSubClass(const std::string &type_name, const TypePtr &type, 480 const std::set<TypePtr> &template_types, const std::string &prim_name, 481 bool is_mix = false); 482 }; 483 } // namespace mindspore 484 #endif // MINDSPORE_CORE_UTILS_CHECK_CONVERT_UTILS_H_ 485