1 /**
2 * Copyright 2019 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "abstract/param_validator.h"
18
19 #include <algorithm>
20 #include <set>
21 #include <string>
22 #include <sstream>
23 #include <memory>
24 #include "utils/symbolic.h"
25 #include "abstract/utils.h"
26
27 namespace mindspore {
28 namespace abstract {
29 #define ABSTRACT_REPORT_NAME_DEC(abstract) constexpr char ReportNameTraits<Abstract##abstract>::name[];
30
31 ABSTRACT_REPORT_NAME_DEC(Tensor)
ABSTRACT_REPORT_NAME_DEC(Tuple)32 ABSTRACT_REPORT_NAME_DEC(Tuple)
33 ABSTRACT_REPORT_NAME_DEC(Scalar)
34 ABSTRACT_REPORT_NAME_DEC(List)
35 ABSTRACT_REPORT_NAME_DEC(Dictionary)
36 ABSTRACT_REPORT_NAME_DEC(Slice)
37 ABSTRACT_REPORT_NAME_DEC(Function)
38 ABSTRACT_REPORT_NAME_DEC(Type)
39 ABSTRACT_REPORT_NAME_DEC(KeywordArg)
40 ABSTRACT_REPORT_NAME_DEC(Class)
41
42 TypePtr CheckType(TypePtr type, const TypePtrList &accepts, const std::string &error_message_prefix) {
43 bool ok = std::any_of(accepts.begin(), accepts.end(),
44 [type](const TypePtr &accept) -> bool { return IsIdentidityOrSubclass(type, accept); });
45 if (ok) {
46 return type;
47 } else {
48 MS_LOG(EXCEPTION) << error_message_prefix << accepts << " but is " << type->ToString();
49 }
50 }
51
CheckTensorDType(const AbstractTensorPtr & tensor,const TypePtrList & accepts,const std::string & error_message_prefix)52 TypePtr CheckTensorDType(const AbstractTensorPtr &tensor, const TypePtrList &accepts,
53 const std::string &error_message_prefix) {
54 MS_EXCEPTION_IF_NULL(tensor);
55 TypePtr type = tensor->BuildType();
56 MS_EXCEPTION_IF_NULL(type);
57 if (!type->isa<TensorType>()) {
58 MS_LOG(EXCEPTION) << error_message_prefix << "requires Tensor but got " << type->ToString();
59 }
60 auto elem = tensor->element();
61 MS_EXCEPTION_IF_NULL(elem);
62 TypePtr ele_type = elem->BuildType();
63 if (ele_type == nullptr) {
64 MS_LOG(EXCEPTION) << "Abstract tensor element type nullptr";
65 }
66 return CheckType(ele_type, accepts, error_message_prefix);
67 }
68
CheckTensorsDTypeSame(const AbstractTensorPtrList & tensor_list,const TypePtrList & accepts,const std::string & error_message_prefix)69 TypePtr CheckTensorsDTypeSame(const AbstractTensorPtrList &tensor_list, const TypePtrList &accepts,
70 const std::string &error_message_prefix) {
71 if (tensor_list.empty()) {
72 MS_LOG(EXCEPTION) << "Array list is empty";
73 }
74
75 auto sample_tensor = tensor_list[0];
76 MS_EXCEPTION_IF_NULL(sample_tensor);
77 auto sample_elem = sample_tensor->element();
78 MS_EXCEPTION_IF_NULL(sample_elem);
79 TypePtr sample_type = sample_elem->BuildType();
80 MS_EXCEPTION_IF_NULL(sample_type);
81 std::ostringstream loginfoBuffer;
82 loginfoBuffer << "same type, got";
83 // Check if other elements have the same type with the first element.
84 for (size_t index = 1; index < tensor_list.size(); ++index) {
85 MS_EXCEPTION_IF_NULL(tensor_list[index]);
86 auto elem = tensor_list[index]->element();
87 MS_EXCEPTION_IF_NULL(elem);
88 auto a_type = elem->BuildType();
89 MS_EXCEPTION_IF_NULL(a_type);
90 loginfoBuffer << " " << a_type->ToString();
91 if (sample_type->type_id() != a_type->type_id()) {
92 MS_LOG(EXCEPTION) << "Expected type " << sample_type->ToString() << ", but got " << a_type->ToString()
93 << ", index " << index;
94 }
95 }
96 MS_LOG(DEBUG) << error_message_prefix << loginfoBuffer.str();
97 return CheckTensorDType(sample_tensor, accepts, error_message_prefix);
98 }
99
CheckScalarType(const AbstractScalarPtr & scalar,const TypePtrList & accepts,const std::string & error_message_prefix)100 TypePtr CheckScalarType(const AbstractScalarPtr &scalar, const TypePtrList &accepts,
101 const std::string &error_message_prefix) {
102 if (scalar == nullptr) {
103 MS_LOG(EXCEPTION) << "Scalar nullptr";
104 }
105 auto type = scalar->BuildType();
106 if (type == nullptr) {
107 MS_LOG(EXCEPTION) << "Scalar value nullptr";
108 }
109
110 return CheckType(type, accepts, error_message_prefix);
111 }
112
CheckShapeSame(const std::string & op,const AbstractTensorPtr & tensor_base,const AbstractTensorPtr & tensor)113 ShapePtr CheckShapeSame(const std::string &op, const AbstractTensorPtr &tensor_base, const AbstractTensorPtr &tensor) {
114 MS_EXCEPTION_IF_NULL(tensor_base);
115 ShapePtr shape_base = tensor_base->shape();
116 MS_EXCEPTION_IF_NULL(shape_base);
117 MS_EXCEPTION_IF_NULL(tensor);
118 ShapePtr shape = tensor->shape();
119 MS_EXCEPTION_IF_NULL(shape);
120 if (*shape != *shape_base) {
121 MS_LOG(EXCEPTION) << op << " evaluator first arg shape " << shape->ToString()
122 << " are not consistent with second arg shape " << shape_base->ToString();
123 }
124 return shape_base;
125 }
126
CheckDtypeSame(const std::string & op,const AbstractTensorPtr & tensor_base,const AbstractTensorPtr & tensor)127 TypePtr CheckDtypeSame(const std::string &op, const AbstractTensorPtr &tensor_base, const AbstractTensorPtr &tensor) {
128 MS_EXCEPTION_IF_NULL(tensor_base);
129 auto base_elem = tensor_base->element();
130 MS_EXCEPTION_IF_NULL(base_elem);
131 TypePtr type_base = base_elem->BuildType();
132 MS_EXCEPTION_IF_NULL(tensor);
133 auto tensor_elem = tensor->element();
134 MS_EXCEPTION_IF_NULL(tensor_elem);
135 TypePtr type = tensor_elem->BuildType();
136 MS_EXCEPTION_IF_NULL(type_base);
137 MS_EXCEPTION_IF_NULL(type);
138 if (*type != *type_base) {
139 MS_LOG(EXCEPTION) << op << " evaluator first arg dtype " << type_base->ToString()
140 << " are not consistent with second arg dtype " << type->ToString();
141 }
142 return type_base;
143 }
144
CheckAxis(const std::string & op,const ValuePtr & axis,int64_t minimum,int64_t max)145 int64_t CheckAxis(const std::string &op, const ValuePtr &axis, int64_t minimum, int64_t max) {
146 if (axis == nullptr) {
147 MS_LOG(EXCEPTION) << op << " evaluator axis is null";
148 }
149 if (!axis->isa<Int64Imm>()) {
150 MS_LOG(EXCEPTION) << op << " evaluator axis should be int64_t, but got " << axis->type_name();
151 }
152 int64_t axis_value = GetValue<int64_t>(axis);
153 if (axis_value > max || axis_value < minimum) {
154 MS_LOG(EXCEPTION) << op << " evaluator axis value should be in the range [" << minimum << ", " << max
155 << "], but get " << axis_value;
156 }
157 return axis_value;
158 }
CheckArgsSize(const std::string & op,const mindspore::abstract::AbstractBasePtrList & args_spec_list,size_t size_expect)159 void CheckArgsSize(const std::string &op, const mindspore::abstract::AbstractBasePtrList &args_spec_list,
160 size_t size_expect) {
161 if (args_spec_list.size() != size_expect) {
162 MS_LOG(EXCEPTION) << op << " input args size should be " << size_expect << ", but got " << args_spec_list.size();
163 }
164
165 for (size_t i = 0; i < size_expect; i++) {
166 MS_EXCEPTION_IF_NULL(args_spec_list[i]);
167 }
168 }
169
CheckShapeAllPositive(const std::string & op,const ShapeVector & shape)170 void CheckShapeAllPositive(const std::string &op, const ShapeVector &shape) {
171 for (size_t i = 0; i < shape.size(); ++i) {
172 if (shape[i] < 0) {
173 MS_LOG(EXCEPTION) << op << " shape element [" << i << "] must be positive integer, but got " << shape[i];
174 }
175 }
176 }
177
CheckShapeAnyAndPositive(const std::string & op,const ShapeVector & shape)178 void CheckShapeAnyAndPositive(const std::string &op, const ShapeVector &shape) {
179 for (size_t i = 0; i < shape.size(); ++i) {
180 if ((shape[i] < 0) && (shape[i] != Shape::SHP_ANY)) {
181 MS_EXCEPTION(ValueError) << op << " shape element [" << i << "] must be positive integer or SHP_ANY, but got "
182 << shape[i];
183 }
184 }
185 }
186
CheckAttrPositiveInt64(const std::string & op,const ValuePtr & attr,const std::string & attr_name)187 int64_t CheckAttrPositiveInt64(const std::string &op, const ValuePtr &attr, const std::string &attr_name) {
188 MS_EXCEPTION_IF_NULL(attr);
189 auto int64_value = attr->cast<Int64ImmPtr>();
190 MS_EXCEPTION_IF_NULL(int64_value);
191 int64_t attr_val = int64_value->value();
192 if (attr_val <= 0) {
193 MS_LOG(EXCEPTION) << op << " invalid " << attr_name << " value: " << attr_val << ", should be greater then 0";
194 }
195 return attr_val;
196 }
197
CheckAttrIntOrTuple(const std::string & op,const ValuePtr & attr,const size_t start_idx,const size_t num_element)198 std::vector<int64_t> CheckAttrIntOrTuple(const std::string &op, const ValuePtr &attr, const size_t start_idx,
199 const size_t num_element) {
200 std::vector<int64_t> result;
201 MS_EXCEPTION_IF_NULL(attr);
202 if (attr->isa<ValueTuple>()) {
203 auto tuple_attr = attr->cast<ValueTuplePtr>();
204 MS_EXCEPTION_IF_NULL(tuple_attr);
205 std::vector<ValuePtr> attr_vec = tuple_attr->value();
206 if (start_idx > attr_vec.size() || start_idx + num_element > attr_vec.size()) {
207 MS_EXCEPTION(IndexError) << op << " attr index is out of range, attr size is " << attr_vec.size()
208 << "but start idx got" << start_idx << " num element " << num_element;
209 }
210 auto it_start = attr_vec.begin() + start_idx;
211 (void)std::transform(it_start, it_start + num_element, std::back_inserter(result),
212 [](const ValuePtr &e) -> int64_t { return GetValue<int64_t>(e); });
213 } else {
214 auto int64_imm = attr->cast<Int64ImmPtr>();
215 MS_EXCEPTION_IF_NULL(int64_imm);
216 int64_t attr_val = int64_imm->value();
217 (void)result.insert(result.begin(), num_element, attr_val);
218 }
219 return result;
220 }
221
CheckAttrStringSet(const std::string & op,const ValuePtr & attr,const std::string & attr_name,const std::set<std::string> & val_set)222 std::string CheckAttrStringSet(const std::string &op, const ValuePtr &attr, const std::string &attr_name,
223 const std::set<std::string> &val_set) {
224 MS_EXCEPTION_IF_NULL(attr);
225 auto string_attr = attr->cast<StringImmPtr>();
226 MS_EXCEPTION_IF_NULL(string_attr);
227 std::string attr_val = string_attr->value();
228 if (val_set.find(attr_val) == val_set.end()) {
229 std::ostringstream buffer;
230 bool f_begin = true;
231 buffer << "{";
232 for (auto &x : val_set) {
233 if (!f_begin) {
234 buffer << ", ";
235 } else {
236 f_begin = false;
237 }
238 buffer << x;
239 }
240 buffer << "}";
241 MS_LOG(EXCEPTION) << op << "Unsupported " << attr_name << ": " << attr_val << ". use " << buffer.str();
242 }
243 return attr_val;
244 }
245
CheckRequiredArgsSize(const std::string & op,const mindspore::abstract::AbstractBasePtrList & args_spec_list,size_t size_expect)246 void CheckRequiredArgsSize(const std::string &op, const mindspore::abstract::AbstractBasePtrList &args_spec_list,
247 size_t size_expect) {
248 if (args_spec_list.size() < size_expect) {
249 MS_LOG(EXCEPTION) << op << " required input args size " << size_expect << ", but got " << args_spec_list.size();
250 }
251 for (size_t i = 0; i < size_expect; i++) {
252 MS_EXCEPTION_IF_NULL(args_spec_list[i]);
253 }
254 }
255
256 } // namespace abstract
257 } // namespace mindspore
258