• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "include/common/utils/primfunc_utils.h"
18 #include "include/common/utils/convert_utils_py.h"
19 
20 namespace mindspore::ops {
EnumToString(OP_DTYPE dtype)21 std::string EnumToString(OP_DTYPE dtype) {
22   static const std::unordered_map<OP_DTYPE, std::string> kEnumToStringMap = {
23     {OP_DTYPE::DT_BOOL, "bool"},
24     {OP_DTYPE::DT_INT, "int"},
25     {OP_DTYPE::DT_FLOAT, "float"},
26     {OP_DTYPE::DT_NUMBER, "Number"},
27     {OP_DTYPE::DT_TENSOR, "Tensor"},
28     {OP_DTYPE::DT_STR, "string"},
29     {OP_DTYPE::DT_ANY, "Any"},
30     {OP_DTYPE::DT_TUPLE_BOOL, "tuple of bool"},
31     {OP_DTYPE::DT_TUPLE_INT, "tuple of int"},
32     {OP_DTYPE::DT_TUPLE_FLOAT, "tuple of float"},
33     {OP_DTYPE::DT_TUPLE_NUMBER, "tuple of Number"},
34     {OP_DTYPE::DT_TUPLE_TENSOR, "tuple of Tensor"},
35     {OP_DTYPE::DT_TUPLE_STR, "tuple of string"},
36     {OP_DTYPE::DT_TUPLE_ANY, "tuple of Any"},
37     {OP_DTYPE::DT_LIST_BOOL, "list of bool"},
38     {OP_DTYPE::DT_LIST_INT, "list of int"},
39     {OP_DTYPE::DT_LIST_FLOAT, "list of float"},
40     {OP_DTYPE::DT_LIST_NUMBER, "list of number"},
41     {OP_DTYPE::DT_LIST_TENSOR, "list of tensor"},
42     {OP_DTYPE::DT_LIST_STR, "list of string"},
43     {OP_DTYPE::DT_LIST_ANY, "list of Any"},
44     {OP_DTYPE::DT_TYPE, "mstype"},
45   };
46 
47   auto it = kEnumToStringMap.find(dtype);
48   if (it == kEnumToStringMap.end()) {
49     MS_LOG(INTERNAL_EXCEPTION) << "Failed to map Enum[" << dtype << "] to String.";
50   }
51   return it->second;
52 }
53 
54 namespace {
55 template <typename T>
ValidateSequenceType(const AbstractBasePtr & abs_seq,OP_DTYPE type_elem)56 bool ValidateSequenceType(const AbstractBasePtr &abs_seq, OP_DTYPE type_elem) {
57   if (!abs_seq->isa<T>()) {
58     return false;
59   }
60   if (type_elem == OP_DTYPE::DT_ANY) {
61     return true;
62   }
63   auto abs = abs_seq->cast<abstract::AbstractSequencePtr>();
64   MS_EXCEPTION_IF_NULL(abs);
65   if (abs->dynamic_len()) {
66     return true;
67   }
68   for (const auto &abs_elem : abs->elements()) {
69     if (!ValidateArgsType(abs_elem, type_elem)) {
70       return false;
71     }
72   }
73   return true;
74 }
75 
ValidateArgsSequenceType(const AbstractBasePtr & abs_arg,OP_DTYPE type_arg)76 bool ValidateArgsSequenceType(const AbstractBasePtr &abs_arg, OP_DTYPE type_arg) {
77   switch (static_cast<int>(type_arg)) {
78     case OP_DTYPE::DT_TUPLE_BOOL: {
79       return ValidateSequenceType<abstract::AbstractTuple>(abs_arg, OP_DTYPE::DT_BOOL);
80     }
81     case OP_DTYPE::DT_TUPLE_INT: {
82       return ValidateSequenceType<abstract::AbstractTuple>(abs_arg, OP_DTYPE::DT_INT);
83     }
84     case OP_DTYPE::DT_TUPLE_FLOAT: {
85       return ValidateSequenceType<abstract::AbstractTuple>(abs_arg, OP_DTYPE::DT_FLOAT);
86     }
87     case OP_DTYPE::DT_TUPLE_NUMBER: {
88       return ValidateSequenceType<abstract::AbstractTuple>(abs_arg, OP_DTYPE::DT_NUMBER);
89     }
90     case OP_DTYPE::DT_TUPLE_TENSOR: {
91       return ValidateSequenceType<abstract::AbstractTuple>(abs_arg, OP_DTYPE::DT_TENSOR);
92     }
93     case OP_DTYPE::DT_TUPLE_STR: {
94       return ValidateSequenceType<abstract::AbstractTuple>(abs_arg, OP_DTYPE::DT_STR);
95     }
96     case OP_DTYPE::DT_TUPLE_ANY: {
97       return ValidateSequenceType<abstract::AbstractTuple>(abs_arg, OP_DTYPE::DT_ANY);
98     }
99     case OP_DTYPE::DT_LIST_BOOL: {
100       return ValidateSequenceType<abstract::AbstractList>(abs_arg, OP_DTYPE::DT_BOOL);
101     }
102     case OP_DTYPE::DT_LIST_INT: {
103       return ValidateSequenceType<abstract::AbstractList>(abs_arg, OP_DTYPE::DT_INT);
104     }
105     case OP_DTYPE::DT_LIST_FLOAT: {
106       return ValidateSequenceType<abstract::AbstractList>(abs_arg, OP_DTYPE::DT_FLOAT);
107     }
108     case OP_DTYPE::DT_LIST_NUMBER: {
109       return ValidateSequenceType<abstract::AbstractList>(abs_arg, OP_DTYPE::DT_NUMBER);
110     }
111     case OP_DTYPE::DT_LIST_TENSOR: {
112       return ValidateSequenceType<abstract::AbstractList>(abs_arg, OP_DTYPE::DT_TENSOR);
113     }
114     case OP_DTYPE::DT_LIST_STR: {
115       return ValidateSequenceType<abstract::AbstractList>(abs_arg, OP_DTYPE::DT_STR);
116     }
117     case OP_DTYPE::DT_LIST_ANY: {
118       return ValidateSequenceType<abstract::AbstractList>(abs_arg, OP_DTYPE::DT_ANY);
119     }
120     default: {
121       MS_EXCEPTION(ValueError) << "Unknown op dtype " << EnumToString(type_arg);
122     }
123   }
124 }
125 }  // namespace
126 
ValidateArgsType(const AbstractBasePtr & abs_arg,OP_DTYPE type_arg)127 bool ValidateArgsType(const AbstractBasePtr &abs_arg, OP_DTYPE type_arg) {
128   auto abs_type = abs_arg->BuildType();
129   MS_EXCEPTION_IF_NULL(abs_type);
130   switch (static_cast<int>(type_arg)) {
131     case OP_DTYPE::DT_ANY: {
132       return true;
133     }
134     case OP_DTYPE::DT_BOOL: {
135       return abs_arg->isa<abstract::AbstractScalar>() && abs_type->isa<Bool>();
136     }
137     case OP_DTYPE::DT_INT: {
138       return abs_arg->isa<abstract::AbstractScalar>() && (abs_type->isa<Int>() || abs_type->isa<UInt>());
139     }
140     case OP_DTYPE::DT_FLOAT: {
141       return abs_arg->isa<abstract::AbstractScalar>() && (abs_type->isa<Float>() || abs_type->isa<BFloat>());
142     }
143     case OP_DTYPE::DT_NUMBER: {
144       return abs_arg->isa<abstract::AbstractScalar>() && abs_type->isa<Number>();
145     }
146     case OP_DTYPE::DT_STR: {
147       return abs_arg->isa<abstract::AbstractScalar>() && abs_type->isa<String>();
148     }
149     case OP_DTYPE::DT_TENSOR: {
150       return abs_arg->isa<abstract::AbstractTensor>();
151     }
152     case OP_DTYPE::DT_TYPE: {
153       return abs_arg->isa<abstract::AbstractType>() && abs_type->isa<Type>();
154     }
155     default: {
156       return ValidateArgsSequenceType(abs_arg, type_arg);
157     }
158   }
159   return false;
160 }
161 
GetRealTypeByHandler(const std::string & type,const std::string & handler)162 static inline std::string GetRealTypeByHandler(const std::string &type, const std::string &handler) {
163   if (handler.empty()) {
164     return type;
165   }
166   static const std::unordered_map<std::string, std::string> handler_to_src_type{{"dtype_to_type_id", "mindspore.dtype"},
167                                                                                 {"str_to_enum", "string"}};
168   const auto iter = handler_to_src_type.find(handler);
169   return iter != handler_to_src_type.end() ? iter->second : type;
170 }
171 
GetRealInputType(const ops::OpInputArg & op_arg)172 static inline std::string GetRealInputType(const ops::OpInputArg &op_arg) {
173   return GetRealTypeByHandler(EnumToString(op_arg.arg_dtype_), op_arg.arg_handler_);
174 }
175 
GetRealTypes(const std::vector<std::string> & op_type_list,const std::vector<OpInputArg> & input_args)176 static inline std::vector<std::string> GetRealTypes(const std::vector<std::string> &op_type_list,
177                                                     const std::vector<OpInputArg> &input_args) {
178   if (input_args.size() != op_type_list.size()) {
179     MS_LOG_EXCEPTION << "size of input_args and op_type_list should be equal, but got " << input_args.size() << " vs "
180                      << op_type_list.size();
181   }
182   std::vector<std::string> real_types(op_type_list.size());
183   for (size_t i = 0; i < op_type_list.size(); ++i) {
184     real_types[i] = GetRealTypeByHandler(op_type_list[i], input_args[i].arg_handler_);
185   }
186   return real_types;
187 }
188 
BuildOpErrorMsg(const OpDefPtr & op_def,const std::vector<std::string> & op_type_list)189 std::string BuildOpErrorMsg(const OpDefPtr &op_def, const std::vector<std::string> &op_type_list) {
190   std::stringstream init_arg_ss;
191   std::stringstream input_arg_ss;
192   for (const auto &op_arg : op_def->args_) {
193     if (op_arg.as_init_arg_) {
194       init_arg_ss << op_arg.arg_name_ << "=<";
195       for (const auto &dtype : op_arg.cast_dtype_) {
196         init_arg_ss << EnumToString(dtype) << ", ";
197       }
198       init_arg_ss << GetRealInputType(op_arg) << ">, ";
199     } else {
200       input_arg_ss << op_arg.arg_name_ << "=<";
201       for (const auto &dtype : op_arg.cast_dtype_) {
202         input_arg_ss << EnumToString(dtype) << ", ";
203       }
204       input_arg_ss << GetRealInputType(op_arg) << ">, ";
205     }
206   }
207 
208   auto init_arg_str = init_arg_ss.str();
209   auto input_arg_str = input_arg_ss.str();
210   constexpr size_t truncate_offset = 2;
211   init_arg_str =
212     init_arg_str.empty() ? "" : init_arg_str.replace(init_arg_str.end() - truncate_offset, init_arg_str.end(), "");
213   input_arg_str =
214     input_arg_str.empty() ? "" : input_arg_str.replace(input_arg_str.end() - truncate_offset, input_arg_str.end(), "");
215 
216   std::stringstream real_init_arg_ss;
217   std::stringstream real_input_arg_ss;
218   auto real_op_type_list = GetRealTypes(op_type_list, op_def->args_);
219   for (size_t i = 0; i < real_op_type_list.size(); i++) {
220     const auto &op_arg = op_def->args_[i];
221     if (op_arg.as_init_arg_) {
222       real_init_arg_ss << op_arg.arg_name_ << "=" << real_op_type_list[i] << ", ";
223     } else {
224       real_input_arg_ss << op_arg.arg_name_ << "=" << real_op_type_list[i] << ", ";
225     }
226   }
227   auto real_init_arg_str = real_init_arg_ss.str();
228   auto real_input_arg_str = real_input_arg_ss.str();
229   real_init_arg_str = real_init_arg_str.empty() ? ""
230                                                 : real_init_arg_str.replace(real_init_arg_str.end() - truncate_offset,
231                                                                             real_init_arg_str.end(), "");
232   real_input_arg_str =
233     real_input_arg_str.empty()
234       ? ""
235       : real_input_arg_str.replace(real_input_arg_str.end() - truncate_offset, real_input_arg_str.end(), "");
236 
237   std::stringstream ss;
238   ss << "Failed calling " << op_def->name_ << " with \"" << op_def->name_ << "(" << real_init_arg_str << ")("
239      << real_input_arg_str << ")\"." << std::endl;
240   ss << "The valid calling should be: " << std::endl;
241   ss << "\"" << op_def->name_ << "(" << init_arg_str << ")(" << input_arg_str << ")\".";
242   return ss.str();
243 }
244 }  // namespace mindspore::ops
245