1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "include/common/utils/primfunc_utils.h"
18 #include "include/common/utils/convert_utils_py.h"
19
20 namespace mindspore::ops {
EnumToString(OP_DTYPE dtype)21 std::string EnumToString(OP_DTYPE dtype) {
22 static const std::unordered_map<OP_DTYPE, std::string> kEnumToStringMap = {
23 {OP_DTYPE::DT_BOOL, "bool"},
24 {OP_DTYPE::DT_INT, "int"},
25 {OP_DTYPE::DT_FLOAT, "float"},
26 {OP_DTYPE::DT_NUMBER, "Number"},
27 {OP_DTYPE::DT_TENSOR, "Tensor"},
28 {OP_DTYPE::DT_STR, "string"},
29 {OP_DTYPE::DT_ANY, "Any"},
30 {OP_DTYPE::DT_TUPLE_BOOL, "tuple of bool"},
31 {OP_DTYPE::DT_TUPLE_INT, "tuple of int"},
32 {OP_DTYPE::DT_TUPLE_FLOAT, "tuple of float"},
33 {OP_DTYPE::DT_TUPLE_NUMBER, "tuple of Number"},
34 {OP_DTYPE::DT_TUPLE_TENSOR, "tuple of Tensor"},
35 {OP_DTYPE::DT_TUPLE_STR, "tuple of string"},
36 {OP_DTYPE::DT_TUPLE_ANY, "tuple of Any"},
37 {OP_DTYPE::DT_LIST_BOOL, "list of bool"},
38 {OP_DTYPE::DT_LIST_INT, "list of int"},
39 {OP_DTYPE::DT_LIST_FLOAT, "list of float"},
40 {OP_DTYPE::DT_LIST_NUMBER, "list of number"},
41 {OP_DTYPE::DT_LIST_TENSOR, "list of tensor"},
42 {OP_DTYPE::DT_LIST_STR, "list of string"},
43 {OP_DTYPE::DT_LIST_ANY, "list of Any"},
44 {OP_DTYPE::DT_TYPE, "mstype"},
45 };
46
47 auto it = kEnumToStringMap.find(dtype);
48 if (it == kEnumToStringMap.end()) {
49 MS_LOG(INTERNAL_EXCEPTION) << "Failed to map Enum[" << dtype << "] to String.";
50 }
51 return it->second;
52 }
53
54 namespace {
55 template <typename T>
ValidateSequenceType(const AbstractBasePtr & abs_seq,OP_DTYPE type_elem)56 bool ValidateSequenceType(const AbstractBasePtr &abs_seq, OP_DTYPE type_elem) {
57 if (!abs_seq->isa<T>()) {
58 return false;
59 }
60 if (type_elem == OP_DTYPE::DT_ANY) {
61 return true;
62 }
63 auto abs = abs_seq->cast<abstract::AbstractSequencePtr>();
64 MS_EXCEPTION_IF_NULL(abs);
65 if (abs->dynamic_len()) {
66 return true;
67 }
68 for (const auto &abs_elem : abs->elements()) {
69 if (!ValidateArgsType(abs_elem, type_elem)) {
70 return false;
71 }
72 }
73 return true;
74 }
75
ValidateArgsSequenceType(const AbstractBasePtr & abs_arg,OP_DTYPE type_arg)76 bool ValidateArgsSequenceType(const AbstractBasePtr &abs_arg, OP_DTYPE type_arg) {
77 switch (static_cast<int>(type_arg)) {
78 case OP_DTYPE::DT_TUPLE_BOOL: {
79 return ValidateSequenceType<abstract::AbstractTuple>(abs_arg, OP_DTYPE::DT_BOOL);
80 }
81 case OP_DTYPE::DT_TUPLE_INT: {
82 return ValidateSequenceType<abstract::AbstractTuple>(abs_arg, OP_DTYPE::DT_INT);
83 }
84 case OP_DTYPE::DT_TUPLE_FLOAT: {
85 return ValidateSequenceType<abstract::AbstractTuple>(abs_arg, OP_DTYPE::DT_FLOAT);
86 }
87 case OP_DTYPE::DT_TUPLE_NUMBER: {
88 return ValidateSequenceType<abstract::AbstractTuple>(abs_arg, OP_DTYPE::DT_NUMBER);
89 }
90 case OP_DTYPE::DT_TUPLE_TENSOR: {
91 return ValidateSequenceType<abstract::AbstractTuple>(abs_arg, OP_DTYPE::DT_TENSOR);
92 }
93 case OP_DTYPE::DT_TUPLE_STR: {
94 return ValidateSequenceType<abstract::AbstractTuple>(abs_arg, OP_DTYPE::DT_STR);
95 }
96 case OP_DTYPE::DT_TUPLE_ANY: {
97 return ValidateSequenceType<abstract::AbstractTuple>(abs_arg, OP_DTYPE::DT_ANY);
98 }
99 case OP_DTYPE::DT_LIST_BOOL: {
100 return ValidateSequenceType<abstract::AbstractList>(abs_arg, OP_DTYPE::DT_BOOL);
101 }
102 case OP_DTYPE::DT_LIST_INT: {
103 return ValidateSequenceType<abstract::AbstractList>(abs_arg, OP_DTYPE::DT_INT);
104 }
105 case OP_DTYPE::DT_LIST_FLOAT: {
106 return ValidateSequenceType<abstract::AbstractList>(abs_arg, OP_DTYPE::DT_FLOAT);
107 }
108 case OP_DTYPE::DT_LIST_NUMBER: {
109 return ValidateSequenceType<abstract::AbstractList>(abs_arg, OP_DTYPE::DT_NUMBER);
110 }
111 case OP_DTYPE::DT_LIST_TENSOR: {
112 return ValidateSequenceType<abstract::AbstractList>(abs_arg, OP_DTYPE::DT_TENSOR);
113 }
114 case OP_DTYPE::DT_LIST_STR: {
115 return ValidateSequenceType<abstract::AbstractList>(abs_arg, OP_DTYPE::DT_STR);
116 }
117 case OP_DTYPE::DT_LIST_ANY: {
118 return ValidateSequenceType<abstract::AbstractList>(abs_arg, OP_DTYPE::DT_ANY);
119 }
120 default: {
121 MS_EXCEPTION(ValueError) << "Unknown op dtype " << EnumToString(type_arg);
122 }
123 }
124 }
125 } // namespace
126
ValidateArgsType(const AbstractBasePtr & abs_arg,OP_DTYPE type_arg)127 bool ValidateArgsType(const AbstractBasePtr &abs_arg, OP_DTYPE type_arg) {
128 auto abs_type = abs_arg->BuildType();
129 MS_EXCEPTION_IF_NULL(abs_type);
130 switch (static_cast<int>(type_arg)) {
131 case OP_DTYPE::DT_ANY: {
132 return true;
133 }
134 case OP_DTYPE::DT_BOOL: {
135 return abs_arg->isa<abstract::AbstractScalar>() && abs_type->isa<Bool>();
136 }
137 case OP_DTYPE::DT_INT: {
138 return abs_arg->isa<abstract::AbstractScalar>() && (abs_type->isa<Int>() || abs_type->isa<UInt>());
139 }
140 case OP_DTYPE::DT_FLOAT: {
141 return abs_arg->isa<abstract::AbstractScalar>() && (abs_type->isa<Float>() || abs_type->isa<BFloat>());
142 }
143 case OP_DTYPE::DT_NUMBER: {
144 return abs_arg->isa<abstract::AbstractScalar>() && abs_type->isa<Number>();
145 }
146 case OP_DTYPE::DT_STR: {
147 return abs_arg->isa<abstract::AbstractScalar>() && abs_type->isa<String>();
148 }
149 case OP_DTYPE::DT_TENSOR: {
150 return abs_arg->isa<abstract::AbstractTensor>();
151 }
152 case OP_DTYPE::DT_TYPE: {
153 return abs_arg->isa<abstract::AbstractType>() && abs_type->isa<Type>();
154 }
155 default: {
156 return ValidateArgsSequenceType(abs_arg, type_arg);
157 }
158 }
159 return false;
160 }
161
GetRealTypeByHandler(const std::string & type,const std::string & handler)162 static inline std::string GetRealTypeByHandler(const std::string &type, const std::string &handler) {
163 if (handler.empty()) {
164 return type;
165 }
166 static const std::unordered_map<std::string, std::string> handler_to_src_type{{"dtype_to_type_id", "mindspore.dtype"},
167 {"str_to_enum", "string"}};
168 const auto iter = handler_to_src_type.find(handler);
169 return iter != handler_to_src_type.end() ? iter->second : type;
170 }
171
GetRealInputType(const ops::OpInputArg & op_arg)172 static inline std::string GetRealInputType(const ops::OpInputArg &op_arg) {
173 return GetRealTypeByHandler(EnumToString(op_arg.arg_dtype_), op_arg.arg_handler_);
174 }
175
GetRealTypes(const std::vector<std::string> & op_type_list,const std::vector<OpInputArg> & input_args)176 static inline std::vector<std::string> GetRealTypes(const std::vector<std::string> &op_type_list,
177 const std::vector<OpInputArg> &input_args) {
178 if (input_args.size() != op_type_list.size()) {
179 MS_LOG_EXCEPTION << "size of input_args and op_type_list should be equal, but got " << input_args.size() << " vs "
180 << op_type_list.size();
181 }
182 std::vector<std::string> real_types(op_type_list.size());
183 for (size_t i = 0; i < op_type_list.size(); ++i) {
184 real_types[i] = GetRealTypeByHandler(op_type_list[i], input_args[i].arg_handler_);
185 }
186 return real_types;
187 }
188
BuildOpErrorMsg(const OpDefPtr & op_def,const std::vector<std::string> & op_type_list)189 std::string BuildOpErrorMsg(const OpDefPtr &op_def, const std::vector<std::string> &op_type_list) {
190 std::stringstream init_arg_ss;
191 std::stringstream input_arg_ss;
192 for (const auto &op_arg : op_def->args_) {
193 if (op_arg.as_init_arg_) {
194 init_arg_ss << op_arg.arg_name_ << "=<";
195 for (const auto &dtype : op_arg.cast_dtype_) {
196 init_arg_ss << EnumToString(dtype) << ", ";
197 }
198 init_arg_ss << GetRealInputType(op_arg) << ">, ";
199 } else {
200 input_arg_ss << op_arg.arg_name_ << "=<";
201 for (const auto &dtype : op_arg.cast_dtype_) {
202 input_arg_ss << EnumToString(dtype) << ", ";
203 }
204 input_arg_ss << GetRealInputType(op_arg) << ">, ";
205 }
206 }
207
208 auto init_arg_str = init_arg_ss.str();
209 auto input_arg_str = input_arg_ss.str();
210 constexpr size_t truncate_offset = 2;
211 init_arg_str =
212 init_arg_str.empty() ? "" : init_arg_str.replace(init_arg_str.end() - truncate_offset, init_arg_str.end(), "");
213 input_arg_str =
214 input_arg_str.empty() ? "" : input_arg_str.replace(input_arg_str.end() - truncate_offset, input_arg_str.end(), "");
215
216 std::stringstream real_init_arg_ss;
217 std::stringstream real_input_arg_ss;
218 auto real_op_type_list = GetRealTypes(op_type_list, op_def->args_);
219 for (size_t i = 0; i < real_op_type_list.size(); i++) {
220 const auto &op_arg = op_def->args_[i];
221 if (op_arg.as_init_arg_) {
222 real_init_arg_ss << op_arg.arg_name_ << "=" << real_op_type_list[i] << ", ";
223 } else {
224 real_input_arg_ss << op_arg.arg_name_ << "=" << real_op_type_list[i] << ", ";
225 }
226 }
227 auto real_init_arg_str = real_init_arg_ss.str();
228 auto real_input_arg_str = real_input_arg_ss.str();
229 real_init_arg_str = real_init_arg_str.empty() ? ""
230 : real_init_arg_str.replace(real_init_arg_str.end() - truncate_offset,
231 real_init_arg_str.end(), "");
232 real_input_arg_str =
233 real_input_arg_str.empty()
234 ? ""
235 : real_input_arg_str.replace(real_input_arg_str.end() - truncate_offset, real_input_arg_str.end(), "");
236
237 std::stringstream ss;
238 ss << "Failed calling " << op_def->name_ << " with \"" << op_def->name_ << "(" << real_init_arg_str << ")("
239 << real_input_arg_str << ")\"." << std::endl;
240 ss << "The valid calling should be: " << std::endl;
241 ss << "\"" << op_def->name_ << "(" << init_arg_str << ")(" << input_arg_str << ")\".";
242 return ss.str();
243 }
244 } // namespace mindspore::ops
245