1 /**
2 * Copyright 2020-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "backend/session/ascend_auto_monad.h"
18 #include <set>
19 #include <map>
20 #include <stack>
21 #include <vector>
22 #include <string>
23 #include <tuple>
24 #include <queue>
25 #include <utility>
26 #include <memory>
27 #include <algorithm>
28 #include "utils/ms_context.h"
29 #include "utils/ordered_map.h"
30 #include "base/core_ops.h"
31 #include "debug/anf_ir_dump.h"
32 #include "pipeline/jit/base.h"
33 #include "backend/session/anf_runtime_algorithm.h"
34 #include "runtime/device/ascend/kernel_select_ascend.h"
35
36 namespace mindspore {
37 namespace session {
38 namespace {
39 // Pair of graph and its actual arguments.
40 using GraphArgPair = std::pair<KernelGraphPtr, std::vector<AnfNodePtr>>;
41
42 // We start label id from 0, and use 0xFFFFFFFF to indicate label not set.
43 constexpr uint32_t kNoLabel = 0xFFFFFFFF;
44
45 // We start input index from 2 for AssignOp, as for inputs[2] is input, inputs[1] is output;
46 constexpr size_t kInputIndex = 2;
47
48 // Primitive attribute for argument link assign.
49 const char LINK[] = "link";
50
51 // Attribute to indicate that the node should not be eliminated.
52 // Used to keep argument Assign nodes for recursive graphs.
53 const char KEEP[] = "keep";
54
55 // Attribute to indicate that this is an assign for output.
56 const char OUTPUT[] = "output";
57
58 // Attribute to indicate that the node is last node in an iteration.
59 const char ITEREND[] = "PROFILING_ITER_END";
60
61 #ifdef ENABLE_DUMP_IR
IsSaveGraph()62 bool IsSaveGraph() {
63 auto context_ptr = MsContext::GetInstance();
64 MS_EXCEPTION_IF_NULL(context_ptr);
65 return context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
66 }
67
DumpAllGraphs(NotNull<KernelGraphPtr> kg,std::set<KernelGraphPtr> * memo)68 void DumpAllGraphs(NotNull<KernelGraphPtr> kg, std::set<KernelGraphPtr> *memo) {
69 if (memo->find(kg) != memo->end()) {
70 return;
71 }
72 memo->insert(kg);
73 std::string file_name = "ascend_auto_monad_" + std::to_string(kg->graph_id()) + ".ir";
74 DumpIR(file_name, kg.get());
75 for (auto &child : kg->child_graph_order()) {
76 auto cg = child.lock();
77 if (cg) {
78 DumpAllGraphs(NOT_NULL(cg), memo);
79 }
80 }
81 }
82
DumpGraphForDebug(const NotNull<KernelGraphPtr> kg)83 void DumpGraphForDebug(const NotNull<KernelGraphPtr> kg) {
84 if (IsSaveGraph()) {
85 std::set<KernelGraphPtr> memo;
86 DumpAllGraphs(kg, &memo);
87 }
88 }
89 #endif
90
91 #ifndef ENABLE_SECURITY
DumpExecuteOrder(const NotNull<KernelGraphPtr> kg)92 void DumpExecuteOrder(const NotNull<KernelGraphPtr> kg) {
93 if (!IsSaveGraph()) {
94 return;
95 }
96 std::string filename = "ascend_execute_order_" + std::to_string(kg->graph_id()) + ".dat";
97 auto filepath = GetSaveGraphsPathName(filename);
98 if (filepath.size() >= PATH_MAX) {
99 MS_LOG(ERROR) << "File path: " << filepath << " is too long.";
100 return;
101 }
102 char real_path[PATH_MAX] = {0};
103 #if defined(_WIN32) || defined(_WIN64)
104 if (_fullpath(filepath, filename.c_str(), PATH_MAX) == nullptr) {
105 MS_LOG(DEBUG) << "dir " << filename << " does not exit.";
106 }
107 #else
108 if (realpath(filepath.c_str(), real_path) == nullptr) {
109 MS_LOG(DEBUG) << "Dir " << filepath << " does not exit.";
110 }
111 #endif
112
113 std::ofstream fout(real_path);
114 if (!fout.is_open()) {
115 MS_LOG(ERROR) << "Open file '" << real_path << "' failed!";
116 return;
117 }
118
119 fout << "Execute order:\n";
120 int index = 0;
121 for (auto &cnode : kg->execution_order()) {
122 MS_EXCEPTION_IF_NULL(cnode);
123 if (IsPrimitiveCNode(cnode, prim::kPrimLabelSet)) {
124 fout << "L" << AnfAlgo::GetNodeAttr<uint32_t>(cnode, kAttrLabelIndex) << ":\n";
125 }
126 fout << " [" << index << "], " << cnode->DebugString();
127 if (AnfAlgo::HasNodeAttr(kAttrLabelIndex, cnode)) {
128 fout << " : L" << AnfAlgo::GetNodeAttr<uint32_t>(cnode, kAttrLabelIndex);
129 }
130 if (AnfAlgo::HasNodeAttr(kAttrLabelSwitchList, cnode)) {
131 auto labels = AnfAlgo::GetNodeAttr<std::vector<uint32_t>>(cnode, kAttrLabelSwitchList);
132 fout << " : ";
133 for (size_t i = 0; i < labels.size(); ++i) {
134 fout << ((i > 0) ? ", L" : "L") << labels[i];
135 }
136 }
137 fout << '\n';
138 index++;
139 }
140 fout.close();
141 }
142 #endif
143
144 // Return kNoLabel when label id attribute not set for the graph.
GetGraphLabel(const KernelGraphPtr & kg)145 uint32_t GetGraphLabel(const KernelGraphPtr &kg) {
146 auto value = kg->get_attr(kAttrLabelIndex);
147 if (value == nullptr) {
148 return kNoLabel;
149 }
150 return GetValue<uint32_t>(value);
151 }
152
153 // Check if one abstract is compatible with another abstract.
IsCompatible(const abstract::AbstractBasePtr & a1,const abstract::AbstractBasePtr & a2)154 bool IsCompatible(const abstract::AbstractBasePtr &a1, const abstract::AbstractBasePtr &a2) {
155 if (a1 == nullptr || a2 == nullptr) {
156 return false;
157 }
158 if (a1 == a2) {
159 return true;
160 }
161 // Check AbstractTuple.
162 if (a1->isa<abstract::AbstractTuple>() && a2->isa<abstract::AbstractTuple>()) {
163 auto &a1_tuple = static_cast<abstract::AbstractTuple &>(*a1);
164 auto &a2_tuple = static_cast<abstract::AbstractTuple &>(*a2);
165 auto &a1_elements = a1_tuple.elements();
166 auto &a2_elements = a2_tuple.elements();
167 if (a1_elements.size() != a2_elements.size()) {
168 return false;
169 }
170 for (size_t i = 0; i < a1_elements.size(); i++) {
171 MS_EXCEPTION_IF_NULL(a1_elements[i]);
172 MS_EXCEPTION_IF_NULL(a2_elements[i]);
173 if (!IsCompatible(a1_elements[i], a2_elements[i])) {
174 return false;
175 }
176 }
177 return true;
178 }
179 // Check AbstractTensor and AbstractRef.
180 auto type1 = a1->BuildType();
181 auto type2 = a2->BuildType();
182 if (type1 != type2 && *type1 != *type2) {
183 return false;
184 }
185 auto shape1 = a1->BuildShape();
186 auto shape2 = a2->BuildShape();
187 if (shape1 == shape2) {
188 return true;
189 }
190 if (shape1->isa<abstract::Shape>() && shape2->isa<abstract::Shape>()) {
191 const auto &shape1_vec = shape1->cast<abstract::ShapePtr>()->shape();
192 const auto &shape2_vec = shape2->cast<abstract::ShapePtr>()->shape();
193 if ((shape1_vec == ShapeVector({1}) && shape2_vec == ShapeVector()) ||
194 (shape1_vec == ShapeVector() && shape2_vec == ShapeVector({1}))) {
195 return true;
196 }
197 }
198 return *shape1 == *shape2;
199 }
200
201 struct CallBranch {
202 KernelGraphPtr graph;
203 std::vector<AnfNodePtr> args;
204 };
205
206 struct CallSite {
207 // Call/Switch/SwitchLayer
208 CNodePtr cnode;
209
210 // CNode after transferring to LabelGoto/LabelSwitch/LabelSet.
211 CNodePtr conversion_cnode;
212
213 // The last monad before call.
214 AnfNodePtr last_monad = nullptr;
215
216 // Branch graph called.
217 std::vector<CallBranch> callees;
218
219 // Parameter for return value.
220 AnfNodePtr out_param = nullptr;
221
222 // Label id for return.
223 uint32_t return_label = kNoLabel;
224
225 // Label param to index map.
226 std::map<AnfNodePtr, uint32_t> label_indexes;
227
228 // True if this is a recursive call.
229 bool recursive = false;
230
231 // True if this is a tail call.
232 bool tail = false;
233
234 // True if this call is a disable tail-opt call.
235 bool disable_tail = false;
236 };
237
238 struct ReturnPoint {
239 CallSite *call_site = nullptr;
240 };
241
242 struct CallInfo {
243 // Call sites in current graph.
244 std::vector<CallSite> call_sites;
245
246 // Return points of current graph.
247 std::vector<ReturnPoint> return_points;
248
249 // Parameter to store label index, if there are
250 // multi return points, this should be set.
251 AnfNodePtr label_param = nullptr;
252
253 // Return monad.
254 AnfNodePtr return_monad_ = nullptr;
255
256 // True if current graph is involved with recursive calls.
257 bool recursive = false;
258 };
259
260 //
261 // ParameterPool cache parameters by its abstract, so that we can reuse
262 // parameter with same abstract to store return values.
263 //
264 class ParameterPool {
265 public:
ParameterPool(const KernelGraphPtr & top_graph)266 explicit ParameterPool(const KernelGraphPtr &top_graph) : top_graph_(top_graph) {}
267 ~ParameterPool() = default;
268
269 // Create or get a parameter from pool with the given abstract.
GetParameter(const abstract::AbstractBasePtr & abs)270 AnfNodePtr GetParameter(const abstract::AbstractBasePtr &abs) {
271 // Find parameter in pool by the given abstract.
272 auto iter = std::find_if(paras_.begin(), paras_.end(), [&abs](auto ¶) {
273 auto para_abs = para->abstract();
274 // Reuse output parameter with compatible abstract.
275 return IsCompatible(abs, para_abs);
276 });
277 // Return the parameter if found.
278 if (iter != paras_.end()) {
279 return *iter;
280 }
281 // If parameter not found with the given abstract, create a new one.
282 auto para = top_graph_->NewParameter(abs);
283 auto out_para = top_graph_->TransTupleToMakeTuple(para);
284 // This is required, so that device memory can be allocated for it.
285 top_graph_->AddChildGraphResult(out_para);
286 // Save new para to pool.
287 paras_.push_back(out_para);
288 return out_para;
289 }
290
291 private:
292 // The top graph.
293 const KernelGraphPtr &top_graph_;
294
295 // Cached parameters.
296 std::vector<AnfNodePtr> paras_;
297 };
298
299 //
300 // Base class for context.
301 //
302 class BaseContext {
303 public:
MarkVisited(const KernelGraphPtr & kg)304 void MarkVisited(const KernelGraphPtr &kg) { visited_graphs_.insert(kg); }
305
IsVisited(const KernelGraphPtr & kg) const306 bool IsVisited(const KernelGraphPtr &kg) const { return visited_graphs_.find(kg) != visited_graphs_.end(); }
307
visited_graphs() const308 const std::set<KernelGraphPtr> &visited_graphs() const { return visited_graphs_; }
309
ClearVisited()310 void ClearVisited() { visited_graphs_.clear(); }
311
~BaseContext()312 virtual ~BaseContext() {}
313
314 private:
315 std::set<KernelGraphPtr> visited_graphs_;
316 };
317
318 //
319 // AscendAutoMonadContext holds some shared states during auto-monad.
320 //
321 class AscendAutoMonadContext : public BaseContext {
322 public:
AscendAutoMonadContext(const KernelGraphPtr & kg)323 explicit AscendAutoMonadContext(const KernelGraphPtr &kg) : top_graph_(kg), param_pool_(kg) {}
324 ~AscendAutoMonadContext() = default;
325
326 // Label id start from 1, and increased by 1 for each new id.
NewLabel()327 uint32_t NewLabel() { return label_id_++; }
328
329 // Current label id, also the number of label ids we currently used.
CurrentLabel() const330 uint32_t CurrentLabel() const { return label_id_; }
331
332 // Create a new parameter.
333 // Output parameters are all created on top graph.
CreateParameter(const AbstractBasePtr & abs)334 AnfNodePtr CreateParameter(const AbstractBasePtr &abs) {
335 auto para = top_graph_->NewParameter(abs);
336 auto out_para = top_graph_->TransTupleToMakeTuple(para);
337 // This is required, so that device memory can be allocated for it.
338 top_graph_->AddChildGraphResult(out_para);
339 return out_para;
340 }
341
342 // Get or create a temporary parameter for the given abstract.
GetTempParameter(const AbstractBasePtr & abs)343 AnfNodePtr GetTempParameter(const AbstractBasePtr &abs) { return param_pool_.GetParameter(abs); }
344
TopGraph() const345 const KernelGraphPtr &TopGraph() const { return top_graph_; }
346
347 // Has already created an stack.
HasInitedStack() const348 const bool HasInitedStack() const { return inited_stack_; }
349
350 // Set flag to indicate whether has already created an stack or not.
SetInitedStack(bool flag)351 void SetInitedStack(bool flag) { inited_stack_ = flag; }
352
353 // The graphs has recursion.
HasRecursiveCall() const354 bool HasRecursiveCall() const { return has_recursive_call_; }
355 // The graphs has subgraph multi-call.
HasSubgraphMultiCall() const356 bool HasSubgraphMultiCall() const { return has_subgraph_multicall_; }
357 // set flag to indicate whether has recursion.
SetRecursiveCall(bool flag)358 void SetRecursiveCall(bool flag) { has_recursive_call_ = flag; }
359 // set flag to indicate whether has multi-call.
SetSubGraphMultiCall(bool flag)360 void SetSubGraphMultiCall(bool flag) { has_subgraph_multicall_ = flag; }
361
362 // Map kernel_graph to its call info.
363 OrderedMap<KernelGraphPtr, CallInfo> call_info_map;
364
365 private:
366 // The top graph.
367 const KernelGraphPtr &top_graph_;
368
369 // The parameter pool that cache parameters for return value.
370 ParameterPool param_pool_;
371
372 // Current label id.
373 uint32_t label_id_ = 0;
374
375 // Create an stack for multi-call and non-tail recursion.
376 bool inited_stack_ = false;
377 // The graphs has recursion or not.
378 bool has_recursive_call_ = false;
379 // The graphs has subgraph multi-call or not.
380 bool has_subgraph_multicall_ = false;
381 };
382
383 //
384 // Call info finder finds graph call information.
385 //
386 class CallInfoFinder {
387 public:
Run(AscendAutoMonadContext * context)388 static void Run(AscendAutoMonadContext *context) {
389 CallInfoFinder finder(context->TopGraph(), context);
390 finder.Run();
391 }
392
393 private:
CallInfoFinder(const KernelGraphPtr & kg,AscendAutoMonadContext * context)394 CallInfoFinder(const KernelGraphPtr &kg, AscendAutoMonadContext *context) : kernel_graph_(kg), context_(*context) {}
395 ~CallInfoFinder() = default;
396
Run()397 void Run() {
398 FindCallSites();
399 FindRecursiveCalls();
400 DisableTailCalls();
401 FindCallReturns();
402 }
403
404 // Find all call sites.
FindCallSites()405 void FindCallSites() {
406 auto call_info = CreateCallInfo();
407 if (call_info == nullptr) {
408 // Skip if call_info for this graph already existed.
409 return;
410 }
411 // Update directly called sub-graphs.
412 kernel_graph_->UpdateChildGraphOrder();
413 // Find Call/Switch/SwitchLayer nodes, and make CallSites for them.
414 AnfNodePtr last_monad = nullptr;
415 auto nodes = TopoSort(kernel_graph_->output());
416 for (auto &node : nodes) {
417 MS_EXCEPTION_IF_NULL(node);
418 if (HasAbstractUMonad(node)) {
419 // Found a node with UMonad abstract, set it as the last monad.
420 last_monad = node;
421 call_info->return_monad_ = last_monad;
422 } else if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimCall)) {
423 MakeCallSite(node->cast<CNodePtr>(), last_monad, call_info);
424 call_info->return_monad_ = nullptr;
425 } else if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimSwitch) ||
426 AnfAlgo::CheckPrimitiveType(node, prim::kPrimSwitchLayer)) {
427 MakeSwitchCallSite(node->cast<CNodePtr>(), last_monad, call_info);
428 call_info->return_monad_ = nullptr;
429 }
430 }
431 // Set the last call as tail call if it is the output node.
432 // We don't set tail call for top graph because return is always required.
433 if (kernel_graph_ != context_.TopGraph() && !call_info->call_sites.empty()) {
434 auto real_output = GetRealNode(kernel_graph_->output());
435 if (real_output == call_info->call_sites.back().cnode) {
436 call_info->call_sites.back().tail = true;
437 }
438 }
439 // Recursively find CallSites from sub-graphs.
440 for (auto &call_site : call_info->call_sites) {
441 for (auto &callee : call_site.callees) {
442 CallInfoFinder finder(callee.graph, &context_);
443 finder.FindCallSites();
444 }
445 }
446 }
447
448 // Find recursive non-tail calls.
FindRecursiveCalls()449 void FindRecursiveCalls() {
450 for (auto &[caller, call_info] : context_.call_info_map) {
451 for (auto &call_site : call_info.call_sites) {
452 if (!call_site.tail) {
453 SearchRecursiveCall(caller, &call_site);
454 }
455 }
456 }
457 }
458
459 // Disable tail call optimization for recursive call graphs.
DisableTailCalls()460 void DisableTailCalls() {
461 for (auto &entry : context_.call_info_map) {
462 auto &call_info = entry.second;
463 if (call_info.recursive && !call_info.call_sites.empty()) {
464 call_info.call_sites.back().tail = false;
465 call_info.call_sites.back().disable_tail = true;
466 }
467 }
468 }
469
470 // Find call-return pairs.
FindCallReturns()471 void FindCallReturns() {
472 for (auto &[caller, call_info] : context_.call_info_map) {
473 for (auto &call_site : call_info.call_sites) {
474 for (auto &callee : call_site.callees) {
475 MakeGraphLabel(callee.graph);
476 }
477 if (!call_site.tail) {
478 SearchCallReturns(caller, &call_site);
479 }
480 }
481 }
482 }
483
484 // Create entry label for the given graph if not set.
MakeGraphLabel(const KernelGraphPtr & kg)485 void MakeGraphLabel(const KernelGraphPtr &kg) {
486 auto label = GetGraphLabel(kg);
487 if (label == kNoLabel) {
488 // Allocate a new label id and save it to the graph.
489 label = context_.NewLabel();
490 kg->set_attr(kAttrLabelIndex, MakeValue(label));
491 }
492 }
493
494 // Search return points for all non-tail calls.
SearchCallReturns(const KernelGraphPtr & caller,CallSite * call_site)495 void SearchCallReturns(const KernelGraphPtr &caller, CallSite *call_site) {
496 std::set<KernelGraphPtr> visited = {caller};
497 std::queue<CallSite *> call_sites;
498 call_sites.push(call_site);
499 while (!call_sites.empty()) {
500 auto site = call_sites.front();
501 call_sites.pop();
502 for (auto &callee : site->callees) {
503 auto &kg = callee.graph;
504 if (visited.find(kg) != visited.end()) {
505 // Skip visited graphs.
506 continue;
507 }
508 // Mark visited.
509 visited.emplace(kg);
510 // Check callee.
511 auto &call_info = context_.call_info_map[kg];
512 auto &sites = call_info.call_sites;
513 if (!sites.empty() && sites.back().tail) {
514 // Follow tail call.
515 call_sites.push(&sites.back());
516 } else {
517 // Find a call-return relation.
518 HandleCallReturn(call_site, kg);
519 }
520 }
521 }
522 }
523
524 struct SearchRecursiveContext {
525 const KernelGraphPtr &start_caller;
526 CallSite *start_site;
527 std::set<KernelGraphPtr> visited;
528 std::vector<KernelGraphPtr> call_path;
529 };
530
531 // Search recursive call from a call-site.
SearchRecursiveCall(const KernelGraphPtr & start_caller,CallSite * start_site)532 void SearchRecursiveCall(const KernelGraphPtr &start_caller, CallSite *start_site) {
533 SearchRecursiveContext context{.start_caller = start_caller, .start_site = start_site};
534 DoSearchRecursiveCall(start_caller, *start_site, &context);
535 }
536
DoSearchRecursiveCall(const KernelGraphPtr & graph,const CallSite & call_site,SearchRecursiveContext * ctx)537 void DoSearchRecursiveCall(const KernelGraphPtr &graph, const CallSite &call_site, SearchRecursiveContext *ctx) {
538 MS_EXCEPTION_IF_NULL(ctx);
539 // Record call path.
540 ctx->call_path.push_back(graph);
541 // Handle callee graphs.
542 for (auto &callee : call_site.callees) {
543 auto &sub_graph = callee.graph;
544 if (sub_graph == ctx->start_caller) {
545 // Find a recursive call path.
546 for (auto &g : ctx->call_path) {
547 // Mark recursive for all graphs in call path.
548 context_.call_info_map[g].recursive = true;
549 }
550 // Mark recursive for the start call-site.
551 MS_EXCEPTION_IF_NULL(ctx->start_site);
552 ctx->start_site->recursive = true;
553 continue;
554 }
555 if (ctx->visited.find(sub_graph) != ctx->visited.end()) {
556 // Skip visited graphs.
557 continue;
558 }
559 // Mark visited.
560 (void)ctx->visited.emplace(sub_graph);
561 // Check call sites in the sub-graph.
562 auto &call_info = context_.call_info_map[sub_graph];
563 auto &sites = call_info.call_sites;
564 for (auto &site : sites) {
565 if (!site.callees.empty()) {
566 DoSearchRecursiveCall(sub_graph, site, ctx);
567 }
568 }
569 }
570 // Don't forget this.
571 ctx->call_path.pop_back();
572 }
573
574 // Handle a call-return relation.
HandleCallReturn(CallSite * call_site,const KernelGraphPtr & callee)575 void HandleCallReturn(CallSite *call_site, const KernelGraphPtr &callee) {
576 // Create a label for the return point.
577 if (call_site->return_label == kNoLabel) {
578 call_site->return_label = context_.NewLabel();
579 }
580 if (!IsCompatible(call_site->cnode->abstract(), callee->output()->abstract())) {
581 MS_LOG(EXCEPTION) << "call_site node: " << call_site->cnode->DebugString() << " has different abstract() with "
582 << callee->ToString() << " output(), [ " << call_site->cnode->abstract()->ToString()
583 << " != " << callee->output()->abstract()->ToString() << " ],"
584 << "Do not support this situation, pls check if the graghs are correct.";
585 }
586
587 // Create a parameter for the return value.
588 if (call_site->out_param == nullptr) {
589 call_site->out_param = context_.CreateParameter(call_site->cnode->abstract());
590 }
591 // Add a return point for the callee graph.
592 auto &call_info = context_.call_info_map[callee];
593 auto &return_point = call_info.return_points.emplace_back();
594 return_point.call_site = call_site;
595
596 // Setup label index if there are multi return points.
597 const auto n_return_points = call_info.return_points.size();
598 const size_t return_point_sizes = 2;
599 if (n_return_points > 1) {
600 if (n_return_points == return_point_sizes) {
601 // Create a parameter to store label index.
602 const ShapeVector shape = {1};
603 auto abs = std::make_shared<abstract::AbstractTensor>(kInt32, shape);
604 call_info.label_param = context_.CreateParameter(abs);
605 // Add label index for the first call site.
606 call_info.return_points.front().call_site->label_indexes.emplace(call_info.label_param, 0);
607 // Judge the last call_site whether is loop, set recursive attr if yes.
608 if (!call_info.call_sites.empty() && call_info.call_sites.back().disable_tail) {
609 SearchRecursiveCall(callee, &call_info.call_sites.back());
610 }
611 }
612 // Add label index for the current call site.
613 auto label_index = static_cast<uint32_t>(call_info.return_points.size() - 1);
614 call_site->label_indexes.emplace(call_info.label_param, label_index);
615 }
616 }
617
618 // Create a CallInfo for current kernel graph, return null if it is already existed.
CreateCallInfo()619 CallInfo *CreateCallInfo() {
620 auto [iter, ok] = context_.call_info_map.add(kernel_graph_);
621 if (!ok) {
622 // CallInfo already existed.
623 return nullptr;
624 }
625 return &(iter->second);
626 }
627
628 // Create CallSite for Call node.
MakeCallSite(const CNodePtr & cnode,const AnfNodePtr & last_monad,CallInfo * call_info)629 void MakeCallSite(const CNodePtr &cnode, const AnfNodePtr &last_monad, CallInfo *call_info) {
630 auto &call_site = call_info->call_sites.emplace_back();
631 call_site.cnode = cnode;
632 call_site.last_monad = last_monad;
633 call_site.callees.emplace_back(GetCallBranch(cnode));
634 }
635
636 // Create CallSite for Switch/SwitchLayer node.
MakeSwitchCallSite(const CNodePtr & cnode,const AnfNodePtr & last_monad,CallInfo * call_info)637 void MakeSwitchCallSite(const CNodePtr &cnode, const AnfNodePtr &last_monad, CallInfo *call_info) {
638 auto &call_site = call_info->call_sites.emplace_back();
639 call_site.cnode = cnode;
640 call_site.last_monad = last_monad;
641 call_site.callees = GetSwitchBranches(cnode);
642 }
643
GetCallBranch(const CNodePtr & cnode)644 CallBranch GetCallBranch(const CNodePtr &cnode) {
645 auto input_graph = cnode->input(kCallKernelGraphIndex);
646 MS_EXCEPTION_IF_NULL(input_graph);
647 auto kg = GetValueNode<KernelGraphPtr>(input_graph);
648 MS_EXCEPTION_IF_NULL(kg);
649 constexpr int64_t call_arg_index = 2;
650 auto &inputs = cnode->inputs();
651 std::vector<AnfNodePtr> args{inputs.begin() + call_arg_index, inputs.end()};
652 return {.graph = kg, .args = std::move(args)};
653 }
654
GetSwitchBranches(const CNodePtr & cnode)655 std::vector<CallBranch> GetSwitchBranches(const CNodePtr &cnode) {
656 constexpr size_t cond_start_index = 2;
657 std::vector<CallBranch> branches;
658 for (size_t index = cond_start_index; index < cnode->inputs().size(); ++index) {
659 branches.emplace_back(GetSwitchBranch(cnode, index));
660 }
661 return branches;
662 }
663
GetSwitchBranch(const CNodePtr & cnode,size_t index)664 CallBranch GetSwitchBranch(const CNodePtr &cnode, size_t index) {
665 auto partial_cnode = dyn_cast<CNode>(cnode->input(index));
666 if (partial_cnode == nullptr) {
667 return {nullptr, {}};
668 }
669 auto &inputs = partial_cnode->inputs();
670 if (!IsPrimitive(inputs.at(0), prim::kPrimPartial)) {
671 MS_LOG(EXCEPTION) << "Invalid switch node: " << cnode->DebugString();
672 }
673 auto graph = GetValueNode<KernelGraphPtr>(inputs.at(1));
674 constexpr int64_t arg_index = 2;
675 std::vector<AnfNodePtr> args{inputs.begin() + arg_index, inputs.end()};
676 return {.graph = graph, .args = std::move(args)};
677 }
678
GetRealNode(const AnfNodePtr & node)679 static AnfNodePtr GetRealNode(const AnfNodePtr &node) {
680 if (!IsPrimitiveCNode(node, prim::kPrimDepend)) {
681 return node;
682 }
683 return GetRealNode(node->cast<CNodePtr>()->input(1));
684 }
685
686 const KernelGraphPtr &kernel_graph_;
687 AscendAutoMonadContext &context_;
688 };
689
690 //
691 // AscendAutoMonadConverter convert control flow to monad form
692 // for a kernel graph and its children graphs recursively.
693 //
694 class AscendAutoMonadConverter {
695 public:
Run(AscendAutoMonadContext * context)696 static void Run(AscendAutoMonadContext *context) {
697 for (auto &entry : context->call_info_map) {
698 AscendAutoMonadConverter converter(entry.first, context, &entry.second);
699 converter.Run();
700 }
701 const auto &top_graph = context->TopGraph();
702 SetIterEndAttrForTopGraph(context, top_graph);
703 }
704
705 private:
AscendAutoMonadConverter(const KernelGraphPtr & kg,AscendAutoMonadContext * context,CallInfo * call_info)706 AscendAutoMonadConverter(const KernelGraphPtr &kg, AscendAutoMonadContext *context, CallInfo *call_info)
707 : kernel_graph_(kg),
708 context_(*context),
709 call_info_(*call_info),
710 name_index_(0),
711 need_stackops_(call_info->recursive) {}
712 ~AscendAutoMonadConverter() = default;
713
Run()714 void Run() {
715 // Create an stack
716 InitStack();
717 // Setup entry label if found.
718 SetupEntryLabel();
719
720 // Handle call sites.
721 for (auto &call_site : call_info_.call_sites) {
722 HandleCallSite(&call_site);
723 }
724 // Handle return points.
725 HandleReturnPoints();
726 // Let output depend on monad.
727 if (monad_) {
728 MakeMonadDepend();
729 }
730 // Handle recursive call.
731 kernel_graph_->SetExecOrderByDefault();
732 if (call_info_.recursive) {
733 const auto &nodes = kernel_graph_->execution_order();
734 AnfAlgo::SetNodeAttr(kAttrRecursiveStart, prim::kValueOne, *nodes.begin());
735 AnfAlgo::SetNodeAttr(kAttrRecursiveEnd, prim::kValueOne, *nodes.rbegin());
736 }
737 for (auto &call_site : call_info_.call_sites) {
738 if (need_stackops_ && call_site.recursive) {
739 MS_LOG(INFO) << "graph:" << kernel_graph_->ToString() << ", loop call_site:" << call_site.cnode->DebugString();
740 InsertStackOps(call_site);
741 }
742 }
743 }
744
745 // Set iteration end points for Profiling.
SetIterEndAttrForTopGraph(AscendAutoMonadContext * context,const KernelGraphPtr & kg)746 static void SetIterEndAttrForTopGraph(AscendAutoMonadContext *context, const KernelGraphPtr &kg) {
747 MS_EXCEPTION_IF_NULL(kg);
748 kg->SetExecOrderByDefault();
749 auto &nodes = kg->execution_order();
750 auto end_iter = nodes.rend();
751 std::set<KernelGraphPtr> memo;
752 memo.insert(kg);
753 auto call_info = context->call_info_map[kg];
754 if (call_info.call_sites.empty()) {
755 SetIterEndAttr(context, kg, false);
756 return;
757 } else {
758 const auto &end_node = call_info.call_sites.back().cnode;
759 end_iter = std::find(nodes.rbegin(), nodes.rend(), end_node);
760 }
761 for (auto iter = nodes.rbegin(); iter != end_iter; ++iter) {
762 if (!AnfAlgo::IsRealCNodeKernel(*iter)) {
763 continue;
764 }
765 if (AnfAlgo::CheckPrimitiveType(*iter, prim::kPrimLabelSet)) {
766 const auto &last_call_site = context->call_info_map[kg].call_sites.back();
767 for (auto &branch : last_call_site.callees) {
768 if (memo.find(branch.graph) != memo.end()) {
769 continue;
770 }
771 FindProfilingEndPoints(context, branch.graph, &memo);
772 }
773 break;
774 }
775 AnfAlgo::SetNodeAttr(ITEREND, prim::kValueOne, *iter);
776 MS_LOG(INFO) << "Set profiling iter-end points: " << (*iter)->DebugString();
777 return;
778 }
779 }
780
781 // Set Attr to the iter-end points.
SetIterEndAttr(AscendAutoMonadContext * context,const KernelGraphPtr & kg,bool has_call_site)782 static void SetIterEndAttr(AscendAutoMonadContext *context, const KernelGraphPtr &kg, bool has_call_site) {
783 MS_EXCEPTION_IF_NULL(kg);
784 kg->SetExecOrderByDefault();
785 auto &nodes = kg->execution_order();
786 auto end_iter = nodes.rend();
787 if (has_call_site) {
788 const auto &end_node = context->call_info_map[kg].call_sites.back().cnode;
789 end_iter = std::find(nodes.rbegin(), nodes.rend(), end_node);
790 }
791 for (auto iter = nodes.rbegin(); iter != end_iter; ++iter) {
792 if (!AnfAlgo::IsRealCNodeKernel(*iter)) {
793 continue;
794 }
795 if (AnfAlgo::CheckPrimitiveType(*iter, prim::kPrimLabelGoto) && AnfAlgo::HasNodeAttr(kAttrReturn, *iter)) {
796 continue;
797 }
798 if (AnfAlgo::CheckPrimitiveType(*iter, prim::kPrimLabelGoto) ||
799 AnfAlgo::CheckPrimitiveType(*iter, prim::kPrimLabelSwitch) ||
800 AnfAlgo::CheckPrimitiveType(*iter, prim::kPrimLabelSet)) {
801 MS_LOG(ERROR) << "this node is Labelxxxx, do not found iter end.";
802 break;
803 }
804 AnfAlgo::SetNodeAttr(ITEREND, prim::kValueOne, *iter);
805 MS_LOG(INFO) << "Set profiling iter-end points: " << (*iter)->DebugString();
806 return;
807 }
808 MS_LOG(ERROR) << "Do not find iter_end point";
809 }
810
811 // Find all iteration end points recursively.
FindProfilingEndPoints(AscendAutoMonadContext * context,const KernelGraphPtr & kg,std::set<KernelGraphPtr> * memo)812 static void FindProfilingEndPoints(AscendAutoMonadContext *context, const KernelGraphPtr &kg,
813 std::set<KernelGraphPtr> *memo) {
814 MS_EXCEPTION_IF_NULL(memo);
815 memo->insert(kg);
816 auto call_info = context->call_info_map[kg];
817 // 1. find the last call site; if no call site, goto step 3.
818 // 2. Judge the call site whether is tail call or not.
819 // 3. if yes, recursively find call site in subgraph; if no, find the last TBE node and set extra attr.
820 if (!call_info.call_sites.empty()) {
821 const auto &last_call_site = call_info.call_sites.back();
822 if (last_call_site.tail) {
823 for (auto &branch : last_call_site.callees) {
824 if (memo->find(branch.graph) != memo->end()) {
825 continue;
826 }
827 FindProfilingEndPoints(context, branch.graph, memo);
828 }
829 } else {
830 SetIterEndAttr(context, kg, true);
831 }
832 } else {
833 SetIterEndAttr(context, kg, false);
834 }
835 }
836
837 // Create a Stack for StackOps if needed.
InitStack()838 void InitStack() {
839 if (!context_.HasInitedStack() && need_stackops_) {
840 auto top_graph = context_.TopGraph();
841 MS_EXCEPTION_IF_NULL(top_graph);
842 auto exec_order = top_graph->execution_order();
843 auto stack_init = StackInit(top_graph);
844 AnfAlgo::KeepOrder(top_graph, stack_init, *exec_order.begin());
845 auto stack_destroy = StackDestroy(top_graph);
846 AnfAlgo::KeepOrder(top_graph, *exec_order.rbegin(), stack_destroy);
847 top_graph->SetExecOrderByDefault();
848 context_.SetRecursiveCall(true);
849 context_.SetInitedStack(true);
850 }
851 }
852
853 // Insert StackOps for call_site in the recursive graph.
InsertStackOps(const CallSite & call_site)854 void InsertStackOps(const CallSite &call_site) {
855 auto call_point = call_site.conversion_cnode;
856 auto exec_order = kernel_graph_->execution_order();
857 std::vector<AnfNodePtr> before_nodes;
858 std::vector<CNodePtr> stack_pushs;
859 bool find_call_point = false;
860 for (auto &node : exec_order) {
861 auto node_name = AnfAlgo::GetCNodeName(node);
862 if (node == call_point) {
863 find_call_point = true;
864 continue;
865 }
866 if (!find_call_point) {
867 if (node_name == kLabelGotoOpName || node_name == kLabelSwitchOpName || node_name == kLabelSetOpName ||
868 node_name == prim::kPrimAssign->name()) {
869 MS_LOG(DEBUG) << "Ignore goto/switch/set/assign ops";
870 } else {
871 before_nodes.push_back(node);
872 MS_LOG(DEBUG) << "push back node:" << node->DebugString();
873 }
874 continue;
875 }
876 if (node->size() == 0 || node_name == kLabelGotoOpName || node_name == kLabelSetOpName ||
877 node_name == prim::kPrimAssign->name()) {
878 continue;
879 }
880 FindInputNode(before_nodes, node, &stack_pushs);
881 }
882 InsertStackPush(kernel_graph_, call_point, stack_pushs);
883 }
884
885 // Find nodes which need StackOps, and insert StackOps for node.
FindInputNode(const std::vector<AnfNodePtr> & before_nodes,const CNodePtr & node,std::vector<CNodePtr> * stack_pushs)886 void FindInputNode(const std::vector<AnfNodePtr> &before_nodes, const CNodePtr &node,
887 std::vector<CNodePtr> *stack_pushs) {
888 MS_EXCEPTION_IF_NULL(node);
889 uint32_t start_index = 1;
890 if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimAssign)) {
891 start_index = kInputIndex;
892 }
893 for (uint32_t i = start_index; i < node->inputs().size(); i++) {
894 auto node_input = node->input(i);
895 // not need to save monad.
896 if (HasAbstractMonad(node_input)) {
897 continue;
898 }
899 MS_EXCEPTION_IF_NULL(node_input);
900 MS_LOG(DEBUG) << "check node input[" << i << "]: " << node_input->DebugString();
901 if (node_input->isa<Parameter>()) {
902 MS_LOG(DEBUG) << "node_input:" << node_input->DebugString() << " is a param";
903 CNodePtr stack_pop = InsertStackPop(node_input, stack_pushs);
904 node->set_input(i, stack_pop);
905 KeepOrderForStackPop(kernel_graph_, stack_pop, node);
906 continue;
907 }
908 auto iter = std::find_if(before_nodes.begin(), before_nodes.end(),
909 [node_input](auto before_node) { return before_node == node_input; });
910 if (iter != before_nodes.end()) {
911 CNodePtr stack_pop = InsertStackPop(*iter, stack_pushs);
912 node->set_input(i, stack_pop);
913 KeepOrderForStackPop(kernel_graph_, stack_pop, node);
914 }
915 }
916 }
917
918 // Create StackOps for node_input.
InsertStackPop(const AnfNodePtr & node_input,std::vector<CNodePtr> * stack_pushs)919 CNodePtr InsertStackPop(const AnfNodePtr &node_input, std::vector<CNodePtr> *stack_pushs) {
920 MS_EXCEPTION_IF_NULL(node_input);
921 MS_EXCEPTION_IF_NULL(stack_pushs);
922 auto stack_push = StackPush(node_input);
923 stack_pushs->emplace_back(stack_push);
924 auto stack_pop = StackPop();
925 MS_EXCEPTION_IF_NULL(stack_pop);
926 stack_pop->set_abstract(node_input->abstract());
927 return stack_pop;
928 }
929
930 // Arrange StackPushs according to the rules of the last pop-up StackPush first,
931 // while ensuring that the last StackPush node is next to the jump_node.
InsertStackPush(const KernelGraphPtr & kg,const CNodePtr & jump_node,const std::vector<CNodePtr> & stack_pushs)932 void InsertStackPush(const KernelGraphPtr &kg, const CNodePtr &jump_node, const std::vector<CNodePtr> &stack_pushs) {
933 MS_LOG(DEBUG) << "There are " << stack_pushs.size() << " stack_push ops";
934 if (stack_pushs.size() < 1) {
935 return;
936 }
937 for (uint32_t i = 1; i < stack_pushs.size(); i++) {
938 AnfAlgo::KeepOrder(kg, stack_pushs[i], stack_pushs[i - 1]);
939 }
940 auto nodes = kg->execution_order();
941 auto node_iter = std::find(nodes.begin(), nodes.end(), jump_node);
942 AnfAlgo::KeepOrder(kg, stack_pushs[0], jump_node);
943 if (node_iter != nodes.begin()) {
944 AnfAlgo::KeepOrder(kg, *(node_iter - 1), *stack_pushs.rbegin());
945 }
946 }
947
948 // Ensure StackPop is next to the jump_node.
KeepOrderForStackPop(const KernelGraphPtr & kg,const CNodePtr & pop,const CNodePtr & jump_node)949 void KeepOrderForStackPop(const KernelGraphPtr &kg, const CNodePtr &pop, const CNodePtr &jump_node) {
950 auto nodes = kg->execution_order();
951 auto node_iter = std::find(nodes.cbegin(), nodes.cend(), jump_node);
952 if (node_iter == nodes.cend()) {
953 MS_LOG(EXCEPTION) << "Cannot find node: " << jump_node->DebugString();
954 }
955 // Insert between jump_node-1 and jump_node.
956 if (node_iter != nodes.begin()) {
957 CNodePtr node = *(node_iter - 1);
958 AnfAlgo::KeepOrder(kg, node, pop);
959 }
960 AnfAlgo::KeepOrder(kg, pop, jump_node);
961 }
962
HandleCallSite(CallSite * call_site)963 void HandleCallSite(CallSite *call_site) {
964 // Update last_monad_.
965 last_monad_ = call_site->last_monad;
966
967 // The call/switch/switch_layer cnode.
968 auto &cnode = call_site->cnode;
969
970 // Get branches of the call_site.
971 // for call, there is one branch;
972 // for switch, the first one is true branch;
973 // for switch_layer, the first one is 0 branch.
974 auto &branches = call_site->callees;
975
976 // Link arguments and find labels for branches.
977 std::vector<KernelGraphPtr> graphes;
978 std::vector<uint32_t> labels;
979 graphes.reserve(branches.size());
980 labels.reserve(branches.size());
981 bool monad_update = false;
982 for (auto &[graph, args] : branches) {
983 MS_EXCEPTION_IF_NULL(graph);
984 auto linked_args = LinkArguments(args, graph);
985 if (linked_args != nullptr) {
986 monad_ = UpdateState(GetMonad(), linked_args);
987 monad_update = true;
988 }
989 graphes.push_back(graph);
990 labels.push_back(GetGraphLabel(graph));
991 }
992 if (!monad_update) {
993 monad_ = last_monad_;
994 }
995
996 // Assign label indexes if required.
997 AssignLabelIndexes(call_site);
998
999 // For Switch, we reverse the graphes and labels, so that the false branch
1000 // is the first one, since for kernel LabelSwitch, false is the first branch.
1001 if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch)) {
1002 std::reverse(graphes.begin(), graphes.end());
1003 std::reverse(labels.begin(), labels.end());
1004 }
1005
1006 // Create LabelGoto or LabelSwitch node.
1007 auto label_goto_switch = MakeLabelGotoSwitch(cnode, graphes, labels);
1008 call_site->conversion_cnode = label_goto_switch;
1009 if (call_site->recursive) {
1010 AnfAlgo::SetNodeAttr(kAttrRecursive, prim::kValueOne, label_goto_switch);
1011 }
1012
1013 // Setup return label and output if required.
1014 if (call_site->return_label != kNoLabel) {
1015 auto label_node = LabelSet(call_site->return_label);
1016 AnfNodePtr output = call_site->out_param;
1017 MS_EXCEPTION_IF_NULL(output);
1018 const bool is_single_call = call_site->label_indexes.empty();
1019 if (is_single_call) {
1020 // For single call, let output depend on the label node,
1021 // this ensures the return label is set before output is used.
1022 output = MakeDepend(output, label_node);
1023 } else {
1024 // For multi-return call, assign result from temp parameter to
1025 // output parameter, this prevent result be overwritten by next call.
1026 auto tmp_param = context_.GetTempParameter(output->abstract());
1027 output = AssignAll(output, tmp_param, false, false, true);
1028 monad_ = UpdateState(GetMonad(), output);
1029 }
1030 // Replace the the call/switch node with the output.
1031 ReplaceNode(cnode, output);
1032 return;
1033 }
1034
1035 // If no return label required, it should be a tail call.
1036 if (!call_site->tail) {
1037 MS_LOG(EXCEPTION) << "Return label not set for non-tail call " << cnode->DebugString();
1038 }
1039 // For tail calls, replace origin call node with label_goto/label_switch.
1040 ReplaceNode(cnode, label_goto_switch);
1041 kernel_graph_->set_end_goto(label_goto_switch);
1042 }
1043
1044 // Assign label indexes to label parameters for a call site.
AssignLabelIndexes(const CallSite * call_site)1045 void AssignLabelIndexes(const CallSite *call_site) {
1046 for (auto &[label_param, label_index] : call_site->label_indexes) {
1047 auto index_value = GetIndexValueNode(label_index);
1048 auto assign = Assign(label_param, index_value, false, false, false);
1049 monad_ = UpdateState(GetMonad(), assign);
1050 }
1051 }
1052
1053 // Create or reuse ValueNode for the index.
GetIndexValueNode(uint32_t index)1054 ValueNodePtr GetIndexValueNode(uint32_t index) {
1055 auto iter = index_nodes_.find(index);
1056 if (iter != index_nodes_.end()) {
1057 // Reuse ValueNode for same index.
1058 return iter->second;
1059 }
1060 // Create a new ValueNode on top graph for the index.
1061 auto &top_graph = context_.TopGraph();
1062 std::vector<int64_t> data = {static_cast<int64_t>(index)};
1063 auto tensor = std::make_shared<tensor::Tensor>(data, kInt32);
1064 auto value_node = top_graph->NewValueNode(tensor->ToAbstract(), tensor);
1065 top_graph->AddValueNodeToGraph(value_node);
1066 index_nodes_.emplace(index, value_node);
1067 return value_node;
1068 }
1069
1070 // Replace a node with new node in current kernel graph.
1071 // We also replace the arguments used for sub-graph calls.
ReplaceNode(const AnfNodePtr & old_node,const AnfNodePtr & new_node)1072 void ReplaceNode(const AnfNodePtr &old_node, const AnfNodePtr &new_node) {
1073 kernel_graph_->ReplaceNode(old_node, new_node);
1074 for (auto &call_site : call_info_.call_sites) {
1075 for (auto &callee : call_site.callees) {
1076 std::replace(callee.args.begin(), callee.args.end(), old_node, new_node);
1077 }
1078 }
1079 }
1080
1081 // Make a label_goto or label_switch for a Call/Switch/SwitchLayer node.
MakeLabelGotoSwitch(const CNodePtr & cnode,const std::vector<KernelGraphPtr> & graphes,const std::vector<uint32_t> & labels)1082 CNodePtr MakeLabelGotoSwitch(const CNodePtr &cnode, const std::vector<KernelGraphPtr> &graphes,
1083 const std::vector<uint32_t> &labels) {
1084 // Create LabelGoto or LabelSwitch according the cnode type.
1085 const bool is_call = AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimCall);
1086 auto label_goto_switch = (is_call ? LabelGoto(labels.front()) : LabelSwitch(cnode->input(1), labels));
1087
1088 // Set child graph attribute for the LabelGoto or LabelSwitch node.
1089 SetChildGrapAttr(label_goto_switch, graphes);
1090
1091 // Mark the label_switch node is for 'switch_layer' if it is.
1092 if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitchLayer)) {
1093 AnfAlgo::SetNodeAttr(kAttrSwitchLayer, prim::kValueOne, label_goto_switch);
1094 }
1095 return label_goto_switch;
1096 }
1097
1098 // Handle return points.
1099 // use label_goto for single return point;
1100 // use label_switch for multi return points.
HandleReturnPoints()1101 void HandleReturnPoints() {
1102 auto &return_points = call_info_.return_points;
1103 // No return points.
1104 if (return_points.empty()) {
1105 return;
1106 }
1107 if (call_info_.return_monad_ != nullptr) {
1108 monad_ = call_info_.return_monad_;
1109 }
1110 // Assign output according the return points.
1111 AssignOutput(return_points);
1112 // Single return point.
1113 if (return_points.size() == 1) {
1114 // Insert label_goto for return.
1115 auto &return_point = return_points.front();
1116 auto return_goto = LabelGoto(return_point.call_site->return_label);
1117 AnfAlgo::SetNodeAttr(kAttrReturn, prim::kValueOne, return_goto);
1118 kernel_graph_->set_end_goto(return_goto);
1119 return;
1120 }
1121 // Multi return points.
1122 std::vector<uint32_t> return_labels;
1123 return_labels.reserve(return_points.size());
1124 // Get return labels from return points.
1125 std::transform(return_points.begin(), return_points.end(), std::back_inserter(return_labels),
1126 [](const ReturnPoint &return_point) { return return_point.call_site->return_label; });
1127 // Insert label_switch for multi return points.
1128 auto &label_param = call_info_.label_param;
1129 MS_EXCEPTION_IF_NULL(label_param);
1130 auto return_switch = LabelSwitch(label_param, return_labels);
1131 AnfAlgo::SetNodeAttr(kAttrReturn, prim::kValueOne, return_switch);
1132 if (!call_info_.recursive) {
1133 AnfAlgo::SetNodeAttr(kAttrMultiCallEnd, prim::kValueOne, return_switch);
1134 }
1135 kernel_graph_->set_end_goto(return_switch);
1136 context_.SetSubGraphMultiCall(true);
1137 }
1138
1139 // Assign graph output to the output parameter.
AssignOutput(const std::vector<ReturnPoint> & return_points)1140 void AssignOutput(const std::vector<ReturnPoint> &return_points) {
1141 // For single call: we directly assign output to the output parameter of the call site;
1142 // For multi call: we assign output to a temp parameter, and let caller assign the
1143 // temp parameter to a output parameter after returned.
1144 auto call_site = return_points.front().call_site;
1145 MS_EXCEPTION_IF_NULL(call_site);
1146 const bool is_single_call = (return_points.size() == 1 && call_site->label_indexes.empty());
1147 AnfNodePtr out_param =
1148 (is_single_call ? call_site->out_param : context_.GetTempParameter(kernel_graph_->output()->abstract()));
1149 MS_EXCEPTION_IF_NULL(out_param);
1150 auto assign_output = AssignAll(out_param, kernel_graph_->output(), false, false, true);
1151 monad_ = UpdateState(GetMonad(), assign_output);
1152 }
1153
1154 // Link actual arguments to graph's formal arguments.
1155 // 1. for multi-args:
1156 // r = Call(fg, arg1, arg2, u)
1157 // linked arguments:
1158 // r1 = Assign(para1, arg1, c)
1159 // r2 = Assign(para2, arg2, c)
1160 // tuple = MakeTuple(r1, r2, u)
1161 // 2. for single-arg:
1162 // r = Call(fg, arg)
1163 // linked arguments:
1164 // r = Assign(para1, arg1, c)
1165 // 3. for empty-arg:
1166 // r = Call(fg)
1167 // linked arguments return null.
LinkArguments(const std::vector<AnfNodePtr> & args,const KernelGraphPtr & graph)1168 AnfNodePtr LinkArguments(const std::vector<AnfNodePtr> &args, const KernelGraphPtr &graph) {
1169 auto ¶s = graph->inputs();
1170 if (args.size() != paras.size()) {
1171 MS_LOG(EXCEPTION) << "Wrong arg number! " << graph->ToString() << " " << args.size() << " != " << paras.size();
1172 }
1173 // If no argument, return null.
1174 if (args.empty()) {
1175 return nullptr;
1176 }
1177 // We do not eliminate argument Assign for recursive graphs.
1178 const bool keep = IsRecursive(graph);
1179 // Single argument.
1180 if (args.size() == 1) {
1181 auto &value = args.front();
1182 if (HasAbstractMonad(value) || paras.front() == value) {
1183 // No assign for single monad argument, return it.
1184 return value;
1185 }
1186 return AssignAll(paras.front(), value, true, keep, false);
1187 }
1188 // Multi arguments.
1189 AnfNodePtrList tuple_inputs;
1190 tuple_inputs.reserve(args.size() + 1);
1191 tuple_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
1192 for (size_t i = 0; i < args.size(); ++i) {
1193 auto &value = args.at(i);
1194 if (HasAbstractMonad(value)) {
1195 // No assign for monad arguments.
1196 tuple_inputs.emplace_back(value);
1197 continue;
1198 }
1199 // Assign general arguments.
1200 auto &target = paras.at(i);
1201 if (target == value) {
1202 continue;
1203 }
1204 (void)tuple_inputs.emplace_back(AssignAll(target, value, true, keep, false));
1205 }
1206 auto new_tuple = kernel_graph_->NewCNode(tuple_inputs);
1207 // Set abstract for the MakeTuple node.
1208 abstract::AbstractBasePtrList element_abstracts;
1209 (void)std::transform(tuple_inputs.begin() + 1, tuple_inputs.end(), std::back_inserter(element_abstracts),
1210 [](const AnfNodePtr &input) { return input->abstract(); });
1211 new_tuple->set_abstract(std::make_shared<abstract::AbstractTuple>(element_abstracts));
1212 return new_tuple;
1213 }
1214
1215 // Return true if the graph is involved with recursive calls.
IsRecursive(const KernelGraphPtr & kg)1216 bool IsRecursive(const KernelGraphPtr &kg) { return context_.call_info_map[kg].recursive; }
1217
1218 // For some cnode, attributes may set to primitive instance, so we create a new prim instance for each cnode.
NewPrimitive(const PrimitivePtr & prim)1219 AnfNodePtr NewPrimitive(const PrimitivePtr &prim) { return NewValueNode(std::make_shared<Primitive>(prim->name())); }
1220
GetLinkMonad()1221 AnfNodePtr GetLinkMonad() {
1222 if (last_monad_ != nullptr) {
1223 return last_monad_;
1224 }
1225 return GetMonad();
1226 }
1227
1228 // Make a assign cnode.
Assign(const AnfNodePtr & target,const AnfNodePtr & source,bool link,bool keep,bool output)1229 CNodePtr Assign(const AnfNodePtr &target, const AnfNodePtr &source, bool link, bool keep, bool output) {
1230 auto monad = (link ? GetLinkMonad() : GetMonad());
1231 auto assign_prim = std::make_shared<Primitive>(prim::kPrimAssign->name());
1232 if (link) {
1233 // Mark this assign is to link real argument to formal argument.
1234 assign_prim->set_attr(LINK, prim::kValueOne);
1235 }
1236 if (keep) {
1237 // Mark this assign should not be eliminated.
1238 assign_prim->set_attr(KEEP, prim::kValueOne);
1239 }
1240 if (output) {
1241 // Mark this assign is used for output parameter.
1242 assign_prim->set_attr(OUTPUT, prim::kValueOne);
1243 }
1244 auto assign = NewValueNode(assign_prim);
1245 auto cnode = kernel_graph_->NewCNode({assign, target, source, monad});
1246 cnode->set_abstract(target->abstract());
1247 return cnode;
1248 }
1249
1250 // AissgnAll support tuple to tuple assign.
AssignAll(const AnfNodePtr & target,const AnfNodePtr & source,bool link,bool keep,bool output)1251 AnfNodePtr AssignAll(const AnfNodePtr &target, const AnfNodePtr &source, bool link, bool keep, bool output) {
1252 if (!AnfAlgo::CheckPrimitiveType(target, prim::kPrimMakeTuple)) {
1253 // Assign single value.
1254 return Assign(target, source, link, keep, output);
1255 }
1256 // Assign tuple.
1257 std::vector<AnfNodePtr> targets = AnfAlgo::GetAllOutput(target, {prim::kPrimTupleGetItem});
1258 std::vector<AnfNodePtr> sources = AnfAlgo::GetAllOutput(source, {prim::kPrimTupleGetItem});
1259 if (targets.size() != sources.size()) {
1260 MS_LOG(EXCEPTION) << "Target size " << targets.size() << " != source size " << sources.size();
1261 }
1262 AnfNodePtrList tuple_inputs;
1263 tuple_inputs.reserve(targets.size() + 1);
1264 tuple_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
1265 for (size_t i = 0; i < targets.size(); ++i) {
1266 (void)tuple_inputs.emplace_back(Assign(targets[i], sources[i], link, keep, output));
1267 }
1268 auto new_tuple = kernel_graph_->NewCNode(tuple_inputs);
1269 // Set abstract for the MakeTuple node.
1270 abstract::AbstractBasePtrList element_abstracts;
1271 (void)std::transform(tuple_inputs.begin() + 1, tuple_inputs.end(), std::back_inserter(element_abstracts),
1272 [](const AnfNodePtr &input) { return input->abstract(); });
1273 new_tuple->set_abstract(std::make_shared<abstract::AbstractTuple>(element_abstracts));
1274 return new_tuple;
1275 }
1276
1277 // Insert UpdateState after input node.
UpdateState(const AnfNodePtr & state,const AnfNodePtr & input)1278 AnfNodePtr UpdateState(const AnfNodePtr &state, const AnfNodePtr &input) {
1279 auto update_state = NewValueNode(prim::kPrimUpdateState);
1280 auto update_state_cnode = kernel_graph_->NewCNode({update_state, state, input});
1281 update_state_cnode->set_abstract(state->abstract());
1282 return update_state_cnode;
1283 }
1284
1285 // Make entry label for current graph.
1286 // from:
1287 // def sub_graph(x, y):
1288 // return add(x, y)
1289 // to:
1290 // def sub_graph(x, y, c):
1291 // c = LabelSet(c) : entry_label
1292 // return add(x, y)
SetupEntryLabel()1293 void SetupEntryLabel() {
1294 auto entry_label = GetGraphLabel(kernel_graph_);
1295 if (entry_label != kNoLabel) {
1296 // Set entry label.
1297 auto label_node = LabelSet(entry_label);
1298 // Make start label the first one in execution order.
1299 kernel_graph_->set_start_label(label_node);
1300 }
1301 }
1302
1303 // Make a Depend cnode.
MakeDepend(const AnfNodePtr & origin,const AnfNodePtr & input)1304 CNodePtr MakeDepend(const AnfNodePtr &origin, const AnfNodePtr &input) {
1305 auto depend = NewValueNode(prim::kPrimDepend);
1306 auto depend_cnode = kernel_graph_->NewCNode({depend, origin, input});
1307 depend_cnode->set_abstract(origin->abstract());
1308 return depend_cnode;
1309 }
1310
1311 // Let output depend on monad.
MakeMonadDepend()1312 void MakeMonadDepend() {
1313 auto monad = GetMonad();
1314 auto origin_output = kernel_graph_->output();
1315 MS_EXCEPTION_IF_NULL(origin_output);
1316 if (origin_output != monad) {
1317 auto depend_cnode = MakeDepend(origin_output, monad);
1318 kernel_graph_->set_output(depend_cnode);
1319 }
1320 }
1321
1322 // Gets the last monad node, we use a separated UMonad for control flow.
GetMonad()1323 AnfNodePtr &GetMonad() {
1324 if (monad_ == nullptr) {
1325 monad_ = GetMonadValue();
1326 }
1327 return monad_;
1328 }
1329
1330 // Gets the monad const value node.
GetMonadValue()1331 AnfNodePtr &GetMonadValue() {
1332 if (monad_value_ == nullptr) {
1333 // We should create monad value node by kernel graph,
1334 // so that kernel_info is properly set for it.
1335 monad_value_ = kernel_graph_->NewValueNode(kUMonad->ToAbstract(), kUMonad);
1336 }
1337 return monad_value_;
1338 }
1339
1340 // Make a LabelGoto node.
LabelGoto(uint32_t label_id)1341 CNodePtr LabelGoto(uint32_t label_id) {
1342 auto monad = GetMonad();
1343 auto label_goto = NewPrimitive(prim::kPrimLabelGoto);
1344 auto cnode = kernel_graph_->NewCNode({label_goto, monad});
1345 AnfAlgo::SetNodeAttr(kAttrLabelIndex, MakeValue(label_id), cnode);
1346 cnode->set_abstract(monad->abstract());
1347 monad_ = cnode;
1348 return cnode;
1349 }
1350
1351 // Make a LabelSet node.
LabelSet(uint32_t label_id)1352 CNodePtr LabelSet(uint32_t label_id) {
1353 auto monad = GetMonad();
1354 auto label_set = NewPrimitive(prim::kPrimLabelSet);
1355 auto cnode = kernel_graph_->NewCNode({label_set, monad});
1356 AnfAlgo::SetNodeAttr(kAttrLabelIndex, MakeValue(label_id), cnode);
1357 cnode->set_abstract(monad->abstract());
1358 monad_ = cnode;
1359 return cnode;
1360 }
1361
1362 // Make a LabelSwitch node.
LabelSwitch(const AnfNodePtr & cond,const std::vector<uint32_t> & labels)1363 CNodePtr LabelSwitch(const AnfNodePtr &cond, const std::vector<uint32_t> &labels) {
1364 auto monad = GetMonad();
1365 auto label_switch = NewPrimitive(prim::kPrimLabelSwitch);
1366 auto cnode = kernel_graph_->NewCNode({label_switch, cond, monad});
1367 auto label_list = MakeValue(labels);
1368 AnfAlgo::SetNodeAttr(kAttrLabelSwitchList, label_list, cnode);
1369 cnode->set_abstract(monad->abstract());
1370 monad_ = cnode;
1371 return cnode;
1372 }
1373
1374 // Set child graph attribute for label_goto/label_switch node.
SetChildGrapAttr(const AnfNodePtr & node,const std::vector<KernelGraphPtr> & graphs)1375 void SetChildGrapAttr(const AnfNodePtr &node, const std::vector<KernelGraphPtr> &graphs) {
1376 AnfAlgo::SetNodeAttr(kAttrChildGraph, MakeValue(graphs), node);
1377 }
1378
1379 // Make a StackInit node.
StackInit(const KernelGraphPtr & kg)1380 CNodePtr StackInit(const KernelGraphPtr &kg) {
1381 auto monad = AnfAlgo::MakeMonadValueNode(kg);
1382 auto stack_init = NewPrimitive(prim::kPrimStackInit);
1383 auto cnode = kg->NewCNode({stack_init, monad});
1384 AnfAlgo::SetNodeAttr(kAttrIndex, MakeValue<int64_t>(0), cnode);
1385 cnode->set_abstract(monad->abstract());
1386 return cnode;
1387 }
1388
1389 // Make a StackDestroy node.
StackDestroy(const KernelGraphPtr & kg)1390 CNodePtr StackDestroy(const KernelGraphPtr &kg) {
1391 auto monad = AnfAlgo::MakeMonadValueNode(kg);
1392 auto stack_destroy = NewPrimitive(prim::kPrimStackDestroy);
1393 auto cnode = kg->NewCNode({stack_destroy, monad});
1394 AnfAlgo::SetNodeAttr(kAttrIndex, MakeValue<int64_t>(0), cnode);
1395 cnode->set_abstract(monad->abstract());
1396 return cnode;
1397 }
1398
1399 // Make a StackPush node.
StackPush(const AnfNodePtr & input)1400 CNodePtr StackPush(const AnfNodePtr &input) {
1401 auto monad = AnfAlgo::MakeMonadValueNode(kernel_graph_);
1402 auto stack_push = NewPrimitive(prim::kPrimStackPush);
1403 auto cnode = kernel_graph_->NewCNode({stack_push, input, monad});
1404 AnfAlgo::SetNodeAttr(kAttrIndex, MakeValue<int64_t>(0), cnode);
1405 auto op_name = std::to_string(kernel_graph_->graph_id()) + "_stack_push_" + std::to_string(name_index_++);
1406 AnfAlgo::SetNodeAttr(kAttrStackOpName, MakeValue(op_name), cnode);
1407 cnode->set_abstract(monad->abstract());
1408 return cnode;
1409 }
1410
1411 // Make a StackPop node.
StackPop()1412 CNodePtr StackPop() {
1413 auto monad = AnfAlgo::MakeMonadValueNode(kernel_graph_);
1414 auto stack_pop = NewPrimitive(prim::kPrimStackPop);
1415 auto cnode = kernel_graph_->NewCNode({stack_pop, monad});
1416 AnfAlgo::SetNodeAttr(kAttrIndex, MakeValue<int64_t>(0), cnode);
1417 auto op_name = std::to_string(kernel_graph_->graph_id()) + "_stack_pop_" + std::to_string(name_index_++);
1418 AnfAlgo::SetNodeAttr(kAttrStackOpName, MakeValue(op_name), cnode);
1419 cnode->set_abstract(monad->abstract()); // need to refresh output's abstract().
1420 return cnode;
1421 }
1422
1423 const KernelGraphPtr &kernel_graph_;
1424 AscendAutoMonadContext &context_;
1425
1426 // Call info for current kernel graph.
1427 CallInfo &call_info_;
1428
1429 // The last monad for Call/Switch node.
1430 AnfNodePtr last_monad_;
1431
1432 // The current control flow monad.
1433 AnfNodePtr monad_;
1434
1435 // The control flow monad const value node.
1436 AnfNodePtr monad_value_;
1437
1438 // Index value node cache for reuse.
1439 std::map<uint32_t, ValueNodePtr> index_nodes_;
1440
1441 // The index of stackops name.
1442 uint32_t name_index_;
1443
1444 // The flag which indicates to insert stackops.
1445 bool need_stackops_;
1446 };
1447
1448 constexpr size_t kAssignTargetIndex = 1;
1449 constexpr size_t kAssignSourceIndex = 2;
1450
1451 class ExecuteOrderGenerator {
1452 public:
1453 class Context : public BaseContext {};
ExecuteOrderGenerator(Context & context,const KernelGraphPtr & graph)1454 ExecuteOrderGenerator(Context &context, const KernelGraphPtr &graph) : context_(context), graph_(graph) {}
1455 ~ExecuteOrderGenerator() = default;
1456
Run()1457 void Run() {
1458 GenerateExecuteOrder();
1459 EraseParameter();
1460 EraseLabel();
1461 UnfoldRepeatedLabels();
1462 }
1463
1464 private:
GenerateGraphOrder(const KernelGraphPtr & graph)1465 void GenerateGraphOrder(const KernelGraphPtr &graph) {
1466 ExecuteOrderGenerator generator(context_, graph);
1467 generator.GenerateExecuteOrder();
1468 }
1469
FindMaxLabelId(const std::vector<CNodePtr> & nodes)1470 uint32_t FindMaxLabelId(const std::vector<CNodePtr> &nodes) {
1471 uint32_t max_label = 0;
1472 for (auto &node : nodes) {
1473 if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelSet)) {
1474 auto label_id = AnfAlgo::GetNodeAttr<uint32_t>(node, kAttrLabelIndex);
1475 max_label = std::max(label_id, max_label);
1476 }
1477 }
1478 return max_label;
1479 }
1480
HandleLabelSwitch(const AnfNodePtr & node,std::vector<uint32_t> * labels,std::vector<uint32_t> * switch_labels,std::multimap<uint32_t,uint32_t> * labels_multimap)1481 void HandleLabelSwitch(const AnfNodePtr &node, std::vector<uint32_t> *labels, std::vector<uint32_t> *switch_labels,
1482 std::multimap<uint32_t, uint32_t> *labels_multimap) {
1483 bool is_new_labels = false;
1484 auto label_list = AnfAlgo::GetNodeAttr<std::vector<uint32_t>>(node, kAttrLabelSwitchList);
1485 std::vector<uint32_t> new_labels;
1486 new_labels.reserve(label_list.size());
1487 for (auto label_id : label_list) {
1488 auto iter = std::find_if(labels->begin(), labels->end(), [label_id](auto id) { return id == label_id; });
1489 // Use new label if find repeated label.
1490 if (iter == labels->end()) {
1491 (void)new_labels.emplace_back(label_id);
1492 (void)labels->emplace_back(label_id);
1493 continue;
1494 }
1495 (void)new_labels.emplace_back(++max_label_);
1496 (void)labels_multimap->emplace(*iter, max_label_);
1497 (void)labels->emplace_back(label_id);
1498 is_new_labels = true;
1499 }
1500 (void)switch_labels->insert(switch_labels->end(), new_labels.begin(), new_labels.end());
1501 if (is_new_labels) {
1502 AnfAlgo::SetNodeAttr(kAttrLabelSwitchList, MakeValue(new_labels), node);
1503 }
1504 }
1505
HandleLabelGoto(const AnfNodePtr & node,std::vector<uint32_t> * labels,std::vector<uint32_t> * switch_labels,std::multimap<uint32_t,uint32_t> * labels_multimap)1506 void HandleLabelGoto(const AnfNodePtr &node, std::vector<uint32_t> *labels, std::vector<uint32_t> *switch_labels,
1507 std::multimap<uint32_t, uint32_t> *labels_multimap) {
1508 auto label_id = AnfAlgo::GetNodeAttr<uint32_t>(node, kAttrLabelIndex);
1509 auto iter = std::find(switch_labels->begin(), switch_labels->end(), label_id);
1510 if (iter == switch_labels->end()) {
1511 (void)labels->emplace_back(label_id);
1512 return;
1513 }
1514 AnfAlgo::SetNodeAttr(kAttrLabelIndex, MakeValue(++max_label_), node);
1515 (void)labels_multimap->emplace(*iter, max_label_);
1516 (void)labels->emplace_back(max_label_);
1517 }
1518
1519 // Unfold Repeated Labels, avoid same label in labelswitches.
UnfoldRepeatedLabels()1520 void UnfoldRepeatedLabels() {
1521 auto nodes = graph_->execution_order();
1522 std::vector<uint32_t> labels;
1523 std::vector<uint32_t> switch_labels;
1524 std::multimap<uint32_t, uint32_t> labels_multimap;
1525 max_label_ = FindMaxLabelId(nodes);
1526 for (auto &node : nodes) {
1527 if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelSwitch)) {
1528 HandleLabelSwitch(node, &labels, &switch_labels, &labels_multimap);
1529 continue;
1530 }
1531 if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelGoto)) {
1532 HandleLabelGoto(node, &labels, &switch_labels, &labels_multimap);
1533 continue;
1534 }
1535 }
1536 InsertLabelSet(&nodes, labels_multimap);
1537 graph_->set_label_num(max_label_ + 1);
1538 graph_->set_execution_order(nodes);
1539 }
1540
InsertLabelSet(std::vector<CNodePtr> * nodes,const std::multimap<uint32_t,uint32_t> & labels_multimap)1541 void InsertLabelSet(std::vector<CNodePtr> *nodes, const std::multimap<uint32_t, uint32_t> &labels_multimap) {
1542 for (auto labels : labels_multimap) {
1543 auto old_label = labels.first;
1544 auto new_label = labels.second;
1545 auto iter = std::find_if(nodes->begin(), nodes->end(), [old_label](auto node) {
1546 if (!AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelSet)) {
1547 return false;
1548 }
1549 auto label_id = AnfAlgo::GetNodeAttr<uint32_t>(node, kAttrLabelIndex);
1550 return label_id == old_label;
1551 });
1552 if (iter == nodes->end()) {
1553 MS_LOG(EXCEPTION) << "Not found labelset:" << old_label;
1554 }
1555 auto label_set = NewValueNode(std::make_shared<Primitive>(prim::kPrimLabelSet->name()));
1556 auto cnode = graph_->NewCNode({label_set});
1557 AnfAlgo::CopyNodeAttrs(*iter, cnode);
1558 AnfAlgo::SetNodeAttr(kAttrLabelIndex, MakeValue(new_label), cnode);
1559 auto monad = graph_->NewValueNode(kUMonad->ToAbstract(), kUMonad);
1560 cnode->set_abstract(monad->abstract());
1561 (void)device::ascend::SelectKernelInfo(cnode);
1562 (void)nodes->insert(iter, cnode);
1563 }
1564 }
1565
AppendGraphOrder(std::vector<CNodePtr> * execution_order,const KernelGraphPtr & graph)1566 void AppendGraphOrder(std::vector<CNodePtr> *execution_order, const KernelGraphPtr &graph) {
1567 auto &order = graph->execution_order();
1568 execution_order->insert(execution_order->end(), order.begin(), order.end());
1569 }
1570
HasSubGraphs(const CNodePtr & cnode)1571 bool HasSubGraphs(const CNodePtr &cnode) { return (cnode && AnfAlgo::HasNodeAttr(kAttrChildGraph, cnode)); }
1572
GetSubGraphs(const CNodePtr & cnode)1573 std::vector<KernelGraphPtr> GetSubGraphs(const CNodePtr &cnode) {
1574 return AnfAlgo::GetNodeAttr<std::vector<KernelGraphPtr>>(cnode, kAttrChildGraph);
1575 }
1576
EraseNodeFromExecOrder(const AnfNodePtr & node,const NotNull<std::vector<CNodePtr> * > exec_order)1577 void EraseNodeFromExecOrder(const AnfNodePtr &node, const NotNull<std::vector<CNodePtr> *> exec_order) {
1578 MS_EXCEPTION_IF_NULL(node);
1579 auto exec_iter = std::find(exec_order->begin(), exec_order->end(), node);
1580 if (exec_iter == exec_order->end()) {
1581 MS_LOG(EXCEPTION) << "Cannot find " << node->DebugString() << " in exec order.";
1582 }
1583 exec_order->erase(exec_iter);
1584 }
1585
GenerateExecuteOrder()1586 void GenerateExecuteOrder() {
1587 // Mark graph is visited.
1588 context_.MarkVisited(graph_);
1589
1590 // Generate topo-sorted kernel cnodes list for this graph.
1591 graph_->SetExecOrderByDefault();
1592
1593 std::vector<CNodePtr> execution_order;
1594 const auto &cnodes = graph_->execution_order();
1595 for (auto &cnode : cnodes) {
1596 // Push current node to execution order list.
1597 execution_order.push_back(cnode);
1598 // For cnode with sub-graphs, such as LabelSwitch, LabelGoto,
1599 // Generate execute order for these sub-graphs,
1600 // and then append them to current execution order list.
1601 if (HasSubGraphs(cnode)) {
1602 auto sub_graphs = GetSubGraphs(cnode);
1603 if (!AnfAlgo::HasNodeAttr(kAttrSwitchLayer, cnode)) {
1604 // For Switch, we use reversed order to generate sub-graph's execution order,
1605 // because the true branch of LabelSwitch is the second one, but
1606 // we want to make true branch ahead of false branch in the generated
1607 // execution order.
1608 std::reverse(sub_graphs.begin(), sub_graphs.end());
1609 }
1610 for (auto &sub_graph : sub_graphs) {
1611 if (context_.IsVisited(sub_graph)) {
1612 // Skip visited sub-graphs.
1613 continue;
1614 }
1615 GenerateGraphOrder(sub_graph);
1616 AppendGraphOrder(&execution_order, sub_graph);
1617 }
1618 // Clear ChildGraph attribute after execute order generated.
1619 AnfAlgo::EraseNodeAttr(kAttrChildGraph, cnode);
1620 }
1621 }
1622 // Save generated execution order into the graph.
1623 graph_->set_execution_order(std::move(execution_order));
1624 }
1625
GetAllNodes(std::map<CNodePtr,const size_t> * search_list)1626 std::set<CNodePtr> GetAllNodes(std::map<CNodePtr, const size_t> *search_list) {
1627 const auto &all_graphs = context_.visited_graphs();
1628 std::set<CNodePtr> all_nodes;
1629 for (auto &graph : all_graphs) {
1630 auto out = graph->get_return();
1631 MS_EXCEPTION_IF_NULL(out);
1632 (void)search_list->emplace(out->cast<CNodePtr>(), 0);
1633 auto nodes = TopoSort(out);
1634 for (auto &node : nodes) {
1635 MS_EXCEPTION_IF_NULL(node);
1636 auto cnode = node->cast<CNodePtr>();
1637 if (cnode != nullptr) {
1638 (void)all_nodes.insert(cnode);
1639 }
1640 }
1641 }
1642 return all_nodes;
1643 }
1644
GetRealNode(const AnfNodePtr & input)1645 static const AnfNodePtr &GetRealNode(const AnfNodePtr &input) {
1646 if (IsPrimitiveCNode(input, prim::kPrimLoad) || IsPrimitiveCNode(input, prim::kPrimDepend)) {
1647 return input->cast<CNodePtr>()->inputs().at(1);
1648 }
1649 return input;
1650 }
1651
RemoveSameInputsAssigns(std::vector<CNodePtr> * exec_order) const1652 void RemoveSameInputsAssigns(std::vector<CNodePtr> *exec_order) const {
1653 for (auto iter = exec_order->begin(); iter != exec_order->end();) {
1654 auto &node = *iter;
1655 auto &inputs = node->inputs();
1656 if (IsPrimitiveCNode(node, prim::kPrimAssign) &&
1657 (inputs.at(kAssignTargetIndex) == GetRealNode(inputs.at(kAssignSourceIndex)))) {
1658 iter = exec_order->erase(iter);
1659 } else {
1660 ++iter;
1661 }
1662 }
1663 }
1664
1665 // Erase redundant parameters and assign nodes.
EraseParameter()1666 void EraseParameter() {
1667 // Copy out execution order list.
1668 auto exec_order = graph_->execution_order();
1669 std::map<CNodePtr, const size_t> search_list;
1670 for (size_t i = 0; i < exec_order.size(); i++) {
1671 search_list.emplace(exec_order[i], i);
1672 }
1673
1674 // Remove assigns that target and source are same.
1675 RemoveSameInputsAssigns(&exec_order);
1676
1677 // Get all nodes and all graphs
1678 std::set<CNodePtr> all_nodes = GetAllNodes(&search_list);
1679 auto &all_graphs = context_.visited_graphs();
1680
1681 // Count parameter write times by check all assign nodes.
1682 auto param_write_times = CountParameterAssigns(search_list, exec_order);
1683
1684 // Erase redundant assigns.
1685 for (auto iter = exec_order.begin(); iter != exec_order.end();) {
1686 auto &node = *iter;
1687 // We only try to erase argument link assign nodes,
1688 // other assign nodes are skipped.
1689 if (IsOptimizableAssign(node)) {
1690 auto &target = node->inputs().at(kAssignTargetIndex);
1691 MS_EXCEPTION_IF_NULL(target);
1692 auto para = param_write_times.find(target);
1693 if (para != param_write_times.end() && para->second.first == 1) {
1694 // Check source of the Assign.
1695 auto &source = node->inputs().at(kAssignSourceIndex);
1696 MS_EXCEPTION_IF_NULL(source);
1697 if (source->isa<Parameter>()) {
1698 auto it = param_write_times.find(source);
1699 const auto index = search_list[node];
1700 if (it != param_write_times.end() && it->second.first > 0 && it->second.second > index) {
1701 // Skip if Assign source is a parameter and be written in other place.
1702 ++iter;
1703 continue;
1704 }
1705 }
1706 // If target only write once, and source not be written,
1707 // replace target with source and erase the Assign node.
1708 auto kg = target->func_graph()->cast<KernelGraphPtr>();
1709 MS_EXCEPTION_IF_NULL(kg);
1710 kg->ReplaceNode(target, source);
1711
1712 // replace parameter in graph input
1713 for (auto &g : all_graphs) {
1714 auto child_graph_inputs = g->MutableInputs();
1715 std::replace(child_graph_inputs->begin(), child_graph_inputs->end(), target, source);
1716 MS_LOG(DEBUG) << "Replace parameter " << target->DebugString() << " by " << source->DebugString()
1717 << " in graph " << g->graph_id() << " inputs";
1718 }
1719
1720 // replace parameter in node
1721 for (auto &iter_node : all_nodes) {
1722 for (size_t i = 0; i < iter_node->size(); ++i) {
1723 if (iter_node->input(i) == target) {
1724 MS_LOG(INFO) << "Replace " << iter_node->DebugString() << " input " << i << " by "
1725 << source->DebugString();
1726 iter_node->set_input(i, source);
1727 }
1728 }
1729 }
1730 iter = exec_order.erase(iter);
1731 continue;
1732 }
1733 }
1734 // Go next node.
1735 ++iter;
1736 }
1737 // Set new execution order with redundant assign removed.
1738 graph_->set_execution_order(std::move(exec_order));
1739 }
1740
1741 // Count parameter write times by check all assign nodes.
CountParameterAssigns(const std::map<CNodePtr,const size_t> & search_list,const std::vector<CNodePtr> & exec_order)1742 std::map<AnfNodePtr, std::pair<int, size_t>> CountParameterAssigns(
1743 const std::map<CNodePtr, const size_t> &search_list, const std::vector<CNodePtr> &exec_order) {
1744 auto ref_map = graph_->GetRefMap();
1745 std::multimap<AnfNodePtr, std::tuple<size_t, AnfNodePtr, size_t>> ref_multimap;
1746 std::set<AnfNodePtr> root_inputs(graph_->inputs().begin(), graph_->inputs().end());
1747 (void)std::transform(ref_map.begin(), ref_map.end(), std::inserter(ref_multimap, ref_multimap.end()),
1748 [](const std::pair<std::pair<AnfNodePtr, size_t>, std::pair<AnfNodePtr, size_t>> &p)
1749 -> std::pair<AnfNodePtr, std::tuple<size_t, AnfNodePtr, size_t>> {
1750 return {p.first.first, {p.first.second, p.second.first, p.second.second}};
1751 });
1752 auto validate_ref_parameter = [](AnfNodePtr node) -> AnfNodePtr {
1753 if (node->isa<CNode>() && AnfAlgo::CheckPrimitiveType(node, prim::kPrimTransData)) {
1754 auto cnode = node->cast<CNodePtr>();
1755 MS_EXCEPTION_IF_NULL(cnode);
1756 auto first_input = cnode->input(kFirstDataInputIndex);
1757 MS_EXCEPTION_IF_NULL(first_input);
1758 return first_input;
1759 }
1760 return node;
1761 };
1762
1763 // Find all graph input parameters.
1764 std::map<AnfNodePtr, std::pair<int, size_t>> param_write_times;
1765 const auto &all_graphs = context_.visited_graphs();
1766 for (const auto &graph : all_graphs) {
1767 for (auto &input : graph->inputs()) {
1768 if (input->isa<Parameter>()) {
1769 param_write_times.emplace(input, std::make_pair(0, 0));
1770 }
1771 }
1772 }
1773
1774 // Search all refnodes for parameter write assigns.
1775 for (auto &node : exec_order) {
1776 if (ref_multimap.find(node) == ref_multimap.end()) {
1777 // if node is not refnode which cannot write param, skip it.
1778 continue;
1779 }
1780 std::set<AnfNodePtr> refed_parameters;
1781 for (auto [iter, end] = ref_multimap.equal_range(node); iter != end; ++iter) {
1782 (void)refed_parameters.insert(validate_ref_parameter(std::get<1>(iter->second)));
1783 }
1784 for (auto &in : node->inputs()) {
1785 auto visit_node = AnfAlgo::VisitKernelWithReturnType(in, 0).first;
1786 visit_node = validate_ref_parameter(visit_node);
1787 if (!visit_node->isa<Parameter>() || root_inputs.find(visit_node) != root_inputs.end()) {
1788 continue;
1789 }
1790 if (refed_parameters.find(visit_node) != refed_parameters.end()) {
1791 auto iter = param_write_times.find(visit_node);
1792 if (iter != param_write_times.end()) {
1793 // Found a parameter writer, count it.
1794 ++(iter->second.first);
1795 if (search_list.find(node) == search_list.end()) {
1796 MS_LOG(EXCEPTION) << "node: " << node->DebugString() << " cannot found in search list.";
1797 }
1798 iter->second.second = search_list.at(node);
1799 }
1800 }
1801 }
1802 }
1803 return param_write_times;
1804 }
1805
1806 // Check if a node is an assign for argument link and can be optimized.
IsOptimizableAssign(const AnfNodePtr & node)1807 bool IsOptimizableAssign(const AnfNodePtr &node) {
1808 auto cnode = dyn_cast<CNode>(node);
1809 if (cnode == nullptr) {
1810 return false;
1811 }
1812 auto prim = GetValueNode<PrimitivePtr>(cnode->inputs().at(0));
1813 if (!IsPrimitiveEquals(prim, prim::kPrimAssign)) {
1814 return false;
1815 }
1816 return (prim->GetAttr(LINK) == prim::kValueOne) && (prim->GetAttr(KEEP) != prim::kValueOne);
1817 }
1818
1819 // Erase LabelGoto and LabelSet
EraseLabel()1820 void EraseLabel() {
1821 // Find used labels (as jump target).
1822 std::set<uint32_t> label_used;
1823 auto exec_order = graph_->execution_order();
1824 for (auto iter = exec_order.begin(); iter != exec_order.end();) {
1825 auto &node = *iter;
1826 if (IsPrimitiveCNode(node, prim::kPrimLabelSwitch)) {
1827 auto labels = AnfAlgo::GetNodeAttr<std::vector<uint32_t>>(node, kAttrLabelSwitchList);
1828 for (auto label : labels) {
1829 label_used.insert(label);
1830 }
1831 } else if (IsPrimitiveCNode(node, prim::kPrimLabelGoto)) {
1832 auto label = AnfAlgo::GetNodeAttr<uint32_t>(node, kAttrLabelIndex);
1833 auto next = std::next(iter);
1834 if (next != exec_order.end() && IsPrimitiveCNode(*next, prim::kPrimLabelSet)) {
1835 // The LabelGoto that jump to next node can be removed.
1836 auto next_label = AnfAlgo::GetNodeAttr<uint32_t>(*next, kAttrLabelIndex);
1837 if (next_label == label) {
1838 iter = exec_order.erase(iter);
1839 continue;
1840 }
1841 }
1842 label_used.insert(label);
1843 }
1844 ++iter;
1845 }
1846 // Erase unused LabelSet nodes.
1847 for (auto iter = exec_order.begin(); iter != exec_order.end();) {
1848 auto &node = *iter;
1849 if (IsPrimitiveCNode(node, prim::kPrimLabelSet)) {
1850 auto label = AnfAlgo::GetNodeAttr<uint32_t>(node, kAttrLabelIndex);
1851 if (label_used.find(label) == label_used.end()) {
1852 iter = exec_order.erase(iter);
1853 continue;
1854 }
1855 }
1856 ++iter;
1857 }
1858 graph_->set_execution_order(std::move(exec_order));
1859 }
1860
1861 Context &context_;
1862 const KernelGraphPtr graph_;
1863 uint32_t max_label_ = 0;
1864 };
1865 } // namespace
1866
Run()1867 void AscendAutoMonad::Run() {
1868 MS_LOG(DEBUG) << "Ascend auto-monad start.";
1869 auto kg = kernel_graph_.get();
1870 AscendAutoMonadContext context(kg);
1871 CallInfoFinder::Run(&context);
1872 AscendAutoMonadConverter::Run(&context);
1873 kernel_graph_->set_label_num(context.CurrentLabel() + 1);
1874 kernel_graph_->set_recursive_call(context.HasRecursiveCall());
1875 kernel_graph_->set_subgraph_multi_call(context.HasSubgraphMultiCall());
1876 MS_LOG(DEBUG) << "Ascend auto-monad finish.";
1877 #ifdef ENABLE_DUMP_IR
1878 DumpGraphForDebug(kernel_graph_);
1879 #endif
1880 }
1881
GenerateExecuteOrder()1882 void AscendAutoMonad::GenerateExecuteOrder() {
1883 MS_LOG(DEBUG) << "Ascend generate execute order start.";
1884 ExecuteOrderGenerator::Context context;
1885 ExecuteOrderGenerator generator(context, kernel_graph_.get());
1886 generator.Run();
1887 MS_LOG(DEBUG) << "Ascend generate execute order finish.";
1888 #ifndef ENABLE_SECURITY
1889 DumpExecuteOrder(kernel_graph_);
1890 #endif
1891 }
1892 } // namespace session
1893 } // namespace mindspore
1894