1 /**
2 * Copyright 2023-2024 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "include/common/symbol_engine/symbol_engine_impl.h"
17 #include <algorithm>
18 #include <ostream>
19 #include "ir/anf.h"
20 #include "ir/func_graph.h"
21 #include "ir/graph_utils.h"
22 #include "ops/array_ops.h"
23 #include "ops/framework_ops.h"
24 #include "ops/sequence_ops.h"
25 #include "mindspore/core/ops/symbol_ops_impl/switch.h"
26 #include "mindspore/core/ops/symbol_ops_impl/j_op.h"
27 #include "utils/check_convert_utils.h"
28 #include "utils/anf_utils.h"
29 #include "mindspore/core/symbolic_shape/utils.h"
30 #include "mindspore/core/symbolic_shape/operation_builder.h"
31
32 namespace mindspore {
33 namespace symshape {
GetCNodesOfFuncGraph(const FuncGraphPtr & fg)34 AnfNodePtrList GetCNodesOfFuncGraph(const FuncGraphPtr &fg) {
35 bool has_node_in_other_graph = false;
36 auto nodes = TopoSort(fg->output(), SuccDeeperSimple, [&has_node_in_other_graph, &fg](const AnfNodePtr &node) {
37 if (!node->isa<CNode>()) {
38 if (GetValuePtr<FuncGraph>(node) != nullptr) {
39 // some nodes of this graph can be linked in other graph.
40 has_node_in_other_graph = true;
41 return FOLLOW;
42 }
43 return EXCLUDE;
44 }
45 if (node->func_graph() != fg) {
46 has_node_in_other_graph = true;
47 }
48 return FOLLOW;
49 });
50 // at frontend, a node may directly links to other node in other graph.
51 if (has_node_in_other_graph) {
52 (void)nodes.erase(
53 std::remove_if(nodes.begin(), nodes.end(),
54 [&fg](const AnfNodePtr &node) { return !node->isa<CNode>() || node->func_graph() != fg; }),
55 nodes.end());
56 }
57 return nodes;
58 }
59
GetFuncGraphFromCNode(const CNodePtr & cnode)60 std::pair<FuncGraphPtr, size_t> GetFuncGraphFromCNode(const CNodePtr &cnode) {
61 auto sub_fg = GetCNodeFuncGraph(cnode);
62 size_t begin_index = kIndex1;
63 if (sub_fg == nullptr && IsPrimitiveCNode(cnode, prim::kPrimPartial)) {
64 auto vnode = cnode->input(kIndex1)->cast<ValueNodePtr>();
65 MS_EXCEPTION_IF_NULL(vnode);
66 sub_fg = vnode->value()->cast<FuncGraphPtr>();
67 MS_EXCEPTION_IF_NULL(sub_fg);
68 begin_index = kIndex2;
69 }
70 if (sub_fg != nullptr && sub_fg->parameters().size() + begin_index < cnode->size()) {
71 MS_LOG(INTERNAL_EXCEPTION) << "For graph " << sub_fg->ToString() << ", the parameter size "
72 << sub_fg->parameters().size() << " is less than cnode input num "
73 << cnode->size() - begin_index;
74 }
75 return std::make_pair(sub_fg, begin_index);
76 }
77
78 class ControlFlowJoinNode : public SpecialCNodeHelper {
79 public:
80 using SpecialCNodeHelper::SpecialCNodeHelper;
Match(const CNodePtr & cnode)81 static bool Match(const CNodePtr &cnode) { return IsPrimitiveCNode(cnode->input(0), prim::kPrimSwitch); }
SetDependStatus(std::map<AnfNodePtr,DependStatus> * depend_status_map)82 void SetDependStatus(std::map<AnfNodePtr, DependStatus> *depend_status_map) override {
83 auto input0 = input();
84 (*depend_status_map)[input0->input(kIndex1)].value = true;
85 SetFuncGraphDepend(input0->input(kIndex2));
86 SetFuncGraphDepend(input0->input(kIndex3));
87 }
ExtractInputs()88 std::pair<PrimitivePtr, AbstractBasePtrList> ExtractInputs() override {
89 auto prim = std::make_shared<Primitive>(ops::kControlFlowJoin);
90 AbstractBasePtrList inputs;
91 auto input0 = input();
92 (void)inputs.emplace_back(input0->input(kIndex1)->abstract());
93 (void)inputs.emplace_back(GetFuncGraphOutAbs(input0->input(kIndex2)));
94 (void)inputs.emplace_back(GetFuncGraphOutAbs(input0->input(kIndex3)));
95 return std::make_pair(std::move(prim), std::move(inputs));
96 }
97
98 protected:
input() const99 CNodePtr input() const {
100 auto input0 = cnode_->input(0)->cast<CNodePtr>();
101 MS_EXCEPTION_IF_NULL(input0);
102 return input0;
103 }
symbol_engine() const104 SymbolEngineImplPtr symbol_engine() const {
105 auto symbol_engine = cnode_->func_graph()->symbol_engine();
106 MS_EXCEPTION_IF_NULL(symbol_engine);
107 auto symbol_engine_impl = symbol_engine->cast<SymbolEngineImplPtr>();
108 MS_EXCEPTION_IF_NULL(symbol_engine_impl);
109 return symbol_engine_impl;
110 }
SetFuncGraphDepend(const AnfNodePtr & node) const111 void SetFuncGraphDepend(const AnfNodePtr &node) const {
112 auto fg = GetValueNode<FuncGraphPtr>(node);
113 if (fg != nullptr) {
114 symbol_engine()->PreBuildQuerySubgraphDependStatus(cnode_, fg, kIndex1);
115 }
116 }
117
GetFuncGraphOutAbs(const AnfNodePtr & node) const118 AbstractBasePtr GetFuncGraphOutAbs(const AnfNodePtr &node) const {
119 if (IsPrimitiveCNode(node, prim::kPrimPartial)) {
120 return GetFuncGraphFromCNode(node->cast<CNodePtr>()).first->output()->abstract();
121 }
122 // the graph with Partial is build symbols ahead, build the pure graph (without Partial) in Switch.
123 auto fg = GetValueNode<FuncGraphPtr>(node);
124 if (fg == nullptr) {
125 MS_EXCEPTION_IF_NULL(node->abstract());
126 return node->abstract();
127 }
128 symbol_engine()->BuildSubgraphImpl(cnode_, fg, kIndex1);
129 return fg->output()->abstract();
130 }
131 };
132
133 class JFuncCaller : public SpecialCNodeHelper {
134 public:
135 /// \brief The call node of PrimJ:
136 ///
137 /// %0 = J(@fg) // primitive "J"
138 /// %1 = %0(inp1, inp2, ...) // the node output a tuple of "(tensor, Func)"
139 /// %2 = TupleGetItem(%1, 1) // get the output "Func"
140 /// %3 = %2(loss_scale) // call the "Func".
141 ///
142 /// this pattern match the "%3", and the output shape is same as "inp1, inp2, ...".
JFuncCaller(const CNodePtr & cnode)143 explicit JFuncCaller(const CNodePtr &cnode) : SpecialCNodeHelper(cnode) {
144 auto getitem1 = cnode->input(kIndex0)->cast<CNodePtr>();
145 MS_EXCEPTION_IF_NULL(getitem1);
146 input_ = getitem1->input(kIndex1)->cast<CNodePtr>();
147 MS_EXCEPTION_IF_NULL(input_);
148 }
149 ~JFuncCaller() override = default;
Match(const CNodePtr & cnode)150 static bool Match(const CNodePtr &cnode) {
151 auto getitem1 = cnode->input(kIndex0)->cast<CNodePtr>();
152 if (getitem1 == nullptr || !IsPrimitiveCNode(getitem1, prim::kPrimTupleGetItem)) {
153 return false;
154 }
155 if (GetValue<int64_t>(GetValueNode(getitem1->input(kIndex2))) != 1) {
156 return false;
157 }
158 auto callj = getitem1->input(kIndex1)->cast<CNodePtr>();
159 return callj != nullptr && IsPrimitiveCNode(callj->input(kIndex0), prim::kPrimJ);
160 }
SetDependStatus(std::map<AnfNodePtr,DependStatus> * depend_status_map)161 void SetDependStatus(std::map<AnfNodePtr, DependStatus> *depend_status_map) override {
162 for (size_t i = 1; i < input_->size(); i++) {
163 (*depend_status_map)[input_->input(i)] = (*depend_status_map)[cnode_];
164 }
165 }
ExtractInputs()166 std::pair<PrimitivePtr, AbstractBasePtrList> ExtractInputs() override {
167 auto prim = std::make_shared<Primitive>(ops::kJFuncCaller);
168 AbstractBasePtrList inputs;
169 inputs.reserve(input_->size());
170 (void)std::transform(input_->inputs().begin(), input_->inputs().end(), std::back_inserter(inputs),
171 [](const AnfNodePtr &node) { return node->abstract(); });
172 return std::make_pair(std::move(prim), std::move(inputs));
173 }
174
175 protected:
176 CNodePtr input_{nullptr};
177 };
178
Build(const FuncGraphPtr & func_graph)179 SymbolEngineImplPtr SymbolEngineImpl::Build(const FuncGraphPtr &func_graph) {
180 if (func_graph->symbol_engine() != nullptr) {
181 CleanSymbols(func_graph);
182 }
183 auto engine = std::make_shared<SymbolEngineImpl>(func_graph);
184 func_graph->set_symbol_engine(engine);
185 engine->PreBuild();
186 engine->BuildImpl();
187 return engine;
188 }
189
BuildNodesSymbol(const FuncGraphPtr & fg,const AnfNodePtrList & cnodes)190 void SymbolEngineImpl::BuildNodesSymbol(const FuncGraphPtr &fg, const AnfNodePtrList &cnodes) {
191 for (auto &node : cnodes) {
192 auto cnode = node->cast<CNodePtr>();
193 MS_EXCEPTION_IF_NULL(cnode);
194 if (auto fg_with_index = GetFuncGraphFromCNode(cnode); fg_with_index.first != nullptr) {
195 // "call" or "Partial" node
196 BuildSubgraphImpl(cnode, fg_with_index.first, fg_with_index.second);
197 } else {
198 BuildCNodeSymbol(cnode);
199 }
200 }
201 // the funcgraph can be empty or only return a ValueNode.
202 if (!cnodes.empty()) {
203 return;
204 }
205 auto node = fg->output();
206 if (node->isa<ValueNode>()) {
207 auto depend_status = depend_status_map_[node];
208 auto node_abs = CloneAbstractIfSymbolExists(node);
209 MS_EXCEPTION_IF_NULL(node_abs);
210 if (depend_status.shape) {
211 auto sym_shape = node_abs->GetShape()->BuildSymbolicShape();
212 MS_LOG(DEBUG) << "Set shape for node: " << node->DebugString() << ". symbol: " << sym_shape->ToString();
213 node_abs->SetSymbolicShape(sym_shape);
214 }
215 if (depend_status.value) {
216 auto sym_value = BuildSymbolicValue(node_abs);
217 MS_LOG(DEBUG) << "Set value for node: " << node->DebugString() << ". symbol: " << sym_value->ToString();
218 node_abs->SetSymbolicValue(sym_value);
219 }
220 }
221 }
222
PreBuild()223 void SymbolEngineImpl::PreBuild() {
224 auto func_graph = func_graph_.lock();
225 MS_EXCEPTION_IF_NULL(func_graph);
226 cnodes_ = GetCNodesOfFuncGraph(func_graph);
227 visited_graph_[func_graph.get()] = 1;
228 PreBuildQueryDependStatus(cnodes_);
229 visited_graph_.clear();
230 }
231
BuildImpl()232 void SymbolEngineImpl::BuildImpl() {
233 auto func_graph = func_graph_.lock();
234 MS_EXCEPTION_IF_NULL(func_graph);
235 MS_LOG(DEBUG) << "Build " << ToString() << " with graph " << func_graph->ToString();
236 emitter_ = std::make_unique<OperationEmitter>(&ops_);
237 visited_graph_[func_graph.get()] = 1;
238 BuildNodesSymbol(func_graph, cnodes_);
239 emitter_->Clean();
240 visited_graph_.clear();
241 generalized_shape_.clear();
242 generalized_value_.clear();
243 }
244
PreBuildSpecialNode(const CNodePtr & cnode)245 void SymbolEngineImpl::PreBuildSpecialNode(const CNodePtr &cnode) {
246 std::shared_ptr<SpecialCNodeHelper> helper = nullptr;
247 if (ControlFlowJoinNode::Match(cnode)) {
248 helper = std::make_shared<ControlFlowJoinNode>(cnode);
249 } else if (JFuncCaller::Match(cnode)) {
250 helper = std::make_shared<JFuncCaller>(cnode);
251 } else {
252 MS_LOG(DEBUG) << "The special node " << cnode->fullname_with_scope() << " is not supported.";
253 return;
254 }
255 special_cnodes_[cnode] = helper;
256 helper->SetDependStatus(&depend_status_map_);
257 }
258
SetInputDependStatus(const CNodePtr & cnode,bool depend_value)259 void SymbolEngineImpl::SetInputDependStatus(const CNodePtr &cnode, bool depend_value) {
260 auto prim = GetCNodePrimitive(cnode);
261 MS_EXCEPTION_IF_NULL(prim);
262 size_t input_num = cnode->size() - 1;
263 auto depends = depend_value ? GetValueDepends(prim, input_num) : GetShapeDepends(prim, input_num);
264 for (size_t i = 0; i < depends.size(); i++) {
265 if (depends[i] == DependOn::kValue) {
266 depend_status_map_[cnode->input(i + 1)].value = true;
267 } else if (depends[i] == DependOn::kShape) {
268 depend_status_map_[cnode->input(i + 1)].shape = true;
269 }
270 }
271 }
272
PreBuildQueryDependStatus(const AnfNodePtrList & cnodes)273 void SymbolEngineImpl::PreBuildQueryDependStatus(const AnfNodePtrList &cnodes) {
274 for (auto iter = cnodes.rbegin(); iter != cnodes.rend(); ++iter) {
275 auto cnode = (*iter)->cast<CNodePtr>();
276 MS_EXCEPTION_IF_NULL(cnode);
277 auto &depend_status = depend_status_map_[cnode];
278 if (!depend_status.value && !depend_status.shape) {
279 // build symbolic shape for the node even though it's not depended by any nodes.
280 depend_status.shape = true;
281 }
282 MS_LOG(DEBUG) << "The depend status of " << cnode->DebugString() << "(" << cnode->fullname_with_scope()
283 << "): shape-depend=" << depend_status.shape << ", value-depend=" << depend_status.value;
284 if (cnode->input(0)->isa<CNode>()) {
285 PreBuildSpecialNode(cnode);
286 continue;
287 }
288 // the "call" node or Partial node.
289 auto subfg_with_index = GetFuncGraphFromCNode(cnode);
290 if (subfg_with_index.first != nullptr) {
291 PreBuildQuerySubgraphDependStatus(cnode, subfg_with_index.first, subfg_with_index.second);
292 continue;
293 }
294 // the normal CNode, check the depend status from operation builder info.
295 if (!OperationBuilderInfoRegistry::HasOp(AnfUtils::GetCNodeName(cnode))) {
296 continue;
297 }
298 if (depend_status.shape) {
299 SetInputDependStatus(cnode, false);
300 }
301 if (depend_status.value) {
302 SetInputDependStatus(cnode, true);
303 }
304 }
305 }
306
PreBuildQuerySubgraphDependStatus(const CNodePtr & cnode,const FuncGraphPtr & sub_fg,size_t begin_input_index)307 void SymbolEngineImpl::PreBuildQuerySubgraphDependStatus(const CNodePtr &cnode, const FuncGraphPtr &sub_fg,
308 size_t begin_input_index) {
309 if (++visited_graph_[sub_fg.get()] > 1) {
310 return;
311 }
312 sub_fg->set_symbol_engine(shared_from_base<SymbolEngine>());
313 depend_status_map_[sub_fg->output()] = depend_status_map_[cnode];
314 PreBuildQueryDependStatus(GetCNodesOfFuncGraph(sub_fg));
315 for (auto ¶m : sub_fg->parameters()) {
316 if (begin_input_index >= cnode->size()) {
317 break;
318 }
319 auto &cnode_input_depend_status = depend_status_map_[cnode->input(begin_input_index++)];
320 auto depend_status = depend_status_map_[param];
321 if (depend_status.shape) {
322 cnode_input_depend_status.shape = true;
323 }
324 if (depend_status.value) {
325 cnode_input_depend_status.value = true;
326 }
327 }
328 }
329
Infer(const AbstractBasePtrList & inputs)330 bool SymbolEngineImpl::Infer(const AbstractBasePtrList &inputs) {
331 if (!support_infer_) {
332 MS_LOG(WARNING) << "The " << ToString() << " does not support infer";
333 return false;
334 }
335 MS_LOG(DEBUG) << "Infer " << ToString() << " with inputs: " << inputs;
336 auto fg = func_graph_.lock();
337 MS_EXCEPTION_IF_NULL(fg);
338 auto ¶ms = fg->parameters();
339 // There may be params like UpdateStates, which won't contribute to infer
340 if (params.size() < inputs.size()) {
341 MS_LOG(EXCEPTION) << "The parameter size should be equal to or larger than inputs size, but got " << params.size()
342 << " vs " << inputs.size();
343 }
344 for (size_t i = 0; i < inputs.size(); i++) {
345 if (auto shape = params[i]->abstract()->GetSymbolicShape(); shape != nullptr) {
346 auto cur_shape = inputs[i]->GetShape()->BuildSymbolicShape();
347 MS_EXCEPTION_IF_NULL(cur_shape);
348 MS_LOG(DEBUG) << "Update shape for input[" << i << "]: " << cur_shape->ToRawString();
349 shape->Update(cur_shape);
350 }
351 if (auto value = params[i]->abstract()->GetSymbolicValue(); value != nullptr && value->CanUpdate()) {
352 auto cur_value = BuildSymbolicValue(inputs[i]);
353 MS_EXCEPTION_IF_NULL(cur_value);
354 MS_LOG(DEBUG) << "Update value for input[" << i << "]: " << cur_value->ToRawString();
355 value->Update(cur_value);
356 }
357 }
358 for (auto &op : ops_) {
359 op->Run();
360 }
361 return true;
362 }
363
IsDependValue(const AnfNodePtr & node)364 bool SymbolEngineImpl::IsDependValue(const AnfNodePtr &node) {
365 if (depend_status_map_.find(node) != depend_status_map_.end()) {
366 return depend_status_map_[node].value;
367 }
368 return false;
369 }
370
IsDependShape(const AnfNodePtr & node)371 bool SymbolEngineImpl::IsDependShape(const AnfNodePtr &node) {
372 if (depend_status_map_.find(node) != depend_status_map_.end()) {
373 return depend_status_map_[node].shape;
374 }
375 return false;
376 }
377
QuerySymbolExprHelper(const SymbolPtr & s,const std::unordered_map<std::string,std::string> & symbol_expr_map)378 std::string SymbolEngineImpl::QuerySymbolExprHelper(
379 const SymbolPtr &s, const std::unordered_map<std::string, std::string> &symbol_expr_map) {
380 auto raw_string = s->ToRawString();
381 if (s->is<ListSymbol>() || s->HasData()) {
382 return raw_string;
383 }
384 if (symbol_expr_map.find(raw_string) != symbol_expr_map.end()) {
385 return raw_string;
386 }
387 auto operation = s->operation();
388 if (operation == nullptr) {
389 return raw_string;
390 }
391 std::ostringstream oss;
392 oss << operation->name() << "(";
393 bool first = true;
394 for (auto &input : operation->inputs()) {
395 if (first == true) {
396 first = false;
397 } else {
398 oss << ", ";
399 }
400 oss << QuerySymbolExprHelper(input, symbol_expr_map);
401 }
402 oss << ")";
403 return oss.str();
404 }
405
QuerySymbolExpr(const AnfNodePtr & node,std::unordered_map<std::string,std::string> * symbol_expr_map)406 void SymbolEngineImpl::QuerySymbolExpr(const AnfNodePtr &node,
407 std::unordered_map<std::string, std::string> *symbol_expr_map) {
408 // todo, use SymbolVisitor to export symbol expr.
409 auto symbolic_shape = node->abstract()->GetSymbolicShape();
410 if (symbolic_shape == nullptr) {
411 return;
412 }
413 for (const auto &symbol : symbolic_shape->symbols()) {
414 auto name = symbol->ToRawString();
415 if (name[0] == 's' && symbol_expr_map->find(name) == symbol_expr_map->end()) {
416 auto expr = QuerySymbolExprHelper(symbol, *symbol_expr_map);
417 (*symbol_expr_map)[name] = expr;
418 }
419 }
420 }
421
GeneralizeParamShape(const AnfNodePtr & param,const AbstractBasePtr & input_abs)422 bool SymbolEngineImpl::GeneralizeParamShape(const AnfNodePtr ¶m, const AbstractBasePtr &input_abs) {
423 if (generalized_shape_.count(param) > 0) {
424 return false;
425 }
426 auto param_abs = param->abstract();
427 MS_EXCEPTION_IF_NULL(param_abs);
428 if (param_abs->GetSymbolicShape() == nullptr || input_abs->GetSymbolicShape() == nullptr) {
429 return false;
430 }
431 auto param_shape = param_abs->GetSymbolicShape();
432 auto input_shape = input_abs->GetSymbolicShape();
433 if (param_shape->EqualsTo(input_shape)) {
434 return false;
435 }
436 bool build_again = false;
437 bool gen_all = false;
438 auto NewInt = [&build_again]() -> IntSymbolPtr {
439 build_again = true;
440 return IntSymbol::Make();
441 };
442 std::function<SymbolPtrList(const SymbolPtrList &, const SymbolPtrList &)> process;
443 process = [&NewInt, &gen_all, &process](const SymbolPtrList &s1, const SymbolPtrList &s2) {
444 SymbolPtrList ret;
445 if (s1.size() != s2.size()) {
446 gen_all = true;
447 return ret;
448 }
449 ret = s1;
450 for (size_t i = 0; i < s1.size(); i++) {
451 if (s1[i]->EqualsTo(s2[i])) {
452 continue;
453 }
454 if (s1[i]->is<ListSymbol>()) {
455 ret[i] = ListSymbol::Make(process(s1[i]->as<ListSymbol>()->symbols(), s2[i]->as<ListSymbol>()->symbols()));
456 continue;
457 }
458 auto v1 = s1[i]->as<IntSymbol>();
459 auto v2 = s2[i]->as<IntSymbol>();
460 if (v2->is_const()) {
461 if (v1->is_const()) {
462 ret[i] = NewInt();
463 } else {
464 auto d1 = v1->divisor();
465 auto r1 = v1->remainder();
466 // if "d1 * v1 + r1 == v2", that's v2 match the condition of v1.
467 if ((v2->value() - r1 + d1) % d1 != 0) {
468 ret[i] = NewInt();
469 }
470 }
471 } else if (v1->is_const()) {
472 ret[i] = NewInt();
473 } else {
474 // two symbols are variable
475 auto d1 = v1->divisor();
476 auto r1 = v1->remainder();
477 auto d2 = v2->divisor();
478 auto r2 = v2->remainder();
479 if (r1 == r2) {
480 if (d2 % d1 != 0) {
481 auto t = NewInt();
482 t->SetDivisorRemainder(std::gcd(d1, d2), r1);
483 ret[i] = t;
484 }
485 } else {
486 ret[i] = NewInt();
487 }
488 }
489 }
490 return ret;
491 };
492 auto ret = process(param_shape->symbols(), input_shape->symbols());
493 if (gen_all) {
494 (void)generalized_shape_.insert(param);
495 param_abs->SetSymbolicShape(param_abs->GetShape()->BuildSymbolicShape());
496 return true;
497 }
498 return build_again;
499 }
500
GeneralizeParamValue(const AnfNodePtr & param,const AbstractBasePtr & input_abs)501 bool SymbolEngineImpl::GeneralizeParamValue(const AnfNodePtr ¶m, const AbstractBasePtr &input_abs) {
502 if (generalized_value_.count(param) > 0) {
503 return false;
504 }
505 auto param_abs = param->abstract();
506 MS_EXCEPTION_IF_NULL(param_abs);
507 if (param_abs->GetSymbolicValue() == nullptr || input_abs->GetSymbolicValue() == nullptr) {
508 return false;
509 }
510 auto param_value = param_abs->GetSymbolicValue();
511 auto input_value = input_abs->GetSymbolicValue();
512 if (param_value->EqualsTo(input_value)) {
513 return false;
514 }
515 param_abs->SetSymbolicValue(BuildSymbolicValue(param_abs));
516 (void)generalized_value_.insert(param);
517 return true;
518 }
519
SetParamSymbols(const CNodePtr & cnode,const FuncGraphPtr & sub_fg,size_t begin_input_index,size_t visit_cnt)520 bool SymbolEngineImpl::SetParamSymbols(const CNodePtr &cnode, const FuncGraphPtr &sub_fg, size_t begin_input_index,
521 size_t visit_cnt) {
522 bool build_again = false;
523 const size_t max_visit_cnt = 5; // to avoid unexplained dead loop
524 for (size_t i = begin_input_index; i < cnode->size(); i++) {
525 auto inp = cnode->input(i);
526 auto input_abs = inp->abstract();
527 MS_EXCEPTION_IF_NULL(input_abs);
528 if (IsDependShape(inp) && input_abs->GetSymbolicShape() == nullptr) {
529 input_abs->SetSymbolicShape(input_abs->GetShape()->BuildSymbolicShape());
530 }
531 if (IsDependValue(inp) && input_abs->GetSymbolicValue() == nullptr) {
532 input_abs->SetSymbolicValue(BuildSymbolicValue(input_abs));
533 }
534 auto param = sub_fg->parameters()[i - begin_input_index];
535 if (visit_cnt == 1) {
536 auto param_abs = CloneAbstractIfSymbolExists(param);
537 MS_EXCEPTION_IF_NULL(param_abs);
538 param_abs->SetSymbolicShape(input_abs->GetSymbolicShape());
539 param_abs->SetSymbolicValue(input_abs->GetSymbolicValue());
540 } else if (visit_cnt <= max_visit_cnt) {
541 build_again = GeneralizeParamShape(param, input_abs) || build_again;
542 build_again = GeneralizeParamValue(param, input_abs) || build_again;
543 }
544 }
545 return build_again;
546 }
547
BuildSubgraphImpl(const CNodePtr & cnode,const FuncGraphPtr & sub_fg,size_t begin_input_index)548 void SymbolEngineImpl::BuildSubgraphImpl(const CNodePtr &cnode, const FuncGraphPtr &sub_fg, size_t begin_input_index) {
549 MS_EXCEPTION_IF_NULL(sub_fg);
550 auto visit_cnt = ++visited_graph_[sub_fg.get()];
551 MS_LOG(DEBUG) << "Build subgraph " << sub_fg->ToString() << " of node " << cnode->fullname_with_scope()
552 << ". visit count: " << visit_cnt;
553 bool build_again = SetParamSymbols(cnode, sub_fg, begin_input_index, visit_cnt);
554 if (visit_cnt > 1) {
555 if (!build_again) {
556 MS_LOG(DEBUG) << "The inputs of graph " << sub_fg->ToString() << " are equal to last building, don't build again";
557 return;
558 }
559 support_infer_ = false;
560 }
561
562 BuildNodesSymbol(sub_fg, GetCNodesOfFuncGraph(sub_fg));
563 // only set the abstract for "call" node.
564 if (IsValueNode<FuncGraph>(cnode->input(0))) {
565 auto out_abs = sub_fg->output()->abstract();
566 MS_EXCEPTION_IF_NULL(out_abs);
567 auto cnode_abs = CloneAbstractIfSymbolExists(cnode);
568 MS_EXCEPTION_IF_NULL(cnode_abs);
569 cnode_abs->SetSymbolicShape(out_abs->GetSymbolicShape());
570 cnode_abs->SetSymbolicValue(out_abs->GetSymbolicValue());
571 }
572 }
573
BuildCNodeSymbolicShape(OperationBuilder * builder,const PrimitivePtr & prim,const AbstractBasePtrList & inputs,const AbstractBasePtr & abstract,const CNodePtr & cnode)574 SymbolPtr SymbolEngineImpl::BuildCNodeSymbolicShape(OperationBuilder *builder, const PrimitivePtr &prim,
575 const AbstractBasePtrList &inputs, const AbstractBasePtr &abstract,
576 const CNodePtr &cnode) {
577 auto digital_shape = abstract->GetShape();
578 MS_EXCEPTION_IF_NULL(digital_shape);
579 if (common::GetEnv("MS_DEV_FORCE_BUILD_SYMBOL") != "on" && !digital_shape->IsDynamic()) {
580 auto static_shape = digital_shape->BuildSymbolicShape();
581 MS_LOG(DEBUG) << "Node " << cnode->fullname_with_scope() << " is static shape: " << digital_shape->ToString();
582 return static_shape;
583 }
584 if (builder == nullptr) {
585 support_infer_ = false;
586 MS_LOG(DEBUG) << "Node " << cnode->fullname_with_scope() << " does not support BuildShape, builder not found.";
587 return digital_shape->BuildSymbolicShape();
588 }
589 SymbolPtr s = nullptr;
590 try {
591 MS_LOG_TRY_CATCH_SCOPE;
592 s = builder->BuildShape(prim, inputs, abstract);
593 } catch (std::exception &e) {
594 MS_LOG(INFO) << "Failed to build symbolic shape for " << cnode->fullname_with_scope() << " with inputs: " << inputs
595 << ". error msg: " << e.what();
596 s = nullptr;
597 }
598 if (s == nullptr) {
599 support_infer_ = false;
600 MS_LOG(DEBUG) << "Node " << cnode->fullname_with_scope() << " does not support BuildShape.";
601 return digital_shape->BuildSymbolicShape();
602 }
603 return s;
604 }
605
BuildCNodeSymbolicValue(OperationBuilder * builder,const PrimitivePtr & prim,const AbstractBasePtrList & inputs,const AbstractBasePtr & abstract,const CNodePtr & cnode)606 SymbolPtr SymbolEngineImpl::BuildCNodeSymbolicValue(OperationBuilder *builder, const PrimitivePtr &prim,
607 const AbstractBasePtrList &inputs, const AbstractBasePtr &abstract,
608 const CNodePtr &cnode) {
609 if (builder == nullptr) {
610 support_infer_ = false;
611 MS_LOG(DEBUG) << "Node " << cnode->fullname_with_scope() << " does not support BuildValue, builder not found.";
612 return BuildSymbolicValue(abstract);
613 }
614 SymbolPtr s = nullptr;
615 try {
616 MS_LOG_TRY_CATCH_SCOPE;
617 s = builder->BuildValue(prim, inputs, abstract);
618 } catch (std::exception &e) {
619 MS_LOG(INFO) << "Failed to build symbolic value for " << cnode->fullname_with_scope() << " with inputs: " << inputs
620 << ". error msg: " << e.what();
621 s = nullptr;
622 }
623 if (s == nullptr) {
624 support_infer_ = false;
625 MS_LOG(DEBUG) << "Node " << cnode->fullname_with_scope() << " does not support BuildValue.";
626 return BuildSymbolicValue(abstract);
627 }
628 return s;
629 }
630
ExtractInputsAbstract(const CNodePtr & cnode)631 AbstractBasePtrList SymbolEngineImpl::ExtractInputsAbstract(const CNodePtr &cnode) {
632 AbstractBasePtrList abs_list;
633 abs_list.reserve(cnode->size());
634 (void)std::transform(cnode->inputs().cbegin() + 1, cnode->inputs().cend(), std::back_inserter(abs_list),
635 [](const AnfNodePtr &node) {
636 MS_EXCEPTION_IF_NULL(node);
637 auto abs = node->abstract();
638 if (abs == nullptr) {
639 if (node->isa<ValueNode>()) {
640 auto vnode = node->cast_ptr<ValueNode>();
641 MS_EXCEPTION_IF_NULL(vnode);
642 MS_EXCEPTION_IF_NULL(vnode->value());
643 abs = vnode->value()->ToAbstract();
644 MS_EXCEPTION_IF_NULL(abs);
645 node->set_abstract(abs);
646 MS_LOG(DEBUG) << "Set new abstract for input node " << node->DebugString();
647 } else {
648 // Do not raise exception here, this input may not be used by operation.
649 MS_LOG(INFO) << "The input " << node->DebugString() << " has null abstract.";
650 }
651 }
652 return abs;
653 });
654 return abs_list;
655 }
656
HasAbstractAny(const AbstractBasePtrList & inputs,const AbstractBasePtr & output)657 bool SymbolEngineImpl::HasAbstractAny(const AbstractBasePtrList &inputs, const AbstractBasePtr &output) {
658 return output->isa<abstract::AbstractAny>() ||
659 std::any_of(inputs.begin(), inputs.end(),
660 [](const AbstractBasePtr &abs) { return abs->isa<abstract::AbstractAny>(); });
661 }
662
BuildCNodeSymbol(const CNodePtr & cnode)663 void SymbolEngineImpl::BuildCNodeSymbol(const CNodePtr &cnode) {
664 PrimitivePtr prim;
665 AbstractBasePtrList inputs;
666 if (cnode->input(0)->isa<CNode>()) {
667 if (auto iter = special_cnodes_.find(cnode); iter != special_cnodes_.end()) {
668 auto ret = iter->second->ExtractInputs();
669 prim = std::move(ret.first);
670 inputs = std::move(ret.second);
671 }
672 if (prim == nullptr) {
673 prim = std::make_shared<Primitive>("_SpecialCNode");
674 }
675 } else {
676 prim = GetCNodePrimitive(cnode);
677 if (prim == nullptr) {
678 prim = std::make_shared<Primitive>("_UnsupportedCNode");
679 }
680 inputs = ExtractInputsAbstract(cnode);
681 }
682 auto abstract = CloneAbstractIfSymbolExists(cnode);
683 MS_EXCEPTION_IF_NULL(abstract);
684 if (HasAbstractAny(inputs, abstract)) {
685 MS_LOG(DEBUG) << "The input or output of " << cnode->fullname_with_scope()
686 << " has AbstractAny, which is not supported by symbol engine. node: " << cnode->DebugString();
687 return;
688 }
689
690 auto builder = OperationBuilderInfoRegistry::GetBuilder(prim->name(), emitter_.get());
691 // theoretically, it's possible that both shape and value are required for a same node.
692 const auto &depend_status = depend_status_map_[cnode];
693 if (depend_status.value) {
694 MS_LOG(DEBUG) << "Build value for node " << cnode->fullname_with_scope() << ". " << cnode->DebugString();
695 auto sym_value = BuildCNodeSymbolicValue(builder.get(), prim, inputs, abstract, cnode);
696 MS_LOG(DEBUG) << "Set value for node: " << cnode->fullname_with_scope() << ". symbol: " << sym_value->ToString();
697 abstract->SetSymbolicValue(sym_value);
698 }
699
700 if (depend_status.shape) {
701 MS_LOG(DEBUG) << "Build shape for node " << cnode->fullname_with_scope() << ". " << cnode->DebugString();
702 auto sym_shape = BuildCNodeSymbolicShape(builder.get(), prim, inputs, abstract, cnode);
703 MS_EXCEPTION_IF_NULL(sym_shape);
704 MS_LOG(DEBUG) << "Set shape for node: " << cnode->fullname_with_scope() << ". symbol: " << sym_shape->ToString();
705 abstract->SetSymbolicShape(sym_shape->as_sptr<ListSymbol>());
706 }
707 }
708
DumpText() const709 std::string SymbolEngineImpl::DumpText() const {
710 std::ostringstream oss;
711 oss << ToString() << " {\n";
712 for (auto op : ops_) {
713 oss << op->DumpText();
714 }
715 oss << "}\n";
716 return oss.str();
717 }
718
CloneAbstractIfSymbolExists(const AbstractBasePtr & abs)719 AbstractBasePtr CloneAbstractIfSymbolExists(const AbstractBasePtr &abs) {
720 if (abs == nullptr) {
721 return nullptr;
722 }
723 if (abs->GetSymbolicShape() == nullptr && abs->GetSymbolicValue() == nullptr) {
724 return abs;
725 }
726 try {
727 MS_LOG_TRY_CATCH_SCOPE;
728 auto new_abs = abs->Clone();
729 MS_EXCEPTION_IF_NULL(new_abs);
730 new_abs->SetSymbolicShape(nullptr);
731 new_abs->SetSymbolicValue(nullptr);
732 return new_abs;
733 } catch (std::exception &e) {
734 if (IS_OUTPUT_ON(MsLogLevel::kDebug)) {
735 std::string sym_shape = abs->GetSymbolicShape() == nullptr ? "" : abs->GetSymbolicShape()->ToString();
736 std::string sym_value = abs->GetSymbolicValue() == nullptr ? "" : abs->GetSymbolicValue()->ToString();
737 MS_LOG(DEBUG) << "The abstract has symbol (S:" << sym_shape << ", V:" << sym_value
738 << ") but cannot be cloned. abstract: " << abs->ToString() << ", msg:" << e.what();
739 }
740 }
741 return abs;
742 }
743
CleanSymbols(const FuncGraphPtr & func_graph)744 void CleanSymbols(const FuncGraphPtr &func_graph) {
745 std::set<AbstractBasePtr> params_abs;
746 for (auto ¶m : func_graph->parameters()) {
747 (void)params_abs.insert(param->abstract());
748 }
749 auto nodes = TopoSort(func_graph->get_return(), SuccDeeperSimple, AlwaysInclude);
750 for (auto &node : nodes) {
751 auto abs = node->abstract();
752 if (params_abs.find(abs) != params_abs.end()) {
753 // do not clean the parameters' symbol
754 continue;
755 }
756 if (abs != nullptr) {
757 abs->SetSymbolicShape(nullptr);
758 abs->SetSymbolicValue(nullptr);
759 }
760 auto fg = node->func_graph();
761 if (fg != nullptr) {
762 fg->set_symbol_engine(nullptr);
763 }
764 }
765 }
766 } // namespace symshape
767 } // namespace mindspore
768