• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "backend/common/graph_kernel/symbol_engine/multi_symbol_engine.h"
17 #include <utility>
18 #include "mindspore/core/symbolic_shape/utils.h"
19 #include "mindspore/core/symbolic_shape/int_symbol.h"
20 
21 namespace mindspore {
22 namespace graphkernel {
23 namespace symshape {
24 using mindspore::symshape::CloneAbstractIfSymbolExists;
25 
SaveInputParaMap(std::map<SymbolPtr,SymbolPtr> * input_para_map,const SymbolPtr & inp,const SymbolPtr & para)26 void MultiSymbolEngine::SaveInputParaMap(std::map<SymbolPtr, SymbolPtr> *input_para_map, const SymbolPtr &inp,
27                                          const SymbolPtr &para) {
28   if (inp->tid() != para->tid()) {
29     MS_LOG(WARNING) << "The type of input and para are not match, " << inp->type_name() << " vs " << para->type_name();
30     return;
31   }
32   (*input_para_map)[inp] = para;
33 }
34 
BuildShapeWithInputHint(const AbstractBasePtr & para_abs,const std::vector<ListSymbolPtr> & inputs,std::map<SymbolPtr,SymbolPtr> * input_para_map)35 ListSymbolPtr MultiSymbolEngine::BuildShapeWithInputHint(const AbstractBasePtr &para_abs,
36                                                          const std::vector<ListSymbolPtr> &inputs,
37                                                          std::map<SymbolPtr, SymbolPtr> *input_para_map) {
38   // only support TensorShape, that input is int-list symbol.
39   if (!para_abs->GetShape()->isa<abstract::TensorShape>()) {
40     return nullptr;
41   }
42   auto cur_shape = inputs.back();
43   for (auto &inp_para : *input_para_map) {
44     if (cur_shape->EqualsTo(inp_para.first)) {
45       return inp_para.second->as_sptr<ListSymbol>();
46     }
47   }
48   if (cur_shape->is_dyn_len()) {
49     return ListSymbol::Make();
50   }
51   SymbolPtrList para_shape;
52   para_shape.reserve(cur_shape->size());
53   for (auto &cur_item : cur_shape->symbols()) {
54     if (cur_item->is<IntSymbol>() && cur_item->HasData()) {
55       (void)para_shape.emplace_back(cur_item);
56       continue;
57     }
58     SymbolPtr para_item = nullptr;
59     for (auto &inp_para : *input_para_map) {
60       if (cur_item->EqualsTo(inp_para.first)) {
61         para_item = inp_para.second;
62         break;
63       }
64     }
65     if (para_item == nullptr) {
66       (void)para_shape.emplace_back(IntSymbol::Make());
67       SaveInputParaMap(input_para_map, cur_item, para_shape.back());
68     } else {
69       (void)para_shape.emplace_back(std::move(para_item));
70     }
71   }
72   return ListSymbol::Make(std::move(para_shape));
73 }
74 
75 // set symbol info for subgraph's parameters, according to the outer cnode's input symbol info.
GenInputSymbols(const CNodePtr & cnode,const FuncGraphPtr & sub_fg,size_t begin_input_index)76 void MultiSymbolEngine::GenInputSymbols(const CNodePtr &cnode, const FuncGraphPtr &sub_fg, size_t begin_input_index) {
77   std::vector<ListSymbolPtr> input_symbolic_shapes;
78   std::map<SymbolPtr, SymbolPtr> input_para_map;
79   input_symbolic_shapes.reserve(cnode->size());
80   for (size_t i = 0; i < sub_fg->parameters().size(); i++) {
81     auto inp_abs = cnode->input(i + begin_input_index)->abstract();
82     MS_EXCEPTION_IF_NULL(inp_abs);
83     auto para_abs = CloneAbstractIfSymbolExists(sub_fg->parameters()[i]);
84     MS_EXCEPTION_IF_NULL(para_abs);
85     (void)input_symbolic_shapes.emplace_back(inp_abs->GetSymbolicShape());
86     if (input_symbolic_shapes.back() != nullptr) {
87       auto s = BuildShapeWithInputHint(para_abs, input_symbolic_shapes, &input_para_map);
88       if (s == nullptr || !s->is<ListSymbol>()) {
89         s = para_abs->GetShape()->BuildSymbolicShape();
90       }
91       para_abs->SetSymbolicShape(s->as_sptr<ListSymbol>());
92       SaveInputParaMap(&input_para_map, input_symbolic_shapes.back(), s);
93     }
94     if (inp_abs->GetSymbolicValue() != nullptr) {
95       para_abs->SetSymbolicValue(mindspore::symshape::BuildSymbolicValue(para_abs));
96     }
97   }
98 }
99 
Build(const FuncGraphPtr & func_graph)100 void MultiSymbolEngine::Build(const FuncGraphPtr &func_graph) {
101   auto engine = std::make_shared<MultiSymbolEngine>(func_graph);
102   func_graph->set_symbol_engine(engine);
103   engine->PreBuild();
104   engine->BuildImpl();
105 }
106 
BuildSubEngine(const AnfNodePtr & node)107 void MultiSymbolEngine::BuildSubEngine(const AnfNodePtr &node) {
108   auto sub_fg = GetCNodeFuncGraph(node);
109   MS_EXCEPTION_IF_NULL(sub_fg);
110   auto engine = std::make_shared<MultiSymbolEngine>(sub_fg);
111   sub_fg->set_symbol_engine(engine);
112   engine->PreBuild();
113   auto main_engine = node->func_graph()->symbol_engine();
114   if (main_engine != nullptr && main_engine->isa<MultiSymbolEngine>()) {
115     main_engine->cast_ptr<MultiSymbolEngine>()->BuildSubgraphImpl(node->cast<CNodePtr>(), sub_fg, 1);
116   } else {
117     engine->BuildImpl();
118   }
119 }
120 
PreBuildQuerySubgraphDependStatus(const CNodePtr & cnode,const FuncGraphPtr & sub_fg,size_t begin_input_index)121 void MultiSymbolEngine::PreBuildQuerySubgraphDependStatus(const CNodePtr &cnode, const FuncGraphPtr &sub_fg,
122                                                           size_t begin_input_index) {
123   auto sub_engine = std::make_shared<MultiSymbolEngine>(sub_fg);
124   sub_fg->set_symbol_engine(sub_engine);
125   sub_engine->depend_status_map_[sub_fg->output()] = this->depend_status_map_[cnode];
126   sub_engine->PreBuild();
127 
128   for (auto &param : sub_fg->parameters()) {
129     auto &cnode_input_depend_status = this->depend_status_map_[cnode->input(begin_input_index++)];
130     auto depend_status = sub_engine->depend_status_map_[param];
131     if (depend_status.shape) {
132       cnode_input_depend_status.shape = true;
133     }
134     if (depend_status.value) {
135       cnode_input_depend_status.value = true;
136     }
137   }
138 }
139 
BuildSubgraphImpl(const CNodePtr & cnode,const FuncGraphPtr & sub_fg,size_t begin_input_index)140 void MultiSymbolEngine::BuildSubgraphImpl(const CNodePtr &cnode, const FuncGraphPtr &sub_fg, size_t begin_input_index) {
141   MS_EXCEPTION_IF_NULL(sub_fg);
142   MS_LOG(DEBUG) << "Build subgraph " << sub_fg->ToString() << " of node " << cnode->fullname_with_scope();
143 
144   MS_EXCEPTION_IF_NULL(sub_fg->symbol_engine());
145   auto sub_engine = sub_fg->symbol_engine()->cast_ptr<MultiSymbolEngine>();
146   MS_EXCEPTION_IF_NULL(sub_engine);
147   GenInputSymbols(cnode, sub_fg, begin_input_index);
148 
149   sub_engine->BuildImpl();
150 
151   auto out_abs = sub_fg->output()->abstract();
152   MS_EXCEPTION_IF_NULL(out_abs);
153   auto cnode_abs = CloneAbstractIfSymbolExists(cnode);
154   MS_EXCEPTION_IF_NULL(cnode_abs);
155   cnode_abs->SetSymbolicShape(out_abs->GetSymbolicShape());
156   cnode_abs->SetSymbolicValue(out_abs->GetSymbolicValue());
157 }
158 }  // namespace symshape
159 }  // namespace graphkernel
160 }  // namespace mindspore
161