• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_BACKEND_COMMON_GRAPH_KERNEL_EXPANDER_BASE_IR_BUILDER_H_
17 #define MINDSPORE_CCSRC_BACKEND_COMMON_GRAPH_KERNEL_EXPANDER_BASE_IR_BUILDER_H_
18 
19 #include <string>
20 #include <vector>
21 #include <functional>
22 #include <memory>
23 #include "utils/hash_map.h"
24 #include "kernel/common_utils.h"
25 #include "include/common/utils/utils.h"
26 #include "backend/common/graph_kernel/expander/base/node.h"
27 #include "backend/common/graph_kernel/expander/base/emitter.h"
28 
29 namespace mindspore::graphkernel::expander {
30 class IrBuilder {
31  public:
32   IrBuilder() = default;
33   virtual ~IrBuilder() = default;
34 
Init(const EmitterPtr & emitter,const NodePtrList * inputs,const HashMap<std::string,ValuePtr> * attrs,const std::string & processor)35   void Init(const EmitterPtr &emitter, const NodePtrList *inputs, const HashMap<std::string, ValuePtr> *attrs,
36             const std::string &processor) {
37     e = emitter;
38     inputs_ptr_ = inputs;
39     attrs_ptr_ = attrs;
40     processor_ = processor;
41   }
42   virtual NodePtrList Expand() = 0;
43 
44   /// \brief build a Tensor node from shape
Tensor(std::vector<int64_t> input)45   NodePtr Tensor(std::vector<int64_t> input) const { return e->EmitValue(std::make_shared<tensor::Tensor>(input)); }
46 
47   /// \brief build a Tensor node from imm data
48   template <typename T>
Tensor(T data,const TypePtr & type_ptr)49   NodePtr Tensor(T data, const TypePtr &type_ptr) const {
50     return e->EmitValue(std::make_shared<tensor::Tensor>(data, type_ptr));
51   }
52   /// \brief build a Tensor node from data list
Tensor(TypeId data_type,const ShapeVector & shape,void * data,TypeId src_data_type)53   NodePtr Tensor(TypeId data_type, const ShapeVector &shape, void *data, TypeId src_data_type) const {
54     auto tensor_ptr = std::make_shared<tensor::Tensor>(data_type, shape, data, src_data_type);
55     return e->EmitValue(tensor_ptr);
56   }
57   /// \brief build a imm value node
58   template <typename T>
Value(const T & value)59   NodePtr Value(const T &value) const {
60     return e->EmitValue(MakeValue(value));
61   }
62 
inputs()63   const NodePtrList &inputs() const { return *inputs_ptr_; }
input(size_t i)64   const NodePtr &input(size_t i) const {
65     if (i < inputs_ptr_->size()) {
66       return (*inputs_ptr_)[i];
67     }
68     MS_LOG(EXCEPTION) << "Input index " << i << " out of range of input size " << inputs_ptr_->size();
69   }
attrs()70   const HashMap<std::string, ValuePtr> &attrs() const { return *attrs_ptr_; }
attr(const std::string & key)71   ValuePtr attr(const std::string &key) const {
72     auto iter = attrs_ptr_->find(key);
73     return iter != attrs_ptr_->end() ? iter->second : nullptr;
74   }
75   template <typename T>
attr(const std::string & key)76   T attr(const std::string &key) const {
77     auto v = attr(key);
78     MS_EXCEPTION_IF_NULL(v);
79     return GetValue<T>(v);
80   }
processor()81   const std::string &processor() const { return processor_; }
82 
83   // meta ops begin
Abs(const NodePtr & node)84   inline NodePtr Abs(const NodePtr &node) const { return e->Emit(MetaOp::Abs, {node}); }
Add(const NodePtr & lhs,const NodePtr & rhs)85   inline NodePtr Add(const NodePtr &lhs, const NodePtr &rhs) const { return e->Emit(MetaOp::Add, {lhs, rhs}); }
Assign(const NodePtr & dst,const NodePtr & src)86   inline NodePtr Assign(const NodePtr &dst, const NodePtr &src) const { return e->Emit(MetaOp::Assign, {dst, src}); }
BroadcastTo(const NodePtr & node,const NodePtr & shape)87   inline NodePtr BroadcastTo(const NodePtr &node, const NodePtr &shape) const {
88     return e->Emit(MetaOp::BroadcastTo, {node, shape});
89   }
Cast(const NodePtr & node,const NodePtr & dst_type)90   inline NodePtr Cast(const NodePtr &node, const NodePtr &dst_type) const {
91     return e->Emit(MetaOp::Cast, {node, dst_type});
92   }
Cast(const NodePtr & node,const TypePtr & dst_type)93   inline NodePtr Cast(const NodePtr &node, const TypePtr &dst_type) const { return Cast(node, dst_type->type_id()); }
Cast(const NodePtr & node,TypeId dst_type)94   inline NodePtr Cast(const NodePtr &node, TypeId dst_type) const {
95     return Cast(node, Value(static_cast<int64_t>(dst_type)));
96   }
Concat(const NodePtrList & inputs,const NodePtr & axis)97   inline NodePtr Concat(const NodePtrList &inputs, const NodePtr &axis) const {
98     NodePtrList new_inputs(inputs.cbegin(), inputs.cend());
99     new_inputs.push_back(axis);
100     return e->Emit(MetaOp::Concat, new_inputs);
101   }
Div(const NodePtr & lhs,const NodePtr & rhs)102   inline NodePtr Div(const NodePtr &lhs, const NodePtr &rhs) const { return e->Emit(MetaOp::Div, {lhs, rhs}); }
Equal(const NodePtr & lhs,const NodePtr & rhs)103   inline NodePtr Equal(const NodePtr &lhs, const NodePtr &rhs) const { return e->Emit(MetaOp::Equal, {lhs, rhs}); }
Exp(const NodePtr & node)104   inline NodePtr Exp(const NodePtr &node) const { return e->Emit(MetaOp::Exp, {node}); }
Gather(const NodePtr & param,const NodePtr & indices,const NodePtr & axis)105   inline NodePtr Gather(const NodePtr &param, const NodePtr &indices, const NodePtr &axis) const {
106     return e->Emit(MetaOp::Gather, {param, indices, axis});
107   }
Greater(const NodePtr & lhs,const NodePtr & rhs)108   inline NodePtr Greater(const NodePtr &lhs, const NodePtr &rhs) const { return e->Emit(MetaOp::Greater, {lhs, rhs}); }
GreaterEqual(const NodePtr & lhs,const NodePtr & rhs)109   inline NodePtr GreaterEqual(const NodePtr &lhs, const NodePtr &rhs) const {
110     return e->Emit(MetaOp::GreaterEqual, {lhs, rhs});
111   }
IsInf(const NodePtr & node)112   inline NodePtr IsInf(const NodePtr &node) const { return e->Emit(MetaOp::IsInf, {node}); }
IsNan(const NodePtr & node)113   inline NodePtr IsNan(const NodePtr &node) const { return e->Emit(MetaOp::IsNan, {node}); }
Less(const NodePtr & lhs,const NodePtr & rhs)114   inline NodePtr Less(const NodePtr &lhs, const NodePtr &rhs) const { return e->Emit(MetaOp::Less, {lhs, rhs}); }
LessEqual(const NodePtr & lhs,const NodePtr & rhs)115   inline NodePtr LessEqual(const NodePtr &lhs, const NodePtr &rhs) const {
116     return e->Emit(MetaOp::LessEqual, {lhs, rhs});
117   }
Log(const NodePtr & node)118   inline NodePtr Log(const NodePtr &node) const { return e->Emit(MetaOp::Log, {node}); }
LogicalAnd(const NodePtr & lhs,const NodePtr & rhs)119   inline NodePtr LogicalAnd(const NodePtr &lhs, const NodePtr &rhs) const {
120     return e->Emit(MetaOp::LogicalAnd, {lhs, rhs});
121   }
LogicalOr(const NodePtr & lhs,const NodePtr & rhs)122   inline NodePtr LogicalOr(const NodePtr &lhs, const NodePtr &rhs) const {
123     return e->Emit(MetaOp::LogicalOr, {lhs, rhs});
124   }
LogicalNot(const NodePtr & node)125   inline NodePtr LogicalNot(const NodePtr &node) const { return e->Emit(MetaOp::LogicalNot, {node}); }
MatMul(const NodePtr & a,const NodePtr & b,const NodePtr & transpose_a,const NodePtr & transpose_b)126   inline NodePtr MatMul(const NodePtr &a, const NodePtr &b, const NodePtr &transpose_a,
127                         const NodePtr &transpose_b) const {
128     return e->Emit(MetaOp::MatMul, {a, b, transpose_a, transpose_b});
129   }
Mul(const NodePtr & lhs,const NodePtr & rhs)130   inline NodePtr Mul(const NodePtr &lhs, const NodePtr &rhs) const { return e->Emit(MetaOp::Mul, {lhs, rhs}); }
Neg(const NodePtr & node)131   inline NodePtr Neg(const NodePtr &node) const { return e->Emit(MetaOp::Neg, {node}); }
ReduceMax(const NodePtr & node,const NodePtr & axis,const NodePtr & keepdims)132   inline NodePtr ReduceMax(const NodePtr &node, const NodePtr &axis, const NodePtr &keepdims) const {
133     return e->Emit(MetaOp::ReduceMax, {node, axis, keepdims});
134   }
ReduceMin(const NodePtr & node,const NodePtr & axis,const NodePtr & keepdims)135   inline NodePtr ReduceMin(const NodePtr &node, const NodePtr &axis, const NodePtr &keepdims) const {
136     return e->Emit(MetaOp::ReduceMin, {node, axis, keepdims});
137   }
ReduceSum(const NodePtr & node,const NodePtr & axis,const NodePtr & keepdims)138   inline NodePtr ReduceSum(const NodePtr &node, const NodePtr &axis, const NodePtr &keepdims) const {
139     return e->Emit(MetaOp::ReduceSum, {node, axis, keepdims, Value(false)});
140   }
ReduceSum(const NodePtr & node,const ShapeVector & axis,bool keepdims)141   inline NodePtr ReduceSum(const NodePtr &node, const ShapeVector &axis, bool keepdims) const {
142     return ReduceSum(node, Tensor(axis), Value(keepdims));
143   }
Reshape(const NodePtr & node,const NodePtr & shape)144   inline NodePtr Reshape(const NodePtr &node, const NodePtr &shape) const {
145     return e->Emit(MetaOp::Reshape, {node, shape});
146   }
Reshape(const NodePtr & node,const ShapeVector & shape)147   inline NodePtr Reshape(const NodePtr &node, const ShapeVector &shape) const { return Reshape(node, Tensor(shape)); }
Rsqrt(const NodePtr & node)148   inline NodePtr Rsqrt(const NodePtr &node) const { return e->Emit(MetaOp::Rsqrt, {node}); }
Reciprocal(const NodePtr & node)149   inline NodePtr Reciprocal(const NodePtr &node) const { return e->Emit(MetaOp::Reciprocal, {node}); }
Select(const NodePtr & cond,const NodePtr & true_case,const NodePtr & false_case)150   inline NodePtr Select(const NodePtr &cond, const NodePtr &true_case, const NodePtr &false_case) const {
151     return e->Emit(MetaOp::Select, {cond, true_case, false_case});
152   }
Shape(const NodePtr & node)153   inline NodePtr Shape(const NodePtr &node) const { return e->Emit(MetaOp::Shape, {node}); }
Sqrt(const NodePtr & node)154   inline NodePtr Sqrt(const NodePtr &node) const { return e->Emit(MetaOp::Sqrt, {node}); }
StridedSlice(const NodePtr & input,const NodePtr & begin,const NodePtr & end,const NodePtr & strides)155   inline NodePtr StridedSlice(const NodePtr &input, const NodePtr &begin, const NodePtr &end,
156                               const NodePtr &strides) const {
157     return e->Emit(MetaOp::StridedSlice, {input, begin, end, strides});
158   }
Sub(const NodePtr & lhs,const NodePtr & rhs)159   inline NodePtr Sub(const NodePtr &lhs, const NodePtr &rhs) const { return e->Emit(MetaOp::Sub, {lhs, rhs}); }
Tanh(const NodePtr & node)160   inline NodePtr Tanh(const NodePtr &node) const {
161     if (processor_ == kernel::kProcessorAiCore) {
162       // Tanh(x) = 1 - 2/(e^{2x}+1)
163       auto tanh_exp = Exp(Mul(node, Tensor(2, node->GetDtype())));
164       auto tanh_add_0 = Add(tanh_exp, Tensor(1, node->GetDtype()));
165       auto tanh_rec = Reciprocal(tanh_add_0);
166       auto tanh_neg = Mul(tanh_rec, Tensor(-2, node->GetDtype()));
167       auto tanh_add_1 = Add(tanh_neg, Tensor(1, node->GetDtype()));
168       return tanh_add_1;
169     }
170     return e->Emit(MetaOp::Tanh, {node});
171   }
Cosh(const NodePtr & node)172   inline NodePtr Cosh(const NodePtr &node) const {
173     if (processor_ == kernel::kProcessorAiCore) {
174       // Cosh(x) = (e^x + e^{-x})/2
175       auto cosh_exp_pos = Exp(node);
176       auto cosh_exp_neg = Exp(Mul(node, Tensor(-1, node->GetDtype())));
177       auto cosh_add = Add(cosh_exp_pos, cosh_exp_neg);
178       auto cosh_div = Div(cosh_add, Tensor(2, node->GetDtype()));
179       return cosh_div;
180     }
181     return e->Emit(MetaOp::Cosh, {node});
182   }
Sinh(const NodePtr & node)183   inline NodePtr Sinh(const NodePtr &node) const {
184     if (processor_ == kernel::kProcessorAiCore) {
185       auto sinh_exp_pos = Exp(node);
186       auto sinh_exp_neg = Exp(Mul(node, Tensor(-1, node->GetDtype())));
187       auto sinh_add = Sub(sinh_exp_pos, sinh_exp_neg);
188       auto sinh_div = Div(sinh_add, Tensor(2, node->GetDtype()));
189       return sinh_div;
190     }
191     return e->Emit(MetaOp::Sinh, {node});
192   }
TensorScatterAdd(const NodePtr & input,const NodePtr & indices,const NodePtr & update)193   inline NodePtr TensorScatterAdd(const NodePtr &input, const NodePtr &indices, const NodePtr &update) const {
194     return e->Emit(MetaOp::TensorScatterAdd, {input, indices, update});
195   }
Transpose(const NodePtr & node,const NodePtr & perm)196   inline NodePtr Transpose(const NodePtr &node, const NodePtr &perm) const {
197     return e->Emit(MetaOp::Transpose, {node, perm});
198   }
199   // meta ops end
200  protected:
201   EmitterPtr e;
202   const NodePtrList *inputs_ptr_{nullptr};
203   const HashMap<std::string, ValuePtr> *attrs_ptr_{nullptr};
204   std::string processor_;
205 };
206 
207 class DefaultIrBuilder : public IrBuilder {
208  public:
209   using ExpandFunc = std::function<NodePtrList(const DefaultIrBuilder *)>;
DefaultIrBuilder(const ExpandFunc & func,const std::string & name)210   explicit DefaultIrBuilder(const ExpandFunc &func, const std::string &name) : func_(func), name_(name) {}
211   ~DefaultIrBuilder() override = default;
212 
Expand()213   NodePtrList Expand() override { return func_(this); }
214 
emitter()215   const EmitterPtr &emitter() const { return e; }
name()216   const std::string &name() const { return name_; }
217 
218  protected:
219   ExpandFunc func_;
220   std::string name_;  // name of op
221 };
222 
223 class IrBuilderRegistry {
224  public:
225   using CreatorFunc = std::function<std::unique_ptr<IrBuilder>()>;
Instance()226   static IrBuilderRegistry &Instance() {
227     static IrBuilderRegistry reg{};
228     return reg;
229   }
230   class RegHelper {
231    public:
232     // Register IrBuilder by subclass.
RegHelper(const std::string & name,const CreatorFunc & func)233     RegHelper(const std::string &name, const CreatorFunc &func) { IrBuilderRegistry::Instance().Reg(name, func); }
234     // Register DefaultIrBuilder
RegHelper(const std::string & name)235     explicit RegHelper(const std::string &name) : name_(name) {}
SetBody(const DefaultIrBuilder::ExpandFunc & func)236     RegHelper &SetBody(const DefaultIrBuilder::ExpandFunc &func) {
237       IrBuilderRegistry::Instance().Reg(
238         name_, [func, name = this->name_]() { return std::make_unique<DefaultIrBuilder>(func, name); });
239       return *this;
240     }
241 
242     ~RegHelper() = default;
243 
244    protected:
245     std::string name_;
246     DefaultIrBuilder::ExpandFunc expand_func_;
247   };
248 
HasOp(const std::string & name)249   bool HasOp(const std::string &name) const { return creator_map_.count(name) > 0; }
GetOp(const std::string & name)250   std::unique_ptr<IrBuilder> GetOp(const std::string &name) const {
251     auto iter = creator_map_.find(name);
252     return (iter != creator_map_.end() ? iter->second() : nullptr);
253   }
254 
255  private:
256   IrBuilderRegistry() = default;
257   ~IrBuilderRegistry() = default;
258 
Reg(const std::string & name,const CreatorFunc & func)259   void Reg(const std::string &name, const CreatorFunc &func) { creator_map_[name] = func; }
260   HashMap<std::string, CreatorFunc> creator_map_;
261 };
262 
263 #define JOIN(x, y) x##y
264 #define UNIQUE_NAME(prefix, cnt) JOIN(prefix, cnt)
265 #define BODYFUNC(v) [](const DefaultIrBuilder *v) -> NodePtrList
266 
267 #define REG_EXPANDER_CLASS(name, cls)                                            \
268   static const IrBuilderRegistry::RegHelper UNIQUE_NAME(g_op_cls_, __COUNTER__)( \
269     name, []() noexcept { return std::unique_ptr<IrBuilder>(static_cast<IrBuilder *>(new cls())); })
270 
271 #define REG_EXPANDER_FUNC(name) \
272   static const IrBuilderRegistry::RegHelper UNIQUE_NAME(g_op_func, __COUNTER__) = IrBuilderRegistry::RegHelper(name)
273 }  // namespace mindspore::graphkernel::expander
274 #endif  // MINDSPORE_CCSRC_BACKEND_COMMON_GRAPH_KERNEL_EXPANDER_BASE_IR_BUILDER_H_
275