• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3  *
4  * Copyright 2019-2021 Huawei Technologies Co., Ltd
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #include "pipeline/jit/static_analysis/program_specialize.h"
20 
21 #include <algorithm>
22 #include <exception>
23 #include "frontend/operator/ops.h"
24 #include "frontend/operator/composite/do_signature.h"
25 #include "abstract/abstract_function.h"
26 #include "abstract/utils.h"
27 #include "utils/utils.h"
28 #include "ir/graph_utils.h"
29 #include "utils/log_adapter.h"
30 #include "debug/trace.h"
31 
32 namespace mindspore {
33 namespace abstract {
34 namespace {
GetEvaluatedValue(const AnfNodeConfigPtr & conf)35 inline AbstractBasePtr GetEvaluatedValue(const AnfNodeConfigPtr &conf) {
36   MS_EXCEPTION_IF_NULL(conf);
37   if (conf->node()->intermediate_abstract()) {
38     return conf->node()->intermediate_abstract();
39   }
40   MS_EXCEPTION_IF_NULL(conf->ObtainEvalResult());
41   return conf->ObtainEvalResult()->abstract();
42 }
43 
BuildValueNode(const ValuePtr & v,const AbstractBasePtr & abs_base)44 AnfNodePtr BuildValueNode(const ValuePtr &v, const AbstractBasePtr &abs_base) {
45   MS_EXCEPTION_IF_NULL(abs_base);
46   AnfNodePtr value_node = NewValueNode(v);
47   value_node->set_abstract(abs_base);
48   MS_LOG(DEBUG) << "Create ValueNode: " << value_node->ToString() << ", with abstract: " << abs_base->ToString();
49   return value_node;
50 }
51 
IsVisible(FuncGraphPtr fg,const FuncGraphPtr & parent)52 bool IsVisible(FuncGraphPtr fg, const FuncGraphPtr &parent) {
53   while (fg != nullptr && fg != parent) {
54     fg = fg->parent();
55   }
56   return fg == parent;
57 }
58 
CheckAbstractTensor(const AbstractBasePtr & abs_base)59 bool CheckAbstractTensor(const AbstractBasePtr &abs_base) {
60   MS_EXCEPTION_IF_NULL(abs_base);
61   if (abs_base->isa<AbstractTensor>()) {
62     return true;
63   } else if (abs_base->isa<AbstractSequeue>()) {
64     const auto &abs_seq = abs_base->cast<AbstractSequeuePtr>();
65     MS_EXCEPTION_IF_NULL(abs_seq);
66     const auto &elements = abs_seq->elements();
67     return std::all_of(elements.cbegin(), elements.cend(), [](const auto &v) { return CheckAbstractTensor(v); });
68   } else {
69     return false;
70   }
71 }
72 }  // namespace
73 
Run(const FuncGraphPtr & fg,const AnalysisContextPtr & context)74 FuncGraphPtr ProgramSpecializer::Run(const FuncGraphPtr &fg, const AnalysisContextPtr &context) {
75   MS_EXCEPTION_IF_NULL(fg);
76   MS_EXCEPTION_IF_NULL(context);
77   MS_LOG(DEBUG) << "Specialize topmost function graph: "
78                 << (context->func_graph() ? context->func_graph()->ToString() : "FG(Null)");
79   if (top_context_ == nullptr) {
80     top_context_ = context;
81     MS_LOG(INFO) << "Specialize set top func graph context: " << context->ToString();
82   }
83   return SpecializeFuncGraph(fg, context);
84 }
85 
SpecializeFuncGraph(const FuncGraphPtr & fg,const AnalysisContextPtr & context)86 FuncGraphPtr ProgramSpecializer::SpecializeFuncGraph(const FuncGraphPtr &fg, const AnalysisContextPtr &context) {
87   MS_EXCEPTION_IF_NULL(fg);
88   MS_EXCEPTION_IF_NULL(context);
89   auto iter = specializations_.find(context->SpecializeKey());
90   if (iter != specializations_.end()) {
91     MS_EXCEPTION_IF_NULL(iter->second);
92     return iter->second->specialized_func_graph();
93   }
94 
95   std::shared_ptr<FuncGraphSpecializer> fg_spec = std::make_shared<FuncGraphSpecializer>(this, fg, context);
96   FuncGraphPtr fg2 = fg_spec->specialized_func_graph();
97   specializations_[context->SpecializeKey()] = fg_spec;
98   fg_spec->Run();
99   return fg2;
100 }
101 
GetFuncGraphSpecializer(const AnalysisContextPtr & context)102 std::shared_ptr<FuncGraphSpecializer> ProgramSpecializer::GetFuncGraphSpecializer(const AnalysisContextPtr &context) {
103   MS_EXCEPTION_IF_NULL(context);
104   auto iter = specializations_.find(context->SpecializeKey());
105   if (iter != specializations_.end()) {
106     return iter->second;
107   }
108   if (context->func_graph() != nullptr) {
109     MS_LOG(EXCEPTION) << "Specialize inner error";
110   }
111   return nullptr;
112 }
113 
GetNextCounter()114 std::string GetNextCounter() {
115   static int64_t g_CloneCounter = 1;
116   std::string str_count = std::to_string(g_CloneCounter);
117   g_CloneCounter++;
118   return str_count;
119 }
120 
FuncGraphSpecializer(ProgramSpecializer * const s,const FuncGraphPtr & fg,const AnalysisContextPtr & context)121 FuncGraphSpecializer::FuncGraphSpecializer(ProgramSpecializer *const s, const FuncGraphPtr &fg,
122                                            const AnalysisContextPtr &context)
123     : specializer_(s), func_graph_(fg), context_(context) {
124   parent_ = s->GetFuncGraphSpecializer(context->parent());
125   engine_ = s->engine();
126   cloner_ = SpecializerClone(fg, std::make_shared<TraceSpecialize>(GetNextCounter()));
127   repl_node_ = cloner_->cloned_node();
128   specialized_func_graph_ = cloner_->cloned_func_graph()[fg];
129   todo_.push_back(fg->get_return());
130   auto ps = fg->parameters();
131   (void)todo_.insert(todo_.end(), ps.begin(), ps.end());
132 }
133 
ReplicateDisconnectedNode(const AnfNodePtr & node)134 AnfNodePtr FuncGraphSpecializer::ReplicateDisconnectedNode(const AnfNodePtr &node) {
135   MS_EXCEPTION_IF_NULL(node);
136   if (node->isa<ValueNode>()) {
137     return node;
138   }
139   std::shared_ptr<FuncGraphSpecializer> specializer = GetTopSpecializer(node);
140 
141   // If had replicated, just return that.
142   MS_EXCEPTION_IF_NULL(specializer->repl_node_);
143   auto iter = specializer->repl_node_->find(node);
144   if (iter != specializer->repl_node_->end()) {
145     return iter->second;
146   }
147   auto new_node = specializer->cloner_->CloneDisconnected(node);
148   if (node->isa<CNode>()) {
149     if (!new_node->isa<CNode>()) {
150       MS_LOG(EXCEPTION) << "new_node must be a CNode, but is " << new_node->DebugString() << ".";
151     }
152     UpdateNewCNodeInputs(node, new_node);
153   }
154 
155   iter = specializer->repl_node_->find(node);
156   if (iter != specializer->repl_node_->end()) {
157     if (iter->second == node) {
158       MS_LOG(EXCEPTION) << "Replicated is same as original node, node: " << node->ToString();
159     }
160   } else {
161     MS_LOG(EXCEPTION) << "Replicate node failed, node: " << node->ToString();
162   }
163   return new_node;
164 }
165 
UpdateNewCNodeInputs(const AnfNodePtr & node,const AnfNodePtr & new_node)166 void FuncGraphSpecializer::UpdateNewCNodeInputs(const AnfNodePtr &node, const AnfNodePtr &new_node) {
167   MS_EXCEPTION_IF_NULL(node);
168   auto c_node = node->cast<CNodePtr>();
169   MS_EXCEPTION_IF_NULL(c_node);
170   auto inputs = c_node->inputs();
171   std::vector<AnfNodePtr> new_inputs;
172   (void)std::transform(
173     inputs.begin(), inputs.end(), std::back_inserter(new_inputs), [this](const AnfNodePtr &inp) -> AnfNodePtr {
174       auto new_inp = ReplicateDisconnectedNode(inp);
175       // Refer the comments in BuildReplacedNode.
176       if (inp->isa<CNode>()) {
177         auto c_inp = inp->cast<CNodePtr>();
178         MS_EXCEPTION_IF_NULL(c_inp);
179         auto c_new_inp = new_inp->cast<CNodePtr>();
180         MS_EXCEPTION_IF_NULL(c_new_inp);
181         MS_EXCEPTION_IF_NULL(c_new_inp->func_graph());
182         MS_LOG(DEBUG) << "Replace in order, inp node: " << inp->DebugString() << " -> " << new_inp->DebugString();
183         c_new_inp->func_graph()->ReplaceInOrder(c_inp, c_new_inp);
184       }
185       return new_inp;
186     });
187 
188   auto c_new_node = new_node->cast<CNodePtr>();
189   MS_EXCEPTION_IF_NULL(c_new_node);
190   c_new_node->set_inputs(new_inputs);
191 }
192 
GetReplicatedNode(const AnfNodePtr & node)193 AnfNodePtr FuncGraphSpecializer::GetReplicatedNode(const AnfNodePtr &node) {
194   std::shared_ptr<FuncGraphSpecializer> specializer = GetTopSpecializer(node);
195   MS_EXCEPTION_IF_NULL(specializer->repl_node_);
196   auto iter = specializer->repl_node_->find(node);
197   if (iter != specializer->repl_node_->end()) {
198     return iter->second;
199   }
200   return node;
201 }
202 
203 // Return itself if node's ValueNode as top,
204 // return the top func graph specializer as top if node's forward Parameter,
205 // or, return the top parent specializer as top.
GetTopSpecializer(const AnfNodePtr & node)206 std::shared_ptr<FuncGraphSpecializer> FuncGraphSpecializer::GetTopSpecializer(const AnfNodePtr &node) {
207   MS_EXCEPTION_IF_NULL(node);
208   FuncGraphPtr fg = node->func_graph();
209   if (fg == nullptr) {  // If ValueNode, return current specializer.
210     MS_LOG(DEBUG) << "Node's a ValueNode, node: " << node->DebugString();
211     return shared_from_this();
212   }
213   std::shared_ptr<FuncGraphSpecializer> specializer = shared_from_this();
214   while (fg != specializer->func_graph_) {
215     if (specializer->parent_ == nullptr && node->isa<Parameter>()) {
216       // If `parent_` is null and forwarded `node` is a Parameter, we'll try to use top func graph as parent.
217       MS_EXCEPTION_IF_NULL(specializer_->top_context());
218       if (specializer_->top_context()->func_graph() == fg) {  // `fg` is top func graph.
219         specializer = specializer_->GetFuncGraphSpecializer(specializer_->top_context());
220         MS_LOG(INFO) << "Used top func graph specializer as parent for "
221                      << (func_graph_ ? func_graph_->ToString() : "FG(Null)") << ", node: " << node->DebugString()
222                      << ", NodeInfo: " << trace::GetDebugInfo(node->debug_info());
223         MS_EXCEPTION_IF_NULL(specializer);
224         break;
225       }
226     } else {
227       specializer = specializer->parent_;
228     }
229     if (specializer == nullptr) {
230       MS_LOG(EXCEPTION) << "`specializer` should not be null, node: " << node->DebugString()
231                         << ", NodeInfo: " << trace::GetDebugInfo(node->debug_info()) << ".\n"
232                         << (func_graph_ ? func_graph_->ToString() : "FG(Null)")
233                         << " has no parent context? At least not " << fg->ToString();
234     }
235   }
236   return specializer;
237 }
238 
Run()239 void FuncGraphSpecializer::Run() {
240   MS_LOG(DEBUG) << "Before run, origin func graph name: " << (func_graph_ ? func_graph_->ToString() : "FG(Null)")
241                 << ", cloned func graph name: "
242                 << (specialized_func_graph_ ? specialized_func_graph_->ToString() : "FG(Null)") << ", func graph: "
243                 << (func_graph_ ? func_graph_->get_return() ? func_graph_->get_return()->DebugString() : "return null"
244                                 : "FG(null)");
245   FirstPass();
246   SecondPass();
247   MS_LOG(DEBUG) << "After run, origin func graph name: " << (func_graph_ ? func_graph_->ToString() : "FG(Null)")
248                 << ", cloned func graph name: "
249                 << (specialized_func_graph_ ? specialized_func_graph_->ToString() : "FG(Null)") << ", new func graph: "
250                 << (specialized_func_graph_ ? specialized_func_graph_->get_return()
251                                                 ? specialized_func_graph_->get_return()->DebugString()
252                                                 : "return null"
253                                             : "FG(null)");
254 }
255 
FirstPass()256 void FuncGraphSpecializer::FirstPass() {
257   while (todo_.size()) {
258     AnfNodePtr node = todo_.back();
259     todo_.pop_back();
260     if (node->func_graph() == nullptr) {
261       // do nothing for ValueNode
262       continue;
263     }
264     if (node->func_graph() != func_graph_) {
265       std::shared_ptr<FuncGraphSpecializer> parent = nullptr;
266       if (parent_ != nullptr) {
267         parent = parent_;
268       } else if (specializer_->top_context()->func_graph() == node->func_graph() && node->isa<Parameter>()) {
269         // If `parent_` is null and forwarded `node` is a Parameter, we'll try to use top func graph as parent.
270         parent = specializer_->GetFuncGraphSpecializer(specializer_->top_context());
271         MS_LOG(INFO) << "Used top func graph specializer as parent for " << func_graph_->ToString()
272                      << ", node: " << node->DebugString() << ", NodeInfo: " << trace::GetDebugInfo(node->debug_info());
273       }
274       if (parent == nullptr) {
275         MS_LOG(EXCEPTION) << "Parent must not null, node: " << node->DebugString()
276                           << ", NodeInfo: " << trace::GetDebugInfo(node->debug_info());
277       }
278       parent->AddTodoItem(node);
279       parent->FirstPass();
280       AnfNodePtr new_node = parent->GetReplicatedNode(node);
281       if (node->isa<CNode>()) {
282         parent->ProcessCNode(new_node->cast<CNodePtr>());
283       }
284       continue;
285     }
286     if (marked_.count(node) > 0) {
287       continue;
288     }
289     (void)marked_.insert(node);
290     ProcessNode(node);
291   }
292 }
293 
294 // Specialize CNode in func graphs
SecondPass()295 void FuncGraphSpecializer::SecondPass() {
296   for (auto &node : BroadFirstSearchGraphCNodes({specialized_func_graph_->get_return()})) {
297     if (node->isa<CNode>()) {
298       ProcessCNode(node->cast<CNodePtr>());
299     }
300   }
301 }
302 
ProcessNode(const AnfNodePtr & node)303 void FuncGraphSpecializer::ProcessNode(const AnfNodePtr &node) {
304   MS_EXCEPTION_IF_NULL(node);
305   ScopeGuard scope_guard(node->scope());
306   AnfNodeConfigPtr conf = MakeConfig(node);
307   AnfNodePtr new_node = GetReplicatedNode(node);
308   MS_EXCEPTION_IF_NULL(new_node);
309   if (new_node->func_graph() != specialized_func_graph_) {
310     MS_LOG(EXCEPTION) << "Error in specializer [A] node: " << node->DebugString()
311                       << ", new_node: " << new_node->DebugString() << ", new_node->func_graph(): "
312                       << (new_node->func_graph() ? new_node->func_graph()->ToString() : "FG(Null)")
313                       << ", specialized_func_graph_: " << specialized_func_graph_->ToString();
314     return;
315   }
316   new_node->set_abstract(GetEvaluatedValue(conf));
317   if (new_node->isa<CNode>() && new_node->abstract()->isa<PartialAbstractClosure>()) {
318     auto partial_abstract = dyn_cast<PartialAbstractClosure>(new_node->abstract());
319     if (partial_abstract->node() == node) {
320       partial_abstract->set_node(new_node);
321     }
322   }
323 
324   MS_LOG(DEBUG) << "Set new_node: " << new_node->ToString() << ", abstract as: " << new_node->abstract()->ToString();
325 
326   if (node->isa<CNode>()) {
327     auto attrs = conf->ObtainEvalResult()->attribute();
328     auto c_old = node->cast<CNodePtr>();
329     auto c_new = new_node->cast<CNodePtr>();
330     MS_EXCEPTION_IF_NULL(c_new);
331     auto new_inputs = c_new->inputs();
332     auto old_inputs = c_old->inputs();
333     for (size_t i = 0; i < old_inputs.size(); ++i) {
334       auto node_input = old_inputs[i];
335       AnfNodeConfigPtr iconf = MakeConfig(node_input);
336       AbstractBasePtr ival = GetEvaluatedValue(iconf);
337       // First try to check if node_input can be replaced by a ValueNode. If cannot, then try to check if
338       // can be replaced by another CNode from anfnode_config_map, otherwise use the replicated node.
339       AnfNodePtr replace_node = BuildPossibleValueNode(iconf->node(), ival, attrs);
340       if (replace_node == nullptr) {
341         replace_node = BuildReplacedNode(iconf);
342         replace_node->set_abstract(ival);
343         MS_LOG(DEBUG) << "Set replaced: " << replace_node->ToString() << ", to abstract: " << ival->ToString();
344       } else {
345         MS_LOG(DEBUG) << "Build possible value node for node: " << node_input->DebugString()
346                       << ", ival: " << ival->ToString() << ", replace_node: " << replace_node->ToString();
347       }
348       if (new_inputs[i] != replace_node) {
349         new_inputs[i] = replace_node;
350         MS_LOG(DEBUG) << "Set new_input[" << i << "] = " << replace_node->DebugString();
351       }
352     }
353     c_new->set_inputs(new_inputs);
354   }
355 }
356 
BuildReplacedNode(const AnfNodeConfigPtr & conf)357 AnfNodePtr FuncGraphSpecializer::BuildReplacedNode(const AnfNodeConfigPtr &conf) {
358   MS_EXCEPTION_IF_NULL(conf);
359 
360   auto conf_iter = engine_->anfnode_config_map().find(conf);
361   AnfNodeConfigPtr new_conf = conf;
362   while (conf_iter != engine_->anfnode_config_map().end()) {
363     MS_LOG(DEBUG) << "Origin conf: node(" << (new_conf->node() ? new_conf->node()->DebugString() : "Node(Null)") << ")";
364     new_conf = conf_iter->second;
365     MS_EXCEPTION_IF_NULL(new_conf);
366     const auto &forward_node = new_conf->node();
367     MS_LOG(DEBUG) << "Replaced conf: node(" << forward_node->DebugString() << ")";
368     const auto &replicated_forward_node = ReplicateDisconnectedNode(forward_node);
369     if (replicated_forward_node && replicated_forward_node->isa<CNode>()) {
370       // The AnfNode in order_list can be:
371       // case 1: also in FuncGraphManager, so it can be got from nodes API of func_graph. it will
372       //         be replaced in CloneOrderList in Cloner.
373       // case 2: AnfNode is not in FuncGraphManager which generated in Analyze phase, so it will not
374       //         be cloned by normal clone API.
375       //    2.1: A forward node , the original node is in FuncGraphManager. The original node will
376       //         be cloned in CloneOrderList in Cloner, and the replicated forward node will replace
377       //         the replicated original node here.
378       //    2.2: an input of a forward node, such as Cast CNode generated in DoCast. It is also another
379       //         original node to fowrad.
380       //    2.3: an input of an input of a forward node, but it's not an original node. Like the Cast CNode
381       //         in MixedPrecisionCastHelper.
382       // For 2.2 and 2.3, we will put a placeholder in order list of replicated func_graph, refer to
383       // CloneOrderlist, and it will be replaced inside ReplicateDisconnectedNode.
384       // For 2.1 the following code will do the job, replace replicated origin cnode with the replicated
385       // forward one in the replicated func_graph.
386       MS_EXCEPTION_IF_NULL(conf_iter->first);
387       const auto &origin_node = conf_iter->first->node();
388       const auto &replicated_origin_node = GetReplicatedNode(origin_node);
389       if (replicated_origin_node != origin_node) {
390         MS_LOG(DEBUG) << "Replace replicated origin node in order list: " << replicated_origin_node->DebugString()
391                       << ", with replicated forwarded node: " << replicated_forward_node->DebugString();
392         MS_EXCEPTION_IF_NULL(replicated_forward_node->func_graph());
393         replicated_forward_node->func_graph()->ReplaceInOrder(replicated_origin_node, replicated_forward_node);
394       } else {
395         MS_LOG(EXCEPTION) << "Origin node is not replicated in specialized func_graph, origin node: "
396                           << (origin_node ? origin_node->DebugString() : "Node(Null)");
397       }
398     }
399     conf_iter = engine_->anfnode_config_map().find(new_conf);
400   }
401   todo_.push_back(new_conf->node());
402   auto repl = GetReplicatedNode(new_conf->node());
403   if (repl->func_graph()) {
404     MS_LOG(DEBUG) << "Set repl: graph(" << repl->func_graph()->ToString() << "), node:" << repl->DebugString()
405                   << ") to replace origin:" << new_conf->node()->DebugString();
406   } else {
407     MS_LOG(DEBUG) << "Set repl: graph(nullptr), node(" << repl->DebugString()
408                   << ") to replace origin: " << new_conf->node()->DebugString();
409   }
410   return repl;
411 }
412 
413 namespace {
414 const StringImmPtr kDeadNode = std::make_shared<StringImm>(kDeadNodeName);
415 const StringImmPtr kPolyNode = std::make_shared<StringImm>(kPolyNodeName);
416 
CanSpecializeNode(const AnfNodePtr & node)417 inline bool CanSpecializeNode(const AnfNodePtr &node) {
418   if (IsValueNode<FuncGraph>(node) || IsValueNode<MetaFuncGraph>(node) || IsValueNode<Primitive>(node)) {
419     return true;
420   }
421   return false;
422 }
423 }  // namespace
424 
BuildSpecializedNode(const AnfNodePtr & node,const AbstractBasePtr & abs,const AbstractBasePtrList & argvals)425 AnfNodePtr FuncGraphSpecializer::BuildSpecializedNode(const AnfNodePtr &node, const AbstractBasePtr &abs,
426                                                       const AbstractBasePtrList &argvals) {
427   MS_EXCEPTION_IF_NULL(abs);
428   MS_EXCEPTION_IF_NULL(node);
429   AbstractFunctionPtr real_a = dyn_cast<AbstractFunction>(abs);
430   MS_EXCEPTION_IF_NULL(real_a);
431 
432   AbstractFunctionPtr func = real_a->GetUnique();
433   SpecializeStatusCode errcode;
434   ScopeGuard scope_guard(node->scope());
435   AnfNodePtr repl = BuildSpecializedNodeInner(node, abs, func, argvals, &errcode);
436   if (repl == nullptr) {
437     if (errcode == kSpecializeFindUniqueArgvalDead) {
438       const auto error_dead_node = std::make_shared<AbstractError>(kDeadNode, node);
439       repl = BuildValueNode(kDeadNode, error_dead_node);
440       MS_LOG(DEBUG) << "DEAD for node: " << node->DebugString() << ", abstract: " << abs->ToString();
441     } else if (errcode == kSpecializeFindUniqueArgvalPoly) {
442       const auto error_poly_node = std::make_shared<AbstractError>(kPolyNode, node);
443       repl = BuildValueNode(kPolyNode, error_poly_node);
444       MS_LOG(DEBUG) << "POLY for node: " << node->DebugString() << ", abstract: " << abs->ToString();
445     } else {
446       MS_LOG(EXCEPTION) << "Failed to build specialized node, node: " << node->DebugString()
447                         << ", abstract: " << abs->ToString();
448     }
449   }
450 
451   // Set the flag, so this MetaFuncGraph will be Re-AutoMonaded.
452   MS_EXCEPTION_IF_NULL(func);
453   if (func->isa<MetaFuncGraphAbstractClosure>()) {
454     auto specialized_fg = GetValueNode<FuncGraphPtr>(repl);
455     if (specialized_fg != nullptr && (argvals.size() > 1) && argvals.back() != nullptr &&
456         argvals.back()->isa<AbstractUMonad>()) {
457       specialized_fg->set_flag(mindspore::kFuncGraphFlagReAutoMonad, true);
458     }
459   }
460   return repl;
461 }
462 
BuildSpecializedNodeInner(const AnfNodePtr & node,const AbstractBasePtr & abs,const AbstractFunctionPtr & func,const AbstractBasePtrList & args,SpecializeStatusCode * errcode)463 AnfNodePtr FuncGraphSpecializer::BuildSpecializedNodeInner(const AnfNodePtr &node, const AbstractBasePtr &abs,
464                                                            const AbstractFunctionPtr &func,
465                                                            const AbstractBasePtrList &args,
466                                                            SpecializeStatusCode *errcode) {
467   MS_EXCEPTION_IF_NULL(abs);
468   MS_EXCEPTION_IF_NULL(func);
469   MS_EXCEPTION_IF_NULL(errcode);
470   *errcode = kSpecializeSuccess;
471 
472   auto real_func = dyn_cast<TypedPrimitiveAbstractClosure>(func);
473   if (real_func != nullptr) {
474     return BuildValueNode(real_func->prim(), abs);
475   }
476 
477   EvaluatorPtr eval = engine_->GetEvaluatorFor(func);
478   MS_EXCEPTION_IF_NULL(eval);
479   AbstractBasePtrList argvals = eval->NormalizeArgs(args);
480 
481   std::pair<AbstractBasePtrList, AbstractBasePtr> result;
482   SpecializeStatusCode status = FindUniqueArgvals(func, eval, argvals, &result);
483   if (status != kSpecializeSuccess) {
484     *errcode = status;
485     return nullptr;
486   }
487   argvals = result.first;
488   AbstractBasePtr unique_output = result.second;
489 
490   auto prim_func = dyn_cast<PrimitiveAbstractClosure>(func);
491   if (prim_func != nullptr) {
492     auto type_func = std::make_shared<TypedPrimitiveAbstractClosure>(prim_func->prim(), argvals, unique_output);
493     return BuildValueNode(prim_func->prim(), type_func);
494   }
495 
496   if (!eval->isa<BaseFuncGraphEvaluator>()) {
497     MS_LOG(EXCEPTION) << "Eval is not BaseGraphEvaluator, but " << eval->ToString();
498   }
499   auto real_eval = dyn_cast<BaseFuncGraphEvaluator>(eval);
500 
501   if (func->context() == nullptr) {
502     MS_LOG(EXCEPTION) << "Func context is nullptr NodeInfo: " << trace::GetDebugInfo(func_graph_->debug_info());
503   }
504   AnalysisContextPtr context = MakeContext(engine_, real_eval, argvals);
505   MS_LOG(DEBUG) << "Specialize function graph: " << context->func_graph()->ToString() << ", args: " << argvals.size()
506                 << ", graph: " << context->func_graph()->get_return()->DebugString();
507   MS_EXCEPTION_IF_NULL(context->func_graph());
508   if (context->func_graph()->stub()) {
509     MS_LOG(DEBUG) << "Specialize stub function graph, return the original node: " << context->func_graph()->ToString()
510                   << ", args: " << argvals.size() << ", graph: " << context->func_graph()->get_return()->DebugString()
511                   << ", " << node->ToString();
512     return node;
513   }
514   FuncGraphPtr v = specializer_->SpecializeFuncGraph(context->func_graph(), context);
515   MS_EXCEPTION_IF_NULL(v);
516   v->set_flag(kFuncGraphFlagUndetermined, false);
517   return BuildValueNode(v, abs);
518 }
519 
MakeContext(const AnalysisEnginePtr & engine,const BaseFuncGraphEvaluatorPtr & evaluator,const AbstractBasePtrList & args_spec_list)520 inline AnalysisContextPtr FuncGraphSpecializer::MakeContext(const AnalysisEnginePtr &engine,
521                                                             const BaseFuncGraphEvaluatorPtr &evaluator,
522                                                             const AbstractBasePtrList &args_spec_list) {
523   AbstractBasePtrList normalized_args_spec_list = evaluator->NormalizeArgs(args_spec_list);
524   FuncGraphPtr fg = evaluator->GetFuncGraph(engine, normalized_args_spec_list);
525   MS_EXCEPTION_IF_NULL(evaluator->parent_context());
526   AnalysisContextPtr new_context = evaluator->parent_context()->NewContext(fg, normalized_args_spec_list);
527   return new_context;
528 }
529 
BuildSpecializedParameterNode(const CNodePtr & new_node)530 AnfNodePtr FuncGraphSpecializer::BuildSpecializedParameterNode(const CNodePtr &new_node) {
531   MS_EXCEPTION_IF_NULL(new_node);
532   auto new_inputs = new_node->inputs();
533   if (new_inputs.empty()) {
534     MS_LOG(EXCEPTION) << "inputs can't be empty.";
535   }
536   AnfNodePtr func = new_inputs[0];
537   MS_EXCEPTION_IF_NULL(new_inputs[0]);
538   AbstractBasePtr fnval = new_inputs[0]->abstract();
539 
540   AbstractBasePtrList args;
541   auto backed_fnval = fnval;
542   if (fnval->isa<PartialAbstractClosure>()) {
543     auto partial_closure = dyn_cast<PartialAbstractClosure>(fnval);
544     backed_fnval = partial_closure->fn();
545     args = partial_closure->args();
546   }
547   std::transform(new_inputs.cbegin() + 1, new_inputs.cend(), std::back_inserter(args),
548                  [](const AnfNodePtr &inp) { return inp->abstract(); });
549 
550   ScopeGuard scope_guard(new_node->scope());
551 
552   auto specialized_node = BuildSpecializedNode(func, backed_fnval, args);
553   auto wrapped_node = specialized_node;
554   if (fnval->isa<PartialAbstractClosure>()) {
555     auto partial_closure = dyn_cast<PartialAbstractClosure>(fnval);
556     AnfNodePtrList partial_node_list = {BuildValueNode(prim::kPrimPartial, FromValueInside(prim::kPrimPartial)),
557                                         specialized_node};
558     auto anf_node = partial_closure->node();
559     if (!anf_node->isa<CNode>()) {
560       MS_LOG(EXCEPTION) << "Must be cnode, but " << anf_node->DebugString();
561     }
562     auto cnode = anf_node->cast<CNodePtr>();
563     if (cnode->size() != partial_closure->args().size() + 2) {
564       MS_LOG(EXCEPTION) << "Size of cnode: " << cnode->DebugString()
565                         << " is not equal to 2 added to size of args: " << mindspore::ToString(partial_closure->args());
566     }
567     auto attrs = std::make_shared<AttrValueMap>();
568     for (size_t i = 0; i < partial_closure->args().size(); i++) {
569       auto old_node = cnode->input(i + 2);
570       auto possibile_value_node = BuildPossibleValueNode(old_node, partial_closure->args()[i], attrs);
571       if (possibile_value_node != nullptr) {
572         partial_node_list.push_back(possibile_value_node);
573       } else {
574         if (!(old_node->isa<CNode>() || old_node->isa<Parameter>())) {
575           MS_LOG(EXCEPTION) << "Old node should be CNode or Parameter, but " << old_node->ToString();
576         }
577         partial_node_list.push_back(old_node);
578       }
579     }
580     MS_EXCEPTION_IF_NULL(new_node->func_graph());
581     wrapped_node = new_node->func_graph()->NewCNode(partial_node_list);
582     wrapped_node->set_abstract(partial_closure);
583   }
584   return wrapped_node;
585 }
586 
GetEvalCache(const EvaluatorPtr & eval)587 const EvaluatorCacheMgrPtr FuncGraphSpecializer::GetEvalCache(const EvaluatorPtr &eval) {
588   MS_EXCEPTION_IF_NULL(eval);
589   auto cache_iter = evalcaches_.find(eval);
590   if (cache_iter == evalcaches_.end()) {
591     evalcaches_[eval] = eval->evaluator_cache_mgr();
592     return eval->evaluator_cache_mgr();
593   }
594   return cache_iter->second;
595 }
596 
BuildFromBroadedArgsVal(const EvaluatorPtr & eval)597 std::pair<AbstractBasePtrList, AbstractBasePtr> FuncGraphSpecializer::BuildFromBroadedArgsVal(
598   const EvaluatorPtr &eval) {
599   MS_EXCEPTION_IF_NULL(eval);
600   std::unordered_set<AbstractBasePtrList, AbstractBasePtrListHasher, AbstractBasePtrListEqual> choices;
601   EvalResultPtr ret = nullptr;
602   AbstractBasePtrList broaded_argvals;
603   std::vector<AbstractBasePtrList> args_vector;
604   auto eval_cache_iter = evalcaches_.find(eval);
605   if (eval_cache_iter == evalcaches_.end()) {
606     MS_LOG(EXCEPTION) << "Evaluator:" << eval->ToString() << " not exist in cache.";
607   }
608   auto &origin_eval_cache = eval_cache_iter->second->GetCache();
609   for (auto &argvals_map : origin_eval_cache) {
610     auto argvals = argvals_map.first;
611     args_vector.push_back(argvals);
612     broaded_argvals.clear();
613     BroadenArgs(argvals, &broaded_argvals);
614     (void)choices.insert(broaded_argvals);
615     MS_LOG(DEBUG) << "Broaded_argvals: " << broaded_argvals.size() << ", " << ::mindspore::ToString(broaded_argvals);
616   }
617   if (choices.size() == 1) {
618     constexpr auto args_size = 2;
619     if (args_vector.size() < args_size) {
620       MS_LOG(EXCEPTION) << "Should have " << args_size << " or more choices, but: " << args_vector.size();
621     }
622     AbstractBasePtrList joined_argvals = args_vector[0];
623     for (size_t i = 1; i < args_vector.size(); ++i) {
624       joined_argvals = abstract::AbstractJoin(joined_argvals, args_vector[i]);
625     }
626     MS_LOG(DEBUG) << "Joined argvals: " << joined_argvals.size() << ", " << ::mindspore::ToString(joined_argvals);
627     EvaluatorCacheMgrPtr real = std::make_shared<EvaluatorCacheMgr>();
628     const auto joined_eval_result = origin_eval_cache.get(joined_argvals);
629     if (joined_eval_result != nullptr) {
630       MS_LOG(DEBUG) << "Find unique Choices in original eval cache, so use it: " << joined_eval_result->ToString();
631 
632       real->SetValue(joined_argvals, joined_eval_result);
633       evalcaches_[eval] = real;
634       return std::make_pair(joined_argvals, joined_eval_result->abstract());
635     } else {
636       bool all_args_tensor = std::all_of(broaded_argvals.cbegin(), broaded_argvals.cend(),
637                                          [](const AbstractBasePtr &v) { return CheckAbstractTensor(v); });
638       if (all_args_tensor) {
639         ConfigPtrList args_conf_list;
640         (void)std::transform(broaded_argvals.cbegin(), broaded_argvals.cend(), std ::back_inserter(args_conf_list),
641                              [](const AbstractBasePtr &v) -> ConfigPtr { return std::make_shared<VirtualConfig>(v); });
642         MS_LOG(WARNING) << "Cannot find joined argvals in cache, run with broaded argsvals: " << broaded_argvals.size()
643                         << ", " << ::mindspore::ToString(broaded_argvals);
644         ret = eval->SingleRun(engine_, args_conf_list, nullptr);
645         MS_EXCEPTION_IF_NULL(ret);
646         real->SetValue(broaded_argvals, ret);
647         evalcaches_[eval] = real;
648         return std::make_pair(broaded_argvals, ret->abstract());
649       }
650     }
651   }
652   MS_LOG(DEBUG) << "Choices.size: " << choices.size();
653   return std::make_pair(AbstractBasePtrList(), nullptr);
654 }
655 
ProcessCNode(const CNodePtr & new_node)656 void FuncGraphSpecializer::ProcessCNode(const CNodePtr &new_node) {
657   MS_EXCEPTION_IF_NULL(new_node);
658   if (specializer_->seen().count(new_node) > 0) {
659     return;
660   }
661   specializer_->AddSeen(new_node);
662   auto new_inputs = new_node->inputs();
663   if (new_inputs.empty()) {
664     MS_LOG(EXCEPTION) << "Inputs of CNode is empty";
665   }
666   AnfNodePtr func = new_inputs[0];
667   MS_EXCEPTION_IF_NULL(func);
668 
669   // First element is func so arg start from 1
670   std::vector<AnfNodePtr> args(new_inputs.begin() + 1, new_inputs.end());
671   // CNode(CNode(Partial, f, arg1), arg2, ...) --> CNode(f, arg1, arg2, ...)
672   const size_t arg_start_index = 2;
673   while (IsPrimitiveCNode(func, prim::kPrimPartial)) {
674     std::vector<AnfNodePtr> inputs = func->cast<CNodePtr>()->inputs();
675     // First element is partial, second is func so arg is start from 2
676     (void)args.insert(args.begin(), inputs.begin() + SizeToInt(arg_start_index), inputs.end());
677     func = inputs[1];
678   }
679   new_inputs = args;
680   (void)new_inputs.insert(new_inputs.begin(), func);
681 
682   AbstractBasePtrList argvals;
683   MS_EXCEPTION_IF_NULL(new_inputs[0]);
684   AbstractBasePtr fnval = new_inputs[0]->abstract();
685   MS_LOG(DEBUG) << "The new_inputs[0] node: pointer: " << new_inputs[0]->ToString() << ", "
686                 << new_inputs[0]->DebugString() << ", abstract: " << new_inputs[0]->abstract()->ToString();
687 
688   // First element is func so function arguments start from 1
689   for (size_t i = 1; i < new_inputs.size(); ++i) {
690     argvals.push_back(new_inputs[i]->abstract());
691     MS_LOG(DEBUG) << "The new_inputs[" << i << "] node: pointer: " << new_inputs[i]->ToString() << ", "
692                   << new_inputs[i]->DebugString() << ", abstract: " << new_inputs[i]->abstract()->ToString();
693   }
694 
695   if (!func->isa<ValueNode>()) {
696     MS_LOG(DEBUG) << func->abstract()->type_name() << " | " << func->abstract()->ToString();
697     if (func->abstract()->isa<AbstractFunction>() && !func->abstract()->isa<AbstractFuncUnion>()) {
698       auto func_abs = func->abstract()->cast<AbstractFunctionPtr>();
699       EvaluatorPtr eval = engine_->GetEvaluatorFor(func_abs);
700       std::pair<AbstractBasePtrList, AbstractBasePtr> result;
701       AbstractBasePtrList empty_args;
702       auto status = FindUniqueArgvals(func_abs, eval, empty_args, &result);
703       MS_LOG(DEBUG) << "FindUniqueArgvals return status: " << status;
704       // if a node is a poly node, or an input parameter is a PartialAbstractClosure, expand it early
705       MS_EXCEPTION_IF_NULL(func->func_graph());
706       if (status == kSpecializeFindUniqueArgvalPoly ||
707           (func->isa<Parameter>() && func->func_graph()->has_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER))) {
708         auto wrapped_node = BuildSpecializedParameterNode(new_node);
709         new_inputs[0] = wrapped_node;
710       }
711     }
712   }
713 
714   if (CanSpecializeNode(func)) {
715     // for primitive node, we build the primitive node with inferred attributes in the first pass
716     // so we do not build replaced node again here in second pass
717     if (IsValueNode<Primitive>(func)) {
718       new_inputs[0] = func;
719     } else {
720       new_inputs[0] = BuildSpecializedNode(func, fnval, argvals);
721     }
722   }
723 
724   for (size_t i = 0; i < argvals.size();) {
725     size_t next = i + 1;
726     if (CanSpecializeNode(args[i])) {
727       new_inputs[next] = BuildSpecializedNode(args[i], argvals[i], std::vector<AbstractBasePtr>{});
728     }
729     i = next;
730   }
731   new_node->set_inputs(new_inputs);
732 }
733 
734 namespace {
DumpEvaluatorCache(const EvaluatorCacheMgrPtr & evaluator_cache_mgr,const AbstractBasePtrList & argvals)735 void DumpEvaluatorCache(const EvaluatorCacheMgrPtr &evaluator_cache_mgr, const AbstractBasePtrList &argvals) {
736   MS_EXCEPTION_IF_NULL(evaluator_cache_mgr);
737   MS_LOG(DEBUG) << "Find unique argvals failed: " << argvals.size() << ", " << argvals << ". Check cache all items.";
738   int64_t i = 0;
739   const EvalResultCache &map = evaluator_cache_mgr->GetCache();
740   for (const auto &item : map) {
741     MS_LOG(DEBUG) << "evaluator_cache[" << i++ << "]: " << item.first;
742   }
743 }
744 
IsPolyFunc(const AbstractFunctionPtr & func,const AbstractBasePtrList & argvals)745 bool IsPolyFunc(const AbstractFunctionPtr &func, const AbstractBasePtrList &argvals) {
746   MS_EXCEPTION_IF_NULL(func);
747   if (func->isa<PrimitiveAbstractClosure>() && argvals.empty()) {
748     MS_LOG(DEBUG) << "High order primitive return POLY.";
749     return true;
750   }
751   if (func->isa<MetaFuncGraphAbstractClosure>() && argvals.empty()) {
752     auto meta_func_graph_wrapper = dyn_cast<MetaFuncGraphAbstractClosure>(func);
753     auto meta_func_graph = meta_func_graph_wrapper->meta_func_graph();
754     if (meta_func_graph != nullptr && meta_func_graph->isa<prim::DoSignatureMetaFuncGraph>()) {
755       auto do_signature = dyn_cast<prim::DoSignatureMetaFuncGraph>(meta_func_graph);
756       if (do_signature != nullptr && do_signature->function()->isa<Primitive>()) {
757         MS_LOG(DEBUG) << "High order primitive " << do_signature->function()->ToString() << " return POLY.";
758         return true;
759       }
760     }
761   }
762   return false;
763 }
764 }  // end anonymous namespace
765 
FindUniqueArgvals(const AbstractFunctionPtr & func,const EvaluatorPtr & eval,const AbstractBasePtrList & argvals,std::pair<AbstractBasePtrList,AbstractBasePtr> * result)766 SpecializeStatusCode FuncGraphSpecializer::FindUniqueArgvals(const AbstractFunctionPtr &func, const EvaluatorPtr &eval,
767                                                              const AbstractBasePtrList &argvals,
768                                                              std::pair<AbstractBasePtrList, AbstractBasePtr> *result) {
769   MS_EXCEPTION_IF_NULL(func);
770   MS_EXCEPTION_IF_NULL(eval);
771   MS_EXCEPTION_IF_NULL(result);
772 
773   EvaluatorCacheMgrPtr evaluator_cache_mgr = eval->evaluator_cache_mgr();
774   MS_EXCEPTION_IF_NULL(evaluator_cache_mgr);
775   auto data = evaluator_cache_mgr->GetValue(argvals);
776   if (data != nullptr) {
777     *result = std::make_pair(argvals, data->abstract());
778     return kSpecializeSuccess;
779   }
780   DumpEvaluatorCache(evaluator_cache_mgr, argvals);
781 
782   auto cache = GetEvalCache(eval);
783   MS_EXCEPTION_IF_NULL(cache);
784   const EvalResultCache &choices = cache->GetCache();
785   if (choices.get(argvals) != nullptr) {
786     MS_EXCEPTION_IF_NULL(cache->GetValue(argvals));
787     *result = std::make_pair(argvals, cache->GetValue(argvals)->abstract());
788     return kSpecializeSuccess;
789   } else if (choices.size() == 1) {
790     MS_LOG(DEBUG) << "Evaluator cache has a single item, just use it.";
791     MS_EXCEPTION_IF_NULL(choices.begin()->second);
792     *result = std::make_pair(choices.begin()->first, choices.begin()->second->abstract());
793     return kSpecializeSuccess;
794   } else if (choices.empty()) {
795     MS_LOG(DEBUG) << "Find DEAD code, it may be optimized in later phase " << func->ToString() << " | "
796                   << func->type_name();
797     return kSpecializeFindUniqueArgvalDead;
798   } else {
799     if (IsPolyFunc(func, argvals)) {
800       return kSpecializeFindUniqueArgvalPoly;
801     }
802 
803     MS_LOG(DEBUG) << "Try to find generalized argvals.";
804     *result = BuildFromBroadedArgsVal(eval);
805     if (!result->first.empty()) {
806       return kSpecializeSuccess;
807     }
808     MS_LOG(DEBUG) << "Find POLY code, it may be unused code or unresolved polymorphism.";
809     return kSpecializeFindUniqueArgvalPoly;
810   }
811 }
BuildPrimtiveValueWithAttributes(const PrimitivePtr & prim,const AttrValueMapPtr & attrs)812 static PrimitivePtr BuildPrimtiveValueWithAttributes(const PrimitivePtr &prim, const AttrValueMapPtr &attrs) {
813   MS_EXCEPTION_IF_NULL(prim);
814   auto &prim_attrs = prim->attrs();
815   bool is_attr_same = true;
816   for (auto &item : *attrs) {
817     auto itr = prim_attrs.find(item.first);
818     if (itr != prim_attrs.end()) {
819       MS_EXCEPTION_IF_NULL(itr->second);
820       MS_EXCEPTION_IF_NULL(item.second);
821       if (!(*(itr->second) == *(item.second))) {
822         is_attr_same = false;
823         break;
824       }
825     } else {
826       is_attr_same = false;
827       break;
828     }
829   }
830   if (!is_attr_same) {
831     auto cloned_prim = prim->Clone();
832     for (auto &item : *attrs) {
833       cloned_prim->AddAttr(item.first, item.second);
834     }
835     return cloned_prim;
836   }
837   return prim;
838 }
839 
BuildPossibleValueNode(const AnfNodePtr & origin_node,const AbstractBasePtr & ival,const AttrValueMapPtr & attrs)840 AnfNodePtr FuncGraphSpecializer::BuildPossibleValueNode(const AnfNodePtr &origin_node, const AbstractBasePtr &ival,
841                                                         const AttrValueMapPtr &attrs) {
842   MS_EXCEPTION_IF_NULL(origin_node);
843   MS_EXCEPTION_IF_NULL(ival);
844 
845   AbstractFunctionPtr abs = dyn_cast<AbstractFunction>(ival);
846   if (abs != nullptr) {
847     // Cannot build a deterministic ValueNode if there are multiple possible AbstractFunction.
848     if (abs->isa<AbstractFuncUnion>()) {
849       return nullptr;
850     }
851     ValuePtr value = nullptr;
852     if (abs->isa<PrimitiveAbstractClosure>()) {
853       auto real_fn = dyn_cast<PrimitiveAbstractClosure>(abs);
854       // for primitive, check if the attribute is the same with cnode inferred attribute, if not, clone a new one
855       if (attrs != nullptr) {
856         value = BuildPrimtiveValueWithAttributes(real_fn->prim(), attrs);
857       } else {
858         value = real_fn->prim();
859       }
860     } else if (abs->isa<MetaFuncGraphAbstractClosure>()) {
861       auto real_fn = dyn_cast<MetaFuncGraphAbstractClosure>(abs);
862       value = real_fn->meta_func_graph();
863     } else if (abs->isa<FuncGraphAbstractClosure>()) {
864       auto real_fn = dyn_cast<FuncGraphAbstractClosure>(abs);
865       value = real_fn->func_graph();
866     } else {
867       return nullptr;
868     }
869     MS_EXCEPTION_IF_NULL(value);
870     if (!value->isa<FuncGraph>() || value->cast<FuncGraphPtr>()->parent() == nullptr ||
871         (IsValueNode<FuncGraph>(origin_node) && IsVisible(func_graph_, value->cast<FuncGraphPtr>()->parent()))) {
872       return BuildValueNode(value, ival);
873     } else {
874       return nullptr;
875     }
876   } else {
877     ValuePtr val = ival->BuildValue();
878     if (val->isa<AnyValue>()) {
879       return nullptr;
880     }
881     // keep primitive 'depend' not to be optimized
882     if (IsPrimitiveCNode(origin_node, prim::kPrimDepend)) {
883       return nullptr;
884     }
885     return BuildValueNode(val, ival);
886   }
887 }
888 
MakeConfig(const AnfNodePtr & node)889 inline AnfNodeConfigPtr FuncGraphSpecializer::MakeConfig(const AnfNodePtr &node) {
890   return engine_->MakeConfig(node, context_, func_graph_);  // `func_graph_` is dummy here.
891 }
892 }  // namespace abstract
893 }  // namespace mindspore
894