1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #ifndef MINDSPORE_CORE_SYMBOLIC_SHAPE_OPERATION_BUILDER_H_
17 #define MINDSPORE_CORE_SYMBOLIC_SHAPE_OPERATION_BUILDER_H_
18 #include <vector>
19 #include <string>
20 #include <memory>
21 #include <utility>
22 #include <unordered_map>
23 #include "mindspore/core/ir/primitive.h"
24 #include "mindspore/core/symbolic_shape/symbol.h"
25 #include "mindspore/core/symbolic_shape/operation.h"
26
27 namespace mindspore {
28 namespace symshape {
29 class OperationBuilder;
30 enum class DependOn : int { kShape, kValue, kNone };
31
32 /// \brief Get depend status of inputs when building shape of prim
33 MS_CORE_API std::vector<DependOn> GetShapeDepends(const PrimitivePtr &prim, size_t input_num);
34
35 /// \brief Get depend status of inputs when building shape of prim
36 MS_CORE_API std::vector<DependOn> GetValueDepends(const PrimitivePtr &prim, size_t input_num);
37
38 using InferFunc = std::function<SymbolPtr(OperationBuilder *)>;
39 using DependFunc = std::function<std::vector<DependOn>(const PrimitivePtr &, size_t)>;
40 struct MS_CORE_API OperationBuilderInfo {
41 InferFunc build_shape_func{nullptr};
42 InferFunc build_value_func{nullptr};
43 DependFunc shape_depend_func{nullptr};
44 DependFunc value_depend_func{nullptr};
45 std::vector<DependOn> shape_depend_list;
46 std::vector<DependOn> value_depend_list;
GetDependsOperationBuilderInfo47 std::vector<DependOn> GetDepends(const PrimitivePtr &prim, size_t input_num, bool build_value) const {
48 return build_value ? (value_depend_func != nullptr ? value_depend_func(prim, input_num) : value_depend_list)
49 : (shape_depend_func != nullptr ? shape_depend_func(prim, input_num) : shape_depend_list);
50 }
51 };
52
53 class MS_CORE_API OperationBuilder {
54 public:
OperationBuilder(OperationEmitter * emitter,const OperationBuilderInfo & info)55 OperationBuilder(OperationEmitter *emitter, const OperationBuilderInfo &info)
56 : emitter_(emitter), symbol_builder_info_(info) {}
57 ~OperationBuilder() = default;
58 SymbolPtr BuildShape(const PrimitivePtr &prim, const AbstractBasePtrList &input_args, const AbstractBasePtr &out);
59 SymbolPtr BuildValue(const PrimitivePtr &prim, const AbstractBasePtrList &input_args, const AbstractBasePtr &out);
60
61 SymbolPtr Emit(const OpPtr &op) const;
62 SymbolPtr GetShape(const AbstractBasePtr &abs) const;
63 SymbolPtr GetValue(const AbstractBasePtr &abs) const;
GetInput(size_t i)64 const AbstractBasePtr &GetInput(size_t i) const {
65 if (input_args_->at(i) == nullptr) {
66 MS_LOG(INTERNAL_EXCEPTION) << "The pointer[input_args_->at(" << i << ")] is null.";
67 }
68 return (*input_args_)[i];
69 }
GetInputShape(size_t i)70 SymbolPtr GetInputShape(size_t i) const { return GetShape(GetInput(i)); }
GetInputValue(size_t i)71 SymbolPtr GetInputValue(size_t i) const { return GetValue(GetInput(i)); }
72 SymbolPtr GetAttr(const std::string &attr_name) const;
73 SymbolPtr GetInputOrAttr(size_t index, const std::string &attr_name) const;
74
is_building_shape()75 bool is_building_shape() const { return is_building_shape_; }
prim()76 const PrimitivePtr &prim() const { return prim_; }
input_num()77 size_t input_num() const { return input_args_->size(); }
out_abstract()78 const AbstractBasePtr &out_abstract() const { return out_; }
symbol_builder_info()79 const OperationBuilderInfo &symbol_builder_info() const { return symbol_builder_info_; }
80
81 protected:
82 OperationEmitter *emitter_;
83 const OperationBuilderInfo &symbol_builder_info_;
84 bool is_building_shape_{false};
85 PrimitivePtr prim_;
86 const AbstractBasePtrList *input_args_;
87 AbstractBasePtr out_;
88 };
89 using OperationBuilderPtr = std::unique_ptr<OperationBuilder>;
90
91 template <DependOn d, size_t n = 0>
DefaultDepender(const PrimitivePtr &,size_t input_num)92 std::vector<DependOn> DefaultDepender(const PrimitivePtr &, size_t input_num) {
93 if (n == 0) {
94 return std::vector<DependOn>(input_num, d);
95 }
96 return std::vector<DependOn>(n, d);
97 }
98
99 /// \brief The default builder to create an `Operation`, input shapes or values are related to the depend list.
100 ///
101 /// \tparam OP The class that inherit from `Operation`.
102 template <typename OP, typename = std::enable_if_t<std::is_base_of_v<Operation, OP>>>
DefaultBuilder(OperationBuilder * b)103 SymbolPtr DefaultBuilder(OperationBuilder *b) {
104 bool build_value = !b->is_building_shape();
105 auto depends = b->symbol_builder_info().GetDepends(b->prim(), b->input_num(), build_value);
106 if (depends.empty()) {
107 MS_LOG(WARNING) << "For " << b->prim()->name() << ", the depends list is empty.";
108 return nullptr;
109 }
110 if (b->input_num() < depends.size()) {
111 MS_LOG(WARNING) << "For " << b->prim()->name() << ", the input args num is less than the depends size. "
112 << b->input_num() << " vs " << depends.size();
113 return nullptr;
114 }
115 SymbolPtrList inputs;
116 inputs.reserve(depends.size());
117 for (size_t i = 0; i < depends.size(); i++) {
118 if (depends[i] == DependOn::kShape) {
119 (void)inputs.emplace_back(b->GetInputShape(i));
120 } else if (depends[i] == DependOn::kValue) {
121 (void)inputs.emplace_back(b->GetInputValue(i));
122 }
123 }
124 return b->Emit(std::make_shared<OP>(std::move(inputs)));
125 }
126
127 /// \brief Use the input symbol as output directly.
128 ///
129 /// \note When using this function, the `SetShapeDepend` or `SetValueDepend` should be set, and only one
130 /// "DependOn::kShape" (or "DependOn::kValue") exists. the depending symbol is used as output.
131 SymbolPtr TransparentInput(OperationBuilder *b);
132
133 class MS_CORE_API OperationBuilderInfoRegistry {
134 public:
135 static const OperationBuilderInfo *GetBuildInfo(const std::string &name);
136 static OperationBuilderPtr GetBuilder(const std::string &name, OperationEmitter *e);
HasOp(const std::string & name)137 static inline bool HasOp(const std::string &name) { return GetBuildInfo(name) != nullptr; }
138
Instance()139 static OperationBuilderInfoRegistry &Instance() {
140 static OperationBuilderInfoRegistry instance{};
141 return instance;
142 }
143
144 class RegHelper {
145 public:
RegHelper(const std::string & name)146 explicit RegHelper(const std::string &name) : builder_(OperationBuilderInfoRegistry::Instance().NewBuilder(name)) {}
SetShapeDepend(const std::initializer_list<DependOn> & depends)147 RegHelper &SetShapeDepend(const std::initializer_list<DependOn> &depends) {
148 builder_->shape_depend_list = depends;
149 return *this;
150 }
SetShapeDepend(const DependFunc & func)151 RegHelper &SetShapeDepend(const DependFunc &func) {
152 builder_->shape_depend_func = func;
153 return *this;
154 }
155 template <DependOn d, size_t n = 0>
SetShapeDependN()156 RegHelper &SetShapeDependN() {
157 return SetShapeDepend(DefaultDepender<d, n>);
158 }
SetShapeFunc(const InferFunc & func)159 RegHelper &SetShapeFunc(const InferFunc &func) {
160 builder_->build_shape_func = func;
161 return *this;
162 }
163 template <typename OP, typename = std::enable_if_t<std::is_base_of_v<Operation, OP>>>
SetShapeFuncWith()164 RegHelper &SetShapeFuncWith() {
165 return SetShapeFunc(DefaultBuilder<OP>);
166 }
SetShapeTransparentFunc()167 RegHelper &SetShapeTransparentFunc() { return SetShapeFunc(TransparentInput); }
168
SetValueDepend(const std::initializer_list<DependOn> & depends)169 RegHelper &SetValueDepend(const std::initializer_list<DependOn> &depends) {
170 builder_->value_depend_list = depends;
171 return *this;
172 }
SetValueDepend(const DependFunc & func)173 RegHelper &SetValueDepend(const DependFunc &func) {
174 builder_->value_depend_func = func;
175 return *this;
176 }
177 template <DependOn d, size_t n = 0>
SetValueDependN()178 RegHelper &SetValueDependN() {
179 return SetValueDepend(DefaultDepender<d, n>);
180 }
SetValueFunc(const InferFunc & func)181 RegHelper &SetValueFunc(const InferFunc &func) {
182 builder_->build_value_func = func;
183 return *this;
184 }
185 template <typename OP, typename = std::enable_if_t<std::is_base_of_v<Operation, OP>>>
SetValueFuncWith()186 RegHelper &SetValueFuncWith() {
187 return SetValueFunc(DefaultBuilder<OP>);
188 }
189 OperationBuilderInfo *builder_;
190 }; // class RegHelper
191
builders()192 const std::unordered_map<std::string, OperationBuilderInfo> &builders() const { return builders_; }
193
194 private:
NewBuilder(const std::string & name)195 OperationBuilderInfo *NewBuilder(const std::string &name) { return &builders_[name]; }
196 std::unordered_map<std::string, OperationBuilderInfo> builders_;
197 };
198
199 #define JOIN(x, y) x##y
200 #define UNIQUE_NAME(prefix, cnt) JOIN(prefix, cnt)
201 #define REG_SYMBOL_OP_BUILDER(name) \
202 const auto UNIQUE_NAME(g_ob_, __COUNTER__) = OperationBuilderInfoRegistry::RegHelper(name)
203 } // namespace symshape
204 } // namespace mindspore
205 #endif // MINDSPORE_CORE_SYMBOLIC_SHAPE_OPERATION_BUILDER_H_
206