1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/optimizer/irpass/recompute.h"
18 #include <set>
19 #include <unordered_map>
20 #include "ops/array_ops.h"
21
22 namespace mindspore {
23 namespace opt {
24 namespace irpass {
EnableCellReuse()25 bool EnableCellReuse() {
26 auto context = MsContext::GetInstance();
27 MS_EXCEPTION_IF_NULL(context);
28 const auto cell_reuse = context->CellReuseLevel() != CellReuseLevel::kNoCellReuse;
29 return cell_reuse;
30 }
31
HasBpropGetter(const OptimizerPtr & opt,const AnfNodePtr & k_fg_caller)32 bool HasBpropGetter(const OptimizerPtr &opt, const AnfNodePtr &k_fg_caller) {
33 MS_EXCEPTION_IF_NULL(opt);
34 auto manager = opt->manager();
35 MS_EXCEPTION_IF_NULL(manager);
36 const auto &node_users = manager->node_users();
37 auto iter = node_users.find(k_fg_caller);
38 if (iter == node_users.end()) {
39 MS_LOG(EXCEPTION) << "The node " << k_fg_caller->DebugString() << " should have users.";
40 }
41
42 return std::any_of(iter->second.begin(), iter->second.end(), [](const std::pair<AnfNodePtr, int> &node_and_idx) {
43 auto user = node_and_idx.first;
44 return IsPrimitiveCNode(user, prim::kPrimTupleGetItem) &&
45 common::AnfAlgo::GetTupleGetItemOutIndex(user->cast<CNodePtr>()) == 1;
46 });
47 }
48
GetBpropCaller(const FuncGraphManagerPtr & manager,const AnfNodePtr & bprop_getter)49 AnfNodePtr GetBpropCaller(const FuncGraphManagerPtr &manager, const AnfNodePtr &bprop_getter) {
50 MS_EXCEPTION_IF_NULL(manager);
51 const auto &node_users = manager->node_users();
52 auto iter = node_users.find(bprop_getter);
53 if (iter == node_users.end()) {
54 return nullptr;
55 }
56 if (iter->second.size() != 1) {
57 MS_LOG(EXCEPTION) << "The number of bprop caller should be 1, but got " << iter->second.size()
58 << ", bprop_getter: " << bprop_getter->DebugString();
59 }
60 auto user_node_idx = iter->second.begin();
61 if (user_node_idx->second != 0) {
62 MS_LOG(EXCEPTION) << "The bprop_getter should be used in input 0, but got " << user_node_idx->second;
63 }
64 return user_node_idx->first;
65 }
66
67 namespace {
68 constexpr auto kGradientsFlag = "Gradients";
69 constexpr auto kAttrReplacedWithPrimal = "replaced_with_primal";
70 constexpr auto kAttrRecomputeMakeTuple = "recompute_make_tuple";
71
WithRecomputedScope(const AnfNodePtr & node)72 bool WithRecomputedScope(const AnfNodePtr &node) {
73 MS_EXCEPTION_IF_NULL(node);
74 if (!node->isa<CNode>()) {
75 return false;
76 }
77 const auto &full_name_with_scope = node->fullname_with_scope();
78 return full_name_with_scope.compare(0, strlen(kAttrRecompute), kAttrRecompute) == 0;
79 }
80
IsRecomputeKGraphCaller(const AnfNodePtr & node)81 bool IsRecomputeKGraphCaller(const AnfNodePtr &node) {
82 auto cnode = dyn_cast_ptr<CNode>(node);
83 if (cnode == nullptr) {
84 return false;
85 }
86 auto call_fg = GetValueNode<FuncGraphPtr>(cnode->input(0));
87 if (call_fg != nullptr && call_fg->has_flag(FUNC_GRAPH_RECOMPUTE_K_GRAPH)) {
88 return true;
89 }
90 return false;
91 }
92
WithGradientScope(const AnfNodePtr & node)93 bool WithGradientScope(const AnfNodePtr &node) {
94 return node->fullname_with_scope().compare(0, strlen(kGradientsFlag), kGradientsFlag) == 0;
95 }
96
IsFromBpropOutput(const AnfNodePtr & node)97 bool IsFromBpropOutput(const AnfNodePtr &node) {
98 if (!IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
99 return false;
100 }
101 auto cur_node = node;
102 while (IsPrimitiveCNode(cur_node, prim::kPrimTupleGetItem)) {
103 cur_node = cur_node->cast<CNodePtr>()->input(kRealInputNodeIndexInTupleGetItem);
104 }
105 if (WithGradientScope(cur_node)) {
106 return true;
107 }
108 auto cur_cnode = cur_node->cast<CNodePtr>();
109 if (cur_cnode == nullptr) {
110 return false;
111 }
112 auto func_abs = dyn_cast<abstract::FuncGraphAbstractClosure>(cur_cnode->input(0)->abstract());
113 if (func_abs == nullptr) {
114 return false;
115 }
116 auto fg = func_abs->func_graph();
117 MS_EXCEPTION_IF_NULL(fg);
118 return fg->has_flag(FUNC_GRAPH_RECOMPUTE_GRAD_GRAPH);
119 }
120
IsGradNode(const AnfNodePtr & node)121 bool IsGradNode(const AnfNodePtr &node) {
122 MS_EXCEPTION_IF_NULL(node);
123 return WithGradientScope(node) || IsFromBpropOutput(node);
124 }
125
IsFpropReturn(const AnfNodePtr & make_tuple)126 bool IsFpropReturn(const AnfNodePtr &make_tuple) {
127 auto cnode = make_tuple->cast<CNodePtr>();
128 constexpr size_t fprop_output_size = 2;
129 if (cnode->size() != fprop_output_size + 1) {
130 return false;
131 }
132 return IsValueNode<FuncGraph>(cnode->input(fprop_output_size));
133 }
134
GetPrimalFromFprop(const FuncGraphPtr & k_fg)135 AnfNodePtr GetPrimalFromFprop(const FuncGraphPtr &k_fg) {
136 if (!IsPrimitiveCNode(k_fg->output(), prim::kPrimMakeTuple)) {
137 return nullptr;
138 }
139 auto k_fg_outputs = k_fg->output()->cast<CNodePtr>()->inputs();
140 if (k_fg_outputs.size() != 3) {
141 return nullptr;
142 }
143 return k_fg_outputs[kIndex1];
144 }
145
ShouldAddNewPrimalOutput(const AnfNodePtr & node,bool recompute_cell)146 bool ShouldAddNewPrimalOutput(const AnfNodePtr &node, bool recompute_cell) {
147 return !IsGradNode(node) || recompute_cell;
148 }
149
IsForwardDepend(const AnfNodePtr & node)150 bool IsForwardDepend(const AnfNodePtr &node) {
151 return IsPrimitiveCNode(node, prim::kPrimDepend) && !node->cast_ptr<CNode>()->HasAttr(kRecomputeInsert);
152 }
153
AddNewPrimalNode(const FuncGraphManagerPtr & manager,const FuncGraphPtr & fg,const AnfNodePtr & origin_primal,const AnfNodePtr & new_primal,bool recompute_cell,std::unordered_map<AnfNodePtr,AnfNodePtr> * origin_to_new_primal)154 bool AddNewPrimalNode(const FuncGraphManagerPtr &manager, const FuncGraphPtr &fg, const AnfNodePtr &origin_primal,
155 const AnfNodePtr &new_primal, bool recompute_cell,
156 std::unordered_map<AnfNodePtr, AnfNodePtr> *origin_to_new_primal) {
157 bool changed = false;
158 auto node_users = manager->node_users()[origin_primal];
159 for (auto &node_and_idx : node_users) {
160 auto user = node_and_idx.first;
161 MS_EXCEPTION_IF_NULL(user);
162 // The forward part may have multiple outputs.
163 if (IsPrimitiveCNode(user, prim::kPrimTupleGetItem) && ShouldAddNewPrimalOutput(user, recompute_cell)) {
164 // Make new tuple_getitem to get corresponding output.
165 auto new_primal_getitem = fg->NewCNode({NewValueNode(prim::kPrimTupleGetItem), new_primal,
166 user->cast_ptr<CNode>()->input(kInputNodeOutputIndexInTupleGetItem)});
167 changed =
168 AddNewPrimalNode(manager, fg, user, new_primal_getitem, recompute_cell, origin_to_new_primal) || changed;
169 continue;
170 }
171 if (IsForwardDepend(user) && ShouldAddNewPrimalOutput(user, recompute_cell)) {
172 // Make new depend node in forward to get corresponding output.
173 auto new_depend = fg->NewCNode(user->cast_ptr<CNode>()->inputs());
174 new_depend->set_input(IntToSize(node_and_idx.second), new_primal);
175 changed = AddNewPrimalNode(manager, fg, user, new_depend, recompute_cell, origin_to_new_primal) || changed;
176 continue;
177 }
178 // The op like concat will have a make_tuple input.
179 if (IsPrimitiveCNode(user, prim::kPrimMakeTuple) && !IsFpropReturn(user) &&
180 ShouldAddNewPrimalOutput(user, recompute_cell)) {
181 auto user_cnode = user->cast<CNodePtr>();
182 MS_EXCEPTION_IF_NULL(user_cnode);
183 if (user_cnode->HasAttr(kAttrRecomputeMakeTuple)) {
184 manager->SetEdge(user_cnode, node_and_idx.second, new_primal);
185 continue;
186 }
187 auto iter = origin_to_new_primal->find(user);
188 if (iter != origin_to_new_primal->end()) {
189 // The new make_tuple has been created, just set its inputs.
190 manager->SetEdge(iter->second, node_and_idx.second, new_primal);
191 continue;
192 }
193 // Create a new primal make_tuple.
194 std::vector<AnfNodePtr> make_tuple_inputs{NewValueNode(prim::kPrimMakeTuple)};
195 for (size_t i = 1; i < user_cnode->size(); ++i) {
196 (void)make_tuple_inputs.emplace_back(user_cnode->input(i));
197 }
198 auto new_primal_make_tuple = fg->NewCNode(make_tuple_inputs);
199 new_primal_make_tuple->set_input(node_and_idx.second, new_primal);
200 new_primal_make_tuple->AddAttr(kAttrRecomputeMakeTuple, MakeValue(true));
201 (void)origin_to_new_primal->emplace(user, new_primal_make_tuple);
202 changed =
203 AddNewPrimalNode(manager, fg, user, new_primal_make_tuple, recompute_cell, origin_to_new_primal) || changed;
204 continue;
205 }
206
207 // Set edge to not recomputed primal nodes.
208 if (recompute_cell || (!IsRecomputeKGraphCaller(user) && !IsGradNode(user))) {
209 MS_LOG(DEBUG) << "Set edge to user: " << user->DebugString() << ", new primal: " << new_primal->DebugString();
210 manager->SetEdge(user, node_and_idx.second, new_primal);
211 changed = true;
212 }
213 }
214 return changed;
215 }
216
IsRecomputeCell(const FuncGraphPtr & k_fg)217 bool IsRecomputeCell(const FuncGraphPtr &k_fg) {
218 auto primal_iter = k_fg->transforms().find("primal");
219 if (primal_iter == k_fg->transforms().end()) {
220 MS_LOG(EXCEPTION) << "The k_fg: " << k_fg << " should have primal part.";
221 }
222 return primal_iter->second.func_graph() != nullptr;
223 }
224
HasRecomputedInput(const CNodePtr & k_fg_caller_cnode)225 bool HasRecomputedInput(const CNodePtr &k_fg_caller_cnode) {
226 for (auto &input : k_fg_caller_cnode->inputs()) {
227 if (IsPrimitiveCNode(input, prim::kPrimMakeTuple)) {
228 return HasRecomputedInput(input->cast<CNodePtr>());
229 }
230 if (IsPrimitiveCNode(input, prim::kPrimDepend) && HasRecomputedInput(input->cast<CNodePtr>())) {
231 return true;
232 }
233 // The recomputed input should be a tuple_getitem to get the forward part of recomputed k graph.
234 if (!IsPrimitiveCNode(input, prim::kPrimTupleGetItem)) {
235 continue;
236 }
237 auto tmp = input->cast<CNodePtr>()->input(1);
238 auto input_k_fg_caller = tmp;
239 // The forward part may have multiple outputs.
240 if (IsPrimitiveCNode(tmp, prim::kPrimTupleGetItem)) {
241 input_k_fg_caller = tmp->cast<CNodePtr>()->input(1);
242 }
243
244 auto cnode = dyn_cast_ptr<CNode>(input_k_fg_caller);
245 if (cnode == nullptr) {
246 continue;
247 }
248 auto call_fg = GetValueNode<FuncGraphPtr>(cnode->input(0));
249 // The output of recomputed cell would not be recomputed.
250 if (call_fg != nullptr && call_fg->has_flag(FUNC_GRAPH_RECOMPUTE_K_GRAPH) && !IsRecomputeCell(call_fg)) {
251 return true;
252 }
253 }
254 return false;
255 }
256
IsForwardGetterTupleGetItem(const AnfNodePtr & node)257 bool IsForwardGetterTupleGetItem(const AnfNodePtr &node) {
258 if (!IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
259 return false;
260 }
261 auto idx = GetValueNode<Int64ImmPtr>(node->cast<CNodePtr>()->input(kInputNodeOutputIndexInTupleGetItem));
262 if (idx != nullptr && idx->value() == 0) {
263 return true;
264 }
265 return false;
266 }
267
GetForwardGetter(const FuncGraphManagerPtr & manager,const CNodePtr & node)268 AnfNodePtr GetForwardGetter(const FuncGraphManagerPtr &manager, const CNodePtr &node) {
269 const auto &user_nodes = manager->node_users()[node];
270 auto iter = std::find_if(user_nodes.begin(), user_nodes.end(), [](const auto &node_and_idx) -> bool {
271 return IsForwardGetterTupleGetItem(node_and_idx.first);
272 });
273 if (iter != user_nodes.end()) {
274 return iter->first;
275 }
276 return nullptr;
277 }
278
GetBpropGetter(const FuncGraphManagerPtr & manager,const CNodePtr & node)279 AnfNodePtr GetBpropGetter(const FuncGraphManagerPtr &manager, const CNodePtr &node) {
280 const auto &user_nodes = manager->node_users()[node];
281 for (const auto &iter : user_nodes) {
282 if (IsPrimitiveCNode(iter.first, prim::kPrimTupleGetItem)) {
283 auto idx = GetValueNode<Int64ImmPtr>(iter.first->cast<CNodePtr>()->input(kInputNodeOutputIndexInTupleGetItem));
284 if (idx != nullptr && idx->value() == 1) {
285 return iter.first;
286 }
287 }
288 }
289 return nullptr;
290 }
291
HasRecomputedOutput(const FuncGraphManagerPtr & manager,const AnfNodePtr & node)292 bool HasRecomputedOutput(const FuncGraphManagerPtr &manager, const AnfNodePtr &node) {
293 // The forward part may have multiple outputs.
294 if (IsOneOfPrimitiveCNode(node, {prim::kPrimTupleGetItem, prim::kPrimMakeTuple, prim::kPrimDepend})) {
295 const auto &user_nodes = manager->node_users()[node];
296 return std::any_of(user_nodes.begin(), user_nodes.end(),
297 [&manager](const auto &iter) { return HasRecomputedOutput(manager, iter.first); });
298 }
299 return IsRecomputeKGraphCaller(node);
300 }
301
GetGradUsers(const FuncGraphManagerPtr & manager,const CNodePtr & node,const CNodePtr & pre_node,std::vector<AnfNodePtr> * grad_users)302 void GetGradUsers(const FuncGraphManagerPtr &manager, const CNodePtr &node, const CNodePtr &pre_node,
303 std::vector<AnfNodePtr> *grad_users) {
304 // The forward part may have multiple outputs.
305 if (IsOneOfPrimitiveCNode(node, {prim::kPrimTupleGetItem, prim::kPrimDepend})) {
306 const auto &user_nodes = manager->node_users()[node];
307 for (const auto &iter : user_nodes) {
308 GetGradUsers(manager, iter.first->cast<CNodePtr>(), node, grad_users);
309 }
310 return;
311 }
312 if (IsGradNode(node)) {
313 const auto &inputs = node->inputs();
314 for (size_t i = 1; i < inputs.size(); ++i) {
315 if (inputs[i] != pre_node && !inputs[i]->isa<ValueNode>() && IsGradNode(inputs[i])) {
316 (void)grad_users->emplace_back(inputs[i]);
317 }
318 }
319 }
320 }
321
IsFromForwardGetter(const AnfNodePtr & forward_getter,const AnfNodePtr & depend_node)322 bool IsFromForwardGetter(const AnfNodePtr &forward_getter, const AnfNodePtr &depend_node) {
323 if (forward_getter == depend_node) {
324 return true;
325 }
326 if (!IsOneOfPrimitiveCNode(depend_node, {prim::kPrimTupleGetItem, prim::kPrimMakeTuple, prim::kPrimZerosLike})) {
327 return false;
328 }
329 const auto &depend_node_inputs = depend_node->cast<CNodePtr>()->inputs();
330 return std::any_of(depend_node_inputs.begin(), depend_node_inputs.end(),
331 [&forward_getter](const auto &input) { return IsFromForwardGetter(forward_getter, input); });
332 }
333
GetDependencies(const FuncGraphManagerPtr & manager,const CNodePtr & k_fg_caller,mindspore::CompactSet<CNodePtr> * final_nodes,mindspore::CompactSet<AnfNodePtr> * dependencies)334 void GetDependencies(const FuncGraphManagerPtr &manager, const CNodePtr &k_fg_caller,
335 mindspore::CompactSet<CNodePtr> *final_nodes, mindspore::CompactSet<AnfNodePtr> *dependencies) {
336 if (final_nodes->find(k_fg_caller) != final_nodes->end()) {
337 return;
338 }
339 bool is_recompute_k_fg_caller = IsRecomputeKGraphCaller(k_fg_caller);
340 // We only handle the recomputed k graph caller.
341 if (!is_recompute_k_fg_caller &&
342 !IsOneOfPrimitiveCNode(k_fg_caller, {prim::kPrimTupleGetItem, prim::kPrimMakeTuple, prim::kPrimDepend})) {
343 return;
344 }
345 if (is_recompute_k_fg_caller) {
346 auto forward_getter = GetForwardGetter(manager, k_fg_caller);
347 // If the k graph caller has no forward getter, it should not output to any other recomputed nodes.
348 if (forward_getter == nullptr) {
349 auto bprop_caller = GetBpropCaller(manager, GetBpropGetter(manager, k_fg_caller));
350 // Add the dout input of its bprop function to the dependencies.
351 if (bprop_caller == nullptr) {
352 return;
353 }
354 (void)final_nodes->insert(k_fg_caller);
355 (void)dependencies->insert(bprop_caller->cast<CNodePtr>()->input(1));
356 return;
357 }
358 if (!HasRecomputedOutput(manager, forward_getter)) {
359 std::vector<AnfNodePtr> grad_users;
360 // Add the other inputs of the grad node to the dependencies.
361 GetGradUsers(manager, forward_getter->cast<CNodePtr>(), k_fg_caller, &grad_users);
362 if (!grad_users.empty()) {
363 for (auto &user : grad_users) {
364 (void)final_nodes->insert(k_fg_caller);
365 (void)dependencies->insert(user);
366 }
367 return;
368 }
369 // Add the dout input of its bprop function to the dependencies.
370 auto bprop_caller = GetBpropCaller(manager, GetBpropGetter(manager, k_fg_caller));
371 if (bprop_caller == nullptr) {
372 return;
373 }
374 (void)final_nodes->insert(k_fg_caller);
375 auto dout = bprop_caller->cast<CNodePtr>()->input(1);
376 if (IsPrimitiveCNode(dout, prim::kPrimMakeTuple) && IsFromForwardGetter(forward_getter, dout)) {
377 return;
378 }
379 (void)dependencies->insert(dout);
380 return;
381 }
382 }
383
384 const auto &user_nodes = manager->node_users()[k_fg_caller];
385 for (const auto &iter : user_nodes) {
386 if (IsPrimitiveCNode(iter.first, prim::kPrimTupleGetItem)) {
387 auto idx = GetValueNode<Int64ImmPtr>(iter.first->cast<CNodePtr>()->input(kInputNodeOutputIndexInTupleGetItem));
388 // Skip bprop getter.
389 if (idx != nullptr && idx->value() == 1 && is_recompute_k_fg_caller) {
390 continue;
391 }
392 }
393 GetDependencies(manager, iter.first->cast<CNodePtr>(), final_nodes, dependencies);
394 }
395 }
396
CopyOriginalInputs(const FuncGraphPtr & bprop_fg,const CNodePtr & node,const AnfNodePtr & depend_nodes,std::vector<AnfNodePtr> * new_inputs)397 void CopyOriginalInputs(const FuncGraphPtr &bprop_fg, const CNodePtr &node, const AnfNodePtr &depend_nodes,
398 std::vector<AnfNodePtr> *new_inputs) {
399 (void)std::transform(
400 node->inputs().begin(), node->inputs().end(), std::back_inserter(*new_inputs),
401 [&bprop_fg](const AnfNodePtr &input) -> AnfNodePtr {
402 // Make sure there is only one u monad fv.
403 if (HasAbstractUMonad(input) && input->func_graph() != nullptr && input->func_graph() != bprop_fg) {
404 return NewValueNode(kUMonad);
405 }
406 return input;
407 });
408 // The recomputed cell should insert depend node at all inputs.
409 if (!IsRecomputeCell(GetValueNode<FuncGraphPtr>(node->input(0)))) {
410 auto depend = bprop_fg->NewCNode({NewValueNode(prim::kPrimDepend), (*new_inputs)[1], depend_nodes});
411 depend->AddAttr(kRecomputeInsert, MakeValue(true));
412 (*new_inputs)[1] = depend;
413 }
414 }
415
MoveKCallerToBprop(const FuncGraphManagerPtr & manager,const FuncGraphPtr & bprop_fg,const CNodePtr & node,const AnfNodePtr & depend_nodes,std::unordered_map<CNodePtr,CNodePtr> * origin_to_new_nodes)416 CNodePtr MoveKCallerToBprop(const FuncGraphManagerPtr &manager, const FuncGraphPtr &bprop_fg, const CNodePtr &node,
417 const AnfNodePtr &depend_nodes,
418 std::unordered_map<CNodePtr, CNodePtr> *origin_to_new_nodes) {
419 auto iter = origin_to_new_nodes->find(node);
420 if (iter != origin_to_new_nodes->end()) {
421 return iter->second;
422 }
423 std::vector<AnfNodePtr> new_inputs;
424 if (IsRecomputeKGraphCaller(node)) {
425 if (!node->HasAttr(kAttrReplacedWithPrimal)) {
426 return node;
427 }
428 if (!HasRecomputedInput(node)) {
429 CopyOriginalInputs(bprop_fg, node, depend_nodes, &new_inputs);
430 } else {
431 for (auto &input : node->inputs()) {
432 if (!input->isa<CNode>()) {
433 (void)new_inputs.emplace_back(input);
434 continue;
435 }
436 (void)new_inputs.emplace_back(
437 MoveKCallerToBprop(manager, bprop_fg, input->cast<CNodePtr>(), depend_nodes, origin_to_new_nodes));
438 }
439 }
440 if (IsRecomputeCell(GetValueNode<FuncGraphPtr>(node->input(0)))) {
441 // Add the dout input of its bprop function to the dependencies.
442 auto new_depend_nodes = depend_nodes;
443 auto bprop_caller = GetBpropCaller(manager, GetBpropGetter(manager, node));
444 if (bprop_caller != nullptr) {
445 std::vector<AnfNodePtr> new_depend_nodes_inputs;
446 (void)std::copy(depend_nodes->cast<CNodePtr>()->inputs().begin(),
447 depend_nodes->cast<CNodePtr>()->inputs().end(), std::back_inserter(new_depend_nodes_inputs));
448 (void)new_depend_nodes_inputs.emplace_back(bprop_caller->cast<CNodePtr>()->input(1));
449 new_depend_nodes = bprop_fg->NewCNode(new_depend_nodes_inputs);
450 }
451 for (size_t i = 1; i < new_inputs.size(); ++i) {
452 auto depend = bprop_fg->NewCNode({NewValueNode(prim::kPrimDepend), new_inputs[i], new_depend_nodes});
453 depend->AddAttr(kRecomputeInsert, MakeValue(true));
454 new_inputs[i] = depend;
455 }
456 }
457 auto new_k_fg_caller = bprop_fg->NewCNode(new_inputs);
458 new_k_fg_caller->AddAttr(kAddedRecomputeDependAttr, MakeValue(true));
459 new_k_fg_caller->AddAttr(kAttrReplacedWithPrimal, MakeValue(true));
460 auto primal_fg_caller = node->user_data<CNode>(kPrimalFgCallerUserDataKey);
461 if (primal_fg_caller != nullptr) {
462 new_k_fg_caller->set_user_data(kPrimalFgCallerUserDataKey, primal_fg_caller);
463 }
464 // Replace the bprop getter with the new k graph caller in bprop graph.
465 auto origin_bprop_getter = GetBpropGetter(manager, node);
466 if (origin_bprop_getter != nullptr) {
467 auto new_bprop_getter = bprop_fg->NewCNodeInOrder(
468 {NewValueNode(prim::kPrimTupleGetItem), new_k_fg_caller, NewValueNode(static_cast<int64_t>(1))});
469 new_bprop_getter->set_abstract(origin_bprop_getter->abstract());
470 (void)manager->Replace(origin_bprop_getter, new_bprop_getter);
471 }
472 (void)origin_to_new_nodes->emplace(node, new_k_fg_caller);
473 return new_k_fg_caller;
474 }
475 // If it is not tuple_getitem, it should be node which is not set recomputed.
476 if (!IsOneOfPrimitiveCNode(
477 node, {prim::kPrimTupleGetItem, prim::kPrimMakeTuple, prim::kPrimDepend, prim::kPrimUpdateState})) {
478 return node;
479 }
480 // If the other branch has not been handle, it should not create new forward getter.
481 if (IsForwardGetterTupleGetItem(node)) {
482 auto real_node = node->cast<CNodePtr>()->input(1);
483 if (IsRecomputeKGraphCaller(real_node) && !real_node->cast<CNodePtr>()->HasAttr(kAttrReplacedWithPrimal)) {
484 return node;
485 }
486 }
487 for (auto &input : node->inputs()) {
488 if (!input->isa<CNode>()) {
489 (void)new_inputs.emplace_back(input);
490 continue;
491 }
492 (void)new_inputs.emplace_back(
493 MoveKCallerToBprop(manager, bprop_fg, input->cast<CNodePtr>(), depend_nodes, origin_to_new_nodes));
494 }
495 auto new_node = bprop_fg->NewCNode(new_inputs);
496 (void)origin_to_new_nodes->emplace(node, new_node);
497 return new_node;
498 }
499
GetKGraphCallerFromTupleGetitem(const AnfNodePtr & node)500 CNodePtr GetKGraphCallerFromTupleGetitem(const AnfNodePtr &node) {
501 auto idx = GetValueNode<Int64ImmPtr>(node->cast<CNodePtr>()->input(kInputNodeOutputIndexInTupleGetItem));
502 // The k_fg_caller return a tuple of forward result and bprop.
503 if (idx == nullptr || idx->value() != 0) {
504 return nullptr;
505 }
506 auto k_fg_caller = node->cast<CNodePtr>()->input(1);
507 MS_EXCEPTION_IF_NULL(k_fg_caller);
508 return k_fg_caller->cast<CNodePtr>();
509 }
510
ReplaceFinalForwardGetter(const FuncGraphManagerPtr & manager,const FuncGraphPtr & fg,const AnfNodePtr & origin_forward_getter,const AnfNodePtr & new_forward_getter)511 void ReplaceFinalForwardGetter(const FuncGraphManagerPtr &manager, const FuncGraphPtr &fg,
512 const AnfNodePtr &origin_forward_getter, const AnfNodePtr &new_forward_getter) {
513 auto node_users = manager->node_users()[origin_forward_getter];
514 for (auto &node_and_idx : node_users) {
515 auto user = node_and_idx.first;
516 MS_EXCEPTION_IF_NULL(user);
517 MS_LOG(DEBUG) << "User: " << user->DebugString();
518 // The forward part may have multiple outputs.
519 if (IsPrimitiveCNode(user, prim::kPrimTupleGetItem)) {
520 // Make new tuple_getitem to get corresponding output.
521 auto new_getitem = fg->NewCNode({NewValueNode(prim::kPrimTupleGetItem), new_forward_getter,
522 user->cast_ptr<CNode>()->input(kInputNodeOutputIndexInTupleGetItem)});
523 ReplaceFinalForwardGetter(manager, fg, user, new_getitem);
524 continue;
525 }
526 if (IsPrimitiveCNode(user, prim::kPrimDepend)) {
527 // Make new depend to get corresponding output.
528 auto new_depend = fg->NewCNode(user->cast_ptr<CNode>()->inputs());
529 new_depend->set_input(IntToSize(node_and_idx.second), new_forward_getter);
530 ReplaceFinalForwardGetter(manager, fg, user, new_depend);
531 continue;
532 }
533 MS_LOG(DEBUG) << "Set edge for user: " << user->DebugString();
534 manager->SetEdge(user, node_and_idx.second, new_forward_getter);
535 }
536 }
537
GetAllRecomputeKFgCallers(const CNodePtr & final_node,mindspore::HashSet<CNodePtr> * recompute_k_fg_callers)538 void GetAllRecomputeKFgCallers(const CNodePtr &final_node, mindspore::HashSet<CNodePtr> *recompute_k_fg_callers) {
539 for (const auto &input : final_node->inputs()) {
540 if (!input->isa<CNode>()) {
541 continue;
542 }
543 auto input_cnode = input->cast<CNodePtr>();
544 if (IsPrimitiveCNode(input_cnode, prim::kPrimTupleGetItem)) {
545 GetAllRecomputeKFgCallers(input_cnode, recompute_k_fg_callers);
546 continue;
547 }
548 // Only get the nodes visited in this round.
549 if (!input_cnode->HasAttr(kAttrReplacedWithPrimal) || !IsRecomputeKGraphCaller(input) ||
550 recompute_k_fg_callers->find(input_cnode) != recompute_k_fg_callers->end()) {
551 continue;
552 }
553 (void)recompute_k_fg_callers->insert(input_cnode);
554 GetAllRecomputeKFgCallers(input_cnode, recompute_k_fg_callers);
555 }
556 }
557
IsFromRecomputeKFgCaller(const FuncGraphPtr & bprop_fg,const mindspore::HashSet<CNodePtr> & recompute_k_fg_callers,const CNodePtr & node,mindspore::HashMap<CNodePtr,bool> * is_from_recompute_k_fg_caller)558 bool IsFromRecomputeKFgCaller(const FuncGraphPtr &bprop_fg, const mindspore::HashSet<CNodePtr> &recompute_k_fg_callers,
559 const CNodePtr &node, mindspore::HashMap<CNodePtr, bool> *is_from_recompute_k_fg_caller) {
560 auto iter = is_from_recompute_k_fg_caller->find(node);
561 if (iter != is_from_recompute_k_fg_caller->end()) {
562 return iter->second;
563 }
564 if (recompute_k_fg_callers.find(node) != recompute_k_fg_callers.end()) {
565 (void)is_from_recompute_k_fg_caller->emplace(node, true);
566 return true;
567 }
568
569 for (const auto &input : node->inputs()) {
570 MS_EXCEPTION_IF_NULL(input);
571 if (!input->isa<CNode>()) {
572 continue;
573 }
574 auto input_cnode = input->cast<CNodePtr>();
575 if (input_cnode->func_graph() != bprop_fg) {
576 AnfNodePtr cur_node = input_cnode;
577 while (IsPrimitiveCNode(cur_node, prim::kPrimTupleGetItem)) {
578 cur_node = cur_node->cast<CNodePtr>()->input(1);
579 }
580 if (cur_node->isa<CNode>() &&
581 recompute_k_fg_callers.find(cur_node->cast<CNodePtr>()) != recompute_k_fg_callers.end()) {
582 (void)is_from_recompute_k_fg_caller->emplace(node, true);
583 return true;
584 }
585 continue;
586 }
587 if (IsFromRecomputeKFgCaller(bprop_fg, recompute_k_fg_callers, input_cnode, is_from_recompute_k_fg_caller)) {
588 (void)is_from_recompute_k_fg_caller->emplace(node, true);
589 return true;
590 }
591 }
592 (void)is_from_recompute_k_fg_caller->emplace(node, false);
593 return false;
594 }
595
AddDependNodes(const FuncGraphManagerPtr & manager,const FuncGraphPtr & fg,const CNodePtr & k_fg_caller_cnode)596 void AddDependNodes(const FuncGraphManagerPtr &manager, const FuncGraphPtr &fg, const CNodePtr &k_fg_caller_cnode) {
597 // Get the nodes which the recomputed part should depend on;
598 mindspore::CompactSet<CNodePtr> final_nodes;
599 mindspore::CompactSet<AnfNodePtr> dependencies;
600 GetDependencies(manager, k_fg_caller_cnode, &final_nodes, &dependencies);
601 if (dependencies.empty()) {
602 return;
603 }
604 FuncGraphPtr bprop_fg;
605 auto bprop_caller = GetBpropCaller(manager, GetBpropGetter(manager, k_fg_caller_cnode));
606 if (bprop_caller == nullptr) {
607 bprop_fg = (*dependencies.begin())->func_graph();
608 } else {
609 bprop_fg = bprop_caller->func_graph();
610 }
611 MS_EXCEPTION_IF_NULL(bprop_fg);
612 // Filter the dependent nodes in case of producing loops.
613 mindspore::HashSet<CNodePtr> recompute_k_fg_callers;
614 for (const auto &final_node : final_nodes) {
615 (void)recompute_k_fg_callers.insert(final_node);
616 GetAllRecomputeKFgCallers(final_node, &recompute_k_fg_callers);
617 }
618 std::vector<AnfNodePtr> depend_inputs{NewValueNode(prim::kPrimMakeTuple)};
619 mindspore::HashMap<CNodePtr, bool> is_from_recompute_k_fg_caller;
620 (void)std::copy_if(dependencies.begin(), dependencies.end(), std::back_inserter(depend_inputs),
621 [bprop_fg, &recompute_k_fg_callers, &is_from_recompute_k_fg_caller](const AnfNodePtr &dependency) {
622 if (!dependency->isa<CNode>()) {
623 return true;
624 }
625 return !IsFromRecomputeKFgCaller(bprop_fg, recompute_k_fg_callers, dependency->cast<CNodePtr>(),
626 &is_from_recompute_k_fg_caller);
627 });
628 // Add the dependency nodes to the first recomputed nodes.
629 auto depend_nodes = bprop_fg->NewCNode(depend_inputs);
630 if (bprop_fg == fg) {
631 if (!IsRecomputeCell(GetValueNode<FuncGraphPtr>(k_fg_caller_cnode->input(0)))) {
632 auto depend = fg->NewCNode({NewValueNode(prim::kPrimDepend), k_fg_caller_cnode->input(1), depend_nodes});
633 depend->AddAttr(kRecomputeInsert, MakeValue(true));
634 manager->SetEdge(k_fg_caller_cnode, 1, depend);
635 k_fg_caller_cnode->AddAttr(kAddedRecomputeDependAttr, MakeValue(true));
636 } else {
637 std::vector<AnfNodePtr> new_k_fg_caller_inputs{k_fg_caller_cnode->input(0)};
638 (void)std::transform(k_fg_caller_cnode->inputs().begin() + 1, k_fg_caller_cnode->inputs().end(),
639 std::back_inserter(new_k_fg_caller_inputs),
640 [&fg, &depend_nodes](const AnfNodePtr &input) -> AnfNodePtr {
641 auto depend = fg->NewCNodeInOrder({NewValueNode(prim::kPrimDepend), input, depend_nodes});
642 depend->AddAttr(kRecomputeInsert, MakeValue(true));
643 return depend;
644 });
645 auto new_k_fg_caller = fg->NewCNodeInOrder(new_k_fg_caller_inputs);
646 auto primal_fg_caller = k_fg_caller_cnode->user_data<CNode>(kPrimalFgCallerUserDataKey);
647 if (primal_fg_caller != nullptr) {
648 new_k_fg_caller->set_user_data(kPrimalFgCallerUserDataKey, primal_fg_caller);
649 }
650 (void)manager->Replace(k_fg_caller_cnode, new_k_fg_caller);
651 new_k_fg_caller->AddAttr(kAddedRecomputeDependAttr, MakeValue(true));
652 new_k_fg_caller->AddAttr(kAttrReplacedWithPrimal, MakeValue(true));
653 }
654 return;
655 }
656 // If the graph of the bprop caller is not the same as the graph of k graph caller, we should move the k graph
657 // caller to the graph of the bprop.
658 std::unordered_map<CNodePtr, CNodePtr> origin_to_new_nodes;
659 for (const auto &final_node : final_nodes) {
660 auto new_k_fg_caller = MoveKCallerToBprop(manager, bprop_fg, final_node, depend_nodes, &origin_to_new_nodes);
661 new_k_fg_caller->AddAttr(kAddedRecomputeDependAttr, MakeValue(true));
662 }
663 for (auto &iter : origin_to_new_nodes) {
664 if (!IsRecomputeKGraphCaller(iter.first)) {
665 continue;
666 }
667 auto forward_getter = GetForwardGetter(manager, iter.first);
668 if (forward_getter == nullptr) {
669 (void)manager->Replace(iter.first, iter.second);
670 } else {
671 auto new_forward_getter =
672 bprop_fg->NewCNode({NewValueNode(prim::kPrimTupleGetItem), iter.second, NewValueNode(static_cast<int64_t>(0))});
673 ReplaceFinalForwardGetter(manager, bprop_fg, forward_getter, new_forward_getter);
674 }
675 }
676 }
677
AddDuplicatedAttr(const FuncGraphPtr & k_fg)678 void AddDuplicatedAttr(const FuncGraphPtr &k_fg) {
679 for (const auto &node : k_fg->nodes()) {
680 if (!node->isa<CNode>()) {
681 continue;
682 }
683 node->cast_ptr<CNode>()->AddAttr(kAttrDuplicated, MakeValue(true));
684 }
685 }
686
AddCseAttr(const FuncGraphPtr & root,bool changed)687 void AddCseAttr(const FuncGraphPtr &root, bool changed) {
688 if (!changed) {
689 return;
690 }
691 auto all_node = TopoSort(root->get_return(), SuccDeeperSimple, AlwaysInclude);
692 for (const auto &node : all_node) {
693 if (WithRecomputedScope(node)) {
694 node->cast<CNodePtr>()->AddAttr(kAttrNeedCseAfterRecompute, MakeValue(true));
695 }
696 }
697 }
698
GetPrimal(const FuncGraphPtr & k_fg,bool * recompute_cell)699 AnfNodePtr GetPrimal(const FuncGraphPtr &k_fg, bool *recompute_cell) {
700 auto primal_iter = k_fg->transforms().find("primal");
701 if (primal_iter == k_fg->transforms().end()) {
702 return nullptr;
703 }
704 AnfNodePtr primal = nullptr;
705 auto primal_fg = primal_iter->second.func_graph();
706 if (primal_fg != nullptr) {
707 primal = NewValueNode(primal_fg);
708 *recompute_cell = true;
709 } else {
710 auto primal_primitive = primal_iter->second.primitive();
711 if (primal_primitive != nullptr) {
712 primal = NewValueNode(primal_primitive);
713 }
714 }
715 return primal;
716 }
717
IsNestedRecomputed(const AnfNodePtr & node)718 bool IsNestedRecomputed(const AnfNodePtr &node) {
719 auto fg = node->func_graph();
720 MS_EXCEPTION_IF_NULL(fg);
721 return fg->has_flag(FUNC_GRAPH_RECOMPUTE_K_GRAPH);
722 }
723
SetPrimalAttrs(const CNodePtr & new_primal,const FuncGraphPtr & k_fg)724 void SetPrimalAttrs(const CNodePtr &new_primal, const FuncGraphPtr &k_fg) {
725 auto forward_in_k_fg = GetPrimalFromFprop(k_fg);
726 auto forward_cnode_in_k_fg = dyn_cast<CNode>(forward_in_k_fg);
727 if (forward_cnode_in_k_fg != nullptr) {
728 new_primal->set_primal_attrs(forward_cnode_in_k_fg->primal_attrs());
729 }
730 }
731 } // namespace
732
AddRecomputeNodes(const FuncGraphPtr & root,const opt::OptimizerPtr & opt)733 bool AddRecomputeNodes(const FuncGraphPtr &root, const opt::OptimizerPtr &opt) {
734 if (!EnableCellReuse()) {
735 return false;
736 }
737 #ifdef ENABLE_DUMP_IR
738 auto context = MsContext::GetInstance();
739 MS_EXCEPTION_IF_NULL(context);
740 bool enable_save_graphs = context->CanDump(kIntroductory);
741 if (enable_save_graphs) {
742 DumpIR("before_recompute_root.ir", root);
743 }
744 #endif
745 MS_EXCEPTION_IF_NULL(root);
746 MS_EXCEPTION_IF_NULL(opt);
747 auto manager = opt->manager();
748 MS_EXCEPTION_IF_NULL(manager);
749 bool changed = false;
750 auto all_node = TopoSort(root->get_return(), SuccDeeperSimple, AlwaysInclude);
751 for (auto iter = all_node.crbegin(); iter != all_node.crend(); (void)iter++) {
752 const auto &node = *iter;
753 if (!IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
754 continue;
755 }
756 auto k_fg_caller_cnode = GetKGraphCallerFromTupleGetitem(node);
757 if (k_fg_caller_cnode == nullptr || k_fg_caller_cnode->HasAttr(kAddedRecomputeDependAttr)) {
758 continue;
759 }
760 auto k_fg = GetValueNode<FuncGraphPtr>(k_fg_caller_cnode->input(0));
761 if (k_fg == nullptr || !k_fg->has_flag(FUNC_GRAPH_RECOMPUTE_K_GRAPH)) {
762 continue;
763 }
764 if (IsNestedRecomputed(k_fg_caller_cnode)) {
765 MS_LOG(WARNING)
766 << "The node and its graph both have been set recomputed, the node would not be handled. The node: "
767 << k_fg_caller_cnode->DebugString();
768 continue;
769 }
770 bool recompute_cell = false;
771 auto primal = GetPrimal(k_fg, &recompute_cell);
772 if (primal == nullptr) {
773 continue;
774 }
775 // Replace the forward getter with the origin primal.
776 constexpr auto recursive_level = 2;
777 MS_LOG(DEBUG) << "Handle recompute k graph forward getter: " << node->DebugString(recursive_level);
778 std::vector<AnfNodePtr> inputs{primal};
779 (void)inputs.insert(inputs.cend(), k_fg_caller_cnode->inputs().begin() + 1, k_fg_caller_cnode->inputs().end());
780 auto fg = node->func_graph();
781 MS_EXCEPTION_IF_NULL(fg);
782 auto new_primal = fg->NewCNodeInOrder(inputs);
783 if (IsValueNode<Primitive>(primal)) {
784 SetPrimalAttrs(new_primal, k_fg);
785 }
786 std::unordered_map<AnfNodePtr, AnfNodePtr> origin_to_new_primal;
787 bool change = AddNewPrimalNode(manager, fg, node, new_primal, recompute_cell, &origin_to_new_primal);
788 changed = change || changed;
789 if (change && recompute_cell) {
790 k_fg_caller_cnode->set_user_data(kPrimalFgCallerUserDataKey, new_primal);
791 }
792 k_fg_caller_cnode->AddAttr(kAttrReplacedWithPrimal, MakeValue(true));
793 // Add duplicated attr to help debugging.
794 AddDuplicatedAttr(k_fg);
795 if (HasRecomputedInput(k_fg_caller_cnode)) {
796 continue;
797 }
798
799 MS_LOG(DEBUG) << "Not has recomputed input k_fg_caller_cnode: " << k_fg_caller_cnode->DebugString();
800 AddDependNodes(manager, fg, k_fg_caller_cnode);
801 }
802 AddCseAttr(root, changed);
803 #ifdef ENABLE_DUMP_IR
804 if (enable_save_graphs) {
805 DumpIR("after_recompute_root.ir", root);
806 }
807 #endif
808 return changed;
809 }
810 } // namespace irpass
811 } // namespace opt
812 } // namespace mindspore
813