1 /** 2 * Copyright 2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CCSRC_BACKEND_COMMON_GRAPH_KERNEL_EXPANDER_BASE_IR_BUILDER_H_ 17 #define MINDSPORE_CCSRC_BACKEND_COMMON_GRAPH_KERNEL_EXPANDER_BASE_IR_BUILDER_H_ 18 19 #include <string> 20 #include <vector> 21 #include <functional> 22 #include <memory> 23 #include "utils/hash_map.h" 24 #include "kernel/common_utils.h" 25 #include "include/common/utils/utils.h" 26 #include "backend/common/graph_kernel/expander/base/node.h" 27 #include "backend/common/graph_kernel/expander/base/emitter.h" 28 29 namespace mindspore::graphkernel::expander { 30 class IrBuilder { 31 public: 32 IrBuilder() = default; 33 virtual ~IrBuilder() = default; 34 Init(const EmitterPtr & emitter,const NodePtrList * inputs,const HashMap<std::string,ValuePtr> * attrs,const std::string & processor)35 void Init(const EmitterPtr &emitter, const NodePtrList *inputs, const HashMap<std::string, ValuePtr> *attrs, 36 const std::string &processor) { 37 e = emitter; 38 inputs_ptr_ = inputs; 39 attrs_ptr_ = attrs; 40 processor_ = processor; 41 } 42 virtual NodePtrList Expand() = 0; 43 44 /// \brief build a Tensor node from shape Tensor(std::vector<int64_t> input)45 NodePtr Tensor(std::vector<int64_t> input) const { return e->EmitValue(std::make_shared<tensor::Tensor>(input)); } 46 47 /// \brief build a Tensor node from imm data 48 template <typename T> Tensor(T data,const TypePtr & type_ptr)49 NodePtr Tensor(T data, const TypePtr &type_ptr) const { 50 return e->EmitValue(std::make_shared<tensor::Tensor>(data, type_ptr)); 51 } 52 /// \brief build a Tensor node from data list Tensor(TypeId data_type,const ShapeVector & shape,void * data,TypeId src_data_type)53 NodePtr Tensor(TypeId data_type, const ShapeVector &shape, void *data, TypeId src_data_type) const { 54 auto tensor_ptr = std::make_shared<tensor::Tensor>(data_type, shape, data, src_data_type); 55 return e->EmitValue(tensor_ptr); 56 } 57 /// \brief build a imm value node 58 template <typename T> Value(const T & value)59 NodePtr Value(const T &value) const { 60 return e->EmitValue(MakeValue(value)); 61 } 62 inputs()63 const NodePtrList &inputs() const { return *inputs_ptr_; } input(size_t i)64 const NodePtr &input(size_t i) const { 65 if (i < inputs_ptr_->size()) { 66 return (*inputs_ptr_)[i]; 67 } 68 MS_LOG(EXCEPTION) << "Input index " << i << " out of range of input size " << inputs_ptr_->size(); 69 } attrs()70 const HashMap<std::string, ValuePtr> &attrs() const { return *attrs_ptr_; } attr(const std::string & key)71 ValuePtr attr(const std::string &key) const { 72 auto iter = attrs_ptr_->find(key); 73 return iter != attrs_ptr_->end() ? iter->second : nullptr; 74 } 75 template <typename T> attr(const std::string & key)76 T attr(const std::string &key) const { 77 auto v = attr(key); 78 MS_EXCEPTION_IF_NULL(v); 79 return GetValue<T>(v); 80 } processor()81 const std::string &processor() const { return processor_; } 82 83 // meta ops begin Abs(const NodePtr & node)84 inline NodePtr Abs(const NodePtr &node) const { return e->Emit(MetaOp::Abs, {node}); } Add(const NodePtr & lhs,const NodePtr & rhs)85 inline NodePtr Add(const NodePtr &lhs, const NodePtr &rhs) const { return e->Emit(MetaOp::Add, {lhs, rhs}); } Assign(const NodePtr & dst,const NodePtr & src)86 inline NodePtr Assign(const NodePtr &dst, const NodePtr &src) const { return e->Emit(MetaOp::Assign, {dst, src}); } BroadcastTo(const NodePtr & node,const NodePtr & shape)87 inline NodePtr BroadcastTo(const NodePtr &node, const NodePtr &shape) const { 88 return e->Emit(MetaOp::BroadcastTo, {node, shape}); 89 } Cast(const NodePtr & node,const NodePtr & dst_type)90 inline NodePtr Cast(const NodePtr &node, const NodePtr &dst_type) const { 91 return e->Emit(MetaOp::Cast, {node, dst_type}); 92 } Cast(const NodePtr & node,const TypePtr & dst_type)93 inline NodePtr Cast(const NodePtr &node, const TypePtr &dst_type) const { return Cast(node, dst_type->type_id()); } Cast(const NodePtr & node,TypeId dst_type)94 inline NodePtr Cast(const NodePtr &node, TypeId dst_type) const { 95 return Cast(node, Value(static_cast<int64_t>(dst_type))); 96 } Concat(const NodePtrList & inputs,const NodePtr & axis)97 inline NodePtr Concat(const NodePtrList &inputs, const NodePtr &axis) const { 98 NodePtrList new_inputs(inputs.cbegin(), inputs.cend()); 99 new_inputs.push_back(axis); 100 return e->Emit(MetaOp::Concat, new_inputs); 101 } Div(const NodePtr & lhs,const NodePtr & rhs)102 inline NodePtr Div(const NodePtr &lhs, const NodePtr &rhs) const { return e->Emit(MetaOp::Div, {lhs, rhs}); } Equal(const NodePtr & lhs,const NodePtr & rhs)103 inline NodePtr Equal(const NodePtr &lhs, const NodePtr &rhs) const { return e->Emit(MetaOp::Equal, {lhs, rhs}); } Exp(const NodePtr & node)104 inline NodePtr Exp(const NodePtr &node) const { return e->Emit(MetaOp::Exp, {node}); } Gather(const NodePtr & param,const NodePtr & indices,const NodePtr & axis)105 inline NodePtr Gather(const NodePtr ¶m, const NodePtr &indices, const NodePtr &axis) const { 106 return e->Emit(MetaOp::Gather, {param, indices, axis}); 107 } Greater(const NodePtr & lhs,const NodePtr & rhs)108 inline NodePtr Greater(const NodePtr &lhs, const NodePtr &rhs) const { return e->Emit(MetaOp::Greater, {lhs, rhs}); } GreaterEqual(const NodePtr & lhs,const NodePtr & rhs)109 inline NodePtr GreaterEqual(const NodePtr &lhs, const NodePtr &rhs) const { 110 return e->Emit(MetaOp::GreaterEqual, {lhs, rhs}); 111 } IsInf(const NodePtr & node)112 inline NodePtr IsInf(const NodePtr &node) const { return e->Emit(MetaOp::IsInf, {node}); } IsNan(const NodePtr & node)113 inline NodePtr IsNan(const NodePtr &node) const { return e->Emit(MetaOp::IsNan, {node}); } Less(const NodePtr & lhs,const NodePtr & rhs)114 inline NodePtr Less(const NodePtr &lhs, const NodePtr &rhs) const { return e->Emit(MetaOp::Less, {lhs, rhs}); } LessEqual(const NodePtr & lhs,const NodePtr & rhs)115 inline NodePtr LessEqual(const NodePtr &lhs, const NodePtr &rhs) const { 116 return e->Emit(MetaOp::LessEqual, {lhs, rhs}); 117 } Log(const NodePtr & node)118 inline NodePtr Log(const NodePtr &node) const { return e->Emit(MetaOp::Log, {node}); } LogicalAnd(const NodePtr & lhs,const NodePtr & rhs)119 inline NodePtr LogicalAnd(const NodePtr &lhs, const NodePtr &rhs) const { 120 return e->Emit(MetaOp::LogicalAnd, {lhs, rhs}); 121 } LogicalOr(const NodePtr & lhs,const NodePtr & rhs)122 inline NodePtr LogicalOr(const NodePtr &lhs, const NodePtr &rhs) const { 123 return e->Emit(MetaOp::LogicalOr, {lhs, rhs}); 124 } LogicalNot(const NodePtr & node)125 inline NodePtr LogicalNot(const NodePtr &node) const { return e->Emit(MetaOp::LogicalNot, {node}); } MatMul(const NodePtr & a,const NodePtr & b,const NodePtr & transpose_a,const NodePtr & transpose_b)126 inline NodePtr MatMul(const NodePtr &a, const NodePtr &b, const NodePtr &transpose_a, 127 const NodePtr &transpose_b) const { 128 return e->Emit(MetaOp::MatMul, {a, b, transpose_a, transpose_b}); 129 } Mul(const NodePtr & lhs,const NodePtr & rhs)130 inline NodePtr Mul(const NodePtr &lhs, const NodePtr &rhs) const { return e->Emit(MetaOp::Mul, {lhs, rhs}); } Neg(const NodePtr & node)131 inline NodePtr Neg(const NodePtr &node) const { return e->Emit(MetaOp::Neg, {node}); } ReduceMax(const NodePtr & node,const NodePtr & axis,const NodePtr & keepdims)132 inline NodePtr ReduceMax(const NodePtr &node, const NodePtr &axis, const NodePtr &keepdims) const { 133 return e->Emit(MetaOp::ReduceMax, {node, axis, keepdims}); 134 } ReduceMin(const NodePtr & node,const NodePtr & axis,const NodePtr & keepdims)135 inline NodePtr ReduceMin(const NodePtr &node, const NodePtr &axis, const NodePtr &keepdims) const { 136 return e->Emit(MetaOp::ReduceMin, {node, axis, keepdims}); 137 } ReduceSum(const NodePtr & node,const NodePtr & axis,const NodePtr & keepdims)138 inline NodePtr ReduceSum(const NodePtr &node, const NodePtr &axis, const NodePtr &keepdims) const { 139 return e->Emit(MetaOp::ReduceSum, {node, axis, keepdims, Value(false)}); 140 } ReduceSum(const NodePtr & node,const ShapeVector & axis,bool keepdims)141 inline NodePtr ReduceSum(const NodePtr &node, const ShapeVector &axis, bool keepdims) const { 142 return ReduceSum(node, Tensor(axis), Value(keepdims)); 143 } Reshape(const NodePtr & node,const NodePtr & shape)144 inline NodePtr Reshape(const NodePtr &node, const NodePtr &shape) const { 145 return e->Emit(MetaOp::Reshape, {node, shape}); 146 } Reshape(const NodePtr & node,const ShapeVector & shape)147 inline NodePtr Reshape(const NodePtr &node, const ShapeVector &shape) const { return Reshape(node, Tensor(shape)); } Rsqrt(const NodePtr & node)148 inline NodePtr Rsqrt(const NodePtr &node) const { return e->Emit(MetaOp::Rsqrt, {node}); } Reciprocal(const NodePtr & node)149 inline NodePtr Reciprocal(const NodePtr &node) const { return e->Emit(MetaOp::Reciprocal, {node}); } Select(const NodePtr & cond,const NodePtr & true_case,const NodePtr & false_case)150 inline NodePtr Select(const NodePtr &cond, const NodePtr &true_case, const NodePtr &false_case) const { 151 return e->Emit(MetaOp::Select, {cond, true_case, false_case}); 152 } Shape(const NodePtr & node)153 inline NodePtr Shape(const NodePtr &node) const { return e->Emit(MetaOp::Shape, {node}); } Sqrt(const NodePtr & node)154 inline NodePtr Sqrt(const NodePtr &node) const { return e->Emit(MetaOp::Sqrt, {node}); } StridedSlice(const NodePtr & input,const NodePtr & begin,const NodePtr & end,const NodePtr & strides)155 inline NodePtr StridedSlice(const NodePtr &input, const NodePtr &begin, const NodePtr &end, 156 const NodePtr &strides) const { 157 return e->Emit(MetaOp::StridedSlice, {input, begin, end, strides}); 158 } Sub(const NodePtr & lhs,const NodePtr & rhs)159 inline NodePtr Sub(const NodePtr &lhs, const NodePtr &rhs) const { return e->Emit(MetaOp::Sub, {lhs, rhs}); } Tanh(const NodePtr & node)160 inline NodePtr Tanh(const NodePtr &node) const { 161 if (processor_ == kernel::kProcessorAiCore) { 162 // Tanh(x) = 1 - 2/(e^{2x}+1) 163 auto tanh_exp = Exp(Mul(node, Tensor(2, node->GetDtype()))); 164 auto tanh_add_0 = Add(tanh_exp, Tensor(1, node->GetDtype())); 165 auto tanh_rec = Reciprocal(tanh_add_0); 166 auto tanh_neg = Mul(tanh_rec, Tensor(-2, node->GetDtype())); 167 auto tanh_add_1 = Add(tanh_neg, Tensor(1, node->GetDtype())); 168 return tanh_add_1; 169 } 170 return e->Emit(MetaOp::Tanh, {node}); 171 } Cosh(const NodePtr & node)172 inline NodePtr Cosh(const NodePtr &node) const { 173 if (processor_ == kernel::kProcessorAiCore) { 174 // Cosh(x) = (e^x + e^{-x})/2 175 auto cosh_exp_pos = Exp(node); 176 auto cosh_exp_neg = Exp(Mul(node, Tensor(-1, node->GetDtype()))); 177 auto cosh_add = Add(cosh_exp_pos, cosh_exp_neg); 178 auto cosh_div = Div(cosh_add, Tensor(2, node->GetDtype())); 179 return cosh_div; 180 } 181 return e->Emit(MetaOp::Cosh, {node}); 182 } Sinh(const NodePtr & node)183 inline NodePtr Sinh(const NodePtr &node) const { 184 if (processor_ == kernel::kProcessorAiCore) { 185 auto sinh_exp_pos = Exp(node); 186 auto sinh_exp_neg = Exp(Mul(node, Tensor(-1, node->GetDtype()))); 187 auto sinh_add = Sub(sinh_exp_pos, sinh_exp_neg); 188 auto sinh_div = Div(sinh_add, Tensor(2, node->GetDtype())); 189 return sinh_div; 190 } 191 return e->Emit(MetaOp::Sinh, {node}); 192 } TensorScatterAdd(const NodePtr & input,const NodePtr & indices,const NodePtr & update)193 inline NodePtr TensorScatterAdd(const NodePtr &input, const NodePtr &indices, const NodePtr &update) const { 194 return e->Emit(MetaOp::TensorScatterAdd, {input, indices, update}); 195 } Transpose(const NodePtr & node,const NodePtr & perm)196 inline NodePtr Transpose(const NodePtr &node, const NodePtr &perm) const { 197 return e->Emit(MetaOp::Transpose, {node, perm}); 198 } 199 // meta ops end 200 protected: 201 EmitterPtr e; 202 const NodePtrList *inputs_ptr_{nullptr}; 203 const HashMap<std::string, ValuePtr> *attrs_ptr_{nullptr}; 204 std::string processor_; 205 }; 206 207 class DefaultIrBuilder : public IrBuilder { 208 public: 209 using ExpandFunc = std::function<NodePtrList(const DefaultIrBuilder *)>; DefaultIrBuilder(const ExpandFunc & func,const std::string & name)210 explicit DefaultIrBuilder(const ExpandFunc &func, const std::string &name) : func_(func), name_(name) {} 211 ~DefaultIrBuilder() override = default; 212 Expand()213 NodePtrList Expand() override { return func_(this); } 214 emitter()215 const EmitterPtr &emitter() const { return e; } name()216 const std::string &name() const { return name_; } 217 218 protected: 219 ExpandFunc func_; 220 std::string name_; // name of op 221 }; 222 223 class IrBuilderRegistry { 224 public: 225 using CreatorFunc = std::function<std::unique_ptr<IrBuilder>()>; Instance()226 static IrBuilderRegistry &Instance() { 227 static IrBuilderRegistry reg{}; 228 return reg; 229 } 230 class RegHelper { 231 public: 232 // Register IrBuilder by subclass. RegHelper(const std::string & name,const CreatorFunc & func)233 RegHelper(const std::string &name, const CreatorFunc &func) { IrBuilderRegistry::Instance().Reg(name, func); } 234 // Register DefaultIrBuilder RegHelper(const std::string & name)235 explicit RegHelper(const std::string &name) : name_(name) {} SetBody(const DefaultIrBuilder::ExpandFunc & func)236 RegHelper &SetBody(const DefaultIrBuilder::ExpandFunc &func) { 237 IrBuilderRegistry::Instance().Reg( 238 name_, [func, name = this->name_]() { return std::make_unique<DefaultIrBuilder>(func, name); }); 239 return *this; 240 } 241 242 ~RegHelper() = default; 243 244 protected: 245 std::string name_; 246 DefaultIrBuilder::ExpandFunc expand_func_; 247 }; 248 HasOp(const std::string & name)249 bool HasOp(const std::string &name) const { return creator_map_.count(name) > 0; } GetOp(const std::string & name)250 std::unique_ptr<IrBuilder> GetOp(const std::string &name) const { 251 auto iter = creator_map_.find(name); 252 return (iter != creator_map_.end() ? iter->second() : nullptr); 253 } 254 255 private: 256 IrBuilderRegistry() = default; 257 ~IrBuilderRegistry() = default; 258 Reg(const std::string & name,const CreatorFunc & func)259 void Reg(const std::string &name, const CreatorFunc &func) { creator_map_[name] = func; } 260 HashMap<std::string, CreatorFunc> creator_map_; 261 }; 262 263 #define JOIN(x, y) x##y 264 #define UNIQUE_NAME(prefix, cnt) JOIN(prefix, cnt) 265 #define BODYFUNC(v) [](const DefaultIrBuilder *v) -> NodePtrList 266 267 #define REG_EXPANDER_CLASS(name, cls) \ 268 static const IrBuilderRegistry::RegHelper UNIQUE_NAME(g_op_cls_, __COUNTER__)( \ 269 name, []() noexcept { return std::unique_ptr<IrBuilder>(static_cast<IrBuilder *>(new cls())); }) 270 271 #define REG_EXPANDER_FUNC(name) \ 272 static const IrBuilderRegistry::RegHelper UNIQUE_NAME(g_op_func, __COUNTER__) = IrBuilderRegistry::RegHelper(name) 273 } // namespace mindspore::graphkernel::expander 274 #endif // MINDSPORE_CCSRC_BACKEND_COMMON_GRAPH_KERNEL_EXPANDER_BASE_IR_BUILDER_H_ 275