1 /**
2 * Copyright 2021-2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "backend/common/graph_kernel/graph_kernel_recompute.h"
18
19 #include <algorithm>
20 #include <deque>
21 #include <functional>
22 #include <limits>
23 #include <map>
24 #include <memory>
25 #include <queue>
26 #include <set>
27 #include <sstream>
28 #include <stack>
29 #include <tuple>
30 #include <utility>
31 #include <vector>
32 #include "mindspore/core/ops/sparse_ops.h"
33 #include "mindspore/core/ops/sequence_ops.h"
34 #include "mindspore/core/ops/math_ops.h"
35 #include "mindspore/core/ops/array_ops.h"
36 #include "mindspore/core/ops/framework_ops.h"
37 #include "kernel/framework_utils.h"
38 #include "backend/common/graph_kernel/graph_kernel_helper.h"
39 #include "backend/common/graph_kernel/core/graph_builder.h"
40 #include "ir/func_graph_cloner.h"
41
42 namespace mindspore::graphkernel {
43 namespace {
GetGetitemIndex(const AnfNodePtr & getitem)44 int64_t GetGetitemIndex(const AnfNodePtr &getitem) {
45 auto vnode = GetValueNode(getitem->cast<CNodePtr>()->input(kInputNodeOutputIndexInTupleGetItem));
46 return GetValue<int64_t>(vnode);
47 }
48
GetOutput(const FuncGraphPtr & func_graph,size_t i)49 AnfNodePtr GetOutput(const FuncGraphPtr &func_graph, size_t i) {
50 auto output_node = func_graph->output()->cast<CNodePtr>();
51 MS_EXCEPTION_IF_NULL(output_node);
52 if (IsPrimitiveCNode(output_node, prim::kPrimMakeTuple)) {
53 if (i + 1 >= output_node->size()) {
54 MS_LOG(EXCEPTION) << i << " is out of range of MakeTuple's size " << output_node->size();
55 }
56 return output_node->input(i + 1);
57 } else {
58 if (i > 0) {
59 MS_LOG(EXCEPTION) << "the graph is single output but i is not 0. it's " << i;
60 }
61 return output_node->cast<AnfNodePtr>();
62 }
63 }
64
IsExclude(const AnfNodePtr & node)65 bool IsExclude(const AnfNodePtr &node) {
66 static std::vector<PrimitivePtr> excludes = {prim::kPrimReturn, prim::kPrimUpdateState, prim::kPrimLoad,
67 prim::kPrimMakeTuple, prim::kPrimDepend};
68 return std::any_of(excludes.begin(), excludes.end(),
69 [&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); });
70 }
71
72 enum class VisitType : char { FOLLOW, STOP };
73 using VisitFunc = std::function<VisitType(const AnfNodePtr &)>;
74 using NextFunc = std::function<AnfNodePtrList(const AnfNodePtr &)>;
75 using ProcessFunc = std::function<void(const AnfNodePtr &)>;
76
Dfs(const AnfNodePtr & current,const VisitFunc & visit_func,const NextFunc & next_func,const ProcessFunc & before_func,const ProcessFunc & after_func,std::set<AnfNodePtr> * visited)77 void Dfs(const AnfNodePtr ¤t, const VisitFunc &visit_func, const NextFunc &next_func,
78 const ProcessFunc &before_func, const ProcessFunc &after_func, std::set<AnfNodePtr> *visited) {
79 if (visited->count(current) > 0) {
80 return;
81 }
82 (void)visited->insert(current);
83 if (visit_func(current) != VisitType::FOLLOW) {
84 return;
85 }
86
87 for (const auto &next : next_func(current)) {
88 before_func(next);
89 Dfs(next, visit_func, next_func, before_func, after_func, visited);
90 after_func(next);
91 }
92 }
93
CollectLinkPaths(const std::map<AnfNodePtr,MemorySize> & topo_indice,const OrderedSet<AnfNodePtr> & direct_users,MemorySize max_topo_user_index,const FuncGraphManagerPtr & mng)94 OrderedMap<AnfNodePtr, AnfNodePtrList> CollectLinkPaths(const std::map<AnfNodePtr, MemorySize> &topo_indice,
95 const OrderedSet<AnfNodePtr> &direct_users,
96 MemorySize max_topo_user_index,
97 const FuncGraphManagerPtr &mng) {
98 std::stack<AnfNodePtr> cur_stack;
99 OrderedMap<AnfNodePtr, AnfNodePtrList> link_paths;
100 auto TmpVisitFunc = [&topo_indice, max_topo_user_index](const AnfNodePtr &n) -> VisitType {
101 if (IsExclude(n)) {
102 return VisitType::STOP;
103 }
104
105 auto iter = topo_indice.find(n);
106 if (iter == topo_indice.end()) {
107 MS_LOG(EXCEPTION) << "Cannot find " << n->fullname_with_scope() << " in topo indices!";
108 }
109 if (iter->second > max_topo_user_index) {
110 return VisitType::STOP;
111 }
112 return VisitType::FOLLOW;
113 };
114
115 auto TmpNextFunc = [&mng](const AnfNodePtr &n) -> AnfNodePtrList {
116 auto users = mng->node_users()[n];
117 AnfNodePtrList nexts;
118 (void)std::transform(users.cbegin(), users.cend(), std::back_inserter(nexts),
119 [](const std::pair<AnfNodePtr, int> &user) { return user.first; });
120 return nexts;
121 };
122
123 auto TmpBeforeFunc = [&link_paths, &cur_stack, &direct_users](const AnfNodePtr &next) -> void {
124 if (direct_users.count(next) == 0) {
125 return;
126 }
127 auto cur_node = cur_stack.top();
128 if (link_paths.find(cur_node) == link_paths.end()) {
129 (void)link_paths.emplace(cur_node, AnfNodePtrList());
130 }
131 link_paths[cur_node].push_back(next);
132 cur_stack.push(next);
133 };
134
135 auto TmpAfterFunc = [&cur_stack, &direct_users](const AnfNodePtr &next) -> void {
136 if (direct_users.count(next) == 0) {
137 return;
138 }
139 cur_stack.push(next);
140 };
141
142 std::set<AnfNodePtr> visited;
143 for (auto user : direct_users) {
144 cur_stack.push(user);
145 Dfs(user, TmpVisitFunc, TmpNextFunc, TmpBeforeFunc, TmpAfterFunc, &visited);
146 cur_stack.pop();
147 }
148
149 return link_paths;
150 }
151
GetLongTermNodes(const AnfNodePtrList & nodes,const AnfNodePtr & end_node,const std::map<AnfNodePtr,MemorySize> & topo_indices,const FuncGraphManagerPtr & mng)152 OrderedSet<AnfNodePtr> GetLongTermNodes(const AnfNodePtrList &nodes, const AnfNodePtr &end_node,
153 const std::map<AnfNodePtr, MemorySize> &topo_indices,
154 const FuncGraphManagerPtr &mng) {
155 OrderedSet<AnfNodePtr> long_term_nodes;
156 for (auto node : nodes) {
157 auto real_node = common::AnfAlgo::VisitKernelWithReturnType(node, 0).first;
158 // Parameter or value have long term tensors.
159 if (!utils::isa<CNodePtr>(real_node)) {
160 (void)long_term_nodes.insert(node);
161 continue;
162 }
163
164 auto users = mng->node_users()[node];
165 if (std::any_of(users.cbegin(), users.cend(), [&topo_indices, &end_node](const std::pair<AnfNodePtr, int> &user) {
166 auto user_topo = topo_indices.find(user.first);
167 auto end_topo = topo_indices.find(end_node);
168 return user_topo->second >= end_topo->second;
169 })) {
170 (void)long_term_nodes.insert(node);
171 }
172 }
173 return long_term_nodes;
174 }
175
176 /**
177 * @brief Remove real input which is not used and change the related graph parameters.
178 *
179 * @param func_graph Graph.
180 * @param inputs Real inputs for graph cnode.
181 */
ElimRedundantInputsAndGraphParameters(const FuncGraphPtr & func_graph,AnfNodePtrList * inputs)182 void ElimRedundantInputsAndGraphParameters(const FuncGraphPtr &func_graph, AnfNodePtrList *inputs) {
183 MS_EXCEPTION_IF_NULL(inputs);
184 const auto &ori_parameter = func_graph->parameters();
185 auto nodes = TopoSort(func_graph->get_return());
186 std::set<AnfNodePtr> used_param;
187 for (auto node : nodes) {
188 if (node->isa<Parameter>()) {
189 (void)used_param.insert(node);
190 }
191 }
192 if (used_param.size() == ori_parameter.size()) {
193 return;
194 }
195 AnfNodePtrList new_parameter, new_inputs;
196 for (size_t i = 0; i < ori_parameter.size(); ++i) {
197 if (used_param.count(ori_parameter[i]) != 0) {
198 new_parameter.push_back(ori_parameter[i]);
199 new_inputs.push_back((*inputs)[i]);
200 }
201 }
202 func_graph->set_parameters(new_parameter);
203 *inputs = std::move(new_inputs);
204 }
205 } // namespace
206
Run(const FuncGraphPtr & func_graph)207 std::vector<Candidate> AutoRecompute::Run(const FuncGraphPtr &func_graph) {
208 lifetime_threshold_ = GraphKernelFlags::GetInstance().recompute_increment_threshold;
209 local_peak_threshold_ = GraphKernelFlags::GetInstance().recompute_peak_threshold;
210 if (!IsThresholdDefaultValue()) {
211 FindCandidates(func_graph);
212 }
213 return candidates_;
214 }
215
216 /**
217 * @brief Filter the input tensor(that live longer than end node) out and return valid inputs for memory calculation. \n
218 * If the topo indices of the input's user is at least one greater than end_node, \n
219 * it will retain when after end_node's execution.
220 *
221 * @param source_node
222 * @param end_node
223 * @param edge_pos
224 * @param mng
225 * @return AnfNodePtrList
226 */
Filter(const AnfNodePtr & source_node,const AnfNodePtr & end_node,int edge_pos,const FuncGraphManagerPtr & mng)227 AnfNodePtrList AutoRecompute::Filter(const AnfNodePtr &source_node, const AnfNodePtr &end_node, int edge_pos,
228 const FuncGraphManagerPtr &mng) {
229 auto source_cnode = source_node->cast<CNodePtr>();
230 MS_EXCEPTION_IF_NULL(source_cnode);
231 AnfNodePtrList node_inputs(source_cnode->inputs().begin() + 1, source_cnode->inputs().end());
232 OrderedSet<AnfNodePtr> long_term_inputs = GetLongTermNodes(node_inputs, end_node, topo_indice_, mng);
233
234 AnfNodePtrList check_inputs;
235 if (IsPrimitiveCNode(end_node->cast<CNodePtr>()->input(IntToSize(edge_pos)), prim::kPrimTupleGetItem)) {
236 auto out_index = GetSourceLinkOutPos(end_node, edge_pos);
237 auto sub_graph = common::AnfAlgo::GetCNodeFuncGraphPtr(source_node);
238 auto out = sub_graph->output();
239 if (!IsPrimitiveCNode(out, prim::kPrimMakeTuple)) {
240 MS_LOG(EXCEPTION) << "Expect MakeTuple node, but got " << common::AnfAlgo::GetCNodeName(out);
241 }
242
243 // Find subgraph's input according to edge node.
244 auto start_node = out->cast<CNodePtr>()->input(IntToSize(out_index + 1));
245 AnfNodePtrList sub_input_parameters;
246 std::queue<AnfNodePtr> node_q;
247 node_q.push(start_node);
248 while (!node_q.empty()) {
249 auto cur = node_q.front();
250 node_q.pop();
251 if (utils::isa<ParameterPtr>(cur)) {
252 sub_input_parameters.push_back(cur);
253 }
254 auto cur_cnode = cur->cast<CNodePtr>();
255 if (cur_cnode) {
256 for (size_t i = 1; i < cur_cnode->size(); ++i) {
257 node_q.push(cur_cnode->input(i));
258 }
259 }
260 }
261
262 // Filte input that user's topo index is great than source graph.
263 for (auto para : sub_input_parameters) {
264 for (size_t i = 0; i < sub_graph->parameters().size(); ++i) {
265 if (para == sub_graph->parameters()[i]) {
266 check_inputs.push_back(node_inputs[i]);
267 }
268 }
269 }
270 } else {
271 check_inputs = node_inputs;
272 }
273
274 AnfNodePtrList res;
275 for (auto input : check_inputs) {
276 if (long_term_inputs.count(input) == 0) {
277 res.push_back(input);
278 }
279 }
280
281 return res;
282 }
283
284 /**
285 * @brief Get valid users information by giving node, excluding TupleGetItem, Load and so on.
286 */
GetValidUsers(const AnfNodePtr & node,const FuncGraphManagerPtr & mng)287 std::tuple<OrderedSet<AnfNodePtr>, OutPosLinkMap, MemorySize> AutoRecompute::GetValidUsers(
288 const AnfNodePtr &node, const FuncGraphManagerPtr &mng) {
289 auto &user_map = mng->node_users();
290 auto users = user_map[node];
291 MemorySize max_topo_user_index = 0;
292 std::queue<std::pair<AnfNodePtr, int>> users_queue;
293 for (auto user_index : users) {
294 users_queue.push(user_index);
295 }
296 OrderedSet<AnfNodePtr> direct_users;
297 OutPosLinkMap user_edge_pos;
298 while (!users_queue.empty()) {
299 auto [user, index] = users_queue.front();
300 users_queue.pop();
301 if (IsPrimitiveCNode(user, prim::kPrimTupleGetItem)) {
302 for (auto get_item_user : user_map[user]) {
303 users_queue.push(get_item_user);
304 }
305 continue;
306 } else if (IsExclude(user)) {
307 continue;
308 }
309 user_edge_pos[user].push_back(index);
310 (void)direct_users.insert(user);
311 // Update maximum topo value.
312 if (topo_indice_[user] > max_topo_user_index) {
313 max_topo_user_index = topo_indice_[user];
314 }
315 }
316
317 return {direct_users, user_edge_pos, max_topo_user_index};
318 }
319
320 /**
321 * @brief Judege target node for recompute according to current node, and capture source node information when find \n
322 * target. There two type for tensor of the edge between source node and target node, example: \n
323 * source ──[Short-Term]── A ── other \n
324 * │ │ \n
325 * └───────[Long-Term]────── target \n
326 * For this example, \n
327 * 1. There are two path from source node to target node, and target is directly user for source node, \n
328 * so the tensor of their edge is a long-term tensor. \n
329 * 2. From source node to A, there is only one path, and A is directly user for source node, \n
330 * so the tensor of their edge is a short-term tensor.
331 *
332 * @param node Source node.
333 * @param mng Graph manager.
334 * @return OutPosLinkList Vector[Tuple(target node, input positions of target node for edge, edge type)].
335 */
JudegeTargetAndCaptureSource(const AnfNodePtr & node,const FuncGraphManagerPtr & mng)336 OutPosLinkList AutoRecompute::JudegeTargetAndCaptureSource(const AnfNodePtr &node, const FuncGraphManagerPtr &mng) {
337 auto [direct_users, user_edge_pos, max_topo_user_index] = GetValidUsers(node, mng);
338 OutPosLinkList target_link_infos;
339 OrderedSet<AnfNodePtr> long_term_users;
340 // If the number of direct users is less than 2, there will no side way to its user....
341 if (direct_users.size() >= 2) {
342 OrderedMap<AnfNodePtr, AnfNodePtrList> link_paths =
343 CollectLinkPaths(topo_indice_, direct_users, max_topo_user_index, mng);
344 for (const auto &[source, paths] : link_paths) {
345 for (auto target : paths) {
346 if (target != source) {
347 (void)target_link_infos.emplace_back(target, user_edge_pos[target], EdgeLifeTimeType::LongTerm);
348 (void)long_term_users.insert(target);
349 }
350 }
351 }
352 }
353
354 // Direct users include long term users and short term users.
355 // If the short term user is graph kernel composite node, it may be absorb and reduce the local peak memory.
356 for (const auto &user : direct_users) {
357 if (long_term_users.count(user) == 0 && common::AnfAlgo::IsGraphKernel(user)) {
358 (void)target_link_infos.emplace_back(user, user_edge_pos[user], EdgeLifeTimeType::ShortTerm);
359 }
360 }
361
362 RecomputeLinkEdgeLog(node, direct_users, target_link_infos);
363 return target_link_infos;
364 }
365
366 /**
367 * @brief Get position of edge tensor between source node and target node. \n
368 * For example, giving target node and edge position 0, will return 1: \n
369 * source node \n
370 * [0] [1] [2] <- output position \n
371 * | \n
372 * | \n
373 * / \n
374 * [0] [1] <- input position \n
375 * target node
376 *
377 * @param target Target node.
378 * @param pos The input position of target node for edge.
379 * @return int The output position of source node for edge.
380 */
GetSourceLinkOutPos(const AnfNodePtr & target,int pos) const381 int AutoRecompute::GetSourceLinkOutPos(const AnfNodePtr &target, int pos) const {
382 // If the input is get-item, than use get-item's index, otherwise zero.
383 auto cnode = target->cast<CNodePtr>();
384 MS_EXCEPTION_IF_NULL(cnode);
385 auto prenode = cnode->input(IntToSize(pos));
386 if (!IsPrimitiveCNode(prenode, prim::kPrimTupleGetItem)) {
387 return 0;
388 }
389
390 auto get_item_cnode = prenode->cast<CNodePtr>();
391 MS_EXCEPTION_IF_NULL(get_item_cnode);
392 auto value_input = get_item_cnode->input(kInputNodeOutputIndexInTupleGetItem);
393 MS_EXCEPTION_IF_NULL(value_input);
394 auto value_node = value_input->cast<ValueNodePtr>();
395 MS_EXCEPTION_IF_NULL(value_node);
396 return static_cast<int>(GetValue<int64_t>(value_node->value()));
397 }
398
SelectThreshold(EdgeLifeTimeType type) const399 MemorySize AutoRecompute::SelectThreshold(EdgeLifeTimeType type) const {
400 MemorySize threshold = 0;
401 auto local_peak_th = local_peak_threshold_ == 0 ? std::numeric_limits<MemorySize>::max() : local_peak_threshold_;
402 auto lifetime_th = lifetime_threshold_ == 0 ? std::numeric_limits<MemorySize>::max() : lifetime_threshold_;
403 if (type == EdgeLifeTimeType::ShortTerm) {
404 threshold = local_peak_th;
405 } else if (type == EdgeLifeTimeType::LongTerm) {
406 threshold = std::min(local_peak_th, lifetime_th);
407 }
408
409 return threshold;
410 }
411
IsThresholdDefaultValue() const412 bool AutoRecompute::IsThresholdDefaultValue() const {
413 if (local_peak_threshold_ == 0 && lifetime_threshold_ == 0) {
414 return true;
415 }
416 return false;
417 }
418
419 /**
420 * @brief Find recompute candidates(source node, target node, edge and its type) in func_graph. \n
421 * Result will be add to candidates_.
422 *
423 * @param func_graph
424 */
FindCandidates(const FuncGraphPtr & func_graph)425 void AutoRecompute::FindCandidates(const FuncGraphPtr &func_graph) {
426 MS_EXCEPTION_IF_NULL(func_graph);
427 candidates_.clear();
428
429 auto mng = func_graph->manager();
430 MS_EXCEPTION_IF_NULL(mng);
431
432 auto topo_nodes = TopoSort(func_graph->get_return());
433 // Topo indice is use to early stop in predecessor check.
434 for (size_t i = 0; i < topo_nodes.size(); ++i) {
435 (void)topo_indice_.emplace(topo_nodes[i], i);
436 }
437
438 // Candidate condition:
439 // 1. Judge current node can see its graph_kernel input with other input's backward path.
440 // 2. Memory variety between split out and origin more than threshold:
441 // `Size(gs_direct_outs_to_gt) - filter(gs_inputs, its) > threshold`.
442 for (auto node : topo_nodes) {
443 if (!common::AnfAlgo::IsGraphKernel(node)) {
444 continue;
445 }
446 auto target_graphs = JudegeTargetAndCaptureSource(node, mng);
447 if (target_graphs.empty()) {
448 continue;
449 }
450 auto node_candidates = FindNodeRecomputeCandidates(node, target_graphs, mng);
451 // Delete duplicated link.
452 for (const auto &[source, target_and_link] : node_candidates) {
453 for (const auto &[target, link] : target_and_link) {
454 candidates_.push_back({source, target, link.first, link.second});
455 }
456 }
457 }
458
459 RecomputeCandidatesLog(candidates_);
460 }
461
462 /**
463 * @brief Find recompute candidates for node as source graph.
464 *
465 * @param node Source graph node.
466 * @param target_graphs Vector of [AnfNodePtr, std::vector<int>, EdgeLifeTimeType].
467 * @param mng Manager of main graph(which contains this node).
468 * @return AutoRecompute::NodeRecomputeCandidates
469 */
FindNodeRecomputeCandidates(const AnfNodePtr & node,const OutPosLinkList & target_graphs,const FuncGraphManagerPtr & mng)470 AutoRecompute::NodeRecomputeCandidates AutoRecompute::FindNodeRecomputeCandidates(const AnfNodePtr &node,
471 const OutPosLinkList &target_graphs,
472 const FuncGraphManagerPtr &mng) {
473 MS_EXCEPTION_IF_NULL(node);
474 MS_EXCEPTION_IF_NULL(mng);
475 NodeRecomputeCandidates node_candidates;
476 auto graph_node = common::AnfAlgo::GetCNodeFuncGraphPtr(node);
477 MS_EXCEPTION_IF_NULL(graph_node);
478 auto nodes = graph_node->nodes();
479 if (std::any_of(nodes.cbegin(), nodes.cend(),
480 [](const AnfNodePtr n) { return IsPrimitiveCNode(n, prim::kPrimReduceSum); })) {
481 return node_candidates;
482 }
483 for (auto [gt, gt_in_pos_vec, edge_life_time_type] : target_graphs) {
484 MemorySize threshold = SelectThreshold(edge_life_time_type);
485 for (auto gt_in_pos : gt_in_pos_vec) {
486 MemorySize out_tensor_size =
487 static_cast<MemorySize>(AnfAlgo::GetOutputTensorMemSize(node, IntToSize(GetSourceLinkOutPos(gt, gt_in_pos))));
488 MemorySize absorb_input_tensor_size = 0;
489 for (auto input : Filter(node, gt, gt_in_pos, mng)) {
490 absorb_input_tensor_size += static_cast<MemorySize>(AnfAlgo::GetOutputTensorMemSize(input, 0));
491 }
492 auto gt_cnode = gt->cast<CNodePtr>();
493 MS_EXCEPTION_IF_NULL(gt_cnode);
494 auto edge = gt_cnode->input(IntToSize(gt_in_pos));
495
496 MS_LOG(DEBUG) << "Recompute case: GS(" << node->fullname_with_scope() << ") -> GT(" << gt->fullname_with_scope()
497 << ") with Edge(" << edge->fullname_with_scope() << "<" << edge_life_time_type << ">.";
498
499 if (out_tensor_size < absorb_input_tensor_size) {
500 MS_LOG(DEBUG) << " ==> Skip this case because memory reduction.";
501 continue;
502 }
503
504 auto memory_increment = out_tensor_size - absorb_input_tensor_size;
505 MS_LOG(DEBUG) << " ==> Threshold: " << threshold << ", Out Tensor[" << out_tensor_size << "] - Absort Tensor["
506 << absorb_input_tensor_size << "] = " << memory_increment;
507
508 if (memory_increment > threshold) {
509 if (node_candidates[node].find(gt) == node_candidates[node].end()) {
510 node_candidates[node][gt] = {edge_life_time_type, AnfNodePtrList{}};
511 }
512 // Only add getitem node as edge, if GS is single output node, there will be no edges.
513 if (IsPrimitiveCNode(edge, prim::kPrimTupleGetItem)) {
514 node_candidates[node][gt].second.push_back(edge);
515 }
516 }
517 }
518 }
519 return node_candidates;
520 }
521
RecomputeLinkEdgeLog(const AnfNodePtr & node,const OrderedSet<AnfNodePtr> & direct_users,const OutPosLinkList & target_link_infos) const522 void AutoRecompute::RecomputeLinkEdgeLog(const AnfNodePtr &node, const OrderedSet<AnfNodePtr> &direct_users,
523 const OutPosLinkList &target_link_infos) const {
524 MS_EXCEPTION_IF_NULL(node);
525 MS_LOG(DEBUG) << "Recompute users for node: " << node->fullname_with_scope();
526 for (const auto &direct_user : direct_users) {
527 MS_LOG(DEBUG) << " └─ " << direct_user->fullname_with_scope();
528 }
529
530 MS_LOG(DEBUG) << "Edge Link relation: ";
531 for (const auto &[target, tartget_in_index, life_type] : target_link_infos) {
532 MS_EXCEPTION_IF_NULL(target);
533 MS_LOG(DEBUG) << " └[" << tartget_in_index << "|<" << life_type
534 << ">]─> Link to: " << target->fullname_with_scope();
535 }
536 }
537
RecomputeCandidatesLog(const std::vector<Candidate> & candidates) const538 void AutoRecompute::RecomputeCandidatesLog(const std::vector<Candidate> &candidates) const {
539 MS_LOG(INFO) << "Recompute candidates: ";
540 for (auto candidate : candidates) {
541 MS_LOG(INFO) << " └─ GS: " << candidate.source_graph->fullname_with_scope();
542 MS_LOG(INFO) << " └─ GT: " << candidate.target_graph->fullname_with_scope();
543 for (auto edge : candidate.recompute_edges) {
544 MS_LOG(INFO) << " └─[Edge]─> " << edge->fullname_with_scope();
545 }
546 }
547 }
548
Run(const FuncGraphPtr & func_graph)549 std::vector<Candidate> CSRRecompute::Run(const FuncGraphPtr &func_graph) {
550 FindCandidates(func_graph);
551 return candidates_;
552 }
553
CheckPrimitiveInput(AnfNodePtr base,const PrimitivePtr & prim_type) const554 bool CSRRecompute::CheckPrimitiveInput(AnfNodePtr base, const PrimitivePtr &prim_type) const {
555 std::deque<AnfNodePtr> q{base};
556 std::set<AnfNodePtr> visited;
557 while (!q.empty()) {
558 auto node = q.front();
559 q.pop_front();
560 if (visited.count(node) > 0) {
561 continue;
562 }
563 (void)visited.insert(node);
564 if (IsPrimitiveCNode(node, prim_type)) {
565 return true;
566 }
567 auto cnode = node->cast<CNodePtr>();
568 if (cnode == nullptr) {
569 continue;
570 }
571 auto inputs = cnode->inputs();
572 (void)q.insert(q.begin(), inputs.begin(), inputs.end());
573 }
574 return false;
575 }
576
FindNodeRecomputeCandidates(const AnfNodePtr & node,const OutPosLinkList & target_graphs,const FuncGraphManagerPtr & mng)577 AutoRecompute::NodeRecomputeCandidates CSRRecompute::FindNodeRecomputeCandidates(const AnfNodePtr &node,
578 const OutPosLinkList &target_graphs,
579 const FuncGraphManagerPtr &mng) {
580 MS_EXCEPTION_IF_NULL(node);
581 MS_EXCEPTION_IF_NULL(mng);
582 NodeRecomputeCandidates node_candidates;
583 auto graph_node = common::AnfAlgo::GetCNodeFuncGraphPtr(node);
584 MS_EXCEPTION_IF_NULL(graph_node);
585 // subgraphs outputting UnsortedSegmentSum or CSRReduceSum along with other ops
586 // (likely the result of Gather), or containing CSRDiv without outputting
587 // UnsortedSegmentSum or CSRReduceSum, are selected as candidates for recompute.
588 auto TargetTail = [](const AnfNodePtr n) {
589 return IsPrimitiveCNode(n, prim::kPrimUnsortedSegmentSum) || IsPrimitiveCNode(n, prim::kPrimCSRReduceSum);
590 };
591 auto TargetHead = [](const AnfNodePtr n) { return IsPrimitiveCNode(n, prim::kPrimCSRDiv); };
592 auto return_node = graph_node->get_return();
593 MS_EXCEPTION_IF_NULL(return_node);
594 auto return_cnode = return_node->cast<CNodePtr>();
595 MS_EXCEPTION_IF_NULL(return_cnode);
596 auto return_inputs = return_cnode->inputs();
597 auto return_tup = return_inputs[1]->cast<CNodePtr>();
598 MS_EXCEPTION_IF_NULL(return_tup);
599 auto tuple_inputs = return_tup->inputs();
600 std::set<size_t> candidate_idx;
601 if (std::any_of(tuple_inputs.cbegin(), tuple_inputs.cend(), TargetTail)) {
602 for (size_t i = 1; i < tuple_inputs.size(); ++i) {
603 if (!TargetTail(tuple_inputs[i])) {
604 (void)candidate_idx.insert(i - 1);
605 }
606 }
607 } else if (std::any_of(tuple_inputs.cbegin(), tuple_inputs.cend(), TargetHead)) {
608 for (size_t i = 1; i < tuple_inputs.size(); ++i) {
609 if (CheckPrimitiveInput(tuple_inputs[i], prim::kPrimCSRDiv)) {
610 (void)candidate_idx.insert(i - 1);
611 }
612 }
613 }
614 if (candidate_idx.empty()) {
615 return node_candidates;
616 }
617 for (size_t i = 0; i < target_graphs.size(); ++i) {
618 AnfNodePtr gt;
619 std::vector<int> gt_in_pos_vec;
620 std::tie(gt, gt_in_pos_vec, std::ignore) = target_graphs[i];
621 for (auto gt_in_pos : gt_in_pos_vec) {
622 auto gt_cnode = gt->cast<CNodePtr>();
623 MS_EXCEPTION_IF_NULL(gt_cnode);
624 auto edge = gt_cnode->input(IntToSize(gt_in_pos));
625 if (!IsPrimitiveCNode(edge, prim::kPrimTupleGetItem)) {
626 continue;
627 }
628 auto edge_cnode = edge->cast<CNodePtr>();
629 MS_EXCEPTION_IF_NULL(edge_cnode);
630 auto tuple_idx = common::AnfAlgo::GetTupleGetItemOutIndex(edge_cnode);
631 if (candidate_idx.count(tuple_idx) > 0) {
632 node_candidates[node][gt].second.push_back(edge);
633 }
634 }
635 }
636 return node_candidates;
637 }
638
CloneGraph(const CNodePtr & source_graph,const AnfNodePtrList & recompute_edges) const639 std::pair<FuncGraphPtr, AnfNodePtrList> GraphKernelRecompute::CloneGraph(const CNodePtr &source_graph,
640 const AnfNodePtrList &recompute_edges) const {
641 MS_EXCEPTION_IF_NULL(source_graph);
642 auto gs = common::AnfAlgo::GetCNodeFuncGraphPtr(source_graph);
643 MS_EXCEPTION_IF_NULL(gs);
644 AnfNodePtrList inputs(source_graph->inputs().begin() + 1, source_graph->inputs().end());
645 auto new_funcgraph = BasicClone(gs);
646 auto output_node = new_funcgraph->output()->cast<CNodePtr>();
647 MS_EXCEPTION_IF_NULL(output_node);
648 if (!IsPrimitiveCNode(output_node, prim::kPrimMakeTuple)) {
649 return {new_funcgraph, inputs};
650 }
651 // remove outputs that not in recompute edges.
652 AnfNodePtrList new_outputs;
653 for (auto &edge : recompute_edges) {
654 auto idx = GetGetitemIndex(edge);
655 new_outputs.push_back(GetOutput(new_funcgraph, LongToSize(idx)));
656 }
657 if (new_outputs.size() + 1 == output_node->size()) {
658 return {new_funcgraph, inputs};
659 }
660 (void)new_outputs.insert(new_outputs.cbegin(), output_node->input(0));
661 auto new_output_node = new_funcgraph->NewCNode(new_outputs);
662 // use the old abstract, since the new_funcgraph will be deleted in later process.
663 new_output_node->set_abstract(output_node->abstract());
664 new_output_node->set_kernel_info(std::make_shared<device::KernelInfo>());
665 new_funcgraph->set_output(new_output_node);
666 ElimRedundantInputsAndGraphParameters(new_funcgraph, &inputs);
667 return {new_funcgraph, inputs};
668 }
669
LinkIntoTargetFuncGraph(const Candidate & candidate,const FuncGraphPtr & cloned_func,const AnfNodePtrList & cloned_inputs,const std::function<std::pair<bool,size_t> (const Candidate &,const AnfNodePtr &)> & edge_match_func) const670 void GraphKernelRecompute::LinkIntoTargetFuncGraph(
671 const Candidate &candidate, const FuncGraphPtr &cloned_func, const AnfNodePtrList &cloned_inputs,
672 const std::function<std::pair<bool, size_t>(const Candidate &, const AnfNodePtr &)> &edge_match_func) const {
673 auto cloned_nodes = TopoSort(cloned_func->get_return());
674 auto gt = common::AnfAlgo::GetCNodeFuncGraphPtr(candidate.target_graph);
675 MS_EXCEPTION_IF_NULL(gt);
676 auto mng = gt->manager();
677 if (mng == nullptr) {
678 mng = Manage(gt, true);
679 gt->set_manager(mng);
680 }
681
682 // link the outputs to gt
683 auto gt_node = candidate.target_graph->cast<CNodePtr>();
684 MS_EXCEPTION_IF_NULL(gt_node);
685 AnfNodePtrList new_parameters;
686 AnfNodePtrList new_inputs;
687 auto ¶ms = gt->parameters();
688 for (size_t i = 0; i < params.size(); i++) {
689 // if the parameter is a recompute edge, then links the param to the cloned_func's output.
690 auto [is_match, out_index] = edge_match_func(candidate, gt_node->input(i + 1));
691 if (is_match) {
692 (void)mng->Replace(params[i], GetOutput(cloned_func, out_index));
693 } else {
694 new_parameters.push_back(params[i]);
695 new_inputs.push_back(gt_node->input(i + 1));
696 }
697 }
698
699 // add new parameters
700 auto &cloned_func_params = cloned_func->parameters();
701 for (size_t i = 0; i < cloned_func_params.size(); i++) {
702 auto iter = std::find(new_inputs.begin(), new_inputs.end(), cloned_inputs[i]);
703 if (iter != new_inputs.end()) {
704 auto idx = iter - new_inputs.begin();
705 (void)cloned_func->manager()->Replace(cloned_func_params[i], new_parameters[LongToSize(idx)]);
706 } else {
707 new_parameters.push_back(gt->add_parameter());
708 new_inputs.push_back(cloned_inputs[i]);
709 (void)cloned_func->manager()->Replace(cloned_func_params[i], new_parameters.back());
710 }
711 }
712
713 // reset the func_graph for cloned_nodes.
714 for (auto &node : cloned_nodes) {
715 if (node->isa<CNode>()) {
716 node->set_func_graph(gt);
717 }
718 }
719 AnfNodePtrList new_node_inputs = {gt_node->input(0)};
720 (void)new_node_inputs.insert(new_node_inputs.cend(), new_inputs.cbegin(), new_inputs.cend());
721 gt->set_parameters(new_parameters);
722 gt_node->set_inputs(new_node_inputs);
723 AnfNodePtrList outputs;
724 kernel::GetFuncGraphOutputNodes(gt, &outputs);
725 gt_node->set_kernel_info(std::make_shared<device::KernelInfo>());
726 Callback::Instance()->SetGraphKernelNodeKernelInfo(gt_node);
727 mng->RemoveRoots();
728 mng->KeepRoots({gt});
729 }
730
Process(const Candidate & candidate) const731 void GraphKernelRecompute::Process(const Candidate &candidate) const {
732 FuncGraphPtr new_funcgraph;
733 AnfNodePtrList inputs;
734 std::function<std::pair<bool, size_t>(const Candidate &, const AnfNodePtr &)> edge_match_func;
735 if (candidate.recompute_edges.empty()) {
736 // single output, clone the whole source_graph.
737 auto gs = common::AnfAlgo::GetCNodeFuncGraphPtr(candidate.source_graph);
738 MS_EXCEPTION_IF_NULL(gs);
739 new_funcgraph = BasicClone(gs);
740 auto source_cnode = candidate.source_graph->cast<CNodePtr>();
741 MS_EXCEPTION_IF_NULL(source_cnode);
742 auto source_inputs = source_cnode->inputs();
743 (void)inputs.insert(inputs.cend(), source_inputs.cbegin() + 1, source_inputs.cend());
744 edge_match_func = [](const Candidate &match_candidate, const AnfNodePtr &to_match) -> std::pair<bool, size_t> {
745 if (match_candidate.source_graph == to_match) {
746 return std::make_pair(true, 0);
747 }
748 return std::make_pair(false, 0);
749 };
750 } else {
751 std::tie(new_funcgraph, inputs) = CloneGraph(candidate.source_graph->cast<CNodePtr>(), candidate.recompute_edges);
752 edge_match_func = [](const Candidate &match_candidate, const AnfNodePtr &to_match) -> std::pair<bool, size_t> {
753 auto iter = std::find(match_candidate.recompute_edges.begin(), match_candidate.recompute_edges.end(), to_match);
754 if (iter != match_candidate.recompute_edges.end()) {
755 auto out_index = iter - match_candidate.recompute_edges.begin();
756 return std::make_pair(true, LongToSize(out_index));
757 }
758 return std::make_pair(false, 0);
759 };
760 }
761
762 auto mng = new_funcgraph->manager();
763 if (mng == nullptr) {
764 mng = Manage(new_funcgraph, true);
765 new_funcgraph->set_manager(mng);
766 }
767
768 if (common::AnfAlgo::IsGraphKernel(candidate.target_graph)) {
769 // the target graph is a GraphKernel, push the new_funcgraph into the target graph.
770 LinkIntoTargetFuncGraph(candidate, new_funcgraph, inputs, edge_match_func);
771 } else {
772 // The target graph is not a GraphKernel, build the new_funcgraph to a CNode.
773 MS_LOG(WARNING) << "Target node " << candidate.target_graph->fullname_with_scope()
774 << " is not a graph kernel node, cannot absort the link edge!";
775 return;
776 }
777 }
778
DoRun(const FuncGraphPtr & func_graph,bool use_csr)779 bool GraphKernelRecompute::DoRun(const FuncGraphPtr &func_graph, bool use_csr) {
780 int repeat_times = 2;
781 while ((repeat_times--) != 0) {
782 if (use_csr) {
783 CSRRecompute csr_recompute;
784 candidates_ = csr_recompute.Run(func_graph);
785 } else {
786 AutoRecompute auto_recompute;
787 candidates_ = auto_recompute.Run(func_graph);
788 }
789 if (candidates_.empty()) {
790 return false;
791 }
792 auto mng = func_graph->manager();
793 MS_EXCEPTION_IF_NULL(mng);
794 for (auto &c : candidates_) {
795 if (!common::AnfAlgo::IsGraphKernel(c.target_graph)) {
796 continue;
797 }
798 std::ostringstream oss;
799 for (auto &e : c.recompute_edges) {
800 if (!IsPrimitiveCNode(e, prim::kPrimTupleGetItem)) {
801 MS_LOG(EXCEPTION) << "The edge should be GetItem but got " << e->fullname_with_scope();
802 }
803 oss << e->fullname_with_scope() << ", ";
804 }
805 MS_LOG(INFO) << "Clone " << c.source_graph->fullname_with_scope() << " to "
806 << c.target_graph->fullname_with_scope() << ", edges [" << oss.str() << "]";
807 Process(c);
808 }
809 mng->RemoveRoots();
810 mng->KeepRoots({func_graph});
811 }
812 return true;
813 }
814
Run(const FuncGraphPtr & func_graph)815 bool GraphKernelRecompute::Run(const FuncGraphPtr &func_graph) {
816 bool status = DoRun(func_graph);
817 if (GraphKernelFlags::GetInstance().enable_csr_fusion) {
818 status = DoRun(func_graph, true) || status;
819 }
820 return status;
821 }
822 } // namespace mindspore::graphkernel
823