1 /** 2 * Copyright 2019-2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CORE_UTILS_CHECK_CONVERT_UTILS_H_ 18 #define MINDSPORE_CORE_UTILS_CHECK_CONVERT_UTILS_H_ 19 #include <vector> 20 #include <string> 21 #include <map> 22 #include <set> 23 #include <utility> 24 #include <typeinfo> 25 #include <memory> 26 #include "abstract/param_validator.h" 27 #include "base/base.h" 28 #include "ir/anf.h" 29 #include "ir/dtype/type_id.h" 30 #include "include/api/format.h" 31 #include "utils/log_adapter.h" 32 namespace mindspore { 33 typedef std::pair<std::map<std::string, int64_t>, std::map<int64_t, std::string>> AttrConverterPair; 34 typedef std::map<std::string, std::vector<int64_t>> ShapeMap; 35 constexpr auto kShape = "shape"; 36 constexpr auto kMinShape = "min_shape"; 37 constexpr auto kMaxShape = "max_shape"; 38 39 enum CompareEnum : int64_t { 40 kEqual = 1, // == 41 kNotEqual = 2, // != 42 kLessThan = 3, // < 43 kLessEqual = 4, // <= 44 kGreaterThan = 5, // > 45 kGreaterEqual = 6, // >= 46 }; 47 48 enum CompareRange { 49 kIncludeNeither = 1, // (a,b) 50 kIncludeLeft = 2, // [a,b) 51 kIncludeRight = 3, // (a,b] 52 kIncludeBoth = 4, // [a,b] 53 }; 54 enum ActivationType : int64_t { 55 NO_ACTIVATION = 0, 56 RELU = 1, 57 SIGMOID = 2, 58 RELU6 = 3, 59 ELU = 4, 60 LEAKY_RELU = 5, 61 ABS = 6, 62 RELU1 = 7, 63 SOFTSIGN = 8, 64 SOFTPLUS = 9, 65 TANH = 10, 66 SELU = 11, 67 HSWISH = 12, 68 HSIGMOID = 13, 69 THRESHOLDRELU = 14, 70 LINEAR = 15, 71 HARD_TANH = 16, 72 SIGN = 17, 73 SWISH = 18, 74 GELU = 19, 75 GLU = 20, 76 UNKNOWN = 21 77 }; 78 enum ReduceMode : int64_t { 79 Reduce_Mean = 0, 80 Reduce_Max = 1, 81 Reduce_Min = 2, 82 Reduce_Prod = 3, 83 Reduce_Sum = 4, 84 Reduce_Sum_Square = 5, 85 Reduce_ASum = 6, 86 Reduce_All = 7 87 }; 88 enum ReduceType : int64_t { 89 REDUCE_MAX = 0, 90 REDUCE_MEAN = 1, 91 REDUCE_ALL = 2, 92 REDUCE_ANY = 3, 93 REDUCE_LOG_SUM_EXP = 4, 94 REDUCE_PROD = 5, 95 REDUCE_SUM = 6, 96 REDUCE_UNKNOW = 7, 97 }; 98 enum EltwiseMode : int64_t { PROD = 0, SUM = 1, MAXIMUM = 2, ELTWISEMODE_UNKNOW = 3 }; 99 100 enum Reduction : int64_t { REDUCTION_SUM = 0, MEAN = 1, NONE = 2 }; 101 102 enum PadMode : int64_t { PAD = 0, SAME = 1, VALID = 2 }; 103 104 enum RoundMode : int64_t { 105 FLOOR = 0, 106 CEIL = 1, 107 }; 108 109 enum PoolMode : int64_t { 110 MAX_POOLING = 0, 111 MEAN_POOLING = 1, 112 }; 113 114 enum GateOrderMode : int64_t { RZH = 0, ZRH = 1 }; 115 116 enum class LshProjectionType : int64_t { UNKNOWN = 0, SPARSE = 1, DENSE = 2 }; 117 118 enum PaddingMode : int64_t { CONSTANT = 0, REFLECT = 1, SYMMETRIC = 2, MODE_RESERVED = 3 }; 119 120 enum class ResizeMethod : int64_t { UNKNOWN = -1, LINEAR = 0, NEAREST = 1, CUBIC = 2 }; 121 122 enum CoordinateTransformMode : int64_t { ASYMMETRIC = 0, ALIGN_CORNERS = 1, HALF_PIXEL = 2, CROP_AND_RESIZE = 3 }; 123 124 enum class NearestMode : int64_t { NORMAL = 0, ROUND_HALF_DOWN = 1, ROUND_HALF_UP = 2, FLOOR = 3, CEIL = 4 }; 125 126 template <typename T> 127 const std::map<CompareEnum, std::function<bool(T, T)>> kCompareMap = { 128 {kEqual, [](T num1, T num2) -> bool { return num1 == num2; }}, 129 {kNotEqual, [](T num1, T num2) -> bool { return num1 != num2; }}, 130 {kLessThan, [](T num1, T num2) -> bool { return num1 < num2; }}, 131 {kLessEqual, [](T num1, T num2) -> bool { return num1 <= num2; }}, 132 {kGreaterThan, [](T num1, T num2) -> bool { return num1 > num2; }}, 133 {kGreaterEqual, [](T num1, T num2) -> bool { return num1 >= num2; }}}; 134 135 template <typename T> 136 const std::map<CompareRange, std::function<bool(T, std::pair<T, T>)>> kCompareRangeMap = { 137 {kIncludeNeither, [](T num1, std::pair<T, T> range) -> bool { return num1 > range.first && num1 < range.second; }}, 138 {kIncludeLeft, [](T num1, std::pair<T, T> range) -> bool { return num1 >= range.first && num1 < range.second; }}, 139 {kIncludeBoth, [](T num1, std::pair<T, T> range) -> bool { return num1 >= range.first && num1 <= range.second; }}, 140 {kIncludeRight, [](T num1, std::pair<T, T> range) -> bool { return num1 > range.first && num1 <= range.second; }}}; 141 142 const std::map<CompareEnum, std::string> kCompareToString = { 143 {kEqual, "be equal to "}, {kNotEqual, "be not equal to "}, 144 {kLessThan, "be less than "}, {kLessEqual, "be less than or equal to "}, 145 {kGreaterThan, "be greater than "}, {kGreaterEqual, "be greater than or equal to "}}; 146 147 const std::map<CompareRange, std::pair<std::string, std::string>> kCompareRangeToString = { 148 {kIncludeNeither, {"in (", ")"}}, 149 {kIncludeLeft, {"in [", ")"}}, 150 {kIncludeRight, {"in (", "]"}}, 151 {kIncludeBoth, {"in [", "]"}}}; 152 153 class CheckAndConvertUtils { 154 public: 155 static std::vector<int64_t> CheckPositiveVector(const std::string &arg_name, const std::vector<int64_t> &arg_value, 156 const std::string &prim_name); 157 static std::string CheckString(const std::string &arg_name, const std::string &arg_value, 158 const std::set<std::string> &check_list, const std::string &prim_name); 159 160 // CheckValue should replace CheckInteger 161 static int64_t CheckInteger(const std::string &arg_name, int64_t arg_value, CompareEnum compare_operator, 162 int64_t match_value, const std::string &prim_name = ""); 163 164 template <typename T> CheckValue(const std::string & arg_name,T arg_value,CompareEnum compare_operator,T match_value,const std::string & prim_name)165 static T CheckValue(const std::string &arg_name, T arg_value, CompareEnum compare_operator, T match_value, 166 const std::string &prim_name) { 167 auto iter = kCompareMap<float>.find(compare_operator); 168 if (iter == kCompareMap<float>.end()) { 169 MS_EXCEPTION(NotExistsError) << "compare_operator " << compare_operator << " cannot find in the compare map"; 170 } 171 if (iter->second(arg_value, match_value)) { 172 return arg_value; 173 } 174 std::ostringstream buffer; 175 if (prim_name.empty()) { 176 buffer << "The attribute[" << arg_name << "] must "; 177 } else { 178 buffer << "The primitive[" << prim_name << "]'s " << arg_name << " must "; 179 } 180 auto iter_to_string = kCompareToString.find(compare_operator); 181 if (iter_to_string == kCompareToString.end()) { 182 MS_EXCEPTION(NotExistsError) << "compare_operator " << compare_operator 183 << " cannot find in the compare string map"; 184 } 185 buffer << iter_to_string->second << match_value << " , but got " << arg_value << "."; 186 MS_EXCEPTION(ValueError) << buffer.str(); 187 } 188 189 template <typename T> CheckInRange(const std::string & arg_name,T arg_value,CompareRange compare_operator,const std::pair<T,T> & range,const std::string & prim_name)190 static void CheckInRange(const std::string &arg_name, T arg_value, CompareRange compare_operator, 191 const std::pair<T, T> &range, const std::string &prim_name) { 192 auto iter = kCompareRangeMap<float>.find(compare_operator); 193 if (iter == kCompareRangeMap<float>.end()) { 194 MS_EXCEPTION(NotExistsError) << "compare_operator " << compare_operator << " cannot find in the compare map"; 195 } 196 if (range.first >= range.second) { 197 MS_EXCEPTION(ArgumentError) << "the check range left must be smaller than right number bug got [ " << range.first 198 << "," << range.second; 199 } 200 if (iter->second(arg_value, range)) { 201 return; 202 } 203 std::ostringstream buffer; 204 if (prim_name.empty()) { 205 buffer << "The attribute[" << arg_name << "] must "; 206 } else { 207 buffer << "The primitive[" << prim_name << "] " << arg_name << " must "; 208 } 209 auto iter_to_string = kCompareRangeToString.find(compare_operator); 210 if (iter_to_string == kCompareRangeToString.end()) { 211 MS_EXCEPTION(NotExistsError) << "compare_operator " << compare_operator 212 << " cannot find in the compare string map"; 213 } 214 auto range_strng = iter_to_string->second; 215 buffer << range_strng.first << range.first << "," << range.second << range_strng.second << " ,but got " << arg_value 216 << "."; 217 MS_EXCEPTION(ValueError) << buffer.str(); 218 } 219 220 static ShapeMap ConvertShapePtrToShapeMap(const BaseShapePtr &shape); 221 static abstract::ShapePtr GetTensorInputShape(const std::string &prim_name, 222 const std::vector<AbstractBasePtr> &input_args, int64_t index); 223 static void Check(const std::string &arg_name, int64_t arg_value, CompareEnum compare_type, 224 const std::string &value_name, int64_t value, const std::string &prim_name = "", 225 ExceptionType exception_type = ValueError); 226 227 template <typename T> 228 static void Check(const std::string &arg_name, const std::vector<T> &arg_value, CompareEnum compare_type, 229 const std::string &value_name, const std::vector<T> &value, const std::string &prim_name = "", 230 ExceptionType exception_type = ValueError) { 231 if (compare_type != kEqual) { 232 auto iter = kCompareToString.find(compare_type); 233 if (iter != kCompareToString.end()) { 234 MS_EXCEPTION(NotSupportError) << "Only supported equal to compare two vectors but got " << iter->second; 235 } 236 MS_EXCEPTION(UnknownError) << "Cannot find the operator " << compare_type << "in the compare map!"; 237 } 238 if (arg_value == value) { 239 return; 240 } 241 std::ostringstream buffer; 242 if (prim_name.empty()) { 243 buffer << "The attribute[" << arg_name << "]:"; 244 } else { 245 buffer << "The primitive[" << prim_name << "]'s " << arg_name << ":"; 246 } 247 auto iter_to_string = kCompareToString.find(compare_type); 248 if (iter_to_string == kCompareToString.end()) { 249 MS_EXCEPTION(NotExistsError) << "compare_operator " << compare_type << " cannot find in the compare string map"; 250 } 251 252 buffer << " ["; 253 for (auto item : arg_value) { 254 buffer << item << ","; 255 } 256 buffer << "]"; 257 buffer << " must " << iter_to_string->second << "["; 258 for (auto item : value) { 259 buffer << item << ","; 260 } 261 buffer << "]"; 262 MS_EXCEPTION(exception_type) << buffer.str(); 263 } 264 265 template <typename T> CheckArgs(const std::string & op,const AbstractBasePtrList & args_spec_list,size_t index)266 static std::shared_ptr<T> CheckArgs(const std::string &op, const AbstractBasePtrList &args_spec_list, size_t index) { 267 if (index >= args_spec_list.size()) { 268 MS_EXCEPTION(ValueError) << op << " evaluator arguments list index out of bound, size " << args_spec_list.size() 269 << ", index " << index; 270 } 271 auto args_spec = args_spec_list[index]; 272 MS_EXCEPTION_IF_NULL(args_spec); 273 auto arg = dyn_cast<T>(args_spec); 274 if (arg == nullptr) { 275 MS_EXCEPTION(TypeError) << "The primitive[" << op << "]'s input[" << index << "] should be a " 276 << abstract::ReportNameTraits<T>::name << ", but got " 277 << args_spec_list[index]->BuildType()->ToString() << "."; 278 } 279 return arg; 280 } 281 282 static TypePtr CheckTensorTypeSame(const std::map<std::string, TypePtr> &types, const std::set<TypePtr> &check_list, 283 const std::string &prim_name); 284 static ShapeVector CheckTensorIntValue(const std::string &type_name, const ValuePtr &value, 285 const std::string &prim_name); 286 static TypePtr CheckTensorTypeValid(const std::string &type_name, const TypePtr &type, 287 const std::set<TypePtr> &check_list, const std::string &prim_name); 288 static TypePtr CheckSubClass(const std::string &type_name, const TypePtr &type, 289 const std::set<TypePtr> &template_types, const std::string &prim_name); 290 static TypePtr CheckScalarOrTensorTypesSame(const std::map<std::string, TypePtr> &args, 291 const std::set<TypePtr> &valid_values, const std::string &prim_name, 292 bool allow_mix = false); 293 static TypePtr CheckTypeValid(const std::string &arg_name, const TypePtr &arg_type, 294 const std::set<TypePtr> &valid_type, const std::string &prim_name); 295 static bool ConvertAttrValueToInt(const std::string &op_type, const std::string &attr_name, ValuePtr *const value); 296 static bool ConvertAttrValueToString(const std::string &op_type, const std::string &attr_name, ValuePtr *const value); 297 static void ConvertAttrValueInExport(const std::string &op_type, const std::string &attr_name, ValuePtr *const value); 298 static void ConvertAttrValueInLoad(const std::string &op_type, const std::string &attr_name, ValuePtr *const value); 299 static AttrConverterPair GetAttrConvertPair(const std::string &op_type, const std::string &attr_name); 300 static bool GetDataFormatEnumValue(const ValuePtr &value, int64_t *enum_value); 301 static void GetPadModEnumValue(const ValuePtr &value, int64_t *enum_value, bool is_upper = false); 302 static void GetReductionEnumValue(const ValuePtr &value, int64_t *enum_value); 303 static bool CheckIrAttrtoOpAttr(const std::string &op_type, const std::string &attr_name, ValuePtr *const value); 304 static void CheckSummaryParam(const AbstractBasePtr &name, const AbstractBasePtr &value, 305 const std::string &class_name); 306 static void CheckMode(const std::string &class_name); 307 static std::vector<int64_t> CheckAttrIntOrTupleInt(const std::string &prim_name, const ValuePtr &attr, 308 const std::string &arg_name); 309 static std::vector<int64_t> CheckAttrTupleInt(const std::string &prim_name, const ValuePtr &attr, 310 const std::string &arg_name); 311 static void CheckMinMaxShape(const ShapeVector &shape, ShapeVector *min_shape, ShapeVector *max_shape); 312 static int64_t GetAndCheckFormat(const ValuePtr &value); 313 static size_t GetRemoveMonadAbsNum(const AbstractBasePtrList &abs_list); 314 static void CheckInputArgs(const std::vector<AbstractBasePtr> &input_args, const CompareEnum compare_operator, 315 const int64_t match_value, const std::string &prim_name); 316 static TypePtr GetInputTensorType(const std::vector<AbstractBasePtr> &input_args, const size_t index, 317 const std::string &prim_name); 318 static bool HasDynamicShapeInput(const AbstractBasePtrList &abs_list); 319 320 private: 321 static TypePtr _CheckTypeSame(const std::map<std::string, TypePtr> &args, const std::string &prim_name, 322 const bool allow_mix); 323 static std::string GetErrorTypeString(const std::set<TypePtr> &check_list, const TypePtr &check_type); 324 }; 325 } // namespace mindspore 326 #endif // MINDSPORE_CORE_UTILS_CHECK_CONVERT_UTILS_H_ 327