• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_ADD_FORWARD_MONAD_DEPEND_H_
18 #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_ADD_FORWARD_MONAD_DEPEND_H_
19 
20 #include <vector>
21 #include "frontend/optimizer/optimizer.h"
22 #include "frontend/optimizer/ad/grad.h"
23 #include "ir/func_graph.h"
24 
25 namespace mindspore {
26 namespace opt {
27 namespace irpass {
28 namespace internal {
GetBpropGetter(const FuncGraphManagerPtr & manager,const CNodePtr & node)29 AnfNodePtr GetBpropGetter(const FuncGraphManagerPtr &manager, const CNodePtr &node) {
30   const auto &user_nodes = manager->node_users()[node];
31   for (const auto &iter : user_nodes) {
32     if (IsPrimitiveCNode(iter.first, prim::kPrimTupleGetItem)) {
33       auto idx = GetValueNode<Int64ImmPtr>(iter.first->cast<CNodePtr>()->input(kInputNodeOutputIndexInTupleGetItem));
34       if (idx != nullptr && idx->value() == 1) {
35         return iter.first;
36       }
37     }
38   }
39   return nullptr;
40 }
41 
GetBpropUser(const FuncGraphManagerPtr & manager,const AnfNodePtr & bprop_getter)42 AnfNodePtr GetBpropUser(const FuncGraphManagerPtr &manager, const AnfNodePtr &bprop_getter) {
43   MS_EXCEPTION_IF_NULL(manager);
44   const auto &node_users = manager->node_users();
45   auto iter = node_users.find(bprop_getter);
46   if (iter == node_users.end()) {
47     return nullptr;
48   }
49   if (iter->second.size() != 1) {
50     MS_LOG(EXCEPTION) << "The number of bprop caller should be 1, but got " << iter->second.size()
51                       << ", bprop_getter: " << bprop_getter->DebugString();
52   }
53   auto user_node_idx = iter->second.begin();
54   if (user_node_idx->second != 0) {
55     MS_LOG(EXCEPTION) << "The bprop_getter should be used in input 0, but got " << user_node_idx->second;
56   }
57   return user_node_idx->first;
58 }
59 
IsMemSideEffectNode(const AnfNodePtr & node)60 bool IsMemSideEffectNode(const AnfNodePtr &node) {
61   auto prim = GetCNodePrimitive(node);
62   if (prim == nullptr) {
63     return false;
64   }
65   return prim->HasAttr(GRAPH_FLAG_SIDE_EFFECT_MEM);
66 }
67 
AddUMonadInput(const FuncGraphManagerPtr & manager,const FuncGraphPtr & bprop_graph,const AnfNodePtr & new_u_para)68 void AddUMonadInput(const FuncGraphManagerPtr &manager, const FuncGraphPtr &bprop_graph, const AnfNodePtr &new_u_para) {
69   auto fprop_graph = bprop_graph->parent();
70   auto is_bprop_node = [&fprop_graph](const AnfNodePtr &node) {
71     if (node->func_graph() == fprop_graph) {
72       return EXCLUDE;
73     }
74     return FOLLOW;
75   };
76   auto all_nodes = TopoSort(bprop_graph->get_return(), SuccDeeperSimple, is_bprop_node);
77   for (const auto &node : all_nodes) {
78     if (!IsMemSideEffectNode(node)) {
79       continue;
80     }
81     MS_LOG(DEBUG) << "Add u monad input for node " << node->DebugString();
82     manager->AddEdge(node, new_u_para);
83   }
84 }
85 
PropagateUMonadInput(const FuncGraphManagerPtr & manager,const FuncGraphPtr & bprop_graph,const AbstractBasePtr & u_abs,bool add_u_input)86 void PropagateUMonadInput(const FuncGraphManagerPtr &manager, const FuncGraphPtr &bprop_graph,
87                           const AbstractBasePtr &u_abs, bool add_u_input) {
88   auto new_u_para = bprop_graph->add_parameter();
89   new_u_para->debug_info()->set_name("forward_u");
90   new_u_para->set_abstract(u_abs);
91   bprop_graph->set_flag(kFuncGraphFlagAddedForwardU, true);
92   if (add_u_input) {
93     AddUMonadInput(manager, bprop_graph, new_u_para);
94   }
95   std::vector<CNodePtr> side_effect_bprop_app_propagate_nodes;
96   for (const auto &node : bprop_graph->nodes()) {
97     auto cnode = dyn_cast<CNode>(node);
98     if (cnode == nullptr) {
99       continue;
100     }
101     if (cnode->HasAttr(kAttrSideEffectBpropAppPropagate) || cnode->HasAttr(kAttrSideEffectBpropApp)) {
102       (void)side_effect_bprop_app_propagate_nodes.emplace_back(cnode);
103     }
104   }
105   if (side_effect_bprop_app_propagate_nodes.empty()) {
106     return;
107   }
108 
109   for (const auto &propagate_node : side_effect_bprop_app_propagate_nodes) {
110     manager->AddEdge(propagate_node, new_u_para);
111     auto bprop_getter_abs = dyn_cast<abstract::FuncGraphAbstractClosure>(propagate_node->input(0)->abstract());
112     if (bprop_getter_abs == nullptr) {
113       MS_LOG(INTERNAL_EXCEPTION) << "The node " << propagate_node->input(0)->DebugString()
114                                  << " should have a FuncGraphAbstractClosure abstract.";
115     }
116     auto bprop_fg = bprop_getter_abs->func_graph();
117     MS_EXCEPTION_IF_NULL(bprop_fg);
118     if (bprop_fg->has_flag(kFuncGraphFlagAddedForwardU)) {
119       continue;
120     }
121     PropagateUMonadInput(manager, bprop_fg, u_abs, propagate_node->HasAttr(kAttrSideEffectBpropApp));
122   }
123 }
124 }  // namespace internal
125 
126 // The origin pattern:
127 // %0 = U
128 // %1 = call fprop(x, y, %0)
129 // %2 = get_item(%1, 1)
130 // %3 = %2[@@bprop](dout)
131 //
132 // graph bprop(dout):
133 //   %0 = side_effect_mem_op(dout)
134 //
135 // After the pass:
136 // kLevelNone(no changes)
137 // %0 = U
138 // %1 = call fprop(x, y, %0)
139 // %2 = get_item(%1, 1)
140 // %3 = %2[@@bprop](dout)
141 //
142 // graph bprop(dout):
143 //   %0 = side_effect_mem_op(dout)
144 //
145 // kLevelTop
146 // %0 = U
147 // %1 = call fprop(x, y, %0)
148 // %2 = get_item(%1, 1)
149 // %3 = %2[@@bprop](dout, %0)
150 //
151 // graph bprop(dout, u):
152 //   %0 = side_effect_mem_op(dout, u)
153 //
154 // kLevelWhole
155 // %0 = U
156 // %1 = call fprop(x, y, %0)
157 // %2 = UpdateState(U, %1)
158 // %3 = get_item(%1, 1)
159 // %4 = %3[@@bprop](dout, %2)
160 //
161 // graph bprop(dout, u):
162 //   %0 = side_effect_mem_op(dout, u)
AddForwardMonadDepend(const FuncGraphPtr & root,const opt::OptimizerPtr & opt)163 bool AddForwardMonadDepend(const FuncGraphPtr &root, const opt::OptimizerPtr &opt) {
164   MS_EXCEPTION_IF_NULL(root);
165   MS_EXCEPTION_IF_NULL(opt);
166   auto manager = opt->manager();
167   MS_EXCEPTION_IF_NULL(manager);
168   std::vector<FuncGraphPtr> top_k_graphs;
169   for (const auto &fg : root->func_graphs_used_total()) {
170     MS_EXCEPTION_IF_NULL(fg);
171     if (fg->has_attr(kAttrBpropAutoMonadLevel) && fg->has_flag(kAttrSideEffectBpropAppPropagate)) {
172       (void)top_k_graphs.emplace_back(fg);
173     }
174   }
175 
176   bool changed = false;
177   for (const auto &top_k_graph : top_k_graphs) {
178     auto bprop_auto_monad_level = GetValue<int>(top_k_graph->get_attr(kAttrBpropAutoMonadLevel));
179     top_k_graph->erase_flag(kAttrBpropAutoMonadLevel);
180     if (bprop_auto_monad_level == ad::BpropAutoMonadLevel::kLevelNone) {
181       break;
182     }
183     FuncGraphPtr bprop_graph = nullptr;
184     AbstractBasePtr u_abs = nullptr;
185     for (const auto &entry : top_k_graph->func_graph_cnodes_index()) {
186       auto k_graph_caller = entry.first->first->cast<CNodePtr>();
187       auto index = entry.first->second;
188       // Get the real graph caller.
189       if (index != 0) {
190         continue;
191       }
192       // Get the monad input.
193       auto umonad_input = k_graph_caller->input(k_graph_caller->size() - 1);
194       if (!HasAbstractUMonad(umonad_input)) {
195         continue;
196       }
197       // Only handle the fprop which has bprop getter.
198       auto bprop_getter = internal::GetBpropGetter(manager, k_graph_caller);
199       if (bprop_getter == nullptr) {
200         continue;
201       }
202       auto bprop_getter_abs = dyn_cast<abstract::FuncGraphAbstractClosure>(bprop_getter->abstract());
203       if (bprop_getter_abs == nullptr) {
204         MS_LOG(INTERNAL_EXCEPTION) << "The node " << bprop_getter->DebugString()
205                                    << " should have a FuncGraphAbstractClosure abstract.";
206       }
207       if (bprop_graph == nullptr) {
208         bprop_graph = bprop_getter_abs->func_graph();
209       } else if (bprop_getter_abs->func_graph() != bprop_graph) {
210         MS_LOG(INTERNAL_EXCEPTION) << "The bprop graphs are not same for the k graph: " << top_k_graph->ToString();
211       }
212       auto bprop_user = internal::GetBpropUser(manager, bprop_getter);
213       if (bprop_user == nullptr) {
214         continue;
215       }
216 
217       auto update_state_to_depend = umonad_input;
218       if (bprop_auto_monad_level == ad::BpropAutoMonadLevel::kLevelWhole) {
219         std::vector<AnfNodePtr> new_update_state_inputs = {NewValueNode(prim::kPrimUpdateState), umonad_input,
220                                                            k_graph_caller};
221         update_state_to_depend = k_graph_caller->func_graph()->NewCNodeInOrder(new_update_state_inputs);
222         update_state_to_depend->set_abstract(umonad_input->abstract());
223       }
224       manager->AddEdge(bprop_user, update_state_to_depend);
225       changed = true;
226       u_abs = umonad_input->abstract();
227     }
228     if (bprop_graph != nullptr && u_abs != nullptr) {
229       internal::PropagateUMonadInput(manager, bprop_graph, u_abs, false);
230     }
231   }
232   return changed;
233 }
234 }  // namespace irpass
235 }  // namespace opt
236 }  // namespace mindspore
237 #endif  // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_ADD_FORWARD_MONAD_DEPEND_H_
238