• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CORE_SYMBOLIC_SHAPE_OPERATION_BUILDER_H_
17 #define MINDSPORE_CORE_SYMBOLIC_SHAPE_OPERATION_BUILDER_H_
18 #include <vector>
19 #include <string>
20 #include <memory>
21 #include <utility>
22 #include <unordered_map>
23 #include "mindspore/core/ir/primitive.h"
24 #include "mindspore/core/symbolic_shape/symbol.h"
25 #include "mindspore/core/symbolic_shape/operation.h"
26 
27 namespace mindspore {
28 namespace symshape {
29 class OperationBuilder;
30 enum class DependOn : int { kShape, kValue, kNone };
31 
32 /// \brief Get depend status of inputs when building shape of prim
33 MS_CORE_API std::vector<DependOn> GetShapeDepends(const PrimitivePtr &prim, size_t input_num);
34 
35 /// \brief Get depend status of inputs when building shape of prim
36 MS_CORE_API std::vector<DependOn> GetValueDepends(const PrimitivePtr &prim, size_t input_num);
37 
38 using InferFunc = std::function<SymbolPtr(OperationBuilder *)>;
39 using DependFunc = std::function<std::vector<DependOn>(const PrimitivePtr &, size_t)>;
40 struct MS_CORE_API OperationBuilderInfo {
41   InferFunc build_shape_func{nullptr};
42   InferFunc build_value_func{nullptr};
43   DependFunc shape_depend_func{nullptr};
44   DependFunc value_depend_func{nullptr};
45   std::vector<DependOn> shape_depend_list;
46   std::vector<DependOn> value_depend_list;
GetDependsOperationBuilderInfo47   std::vector<DependOn> GetDepends(const PrimitivePtr &prim, size_t input_num, bool build_value) const {
48     return build_value ? (value_depend_func != nullptr ? value_depend_func(prim, input_num) : value_depend_list)
49                        : (shape_depend_func != nullptr ? shape_depend_func(prim, input_num) : shape_depend_list);
50   }
51 };
52 
53 class MS_CORE_API OperationBuilder {
54  public:
OperationBuilder(OperationEmitter * emitter,const OperationBuilderInfo & info)55   OperationBuilder(OperationEmitter *emitter, const OperationBuilderInfo &info)
56       : emitter_(emitter), symbol_builder_info_(info) {}
57   ~OperationBuilder() = default;
58   SymbolPtr BuildShape(const PrimitivePtr &prim, const AbstractBasePtrList &input_args, const AbstractBasePtr &out);
59   SymbolPtr BuildValue(const PrimitivePtr &prim, const AbstractBasePtrList &input_args, const AbstractBasePtr &out);
60 
61   SymbolPtr Emit(const OpPtr &op) const;
62   SymbolPtr GetShape(const AbstractBasePtr &abs) const;
63   SymbolPtr GetValue(const AbstractBasePtr &abs) const;
GetInput(size_t i)64   const AbstractBasePtr &GetInput(size_t i) const {
65     if (input_args_->at(i) == nullptr) {
66       MS_LOG(INTERNAL_EXCEPTION) << "The pointer[input_args_->at(" << i << ")] is null.";
67     }
68     return (*input_args_)[i];
69   }
GetInputShape(size_t i)70   SymbolPtr GetInputShape(size_t i) const { return GetShape(GetInput(i)); }
GetInputValue(size_t i)71   SymbolPtr GetInputValue(size_t i) const { return GetValue(GetInput(i)); }
72   SymbolPtr GetAttr(const std::string &attr_name) const;
73   SymbolPtr GetInputOrAttr(size_t index, const std::string &attr_name) const;
74 
is_building_shape()75   bool is_building_shape() const { return is_building_shape_; }
prim()76   const PrimitivePtr &prim() const { return prim_; }
input_num()77   size_t input_num() const { return input_args_->size(); }
out_abstract()78   const AbstractBasePtr &out_abstract() const { return out_; }
symbol_builder_info()79   const OperationBuilderInfo &symbol_builder_info() const { return symbol_builder_info_; }
80 
81  protected:
82   OperationEmitter *emitter_;
83   const OperationBuilderInfo &symbol_builder_info_;
84   bool is_building_shape_{false};
85   PrimitivePtr prim_;
86   const AbstractBasePtrList *input_args_;
87   AbstractBasePtr out_;
88 };
89 using OperationBuilderPtr = std::unique_ptr<OperationBuilder>;
90 
91 template <DependOn d, size_t n = 0>
DefaultDepender(const PrimitivePtr &,size_t input_num)92 std::vector<DependOn> DefaultDepender(const PrimitivePtr &, size_t input_num) {
93   if (n == 0) {
94     return std::vector<DependOn>(input_num, d);
95   }
96   return std::vector<DependOn>(n, d);
97 }
98 
99 /// \brief The default builder to create an `Operation`, input shapes or values are related to the depend list.
100 ///
101 /// \tparam OP The class that inherit from `Operation`.
102 template <typename OP, typename = std::enable_if_t<std::is_base_of_v<Operation, OP>>>
DefaultBuilder(OperationBuilder * b)103 SymbolPtr DefaultBuilder(OperationBuilder *b) {
104   bool build_value = !b->is_building_shape();
105   auto depends = b->symbol_builder_info().GetDepends(b->prim(), b->input_num(), build_value);
106   if (depends.empty()) {
107     MS_LOG(WARNING) << "For " << b->prim()->name() << ", the depends list is empty.";
108     return nullptr;
109   }
110   if (b->input_num() < depends.size()) {
111     MS_LOG(WARNING) << "For " << b->prim()->name() << ", the input args num is less than the depends size. "
112                     << b->input_num() << " vs " << depends.size();
113     return nullptr;
114   }
115   SymbolPtrList inputs;
116   inputs.reserve(depends.size());
117   for (size_t i = 0; i < depends.size(); i++) {
118     if (depends[i] == DependOn::kShape) {
119       (void)inputs.emplace_back(b->GetInputShape(i));
120     } else if (depends[i] == DependOn::kValue) {
121       (void)inputs.emplace_back(b->GetInputValue(i));
122     }
123   }
124   return b->Emit(std::make_shared<OP>(std::move(inputs)));
125 }
126 
127 /// \brief Use the input symbol as output directly.
128 ///
129 /// \note When using this function, the `SetShapeDepend` or `SetValueDepend` should be set, and only one
130 /// "DependOn::kShape" (or "DependOn::kValue") exists. the depending symbol is used as output.
131 SymbolPtr TransparentInput(OperationBuilder *b);
132 
133 class MS_CORE_API OperationBuilderInfoRegistry {
134  public:
135   static const OperationBuilderInfo *GetBuildInfo(const std::string &name);
136   static OperationBuilderPtr GetBuilder(const std::string &name, OperationEmitter *e);
HasOp(const std::string & name)137   static inline bool HasOp(const std::string &name) { return GetBuildInfo(name) != nullptr; }
138 
Instance()139   static OperationBuilderInfoRegistry &Instance() {
140     static OperationBuilderInfoRegistry instance{};
141     return instance;
142   }
143 
144   class RegHelper {
145    public:
RegHelper(const std::string & name)146     explicit RegHelper(const std::string &name) : builder_(OperationBuilderInfoRegistry::Instance().NewBuilder(name)) {}
SetShapeDepend(const std::initializer_list<DependOn> & depends)147     RegHelper &SetShapeDepend(const std::initializer_list<DependOn> &depends) {
148       builder_->shape_depend_list = depends;
149       return *this;
150     }
SetShapeDepend(const DependFunc & func)151     RegHelper &SetShapeDepend(const DependFunc &func) {
152       builder_->shape_depend_func = func;
153       return *this;
154     }
155     template <DependOn d, size_t n = 0>
SetShapeDependN()156     RegHelper &SetShapeDependN() {
157       return SetShapeDepend(DefaultDepender<d, n>);
158     }
SetShapeFunc(const InferFunc & func)159     RegHelper &SetShapeFunc(const InferFunc &func) {
160       builder_->build_shape_func = func;
161       return *this;
162     }
163     template <typename OP, typename = std::enable_if_t<std::is_base_of_v<Operation, OP>>>
SetShapeFuncWith()164     RegHelper &SetShapeFuncWith() {
165       return SetShapeFunc(DefaultBuilder<OP>);
166     }
SetShapeTransparentFunc()167     RegHelper &SetShapeTransparentFunc() { return SetShapeFunc(TransparentInput); }
168 
SetValueDepend(const std::initializer_list<DependOn> & depends)169     RegHelper &SetValueDepend(const std::initializer_list<DependOn> &depends) {
170       builder_->value_depend_list = depends;
171       return *this;
172     }
SetValueDepend(const DependFunc & func)173     RegHelper &SetValueDepend(const DependFunc &func) {
174       builder_->value_depend_func = func;
175       return *this;
176     }
177     template <DependOn d, size_t n = 0>
SetValueDependN()178     RegHelper &SetValueDependN() {
179       return SetValueDepend(DefaultDepender<d, n>);
180     }
SetValueFunc(const InferFunc & func)181     RegHelper &SetValueFunc(const InferFunc &func) {
182       builder_->build_value_func = func;
183       return *this;
184     }
185     template <typename OP, typename = std::enable_if_t<std::is_base_of_v<Operation, OP>>>
SetValueFuncWith()186     RegHelper &SetValueFuncWith() {
187       return SetValueFunc(DefaultBuilder<OP>);
188     }
189     OperationBuilderInfo *builder_;
190   };  // class RegHelper
191 
builders()192   const std::unordered_map<std::string, OperationBuilderInfo> &builders() const { return builders_; }
193 
194  private:
NewBuilder(const std::string & name)195   OperationBuilderInfo *NewBuilder(const std::string &name) { return &builders_[name]; }
196   std::unordered_map<std::string, OperationBuilderInfo> builders_;
197 };
198 
199 #define JOIN(x, y) x##y
200 #define UNIQUE_NAME(prefix, cnt) JOIN(prefix, cnt)
201 #define REG_SYMBOL_OP_BUILDER(name) \
202   const auto UNIQUE_NAME(g_ob_, __COUNTER__) = OperationBuilderInfoRegistry::RegHelper(name)
203 }  // namespace symshape
204 }  // namespace mindspore
205 #endif  // MINDSPORE_CORE_SYMBOLIC_SHAPE_OPERATION_BUILDER_H_
206