1 /**
2 * Copyright 2024 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_ADD_FORWARD_MONAD_DEPEND_H_
18 #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_ADD_FORWARD_MONAD_DEPEND_H_
19
20 #include <vector>
21 #include "frontend/optimizer/optimizer.h"
22 #include "frontend/optimizer/ad/grad.h"
23 #include "ir/func_graph.h"
24
25 namespace mindspore {
26 namespace opt {
27 namespace irpass {
28 namespace internal {
GetBpropGetter(const FuncGraphManagerPtr & manager,const CNodePtr & node)29 AnfNodePtr GetBpropGetter(const FuncGraphManagerPtr &manager, const CNodePtr &node) {
30 const auto &user_nodes = manager->node_users()[node];
31 for (const auto &iter : user_nodes) {
32 if (IsPrimitiveCNode(iter.first, prim::kPrimTupleGetItem)) {
33 auto idx = GetValueNode<Int64ImmPtr>(iter.first->cast<CNodePtr>()->input(kInputNodeOutputIndexInTupleGetItem));
34 if (idx != nullptr && idx->value() == 1) {
35 return iter.first;
36 }
37 }
38 }
39 return nullptr;
40 }
41
GetBpropUser(const FuncGraphManagerPtr & manager,const AnfNodePtr & bprop_getter)42 AnfNodePtr GetBpropUser(const FuncGraphManagerPtr &manager, const AnfNodePtr &bprop_getter) {
43 MS_EXCEPTION_IF_NULL(manager);
44 const auto &node_users = manager->node_users();
45 auto iter = node_users.find(bprop_getter);
46 if (iter == node_users.end()) {
47 return nullptr;
48 }
49 if (iter->second.size() != 1) {
50 MS_LOG(EXCEPTION) << "The number of bprop caller should be 1, but got " << iter->second.size()
51 << ", bprop_getter: " << bprop_getter->DebugString();
52 }
53 auto user_node_idx = iter->second.begin();
54 if (user_node_idx->second != 0) {
55 MS_LOG(EXCEPTION) << "The bprop_getter should be used in input 0, but got " << user_node_idx->second;
56 }
57 return user_node_idx->first;
58 }
59
IsMemSideEffectNode(const AnfNodePtr & node)60 bool IsMemSideEffectNode(const AnfNodePtr &node) {
61 auto prim = GetCNodePrimitive(node);
62 if (prim == nullptr) {
63 return false;
64 }
65 return prim->HasAttr(GRAPH_FLAG_SIDE_EFFECT_MEM);
66 }
67
AddUMonadInput(const FuncGraphManagerPtr & manager,const FuncGraphPtr & bprop_graph,const AnfNodePtr & new_u_para)68 void AddUMonadInput(const FuncGraphManagerPtr &manager, const FuncGraphPtr &bprop_graph, const AnfNodePtr &new_u_para) {
69 auto fprop_graph = bprop_graph->parent();
70 auto is_bprop_node = [&fprop_graph](const AnfNodePtr &node) {
71 if (node->func_graph() == fprop_graph) {
72 return EXCLUDE;
73 }
74 return FOLLOW;
75 };
76 auto all_nodes = TopoSort(bprop_graph->get_return(), SuccDeeperSimple, is_bprop_node);
77 for (const auto &node : all_nodes) {
78 if (!IsMemSideEffectNode(node)) {
79 continue;
80 }
81 MS_LOG(DEBUG) << "Add u monad input for node " << node->DebugString();
82 manager->AddEdge(node, new_u_para);
83 }
84 }
85
PropagateUMonadInput(const FuncGraphManagerPtr & manager,const FuncGraphPtr & bprop_graph,const AbstractBasePtr & u_abs,bool add_u_input)86 void PropagateUMonadInput(const FuncGraphManagerPtr &manager, const FuncGraphPtr &bprop_graph,
87 const AbstractBasePtr &u_abs, bool add_u_input) {
88 auto new_u_para = bprop_graph->add_parameter();
89 new_u_para->debug_info()->set_name("forward_u");
90 new_u_para->set_abstract(u_abs);
91 bprop_graph->set_flag(kFuncGraphFlagAddedForwardU, true);
92 if (add_u_input) {
93 AddUMonadInput(manager, bprop_graph, new_u_para);
94 }
95 std::vector<CNodePtr> side_effect_bprop_app_propagate_nodes;
96 for (const auto &node : bprop_graph->nodes()) {
97 auto cnode = dyn_cast<CNode>(node);
98 if (cnode == nullptr) {
99 continue;
100 }
101 if (cnode->HasAttr(kAttrSideEffectBpropAppPropagate) || cnode->HasAttr(kAttrSideEffectBpropApp)) {
102 (void)side_effect_bprop_app_propagate_nodes.emplace_back(cnode);
103 }
104 }
105 if (side_effect_bprop_app_propagate_nodes.empty()) {
106 return;
107 }
108
109 for (const auto &propagate_node : side_effect_bprop_app_propagate_nodes) {
110 manager->AddEdge(propagate_node, new_u_para);
111 auto bprop_getter_abs = dyn_cast<abstract::FuncGraphAbstractClosure>(propagate_node->input(0)->abstract());
112 if (bprop_getter_abs == nullptr) {
113 MS_LOG(INTERNAL_EXCEPTION) << "The node " << propagate_node->input(0)->DebugString()
114 << " should have a FuncGraphAbstractClosure abstract.";
115 }
116 auto bprop_fg = bprop_getter_abs->func_graph();
117 MS_EXCEPTION_IF_NULL(bprop_fg);
118 if (bprop_fg->has_flag(kFuncGraphFlagAddedForwardU)) {
119 continue;
120 }
121 PropagateUMonadInput(manager, bprop_fg, u_abs, propagate_node->HasAttr(kAttrSideEffectBpropApp));
122 }
123 }
124 } // namespace internal
125
126 // The origin pattern:
127 // %0 = U
128 // %1 = call fprop(x, y, %0)
129 // %2 = get_item(%1, 1)
130 // %3 = %2[@@bprop](dout)
131 //
132 // graph bprop(dout):
133 // %0 = side_effect_mem_op(dout)
134 //
135 // After the pass:
136 // kLevelNone(no changes)
137 // %0 = U
138 // %1 = call fprop(x, y, %0)
139 // %2 = get_item(%1, 1)
140 // %3 = %2[@@bprop](dout)
141 //
142 // graph bprop(dout):
143 // %0 = side_effect_mem_op(dout)
144 //
145 // kLevelTop
146 // %0 = U
147 // %1 = call fprop(x, y, %0)
148 // %2 = get_item(%1, 1)
149 // %3 = %2[@@bprop](dout, %0)
150 //
151 // graph bprop(dout, u):
152 // %0 = side_effect_mem_op(dout, u)
153 //
154 // kLevelWhole
155 // %0 = U
156 // %1 = call fprop(x, y, %0)
157 // %2 = UpdateState(U, %1)
158 // %3 = get_item(%1, 1)
159 // %4 = %3[@@bprop](dout, %2)
160 //
161 // graph bprop(dout, u):
162 // %0 = side_effect_mem_op(dout, u)
AddForwardMonadDepend(const FuncGraphPtr & root,const opt::OptimizerPtr & opt)163 bool AddForwardMonadDepend(const FuncGraphPtr &root, const opt::OptimizerPtr &opt) {
164 MS_EXCEPTION_IF_NULL(root);
165 MS_EXCEPTION_IF_NULL(opt);
166 auto manager = opt->manager();
167 MS_EXCEPTION_IF_NULL(manager);
168 std::vector<FuncGraphPtr> top_k_graphs;
169 for (const auto &fg : root->func_graphs_used_total()) {
170 MS_EXCEPTION_IF_NULL(fg);
171 if (fg->has_attr(kAttrBpropAutoMonadLevel) && fg->has_flag(kAttrSideEffectBpropAppPropagate)) {
172 (void)top_k_graphs.emplace_back(fg);
173 }
174 }
175
176 bool changed = false;
177 for (const auto &top_k_graph : top_k_graphs) {
178 auto bprop_auto_monad_level = GetValue<int>(top_k_graph->get_attr(kAttrBpropAutoMonadLevel));
179 top_k_graph->erase_flag(kAttrBpropAutoMonadLevel);
180 if (bprop_auto_monad_level == ad::BpropAutoMonadLevel::kLevelNone) {
181 break;
182 }
183 FuncGraphPtr bprop_graph = nullptr;
184 AbstractBasePtr u_abs = nullptr;
185 for (const auto &entry : top_k_graph->func_graph_cnodes_index()) {
186 auto k_graph_caller = entry.first->first->cast<CNodePtr>();
187 auto index = entry.first->second;
188 // Get the real graph caller.
189 if (index != 0) {
190 continue;
191 }
192 // Get the monad input.
193 auto umonad_input = k_graph_caller->input(k_graph_caller->size() - 1);
194 if (!HasAbstractUMonad(umonad_input)) {
195 continue;
196 }
197 // Only handle the fprop which has bprop getter.
198 auto bprop_getter = internal::GetBpropGetter(manager, k_graph_caller);
199 if (bprop_getter == nullptr) {
200 continue;
201 }
202 auto bprop_getter_abs = dyn_cast<abstract::FuncGraphAbstractClosure>(bprop_getter->abstract());
203 if (bprop_getter_abs == nullptr) {
204 MS_LOG(INTERNAL_EXCEPTION) << "The node " << bprop_getter->DebugString()
205 << " should have a FuncGraphAbstractClosure abstract.";
206 }
207 if (bprop_graph == nullptr) {
208 bprop_graph = bprop_getter_abs->func_graph();
209 } else if (bprop_getter_abs->func_graph() != bprop_graph) {
210 MS_LOG(INTERNAL_EXCEPTION) << "The bprop graphs are not same for the k graph: " << top_k_graph->ToString();
211 }
212 auto bprop_user = internal::GetBpropUser(manager, bprop_getter);
213 if (bprop_user == nullptr) {
214 continue;
215 }
216
217 auto update_state_to_depend = umonad_input;
218 if (bprop_auto_monad_level == ad::BpropAutoMonadLevel::kLevelWhole) {
219 std::vector<AnfNodePtr> new_update_state_inputs = {NewValueNode(prim::kPrimUpdateState), umonad_input,
220 k_graph_caller};
221 update_state_to_depend = k_graph_caller->func_graph()->NewCNodeInOrder(new_update_state_inputs);
222 update_state_to_depend->set_abstract(umonad_input->abstract());
223 }
224 manager->AddEdge(bprop_user, update_state_to_depend);
225 changed = true;
226 u_abs = umonad_input->abstract();
227 }
228 if (bprop_graph != nullptr && u_abs != nullptr) {
229 internal::PropagateUMonadInput(manager, bprop_graph, u_abs, false);
230 }
231 }
232 return changed;
233 }
234 } // namespace irpass
235 } // namespace opt
236 } // namespace mindspore
237 #endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_ADD_FORWARD_MONAD_DEPEND_H_
238