1 /**
2 * Copyright 2020-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "pipeline/jit/static_analysis/auto_monad.h"
18 #include <set>
19 #include <map>
20 #include <list>
21 #include <unordered_map>
22 #include <vector>
23 #include <stack>
24 #include <utility>
25 #include <algorithm>
26 #include "pipeline/jit/parse/resolve.h"
27 #include "frontend/operator/ops.h"
28 #include "frontend/operator/composite/multitype_funcgraph.h"
29 #include "utils/flags.h"
30 #include "utils/utils.h"
31 #include "utils/ordered_map.h"
32 #include "base/core_ops.h"
33 #include "abstract/abstract_value.h"
34
35 namespace mindspore {
36 namespace pipeline {
37 namespace { // namespace anonymous
38 using ClassTypePtr = std::shared_ptr<parse::ClassType>;
39 using RefInputs = OrderedMap<AnfNodePtr, std::vector<size_t>>;
40
41 // Add or get a monad parameter.
AddMonadParameter(const FuncGraphPtr & func_graph,const std::string & name,const abstract::AbstractBasePtr & abs)42 AnfNodePtr AddMonadParameter(const FuncGraphPtr &func_graph, const std::string &name,
43 const abstract::AbstractBasePtr &abs) {
44 MS_EXCEPTION_IF_NULL(func_graph);
45 size_t params_size = func_graph->parameters().size();
46 size_t io_monad_location = params_size;
47 // Search for existed parameters, return it if found.
48 for (size_t i = 0; i < params_size; i++) {
49 auto &node = func_graph->parameters()[i];
50 auto para = dyn_cast<Parameter>(node);
51 if (para == nullptr) {
52 continue;
53 }
54 auto para_abs = para->abstract();
55 if (para_abs && *para_abs == *abs) {
56 return para;
57 }
58 if (HasAbstractIOMonad(para)) {
59 io_monad_location = i;
60 }
61 }
62 // Create a new parameter if not existed.
63 auto para = std::make_shared<Parameter>(func_graph);
64 para->set_name(name);
65 para->debug_info()->set_name(name);
66 para->set_abstract(abs);
67 // If io monad parameter added before u monad parameter, should insert u monad before io monad in parameters
68 if (io_monad_location != params_size && abs->isa<abstract::AbstractUMonad>()) {
69 std::vector<AnfNodePtr> params = func_graph->parameters();
70 (void)params.insert(params.begin() + SizeToInt(io_monad_location), para);
71 func_graph->set_parameters(params);
72 } else {
73 func_graph->add_parameter(para);
74 }
75 return para;
76 }
77
78 // Gets side effect propagate attribute value from a ClassType object.
GetSideEffectPropagate(const ClassTypePtr & class_type)79 int GetSideEffectPropagate(const ClassTypePtr &class_type) {
80 if (class_type) {
81 auto obj = class_type->obj();
82 if (py::hasattr(obj, GRAPH_FLAG_SIDE_EFFECT_PROPAGATE)) {
83 auto value = py::getattr(obj, GRAPH_FLAG_SIDE_EFFECT_PROPAGATE);
84 return value.cast<int>();
85 }
86 }
87 return 0;
88 }
89
90 // Gets 'side_effect_propagate' attribute value from a primitive.
GetSideEffectPropagate(const PrimitivePtr & prim)91 int GetSideEffectPropagate(const PrimitivePtr &prim) {
92 if (prim) {
93 auto attr = prim->GetAttr(GRAPH_FLAG_SIDE_EFFECT_PROPAGATE);
94 if (attr && attr->isa<Int64Imm>()) {
95 return static_cast<int>(attr->cast<Int64ImmPtr>()->value());
96 }
97 }
98 return 0;
99 }
100
101 // Return true if the node has Ref abstract.
HasAbstractRef(const AnfNodePtr & node)102 bool HasAbstractRef(const AnfNodePtr &node) {
103 if (node == nullptr) {
104 return false;
105 }
106 auto &abs = node->abstract();
107 return (abs != nullptr) && abs->isa<abstract::AbstractRef>();
108 }
109
110 // Gets ref inputs and its indexes from a cnode.
GetRefInputs(const CNodePtr & cnode)111 RefInputs GetRefInputs(const CNodePtr &cnode) {
112 RefInputs ref_inputs;
113 MS_EXCEPTION_IF_NULL(cnode);
114 for (size_t i = 1; i < cnode->size(); ++i) {
115 auto &input = cnode->inputs().at(i);
116 if (HasAbstractRef(input)) {
117 ref_inputs[input].push_back(i);
118 }
119 }
120 return ref_inputs;
121 }
122
123 // Return true if cnode has ref input.
HasRefInput(const CNodePtr & cnode)124 bool HasRefInput(const CNodePtr &cnode) {
125 if (cnode == nullptr || cnode->inputs().empty()) {
126 return false;
127 }
128 auto &inputs = cnode->inputs();
129 // Return true if any of arguments is ref.
130 return std::any_of(inputs.begin() + 1, inputs.end(), [](const auto &input) { return HasAbstractRef(input); });
131 }
132
133 // Return true if we don't need Load for the given primitive.
134 // i.e. keep Ref as Ref for some primitives.
IsKeepRef(const PrimitivePtr & prim)135 bool IsKeepRef(const PrimitivePtr &prim) {
136 return (GetSideEffectPropagate(prim) != 0) || IsPrimitiveEquals(prim, prim::kPrimRefToEmbed) ||
137 IsPrimitiveEquals(prim, prim::kPrimPull);
138 }
139
140 // Gets primitive if the node is a primitive value node.
GetPrimitive(const AnfNodePtr & node)141 PrimitivePtr GetPrimitive(const AnfNodePtr &node) {
142 PrimitivePtr prim = GetValueNode<PrimitivePtr>(node);
143 auto do_sig = dyn_cast<mindspore::prim::DoSignaturePrimitive>(prim);
144 if (do_sig) {
145 auto val = do_sig->function();
146 return dyn_cast<Primitive>(val);
147 }
148 return prim;
149 }
150
151 // Gets primitive from the given cnode, return nullptr if cnode.inputs[0] is not a primitive.
GetPrimitive(const CNodePtr & cnode)152 PrimitivePtr GetPrimitive(const CNodePtr &cnode) {
153 if (cnode == nullptr || cnode->inputs().empty()) {
154 return nullptr;
155 }
156 return GetPrimitive(cnode->input(0));
157 }
158
159 // Gets func_graph from the given cnode, return nullptr if it is not a func graph call.
GetFuncGraph(const CNodePtr & cnode)160 FuncGraphPtr GetFuncGraph(const CNodePtr &cnode) {
161 if (cnode != nullptr && !cnode->inputs().empty()) {
162 return GetValueNode<FuncGraphPtr>(cnode->input(0));
163 }
164 return nullptr;
165 }
166
167 // Gets class_type from the given cnode->inputs[0].
GetClassType(const CNodePtr & cnode)168 ClassTypePtr GetClassType(const CNodePtr &cnode) {
169 if (cnode && !cnode->inputs().empty()) {
170 auto apply = cnode->input(0);
171 auto apply_cnode = dyn_cast<CNode>(apply);
172 if (apply_cnode && !apply_cnode->inputs().empty()) {
173 return GetValueNode<ClassTypePtr>(apply_cnode->input(0));
174 }
175 }
176 return nullptr;
177 }
178
179 // Gets first input as cnode from the given cnode,
180 // return null if input[0] is not a cnode.
GetFuncCNode(const CNodePtr & cnode)181 CNodePtr GetFuncCNode(const CNodePtr &cnode) {
182 if (cnode != nullptr && !cnode->inputs().empty()) {
183 return dyn_cast<CNode>(cnode->input(0));
184 }
185 return nullptr;
186 }
187
188 // Gets first input as function parameter from the given cnode,
189 // return null if input[0] is not a parameter.
GetFuncParameter(const CNodePtr & cnode)190 ParameterPtr GetFuncParameter(const CNodePtr &cnode) {
191 if (cnode != nullptr && !cnode->inputs().empty()) {
192 return dyn_cast<Parameter>(cnode->input(0));
193 }
194 return nullptr;
195 }
196
197 // Gets first input as MultitypeFuncGraph from the given cnode,
198 // return null if input[0] is not a MultitypeFuncGraph.
GetFuncMultitypeFuncGraph(const CNodePtr & cnode)199 prim::MultitypeFuncGraphPtr GetFuncMultitypeFuncGraph(const CNodePtr &cnode) {
200 if (cnode != nullptr && !cnode->inputs().empty()) {
201 return GetValueNode<prim::MultitypeFuncGraphPtr>(cnode->input(0));
202 }
203 return nullptr;
204 }
205
206 // --------------------------------------------------------------------
207 // SCC (Strongly Connected Components) related types.
208 // --------------------------------------------------------------------
209 using SccVector = std::set<FuncGraphPtr>;
210 using SccPtr = std::shared_ptr<SccVector>;
211 using SccMap = std::unordered_map<FuncGraphPtr, SccPtr>;
212
213 // ---------------------------------------------------------------------
214 // SccFinder find SCCs using Tarjan's algorithm.
215 // ---------------------------------------------------------------------
216 class SccFinder {
217 public:
SccFinder(const FuncGraphPtr & root)218 explicit SccFinder(const FuncGraphPtr &root) : root_(root) {}
219 ~SccFinder() = default;
Run()220 void Run() { (void)Search(root_); }
scc_map() const221 const SccMap &scc_map() const { return scc_map_; }
222
223 private:
224 // Save state of a func graph.
225 struct State {
226 size_t index = 0;
227 size_t lowlink = 0;
228 bool in_stack = false;
Statemindspore::pipeline::__anon897cf14f0111::SccFinder::State229 explicit State(size_t index) : index(index), lowlink(index), in_stack(false) {}
230 ~State() = default;
231 };
232
233 // Search SCCs from the given graph.
Search(FuncGraphPtr graph)234 const State &Search(FuncGraphPtr graph) {
235 // Create graph state, set it as visited.
236 MS_EXCEPTION_IF_NULL(graph);
237 auto [inserted, ok] = visited_.emplace(graph, State(index_++));
238 if (!ok) {
239 MS_LOG(EXCEPTION) << "Already visited: " << graph->ToString();
240 }
241 auto &state = inserted->second;
242 // Push visited graph to stack.
243 stack_.push(graph);
244 state.in_stack = true;
245 // Search successor graphs.
246 for (auto &used : graph->func_graphs_used()) {
247 auto &sg = used.first;
248 auto iter = visited_.find(sg);
249 if (iter == visited_.end()) {
250 // Successor graph has not yet been visited, recurse on it.
251 auto &sg_state = Search(sg);
252 state.lowlink = std::min(state.lowlink, sg_state.lowlink);
253 } else if (iter->second.in_stack) {
254 // Successor graph is in stack and hence in the current SCC.
255 state.lowlink = std::min(state.lowlink, iter->second.index);
256 }
257 }
258 // If index == lowlink, this means it is the root of SCC.
259 if (state.index == state.lowlink) {
260 // Pop members of the SCC from stack, they are on top of its root.
261 auto scc = std::make_shared<SccVector>();
262 while (!stack_.empty()) {
263 auto g = stack_.top();
264 stack_.pop();
265 auto found = visited_.find(g);
266 if (found == visited_.end()) {
267 MS_LOG(EXCEPTION) << "Unexpected graph: " << g->ToString();
268 }
269 found->second.in_stack = false;
270 // Add graph to SCC, and create the map from graph to SCC.
271 scc->insert(g);
272 scc_map_.emplace(g, scc);
273 if (g == graph) {
274 break;
275 }
276 }
277 // SCC should not be empty.
278 if (scc->empty()) {
279 MS_LOG(EXCEPTION) << "Invalid SCC for: " << graph->ToString();
280 }
281 }
282 return state;
283 }
284
285 // The root graph.
286 FuncGraphPtr root_;
287
288 // Current index by DFS order.
289 size_t index_ = 1;
290
291 // Visited graphs and their states.
292 std::unordered_map<FuncGraphPtr, State> visited_;
293
294 // The stack for Tarjan algorithm.
295 std::stack<FuncGraphPtr> stack_;
296
297 // The result SCC map, from graph to its SCC.
298 SccMap scc_map_;
299 };
300
301 struct SwitchLayerCall {
302 CNodePtr caller;
303 EffectInfo effect_info;
304 std::vector<FuncGraphPtr> branches;
305 };
306
307 // -------------------------------------------------------------------------------
308 // SideEffectFinder search and mark side effects for graph and its sub-graphs.
309 // -------------------------------------------------------------------------------
310 class SideEffectFinder {
311 public:
Search(const FuncGraphPtr & root)312 static void Search(const FuncGraphPtr &root) {
313 SideEffectFinder finder(root);
314 finder.Run();
315 }
316
317 private:
SideEffectFinder(const FuncGraphPtr & root)318 explicit SideEffectFinder(const FuncGraphPtr &root) : root_(root) {}
319 ~SideEffectFinder() = default;
320
Run()321 void Run() {
322 // To handle recursive calls, we generate SCC map before search.
323 GenerateSccMap();
324 // Update order list to include outer cnodes.
325 UpdateOrderLists();
326 // Find side effects by DFS from the top graph.
327 (void)GetEffectInfo(root_);
328 // Check switch layer calls, add monad arguments if need.
329 HandleSwitchLayerCalls();
330 }
331
UpdateOrderLists() const332 void UpdateOrderLists() const {
333 // Some cnodes used in current func graph but belong to other func graph, we have to
334 // insert them into order list so that we can handle side effects for them.
335 UpdateOrderList(root_);
336 for (auto &fg : root_->func_graphs_used_total()) {
337 UpdateOrderList(fg);
338 }
339 }
340
UpdateOrderList(const FuncGraphPtr & func_graph)341 static void UpdateOrderList(const FuncGraphPtr &func_graph) {
342 MS_EXCEPTION_IF_NULL(func_graph);
343 OrderedSet<CNodePtr> new_order_list;
344 const auto &order_list = func_graph->order_list();
345 for (auto &cnode : order_list) {
346 PushToOrderList(func_graph, cnode, &new_order_list);
347 }
348 func_graph->set_order_list(std::move(new_order_list));
349 }
350
PushToOrderList(const FuncGraphPtr & fg,const CNodePtr & cnode,OrderedSet<CNodePtr> * new_order_list)351 static void PushToOrderList(const FuncGraphPtr &fg, const CNodePtr &cnode, OrderedSet<CNodePtr> *new_order_list) {
352 MS_EXCEPTION_IF_NULL(cnode);
353 MS_EXCEPTION_IF_NULL(new_order_list);
354 if (new_order_list->contains(cnode)) {
355 return;
356 }
357 for (auto &input : cnode->inputs()) {
358 auto input_cnode = dyn_cast<CNode>(input);
359 if (input_cnode != nullptr && input_cnode->func_graph() != fg) {
360 PushToOrderList(fg, input_cnode, new_order_list);
361 }
362 }
363 new_order_list->push_back(cnode);
364 }
365
366 // Generate SCC map by SccFinder.
GenerateSccMap()367 void GenerateSccMap() {
368 SccFinder scc_finder(root_);
369 scc_finder.Run();
370 scc_map_ = std::move(scc_finder.scc_map());
371 }
372
373 // Gets branch graph from a switch cnode at given input index.
GetSwitchBranch(const CNodePtr & cnode,size_t index)374 FuncGraphPtr GetSwitchBranch(const CNodePtr &cnode, size_t index) {
375 MS_EXCEPTION_IF_NULL(cnode);
376 return GetValueNode<FuncGraphPtr>(cnode->inputs().at(index));
377 }
378
379 // Gets branch graphs from a switch cnode.
GetSwitchBranches(const CNodePtr & cnode)380 std::vector<FuncGraphPtr> GetSwitchBranches(const CNodePtr &cnode) {
381 MS_EXCEPTION_IF_NULL(cnode);
382 constexpr size_t switch_cnode_size = 4;
383 constexpr size_t true_index = 2;
384 constexpr size_t false_index = 3;
385 // Check size.
386 if (cnode->size() != switch_cnode_size) {
387 MS_LOG(EXCEPTION) << "Invalid switch: " << cnode->DebugString();
388 }
389 // Add both branches, in some case, only one branch is set.
390 std::vector<FuncGraphPtr> branches;
391 auto true_branch = GetSwitchBranch(cnode, true_index);
392 if (true_branch != nullptr) {
393 branches.emplace_back(true_branch);
394 }
395 auto false_branch = GetSwitchBranch(cnode, false_index);
396 if (false_branch != nullptr) {
397 branches.emplace_back(false_branch);
398 }
399 if (branches.empty()) {
400 MS_LOG(EXCEPTION) << "Invalid switch: " << cnode->DebugString();
401 }
402 return branches;
403 }
404
405 // Add monad parameter to switch branch graphs.
AddMonadParameters(const std::vector<FuncGraphPtr> & branches,const std::string & name,const AbstractBasePtr & abs)406 void AddMonadParameters(const std::vector<FuncGraphPtr> &branches, const std::string &name,
407 const AbstractBasePtr &abs) {
408 for (auto &branch : branches) {
409 (void)AddMonadParameter(branch, name, abs);
410 }
411 }
412
413 // Trace effect info for Switch cnode.
TraceSwitchEffectInfo(const CNodePtr & cnode)414 EffectInfo TraceSwitchEffectInfo(const CNodePtr &cnode) {
415 // Find branches from switch cnode.
416 auto branches = GetSwitchBranches(cnode);
417 // For some case, only one branch is set.
418 if (branches.size() == 1) {
419 auto &branch = branches.front();
420 // Save branch caller, so that we can update arguments for the caller.
421 SaveBranchCaller(cnode, branch);
422 return GetEffectInfo(branch);
423 }
424 // When both branches are set, merge their effect infos.
425 EffectInfo info = MergeEffectInfo(branches);
426 if (info.state == EffectInfo::kDetected) {
427 // Setup both branches according the merged effect info.
428 SetupEffectBranches(info, branches);
429 }
430 return info;
431 }
432
433 // Trace effect info for SwitchLayer cnode.
TraceSwitchLayerEffectInfo(const CNodePtr & cnode)434 EffectInfo TraceSwitchLayerEffectInfo(const CNodePtr &cnode) {
435 // Find branches from switch_layer cnode.
436 auto branches = GetSwitchLayerBranches(cnode);
437 // Merge effect info from all branches.
438 EffectInfo info = MergeEffectInfo(branches);
439 if (info.state == EffectInfo::kDetected) {
440 // Setup branches according the merged effect info.
441 SetupEffectBranches(info, branches);
442 // Save the switch_layer call, so that we can add monad argument for it if need.
443 auto &call = switch_layer_calls_.emplace_back();
444 call.caller = caller_;
445 call.effect_info = info;
446 call.branches = move(branches);
447 }
448 return info;
449 }
450
HandleSwitchLayerCalls()451 void HandleSwitchLayerCalls() {
452 for (auto &call : switch_layer_calls_) {
453 const auto &info = call.effect_info;
454 const auto &branches = call.branches;
455 auto new_info = MergeEffectInfo(branches);
456 // Reset branches if effect info changed.
457 if (new_info.memory != info.memory || new_info.load != info.load || new_info.io != info.io) {
458 AddMonadForCaller(call.caller, new_info);
459 SetupEffectBranches(new_info, branches);
460 }
461 }
462 }
463
464 // Gets branch graphs from a switch_layer cnode.
GetSwitchLayerBranches(const CNodePtr & cnode)465 std::vector<FuncGraphPtr> GetSwitchLayerBranches(const CNodePtr &cnode) {
466 MS_EXCEPTION_IF_NULL(cnode);
467 constexpr size_t func_tuple_index = 2;
468 if (cnode->size() <= func_tuple_index) {
469 MS_LOG(EXCEPTION) << "Invalid switch_layer: " << cnode->DebugString(2);
470 }
471 auto func_tuple = cnode->inputs().at(func_tuple_index);
472 return GetGraphsFromTuple(func_tuple);
473 }
474
475 // Get and trace graphs from a tuple of func node for switch_layer.
GetGraphsFromTuple(const AnfNodePtr & func_tuple)476 std::vector<FuncGraphPtr> GetGraphsFromTuple(const AnfNodePtr &func_tuple) {
477 // The func tuple maker.
478 if (IsPrimitiveCNode(func_tuple, prim::kPrimMakeTuple)) {
479 return GetGraphsFromMakeTuple(func_tuple->cast<CNodePtr>());
480 }
481 // Trace tuple from parameter.
482 auto para = dyn_cast<Parameter>(func_tuple);
483 if (para != nullptr) {
484 std::vector<FuncGraphPtr> graphs;
485 ForEachRealArguments(para,
486 [this, &graphs](const AnfNodePtr &arg) { graphs = std::move(GetGraphsFromTuple(arg)); });
487 return graphs;
488 }
489 // Trace tuple returned from func graph call.
490 auto cnode = dyn_cast<CNode>(func_tuple);
491 auto func_graph = GetFuncGraph(cnode);
492 if (func_graph != nullptr) {
493 return GetGraphsFromTuple(func_graph->output());
494 }
495 MS_LOG(EXCEPTION) << "Invalid input for switch_layer: func_graph is nullptr.";
496 }
497
498 // Get graphs from a tuple of funcs make node for switch_layer.
GetGraphsFromMakeTuple(const CNodePtr & make_tuple)499 std::vector<FuncGraphPtr> GetGraphsFromMakeTuple(const CNodePtr &make_tuple) {
500 MS_EXCEPTION_IF_NULL(make_tuple);
501 auto &inputs = make_tuple->inputs();
502 if (inputs.size() <= 1) {
503 MS_LOG(EXCEPTION) << "Invalid make_tuple for switch_layer: " << make_tuple->DebugString(2);
504 }
505 std::vector<FuncGraphPtr> graphs;
506 graphs.reserve(inputs.size() - 1);
507 for (size_t i = 1; i < inputs.size(); ++i) {
508 auto func_graph = GetValueNode<FuncGraphPtr>(inputs.at(i));
509 if (func_graph == nullptr) {
510 MS_LOG(WARNING) << "Non-graph found in switch_layer input: " << make_tuple->DebugString(2) << " index=" << i;
511 continue;
512 }
513 graphs.push_back(func_graph);
514 }
515 return graphs;
516 }
517
518 // Trace effect info from tuple_getitem cnode.
TraceTupleGetItemEffectInfo(const CNodePtr & cnode,std::stack<int64_t> * tuple_indexes)519 EffectInfo TraceTupleGetItemEffectInfo(const CNodePtr &cnode, std::stack<int64_t> *tuple_indexes) {
520 constexpr size_t tuple_input = 1;
521 constexpr size_t index_input = 2;
522 constexpr size_t cnode_size = 3;
523 if (cnode->size() != cnode_size) {
524 MS_LOG(EXCEPTION) << "Invalid tuple_getitem: " << cnode->DebugString();
525 }
526 // Get item index.
527 auto &index_node = cnode->inputs().at(index_input);
528 auto index_value = GetValueNode<Int64ImmPtr>(index_node);
529 if (index_value == nullptr) {
530 MS_LOG(EXCEPTION) << "Tuple_getitem with non-const index " << cnode->DebugString();
531 }
532 int64_t index = index_value->value();
533
534 // Get tuple value.
535 const auto &tuple_node = cnode->inputs().at(tuple_input);
536 // Push tuple index.
537 tuple_indexes->push(index);
538 return TraceTupleEffectInfo(tuple_node, tuple_indexes);
539 }
540
TraceTupleEffectInfo(const AnfNodePtr & tuple_node,std::stack<int64_t> * tuple_indexes)541 EffectInfo TraceTupleEffectInfo(const AnfNodePtr &tuple_node, std::stack<int64_t> *tuple_indexes) {
542 MS_EXCEPTION_IF_NULL(tuple_indexes);
543 auto para = dyn_cast<Parameter>(tuple_node);
544 if (para != nullptr) {
545 return TraceTupleParaEffectInfo(para, *tuple_indexes);
546 }
547 auto tuple_cnode = dyn_cast<CNode>(tuple_node);
548 if (tuple_cnode != nullptr) {
549 return TraceTupleCNodeEffectInfo(tuple_cnode, tuple_indexes);
550 }
551 // Should not reach here.
552 MS_LOG(EXCEPTION) << "Side effects untraceable: tuple_cnode is nullptr.";
553 }
554
TraceTupleParaEffectInfo(const ParameterPtr & para,const std::stack<int64_t> & tuple_indexes)555 EffectInfo TraceTupleParaEffectInfo(const ParameterPtr ¶, const std::stack<int64_t> &tuple_indexes) {
556 EffectInfo info{EffectInfo::kDetected, false, false, false};
557 ForEachRealArguments(para, [this, &info, tuple_indexes](const AnfNodePtr &arg) {
558 // Merge real argument effect info.
559 auto tuple_indexes_copy = tuple_indexes;
560 auto arg_info = TraceTupleEffectInfo(arg, &tuple_indexes_copy);
561 info.Merge(arg_info);
562 });
563 return info;
564 }
565
TraceTupleCNodeEffectInfo(const CNodePtr & cnode,std::stack<int64_t> * tuple_indexes)566 EffectInfo TraceTupleCNodeEffectInfo(const CNodePtr &cnode, std::stack<int64_t> *tuple_indexes) {
567 MS_EXCEPTION_IF_NULL(tuple_indexes);
568 MS_EXCEPTION_IF_NULL(cnode);
569 auto prim = GetPrimitive(cnode);
570 // Trace MakeTuple.
571 if (IsPrimitiveEquals(prim, prim::kPrimMakeTuple)) {
572 if (tuple_indexes->empty()) {
573 MS_LOG(EXCEPTION) << "Unexpected make_tuple: " << cnode->DebugString(2);
574 return {EffectInfo::kDetected, false, false, false};
575 }
576 // Pop out tuple index.
577 auto top_index = tuple_indexes->top();
578 tuple_indexes->pop();
579 size_t input_index = 0;
580 // Support tuple index is negative
581 if (top_index < 0) {
582 if (SizeToLong(cnode->size()) + top_index < 0) {
583 MS_LOG(EXCEPTION) << "Invalid make_tuple: " << cnode->DebugString() << " index=" << top_index;
584 }
585 input_index = static_cast<size_t>(cnode->size() + top_index);
586 } else {
587 // Follow the tuple item according the index.
588 input_index = static_cast<size_t>(top_index) + 1;
589 }
590 if (input_index >= cnode->size()) {
591 MS_LOG(EXCEPTION) << "Invalid make_tuple: " << cnode->DebugString() << " index=" << top_index;
592 }
593 if (tuple_indexes->empty()) {
594 // Trace non-tuple.
595 return TraceEffectInfo(cnode->inputs().at(input_index));
596 }
597 // This is the tuple of tuple case.
598 return TraceTupleEffectInfo(cnode->inputs().at(input_index), tuple_indexes);
599 }
600 // Trace TupleGetItem (tuple of tuple).
601 if (IsPrimitiveEquals(prim, prim::kPrimTupleGetItem)) {
602 return TraceTupleGetItemEffectInfo(cnode, tuple_indexes);
603 }
604 // Trace primitive propagating side effect from its input, such as Depend, Identity, etc.
605 int input_index = GetSideEffectPropagate(prim);
606 if (input_index > 0 && input_index < static_cast<int>(cnode->size())) {
607 return TraceTupleEffectInfo(cnode->input(static_cast<size_t>(input_index)), tuple_indexes);
608 }
609 // Tuple returned from func graph call.
610 auto func_graph = GetFuncGraph(cnode);
611 if (func_graph != nullptr) {
612 return TraceTupleEffectInfo(func_graph->output(), tuple_indexes);
613 }
614 // Tuple returned from a Switch call.
615 if (cnode->size() == 1 && IsPrimitiveCNode(cnode->input(0), prim::kPrimSwitch)) {
616 return TraceTupleFromSwitch(cnode->input(0)->cast<CNodePtr>(), *tuple_indexes);
617 }
618 // Tuple is returned from J().
619 // %1 = J(primal)
620 // tuple = %1(args)
621 if (cnode->size() > 0 && IsPrimitiveCNode(cnode->input(0), prim::kPrimJ)) {
622 MS_LOG(DEBUG) << "Tuple from J: " << cnode->DebugString(2);
623 return {EffectInfo::kDetected, false, false, false};
624 }
625 // Rare case.
626 MS_LOG(WARNING) << "Tuple untraceable from: " << cnode->DebugString(2);
627 return {EffectInfo::kDetected, false, false, false};
628 }
629
630 // Trace effect info from a Switch node that output is a tuple.
TraceTupleFromSwitch(const CNodePtr & switch_cnode,const std::stack<int64_t> & tuple_indexes)631 EffectInfo TraceTupleFromSwitch(const CNodePtr &switch_cnode, const std::stack<int64_t> &tuple_indexes) {
632 auto branches = GetSwitchBranches(switch_cnode);
633 EffectInfo info = {EffectInfo::kDetected, false, false, false};
634 for (auto &branch : branches) {
635 auto tuple_indexes_copy = tuple_indexes;
636 EffectInfo branch_info = TraceTupleEffectInfo(branch->output(), &tuple_indexes_copy);
637 info.Merge(branch_info);
638 }
639 return info;
640 }
641
642 // Setup all branches according the effect info.
SetupEffectBranches(const EffectInfo & info,const std::vector<FuncGraphPtr> & branches)643 void SetupEffectBranches(const EffectInfo &info, const std::vector<FuncGraphPtr> &branches) {
644 // Setup monad parameters for all branches according the effect info.
645 if (info.memory || info.load) {
646 AddMonadParameters(branches, "u", kUMonad->ToAbstract());
647 }
648 if (info.io) {
649 AddMonadParameters(branches, "io", kIOMonad->ToAbstract());
650 }
651 // Set merged effect info to both branches.
652 for (auto &branch : branches) {
653 MS_EXCEPTION_IF_NULL(branch);
654 branch->SetEffectInfo(info);
655 // Update caller if it is existed.
656 UpdateBranchCaller(branch);
657 }
658 }
659
660 // Merge effect info for switch or switch_layer branch graphs.
MergeEffectInfo(const std::vector<FuncGraphPtr> & branches)661 EffectInfo MergeEffectInfo(const std::vector<FuncGraphPtr> &branches) {
662 EffectInfo info = {EffectInfo::kDetected, false, false, false};
663 for (auto &branch : branches) {
664 MS_EXCEPTION_IF_NULL(branch);
665 EffectInfo branch_info = GetEffectInfo(branch);
666 info.Merge(branch_info);
667 }
668 return info;
669 }
670
671 // Trace a cnode for effect info.
TraceEffectInfo(const CNodePtr & cnode)672 EffectInfo TraceEffectInfo(const CNodePtr &cnode) {
673 MS_EXCEPTION_IF_NULL(cnode);
674 auto prim = GetPrimitive(cnode);
675 if (IsPrimitiveEquals(prim, prim::kPrimSwitch)) {
676 // Special handling for Switch primitive.
677 return TraceSwitchEffectInfo(cnode);
678 }
679
680 if (IsPrimitiveEquals(prim, prim::kPrimSwitchLayer)) {
681 // Special handling for SwitchLayer primitive.
682 return TraceSwitchLayerEffectInfo(cnode);
683 }
684
685 if (IsPrimitiveEquals(prim, prim::kPrimTupleGetItem)) {
686 // Trace tuple_getitem.
687 std::stack<int64_t> tuple_indexes;
688 return TraceTupleGetItemEffectInfo(cnode, &tuple_indexes);
689 }
690
691 // For high-order pritimive such as Partial,
692 // we trace effect info from its argument.
693 int index_prim = GetSideEffectPropagate(prim);
694 if (index_prim > 0 && index_prim < static_cast<int>(cnode->size())) {
695 return TraceEffectInfo(cnode->input(static_cast<size_t>(index_prim)));
696 }
697
698 // For func graph calls, we trace effect info from graph output.
699 auto called_graph = GetFuncGraph(cnode);
700 if (called_graph) {
701 return TraceEffectInfo(called_graph->output());
702 }
703
704 //
705 // For ClassType as the input[0], if it is a primitive class
706 // with 'side_effect_propagate' attribute, we trace side effect
707 // from its argument indxed by the attribute value.
708 //
709 // e.g.:
710 // setpara = P.Partial()(P.Assign, self.para)
711 // setpara(x)
712 //
713 auto class_type = GetClassType(cnode);
714 if (class_type) {
715 int index = GetSideEffectPropagate(class_type);
716 if (index > 0 && index < static_cast<int>(cnode->size())) {
717 return TraceEffectInfo(cnode->input(static_cast<size_t>(index)));
718 }
719 }
720
721 // Otherwise, no side effect found and stop trace.
722 return {EffectInfo::kDetected, false, false, false};
723 }
724
725 // Trace an ANFNode for effect info.
TraceEffectInfo(const AnfNodePtr & node)726 EffectInfo TraceEffectInfo(const AnfNodePtr &node) {
727 if (node) {
728 // Trace cnode.
729 auto cnode = node->cast<CNodePtr>();
730 if (cnode) {
731 return TraceEffectInfo(cnode);
732 }
733
734 // Trace parameter.
735 auto para = node->cast<ParameterPtr>();
736 if (para) {
737 return TraceEffectInfo(para);
738 }
739
740 // Trace primitive.
741 auto prim = GetPrimitive(node);
742 if (prim) {
743 return GetPrimEffectInfo(prim);
744 }
745
746 // Trace func graph.
747 auto value_node = node->cast<ValueNodePtr>();
748 if (value_node && value_node->value()) {
749 auto graph = value_node->value()->cast<FuncGraphPtr>();
750 if (graph) {
751 return GetEffectInfo(graph);
752 }
753 }
754 }
755 // Something is wrong if we reached here.
756 MS_LOG(WARNING) << "EffectInfo untraceable: node is a nullptr.";
757 return {EffectInfo::kDetected, false, false, false};
758 }
759
GetParameterIndex(const FuncGraphPtr & func_graph,const ParameterPtr & para)760 int GetParameterIndex(const FuncGraphPtr &func_graph, const ParameterPtr ¶) {
761 int parameter_index = 0;
762 for (auto ¶meter : func_graph->parameters()) {
763 if (para == parameter) {
764 return parameter_index;
765 }
766 ++parameter_index;
767 }
768 MS_LOG(EXCEPTION) << "Parameter not found: " << (para ? para->DebugString() : "<null>");
769 }
770
771 // Trace effect info from function parameter.
TraceEffectInfo(const ParameterPtr & para)772 EffectInfo TraceEffectInfo(const ParameterPtr ¶) {
773 EffectInfo info{EffectInfo::kDetected, false, false, false};
774 ForEachRealArguments(para, [this, &info](const AnfNodePtr &arg) {
775 // Merge caller input effect info.
776 auto input_info = TraceEffectInfo(arg);
777 info.Merge(input_info);
778 });
779 return info;
780 }
781
ForEachRealArguments(const ParameterPtr & para,const std::function<void (const AnfNodePtr &)> & handler)782 void ForEachRealArguments(const ParameterPtr ¶, const std::function<void(const AnfNodePtr &)> &handler) {
783 MS_EXCEPTION_IF_NULL(para);
784 auto func_graph = para->func_graph();
785 MS_EXCEPTION_IF_NULL(func_graph);
786 // Find index of the parameter, starts from 0.
787 const int para_index = GetParameterIndex(func_graph, para);
788 const size_t input_index = static_cast<size_t>(para_index) + 1;
789 // Search user cnodes of the func graph.
790 auto &users = func_graph->func_graph_cnodes_index();
791 if (users.empty()) {
792 MS_LOG(WARNING) << "Unused graph for parameter " << para->DebugString();
793 }
794 for (auto &user : users) {
795 auto use_index = user.first->second;
796 if (use_index != 0) {
797 // Skip non-caller usage.
798 continue;
799 }
800 // Caller cnode.
801 auto cnode = dyn_cast<CNode>(user.first->first);
802 MS_EXCEPTION_IF_NULL(cnode);
803 if (cnode && input_index < cnode->size()) {
804 auto &real_arg = cnode->input(input_index);
805 if (real_arg == para) {
806 // Skip if the real argument is the given parameter.
807 continue;
808 }
809 handler(real_arg);
810 }
811 }
812 }
813
814 // For call node, returns effect info of the callee graph.
GetCallEffectInfo(const CNodePtr & cnode)815 EffectInfo GetCallEffectInfo(const CNodePtr &cnode) {
816 MS_EXCEPTION_IF_NULL(cnode);
817 constexpr size_t min_call_node_size = 2;
818 if (cnode->size() < min_call_node_size) {
819 MS_LOG(EXCEPTION) << "Invalid call node: " << cnode->DebugString();
820 }
821 auto func_graph = GetValueNode<FuncGraphPtr>(cnode->inputs().at(1));
822 if (func_graph == nullptr) {
823 MS_LOG(EXCEPTION) << "Invalid call node: " << cnode->DebugString();
824 }
825 return GetEffectInfo(func_graph);
826 }
827
828 // Detect effect info by depth first search.
DetectEffectInfo(const CNodePtr & cnode)829 EffectInfo DetectEffectInfo(const CNodePtr &cnode) {
830 // For primitive, get effect info from its attributes and inputs.
831 auto prim = GetPrimitive(cnode);
832 if (prim) {
833 // Skip 'return' cnode.
834 if (IsPrimitiveEquals(prim, prim::kPrimReturn)) {
835 return {EffectInfo::kDetected, false, false, false};
836 }
837 // Special handling for 'call' cnode.
838 if (IsPrimitiveEquals(prim, prim::kPrimCall)) {
839 return GetCallEffectInfo(cnode);
840 }
841 auto info = GetPrimEffectInfo(prim);
842 if (!info.memory && !IsKeepRef(prim)) {
843 // For primitive calls, if no memory effects but
844 // Ref parameter used, we will insert 'load' before them.
845 // Except for primitives like J(f) or Partial(f, x) which propagate side effect,
846 // load is inserted inside the func_graph f.
847 info.load = HasRefInput(cnode);
848 }
849 return info;
850 }
851
852 // For func graph, detect effect info by its children cnodes.
853 auto func_graph = GetFuncGraph(cnode);
854 if (func_graph) {
855 return GetEffectInfo(func_graph);
856 }
857
858 // When input[0] is a cnode, it is a function returned from
859 // a high-order function call, we trace it by return value.
860 auto func_cnode = GetFuncCNode(cnode);
861 if (func_cnode) {
862 caller_ = cnode;
863 return TraceEffectInfo(func_cnode);
864 }
865
866 // When input[0] is a parameter, it is a function parameter for
867 // the high-order function, we trace it by caller.
868 auto func_para = GetFuncParameter(cnode);
869 if (func_para) {
870 return TraceEffectInfo(func_para);
871 }
872
873 // When input[0] is a MultitypeFuncGraph, it's not specialized
874 // as one of its parameters is AbstractUndertermined,
875 // This MultitypeFuncGraph may be specialized at next Renormalize
876 // process, but we have to keep the order by insert UMonad now,
877 // otherwise order will be lost in next Renormalize.
878 // So assume it has memory side effect conservatively.
879 auto func_multitype = GetFuncMultitypeFuncGraph(cnode);
880 if (func_multitype) {
881 MS_LOG(DEBUG) << "Assume memory side effect for: " << cnode->DebugString();
882 return {EffectInfo::kDetected, true, false, false};
883 }
884
885 MS_LOG(WARNING) << "Side effect undetectable: " << cnode->DebugString(2);
886 return {EffectInfo::kDetected, false, false, false};
887 }
888
889 // Gets EffectInfo for CNode.
GetEffectInfo(const CNodePtr & cnode)890 EffectInfo GetEffectInfo(const CNodePtr &cnode) {
891 const auto &effect_info = cnode->GetEffectInfo();
892 if (effect_info.state == EffectInfo::kDetected) {
893 // Effect info already detected, return it.
894 return effect_info;
895 }
896
897 // Detect effect info for the cnode.
898 EffectInfo info = DetectEffectInfo(cnode);
899 if (info.state == EffectInfo::kDetected) {
900 // Save detected info into cnode.
901 cnode->SetEffectInfo(info);
902 }
903 return info;
904 }
905
906 // Gets SCC that the given graph belongs to.
GetScc(const FuncGraphPtr & func_graph) const907 const SccPtr &GetScc(const FuncGraphPtr &func_graph) const {
908 auto found = scc_map_.find(func_graph);
909 if (found == scc_map_.end()) {
910 MS_LOG(EXCEPTION) << "SCC not found for " << (func_graph ? func_graph->ToString() : "FG(null)");
911 }
912 return found->second;
913 }
914
915 // Set effect info for all member graphs in the SCC.
SetSccEffectInfo(const SccPtr & scc,const EffectInfo & info) const916 void SetSccEffectInfo(const SccPtr &scc, const EffectInfo &info) const {
917 MS_EXCEPTION_IF_NULL(scc);
918 for (auto &g : *scc) {
919 MS_EXCEPTION_IF_NULL(g);
920 g->SetEffectInfo(info);
921 }
922 }
923
924 // Gets EffectInfo for func graph.
GetEffectInfo(const FuncGraphPtr & func_graph)925 EffectInfo GetEffectInfo(const FuncGraphPtr &func_graph) {
926 MS_EXCEPTION_IF_NULL(func_graph);
927 const auto &effect_info = func_graph->GetEffectInfo();
928 if (effect_info.state != EffectInfo::kUnknown) {
929 // Effect info already set, return it.
930 return effect_info;
931 }
932 // Get SCC that this graph belongs to.
933 auto &scc = GetScc(func_graph);
934 MS_EXCEPTION_IF_NULL(scc);
935 // To prevent SCC members be visited again, we set effect info
936 // to 'kDetecting' state before start to check cnodes.
937 EffectInfo info{EffectInfo::kDetecting, false, false, false};
938 SetSccEffectInfo(scc, info);
939 // Check side effects for all cnodes in the SCC.
940 std::vector<CNodePtr> undetected;
941 for (auto &g : *scc) {
942 MS_EXCEPTION_IF_NULL(g);
943 for (auto &cnode : g->order_list()) {
944 auto cnode_effect = GetEffectInfo(cnode);
945 if (cnode_effect.state != EffectInfo::kDetected) {
946 // For side effect undetected node, it could be a call to the SCC member graph,
947 // we will try to check side effect again after SCC side effect detected.
948 undetected.push_back(cnode);
949 }
950 // Merge effect info from the node.
951 info.Merge(cnode_effect);
952 }
953 // Make sure all sub-graphs is checked. since some sub-graphs may not directly called,
954 // for example: return ValueNode(sub_graph).
955 for (auto &sg : g->func_graphs_used()) {
956 (void)GetEffectInfo(sg.first);
957 }
958 }
959 // Update effect into for all members of the SCC.
960 info.state = EffectInfo::kDetected;
961 SetSccEffectInfo(scc, info);
962 // Check undetected cnodes again after side effect of the SCC is detected.
963 for (auto &cnode : undetected) {
964 MS_EXCEPTION_IF_NULL(cnode);
965 auto cnode_effect = GetEffectInfo(cnode);
966 // Side effect should be detected now.
967 if (cnode_effect.state != EffectInfo::kDetected) {
968 MS_LOG(EXCEPTION) << "Side effect is undectable: " << cnode->DebugString();
969 }
970 }
971 // graph which need PipelineSplit doesn't have effect.
972 if (func_graph->stage() != -1) {
973 info.memory = false;
974 info.load = false;
975 info.io = false;
976 }
977 return info;
978 }
979
SaveBranchCaller(const CNodePtr & switch_node,const FuncGraphPtr & branch)980 void SaveBranchCaller(const CNodePtr &switch_node, const FuncGraphPtr &branch) {
981 MS_EXCEPTION_IF_NULL(branch);
982 MS_EXCEPTION_IF_NULL(switch_node);
983 auto manager = branch->manager();
984 MS_EXCEPTION_IF_NULL(manager);
985 auto &node_users = manager->node_users();
986 auto found = node_users.find(switch_node);
987 if (found == node_users.end()) {
988 MS_LOG(WARNING) << "Caller not found for " << switch_node->DebugString();
989 return;
990 }
991 if (found->second.size() != 1) {
992 MS_LOG(WARNING) << "Wrong callers " << found->second.size() << " for " << switch_node->DebugString();
993 return;
994 }
995 auto &user = *found->second.begin();
996 auto cnode = dyn_cast<CNode>(user.first);
997 if (cnode != nullptr || user.second == 0) {
998 branch_caller_map.emplace(branch, cnode);
999 }
1000 }
1001
UpdateBranchCaller(const FuncGraphPtr & branch)1002 void UpdateBranchCaller(const FuncGraphPtr &branch) {
1003 MS_EXCEPTION_IF_NULL(branch);
1004 auto iter = branch_caller_map.find(branch);
1005 if (iter == branch_caller_map.end()) {
1006 return;
1007 }
1008 const auto &caller = iter->second;
1009 const auto &info = branch->GetEffectInfo();
1010 AddMonadForCaller(caller, info);
1011 }
1012
AddMonadForCaller(const CNodePtr & caller,const EffectInfo & info)1013 void AddMonadForCaller(const CNodePtr &caller, const EffectInfo &info) {
1014 if (info.memory || info.load) {
1015 // Add u monad argument to caller if need.
1016 AddMonadArgument(caller, kUMonad);
1017 }
1018 if (info.io) {
1019 // Add io monad argument to caller if need.
1020 AddMonadArgument(caller, kIOMonad);
1021 }
1022 }
1023
AddMonadArgument(const CNodePtr & cnode,const ValuePtr & monad)1024 void AddMonadArgument(const CNodePtr &cnode, const ValuePtr &monad) {
1025 MS_EXCEPTION_IF_NULL(cnode);
1026 MS_EXCEPTION_IF_NULL(monad);
1027 auto monad_abs = monad->ToAbstract();
1028 for (size_t i = 1; i < cnode->size(); ++i) {
1029 auto abs = cnode->inputs().at(i)->abstract();
1030 if (abs != nullptr && *abs == *monad_abs) {
1031 // Skip if monad argument already existed.
1032 return;
1033 }
1034 }
1035 // Add monad argument if not yet.
1036 auto monad_input = NewValueNode(monad);
1037 monad_input->set_abstract(monad_abs);
1038 if ((monad == kUMonad) && cnode->size() > 1 && HasAbstractIOMonad(cnode->inputs().back())) {
1039 // Insert u monad before io monad.
1040 size_t last_index = cnode->size() - 1;
1041 cnode->add_input(cnode->input(last_index));
1042 cnode->set_input(last_index, monad_input);
1043 } else {
1044 // Add monad as the last input.
1045 cnode->add_input(monad_input);
1046 }
1047 }
1048
1049 // The root graph.
1050 FuncGraphPtr root_;
1051
1052 // SCC map.
1053 SccMap scc_map_;
1054
1055 // Single branch (in switch) and its caller cnode.
1056 std::map<FuncGraphPtr, CNodePtr> branch_caller_map;
1057
1058 // Current high order func caller cnode.
1059 CNodePtr caller_ = nullptr;
1060
1061 // switch_layer_calls save all switch_layer calls, so that
1062 // we can check whether monad argument should be added for them.
1063 std::vector<SwitchLayerCall> switch_layer_calls_;
1064 }; // class SideEffectFinder
1065
1066 // --------------------------------------------------------------------
1067 // AutoMonadConverter converts side-effect cnodes into monad form.
1068 // --------------------------------------------------------------------
1069 class AutoMonadConverter {
1070 public:
Handle(const FuncGraphPtr & func_graph,bool top)1071 static bool Handle(const FuncGraphPtr &func_graph, bool top) {
1072 AutoMonadConverter converter(func_graph, top);
1073 return converter.Run();
1074 }
1075
1076 private:
AutoMonadConverter(const FuncGraphPtr & func_graph,bool top)1077 AutoMonadConverter(const FuncGraphPtr &func_graph, bool top)
1078 : func_graph_(func_graph), manager_(func_graph->manager()), top_(top) {}
1079
1080 ~AutoMonadConverter() = default;
1081
Run()1082 bool Run() {
1083 // Handle cnodes for side effects.
1084 const auto &info = func_graph_->GetEffectInfo();
1085 if (info.state == EffectInfo::kDetected) {
1086 HandleCNodes();
1087 }
1088
1089 // Safe to clear isolated nodes after handled side effect nodes.
1090 ClearIsolatedNodes();
1091
1092 // Clean up after conversion finished.
1093 func_graph_->ClearOrderList();
1094 return has_effect_cnodes_;
1095 }
1096
1097 // Check if there are side effects from effect info.
HasSideEffects(const EffectInfo & info)1098 static bool HasSideEffects(const EffectInfo &info) { return (info.memory || info.io || info.load); }
1099
1100 // Gets effect info for a cnode.
GetEffectInfo(const CNodePtr & cnode) const1101 const EffectInfo &GetEffectInfo(const CNodePtr &cnode) const {
1102 MS_EXCEPTION_IF_NULL(cnode);
1103 auto &effect_info = cnode->GetEffectInfo();
1104 if (effect_info.state != EffectInfo::kDetected) {
1105 // Effect info should have been set by SideEffectFinder.
1106 MS_LOG(EXCEPTION) << "Side effects not detected: " << cnode->DebugString();
1107 }
1108 return effect_info;
1109 }
1110
1111 // Handle CNodes for side effects.
HandleCNodes()1112 void HandleCNodes() {
1113 // Check whether UpdateState and Depend are required.
1114 bool update_state = NeedUpdateState();
1115
1116 // Check all cnodes in order list.
1117 for (auto &cnode : func_graph_->order_list()) {
1118 auto &info = GetEffectInfo(cnode);
1119 has_effect_cnodes_ = (has_effect_cnodes_ || HasSideEffects(info));
1120 if (cnode->func_graph() != func_graph_) {
1121 // Handle outer cnode.
1122 HandleOuterNode(cnode, info);
1123 } else {
1124 // Handle cnode with memory side effects.
1125 if (info.memory) {
1126 HandleMemoryEffects(cnode, update_state);
1127 } else if (info.load) {
1128 // If no memory side effects, handle load if need.
1129 HandleLoad(cnode, update_state);
1130 }
1131 // Handle cnode with IO side effects.
1132 if (info.io) {
1133 HandleIoEffects(cnode, update_state);
1134 }
1135 // If the node has no side effects but 'no_eliminate' flag is set,
1136 // we save it to no_eliminate_nodes and handle them late.
1137 if (!info.memory && !info.io && IsNoEliminateNode(cnode)) {
1138 no_eliminate_nodes_.emplace_back(cnode);
1139 }
1140 }
1141 cnode->SetEffectHandled(true);
1142 }
1143 // Attach no eliminate nodes to output.
1144 HandleNoEliminateNodes();
1145 // Attach monad to output if required.
1146 if (update_state) {
1147 AttachMonadToOutput();
1148 }
1149 }
1150
1151 // Return true if the given cnode is primitive cnode with 'no_eliminate' flag.
IsNoEliminateNode(const CNodePtr & cnode)1152 bool IsNoEliminateNode(const CNodePtr &cnode) {
1153 if (cnode == nullptr || cnode->size() == 0) {
1154 return false;
1155 }
1156 auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
1157 if (prim == nullptr) {
1158 return false;
1159 }
1160 return GetPrimitiveFlag(prim, ATTR_NO_ELIMINATE);
1161 }
1162
1163 // Attach no eliminate nodes to output.
HandleNoEliminateNodes()1164 void HandleNoEliminateNodes() {
1165 if (no_eliminate_nodes_.empty()) {
1166 // Skip if no nodes to be handled.
1167 return;
1168 }
1169 // If only one node, attach it to output directly.
1170 if (no_eliminate_nodes_.size() == 1) {
1171 AttachToOutput(no_eliminate_nodes_.front());
1172 return;
1173 }
1174 // For multiple nodes, attach them to output by a tuple.
1175 std::vector<AnfNodePtr> tuple_inputs;
1176 AbstractBasePtrList element_abstracts;
1177 tuple_inputs.reserve(no_eliminate_nodes_.size() + 1);
1178 element_abstracts.reserve(no_eliminate_nodes_.size());
1179 tuple_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
1180 for (auto &node : no_eliminate_nodes_) {
1181 tuple_inputs.emplace_back(node);
1182 element_abstracts.emplace_back(node->abstract());
1183 }
1184 auto make_tuple_node = func_graph_->NewCNode(tuple_inputs);
1185 make_tuple_node->set_abstract(std::make_shared<abstract::AbstractTuple>(element_abstracts));
1186 AttachToOutput(make_tuple_node);
1187 }
1188
1189 // Clean no side effect dependency nodes.
1190 // From: output = Depend(output, StopGrad)
1191 // return output
1192 //
1193 // To: return output
ClearIsolatedNodes() const1194 void ClearIsolatedNodes() const {
1195 auto output = GetGraphOutput();
1196 if (IsPrimitiveCNode(output, prim::kPrimDepend) &&
1197 IsPrimitiveCNode(output->cast<CNodePtr>()->input(2), prim::kPrimStopGradient)) {
1198 // Replace Depend(orig_output, StopGrad) node with orig_output.
1199 // After that, nodes may be eliminated if have no side effects.
1200 auto &orig_output = output->cast<CNodePtr>()->input(1);
1201 func_graph_->set_output(orig_output);
1202 }
1203 }
1204
HandleOuterNode(const CNodePtr & cnode,const EffectInfo & info)1205 void HandleOuterNode(const CNodePtr &cnode, const EffectInfo &info) {
1206 MS_EXCEPTION_IF_NULL(cnode);
1207 if (info.memory || info.load) {
1208 (void)GetUniverse();
1209 bool load_with_primitive = (info.load && IsPrimitiveCNode(cnode));
1210 if (!cnode->IsEffectHandled() && !load_with_primitive) {
1211 auto u_node = NewValueNode(kUMonad);
1212 u_node->set_abstract(kUMonad->ToAbstract());
1213 cnode->add_input(u_node);
1214 }
1215 }
1216 if (info.io) {
1217 (void)GetIoState();
1218 if (!cnode->IsEffectHandled()) {
1219 auto io = NewValueNode(kIOMonad);
1220 io->set_abstract(kIOMonad->ToAbstract());
1221 cnode->add_input(io);
1222 }
1223 }
1224 }
1225
1226 //
1227 // Convert cnode with memory side effect to monad form,
1228 // from:
1229 // output = func(input)
1230 // to:
1231 // output = func(input, u)
1232 // u = UpdateState(u, output) # if update_state is true
1233 //
HandleMemoryEffects(const CNodePtr & cnode,bool update_state)1234 void HandleMemoryEffects(const CNodePtr &cnode, bool update_state) {
1235 const auto &u = GetUniverse();
1236 AddMonadInput(cnode, u);
1237 if (update_state) {
1238 u_ = UpdateState(u, cnode);
1239 }
1240 }
1241
1242 //
1243 // Convert cnode with io side effect to monad form,
1244 // from:
1245 // output = func(input)
1246 // to:
1247 // output = func(input, io)
1248 // io = UpdateState(io, output) # if update_state is true
1249 //
HandleIoEffects(const CNodePtr & cnode,bool update_state)1250 void HandleIoEffects(const CNodePtr &cnode, bool update_state) {
1251 const auto &io = GetIoState();
1252 AddMonadInput(cnode, io);
1253 if (update_state) {
1254 io_ = UpdateState(io, cnode);
1255 }
1256 }
1257
HandleLoad(const CNodePtr & cnode,bool update_state)1258 void HandleLoad(const CNodePtr &cnode, bool update_state) {
1259 MS_EXCEPTION_IF_NULL(cnode);
1260 auto value = GetValueNode(cnode->input(0));
1261 if (value && value->isa<Primitive>()) {
1262 // For primitive calls that use Ref as input, insert Loads before them.
1263 InsertLoads(cnode, update_state);
1264 } else {
1265 // For non-primitive calls, load is used inside the callee,
1266 // We do not insert load for it but handle it as a side
1267 // effects cnode.
1268 HandleMemoryEffects(cnode, update_state);
1269 }
1270 }
1271
1272 //
1273 // Insert Loads for a primitive cnode that use Ref as input.
1274 // for example, from:
1275 // out = Prim(self.para1, self.para2, other_args)
1276 // to:
1277 // p1 = Load(self.para1, u)
1278 // p2 = Load(self.para2, u)
1279 // t = make_tuple(p1, p2) # if update_state
1280 // u1 = UpdateState(u, t) # is required
1281 // out = Prim(p1, p2, other_args)
1282 //
InsertLoads(const CNodePtr & cnode,bool update_state)1283 void InsertLoads(const CNodePtr &cnode, bool update_state) {
1284 // Find ref inputs.
1285 auto ref_inputs = GetRefInputs(cnode);
1286 if (ref_inputs.empty()) {
1287 MS_LOG(WARNING) << "Ref input not found for load insertion: " << cnode->DebugString();
1288 return;
1289 }
1290 // Current u monad.
1291 auto current_u = GetUniverse();
1292 // Create Load cnodes.
1293 auto loads = MakeLoads(cnode, ref_inputs, current_u);
1294 if (loads.empty() || !update_state) {
1295 // Skip UpdateState insertion.
1296 return;
1297 }
1298 // Insert UpdateState if required.
1299 if (loads.size() == 1) {
1300 // One Load, no make_tuple needed.
1301 u_ = UpdateState(current_u, loads.front());
1302 return;
1303 }
1304 // Multiple Loads, Create a MakeTuple before UpdateState.
1305 abstract::AbstractBasePtrList load_abstracts;
1306 std::transform(loads.begin(), loads.end(), std::back_inserter(load_abstracts),
1307 [](const AnfNodePtr &load) { return load->abstract(); });
1308 loads.insert(loads.begin(), NewValueNode(prim::kPrimMakeTuple));
1309 auto make_tuple = func_graph_->NewCNode(loads);
1310 make_tuple->set_abstract(std::make_shared<abstract::AbstractTuple>(load_abstracts));
1311 u_ = UpdateState(current_u, make_tuple);
1312 }
1313
MakeLoads(const CNodePtr & cnode,const RefInputs & ref_inputs,const AnfNodePtr & u)1314 std::vector<AnfNodePtr> MakeLoads(const CNodePtr &cnode, const RefInputs &ref_inputs, const AnfNodePtr &u) {
1315 std::vector<AnfNodePtr> loads;
1316 for (auto &ref_input : ref_inputs) {
1317 // Make a Load cnode for ref input.
1318 auto &ref = ref_input.first;
1319 auto load = MakeLoad(cnode, ref, u);
1320 // Replace input with the load cnode.
1321 for (size_t index : ref_input.second) {
1322 manager_->SetEdge(cnode, index, load);
1323 }
1324 loads.emplace_back(std::move(load));
1325 }
1326 return loads;
1327 }
1328
MakeLoad(const CNodePtr & cnode,const AnfNodePtr & ref,const AnfNodePtr & u)1329 CNodePtr MakeLoad(const CNodePtr &cnode, const AnfNodePtr &ref, const AnfNodePtr &u) {
1330 static const std::string primitive_target = "primitive_target";
1331 // Create Load cnode.
1332 auto load_prim = NewValueNode(prim::kPrimLoad);
1333 auto load_cnode = func_graph_->NewCNode({load_prim, ref, u});
1334 // Set device target for Load CNode.
1335 std::string target = GetCNodeTarget(cnode);
1336 load_cnode->set_user_data(primitive_target, std::make_shared<std::string>(target));
1337 // Set load_cnode abstract to Tensor according the input Ref[Tensor].
1338 auto ref_abs = dyn_cast<abstract::AbstractRef>(ref->abstract());
1339 MS_EXCEPTION_IF_NULL(ref_abs);
1340 load_cnode->set_abstract(ref_abs->CloneAsTensor());
1341 return load_cnode;
1342 }
1343
1344 // Add or replace monad input.
AddMonadInput(const CNodePtr & cnode,const AnfNodePtr & monad)1345 void AddMonadInput(const CNodePtr &cnode, const AnfNodePtr &monad) {
1346 MS_EXCEPTION_IF_NULL(cnode);
1347 constexpr size_t max_monad_inputs = 2;
1348 auto monad_abs = monad->abstract();
1349 auto &inputs = cnode->inputs();
1350 int last = static_cast<int>(inputs.size()) - 1;
1351 int stop = last - max_monad_inputs;
1352 // Search monad in inputs, replace it if found.
1353 for (int i = last; i > 0 && i > stop; --i) {
1354 size_t index = static_cast<size_t>(i);
1355 auto input_abs = inputs[index]->abstract();
1356 if (input_abs && *input_abs == *monad_abs) {
1357 manager_->SetEdge(cnode, i, monad);
1358 return;
1359 }
1360 }
1361 // If monad not found in inputs, add a monad input.
1362 manager_->AddEdge(cnode, monad);
1363 }
1364
AttachMonadToOutput() const1365 void AttachMonadToOutput() const {
1366 if (u_) {
1367 AttachToOutput(u_);
1368 }
1369 if (io_) {
1370 AttachToOutput(io_);
1371 }
1372 }
1373
AttachToOutput(const AnfNodePtr & node) const1374 void AttachToOutput(const AnfNodePtr &node) const {
1375 auto output = GetGraphOutput();
1376 auto depend = NewValueNode(prim::kPrimDepend);
1377 // If isolated nodes dependencies exist.
1378 if (IsPrimitiveCNode(output, prim::kPrimDepend) &&
1379 IsPrimitiveCNode(output->cast<CNodePtr>()->input(kDependAttachNodeIndex), prim::kPrimStopGradient)) {
1380 // Insert new Depend node before isolated Depend node.
1381 auto isolated_depend = output->cast<CNodePtr>();
1382 auto &orig_output = isolated_depend->input(1);
1383 auto state_depend = func_graph_->NewCNode({depend, orig_output, node});
1384 state_depend->set_abstract(orig_output->abstract());
1385 manager_->SetEdge(isolated_depend, 1, state_depend);
1386 return;
1387 }
1388 // Insert Depend node and set it as output, if no isolated nodes.
1389 auto depend_cnode = func_graph_->NewCNode({depend, output, node});
1390 depend_cnode->set_abstract(output->abstract());
1391 func_graph_->set_output(depend_cnode);
1392 }
1393
GetGraphOutput() const1394 AnfNodePtr GetGraphOutput() const {
1395 auto output = func_graph_->output();
1396 if (output != nullptr) {
1397 return output;
1398 }
1399 return NewValueNode(kNone);
1400 }
1401
UpdateState(const AnfNodePtr & state,const AnfNodePtr & attach)1402 AnfNodePtr UpdateState(const AnfNodePtr &state, const AnfNodePtr &attach) {
1403 MS_EXCEPTION_IF_NULL(attach);
1404 // Not attach UpdateState if set kAttrIgnoreSideEffect.
1405 auto attr_ignore_side_effect = attach->cast<CNodePtr>()->GetAttr(kAttrIgnoreSideEffect);
1406 auto ignore_side_effect = attr_ignore_side_effect != nullptr && attr_ignore_side_effect->isa<BoolImm>() &&
1407 GetValue<bool>(attr_ignore_side_effect);
1408 if (ignore_side_effect) {
1409 return state;
1410 }
1411
1412 auto update_state = NewValueNode(prim::kPrimUpdateState);
1413 auto update_state_cnode = func_graph_->NewCNode({update_state, state, attach});
1414 update_state_cnode->set_abstract(state->abstract());
1415 return update_state_cnode;
1416 }
1417
GetUniverse()1418 AnfNodePtr &GetUniverse() {
1419 if (u_ == nullptr) {
1420 if (top_) {
1421 u_ = NewValueNode(kUMonad);
1422 u_->set_abstract(kUMonad->ToAbstract());
1423 } else {
1424 u_ = AddMonadParameter(func_graph_, "u", kUMonad->ToAbstract());
1425 }
1426 }
1427 return u_;
1428 }
1429
GetIoState()1430 AnfNodePtr &GetIoState() {
1431 if (io_ == nullptr) {
1432 if (top_) {
1433 io_ = NewValueNode(kIOMonad);
1434 io_->set_abstract(kIOMonad->ToAbstract());
1435 } else {
1436 io_ = AddMonadParameter(func_graph_, "io", kIOMonad->ToAbstract());
1437 }
1438 }
1439 return io_;
1440 }
1441
1442 // Return true if update_state should be used in this func graph.
1443 // In some case, update_state can be omitted, such as:
1444 // def side_effect_tail_call(args):
1445 // a = pure_func(args)
1446 // return side_effect_call(a)
NeedUpdateState() const1447 bool NeedUpdateState() const {
1448 // Search for the only one side effect cnode.
1449 CNodePtr side_effect_cnode = nullptr;
1450 for (auto &cnode : func_graph_->order_list()) {
1451 if (HasSideEffect(cnode)) {
1452 if (side_effect_cnode != nullptr) {
1453 // There are multiple side effect cnodes, update state is required.
1454 return true;
1455 }
1456 side_effect_cnode = cnode;
1457 }
1458 }
1459 if (side_effect_cnode == nullptr) {
1460 // No side effect cnode, no update state.
1461 return false;
1462 }
1463 if (IsPrimitiveCNode(side_effect_cnode)) {
1464 // Always add update_state for primitive cnode.
1465 return true;
1466 }
1467 // If the only side effect cnode is not the tail call, update_state is required.
1468 return func_graph_->output() != side_effect_cnode;
1469 }
1470
HasSideEffect(const CNodePtr & cnode) const1471 bool HasSideEffect(const CNodePtr &cnode) const {
1472 const auto &cnode_info = GetEffectInfo(cnode);
1473 return (cnode_info.memory || cnode_info.load || cnode_info.io);
1474 }
1475
1476 // The func graph to be converted.
1477 const FuncGraphPtr &func_graph_;
1478
1479 // The func graph manager, used for graph edge update.
1480 FuncGraphManagerPtr manager_;
1481
1482 // True if converting top graph.
1483 const bool top_;
1484
1485 // True if there are side effect cnodes within this func graph.
1486 bool has_effect_cnodes_ = false;
1487
1488 // CNodes that should not be eliminated even it is isolated node.
1489 std::vector<CNodePtr> no_eliminate_nodes_;
1490
1491 // Current memory state node, null if no memory side effects.
1492 AnfNodePtr u_;
1493
1494 // Current IO state node, null if no IO side effects.
1495 AnfNodePtr io_;
1496 }; // class AutoMonadConverter
1497 } // namespace
1498
1499 // Entry point of the auto-monad phase,
1500 // the func_graph should be resolved and infer is done.
1501 // return true if side effect nodes found in func_graph.
AutoMonad(const FuncGraphPtr & func_graph)1502 bool AutoMonad(const FuncGraphPtr &func_graph) {
1503 MS_EXCEPTION_IF_NULL(func_graph);
1504 MS_EXCEPTION_IF_NULL(func_graph->manager());
1505
1506 // Search and mark side effects for the graph and sub-graphs.
1507 // this should be called before auto-monad starts.
1508 SideEffectFinder::Search(func_graph);
1509
1510 // Execute auto-monad conversion on top graph.
1511 bool has_effects = AutoMonadConverter::Handle(func_graph, true);
1512 // Convert used sub-graphs.
1513 auto fg_used_total = func_graph->func_graphs_used_total();
1514 for (auto &fg : fg_used_total) {
1515 auto top_flag = fg->has_flag(mindspore::kFuncGraphFlagBackPropEntry);
1516 if (fg->stage() != -1) {
1517 top_flag = true;
1518 }
1519 bool fg_has_effects = AutoMonadConverter::Handle(fg, top_flag);
1520 has_effects = has_effects || fg_has_effects;
1521 }
1522 return has_effects;
1523 }
1524
ReAutoMonad(const FuncGraphPtr & func_graph)1525 bool ReAutoMonad(const FuncGraphPtr &func_graph) {
1526 // AutoMonad for bprop network, only Monad for func graphs which back propogators have side effects.
1527 // Or AutoMonad for MultitypeFuncGraph which specialized in Renormalize other than the first Specialize pass.
1528 MS_EXCEPTION_IF_NULL(func_graph);
1529 bool need_auto_monad = false;
1530 std::vector<FuncGraphPtr> auto_monaded_fg;
1531 func_graph->EraseUnusedNodeInOrder();
1532 for (auto &fg : func_graph->func_graphs_used_total()) {
1533 if (fg->has_flag(mindspore::kFuncGraphFlagReAutoMonad)) {
1534 auto_monaded_fg.push_back(fg);
1535 for (auto &used_fg : fg->func_graphs_used_total()) {
1536 used_fg->set_flag(mindspore::kFuncGraphFlagReAutoMonad, true);
1537 auto_monaded_fg.push_back(used_fg);
1538 }
1539 need_auto_monad = true;
1540 MS_LOG(DEBUG) << "AutoMonad Grad for func graph: " << fg->ToString();
1541 }
1542 fg->EraseUnusedNodeInOrder();
1543 }
1544 bool changed = false;
1545 if (need_auto_monad) {
1546 for (auto &fg : func_graph->func_graphs_used_total()) {
1547 if (!fg->has_flag(mindspore::kFuncGraphFlagReAutoMonad)) {
1548 fg->ClearOrderList();
1549 }
1550 }
1551 changed = AutoMonad(func_graph);
1552 for (auto &fg : auto_monaded_fg) {
1553 fg->erase_flag(mindspore::kFuncGraphFlagReAutoMonad);
1554 }
1555 // After auto monad, Order List and Isolate nodes in graph and manager will be cleared.
1556 } else {
1557 func_graph->ClearOrderList();
1558 for (auto &fg : func_graph->func_graphs_used_total()) {
1559 fg->ClearOrderList();
1560 }
1561 }
1562 return changed;
1563 }
1564 } // namespace pipeline
1565 } // namespace mindspore
1566