1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/parallel/graph_util/pipeline_split_utils.h"
18 #include <algorithm>
19 #include <list>
20 #include <memory>
21 #include <queue>
22 #include <set>
23 #include "include/common/utils/comm_manager.h"
24 #include "frontend/parallel/device_manager.h"
25 #include "frontend/parallel/graph_util/generate_graph.h"
26 #include "frontend/parallel/graph_util/node_info.h"
27 #include "frontend/parallel/ops_info/ops_utils.h"
28 #include "frontend/parallel/step_parallel.h"
29 #include "frontend/parallel/step_parallel_utils.h"
30 #include "frontend/parallel/dynamic_shape/dynamic_shape.h"
31 #include "frontend/parallel/graph_util/fold_pipeline_split_utils.h"
32 #include "include/common/utils/parallel_context.h"
33 #include "ir/value.h"
34 #include "ops/array_ops.h"
35 #include "ops/framework_ops.h"
36 #include "ops/other_ops.h"
37 #include "ops/sequence_ops.h"
38 #include "utils/parallel_node_check.h"
39
40 namespace mindspore {
41 namespace parallel {
42 namespace {
IsSendRec(const AnfNodePtr & node)43 bool IsSendRec(const AnfNodePtr &node) {
44 return IsPrimitiveCNode(node, prim::kPrimSend) || IsPrimitiveCNode(node, prim::kPrimReceive);
45 }
46
TagForSendRecDepend(const AnfNodePtr & prior_node,const AnfNodePtr & post_node)47 std::string TagForSendRecDepend(const AnfNodePtr &prior_node, const AnfNodePtr &post_node) {
48 if (!IsSendRec(prior_node) || !IsSendRec(post_node)) {
49 return "";
50 }
51 if (prior_node->cast<CNodePtr>()->HasPrimalAttr(kPrimalAttrForwardNodeName) ==
52 post_node->cast<CNodePtr>()->HasPrimalAttr(kPrimalAttrForwardNodeName)) {
53 return "";
54 }
55 return std::string(SEND_REC_DEPEND);
56 }
57 } // namespace
58
IsFirstStage()59 bool IsFirstStage() {
60 MS_EXCEPTION_IF_NULL(g_device_manager);
61 auto stage_id = g_device_manager->stage_id();
62 return stage_id == 0;
63 }
64
IsLastStage()65 bool IsLastStage() {
66 MS_EXCEPTION_IF_NULL(g_device_manager);
67 auto stage_num = g_device_manager->stage_num();
68 auto stage_id = g_device_manager->stage_id();
69 return ((stage_num - 1) == stage_id);
70 }
71
GetReceiveMicro(const CNodePtr & cnode)72 static ValuePtr GetReceiveMicro(const CNodePtr &cnode) {
73 std::queue<CNodePtr> que;
74 std::set<AnfNodePtr> visited;
75 que.push(cnode);
76 while (!que.empty()) {
77 auto front = que.front();
78 que.pop();
79 (void)(visited.insert(front));
80 for (size_t i = 1; i < front->size(); ++i) {
81 auto input = front->input(i);
82 if (!input->isa<CNode>()) {
83 continue;
84 }
85 auto cinput = input->cast<CNodePtr>();
86 MS_EXCEPTION_IF_NULL(cinput);
87 if (IsPrimitiveCNode(cinput, prim::kPrimReceive)) {
88 return cinput->GetPrimalAttr(MICRO);
89 }
90 if (visited.find(cinput) == visited.end()) {
91 que.push(cinput);
92 }
93 }
94 }
95 return nullptr;
96 }
97
GetReceiveSegment(const CNodePtr & cnode)98 static ValuePtr GetReceiveSegment(const CNodePtr &cnode) {
99 std::queue<CNodePtr> que;
100 std::set<AnfNodePtr> visited;
101 que.push(cnode);
102 while (!que.empty()) {
103 auto front = que.front();
104 que.pop();
105 (void)(visited.insert(front));
106 for (size_t i = 1; i < front->size(); ++i) {
107 auto input = front->input(i);
108 if (!input->isa<CNode>()) {
109 continue;
110 }
111 auto cinput = input->cast<CNodePtr>();
112 MS_EXCEPTION_IF_NULL(cinput);
113 if (IsPrimitiveCNode(cinput, prim::kPrimReceive)) {
114 return cinput->GetPrimalAttr(SEGMENT);
115 }
116 if (visited.find(cinput) == visited.end()) {
117 que.push(cinput);
118 }
119 }
120 }
121 return nullptr;
122 }
123
EnableShareCell()124 static bool EnableShareCell() {
125 auto context = MsContext::GetInstance();
126 MS_EXCEPTION_IF_NULL(context);
127 const auto cell_reuse = context->CellReuseLevel() != CellReuseLevel::kNoCellReuse;
128 const auto &comm_reuse_env = common::GetEnv("MS_COMM_COMPILER_OPT");
129 if (!comm_reuse_env.empty() && cell_reuse) {
130 MS_LOG(EXCEPTION) << "The cell reuse cannot be used with communication reuse,"
131 " please unset environment variable 'MS_COMM_COMPILER_OPT'";
132 }
133 MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
134 bool grad_accumulation_shard = ParallelContext::GetInstance()->grad_accumulation_shard();
135 if (grad_accumulation_shard && cell_reuse) {
136 MS_LOG(EXCEPTION)
137 << "The cell reuse cannot be used with sharding accumulate grad parameter with optimizer parallel,"
138 " please set_auto_parallel_context(parallel_optimizer_config={'gradient_accumulation_shard':False})";
139 }
140 return cell_reuse;
141 }
142
GetCallBackwardEndNext(const AnfNodePtr & node)143 static AnfNodePtr GetCallBackwardEndNext(const AnfNodePtr &node) {
144 if (!node->has_user_data(CALL_BACKWARD_END_NEXT)) {
145 return node;
146 }
147 return node->user_data<AnfNode>(CALL_BACKWARD_END_NEXT);
148 }
149
IsValidNode(const AnfNodePtr & node,const AnfNodePtr & return_node,const NodeUsersMap & node_user_map)150 bool IsValidNode(const AnfNodePtr &node, const AnfNodePtr &return_node, const NodeUsersMap &node_user_map) {
151 if (node == return_node) {
152 return true;
153 }
154 auto iter = node_user_map.find(node);
155 if (iter == node_user_map.end()) {
156 return false;
157 }
158 const auto &users = (*iter).second;
159 return std::any_of(users.begin(), users.end(),
160 [&return_node, &node_user_map](const std::pair<AnfNodePtr, int> &user) {
161 return IsValidNode(user.first, return_node, node_user_map);
162 });
163 }
164
165 // judge if the graph call specified grad nodes, the specified grad nodes is in grad_graph
166 // search if call specified grad nodes according to dfs
CallGradNodes(const FuncGraphPtr & graph,const FuncGraphPtr & grad_graph,std::set<FuncGraphPtr> * const visit)167 static bool CallGradNodes(const FuncGraphPtr &graph, const FuncGraphPtr &grad_graph,
168 std::set<FuncGraphPtr> *const visit) {
169 if (visit->find(graph) != visit->end()) {
170 return false;
171 }
172 if (graph == grad_graph) {
173 return true;
174 }
175 (void)(visit->insert(graph));
176 const auto &cnodes = graph->GetOrderedCnodes();
177 for (const auto &cnode : cnodes) {
178 const auto &abs = cnode->input(0)->abstract();
179 if (!abs || !abs->isa<abstract::AbstractFunction>()) {
180 continue;
181 }
182 const auto &abs_func = abs->cast<abstract::AbstractFunctionPtr>();
183 if (!abs_func->isa<abstract::FuncGraphAbstractClosure>()) {
184 continue;
185 }
186 const auto &abs_func_graph = abs->cast<abstract::FuncGraphAbstractClosurePtr>();
187 auto fg = abs_func_graph->func_graph();
188 if (fg && fg == grad_graph) {
189 return true;
190 }
191 if (CallGradNodes(fg, grad_graph, visit)) {
192 return true;
193 }
194 }
195 return false;
196 }
197
FindGradGraph(const FuncGraphPtr & root)198 static FuncGraphPtr FindGradGraph(const FuncGraphPtr &root) {
199 const auto &nodes = DeepScopedGraphSearch(root->get_return());
200 for (const auto &node : nodes) {
201 if (!node->isa<CNode>()) {
202 continue;
203 }
204 const auto &cnode = node->cast<CNodePtr>();
205 if (cnode->HasPrimalAttr(PARAMETER_START_SHARE_CELL) && cnode->HasPrimalAttr(kPrimalAttrForwardNodeName)) {
206 const auto &grad_graph = cnode->func_graph();
207 MS_LOG(INFO) << "The specified grad nodes is in graph " << grad_graph->ToString();
208 return grad_graph;
209 }
210 }
211 MS_LOG(EXCEPTION) << "Stage0: The grad graph has not been found in lazy inline mode.";
212 return nullptr;
213 }
214
SetParameterStartForCellShare(const FuncGraphPtr & root)215 void SetParameterStartForCellShare(const FuncGraphPtr &root) {
216 MS_EXCEPTION_IF_NULL(root);
217 auto share_cell = EnableShareCell();
218 if (!share_cell) {
219 return;
220 }
221 if (!IsFirstStage()) {
222 return;
223 }
224 FuncGraphPtr grad_graph = FindGradGraph(root);
225 MS_EXCEPTION_IF_NULL(grad_graph);
226 const auto &manager = root->manager();
227 auto node_user_map = manager->node_users();
228 auto all_nodes = root->GetOrderedCnodes();
229 std::set<FuncGraphPtr> call_grad_nodes;
230 bool has_find = false;
231 for (auto &node : all_nodes) {
232 // if cnode is a call_backward node
233 if (!IsPrimitiveCNode(node->input(0), prim::kPrimTupleGetItem)) {
234 continue;
235 }
236 const auto &abs = node->input(0)->abstract();
237 if (!abs || !abs->isa<abstract::AbstractFunction>()) {
238 continue;
239 }
240 const auto &abs_func = abs->cast<abstract::AbstractFunctionPtr>();
241 if (!abs_func->isa<abstract::FuncGraphAbstractClosure>()) {
242 continue;
243 }
244 std::set<FuncGraphPtr> visit;
245 const auto &abs_func_graph = abs->cast<abstract::FuncGraphAbstractClosurePtr>();
246 auto fg = abs_func_graph->func_graph();
247 if (!fg || (call_grad_nodes.find(fg) == call_grad_nodes.end() && !CallGradNodes(fg, grad_graph, &visit))) {
248 continue;
249 }
250 if (call_grad_nodes.find(fg) == call_grad_nodes.end()) {
251 (void)(call_grad_nodes.insert(fg));
252 }
253 auto micro = GetReceiveMicro(node);
254 MS_EXCEPTION_IF_NULL(micro);
255 auto node_abs = node->abstract();
256 if (node_abs->isa<abstract::AbstractTuple>()) {
257 CNodePtr next = nullptr;
258 const auto &users = node_user_map[node];
259 for (const auto &user : users) {
260 const auto &cuser = user.first->cast<CNodePtr>();
261 MS_EXCEPTION_IF_NULL(cuser);
262 if (IsPrimitiveCNode(cuser, prim::kPrimTupleGetItem) && IsValidNode(cuser, root->get_return(), node_user_map)) {
263 next = cuser;
264 break;
265 }
266 }
267 node->set_user_data<AnfNode>(CALL_BACKWARD_END_NEXT, next);
268 }
269 has_find = true;
270 node->AddPrimalAttr(MICRO, micro);
271 node->AddPrimalAttr(PARAMETER_START, micro);
272 auto parallel_context = parallel::ParallelContext::GetInstance();
273 if (parallel_context->enable_fold_pipeline()) {
274 auto segment = GetReceiveSegment(node);
275 MS_EXCEPTION_IF_NULL(segment);
276 node->AddPrimalAttr(SEGMENT, segment);
277 }
278 }
279 if (!has_find) {
280 MS_LOG(EXCEPTION) << "Stage0: The backward end flag has not been marked in lazy inline mode.";
281 } else {
282 MS_LOG(INFO) << "Stage0: The backward end flag has been marked in lazy inline mode.";
283 }
284 }
285
FindAccuGrad(const CNodePtr & cnode)286 AnfNodePtr FindAccuGrad(const CNodePtr &cnode) {
287 auto pre_node = cnode->input(1);
288 size_t depth = 0;
289 while (true) {
290 if (depth > MAX_RECURSIVE_DEPTH) {
291 return nullptr;
292 }
293 depth += 1;
294 if (pre_node->isa<Parameter>()) {
295 return pre_node;
296 } else {
297 if (pre_node->isa<CNode>()) {
298 auto pre_cnode = pre_node->cast<CNodePtr>();
299 pre_node = pre_cnode->input(1);
300 } else {
301 return nullptr;
302 }
303 }
304 }
305 return nullptr;
306 }
307
SetStridedSliceStrategy(const AnfNodePtr & node)308 void SetStridedSliceStrategy(const AnfNodePtr &node) {
309 MS_EXCEPTION_IF_NULL(node);
310 if (!IsPrimitiveCNode(node, prim::kPrimStridedSlice)) {
311 return;
312 }
313 bool full_batch = ParallelContext::GetInstance()->full_batch();
314 auto dev_num = g_device_manager->stage_device_num();
315 auto cnode = node->cast<CNodePtr>();
316 MS_EXCEPTION_IF_NULL(cnode);
317 std::vector<Shapes> shape_list;
318 if (InDynamicGraph(cnode)) {
319 shape_list = ExtractRealDivisor(cnode);
320 MS_LOG(INFO) << "the node is in dynamic shape graph, the divisor is " << ShapesToString(shape_list[0]);
321 } else {
322 shape_list = ExtractShape(cnode);
323 }
324
325 if (shape_list.empty()) {
326 MS_LOG(EXCEPTION) << "Failure:node " << cnode->ToString() << " failed to extract shape";
327 }
328 std::vector<ValuePtr> elements;
329 for (size_t i = 0; i < shape_list[0].size(); i++) {
330 if (shape_list[0][i].empty()) {
331 MS_LOG(EXCEPTION) << "shape_list[ " << i << " ].size() is zero";
332 }
333 Dimensions input_strategy;
334 for (size_t j = 0; j < shape_list[0][i].size(); j++) {
335 input_strategy.push_back(1);
336 }
337 static const auto skip_redis = (common::GetEnv("PIPELINE_SLICE_SKIP_REDISTRIBUTION") == "1");
338 if (skip_redis && !full_batch && input_strategy.size() > 0) {
339 auto dim = shape_list[1][0][0];
340 if (dev_num <= dim && ((dim % dev_num) == 0)) {
341 input_strategy[0] = dev_num;
342 } else if (dim < dev_num && ((dev_num % dim) == 0)) {
343 input_strategy[0] = dim;
344 }
345 auto prim = GetCNodePrimitive(node);
346 if (prim->HasAttr("out_shard_size")) {
347 auto out_shard_size = GetValue<int64_t>(prim->GetAttr("out_shard_size"));
348 input_strategy[0] = out_shard_size;
349 }
350 auto attrs = prim->attrs();
351 attrs[parallel::SKIP_REDISTRIBUTION] = MakeValue<bool>(true);
352 (void)prim->SetAttrs(attrs);
353 }
354
355 elements.push_back(MakeValue(input_strategy));
356 }
357 ValueTuplePtr strategy = std::make_shared<ValueTuple>(elements);
358 cnode->AddPrimalAttr(IN_STRATEGY, strategy);
359 }
360
FindNodeWithMircoSize(const AnfNodePtr & node_user,const NodeUsersMap & node_users_map)361 CNodePtr FindNodeWithMircoSize(const AnfNodePtr &node_user, const NodeUsersMap &node_users_map) {
362 // Recursively find micro tags, this may takes much more time if layers are too much
363 std::queue<AnfNodePtr> visited;
364 visited.push(node_user);
365 while (!visited.empty()) {
366 auto cur_node = visited.front();
367 visited.pop();
368 if (node_users_map.find(cur_node) == node_users_map.end()) {
369 continue;
370 }
371 auto users = node_users_map.at(cur_node);
372 for (auto &temp_user : users) {
373 auto cnode = temp_user.first->cast<CNodePtr>();
374 MS_EXCEPTION_IF_NULL(cnode);
375 if (!cnode->HasPrimalAttr(MICRO)) {
376 visited.push(temp_user.first);
377 } else {
378 return cnode;
379 }
380 }
381 }
382 return nullptr;
383 }
384
IsSourceUsedByMirror(const CNodePtr & node,const NodeUsersMap & node_user_map)385 bool IsSourceUsedByMirror(const CNodePtr &node, const NodeUsersMap &node_user_map) {
386 if (node->size() < 2) {
387 return false;
388 }
389 auto parameter_node = node->input(1);
390 if (parameter_node->cast<ParameterPtr>()) {
391 for (auto &item : node_user_map.at(parameter_node)) {
392 if (IsPrimitiveCNode(item.first, prim::kPrimMirrorMicroStep)) {
393 return true;
394 }
395 }
396 }
397 return false;
398 }
InsertVirtualAssignAdd(const std::pair<AnfNodePtr,int> & node_user,const FuncGraphManagerPtr & manager,const AnfNodePtr & accu_parameter,const NodeUsersMap & node_user_map)399 void InsertVirtualAssignAdd(const std::pair<AnfNodePtr, int> &node_user, const FuncGraphManagerPtr &manager,
400 const AnfNodePtr &accu_parameter, const NodeUsersMap &node_user_map) {
401 auto cnode = node_user.first->cast<CNodePtr>();
402 if (IsPrimitiveCNode(cnode, prim::kPrimReceive) || !cnode->in_forward_flag()) {
403 return;
404 }
405 MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
406 bool enable_parallel_optimizer = ParallelContext::GetInstance()->enable_parallel_optimizer();
407 bool grad_accumulation_shard = ParallelContext::GetInstance()->grad_accumulation_shard();
408 auto is_pp_interleave = ParallelContext::GetInstance()->pipeline_interleave();
409 if (!is_pp_interleave && IsPrimitiveCNode(cnode, prim::kPrimMakeTuple)) {
410 return;
411 }
412 if (IsPrimitiveCNode(cnode, prim::kPrimDepend) && enable_parallel_optimizer &&
413 IsSourceUsedByMirror(cnode, node_user_map)) {
414 return;
415 }
416 auto param_ptr = accu_parameter->cast<ParameterPtr>();
417 MS_EXCEPTION_IF_NULL(param_ptr);
418 // If grad_accumulation_shard is ture, a ReduceScatter will be inserted at each micro step,
419 // So the fusion id should be different for each micro step
420 // otherwise they will be fused into the one ReduceScatter alone micro_steps.
421 // if grad_accumulation_shard is false, we pass an empty group, so no ReduceScatter will be inserted
422 ValuePtr args1 = nullptr;
423 ValuePtr args2 = nullptr;
424 ValuePtr micro = nullptr;
425 int64_t step = 0;
426 if (grad_accumulation_shard) {
427 auto cnode_with_micro_size = FindNodeWithMircoSize(cnode, node_user_map);
428 if (cnode_with_micro_size && cnode_with_micro_size->HasPrimalAttr(MICRO)) {
429 micro = cnode_with_micro_size->GetPrimalAttr(MICRO);
430 step = GetValue<int64_t>(micro);
431 }
432 }
433 args1 = MakeValue(param_ptr->user_data<TensorLayout>()->opt_shard_group());
434 args2 = MakeValue(LongToSize(param_ptr->param_info()->comm_fusion()) + LongToSize(step) * PIPELINE_FUSTION_OFFSET);
435 OperatorAttrs attrs = {};
436 auto py_instance = CreateOpInstance(attrs, VIRTUAL_ASSIGN_ADD, VIRTUAL_ASSIGN_ADD);
437 auto value_node = NewValueNode(py_instance);
438 // Set the attribute of the reduce scatter
439 auto new_prim = GetValueNode<PrimitivePtr>(value_node);
440 MS_EXCEPTION_IF_NULL(new_prim);
441 auto attrs_prim = new_prim->attrs();
442 attrs_prim[GROUP] = args1;
443 attrs_prim[kAttrFusion] = args2;
444 if (cnode->HasPrimalAttr(PIPELINE_PARAM)) {
445 attrs_prim[PIPELINE_PARAM] = MakeValue(true);
446 }
447 (void)new_prim->SetAttrs(attrs_prim);
448
449 std::vector<AnfNodePtr> virtual_node_input = {value_node, cnode->input(IntToSize(node_user.second)), accu_parameter};
450 auto graph = cnode->func_graph();
451 auto virtual_node = graph->NewCNode(virtual_node_input);
452 manager->SetEdge(cnode, node_user.second, virtual_node);
453 }
454
InsertVirtualAccuGrad(const AnfNodePtr & recv,const FuncGraphManagerPtr & manager,const AnfNodePtr & param)455 void InsertVirtualAccuGrad(const AnfNodePtr &recv, const FuncGraphManagerPtr &manager, const AnfNodePtr ¶m) {
456 auto cnode = recv->cast<CNodePtr>();
457 MS_EXCEPTION_IF_NULL(cnode);
458 OperatorAttrs attrs;
459 auto py_instance = CreateOpInstance(attrs, VIRTUAL_ACCU_GRAD, VIRTUAL_ACCU_GRAD);
460 auto value_node = NewValueNode(py_instance);
461 std::vector<AnfNodePtr> virtual_node_input = {value_node, recv, param};
462 auto graph = cnode->func_graph();
463 MS_EXCEPTION_IF_NULL(graph);
464 auto virtual_node = graph->NewCNode(virtual_node_input);
465 (void)manager->Replace(recv, virtual_node);
466 }
467
FindGradAccuParameter(const std::vector<AnfNodePtr> & parameters,const std::string & name)468 AnfNodePtr FindGradAccuParameter(const std::vector<AnfNodePtr> ¶meters, const std::string &name) {
469 for (auto ¶meter : parameters) {
470 auto param_ptr = parameter->cast<ParameterPtr>();
471 MS_EXCEPTION_IF_NULL(param_ptr);
472 if (param_ptr->name() == name) {
473 continue;
474 }
475 auto expect_name = "accu_grads." + name;
476 if (param_ptr->name() == expect_name) {
477 return parameter;
478 }
479 }
480 return nullptr;
481 }
482
483 // If the graph likes the followings:
484 // 1. MicroStepAllGather->MirrorMicro->load, we need to visit the param after the load
FindNextNode(const std::pair<AnfNodePtr,int> & node_ptr,const NodeUsersMap & node_users_map,const std::set<string> & check_list={prim::kPrimMirrorMicroStep->name(), prim::kPrimMicroStepAllGather->name(), prim::kPrimLoad->name()})485 std::vector<std::pair<AnfNodePtr, int>> FindNextNode(
486 const std::pair<AnfNodePtr, int> &node_ptr, const NodeUsersMap &node_users_map,
487 const std::set<string> &check_list = {prim::kPrimMirrorMicroStep->name(), prim::kPrimMicroStepAllGather->name(),
488 prim::kPrimLoad->name()}) {
489 std::vector<std::pair<AnfNodePtr, int>> to_be_visited_set;
490 if (!IsSomePrimitiveList(node_ptr.first->cast<CNodePtr>(), check_list)) {
491 (void)to_be_visited_set.emplace_back(node_ptr);
492 return to_be_visited_set;
493 }
494 auto node_set = node_users_map.at(node_ptr.first);
495 std::queue<std::pair<std::shared_ptr<AnfNode>, int>> visited;
496 for (auto &node_user : node_set) {
497 visited.push(node_user);
498 }
499 while (visited.size() >= 1) {
500 auto node = visited.front();
501 visited.pop();
502 if (!IsSomePrimitiveList(node.first->cast<CNodePtr>(), check_list)) {
503 (void)to_be_visited_set.emplace_back(node);
504 } else {
505 auto next_node_set = node_users_map.at(node.first);
506 for (auto &node_user : next_node_set) {
507 visited.push(node_user);
508 }
509 }
510 }
511 return to_be_visited_set;
512 }
513
FuncNodeUsersSet(const AnfNodePtr & parameter)514 std::set<std::pair<AnfNodePtr, int>> FuncNodeUsersSet(const AnfNodePtr ¶meter) {
515 MS_EXCEPTION_IF_NULL(parameter->func_graph());
516 MS_EXCEPTION_IF_NULL(parameter->func_graph()->manager());
517 auto node_users_map = parameter->func_graph()->manager()->node_users();
518 auto node_users = node_users_map[parameter];
519 std::set<std::pair<AnfNodePtr, int>> all_node_users;
520 for (auto &n_pair : node_users) {
521 auto users_skip_virtual_nodes =
522 FindNextNode(n_pair, node_users_map,
523 {prim::kPrimMirrorMicroStep->name(), prim::kPrimMicroStepAllGather->name(), prim::kPrimLoad->name(),
524 prim::kPrimCast->name()});
525 for (const auto &node_pair : users_skip_virtual_nodes) {
526 auto func_node_users = FuncGraphNodeUsers(node_pair);
527 if (func_node_users.empty()) {
528 (void)all_node_users.insert(node_pair);
529 continue;
530 }
531 for (const auto &func_node_user : func_node_users) {
532 (void)all_node_users.insert(func_node_user);
533 }
534 }
535 }
536 return all_node_users;
537 }
538
HandleReceiveParam(const FuncGraphPtr & root)539 void HandleReceiveParam(const FuncGraphPtr &root) {
540 auto parameters = root->parameters();
541 auto node_users_map = root->manager()->node_users();
542 auto all_nodes = TopoSort(root->get_return(), SuccDeeperSimple);
543 for (auto &node : all_nodes) {
544 if (!IsPrimitiveCNode(node, prim::kPrimReceive)) {
545 continue;
546 }
547 auto cnode = node->cast<CNodePtr>();
548 if (!cnode->HasPrimalAttr(PIPELINE_PARAM)) {
549 continue;
550 }
551 auto parameter_ptr = cnode->input(1)->cast<ParameterPtr>();
552 MS_EXCEPTION_IF_NULL(parameter_ptr);
553 auto accu_parameter = FindGradAccuParameter(parameters, parameter_ptr->name());
554 if (!accu_parameter) {
555 continue;
556 }
557 auto base_shape = accu_parameter->Shape();
558 auto shape_ptr = dyn_cast<abstract::Shape>(base_shape);
559 auto slice_shape = shape_ptr->shape();
560 auto prim = GetCNodePrimitive(cnode);
561 std::vector<ValuePtr> element;
562 (void)std::transform(slice_shape.begin(), slice_shape.end(), std::back_inserter(element),
563 [](int64_t elem) { return MakeValue(elem); });
564 auto value = std::make_shared<ValueList>(element);
565 prim->set_attr(SHAPE, value);
566 std::set<std::pair<AnfNodePtr, int>> all_node_users = FuncNodeUsersSet(node);
567 for (auto &temp_user : all_node_users) {
568 auto temp_node = temp_user.first;
569 // Micro virtual operator might be inserted after cast
570 if (IsPrimitiveCNode(temp_node, prim::kPrimCast)) {
571 temp_node = node_users_map[temp_node].begin()->first;
572 }
573 if (IsPrimitiveCNode(temp_node, prim::kPrimMirrorMicroStep) ||
574 IsPrimitiveCNode(temp_node, prim::kPrimMicroStepAllGather)) {
575 auto node_set = node_users_map[temp_node];
576 for (auto &node_user : node_set) {
577 InsertVirtualAssignAdd(node_user, root->manager(), accu_parameter, node_users_map);
578 }
579 } else {
580 InsertVirtualAssignAdd(temp_user, root->manager(), accu_parameter, node_users_map);
581 }
582 }
583 InsertVirtualAccuGrad(node, root->manager(), accu_parameter);
584 }
585 }
586
AddVirtualAssignAdd(const FuncGraphPtr & root)587 void AddVirtualAssignAdd(const FuncGraphPtr &root) {
588 auto parameters = root->parameters();
589 auto node_users_map = root->manager()->node_users();
590 for (auto ¶meter : parameters) {
591 auto parameter_ptr = parameter->cast<ParameterPtr>();
592 auto accu_parameter = FindGradAccuParameter(parameters, parameter_ptr->name());
593 if (!accu_parameter) {
594 continue;
595 }
596 std::set<std::pair<AnfNodePtr, int>> all_node_users = FuncNodeUsersSet(parameter);
597 for (auto &temp_user : all_node_users) {
598 // Micro virtual operator might be inserted after cast
599 auto temp_node = temp_user;
600 if (IsPrimitiveCNode(temp_node.first, prim::kPrimCast)) {
601 temp_node = *node_users_map[temp_node.first].begin();
602 }
603 if (!IsSomePrimitiveList(
604 temp_node.first->cast<CNodePtr>(),
605 {prim::kPrimMirrorMicroStep->name(), prim::kPrimMicroStepAllGather->name(), prim::kPrimLoad->name()})) {
606 InsertVirtualAssignAdd(temp_node, root->manager(), accu_parameter, node_users_map);
607 continue;
608 }
609 auto node_set = FindNextNode(temp_node, node_users_map);
610 for (auto &node_user : node_set) {
611 InsertVirtualAssignAdd(node_user, root->manager(), accu_parameter, node_users_map);
612 }
613 }
614 }
615 }
616
SliceSort(const CNodePtr & cnode1,const CNodePtr & cnode2)617 bool SliceSort(const CNodePtr &cnode1, const CNodePtr &cnode2) {
618 if (IsPrimitiveCNode(cnode1, prim::kPrimStridedSlice) && IsPrimitiveCNode(cnode2, prim::kPrimStridedSlice)) {
619 auto slice_index1 = GetValue<int64_t>(cnode1->GetPrimalAttr(SLICE_INDEX));
620 auto slice_index2 = GetValue<int64_t>(cnode2->GetPrimalAttr(SLICE_INDEX));
621 return slice_index1 < slice_index2;
622 }
623 if (IsPrimitiveCNode(cnode1, prim::kPrimStridedSlice)) {
624 return false;
625 }
626 return true;
627 }
628
CompFunc(const AnfNodePtr & node1,const AnfNodePtr & node2)629 bool CompFunc(const AnfNodePtr &node1, const AnfNodePtr &node2) {
630 MS_EXCEPTION_IF_NULL(node1);
631 MS_EXCEPTION_IF_NULL(node2);
632 auto cnode1 = node1->cast<CNodePtr>();
633 auto cnode2 = node2->cast<CNodePtr>();
634 MS_EXCEPTION_IF_NULL(cnode1);
635 MS_EXCEPTION_IF_NULL(cnode2);
636 auto micro1 = cnode1->GetPrimalAttr(MICRO);
637 auto micro2 = cnode2->GetPrimalAttr(MICRO);
638 MS_EXCEPTION_IF_NULL(micro1);
639 MS_EXCEPTION_IF_NULL(micro2);
640 auto micro1_value = GetValue<int64_t>(micro1);
641 auto micro2_value = GetValue<int64_t>(micro2);
642 if (micro1_value == micro2_value) {
643 if (IsPrimitiveCNode(node1, prim::kPrimStridedSlice) || IsPrimitiveCNode(node2, prim::kPrimStridedSlice)) {
644 return SliceSort(cnode1, cnode2);
645 }
646 auto prim1 = GetCNodePrimitive(cnode1);
647 auto prim2 = GetCNodePrimitive(cnode2);
648 if (EnableShareCell() && prim1 == nullptr && prim2 == nullptr) {
649 return false;
650 }
651 MS_EXCEPTION_IF_NULL(prim1);
652 MS_EXCEPTION_IF_NULL(prim2);
653 auto rank_tag1 = prim1->GetAttr(SRC_RANK);
654 auto rank_tag2 = prim2->GetAttr(SRC_RANK);
655 if (rank_tag1 == nullptr) {
656 rank_tag1 = prim1->GetAttr(DEST_RANK);
657 }
658 if (rank_tag2 == nullptr) {
659 rank_tag2 = prim2->GetAttr(DEST_RANK);
660 }
661 if (!rank_tag1 || !rank_tag2) {
662 return false;
663 }
664 auto rank1_value = GetValue<int64_t>(rank_tag1);
665 auto rank2_value = GetValue<int64_t>(rank_tag2);
666 if (rank1_value == rank2_value) {
667 auto sr_tag1 = prim1->GetAttr(SR_TAG);
668 auto sr_tag2 = prim2->GetAttr(SR_TAG);
669 MS_EXCEPTION_IF_NULL(sr_tag1);
670 MS_EXCEPTION_IF_NULL(sr_tag2);
671 auto sr1_value = GetValue<int64_t>(sr_tag1);
672 auto sr2_value = GetValue<int64_t>(sr_tag2);
673 return sr1_value < sr2_value;
674 }
675 return rank1_value < rank2_value;
676 }
677 return micro1_value < micro2_value;
678 }
679
InsertDepend(const AnfNodePtr & prior_node,const AnfNodePtr & post_node,const FuncGraphManagerPtr & manager,const FuncGraphPtr & root,const std::string & attr_tag)680 void InsertDepend(const AnfNodePtr &prior_node, const AnfNodePtr &post_node, const FuncGraphManagerPtr &manager,
681 const FuncGraphPtr &root, const std::string &attr_tag) {
682 MS_EXCEPTION_IF_NULL(prior_node);
683 MS_EXCEPTION_IF_NULL(post_node);
684 auto post_cnode = post_node->cast<CNodePtr>();
685 MS_EXCEPTION_IF_NULL(post_cnode);
686 std::vector<AnfNodePtr> depend_input = {NewValueNode(prim::kPrimDepend), post_cnode->input(1), prior_node};
687 auto depend_node = root->NewCNode(depend_input);
688 depend_node->set_abstract(post_cnode->input(1)->abstract());
689 if (!attr_tag.empty()) {
690 depend_node->AddAttr(attr_tag, MakeValue<bool>(true));
691 }
692 manager->SetEdge(post_node, 1, depend_node);
693 }
694
ReorderForForward(const std::vector<AnfNodePtr> & forward_start,const std::vector<AnfNodePtr> & forward_end,const FuncGraphPtr & root)695 void ReorderForForward(const std::vector<AnfNodePtr> &forward_start, const std::vector<AnfNodePtr> &forward_end,
696 const FuncGraphPtr &root) {
697 MS_EXCEPTION_IF_NULL(g_device_manager);
698 MS_EXCEPTION_IF_NULL(root);
699 auto manager = root->manager();
700 MS_EXCEPTION_IF_NULL(manager);
701 auto stage_num = g_device_manager->stage_num();
702 auto stage_id = g_device_manager->stage_id();
703 for (size_t i = 1; i < LongToSize(stage_num - stage_id); ++i) {
704 auto prior_node = forward_end[i - 1];
705 auto post_node = forward_start[i];
706 InsertDepend(prior_node, post_node, manager, root);
707 }
708 }
709
ReorderForBackward(const PipelinePair & forward_start_pair,const PipelinePair & forward_end_pair,const PipelinePair & backward_start_pair,const PipelinePair & backward_end_pair,const PipelinePair & forward_end_before_pair,const FuncGraphPtr & root)710 void ReorderForBackward(const PipelinePair &forward_start_pair, const PipelinePair &forward_end_pair,
711 const PipelinePair &backward_start_pair, const PipelinePair &backward_end_pair,
712 const PipelinePair &forward_end_before_pair, const FuncGraphPtr &root) {
713 MS_EXCEPTION_IF_NULL(g_device_manager);
714 MS_EXCEPTION_IF_NULL(root);
715 auto manager = root->manager();
716 MS_EXCEPTION_IF_NULL(manager);
717 auto stage_num = g_device_manager->stage_num();
718 auto stage_id = g_device_manager->stage_id();
719 for (size_t i = LongToSize(stage_num - stage_id); i < (forward_start_pair.first.size()); ++i) {
720 auto prior_node1 = forward_end_before_pair.second[i];
721 auto post_node1 = backward_start_pair.first[LongToSize(SizeToLong(i) - stage_num + stage_id + 1)];
722 InsertDepend(prior_node1, post_node1, manager, root, TagForSendRecDepend(prior_node1, post_node1));
723 auto prior_node2 = backward_end_pair.second[LongToSize(SizeToLong(i) - stage_num + stage_id)];
724 prior_node2 = GetCallBackwardEndNext(prior_node2);
725 auto post_node2 = forward_start_pair.first[i];
726 InsertDepend(prior_node2, post_node2, manager, root, TagForSendRecDepend(prior_node2, post_node2));
727 }
728 for (size_t i = LongToSize(stage_num - stage_id); i < (forward_start_pair.first.size() + 1); ++i) {
729 if (!IsLastStage()) {
730 auto prior_node3 = backward_start_pair.second[LongToSize(SizeToLong(i) - stage_num + stage_id)];
731 auto post_node3 = forward_end_pair.first[i - 1];
732 InsertDepend(prior_node3, post_node3, manager, root, TagForSendRecDepend(prior_node3, post_node3));
733 auto prior_node4 = forward_end_pair.second[i - 1];
734 auto post_node4 = backward_end_pair.first[LongToSize(SizeToLong(i) - stage_num + stage_id)];
735 InsertDepend(prior_node4, post_node4, manager, root, TagForSendRecDepend(prior_node4, post_node4));
736 }
737 }
738 for (size_t j = LongToSize(SizeToLong(backward_start_pair.first.size()) - stage_num + stage_id + 1);
739 j < backward_start_pair.first.size(); ++j) {
740 auto prior_node5 = backward_end_pair.second[j - 1];
741 prior_node5 = GetCallBackwardEndNext(prior_node5);
742 auto post_node5 = backward_start_pair.first[j];
743 InsertDepend(prior_node5, post_node5, manager, root, TagForSendRecDepend(prior_node5, post_node5));
744 }
745 if (!IsLastStage()) {
746 auto prior_node6 = forward_end_before_pair.second[LongToSize(stage_num - 1 - stage_id)];
747 auto post_node6 = backward_start_pair.first[0];
748 InsertDepend(prior_node6, post_node6, manager, root, TagForSendRecDepend(prior_node6, post_node6));
749 }
750 }
751
ReorderForParams(const PipelinePair & backward_params_pair,const PipelinePair & forward_params_pair,const PipelinePair & backward_end_pair,const PipelinePair & forward_start_pair,const FuncGraphPtr & root)752 void ReorderForParams(const PipelinePair &backward_params_pair, const PipelinePair &forward_params_pair,
753 const PipelinePair &backward_end_pair, const PipelinePair &forward_start_pair,
754 const FuncGraphPtr &root) {
755 auto manager = root->manager();
756 MS_EXCEPTION_IF_NULL(manager);
757 if (!forward_params_pair.second.empty()) {
758 auto prior_node = forward_params_pair.second.back();
759 auto post_node = forward_start_pair.first.front();
760 InsertDepend(prior_node, post_node, manager, root);
761 }
762 if (!backward_params_pair.first.empty()) {
763 auto prior_node2 = backward_end_pair.second.back();
764 prior_node2 = GetCallBackwardEndNext(prior_node2);
765 auto post_node2 = backward_params_pair.first.front();
766 InsertDepend(prior_node2, post_node2, manager, root);
767 }
768 }
769
GetMicroBatch(const AnfNodePtr & node)770 int64_t GetMicroBatch(const AnfNodePtr &node) {
771 MS_EXCEPTION_IF_NULL(node);
772 auto cnode = node->cast<CNodePtr>();
773 MS_EXCEPTION_IF_NULL(cnode);
774 auto micro_value = cnode->GetPrimalAttr(MICRO);
775 MS_EXCEPTION_IF_NULL(micro_value);
776 return GetValue<int64_t>(micro_value);
777 }
778
CommonDeduplicate(const std::vector<AnfNodePtr> & node_vector,std::vector<AnfNodePtr> * out_vec_begin,std::vector<AnfNodePtr> * out_vec_end,const FuncGraphPtr & root,int64_t micro_max,int64_t seg_max,int64_t h,bool is_train)779 void CommonDeduplicate(const std::vector<AnfNodePtr> &node_vector, std::vector<AnfNodePtr> *out_vec_begin,
780 std::vector<AnfNodePtr> *out_vec_end, const FuncGraphPtr &root, int64_t micro_max,
781 int64_t seg_max, int64_t h, bool is_train) {
782 std::vector<AnfNodePtr> temp_vec;
783 auto manager = root->manager();
784 for (int64_t i = 0; i <= micro_max; ++i) {
785 temp_vec.clear();
786 if (!is_train) {
787 temp_vec = node_vector;
788 } else {
789 for (auto &node : node_vector) {
790 auto node_micro = GetMicroBatch(node);
791 if (seg_max >= 1) {
792 auto node_seg = GetSegment(node);
793 if (node_micro == i && node_seg == h) {
794 temp_vec.push_back(node);
795 }
796 } else {
797 if (node_micro == i) {
798 temp_vec.push_back(node);
799 }
800 }
801 }
802 }
803 if (temp_vec.empty()) {
804 MS_LOG(INFO) << "No Duplicate MicroBatch.";
805 continue;
806 }
807 if (temp_vec.size() == 1) {
808 if (seg_max >= 1) {
809 MS_LOG(WARNING) << "Single element, no need to deduplicate.";
810 out_vec_begin->push_back(temp_vec.front());
811 out_vec_end->push_back(temp_vec.back());
812 }
813 continue;
814 }
815 std::sort(temp_vec.begin(), temp_vec.end(), CompFunc);
816 for (size_t j = 0; j < temp_vec.size() - 1; ++j) {
817 auto prior_node = temp_vec[j];
818 prior_node = GetCallBackwardEndNext(prior_node);
819 auto post_node = temp_vec[j + 1];
820 InsertDepend(prior_node, post_node, manager, root);
821 }
822 if (!temp_vec.empty()) {
823 out_vec_begin->push_back(temp_vec.front());
824 out_vec_end->push_back(temp_vec.back());
825 }
826 }
827 }
828
GetForwardEndBeforePair(const PipelinePair & forward_end_pair)829 PipelinePair GetForwardEndBeforePair(const PipelinePair &forward_end_pair) {
830 PipelinePair forward_end_before_pair;
831 if (!IsLastStage()) {
832 for (auto &node : forward_end_pair.first) {
833 auto cnode = node->cast<CNodePtr>();
834 auto temp_node = GetActualOp(cnode->input(1));
835 MS_EXCEPTION_IF_NULL(temp_node);
836 forward_end_before_pair.first.push_back(temp_node);
837 }
838 for (auto &node : forward_end_pair.second) {
839 auto cnode = node->cast<CNodePtr>();
840 auto temp_node = GetActualOp(cnode->input(1));
841 MS_EXCEPTION_IF_NULL(temp_node);
842 forward_end_before_pair.second.push_back(temp_node);
843 }
844 } else {
845 forward_end_before_pair = forward_end_pair;
846 }
847 return forward_end_before_pair;
848 }
849
GetMicroMax(const FuncGraphPtr & root,const std::vector<AnfNodePtr> & forward_end)850 int64_t GetMicroMax(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &forward_end) {
851 int64_t micro_max = 0;
852 if (forward_end.empty()) {
853 MS_LOG(EXCEPTION) << "can not find the end node of pipeline, you are advised to use 'PipelineCell' to fix it.";
854 } else {
855 auto forward_end_cnode = forward_end.back()->cast<CNodePtr>();
856 auto micro_size = forward_end_cnode->GetPrimalAttr(MICRO);
857 MS_EXCEPTION_IF_NULL(micro_size);
858 micro_max = GetValue<int64_t>(micro_size);
859 }
860 return micro_max;
861 }
862
GetSegment(const AnfNodePtr & node)863 int64_t GetSegment(const AnfNodePtr &node) {
864 MS_EXCEPTION_IF_NULL(node);
865 auto cnode = node->cast<CNodePtr>();
866 MS_EXCEPTION_IF_NULL(cnode);
867 auto seg_value = cnode->GetPrimalAttr(SEGMENT);
868 MS_EXCEPTION_IF_NULL(seg_value);
869 return GetValue<int64_t>(seg_value);
870 }
871
BroadCastMicroBatch(const CNodePtr & node,NodeUsersMap * node_users_map,const ValuePtr & value,size_t max_depth)872 void BroadCastMicroBatch(const CNodePtr &node, NodeUsersMap *node_users_map, const ValuePtr &value, size_t max_depth) {
873 auto node_users = (*node_users_map)[node];
874 if (max_depth > MAX_RECURSIVE_DEPTH) {
875 MS_LOG(EXCEPTION) << "Recursive call is larger than 100000.";
876 }
877 for (auto &node_pair : node_users) {
878 auto user_node = node_pair.first->cast<CNodePtr>();
879 if (user_node->HasPrimalAttr(MICRO) || IsPrimitiveCNode(user_node, prim::kPrimUpdateState)) {
880 continue;
881 }
882 user_node->AddPrimalAttr(MICRO, value);
883 BroadCastMicroBatch(user_node, node_users_map, value, max_depth + 1);
884 }
885 }
886
BroadCastNeedGrad(const AnfNodePtr & node,NodeUsersMap * node_user_map,const FuncGraphPtr & root)887 void BroadCastNeedGrad(const AnfNodePtr &node, NodeUsersMap *node_user_map, const FuncGraphPtr &root) {
888 auto node_users = (*node_user_map)[node];
889 for (auto &node_user : node_users) {
890 auto cnode = node_user.first->cast<CNodePtr>();
891 MS_EXCEPTION_IF_NULL(cnode);
892 if (cnode->HasPrimalAttr(NEED_GRAD)) {
893 continue;
894 }
895 if (cnode->func_graph() == root) {
896 continue;
897 }
898 cnode->AddPrimalAttr(NEED_GRAD, MakeValue(1));
899 BroadCastNeedGrad(cnode, node_user_map, root);
900 }
901 }
902
903 // Label node that need backpropagation
LabelNeedGrad(const FuncGraphManagerPtr & manager,const FuncGraphPtr & root)904 void LabelNeedGrad(const FuncGraphManagerPtr &manager, const FuncGraphPtr &root) {
905 auto parameters = root->parameters();
906 auto &node_user_map = manager->node_users();
907 for (auto ¶meter : parameters) {
908 if (!ParameterRequireGrad(parameter)) {
909 continue;
910 }
911 auto param_ptr = parameter->cast<ParameterPtr>();
912 MS_EXCEPTION_IF_NULL(param_ptr);
913 if (param_ptr->name().find(ACCU_GRADS) != std::string::npos) {
914 continue;
915 }
916 BroadCastNeedGrad(parameter, &node_user_map, root);
917 }
918 }
919
LastStageEndNode(const std::vector<AnfNodePtr> & all_nodes,const FuncGraphManagerPtr & manager,const FuncGraphPtr & root)920 void LastStageEndNode(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphManagerPtr &manager,
921 const FuncGraphPtr &root) {
922 if (!IsLastStage()) {
923 return;
924 }
925 LabelNeedGrad(manager, root);
926 for (auto &node : all_nodes) {
927 if (!node->isa<CNode>()) {
928 continue;
929 }
930 auto cnode = node->cast<CNodePtr>();
931 if (!cnode->HasPrimalAttr(MICRO)) {
932 continue;
933 }
934 auto prim = GetCNodePrimitive(node);
935 if (prim && prim->HasAttr(PIPELINE_END)) {
936 for (size_t i = 0; i < cnode->size(); ++i) {
937 auto temp_node = GetRealKernelNode(cnode->input(i), -1, nullptr).first;
938 if (!temp_node->isa<CNode>()) {
939 continue;
940 }
941 auto temp_prim = GetCNodePrimitive(temp_node);
942 if (!temp_prim || temp_prim->HasAttr(PIPELINE_END)) {
943 continue;
944 }
945 InsertVirtualPipelineEndNode(cnode, manager, i);
946 }
947 }
948 }
949 }
950
Micro(const CNodePtr & cnode,NodeUsersMap * node_users_map,size_t max_depth)951 ValuePtr Micro(const CNodePtr &cnode, NodeUsersMap *node_users_map, size_t max_depth) {
952 if (max_depth > MAX_RECURSIVE_DEPTH) {
953 MS_LOG(EXCEPTION) << "Recursive call is larger than 100000.";
954 }
955 if (cnode->HasPrimalAttr(MICRO)) {
956 return cnode->GetPrimalAttr(MICRO);
957 }
958 auto node_users = (*node_users_map)[cnode];
959 for (auto &node_pair : node_users) {
960 auto user_node = node_pair.first->cast<CNodePtr>();
961 auto micro = Micro(user_node, node_users_map, max_depth + 1);
962 if (micro) {
963 return micro;
964 }
965 }
966 return nullptr;
967 }
968
ParameterStartNode(const std::vector<AnfNodePtr> & all_nodes,const FuncGraphManagerPtr & manager)969 void ParameterStartNode(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphManagerPtr &manager) {
970 auto &node_users_map = manager->node_users();
971 for (auto &node : all_nodes) {
972 if (!node->isa<CNode>()) {
973 continue;
974 }
975 auto cnode = node->cast<CNodePtr>();
976 auto prim = GetCNodePrimitive(node);
977 if (prim && prim->HasAttr(PARAMETER_START_SHARE_CELL)) {
978 cnode->AddPrimalAttr(PARAMETER_START_SHARE_CELL, prim->GetAttr(PARAMETER_START_SHARE_CELL));
979 continue;
980 }
981 if (prim && prim->HasAttr(PARAMETER_START)) {
982 auto micro = Micro(cnode, &node_users_map, 0);
983 MS_EXCEPTION_IF_NULL(micro);
984 auto new_prim = prim->Clone();
985 new_prim->SetAttrs(prim->attrs());
986 manager->SetEdge(cnode, 0, NewValueNode(new_prim));
987 cnode->AddPrimalAttr(MICRO, micro);
988 cnode->AddPrimalAttr(PARAMETER_START, micro);
989 int64_t seg = 0;
990 cnode->AddPrimalAttr(SEGMENT, MakeValue(seg));
991 }
992 }
993 }
994
HandleMicroBatch(const std::vector<AnfNodePtr> & all_nodes,const FuncGraphManagerPtr & manager)995 void HandleMicroBatch(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphManagerPtr &manager) {
996 auto &node_users_map = manager->node_users();
997 for (auto &node : all_nodes) {
998 if (!node->isa<CNode>()) {
999 continue;
1000 }
1001 auto cnode = node->cast<CNodePtr>();
1002 if (!cnode->HasPrimalAttr(MICRO)) {
1003 continue;
1004 }
1005 auto micro = cnode->GetPrimalAttr(MICRO);
1006 MS_EXCEPTION_IF_NULL(micro);
1007 BroadCastMicroBatch(cnode, &node_users_map, micro, 0);
1008 }
1009 }
1010
GetActualOp(const AnfNodePtr & node)1011 AnfNodePtr GetActualOp(const AnfNodePtr &node) {
1012 if (IsPrimitiveCNode(node, prim::kPrimDepend)) {
1013 auto cnode = node->cast<CNodePtr>();
1014 return cnode->input(1);
1015 }
1016 return node;
1017 }
1018
GetBorderNode(std::vector<AnfNodePtr> * forward_start,std::vector<AnfNodePtr> * forward_end,std::vector<AnfNodePtr> * backward_start,std::vector<AnfNodePtr> * backward_end,std::vector<AnfNodePtr> * forward_params,std::vector<AnfNodePtr> * backward_params,std::vector<AnfNodePtr> * allreduce_params,const FuncGraphPtr & root)1019 void GetBorderNode(std::vector<AnfNodePtr> *forward_start, std::vector<AnfNodePtr> *forward_end,
1020 std::vector<AnfNodePtr> *backward_start, std::vector<AnfNodePtr> *backward_end,
1021 std::vector<AnfNodePtr> *forward_params, std::vector<AnfNodePtr> *backward_params,
1022 std::vector<AnfNodePtr> *allreduce_params, const FuncGraphPtr &root) {
1023 std::list<ValuePtr> name_list = {};
1024 int64_t slice_index = 0;
1025 auto all_nodes = DeepScopedGraphSearch(root->get_return());
1026 for (auto &node : all_nodes) {
1027 if (!node->isa<CNode>() || IsPrimitiveCNode(node, prim::kPrimDepend) ||
1028 IsPrimitiveCNode(node, prim::kPrimZerosLike)) {
1029 continue;
1030 }
1031 auto prim = GetCNodePrimitive(node);
1032 auto cnode = node->cast<CNodePtr>();
1033 auto share_cell = EnableShareCell();
1034 if (share_cell && cnode->HasPrimalAttr(PARAMETER_START)) {
1035 backward_end->push_back(node);
1036 }
1037 if (cnode->HasPrimalAttr(kPrimalAttrForwardNodeName)) {
1038 auto forward_node_name = cnode->GetPrimalAttr(kPrimalAttrForwardNodeName);
1039 if (std::find(name_list.begin(), name_list.end(), forward_node_name) != name_list.end()) {
1040 continue;
1041 }
1042 name_list.push_back(forward_node_name);
1043 if (cnode->HasPrimalAttr(PIPELINE_END)) {
1044 backward_start->push_back(node);
1045 }
1046 if (cnode->HasPrimalAttr(PIPELINE_BEGIN)) {
1047 backward_end->push_back(node);
1048 }
1049 if (!share_cell && cnode->HasPrimalAttr(PARAMETER_START)) {
1050 backward_end->push_back(node);
1051 }
1052 if (cnode->HasPrimalAttr(PIPELINE_PARAM)) {
1053 backward_params->push_back(node);
1054 }
1055 if (prim->HasAttr(PARAMETER_MICRO)) {
1056 allreduce_params->push_back(node);
1057 }
1058 continue;
1059 }
1060 // the return of cnode->HasPrimalAttr(kPrimalAttrForwardNodeName) is false.
1061 if (cnode->HasPrimalAttr(PIPELINE_BEGIN)) {
1062 if (IsPrimitiveCNode(cnode, prim::kPrimStridedSlice)) {
1063 cnode->AddPrimalAttr(SLICE_INDEX, MakeValue(slice_index));
1064 slice_index += 1;
1065 }
1066 forward_start->push_back(node);
1067 }
1068 if (cnode->HasPrimalAttr(PIPELINE_END)) {
1069 forward_end->push_back(node);
1070 }
1071 if (cnode->HasPrimalAttr(PIPELINE_PARAM)) {
1072 forward_params->push_back(node);
1073 }
1074 }
1075 std::sort((*backward_start).begin(), (*backward_start).end(), CompFuncBySegDescending);
1076 std::sort((*backward_end).begin(), (*backward_end).end(), CompFuncBySegDescending);
1077 std::sort((*forward_start).begin(), (*forward_start).end(), CompFuncBySegAscending);
1078 std::sort((*forward_end).begin(), (*forward_end).end(), CompFuncBySegAscending);
1079 std::sort((*backward_params).begin(), (*backward_params).end(), CompFunc);
1080 std::sort((*forward_params).begin(), (*forward_params).end(), CompFunc);
1081 }
1082
CheckBorderNode(const PipelinePair & forward_start_pair,const PipelinePair & forward_end_pair,const PipelinePair & backward_start_pair,const PipelinePair & backward_end_pair,std::vector<int64_t> seg_micro_max)1083 void CheckBorderNode(const PipelinePair &forward_start_pair, const PipelinePair &forward_end_pair,
1084 const PipelinePair &backward_start_pair, const PipelinePair &backward_end_pair,
1085 std::vector<int64_t> seg_micro_max) {
1086 auto micro_size = LongToSize(seg_micro_max[0] + 1);
1087 auto seg_size = LongToSize(seg_micro_max[1] + 1);
1088 auto total_micro_size = micro_size * seg_size;
1089 std::string cause = ". One possible cause is that the @lazy_inline decorator is misplaced.";
1090 if (forward_start_pair.first.size() != total_micro_size) {
1091 MS_LOG(EXCEPTION) << "forward_node's size:" << forward_start_pair.first.size()
1092 << "is not equal to micro size:" << total_micro_size << cause;
1093 }
1094 if (forward_end_pair.first.size() != total_micro_size) {
1095 MS_LOG(EXCEPTION) << "forward_node's size:" << forward_end_pair.first.size()
1096 << "is not equal to micro size:" << total_micro_size << cause;
1097 }
1098 if (backward_start_pair.first.size() != total_micro_size) {
1099 MS_LOG(EXCEPTION) << "backward_node's size:" << backward_start_pair.first.size()
1100 << "is not equal to micro size:" << total_micro_size << cause;
1101 }
1102 if (backward_end_pair.first.size() != total_micro_size) {
1103 MS_LOG(EXCEPTION) << "backward_node's size:" << backward_end_pair.first.size()
1104 << "is not equal to micro size:" << total_micro_size << cause;
1105 }
1106 }
1107
Reorder(const FuncGraphPtr & root)1108 void Reorder(const FuncGraphPtr &root) {
1109 std::vector<AnfNodePtr> forward_start;
1110 std::vector<AnfNodePtr> forward_end;
1111 std::vector<AnfNodePtr> forward_params;
1112 std::vector<AnfNodePtr> backward_start;
1113 std::vector<AnfNodePtr> backward_end;
1114 std::vector<AnfNodePtr> backward_params;
1115 std::vector<AnfNodePtr> allreduce_params;
1116 SetParameterStartForCellShare(root);
1117 GetBorderNode(&forward_start, &forward_end, &backward_start, &backward_end, &forward_params, &backward_params,
1118 &allreduce_params, root);
1119 int64_t micro_max = GetMicroMax(root, forward_end);
1120 std::vector<int64_t> seg_micro_max{micro_max, 0};
1121 auto backward_start_pair = Deduplicate(backward_start, root, micro_max, 0, true);
1122 auto backward_end_pair = Deduplicate(backward_end, root, micro_max, 0, true);
1123 auto forward_start_pair = Deduplicate(forward_start, root, micro_max, 0, true);
1124 auto forward_end_pair = Deduplicate(forward_end, root, micro_max, 0, true);
1125 auto forward_params_pair = Deduplicate(forward_params, root, micro_max, 0, true);
1126 auto backward_params_pair = Deduplicate(backward_params, root, micro_max, 0, true);
1127 CheckBorderNode(forward_start_pair, forward_end_pair, backward_start_pair, backward_end_pair, seg_micro_max);
1128 auto forward_end_before_pair = GetForwardEndBeforePair(forward_end_pair);
1129 auto ret_after = root->get_return();
1130 MS_EXCEPTION_IF_NULL(ret_after);
1131 auto all_nodes = DeepScopedGraphSearch(ret_after);
1132 auto manager = root->manager();
1133 for (auto &node : all_nodes) {
1134 if (!node->isa<CNode>()) {
1135 continue;
1136 }
1137 if (IsSomePrimitive(node->cast<CNodePtr>(), kNPUClearFloatStatusOpName)) {
1138 InsertDepend(node, forward_end.front(), manager, root);
1139 break;
1140 }
1141 }
1142 ReorderForForward(forward_start_pair.first, forward_end_pair.second, root);
1143 ReorderForBackward(forward_start_pair, forward_end_pair, backward_start_pair, backward_end_pair,
1144 forward_end_before_pair, root);
1145 ReorderForParams(backward_params_pair, forward_params_pair, backward_end_pair, forward_start_pair, root);
1146 }
1147
ReorderForPredict(const FuncGraphPtr & root,const FuncGraphManagerPtr & manager)1148 void ReorderForPredict(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager) {
1149 std::vector<AnfNodePtr> forward_end;
1150 std::vector<AnfNodePtr> forward_start;
1151 std::vector<AnfNodePtr> forward_params;
1152 int64_t slice_index = 0;
1153 for (auto &node : root->nodes()) {
1154 if (!node->isa<CNode>()) {
1155 continue;
1156 }
1157 auto cnode = node->cast<CNodePtr>();
1158 if (cnode->HasPrimalAttr(PIPELINE_BEGIN)) {
1159 if (IsPrimitiveCNode(cnode, prim::kPrimStridedSlice)) {
1160 cnode->AddPrimalAttr(SLICE_INDEX, MakeValue(slice_index));
1161 slice_index += 1;
1162 }
1163 forward_start.push_back(node);
1164 }
1165 if (cnode->HasPrimalAttr(PIPELINE_END)) {
1166 forward_end.push_back(node);
1167 }
1168 if (cnode->HasPrimalAttr(PIPELINE_PARAM)) {
1169 forward_params.push_back(node);
1170 }
1171 }
1172 std::sort(forward_start.begin(), forward_start.end(), CompFunc);
1173 std::sort(forward_end.begin(), forward_end.end(), CompFunc);
1174 std::sort(forward_params.begin(), forward_params.end(), CompFunc);
1175 auto forward_start_pair = Deduplicate(forward_start, root, 0, 0, false);
1176 auto forward_end_pair = Deduplicate(forward_end, root, 0, 0, false);
1177 auto forward_params_pair = Deduplicate(forward_params, root, 0, 0, false);
1178 if (!forward_end.empty() && !forward_params.empty()) {
1179 InsertDepend(forward_params_pair.second[0], forward_end_pair.first[0], manager, root);
1180 }
1181 if (!forward_start.empty() && !forward_params.empty()) {
1182 InsertDepend(forward_params_pair.second[0], forward_start_pair.first[0], manager, root);
1183 }
1184 }
1185
GetRank()1186 int64_t GetRank() {
1187 auto ms_context = MsContext::GetInstance();
1188 MS_EXCEPTION_IF_NULL(ms_context);
1189 auto world_group = GetWorldGroup();
1190 int64_t global_rank = parallel::ParallelContext::GetInstance()->global_rank();
1191 uint32_t rank_id = 0;
1192 if (!parallel::ParallelContext::GetInstance()->global_rank_is_set()) {
1193 if (!CommManager::GetInstance().GetRankID(world_group, &rank_id)) {
1194 MS_LOG(EXCEPTION) << "Get rank id failed.";
1195 }
1196 global_rank = UintToInt(rank_id);
1197 }
1198 return global_rank;
1199 }
1200
GetWorldGroup()1201 std::string GetWorldGroup() {
1202 auto context = MsContext::GetInstance();
1203 MS_EXCEPTION_IF_NULL(context);
1204 std::string group;
1205 std::string backend = context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
1206 if (backend == kAscendDevice) {
1207 group = parallel::HCCL_WORLD_GROUP;
1208 } else if (backend == kGPUDevice) {
1209 group = parallel::NCCL_WORLD_GROUP;
1210 } else {
1211 MS_LOG(EXCEPTION) << "Invalid backend: " << backend;
1212 }
1213 return group;
1214 }
1215
InferStage()1216 int64_t InferStage() {
1217 auto global_rank = GetRank();
1218 auto world_group = GetWorldGroup();
1219 uint32_t world_rank_size = 0;
1220 int64_t device_num = 0;
1221 auto stage_num = parallel::ParallelContext::GetInstance()->pipeline_stage_split_num();
1222 if (!parallel::ParallelContext::GetInstance()->device_num_is_set()) {
1223 if (!CommManager::GetInstance().GetRankSize(world_group, &world_rank_size)) {
1224 MS_LOG(EXCEPTION) << "Get rank size failed";
1225 }
1226 device_num = UintToInt(world_rank_size);
1227 MS_LOG(INFO) << "Get device num from communication model, the device num is " << device_num;
1228 } else {
1229 device_num = parallel::ParallelContext::GetInstance()->device_num();
1230 }
1231
1232 if (device_num < 1) {
1233 MS_LOG(ERROR) << "For 'PipelineSplit', the argument 'device_num' must be positive, "
1234 "but got the value of device_num: "
1235 << device_num;
1236 }
1237 if (global_rank < 0) {
1238 MS_LOG(ERROR) << "For 'PipelineSplit', the argument 'global_rank' must be nonnegative, "
1239 "but got the value of global_rank: "
1240 << global_rank;
1241 }
1242 if (stage_num == 0) {
1243 MS_LOG(EXCEPTION) << "Stage_num is zero";
1244 }
1245 if (device_num % stage_num != 0) {
1246 MS_LOG(EXCEPTION) << "Device_num must be divisible by the stage_num, got device_num: " << device_num
1247 << " stage_num: " << stage_num;
1248 }
1249 auto per_stage_rank_num = device_num / stage_num;
1250 return global_rank / per_stage_rank_num;
1251 }
1252 } // namespace parallel
1253 } // namespace mindspore
1254