• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "pipeline/jit/static_analysis/auto_monad.h"
18 #include <set>
19 #include <map>
20 #include <list>
21 #include <unordered_map>
22 #include <vector>
23 #include <stack>
24 #include <utility>
25 #include <algorithm>
26 #include "pipeline/jit/parse/resolve.h"
27 #include "frontend/operator/ops.h"
28 #include "frontend/operator/composite/multitype_funcgraph.h"
29 #include "utils/flags.h"
30 #include "utils/utils.h"
31 #include "utils/ordered_map.h"
32 #include "base/core_ops.h"
33 #include "abstract/abstract_value.h"
34 
35 namespace mindspore {
36 namespace pipeline {
37 namespace {  // namespace anonymous
38 using ClassTypePtr = std::shared_ptr<parse::ClassType>;
39 using RefInputs = OrderedMap<AnfNodePtr, std::vector<size_t>>;
40 
41 // Add or get a monad parameter.
AddMonadParameter(const FuncGraphPtr & func_graph,const std::string & name,const abstract::AbstractBasePtr & abs)42 AnfNodePtr AddMonadParameter(const FuncGraphPtr &func_graph, const std::string &name,
43                              const abstract::AbstractBasePtr &abs) {
44   MS_EXCEPTION_IF_NULL(func_graph);
45   size_t params_size = func_graph->parameters().size();
46   size_t io_monad_location = params_size;
47   // Search for existed parameters, return it if found.
48   for (size_t i = 0; i < params_size; i++) {
49     auto &node = func_graph->parameters()[i];
50     auto para = dyn_cast<Parameter>(node);
51     if (para == nullptr) {
52       continue;
53     }
54     auto para_abs = para->abstract();
55     if (para_abs && *para_abs == *abs) {
56       return para;
57     }
58     if (HasAbstractIOMonad(para)) {
59       io_monad_location = i;
60     }
61   }
62   // Create a new parameter if not existed.
63   auto para = std::make_shared<Parameter>(func_graph);
64   para->set_name(name);
65   para->debug_info()->set_name(name);
66   para->set_abstract(abs);
67   // If io monad parameter added before u monad parameter, should insert u monad before io monad in parameters
68   if (io_monad_location != params_size && abs->isa<abstract::AbstractUMonad>()) {
69     std::vector<AnfNodePtr> params = func_graph->parameters();
70     (void)params.insert(params.begin() + SizeToInt(io_monad_location), para);
71     func_graph->set_parameters(params);
72   } else {
73     func_graph->add_parameter(para);
74   }
75   return para;
76 }
77 
78 // Gets side effect propagate attribute value from a ClassType object.
GetSideEffectPropagate(const ClassTypePtr & class_type)79 int GetSideEffectPropagate(const ClassTypePtr &class_type) {
80   if (class_type) {
81     auto obj = class_type->obj();
82     if (py::hasattr(obj, GRAPH_FLAG_SIDE_EFFECT_PROPAGATE)) {
83       auto value = py::getattr(obj, GRAPH_FLAG_SIDE_EFFECT_PROPAGATE);
84       return value.cast<int>();
85     }
86   }
87   return 0;
88 }
89 
90 // Gets 'side_effect_propagate' attribute value from a primitive.
GetSideEffectPropagate(const PrimitivePtr & prim)91 int GetSideEffectPropagate(const PrimitivePtr &prim) {
92   if (prim) {
93     auto attr = prim->GetAttr(GRAPH_FLAG_SIDE_EFFECT_PROPAGATE);
94     if (attr && attr->isa<Int64Imm>()) {
95       return static_cast<int>(attr->cast<Int64ImmPtr>()->value());
96     }
97   }
98   return 0;
99 }
100 
101 // Return true if the node has Ref abstract.
HasAbstractRef(const AnfNodePtr & node)102 bool HasAbstractRef(const AnfNodePtr &node) {
103   if (node == nullptr) {
104     return false;
105   }
106   auto &abs = node->abstract();
107   return (abs != nullptr) && abs->isa<abstract::AbstractRef>();
108 }
109 
110 // Gets ref inputs and its indexes from a cnode.
GetRefInputs(const CNodePtr & cnode)111 RefInputs GetRefInputs(const CNodePtr &cnode) {
112   RefInputs ref_inputs;
113   MS_EXCEPTION_IF_NULL(cnode);
114   for (size_t i = 1; i < cnode->size(); ++i) {
115     auto &input = cnode->inputs().at(i);
116     if (HasAbstractRef(input)) {
117       ref_inputs[input].push_back(i);
118     }
119   }
120   return ref_inputs;
121 }
122 
123 // Return true if cnode has ref input.
HasRefInput(const CNodePtr & cnode)124 bool HasRefInput(const CNodePtr &cnode) {
125   if (cnode == nullptr || cnode->inputs().empty()) {
126     return false;
127   }
128   auto &inputs = cnode->inputs();
129   // Return true if any of arguments is ref.
130   return std::any_of(inputs.begin() + 1, inputs.end(), [](const auto &input) { return HasAbstractRef(input); });
131 }
132 
133 // Return true if we don't need Load for the given primitive.
134 // i.e. keep Ref as Ref for some primitives.
IsKeepRef(const PrimitivePtr & prim)135 bool IsKeepRef(const PrimitivePtr &prim) {
136   return (GetSideEffectPropagate(prim) != 0) || IsPrimitiveEquals(prim, prim::kPrimRefToEmbed) ||
137          IsPrimitiveEquals(prim, prim::kPrimPull);
138 }
139 
140 // Gets primitive if the node is a primitive value node.
GetPrimitive(const AnfNodePtr & node)141 PrimitivePtr GetPrimitive(const AnfNodePtr &node) {
142   PrimitivePtr prim = GetValueNode<PrimitivePtr>(node);
143   auto do_sig = dyn_cast<mindspore::prim::DoSignaturePrimitive>(prim);
144   if (do_sig) {
145     auto val = do_sig->function();
146     return dyn_cast<Primitive>(val);
147   }
148   return prim;
149 }
150 
151 // Gets primitive from the given cnode, return nullptr if cnode.inputs[0] is not a primitive.
GetPrimitive(const CNodePtr & cnode)152 PrimitivePtr GetPrimitive(const CNodePtr &cnode) {
153   if (cnode == nullptr || cnode->inputs().empty()) {
154     return nullptr;
155   }
156   return GetPrimitive(cnode->input(0));
157 }
158 
159 // Gets func_graph from the given cnode, return nullptr if it is not a func graph call.
GetFuncGraph(const CNodePtr & cnode)160 FuncGraphPtr GetFuncGraph(const CNodePtr &cnode) {
161   if (cnode != nullptr && !cnode->inputs().empty()) {
162     return GetValueNode<FuncGraphPtr>(cnode->input(0));
163   }
164   return nullptr;
165 }
166 
167 // Gets class_type from the given cnode->inputs[0].
GetClassType(const CNodePtr & cnode)168 ClassTypePtr GetClassType(const CNodePtr &cnode) {
169   if (cnode && !cnode->inputs().empty()) {
170     auto apply = cnode->input(0);
171     auto apply_cnode = dyn_cast<CNode>(apply);
172     if (apply_cnode && !apply_cnode->inputs().empty()) {
173       return GetValueNode<ClassTypePtr>(apply_cnode->input(0));
174     }
175   }
176   return nullptr;
177 }
178 
179 // Gets first input as cnode from the given cnode,
180 // return null if input[0] is not a cnode.
GetFuncCNode(const CNodePtr & cnode)181 CNodePtr GetFuncCNode(const CNodePtr &cnode) {
182   if (cnode != nullptr && !cnode->inputs().empty()) {
183     return dyn_cast<CNode>(cnode->input(0));
184   }
185   return nullptr;
186 }
187 
188 // Gets first input as function parameter from the given cnode,
189 // return null if input[0] is not a parameter.
GetFuncParameter(const CNodePtr & cnode)190 ParameterPtr GetFuncParameter(const CNodePtr &cnode) {
191   if (cnode != nullptr && !cnode->inputs().empty()) {
192     return dyn_cast<Parameter>(cnode->input(0));
193   }
194   return nullptr;
195 }
196 
197 // Gets first input as MultitypeFuncGraph from the given cnode,
198 // return null if input[0] is not a MultitypeFuncGraph.
GetFuncMultitypeFuncGraph(const CNodePtr & cnode)199 prim::MultitypeFuncGraphPtr GetFuncMultitypeFuncGraph(const CNodePtr &cnode) {
200   if (cnode != nullptr && !cnode->inputs().empty()) {
201     return GetValueNode<prim::MultitypeFuncGraphPtr>(cnode->input(0));
202   }
203   return nullptr;
204 }
205 
206 // --------------------------------------------------------------------
207 // SCC (Strongly Connected Components) related types.
208 // --------------------------------------------------------------------
209 using SccVector = std::set<FuncGraphPtr>;
210 using SccPtr = std::shared_ptr<SccVector>;
211 using SccMap = std::unordered_map<FuncGraphPtr, SccPtr>;
212 
213 // ---------------------------------------------------------------------
214 // SccFinder find SCCs using Tarjan's algorithm.
215 // ---------------------------------------------------------------------
216 class SccFinder {
217  public:
SccFinder(const FuncGraphPtr & root)218   explicit SccFinder(const FuncGraphPtr &root) : root_(root) {}
219   ~SccFinder() = default;
Run()220   void Run() { (void)Search(root_); }
scc_map() const221   const SccMap &scc_map() const { return scc_map_; }
222 
223  private:
224   // Save state of a func graph.
225   struct State {
226     size_t index = 0;
227     size_t lowlink = 0;
228     bool in_stack = false;
Statemindspore::pipeline::__anon897cf14f0111::SccFinder::State229     explicit State(size_t index) : index(index), lowlink(index), in_stack(false) {}
230     ~State() = default;
231   };
232 
233   // Search SCCs from the given graph.
Search(FuncGraphPtr graph)234   const State &Search(FuncGraphPtr graph) {
235     // Create graph state, set it as visited.
236     MS_EXCEPTION_IF_NULL(graph);
237     auto [inserted, ok] = visited_.emplace(graph, State(index_++));
238     if (!ok) {
239       MS_LOG(EXCEPTION) << "Already visited: " << graph->ToString();
240     }
241     auto &state = inserted->second;
242     // Push visited graph to stack.
243     stack_.push(graph);
244     state.in_stack = true;
245     // Search successor graphs.
246     for (auto &used : graph->func_graphs_used()) {
247       auto &sg = used.first;
248       auto iter = visited_.find(sg);
249       if (iter == visited_.end()) {
250         // Successor graph has not yet been visited, recurse on it.
251         auto &sg_state = Search(sg);
252         state.lowlink = std::min(state.lowlink, sg_state.lowlink);
253       } else if (iter->second.in_stack) {
254         // Successor graph is in stack and hence in the current SCC.
255         state.lowlink = std::min(state.lowlink, iter->second.index);
256       }
257     }
258     // If index == lowlink, this means it is the root of SCC.
259     if (state.index == state.lowlink) {
260       // Pop members of the SCC from stack, they are on top of its root.
261       auto scc = std::make_shared<SccVector>();
262       while (!stack_.empty()) {
263         auto g = stack_.top();
264         stack_.pop();
265         auto found = visited_.find(g);
266         if (found == visited_.end()) {
267           MS_LOG(EXCEPTION) << "Unexpected graph: " << g->ToString();
268         }
269         found->second.in_stack = false;
270         // Add graph to SCC, and create the map from graph to SCC.
271         scc->insert(g);
272         scc_map_.emplace(g, scc);
273         if (g == graph) {
274           break;
275         }
276       }
277       // SCC should not be empty.
278       if (scc->empty()) {
279         MS_LOG(EXCEPTION) << "Invalid SCC for: " << graph->ToString();
280       }
281     }
282     return state;
283   }
284 
285   // The root graph.
286   FuncGraphPtr root_;
287 
288   // Current index by DFS order.
289   size_t index_ = 1;
290 
291   // Visited graphs and their states.
292   std::unordered_map<FuncGraphPtr, State> visited_;
293 
294   // The stack for Tarjan algorithm.
295   std::stack<FuncGraphPtr> stack_;
296 
297   // The result SCC map, from graph to its SCC.
298   SccMap scc_map_;
299 };
300 
301 struct SwitchLayerCall {
302   CNodePtr caller;
303   EffectInfo effect_info;
304   std::vector<FuncGraphPtr> branches;
305 };
306 
307 // -------------------------------------------------------------------------------
308 // SideEffectFinder search and mark side effects for graph and its sub-graphs.
309 // -------------------------------------------------------------------------------
310 class SideEffectFinder {
311  public:
Search(const FuncGraphPtr & root)312   static void Search(const FuncGraphPtr &root) {
313     SideEffectFinder finder(root);
314     finder.Run();
315   }
316 
317  private:
SideEffectFinder(const FuncGraphPtr & root)318   explicit SideEffectFinder(const FuncGraphPtr &root) : root_(root) {}
319   ~SideEffectFinder() = default;
320 
Run()321   void Run() {
322     // To handle recursive calls, we generate SCC map before search.
323     GenerateSccMap();
324     // Update order list to include outer cnodes.
325     UpdateOrderLists();
326     // Find side effects by DFS from the top graph.
327     (void)GetEffectInfo(root_);
328     // Check switch layer calls, add monad arguments if need.
329     HandleSwitchLayerCalls();
330   }
331 
UpdateOrderLists() const332   void UpdateOrderLists() const {
333     // Some cnodes used in current func graph but belong to other func graph, we have to
334     // insert them into order list so that we can handle side effects for them.
335     UpdateOrderList(root_);
336     for (auto &fg : root_->func_graphs_used_total()) {
337       UpdateOrderList(fg);
338     }
339   }
340 
UpdateOrderList(const FuncGraphPtr & func_graph)341   static void UpdateOrderList(const FuncGraphPtr &func_graph) {
342     MS_EXCEPTION_IF_NULL(func_graph);
343     OrderedSet<CNodePtr> new_order_list;
344     const auto &order_list = func_graph->order_list();
345     for (auto &cnode : order_list) {
346       PushToOrderList(func_graph, cnode, &new_order_list);
347     }
348     func_graph->set_order_list(std::move(new_order_list));
349   }
350 
PushToOrderList(const FuncGraphPtr & fg,const CNodePtr & cnode,OrderedSet<CNodePtr> * new_order_list)351   static void PushToOrderList(const FuncGraphPtr &fg, const CNodePtr &cnode, OrderedSet<CNodePtr> *new_order_list) {
352     MS_EXCEPTION_IF_NULL(cnode);
353     MS_EXCEPTION_IF_NULL(new_order_list);
354     if (new_order_list->contains(cnode)) {
355       return;
356     }
357     for (auto &input : cnode->inputs()) {
358       auto input_cnode = dyn_cast<CNode>(input);
359       if (input_cnode != nullptr && input_cnode->func_graph() != fg) {
360         PushToOrderList(fg, input_cnode, new_order_list);
361       }
362     }
363     new_order_list->push_back(cnode);
364   }
365 
366   // Generate SCC map by SccFinder.
GenerateSccMap()367   void GenerateSccMap() {
368     SccFinder scc_finder(root_);
369     scc_finder.Run();
370     scc_map_ = std::move(scc_finder.scc_map());
371   }
372 
373   // Gets branch graph from a switch cnode at given input index.
GetSwitchBranch(const CNodePtr & cnode,size_t index)374   FuncGraphPtr GetSwitchBranch(const CNodePtr &cnode, size_t index) {
375     MS_EXCEPTION_IF_NULL(cnode);
376     return GetValueNode<FuncGraphPtr>(cnode->inputs().at(index));
377   }
378 
379   // Gets branch graphs from a switch cnode.
GetSwitchBranches(const CNodePtr & cnode)380   std::vector<FuncGraphPtr> GetSwitchBranches(const CNodePtr &cnode) {
381     MS_EXCEPTION_IF_NULL(cnode);
382     constexpr size_t switch_cnode_size = 4;
383     constexpr size_t true_index = 2;
384     constexpr size_t false_index = 3;
385     // Check size.
386     if (cnode->size() != switch_cnode_size) {
387       MS_LOG(EXCEPTION) << "Invalid switch: " << cnode->DebugString();
388     }
389     // Add both branches, in some case, only one branch is set.
390     std::vector<FuncGraphPtr> branches;
391     auto true_branch = GetSwitchBranch(cnode, true_index);
392     if (true_branch != nullptr) {
393       branches.emplace_back(true_branch);
394     }
395     auto false_branch = GetSwitchBranch(cnode, false_index);
396     if (false_branch != nullptr) {
397       branches.emplace_back(false_branch);
398     }
399     if (branches.empty()) {
400       MS_LOG(EXCEPTION) << "Invalid switch: " << cnode->DebugString();
401     }
402     return branches;
403   }
404 
405   // Add monad parameter to switch branch graphs.
AddMonadParameters(const std::vector<FuncGraphPtr> & branches,const std::string & name,const AbstractBasePtr & abs)406   void AddMonadParameters(const std::vector<FuncGraphPtr> &branches, const std::string &name,
407                           const AbstractBasePtr &abs) {
408     for (auto &branch : branches) {
409       (void)AddMonadParameter(branch, name, abs);
410     }
411   }
412 
413   // Trace effect info for Switch cnode.
TraceSwitchEffectInfo(const CNodePtr & cnode)414   EffectInfo TraceSwitchEffectInfo(const CNodePtr &cnode) {
415     // Find branches from switch cnode.
416     auto branches = GetSwitchBranches(cnode);
417     // For some case, only one branch is set.
418     if (branches.size() == 1) {
419       auto &branch = branches.front();
420       // Save branch caller, so that we can update arguments for the caller.
421       SaveBranchCaller(cnode, branch);
422       return GetEffectInfo(branch);
423     }
424     // When both branches are set, merge their effect infos.
425     EffectInfo info = MergeEffectInfo(branches);
426     if (info.state == EffectInfo::kDetected) {
427       // Setup both branches according the merged effect info.
428       SetupEffectBranches(info, branches);
429     }
430     return info;
431   }
432 
433   // Trace effect info for SwitchLayer cnode.
TraceSwitchLayerEffectInfo(const CNodePtr & cnode)434   EffectInfo TraceSwitchLayerEffectInfo(const CNodePtr &cnode) {
435     // Find branches from switch_layer cnode.
436     auto branches = GetSwitchLayerBranches(cnode);
437     // Merge effect info from all branches.
438     EffectInfo info = MergeEffectInfo(branches);
439     if (info.state == EffectInfo::kDetected) {
440       // Setup branches according the merged effect info.
441       SetupEffectBranches(info, branches);
442       // Save the switch_layer call, so that we can add monad argument for it if need.
443       auto &call = switch_layer_calls_.emplace_back();
444       call.caller = caller_;
445       call.effect_info = info;
446       call.branches = move(branches);
447     }
448     return info;
449   }
450 
HandleSwitchLayerCalls()451   void HandleSwitchLayerCalls() {
452     for (auto &call : switch_layer_calls_) {
453       const auto &info = call.effect_info;
454       const auto &branches = call.branches;
455       auto new_info = MergeEffectInfo(branches);
456       // Reset branches if effect info changed.
457       if (new_info.memory != info.memory || new_info.load != info.load || new_info.io != info.io) {
458         AddMonadForCaller(call.caller, new_info);
459         SetupEffectBranches(new_info, branches);
460       }
461     }
462   }
463 
464   // Gets branch graphs from a switch_layer cnode.
GetSwitchLayerBranches(const CNodePtr & cnode)465   std::vector<FuncGraphPtr> GetSwitchLayerBranches(const CNodePtr &cnode) {
466     MS_EXCEPTION_IF_NULL(cnode);
467     constexpr size_t func_tuple_index = 2;
468     if (cnode->size() <= func_tuple_index) {
469       MS_LOG(EXCEPTION) << "Invalid switch_layer: " << cnode->DebugString(2);
470     }
471     auto func_tuple = cnode->inputs().at(func_tuple_index);
472     return GetGraphsFromTuple(func_tuple);
473   }
474 
475   // Get and trace graphs from a tuple of func node for switch_layer.
GetGraphsFromTuple(const AnfNodePtr & func_tuple)476   std::vector<FuncGraphPtr> GetGraphsFromTuple(const AnfNodePtr &func_tuple) {
477     // The func tuple maker.
478     if (IsPrimitiveCNode(func_tuple, prim::kPrimMakeTuple)) {
479       return GetGraphsFromMakeTuple(func_tuple->cast<CNodePtr>());
480     }
481     // Trace tuple from parameter.
482     auto para = dyn_cast<Parameter>(func_tuple);
483     if (para != nullptr) {
484       std::vector<FuncGraphPtr> graphs;
485       ForEachRealArguments(para,
486                            [this, &graphs](const AnfNodePtr &arg) { graphs = std::move(GetGraphsFromTuple(arg)); });
487       return graphs;
488     }
489     // Trace tuple returned from func graph call.
490     auto cnode = dyn_cast<CNode>(func_tuple);
491     auto func_graph = GetFuncGraph(cnode);
492     if (func_graph != nullptr) {
493       return GetGraphsFromTuple(func_graph->output());
494     }
495     MS_LOG(EXCEPTION) << "Invalid input for switch_layer: func_graph is nullptr.";
496   }
497 
498   // Get graphs from a tuple of funcs make node for switch_layer.
GetGraphsFromMakeTuple(const CNodePtr & make_tuple)499   std::vector<FuncGraphPtr> GetGraphsFromMakeTuple(const CNodePtr &make_tuple) {
500     MS_EXCEPTION_IF_NULL(make_tuple);
501     auto &inputs = make_tuple->inputs();
502     if (inputs.size() <= 1) {
503       MS_LOG(EXCEPTION) << "Invalid make_tuple for switch_layer: " << make_tuple->DebugString(2);
504     }
505     std::vector<FuncGraphPtr> graphs;
506     graphs.reserve(inputs.size() - 1);
507     for (size_t i = 1; i < inputs.size(); ++i) {
508       auto func_graph = GetValueNode<FuncGraphPtr>(inputs.at(i));
509       if (func_graph == nullptr) {
510         MS_LOG(WARNING) << "Non-graph found in switch_layer input: " << make_tuple->DebugString(2) << " index=" << i;
511         continue;
512       }
513       graphs.push_back(func_graph);
514     }
515     return graphs;
516   }
517 
518   // Trace effect info from tuple_getitem cnode.
TraceTupleGetItemEffectInfo(const CNodePtr & cnode,std::stack<int64_t> * tuple_indexes)519   EffectInfo TraceTupleGetItemEffectInfo(const CNodePtr &cnode, std::stack<int64_t> *tuple_indexes) {
520     constexpr size_t tuple_input = 1;
521     constexpr size_t index_input = 2;
522     constexpr size_t cnode_size = 3;
523     if (cnode->size() != cnode_size) {
524       MS_LOG(EXCEPTION) << "Invalid tuple_getitem: " << cnode->DebugString();
525     }
526     // Get item index.
527     auto &index_node = cnode->inputs().at(index_input);
528     auto index_value = GetValueNode<Int64ImmPtr>(index_node);
529     if (index_value == nullptr) {
530       MS_LOG(EXCEPTION) << "Tuple_getitem with non-const index " << cnode->DebugString();
531     }
532     int64_t index = index_value->value();
533 
534     // Get tuple value.
535     const auto &tuple_node = cnode->inputs().at(tuple_input);
536     // Push tuple index.
537     tuple_indexes->push(index);
538     return TraceTupleEffectInfo(tuple_node, tuple_indexes);
539   }
540 
TraceTupleEffectInfo(const AnfNodePtr & tuple_node,std::stack<int64_t> * tuple_indexes)541   EffectInfo TraceTupleEffectInfo(const AnfNodePtr &tuple_node, std::stack<int64_t> *tuple_indexes) {
542     MS_EXCEPTION_IF_NULL(tuple_indexes);
543     auto para = dyn_cast<Parameter>(tuple_node);
544     if (para != nullptr) {
545       return TraceTupleParaEffectInfo(para, *tuple_indexes);
546     }
547     auto tuple_cnode = dyn_cast<CNode>(tuple_node);
548     if (tuple_cnode != nullptr) {
549       return TraceTupleCNodeEffectInfo(tuple_cnode, tuple_indexes);
550     }
551     // Should not reach here.
552     MS_LOG(EXCEPTION) << "Side effects untraceable: tuple_cnode is nullptr.";
553   }
554 
TraceTupleParaEffectInfo(const ParameterPtr & para,const std::stack<int64_t> & tuple_indexes)555   EffectInfo TraceTupleParaEffectInfo(const ParameterPtr &para, const std::stack<int64_t> &tuple_indexes) {
556     EffectInfo info{EffectInfo::kDetected, false, false, false};
557     ForEachRealArguments(para, [this, &info, tuple_indexes](const AnfNodePtr &arg) {
558       // Merge real argument effect info.
559       auto tuple_indexes_copy = tuple_indexes;
560       auto arg_info = TraceTupleEffectInfo(arg, &tuple_indexes_copy);
561       info.Merge(arg_info);
562     });
563     return info;
564   }
565 
TraceTupleCNodeEffectInfo(const CNodePtr & cnode,std::stack<int64_t> * tuple_indexes)566   EffectInfo TraceTupleCNodeEffectInfo(const CNodePtr &cnode, std::stack<int64_t> *tuple_indexes) {
567     MS_EXCEPTION_IF_NULL(tuple_indexes);
568     MS_EXCEPTION_IF_NULL(cnode);
569     auto prim = GetPrimitive(cnode);
570     // Trace MakeTuple.
571     if (IsPrimitiveEquals(prim, prim::kPrimMakeTuple)) {
572       if (tuple_indexes->empty()) {
573         MS_LOG(EXCEPTION) << "Unexpected make_tuple: " << cnode->DebugString(2);
574         return {EffectInfo::kDetected, false, false, false};
575       }
576       // Pop out tuple index.
577       auto top_index = tuple_indexes->top();
578       tuple_indexes->pop();
579       size_t input_index = 0;
580       // Support tuple index is negative
581       if (top_index < 0) {
582         if (SizeToLong(cnode->size()) + top_index < 0) {
583           MS_LOG(EXCEPTION) << "Invalid make_tuple: " << cnode->DebugString() << " index=" << top_index;
584         }
585         input_index = static_cast<size_t>(cnode->size() + top_index);
586       } else {
587         // Follow the tuple item according the index.
588         input_index = static_cast<size_t>(top_index) + 1;
589       }
590       if (input_index >= cnode->size()) {
591         MS_LOG(EXCEPTION) << "Invalid make_tuple: " << cnode->DebugString() << " index=" << top_index;
592       }
593       if (tuple_indexes->empty()) {
594         // Trace non-tuple.
595         return TraceEffectInfo(cnode->inputs().at(input_index));
596       }
597       // This is the tuple of tuple case.
598       return TraceTupleEffectInfo(cnode->inputs().at(input_index), tuple_indexes);
599     }
600     // Trace TupleGetItem (tuple of tuple).
601     if (IsPrimitiveEquals(prim, prim::kPrimTupleGetItem)) {
602       return TraceTupleGetItemEffectInfo(cnode, tuple_indexes);
603     }
604     // Trace primitive propagating side effect from its input, such as Depend, Identity, etc.
605     int input_index = GetSideEffectPropagate(prim);
606     if (input_index > 0 && input_index < static_cast<int>(cnode->size())) {
607       return TraceTupleEffectInfo(cnode->input(static_cast<size_t>(input_index)), tuple_indexes);
608     }
609     // Tuple returned from func graph call.
610     auto func_graph = GetFuncGraph(cnode);
611     if (func_graph != nullptr) {
612       return TraceTupleEffectInfo(func_graph->output(), tuple_indexes);
613     }
614     // Tuple returned from a Switch call.
615     if (cnode->size() == 1 && IsPrimitiveCNode(cnode->input(0), prim::kPrimSwitch)) {
616       return TraceTupleFromSwitch(cnode->input(0)->cast<CNodePtr>(), *tuple_indexes);
617     }
618     // Tuple is returned from J().
619     //   %1 = J(primal)
620     //   tuple = %1(args)
621     if (cnode->size() > 0 && IsPrimitiveCNode(cnode->input(0), prim::kPrimJ)) {
622       MS_LOG(DEBUG) << "Tuple from J: " << cnode->DebugString(2);
623       return {EffectInfo::kDetected, false, false, false};
624     }
625     // Rare case.
626     MS_LOG(WARNING) << "Tuple untraceable from: " << cnode->DebugString(2);
627     return {EffectInfo::kDetected, false, false, false};
628   }
629 
630   // Trace effect info from a Switch node that output is a tuple.
TraceTupleFromSwitch(const CNodePtr & switch_cnode,const std::stack<int64_t> & tuple_indexes)631   EffectInfo TraceTupleFromSwitch(const CNodePtr &switch_cnode, const std::stack<int64_t> &tuple_indexes) {
632     auto branches = GetSwitchBranches(switch_cnode);
633     EffectInfo info = {EffectInfo::kDetected, false, false, false};
634     for (auto &branch : branches) {
635       auto tuple_indexes_copy = tuple_indexes;
636       EffectInfo branch_info = TraceTupleEffectInfo(branch->output(), &tuple_indexes_copy);
637       info.Merge(branch_info);
638     }
639     return info;
640   }
641 
642   // Setup all branches according the effect info.
SetupEffectBranches(const EffectInfo & info,const std::vector<FuncGraphPtr> & branches)643   void SetupEffectBranches(const EffectInfo &info, const std::vector<FuncGraphPtr> &branches) {
644     // Setup monad parameters for all branches according the effect info.
645     if (info.memory || info.load) {
646       AddMonadParameters(branches, "u", kUMonad->ToAbstract());
647     }
648     if (info.io) {
649       AddMonadParameters(branches, "io", kIOMonad->ToAbstract());
650     }
651     // Set merged effect info to both branches.
652     for (auto &branch : branches) {
653       MS_EXCEPTION_IF_NULL(branch);
654       branch->SetEffectInfo(info);
655       // Update caller if it is existed.
656       UpdateBranchCaller(branch);
657     }
658   }
659 
660   // Merge effect info for switch or switch_layer branch graphs.
MergeEffectInfo(const std::vector<FuncGraphPtr> & branches)661   EffectInfo MergeEffectInfo(const std::vector<FuncGraphPtr> &branches) {
662     EffectInfo info = {EffectInfo::kDetected, false, false, false};
663     for (auto &branch : branches) {
664       MS_EXCEPTION_IF_NULL(branch);
665       EffectInfo branch_info = GetEffectInfo(branch);
666       info.Merge(branch_info);
667     }
668     return info;
669   }
670 
671   // Trace a cnode for effect info.
TraceEffectInfo(const CNodePtr & cnode)672   EffectInfo TraceEffectInfo(const CNodePtr &cnode) {
673     MS_EXCEPTION_IF_NULL(cnode);
674     auto prim = GetPrimitive(cnode);
675     if (IsPrimitiveEquals(prim, prim::kPrimSwitch)) {
676       // Special handling for Switch primitive.
677       return TraceSwitchEffectInfo(cnode);
678     }
679 
680     if (IsPrimitiveEquals(prim, prim::kPrimSwitchLayer)) {
681       // Special handling for SwitchLayer primitive.
682       return TraceSwitchLayerEffectInfo(cnode);
683     }
684 
685     if (IsPrimitiveEquals(prim, prim::kPrimTupleGetItem)) {
686       // Trace tuple_getitem.
687       std::stack<int64_t> tuple_indexes;
688       return TraceTupleGetItemEffectInfo(cnode, &tuple_indexes);
689     }
690 
691     // For high-order pritimive such as Partial,
692     // we trace effect info from its argument.
693     int index_prim = GetSideEffectPropagate(prim);
694     if (index_prim > 0 && index_prim < static_cast<int>(cnode->size())) {
695       return TraceEffectInfo(cnode->input(static_cast<size_t>(index_prim)));
696     }
697 
698     // For func graph calls, we trace effect info from graph output.
699     auto called_graph = GetFuncGraph(cnode);
700     if (called_graph) {
701       return TraceEffectInfo(called_graph->output());
702     }
703 
704     //
705     // For ClassType as the input[0], if it is a primitive class
706     // with 'side_effect_propagate' attribute, we trace side effect
707     // from its argument indxed by the attribute value.
708     //
709     // e.g.:
710     //     setpara = P.Partial()(P.Assign, self.para)
711     //     setpara(x)
712     //
713     auto class_type = GetClassType(cnode);
714     if (class_type) {
715       int index = GetSideEffectPropagate(class_type);
716       if (index > 0 && index < static_cast<int>(cnode->size())) {
717         return TraceEffectInfo(cnode->input(static_cast<size_t>(index)));
718       }
719     }
720 
721     // Otherwise, no side effect found and stop trace.
722     return {EffectInfo::kDetected, false, false, false};
723   }
724 
725   // Trace an ANFNode for effect info.
TraceEffectInfo(const AnfNodePtr & node)726   EffectInfo TraceEffectInfo(const AnfNodePtr &node) {
727     if (node) {
728       // Trace cnode.
729       auto cnode = node->cast<CNodePtr>();
730       if (cnode) {
731         return TraceEffectInfo(cnode);
732       }
733 
734       // Trace parameter.
735       auto para = node->cast<ParameterPtr>();
736       if (para) {
737         return TraceEffectInfo(para);
738       }
739 
740       // Trace primitive.
741       auto prim = GetPrimitive(node);
742       if (prim) {
743         return GetPrimEffectInfo(prim);
744       }
745 
746       // Trace func graph.
747       auto value_node = node->cast<ValueNodePtr>();
748       if (value_node && value_node->value()) {
749         auto graph = value_node->value()->cast<FuncGraphPtr>();
750         if (graph) {
751           return GetEffectInfo(graph);
752         }
753       }
754     }
755     // Something is wrong if we reached here.
756     MS_LOG(WARNING) << "EffectInfo untraceable: node is a nullptr.";
757     return {EffectInfo::kDetected, false, false, false};
758   }
759 
GetParameterIndex(const FuncGraphPtr & func_graph,const ParameterPtr & para)760   int GetParameterIndex(const FuncGraphPtr &func_graph, const ParameterPtr &para) {
761     int parameter_index = 0;
762     for (auto &parameter : func_graph->parameters()) {
763       if (para == parameter) {
764         return parameter_index;
765       }
766       ++parameter_index;
767     }
768     MS_LOG(EXCEPTION) << "Parameter not found: " << (para ? para->DebugString() : "<null>");
769   }
770 
771   // Trace effect info from function parameter.
TraceEffectInfo(const ParameterPtr & para)772   EffectInfo TraceEffectInfo(const ParameterPtr &para) {
773     EffectInfo info{EffectInfo::kDetected, false, false, false};
774     ForEachRealArguments(para, [this, &info](const AnfNodePtr &arg) {
775       // Merge caller input effect info.
776       auto input_info = TraceEffectInfo(arg);
777       info.Merge(input_info);
778     });
779     return info;
780   }
781 
ForEachRealArguments(const ParameterPtr & para,const std::function<void (const AnfNodePtr &)> & handler)782   void ForEachRealArguments(const ParameterPtr &para, const std::function<void(const AnfNodePtr &)> &handler) {
783     MS_EXCEPTION_IF_NULL(para);
784     auto func_graph = para->func_graph();
785     MS_EXCEPTION_IF_NULL(func_graph);
786     // Find index of the parameter, starts from 0.
787     const int para_index = GetParameterIndex(func_graph, para);
788     const size_t input_index = static_cast<size_t>(para_index) + 1;
789     // Search user cnodes of the func graph.
790     auto &users = func_graph->func_graph_cnodes_index();
791     if (users.empty()) {
792       MS_LOG(WARNING) << "Unused graph for parameter " << para->DebugString();
793     }
794     for (auto &user : users) {
795       auto use_index = user.first->second;
796       if (use_index != 0) {
797         // Skip non-caller usage.
798         continue;
799       }
800       // Caller cnode.
801       auto cnode = dyn_cast<CNode>(user.first->first);
802       MS_EXCEPTION_IF_NULL(cnode);
803       if (cnode && input_index < cnode->size()) {
804         auto &real_arg = cnode->input(input_index);
805         if (real_arg == para) {
806           // Skip if the real argument is the given parameter.
807           continue;
808         }
809         handler(real_arg);
810       }
811     }
812   }
813 
814   // For call node, returns effect info of the callee graph.
GetCallEffectInfo(const CNodePtr & cnode)815   EffectInfo GetCallEffectInfo(const CNodePtr &cnode) {
816     MS_EXCEPTION_IF_NULL(cnode);
817     constexpr size_t min_call_node_size = 2;
818     if (cnode->size() < min_call_node_size) {
819       MS_LOG(EXCEPTION) << "Invalid call node: " << cnode->DebugString();
820     }
821     auto func_graph = GetValueNode<FuncGraphPtr>(cnode->inputs().at(1));
822     if (func_graph == nullptr) {
823       MS_LOG(EXCEPTION) << "Invalid call node: " << cnode->DebugString();
824     }
825     return GetEffectInfo(func_graph);
826   }
827 
828   // Detect effect info by depth first search.
DetectEffectInfo(const CNodePtr & cnode)829   EffectInfo DetectEffectInfo(const CNodePtr &cnode) {
830     // For primitive, get effect info from its attributes and inputs.
831     auto prim = GetPrimitive(cnode);
832     if (prim) {
833       // Skip 'return' cnode.
834       if (IsPrimitiveEquals(prim, prim::kPrimReturn)) {
835         return {EffectInfo::kDetected, false, false, false};
836       }
837       // Special handling for 'call' cnode.
838       if (IsPrimitiveEquals(prim, prim::kPrimCall)) {
839         return GetCallEffectInfo(cnode);
840       }
841       auto info = GetPrimEffectInfo(prim);
842       if (!info.memory && !IsKeepRef(prim)) {
843         // For primitive calls, if no memory effects but
844         // Ref parameter used, we will insert 'load' before them.
845         // Except for primitives like J(f) or Partial(f, x) which propagate side effect,
846         // load is inserted inside the func_graph f.
847         info.load = HasRefInput(cnode);
848       }
849       return info;
850     }
851 
852     // For func graph, detect effect info by its children cnodes.
853     auto func_graph = GetFuncGraph(cnode);
854     if (func_graph) {
855       return GetEffectInfo(func_graph);
856     }
857 
858     // When input[0] is a cnode, it is a function returned from
859     // a high-order function call, we trace it by return value.
860     auto func_cnode = GetFuncCNode(cnode);
861     if (func_cnode) {
862       caller_ = cnode;
863       return TraceEffectInfo(func_cnode);
864     }
865 
866     // When input[0] is a parameter, it is a function parameter for
867     // the high-order function, we trace it by caller.
868     auto func_para = GetFuncParameter(cnode);
869     if (func_para) {
870       return TraceEffectInfo(func_para);
871     }
872 
873     // When input[0] is a MultitypeFuncGraph, it's not specialized
874     // as one of its parameters is AbstractUndertermined,
875     // This MultitypeFuncGraph may be specialized at next Renormalize
876     // process, but we have to keep the order by insert UMonad now,
877     // otherwise order will be lost in next Renormalize.
878     // So assume it has memory side effect conservatively.
879     auto func_multitype = GetFuncMultitypeFuncGraph(cnode);
880     if (func_multitype) {
881       MS_LOG(DEBUG) << "Assume memory side effect for: " << cnode->DebugString();
882       return {EffectInfo::kDetected, true, false, false};
883     }
884 
885     MS_LOG(WARNING) << "Side effect undetectable: " << cnode->DebugString(2);
886     return {EffectInfo::kDetected, false, false, false};
887   }
888 
889   // Gets EffectInfo for CNode.
GetEffectInfo(const CNodePtr & cnode)890   EffectInfo GetEffectInfo(const CNodePtr &cnode) {
891     const auto &effect_info = cnode->GetEffectInfo();
892     if (effect_info.state == EffectInfo::kDetected) {
893       // Effect info already detected, return it.
894       return effect_info;
895     }
896 
897     // Detect effect info for the cnode.
898     EffectInfo info = DetectEffectInfo(cnode);
899     if (info.state == EffectInfo::kDetected) {
900       // Save detected info into cnode.
901       cnode->SetEffectInfo(info);
902     }
903     return info;
904   }
905 
906   // Gets SCC that the given graph belongs to.
GetScc(const FuncGraphPtr & func_graph) const907   const SccPtr &GetScc(const FuncGraphPtr &func_graph) const {
908     auto found = scc_map_.find(func_graph);
909     if (found == scc_map_.end()) {
910       MS_LOG(EXCEPTION) << "SCC not found for " << (func_graph ? func_graph->ToString() : "FG(null)");
911     }
912     return found->second;
913   }
914 
915   // Set effect info for all member graphs in the SCC.
SetSccEffectInfo(const SccPtr & scc,const EffectInfo & info) const916   void SetSccEffectInfo(const SccPtr &scc, const EffectInfo &info) const {
917     MS_EXCEPTION_IF_NULL(scc);
918     for (auto &g : *scc) {
919       MS_EXCEPTION_IF_NULL(g);
920       g->SetEffectInfo(info);
921     }
922   }
923 
924   // Gets EffectInfo for func graph.
GetEffectInfo(const FuncGraphPtr & func_graph)925   EffectInfo GetEffectInfo(const FuncGraphPtr &func_graph) {
926     MS_EXCEPTION_IF_NULL(func_graph);
927     const auto &effect_info = func_graph->GetEffectInfo();
928     if (effect_info.state != EffectInfo::kUnknown) {
929       // Effect info already set, return it.
930       return effect_info;
931     }
932     // Get SCC that this graph belongs to.
933     auto &scc = GetScc(func_graph);
934     MS_EXCEPTION_IF_NULL(scc);
935     // To prevent SCC members be visited again, we set effect info
936     // to 'kDetecting' state before start to check cnodes.
937     EffectInfo info{EffectInfo::kDetecting, false, false, false};
938     SetSccEffectInfo(scc, info);
939     // Check side effects for all cnodes in the SCC.
940     std::vector<CNodePtr> undetected;
941     for (auto &g : *scc) {
942       MS_EXCEPTION_IF_NULL(g);
943       for (auto &cnode : g->order_list()) {
944         auto cnode_effect = GetEffectInfo(cnode);
945         if (cnode_effect.state != EffectInfo::kDetected) {
946           // For side effect undetected node, it could be a call to the SCC member graph,
947           // we will try to check side effect again after SCC side effect detected.
948           undetected.push_back(cnode);
949         }
950         // Merge effect info from the node.
951         info.Merge(cnode_effect);
952       }
953       // Make sure all sub-graphs is checked. since some sub-graphs may not directly called,
954       // for example: return ValueNode(sub_graph).
955       for (auto &sg : g->func_graphs_used()) {
956         (void)GetEffectInfo(sg.first);
957       }
958     }
959     // Update effect into for all members of the SCC.
960     info.state = EffectInfo::kDetected;
961     SetSccEffectInfo(scc, info);
962     // Check undetected cnodes again after side effect of the SCC is detected.
963     for (auto &cnode : undetected) {
964       MS_EXCEPTION_IF_NULL(cnode);
965       auto cnode_effect = GetEffectInfo(cnode);
966       // Side effect should be detected now.
967       if (cnode_effect.state != EffectInfo::kDetected) {
968         MS_LOG(EXCEPTION) << "Side effect is undectable: " << cnode->DebugString();
969       }
970     }
971     // graph which need PipelineSplit doesn't have effect.
972     if (func_graph->stage() != -1) {
973       info.memory = false;
974       info.load = false;
975       info.io = false;
976     }
977     return info;
978   }
979 
SaveBranchCaller(const CNodePtr & switch_node,const FuncGraphPtr & branch)980   void SaveBranchCaller(const CNodePtr &switch_node, const FuncGraphPtr &branch) {
981     MS_EXCEPTION_IF_NULL(branch);
982     MS_EXCEPTION_IF_NULL(switch_node);
983     auto manager = branch->manager();
984     MS_EXCEPTION_IF_NULL(manager);
985     auto &node_users = manager->node_users();
986     auto found = node_users.find(switch_node);
987     if (found == node_users.end()) {
988       MS_LOG(WARNING) << "Caller not found for " << switch_node->DebugString();
989       return;
990     }
991     if (found->second.size() != 1) {
992       MS_LOG(WARNING) << "Wrong callers " << found->second.size() << " for " << switch_node->DebugString();
993       return;
994     }
995     auto &user = *found->second.begin();
996     auto cnode = dyn_cast<CNode>(user.first);
997     if (cnode != nullptr || user.second == 0) {
998       branch_caller_map.emplace(branch, cnode);
999     }
1000   }
1001 
UpdateBranchCaller(const FuncGraphPtr & branch)1002   void UpdateBranchCaller(const FuncGraphPtr &branch) {
1003     MS_EXCEPTION_IF_NULL(branch);
1004     auto iter = branch_caller_map.find(branch);
1005     if (iter == branch_caller_map.end()) {
1006       return;
1007     }
1008     const auto &caller = iter->second;
1009     const auto &info = branch->GetEffectInfo();
1010     AddMonadForCaller(caller, info);
1011   }
1012 
AddMonadForCaller(const CNodePtr & caller,const EffectInfo & info)1013   void AddMonadForCaller(const CNodePtr &caller, const EffectInfo &info) {
1014     if (info.memory || info.load) {
1015       // Add u monad argument to caller if need.
1016       AddMonadArgument(caller, kUMonad);
1017     }
1018     if (info.io) {
1019       // Add io monad argument to caller if need.
1020       AddMonadArgument(caller, kIOMonad);
1021     }
1022   }
1023 
AddMonadArgument(const CNodePtr & cnode,const ValuePtr & monad)1024   void AddMonadArgument(const CNodePtr &cnode, const ValuePtr &monad) {
1025     MS_EXCEPTION_IF_NULL(cnode);
1026     MS_EXCEPTION_IF_NULL(monad);
1027     auto monad_abs = monad->ToAbstract();
1028     for (size_t i = 1; i < cnode->size(); ++i) {
1029       auto abs = cnode->inputs().at(i)->abstract();
1030       if (abs != nullptr && *abs == *monad_abs) {
1031         // Skip if monad argument already existed.
1032         return;
1033       }
1034     }
1035     // Add monad argument if not yet.
1036     auto monad_input = NewValueNode(monad);
1037     monad_input->set_abstract(monad_abs);
1038     if ((monad == kUMonad) && cnode->size() > 1 && HasAbstractIOMonad(cnode->inputs().back())) {
1039       // Insert u monad before io monad.
1040       size_t last_index = cnode->size() - 1;
1041       cnode->add_input(cnode->input(last_index));
1042       cnode->set_input(last_index, monad_input);
1043     } else {
1044       // Add monad as the last input.
1045       cnode->add_input(monad_input);
1046     }
1047   }
1048 
1049   // The root graph.
1050   FuncGraphPtr root_;
1051 
1052   // SCC map.
1053   SccMap scc_map_;
1054 
1055   // Single branch (in switch) and its caller cnode.
1056   std::map<FuncGraphPtr, CNodePtr> branch_caller_map;
1057 
1058   // Current high order func caller cnode.
1059   CNodePtr caller_ = nullptr;
1060 
1061   // switch_layer_calls save all switch_layer calls, so that
1062   // we can check whether monad argument should be added for them.
1063   std::vector<SwitchLayerCall> switch_layer_calls_;
1064 };  // class SideEffectFinder
1065 
1066 // --------------------------------------------------------------------
1067 // AutoMonadConverter converts side-effect cnodes into monad form.
1068 // --------------------------------------------------------------------
1069 class AutoMonadConverter {
1070  public:
Handle(const FuncGraphPtr & func_graph,bool top)1071   static bool Handle(const FuncGraphPtr &func_graph, bool top) {
1072     AutoMonadConverter converter(func_graph, top);
1073     return converter.Run();
1074   }
1075 
1076  private:
AutoMonadConverter(const FuncGraphPtr & func_graph,bool top)1077   AutoMonadConverter(const FuncGraphPtr &func_graph, bool top)
1078       : func_graph_(func_graph), manager_(func_graph->manager()), top_(top) {}
1079 
1080   ~AutoMonadConverter() = default;
1081 
Run()1082   bool Run() {
1083     // Handle cnodes for side effects.
1084     const auto &info = func_graph_->GetEffectInfo();
1085     if (info.state == EffectInfo::kDetected) {
1086       HandleCNodes();
1087     }
1088 
1089     // Safe to clear isolated nodes after handled side effect nodes.
1090     ClearIsolatedNodes();
1091 
1092     // Clean up after conversion finished.
1093     func_graph_->ClearOrderList();
1094     return has_effect_cnodes_;
1095   }
1096 
1097   // Check if there are side effects from effect info.
HasSideEffects(const EffectInfo & info)1098   static bool HasSideEffects(const EffectInfo &info) { return (info.memory || info.io || info.load); }
1099 
1100   // Gets effect info for a cnode.
GetEffectInfo(const CNodePtr & cnode) const1101   const EffectInfo &GetEffectInfo(const CNodePtr &cnode) const {
1102     MS_EXCEPTION_IF_NULL(cnode);
1103     auto &effect_info = cnode->GetEffectInfo();
1104     if (effect_info.state != EffectInfo::kDetected) {
1105       // Effect info should have been set by SideEffectFinder.
1106       MS_LOG(EXCEPTION) << "Side effects not detected: " << cnode->DebugString();
1107     }
1108     return effect_info;
1109   }
1110 
1111   // Handle CNodes for side effects.
HandleCNodes()1112   void HandleCNodes() {
1113     // Check whether UpdateState and Depend are required.
1114     bool update_state = NeedUpdateState();
1115 
1116     // Check all cnodes in order list.
1117     for (auto &cnode : func_graph_->order_list()) {
1118       auto &info = GetEffectInfo(cnode);
1119       has_effect_cnodes_ = (has_effect_cnodes_ || HasSideEffects(info));
1120       if (cnode->func_graph() != func_graph_) {
1121         // Handle outer cnode.
1122         HandleOuterNode(cnode, info);
1123       } else {
1124         // Handle cnode with memory side effects.
1125         if (info.memory) {
1126           HandleMemoryEffects(cnode, update_state);
1127         } else if (info.load) {
1128           // If no memory side effects, handle load if need.
1129           HandleLoad(cnode, update_state);
1130         }
1131         // Handle cnode with IO side effects.
1132         if (info.io) {
1133           HandleIoEffects(cnode, update_state);
1134         }
1135         // If the node has no side effects but 'no_eliminate' flag is set,
1136         // we save it to no_eliminate_nodes and handle them late.
1137         if (!info.memory && !info.io && IsNoEliminateNode(cnode)) {
1138           no_eliminate_nodes_.emplace_back(cnode);
1139         }
1140       }
1141       cnode->SetEffectHandled(true);
1142     }
1143     // Attach no eliminate nodes to output.
1144     HandleNoEliminateNodes();
1145     // Attach monad to output if required.
1146     if (update_state) {
1147       AttachMonadToOutput();
1148     }
1149   }
1150 
1151   // Return true if the given cnode is primitive cnode with 'no_eliminate' flag.
IsNoEliminateNode(const CNodePtr & cnode)1152   bool IsNoEliminateNode(const CNodePtr &cnode) {
1153     if (cnode == nullptr || cnode->size() == 0) {
1154       return false;
1155     }
1156     auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
1157     if (prim == nullptr) {
1158       return false;
1159     }
1160     return GetPrimitiveFlag(prim, ATTR_NO_ELIMINATE);
1161   }
1162 
1163   // Attach no eliminate nodes to output.
HandleNoEliminateNodes()1164   void HandleNoEliminateNodes() {
1165     if (no_eliminate_nodes_.empty()) {
1166       // Skip if no nodes to be handled.
1167       return;
1168     }
1169     // If only one node, attach it to output directly.
1170     if (no_eliminate_nodes_.size() == 1) {
1171       AttachToOutput(no_eliminate_nodes_.front());
1172       return;
1173     }
1174     // For multiple nodes, attach them to output by a tuple.
1175     std::vector<AnfNodePtr> tuple_inputs;
1176     AbstractBasePtrList element_abstracts;
1177     tuple_inputs.reserve(no_eliminate_nodes_.size() + 1);
1178     element_abstracts.reserve(no_eliminate_nodes_.size());
1179     tuple_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
1180     for (auto &node : no_eliminate_nodes_) {
1181       tuple_inputs.emplace_back(node);
1182       element_abstracts.emplace_back(node->abstract());
1183     }
1184     auto make_tuple_node = func_graph_->NewCNode(tuple_inputs);
1185     make_tuple_node->set_abstract(std::make_shared<abstract::AbstractTuple>(element_abstracts));
1186     AttachToOutput(make_tuple_node);
1187   }
1188 
1189   // Clean no side effect dependency nodes.
1190   //   From:  output = Depend(output, StopGrad)
1191   //          return output
1192   //
1193   //   To:    return output
ClearIsolatedNodes() const1194   void ClearIsolatedNodes() const {
1195     auto output = GetGraphOutput();
1196     if (IsPrimitiveCNode(output, prim::kPrimDepend) &&
1197         IsPrimitiveCNode(output->cast<CNodePtr>()->input(2), prim::kPrimStopGradient)) {
1198       // Replace Depend(orig_output, StopGrad) node with orig_output.
1199       // After that, nodes may be eliminated if have no side effects.
1200       auto &orig_output = output->cast<CNodePtr>()->input(1);
1201       func_graph_->set_output(orig_output);
1202     }
1203   }
1204 
HandleOuterNode(const CNodePtr & cnode,const EffectInfo & info)1205   void HandleOuterNode(const CNodePtr &cnode, const EffectInfo &info) {
1206     MS_EXCEPTION_IF_NULL(cnode);
1207     if (info.memory || info.load) {
1208       (void)GetUniverse();
1209       bool load_with_primitive = (info.load && IsPrimitiveCNode(cnode));
1210       if (!cnode->IsEffectHandled() && !load_with_primitive) {
1211         auto u_node = NewValueNode(kUMonad);
1212         u_node->set_abstract(kUMonad->ToAbstract());
1213         cnode->add_input(u_node);
1214       }
1215     }
1216     if (info.io) {
1217       (void)GetIoState();
1218       if (!cnode->IsEffectHandled()) {
1219         auto io = NewValueNode(kIOMonad);
1220         io->set_abstract(kIOMonad->ToAbstract());
1221         cnode->add_input(io);
1222       }
1223     }
1224   }
1225 
1226   //
1227   // Convert cnode with memory side effect to monad form,
1228   // from:
1229   //    output = func(input)
1230   // to:
1231   //    output = func(input, u)
1232   //    u = UpdateState(u, output) # if update_state is true
1233   //
HandleMemoryEffects(const CNodePtr & cnode,bool update_state)1234   void HandleMemoryEffects(const CNodePtr &cnode, bool update_state) {
1235     const auto &u = GetUniverse();
1236     AddMonadInput(cnode, u);
1237     if (update_state) {
1238       u_ = UpdateState(u, cnode);
1239     }
1240   }
1241 
1242   //
1243   // Convert cnode with io side effect to monad form,
1244   // from:
1245   //    output = func(input)
1246   // to:
1247   //    output = func(input, io)
1248   //    io = UpdateState(io, output) # if update_state is true
1249   //
HandleIoEffects(const CNodePtr & cnode,bool update_state)1250   void HandleIoEffects(const CNodePtr &cnode, bool update_state) {
1251     const auto &io = GetIoState();
1252     AddMonadInput(cnode, io);
1253     if (update_state) {
1254       io_ = UpdateState(io, cnode);
1255     }
1256   }
1257 
HandleLoad(const CNodePtr & cnode,bool update_state)1258   void HandleLoad(const CNodePtr &cnode, bool update_state) {
1259     MS_EXCEPTION_IF_NULL(cnode);
1260     auto value = GetValueNode(cnode->input(0));
1261     if (value && value->isa<Primitive>()) {
1262       // For primitive calls that use Ref as input, insert Loads before them.
1263       InsertLoads(cnode, update_state);
1264     } else {
1265       // For non-primitive calls, load is used inside the callee,
1266       // We do not insert load for it but handle it as a side
1267       // effects cnode.
1268       HandleMemoryEffects(cnode, update_state);
1269     }
1270   }
1271 
1272   //
1273   // Insert Loads for a primitive cnode that use Ref as input.
1274   // for example, from:
1275   //    out = Prim(self.para1, self.para2, other_args)
1276   // to:
1277   //    p1 = Load(self.para1, u)
1278   //    p2 = Load(self.para2, u)
1279   //    t = make_tuple(p1, p2) # if update_state
1280   //    u1 = UpdateState(u, t)   # is required
1281   //    out = Prim(p1, p2, other_args)
1282   //
InsertLoads(const CNodePtr & cnode,bool update_state)1283   void InsertLoads(const CNodePtr &cnode, bool update_state) {
1284     // Find ref inputs.
1285     auto ref_inputs = GetRefInputs(cnode);
1286     if (ref_inputs.empty()) {
1287       MS_LOG(WARNING) << "Ref input not found for load insertion: " << cnode->DebugString();
1288       return;
1289     }
1290     // Current u monad.
1291     auto current_u = GetUniverse();
1292     // Create Load cnodes.
1293     auto loads = MakeLoads(cnode, ref_inputs, current_u);
1294     if (loads.empty() || !update_state) {
1295       // Skip UpdateState insertion.
1296       return;
1297     }
1298     // Insert UpdateState if required.
1299     if (loads.size() == 1) {
1300       // One Load, no make_tuple needed.
1301       u_ = UpdateState(current_u, loads.front());
1302       return;
1303     }
1304     // Multiple Loads, Create a MakeTuple before UpdateState.
1305     abstract::AbstractBasePtrList load_abstracts;
1306     std::transform(loads.begin(), loads.end(), std::back_inserter(load_abstracts),
1307                    [](const AnfNodePtr &load) { return load->abstract(); });
1308     loads.insert(loads.begin(), NewValueNode(prim::kPrimMakeTuple));
1309     auto make_tuple = func_graph_->NewCNode(loads);
1310     make_tuple->set_abstract(std::make_shared<abstract::AbstractTuple>(load_abstracts));
1311     u_ = UpdateState(current_u, make_tuple);
1312   }
1313 
MakeLoads(const CNodePtr & cnode,const RefInputs & ref_inputs,const AnfNodePtr & u)1314   std::vector<AnfNodePtr> MakeLoads(const CNodePtr &cnode, const RefInputs &ref_inputs, const AnfNodePtr &u) {
1315     std::vector<AnfNodePtr> loads;
1316     for (auto &ref_input : ref_inputs) {
1317       // Make a Load cnode for ref input.
1318       auto &ref = ref_input.first;
1319       auto load = MakeLoad(cnode, ref, u);
1320       // Replace input with the load cnode.
1321       for (size_t index : ref_input.second) {
1322         manager_->SetEdge(cnode, index, load);
1323       }
1324       loads.emplace_back(std::move(load));
1325     }
1326     return loads;
1327   }
1328 
MakeLoad(const CNodePtr & cnode,const AnfNodePtr & ref,const AnfNodePtr & u)1329   CNodePtr MakeLoad(const CNodePtr &cnode, const AnfNodePtr &ref, const AnfNodePtr &u) {
1330     static const std::string primitive_target = "primitive_target";
1331     // Create Load cnode.
1332     auto load_prim = NewValueNode(prim::kPrimLoad);
1333     auto load_cnode = func_graph_->NewCNode({load_prim, ref, u});
1334     // Set device target for Load CNode.
1335     std::string target = GetCNodeTarget(cnode);
1336     load_cnode->set_user_data(primitive_target, std::make_shared<std::string>(target));
1337     // Set load_cnode abstract to Tensor according the input Ref[Tensor].
1338     auto ref_abs = dyn_cast<abstract::AbstractRef>(ref->abstract());
1339     MS_EXCEPTION_IF_NULL(ref_abs);
1340     load_cnode->set_abstract(ref_abs->CloneAsTensor());
1341     return load_cnode;
1342   }
1343 
1344   // Add or replace monad input.
AddMonadInput(const CNodePtr & cnode,const AnfNodePtr & monad)1345   void AddMonadInput(const CNodePtr &cnode, const AnfNodePtr &monad) {
1346     MS_EXCEPTION_IF_NULL(cnode);
1347     constexpr size_t max_monad_inputs = 2;
1348     auto monad_abs = monad->abstract();
1349     auto &inputs = cnode->inputs();
1350     int last = static_cast<int>(inputs.size()) - 1;
1351     int stop = last - max_monad_inputs;
1352     // Search monad in inputs, replace it if found.
1353     for (int i = last; i > 0 && i > stop; --i) {
1354       size_t index = static_cast<size_t>(i);
1355       auto input_abs = inputs[index]->abstract();
1356       if (input_abs && *input_abs == *monad_abs) {
1357         manager_->SetEdge(cnode, i, monad);
1358         return;
1359       }
1360     }
1361     // If monad not found in inputs, add a monad input.
1362     manager_->AddEdge(cnode, monad);
1363   }
1364 
AttachMonadToOutput() const1365   void AttachMonadToOutput() const {
1366     if (u_) {
1367       AttachToOutput(u_);
1368     }
1369     if (io_) {
1370       AttachToOutput(io_);
1371     }
1372   }
1373 
AttachToOutput(const AnfNodePtr & node) const1374   void AttachToOutput(const AnfNodePtr &node) const {
1375     auto output = GetGraphOutput();
1376     auto depend = NewValueNode(prim::kPrimDepend);
1377     // If isolated nodes dependencies exist.
1378     if (IsPrimitiveCNode(output, prim::kPrimDepend) &&
1379         IsPrimitiveCNode(output->cast<CNodePtr>()->input(kDependAttachNodeIndex), prim::kPrimStopGradient)) {
1380       // Insert new Depend node before isolated Depend node.
1381       auto isolated_depend = output->cast<CNodePtr>();
1382       auto &orig_output = isolated_depend->input(1);
1383       auto state_depend = func_graph_->NewCNode({depend, orig_output, node});
1384       state_depend->set_abstract(orig_output->abstract());
1385       manager_->SetEdge(isolated_depend, 1, state_depend);
1386       return;
1387     }
1388     // Insert Depend node and set it as output, if no isolated nodes.
1389     auto depend_cnode = func_graph_->NewCNode({depend, output, node});
1390     depend_cnode->set_abstract(output->abstract());
1391     func_graph_->set_output(depend_cnode);
1392   }
1393 
GetGraphOutput() const1394   AnfNodePtr GetGraphOutput() const {
1395     auto output = func_graph_->output();
1396     if (output != nullptr) {
1397       return output;
1398     }
1399     return NewValueNode(kNone);
1400   }
1401 
UpdateState(const AnfNodePtr & state,const AnfNodePtr & attach)1402   AnfNodePtr UpdateState(const AnfNodePtr &state, const AnfNodePtr &attach) {
1403     MS_EXCEPTION_IF_NULL(attach);
1404     // Not attach UpdateState if set kAttrIgnoreSideEffect.
1405     auto attr_ignore_side_effect = attach->cast<CNodePtr>()->GetAttr(kAttrIgnoreSideEffect);
1406     auto ignore_side_effect = attr_ignore_side_effect != nullptr && attr_ignore_side_effect->isa<BoolImm>() &&
1407                               GetValue<bool>(attr_ignore_side_effect);
1408     if (ignore_side_effect) {
1409       return state;
1410     }
1411 
1412     auto update_state = NewValueNode(prim::kPrimUpdateState);
1413     auto update_state_cnode = func_graph_->NewCNode({update_state, state, attach});
1414     update_state_cnode->set_abstract(state->abstract());
1415     return update_state_cnode;
1416   }
1417 
GetUniverse()1418   AnfNodePtr &GetUniverse() {
1419     if (u_ == nullptr) {
1420       if (top_) {
1421         u_ = NewValueNode(kUMonad);
1422         u_->set_abstract(kUMonad->ToAbstract());
1423       } else {
1424         u_ = AddMonadParameter(func_graph_, "u", kUMonad->ToAbstract());
1425       }
1426     }
1427     return u_;
1428   }
1429 
GetIoState()1430   AnfNodePtr &GetIoState() {
1431     if (io_ == nullptr) {
1432       if (top_) {
1433         io_ = NewValueNode(kIOMonad);
1434         io_->set_abstract(kIOMonad->ToAbstract());
1435       } else {
1436         io_ = AddMonadParameter(func_graph_, "io", kIOMonad->ToAbstract());
1437       }
1438     }
1439     return io_;
1440   }
1441 
1442   // Return true if update_state should be used in this func graph.
1443   // In some case, update_state can be omitted, such as:
1444   //   def side_effect_tail_call(args):
1445   //       a = pure_func(args)
1446   //       return side_effect_call(a)
NeedUpdateState() const1447   bool NeedUpdateState() const {
1448     // Search for the only one side effect cnode.
1449     CNodePtr side_effect_cnode = nullptr;
1450     for (auto &cnode : func_graph_->order_list()) {
1451       if (HasSideEffect(cnode)) {
1452         if (side_effect_cnode != nullptr) {
1453           // There are multiple side effect cnodes, update state is required.
1454           return true;
1455         }
1456         side_effect_cnode = cnode;
1457       }
1458     }
1459     if (side_effect_cnode == nullptr) {
1460       // No side effect cnode, no update state.
1461       return false;
1462     }
1463     if (IsPrimitiveCNode(side_effect_cnode)) {
1464       // Always add update_state for primitive cnode.
1465       return true;
1466     }
1467     // If the only side effect cnode is not the tail call, update_state is required.
1468     return func_graph_->output() != side_effect_cnode;
1469   }
1470 
HasSideEffect(const CNodePtr & cnode) const1471   bool HasSideEffect(const CNodePtr &cnode) const {
1472     const auto &cnode_info = GetEffectInfo(cnode);
1473     return (cnode_info.memory || cnode_info.load || cnode_info.io);
1474   }
1475 
1476   // The func graph to be converted.
1477   const FuncGraphPtr &func_graph_;
1478 
1479   // The func graph manager, used for graph edge update.
1480   FuncGraphManagerPtr manager_;
1481 
1482   // True if converting top graph.
1483   const bool top_;
1484 
1485   // True if there are side effect cnodes within this func graph.
1486   bool has_effect_cnodes_ = false;
1487 
1488   // CNodes that should not be eliminated even it is isolated node.
1489   std::vector<CNodePtr> no_eliminate_nodes_;
1490 
1491   // Current memory state node, null if no memory side effects.
1492   AnfNodePtr u_;
1493 
1494   // Current IO state node, null if no IO side effects.
1495   AnfNodePtr io_;
1496 };  // class AutoMonadConverter
1497 }  // namespace
1498 
1499 // Entry point of the auto-monad phase,
1500 // the func_graph should be resolved and infer is done.
1501 // return true if side effect nodes found in func_graph.
AutoMonad(const FuncGraphPtr & func_graph)1502 bool AutoMonad(const FuncGraphPtr &func_graph) {
1503   MS_EXCEPTION_IF_NULL(func_graph);
1504   MS_EXCEPTION_IF_NULL(func_graph->manager());
1505 
1506   // Search and mark side effects for the graph and sub-graphs.
1507   // this should be called before auto-monad starts.
1508   SideEffectFinder::Search(func_graph);
1509 
1510   // Execute auto-monad conversion on top graph.
1511   bool has_effects = AutoMonadConverter::Handle(func_graph, true);
1512   // Convert used sub-graphs.
1513   auto fg_used_total = func_graph->func_graphs_used_total();
1514   for (auto &fg : fg_used_total) {
1515     auto top_flag = fg->has_flag(mindspore::kFuncGraphFlagBackPropEntry);
1516     if (fg->stage() != -1) {
1517       top_flag = true;
1518     }
1519     bool fg_has_effects = AutoMonadConverter::Handle(fg, top_flag);
1520     has_effects = has_effects || fg_has_effects;
1521   }
1522   return has_effects;
1523 }
1524 
ReAutoMonad(const FuncGraphPtr & func_graph)1525 bool ReAutoMonad(const FuncGraphPtr &func_graph) {
1526   // AutoMonad for bprop network, only Monad for func graphs which back propogators have side effects.
1527   // Or AutoMonad for MultitypeFuncGraph which specialized in Renormalize other than the first Specialize pass.
1528   MS_EXCEPTION_IF_NULL(func_graph);
1529   bool need_auto_monad = false;
1530   std::vector<FuncGraphPtr> auto_monaded_fg;
1531   func_graph->EraseUnusedNodeInOrder();
1532   for (auto &fg : func_graph->func_graphs_used_total()) {
1533     if (fg->has_flag(mindspore::kFuncGraphFlagReAutoMonad)) {
1534       auto_monaded_fg.push_back(fg);
1535       for (auto &used_fg : fg->func_graphs_used_total()) {
1536         used_fg->set_flag(mindspore::kFuncGraphFlagReAutoMonad, true);
1537         auto_monaded_fg.push_back(used_fg);
1538       }
1539       need_auto_monad = true;
1540       MS_LOG(DEBUG) << "AutoMonad Grad for func graph: " << fg->ToString();
1541     }
1542     fg->EraseUnusedNodeInOrder();
1543   }
1544   bool changed = false;
1545   if (need_auto_monad) {
1546     for (auto &fg : func_graph->func_graphs_used_total()) {
1547       if (!fg->has_flag(mindspore::kFuncGraphFlagReAutoMonad)) {
1548         fg->ClearOrderList();
1549       }
1550     }
1551     changed = AutoMonad(func_graph);
1552     for (auto &fg : auto_monaded_fg) {
1553       fg->erase_flag(mindspore::kFuncGraphFlagReAutoMonad);
1554     }
1555     // After auto monad, Order List and Isolate nodes in graph and manager will be cleared.
1556   } else {
1557     func_graph->ClearOrderList();
1558     for (auto &fg : func_graph->func_graphs_used_total()) {
1559       fg->ClearOrderList();
1560     }
1561   }
1562   return changed;
1563 }
1564 }  // namespace pipeline
1565 }  // namespace mindspore
1566