1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "backend/common/graph_kernel/symbol_engine/multi_symbol_engine.h"
17 #include <utility>
18 #include "mindspore/core/symbolic_shape/utils.h"
19 #include "mindspore/core/symbolic_shape/int_symbol.h"
20
21 namespace mindspore {
22 namespace graphkernel {
23 namespace symshape {
24 using mindspore::symshape::CloneAbstractIfSymbolExists;
25
SaveInputParaMap(std::map<SymbolPtr,SymbolPtr> * input_para_map,const SymbolPtr & inp,const SymbolPtr & para)26 void MultiSymbolEngine::SaveInputParaMap(std::map<SymbolPtr, SymbolPtr> *input_para_map, const SymbolPtr &inp,
27 const SymbolPtr ¶) {
28 if (inp->tid() != para->tid()) {
29 MS_LOG(WARNING) << "The type of input and para are not match, " << inp->type_name() << " vs " << para->type_name();
30 return;
31 }
32 (*input_para_map)[inp] = para;
33 }
34
BuildShapeWithInputHint(const AbstractBasePtr & para_abs,const std::vector<ListSymbolPtr> & inputs,std::map<SymbolPtr,SymbolPtr> * input_para_map)35 ListSymbolPtr MultiSymbolEngine::BuildShapeWithInputHint(const AbstractBasePtr ¶_abs,
36 const std::vector<ListSymbolPtr> &inputs,
37 std::map<SymbolPtr, SymbolPtr> *input_para_map) {
38 // only support TensorShape, that input is int-list symbol.
39 if (!para_abs->GetShape()->isa<abstract::TensorShape>()) {
40 return nullptr;
41 }
42 auto cur_shape = inputs.back();
43 for (auto &inp_para : *input_para_map) {
44 if (cur_shape->EqualsTo(inp_para.first)) {
45 return inp_para.second->as_sptr<ListSymbol>();
46 }
47 }
48 if (cur_shape->is_dyn_len()) {
49 return ListSymbol::Make();
50 }
51 SymbolPtrList para_shape;
52 para_shape.reserve(cur_shape->size());
53 for (auto &cur_item : cur_shape->symbols()) {
54 if (cur_item->is<IntSymbol>() && cur_item->HasData()) {
55 (void)para_shape.emplace_back(cur_item);
56 continue;
57 }
58 SymbolPtr para_item = nullptr;
59 for (auto &inp_para : *input_para_map) {
60 if (cur_item->EqualsTo(inp_para.first)) {
61 para_item = inp_para.second;
62 break;
63 }
64 }
65 if (para_item == nullptr) {
66 (void)para_shape.emplace_back(IntSymbol::Make());
67 SaveInputParaMap(input_para_map, cur_item, para_shape.back());
68 } else {
69 (void)para_shape.emplace_back(std::move(para_item));
70 }
71 }
72 return ListSymbol::Make(std::move(para_shape));
73 }
74
75 // set symbol info for subgraph's parameters, according to the outer cnode's input symbol info.
GenInputSymbols(const CNodePtr & cnode,const FuncGraphPtr & sub_fg,size_t begin_input_index)76 void MultiSymbolEngine::GenInputSymbols(const CNodePtr &cnode, const FuncGraphPtr &sub_fg, size_t begin_input_index) {
77 std::vector<ListSymbolPtr> input_symbolic_shapes;
78 std::map<SymbolPtr, SymbolPtr> input_para_map;
79 input_symbolic_shapes.reserve(cnode->size());
80 for (size_t i = 0; i < sub_fg->parameters().size(); i++) {
81 auto inp_abs = cnode->input(i + begin_input_index)->abstract();
82 MS_EXCEPTION_IF_NULL(inp_abs);
83 auto para_abs = CloneAbstractIfSymbolExists(sub_fg->parameters()[i]);
84 MS_EXCEPTION_IF_NULL(para_abs);
85 (void)input_symbolic_shapes.emplace_back(inp_abs->GetSymbolicShape());
86 if (input_symbolic_shapes.back() != nullptr) {
87 auto s = BuildShapeWithInputHint(para_abs, input_symbolic_shapes, &input_para_map);
88 if (s == nullptr || !s->is<ListSymbol>()) {
89 s = para_abs->GetShape()->BuildSymbolicShape();
90 }
91 para_abs->SetSymbolicShape(s->as_sptr<ListSymbol>());
92 SaveInputParaMap(&input_para_map, input_symbolic_shapes.back(), s);
93 }
94 if (inp_abs->GetSymbolicValue() != nullptr) {
95 para_abs->SetSymbolicValue(mindspore::symshape::BuildSymbolicValue(para_abs));
96 }
97 }
98 }
99
Build(const FuncGraphPtr & func_graph)100 void MultiSymbolEngine::Build(const FuncGraphPtr &func_graph) {
101 auto engine = std::make_shared<MultiSymbolEngine>(func_graph);
102 func_graph->set_symbol_engine(engine);
103 engine->PreBuild();
104 engine->BuildImpl();
105 }
106
BuildSubEngine(const AnfNodePtr & node)107 void MultiSymbolEngine::BuildSubEngine(const AnfNodePtr &node) {
108 auto sub_fg = GetCNodeFuncGraph(node);
109 MS_EXCEPTION_IF_NULL(sub_fg);
110 auto engine = std::make_shared<MultiSymbolEngine>(sub_fg);
111 sub_fg->set_symbol_engine(engine);
112 engine->PreBuild();
113 auto main_engine = node->func_graph()->symbol_engine();
114 if (main_engine != nullptr && main_engine->isa<MultiSymbolEngine>()) {
115 main_engine->cast_ptr<MultiSymbolEngine>()->BuildSubgraphImpl(node->cast<CNodePtr>(), sub_fg, 1);
116 } else {
117 engine->BuildImpl();
118 }
119 }
120
PreBuildQuerySubgraphDependStatus(const CNodePtr & cnode,const FuncGraphPtr & sub_fg,size_t begin_input_index)121 void MultiSymbolEngine::PreBuildQuerySubgraphDependStatus(const CNodePtr &cnode, const FuncGraphPtr &sub_fg,
122 size_t begin_input_index) {
123 auto sub_engine = std::make_shared<MultiSymbolEngine>(sub_fg);
124 sub_fg->set_symbol_engine(sub_engine);
125 sub_engine->depend_status_map_[sub_fg->output()] = this->depend_status_map_[cnode];
126 sub_engine->PreBuild();
127
128 for (auto ¶m : sub_fg->parameters()) {
129 auto &cnode_input_depend_status = this->depend_status_map_[cnode->input(begin_input_index++)];
130 auto depend_status = sub_engine->depend_status_map_[param];
131 if (depend_status.shape) {
132 cnode_input_depend_status.shape = true;
133 }
134 if (depend_status.value) {
135 cnode_input_depend_status.value = true;
136 }
137 }
138 }
139
BuildSubgraphImpl(const CNodePtr & cnode,const FuncGraphPtr & sub_fg,size_t begin_input_index)140 void MultiSymbolEngine::BuildSubgraphImpl(const CNodePtr &cnode, const FuncGraphPtr &sub_fg, size_t begin_input_index) {
141 MS_EXCEPTION_IF_NULL(sub_fg);
142 MS_LOG(DEBUG) << "Build subgraph " << sub_fg->ToString() << " of node " << cnode->fullname_with_scope();
143
144 MS_EXCEPTION_IF_NULL(sub_fg->symbol_engine());
145 auto sub_engine = sub_fg->symbol_engine()->cast_ptr<MultiSymbolEngine>();
146 MS_EXCEPTION_IF_NULL(sub_engine);
147 GenInputSymbols(cnode, sub_fg, begin_input_index);
148
149 sub_engine->BuildImpl();
150
151 auto out_abs = sub_fg->output()->abstract();
152 MS_EXCEPTION_IF_NULL(out_abs);
153 auto cnode_abs = CloneAbstractIfSymbolExists(cnode);
154 MS_EXCEPTION_IF_NULL(cnode_abs);
155 cnode_abs->SetSymbolicShape(out_abs->GetSymbolicShape());
156 cnode_abs->SetSymbolicValue(out_abs->GetSymbolicValue());
157 }
158 } // namespace symshape
159 } // namespace graphkernel
160 } // namespace mindspore
161