• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "pipeline/jit/ps/static_analysis/auto_monad.h"
18 #include <list>
19 #include <vector>
20 #include <stack>
21 #include <string>
22 #include <utility>
23 #include <memory>
24 #include <algorithm>
25 #include "mindspore/core/ops/structure_ops.h"
26 #include "mindspore/core/ops/sparse_ops.h"
27 #include "mindspore/core/ops/sequence_ops.h"
28 #include "mindspore/core/ops/nn_ops.h"
29 #include "mindspore/core/ops/array_ops.h"
30 #include "mindspore/core/ops/framework_ops.h"
31 #include "ir/anf.h"
32 #include "pipeline/jit/ps/parse/resolve.h"
33 #include "frontend/operator/ops.h"
34 #include "frontend/operator/composite/multitype_funcgraph.h"
35 #include "utils/flags.h"
36 #include "include/common/utils/utils.h"
37 #include "include/common/utils/anfalgo.h"
38 #include "utils/hash_map.h"
39 #include "utils/hash_set.h"
40 #include "utils/log_adapter.h"
41 #include "utils/ordered_map.h"
42 #include "utils/ordered_set.h"
43 #include "base/effect_info.h"
44 #include "abstract/abstract_value.h"
45 #include "pipeline/jit/ps/debug/trace.h"
46 
47 namespace mindspore {
48 namespace pipeline {
49 namespace {  // namespace anonymous
50 using ClassTypePtr = std::shared_ptr<parse::ClassType>;
51 using RefInputs = OrderedMap<AnfNodePtr, std::vector<size_t>>;
52 
53 // Add or get a monad parameter.
AddMonadParameter(const FuncGraphPtr & func_graph,const std::string & name,const abstract::AbstractBasePtr & abs)54 AnfNodePtr AddMonadParameter(const FuncGraphPtr &func_graph, const std::string &name,
55                              const abstract::AbstractBasePtr &abs) {
56   MS_EXCEPTION_IF_NULL(func_graph);
57   MS_EXCEPTION_IF_NULL(abs);
58   size_t params_size = func_graph->parameters().size();
59   size_t io_monad_location = params_size;
60   // Search for existed parameters, return it if found.
61   for (size_t i = 0; i < params_size; i++) {
62     auto &node = func_graph->parameters()[i];
63     auto para = dyn_cast<Parameter>(node);
64     if (para == nullptr) {
65       continue;
66     }
67     auto para_abs = para->abstract();
68     if (para_abs && *para_abs == *abs) {
69       return para;
70     }
71     if (HasAbstractIOMonad(para)) {
72       io_monad_location = i;
73     }
74   }
75   // Create a new parameter if not existed.
76   auto para = std::make_shared<Parameter>(func_graph);
77   para->set_name(name);
78   MS_EXCEPTION_IF_NULL(para->debug_info());
79   para->debug_info()->set_name(name);
80   para->set_abstract(abs);
81   // If io monad parameter added before u monad parameter, should insert u monad before io monad in parameters
82   if (io_monad_location != params_size && abs->isa<abstract::AbstractUMonad>()) {
83     std::vector<AnfNodePtr> params = func_graph->parameters();
84     (void)params.insert(params.begin() + SizeToInt(io_monad_location), para);
85     func_graph->set_parameters(params);
86   } else {
87     func_graph->add_parameter(para);
88   }
89   return para;
90 }
91 
92 // Gets side effect propagate attribute value from a ClassType object.
GetSideEffectPropagate(const ClassTypePtr & class_type)93 int GetSideEffectPropagate(const ClassTypePtr &class_type) {
94   if (class_type) {
95     auto obj = class_type->obj();
96     if (py::hasattr(obj, GRAPH_FLAG_SIDE_EFFECT_PROPAGATE)) {
97       auto value = py::getattr(obj, GRAPH_FLAG_SIDE_EFFECT_PROPAGATE);
98       return value.cast<int>();
99     }
100   }
101   return 0;
102 }
103 
104 // Gets 'side_effect_propagate' attribute value from a primitive.
GetSideEffectPropagate(const PrimitivePtr & prim)105 int GetSideEffectPropagate(const PrimitivePtr &prim) {
106   if (prim) {
107     auto attr = prim->GetAttr(GRAPH_FLAG_SIDE_EFFECT_PROPAGATE);
108     if (attr && attr->isa<Int64Imm>()) {
109       return static_cast<int>(attr->cast<Int64ImmPtr>()->value());
110     }
111   }
112   return 0;
113 }
114 
115 // Gets ref inputs and its indexes from a cnode.
GetRefInputs(const CNodePtr & cnode)116 RefInputs GetRefInputs(const CNodePtr &cnode) {
117   RefInputs ref_inputs;
118   MS_EXCEPTION_IF_NULL(cnode);
119   for (size_t i = 1; i < cnode->size(); ++i) {
120     auto &input = cnode->input(i);
121     if (common::AnfAlgo::HasAbstractRef(input)) {
122       ref_inputs[input].push_back(i);
123     }
124   }
125   return ref_inputs;
126 }
127 
128 // Return true if cnode has ref input.
HasRefInput(const CNodePtr & cnode)129 bool HasRefInput(const CNodePtr &cnode) {
130   if (cnode == nullptr || cnode->empty()) {
131     return false;
132   }
133   // Return true if any of arguments is ref.
134   return std::any_of(cnode->weak_inputs().begin() + 1, cnode->weak_inputs().end(), [](const auto &weak_input) {
135     const auto &input = weak_input.lock();
136     MS_EXCEPTION_IF_NULL(input);
137     return common::AnfAlgo::HasAbstractRef(input);
138   });
139 }
140 
141 // Return true if cnode has tuple(ref) or list(ref).
HasRefSequenceInput(const CNodePtr & cnode)142 bool HasRefSequenceInput(const CNodePtr &cnode) {
143   if (cnode == nullptr || cnode->empty()) {
144     return false;
145   }
146   for (size_t index = 1; index < cnode->size(); ++index) {
147     const auto &input = cnode->input(index);
148     MS_EXCEPTION_IF_NULL(input);
149     if (common::AnfAlgo::SequenceHasAbstractRef(input)) {
150       return true;
151     }
152   }
153   return false;
154 }
155 
156 // Return true if we don't need Load for the given primitive.
157 // i.e. keep Ref as Ref for some primitives.
IsKeepRef(const PrimitivePtr & prim)158 bool IsKeepRef(const PrimitivePtr &prim) {
159   return (GetSideEffectPropagate(prim) != 0) || IsPrimitiveEquals(prim, prim::kPrimRefToEmbed) ||
160          IsPrimitiveEquals(prim, prim::kPrimPull) || IsPrimitiveEquals(prim, prim::kPrimMakeTuple) ||
161          IsPrimitiveEquals(prim, prim::kPrimMakeList);
162 }
163 
164 // Gets func_graph from the given cnode, return nullptr if it is not a func graph call.
GetFuncGraph(const CNodePtr & cnode)165 FuncGraphPtr GetFuncGraph(const CNodePtr &cnode) {
166   if (cnode != nullptr && !cnode->empty()) {
167     return GetValueNode<FuncGraphPtr>(cnode->input(0));
168   }
169   return nullptr;
170 }
171 
172 // Gets first input as cnode from the given cnode,
173 // return null if input[0] is not a cnode.
GetFuncCNode(const CNodePtr & cnode)174 CNodePtr GetFuncCNode(const CNodePtr &cnode) {
175   if (cnode != nullptr && !cnode->empty()) {
176     return dyn_cast<CNode>(cnode->input(0));
177   }
178   return nullptr;
179 }
180 
181 // Gets first input as function parameter from the given cnode,
182 // return null if input[0] is not a parameter.
GetFuncParameter(const CNodePtr & cnode)183 ParameterPtr GetFuncParameter(const CNodePtr &cnode) {
184   if (cnode != nullptr && !cnode->empty()) {
185     return dyn_cast<Parameter>(cnode->input(0));
186   }
187   return nullptr;
188 }
189 
GetFuncGraphFromPartialAbstract(const abstract::AbstractBasePtr & abs)190 FuncGraphPtr GetFuncGraphFromPartialAbstract(const abstract::AbstractBasePtr &abs) {
191   if (abs == nullptr || !abs->isa<abstract::PartialAbstractClosure>()) {
192     return nullptr;
193   }
194 
195   auto partial_closure = dyn_cast<abstract::PartialAbstractClosure>(abs);
196   MS_EXCEPTION_IF_NULL(partial_closure);
197   if (partial_closure->fn() == nullptr) {
198     MS_LOG(ERROR) << "Partial closure's func graph is null, " << abs->ToString();
199     return nullptr;
200   }
201   auto func_graph_abstract = dyn_cast<abstract::FuncGraphAbstractClosure>(partial_closure->fn());
202   if (func_graph_abstract != nullptr) {
203     MS_EXCEPTION_IF_NULL(func_graph_abstract);
204     if (!func_graph_abstract->specialized()) {
205       MS_LOG(DEBUG) << "Unspecialized func graph, partial abs: " << abs->ToString()
206                     << ", partial fn abs: " << func_graph_abstract->ToString();
207       return nullptr;
208     }
209     return func_graph_abstract->func_graph();
210   }
211 
212   // Nested Partial.
213   return GetFuncGraphFromPartialAbstract(partial_closure->fn());
214 }
215 
GetFuncGraphFromFuncGraphAbstract(const abstract::AbstractBasePtr & abs)216 FuncGraphPtr GetFuncGraphFromFuncGraphAbstract(const abstract::AbstractBasePtr &abs) {
217   auto func_closure = dyn_cast<abstract::FuncGraphAbstractClosure>(abs);
218   if (func_closure == nullptr) {
219     return nullptr;
220   }
221   if (func_closure->func_graph() == nullptr) {
222     MS_LOG(DEBUG) << "FuncGraph closure's func graph is null, " << abs->ToString();
223     return nullptr;
224   }
225   return func_closure->func_graph();
226 }
227 
228 // Gets first input as MultitypeFuncGraph from the given cnode,
229 // return null if input[0] is not a MultitypeFuncGraph.
GetFuncMultitypeFuncGraph(const CNodePtr & cnode)230 prim::MultitypeFuncGraphPtr GetFuncMultitypeFuncGraph(const CNodePtr &cnode) {
231   if (cnode != nullptr && !cnode->empty()) {
232     return GetValueNode<prim::MultitypeFuncGraphPtr>(cnode->input(0));
233   }
234   return nullptr;
235 }
236 
237 // The cnode is non-effect-node, and the cnode is real node, and the inputs of cnode is dynamic.
IsNonEffectRealNodeAndInputIsDynamic(const CNodePtr & cnode)238 bool IsNonEffectRealNodeAndInputIsDynamic(const CNodePtr &cnode) {
239   MS_EXCEPTION_IF_NULL(cnode);
240   static const PrimitiveSet dynamic_input_node_prims = {
241     prim::kPrimStack,        prim::kPrimConcat,   prim::kPrimAddN,          prim::kPrimIdentityN,
242     prim::kPrimSparseConcat, prim::kPrimMeshgrid, prim::kPrimDynamicStitch, prim::kPrimPyExecute,
243     prim::kPrimPyInterpret,  prim::kPrimMakeDict};
244   PrimitivePtr prim = cnode->empty() ? nullptr : GetValueNode<PrimitivePtr>(cnode->input(0));
245   if (prim == nullptr) {
246     return false;
247   }
248   return dynamic_input_node_prims.find(prim) != dynamic_input_node_prims.end();
249 }
250 
251 // --------------------------------------------------------------------
252 // SCC (Strongly Connected Components) related types.
253 // --------------------------------------------------------------------
254 using SccVector = mindspore::HashSet<FuncGraphPtr>;
255 using SccPtr = std::shared_ptr<SccVector>;
256 using SccMap = mindspore::HashMap<FuncGraphPtr, SccPtr>;
257 
258 // ---------------------------------------------------------------------
259 // SccFinder find SCCs using Tarjan's algorithm.
260 // ---------------------------------------------------------------------
261 class SccFinder {
262  public:
SccFinder(const FuncGraphPtr & root)263   explicit SccFinder(const FuncGraphPtr &root) : root_(root) {}
264   ~SccFinder() = default;
Run()265   void Run() { Search(root_); }
scc_map()266   SccMap scc_map() { return std::move(scc_map_); }
267 
268  private:
269   // Store each layer of visit stack.
270   struct SccVisitInfo {
271     FuncGraphPtr graph{nullptr};
272     const FuncGraphCounterMap *func_graphs_used_ptr{nullptr};
273     FuncGraphCounterMap::const_iterator visit_iter;
274   };
275 
276   // Tarjan algorithm. Search SCCs from the given graph.
277   // Iterative implementation.
Search(const FuncGraphPtr & graph)278   void Search(const FuncGraphPtr &graph) {
279     MS_EXCEPTION_IF_NULL(graph);
280     std::stack<SccVisitInfo> visit_stack;
281     auto seen = NewFgSeenGeneration();
282     // Push the origin graph.
283     SccVisitInfo info;
284     info.graph = graph;
285     info.graph->seen_ = seen;        // If visited.
286     info.graph->extra_seen_ = seen;  // If in stack.
287     auto index = 1;
288     info.graph->set_user_data<size_t>("index", std::make_shared<size_t>(index));
289     info.graph->set_user_data<size_t>("low", std::make_shared<size_t>(index));
290     stack_.push(graph);
291     visit_stack.push(std::move(info));
292     while (!visit_stack.empty()) {
293       auto &current_info = visit_stack.top();
294       if (current_info.func_graphs_used_ptr == nullptr) {
295         current_info.func_graphs_used_ptr = &current_info.graph->func_graphs_used();
296         current_info.visit_iter = current_info.func_graphs_used_ptr->cbegin();
297       }
298       // If there's not visited used func graph, continue visiting the left used.
299       if (current_info.visit_iter != current_info.func_graphs_used_ptr->cend()) {
300         auto used_graph = current_info.visit_iter->first;
301         ++current_info.visit_iter;
302         if (used_graph->seen_ != seen) {
303           // First visit, push it.
304           MS_LOG(DEBUG) << "Push graph: " << used_graph->ToString();
305           stack_.push(used_graph);
306           SccVisitInfo used_info;
307           ++index;
308           used_info.graph = used_graph;
309           used_info.graph->set_user_data<size_t>("index", std::make_shared<size_t>(index));
310           used_info.graph->set_user_data<size_t>("low", std::make_shared<size_t>(index));
311           used_info.graph->seen_ = seen;        // If visited.
312           used_info.graph->extra_seen_ = seen;  // If in stack.
313           visit_stack.push(std::move(used_info));
314         } else if (used_graph->extra_seen_ == seen) {
315           // Visited before AND in stack, update low.
316           auto min_low = std::min(*current_info.graph->user_data<size_t>("low"), *used_graph->user_data<size_t>("low"));
317           current_info.graph->set_user_data<size_t>("low", std::make_shared<size_t>(min_low));
318           MS_LOG(DEBUG) << "Update low [" << min_low << "] for " << current_info.graph->ToString() << " by "
319                         << used_graph->ToString();
320         }
321         continue;
322       }
323       // If all used func graphs are visited, pop it and check if it's SCC root.
324       auto current_graph = current_info.graph;
325       if (*current_graph->user_data<size_t>("low") != *current_graph->user_data<size_t>("index")) {
326         // Update low when pop.
327         visit_stack.pop();
328         auto &next_info = visit_stack.top();
329         auto min_low = std::min(*next_info.graph->user_data<size_t>("low"), *current_graph->user_data<size_t>("low"));
330         next_info.graph->set_user_data<size_t>("low", std::make_shared<size_t>(min_low));
331         MS_LOG(DEBUG) << "Update low [" << min_low << "] for " << next_info.graph->ToString() << " by "
332                       << current_graph->ToString();
333         continue;
334       }
335       MS_LOG(DEBUG) << "Found SCC root: " << current_graph->ToString();
336       // Pop members of the SCC from stack, they are on top of its root.
337       auto scc = std::make_shared<SccVector>();
338       while (!stack_.empty()) {
339         auto g = stack_.top();
340         g->extra_seen_ = 0;  // Not in stack any more.
341         stack_.pop();
342         // Add graph to SCC, and create the map from graph to SCC.
343         scc->insert(g);
344         (void)scc_map_.emplace(g, scc);
345         if (g == current_graph) {
346           break;
347         }
348       }
349       // SCC should not be empty.
350       if (scc->empty()) {
351         MS_LOG(INTERNAL_EXCEPTION) << "Invalid SCC for: " << graph->ToString();
352       }
353       visit_stack.pop();
354     }
355   }
356 
357   // The root graph.
358   FuncGraphPtr root_;
359 
360   // The stack for Tarjan algorithm.
361   std::stack<FuncGraphPtr> stack_;
362 
363   // The result SCC map, from graph to its SCC.
364   SccMap scc_map_;
365 };
366 
367 struct SwitchLayerCall {
368   CNodePtr caller;
369   EffectInfo effect_info;
370   std::vector<FuncGraphPtr> branches;
371 };
372 
373 class NodeStackGuard {
374  public:
NodeStackGuard(OrderedSet<AnfNodePtr> * stack,const AnfNodePtr & node)375   NodeStackGuard(OrderedSet<AnfNodePtr> *stack, const AnfNodePtr &node) : stack_(stack) { stack_->push_front(node); }
~NodeStackGuard()376   ~NodeStackGuard() {
377     try {
378       (void)stack_->pop();
379     } catch (const std::exception &e) {
380       MS_LOG(ERROR) << "Exception when pop. Error info " << e.what();
381     }
382 
383     stack_ = nullptr;
384   }
385 
386  private:
387   OrderedSet<AnfNodePtr> *stack_;
388 };
389 
390 // -------------------------------------------------------------------------------
391 // SideEffectFinder search and mark side effects for graph and its sub-graphs.
392 // -------------------------------------------------------------------------------
393 class SideEffectFinder {
394  public:
Search(const FuncGraphPtr & root)395   static void Search(const FuncGraphPtr &root) {
396     SideEffectFinder finder(root);
397     finder.Run();
398   }
399 
400  private:
SideEffectFinder(const FuncGraphPtr & root)401   explicit SideEffectFinder(const FuncGraphPtr &root) : root_(root) {}
402   ~SideEffectFinder() = default;
403 
Run()404   void Run() {
405     // To handle recursive calls, we generate SCC map before search.
406     GenerateSccMap();
407     // Update order list to include outer cnodes.
408     UpdateOrderLists();
409     // Find side effects by DFS from the top graph.
410     ObtainEffectInfoForFuncGraphs(root_);
411     // Check Switch calls, add monad arguments if need.
412     HandleSwitchCalls();
413     // Check SwitchLayer calls, add monad arguments if need.
414     HandleSwitchLayerCalls();
415     // Check Partial CNode calls, add monad arguments if need.
416     HandlePartialCalls();
417   }
418 
UpdateOrderLists() const419   void UpdateOrderLists() const {
420     // Some cnodes used in current func graph but belong to other func graph, we have to
421     // insert them into order list so that we can handle side effects for them.
422     UpdateOrderList(root_);
423     for (auto &fg : root_->func_graphs_used_total()) {
424       UpdateOrderList(fg);
425     }
426   }
427 
UpdateOrderList(const FuncGraphPtr & func_graph)428   static void UpdateOrderList(const FuncGraphPtr &func_graph) {
429     MS_EXCEPTION_IF_NULL(func_graph);
430     std::list<CNodeWeakPtr> new_order_list;
431     const auto &order_list = func_graph->order_list();
432     for (auto &weak_cnode : order_list) {
433       const auto &cnode = weak_cnode.lock();
434       if (cnode != nullptr) {
435         PushToOrderList(func_graph, cnode, &new_order_list);
436       }
437     }
438     func_graph->set_order_list(std::move(new_order_list));
439   }
440 
PushToOrderList(const FuncGraphPtr & fg,const CNodePtr & cnode,std::list<CNodeWeakPtr> * new_order_list)441   static void PushToOrderList(const FuncGraphPtr &fg, const CNodePtr &cnode, std::list<CNodeWeakPtr> *new_order_list) {
442     MS_EXCEPTION_IF_NULL(cnode);
443     MS_EXCEPTION_IF_NULL(new_order_list);
444     // If contains.
445     auto iter = std::find_if(new_order_list->cbegin(), new_order_list->cend(), [&cnode](const CNodeWeakPtr &node) {
446       return node.lock() != nullptr && node.lock() == cnode;
447     });
448     if (iter != new_order_list->cend()) {
449       return;
450     }
451 
452     for (auto &weak_input : cnode->weak_inputs()) {
453       auto input = weak_input.lock();
454       MS_EXCEPTION_IF_NULL(input);
455       auto input_cnode = dyn_cast<CNode>(input);
456       if (input_cnode != nullptr && input_cnode->func_graph() != fg) {
457         PushToOrderList(fg, input_cnode, new_order_list);
458       }
459     }
460     new_order_list->emplace_back(CNodeWeakPtr(cnode));
461   }
462 
463   // Generate SCC map by SccFinder.
GenerateSccMap()464   void GenerateSccMap() {
465     SccFinder scc_finder(root_);
466     scc_finder.Run();
467     scc_map_ = std::move(scc_finder.scc_map());
468   }
469 
470   // Gets branch graph from a switch cnode at given input index.
GetSwitchBranch(const CNodePtr & cnode,size_t index) const471   FuncGraphPtr GetSwitchBranch(const CNodePtr &cnode, size_t index) const {
472     MS_EXCEPTION_IF_NULL(cnode);
473     const auto &branch_node = cnode->input(index);
474     AnfNodePtr branch_fg_node = branch_node;
475     if (IsPrimitiveCNode(branch_node, prim::kPrimPartial)) {
476       auto branch_abs = branch_node->abstract();
477       constexpr auto recursive_level = 2;
478       MS_LOG(DEBUG) << "branch_node: " << branch_node->DebugString(recursive_level)
479                     << ", abstract: " << (branch_abs != nullptr ? branch_abs->ToString() : "null");
480       auto branch_cnode = branch_node->cast_ptr<CNode>();
481       MS_EXCEPTION_IF_NULL(branch_cnode);
482       branch_fg_node = branch_cnode->input(1);
483       MS_EXCEPTION_IF_NULL(branch_fg_node);
484       MS_LOG(DEBUG) << "branch_fg_node: " << branch_fg_node->DebugString(recursive_level);
485     }
486     return GetValueNode<FuncGraphPtr>(branch_fg_node);
487   }
488 
489   // Gets branch graphs from a switch cnode.
GetSwitchBranches(const CNodePtr & cnode) const490   std::vector<FuncGraphPtr> GetSwitchBranches(const CNodePtr &cnode) const {
491     MS_EXCEPTION_IF_NULL(cnode);
492     constexpr size_t switch_cnode_size = 4;
493     constexpr size_t true_index = 2;
494     constexpr size_t false_index = 3;
495     // Check size.
496     if (cnode->size() != switch_cnode_size) {
497       MS_LOG(INTERNAL_EXCEPTION) << "Invalid switch: " << cnode->DebugString();
498     }
499     // Add both branches, in some case, only one branch is set.
500     std::vector<FuncGraphPtr> branches;
501     auto true_branch = GetSwitchBranch(cnode, true_index);
502     if (true_branch != nullptr) {
503       (void)branches.emplace_back(true_branch);
504     }
505     auto false_branch = GetSwitchBranch(cnode, false_index);
506     if (false_branch != nullptr) {
507       (void)branches.emplace_back(false_branch);
508     }
509     if (branches.empty()) {
510       constexpr auto recursive_level = 2;
511       MS_LOG(INTERNAL_EXCEPTION) << "Invalid switch: " << cnode->DebugString(recursive_level);
512     }
513     return branches;
514   }
515 
516   // Add monad parameter to switch branch graphs.
AddMonadParameters(const std::vector<FuncGraphPtr> & branches,const std::string & name,const AbstractBasePtr & abs) const517   void AddMonadParameters(const std::vector<FuncGraphPtr> &branches, const std::string &name,
518                           const AbstractBasePtr &abs) const {
519     for (auto &branch : branches) {
520       (void)AddMonadParameter(branch, name, abs);
521     }
522   }
523 
524   // Trace effect info for Partial call node.
TracePartialCallEffectInfo(const CNodePtr & cnode,const EffectInfo & old_info)525   EffectInfo TracePartialCallEffectInfo(const CNodePtr &cnode, const EffectInfo &old_info) {
526     const AnfNodePtr &func_node = cnode->input(0);
527     MS_EXCEPTION_IF_NULL(func_node);
528     // Only handle for Parameter or Non-Partial CNode.
529     if (!func_node->isa<Parameter>() && (!func_node->isa<CNode>() || IsPrimitiveCNode(func_node, prim::kPrimPartial))) {
530       return old_info;
531     }
532     auto partial_real_func = GetFuncGraphFromPartialAbstract(func_node->abstract());
533     if (partial_real_func == nullptr) {
534       return old_info;
535     }
536 
537     // Not retry checking, if has already confirmed the Partial func graph has side effect, or still detect ongoing.
538     if (old_info.state != EffectInfo::kDetected || old_info.memory || old_info.io || old_info.load ||
539         old_info.back_mem) {
540       return old_info;
541     }
542 
543     // Record the Partial callers and real func graph.
544     (void)partial_cnode_calls_.emplace(cnode, partial_real_func);
545 
546     // Try to obtain the effect info of func graph.
547     auto effect_info = ObtainEffectInfoForFuncGraph(partial_real_func);
548     MS_EXCEPTION_IF_NULL(func_node->abstract());
549     MS_LOG(DEBUG) << "CNode or Parameter func: " << func_node->DebugString()
550                   << ", partial_real_func: " << partial_real_func->ToString() << ", "
551                   << func_node->abstract()->ToString() << ", cnode: " << cnode->DebugString()
552                   << ", effect_info: " << effect_info.memory << "/" << effect_info.io << "/" << effect_info.load;
553     return effect_info;
554   }
555 
556   // Trace effect info for Switch cnode.
TraceSwitchEffectInfo(const CNodePtr & cnode)557   EffectInfo TraceSwitchEffectInfo(const CNodePtr &cnode) {
558     // Find branches from switch cnode.
559     auto branches = GetSwitchBranches(cnode);
560     // Save branch caller, so that we can update arguments for the caller.
561     SaveBranchCaller(cnode, branches);
562     // For some case, only one branch is set.
563     if (branches.size() == 1) {
564       auto &branch = branches.front();
565       return ObtainEffectInfoForFuncGraph(branch);
566     }
567     // When both branches are set, merge their effect infos.
568     EffectInfo info = MergeEffectInfo(branches);
569     if (info.state == EffectInfo::kDetected) {
570       // Setup both branches according the merged effect info.
571       SetupEffectBranches(info, branches);
572     }
573     return info;
574   }
575 
576   // Trace effect info for SwitchLayer cnode.
TraceSwitchLayerEffectInfo(const CNodePtr & cnode)577   EffectInfo TraceSwitchLayerEffectInfo(const CNodePtr &cnode) {
578     // Find branches from switch_layer cnode.
579     auto branches = GetSwitchLayerBranches(cnode);
580     // Merge effect info from all branches.
581     EffectInfo info = MergeEffectInfo(branches);
582     if (info.state == EffectInfo::kDetected) {
583       // Setup branches according the merged effect info.
584       SetupEffectBranches(info, branches);
585       // Save the switch_layer call, so that we can add monad argument for it if need.
586       auto &call = switch_layer_calls_.emplace_back();
587       call.caller = caller_;
588       call.effect_info = info;
589       call.branches = move(branches);
590     }
591     return info;
592   }
593 
HandlePartialCalls()594   void HandlePartialCalls() {
595     for (auto &call : partial_cnode_calls_) {
596       const auto &caller = call.first;
597       const auto &func_graph = call.second;
598       const auto &effect_info = ObtainEffectInfoForFuncGraph(func_graph);
599       MS_EXCEPTION_IF_NULL(caller->abstract());
600       MS_LOG(DEBUG) << "func_graph: " << func_graph->ToString() << ", caller: " << caller->DebugString() << ", "
601                     << caller->abstract()->ToString() << ", effect_info: " << effect_info.memory << "/"
602                     << effect_info.io << "/" << effect_info.load << "/" << effect_info.back_mem;
603       AddMonadForCaller(caller, effect_info);
604       // Setup monad parameters for func graph according the effect info.
605       if (effect_info.memory || effect_info.load) {
606         (void)AddMonadParameter(func_graph, "u", kUMonad->ToAbstract());
607       }
608       if (effect_info.io) {
609         (void)AddMonadParameter(func_graph, "io", kIOMonad->ToAbstract());
610       }
611     }
612   }
613 
HandleSwitchCalls()614   void HandleSwitchCalls() {
615     for (auto &call : switch_calls_) {
616       const auto &caller = call.first;
617       const auto &branches = call.second;
618       CheckAndFixSwitchCall(caller, branches);
619     }
620   }
621 
CheckAndFixSwitchCall(const CNodePtr & caller,const FuncGraphVector & branches) const622   void CheckAndFixSwitchCall(const CNodePtr &caller, const FuncGraphVector &branches) const {
623     MS_EXCEPTION_IF_NULL(caller);
624     const auto caller_input_size = caller->size() - 1;
625     for (size_t i = 0; i < branches.size(); ++i) {
626       const auto &branch = branches[i];
627       MS_EXCEPTION_IF_NULL(branch);
628 
629       // Get partial branch input size.
630       size_t extra_input_size = 0;
631       const auto &switch_node = caller->input(0);
632       if (!IsPrimitiveCNode(switch_node, prim::kPrimSwitch)) {
633         MS_LOG(INTERNAL_EXCEPTION) << "Not switch CNode, " << switch_node->DebugString();
634       }
635       const auto &switch_cnode = dyn_cast<CNode>(switch_node);
636       constexpr auto ignore_switch_and_cond_count = 2;
637       const auto &branch_node = switch_cnode->input(i + ignore_switch_and_cond_count);
638       if (IsPrimitiveCNode(branch_node, prim::kPrimPartial)) {
639         const auto &branch_cnode = branch_node->cast_ptr<CNode>();
640         constexpr auto ignore_partial_and_fg_count = 2;
641         extra_input_size = branch_cnode->size() - ignore_partial_and_fg_count;
642       }
643 
644       // Check inputs size.
645       if (caller_input_size + extra_input_size != branch->parameters().size()) {
646         // Fix branch if number of parameter mismatch.
647         FixSwitchBranch(caller, branch);
648         // The number of parameter should matched after fix.
649         if (caller_input_size + extra_input_size != branch->parameters().size()) {
650           constexpr auto recursive_count = 2;
651           MS_LOG(INTERNAL_EXCEPTION) << "Fix switch branch parameters failed! " << caller->DebugString(recursive_count)
652                                      << ", branch: " << branch->ToString()
653                                      << ", branch node: " << branch_node->DebugString(recursive_count)
654                                      << ", size: " << caller_input_size << " + " << extra_input_size << " not equal to "
655                                      << branch->parameters().size();
656         }
657       }
658     }
659   }
660 
FixSwitchBranch(const CNodePtr & caller,const FuncGraphPtr & branch) const661   void FixSwitchBranch(const CNodePtr &caller, const FuncGraphPtr &branch) const {
662     MS_EXCEPTION_IF_NULL(branch);
663     for (size_t i = caller->size() - 1; i > 0; --i) {
664       auto &input = caller->input(i);
665       MS_EXCEPTION_IF_NULL(input);
666       if (HasAbstractUMonad(input)) {
667         (void)AddMonadParameter(branch, "u", input->abstract());
668       } else if (HasAbstractIOMonad(input)) {
669         (void)AddMonadParameter(branch, "io", input->abstract());
670       }
671     }
672   }
673 
HandleSwitchLayerCalls()674   void HandleSwitchLayerCalls() {
675     for (auto &call : switch_layer_calls_) {
676       const auto &info = call.effect_info;
677       const auto &branches = call.branches;
678       auto new_info = MergeEffectInfo(branches);
679       // Reset branches if effect info changed.
680       if (new_info.memory != info.memory || new_info.load != info.load || new_info.io != info.io) {
681         AddMonadForCaller(call.caller, new_info);
682         SetupEffectBranches(new_info, branches);
683       }
684     }
685   }
686 
687   // Gets branch graphs from a switch_layer cnode.
GetSwitchLayerBranches(const CNodePtr & cnode)688   std::vector<FuncGraphPtr> GetSwitchLayerBranches(const CNodePtr &cnode) {
689     MS_EXCEPTION_IF_NULL(cnode);
690     constexpr size_t func_tuple_index = 2;
691     constexpr int recursive_level = 2;
692     if (cnode->size() <= func_tuple_index) {
693       MS_LOG(INTERNAL_EXCEPTION) << "Invalid switch_layer: " << cnode->DebugString(recursive_level);
694     }
695     auto func_tuple = cnode->input(func_tuple_index);
696     return GetGraphsFromTuple(func_tuple);
697   }
698 
GetGraphFromSwitchWithDeadNode(const CNodePtr & cnode) const699   FuncGraphPtr GetGraphFromSwitchWithDeadNode(const CNodePtr &cnode) const {
700     MS_EXCEPTION_IF_NULL(cnode);
701     auto input = cnode->input(0);
702     MS_EXCEPTION_IF_NULL(input);
703     if (!IsPrimitiveCNode(input, prim::kPrimSwitch)) {
704       return nullptr;
705     }
706     auto node = input->cast_ptr<CNode>();
707     if (node->size() < kSwitchInputSize) {
708       MS_LOG(EXCEPTION) << "Switch inputs size: " << node->size() << "less than " << kSwitchInputSize;
709     }
710     auto cond_node = node->input(kSwitchCondIndex);
711     auto cond_abs = cond_node->abstract();
712     MS_EXCEPTION_IF_NULL(cond_abs);
713     auto cond_abs_val = cond_abs->BuildValue();
714     MS_EXCEPTION_IF_NULL(cond_abs_val);
715     if (cond_abs_val->ContainsValueAny()) {
716       return nullptr;
717     }
718     auto cond_abs_bool_val = dyn_cast<BoolImm>(cond_abs_val);
719     MS_EXCEPTION_IF_NULL(cond_abs_bool_val);
720     auto branch =
721       cond_abs_bool_val->value() ? node->input(kSwitchTrueBranchIndex) : node->input(kSwitchFalseBranchIndex);
722     return GetValueNode<FuncGraphPtr>(branch);
723   }
724 
725   // Get and trace graphs from a tuple of func node for switch_layer.
GetGraphsFromTuple(const AnfNodePtr & func_tuple)726   std::vector<FuncGraphPtr> GetGraphsFromTuple(const AnfNodePtr &func_tuple) {
727     // The functions make tuple CNode.
728     if (IsPrimitiveCNode(func_tuple, prim::kPrimMakeTuple)) {
729       return GetGraphsFromMakeTuple(func_tuple->cast<CNodePtr>());
730     }
731     // The functions value tuple.
732     if (IsValueNode<ValueTuple>(func_tuple)) {
733       return GetGraphsFromValueTuple(func_tuple->cast<ValueNodePtr>());
734     }
735     // Trace tuple from parameter.
736     auto para = dyn_cast<Parameter>(func_tuple);
737     if (para != nullptr) {
738       std::vector<FuncGraphPtr> graphs;
739       ForEachRealArguments(para,
740                            [this, &graphs](const AnfNodePtr &arg) { graphs = std::move(GetGraphsFromTuple(arg)); });
741       return graphs;
742     }
743     // Trace tuple returned from func graph call.
744     auto cnode = dyn_cast<CNode>(func_tuple);
745     MS_EXCEPTION_IF_NULL(cnode);
746     auto func_graph = GetFuncGraph(cnode);
747     if (func_graph != nullptr) {
748       return GetGraphsFromTuple(func_graph->output());
749     }
750     // Trace tuple returned from func graph call including switch with dead node.
751     func_graph = GetGraphFromSwitchWithDeadNode(cnode);
752     if (func_graph != nullptr) {
753       return GetGraphsFromTuple(func_graph->output());
754     }
755     MS_LOG(INTERNAL_EXCEPTION) << "Invalid input for switch_layer: func_graph is nullptr.";
756   }
757 
758   // Get graphs from a tuple of funcs make node for switch_layer.
GetGraphsFromMakeTuple(const CNodePtr & make_tuple) const759   std::vector<FuncGraphPtr> GetGraphsFromMakeTuple(const CNodePtr &make_tuple) const {
760     MS_EXCEPTION_IF_NULL(make_tuple);
761     constexpr int recursive_level = 2;
762     if (make_tuple->size() <= 1) {
763       MS_LOG(INTERNAL_EXCEPTION) << "Invalid make_tuple for switch_layer: " << make_tuple->DebugString(recursive_level);
764     }
765     std::vector<FuncGraphPtr> graphs;
766     graphs.reserve(make_tuple->size() - 1);
767     for (size_t i = 1; i < make_tuple->size(); ++i) {
768       auto func_graph = GetValueNode<FuncGraphPtr>(make_tuple->input(i));
769       if (func_graph == nullptr) {
770         MS_LOG(WARNING) << "Non-graph found in switch_layer input: " << make_tuple->DebugString(recursive_level)
771                         << ", index: " << i;
772         continue;
773       }
774       graphs.push_back(func_graph);
775     }
776     return graphs;
777   }
778 
779   // Get graphs from a tuple of functions value tuple for switch_layer.
GetGraphsFromValueTuple(const ValueNodePtr & value_node) const780   std::vector<FuncGraphPtr> GetGraphsFromValueTuple(const ValueNodePtr &value_node) const {
781     MS_EXCEPTION_IF_NULL(value_node);
782     const auto &value = value_node->value();
783     MS_EXCEPTION_IF_NULL(value);
784     auto value_tuple = value->cast_ptr<ValueTuple>();
785     MS_EXCEPTION_IF_NULL(value_tuple);
786     std::vector<FuncGraphPtr> graphs;
787     graphs.reserve(value_tuple->size());
788     const auto &tuple_elements = value_tuple->value();
789     for (size_t i = 0; i < tuple_elements.size(); ++i) {
790       const auto &tuple_element = tuple_elements[i];
791       MS_EXCEPTION_IF_NULL(tuple_element);
792       auto func_graph = tuple_element->cast<FuncGraphPtr>();
793       if (func_graph == nullptr) {
794         MS_LOG(WARNING) << "Non-graph found in switch_layer input: " << value_node->DebugString() << ", index: " << i;
795         continue;
796       }
797       graphs.push_back(func_graph);
798     }
799     return graphs;
800   }
801 
802   // Trace effect info from tuple_getitem cnode.
TraceGetItemEffectInfo(const CNodePtr & cnode,std::stack<ValuePtr> * indexes)803   EffectInfo TraceGetItemEffectInfo(const CNodePtr &cnode, std::stack<ValuePtr> *indexes) {
804     MS_EXCEPTION_IF_NULL(cnode);
805     MS_EXCEPTION_IF_NULL(indexes);
806     constexpr size_t tuple_or_list_or_dict_input = 1;
807     constexpr size_t index_input = 2;
808     constexpr size_t cnode_size = 3;
809     if (cnode->size() != cnode_size) {
810       MS_LOG(INTERNAL_EXCEPTION) << "Invalid getitem: " << cnode->DebugString();
811     }
812     // Get item index.
813     auto &index_node = cnode->input(index_input);
814     auto index_value = dyn_cast<ValueNode>(index_node);
815     if (index_value == nullptr) {
816       MS_LOG(INTERNAL_EXCEPTION) << "getitem with non-const index, cnode: " << cnode->DebugString();
817     }
818 
819     // Get tuple, list or dict value.
820     const auto &tuple_or_list_or_dict_node = cnode->input(tuple_or_list_or_dict_input);
821     // Push tuple, list or dict index.
822     indexes->push(index_value->value());
823     return TraceTupleListOrDictEffectInfo(tuple_or_list_or_dict_node, indexes);
824   }
825 
TraceTupleListOrDictEffectInfo(const AnfNodePtr & node,std::stack<ValuePtr> * indexes)826   EffectInfo TraceTupleListOrDictEffectInfo(const AnfNodePtr &node, std::stack<ValuePtr> *indexes) {
827     MS_EXCEPTION_IF_NULL(indexes);
828     auto para = dyn_cast<Parameter>(node);
829     if (para != nullptr) {
830       return TraceTupleListParaEffectInfo(para, *indexes);
831     }
832     auto cnode = dyn_cast<CNode>(node);
833     if (cnode != nullptr) {
834       return TraceTupleListCNodeEffectInfo(cnode, indexes);
835     }
836     // Should not reach here.
837     MS_LOG(INTERNAL_EXCEPTION) << "Side effects untraceable: cnode is nullptr. Invalid node: " << node->DebugString();
838   }
839 
TraceTupleListParaEffectInfo(const ParameterPtr & para,const std::stack<ValuePtr> & indexes)840   EffectInfo TraceTupleListParaEffectInfo(const ParameterPtr &para, const std::stack<ValuePtr> &indexes) {
841     EffectInfo info{EffectInfo::kDetected, false, false, false, false};
842     ForEachRealArguments(para, [this, &info, indexes](const AnfNodePtr &arg) {
843       // Merge real argument effect info.
844       auto indexes_copy = indexes;
845       auto arg_info = TraceTupleListOrDictEffectInfo(arg, &indexes_copy);
846       info.Merge(arg_info);
847     });
848     return info;
849   }
850 
GetInputIndex(const ValuePtr & top_index_value,const CNodePtr & origin_cnode,size_t inputs_size)851   size_t GetInputIndex(const ValuePtr &top_index_value, const CNodePtr &origin_cnode, size_t inputs_size) {
852     auto int64_imm = dyn_cast<Int64Imm>(top_index_value);
853     if (int64_imm == nullptr) {
854       MS_LOG(INTERNAL_EXCEPTION) << "Invalid make_tuple: " << origin_cnode->DebugString()
855                                  << ", index: " << (top_index_value == nullptr ? "null" : top_index_value->ToString());
856     }
857     auto top_index = int64_imm->value();
858     size_t input_index = 0;
859     // Support tuple index is negative
860     if (top_index < 0) {
861       if (SizeToLong(inputs_size) + top_index < 0) {
862         MS_LOG(INTERNAL_EXCEPTION) << "Invalid make_tuple: " << origin_cnode->DebugString() << " index=" << top_index;
863       }
864       input_index = static_cast<size_t>(inputs_size + top_index);
865     } else {
866       // Follow the tuple item according the index.
867       input_index = static_cast<size_t>(top_index) + 1;
868     }
869     if (input_index >= inputs_size) {
870       MS_LOG(INTERNAL_EXCEPTION) << "Invalid make_tuple: " << origin_cnode->DebugString() << " index=" << top_index;
871     }
872     return input_index;
873   }
874 
TraceMakeTupleListEffectInfo(const CNodePtr & cnode,std::stack<ValuePtr> * indexes)875   EffectInfo TraceMakeTupleListEffectInfo(const CNodePtr &cnode, std::stack<ValuePtr> *indexes) {
876     constexpr int recursive_level = 2;
877     if (indexes->empty()) {
878       MS_LOG(INTERNAL_EXCEPTION) << "Unexpected make_tuple or make_list: " << cnode->DebugString(recursive_level);
879     }
880     // Pop out tuple index.
881     auto top_index_value = indexes->top();
882     indexes->pop();
883     auto input_index = GetInputIndex(top_index_value, cnode, cnode->size());
884     if (indexes->empty()) {
885       // Trace non-tuple.
886       return TraceEffectInfo(cnode->input(input_index));
887     }
888     // This is the tuple of tuple case.
889     return TraceTupleListOrDictEffectInfo(cnode->input(input_index), indexes);
890   }
891 
TraceMakeDictEffectInfo(const CNodePtr & cnode,std::stack<ValuePtr> * indexes)892   EffectInfo TraceMakeDictEffectInfo(const CNodePtr &cnode, std::stack<ValuePtr> *indexes) {
893     constexpr int recursive_level = 2;
894     if (indexes->empty()) {
895       MS_LOG(INTERNAL_EXCEPTION) << "Unexpected make_dict: " << cnode->DebugString(recursive_level);
896     }
897     // Pop out dict index.
898     auto top_key_value = indexes->top();
899     MS_EXCEPTION_IF_NULL(top_key_value);
900     indexes->pop();
901     constexpr size_t keys_node_index = 1;
902     constexpr size_t values_node_index = 2;
903     auto keys_node = cnode->input(keys_node_index);
904     MS_EXCEPTION_IF_NULL(keys_node);
905     auto keys = GetValueNode<ValueTuplePtr>(keys_node);
906     if (keys == nullptr) {
907       MS_LOG(INTERNAL_EXCEPTION) << "Invalid make_dict: " << cnode->DebugString()
908                                  << ", the keys node: " << keys_node->DebugString();
909     }
910     for (size_t i = 0; i < keys->size(); ++i) {
911       MS_EXCEPTION_IF_NULL(keys->value()[i]);
912       if (*(keys->value()[i]) == *top_key_value) {
913         // The values_node is a make_dict.
914         indexes->push(MakeValue(SizeToLong(i)));
915         return TraceTupleListOrDictEffectInfo(cnode->input(values_node_index), indexes);
916       }
917     }
918     MS_LOG(WARNING) << "make_dict untraceable from: " << cnode->DebugString(recursive_level);
919     return {EffectInfo::kDetected, false, false, false};
920   }
921 
TraceDictItemsEffectInfo(const CNodePtr & cnode,std::stack<ValuePtr> * indexes)922   EffectInfo TraceDictItemsEffectInfo(const CNodePtr &cnode, std::stack<ValuePtr> *indexes) {
923     constexpr int recursive_level = 2;
924     // Pop dict_getitem index.
925     if (indexes->empty()) {
926       MS_LOG(INTERNAL_EXCEPTION) << "Unexpected dict_items: " << cnode->DebugString(recursive_level);
927     }
928     auto list_getitem_index_value = indexes->top();
929     indexes->pop();
930     // Pop dict_getitem index.
931     if (indexes->empty()) {
932       MS_LOG(INTERNAL_EXCEPTION) << "Unexpected dict_items: " << cnode->DebugString(recursive_level);
933     }
934     auto tuple_getitem_index_value = indexes->top();
935     indexes->pop();
936     constexpr size_t key_and_value_tuple_size = 2;
937     auto tuple_getitem_index = GetInputIndex(tuple_getitem_index_value, cnode, key_and_value_tuple_size + 1);
938     // If the item is a value_node, skip.
939     if (tuple_getitem_index == 1) {
940       MS_LOG(INFO) << "dict_items untraceable from: " << cnode->DebugString(recursive_level);
941       return {EffectInfo::kDetected, false, false, false};
942     }
943     // dict_items(make_dict(keys_value_tuple, make_tuple()))
944     if (!IsPrimitiveCNode(cnode->input(1), prim::kPrimMakeDict)) {
945       MS_LOG(WARNING) << "dict_items untraceable from: " << cnode->DebugString(recursive_level);
946       return {EffectInfo::kDetected, false, false, false};
947     }
948     // Trace the make_tuple.
949     auto make_dict_cnode = cnode->input(1)->cast<CNodePtr>();
950     constexpr size_t values_node_index = 2;
951     indexes->push(list_getitem_index_value);
952     return TraceTupleListOrDictEffectInfo(make_dict_cnode->input(values_node_index), indexes);
953   }
954 
TraceTupleListCNodeEffectInfo(const CNodePtr & cnode,std::stack<ValuePtr> * indexes)955   EffectInfo TraceTupleListCNodeEffectInfo(const CNodePtr &cnode, std::stack<ValuePtr> *indexes) {
956     MS_EXCEPTION_IF_NULL(indexes);
957     MS_EXCEPTION_IF_NULL(cnode);
958     auto prim = GetCNodePrimitiveWithoutDoSignature(cnode);
959     constexpr int recursive_level = 2;
960     // Trace MakeTuple or MakeList.
961     if (IsPrimitiveEquals(prim, prim::kPrimMakeTuple) || IsPrimitiveEquals(prim, prim::kPrimMakeList)) {
962       return TraceMakeTupleListEffectInfo(cnode, indexes);
963     }
964     // Trace MakeDict.
965     if (IsPrimitiveEquals(prim, prim::kPrimMakeDict)) {
966       return TraceMakeDictEffectInfo(cnode, indexes);
967     }
968     // Trace the case of tuple, list or dict nested.
969     if (IsPrimitiveEquals(prim, prim::kPrimTupleGetItem) || IsPrimitiveEquals(prim, prim::kPrimListGetItem) ||
970         IsPrimitiveEquals(prim, prim::kPrimDictGetItem)) {
971       return TraceGetItemEffectInfo(cnode, indexes);
972     }
973     if (IsPrimitiveEquals(prim, prim::kPrimDictGetValues) && IsPrimitiveCNode(cnode->input(1), prim::kPrimMakeDict)) {
974       auto make_dict_cnode = cnode->input(1)->cast<CNodePtr>();
975       constexpr size_t values_node_index = 2;
976       return TraceTupleListOrDictEffectInfo(make_dict_cnode->input(values_node_index), indexes);
977     }
978     if (IsPrimitiveEquals(prim, prim::kPrimDictItems)) {
979       return TraceDictItemsEffectInfo(cnode, indexes);
980     }
981     // Trace primitive propagating side effect from its input, such as Depend, etc.
982     int input_index = GetSideEffectPropagate(prim);
983     if (input_index > 0 && input_index < static_cast<int>(cnode->size())) {
984       return TraceTupleListOrDictEffectInfo(cnode->input(static_cast<size_t>(input_index)), indexes);
985     }
986     // Tuple returned from func graph call.
987     auto func_graph = GetFuncGraph(cnode);
988     if (func_graph != nullptr) {
989       return TraceTupleListOrDictEffectInfo(func_graph->output(), indexes);
990     }
991     // Tuple returned from a Switch call.
992     if (cnode->size() == 1 && IsPrimitiveCNode(cnode->input(0), prim::kPrimSwitch)) {
993       return TraceTupleFromSwitch(cnode->input(0)->cast<CNodePtr>(), *indexes);
994     }
995     // Tuple is returned from J().
996     //   %1 = J(primal)
997     //   tuple = %1(args)
998     if (cnode->size() > 0 && IsPrimitiveCNode(cnode->input(0), prim::kPrimJ)) {
999       MS_LOG(DEBUG) << "Tuple from J: " << cnode->DebugString(recursive_level);
1000       constexpr size_t func_index = 1;
1001       auto j_conde = cnode->input(0)->cast<CNodePtr>();
1002       auto j_func = j_conde->input(func_index);
1003       auto func_info = TraceEffectInfo(j_func);
1004       // In order to add the Umonad arg to the bprop_top_cell in advance,
1005       // so that the side effects in the bprop graph are sorted earlier than the side effects of the optimizer.
1006       return {EffectInfo::kDetected, false, false, false, func_info.back_mem};
1007     }
1008     // Rare case.
1009     MS_LOG(WARNING) << "Tuple untraceable from: " << cnode->DebugString(recursive_level);
1010     return {EffectInfo::kDetected, false, false, false};
1011   }
1012 
1013   // Trace effect info from a Switch node that output is a tuple.
TraceTupleFromSwitch(const CNodePtr & switch_cnode,const std::stack<ValuePtr> & tuple_indexes)1014   EffectInfo TraceTupleFromSwitch(const CNodePtr &switch_cnode, const std::stack<ValuePtr> &tuple_indexes) {
1015     auto branches = GetSwitchBranches(switch_cnode);
1016     EffectInfo info = {EffectInfo::kDetected, false, false, false, false};
1017     for (auto &branch : branches) {
1018       MS_EXCEPTION_IF_NULL(branch);
1019       auto tuple_indexes_copy = tuple_indexes;
1020       EffectInfo branch_info = TraceTupleListOrDictEffectInfo(branch->output(), &tuple_indexes_copy);
1021       info.Merge(branch_info);
1022     }
1023     return info;
1024   }
1025 
1026   // Setup all branches according the effect info.
SetupEffectBranches(const EffectInfo & info,const std::vector<FuncGraphPtr> & branches)1027   void SetupEffectBranches(const EffectInfo &info, const std::vector<FuncGraphPtr> &branches) {
1028     // Setup monad parameters for all branches according the effect info.
1029     if (info.memory || info.load) {
1030       AddMonadParameters(branches, "u", kUMonad->ToAbstract());
1031     }
1032     if (info.io) {
1033       AddMonadParameters(branches, "io", kIOMonad->ToAbstract());
1034     }
1035     // Set merged effect info to both branches.
1036     for (auto &branch : branches) {
1037       MS_EXCEPTION_IF_NULL(branch);
1038       branch->SetEffectInfo(info);
1039       // Update caller if it is existed.
1040       UpdateBranchCaller(branch);
1041     }
1042   }
1043 
1044   // Merge effect info for switch or switch_layer branch graphs.
MergeEffectInfo(const std::vector<FuncGraphPtr> & branches)1045   EffectInfo MergeEffectInfo(const std::vector<FuncGraphPtr> &branches) {
1046     EffectInfo info = {EffectInfo::kDetected, false, false, false, false};
1047     for (auto &branch : branches) {
1048       MS_EXCEPTION_IF_NULL(branch);
1049       EffectInfo branch_info = ObtainEffectInfoForFuncGraph(branch);
1050       info.Merge(branch_info);
1051     }
1052     return info;
1053   }
1054 
1055   // Trace a cnode for effect info.
TraceEffectInfoForCNode(const CNodePtr & cnode)1056   EffectInfo TraceEffectInfoForCNode(const CNodePtr &cnode) {
1057     MS_EXCEPTION_IF_NULL(cnode);
1058     auto prim = GetCNodePrimitiveWithoutDoSignature(cnode);
1059     if (IsPrimitiveEquals(prim, prim::kPrimSwitch)) {
1060       // Special handling for Switch primitive.
1061       return TraceSwitchEffectInfo(cnode);
1062     }
1063 
1064     if (IsPrimitiveEquals(prim, prim::kPrimSwitchLayer)) {
1065       // Special handling for SwitchLayer primitive.
1066       return TraceSwitchLayerEffectInfo(cnode);
1067     }
1068 
1069     if (IsPrimitiveEquals(prim, prim::kPrimTupleGetItem) || IsPrimitiveEquals(prim, prim::kPrimListGetItem) ||
1070         IsPrimitiveEquals(prim, prim::kPrimDictGetItem)) {
1071       // Trace tuple_getitem or list_getitem or dict_getitem.
1072       std::stack<ValuePtr> indexes;
1073       return TraceGetItemEffectInfo(cnode, &indexes);
1074     }
1075 
1076     if (IsPrimitiveEquals(prim, prim::kPrimMakeTuple) || IsPrimitiveEquals(prim, prim::kPrimMakeList)) {
1077       // Trace make_tuple or make_list.
1078       EffectInfo info{EffectInfo::kDetected, false, false, false, false};
1079       for (size_t i = 1; i < cnode->size(); ++i) {
1080         auto input_info = TraceEffectInfo(cnode->input(i));
1081         info.Merge(input_info);
1082       }
1083       return info;
1084     }
1085 
1086     // For high-order primitive such as Partial,
1087     // we trace effect info from its argument.
1088     int index_prim = GetSideEffectPropagate(prim);
1089     if (index_prim > 0 && index_prim < static_cast<int>(cnode->size())) {
1090       return TraceEffectInfo(cnode->input(static_cast<size_t>(index_prim)));
1091     }
1092 
1093     // For func graph calls, we trace effect info from graph output.
1094     auto called_graph = GetFuncGraph(cnode);
1095     if (called_graph != nullptr) {
1096       // Save the caller of the graph, so that we can update
1097       // monad parameters for it when requires.
1098       (void)graph_callers_[called_graph].emplace(cnode);
1099       return TraceEffectInfo(called_graph->output());
1100     }
1101 
1102     auto func_cnode = GetFuncCNode(cnode);
1103     if (func_cnode != nullptr) {
1104       //
1105       // For ClassType as the input[0], if it is a primitive class
1106       // with 'side_effect_propagate' attribute, we trace side effect
1107       // from its argument indxed by the attribute value.
1108       //
1109       // e.g.:
1110       //     setpara = P.Partial()(P.Assign, self.para)
1111       //     setpara(x)
1112       //
1113       auto class_type = GetValueNode<ClassTypePtr>(func_cnode->input(0));
1114       if (class_type != nullptr) {
1115         int index = GetSideEffectPropagate(class_type);
1116         if (index > 0 && index < static_cast<int>(cnode->size())) {
1117           return TraceEffectInfo(cnode->input(static_cast<size_t>(index)));
1118         }
1119       }
1120 
1121       // For high order cnode, trace effect info from the output of the input cnode.
1122       return TraceOutputEffectInfo(func_cnode);
1123     }
1124 
1125     // %0 = ExtractKeywordArg("key", value) // Maybe func_graph which has side effect.
1126     // %1 = %0(arg1, arg2)                  // Need add monad.
1127     if (IsPrimitiveCNode(cnode, prim::kPrimExtractKeywordArg)) {
1128       auto abs = cnode->abstract();
1129       auto real_func = GetFuncGraphFromFuncGraphAbstract(abs);
1130       if (real_func != nullptr) {
1131         // Try to obtain the effect info of func graph.
1132         auto effect_info = ObtainEffectInfoForFuncGraph(real_func);
1133         MS_LOG(DEBUG) << "The real_func: " << real_func->ToString() << ", " << abs->ToString()
1134                       << ", cnode: " << cnode->DebugString() << ", effect_info: " << effect_info.memory << "/"
1135                       << effect_info.io << "/" << effect_info.load;
1136         return effect_info;
1137       }
1138     }
1139     // Otherwise, assume no side effect and stop trace.
1140     MS_LOG(INFO) << "CNode side effect unknown: " << cnode->DebugString();
1141     return {EffectInfo::kDetected, false, false, false, false};
1142   }
1143 
1144   // Trace effect info from output of the cnode.
TraceOutputEffectInfo(const CNodePtr & cnode)1145   EffectInfo TraceOutputEffectInfo(const CNodePtr &cnode) {
1146     MS_EXCEPTION_IF_NULL(cnode);
1147     std::vector<ValuePtr> values;
1148     GetOutputValues(cnode, &values);
1149     if (values.size() == 1) {
1150       return ObtainEffectInfoForValue(values.front());
1151     }
1152     EffectInfo info{EffectInfo::kDetected, false, false, false, false};
1153     for (auto &value : values) {
1154       info.Merge(ObtainEffectInfoForValue(value));
1155     }
1156     return info;
1157   }
1158 
ObtainEffectInfoForValue(const ValuePtr & value)1159   EffectInfo ObtainEffectInfoForValue(const ValuePtr &value) {
1160     MS_EXCEPTION_IF_NULL(value);
1161     // FuncGraph.
1162     auto graph = dyn_cast<FuncGraph>(value);
1163     if (graph != nullptr) {
1164       return ObtainEffectInfoForFuncGraph(graph);
1165     }
1166     // Primitive.
1167     auto prim = dyn_cast<Primitive>(value);
1168     if (prim != nullptr) {
1169       return GetPrimEffectInfo(prim);
1170     }
1171     MS_LOG(INFO) << "Value side effect unknown: " << value->ToString();
1172     return {EffectInfo::kDetected, false, false, false, false};
1173   }
1174 
GetOutputValues(const CNodePtr & cnode,std::vector<ValuePtr> * values)1175   void GetOutputValues(const CNodePtr &cnode, std::vector<ValuePtr> *values) {
1176     MS_EXCEPTION_IF_NULL(cnode);
1177     // CNode is a func graph call.
1178     auto graph = GetValueNode<FuncGraphPtr>(cnode->input(0));
1179     if (graph != nullptr) {
1180       GetOutputValues(graph, values);
1181       return;
1182     }
1183     // CNode is applying another cnode.
1184     auto func_cnode = dyn_cast<CNode>(cnode->input(0));
1185     if (func_cnode != nullptr) {
1186       GetOutputValues(func_cnode, values);
1187       return;
1188     }
1189     // Primitive cnode.
1190     auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
1191     if (IsPrimitiveEquals(prim, prim::kPrimSwitch)) {
1192       // Switch.
1193       auto branches = GetSwitchBranches(cnode);
1194       GetOutputValues(branches, values);
1195       return;
1196     }
1197     if (IsPrimitiveEquals(prim, prim::kPrimSwitchLayer)) {
1198       // Switch layer.
1199       auto branches = GetSwitchLayerBranches(cnode);
1200       GetOutputValues(branches, values);
1201       return;
1202     }
1203     if (IsPrimitiveEquals(prim, prim::kPrimPartial)) {
1204       // Partial.
1205       auto fg = GetValueNode<FuncGraphPtr>(cnode->input(1));
1206       if (fg != nullptr) {
1207         GetOutputValues(fg, values);
1208         return;
1209       }
1210     }
1211     // Other cases not supported yet.
1212     MS_LOG(INFO) << "Output unknown: " << cnode->DebugString();
1213   }
1214 
GetOutputValues(const FuncGraphPtr & graph,std::vector<ValuePtr> * values)1215   void GetOutputValues(const FuncGraphPtr &graph, std::vector<ValuePtr> *values) {
1216     MS_EXCEPTION_IF_NULL(graph);
1217     MS_EXCEPTION_IF_NULL(values);
1218     auto output = graph->output();
1219     // Output is a value node.
1220     auto value = GetValueNode(output);
1221     if (value != nullptr) {
1222       (void)values->emplace_back(value);
1223       return;
1224     }
1225 
1226     // Output is a cnode.
1227     auto cnode = dyn_cast<CNode>(output);
1228     if (cnode != nullptr) {
1229       GetOutputValues(cnode, values);
1230       return;
1231     }
1232     MS_EXCEPTION_IF_NULL(output);
1233     MS_LOG(INFO) << "Unexpected output: " << output->DebugString();
1234   }
1235 
GetOutputValues(const std::vector<FuncGraphPtr> & graphs,std::vector<ValuePtr> * values)1236   void GetOutputValues(const std::vector<FuncGraphPtr> &graphs, std::vector<ValuePtr> *values) {
1237     for (auto &graph : graphs) {
1238       GetOutputValues(graph, values);
1239     }
1240   }
1241 
1242   // Trace an AnfNode for effect info.
TraceEffectInfo(const AnfNodePtr & node)1243   EffectInfo TraceEffectInfo(const AnfNodePtr &node) {
1244     MS_EXCEPTION_IF_NULL(node);
1245     // Trace cnode.
1246     auto cnode = node->cast<CNodePtr>();
1247     if (cnode != nullptr) {
1248       return TraceEffectInfoForCNode(cnode);
1249     }
1250 
1251     // Trace parameter.
1252     auto para = node->cast<ParameterPtr>();
1253     if (para != nullptr) {
1254       return TraceEffectInfoForParameter(para);
1255     }
1256 
1257     // Trace primitive.
1258     auto prim = GetPrimitiveWithoutDoSignature(node);
1259     if (prim != nullptr) {
1260       return GetPrimEffectInfo(prim);
1261     }
1262 
1263     // Trace func graph.
1264     auto graph = GetValueNode<FuncGraphPtr>(node);
1265     if (graph != nullptr) {
1266       return ObtainEffectInfoForFuncGraph(graph);
1267     }
1268 
1269     // Other ValueNode has no side effects. For example: ValueNode<ClassType> node.
1270     //  node1 = ValueNode<ClassType> class 'mindspore.ops.operations.debug_ops.Print'
1271     //  node2 = _get_cache_prim(node1) // the node has side effects.
1272     if (node->isa<ValueNode>()) {
1273       MS_LOG(DEBUG) << "The ValueNode has no side effect: " << node->DebugString();
1274       return {EffectInfo::kDetected, false, false, false, false};
1275     }
1276     // Something is wrong if we reached here.
1277     MS_LOG(WARNING) << "The effect info of the node is untraceable: " << node->DebugString()
1278                     << ".\nLine:" << trace::GetDebugInfoStr(node->debug_info());
1279     return {EffectInfo::kDetected, false, false, false, false};
1280   }
1281 
GetParameterIndex(const FuncGraphPtr & func_graph,const ParameterPtr & para) const1282   int GetParameterIndex(const FuncGraphPtr &func_graph, const ParameterPtr &para) const {
1283     int parameter_index = 0;
1284     for (auto &parameter : func_graph->parameters()) {
1285       if (para == parameter) {
1286         return parameter_index;
1287       }
1288       ++parameter_index;
1289     }
1290     MS_LOG(INTERNAL_EXCEPTION) << "Parameter not found: " << (para ? para->DebugString() : "<null>");
1291   }
1292 
1293   // Trace effect info from function parameter.
TraceEffectInfoForParameter(const ParameterPtr & para)1294   EffectInfo TraceEffectInfoForParameter(const ParameterPtr &para) {
1295     EffectInfo info{EffectInfo::kDetected, false, false, false, false};
1296     ForEachRealArguments(para, [this, &para, &info](const AnfNodePtr &arg) {
1297       // Merge caller input effect info.
1298       auto input_info = TraceEffectInfo(arg);
1299       info.Merge(input_info);
1300     });
1301     return info;
1302   }
1303 
ForEachRealArguments(const ParameterPtr & para,const std::function<void (const AnfNodePtr &)> & handler)1304   void ForEachRealArguments(const ParameterPtr &para, const std::function<void(const AnfNodePtr &)> &handler) {
1305     MS_EXCEPTION_IF_NULL(para);
1306     auto func_graph = para->func_graph();
1307     MS_EXCEPTION_IF_NULL(func_graph);
1308     // Find index of the parameter, starts from 0.
1309     const int para_index = GetParameterIndex(func_graph, para);
1310     const size_t input_index = static_cast<size_t>(para_index) + 1;
1311     // Search user cnodes of the func graph.
1312     auto &users = func_graph->func_graph_cnodes_index();
1313     if (users.empty()) {
1314       MS_LOG(WARNING) << "Unused graph for parameter " << para->DebugString();
1315     }
1316     // Push the parameter to a stack so that we can check cycle binding.
1317     NodeStackGuard param_stack_guard(&formal_param_stack_, para);
1318     for (auto &user : users) {
1319       auto use_index = user.first->second;
1320       if (use_index != 0) {
1321         // Skip non-caller usage.
1322         continue;
1323       }
1324       // Caller cnode.
1325       auto cnode = dyn_cast<CNode>(user.first->first);
1326       MS_EXCEPTION_IF_NULL(cnode);
1327       if (cnode != nullptr && input_index < cnode->size()) {
1328         auto &input = cnode->input(input_index);
1329         if (formal_param_stack_.contains(input)) {
1330           // Skip if the input is a parameter that we are finding its real argument.
1331           continue;
1332         }
1333         handler(input);
1334       }
1335     }
1336   }
1337 
1338   // For call node, returns effect info of the callee graph.
GetCallEffectInfo(const CNodePtr & cnode)1339   EffectInfo GetCallEffectInfo(const CNodePtr &cnode) {
1340     MS_EXCEPTION_IF_NULL(cnode);
1341     constexpr size_t min_call_node_size = 2;
1342     if (cnode->size() < min_call_node_size) {
1343       MS_LOG(INTERNAL_EXCEPTION) << "Invalid call node: " << cnode->DebugString();
1344     }
1345     auto func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
1346     if (func_graph == nullptr) {
1347       MS_LOG(INTERNAL_EXCEPTION) << "Invalid call node: " << cnode->DebugString();
1348     }
1349     return ObtainEffectInfoForFuncGraph(func_graph);
1350   }
1351 
1352   // Detect effect info by depth first search.
ObtainEffectInfoForCNodeInner(const CNodePtr & cnode)1353   EffectInfo ObtainEffectInfoForCNodeInner(const CNodePtr &cnode) {
1354     // For primitive, get effect info from its attributes and inputs.
1355     auto prim = GetCNodePrimitiveWithoutDoSignature(cnode);
1356     if (prim != nullptr) {
1357       // Skip 'return' cnode.
1358       if (IsPrimitiveEquals(prim, prim::kPrimReturn)) {
1359         return {EffectInfo::kDetected, false, false, false, false};
1360       }
1361       // Special handling for 'call' cnode.
1362       if (IsPrimitiveEquals(prim, prim::kPrimCall)) {
1363         return GetCallEffectInfo(cnode);
1364       }
1365       auto info = GetPrimEffectInfo(prim);
1366       if (!info.memory && !IsKeepRef(prim)) {
1367         // For primitive calls, if no memory effects but
1368         // Ref parameter used, we will insert 'load' before them.
1369         // Except for primitives like J(f) or Partial(f, x) which propagate side effect,
1370         // load is inserted inside the func_graph f.
1371         info.load = HasRefInput(cnode);
1372       }
1373       if (!info.memory && IsNonEffectRealNodeAndInputIsDynamic(cnode)) {
1374         info.load = HasRefSequenceInput(cnode);
1375       }
1376       return info;
1377     }
1378 
1379     // For func graph, detect effect info by its children cnodes.
1380     auto func_graph = GetFuncGraph(cnode);
1381     if (func_graph != nullptr) {
1382       // Save the caller of the graph, so that we can update
1383       // monad parameters for it when requires.
1384       (void)graph_callers_[func_graph].emplace(cnode);
1385       return ObtainEffectInfoForFuncGraph(func_graph);
1386     }
1387 
1388     // When input[0] is a cnode, it is a function returned from
1389     // a high-order function call, we trace it by return value.
1390     auto func_cnode = GetFuncCNode(cnode);
1391     if (func_cnode != nullptr) {
1392       caller_ = cnode;
1393       auto effect_info = TraceEffectInfoForCNode(func_cnode);
1394       // Retry for Partial call.
1395       return TracePartialCallEffectInfo(cnode, effect_info);
1396     }
1397 
1398     // When input[0] is a parameter, it is a function parameter for
1399     // the high-order function, we trace it by caller.
1400     auto func_para = GetFuncParameter(cnode);
1401     if (func_para != nullptr) {
1402       auto effect_info = TraceEffectInfoForParameter(func_para);
1403       // Retry for Partial call.
1404       return TracePartialCallEffectInfo(cnode, effect_info);
1405     }
1406 
1407     // When input[0] is a MultitypeFuncGraph, it's not specialized
1408     // as one of its parameters is AbstractUndertermined,
1409     // This MultitypeFuncGraph may be specialized at next Renormalize
1410     // process, but we have to keep the order by insert UMonad now,
1411     // otherwise order will be lost in next Renormalize.
1412     // So assume it has memory side effect conservatively.
1413     auto func_multitype = GetFuncMultitypeFuncGraph(cnode);
1414     if (func_multitype != nullptr) {
1415       MS_LOG(DEBUG) << "Assume memory side effect for: " << cnode->DebugString();
1416       return {EffectInfo::kDetected, true, false, false, false};
1417     }
1418 
1419     // For other cnodes, we assume that they have no side effects.
1420     MS_LOG(DEBUG) << "Assume no side effect for: " << cnode->DebugString();
1421     return {EffectInfo::kDetected, false, false, false, false};
1422   }
1423 
1424   // Gets EffectInfo for CNode.
ObtainEffectInfoForCNode(const CNodePtr & cnode)1425   EffectInfo ObtainEffectInfoForCNode(const CNodePtr &cnode) {
1426     const auto &effect_info = cnode->GetEffectInfo();
1427     if (effect_info.state == EffectInfo::kDetected) {
1428       // Effect info already detected, return it.
1429       return effect_info;
1430     }
1431 
1432     // Detect effect info for the cnode.
1433     EffectInfo info = ObtainEffectInfoForCNodeInner(cnode);
1434     if (info.state == EffectInfo::kDetected) {
1435       // Save detected info into cnode.
1436       cnode->SetEffectInfo(info);
1437     }
1438     return info;
1439   }
1440 
1441   // Gets SCC that the given graph belongs to.
GetScc(const FuncGraphPtr & func_graph) const1442   SccPtr GetScc(const FuncGraphPtr &func_graph) const {
1443     auto found = scc_map_.find(func_graph);
1444     if (found == scc_map_.end()) {
1445       return nullptr;
1446     }
1447     return found->second;
1448   }
1449 
1450   // Set effect info for all member graphs in the SCC.
SetSccEffectInfo(const SccPtr & scc,const EffectInfo & info) const1451   void SetSccEffectInfo(const SccPtr &scc, const EffectInfo &info) const {
1452     MS_EXCEPTION_IF_NULL(scc);
1453     for (auto &g : *scc) {
1454       MS_EXCEPTION_IF_NULL(g);
1455       g->SetEffectInfo(info);
1456     }
1457   }
1458 
1459   // Gets EffectInfo for func graph's total used.
ObtainEffectInfoForFuncGraphs(const FuncGraphPtr & func_graph)1460   void ObtainEffectInfoForFuncGraphs(const FuncGraphPtr &func_graph) {
1461     MS_EXCEPTION_IF_NULL(func_graph);
1462     auto &used_func_graphs = func_graph->func_graphs_used_total();
1463     for (auto iter = used_func_graphs.crbegin(); iter != used_func_graphs.crend(); ++iter) {
1464       auto used_func_graph = *iter;
1465       MS_EXCEPTION_IF_NULL(used_func_graph);
1466       (void)ObtainEffectInfoForFuncGraph(used_func_graph);
1467     }
1468     ObtainEffectInfoForFuncGraph(func_graph);
1469   }
1470 
1471   // Gets EffectInfo for func graph.
ObtainEffectInfoForFuncGraph(const FuncGraphPtr & func_graph)1472   EffectInfo ObtainEffectInfoForFuncGraph(const FuncGraphPtr &func_graph) {
1473     MS_EXCEPTION_IF_NULL(func_graph);
1474     auto effect_info = func_graph->GetEffectInfo();
1475     if (effect_info.state != EffectInfo::kUnknown) {
1476       return effect_info;
1477     }
1478 
1479     // Get SCC that this graph belongs to.
1480     auto scc = GetScc(func_graph);
1481     if (scc == nullptr) {
1482       MS_LOG(INTERNAL_EXCEPTION) << "Scc should not be null, func_graph: " << func_graph->ToString();
1483     }
1484     // To prevent SCC members be visited again, we set effect info
1485     // to 'kDetecting' state before start to check cnodes.
1486     EffectInfo info{EffectInfo::kDetecting, false, false, false, false};
1487     SetSccEffectInfo(scc, info);
1488 
1489     // Check side effects for all cnodes in the SCC.
1490     std::vector<CNodePtr> undetected;
1491     for (auto &g : *scc) {
1492       MS_EXCEPTION_IF_NULL(g);
1493       for (auto &weak_cnode : g->order_list()) {
1494         const auto &cnode = weak_cnode.lock();
1495         if (cnode == nullptr) {
1496           continue;
1497         }
1498         auto cnode_effect = ObtainEffectInfoForCNode(cnode);
1499         if (cnode_effect.state != EffectInfo::kDetected) {
1500           // For side effect undetected node, it could be a call to the SCC member graph,
1501           // we will try to check side effect again after SCC side effect detected.
1502           undetected.push_back(cnode);
1503         }
1504         // Merge effect info from the node.
1505         info.Merge(cnode_effect);
1506       }
1507       // Make sure all sub-graphs is checked. since some sub-graphs may not directly called,
1508       // for example: return ValueNode(sub_graph).
1509       for (auto &sg : g->func_graphs_used()) {
1510         (void)ObtainEffectInfoForFuncGraph(sg.first);
1511       }
1512     }
1513     // Update effect into for all members of the SCC.
1514     info.state = EffectInfo::kDetected;
1515     SetSccEffectInfo(scc, info);
1516 
1517     // Check undetected cnodes again after side effect of the SCC is detected.
1518     for (auto &cnode : undetected) {
1519       MS_EXCEPTION_IF_NULL(cnode);
1520       auto cnode_effect = ObtainEffectInfoForCNode(cnode);
1521       // Side effect should be detected now, except free variable nodes that not belong to current SCC.
1522       if (cnode_effect.state != EffectInfo::kDetected && scc->find(cnode->func_graph()) != scc->end()) {
1523         MS_LOG(INTERNAL_EXCEPTION) << "Side effect is undetectable: " << cnode->DebugString();
1524       }
1525     }
1526     return info;
1527   }
1528 
1529   // The caller of switch node is also a caller of the branches, we save them
1530   // so that we can update monad parameters for the caller when it requires.
SaveBranchCaller(const CNodePtr & switch_node,const FuncGraphVector & branches)1531   void SaveBranchCaller(const CNodePtr &switch_node, const FuncGraphVector &branches) {
1532     MS_EXCEPTION_IF_NULL(switch_node);
1533     auto fg = switch_node->func_graph();
1534     MS_EXCEPTION_IF_NULL(fg);
1535     auto manager = fg->manager();
1536     MS_EXCEPTION_IF_NULL(manager);
1537     auto &node_users = manager->node_users();
1538     auto found = node_users.find(switch_node);
1539     if (found == node_users.end()) {
1540       MS_LOG(WARNING) << "Caller not found for " << switch_node->DebugString();
1541       return;
1542     }
1543     bool is_multi_branches = (branches.size() > 1);
1544     for (auto &user : found->second) {
1545       auto cnode = dyn_cast<CNode>(user.first);
1546       if (cnode == nullptr || user.second != 0) {
1547         continue;
1548       }
1549       // The cnode is the switch caller.
1550       if (is_multi_branches) {
1551         // Caller to branches.
1552         (void)switch_calls_.emplace(cnode, branches);
1553       }
1554       for (auto &branch : branches) {
1555         // Branch to caller.
1556         (void)graph_callers_[branch].emplace(cnode);
1557       }
1558     }
1559   }
1560 
UpdateBranchCaller(const FuncGraphPtr & branch)1561   void UpdateBranchCaller(const FuncGraphPtr &branch) {
1562     MS_EXCEPTION_IF_NULL(branch);
1563     auto iter = graph_callers_.find(branch);
1564     if (iter == graph_callers_.end()) {
1565       return;
1566     }
1567     const auto &info = branch->GetEffectInfo();
1568     for (auto &caller : iter->second) {
1569       AddMonadForCaller(caller, info);
1570     }
1571   }
1572 
AddMonadForCaller(const CNodePtr & caller,const EffectInfo & info) const1573   void AddMonadForCaller(const CNodePtr &caller, const EffectInfo &info) const {
1574     if (info.memory || info.load) {
1575       // Add u monad argument to caller if need.
1576       AddMonadArgument(caller, kUMonad);
1577     }
1578     if (info.io) {
1579       // Add io monad argument to caller if need.
1580       AddMonadArgument(caller, kIOMonad);
1581     }
1582   }
1583 
AddMonadArgument(const CNodePtr & cnode,const ValuePtr & monad) const1584   void AddMonadArgument(const CNodePtr &cnode, const ValuePtr &monad) const {
1585     MS_EXCEPTION_IF_NULL(cnode);
1586     MS_EXCEPTION_IF_NULL(monad);
1587     auto monad_abs = monad->ToAbstract();
1588     for (size_t i = 1; i < cnode->size(); ++i) {
1589       auto abs = cnode->input(i)->abstract();
1590       if (abs != nullptr && *abs == *monad_abs) {
1591         // Skip if monad argument already existed.
1592         return;
1593       }
1594     }
1595     // Add monad argument if not yet.
1596     auto monad_input = NewValueNode(monad);
1597     monad_input->set_abstract(monad_abs);
1598     if ((monad == kUMonad) && cnode->size() > 1 && HasAbstractIOMonad(cnode->weak_inputs().back().lock())) {
1599       // Insert u monad before io monad.
1600       size_t last_index = cnode->size() - 1;
1601       cnode->add_input(cnode->input(last_index));
1602       cnode->set_input(last_index, monad_input);
1603     } else {
1604       // Add monad as the last input.
1605       cnode->add_input(monad_input);
1606     }
1607   }
1608 
1609   // The root graph.
1610   FuncGraphPtr root_;
1611 
1612   // SCC map.
1613   SccMap scc_map_;
1614 
1615   // Map graph to its caller cnodes, so that we can add monad inputs to the
1616   // caller cnode when we late found that the graph added monad parameters.
1617   mindspore::HashMap<FuncGraphPtr, mindspore::HashSet<CNodePtr>> graph_callers_;
1618 
1619   // Current high order func caller cnode.
1620   CNodePtr caller_ = nullptr;
1621 
1622   // Save partial CNode caller cnodes and its real func graph, so that we can check and
1623   // update monad parameters for the real func graph according the caller inputs.
1624   mindspore::HashMap<CNodePtr, FuncGraphPtr> partial_cnode_calls_;
1625 
1626   // Save switch caller cnodes and their branches, so that we can check and
1627   // update monad parameters for branches according the caller inputs.
1628   mindspore::HashMap<CNodePtr, FuncGraphVector> switch_calls_;
1629 
1630   // switch_layer_calls save all switch_layer calls, so that
1631   // we can check whether monad argument should be added for them.
1632   std::vector<SwitchLayerCall> switch_layer_calls_;
1633 
1634   // Save traced formal parameters so that we can check cycle parameter binding.
1635   OrderedSet<AnfNodePtr> formal_param_stack_;
1636 };  // class SideEffectFinder
1637 
1638 // --------------------------------------------------------------------
1639 // AutoMonadConverter converts side-effect cnodes into monad form.
1640 // --------------------------------------------------------------------
1641 class AutoMonadConverter {
1642  public:
Handle(const FuncGraphPtr & func_graph,bool top)1643   static bool Handle(const FuncGraphPtr &func_graph, bool top) {
1644     AutoMonadConverter converter(func_graph, top);
1645     return converter.Run();
1646   }
1647 
1648  private:
AutoMonadConverter(const FuncGraphPtr & func_graph,bool top)1649   AutoMonadConverter(const FuncGraphPtr &func_graph, bool top)
1650       : func_graph_(func_graph), manager_(func_graph->manager()), top_(top) {}
1651 
1652   ~AutoMonadConverter() = default;
1653 
Run()1654   bool Run() {
1655     // Handle cnodes for side effects.
1656     const auto &info = func_graph_->GetEffectInfo();
1657     if (info.state == EffectInfo::kDetected) {
1658       HandleCNodes();
1659     }
1660 
1661     // Safe to clear isolated nodes after handled side effect nodes.
1662     ClearIsolatedNodes();
1663 
1664     // Clean up after conversion finished.
1665     func_graph_->ClearOrderList();
1666     return has_effect_cnodes_;
1667   }
1668 
1669   // Check if there are side effects from effect info.
HasSideEffects(const EffectInfo & info)1670   static bool HasSideEffects(const EffectInfo &info) { return (info.memory || info.io || info.load || info.back_mem); }
1671 
1672   // Gets effect info for a cnode.
GetEffectInfoFromCNode(const CNodePtr & cnode) const1673   const EffectInfo &GetEffectInfoFromCNode(const CNodePtr &cnode) const {
1674     MS_EXCEPTION_IF_NULL(cnode);
1675     auto &effect_info = cnode->GetEffectInfo();
1676     if (effect_info.state != EffectInfo::kDetected) {
1677       // Effect info should have been set by SideEffectFinder.
1678       MS_LOG(WARNING) << "Side effects not detected: " << cnode->DebugString();
1679     }
1680     return effect_info;
1681   }
1682 
1683   // Handle CNodes for side effects.
HandleCNodes()1684   void HandleCNodes() {
1685     // Check whether UpdateState and Depend are required.
1686     bool update_state = NeedUpdateState();
1687 
1688     // Check all cnodes in order list.
1689     for (auto &weak_cnode : func_graph_->order_list()) {
1690       const auto &cnode = weak_cnode.lock();
1691       if (cnode == nullptr) {
1692         continue;
1693       }
1694       // Process param.value()  Load(param, U) ---> Load(param, GetUniverse())
1695       if (IsPrimitiveCNode(cnode, prim::kPrimLoad)) {
1696         const size_t param_index = 1;
1697         const size_t monad_index = 2;
1698         auto param = cnode->input(param_index);
1699         auto load_monad = cnode->input(monad_index);
1700         auto param_abs = param->abstract();
1701         MS_EXCEPTION_IF_NULL(param_abs);
1702         if (param_abs->isa<abstract::AbstractRefTensor>() && IsValueNode<UMonad>(load_monad)) {
1703           auto current_u = GetUniverse();
1704           manager_->SetEdge(cnode, SizeToInt(monad_index), current_u);
1705           u_ = UpdateState(current_u, cnode);
1706           continue;
1707         }
1708       }
1709       auto &info = GetEffectInfoFromCNode(cnode);
1710       has_effect_cnodes_ = (has_effect_cnodes_ || HasSideEffects(info));
1711       if (cnode->func_graph() != func_graph_) {
1712         // Handle outer cnode.
1713         HandleOuterNode(cnode, info);
1714       } else {
1715         // Handle cnode with memory side effects.
1716         if (info.memory) {
1717           HandleMemoryEffects(cnode, update_state);
1718         } else if (info.load) {
1719           // If no memory side effects, handle load if need.
1720           HandleLoad(cnode, update_state);
1721         }
1722         // Handle cnode with IO side effects.
1723         if (info.io) {
1724           HandleIoEffects(cnode, update_state);
1725         }
1726         // If the node has no side effects but 'no_eliminate' flag is set,
1727         // we save it to no_eliminate_nodes and handle them late.
1728         if (!info.memory && !info.io && IsNoEliminateNode(cnode)) {
1729           (void)no_eliminate_nodes_.emplace_back(cnode);
1730         }
1731       }
1732       cnode->SetEffectHandled(true);
1733     }
1734     // Attach no eliminate nodes to output.
1735     HandleNoEliminateNodes();
1736     // Attach monad to output if required.
1737     if (update_state) {
1738       AttachMonadToOutput();
1739     }
1740   }
1741 
1742   // Return true if the given cnode is primitive cnode with 'no_eliminate' flag.
IsNoEliminateNode(const CNodePtr & cnode) const1743   bool IsNoEliminateNode(const CNodePtr &cnode) const {
1744     if (cnode == nullptr || cnode->size() == 0) {
1745       return false;
1746     }
1747     auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
1748     if (prim == nullptr) {
1749       return false;
1750     }
1751     return GetPrimitiveFlag(prim, ATTR_NO_ELIMINATE);
1752   }
1753 
1754   // Attach no eliminate nodes to output.
HandleNoEliminateNodes()1755   void HandleNoEliminateNodes() {
1756     if (no_eliminate_nodes_.empty()) {
1757       // Skip if no nodes to be handled.
1758       return;
1759     }
1760     // If only one node, attach it to output directly.
1761     if (no_eliminate_nodes_.size() == 1) {
1762       AttachToOutput(no_eliminate_nodes_.front());
1763       return;
1764     }
1765     // For multiple nodes, attach them to output by a tuple.
1766     std::vector<AnfNodePtr> tuple_inputs;
1767     AbstractBasePtrList element_abstracts;
1768     tuple_inputs.reserve(no_eliminate_nodes_.size() + 1);
1769     element_abstracts.reserve(no_eliminate_nodes_.size());
1770     (void)tuple_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
1771     for (auto &node : no_eliminate_nodes_) {
1772       (void)tuple_inputs.emplace_back(node);
1773       (void)element_abstracts.emplace_back(node->abstract());
1774     }
1775     auto make_tuple_node = func_graph_->NewCNode(tuple_inputs);
1776     make_tuple_node->set_abstract(std::make_shared<abstract::AbstractTuple>(element_abstracts));
1777     AttachToOutput(make_tuple_node);
1778   }
1779 
1780   // Clean no side effect dependency nodes.
1781   //   From:  output = Depend(output, StopGrad)
1782   //          return output
1783   //
1784   //   To:    return output
ClearIsolatedNodes() const1785   void ClearIsolatedNodes() const {
1786     auto output = GetGraphOutput();
1787     constexpr size_t attach_index = 2;
1788     if (IsPrimitiveCNode(output, prim::kPrimDepend)) {
1789       auto attach_node = output->cast<CNodePtr>()->input(attach_index);
1790       if (IsPrimitiveCNode(attach_node, prim::kPrimStopGradient)) {
1791         auto attach_cnode = attach_node->cast<CNodePtr>();
1792         auto input = attach_cnode->input(1);
1793         // Check the input of stop_gradient.
1794         if (input->isa<CNode>() && input->cast<CNodePtr>()->has_side_effect_node()) {
1795           MS_LOG(WARNING) << "Some side effect nodes were eliminated by mistake.";
1796         }
1797         // Replace Depend(orig_output, StopGrad) node with orig_output.
1798         // After that, nodes may be eliminated if have no side effects.
1799         auto &orig_output = output->cast<CNodePtr>()->input(1);
1800         func_graph_->set_output(orig_output);
1801       }
1802     }
1803   }
1804 
HandleOuterNode(const CNodePtr & cnode,const EffectInfo & info)1805   void HandleOuterNode(const CNodePtr &cnode, const EffectInfo &info) {
1806     MS_EXCEPTION_IF_NULL(cnode);
1807     if (info.memory || info.load) {
1808       (void)GetUniverse();
1809       bool load_with_primitive = (info.load && IsPrimitiveCNode(cnode));
1810       if (!cnode->IsEffectHandled() && !load_with_primitive) {
1811         auto u_node = NewValueNode(kUMonad);
1812         u_node->set_abstract(kUMonad->ToAbstract());
1813         cnode->add_input(u_node);
1814       }
1815     }
1816     if (info.io) {
1817       (void)GetIoState();
1818       if (!cnode->IsEffectHandled()) {
1819         auto io = NewValueNode(kIOMonad);
1820         io->set_abstract(kIOMonad->ToAbstract());
1821         cnode->add_input(io);
1822       }
1823     }
1824   }
1825 
1826   //
1827   // Convert cnode with memory side effect to monad form,
1828   // from:
1829   //    output = func(input)
1830   // to:
1831   //    output = func(input, u)
1832   //    u = UpdateState(u, output) # if update_state is true
1833   //
HandleMemoryEffects(const CNodePtr & cnode,bool update_state)1834   void HandleMemoryEffects(const CNodePtr &cnode, bool update_state) {
1835     const auto &u = GetUniverse();
1836     AddMonadInput(cnode, u);
1837     if (update_state) {
1838       u_ = UpdateState(u, cnode);
1839     }
1840   }
1841 
1842   //
1843   // Convert cnode with io side effect to monad form,
1844   // from:
1845   //    output = func(input)
1846   // to:
1847   //    output = func(input, io)
1848   //    io = UpdateState(io, output) # if update_state is true
1849   //
HandleIoEffects(const CNodePtr & cnode,bool update_state)1850   void HandleIoEffects(const CNodePtr &cnode, bool update_state) {
1851     const auto &io = GetIoState();
1852     AddMonadInput(cnode, io);
1853     if (update_state) {
1854       io_ = UpdateState(io, cnode);
1855     }
1856   }
1857 
HandleLoad(const CNodePtr & cnode,bool update_state)1858   void HandleLoad(const CNodePtr &cnode, bool update_state) {
1859     MS_EXCEPTION_IF_NULL(cnode);
1860     // Check if a sequence which has ref exists in the inputs of the cnode, and the cnode is a real node.
1861     if (IsNonEffectRealNodeAndInputIsDynamic(cnode)) {
1862       return InsertLoadForSequenceRef(cnode, update_state);
1863     }
1864     if (IsValueNode<Primitive>(cnode->input(0))) {
1865       // For primitive calls that use Ref as input, insert Loads before them.
1866       InsertLoads(cnode, update_state);
1867     } else {
1868       // For non-primitive calls, load is used inside the callee,
1869       // We do not insert load for it but handle it as a side
1870       // effects cnode.
1871       HandleMemoryEffects(cnode, update_state);
1872     }
1873   }
1874 
NewItemNode(const AnfNodePtr & node,const AbstractBasePtr & seq_abs,const AbstractBasePtr & item_abs,size_t index)1875   AnfNodePtr NewItemNode(const AnfNodePtr &node, const AbstractBasePtr &seq_abs, const AbstractBasePtr &item_abs,
1876                          size_t index) {
1877     std::vector<AnfNodePtr> item_inputs;
1878     if (seq_abs->isa<abstract::AbstractTuple>()) {
1879       (void)item_inputs.emplace_back(NewValueNode(prim::kPrimTupleGetItem));
1880     } else if (seq_abs->isa<abstract::AbstractList>()) {
1881       (void)item_inputs.emplace_back(NewValueNode(prim::kPrimListGetItem));
1882     }
1883     (void)item_inputs.emplace_back(node);
1884     (void)item_inputs.emplace_back(NewValueNode(SizeToLong(index)));
1885     auto new_item = func_graph_->NewCNode(std::move(item_inputs));
1886     new_item->set_abstract(item_abs);
1887     if (item_abs->isa<abstract::AbstractRefTensor>()) {
1888       // Current u monad.
1889       auto current_u = GetUniverse();
1890       // Make a Load for item node.
1891       new_item = MakeLoad(node, new_item, current_u);
1892     }
1893     return new_item;
1894   }
1895 
1896   // params = (param1, param2, ..., value)
1897   // addn(params, xxx)  non-effect-node need insert load for params.
InsertLoadForSequenceRef(const CNodePtr & cnode,bool update_state)1898   void InsertLoadForSequenceRef(const CNodePtr &cnode, bool update_state) {
1899     abstract::AbstractBasePtrList new_seq_abstracts;
1900     for (size_t index = 1; index < cnode->size(); ++index) {
1901       const auto &input = cnode->input(index);
1902       const auto &input_abs = input->abstract();
1903       MS_EXCEPTION_IF_NULL(input_abs);
1904       if (!input_abs->isa<abstract::AbstractTuple>() && !input_abs->isa<abstract::AbstractList>()) {
1905         (void)new_seq_abstracts.emplace_back(input_abs);
1906         continue;
1907       }
1908       // Handle the input which is sequence.
1909       std::vector<AnfNodePtr> new_sequence_inputs;
1910       if (input_abs->isa<abstract::AbstractTuple>()) {
1911         (void)new_sequence_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
1912       } else if (input_abs->isa<abstract::AbstractList>()) {
1913         (void)new_sequence_inputs.emplace_back(NewValueNode(prim::kPrimMakeList));
1914       }
1915       auto seq_abs = input_abs->cast_ptr<abstract::AbstractSequence>();
1916       MS_EXCEPTION_IF_NULL(seq_abs);
1917       const auto &elements = seq_abs->elements();
1918       for (size_t item_index = 0; item_index < elements.size(); ++item_index) {
1919         const auto &item_abs = elements[item_index];
1920         auto item = NewItemNode(input, input_abs, item_abs, item_index);
1921         (void)new_sequence_inputs.emplace_back(item);
1922         (void)new_seq_abstracts.emplace_back(item->abstract());
1923       }
1924       auto new_seq = func_graph_->NewCNode(std::move(new_sequence_inputs));
1925       MS_LOG(DEBUG) << "Replace the input of non-effect-node:" << cnode->DebugString()
1926                     << " with:" << new_seq->DebugString();
1927       if (input_abs->isa<abstract::AbstractTuple>()) {
1928         new_seq->set_abstract(std::make_shared<abstract::AbstractTuple>(new_seq_abstracts));
1929       } else if (input_abs->isa<abstract::AbstractList>()) {
1930         new_seq->set_abstract(std::make_shared<abstract::AbstractList>(new_seq_abstracts));
1931       }
1932       manager_->SetEdge(cnode, SizeToInt(index), new_seq);
1933       if (update_state) {
1934         auto current_u = GetUniverse();
1935         // In the order_enforce phase, the cnode will be added to the updatestate to ensure the order,
1936         // and the input of the updatestate is maintained here to 2.
1937         // to ensure the verification of the updatestate in the relevant pass.
1938         u_ = UpdateState(current_u, new_seq);
1939       }
1940     }
1941   }
1942 
1943   //
1944   // Insert Loads for a primitive cnode that use Ref as input.
1945   // for example, from:
1946   //    out = Prim(self.para1, self.para2, other_args)
1947   // to:
1948   //    p1 = Load(self.para1, u)
1949   //    p2 = Load(self.para2, u)
1950   //    t = make_tuple(p1, p2) # if update_state
1951   //    u1 = UpdateState(u, t)   # is required
1952   //    out = Prim(p1, p2, other_args)
1953   //
InsertLoads(const CNodePtr & cnode,bool update_state)1954   void InsertLoads(const CNodePtr &cnode, bool update_state) {
1955     // Find ref inputs.
1956     auto ref_inputs = GetRefInputs(cnode);
1957     if (ref_inputs.empty()) {
1958       MS_LOG(WARNING) << "Ref input not found for load insertion: " << cnode->DebugString();
1959       return;
1960     }
1961     // Current u monad.
1962     auto current_u = GetUniverse();
1963     // Create Load cnodes.
1964     auto loads = MakeLoads(cnode, ref_inputs, current_u);
1965     if (loads.empty() || !update_state) {
1966       // Skip UpdateState insertion.
1967       return;
1968     }
1969     // Insert UpdateState if required.
1970     if (loads.size() == 1) {
1971       // One Load, no make_tuple needed.
1972       u_ = UpdateState(current_u, loads.front());
1973       return;
1974     }
1975     // Multiple Loads, Create a MakeTuple before UpdateState.
1976     abstract::AbstractBasePtrList load_abstracts;
1977     (void)std::transform(loads.begin(), loads.end(), std::back_inserter(load_abstracts),
1978                          [](const AnfNodePtr &load) { return load->abstract(); });
1979     (void)loads.insert(loads.begin(), NewValueNode(prim::kPrimMakeTuple));
1980     auto make_tuple = func_graph_->NewCNode(loads);
1981     make_tuple->set_abstract(std::make_shared<abstract::AbstractTuple>(load_abstracts));
1982     u_ = UpdateState(current_u, make_tuple);
1983   }
1984 
MakeLoads(const CNodePtr & cnode,const RefInputs & ref_inputs,const AnfNodePtr & u)1985   std::vector<AnfNodePtr> MakeLoads(const CNodePtr &cnode, const RefInputs &ref_inputs, const AnfNodePtr &u) {
1986     std::vector<AnfNodePtr> loads;
1987     for (auto &ref_input : ref_inputs) {
1988       // Make a Load cnode for ref input.
1989       auto &ref = ref_input.first;
1990       auto load = MakeLoad(cnode, ref, u);
1991       // Replace input with the load cnode.
1992       for (size_t index : ref_input.second) {
1993         manager_->SetEdge(cnode, SizeToInt(index), load);
1994       }
1995       (void)loads.emplace_back(std::move(load));
1996     }
1997     return loads;
1998   }
1999 
MakeLoad(const AnfNodePtr & node,const AnfNodePtr & ref,const AnfNodePtr & u)2000   CNodePtr MakeLoad(const AnfNodePtr &node, const AnfNodePtr &ref, const AnfNodePtr &u) {
2001     static const std::string primitive_target = "primitive_target";
2002     // Create Load cnode.
2003     auto load_prim = NewValueNode(prim::kPrimLoad);
2004     auto load_cnode = func_graph_->NewCNode({load_prim, ref, u});
2005     // Set device target for Load CNode.
2006     std::string target = GetCNodeTarget(node);
2007     load_cnode->set_user_data(primitive_target, std::make_shared<std::string>(target));
2008     // Set load_cnode abstract to Tensor according the input Ref[Tensor].
2009     auto ref_abs = dyn_cast<abstract::AbstractRefTensor>(ref->abstract());
2010     MS_EXCEPTION_IF_NULL(ref_abs);
2011     load_cnode->set_abstract(ref_abs->CloneAsTensor());
2012     return load_cnode;
2013   }
2014 
2015   // Add or replace monad input.
AddMonadInput(const CNodePtr & cnode,const AnfNodePtr & monad)2016   void AddMonadInput(const CNodePtr &cnode, const AnfNodePtr &monad) {
2017     MS_EXCEPTION_IF_NULL(cnode);
2018     constexpr size_t max_monad_inputs = 2;
2019     auto monad_abs = monad->abstract();
2020     int last = static_cast<int>(cnode->size()) - 1;
2021     int stop = last - max_monad_inputs;
2022     // Search monad in inputs, replace it if found.
2023     for (int i = last; i > 0 && i > stop; --i) {
2024       size_t index = static_cast<size_t>(i);
2025       auto input_abs = cnode->input(index)->abstract();
2026       if (input_abs && *input_abs == *monad_abs) {
2027         manager_->SetEdge(cnode, i, monad);
2028         return;
2029       }
2030     }
2031     // If monad not found in inputs, add a monad input.
2032     manager_->AddEdge(cnode, monad);
2033   }
2034 
AttachMonadToOutput() const2035   void AttachMonadToOutput() const {
2036     if (u_) {
2037       AttachToOutput(u_);
2038     }
2039     if (io_) {
2040       AttachToOutput(io_);
2041     }
2042   }
2043 
AttachToOutput(const AnfNodePtr & node) const2044   void AttachToOutput(const AnfNodePtr &node) const {
2045     auto output = GetGraphOutput();
2046     TraceGuard guard(std::make_shared<TraceCopy>(output->debug_info()));
2047     auto depend = NewValueNode(prim::kPrimDepend);
2048     // If isolated nodes dependencies exist.
2049     if (IsPrimitiveCNode(output, prim::kPrimDepend) &&
2050         IsPrimitiveCNode(output->cast<CNodePtr>()->input(kDependAttachNodeIndex), prim::kPrimStopGradient)) {
2051       // Insert new Depend node before isolated Depend node.
2052       auto isolated_depend = output->cast<CNodePtr>();
2053       auto &orig_output = isolated_depend->input(1);
2054       auto state_depend = func_graph_->NewCNode({depend, orig_output, node});
2055       state_depend->set_abstract(orig_output->abstract());
2056       manager_->SetEdge(isolated_depend, 1, state_depend);
2057       return;
2058     }
2059     // Insert Depend node and set it as output, if no isolated nodes.
2060     auto depend_cnode = func_graph_->NewCNode({depend, output, node});
2061     depend_cnode->set_abstract(output->abstract());
2062     func_graph_->set_output(depend_cnode);
2063   }
2064 
GetGraphOutput() const2065   AnfNodePtr GetGraphOutput() const {
2066     auto output = func_graph_->output();
2067     if (output != nullptr) {
2068       return output;
2069     }
2070     return NewValueNode(kNone);
2071   }
2072 
UpdateState(const AnfNodePtr & state,const AnfNodePtr & attach)2073   AnfNodePtr UpdateState(const AnfNodePtr &state, const AnfNodePtr &attach) {
2074     MS_EXCEPTION_IF_NULL(attach);
2075     auto attach_cnode = attach->cast<CNodePtr>();
2076     MS_EXCEPTION_IF_NULL(attach_cnode);
2077     // Not attach UpdateState if set kAttrIgnoreSideEffect.
2078     auto attr_ignore_side_effect = attach_cnode->GetAttr(kAttrIgnoreSideEffect);
2079     auto ignore_side_effect = attr_ignore_side_effect != nullptr && attr_ignore_side_effect->isa<BoolImm>() &&
2080                               GetValue<bool>(attr_ignore_side_effect);
2081     if (ignore_side_effect) {
2082       return state;
2083     }
2084 
2085     auto update_state = NewValueNode(prim::kPrimUpdateState);
2086     auto update_state_cnode = func_graph_->NewCNode({update_state, state, attach});
2087     update_state_cnode->set_abstract(state->abstract());
2088     return update_state_cnode;
2089   }
2090 
GetUniverse()2091   AnfNodePtr &GetUniverse() {
2092     if (u_ == nullptr) {
2093       if (top_) {
2094         u_ = NewValueNode(kUMonad);
2095         u_->set_abstract(kUMonad->ToAbstract());
2096       } else {
2097         u_ = AddMonadParameter(func_graph_, "u", kUMonad->ToAbstract());
2098       }
2099     }
2100     return u_;
2101   }
2102 
GetIoState()2103   AnfNodePtr &GetIoState() {
2104     if (io_ == nullptr) {
2105       if (top_) {
2106         io_ = NewValueNode(kIOMonad);
2107         io_->set_abstract(kIOMonad->ToAbstract());
2108       } else {
2109         io_ = AddMonadParameter(func_graph_, "io", kIOMonad->ToAbstract());
2110       }
2111     }
2112     return io_;
2113   }
2114 
2115   // Return true if update_state should be used in this func graph.
2116   // In some case, update_state can be omitted, such as:
2117   //   def side_effect_tail_call(args):
2118   //       a = pure_func(args)
2119   //       return side_effect_call(a)
NeedUpdateState() const2120   bool NeedUpdateState() const {
2121     // Search for the only one side effect cnode.
2122     CNodePtr side_effect_cnode = nullptr;
2123     for (auto &weak_cnode : func_graph_->order_list()) {
2124       const auto &cnode = weak_cnode.lock();
2125       if (cnode == nullptr) {
2126         continue;
2127       }
2128       if (HasSideEffect(cnode)) {
2129         if (side_effect_cnode != nullptr) {
2130           // There are multiple side effect cnodes, update state is required.
2131           return true;
2132         }
2133         side_effect_cnode = cnode;
2134       }
2135     }
2136     if (side_effect_cnode == nullptr) {
2137       // No side effect cnode, no update state.
2138       return false;
2139     }
2140     if (IsPrimitiveCNode(side_effect_cnode)) {
2141       // Always add update_state for primitive cnode.
2142       return true;
2143     }
2144     // If the only side effect cnode is not the tail call, update_state is required.
2145     return func_graph_->output() != side_effect_cnode;
2146   }
2147 
HasSideEffect(const CNodePtr & cnode) const2148   bool HasSideEffect(const CNodePtr &cnode) const {
2149     const auto &cnode_info = GetEffectInfoFromCNode(cnode);
2150     return (cnode_info.memory || cnode_info.load || cnode_info.io);
2151   }
2152 
2153   // The func graph to be converted.
2154   const FuncGraphPtr &func_graph_;
2155 
2156   // The func graph manager, used for graph edge update.
2157   FuncGraphManagerPtr manager_;
2158 
2159   // True if converting top graph.
2160   const bool top_;
2161 
2162   // True if there are side effect cnodes within this func graph.
2163   bool has_effect_cnodes_ = false;
2164 
2165   // CNodes that should not be eliminated even it is isolated node.
2166   std::vector<CNodePtr> no_eliminate_nodes_;
2167 
2168   // Current memory state node, null if no memory side effects.
2169   AnfNodePtr u_;
2170 
2171   // Current IO state node, null if no IO side effects.
2172   AnfNodePtr io_;
2173 };  // class AutoMonadConverter
2174 }  // namespace
2175 
2176 // Entry point of the auto-monad phase,
2177 // the func_graph should be resolved and infer is done.
2178 // return true if side effect nodes found in func_graph.
AutoMonad(const FuncGraphPtr & func_graph)2179 bool AutoMonad(const FuncGraphPtr &func_graph) {
2180   MS_EXCEPTION_IF_NULL(func_graph);
2181   MS_EXCEPTION_IF_NULL(func_graph->manager());
2182 
2183   // Search and mark side effects for the graph and sub-graphs.
2184   // this should be called before auto-monad starts.
2185   SideEffectFinder::Search(func_graph);
2186 
2187   // Execute auto-monad conversion on top graph.
2188   bool has_effects = AutoMonadConverter::Handle(func_graph, true);
2189   // Convert used sub-graphs.
2190   auto fg_used_total = func_graph->func_graphs_used_total();
2191   for (auto &fg : fg_used_total) {
2192     MS_EXCEPTION_IF_NULL(fg);
2193     auto top_flag = fg->has_flag(mindspore::kFuncGraphFlagBackPropEntry);
2194     bool fg_has_effects = AutoMonadConverter::Handle(fg, top_flag);
2195     has_effects = has_effects || fg_has_effects;
2196   }
2197   return has_effects;
2198 }
2199 
ReAutoMonad(const FuncGraphPtr & func_graph)2200 bool ReAutoMonad(const FuncGraphPtr &func_graph) {
2201   // AutoMonad for bprop network, only Monad for func graphs which back propogators have side effects.
2202   // Or AutoMonad for MultitypeFuncGraph which specialized in Renormalize other than the first Specialize pass.
2203   MS_EXCEPTION_IF_NULL(func_graph);
2204   bool need_auto_monad = false;
2205   std::vector<FuncGraphPtr> auto_monaded_fg;
2206   func_graph->EraseUnusedNodeInOrder();
2207   for (auto &fg : func_graph->func_graphs_used_total()) {
2208     MS_EXCEPTION_IF_NULL(fg);
2209     if (fg->has_flag(mindspore::kFuncGraphFlagReAutoMonad)) {
2210       auto_monaded_fg.push_back(fg);
2211       for (auto &used_fg : fg->func_graphs_used_total()) {
2212         MS_EXCEPTION_IF_NULL(used_fg);
2213         used_fg->set_flag(mindspore::kFuncGraphFlagReAutoMonad, true);
2214         auto_monaded_fg.push_back(used_fg);
2215       }
2216       need_auto_monad = true;
2217       MS_LOG(DEBUG) << "AutoMonad Grad for func graph: " << fg->ToString();
2218     }
2219     fg->EraseUnusedNodeInOrder();
2220   }
2221   bool changed = false;
2222   if (need_auto_monad) {
2223     for (auto &fg : func_graph->func_graphs_used_total()) {
2224       MS_EXCEPTION_IF_NULL(fg);
2225       if (!fg->has_flag(mindspore::kFuncGraphFlagReAutoMonad)) {
2226         fg->ClearOrderList();
2227       }
2228     }
2229     changed = AutoMonad(func_graph);
2230     for (auto &fg : auto_monaded_fg) {
2231       MS_EXCEPTION_IF_NULL(fg);
2232       fg->erase_flag(mindspore::kFuncGraphFlagReAutoMonad);
2233     }
2234     // After auto monad, Order List and Isolate nodes in graph and manager will be cleared.
2235   } else {
2236     func_graph->ClearOrderList();
2237     for (auto &fg : func_graph->func_graphs_used_total()) {
2238       fg->ClearOrderList();
2239     }
2240   }
2241   return changed;
2242 }
2243 }  // namespace pipeline
2244 }  // namespace mindspore
2245