• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3  *
4  * Copyright 2019-2022 Huawei Technologies Co., Ltd
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #include "abstract/utils.h"
20 
21 #include "utils/ms_context.h"
22 #include "utils/symbolic.h"
23 #include "abstract/abstract_function.h"
24 
25 namespace mindspore {
26 namespace abstract {
27 const std::map<TypeId, size_t> type_map = {
28   {kNumberTypeBool, 1},        {kNumberTypeInt, 4},      {kNumberTypeInt8, 1},    {kNumberTypeInt16, 2},
29   {kNumberTypeInt32, 4},       {kNumberTypeInt64, 8},    {kNumberTypeUInt, 4},    {kNumberTypeUInt8, 1},
30   {kNumberTypeUInt16, 2},      {kNumberTypeUInt32, 4},   {kNumberTypeUInt64, 8},  {kNumberTypeFloat, 4},
31   {kNumberTypeFloat16, 2},     {kNumberTypeFloat32, 4},  {kNumberTypeFloat64, 8}, {kNumberTypeComplex64, 8},
32   {kNumberTypeComplex128, 16}, {kNumberTypeBFloat16, 2}, {kNumberTypeInt4, 1}};
33 
ValueJoin(const ValuePtr & value1,const ValuePtr & value2)34 ValuePtr ValueJoin(const ValuePtr &value1, const ValuePtr &value2) {
35   MS_EXCEPTION_IF_NULL(value1);
36   MS_EXCEPTION_IF_NULL(value2);
37   if (*value1 == *value2) {
38     return value1;
39   }
40   return kValueAny;
41 }
42 
TypeJoin(const TypePtr & type1,const TypePtr & type2)43 TypePtr TypeJoin(const TypePtr &type1, const TypePtr &type2) {
44   MS_EXCEPTION_IF_NULL(type1);
45   MS_EXCEPTION_IF_NULL(type2);
46   if (*type1 == *type2) {
47     return type1;
48   }
49   return kTypeAny;
50 }
51 
IsShapesDynamicRank(const std::vector<ShapeVector> & shapes)52 bool IsShapesDynamicRank(const std::vector<ShapeVector> &shapes) {
53   return std::any_of(shapes.begin(), shapes.end(), [](const ShapeVector &shape) {
54     return std::any_of(shape.begin(), shape.end(), [](int64_t dim) { return dim == Shape::kShapeRankAny; });
55   });
56 }
57 
SingleElementShapeJoin(const ShapePtr & shape1,const ShapePtr & shape2)58 ShapePtr SingleElementShapeJoin(const ShapePtr &shape1, const ShapePtr &shape2) {
59   // special case: shape(1), shape() -> shape(1)
60   if (shape1->shape().size() == 1 && shape1->shape()[0] == 1 && shape2->shape().empty()) {
61     return shape1;
62   }
63   if (shape2->shape().size() == 1 && shape2->shape()[0] == 1 && shape1->shape().empty()) {
64     return shape2;
65   }
66   return nullptr;
67 }
68 
SingleShapeValueJoin(const ShapeValueDType & shape_value1,const ShapeValueDType & shape_value2)69 ShapeValueDType SingleShapeValueJoin(const ShapeValueDType &shape_value1, const ShapeValueDType &shape_value2) {
70   if (shape_value1 == shape_value2) {
71     return shape_value1;
72   }
73   return Shape::kShapeDimAny;
74 }
75 
ShapeJoin(const ShapePtr & shape1,const ShapePtr & shape2)76 ShapePtr ShapeJoin(const ShapePtr &shape1, const ShapePtr &shape2) {
77   MS_EXCEPTION_IF_NULL(shape1);
78   MS_EXCEPTION_IF_NULL(shape2);
79   if (*shape1 == *shape2) {
80     return shape1;
81   }
82 
83   bool has_dynamic_rank = IsShapesDynamicRank({shape1->shape(), shape2->shape()});
84   if (has_dynamic_rank) {
85     return std::make_shared<Shape>(ShapeVector{Shape::kShapeRankAny});
86   }
87   // lengths of two shapes are not same, join failed
88   if (shape1->shape().size() != shape2->shape().size()) {
89     auto joined_shape = SingleElementShapeJoin(shape1, shape2);
90     if (joined_shape != nullptr) {
91       return joined_shape;
92     }
93     return std::make_shared<Shape>(ShapeVector({Shape::kShapeRankAny}));
94   }
95   ShapeVector dims(shape1->shape().size());
96   for (std::size_t i = 0; i < shape1->shape().size(); i++) {
97     auto joined_shape_value = SingleShapeValueJoin(shape1->shape()[i], shape2->shape()[i]);
98     if (joined_shape_value == Shape::kShapeError) {
99       return nullptr;
100     }
101     dims[i] = joined_shape_value;
102   }
103   return std::make_shared<Shape>(dims);
104 }
105 
AbstractJoin(const AbstractBasePtrList & args_abs_list)106 AbstractBasePtr AbstractJoin(const AbstractBasePtrList &args_abs_list) {
107   if (args_abs_list.empty()) {
108     MS_LOG(INTERNAL_EXCEPTION) << "AbstractJoin requires at least 1 params, while the input size is "
109                                << args_abs_list.size() << ".";
110   }
111   AbstractBasePtr arg_spec_tmp = args_abs_list[0];
112   MS_EXCEPTION_IF_NULL(arg_spec_tmp);
113   for (const auto &arg_spec : args_abs_list) {
114     MS_EXCEPTION_IF_NULL(arg_spec);
115     arg_spec_tmp = arg_spec_tmp->Join(arg_spec);
116     MS_EXCEPTION_IF_NULL(arg_spec_tmp);
117   }
118   return arg_spec_tmp;
119 }
120 
AbstractJoin(const AbstractBasePtrList & lhs,const AbstractBasePtrList & rhs)121 AbstractBasePtrList AbstractJoin(const AbstractBasePtrList &lhs, const AbstractBasePtrList &rhs) {
122   if (lhs.size() != rhs.size()) {
123     MS_LOG(EXCEPTION) << "Join failed as list don't have the same size. lhs: " << ::mindspore::ToString(lhs)
124                       << ", rhs: " << ::mindspore::ToString(rhs);
125   }
126   AbstractBasePtrList joined_list;
127   bool changes = false;
128   for (std::size_t i = 0; i < lhs.size(); i++) {
129     MS_EXCEPTION_IF_NULL(lhs[i]);
130     auto joined_elem = lhs[i]->Join(rhs[i]);
131     MS_EXCEPTION_IF_NULL(joined_elem);
132     if (joined_elem != lhs[i]) {
133       changes = true;
134     }
135     joined_list.push_back(joined_elem);
136   }
137   if (!changes) {
138     return lhs;
139   }
140   return joined_list;
141 }
142 
AbstractBroaden(const AbstractBasePtr & abs)143 AbstractBasePtr AbstractBroaden(const AbstractBasePtr &abs) {
144   MS_EXCEPTION_IF_NULL(abs);
145   if (abs->isa<AbstractSequence>() && !abs->isa<AbstractSparseTensor>()) {
146     auto sequence_abs = abs->cast<AbstractSequencePtr>();
147     if (sequence_abs->dynamic_len()) {
148       auto elem_abs = sequence_abs->dynamic_len_element_abs();
149       auto cloned_abs = sequence_abs->Clone()->cast<AbstractSequencePtr>();
150       cloned_abs->set_dynamic_len_element_abs(elem_abs);
151       return cloned_abs;
152     }
153     std::vector<AbstractBasePtr> new_elements;
154     new_elements.reserve(sequence_abs->elements().size());
155     (void)std::transform(sequence_abs->elements().cbegin(), sequence_abs->elements().cend(),
156                          std::back_inserter(new_elements), AbstractBroaden);
157     if (sequence_abs->isa<AbstractTuple>()) {
158       return std::make_shared<AbstractTuple>(new_elements, sequence_abs->sequence_nodes());
159     }
160     if (sequence_abs->isa<AbstractList>()) {
161       return std::make_shared<AbstractList>(new_elements, sequence_abs->sequence_nodes());
162     }
163     MS_INTERNAL_EXCEPTION(TypeError) << "Unknown AbstractSequence type:" << abs->ToString();
164   }
165   if (abs->isa<AbstractDictionary>()) {
166     auto abs_dict = abs->cast<AbstractDictionaryPtr>();
167     const auto &origin_kv = abs_dict->elements();
168     std::vector<AbstractElementPair> kv;
169     (void)std::transform(origin_kv.cbegin(), origin_kv.cend(), std::back_inserter(kv),
170                          [](const AbstractElementPair &item) {
171                            MS_EXCEPTION_IF_NULL(item.second);
172                            return std::make_pair(item.first, AbstractBroaden(item.second));
173                          });
174     return std::make_shared<AbstractDictionary>(kv);
175   }
176   if (abs->isa<AbstractScalar>()) {
177     auto arg_type = abs->BuildType();
178     MS_EXCEPTION_IF_NULL(arg_type);
179     auto abs_scalar = abs->cast<AbstractScalarPtr>();
180     if (arg_type->isa<Number>() || arg_type->isa<String>()) {
181       abs_scalar->set_is_variable(true);
182     }
183   }
184   return abs->Broaden();
185 }
186 
SensitivityTransform(const AbstractBasePtr & spec)187 AbstractBasePtr SensitivityTransform(const AbstractBasePtr &spec) {
188   auto f_spec = dyn_cast_ptr<AbstractFunction>(spec);
189   if (f_spec != nullptr) {
190     return std::make_shared<AbstractScalar>(kValueAny, std::make_shared<EnvType>());
191   }
192   return spec->Clone();
193 }
194 
BroadcastShape(ShapeVector shpx,ShapeVector shpy)195 ShapeVector BroadcastShape(ShapeVector shpx, ShapeVector shpy) {
196   int dlen = SizeToInt(shpx.size()) - SizeToInt(shpy.size());
197   if (dlen < 0) {
198     for (int i = 0; i < -dlen; ++i) {
199       (void)shpx.insert(shpx.begin(), 1);
200     }
201   } else if (dlen > 0) {
202     for (int i = 0; i < dlen; i++) {
203       (void)shpy.insert(shpy.begin(), 1);
204     }
205   }
206   if (shpx.size() != shpy.size()) {
207     MS_LOG(INTERNAL_EXCEPTION) << "Failure: shpx.size() != shpy.size().";
208   }
209   ShapeVector shp;
210   for (size_t i = 0; i < shpx.size(); i++) {
211     auto a = shpx[i];
212     auto b = shpy[i];
213     if (a == 1) {
214       shp.push_back(b);
215     } else if (b == 1) {
216       shp.push_back(a);
217     } else if (a == -1) {
218       shp.push_back(b);
219     } else if (b == -1) {
220       shp.push_back(a);
221     } else if (a == b) {
222       shp.push_back(a);
223     } else {
224       return ShapeVector();
225     }
226   }
227   return shp;
228 }
229 
TypeIdSize(const TypeId data_type)230 size_t TypeIdSize(const TypeId data_type) {
231   const size_t unsupported_type_error = 0;
232   auto iter = type_map.find(data_type);
233   if (iter != type_map.end()) {
234     return iter->second;
235   }
236   return unsupported_type_error;
237 }
238 
MakeAbstractTensor(const ShapePtr & shape,const TypePtr & type)239 AbstractBasePtr MakeAbstractTensor(const ShapePtr &shape, const TypePtr &type) {
240   MS_EXCEPTION_IF_NULL(shape);
241   MS_EXCEPTION_IF_NULL(type);
242   AbstractBasePtr tensor = nullptr;
243 
244   auto ret_shape = shape->Clone();
245   if (type->isa<TensorType>()) {
246     auto tensor_type = type->cast_ptr<TensorType>();
247     MS_EXCEPTION_IF_NULL(tensor_type);
248     auto element = std::make_shared<abstract::AbstractScalar>(kValueAny, tensor_type->element());
249     tensor = std::make_shared<abstract::AbstractTensor>(element, ret_shape);
250   } else {
251     auto element = std::make_shared<abstract::AbstractScalar>(kValueAny, type);
252     tensor = std::make_shared<abstract::AbstractTensor>(element, ret_shape);
253   }
254   return tensor;
255 }
256 
MakeMonadAbstract(const MonadTypePtr & type)257 AbstractBasePtr MakeMonadAbstract(const MonadTypePtr &type) {
258   if (type->isa<UMonadType>()) {
259     return kUMonad->ToAbstract();
260   } else if (type->isa<IOMonadType>()) {
261     return kIOMonad->ToAbstract();
262   }
263   MS_INTERNAL_EXCEPTION(UnknownError) << "Unsupported to convert type " << type->ToString() << " to monad abstract";
264 }
265 
MakeAbstract(const BaseShapePtr & base_shape,const TypePtr & type)266 AbstractBasePtr MakeAbstract(const BaseShapePtr &base_shape, const TypePtr &type) {
267   MS_EXCEPTION_IF_NULL(base_shape);
268   MS_EXCEPTION_IF_NULL(type);
269   if ((base_shape->isa<Shape>())) {
270     auto shape = base_shape->cast<ShapePtr>();
271     MS_EXCEPTION_IF_NULL(shape);
272     auto shape_vec = shape->shape();
273     // if the size of shape list is empty, return an scalar abstract
274     if (shape_vec.empty() && (!type->isa<TensorType>())) {
275       abstract::AbstractScalarPtr abs_scalar = std::make_shared<abstract::AbstractScalar>(kValueAny, type);
276       return abs_scalar;
277     }
278     return MakeAbstractTensor(shape, type);
279   } else if (base_shape->isa<NoShape>() && type->isa<Type>()) {
280     return std::make_shared<abstract::AbstractScalar>(kValueAny, type);
281   } else if (base_shape->isa<TupleShape>() && type->isa<Tuple>()) {
282     auto shape_tuple = base_shape->cast_ptr<TupleShape>();
283     auto type_tuple = type->cast_ptr<Tuple>();
284     AbstractBasePtrList ptr_list;
285     for (size_t it = 0; it < shape_tuple->size(); ++it) {
286       auto tensor_it = MakeAbstract((*shape_tuple)[it], (*type_tuple)[it]);
287       ptr_list.push_back(tensor_it);
288     }
289     auto tuple = std::make_shared<abstract::AbstractTuple>(ptr_list);
290     return tuple;
291   } else if (base_shape->isa<ListShape>() && type->isa<List>()) {
292     auto shape_list = base_shape->cast_ptr<ListShape>();
293     auto type_list = type->cast_ptr<List>();
294     AbstractBasePtrList ptr_list;
295     for (size_t it = 0; it < shape_list->size(); ++it) {
296       auto tensor_it = MakeAbstract((*shape_list)[it], (*type_list)[it]);
297       ptr_list.push_back(tensor_it);
298     }
299     auto list = std::make_shared<abstract::AbstractList>(ptr_list);
300     return list;
301   } else if (base_shape->isa<NoShape>() && type->isa<TypeNone>()) {
302     // AbstractNone indicates there is no output for this CNode node.
303     auto abstract_none = std::make_shared<abstract::AbstractNone>();
304     return abstract_none;
305   } else if (type->isa<Monad>()) {
306     // Return monad abstract if it is monad type.
307     return MakeMonadAbstract(type->cast<MonadTypePtr>());
308   }
309   MS_LOG(INTERNAL_EXCEPTION) << "Evaluator return invalid shape " << base_shape->ToString() << " or type. "
310                              << type->ToString();
311 }
312 
SetVariableFlag(const AbstractBasePtr & abs)313 void SetVariableFlag(const AbstractBasePtr &abs) {
314   if (!abs->isa<abstract::AbstractFunction>()) {
315     return;
316   }
317   const auto func_abs = abs->cast_ptr<abstract::AbstractFunction>();
318   MS_EXCEPTION_IF_NULL(func_abs);
319   abstract::FuncGraphAbstractClosure *closure_abs = nullptr;
320   auto partial_closure_abs = func_abs->cast_ptr<abstract::PartialAbstractClosure>();
321   if (partial_closure_abs != nullptr) {
322     closure_abs = partial_closure_abs->fn()->cast_ptr<abstract::FuncGraphAbstractClosure>();
323   } else {
324     closure_abs = func_abs->cast_ptr<abstract::FuncGraphAbstractClosure>();
325   }
326   if (closure_abs != nullptr) {
327     auto func = closure_abs->func_graph();
328     MS_EXCEPTION_IF_NULL(func);
329     func->set_is_tensor_condition_branch(true);
330     MS_LOG(DEBUG) << "Set is_tensor_condition_branch for func_graph:" << func->ToString();
331   }
332 }
333 
334 namespace {
GetFuncGraphFromAbs(const abstract::AbstractBasePtr & abs,const AnfNodePtr & call_node)335 FuncGraphPtr GetFuncGraphFromAbs(const abstract::AbstractBasePtr &abs, const AnfNodePtr &call_node) {
336   MS_EXCEPTION_IF_NULL(call_node);
337   if (abs == nullptr) {
338     MS_LOG(ERROR) << "Null abstract, current node: " << call_node->DebugString();
339     return nullptr;
340   }
341   if (abs->isa<abstract::FuncGraphAbstractClosure>()) {
342     auto abs_func_graph = abs->cast<abstract::FuncGraphAbstractClosurePtr>();
343     MS_EXCEPTION_IF_NULL(abs_func_graph);
344     if (!abs_func_graph->specialized()) {
345       MS_LOG(INFO) << "Unspecialized func graph abstract: " << abs_func_graph->ToString()
346                    << ", node: " << call_node->DebugString();
347     }
348     return abs_func_graph->func_graph();
349   }
350   if (abs->isa<abstract::PartialAbstractClosure>()) {
351     auto abs_partial_closure = abs->cast<abstract::PartialAbstractClosurePtr>();
352     MS_EXCEPTION_IF_NULL(abs_partial_closure);
353     auto abs_func = abs_partial_closure->fn();
354     return GetFuncGraphFromAbs(abs_func, call_node);
355   }
356   MS_LOG(ERROR) << "Unexpected abs: " << abs->ToString() << ", call node: " << call_node->DebugString();
357   return nullptr;
358 }
359 }  // namespace
360 
GetFuncGraphsFromCallNode(const CNodePtr & call_node)361 std::vector<FuncGraphPtr> GetFuncGraphsFromCallNode(const CNodePtr &call_node) {
362   MS_EXCEPTION_IF_NULL(call_node);
363   auto func_node = call_node->input(0);
364   if (IsPrimitiveCNode(func_node, prim::kPrimPartial)) {
365     func_node = func_node->cast<CNodePtr>()->input(1);
366   }
367   if (IsValueNode<FuncGraph>(func_node)) {
368     return {GetValueNode<FuncGraphPtr>(func_node)};
369   }
370   auto abs = func_node->abstract();
371   MS_EXCEPTION_IF_NULL(abs);
372   if (abs == nullptr) {
373     MS_LOG(ERROR) << "Null abstract, current call node: " << call_node->DebugString();
374     return {};
375   }
376   if (!abs->isa<abstract::AbstractFunction>()) {
377     MS_LOG(ERROR) << "Unexpected abs: " << abs->ToString() << ", call_node: " << call_node->DebugString();
378     return {};
379   }
380   auto abs_func = abs->cast<abstract::AbstractFunctionPtr>();
381   MS_EXCEPTION_IF_NULL(abs_func);
382   std::vector<FuncGraphPtr> func_graphs;
383   if (abs->isa<abstract::AbstractFuncUnion>()) {
384     auto visit_func = [&func_graphs, &call_node](const abstract::AbstractFuncAtomPtr &poss) {
385       (void)func_graphs.emplace_back(GetFuncGraphFromAbs(poss, call_node));
386     };
387     abs_func->Visit(visit_func);
388   } else {
389     (void)func_graphs.emplace_back(GetFuncGraphFromAbs(abs_func, call_node));
390   }
391   bool exist_null_fg =
392     std::any_of(func_graphs.cbegin(), func_graphs.cend(), [](const FuncGraphPtr &fg) { return fg == nullptr; });
393   if (exist_null_fg) {
394     MS_LOG(ERROR) << "Get func graphs from abstract failed!";
395     return {};
396   }
397   return func_graphs;
398 }
399 }  // namespace abstract
400 }  // namespace mindspore
401