• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3  *
4  * Copyright 2019-2023 Huawei Technologies Co., Ltd
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #include "pipeline/jit/ps/static_analysis/program_specialize.h"
20 
21 #include <algorithm>
22 #include <exception>
23 #include <unordered_set>
24 #include "mindspore/core/ops/sequence_ops.h"
25 #include "mindspore/core/ops/framework_ops.h"
26 #include "frontend/operator/ops.h"
27 #include "frontend/operator/composite/do_signature.h"
28 #include "abstract/abstract_function.h"
29 #include "abstract/utils.h"
30 #include "ir/graph_utils.h"
31 #include "utils/log_adapter.h"
32 #include "utils/compile_config.h"
33 #include "pipeline/jit/ps/debug/trace.h"
34 #include "pipeline/jit/ps/fallback.h"
35 #include "include/common/fallback.h"
36 #include "include/common/utils/convert_utils_py.h"
37 
38 namespace mindspore {
39 namespace abstract {
40 namespace {
GetEvalResult(const AnfNodeConfigPtr & conf)41 EvalResultPtr GetEvalResult(const AnfNodeConfigPtr &conf) {
42   try {
43     MS_EXCEPTION_IF_NULL(conf);
44     const auto &eval_result = conf->ObtainEvalResult();
45     MS_EXCEPTION_IF_NULL(eval_result);
46     return eval_result;
47   } catch (const std::exception &e) {
48     constexpr int recursive_level = 2;
49     static const bool enable_pre_lift = (common::GetCompileConfig("PRE_LIFT") == "1");
50     if (enable_pre_lift && IsPrimitiveCNode(conf->node(), prim::kPrimPartial)) {
51       MS_LOG(ERROR) << "node: " << conf->node()->DebugString(recursive_level);
52       auto abs_res = std::make_shared<AbstractNone>();
53       auto eval_result = std::make_shared<EvalResult>(abs_res, std::make_shared<AttrValueMap>());
54       return eval_result;
55     }
56     MS_LOG(INTERNAL_EXCEPTION) << "Fail to get eval result with conf " << conf->ToString();
57   }
58 }
59 
BuildValueNode(const ValuePtr & v,const AnfNodePtr & origin_node,const AbstractBasePtr & abs_base)60 AnfNodePtr BuildValueNode(const ValuePtr &v, const AnfNodePtr &origin_node, const AbstractBasePtr &abs_base) {
61   MS_EXCEPTION_IF_NULL(abs_base);
62   AnfNodePtr value_node = NewValueNode(v);
63   value_node->set_abstract(abs_base);
64   value_node->set_debug_info(origin_node->debug_info());
65   MS_LOG(DEBUG) << "Create ValueNode: " << value_node->ToString() << ", with abstract: " << abs_base->ToString();
66   return value_node;
67 }
68 
IsVisible(FuncGraphPtr fg,const FuncGraphPtr & parent)69 bool IsVisible(FuncGraphPtr fg, const FuncGraphPtr &parent) {
70   while (fg != nullptr && fg != parent) {
71     fg = fg->parent();
72   }
73   return fg == parent;
74 }
75 
CanSpecializeValueNode(const AnfNodePtr & node)76 bool CanSpecializeValueNode(const AnfNodePtr &node) {
77   if (IsValueNode<MetaFuncGraph>(node) || IsValueNode<Primitive>(node)) {
78     return true;
79   }
80   if (IsValueNode<FuncGraph>(node)) {
81     if (node->abstract() != nullptr) {
82       auto abs_func = node->abstract()->cast_ptr<FuncGraphAbstractClosure>();
83       // If this funcgraph had specialized in ProcessCNode of FirstPass,
84       // then ignore it.
85       if (abs_func != nullptr && abs_func->specialized()) {
86         MS_LOG(DEBUG) << "Ignore specializing func graph: " << abs_func->ToString();
87         return false;
88       }
89     }
90     return true;
91   }
92   return false;
93 }
94 
PurifyAbstractOfSequence(ProgramSpecializer * const specializer)95 void PurifyAbstractOfSequence(ProgramSpecializer *const specializer) {
96   MS_EXCEPTION_IF_NULL(specializer);
97   constexpr int recursive_level = 2;
98   for (auto &abstract_and_node : specializer->sequence_abstract_list()) {
99     auto &sequence_abs = abstract_and_node.first;
100     MS_EXCEPTION_IF_NULL(sequence_abs);
101     MS_EXCEPTION_IF_NULL(abstract_and_node.second);
102     if (!sequence_abs->PurifyElements()) {
103       MS_LOG(INFO) << "Purify elements failed, abstract: " << sequence_abs->ToString()
104                    << ", node: " << abstract_and_node.second->DebugString(recursive_level);
105     } else {
106       MS_LOG(DEBUG) << "Purify elements, abstract: " << sequence_abs->ToString()
107                     << ", node: " << abstract_and_node.second->DebugString(recursive_level);
108     }
109   }
110 }
111 
112 // Second elimination.
113 // Eliminate the dead node in sequence node, and purify the abstract of sequence node.
EliminateCollectedSequenceNodes(ProgramSpecializer * const specializer)114 void EliminateCollectedSequenceNodes(ProgramSpecializer *const specializer) {
115   MS_EXCEPTION_IF_NULL(specializer);
116   // Call PurifyElements() to purify tuple/list elements.
117   static const auto enable_only_mark_unused_element = (common::GetCompileConfig("DDE_ONLY_MARK") == "1");
118   if (enable_only_mark_unused_element) {
119     return;
120   }
121 
122   // Purify the abstract of tuple/list.
123   PurifyAbstractOfSequence(specializer);
124   // Eliminate DeadNode in tuple/list.
125   for (auto &dead_node_info : specializer->dead_node_list()) {
126     auto pos = dead_node_info.second;
127     auto node = dead_node_info.first;
128     auto flags = GetSequenceNodeElementsUseFlags(node);
129     if (flags == nullptr) {
130       continue;
131     }
132 
133     // Handle MakeTuple/MakeList CNode.
134     auto cnode = dyn_cast_ptr<CNode>(node);
135     if (cnode != nullptr) {
136       if (pos + 1 >= cnode->size()) {
137         continue;
138       }
139       if (!IsDeadNode(cnode->input(pos + 1))) {
140         continue;
141       }
142 
143       constexpr int recursive_level = 2;
144       MS_LOG(DEBUG) << "Erase elements[" << pos << "] DeadNode as zero for " << cnode->DebugString(recursive_level);
145       // Change the node.
146       auto zero_value = NewValueNode(MakeValue<int64_t>(0));
147       zero_value->set_abstract(
148         std::make_shared<abstract::AbstractScalar>(std::make_shared<Int64Imm>(0), std::make_shared<Problem>()));
149       cnode->set_input(pos + 1, zero_value);
150 
151       // Change the abstract.
152       (*flags)[pos] = false;  // Change the use flag as 0.
153       auto sequence_abs = dyn_cast_ptr<AbstractSequence>(node->abstract());
154       if (sequence_abs != nullptr && !sequence_abs->PurifyElements()) {
155         MS_LOG(ERROR) << "Purify elements failed, abstract: " << sequence_abs->ToString()
156                       << ", node: " << node->DebugString(recursive_level);
157       }
158       continue;
159     }
160     // Handle ValueTuple/ValueList.
161     if (IsValueNode<ValueTuple>(node) || IsValueNode<ValueList>(node)) {
162       auto sequence_value = GetValuePtr<ValueSequence>(node);
163       MS_EXCEPTION_IF_NULL(sequence_value);
164       if (pos >= sequence_value->value().size()) {
165         continue;
166       }
167       ValuePtr element_value = sequence_value->value()[pos];
168       auto element_err_value = element_value->cast_ptr<ValueProblem>();
169       if (element_err_value == nullptr || !element_err_value->IsDead()) {
170         continue;
171       }
172 
173       MS_LOG(DEBUG) << "Erase elements[" << pos << "] DeadNode as zero for " << node->DebugString();
174       // Change the node.
175       auto zero = MakeValue<int64_t>(0);
176       auto value_list = const_cast<ValuePtrList &>(sequence_value->value());
177       value_list[pos] = zero;
178 
179       // Change the abstract.
180       (*flags)[pos] = false;  // Change the use flag as 0.
181       auto sequence_abs = dyn_cast_ptr<AbstractSequence>(node->abstract());
182       if (sequence_abs != nullptr && !sequence_abs->PurifyElements()) {
183         constexpr int recursive_level = 2;
184         MS_LOG(ERROR) << "Purify elements failed, abstract: " << sequence_abs->ToString()
185                       << ", node: " << node->DebugString(recursive_level);
186       }
187     }
188   }
189 }
190 
BroadenArgs(const AbstractBasePtrList & args_abs_list,AbstractBasePtrList * broaded_args)191 void BroadenArgs(const AbstractBasePtrList &args_abs_list, AbstractBasePtrList *broaded_args) {
192   MS_EXCEPTION_IF_NULL(broaded_args);
193   (void)std::transform(args_abs_list.begin(), args_abs_list.end(), std::back_inserter(*broaded_args),
194                        [](const AbstractBasePtr &arg) -> AbstractBasePtr {
195                          MS_EXCEPTION_IF_NULL(arg);
196                          if (arg->GetValueTrack() != kValueAny) {
197                            return arg->Broaden();
198                          }
199                          return arg;
200                        });
201 }
202 
203 // These abstract sequence can't handled by DDE.
IsInvalidAbstractSequence(const AbstractSequencePtr & abs)204 bool IsInvalidAbstractSequence(const AbstractSequencePtr &abs) {
205   if (abs == nullptr || abs->isa<AbstractSparseTensor>() || abs->sequence_nodes() == nullptr ||
206       abs->sequence_nodes()->empty()) {
207     return true;
208   }
209   if (abs->dyn_len_arg() || abs->dynamic_len()) {
210     return true;
211   }
212   return false;
213 }
214 }  // namespace
215 
Run(const FuncGraphPtr & fg,const AnalysisContextPtr & context)216 FuncGraphPtr ProgramSpecializer::Run(const FuncGraphPtr &fg, const AnalysisContextPtr &context) {
217   MS_EXCEPTION_IF_NULL(fg);
218   MS_EXCEPTION_IF_NULL(context);
219   MS_LOG(DEBUG) << "Specialize topmost function graph: "
220                 << (context->func_graph() ? context->func_graph()->ToString() : "FG(Null)");
221   if (top_context_ == nullptr) {
222     top_context_ = context;
223     MS_LOG(INFO) << "Specialize set top func graph context: " << context->ToString();
224   }
225   auto top_func_graph_spec = NewFuncGraphSpecializer(context, fg);
226   PushFuncGraphTodoItem(top_func_graph_spec);
227   while (!func_graph_todo_items_.empty()) {
228     auto current_fg_spec = func_graph_todo_items_.top();
229     MS_EXCEPTION_IF_NULL(current_fg_spec);
230     if (current_fg_spec->done()) {
231       func_graph_todo_items_.pop();
232       continue;
233     }
234     // run current func graph specializer
235     current_fg_spec->Run();
236   }
237   auto res = top_func_graph_spec->specialized_func_graph();
238   MS_LOG(DEBUG) << "Specialized top graph: " << res->ToString();
239   EliminateCollectedSequenceNodes(this);
240   return res;
241 }
242 
GetFuncGraphSpecializer(const AnalysisContextPtr & context)243 std::shared_ptr<FuncGraphSpecializer> ProgramSpecializer::GetFuncGraphSpecializer(const AnalysisContextPtr &context) {
244   MS_EXCEPTION_IF_NULL(context);
245   auto iter = specializations_.find(context);
246   if (iter != specializations_.end()) {
247     return iter->second;
248   }
249   return nullptr;
250 }
251 
NewFuncGraphSpecializer(const AnalysisContextPtr & context,const FuncGraphPtr & fg)252 FuncGraphSpecializerPtr ProgramSpecializer::NewFuncGraphSpecializer(const AnalysisContextPtr &context,
253                                                                     const FuncGraphPtr &fg) {
254   MS_EXCEPTION_IF_NULL(context);
255   auto result = specializations_.emplace(context, nullptr);
256   if (result.second) {
257     MS_LOG(DEBUG) << "Make new specializer of context: " << context->ToString() << ", fg: " << fg->ToString();
258     auto fg_spec = std::make_shared<FuncGraphSpecializer>(this, fg, context);
259     result.first->second = fg_spec;
260     return fg_spec;
261   }
262   MS_LOG(INTERNAL_EXCEPTION) << "Specializer exist in cache, can't not create again, context: " << context->ToString();
263 }
264 
SetSpecializedAbstract(const AbstractFunctionPtr & old_abs_func,const AbstractFunctionPtr & new_abs_func,const CNodePtr & cnode,const AnfNodePtr & func)265 void ProgramSpecializer::SetSpecializedAbstract(const AbstractFunctionPtr &old_abs_func,
266                                                 const AbstractFunctionPtr &new_abs_func, const CNodePtr &cnode,
267                                                 const AnfNodePtr &func) {
268   MS_EXCEPTION_IF_NULL(cnode);
269   MS_EXCEPTION_IF_NULL(func);
270   MS_EXCEPTION_IF_NULL(old_abs_func);
271   MS_EXCEPTION_IF_NULL(new_abs_func);
272   auto iter = specialized_abs_map_.find(old_abs_func);
273   if (iter == specialized_abs_map_.end()) {
274     MS_LOG(DEBUG) << "Emplace cnode: " << cnode->DebugString() << ", func: " << func->ToString()
275                   << ", old_abstract: " << old_abs_func->ToString() << ", new_abs_func: " << new_abs_func->ToString();
276     (void)specialized_abs_map_.emplace(old_abs_func, std::make_pair(true, new_abs_func));
277   } else {
278     MS_LOG(DEBUG) << "Duplicate abstract from cnode: " << cnode->DebugString() << ", func: " << func->ToString()
279                   << ", old_abstract: " << old_abs_func->ToString() << ", new_abs_func: " << new_abs_func->ToString();
280     if (!(*iter->second.second == *new_abs_func)) {
281       MS_LOG(DEBUG) << "Duplicate abstract from cnode: " << cnode->DebugString() << ", func: " << func->ToString()
282                     << ", old_abstract: " << old_abs_func->ToString() << ", first: " << iter->second.second->ToString()
283                     << ", new_abs_func: " << new_abs_func->ToString();
284       // Cannot determined which one to use.
285       iter->second.first = false;
286     }
287   }
288 }
289 
GetSpecializedAbstract(const AbstractFunctionPtr & old_abs_func)290 AbstractFunctionPtr ProgramSpecializer::GetSpecializedAbstract(const AbstractFunctionPtr &old_abs_func) {
291   MS_EXCEPTION_IF_NULL(old_abs_func);
292   auto iter = specialized_abs_map_.find(old_abs_func);
293   if (iter != specialized_abs_map_.end()) {
294     if (iter->second.first) {
295       MS_EXCEPTION_IF_NULL(iter->second.second);
296       MS_LOG(DEBUG) << "Find abstract for old_abstract: " << old_abs_func->ToString()
297                     << ", new_abs_func: " << iter->second.second->ToString();
298       return iter->second.second;
299     }
300     return nullptr;
301   }
302   if (old_abs_func->isa<FuncGraphAbstractClosure>()) {
303     const auto &old_func_graph_abs = dyn_cast_ptr<FuncGraphAbstractClosure>(old_abs_func);
304     auto unique_specialized_abs = GetUniqueFuncGraphAbstract(old_func_graph_abs->func_graph());
305     if (unique_specialized_abs != nullptr) {
306       MS_EXCEPTION_IF_NULL(old_func_graph_abs->func_graph());
307       MS_LOG(DEBUG) << "Find unique abstract for funcgraph: " << old_func_graph_abs->func_graph()->ToString() << " in "
308                     << old_abs_func->ToString() << ", unique_abs: " << unique_specialized_abs->ToString();
309       return unique_specialized_abs;
310     }
311   }
312   MS_LOG(DEBUG) << "Cannot find abstract for old_abstract: " << old_abs_func->ToString();
313   return nullptr;
314 }
315 
SpecializeAbstractFuncRecursively(const AbstractFunctionPtr & old_abs_func)316 AbstractFunctionPtr ProgramSpecializer::SpecializeAbstractFuncRecursively(const AbstractFunctionPtr &old_abs_func) {
317   MS_EXCEPTION_IF_NULL(old_abs_func);
318   AbstractFunctionPtr new_abs = nullptr;
319   if (old_abs_func->isa<AbstractFuncUnion>()) {
320     AbstractFuncAtomPtrList func_atoms;
321     auto build_new_abs = [this, &func_atoms](const AbstractFuncAtomPtr &poss) {
322       MS_EXCEPTION_IF_NULL(poss);
323       auto resolved_atom = poss;
324       if (poss->isa<AsyncAbstractFuncAtom>()) {
325         auto async_abs_func = poss->cast_ptr<AsyncAbstractFuncAtom>();
326         const auto &resolved_func = async_abs_func->GetUnique();
327         MS_EXCEPTION_IF_NULL(resolved_func);
328         resolved_atom = resolved_func->cast<AbstractFuncAtomPtr>();
329         MS_EXCEPTION_IF_NULL(resolved_atom);
330         MS_LOG(DEBUG) << "Resolved AsyncAbstractFuncAtom is: " << resolved_atom->ToString();
331       }
332       auto specialized_abs = this->SpecializeAbstractFuncRecursively(resolved_atom);
333       AbstractFuncAtomPtr new_abs_atom = nullptr;
334       if (specialized_abs == nullptr) {
335         MS_LOG(DEBUG) << "Cannot resolve old_abs: " << resolved_atom->ToString()
336                       << " to specialized abstract, use old one";
337         new_abs_atom = resolved_atom;
338       } else if (specialized_abs->isa<AbstractFuncAtom>()) {
339         MS_LOG(DEBUG) << "Resolve old_abs: " << resolved_atom->ToString()
340                       << " to specialized abstract, specialized abstract: " << specialized_abs->ToString();
341         new_abs_atom = specialized_abs->cast<AbstractFuncAtomPtr>();
342       } else {
343         MS_LOG(DEBUG) << "Cannot resolve old_abs: " << resolved_atom->ToString()
344                       << " to AbstractFuncAtom, use old one. Specialized abstract: " << specialized_abs->ToString();
345         new_abs_atom = resolved_atom;
346       }
347       func_atoms.push_back(new_abs_atom);
348     };
349     old_abs_func->Visit(build_new_abs);
350     new_abs = std::make_shared<AbstractFuncUnion>(func_atoms);
351   } else if (old_abs_func->isa<FuncGraphAbstractClosure>() || old_abs_func->isa<MetaFuncGraphAbstractClosure>()) {
352     new_abs = GetSpecializedAbstract(old_abs_func);
353     if (new_abs != nullptr) {
354       MS_LOG(DEBUG) << "Find specialized abstract, old_abstract: " << old_abs_func->ToString()
355                     << ", specialized_abstract: " << new_abs->ToString();
356     } else {
357       MS_LOG(DEBUG) << "cannot find specialized abstract, old_abstract: " << old_abs_func->ToString();
358     }
359   } else if (old_abs_func->isa<PartialAbstractClosure>()) {
360     const auto &old_partial_abs = old_abs_func->cast<PartialAbstractClosurePtr>();
361     const auto &old_abs_fn = old_partial_abs->fn();
362     auto new_abs_fn = GetSpecializedAbstract(old_abs_fn);
363     if (new_abs_fn != nullptr && new_abs_fn->isa<AbstractFuncAtom>()) {
364       auto new_abs_fn_atom = new_abs_fn->cast<AbstractFuncAtomPtr>();
365       auto new_partial_abs =
366         std::make_shared<PartialAbstractClosure>(new_abs_fn_atom, old_partial_abs->args(), old_partial_abs->node());
367       new_partial_abs->set_need_append_to_end(old_partial_abs->need_append_to_end());
368       new_abs = new_partial_abs;
369       MS_LOG(DEBUG) << "Find specialized abstract, old_abstract: " << old_abs_func->ToString()
370                     << ", specialized_abstract: " << new_abs->ToString();
371     } else {
372       MS_LOG(DEBUG) << "Cannot find specialized abstract, old_abstract: " << old_abs_func->ToString();
373     }
374   }
375   return new_abs;
376 }
377 
SpecializeCNodeInput0FuncGraph()378 void ProgramSpecializer::SpecializeCNodeInput0FuncGraph() {
379   MS_EXCEPTION_IF_NULL(manager_);
380   const auto &all_nodes = manager_->all_nodes();
381   for (auto node : all_nodes) {
382     MS_EXCEPTION_IF_NULL(node);
383     if (!node->isa<CNode>()) {
384       continue;
385     }
386     auto &input0 = node->cast_ptr<CNode>()->input(0);
387     MS_EXCEPTION_IF_NULL(input0);
388     if (IsValueNode<FuncGraph>(input0) || IsValueNode<Primitive>(input0)) {
389       continue;
390     }
391     MS_EXCEPTION_IF_NULL(node);
392     const auto &old_abs = input0->abstract();
393     if (old_abs == nullptr) {
394       constexpr auto recursive_level = 2;
395       MS_LOG(INTERNAL_EXCEPTION) << "Node's first input abstract should not be null, "
396                                  << node->DebugString(recursive_level);
397     }
398     if (!(old_abs->isa<FuncGraphAbstractClosure>() || old_abs->isa<MetaFuncGraphAbstractClosure>() ||
399           old_abs->isa<AbstractFuncUnion>() || old_abs->isa<PartialAbstractClosure>())) {
400       continue;
401     }
402     auto old_abs_func = old_abs->cast<AbstractFunctionPtr>();
403     auto new_abs_func = SpecializeAbstractFuncRecursively(old_abs_func);
404     if (new_abs_func != nullptr) {
405       input0->set_abstract(new_abs_func);
406       MS_LOG(DEBUG) << "Find specialized abstract for node: " << input0->DebugString()
407                     << ", old_abstract: " << old_abs->ToString()
408                     << ", specialized_abstract: " << new_abs_func->ToString();
409     } else {
410       MS_LOG(DEBUG) << "cannot find specialized abstract for node: " << input0->DebugString()
411                     << ", old_abstract: " << old_abs_func->ToString();
412     }
413   }
414 }
415 
GetNextCounter()416 static int64_t GetNextCounter() {
417   static int64_t g_CloneCounter = 1;
418   return g_CloneCounter++;
419 }
420 
FuncGraphSpecializer(ProgramSpecializer * const s,const FuncGraphPtr & fg,const AnalysisContextPtr & context)421 FuncGraphSpecializer::FuncGraphSpecializer(ProgramSpecializer *const s, const FuncGraphPtr &fg,
422                                            const AnalysisContextPtr &context)
423     : specializer_(s), func_graph_(fg), context_(context) {
424   parent_ = s->GetFuncGraphSpecializer(context->parent());
425   MS_EXCEPTION_IF_NULL(context->parent());
426   if (ParentNotSpecialized(context)) {
427     MS_LOG(INTERNAL_EXCEPTION) << "Parent func graph should be handled in advance, fg: " << fg->ToString()
428                                << ", context: " << context->ToString()
429                                << ", parent context: " << context->parent()->ToString();
430   }
431   engine_ = s->engine();
432   cloner_ = SpecializerClone(fg, std::make_shared<TraceSpecialize>(GetNextCounter()));
433   specialized_func_graph_ = cloner_->cloned_func_graphs().find(fg)->second;
434   AddTodoItem(fg->get_return());
435   AddTodoItem(fg->parameters());
436 }
437 
ReplicateDisconnectedNode(const AnfNodePtr & node)438 AnfNodePtr FuncGraphSpecializer::ReplicateDisconnectedNode(const AnfNodePtr &node) {
439   MS_EXCEPTION_IF_NULL(node);
440   if (node->isa<ValueNode>()) {
441     return node;
442   }
443   std::shared_ptr<FuncGraphSpecializer> specializer = GetTopSpecializer(node);
444   if (specializer == nullptr) {
445     constexpr auto recursive_level = 2;
446     MS_LOG(INTERNAL_EXCEPTION) << "Specializer should not be null, node: " << node->DebugString(recursive_level)
447                                << ", NodeInfo: \n"
448                                << trace::GetDebugInfoStr(node->debug_info()) << "\n"
449                                << (func_graph_ ? func_graph_->ToString() : "FG(Null)") << " has no parent context?";
450   }
451 
452   // If had replicated, just return that.
453   auto iter = specializer->cloned_nodes().find(node);
454   if (iter != specializer->cloned_nodes().end()) {
455     return iter->second;
456   }
457   auto new_node = specializer->cloner_->CloneDisconnected(node);
458   if (node->isa<CNode>()) {
459     if (!new_node->isa<CNode>()) {
460       MS_LOG(INTERNAL_EXCEPTION) << "new_node must be a CNode, but is " << new_node->DebugString() << ".";
461     }
462     UpdateNewCNodeInputs(node, new_node);
463   }
464 
465   iter = specializer->cloned_nodes().find(node);
466   if (iter != specializer->cloned_nodes().end()) {
467     if (iter->second == node) {
468       MS_LOG(INTERNAL_EXCEPTION) << "Replicated is same as original node, node: " << node->ToString();
469     }
470   } else {
471     MS_LOG(INTERNAL_EXCEPTION) << "Replicate node failed, node: " << node->ToString();
472   }
473   return new_node;
474 }
475 
UpdateNewCNodeInputs(const AnfNodePtr & node,const AnfNodePtr & new_node)476 void FuncGraphSpecializer::UpdateNewCNodeInputs(const AnfNodePtr &node, const AnfNodePtr &new_node) {
477   MS_EXCEPTION_IF_NULL(node);
478   auto c_node = node->cast_ptr<CNode>();
479   MS_EXCEPTION_IF_NULL(c_node);
480   auto inputs = c_node->weak_inputs();
481   AnfNodeWeakPtrList new_inputs;
482   (void)std::transform(inputs.cbegin(), inputs.cend(), std::back_inserter(new_inputs),
483                        [this](const AnfNodeWeakPtr &weak_inp) -> AnfNodePtr {
484                          auto inp = weak_inp.lock();
485                          MS_EXCEPTION_IF_NULL(inp);
486                          auto new_inp = ReplicateDisconnectedNode(inp);
487                          // Refer the comments in BuildReplacedNode.
488                          if (inp->isa<CNode>()) {
489                            auto c_inp = inp->cast<CNodePtr>();
490                            MS_EXCEPTION_IF_NULL(c_inp);
491                            auto c_new_inp = new_inp->cast<CNodePtr>();
492                            MS_EXCEPTION_IF_NULL(c_new_inp);
493                            MS_EXCEPTION_IF_NULL(c_new_inp->func_graph());
494                            MS_LOG(DEBUG) << "Replace in order, inp node: " << inp->DebugString() << " -> "
495                                          << new_inp->DebugString();
496                            c_new_inp->func_graph()->ReplaceInOrder(c_inp, c_new_inp);
497                          }
498                          return new_inp;
499                        });
500   MS_EXCEPTION_IF_NULL(new_node);
501   auto c_new_node = new_node->cast_ptr<CNode>();
502   MS_EXCEPTION_IF_NULL(c_new_node);
503   c_new_node->set_weak_inputs(new_inputs);
504 }
505 
GetReplicatedNode(const AnfNodePtr & node)506 AnfNodePtr FuncGraphSpecializer::GetReplicatedNode(const AnfNodePtr &node) {
507   std::shared_ptr<FuncGraphSpecializer> specializer = GetTopSpecializer(node);
508   if (specializer == nullptr) {
509     constexpr auto recursive_level = 2;
510     MS_LOG(INTERNAL_EXCEPTION) << "Specializer should not be null, node: " << node->DebugString(recursive_level)
511                                << ", NodeInfo: \n"
512                                << trace::GetDebugInfoStr(node->debug_info()) << "\n"
513                                << (func_graph_ ? func_graph_->ToString() : "FG(Null)") << " has no parent context?";
514   }
515   auto iter = specializer->cloned_nodes().find(node);
516   if (iter != specializer->cloned_nodes().end()) {
517     return iter->second;
518   }
519   return node;
520 }
521 
522 // Return itself if node's ValueNode as top,
523 // return the top func graph specializer as top if node's forward Parameter,
524 // or, return the top parent specializer as top.
GetTopSpecializer(const AnfNodePtr & node)525 std::shared_ptr<FuncGraphSpecializer> FuncGraphSpecializer::GetTopSpecializer(const AnfNodePtr &node) {
526   MS_EXCEPTION_IF_NULL(node);
527   FuncGraphPtr fg = node->func_graph();
528   if (fg == nullptr) {  // If ValueNode, return current specializer.
529     MS_LOG(DEBUG) << "Node's a ValueNode, node: " << node->DebugString();
530     return shared_from_this();
531   }
532   std::shared_ptr<FuncGraphSpecializer> specializer = shared_from_this();
533   while (fg != specializer->func_graph_) {
534     if (specializer->parent_ == nullptr && node->isa<Parameter>()) {
535       // If `parent_` is null and forwarded `node` is a Parameter, we'll try to use top func graph as parent.
536       auto &top_context = specializer_->top_context();
537       MS_EXCEPTION_IF_NULL(top_context);
538       if (top_context->func_graph() == fg) {  // `fg` is top func graph.
539         MS_LOG(INFO) << "Used top func graph specializer as parent for "
540                      << (func_graph_ ? func_graph_->ToString() : "FG(Null)") << ", node: " << node->DebugString()
541                      << ", NodeInfo: " << trace::GetDebugInfoStr(node->debug_info());
542         specializer = specializer_->GetFuncGraphSpecializer(top_context);
543         if (specializer == nullptr) {
544           constexpr auto recursive_level = 2;
545           MS_LOG(INTERNAL_EXCEPTION) << "Specializer must not be null, node: " << node->DebugString(recursive_level)
546                                      << ", NodeInfo: " << trace::GetDebugInfoStr(node->debug_info());
547         }
548       } else {
549         MS_EXCEPTION_IF_NULL(top_context->func_graph());
550         MS_LOG(INFO) << "Used current specializer, fg: " << fg->ToString()
551                      << ", current fg: " << specializer->func_graph_->ToString()
552                      << ", top fg: " << top_context->func_graph()->ToString();
553       }
554       break;
555     } else {
556       specializer = specializer->parent_;
557     }
558     if (specializer == nullptr) {
559       return nullptr;
560     }
561   }
562   return specializer;
563 }
564 
Run()565 void FuncGraphSpecializer::Run() {
566   MS_LOG(DEBUG) << "Before run, origin func graph name: " << (func_graph_ ? func_graph_->ToString() : "FG(Null)")
567                 << ", cloned func graph name: "
568                 << (specialized_func_graph_ ? specialized_func_graph_->ToString() : "FG(Null)") << ", func graph: "
569                 << (func_graph_ ? func_graph_->get_return() ? func_graph_->get_return()->DebugString() : "return null"
570                                 : "FG(null)");
571   FirstPass();
572   SecondPass();
573   MS_LOG(DEBUG) << "After run, origin func graph name: " << (func_graph_ ? func_graph_->ToString() : "FG(Null)")
574                 << ", cloned func graph name: "
575                 << (specialized_func_graph_ ? specialized_func_graph_->ToString() : "FG(Null)") << ", new func graph: "
576                 << (specialized_func_graph_ ? specialized_func_graph_->get_return()
577                                                 ? specialized_func_graph_->get_return()->DebugString()
578                                                 : "return null"
579                                             : "FG(null)");
580 }
581 
FirstPass()582 void FuncGraphSpecializer::FirstPass() {
583   while (!todo_.empty()) {
584     AnfNodePtr node = todo_.back();
585     todo_.pop_back();
586     if (node->func_graph() == nullptr) {
587       // Do nothing for ValueNode
588       continue;
589     }
590     if (node->func_graph() != func_graph_) {
591       std::shared_ptr<FuncGraphSpecializer> parent = nullptr;
592       if (parent_ != nullptr) {
593         parent = parent_;
594       } else if (specializer_->top_context() && specializer_->top_context()->func_graph() == node->func_graph() &&
595                  node->isa<Parameter>()) {
596         // If `parent_` is null and forwarded `node` is a Parameter, we'll try to use top func graph as parent.
597         parent = specializer_->GetFuncGraphSpecializer(specializer_->top_context());
598         MS_LOG(INFO) << "Used top func graph specializer as parent for " << func_graph_->ToString()
599                      << ", node: " << node->DebugString()
600                      << ", NodeInfo: " << trace::GetDebugInfoStr(node->debug_info());
601       }
602       if (parent == nullptr) {
603         MS_LOG(INTERNAL_EXCEPTION) << "Parent must not be null, node: " << node->DebugString()
604                                    << ", NodeInfo: " << trace::GetDebugInfoStr(node->debug_info());
605       }
606       parent->AddTodoItem(node);
607       parent->FirstPass();
608       AnfNodePtr new_node = parent->GetReplicatedNode(node);
609       if (new_node->isa<CNode>()) {
610         MS_LOG(DEBUG) << "ProcessCNode in FirstPass for " << func_graph_->ToString()
611                       << ", node: " << node->DebugString() << ", new_node: " << new_node->DebugString();
612         (void)parent->ProcessCNode(new_node->cast<CNodePtr>());
613       }
614       continue;
615     }
616     if (marked_.count(node) > 0) {
617       continue;
618     }
619     (void)marked_.insert(node);
620     ProcessNode(node);
621   }
622 }
623 
624 // Specialize CNode in func graphs
SecondPass()625 void FuncGraphSpecializer::SecondPass() {
626   if (second_pass_todo_.empty()) {
627     second_pass_todo_ = BroadFirstSearchGraphCNodes(specialized_func_graph_->return_node());
628   }
629   MS_LOG(DEBUG) << "Start in index: " << second_pass_todo_index_ << ", fg: " << func_graph_->ToString()
630                 << ", todo list size: " << second_pass_todo_.size();
631   while (second_pass_todo_index_ < second_pass_todo_.size()) {
632     auto success = ProcessCNode(second_pass_todo_[second_pass_todo_index_]);
633     if (!success) {
634       MS_LOG(DEBUG) << "Suspend in index: " << second_pass_todo_index_
635                     << ", node: " << second_pass_todo_[second_pass_todo_index_]->DebugString();
636       return;
637     }
638     ++second_pass_todo_index_;
639   }
640   MS_EXCEPTION_IF_NULL(func_graph_);
641   MS_LOG(DEBUG) << "Set done of fg: " << func_graph_->ToString();
642   done_ = true;
643 }
644 
645 namespace {
UpdateForEmptySequenceNode(const AnfNodePtr & new_node,const AnfNodePtr & old_node,const AbstractSequencePtr & old_sequence_abs)646 void UpdateForEmptySequenceNode(const AnfNodePtr &new_node, const AnfNodePtr &old_node,
647                                 const AbstractSequencePtr &old_sequence_abs) {
648   if (!IsValueNode<ValueTuple>(new_node) && !IsValueNode<ValueList>(new_node)) {
649     return;
650   }
651   MS_EXCEPTION_IF_NULL(old_sequence_abs);
652   auto sequence_nodes = std::make_shared<AnfNodeWeakPtrList>();
653   (void)sequence_nodes->emplace_back(AnfNodeWeakPtr(new_node));
654   old_sequence_abs->set_sequence_nodes(sequence_nodes);
655   auto flags = GetSequenceNodeElementsUseFlags(old_node);
656   if (flags != nullptr) {
657     SetSequenceNodeElementsUseFlags(new_node, flags);
658   } else {
659     SetSequenceNodeElementsUseFlags(new_node,
660                                     std::make_shared<std::vector<bool>>(old_sequence_abs->elements().size(), true));
661   }
662 }
663 
664 // Update elements use flags for MakeTuple/tuple node,
665 // and update the node's AbstractSequence 'sequence_nodes' info.
UpdateSequenceNode(const AnfNodePtr & new_node,const AnfNodePtr & old_node,const AbstractBasePtr & old_abs)666 void UpdateSequenceNode(const AnfNodePtr &new_node, const AnfNodePtr &old_node, const AbstractBasePtr &old_abs) {
667   if (new_node == old_node) {
668     return;
669   }
670   MS_EXCEPTION_IF_NULL(old_node);
671   auto old_sequence_abs = dyn_cast<AbstractSequence>(old_abs);
672   if (old_sequence_abs == nullptr || old_sequence_abs->isa<AbstractSparseTensor>()) {
673     MS_LOG(DEBUG) << "The abstract is not AbstractTuple/AbstractList, " << old_node->DebugString() << " --> "
674                   << new_node->DebugString();
675     return;
676   }
677   if (old_sequence_abs->sequence_nodes() == nullptr || old_sequence_abs->sequence_nodes()->empty()) {
678     MS_LOG(DEBUG) << "No sequence node in old abs, " << old_node->DebugString() << " --> " << new_node->DebugString();
679     // The abstract of old_node may have not sequence_nodes when it is a parameter or tuple output cnode.
680     UpdateForEmptySequenceNode(new_node, old_node, old_sequence_abs);
681     return;
682   }
683 
684   // Since the 'old_node' may not equal to 'old_abs' sequence node,
685   // if the new_node is built by the abstract of 'forward old node',
686   // we just set 'new_node' as 'old_abs' sequence node here.
687   if (IsValueNode<ValueTuple>(new_node) || IsValueNode<ValueList>(new_node)) {
688     // Just find a valid sequence node.
689     for (auto &weak_node : *old_sequence_abs->sequence_nodes()) {
690       auto sequence_node = weak_node.lock();
691       if (sequence_node == nullptr) {
692         continue;
693       }
694       auto flags = GetSequenceNodeElementsUseFlags(sequence_node);
695       if (flags == nullptr) {
696         continue;
697       }
698       // Copy the flags to new node, and set new node to sequence abstract.
699       // Actually, here we needn't require unique sequence nodes pointer between abstract any more.
700       SetSequenceNodeElementsUseFlags(new_node, flags);
701       old_sequence_abs->InsertSequenceNode(new_node);
702       return;
703     }
704     MS_LOG(INFO) << "Not found any valid sequence node, " << old_node->DebugString() << " --> "
705                  << new_node->DebugString();
706     return;
707   }
708 
709   for (auto &weak_node : *old_sequence_abs->sequence_nodes()) {
710     auto sequence_node = weak_node.lock();
711     if (sequence_node == nullptr) {
712       MS_LOG(DEBUG) << "The sequence_nodes is free. " << old_node->DebugString() << " --> " << new_node->DebugString();
713       continue;
714     }
715     if (sequence_node != old_node) {
716       continue;
717     }
718 
719     // Update new node's flags with old one, and update old sequence abstract's source node.
720     auto flags = GetSequenceNodeElementsUseFlags(old_node);
721     MS_LOG(DEBUG) << "Update sequence node, " << old_node->DebugString() << " --> " << new_node->DebugString()
722                   << ", elements_use_flags: " << (*flags);
723     SetSequenceNodeElementsUseFlags(new_node, flags);
724     old_sequence_abs->UpdateSequenceNode(sequence_node, new_node);
725 
726     // Update new sequence abstract if it's not equal to old one.
727     const AbstractBasePtr &new_abs = new_node->abstract();
728     if (old_abs == new_abs) {
729       continue;
730     }
731     MS_LOG(ERROR) << "New abstract, " << old_node->DebugString() << " --> " << new_node->DebugString()
732                   << ", elements_use_flags: " << (*flags);
733     auto new_sequence_abs = dyn_cast_ptr<AbstractSequence>(new_abs);
734     if (new_sequence_abs == nullptr) {
735       MS_LOG(INTERNAL_EXCEPTION) << "New node should be sequence type as well, but got " << new_abs->ToString();
736     }
737     if (new_sequence_abs->sequence_nodes() == nullptr || new_sequence_abs->sequence_nodes()->empty()) {
738       std::shared_ptr<AnfNodeWeakPtrList> new_sequence_nodes = std::make_shared<AnfNodeWeakPtrList>();
739       (void)new_sequence_nodes->emplace_back(AnfNodeWeakPtr(new_node));
740       new_sequence_abs->set_sequence_nodes(new_sequence_nodes);
741     } else {
742       new_sequence_abs->InsertSequenceNode(new_node);
743     }
744   }
745 }
746 
747 // Purify specific input of a CNode.
748 template <typename T, typename S>
PurifySequenceValueNode(const CNodePtr & cnode,size_t index,ProgramSpecializer * const specializer)749 void PurifySequenceValueNode(const CNodePtr &cnode, size_t index, ProgramSpecializer *const specializer) {
750   MS_EXCEPTION_IF_NULL(cnode);
751   const auto &old_input = cnode->input(index);
752   MS_EXCEPTION_IF_NULL(old_input);
753   auto sequence_value = GetValuePtr<T>(old_input);
754   if (sequence_value == nullptr) {
755     return;
756   }
757   auto flags = GetSequenceNodeElementsUseFlags(old_input);
758   if (flags == nullptr) {
759     return;
760   }
761   auto old_input_abs = old_input->abstract();
762   MS_EXCEPTION_IF_NULL(old_input_abs);
763   auto old_sequence_abs = dyn_cast<AbstractSequence>(old_input_abs);
764   MS_EXCEPTION_IF_NULL(old_sequence_abs);
765   // Dynamic len abstract sequence no need purify.
766   if (IsInvalidAbstractSequence(old_sequence_abs)) {
767     return;
768   }
769 
770   std::vector<size_t> dead_node_positions;
771   ValuePtrList elements;
772   AbstractBasePtrList elements_abs{};
773   auto sequence_value_size = sequence_value->value().size();
774   if (flags->size() < sequence_value_size) {
775     MS_LOG(INTERNAL_EXCEPTION) << "Inner exception. CNode: " << cnode->ToString() << " input: " << old_input->ToString()
776                                << " flags size: " << flags->size()
777                                << " values size: " << sequence_value->value().size();
778   }
779   for (size_t i = 0; i < sequence_value_size; ++i) {
780     ValuePtr old_sequence_value = sequence_value->value()[i];
781     MS_EXCEPTION_IF_NULL(old_sequence_value);
782     auto old_sequence_err_value = old_sequence_value->cast_ptr<ValueProblem>();
783     if (old_sequence_err_value != nullptr && old_sequence_err_value->IsDead()) {
784       MS_LOG(DEBUG) << "Collect for erasing elements[" << i << "] DeadNode as zero for " << old_input->DebugString()
785                     << ", which is inputs[" << index << "] of " << cnode->DebugString();
786       (void)dead_node_positions.emplace_back(i);
787     }
788     if (!(*flags)[i]) {
789       auto zero = MakeValue<int64_t>(0);
790       (void)elements.emplace_back(zero);
791       (void)elements_abs.emplace_back(zero->ToAbstract());
792       MS_LOG(DEBUG) << "Erase elements[" << i << "] as zero for " << old_input->DebugString() << ", which is inputs["
793                     << index << "] of " << cnode->DebugString();
794     } else {
795       (void)elements.emplace_back(old_sequence_value);
796       (void)elements_abs.emplace_back(old_sequence_abs->elements()[i]);
797     }
798   }
799   auto new_sequence_value = std::make_shared<T>(elements);
800   auto new_input = NewValueNode(new_sequence_value);
801   auto new_sequence_abs = std::make_shared<S>(elements_abs);
802   std::shared_ptr<AnfNodeWeakPtrList> sequence_nodes = std::make_shared<AnfNodeWeakPtrList>();
803   (void)sequence_nodes->emplace_back(AnfNodeWeakPtr(new_input));
804   new_sequence_abs->set_sequence_nodes(sequence_nodes);
805   if constexpr (std::is_same<S, AbstractList>()) {
806     auto old_sequence_abs_list = old_sequence_abs->cast<AbstractListPtr>();
807     MS_EXCEPTION_IF_NULL(old_sequence_abs_list);
808     if (fallback::HasObjInExtraInfoHolder(old_sequence_abs_list)) {
809       MS_LOG(DEBUG) << "old AbstractList has python object, attach it to new AbstractList.";
810       auto list_obj = fallback::GetObjFromExtraInfoHolder(old_sequence_abs_list);
811       auto create_in_graph = fallback::GetCreateInGraphFromExtraInfoHolder(old_sequence_abs_list);
812       fallback::AttachPyObjToExtraInfoHolder(new_sequence_abs, list_obj, create_in_graph);
813     }
814   }
815 
816   new_input->set_abstract(new_sequence_abs);
817 
818   // Always reset tuple value node's use flags as non-use.
819   SetSequenceNodeElementsUseFlags(new_input, flags);
820   MS_LOG(DEBUG) << "Update ValueTuple/ValueList, " << old_input->DebugString() << " --> " << new_input->DebugString()
821                 << ", which is inputs[" << index << "] of " << cnode->DebugString() << ", flags: " << (*flags);
822   // Keep the node not to release before we purify its abstract.
823   (void)specializer->sequence_abstract_list().emplace_back(std::pair(new_sequence_abs, old_input));
824   for (size_t pos : dead_node_positions) {
825     (void)specializer->dead_node_list().emplace_back(std::pair(new_input, pos));
826   }
827   cnode->set_input(index, new_input);
828 }
829 
PurifyNamedTupleValueNode(const CNodePtr & cnode,size_t index,ProgramSpecializer * const specializer)830 void PurifyNamedTupleValueNode(const CNodePtr &cnode, size_t index, ProgramSpecializer *const specializer) {
831   MS_EXCEPTION_IF_NULL(cnode);
832   const auto &old_input = cnode->input(index);
833   MS_EXCEPTION_IF_NULL(old_input);
834   auto sequence_value = GetValuePtr<ValueNamedTuple>(old_input);
835   if (sequence_value == nullptr) {
836     return;
837   }
838   auto flags = GetSequenceNodeElementsUseFlags(old_input);
839   if (flags == nullptr) {
840     return;
841   }
842   auto old_input_abs = old_input->abstract();
843   MS_EXCEPTION_IF_NULL(old_input_abs);
844   auto old_sequence_abs = dyn_cast<AbstractSequence>(old_input_abs);
845   MS_EXCEPTION_IF_NULL(old_sequence_abs);
846   // Dynamic len abstract sequence no need purify.
847   if (IsInvalidAbstractSequence(old_sequence_abs)) {
848     return;
849   }
850 
851   std::vector<size_t> dead_node_positions;
852   ValuePtrList elements;
853   AbstractBasePtrList elements_abs{};
854   auto sequence_value_size = sequence_value->value().size();
855   if (flags->size() < sequence_value_size) {
856     MS_LOG(INTERNAL_EXCEPTION) << "Inner exception. CNode: " << cnode->ToString() << " input: " << old_input->ToString()
857                                << " flags size: " << flags->size()
858                                << " values size: " << sequence_value->value().size();
859   }
860   for (size_t i = 0; i < sequence_value_size; ++i) {
861     ValuePtr old_sequence_value = sequence_value->value()[i];
862     MS_EXCEPTION_IF_NULL(old_sequence_value);
863     auto old_sequence_err_value = old_sequence_value->cast_ptr<ValueProblem>();
864     if (old_sequence_err_value != nullptr && old_sequence_err_value->IsDead()) {
865       MS_LOG(DEBUG) << "Collect for erasing elements[" << i << "] DeadNode as zero for " << old_input->DebugString()
866                     << ", which is inputs[" << index << "] of " << cnode->DebugString();
867       (void)dead_node_positions.emplace_back(i);
868     }
869     if (!(*flags)[i]) {
870       auto zero = MakeValue<int64_t>(0);
871       (void)elements.emplace_back(zero);
872       (void)elements_abs.emplace_back(zero->ToAbstract());
873       MS_LOG(DEBUG) << "Erase elements[" << i << "] as zero for " << old_input->DebugString() << ", which is inputs["
874                     << index << "] of " << cnode->DebugString();
875     } else {
876       (void)elements.emplace_back(old_sequence_value);
877       (void)elements_abs.emplace_back(old_sequence_abs->elements()[i]);
878     }
879   }
880 
881   const auto &sub_class_name = sequence_value->sub_class_name();
882   const auto &keys = sequence_value->key();
883   abstract::AbstractBasePtrList key_abs;
884   (void)std::transform(keys.begin(), keys.end(), std::back_inserter(key_abs), [](const ValuePtr &key) {
885     MS_EXCEPTION_IF_NULL(key);
886     return key->ToAbstract();
887   });
888   auto new_sequence_value = std::make_shared<ValueNamedTuple>(sub_class_name, keys, elements);
889   auto new_input = NewValueNode(new_sequence_value);
890   auto new_sequence_abs = std::make_shared<AbstractNamedTuple>(sub_class_name, key_abs, elements_abs);
891   std::shared_ptr<AnfNodeWeakPtrList> sequence_nodes = std::make_shared<AnfNodeWeakPtrList>();
892   (void)sequence_nodes->emplace_back(AnfNodeWeakPtr(new_input));
893   new_sequence_abs->set_sequence_nodes(sequence_nodes);
894 
895   new_input->set_abstract(new_sequence_abs);
896 
897   // Always reset tuple value node's use flags as non-use.
898   SetSequenceNodeElementsUseFlags(new_input, flags);
899   MS_LOG(DEBUG) << "Update ValueNamedTuple, " << old_input->DebugString() << " --> " << new_input->DebugString()
900                 << ", which is inputs[" << index << "] of " << cnode->DebugString() << ", flags: " << (*flags);
901   // Keep the node not to release before we purify its abstract.
902   (void)specializer->sequence_abstract_list().emplace_back(std::pair(new_sequence_abs, old_input));
903   for (size_t pos : dead_node_positions) {
904     (void)specializer->dead_node_list().emplace_back(std::pair(new_input, pos));
905   }
906   cnode->set_input(index, new_input);
907 }
908 }  // namespace
909 
910 // First elimination.
911 // Eliminate the unused items of Tuple/List.
912 // Just adjust the nodes, not change the abstracts and dead nodes.
EliminateUnusedSequenceItem(const CNodePtr & cnode) const913 void FuncGraphSpecializer::EliminateUnusedSequenceItem(const CNodePtr &cnode) const {
914   if (cnode == nullptr || cnode->abstract() == nullptr) {
915     MS_LOG(INTERNAL_EXCEPTION) << "The parameter \'node\' and its abstract should not be null.";
916   }
917   auto &sequence_abstract_list = specializer_->sequence_abstract_list();
918 
919   // Add CNode's inputs if they're sequence abstract, and sequence nodes exist.
920   (void)std::for_each(cnode->weak_inputs().cbegin(), cnode->weak_inputs().cend(),
921                       [&sequence_abstract_list](const AnfNodeWeakPtr &weak_input) {
922                         auto input = weak_input.lock();
923                         MS_EXCEPTION_IF_NULL(input);
924                         const AbstractBasePtr input_abs = input->abstract();
925                         AbstractSequencePtr input_sequence_abs = dyn_cast<AbstractSequence>(input_abs);
926                         if (IsInvalidAbstractSequence(input_sequence_abs)) {
927                           return;
928                         }
929                         // Not call PurifyElements() here, just add to list.
930                         (void)sequence_abstract_list.emplace_back(std::pair(input_sequence_abs, input));
931                       });
932 
933   // Add CNode if it's sequence abstract, and sequence nodes exist.
934   const AbstractBasePtr abs = cnode->abstract();
935   AbstractSequencePtr sequence_abs = dyn_cast<AbstractSequence>(abs);
936   if (IsInvalidAbstractSequence(sequence_abs)) {
937     return;
938   }
939   // Not call PurifyElements() here, just add to list.
940   (void)sequence_abstract_list.emplace_back(std::pair(sequence_abs, cnode));
941 
942   // Purify MakeTuple/MakeList CNode.
943   if (IsPrimitiveCNode(cnode, prim::kPrimMakeTuple) || IsPrimitiveCNode(cnode, prim::kPrimMakeList)) {
944     auto flags = GetSequenceNodeElementsUseFlags(cnode);
945     if (flags != nullptr) {
946       std::vector<AnfNodePtr> inputs;
947       (void)inputs.emplace_back(cnode->input(0));
948       for (size_t i = 0; i < (*flags).size(); ++i) {
949         auto old_input = cnode->input(i + 1);
950         if (!(*flags)[i]) {
951           auto zero_value = NewValueNode(MakeValue<int64_t>(0));
952           zero_value->set_abstract(std::make_shared<abstract::AbstractScalar>(std::make_shared<Int64Imm>(0)));
953           (void)inputs.emplace_back(zero_value);
954           constexpr int recursive_level = 2;
955           MS_LOG(DEBUG) << "Erase elements[" << i << "] as zero for " << cnode->DebugString(recursive_level);
956         } else if (IsDeadNode(old_input)) {
957           constexpr int recursive_level = 2;
958           MS_LOG(DEBUG) << "Collect for erasing elements[" << i << "] DeadNode as zero for " << cnode << "/"
959                         << cnode->DebugString(recursive_level);
960           (void)specializer_->dead_node_list().emplace_back(std::pair(cnode, i));
961           (void)inputs.emplace_back(old_input);
962         } else {
963           (void)inputs.emplace_back(old_input);
964         }
965       }
966       cnode->set_inputs(std::move(inputs));
967       cnode->set_abstract(sequence_abs);
968     }
969   }
970   // Purify each Tuple/List ValueNode in CNode.
971   for (size_t i = 1; i < cnode->size(); ++i) {
972     if (IsValueNode<ValueTuple>(cnode->input(i))) {
973       if (IsValueNode<ValueNamedTuple>(cnode->input(i))) {
974         PurifyNamedTupleValueNode(cnode, i, specializer_);
975       } else {
976         PurifySequenceValueNode<ValueTuple, AbstractTuple>(cnode, i, specializer_);
977       }
978     } else if (IsValueNode<ValueList>(cnode->input(i))) {
979       PurifySequenceValueNode<ValueList, AbstractList>(cnode, i, specializer_);
980     }
981   }
982 }
983 
ProcessNode(const AnfNodePtr & node)984 void FuncGraphSpecializer::ProcessNode(const AnfNodePtr &node) {
985   MS_EXCEPTION_IF_NULL(node);
986   ScopeGuard scope_guard(node->scope());
987   AnfNodeConfigPtr conf = MakeConfig(node);
988   MS_EXCEPTION_IF_NULL(conf);
989   TraceGuard guard(std::make_shared<TraceCopy>(node->debug_info()));
990   AnfNodePtr new_node = GetReplicatedNode(node);
991   MS_EXCEPTION_IF_NULL(new_node);
992   if (new_node->func_graph() != specialized_func_graph_) {
993     MS_LOG(INTERNAL_EXCEPTION) << "Found not specialized node, node: " << node->DebugString()
994                                << ", new_node: " << new_node->DebugString() << ", new_node->func_graph(): "
995                                << (new_node->func_graph() ? new_node->func_graph()->ToString() : "FG(Null)")
996                                << ", specialized_func_graph_: " << specialized_func_graph_->ToString();
997   }
998   const EvalResultPtr &conf_eval_result = GetEvalResult(conf);
999   MS_EXCEPTION_IF_NULL(conf_eval_result);
1000   new_node->set_abstract(conf_eval_result->abstract());
1001   MS_EXCEPTION_IF_NULL(new_node->abstract());
1002 
1003   // Update PartialAbstractClosure's bound node.
1004   if (new_node->isa<CNode>() && new_node->abstract()->isa<PartialAbstractClosure>()) {
1005     auto partial_closure = dyn_cast_ptr<PartialAbstractClosure>(new_node->abstract());
1006     MS_EXCEPTION_IF_NULL(partial_closure);
1007     auto partial_node = partial_closure->node();
1008     if (partial_node != nullptr && GetTopSpecializer(partial_node) != nullptr) {
1009       auto new_partial_node = GetReplicatedNode(partial_node);
1010       if (new_partial_node != partial_node) {  // Old Partial CNode was replaced. Need update.
1011         partial_closure->set_node(new_partial_node);
1012       }
1013     }
1014   }
1015   MS_LOG(DEBUG) << "Set new_node: " << new_node->DebugString() << ", abstract as: " << new_node->abstract()->ToString()
1016                 << ", func_graph_: " << func_graph_->ToString()
1017                 << ", specialized_func_graph_: " << specialized_func_graph_->ToString();
1018 
1019   if (!node->isa<CNode>()) {
1020     return;
1021   }
1022   static const auto enable_eliminate_unused_element = (common::GetCompileConfig("ENABLE_DDE") != "0");
1023   auto attrs = conf_eval_result->attribute();
1024   auto c_old = node->cast_ptr<CNode>();
1025   auto c_new = new_node->cast_ptr<CNode>();
1026   MS_EXCEPTION_IF_NULL(c_new);
1027   auto new_inputs = c_new->weak_inputs();
1028   auto old_inputs = c_old->weak_inputs();
1029   for (size_t i = 0; i < old_inputs.size(); ++i) {
1030     auto node_input = old_inputs[i].lock();
1031     MS_EXCEPTION_IF_NULL(node_input);
1032     AnfNodeConfigPtr input_conf = MakeConfig(node_input);
1033     MS_EXCEPTION_IF_NULL(input_conf);
1034     const auto &eval_result = GetEvalResult(input_conf);
1035     const AbstractBasePtr &abs = eval_result->abstract();
1036     // Check if there's an inplace abstract and use it.
1037     AbstractBasePtr real_abs;
1038     if (abs->inplace_abstract() == nullptr) {
1039       real_abs = abs;
1040     } else {
1041       real_abs = abs->inplace_abstract();
1042       MS_LOG(INFO) << "Use inplace abstract, " << abs->ToString() << " -> " << real_abs->ToString();
1043     }
1044     bool ignore_build_value = false;
1045     AnfNodePtr replace_node = nullptr;
1046     MS_EXCEPTION_IF_NULL(specializer_->engine());
1047     if (specializer_->engine()->check_side_effect()) {
1048       auto cnode_input = dyn_cast_ptr<CNode>(node_input);
1049       ignore_build_value = (cnode_input != nullptr && cnode_input->has_side_effect_node());
1050       if (ignore_build_value) {
1051         MS_LOG(INFO) << "Don't build value node for CNode which contains isolated side-effect inputs, node: "
1052                      << cnode_input->DebugString() << ", flag: " << cnode_input->has_side_effect_node();
1053       }
1054     }
1055     if (!ignore_build_value) {
1056       // First try to check if node_input can be replaced by a ValueNode. If cannot, then try to check if
1057       // can be replaced by another CNode from anfnode_config_map, otherwise use the replicated node.
1058       replace_node = BuildPossibleValueNode(node_input, real_abs, attrs, node);
1059     }
1060     if (replace_node == nullptr) {
1061       replace_node = BuildReplacedNode(input_conf);
1062       MS_EXCEPTION_IF_NULL(replace_node);
1063       replace_node->set_abstract(real_abs);
1064       MS_LOG(DEBUG) << "Set replaced input[" << i << "]: " << replace_node->DebugString()
1065                     << ", NodeConfig: " << input_conf->ToString() << ", result: " << real_abs.get() << "/"
1066                     << real_abs->ToString();
1067     } else {
1068       MS_EXCEPTION_IF_NULL(real_abs);
1069       MS_LOG(DEBUG) << "Build possible value node for node: " << node_input->DebugString()
1070                     << ", real_abs: " << real_abs->ToString() << ", replace_node: " << replace_node->DebugString();
1071     }
1072     MS_EXCEPTION_IF_NULL(replace_node);
1073     if (enable_eliminate_unused_element) {
1074       UpdateSequenceNode(replace_node, node_input, real_abs);
1075     }
1076     if (new_inputs[i].lock() != replace_node) {
1077       new_node->func_graph()->AddOwnNode(replace_node);
1078       new_inputs[i] = replace_node;
1079       MS_LOG(DEBUG) << "Set new_input[" << i << "]: " << replace_node->DebugString();
1080     }
1081   }
1082   c_new->set_weak_inputs(new_inputs);
1083   MS_LOG(DEBUG) << "Update cnode: " << c_new << "/" << c_new->DebugString();
1084 }
1085 
BuildReplacedNode(const AnfNodeConfigPtr & conf)1086 AnfNodePtr FuncGraphSpecializer::BuildReplacedNode(const AnfNodeConfigPtr &conf) {
1087   MS_EXCEPTION_IF_NULL(conf);
1088   auto conf_iter = engine_->anfnode_config_map().find(conf);
1089   AnfNodeConfigPtr new_conf = conf;
1090   while (conf_iter != engine_->anfnode_config_map().end()) {
1091     MS_LOG(DEBUG) << "Origin conf: node(" << (new_conf->node() ? new_conf->node()->DebugString() : "Node(Null)") << ")";
1092     new_conf = conf_iter->second;
1093     MS_EXCEPTION_IF_NULL(new_conf);
1094     const auto &forward_node = new_conf->node();
1095     MS_EXCEPTION_IF_NULL(forward_node);
1096     MS_LOG(DEBUG) << "Replaced conf: node(" << forward_node->DebugString() << ")";
1097     const auto &replicated_forward_node = ReplicateDisconnectedNode(forward_node);
1098     if (replicated_forward_node && replicated_forward_node->isa<CNode>()) {
1099       // The AnfNode in order_list can be:
1100       // case 1: also in FuncGraphManager, so it can be got from nodes API of func_graph. it will
1101       //         be replaced in CloneOrderList in Cloner.
1102       // case 2: AnfNode is not in FuncGraphManager which generated in Analyze phase, so it will not
1103       //         be cloned by normal clone API.
1104       //    2.1: A forward node , the original node is in FuncGraphManager. The original node will
1105       //         be cloned in CloneOrderList in Cloner, and the replicated forward node will replace
1106       //         the replicated original node here.
1107       //    2.2: an input of a forward node, such as Cast CNode generated in DoCast. It is also another
1108       //         original node to fowrad.
1109       //    2.3: an input of an input of a forward node, but it's not an original node. Like the Cast CNode
1110       //         in MixedPrecisionCastHelper.
1111       // For 2.2 and 2.3, we will put a placeholder in order list of replicated func_graph, refer to
1112       // CloneOrderlist, and it will be replaced inside ReplicateDisconnectedNode.
1113       // For 2.1 the following code will do the job, replace replicated origin cnode with the replicated
1114       // forward one in the replicated func_graph.
1115       MS_EXCEPTION_IF_NULL(conf_iter->first);
1116       const auto &origin_node = conf_iter->first->node();
1117       const auto &replicated_origin_node = GetReplicatedNode(origin_node);
1118       if (replicated_origin_node != origin_node) {
1119         MS_LOG(DEBUG) << "Replace replicated origin node in order list: " << replicated_origin_node->DebugString()
1120                       << ", with replicated forwarded node: " << replicated_forward_node->DebugString();
1121         MS_EXCEPTION_IF_NULL(replicated_forward_node->func_graph());
1122         replicated_forward_node->func_graph()->ReplaceInOrder(replicated_origin_node, replicated_forward_node);
1123       } else {
1124         MS_LOG(INTERNAL_EXCEPTION) << "Origin node is not replicated in specialized func_graph, origin node: "
1125                                    << (origin_node ? origin_node->DebugString() : "Node(Null)");
1126       }
1127     }
1128     conf_iter = engine_->anfnode_config_map().find(new_conf);
1129   }
1130   AddTodoItem(new_conf->node());
1131   auto repl = GetReplicatedNode(new_conf->node());
1132   if (repl->func_graph()) {
1133     MS_LOG(DEBUG) << "Set repl: graph(" << repl->func_graph()->ToString() << "), node: " << repl->DebugString()
1134                   << ") to replace origin: " << new_conf->node()->DebugString();
1135   } else {
1136     MS_LOG(DEBUG) << "Set repl: graph(nullptr), node(" << repl->DebugString()
1137                   << ") to replace origin: " << new_conf->node()->DebugString();
1138   }
1139   return repl;
1140 }
1141 
BuildSpecializedNode(const CNodePtr & cnode,const AnfNodePtr & func,const AbstractBasePtr & abs,const AbstractBasePtrList & args_abs_list)1142 AnfNodePtr FuncGraphSpecializer::BuildSpecializedNode(const CNodePtr &cnode, const AnfNodePtr &func,
1143                                                       const AbstractBasePtr &abs,
1144                                                       const AbstractBasePtrList &args_abs_list) {
1145   MS_EXCEPTION_IF_NULL(abs);
1146   MS_EXCEPTION_IF_NULL(func);
1147   auto real_a = dyn_cast_ptr<AbstractFunction>(abs);
1148   MS_EXCEPTION_IF_NULL(real_a);
1149 
1150   AbstractFunctionPtr func_abs = real_a->GetUnique();
1151   SpecializeStatusCode errcode;
1152   ScopeGuard scope_guard(func->scope());
1153   AnfNodePtr specialized_node = BuildSpecializedNodeInner(cnode, func, abs, func_abs, args_abs_list, &errcode);
1154   if (specialized_node == nullptr) {
1155     // If errcode is success, it means child graph specialize.
1156     if (errcode == kSpecializeSuccess) {
1157       return nullptr;
1158     }
1159     if (errcode == kSpecializeDead) {
1160       const auto err_dead_value = std::make_shared<ValueProblem>(ValueProblemType::kDead);
1161       const auto err_dead_abstract = std::make_shared<AbstractProblem>(err_dead_value, func);
1162       specialized_node = BuildValueNode(err_dead_value, cnode, err_dead_abstract);
1163       constexpr auto recursive_level = 2;
1164       MS_LOG(DEBUG) << "DEAD for func: " << func->DebugString(recursive_level) << ", abstract: " << abs->ToString();
1165     } else if (errcode == kSpecializePoly) {
1166       const auto error_poly_value = std::make_shared<ValueProblem>(ValueProblemType::kPoly);
1167       const auto error_poly_abstract = std::make_shared<AbstractProblem>(error_poly_value, func);
1168       specialized_node = BuildValueNode(error_poly_value, cnode, error_poly_abstract);
1169       constexpr auto recursive_level = 2;
1170       MS_LOG(DEBUG) << "POLY for func: " << func->DebugString(recursive_level) << ", abstract: " << abs->ToString();
1171     } else {
1172       MS_LOG(INTERNAL_EXCEPTION) << "Failed to build specialized func, func: " << func->DebugString()
1173                                  << ", abstract: " << abs->ToString();
1174     }
1175   }
1176 
1177   // Set the flag, so this MetaFuncGraph will be Re-AutoMonaded.
1178   MS_EXCEPTION_IF_NULL(func_abs);
1179   if (func_abs->isa<MetaFuncGraphAbstractClosure>()) {
1180     auto specialized_fg = GetValuePtr<FuncGraph>(specialized_node);
1181     if (specialized_fg != nullptr && (args_abs_list.size() > 1) && args_abs_list.back() != nullptr &&
1182         args_abs_list.back()->isa<AbstractUMonad>()) {
1183       specialized_fg->set_flag(mindspore::kFuncGraphFlagReAutoMonad, true);
1184     }
1185   }
1186   return specialized_node;
1187 }
1188 
BuildSpecializedNodeInner(const CNodePtr & cnode,const AnfNodePtr & func,const AbstractBasePtr & abs,const AbstractFunctionPtr & func_abs,const AbstractBasePtrList & args,SpecializeStatusCode * errcode)1189 AnfNodePtr FuncGraphSpecializer::BuildSpecializedNodeInner(const CNodePtr &cnode, const AnfNodePtr &func,
1190                                                            const AbstractBasePtr &abs,
1191                                                            const AbstractFunctionPtr &func_abs,
1192                                                            const AbstractBasePtrList &args,
1193                                                            SpecializeStatusCode *errcode) {
1194   MS_EXCEPTION_IF_NULL(abs);
1195   MS_EXCEPTION_IF_NULL(func_abs);
1196   MS_EXCEPTION_IF_NULL(errcode);
1197   *errcode = kSpecializeSuccess;
1198   auto real_func = dyn_cast_ptr<TypedPrimitiveAbstractClosure>(func_abs);
1199   if (real_func != nullptr) {
1200     return BuildValueNode(real_func->prim(), cnode, abs);
1201   }
1202 
1203   EvaluatorPtr eval = engine_->GetEvaluatorFor(func_abs);
1204   MS_EXCEPTION_IF_NULL(eval);
1205   eval->set_bound_node(cnode);
1206   AbstractBasePtrList args_abs_list = eval->NormalizeArgs(args);
1207   std::pair<AbstractBasePtrList, AbstractBasePtr> result;
1208   SpecializeStatusCode status = AcquireUniqueEvalResult(func_abs, eval, args_abs_list, &result);
1209   if (status != kSpecializeSuccess) {
1210     *errcode = status;
1211     return nullptr;
1212   }
1213   args_abs_list = result.first;
1214   AbstractBasePtr unique_output = result.second;
1215 
1216   auto prim_func = dyn_cast_ptr<PrimitiveAbstractClosure>(func_abs);
1217   if (prim_func != nullptr) {
1218     auto type_func = std::make_shared<TypedPrimitiveAbstractClosure>(prim_func->prim(), args_abs_list, unique_output);
1219     return BuildValueNode(prim_func->prim(), cnode, type_func);
1220   }
1221 
1222   if (!eval->isa<BaseFuncGraphEvaluator>()) {
1223     MS_LOG(INTERNAL_EXCEPTION) << "Expect the eval is a BaseGraphEvaluator, but got " << eval->ToString()
1224                                << ", func: " << func->DebugString() << ", abs: " << func_abs->ToString()
1225                                << ", args: " << args;
1226   }
1227   auto real_eval = dyn_cast<BaseFuncGraphEvaluator>(eval);
1228 
1229   if (func_abs->context() == nullptr) {
1230     MS_LOG(INTERNAL_EXCEPTION) << "Func context is nullptr NodeInfo: "
1231                                << trace::GetDebugInfoStr(func_graph_->debug_info());
1232   }
1233   auto context = GetAnalysisContext(engine_, real_eval, args_abs_list);
1234   if (context == nullptr) {
1235     MS_LOG(INTERNAL_EXCEPTION) << "Failed to get context from static analysis cache, call node: "
1236                                << cnode->DebugString() << ", args: " << mindspore::ToString(args);
1237   }
1238 
1239   constexpr auto recursive_level = 2;
1240   MS_LOG(DEBUG) << "Specialize function graph: " << context->func_graph()->ToString() << ", args: " << args_abs_list
1241                 << ", func: " << func->DebugString(recursive_level) << ", context: " << context.get() << ", "
1242                 << context->ToString();
1243   MS_EXCEPTION_IF_NULL(context->func_graph());
1244   if (context->func_graph()->stub()) {
1245     MS_EXCEPTION_IF_NULL(context->func_graph()->get_return());
1246     MS_LOG(DEBUG) << "Specialize stub function graph, return the original node: " << context->func_graph()->ToString()
1247                   << ", args: " << args_abs_list.size()
1248                   << ", graph: " << context->func_graph()->get_return()->DebugString() << ", " << func->ToString();
1249     return func;
1250   }
1251   // Get the upper most func graph of which parent has been specialized.
1252   while (ParentNotSpecialized(context)) {
1253     context = context->parent();
1254   }
1255   auto fg_spec = specializer_->GetFuncGraphSpecializer(context);
1256   // If func graph specializer dose not exist before, make a new specializer and push to stack, and return nullptr.
1257   if (fg_spec == nullptr) {
1258     fg_spec = specializer_->NewFuncGraphSpecializer(context, context->func_graph());
1259     specializer_->PushFuncGraphTodoItem(fg_spec);
1260     return nullptr;
1261   }
1262 
1263   FuncGraphPtr func_graph = fg_spec->specialized_func_graph();
1264   MS_LOG(DEBUG) << "Get spec fg of func graph: " << context->func_graph()->ToString()
1265                 << ", specialized fg: " << func_graph->ToString();
1266   MS_EXCEPTION_IF_NULL(func_graph);
1267   func_graph->set_flag(kFuncGraphFlagUndetermined, false);
1268   static auto dummy_context = AnalysisContext::DummyContext();
1269   MS_EXCEPTION_IF_NULL(dummy_context);
1270   // Build a map that map unspecialized abstract function to specialized function, later it can be used
1271   // for specialize input0 of CNode in specialized func graph if input0 is not FuncGraph.
1272   auto new_abs_func = std::make_shared<FuncGraphAbstractClosure>(func_graph, dummy_context, nullptr, true);
1273   specializer_->SetSpecializedAbstract(func_abs, new_abs_func, cnode, func);
1274   if (func_abs->isa<FuncGraphAbstractClosure>()) {
1275     const auto &func_graph_abs = dyn_cast_ptr<FuncGraphAbstractClosure>(func_abs);
1276     specializer_->SetSpecializedFuncGraphToAbstract(func_graph_abs->func_graph(), new_abs_func);
1277   }
1278   return BuildValueNode(func_graph, cnode, new_abs_func);
1279 }
1280 
1281 // The CNode function is Parameter.
1282 // If the Parameter is PartialApp, unpack it and rebuild a new one.
BuildSpecializedParameterCNode(const CNodePtr & cnode)1283 AnfNodePtr FuncGraphSpecializer::BuildSpecializedParameterCNode(const CNodePtr &cnode) {
1284   MS_EXCEPTION_IF_NULL(cnode);
1285   auto new_inputs = cnode->weak_inputs();
1286   if (new_inputs.empty()) {
1287     MS_LOG(INTERNAL_EXCEPTION) << "inputs can't be empty.";
1288   }
1289   AnfNodePtr func = new_inputs[0].lock();
1290   MS_EXCEPTION_IF_NULL(func);
1291   AbstractBasePtr func_abs = func->abstract();
1292 
1293   AbstractBasePtrList args;
1294   auto real_func_abs = func_abs;
1295   MS_EXCEPTION_IF_NULL(func_abs);
1296   if (func_abs->isa<PartialAbstractClosure>()) {
1297     auto partial_closure = dyn_cast_ptr<PartialAbstractClosure>(func_abs);
1298     real_func_abs = partial_closure->fn();
1299     args = partial_closure->args();
1300   }
1301   (void)std::transform(new_inputs.cbegin() + 1, new_inputs.cend(), std::back_inserter(args),
1302                        [](const AnfNodeWeakPtr &weak_inp) -> AbstractBasePtr {
1303                          auto inp = weak_inp.lock();
1304                          MS_EXCEPTION_IF_NULL(inp);
1305                          return inp->abstract();
1306                        });
1307 
1308   ScopeGuard scope_guard(cnode->scope());
1309   auto specialized_node = BuildSpecializedNode(cnode, func, real_func_abs, args);
1310   if (specialized_node == nullptr) {
1311     return nullptr;
1312   }
1313 
1314   // Built for Non-Partial parameter function.
1315   if (!func_abs->isa<PartialAbstractClosure>()) {
1316     MS_LOG(DEBUG) << "cnode: " << cnode->DebugString() << ", func_abs: " << func_abs->ToString()
1317                   << ", specialized_node: " << specialized_node->DebugString();
1318     return specialized_node;
1319   }
1320 
1321   // To build for Partial parameter function.
1322   auto partial_closure = dyn_cast<PartialAbstractClosure>(func_abs);
1323   AnfNodePtrList partial_node_list = {BuildValueNode(prim::kPrimPartial, cnode, FromValueInside(prim::kPrimPartial)),
1324                                       specialized_node};
1325   auto partial_node = partial_closure->node();
1326   if (partial_node == nullptr) {
1327     MS_LOG(INTERNAL_EXCEPTION) << "Partial node is null, cnode: " << cnode->DebugString()
1328                                << ", func_abs: " << func_abs->ToString();
1329   }
1330   if (!partial_node->isa<CNode>()) {
1331     MS_LOG(INTERNAL_EXCEPTION) << "Must be cnode, but " << partial_node->DebugString();
1332   }
1333   auto partial_cnode = partial_node->cast<CNodePtr>();
1334   constexpr auto extra_args_size = 2;
1335   if (partial_cnode->size() != partial_closure->args().size() + extra_args_size) {
1336     MS_LOG(INTERNAL_EXCEPTION) << "Size of cnode: " << partial_cnode->DebugString()
1337                                << " is not equal to 2 added to size of args: "
1338                                << mindspore::ToString(partial_closure->args());
1339   }
1340   auto attrs = std::make_shared<AttrValueMap>();
1341   for (size_t i = 0; i < partial_closure->args().size(); i++) {
1342     auto old_node = partial_cnode->input(i + extra_args_size);
1343     MS_EXCEPTION_IF_NULL(old_node);
1344     auto possibile_value_node = BuildPossibleValueNode(old_node, partial_closure->args()[i], attrs);
1345     if (possibile_value_node != nullptr) {
1346       partial_node_list.push_back(possibile_value_node);
1347     } else {
1348       if (!(old_node->isa<CNode>() || old_node->isa<Parameter>())) {
1349         MS_LOG(INTERNAL_EXCEPTION) << "Old node should be CNode or Parameter, but " << old_node->ToString();
1350       }
1351       partial_node_list.push_back(old_node);
1352     }
1353   }
1354   MS_EXCEPTION_IF_NULL(cnode->func_graph());
1355   auto wrapped_node = cnode->func_graph()->NewCNode(std::move(partial_node_list));
1356   wrapped_node->set_abstract(partial_closure);
1357   MS_LOG(DEBUG) << "cnode: " << cnode->DebugString() << ", func_abs: " << func_abs->ToString()
1358                 << ", wrapped_node: " << wrapped_node->DebugString();
1359   return wrapped_node;
1360 }
1361 
GetEvalCache(const EvaluatorPtr & eval)1362 const EvaluatorCacheMgrPtr FuncGraphSpecializer::GetEvalCache(const EvaluatorPtr &eval) {
1363   MS_EXCEPTION_IF_NULL(eval);
1364   auto cache_iter = eval_cache_.find(eval);
1365   if (cache_iter == eval_cache_.end()) {
1366     eval_cache_[eval] = eval->evaluator_cache_mgr();
1367     return eval->evaluator_cache_mgr();
1368   }
1369   return cache_iter->second;
1370 }
1371 
BuildFromBroadedArgs(const EvaluatorPtr & eval)1372 std::pair<AbstractBasePtrList, AbstractBasePtr> FuncGraphSpecializer::BuildFromBroadedArgs(const EvaluatorPtr &eval) {
1373   MS_EXCEPTION_IF_NULL(eval);
1374   std::unordered_set<AbstractBasePtrList, AbstractBasePtrListHasher, AbstractBasePtrListEqual> choices;
1375   EvalResultPtr res = nullptr;
1376   AbstractBasePtrList broaded_args_list;
1377   std::vector<AbstractBasePtrList> args_vector;
1378   auto eval_cache_iter = eval_cache_.find(eval);
1379   if (eval_cache_iter == eval_cache_.end()) {
1380     MS_LOG(INTERNAL_EXCEPTION) << "Evaluator: " << eval->ToString() << " not exist in cache.";
1381   }
1382   MS_EXCEPTION_IF_NULL(eval_cache_iter->second);
1383   auto &origin_eval_cache = eval_cache_iter->second->GetCache();
1384   for (auto &args_map : origin_eval_cache) {
1385     auto args = args_map.first;
1386     args_vector.push_back(args);
1387   }
1388   // If joinable, maybe choices size is 1 or dynamic shape.
1389   constexpr auto args_size = 2;
1390   if (args_vector.size() < args_size) {
1391     MS_LOG(INTERNAL_EXCEPTION) << "Should have " << args_size << " or more choices, but: " << args_vector.size();
1392   }
1393   AbstractBasePtrList joined_args = args_vector[0];
1394   for (size_t i = 1; i < args_vector.size(); ++i) {
1395     // The args may be not joinable (AbstractScalar join with AbstractTensor), just ignore that case.
1396     try {
1397       MS_LOG_TRY_CATCH_SCOPE;
1398       joined_args = abstract::AbstractJoin(joined_args, args_vector[i]);
1399     } catch (const std::exception &e) {
1400       MS_LOG(DEBUG) << "Cannot join, args1: " << ::mindspore::ToString(joined_args)
1401                     << ", args2: " << ::mindspore::ToString(args_vector[i]);
1402       return std::make_pair(AbstractBasePtrList(), nullptr);
1403     }
1404   }
1405   MS_LOG(DEBUG) << "Joined args list: " << joined_args.size() << ", " << ::mindspore::ToString(joined_args);
1406 
1407   EvaluatorCacheMgrPtr real = std::make_shared<EvaluatorCacheMgr>();
1408   const auto joined_eval_result = origin_eval_cache.get(joined_args);
1409   if (joined_eval_result != nullptr) {
1410     MS_LOG(DEBUG) << "Find unique choice in original eval cache for joined args list: "
1411                   << joined_eval_result->abstract()->ToString();
1412     real->SetValue(joined_args, joined_eval_result);
1413     eval_cache_[eval] = real;
1414     return std::make_pair(joined_args, joined_eval_result->abstract());
1415   }
1416   for (const auto &args : args_vector) {
1417     broaded_args_list.clear();
1418     BroadenArgs(args, &broaded_args_list);
1419     (void)choices.insert(broaded_args_list);
1420     MS_LOG(DEBUG) << "Broaded args list: " << broaded_args_list.size() << ", "
1421                   << ::mindspore::ToString(broaded_args_list);
1422   }
1423   if (choices.size() == 1) {
1424     ConfigPtrList args_conf_list;
1425     (void)std::transform(broaded_args_list.cbegin(), broaded_args_list.cend(), std ::back_inserter(args_conf_list),
1426                          [](const AbstractBasePtr &v) -> ConfigPtr { return std::make_shared<VirtualConfig>(v); });
1427     MS_LOG(DEBUG) << "Cannot find joined args in cache, run with broaded args list: " << broaded_args_list.size()
1428                   << ", " << ::mindspore::ToString(broaded_args_list);
1429     res = eval->SingleRun(engine_, args_conf_list, nullptr);
1430     MS_EXCEPTION_IF_NULL(res);
1431     real->SetValue(broaded_args_list, res);
1432     eval_cache_[eval] = real;
1433     return std::make_pair(broaded_args_list, res->abstract());
1434   }
1435   MS_LOG(DEBUG) << "Choices.size: " << choices.size();
1436   return std::make_pair(AbstractBasePtrList(), nullptr);
1437 }
1438 
1439 namespace {
IsHighOrderCall(const AnfNodePtr & func)1440 bool IsHighOrderCall(const AnfNodePtr &func) {
1441   return !func->isa<ValueNode>() && func->abstract()->isa<AbstractFunction>() &&
1442          !func->abstract()->isa<AbstractFuncUnion>();
1443 }
1444 
1445 // Update inputs' user data from their abstracts to nodes.
UpdateInputsUserData(const CNodePtr & old_cnode,const AnfNodeWeakPtrList & new_weak_inputs)1446 void UpdateInputsUserData(const CNodePtr &old_cnode, const AnfNodeWeakPtrList &new_weak_inputs) {
1447   const auto &old_weak_inputs = old_cnode->weak_inputs();
1448   if (old_weak_inputs.size() != new_weak_inputs.size()) {
1449     MS_LOG(DEBUG) << "Old inputs size is not equal to new inputs size, node: " << old_cnode->DebugString();
1450     return;
1451   }
1452   // Update real type and shape info.
1453   for (size_t i = 0; i < old_cnode->size(); ++i) {
1454     const auto &old_input = old_weak_inputs[i].lock();
1455     MS_EXCEPTION_IF_NULL(old_input);
1456     const auto &old_input_abs = old_input->abstract();
1457     if (old_input_abs == nullptr) {
1458       MS_LOG(INTERNAL_EXCEPTION) << "The pointer 'old_input_abs' is null, old input node: " << old_input->DebugString();
1459     }
1460     auto new_weak_input = new_weak_inputs[i].lock();
1461     if (new_weak_input == nullptr) {
1462       MS_LOG(INTERNAL_EXCEPTION) << "The " << i << "th input is null, " << old_cnode->DebugString();
1463     }
1464     if (fallback::HasRealType(old_input_abs)) {
1465       const auto &real_type = fallback::GetRealType<AbstractBase, Type>(old_input_abs);
1466       fallback::SetRealType<AnfNode, Type>(new_weak_input, real_type);
1467     }
1468     if (fallback::HasRealShape(old_input_abs)) {
1469       const auto &real_type = fallback::GetRealShape<AbstractBase, BaseShape>(old_input_abs);
1470       fallback::SetRealShape<AnfNode, BaseShape>(new_weak_input, real_type);
1471     }
1472     if (fallback::HasObjInExtraInfoHolder(old_input_abs)) {
1473       MS_LOG(DEBUG) << "Inherit python list object from old input abstract.";
1474       auto list_py_obj = fallback::GetObjFromExtraInfoHolder(old_input_abs);
1475       fallback::AttachPyObjToExtraInfoHolder(new_weak_input->abstract(), list_py_obj, false);
1476     }
1477   }
1478 }
1479 
BuildRealInputsFromPartialCNode(const AnfNodePtr & func,AnfNodeWeakPtrList * new_inputs_ptr)1480 AnfNodePtr BuildRealInputsFromPartialCNode(const AnfNodePtr &func, AnfNodeWeakPtrList *new_inputs_ptr) {
1481   auto &new_inputs = *new_inputs_ptr;
1482   AnfNodePtr real_func = func;
1483   constexpr int arg_start_index = 2;
1484   while (IsPrimitiveCNode(real_func, prim::kPrimPartial)) {
1485     auto func_cnode = real_func->cast_ptr<CNode>();
1486     MS_EXCEPTION_IF_NULL(func_cnode);
1487     auto &inputs = func_cnode->weak_inputs();
1488     // First element is partial, second is func so arg is start from 2
1489     (void)new_inputs.insert(new_inputs.cbegin(), inputs.cbegin() + arg_start_index, inputs.cend());
1490     real_func = inputs[1].lock();
1491     MS_LOG(DEBUG) << "Real func: " << real_func->ToString() << ", func_cnode: " << func_cnode->DebugString()
1492                   << ", new_inputs size: " << new_inputs.size();
1493   }
1494   return real_func;
1495 }
1496 
1497 // If it's Partial CNode, repack the inputs.
1498 // CNode(CNode(Partial, f, arg1), arg2, ...) --> CNode(f, arg1, arg2, ...)
GetCNodeRealInputs(const CNodePtr & cnode)1499 AnfNodeWeakPtrList GetCNodeRealInputs(const CNodePtr &cnode) {
1500   auto &inputs = cnode->weak_inputs();
1501   if (inputs.empty()) {
1502     MS_LOG(INTERNAL_EXCEPTION) << "Inputs of CNode is empty";
1503   }
1504   AnfNodePtr func = inputs[0].lock();
1505   MS_EXCEPTION_IF_NULL(func);
1506   if (!IsPrimitiveCNode(func, prim::kPrimPartial)) {
1507     return inputs;
1508   }
1509 
1510   // First element is func, so start from 1.
1511   AnfNodeWeakPtrList new_inputs(inputs.begin() + 1, inputs.end());
1512   func = BuildRealInputsFromPartialCNode(func, &new_inputs);
1513   (void)new_inputs.insert(new_inputs.cbegin(), func);
1514   cnode->func_graph()->AddOwnNode(func);
1515   return new_inputs;
1516 }
1517 }  // namespace
1518 
ProcessCNodeEnd(const CNodePtr & cnode,const AnfNodeWeakPtrList & new_weak_inputs)1519 void FuncGraphSpecializer::ProcessCNodeEnd(const CNodePtr &cnode, const AnfNodeWeakPtrList &new_weak_inputs) {
1520   // Update inputs' user data from their abstracts to nodes.
1521   UpdateInputsUserData(cnode, new_weak_inputs);
1522   // Set the updated inputs.
1523   cnode->set_weak_inputs(new_weak_inputs);
1524 
1525   // Eliminate the unused elements in the tuple/list.
1526   static const auto enable_eliminate_unused_element = (common::GetCompileConfig("ENABLE_DDE") != "0");
1527   static const auto enable_only_mark_unused_element = (common::GetCompileConfig("DDE_ONLY_MARK") == "1");
1528   if (enable_eliminate_unused_element && !enable_only_mark_unused_element) {
1529     EliminateUnusedSequenceItem(cnode);
1530   }
1531   constexpr auto recursive_level = 2;
1532   // Only success processed node can be added to seen.
1533   MS_LOG(DEBUG) << "New CNode: " << cnode->DebugString(recursive_level);
1534   specializer_->AddSeen(cnode);
1535 }
1536 
1537 // Process Switch App CNode in advance.
1538 // Including: Switch App CNode, Switch CNode, and Switch inputs CNodes(Partial CNode).
ProcessSwitchAppCNode(const CNodePtr & cnode)1539 bool FuncGraphSpecializer::ProcessSwitchAppCNode(const CNodePtr &cnode) {
1540   auto new_switch_app_inputs = cnode->weak_inputs();
1541   if (new_switch_app_inputs.empty()) {
1542     MS_LOG(INTERNAL_EXCEPTION) << "Inputs of CNode is empty";
1543   }
1544   const AnfNodePtr &func = new_switch_app_inputs[0].lock();
1545   MS_EXCEPTION_IF_NULL(func);
1546   if (!IsPrimitiveCNode(func, prim::kPrimSwitch)) {
1547     return false;
1548   }
1549   const auto &switch_cnode = dyn_cast<CNode>(func);
1550   auto new_switch_inputs = switch_cnode->weak_inputs();
1551   if (new_switch_inputs.empty()) {
1552     MS_LOG(INTERNAL_EXCEPTION) << "Switch CNode input is empty";
1553   }
1554 
1555   // Specialize the switch app fg arguments, from index 1(cond).
1556   bool finished = true;
1557   constexpr size_t switch_fg_arg_start_index = 1;
1558   constexpr size_t switch_fg_arg_end_index = 4;
1559   for (size_t i = switch_fg_arg_start_index; i < switch_fg_arg_end_index; ++i) {
1560     auto switch_input_node = new_switch_inputs[i].lock();
1561     MS_EXCEPTION_IF_NULL(switch_input_node);
1562     CNodePtr switch_input_cnode = nullptr;
1563     AnfNodePtr real_switch_input_cnode_func = nullptr;
1564     AnfNodeWeakPtrList real_switch_input_cnode_inputs;
1565     if (IsPrimitiveCNode(switch_input_node, prim::kPrimPartial)) {
1566       switch_input_cnode = dyn_cast<CNode>(switch_input_node);
1567       MS_EXCEPTION_IF_NULL(switch_input_cnode);
1568       real_switch_input_cnode_func =
1569         BuildRealInputsFromPartialCNode(switch_input_cnode, &real_switch_input_cnode_inputs);
1570     } else {
1571       if (!IsValueNode<FuncGraph>(switch_input_node)) {
1572         // The Switch input[i] is not Partial CNode, or FuncGraph node
1573         continue;
1574       }
1575       real_switch_input_cnode_func = switch_input_node;
1576       // Since BuildSpecializedNode() 1st argument CNode is used for debug info, we use switch node for FuncGraph input.
1577       switch_input_cnode = switch_cnode;
1578     }
1579 
1580     if (!CanSpecializeValueNode(real_switch_input_cnode_func)) {
1581       continue;
1582     }
1583     constexpr size_t switch_app_arg_start_index = 1;
1584     for (size_t j = switch_app_arg_start_index; j < new_switch_app_inputs.size(); ++j) {
1585       (void)real_switch_input_cnode_inputs.emplace_back(new_switch_app_inputs[j]);
1586     }
1587     AbstractBasePtrList args;
1588     AbstractBasePtr func_abs = real_switch_input_cnode_func->abstract();
1589     // First element is function, so the arguments start from 1.
1590     for (size_t j = 0; j < real_switch_input_cnode_inputs.size(); ++j) {
1591       args.push_back(real_switch_input_cnode_inputs[j].lock()->abstract());
1592     }
1593     auto specialized_func_node = BuildSpecializedNode(switch_input_cnode, real_switch_input_cnode_func, func_abs, args);
1594     if (specialized_func_node == nullptr) {
1595       finished = false;
1596       continue;
1597     }
1598     if (!finished) {
1599       continue;
1600     }
1601     // Rebuild a Partial CNode.
1602     if (!IsDeadNode(specialized_func_node) && IsPrimitiveCNode(switch_input_node, prim::kPrimPartial)) {
1603       // Fill new Partial CNode's inputs list.
1604       AnfNodePtr partial_value_node = NewValueNode(prim::kPrimPartial);
1605       partial_value_node->set_abstract(FromValueInside(prim::kPrimPartial));
1606       partial_value_node->set_debug_info(switch_input_node->debug_info());
1607       MS_EXCEPTION_IF_NULL(switch_input_cnode->func_graph());
1608       switch_input_cnode->func_graph()->AddOwnNode(partial_value_node);
1609       switch_input_cnode->func_graph()->AddOwnNode(specialized_func_node);
1610       AnfNodeWeakPtrList partial_node_list = {partial_value_node, specialized_func_node};
1611       // Specialize Partial CNode func graph inputs.
1612       constexpr auto partial_arg_start_index = 2;
1613       (void)std::copy(switch_input_cnode->weak_inputs().cbegin() + partial_arg_start_index,
1614                       switch_input_cnode->weak_inputs().cend(), std::back_inserter(partial_node_list));
1615       for (size_t j = partial_arg_start_index; j < partial_node_list.size(); ++j) {
1616         auto old_node = partial_node_list[j].lock();
1617         MS_EXCEPTION_IF_NULL(old_node);
1618         if (CanSpecializeValueNode(old_node)) {
1619           auto new_partial_input_node =
1620             BuildSpecializedNode(switch_input_cnode, old_node, old_node->abstract(), std::vector<AbstractBasePtr>{});
1621           if (new_partial_input_node == nullptr) {
1622             return false;
1623           }
1624           partial_node_list[j] = new_partial_input_node;
1625           switch_input_cnode->func_graph()->AddOwnNode(new_partial_input_node);
1626         }
1627       }
1628 
1629       // Finish the Partial CNode specialize.
1630       MS_EXCEPTION_IF_NULL(switch_input_cnode);
1631       ProcessCNodeEnd(switch_input_cnode, partial_node_list);
1632       new_switch_inputs[i] = switch_input_cnode;
1633     } else {
1634       new_switch_inputs[i] = specialized_func_node;
1635     }
1636   }
1637 
1638   // Wait for sub func graph specialize finish.
1639   if (!finished) {
1640     return false;
1641   }
1642 
1643   ProcessCNodeEnd(switch_cnode, new_switch_inputs);
1644 
1645   new_switch_app_inputs[0] = switch_cnode;
1646   ProcessCNodeEnd(cnode, new_switch_app_inputs);
1647 
1648   return true;
1649 }
1650 
ProcessCNode(const CNodePtr & cnode)1651 bool FuncGraphSpecializer::ProcessCNode(const CNodePtr &cnode) {
1652   MS_EXCEPTION_IF_NULL(cnode);
1653   if (specializer_->seen().count(cnode) > 0) {
1654     return true;
1655   }
1656   constexpr auto recursive_level = 2;
1657   MS_LOG(DEBUG) << "Handle CNode: " << cnode->DebugString(recursive_level);
1658   auto new_inputs = GetCNodeRealInputs(cnode);
1659   const AnfNodePtr &func = new_inputs[0].lock();
1660 
1661   // Deal with Switch App CNode.
1662   static const bool enable_pre_lift = (common::GetCompileConfig("PRE_LIFT") == "1");
1663   if (enable_pre_lift && IsPrimitiveCNode(func, prim::kPrimSwitch)) {
1664     return ProcessSwitchAppCNode(cnode);
1665   }
1666 
1667   // Deal with the CNode|Parameter function call including Partial closure ahead.
1668   if (IsHighOrderCall(func)) {
1669     MS_EXCEPTION_IF_NULL(func->abstract());
1670     auto func_abs = func->abstract()->cast<AbstractFunctionPtr>();
1671     EvaluatorPtr eval = engine_->GetEvaluatorFor(func_abs);
1672     std::pair<AbstractBasePtrList, AbstractBasePtr> result;
1673     AbstractBasePtrList empty_args;
1674     auto status = AcquireUniqueEvalResult(func_abs, eval, empty_args, &result);
1675     MS_EXCEPTION_IF_NULL(func->func_graph());
1676     MS_LOG(DEBUG) << "POLY: " << (status == kSpecializePoly) << ", func: " << func->ToString()
1677                   << ", abstract: " << func_abs->ToString() << ", "
1678                   << func->func_graph()->has_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER);
1679     // If a node is a poly node, or an input parameter is a PartialAbstractClosure, expand it early.
1680     if (status == kSpecializePoly ||
1681         (func->isa<Parameter>() && func->func_graph()->has_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER))) {
1682       auto wrapped_node = BuildSpecializedParameterCNode(cnode);
1683       if (wrapped_node == nullptr) {
1684         return false;
1685       }
1686       MS_LOG(DEBUG) << "Partial closure or parameter call is handled, wrapped_node: "
1687                     << wrapped_node->DebugString(recursive_level);
1688       new_inputs[0] = wrapped_node;
1689       cnode->func_graph()->AddOwnNode(wrapped_node);
1690     }
1691   }
1692 
1693   // Specialize the function, aka inputs[0], if input0 is a ValueNode<FuncGraph> or ValueNode<Primitive>,
1694   // CanSpecializeValueNode return true, otherwise false.
1695   if (CanSpecializeValueNode(func)) {
1696     // For primitive node, we build the primitive node with inferred attributes in the first pass,
1697     // so we do not build replaced node again here in second pass.
1698     if (IsValueNode<Primitive>(func)) {
1699       new_inputs[0] = func;
1700       cnode->func_graph()->AddOwnNode(func);
1701     } else {
1702       AbstractBasePtrList args;
1703       AbstractBasePtr func_abs = new_inputs[0].lock()->abstract();
1704       // First element is function, so the arguments start from 1.
1705       for (size_t i = 1; i < new_inputs.size(); ++i) {
1706         args.push_back(new_inputs[i].lock()->abstract());
1707       }
1708       auto specialized_func_node = BuildSpecializedNode(cnode, func, func_abs, args);
1709       if (specialized_func_node == nullptr) {
1710         return false;
1711       }
1712 
1713       new_inputs[0] = specialized_func_node;
1714       cnode->func_graph()->AddOwnNode(specialized_func_node);
1715       MS_LOG(DEBUG) << "Specalize func: " << func->type_name() << "/" << func->DebugString(recursive_level)
1716                     << ", new_func: " << new_inputs[0].lock()->DebugString(recursive_level) << ", args: " << args;
1717     }
1718   }
1719 
1720   // Specialize the arguments, except inputs[0].
1721   for (size_t i = 1; i < new_inputs.size(); ++i) {
1722     auto old_node = new_inputs[i].lock();
1723     if (CanSpecializeValueNode(old_node)) {
1724       auto new_node = BuildSpecializedNode(cnode, old_node, old_node->abstract(), std::vector<AbstractBasePtr>{});
1725       if (new_node == nullptr) {
1726         return false;
1727       }
1728 
1729       MS_LOG(DEBUG) << "Specalize arg[" << i << "]: " << old_node->DebugString(recursive_level)
1730                     << ", new_node: " << new_node->DebugString(recursive_level);
1731       new_inputs[i] = new_node;
1732       cnode->func_graph()->AddOwnNode(new_node);
1733     }
1734   }
1735   ProcessCNodeEnd(cnode, new_inputs);
1736   return true;
1737 }
1738 
ParentNotSpecialized(const AnalysisContextPtr & context) const1739 bool FuncGraphSpecializer::ParentNotSpecialized(const AnalysisContextPtr &context) const {
1740   auto parent_context = context->parent();
1741   auto parent_specializer = specializer_->GetFuncGraphSpecializer(parent_context);
1742   // If can't get specializer of parent and parent is not DummyContext, it means parent not specialized.
1743   auto parent_not_specialized = parent_specializer == nullptr && parent_context->func_graph() != nullptr;
1744   return parent_not_specialized;
1745 }
1746 
1747 namespace {
DumpEvaluatorCache(const EvaluatorPtr & eval,const AbstractBasePtrList & args_abs_list)1748 void DumpEvaluatorCache(const EvaluatorPtr &eval, const AbstractBasePtrList &args_abs_list) {
1749   MS_EXCEPTION_IF_NULL(eval);
1750   const EvaluatorCacheMgrPtr &evaluator_cache_mgr = eval->evaluator_cache_mgr();
1751   MS_EXCEPTION_IF_NULL(evaluator_cache_mgr);
1752   MS_LOG(DEBUG) << "Find unique args_abs_list failed, total " << args_abs_list.size() << ". Check cache all items.";
1753   MS_LOG(DEBUG) << "[" << eval << "/" << eval->ToString()
1754                 << "] Dump current key, args_abs_list hash: " << AbstractBasePtrListHash(args_abs_list)
1755                 << ", args_abs_list: " << args_abs_list;
1756 
1757   int64_t i = 0;
1758   const EvalResultCache &map = evaluator_cache_mgr->GetCache();
1759   for (const auto &item : map) {
1760     MS_LOG(DEBUG) << "\tevaluator_cache[" << i++ << "]: {args_abs_list hash: " << AbstractBasePtrListHash(item.first)
1761                   << ", args_abs_list: " << item.first << "}";
1762   }
1763 }
1764 
IsPolyFunc(const AbstractFunctionPtr & func,const AbstractBasePtrList & args_abs_list)1765 bool IsPolyFunc(const AbstractFunctionPtr &func, const AbstractBasePtrList &args_abs_list) {
1766   MS_EXCEPTION_IF_NULL(func);
1767   if (func->isa<PrimitiveAbstractClosure>() && args_abs_list.empty()) {
1768     MS_LOG(DEBUG) << "High order primitive return POLY.";
1769     return true;
1770   }
1771   if (func->isa<MetaFuncGraphAbstractClosure>() && args_abs_list.empty()) {
1772     auto meta_func_graph_wrapper = dyn_cast_ptr<MetaFuncGraphAbstractClosure>(func);
1773     auto meta_func_graph = meta_func_graph_wrapper->meta_func_graph();
1774     if (meta_func_graph != nullptr && meta_func_graph->isa<prim::DoSignatureMetaFuncGraph>()) {
1775       auto do_signature = dyn_cast_ptr<prim::DoSignatureMetaFuncGraph>(meta_func_graph);
1776       if (do_signature != nullptr && do_signature->function()->isa<Primitive>()) {
1777         MS_LOG(DEBUG) << "High order primitive " << do_signature->function()->ToString() << " return POLY.";
1778         return true;
1779       }
1780     }
1781   }
1782   return false;
1783 }
1784 }  // namespace
1785 
AcquireUniqueEvalResult(const AbstractFunctionPtr & func,const EvaluatorPtr & eval,const AbstractBasePtrList & args_abs_list,std::pair<AbstractBasePtrList,AbstractBasePtr> * res)1786 SpecializeStatusCode FuncGraphSpecializer::AcquireUniqueEvalResult(
1787   const AbstractFunctionPtr &func, const EvaluatorPtr &eval, const AbstractBasePtrList &args_abs_list,
1788   std::pair<AbstractBasePtrList, AbstractBasePtr> *res) {
1789   MS_EXCEPTION_IF_NULL(func);
1790   MS_EXCEPTION_IF_NULL(eval);
1791   MS_EXCEPTION_IF_NULL(res);
1792 
1793   EvaluatorCacheMgrPtr evaluator_cache_mgr = eval->evaluator_cache_mgr();
1794   MS_EXCEPTION_IF_NULL(evaluator_cache_mgr);
1795   auto data = evaluator_cache_mgr->GetValue(args_abs_list);
1796   if (data != nullptr) {
1797     *res = std::make_pair(args_abs_list, data->abstract());
1798     return kSpecializeSuccess;
1799   }
1800   DumpEvaluatorCache(eval, args_abs_list);
1801 
1802   auto cache = GetEvalCache(eval);
1803   MS_EXCEPTION_IF_NULL(cache);
1804   const EvalResultCache &choices = cache->GetCache();
1805   auto eval_result = choices.get(args_abs_list);
1806   if (eval_result != nullptr) {
1807     *res = std::make_pair(args_abs_list, eval_result->abstract());
1808     return kSpecializeSuccess;
1809   } else if (choices.size() == 1) {
1810     MS_LOG(DEBUG) << "Evaluator cache has a single item, just use it.";
1811     MS_EXCEPTION_IF_NULL(choices.begin()->second);
1812     *res = std::make_pair(choices.begin()->first, choices.begin()->second->abstract());
1813     return kSpecializeSuccess;
1814   } else if (choices.empty()) {
1815     MS_LOG(DEBUG) << "Find DEAD code, it may be optimized in later phase " << func->ToString() << " | "
1816                   << func->type_name() << ", evaluator: " << eval->ToString() << ", ptr: " << eval.get();
1817     return kSpecializeDead;
1818   } else {
1819     if (IsPolyFunc(func, args_abs_list)) {
1820       return kSpecializePoly;
1821     }
1822     *res = BuildFromBroadedArgs(eval);
1823     if (!res->first.empty()) {
1824       MS_LOG(DEBUG) << "Build for generalized args_abs_list successfully.";
1825       // Synchronize the new evaluated abstract with the abstract from common evaluating routine.
1826       MS_EXCEPTION_IF_NULL(res->second);
1827       auto new_sequence_abs = dyn_cast<abstract::AbstractSequence>(res->second);
1828       for (auto &choice : choices) {
1829         MS_EXCEPTION_IF_NULL(choice.second);
1830         MS_EXCEPTION_IF_NULL(choice.second->abstract());
1831         auto abs = choice.second->abstract()->cast<AbstractSequencePtr>();
1832         if (abs != nullptr) {
1833           SynchronizeSequenceElementsUseFlagsRecursively(abs, new_sequence_abs);
1834         }
1835       }
1836       return kSpecializeSuccess;
1837     }
1838     MS_LOG(DEBUG) << "Found POLY node, it may be unused code or unresolved polymorphism, "
1839                   << "func: " << func->ToString() << ", choices.size: " << choices.size()
1840                   << ", args_abs_list.size: " << args_abs_list.size();
1841     return kSpecializePoly;
1842   }
1843 }
1844 
BuildPrimtiveValueWithAttributes(const PrimitivePtr & prim,const AttrValueMapPtr & attrs)1845 static PrimitivePtr BuildPrimtiveValueWithAttributes(const PrimitivePtr &prim, const AttrValueMapPtr &attrs) {
1846   MS_EXCEPTION_IF_NULL(prim);
1847   auto &prim_attrs = prim->attrs();
1848   bool is_attr_same = true;
1849   for (auto &item : *attrs) {
1850     auto itr = prim_attrs.find(item.first);
1851     if (itr != prim_attrs.end()) {
1852       MS_EXCEPTION_IF_NULL(itr->second);
1853       MS_EXCEPTION_IF_NULL(item.second);
1854       if (!(*(itr->second) == *(item.second))) {
1855         is_attr_same = false;
1856         break;
1857       }
1858     } else {
1859       is_attr_same = false;
1860       break;
1861     }
1862   }
1863   if (!is_attr_same) {
1864     auto cloned_prim = prim->Clone();
1865     for (auto &item : *attrs) {
1866       cloned_prim->AddAttr(item.first, item.second);
1867     }
1868     return cloned_prim;
1869   }
1870   return prim;
1871 }
1872 
GetValueForAbstractFunction(const AbstractFunctionPtr & abs,const AttrValueMapPtr & attrs)1873 ValuePtr GetValueForAbstractFunction(const AbstractFunctionPtr &abs, const AttrValueMapPtr &attrs) {
1874   ValuePtr value = nullptr;
1875   if (abs->isa<PrimitiveAbstractClosure>()) {
1876     auto real_fn = dyn_cast_ptr<PrimitiveAbstractClosure>(abs);
1877     MS_EXCEPTION_IF_NULL(real_fn);
1878     // For primitive, check if the attribute is the same with cnode inferred attribute, if not, clone a new one
1879     if (attrs != nullptr) {
1880       value = BuildPrimtiveValueWithAttributes(real_fn->prim(), attrs);
1881     } else {
1882       value = real_fn->prim();
1883     }
1884   } else if (abs->isa<MetaFuncGraphAbstractClosure>()) {
1885     auto real_fn = dyn_cast_ptr<MetaFuncGraphAbstractClosure>(abs);
1886     value = real_fn->meta_func_graph();
1887   } else if (abs->isa<FuncGraphAbstractClosure>()) {
1888     auto real_fn = dyn_cast_ptr<FuncGraphAbstractClosure>(abs);
1889     value = real_fn->func_graph();
1890   } else {
1891     return nullptr;
1892   }
1893   return value;
1894 }
1895 
BuildValueNodeForAbstractFunction(const AnfNodePtr & origin_node,const AbstractBasePtr & ival,const AttrValueMapPtr & attrs,const AnfNodePtr & cnode,const AbstractFunctionPtr & abs)1896 AnfNodePtr FuncGraphSpecializer::BuildValueNodeForAbstractFunction(const AnfNodePtr &origin_node,
1897                                                                    const AbstractBasePtr &ival,
1898                                                                    const AttrValueMapPtr &attrs,
1899                                                                    const AnfNodePtr &cnode,
1900                                                                    const AbstractFunctionPtr &abs) {
1901   ValuePtr value = GetValueForAbstractFunction(abs, attrs);
1902   if (value == nullptr) {
1903     return nullptr;
1904   }
1905   if (value->isa<FuncGraph>() && value->cast_ptr<FuncGraph>()->has_flag(FUNC_GRAPH_RECOMPUTE_GRAD_GRAPH)) {
1906     return nullptr;
1907   }
1908   if (!value->isa<FuncGraph>() || value->cast_ptr<FuncGraph>()->parent() == nullptr ||
1909       (IsValueNode<FuncGraph>(origin_node) && IsVisible(func_graph_, value->cast_ptr<FuncGraph>()->parent()))) {
1910     if (IS_OUTPUT_ON(MsLogLevel::kDebug)) {
1911       if (cnode != nullptr) {
1912         MS_LOG(DEBUG) << "Specialize non-value to func graph, value: " << value->ToString()
1913                       << ", cnode: " << cnode->DebugString() << ", origin_node: " << origin_node->DebugString()
1914                       << ", func_graph_: " << func_graph_->ToString();
1915       }
1916       if (value->isa<FuncGraph>() && value->cast_ptr<FuncGraph>()->parent() != nullptr) {
1917         MS_LOG(DEBUG) << "Specialize func graph, " << value->ToString()
1918                       << " has_parent, is_visible: " << IsVisible(func_graph_, value->cast_ptr<FuncGraph>()->parent());
1919       }
1920     }
1921     return BuildValueNode(value, origin_node, ival);
1922   } else if (cnode != nullptr && IsPrimitiveCNode(cnode, prim::kPrimJ) && origin_node->isa<Parameter>() &&
1923              !value->cast_ptr<FuncGraph>()->has_flag(FUNC_GRAPH_FLAG_K_GRAPH)) {
1924     // Only if J(Parameter=func_graph) and func_graph(aka 'value') is not K graph.
1925     MS_LOG(DEBUG) << "Specialize the parameter used by J CNode, cnode: " << cnode->DebugString();
1926     return BuildValueNode(value, origin_node, ival);
1927   }
1928   return nullptr;
1929 }
1930 
BuildPossibleValueNode(const AnfNodePtr & origin_node,const AbstractBasePtr & ival,const AttrValueMapPtr & attrs,const AnfNodePtr & cnode)1931 AnfNodePtr FuncGraphSpecializer::BuildPossibleValueNode(const AnfNodePtr &origin_node, const AbstractBasePtr &ival,
1932                                                         const AttrValueMapPtr &attrs, const AnfNodePtr &cnode) {
1933   MS_EXCEPTION_IF_NULL(origin_node);
1934   MS_EXCEPTION_IF_NULL(ival);
1935 
1936   AbstractFunctionPtr abs = dyn_cast<AbstractFunction>(ival);
1937   if (abs != nullptr) {
1938     // Cannot build a deterministic ValueNode if there are multiple possible AbstractFunction.
1939     if (abs->isa<AbstractFuncUnion>()) {
1940       return nullptr;
1941     }
1942     return BuildValueNodeForAbstractFunction(origin_node, ival, attrs, cnode, abs);
1943   } else {
1944     ValuePtr val = ival->BuildValue();
1945     if (val->ContainsValueAny()) {
1946       return nullptr;
1947     }
1948     // If node is an AutoMonad node, don't convert the node to value node `U` or `IO` to avoid side-effect op miss.
1949     if (val->isa<Monad>()) {
1950       return nullptr;
1951     }
1952     // Keep primitive 'depend' not to be optimized
1953     if (IsPrimitiveCNode(origin_node, prim::kPrimDepend)) {
1954       return nullptr;
1955     }
1956     // Keep primitive 'ListInplaceClear' not to be optimized
1957     if (IsPrimitiveCNode(origin_node, prim::kPrimListInplaceClear)) {
1958       return nullptr;
1959     }
1960     // Keep primitive 'PyExecute' not to be optimized
1961     if (IsPrimitiveCNode(origin_node, prim::kPrimPyExecute)) {
1962       return nullptr;
1963     }
1964     return BuildValueNode(val, origin_node, ival);
1965   }
1966 }
1967 
GetAnalysisContext(const AnalysisEnginePtr & engine,const BaseFuncGraphEvaluatorPtr & evaluator,const AbstractBasePtrList & args_abs_list) const1968 inline AnalysisContextPtr FuncGraphSpecializer::GetAnalysisContext(const AnalysisEnginePtr &engine,
1969                                                                    const BaseFuncGraphEvaluatorPtr &evaluator,
1970                                                                    const AbstractBasePtrList &args_abs_list) const {
1971   MS_EXCEPTION_IF_NULL(evaluator);
1972   // If it is common calling header, try to use the context generated by the infer process of body calling header, so
1973   // need broaden the args to keep context of common calling header same with context of body calling header.
1974   AbstractBasePtrList normalized_args_abs_list = evaluator->NormalizeArgs(args_abs_list);
1975   FuncGraphPtr fg = evaluator->GetFuncGraph(engine, normalized_args_abs_list);
1976   auto parent_context = evaluator->parent_context();
1977   MS_EXCEPTION_IF_NULL(parent_context);
1978   auto cached_context = parent_context->GetCachedContext(fg, normalized_args_abs_list);
1979   if (cached_context != nullptr) {
1980     return cached_context;
1981   }
1982   // If can't get context by broadened args, try to get context by not broadened args.
1983   cached_context = parent_context->GetCachedContext(fg, args_abs_list);
1984   if (cached_context != nullptr) {
1985     return cached_context;
1986   }
1987   // if it is a bprop meta func graph, need to make a new context and do static analysis in ProcessNode.
1988   return NewContext(parent_context, fg, normalized_args_abs_list);
1989 }
1990 }  // namespace abstract
1991 }  // namespace mindspore
1992