• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "abstract/param_validator.h"
18 
19 #include <algorithm>
20 #include <string>
21 #include <sstream>
22 #include <memory>
23 #include "abstract/dshape.h"
24 #include "ir/dtype.h"
25 #include "ir/dtype/tensor_type.h"
26 #include "ir/scalar.h"
27 namespace mindspore {
28 namespace abstract {
29 #define ABSTRACT_REPORT_NAME_DEC(abstract) constexpr char ReportNameTraits<Abstract##abstract>::name[];
30 
31 ABSTRACT_REPORT_NAME_DEC(Tensor)
ABSTRACT_REPORT_NAME_DEC(Tuple)32 ABSTRACT_REPORT_NAME_DEC(Tuple)
33 ABSTRACT_REPORT_NAME_DEC(Scalar)
34 ABSTRACT_REPORT_NAME_DEC(List)
35 ABSTRACT_REPORT_NAME_DEC(Dictionary)
36 ABSTRACT_REPORT_NAME_DEC(Slice)
37 ABSTRACT_REPORT_NAME_DEC(Function)
38 ABSTRACT_REPORT_NAME_DEC(Type)
39 ABSTRACT_REPORT_NAME_DEC(KeywordArg)
40 
41 TypePtr CheckType(TypePtr type, const TypePtrList &accepts, const std::string &error_message_prefix) {
42   auto ori_type = type;
43   if (type->isa<TensorType>()) {
44     auto tensor = type->cast_ptr<TensorType>();
45     type = tensor->element();
46     MS_EXCEPTION_IF_NULL(type);
47   }
48   bool ok = std::any_of(accepts.begin(), accepts.end(),
49                         [type](const TypePtr &accept) -> bool { return IsIdentidityOrSubclass(type, accept); });
50   if (ok) {
51     return type;
52   } else {
53     MS_EXCEPTION(TypeError) << error_message_prefix << " should be Tensor" << accepts << ",but got "
54                             << ori_type->ToString();
55   }
56 }
57 
CheckTensorDType(const AbstractBasePtr & tensor,const TypePtrList & accepts,const std::string & error_message_prefix)58 TypePtr CheckTensorDType(const AbstractBasePtr &tensor, const TypePtrList &accepts,
59                          const std::string &error_message_prefix) {
60   MS_EXCEPTION_IF_NULL(tensor);
61   TypePtr type = tensor->GetType();
62   MS_EXCEPTION_IF_NULL(type);
63   if (!type->isa<TensorType>()) {
64     MS_LOG(EXCEPTION) << error_message_prefix << "requires Tensor but got " << type->ToString();
65   }
66   return CheckType(type, accepts, error_message_prefix);
67 }
68 
CheckTensorsDTypeSame(const AbstractTensorPtrList & tensor_list,const TypePtrList & accepts,const std::string & error_message_prefix)69 TypePtr CheckTensorsDTypeSame(const AbstractTensorPtrList &tensor_list, const TypePtrList &accepts,
70                               const std::string &error_message_prefix) {
71   if (tensor_list.empty()) {
72     MS_LOG(EXCEPTION) << "Array list is empty";
73   }
74 
75   auto sample_tensor = tensor_list[0];
76   MS_EXCEPTION_IF_NULL(sample_tensor);
77   auto sample_elem = sample_tensor->element();
78   MS_EXCEPTION_IF_NULL(sample_elem);
79   TypePtr sample_type = sample_elem->BuildType();
80   MS_EXCEPTION_IF_NULL(sample_type);
81   std::ostringstream loginfoBuffer;
82   loginfoBuffer << "[" << sample_tensor->BuildType()->ToString();
83   bool error_flag = false;
84   // Check if other elements have the same type with the first element.
85   for (size_t index = 1; index < tensor_list.size(); ++index) {
86     MS_EXCEPTION_IF_NULL(tensor_list[index]);
87     auto elem = tensor_list[index]->element();
88     MS_EXCEPTION_IF_NULL(elem);
89     auto a_type = elem->BuildType();
90     MS_EXCEPTION_IF_NULL(a_type);
91     loginfoBuffer << "," << tensor_list[index]->BuildType()->ToString();
92     if (sample_type->type_id() != a_type->type_id()) {
93       error_flag = true;
94     }
95   }
96   if (error_flag) {
97     MS_EXCEPTION(ValueError) << error_message_prefix << " must be same, but got " << loginfoBuffer.str() << "]";
98   }
99   MS_LOG(DEBUG) << error_message_prefix << loginfoBuffer.str();
100   return CheckTensorDType(sample_tensor, accepts, error_message_prefix);
101 }
102 
CheckScalarType(const AbstractScalarPtr & scalar,const TypePtrList & accepts,const std::string & error_message_prefix)103 TypePtr CheckScalarType(const AbstractScalarPtr &scalar, const TypePtrList &accepts,
104                         const std::string &error_message_prefix) {
105   if (scalar == nullptr) {
106     MS_LOG(INTERNAL_EXCEPTION) << "Scalar nullptr";
107   }
108   auto type = scalar->BuildType();
109   if (type == nullptr) {
110     MS_LOG(INTERNAL_EXCEPTION) << "Scalar value nullptr";
111   }
112 
113   return CheckType(type, accepts, error_message_prefix);
114 }
115 
116 // new function
CheckShapeSame(const std::string & op,const AbstractBasePtr & tensor_base,const AbstractBasePtr & tensor)117 void CheckShapeSame(const std::string &op, const AbstractBasePtr &tensor_base, const AbstractBasePtr &tensor) {
118   MS_EXCEPTION_IF_NULL(tensor_base);
119   if (tensor_base->GetType()->object_type() != kObjectTypeTensorType) {
120     MS_EXCEPTION(TypeError) << "For primitive[" << op << "], the first input should be tensor type, but got "
121                             << tensor_base->GetType()->ToString() << ".";
122   }
123   auto shape_base = tensor_base->GetShape();
124   MS_EXCEPTION_IF_NULL(shape_base);
125   MS_EXCEPTION_IF_NULL(tensor);
126   if (tensor->GetType()->object_type() != kObjectTypeTensorType) {
127     MS_EXCEPTION(TypeError) << "For primitive[" << op << "], the second input should be tensor type, but got "
128                             << tensor->GetType()->ToString() << ".";
129   }
130   auto shape = tensor->GetShape();
131   MS_EXCEPTION_IF_NULL(shape);
132   if (shape_base->IsDimUnknown() || shape->IsDimUnknown()) {
133     return;
134   }
135 
136   const auto &shape_vector = shape->GetShapeVector();
137   const auto &shape_base_vector = shape_base->GetShapeVector();
138   if (shape_vector.size() != shape_base_vector.size()) {
139     MS_EXCEPTION(ValueError) << "For '" << op << "', the shape of two args should be same, but the first arg shape "
140                              << shape_base->ToString() << " are not consistent with second arg shape "
141                              << shape->ToString();
142   }
143 
144   for (size_t i = 0; i < shape_vector.size(); i++) {
145     if (shape_vector[i] == Shape::kShapeDimAny || shape_base_vector[i] == Shape::kShapeDimAny) {
146       continue;
147     }
148     if (shape_vector[i] != shape_base_vector[i]) {
149       MS_EXCEPTION(ValueError) << "For '" << op << "',  the shape of two args should be same, but the first arg shape "
150                                << shape_base->ToString() << " are not consistent with second arg shape "
151                                << shape->ToString();
152     }
153   }
154   return;
155 }
156 
CheckDtypeSame(const std::string & op,const AbstractBasePtr & tensor_base,const AbstractBasePtr & tensor)157 TypePtr CheckDtypeSame(const std::string &op, const AbstractBasePtr &tensor_base, const AbstractBasePtr &tensor) {
158   MS_EXCEPTION_IF_NULL(tensor_base);
159   TypePtr type_base = tensor_base->GetType();
160   MS_EXCEPTION_IF_NULL(tensor);
161   TypePtr type = tensor->GetType();
162   MS_EXCEPTION_IF_NULL(type_base);
163   MS_EXCEPTION_IF_NULL(type);
164   CheckDtypeSame(op, type_base, type);
165   return type_base;
166 }
167 
168 // old function
CheckShapeSame(const std::string & op,const AbstractTensorPtr & tensor_base,const AbstractTensorPtr & tensor)169 void CheckShapeSame(const std::string &op, const AbstractTensorPtr &tensor_base, const AbstractTensorPtr &tensor) {
170   MS_EXCEPTION_IF_NULL(tensor_base);
171   ShapePtr shape_base = tensor_base->shape();
172   MS_EXCEPTION_IF_NULL(shape_base);
173   MS_EXCEPTION_IF_NULL(tensor);
174   ShapePtr shape = tensor->shape();
175   MS_EXCEPTION_IF_NULL(shape);
176   if (shape_base->IsDimUnknown() || shape->IsDimUnknown()) {
177     return;
178   }
179 
180   auto shape_vector = shape->shape();
181   auto shape_base_vector = shape_base->shape();
182   if (shape_vector.size() != shape_base_vector.size()) {
183     MS_EXCEPTION(ValueError) << "For '" << op << "', the shape of two args should be same, but the first arg shape "
184                              << shape_base->ToString() << " are not consistent with second arg shape "
185                              << shape->ToString();
186   }
187 
188   for (size_t i = 0; i < shape_vector.size(); i++) {
189     if (shape_vector[i] == Shape::kShapeDimAny || shape_base_vector[i] == Shape::kShapeDimAny) {
190       continue;
191     }
192     if (shape_vector[i] != shape_base_vector[i]) {
193       MS_EXCEPTION(ValueError) << "For '" << op << "',  the shape of two args should be same, but the first arg shape "
194                                << shape_base->ToString() << " are not consistent with second arg shape "
195                                << shape->ToString();
196     }
197   }
198   return;
199 }
200 
CheckDtypeSame(const std::string & op,const AbstractTensorPtr & tensor_base,const AbstractTensorPtr & tensor)201 TypePtr CheckDtypeSame(const std::string &op, const AbstractTensorPtr &tensor_base, const AbstractTensorPtr &tensor) {
202   MS_EXCEPTION_IF_NULL(tensor_base);
203   auto base_elem = tensor_base->element();
204   MS_EXCEPTION_IF_NULL(base_elem);
205   TypePtr type_base = base_elem->BuildType();
206   MS_EXCEPTION_IF_NULL(tensor);
207   auto tensor_elem = tensor->element();
208   MS_EXCEPTION_IF_NULL(tensor_elem);
209   TypePtr type = tensor_elem->BuildType();
210   MS_EXCEPTION_IF_NULL(type_base);
211   MS_EXCEPTION_IF_NULL(type);
212   CheckDtypeSame(op, type_base, type);
213   return type_base;
214 }
215 
CheckAxis(const std::string & op,const std::string & args_name,const ValuePtr & axis,int64_t minimum,int64_t max,const std::string & rank_name)216 int64_t CheckAxis(const std::string &op, const std::string &args_name, const ValuePtr &axis, int64_t minimum,
217                   int64_t max, const std::string &rank_name) {
218   if (axis == nullptr) {
219     MS_LOG(EXCEPTION) << op << " evaluator axis is null";
220   }
221   if (!axis->isa<Int64Imm>()) {
222     MS_LOG(EXCEPTION) << op << " evaluator axis should be int64_t, but got " << axis->type_name();
223   }
224   int64_t axis_value = GetValue<int64_t>(axis);
225   if (axis_value >= max || axis_value < minimum) {
226     MS_LOG(EXCEPTION) << "For primitive[" << op << "], " << rank_name << "'s rank is " << max << ", while the "
227                       << "\'" << args_name << "\' value should be in the range [" << minimum << ", " << max
228                       << "), but got " << axis_value;
229   }
230   if (axis_value < 0) {
231     axis_value = axis_value + max;
232   }
233   return axis_value;
234 }
CheckArgsSize(const std::string & op,const mindspore::abstract::AbstractBasePtrList & args_abs_list,size_t size_expect)235 void CheckArgsSize(const std::string &op, const mindspore::abstract::AbstractBasePtrList &args_abs_list,
236                    size_t size_expect) {
237   if (args_abs_list.size() != size_expect) {
238     MS_LOG(EXCEPTION) << "For '" << op << "', the number of input should be " << size_expect << ", but got "
239                       << args_abs_list.size();
240   }
241 
242   for (size_t i = 0; i < size_expect; i++) {
243     MS_EXCEPTION_IF_NULL(args_abs_list[i]);
244   }
245 }
246 
CheckShapeAllPositive(const std::string & op,const ShapeVector & shape)247 void CheckShapeAllPositive(const std::string &op, const ShapeVector &shape) {
248   for (size_t i = 0; i < shape.size(); ++i) {
249     if (shape[i] < 0) {
250       MS_LOG(EXCEPTION) << "For '" << op << "', shape element [" << i << "] must be positive integer, but got "
251                         << shape[i];
252     }
253   }
254 }
255 
CheckShapeAnyAndPositive(const std::string & op,const ShapeVector & shape)256 void CheckShapeAnyAndPositive(const std::string &op, const ShapeVector &shape) {
257   for (size_t i = 0; i < shape.size(); ++i) {
258     if ((shape[i] < 0) && (shape[i] != Shape::kShapeDimAny)) {
259       MS_EXCEPTION(ValueError) << op << " shape element [" << i
260                                << "] must be positive integer or kShapeDimAny, but got " << shape[i];
261     }
262   }
263 }
264 
CheckRequiredArgsSize(const std::string & op,const mindspore::abstract::AbstractBasePtrList & args_abs_list,size_t size_expect)265 void CheckRequiredArgsSize(const std::string &op, const mindspore::abstract::AbstractBasePtrList &args_abs_list,
266                            size_t size_expect) {
267   if (args_abs_list.size() < size_expect) {
268     MS_LOG(EXCEPTION) << op << " required input args size " << size_expect << ", but got " << args_abs_list.size();
269   }
270   for (size_t i = 0; i < size_expect; i++) {
271     MS_EXCEPTION_IF_NULL(args_abs_list[i]);
272   }
273 }
274 }  // namespace abstract
275 }  // namespace mindspore
276