• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "mindspore/core/symbolic_shape/operation_builder.h"
17 #include "mindspore/core/symbolic_shape/utils.h"
18 
19 namespace mindspore {
20 namespace symshape {
BuildShape(const PrimitivePtr & prim,const AbstractBasePtrList & input_args,const AbstractBasePtr & out)21 SymbolPtr OperationBuilder::BuildShape(const PrimitivePtr &prim, const AbstractBasePtrList &input_args,
22                                        const AbstractBasePtr &out) {
23   is_building_shape_ = true;
24   prim_ = prim;
25   input_args_ = &input_args;
26   out_ = out;
27   if (symbol_builder_info_.build_shape_func == nullptr) {
28     return nullptr;
29   }
30   return symbol_builder_info_.build_shape_func(this);
31 }
32 
BuildValue(const PrimitivePtr & prim,const AbstractBasePtrList & input_args,const AbstractBasePtr & out)33 SymbolPtr OperationBuilder::BuildValue(const PrimitivePtr &prim, const AbstractBasePtrList &input_args,
34                                        const AbstractBasePtr &out) {
35   is_building_shape_ = false;
36   prim_ = prim;
37   input_args_ = &input_args;
38   out_ = out;
39   if (symbol_builder_info_.build_value_func == nullptr) {
40     return nullptr;
41   }
42   return symbol_builder_info_.build_value_func(this);
43 }
44 
GetShape(const AbstractBasePtr & abs) const45 SymbolPtr OperationBuilder::GetShape(const AbstractBasePtr &abs) const {
46   auto real_shape = abs->GetSymbolicShape();
47   if (real_shape != nullptr) {
48     return real_shape;
49   }
50   auto baseshape = abs->GetShape();
51   MS_EXCEPTION_IF_NULL(baseshape);
52   real_shape = baseshape->BuildSymbolicShape();
53   MS_EXCEPTION_IF_NULL(real_shape);
54   abs->SetSymbolicShape(real_shape);
55   return real_shape;
56 }
57 
GetValue(const AbstractBasePtr & abs) const58 SymbolPtr OperationBuilder::GetValue(const AbstractBasePtr &abs) const {
59   SymbolPtr smbl = abs->GetSymbolicValue();
60   if (smbl != nullptr) {
61     return smbl;
62   }
63   smbl = BuildSymbolicValue(abs);
64   MS_EXCEPTION_IF_NULL(smbl);
65   abs->SetSymbolicValue(smbl);
66   return smbl;
67 }
68 
GetAttr(const std::string & attr_name) const69 SymbolPtr OperationBuilder::GetAttr(const std::string &attr_name) const {
70   auto attr = prim_->GetAttr(attr_name);
71   if (attr == nullptr) {
72     return nullptr;
73   }
74   return ConstValueToSymbol(attr);
75 }
76 
GetInputOrAttr(size_t index,const std::string & attr_name) const77 SymbolPtr OperationBuilder::GetInputOrAttr(size_t index, const std::string &attr_name) const {
78   if (input_args_->size() > index) {
79     return GetInputValue(index);
80   }
81   return GetAttr(attr_name);
82 }
83 
Emit(const OpPtr & op) const84 SymbolPtr OperationBuilder::Emit(const OpPtr &op) const {
85   op->SetOutAbstract(this->out_abstract());
86   auto ret = emitter_->Emit(op);
87   op->SetOutAbstract(nullptr);
88   return ret;
89 }
90 
TransparentInput(OperationBuilder * b)91 SymbolPtr TransparentInput(OperationBuilder *b) {
92   bool build_value = !b->is_building_shape();
93   auto depends = b->symbol_builder_info().GetDepends(b->prim(), b->input_num(), build_value);
94   // check only one depend status in the list.
95   auto iter1 = std::find_if(depends.begin(), depends.end(), [](DependOn d) { return d != DependOn::kNone; });
96   if (iter1 == depends.end()) {
97     return nullptr;
98   }
99   auto iter2 = std::find_if(iter1 + 1, depends.end(), [](DependOn d) { return d != DependOn::kNone; });
100   if (iter2 != depends.end()) {
101     return nullptr;
102   }
103   size_t idx = iter1 - depends.begin();
104   return (*iter1 == DependOn::kShape) ? b->GetInputShape(idx) : b->GetInputValue(idx);
105 }
106 
GetBuildInfo(const std::string & name)107 const OperationBuilderInfo *OperationBuilderInfoRegistry::GetBuildInfo(const std::string &name) {
108   const auto &builders = OperationBuilderInfoRegistry::Instance().builders_;
109   auto iter = builders.find(name);
110   return (iter == builders.end() ? nullptr : &(iter->second));
111 }
112 
GetBuilder(const std::string & name,OperationEmitter * e)113 OperationBuilderPtr OperationBuilderInfoRegistry::GetBuilder(const std::string &name, OperationEmitter *e) {
114   auto *build_info = GetBuildInfo(name);
115   if (build_info == nullptr) {
116     return nullptr;
117   }
118   return std::make_unique<OperationBuilder>(e, *build_info);
119 }
120 
GetShapeDepends(const PrimitivePtr & prim,size_t input_num)121 std::vector<DependOn> GetShapeDepends(const PrimitivePtr &prim, size_t input_num) {
122   MS_EXCEPTION_IF_NULL(prim);
123   auto build_info = OperationBuilderInfoRegistry::GetBuildInfo(prim->name());
124   if (build_info == nullptr) {
125     return std::vector<DependOn>();
126   }
127   auto ret = build_info->GetDepends(prim, input_num, false);
128   if (!ret.empty()) {
129     ret.resize(input_num, DependOn::kNone);
130   }
131   return ret;
132 }
133 
GetValueDepends(const PrimitivePtr & prim,size_t input_num)134 std::vector<DependOn> GetValueDepends(const PrimitivePtr &prim, size_t input_num) {
135   MS_EXCEPTION_IF_NULL(prim);
136   auto build_info = OperationBuilderInfoRegistry::GetBuildInfo(prim->name());
137   if (build_info == nullptr) {
138     return std::vector<DependOn>();
139   }
140   auto ret = build_info->GetDepends(prim, input_num, true);
141   if (!ret.empty()) {
142     ret.resize(input_num, DependOn::kNone);
143   }
144   return ret;
145 }
146 }  // namespace symshape
147 }  // namespace mindspore
148