• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/parallel/pipeline_transformer/fold_pipeline_transformer.h"
18 #include <set>
19 #include <vector>
20 #include <string>
21 #include <utility>
22 #include <algorithm>
23 #include <memory>
24 #include "frontend/parallel/pipeline_transformer/pipeline_transformer.h"
25 #include "frontend/parallel/auto_parallel/graph_costmodel.h"
26 #include "frontend/parallel/graph_util/graph_splitter.h"
27 #include "frontend/parallel/ops_info/ops_utils.h"
28 #include "frontend/parallel/group_manager.h"
29 #include "frontend/parallel/parameter_manager.h"
30 #include "include/common/utils/parallel_context.h"
31 #include "frontend/parallel/step_parallel.h"
32 #include "frontend/parallel/node_check.h"
33 #include "frontend/parallel/graph_util/node_info.h"
34 #include "frontend/parallel/graph_util/pipeline_split_utils.h"
35 #include "frontend/parallel/step_parallel_utils.h"
36 #include "ir/anf.h"
37 #include "ir/graph_utils.h"
38 #include "ops/other_ops.h"
39 #include "ops/array_ops.h"
40 #include "ops/framework_ops.h"
41 #include "include/common/utils/comm_manager.h"
42 #include "utils/ms_context.h"
43 #include "utils/parallel_node_check.h"
44 
45 namespace mindspore {
46 namespace parallel {
47 mindspore::HashMap<int64_t, int64_t> fold_send_tag_map;
48 mindspore::HashMap<int64_t, int64_t> fold_recv_tag_map;
49 
CreateForwardGroup2()50 void FoldPipelineTransformer::CreateForwardGroup2() {
51   auto rank_id = g_device_manager->global_rank();
52   auto stage_id = g_device_manager->stage_id();
53   auto stage_num = g_device_manager->stage_num();
54 
55   std::vector<int64_t> forward_rank_list;
56   forward_rank_list.push_back(rank_id);
57   if (stage_id < stage_num - 1) {
58     forward_rank_list.push_back(rank_id + per_stage_rank_num_);
59   } else {
60     forward_rank_list.push_back(rank_id + per_stage_rank_num_ * (0 - stage_id));
61   }
62 
63   Group g;
64 
65   if (g_device_manager->CreateGroup(forward_rank_list, &g) != SUCCESS) {
66     MS_LOG(EXCEPTION) << "Create forward communication group between all pipeline stages failed, the rank_list is: "
67                       << forward_rank_list;
68   }
69 
70   std::vector<int64_t> backward_rank_list;
71   if (stage_id == 0) {
72     backward_rank_list.push_back(rank_id + per_stage_rank_num_ * (stage_num - 1));
73   } else {
74     backward_rank_list.push_back(rank_id - per_stage_rank_num_);
75   }
76   backward_rank_list.push_back(rank_id);
77 
78   Group g_back;
79   if (g_device_manager->CreateGroup(backward_rank_list, &g_back) != SUCCESS) {
80     MS_LOG(EXCEPTION) << "Create backward communication group between all pipeline stages failed, the rank_list is: "
81                       << backward_rank_list;
82   }
83 
84   group_.push_back(g.name());
85   group_.push_back(g_back.name());
86 }
HandleSegment(const ValuePtr & value,const FuncGraphPtr & graph)87 void HandleSegment(const ValuePtr &value, const FuncGraphPtr &graph) {
88   MS_EXCEPTION_IF_NULL(graph);
89   auto nodes = graph->nodes();
90   for (auto node : nodes) {
91     if (node->isa<CNode>()) {
92       auto cnode = node->cast<CNodePtr>();
93       MS_LOG(INFO) << "Handle Segment cnode: " << cnode->fullname_with_scope();
94       cnode->AddPrimalAttr(SEGMENT, value);
95     }
96   }
97 }
Coloring()98 void FoldPipelineTransformer::Coloring() {
99   auto need_coloring = true;
100   std::set<int64_t> stage_set;
101   std::set<int64_t> segment_set;
102   if (!IsTraining(manager_)) {
103     is_train_ = false;
104   }
105   while (need_coloring) {
106     need_coloring = false;
107     for (auto &fg : manager_->func_graphs()) {
108       if (fg == root_ && is_train_) {
109         continue;
110       }
111       auto value_nodes = fg->value_nodes();
112       for (auto value_pair = value_nodes.cbegin(); value_pair != value_nodes.cend(); ++value_pair) {
113         auto node = (*value_pair).first;
114         if (!IsValueNode<FuncGraph>(node)) {
115           continue;
116         }
117         auto graph = GetValueNode<FuncGraphPtr>(node);
118         if (graph->stage() == -1) {
119           continue;
120         }
121         (void)stage_set.insert(graph->stage());
122         (void)segment_set.insert(graph->segment());
123         auto node_users = manager_->node_users()[node];
124         HandleSegment(MakeValue(graph->segment()), graph);
125         for (auto &user_pair : node_users) {
126           auto user_node = user_pair.first->cast<CNodePtr>();
127           user_node->set_user_data<NodeStageInfo>(std::make_shared<NodeStageInfo>(graph->stage()));
128           user_node->set_user_data<NodeSegmentInfo>(std::make_shared<NodeSegmentInfo>(graph->segment()));
129           auto user_node_graph = user_node->func_graph();
130           if (graph->stage() == stage_ && user_node_graph->stage() == -1) {
131             user_node_graph->set_stage(graph->stage());
132             MS_LOG(INFO) << "Set_segment in Coloring" << graph->segment();
133             user_node_graph->set_segment(graph->segment());
134             need_coloring = true;
135           }
136         }
137       }
138     }
139   }
140   MS_EXCEPTION_IF_NULL(g_device_manager);
141   auto stage_num = g_device_manager->stage_num();
142   auto segment_num = ParallelContext::GetInstance()->pipeline_segment_split_num();
143   if (SizeToLong(stage_set.size()) != stage_num) {
144     MS_LOG(EXCEPTION) << "Stage num is " << stage_num << " is not equal to stage used: " << stage_set.size();
145   }
146   if (SizeToLong(segment_set.size()) != segment_num) {
147     MS_LOG(EXCEPTION) << "Segment num is " << segment_num << " is not equal to segment used: " << segment_set.size();
148   }
149 }
150 
ColorForNodes()151 void FoldPipelineTransformer::ColorForNodes() {
152   for (auto &fg : manager_->func_graphs()) {
153     auto stage = fg->stage();
154     auto segment = fg->segment();
155     if (stage < 0) {
156       continue;
157     }
158     if (segment < 0) {
159       continue;
160     }
161     if (fg == root_ || fg == main_graph_ || fg == shared_cell_) {
162       continue;
163     }
164     auto all_nodes = fg->nodes();
165     for (auto node : all_nodes) {
166       if (node->user_data<NodeStageInfo>() != nullptr) {
167         continue;
168       }
169       node->set_user_data<NodeStageInfo>(std::make_shared<NodeStageInfo>(stage));
170       if (node->user_data<NodeSegmentInfo>() != nullptr) {
171         continue;
172       }
173       node->set_user_data<NodeSegmentInfo>(std::make_shared<NodeSegmentInfo>(segment));
174     }
175   }
176 }
177 
BroadCastColoring()178 void FoldPipelineTransformer::BroadCastColoring() {
179   auto need_coloring = true;
180   while (need_coloring) {
181     need_coloring = false;
182     auto all_nodes = main_graph_->nodes();
183     auto node_users = manager_->node_users();
184     for (auto node = all_nodes.cbegin(); node != all_nodes.cend(); ++node) {
185       auto stage_info = (*node)->user_data<NodeStageInfo>();
186       auto segment_info = (*node)->user_data<NodeSegmentInfo>();
187       if (!(*node)->isa<CNode>() || stage_info == nullptr || stage_info->stage() == -1 ||
188           IsPrimitiveCNode(*node, prim::kPrimUpdateState)) {
189         continue;
190       }
191       auto stage = stage_info->stage();
192       auto segment = segment_info->segment();
193       for (auto &user_pair : node_users[*node]) {
194         auto user_node = user_pair.first->cast<CNodePtr>();
195         auto user_stage_info = user_node->user_data<NodeStageInfo>();
196         auto user_segment_info = user_node->user_data<NodeSegmentInfo>();
197         if (user_stage_info == nullptr) {
198           user_node->set_user_data<NodeStageInfo>(std::make_shared<NodeStageInfo>(stage));
199           user_node->set_user_data<NodeSegmentInfo>(std::make_shared<NodeSegmentInfo>(segment));
200           need_coloring = true;
201           continue;
202         }
203         auto user_node_stage = user_stage_info->stage();
204         auto user_node_segment = user_segment_info->segment();
205         if (stage > user_node_stage && segment == user_node_segment) {
206           if (IsValueNode<FuncGraph>(user_node->input(0))) {
207             MS_LOG(WARNING) << "The stage setting is incorrect. PreNode's stage: " << stage
208                             << " is larger than NextNode's stage:" << user_node_stage;
209           }
210           user_node->set_user_data<NodeStageInfo>(std::make_shared<NodeStageInfo>(stage));
211           need_coloring = true;
212         }
213         if (segment > user_node_segment) {
214           user_node->set_user_data<NodeSegmentInfo>(std::make_shared<NodeSegmentInfo>(segment));
215           need_coloring = true;
216         }
217       }
218     }
219   }
220   ColorForNodes();
221 }
222 
InsertSend(const AnfNodePtr & parameter,int64_t user_node_stage,int64_t node_stage,const ValuePtr & value,int64_t segment)223 SendAttr FoldPipelineTransformer::InsertSend(const AnfNodePtr &parameter, int64_t user_node_stage, int64_t node_stage,
224                                              const ValuePtr &value, int64_t segment) {
225   auto dest_rank = global_rank_ + (user_node_stage - node_stage) * per_stage_rank_num_;
226   int64_t send_tag;
227   auto stage_num = g_device_manager->stage_num();
228   if (node_stage == 0 && user_node_stage > 1 && stage_num > 2) {
229     if (fold_recv_tag_map.find(dest_rank) != fold_recv_tag_map.end()) {
230       send_tag = fold_recv_tag_map[dest_rank] + 1;
231       fold_recv_tag_map[dest_rank] += 1;
232     } else {
233       send_tag = 0;
234       fold_recv_tag_map[dest_rank] = 0;
235     }
236   } else {
237     if (fold_send_tag_map.find(dest_rank) != fold_send_tag_map.end()) {
238       send_tag = fold_send_tag_map[dest_rank] + 1;
239       fold_send_tag_map[dest_rank] += 1;
240     } else {
241       send_tag = 0;
242       fold_send_tag_map[dest_rank] = 0;
243     }
244   }
245   Attr attr_tag = std::make_pair(SR_TAG, MakeValue(send_tag));
246   Attr attr_rank = std::make_pair(DEST_RANK, MakeValue(user_node_stage));
247   Attr attr_group = std::make_pair(GROUP, MakeValue(group_[0]));
248   Attr attr_group_back = std::make_pair(GROUP_BACK, MakeValue(group_[1]));
249   if (stage_num > 2) {
250     auto next = (user_node_stage == 0) ? 0 : 1;
251     attr_rank = std::make_pair(DEST_RANK, MakeValue(next));
252     attr_group = std::make_pair(GROUP, MakeValue(group_[0]));
253     attr_group_back = std::make_pair(GROUP_BACK, MakeValue(group_[0]));
254   }
255 
256   if (node_stage == 0 && user_node_stage > 1 && stage_num > 2) {
257     attr_group = std::make_pair(GROUP, MakeValue(group_[1]));
258     attr_group_back = std::make_pair(GROUP_BACK, MakeValue(group_[1]));
259     attr_rank = std::make_pair(DEST_RANK, MakeValue(1));
260   }
261   auto graph = enable_share_cell_ ? shared_cell_ : main_graph_;
262   std::vector<AnfNodePtr> send_input = {parameter};
263   OperatorAttrs attrs = {attr_tag, attr_rank, attr_group, attr_group_back};
264   CNodePtr send = CreateCNodeByInputsAndAttr(graph, SEND, SEND, send_input, attrs);
265   auto prim = GetCNodePrimitive(send);
266   AnfNodePtr care_node;
267   bool is_param = true;
268   auto op_info_pair = GetOpInfoPair(parameter, parameter, &care_node, &is_param);
269   auto tensor_info = GetTensorInfo(op_info_pair, is_param);
270 
271   auto index = op_info_pair.second;
272   auto op_info = op_info_pair.first;
273   auto slice_shape = tensor_info.slice_shape();
274   auto shape_type_pair = GetShapeType(parameter, slice_shape, 0);
275   prim->set_attr(SHAPE, shape_type_pair.first);
276   prim->set_attr(DTYPE, shape_type_pair.second);
277   if (!is_param) {
278     send->AddPrimalAttr(PIPELINE_END, value);
279   } else {
280     send->AddPrimalAttr(PIPELINE_PARAM, value);
281     send->set_user_data<OperatorInfo>(op_info);
282     send->AddPrimalAttr(PARAM_INDEX, MakeValue(index));
283     auto param = care_node ? care_node : parameter;
284     send->set_user_data<AnfNode>(INPUT_PARAM, param);
285   }
286   send->AddPrimalAttr(MICRO, value);
287   send->AddPrimalAttr(SEGMENT, MakeValue(segment));
288   MS_LOG(INFO) << "Insert Send op, segment is " << segment;
289   send->AddPrimalAttr(DEST_RANK, MakeValue(user_node_stage));
290   OperatorAttrs depend_attrs;
291   CNodePtr depend = CreateCNodeByInputsAndAttr(graph, DEPEND, DEPEND, AnfNodePtrList{parameter, send}, depend_attrs);
292   auto abstract = parameter->abstract();
293   if (care_node) {
294     abstract = care_node->abstract();
295   }
296   depend->set_abstract(abstract);
297   send->set_abstract(abstract);
298   SendAttr send_out = {shape_type_pair.first, shape_type_pair.second, depend};
299 
300   send->set_user_data<int64_t>(DEST_RANK, std::make_shared<int64_t>(dest_rank));
301   send->set_user_data<int64_t>(USER_NODE_STAGE, std::make_shared<int64_t>(user_node_stage));
302   return send_out;
303 }
304 
ComputeRecvTag(int64_t node_stage,int64_t user_node_stage,int64_t stage_num,int64_t src_rank)305 int64_t FoldPipelineTransformer::ComputeRecvTag(int64_t node_stage, int64_t user_node_stage, int64_t stage_num,
306                                                 int64_t src_rank) {
307   int64_t recv_tag;
308   if (node_stage == 0 && user_node_stage > 1 && stage_num > 2) {
309     if (fold_send_tag_map.find(src_rank) != fold_send_tag_map.end()) {
310       recv_tag = fold_send_tag_map[src_rank] + 1;
311       fold_send_tag_map[src_rank] += 1;
312     } else {
313       recv_tag = 0;
314       fold_send_tag_map[src_rank] = 0;
315     }
316   } else {
317     if (fold_recv_tag_map.find(src_rank) != fold_recv_tag_map.end()) {
318       recv_tag = fold_recv_tag_map[src_rank] + 1;
319       fold_recv_tag_map[src_rank] += 1;
320     } else {
321       recv_tag = 0;
322       fold_recv_tag_map[src_rank] = 0;
323     }
324   }
325   return recv_tag;
326 }
327 
InsertReceive(const FuncGraphPtr & graph,const AnfNodePtr & node,const AnfNodePtr & use_node,int index,int64_t user_node_stage,int64_t node_stage,const ValuePtr & value,const AnfNodePtr & graph_param,int64_t segment)328 AnfNodePtr FoldPipelineTransformer::InsertReceive(const FuncGraphPtr &graph, const AnfNodePtr &node,
329                                                   const AnfNodePtr &use_node, int index, int64_t user_node_stage,
330                                                   int64_t node_stage, const ValuePtr &value,
331                                                   const AnfNodePtr &graph_param, int64_t segment) {
332   auto src_rank = global_rank_ - (user_node_stage - node_stage) * per_stage_rank_num_;
333   auto stage_num = g_device_manager->stage_num();
334   auto recv_tag = ComputeRecvTag(node_stage, user_node_stage, stage_num, src_rank);
335   Attr attr_tag = std::make_pair(SR_TAG, MakeValue(recv_tag));
336   Attr attr_rank = std::make_pair(SRC_RANK, MakeValue(node_stage));
337   Attr attr_group = std::make_pair(GROUP, MakeValue(group_[0]));
338   Attr attr_group_back = std::make_pair(GROUP_BACK, MakeValue(group_[1]));
339 
340   if (stage_num > 2) {
341     auto next = (user_node_stage == 0) ? 1 : 0;
342     attr_rank = std::make_pair(SRC_RANK, MakeValue(next));
343     attr_group = std::make_pair(GROUP, MakeValue(group_[1]));
344     attr_group_back = std::make_pair(GROUP_BACK, MakeValue(group_[1]));
345   }
346   bool is_param = true;
347   AnfNodePtr care_node;
348   auto op_info_pair = GetOpInfoPair(node, graph_param, &care_node, &is_param);
349   auto tensor_info = GetTensorInfo(op_info_pair, is_param);
350   auto tensor_layout = tensor_info.tensor_layout();
351   Shape slice_shape = tensor_info.slice_shape();
352   auto shape_type_pair = GetShapeType(node, slice_shape, 0);
353   Attr attr_shape = std::make_pair(SHAPE, shape_type_pair.first);
354   Attr attr_dtype = std::make_pair(DTYPE, shape_type_pair.second);
355   if (node_stage == 0 && user_node_stage > 1 && stage_num > 2) {
356     attr_group = std::make_pair(GROUP, MakeValue(group_[0]));
357     attr_group_back = std::make_pair(GROUP_BACK, MakeValue(group_[0]));
358     attr_rank = std::make_pair(SRC_RANK, MakeValue(0));
359   }
360   OperatorAttrs attrs = {attr_tag, attr_rank, attr_shape, attr_dtype, attr_group, attr_group_back};
361   std::vector<AnfNodePtr> recv_input;
362   if (node->isa<Parameter>()) {
363     recv_input = {node};
364   } else {
365     recv_input = {virtual_param_};
366   }
367   auto recv = CreateCNodeByInputsAndAttr(graph, RECEIVE, RECEIVE, recv_input, attrs);
368   if (is_param) {
369     recv->set_user_data<AnfNode>(PIPELINE_PARAM, node);
370     recv->AddPrimalAttr(PIPELINE_PARAM, value);
371     auto param = care_node ? care_node : node;
372     recv->set_user_data<AnfNode>(INPUT_PARAM, param);
373   } else {
374     recv->AddPrimalAttr(PIPELINE_BEGIN, value);
375   }
376   recv->AddPrimalAttr(MICRO, value);
377   recv->AddPrimalAttr(SRC_RANK, MakeValue(node_stage));
378   recv->AddPrimalAttr(SEGMENT, MakeValue(segment));
379   MS_LOG(INFO) << "Insertreceive segment" << segment;
380   auto node_abstract = node->abstract();
381   if (node->isa<CNode>()) {
382     auto cnode = node->cast<CNodePtr>();
383     MS_EXCEPTION_IF_NULL(cnode);
384     if (IsValueNode<FuncGraph>(cnode->input(0))) {
385       auto output = GetValueNode<FuncGraphPtr>(cnode->input(0))->output();
386       MS_EXCEPTION_IF_NULL(output);
387       node_abstract = output->abstract();
388     }
389   }
390   MS_EXCEPTION_IF_NULL(node_abstract);
391   recv->set_abstract(node_abstract);
392   if (node->isa<Parameter>()) {
393     BaseShapePtr parallel_shape = std::make_shared<abstract::Shape>(slice_shape);
394     auto abstract_clone = node->abstract()->Clone();
395     MS_EXCEPTION_IF_NULL(abstract_clone);
396     abstract_clone->set_shape(parallel_shape);
397     node->set_abstract(abstract_clone);
398     node->set_user_data<TensorLayout>(std::make_shared<TensorLayout>(tensor_layout));
399     auto actual_param = RefParameterToActualParameter(node);
400     if (actual_param) {
401       actual_param->set_user_data<TensorLayout>(std::make_shared<TensorLayout>(tensor_layout));
402       auto actual_param_abstract = actual_param->abstract()->Clone();
403       actual_param_abstract->set_shape(parallel_shape);
404       actual_param->set_abstract(actual_param_abstract);
405     }
406   }
407   recv->set_user_data<TensorLayout>(std::make_shared<TensorLayout>(tensor_layout));
408   recv->set_user_data<OperatorInfo>(op_info_pair.first);
409 
410   recv->set_user_data<int64_t>(SRC_RANK, std::make_shared<int64_t>(src_rank));
411   recv->set_user_data<int64_t>(NODE_STAGE, std::make_shared<int64_t>(node_stage));
412   recv->set_user_data<Type>(SLICE_DTYPE, shape_type_pair.second);
413   recv->set_user_data<Shape>(SLICE_SHAPE, std::make_shared<Shape>(slice_shape));
414 
415   manager_->SetEdge(use_node, index, recv);
416   return recv;
417 }
418 
Reuse(const AnfNodePtr & node,int64_t stage,int64_t node_segment,const std::vector<AnfNodePtr> & out_input,const std::vector<int64_t> & out_input_segment,const std::string & tag)419 AnfNodePtr FoldPipelineTransformer::Reuse(const AnfNodePtr &node, int64_t stage, int64_t node_segment,
420                                           const std::vector<AnfNodePtr> &out_input,
421                                           const std::vector<int64_t> &out_input_segment, const std::string &tag) {
422   std::vector<std::pair<AnfNodePtr, int64_t>> zipped;
423   std::transform(out_input.begin(), out_input.end(), out_input_segment.begin(), std::back_inserter(zipped),
424                  [](const auto &send, const auto &send_segment) { return std::make_pair(send, send_segment); });
425 
426   for (auto &zipp : zipped) {
427     auto input = zipp.first;
428     auto send_segment = zipp.second;
429     auto cnode = input->cast<CNodePtr>();
430     if (!cnode) {
431       continue;
432     }
433     if (IsPrimitiveCNode(cnode, prim::kPrimDepend)) {
434       cnode = cnode->input(DEPEND_NODE_SOURCE_INDEX)->cast<CNodePtr>();
435     }
436     if (cnode->input(1) == node) {
437       auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
438       auto dest_rank_send = GetValue<int64_t>(prim->GetAttr(tag));
439       if (dest_rank_send == stage && node_segment == send_segment) {
440         return input;
441       }
442     }
443   }
444   return nullptr;
445 }
446 
HandleParameterGraph(const AnfNodePtr & node,const AnfNodePtr & use_node,int64_t stage,int64_t user_stage,const ValuePtr & micro,size_t pos,const std::vector<AnfNodePtr> & ops)447 AnfNodePtr FoldPipelineTransformer::HandleParameterGraph(const AnfNodePtr &node, const AnfNodePtr &use_node,
448                                                          int64_t stage, int64_t user_stage, const ValuePtr &micro,
449                                                          size_t pos, const std::vector<AnfNodePtr> &ops) {
450   CNodePtr call_node = nullptr;
451   auto argument = GetRealKernelNode(node, -1, &call_node).first;
452 
453   auto use_cnode = use_node->cast<CNodePtr>();
454   MS_EXCEPTION_IF_NULL(use_cnode);
455   if (!IsValueNode<FuncGraph>(use_cnode->input(0))) {
456     MS_LOG(EXCEPTION) << "Parameter must be used by a graph, but got: " << use_cnode->DebugString();
457   }
458   auto use_graph = GetValueNode<FuncGraphPtr>(use_cnode->input(0));
459   auto use_parameter_list = use_graph->parameters();
460   auto parameter = use_parameter_list.at(pos - 1);
461 
462   // insert receive
463   if (stage_ == user_stage) {
464     auto recv = PipelineTransformer::Reuse(argument, stage, ops, SRC_RANK);
465     if (recv) {
466       manager_->SetEdge(use_node, SizeToInt(pos), recv);
467       return nullptr;
468     }
469     auto root_param = argument;
470     if (argument->isa<Parameter>() && argument->func_graph() != root_) {
471       root_param = GetArgumentsByParameter(argument);
472     }
473     (void)parameter_color_map_[root_param].insert(user_stage);
474     auto graph = enable_share_cell_ ? shared_cell_ : main_graph_;
475     return InsertReceive(graph, argument, use_node, SizeToInt(pos), user_stage, stage, micro, parameter, 0);
476   }
477   // insert send
478   if (PipelineTransformer::Reuse(argument, user_stage, ops, DEST_RANK)) {
479     return nullptr;
480   }
481   auto send_out = InsertSend(argument, user_stage, stage_, micro, 0);
482   send_out.depend->set_user_data<Type>(DTYPE, send_out.type);
483   send_out.depend->set_user_data<ValueList>(SHAPE, send_out.shape);
484   return send_out.depend;
485 }
486 
IsStageConflict(int64_t node_stage,int64_t user_node_stage,int64_t node_segment,int64_t user_node_segment,int64_t stage_num,bool isEmbed)487 bool IsStageConflict(int64_t node_stage, int64_t user_node_stage, int64_t node_segment, int64_t user_node_segment,
488                      int64_t stage_num, bool isEmbed) {
489   if (isEmbed || (node_stage < user_node_stage && node_segment == user_node_segment) ||
490       (node_stage == stage_num - 1 && user_node_stage == 0 && node_segment < user_node_segment)) {
491     return true;
492   }
493   return false;
494 }
495 
CutBorderForNode(const FuncGraphPtr & graph,const AnfNodePtr & node,std::vector<AnfNodePtr> * send_ops,std::vector<int64_t> * send_ops_segment,std::vector<AnfNodePtr> * receive_ops)496 void FoldPipelineTransformer::CutBorderForNode(const FuncGraphPtr &graph, const AnfNodePtr &node,
497                                                std::vector<AnfNodePtr> *send_ops,
498                                                std::vector<int64_t> *send_ops_segment,
499                                                std::vector<AnfNodePtr> *receive_ops) {
500   auto stage_info = node->user_data<NodeStageInfo>();
501   auto segment_info = node->user_data<NodeSegmentInfo>();
502   auto node_users = manager_->node_users()[node];
503   AnfNodePtr receive = nullptr;
504   for (auto &user_pair : node_users) {
505     auto user_node = user_pair.first;
506     auto node_stage = stage_info->stage();
507     auto node_segment = segment_info->segment();
508     auto user_stage_info = user_node->user_data<NodeStageInfo>();
509     if (user_stage_info == nullptr) {
510       continue;
511     }
512     auto user_segment_info = user_node->user_data<NodeSegmentInfo>();
513     if (user_segment_info == nullptr) {
514       continue;
515     }
516     auto user_node_stage = user_stage_info->stage();
517     if (node_stage != stage_ && user_node_stage != stage_) {
518       continue;
519     }
520     auto micro = user_node->cast<CNodePtr>()->GetPrimalAttr(MICRO);
521     auto user_node_segment = user_segment_info->segment();
522     if (!micro) {
523       MS_LOG(INFO) << "Can't find micro_batch information, use micro(0)";
524       micro = MakeValue(int64_t(0));
525     }
526     auto stage_num = g_device_manager->stage_num();
527 
528     bool isEmbed = node_stage < user_node_stage && node_segment != user_node_segment;
529     if (IsStageConflict(node_stage, user_node_stage, node_segment, user_node_segment, stage_num, isEmbed)) {
530       if (node_stage == stage_) {
531         if (IsParameterGraph(node) && isEmbed) {
532           auto send_depend = HandleParameterGraph(node, user_node, node_stage, user_node_stage, micro,
533                                                   IntToSize(user_pair.second), *send_ops);
534           if (!send_depend) {
535             continue;
536           }
537           (void)send_ops->insert(send_ops->cbegin(), send_depend);
538           (void)send_ops_segment->insert(send_ops_segment->begin(), node_segment);
539           continue;
540         }
541         if (Reuse(node, user_node_stage, user_node_segment, *send_ops, *send_ops_segment, DEST_RANK)) {
542           continue;
543         }
544         auto send_out = InsertSend(node, user_node_stage, node_stage, micro, node_segment);
545         MS_EXCEPTION_IF_NULL(send_out.depend);
546         send_ops->push_back(send_out.depend);
547         send_ops_segment->push_back(node_segment);
548         send_out.depend->set_user_data<Type>(DTYPE, send_out.type);
549         send_out.depend->set_user_data<ValueList>(SHAPE, send_out.shape);
550       } else {
551         if (!receive) {
552           if (IsParameterGraph(node)) {
553             receive = HandleParameterGraph(node, user_node, node_stage, user_node_stage, micro,
554                                            IntToSize(user_pair.second), *receive_ops);
555             if (!receive) {
556               continue;
557             }
558             receive_ops->push_back(receive);
559           } else {
560             receive = InsertReceive(graph, node, user_node, user_pair.second, user_node_stage, node_stage, micro, node,
561                                     user_node_segment);
562             receive_ops->push_back(receive);
563           }
564         } else {
565           manager_->SetEdge(user_node, user_pair.second, receive);
566         }
567       }
568       continue;
569     }
570     if (node_stage > user_node_stage && node_segment == user_node_segment) {
571       MS_LOG(EXCEPTION) << "Within a segment, node_stage: " << node_stage
572                         << " must be smaller than user_node_stage: " << user_node_stage;
573     }
574   }
575 }
576 
CutBorder(const FuncGraphPtr & graph)577 std::pair<std::vector<AnfNodePtr>, std::vector<AnfNodePtr>> FoldPipelineTransformer::CutBorder(
578   const FuncGraphPtr &graph) {
579   std::vector<AnfNodePtr> send_ops;
580   std::vector<int64_t> send_ops_segment;
581   std::vector<AnfNodePtr> receive_ops;
582   auto ret = graph->get_return();
583   MS_EXCEPTION_IF_NULL(ret);
584   std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
585   std::reverse(all_nodes.begin(), all_nodes.end());
586   auto stage_num = g_device_manager->stage_num();
587   if (is_train_ && (stage_num > micro_size_)) {
588     MS_LOG(EXCEPTION) << "MicroBatch size: " << micro_size_ << " can't less than stage num: " << stage_num;
589   }
590   for (auto &node : all_nodes) {
591     auto stage_info = node->user_data<NodeStageInfo>();
592     if (!node->isa<CNode>() || stage_info == nullptr || stage_info->stage() == -1 ||
593         IsPrimitiveCNode(node, prim::kPrimUpdateState)) {
594       continue;
595     }
596     CutBorderForNode(graph, node, &send_ops, &send_ops_segment, &receive_ops);
597   }
598   RemoveMonadNode();
599   return std::make_pair(send_ops, receive_ops);
600 }
601 
HandleSharedParameter()602 std::pair<std::vector<AnfNodePtr>, std::vector<AnfNodePtr>> FoldPipelineTransformer::HandleSharedParameter() {
603   auto parameters = root_->parameters();
604   std::vector<AnfNodePtr> sends = {};
605   std::vector<AnfNodePtr> recvs = {};
606   for (auto &parameter : parameters) {
607     auto parameter_stage = parameter_color_map_[parameter];
608     if (parameter_stage.size() <= 1) {
609       continue;
610     }
611     const auto &node_users_map = manager_->node_users();
612     auto users = GetParameterLoadUsers(parameter, node_users_map);
613     for (auto &user : users) {
614       auto node = user.first;
615       auto cnode = node->cast<CNodePtr>();
616       auto graph = node->func_graph();
617       if (IsValueNode<FuncGraph>(cnode->input(0))) {
618         graph = GetValueNode<FuncGraphPtr>(cnode->input(0));
619       }
620       if (graph == root_ || graph->stage() == -1 || parameter_stage.count(stage_) == 0) {
621         continue;
622       }
623       auto micro = cnode->GetPrimalAttr(MICRO);
624       if (!micro) {
625         MS_LOG(INFO) << "Parameter: " << parameter->ToString() << " doesn't have micro batch";
626         micro = MakeValue(int64_t(0));
627       }
628       if (stage_ == *parameter_stage.begin()) {
629         auto user_stage = graph->stage();
630         auto stage_info = node->user_data<NodeStageInfo>();
631         if (stage_info) {
632           user_stage = stage_info->stage();
633         }
634         if (graph->stage() == stage_ || user_stage == -1) {
635           continue;
636         }
637         if (PipelineTransformer::Reuse(parameter, user_stage, sends, DEST_RANK)) {
638           continue;
639         }
640         auto send_out = InsertSend(parameter, user_stage, stage_, micro, 0);
641         sends.push_back(send_out.depend);
642       } else {
643         auto receive = PipelineTransformer::Reuse(parameter, *parameter_stage.begin(), recvs, SRC_RANK);
644         if (receive) {
645           manager_->SetEdge(node, user.second, receive);
646         } else {
647           AnfNodePtr recv;
648           auto fg = enable_share_cell_ ? shared_cell_ : main_graph_;
649           recv = InsertReceive(fg, parameter, node, user.second, stage_, *parameter_stage.begin(), micro, parameter, 0);
650           (void)(recvs.push_back(recv));
651         }
652       }
653     }
654   }
655   return std::make_pair(sends, recvs);
656 }
657 
CutGraph()658 void FoldPipelineTransformer::CutGraph() {
659   CreateForwardGroup2();
660   MS_EXCEPTION_IF_NULL(main_graph_);
661   auto send_recv_shared_param = HandleSharedParameter();
662   auto graph = enable_share_cell_ ? shared_cell_ : main_graph_;
663   MS_EXCEPTION_IF_NULL(graph);
664   auto send_recv_cut_border = CutBorder(graph);
665   std::vector<AnfNodePtr> send_ops;
666   (void)(send_ops.insert(send_ops.end(), send_recv_shared_param.first.begin(), send_recv_shared_param.first.end()));
667   (void)(send_ops.insert(send_ops.end(), send_recv_cut_border.first.begin(), send_recv_cut_border.first.end()));
668   if (IsLastStage() && !enable_share_cell_) {
669     auto out_node = main_graph_->output();
670 
671     auto make_tuple = CreateMakeTupleNode(main_graph_, send_ops);
672 
673     std::vector<AnfNodePtr> tuple_out_depend = {NewValueNode(prim::kPrimDepend)};
674     tuple_out_depend.push_back(out_node);
675     tuple_out_depend.push_back(make_tuple);
676 
677     auto tuple_out_depend_node = main_graph_->NewCNode(tuple_out_depend);
678     tuple_out_depend_node->set_abstract(out_node->abstract());
679     (void)manager_->Replace(main_graph_->output(), tuple_out_depend_node);
680     return;
681   }
682   if (send_ops.empty() && !is_train_) {
683     return;
684   }
685   if (!send_ops.empty()) {
686     type_ptr_ = send_ops.back()->user_data<Type>(DTYPE);
687     shape_ = send_ops.back()->user_data<ValueList>(SHAPE);
688   }
689   if (!enable_share_cell_) {
690     auto make_tuple = CreateMakeTupleNode(main_graph_, send_ops);
691     auto zero_outputs = GetZeroOutputs(main_graph_);
692     std::vector<AnfNodePtr> out = {NewValueNode(prim::kPrimDepend), zero_outputs, make_tuple};
693     auto out_node = main_graph_->NewCNode(out);
694     (void)manager_->Replace(main_graph_->output(), out_node);
695     return;
696   }
697   fold_send_tag_map.clear();
698   fold_recv_tag_map.clear();
699   if (!IsLastStage()) {
700     HandleGraphOutputs(send_ops);
701   }
702   std::vector<AnfNodePtr> recv_ops;
703   (void)(recv_ops.insert(recv_ops.end(), send_recv_shared_param.second.begin(), send_recv_shared_param.second.end()));
704   (void)(recv_ops.insert(recv_ops.end(), send_recv_cut_border.second.begin(), send_recv_cut_border.second.end()));
705   HandleGraphInputs(recv_ops);
706 }
707 
708 }  // namespace parallel
709 }  // namespace mindspore
710