1 /**
2 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3 *
4 * Copyright 2019-2022 Huawei Technologies Co., Ltd
5 *
6 * Licensed under the Apache License, Version 2.0 (the "License");
7 * you may not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an "AS IS" BASIS,
14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 #include "abstract/utils.h"
20
21 #include "utils/ms_context.h"
22 #include "utils/symbolic.h"
23 #include "abstract/abstract_function.h"
24
25 namespace mindspore {
26 namespace abstract {
27 const std::map<TypeId, size_t> type_map = {
28 {kNumberTypeBool, 1}, {kNumberTypeInt, 4}, {kNumberTypeInt8, 1}, {kNumberTypeInt16, 2},
29 {kNumberTypeInt32, 4}, {kNumberTypeInt64, 8}, {kNumberTypeUInt, 4}, {kNumberTypeUInt8, 1},
30 {kNumberTypeUInt16, 2}, {kNumberTypeUInt32, 4}, {kNumberTypeUInt64, 8}, {kNumberTypeFloat, 4},
31 {kNumberTypeFloat16, 2}, {kNumberTypeFloat32, 4}, {kNumberTypeFloat64, 8}, {kNumberTypeComplex64, 8},
32 {kNumberTypeComplex128, 16}, {kNumberTypeBFloat16, 2}, {kNumberTypeInt4, 1}};
33
ValueJoin(const ValuePtr & value1,const ValuePtr & value2)34 ValuePtr ValueJoin(const ValuePtr &value1, const ValuePtr &value2) {
35 MS_EXCEPTION_IF_NULL(value1);
36 MS_EXCEPTION_IF_NULL(value2);
37 if (*value1 == *value2) {
38 return value1;
39 }
40 return kValueAny;
41 }
42
TypeJoin(const TypePtr & type1,const TypePtr & type2)43 TypePtr TypeJoin(const TypePtr &type1, const TypePtr &type2) {
44 MS_EXCEPTION_IF_NULL(type1);
45 MS_EXCEPTION_IF_NULL(type2);
46 if (*type1 == *type2) {
47 return type1;
48 }
49 return kTypeAny;
50 }
51
IsShapesDynamicRank(const std::vector<ShapeVector> & shapes)52 bool IsShapesDynamicRank(const std::vector<ShapeVector> &shapes) {
53 return std::any_of(shapes.begin(), shapes.end(), [](const ShapeVector &shape) {
54 return std::any_of(shape.begin(), shape.end(), [](int64_t dim) { return dim == Shape::kShapeRankAny; });
55 });
56 }
57
SingleElementShapeJoin(const ShapePtr & shape1,const ShapePtr & shape2)58 ShapePtr SingleElementShapeJoin(const ShapePtr &shape1, const ShapePtr &shape2) {
59 // special case: shape(1), shape() -> shape(1)
60 if (shape1->shape().size() == 1 && shape1->shape()[0] == 1 && shape2->shape().empty()) {
61 return shape1;
62 }
63 if (shape2->shape().size() == 1 && shape2->shape()[0] == 1 && shape1->shape().empty()) {
64 return shape2;
65 }
66 return nullptr;
67 }
68
SingleShapeValueJoin(const ShapeValueDType & shape_value1,const ShapeValueDType & shape_value2)69 ShapeValueDType SingleShapeValueJoin(const ShapeValueDType &shape_value1, const ShapeValueDType &shape_value2) {
70 if (shape_value1 == shape_value2) {
71 return shape_value1;
72 }
73 return Shape::kShapeDimAny;
74 }
75
ShapeJoin(const ShapePtr & shape1,const ShapePtr & shape2)76 ShapePtr ShapeJoin(const ShapePtr &shape1, const ShapePtr &shape2) {
77 MS_EXCEPTION_IF_NULL(shape1);
78 MS_EXCEPTION_IF_NULL(shape2);
79 if (*shape1 == *shape2) {
80 return shape1;
81 }
82
83 bool has_dynamic_rank = IsShapesDynamicRank({shape1->shape(), shape2->shape()});
84 if (has_dynamic_rank) {
85 return std::make_shared<Shape>(ShapeVector{Shape::kShapeRankAny});
86 }
87 // lengths of two shapes are not same, join failed
88 if (shape1->shape().size() != shape2->shape().size()) {
89 auto joined_shape = SingleElementShapeJoin(shape1, shape2);
90 if (joined_shape != nullptr) {
91 return joined_shape;
92 }
93 return std::make_shared<Shape>(ShapeVector({Shape::kShapeRankAny}));
94 }
95 ShapeVector dims(shape1->shape().size());
96 for (std::size_t i = 0; i < shape1->shape().size(); i++) {
97 auto joined_shape_value = SingleShapeValueJoin(shape1->shape()[i], shape2->shape()[i]);
98 if (joined_shape_value == Shape::kShapeError) {
99 return nullptr;
100 }
101 dims[i] = joined_shape_value;
102 }
103 return std::make_shared<Shape>(dims);
104 }
105
AbstractJoin(const AbstractBasePtrList & args_abs_list)106 AbstractBasePtr AbstractJoin(const AbstractBasePtrList &args_abs_list) {
107 if (args_abs_list.empty()) {
108 MS_LOG(INTERNAL_EXCEPTION) << "AbstractJoin requires at least 1 params, while the input size is "
109 << args_abs_list.size() << ".";
110 }
111 AbstractBasePtr arg_spec_tmp = args_abs_list[0];
112 MS_EXCEPTION_IF_NULL(arg_spec_tmp);
113 for (const auto &arg_spec : args_abs_list) {
114 MS_EXCEPTION_IF_NULL(arg_spec);
115 arg_spec_tmp = arg_spec_tmp->Join(arg_spec);
116 MS_EXCEPTION_IF_NULL(arg_spec_tmp);
117 }
118 return arg_spec_tmp;
119 }
120
AbstractJoin(const AbstractBasePtrList & lhs,const AbstractBasePtrList & rhs)121 AbstractBasePtrList AbstractJoin(const AbstractBasePtrList &lhs, const AbstractBasePtrList &rhs) {
122 if (lhs.size() != rhs.size()) {
123 MS_LOG(EXCEPTION) << "Join failed as list don't have the same size. lhs: " << ::mindspore::ToString(lhs)
124 << ", rhs: " << ::mindspore::ToString(rhs);
125 }
126 AbstractBasePtrList joined_list;
127 bool changes = false;
128 for (std::size_t i = 0; i < lhs.size(); i++) {
129 MS_EXCEPTION_IF_NULL(lhs[i]);
130 auto joined_elem = lhs[i]->Join(rhs[i]);
131 MS_EXCEPTION_IF_NULL(joined_elem);
132 if (joined_elem != lhs[i]) {
133 changes = true;
134 }
135 joined_list.push_back(joined_elem);
136 }
137 if (!changes) {
138 return lhs;
139 }
140 return joined_list;
141 }
142
AbstractBroaden(const AbstractBasePtr & abs)143 AbstractBasePtr AbstractBroaden(const AbstractBasePtr &abs) {
144 MS_EXCEPTION_IF_NULL(abs);
145 if (abs->isa<AbstractSequence>() && !abs->isa<AbstractSparseTensor>()) {
146 auto sequence_abs = abs->cast<AbstractSequencePtr>();
147 if (sequence_abs->dynamic_len()) {
148 auto elem_abs = sequence_abs->dynamic_len_element_abs();
149 auto cloned_abs = sequence_abs->Clone()->cast<AbstractSequencePtr>();
150 cloned_abs->set_dynamic_len_element_abs(elem_abs);
151 return cloned_abs;
152 }
153 std::vector<AbstractBasePtr> new_elements;
154 new_elements.reserve(sequence_abs->elements().size());
155 (void)std::transform(sequence_abs->elements().cbegin(), sequence_abs->elements().cend(),
156 std::back_inserter(new_elements), AbstractBroaden);
157 if (sequence_abs->isa<AbstractTuple>()) {
158 return std::make_shared<AbstractTuple>(new_elements, sequence_abs->sequence_nodes());
159 }
160 if (sequence_abs->isa<AbstractList>()) {
161 return std::make_shared<AbstractList>(new_elements, sequence_abs->sequence_nodes());
162 }
163 MS_INTERNAL_EXCEPTION(TypeError) << "Unknown AbstractSequence type:" << abs->ToString();
164 }
165 if (abs->isa<AbstractDictionary>()) {
166 auto abs_dict = abs->cast<AbstractDictionaryPtr>();
167 const auto &origin_kv = abs_dict->elements();
168 std::vector<AbstractElementPair> kv;
169 (void)std::transform(origin_kv.cbegin(), origin_kv.cend(), std::back_inserter(kv),
170 [](const AbstractElementPair &item) {
171 MS_EXCEPTION_IF_NULL(item.second);
172 return std::make_pair(item.first, AbstractBroaden(item.second));
173 });
174 return std::make_shared<AbstractDictionary>(kv);
175 }
176 if (abs->isa<AbstractScalar>()) {
177 auto arg_type = abs->BuildType();
178 MS_EXCEPTION_IF_NULL(arg_type);
179 auto abs_scalar = abs->cast<AbstractScalarPtr>();
180 if (arg_type->isa<Number>() || arg_type->isa<String>()) {
181 abs_scalar->set_is_variable(true);
182 }
183 }
184 return abs->Broaden();
185 }
186
SensitivityTransform(const AbstractBasePtr & spec)187 AbstractBasePtr SensitivityTransform(const AbstractBasePtr &spec) {
188 auto f_spec = dyn_cast_ptr<AbstractFunction>(spec);
189 if (f_spec != nullptr) {
190 return std::make_shared<AbstractScalar>(kValueAny, std::make_shared<EnvType>());
191 }
192 return spec->Clone();
193 }
194
BroadcastShape(ShapeVector shpx,ShapeVector shpy)195 ShapeVector BroadcastShape(ShapeVector shpx, ShapeVector shpy) {
196 int dlen = SizeToInt(shpx.size()) - SizeToInt(shpy.size());
197 if (dlen < 0) {
198 for (int i = 0; i < -dlen; ++i) {
199 (void)shpx.insert(shpx.begin(), 1);
200 }
201 } else if (dlen > 0) {
202 for (int i = 0; i < dlen; i++) {
203 (void)shpy.insert(shpy.begin(), 1);
204 }
205 }
206 if (shpx.size() != shpy.size()) {
207 MS_LOG(INTERNAL_EXCEPTION) << "Failure: shpx.size() != shpy.size().";
208 }
209 ShapeVector shp;
210 for (size_t i = 0; i < shpx.size(); i++) {
211 auto a = shpx[i];
212 auto b = shpy[i];
213 if (a == 1) {
214 shp.push_back(b);
215 } else if (b == 1) {
216 shp.push_back(a);
217 } else if (a == -1) {
218 shp.push_back(b);
219 } else if (b == -1) {
220 shp.push_back(a);
221 } else if (a == b) {
222 shp.push_back(a);
223 } else {
224 return ShapeVector();
225 }
226 }
227 return shp;
228 }
229
TypeIdSize(const TypeId data_type)230 size_t TypeIdSize(const TypeId data_type) {
231 const size_t unsupported_type_error = 0;
232 auto iter = type_map.find(data_type);
233 if (iter != type_map.end()) {
234 return iter->second;
235 }
236 return unsupported_type_error;
237 }
238
MakeAbstractTensor(const ShapePtr & shape,const TypePtr & type)239 AbstractBasePtr MakeAbstractTensor(const ShapePtr &shape, const TypePtr &type) {
240 MS_EXCEPTION_IF_NULL(shape);
241 MS_EXCEPTION_IF_NULL(type);
242 AbstractBasePtr tensor = nullptr;
243
244 auto ret_shape = shape->Clone();
245 if (type->isa<TensorType>()) {
246 auto tensor_type = type->cast_ptr<TensorType>();
247 MS_EXCEPTION_IF_NULL(tensor_type);
248 auto element = std::make_shared<abstract::AbstractScalar>(kValueAny, tensor_type->element());
249 tensor = std::make_shared<abstract::AbstractTensor>(element, ret_shape);
250 } else {
251 auto element = std::make_shared<abstract::AbstractScalar>(kValueAny, type);
252 tensor = std::make_shared<abstract::AbstractTensor>(element, ret_shape);
253 }
254 return tensor;
255 }
256
MakeMonadAbstract(const MonadTypePtr & type)257 AbstractBasePtr MakeMonadAbstract(const MonadTypePtr &type) {
258 if (type->isa<UMonadType>()) {
259 return kUMonad->ToAbstract();
260 } else if (type->isa<IOMonadType>()) {
261 return kIOMonad->ToAbstract();
262 }
263 MS_INTERNAL_EXCEPTION(UnknownError) << "Unsupported to convert type " << type->ToString() << " to monad abstract";
264 }
265
MakeAbstract(const BaseShapePtr & base_shape,const TypePtr & type)266 AbstractBasePtr MakeAbstract(const BaseShapePtr &base_shape, const TypePtr &type) {
267 MS_EXCEPTION_IF_NULL(base_shape);
268 MS_EXCEPTION_IF_NULL(type);
269 if ((base_shape->isa<Shape>())) {
270 auto shape = base_shape->cast<ShapePtr>();
271 MS_EXCEPTION_IF_NULL(shape);
272 auto shape_vec = shape->shape();
273 // if the size of shape list is empty, return an scalar abstract
274 if (shape_vec.empty() && (!type->isa<TensorType>())) {
275 abstract::AbstractScalarPtr abs_scalar = std::make_shared<abstract::AbstractScalar>(kValueAny, type);
276 return abs_scalar;
277 }
278 return MakeAbstractTensor(shape, type);
279 } else if (base_shape->isa<NoShape>() && type->isa<Type>()) {
280 return std::make_shared<abstract::AbstractScalar>(kValueAny, type);
281 } else if (base_shape->isa<TupleShape>() && type->isa<Tuple>()) {
282 auto shape_tuple = base_shape->cast_ptr<TupleShape>();
283 auto type_tuple = type->cast_ptr<Tuple>();
284 AbstractBasePtrList ptr_list;
285 for (size_t it = 0; it < shape_tuple->size(); ++it) {
286 auto tensor_it = MakeAbstract((*shape_tuple)[it], (*type_tuple)[it]);
287 ptr_list.push_back(tensor_it);
288 }
289 auto tuple = std::make_shared<abstract::AbstractTuple>(ptr_list);
290 return tuple;
291 } else if (base_shape->isa<ListShape>() && type->isa<List>()) {
292 auto shape_list = base_shape->cast_ptr<ListShape>();
293 auto type_list = type->cast_ptr<List>();
294 AbstractBasePtrList ptr_list;
295 for (size_t it = 0; it < shape_list->size(); ++it) {
296 auto tensor_it = MakeAbstract((*shape_list)[it], (*type_list)[it]);
297 ptr_list.push_back(tensor_it);
298 }
299 auto list = std::make_shared<abstract::AbstractList>(ptr_list);
300 return list;
301 } else if (base_shape->isa<NoShape>() && type->isa<TypeNone>()) {
302 // AbstractNone indicates there is no output for this CNode node.
303 auto abstract_none = std::make_shared<abstract::AbstractNone>();
304 return abstract_none;
305 } else if (type->isa<Monad>()) {
306 // Return monad abstract if it is monad type.
307 return MakeMonadAbstract(type->cast<MonadTypePtr>());
308 }
309 MS_LOG(INTERNAL_EXCEPTION) << "Evaluator return invalid shape " << base_shape->ToString() << " or type. "
310 << type->ToString();
311 }
312
SetVariableFlag(const AbstractBasePtr & abs)313 void SetVariableFlag(const AbstractBasePtr &abs) {
314 if (!abs->isa<abstract::AbstractFunction>()) {
315 return;
316 }
317 const auto func_abs = abs->cast_ptr<abstract::AbstractFunction>();
318 MS_EXCEPTION_IF_NULL(func_abs);
319 abstract::FuncGraphAbstractClosure *closure_abs = nullptr;
320 auto partial_closure_abs = func_abs->cast_ptr<abstract::PartialAbstractClosure>();
321 if (partial_closure_abs != nullptr) {
322 closure_abs = partial_closure_abs->fn()->cast_ptr<abstract::FuncGraphAbstractClosure>();
323 } else {
324 closure_abs = func_abs->cast_ptr<abstract::FuncGraphAbstractClosure>();
325 }
326 if (closure_abs != nullptr) {
327 auto func = closure_abs->func_graph();
328 MS_EXCEPTION_IF_NULL(func);
329 func->set_is_tensor_condition_branch(true);
330 MS_LOG(DEBUG) << "Set is_tensor_condition_branch for func_graph:" << func->ToString();
331 }
332 }
333
334 namespace {
GetFuncGraphFromAbs(const abstract::AbstractBasePtr & abs,const AnfNodePtr & call_node)335 FuncGraphPtr GetFuncGraphFromAbs(const abstract::AbstractBasePtr &abs, const AnfNodePtr &call_node) {
336 MS_EXCEPTION_IF_NULL(call_node);
337 if (abs == nullptr) {
338 MS_LOG(ERROR) << "Null abstract, current node: " << call_node->DebugString();
339 return nullptr;
340 }
341 if (abs->isa<abstract::FuncGraphAbstractClosure>()) {
342 auto abs_func_graph = abs->cast<abstract::FuncGraphAbstractClosurePtr>();
343 MS_EXCEPTION_IF_NULL(abs_func_graph);
344 if (!abs_func_graph->specialized()) {
345 MS_LOG(INFO) << "Unspecialized func graph abstract: " << abs_func_graph->ToString()
346 << ", node: " << call_node->DebugString();
347 }
348 return abs_func_graph->func_graph();
349 }
350 if (abs->isa<abstract::PartialAbstractClosure>()) {
351 auto abs_partial_closure = abs->cast<abstract::PartialAbstractClosurePtr>();
352 MS_EXCEPTION_IF_NULL(abs_partial_closure);
353 auto abs_func = abs_partial_closure->fn();
354 return GetFuncGraphFromAbs(abs_func, call_node);
355 }
356 MS_LOG(ERROR) << "Unexpected abs: " << abs->ToString() << ", call node: " << call_node->DebugString();
357 return nullptr;
358 }
359 } // namespace
360
GetFuncGraphsFromCallNode(const CNodePtr & call_node)361 std::vector<FuncGraphPtr> GetFuncGraphsFromCallNode(const CNodePtr &call_node) {
362 MS_EXCEPTION_IF_NULL(call_node);
363 auto func_node = call_node->input(0);
364 if (IsPrimitiveCNode(func_node, prim::kPrimPartial)) {
365 func_node = func_node->cast<CNodePtr>()->input(1);
366 }
367 if (IsValueNode<FuncGraph>(func_node)) {
368 return {GetValueNode<FuncGraphPtr>(func_node)};
369 }
370 auto abs = func_node->abstract();
371 MS_EXCEPTION_IF_NULL(abs);
372 if (abs == nullptr) {
373 MS_LOG(ERROR) << "Null abstract, current call node: " << call_node->DebugString();
374 return {};
375 }
376 if (!abs->isa<abstract::AbstractFunction>()) {
377 MS_LOG(ERROR) << "Unexpected abs: " << abs->ToString() << ", call_node: " << call_node->DebugString();
378 return {};
379 }
380 auto abs_func = abs->cast<abstract::AbstractFunctionPtr>();
381 MS_EXCEPTION_IF_NULL(abs_func);
382 std::vector<FuncGraphPtr> func_graphs;
383 if (abs->isa<abstract::AbstractFuncUnion>()) {
384 auto visit_func = [&func_graphs, &call_node](const abstract::AbstractFuncAtomPtr &poss) {
385 (void)func_graphs.emplace_back(GetFuncGraphFromAbs(poss, call_node));
386 };
387 abs_func->Visit(visit_func);
388 } else {
389 (void)func_graphs.emplace_back(GetFuncGraphFromAbs(abs_func, call_node));
390 }
391 bool exist_null_fg =
392 std::any_of(func_graphs.cbegin(), func_graphs.cend(), [](const FuncGraphPtr &fg) { return fg == nullptr; });
393 if (exist_null_fg) {
394 MS_LOG(ERROR) << "Get func graphs from abstract failed!";
395 return {};
396 }
397 return func_graphs;
398 }
399 } // namespace abstract
400 } // namespace mindspore
401