• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3  *
4  * Copyright 2019 Huawei Technologies Co., Ltd
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #include "abstract/utils.h"
20 
21 #include <string>
22 #include <sstream>
23 #include <memory>
24 #include "utils/ms_context.h"
25 #include "utils/symbolic.h"
26 #include "abstract/param_validator.h"
27 
28 namespace mindspore {
29 namespace abstract {
30 const std::map<TypeId, size_t> type_map = {
31   {kNumberTypeBool, 1},       {kNumberTypeInt, 4},     {kNumberTypeInt8, 1},    {kNumberTypeInt16, 2},
32   {kNumberTypeInt32, 4},      {kNumberTypeInt64, 8},   {kNumberTypeUInt, 4},    {kNumberTypeUInt8, 1},
33   {kNumberTypeUInt16, 2},     {kNumberTypeUInt32, 4},  {kNumberTypeUInt64, 8},  {kNumberTypeFloat, 4},
34   {kNumberTypeFloat16, 2},    {kNumberTypeFloat32, 4}, {kNumberTypeFloat64, 8}, {kNumberTypeComplex64, 8},
35   {kNumberTypeComplex128, 16}};
36 
ValueJoin(const ValuePtr & value1,const ValuePtr & value2)37 ValuePtr ValueJoin(const ValuePtr &value1, const ValuePtr &value2) {
38   MS_EXCEPTION_IF_NULL(value1);
39   MS_EXCEPTION_IF_NULL(value2);
40   if (*value1 == *value2) {
41     return value1;
42   }
43   return kAnyValue;
44 }
45 
TypeJoin(const TypePtr & type1,const TypePtr & type2)46 TypePtr TypeJoin(const TypePtr &type1, const TypePtr &type2) {
47   MS_EXCEPTION_IF_NULL(type1);
48   MS_EXCEPTION_IF_NULL(type2);
49   if (*type1 == *type2) {
50     return type1;
51   }
52   return kAnyType;
53 }
54 
CalculateDynamicShape(const ShapePtr & shape1,const ShapePtr & shape2,const ShapeVector & dims)55 ShapePtr CalculateDynamicShape(const ShapePtr &shape1, const ShapePtr &shape2, const ShapeVector &dims) {
56   // calculate dynamic shape
57   ShapeVector min_dims(dims.size());
58   ShapeVector max_dims(dims.size());
59   MS_EXCEPTION_IF_NULL(shape1);
60   MS_EXCEPTION_IF_NULL(shape2);
61   for (size_t i = 0; i < dims.size(); ++i) {
62     if (dims[i] != Shape::SHP_ANY) {
63       min_dims[i] = max_dims[i] = dims[i];
64       continue;
65     }
66     if (shape1->shape()[i] != Shape::SHP_ANY && shape2->shape()[i] != Shape::SHP_ANY) {
67       min_dims[i] = std::min(shape1->shape()[i], shape2->shape()[i]);
68       max_dims[i] = std::max(shape1->shape()[i], shape2->shape()[i]);
69       continue;
70     }
71     if (shape1->shape()[i] == Shape::SHP_ANY && shape2->shape()[i] != Shape::SHP_ANY) {
72       if (shape1->min_shape().size() <= i || shape1->max_shape().size() <= i) {
73         MS_EXCEPTION(ValueError) << "Shape " << shape1->ToString()
74                                  << " has dynamic shape, but does not have min/max shape info.";
75       }
76       min_dims[i] = std::min(shape1->min_shape()[i], shape2->shape()[i]);
77       max_dims[i] = std::max(shape1->max_shape()[i], shape2->shape()[i]);
78       continue;
79     }
80     if (shape1->shape()[i] != Shape::SHP_ANY && shape2->shape()[i] == Shape::SHP_ANY) {
81       if (shape2->min_shape().size() <= i || shape2->max_shape().size() <= i) {
82         MS_EXCEPTION(ValueError) << "Shape " << shape1->ToString()
83                                  << " has dynamic shape, but does not have min/max shape info.";
84       }
85       min_dims[i] = std::min(shape1->shape()[i], shape2->min_shape()[i]);
86       max_dims[i] = std::max(shape1->shape()[i], shape2->max_shape()[i]);
87       continue;
88     }
89     // both shapes contains dynamic shape
90     if (shape1->min_shape().size() <= i || shape1->max_shape().size() <= i) {
91       MS_EXCEPTION(ValueError) << "Shape " << shape1->ToString()
92                                << " has dynamic shape, but does not have min/max shape info.";
93     }
94     if (shape2->min_shape().size() <= i || shape2->max_shape().size() <= i) {
95       MS_EXCEPTION(ValueError) << "Shape " << shape2->ToString()
96                                << " has dynamic shape, but does not have min/max shape info.";
97     }
98     min_dims[i] = std::min(shape1->min_shape()[i], shape2->min_shape()[i]);
99     max_dims[i] = std::max(shape1->max_shape()[i], shape2->max_shape()[i]);
100   }
101   return std::make_shared<Shape>(dims, min_dims, max_dims);
102 }
103 
ShapeJoin(const ShapePtr & shape1,const ShapePtr & shape2)104 ShapePtr ShapeJoin(const ShapePtr &shape1, const ShapePtr &shape2) {
105   MS_EXCEPTION_IF_NULL(shape1);
106   MS_EXCEPTION_IF_NULL(shape2);
107   if (*shape1 == *shape2) {
108     return shape1;
109   }
110   // lengths of two shapes are not same, join failed
111   if (shape1->shape().size() != shape2->shape().size()) {
112     // special case: shape(1), shape() -> shape(1)
113     if (shape1->shape().size() == 1 && shape1->shape()[0] == 1 && shape2->shape().empty()) {
114       return shape1;
115     }
116     if (shape2->shape().size() == 1 && shape2->shape()[0] == 1 && shape1->shape().empty()) {
117       return shape2;
118     }
119     return nullptr;
120   }
121   ShapeVector dims;
122   bool has_dynamic_shape = false;
123   dims.resize(shape1->shape().size());
124   for (std::size_t i = 0; i < shape1->shape().size(); i++) {
125     if (shape1->shape()[i] == shape2->shape()[i]) {
126       dims[i] = shape1->shape()[i];
127       if (shape1->shape()[i] == Shape::SHP_ANY) {
128         has_dynamic_shape = true;
129       }
130     } else {
131       dims[i] = Shape::SHP_ANY;
132       has_dynamic_shape = true;
133     }
134   }
135   if (!has_dynamic_shape) {
136     return std::make_shared<Shape>(dims);
137   }
138   return CalculateDynamicShape(shape1, shape2, dims);
139 }
140 
AbstractJoin(const AbstractBasePtrList & args_spec_list)141 AbstractBasePtr AbstractJoin(const AbstractBasePtrList &args_spec_list) {
142   if (args_spec_list.empty()) {
143     MS_LOG(EXCEPTION) << "AbstractJoin requires at least 1 params, while the input size is " << args_spec_list.size()
144                       << ".";
145   }
146   AbstractBasePtr arg_spec_tmp = args_spec_list[0];
147   MS_EXCEPTION_IF_NULL(arg_spec_tmp);
148   for (const auto &arg_spec : args_spec_list) {
149     MS_EXCEPTION_IF_NULL(arg_spec);
150     arg_spec_tmp = arg_spec_tmp->Join(arg_spec);
151     MS_EXCEPTION_IF_NULL(arg_spec_tmp);
152   }
153   return arg_spec_tmp;
154 }
155 
AbstractJoin(const AbstractBasePtrList & spec1,const AbstractBasePtrList & spec2)156 AbstractBasePtrList AbstractJoin(const AbstractBasePtrList &spec1, const AbstractBasePtrList &spec2) {
157   if (spec1.size() != spec2.size()) {
158     MS_LOG(EXCEPTION) << "Join failed as list don't have the same size. spec1: " << ::mindspore::ToString(spec1)
159                       << ", spec2: " << ::mindspore::ToString(spec2);
160   }
161   AbstractBasePtrList joined_list;
162   bool changes = false;
163   for (std::size_t i = 0; i < spec1.size(); i++) {
164     MS_EXCEPTION_IF_NULL(spec1[i]);
165     auto joined_elem = spec1[i]->Join(spec2[i]);
166     MS_EXCEPTION_IF_NULL(joined_elem);
167     if (joined_elem != spec1[i]) {
168       changes = true;
169     }
170     joined_list.push_back(joined_elem);
171   }
172   if (!changes) {
173     return spec1;
174   }
175   return joined_list;
176 }
177 
SensitivityTransform(const AbstractBasePtr & spec)178 AbstractBasePtr SensitivityTransform(const AbstractBasePtr &spec) {
179   AbstractFunctionPtr f_spec = dyn_cast<AbstractFunction>(spec);
180   if (f_spec != nullptr) {
181     return std::make_shared<AbstractScalar>(kAnyValue, std::make_shared<EnvType>());
182   }
183   return spec->Clone();
184 }
185 
186 namespace {
187 // Join all types in args_type_list;
TypeJoin(const TypePtrList & args_type_list)188 TypePtr TypeJoin(const TypePtrList &args_type_list) {
189   if (args_type_list.empty()) {
190     MS_LOG(EXCEPTION) << "args_type_list is empty";
191   }
192 
193   TypePtr type_tmp = args_type_list[0];
194   for (std::size_t i = 1; i < args_type_list.size(); i++) {
195     type_tmp = abstract::TypeJoin(type_tmp, args_type_list[i]);
196   }
197   return type_tmp;
198 }
199 }  // namespace
200 
CheckType(const TypePtr & expected_type,const TypePtr & x)201 bool CheckType(const TypePtr &expected_type, const TypePtr &x) {
202   // As x and predicate both are mindspore type statically, here we only to judge whether
203   // x is predicate or is a subclass of predicate.
204   return IsIdentidityOrSubclass(x, expected_type);
205 }
206 
CheckTypeList(const TypePtr & predicate,const TypePtrList & args_type_list)207 TypePtr CheckTypeList(const TypePtr &predicate, const TypePtrList &args_type_list) {
208   MS_EXCEPTION_IF_NULL(predicate);
209   for (const auto &arg_type : args_type_list) {
210     MS_EXCEPTION_IF_NULL(arg_type);
211     if (!CheckType(predicate, arg_type)) {
212       MS_LOG(EXCEPTION) << "The expected is " << predicate->ToString() << ", not " << arg_type->ToString();
213     }
214   }
215   return TypeJoin(args_type_list);
216 }
217 
GetPositiveAxis(int64_t axis_value,size_t increment)218 int64_t GetPositiveAxis(int64_t axis_value, size_t increment) {
219   if (axis_value < 0) {
220     axis_value = axis_value + SizeToLong(increment);
221   }
222 
223   if (axis_value < 0) {
224     MS_LOG(EXCEPTION) << "axis_value should not still <0";
225   }
226 
227   return axis_value;
228 }
229 
230 // Return if two shapes can be broadcast.
231 // Broadcast shape is placed in broadcast_output_shape.
RealBroadcast(const std::string & op,ShapeVector x_shape,ShapeVector y_shape)232 ShapeVector RealBroadcast(const std::string &op, ShapeVector x_shape, ShapeVector y_shape) {
233   std::reverse(x_shape.begin(), x_shape.end());
234   std::reverse(y_shape.begin(), y_shape.end());
235   // Fill a placeholder value 1 which will be replaced later.
236   size_t std_len = x_shape.size() > y_shape.size() ? x_shape.size() : y_shape.size();
237   y_shape.resize(std_len, 1);
238   x_shape.resize(std_len, 1);
239 
240   ShapeVector broadcast_shape;
241   for (size_t i = 0; i < std_len; i++) {
242     int64_t x_i = x_shape[i];  // i-th dimension of x
243     int64_t y_i = y_shape[i];  // i-th dimension of y
244     int64_t output_i = 0;      // i-th dimension of the output
245     if (x_i == y_i) {
246       output_i = x_i;
247     } else if (x_i == 1) {
248       output_i = y_i;
249     } else if (y_i == 1) {
250       output_i = x_i;
251     } else {
252       MS_LOG(EXCEPTION)
253         << op
254         << " evaluator the shape of first tensor and the shape of second tensor do not meet the broadcasting "
255            "requirements";
256     }
257     broadcast_shape.push_back(output_i);
258   }
259   std::reverse(broadcast_shape.begin(), broadcast_shape.end());
260   return broadcast_shape;
261 }
262 
BroadcastShape(ShapeVector shpx,ShapeVector shpy)263 ShapeVector BroadcastShape(ShapeVector shpx, ShapeVector shpy) {
264   int dlen = SizeToInt(shpx.size()) - SizeToInt(shpy.size());
265   if (dlen < 0) {
266     for (int i = 0; i < -dlen; ++i) {
267       (void)shpx.insert(shpx.begin(), 1);
268     }
269   } else if (dlen > 0) {
270     for (int i = 0; i < dlen; i++) {
271       (void)shpy.insert(shpy.begin(), 1);
272     }
273   }
274   if (shpx.size() != shpy.size()) {
275     MS_LOG(EXCEPTION) << "Failure: shpx.size() != shpy.size().";
276   }
277   ShapeVector shp;
278   for (size_t i = 0; i < shpx.size(); i++) {
279     auto a = shpx[i];
280     auto b = shpy[i];
281     if (a == 1) {
282       shp.push_back(b);
283     } else if (b == 1) {
284       shp.push_back(a);
285     } else if (a == -1) {
286       shp.push_back(b);
287     } else if (b == -1) {
288       shp.push_back(a);
289     } else if (a == b) {
290       shp.push_back(a);
291     } else {
292       return ShapeVector();
293     }
294   }
295   return shp;
296 }
297 
TypeIdSize(const TypeId data_type)298 size_t TypeIdSize(const TypeId data_type) {
299   const size_t unsupported_type_error = 0;
300   auto iter = type_map.find(data_type);
301   if (iter != type_map.end()) {
302     return iter->second;
303   }
304   return unsupported_type_error;
305 }
306 
ShapeSize(const std::vector<size_t> & shape)307 size_t ShapeSize(const std::vector<size_t> &shape) {
308   return std::accumulate(shape.begin(), shape.end(), IntToSize(1), std::multiplies<size_t>());
309 }
310 
CheckMinMaxShape(const ShapeVector & shape,ShapeVector * min_shape,ShapeVector * max_shape)311 void CheckMinMaxShape(const ShapeVector &shape, ShapeVector *min_shape, ShapeVector *max_shape) {
312   *min_shape = (*min_shape).empty() ? shape : *min_shape;
313   *max_shape = (*max_shape).empty() ? shape : *max_shape;
314 }
315 
GetUnsortedSegmentOpScalarArg(const AbstractBasePtrList & args_spec_list,const std::string & op_name)316 int64_t GetUnsortedSegmentOpScalarArg(const AbstractBasePtrList &args_spec_list, const std::string &op_name) {
317   int64_t num_segments_value = 0;
318   constexpr size_t scalar_index = 2;
319   if (args_spec_list[scalar_index]->isa<AbstractTensor>()) {  // num_segments is Tensor
320     auto num_segments = args_spec_list[scalar_index]->cast<AbstractTensorPtr>();
321     MS_EXCEPTION_IF_NULL(num_segments);
322     auto num_segments_value_ptr = num_segments->BuildValue();
323     MS_EXCEPTION_IF_NULL(num_segments_value_ptr);
324     auto num_segments_tensor = num_segments_value_ptr->cast<tensor::TensorPtr>();
325     MS_EXCEPTION_IF_NULL(num_segments_tensor);
326     if (num_segments->element()->GetTypeTrack()->type_id() == TypeId::kNumberTypeInt64) {
327       num_segments_value = *static_cast<int64_t *>(num_segments_tensor->data_c());
328     } else {
329       num_segments_value = *static_cast<int32_t *>(num_segments_tensor->data_c());
330     }
331   } else if (args_spec_list[scalar_index]->isa<AbstractScalar>()) {  // num_segments is Scalar
332     auto num_segments = CheckArg<AbstractScalar>(op_name, args_spec_list, scalar_index);
333     if (num_segments->GetTypeTrack()->type_id() == TypeId::kNumberTypeInt64) {
334       num_segments_value = GetValue<int64_t>(num_segments->BuildValue());
335     } else {
336       num_segments_value = GetValue<int32_t>(num_segments->BuildValue());
337     }
338   } else {
339     MS_LOG(EXCEPTION) << "num_segments incorrect type in " << op_name;
340   }
341   return num_segments_value;
342 }
343 
MakeAbstractTensor(const ShapePtr & shape,const TypePtr & type)344 AbstractBasePtr MakeAbstractTensor(const ShapePtr &shape, const TypePtr &type) {
345   MS_EXCEPTION_IF_NULL(shape);
346   MS_EXCEPTION_IF_NULL(type);
347   AbstractBasePtr tensor = nullptr;
348   auto ret_vec = shape->shape();
349   ShapeVector min_shape_vec;
350   ShapeVector max_shape_vec;
351 
352   if (!shape->min_shape().empty()) {
353     min_shape_vec = shape->min_shape();
354   }
355   if (!shape->max_shape().empty()) {
356     max_shape_vec = shape->max_shape();
357   }
358 
359   auto ret_shape = std::make_shared<abstract::Shape>(ret_vec, min_shape_vec, max_shape_vec);
360   if (type->isa<TensorType>()) {
361     auto tensor_type = type->cast<TensorTypePtr>();
362     MS_EXCEPTION_IF_NULL(tensor_type);
363     auto element = std::make_shared<abstract::AbstractScalar>(kAnyValue, tensor_type->element());
364     tensor = std::make_shared<abstract::AbstractTensor>(element, ret_shape);
365   } else {
366     auto element = std::make_shared<abstract::AbstractScalar>(kAnyValue, type);
367     tensor = std::make_shared<abstract::AbstractTensor>(element, ret_shape);
368   }
369   return tensor;
370 }
371 
MakeMonadAbstract(const MonadTypePtr & type)372 AbstractBasePtr MakeMonadAbstract(const MonadTypePtr &type) {
373   if (type->isa<UMonadType>()) {
374     return kUMonad->ToAbstract();
375   } else if (type->isa<IOMonadType>()) {
376     return kIOMonad->ToAbstract();
377   }
378   MS_EXCEPTION(UnknownError) << "Unsupported to convert type " << type->ToString() << " to monad abstract";
379 }
380 
MakeAbstract(const BaseShapePtr & base_shape,const TypePtr & type)381 AbstractBasePtr MakeAbstract(const BaseShapePtr &base_shape, const TypePtr &type) {
382   MS_EXCEPTION_IF_NULL(base_shape);
383   MS_EXCEPTION_IF_NULL(type);
384   if ((base_shape->isa<Shape>())) {
385     auto shape = base_shape->cast<ShapePtr>();
386     MS_EXCEPTION_IF_NULL(shape);
387     auto shape_vec = shape->shape();
388     // if the size of shape list is empty, return an scalar abstract
389     if (shape_vec.empty() && (!type->isa<TensorType>())) {
390       abstract::AbstractScalarPtr abs_scalar = std::make_shared<abstract::AbstractScalar>(kAnyValue, type);
391       return abs_scalar;
392     }
393     return MakeAbstractTensor(shape, type);
394   } else if (base_shape->isa<TupleShape>() && type->isa<Tuple>()) {
395     auto shape_tuple = base_shape->cast<TupleShapePtr>();
396     auto type_tuple = type->cast<TuplePtr>();
397     AbstractBasePtrList ptr_list;
398     for (size_t it = 0; it < shape_tuple->size(); ++it) {
399       auto tensor_it = MakeAbstract((*shape_tuple)[it], (*type_tuple)[it]);
400       ptr_list.push_back(tensor_it);
401     }
402     auto tuple = std::make_shared<abstract::AbstractTuple>(ptr_list);
403     return tuple;
404   } else if (base_shape->isa<ListShape>() && type->isa<List>()) {
405     auto shape_list = base_shape->cast<ListShapePtr>();
406     auto type_list = type->cast<ListPtr>();
407     AbstractBasePtrList ptr_list;
408     for (size_t it = 0; it < shape_list->size(); ++it) {
409       auto tensor_it = MakeAbstract((*shape_list)[it], (*type_list)[it]);
410       ptr_list.push_back(tensor_it);
411     }
412     auto list = std::make_shared<abstract::AbstractList>(ptr_list);
413     return list;
414   } else if (base_shape->isa<NoShape>() && type->isa<TypeNone>()) {
415     // AbstractNone indicates there is no output for this CNode node.
416     auto abstract_none = std::make_shared<abstract::AbstractNone>();
417     return abstract_none;
418   } else if (type->isa<Monad>()) {
419     // Return monad abstract if it is monad type.
420     return MakeMonadAbstract(type->cast<MonadTypePtr>());
421   } else {
422     // When sparse enabled, the undetermined might be raised and eliminated in opt passes
423     auto context = MsContext::GetInstance();
424     MS_EXCEPTION_IF_NULL(context);
425     bool enable_sparse = context->get_param<bool>(MS_CTX_ENABLE_SPARSE);
426     if (enable_sparse) {
427       return std::make_shared<abstract::AbstractUndetermined>();
428     }
429     MS_LOG(EXCEPTION) << "evaluator return invalid shape " << base_shape->ToString() << "or type. " << type->ToString();
430   }
431 }
432 }  // namespace abstract
433 }  // namespace mindspore
434