1 /** 2 * Copyright 2022-2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_COMMON_EXPANDER_CORE_EMITTER_H_ 18 #define MINDSPORE_CCSRC_COMMON_EXPANDER_CORE_EMITTER_H_ 19 #include <map> 20 #include <memory> 21 #include <string> 22 #include <tuple> 23 #include <unordered_set> 24 #include <utility> 25 #include <vector> 26 #include "include/common/expander/core/infer.h" 27 #include "include/common/expander/core/node.h" 28 #include "include/common/utils/utils.h" 29 #include "ir/func_graph.h" 30 #include "ir/functor.h" 31 #include "ops/array_op_name.h" 32 #include "ops/comparison_op_name.h" 33 #include "ops/framework_op_name.h" 34 #include "ops/arithmetic_op_name.h" 35 #include "ops/math_ops.h" 36 #include "ops/sequence_ops.h" 37 #include "ops/shape_calc.h" 38 #include "ops/auto_generate/gen_ops_name.h" 39 40 namespace mindspore { 41 namespace expander { 42 using ShapeValidFunc = std::function<bool(size_t, const ShapeVector &)>; 43 44 class COMMON_EXPORT Emitter { 45 public: infer_(infer)46 explicit Emitter(const ExpanderInferPtr &infer, const ScopePtr &scope = nullptr) : infer_(infer), scope_(scope) {} 47 virtual ~Emitter() = default; 48 49 /// \brief Emit a primitive CNode 50 NodePtr Emit(const std::string &op_name, const NodePtrList &inputs, const DAttr &attrs = {}); 51 PrimitivePtr NewPrimitive(const std::string &name, const DAttr &attrs = {}); 52 53 /// \brief Emit a ValueNode 54 virtual NodePtr EmitValue(const ValuePtr &value); 55 NewIrNode(const AnfNodePtr & anfnode)56 NodePtr NewIrNode(const AnfNodePtr &anfnode) { return std::make_shared<IrNode>(anfnode, this); } NewFuncNode(const ValuePtr & value,const abstract::AbstractBasePtr & abs,InputType input_type)57 FuncNodePtr NewFuncNode(const ValuePtr &value, const abstract::AbstractBasePtr &abs, InputType input_type) { 58 return std::make_shared<FuncNode>(value, abs, input_type, this); 59 } MakeTuple(const NodePtrList & inputs)60 virtual NodePtr MakeTuple(const NodePtrList &inputs) { return EmitOp(prim::kPrimMakeTuple, inputs); } MakeList(const NodePtrList & inputs)61 virtual NodePtr MakeList(const NodePtrList &inputs) { return EmitOp(prim::kPrimMakeList, inputs); } TupleGetItem(const NodePtr & input,size_t i)62 virtual NodePtr TupleGetItem(const NodePtr &input, size_t i) { 63 return Emit(mindspore::kTupleGetItemOpName, {input, Value(static_cast<int64_t>(i))}); 64 } TupleGetItem(const NodePtr & input,const NodePtr & i)65 virtual NodePtr TupleGetItem(const NodePtr &input, const NodePtr &i) { return Emit(kTupleGetItemOpName, {input, i}); } Len(const NodePtr & input)66 NodePtr Len(const NodePtr &input) { return Emit(kSequenceLenOpName, {input}); } ScalarAdd(const NodePtr & lhs,const NodePtr & rhs)67 NodePtr ScalarAdd(const NodePtr &lhs, const NodePtr &rhs) { return Emit(ops::kNameScalarAdd, {lhs, rhs}); } ScalarSub(const NodePtr & lhs,const NodePtr & rhs)68 NodePtr ScalarSub(const NodePtr &lhs, const NodePtr &rhs) { return Emit(ops::kNameScalarSub, {lhs, rhs}); } ScalarMul(const NodePtr & lhs,const NodePtr & rhs)69 NodePtr ScalarMul(const NodePtr &lhs, const NodePtr &rhs) { return Emit(ops::kNameScalarMul, {lhs, rhs}); } ScalarDiv(const NodePtr & lhs,const NodePtr & rhs)70 NodePtr ScalarDiv(const NodePtr &lhs, const NodePtr &rhs) { return Emit(ops::kNameScalarDiv, {lhs, rhs}); } ScalarFloorDiv(const NodePtr & lhs,const NodePtr & rhs)71 NodePtr ScalarFloorDiv(const NodePtr &lhs, const NodePtr &rhs) { return Emit(ops::kNameScalarFloorDiv, {lhs, rhs}); } ScalarNeg(const NodePtr & node)72 NodePtr ScalarNeg(const NodePtr &node) { return Emit(ops::kNameScalarUsub, {node}); } 73 NodePtr Cast(const NodePtr &node, const TypePtr &type); Cast(const NodePtr & node,TypeId type_id)74 NodePtr Cast(const NodePtr &node, TypeId type_id) { return Cast(node, TypeIdToType(type_id)); } 75 76 NodePtr Reshape(const NodePtr &node, const NodePtr &shape); Reshape(const NodePtr & node,const ShapeVector & shape)77 NodePtr Reshape(const NodePtr &node, const ShapeVector &shape) { return Reshape(node, Value(shape)); } ExpandDims(const NodePtr & node,int64_t axis)78 NodePtr ExpandDims(const NodePtr &node, int64_t axis) { return Emit(kExpandDimsOpName, {node, Value(axis)}); } Abs(const NodePtr & node)79 NodePtr Abs(const NodePtr &node) { return Emit(mindspore::kAbsOpName, {node}); } Neg(const NodePtr & node)80 NodePtr Neg(const NodePtr &node) { return Emit(mindspore::kNegOpName, {node}); } Reciprocal(const NodePtr & node)81 NodePtr Reciprocal(const NodePtr &node) { return Emit(mindspore::kReciprocalOpName, {node}); } Square(const NodePtr & node)82 NodePtr Square(const NodePtr &node) { return Emit(mindspore::kSquareOpName, {node}); } Sign(const NodePtr & node)83 NodePtr Sign(const NodePtr &node) { return Emit(prim::kPrimSign->name(), {node}); } 84 NodePtr Exp(const NodePtr &x); 85 NodePtr Log(const NodePtr &x); 86 NodePtr Transpose(const NodePtr &node, const NodePtr &perm); Transpose(const NodePtr & node,const ShapeVector & perm)87 NodePtr Transpose(const NodePtr &node, const ShapeVector &perm) { return Transpose(node, Value(perm)); } 88 NodePtr Tile(const NodePtr &node, const NodePtr &dims); Tile(const NodePtr & node,const ShapeVector & dims)89 NodePtr Tile(const NodePtr &node, const ShapeVector &dims) { return Tile(node, Value(dims)); } Concat(const NodePtr & input,int64_t axis)90 NodePtr Concat(const NodePtr &input, int64_t axis) { return Emit(kConcatOpName, {input, Value(axis)}); } Concat(const NodePtrList & inputs,int64_t axis)91 NodePtr Concat(const NodePtrList &inputs, int64_t axis) { 92 return Emit(kConcatOpName, {MakeTuple(inputs), Value(axis)}); 93 } Add(const NodePtr & lhs,const NodePtr & rhs)94 NodePtr Add(const NodePtr &lhs, const NodePtr &rhs) { return UnifyDtypeAndEmit(mindspore::kAddOpName, lhs, rhs); } Sub(const NodePtr & lhs,const NodePtr & rhs)95 NodePtr Sub(const NodePtr &lhs, const NodePtr &rhs) { return UnifyDtypeAndEmit(mindspore::kSubOpName, lhs, rhs); } Mul(const NodePtr & lhs,const NodePtr & rhs)96 NodePtr Mul(const NodePtr &lhs, const NodePtr &rhs) { return UnifyDtypeAndEmit(mindspore::kMulOpName, lhs, rhs); } Div(const NodePtr & lhs,const NodePtr & rhs)97 NodePtr Div(const NodePtr &lhs, const NodePtr &rhs) { return UnifyDtypeAndEmit(kDivOpName, lhs, rhs); } RealDiv(const NodePtr & lhs,const NodePtr & rhs)98 NodePtr RealDiv(const NodePtr &lhs, const NodePtr &rhs) { 99 return UnifyDtypeAndEmit(mindspore::kRealDivOpName, lhs, rhs); 100 } Mod(const NodePtr & lhs,const NodePtr & rhs)101 NodePtr Mod(const NodePtr &lhs, const NodePtr &rhs) { return UnifyDtypeAndEmit("Mod", lhs, rhs); } Pow(const NodePtr & lhs,const NodePtr & rhs)102 NodePtr Pow(const NodePtr &lhs, const NodePtr &rhs) { return UnifyDtypeAndEmit(kPowOpName, lhs, rhs); } 103 NodePtr MatMul(const NodePtr &a, const NodePtr &b, bool transpose_a = false, bool transpose_b = false); 104 NodePtr MatMulExt(const NodePtr &a, const NodePtr &b); 105 NodePtr BatchMatMul(const NodePtr &a, const NodePtr &b, bool transpose_a = false, bool transpose_b = false); Maximum(const NodePtr & lhs,const NodePtr & rhs)106 NodePtr Maximum(const NodePtr &lhs, const NodePtr &rhs) { return UnifyDtypeAndEmit(kMaximumOpName, lhs, rhs); } Minimum(const NodePtr & lhs,const NodePtr & rhs)107 NodePtr Minimum(const NodePtr &lhs, const NodePtr &rhs) { return UnifyDtypeAndEmit(kMinimumOpName, lhs, rhs); } FloorDiv(const NodePtr & lhs,const NodePtr & rhs)108 NodePtr FloorDiv(const NodePtr &lhs, const NodePtr &rhs) { return UnifyDtypeAndEmit("FloorDiv", lhs, rhs); } FloorMod(const NodePtr & lhs,const NodePtr & rhs)109 NodePtr FloorMod(const NodePtr &lhs, const NodePtr &rhs) { return UnifyDtypeAndEmit("FloorMod", lhs, rhs); } DivNoNan(const NodePtr & lhs,const NodePtr & rhs)110 NodePtr DivNoNan(const NodePtr &lhs, const NodePtr &rhs) { return UnifyDtypeAndEmit("DivNoNan", lhs, rhs); } MulNoNan(const NodePtr & lhs,const NodePtr & rhs)111 NodePtr MulNoNan(const NodePtr &lhs, const NodePtr &rhs) { return UnifyDtypeAndEmit("MulNoNan", lhs, rhs); } Xdivy(const NodePtr & lhs,const NodePtr & rhs)112 NodePtr Xdivy(const NodePtr &lhs, const NodePtr &rhs) { return UnifyDtypeAndEmit("Xdivy", lhs, rhs); } Xlogy(const NodePtr & lhs,const NodePtr & rhs)113 NodePtr Xlogy(const NodePtr &lhs, const NodePtr &rhs) { return UnifyDtypeAndEmit("Xlogy", lhs, rhs); } 114 Select(const NodePtr & cond,const NodePtr & lhs,const NodePtr & rhs)115 NodePtr Select(const NodePtr &cond, const NodePtr &lhs, const NodePtr &rhs) { 116 auto [a, b] = UnifyDtype2(lhs, rhs); 117 return Emit(kSelectOpName, {cond, a, b}); 118 } 119 NodePtr Less(const NodePtr &lhs, const NodePtr &rhs, const TypePtr &dst_type = nullptr) { 120 return CmpOpWithCast(kLessOpName, lhs, rhs, dst_type); 121 } 122 NodePtr LessEqual(const NodePtr &lhs, const NodePtr &rhs, const TypePtr &dst_type = nullptr) { 123 return CmpOpWithCast(kLessEqualOpName, lhs, rhs, dst_type); 124 } 125 NodePtr Greater(const NodePtr &lhs, const NodePtr &rhs, const TypePtr &dst_type = nullptr) { 126 return CmpOpWithCast(kGreaterOpName, lhs, rhs, dst_type); 127 } 128 NodePtr GreaterEqual(const NodePtr &lhs, const NodePtr &rhs, const TypePtr &dst_type = nullptr) { 129 return CmpOpWithCast(kGreaterEqualOpName, lhs, rhs, dst_type); 130 } 131 NodePtr Equal(const NodePtr &lhs, const NodePtr &rhs, const TypePtr &dst_type = nullptr) { 132 auto abs = lhs->abstract(); 133 MS_EXCEPTION_IF_NULL(abs); 134 if (abs->isa<abstract::AbstractTensor>()) { 135 return CmpOpWithCast(kEqualOpName, lhs, rhs, dst_type); 136 } else if (abs->isa<abstract::AbstractScalar>()) { 137 return ScalarEq(lhs, rhs, dst_type); 138 } 139 MS_LOG(EXCEPTION) << "'Equal' only support [Tensor] or [Scalar] input, but got: " << abs->ToString(); 140 } ScalarEq(const NodePtr & lhs,const NodePtr & rhs,const TypePtr & dst_type)141 virtual NodePtr ScalarEq(const NodePtr &lhs, const NodePtr &rhs, const TypePtr &dst_type) { 142 auto node = UnifyDtypeAndEmit("ScalarEq", lhs, rhs); 143 return dst_type == nullptr ? node : Cast(node, dst_type); 144 } 145 NodePtr NotEqual(const NodePtr &lhs, const NodePtr &rhs, const TypePtr &dst_type = nullptr) { 146 return CmpOpWithCast("NotEqual", lhs, rhs, dst_type); 147 } LogicalAnd(const NodePtr & lhs,const NodePtr & rhs)148 NodePtr LogicalAnd(const NodePtr &lhs, const NodePtr &rhs) { return Emit("LogicalAnd", {lhs, rhs}); } LogicalOr(const NodePtr & lhs,const NodePtr & rhs)149 NodePtr LogicalOr(const NodePtr &lhs, const NodePtr &rhs) { return Emit("LogicalOr", {lhs, rhs}); } LogicalNot(const NodePtr & x)150 NodePtr LogicalNot(const NodePtr &x) { return Emit("LogicalNot", {x}); } 151 152 NodePtr BoolNot(const NodePtr &node); 153 OnesLike(const NodePtr & x)154 NodePtr OnesLike(const NodePtr &x) { return Emit("OnesLike", {x}); } UnsortedSegmentSum(const NodePtr & x,const NodePtr & segment_ids,const NodePtr & num_segments)155 NodePtr UnsortedSegmentSum(const NodePtr &x, const NodePtr &segment_ids, const NodePtr &num_segments) { 156 return Emit("UnsortedSegmentSum", {x, segment_ids, num_segments}); 157 } GatherNd(const NodePtr & input_x,const NodePtr & indices)158 NodePtr GatherNd(const NodePtr &input_x, const NodePtr &indices) { return Emit("GatherNd", {input_x, indices}); } ScatterNd(const NodePtr & indices,const NodePtr & update,const NodePtr & shape)159 NodePtr ScatterNd(const NodePtr &indices, const NodePtr &update, const NodePtr &shape) { 160 return Emit("ScatterNd", {indices, update, shape}); 161 } Stack(const NodePtr & x,const ValuePtr & axis)162 virtual NodePtr Stack(const NodePtr &x, const ValuePtr &axis) { return Emit("Stack", {x}, {{"axis", axis}}); } Stack(const NodePtrList & x,int64_t axis)163 virtual NodePtr Stack(const NodePtrList &x, int64_t axis) { return Stack(MakeTuple(x), MakeValue(axis)); } TensorScatterUpdate(const NodePtr & input_x,const NodePtr & indices,const NodePtr & updates)164 NodePtr TensorScatterUpdate(const NodePtr &input_x, const NodePtr &indices, const NodePtr &updates) { 165 return Emit("TensorScatterUpdate", {input_x, indices, updates}); 166 } Slice(const NodePtr & x,const NodePtr & begin,const NodePtr & size)167 NodePtr Slice(const NodePtr &x, const NodePtr &begin, const NodePtr &size) { return Emit("Slice", {x, begin, size}); } Squeeze(const NodePtr & x,const ValuePtr & axis)168 NodePtr Squeeze(const NodePtr &x, const ValuePtr &axis) { return Emit("Squeeze", {x}, {{"axis", axis}}); } Sqrt(const NodePtr & x)169 NodePtr Sqrt(const NodePtr &x) { return Emit("Sqrt", {x}); } MatrixSetDiagV3(const NodePtr & x,const NodePtr & diagonal,const NodePtr & k,const ValuePtr & align)170 NodePtr MatrixSetDiagV3(const NodePtr &x, const NodePtr &diagonal, const NodePtr &k, const ValuePtr &align) { 171 const auto diag_max_length = 200000000; 172 return Emit("MatrixSetDiagV3", {x, diagonal, k}, 173 {{"max_length", MakeValue<int64_t>(diag_max_length)}, {"align", align}}); 174 } MatrixDiagPartV3(const NodePtr & x,const NodePtr & diagonal,const NodePtr & k,const ValuePtr & align)175 NodePtr MatrixDiagPartV3(const NodePtr &x, const NodePtr &diagonal, const NodePtr &k, const ValuePtr &align) { 176 const auto diag_max_length = 200000000; 177 return Emit("MatrixDiagPartV3", {x, diagonal, k}, 178 {{"max_length", MakeValue<int64_t>(diag_max_length)}, {"align", align}}); 179 } LinSpace(const NodePtr & start,const NodePtr & stop,const NodePtr & num)180 NodePtr LinSpace(const NodePtr &start, const NodePtr &stop, const NodePtr &num) { 181 return Emit("LinSpace", {start, stop, num}); 182 } 183 184 // complex Conj(const NodePtr & input)185 NodePtr Conj(const NodePtr &input) { 186 TypeId type_id = input->dtype()->type_id(); 187 if (type_id == kNumberTypeComplex64 || type_id == kNumberTypeComplex128) { 188 return Emit("Conj", {input}); 189 } 190 return input; 191 } Complex(const NodePtr & real,const NodePtr & imag)192 NodePtr Complex(const NodePtr &real, const NodePtr &imag) { return Emit("Complex", {real, imag}); } Real(const NodePtr & x)193 NodePtr Real(const NodePtr &x) { return Emit(kRealOpName, {x}); } Imag(const NodePtr & x)194 NodePtr Imag(const NodePtr &x) { return Emit(kImagOpName, {x}); } 195 CumProd(const NodePtr & x,const NodePtr & axis,const NodePtr & exclusive,const NodePtr & reverse)196 NodePtr CumProd(const NodePtr &x, const NodePtr &axis, const NodePtr &exclusive, const NodePtr &reverse) { 197 return Emit("CumProd", {x, axis, exclusive, reverse}); 198 } CumProd(const NodePtr & x,const NodePtr & axis,const bool & exclusive,const bool & reverse)199 NodePtr CumProd(const NodePtr &x, const NodePtr &axis, const bool &exclusive, const bool &reverse) { 200 return CumProd(x, axis, Value(exclusive), Value(reverse)); 201 } CumSum(const NodePtr & x,const NodePtr & axis,const NodePtr & exclusive,const NodePtr & reverse)202 NodePtr CumSum(const NodePtr &x, const NodePtr &axis, const NodePtr &exclusive, const NodePtr &reverse) { 203 return Emit("CumSum", {x, axis, exclusive, reverse}); 204 } CumSum(const NodePtr & x,const NodePtr & axis,const bool & exclusive,const bool & reverse)205 NodePtr CumSum(const NodePtr &x, const NodePtr &axis, const bool &exclusive, const bool &reverse) { 206 return CumSum(x, axis, Value(exclusive), Value(reverse)); 207 } CSR2COO(const NodePtr & indptr,const NodePtr & nnz)208 NodePtr CSR2COO(const NodePtr &indptr, const NodePtr &nnz) { return Emit("CSR2COO", {indptr, nnz}); } 209 NodePtr ScalarToTensor(const NodePtr &node); 210 NodePtr ScalarToTensor(const NodePtr &node, const TypePtr &dtype); 211 std::pair<bool, ShapeVector> NeedReduce(const ShapeVector &shape, const std::vector<int64_t> &axis, bool keep_dim, 212 bool skip_mode = false) const; 213 std::pair<bool, NodePtr> NeedReduce(const NodePtr &shape, const NodePtr &axis, bool keep_dim, bool skip_mode = false); 214 NodePtr ReduceSum(const NodePtr &x, const NodePtr &axis, bool keep_dims = false, bool skip_mode = false); 215 NodePtr ReduceSum(const NodePtr &x, const ShapeVector &axis = {}, bool keep_dims = false); 216 NodePtr SumExt(const NodePtr &input, const NodePtr &axis, const NodePtr &keep_dims); 217 NodePtr BroadcastTo(const NodePtr &x, const NodePtr &y); 218 219 NodePtr ZerosLike(const NodePtr &node); Depend(const NodePtr & value,const NodePtr & expr)220 virtual NodePtr Depend(const NodePtr &value, const NodePtr &expr) { 221 return Emit("Depend", {value, expr}, {{"side_effect_propagate", MakeValue(1)}}); 222 } 223 NodePtr Fill(double value, const ShapeVector &shape, TypeId data_type); 224 NodePtr Fill(int64_t value, const ShapeVector &shape, TypeId data_type); 225 template <typename T> Fill(const T & value,const NodePtr & shape,TypeId data_type)226 NodePtr Fill(const T &value, const NodePtr &shape, TypeId data_type) { 227 MS_EXCEPTION_IF_NULL(shape); 228 if (shape->input_type() == InputType::kConstant) { 229 auto v = shape->BuildValue(); 230 MS_EXCEPTION_IF_NULL(v); 231 return Fill(value, GetValue<ShapeVector>(v), data_type); 232 } 233 auto value_tensor = Cast(Tensor(value), data_type); 234 return Emit("DynamicBroadcastTo", {value_tensor, shape}); 235 } 236 237 NodePtr Shape(const NodePtr &node, bool tensor = false) { 238 auto shape = node->shape(); 239 if (tensor) { 240 return IsDynamic(shape) ? Emit("TensorShape", {node}) : Tensor(shape); 241 } else { 242 return IsDynamic(shape) ? Emit("Shape", {node}) : Value<ShapeVector>(shape); 243 } 244 } 245 246 NodePtr Gather(const NodePtr ¶ms, const NodePtr &indices, int64_t axis, int64_t batch_dims = 0); 247 NodePtr Gather(const NodePtr ¶ms, const NodePtr &indices, const NodePtr &axis, int64_t batch_dims = 0); GatherD(const NodePtr & x,const NodePtr & dim,const NodePtr & index)248 NodePtr GatherD(const NodePtr &x, const NodePtr &dim, const NodePtr &index) { 249 return Emit("GatherD", {x, dim, index}); 250 } BatchNormGrad(const NodePtrList & inputs,bool is_scale_or_bias_grad)251 virtual NodePtr BatchNormGrad(const NodePtrList &inputs, bool is_scale_or_bias_grad) { 252 return Emit("BatchNormGrad", inputs); 253 } 254 virtual NodePtr SparseSoftmaxCrossEntropyWithLogits(const NodePtrList &inputs, const DAttr &attrs, const NodePtr &out, 255 const NodePtr &dout, bool is_graph_mode); 256 257 // By comparing x with itself, test whether x is NaN IsNanFunc(const NodePtr & x)258 inline NodePtr IsNanFunc(const NodePtr &x) { return NotEqual(x, x); } 259 Zeros(const NodePtr & x)260 NodePtr Zeros(const NodePtr &x) { 261 auto x_shape = x->shape(); 262 if (!x_shape.empty() && !IsDynamicRank(x_shape)) { 263 // There are currently some problems under 0d that need to be fixed later. 264 return Emit("Zeros", {Shape(x), Value<int64_t>(x->dtype()->type_id())}); 265 } 266 return ZerosLike(x); 267 } 268 269 /// \brief Emit a value node 270 template <typename T> Value(const T & value)271 NodePtr Value(const T &value) { 272 return EmitValue(MakeValue(value)); 273 } 274 275 /// \brief Emit a Tensor node. 276 template <typename T> 277 NodePtr Tensor(T data, TypePtr type_ptr = nullptr) { 278 auto tensor_ptr = std::make_shared<tensor::Tensor>(data, type_ptr); 279 return EmitValue(tensor_ptr); 280 } 281 282 /// \brief Emit a tensor node. Tensor(TypeId data_type,const ShapeVector & shape,void * data,TypeId src_data_type)283 NodePtr Tensor(TypeId data_type, const ShapeVector &shape, void *data, TypeId src_data_type) { 284 auto tensor_ptr = std::make_shared<tensor::Tensor>(data_type, shape, data, src_data_type); 285 return EmitValue(tensor_ptr); 286 } 287 288 /// \brief get the ExpanderInferPtr infer()289 const ExpanderInferPtr &infer() const { return infer_; } 290 291 /// \brief Shape calculation. This interface is used to unify the code between static-shape and dynamic-shape 292 /// situation, the output type is depend on types of inputs. 293 /// 294 /// \param[in] functor The ShapeCalcBaseFunctor object. 295 /// \param[in] inputs The input tensors. 296 /// \param[in] value_depend If index i exists in 'value_depend', the value of inputs[i] is sent to 'functor'. 297 /// otherwise the shape of inputs[i] is sent. 298 /// \param[in] valid_func The function to check whether the index and input shape is valid. 299 /// \return NodePtrList, the outputs shape list. When inputs are all static-shape tensors, shape vectors are returned. 300 /// otherwise CNode tensors are returned. 301 NodePtrList ShapeCalc(const ShapeCalcBaseFunctorPtr &functor, const NodePtrList &inputs, 302 const std::vector<int64_t> &value_depend = {}, const ShapeValidFunc &valid_func = nullptr); 303 304 /// \brief Emit a TensorToTuple node. 305 NodePtr TensorToTuple(const NodePtr &node); 306 307 using BlockFunc = std::function<NodePtrList(Emitter *)>; 308 /// \brief Generate a conditional block. 309 /// 310 /// \param[in] cond condition node, it should be a tensor of Bool. 311 /// \param[in] true_case the true branch. 312 /// \param[in] false_case the false branch. 313 /// \return node of tuple or single value, which is depends on the output list of two branches. 314 /// \note The overloaded operators (like a+b) should not be used for captured variables in the true_case/false_case 315 /// functions, use the function argument `Emitter` instead, like `emitter->Add(a, b)`. The output list of two branches 316 /// should match the join rules of control flow. 317 virtual NodePtr Conditional(const NodePtr &cond, const BlockFunc &true_case, const BlockFunc &false_case); 318 319 /// \brief Generate a while-loop block. 320 /// 321 /// \param[in] cond condition node, it should be a tensor of Bool. 322 /// \param[in] body the loop body. 323 /// \param[in] init_list the initial variables that would be modified in body. 324 /// \return node of tuple or single value, which is depends on the init_list. 325 /// \note The overloaded operators (like `a+b`) should not be used for captured variables in the body function, use 326 /// the function argument `Emitter` instead, like `emitter->Add(a, b)`. The length and node order of the output list 327 /// of the body function should match init_list. 328 virtual NodePtr While(const NodePtr &cond, const BlockFunc &body, const NodePtrList &init_list); 329 330 protected: 331 virtual NodePtr EmitOp(const PrimitivePtr &prim, const NodePtrList &inputs); CmpOpWithCast(const std::string & op,const NodePtr & lhs,const NodePtr & rhs,const TypePtr & dst_type)332 NodePtr CmpOpWithCast(const std::string &op, const NodePtr &lhs, const NodePtr &rhs, const TypePtr &dst_type) { 333 auto node = UnifyDtypeAndEmit(op, lhs, rhs); 334 return dst_type == nullptr ? node : Cast(node, dst_type); 335 } 336 std::tuple<NodePtr, NodePtr> UnifyDtype2(const NodePtr &lhs, const NodePtr &rhs); 337 NodePtr UnifyDtypeAndEmit(const std::string &op, const NodePtr &a, const NodePtr &b, const DAttr &attrs = {}) { 338 auto [lhs, rhs] = UnifyDtype2(a, b); 339 return Emit(op, {lhs, rhs}, attrs); 340 } 341 342 ExpanderInferPtr infer_{nullptr}; 343 ScopePtr scope_{nullptr}; 344 inline static const std::vector<size_t> type_vector_ = [] { 345 std::vector<size_t> type_vector(kSparseTypeEnd + 1); 346 type_vector[kNumberTypeBool] = 1; 347 type_vector[kNumberTypeInt8] = 2; 348 type_vector[kNumberTypeUInt8] = 3; 349 type_vector[kNumberTypeInt16] = 4; 350 type_vector[kNumberTypeInt32] = 5; 351 type_vector[kNumberTypeInt64] = 6; 352 type_vector[kNumberTypeFloat16] = 7; 353 type_vector[kNumberTypeFloat32] = 8; 354 type_vector[kNumberTypeFloat64] = 9; 355 return type_vector; 356 }(); primc_func_cache()357 static HashMap<std::string, ops::OpPrimCDefineFunc> &primc_func_cache() { 358 static HashMap<std::string, ops::OpPrimCDefineFunc> cache{}; 359 return cache; 360 } 361 }; 362 using EmitterPtr = std::shared_ptr<Emitter>; 363 364 COMMON_EXPORT NodePtr operator+(const NodePtr &lhs, const NodePtr &rhs); 365 COMMON_EXPORT NodePtr operator-(const NodePtr &lhs, const NodePtr &rhs); 366 COMMON_EXPORT NodePtr operator*(const NodePtr &lhs, const NodePtr &rhs); 367 COMMON_EXPORT NodePtr operator/(const NodePtr &lhs, const NodePtr &rhs); 368 COMMON_EXPORT NodePtr operator-(const NodePtr &node); 369 370 class COMMON_EXPORT CtrlFlowBlock { 371 public: 372 using BlockFunc = std::function<NodePtrList(Emitter *)>; 373 using EmitterCreator = std::function<EmitterPtr(const FuncGraphPtr &, const ExpanderInferPtr &)>; 374 CtrlFlowBlock(Emitter *emitter, const FuncGraphPtr &func_graph, const EmitterCreator &ec = nullptr) emitter_(emitter)375 : emitter_(emitter), func_graph_(func_graph), emitter_creator_(ec) { 376 MS_EXCEPTION_IF_NULL(emitter); 377 MS_EXCEPTION_IF_NULL(func_graph); 378 } 379 ~CtrlFlowBlock() = default; 380 NodePtr IfThenElse(const NodePtr &cond, const BlockFunc &true_case, const BlockFunc &false_case); 381 382 NodePtr While(const NodePtr &cond, const BlockFunc &while_body_func, const NodePtrList &init_list); 383 384 protected: 385 EmitterPtr CreateInnerEmitter(const FuncGraphPtr &fg, const ExpanderInferPtr &infer) const; 386 NodePtr BuildSubgraph(const BlockFunc &func); 387 NodePtrList BuildSubgraphOfPartial(const BlockFunc &func); 388 389 Emitter *emitter_; 390 FuncGraphPtr func_graph_; 391 EmitterCreator emitter_creator_; 392 size_t output_num_{0}; 393 abstract::AbstractBasePtr out_abstract_{nullptr}; 394 395 class CppInferWithPartial : public CppInfer { 396 public: 397 void Infer(const NodePtr &node) override; 398 }; 399 }; 400 401 class COMMON_EXPORT IrEmitter : public Emitter { 402 public: 403 IrEmitter(const FuncGraphPtr &func_graph, const ExpanderInferPtr &infer, const ScopePtr &scope = nullptr) Emitter(infer,scope)404 : Emitter(infer, scope), func_graph_(func_graph) { 405 MS_EXCEPTION_IF_NULL(func_graph); 406 MS_EXCEPTION_IF_NULL(infer); 407 } 408 NodePtr EmitValue(const ValuePtr &value) override; func_graph()409 FuncGraphPtr func_graph() { return func_graph_; } 410 411 protected: 412 NodePtr EmitOp(const PrimitivePtr &prim, const NodePtrList &inputs) override; 413 FuncGraphPtr func_graph_; 414 }; 415 416 class PureShapeCalc : public ShapeCalcBaseFunctor { 417 public: 418 // CalcFunc/InferFunc/CalcWithTupleFunc/InferWithTupleFunc are defined as pure function pointer other than a 419 // std::function, meaning that they should be a lambda function without any capture. 420 using CalcFunc = ShapeArray (*)(const ShapeArray &); 421 using InferFunc = std::vector<int64_t> (*)(const ShapeArray &, const HashSet<size_t> &); 422 using CalcWithTupleFunc = ShapeArray (*)(const ShapeArray &, const ElemPosIdx &); 423 using InferWithTupleFunc = InferOutputInfo (*)(const ShapeArray &, const HashSet<size_t> &, const ElemPosIdx &); 424 PureShapeCalc(const std::string & name)425 explicit PureShapeCalc(const std::string &name) : ShapeCalcBaseFunctor(name) { 426 FunctorRegistry::Instance().Register(name, [this]() { return shared_from_base<Functor>(); }); 427 } 428 429 PureShapeCalc(const PureShapeCalc &) = delete; 430 PureShapeCalc(PureShapeCalc &&) = delete; 431 PureShapeCalc &operator=(const PureShapeCalc &) = delete; 432 PureShapeCalc &operator=(PureShapeCalc &&) = delete; 433 ~PureShapeCalc() override = default; MS_DECLARE_PARENT(PureShapeCalc,ShapeCalcBaseFunctor)434 MS_DECLARE_PARENT(PureShapeCalc, ShapeCalcBaseFunctor) 435 436 ValuePtr ToValue() const override { return nullptr; } FromValue(const ValuePtr &)437 void FromValue(const ValuePtr &) override {} 438 Calc(const ShapeArray & inputs,const ElemPosIdx & pos_idx)439 ShapeArray Calc(const ShapeArray &inputs, const ElemPosIdx &pos_idx) const override { 440 ShapeArray calc_res; 441 if (calc_func_ != nullptr) { 442 calc_res = calc_func_(inputs); 443 } else if (cal_with_tuple_func_ != nullptr) { 444 calc_res = cal_with_tuple_func_(inputs, pos_idx); 445 } else { 446 MS_LOG(EXCEPTION) << "The calc_func of " << name() << " is nullptr"; 447 } 448 449 return calc_res; 450 } 451 Infer(const ShapeArray & inputs,const HashSet<size_t> & unknown_inputs,const ElemPosIdx & pos_idx)452 InferOutputInfo Infer(const ShapeArray &inputs, const HashSet<size_t> &unknown_inputs, 453 const ElemPosIdx &pos_idx) const override { 454 InferOutputInfo infer_res; 455 if (infer_func_ != nullptr) { 456 auto output_shapes = infer_func_(inputs, unknown_inputs); 457 infer_res = std::make_pair(output_shapes, false); 458 } else if (infer_with_tuple_func_ != nullptr) { 459 infer_res = infer_with_tuple_func_(inputs, unknown_inputs, pos_idx); 460 } else { 461 MS_LOG(EXCEPTION) << "The infer_func of " << name() << " is nullptr"; 462 } 463 464 return infer_res; 465 } 466 SetCalc(const CalcFunc & calc_func)467 PureShapeCalc &SetCalc(const CalcFunc &calc_func) { 468 calc_func_ = calc_func; 469 return *this; 470 } 471 SetInfer(const InferFunc & infer_func)472 std::shared_ptr<PureShapeCalc> SetInfer(const InferFunc &infer_func) { 473 infer_func_ = infer_func; 474 if (calc_func_ == nullptr || cal_with_tuple_func_ != nullptr) { 475 MS_LOG(EXCEPTION) << "The Calc Function and Infer Function should all not support tuple!"; 476 } 477 return shared_from_base<PureShapeCalc>(); 478 } 479 SetCalc(const CalcWithTupleFunc & calc_func)480 PureShapeCalc &SetCalc(const CalcWithTupleFunc &calc_func) { 481 cal_with_tuple_func_ = calc_func; 482 return *this; 483 } 484 SetInfer(const InferWithTupleFunc & infer_func)485 std::shared_ptr<PureShapeCalc> SetInfer(const InferWithTupleFunc &infer_func) { 486 infer_with_tuple_func_ = infer_func; 487 if (cal_with_tuple_func_ == nullptr || calc_func_ != nullptr) { 488 MS_LOG(EXCEPTION) << "The Calc Function and Infer Function should all support tuple!"; 489 } 490 return shared_from_base<PureShapeCalc>(); 491 } 492 493 private: 494 CalcFunc calc_func_{nullptr}; 495 InferFunc infer_func_{nullptr}; 496 CalcWithTupleFunc cal_with_tuple_func_{nullptr}; 497 InferWithTupleFunc infer_with_tuple_func_{nullptr}; 498 }; 499 500 #define DEF_PURE_SHAPE_CALC(name) \ 501 static const std::shared_ptr<PureShapeCalc> name = (*(std::make_shared<PureShapeCalc>("ShapeCalc_" #name))) 502 503 } // namespace expander 504 } // namespace mindspore 505 #endif // MINDSPORE_CCSRC_COMMON_EXPANDER_CORE_EMITTER_H_ 506