• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "include/common/expander/core/emitter.h"
18 
19 #include <algorithm>
20 #include <functional>
21 #include <unordered_set>
22 #include <utility>
23 #include "include/common/utils/utils.h"
24 #include "ir/anf.h"
25 #include "ops/sequence_ops.h"
26 #include "ops/math_ops.h"
27 #include "ops/array_ops.h"
28 #include "ops/framework_ops.h"
29 #include "include/common/utils/convert_utils.h"
30 #include "ir/functor.h"
31 #include "ops/primitive_c.h"
32 #include "utils/anf_utils.h"
33 #include "utils/check_convert_utils.h"
34 #include "utils/ms_context.h"
35 #include "ops/op_def.h"
36 #include "ir/primitive.h"
37 
38 namespace mindspore {
39 namespace expander {
40 namespace {
GetIntList(const NodePtr & node)41 std::pair<bool, std::vector<int64_t>> GetIntList(const NodePtr &node) {
42   MS_EXCEPTION_IF_NULL(node);
43   ValuePtr value_ptr = node->BuildValue();
44   if (value_ptr != nullptr) {
45     if ((value_ptr->isa<ValueSequence>() && !value_ptr->ContainsValueAny()) || value_ptr->isa<Scalar>()) {
46       return std::make_pair(true, CheckAndConvertUtils::CheckIntOrTupleInt("value", value_ptr, "GetIntList"));
47     }
48     if (value_ptr->isa<tensor::Tensor>()) {
49       auto tensor = value_ptr->cast<tensor::TensorPtr>();
50       MS_EXCEPTION_IF_NULL(tensor);
51       // In pynative mode, need data sync before get tensor value, otherwise the tensor value may be undefined.
52       tensor->data_sync();
53       return std::make_pair(true, CheckAndConvertUtils::CheckTensorIntValue("value", value_ptr, "GetIntList"));
54     }
55   }
56   return std::make_pair(false, std::vector<int64_t>{});
57 }
58 
CreateZeroScalar(const TypePtr & type)59 ValuePtr CreateZeroScalar(const TypePtr &type) {
60   auto tensor = std::make_shared<tensor::Tensor>(0, type);
61   return CreateValueFromTensor(tensor);
62 }
63 
CalReshapeRealDstShape(const ShapeVector & x_shape,const ShapeVector & dst_shape)64 ShapeVector CalReshapeRealDstShape(const ShapeVector &x_shape, const ShapeVector &dst_shape) {
65   if (!IsDynamicShape(dst_shape)) {
66     return dst_shape;
67   }
68 
69   if (IsDynamicRank(dst_shape) || IsDynamic(x_shape)) {
70     MS_LOG(EXCEPTION) << "The source shape(" << x_shape << ") or target shape(" << dst_shape
71                       << ") is invalid for Reshape const infer!";
72   }
73 
74   ShapeVector res_shape(dst_shape.begin(), dst_shape.end());
75   if (std::count(dst_shape.begin(), dst_shape.end(), abstract::Shape::kShapeDimAny) != 1) {
76     MS_LOG(EXCEPTION) << "The target shape can only have one -1 for Reshape, bug got " << dst_shape;
77   }
78 
79   auto total_size = std::accumulate(x_shape.cbegin(), x_shape.cend(), 1, std::multiplies<int64_t>());
80   size_t target_idx = 0;
81   int64_t dst_size = 1;
82   for (size_t i = 0; i < dst_shape.size(); ++i) {
83     if (dst_shape[i] == abstract::Shape::kShapeDimAny) {
84       target_idx = i;
85       continue;
86     }
87     dst_size *= dst_shape[i];
88   }
89   MS_EXCEPTION_IF_CHECK_FAIL(dst_size != 0, "Cannot divide zeros!");
90   res_shape[target_idx] = total_size / dst_size;
91   return res_shape;
92 }
93 }  // namespace
94 
Emit(const std::string & op_name,const NodePtrList & inputs,const DAttr & attrs)95 NodePtr Emitter::Emit(const std::string &op_name, const NodePtrList &inputs, const DAttr &attrs) {
96   auto prim = NewPrimitive(op_name, attrs);
97   return EmitOp(prim, inputs);
98 }
99 
EmitOp(const PrimitivePtr & prim,const NodePtrList & inputs)100 NodePtr Emitter::EmitOp(const PrimitivePtr &prim, const NodePtrList &inputs) {
101   MS_EXCEPTION(NotImplementedError) << "Base Emitter not implemented EmitOp() method";
102 }
103 
NewPrimitive(const std::string & op_name,const DAttr & attrs)104 PrimitivePtr Emitter::NewPrimitive(const std::string &op_name, const DAttr &attrs) {
105   PrimitivePtr prim = nullptr;
106   if (mindspore::ops::IsPrimitiveFunction(op_name)) {
107     prim = std::make_shared<Primitive>(op_name);
108     if (!attrs.empty()) {
109       prim->SetAttrs(attrs);
110     }
111   } else {
112     auto &func = Emitter::primc_func_cache()[op_name];
113     if (func == nullptr) {
114       const auto &op_primc_fns = ops::OpPrimCRegister::GetInstance().GetPrimCMap();
115       const auto iter = op_primc_fns.find(op_name);
116       prim = iter == op_primc_fns.end() ? std::make_shared<ops::PrimitiveC>(op_name) : (func = iter->second)();
117     } else {
118       prim = func();
119     }
120   }
121   MS_EXCEPTION_IF_NULL(prim);
122   if (!attrs.empty()) {
123     (void)prim->SetAttrs(attrs);
124   }
125   return prim;
126 }
127 
EmitValue(const ValuePtr & value)128 NodePtr Emitter::EmitValue(const ValuePtr &value) {
129   MS_EXCEPTION(NotImplementedError) << "Base Emitter not implemented EmitValue() method";
130 }
131 
Exp(const NodePtr & x)132 NodePtr Emitter::Exp(const NodePtr &x) {
133   return Emit(kExpOpName, {x},
134               {{"base", MakeValue<float>(-1.0)}, {"scale", MakeValue<float>(1.0)}, {"shift", MakeValue<float>(0.0)}});
135 }
136 
Log(const NodePtr & x)137 NodePtr Emitter::Log(const NodePtr &x) {
138   return Emit(kLogOpName, {x},
139               {{"base", MakeValue<pyfloat>(-1.0)},
140                {"scale", MakeValue<pyfloat>(1.0)},
141                {"shift", MakeValue<pyfloat>(0.0)},
142                {"cust_aicpu", MakeValue(kLogOpName)}});
143 }
144 
Cast(const NodePtr & node,const TypePtr & type)145 NodePtr Emitter::Cast(const NodePtr &node, const TypePtr &type) {
146   MS_EXCEPTION_IF_NULL(node);
147   MS_EXCEPTION_IF_NULL(type);
148   // do not emit a node when the dst type is the same as src type
149   if (node->dtype()->type_id() == type->type_id()) {
150     return node;
151   }
152   return Emit("Cast", {node, Value(static_cast<int64_t>(type->type_id()))});
153 }
154 
Reshape(const NodePtr & node,const NodePtr & shape)155 NodePtr Emitter::Reshape(const NodePtr &node, const NodePtr &shape) {
156   MS_EXCEPTION_IF_NULL(node);
157   auto [success, dst_shape] = GetIntList(shape);
158   if (!success) {
159     auto tuple_shape = TensorToTuple(shape);
160     return Emit(kReshapeOpName, {node, tuple_shape});
161   }
162 
163   if (node->input_type() == InputType::kConstant) {
164     // If node and shape is both known, return node itself or a new tensor with target shape.
165     auto value = node->BuildValue();
166     MS_EXCEPTION_IF_NULL(value);
167     auto tensor = value->cast<tensor::TensorPtr>();
168     if (tensor != nullptr && tensor->data().const_data() != nullptr) {
169       const auto &tensor_shape = tensor->shape_c();
170       auto update_shape = CalReshapeRealDstShape(tensor_shape, dst_shape);
171       if (tensor_shape == update_shape) {
172         return node;
173       }
174       auto type_id = tensor->data_type();
175       return this->Tensor(type_id, update_shape, tensor->data_c(), type_id);
176     }
177   }
178 
179   auto node_shape = node->shape();
180   if (IsDynamicRank(node_shape)) {
181     return Emit(kReshapeOpName, {node, shape});
182   }
183   if (dst_shape.size() != node_shape.size()) {
184     return Emit(kReshapeOpName, {node, shape});
185   }
186   for (size_t i = 0; i < dst_shape.size(); ++i) {
187     if (dst_shape[i] != node_shape[i] && dst_shape[i] != -1) {
188       return Emit(kReshapeOpName, {node, shape});
189     }
190   }
191   return node;
192 }
193 
MatMul(const NodePtr & a,const NodePtr & b,bool transpose_a,bool transpose_b)194 NodePtr Emitter::MatMul(const NodePtr &a, const NodePtr &b, bool transpose_a, bool transpose_b) {
195   return Emit(prim::kPrimMatMul->name(), {a, b, Value(transpose_a), Value(transpose_b)});
196 }
197 
BatchMatMul(const NodePtr & a,const NodePtr & b,bool transpose_a,bool transpose_b)198 NodePtr Emitter::BatchMatMul(const NodePtr &a, const NodePtr &b, bool transpose_a, bool transpose_b) {
199   return Emit(prim::kPrimBatchMatMul->name(), {a, b, Value(transpose_a), Value(transpose_b)});
200 }
201 
MatMulExt(const NodePtr & a,const NodePtr & b)202 NodePtr Emitter::MatMulExt(const NodePtr &a, const NodePtr &b) {
203   return UnifyDtypeAndEmit(prim::kPrimMatMulExt->name(), a, b, {});
204 }
205 
Transpose(const NodePtr & node,const NodePtr & perm)206 NodePtr Emitter::Transpose(const NodePtr &node, const NodePtr &perm) {
207   MS_EXCEPTION_IF_NULL(node);
208   MS_EXCEPTION_IF_NULL(perm);
209   auto [success, perm_list] = GetIntList(perm);
210   if (!success) {
211     auto tuple_perm = TensorToTuple(perm);
212     return Emit(kTransposeOpName, {node, tuple_perm});
213   }
214   // perm like [0, 1, 2, 3] does not need transpose.
215   auto n = SizeToLong(perm_list.size());
216   for (size_t i = 0; i < perm_list.size(); ++i) {
217     // perm value may be negative, e.g. [0, -3, 2, 3] is equal to [0, 1, 2, 3]
218     auto perm_i = perm_list[i] < 0 ? (perm_list[i] + n) : perm_list[i];
219     if (perm_i != static_cast<int64_t>(i)) {
220       return Emit(kTransposeOpName, {node, perm});
221     }
222   }
223   return node;
224 }
225 
Tile(const NodePtr & node,const NodePtr & dims)226 NodePtr Emitter::Tile(const NodePtr &node, const NodePtr &dims) {
227   MS_EXCEPTION_IF_NULL(node);
228   MS_EXCEPTION_IF_NULL(dims);
229   auto [success, multiples_list] = GetIntList(dims);
230   if (!success) {
231     auto tuple_multiples = TensorToTuple(dims);
232     return Emit(kTileOpName, {node, tuple_multiples});
233   }
234   bool is_all_one = std::all_of(multiples_list.begin(), multiples_list.end(), [](int64_t shp) { return shp == 1; });
235   if (is_all_one && node->shape().size() >= multiples_list.size()) {
236     return node;
237   }
238   return Emit(kTileOpName, {node, dims});
239 }
240 
BroadcastTo(const NodePtr & x,const NodePtr & y)241 NodePtr Emitter::BroadcastTo(const NodePtr &x, const NodePtr &y) {
242   if (IsDynamic(x->shape()) || IsDynamic(y->shape())) {
243     return Emit("BroadcastTo", {x, Shape(y)});
244   }
245 
246   return x->shape() == y->shape() ? x : Emit("BroadcastTo", {x, Shape(y)});
247 }
248 
ZerosLike(const NodePtr & node)249 NodePtr Emitter::ZerosLike(const NodePtr &node) {
250   MS_EXCEPTION_IF_NULL(node);
251   if (node->input_type() == InputType::kConstant) {
252     if (node->dtype()->type_id() == kMetaTypeNone) {
253       return Tensor(0);
254     }
255     auto v = node->BuildValue();
256     MS_EXCEPTION_IF_NULL(v);
257     if (v->isa<ValueSequence>()) {
258       return Emit(kSequenceZerosLikeOpName, {node});
259     } else if (v->isa<Scalar>()) {
260       return EmitValue(CreateZeroScalar(v->type()));
261     } else if (v->isa<Type>()) {
262       return Tensor(0, v->type());
263     } else if (v->isa<Monad>()) {
264       return Tensor(0);
265     }
266   }
267 
268   auto abs = node->abstract();
269   MS_EXCEPTION_IF_NULL(abs);
270 
271   if (abs->isa<abstract::AbstractTensor>()) {
272     return Emit(kZerosLikeOpName, {node});
273   } else if (abs->isa<abstract::AbstractMonad>() || abs->isa<abstract::AbstractType>() ||
274              abs->isa<abstract::AbstractNone>()) {
275     return node;
276   } else if (abs->isa<abstract::AbstractSequence>()) {
277     auto sequence_abs = abs->cast<abstract::AbstractSequencePtr>();
278     if (!sequence_abs->dynamic_len() && sequence_abs->empty()) {
279       return node;
280     }
281     return Emit(kSequenceZerosLikeOpName, {node});
282   } else if (abs->isa<abstract::AbstractScalar>()) {
283     auto value = CreateZeroScalar(abs->BuildType());
284     return EmitValue(value);
285   }
286 
287   MS_LOG(EXCEPTION) << "Cannot emit ZerosLike for " << node->ToString() << " with abstract " << abs;
288 }
289 
Fill(double value,const ShapeVector & shape,TypeId data_type)290 NodePtr Emitter::Fill(double value, const ShapeVector &shape, TypeId data_type) {
291   size_t data_num = LongToSize(std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int64_t>()));
292   std::vector<double> data(data_num, value);
293   return Tensor(data_type, shape, &data[0], TypeId::kNumberTypeFloat64);
294 }
295 
Fill(int64_t value,const ShapeVector & shape,TypeId data_type)296 NodePtr Emitter::Fill(int64_t value, const ShapeVector &shape, TypeId data_type) {
297   size_t data_num = LongToSize(std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int64_t>()));
298   std::vector<int64_t> data(data_num, value);
299   return Tensor(data_type, shape, &data[0], TypeId::kNumberTypeInt64);
300 }
301 
ScalarToTensor(const NodePtr & node)302 NodePtr Emitter::ScalarToTensor(const NodePtr &node) {
303   MS_EXCEPTION_IF_NULL(node);
304   auto value = node->BuildValue();
305   MS_EXCEPTION_IF_NULL(value);
306   auto scalar = value->cast<ScalarPtr>();
307   MS_EXCEPTION_IF_NULL(scalar);
308   auto tensor = mindspore::ScalarToTensor(scalar);
309   return EmitValue(tensor);
310 }
311 
ScalarToTensor(const NodePtr & node,const TypePtr & dtype)312 NodePtr Emitter::ScalarToTensor(const NodePtr &node, const TypePtr &dtype) {
313   MS_EXCEPTION_IF_NULL(node);
314   auto value = node->BuildValue();
315   MS_EXCEPTION_IF_NULL(value);
316   auto scalar = value->cast<ScalarPtr>();
317   if (scalar == nullptr) {
318     return Emit("ScalarToTensor", {node, Value(static_cast<int64_t>(dtype->type_id()))});
319   }
320   MS_EXCEPTION_IF_NULL(scalar);
321   auto tensor = mindspore::ScalarToTensor(scalar);
322   return EmitValue(tensor);
323 }
324 
BoolNot(const NodePtr & node)325 NodePtr Emitter::BoolNot(const NodePtr &node) {
326   MS_EXCEPTION_IF_NULL(node);
327   auto abs = node->abstract();
328   MS_EXCEPTION_IF_NULL(abs);
329   NodePtr new_node{nullptr};
330   if (abs->isa<abstract::AbstractScalar>()) {
331     auto value_ptr = node->BuildValue();
332     MS_EXCEPTION_IF_NULL(value_ptr);
333     if (!(value_ptr->isa<ValueAny>() || value_ptr->isa<None>())) {
334       auto value = GetValue<bool>(value_ptr);
335       new_node = Value(static_cast<bool>(!value));
336     } else {
337       new_node = Emit("BoolNot", {node});
338     }
339   } else {
340     MS_LOG(EXCEPTION) << "BooNot only support scalar, but got " << abs->ToString();
341   }
342   return new_node;
343 }
344 
NeedReduce(const ShapeVector & shape,const std::vector<int64_t> & axis,bool keep_dim,bool skip_mode) const345 std::pair<bool, ShapeVector> Emitter::NeedReduce(const ShapeVector &shape, const std::vector<int64_t> &axis,
346                                                  bool keep_dim, bool skip_mode) const {
347   if (IsDynamic(shape)) {
348     return std::make_pair(true, shape);
349   }
350   if (shape.empty() || (skip_mode && axis.empty())) {
351     return std::make_pair(false, shape);
352   }
353   auto rank = SizeToLong(shape.size());
354   std::vector<bool> axis_map;
355   if (axis.empty()) {
356     axis_map = std::vector<bool>(shape.size(), true);
357   } else {
358     axis_map = std::vector<bool>(shape.size(), false);
359     for (size_t i = 0; i < axis.size(); ++i) {
360       if (axis[i] < -rank || axis[i] >= rank) {
361         MS_EXCEPTION(ValueError) << "Reduce axis[" << i << "] is " << axis[i] << ", which is out of range [-" << rank
362                                  << ", " << rank << ") for shape: " << shape;
363       }
364       auto axis_i = axis[i] < 0 ? axis[i] + rank : axis[i];
365       axis_map[LongToSize(axis_i)] = true;
366     }
367   }
368   // Calc reduce output shape
369   ShapeVector out_shape;
370   bool need_reduce = false;
371   for (size_t i = 0; i < shape.size(); ++i) {
372     if (!axis_map[i]) {
373       // not reduce axis
374       out_shape.push_back(shape[i]);
375     } else {
376       // reduce axis
377       if (shape[i] != 1) {
378         need_reduce = true;
379       }
380       if (keep_dim) {
381         out_shape.push_back(1);
382       }
383     }
384   }
385   return std::make_pair(need_reduce, out_shape);
386 }
387 
NeedReduce(const NodePtr & shape,const NodePtr & axis,bool keep_dim,bool skip_mode)388 std::pair<bool, NodePtr> Emitter::NeedReduce(const NodePtr &shape, const NodePtr &axis, bool keep_dim, bool skip_mode) {
389   auto [axis_success, axis_value] = GetIntList(axis);
390   if (axis_success && skip_mode && axis_value.empty()) {
391     return std::make_pair(false, shape);
392   }
393   auto [shape_success, shape_value] = GetIntList(shape);
394   if (shape_success && axis_success) {
395     auto [need_reduce, shape_vec] = NeedReduce(shape_value, axis_value, keep_dim, skip_mode);
396     return std::make_pair(need_reduce, Value(shape_vec));
397   }
398 
399   auto v = Value(ShapeVector{});
400   return std::make_pair(true, v);
401 }
402 
ReduceSum(const NodePtr & x,const NodePtr & axis,bool keep_dims,bool skip_mode)403 NodePtr Emitter::ReduceSum(const NodePtr &x, const NodePtr &axis, bool keep_dims, bool skip_mode) {
404   MS_EXCEPTION_IF_NULL(x);
405   MS_EXCEPTION_IF_NULL(axis);
406   auto need_reduce = NeedReduce(Shape(x), axis, keep_dims, skip_mode);
407   if (!need_reduce.first) {
408     return Reshape(x, need_reduce.second);
409   }
410   auto tuple_axis = TensorToTuple(axis);
411   return Emit(prim::kPrimReduceSum->name(), {x, tuple_axis, Value(keep_dims), Value(skip_mode)});
412 }
413 
ReduceSum(const NodePtr & x,const ShapeVector & axis,bool keep_dims)414 NodePtr Emitter::ReduceSum(const NodePtr &x, const ShapeVector &axis, bool keep_dims) {
415   MS_EXCEPTION_IF_NULL(x);
416   auto real_axis = axis;
417 #ifdef WITH_BACKEND
418   const auto &shape = x->shape();
419   if (real_axis.empty()) {
420     if (IsDynamicRank(shape)) {
421       MS_LOG(DEBUG) << "For ReduceSum, it may wrong with a empty axis for dynamic rank case.";
422     } else {
423       for (int64_t i = 0; i < SizeToLong(shape.size()); i++) {
424         real_axis.push_back(i);
425       }
426     }
427   }
428 #endif
429   return ReduceSum(x, Value<ShapeVector>(real_axis), keep_dims, false);
430 }
431 
SumExt(const NodePtr & input,const NodePtr & axis,const NodePtr & keep_dims)432 NodePtr Emitter::SumExt(const NodePtr &input, const NodePtr &axis, const NodePtr &keep_dims) {
433   MS_EXCEPTION_IF_NULL(input);
434   MS_EXCEPTION_IF_NULL(axis);
435   MS_EXCEPTION_IF_NULL(keep_dims);
436   auto input_dtype_id = input->dtype()->type_id();
437   auto dtype = Value(static_cast<int64_t>(input_dtype_id));
438   MS_EXCEPTION_IF_NULL(dtype);
439 
440   return Emit("SumExt", {input, axis, keep_dims, dtype});
441 }
442 
Gather(const NodePtr & params,const NodePtr & indices,const NodePtr & axis,int64_t batch_dims)443 NodePtr Emitter::Gather(const NodePtr &params, const NodePtr &indices, const NodePtr &axis, int64_t batch_dims) {
444   MS_EXCEPTION_IF_NULL(params);
445   MS_EXCEPTION_IF_NULL(indices);
446   MS_EXCEPTION_IF_NULL(axis);
447   return Emit(kGatherOpName, {params, indices, axis, Value(batch_dims)});
448 }
Gather(const NodePtr & params,const NodePtr & indices,int64_t axis,int64_t batch_dims)449 NodePtr Emitter::Gather(const NodePtr &params, const NodePtr &indices, int64_t axis, int64_t batch_dims) {
450   return Gather(params, indices, Value(axis), batch_dims);
451 }
452 
GetConstInputs(const NodePtrList & inputs,const std::vector<bool> & only_depend_shape,const ShapeValidFunc & valid_func)453 std::tuple<bool, ShapeArray, std::vector<std::vector<size_t>>> GetConstInputs(
454   const NodePtrList &inputs, const std::vector<bool> &only_depend_shape, const ShapeValidFunc &valid_func) {
455   bool all_const = true;
456   ShapeArray const_args;
457   std::vector<std::vector<size_t>> pos_idx;
458 
459   for (size_t i = 0; i < inputs.size(); ++i) {
460     MS_EXCEPTION_IF_NULL(inputs[i]);
461 
462     if (!only_depend_shape[i]) {
463       // input[i]'s value is used
464       auto [success, vec] = GetIntList(inputs[i]);
465       if (!success) {
466         all_const = false;
467         break;
468       }
469       pos_idx.push_back({const_args.size()});
470       const_args.push_back(vec);
471     } else {
472       // input[i]'s shape is used
473       auto abs = inputs[i]->abstract();
474       MS_EXCEPTION_IF_NULL(abs);
475 
476       if (auto sequence_abs = abs->cast<abstract::AbstractSequencePtr>(); sequence_abs != nullptr) {
477         auto begin_idx = const_args.size();
478         auto is_const = ops::TryGetShapeArg(sequence_abs, &const_args, &pos_idx);
479 
480         if (is_const) {
481           for (size_t j = begin_idx; j < const_args.size(); ++j) {
482             is_const = valid_func ? valid_func(j, const_args[j]) : !IsDynamic(const_args[j]);
483             if (!is_const) {
484               break;
485             }
486           }
487         }
488 
489         if (!is_const) {
490           all_const = false;
491           break;
492         }
493       } else {
494         auto input_shape = inputs[i]->shape();
495         auto input_valid = valid_func ? valid_func(i, input_shape) : !IsDynamic(input_shape);
496         if (!input_valid) {
497           all_const = false;
498           break;
499         }
500         pos_idx.push_back({const_args.size()});
501         const_args.push_back(input_shape);
502       }
503     }
504   }
505 
506   return std::make_tuple(all_const, const_args, pos_idx);
507 }
508 
ShapeCalc(const ShapeCalcBaseFunctorPtr & functor,const NodePtrList & inputs,const std::vector<int64_t> & value_depend,const ShapeValidFunc & valid_func)509 NodePtrList Emitter::ShapeCalc(const ShapeCalcBaseFunctorPtr &functor, const NodePtrList &inputs,
510                                const std::vector<int64_t> &value_depend, const ShapeValidFunc &valid_func) {
511   std::vector<bool> only_depend_shape(inputs.size(), true);
512   for (auto idx : value_depend) {
513     only_depend_shape[LongToSize(idx)] = false;
514   }
515 
516   bool all_const;
517   ShapeArray const_args;
518   std::vector<std::vector<size_t>> pos_idx;
519   // Try to get all const input shapes or values, and call the shape calc function when success.
520   std::tie(all_const, const_args, pos_idx) = GetConstInputs(inputs, only_depend_shape, valid_func);
521   NodePtrList res;
522   // all inputs are static-shape tensors,
523   if (all_const) {
524     auto out = functor->Calc(const_args, pos_idx);
525     res.reserve(out.size());
526     (void)std::transform(out.begin(), out.end(), std::back_inserter(res),
527                          [this](const ShapeVector &sh) { return Value(sh); });
528     return res;
529   }
530 
531   auto out = Emit(kShapeCalcOpName, inputs,
532                   {{kAttrFunctor, functor},
533                    {kAttrOnlyDependShape, MakeValue(only_depend_shape)},
534                    {kAttrInputIsDynamicShape, MakeValue(true)}});
535   MS_EXCEPTION_IF_NULL(out);
536   auto abs = out->abstract();
537   MS_EXCEPTION_IF_NULL(abs);
538   auto tuple_abs = abs->cast<abstract::AbstractTuplePtr>();
539   MS_EXCEPTION_IF_NULL(tuple_abs);
540   if (!tuple_abs->dynamic_len() && tuple_abs->size() != 0 && tuple_abs->elements()[0]->isa<abstract::AbstractTuple>()) {
541     res.reserve(tuple_abs->size());
542     for (size_t i = 0; i < tuple_abs->size(); ++i) {
543       res.push_back(TupleGetItem(out, i));
544     }
545   } else {
546     res.push_back(out);
547   }
548   return res;
549 }
550 
TensorToTuple(const NodePtr & node)551 NodePtr Emitter::TensorToTuple(const NodePtr &node) {
552   MS_EXCEPTION_IF_NULL(node);
553   auto abs = node->abstract();
554   MS_EXCEPTION_IF_NULL(abs);
555   if (abs->isa<abstract::AbstractTensor>()) {
556     auto [success, value] = GetIntList(node);
557     if (success) {
558       return EmitValue(MakeValue(value));
559     }
560     return Emit(kTensorToTupleOpName, {node});
561   }
562   if (!abs->isa<abstract::AbstractTuple>()) {
563     MS_LOG(INTERNAL_EXCEPTION) << "A Tensor or tuple is expected, but got " << abs->BuildType()->ToString() << ".";
564   }
565   return node;
566 }
567 
UnifyDtype2(const NodePtr & lhs,const NodePtr & rhs)568 std::tuple<NodePtr, NodePtr> Emitter::UnifyDtype2(const NodePtr &lhs, const NodePtr &rhs) {
569   auto it1 = type_vector_[lhs->dtype()->type_id()];
570   auto it2 = type_vector_[rhs->dtype()->type_id()];
571   if (!it1 || !it2 || it1 == it2) {
572     return {lhs, rhs};
573   }
574   if (it1 < it2) {
575     return {this->Cast(lhs, rhs->dtype()), rhs};
576   }
577   return {lhs, this->Cast(rhs, lhs->dtype())};
578 }
579 
SparseSoftmaxCrossEntropyWithLogits(const NodePtrList & inputs,const DAttr & attrs,const NodePtr & out,const NodePtr & dout,bool is_graph_mode)580 NodePtr Emitter::SparseSoftmaxCrossEntropyWithLogits(const NodePtrList &inputs, const DAttr &attrs, const NodePtr &out,
581                                                      const NodePtr &dout, bool is_graph_mode) {
582   auto grad = Emit("SparseSoftmaxCrossEntropyWithLogits", inputs, attrs);
583   if (is_graph_mode) {
584     grad = Depend(grad, out);
585   }
586   grad = Mul(grad, dout);
587   return grad;
588 }
589 
Conditional(const NodePtr & cond,const BlockFunc & true_case,const BlockFunc & false_case)590 NodePtr Emitter::Conditional(const NodePtr &cond, const BlockFunc &true_case, const BlockFunc &false_case) {
591   MS_EXCEPTION(NotImplementedError) << "Base Emitter not implement Conditional() method";
592 }
593 
While(const NodePtr & cond,const BlockFunc & body,const NodePtrList & init_list)594 NodePtr Emitter::While(const NodePtr &cond, const BlockFunc &body, const NodePtrList &init_list) {
595   MS_EXCEPTION(NotImplementedError) << "Base Emitter not implement While() method";
596 }
597 
IfThenElse(const NodePtr & cond,const BlockFunc & true_case,const BlockFunc & false_case)598 NodePtr CtrlFlowBlock::IfThenElse(const NodePtr &cond, const BlockFunc &true_case, const BlockFunc &false_case) {
599   auto tb = BuildSubgraph(true_case);
600   auto fb = BuildSubgraph(false_case);
601   auto s = emitter_->Emit("Switch", {cond, tb, fb});
602 
603   auto cnode = func_graph_->FuncGraph::NewCNode({s->get()});
604   cnode->set_abstract(out_abstract_);
605   auto node = emitter_->NewIrNode(cnode->cast<AnfNodePtr>());
606   return node;
607 }
608 
While(const NodePtr & cond,const BlockFunc & while_body_func,const NodePtrList & init_list)609 NodePtr CtrlFlowBlock::While(const NodePtr &cond, const BlockFunc &while_body_func, const NodePtrList &init_list) {
610   auto while_fg = std::make_shared<FuncGraph>();
611   MS_EXCEPTION_IF_NULL(while_fg);
612   auto cond_cnode = cond->get()->cast<CNodePtr>();
613   MS_EXCEPTION_IF_NULL(cond_cnode);
614 
615   cond_cnode->set_func_graph(while_fg);
616   auto while_fg_emitter = CreateInnerEmitter(while_fg, std::make_shared<CppInferWithPartial>());
617   MS_EXCEPTION_IF_NULL(while_fg_emitter);
618   AnfNodePtrList main_while_fg_inputs = {NewValueNode(while_fg)};
619   std::map<AnfNodePtr, ParameterPtr> param_map;
620   auto replace_by_param = [&main_while_fg_inputs, &param_map, &while_fg](const AnfNodePtr &inp) {
621     auto &param = param_map[inp];
622     if (param == nullptr) {
623       param = while_fg->add_parameter();
624       param->set_abstract(inp->abstract());
625       (void)main_while_fg_inputs.emplace_back(inp);
626     }
627     return param;
628   };
629 
630   auto empty_body_func = [&init_list](Emitter *) { return init_list; };
631   auto empty_body_fg_with_inputs = BuildSubgraphOfPartial(empty_body_func);
632   for (size_t i = 1; i < empty_body_fg_with_inputs.size(); i++) {
633     auto inp = empty_body_fg_with_inputs[i]->get();
634     empty_body_fg_with_inputs[i] = while_fg_emitter->NewIrNode(replace_by_param(inp));
635   }
636   for (size_t i = 1; i < cond_cnode->size(); i++) {
637     auto inp = cond_cnode->input(i);
638     MS_EXCEPTION_IF_NULL(inp);
639     if (!inp->isa<ValueNode>()) {
640       cond_cnode->set_input(i, replace_by_param(inp));
641     }
642   }
643 
644   auto body_with_inputs = BuildSubgraphOfPartial(while_body_func);
645   auto body_fg = body_with_inputs[0]->get()->cast<ValueNodePtr>()->value()->cast<FuncGraphPtr>();
646   for (size_t i = 1; i < body_with_inputs.size(); i++) {
647     body_with_inputs[i] = while_fg_emitter->NewIrNode(replace_by_param(body_with_inputs[i]->get()));
648   }
649   // replace the body's output to call the outside while-fg
650   AnfNodePtrList body_while_fg_inputs{NewValueNode(while_fg)};
651   if (IsPrimitiveCNode(body_fg->output(), prim::kPrimMakeTuple)) {
652     auto mt = body_fg->output()->cast<CNodePtr>();
653     MS_EXCEPTION_IF_NULL(mt);
654     (void)body_while_fg_inputs.insert(body_while_fg_inputs.end(), mt->inputs().begin() + 1, mt->inputs().end());
655   } else {
656     body_while_fg_inputs.push_back(body_fg->output());
657   }
658   if (body_while_fg_inputs.size() - 1 != init_list.size()) {
659     MS_LOG(EXCEPTION) << "The while body's output size should be equal to init_list.size(), but got "
660                       << (body_while_fg_inputs.size() - 1) << " vs " << init_list.size();
661   }
662   if (body_while_fg_inputs.size() < main_while_fg_inputs.size()) {
663     for (size_t i = body_while_fg_inputs.size(); i < main_while_fg_inputs.size(); i++) {
664       auto inp = while_fg->parameters()[i - 1];
665       auto iter = std::find_if(body_with_inputs.begin(), body_with_inputs.end(),
666                                [&inp](const NodePtr &no) { return no->get() == inp; });
667       if (iter != body_with_inputs.end()) {
668         auto param_idx = iter - body_with_inputs.begin() - 1;
669         body_while_fg_inputs.push_back(body_fg->parameters()[LongToSize(param_idx)]);
670       } else {
671         body_with_inputs.push_back(while_fg_emitter->NewIrNode(inp));
672         auto p = body_fg->add_parameter();
673         p->set_abstract(inp->abstract());
674         body_while_fg_inputs.push_back(p);
675       }
676     }
677   }
678   auto body_call_fg = body_fg->NewCNode(body_while_fg_inputs);
679   body_call_fg->set_abstract(out_abstract_);
680   body_fg->set_output(body_call_fg);
681 
682   auto tb = while_fg_emitter->Emit("Partial", body_with_inputs);
683   auto fb = while_fg_emitter->Emit("Partial", empty_body_fg_with_inputs);
684   auto s = while_fg_emitter->Emit("Switch", {cond, tb, fb});
685   auto cnode = while_fg->NewCNode({s->get()});
686   cnode->set_abstract(out_abstract_);
687   while_fg->set_output(cnode);
688 
689   auto main_cnode = func_graph_->FuncGraph::NewCNode(main_while_fg_inputs);
690   main_cnode->set_abstract(out_abstract_);
691   return emitter_->NewIrNode(main_cnode);
692 }
693 
CreateInnerEmitter(const FuncGraphPtr & fg,const ExpanderInferPtr & infer) const694 EmitterPtr CtrlFlowBlock::CreateInnerEmitter(const FuncGraphPtr &fg, const ExpanderInferPtr &infer) const {
695   return emitter_creator_ ? emitter_creator_(fg, infer) : std::make_shared<IrEmitter>(fg, infer);
696 }
697 
BuildSubgraph(const BlockFunc & func)698 NodePtr CtrlFlowBlock::BuildSubgraph(const BlockFunc &func) {
699   auto fg = std::make_shared<FuncGraph>();
700   MS_EXCEPTION_IF_NULL(fg);
701   fg->set_indirect(std::make_shared<bool>(true));
702   auto e = CreateInnerEmitter(fg, emitter_->infer());
703   MS_EXCEPTION_IF_NULL(e);
704   auto outputs = func(e.get());
705   if (outputs.empty()) {
706     MS_LOG(EXCEPTION) << "The block function should not return empty list.";
707   }
708   if (output_num_ == 0) {
709     output_num_ = outputs.size();
710   } else if (output_num_ != outputs.size()) {
711     MS_LOG(EXCEPTION) << "The count of outputs of each block function should be equal, but got " << output_num_
712                       << " vs " << outputs.size() << ".";
713   }
714   NodePtr output;
715   if (output_num_ > 1) {
716     output = e->MakeTuple(outputs);
717     SetSequenceNodeElementsUseFlags(output->get(), std::make_shared<std::vector<bool>>(output_num_, true));
718   } else {
719     output = outputs[0];
720   }
721   fg->set_output(output->get());
722   if (out_abstract_ == nullptr) {
723     out_abstract_ = output->abstract();
724   }
725   return emitter_->Value(fg);
726 }
727 
BuildSubgraphOfPartial(const BlockFunc & func)728 NodePtrList CtrlFlowBlock::BuildSubgraphOfPartial(const BlockFunc &func) {
729   auto fg = std::make_shared<FuncGraph>();
730   MS_EXCEPTION_IF_NULL(fg);
731   fg->set_indirect(std::make_shared<bool>(true));
732   auto sub_emitter = CreateInnerEmitter(fg, emitter_->infer());
733   MS_EXCEPTION_IF_NULL(sub_emitter);
734   auto output = func(sub_emitter.get());
735   if (output.empty()) {
736     MS_LOG(EXCEPTION) << "The block function should not return empty list.";
737   }
738   if (output_num_ == 0) {
739     output_num_ = output.size();
740   } else if (output_num_ != output.size()) {
741     MS_LOG(EXCEPTION) << "The count of outputs of each block function should be equal, but got " << output_num_
742                       << " vs " << output.size() << ".";
743   }
744   fg->set_output((output_num_ > 1) ? sub_emitter->MakeTuple(output)->get() : output[0]->get());
745   if (out_abstract_ == nullptr) {
746     out_abstract_ = fg->output()->abstract();
747   }
748   if (output_num_ > 1) {
749     SetSequenceNodeElementsUseFlags(fg->output(), std::make_shared<std::vector<bool>>(output_num_, true));
750   }
751 
752   // replace the captured inputs to parameter
753   std::function<void(const CNodePtr &)> dfs;
754   std::unordered_set<AnfNodePtr> visited;
755   std::map<AnfNodePtr, ParameterPtr> param_map;
756   NodePtrList fg_with_inputs = {emitter_->Value(fg)};
757   dfs = [&visited, &dfs, &fg, &param_map, &fg_with_inputs, this](const CNodePtr &node) {
758     (void)visited.insert(node);
759     for (size_t i = 0; i < node->size(); i++) {
760       auto inp = node->input(i);
761       if (inp->func_graph() == nullptr) {
762         continue;
763       }
764       if (inp->func_graph() == fg) {
765         if (inp->isa<CNode>() && visited.count(inp) == 0) {
766           dfs(inp->cast<CNodePtr>());
767         }
768       } else {
769         auto &param = param_map[inp];
770         if (param == nullptr) {
771           param = fg->add_parameter();
772           param->set_abstract(inp->abstract());
773           (void)fg_with_inputs.emplace_back(emitter_->NewIrNode(inp));
774         }
775         node->set_input(i, param);
776       }
777     }
778   };
779   dfs(fg->get_return());
780   return fg_with_inputs;
781 }
782 
Infer(const NodePtr & node)783 void CtrlFlowBlock::CppInferWithPartial::Infer(const NodePtr &node) {
784   if (IsPrimitiveCNode(node->get(), prim::kPrimPartial) || IsPrimitiveCNode(node->get(), prim::kPrimSwitch)) {
785     return;
786   }
787   CppInfer::Infer(node);
788 }
789 
EmitOp(const PrimitivePtr & prim,const NodePtrList & inputs)790 NodePtr IrEmitter::EmitOp(const PrimitivePtr &prim, const NodePtrList &inputs) {
791   AnfNodePtrList cnode_inputs = {NewValueNode(prim)};
792   cnode_inputs.reserve(inputs.size() + 1);
793   (void)std::transform(inputs.cbegin(), inputs.cend(), std::back_inserter(cnode_inputs), [](const NodePtr &no) {
794     MS_EXCEPTION_IF_NULL(no);
795     return no->get();
796   });
797   auto cnode = func_graph_->NewCNode(cnode_inputs);
798   if (scope_ != nullptr) {
799     cnode->set_scope(scope_);
800   }
801   auto node = NewIrNode(cnode->cast<AnfNodePtr>());
802   infer_->Infer(node);
803   return node;
804 }
805 
EmitValue(const ValuePtr & value)806 NodePtr IrEmitter::EmitValue(const ValuePtr &value) {
807   auto node = NewIrNode(NewValueNode(value));
808   infer_->Infer(node);
809   return node;
810 }
811 
operator +(const NodePtr & lhs,const NodePtr & rhs)812 NodePtr operator+(const NodePtr &lhs, const NodePtr &rhs) { return lhs->emitter()->Add(lhs, rhs); }
operator -(const NodePtr & lhs,const NodePtr & rhs)813 NodePtr operator-(const NodePtr &lhs, const NodePtr &rhs) { return lhs->emitter()->Sub(lhs, rhs); }
operator *(const NodePtr & lhs,const NodePtr & rhs)814 NodePtr operator*(const NodePtr &lhs, const NodePtr &rhs) { return lhs->emitter()->Mul(lhs, rhs); }
operator /(const NodePtr & lhs,const NodePtr & rhs)815 NodePtr operator/(const NodePtr &lhs, const NodePtr &rhs) { return lhs->emitter()->RealDiv(lhs, rhs); }
operator -(const NodePtr & node)816 NodePtr operator-(const NodePtr &node) { return node->emitter()->Neg(node); }
817 }  // namespace expander
818 }  // namespace mindspore
819