• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_COMMON_EXPANDER_CORE_EMITTER_H_
18 #define MINDSPORE_CCSRC_COMMON_EXPANDER_CORE_EMITTER_H_
19 #include <map>
20 #include <memory>
21 #include <string>
22 #include <tuple>
23 #include <unordered_set>
24 #include <utility>
25 #include <vector>
26 #include "include/common/expander/core/infer.h"
27 #include "include/common/expander/core/node.h"
28 #include "include/common/utils/utils.h"
29 #include "ir/func_graph.h"
30 #include "ir/functor.h"
31 #include "ops/array_op_name.h"
32 #include "ops/comparison_op_name.h"
33 #include "ops/framework_op_name.h"
34 #include "ops/arithmetic_op_name.h"
35 #include "ops/math_ops.h"
36 #include "ops/sequence_ops.h"
37 #include "ops/shape_calc.h"
38 #include "ops/auto_generate/gen_ops_name.h"
39 
40 namespace mindspore {
41 namespace expander {
42 using ShapeValidFunc = std::function<bool(size_t, const ShapeVector &)>;
43 
44 class COMMON_EXPORT Emitter {
45  public:
infer_(infer)46   explicit Emitter(const ExpanderInferPtr &infer, const ScopePtr &scope = nullptr) : infer_(infer), scope_(scope) {}
47   virtual ~Emitter() = default;
48 
49   /// \brief Emit a primitive CNode
50   NodePtr Emit(const std::string &op_name, const NodePtrList &inputs, const DAttr &attrs = {});
51   PrimitivePtr NewPrimitive(const std::string &name, const DAttr &attrs = {});
52 
53   /// \brief Emit a ValueNode
54   virtual NodePtr EmitValue(const ValuePtr &value);
55 
NewIrNode(const AnfNodePtr & anfnode)56   NodePtr NewIrNode(const AnfNodePtr &anfnode) { return std::make_shared<IrNode>(anfnode, this); }
NewFuncNode(const ValuePtr & value,const abstract::AbstractBasePtr & abs,InputType input_type)57   FuncNodePtr NewFuncNode(const ValuePtr &value, const abstract::AbstractBasePtr &abs, InputType input_type) {
58     return std::make_shared<FuncNode>(value, abs, input_type, this);
59   }
MakeTuple(const NodePtrList & inputs)60   virtual NodePtr MakeTuple(const NodePtrList &inputs) { return EmitOp(prim::kPrimMakeTuple, inputs); }
MakeList(const NodePtrList & inputs)61   virtual NodePtr MakeList(const NodePtrList &inputs) { return EmitOp(prim::kPrimMakeList, inputs); }
TupleGetItem(const NodePtr & input,size_t i)62   virtual NodePtr TupleGetItem(const NodePtr &input, size_t i) {
63     return Emit(mindspore::kTupleGetItemOpName, {input, Value(static_cast<int64_t>(i))});
64   }
TupleGetItem(const NodePtr & input,const NodePtr & i)65   virtual NodePtr TupleGetItem(const NodePtr &input, const NodePtr &i) { return Emit(kTupleGetItemOpName, {input, i}); }
Len(const NodePtr & input)66   NodePtr Len(const NodePtr &input) { return Emit(kSequenceLenOpName, {input}); }
ScalarAdd(const NodePtr & lhs,const NodePtr & rhs)67   NodePtr ScalarAdd(const NodePtr &lhs, const NodePtr &rhs) { return Emit(ops::kNameScalarAdd, {lhs, rhs}); }
ScalarSub(const NodePtr & lhs,const NodePtr & rhs)68   NodePtr ScalarSub(const NodePtr &lhs, const NodePtr &rhs) { return Emit(ops::kNameScalarSub, {lhs, rhs}); }
ScalarMul(const NodePtr & lhs,const NodePtr & rhs)69   NodePtr ScalarMul(const NodePtr &lhs, const NodePtr &rhs) { return Emit(ops::kNameScalarMul, {lhs, rhs}); }
ScalarDiv(const NodePtr & lhs,const NodePtr & rhs)70   NodePtr ScalarDiv(const NodePtr &lhs, const NodePtr &rhs) { return Emit(ops::kNameScalarDiv, {lhs, rhs}); }
ScalarFloorDiv(const NodePtr & lhs,const NodePtr & rhs)71   NodePtr ScalarFloorDiv(const NodePtr &lhs, const NodePtr &rhs) { return Emit(ops::kNameScalarFloorDiv, {lhs, rhs}); }
ScalarNeg(const NodePtr & node)72   NodePtr ScalarNeg(const NodePtr &node) { return Emit(ops::kNameScalarUsub, {node}); }
73   NodePtr Cast(const NodePtr &node, const TypePtr &type);
Cast(const NodePtr & node,TypeId type_id)74   NodePtr Cast(const NodePtr &node, TypeId type_id) { return Cast(node, TypeIdToType(type_id)); }
75 
76   NodePtr Reshape(const NodePtr &node, const NodePtr &shape);
Reshape(const NodePtr & node,const ShapeVector & shape)77   NodePtr Reshape(const NodePtr &node, const ShapeVector &shape) { return Reshape(node, Value(shape)); }
ExpandDims(const NodePtr & node,int64_t axis)78   NodePtr ExpandDims(const NodePtr &node, int64_t axis) { return Emit(kExpandDimsOpName, {node, Value(axis)}); }
Abs(const NodePtr & node)79   NodePtr Abs(const NodePtr &node) { return Emit(mindspore::kAbsOpName, {node}); }
Neg(const NodePtr & node)80   NodePtr Neg(const NodePtr &node) { return Emit(mindspore::kNegOpName, {node}); }
Reciprocal(const NodePtr & node)81   NodePtr Reciprocal(const NodePtr &node) { return Emit(mindspore::kReciprocalOpName, {node}); }
Square(const NodePtr & node)82   NodePtr Square(const NodePtr &node) { return Emit(mindspore::kSquareOpName, {node}); }
Sign(const NodePtr & node)83   NodePtr Sign(const NodePtr &node) { return Emit(prim::kPrimSign->name(), {node}); }
84   NodePtr Exp(const NodePtr &x);
85   NodePtr Log(const NodePtr &x);
86   NodePtr Transpose(const NodePtr &node, const NodePtr &perm);
Transpose(const NodePtr & node,const ShapeVector & perm)87   NodePtr Transpose(const NodePtr &node, const ShapeVector &perm) { return Transpose(node, Value(perm)); }
88   NodePtr Tile(const NodePtr &node, const NodePtr &dims);
Tile(const NodePtr & node,const ShapeVector & dims)89   NodePtr Tile(const NodePtr &node, const ShapeVector &dims) { return Tile(node, Value(dims)); }
Concat(const NodePtr & input,int64_t axis)90   NodePtr Concat(const NodePtr &input, int64_t axis) { return Emit(kConcatOpName, {input, Value(axis)}); }
Concat(const NodePtrList & inputs,int64_t axis)91   NodePtr Concat(const NodePtrList &inputs, int64_t axis) {
92     return Emit(kConcatOpName, {MakeTuple(inputs), Value(axis)});
93   }
Add(const NodePtr & lhs,const NodePtr & rhs)94   NodePtr Add(const NodePtr &lhs, const NodePtr &rhs) { return UnifyDtypeAndEmit(mindspore::kAddOpName, lhs, rhs); }
Sub(const NodePtr & lhs,const NodePtr & rhs)95   NodePtr Sub(const NodePtr &lhs, const NodePtr &rhs) { return UnifyDtypeAndEmit(mindspore::kSubOpName, lhs, rhs); }
Mul(const NodePtr & lhs,const NodePtr & rhs)96   NodePtr Mul(const NodePtr &lhs, const NodePtr &rhs) { return UnifyDtypeAndEmit(mindspore::kMulOpName, lhs, rhs); }
Div(const NodePtr & lhs,const NodePtr & rhs)97   NodePtr Div(const NodePtr &lhs, const NodePtr &rhs) { return UnifyDtypeAndEmit(kDivOpName, lhs, rhs); }
RealDiv(const NodePtr & lhs,const NodePtr & rhs)98   NodePtr RealDiv(const NodePtr &lhs, const NodePtr &rhs) {
99     return UnifyDtypeAndEmit(mindspore::kRealDivOpName, lhs, rhs);
100   }
Mod(const NodePtr & lhs,const NodePtr & rhs)101   NodePtr Mod(const NodePtr &lhs, const NodePtr &rhs) { return UnifyDtypeAndEmit("Mod", lhs, rhs); }
Pow(const NodePtr & lhs,const NodePtr & rhs)102   NodePtr Pow(const NodePtr &lhs, const NodePtr &rhs) { return UnifyDtypeAndEmit(kPowOpName, lhs, rhs); }
103   NodePtr MatMul(const NodePtr &a, const NodePtr &b, bool transpose_a = false, bool transpose_b = false);
104   NodePtr MatMulExt(const NodePtr &a, const NodePtr &b);
105   NodePtr BatchMatMul(const NodePtr &a, const NodePtr &b, bool transpose_a = false, bool transpose_b = false);
Maximum(const NodePtr & lhs,const NodePtr & rhs)106   NodePtr Maximum(const NodePtr &lhs, const NodePtr &rhs) { return UnifyDtypeAndEmit(kMaximumOpName, lhs, rhs); }
Minimum(const NodePtr & lhs,const NodePtr & rhs)107   NodePtr Minimum(const NodePtr &lhs, const NodePtr &rhs) { return UnifyDtypeAndEmit(kMinimumOpName, lhs, rhs); }
FloorDiv(const NodePtr & lhs,const NodePtr & rhs)108   NodePtr FloorDiv(const NodePtr &lhs, const NodePtr &rhs) { return UnifyDtypeAndEmit("FloorDiv", lhs, rhs); }
FloorMod(const NodePtr & lhs,const NodePtr & rhs)109   NodePtr FloorMod(const NodePtr &lhs, const NodePtr &rhs) { return UnifyDtypeAndEmit("FloorMod", lhs, rhs); }
DivNoNan(const NodePtr & lhs,const NodePtr & rhs)110   NodePtr DivNoNan(const NodePtr &lhs, const NodePtr &rhs) { return UnifyDtypeAndEmit("DivNoNan", lhs, rhs); }
MulNoNan(const NodePtr & lhs,const NodePtr & rhs)111   NodePtr MulNoNan(const NodePtr &lhs, const NodePtr &rhs) { return UnifyDtypeAndEmit("MulNoNan", lhs, rhs); }
Xdivy(const NodePtr & lhs,const NodePtr & rhs)112   NodePtr Xdivy(const NodePtr &lhs, const NodePtr &rhs) { return UnifyDtypeAndEmit("Xdivy", lhs, rhs); }
Xlogy(const NodePtr & lhs,const NodePtr & rhs)113   NodePtr Xlogy(const NodePtr &lhs, const NodePtr &rhs) { return UnifyDtypeAndEmit("Xlogy", lhs, rhs); }
114 
Select(const NodePtr & cond,const NodePtr & lhs,const NodePtr & rhs)115   NodePtr Select(const NodePtr &cond, const NodePtr &lhs, const NodePtr &rhs) {
116     auto [a, b] = UnifyDtype2(lhs, rhs);
117     return Emit(kSelectOpName, {cond, a, b});
118   }
119   NodePtr Less(const NodePtr &lhs, const NodePtr &rhs, const TypePtr &dst_type = nullptr) {
120     return CmpOpWithCast(kLessOpName, lhs, rhs, dst_type);
121   }
122   NodePtr LessEqual(const NodePtr &lhs, const NodePtr &rhs, const TypePtr &dst_type = nullptr) {
123     return CmpOpWithCast(kLessEqualOpName, lhs, rhs, dst_type);
124   }
125   NodePtr Greater(const NodePtr &lhs, const NodePtr &rhs, const TypePtr &dst_type = nullptr) {
126     return CmpOpWithCast(kGreaterOpName, lhs, rhs, dst_type);
127   }
128   NodePtr GreaterEqual(const NodePtr &lhs, const NodePtr &rhs, const TypePtr &dst_type = nullptr) {
129     return CmpOpWithCast(kGreaterEqualOpName, lhs, rhs, dst_type);
130   }
131   NodePtr Equal(const NodePtr &lhs, const NodePtr &rhs, const TypePtr &dst_type = nullptr) {
132     auto abs = lhs->abstract();
133     MS_EXCEPTION_IF_NULL(abs);
134     if (abs->isa<abstract::AbstractTensor>()) {
135       return CmpOpWithCast(kEqualOpName, lhs, rhs, dst_type);
136     } else if (abs->isa<abstract::AbstractScalar>()) {
137       return ScalarEq(lhs, rhs, dst_type);
138     }
139     MS_LOG(EXCEPTION) << "'Equal' only support [Tensor] or [Scalar] input, but got: " << abs->ToString();
140   }
ScalarEq(const NodePtr & lhs,const NodePtr & rhs,const TypePtr & dst_type)141   virtual NodePtr ScalarEq(const NodePtr &lhs, const NodePtr &rhs, const TypePtr &dst_type) {
142     auto node = UnifyDtypeAndEmit("ScalarEq", lhs, rhs);
143     return dst_type == nullptr ? node : Cast(node, dst_type);
144   }
145   NodePtr NotEqual(const NodePtr &lhs, const NodePtr &rhs, const TypePtr &dst_type = nullptr) {
146     return CmpOpWithCast("NotEqual", lhs, rhs, dst_type);
147   }
LogicalAnd(const NodePtr & lhs,const NodePtr & rhs)148   NodePtr LogicalAnd(const NodePtr &lhs, const NodePtr &rhs) { return Emit("LogicalAnd", {lhs, rhs}); }
LogicalOr(const NodePtr & lhs,const NodePtr & rhs)149   NodePtr LogicalOr(const NodePtr &lhs, const NodePtr &rhs) { return Emit("LogicalOr", {lhs, rhs}); }
LogicalNot(const NodePtr & x)150   NodePtr LogicalNot(const NodePtr &x) { return Emit("LogicalNot", {x}); }
151 
152   NodePtr BoolNot(const NodePtr &node);
153 
OnesLike(const NodePtr & x)154   NodePtr OnesLike(const NodePtr &x) { return Emit("OnesLike", {x}); }
UnsortedSegmentSum(const NodePtr & x,const NodePtr & segment_ids,const NodePtr & num_segments)155   NodePtr UnsortedSegmentSum(const NodePtr &x, const NodePtr &segment_ids, const NodePtr &num_segments) {
156     return Emit("UnsortedSegmentSum", {x, segment_ids, num_segments});
157   }
GatherNd(const NodePtr & input_x,const NodePtr & indices)158   NodePtr GatherNd(const NodePtr &input_x, const NodePtr &indices) { return Emit("GatherNd", {input_x, indices}); }
ScatterNd(const NodePtr & indices,const NodePtr & update,const NodePtr & shape)159   NodePtr ScatterNd(const NodePtr &indices, const NodePtr &update, const NodePtr &shape) {
160     return Emit("ScatterNd", {indices, update, shape});
161   }
Stack(const NodePtr & x,const ValuePtr & axis)162   virtual NodePtr Stack(const NodePtr &x, const ValuePtr &axis) { return Emit("Stack", {x}, {{"axis", axis}}); }
Stack(const NodePtrList & x,int64_t axis)163   virtual NodePtr Stack(const NodePtrList &x, int64_t axis) { return Stack(MakeTuple(x), MakeValue(axis)); }
TensorScatterUpdate(const NodePtr & input_x,const NodePtr & indices,const NodePtr & updates)164   NodePtr TensorScatterUpdate(const NodePtr &input_x, const NodePtr &indices, const NodePtr &updates) {
165     return Emit("TensorScatterUpdate", {input_x, indices, updates});
166   }
Slice(const NodePtr & x,const NodePtr & begin,const NodePtr & size)167   NodePtr Slice(const NodePtr &x, const NodePtr &begin, const NodePtr &size) { return Emit("Slice", {x, begin, size}); }
Squeeze(const NodePtr & x,const ValuePtr & axis)168   NodePtr Squeeze(const NodePtr &x, const ValuePtr &axis) { return Emit("Squeeze", {x}, {{"axis", axis}}); }
Sqrt(const NodePtr & x)169   NodePtr Sqrt(const NodePtr &x) { return Emit("Sqrt", {x}); }
MatrixSetDiagV3(const NodePtr & x,const NodePtr & diagonal,const NodePtr & k,const ValuePtr & align)170   NodePtr MatrixSetDiagV3(const NodePtr &x, const NodePtr &diagonal, const NodePtr &k, const ValuePtr &align) {
171     const auto diag_max_length = 200000000;
172     return Emit("MatrixSetDiagV3", {x, diagonal, k},
173                 {{"max_length", MakeValue<int64_t>(diag_max_length)}, {"align", align}});
174   }
MatrixDiagPartV3(const NodePtr & x,const NodePtr & diagonal,const NodePtr & k,const ValuePtr & align)175   NodePtr MatrixDiagPartV3(const NodePtr &x, const NodePtr &diagonal, const NodePtr &k, const ValuePtr &align) {
176     const auto diag_max_length = 200000000;
177     return Emit("MatrixDiagPartV3", {x, diagonal, k},
178                 {{"max_length", MakeValue<int64_t>(diag_max_length)}, {"align", align}});
179   }
LinSpace(const NodePtr & start,const NodePtr & stop,const NodePtr & num)180   NodePtr LinSpace(const NodePtr &start, const NodePtr &stop, const NodePtr &num) {
181     return Emit("LinSpace", {start, stop, num});
182   }
183 
184   // complex
Conj(const NodePtr & input)185   NodePtr Conj(const NodePtr &input) {
186     TypeId type_id = input->dtype()->type_id();
187     if (type_id == kNumberTypeComplex64 || type_id == kNumberTypeComplex128) {
188       return Emit("Conj", {input});
189     }
190     return input;
191   }
Complex(const NodePtr & real,const NodePtr & imag)192   NodePtr Complex(const NodePtr &real, const NodePtr &imag) { return Emit("Complex", {real, imag}); }
Real(const NodePtr & x)193   NodePtr Real(const NodePtr &x) { return Emit(kRealOpName, {x}); }
Imag(const NodePtr & x)194   NodePtr Imag(const NodePtr &x) { return Emit(kImagOpName, {x}); }
195 
CumProd(const NodePtr & x,const NodePtr & axis,const NodePtr & exclusive,const NodePtr & reverse)196   NodePtr CumProd(const NodePtr &x, const NodePtr &axis, const NodePtr &exclusive, const NodePtr &reverse) {
197     return Emit("CumProd", {x, axis, exclusive, reverse});
198   }
CumProd(const NodePtr & x,const NodePtr & axis,const bool & exclusive,const bool & reverse)199   NodePtr CumProd(const NodePtr &x, const NodePtr &axis, const bool &exclusive, const bool &reverse) {
200     return CumProd(x, axis, Value(exclusive), Value(reverse));
201   }
CumSum(const NodePtr & x,const NodePtr & axis,const NodePtr & exclusive,const NodePtr & reverse)202   NodePtr CumSum(const NodePtr &x, const NodePtr &axis, const NodePtr &exclusive, const NodePtr &reverse) {
203     return Emit("CumSum", {x, axis, exclusive, reverse});
204   }
CumSum(const NodePtr & x,const NodePtr & axis,const bool & exclusive,const bool & reverse)205   NodePtr CumSum(const NodePtr &x, const NodePtr &axis, const bool &exclusive, const bool &reverse) {
206     return CumSum(x, axis, Value(exclusive), Value(reverse));
207   }
CSR2COO(const NodePtr & indptr,const NodePtr & nnz)208   NodePtr CSR2COO(const NodePtr &indptr, const NodePtr &nnz) { return Emit("CSR2COO", {indptr, nnz}); }
209   NodePtr ScalarToTensor(const NodePtr &node);
210   NodePtr ScalarToTensor(const NodePtr &node, const TypePtr &dtype);
211   std::pair<bool, ShapeVector> NeedReduce(const ShapeVector &shape, const std::vector<int64_t> &axis, bool keep_dim,
212                                           bool skip_mode = false) const;
213   std::pair<bool, NodePtr> NeedReduce(const NodePtr &shape, const NodePtr &axis, bool keep_dim, bool skip_mode = false);
214   NodePtr ReduceSum(const NodePtr &x, const NodePtr &axis, bool keep_dims = false, bool skip_mode = false);
215   NodePtr ReduceSum(const NodePtr &x, const ShapeVector &axis = {}, bool keep_dims = false);
216   NodePtr SumExt(const NodePtr &input, const NodePtr &axis, const NodePtr &keep_dims);
217   NodePtr BroadcastTo(const NodePtr &x, const NodePtr &y);
218 
219   NodePtr ZerosLike(const NodePtr &node);
Depend(const NodePtr & value,const NodePtr & expr)220   virtual NodePtr Depend(const NodePtr &value, const NodePtr &expr) {
221     return Emit("Depend", {value, expr}, {{"side_effect_propagate", MakeValue(1)}});
222   }
223   NodePtr Fill(double value, const ShapeVector &shape, TypeId data_type);
224   NodePtr Fill(int64_t value, const ShapeVector &shape, TypeId data_type);
225   template <typename T>
Fill(const T & value,const NodePtr & shape,TypeId data_type)226   NodePtr Fill(const T &value, const NodePtr &shape, TypeId data_type) {
227     MS_EXCEPTION_IF_NULL(shape);
228     if (shape->input_type() == InputType::kConstant) {
229       auto v = shape->BuildValue();
230       MS_EXCEPTION_IF_NULL(v);
231       return Fill(value, GetValue<ShapeVector>(v), data_type);
232     }
233     auto value_tensor = Cast(Tensor(value), data_type);
234     return Emit("DynamicBroadcastTo", {value_tensor, shape});
235   }
236 
237   NodePtr Shape(const NodePtr &node, bool tensor = false) {
238     auto shape = node->shape();
239     if (tensor) {
240       return IsDynamic(shape) ? Emit("TensorShape", {node}) : Tensor(shape);
241     } else {
242       return IsDynamic(shape) ? Emit("Shape", {node}) : Value<ShapeVector>(shape);
243     }
244   }
245 
246   NodePtr Gather(const NodePtr &params, const NodePtr &indices, int64_t axis, int64_t batch_dims = 0);
247   NodePtr Gather(const NodePtr &params, const NodePtr &indices, const NodePtr &axis, int64_t batch_dims = 0);
GatherD(const NodePtr & x,const NodePtr & dim,const NodePtr & index)248   NodePtr GatherD(const NodePtr &x, const NodePtr &dim, const NodePtr &index) {
249     return Emit("GatherD", {x, dim, index});
250   }
BatchNormGrad(const NodePtrList & inputs,bool is_scale_or_bias_grad)251   virtual NodePtr BatchNormGrad(const NodePtrList &inputs, bool is_scale_or_bias_grad) {
252     return Emit("BatchNormGrad", inputs);
253   }
254   virtual NodePtr SparseSoftmaxCrossEntropyWithLogits(const NodePtrList &inputs, const DAttr &attrs, const NodePtr &out,
255                                                       const NodePtr &dout, bool is_graph_mode);
256 
257   // By comparing x with itself, test whether x is NaN
IsNanFunc(const NodePtr & x)258   inline NodePtr IsNanFunc(const NodePtr &x) { return NotEqual(x, x); }
259 
Zeros(const NodePtr & x)260   NodePtr Zeros(const NodePtr &x) {
261     auto x_shape = x->shape();
262     if (!x_shape.empty() && !IsDynamicRank(x_shape)) {
263       // There are currently some problems under 0d that need to be fixed later.
264       return Emit("Zeros", {Shape(x), Value<int64_t>(x->dtype()->type_id())});
265     }
266     return ZerosLike(x);
267   }
268 
269   /// \brief Emit a value node
270   template <typename T>
Value(const T & value)271   NodePtr Value(const T &value) {
272     return EmitValue(MakeValue(value));
273   }
274 
275   /// \brief Emit a Tensor node.
276   template <typename T>
277   NodePtr Tensor(T data, TypePtr type_ptr = nullptr) {
278     auto tensor_ptr = std::make_shared<tensor::Tensor>(data, type_ptr);
279     return EmitValue(tensor_ptr);
280   }
281 
282   /// \brief Emit a tensor node.
Tensor(TypeId data_type,const ShapeVector & shape,void * data,TypeId src_data_type)283   NodePtr Tensor(TypeId data_type, const ShapeVector &shape, void *data, TypeId src_data_type) {
284     auto tensor_ptr = std::make_shared<tensor::Tensor>(data_type, shape, data, src_data_type);
285     return EmitValue(tensor_ptr);
286   }
287 
288   /// \brief get the ExpanderInferPtr
infer()289   const ExpanderInferPtr &infer() const { return infer_; }
290 
291   /// \brief Shape calculation. This interface is used to unify the code between static-shape and dynamic-shape
292   /// situation, the output type is depend on types of inputs.
293   ///
294   /// \param[in] functor The ShapeCalcBaseFunctor object.
295   /// \param[in] inputs The input tensors.
296   /// \param[in] value_depend If index i exists in 'value_depend', the value of inputs[i] is sent to 'functor'.
297   ///                         otherwise the shape of inputs[i] is sent.
298   /// \param[in] valid_func The function to check whether the index and input shape is valid.
299   /// \return NodePtrList, the outputs shape list. When inputs are all static-shape tensors, shape vectors are returned.
300   /// otherwise CNode tensors are returned.
301   NodePtrList ShapeCalc(const ShapeCalcBaseFunctorPtr &functor, const NodePtrList &inputs,
302                         const std::vector<int64_t> &value_depend = {}, const ShapeValidFunc &valid_func = nullptr);
303 
304   /// \brief Emit a TensorToTuple node.
305   NodePtr TensorToTuple(const NodePtr &node);
306 
307   using BlockFunc = std::function<NodePtrList(Emitter *)>;
308   /// \brief Generate a conditional block.
309   ///
310   /// \param[in] cond condition node, it should be a tensor of Bool.
311   /// \param[in] true_case  the true branch.
312   /// \param[in] false_case the false branch.
313   /// \return node of tuple or single value, which is depends on the output list of two branches.
314   /// \note The overloaded operators (like a+b) should not be used for captured variables in the true_case/false_case
315   /// functions, use the function argument `Emitter` instead, like `emitter->Add(a, b)`. The output list of two branches
316   /// should match the join rules of control flow.
317   virtual NodePtr Conditional(const NodePtr &cond, const BlockFunc &true_case, const BlockFunc &false_case);
318 
319   /// \brief Generate a while-loop block.
320   ///
321   /// \param[in] cond condition node, it should be a tensor of Bool.
322   /// \param[in] body  the loop body.
323   /// \param[in] init_list the initial variables that would be modified in body.
324   /// \return node of tuple or single value, which is depends on the init_list.
325   /// \note The overloaded operators (like `a+b`) should not be used for captured variables in the body function, use
326   /// the function argument `Emitter` instead, like `emitter->Add(a, b)`. The length and node order of the output list
327   /// of the body function should match init_list.
328   virtual NodePtr While(const NodePtr &cond, const BlockFunc &body, const NodePtrList &init_list);
329 
330  protected:
331   virtual NodePtr EmitOp(const PrimitivePtr &prim, const NodePtrList &inputs);
CmpOpWithCast(const std::string & op,const NodePtr & lhs,const NodePtr & rhs,const TypePtr & dst_type)332   NodePtr CmpOpWithCast(const std::string &op, const NodePtr &lhs, const NodePtr &rhs, const TypePtr &dst_type) {
333     auto node = UnifyDtypeAndEmit(op, lhs, rhs);
334     return dst_type == nullptr ? node : Cast(node, dst_type);
335   }
336   std::tuple<NodePtr, NodePtr> UnifyDtype2(const NodePtr &lhs, const NodePtr &rhs);
337   NodePtr UnifyDtypeAndEmit(const std::string &op, const NodePtr &a, const NodePtr &b, const DAttr &attrs = {}) {
338     auto [lhs, rhs] = UnifyDtype2(a, b);
339     return Emit(op, {lhs, rhs}, attrs);
340   }
341 
342   ExpanderInferPtr infer_{nullptr};
343   ScopePtr scope_{nullptr};
344   inline static const std::vector<size_t> type_vector_ = [] {
345     std::vector<size_t> type_vector(kSparseTypeEnd + 1);
346     type_vector[kNumberTypeBool] = 1;
347     type_vector[kNumberTypeInt8] = 2;
348     type_vector[kNumberTypeUInt8] = 3;
349     type_vector[kNumberTypeInt16] = 4;
350     type_vector[kNumberTypeInt32] = 5;
351     type_vector[kNumberTypeInt64] = 6;
352     type_vector[kNumberTypeFloat16] = 7;
353     type_vector[kNumberTypeFloat32] = 8;
354     type_vector[kNumberTypeFloat64] = 9;
355     return type_vector;
356   }();
primc_func_cache()357   static HashMap<std::string, ops::OpPrimCDefineFunc> &primc_func_cache() {
358     static HashMap<std::string, ops::OpPrimCDefineFunc> cache{};
359     return cache;
360   }
361 };
362 using EmitterPtr = std::shared_ptr<Emitter>;
363 
364 COMMON_EXPORT NodePtr operator+(const NodePtr &lhs, const NodePtr &rhs);
365 COMMON_EXPORT NodePtr operator-(const NodePtr &lhs, const NodePtr &rhs);
366 COMMON_EXPORT NodePtr operator*(const NodePtr &lhs, const NodePtr &rhs);
367 COMMON_EXPORT NodePtr operator/(const NodePtr &lhs, const NodePtr &rhs);
368 COMMON_EXPORT NodePtr operator-(const NodePtr &node);
369 
370 class COMMON_EXPORT CtrlFlowBlock {
371  public:
372   using BlockFunc = std::function<NodePtrList(Emitter *)>;
373   using EmitterCreator = std::function<EmitterPtr(const FuncGraphPtr &, const ExpanderInferPtr &)>;
374   CtrlFlowBlock(Emitter *emitter, const FuncGraphPtr &func_graph, const EmitterCreator &ec = nullptr)
emitter_(emitter)375       : emitter_(emitter), func_graph_(func_graph), emitter_creator_(ec) {
376     MS_EXCEPTION_IF_NULL(emitter);
377     MS_EXCEPTION_IF_NULL(func_graph);
378   }
379   ~CtrlFlowBlock() = default;
380   NodePtr IfThenElse(const NodePtr &cond, const BlockFunc &true_case, const BlockFunc &false_case);
381 
382   NodePtr While(const NodePtr &cond, const BlockFunc &while_body_func, const NodePtrList &init_list);
383 
384  protected:
385   EmitterPtr CreateInnerEmitter(const FuncGraphPtr &fg, const ExpanderInferPtr &infer) const;
386   NodePtr BuildSubgraph(const BlockFunc &func);
387   NodePtrList BuildSubgraphOfPartial(const BlockFunc &func);
388 
389   Emitter *emitter_;
390   FuncGraphPtr func_graph_;
391   EmitterCreator emitter_creator_;
392   size_t output_num_{0};
393   abstract::AbstractBasePtr out_abstract_{nullptr};
394 
395   class CppInferWithPartial : public CppInfer {
396    public:
397     void Infer(const NodePtr &node) override;
398   };
399 };
400 
401 class COMMON_EXPORT IrEmitter : public Emitter {
402  public:
403   IrEmitter(const FuncGraphPtr &func_graph, const ExpanderInferPtr &infer, const ScopePtr &scope = nullptr)
Emitter(infer,scope)404       : Emitter(infer, scope), func_graph_(func_graph) {
405     MS_EXCEPTION_IF_NULL(func_graph);
406     MS_EXCEPTION_IF_NULL(infer);
407   }
408   NodePtr EmitValue(const ValuePtr &value) override;
func_graph()409   FuncGraphPtr func_graph() { return func_graph_; }
410 
411  protected:
412   NodePtr EmitOp(const PrimitivePtr &prim, const NodePtrList &inputs) override;
413   FuncGraphPtr func_graph_;
414 };
415 
416 class PureShapeCalc : public ShapeCalcBaseFunctor {
417  public:
418   // CalcFunc/InferFunc/CalcWithTupleFunc/InferWithTupleFunc are defined as pure function pointer other than a
419   // std::function, meaning that they should be a lambda function without any capture.
420   using CalcFunc = ShapeArray (*)(const ShapeArray &);
421   using InferFunc = std::vector<int64_t> (*)(const ShapeArray &, const HashSet<size_t> &);
422   using CalcWithTupleFunc = ShapeArray (*)(const ShapeArray &, const ElemPosIdx &);
423   using InferWithTupleFunc = InferOutputInfo (*)(const ShapeArray &, const HashSet<size_t> &, const ElemPosIdx &);
424 
PureShapeCalc(const std::string & name)425   explicit PureShapeCalc(const std::string &name) : ShapeCalcBaseFunctor(name) {
426     FunctorRegistry::Instance().Register(name, [this]() { return shared_from_base<Functor>(); });
427   }
428 
429   PureShapeCalc(const PureShapeCalc &) = delete;
430   PureShapeCalc(PureShapeCalc &&) = delete;
431   PureShapeCalc &operator=(const PureShapeCalc &) = delete;
432   PureShapeCalc &operator=(PureShapeCalc &&) = delete;
433   ~PureShapeCalc() override = default;
MS_DECLARE_PARENT(PureShapeCalc,ShapeCalcBaseFunctor)434   MS_DECLARE_PARENT(PureShapeCalc, ShapeCalcBaseFunctor)
435 
436   ValuePtr ToValue() const override { return nullptr; }
FromValue(const ValuePtr &)437   void FromValue(const ValuePtr &) override {}
438 
Calc(const ShapeArray & inputs,const ElemPosIdx & pos_idx)439   ShapeArray Calc(const ShapeArray &inputs, const ElemPosIdx &pos_idx) const override {
440     ShapeArray calc_res;
441     if (calc_func_ != nullptr) {
442       calc_res = calc_func_(inputs);
443     } else if (cal_with_tuple_func_ != nullptr) {
444       calc_res = cal_with_tuple_func_(inputs, pos_idx);
445     } else {
446       MS_LOG(EXCEPTION) << "The calc_func of " << name() << " is nullptr";
447     }
448 
449     return calc_res;
450   }
451 
Infer(const ShapeArray & inputs,const HashSet<size_t> & unknown_inputs,const ElemPosIdx & pos_idx)452   InferOutputInfo Infer(const ShapeArray &inputs, const HashSet<size_t> &unknown_inputs,
453                         const ElemPosIdx &pos_idx) const override {
454     InferOutputInfo infer_res;
455     if (infer_func_ != nullptr) {
456       auto output_shapes = infer_func_(inputs, unknown_inputs);
457       infer_res = std::make_pair(output_shapes, false);
458     } else if (infer_with_tuple_func_ != nullptr) {
459       infer_res = infer_with_tuple_func_(inputs, unknown_inputs, pos_idx);
460     } else {
461       MS_LOG(EXCEPTION) << "The infer_func of " << name() << " is nullptr";
462     }
463 
464     return infer_res;
465   }
466 
SetCalc(const CalcFunc & calc_func)467   PureShapeCalc &SetCalc(const CalcFunc &calc_func) {
468     calc_func_ = calc_func;
469     return *this;
470   }
471 
SetInfer(const InferFunc & infer_func)472   std::shared_ptr<PureShapeCalc> SetInfer(const InferFunc &infer_func) {
473     infer_func_ = infer_func;
474     if (calc_func_ == nullptr || cal_with_tuple_func_ != nullptr) {
475       MS_LOG(EXCEPTION) << "The Calc Function and Infer Function should all not support tuple!";
476     }
477     return shared_from_base<PureShapeCalc>();
478   }
479 
SetCalc(const CalcWithTupleFunc & calc_func)480   PureShapeCalc &SetCalc(const CalcWithTupleFunc &calc_func) {
481     cal_with_tuple_func_ = calc_func;
482     return *this;
483   }
484 
SetInfer(const InferWithTupleFunc & infer_func)485   std::shared_ptr<PureShapeCalc> SetInfer(const InferWithTupleFunc &infer_func) {
486     infer_with_tuple_func_ = infer_func;
487     if (cal_with_tuple_func_ == nullptr || calc_func_ != nullptr) {
488       MS_LOG(EXCEPTION) << "The Calc Function and Infer Function should all support tuple!";
489     }
490     return shared_from_base<PureShapeCalc>();
491   }
492 
493  private:
494   CalcFunc calc_func_{nullptr};
495   InferFunc infer_func_{nullptr};
496   CalcWithTupleFunc cal_with_tuple_func_{nullptr};
497   InferWithTupleFunc infer_with_tuple_func_{nullptr};
498 };
499 
500 #define DEF_PURE_SHAPE_CALC(name) \
501   static const std::shared_ptr<PureShapeCalc> name = (*(std::make_shared<PureShapeCalc>("ShapeCalc_" #name)))
502 
503 }  // namespace expander
504 }  // namespace mindspore
505 #endif  // MINDSPORE_CCSRC_COMMON_EXPANDER_CORE_EMITTER_H_
506