1 /**
2 * Copyright 2020-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "pipeline/jit/ps/static_analysis/auto_monad.h"
18 #include <list>
19 #include <vector>
20 #include <stack>
21 #include <string>
22 #include <utility>
23 #include <memory>
24 #include <algorithm>
25 #include "mindspore/core/ops/structure_ops.h"
26 #include "mindspore/core/ops/sparse_ops.h"
27 #include "mindspore/core/ops/sequence_ops.h"
28 #include "mindspore/core/ops/nn_ops.h"
29 #include "mindspore/core/ops/array_ops.h"
30 #include "mindspore/core/ops/framework_ops.h"
31 #include "ir/anf.h"
32 #include "pipeline/jit/ps/parse/resolve.h"
33 #include "frontend/operator/ops.h"
34 #include "frontend/operator/composite/multitype_funcgraph.h"
35 #include "utils/flags.h"
36 #include "include/common/utils/utils.h"
37 #include "include/common/utils/anfalgo.h"
38 #include "utils/hash_map.h"
39 #include "utils/hash_set.h"
40 #include "utils/log_adapter.h"
41 #include "utils/ordered_map.h"
42 #include "utils/ordered_set.h"
43 #include "base/effect_info.h"
44 #include "abstract/abstract_value.h"
45 #include "pipeline/jit/ps/debug/trace.h"
46
47 namespace mindspore {
48 namespace pipeline {
49 namespace { // namespace anonymous
50 using ClassTypePtr = std::shared_ptr<parse::ClassType>;
51 using RefInputs = OrderedMap<AnfNodePtr, std::vector<size_t>>;
52
53 // Add or get a monad parameter.
AddMonadParameter(const FuncGraphPtr & func_graph,const std::string & name,const abstract::AbstractBasePtr & abs)54 AnfNodePtr AddMonadParameter(const FuncGraphPtr &func_graph, const std::string &name,
55 const abstract::AbstractBasePtr &abs) {
56 MS_EXCEPTION_IF_NULL(func_graph);
57 MS_EXCEPTION_IF_NULL(abs);
58 size_t params_size = func_graph->parameters().size();
59 size_t io_monad_location = params_size;
60 // Search for existed parameters, return it if found.
61 for (size_t i = 0; i < params_size; i++) {
62 auto &node = func_graph->parameters()[i];
63 auto para = dyn_cast<Parameter>(node);
64 if (para == nullptr) {
65 continue;
66 }
67 auto para_abs = para->abstract();
68 if (para_abs && *para_abs == *abs) {
69 return para;
70 }
71 if (HasAbstractIOMonad(para)) {
72 io_monad_location = i;
73 }
74 }
75 // Create a new parameter if not existed.
76 auto para = std::make_shared<Parameter>(func_graph);
77 para->set_name(name);
78 MS_EXCEPTION_IF_NULL(para->debug_info());
79 para->debug_info()->set_name(name);
80 para->set_abstract(abs);
81 // If io monad parameter added before u monad parameter, should insert u monad before io monad in parameters
82 if (io_monad_location != params_size && abs->isa<abstract::AbstractUMonad>()) {
83 std::vector<AnfNodePtr> params = func_graph->parameters();
84 (void)params.insert(params.begin() + SizeToInt(io_monad_location), para);
85 func_graph->set_parameters(params);
86 } else {
87 func_graph->add_parameter(para);
88 }
89 return para;
90 }
91
92 // Gets side effect propagate attribute value from a ClassType object.
GetSideEffectPropagate(const ClassTypePtr & class_type)93 int GetSideEffectPropagate(const ClassTypePtr &class_type) {
94 if (class_type) {
95 auto obj = class_type->obj();
96 if (py::hasattr(obj, GRAPH_FLAG_SIDE_EFFECT_PROPAGATE)) {
97 auto value = py::getattr(obj, GRAPH_FLAG_SIDE_EFFECT_PROPAGATE);
98 return value.cast<int>();
99 }
100 }
101 return 0;
102 }
103
104 // Gets 'side_effect_propagate' attribute value from a primitive.
GetSideEffectPropagate(const PrimitivePtr & prim)105 int GetSideEffectPropagate(const PrimitivePtr &prim) {
106 if (prim) {
107 auto attr = prim->GetAttr(GRAPH_FLAG_SIDE_EFFECT_PROPAGATE);
108 if (attr && attr->isa<Int64Imm>()) {
109 return static_cast<int>(attr->cast<Int64ImmPtr>()->value());
110 }
111 }
112 return 0;
113 }
114
115 // Gets ref inputs and its indexes from a cnode.
GetRefInputs(const CNodePtr & cnode)116 RefInputs GetRefInputs(const CNodePtr &cnode) {
117 RefInputs ref_inputs;
118 MS_EXCEPTION_IF_NULL(cnode);
119 for (size_t i = 1; i < cnode->size(); ++i) {
120 auto &input = cnode->input(i);
121 if (common::AnfAlgo::HasAbstractRef(input)) {
122 ref_inputs[input].push_back(i);
123 }
124 }
125 return ref_inputs;
126 }
127
128 // Return true if cnode has ref input.
HasRefInput(const CNodePtr & cnode)129 bool HasRefInput(const CNodePtr &cnode) {
130 if (cnode == nullptr || cnode->empty()) {
131 return false;
132 }
133 // Return true if any of arguments is ref.
134 return std::any_of(cnode->weak_inputs().begin() + 1, cnode->weak_inputs().end(), [](const auto &weak_input) {
135 const auto &input = weak_input.lock();
136 MS_EXCEPTION_IF_NULL(input);
137 return common::AnfAlgo::HasAbstractRef(input);
138 });
139 }
140
141 // Return true if cnode has tuple(ref) or list(ref).
HasRefSequenceInput(const CNodePtr & cnode)142 bool HasRefSequenceInput(const CNodePtr &cnode) {
143 if (cnode == nullptr || cnode->empty()) {
144 return false;
145 }
146 for (size_t index = 1; index < cnode->size(); ++index) {
147 const auto &input = cnode->input(index);
148 MS_EXCEPTION_IF_NULL(input);
149 if (common::AnfAlgo::SequenceHasAbstractRef(input)) {
150 return true;
151 }
152 }
153 return false;
154 }
155
156 // Return true if we don't need Load for the given primitive.
157 // i.e. keep Ref as Ref for some primitives.
IsKeepRef(const PrimitivePtr & prim)158 bool IsKeepRef(const PrimitivePtr &prim) {
159 return (GetSideEffectPropagate(prim) != 0) || IsPrimitiveEquals(prim, prim::kPrimRefToEmbed) ||
160 IsPrimitiveEquals(prim, prim::kPrimPull) || IsPrimitiveEquals(prim, prim::kPrimMakeTuple) ||
161 IsPrimitiveEquals(prim, prim::kPrimMakeList);
162 }
163
164 // Gets func_graph from the given cnode, return nullptr if it is not a func graph call.
GetFuncGraph(const CNodePtr & cnode)165 FuncGraphPtr GetFuncGraph(const CNodePtr &cnode) {
166 if (cnode != nullptr && !cnode->empty()) {
167 return GetValueNode<FuncGraphPtr>(cnode->input(0));
168 }
169 return nullptr;
170 }
171
172 // Gets first input as cnode from the given cnode,
173 // return null if input[0] is not a cnode.
GetFuncCNode(const CNodePtr & cnode)174 CNodePtr GetFuncCNode(const CNodePtr &cnode) {
175 if (cnode != nullptr && !cnode->empty()) {
176 return dyn_cast<CNode>(cnode->input(0));
177 }
178 return nullptr;
179 }
180
181 // Gets first input as function parameter from the given cnode,
182 // return null if input[0] is not a parameter.
GetFuncParameter(const CNodePtr & cnode)183 ParameterPtr GetFuncParameter(const CNodePtr &cnode) {
184 if (cnode != nullptr && !cnode->empty()) {
185 return dyn_cast<Parameter>(cnode->input(0));
186 }
187 return nullptr;
188 }
189
GetFuncGraphFromPartialAbstract(const abstract::AbstractBasePtr & abs)190 FuncGraphPtr GetFuncGraphFromPartialAbstract(const abstract::AbstractBasePtr &abs) {
191 if (abs == nullptr || !abs->isa<abstract::PartialAbstractClosure>()) {
192 return nullptr;
193 }
194
195 auto partial_closure = dyn_cast<abstract::PartialAbstractClosure>(abs);
196 MS_EXCEPTION_IF_NULL(partial_closure);
197 if (partial_closure->fn() == nullptr) {
198 MS_LOG(ERROR) << "Partial closure's func graph is null, " << abs->ToString();
199 return nullptr;
200 }
201 auto func_graph_abstract = dyn_cast<abstract::FuncGraphAbstractClosure>(partial_closure->fn());
202 if (func_graph_abstract != nullptr) {
203 MS_EXCEPTION_IF_NULL(func_graph_abstract);
204 if (!func_graph_abstract->specialized()) {
205 MS_LOG(DEBUG) << "Unspecialized func graph, partial abs: " << abs->ToString()
206 << ", partial fn abs: " << func_graph_abstract->ToString();
207 return nullptr;
208 }
209 return func_graph_abstract->func_graph();
210 }
211
212 // Nested Partial.
213 return GetFuncGraphFromPartialAbstract(partial_closure->fn());
214 }
215
GetFuncGraphFromFuncGraphAbstract(const abstract::AbstractBasePtr & abs)216 FuncGraphPtr GetFuncGraphFromFuncGraphAbstract(const abstract::AbstractBasePtr &abs) {
217 auto func_closure = dyn_cast<abstract::FuncGraphAbstractClosure>(abs);
218 if (func_closure == nullptr) {
219 return nullptr;
220 }
221 if (func_closure->func_graph() == nullptr) {
222 MS_LOG(DEBUG) << "FuncGraph closure's func graph is null, " << abs->ToString();
223 return nullptr;
224 }
225 return func_closure->func_graph();
226 }
227
228 // Gets first input as MultitypeFuncGraph from the given cnode,
229 // return null if input[0] is not a MultitypeFuncGraph.
GetFuncMultitypeFuncGraph(const CNodePtr & cnode)230 prim::MultitypeFuncGraphPtr GetFuncMultitypeFuncGraph(const CNodePtr &cnode) {
231 if (cnode != nullptr && !cnode->empty()) {
232 return GetValueNode<prim::MultitypeFuncGraphPtr>(cnode->input(0));
233 }
234 return nullptr;
235 }
236
237 // The cnode is non-effect-node, and the cnode is real node, and the inputs of cnode is dynamic.
IsNonEffectRealNodeAndInputIsDynamic(const CNodePtr & cnode)238 bool IsNonEffectRealNodeAndInputIsDynamic(const CNodePtr &cnode) {
239 MS_EXCEPTION_IF_NULL(cnode);
240 static const PrimitiveSet dynamic_input_node_prims = {
241 prim::kPrimStack, prim::kPrimConcat, prim::kPrimAddN, prim::kPrimIdentityN,
242 prim::kPrimSparseConcat, prim::kPrimMeshgrid, prim::kPrimDynamicStitch, prim::kPrimPyExecute,
243 prim::kPrimPyInterpret, prim::kPrimMakeDict};
244 PrimitivePtr prim = cnode->empty() ? nullptr : GetValueNode<PrimitivePtr>(cnode->input(0));
245 if (prim == nullptr) {
246 return false;
247 }
248 return dynamic_input_node_prims.find(prim) != dynamic_input_node_prims.end();
249 }
250
251 // --------------------------------------------------------------------
252 // SCC (Strongly Connected Components) related types.
253 // --------------------------------------------------------------------
254 using SccVector = mindspore::HashSet<FuncGraphPtr>;
255 using SccPtr = std::shared_ptr<SccVector>;
256 using SccMap = mindspore::HashMap<FuncGraphPtr, SccPtr>;
257
258 // ---------------------------------------------------------------------
259 // SccFinder find SCCs using Tarjan's algorithm.
260 // ---------------------------------------------------------------------
261 class SccFinder {
262 public:
SccFinder(const FuncGraphPtr & root)263 explicit SccFinder(const FuncGraphPtr &root) : root_(root) {}
264 ~SccFinder() = default;
Run()265 void Run() { Search(root_); }
scc_map()266 SccMap scc_map() { return std::move(scc_map_); }
267
268 private:
269 // Store each layer of visit stack.
270 struct SccVisitInfo {
271 FuncGraphPtr graph{nullptr};
272 const FuncGraphCounterMap *func_graphs_used_ptr{nullptr};
273 FuncGraphCounterMap::const_iterator visit_iter;
274 };
275
276 // Tarjan algorithm. Search SCCs from the given graph.
277 // Iterative implementation.
Search(const FuncGraphPtr & graph)278 void Search(const FuncGraphPtr &graph) {
279 MS_EXCEPTION_IF_NULL(graph);
280 std::stack<SccVisitInfo> visit_stack;
281 auto seen = NewFgSeenGeneration();
282 // Push the origin graph.
283 SccVisitInfo info;
284 info.graph = graph;
285 info.graph->seen_ = seen; // If visited.
286 info.graph->extra_seen_ = seen; // If in stack.
287 auto index = 1;
288 info.graph->set_user_data<size_t>("index", std::make_shared<size_t>(index));
289 info.graph->set_user_data<size_t>("low", std::make_shared<size_t>(index));
290 stack_.push(graph);
291 visit_stack.push(std::move(info));
292 while (!visit_stack.empty()) {
293 auto ¤t_info = visit_stack.top();
294 if (current_info.func_graphs_used_ptr == nullptr) {
295 current_info.func_graphs_used_ptr = ¤t_info.graph->func_graphs_used();
296 current_info.visit_iter = current_info.func_graphs_used_ptr->cbegin();
297 }
298 // If there's not visited used func graph, continue visiting the left used.
299 if (current_info.visit_iter != current_info.func_graphs_used_ptr->cend()) {
300 auto used_graph = current_info.visit_iter->first;
301 ++current_info.visit_iter;
302 if (used_graph->seen_ != seen) {
303 // First visit, push it.
304 MS_LOG(DEBUG) << "Push graph: " << used_graph->ToString();
305 stack_.push(used_graph);
306 SccVisitInfo used_info;
307 ++index;
308 used_info.graph = used_graph;
309 used_info.graph->set_user_data<size_t>("index", std::make_shared<size_t>(index));
310 used_info.graph->set_user_data<size_t>("low", std::make_shared<size_t>(index));
311 used_info.graph->seen_ = seen; // If visited.
312 used_info.graph->extra_seen_ = seen; // If in stack.
313 visit_stack.push(std::move(used_info));
314 } else if (used_graph->extra_seen_ == seen) {
315 // Visited before AND in stack, update low.
316 auto min_low = std::min(*current_info.graph->user_data<size_t>("low"), *used_graph->user_data<size_t>("low"));
317 current_info.graph->set_user_data<size_t>("low", std::make_shared<size_t>(min_low));
318 MS_LOG(DEBUG) << "Update low [" << min_low << "] for " << current_info.graph->ToString() << " by "
319 << used_graph->ToString();
320 }
321 continue;
322 }
323 // If all used func graphs are visited, pop it and check if it's SCC root.
324 auto current_graph = current_info.graph;
325 if (*current_graph->user_data<size_t>("low") != *current_graph->user_data<size_t>("index")) {
326 // Update low when pop.
327 visit_stack.pop();
328 auto &next_info = visit_stack.top();
329 auto min_low = std::min(*next_info.graph->user_data<size_t>("low"), *current_graph->user_data<size_t>("low"));
330 next_info.graph->set_user_data<size_t>("low", std::make_shared<size_t>(min_low));
331 MS_LOG(DEBUG) << "Update low [" << min_low << "] for " << next_info.graph->ToString() << " by "
332 << current_graph->ToString();
333 continue;
334 }
335 MS_LOG(DEBUG) << "Found SCC root: " << current_graph->ToString();
336 // Pop members of the SCC from stack, they are on top of its root.
337 auto scc = std::make_shared<SccVector>();
338 while (!stack_.empty()) {
339 auto g = stack_.top();
340 g->extra_seen_ = 0; // Not in stack any more.
341 stack_.pop();
342 // Add graph to SCC, and create the map from graph to SCC.
343 scc->insert(g);
344 (void)scc_map_.emplace(g, scc);
345 if (g == current_graph) {
346 break;
347 }
348 }
349 // SCC should not be empty.
350 if (scc->empty()) {
351 MS_LOG(INTERNAL_EXCEPTION) << "Invalid SCC for: " << graph->ToString();
352 }
353 visit_stack.pop();
354 }
355 }
356
357 // The root graph.
358 FuncGraphPtr root_;
359
360 // The stack for Tarjan algorithm.
361 std::stack<FuncGraphPtr> stack_;
362
363 // The result SCC map, from graph to its SCC.
364 SccMap scc_map_;
365 };
366
367 struct SwitchLayerCall {
368 CNodePtr caller;
369 EffectInfo effect_info;
370 std::vector<FuncGraphPtr> branches;
371 };
372
373 class NodeStackGuard {
374 public:
NodeStackGuard(OrderedSet<AnfNodePtr> * stack,const AnfNodePtr & node)375 NodeStackGuard(OrderedSet<AnfNodePtr> *stack, const AnfNodePtr &node) : stack_(stack) { stack_->push_front(node); }
~NodeStackGuard()376 ~NodeStackGuard() {
377 try {
378 (void)stack_->pop();
379 } catch (const std::exception &e) {
380 MS_LOG(ERROR) << "Exception when pop. Error info " << e.what();
381 }
382
383 stack_ = nullptr;
384 }
385
386 private:
387 OrderedSet<AnfNodePtr> *stack_;
388 };
389
390 // -------------------------------------------------------------------------------
391 // SideEffectFinder search and mark side effects for graph and its sub-graphs.
392 // -------------------------------------------------------------------------------
393 class SideEffectFinder {
394 public:
Search(const FuncGraphPtr & root)395 static void Search(const FuncGraphPtr &root) {
396 SideEffectFinder finder(root);
397 finder.Run();
398 }
399
400 private:
SideEffectFinder(const FuncGraphPtr & root)401 explicit SideEffectFinder(const FuncGraphPtr &root) : root_(root) {}
402 ~SideEffectFinder() = default;
403
Run()404 void Run() {
405 // To handle recursive calls, we generate SCC map before search.
406 GenerateSccMap();
407 // Update order list to include outer cnodes.
408 UpdateOrderLists();
409 // Find side effects by DFS from the top graph.
410 ObtainEffectInfoForFuncGraphs(root_);
411 // Check Switch calls, add monad arguments if need.
412 HandleSwitchCalls();
413 // Check SwitchLayer calls, add monad arguments if need.
414 HandleSwitchLayerCalls();
415 // Check Partial CNode calls, add monad arguments if need.
416 HandlePartialCalls();
417 }
418
UpdateOrderLists() const419 void UpdateOrderLists() const {
420 // Some cnodes used in current func graph but belong to other func graph, we have to
421 // insert them into order list so that we can handle side effects for them.
422 UpdateOrderList(root_);
423 for (auto &fg : root_->func_graphs_used_total()) {
424 UpdateOrderList(fg);
425 }
426 }
427
UpdateOrderList(const FuncGraphPtr & func_graph)428 static void UpdateOrderList(const FuncGraphPtr &func_graph) {
429 MS_EXCEPTION_IF_NULL(func_graph);
430 std::list<CNodeWeakPtr> new_order_list;
431 const auto &order_list = func_graph->order_list();
432 for (auto &weak_cnode : order_list) {
433 const auto &cnode = weak_cnode.lock();
434 if (cnode != nullptr) {
435 PushToOrderList(func_graph, cnode, &new_order_list);
436 }
437 }
438 func_graph->set_order_list(std::move(new_order_list));
439 }
440
PushToOrderList(const FuncGraphPtr & fg,const CNodePtr & cnode,std::list<CNodeWeakPtr> * new_order_list)441 static void PushToOrderList(const FuncGraphPtr &fg, const CNodePtr &cnode, std::list<CNodeWeakPtr> *new_order_list) {
442 MS_EXCEPTION_IF_NULL(cnode);
443 MS_EXCEPTION_IF_NULL(new_order_list);
444 // If contains.
445 auto iter = std::find_if(new_order_list->cbegin(), new_order_list->cend(), [&cnode](const CNodeWeakPtr &node) {
446 return node.lock() != nullptr && node.lock() == cnode;
447 });
448 if (iter != new_order_list->cend()) {
449 return;
450 }
451
452 for (auto &weak_input : cnode->weak_inputs()) {
453 auto input = weak_input.lock();
454 MS_EXCEPTION_IF_NULL(input);
455 auto input_cnode = dyn_cast<CNode>(input);
456 if (input_cnode != nullptr && input_cnode->func_graph() != fg) {
457 PushToOrderList(fg, input_cnode, new_order_list);
458 }
459 }
460 new_order_list->emplace_back(CNodeWeakPtr(cnode));
461 }
462
463 // Generate SCC map by SccFinder.
GenerateSccMap()464 void GenerateSccMap() {
465 SccFinder scc_finder(root_);
466 scc_finder.Run();
467 scc_map_ = std::move(scc_finder.scc_map());
468 }
469
470 // Gets branch graph from a switch cnode at given input index.
GetSwitchBranch(const CNodePtr & cnode,size_t index) const471 FuncGraphPtr GetSwitchBranch(const CNodePtr &cnode, size_t index) const {
472 MS_EXCEPTION_IF_NULL(cnode);
473 const auto &branch_node = cnode->input(index);
474 AnfNodePtr branch_fg_node = branch_node;
475 if (IsPrimitiveCNode(branch_node, prim::kPrimPartial)) {
476 auto branch_abs = branch_node->abstract();
477 constexpr auto recursive_level = 2;
478 MS_LOG(DEBUG) << "branch_node: " << branch_node->DebugString(recursive_level)
479 << ", abstract: " << (branch_abs != nullptr ? branch_abs->ToString() : "null");
480 auto branch_cnode = branch_node->cast_ptr<CNode>();
481 MS_EXCEPTION_IF_NULL(branch_cnode);
482 branch_fg_node = branch_cnode->input(1);
483 MS_EXCEPTION_IF_NULL(branch_fg_node);
484 MS_LOG(DEBUG) << "branch_fg_node: " << branch_fg_node->DebugString(recursive_level);
485 }
486 return GetValueNode<FuncGraphPtr>(branch_fg_node);
487 }
488
489 // Gets branch graphs from a switch cnode.
GetSwitchBranches(const CNodePtr & cnode) const490 std::vector<FuncGraphPtr> GetSwitchBranches(const CNodePtr &cnode) const {
491 MS_EXCEPTION_IF_NULL(cnode);
492 constexpr size_t switch_cnode_size = 4;
493 constexpr size_t true_index = 2;
494 constexpr size_t false_index = 3;
495 // Check size.
496 if (cnode->size() != switch_cnode_size) {
497 MS_LOG(INTERNAL_EXCEPTION) << "Invalid switch: " << cnode->DebugString();
498 }
499 // Add both branches, in some case, only one branch is set.
500 std::vector<FuncGraphPtr> branches;
501 auto true_branch = GetSwitchBranch(cnode, true_index);
502 if (true_branch != nullptr) {
503 (void)branches.emplace_back(true_branch);
504 }
505 auto false_branch = GetSwitchBranch(cnode, false_index);
506 if (false_branch != nullptr) {
507 (void)branches.emplace_back(false_branch);
508 }
509 if (branches.empty()) {
510 constexpr auto recursive_level = 2;
511 MS_LOG(INTERNAL_EXCEPTION) << "Invalid switch: " << cnode->DebugString(recursive_level);
512 }
513 return branches;
514 }
515
516 // Add monad parameter to switch branch graphs.
AddMonadParameters(const std::vector<FuncGraphPtr> & branches,const std::string & name,const AbstractBasePtr & abs) const517 void AddMonadParameters(const std::vector<FuncGraphPtr> &branches, const std::string &name,
518 const AbstractBasePtr &abs) const {
519 for (auto &branch : branches) {
520 (void)AddMonadParameter(branch, name, abs);
521 }
522 }
523
524 // Trace effect info for Partial call node.
TracePartialCallEffectInfo(const CNodePtr & cnode,const EffectInfo & old_info)525 EffectInfo TracePartialCallEffectInfo(const CNodePtr &cnode, const EffectInfo &old_info) {
526 const AnfNodePtr &func_node = cnode->input(0);
527 MS_EXCEPTION_IF_NULL(func_node);
528 // Only handle for Parameter or Non-Partial CNode.
529 if (!func_node->isa<Parameter>() && (!func_node->isa<CNode>() || IsPrimitiveCNode(func_node, prim::kPrimPartial))) {
530 return old_info;
531 }
532 auto partial_real_func = GetFuncGraphFromPartialAbstract(func_node->abstract());
533 if (partial_real_func == nullptr) {
534 return old_info;
535 }
536
537 // Not retry checking, if has already confirmed the Partial func graph has side effect, or still detect ongoing.
538 if (old_info.state != EffectInfo::kDetected || old_info.memory || old_info.io || old_info.load ||
539 old_info.back_mem) {
540 return old_info;
541 }
542
543 // Record the Partial callers and real func graph.
544 (void)partial_cnode_calls_.emplace(cnode, partial_real_func);
545
546 // Try to obtain the effect info of func graph.
547 auto effect_info = ObtainEffectInfoForFuncGraph(partial_real_func);
548 MS_EXCEPTION_IF_NULL(func_node->abstract());
549 MS_LOG(DEBUG) << "CNode or Parameter func: " << func_node->DebugString()
550 << ", partial_real_func: " << partial_real_func->ToString() << ", "
551 << func_node->abstract()->ToString() << ", cnode: " << cnode->DebugString()
552 << ", effect_info: " << effect_info.memory << "/" << effect_info.io << "/" << effect_info.load;
553 return effect_info;
554 }
555
556 // Trace effect info for Switch cnode.
TraceSwitchEffectInfo(const CNodePtr & cnode)557 EffectInfo TraceSwitchEffectInfo(const CNodePtr &cnode) {
558 // Find branches from switch cnode.
559 auto branches = GetSwitchBranches(cnode);
560 // Save branch caller, so that we can update arguments for the caller.
561 SaveBranchCaller(cnode, branches);
562 // For some case, only one branch is set.
563 if (branches.size() == 1) {
564 auto &branch = branches.front();
565 return ObtainEffectInfoForFuncGraph(branch);
566 }
567 // When both branches are set, merge their effect infos.
568 EffectInfo info = MergeEffectInfo(branches);
569 if (info.state == EffectInfo::kDetected) {
570 // Setup both branches according the merged effect info.
571 SetupEffectBranches(info, branches);
572 }
573 return info;
574 }
575
576 // Trace effect info for SwitchLayer cnode.
TraceSwitchLayerEffectInfo(const CNodePtr & cnode)577 EffectInfo TraceSwitchLayerEffectInfo(const CNodePtr &cnode) {
578 // Find branches from switch_layer cnode.
579 auto branches = GetSwitchLayerBranches(cnode);
580 // Merge effect info from all branches.
581 EffectInfo info = MergeEffectInfo(branches);
582 if (info.state == EffectInfo::kDetected) {
583 // Setup branches according the merged effect info.
584 SetupEffectBranches(info, branches);
585 // Save the switch_layer call, so that we can add monad argument for it if need.
586 auto &call = switch_layer_calls_.emplace_back();
587 call.caller = caller_;
588 call.effect_info = info;
589 call.branches = move(branches);
590 }
591 return info;
592 }
593
HandlePartialCalls()594 void HandlePartialCalls() {
595 for (auto &call : partial_cnode_calls_) {
596 const auto &caller = call.first;
597 const auto &func_graph = call.second;
598 const auto &effect_info = ObtainEffectInfoForFuncGraph(func_graph);
599 MS_EXCEPTION_IF_NULL(caller->abstract());
600 MS_LOG(DEBUG) << "func_graph: " << func_graph->ToString() << ", caller: " << caller->DebugString() << ", "
601 << caller->abstract()->ToString() << ", effect_info: " << effect_info.memory << "/"
602 << effect_info.io << "/" << effect_info.load << "/" << effect_info.back_mem;
603 AddMonadForCaller(caller, effect_info);
604 // Setup monad parameters for func graph according the effect info.
605 if (effect_info.memory || effect_info.load) {
606 (void)AddMonadParameter(func_graph, "u", kUMonad->ToAbstract());
607 }
608 if (effect_info.io) {
609 (void)AddMonadParameter(func_graph, "io", kIOMonad->ToAbstract());
610 }
611 }
612 }
613
HandleSwitchCalls()614 void HandleSwitchCalls() {
615 for (auto &call : switch_calls_) {
616 const auto &caller = call.first;
617 const auto &branches = call.second;
618 CheckAndFixSwitchCall(caller, branches);
619 }
620 }
621
CheckAndFixSwitchCall(const CNodePtr & caller,const FuncGraphVector & branches) const622 void CheckAndFixSwitchCall(const CNodePtr &caller, const FuncGraphVector &branches) const {
623 MS_EXCEPTION_IF_NULL(caller);
624 const auto caller_input_size = caller->size() - 1;
625 for (size_t i = 0; i < branches.size(); ++i) {
626 const auto &branch = branches[i];
627 MS_EXCEPTION_IF_NULL(branch);
628
629 // Get partial branch input size.
630 size_t extra_input_size = 0;
631 const auto &switch_node = caller->input(0);
632 if (!IsPrimitiveCNode(switch_node, prim::kPrimSwitch)) {
633 MS_LOG(INTERNAL_EXCEPTION) << "Not switch CNode, " << switch_node->DebugString();
634 }
635 const auto &switch_cnode = dyn_cast<CNode>(switch_node);
636 constexpr auto ignore_switch_and_cond_count = 2;
637 const auto &branch_node = switch_cnode->input(i + ignore_switch_and_cond_count);
638 if (IsPrimitiveCNode(branch_node, prim::kPrimPartial)) {
639 const auto &branch_cnode = branch_node->cast_ptr<CNode>();
640 constexpr auto ignore_partial_and_fg_count = 2;
641 extra_input_size = branch_cnode->size() - ignore_partial_and_fg_count;
642 }
643
644 // Check inputs size.
645 if (caller_input_size + extra_input_size != branch->parameters().size()) {
646 // Fix branch if number of parameter mismatch.
647 FixSwitchBranch(caller, branch);
648 // The number of parameter should matched after fix.
649 if (caller_input_size + extra_input_size != branch->parameters().size()) {
650 constexpr auto recursive_count = 2;
651 MS_LOG(INTERNAL_EXCEPTION) << "Fix switch branch parameters failed! " << caller->DebugString(recursive_count)
652 << ", branch: " << branch->ToString()
653 << ", branch node: " << branch_node->DebugString(recursive_count)
654 << ", size: " << caller_input_size << " + " << extra_input_size << " not equal to "
655 << branch->parameters().size();
656 }
657 }
658 }
659 }
660
FixSwitchBranch(const CNodePtr & caller,const FuncGraphPtr & branch) const661 void FixSwitchBranch(const CNodePtr &caller, const FuncGraphPtr &branch) const {
662 MS_EXCEPTION_IF_NULL(branch);
663 for (size_t i = caller->size() - 1; i > 0; --i) {
664 auto &input = caller->input(i);
665 MS_EXCEPTION_IF_NULL(input);
666 if (HasAbstractUMonad(input)) {
667 (void)AddMonadParameter(branch, "u", input->abstract());
668 } else if (HasAbstractIOMonad(input)) {
669 (void)AddMonadParameter(branch, "io", input->abstract());
670 }
671 }
672 }
673
HandleSwitchLayerCalls()674 void HandleSwitchLayerCalls() {
675 for (auto &call : switch_layer_calls_) {
676 const auto &info = call.effect_info;
677 const auto &branches = call.branches;
678 auto new_info = MergeEffectInfo(branches);
679 // Reset branches if effect info changed.
680 if (new_info.memory != info.memory || new_info.load != info.load || new_info.io != info.io) {
681 AddMonadForCaller(call.caller, new_info);
682 SetupEffectBranches(new_info, branches);
683 }
684 }
685 }
686
687 // Gets branch graphs from a switch_layer cnode.
GetSwitchLayerBranches(const CNodePtr & cnode)688 std::vector<FuncGraphPtr> GetSwitchLayerBranches(const CNodePtr &cnode) {
689 MS_EXCEPTION_IF_NULL(cnode);
690 constexpr size_t func_tuple_index = 2;
691 constexpr int recursive_level = 2;
692 if (cnode->size() <= func_tuple_index) {
693 MS_LOG(INTERNAL_EXCEPTION) << "Invalid switch_layer: " << cnode->DebugString(recursive_level);
694 }
695 auto func_tuple = cnode->input(func_tuple_index);
696 return GetGraphsFromTuple(func_tuple);
697 }
698
GetGraphFromSwitchWithDeadNode(const CNodePtr & cnode) const699 FuncGraphPtr GetGraphFromSwitchWithDeadNode(const CNodePtr &cnode) const {
700 MS_EXCEPTION_IF_NULL(cnode);
701 auto input = cnode->input(0);
702 MS_EXCEPTION_IF_NULL(input);
703 if (!IsPrimitiveCNode(input, prim::kPrimSwitch)) {
704 return nullptr;
705 }
706 auto node = input->cast_ptr<CNode>();
707 if (node->size() < kSwitchInputSize) {
708 MS_LOG(EXCEPTION) << "Switch inputs size: " << node->size() << "less than " << kSwitchInputSize;
709 }
710 auto cond_node = node->input(kSwitchCondIndex);
711 auto cond_abs = cond_node->abstract();
712 MS_EXCEPTION_IF_NULL(cond_abs);
713 auto cond_abs_val = cond_abs->BuildValue();
714 MS_EXCEPTION_IF_NULL(cond_abs_val);
715 if (cond_abs_val->ContainsValueAny()) {
716 return nullptr;
717 }
718 auto cond_abs_bool_val = dyn_cast<BoolImm>(cond_abs_val);
719 MS_EXCEPTION_IF_NULL(cond_abs_bool_val);
720 auto branch =
721 cond_abs_bool_val->value() ? node->input(kSwitchTrueBranchIndex) : node->input(kSwitchFalseBranchIndex);
722 return GetValueNode<FuncGraphPtr>(branch);
723 }
724
725 // Get and trace graphs from a tuple of func node for switch_layer.
GetGraphsFromTuple(const AnfNodePtr & func_tuple)726 std::vector<FuncGraphPtr> GetGraphsFromTuple(const AnfNodePtr &func_tuple) {
727 // The functions make tuple CNode.
728 if (IsPrimitiveCNode(func_tuple, prim::kPrimMakeTuple)) {
729 return GetGraphsFromMakeTuple(func_tuple->cast<CNodePtr>());
730 }
731 // The functions value tuple.
732 if (IsValueNode<ValueTuple>(func_tuple)) {
733 return GetGraphsFromValueTuple(func_tuple->cast<ValueNodePtr>());
734 }
735 // Trace tuple from parameter.
736 auto para = dyn_cast<Parameter>(func_tuple);
737 if (para != nullptr) {
738 std::vector<FuncGraphPtr> graphs;
739 ForEachRealArguments(para,
740 [this, &graphs](const AnfNodePtr &arg) { graphs = std::move(GetGraphsFromTuple(arg)); });
741 return graphs;
742 }
743 // Trace tuple returned from func graph call.
744 auto cnode = dyn_cast<CNode>(func_tuple);
745 MS_EXCEPTION_IF_NULL(cnode);
746 auto func_graph = GetFuncGraph(cnode);
747 if (func_graph != nullptr) {
748 return GetGraphsFromTuple(func_graph->output());
749 }
750 // Trace tuple returned from func graph call including switch with dead node.
751 func_graph = GetGraphFromSwitchWithDeadNode(cnode);
752 if (func_graph != nullptr) {
753 return GetGraphsFromTuple(func_graph->output());
754 }
755 MS_LOG(INTERNAL_EXCEPTION) << "Invalid input for switch_layer: func_graph is nullptr.";
756 }
757
758 // Get graphs from a tuple of funcs make node for switch_layer.
GetGraphsFromMakeTuple(const CNodePtr & make_tuple) const759 std::vector<FuncGraphPtr> GetGraphsFromMakeTuple(const CNodePtr &make_tuple) const {
760 MS_EXCEPTION_IF_NULL(make_tuple);
761 constexpr int recursive_level = 2;
762 if (make_tuple->size() <= 1) {
763 MS_LOG(INTERNAL_EXCEPTION) << "Invalid make_tuple for switch_layer: " << make_tuple->DebugString(recursive_level);
764 }
765 std::vector<FuncGraphPtr> graphs;
766 graphs.reserve(make_tuple->size() - 1);
767 for (size_t i = 1; i < make_tuple->size(); ++i) {
768 auto func_graph = GetValueNode<FuncGraphPtr>(make_tuple->input(i));
769 if (func_graph == nullptr) {
770 MS_LOG(WARNING) << "Non-graph found in switch_layer input: " << make_tuple->DebugString(recursive_level)
771 << ", index: " << i;
772 continue;
773 }
774 graphs.push_back(func_graph);
775 }
776 return graphs;
777 }
778
779 // Get graphs from a tuple of functions value tuple for switch_layer.
GetGraphsFromValueTuple(const ValueNodePtr & value_node) const780 std::vector<FuncGraphPtr> GetGraphsFromValueTuple(const ValueNodePtr &value_node) const {
781 MS_EXCEPTION_IF_NULL(value_node);
782 const auto &value = value_node->value();
783 MS_EXCEPTION_IF_NULL(value);
784 auto value_tuple = value->cast_ptr<ValueTuple>();
785 MS_EXCEPTION_IF_NULL(value_tuple);
786 std::vector<FuncGraphPtr> graphs;
787 graphs.reserve(value_tuple->size());
788 const auto &tuple_elements = value_tuple->value();
789 for (size_t i = 0; i < tuple_elements.size(); ++i) {
790 const auto &tuple_element = tuple_elements[i];
791 MS_EXCEPTION_IF_NULL(tuple_element);
792 auto func_graph = tuple_element->cast<FuncGraphPtr>();
793 if (func_graph == nullptr) {
794 MS_LOG(WARNING) << "Non-graph found in switch_layer input: " << value_node->DebugString() << ", index: " << i;
795 continue;
796 }
797 graphs.push_back(func_graph);
798 }
799 return graphs;
800 }
801
802 // Trace effect info from tuple_getitem cnode.
TraceGetItemEffectInfo(const CNodePtr & cnode,std::stack<ValuePtr> * indexes)803 EffectInfo TraceGetItemEffectInfo(const CNodePtr &cnode, std::stack<ValuePtr> *indexes) {
804 MS_EXCEPTION_IF_NULL(cnode);
805 MS_EXCEPTION_IF_NULL(indexes);
806 constexpr size_t tuple_or_list_or_dict_input = 1;
807 constexpr size_t index_input = 2;
808 constexpr size_t cnode_size = 3;
809 if (cnode->size() != cnode_size) {
810 MS_LOG(INTERNAL_EXCEPTION) << "Invalid getitem: " << cnode->DebugString();
811 }
812 // Get item index.
813 auto &index_node = cnode->input(index_input);
814 auto index_value = dyn_cast<ValueNode>(index_node);
815 if (index_value == nullptr) {
816 MS_LOG(INTERNAL_EXCEPTION) << "getitem with non-const index, cnode: " << cnode->DebugString();
817 }
818
819 // Get tuple, list or dict value.
820 const auto &tuple_or_list_or_dict_node = cnode->input(tuple_or_list_or_dict_input);
821 // Push tuple, list or dict index.
822 indexes->push(index_value->value());
823 return TraceTupleListOrDictEffectInfo(tuple_or_list_or_dict_node, indexes);
824 }
825
TraceTupleListOrDictEffectInfo(const AnfNodePtr & node,std::stack<ValuePtr> * indexes)826 EffectInfo TraceTupleListOrDictEffectInfo(const AnfNodePtr &node, std::stack<ValuePtr> *indexes) {
827 MS_EXCEPTION_IF_NULL(indexes);
828 auto para = dyn_cast<Parameter>(node);
829 if (para != nullptr) {
830 return TraceTupleListParaEffectInfo(para, *indexes);
831 }
832 auto cnode = dyn_cast<CNode>(node);
833 if (cnode != nullptr) {
834 return TraceTupleListCNodeEffectInfo(cnode, indexes);
835 }
836 // Should not reach here.
837 MS_LOG(INTERNAL_EXCEPTION) << "Side effects untraceable: cnode is nullptr. Invalid node: " << node->DebugString();
838 }
839
TraceTupleListParaEffectInfo(const ParameterPtr & para,const std::stack<ValuePtr> & indexes)840 EffectInfo TraceTupleListParaEffectInfo(const ParameterPtr ¶, const std::stack<ValuePtr> &indexes) {
841 EffectInfo info{EffectInfo::kDetected, false, false, false, false};
842 ForEachRealArguments(para, [this, &info, indexes](const AnfNodePtr &arg) {
843 // Merge real argument effect info.
844 auto indexes_copy = indexes;
845 auto arg_info = TraceTupleListOrDictEffectInfo(arg, &indexes_copy);
846 info.Merge(arg_info);
847 });
848 return info;
849 }
850
GetInputIndex(const ValuePtr & top_index_value,const CNodePtr & origin_cnode,size_t inputs_size)851 size_t GetInputIndex(const ValuePtr &top_index_value, const CNodePtr &origin_cnode, size_t inputs_size) {
852 auto int64_imm = dyn_cast<Int64Imm>(top_index_value);
853 if (int64_imm == nullptr) {
854 MS_LOG(INTERNAL_EXCEPTION) << "Invalid make_tuple: " << origin_cnode->DebugString()
855 << ", index: " << (top_index_value == nullptr ? "null" : top_index_value->ToString());
856 }
857 auto top_index = int64_imm->value();
858 size_t input_index = 0;
859 // Support tuple index is negative
860 if (top_index < 0) {
861 if (SizeToLong(inputs_size) + top_index < 0) {
862 MS_LOG(INTERNAL_EXCEPTION) << "Invalid make_tuple: " << origin_cnode->DebugString() << " index=" << top_index;
863 }
864 input_index = static_cast<size_t>(inputs_size + top_index);
865 } else {
866 // Follow the tuple item according the index.
867 input_index = static_cast<size_t>(top_index) + 1;
868 }
869 if (input_index >= inputs_size) {
870 MS_LOG(INTERNAL_EXCEPTION) << "Invalid make_tuple: " << origin_cnode->DebugString() << " index=" << top_index;
871 }
872 return input_index;
873 }
874
TraceMakeTupleListEffectInfo(const CNodePtr & cnode,std::stack<ValuePtr> * indexes)875 EffectInfo TraceMakeTupleListEffectInfo(const CNodePtr &cnode, std::stack<ValuePtr> *indexes) {
876 constexpr int recursive_level = 2;
877 if (indexes->empty()) {
878 MS_LOG(INTERNAL_EXCEPTION) << "Unexpected make_tuple or make_list: " << cnode->DebugString(recursive_level);
879 }
880 // Pop out tuple index.
881 auto top_index_value = indexes->top();
882 indexes->pop();
883 auto input_index = GetInputIndex(top_index_value, cnode, cnode->size());
884 if (indexes->empty()) {
885 // Trace non-tuple.
886 return TraceEffectInfo(cnode->input(input_index));
887 }
888 // This is the tuple of tuple case.
889 return TraceTupleListOrDictEffectInfo(cnode->input(input_index), indexes);
890 }
891
TraceMakeDictEffectInfo(const CNodePtr & cnode,std::stack<ValuePtr> * indexes)892 EffectInfo TraceMakeDictEffectInfo(const CNodePtr &cnode, std::stack<ValuePtr> *indexes) {
893 constexpr int recursive_level = 2;
894 if (indexes->empty()) {
895 MS_LOG(INTERNAL_EXCEPTION) << "Unexpected make_dict: " << cnode->DebugString(recursive_level);
896 }
897 // Pop out dict index.
898 auto top_key_value = indexes->top();
899 MS_EXCEPTION_IF_NULL(top_key_value);
900 indexes->pop();
901 constexpr size_t keys_node_index = 1;
902 constexpr size_t values_node_index = 2;
903 auto keys_node = cnode->input(keys_node_index);
904 MS_EXCEPTION_IF_NULL(keys_node);
905 auto keys = GetValueNode<ValueTuplePtr>(keys_node);
906 if (keys == nullptr) {
907 MS_LOG(INTERNAL_EXCEPTION) << "Invalid make_dict: " << cnode->DebugString()
908 << ", the keys node: " << keys_node->DebugString();
909 }
910 for (size_t i = 0; i < keys->size(); ++i) {
911 MS_EXCEPTION_IF_NULL(keys->value()[i]);
912 if (*(keys->value()[i]) == *top_key_value) {
913 // The values_node is a make_dict.
914 indexes->push(MakeValue(SizeToLong(i)));
915 return TraceTupleListOrDictEffectInfo(cnode->input(values_node_index), indexes);
916 }
917 }
918 MS_LOG(WARNING) << "make_dict untraceable from: " << cnode->DebugString(recursive_level);
919 return {EffectInfo::kDetected, false, false, false};
920 }
921
TraceDictItemsEffectInfo(const CNodePtr & cnode,std::stack<ValuePtr> * indexes)922 EffectInfo TraceDictItemsEffectInfo(const CNodePtr &cnode, std::stack<ValuePtr> *indexes) {
923 constexpr int recursive_level = 2;
924 // Pop dict_getitem index.
925 if (indexes->empty()) {
926 MS_LOG(INTERNAL_EXCEPTION) << "Unexpected dict_items: " << cnode->DebugString(recursive_level);
927 }
928 auto list_getitem_index_value = indexes->top();
929 indexes->pop();
930 // Pop dict_getitem index.
931 if (indexes->empty()) {
932 MS_LOG(INTERNAL_EXCEPTION) << "Unexpected dict_items: " << cnode->DebugString(recursive_level);
933 }
934 auto tuple_getitem_index_value = indexes->top();
935 indexes->pop();
936 constexpr size_t key_and_value_tuple_size = 2;
937 auto tuple_getitem_index = GetInputIndex(tuple_getitem_index_value, cnode, key_and_value_tuple_size + 1);
938 // If the item is a value_node, skip.
939 if (tuple_getitem_index == 1) {
940 MS_LOG(INFO) << "dict_items untraceable from: " << cnode->DebugString(recursive_level);
941 return {EffectInfo::kDetected, false, false, false};
942 }
943 // dict_items(make_dict(keys_value_tuple, make_tuple()))
944 if (!IsPrimitiveCNode(cnode->input(1), prim::kPrimMakeDict)) {
945 MS_LOG(WARNING) << "dict_items untraceable from: " << cnode->DebugString(recursive_level);
946 return {EffectInfo::kDetected, false, false, false};
947 }
948 // Trace the make_tuple.
949 auto make_dict_cnode = cnode->input(1)->cast<CNodePtr>();
950 constexpr size_t values_node_index = 2;
951 indexes->push(list_getitem_index_value);
952 return TraceTupleListOrDictEffectInfo(make_dict_cnode->input(values_node_index), indexes);
953 }
954
TraceTupleListCNodeEffectInfo(const CNodePtr & cnode,std::stack<ValuePtr> * indexes)955 EffectInfo TraceTupleListCNodeEffectInfo(const CNodePtr &cnode, std::stack<ValuePtr> *indexes) {
956 MS_EXCEPTION_IF_NULL(indexes);
957 MS_EXCEPTION_IF_NULL(cnode);
958 auto prim = GetCNodePrimitiveWithoutDoSignature(cnode);
959 constexpr int recursive_level = 2;
960 // Trace MakeTuple or MakeList.
961 if (IsPrimitiveEquals(prim, prim::kPrimMakeTuple) || IsPrimitiveEquals(prim, prim::kPrimMakeList)) {
962 return TraceMakeTupleListEffectInfo(cnode, indexes);
963 }
964 // Trace MakeDict.
965 if (IsPrimitiveEquals(prim, prim::kPrimMakeDict)) {
966 return TraceMakeDictEffectInfo(cnode, indexes);
967 }
968 // Trace the case of tuple, list or dict nested.
969 if (IsPrimitiveEquals(prim, prim::kPrimTupleGetItem) || IsPrimitiveEquals(prim, prim::kPrimListGetItem) ||
970 IsPrimitiveEquals(prim, prim::kPrimDictGetItem)) {
971 return TraceGetItemEffectInfo(cnode, indexes);
972 }
973 if (IsPrimitiveEquals(prim, prim::kPrimDictGetValues) && IsPrimitiveCNode(cnode->input(1), prim::kPrimMakeDict)) {
974 auto make_dict_cnode = cnode->input(1)->cast<CNodePtr>();
975 constexpr size_t values_node_index = 2;
976 return TraceTupleListOrDictEffectInfo(make_dict_cnode->input(values_node_index), indexes);
977 }
978 if (IsPrimitiveEquals(prim, prim::kPrimDictItems)) {
979 return TraceDictItemsEffectInfo(cnode, indexes);
980 }
981 // Trace primitive propagating side effect from its input, such as Depend, etc.
982 int input_index = GetSideEffectPropagate(prim);
983 if (input_index > 0 && input_index < static_cast<int>(cnode->size())) {
984 return TraceTupleListOrDictEffectInfo(cnode->input(static_cast<size_t>(input_index)), indexes);
985 }
986 // Tuple returned from func graph call.
987 auto func_graph = GetFuncGraph(cnode);
988 if (func_graph != nullptr) {
989 return TraceTupleListOrDictEffectInfo(func_graph->output(), indexes);
990 }
991 // Tuple returned from a Switch call.
992 if (cnode->size() == 1 && IsPrimitiveCNode(cnode->input(0), prim::kPrimSwitch)) {
993 return TraceTupleFromSwitch(cnode->input(0)->cast<CNodePtr>(), *indexes);
994 }
995 // Tuple is returned from J().
996 // %1 = J(primal)
997 // tuple = %1(args)
998 if (cnode->size() > 0 && IsPrimitiveCNode(cnode->input(0), prim::kPrimJ)) {
999 MS_LOG(DEBUG) << "Tuple from J: " << cnode->DebugString(recursive_level);
1000 constexpr size_t func_index = 1;
1001 auto j_conde = cnode->input(0)->cast<CNodePtr>();
1002 auto j_func = j_conde->input(func_index);
1003 auto func_info = TraceEffectInfo(j_func);
1004 // In order to add the Umonad arg to the bprop_top_cell in advance,
1005 // so that the side effects in the bprop graph are sorted earlier than the side effects of the optimizer.
1006 return {EffectInfo::kDetected, false, false, false, func_info.back_mem};
1007 }
1008 // Rare case.
1009 MS_LOG(WARNING) << "Tuple untraceable from: " << cnode->DebugString(recursive_level);
1010 return {EffectInfo::kDetected, false, false, false};
1011 }
1012
1013 // Trace effect info from a Switch node that output is a tuple.
TraceTupleFromSwitch(const CNodePtr & switch_cnode,const std::stack<ValuePtr> & tuple_indexes)1014 EffectInfo TraceTupleFromSwitch(const CNodePtr &switch_cnode, const std::stack<ValuePtr> &tuple_indexes) {
1015 auto branches = GetSwitchBranches(switch_cnode);
1016 EffectInfo info = {EffectInfo::kDetected, false, false, false, false};
1017 for (auto &branch : branches) {
1018 MS_EXCEPTION_IF_NULL(branch);
1019 auto tuple_indexes_copy = tuple_indexes;
1020 EffectInfo branch_info = TraceTupleListOrDictEffectInfo(branch->output(), &tuple_indexes_copy);
1021 info.Merge(branch_info);
1022 }
1023 return info;
1024 }
1025
1026 // Setup all branches according the effect info.
SetupEffectBranches(const EffectInfo & info,const std::vector<FuncGraphPtr> & branches)1027 void SetupEffectBranches(const EffectInfo &info, const std::vector<FuncGraphPtr> &branches) {
1028 // Setup monad parameters for all branches according the effect info.
1029 if (info.memory || info.load) {
1030 AddMonadParameters(branches, "u", kUMonad->ToAbstract());
1031 }
1032 if (info.io) {
1033 AddMonadParameters(branches, "io", kIOMonad->ToAbstract());
1034 }
1035 // Set merged effect info to both branches.
1036 for (auto &branch : branches) {
1037 MS_EXCEPTION_IF_NULL(branch);
1038 branch->SetEffectInfo(info);
1039 // Update caller if it is existed.
1040 UpdateBranchCaller(branch);
1041 }
1042 }
1043
1044 // Merge effect info for switch or switch_layer branch graphs.
MergeEffectInfo(const std::vector<FuncGraphPtr> & branches)1045 EffectInfo MergeEffectInfo(const std::vector<FuncGraphPtr> &branches) {
1046 EffectInfo info = {EffectInfo::kDetected, false, false, false, false};
1047 for (auto &branch : branches) {
1048 MS_EXCEPTION_IF_NULL(branch);
1049 EffectInfo branch_info = ObtainEffectInfoForFuncGraph(branch);
1050 info.Merge(branch_info);
1051 }
1052 return info;
1053 }
1054
1055 // Trace a cnode for effect info.
TraceEffectInfoForCNode(const CNodePtr & cnode)1056 EffectInfo TraceEffectInfoForCNode(const CNodePtr &cnode) {
1057 MS_EXCEPTION_IF_NULL(cnode);
1058 auto prim = GetCNodePrimitiveWithoutDoSignature(cnode);
1059 if (IsPrimitiveEquals(prim, prim::kPrimSwitch)) {
1060 // Special handling for Switch primitive.
1061 return TraceSwitchEffectInfo(cnode);
1062 }
1063
1064 if (IsPrimitiveEquals(prim, prim::kPrimSwitchLayer)) {
1065 // Special handling for SwitchLayer primitive.
1066 return TraceSwitchLayerEffectInfo(cnode);
1067 }
1068
1069 if (IsPrimitiveEquals(prim, prim::kPrimTupleGetItem) || IsPrimitiveEquals(prim, prim::kPrimListGetItem) ||
1070 IsPrimitiveEquals(prim, prim::kPrimDictGetItem)) {
1071 // Trace tuple_getitem or list_getitem or dict_getitem.
1072 std::stack<ValuePtr> indexes;
1073 return TraceGetItemEffectInfo(cnode, &indexes);
1074 }
1075
1076 if (IsPrimitiveEquals(prim, prim::kPrimMakeTuple) || IsPrimitiveEquals(prim, prim::kPrimMakeList)) {
1077 // Trace make_tuple or make_list.
1078 EffectInfo info{EffectInfo::kDetected, false, false, false, false};
1079 for (size_t i = 1; i < cnode->size(); ++i) {
1080 auto input_info = TraceEffectInfo(cnode->input(i));
1081 info.Merge(input_info);
1082 }
1083 return info;
1084 }
1085
1086 // For high-order primitive such as Partial,
1087 // we trace effect info from its argument.
1088 int index_prim = GetSideEffectPropagate(prim);
1089 if (index_prim > 0 && index_prim < static_cast<int>(cnode->size())) {
1090 return TraceEffectInfo(cnode->input(static_cast<size_t>(index_prim)));
1091 }
1092
1093 // For func graph calls, we trace effect info from graph output.
1094 auto called_graph = GetFuncGraph(cnode);
1095 if (called_graph != nullptr) {
1096 // Save the caller of the graph, so that we can update
1097 // monad parameters for it when requires.
1098 (void)graph_callers_[called_graph].emplace(cnode);
1099 return TraceEffectInfo(called_graph->output());
1100 }
1101
1102 auto func_cnode = GetFuncCNode(cnode);
1103 if (func_cnode != nullptr) {
1104 //
1105 // For ClassType as the input[0], if it is a primitive class
1106 // with 'side_effect_propagate' attribute, we trace side effect
1107 // from its argument indxed by the attribute value.
1108 //
1109 // e.g.:
1110 // setpara = P.Partial()(P.Assign, self.para)
1111 // setpara(x)
1112 //
1113 auto class_type = GetValueNode<ClassTypePtr>(func_cnode->input(0));
1114 if (class_type != nullptr) {
1115 int index = GetSideEffectPropagate(class_type);
1116 if (index > 0 && index < static_cast<int>(cnode->size())) {
1117 return TraceEffectInfo(cnode->input(static_cast<size_t>(index)));
1118 }
1119 }
1120
1121 // For high order cnode, trace effect info from the output of the input cnode.
1122 return TraceOutputEffectInfo(func_cnode);
1123 }
1124
1125 // %0 = ExtractKeywordArg("key", value) // Maybe func_graph which has side effect.
1126 // %1 = %0(arg1, arg2) // Need add monad.
1127 if (IsPrimitiveCNode(cnode, prim::kPrimExtractKeywordArg)) {
1128 auto abs = cnode->abstract();
1129 auto real_func = GetFuncGraphFromFuncGraphAbstract(abs);
1130 if (real_func != nullptr) {
1131 // Try to obtain the effect info of func graph.
1132 auto effect_info = ObtainEffectInfoForFuncGraph(real_func);
1133 MS_LOG(DEBUG) << "The real_func: " << real_func->ToString() << ", " << abs->ToString()
1134 << ", cnode: " << cnode->DebugString() << ", effect_info: " << effect_info.memory << "/"
1135 << effect_info.io << "/" << effect_info.load;
1136 return effect_info;
1137 }
1138 }
1139 // Otherwise, assume no side effect and stop trace.
1140 MS_LOG(INFO) << "CNode side effect unknown: " << cnode->DebugString();
1141 return {EffectInfo::kDetected, false, false, false, false};
1142 }
1143
1144 // Trace effect info from output of the cnode.
TraceOutputEffectInfo(const CNodePtr & cnode)1145 EffectInfo TraceOutputEffectInfo(const CNodePtr &cnode) {
1146 MS_EXCEPTION_IF_NULL(cnode);
1147 std::vector<ValuePtr> values;
1148 GetOutputValues(cnode, &values);
1149 if (values.size() == 1) {
1150 return ObtainEffectInfoForValue(values.front());
1151 }
1152 EffectInfo info{EffectInfo::kDetected, false, false, false, false};
1153 for (auto &value : values) {
1154 info.Merge(ObtainEffectInfoForValue(value));
1155 }
1156 return info;
1157 }
1158
ObtainEffectInfoForValue(const ValuePtr & value)1159 EffectInfo ObtainEffectInfoForValue(const ValuePtr &value) {
1160 MS_EXCEPTION_IF_NULL(value);
1161 // FuncGraph.
1162 auto graph = dyn_cast<FuncGraph>(value);
1163 if (graph != nullptr) {
1164 return ObtainEffectInfoForFuncGraph(graph);
1165 }
1166 // Primitive.
1167 auto prim = dyn_cast<Primitive>(value);
1168 if (prim != nullptr) {
1169 return GetPrimEffectInfo(prim);
1170 }
1171 MS_LOG(INFO) << "Value side effect unknown: " << value->ToString();
1172 return {EffectInfo::kDetected, false, false, false, false};
1173 }
1174
GetOutputValues(const CNodePtr & cnode,std::vector<ValuePtr> * values)1175 void GetOutputValues(const CNodePtr &cnode, std::vector<ValuePtr> *values) {
1176 MS_EXCEPTION_IF_NULL(cnode);
1177 // CNode is a func graph call.
1178 auto graph = GetValueNode<FuncGraphPtr>(cnode->input(0));
1179 if (graph != nullptr) {
1180 GetOutputValues(graph, values);
1181 return;
1182 }
1183 // CNode is applying another cnode.
1184 auto func_cnode = dyn_cast<CNode>(cnode->input(0));
1185 if (func_cnode != nullptr) {
1186 GetOutputValues(func_cnode, values);
1187 return;
1188 }
1189 // Primitive cnode.
1190 auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
1191 if (IsPrimitiveEquals(prim, prim::kPrimSwitch)) {
1192 // Switch.
1193 auto branches = GetSwitchBranches(cnode);
1194 GetOutputValues(branches, values);
1195 return;
1196 }
1197 if (IsPrimitiveEquals(prim, prim::kPrimSwitchLayer)) {
1198 // Switch layer.
1199 auto branches = GetSwitchLayerBranches(cnode);
1200 GetOutputValues(branches, values);
1201 return;
1202 }
1203 if (IsPrimitiveEquals(prim, prim::kPrimPartial)) {
1204 // Partial.
1205 auto fg = GetValueNode<FuncGraphPtr>(cnode->input(1));
1206 if (fg != nullptr) {
1207 GetOutputValues(fg, values);
1208 return;
1209 }
1210 }
1211 // Other cases not supported yet.
1212 MS_LOG(INFO) << "Output unknown: " << cnode->DebugString();
1213 }
1214
GetOutputValues(const FuncGraphPtr & graph,std::vector<ValuePtr> * values)1215 void GetOutputValues(const FuncGraphPtr &graph, std::vector<ValuePtr> *values) {
1216 MS_EXCEPTION_IF_NULL(graph);
1217 MS_EXCEPTION_IF_NULL(values);
1218 auto output = graph->output();
1219 // Output is a value node.
1220 auto value = GetValueNode(output);
1221 if (value != nullptr) {
1222 (void)values->emplace_back(value);
1223 return;
1224 }
1225
1226 // Output is a cnode.
1227 auto cnode = dyn_cast<CNode>(output);
1228 if (cnode != nullptr) {
1229 GetOutputValues(cnode, values);
1230 return;
1231 }
1232 MS_EXCEPTION_IF_NULL(output);
1233 MS_LOG(INFO) << "Unexpected output: " << output->DebugString();
1234 }
1235
GetOutputValues(const std::vector<FuncGraphPtr> & graphs,std::vector<ValuePtr> * values)1236 void GetOutputValues(const std::vector<FuncGraphPtr> &graphs, std::vector<ValuePtr> *values) {
1237 for (auto &graph : graphs) {
1238 GetOutputValues(graph, values);
1239 }
1240 }
1241
1242 // Trace an AnfNode for effect info.
TraceEffectInfo(const AnfNodePtr & node)1243 EffectInfo TraceEffectInfo(const AnfNodePtr &node) {
1244 MS_EXCEPTION_IF_NULL(node);
1245 // Trace cnode.
1246 auto cnode = node->cast<CNodePtr>();
1247 if (cnode != nullptr) {
1248 return TraceEffectInfoForCNode(cnode);
1249 }
1250
1251 // Trace parameter.
1252 auto para = node->cast<ParameterPtr>();
1253 if (para != nullptr) {
1254 return TraceEffectInfoForParameter(para);
1255 }
1256
1257 // Trace primitive.
1258 auto prim = GetPrimitiveWithoutDoSignature(node);
1259 if (prim != nullptr) {
1260 return GetPrimEffectInfo(prim);
1261 }
1262
1263 // Trace func graph.
1264 auto graph = GetValueNode<FuncGraphPtr>(node);
1265 if (graph != nullptr) {
1266 return ObtainEffectInfoForFuncGraph(graph);
1267 }
1268
1269 // Other ValueNode has no side effects. For example: ValueNode<ClassType> node.
1270 // node1 = ValueNode<ClassType> class 'mindspore.ops.operations.debug_ops.Print'
1271 // node2 = _get_cache_prim(node1) // the node has side effects.
1272 if (node->isa<ValueNode>()) {
1273 MS_LOG(DEBUG) << "The ValueNode has no side effect: " << node->DebugString();
1274 return {EffectInfo::kDetected, false, false, false, false};
1275 }
1276 // Something is wrong if we reached here.
1277 MS_LOG(WARNING) << "The effect info of the node is untraceable: " << node->DebugString()
1278 << ".\nLine:" << trace::GetDebugInfoStr(node->debug_info());
1279 return {EffectInfo::kDetected, false, false, false, false};
1280 }
1281
GetParameterIndex(const FuncGraphPtr & func_graph,const ParameterPtr & para) const1282 int GetParameterIndex(const FuncGraphPtr &func_graph, const ParameterPtr ¶) const {
1283 int parameter_index = 0;
1284 for (auto ¶meter : func_graph->parameters()) {
1285 if (para == parameter) {
1286 return parameter_index;
1287 }
1288 ++parameter_index;
1289 }
1290 MS_LOG(INTERNAL_EXCEPTION) << "Parameter not found: " << (para ? para->DebugString() : "<null>");
1291 }
1292
1293 // Trace effect info from function parameter.
TraceEffectInfoForParameter(const ParameterPtr & para)1294 EffectInfo TraceEffectInfoForParameter(const ParameterPtr ¶) {
1295 EffectInfo info{EffectInfo::kDetected, false, false, false, false};
1296 ForEachRealArguments(para, [this, ¶, &info](const AnfNodePtr &arg) {
1297 // Merge caller input effect info.
1298 auto input_info = TraceEffectInfo(arg);
1299 info.Merge(input_info);
1300 });
1301 return info;
1302 }
1303
ForEachRealArguments(const ParameterPtr & para,const std::function<void (const AnfNodePtr &)> & handler)1304 void ForEachRealArguments(const ParameterPtr ¶, const std::function<void(const AnfNodePtr &)> &handler) {
1305 MS_EXCEPTION_IF_NULL(para);
1306 auto func_graph = para->func_graph();
1307 MS_EXCEPTION_IF_NULL(func_graph);
1308 // Find index of the parameter, starts from 0.
1309 const int para_index = GetParameterIndex(func_graph, para);
1310 const size_t input_index = static_cast<size_t>(para_index) + 1;
1311 // Search user cnodes of the func graph.
1312 auto &users = func_graph->func_graph_cnodes_index();
1313 if (users.empty()) {
1314 MS_LOG(WARNING) << "Unused graph for parameter " << para->DebugString();
1315 }
1316 // Push the parameter to a stack so that we can check cycle binding.
1317 NodeStackGuard param_stack_guard(&formal_param_stack_, para);
1318 for (auto &user : users) {
1319 auto use_index = user.first->second;
1320 if (use_index != 0) {
1321 // Skip non-caller usage.
1322 continue;
1323 }
1324 // Caller cnode.
1325 auto cnode = dyn_cast<CNode>(user.first->first);
1326 MS_EXCEPTION_IF_NULL(cnode);
1327 if (cnode != nullptr && input_index < cnode->size()) {
1328 auto &input = cnode->input(input_index);
1329 if (formal_param_stack_.contains(input)) {
1330 // Skip if the input is a parameter that we are finding its real argument.
1331 continue;
1332 }
1333 handler(input);
1334 }
1335 }
1336 }
1337
1338 // For call node, returns effect info of the callee graph.
GetCallEffectInfo(const CNodePtr & cnode)1339 EffectInfo GetCallEffectInfo(const CNodePtr &cnode) {
1340 MS_EXCEPTION_IF_NULL(cnode);
1341 constexpr size_t min_call_node_size = 2;
1342 if (cnode->size() < min_call_node_size) {
1343 MS_LOG(INTERNAL_EXCEPTION) << "Invalid call node: " << cnode->DebugString();
1344 }
1345 auto func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
1346 if (func_graph == nullptr) {
1347 MS_LOG(INTERNAL_EXCEPTION) << "Invalid call node: " << cnode->DebugString();
1348 }
1349 return ObtainEffectInfoForFuncGraph(func_graph);
1350 }
1351
1352 // Detect effect info by depth first search.
ObtainEffectInfoForCNodeInner(const CNodePtr & cnode)1353 EffectInfo ObtainEffectInfoForCNodeInner(const CNodePtr &cnode) {
1354 // For primitive, get effect info from its attributes and inputs.
1355 auto prim = GetCNodePrimitiveWithoutDoSignature(cnode);
1356 if (prim != nullptr) {
1357 // Skip 'return' cnode.
1358 if (IsPrimitiveEquals(prim, prim::kPrimReturn)) {
1359 return {EffectInfo::kDetected, false, false, false, false};
1360 }
1361 // Special handling for 'call' cnode.
1362 if (IsPrimitiveEquals(prim, prim::kPrimCall)) {
1363 return GetCallEffectInfo(cnode);
1364 }
1365 auto info = GetPrimEffectInfo(prim);
1366 if (!info.memory && !IsKeepRef(prim)) {
1367 // For primitive calls, if no memory effects but
1368 // Ref parameter used, we will insert 'load' before them.
1369 // Except for primitives like J(f) or Partial(f, x) which propagate side effect,
1370 // load is inserted inside the func_graph f.
1371 info.load = HasRefInput(cnode);
1372 }
1373 if (!info.memory && IsNonEffectRealNodeAndInputIsDynamic(cnode)) {
1374 info.load = HasRefSequenceInput(cnode);
1375 }
1376 return info;
1377 }
1378
1379 // For func graph, detect effect info by its children cnodes.
1380 auto func_graph = GetFuncGraph(cnode);
1381 if (func_graph != nullptr) {
1382 // Save the caller of the graph, so that we can update
1383 // monad parameters for it when requires.
1384 (void)graph_callers_[func_graph].emplace(cnode);
1385 return ObtainEffectInfoForFuncGraph(func_graph);
1386 }
1387
1388 // When input[0] is a cnode, it is a function returned from
1389 // a high-order function call, we trace it by return value.
1390 auto func_cnode = GetFuncCNode(cnode);
1391 if (func_cnode != nullptr) {
1392 caller_ = cnode;
1393 auto effect_info = TraceEffectInfoForCNode(func_cnode);
1394 // Retry for Partial call.
1395 return TracePartialCallEffectInfo(cnode, effect_info);
1396 }
1397
1398 // When input[0] is a parameter, it is a function parameter for
1399 // the high-order function, we trace it by caller.
1400 auto func_para = GetFuncParameter(cnode);
1401 if (func_para != nullptr) {
1402 auto effect_info = TraceEffectInfoForParameter(func_para);
1403 // Retry for Partial call.
1404 return TracePartialCallEffectInfo(cnode, effect_info);
1405 }
1406
1407 // When input[0] is a MultitypeFuncGraph, it's not specialized
1408 // as one of its parameters is AbstractUndertermined,
1409 // This MultitypeFuncGraph may be specialized at next Renormalize
1410 // process, but we have to keep the order by insert UMonad now,
1411 // otherwise order will be lost in next Renormalize.
1412 // So assume it has memory side effect conservatively.
1413 auto func_multitype = GetFuncMultitypeFuncGraph(cnode);
1414 if (func_multitype != nullptr) {
1415 MS_LOG(DEBUG) << "Assume memory side effect for: " << cnode->DebugString();
1416 return {EffectInfo::kDetected, true, false, false, false};
1417 }
1418
1419 // For other cnodes, we assume that they have no side effects.
1420 MS_LOG(DEBUG) << "Assume no side effect for: " << cnode->DebugString();
1421 return {EffectInfo::kDetected, false, false, false, false};
1422 }
1423
1424 // Gets EffectInfo for CNode.
ObtainEffectInfoForCNode(const CNodePtr & cnode)1425 EffectInfo ObtainEffectInfoForCNode(const CNodePtr &cnode) {
1426 const auto &effect_info = cnode->GetEffectInfo();
1427 if (effect_info.state == EffectInfo::kDetected) {
1428 // Effect info already detected, return it.
1429 return effect_info;
1430 }
1431
1432 // Detect effect info for the cnode.
1433 EffectInfo info = ObtainEffectInfoForCNodeInner(cnode);
1434 if (info.state == EffectInfo::kDetected) {
1435 // Save detected info into cnode.
1436 cnode->SetEffectInfo(info);
1437 }
1438 return info;
1439 }
1440
1441 // Gets SCC that the given graph belongs to.
GetScc(const FuncGraphPtr & func_graph) const1442 SccPtr GetScc(const FuncGraphPtr &func_graph) const {
1443 auto found = scc_map_.find(func_graph);
1444 if (found == scc_map_.end()) {
1445 return nullptr;
1446 }
1447 return found->second;
1448 }
1449
1450 // Set effect info for all member graphs in the SCC.
SetSccEffectInfo(const SccPtr & scc,const EffectInfo & info) const1451 void SetSccEffectInfo(const SccPtr &scc, const EffectInfo &info) const {
1452 MS_EXCEPTION_IF_NULL(scc);
1453 for (auto &g : *scc) {
1454 MS_EXCEPTION_IF_NULL(g);
1455 g->SetEffectInfo(info);
1456 }
1457 }
1458
1459 // Gets EffectInfo for func graph's total used.
ObtainEffectInfoForFuncGraphs(const FuncGraphPtr & func_graph)1460 void ObtainEffectInfoForFuncGraphs(const FuncGraphPtr &func_graph) {
1461 MS_EXCEPTION_IF_NULL(func_graph);
1462 auto &used_func_graphs = func_graph->func_graphs_used_total();
1463 for (auto iter = used_func_graphs.crbegin(); iter != used_func_graphs.crend(); ++iter) {
1464 auto used_func_graph = *iter;
1465 MS_EXCEPTION_IF_NULL(used_func_graph);
1466 (void)ObtainEffectInfoForFuncGraph(used_func_graph);
1467 }
1468 ObtainEffectInfoForFuncGraph(func_graph);
1469 }
1470
1471 // Gets EffectInfo for func graph.
ObtainEffectInfoForFuncGraph(const FuncGraphPtr & func_graph)1472 EffectInfo ObtainEffectInfoForFuncGraph(const FuncGraphPtr &func_graph) {
1473 MS_EXCEPTION_IF_NULL(func_graph);
1474 auto effect_info = func_graph->GetEffectInfo();
1475 if (effect_info.state != EffectInfo::kUnknown) {
1476 return effect_info;
1477 }
1478
1479 // Get SCC that this graph belongs to.
1480 auto scc = GetScc(func_graph);
1481 if (scc == nullptr) {
1482 MS_LOG(INTERNAL_EXCEPTION) << "Scc should not be null, func_graph: " << func_graph->ToString();
1483 }
1484 // To prevent SCC members be visited again, we set effect info
1485 // to 'kDetecting' state before start to check cnodes.
1486 EffectInfo info{EffectInfo::kDetecting, false, false, false, false};
1487 SetSccEffectInfo(scc, info);
1488
1489 // Check side effects for all cnodes in the SCC.
1490 std::vector<CNodePtr> undetected;
1491 for (auto &g : *scc) {
1492 MS_EXCEPTION_IF_NULL(g);
1493 for (auto &weak_cnode : g->order_list()) {
1494 const auto &cnode = weak_cnode.lock();
1495 if (cnode == nullptr) {
1496 continue;
1497 }
1498 auto cnode_effect = ObtainEffectInfoForCNode(cnode);
1499 if (cnode_effect.state != EffectInfo::kDetected) {
1500 // For side effect undetected node, it could be a call to the SCC member graph,
1501 // we will try to check side effect again after SCC side effect detected.
1502 undetected.push_back(cnode);
1503 }
1504 // Merge effect info from the node.
1505 info.Merge(cnode_effect);
1506 }
1507 // Make sure all sub-graphs is checked. since some sub-graphs may not directly called,
1508 // for example: return ValueNode(sub_graph).
1509 for (auto &sg : g->func_graphs_used()) {
1510 (void)ObtainEffectInfoForFuncGraph(sg.first);
1511 }
1512 }
1513 // Update effect into for all members of the SCC.
1514 info.state = EffectInfo::kDetected;
1515 SetSccEffectInfo(scc, info);
1516
1517 // Check undetected cnodes again after side effect of the SCC is detected.
1518 for (auto &cnode : undetected) {
1519 MS_EXCEPTION_IF_NULL(cnode);
1520 auto cnode_effect = ObtainEffectInfoForCNode(cnode);
1521 // Side effect should be detected now, except free variable nodes that not belong to current SCC.
1522 if (cnode_effect.state != EffectInfo::kDetected && scc->find(cnode->func_graph()) != scc->end()) {
1523 MS_LOG(INTERNAL_EXCEPTION) << "Side effect is undetectable: " << cnode->DebugString();
1524 }
1525 }
1526 return info;
1527 }
1528
1529 // The caller of switch node is also a caller of the branches, we save them
1530 // so that we can update monad parameters for the caller when it requires.
SaveBranchCaller(const CNodePtr & switch_node,const FuncGraphVector & branches)1531 void SaveBranchCaller(const CNodePtr &switch_node, const FuncGraphVector &branches) {
1532 MS_EXCEPTION_IF_NULL(switch_node);
1533 auto fg = switch_node->func_graph();
1534 MS_EXCEPTION_IF_NULL(fg);
1535 auto manager = fg->manager();
1536 MS_EXCEPTION_IF_NULL(manager);
1537 auto &node_users = manager->node_users();
1538 auto found = node_users.find(switch_node);
1539 if (found == node_users.end()) {
1540 MS_LOG(WARNING) << "Caller not found for " << switch_node->DebugString();
1541 return;
1542 }
1543 bool is_multi_branches = (branches.size() > 1);
1544 for (auto &user : found->second) {
1545 auto cnode = dyn_cast<CNode>(user.first);
1546 if (cnode == nullptr || user.second != 0) {
1547 continue;
1548 }
1549 // The cnode is the switch caller.
1550 if (is_multi_branches) {
1551 // Caller to branches.
1552 (void)switch_calls_.emplace(cnode, branches);
1553 }
1554 for (auto &branch : branches) {
1555 // Branch to caller.
1556 (void)graph_callers_[branch].emplace(cnode);
1557 }
1558 }
1559 }
1560
UpdateBranchCaller(const FuncGraphPtr & branch)1561 void UpdateBranchCaller(const FuncGraphPtr &branch) {
1562 MS_EXCEPTION_IF_NULL(branch);
1563 auto iter = graph_callers_.find(branch);
1564 if (iter == graph_callers_.end()) {
1565 return;
1566 }
1567 const auto &info = branch->GetEffectInfo();
1568 for (auto &caller : iter->second) {
1569 AddMonadForCaller(caller, info);
1570 }
1571 }
1572
AddMonadForCaller(const CNodePtr & caller,const EffectInfo & info) const1573 void AddMonadForCaller(const CNodePtr &caller, const EffectInfo &info) const {
1574 if (info.memory || info.load) {
1575 // Add u monad argument to caller if need.
1576 AddMonadArgument(caller, kUMonad);
1577 }
1578 if (info.io) {
1579 // Add io monad argument to caller if need.
1580 AddMonadArgument(caller, kIOMonad);
1581 }
1582 }
1583
AddMonadArgument(const CNodePtr & cnode,const ValuePtr & monad) const1584 void AddMonadArgument(const CNodePtr &cnode, const ValuePtr &monad) const {
1585 MS_EXCEPTION_IF_NULL(cnode);
1586 MS_EXCEPTION_IF_NULL(monad);
1587 auto monad_abs = monad->ToAbstract();
1588 for (size_t i = 1; i < cnode->size(); ++i) {
1589 auto abs = cnode->input(i)->abstract();
1590 if (abs != nullptr && *abs == *monad_abs) {
1591 // Skip if monad argument already existed.
1592 return;
1593 }
1594 }
1595 // Add monad argument if not yet.
1596 auto monad_input = NewValueNode(monad);
1597 monad_input->set_abstract(monad_abs);
1598 if ((monad == kUMonad) && cnode->size() > 1 && HasAbstractIOMonad(cnode->weak_inputs().back().lock())) {
1599 // Insert u monad before io monad.
1600 size_t last_index = cnode->size() - 1;
1601 cnode->add_input(cnode->input(last_index));
1602 cnode->set_input(last_index, monad_input);
1603 } else {
1604 // Add monad as the last input.
1605 cnode->add_input(monad_input);
1606 }
1607 }
1608
1609 // The root graph.
1610 FuncGraphPtr root_;
1611
1612 // SCC map.
1613 SccMap scc_map_;
1614
1615 // Map graph to its caller cnodes, so that we can add monad inputs to the
1616 // caller cnode when we late found that the graph added monad parameters.
1617 mindspore::HashMap<FuncGraphPtr, mindspore::HashSet<CNodePtr>> graph_callers_;
1618
1619 // Current high order func caller cnode.
1620 CNodePtr caller_ = nullptr;
1621
1622 // Save partial CNode caller cnodes and its real func graph, so that we can check and
1623 // update monad parameters for the real func graph according the caller inputs.
1624 mindspore::HashMap<CNodePtr, FuncGraphPtr> partial_cnode_calls_;
1625
1626 // Save switch caller cnodes and their branches, so that we can check and
1627 // update monad parameters for branches according the caller inputs.
1628 mindspore::HashMap<CNodePtr, FuncGraphVector> switch_calls_;
1629
1630 // switch_layer_calls save all switch_layer calls, so that
1631 // we can check whether monad argument should be added for them.
1632 std::vector<SwitchLayerCall> switch_layer_calls_;
1633
1634 // Save traced formal parameters so that we can check cycle parameter binding.
1635 OrderedSet<AnfNodePtr> formal_param_stack_;
1636 }; // class SideEffectFinder
1637
1638 // --------------------------------------------------------------------
1639 // AutoMonadConverter converts side-effect cnodes into monad form.
1640 // --------------------------------------------------------------------
1641 class AutoMonadConverter {
1642 public:
Handle(const FuncGraphPtr & func_graph,bool top)1643 static bool Handle(const FuncGraphPtr &func_graph, bool top) {
1644 AutoMonadConverter converter(func_graph, top);
1645 return converter.Run();
1646 }
1647
1648 private:
AutoMonadConverter(const FuncGraphPtr & func_graph,bool top)1649 AutoMonadConverter(const FuncGraphPtr &func_graph, bool top)
1650 : func_graph_(func_graph), manager_(func_graph->manager()), top_(top) {}
1651
1652 ~AutoMonadConverter() = default;
1653
Run()1654 bool Run() {
1655 // Handle cnodes for side effects.
1656 const auto &info = func_graph_->GetEffectInfo();
1657 if (info.state == EffectInfo::kDetected) {
1658 HandleCNodes();
1659 }
1660
1661 // Safe to clear isolated nodes after handled side effect nodes.
1662 ClearIsolatedNodes();
1663
1664 // Clean up after conversion finished.
1665 func_graph_->ClearOrderList();
1666 return has_effect_cnodes_;
1667 }
1668
1669 // Check if there are side effects from effect info.
HasSideEffects(const EffectInfo & info)1670 static bool HasSideEffects(const EffectInfo &info) { return (info.memory || info.io || info.load || info.back_mem); }
1671
1672 // Gets effect info for a cnode.
GetEffectInfoFromCNode(const CNodePtr & cnode) const1673 const EffectInfo &GetEffectInfoFromCNode(const CNodePtr &cnode) const {
1674 MS_EXCEPTION_IF_NULL(cnode);
1675 auto &effect_info = cnode->GetEffectInfo();
1676 if (effect_info.state != EffectInfo::kDetected) {
1677 // Effect info should have been set by SideEffectFinder.
1678 MS_LOG(WARNING) << "Side effects not detected: " << cnode->DebugString();
1679 }
1680 return effect_info;
1681 }
1682
1683 // Handle CNodes for side effects.
HandleCNodes()1684 void HandleCNodes() {
1685 // Check whether UpdateState and Depend are required.
1686 bool update_state = NeedUpdateState();
1687
1688 // Check all cnodes in order list.
1689 for (auto &weak_cnode : func_graph_->order_list()) {
1690 const auto &cnode = weak_cnode.lock();
1691 if (cnode == nullptr) {
1692 continue;
1693 }
1694 // Process param.value() Load(param, U) ---> Load(param, GetUniverse())
1695 if (IsPrimitiveCNode(cnode, prim::kPrimLoad)) {
1696 const size_t param_index = 1;
1697 const size_t monad_index = 2;
1698 auto param = cnode->input(param_index);
1699 auto load_monad = cnode->input(monad_index);
1700 auto param_abs = param->abstract();
1701 MS_EXCEPTION_IF_NULL(param_abs);
1702 if (param_abs->isa<abstract::AbstractRefTensor>() && IsValueNode<UMonad>(load_monad)) {
1703 auto current_u = GetUniverse();
1704 manager_->SetEdge(cnode, SizeToInt(monad_index), current_u);
1705 u_ = UpdateState(current_u, cnode);
1706 continue;
1707 }
1708 }
1709 auto &info = GetEffectInfoFromCNode(cnode);
1710 has_effect_cnodes_ = (has_effect_cnodes_ || HasSideEffects(info));
1711 if (cnode->func_graph() != func_graph_) {
1712 // Handle outer cnode.
1713 HandleOuterNode(cnode, info);
1714 } else {
1715 // Handle cnode with memory side effects.
1716 if (info.memory) {
1717 HandleMemoryEffects(cnode, update_state);
1718 } else if (info.load) {
1719 // If no memory side effects, handle load if need.
1720 HandleLoad(cnode, update_state);
1721 }
1722 // Handle cnode with IO side effects.
1723 if (info.io) {
1724 HandleIoEffects(cnode, update_state);
1725 }
1726 // If the node has no side effects but 'no_eliminate' flag is set,
1727 // we save it to no_eliminate_nodes and handle them late.
1728 if (!info.memory && !info.io && IsNoEliminateNode(cnode)) {
1729 (void)no_eliminate_nodes_.emplace_back(cnode);
1730 }
1731 }
1732 cnode->SetEffectHandled(true);
1733 }
1734 // Attach no eliminate nodes to output.
1735 HandleNoEliminateNodes();
1736 // Attach monad to output if required.
1737 if (update_state) {
1738 AttachMonadToOutput();
1739 }
1740 }
1741
1742 // Return true if the given cnode is primitive cnode with 'no_eliminate' flag.
IsNoEliminateNode(const CNodePtr & cnode) const1743 bool IsNoEliminateNode(const CNodePtr &cnode) const {
1744 if (cnode == nullptr || cnode->size() == 0) {
1745 return false;
1746 }
1747 auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
1748 if (prim == nullptr) {
1749 return false;
1750 }
1751 return GetPrimitiveFlag(prim, ATTR_NO_ELIMINATE);
1752 }
1753
1754 // Attach no eliminate nodes to output.
HandleNoEliminateNodes()1755 void HandleNoEliminateNodes() {
1756 if (no_eliminate_nodes_.empty()) {
1757 // Skip if no nodes to be handled.
1758 return;
1759 }
1760 // If only one node, attach it to output directly.
1761 if (no_eliminate_nodes_.size() == 1) {
1762 AttachToOutput(no_eliminate_nodes_.front());
1763 return;
1764 }
1765 // For multiple nodes, attach them to output by a tuple.
1766 std::vector<AnfNodePtr> tuple_inputs;
1767 AbstractBasePtrList element_abstracts;
1768 tuple_inputs.reserve(no_eliminate_nodes_.size() + 1);
1769 element_abstracts.reserve(no_eliminate_nodes_.size());
1770 (void)tuple_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
1771 for (auto &node : no_eliminate_nodes_) {
1772 (void)tuple_inputs.emplace_back(node);
1773 (void)element_abstracts.emplace_back(node->abstract());
1774 }
1775 auto make_tuple_node = func_graph_->NewCNode(tuple_inputs);
1776 make_tuple_node->set_abstract(std::make_shared<abstract::AbstractTuple>(element_abstracts));
1777 AttachToOutput(make_tuple_node);
1778 }
1779
1780 // Clean no side effect dependency nodes.
1781 // From: output = Depend(output, StopGrad)
1782 // return output
1783 //
1784 // To: return output
ClearIsolatedNodes() const1785 void ClearIsolatedNodes() const {
1786 auto output = GetGraphOutput();
1787 constexpr size_t attach_index = 2;
1788 if (IsPrimitiveCNode(output, prim::kPrimDepend)) {
1789 auto attach_node = output->cast<CNodePtr>()->input(attach_index);
1790 if (IsPrimitiveCNode(attach_node, prim::kPrimStopGradient)) {
1791 auto attach_cnode = attach_node->cast<CNodePtr>();
1792 auto input = attach_cnode->input(1);
1793 // Check the input of stop_gradient.
1794 if (input->isa<CNode>() && input->cast<CNodePtr>()->has_side_effect_node()) {
1795 MS_LOG(WARNING) << "Some side effect nodes were eliminated by mistake.";
1796 }
1797 // Replace Depend(orig_output, StopGrad) node with orig_output.
1798 // After that, nodes may be eliminated if have no side effects.
1799 auto &orig_output = output->cast<CNodePtr>()->input(1);
1800 func_graph_->set_output(orig_output);
1801 }
1802 }
1803 }
1804
HandleOuterNode(const CNodePtr & cnode,const EffectInfo & info)1805 void HandleOuterNode(const CNodePtr &cnode, const EffectInfo &info) {
1806 MS_EXCEPTION_IF_NULL(cnode);
1807 if (info.memory || info.load) {
1808 (void)GetUniverse();
1809 bool load_with_primitive = (info.load && IsPrimitiveCNode(cnode));
1810 if (!cnode->IsEffectHandled() && !load_with_primitive) {
1811 auto u_node = NewValueNode(kUMonad);
1812 u_node->set_abstract(kUMonad->ToAbstract());
1813 cnode->add_input(u_node);
1814 }
1815 }
1816 if (info.io) {
1817 (void)GetIoState();
1818 if (!cnode->IsEffectHandled()) {
1819 auto io = NewValueNode(kIOMonad);
1820 io->set_abstract(kIOMonad->ToAbstract());
1821 cnode->add_input(io);
1822 }
1823 }
1824 }
1825
1826 //
1827 // Convert cnode with memory side effect to monad form,
1828 // from:
1829 // output = func(input)
1830 // to:
1831 // output = func(input, u)
1832 // u = UpdateState(u, output) # if update_state is true
1833 //
HandleMemoryEffects(const CNodePtr & cnode,bool update_state)1834 void HandleMemoryEffects(const CNodePtr &cnode, bool update_state) {
1835 const auto &u = GetUniverse();
1836 AddMonadInput(cnode, u);
1837 if (update_state) {
1838 u_ = UpdateState(u, cnode);
1839 }
1840 }
1841
1842 //
1843 // Convert cnode with io side effect to monad form,
1844 // from:
1845 // output = func(input)
1846 // to:
1847 // output = func(input, io)
1848 // io = UpdateState(io, output) # if update_state is true
1849 //
HandleIoEffects(const CNodePtr & cnode,bool update_state)1850 void HandleIoEffects(const CNodePtr &cnode, bool update_state) {
1851 const auto &io = GetIoState();
1852 AddMonadInput(cnode, io);
1853 if (update_state) {
1854 io_ = UpdateState(io, cnode);
1855 }
1856 }
1857
HandleLoad(const CNodePtr & cnode,bool update_state)1858 void HandleLoad(const CNodePtr &cnode, bool update_state) {
1859 MS_EXCEPTION_IF_NULL(cnode);
1860 // Check if a sequence which has ref exists in the inputs of the cnode, and the cnode is a real node.
1861 if (IsNonEffectRealNodeAndInputIsDynamic(cnode)) {
1862 return InsertLoadForSequenceRef(cnode, update_state);
1863 }
1864 if (IsValueNode<Primitive>(cnode->input(0))) {
1865 // For primitive calls that use Ref as input, insert Loads before them.
1866 InsertLoads(cnode, update_state);
1867 } else {
1868 // For non-primitive calls, load is used inside the callee,
1869 // We do not insert load for it but handle it as a side
1870 // effects cnode.
1871 HandleMemoryEffects(cnode, update_state);
1872 }
1873 }
1874
NewItemNode(const AnfNodePtr & node,const AbstractBasePtr & seq_abs,const AbstractBasePtr & item_abs,size_t index)1875 AnfNodePtr NewItemNode(const AnfNodePtr &node, const AbstractBasePtr &seq_abs, const AbstractBasePtr &item_abs,
1876 size_t index) {
1877 std::vector<AnfNodePtr> item_inputs;
1878 if (seq_abs->isa<abstract::AbstractTuple>()) {
1879 (void)item_inputs.emplace_back(NewValueNode(prim::kPrimTupleGetItem));
1880 } else if (seq_abs->isa<abstract::AbstractList>()) {
1881 (void)item_inputs.emplace_back(NewValueNode(prim::kPrimListGetItem));
1882 }
1883 (void)item_inputs.emplace_back(node);
1884 (void)item_inputs.emplace_back(NewValueNode(SizeToLong(index)));
1885 auto new_item = func_graph_->NewCNode(std::move(item_inputs));
1886 new_item->set_abstract(item_abs);
1887 if (item_abs->isa<abstract::AbstractRefTensor>()) {
1888 // Current u monad.
1889 auto current_u = GetUniverse();
1890 // Make a Load for item node.
1891 new_item = MakeLoad(node, new_item, current_u);
1892 }
1893 return new_item;
1894 }
1895
1896 // params = (param1, param2, ..., value)
1897 // addn(params, xxx) non-effect-node need insert load for params.
InsertLoadForSequenceRef(const CNodePtr & cnode,bool update_state)1898 void InsertLoadForSequenceRef(const CNodePtr &cnode, bool update_state) {
1899 abstract::AbstractBasePtrList new_seq_abstracts;
1900 for (size_t index = 1; index < cnode->size(); ++index) {
1901 const auto &input = cnode->input(index);
1902 const auto &input_abs = input->abstract();
1903 MS_EXCEPTION_IF_NULL(input_abs);
1904 if (!input_abs->isa<abstract::AbstractTuple>() && !input_abs->isa<abstract::AbstractList>()) {
1905 (void)new_seq_abstracts.emplace_back(input_abs);
1906 continue;
1907 }
1908 // Handle the input which is sequence.
1909 std::vector<AnfNodePtr> new_sequence_inputs;
1910 if (input_abs->isa<abstract::AbstractTuple>()) {
1911 (void)new_sequence_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
1912 } else if (input_abs->isa<abstract::AbstractList>()) {
1913 (void)new_sequence_inputs.emplace_back(NewValueNode(prim::kPrimMakeList));
1914 }
1915 auto seq_abs = input_abs->cast_ptr<abstract::AbstractSequence>();
1916 MS_EXCEPTION_IF_NULL(seq_abs);
1917 const auto &elements = seq_abs->elements();
1918 for (size_t item_index = 0; item_index < elements.size(); ++item_index) {
1919 const auto &item_abs = elements[item_index];
1920 auto item = NewItemNode(input, input_abs, item_abs, item_index);
1921 (void)new_sequence_inputs.emplace_back(item);
1922 (void)new_seq_abstracts.emplace_back(item->abstract());
1923 }
1924 auto new_seq = func_graph_->NewCNode(std::move(new_sequence_inputs));
1925 MS_LOG(DEBUG) << "Replace the input of non-effect-node:" << cnode->DebugString()
1926 << " with:" << new_seq->DebugString();
1927 if (input_abs->isa<abstract::AbstractTuple>()) {
1928 new_seq->set_abstract(std::make_shared<abstract::AbstractTuple>(new_seq_abstracts));
1929 } else if (input_abs->isa<abstract::AbstractList>()) {
1930 new_seq->set_abstract(std::make_shared<abstract::AbstractList>(new_seq_abstracts));
1931 }
1932 manager_->SetEdge(cnode, SizeToInt(index), new_seq);
1933 if (update_state) {
1934 auto current_u = GetUniverse();
1935 // In the order_enforce phase, the cnode will be added to the updatestate to ensure the order,
1936 // and the input of the updatestate is maintained here to 2.
1937 // to ensure the verification of the updatestate in the relevant pass.
1938 u_ = UpdateState(current_u, new_seq);
1939 }
1940 }
1941 }
1942
1943 //
1944 // Insert Loads for a primitive cnode that use Ref as input.
1945 // for example, from:
1946 // out = Prim(self.para1, self.para2, other_args)
1947 // to:
1948 // p1 = Load(self.para1, u)
1949 // p2 = Load(self.para2, u)
1950 // t = make_tuple(p1, p2) # if update_state
1951 // u1 = UpdateState(u, t) # is required
1952 // out = Prim(p1, p2, other_args)
1953 //
InsertLoads(const CNodePtr & cnode,bool update_state)1954 void InsertLoads(const CNodePtr &cnode, bool update_state) {
1955 // Find ref inputs.
1956 auto ref_inputs = GetRefInputs(cnode);
1957 if (ref_inputs.empty()) {
1958 MS_LOG(WARNING) << "Ref input not found for load insertion: " << cnode->DebugString();
1959 return;
1960 }
1961 // Current u monad.
1962 auto current_u = GetUniverse();
1963 // Create Load cnodes.
1964 auto loads = MakeLoads(cnode, ref_inputs, current_u);
1965 if (loads.empty() || !update_state) {
1966 // Skip UpdateState insertion.
1967 return;
1968 }
1969 // Insert UpdateState if required.
1970 if (loads.size() == 1) {
1971 // One Load, no make_tuple needed.
1972 u_ = UpdateState(current_u, loads.front());
1973 return;
1974 }
1975 // Multiple Loads, Create a MakeTuple before UpdateState.
1976 abstract::AbstractBasePtrList load_abstracts;
1977 (void)std::transform(loads.begin(), loads.end(), std::back_inserter(load_abstracts),
1978 [](const AnfNodePtr &load) { return load->abstract(); });
1979 (void)loads.insert(loads.begin(), NewValueNode(prim::kPrimMakeTuple));
1980 auto make_tuple = func_graph_->NewCNode(loads);
1981 make_tuple->set_abstract(std::make_shared<abstract::AbstractTuple>(load_abstracts));
1982 u_ = UpdateState(current_u, make_tuple);
1983 }
1984
MakeLoads(const CNodePtr & cnode,const RefInputs & ref_inputs,const AnfNodePtr & u)1985 std::vector<AnfNodePtr> MakeLoads(const CNodePtr &cnode, const RefInputs &ref_inputs, const AnfNodePtr &u) {
1986 std::vector<AnfNodePtr> loads;
1987 for (auto &ref_input : ref_inputs) {
1988 // Make a Load cnode for ref input.
1989 auto &ref = ref_input.first;
1990 auto load = MakeLoad(cnode, ref, u);
1991 // Replace input with the load cnode.
1992 for (size_t index : ref_input.second) {
1993 manager_->SetEdge(cnode, SizeToInt(index), load);
1994 }
1995 (void)loads.emplace_back(std::move(load));
1996 }
1997 return loads;
1998 }
1999
MakeLoad(const AnfNodePtr & node,const AnfNodePtr & ref,const AnfNodePtr & u)2000 CNodePtr MakeLoad(const AnfNodePtr &node, const AnfNodePtr &ref, const AnfNodePtr &u) {
2001 static const std::string primitive_target = "primitive_target";
2002 // Create Load cnode.
2003 auto load_prim = NewValueNode(prim::kPrimLoad);
2004 auto load_cnode = func_graph_->NewCNode({load_prim, ref, u});
2005 // Set device target for Load CNode.
2006 std::string target = GetCNodeTarget(node);
2007 load_cnode->set_user_data(primitive_target, std::make_shared<std::string>(target));
2008 // Set load_cnode abstract to Tensor according the input Ref[Tensor].
2009 auto ref_abs = dyn_cast<abstract::AbstractRefTensor>(ref->abstract());
2010 MS_EXCEPTION_IF_NULL(ref_abs);
2011 load_cnode->set_abstract(ref_abs->CloneAsTensor());
2012 return load_cnode;
2013 }
2014
2015 // Add or replace monad input.
AddMonadInput(const CNodePtr & cnode,const AnfNodePtr & monad)2016 void AddMonadInput(const CNodePtr &cnode, const AnfNodePtr &monad) {
2017 MS_EXCEPTION_IF_NULL(cnode);
2018 constexpr size_t max_monad_inputs = 2;
2019 auto monad_abs = monad->abstract();
2020 int last = static_cast<int>(cnode->size()) - 1;
2021 int stop = last - max_monad_inputs;
2022 // Search monad in inputs, replace it if found.
2023 for (int i = last; i > 0 && i > stop; --i) {
2024 size_t index = static_cast<size_t>(i);
2025 auto input_abs = cnode->input(index)->abstract();
2026 if (input_abs && *input_abs == *monad_abs) {
2027 manager_->SetEdge(cnode, i, monad);
2028 return;
2029 }
2030 }
2031 // If monad not found in inputs, add a monad input.
2032 manager_->AddEdge(cnode, monad);
2033 }
2034
AttachMonadToOutput() const2035 void AttachMonadToOutput() const {
2036 if (u_) {
2037 AttachToOutput(u_);
2038 }
2039 if (io_) {
2040 AttachToOutput(io_);
2041 }
2042 }
2043
AttachToOutput(const AnfNodePtr & node) const2044 void AttachToOutput(const AnfNodePtr &node) const {
2045 auto output = GetGraphOutput();
2046 TraceGuard guard(std::make_shared<TraceCopy>(output->debug_info()));
2047 auto depend = NewValueNode(prim::kPrimDepend);
2048 // If isolated nodes dependencies exist.
2049 if (IsPrimitiveCNode(output, prim::kPrimDepend) &&
2050 IsPrimitiveCNode(output->cast<CNodePtr>()->input(kDependAttachNodeIndex), prim::kPrimStopGradient)) {
2051 // Insert new Depend node before isolated Depend node.
2052 auto isolated_depend = output->cast<CNodePtr>();
2053 auto &orig_output = isolated_depend->input(1);
2054 auto state_depend = func_graph_->NewCNode({depend, orig_output, node});
2055 state_depend->set_abstract(orig_output->abstract());
2056 manager_->SetEdge(isolated_depend, 1, state_depend);
2057 return;
2058 }
2059 // Insert Depend node and set it as output, if no isolated nodes.
2060 auto depend_cnode = func_graph_->NewCNode({depend, output, node});
2061 depend_cnode->set_abstract(output->abstract());
2062 func_graph_->set_output(depend_cnode);
2063 }
2064
GetGraphOutput() const2065 AnfNodePtr GetGraphOutput() const {
2066 auto output = func_graph_->output();
2067 if (output != nullptr) {
2068 return output;
2069 }
2070 return NewValueNode(kNone);
2071 }
2072
UpdateState(const AnfNodePtr & state,const AnfNodePtr & attach)2073 AnfNodePtr UpdateState(const AnfNodePtr &state, const AnfNodePtr &attach) {
2074 MS_EXCEPTION_IF_NULL(attach);
2075 auto attach_cnode = attach->cast<CNodePtr>();
2076 MS_EXCEPTION_IF_NULL(attach_cnode);
2077 // Not attach UpdateState if set kAttrIgnoreSideEffect.
2078 auto attr_ignore_side_effect = attach_cnode->GetAttr(kAttrIgnoreSideEffect);
2079 auto ignore_side_effect = attr_ignore_side_effect != nullptr && attr_ignore_side_effect->isa<BoolImm>() &&
2080 GetValue<bool>(attr_ignore_side_effect);
2081 if (ignore_side_effect) {
2082 return state;
2083 }
2084
2085 auto update_state = NewValueNode(prim::kPrimUpdateState);
2086 auto update_state_cnode = func_graph_->NewCNode({update_state, state, attach});
2087 update_state_cnode->set_abstract(state->abstract());
2088 return update_state_cnode;
2089 }
2090
GetUniverse()2091 AnfNodePtr &GetUniverse() {
2092 if (u_ == nullptr) {
2093 if (top_) {
2094 u_ = NewValueNode(kUMonad);
2095 u_->set_abstract(kUMonad->ToAbstract());
2096 } else {
2097 u_ = AddMonadParameter(func_graph_, "u", kUMonad->ToAbstract());
2098 }
2099 }
2100 return u_;
2101 }
2102
GetIoState()2103 AnfNodePtr &GetIoState() {
2104 if (io_ == nullptr) {
2105 if (top_) {
2106 io_ = NewValueNode(kIOMonad);
2107 io_->set_abstract(kIOMonad->ToAbstract());
2108 } else {
2109 io_ = AddMonadParameter(func_graph_, "io", kIOMonad->ToAbstract());
2110 }
2111 }
2112 return io_;
2113 }
2114
2115 // Return true if update_state should be used in this func graph.
2116 // In some case, update_state can be omitted, such as:
2117 // def side_effect_tail_call(args):
2118 // a = pure_func(args)
2119 // return side_effect_call(a)
NeedUpdateState() const2120 bool NeedUpdateState() const {
2121 // Search for the only one side effect cnode.
2122 CNodePtr side_effect_cnode = nullptr;
2123 for (auto &weak_cnode : func_graph_->order_list()) {
2124 const auto &cnode = weak_cnode.lock();
2125 if (cnode == nullptr) {
2126 continue;
2127 }
2128 if (HasSideEffect(cnode)) {
2129 if (side_effect_cnode != nullptr) {
2130 // There are multiple side effect cnodes, update state is required.
2131 return true;
2132 }
2133 side_effect_cnode = cnode;
2134 }
2135 }
2136 if (side_effect_cnode == nullptr) {
2137 // No side effect cnode, no update state.
2138 return false;
2139 }
2140 if (IsPrimitiveCNode(side_effect_cnode)) {
2141 // Always add update_state for primitive cnode.
2142 return true;
2143 }
2144 // If the only side effect cnode is not the tail call, update_state is required.
2145 return func_graph_->output() != side_effect_cnode;
2146 }
2147
HasSideEffect(const CNodePtr & cnode) const2148 bool HasSideEffect(const CNodePtr &cnode) const {
2149 const auto &cnode_info = GetEffectInfoFromCNode(cnode);
2150 return (cnode_info.memory || cnode_info.load || cnode_info.io);
2151 }
2152
2153 // The func graph to be converted.
2154 const FuncGraphPtr &func_graph_;
2155
2156 // The func graph manager, used for graph edge update.
2157 FuncGraphManagerPtr manager_;
2158
2159 // True if converting top graph.
2160 const bool top_;
2161
2162 // True if there are side effect cnodes within this func graph.
2163 bool has_effect_cnodes_ = false;
2164
2165 // CNodes that should not be eliminated even it is isolated node.
2166 std::vector<CNodePtr> no_eliminate_nodes_;
2167
2168 // Current memory state node, null if no memory side effects.
2169 AnfNodePtr u_;
2170
2171 // Current IO state node, null if no IO side effects.
2172 AnfNodePtr io_;
2173 }; // class AutoMonadConverter
2174 } // namespace
2175
2176 // Entry point of the auto-monad phase,
2177 // the func_graph should be resolved and infer is done.
2178 // return true if side effect nodes found in func_graph.
AutoMonad(const FuncGraphPtr & func_graph)2179 bool AutoMonad(const FuncGraphPtr &func_graph) {
2180 MS_EXCEPTION_IF_NULL(func_graph);
2181 MS_EXCEPTION_IF_NULL(func_graph->manager());
2182
2183 // Search and mark side effects for the graph and sub-graphs.
2184 // this should be called before auto-monad starts.
2185 SideEffectFinder::Search(func_graph);
2186
2187 // Execute auto-monad conversion on top graph.
2188 bool has_effects = AutoMonadConverter::Handle(func_graph, true);
2189 // Convert used sub-graphs.
2190 auto fg_used_total = func_graph->func_graphs_used_total();
2191 for (auto &fg : fg_used_total) {
2192 MS_EXCEPTION_IF_NULL(fg);
2193 auto top_flag = fg->has_flag(mindspore::kFuncGraphFlagBackPropEntry);
2194 bool fg_has_effects = AutoMonadConverter::Handle(fg, top_flag);
2195 has_effects = has_effects || fg_has_effects;
2196 }
2197 return has_effects;
2198 }
2199
ReAutoMonad(const FuncGraphPtr & func_graph)2200 bool ReAutoMonad(const FuncGraphPtr &func_graph) {
2201 // AutoMonad for bprop network, only Monad for func graphs which back propogators have side effects.
2202 // Or AutoMonad for MultitypeFuncGraph which specialized in Renormalize other than the first Specialize pass.
2203 MS_EXCEPTION_IF_NULL(func_graph);
2204 bool need_auto_monad = false;
2205 std::vector<FuncGraphPtr> auto_monaded_fg;
2206 func_graph->EraseUnusedNodeInOrder();
2207 for (auto &fg : func_graph->func_graphs_used_total()) {
2208 MS_EXCEPTION_IF_NULL(fg);
2209 if (fg->has_flag(mindspore::kFuncGraphFlagReAutoMonad)) {
2210 auto_monaded_fg.push_back(fg);
2211 for (auto &used_fg : fg->func_graphs_used_total()) {
2212 MS_EXCEPTION_IF_NULL(used_fg);
2213 used_fg->set_flag(mindspore::kFuncGraphFlagReAutoMonad, true);
2214 auto_monaded_fg.push_back(used_fg);
2215 }
2216 need_auto_monad = true;
2217 MS_LOG(DEBUG) << "AutoMonad Grad for func graph: " << fg->ToString();
2218 }
2219 fg->EraseUnusedNodeInOrder();
2220 }
2221 bool changed = false;
2222 if (need_auto_monad) {
2223 for (auto &fg : func_graph->func_graphs_used_total()) {
2224 MS_EXCEPTION_IF_NULL(fg);
2225 if (!fg->has_flag(mindspore::kFuncGraphFlagReAutoMonad)) {
2226 fg->ClearOrderList();
2227 }
2228 }
2229 changed = AutoMonad(func_graph);
2230 for (auto &fg : auto_monaded_fg) {
2231 MS_EXCEPTION_IF_NULL(fg);
2232 fg->erase_flag(mindspore::kFuncGraphFlagReAutoMonad);
2233 }
2234 // After auto monad, Order List and Isolate nodes in graph and manager will be cleared.
2235 } else {
2236 func_graph->ClearOrderList();
2237 for (auto &fg : func_graph->func_graphs_used_total()) {
2238 fg->ClearOrderList();
2239 }
2240 }
2241 return changed;
2242 }
2243 } // namespace pipeline
2244 } // namespace mindspore
2245