• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "backend/session/ascend_auto_monad.h"
18 #include <set>
19 #include <map>
20 #include <stack>
21 #include <vector>
22 #include <string>
23 #include <tuple>
24 #include <queue>
25 #include <utility>
26 #include <memory>
27 #include <algorithm>
28 #include "utils/ms_context.h"
29 #include "utils/ordered_map.h"
30 #include "base/core_ops.h"
31 #include "debug/anf_ir_dump.h"
32 #include "pipeline/jit/base.h"
33 #include "backend/session/anf_runtime_algorithm.h"
34 #include "runtime/device/ascend/kernel_select_ascend.h"
35 
36 namespace mindspore {
37 namespace session {
38 namespace {
39 // Pair of graph and its actual arguments.
40 using GraphArgPair = std::pair<KernelGraphPtr, std::vector<AnfNodePtr>>;
41 
42 // We start label id from 0, and use 0xFFFFFFFF to indicate label not set.
43 constexpr uint32_t kNoLabel = 0xFFFFFFFF;
44 
45 // We start input index from 2 for AssignOp, as for inputs[2] is input, inputs[1] is output;
46 constexpr size_t kInputIndex = 2;
47 
48 // Primitive attribute for argument link assign.
49 const char LINK[] = "link";
50 
51 // Attribute to indicate that the node should not be eliminated.
52 // Used to keep argument Assign nodes for recursive graphs.
53 const char KEEP[] = "keep";
54 
55 // Attribute to indicate that this is an assign for output.
56 const char OUTPUT[] = "output";
57 
58 // Attribute to indicate that the node is last node in an iteration.
59 const char ITEREND[] = "PROFILING_ITER_END";
60 
61 #ifdef ENABLE_DUMP_IR
IsSaveGraph()62 bool IsSaveGraph() {
63   auto context_ptr = MsContext::GetInstance();
64   MS_EXCEPTION_IF_NULL(context_ptr);
65   return context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
66 }
67 
DumpAllGraphs(NotNull<KernelGraphPtr> kg,std::set<KernelGraphPtr> * memo)68 void DumpAllGraphs(NotNull<KernelGraphPtr> kg, std::set<KernelGraphPtr> *memo) {
69   if (memo->find(kg) != memo->end()) {
70     return;
71   }
72   memo->insert(kg);
73   std::string file_name = "ascend_auto_monad_" + std::to_string(kg->graph_id()) + ".ir";
74   DumpIR(file_name, kg.get());
75   for (auto &child : kg->child_graph_order()) {
76     auto cg = child.lock();
77     if (cg) {
78       DumpAllGraphs(NOT_NULL(cg), memo);
79     }
80   }
81 }
82 
DumpGraphForDebug(const NotNull<KernelGraphPtr> kg)83 void DumpGraphForDebug(const NotNull<KernelGraphPtr> kg) {
84   if (IsSaveGraph()) {
85     std::set<KernelGraphPtr> memo;
86     DumpAllGraphs(kg, &memo);
87   }
88 }
89 #endif
90 
91 #ifndef ENABLE_SECURITY
DumpExecuteOrder(const NotNull<KernelGraphPtr> kg)92 void DumpExecuteOrder(const NotNull<KernelGraphPtr> kg) {
93   if (!IsSaveGraph()) {
94     return;
95   }
96   std::string filename = "ascend_execute_order_" + std::to_string(kg->graph_id()) + ".dat";
97   auto filepath = GetSaveGraphsPathName(filename);
98   if (filepath.size() >= PATH_MAX) {
99     MS_LOG(ERROR) << "File path: " << filepath << " is too long.";
100     return;
101   }
102   char real_path[PATH_MAX] = {0};
103 #if defined(_WIN32) || defined(_WIN64)
104   if (_fullpath(filepath, filename.c_str(), PATH_MAX) == nullptr) {
105     MS_LOG(DEBUG) << "dir " << filename << " does not exit.";
106   }
107 #else
108   if (realpath(filepath.c_str(), real_path) == nullptr) {
109     MS_LOG(DEBUG) << "Dir " << filepath << " does not exit.";
110   }
111 #endif
112 
113   std::ofstream fout(real_path);
114   if (!fout.is_open()) {
115     MS_LOG(ERROR) << "Open file '" << real_path << "' failed!";
116     return;
117   }
118 
119   fout << "Execute order:\n";
120   int index = 0;
121   for (auto &cnode : kg->execution_order()) {
122     MS_EXCEPTION_IF_NULL(cnode);
123     if (IsPrimitiveCNode(cnode, prim::kPrimLabelSet)) {
124       fout << "L" << AnfAlgo::GetNodeAttr<uint32_t>(cnode, kAttrLabelIndex) << ":\n";
125     }
126     fout << "  [" << index << "], " << cnode->DebugString();
127     if (AnfAlgo::HasNodeAttr(kAttrLabelIndex, cnode)) {
128       fout << " : L" << AnfAlgo::GetNodeAttr<uint32_t>(cnode, kAttrLabelIndex);
129     }
130     if (AnfAlgo::HasNodeAttr(kAttrLabelSwitchList, cnode)) {
131       auto labels = AnfAlgo::GetNodeAttr<std::vector<uint32_t>>(cnode, kAttrLabelSwitchList);
132       fout << " : ";
133       for (size_t i = 0; i < labels.size(); ++i) {
134         fout << ((i > 0) ? ", L" : "L") << labels[i];
135       }
136     }
137     fout << '\n';
138     index++;
139   }
140   fout.close();
141 }
142 #endif
143 
144 // Return kNoLabel when label id attribute not set for the graph.
GetGraphLabel(const KernelGraphPtr & kg)145 uint32_t GetGraphLabel(const KernelGraphPtr &kg) {
146   auto value = kg->get_attr(kAttrLabelIndex);
147   if (value == nullptr) {
148     return kNoLabel;
149   }
150   return GetValue<uint32_t>(value);
151 }
152 
153 // Check if one abstract is compatible with another abstract.
IsCompatible(const abstract::AbstractBasePtr & a1,const abstract::AbstractBasePtr & a2)154 bool IsCompatible(const abstract::AbstractBasePtr &a1, const abstract::AbstractBasePtr &a2) {
155   if (a1 == nullptr || a2 == nullptr) {
156     return false;
157   }
158   if (a1 == a2) {
159     return true;
160   }
161   // Check AbstractTuple.
162   if (a1->isa<abstract::AbstractTuple>() && a2->isa<abstract::AbstractTuple>()) {
163     auto &a1_tuple = static_cast<abstract::AbstractTuple &>(*a1);
164     auto &a2_tuple = static_cast<abstract::AbstractTuple &>(*a2);
165     auto &a1_elements = a1_tuple.elements();
166     auto &a2_elements = a2_tuple.elements();
167     if (a1_elements.size() != a2_elements.size()) {
168       return false;
169     }
170     for (size_t i = 0; i < a1_elements.size(); i++) {
171       MS_EXCEPTION_IF_NULL(a1_elements[i]);
172       MS_EXCEPTION_IF_NULL(a2_elements[i]);
173       if (!IsCompatible(a1_elements[i], a2_elements[i])) {
174         return false;
175       }
176     }
177     return true;
178   }
179   // Check AbstractTensor and AbstractRef.
180   auto type1 = a1->BuildType();
181   auto type2 = a2->BuildType();
182   if (type1 != type2 && *type1 != *type2) {
183     return false;
184   }
185   auto shape1 = a1->BuildShape();
186   auto shape2 = a2->BuildShape();
187   if (shape1 == shape2) {
188     return true;
189   }
190   if (shape1->isa<abstract::Shape>() && shape2->isa<abstract::Shape>()) {
191     const auto &shape1_vec = shape1->cast<abstract::ShapePtr>()->shape();
192     const auto &shape2_vec = shape2->cast<abstract::ShapePtr>()->shape();
193     if ((shape1_vec == ShapeVector({1}) && shape2_vec == ShapeVector()) ||
194         (shape1_vec == ShapeVector() && shape2_vec == ShapeVector({1}))) {
195       return true;
196     }
197   }
198   return *shape1 == *shape2;
199 }
200 
201 struct CallBranch {
202   KernelGraphPtr graph;
203   std::vector<AnfNodePtr> args;
204 };
205 
206 struct CallSite {
207   // Call/Switch/SwitchLayer
208   CNodePtr cnode;
209 
210   // CNode after transferring to LabelGoto/LabelSwitch/LabelSet.
211   CNodePtr conversion_cnode;
212 
213   // The last monad before call.
214   AnfNodePtr last_monad = nullptr;
215 
216   // Branch graph called.
217   std::vector<CallBranch> callees;
218 
219   // Parameter for return value.
220   AnfNodePtr out_param = nullptr;
221 
222   // Label id for return.
223   uint32_t return_label = kNoLabel;
224 
225   // Label param to index map.
226   std::map<AnfNodePtr, uint32_t> label_indexes;
227 
228   // True if this is a recursive call.
229   bool recursive = false;
230 
231   // True if this is a tail call.
232   bool tail = false;
233 
234   // True if this call is a disable tail-opt call.
235   bool disable_tail = false;
236 };
237 
238 struct ReturnPoint {
239   CallSite *call_site = nullptr;
240 };
241 
242 struct CallInfo {
243   // Call sites in current graph.
244   std::vector<CallSite> call_sites;
245 
246   // Return points of current graph.
247   std::vector<ReturnPoint> return_points;
248 
249   // Parameter to store label index, if there are
250   // multi return points, this should be set.
251   AnfNodePtr label_param = nullptr;
252 
253   // Return monad.
254   AnfNodePtr return_monad_ = nullptr;
255 
256   // True if current graph is involved with recursive calls.
257   bool recursive = false;
258 };
259 
260 //
261 // ParameterPool cache parameters by its abstract, so that we can reuse
262 // parameter with same abstract to store return values.
263 //
264 class ParameterPool {
265  public:
ParameterPool(const KernelGraphPtr & top_graph)266   explicit ParameterPool(const KernelGraphPtr &top_graph) : top_graph_(top_graph) {}
267   ~ParameterPool() = default;
268 
269   // Create or get a parameter from pool with the given abstract.
GetParameter(const abstract::AbstractBasePtr & abs)270   AnfNodePtr GetParameter(const abstract::AbstractBasePtr &abs) {
271     // Find parameter in pool by the given abstract.
272     auto iter = std::find_if(paras_.begin(), paras_.end(), [&abs](auto &para) {
273       auto para_abs = para->abstract();
274       // Reuse output parameter with compatible abstract.
275       return IsCompatible(abs, para_abs);
276     });
277     // Return the parameter if found.
278     if (iter != paras_.end()) {
279       return *iter;
280     }
281     // If parameter not found with the given abstract, create a new one.
282     auto para = top_graph_->NewParameter(abs);
283     auto out_para = top_graph_->TransTupleToMakeTuple(para);
284     // This is required, so that device memory can be allocated for it.
285     top_graph_->AddChildGraphResult(out_para);
286     // Save new para to pool.
287     paras_.push_back(out_para);
288     return out_para;
289   }
290 
291  private:
292   // The top graph.
293   const KernelGraphPtr &top_graph_;
294 
295   // Cached parameters.
296   std::vector<AnfNodePtr> paras_;
297 };
298 
299 //
300 // Base class for context.
301 //
302 class BaseContext {
303  public:
MarkVisited(const KernelGraphPtr & kg)304   void MarkVisited(const KernelGraphPtr &kg) { visited_graphs_.insert(kg); }
305 
IsVisited(const KernelGraphPtr & kg) const306   bool IsVisited(const KernelGraphPtr &kg) const { return visited_graphs_.find(kg) != visited_graphs_.end(); }
307 
visited_graphs() const308   const std::set<KernelGraphPtr> &visited_graphs() const { return visited_graphs_; }
309 
ClearVisited()310   void ClearVisited() { visited_graphs_.clear(); }
311 
~BaseContext()312   virtual ~BaseContext() {}
313 
314  private:
315   std::set<KernelGraphPtr> visited_graphs_;
316 };
317 
318 //
319 // AscendAutoMonadContext holds some shared states during auto-monad.
320 //
321 class AscendAutoMonadContext : public BaseContext {
322  public:
AscendAutoMonadContext(const KernelGraphPtr & kg)323   explicit AscendAutoMonadContext(const KernelGraphPtr &kg) : top_graph_(kg), param_pool_(kg) {}
324   ~AscendAutoMonadContext() = default;
325 
326   // Label id start from 1, and increased by 1 for each new id.
NewLabel()327   uint32_t NewLabel() { return label_id_++; }
328 
329   // Current label id, also the number of label ids we currently used.
CurrentLabel() const330   uint32_t CurrentLabel() const { return label_id_; }
331 
332   // Create a new parameter.
333   // Output parameters are all created on top graph.
CreateParameter(const AbstractBasePtr & abs)334   AnfNodePtr CreateParameter(const AbstractBasePtr &abs) {
335     auto para = top_graph_->NewParameter(abs);
336     auto out_para = top_graph_->TransTupleToMakeTuple(para);
337     // This is required, so that device memory can be allocated for it.
338     top_graph_->AddChildGraphResult(out_para);
339     return out_para;
340   }
341 
342   // Get or create a temporary parameter for the given abstract.
GetTempParameter(const AbstractBasePtr & abs)343   AnfNodePtr GetTempParameter(const AbstractBasePtr &abs) { return param_pool_.GetParameter(abs); }
344 
TopGraph() const345   const KernelGraphPtr &TopGraph() const { return top_graph_; }
346 
347   // Has already created an stack.
HasInitedStack() const348   const bool HasInitedStack() const { return inited_stack_; }
349 
350   // Set flag to indicate whether has already created an stack or not.
SetInitedStack(bool flag)351   void SetInitedStack(bool flag) { inited_stack_ = flag; }
352 
353   // The graphs has recursion.
HasRecursiveCall() const354   bool HasRecursiveCall() const { return has_recursive_call_; }
355   // The graphs has subgraph multi-call.
HasSubgraphMultiCall() const356   bool HasSubgraphMultiCall() const { return has_subgraph_multicall_; }
357   // set flag to indicate whether has recursion.
SetRecursiveCall(bool flag)358   void SetRecursiveCall(bool flag) { has_recursive_call_ = flag; }
359   // set flag to indicate whether has multi-call.
SetSubGraphMultiCall(bool flag)360   void SetSubGraphMultiCall(bool flag) { has_subgraph_multicall_ = flag; }
361 
362   // Map kernel_graph to its call info.
363   OrderedMap<KernelGraphPtr, CallInfo> call_info_map;
364 
365  private:
366   // The top graph.
367   const KernelGraphPtr &top_graph_;
368 
369   // The parameter pool that cache parameters for return value.
370   ParameterPool param_pool_;
371 
372   // Current label id.
373   uint32_t label_id_ = 0;
374 
375   // Create an stack for multi-call and non-tail recursion.
376   bool inited_stack_ = false;
377   // The graphs has recursion or not.
378   bool has_recursive_call_ = false;
379   // The graphs has subgraph multi-call or not.
380   bool has_subgraph_multicall_ = false;
381 };
382 
383 //
384 // Call info finder finds graph call information.
385 //
386 class CallInfoFinder {
387  public:
Run(AscendAutoMonadContext * context)388   static void Run(AscendAutoMonadContext *context) {
389     CallInfoFinder finder(context->TopGraph(), context);
390     finder.Run();
391   }
392 
393  private:
CallInfoFinder(const KernelGraphPtr & kg,AscendAutoMonadContext * context)394   CallInfoFinder(const KernelGraphPtr &kg, AscendAutoMonadContext *context) : kernel_graph_(kg), context_(*context) {}
395   ~CallInfoFinder() = default;
396 
Run()397   void Run() {
398     FindCallSites();
399     FindRecursiveCalls();
400     DisableTailCalls();
401     FindCallReturns();
402   }
403 
404   // Find all call sites.
FindCallSites()405   void FindCallSites() {
406     auto call_info = CreateCallInfo();
407     if (call_info == nullptr) {
408       // Skip if call_info for this graph already existed.
409       return;
410     }
411     // Update directly called sub-graphs.
412     kernel_graph_->UpdateChildGraphOrder();
413     // Find Call/Switch/SwitchLayer nodes, and make CallSites for them.
414     AnfNodePtr last_monad = nullptr;
415     auto nodes = TopoSort(kernel_graph_->output());
416     for (auto &node : nodes) {
417       MS_EXCEPTION_IF_NULL(node);
418       if (HasAbstractUMonad(node)) {
419         // Found a node with UMonad abstract, set it as the last monad.
420         last_monad = node;
421         call_info->return_monad_ = last_monad;
422       } else if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimCall)) {
423         MakeCallSite(node->cast<CNodePtr>(), last_monad, call_info);
424         call_info->return_monad_ = nullptr;
425       } else if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimSwitch) ||
426                  AnfAlgo::CheckPrimitiveType(node, prim::kPrimSwitchLayer)) {
427         MakeSwitchCallSite(node->cast<CNodePtr>(), last_monad, call_info);
428         call_info->return_monad_ = nullptr;
429       }
430     }
431     // Set the last call as tail call if it is the output node.
432     // We don't set tail call for top graph because return is always required.
433     if (kernel_graph_ != context_.TopGraph() && !call_info->call_sites.empty()) {
434       auto real_output = GetRealNode(kernel_graph_->output());
435       if (real_output == call_info->call_sites.back().cnode) {
436         call_info->call_sites.back().tail = true;
437       }
438     }
439     // Recursively find CallSites from sub-graphs.
440     for (auto &call_site : call_info->call_sites) {
441       for (auto &callee : call_site.callees) {
442         CallInfoFinder finder(callee.graph, &context_);
443         finder.FindCallSites();
444       }
445     }
446   }
447 
448   // Find recursive non-tail calls.
FindRecursiveCalls()449   void FindRecursiveCalls() {
450     for (auto &[caller, call_info] : context_.call_info_map) {
451       for (auto &call_site : call_info.call_sites) {
452         if (!call_site.tail) {
453           SearchRecursiveCall(caller, &call_site);
454         }
455       }
456     }
457   }
458 
459   // Disable tail call optimization for recursive call graphs.
DisableTailCalls()460   void DisableTailCalls() {
461     for (auto &entry : context_.call_info_map) {
462       auto &call_info = entry.second;
463       if (call_info.recursive && !call_info.call_sites.empty()) {
464         call_info.call_sites.back().tail = false;
465         call_info.call_sites.back().disable_tail = true;
466       }
467     }
468   }
469 
470   // Find call-return pairs.
FindCallReturns()471   void FindCallReturns() {
472     for (auto &[caller, call_info] : context_.call_info_map) {
473       for (auto &call_site : call_info.call_sites) {
474         for (auto &callee : call_site.callees) {
475           MakeGraphLabel(callee.graph);
476         }
477         if (!call_site.tail) {
478           SearchCallReturns(caller, &call_site);
479         }
480       }
481     }
482   }
483 
484   // Create entry label for the given graph if not set.
MakeGraphLabel(const KernelGraphPtr & kg)485   void MakeGraphLabel(const KernelGraphPtr &kg) {
486     auto label = GetGraphLabel(kg);
487     if (label == kNoLabel) {
488       // Allocate a new label id and save it to the graph.
489       label = context_.NewLabel();
490       kg->set_attr(kAttrLabelIndex, MakeValue(label));
491     }
492   }
493 
494   // Search return points for all non-tail calls.
SearchCallReturns(const KernelGraphPtr & caller,CallSite * call_site)495   void SearchCallReturns(const KernelGraphPtr &caller, CallSite *call_site) {
496     std::set<KernelGraphPtr> visited = {caller};
497     std::queue<CallSite *> call_sites;
498     call_sites.push(call_site);
499     while (!call_sites.empty()) {
500       auto site = call_sites.front();
501       call_sites.pop();
502       for (auto &callee : site->callees) {
503         auto &kg = callee.graph;
504         if (visited.find(kg) != visited.end()) {
505           // Skip visited graphs.
506           continue;
507         }
508         // Mark visited.
509         visited.emplace(kg);
510         // Check callee.
511         auto &call_info = context_.call_info_map[kg];
512         auto &sites = call_info.call_sites;
513         if (!sites.empty() && sites.back().tail) {
514           // Follow tail call.
515           call_sites.push(&sites.back());
516         } else {
517           // Find a call-return relation.
518           HandleCallReturn(call_site, kg);
519         }
520       }
521     }
522   }
523 
524   struct SearchRecursiveContext {
525     const KernelGraphPtr &start_caller;
526     CallSite *start_site;
527     std::set<KernelGraphPtr> visited;
528     std::vector<KernelGraphPtr> call_path;
529   };
530 
531   // Search recursive call from a call-site.
SearchRecursiveCall(const KernelGraphPtr & start_caller,CallSite * start_site)532   void SearchRecursiveCall(const KernelGraphPtr &start_caller, CallSite *start_site) {
533     SearchRecursiveContext context{.start_caller = start_caller, .start_site = start_site};
534     DoSearchRecursiveCall(start_caller, *start_site, &context);
535   }
536 
DoSearchRecursiveCall(const KernelGraphPtr & graph,const CallSite & call_site,SearchRecursiveContext * ctx)537   void DoSearchRecursiveCall(const KernelGraphPtr &graph, const CallSite &call_site, SearchRecursiveContext *ctx) {
538     MS_EXCEPTION_IF_NULL(ctx);
539     // Record call path.
540     ctx->call_path.push_back(graph);
541     // Handle callee graphs.
542     for (auto &callee : call_site.callees) {
543       auto &sub_graph = callee.graph;
544       if (sub_graph == ctx->start_caller) {
545         // Find a recursive call path.
546         for (auto &g : ctx->call_path) {
547           // Mark recursive for all graphs in call path.
548           context_.call_info_map[g].recursive = true;
549         }
550         // Mark recursive for the start call-site.
551         MS_EXCEPTION_IF_NULL(ctx->start_site);
552         ctx->start_site->recursive = true;
553         continue;
554       }
555       if (ctx->visited.find(sub_graph) != ctx->visited.end()) {
556         // Skip visited graphs.
557         continue;
558       }
559       // Mark visited.
560       (void)ctx->visited.emplace(sub_graph);
561       // Check call sites in the sub-graph.
562       auto &call_info = context_.call_info_map[sub_graph];
563       auto &sites = call_info.call_sites;
564       for (auto &site : sites) {
565         if (!site.callees.empty()) {
566           DoSearchRecursiveCall(sub_graph, site, ctx);
567         }
568       }
569     }
570     // Don't forget this.
571     ctx->call_path.pop_back();
572   }
573 
574   // Handle a call-return relation.
HandleCallReturn(CallSite * call_site,const KernelGraphPtr & callee)575   void HandleCallReturn(CallSite *call_site, const KernelGraphPtr &callee) {
576     // Create a label for the return point.
577     if (call_site->return_label == kNoLabel) {
578       call_site->return_label = context_.NewLabel();
579     }
580     if (!IsCompatible(call_site->cnode->abstract(), callee->output()->abstract())) {
581       MS_LOG(EXCEPTION) << "call_site node: " << call_site->cnode->DebugString() << " has different abstract() with "
582                         << callee->ToString() << " output(), [ " << call_site->cnode->abstract()->ToString()
583                         << " != " << callee->output()->abstract()->ToString() << " ],"
584                         << "Do not support this situation, pls check if the graghs are correct.";
585     }
586 
587     // Create a parameter for the return value.
588     if (call_site->out_param == nullptr) {
589       call_site->out_param = context_.CreateParameter(call_site->cnode->abstract());
590     }
591     // Add a return point for the callee graph.
592     auto &call_info = context_.call_info_map[callee];
593     auto &return_point = call_info.return_points.emplace_back();
594     return_point.call_site = call_site;
595 
596     // Setup label index if there are multi return points.
597     const auto n_return_points = call_info.return_points.size();
598     const size_t return_point_sizes = 2;
599     if (n_return_points > 1) {
600       if (n_return_points == return_point_sizes) {
601         // Create a parameter to store label index.
602         const ShapeVector shape = {1};
603         auto abs = std::make_shared<abstract::AbstractTensor>(kInt32, shape);
604         call_info.label_param = context_.CreateParameter(abs);
605         // Add label index for the first call site.
606         call_info.return_points.front().call_site->label_indexes.emplace(call_info.label_param, 0);
607         // Judge the last call_site whether is loop, set recursive attr if yes.
608         if (!call_info.call_sites.empty() && call_info.call_sites.back().disable_tail) {
609           SearchRecursiveCall(callee, &call_info.call_sites.back());
610         }
611       }
612       // Add label index for the current call site.
613       auto label_index = static_cast<uint32_t>(call_info.return_points.size() - 1);
614       call_site->label_indexes.emplace(call_info.label_param, label_index);
615     }
616   }
617 
618   // Create a CallInfo for current kernel graph, return null if it is already existed.
CreateCallInfo()619   CallInfo *CreateCallInfo() {
620     auto [iter, ok] = context_.call_info_map.add(kernel_graph_);
621     if (!ok) {
622       // CallInfo already existed.
623       return nullptr;
624     }
625     return &(iter->second);
626   }
627 
628   // Create CallSite for Call node.
MakeCallSite(const CNodePtr & cnode,const AnfNodePtr & last_monad,CallInfo * call_info)629   void MakeCallSite(const CNodePtr &cnode, const AnfNodePtr &last_monad, CallInfo *call_info) {
630     auto &call_site = call_info->call_sites.emplace_back();
631     call_site.cnode = cnode;
632     call_site.last_monad = last_monad;
633     call_site.callees.emplace_back(GetCallBranch(cnode));
634   }
635 
636   // Create CallSite for Switch/SwitchLayer node.
MakeSwitchCallSite(const CNodePtr & cnode,const AnfNodePtr & last_monad,CallInfo * call_info)637   void MakeSwitchCallSite(const CNodePtr &cnode, const AnfNodePtr &last_monad, CallInfo *call_info) {
638     auto &call_site = call_info->call_sites.emplace_back();
639     call_site.cnode = cnode;
640     call_site.last_monad = last_monad;
641     call_site.callees = GetSwitchBranches(cnode);
642   }
643 
GetCallBranch(const CNodePtr & cnode)644   CallBranch GetCallBranch(const CNodePtr &cnode) {
645     auto input_graph = cnode->input(kCallKernelGraphIndex);
646     MS_EXCEPTION_IF_NULL(input_graph);
647     auto kg = GetValueNode<KernelGraphPtr>(input_graph);
648     MS_EXCEPTION_IF_NULL(kg);
649     constexpr int64_t call_arg_index = 2;
650     auto &inputs = cnode->inputs();
651     std::vector<AnfNodePtr> args{inputs.begin() + call_arg_index, inputs.end()};
652     return {.graph = kg, .args = std::move(args)};
653   }
654 
GetSwitchBranches(const CNodePtr & cnode)655   std::vector<CallBranch> GetSwitchBranches(const CNodePtr &cnode) {
656     constexpr size_t cond_start_index = 2;
657     std::vector<CallBranch> branches;
658     for (size_t index = cond_start_index; index < cnode->inputs().size(); ++index) {
659       branches.emplace_back(GetSwitchBranch(cnode, index));
660     }
661     return branches;
662   }
663 
GetSwitchBranch(const CNodePtr & cnode,size_t index)664   CallBranch GetSwitchBranch(const CNodePtr &cnode, size_t index) {
665     auto partial_cnode = dyn_cast<CNode>(cnode->input(index));
666     if (partial_cnode == nullptr) {
667       return {nullptr, {}};
668     }
669     auto &inputs = partial_cnode->inputs();
670     if (!IsPrimitive(inputs.at(0), prim::kPrimPartial)) {
671       MS_LOG(EXCEPTION) << "Invalid switch node: " << cnode->DebugString();
672     }
673     auto graph = GetValueNode<KernelGraphPtr>(inputs.at(1));
674     constexpr int64_t arg_index = 2;
675     std::vector<AnfNodePtr> args{inputs.begin() + arg_index, inputs.end()};
676     return {.graph = graph, .args = std::move(args)};
677   }
678 
GetRealNode(const AnfNodePtr & node)679   static AnfNodePtr GetRealNode(const AnfNodePtr &node) {
680     if (!IsPrimitiveCNode(node, prim::kPrimDepend)) {
681       return node;
682     }
683     return GetRealNode(node->cast<CNodePtr>()->input(1));
684   }
685 
686   const KernelGraphPtr &kernel_graph_;
687   AscendAutoMonadContext &context_;
688 };
689 
690 //
691 // AscendAutoMonadConverter convert control flow to monad form
692 // for a kernel graph and its children graphs recursively.
693 //
694 class AscendAutoMonadConverter {
695  public:
Run(AscendAutoMonadContext * context)696   static void Run(AscendAutoMonadContext *context) {
697     for (auto &entry : context->call_info_map) {
698       AscendAutoMonadConverter converter(entry.first, context, &entry.second);
699       converter.Run();
700     }
701     const auto &top_graph = context->TopGraph();
702     SetIterEndAttrForTopGraph(context, top_graph);
703   }
704 
705  private:
AscendAutoMonadConverter(const KernelGraphPtr & kg,AscendAutoMonadContext * context,CallInfo * call_info)706   AscendAutoMonadConverter(const KernelGraphPtr &kg, AscendAutoMonadContext *context, CallInfo *call_info)
707       : kernel_graph_(kg),
708         context_(*context),
709         call_info_(*call_info),
710         name_index_(0),
711         need_stackops_(call_info->recursive) {}
712   ~AscendAutoMonadConverter() = default;
713 
Run()714   void Run() {
715     // Create an stack
716     InitStack();
717     // Setup entry label if found.
718     SetupEntryLabel();
719 
720     // Handle call sites.
721     for (auto &call_site : call_info_.call_sites) {
722       HandleCallSite(&call_site);
723     }
724     // Handle return points.
725     HandleReturnPoints();
726     // Let output depend on monad.
727     if (monad_) {
728       MakeMonadDepend();
729     }
730     // Handle recursive call.
731     kernel_graph_->SetExecOrderByDefault();
732     if (call_info_.recursive) {
733       const auto &nodes = kernel_graph_->execution_order();
734       AnfAlgo::SetNodeAttr(kAttrRecursiveStart, prim::kValueOne, *nodes.begin());
735       AnfAlgo::SetNodeAttr(kAttrRecursiveEnd, prim::kValueOne, *nodes.rbegin());
736     }
737     for (auto &call_site : call_info_.call_sites) {
738       if (need_stackops_ && call_site.recursive) {
739         MS_LOG(INFO) << "graph:" << kernel_graph_->ToString() << ", loop call_site:" << call_site.cnode->DebugString();
740         InsertStackOps(call_site);
741       }
742     }
743   }
744 
745   // Set iteration end points for Profiling.
SetIterEndAttrForTopGraph(AscendAutoMonadContext * context,const KernelGraphPtr & kg)746   static void SetIterEndAttrForTopGraph(AscendAutoMonadContext *context, const KernelGraphPtr &kg) {
747     MS_EXCEPTION_IF_NULL(kg);
748     kg->SetExecOrderByDefault();
749     auto &nodes = kg->execution_order();
750     auto end_iter = nodes.rend();
751     std::set<KernelGraphPtr> memo;
752     memo.insert(kg);
753     auto call_info = context->call_info_map[kg];
754     if (call_info.call_sites.empty()) {
755       SetIterEndAttr(context, kg, false);
756       return;
757     } else {
758       const auto &end_node = call_info.call_sites.back().cnode;
759       end_iter = std::find(nodes.rbegin(), nodes.rend(), end_node);
760     }
761     for (auto iter = nodes.rbegin(); iter != end_iter; ++iter) {
762       if (!AnfAlgo::IsRealCNodeKernel(*iter)) {
763         continue;
764       }
765       if (AnfAlgo::CheckPrimitiveType(*iter, prim::kPrimLabelSet)) {
766         const auto &last_call_site = context->call_info_map[kg].call_sites.back();
767         for (auto &branch : last_call_site.callees) {
768           if (memo.find(branch.graph) != memo.end()) {
769             continue;
770           }
771           FindProfilingEndPoints(context, branch.graph, &memo);
772         }
773         break;
774       }
775       AnfAlgo::SetNodeAttr(ITEREND, prim::kValueOne, *iter);
776       MS_LOG(INFO) << "Set profiling iter-end points: " << (*iter)->DebugString();
777       return;
778     }
779   }
780 
781   // Set Attr to the iter-end points.
SetIterEndAttr(AscendAutoMonadContext * context,const KernelGraphPtr & kg,bool has_call_site)782   static void SetIterEndAttr(AscendAutoMonadContext *context, const KernelGraphPtr &kg, bool has_call_site) {
783     MS_EXCEPTION_IF_NULL(kg);
784     kg->SetExecOrderByDefault();
785     auto &nodes = kg->execution_order();
786     auto end_iter = nodes.rend();
787     if (has_call_site) {
788       const auto &end_node = context->call_info_map[kg].call_sites.back().cnode;
789       end_iter = std::find(nodes.rbegin(), nodes.rend(), end_node);
790     }
791     for (auto iter = nodes.rbegin(); iter != end_iter; ++iter) {
792       if (!AnfAlgo::IsRealCNodeKernel(*iter)) {
793         continue;
794       }
795       if (AnfAlgo::CheckPrimitiveType(*iter, prim::kPrimLabelGoto) && AnfAlgo::HasNodeAttr(kAttrReturn, *iter)) {
796         continue;
797       }
798       if (AnfAlgo::CheckPrimitiveType(*iter, prim::kPrimLabelGoto) ||
799           AnfAlgo::CheckPrimitiveType(*iter, prim::kPrimLabelSwitch) ||
800           AnfAlgo::CheckPrimitiveType(*iter, prim::kPrimLabelSet)) {
801         MS_LOG(ERROR) << "this node is Labelxxxx, do not found iter end.";
802         break;
803       }
804       AnfAlgo::SetNodeAttr(ITEREND, prim::kValueOne, *iter);
805       MS_LOG(INFO) << "Set profiling iter-end points: " << (*iter)->DebugString();
806       return;
807     }
808     MS_LOG(ERROR) << "Do not find iter_end point";
809   }
810 
811   // Find all iteration end points recursively.
FindProfilingEndPoints(AscendAutoMonadContext * context,const KernelGraphPtr & kg,std::set<KernelGraphPtr> * memo)812   static void FindProfilingEndPoints(AscendAutoMonadContext *context, const KernelGraphPtr &kg,
813                                      std::set<KernelGraphPtr> *memo) {
814     MS_EXCEPTION_IF_NULL(memo);
815     memo->insert(kg);
816     auto call_info = context->call_info_map[kg];
817     // 1. find the last call site; if no call site, goto step 3.
818     // 2. Judge the call site whether is tail call or not.
819     // 3. if yes, recursively find call site in subgraph; if no, find the last TBE node and set extra attr.
820     if (!call_info.call_sites.empty()) {
821       const auto &last_call_site = call_info.call_sites.back();
822       if (last_call_site.tail) {
823         for (auto &branch : last_call_site.callees) {
824           if (memo->find(branch.graph) != memo->end()) {
825             continue;
826           }
827           FindProfilingEndPoints(context, branch.graph, memo);
828         }
829       } else {
830         SetIterEndAttr(context, kg, true);
831       }
832     } else {
833       SetIterEndAttr(context, kg, false);
834     }
835   }
836 
837   // Create a Stack for StackOps if needed.
InitStack()838   void InitStack() {
839     if (!context_.HasInitedStack() && need_stackops_) {
840       auto top_graph = context_.TopGraph();
841       MS_EXCEPTION_IF_NULL(top_graph);
842       auto exec_order = top_graph->execution_order();
843       auto stack_init = StackInit(top_graph);
844       AnfAlgo::KeepOrder(top_graph, stack_init, *exec_order.begin());
845       auto stack_destroy = StackDestroy(top_graph);
846       AnfAlgo::KeepOrder(top_graph, *exec_order.rbegin(), stack_destroy);
847       top_graph->SetExecOrderByDefault();
848       context_.SetRecursiveCall(true);
849       context_.SetInitedStack(true);
850     }
851   }
852 
853   // Insert StackOps for call_site in the recursive graph.
InsertStackOps(const CallSite & call_site)854   void InsertStackOps(const CallSite &call_site) {
855     auto call_point = call_site.conversion_cnode;
856     auto exec_order = kernel_graph_->execution_order();
857     std::vector<AnfNodePtr> before_nodes;
858     std::vector<CNodePtr> stack_pushs;
859     bool find_call_point = false;
860     for (auto &node : exec_order) {
861       auto node_name = AnfAlgo::GetCNodeName(node);
862       if (node == call_point) {
863         find_call_point = true;
864         continue;
865       }
866       if (!find_call_point) {
867         if (node_name == kLabelGotoOpName || node_name == kLabelSwitchOpName || node_name == kLabelSetOpName ||
868             node_name == prim::kPrimAssign->name()) {
869           MS_LOG(DEBUG) << "Ignore goto/switch/set/assign ops";
870         } else {
871           before_nodes.push_back(node);
872           MS_LOG(DEBUG) << "push back node:" << node->DebugString();
873         }
874         continue;
875       }
876       if (node->size() == 0 || node_name == kLabelGotoOpName || node_name == kLabelSetOpName ||
877           node_name == prim::kPrimAssign->name()) {
878         continue;
879       }
880       FindInputNode(before_nodes, node, &stack_pushs);
881     }
882     InsertStackPush(kernel_graph_, call_point, stack_pushs);
883   }
884 
885   // Find nodes which need StackOps, and insert StackOps for node.
FindInputNode(const std::vector<AnfNodePtr> & before_nodes,const CNodePtr & node,std::vector<CNodePtr> * stack_pushs)886   void FindInputNode(const std::vector<AnfNodePtr> &before_nodes, const CNodePtr &node,
887                      std::vector<CNodePtr> *stack_pushs) {
888     MS_EXCEPTION_IF_NULL(node);
889     uint32_t start_index = 1;
890     if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimAssign)) {
891       start_index = kInputIndex;
892     }
893     for (uint32_t i = start_index; i < node->inputs().size(); i++) {
894       auto node_input = node->input(i);
895       // not need to save monad.
896       if (HasAbstractMonad(node_input)) {
897         continue;
898       }
899       MS_EXCEPTION_IF_NULL(node_input);
900       MS_LOG(DEBUG) << "check node input[" << i << "]: " << node_input->DebugString();
901       if (node_input->isa<Parameter>()) {
902         MS_LOG(DEBUG) << "node_input:" << node_input->DebugString() << " is a param";
903         CNodePtr stack_pop = InsertStackPop(node_input, stack_pushs);
904         node->set_input(i, stack_pop);
905         KeepOrderForStackPop(kernel_graph_, stack_pop, node);
906         continue;
907       }
908       auto iter = std::find_if(before_nodes.begin(), before_nodes.end(),
909                                [node_input](auto before_node) { return before_node == node_input; });
910       if (iter != before_nodes.end()) {
911         CNodePtr stack_pop = InsertStackPop(*iter, stack_pushs);
912         node->set_input(i, stack_pop);
913         KeepOrderForStackPop(kernel_graph_, stack_pop, node);
914       }
915     }
916   }
917 
918   // Create StackOps for node_input.
InsertStackPop(const AnfNodePtr & node_input,std::vector<CNodePtr> * stack_pushs)919   CNodePtr InsertStackPop(const AnfNodePtr &node_input, std::vector<CNodePtr> *stack_pushs) {
920     MS_EXCEPTION_IF_NULL(node_input);
921     MS_EXCEPTION_IF_NULL(stack_pushs);
922     auto stack_push = StackPush(node_input);
923     stack_pushs->emplace_back(stack_push);
924     auto stack_pop = StackPop();
925     MS_EXCEPTION_IF_NULL(stack_pop);
926     stack_pop->set_abstract(node_input->abstract());
927     return stack_pop;
928   }
929 
930   // Arrange StackPushs according to the rules of the last pop-up StackPush first,
931   // while ensuring that the last StackPush node is next to the jump_node.
InsertStackPush(const KernelGraphPtr & kg,const CNodePtr & jump_node,const std::vector<CNodePtr> & stack_pushs)932   void InsertStackPush(const KernelGraphPtr &kg, const CNodePtr &jump_node, const std::vector<CNodePtr> &stack_pushs) {
933     MS_LOG(DEBUG) << "There are " << stack_pushs.size() << " stack_push ops";
934     if (stack_pushs.size() < 1) {
935       return;
936     }
937     for (uint32_t i = 1; i < stack_pushs.size(); i++) {
938       AnfAlgo::KeepOrder(kg, stack_pushs[i], stack_pushs[i - 1]);
939     }
940     auto nodes = kg->execution_order();
941     auto node_iter = std::find(nodes.begin(), nodes.end(), jump_node);
942     AnfAlgo::KeepOrder(kg, stack_pushs[0], jump_node);
943     if (node_iter != nodes.begin()) {
944       AnfAlgo::KeepOrder(kg, *(node_iter - 1), *stack_pushs.rbegin());
945     }
946   }
947 
948   // Ensure StackPop is next to the jump_node.
KeepOrderForStackPop(const KernelGraphPtr & kg,const CNodePtr & pop,const CNodePtr & jump_node)949   void KeepOrderForStackPop(const KernelGraphPtr &kg, const CNodePtr &pop, const CNodePtr &jump_node) {
950     auto nodes = kg->execution_order();
951     auto node_iter = std::find(nodes.cbegin(), nodes.cend(), jump_node);
952     if (node_iter == nodes.cend()) {
953       MS_LOG(EXCEPTION) << "Cannot find node: " << jump_node->DebugString();
954     }
955     // Insert between jump_node-1 and jump_node.
956     if (node_iter != nodes.begin()) {
957       CNodePtr node = *(node_iter - 1);
958       AnfAlgo::KeepOrder(kg, node, pop);
959     }
960     AnfAlgo::KeepOrder(kg, pop, jump_node);
961   }
962 
HandleCallSite(CallSite * call_site)963   void HandleCallSite(CallSite *call_site) {
964     // Update last_monad_.
965     last_monad_ = call_site->last_monad;
966 
967     // The call/switch/switch_layer cnode.
968     auto &cnode = call_site->cnode;
969 
970     // Get branches of the call_site.
971     // for call, there is one branch;
972     // for switch, the first one is true branch;
973     // for switch_layer, the first one is 0 branch.
974     auto &branches = call_site->callees;
975 
976     // Link arguments and find labels for branches.
977     std::vector<KernelGraphPtr> graphes;
978     std::vector<uint32_t> labels;
979     graphes.reserve(branches.size());
980     labels.reserve(branches.size());
981     bool monad_update = false;
982     for (auto &[graph, args] : branches) {
983       MS_EXCEPTION_IF_NULL(graph);
984       auto linked_args = LinkArguments(args, graph);
985       if (linked_args != nullptr) {
986         monad_ = UpdateState(GetMonad(), linked_args);
987         monad_update = true;
988       }
989       graphes.push_back(graph);
990       labels.push_back(GetGraphLabel(graph));
991     }
992     if (!monad_update) {
993       monad_ = last_monad_;
994     }
995 
996     // Assign label indexes if required.
997     AssignLabelIndexes(call_site);
998 
999     // For Switch, we reverse the graphes and labels, so that the false branch
1000     // is the first one, since for kernel LabelSwitch, false is the first branch.
1001     if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch)) {
1002       std::reverse(graphes.begin(), graphes.end());
1003       std::reverse(labels.begin(), labels.end());
1004     }
1005 
1006     // Create LabelGoto or LabelSwitch node.
1007     auto label_goto_switch = MakeLabelGotoSwitch(cnode, graphes, labels);
1008     call_site->conversion_cnode = label_goto_switch;
1009     if (call_site->recursive) {
1010       AnfAlgo::SetNodeAttr(kAttrRecursive, prim::kValueOne, label_goto_switch);
1011     }
1012 
1013     // Setup return label and output if required.
1014     if (call_site->return_label != kNoLabel) {
1015       auto label_node = LabelSet(call_site->return_label);
1016       AnfNodePtr output = call_site->out_param;
1017       MS_EXCEPTION_IF_NULL(output);
1018       const bool is_single_call = call_site->label_indexes.empty();
1019       if (is_single_call) {
1020         // For single call, let output depend on the label node,
1021         // this ensures the return label is set before output is used.
1022         output = MakeDepend(output, label_node);
1023       } else {
1024         // For multi-return call, assign result from temp parameter to
1025         // output parameter, this prevent result be overwritten by next call.
1026         auto tmp_param = context_.GetTempParameter(output->abstract());
1027         output = AssignAll(output, tmp_param, false, false, true);
1028         monad_ = UpdateState(GetMonad(), output);
1029       }
1030       // Replace the the call/switch node with the output.
1031       ReplaceNode(cnode, output);
1032       return;
1033     }
1034 
1035     // If no return label required, it should be a tail call.
1036     if (!call_site->tail) {
1037       MS_LOG(EXCEPTION) << "Return label not set for non-tail call " << cnode->DebugString();
1038     }
1039     // For tail calls, replace origin call node with label_goto/label_switch.
1040     ReplaceNode(cnode, label_goto_switch);
1041     kernel_graph_->set_end_goto(label_goto_switch);
1042   }
1043 
1044   // Assign label indexes to label parameters for a call site.
AssignLabelIndexes(const CallSite * call_site)1045   void AssignLabelIndexes(const CallSite *call_site) {
1046     for (auto &[label_param, label_index] : call_site->label_indexes) {
1047       auto index_value = GetIndexValueNode(label_index);
1048       auto assign = Assign(label_param, index_value, false, false, false);
1049       monad_ = UpdateState(GetMonad(), assign);
1050     }
1051   }
1052 
1053   // Create or reuse ValueNode for the index.
GetIndexValueNode(uint32_t index)1054   ValueNodePtr GetIndexValueNode(uint32_t index) {
1055     auto iter = index_nodes_.find(index);
1056     if (iter != index_nodes_.end()) {
1057       // Reuse ValueNode for same index.
1058       return iter->second;
1059     }
1060     // Create a new ValueNode on top graph for the index.
1061     auto &top_graph = context_.TopGraph();
1062     std::vector<int64_t> data = {static_cast<int64_t>(index)};
1063     auto tensor = std::make_shared<tensor::Tensor>(data, kInt32);
1064     auto value_node = top_graph->NewValueNode(tensor->ToAbstract(), tensor);
1065     top_graph->AddValueNodeToGraph(value_node);
1066     index_nodes_.emplace(index, value_node);
1067     return value_node;
1068   }
1069 
1070   // Replace a node with new node in current kernel graph.
1071   // We also replace the arguments used for sub-graph calls.
ReplaceNode(const AnfNodePtr & old_node,const AnfNodePtr & new_node)1072   void ReplaceNode(const AnfNodePtr &old_node, const AnfNodePtr &new_node) {
1073     kernel_graph_->ReplaceNode(old_node, new_node);
1074     for (auto &call_site : call_info_.call_sites) {
1075       for (auto &callee : call_site.callees) {
1076         std::replace(callee.args.begin(), callee.args.end(), old_node, new_node);
1077       }
1078     }
1079   }
1080 
1081   // Make a label_goto or label_switch for a Call/Switch/SwitchLayer node.
MakeLabelGotoSwitch(const CNodePtr & cnode,const std::vector<KernelGraphPtr> & graphes,const std::vector<uint32_t> & labels)1082   CNodePtr MakeLabelGotoSwitch(const CNodePtr &cnode, const std::vector<KernelGraphPtr> &graphes,
1083                                const std::vector<uint32_t> &labels) {
1084     // Create LabelGoto or LabelSwitch according the cnode type.
1085     const bool is_call = AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimCall);
1086     auto label_goto_switch = (is_call ? LabelGoto(labels.front()) : LabelSwitch(cnode->input(1), labels));
1087 
1088     // Set child graph attribute for the LabelGoto or LabelSwitch node.
1089     SetChildGrapAttr(label_goto_switch, graphes);
1090 
1091     // Mark the label_switch node is for 'switch_layer' if it is.
1092     if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitchLayer)) {
1093       AnfAlgo::SetNodeAttr(kAttrSwitchLayer, prim::kValueOne, label_goto_switch);
1094     }
1095     return label_goto_switch;
1096   }
1097 
1098   // Handle return points.
1099   // use label_goto for single return point;
1100   // use label_switch for multi return points.
HandleReturnPoints()1101   void HandleReturnPoints() {
1102     auto &return_points = call_info_.return_points;
1103     // No return points.
1104     if (return_points.empty()) {
1105       return;
1106     }
1107     if (call_info_.return_monad_ != nullptr) {
1108       monad_ = call_info_.return_monad_;
1109     }
1110     // Assign output according the return points.
1111     AssignOutput(return_points);
1112     // Single return point.
1113     if (return_points.size() == 1) {
1114       // Insert label_goto for return.
1115       auto &return_point = return_points.front();
1116       auto return_goto = LabelGoto(return_point.call_site->return_label);
1117       AnfAlgo::SetNodeAttr(kAttrReturn, prim::kValueOne, return_goto);
1118       kernel_graph_->set_end_goto(return_goto);
1119       return;
1120     }
1121     // Multi return points.
1122     std::vector<uint32_t> return_labels;
1123     return_labels.reserve(return_points.size());
1124     // Get return labels from return points.
1125     std::transform(return_points.begin(), return_points.end(), std::back_inserter(return_labels),
1126                    [](const ReturnPoint &return_point) { return return_point.call_site->return_label; });
1127     // Insert label_switch for multi return points.
1128     auto &label_param = call_info_.label_param;
1129     MS_EXCEPTION_IF_NULL(label_param);
1130     auto return_switch = LabelSwitch(label_param, return_labels);
1131     AnfAlgo::SetNodeAttr(kAttrReturn, prim::kValueOne, return_switch);
1132     if (!call_info_.recursive) {
1133       AnfAlgo::SetNodeAttr(kAttrMultiCallEnd, prim::kValueOne, return_switch);
1134     }
1135     kernel_graph_->set_end_goto(return_switch);
1136     context_.SetSubGraphMultiCall(true);
1137   }
1138 
1139   // Assign graph output to the output parameter.
AssignOutput(const std::vector<ReturnPoint> & return_points)1140   void AssignOutput(const std::vector<ReturnPoint> &return_points) {
1141     // For single call: we directly assign output to the output parameter of the call site;
1142     // For multi call: we assign output to a temp parameter, and let caller assign the
1143     // temp parameter to a output parameter after returned.
1144     auto call_site = return_points.front().call_site;
1145     MS_EXCEPTION_IF_NULL(call_site);
1146     const bool is_single_call = (return_points.size() == 1 && call_site->label_indexes.empty());
1147     AnfNodePtr out_param =
1148       (is_single_call ? call_site->out_param : context_.GetTempParameter(kernel_graph_->output()->abstract()));
1149     MS_EXCEPTION_IF_NULL(out_param);
1150     auto assign_output = AssignAll(out_param, kernel_graph_->output(), false, false, true);
1151     monad_ = UpdateState(GetMonad(), assign_output);
1152   }
1153 
1154   // Link actual arguments to graph's formal arguments.
1155   // 1. for multi-args:
1156   //   r = Call(fg, arg1, arg2, u)
1157   // linked arguments:
1158   //   r1 = Assign(para1, arg1, c)
1159   //   r2 = Assign(para2, arg2, c)
1160   //   tuple = MakeTuple(r1, r2, u)
1161   // 2. for single-arg:
1162   //   r = Call(fg, arg)
1163   // linked arguments:
1164   //   r = Assign(para1, arg1, c)
1165   // 3. for empty-arg:
1166   //   r = Call(fg)
1167   // linked arguments return null.
LinkArguments(const std::vector<AnfNodePtr> & args,const KernelGraphPtr & graph)1168   AnfNodePtr LinkArguments(const std::vector<AnfNodePtr> &args, const KernelGraphPtr &graph) {
1169     auto &paras = graph->inputs();
1170     if (args.size() != paras.size()) {
1171       MS_LOG(EXCEPTION) << "Wrong arg number! " << graph->ToString() << " " << args.size() << " != " << paras.size();
1172     }
1173     // If no argument, return null.
1174     if (args.empty()) {
1175       return nullptr;
1176     }
1177     // We do not eliminate argument Assign for recursive graphs.
1178     const bool keep = IsRecursive(graph);
1179     // Single argument.
1180     if (args.size() == 1) {
1181       auto &value = args.front();
1182       if (HasAbstractMonad(value) || paras.front() == value) {
1183         // No assign for single monad argument, return it.
1184         return value;
1185       }
1186       return AssignAll(paras.front(), value, true, keep, false);
1187     }
1188     // Multi arguments.
1189     AnfNodePtrList tuple_inputs;
1190     tuple_inputs.reserve(args.size() + 1);
1191     tuple_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
1192     for (size_t i = 0; i < args.size(); ++i) {
1193       auto &value = args.at(i);
1194       if (HasAbstractMonad(value)) {
1195         // No assign for monad arguments.
1196         tuple_inputs.emplace_back(value);
1197         continue;
1198       }
1199       // Assign general arguments.
1200       auto &target = paras.at(i);
1201       if (target == value) {
1202         continue;
1203       }
1204       (void)tuple_inputs.emplace_back(AssignAll(target, value, true, keep, false));
1205     }
1206     auto new_tuple = kernel_graph_->NewCNode(tuple_inputs);
1207     // Set abstract for the MakeTuple node.
1208     abstract::AbstractBasePtrList element_abstracts;
1209     (void)std::transform(tuple_inputs.begin() + 1, tuple_inputs.end(), std::back_inserter(element_abstracts),
1210                          [](const AnfNodePtr &input) { return input->abstract(); });
1211     new_tuple->set_abstract(std::make_shared<abstract::AbstractTuple>(element_abstracts));
1212     return new_tuple;
1213   }
1214 
1215   // Return true if the graph is involved with recursive calls.
IsRecursive(const KernelGraphPtr & kg)1216   bool IsRecursive(const KernelGraphPtr &kg) { return context_.call_info_map[kg].recursive; }
1217 
1218   // For some cnode, attributes may set to primitive instance, so we create a new prim instance for each cnode.
NewPrimitive(const PrimitivePtr & prim)1219   AnfNodePtr NewPrimitive(const PrimitivePtr &prim) { return NewValueNode(std::make_shared<Primitive>(prim->name())); }
1220 
GetLinkMonad()1221   AnfNodePtr GetLinkMonad() {
1222     if (last_monad_ != nullptr) {
1223       return last_monad_;
1224     }
1225     return GetMonad();
1226   }
1227 
1228   // Make a assign cnode.
Assign(const AnfNodePtr & target,const AnfNodePtr & source,bool link,bool keep,bool output)1229   CNodePtr Assign(const AnfNodePtr &target, const AnfNodePtr &source, bool link, bool keep, bool output) {
1230     auto monad = (link ? GetLinkMonad() : GetMonad());
1231     auto assign_prim = std::make_shared<Primitive>(prim::kPrimAssign->name());
1232     if (link) {
1233       // Mark this assign is to link real argument to formal argument.
1234       assign_prim->set_attr(LINK, prim::kValueOne);
1235     }
1236     if (keep) {
1237       // Mark this assign should not be eliminated.
1238       assign_prim->set_attr(KEEP, prim::kValueOne);
1239     }
1240     if (output) {
1241       // Mark this assign is used for output parameter.
1242       assign_prim->set_attr(OUTPUT, prim::kValueOne);
1243     }
1244     auto assign = NewValueNode(assign_prim);
1245     auto cnode = kernel_graph_->NewCNode({assign, target, source, monad});
1246     cnode->set_abstract(target->abstract());
1247     return cnode;
1248   }
1249 
1250   // AissgnAll support tuple to tuple assign.
AssignAll(const AnfNodePtr & target,const AnfNodePtr & source,bool link,bool keep,bool output)1251   AnfNodePtr AssignAll(const AnfNodePtr &target, const AnfNodePtr &source, bool link, bool keep, bool output) {
1252     if (!AnfAlgo::CheckPrimitiveType(target, prim::kPrimMakeTuple)) {
1253       // Assign single value.
1254       return Assign(target, source, link, keep, output);
1255     }
1256     // Assign tuple.
1257     std::vector<AnfNodePtr> targets = AnfAlgo::GetAllOutput(target, {prim::kPrimTupleGetItem});
1258     std::vector<AnfNodePtr> sources = AnfAlgo::GetAllOutput(source, {prim::kPrimTupleGetItem});
1259     if (targets.size() != sources.size()) {
1260       MS_LOG(EXCEPTION) << "Target size " << targets.size() << " != source size " << sources.size();
1261     }
1262     AnfNodePtrList tuple_inputs;
1263     tuple_inputs.reserve(targets.size() + 1);
1264     tuple_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
1265     for (size_t i = 0; i < targets.size(); ++i) {
1266       (void)tuple_inputs.emplace_back(Assign(targets[i], sources[i], link, keep, output));
1267     }
1268     auto new_tuple = kernel_graph_->NewCNode(tuple_inputs);
1269     // Set abstract for the MakeTuple node.
1270     abstract::AbstractBasePtrList element_abstracts;
1271     (void)std::transform(tuple_inputs.begin() + 1, tuple_inputs.end(), std::back_inserter(element_abstracts),
1272                          [](const AnfNodePtr &input) { return input->abstract(); });
1273     new_tuple->set_abstract(std::make_shared<abstract::AbstractTuple>(element_abstracts));
1274     return new_tuple;
1275   }
1276 
1277   // Insert UpdateState after input node.
UpdateState(const AnfNodePtr & state,const AnfNodePtr & input)1278   AnfNodePtr UpdateState(const AnfNodePtr &state, const AnfNodePtr &input) {
1279     auto update_state = NewValueNode(prim::kPrimUpdateState);
1280     auto update_state_cnode = kernel_graph_->NewCNode({update_state, state, input});
1281     update_state_cnode->set_abstract(state->abstract());
1282     return update_state_cnode;
1283   }
1284 
1285   // Make entry label for current graph.
1286   // from:
1287   //   def sub_graph(x, y):
1288   //     return add(x, y)
1289   // to:
1290   //   def sub_graph(x, y, c):
1291   //     c = LabelSet(c) : entry_label
1292   //     return add(x, y)
SetupEntryLabel()1293   void SetupEntryLabel() {
1294     auto entry_label = GetGraphLabel(kernel_graph_);
1295     if (entry_label != kNoLabel) {
1296       // Set entry label.
1297       auto label_node = LabelSet(entry_label);
1298       // Make start label the first one in execution order.
1299       kernel_graph_->set_start_label(label_node);
1300     }
1301   }
1302 
1303   // Make a Depend cnode.
MakeDepend(const AnfNodePtr & origin,const AnfNodePtr & input)1304   CNodePtr MakeDepend(const AnfNodePtr &origin, const AnfNodePtr &input) {
1305     auto depend = NewValueNode(prim::kPrimDepend);
1306     auto depend_cnode = kernel_graph_->NewCNode({depend, origin, input});
1307     depend_cnode->set_abstract(origin->abstract());
1308     return depend_cnode;
1309   }
1310 
1311   // Let output depend on monad.
MakeMonadDepend()1312   void MakeMonadDepend() {
1313     auto monad = GetMonad();
1314     auto origin_output = kernel_graph_->output();
1315     MS_EXCEPTION_IF_NULL(origin_output);
1316     if (origin_output != monad) {
1317       auto depend_cnode = MakeDepend(origin_output, monad);
1318       kernel_graph_->set_output(depend_cnode);
1319     }
1320   }
1321 
1322   // Gets the last monad node, we use a separated UMonad for control flow.
GetMonad()1323   AnfNodePtr &GetMonad() {
1324     if (monad_ == nullptr) {
1325       monad_ = GetMonadValue();
1326     }
1327     return monad_;
1328   }
1329 
1330   // Gets the monad const value node.
GetMonadValue()1331   AnfNodePtr &GetMonadValue() {
1332     if (monad_value_ == nullptr) {
1333       // We should create monad value node by kernel graph,
1334       // so that kernel_info is properly set for it.
1335       monad_value_ = kernel_graph_->NewValueNode(kUMonad->ToAbstract(), kUMonad);
1336     }
1337     return monad_value_;
1338   }
1339 
1340   // Make a LabelGoto node.
LabelGoto(uint32_t label_id)1341   CNodePtr LabelGoto(uint32_t label_id) {
1342     auto monad = GetMonad();
1343     auto label_goto = NewPrimitive(prim::kPrimLabelGoto);
1344     auto cnode = kernel_graph_->NewCNode({label_goto, monad});
1345     AnfAlgo::SetNodeAttr(kAttrLabelIndex, MakeValue(label_id), cnode);
1346     cnode->set_abstract(monad->abstract());
1347     monad_ = cnode;
1348     return cnode;
1349   }
1350 
1351   // Make a LabelSet node.
LabelSet(uint32_t label_id)1352   CNodePtr LabelSet(uint32_t label_id) {
1353     auto monad = GetMonad();
1354     auto label_set = NewPrimitive(prim::kPrimLabelSet);
1355     auto cnode = kernel_graph_->NewCNode({label_set, monad});
1356     AnfAlgo::SetNodeAttr(kAttrLabelIndex, MakeValue(label_id), cnode);
1357     cnode->set_abstract(monad->abstract());
1358     monad_ = cnode;
1359     return cnode;
1360   }
1361 
1362   // Make a LabelSwitch node.
LabelSwitch(const AnfNodePtr & cond,const std::vector<uint32_t> & labels)1363   CNodePtr LabelSwitch(const AnfNodePtr &cond, const std::vector<uint32_t> &labels) {
1364     auto monad = GetMonad();
1365     auto label_switch = NewPrimitive(prim::kPrimLabelSwitch);
1366     auto cnode = kernel_graph_->NewCNode({label_switch, cond, monad});
1367     auto label_list = MakeValue(labels);
1368     AnfAlgo::SetNodeAttr(kAttrLabelSwitchList, label_list, cnode);
1369     cnode->set_abstract(monad->abstract());
1370     monad_ = cnode;
1371     return cnode;
1372   }
1373 
1374   // Set child graph attribute for label_goto/label_switch node.
SetChildGrapAttr(const AnfNodePtr & node,const std::vector<KernelGraphPtr> & graphs)1375   void SetChildGrapAttr(const AnfNodePtr &node, const std::vector<KernelGraphPtr> &graphs) {
1376     AnfAlgo::SetNodeAttr(kAttrChildGraph, MakeValue(graphs), node);
1377   }
1378 
1379   // Make a StackInit node.
StackInit(const KernelGraphPtr & kg)1380   CNodePtr StackInit(const KernelGraphPtr &kg) {
1381     auto monad = AnfAlgo::MakeMonadValueNode(kg);
1382     auto stack_init = NewPrimitive(prim::kPrimStackInit);
1383     auto cnode = kg->NewCNode({stack_init, monad});
1384     AnfAlgo::SetNodeAttr(kAttrIndex, MakeValue<int64_t>(0), cnode);
1385     cnode->set_abstract(monad->abstract());
1386     return cnode;
1387   }
1388 
1389   // Make a StackDestroy node.
StackDestroy(const KernelGraphPtr & kg)1390   CNodePtr StackDestroy(const KernelGraphPtr &kg) {
1391     auto monad = AnfAlgo::MakeMonadValueNode(kg);
1392     auto stack_destroy = NewPrimitive(prim::kPrimStackDestroy);
1393     auto cnode = kg->NewCNode({stack_destroy, monad});
1394     AnfAlgo::SetNodeAttr(kAttrIndex, MakeValue<int64_t>(0), cnode);
1395     cnode->set_abstract(monad->abstract());
1396     return cnode;
1397   }
1398 
1399   // Make a StackPush node.
StackPush(const AnfNodePtr & input)1400   CNodePtr StackPush(const AnfNodePtr &input) {
1401     auto monad = AnfAlgo::MakeMonadValueNode(kernel_graph_);
1402     auto stack_push = NewPrimitive(prim::kPrimStackPush);
1403     auto cnode = kernel_graph_->NewCNode({stack_push, input, monad});
1404     AnfAlgo::SetNodeAttr(kAttrIndex, MakeValue<int64_t>(0), cnode);
1405     auto op_name = std::to_string(kernel_graph_->graph_id()) + "_stack_push_" + std::to_string(name_index_++);
1406     AnfAlgo::SetNodeAttr(kAttrStackOpName, MakeValue(op_name), cnode);
1407     cnode->set_abstract(monad->abstract());
1408     return cnode;
1409   }
1410 
1411   // Make a StackPop node.
StackPop()1412   CNodePtr StackPop() {
1413     auto monad = AnfAlgo::MakeMonadValueNode(kernel_graph_);
1414     auto stack_pop = NewPrimitive(prim::kPrimStackPop);
1415     auto cnode = kernel_graph_->NewCNode({stack_pop, monad});
1416     AnfAlgo::SetNodeAttr(kAttrIndex, MakeValue<int64_t>(0), cnode);
1417     auto op_name = std::to_string(kernel_graph_->graph_id()) + "_stack_pop_" + std::to_string(name_index_++);
1418     AnfAlgo::SetNodeAttr(kAttrStackOpName, MakeValue(op_name), cnode);
1419     cnode->set_abstract(monad->abstract());  // need to refresh output's abstract().
1420     return cnode;
1421   }
1422 
1423   const KernelGraphPtr &kernel_graph_;
1424   AscendAutoMonadContext &context_;
1425 
1426   // Call info for current kernel graph.
1427   CallInfo &call_info_;
1428 
1429   // The last monad for Call/Switch node.
1430   AnfNodePtr last_monad_;
1431 
1432   // The current control flow monad.
1433   AnfNodePtr monad_;
1434 
1435   // The control flow monad const value node.
1436   AnfNodePtr monad_value_;
1437 
1438   // Index value node cache for reuse.
1439   std::map<uint32_t, ValueNodePtr> index_nodes_;
1440 
1441   // The index of stackops name.
1442   uint32_t name_index_;
1443 
1444   // The flag which indicates to insert stackops.
1445   bool need_stackops_;
1446 };
1447 
1448 constexpr size_t kAssignTargetIndex = 1;
1449 constexpr size_t kAssignSourceIndex = 2;
1450 
1451 class ExecuteOrderGenerator {
1452  public:
1453   class Context : public BaseContext {};
ExecuteOrderGenerator(Context & context,const KernelGraphPtr & graph)1454   ExecuteOrderGenerator(Context &context, const KernelGraphPtr &graph) : context_(context), graph_(graph) {}
1455   ~ExecuteOrderGenerator() = default;
1456 
Run()1457   void Run() {
1458     GenerateExecuteOrder();
1459     EraseParameter();
1460     EraseLabel();
1461     UnfoldRepeatedLabels();
1462   }
1463 
1464  private:
GenerateGraphOrder(const KernelGraphPtr & graph)1465   void GenerateGraphOrder(const KernelGraphPtr &graph) {
1466     ExecuteOrderGenerator generator(context_, graph);
1467     generator.GenerateExecuteOrder();
1468   }
1469 
FindMaxLabelId(const std::vector<CNodePtr> & nodes)1470   uint32_t FindMaxLabelId(const std::vector<CNodePtr> &nodes) {
1471     uint32_t max_label = 0;
1472     for (auto &node : nodes) {
1473       if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelSet)) {
1474         auto label_id = AnfAlgo::GetNodeAttr<uint32_t>(node, kAttrLabelIndex);
1475         max_label = std::max(label_id, max_label);
1476       }
1477     }
1478     return max_label;
1479   }
1480 
HandleLabelSwitch(const AnfNodePtr & node,std::vector<uint32_t> * labels,std::vector<uint32_t> * switch_labels,std::multimap<uint32_t,uint32_t> * labels_multimap)1481   void HandleLabelSwitch(const AnfNodePtr &node, std::vector<uint32_t> *labels, std::vector<uint32_t> *switch_labels,
1482                          std::multimap<uint32_t, uint32_t> *labels_multimap) {
1483     bool is_new_labels = false;
1484     auto label_list = AnfAlgo::GetNodeAttr<std::vector<uint32_t>>(node, kAttrLabelSwitchList);
1485     std::vector<uint32_t> new_labels;
1486     new_labels.reserve(label_list.size());
1487     for (auto label_id : label_list) {
1488       auto iter = std::find_if(labels->begin(), labels->end(), [label_id](auto id) { return id == label_id; });
1489       // Use new label if find repeated label.
1490       if (iter == labels->end()) {
1491         (void)new_labels.emplace_back(label_id);
1492         (void)labels->emplace_back(label_id);
1493         continue;
1494       }
1495       (void)new_labels.emplace_back(++max_label_);
1496       (void)labels_multimap->emplace(*iter, max_label_);
1497       (void)labels->emplace_back(label_id);
1498       is_new_labels = true;
1499     }
1500     (void)switch_labels->insert(switch_labels->end(), new_labels.begin(), new_labels.end());
1501     if (is_new_labels) {
1502       AnfAlgo::SetNodeAttr(kAttrLabelSwitchList, MakeValue(new_labels), node);
1503     }
1504   }
1505 
HandleLabelGoto(const AnfNodePtr & node,std::vector<uint32_t> * labels,std::vector<uint32_t> * switch_labels,std::multimap<uint32_t,uint32_t> * labels_multimap)1506   void HandleLabelGoto(const AnfNodePtr &node, std::vector<uint32_t> *labels, std::vector<uint32_t> *switch_labels,
1507                        std::multimap<uint32_t, uint32_t> *labels_multimap) {
1508     auto label_id = AnfAlgo::GetNodeAttr<uint32_t>(node, kAttrLabelIndex);
1509     auto iter = std::find(switch_labels->begin(), switch_labels->end(), label_id);
1510     if (iter == switch_labels->end()) {
1511       (void)labels->emplace_back(label_id);
1512       return;
1513     }
1514     AnfAlgo::SetNodeAttr(kAttrLabelIndex, MakeValue(++max_label_), node);
1515     (void)labels_multimap->emplace(*iter, max_label_);
1516     (void)labels->emplace_back(max_label_);
1517   }
1518 
1519   // Unfold Repeated Labels, avoid same label in labelswitches.
UnfoldRepeatedLabels()1520   void UnfoldRepeatedLabels() {
1521     auto nodes = graph_->execution_order();
1522     std::vector<uint32_t> labels;
1523     std::vector<uint32_t> switch_labels;
1524     std::multimap<uint32_t, uint32_t> labels_multimap;
1525     max_label_ = FindMaxLabelId(nodes);
1526     for (auto &node : nodes) {
1527       if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelSwitch)) {
1528         HandleLabelSwitch(node, &labels, &switch_labels, &labels_multimap);
1529         continue;
1530       }
1531       if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelGoto)) {
1532         HandleLabelGoto(node, &labels, &switch_labels, &labels_multimap);
1533         continue;
1534       }
1535     }
1536     InsertLabelSet(&nodes, labels_multimap);
1537     graph_->set_label_num(max_label_ + 1);
1538     graph_->set_execution_order(nodes);
1539   }
1540 
InsertLabelSet(std::vector<CNodePtr> * nodes,const std::multimap<uint32_t,uint32_t> & labels_multimap)1541   void InsertLabelSet(std::vector<CNodePtr> *nodes, const std::multimap<uint32_t, uint32_t> &labels_multimap) {
1542     for (auto labels : labels_multimap) {
1543       auto old_label = labels.first;
1544       auto new_label = labels.second;
1545       auto iter = std::find_if(nodes->begin(), nodes->end(), [old_label](auto node) {
1546         if (!AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelSet)) {
1547           return false;
1548         }
1549         auto label_id = AnfAlgo::GetNodeAttr<uint32_t>(node, kAttrLabelIndex);
1550         return label_id == old_label;
1551       });
1552       if (iter == nodes->end()) {
1553         MS_LOG(EXCEPTION) << "Not found labelset:" << old_label;
1554       }
1555       auto label_set = NewValueNode(std::make_shared<Primitive>(prim::kPrimLabelSet->name()));
1556       auto cnode = graph_->NewCNode({label_set});
1557       AnfAlgo::CopyNodeAttrs(*iter, cnode);
1558       AnfAlgo::SetNodeAttr(kAttrLabelIndex, MakeValue(new_label), cnode);
1559       auto monad = graph_->NewValueNode(kUMonad->ToAbstract(), kUMonad);
1560       cnode->set_abstract(monad->abstract());
1561       (void)device::ascend::SelectKernelInfo(cnode);
1562       (void)nodes->insert(iter, cnode);
1563     }
1564   }
1565 
AppendGraphOrder(std::vector<CNodePtr> * execution_order,const KernelGraphPtr & graph)1566   void AppendGraphOrder(std::vector<CNodePtr> *execution_order, const KernelGraphPtr &graph) {
1567     auto &order = graph->execution_order();
1568     execution_order->insert(execution_order->end(), order.begin(), order.end());
1569   }
1570 
HasSubGraphs(const CNodePtr & cnode)1571   bool HasSubGraphs(const CNodePtr &cnode) { return (cnode && AnfAlgo::HasNodeAttr(kAttrChildGraph, cnode)); }
1572 
GetSubGraphs(const CNodePtr & cnode)1573   std::vector<KernelGraphPtr> GetSubGraphs(const CNodePtr &cnode) {
1574     return AnfAlgo::GetNodeAttr<std::vector<KernelGraphPtr>>(cnode, kAttrChildGraph);
1575   }
1576 
EraseNodeFromExecOrder(const AnfNodePtr & node,const NotNull<std::vector<CNodePtr> * > exec_order)1577   void EraseNodeFromExecOrder(const AnfNodePtr &node, const NotNull<std::vector<CNodePtr> *> exec_order) {
1578     MS_EXCEPTION_IF_NULL(node);
1579     auto exec_iter = std::find(exec_order->begin(), exec_order->end(), node);
1580     if (exec_iter == exec_order->end()) {
1581       MS_LOG(EXCEPTION) << "Cannot find " << node->DebugString() << " in exec order.";
1582     }
1583     exec_order->erase(exec_iter);
1584   }
1585 
GenerateExecuteOrder()1586   void GenerateExecuteOrder() {
1587     // Mark graph is visited.
1588     context_.MarkVisited(graph_);
1589 
1590     // Generate topo-sorted kernel cnodes list for this graph.
1591     graph_->SetExecOrderByDefault();
1592 
1593     std::vector<CNodePtr> execution_order;
1594     const auto &cnodes = graph_->execution_order();
1595     for (auto &cnode : cnodes) {
1596       // Push current node to execution order list.
1597       execution_order.push_back(cnode);
1598       // For cnode with sub-graphs, such as LabelSwitch, LabelGoto,
1599       // Generate execute order for these sub-graphs,
1600       // and then append them to current execution order list.
1601       if (HasSubGraphs(cnode)) {
1602         auto sub_graphs = GetSubGraphs(cnode);
1603         if (!AnfAlgo::HasNodeAttr(kAttrSwitchLayer, cnode)) {
1604           // For Switch, we use reversed order to generate sub-graph's execution order,
1605           // because the true branch of LabelSwitch is the second one, but
1606           // we want to make true branch ahead of false branch in the generated
1607           // execution order.
1608           std::reverse(sub_graphs.begin(), sub_graphs.end());
1609         }
1610         for (auto &sub_graph : sub_graphs) {
1611           if (context_.IsVisited(sub_graph)) {
1612             // Skip visited sub-graphs.
1613             continue;
1614           }
1615           GenerateGraphOrder(sub_graph);
1616           AppendGraphOrder(&execution_order, sub_graph);
1617         }
1618         // Clear ChildGraph attribute after execute order generated.
1619         AnfAlgo::EraseNodeAttr(kAttrChildGraph, cnode);
1620       }
1621     }
1622     // Save generated execution order into the graph.
1623     graph_->set_execution_order(std::move(execution_order));
1624   }
1625 
GetAllNodes(std::map<CNodePtr,const size_t> * search_list)1626   std::set<CNodePtr> GetAllNodes(std::map<CNodePtr, const size_t> *search_list) {
1627     const auto &all_graphs = context_.visited_graphs();
1628     std::set<CNodePtr> all_nodes;
1629     for (auto &graph : all_graphs) {
1630       auto out = graph->get_return();
1631       MS_EXCEPTION_IF_NULL(out);
1632       (void)search_list->emplace(out->cast<CNodePtr>(), 0);
1633       auto nodes = TopoSort(out);
1634       for (auto &node : nodes) {
1635         MS_EXCEPTION_IF_NULL(node);
1636         auto cnode = node->cast<CNodePtr>();
1637         if (cnode != nullptr) {
1638           (void)all_nodes.insert(cnode);
1639         }
1640       }
1641     }
1642     return all_nodes;
1643   }
1644 
GetRealNode(const AnfNodePtr & input)1645   static const AnfNodePtr &GetRealNode(const AnfNodePtr &input) {
1646     if (IsPrimitiveCNode(input, prim::kPrimLoad) || IsPrimitiveCNode(input, prim::kPrimDepend)) {
1647       return input->cast<CNodePtr>()->inputs().at(1);
1648     }
1649     return input;
1650   }
1651 
RemoveSameInputsAssigns(std::vector<CNodePtr> * exec_order) const1652   void RemoveSameInputsAssigns(std::vector<CNodePtr> *exec_order) const {
1653     for (auto iter = exec_order->begin(); iter != exec_order->end();) {
1654       auto &node = *iter;
1655       auto &inputs = node->inputs();
1656       if (IsPrimitiveCNode(node, prim::kPrimAssign) &&
1657           (inputs.at(kAssignTargetIndex) == GetRealNode(inputs.at(kAssignSourceIndex)))) {
1658         iter = exec_order->erase(iter);
1659       } else {
1660         ++iter;
1661       }
1662     }
1663   }
1664 
1665   // Erase redundant parameters and assign nodes.
EraseParameter()1666   void EraseParameter() {
1667     // Copy out execution order list.
1668     auto exec_order = graph_->execution_order();
1669     std::map<CNodePtr, const size_t> search_list;
1670     for (size_t i = 0; i < exec_order.size(); i++) {
1671       search_list.emplace(exec_order[i], i);
1672     }
1673 
1674     // Remove assigns that target and source are same.
1675     RemoveSameInputsAssigns(&exec_order);
1676 
1677     // Get all nodes and all graphs
1678     std::set<CNodePtr> all_nodes = GetAllNodes(&search_list);
1679     auto &all_graphs = context_.visited_graphs();
1680 
1681     // Count parameter write times by check all assign nodes.
1682     auto param_write_times = CountParameterAssigns(search_list, exec_order);
1683 
1684     // Erase redundant assigns.
1685     for (auto iter = exec_order.begin(); iter != exec_order.end();) {
1686       auto &node = *iter;
1687       // We only try to erase argument link assign nodes,
1688       // other assign nodes are skipped.
1689       if (IsOptimizableAssign(node)) {
1690         auto &target = node->inputs().at(kAssignTargetIndex);
1691         MS_EXCEPTION_IF_NULL(target);
1692         auto para = param_write_times.find(target);
1693         if (para != param_write_times.end() && para->second.first == 1) {
1694           // Check source of the Assign.
1695           auto &source = node->inputs().at(kAssignSourceIndex);
1696           MS_EXCEPTION_IF_NULL(source);
1697           if (source->isa<Parameter>()) {
1698             auto it = param_write_times.find(source);
1699             const auto index = search_list[node];
1700             if (it != param_write_times.end() && it->second.first > 0 && it->second.second > index) {
1701               // Skip if Assign source is a parameter and be written in other place.
1702               ++iter;
1703               continue;
1704             }
1705           }
1706           // If target only write once, and source not be written,
1707           // replace target with source and erase the Assign node.
1708           auto kg = target->func_graph()->cast<KernelGraphPtr>();
1709           MS_EXCEPTION_IF_NULL(kg);
1710           kg->ReplaceNode(target, source);
1711 
1712           // replace parameter in graph input
1713           for (auto &g : all_graphs) {
1714             auto child_graph_inputs = g->MutableInputs();
1715             std::replace(child_graph_inputs->begin(), child_graph_inputs->end(), target, source);
1716             MS_LOG(DEBUG) << "Replace parameter " << target->DebugString() << " by " << source->DebugString()
1717                           << " in graph " << g->graph_id() << " inputs";
1718           }
1719 
1720           // replace parameter in node
1721           for (auto &iter_node : all_nodes) {
1722             for (size_t i = 0; i < iter_node->size(); ++i) {
1723               if (iter_node->input(i) == target) {
1724                 MS_LOG(INFO) << "Replace " << iter_node->DebugString() << " input " << i << " by "
1725                              << source->DebugString();
1726                 iter_node->set_input(i, source);
1727               }
1728             }
1729           }
1730           iter = exec_order.erase(iter);
1731           continue;
1732         }
1733       }
1734       // Go next node.
1735       ++iter;
1736     }
1737     // Set new execution order with redundant assign removed.
1738     graph_->set_execution_order(std::move(exec_order));
1739   }
1740 
1741   // Count parameter write times by check all assign nodes.
CountParameterAssigns(const std::map<CNodePtr,const size_t> & search_list,const std::vector<CNodePtr> & exec_order)1742   std::map<AnfNodePtr, std::pair<int, size_t>> CountParameterAssigns(
1743     const std::map<CNodePtr, const size_t> &search_list, const std::vector<CNodePtr> &exec_order) {
1744     auto ref_map = graph_->GetRefMap();
1745     std::multimap<AnfNodePtr, std::tuple<size_t, AnfNodePtr, size_t>> ref_multimap;
1746     std::set<AnfNodePtr> root_inputs(graph_->inputs().begin(), graph_->inputs().end());
1747     (void)std::transform(ref_map.begin(), ref_map.end(), std::inserter(ref_multimap, ref_multimap.end()),
1748                          [](const std::pair<std::pair<AnfNodePtr, size_t>, std::pair<AnfNodePtr, size_t>> &p)
1749                            -> std::pair<AnfNodePtr, std::tuple<size_t, AnfNodePtr, size_t>> {
1750                            return {p.first.first, {p.first.second, p.second.first, p.second.second}};
1751                          });
1752     auto validate_ref_parameter = [](AnfNodePtr node) -> AnfNodePtr {
1753       if (node->isa<CNode>() && AnfAlgo::CheckPrimitiveType(node, prim::kPrimTransData)) {
1754         auto cnode = node->cast<CNodePtr>();
1755         MS_EXCEPTION_IF_NULL(cnode);
1756         auto first_input = cnode->input(kFirstDataInputIndex);
1757         MS_EXCEPTION_IF_NULL(first_input);
1758         return first_input;
1759       }
1760       return node;
1761     };
1762 
1763     // Find all graph input parameters.
1764     std::map<AnfNodePtr, std::pair<int, size_t>> param_write_times;
1765     const auto &all_graphs = context_.visited_graphs();
1766     for (const auto &graph : all_graphs) {
1767       for (auto &input : graph->inputs()) {
1768         if (input->isa<Parameter>()) {
1769           param_write_times.emplace(input, std::make_pair(0, 0));
1770         }
1771       }
1772     }
1773 
1774     // Search all refnodes for parameter write assigns.
1775     for (auto &node : exec_order) {
1776       if (ref_multimap.find(node) == ref_multimap.end()) {
1777         // if node is not refnode which cannot write param, skip it.
1778         continue;
1779       }
1780       std::set<AnfNodePtr> refed_parameters;
1781       for (auto [iter, end] = ref_multimap.equal_range(node); iter != end; ++iter) {
1782         (void)refed_parameters.insert(validate_ref_parameter(std::get<1>(iter->second)));
1783       }
1784       for (auto &in : node->inputs()) {
1785         auto visit_node = AnfAlgo::VisitKernelWithReturnType(in, 0).first;
1786         visit_node = validate_ref_parameter(visit_node);
1787         if (!visit_node->isa<Parameter>() || root_inputs.find(visit_node) != root_inputs.end()) {
1788           continue;
1789         }
1790         if (refed_parameters.find(visit_node) != refed_parameters.end()) {
1791           auto iter = param_write_times.find(visit_node);
1792           if (iter != param_write_times.end()) {
1793             // Found a parameter writer, count it.
1794             ++(iter->second.first);
1795             if (search_list.find(node) == search_list.end()) {
1796               MS_LOG(EXCEPTION) << "node: " << node->DebugString() << " cannot found in search list.";
1797             }
1798             iter->second.second = search_list.at(node);
1799           }
1800         }
1801       }
1802     }
1803     return param_write_times;
1804   }
1805 
1806   // Check if a node is an assign for argument link and can be optimized.
IsOptimizableAssign(const AnfNodePtr & node)1807   bool IsOptimizableAssign(const AnfNodePtr &node) {
1808     auto cnode = dyn_cast<CNode>(node);
1809     if (cnode == nullptr) {
1810       return false;
1811     }
1812     auto prim = GetValueNode<PrimitivePtr>(cnode->inputs().at(0));
1813     if (!IsPrimitiveEquals(prim, prim::kPrimAssign)) {
1814       return false;
1815     }
1816     return (prim->GetAttr(LINK) == prim::kValueOne) && (prim->GetAttr(KEEP) != prim::kValueOne);
1817   }
1818 
1819   // Erase LabelGoto and LabelSet
EraseLabel()1820   void EraseLabel() {
1821     // Find used labels (as jump target).
1822     std::set<uint32_t> label_used;
1823     auto exec_order = graph_->execution_order();
1824     for (auto iter = exec_order.begin(); iter != exec_order.end();) {
1825       auto &node = *iter;
1826       if (IsPrimitiveCNode(node, prim::kPrimLabelSwitch)) {
1827         auto labels = AnfAlgo::GetNodeAttr<std::vector<uint32_t>>(node, kAttrLabelSwitchList);
1828         for (auto label : labels) {
1829           label_used.insert(label);
1830         }
1831       } else if (IsPrimitiveCNode(node, prim::kPrimLabelGoto)) {
1832         auto label = AnfAlgo::GetNodeAttr<uint32_t>(node, kAttrLabelIndex);
1833         auto next = std::next(iter);
1834         if (next != exec_order.end() && IsPrimitiveCNode(*next, prim::kPrimLabelSet)) {
1835           // The LabelGoto that jump to next node can be removed.
1836           auto next_label = AnfAlgo::GetNodeAttr<uint32_t>(*next, kAttrLabelIndex);
1837           if (next_label == label) {
1838             iter = exec_order.erase(iter);
1839             continue;
1840           }
1841         }
1842         label_used.insert(label);
1843       }
1844       ++iter;
1845     }
1846     // Erase unused LabelSet nodes.
1847     for (auto iter = exec_order.begin(); iter != exec_order.end();) {
1848       auto &node = *iter;
1849       if (IsPrimitiveCNode(node, prim::kPrimLabelSet)) {
1850         auto label = AnfAlgo::GetNodeAttr<uint32_t>(node, kAttrLabelIndex);
1851         if (label_used.find(label) == label_used.end()) {
1852           iter = exec_order.erase(iter);
1853           continue;
1854         }
1855       }
1856       ++iter;
1857     }
1858     graph_->set_execution_order(std::move(exec_order));
1859   }
1860 
1861   Context &context_;
1862   const KernelGraphPtr graph_;
1863   uint32_t max_label_ = 0;
1864 };
1865 }  // namespace
1866 
Run()1867 void AscendAutoMonad::Run() {
1868   MS_LOG(DEBUG) << "Ascend auto-monad start.";
1869   auto kg = kernel_graph_.get();
1870   AscendAutoMonadContext context(kg);
1871   CallInfoFinder::Run(&context);
1872   AscendAutoMonadConverter::Run(&context);
1873   kernel_graph_->set_label_num(context.CurrentLabel() + 1);
1874   kernel_graph_->set_recursive_call(context.HasRecursiveCall());
1875   kernel_graph_->set_subgraph_multi_call(context.HasSubgraphMultiCall());
1876   MS_LOG(DEBUG) << "Ascend auto-monad finish.";
1877 #ifdef ENABLE_DUMP_IR
1878   DumpGraphForDebug(kernel_graph_);
1879 #endif
1880 }
1881 
GenerateExecuteOrder()1882 void AscendAutoMonad::GenerateExecuteOrder() {
1883   MS_LOG(DEBUG) << "Ascend generate execute order start.";
1884   ExecuteOrderGenerator::Context context;
1885   ExecuteOrderGenerator generator(context, kernel_graph_.get());
1886   generator.Run();
1887   MS_LOG(DEBUG) << "Ascend generate execute order finish.";
1888 #ifndef ENABLE_SECURITY
1889   DumpExecuteOrder(kernel_graph_);
1890 #endif
1891 }
1892 }  // namespace session
1893 }  // namespace mindspore
1894