1 /**
2 * Copyright 2019-2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "abstract/param_validator.h"
18
19 #include <algorithm>
20 #include <string>
21 #include <sstream>
22 #include <memory>
23 #include "abstract/dshape.h"
24 #include "ir/dtype.h"
25 #include "ir/dtype/tensor_type.h"
26 #include "ir/scalar.h"
27 namespace mindspore {
28 namespace abstract {
29 #define ABSTRACT_REPORT_NAME_DEC(abstract) constexpr char ReportNameTraits<Abstract##abstract>::name[];
30
31 ABSTRACT_REPORT_NAME_DEC(Tensor)
ABSTRACT_REPORT_NAME_DEC(Tuple)32 ABSTRACT_REPORT_NAME_DEC(Tuple)
33 ABSTRACT_REPORT_NAME_DEC(Scalar)
34 ABSTRACT_REPORT_NAME_DEC(List)
35 ABSTRACT_REPORT_NAME_DEC(Dictionary)
36 ABSTRACT_REPORT_NAME_DEC(Slice)
37 ABSTRACT_REPORT_NAME_DEC(Function)
38 ABSTRACT_REPORT_NAME_DEC(Type)
39 ABSTRACT_REPORT_NAME_DEC(KeywordArg)
40
41 TypePtr CheckType(TypePtr type, const TypePtrList &accepts, const std::string &error_message_prefix) {
42 auto ori_type = type;
43 if (type->isa<TensorType>()) {
44 auto tensor = type->cast_ptr<TensorType>();
45 type = tensor->element();
46 MS_EXCEPTION_IF_NULL(type);
47 }
48 bool ok = std::any_of(accepts.begin(), accepts.end(),
49 [type](const TypePtr &accept) -> bool { return IsIdentidityOrSubclass(type, accept); });
50 if (ok) {
51 return type;
52 } else {
53 MS_EXCEPTION(TypeError) << error_message_prefix << " should be Tensor" << accepts << ",but got "
54 << ori_type->ToString();
55 }
56 }
57
CheckTensorDType(const AbstractBasePtr & tensor,const TypePtrList & accepts,const std::string & error_message_prefix)58 TypePtr CheckTensorDType(const AbstractBasePtr &tensor, const TypePtrList &accepts,
59 const std::string &error_message_prefix) {
60 MS_EXCEPTION_IF_NULL(tensor);
61 TypePtr type = tensor->GetType();
62 MS_EXCEPTION_IF_NULL(type);
63 if (!type->isa<TensorType>()) {
64 MS_LOG(EXCEPTION) << error_message_prefix << "requires Tensor but got " << type->ToString();
65 }
66 return CheckType(type, accepts, error_message_prefix);
67 }
68
CheckTensorsDTypeSame(const AbstractTensorPtrList & tensor_list,const TypePtrList & accepts,const std::string & error_message_prefix)69 TypePtr CheckTensorsDTypeSame(const AbstractTensorPtrList &tensor_list, const TypePtrList &accepts,
70 const std::string &error_message_prefix) {
71 if (tensor_list.empty()) {
72 MS_LOG(EXCEPTION) << "Array list is empty";
73 }
74
75 auto sample_tensor = tensor_list[0];
76 MS_EXCEPTION_IF_NULL(sample_tensor);
77 auto sample_elem = sample_tensor->element();
78 MS_EXCEPTION_IF_NULL(sample_elem);
79 TypePtr sample_type = sample_elem->BuildType();
80 MS_EXCEPTION_IF_NULL(sample_type);
81 std::ostringstream loginfoBuffer;
82 loginfoBuffer << "[" << sample_tensor->BuildType()->ToString();
83 bool error_flag = false;
84 // Check if other elements have the same type with the first element.
85 for (size_t index = 1; index < tensor_list.size(); ++index) {
86 MS_EXCEPTION_IF_NULL(tensor_list[index]);
87 auto elem = tensor_list[index]->element();
88 MS_EXCEPTION_IF_NULL(elem);
89 auto a_type = elem->BuildType();
90 MS_EXCEPTION_IF_NULL(a_type);
91 loginfoBuffer << "," << tensor_list[index]->BuildType()->ToString();
92 if (sample_type->type_id() != a_type->type_id()) {
93 error_flag = true;
94 }
95 }
96 if (error_flag) {
97 MS_EXCEPTION(ValueError) << error_message_prefix << " must be same, but got " << loginfoBuffer.str() << "]";
98 }
99 MS_LOG(DEBUG) << error_message_prefix << loginfoBuffer.str();
100 return CheckTensorDType(sample_tensor, accepts, error_message_prefix);
101 }
102
CheckScalarType(const AbstractScalarPtr & scalar,const TypePtrList & accepts,const std::string & error_message_prefix)103 TypePtr CheckScalarType(const AbstractScalarPtr &scalar, const TypePtrList &accepts,
104 const std::string &error_message_prefix) {
105 if (scalar == nullptr) {
106 MS_LOG(INTERNAL_EXCEPTION) << "Scalar nullptr";
107 }
108 auto type = scalar->BuildType();
109 if (type == nullptr) {
110 MS_LOG(INTERNAL_EXCEPTION) << "Scalar value nullptr";
111 }
112
113 return CheckType(type, accepts, error_message_prefix);
114 }
115
116 // new function
CheckShapeSame(const std::string & op,const AbstractBasePtr & tensor_base,const AbstractBasePtr & tensor)117 void CheckShapeSame(const std::string &op, const AbstractBasePtr &tensor_base, const AbstractBasePtr &tensor) {
118 MS_EXCEPTION_IF_NULL(tensor_base);
119 if (tensor_base->GetType()->object_type() != kObjectTypeTensorType) {
120 MS_EXCEPTION(TypeError) << "For primitive[" << op << "], the first input should be tensor type, but got "
121 << tensor_base->GetType()->ToString() << ".";
122 }
123 auto shape_base = tensor_base->GetShape();
124 MS_EXCEPTION_IF_NULL(shape_base);
125 MS_EXCEPTION_IF_NULL(tensor);
126 if (tensor->GetType()->object_type() != kObjectTypeTensorType) {
127 MS_EXCEPTION(TypeError) << "For primitive[" << op << "], the second input should be tensor type, but got "
128 << tensor->GetType()->ToString() << ".";
129 }
130 auto shape = tensor->GetShape();
131 MS_EXCEPTION_IF_NULL(shape);
132 if (shape_base->IsDimUnknown() || shape->IsDimUnknown()) {
133 return;
134 }
135
136 const auto &shape_vector = shape->GetShapeVector();
137 const auto &shape_base_vector = shape_base->GetShapeVector();
138 if (shape_vector.size() != shape_base_vector.size()) {
139 MS_EXCEPTION(ValueError) << "For '" << op << "', the shape of two args should be same, but the first arg shape "
140 << shape_base->ToString() << " are not consistent with second arg shape "
141 << shape->ToString();
142 }
143
144 for (size_t i = 0; i < shape_vector.size(); i++) {
145 if (shape_vector[i] == Shape::kShapeDimAny || shape_base_vector[i] == Shape::kShapeDimAny) {
146 continue;
147 }
148 if (shape_vector[i] != shape_base_vector[i]) {
149 MS_EXCEPTION(ValueError) << "For '" << op << "', the shape of two args should be same, but the first arg shape "
150 << shape_base->ToString() << " are not consistent with second arg shape "
151 << shape->ToString();
152 }
153 }
154 return;
155 }
156
CheckDtypeSame(const std::string & op,const AbstractBasePtr & tensor_base,const AbstractBasePtr & tensor)157 TypePtr CheckDtypeSame(const std::string &op, const AbstractBasePtr &tensor_base, const AbstractBasePtr &tensor) {
158 MS_EXCEPTION_IF_NULL(tensor_base);
159 TypePtr type_base = tensor_base->GetType();
160 MS_EXCEPTION_IF_NULL(tensor);
161 TypePtr type = tensor->GetType();
162 MS_EXCEPTION_IF_NULL(type_base);
163 MS_EXCEPTION_IF_NULL(type);
164 CheckDtypeSame(op, type_base, type);
165 return type_base;
166 }
167
168 // old function
CheckShapeSame(const std::string & op,const AbstractTensorPtr & tensor_base,const AbstractTensorPtr & tensor)169 void CheckShapeSame(const std::string &op, const AbstractTensorPtr &tensor_base, const AbstractTensorPtr &tensor) {
170 MS_EXCEPTION_IF_NULL(tensor_base);
171 ShapePtr shape_base = tensor_base->shape();
172 MS_EXCEPTION_IF_NULL(shape_base);
173 MS_EXCEPTION_IF_NULL(tensor);
174 ShapePtr shape = tensor->shape();
175 MS_EXCEPTION_IF_NULL(shape);
176 if (shape_base->IsDimUnknown() || shape->IsDimUnknown()) {
177 return;
178 }
179
180 auto shape_vector = shape->shape();
181 auto shape_base_vector = shape_base->shape();
182 if (shape_vector.size() != shape_base_vector.size()) {
183 MS_EXCEPTION(ValueError) << "For '" << op << "', the shape of two args should be same, but the first arg shape "
184 << shape_base->ToString() << " are not consistent with second arg shape "
185 << shape->ToString();
186 }
187
188 for (size_t i = 0; i < shape_vector.size(); i++) {
189 if (shape_vector[i] == Shape::kShapeDimAny || shape_base_vector[i] == Shape::kShapeDimAny) {
190 continue;
191 }
192 if (shape_vector[i] != shape_base_vector[i]) {
193 MS_EXCEPTION(ValueError) << "For '" << op << "', the shape of two args should be same, but the first arg shape "
194 << shape_base->ToString() << " are not consistent with second arg shape "
195 << shape->ToString();
196 }
197 }
198 return;
199 }
200
CheckDtypeSame(const std::string & op,const AbstractTensorPtr & tensor_base,const AbstractTensorPtr & tensor)201 TypePtr CheckDtypeSame(const std::string &op, const AbstractTensorPtr &tensor_base, const AbstractTensorPtr &tensor) {
202 MS_EXCEPTION_IF_NULL(tensor_base);
203 auto base_elem = tensor_base->element();
204 MS_EXCEPTION_IF_NULL(base_elem);
205 TypePtr type_base = base_elem->BuildType();
206 MS_EXCEPTION_IF_NULL(tensor);
207 auto tensor_elem = tensor->element();
208 MS_EXCEPTION_IF_NULL(tensor_elem);
209 TypePtr type = tensor_elem->BuildType();
210 MS_EXCEPTION_IF_NULL(type_base);
211 MS_EXCEPTION_IF_NULL(type);
212 CheckDtypeSame(op, type_base, type);
213 return type_base;
214 }
215
CheckAxis(const std::string & op,const std::string & args_name,const ValuePtr & axis,int64_t minimum,int64_t max,const std::string & rank_name)216 int64_t CheckAxis(const std::string &op, const std::string &args_name, const ValuePtr &axis, int64_t minimum,
217 int64_t max, const std::string &rank_name) {
218 if (axis == nullptr) {
219 MS_LOG(EXCEPTION) << op << " evaluator axis is null";
220 }
221 if (!axis->isa<Int64Imm>()) {
222 MS_LOG(EXCEPTION) << op << " evaluator axis should be int64_t, but got " << axis->type_name();
223 }
224 int64_t axis_value = GetValue<int64_t>(axis);
225 if (axis_value >= max || axis_value < minimum) {
226 MS_LOG(EXCEPTION) << "For primitive[" << op << "], " << rank_name << "'s rank is " << max << ", while the "
227 << "\'" << args_name << "\' value should be in the range [" << minimum << ", " << max
228 << "), but got " << axis_value;
229 }
230 if (axis_value < 0) {
231 axis_value = axis_value + max;
232 }
233 return axis_value;
234 }
CheckArgsSize(const std::string & op,const mindspore::abstract::AbstractBasePtrList & args_abs_list,size_t size_expect)235 void CheckArgsSize(const std::string &op, const mindspore::abstract::AbstractBasePtrList &args_abs_list,
236 size_t size_expect) {
237 if (args_abs_list.size() != size_expect) {
238 MS_LOG(EXCEPTION) << "For '" << op << "', the number of input should be " << size_expect << ", but got "
239 << args_abs_list.size();
240 }
241
242 for (size_t i = 0; i < size_expect; i++) {
243 MS_EXCEPTION_IF_NULL(args_abs_list[i]);
244 }
245 }
246
CheckShapeAllPositive(const std::string & op,const ShapeVector & shape)247 void CheckShapeAllPositive(const std::string &op, const ShapeVector &shape) {
248 for (size_t i = 0; i < shape.size(); ++i) {
249 if (shape[i] < 0) {
250 MS_LOG(EXCEPTION) << "For '" << op << "', shape element [" << i << "] must be positive integer, but got "
251 << shape[i];
252 }
253 }
254 }
255
CheckShapeAnyAndPositive(const std::string & op,const ShapeVector & shape)256 void CheckShapeAnyAndPositive(const std::string &op, const ShapeVector &shape) {
257 for (size_t i = 0; i < shape.size(); ++i) {
258 if ((shape[i] < 0) && (shape[i] != Shape::kShapeDimAny)) {
259 MS_EXCEPTION(ValueError) << op << " shape element [" << i
260 << "] must be positive integer or kShapeDimAny, but got " << shape[i];
261 }
262 }
263 }
264
CheckRequiredArgsSize(const std::string & op,const mindspore::abstract::AbstractBasePtrList & args_abs_list,size_t size_expect)265 void CheckRequiredArgsSize(const std::string &op, const mindspore::abstract::AbstractBasePtrList &args_abs_list,
266 size_t size_expect) {
267 if (args_abs_list.size() < size_expect) {
268 MS_LOG(EXCEPTION) << op << " required input args size " << size_expect << ", but got " << args_abs_list.size();
269 }
270 for (size_t i = 0; i < size_expect; i++) {
271 MS_EXCEPTION_IF_NULL(args_abs_list[i]);
272 }
273 }
274 } // namespace abstract
275 } // namespace mindspore
276