• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "backend/common/graph_kernel/graph_kernel_recompute.h"
18 
19 #include <algorithm>
20 #include <deque>
21 #include <functional>
22 #include <limits>
23 #include <map>
24 #include <memory>
25 #include <queue>
26 #include <set>
27 #include <sstream>
28 #include <stack>
29 #include <tuple>
30 #include <utility>
31 #include <vector>
32 #include "mindspore/core/ops/sparse_ops.h"
33 #include "mindspore/core/ops/sequence_ops.h"
34 #include "mindspore/core/ops/math_ops.h"
35 #include "mindspore/core/ops/array_ops.h"
36 #include "mindspore/core/ops/framework_ops.h"
37 #include "kernel/framework_utils.h"
38 #include "backend/common/graph_kernel/graph_kernel_helper.h"
39 #include "backend/common/graph_kernel/core/graph_builder.h"
40 #include "ir/func_graph_cloner.h"
41 
42 namespace mindspore::graphkernel {
43 namespace {
GetGetitemIndex(const AnfNodePtr & getitem)44 int64_t GetGetitemIndex(const AnfNodePtr &getitem) {
45   auto vnode = GetValueNode(getitem->cast<CNodePtr>()->input(kInputNodeOutputIndexInTupleGetItem));
46   return GetValue<int64_t>(vnode);
47 }
48 
GetOutput(const FuncGraphPtr & func_graph,size_t i)49 AnfNodePtr GetOutput(const FuncGraphPtr &func_graph, size_t i) {
50   auto output_node = func_graph->output()->cast<CNodePtr>();
51   MS_EXCEPTION_IF_NULL(output_node);
52   if (IsPrimitiveCNode(output_node, prim::kPrimMakeTuple)) {
53     if (i + 1 >= output_node->size()) {
54       MS_LOG(EXCEPTION) << i << " is out of range of MakeTuple's size " << output_node->size();
55     }
56     return output_node->input(i + 1);
57   } else {
58     if (i > 0) {
59       MS_LOG(EXCEPTION) << "the graph is single output but i is not 0. it's " << i;
60     }
61     return output_node->cast<AnfNodePtr>();
62   }
63 }
64 
IsExclude(const AnfNodePtr & node)65 bool IsExclude(const AnfNodePtr &node) {
66   static std::vector<PrimitivePtr> excludes = {prim::kPrimReturn, prim::kPrimUpdateState, prim::kPrimLoad,
67                                                prim::kPrimMakeTuple, prim::kPrimDepend};
68   return std::any_of(excludes.begin(), excludes.end(),
69                      [&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); });
70 }
71 
72 enum class VisitType : char { FOLLOW, STOP };
73 using VisitFunc = std::function<VisitType(const AnfNodePtr &)>;
74 using NextFunc = std::function<AnfNodePtrList(const AnfNodePtr &)>;
75 using ProcessFunc = std::function<void(const AnfNodePtr &)>;
76 
Dfs(const AnfNodePtr & current,const VisitFunc & visit_func,const NextFunc & next_func,const ProcessFunc & before_func,const ProcessFunc & after_func,std::set<AnfNodePtr> * visited)77 void Dfs(const AnfNodePtr &current, const VisitFunc &visit_func, const NextFunc &next_func,
78          const ProcessFunc &before_func, const ProcessFunc &after_func, std::set<AnfNodePtr> *visited) {
79   if (visited->count(current) > 0) {
80     return;
81   }
82   (void)visited->insert(current);
83   if (visit_func(current) != VisitType::FOLLOW) {
84     return;
85   }
86 
87   for (const auto &next : next_func(current)) {
88     before_func(next);
89     Dfs(next, visit_func, next_func, before_func, after_func, visited);
90     after_func(next);
91   }
92 }
93 
CollectLinkPaths(const std::map<AnfNodePtr,MemorySize> & topo_indice,const OrderedSet<AnfNodePtr> & direct_users,MemorySize max_topo_user_index,const FuncGraphManagerPtr & mng)94 OrderedMap<AnfNodePtr, AnfNodePtrList> CollectLinkPaths(const std::map<AnfNodePtr, MemorySize> &topo_indice,
95                                                         const OrderedSet<AnfNodePtr> &direct_users,
96                                                         MemorySize max_topo_user_index,
97                                                         const FuncGraphManagerPtr &mng) {
98   std::stack<AnfNodePtr> cur_stack;
99   OrderedMap<AnfNodePtr, AnfNodePtrList> link_paths;
100   auto TmpVisitFunc = [&topo_indice, max_topo_user_index](const AnfNodePtr &n) -> VisitType {
101     if (IsExclude(n)) {
102       return VisitType::STOP;
103     }
104 
105     auto iter = topo_indice.find(n);
106     if (iter == topo_indice.end()) {
107       MS_LOG(EXCEPTION) << "Cannot find " << n->fullname_with_scope() << " in topo indices!";
108     }
109     if (iter->second > max_topo_user_index) {
110       return VisitType::STOP;
111     }
112     return VisitType::FOLLOW;
113   };
114 
115   auto TmpNextFunc = [&mng](const AnfNodePtr &n) -> AnfNodePtrList {
116     auto users = mng->node_users()[n];
117     AnfNodePtrList nexts;
118     (void)std::transform(users.cbegin(), users.cend(), std::back_inserter(nexts),
119                          [](const std::pair<AnfNodePtr, int> &user) { return user.first; });
120     return nexts;
121   };
122 
123   auto TmpBeforeFunc = [&link_paths, &cur_stack, &direct_users](const AnfNodePtr &next) -> void {
124     if (direct_users.count(next) == 0) {
125       return;
126     }
127     auto cur_node = cur_stack.top();
128     if (link_paths.find(cur_node) == link_paths.end()) {
129       (void)link_paths.emplace(cur_node, AnfNodePtrList());
130     }
131     link_paths[cur_node].push_back(next);
132     cur_stack.push(next);
133   };
134 
135   auto TmpAfterFunc = [&cur_stack, &direct_users](const AnfNodePtr &next) -> void {
136     if (direct_users.count(next) == 0) {
137       return;
138     }
139     cur_stack.push(next);
140   };
141 
142   std::set<AnfNodePtr> visited;
143   for (auto user : direct_users) {
144     cur_stack.push(user);
145     Dfs(user, TmpVisitFunc, TmpNextFunc, TmpBeforeFunc, TmpAfterFunc, &visited);
146     cur_stack.pop();
147   }
148 
149   return link_paths;
150 }
151 
GetLongTermNodes(const AnfNodePtrList & nodes,const AnfNodePtr & end_node,const std::map<AnfNodePtr,MemorySize> & topo_indices,const FuncGraphManagerPtr & mng)152 OrderedSet<AnfNodePtr> GetLongTermNodes(const AnfNodePtrList &nodes, const AnfNodePtr &end_node,
153                                         const std::map<AnfNodePtr, MemorySize> &topo_indices,
154                                         const FuncGraphManagerPtr &mng) {
155   OrderedSet<AnfNodePtr> long_term_nodes;
156   for (auto node : nodes) {
157     auto real_node = common::AnfAlgo::VisitKernelWithReturnType(node, 0).first;
158     // Parameter or value have long term tensors.
159     if (!utils::isa<CNodePtr>(real_node)) {
160       (void)long_term_nodes.insert(node);
161       continue;
162     }
163 
164     auto users = mng->node_users()[node];
165     if (std::any_of(users.cbegin(), users.cend(), [&topo_indices, &end_node](const std::pair<AnfNodePtr, int> &user) {
166           auto user_topo = topo_indices.find(user.first);
167           auto end_topo = topo_indices.find(end_node);
168           return user_topo->second >= end_topo->second;
169         })) {
170       (void)long_term_nodes.insert(node);
171     }
172   }
173   return long_term_nodes;
174 }
175 
176 /**
177  * @brief Remove real input which is not used and change the related graph parameters.
178  *
179  * @param func_graph Graph.
180  * @param inputs Real inputs for graph cnode.
181  */
ElimRedundantInputsAndGraphParameters(const FuncGraphPtr & func_graph,AnfNodePtrList * inputs)182 void ElimRedundantInputsAndGraphParameters(const FuncGraphPtr &func_graph, AnfNodePtrList *inputs) {
183   MS_EXCEPTION_IF_NULL(inputs);
184   const auto &ori_parameter = func_graph->parameters();
185   auto nodes = TopoSort(func_graph->get_return());
186   std::set<AnfNodePtr> used_param;
187   for (auto node : nodes) {
188     if (node->isa<Parameter>()) {
189       (void)used_param.insert(node);
190     }
191   }
192   if (used_param.size() == ori_parameter.size()) {
193     return;
194   }
195   AnfNodePtrList new_parameter, new_inputs;
196   for (size_t i = 0; i < ori_parameter.size(); ++i) {
197     if (used_param.count(ori_parameter[i]) != 0) {
198       new_parameter.push_back(ori_parameter[i]);
199       new_inputs.push_back((*inputs)[i]);
200     }
201   }
202   func_graph->set_parameters(new_parameter);
203   *inputs = std::move(new_inputs);
204 }
205 }  // namespace
206 
Run(const FuncGraphPtr & func_graph)207 std::vector<Candidate> AutoRecompute::Run(const FuncGraphPtr &func_graph) {
208   lifetime_threshold_ = GraphKernelFlags::GetInstance().recompute_increment_threshold;
209   local_peak_threshold_ = GraphKernelFlags::GetInstance().recompute_peak_threshold;
210   if (!IsThresholdDefaultValue()) {
211     FindCandidates(func_graph);
212   }
213   return candidates_;
214 }
215 
216 /**
217  * @brief Filter the input tensor(that live longer than end node) out and return valid inputs for memory calculation. \n
218  *        If the topo indices of the input's user is at least one greater than end_node,                              \n
219  *        it will retain when after end_node's execution.
220  *
221  * @param source_node
222  * @param end_node
223  * @param edge_pos
224  * @param mng
225  * @return AnfNodePtrList
226  */
Filter(const AnfNodePtr & source_node,const AnfNodePtr & end_node,int edge_pos,const FuncGraphManagerPtr & mng)227 AnfNodePtrList AutoRecompute::Filter(const AnfNodePtr &source_node, const AnfNodePtr &end_node, int edge_pos,
228                                      const FuncGraphManagerPtr &mng) {
229   auto source_cnode = source_node->cast<CNodePtr>();
230   MS_EXCEPTION_IF_NULL(source_cnode);
231   AnfNodePtrList node_inputs(source_cnode->inputs().begin() + 1, source_cnode->inputs().end());
232   OrderedSet<AnfNodePtr> long_term_inputs = GetLongTermNodes(node_inputs, end_node, topo_indice_, mng);
233 
234   AnfNodePtrList check_inputs;
235   if (IsPrimitiveCNode(end_node->cast<CNodePtr>()->input(IntToSize(edge_pos)), prim::kPrimTupleGetItem)) {
236     auto out_index = GetSourceLinkOutPos(end_node, edge_pos);
237     auto sub_graph = common::AnfAlgo::GetCNodeFuncGraphPtr(source_node);
238     auto out = sub_graph->output();
239     if (!IsPrimitiveCNode(out, prim::kPrimMakeTuple)) {
240       MS_LOG(EXCEPTION) << "Expect MakeTuple node, but got " << common::AnfAlgo::GetCNodeName(out);
241     }
242 
243     // Find subgraph's input according to edge node.
244     auto start_node = out->cast<CNodePtr>()->input(IntToSize(out_index + 1));
245     AnfNodePtrList sub_input_parameters;
246     std::queue<AnfNodePtr> node_q;
247     node_q.push(start_node);
248     while (!node_q.empty()) {
249       auto cur = node_q.front();
250       node_q.pop();
251       if (utils::isa<ParameterPtr>(cur)) {
252         sub_input_parameters.push_back(cur);
253       }
254       auto cur_cnode = cur->cast<CNodePtr>();
255       if (cur_cnode) {
256         for (size_t i = 1; i < cur_cnode->size(); ++i) {
257           node_q.push(cur_cnode->input(i));
258         }
259       }
260     }
261 
262     // Filte input that user's topo index is great than source graph.
263     for (auto para : sub_input_parameters) {
264       for (size_t i = 0; i < sub_graph->parameters().size(); ++i) {
265         if (para == sub_graph->parameters()[i]) {
266           check_inputs.push_back(node_inputs[i]);
267         }
268       }
269     }
270   } else {
271     check_inputs = node_inputs;
272   }
273 
274   AnfNodePtrList res;
275   for (auto input : check_inputs) {
276     if (long_term_inputs.count(input) == 0) {
277       res.push_back(input);
278     }
279   }
280 
281   return res;
282 }
283 
284 /**
285  * @brief Get valid users information by giving node, excluding TupleGetItem, Load and so on.
286  */
GetValidUsers(const AnfNodePtr & node,const FuncGraphManagerPtr & mng)287 std::tuple<OrderedSet<AnfNodePtr>, OutPosLinkMap, MemorySize> AutoRecompute::GetValidUsers(
288   const AnfNodePtr &node, const FuncGraphManagerPtr &mng) {
289   auto &user_map = mng->node_users();
290   auto users = user_map[node];
291   MemorySize max_topo_user_index = 0;
292   std::queue<std::pair<AnfNodePtr, int>> users_queue;
293   for (auto user_index : users) {
294     users_queue.push(user_index);
295   }
296   OrderedSet<AnfNodePtr> direct_users;
297   OutPosLinkMap user_edge_pos;
298   while (!users_queue.empty()) {
299     auto [user, index] = users_queue.front();
300     users_queue.pop();
301     if (IsPrimitiveCNode(user, prim::kPrimTupleGetItem)) {
302       for (auto get_item_user : user_map[user]) {
303         users_queue.push(get_item_user);
304       }
305       continue;
306     } else if (IsExclude(user)) {
307       continue;
308     }
309     user_edge_pos[user].push_back(index);
310     (void)direct_users.insert(user);
311     // Update maximum topo value.
312     if (topo_indice_[user] > max_topo_user_index) {
313       max_topo_user_index = topo_indice_[user];
314     }
315   }
316 
317   return {direct_users, user_edge_pos, max_topo_user_index};
318 }
319 
320 /**
321  * @brief Judege target node for recompute according to current node, and capture source node information when find   \n
322  *        target. There two type for tensor of the edge between source node and target node, example:                 \n
323  *          source ──[Short-Term]── A ── other                                                                        \n
324  *             │                           │                                                                          \n
325  *             └───────[Long-Term]────── target                                                                       \n
326  *          For this example,                                                                                         \n
327  *          1. There are two path from source node to target node, and target is directly user for source node,       \n
328  *             so the tensor of their edge is a long-term tensor.                                                     \n
329  *          2. From source node to A, there is only one path, and A is directly user for source node,                 \n
330  *             so the tensor of their edge is a short-term tensor.
331  *
332  * @param node Source node.
333  * @param mng Graph manager.
334  * @return OutPosLinkList Vector[Tuple(target node, input positions of target node for edge, edge type)].
335  */
JudegeTargetAndCaptureSource(const AnfNodePtr & node,const FuncGraphManagerPtr & mng)336 OutPosLinkList AutoRecompute::JudegeTargetAndCaptureSource(const AnfNodePtr &node, const FuncGraphManagerPtr &mng) {
337   auto [direct_users, user_edge_pos, max_topo_user_index] = GetValidUsers(node, mng);
338   OutPosLinkList target_link_infos;
339   OrderedSet<AnfNodePtr> long_term_users;
340   // If the number of direct users is less than 2, there will no side way to its user....
341   if (direct_users.size() >= 2) {
342     OrderedMap<AnfNodePtr, AnfNodePtrList> link_paths =
343       CollectLinkPaths(topo_indice_, direct_users, max_topo_user_index, mng);
344     for (const auto &[source, paths] : link_paths) {
345       for (auto target : paths) {
346         if (target != source) {
347           (void)target_link_infos.emplace_back(target, user_edge_pos[target], EdgeLifeTimeType::LongTerm);
348           (void)long_term_users.insert(target);
349         }
350       }
351     }
352   }
353 
354   // Direct users include long term users and short term users.
355   // If the short term user is graph kernel composite node, it may be absorb and reduce the local peak memory.
356   for (const auto &user : direct_users) {
357     if (long_term_users.count(user) == 0 && common::AnfAlgo::IsGraphKernel(user)) {
358       (void)target_link_infos.emplace_back(user, user_edge_pos[user], EdgeLifeTimeType::ShortTerm);
359     }
360   }
361 
362   RecomputeLinkEdgeLog(node, direct_users, target_link_infos);
363   return target_link_infos;
364 }
365 
366 /**
367  * @brief Get position of edge tensor between source node and target node.     \n
368  *        For example, giving target node and edge position 0, will return 1:  \n
369  *          source node                                                        \n
370  *          [0] [1] [2]  <- output position                                    \n
371  *               |                                                             \n
372  *               |                                                             \n
373  *              /                                                              \n
374  *            [0] [1]    <- input position                                     \n
375  *          target node
376  *
377  * @param target Target node.
378  * @param pos The input position of target node for edge.
379  * @return int The output position of source node for edge.
380  */
GetSourceLinkOutPos(const AnfNodePtr & target,int pos) const381 int AutoRecompute::GetSourceLinkOutPos(const AnfNodePtr &target, int pos) const {
382   // If the input is get-item, than use get-item's index, otherwise zero.
383   auto cnode = target->cast<CNodePtr>();
384   MS_EXCEPTION_IF_NULL(cnode);
385   auto prenode = cnode->input(IntToSize(pos));
386   if (!IsPrimitiveCNode(prenode, prim::kPrimTupleGetItem)) {
387     return 0;
388   }
389 
390   auto get_item_cnode = prenode->cast<CNodePtr>();
391   MS_EXCEPTION_IF_NULL(get_item_cnode);
392   auto value_input = get_item_cnode->input(kInputNodeOutputIndexInTupleGetItem);
393   MS_EXCEPTION_IF_NULL(value_input);
394   auto value_node = value_input->cast<ValueNodePtr>();
395   MS_EXCEPTION_IF_NULL(value_node);
396   return static_cast<int>(GetValue<int64_t>(value_node->value()));
397 }
398 
SelectThreshold(EdgeLifeTimeType type) const399 MemorySize AutoRecompute::SelectThreshold(EdgeLifeTimeType type) const {
400   MemorySize threshold = 0;
401   auto local_peak_th = local_peak_threshold_ == 0 ? std::numeric_limits<MemorySize>::max() : local_peak_threshold_;
402   auto lifetime_th = lifetime_threshold_ == 0 ? std::numeric_limits<MemorySize>::max() : lifetime_threshold_;
403   if (type == EdgeLifeTimeType::ShortTerm) {
404     threshold = local_peak_th;
405   } else if (type == EdgeLifeTimeType::LongTerm) {
406     threshold = std::min(local_peak_th, lifetime_th);
407   }
408 
409   return threshold;
410 }
411 
IsThresholdDefaultValue() const412 bool AutoRecompute::IsThresholdDefaultValue() const {
413   if (local_peak_threshold_ == 0 && lifetime_threshold_ == 0) {
414     return true;
415   }
416   return false;
417 }
418 
419 /**
420  * @brief Find recompute candidates(source node, target node, edge and its type) in func_graph. \n
421  *        Result will be add to candidates_.
422  *
423  * @param func_graph
424  */
FindCandidates(const FuncGraphPtr & func_graph)425 void AutoRecompute::FindCandidates(const FuncGraphPtr &func_graph) {
426   MS_EXCEPTION_IF_NULL(func_graph);
427   candidates_.clear();
428 
429   auto mng = func_graph->manager();
430   MS_EXCEPTION_IF_NULL(mng);
431 
432   auto topo_nodes = TopoSort(func_graph->get_return());
433   // Topo indice is use to early stop in predecessor check.
434   for (size_t i = 0; i < topo_nodes.size(); ++i) {
435     (void)topo_indice_.emplace(topo_nodes[i], i);
436   }
437 
438   // Candidate condition:
439   // 1. Judge current node can see its graph_kernel input with other input's backward path.
440   // 2. Memory variety between split out and origin more than threshold:
441   //    `Size(gs_direct_outs_to_gt) - filter(gs_inputs, its) > threshold`.
442   for (auto node : topo_nodes) {
443     if (!common::AnfAlgo::IsGraphKernel(node)) {
444       continue;
445     }
446     auto target_graphs = JudegeTargetAndCaptureSource(node, mng);
447     if (target_graphs.empty()) {
448       continue;
449     }
450     auto node_candidates = FindNodeRecomputeCandidates(node, target_graphs, mng);
451     // Delete duplicated link.
452     for (const auto &[source, target_and_link] : node_candidates) {
453       for (const auto &[target, link] : target_and_link) {
454         candidates_.push_back({source, target, link.first, link.second});
455       }
456     }
457   }
458 
459   RecomputeCandidatesLog(candidates_);
460 }
461 
462 /**
463  * @brief Find recompute candidates for node as source graph.
464  *
465  * @param node Source graph node.
466  * @param target_graphs Vector of [AnfNodePtr, std::vector<int>, EdgeLifeTimeType].
467  * @param mng Manager of main graph(which contains this node).
468  * @return AutoRecompute::NodeRecomputeCandidates
469  */
FindNodeRecomputeCandidates(const AnfNodePtr & node,const OutPosLinkList & target_graphs,const FuncGraphManagerPtr & mng)470 AutoRecompute::NodeRecomputeCandidates AutoRecompute::FindNodeRecomputeCandidates(const AnfNodePtr &node,
471                                                                                   const OutPosLinkList &target_graphs,
472                                                                                   const FuncGraphManagerPtr &mng) {
473   MS_EXCEPTION_IF_NULL(node);
474   MS_EXCEPTION_IF_NULL(mng);
475   NodeRecomputeCandidates node_candidates;
476   auto graph_node = common::AnfAlgo::GetCNodeFuncGraphPtr(node);
477   MS_EXCEPTION_IF_NULL(graph_node);
478   auto nodes = graph_node->nodes();
479   if (std::any_of(nodes.cbegin(), nodes.cend(),
480                   [](const AnfNodePtr n) { return IsPrimitiveCNode(n, prim::kPrimReduceSum); })) {
481     return node_candidates;
482   }
483   for (auto [gt, gt_in_pos_vec, edge_life_time_type] : target_graphs) {
484     MemorySize threshold = SelectThreshold(edge_life_time_type);
485     for (auto gt_in_pos : gt_in_pos_vec) {
486       MemorySize out_tensor_size =
487         static_cast<MemorySize>(AnfAlgo::GetOutputTensorMemSize(node, IntToSize(GetSourceLinkOutPos(gt, gt_in_pos))));
488       MemorySize absorb_input_tensor_size = 0;
489       for (auto input : Filter(node, gt, gt_in_pos, mng)) {
490         absorb_input_tensor_size += static_cast<MemorySize>(AnfAlgo::GetOutputTensorMemSize(input, 0));
491       }
492       auto gt_cnode = gt->cast<CNodePtr>();
493       MS_EXCEPTION_IF_NULL(gt_cnode);
494       auto edge = gt_cnode->input(IntToSize(gt_in_pos));
495 
496       MS_LOG(DEBUG) << "Recompute case: GS(" << node->fullname_with_scope() << ") -> GT(" << gt->fullname_with_scope()
497                     << ") with Edge(" << edge->fullname_with_scope() << "<" << edge_life_time_type << ">.";
498 
499       if (out_tensor_size < absorb_input_tensor_size) {
500         MS_LOG(DEBUG) << " ==> Skip this case because memory reduction.";
501         continue;
502       }
503 
504       auto memory_increment = out_tensor_size - absorb_input_tensor_size;
505       MS_LOG(DEBUG) << " ==> Threshold: " << threshold << ", Out Tensor[" << out_tensor_size << "] - Absort Tensor["
506                     << absorb_input_tensor_size << "] = " << memory_increment;
507 
508       if (memory_increment > threshold) {
509         if (node_candidates[node].find(gt) == node_candidates[node].end()) {
510           node_candidates[node][gt] = {edge_life_time_type, AnfNodePtrList{}};
511         }
512         // Only add getitem node as edge, if GS is single output node, there will be no edges.
513         if (IsPrimitiveCNode(edge, prim::kPrimTupleGetItem)) {
514           node_candidates[node][gt].second.push_back(edge);
515         }
516       }
517     }
518   }
519   return node_candidates;
520 }
521 
RecomputeLinkEdgeLog(const AnfNodePtr & node,const OrderedSet<AnfNodePtr> & direct_users,const OutPosLinkList & target_link_infos) const522 void AutoRecompute::RecomputeLinkEdgeLog(const AnfNodePtr &node, const OrderedSet<AnfNodePtr> &direct_users,
523                                          const OutPosLinkList &target_link_infos) const {
524   MS_EXCEPTION_IF_NULL(node);
525   MS_LOG(DEBUG) << "Recompute users for node: " << node->fullname_with_scope();
526   for (const auto &direct_user : direct_users) {
527     MS_LOG(DEBUG) << "  └─ " << direct_user->fullname_with_scope();
528   }
529 
530   MS_LOG(DEBUG) << "Edge Link relation: ";
531   for (const auto &[target, tartget_in_index, life_type] : target_link_infos) {
532     MS_EXCEPTION_IF_NULL(target);
533     MS_LOG(DEBUG) << "  └[" << tartget_in_index << "|<" << life_type
534                   << ">]─> Link to: " << target->fullname_with_scope();
535   }
536 }
537 
RecomputeCandidatesLog(const std::vector<Candidate> & candidates) const538 void AutoRecompute::RecomputeCandidatesLog(const std::vector<Candidate> &candidates) const {
539   MS_LOG(INFO) << "Recompute candidates: ";
540   for (auto candidate : candidates) {
541     MS_LOG(INFO) << "  └─ GS: " << candidate.source_graph->fullname_with_scope();
542     MS_LOG(INFO) << "  └─ GT: " << candidate.target_graph->fullname_with_scope();
543     for (auto edge : candidate.recompute_edges) {
544       MS_LOG(INFO) << "    └─[Edge]─> " << edge->fullname_with_scope();
545     }
546   }
547 }
548 
Run(const FuncGraphPtr & func_graph)549 std::vector<Candidate> CSRRecompute::Run(const FuncGraphPtr &func_graph) {
550   FindCandidates(func_graph);
551   return candidates_;
552 }
553 
CheckPrimitiveInput(AnfNodePtr base,const PrimitivePtr & prim_type) const554 bool CSRRecompute::CheckPrimitiveInput(AnfNodePtr base, const PrimitivePtr &prim_type) const {
555   std::deque<AnfNodePtr> q{base};
556   std::set<AnfNodePtr> visited;
557   while (!q.empty()) {
558     auto node = q.front();
559     q.pop_front();
560     if (visited.count(node) > 0) {
561       continue;
562     }
563     (void)visited.insert(node);
564     if (IsPrimitiveCNode(node, prim_type)) {
565       return true;
566     }
567     auto cnode = node->cast<CNodePtr>();
568     if (cnode == nullptr) {
569       continue;
570     }
571     auto inputs = cnode->inputs();
572     (void)q.insert(q.begin(), inputs.begin(), inputs.end());
573   }
574   return false;
575 }
576 
FindNodeRecomputeCandidates(const AnfNodePtr & node,const OutPosLinkList & target_graphs,const FuncGraphManagerPtr & mng)577 AutoRecompute::NodeRecomputeCandidates CSRRecompute::FindNodeRecomputeCandidates(const AnfNodePtr &node,
578                                                                                  const OutPosLinkList &target_graphs,
579                                                                                  const FuncGraphManagerPtr &mng) {
580   MS_EXCEPTION_IF_NULL(node);
581   MS_EXCEPTION_IF_NULL(mng);
582   NodeRecomputeCandidates node_candidates;
583   auto graph_node = common::AnfAlgo::GetCNodeFuncGraphPtr(node);
584   MS_EXCEPTION_IF_NULL(graph_node);
585   // subgraphs outputting UnsortedSegmentSum or CSRReduceSum along with other ops
586   // (likely the result of Gather), or containing CSRDiv without outputting
587   // UnsortedSegmentSum or CSRReduceSum, are selected as candidates for recompute.
588   auto TargetTail = [](const AnfNodePtr n) {
589     return IsPrimitiveCNode(n, prim::kPrimUnsortedSegmentSum) || IsPrimitiveCNode(n, prim::kPrimCSRReduceSum);
590   };
591   auto TargetHead = [](const AnfNodePtr n) { return IsPrimitiveCNode(n, prim::kPrimCSRDiv); };
592   auto return_node = graph_node->get_return();
593   MS_EXCEPTION_IF_NULL(return_node);
594   auto return_cnode = return_node->cast<CNodePtr>();
595   MS_EXCEPTION_IF_NULL(return_cnode);
596   auto return_inputs = return_cnode->inputs();
597   auto return_tup = return_inputs[1]->cast<CNodePtr>();
598   MS_EXCEPTION_IF_NULL(return_tup);
599   auto tuple_inputs = return_tup->inputs();
600   std::set<size_t> candidate_idx;
601   if (std::any_of(tuple_inputs.cbegin(), tuple_inputs.cend(), TargetTail)) {
602     for (size_t i = 1; i < tuple_inputs.size(); ++i) {
603       if (!TargetTail(tuple_inputs[i])) {
604         (void)candidate_idx.insert(i - 1);
605       }
606     }
607   } else if (std::any_of(tuple_inputs.cbegin(), tuple_inputs.cend(), TargetHead)) {
608     for (size_t i = 1; i < tuple_inputs.size(); ++i) {
609       if (CheckPrimitiveInput(tuple_inputs[i], prim::kPrimCSRDiv)) {
610         (void)candidate_idx.insert(i - 1);
611       }
612     }
613   }
614   if (candidate_idx.empty()) {
615     return node_candidates;
616   }
617   for (size_t i = 0; i < target_graphs.size(); ++i) {
618     AnfNodePtr gt;
619     std::vector<int> gt_in_pos_vec;
620     std::tie(gt, gt_in_pos_vec, std::ignore) = target_graphs[i];
621     for (auto gt_in_pos : gt_in_pos_vec) {
622       auto gt_cnode = gt->cast<CNodePtr>();
623       MS_EXCEPTION_IF_NULL(gt_cnode);
624       auto edge = gt_cnode->input(IntToSize(gt_in_pos));
625       if (!IsPrimitiveCNode(edge, prim::kPrimTupleGetItem)) {
626         continue;
627       }
628       auto edge_cnode = edge->cast<CNodePtr>();
629       MS_EXCEPTION_IF_NULL(edge_cnode);
630       auto tuple_idx = common::AnfAlgo::GetTupleGetItemOutIndex(edge_cnode);
631       if (candidate_idx.count(tuple_idx) > 0) {
632         node_candidates[node][gt].second.push_back(edge);
633       }
634     }
635   }
636   return node_candidates;
637 }
638 
CloneGraph(const CNodePtr & source_graph,const AnfNodePtrList & recompute_edges) const639 std::pair<FuncGraphPtr, AnfNodePtrList> GraphKernelRecompute::CloneGraph(const CNodePtr &source_graph,
640                                                                          const AnfNodePtrList &recompute_edges) const {
641   MS_EXCEPTION_IF_NULL(source_graph);
642   auto gs = common::AnfAlgo::GetCNodeFuncGraphPtr(source_graph);
643   MS_EXCEPTION_IF_NULL(gs);
644   AnfNodePtrList inputs(source_graph->inputs().begin() + 1, source_graph->inputs().end());
645   auto new_funcgraph = BasicClone(gs);
646   auto output_node = new_funcgraph->output()->cast<CNodePtr>();
647   MS_EXCEPTION_IF_NULL(output_node);
648   if (!IsPrimitiveCNode(output_node, prim::kPrimMakeTuple)) {
649     return {new_funcgraph, inputs};
650   }
651   // remove outputs that not in recompute edges.
652   AnfNodePtrList new_outputs;
653   for (auto &edge : recompute_edges) {
654     auto idx = GetGetitemIndex(edge);
655     new_outputs.push_back(GetOutput(new_funcgraph, LongToSize(idx)));
656   }
657   if (new_outputs.size() + 1 == output_node->size()) {
658     return {new_funcgraph, inputs};
659   }
660   (void)new_outputs.insert(new_outputs.cbegin(), output_node->input(0));
661   auto new_output_node = new_funcgraph->NewCNode(new_outputs);
662   // use the old abstract, since the new_funcgraph will be deleted in later process.
663   new_output_node->set_abstract(output_node->abstract());
664   new_output_node->set_kernel_info(std::make_shared<device::KernelInfo>());
665   new_funcgraph->set_output(new_output_node);
666   ElimRedundantInputsAndGraphParameters(new_funcgraph, &inputs);
667   return {new_funcgraph, inputs};
668 }
669 
LinkIntoTargetFuncGraph(const Candidate & candidate,const FuncGraphPtr & cloned_func,const AnfNodePtrList & cloned_inputs,const std::function<std::pair<bool,size_t> (const Candidate &,const AnfNodePtr &)> & edge_match_func) const670 void GraphKernelRecompute::LinkIntoTargetFuncGraph(
671   const Candidate &candidate, const FuncGraphPtr &cloned_func, const AnfNodePtrList &cloned_inputs,
672   const std::function<std::pair<bool, size_t>(const Candidate &, const AnfNodePtr &)> &edge_match_func) const {
673   auto cloned_nodes = TopoSort(cloned_func->get_return());
674   auto gt = common::AnfAlgo::GetCNodeFuncGraphPtr(candidate.target_graph);
675   MS_EXCEPTION_IF_NULL(gt);
676   auto mng = gt->manager();
677   if (mng == nullptr) {
678     mng = Manage(gt, true);
679     gt->set_manager(mng);
680   }
681 
682   // link the outputs to gt
683   auto gt_node = candidate.target_graph->cast<CNodePtr>();
684   MS_EXCEPTION_IF_NULL(gt_node);
685   AnfNodePtrList new_parameters;
686   AnfNodePtrList new_inputs;
687   auto &params = gt->parameters();
688   for (size_t i = 0; i < params.size(); i++) {
689     // if the parameter is a recompute edge, then links the param to the cloned_func's output.
690     auto [is_match, out_index] = edge_match_func(candidate, gt_node->input(i + 1));
691     if (is_match) {
692       (void)mng->Replace(params[i], GetOutput(cloned_func, out_index));
693     } else {
694       new_parameters.push_back(params[i]);
695       new_inputs.push_back(gt_node->input(i + 1));
696     }
697   }
698 
699   // add new parameters
700   auto &cloned_func_params = cloned_func->parameters();
701   for (size_t i = 0; i < cloned_func_params.size(); i++) {
702     auto iter = std::find(new_inputs.begin(), new_inputs.end(), cloned_inputs[i]);
703     if (iter != new_inputs.end()) {
704       auto idx = iter - new_inputs.begin();
705       (void)cloned_func->manager()->Replace(cloned_func_params[i], new_parameters[LongToSize(idx)]);
706     } else {
707       new_parameters.push_back(gt->add_parameter());
708       new_inputs.push_back(cloned_inputs[i]);
709       (void)cloned_func->manager()->Replace(cloned_func_params[i], new_parameters.back());
710     }
711   }
712 
713   // reset the func_graph for cloned_nodes.
714   for (auto &node : cloned_nodes) {
715     if (node->isa<CNode>()) {
716       node->set_func_graph(gt);
717     }
718   }
719   AnfNodePtrList new_node_inputs = {gt_node->input(0)};
720   (void)new_node_inputs.insert(new_node_inputs.cend(), new_inputs.cbegin(), new_inputs.cend());
721   gt->set_parameters(new_parameters);
722   gt_node->set_inputs(new_node_inputs);
723   AnfNodePtrList outputs;
724   kernel::GetFuncGraphOutputNodes(gt, &outputs);
725   gt_node->set_kernel_info(std::make_shared<device::KernelInfo>());
726   Callback::Instance()->SetGraphKernelNodeKernelInfo(gt_node);
727   mng->RemoveRoots();
728   mng->KeepRoots({gt});
729 }
730 
Process(const Candidate & candidate) const731 void GraphKernelRecompute::Process(const Candidate &candidate) const {
732   FuncGraphPtr new_funcgraph;
733   AnfNodePtrList inputs;
734   std::function<std::pair<bool, size_t>(const Candidate &, const AnfNodePtr &)> edge_match_func;
735   if (candidate.recompute_edges.empty()) {
736     // single output, clone the whole source_graph.
737     auto gs = common::AnfAlgo::GetCNodeFuncGraphPtr(candidate.source_graph);
738     MS_EXCEPTION_IF_NULL(gs);
739     new_funcgraph = BasicClone(gs);
740     auto source_cnode = candidate.source_graph->cast<CNodePtr>();
741     MS_EXCEPTION_IF_NULL(source_cnode);
742     auto source_inputs = source_cnode->inputs();
743     (void)inputs.insert(inputs.cend(), source_inputs.cbegin() + 1, source_inputs.cend());
744     edge_match_func = [](const Candidate &match_candidate, const AnfNodePtr &to_match) -> std::pair<bool, size_t> {
745       if (match_candidate.source_graph == to_match) {
746         return std::make_pair(true, 0);
747       }
748       return std::make_pair(false, 0);
749     };
750   } else {
751     std::tie(new_funcgraph, inputs) = CloneGraph(candidate.source_graph->cast<CNodePtr>(), candidate.recompute_edges);
752     edge_match_func = [](const Candidate &match_candidate, const AnfNodePtr &to_match) -> std::pair<bool, size_t> {
753       auto iter = std::find(match_candidate.recompute_edges.begin(), match_candidate.recompute_edges.end(), to_match);
754       if (iter != match_candidate.recompute_edges.end()) {
755         auto out_index = iter - match_candidate.recompute_edges.begin();
756         return std::make_pair(true, LongToSize(out_index));
757       }
758       return std::make_pair(false, 0);
759     };
760   }
761 
762   auto mng = new_funcgraph->manager();
763   if (mng == nullptr) {
764     mng = Manage(new_funcgraph, true);
765     new_funcgraph->set_manager(mng);
766   }
767 
768   if (common::AnfAlgo::IsGraphKernel(candidate.target_graph)) {
769     // the target graph is a GraphKernel, push the new_funcgraph into the target graph.
770     LinkIntoTargetFuncGraph(candidate, new_funcgraph, inputs, edge_match_func);
771   } else {
772     // The target graph is not a GraphKernel, build the new_funcgraph to a CNode.
773     MS_LOG(WARNING) << "Target node " << candidate.target_graph->fullname_with_scope()
774                     << " is not a graph kernel node, cannot absort the link edge!";
775     return;
776   }
777 }
778 
DoRun(const FuncGraphPtr & func_graph,bool use_csr)779 bool GraphKernelRecompute::DoRun(const FuncGraphPtr &func_graph, bool use_csr) {
780   int repeat_times = 2;
781   while ((repeat_times--) != 0) {
782     if (use_csr) {
783       CSRRecompute csr_recompute;
784       candidates_ = csr_recompute.Run(func_graph);
785     } else {
786       AutoRecompute auto_recompute;
787       candidates_ = auto_recompute.Run(func_graph);
788     }
789     if (candidates_.empty()) {
790       return false;
791     }
792     auto mng = func_graph->manager();
793     MS_EXCEPTION_IF_NULL(mng);
794     for (auto &c : candidates_) {
795       if (!common::AnfAlgo::IsGraphKernel(c.target_graph)) {
796         continue;
797       }
798       std::ostringstream oss;
799       for (auto &e : c.recompute_edges) {
800         if (!IsPrimitiveCNode(e, prim::kPrimTupleGetItem)) {
801           MS_LOG(EXCEPTION) << "The edge should be GetItem but got " << e->fullname_with_scope();
802         }
803         oss << e->fullname_with_scope() << ", ";
804       }
805       MS_LOG(INFO) << "Clone " << c.source_graph->fullname_with_scope() << " to "
806                    << c.target_graph->fullname_with_scope() << ", edges [" << oss.str() << "]";
807       Process(c);
808     }
809     mng->RemoveRoots();
810     mng->KeepRoots({func_graph});
811   }
812   return true;
813 }
814 
Run(const FuncGraphPtr & func_graph)815 bool GraphKernelRecompute::Run(const FuncGraphPtr &func_graph) {
816   bool status = DoRun(func_graph);
817   if (GraphKernelFlags::GetInstance().enable_csr_fusion) {
818     status = DoRun(func_graph, true) || status;
819   }
820   return status;
821 }
822 }  // namespace mindspore::graphkernel
823