• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/parallel/graph_util/pipeline_split_utils.h"
18 #include <algorithm>
19 #include <list>
20 #include <memory>
21 #include <queue>
22 #include <set>
23 #include "include/common/utils/comm_manager.h"
24 #include "frontend/parallel/device_manager.h"
25 #include "frontend/parallel/graph_util/generate_graph.h"
26 #include "frontend/parallel/graph_util/node_info.h"
27 #include "frontend/parallel/ops_info/ops_utils.h"
28 #include "frontend/parallel/step_parallel.h"
29 #include "frontend/parallel/step_parallel_utils.h"
30 #include "frontend/parallel/dynamic_shape/dynamic_shape.h"
31 #include "frontend/parallel/graph_util/fold_pipeline_split_utils.h"
32 #include "include/common/utils/parallel_context.h"
33 #include "ir/value.h"
34 #include "ops/array_ops.h"
35 #include "ops/framework_ops.h"
36 #include "ops/other_ops.h"
37 #include "ops/sequence_ops.h"
38 #include "utils/parallel_node_check.h"
39 
40 namespace mindspore {
41 namespace parallel {
42 namespace {
IsSendRec(const AnfNodePtr & node)43 bool IsSendRec(const AnfNodePtr &node) {
44   return IsPrimitiveCNode(node, prim::kPrimSend) || IsPrimitiveCNode(node, prim::kPrimReceive);
45 }
46 
TagForSendRecDepend(const AnfNodePtr & prior_node,const AnfNodePtr & post_node)47 std::string TagForSendRecDepend(const AnfNodePtr &prior_node, const AnfNodePtr &post_node) {
48   if (!IsSendRec(prior_node) || !IsSendRec(post_node)) {
49     return "";
50   }
51   if (prior_node->cast<CNodePtr>()->HasPrimalAttr(kPrimalAttrForwardNodeName) ==
52       post_node->cast<CNodePtr>()->HasPrimalAttr(kPrimalAttrForwardNodeName)) {
53     return "";
54   }
55   return std::string(SEND_REC_DEPEND);
56 }
57 }  // namespace
58 
IsFirstStage()59 bool IsFirstStage() {
60   MS_EXCEPTION_IF_NULL(g_device_manager);
61   auto stage_id = g_device_manager->stage_id();
62   return stage_id == 0;
63 }
64 
IsLastStage()65 bool IsLastStage() {
66   MS_EXCEPTION_IF_NULL(g_device_manager);
67   auto stage_num = g_device_manager->stage_num();
68   auto stage_id = g_device_manager->stage_id();
69   return ((stage_num - 1) == stage_id);
70 }
71 
GetReceiveMicro(const CNodePtr & cnode)72 static ValuePtr GetReceiveMicro(const CNodePtr &cnode) {
73   std::queue<CNodePtr> que;
74   std::set<AnfNodePtr> visited;
75   que.push(cnode);
76   while (!que.empty()) {
77     auto front = que.front();
78     que.pop();
79     (void)(visited.insert(front));
80     for (size_t i = 1; i < front->size(); ++i) {
81       auto input = front->input(i);
82       if (!input->isa<CNode>()) {
83         continue;
84       }
85       auto cinput = input->cast<CNodePtr>();
86       MS_EXCEPTION_IF_NULL(cinput);
87       if (IsPrimitiveCNode(cinput, prim::kPrimReceive)) {
88         return cinput->GetPrimalAttr(MICRO);
89       }
90       if (visited.find(cinput) == visited.end()) {
91         que.push(cinput);
92       }
93     }
94   }
95   return nullptr;
96 }
97 
GetReceiveSegment(const CNodePtr & cnode)98 static ValuePtr GetReceiveSegment(const CNodePtr &cnode) {
99   std::queue<CNodePtr> que;
100   std::set<AnfNodePtr> visited;
101   que.push(cnode);
102   while (!que.empty()) {
103     auto front = que.front();
104     que.pop();
105     (void)(visited.insert(front));
106     for (size_t i = 1; i < front->size(); ++i) {
107       auto input = front->input(i);
108       if (!input->isa<CNode>()) {
109         continue;
110       }
111       auto cinput = input->cast<CNodePtr>();
112       MS_EXCEPTION_IF_NULL(cinput);
113       if (IsPrimitiveCNode(cinput, prim::kPrimReceive)) {
114         return cinput->GetPrimalAttr(SEGMENT);
115       }
116       if (visited.find(cinput) == visited.end()) {
117         que.push(cinput);
118       }
119     }
120   }
121   return nullptr;
122 }
123 
EnableShareCell()124 static bool EnableShareCell() {
125   auto context = MsContext::GetInstance();
126   MS_EXCEPTION_IF_NULL(context);
127   const auto cell_reuse = context->CellReuseLevel() != CellReuseLevel::kNoCellReuse;
128   const auto &comm_reuse_env = common::GetEnv("MS_COMM_COMPILER_OPT");
129   if (!comm_reuse_env.empty() && cell_reuse) {
130     MS_LOG(EXCEPTION) << "The cell reuse cannot be used with communication reuse,"
131                          " please unset environment variable 'MS_COMM_COMPILER_OPT'";
132   }
133   MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
134   bool grad_accumulation_shard = ParallelContext::GetInstance()->grad_accumulation_shard();
135   if (grad_accumulation_shard && cell_reuse) {
136     MS_LOG(EXCEPTION)
137       << "The cell reuse cannot be used with sharding accumulate grad parameter with optimizer parallel,"
138          " please set_auto_parallel_context(parallel_optimizer_config={'gradient_accumulation_shard':False})";
139   }
140   return cell_reuse;
141 }
142 
GetCallBackwardEndNext(const AnfNodePtr & node)143 static AnfNodePtr GetCallBackwardEndNext(const AnfNodePtr &node) {
144   if (!node->has_user_data(CALL_BACKWARD_END_NEXT)) {
145     return node;
146   }
147   return node->user_data<AnfNode>(CALL_BACKWARD_END_NEXT);
148 }
149 
IsValidNode(const AnfNodePtr & node,const AnfNodePtr & return_node,const NodeUsersMap & node_user_map)150 bool IsValidNode(const AnfNodePtr &node, const AnfNodePtr &return_node, const NodeUsersMap &node_user_map) {
151   if (node == return_node) {
152     return true;
153   }
154   auto iter = node_user_map.find(node);
155   if (iter == node_user_map.end()) {
156     return false;
157   }
158   const auto &users = (*iter).second;
159   return std::any_of(users.begin(), users.end(),
160                      [&return_node, &node_user_map](const std::pair<AnfNodePtr, int> &user) {
161                        return IsValidNode(user.first, return_node, node_user_map);
162                      });
163 }
164 
165 // judge if the graph call specified grad nodes, the specified grad nodes is in grad_graph
166 // search if call specified grad nodes according to dfs
CallGradNodes(const FuncGraphPtr & graph,const FuncGraphPtr & grad_graph,std::set<FuncGraphPtr> * const visit)167 static bool CallGradNodes(const FuncGraphPtr &graph, const FuncGraphPtr &grad_graph,
168                           std::set<FuncGraphPtr> *const visit) {
169   if (visit->find(graph) != visit->end()) {
170     return false;
171   }
172   if (graph == grad_graph) {
173     return true;
174   }
175   (void)(visit->insert(graph));
176   const auto &cnodes = graph->GetOrderedCnodes();
177   for (const auto &cnode : cnodes) {
178     const auto &abs = cnode->input(0)->abstract();
179     if (!abs || !abs->isa<abstract::AbstractFunction>()) {
180       continue;
181     }
182     const auto &abs_func = abs->cast<abstract::AbstractFunctionPtr>();
183     if (!abs_func->isa<abstract::FuncGraphAbstractClosure>()) {
184       continue;
185     }
186     const auto &abs_func_graph = abs->cast<abstract::FuncGraphAbstractClosurePtr>();
187     auto fg = abs_func_graph->func_graph();
188     if (fg && fg == grad_graph) {
189       return true;
190     }
191     if (CallGradNodes(fg, grad_graph, visit)) {
192       return true;
193     }
194   }
195   return false;
196 }
197 
FindGradGraph(const FuncGraphPtr & root)198 static FuncGraphPtr FindGradGraph(const FuncGraphPtr &root) {
199   const auto &nodes = DeepScopedGraphSearch(root->get_return());
200   for (const auto &node : nodes) {
201     if (!node->isa<CNode>()) {
202       continue;
203     }
204     const auto &cnode = node->cast<CNodePtr>();
205     if (cnode->HasPrimalAttr(PARAMETER_START_SHARE_CELL) && cnode->HasPrimalAttr(kPrimalAttrForwardNodeName)) {
206       const auto &grad_graph = cnode->func_graph();
207       MS_LOG(INFO) << "The specified grad nodes is in graph " << grad_graph->ToString();
208       return grad_graph;
209     }
210   }
211   MS_LOG(EXCEPTION) << "Stage0: The grad graph has not been found in lazy inline mode.";
212   return nullptr;
213 }
214 
SetParameterStartForCellShare(const FuncGraphPtr & root)215 void SetParameterStartForCellShare(const FuncGraphPtr &root) {
216   MS_EXCEPTION_IF_NULL(root);
217   auto share_cell = EnableShareCell();
218   if (!share_cell) {
219     return;
220   }
221   if (!IsFirstStage()) {
222     return;
223   }
224   FuncGraphPtr grad_graph = FindGradGraph(root);
225   MS_EXCEPTION_IF_NULL(grad_graph);
226   const auto &manager = root->manager();
227   auto node_user_map = manager->node_users();
228   auto all_nodes = root->GetOrderedCnodes();
229   std::set<FuncGraphPtr> call_grad_nodes;
230   bool has_find = false;
231   for (auto &node : all_nodes) {
232     // if cnode is a call_backward node
233     if (!IsPrimitiveCNode(node->input(0), prim::kPrimTupleGetItem)) {
234       continue;
235     }
236     const auto &abs = node->input(0)->abstract();
237     if (!abs || !abs->isa<abstract::AbstractFunction>()) {
238       continue;
239     }
240     const auto &abs_func = abs->cast<abstract::AbstractFunctionPtr>();
241     if (!abs_func->isa<abstract::FuncGraphAbstractClosure>()) {
242       continue;
243     }
244     std::set<FuncGraphPtr> visit;
245     const auto &abs_func_graph = abs->cast<abstract::FuncGraphAbstractClosurePtr>();
246     auto fg = abs_func_graph->func_graph();
247     if (!fg || (call_grad_nodes.find(fg) == call_grad_nodes.end() && !CallGradNodes(fg, grad_graph, &visit))) {
248       continue;
249     }
250     if (call_grad_nodes.find(fg) == call_grad_nodes.end()) {
251       (void)(call_grad_nodes.insert(fg));
252     }
253     auto micro = GetReceiveMicro(node);
254     MS_EXCEPTION_IF_NULL(micro);
255     auto node_abs = node->abstract();
256     if (node_abs->isa<abstract::AbstractTuple>()) {
257       CNodePtr next = nullptr;
258       const auto &users = node_user_map[node];
259       for (const auto &user : users) {
260         const auto &cuser = user.first->cast<CNodePtr>();
261         MS_EXCEPTION_IF_NULL(cuser);
262         if (IsPrimitiveCNode(cuser, prim::kPrimTupleGetItem) && IsValidNode(cuser, root->get_return(), node_user_map)) {
263           next = cuser;
264           break;
265         }
266       }
267       node->set_user_data<AnfNode>(CALL_BACKWARD_END_NEXT, next);
268     }
269     has_find = true;
270     node->AddPrimalAttr(MICRO, micro);
271     node->AddPrimalAttr(PARAMETER_START, micro);
272     auto parallel_context = parallel::ParallelContext::GetInstance();
273     if (parallel_context->enable_fold_pipeline()) {
274       auto segment = GetReceiveSegment(node);
275       MS_EXCEPTION_IF_NULL(segment);
276       node->AddPrimalAttr(SEGMENT, segment);
277     }
278   }
279   if (!has_find) {
280     MS_LOG(EXCEPTION) << "Stage0: The backward end flag has not been marked in lazy inline mode.";
281   } else {
282     MS_LOG(INFO) << "Stage0: The backward end flag has been marked in lazy inline mode.";
283   }
284 }
285 
FindAccuGrad(const CNodePtr & cnode)286 AnfNodePtr FindAccuGrad(const CNodePtr &cnode) {
287   auto pre_node = cnode->input(1);
288   size_t depth = 0;
289   while (true) {
290     if (depth > MAX_RECURSIVE_DEPTH) {
291       return nullptr;
292     }
293     depth += 1;
294     if (pre_node->isa<Parameter>()) {
295       return pre_node;
296     } else {
297       if (pre_node->isa<CNode>()) {
298         auto pre_cnode = pre_node->cast<CNodePtr>();
299         pre_node = pre_cnode->input(1);
300       } else {
301         return nullptr;
302       }
303     }
304   }
305   return nullptr;
306 }
307 
SetStridedSliceStrategy(const AnfNodePtr & node)308 void SetStridedSliceStrategy(const AnfNodePtr &node) {
309   MS_EXCEPTION_IF_NULL(node);
310   if (!IsPrimitiveCNode(node, prim::kPrimStridedSlice)) {
311     return;
312   }
313   bool full_batch = ParallelContext::GetInstance()->full_batch();
314   auto dev_num = g_device_manager->stage_device_num();
315   auto cnode = node->cast<CNodePtr>();
316   MS_EXCEPTION_IF_NULL(cnode);
317   std::vector<Shapes> shape_list;
318   if (InDynamicGraph(cnode)) {
319     shape_list = ExtractRealDivisor(cnode);
320     MS_LOG(INFO) << "the node is in dynamic shape graph, the divisor is " << ShapesToString(shape_list[0]);
321   } else {
322     shape_list = ExtractShape(cnode);
323   }
324 
325   if (shape_list.empty()) {
326     MS_LOG(EXCEPTION) << "Failure:node " << cnode->ToString() << " failed to extract shape";
327   }
328   std::vector<ValuePtr> elements;
329   for (size_t i = 0; i < shape_list[0].size(); i++) {
330     if (shape_list[0][i].empty()) {
331       MS_LOG(EXCEPTION) << "shape_list[ " << i << " ].size() is zero";
332     }
333     Dimensions input_strategy;
334     for (size_t j = 0; j < shape_list[0][i].size(); j++) {
335       input_strategy.push_back(1);
336     }
337     static const auto skip_redis = (common::GetEnv("PIPELINE_SLICE_SKIP_REDISTRIBUTION") == "1");
338     if (skip_redis && !full_batch && input_strategy.size() > 0) {
339       auto dim = shape_list[1][0][0];
340       if (dev_num <= dim && ((dim % dev_num) == 0)) {
341         input_strategy[0] = dev_num;
342       } else if (dim < dev_num && ((dev_num % dim) == 0)) {
343         input_strategy[0] = dim;
344       }
345       auto prim = GetCNodePrimitive(node);
346       if (prim->HasAttr("out_shard_size")) {
347         auto out_shard_size = GetValue<int64_t>(prim->GetAttr("out_shard_size"));
348         input_strategy[0] = out_shard_size;
349       }
350       auto attrs = prim->attrs();
351       attrs[parallel::SKIP_REDISTRIBUTION] = MakeValue<bool>(true);
352       (void)prim->SetAttrs(attrs);
353     }
354 
355     elements.push_back(MakeValue(input_strategy));
356   }
357   ValueTuplePtr strategy = std::make_shared<ValueTuple>(elements);
358   cnode->AddPrimalAttr(IN_STRATEGY, strategy);
359 }
360 
FindNodeWithMircoSize(const AnfNodePtr & node_user,const NodeUsersMap & node_users_map)361 CNodePtr FindNodeWithMircoSize(const AnfNodePtr &node_user, const NodeUsersMap &node_users_map) {
362   // Recursively find micro tags, this may takes much more time if layers are too much
363   std::queue<AnfNodePtr> visited;
364   visited.push(node_user);
365   while (!visited.empty()) {
366     auto cur_node = visited.front();
367     visited.pop();
368     if (node_users_map.find(cur_node) == node_users_map.end()) {
369       continue;
370     }
371     auto users = node_users_map.at(cur_node);
372     for (auto &temp_user : users) {
373       auto cnode = temp_user.first->cast<CNodePtr>();
374       MS_EXCEPTION_IF_NULL(cnode);
375       if (!cnode->HasPrimalAttr(MICRO)) {
376         visited.push(temp_user.first);
377       } else {
378         return cnode;
379       }
380     }
381   }
382   return nullptr;
383 }
384 
IsSourceUsedByMirror(const CNodePtr & node,const NodeUsersMap & node_user_map)385 bool IsSourceUsedByMirror(const CNodePtr &node, const NodeUsersMap &node_user_map) {
386   if (node->size() < 2) {
387     return false;
388   }
389   auto parameter_node = node->input(1);
390   if (parameter_node->cast<ParameterPtr>()) {
391     for (auto &item : node_user_map.at(parameter_node)) {
392       if (IsPrimitiveCNode(item.first, prim::kPrimMirrorMicroStep)) {
393         return true;
394       }
395     }
396   }
397   return false;
398 }
InsertVirtualAssignAdd(const std::pair<AnfNodePtr,int> & node_user,const FuncGraphManagerPtr & manager,const AnfNodePtr & accu_parameter,const NodeUsersMap & node_user_map)399 void InsertVirtualAssignAdd(const std::pair<AnfNodePtr, int> &node_user, const FuncGraphManagerPtr &manager,
400                             const AnfNodePtr &accu_parameter, const NodeUsersMap &node_user_map) {
401   auto cnode = node_user.first->cast<CNodePtr>();
402   if (IsPrimitiveCNode(cnode, prim::kPrimReceive) || !cnode->in_forward_flag()) {
403     return;
404   }
405   MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
406   bool enable_parallel_optimizer = ParallelContext::GetInstance()->enable_parallel_optimizer();
407   bool grad_accumulation_shard = ParallelContext::GetInstance()->grad_accumulation_shard();
408   auto is_pp_interleave = ParallelContext::GetInstance()->pipeline_interleave();
409   if (!is_pp_interleave && IsPrimitiveCNode(cnode, prim::kPrimMakeTuple)) {
410     return;
411   }
412   if (IsPrimitiveCNode(cnode, prim::kPrimDepend) && enable_parallel_optimizer &&
413       IsSourceUsedByMirror(cnode, node_user_map)) {
414     return;
415   }
416   auto param_ptr = accu_parameter->cast<ParameterPtr>();
417   MS_EXCEPTION_IF_NULL(param_ptr);
418   // If grad_accumulation_shard is ture, a ReduceScatter will be inserted at each micro step,
419   // So the fusion id should be different for each micro step
420   // otherwise they will be fused into the one ReduceScatter alone micro_steps.
421   // if grad_accumulation_shard is false, we pass an empty group, so no ReduceScatter will be inserted
422   ValuePtr args1 = nullptr;
423   ValuePtr args2 = nullptr;
424   ValuePtr micro = nullptr;
425   int64_t step = 0;
426   if (grad_accumulation_shard) {
427     auto cnode_with_micro_size = FindNodeWithMircoSize(cnode, node_user_map);
428     if (cnode_with_micro_size && cnode_with_micro_size->HasPrimalAttr(MICRO)) {
429       micro = cnode_with_micro_size->GetPrimalAttr(MICRO);
430       step = GetValue<int64_t>(micro);
431     }
432   }
433   args1 = MakeValue(param_ptr->user_data<TensorLayout>()->opt_shard_group());
434   args2 = MakeValue(LongToSize(param_ptr->param_info()->comm_fusion()) + LongToSize(step) * PIPELINE_FUSTION_OFFSET);
435   OperatorAttrs attrs = {};
436   auto py_instance = CreateOpInstance(attrs, VIRTUAL_ASSIGN_ADD, VIRTUAL_ASSIGN_ADD);
437   auto value_node = NewValueNode(py_instance);
438   // Set the attribute of the reduce scatter
439   auto new_prim = GetValueNode<PrimitivePtr>(value_node);
440   MS_EXCEPTION_IF_NULL(new_prim);
441   auto attrs_prim = new_prim->attrs();
442   attrs_prim[GROUP] = args1;
443   attrs_prim[kAttrFusion] = args2;
444   if (cnode->HasPrimalAttr(PIPELINE_PARAM)) {
445     attrs_prim[PIPELINE_PARAM] = MakeValue(true);
446   }
447   (void)new_prim->SetAttrs(attrs_prim);
448 
449   std::vector<AnfNodePtr> virtual_node_input = {value_node, cnode->input(IntToSize(node_user.second)), accu_parameter};
450   auto graph = cnode->func_graph();
451   auto virtual_node = graph->NewCNode(virtual_node_input);
452   manager->SetEdge(cnode, node_user.second, virtual_node);
453 }
454 
InsertVirtualAccuGrad(const AnfNodePtr & recv,const FuncGraphManagerPtr & manager,const AnfNodePtr & param)455 void InsertVirtualAccuGrad(const AnfNodePtr &recv, const FuncGraphManagerPtr &manager, const AnfNodePtr &param) {
456   auto cnode = recv->cast<CNodePtr>();
457   MS_EXCEPTION_IF_NULL(cnode);
458   OperatorAttrs attrs;
459   auto py_instance = CreateOpInstance(attrs, VIRTUAL_ACCU_GRAD, VIRTUAL_ACCU_GRAD);
460   auto value_node = NewValueNode(py_instance);
461   std::vector<AnfNodePtr> virtual_node_input = {value_node, recv, param};
462   auto graph = cnode->func_graph();
463   MS_EXCEPTION_IF_NULL(graph);
464   auto virtual_node = graph->NewCNode(virtual_node_input);
465   (void)manager->Replace(recv, virtual_node);
466 }
467 
FindGradAccuParameter(const std::vector<AnfNodePtr> & parameters,const std::string & name)468 AnfNodePtr FindGradAccuParameter(const std::vector<AnfNodePtr> &parameters, const std::string &name) {
469   for (auto &parameter : parameters) {
470     auto param_ptr = parameter->cast<ParameterPtr>();
471     MS_EXCEPTION_IF_NULL(param_ptr);
472     if (param_ptr->name() == name) {
473       continue;
474     }
475     auto expect_name = "accu_grads." + name;
476     if (param_ptr->name() == expect_name) {
477       return parameter;
478     }
479   }
480   return nullptr;
481 }
482 
483 // If the graph likes the followings:
484 // 1. MicroStepAllGather->MirrorMicro->load, we need to visit the param after the load
FindNextNode(const std::pair<AnfNodePtr,int> & node_ptr,const NodeUsersMap & node_users_map,const std::set<string> & check_list={prim::kPrimMirrorMicroStep->name(), prim::kPrimMicroStepAllGather->name(), prim::kPrimLoad->name()})485 std::vector<std::pair<AnfNodePtr, int>> FindNextNode(
486   const std::pair<AnfNodePtr, int> &node_ptr, const NodeUsersMap &node_users_map,
487   const std::set<string> &check_list = {prim::kPrimMirrorMicroStep->name(), prim::kPrimMicroStepAllGather->name(),
488                                         prim::kPrimLoad->name()}) {
489   std::vector<std::pair<AnfNodePtr, int>> to_be_visited_set;
490   if (!IsSomePrimitiveList(node_ptr.first->cast<CNodePtr>(), check_list)) {
491     (void)to_be_visited_set.emplace_back(node_ptr);
492     return to_be_visited_set;
493   }
494   auto node_set = node_users_map.at(node_ptr.first);
495   std::queue<std::pair<std::shared_ptr<AnfNode>, int>> visited;
496   for (auto &node_user : node_set) {
497     visited.push(node_user);
498   }
499   while (visited.size() >= 1) {
500     auto node = visited.front();
501     visited.pop();
502     if (!IsSomePrimitiveList(node.first->cast<CNodePtr>(), check_list)) {
503       (void)to_be_visited_set.emplace_back(node);
504     } else {
505       auto next_node_set = node_users_map.at(node.first);
506       for (auto &node_user : next_node_set) {
507         visited.push(node_user);
508       }
509     }
510   }
511   return to_be_visited_set;
512 }
513 
FuncNodeUsersSet(const AnfNodePtr & parameter)514 std::set<std::pair<AnfNodePtr, int>> FuncNodeUsersSet(const AnfNodePtr &parameter) {
515   MS_EXCEPTION_IF_NULL(parameter->func_graph());
516   MS_EXCEPTION_IF_NULL(parameter->func_graph()->manager());
517   auto node_users_map = parameter->func_graph()->manager()->node_users();
518   auto node_users = node_users_map[parameter];
519   std::set<std::pair<AnfNodePtr, int>> all_node_users;
520   for (auto &n_pair : node_users) {
521     auto users_skip_virtual_nodes =
522       FindNextNode(n_pair, node_users_map,
523                    {prim::kPrimMirrorMicroStep->name(), prim::kPrimMicroStepAllGather->name(), prim::kPrimLoad->name(),
524                     prim::kPrimCast->name()});
525     for (const auto &node_pair : users_skip_virtual_nodes) {
526       auto func_node_users = FuncGraphNodeUsers(node_pair);
527       if (func_node_users.empty()) {
528         (void)all_node_users.insert(node_pair);
529         continue;
530       }
531       for (const auto &func_node_user : func_node_users) {
532         (void)all_node_users.insert(func_node_user);
533       }
534     }
535   }
536   return all_node_users;
537 }
538 
HandleReceiveParam(const FuncGraphPtr & root)539 void HandleReceiveParam(const FuncGraphPtr &root) {
540   auto parameters = root->parameters();
541   auto node_users_map = root->manager()->node_users();
542   auto all_nodes = TopoSort(root->get_return(), SuccDeeperSimple);
543   for (auto &node : all_nodes) {
544     if (!IsPrimitiveCNode(node, prim::kPrimReceive)) {
545       continue;
546     }
547     auto cnode = node->cast<CNodePtr>();
548     if (!cnode->HasPrimalAttr(PIPELINE_PARAM)) {
549       continue;
550     }
551     auto parameter_ptr = cnode->input(1)->cast<ParameterPtr>();
552     MS_EXCEPTION_IF_NULL(parameter_ptr);
553     auto accu_parameter = FindGradAccuParameter(parameters, parameter_ptr->name());
554     if (!accu_parameter) {
555       continue;
556     }
557     auto base_shape = accu_parameter->Shape();
558     auto shape_ptr = dyn_cast<abstract::Shape>(base_shape);
559     auto slice_shape = shape_ptr->shape();
560     auto prim = GetCNodePrimitive(cnode);
561     std::vector<ValuePtr> element;
562     (void)std::transform(slice_shape.begin(), slice_shape.end(), std::back_inserter(element),
563                          [](int64_t elem) { return MakeValue(elem); });
564     auto value = std::make_shared<ValueList>(element);
565     prim->set_attr(SHAPE, value);
566     std::set<std::pair<AnfNodePtr, int>> all_node_users = FuncNodeUsersSet(node);
567     for (auto &temp_user : all_node_users) {
568       auto temp_node = temp_user.first;
569       // Micro virtual operator might be inserted after cast
570       if (IsPrimitiveCNode(temp_node, prim::kPrimCast)) {
571         temp_node = node_users_map[temp_node].begin()->first;
572       }
573       if (IsPrimitiveCNode(temp_node, prim::kPrimMirrorMicroStep) ||
574           IsPrimitiveCNode(temp_node, prim::kPrimMicroStepAllGather)) {
575         auto node_set = node_users_map[temp_node];
576         for (auto &node_user : node_set) {
577           InsertVirtualAssignAdd(node_user, root->manager(), accu_parameter, node_users_map);
578         }
579       } else {
580         InsertVirtualAssignAdd(temp_user, root->manager(), accu_parameter, node_users_map);
581       }
582     }
583     InsertVirtualAccuGrad(node, root->manager(), accu_parameter);
584   }
585 }
586 
AddVirtualAssignAdd(const FuncGraphPtr & root)587 void AddVirtualAssignAdd(const FuncGraphPtr &root) {
588   auto parameters = root->parameters();
589   auto node_users_map = root->manager()->node_users();
590   for (auto &parameter : parameters) {
591     auto parameter_ptr = parameter->cast<ParameterPtr>();
592     auto accu_parameter = FindGradAccuParameter(parameters, parameter_ptr->name());
593     if (!accu_parameter) {
594       continue;
595     }
596     std::set<std::pair<AnfNodePtr, int>> all_node_users = FuncNodeUsersSet(parameter);
597     for (auto &temp_user : all_node_users) {
598       // Micro virtual operator might be inserted after cast
599       auto temp_node = temp_user;
600       if (IsPrimitiveCNode(temp_node.first, prim::kPrimCast)) {
601         temp_node = *node_users_map[temp_node.first].begin();
602       }
603       if (!IsSomePrimitiveList(
604             temp_node.first->cast<CNodePtr>(),
605             {prim::kPrimMirrorMicroStep->name(), prim::kPrimMicroStepAllGather->name(), prim::kPrimLoad->name()})) {
606         InsertVirtualAssignAdd(temp_node, root->manager(), accu_parameter, node_users_map);
607         continue;
608       }
609       auto node_set = FindNextNode(temp_node, node_users_map);
610       for (auto &node_user : node_set) {
611         InsertVirtualAssignAdd(node_user, root->manager(), accu_parameter, node_users_map);
612       }
613     }
614   }
615 }
616 
SliceSort(const CNodePtr & cnode1,const CNodePtr & cnode2)617 bool SliceSort(const CNodePtr &cnode1, const CNodePtr &cnode2) {
618   if (IsPrimitiveCNode(cnode1, prim::kPrimStridedSlice) && IsPrimitiveCNode(cnode2, prim::kPrimStridedSlice)) {
619     auto slice_index1 = GetValue<int64_t>(cnode1->GetPrimalAttr(SLICE_INDEX));
620     auto slice_index2 = GetValue<int64_t>(cnode2->GetPrimalAttr(SLICE_INDEX));
621     return slice_index1 < slice_index2;
622   }
623   if (IsPrimitiveCNode(cnode1, prim::kPrimStridedSlice)) {
624     return false;
625   }
626   return true;
627 }
628 
CompFunc(const AnfNodePtr & node1,const AnfNodePtr & node2)629 bool CompFunc(const AnfNodePtr &node1, const AnfNodePtr &node2) {
630   MS_EXCEPTION_IF_NULL(node1);
631   MS_EXCEPTION_IF_NULL(node2);
632   auto cnode1 = node1->cast<CNodePtr>();
633   auto cnode2 = node2->cast<CNodePtr>();
634   MS_EXCEPTION_IF_NULL(cnode1);
635   MS_EXCEPTION_IF_NULL(cnode2);
636   auto micro1 = cnode1->GetPrimalAttr(MICRO);
637   auto micro2 = cnode2->GetPrimalAttr(MICRO);
638   MS_EXCEPTION_IF_NULL(micro1);
639   MS_EXCEPTION_IF_NULL(micro2);
640   auto micro1_value = GetValue<int64_t>(micro1);
641   auto micro2_value = GetValue<int64_t>(micro2);
642   if (micro1_value == micro2_value) {
643     if (IsPrimitiveCNode(node1, prim::kPrimStridedSlice) || IsPrimitiveCNode(node2, prim::kPrimStridedSlice)) {
644       return SliceSort(cnode1, cnode2);
645     }
646     auto prim1 = GetCNodePrimitive(cnode1);
647     auto prim2 = GetCNodePrimitive(cnode2);
648     if (EnableShareCell() && prim1 == nullptr && prim2 == nullptr) {
649       return false;
650     }
651     MS_EXCEPTION_IF_NULL(prim1);
652     MS_EXCEPTION_IF_NULL(prim2);
653     auto rank_tag1 = prim1->GetAttr(SRC_RANK);
654     auto rank_tag2 = prim2->GetAttr(SRC_RANK);
655     if (rank_tag1 == nullptr) {
656       rank_tag1 = prim1->GetAttr(DEST_RANK);
657     }
658     if (rank_tag2 == nullptr) {
659       rank_tag2 = prim2->GetAttr(DEST_RANK);
660     }
661     if (!rank_tag1 || !rank_tag2) {
662       return false;
663     }
664     auto rank1_value = GetValue<int64_t>(rank_tag1);
665     auto rank2_value = GetValue<int64_t>(rank_tag2);
666     if (rank1_value == rank2_value) {
667       auto sr_tag1 = prim1->GetAttr(SR_TAG);
668       auto sr_tag2 = prim2->GetAttr(SR_TAG);
669       MS_EXCEPTION_IF_NULL(sr_tag1);
670       MS_EXCEPTION_IF_NULL(sr_tag2);
671       auto sr1_value = GetValue<int64_t>(sr_tag1);
672       auto sr2_value = GetValue<int64_t>(sr_tag2);
673       return sr1_value < sr2_value;
674     }
675     return rank1_value < rank2_value;
676   }
677   return micro1_value < micro2_value;
678 }
679 
InsertDepend(const AnfNodePtr & prior_node,const AnfNodePtr & post_node,const FuncGraphManagerPtr & manager,const FuncGraphPtr & root,const std::string & attr_tag)680 void InsertDepend(const AnfNodePtr &prior_node, const AnfNodePtr &post_node, const FuncGraphManagerPtr &manager,
681                   const FuncGraphPtr &root, const std::string &attr_tag) {
682   MS_EXCEPTION_IF_NULL(prior_node);
683   MS_EXCEPTION_IF_NULL(post_node);
684   auto post_cnode = post_node->cast<CNodePtr>();
685   MS_EXCEPTION_IF_NULL(post_cnode);
686   std::vector<AnfNodePtr> depend_input = {NewValueNode(prim::kPrimDepend), post_cnode->input(1), prior_node};
687   auto depend_node = root->NewCNode(depend_input);
688   depend_node->set_abstract(post_cnode->input(1)->abstract());
689   if (!attr_tag.empty()) {
690     depend_node->AddAttr(attr_tag, MakeValue<bool>(true));
691   }
692   manager->SetEdge(post_node, 1, depend_node);
693 }
694 
ReorderForForward(const std::vector<AnfNodePtr> & forward_start,const std::vector<AnfNodePtr> & forward_end,const FuncGraphPtr & root)695 void ReorderForForward(const std::vector<AnfNodePtr> &forward_start, const std::vector<AnfNodePtr> &forward_end,
696                        const FuncGraphPtr &root) {
697   MS_EXCEPTION_IF_NULL(g_device_manager);
698   MS_EXCEPTION_IF_NULL(root);
699   auto manager = root->manager();
700   MS_EXCEPTION_IF_NULL(manager);
701   auto stage_num = g_device_manager->stage_num();
702   auto stage_id = g_device_manager->stage_id();
703   for (size_t i = 1; i < LongToSize(stage_num - stage_id); ++i) {
704     auto prior_node = forward_end[i - 1];
705     auto post_node = forward_start[i];
706     InsertDepend(prior_node, post_node, manager, root);
707   }
708 }
709 
ReorderForBackward(const PipelinePair & forward_start_pair,const PipelinePair & forward_end_pair,const PipelinePair & backward_start_pair,const PipelinePair & backward_end_pair,const PipelinePair & forward_end_before_pair,const FuncGraphPtr & root)710 void ReorderForBackward(const PipelinePair &forward_start_pair, const PipelinePair &forward_end_pair,
711                         const PipelinePair &backward_start_pair, const PipelinePair &backward_end_pair,
712                         const PipelinePair &forward_end_before_pair, const FuncGraphPtr &root) {
713   MS_EXCEPTION_IF_NULL(g_device_manager);
714   MS_EXCEPTION_IF_NULL(root);
715   auto manager = root->manager();
716   MS_EXCEPTION_IF_NULL(manager);
717   auto stage_num = g_device_manager->stage_num();
718   auto stage_id = g_device_manager->stage_id();
719   for (size_t i = LongToSize(stage_num - stage_id); i < (forward_start_pair.first.size()); ++i) {
720     auto prior_node1 = forward_end_before_pair.second[i];
721     auto post_node1 = backward_start_pair.first[LongToSize(SizeToLong(i) - stage_num + stage_id + 1)];
722     InsertDepend(prior_node1, post_node1, manager, root, TagForSendRecDepend(prior_node1, post_node1));
723     auto prior_node2 = backward_end_pair.second[LongToSize(SizeToLong(i) - stage_num + stage_id)];
724     prior_node2 = GetCallBackwardEndNext(prior_node2);
725     auto post_node2 = forward_start_pair.first[i];
726     InsertDepend(prior_node2, post_node2, manager, root, TagForSendRecDepend(prior_node2, post_node2));
727   }
728   for (size_t i = LongToSize(stage_num - stage_id); i < (forward_start_pair.first.size() + 1); ++i) {
729     if (!IsLastStage()) {
730       auto prior_node3 = backward_start_pair.second[LongToSize(SizeToLong(i) - stage_num + stage_id)];
731       auto post_node3 = forward_end_pair.first[i - 1];
732       InsertDepend(prior_node3, post_node3, manager, root, TagForSendRecDepend(prior_node3, post_node3));
733       auto prior_node4 = forward_end_pair.second[i - 1];
734       auto post_node4 = backward_end_pair.first[LongToSize(SizeToLong(i) - stage_num + stage_id)];
735       InsertDepend(prior_node4, post_node4, manager, root, TagForSendRecDepend(prior_node4, post_node4));
736     }
737   }
738   for (size_t j = LongToSize(SizeToLong(backward_start_pair.first.size()) - stage_num + stage_id + 1);
739        j < backward_start_pair.first.size(); ++j) {
740     auto prior_node5 = backward_end_pair.second[j - 1];
741     prior_node5 = GetCallBackwardEndNext(prior_node5);
742     auto post_node5 = backward_start_pair.first[j];
743     InsertDepend(prior_node5, post_node5, manager, root, TagForSendRecDepend(prior_node5, post_node5));
744   }
745   if (!IsLastStage()) {
746     auto prior_node6 = forward_end_before_pair.second[LongToSize(stage_num - 1 - stage_id)];
747     auto post_node6 = backward_start_pair.first[0];
748     InsertDepend(prior_node6, post_node6, manager, root, TagForSendRecDepend(prior_node6, post_node6));
749   }
750 }
751 
ReorderForParams(const PipelinePair & backward_params_pair,const PipelinePair & forward_params_pair,const PipelinePair & backward_end_pair,const PipelinePair & forward_start_pair,const FuncGraphPtr & root)752 void ReorderForParams(const PipelinePair &backward_params_pair, const PipelinePair &forward_params_pair,
753                       const PipelinePair &backward_end_pair, const PipelinePair &forward_start_pair,
754                       const FuncGraphPtr &root) {
755   auto manager = root->manager();
756   MS_EXCEPTION_IF_NULL(manager);
757   if (!forward_params_pair.second.empty()) {
758     auto prior_node = forward_params_pair.second.back();
759     auto post_node = forward_start_pair.first.front();
760     InsertDepend(prior_node, post_node, manager, root);
761   }
762   if (!backward_params_pair.first.empty()) {
763     auto prior_node2 = backward_end_pair.second.back();
764     prior_node2 = GetCallBackwardEndNext(prior_node2);
765     auto post_node2 = backward_params_pair.first.front();
766     InsertDepend(prior_node2, post_node2, manager, root);
767   }
768 }
769 
GetMicroBatch(const AnfNodePtr & node)770 int64_t GetMicroBatch(const AnfNodePtr &node) {
771   MS_EXCEPTION_IF_NULL(node);
772   auto cnode = node->cast<CNodePtr>();
773   MS_EXCEPTION_IF_NULL(cnode);
774   auto micro_value = cnode->GetPrimalAttr(MICRO);
775   MS_EXCEPTION_IF_NULL(micro_value);
776   return GetValue<int64_t>(micro_value);
777 }
778 
CommonDeduplicate(const std::vector<AnfNodePtr> & node_vector,std::vector<AnfNodePtr> * out_vec_begin,std::vector<AnfNodePtr> * out_vec_end,const FuncGraphPtr & root,int64_t micro_max,int64_t seg_max,int64_t h,bool is_train)779 void CommonDeduplicate(const std::vector<AnfNodePtr> &node_vector, std::vector<AnfNodePtr> *out_vec_begin,
780                        std::vector<AnfNodePtr> *out_vec_end, const FuncGraphPtr &root, int64_t micro_max,
781                        int64_t seg_max, int64_t h, bool is_train) {
782   std::vector<AnfNodePtr> temp_vec;
783   auto manager = root->manager();
784   for (int64_t i = 0; i <= micro_max; ++i) {
785     temp_vec.clear();
786     if (!is_train) {
787       temp_vec = node_vector;
788     } else {
789       for (auto &node : node_vector) {
790         auto node_micro = GetMicroBatch(node);
791         if (seg_max >= 1) {
792           auto node_seg = GetSegment(node);
793           if (node_micro == i && node_seg == h) {
794             temp_vec.push_back(node);
795           }
796         } else {
797           if (node_micro == i) {
798             temp_vec.push_back(node);
799           }
800         }
801       }
802     }
803     if (temp_vec.empty()) {
804       MS_LOG(INFO) << "No Duplicate MicroBatch.";
805       continue;
806     }
807     if (temp_vec.size() == 1) {
808       if (seg_max >= 1) {
809         MS_LOG(WARNING) << "Single element, no need to deduplicate.";
810         out_vec_begin->push_back(temp_vec.front());
811         out_vec_end->push_back(temp_vec.back());
812       }
813       continue;
814     }
815     std::sort(temp_vec.begin(), temp_vec.end(), CompFunc);
816     for (size_t j = 0; j < temp_vec.size() - 1; ++j) {
817       auto prior_node = temp_vec[j];
818       prior_node = GetCallBackwardEndNext(prior_node);
819       auto post_node = temp_vec[j + 1];
820       InsertDepend(prior_node, post_node, manager, root);
821     }
822     if (!temp_vec.empty()) {
823       out_vec_begin->push_back(temp_vec.front());
824       out_vec_end->push_back(temp_vec.back());
825     }
826   }
827 }
828 
GetForwardEndBeforePair(const PipelinePair & forward_end_pair)829 PipelinePair GetForwardEndBeforePair(const PipelinePair &forward_end_pair) {
830   PipelinePair forward_end_before_pair;
831   if (!IsLastStage()) {
832     for (auto &node : forward_end_pair.first) {
833       auto cnode = node->cast<CNodePtr>();
834       auto temp_node = GetActualOp(cnode->input(1));
835       MS_EXCEPTION_IF_NULL(temp_node);
836       forward_end_before_pair.first.push_back(temp_node);
837     }
838     for (auto &node : forward_end_pair.second) {
839       auto cnode = node->cast<CNodePtr>();
840       auto temp_node = GetActualOp(cnode->input(1));
841       MS_EXCEPTION_IF_NULL(temp_node);
842       forward_end_before_pair.second.push_back(temp_node);
843     }
844   } else {
845     forward_end_before_pair = forward_end_pair;
846   }
847   return forward_end_before_pair;
848 }
849 
GetMicroMax(const FuncGraphPtr & root,const std::vector<AnfNodePtr> & forward_end)850 int64_t GetMicroMax(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &forward_end) {
851   int64_t micro_max = 0;
852   if (forward_end.empty()) {
853     MS_LOG(EXCEPTION) << "can not find the end node of pipeline, you are advised to use 'PipelineCell' to fix it.";
854   } else {
855     auto forward_end_cnode = forward_end.back()->cast<CNodePtr>();
856     auto micro_size = forward_end_cnode->GetPrimalAttr(MICRO);
857     MS_EXCEPTION_IF_NULL(micro_size);
858     micro_max = GetValue<int64_t>(micro_size);
859   }
860   return micro_max;
861 }
862 
GetSegment(const AnfNodePtr & node)863 int64_t GetSegment(const AnfNodePtr &node) {
864   MS_EXCEPTION_IF_NULL(node);
865   auto cnode = node->cast<CNodePtr>();
866   MS_EXCEPTION_IF_NULL(cnode);
867   auto seg_value = cnode->GetPrimalAttr(SEGMENT);
868   MS_EXCEPTION_IF_NULL(seg_value);
869   return GetValue<int64_t>(seg_value);
870 }
871 
BroadCastMicroBatch(const CNodePtr & node,NodeUsersMap * node_users_map,const ValuePtr & value,size_t max_depth)872 void BroadCastMicroBatch(const CNodePtr &node, NodeUsersMap *node_users_map, const ValuePtr &value, size_t max_depth) {
873   auto node_users = (*node_users_map)[node];
874   if (max_depth > MAX_RECURSIVE_DEPTH) {
875     MS_LOG(EXCEPTION) << "Recursive call is larger than 100000.";
876   }
877   for (auto &node_pair : node_users) {
878     auto user_node = node_pair.first->cast<CNodePtr>();
879     if (user_node->HasPrimalAttr(MICRO) || IsPrimitiveCNode(user_node, prim::kPrimUpdateState)) {
880       continue;
881     }
882     user_node->AddPrimalAttr(MICRO, value);
883     BroadCastMicroBatch(user_node, node_users_map, value, max_depth + 1);
884   }
885 }
886 
BroadCastNeedGrad(const AnfNodePtr & node,NodeUsersMap * node_user_map,const FuncGraphPtr & root)887 void BroadCastNeedGrad(const AnfNodePtr &node, NodeUsersMap *node_user_map, const FuncGraphPtr &root) {
888   auto node_users = (*node_user_map)[node];
889   for (auto &node_user : node_users) {
890     auto cnode = node_user.first->cast<CNodePtr>();
891     MS_EXCEPTION_IF_NULL(cnode);
892     if (cnode->HasPrimalAttr(NEED_GRAD)) {
893       continue;
894     }
895     if (cnode->func_graph() == root) {
896       continue;
897     }
898     cnode->AddPrimalAttr(NEED_GRAD, MakeValue(1));
899     BroadCastNeedGrad(cnode, node_user_map, root);
900   }
901 }
902 
903 // Label node that need backpropagation
LabelNeedGrad(const FuncGraphManagerPtr & manager,const FuncGraphPtr & root)904 void LabelNeedGrad(const FuncGraphManagerPtr &manager, const FuncGraphPtr &root) {
905   auto parameters = root->parameters();
906   auto &node_user_map = manager->node_users();
907   for (auto &parameter : parameters) {
908     if (!ParameterRequireGrad(parameter)) {
909       continue;
910     }
911     auto param_ptr = parameter->cast<ParameterPtr>();
912     MS_EXCEPTION_IF_NULL(param_ptr);
913     if (param_ptr->name().find(ACCU_GRADS) != std::string::npos) {
914       continue;
915     }
916     BroadCastNeedGrad(parameter, &node_user_map, root);
917   }
918 }
919 
LastStageEndNode(const std::vector<AnfNodePtr> & all_nodes,const FuncGraphManagerPtr & manager,const FuncGraphPtr & root)920 void LastStageEndNode(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphManagerPtr &manager,
921                       const FuncGraphPtr &root) {
922   if (!IsLastStage()) {
923     return;
924   }
925   LabelNeedGrad(manager, root);
926   for (auto &node : all_nodes) {
927     if (!node->isa<CNode>()) {
928       continue;
929     }
930     auto cnode = node->cast<CNodePtr>();
931     if (!cnode->HasPrimalAttr(MICRO)) {
932       continue;
933     }
934     auto prim = GetCNodePrimitive(node);
935     if (prim && prim->HasAttr(PIPELINE_END)) {
936       for (size_t i = 0; i < cnode->size(); ++i) {
937         auto temp_node = GetRealKernelNode(cnode->input(i), -1, nullptr).first;
938         if (!temp_node->isa<CNode>()) {
939           continue;
940         }
941         auto temp_prim = GetCNodePrimitive(temp_node);
942         if (!temp_prim || temp_prim->HasAttr(PIPELINE_END)) {
943           continue;
944         }
945         InsertVirtualPipelineEndNode(cnode, manager, i);
946       }
947     }
948   }
949 }
950 
Micro(const CNodePtr & cnode,NodeUsersMap * node_users_map,size_t max_depth)951 ValuePtr Micro(const CNodePtr &cnode, NodeUsersMap *node_users_map, size_t max_depth) {
952   if (max_depth > MAX_RECURSIVE_DEPTH) {
953     MS_LOG(EXCEPTION) << "Recursive call is larger than 100000.";
954   }
955   if (cnode->HasPrimalAttr(MICRO)) {
956     return cnode->GetPrimalAttr(MICRO);
957   }
958   auto node_users = (*node_users_map)[cnode];
959   for (auto &node_pair : node_users) {
960     auto user_node = node_pair.first->cast<CNodePtr>();
961     auto micro = Micro(user_node, node_users_map, max_depth + 1);
962     if (micro) {
963       return micro;
964     }
965   }
966   return nullptr;
967 }
968 
ParameterStartNode(const std::vector<AnfNodePtr> & all_nodes,const FuncGraphManagerPtr & manager)969 void ParameterStartNode(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphManagerPtr &manager) {
970   auto &node_users_map = manager->node_users();
971   for (auto &node : all_nodes) {
972     if (!node->isa<CNode>()) {
973       continue;
974     }
975     auto cnode = node->cast<CNodePtr>();
976     auto prim = GetCNodePrimitive(node);
977     if (prim && prim->HasAttr(PARAMETER_START_SHARE_CELL)) {
978       cnode->AddPrimalAttr(PARAMETER_START_SHARE_CELL, prim->GetAttr(PARAMETER_START_SHARE_CELL));
979       continue;
980     }
981     if (prim && prim->HasAttr(PARAMETER_START)) {
982       auto micro = Micro(cnode, &node_users_map, 0);
983       MS_EXCEPTION_IF_NULL(micro);
984       auto new_prim = prim->Clone();
985       new_prim->SetAttrs(prim->attrs());
986       manager->SetEdge(cnode, 0, NewValueNode(new_prim));
987       cnode->AddPrimalAttr(MICRO, micro);
988       cnode->AddPrimalAttr(PARAMETER_START, micro);
989       int64_t seg = 0;
990       cnode->AddPrimalAttr(SEGMENT, MakeValue(seg));
991     }
992   }
993 }
994 
HandleMicroBatch(const std::vector<AnfNodePtr> & all_nodes,const FuncGraphManagerPtr & manager)995 void HandleMicroBatch(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphManagerPtr &manager) {
996   auto &node_users_map = manager->node_users();
997   for (auto &node : all_nodes) {
998     if (!node->isa<CNode>()) {
999       continue;
1000     }
1001     auto cnode = node->cast<CNodePtr>();
1002     if (!cnode->HasPrimalAttr(MICRO)) {
1003       continue;
1004     }
1005     auto micro = cnode->GetPrimalAttr(MICRO);
1006     MS_EXCEPTION_IF_NULL(micro);
1007     BroadCastMicroBatch(cnode, &node_users_map, micro, 0);
1008   }
1009 }
1010 
GetActualOp(const AnfNodePtr & node)1011 AnfNodePtr GetActualOp(const AnfNodePtr &node) {
1012   if (IsPrimitiveCNode(node, prim::kPrimDepend)) {
1013     auto cnode = node->cast<CNodePtr>();
1014     return cnode->input(1);
1015   }
1016   return node;
1017 }
1018 
GetBorderNode(std::vector<AnfNodePtr> * forward_start,std::vector<AnfNodePtr> * forward_end,std::vector<AnfNodePtr> * backward_start,std::vector<AnfNodePtr> * backward_end,std::vector<AnfNodePtr> * forward_params,std::vector<AnfNodePtr> * backward_params,std::vector<AnfNodePtr> * allreduce_params,const FuncGraphPtr & root)1019 void GetBorderNode(std::vector<AnfNodePtr> *forward_start, std::vector<AnfNodePtr> *forward_end,
1020                    std::vector<AnfNodePtr> *backward_start, std::vector<AnfNodePtr> *backward_end,
1021                    std::vector<AnfNodePtr> *forward_params, std::vector<AnfNodePtr> *backward_params,
1022                    std::vector<AnfNodePtr> *allreduce_params, const FuncGraphPtr &root) {
1023   std::list<ValuePtr> name_list = {};
1024   int64_t slice_index = 0;
1025   auto all_nodes = DeepScopedGraphSearch(root->get_return());
1026   for (auto &node : all_nodes) {
1027     if (!node->isa<CNode>() || IsPrimitiveCNode(node, prim::kPrimDepend) ||
1028         IsPrimitiveCNode(node, prim::kPrimZerosLike)) {
1029       continue;
1030     }
1031     auto prim = GetCNodePrimitive(node);
1032     auto cnode = node->cast<CNodePtr>();
1033     auto share_cell = EnableShareCell();
1034     if (share_cell && cnode->HasPrimalAttr(PARAMETER_START)) {
1035       backward_end->push_back(node);
1036     }
1037     if (cnode->HasPrimalAttr(kPrimalAttrForwardNodeName)) {
1038       auto forward_node_name = cnode->GetPrimalAttr(kPrimalAttrForwardNodeName);
1039       if (std::find(name_list.begin(), name_list.end(), forward_node_name) != name_list.end()) {
1040         continue;
1041       }
1042       name_list.push_back(forward_node_name);
1043       if (cnode->HasPrimalAttr(PIPELINE_END)) {
1044         backward_start->push_back(node);
1045       }
1046       if (cnode->HasPrimalAttr(PIPELINE_BEGIN)) {
1047         backward_end->push_back(node);
1048       }
1049       if (!share_cell && cnode->HasPrimalAttr(PARAMETER_START)) {
1050         backward_end->push_back(node);
1051       }
1052       if (cnode->HasPrimalAttr(PIPELINE_PARAM)) {
1053         backward_params->push_back(node);
1054       }
1055       if (prim->HasAttr(PARAMETER_MICRO)) {
1056         allreduce_params->push_back(node);
1057       }
1058       continue;
1059     }
1060     // the return of cnode->HasPrimalAttr(kPrimalAttrForwardNodeName) is false.
1061     if (cnode->HasPrimalAttr(PIPELINE_BEGIN)) {
1062       if (IsPrimitiveCNode(cnode, prim::kPrimStridedSlice)) {
1063         cnode->AddPrimalAttr(SLICE_INDEX, MakeValue(slice_index));
1064         slice_index += 1;
1065       }
1066       forward_start->push_back(node);
1067     }
1068     if (cnode->HasPrimalAttr(PIPELINE_END)) {
1069       forward_end->push_back(node);
1070     }
1071     if (cnode->HasPrimalAttr(PIPELINE_PARAM)) {
1072       forward_params->push_back(node);
1073     }
1074   }
1075   std::sort((*backward_start).begin(), (*backward_start).end(), CompFuncBySegDescending);
1076   std::sort((*backward_end).begin(), (*backward_end).end(), CompFuncBySegDescending);
1077   std::sort((*forward_start).begin(), (*forward_start).end(), CompFuncBySegAscending);
1078   std::sort((*forward_end).begin(), (*forward_end).end(), CompFuncBySegAscending);
1079   std::sort((*backward_params).begin(), (*backward_params).end(), CompFunc);
1080   std::sort((*forward_params).begin(), (*forward_params).end(), CompFunc);
1081 }
1082 
CheckBorderNode(const PipelinePair & forward_start_pair,const PipelinePair & forward_end_pair,const PipelinePair & backward_start_pair,const PipelinePair & backward_end_pair,std::vector<int64_t> seg_micro_max)1083 void CheckBorderNode(const PipelinePair &forward_start_pair, const PipelinePair &forward_end_pair,
1084                      const PipelinePair &backward_start_pair, const PipelinePair &backward_end_pair,
1085                      std::vector<int64_t> seg_micro_max) {
1086   auto micro_size = LongToSize(seg_micro_max[0] + 1);
1087   auto seg_size = LongToSize(seg_micro_max[1] + 1);
1088   auto total_micro_size = micro_size * seg_size;
1089   std::string cause = ". One possible cause is that the @lazy_inline decorator is misplaced.";
1090   if (forward_start_pair.first.size() != total_micro_size) {
1091     MS_LOG(EXCEPTION) << "forward_node's size:" << forward_start_pair.first.size()
1092                       << "is not equal to micro size:" << total_micro_size << cause;
1093   }
1094   if (forward_end_pair.first.size() != total_micro_size) {
1095     MS_LOG(EXCEPTION) << "forward_node's size:" << forward_end_pair.first.size()
1096                       << "is not equal to micro size:" << total_micro_size << cause;
1097   }
1098   if (backward_start_pair.first.size() != total_micro_size) {
1099     MS_LOG(EXCEPTION) << "backward_node's size:" << backward_start_pair.first.size()
1100                       << "is not equal to micro size:" << total_micro_size << cause;
1101   }
1102   if (backward_end_pair.first.size() != total_micro_size) {
1103     MS_LOG(EXCEPTION) << "backward_node's size:" << backward_end_pair.first.size()
1104                       << "is not equal to micro size:" << total_micro_size << cause;
1105   }
1106 }
1107 
Reorder(const FuncGraphPtr & root)1108 void Reorder(const FuncGraphPtr &root) {
1109   std::vector<AnfNodePtr> forward_start;
1110   std::vector<AnfNodePtr> forward_end;
1111   std::vector<AnfNodePtr> forward_params;
1112   std::vector<AnfNodePtr> backward_start;
1113   std::vector<AnfNodePtr> backward_end;
1114   std::vector<AnfNodePtr> backward_params;
1115   std::vector<AnfNodePtr> allreduce_params;
1116   SetParameterStartForCellShare(root);
1117   GetBorderNode(&forward_start, &forward_end, &backward_start, &backward_end, &forward_params, &backward_params,
1118                 &allreduce_params, root);
1119   int64_t micro_max = GetMicroMax(root, forward_end);
1120   std::vector<int64_t> seg_micro_max{micro_max, 0};
1121   auto backward_start_pair = Deduplicate(backward_start, root, micro_max, 0, true);
1122   auto backward_end_pair = Deduplicate(backward_end, root, micro_max, 0, true);
1123   auto forward_start_pair = Deduplicate(forward_start, root, micro_max, 0, true);
1124   auto forward_end_pair = Deduplicate(forward_end, root, micro_max, 0, true);
1125   auto forward_params_pair = Deduplicate(forward_params, root, micro_max, 0, true);
1126   auto backward_params_pair = Deduplicate(backward_params, root, micro_max, 0, true);
1127   CheckBorderNode(forward_start_pair, forward_end_pair, backward_start_pair, backward_end_pair, seg_micro_max);
1128   auto forward_end_before_pair = GetForwardEndBeforePair(forward_end_pair);
1129   auto ret_after = root->get_return();
1130   MS_EXCEPTION_IF_NULL(ret_after);
1131   auto all_nodes = DeepScopedGraphSearch(ret_after);
1132   auto manager = root->manager();
1133   for (auto &node : all_nodes) {
1134     if (!node->isa<CNode>()) {
1135       continue;
1136     }
1137     if (IsSomePrimitive(node->cast<CNodePtr>(), kNPUClearFloatStatusOpName)) {
1138       InsertDepend(node, forward_end.front(), manager, root);
1139       break;
1140     }
1141   }
1142   ReorderForForward(forward_start_pair.first, forward_end_pair.second, root);
1143   ReorderForBackward(forward_start_pair, forward_end_pair, backward_start_pair, backward_end_pair,
1144                      forward_end_before_pair, root);
1145   ReorderForParams(backward_params_pair, forward_params_pair, backward_end_pair, forward_start_pair, root);
1146 }
1147 
ReorderForPredict(const FuncGraphPtr & root,const FuncGraphManagerPtr & manager)1148 void ReorderForPredict(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager) {
1149   std::vector<AnfNodePtr> forward_end;
1150   std::vector<AnfNodePtr> forward_start;
1151   std::vector<AnfNodePtr> forward_params;
1152   int64_t slice_index = 0;
1153   for (auto &node : root->nodes()) {
1154     if (!node->isa<CNode>()) {
1155       continue;
1156     }
1157     auto cnode = node->cast<CNodePtr>();
1158     if (cnode->HasPrimalAttr(PIPELINE_BEGIN)) {
1159       if (IsPrimitiveCNode(cnode, prim::kPrimStridedSlice)) {
1160         cnode->AddPrimalAttr(SLICE_INDEX, MakeValue(slice_index));
1161         slice_index += 1;
1162       }
1163       forward_start.push_back(node);
1164     }
1165     if (cnode->HasPrimalAttr(PIPELINE_END)) {
1166       forward_end.push_back(node);
1167     }
1168     if (cnode->HasPrimalAttr(PIPELINE_PARAM)) {
1169       forward_params.push_back(node);
1170     }
1171   }
1172   std::sort(forward_start.begin(), forward_start.end(), CompFunc);
1173   std::sort(forward_end.begin(), forward_end.end(), CompFunc);
1174   std::sort(forward_params.begin(), forward_params.end(), CompFunc);
1175   auto forward_start_pair = Deduplicate(forward_start, root, 0, 0, false);
1176   auto forward_end_pair = Deduplicate(forward_end, root, 0, 0, false);
1177   auto forward_params_pair = Deduplicate(forward_params, root, 0, 0, false);
1178   if (!forward_end.empty() && !forward_params.empty()) {
1179     InsertDepend(forward_params_pair.second[0], forward_end_pair.first[0], manager, root);
1180   }
1181   if (!forward_start.empty() && !forward_params.empty()) {
1182     InsertDepend(forward_params_pair.second[0], forward_start_pair.first[0], manager, root);
1183   }
1184 }
1185 
GetRank()1186 int64_t GetRank() {
1187   auto ms_context = MsContext::GetInstance();
1188   MS_EXCEPTION_IF_NULL(ms_context);
1189   auto world_group = GetWorldGroup();
1190   int64_t global_rank = parallel::ParallelContext::GetInstance()->global_rank();
1191   uint32_t rank_id = 0;
1192   if (!parallel::ParallelContext::GetInstance()->global_rank_is_set()) {
1193     if (!CommManager::GetInstance().GetRankID(world_group, &rank_id)) {
1194       MS_LOG(EXCEPTION) << "Get rank id failed.";
1195     }
1196     global_rank = UintToInt(rank_id);
1197   }
1198   return global_rank;
1199 }
1200 
GetWorldGroup()1201 std::string GetWorldGroup() {
1202   auto context = MsContext::GetInstance();
1203   MS_EXCEPTION_IF_NULL(context);
1204   std::string group;
1205   std::string backend = context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
1206   if (backend == kAscendDevice) {
1207     group = parallel::HCCL_WORLD_GROUP;
1208   } else if (backend == kGPUDevice) {
1209     group = parallel::NCCL_WORLD_GROUP;
1210   } else {
1211     MS_LOG(EXCEPTION) << "Invalid backend: " << backend;
1212   }
1213   return group;
1214 }
1215 
InferStage()1216 int64_t InferStage() {
1217   auto global_rank = GetRank();
1218   auto world_group = GetWorldGroup();
1219   uint32_t world_rank_size = 0;
1220   int64_t device_num = 0;
1221   auto stage_num = parallel::ParallelContext::GetInstance()->pipeline_stage_split_num();
1222   if (!parallel::ParallelContext::GetInstance()->device_num_is_set()) {
1223     if (!CommManager::GetInstance().GetRankSize(world_group, &world_rank_size)) {
1224       MS_LOG(EXCEPTION) << "Get rank size failed";
1225     }
1226     device_num = UintToInt(world_rank_size);
1227     MS_LOG(INFO) << "Get device num from communication model, the device num is  " << device_num;
1228   } else {
1229     device_num = parallel::ParallelContext::GetInstance()->device_num();
1230   }
1231 
1232   if (device_num < 1) {
1233     MS_LOG(ERROR) << "For 'PipelineSplit', the argument 'device_num' must be positive, "
1234                      "but got the value of device_num: "
1235                   << device_num;
1236   }
1237   if (global_rank < 0) {
1238     MS_LOG(ERROR) << "For 'PipelineSplit', the argument 'global_rank' must be nonnegative, "
1239                      "but got the value of global_rank: "
1240                   << global_rank;
1241   }
1242   if (stage_num == 0) {
1243     MS_LOG(EXCEPTION) << "Stage_num is zero";
1244   }
1245   if (device_num % stage_num != 0) {
1246     MS_LOG(EXCEPTION) << "Device_num must be divisible by the stage_num, got device_num: " << device_num
1247                       << " stage_num: " << stage_num;
1248   }
1249   auto per_stage_rank_num = device_num / stage_num;
1250   return global_rank / per_stage_rank_num;
1251 }
1252 }  // namespace parallel
1253 }  // namespace mindspore
1254