• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "abstract/param_validator.h"
18 
19 #include <algorithm>
20 #include <set>
21 #include <string>
22 #include <sstream>
23 #include <memory>
24 #include "utils/symbolic.h"
25 #include "abstract/utils.h"
26 
27 namespace mindspore {
28 namespace abstract {
29 #define ABSTRACT_REPORT_NAME_DEC(abstract) constexpr char ReportNameTraits<Abstract##abstract>::name[];
30 
31 ABSTRACT_REPORT_NAME_DEC(Tensor)
ABSTRACT_REPORT_NAME_DEC(Tuple)32 ABSTRACT_REPORT_NAME_DEC(Tuple)
33 ABSTRACT_REPORT_NAME_DEC(Scalar)
34 ABSTRACT_REPORT_NAME_DEC(List)
35 ABSTRACT_REPORT_NAME_DEC(Dictionary)
36 ABSTRACT_REPORT_NAME_DEC(Slice)
37 ABSTRACT_REPORT_NAME_DEC(Function)
38 ABSTRACT_REPORT_NAME_DEC(Type)
39 ABSTRACT_REPORT_NAME_DEC(KeywordArg)
40 ABSTRACT_REPORT_NAME_DEC(Class)
41 
42 TypePtr CheckType(TypePtr type, const TypePtrList &accepts, const std::string &error_message_prefix) {
43   bool ok = std::any_of(accepts.begin(), accepts.end(),
44                         [type](const TypePtr &accept) -> bool { return IsIdentidityOrSubclass(type, accept); });
45   if (ok) {
46     return type;
47   } else {
48     MS_LOG(EXCEPTION) << error_message_prefix << accepts << " but is " << type->ToString();
49   }
50 }
51 
CheckTensorDType(const AbstractTensorPtr & tensor,const TypePtrList & accepts,const std::string & error_message_prefix)52 TypePtr CheckTensorDType(const AbstractTensorPtr &tensor, const TypePtrList &accepts,
53                          const std::string &error_message_prefix) {
54   MS_EXCEPTION_IF_NULL(tensor);
55   TypePtr type = tensor->BuildType();
56   MS_EXCEPTION_IF_NULL(type);
57   if (!type->isa<TensorType>()) {
58     MS_LOG(EXCEPTION) << error_message_prefix << "requires Tensor but got " << type->ToString();
59   }
60   auto elem = tensor->element();
61   MS_EXCEPTION_IF_NULL(elem);
62   TypePtr ele_type = elem->BuildType();
63   if (ele_type == nullptr) {
64     MS_LOG(EXCEPTION) << "Abstract tensor element type nullptr";
65   }
66   return CheckType(ele_type, accepts, error_message_prefix);
67 }
68 
CheckTensorsDTypeSame(const AbstractTensorPtrList & tensor_list,const TypePtrList & accepts,const std::string & error_message_prefix)69 TypePtr CheckTensorsDTypeSame(const AbstractTensorPtrList &tensor_list, const TypePtrList &accepts,
70                               const std::string &error_message_prefix) {
71   if (tensor_list.empty()) {
72     MS_LOG(EXCEPTION) << "Array list is empty";
73   }
74 
75   auto sample_tensor = tensor_list[0];
76   MS_EXCEPTION_IF_NULL(sample_tensor);
77   auto sample_elem = sample_tensor->element();
78   MS_EXCEPTION_IF_NULL(sample_elem);
79   TypePtr sample_type = sample_elem->BuildType();
80   MS_EXCEPTION_IF_NULL(sample_type);
81   std::ostringstream loginfoBuffer;
82   loginfoBuffer << "same type, got";
83   // Check if other elements have the same type with the first element.
84   for (size_t index = 1; index < tensor_list.size(); ++index) {
85     MS_EXCEPTION_IF_NULL(tensor_list[index]);
86     auto elem = tensor_list[index]->element();
87     MS_EXCEPTION_IF_NULL(elem);
88     auto a_type = elem->BuildType();
89     MS_EXCEPTION_IF_NULL(a_type);
90     loginfoBuffer << " " << a_type->ToString();
91     if (sample_type->type_id() != a_type->type_id()) {
92       MS_LOG(EXCEPTION) << "Expected type " << sample_type->ToString() << ", but got " << a_type->ToString()
93                         << ", index " << index;
94     }
95   }
96   MS_LOG(DEBUG) << error_message_prefix << loginfoBuffer.str();
97   return CheckTensorDType(sample_tensor, accepts, error_message_prefix);
98 }
99 
CheckScalarType(const AbstractScalarPtr & scalar,const TypePtrList & accepts,const std::string & error_message_prefix)100 TypePtr CheckScalarType(const AbstractScalarPtr &scalar, const TypePtrList &accepts,
101                         const std::string &error_message_prefix) {
102   if (scalar == nullptr) {
103     MS_LOG(EXCEPTION) << "Scalar nullptr";
104   }
105   auto type = scalar->BuildType();
106   if (type == nullptr) {
107     MS_LOG(EXCEPTION) << "Scalar value nullptr";
108   }
109 
110   return CheckType(type, accepts, error_message_prefix);
111 }
112 
CheckShapeSame(const std::string & op,const AbstractTensorPtr & tensor_base,const AbstractTensorPtr & tensor)113 ShapePtr CheckShapeSame(const std::string &op, const AbstractTensorPtr &tensor_base, const AbstractTensorPtr &tensor) {
114   MS_EXCEPTION_IF_NULL(tensor_base);
115   ShapePtr shape_base = tensor_base->shape();
116   MS_EXCEPTION_IF_NULL(shape_base);
117   MS_EXCEPTION_IF_NULL(tensor);
118   ShapePtr shape = tensor->shape();
119   MS_EXCEPTION_IF_NULL(shape);
120   if (*shape != *shape_base) {
121     MS_LOG(EXCEPTION) << op << " evaluator first arg shape " << shape->ToString()
122                       << " are not consistent with second arg shape " << shape_base->ToString();
123   }
124   return shape_base;
125 }
126 
CheckDtypeSame(const std::string & op,const AbstractTensorPtr & tensor_base,const AbstractTensorPtr & tensor)127 TypePtr CheckDtypeSame(const std::string &op, const AbstractTensorPtr &tensor_base, const AbstractTensorPtr &tensor) {
128   MS_EXCEPTION_IF_NULL(tensor_base);
129   auto base_elem = tensor_base->element();
130   MS_EXCEPTION_IF_NULL(base_elem);
131   TypePtr type_base = base_elem->BuildType();
132   MS_EXCEPTION_IF_NULL(tensor);
133   auto tensor_elem = tensor->element();
134   MS_EXCEPTION_IF_NULL(tensor_elem);
135   TypePtr type = tensor_elem->BuildType();
136   MS_EXCEPTION_IF_NULL(type_base);
137   MS_EXCEPTION_IF_NULL(type);
138   if (*type != *type_base) {
139     MS_LOG(EXCEPTION) << op << " evaluator first arg dtype " << type_base->ToString()
140                       << " are not consistent with second arg dtype " << type->ToString();
141   }
142   return type_base;
143 }
144 
CheckAxis(const std::string & op,const ValuePtr & axis,int64_t minimum,int64_t max)145 int64_t CheckAxis(const std::string &op, const ValuePtr &axis, int64_t minimum, int64_t max) {
146   if (axis == nullptr) {
147     MS_LOG(EXCEPTION) << op << " evaluator axis is null";
148   }
149   if (!axis->isa<Int64Imm>()) {
150     MS_LOG(EXCEPTION) << op << " evaluator axis should be int64_t, but got " << axis->type_name();
151   }
152   int64_t axis_value = GetValue<int64_t>(axis);
153   if (axis_value > max || axis_value < minimum) {
154     MS_LOG(EXCEPTION) << op << " evaluator axis value should be in the range [" << minimum << ", " << max
155                       << "], but get " << axis_value;
156   }
157   return axis_value;
158 }
CheckArgsSize(const std::string & op,const mindspore::abstract::AbstractBasePtrList & args_spec_list,size_t size_expect)159 void CheckArgsSize(const std::string &op, const mindspore::abstract::AbstractBasePtrList &args_spec_list,
160                    size_t size_expect) {
161   if (args_spec_list.size() != size_expect) {
162     MS_LOG(EXCEPTION) << op << " input args size should be " << size_expect << ", but got " << args_spec_list.size();
163   }
164 
165   for (size_t i = 0; i < size_expect; i++) {
166     MS_EXCEPTION_IF_NULL(args_spec_list[i]);
167   }
168 }
169 
CheckShapeAllPositive(const std::string & op,const ShapeVector & shape)170 void CheckShapeAllPositive(const std::string &op, const ShapeVector &shape) {
171   for (size_t i = 0; i < shape.size(); ++i) {
172     if (shape[i] < 0) {
173       MS_LOG(EXCEPTION) << op << " shape element [" << i << "] must be positive integer, but got " << shape[i];
174     }
175   }
176 }
177 
CheckShapeAnyAndPositive(const std::string & op,const ShapeVector & shape)178 void CheckShapeAnyAndPositive(const std::string &op, const ShapeVector &shape) {
179   for (size_t i = 0; i < shape.size(); ++i) {
180     if ((shape[i] < 0) && (shape[i] != Shape::SHP_ANY)) {
181       MS_EXCEPTION(ValueError) << op << " shape element [" << i << "] must be positive integer or SHP_ANY, but got "
182                                << shape[i];
183     }
184   }
185 }
186 
CheckAttrPositiveInt64(const std::string & op,const ValuePtr & attr,const std::string & attr_name)187 int64_t CheckAttrPositiveInt64(const std::string &op, const ValuePtr &attr, const std::string &attr_name) {
188   MS_EXCEPTION_IF_NULL(attr);
189   auto int64_value = attr->cast<Int64ImmPtr>();
190   MS_EXCEPTION_IF_NULL(int64_value);
191   int64_t attr_val = int64_value->value();
192   if (attr_val <= 0) {
193     MS_LOG(EXCEPTION) << op << " invalid " << attr_name << " value: " << attr_val << ", should be greater then 0";
194   }
195   return attr_val;
196 }
197 
CheckAttrIntOrTuple(const std::string & op,const ValuePtr & attr,const size_t start_idx,const size_t num_element)198 std::vector<int64_t> CheckAttrIntOrTuple(const std::string &op, const ValuePtr &attr, const size_t start_idx,
199                                          const size_t num_element) {
200   std::vector<int64_t> result;
201   MS_EXCEPTION_IF_NULL(attr);
202   if (attr->isa<ValueTuple>()) {
203     auto tuple_attr = attr->cast<ValueTuplePtr>();
204     MS_EXCEPTION_IF_NULL(tuple_attr);
205     std::vector<ValuePtr> attr_vec = tuple_attr->value();
206     if (start_idx > attr_vec.size() || start_idx + num_element > attr_vec.size()) {
207       MS_EXCEPTION(IndexError) << op << " attr index is out of range, attr size is " << attr_vec.size()
208                                << "but start idx got" << start_idx << " num element " << num_element;
209     }
210     auto it_start = attr_vec.begin() + start_idx;
211     (void)std::transform(it_start, it_start + num_element, std::back_inserter(result),
212                          [](const ValuePtr &e) -> int64_t { return GetValue<int64_t>(e); });
213   } else {
214     auto int64_imm = attr->cast<Int64ImmPtr>();
215     MS_EXCEPTION_IF_NULL(int64_imm);
216     int64_t attr_val = int64_imm->value();
217     (void)result.insert(result.begin(), num_element, attr_val);
218   }
219   return result;
220 }
221 
CheckAttrStringSet(const std::string & op,const ValuePtr & attr,const std::string & attr_name,const std::set<std::string> & val_set)222 std::string CheckAttrStringSet(const std::string &op, const ValuePtr &attr, const std::string &attr_name,
223                                const std::set<std::string> &val_set) {
224   MS_EXCEPTION_IF_NULL(attr);
225   auto string_attr = attr->cast<StringImmPtr>();
226   MS_EXCEPTION_IF_NULL(string_attr);
227   std::string attr_val = string_attr->value();
228   if (val_set.find(attr_val) == val_set.end()) {
229     std::ostringstream buffer;
230     bool f_begin = true;
231     buffer << "{";
232     for (auto &x : val_set) {
233       if (!f_begin) {
234         buffer << ", ";
235       } else {
236         f_begin = false;
237       }
238       buffer << x;
239     }
240     buffer << "}";
241     MS_LOG(EXCEPTION) << op << "Unsupported " << attr_name << ": " << attr_val << ". use " << buffer.str();
242   }
243   return attr_val;
244 }
245 
CheckRequiredArgsSize(const std::string & op,const mindspore::abstract::AbstractBasePtrList & args_spec_list,size_t size_expect)246 void CheckRequiredArgsSize(const std::string &op, const mindspore::abstract::AbstractBasePtrList &args_spec_list,
247                            size_t size_expect) {
248   if (args_spec_list.size() < size_expect) {
249     MS_LOG(EXCEPTION) << op << " required input args size " << size_expect << ", but got " << args_spec_list.size();
250   }
251   for (size_t i = 0; i < size_expect; i++) {
252     MS_EXCEPTION_IF_NULL(args_spec_list[i]);
253   }
254 }
255 
256 }  // namespace abstract
257 }  // namespace mindspore
258