• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/optimizer/irpass/recompute.h"
18 #include <set>
19 #include <unordered_map>
20 #include "ops/array_ops.h"
21 
22 namespace mindspore {
23 namespace opt {
24 namespace irpass {
EnableCellReuse()25 bool EnableCellReuse() {
26   auto context = MsContext::GetInstance();
27   MS_EXCEPTION_IF_NULL(context);
28   const auto cell_reuse = context->CellReuseLevel() != CellReuseLevel::kNoCellReuse;
29   return cell_reuse;
30 }
31 
HasBpropGetter(const OptimizerPtr & opt,const AnfNodePtr & k_fg_caller)32 bool HasBpropGetter(const OptimizerPtr &opt, const AnfNodePtr &k_fg_caller) {
33   MS_EXCEPTION_IF_NULL(opt);
34   auto manager = opt->manager();
35   MS_EXCEPTION_IF_NULL(manager);
36   const auto &node_users = manager->node_users();
37   auto iter = node_users.find(k_fg_caller);
38   if (iter == node_users.end()) {
39     MS_LOG(EXCEPTION) << "The node " << k_fg_caller->DebugString() << " should have users.";
40   }
41 
42   return std::any_of(iter->second.begin(), iter->second.end(), [](const std::pair<AnfNodePtr, int> &node_and_idx) {
43     auto user = node_and_idx.first;
44     return IsPrimitiveCNode(user, prim::kPrimTupleGetItem) &&
45            common::AnfAlgo::GetTupleGetItemOutIndex(user->cast<CNodePtr>()) == 1;
46   });
47 }
48 
GetBpropCaller(const FuncGraphManagerPtr & manager,const AnfNodePtr & bprop_getter)49 AnfNodePtr GetBpropCaller(const FuncGraphManagerPtr &manager, const AnfNodePtr &bprop_getter) {
50   MS_EXCEPTION_IF_NULL(manager);
51   const auto &node_users = manager->node_users();
52   auto iter = node_users.find(bprop_getter);
53   if (iter == node_users.end()) {
54     return nullptr;
55   }
56   if (iter->second.size() != 1) {
57     MS_LOG(EXCEPTION) << "The number of bprop caller should be 1, but got " << iter->second.size()
58                       << ", bprop_getter: " << bprop_getter->DebugString();
59   }
60   auto user_node_idx = iter->second.begin();
61   if (user_node_idx->second != 0) {
62     MS_LOG(EXCEPTION) << "The bprop_getter should be used in input 0, but got " << user_node_idx->second;
63   }
64   return user_node_idx->first;
65 }
66 
67 namespace {
68 constexpr auto kGradientsFlag = "Gradients";
69 constexpr auto kAttrReplacedWithPrimal = "replaced_with_primal";
70 constexpr auto kAttrRecomputeMakeTuple = "recompute_make_tuple";
71 
WithRecomputedScope(const AnfNodePtr & node)72 bool WithRecomputedScope(const AnfNodePtr &node) {
73   MS_EXCEPTION_IF_NULL(node);
74   if (!node->isa<CNode>()) {
75     return false;
76   }
77   const auto &full_name_with_scope = node->fullname_with_scope();
78   return full_name_with_scope.compare(0, strlen(kAttrRecompute), kAttrRecompute) == 0;
79 }
80 
IsRecomputeKGraphCaller(const AnfNodePtr & node)81 bool IsRecomputeKGraphCaller(const AnfNodePtr &node) {
82   auto cnode = dyn_cast_ptr<CNode>(node);
83   if (cnode == nullptr) {
84     return false;
85   }
86   auto call_fg = GetValueNode<FuncGraphPtr>(cnode->input(0));
87   if (call_fg != nullptr && call_fg->has_flag(FUNC_GRAPH_RECOMPUTE_K_GRAPH)) {
88     return true;
89   }
90   return false;
91 }
92 
WithGradientScope(const AnfNodePtr & node)93 bool WithGradientScope(const AnfNodePtr &node) {
94   return node->fullname_with_scope().compare(0, strlen(kGradientsFlag), kGradientsFlag) == 0;
95 }
96 
IsFromBpropOutput(const AnfNodePtr & node)97 bool IsFromBpropOutput(const AnfNodePtr &node) {
98   if (!IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
99     return false;
100   }
101   auto cur_node = node;
102   while (IsPrimitiveCNode(cur_node, prim::kPrimTupleGetItem)) {
103     cur_node = cur_node->cast<CNodePtr>()->input(kRealInputNodeIndexInTupleGetItem);
104   }
105   if (WithGradientScope(cur_node)) {
106     return true;
107   }
108   auto cur_cnode = cur_node->cast<CNodePtr>();
109   if (cur_cnode == nullptr) {
110     return false;
111   }
112   auto func_abs = dyn_cast<abstract::FuncGraphAbstractClosure>(cur_cnode->input(0)->abstract());
113   if (func_abs == nullptr) {
114     return false;
115   }
116   auto fg = func_abs->func_graph();
117   MS_EXCEPTION_IF_NULL(fg);
118   return fg->has_flag(FUNC_GRAPH_RECOMPUTE_GRAD_GRAPH);
119 }
120 
IsGradNode(const AnfNodePtr & node)121 bool IsGradNode(const AnfNodePtr &node) {
122   MS_EXCEPTION_IF_NULL(node);
123   return WithGradientScope(node) || IsFromBpropOutput(node);
124 }
125 
IsFpropReturn(const AnfNodePtr & make_tuple)126 bool IsFpropReturn(const AnfNodePtr &make_tuple) {
127   auto cnode = make_tuple->cast<CNodePtr>();
128   constexpr size_t fprop_output_size = 2;
129   if (cnode->size() != fprop_output_size + 1) {
130     return false;
131   }
132   return IsValueNode<FuncGraph>(cnode->input(fprop_output_size));
133 }
134 
GetPrimalFromFprop(const FuncGraphPtr & k_fg)135 AnfNodePtr GetPrimalFromFprop(const FuncGraphPtr &k_fg) {
136   if (!IsPrimitiveCNode(k_fg->output(), prim::kPrimMakeTuple)) {
137     return nullptr;
138   }
139   auto k_fg_outputs = k_fg->output()->cast<CNodePtr>()->inputs();
140   if (k_fg_outputs.size() != 3) {
141     return nullptr;
142   }
143   return k_fg_outputs[kIndex1];
144 }
145 
ShouldAddNewPrimalOutput(const AnfNodePtr & node,bool recompute_cell)146 bool ShouldAddNewPrimalOutput(const AnfNodePtr &node, bool recompute_cell) {
147   return !IsGradNode(node) || recompute_cell;
148 }
149 
IsForwardDepend(const AnfNodePtr & node)150 bool IsForwardDepend(const AnfNodePtr &node) {
151   return IsPrimitiveCNode(node, prim::kPrimDepend) && !node->cast_ptr<CNode>()->HasAttr(kRecomputeInsert);
152 }
153 
AddNewPrimalNode(const FuncGraphManagerPtr & manager,const FuncGraphPtr & fg,const AnfNodePtr & origin_primal,const AnfNodePtr & new_primal,bool recompute_cell,std::unordered_map<AnfNodePtr,AnfNodePtr> * origin_to_new_primal)154 bool AddNewPrimalNode(const FuncGraphManagerPtr &manager, const FuncGraphPtr &fg, const AnfNodePtr &origin_primal,
155                       const AnfNodePtr &new_primal, bool recompute_cell,
156                       std::unordered_map<AnfNodePtr, AnfNodePtr> *origin_to_new_primal) {
157   bool changed = false;
158   auto node_users = manager->node_users()[origin_primal];
159   for (auto &node_and_idx : node_users) {
160     auto user = node_and_idx.first;
161     MS_EXCEPTION_IF_NULL(user);
162     // The forward part may have multiple outputs.
163     if (IsPrimitiveCNode(user, prim::kPrimTupleGetItem) && ShouldAddNewPrimalOutput(user, recompute_cell)) {
164       // Make new tuple_getitem to get corresponding output.
165       auto new_primal_getitem = fg->NewCNode({NewValueNode(prim::kPrimTupleGetItem), new_primal,
166                                               user->cast_ptr<CNode>()->input(kInputNodeOutputIndexInTupleGetItem)});
167       changed =
168         AddNewPrimalNode(manager, fg, user, new_primal_getitem, recompute_cell, origin_to_new_primal) || changed;
169       continue;
170     }
171     if (IsForwardDepend(user) && ShouldAddNewPrimalOutput(user, recompute_cell)) {
172       // Make new depend node in forward to get corresponding output.
173       auto new_depend = fg->NewCNode(user->cast_ptr<CNode>()->inputs());
174       new_depend->set_input(IntToSize(node_and_idx.second), new_primal);
175       changed = AddNewPrimalNode(manager, fg, user, new_depend, recompute_cell, origin_to_new_primal) || changed;
176       continue;
177     }
178     // The op like concat will have a make_tuple input.
179     if (IsPrimitiveCNode(user, prim::kPrimMakeTuple) && !IsFpropReturn(user) &&
180         ShouldAddNewPrimalOutput(user, recompute_cell)) {
181       auto user_cnode = user->cast<CNodePtr>();
182       MS_EXCEPTION_IF_NULL(user_cnode);
183       if (user_cnode->HasAttr(kAttrRecomputeMakeTuple)) {
184         manager->SetEdge(user_cnode, node_and_idx.second, new_primal);
185         continue;
186       }
187       auto iter = origin_to_new_primal->find(user);
188       if (iter != origin_to_new_primal->end()) {
189         // The new make_tuple has been created, just set its inputs.
190         manager->SetEdge(iter->second, node_and_idx.second, new_primal);
191         continue;
192       }
193       // Create a new primal make_tuple.
194       std::vector<AnfNodePtr> make_tuple_inputs{NewValueNode(prim::kPrimMakeTuple)};
195       for (size_t i = 1; i < user_cnode->size(); ++i) {
196         (void)make_tuple_inputs.emplace_back(user_cnode->input(i));
197       }
198       auto new_primal_make_tuple = fg->NewCNode(make_tuple_inputs);
199       new_primal_make_tuple->set_input(node_and_idx.second, new_primal);
200       new_primal_make_tuple->AddAttr(kAttrRecomputeMakeTuple, MakeValue(true));
201       (void)origin_to_new_primal->emplace(user, new_primal_make_tuple);
202       changed =
203         AddNewPrimalNode(manager, fg, user, new_primal_make_tuple, recompute_cell, origin_to_new_primal) || changed;
204       continue;
205     }
206 
207     // Set edge to not recomputed primal nodes.
208     if (recompute_cell || (!IsRecomputeKGraphCaller(user) && !IsGradNode(user))) {
209       MS_LOG(DEBUG) << "Set edge to user: " << user->DebugString() << ", new primal: " << new_primal->DebugString();
210       manager->SetEdge(user, node_and_idx.second, new_primal);
211       changed = true;
212     }
213   }
214   return changed;
215 }
216 
IsRecomputeCell(const FuncGraphPtr & k_fg)217 bool IsRecomputeCell(const FuncGraphPtr &k_fg) {
218   auto primal_iter = k_fg->transforms().find("primal");
219   if (primal_iter == k_fg->transforms().end()) {
220     MS_LOG(EXCEPTION) << "The k_fg: " << k_fg << " should have primal part.";
221   }
222   return primal_iter->second.func_graph() != nullptr;
223 }
224 
HasRecomputedInput(const CNodePtr & k_fg_caller_cnode)225 bool HasRecomputedInput(const CNodePtr &k_fg_caller_cnode) {
226   for (auto &input : k_fg_caller_cnode->inputs()) {
227     if (IsPrimitiveCNode(input, prim::kPrimMakeTuple)) {
228       return HasRecomputedInput(input->cast<CNodePtr>());
229     }
230     if (IsPrimitiveCNode(input, prim::kPrimDepend) && HasRecomputedInput(input->cast<CNodePtr>())) {
231       return true;
232     }
233     // The recomputed input should be a tuple_getitem to get the forward part of recomputed k graph.
234     if (!IsPrimitiveCNode(input, prim::kPrimTupleGetItem)) {
235       continue;
236     }
237     auto tmp = input->cast<CNodePtr>()->input(1);
238     auto input_k_fg_caller = tmp;
239     // The forward part may have multiple outputs.
240     if (IsPrimitiveCNode(tmp, prim::kPrimTupleGetItem)) {
241       input_k_fg_caller = tmp->cast<CNodePtr>()->input(1);
242     }
243 
244     auto cnode = dyn_cast_ptr<CNode>(input_k_fg_caller);
245     if (cnode == nullptr) {
246       continue;
247     }
248     auto call_fg = GetValueNode<FuncGraphPtr>(cnode->input(0));
249     // The output of recomputed cell would not be recomputed.
250     if (call_fg != nullptr && call_fg->has_flag(FUNC_GRAPH_RECOMPUTE_K_GRAPH) && !IsRecomputeCell(call_fg)) {
251       return true;
252     }
253   }
254   return false;
255 }
256 
IsForwardGetterTupleGetItem(const AnfNodePtr & node)257 bool IsForwardGetterTupleGetItem(const AnfNodePtr &node) {
258   if (!IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
259     return false;
260   }
261   auto idx = GetValueNode<Int64ImmPtr>(node->cast<CNodePtr>()->input(kInputNodeOutputIndexInTupleGetItem));
262   if (idx != nullptr && idx->value() == 0) {
263     return true;
264   }
265   return false;
266 }
267 
GetForwardGetter(const FuncGraphManagerPtr & manager,const CNodePtr & node)268 AnfNodePtr GetForwardGetter(const FuncGraphManagerPtr &manager, const CNodePtr &node) {
269   const auto &user_nodes = manager->node_users()[node];
270   auto iter = std::find_if(user_nodes.begin(), user_nodes.end(), [](const auto &node_and_idx) -> bool {
271     return IsForwardGetterTupleGetItem(node_and_idx.first);
272   });
273   if (iter != user_nodes.end()) {
274     return iter->first;
275   }
276   return nullptr;
277 }
278 
GetBpropGetter(const FuncGraphManagerPtr & manager,const CNodePtr & node)279 AnfNodePtr GetBpropGetter(const FuncGraphManagerPtr &manager, const CNodePtr &node) {
280   const auto &user_nodes = manager->node_users()[node];
281   for (const auto &iter : user_nodes) {
282     if (IsPrimitiveCNode(iter.first, prim::kPrimTupleGetItem)) {
283       auto idx = GetValueNode<Int64ImmPtr>(iter.first->cast<CNodePtr>()->input(kInputNodeOutputIndexInTupleGetItem));
284       if (idx != nullptr && idx->value() == 1) {
285         return iter.first;
286       }
287     }
288   }
289   return nullptr;
290 }
291 
HasRecomputedOutput(const FuncGraphManagerPtr & manager,const AnfNodePtr & node)292 bool HasRecomputedOutput(const FuncGraphManagerPtr &manager, const AnfNodePtr &node) {
293   // The forward part may have multiple outputs.
294   if (IsOneOfPrimitiveCNode(node, {prim::kPrimTupleGetItem, prim::kPrimMakeTuple, prim::kPrimDepend})) {
295     const auto &user_nodes = manager->node_users()[node];
296     return std::any_of(user_nodes.begin(), user_nodes.end(),
297                        [&manager](const auto &iter) { return HasRecomputedOutput(manager, iter.first); });
298   }
299   return IsRecomputeKGraphCaller(node);
300 }
301 
GetGradUsers(const FuncGraphManagerPtr & manager,const CNodePtr & node,const CNodePtr & pre_node,std::vector<AnfNodePtr> * grad_users)302 void GetGradUsers(const FuncGraphManagerPtr &manager, const CNodePtr &node, const CNodePtr &pre_node,
303                   std::vector<AnfNodePtr> *grad_users) {
304   // The forward part may have multiple outputs.
305   if (IsOneOfPrimitiveCNode(node, {prim::kPrimTupleGetItem, prim::kPrimDepend})) {
306     const auto &user_nodes = manager->node_users()[node];
307     for (const auto &iter : user_nodes) {
308       GetGradUsers(manager, iter.first->cast<CNodePtr>(), node, grad_users);
309     }
310     return;
311   }
312   if (IsGradNode(node)) {
313     const auto &inputs = node->inputs();
314     for (size_t i = 1; i < inputs.size(); ++i) {
315       if (inputs[i] != pre_node && !inputs[i]->isa<ValueNode>() && IsGradNode(inputs[i])) {
316         (void)grad_users->emplace_back(inputs[i]);
317       }
318     }
319   }
320 }
321 
IsFromForwardGetter(const AnfNodePtr & forward_getter,const AnfNodePtr & depend_node)322 bool IsFromForwardGetter(const AnfNodePtr &forward_getter, const AnfNodePtr &depend_node) {
323   if (forward_getter == depend_node) {
324     return true;
325   }
326   if (!IsOneOfPrimitiveCNode(depend_node, {prim::kPrimTupleGetItem, prim::kPrimMakeTuple, prim::kPrimZerosLike})) {
327     return false;
328   }
329   const auto &depend_node_inputs = depend_node->cast<CNodePtr>()->inputs();
330   return std::any_of(depend_node_inputs.begin(), depend_node_inputs.end(),
331                      [&forward_getter](const auto &input) { return IsFromForwardGetter(forward_getter, input); });
332 }
333 
GetDependencies(const FuncGraphManagerPtr & manager,const CNodePtr & k_fg_caller,mindspore::CompactSet<CNodePtr> * final_nodes,mindspore::CompactSet<AnfNodePtr> * dependencies)334 void GetDependencies(const FuncGraphManagerPtr &manager, const CNodePtr &k_fg_caller,
335                      mindspore::CompactSet<CNodePtr> *final_nodes, mindspore::CompactSet<AnfNodePtr> *dependencies) {
336   if (final_nodes->find(k_fg_caller) != final_nodes->end()) {
337     return;
338   }
339   bool is_recompute_k_fg_caller = IsRecomputeKGraphCaller(k_fg_caller);
340   // We only handle the recomputed k graph caller.
341   if (!is_recompute_k_fg_caller &&
342       !IsOneOfPrimitiveCNode(k_fg_caller, {prim::kPrimTupleGetItem, prim::kPrimMakeTuple, prim::kPrimDepend})) {
343     return;
344   }
345   if (is_recompute_k_fg_caller) {
346     auto forward_getter = GetForwardGetter(manager, k_fg_caller);
347     // If the k graph caller has no forward getter, it should not output to any other recomputed nodes.
348     if (forward_getter == nullptr) {
349       auto bprop_caller = GetBpropCaller(manager, GetBpropGetter(manager, k_fg_caller));
350       // Add the dout input of its bprop function to the dependencies.
351       if (bprop_caller == nullptr) {
352         return;
353       }
354       (void)final_nodes->insert(k_fg_caller);
355       (void)dependencies->insert(bprop_caller->cast<CNodePtr>()->input(1));
356       return;
357     }
358     if (!HasRecomputedOutput(manager, forward_getter)) {
359       std::vector<AnfNodePtr> grad_users;
360       // Add the other inputs of the grad node to the dependencies.
361       GetGradUsers(manager, forward_getter->cast<CNodePtr>(), k_fg_caller, &grad_users);
362       if (!grad_users.empty()) {
363         for (auto &user : grad_users) {
364           (void)final_nodes->insert(k_fg_caller);
365           (void)dependencies->insert(user);
366         }
367         return;
368       }
369       // Add the dout input of its bprop function to the dependencies.
370       auto bprop_caller = GetBpropCaller(manager, GetBpropGetter(manager, k_fg_caller));
371       if (bprop_caller == nullptr) {
372         return;
373       }
374       (void)final_nodes->insert(k_fg_caller);
375       auto dout = bprop_caller->cast<CNodePtr>()->input(1);
376       if (IsPrimitiveCNode(dout, prim::kPrimMakeTuple) && IsFromForwardGetter(forward_getter, dout)) {
377         return;
378       }
379       (void)dependencies->insert(dout);
380       return;
381     }
382   }
383 
384   const auto &user_nodes = manager->node_users()[k_fg_caller];
385   for (const auto &iter : user_nodes) {
386     if (IsPrimitiveCNode(iter.first, prim::kPrimTupleGetItem)) {
387       auto idx = GetValueNode<Int64ImmPtr>(iter.first->cast<CNodePtr>()->input(kInputNodeOutputIndexInTupleGetItem));
388       // Skip bprop getter.
389       if (idx != nullptr && idx->value() == 1 && is_recompute_k_fg_caller) {
390         continue;
391       }
392     }
393     GetDependencies(manager, iter.first->cast<CNodePtr>(), final_nodes, dependencies);
394   }
395 }
396 
CopyOriginalInputs(const FuncGraphPtr & bprop_fg,const CNodePtr & node,const AnfNodePtr & depend_nodes,std::vector<AnfNodePtr> * new_inputs)397 void CopyOriginalInputs(const FuncGraphPtr &bprop_fg, const CNodePtr &node, const AnfNodePtr &depend_nodes,
398                         std::vector<AnfNodePtr> *new_inputs) {
399   (void)std::transform(
400     node->inputs().begin(), node->inputs().end(), std::back_inserter(*new_inputs),
401     [&bprop_fg](const AnfNodePtr &input) -> AnfNodePtr {
402       // Make sure there is only one u monad fv.
403       if (HasAbstractUMonad(input) && input->func_graph() != nullptr && input->func_graph() != bprop_fg) {
404         return NewValueNode(kUMonad);
405       }
406       return input;
407     });
408   // The recomputed cell should insert depend node at all inputs.
409   if (!IsRecomputeCell(GetValueNode<FuncGraphPtr>(node->input(0)))) {
410     auto depend = bprop_fg->NewCNode({NewValueNode(prim::kPrimDepend), (*new_inputs)[1], depend_nodes});
411     depend->AddAttr(kRecomputeInsert, MakeValue(true));
412     (*new_inputs)[1] = depend;
413   }
414 }
415 
MoveKCallerToBprop(const FuncGraphManagerPtr & manager,const FuncGraphPtr & bprop_fg,const CNodePtr & node,const AnfNodePtr & depend_nodes,std::unordered_map<CNodePtr,CNodePtr> * origin_to_new_nodes)416 CNodePtr MoveKCallerToBprop(const FuncGraphManagerPtr &manager, const FuncGraphPtr &bprop_fg, const CNodePtr &node,
417                             const AnfNodePtr &depend_nodes,
418                             std::unordered_map<CNodePtr, CNodePtr> *origin_to_new_nodes) {
419   auto iter = origin_to_new_nodes->find(node);
420   if (iter != origin_to_new_nodes->end()) {
421     return iter->second;
422   }
423   std::vector<AnfNodePtr> new_inputs;
424   if (IsRecomputeKGraphCaller(node)) {
425     if (!node->HasAttr(kAttrReplacedWithPrimal)) {
426       return node;
427     }
428     if (!HasRecomputedInput(node)) {
429       CopyOriginalInputs(bprop_fg, node, depend_nodes, &new_inputs);
430     } else {
431       for (auto &input : node->inputs()) {
432         if (!input->isa<CNode>()) {
433           (void)new_inputs.emplace_back(input);
434           continue;
435         }
436         (void)new_inputs.emplace_back(
437           MoveKCallerToBprop(manager, bprop_fg, input->cast<CNodePtr>(), depend_nodes, origin_to_new_nodes));
438       }
439     }
440     if (IsRecomputeCell(GetValueNode<FuncGraphPtr>(node->input(0)))) {
441       // Add the dout input of its bprop function to the dependencies.
442       auto new_depend_nodes = depend_nodes;
443       auto bprop_caller = GetBpropCaller(manager, GetBpropGetter(manager, node));
444       if (bprop_caller != nullptr) {
445         std::vector<AnfNodePtr> new_depend_nodes_inputs;
446         (void)std::copy(depend_nodes->cast<CNodePtr>()->inputs().begin(),
447                         depend_nodes->cast<CNodePtr>()->inputs().end(), std::back_inserter(new_depend_nodes_inputs));
448         (void)new_depend_nodes_inputs.emplace_back(bprop_caller->cast<CNodePtr>()->input(1));
449         new_depend_nodes = bprop_fg->NewCNode(new_depend_nodes_inputs);
450       }
451       for (size_t i = 1; i < new_inputs.size(); ++i) {
452         auto depend = bprop_fg->NewCNode({NewValueNode(prim::kPrimDepend), new_inputs[i], new_depend_nodes});
453         depend->AddAttr(kRecomputeInsert, MakeValue(true));
454         new_inputs[i] = depend;
455       }
456     }
457     auto new_k_fg_caller = bprop_fg->NewCNode(new_inputs);
458     new_k_fg_caller->AddAttr(kAddedRecomputeDependAttr, MakeValue(true));
459     new_k_fg_caller->AddAttr(kAttrReplacedWithPrimal, MakeValue(true));
460     auto primal_fg_caller = node->user_data<CNode>(kPrimalFgCallerUserDataKey);
461     if (primal_fg_caller != nullptr) {
462       new_k_fg_caller->set_user_data(kPrimalFgCallerUserDataKey, primal_fg_caller);
463     }
464     // Replace the bprop getter with the new k graph caller in bprop graph.
465     auto origin_bprop_getter = GetBpropGetter(manager, node);
466     if (origin_bprop_getter != nullptr) {
467       auto new_bprop_getter = bprop_fg->NewCNodeInOrder(
468         {NewValueNode(prim::kPrimTupleGetItem), new_k_fg_caller, NewValueNode(static_cast<int64_t>(1))});
469       new_bprop_getter->set_abstract(origin_bprop_getter->abstract());
470       (void)manager->Replace(origin_bprop_getter, new_bprop_getter);
471     }
472     (void)origin_to_new_nodes->emplace(node, new_k_fg_caller);
473     return new_k_fg_caller;
474   }
475   // If it is not tuple_getitem, it should be node which is not set recomputed.
476   if (!IsOneOfPrimitiveCNode(
477         node, {prim::kPrimTupleGetItem, prim::kPrimMakeTuple, prim::kPrimDepend, prim::kPrimUpdateState})) {
478     return node;
479   }
480   // If the other branch has not been handle, it should not create new forward getter.
481   if (IsForwardGetterTupleGetItem(node)) {
482     auto real_node = node->cast<CNodePtr>()->input(1);
483     if (IsRecomputeKGraphCaller(real_node) && !real_node->cast<CNodePtr>()->HasAttr(kAttrReplacedWithPrimal)) {
484       return node;
485     }
486   }
487   for (auto &input : node->inputs()) {
488     if (!input->isa<CNode>()) {
489       (void)new_inputs.emplace_back(input);
490       continue;
491     }
492     (void)new_inputs.emplace_back(
493       MoveKCallerToBprop(manager, bprop_fg, input->cast<CNodePtr>(), depend_nodes, origin_to_new_nodes));
494   }
495   auto new_node = bprop_fg->NewCNode(new_inputs);
496   (void)origin_to_new_nodes->emplace(node, new_node);
497   return new_node;
498 }
499 
GetKGraphCallerFromTupleGetitem(const AnfNodePtr & node)500 CNodePtr GetKGraphCallerFromTupleGetitem(const AnfNodePtr &node) {
501   auto idx = GetValueNode<Int64ImmPtr>(node->cast<CNodePtr>()->input(kInputNodeOutputIndexInTupleGetItem));
502   // The k_fg_caller return a tuple of forward result and bprop.
503   if (idx == nullptr || idx->value() != 0) {
504     return nullptr;
505   }
506   auto k_fg_caller = node->cast<CNodePtr>()->input(1);
507   MS_EXCEPTION_IF_NULL(k_fg_caller);
508   return k_fg_caller->cast<CNodePtr>();
509 }
510 
ReplaceFinalForwardGetter(const FuncGraphManagerPtr & manager,const FuncGraphPtr & fg,const AnfNodePtr & origin_forward_getter,const AnfNodePtr & new_forward_getter)511 void ReplaceFinalForwardGetter(const FuncGraphManagerPtr &manager, const FuncGraphPtr &fg,
512                                const AnfNodePtr &origin_forward_getter, const AnfNodePtr &new_forward_getter) {
513   auto node_users = manager->node_users()[origin_forward_getter];
514   for (auto &node_and_idx : node_users) {
515     auto user = node_and_idx.first;
516     MS_EXCEPTION_IF_NULL(user);
517     MS_LOG(DEBUG) << "User: " << user->DebugString();
518     // The forward part may have multiple outputs.
519     if (IsPrimitiveCNode(user, prim::kPrimTupleGetItem)) {
520       // Make new tuple_getitem to get corresponding output.
521       auto new_getitem = fg->NewCNode({NewValueNode(prim::kPrimTupleGetItem), new_forward_getter,
522                                        user->cast_ptr<CNode>()->input(kInputNodeOutputIndexInTupleGetItem)});
523       ReplaceFinalForwardGetter(manager, fg, user, new_getitem);
524       continue;
525     }
526     if (IsPrimitiveCNode(user, prim::kPrimDepend)) {
527       // Make new depend to get corresponding output.
528       auto new_depend = fg->NewCNode(user->cast_ptr<CNode>()->inputs());
529       new_depend->set_input(IntToSize(node_and_idx.second), new_forward_getter);
530       ReplaceFinalForwardGetter(manager, fg, user, new_depend);
531       continue;
532     }
533     MS_LOG(DEBUG) << "Set edge for user: " << user->DebugString();
534     manager->SetEdge(user, node_and_idx.second, new_forward_getter);
535   }
536 }
537 
GetAllRecomputeKFgCallers(const CNodePtr & final_node,mindspore::HashSet<CNodePtr> * recompute_k_fg_callers)538 void GetAllRecomputeKFgCallers(const CNodePtr &final_node, mindspore::HashSet<CNodePtr> *recompute_k_fg_callers) {
539   for (const auto &input : final_node->inputs()) {
540     if (!input->isa<CNode>()) {
541       continue;
542     }
543     auto input_cnode = input->cast<CNodePtr>();
544     if (IsPrimitiveCNode(input_cnode, prim::kPrimTupleGetItem)) {
545       GetAllRecomputeKFgCallers(input_cnode, recompute_k_fg_callers);
546       continue;
547     }
548     // Only get the nodes visited in this round.
549     if (!input_cnode->HasAttr(kAttrReplacedWithPrimal) || !IsRecomputeKGraphCaller(input) ||
550         recompute_k_fg_callers->find(input_cnode) != recompute_k_fg_callers->end()) {
551       continue;
552     }
553     (void)recompute_k_fg_callers->insert(input_cnode);
554     GetAllRecomputeKFgCallers(input_cnode, recompute_k_fg_callers);
555   }
556 }
557 
IsFromRecomputeKFgCaller(const FuncGraphPtr & bprop_fg,const mindspore::HashSet<CNodePtr> & recompute_k_fg_callers,const CNodePtr & node,mindspore::HashMap<CNodePtr,bool> * is_from_recompute_k_fg_caller)558 bool IsFromRecomputeKFgCaller(const FuncGraphPtr &bprop_fg, const mindspore::HashSet<CNodePtr> &recompute_k_fg_callers,
559                               const CNodePtr &node, mindspore::HashMap<CNodePtr, bool> *is_from_recompute_k_fg_caller) {
560   auto iter = is_from_recompute_k_fg_caller->find(node);
561   if (iter != is_from_recompute_k_fg_caller->end()) {
562     return iter->second;
563   }
564   if (recompute_k_fg_callers.find(node) != recompute_k_fg_callers.end()) {
565     (void)is_from_recompute_k_fg_caller->emplace(node, true);
566     return true;
567   }
568 
569   for (const auto &input : node->inputs()) {
570     MS_EXCEPTION_IF_NULL(input);
571     if (!input->isa<CNode>()) {
572       continue;
573     }
574     auto input_cnode = input->cast<CNodePtr>();
575     if (input_cnode->func_graph() != bprop_fg) {
576       AnfNodePtr cur_node = input_cnode;
577       while (IsPrimitiveCNode(cur_node, prim::kPrimTupleGetItem)) {
578         cur_node = cur_node->cast<CNodePtr>()->input(1);
579       }
580       if (cur_node->isa<CNode>() &&
581           recompute_k_fg_callers.find(cur_node->cast<CNodePtr>()) != recompute_k_fg_callers.end()) {
582         (void)is_from_recompute_k_fg_caller->emplace(node, true);
583         return true;
584       }
585       continue;
586     }
587     if (IsFromRecomputeKFgCaller(bprop_fg, recompute_k_fg_callers, input_cnode, is_from_recompute_k_fg_caller)) {
588       (void)is_from_recompute_k_fg_caller->emplace(node, true);
589       return true;
590     }
591   }
592   (void)is_from_recompute_k_fg_caller->emplace(node, false);
593   return false;
594 }
595 
AddDependNodes(const FuncGraphManagerPtr & manager,const FuncGraphPtr & fg,const CNodePtr & k_fg_caller_cnode)596 void AddDependNodes(const FuncGraphManagerPtr &manager, const FuncGraphPtr &fg, const CNodePtr &k_fg_caller_cnode) {
597   // Get the nodes which the recomputed part should depend on;
598   mindspore::CompactSet<CNodePtr> final_nodes;
599   mindspore::CompactSet<AnfNodePtr> dependencies;
600   GetDependencies(manager, k_fg_caller_cnode, &final_nodes, &dependencies);
601   if (dependencies.empty()) {
602     return;
603   }
604   FuncGraphPtr bprop_fg;
605   auto bprop_caller = GetBpropCaller(manager, GetBpropGetter(manager, k_fg_caller_cnode));
606   if (bprop_caller == nullptr) {
607     bprop_fg = (*dependencies.begin())->func_graph();
608   } else {
609     bprop_fg = bprop_caller->func_graph();
610   }
611   MS_EXCEPTION_IF_NULL(bprop_fg);
612   // Filter the dependent nodes in case of producing loops.
613   mindspore::HashSet<CNodePtr> recompute_k_fg_callers;
614   for (const auto &final_node : final_nodes) {
615     (void)recompute_k_fg_callers.insert(final_node);
616     GetAllRecomputeKFgCallers(final_node, &recompute_k_fg_callers);
617   }
618   std::vector<AnfNodePtr> depend_inputs{NewValueNode(prim::kPrimMakeTuple)};
619   mindspore::HashMap<CNodePtr, bool> is_from_recompute_k_fg_caller;
620   (void)std::copy_if(dependencies.begin(), dependencies.end(), std::back_inserter(depend_inputs),
621                      [bprop_fg, &recompute_k_fg_callers, &is_from_recompute_k_fg_caller](const AnfNodePtr &dependency) {
622                        if (!dependency->isa<CNode>()) {
623                          return true;
624                        }
625                        return !IsFromRecomputeKFgCaller(bprop_fg, recompute_k_fg_callers, dependency->cast<CNodePtr>(),
626                                                         &is_from_recompute_k_fg_caller);
627                      });
628   // Add the dependency nodes to the first recomputed nodes.
629   auto depend_nodes = bprop_fg->NewCNode(depend_inputs);
630   if (bprop_fg == fg) {
631     if (!IsRecomputeCell(GetValueNode<FuncGraphPtr>(k_fg_caller_cnode->input(0)))) {
632       auto depend = fg->NewCNode({NewValueNode(prim::kPrimDepend), k_fg_caller_cnode->input(1), depend_nodes});
633       depend->AddAttr(kRecomputeInsert, MakeValue(true));
634       manager->SetEdge(k_fg_caller_cnode, 1, depend);
635       k_fg_caller_cnode->AddAttr(kAddedRecomputeDependAttr, MakeValue(true));
636     } else {
637       std::vector<AnfNodePtr> new_k_fg_caller_inputs{k_fg_caller_cnode->input(0)};
638       (void)std::transform(k_fg_caller_cnode->inputs().begin() + 1, k_fg_caller_cnode->inputs().end(),
639                            std::back_inserter(new_k_fg_caller_inputs),
640                            [&fg, &depend_nodes](const AnfNodePtr &input) -> AnfNodePtr {
641                              auto depend = fg->NewCNodeInOrder({NewValueNode(prim::kPrimDepend), input, depend_nodes});
642                              depend->AddAttr(kRecomputeInsert, MakeValue(true));
643                              return depend;
644                            });
645       auto new_k_fg_caller = fg->NewCNodeInOrder(new_k_fg_caller_inputs);
646       auto primal_fg_caller = k_fg_caller_cnode->user_data<CNode>(kPrimalFgCallerUserDataKey);
647       if (primal_fg_caller != nullptr) {
648         new_k_fg_caller->set_user_data(kPrimalFgCallerUserDataKey, primal_fg_caller);
649       }
650       (void)manager->Replace(k_fg_caller_cnode, new_k_fg_caller);
651       new_k_fg_caller->AddAttr(kAddedRecomputeDependAttr, MakeValue(true));
652       new_k_fg_caller->AddAttr(kAttrReplacedWithPrimal, MakeValue(true));
653     }
654     return;
655   }
656   // If the graph of the bprop caller is not the same as the graph of k graph caller, we should move the k graph
657   // caller to the graph of the bprop.
658   std::unordered_map<CNodePtr, CNodePtr> origin_to_new_nodes;
659   for (const auto &final_node : final_nodes) {
660     auto new_k_fg_caller = MoveKCallerToBprop(manager, bprop_fg, final_node, depend_nodes, &origin_to_new_nodes);
661     new_k_fg_caller->AddAttr(kAddedRecomputeDependAttr, MakeValue(true));
662   }
663   for (auto &iter : origin_to_new_nodes) {
664     if (!IsRecomputeKGraphCaller(iter.first)) {
665       continue;
666     }
667     auto forward_getter = GetForwardGetter(manager, iter.first);
668     if (forward_getter == nullptr) {
669       (void)manager->Replace(iter.first, iter.second);
670     } else {
671       auto new_forward_getter =
672         bprop_fg->NewCNode({NewValueNode(prim::kPrimTupleGetItem), iter.second, NewValueNode(static_cast<int64_t>(0))});
673       ReplaceFinalForwardGetter(manager, bprop_fg, forward_getter, new_forward_getter);
674     }
675   }
676 }
677 
AddDuplicatedAttr(const FuncGraphPtr & k_fg)678 void AddDuplicatedAttr(const FuncGraphPtr &k_fg) {
679   for (const auto &node : k_fg->nodes()) {
680     if (!node->isa<CNode>()) {
681       continue;
682     }
683     node->cast_ptr<CNode>()->AddAttr(kAttrDuplicated, MakeValue(true));
684   }
685 }
686 
AddCseAttr(const FuncGraphPtr & root,bool changed)687 void AddCseAttr(const FuncGraphPtr &root, bool changed) {
688   if (!changed) {
689     return;
690   }
691   auto all_node = TopoSort(root->get_return(), SuccDeeperSimple, AlwaysInclude);
692   for (const auto &node : all_node) {
693     if (WithRecomputedScope(node)) {
694       node->cast<CNodePtr>()->AddAttr(kAttrNeedCseAfterRecompute, MakeValue(true));
695     }
696   }
697 }
698 
GetPrimal(const FuncGraphPtr & k_fg,bool * recompute_cell)699 AnfNodePtr GetPrimal(const FuncGraphPtr &k_fg, bool *recompute_cell) {
700   auto primal_iter = k_fg->transforms().find("primal");
701   if (primal_iter == k_fg->transforms().end()) {
702     return nullptr;
703   }
704   AnfNodePtr primal = nullptr;
705   auto primal_fg = primal_iter->second.func_graph();
706   if (primal_fg != nullptr) {
707     primal = NewValueNode(primal_fg);
708     *recompute_cell = true;
709   } else {
710     auto primal_primitive = primal_iter->second.primitive();
711     if (primal_primitive != nullptr) {
712       primal = NewValueNode(primal_primitive);
713     }
714   }
715   return primal;
716 }
717 
IsNestedRecomputed(const AnfNodePtr & node)718 bool IsNestedRecomputed(const AnfNodePtr &node) {
719   auto fg = node->func_graph();
720   MS_EXCEPTION_IF_NULL(fg);
721   return fg->has_flag(FUNC_GRAPH_RECOMPUTE_K_GRAPH);
722 }
723 
SetPrimalAttrs(const CNodePtr & new_primal,const FuncGraphPtr & k_fg)724 void SetPrimalAttrs(const CNodePtr &new_primal, const FuncGraphPtr &k_fg) {
725   auto forward_in_k_fg = GetPrimalFromFprop(k_fg);
726   auto forward_cnode_in_k_fg = dyn_cast<CNode>(forward_in_k_fg);
727   if (forward_cnode_in_k_fg != nullptr) {
728     new_primal->set_primal_attrs(forward_cnode_in_k_fg->primal_attrs());
729   }
730 }
731 }  // namespace
732 
AddRecomputeNodes(const FuncGraphPtr & root,const opt::OptimizerPtr & opt)733 bool AddRecomputeNodes(const FuncGraphPtr &root, const opt::OptimizerPtr &opt) {
734   if (!EnableCellReuse()) {
735     return false;
736   }
737 #ifdef ENABLE_DUMP_IR
738   auto context = MsContext::GetInstance();
739   MS_EXCEPTION_IF_NULL(context);
740   bool enable_save_graphs = context->CanDump(kIntroductory);
741   if (enable_save_graphs) {
742     DumpIR("before_recompute_root.ir", root);
743   }
744 #endif
745   MS_EXCEPTION_IF_NULL(root);
746   MS_EXCEPTION_IF_NULL(opt);
747   auto manager = opt->manager();
748   MS_EXCEPTION_IF_NULL(manager);
749   bool changed = false;
750   auto all_node = TopoSort(root->get_return(), SuccDeeperSimple, AlwaysInclude);
751   for (auto iter = all_node.crbegin(); iter != all_node.crend(); (void)iter++) {
752     const auto &node = *iter;
753     if (!IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
754       continue;
755     }
756     auto k_fg_caller_cnode = GetKGraphCallerFromTupleGetitem(node);
757     if (k_fg_caller_cnode == nullptr || k_fg_caller_cnode->HasAttr(kAddedRecomputeDependAttr)) {
758       continue;
759     }
760     auto k_fg = GetValueNode<FuncGraphPtr>(k_fg_caller_cnode->input(0));
761     if (k_fg == nullptr || !k_fg->has_flag(FUNC_GRAPH_RECOMPUTE_K_GRAPH)) {
762       continue;
763     }
764     if (IsNestedRecomputed(k_fg_caller_cnode)) {
765       MS_LOG(WARNING)
766         << "The node and its graph both have been set recomputed, the node would not be handled. The node: "
767         << k_fg_caller_cnode->DebugString();
768       continue;
769     }
770     bool recompute_cell = false;
771     auto primal = GetPrimal(k_fg, &recompute_cell);
772     if (primal == nullptr) {
773       continue;
774     }
775     // Replace the forward getter with the origin primal.
776     constexpr auto recursive_level = 2;
777     MS_LOG(DEBUG) << "Handle recompute k graph forward getter: " << node->DebugString(recursive_level);
778     std::vector<AnfNodePtr> inputs{primal};
779     (void)inputs.insert(inputs.cend(), k_fg_caller_cnode->inputs().begin() + 1, k_fg_caller_cnode->inputs().end());
780     auto fg = node->func_graph();
781     MS_EXCEPTION_IF_NULL(fg);
782     auto new_primal = fg->NewCNodeInOrder(inputs);
783     if (IsValueNode<Primitive>(primal)) {
784       SetPrimalAttrs(new_primal, k_fg);
785     }
786     std::unordered_map<AnfNodePtr, AnfNodePtr> origin_to_new_primal;
787     bool change = AddNewPrimalNode(manager, fg, node, new_primal, recompute_cell, &origin_to_new_primal);
788     changed = change || changed;
789     if (change && recompute_cell) {
790       k_fg_caller_cnode->set_user_data(kPrimalFgCallerUserDataKey, new_primal);
791     }
792     k_fg_caller_cnode->AddAttr(kAttrReplacedWithPrimal, MakeValue(true));
793     // Add duplicated attr to help debugging.
794     AddDuplicatedAttr(k_fg);
795     if (HasRecomputedInput(k_fg_caller_cnode)) {
796       continue;
797     }
798 
799     MS_LOG(DEBUG) << "Not has recomputed input k_fg_caller_cnode: " << k_fg_caller_cnode->DebugString();
800     AddDependNodes(manager, fg, k_fg_caller_cnode);
801   }
802   AddCseAttr(root, changed);
803 #ifdef ENABLE_DUMP_IR
804   if (enable_save_graphs) {
805     DumpIR("after_recompute_root.ir", root);
806   }
807 #endif
808   return changed;
809 }
810 }  // namespace irpass
811 }  // namespace opt
812 }  // namespace mindspore
813