1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "mindspore/core/symbolic_shape/operation_builder.h"
17 #include "mindspore/core/symbolic_shape/utils.h"
18
19 namespace mindspore {
20 namespace symshape {
BuildShape(const PrimitivePtr & prim,const AbstractBasePtrList & input_args,const AbstractBasePtr & out)21 SymbolPtr OperationBuilder::BuildShape(const PrimitivePtr &prim, const AbstractBasePtrList &input_args,
22 const AbstractBasePtr &out) {
23 is_building_shape_ = true;
24 prim_ = prim;
25 input_args_ = &input_args;
26 out_ = out;
27 if (symbol_builder_info_.build_shape_func == nullptr) {
28 return nullptr;
29 }
30 return symbol_builder_info_.build_shape_func(this);
31 }
32
BuildValue(const PrimitivePtr & prim,const AbstractBasePtrList & input_args,const AbstractBasePtr & out)33 SymbolPtr OperationBuilder::BuildValue(const PrimitivePtr &prim, const AbstractBasePtrList &input_args,
34 const AbstractBasePtr &out) {
35 is_building_shape_ = false;
36 prim_ = prim;
37 input_args_ = &input_args;
38 out_ = out;
39 if (symbol_builder_info_.build_value_func == nullptr) {
40 return nullptr;
41 }
42 return symbol_builder_info_.build_value_func(this);
43 }
44
GetShape(const AbstractBasePtr & abs) const45 SymbolPtr OperationBuilder::GetShape(const AbstractBasePtr &abs) const {
46 auto real_shape = abs->GetSymbolicShape();
47 if (real_shape != nullptr) {
48 return real_shape;
49 }
50 auto baseshape = abs->GetShape();
51 MS_EXCEPTION_IF_NULL(baseshape);
52 real_shape = baseshape->BuildSymbolicShape();
53 MS_EXCEPTION_IF_NULL(real_shape);
54 abs->SetSymbolicShape(real_shape);
55 return real_shape;
56 }
57
GetValue(const AbstractBasePtr & abs) const58 SymbolPtr OperationBuilder::GetValue(const AbstractBasePtr &abs) const {
59 SymbolPtr smbl = abs->GetSymbolicValue();
60 if (smbl != nullptr) {
61 return smbl;
62 }
63 smbl = BuildSymbolicValue(abs);
64 MS_EXCEPTION_IF_NULL(smbl);
65 abs->SetSymbolicValue(smbl);
66 return smbl;
67 }
68
GetAttr(const std::string & attr_name) const69 SymbolPtr OperationBuilder::GetAttr(const std::string &attr_name) const {
70 auto attr = prim_->GetAttr(attr_name);
71 if (attr == nullptr) {
72 return nullptr;
73 }
74 return ConstValueToSymbol(attr);
75 }
76
GetInputOrAttr(size_t index,const std::string & attr_name) const77 SymbolPtr OperationBuilder::GetInputOrAttr(size_t index, const std::string &attr_name) const {
78 if (input_args_->size() > index) {
79 return GetInputValue(index);
80 }
81 return GetAttr(attr_name);
82 }
83
Emit(const OpPtr & op) const84 SymbolPtr OperationBuilder::Emit(const OpPtr &op) const {
85 op->SetOutAbstract(this->out_abstract());
86 auto ret = emitter_->Emit(op);
87 op->SetOutAbstract(nullptr);
88 return ret;
89 }
90
TransparentInput(OperationBuilder * b)91 SymbolPtr TransparentInput(OperationBuilder *b) {
92 bool build_value = !b->is_building_shape();
93 auto depends = b->symbol_builder_info().GetDepends(b->prim(), b->input_num(), build_value);
94 // check only one depend status in the list.
95 auto iter1 = std::find_if(depends.begin(), depends.end(), [](DependOn d) { return d != DependOn::kNone; });
96 if (iter1 == depends.end()) {
97 return nullptr;
98 }
99 auto iter2 = std::find_if(iter1 + 1, depends.end(), [](DependOn d) { return d != DependOn::kNone; });
100 if (iter2 != depends.end()) {
101 return nullptr;
102 }
103 size_t idx = iter1 - depends.begin();
104 return (*iter1 == DependOn::kShape) ? b->GetInputShape(idx) : b->GetInputValue(idx);
105 }
106
GetBuildInfo(const std::string & name)107 const OperationBuilderInfo *OperationBuilderInfoRegistry::GetBuildInfo(const std::string &name) {
108 const auto &builders = OperationBuilderInfoRegistry::Instance().builders_;
109 auto iter = builders.find(name);
110 return (iter == builders.end() ? nullptr : &(iter->second));
111 }
112
GetBuilder(const std::string & name,OperationEmitter * e)113 OperationBuilderPtr OperationBuilderInfoRegistry::GetBuilder(const std::string &name, OperationEmitter *e) {
114 auto *build_info = GetBuildInfo(name);
115 if (build_info == nullptr) {
116 return nullptr;
117 }
118 return std::make_unique<OperationBuilder>(e, *build_info);
119 }
120
GetShapeDepends(const PrimitivePtr & prim,size_t input_num)121 std::vector<DependOn> GetShapeDepends(const PrimitivePtr &prim, size_t input_num) {
122 MS_EXCEPTION_IF_NULL(prim);
123 auto build_info = OperationBuilderInfoRegistry::GetBuildInfo(prim->name());
124 if (build_info == nullptr) {
125 return std::vector<DependOn>();
126 }
127 auto ret = build_info->GetDepends(prim, input_num, false);
128 if (!ret.empty()) {
129 ret.resize(input_num, DependOn::kNone);
130 }
131 return ret;
132 }
133
GetValueDepends(const PrimitivePtr & prim,size_t input_num)134 std::vector<DependOn> GetValueDepends(const PrimitivePtr &prim, size_t input_num) {
135 MS_EXCEPTION_IF_NULL(prim);
136 auto build_info = OperationBuilderInfoRegistry::GetBuildInfo(prim->name());
137 if (build_info == nullptr) {
138 return std::vector<DependOn>();
139 }
140 auto ret = build_info->GetDepends(prim, input_num, true);
141 if (!ret.empty()) {
142 ret.resize(input_num, DependOn::kNone);
143 }
144 return ret;
145 }
146 } // namespace symshape
147 } // namespace mindspore
148