1 /**
2 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3 *
4 * Copyright 2019 Huawei Technologies Co., Ltd
5 *
6 * Licensed under the Apache License, Version 2.0 (the "License");
7 * you may not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an "AS IS" BASIS,
14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 #include "abstract/utils.h"
20
21 #include <string>
22 #include <sstream>
23 #include <memory>
24 #include "utils/ms_context.h"
25 #include "utils/symbolic.h"
26 #include "abstract/param_validator.h"
27
28 namespace mindspore {
29 namespace abstract {
30 const std::map<TypeId, size_t> type_map = {
31 {kNumberTypeBool, 1}, {kNumberTypeInt, 4}, {kNumberTypeInt8, 1}, {kNumberTypeInt16, 2},
32 {kNumberTypeInt32, 4}, {kNumberTypeInt64, 8}, {kNumberTypeUInt, 4}, {kNumberTypeUInt8, 1},
33 {kNumberTypeUInt16, 2}, {kNumberTypeUInt32, 4}, {kNumberTypeUInt64, 8}, {kNumberTypeFloat, 4},
34 {kNumberTypeFloat16, 2}, {kNumberTypeFloat32, 4}, {kNumberTypeFloat64, 8}, {kNumberTypeComplex64, 8},
35 {kNumberTypeComplex128, 16}};
36
ValueJoin(const ValuePtr & value1,const ValuePtr & value2)37 ValuePtr ValueJoin(const ValuePtr &value1, const ValuePtr &value2) {
38 MS_EXCEPTION_IF_NULL(value1);
39 MS_EXCEPTION_IF_NULL(value2);
40 if (*value1 == *value2) {
41 return value1;
42 }
43 return kAnyValue;
44 }
45
TypeJoin(const TypePtr & type1,const TypePtr & type2)46 TypePtr TypeJoin(const TypePtr &type1, const TypePtr &type2) {
47 MS_EXCEPTION_IF_NULL(type1);
48 MS_EXCEPTION_IF_NULL(type2);
49 if (*type1 == *type2) {
50 return type1;
51 }
52 return kAnyType;
53 }
54
CalculateDynamicShape(const ShapePtr & shape1,const ShapePtr & shape2,const ShapeVector & dims)55 ShapePtr CalculateDynamicShape(const ShapePtr &shape1, const ShapePtr &shape2, const ShapeVector &dims) {
56 // calculate dynamic shape
57 ShapeVector min_dims(dims.size());
58 ShapeVector max_dims(dims.size());
59 MS_EXCEPTION_IF_NULL(shape1);
60 MS_EXCEPTION_IF_NULL(shape2);
61 for (size_t i = 0; i < dims.size(); ++i) {
62 if (dims[i] != Shape::SHP_ANY) {
63 min_dims[i] = max_dims[i] = dims[i];
64 continue;
65 }
66 if (shape1->shape()[i] != Shape::SHP_ANY && shape2->shape()[i] != Shape::SHP_ANY) {
67 min_dims[i] = std::min(shape1->shape()[i], shape2->shape()[i]);
68 max_dims[i] = std::max(shape1->shape()[i], shape2->shape()[i]);
69 continue;
70 }
71 if (shape1->shape()[i] == Shape::SHP_ANY && shape2->shape()[i] != Shape::SHP_ANY) {
72 if (shape1->min_shape().size() <= i || shape1->max_shape().size() <= i) {
73 MS_EXCEPTION(ValueError) << "Shape " << shape1->ToString()
74 << " has dynamic shape, but does not have min/max shape info.";
75 }
76 min_dims[i] = std::min(shape1->min_shape()[i], shape2->shape()[i]);
77 max_dims[i] = std::max(shape1->max_shape()[i], shape2->shape()[i]);
78 continue;
79 }
80 if (shape1->shape()[i] != Shape::SHP_ANY && shape2->shape()[i] == Shape::SHP_ANY) {
81 if (shape2->min_shape().size() <= i || shape2->max_shape().size() <= i) {
82 MS_EXCEPTION(ValueError) << "Shape " << shape1->ToString()
83 << " has dynamic shape, but does not have min/max shape info.";
84 }
85 min_dims[i] = std::min(shape1->shape()[i], shape2->min_shape()[i]);
86 max_dims[i] = std::max(shape1->shape()[i], shape2->max_shape()[i]);
87 continue;
88 }
89 // both shapes contains dynamic shape
90 if (shape1->min_shape().size() <= i || shape1->max_shape().size() <= i) {
91 MS_EXCEPTION(ValueError) << "Shape " << shape1->ToString()
92 << " has dynamic shape, but does not have min/max shape info.";
93 }
94 if (shape2->min_shape().size() <= i || shape2->max_shape().size() <= i) {
95 MS_EXCEPTION(ValueError) << "Shape " << shape2->ToString()
96 << " has dynamic shape, but does not have min/max shape info.";
97 }
98 min_dims[i] = std::min(shape1->min_shape()[i], shape2->min_shape()[i]);
99 max_dims[i] = std::max(shape1->max_shape()[i], shape2->max_shape()[i]);
100 }
101 return std::make_shared<Shape>(dims, min_dims, max_dims);
102 }
103
ShapeJoin(const ShapePtr & shape1,const ShapePtr & shape2)104 ShapePtr ShapeJoin(const ShapePtr &shape1, const ShapePtr &shape2) {
105 MS_EXCEPTION_IF_NULL(shape1);
106 MS_EXCEPTION_IF_NULL(shape2);
107 if (*shape1 == *shape2) {
108 return shape1;
109 }
110 // lengths of two shapes are not same, join failed
111 if (shape1->shape().size() != shape2->shape().size()) {
112 // special case: shape(1), shape() -> shape(1)
113 if (shape1->shape().size() == 1 && shape1->shape()[0] == 1 && shape2->shape().empty()) {
114 return shape1;
115 }
116 if (shape2->shape().size() == 1 && shape2->shape()[0] == 1 && shape1->shape().empty()) {
117 return shape2;
118 }
119 return nullptr;
120 }
121 ShapeVector dims;
122 bool has_dynamic_shape = false;
123 dims.resize(shape1->shape().size());
124 for (std::size_t i = 0; i < shape1->shape().size(); i++) {
125 if (shape1->shape()[i] == shape2->shape()[i]) {
126 dims[i] = shape1->shape()[i];
127 if (shape1->shape()[i] == Shape::SHP_ANY) {
128 has_dynamic_shape = true;
129 }
130 } else {
131 dims[i] = Shape::SHP_ANY;
132 has_dynamic_shape = true;
133 }
134 }
135 if (!has_dynamic_shape) {
136 return std::make_shared<Shape>(dims);
137 }
138 return CalculateDynamicShape(shape1, shape2, dims);
139 }
140
AbstractJoin(const AbstractBasePtrList & args_spec_list)141 AbstractBasePtr AbstractJoin(const AbstractBasePtrList &args_spec_list) {
142 if (args_spec_list.empty()) {
143 MS_LOG(EXCEPTION) << "AbstractJoin requires at least 1 params, while the input size is " << args_spec_list.size()
144 << ".";
145 }
146 AbstractBasePtr arg_spec_tmp = args_spec_list[0];
147 MS_EXCEPTION_IF_NULL(arg_spec_tmp);
148 for (const auto &arg_spec : args_spec_list) {
149 MS_EXCEPTION_IF_NULL(arg_spec);
150 arg_spec_tmp = arg_spec_tmp->Join(arg_spec);
151 MS_EXCEPTION_IF_NULL(arg_spec_tmp);
152 }
153 return arg_spec_tmp;
154 }
155
AbstractJoin(const AbstractBasePtrList & spec1,const AbstractBasePtrList & spec2)156 AbstractBasePtrList AbstractJoin(const AbstractBasePtrList &spec1, const AbstractBasePtrList &spec2) {
157 if (spec1.size() != spec2.size()) {
158 MS_LOG(EXCEPTION) << "Join failed as list don't have the same size. spec1: " << ::mindspore::ToString(spec1)
159 << ", spec2: " << ::mindspore::ToString(spec2);
160 }
161 AbstractBasePtrList joined_list;
162 bool changes = false;
163 for (std::size_t i = 0; i < spec1.size(); i++) {
164 MS_EXCEPTION_IF_NULL(spec1[i]);
165 auto joined_elem = spec1[i]->Join(spec2[i]);
166 MS_EXCEPTION_IF_NULL(joined_elem);
167 if (joined_elem != spec1[i]) {
168 changes = true;
169 }
170 joined_list.push_back(joined_elem);
171 }
172 if (!changes) {
173 return spec1;
174 }
175 return joined_list;
176 }
177
SensitivityTransform(const AbstractBasePtr & spec)178 AbstractBasePtr SensitivityTransform(const AbstractBasePtr &spec) {
179 AbstractFunctionPtr f_spec = dyn_cast<AbstractFunction>(spec);
180 if (f_spec != nullptr) {
181 return std::make_shared<AbstractScalar>(kAnyValue, std::make_shared<EnvType>());
182 }
183 return spec->Clone();
184 }
185
186 namespace {
187 // Join all types in args_type_list;
TypeJoin(const TypePtrList & args_type_list)188 TypePtr TypeJoin(const TypePtrList &args_type_list) {
189 if (args_type_list.empty()) {
190 MS_LOG(EXCEPTION) << "args_type_list is empty";
191 }
192
193 TypePtr type_tmp = args_type_list[0];
194 for (std::size_t i = 1; i < args_type_list.size(); i++) {
195 type_tmp = abstract::TypeJoin(type_tmp, args_type_list[i]);
196 }
197 return type_tmp;
198 }
199 } // namespace
200
CheckType(const TypePtr & expected_type,const TypePtr & x)201 bool CheckType(const TypePtr &expected_type, const TypePtr &x) {
202 // As x and predicate both are mindspore type statically, here we only to judge whether
203 // x is predicate or is a subclass of predicate.
204 return IsIdentidityOrSubclass(x, expected_type);
205 }
206
CheckTypeList(const TypePtr & predicate,const TypePtrList & args_type_list)207 TypePtr CheckTypeList(const TypePtr &predicate, const TypePtrList &args_type_list) {
208 MS_EXCEPTION_IF_NULL(predicate);
209 for (const auto &arg_type : args_type_list) {
210 MS_EXCEPTION_IF_NULL(arg_type);
211 if (!CheckType(predicate, arg_type)) {
212 MS_LOG(EXCEPTION) << "The expected is " << predicate->ToString() << ", not " << arg_type->ToString();
213 }
214 }
215 return TypeJoin(args_type_list);
216 }
217
GetPositiveAxis(int64_t axis_value,size_t increment)218 int64_t GetPositiveAxis(int64_t axis_value, size_t increment) {
219 if (axis_value < 0) {
220 axis_value = axis_value + SizeToLong(increment);
221 }
222
223 if (axis_value < 0) {
224 MS_LOG(EXCEPTION) << "axis_value should not still <0";
225 }
226
227 return axis_value;
228 }
229
230 // Return if two shapes can be broadcast.
231 // Broadcast shape is placed in broadcast_output_shape.
RealBroadcast(const std::string & op,ShapeVector x_shape,ShapeVector y_shape)232 ShapeVector RealBroadcast(const std::string &op, ShapeVector x_shape, ShapeVector y_shape) {
233 std::reverse(x_shape.begin(), x_shape.end());
234 std::reverse(y_shape.begin(), y_shape.end());
235 // Fill a placeholder value 1 which will be replaced later.
236 size_t std_len = x_shape.size() > y_shape.size() ? x_shape.size() : y_shape.size();
237 y_shape.resize(std_len, 1);
238 x_shape.resize(std_len, 1);
239
240 ShapeVector broadcast_shape;
241 for (size_t i = 0; i < std_len; i++) {
242 int64_t x_i = x_shape[i]; // i-th dimension of x
243 int64_t y_i = y_shape[i]; // i-th dimension of y
244 int64_t output_i = 0; // i-th dimension of the output
245 if (x_i == y_i) {
246 output_i = x_i;
247 } else if (x_i == 1) {
248 output_i = y_i;
249 } else if (y_i == 1) {
250 output_i = x_i;
251 } else {
252 MS_LOG(EXCEPTION)
253 << op
254 << " evaluator the shape of first tensor and the shape of second tensor do not meet the broadcasting "
255 "requirements";
256 }
257 broadcast_shape.push_back(output_i);
258 }
259 std::reverse(broadcast_shape.begin(), broadcast_shape.end());
260 return broadcast_shape;
261 }
262
BroadcastShape(ShapeVector shpx,ShapeVector shpy)263 ShapeVector BroadcastShape(ShapeVector shpx, ShapeVector shpy) {
264 int dlen = SizeToInt(shpx.size()) - SizeToInt(shpy.size());
265 if (dlen < 0) {
266 for (int i = 0; i < -dlen; ++i) {
267 (void)shpx.insert(shpx.begin(), 1);
268 }
269 } else if (dlen > 0) {
270 for (int i = 0; i < dlen; i++) {
271 (void)shpy.insert(shpy.begin(), 1);
272 }
273 }
274 if (shpx.size() != shpy.size()) {
275 MS_LOG(EXCEPTION) << "Failure: shpx.size() != shpy.size().";
276 }
277 ShapeVector shp;
278 for (size_t i = 0; i < shpx.size(); i++) {
279 auto a = shpx[i];
280 auto b = shpy[i];
281 if (a == 1) {
282 shp.push_back(b);
283 } else if (b == 1) {
284 shp.push_back(a);
285 } else if (a == -1) {
286 shp.push_back(b);
287 } else if (b == -1) {
288 shp.push_back(a);
289 } else if (a == b) {
290 shp.push_back(a);
291 } else {
292 return ShapeVector();
293 }
294 }
295 return shp;
296 }
297
TypeIdSize(const TypeId data_type)298 size_t TypeIdSize(const TypeId data_type) {
299 const size_t unsupported_type_error = 0;
300 auto iter = type_map.find(data_type);
301 if (iter != type_map.end()) {
302 return iter->second;
303 }
304 return unsupported_type_error;
305 }
306
ShapeSize(const std::vector<size_t> & shape)307 size_t ShapeSize(const std::vector<size_t> &shape) {
308 return std::accumulate(shape.begin(), shape.end(), IntToSize(1), std::multiplies<size_t>());
309 }
310
CheckMinMaxShape(const ShapeVector & shape,ShapeVector * min_shape,ShapeVector * max_shape)311 void CheckMinMaxShape(const ShapeVector &shape, ShapeVector *min_shape, ShapeVector *max_shape) {
312 *min_shape = (*min_shape).empty() ? shape : *min_shape;
313 *max_shape = (*max_shape).empty() ? shape : *max_shape;
314 }
315
GetUnsortedSegmentOpScalarArg(const AbstractBasePtrList & args_spec_list,const std::string & op_name)316 int64_t GetUnsortedSegmentOpScalarArg(const AbstractBasePtrList &args_spec_list, const std::string &op_name) {
317 int64_t num_segments_value = 0;
318 constexpr size_t scalar_index = 2;
319 if (args_spec_list[scalar_index]->isa<AbstractTensor>()) { // num_segments is Tensor
320 auto num_segments = args_spec_list[scalar_index]->cast<AbstractTensorPtr>();
321 MS_EXCEPTION_IF_NULL(num_segments);
322 auto num_segments_value_ptr = num_segments->BuildValue();
323 MS_EXCEPTION_IF_NULL(num_segments_value_ptr);
324 auto num_segments_tensor = num_segments_value_ptr->cast<tensor::TensorPtr>();
325 MS_EXCEPTION_IF_NULL(num_segments_tensor);
326 if (num_segments->element()->GetTypeTrack()->type_id() == TypeId::kNumberTypeInt64) {
327 num_segments_value = *static_cast<int64_t *>(num_segments_tensor->data_c());
328 } else {
329 num_segments_value = *static_cast<int32_t *>(num_segments_tensor->data_c());
330 }
331 } else if (args_spec_list[scalar_index]->isa<AbstractScalar>()) { // num_segments is Scalar
332 auto num_segments = CheckArg<AbstractScalar>(op_name, args_spec_list, scalar_index);
333 if (num_segments->GetTypeTrack()->type_id() == TypeId::kNumberTypeInt64) {
334 num_segments_value = GetValue<int64_t>(num_segments->BuildValue());
335 } else {
336 num_segments_value = GetValue<int32_t>(num_segments->BuildValue());
337 }
338 } else {
339 MS_LOG(EXCEPTION) << "num_segments incorrect type in " << op_name;
340 }
341 return num_segments_value;
342 }
343
MakeAbstractTensor(const ShapePtr & shape,const TypePtr & type)344 AbstractBasePtr MakeAbstractTensor(const ShapePtr &shape, const TypePtr &type) {
345 MS_EXCEPTION_IF_NULL(shape);
346 MS_EXCEPTION_IF_NULL(type);
347 AbstractBasePtr tensor = nullptr;
348 auto ret_vec = shape->shape();
349 ShapeVector min_shape_vec;
350 ShapeVector max_shape_vec;
351
352 if (!shape->min_shape().empty()) {
353 min_shape_vec = shape->min_shape();
354 }
355 if (!shape->max_shape().empty()) {
356 max_shape_vec = shape->max_shape();
357 }
358
359 auto ret_shape = std::make_shared<abstract::Shape>(ret_vec, min_shape_vec, max_shape_vec);
360 if (type->isa<TensorType>()) {
361 auto tensor_type = type->cast<TensorTypePtr>();
362 MS_EXCEPTION_IF_NULL(tensor_type);
363 auto element = std::make_shared<abstract::AbstractScalar>(kAnyValue, tensor_type->element());
364 tensor = std::make_shared<abstract::AbstractTensor>(element, ret_shape);
365 } else {
366 auto element = std::make_shared<abstract::AbstractScalar>(kAnyValue, type);
367 tensor = std::make_shared<abstract::AbstractTensor>(element, ret_shape);
368 }
369 return tensor;
370 }
371
MakeMonadAbstract(const MonadTypePtr & type)372 AbstractBasePtr MakeMonadAbstract(const MonadTypePtr &type) {
373 if (type->isa<UMonadType>()) {
374 return kUMonad->ToAbstract();
375 } else if (type->isa<IOMonadType>()) {
376 return kIOMonad->ToAbstract();
377 }
378 MS_EXCEPTION(UnknownError) << "Unsupported to convert type " << type->ToString() << " to monad abstract";
379 }
380
MakeAbstract(const BaseShapePtr & base_shape,const TypePtr & type)381 AbstractBasePtr MakeAbstract(const BaseShapePtr &base_shape, const TypePtr &type) {
382 MS_EXCEPTION_IF_NULL(base_shape);
383 MS_EXCEPTION_IF_NULL(type);
384 if ((base_shape->isa<Shape>())) {
385 auto shape = base_shape->cast<ShapePtr>();
386 MS_EXCEPTION_IF_NULL(shape);
387 auto shape_vec = shape->shape();
388 // if the size of shape list is empty, return an scalar abstract
389 if (shape_vec.empty() && (!type->isa<TensorType>())) {
390 abstract::AbstractScalarPtr abs_scalar = std::make_shared<abstract::AbstractScalar>(kAnyValue, type);
391 return abs_scalar;
392 }
393 return MakeAbstractTensor(shape, type);
394 } else if (base_shape->isa<TupleShape>() && type->isa<Tuple>()) {
395 auto shape_tuple = base_shape->cast<TupleShapePtr>();
396 auto type_tuple = type->cast<TuplePtr>();
397 AbstractBasePtrList ptr_list;
398 for (size_t it = 0; it < shape_tuple->size(); ++it) {
399 auto tensor_it = MakeAbstract((*shape_tuple)[it], (*type_tuple)[it]);
400 ptr_list.push_back(tensor_it);
401 }
402 auto tuple = std::make_shared<abstract::AbstractTuple>(ptr_list);
403 return tuple;
404 } else if (base_shape->isa<ListShape>() && type->isa<List>()) {
405 auto shape_list = base_shape->cast<ListShapePtr>();
406 auto type_list = type->cast<ListPtr>();
407 AbstractBasePtrList ptr_list;
408 for (size_t it = 0; it < shape_list->size(); ++it) {
409 auto tensor_it = MakeAbstract((*shape_list)[it], (*type_list)[it]);
410 ptr_list.push_back(tensor_it);
411 }
412 auto list = std::make_shared<abstract::AbstractList>(ptr_list);
413 return list;
414 } else if (base_shape->isa<NoShape>() && type->isa<TypeNone>()) {
415 // AbstractNone indicates there is no output for this CNode node.
416 auto abstract_none = std::make_shared<abstract::AbstractNone>();
417 return abstract_none;
418 } else if (type->isa<Monad>()) {
419 // Return monad abstract if it is monad type.
420 return MakeMonadAbstract(type->cast<MonadTypePtr>());
421 } else {
422 // When sparse enabled, the undetermined might be raised and eliminated in opt passes
423 auto context = MsContext::GetInstance();
424 MS_EXCEPTION_IF_NULL(context);
425 bool enable_sparse = context->get_param<bool>(MS_CTX_ENABLE_SPARSE);
426 if (enable_sparse) {
427 return std::make_shared<abstract::AbstractUndetermined>();
428 }
429 MS_LOG(EXCEPTION) << "evaluator return invalid shape " << base_shape->ToString() << "or type. " << type->ToString();
430 }
431 }
432 } // namespace abstract
433 } // namespace mindspore
434