1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/parallel/pipeline_transformer/fold_pipeline_transformer.h"
18 #include <set>
19 #include <vector>
20 #include <string>
21 #include <utility>
22 #include <algorithm>
23 #include <memory>
24 #include "frontend/parallel/pipeline_transformer/pipeline_transformer.h"
25 #include "frontend/parallel/auto_parallel/graph_costmodel.h"
26 #include "frontend/parallel/graph_util/graph_splitter.h"
27 #include "frontend/parallel/ops_info/ops_utils.h"
28 #include "frontend/parallel/group_manager.h"
29 #include "frontend/parallel/parameter_manager.h"
30 #include "include/common/utils/parallel_context.h"
31 #include "frontend/parallel/step_parallel.h"
32 #include "frontend/parallel/node_check.h"
33 #include "frontend/parallel/graph_util/node_info.h"
34 #include "frontend/parallel/graph_util/pipeline_split_utils.h"
35 #include "frontend/parallel/step_parallel_utils.h"
36 #include "ir/anf.h"
37 #include "ir/graph_utils.h"
38 #include "ops/other_ops.h"
39 #include "ops/array_ops.h"
40 #include "ops/framework_ops.h"
41 #include "include/common/utils/comm_manager.h"
42 #include "utils/ms_context.h"
43 #include "utils/parallel_node_check.h"
44
45 namespace mindspore {
46 namespace parallel {
47 mindspore::HashMap<int64_t, int64_t> fold_send_tag_map;
48 mindspore::HashMap<int64_t, int64_t> fold_recv_tag_map;
49
CreateForwardGroup2()50 void FoldPipelineTransformer::CreateForwardGroup2() {
51 auto rank_id = g_device_manager->global_rank();
52 auto stage_id = g_device_manager->stage_id();
53 auto stage_num = g_device_manager->stage_num();
54
55 std::vector<int64_t> forward_rank_list;
56 forward_rank_list.push_back(rank_id);
57 if (stage_id < stage_num - 1) {
58 forward_rank_list.push_back(rank_id + per_stage_rank_num_);
59 } else {
60 forward_rank_list.push_back(rank_id + per_stage_rank_num_ * (0 - stage_id));
61 }
62
63 Group g;
64
65 if (g_device_manager->CreateGroup(forward_rank_list, &g) != SUCCESS) {
66 MS_LOG(EXCEPTION) << "Create forward communication group between all pipeline stages failed, the rank_list is: "
67 << forward_rank_list;
68 }
69
70 std::vector<int64_t> backward_rank_list;
71 if (stage_id == 0) {
72 backward_rank_list.push_back(rank_id + per_stage_rank_num_ * (stage_num - 1));
73 } else {
74 backward_rank_list.push_back(rank_id - per_stage_rank_num_);
75 }
76 backward_rank_list.push_back(rank_id);
77
78 Group g_back;
79 if (g_device_manager->CreateGroup(backward_rank_list, &g_back) != SUCCESS) {
80 MS_LOG(EXCEPTION) << "Create backward communication group between all pipeline stages failed, the rank_list is: "
81 << backward_rank_list;
82 }
83
84 group_.push_back(g.name());
85 group_.push_back(g_back.name());
86 }
HandleSegment(const ValuePtr & value,const FuncGraphPtr & graph)87 void HandleSegment(const ValuePtr &value, const FuncGraphPtr &graph) {
88 MS_EXCEPTION_IF_NULL(graph);
89 auto nodes = graph->nodes();
90 for (auto node : nodes) {
91 if (node->isa<CNode>()) {
92 auto cnode = node->cast<CNodePtr>();
93 MS_LOG(INFO) << "Handle Segment cnode: " << cnode->fullname_with_scope();
94 cnode->AddPrimalAttr(SEGMENT, value);
95 }
96 }
97 }
Coloring()98 void FoldPipelineTransformer::Coloring() {
99 auto need_coloring = true;
100 std::set<int64_t> stage_set;
101 std::set<int64_t> segment_set;
102 if (!IsTraining(manager_)) {
103 is_train_ = false;
104 }
105 while (need_coloring) {
106 need_coloring = false;
107 for (auto &fg : manager_->func_graphs()) {
108 if (fg == root_ && is_train_) {
109 continue;
110 }
111 auto value_nodes = fg->value_nodes();
112 for (auto value_pair = value_nodes.cbegin(); value_pair != value_nodes.cend(); ++value_pair) {
113 auto node = (*value_pair).first;
114 if (!IsValueNode<FuncGraph>(node)) {
115 continue;
116 }
117 auto graph = GetValueNode<FuncGraphPtr>(node);
118 if (graph->stage() == -1) {
119 continue;
120 }
121 (void)stage_set.insert(graph->stage());
122 (void)segment_set.insert(graph->segment());
123 auto node_users = manager_->node_users()[node];
124 HandleSegment(MakeValue(graph->segment()), graph);
125 for (auto &user_pair : node_users) {
126 auto user_node = user_pair.first->cast<CNodePtr>();
127 user_node->set_user_data<NodeStageInfo>(std::make_shared<NodeStageInfo>(graph->stage()));
128 user_node->set_user_data<NodeSegmentInfo>(std::make_shared<NodeSegmentInfo>(graph->segment()));
129 auto user_node_graph = user_node->func_graph();
130 if (graph->stage() == stage_ && user_node_graph->stage() == -1) {
131 user_node_graph->set_stage(graph->stage());
132 MS_LOG(INFO) << "Set_segment in Coloring" << graph->segment();
133 user_node_graph->set_segment(graph->segment());
134 need_coloring = true;
135 }
136 }
137 }
138 }
139 }
140 MS_EXCEPTION_IF_NULL(g_device_manager);
141 auto stage_num = g_device_manager->stage_num();
142 auto segment_num = ParallelContext::GetInstance()->pipeline_segment_split_num();
143 if (SizeToLong(stage_set.size()) != stage_num) {
144 MS_LOG(EXCEPTION) << "Stage num is " << stage_num << " is not equal to stage used: " << stage_set.size();
145 }
146 if (SizeToLong(segment_set.size()) != segment_num) {
147 MS_LOG(EXCEPTION) << "Segment num is " << segment_num << " is not equal to segment used: " << segment_set.size();
148 }
149 }
150
ColorForNodes()151 void FoldPipelineTransformer::ColorForNodes() {
152 for (auto &fg : manager_->func_graphs()) {
153 auto stage = fg->stage();
154 auto segment = fg->segment();
155 if (stage < 0) {
156 continue;
157 }
158 if (segment < 0) {
159 continue;
160 }
161 if (fg == root_ || fg == main_graph_ || fg == shared_cell_) {
162 continue;
163 }
164 auto all_nodes = fg->nodes();
165 for (auto node : all_nodes) {
166 if (node->user_data<NodeStageInfo>() != nullptr) {
167 continue;
168 }
169 node->set_user_data<NodeStageInfo>(std::make_shared<NodeStageInfo>(stage));
170 if (node->user_data<NodeSegmentInfo>() != nullptr) {
171 continue;
172 }
173 node->set_user_data<NodeSegmentInfo>(std::make_shared<NodeSegmentInfo>(segment));
174 }
175 }
176 }
177
BroadCastColoring()178 void FoldPipelineTransformer::BroadCastColoring() {
179 auto need_coloring = true;
180 while (need_coloring) {
181 need_coloring = false;
182 auto all_nodes = main_graph_->nodes();
183 auto node_users = manager_->node_users();
184 for (auto node = all_nodes.cbegin(); node != all_nodes.cend(); ++node) {
185 auto stage_info = (*node)->user_data<NodeStageInfo>();
186 auto segment_info = (*node)->user_data<NodeSegmentInfo>();
187 if (!(*node)->isa<CNode>() || stage_info == nullptr || stage_info->stage() == -1 ||
188 IsPrimitiveCNode(*node, prim::kPrimUpdateState)) {
189 continue;
190 }
191 auto stage = stage_info->stage();
192 auto segment = segment_info->segment();
193 for (auto &user_pair : node_users[*node]) {
194 auto user_node = user_pair.first->cast<CNodePtr>();
195 auto user_stage_info = user_node->user_data<NodeStageInfo>();
196 auto user_segment_info = user_node->user_data<NodeSegmentInfo>();
197 if (user_stage_info == nullptr) {
198 user_node->set_user_data<NodeStageInfo>(std::make_shared<NodeStageInfo>(stage));
199 user_node->set_user_data<NodeSegmentInfo>(std::make_shared<NodeSegmentInfo>(segment));
200 need_coloring = true;
201 continue;
202 }
203 auto user_node_stage = user_stage_info->stage();
204 auto user_node_segment = user_segment_info->segment();
205 if (stage > user_node_stage && segment == user_node_segment) {
206 if (IsValueNode<FuncGraph>(user_node->input(0))) {
207 MS_LOG(WARNING) << "The stage setting is incorrect. PreNode's stage: " << stage
208 << " is larger than NextNode's stage:" << user_node_stage;
209 }
210 user_node->set_user_data<NodeStageInfo>(std::make_shared<NodeStageInfo>(stage));
211 need_coloring = true;
212 }
213 if (segment > user_node_segment) {
214 user_node->set_user_data<NodeSegmentInfo>(std::make_shared<NodeSegmentInfo>(segment));
215 need_coloring = true;
216 }
217 }
218 }
219 }
220 ColorForNodes();
221 }
222
InsertSend(const AnfNodePtr & parameter,int64_t user_node_stage,int64_t node_stage,const ValuePtr & value,int64_t segment)223 SendAttr FoldPipelineTransformer::InsertSend(const AnfNodePtr ¶meter, int64_t user_node_stage, int64_t node_stage,
224 const ValuePtr &value, int64_t segment) {
225 auto dest_rank = global_rank_ + (user_node_stage - node_stage) * per_stage_rank_num_;
226 int64_t send_tag;
227 auto stage_num = g_device_manager->stage_num();
228 if (node_stage == 0 && user_node_stage > 1 && stage_num > 2) {
229 if (fold_recv_tag_map.find(dest_rank) != fold_recv_tag_map.end()) {
230 send_tag = fold_recv_tag_map[dest_rank] + 1;
231 fold_recv_tag_map[dest_rank] += 1;
232 } else {
233 send_tag = 0;
234 fold_recv_tag_map[dest_rank] = 0;
235 }
236 } else {
237 if (fold_send_tag_map.find(dest_rank) != fold_send_tag_map.end()) {
238 send_tag = fold_send_tag_map[dest_rank] + 1;
239 fold_send_tag_map[dest_rank] += 1;
240 } else {
241 send_tag = 0;
242 fold_send_tag_map[dest_rank] = 0;
243 }
244 }
245 Attr attr_tag = std::make_pair(SR_TAG, MakeValue(send_tag));
246 Attr attr_rank = std::make_pair(DEST_RANK, MakeValue(user_node_stage));
247 Attr attr_group = std::make_pair(GROUP, MakeValue(group_[0]));
248 Attr attr_group_back = std::make_pair(GROUP_BACK, MakeValue(group_[1]));
249 if (stage_num > 2) {
250 auto next = (user_node_stage == 0) ? 0 : 1;
251 attr_rank = std::make_pair(DEST_RANK, MakeValue(next));
252 attr_group = std::make_pair(GROUP, MakeValue(group_[0]));
253 attr_group_back = std::make_pair(GROUP_BACK, MakeValue(group_[0]));
254 }
255
256 if (node_stage == 0 && user_node_stage > 1 && stage_num > 2) {
257 attr_group = std::make_pair(GROUP, MakeValue(group_[1]));
258 attr_group_back = std::make_pair(GROUP_BACK, MakeValue(group_[1]));
259 attr_rank = std::make_pair(DEST_RANK, MakeValue(1));
260 }
261 auto graph = enable_share_cell_ ? shared_cell_ : main_graph_;
262 std::vector<AnfNodePtr> send_input = {parameter};
263 OperatorAttrs attrs = {attr_tag, attr_rank, attr_group, attr_group_back};
264 CNodePtr send = CreateCNodeByInputsAndAttr(graph, SEND, SEND, send_input, attrs);
265 auto prim = GetCNodePrimitive(send);
266 AnfNodePtr care_node;
267 bool is_param = true;
268 auto op_info_pair = GetOpInfoPair(parameter, parameter, &care_node, &is_param);
269 auto tensor_info = GetTensorInfo(op_info_pair, is_param);
270
271 auto index = op_info_pair.second;
272 auto op_info = op_info_pair.first;
273 auto slice_shape = tensor_info.slice_shape();
274 auto shape_type_pair = GetShapeType(parameter, slice_shape, 0);
275 prim->set_attr(SHAPE, shape_type_pair.first);
276 prim->set_attr(DTYPE, shape_type_pair.second);
277 if (!is_param) {
278 send->AddPrimalAttr(PIPELINE_END, value);
279 } else {
280 send->AddPrimalAttr(PIPELINE_PARAM, value);
281 send->set_user_data<OperatorInfo>(op_info);
282 send->AddPrimalAttr(PARAM_INDEX, MakeValue(index));
283 auto param = care_node ? care_node : parameter;
284 send->set_user_data<AnfNode>(INPUT_PARAM, param);
285 }
286 send->AddPrimalAttr(MICRO, value);
287 send->AddPrimalAttr(SEGMENT, MakeValue(segment));
288 MS_LOG(INFO) << "Insert Send op, segment is " << segment;
289 send->AddPrimalAttr(DEST_RANK, MakeValue(user_node_stage));
290 OperatorAttrs depend_attrs;
291 CNodePtr depend = CreateCNodeByInputsAndAttr(graph, DEPEND, DEPEND, AnfNodePtrList{parameter, send}, depend_attrs);
292 auto abstract = parameter->abstract();
293 if (care_node) {
294 abstract = care_node->abstract();
295 }
296 depend->set_abstract(abstract);
297 send->set_abstract(abstract);
298 SendAttr send_out = {shape_type_pair.first, shape_type_pair.second, depend};
299
300 send->set_user_data<int64_t>(DEST_RANK, std::make_shared<int64_t>(dest_rank));
301 send->set_user_data<int64_t>(USER_NODE_STAGE, std::make_shared<int64_t>(user_node_stage));
302 return send_out;
303 }
304
ComputeRecvTag(int64_t node_stage,int64_t user_node_stage,int64_t stage_num,int64_t src_rank)305 int64_t FoldPipelineTransformer::ComputeRecvTag(int64_t node_stage, int64_t user_node_stage, int64_t stage_num,
306 int64_t src_rank) {
307 int64_t recv_tag;
308 if (node_stage == 0 && user_node_stage > 1 && stage_num > 2) {
309 if (fold_send_tag_map.find(src_rank) != fold_send_tag_map.end()) {
310 recv_tag = fold_send_tag_map[src_rank] + 1;
311 fold_send_tag_map[src_rank] += 1;
312 } else {
313 recv_tag = 0;
314 fold_send_tag_map[src_rank] = 0;
315 }
316 } else {
317 if (fold_recv_tag_map.find(src_rank) != fold_recv_tag_map.end()) {
318 recv_tag = fold_recv_tag_map[src_rank] + 1;
319 fold_recv_tag_map[src_rank] += 1;
320 } else {
321 recv_tag = 0;
322 fold_recv_tag_map[src_rank] = 0;
323 }
324 }
325 return recv_tag;
326 }
327
InsertReceive(const FuncGraphPtr & graph,const AnfNodePtr & node,const AnfNodePtr & use_node,int index,int64_t user_node_stage,int64_t node_stage,const ValuePtr & value,const AnfNodePtr & graph_param,int64_t segment)328 AnfNodePtr FoldPipelineTransformer::InsertReceive(const FuncGraphPtr &graph, const AnfNodePtr &node,
329 const AnfNodePtr &use_node, int index, int64_t user_node_stage,
330 int64_t node_stage, const ValuePtr &value,
331 const AnfNodePtr &graph_param, int64_t segment) {
332 auto src_rank = global_rank_ - (user_node_stage - node_stage) * per_stage_rank_num_;
333 auto stage_num = g_device_manager->stage_num();
334 auto recv_tag = ComputeRecvTag(node_stage, user_node_stage, stage_num, src_rank);
335 Attr attr_tag = std::make_pair(SR_TAG, MakeValue(recv_tag));
336 Attr attr_rank = std::make_pair(SRC_RANK, MakeValue(node_stage));
337 Attr attr_group = std::make_pair(GROUP, MakeValue(group_[0]));
338 Attr attr_group_back = std::make_pair(GROUP_BACK, MakeValue(group_[1]));
339
340 if (stage_num > 2) {
341 auto next = (user_node_stage == 0) ? 1 : 0;
342 attr_rank = std::make_pair(SRC_RANK, MakeValue(next));
343 attr_group = std::make_pair(GROUP, MakeValue(group_[1]));
344 attr_group_back = std::make_pair(GROUP_BACK, MakeValue(group_[1]));
345 }
346 bool is_param = true;
347 AnfNodePtr care_node;
348 auto op_info_pair = GetOpInfoPair(node, graph_param, &care_node, &is_param);
349 auto tensor_info = GetTensorInfo(op_info_pair, is_param);
350 auto tensor_layout = tensor_info.tensor_layout();
351 Shape slice_shape = tensor_info.slice_shape();
352 auto shape_type_pair = GetShapeType(node, slice_shape, 0);
353 Attr attr_shape = std::make_pair(SHAPE, shape_type_pair.first);
354 Attr attr_dtype = std::make_pair(DTYPE, shape_type_pair.second);
355 if (node_stage == 0 && user_node_stage > 1 && stage_num > 2) {
356 attr_group = std::make_pair(GROUP, MakeValue(group_[0]));
357 attr_group_back = std::make_pair(GROUP_BACK, MakeValue(group_[0]));
358 attr_rank = std::make_pair(SRC_RANK, MakeValue(0));
359 }
360 OperatorAttrs attrs = {attr_tag, attr_rank, attr_shape, attr_dtype, attr_group, attr_group_back};
361 std::vector<AnfNodePtr> recv_input;
362 if (node->isa<Parameter>()) {
363 recv_input = {node};
364 } else {
365 recv_input = {virtual_param_};
366 }
367 auto recv = CreateCNodeByInputsAndAttr(graph, RECEIVE, RECEIVE, recv_input, attrs);
368 if (is_param) {
369 recv->set_user_data<AnfNode>(PIPELINE_PARAM, node);
370 recv->AddPrimalAttr(PIPELINE_PARAM, value);
371 auto param = care_node ? care_node : node;
372 recv->set_user_data<AnfNode>(INPUT_PARAM, param);
373 } else {
374 recv->AddPrimalAttr(PIPELINE_BEGIN, value);
375 }
376 recv->AddPrimalAttr(MICRO, value);
377 recv->AddPrimalAttr(SRC_RANK, MakeValue(node_stage));
378 recv->AddPrimalAttr(SEGMENT, MakeValue(segment));
379 MS_LOG(INFO) << "Insertreceive segment" << segment;
380 auto node_abstract = node->abstract();
381 if (node->isa<CNode>()) {
382 auto cnode = node->cast<CNodePtr>();
383 MS_EXCEPTION_IF_NULL(cnode);
384 if (IsValueNode<FuncGraph>(cnode->input(0))) {
385 auto output = GetValueNode<FuncGraphPtr>(cnode->input(0))->output();
386 MS_EXCEPTION_IF_NULL(output);
387 node_abstract = output->abstract();
388 }
389 }
390 MS_EXCEPTION_IF_NULL(node_abstract);
391 recv->set_abstract(node_abstract);
392 if (node->isa<Parameter>()) {
393 BaseShapePtr parallel_shape = std::make_shared<abstract::Shape>(slice_shape);
394 auto abstract_clone = node->abstract()->Clone();
395 MS_EXCEPTION_IF_NULL(abstract_clone);
396 abstract_clone->set_shape(parallel_shape);
397 node->set_abstract(abstract_clone);
398 node->set_user_data<TensorLayout>(std::make_shared<TensorLayout>(tensor_layout));
399 auto actual_param = RefParameterToActualParameter(node);
400 if (actual_param) {
401 actual_param->set_user_data<TensorLayout>(std::make_shared<TensorLayout>(tensor_layout));
402 auto actual_param_abstract = actual_param->abstract()->Clone();
403 actual_param_abstract->set_shape(parallel_shape);
404 actual_param->set_abstract(actual_param_abstract);
405 }
406 }
407 recv->set_user_data<TensorLayout>(std::make_shared<TensorLayout>(tensor_layout));
408 recv->set_user_data<OperatorInfo>(op_info_pair.first);
409
410 recv->set_user_data<int64_t>(SRC_RANK, std::make_shared<int64_t>(src_rank));
411 recv->set_user_data<int64_t>(NODE_STAGE, std::make_shared<int64_t>(node_stage));
412 recv->set_user_data<Type>(SLICE_DTYPE, shape_type_pair.second);
413 recv->set_user_data<Shape>(SLICE_SHAPE, std::make_shared<Shape>(slice_shape));
414
415 manager_->SetEdge(use_node, index, recv);
416 return recv;
417 }
418
Reuse(const AnfNodePtr & node,int64_t stage,int64_t node_segment,const std::vector<AnfNodePtr> & out_input,const std::vector<int64_t> & out_input_segment,const std::string & tag)419 AnfNodePtr FoldPipelineTransformer::Reuse(const AnfNodePtr &node, int64_t stage, int64_t node_segment,
420 const std::vector<AnfNodePtr> &out_input,
421 const std::vector<int64_t> &out_input_segment, const std::string &tag) {
422 std::vector<std::pair<AnfNodePtr, int64_t>> zipped;
423 std::transform(out_input.begin(), out_input.end(), out_input_segment.begin(), std::back_inserter(zipped),
424 [](const auto &send, const auto &send_segment) { return std::make_pair(send, send_segment); });
425
426 for (auto &zipp : zipped) {
427 auto input = zipp.first;
428 auto send_segment = zipp.second;
429 auto cnode = input->cast<CNodePtr>();
430 if (!cnode) {
431 continue;
432 }
433 if (IsPrimitiveCNode(cnode, prim::kPrimDepend)) {
434 cnode = cnode->input(DEPEND_NODE_SOURCE_INDEX)->cast<CNodePtr>();
435 }
436 if (cnode->input(1) == node) {
437 auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
438 auto dest_rank_send = GetValue<int64_t>(prim->GetAttr(tag));
439 if (dest_rank_send == stage && node_segment == send_segment) {
440 return input;
441 }
442 }
443 }
444 return nullptr;
445 }
446
HandleParameterGraph(const AnfNodePtr & node,const AnfNodePtr & use_node,int64_t stage,int64_t user_stage,const ValuePtr & micro,size_t pos,const std::vector<AnfNodePtr> & ops)447 AnfNodePtr FoldPipelineTransformer::HandleParameterGraph(const AnfNodePtr &node, const AnfNodePtr &use_node,
448 int64_t stage, int64_t user_stage, const ValuePtr µ,
449 size_t pos, const std::vector<AnfNodePtr> &ops) {
450 CNodePtr call_node = nullptr;
451 auto argument = GetRealKernelNode(node, -1, &call_node).first;
452
453 auto use_cnode = use_node->cast<CNodePtr>();
454 MS_EXCEPTION_IF_NULL(use_cnode);
455 if (!IsValueNode<FuncGraph>(use_cnode->input(0))) {
456 MS_LOG(EXCEPTION) << "Parameter must be used by a graph, but got: " << use_cnode->DebugString();
457 }
458 auto use_graph = GetValueNode<FuncGraphPtr>(use_cnode->input(0));
459 auto use_parameter_list = use_graph->parameters();
460 auto parameter = use_parameter_list.at(pos - 1);
461
462 // insert receive
463 if (stage_ == user_stage) {
464 auto recv = PipelineTransformer::Reuse(argument, stage, ops, SRC_RANK);
465 if (recv) {
466 manager_->SetEdge(use_node, SizeToInt(pos), recv);
467 return nullptr;
468 }
469 auto root_param = argument;
470 if (argument->isa<Parameter>() && argument->func_graph() != root_) {
471 root_param = GetArgumentsByParameter(argument);
472 }
473 (void)parameter_color_map_[root_param].insert(user_stage);
474 auto graph = enable_share_cell_ ? shared_cell_ : main_graph_;
475 return InsertReceive(graph, argument, use_node, SizeToInt(pos), user_stage, stage, micro, parameter, 0);
476 }
477 // insert send
478 if (PipelineTransformer::Reuse(argument, user_stage, ops, DEST_RANK)) {
479 return nullptr;
480 }
481 auto send_out = InsertSend(argument, user_stage, stage_, micro, 0);
482 send_out.depend->set_user_data<Type>(DTYPE, send_out.type);
483 send_out.depend->set_user_data<ValueList>(SHAPE, send_out.shape);
484 return send_out.depend;
485 }
486
IsStageConflict(int64_t node_stage,int64_t user_node_stage,int64_t node_segment,int64_t user_node_segment,int64_t stage_num,bool isEmbed)487 bool IsStageConflict(int64_t node_stage, int64_t user_node_stage, int64_t node_segment, int64_t user_node_segment,
488 int64_t stage_num, bool isEmbed) {
489 if (isEmbed || (node_stage < user_node_stage && node_segment == user_node_segment) ||
490 (node_stage == stage_num - 1 && user_node_stage == 0 && node_segment < user_node_segment)) {
491 return true;
492 }
493 return false;
494 }
495
CutBorderForNode(const FuncGraphPtr & graph,const AnfNodePtr & node,std::vector<AnfNodePtr> * send_ops,std::vector<int64_t> * send_ops_segment,std::vector<AnfNodePtr> * receive_ops)496 void FoldPipelineTransformer::CutBorderForNode(const FuncGraphPtr &graph, const AnfNodePtr &node,
497 std::vector<AnfNodePtr> *send_ops,
498 std::vector<int64_t> *send_ops_segment,
499 std::vector<AnfNodePtr> *receive_ops) {
500 auto stage_info = node->user_data<NodeStageInfo>();
501 auto segment_info = node->user_data<NodeSegmentInfo>();
502 auto node_users = manager_->node_users()[node];
503 AnfNodePtr receive = nullptr;
504 for (auto &user_pair : node_users) {
505 auto user_node = user_pair.first;
506 auto node_stage = stage_info->stage();
507 auto node_segment = segment_info->segment();
508 auto user_stage_info = user_node->user_data<NodeStageInfo>();
509 if (user_stage_info == nullptr) {
510 continue;
511 }
512 auto user_segment_info = user_node->user_data<NodeSegmentInfo>();
513 if (user_segment_info == nullptr) {
514 continue;
515 }
516 auto user_node_stage = user_stage_info->stage();
517 if (node_stage != stage_ && user_node_stage != stage_) {
518 continue;
519 }
520 auto micro = user_node->cast<CNodePtr>()->GetPrimalAttr(MICRO);
521 auto user_node_segment = user_segment_info->segment();
522 if (!micro) {
523 MS_LOG(INFO) << "Can't find micro_batch information, use micro(0)";
524 micro = MakeValue(int64_t(0));
525 }
526 auto stage_num = g_device_manager->stage_num();
527
528 bool isEmbed = node_stage < user_node_stage && node_segment != user_node_segment;
529 if (IsStageConflict(node_stage, user_node_stage, node_segment, user_node_segment, stage_num, isEmbed)) {
530 if (node_stage == stage_) {
531 if (IsParameterGraph(node) && isEmbed) {
532 auto send_depend = HandleParameterGraph(node, user_node, node_stage, user_node_stage, micro,
533 IntToSize(user_pair.second), *send_ops);
534 if (!send_depend) {
535 continue;
536 }
537 (void)send_ops->insert(send_ops->cbegin(), send_depend);
538 (void)send_ops_segment->insert(send_ops_segment->begin(), node_segment);
539 continue;
540 }
541 if (Reuse(node, user_node_stage, user_node_segment, *send_ops, *send_ops_segment, DEST_RANK)) {
542 continue;
543 }
544 auto send_out = InsertSend(node, user_node_stage, node_stage, micro, node_segment);
545 MS_EXCEPTION_IF_NULL(send_out.depend);
546 send_ops->push_back(send_out.depend);
547 send_ops_segment->push_back(node_segment);
548 send_out.depend->set_user_data<Type>(DTYPE, send_out.type);
549 send_out.depend->set_user_data<ValueList>(SHAPE, send_out.shape);
550 } else {
551 if (!receive) {
552 if (IsParameterGraph(node)) {
553 receive = HandleParameterGraph(node, user_node, node_stage, user_node_stage, micro,
554 IntToSize(user_pair.second), *receive_ops);
555 if (!receive) {
556 continue;
557 }
558 receive_ops->push_back(receive);
559 } else {
560 receive = InsertReceive(graph, node, user_node, user_pair.second, user_node_stage, node_stage, micro, node,
561 user_node_segment);
562 receive_ops->push_back(receive);
563 }
564 } else {
565 manager_->SetEdge(user_node, user_pair.second, receive);
566 }
567 }
568 continue;
569 }
570 if (node_stage > user_node_stage && node_segment == user_node_segment) {
571 MS_LOG(EXCEPTION) << "Within a segment, node_stage: " << node_stage
572 << " must be smaller than user_node_stage: " << user_node_stage;
573 }
574 }
575 }
576
CutBorder(const FuncGraphPtr & graph)577 std::pair<std::vector<AnfNodePtr>, std::vector<AnfNodePtr>> FoldPipelineTransformer::CutBorder(
578 const FuncGraphPtr &graph) {
579 std::vector<AnfNodePtr> send_ops;
580 std::vector<int64_t> send_ops_segment;
581 std::vector<AnfNodePtr> receive_ops;
582 auto ret = graph->get_return();
583 MS_EXCEPTION_IF_NULL(ret);
584 std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
585 std::reverse(all_nodes.begin(), all_nodes.end());
586 auto stage_num = g_device_manager->stage_num();
587 if (is_train_ && (stage_num > micro_size_)) {
588 MS_LOG(EXCEPTION) << "MicroBatch size: " << micro_size_ << " can't less than stage num: " << stage_num;
589 }
590 for (auto &node : all_nodes) {
591 auto stage_info = node->user_data<NodeStageInfo>();
592 if (!node->isa<CNode>() || stage_info == nullptr || stage_info->stage() == -1 ||
593 IsPrimitiveCNode(node, prim::kPrimUpdateState)) {
594 continue;
595 }
596 CutBorderForNode(graph, node, &send_ops, &send_ops_segment, &receive_ops);
597 }
598 RemoveMonadNode();
599 return std::make_pair(send_ops, receive_ops);
600 }
601
HandleSharedParameter()602 std::pair<std::vector<AnfNodePtr>, std::vector<AnfNodePtr>> FoldPipelineTransformer::HandleSharedParameter() {
603 auto parameters = root_->parameters();
604 std::vector<AnfNodePtr> sends = {};
605 std::vector<AnfNodePtr> recvs = {};
606 for (auto ¶meter : parameters) {
607 auto parameter_stage = parameter_color_map_[parameter];
608 if (parameter_stage.size() <= 1) {
609 continue;
610 }
611 const auto &node_users_map = manager_->node_users();
612 auto users = GetParameterLoadUsers(parameter, node_users_map);
613 for (auto &user : users) {
614 auto node = user.first;
615 auto cnode = node->cast<CNodePtr>();
616 auto graph = node->func_graph();
617 if (IsValueNode<FuncGraph>(cnode->input(0))) {
618 graph = GetValueNode<FuncGraphPtr>(cnode->input(0));
619 }
620 if (graph == root_ || graph->stage() == -1 || parameter_stage.count(stage_) == 0) {
621 continue;
622 }
623 auto micro = cnode->GetPrimalAttr(MICRO);
624 if (!micro) {
625 MS_LOG(INFO) << "Parameter: " << parameter->ToString() << " doesn't have micro batch";
626 micro = MakeValue(int64_t(0));
627 }
628 if (stage_ == *parameter_stage.begin()) {
629 auto user_stage = graph->stage();
630 auto stage_info = node->user_data<NodeStageInfo>();
631 if (stage_info) {
632 user_stage = stage_info->stage();
633 }
634 if (graph->stage() == stage_ || user_stage == -1) {
635 continue;
636 }
637 if (PipelineTransformer::Reuse(parameter, user_stage, sends, DEST_RANK)) {
638 continue;
639 }
640 auto send_out = InsertSend(parameter, user_stage, stage_, micro, 0);
641 sends.push_back(send_out.depend);
642 } else {
643 auto receive = PipelineTransformer::Reuse(parameter, *parameter_stage.begin(), recvs, SRC_RANK);
644 if (receive) {
645 manager_->SetEdge(node, user.second, receive);
646 } else {
647 AnfNodePtr recv;
648 auto fg = enable_share_cell_ ? shared_cell_ : main_graph_;
649 recv = InsertReceive(fg, parameter, node, user.second, stage_, *parameter_stage.begin(), micro, parameter, 0);
650 (void)(recvs.push_back(recv));
651 }
652 }
653 }
654 }
655 return std::make_pair(sends, recvs);
656 }
657
CutGraph()658 void FoldPipelineTransformer::CutGraph() {
659 CreateForwardGroup2();
660 MS_EXCEPTION_IF_NULL(main_graph_);
661 auto send_recv_shared_param = HandleSharedParameter();
662 auto graph = enable_share_cell_ ? shared_cell_ : main_graph_;
663 MS_EXCEPTION_IF_NULL(graph);
664 auto send_recv_cut_border = CutBorder(graph);
665 std::vector<AnfNodePtr> send_ops;
666 (void)(send_ops.insert(send_ops.end(), send_recv_shared_param.first.begin(), send_recv_shared_param.first.end()));
667 (void)(send_ops.insert(send_ops.end(), send_recv_cut_border.first.begin(), send_recv_cut_border.first.end()));
668 if (IsLastStage() && !enable_share_cell_) {
669 auto out_node = main_graph_->output();
670
671 auto make_tuple = CreateMakeTupleNode(main_graph_, send_ops);
672
673 std::vector<AnfNodePtr> tuple_out_depend = {NewValueNode(prim::kPrimDepend)};
674 tuple_out_depend.push_back(out_node);
675 tuple_out_depend.push_back(make_tuple);
676
677 auto tuple_out_depend_node = main_graph_->NewCNode(tuple_out_depend);
678 tuple_out_depend_node->set_abstract(out_node->abstract());
679 (void)manager_->Replace(main_graph_->output(), tuple_out_depend_node);
680 return;
681 }
682 if (send_ops.empty() && !is_train_) {
683 return;
684 }
685 if (!send_ops.empty()) {
686 type_ptr_ = send_ops.back()->user_data<Type>(DTYPE);
687 shape_ = send_ops.back()->user_data<ValueList>(SHAPE);
688 }
689 if (!enable_share_cell_) {
690 auto make_tuple = CreateMakeTupleNode(main_graph_, send_ops);
691 auto zero_outputs = GetZeroOutputs(main_graph_);
692 std::vector<AnfNodePtr> out = {NewValueNode(prim::kPrimDepend), zero_outputs, make_tuple};
693 auto out_node = main_graph_->NewCNode(out);
694 (void)manager_->Replace(main_graph_->output(), out_node);
695 return;
696 }
697 fold_send_tag_map.clear();
698 fold_recv_tag_map.clear();
699 if (!IsLastStage()) {
700 HandleGraphOutputs(send_ops);
701 }
702 std::vector<AnfNodePtr> recv_ops;
703 (void)(recv_ops.insert(recv_ops.end(), send_recv_shared_param.second.begin(), send_recv_shared_param.second.end()));
704 (void)(recv_ops.insert(recv_ops.end(), send_recv_cut_border.second.begin(), send_recv_cut_border.second.end()));
705 HandleGraphInputs(recv_ops);
706 }
707
708 } // namespace parallel
709 } // namespace mindspore
710