• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CORE_UTILS_CHECK_CONVERT_UTILS_H_
18 #define MINDSPORE_CORE_UTILS_CHECK_CONVERT_UTILS_H_
19 #include <vector>
20 #include <string>
21 #include <map>
22 #include <set>
23 #include <utility>
24 #include <typeinfo>
25 #include <memory>
26 #include "abstract/param_validator.h"
27 #include "base/base.h"
28 #include "ir/anf.h"
29 #include "ir/dtype/type_id.h"
30 #include "include/api/format.h"
31 #include "utils/log_adapter.h"
32 namespace mindspore {
33 typedef std::pair<std::map<std::string, int64_t>, std::map<int64_t, std::string>> AttrConverterPair;
34 typedef std::map<std::string, std::vector<int64_t>> ShapeMap;
35 constexpr auto kShape = "shape";
36 constexpr auto kMinShape = "min_shape";
37 constexpr auto kMaxShape = "max_shape";
38 
39 enum CompareEnum : int64_t {
40   kEqual = 1,         // ==
41   kNotEqual = 2,      // !=
42   kLessThan = 3,      // <
43   kLessEqual = 4,     // <=
44   kGreaterThan = 5,   // >
45   kGreaterEqual = 6,  // >=
46 };
47 
48 enum CompareRange {
49   kIncludeNeither = 1,  // (a,b)
50   kIncludeLeft = 2,     // [a,b)
51   kIncludeRight = 3,    // (a,b]
52   kIncludeBoth = 4,     // [a,b]
53 };
54 enum ActivationType : int64_t {
55   NO_ACTIVATION = 0,
56   RELU = 1,
57   SIGMOID = 2,
58   RELU6 = 3,
59   ELU = 4,
60   LEAKY_RELU = 5,
61   ABS = 6,
62   RELU1 = 7,
63   SOFTSIGN = 8,
64   SOFTPLUS = 9,
65   TANH = 10,
66   SELU = 11,
67   HSWISH = 12,
68   HSIGMOID = 13,
69   THRESHOLDRELU = 14,
70   LINEAR = 15,
71   HARD_TANH = 16,
72   SIGN = 17,
73   SWISH = 18,
74   GELU = 19,
75   GLU = 20,
76   UNKNOWN = 21
77 };
78 enum ReduceMode : int64_t {
79   Reduce_Mean = 0,
80   Reduce_Max = 1,
81   Reduce_Min = 2,
82   Reduce_Prod = 3,
83   Reduce_Sum = 4,
84   Reduce_Sum_Square = 5,
85   Reduce_ASum = 6,
86   Reduce_All = 7
87 };
88 enum ReduceType : int64_t {
89   REDUCE_MAX = 0,
90   REDUCE_MEAN = 1,
91   REDUCE_ALL = 2,
92   REDUCE_ANY = 3,
93   REDUCE_LOG_SUM_EXP = 4,
94   REDUCE_PROD = 5,
95   REDUCE_SUM = 6,
96   REDUCE_UNKNOW = 7,
97 };
98 enum EltwiseMode : int64_t { PROD = 0, SUM = 1, MAXIMUM = 2, ELTWISEMODE_UNKNOW = 3 };
99 
100 enum Reduction : int64_t { REDUCTION_SUM = 0, MEAN = 1, NONE = 2 };
101 
102 enum PadMode : int64_t { PAD = 0, SAME = 1, VALID = 2 };
103 
104 enum RoundMode : int64_t {
105   FLOOR = 0,
106   CEIL = 1,
107 };
108 
109 enum PoolMode : int64_t {
110   MAX_POOLING = 0,
111   MEAN_POOLING = 1,
112 };
113 
114 enum GateOrderMode : int64_t { RZH = 0, ZRH = 1 };
115 
116 enum class LshProjectionType : int64_t { UNKNOWN = 0, SPARSE = 1, DENSE = 2 };
117 
118 enum PaddingMode : int64_t { CONSTANT = 0, REFLECT = 1, SYMMETRIC = 2, MODE_RESERVED = 3 };
119 
120 enum class ResizeMethod : int64_t { UNKNOWN = -1, LINEAR = 0, NEAREST = 1, CUBIC = 2 };
121 
122 enum CoordinateTransformMode : int64_t { ASYMMETRIC = 0, ALIGN_CORNERS = 1, HALF_PIXEL = 2, CROP_AND_RESIZE = 3 };
123 
124 enum class NearestMode : int64_t { NORMAL = 0, ROUND_HALF_DOWN = 1, ROUND_HALF_UP = 2, FLOOR = 3, CEIL = 4 };
125 
126 template <typename T>
127 const std::map<CompareEnum, std::function<bool(T, T)>> kCompareMap = {
128   {kEqual, [](T num1, T num2) -> bool { return num1 == num2; }},
129   {kNotEqual, [](T num1, T num2) -> bool { return num1 != num2; }},
130   {kLessThan, [](T num1, T num2) -> bool { return num1 < num2; }},
131   {kLessEqual, [](T num1, T num2) -> bool { return num1 <= num2; }},
132   {kGreaterThan, [](T num1, T num2) -> bool { return num1 > num2; }},
133   {kGreaterEqual, [](T num1, T num2) -> bool { return num1 >= num2; }}};
134 
135 template <typename T>
136 const std::map<CompareRange, std::function<bool(T, std::pair<T, T>)>> kCompareRangeMap = {
137   {kIncludeNeither, [](T num1, std::pair<T, T> range) -> bool { return num1 > range.first && num1 < range.second; }},
138   {kIncludeLeft, [](T num1, std::pair<T, T> range) -> bool { return num1 >= range.first && num1 < range.second; }},
139   {kIncludeBoth, [](T num1, std::pair<T, T> range) -> bool { return num1 >= range.first && num1 <= range.second; }},
140   {kIncludeRight, [](T num1, std::pair<T, T> range) -> bool { return num1 > range.first && num1 <= range.second; }}};
141 
142 const std::map<CompareEnum, std::string> kCompareToString = {
143   {kEqual, "be equal to "},           {kNotEqual, "be not equal to "},
144   {kLessThan, "be less than "},       {kLessEqual, "be less than or equal to "},
145   {kGreaterThan, "be greater than "}, {kGreaterEqual, "be greater than or equal to "}};
146 
147 const std::map<CompareRange, std::pair<std::string, std::string>> kCompareRangeToString = {
148   {kIncludeNeither, {"in (", ")"}},
149   {kIncludeLeft, {"in [", ")"}},
150   {kIncludeRight, {"in (", "]"}},
151   {kIncludeBoth, {"in [", "]"}}};
152 
153 class CheckAndConvertUtils {
154  public:
155   static std::vector<int64_t> CheckPositiveVector(const std::string &arg_name, const std::vector<int64_t> &arg_value,
156                                                   const std::string &prim_name);
157   static std::string CheckString(const std::string &arg_name, const std::string &arg_value,
158                                  const std::set<std::string> &check_list, const std::string &prim_name);
159 
160   // CheckValue should replace CheckInteger
161   static int64_t CheckInteger(const std::string &arg_name, int64_t arg_value, CompareEnum compare_operator,
162                               int64_t match_value, const std::string &prim_name = "");
163 
164   template <typename T>
CheckValue(const std::string & arg_name,T arg_value,CompareEnum compare_operator,T match_value,const std::string & prim_name)165   static T CheckValue(const std::string &arg_name, T arg_value, CompareEnum compare_operator, T match_value,
166                       const std::string &prim_name) {
167     auto iter = kCompareMap<float>.find(compare_operator);
168     if (iter == kCompareMap<float>.end()) {
169       MS_EXCEPTION(NotExistsError) << "compare_operator " << compare_operator << " cannot find in the compare map";
170     }
171     if (iter->second(arg_value, match_value)) {
172       return arg_value;
173     }
174     std::ostringstream buffer;
175     if (prim_name.empty()) {
176       buffer << "The attribute[" << arg_name << "] must ";
177     } else {
178       buffer << "The primitive[" << prim_name << "]'s " << arg_name << " must ";
179     }
180     auto iter_to_string = kCompareToString.find(compare_operator);
181     if (iter_to_string == kCompareToString.end()) {
182       MS_EXCEPTION(NotExistsError) << "compare_operator " << compare_operator
183                                    << " cannot find in the compare string map";
184     }
185     buffer << iter_to_string->second << match_value << " , but got " << arg_value << ".";
186     MS_EXCEPTION(ValueError) << buffer.str();
187   }
188 
189   template <typename T>
CheckInRange(const std::string & arg_name,T arg_value,CompareRange compare_operator,const std::pair<T,T> & range,const std::string & prim_name)190   static void CheckInRange(const std::string &arg_name, T arg_value, CompareRange compare_operator,
191                            const std::pair<T, T> &range, const std::string &prim_name) {
192     auto iter = kCompareRangeMap<float>.find(compare_operator);
193     if (iter == kCompareRangeMap<float>.end()) {
194       MS_EXCEPTION(NotExistsError) << "compare_operator " << compare_operator << " cannot find in the compare map";
195     }
196     if (range.first >= range.second) {
197       MS_EXCEPTION(ArgumentError) << "the check range left must be smaller than right number bug got [ " << range.first
198                                   << "," << range.second;
199     }
200     if (iter->second(arg_value, range)) {
201       return;
202     }
203     std::ostringstream buffer;
204     if (prim_name.empty()) {
205       buffer << "The attribute[" << arg_name << "] must ";
206     } else {
207       buffer << "The primitive[" << prim_name << "] " << arg_name << " must ";
208     }
209     auto iter_to_string = kCompareRangeToString.find(compare_operator);
210     if (iter_to_string == kCompareRangeToString.end()) {
211       MS_EXCEPTION(NotExistsError) << "compare_operator " << compare_operator
212                                    << " cannot find in the compare string map";
213     }
214     auto range_strng = iter_to_string->second;
215     buffer << range_strng.first << range.first << "," << range.second << range_strng.second << " ,but got " << arg_value
216            << ".";
217     MS_EXCEPTION(ValueError) << buffer.str();
218   }
219 
220   static ShapeMap ConvertShapePtrToShapeMap(const BaseShapePtr &shape);
221   static abstract::ShapePtr GetTensorInputShape(const std::string &prim_name,
222                                                 const std::vector<AbstractBasePtr> &input_args, int64_t index);
223   static void Check(const std::string &arg_name, int64_t arg_value, CompareEnum compare_type,
224                     const std::string &value_name, int64_t value, const std::string &prim_name = "",
225                     ExceptionType exception_type = ValueError);
226 
227   template <typename T>
228   static void Check(const std::string &arg_name, const std::vector<T> &arg_value, CompareEnum compare_type,
229                     const std::string &value_name, const std::vector<T> &value, const std::string &prim_name = "",
230                     ExceptionType exception_type = ValueError) {
231     if (compare_type != kEqual) {
232       auto iter = kCompareToString.find(compare_type);
233       if (iter != kCompareToString.end()) {
234         MS_EXCEPTION(NotSupportError) << "Only supported equal to compare two vectors but got " << iter->second;
235       }
236       MS_EXCEPTION(UnknownError) << "Cannot find the operator " << compare_type << "in the compare map!";
237     }
238     if (arg_value == value) {
239       return;
240     }
241     std::ostringstream buffer;
242     if (prim_name.empty()) {
243       buffer << "The attribute[" << arg_name << "]:";
244     } else {
245       buffer << "The primitive[" << prim_name << "]'s " << arg_name << ":";
246     }
247     auto iter_to_string = kCompareToString.find(compare_type);
248     if (iter_to_string == kCompareToString.end()) {
249       MS_EXCEPTION(NotExistsError) << "compare_operator " << compare_type << " cannot find in the compare string map";
250     }
251 
252     buffer << " [";
253     for (auto item : arg_value) {
254       buffer << item << ",";
255     }
256     buffer << "]";
257     buffer << " must " << iter_to_string->second << "[";
258     for (auto item : value) {
259       buffer << item << ",";
260     }
261     buffer << "]";
262     MS_EXCEPTION(exception_type) << buffer.str();
263   }
264 
265   template <typename T>
CheckArgs(const std::string & op,const AbstractBasePtrList & args_spec_list,size_t index)266   static std::shared_ptr<T> CheckArgs(const std::string &op, const AbstractBasePtrList &args_spec_list, size_t index) {
267     if (index >= args_spec_list.size()) {
268       MS_EXCEPTION(ValueError) << op << " evaluator arguments list index out of bound, size " << args_spec_list.size()
269                                << ", index " << index;
270     }
271     auto args_spec = args_spec_list[index];
272     MS_EXCEPTION_IF_NULL(args_spec);
273     auto arg = dyn_cast<T>(args_spec);
274     if (arg == nullptr) {
275       MS_EXCEPTION(TypeError) << "The primitive[" << op << "]'s input[" << index << "] should be a "
276                               << abstract::ReportNameTraits<T>::name << ", but got "
277                               << args_spec_list[index]->BuildType()->ToString() << ".";
278     }
279     return arg;
280   }
281 
282   static TypePtr CheckTensorTypeSame(const std::map<std::string, TypePtr> &types, const std::set<TypePtr> &check_list,
283                                      const std::string &prim_name);
284   static ShapeVector CheckTensorIntValue(const std::string &type_name, const ValuePtr &value,
285                                          const std::string &prim_name);
286   static TypePtr CheckTensorTypeValid(const std::string &type_name, const TypePtr &type,
287                                       const std::set<TypePtr> &check_list, const std::string &prim_name);
288   static TypePtr CheckSubClass(const std::string &type_name, const TypePtr &type,
289                                const std::set<TypePtr> &template_types, const std::string &prim_name);
290   static TypePtr CheckScalarOrTensorTypesSame(const std::map<std::string, TypePtr> &args,
291                                               const std::set<TypePtr> &valid_values, const std::string &prim_name,
292                                               bool allow_mix = false);
293   static TypePtr CheckTypeValid(const std::string &arg_name, const TypePtr &arg_type,
294                                 const std::set<TypePtr> &valid_type, const std::string &prim_name);
295   static bool ConvertAttrValueToInt(const std::string &op_type, const std::string &attr_name, ValuePtr *const value);
296   static bool ConvertAttrValueToString(const std::string &op_type, const std::string &attr_name, ValuePtr *const value);
297   static void ConvertAttrValueInExport(const std::string &op_type, const std::string &attr_name, ValuePtr *const value);
298   static void ConvertAttrValueInLoad(const std::string &op_type, const std::string &attr_name, ValuePtr *const value);
299   static AttrConverterPair GetAttrConvertPair(const std::string &op_type, const std::string &attr_name);
300   static bool GetDataFormatEnumValue(const ValuePtr &value, int64_t *enum_value);
301   static void GetPadModEnumValue(const ValuePtr &value, int64_t *enum_value, bool is_upper = false);
302   static void GetReductionEnumValue(const ValuePtr &value, int64_t *enum_value);
303   static bool CheckIrAttrtoOpAttr(const std::string &op_type, const std::string &attr_name, ValuePtr *const value);
304   static void CheckSummaryParam(const AbstractBasePtr &name, const AbstractBasePtr &value,
305                                 const std::string &class_name);
306   static void CheckMode(const std::string &class_name);
307   static std::vector<int64_t> CheckAttrIntOrTupleInt(const std::string &prim_name, const ValuePtr &attr,
308                                                      const std::string &arg_name);
309   static std::vector<int64_t> CheckAttrTupleInt(const std::string &prim_name, const ValuePtr &attr,
310                                                 const std::string &arg_name);
311   static void CheckMinMaxShape(const ShapeVector &shape, ShapeVector *min_shape, ShapeVector *max_shape);
312   static int64_t GetAndCheckFormat(const ValuePtr &value);
313   static size_t GetRemoveMonadAbsNum(const AbstractBasePtrList &abs_list);
314   static void CheckInputArgs(const std::vector<AbstractBasePtr> &input_args, const CompareEnum compare_operator,
315                              const int64_t match_value, const std::string &prim_name);
316   static TypePtr GetInputTensorType(const std::vector<AbstractBasePtr> &input_args, const size_t index,
317                                     const std::string &prim_name);
318   static bool HasDynamicShapeInput(const AbstractBasePtrList &abs_list);
319 
320  private:
321   static TypePtr _CheckTypeSame(const std::map<std::string, TypePtr> &args, const std::string &prim_name,
322                                 const bool allow_mix);
323   static std::string GetErrorTypeString(const std::set<TypePtr> &check_list, const TypePtr &check_type);
324 };
325 }  // namespace mindspore
326 #endif  // MINDSPORE_CORE_UTILS_CHECK_CONVERT_UTILS_H_
327