• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023-2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "include/common/symbol_engine/symbol_engine_impl.h"
17 #include <algorithm>
18 #include <ostream>
19 #include "ir/anf.h"
20 #include "ir/func_graph.h"
21 #include "ir/graph_utils.h"
22 #include "ops/array_ops.h"
23 #include "ops/framework_ops.h"
24 #include "ops/sequence_ops.h"
25 #include "mindspore/core/ops/symbol_ops_impl/switch.h"
26 #include "mindspore/core/ops/symbol_ops_impl/j_op.h"
27 #include "utils/check_convert_utils.h"
28 #include "utils/anf_utils.h"
29 #include "mindspore/core/symbolic_shape/utils.h"
30 #include "mindspore/core/symbolic_shape/operation_builder.h"
31 
32 namespace mindspore {
33 namespace symshape {
GetCNodesOfFuncGraph(const FuncGraphPtr & fg)34 AnfNodePtrList GetCNodesOfFuncGraph(const FuncGraphPtr &fg) {
35   bool has_node_in_other_graph = false;
36   auto nodes = TopoSort(fg->output(), SuccDeeperSimple, [&has_node_in_other_graph, &fg](const AnfNodePtr &node) {
37     if (!node->isa<CNode>()) {
38       if (GetValuePtr<FuncGraph>(node) != nullptr) {
39         // some nodes of this graph can be linked in other graph.
40         has_node_in_other_graph = true;
41         return FOLLOW;
42       }
43       return EXCLUDE;
44     }
45     if (node->func_graph() != fg) {
46       has_node_in_other_graph = true;
47     }
48     return FOLLOW;
49   });
50   // at frontend, a node may directly links to other node in other graph.
51   if (has_node_in_other_graph) {
52     (void)nodes.erase(
53       std::remove_if(nodes.begin(), nodes.end(),
54                      [&fg](const AnfNodePtr &node) { return !node->isa<CNode>() || node->func_graph() != fg; }),
55       nodes.end());
56   }
57   return nodes;
58 }
59 
GetFuncGraphFromCNode(const CNodePtr & cnode)60 std::pair<FuncGraphPtr, size_t> GetFuncGraphFromCNode(const CNodePtr &cnode) {
61   auto sub_fg = GetCNodeFuncGraph(cnode);
62   size_t begin_index = kIndex1;
63   if (sub_fg == nullptr && IsPrimitiveCNode(cnode, prim::kPrimPartial)) {
64     auto vnode = cnode->input(kIndex1)->cast<ValueNodePtr>();
65     MS_EXCEPTION_IF_NULL(vnode);
66     sub_fg = vnode->value()->cast<FuncGraphPtr>();
67     MS_EXCEPTION_IF_NULL(sub_fg);
68     begin_index = kIndex2;
69   }
70   if (sub_fg != nullptr && sub_fg->parameters().size() + begin_index < cnode->size()) {
71     MS_LOG(INTERNAL_EXCEPTION) << "For graph " << sub_fg->ToString() << ", the parameter size "
72                                << sub_fg->parameters().size() << " is less than cnode input num "
73                                << cnode->size() - begin_index;
74   }
75   return std::make_pair(sub_fg, begin_index);
76 }
77 
78 class ControlFlowJoinNode : public SpecialCNodeHelper {
79  public:
80   using SpecialCNodeHelper::SpecialCNodeHelper;
Match(const CNodePtr & cnode)81   static bool Match(const CNodePtr &cnode) { return IsPrimitiveCNode(cnode->input(0), prim::kPrimSwitch); }
SetDependStatus(std::map<AnfNodePtr,DependStatus> * depend_status_map)82   void SetDependStatus(std::map<AnfNodePtr, DependStatus> *depend_status_map) override {
83     auto input0 = input();
84     (*depend_status_map)[input0->input(kIndex1)].value = true;
85     SetFuncGraphDepend(input0->input(kIndex2));
86     SetFuncGraphDepend(input0->input(kIndex3));
87   }
ExtractInputs()88   std::pair<PrimitivePtr, AbstractBasePtrList> ExtractInputs() override {
89     auto prim = std::make_shared<Primitive>(ops::kControlFlowJoin);
90     AbstractBasePtrList inputs;
91     auto input0 = input();
92     (void)inputs.emplace_back(input0->input(kIndex1)->abstract());
93     (void)inputs.emplace_back(GetFuncGraphOutAbs(input0->input(kIndex2)));
94     (void)inputs.emplace_back(GetFuncGraphOutAbs(input0->input(kIndex3)));
95     return std::make_pair(std::move(prim), std::move(inputs));
96   }
97 
98  protected:
input() const99   CNodePtr input() const {
100     auto input0 = cnode_->input(0)->cast<CNodePtr>();
101     MS_EXCEPTION_IF_NULL(input0);
102     return input0;
103   }
symbol_engine() const104   SymbolEngineImplPtr symbol_engine() const {
105     auto symbol_engine = cnode_->func_graph()->symbol_engine();
106     MS_EXCEPTION_IF_NULL(symbol_engine);
107     auto symbol_engine_impl = symbol_engine->cast<SymbolEngineImplPtr>();
108     MS_EXCEPTION_IF_NULL(symbol_engine_impl);
109     return symbol_engine_impl;
110   }
SetFuncGraphDepend(const AnfNodePtr & node) const111   void SetFuncGraphDepend(const AnfNodePtr &node) const {
112     auto fg = GetValueNode<FuncGraphPtr>(node);
113     if (fg != nullptr) {
114       symbol_engine()->PreBuildQuerySubgraphDependStatus(cnode_, fg, kIndex1);
115     }
116   }
117 
GetFuncGraphOutAbs(const AnfNodePtr & node) const118   AbstractBasePtr GetFuncGraphOutAbs(const AnfNodePtr &node) const {
119     if (IsPrimitiveCNode(node, prim::kPrimPartial)) {
120       return GetFuncGraphFromCNode(node->cast<CNodePtr>()).first->output()->abstract();
121     }
122     // the graph with Partial is build symbols ahead, build the pure graph (without Partial) in Switch.
123     auto fg = GetValueNode<FuncGraphPtr>(node);
124     if (fg == nullptr) {
125       MS_EXCEPTION_IF_NULL(node->abstract());
126       return node->abstract();
127     }
128     symbol_engine()->BuildSubgraphImpl(cnode_, fg, kIndex1);
129     return fg->output()->abstract();
130   }
131 };
132 
133 class JFuncCaller : public SpecialCNodeHelper {
134  public:
135   /// \brief The call node of PrimJ:
136   ///
137   ///  %0 = J(@fg) // primitive "J"
138   ///  %1 = %0(inp1, inp2, ...) // the node output a tuple of "(tensor, Func)"
139   ///  %2 = TupleGetItem(%1, 1)  // get the output "Func"
140   ///  %3 = %2(loss_scale)       // call the "Func".
141   ///
142   /// this pattern match the "%3", and the output shape is same as "inp1, inp2, ...".
JFuncCaller(const CNodePtr & cnode)143   explicit JFuncCaller(const CNodePtr &cnode) : SpecialCNodeHelper(cnode) {
144     auto getitem1 = cnode->input(kIndex0)->cast<CNodePtr>();
145     MS_EXCEPTION_IF_NULL(getitem1);
146     input_ = getitem1->input(kIndex1)->cast<CNodePtr>();
147     MS_EXCEPTION_IF_NULL(input_);
148   }
149   ~JFuncCaller() override = default;
Match(const CNodePtr & cnode)150   static bool Match(const CNodePtr &cnode) {
151     auto getitem1 = cnode->input(kIndex0)->cast<CNodePtr>();
152     if (getitem1 == nullptr || !IsPrimitiveCNode(getitem1, prim::kPrimTupleGetItem)) {
153       return false;
154     }
155     if (GetValue<int64_t>(GetValueNode(getitem1->input(kIndex2))) != 1) {
156       return false;
157     }
158     auto callj = getitem1->input(kIndex1)->cast<CNodePtr>();
159     return callj != nullptr && IsPrimitiveCNode(callj->input(kIndex0), prim::kPrimJ);
160   }
SetDependStatus(std::map<AnfNodePtr,DependStatus> * depend_status_map)161   void SetDependStatus(std::map<AnfNodePtr, DependStatus> *depend_status_map) override {
162     for (size_t i = 1; i < input_->size(); i++) {
163       (*depend_status_map)[input_->input(i)] = (*depend_status_map)[cnode_];
164     }
165   }
ExtractInputs()166   std::pair<PrimitivePtr, AbstractBasePtrList> ExtractInputs() override {
167     auto prim = std::make_shared<Primitive>(ops::kJFuncCaller);
168     AbstractBasePtrList inputs;
169     inputs.reserve(input_->size());
170     (void)std::transform(input_->inputs().begin(), input_->inputs().end(), std::back_inserter(inputs),
171                          [](const AnfNodePtr &node) { return node->abstract(); });
172     return std::make_pair(std::move(prim), std::move(inputs));
173   }
174 
175  protected:
176   CNodePtr input_{nullptr};
177 };
178 
Build(const FuncGraphPtr & func_graph)179 SymbolEngineImplPtr SymbolEngineImpl::Build(const FuncGraphPtr &func_graph) {
180   if (func_graph->symbol_engine() != nullptr) {
181     CleanSymbols(func_graph);
182   }
183   auto engine = std::make_shared<SymbolEngineImpl>(func_graph);
184   func_graph->set_symbol_engine(engine);
185   engine->PreBuild();
186   engine->BuildImpl();
187   return engine;
188 }
189 
BuildNodesSymbol(const FuncGraphPtr & fg,const AnfNodePtrList & cnodes)190 void SymbolEngineImpl::BuildNodesSymbol(const FuncGraphPtr &fg, const AnfNodePtrList &cnodes) {
191   for (auto &node : cnodes) {
192     auto cnode = node->cast<CNodePtr>();
193     MS_EXCEPTION_IF_NULL(cnode);
194     if (auto fg_with_index = GetFuncGraphFromCNode(cnode); fg_with_index.first != nullptr) {
195       // "call" or "Partial" node
196       BuildSubgraphImpl(cnode, fg_with_index.first, fg_with_index.second);
197     } else {
198       BuildCNodeSymbol(cnode);
199     }
200   }
201   // the funcgraph can be empty or only return a ValueNode.
202   if (!cnodes.empty()) {
203     return;
204   }
205   auto node = fg->output();
206   if (node->isa<ValueNode>()) {
207     auto depend_status = depend_status_map_[node];
208     auto node_abs = CloneAbstractIfSymbolExists(node);
209     MS_EXCEPTION_IF_NULL(node_abs);
210     if (depend_status.shape) {
211       auto sym_shape = node_abs->GetShape()->BuildSymbolicShape();
212       MS_LOG(DEBUG) << "Set shape for node: " << node->DebugString() << ". symbol: " << sym_shape->ToString();
213       node_abs->SetSymbolicShape(sym_shape);
214     }
215     if (depend_status.value) {
216       auto sym_value = BuildSymbolicValue(node_abs);
217       MS_LOG(DEBUG) << "Set value for node: " << node->DebugString() << ". symbol: " << sym_value->ToString();
218       node_abs->SetSymbolicValue(sym_value);
219     }
220   }
221 }
222 
PreBuild()223 void SymbolEngineImpl::PreBuild() {
224   auto func_graph = func_graph_.lock();
225   MS_EXCEPTION_IF_NULL(func_graph);
226   cnodes_ = GetCNodesOfFuncGraph(func_graph);
227   visited_graph_[func_graph.get()] = 1;
228   PreBuildQueryDependStatus(cnodes_);
229   visited_graph_.clear();
230 }
231 
BuildImpl()232 void SymbolEngineImpl::BuildImpl() {
233   auto func_graph = func_graph_.lock();
234   MS_EXCEPTION_IF_NULL(func_graph);
235   MS_LOG(DEBUG) << "Build " << ToString() << " with graph " << func_graph->ToString();
236   emitter_ = std::make_unique<OperationEmitter>(&ops_);
237   visited_graph_[func_graph.get()] = 1;
238   BuildNodesSymbol(func_graph, cnodes_);
239   emitter_->Clean();
240   visited_graph_.clear();
241   generalized_shape_.clear();
242   generalized_value_.clear();
243 }
244 
PreBuildSpecialNode(const CNodePtr & cnode)245 void SymbolEngineImpl::PreBuildSpecialNode(const CNodePtr &cnode) {
246   std::shared_ptr<SpecialCNodeHelper> helper = nullptr;
247   if (ControlFlowJoinNode::Match(cnode)) {
248     helper = std::make_shared<ControlFlowJoinNode>(cnode);
249   } else if (JFuncCaller::Match(cnode)) {
250     helper = std::make_shared<JFuncCaller>(cnode);
251   } else {
252     MS_LOG(DEBUG) << "The special node " << cnode->fullname_with_scope() << " is not supported.";
253     return;
254   }
255   special_cnodes_[cnode] = helper;
256   helper->SetDependStatus(&depend_status_map_);
257 }
258 
SetInputDependStatus(const CNodePtr & cnode,bool depend_value)259 void SymbolEngineImpl::SetInputDependStatus(const CNodePtr &cnode, bool depend_value) {
260   auto prim = GetCNodePrimitive(cnode);
261   MS_EXCEPTION_IF_NULL(prim);
262   size_t input_num = cnode->size() - 1;
263   auto depends = depend_value ? GetValueDepends(prim, input_num) : GetShapeDepends(prim, input_num);
264   for (size_t i = 0; i < depends.size(); i++) {
265     if (depends[i] == DependOn::kValue) {
266       depend_status_map_[cnode->input(i + 1)].value = true;
267     } else if (depends[i] == DependOn::kShape) {
268       depend_status_map_[cnode->input(i + 1)].shape = true;
269     }
270   }
271 }
272 
PreBuildQueryDependStatus(const AnfNodePtrList & cnodes)273 void SymbolEngineImpl::PreBuildQueryDependStatus(const AnfNodePtrList &cnodes) {
274   for (auto iter = cnodes.rbegin(); iter != cnodes.rend(); ++iter) {
275     auto cnode = (*iter)->cast<CNodePtr>();
276     MS_EXCEPTION_IF_NULL(cnode);
277     auto &depend_status = depend_status_map_[cnode];
278     if (!depend_status.value && !depend_status.shape) {
279       // build symbolic shape for the node even though it's not depended by any nodes.
280       depend_status.shape = true;
281     }
282     MS_LOG(DEBUG) << "The depend status of " << cnode->DebugString() << "(" << cnode->fullname_with_scope()
283                   << "): shape-depend=" << depend_status.shape << ", value-depend=" << depend_status.value;
284     if (cnode->input(0)->isa<CNode>()) {
285       PreBuildSpecialNode(cnode);
286       continue;
287     }
288     // the "call" node or Partial node.
289     auto subfg_with_index = GetFuncGraphFromCNode(cnode);
290     if (subfg_with_index.first != nullptr) {
291       PreBuildQuerySubgraphDependStatus(cnode, subfg_with_index.first, subfg_with_index.second);
292       continue;
293     }
294     // the normal CNode, check the depend status from operation builder info.
295     if (!OperationBuilderInfoRegistry::HasOp(AnfUtils::GetCNodeName(cnode))) {
296       continue;
297     }
298     if (depend_status.shape) {
299       SetInputDependStatus(cnode, false);
300     }
301     if (depend_status.value) {
302       SetInputDependStatus(cnode, true);
303     }
304   }
305 }
306 
PreBuildQuerySubgraphDependStatus(const CNodePtr & cnode,const FuncGraphPtr & sub_fg,size_t begin_input_index)307 void SymbolEngineImpl::PreBuildQuerySubgraphDependStatus(const CNodePtr &cnode, const FuncGraphPtr &sub_fg,
308                                                          size_t begin_input_index) {
309   if (++visited_graph_[sub_fg.get()] > 1) {
310     return;
311   }
312   sub_fg->set_symbol_engine(shared_from_base<SymbolEngine>());
313   depend_status_map_[sub_fg->output()] = depend_status_map_[cnode];
314   PreBuildQueryDependStatus(GetCNodesOfFuncGraph(sub_fg));
315   for (auto &param : sub_fg->parameters()) {
316     if (begin_input_index >= cnode->size()) {
317       break;
318     }
319     auto &cnode_input_depend_status = depend_status_map_[cnode->input(begin_input_index++)];
320     auto depend_status = depend_status_map_[param];
321     if (depend_status.shape) {
322       cnode_input_depend_status.shape = true;
323     }
324     if (depend_status.value) {
325       cnode_input_depend_status.value = true;
326     }
327   }
328 }
329 
Infer(const AbstractBasePtrList & inputs)330 bool SymbolEngineImpl::Infer(const AbstractBasePtrList &inputs) {
331   if (!support_infer_) {
332     MS_LOG(WARNING) << "The " << ToString() << " does not support infer";
333     return false;
334   }
335   MS_LOG(DEBUG) << "Infer " << ToString() << " with inputs: " << inputs;
336   auto fg = func_graph_.lock();
337   MS_EXCEPTION_IF_NULL(fg);
338   auto &params = fg->parameters();
339   // There may be params like UpdateStates, which won't contribute to infer
340   if (params.size() < inputs.size()) {
341     MS_LOG(EXCEPTION) << "The parameter size should be equal to or larger than inputs size, but got " << params.size()
342                       << " vs " << inputs.size();
343   }
344   for (size_t i = 0; i < inputs.size(); i++) {
345     if (auto shape = params[i]->abstract()->GetSymbolicShape(); shape != nullptr) {
346       auto cur_shape = inputs[i]->GetShape()->BuildSymbolicShape();
347       MS_EXCEPTION_IF_NULL(cur_shape);
348       MS_LOG(DEBUG) << "Update shape for input[" << i << "]: " << cur_shape->ToRawString();
349       shape->Update(cur_shape);
350     }
351     if (auto value = params[i]->abstract()->GetSymbolicValue(); value != nullptr && value->CanUpdate()) {
352       auto cur_value = BuildSymbolicValue(inputs[i]);
353       MS_EXCEPTION_IF_NULL(cur_value);
354       MS_LOG(DEBUG) << "Update value for input[" << i << "]: " << cur_value->ToRawString();
355       value->Update(cur_value);
356     }
357   }
358   for (auto &op : ops_) {
359     op->Run();
360   }
361   return true;
362 }
363 
IsDependValue(const AnfNodePtr & node)364 bool SymbolEngineImpl::IsDependValue(const AnfNodePtr &node) {
365   if (depend_status_map_.find(node) != depend_status_map_.end()) {
366     return depend_status_map_[node].value;
367   }
368   return false;
369 }
370 
IsDependShape(const AnfNodePtr & node)371 bool SymbolEngineImpl::IsDependShape(const AnfNodePtr &node) {
372   if (depend_status_map_.find(node) != depend_status_map_.end()) {
373     return depend_status_map_[node].shape;
374   }
375   return false;
376 }
377 
QuerySymbolExprHelper(const SymbolPtr & s,const std::unordered_map<std::string,std::string> & symbol_expr_map)378 std::string SymbolEngineImpl::QuerySymbolExprHelper(
379   const SymbolPtr &s, const std::unordered_map<std::string, std::string> &symbol_expr_map) {
380   auto raw_string = s->ToRawString();
381   if (s->is<ListSymbol>() || s->HasData()) {
382     return raw_string;
383   }
384   if (symbol_expr_map.find(raw_string) != symbol_expr_map.end()) {
385     return raw_string;
386   }
387   auto operation = s->operation();
388   if (operation == nullptr) {
389     return raw_string;
390   }
391   std::ostringstream oss;
392   oss << operation->name() << "(";
393   bool first = true;
394   for (auto &input : operation->inputs()) {
395     if (first == true) {
396       first = false;
397     } else {
398       oss << ", ";
399     }
400     oss << QuerySymbolExprHelper(input, symbol_expr_map);
401   }
402   oss << ")";
403   return oss.str();
404 }
405 
QuerySymbolExpr(const AnfNodePtr & node,std::unordered_map<std::string,std::string> * symbol_expr_map)406 void SymbolEngineImpl::QuerySymbolExpr(const AnfNodePtr &node,
407                                        std::unordered_map<std::string, std::string> *symbol_expr_map) {
408   // todo, use SymbolVisitor to export symbol expr.
409   auto symbolic_shape = node->abstract()->GetSymbolicShape();
410   if (symbolic_shape == nullptr) {
411     return;
412   }
413   for (const auto &symbol : symbolic_shape->symbols()) {
414     auto name = symbol->ToRawString();
415     if (name[0] == 's' && symbol_expr_map->find(name) == symbol_expr_map->end()) {
416       auto expr = QuerySymbolExprHelper(symbol, *symbol_expr_map);
417       (*symbol_expr_map)[name] = expr;
418     }
419   }
420 }
421 
GeneralizeParamShape(const AnfNodePtr & param,const AbstractBasePtr & input_abs)422 bool SymbolEngineImpl::GeneralizeParamShape(const AnfNodePtr &param, const AbstractBasePtr &input_abs) {
423   if (generalized_shape_.count(param) > 0) {
424     return false;
425   }
426   auto param_abs = param->abstract();
427   MS_EXCEPTION_IF_NULL(param_abs);
428   if (param_abs->GetSymbolicShape() == nullptr || input_abs->GetSymbolicShape() == nullptr) {
429     return false;
430   }
431   auto param_shape = param_abs->GetSymbolicShape();
432   auto input_shape = input_abs->GetSymbolicShape();
433   if (param_shape->EqualsTo(input_shape)) {
434     return false;
435   }
436   bool build_again = false;
437   bool gen_all = false;
438   auto NewInt = [&build_again]() -> IntSymbolPtr {
439     build_again = true;
440     return IntSymbol::Make();
441   };
442   std::function<SymbolPtrList(const SymbolPtrList &, const SymbolPtrList &)> process;
443   process = [&NewInt, &gen_all, &process](const SymbolPtrList &s1, const SymbolPtrList &s2) {
444     SymbolPtrList ret;
445     if (s1.size() != s2.size()) {
446       gen_all = true;
447       return ret;
448     }
449     ret = s1;
450     for (size_t i = 0; i < s1.size(); i++) {
451       if (s1[i]->EqualsTo(s2[i])) {
452         continue;
453       }
454       if (s1[i]->is<ListSymbol>()) {
455         ret[i] = ListSymbol::Make(process(s1[i]->as<ListSymbol>()->symbols(), s2[i]->as<ListSymbol>()->symbols()));
456         continue;
457       }
458       auto v1 = s1[i]->as<IntSymbol>();
459       auto v2 = s2[i]->as<IntSymbol>();
460       if (v2->is_const()) {
461         if (v1->is_const()) {
462           ret[i] = NewInt();
463         } else {
464           auto d1 = v1->divisor();
465           auto r1 = v1->remainder();
466           // if "d1 * v1 + r1 == v2", that's v2 match the condition of v1.
467           if ((v2->value() - r1 + d1) % d1 != 0) {
468             ret[i] = NewInt();
469           }
470         }
471       } else if (v1->is_const()) {
472         ret[i] = NewInt();
473       } else {
474         // two symbols are variable
475         auto d1 = v1->divisor();
476         auto r1 = v1->remainder();
477         auto d2 = v2->divisor();
478         auto r2 = v2->remainder();
479         if (r1 == r2) {
480           if (d2 % d1 != 0) {
481             auto t = NewInt();
482             t->SetDivisorRemainder(std::gcd(d1, d2), r1);
483             ret[i] = t;
484           }
485         } else {
486           ret[i] = NewInt();
487         }
488       }
489     }
490     return ret;
491   };
492   auto ret = process(param_shape->symbols(), input_shape->symbols());
493   if (gen_all) {
494     (void)generalized_shape_.insert(param);
495     param_abs->SetSymbolicShape(param_abs->GetShape()->BuildSymbolicShape());
496     return true;
497   }
498   return build_again;
499 }
500 
GeneralizeParamValue(const AnfNodePtr & param,const AbstractBasePtr & input_abs)501 bool SymbolEngineImpl::GeneralizeParamValue(const AnfNodePtr &param, const AbstractBasePtr &input_abs) {
502   if (generalized_value_.count(param) > 0) {
503     return false;
504   }
505   auto param_abs = param->abstract();
506   MS_EXCEPTION_IF_NULL(param_abs);
507   if (param_abs->GetSymbolicValue() == nullptr || input_abs->GetSymbolicValue() == nullptr) {
508     return false;
509   }
510   auto param_value = param_abs->GetSymbolicValue();
511   auto input_value = input_abs->GetSymbolicValue();
512   if (param_value->EqualsTo(input_value)) {
513     return false;
514   }
515   param_abs->SetSymbolicValue(BuildSymbolicValue(param_abs));
516   (void)generalized_value_.insert(param);
517   return true;
518 }
519 
SetParamSymbols(const CNodePtr & cnode,const FuncGraphPtr & sub_fg,size_t begin_input_index,size_t visit_cnt)520 bool SymbolEngineImpl::SetParamSymbols(const CNodePtr &cnode, const FuncGraphPtr &sub_fg, size_t begin_input_index,
521                                        size_t visit_cnt) {
522   bool build_again = false;
523   const size_t max_visit_cnt = 5;  // to avoid unexplained dead loop
524   for (size_t i = begin_input_index; i < cnode->size(); i++) {
525     auto inp = cnode->input(i);
526     auto input_abs = inp->abstract();
527     MS_EXCEPTION_IF_NULL(input_abs);
528     if (IsDependShape(inp) && input_abs->GetSymbolicShape() == nullptr) {
529       input_abs->SetSymbolicShape(input_abs->GetShape()->BuildSymbolicShape());
530     }
531     if (IsDependValue(inp) && input_abs->GetSymbolicValue() == nullptr) {
532       input_abs->SetSymbolicValue(BuildSymbolicValue(input_abs));
533     }
534     auto param = sub_fg->parameters()[i - begin_input_index];
535     if (visit_cnt == 1) {
536       auto param_abs = CloneAbstractIfSymbolExists(param);
537       MS_EXCEPTION_IF_NULL(param_abs);
538       param_abs->SetSymbolicShape(input_abs->GetSymbolicShape());
539       param_abs->SetSymbolicValue(input_abs->GetSymbolicValue());
540     } else if (visit_cnt <= max_visit_cnt) {
541       build_again = GeneralizeParamShape(param, input_abs) || build_again;
542       build_again = GeneralizeParamValue(param, input_abs) || build_again;
543     }
544   }
545   return build_again;
546 }
547 
BuildSubgraphImpl(const CNodePtr & cnode,const FuncGraphPtr & sub_fg,size_t begin_input_index)548 void SymbolEngineImpl::BuildSubgraphImpl(const CNodePtr &cnode, const FuncGraphPtr &sub_fg, size_t begin_input_index) {
549   MS_EXCEPTION_IF_NULL(sub_fg);
550   auto visit_cnt = ++visited_graph_[sub_fg.get()];
551   MS_LOG(DEBUG) << "Build subgraph " << sub_fg->ToString() << " of node " << cnode->fullname_with_scope()
552                 << ". visit count: " << visit_cnt;
553   bool build_again = SetParamSymbols(cnode, sub_fg, begin_input_index, visit_cnt);
554   if (visit_cnt > 1) {
555     if (!build_again) {
556       MS_LOG(DEBUG) << "The inputs of graph " << sub_fg->ToString() << " are equal to last building, don't build again";
557       return;
558     }
559     support_infer_ = false;
560   }
561 
562   BuildNodesSymbol(sub_fg, GetCNodesOfFuncGraph(sub_fg));
563   // only set the abstract for "call" node.
564   if (IsValueNode<FuncGraph>(cnode->input(0))) {
565     auto out_abs = sub_fg->output()->abstract();
566     MS_EXCEPTION_IF_NULL(out_abs);
567     auto cnode_abs = CloneAbstractIfSymbolExists(cnode);
568     MS_EXCEPTION_IF_NULL(cnode_abs);
569     cnode_abs->SetSymbolicShape(out_abs->GetSymbolicShape());
570     cnode_abs->SetSymbolicValue(out_abs->GetSymbolicValue());
571   }
572 }
573 
BuildCNodeSymbolicShape(OperationBuilder * builder,const PrimitivePtr & prim,const AbstractBasePtrList & inputs,const AbstractBasePtr & abstract,const CNodePtr & cnode)574 SymbolPtr SymbolEngineImpl::BuildCNodeSymbolicShape(OperationBuilder *builder, const PrimitivePtr &prim,
575                                                     const AbstractBasePtrList &inputs, const AbstractBasePtr &abstract,
576                                                     const CNodePtr &cnode) {
577   auto digital_shape = abstract->GetShape();
578   MS_EXCEPTION_IF_NULL(digital_shape);
579   if (common::GetEnv("MS_DEV_FORCE_BUILD_SYMBOL") != "on" && !digital_shape->IsDynamic()) {
580     auto static_shape = digital_shape->BuildSymbolicShape();
581     MS_LOG(DEBUG) << "Node " << cnode->fullname_with_scope() << " is static shape: " << digital_shape->ToString();
582     return static_shape;
583   }
584   if (builder == nullptr) {
585     support_infer_ = false;
586     MS_LOG(DEBUG) << "Node " << cnode->fullname_with_scope() << " does not support BuildShape, builder not found.";
587     return digital_shape->BuildSymbolicShape();
588   }
589   SymbolPtr s = nullptr;
590   try {
591     MS_LOG_TRY_CATCH_SCOPE;
592     s = builder->BuildShape(prim, inputs, abstract);
593   } catch (std::exception &e) {
594     MS_LOG(INFO) << "Failed to build symbolic shape for " << cnode->fullname_with_scope() << " with inputs: " << inputs
595                  << ". error msg: " << e.what();
596     s = nullptr;
597   }
598   if (s == nullptr) {
599     support_infer_ = false;
600     MS_LOG(DEBUG) << "Node " << cnode->fullname_with_scope() << " does not support BuildShape.";
601     return digital_shape->BuildSymbolicShape();
602   }
603   return s;
604 }
605 
BuildCNodeSymbolicValue(OperationBuilder * builder,const PrimitivePtr & prim,const AbstractBasePtrList & inputs,const AbstractBasePtr & abstract,const CNodePtr & cnode)606 SymbolPtr SymbolEngineImpl::BuildCNodeSymbolicValue(OperationBuilder *builder, const PrimitivePtr &prim,
607                                                     const AbstractBasePtrList &inputs, const AbstractBasePtr &abstract,
608                                                     const CNodePtr &cnode) {
609   if (builder == nullptr) {
610     support_infer_ = false;
611     MS_LOG(DEBUG) << "Node " << cnode->fullname_with_scope() << " does not support BuildValue, builder not found.";
612     return BuildSymbolicValue(abstract);
613   }
614   SymbolPtr s = nullptr;
615   try {
616     MS_LOG_TRY_CATCH_SCOPE;
617     s = builder->BuildValue(prim, inputs, abstract);
618   } catch (std::exception &e) {
619     MS_LOG(INFO) << "Failed to build symbolic value for " << cnode->fullname_with_scope() << " with inputs: " << inputs
620                  << ". error msg: " << e.what();
621     s = nullptr;
622   }
623   if (s == nullptr) {
624     support_infer_ = false;
625     MS_LOG(DEBUG) << "Node " << cnode->fullname_with_scope() << " does not support BuildValue.";
626     return BuildSymbolicValue(abstract);
627   }
628   return s;
629 }
630 
ExtractInputsAbstract(const CNodePtr & cnode)631 AbstractBasePtrList SymbolEngineImpl::ExtractInputsAbstract(const CNodePtr &cnode) {
632   AbstractBasePtrList abs_list;
633   abs_list.reserve(cnode->size());
634   (void)std::transform(cnode->inputs().cbegin() + 1, cnode->inputs().cend(), std::back_inserter(abs_list),
635                        [](const AnfNodePtr &node) {
636                          MS_EXCEPTION_IF_NULL(node);
637                          auto abs = node->abstract();
638                          if (abs == nullptr) {
639                            if (node->isa<ValueNode>()) {
640                              auto vnode = node->cast_ptr<ValueNode>();
641                              MS_EXCEPTION_IF_NULL(vnode);
642                              MS_EXCEPTION_IF_NULL(vnode->value());
643                              abs = vnode->value()->ToAbstract();
644                              MS_EXCEPTION_IF_NULL(abs);
645                              node->set_abstract(abs);
646                              MS_LOG(DEBUG) << "Set new abstract for input node " << node->DebugString();
647                            } else {
648                              // Do not raise exception here, this input may not be used by operation.
649                              MS_LOG(INFO) << "The input " << node->DebugString() << " has null abstract.";
650                            }
651                          }
652                          return abs;
653                        });
654   return abs_list;
655 }
656 
HasAbstractAny(const AbstractBasePtrList & inputs,const AbstractBasePtr & output)657 bool SymbolEngineImpl::HasAbstractAny(const AbstractBasePtrList &inputs, const AbstractBasePtr &output) {
658   return output->isa<abstract::AbstractAny>() ||
659          std::any_of(inputs.begin(), inputs.end(),
660                      [](const AbstractBasePtr &abs) { return abs->isa<abstract::AbstractAny>(); });
661 }
662 
BuildCNodeSymbol(const CNodePtr & cnode)663 void SymbolEngineImpl::BuildCNodeSymbol(const CNodePtr &cnode) {
664   PrimitivePtr prim;
665   AbstractBasePtrList inputs;
666   if (cnode->input(0)->isa<CNode>()) {
667     if (auto iter = special_cnodes_.find(cnode); iter != special_cnodes_.end()) {
668       auto ret = iter->second->ExtractInputs();
669       prim = std::move(ret.first);
670       inputs = std::move(ret.second);
671     }
672     if (prim == nullptr) {
673       prim = std::make_shared<Primitive>("_SpecialCNode");
674     }
675   } else {
676     prim = GetCNodePrimitive(cnode);
677     if (prim == nullptr) {
678       prim = std::make_shared<Primitive>("_UnsupportedCNode");
679     }
680     inputs = ExtractInputsAbstract(cnode);
681   }
682   auto abstract = CloneAbstractIfSymbolExists(cnode);
683   MS_EXCEPTION_IF_NULL(abstract);
684   if (HasAbstractAny(inputs, abstract)) {
685     MS_LOG(DEBUG) << "The input or output of " << cnode->fullname_with_scope()
686                   << " has AbstractAny, which is not supported by symbol engine. node: " << cnode->DebugString();
687     return;
688   }
689 
690   auto builder = OperationBuilderInfoRegistry::GetBuilder(prim->name(), emitter_.get());
691   // theoretically, it's possible that both shape and value are required for a same node.
692   const auto &depend_status = depend_status_map_[cnode];
693   if (depend_status.value) {
694     MS_LOG(DEBUG) << "Build value for node " << cnode->fullname_with_scope() << ".   " << cnode->DebugString();
695     auto sym_value = BuildCNodeSymbolicValue(builder.get(), prim, inputs, abstract, cnode);
696     MS_LOG(DEBUG) << "Set value for node: " << cnode->fullname_with_scope() << ". symbol: " << sym_value->ToString();
697     abstract->SetSymbolicValue(sym_value);
698   }
699 
700   if (depend_status.shape) {
701     MS_LOG(DEBUG) << "Build shape for node " << cnode->fullname_with_scope() << ".   " << cnode->DebugString();
702     auto sym_shape = BuildCNodeSymbolicShape(builder.get(), prim, inputs, abstract, cnode);
703     MS_EXCEPTION_IF_NULL(sym_shape);
704     MS_LOG(DEBUG) << "Set shape for node: " << cnode->fullname_with_scope() << ". symbol: " << sym_shape->ToString();
705     abstract->SetSymbolicShape(sym_shape->as_sptr<ListSymbol>());
706   }
707 }
708 
DumpText() const709 std::string SymbolEngineImpl::DumpText() const {
710   std::ostringstream oss;
711   oss << ToString() << " {\n";
712   for (auto op : ops_) {
713     oss << op->DumpText();
714   }
715   oss << "}\n";
716   return oss.str();
717 }
718 
CloneAbstractIfSymbolExists(const AbstractBasePtr & abs)719 AbstractBasePtr CloneAbstractIfSymbolExists(const AbstractBasePtr &abs) {
720   if (abs == nullptr) {
721     return nullptr;
722   }
723   if (abs->GetSymbolicShape() == nullptr && abs->GetSymbolicValue() == nullptr) {
724     return abs;
725   }
726   try {
727     MS_LOG_TRY_CATCH_SCOPE;
728     auto new_abs = abs->Clone();
729     MS_EXCEPTION_IF_NULL(new_abs);
730     new_abs->SetSymbolicShape(nullptr);
731     new_abs->SetSymbolicValue(nullptr);
732     return new_abs;
733   } catch (std::exception &e) {
734     if (IS_OUTPUT_ON(MsLogLevel::kDebug)) {
735       std::string sym_shape = abs->GetSymbolicShape() == nullptr ? "" : abs->GetSymbolicShape()->ToString();
736       std::string sym_value = abs->GetSymbolicValue() == nullptr ? "" : abs->GetSymbolicValue()->ToString();
737       MS_LOG(DEBUG) << "The abstract has symbol (S:" << sym_shape << ", V:" << sym_value
738                     << ") but cannot be cloned. abstract: " << abs->ToString() << ", msg:" << e.what();
739     }
740   }
741   return abs;
742 }
743 
CleanSymbols(const FuncGraphPtr & func_graph)744 void CleanSymbols(const FuncGraphPtr &func_graph) {
745   std::set<AbstractBasePtr> params_abs;
746   for (auto &param : func_graph->parameters()) {
747     (void)params_abs.insert(param->abstract());
748   }
749   auto nodes = TopoSort(func_graph->get_return(), SuccDeeperSimple, AlwaysInclude);
750   for (auto &node : nodes) {
751     auto abs = node->abstract();
752     if (params_abs.find(abs) != params_abs.end()) {
753       // do not clean the parameters' symbol
754       continue;
755     }
756     if (abs != nullptr) {
757       abs->SetSymbolicShape(nullptr);
758       abs->SetSymbolicValue(nullptr);
759     }
760     auto fg = node->func_graph();
761     if (fg != nullptr) {
762       fg->set_symbol_engine(nullptr);
763     }
764   }
765 }
766 }  // namespace symshape
767 }  // namespace mindspore
768