• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/parallel/graph_util/fold_pipeline_split_utils.h"
18 #include <memory>
19 #include <list>
20 #include <set>
21 #include <queue>
22 #include <algorithm>
23 
24 #include "frontend/parallel/graph_util/generate_graph.h"
25 #include "frontend/parallel/graph_util/pipeline_split_utils.h"
26 #include "ops/other_ops.h"
27 #include "ops/math_ops.h"
28 #include "ops/framework_ops.h"
29 #include "ops/array_ops.h"
30 #include "ops/nn_ops.h"
31 #include "ir/value.h"
32 #include "frontend/parallel/ops_info/ops_utils.h"
33 #include "frontend/parallel/device_manager.h"
34 #include "include/common/utils/parallel_context.h"
35 #include "frontend/parallel/step_parallel.h"
36 #include "frontend/parallel/step_parallel_utils.h"
37 #include "frontend/parallel/graph_util/node_info.h"
38 #include "utils/parallel_node_check.h"
39 
40 namespace mindspore {
41 namespace parallel {
42 
43 namespace {
44 constexpr int kBackwardEnd = 1;
45 constexpr int kForwardStart = 2;
46 constexpr int kForwardEnd = 3;
47 }  // namespace
48 
49 const std::set<PrimitivePtr> END_NODE_BLACK_LIST = {
50   prim::kPrimDepend,    prim::kPrimTupleGetItem, prim::kPrimAdd,    prim::kPrimSoftmaxCrossEntropyWithLogits,
51   prim::kPrimMakeTuple, prim::kPrimUpdateState,  prim::kPrimReshape};
52 
GetSegmentMax(const FuncGraphPtr & root,const std::vector<AnfNodePtr> & forward_end)53 int64_t GetSegmentMax(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &forward_end) {
54   int64_t seg_max = 0;
55   if (forward_end.empty()) {
56     MS_LOG(EXCEPTION) << "Can not find the end node of pipeline, you are advised to use 'PipelineCell' to fix it.";
57   } else {
58     auto forward_end_cnode = forward_end.back()->cast<CNodePtr>();
59     auto seg_size = forward_end_cnode->GetPrimalAttr(SEGMENT);
60     MS_EXCEPTION_IF_NULL(seg_size);
61     seg_max = GetValue<int64_t>(seg_size);
62   }
63   return seg_max;
64 }
65 
GetSubStepPairs(const PipelinePair & fp_or_bp_pair,int64_t sub_step_num,int64_t seg_num,int64_t sub_micro_num,int64_t micro_num)66 std::vector<PipelinePair> GetSubStepPairs(const PipelinePair &fp_or_bp_pair, int64_t sub_step_num, int64_t seg_num,
67                                           int64_t sub_micro_num, int64_t micro_num) {
68   std::vector<PipelinePair> fp_or_bp_sub_pairs;
69   for (int64_t s = 0; s < sub_step_num; s++) {
70     std::vector<AnfNodePtr> temp_first;
71     std::vector<AnfNodePtr> temp_second;
72     for (int64_t sid = 0; sid < seg_num; sid++) {
73       temp_first.insert(temp_first.end(), fp_or_bp_pair.first.begin() + s * sub_micro_num + sid * micro_num,
74                         fp_or_bp_pair.first.begin() + (s + 1) * sub_micro_num + sid * micro_num);
75       temp_second.insert(temp_second.end(), fp_or_bp_pair.second.begin() + s * sub_micro_num + sid * micro_num,
76                          fp_or_bp_pair.second.begin() + (s + 1) * sub_micro_num + sid * micro_num);
77     }
78     fp_or_bp_sub_pairs.emplace_back(temp_first, temp_second);
79   }
80   return fp_or_bp_sub_pairs;
81 }
82 
CompFuncBySegAscending(const AnfNodePtr & node1,const AnfNodePtr & node2)83 bool CompFuncBySegAscending(const AnfNodePtr &node1, const AnfNodePtr &node2) {
84   auto parallel_context = parallel::ParallelContext::GetInstance();
85   if (parallel_context->enable_fold_pipeline()) {
86     auto get_value_func = [](const AnfNodePtr &node) {
87       MS_EXCEPTION_IF_NULL(node);
88       auto cnode = node->cast<CNodePtr>();
89       MS_EXCEPTION_IF_NULL(cnode);
90       auto seg = cnode->GetPrimalAttr(SEGMENT);
91       MS_EXCEPTION_IF_NULL(seg);
92       return GetValue<int64_t>(seg);
93     };
94 
95     if (get_value_func(node1) != get_value_func(node2)) {
96       return get_value_func(node1) < get_value_func(node2);
97     }
98   }
99   return CompFunc(node1, node2);
100 }
101 
CompFuncBySegDescending(const AnfNodePtr & node1,const AnfNodePtr & node2)102 bool CompFuncBySegDescending(const AnfNodePtr &node1, const AnfNodePtr &node2) {
103   auto parallel_context = parallel::ParallelContext::GetInstance();
104   if (parallel_context->enable_fold_pipeline()) {
105     auto get_value_func = [](const AnfNodePtr &node) {
106       MS_EXCEPTION_IF_NULL(node);
107       auto cnode = node->cast<CNodePtr>();
108       MS_EXCEPTION_IF_NULL(cnode);
109       auto seg = cnode->GetPrimalAttr(SEGMENT);
110       MS_EXCEPTION_IF_NULL(seg);
111       return GetValue<int64_t>(seg);
112     };
113 
114     if (get_value_func(node1) != get_value_func(node2)) {
115       return get_value_func(node1) > get_value_func(node2);
116     }
117   }
118   return CompFunc(node1, node2);
119 }
120 
InsertVirtualFoldPipelineEndNode(const AnfNodePtr & temp_node,const FuncGraphManagerPtr & manager)121 void InsertVirtualFoldPipelineEndNode(const AnfNodePtr &temp_node, const FuncGraphManagerPtr &manager) {
122   auto end_node = GetPreNode(temp_node);
123   MS_EXCEPTION_IF_NULL(end_node);
124   auto end_cnode = end_node->cast<CNodePtr>();
125   MS_EXCEPTION_IF_NULL(end_cnode);
126   auto end_prim = GetCNodePrimitive(end_node);
127   OperatorAttrs attrs_;
128   auto op = CreateOpInstance(attrs_, "_VirtualPipelineEnd", "end_node");
129   auto value_node = NewValueNode(op);
130   auto new_prim = GetValueNode(value_node)->cast<PrimitivePtr>();
131   (void)new_prim->SetAttrs(end_prim->attrs());
132   manager->SetEdge(end_node, 0, value_node);
133   end_cnode->AddPrimalAttr(PIPELINE_END, end_cnode->GetPrimalAttr(MICRO));
134   auto seg = ParallelContext::GetInstance()->pipeline_segment_split_num();
135   end_cnode->AddPrimalAttr(SEGMENT, MakeValue(seg - 1));
136 }
137 
FindNodeFirstUser(const FuncGraphPtr & root,const AnfNodePtr & node)138 AnfNodePtr FindNodeFirstUser(const FuncGraphPtr &root, const AnfNodePtr &node) {
139   MS_EXCEPTION_IF_NULL(root);
140   auto node_users_map = root->manager()->node_users();
141   auto users = node_users_map[node];
142   for (auto &temp_user : users) {
143     MS_LOG(INFO) << "Receive user: " << (temp_user.first)->ToString();
144     return temp_user.first;
145   }
146   return nullptr;
147 }
148 
IsInEndNodeBlackListOrParallelBlackList(const CNodePtr & cnode)149 static bool IsInEndNodeBlackListOrParallelBlackList(const CNodePtr &cnode) {
150   MS_EXCEPTION_IF_NULL(cnode);
151   if (!IsValueNode<Primitive>(cnode->input(0))) {
152     return true;
153   }
154   auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
155   if (IsInParallelBlackList(prim)) {
156     return true;
157   }
158   for (auto &prim_node : END_NODE_BLACK_LIST) {
159     if (IsPrimitiveCNode(cnode, prim_node)) {
160       return true;
161     }
162   }
163   return false;
164 }
165 
GetPreNode(const AnfNodePtr & node)166 AnfNodePtr GetPreNode(const AnfNodePtr &node) {
167   auto cnode = node->cast<CNodePtr>();
168   MS_EXCEPTION_IF_NULL(cnode);
169   std::vector<AnfNodePtr> node_queue = {node};
170   while (!node_queue.empty()) {
171     auto cur_node = (*node_queue.begin())->cast<CNodePtr>();
172     (void)node_queue.erase(node_queue.begin());
173     if (!cur_node) {
174       continue;
175     }
176     if (!IsInEndNodeBlackListOrParallelBlackList(cur_node) && cur_node->HasPrimalAttr(NEED_GRAD)) {
177       MS_LOG(INFO) << "Pipeline End node: " << cur_node->DebugString();
178       return cur_node;
179     }
180     (void)node_queue.insert(node_queue.end(), cur_node->inputs().begin() + 1, cur_node->inputs().end());
181   }
182   MS_LOG(EXCEPTION) << "Get Pipeline End node failed.";
183 }
184 
ComputeLastSegForwardEndIdx(const PipelinePair & forward_start,size_t curr_idx,int64_t micro_max,int64_t stage_num,int64_t stage_id)185 static bool ComputeLastSegForwardEndIdx(const PipelinePair &forward_start, size_t curr_idx, int64_t micro_max,
186                                         int64_t stage_num, int64_t stage_id) {
187   auto last_seg_idx = static_cast<size_t>(1 + micro_max + 1 - 2 * (stage_num - stage_id - 1) - 1);
188   return curr_idx > forward_start.first.size() - last_seg_idx;
189 }
190 
ReorderForFoldPipelineForward(const std::vector<PipelinePair> & pair_vector,int64_t seg_max,int64_t micro_max,const FuncGraphPtr & root,AnfNodePtr * start_of_forward,AnfNodePtr * end_of_forward,bool enable_1f1b)191 void ReorderForFoldPipelineForward(const std::vector<PipelinePair> &pair_vector, int64_t seg_max, int64_t micro_max,
192                                    const FuncGraphPtr &root, AnfNodePtr *start_of_forward, AnfNodePtr *end_of_forward,
193                                    bool enable_1f1b) {
194   MS_EXCEPTION_IF_NULL(g_device_manager);
195   MS_EXCEPTION_IF_NULL(root);
196   auto manager = root->manager();
197   MS_EXCEPTION_IF_NULL(manager);
198 
199   auto stage_num = g_device_manager->stage_num();
200   auto stage_id = g_device_manager->stage_id();
201   *start_of_forward = pair_vector[kForwardStart].first[0];
202   for (size_t i = 1; i < pair_vector[kForwardStart].first.size(); ++i) {
203     auto prior_node_begin = pair_vector[kForwardEnd].first[i - 1];
204     auto prior_node_end = pair_vector[kForwardEnd].second[i - 1];
205     auto post_node_begin = pair_vector[kForwardStart].first[i];
206     auto post_node_end = pair_vector[kForwardStart].second[i];
207     if (IsFirstStage() && (i > IntToSize(micro_max))) {
208       auto receive_node = post_node_begin;
209       post_node_begin = FindNodeFirstUser(root, post_node_begin);
210 
211       MS_EXCEPTION_IF_NULL(post_node_begin);
212       auto insert_idx = i - LongToSize(micro_max + 1) + LongToSize(stage_num - 1);
213       auto send_node_begin = pair_vector[3].first[insert_idx];
214       auto send_node_end = pair_vector[3].second[insert_idx];
215       InsertDepend(post_node_end, send_node_begin, manager, root);
216 
217       auto send_cnode = send_node_begin->cast<CNodePtr>();
218       auto before_send_node = GetActualOp(send_cnode->input(1));
219 
220       InsertDepend(before_send_node, receive_node, manager, root);
221     }
222     if (enable_1f1b && ComputeLastSegForwardEndIdx(pair_vector[kForwardStart], i, micro_max, stage_num, stage_id)) {
223       continue;
224     }
225 
226     InsertDepend(prior_node_end, post_node_begin, manager, root);
227     *end_of_forward = pair_vector[kForwardEnd].second[i];
228   }
229   (*end_of_forward)->cast<CNodePtr>()->AddPrimalAttr(FORWARD_END, MakeValue(true));
230   (*end_of_forward)->cast<CNodePtr>()->AddPrimalAttr(SEGMENT_MAX, MakeValue(seg_max));
231 }
232 
ReorderForBackwardLastSeg(const std::vector<PipelinePair> & pair_vector,const FuncGraphPtr & root,AnfNodePtr * start_of_backward,AnfNodePtr * end_of_backward,int64_t micro_max)233 void ReorderForBackwardLastSeg(const std::vector<PipelinePair> &pair_vector, const FuncGraphPtr &root,
234                                AnfNodePtr *start_of_backward, AnfNodePtr *end_of_backward, int64_t micro_max) {
235   MS_EXCEPTION_IF_NULL(g_device_manager);
236   MS_EXCEPTION_IF_NULL(root);
237   auto manager = root->manager();
238   MS_EXCEPTION_IF_NULL(manager);
239   auto stage_num = g_device_manager->stage_num();
240   auto stage_id = g_device_manager->stage_id();
241   int64_t seg_max = GetSegmentMax(root, pair_vector[3].second);
242   MS_LOG(INFO) << "Micro max:" << micro_max << "seg_max" << seg_max;
243   int64_t last_seg_index = SizeToLong(pair_vector[2].first.size()) - 1 - micro_max;
244   int64_t cur_stage_fwd_max_idx = 2 * (stage_num - stage_id - 1) + 1;
245   if (!IsFirstStage() && (micro_max + 1 > cur_stage_fwd_max_idx)) {
246     for (size_t i = LongToSize(cur_stage_fwd_max_idx); i < LongToSize(micro_max + 1); ++i) {
247       auto forward_node_begin = pair_vector[2].first[LongToSize(last_seg_index) + i];
248       auto forward_node_end = pair_vector[2].second[LongToSize(last_seg_index) + i];
249       size_t insert_idx;
250       if (i == LongToSize(cur_stage_fwd_max_idx)) {
251         if (IsLastStage()) {
252           continue;
253         }
254         insert_idx = LongToSize(last_seg_index) + i - 1;
255         auto post_node = pair_vector[3].first[insert_idx];
256         InsertDepend(forward_node_end, post_node, manager, root);
257 
258         auto prior_node = pair_vector[4].second[insert_idx];
259         InsertDepend(prior_node, forward_node_begin, manager, root);
260       } else {
261         if (IsLastStage() && i == LongToSize(cur_stage_fwd_max_idx + 1)) {
262           auto post_node0 = pair_vector[1].first[0];
263           InsertDepend(forward_node_end, post_node0, manager, root);
264           auto pre_prior_node = pair_vector[2].second[LongToSize(last_seg_index) + i - 1];
265           InsertDepend(pre_prior_node, forward_node_begin, manager, root);
266           auto pre_post_node = pair_vector[2].first[LongToSize(last_seg_index) + i - 1];
267           auto prior_node0 = GetActualOp(pair_vector[1].first[0]->cast<CNodePtr>()->input(1));
268           InsertDepend(prior_node0, pre_post_node, manager, root);
269           continue;
270         }
271         insert_idx = i - LongToSize(cur_stage_fwd_max_idx) - 1;
272         auto post_node1 = pair_vector[1].first[insert_idx];
273         InsertDepend(forward_node_end, post_node1, manager, root);
274 
275         auto prior_cnode1 = post_node1->cast<CNodePtr>();
276         auto before_prior_cnode = GetActualOp(prior_cnode1->input(1));
277         InsertDepend(before_prior_cnode, forward_node_begin, manager, root);
278       }
279     }
280   }
281 
282   if (micro_max + 1 > cur_stage_fwd_max_idx) {
283     for (size_t i = LongToSize(cur_stage_fwd_max_idx); i < LongToSize(micro_max + 1); ++i) {
284       if (!IsLastStage()) {
285         auto prior_node1 = pair_vector[3].second[last_seg_index + i];
286         auto post_node1 = pair_vector[0].first[LongToSize(SizeToLong(i) - cur_stage_fwd_max_idx + 1)];
287         InsertDepend(prior_node1, post_node1, manager, root);
288       }
289       std::shared_ptr<AnfNode> post_node2;
290       post_node2 = FindNodeFirstUser(root, pair_vector[kForwardStart].first[last_seg_index + i]);
291       auto prior_node2 = pair_vector[1].second[LongToSize(SizeToLong(i) - cur_stage_fwd_max_idx)];
292       InsertDepend(prior_node2, post_node2, manager, root);
293     }
294 
295     for (size_t j = LongToSize(micro_max + 1 - 2 * (stage_num - stage_id - 1)); j < LongToSize(micro_max + 1); ++j) {
296       auto prior_node3 = pair_vector[1].second[j - 1];
297       auto post_node3 = pair_vector[0].first[j];
298       InsertDepend(prior_node3, post_node3, manager, root);
299     }
300   } else {
301     for (size_t j = 1; j < LongToSize(micro_max + 1); ++j) {
302       auto prior_node4 = pair_vector[1].second[j - 1];
303       auto post_node4 = pair_vector[0].first[j];
304       InsertDepend(prior_node4, post_node4, manager, root);
305     }
306   }
307 
308   if (!IsLastStage()) {
309     std::shared_ptr<AnfNode> prior_node5;
310     if ((micro_max + 1 > cur_stage_fwd_max_idx)) {
311       prior_node5 = pair_vector[kForwardEnd].second[LongToSize(last_seg_index + cur_stage_fwd_max_idx - 1)];
312     } else {
313       prior_node5 = pair_vector[kForwardEnd].second[LongToSize(last_seg_index + micro_max)];
314     }
315     auto post_node5 = pair_vector[0].first[0];
316     InsertDepend(prior_node5, post_node5, manager, root);
317   }
318 
319   for (size_t i = 0; i < pair_vector[0].first.size(); ++i) {
320     pair_vector[0].first[i]->cast<CNodePtr>()->AddPrimalAttr(BACKWARD_MICRO_END, MakeValue(true));
321     pair_vector[0].first[i]->cast<CNodePtr>()->AddPrimalAttr(SEGMENT_MAX, MakeValue(seg_max));
322   }
323   *start_of_backward = pair_vector[0].first[0];
324   *end_of_backward = pair_vector[1].second.back();
325   ReorderForBackwardOtherSeg(pair_vector[0], pair_vector[1], micro_max, stage_num, root);
326 }
327 
ReorderForBackwardOtherSeg(const PipelinePair & backward_start_pair,const PipelinePair & backward_end_pair,int64_t micro_max,int64_t stage_num,const FuncGraphPtr & root)328 void ReorderForBackwardOtherSeg(const PipelinePair &backward_start_pair, const PipelinePair &backward_end_pair,
329                                 int64_t micro_max, int64_t stage_num, const FuncGraphPtr &root) {
330   MS_EXCEPTION_IF_NULL(root);
331   auto manager = root->manager();
332   for (size_t i = LongToSize(micro_max) + 1; i < backward_start_pair.first.size(); ++i) {
333     auto prior_node_begin = backward_end_pair.first[i - 1];
334     auto prior_node_end = backward_end_pair.second[i - 1];
335     auto post_node_begin = backward_start_pair.first[i];
336     auto post_node_end = backward_start_pair.second[i];
337 
338     if (IsLastStage() && (i > IntToSize(micro_max))) {
339       auto receive_node = post_node_begin;
340       post_node_begin = FindNodeFirstUser(root, post_node_begin);
341       int64_t insert_idx = SizeToLong(i) - (micro_max + 1) + (stage_num - 1);
342       auto send_node_begin = backward_end_pair.first[insert_idx];
343       auto send_node_end = backward_end_pair.second[insert_idx];
344       InsertDepend(post_node_end, send_node_begin, manager, root);
345 
346       auto send_cnode = send_node_begin->cast<CNodePtr>();
347       auto before_send_node = GetActualOp(send_cnode->input(1));
348       before_send_node = GetActualOp((before_send_node->cast<CNodePtr>())->input(1));
349 
350       InsertDepend(before_send_node, receive_node, manager, root);
351     }
352 
353     InsertDepend(prior_node_end, post_node_begin, manager, root);
354   }
355 }
356 
Deduplicate(const std::vector<AnfNodePtr> & node_vector,const FuncGraphPtr & root,int64_t micro_max,int64_t seg_max,bool is_train)357 PipelinePair Deduplicate(const std::vector<AnfNodePtr> &node_vector, const FuncGraphPtr &root, int64_t micro_max,
358                          int64_t seg_max, bool is_train) {
359   std::vector<AnfNodePtr> out_vec_begin;
360   std::vector<AnfNodePtr> out_vec_end;
361   for (int64_t h = 0; h <= seg_max; ++h) {
362     CommonDeduplicate(node_vector, &out_vec_begin, &out_vec_end, root, micro_max, seg_max, h, is_train);
363   }
364   if (out_vec_begin.empty()) {
365     return std::make_pair(node_vector, node_vector);
366   }
367   return std::make_pair(out_vec_begin, out_vec_end);
368 }
369 
DeduplicateBySegAscending(const std::vector<AnfNodePtr> & node_vector,const FuncGraphPtr & root,int64_t micro_max,bool is_train,int64_t seg_max=0)370 PipelinePair DeduplicateBySegAscending(const std::vector<AnfNodePtr> &node_vector, const FuncGraphPtr &root,
371                                        int64_t micro_max, bool is_train, int64_t seg_max = 0) {
372   std::vector<AnfNodePtr> out_vec_begin;
373   std::vector<AnfNodePtr> out_vec_end;
374   for (int64_t h = 0; h <= seg_max; ++h) {
375     CommonDeduplicate(node_vector, &out_vec_begin, &out_vec_end, root, micro_max, seg_max, h, is_train);
376   }
377   if (out_vec_begin.empty()) {
378     return std::make_pair(node_vector, node_vector);
379   }
380   return std::make_pair(out_vec_begin, out_vec_end);
381 }
382 
DeduplicateBySegDescending(const std::vector<AnfNodePtr> & node_vector,const FuncGraphPtr & root,int64_t micro_max,bool is_train,int64_t seg_max=0)383 PipelinePair DeduplicateBySegDescending(const std::vector<AnfNodePtr> &node_vector, const FuncGraphPtr &root,
384                                         int64_t micro_max, bool is_train, int64_t seg_max = 0) {
385   std::vector<AnfNodePtr> out_vec_begin;
386   std::vector<AnfNodePtr> out_vec_end;
387   for (int64_t h = seg_max; h >= 0; --h) {
388     CommonDeduplicate(node_vector, &out_vec_begin, &out_vec_end, root, micro_max, seg_max, h, is_train);
389   }
390   if (out_vec_begin.empty()) {
391     return std::make_pair(node_vector, node_vector);
392   }
393   return std::make_pair(out_vec_begin, out_vec_end);
394 }
395 
ReorderForFoldPipelineBackward(const std::vector<PipelinePair> & pair_vector,int64_t seg_max,int64_t micro_max,const FuncGraphPtr & root,AnfNodePtr * start_of_backward,AnfNodePtr * end_of_backward)396 void ReorderForFoldPipelineBackward(const std::vector<PipelinePair> &pair_vector, int64_t seg_max, int64_t micro_max,
397                                     const FuncGraphPtr &root, AnfNodePtr *start_of_backward,
398                                     AnfNodePtr *end_of_backward) {
399   MS_EXCEPTION_IF_NULL(g_device_manager);
400   MS_EXCEPTION_IF_NULL(root);
401   auto manager = root->manager();
402   MS_EXCEPTION_IF_NULL(manager);
403   auto stage_num = g_device_manager->stage_num();
404 
405   bool first = true;
406   for (size_t i = 0; i < pair_vector[0].first.size(); ++i) {
407     pair_vector[0].first[i]->cast<CNodePtr>()->AddPrimalAttr(BACKWARD_MICRO_END, MakeValue(true));
408     pair_vector[0].first[i]->cast<CNodePtr>()->AddPrimalAttr(SEGMENT_MAX, MakeValue(seg_max));
409   }
410   for (size_t i = 1; i < pair_vector[0].first.size(); ++i) {
411     auto prior_node_begin = pair_vector[1].first[i - 1];
412     auto prior_node_end = pair_vector[1].second[i - 1];
413     auto post_node_begin = pair_vector[0].first[i];
414     auto post_node_end = pair_vector[0].second[i];
415 
416     if (IsLastStage() && (i > IntToSize(micro_max))) {
417       auto receive_node = post_node_begin;
418       post_node_begin = FindNodeFirstUser(root, post_node_begin);
419       auto insert_idx = i - (IntToSize(micro_max) + 1) + (IntToSize(stage_num) - 1);
420       auto send_node_begin = pair_vector[1].first[insert_idx];
421       auto send_node_end = pair_vector[1].second[insert_idx];
422 
423       InsertDepend(post_node_end, send_node_begin, manager, root);
424 
425       auto send_cnode = send_node_begin->cast<CNodePtr>();
426       auto before_send_node = GetActualOp(send_cnode->input(1));
427       before_send_node = GetActualOp((before_send_node->cast<CNodePtr>())->input(1));
428 
429       InsertDepend(before_send_node, receive_node, manager, root);
430     }
431 
432     InsertDepend(prior_node_end, post_node_begin, manager, root);
433     if (first) {
434       *start_of_backward = pair_vector[0].first[i - 1];
435       first = false;
436     }
437   }
438   *end_of_backward = pair_vector[1].second.back();
439 }
440 
UpdateSubPairs(int64_t sub_step_num,int64_t micro_num,std::vector<PipelinePair> pair_vector,int64_t sub_micro_num,int64_t seg_num)441 PipelinePairVector UpdateSubPairs(int64_t sub_step_num, int64_t micro_num, std::vector<PipelinePair> pair_vector,
442                                   int64_t sub_micro_num, int64_t seg_num) {
443   PipelinePairVector sub_pair_vector;
444   PipelinePairVector tmp_pair_vector;
445   if (micro_num % sub_step_num != 0) {
446     MS_LOG(EXCEPTION) << "Micro_num(" << micro_num << ")cannot be divisible by sub_step_num(" << sub_step_num << ").";
447   }
448 
449   if (sub_micro_num < g_device_manager->stage_num()) {
450     MS_LOG(EXCEPTION) << "Sub_micro_num(" << sub_micro_num << ") is less than stage_num("
451                       << g_device_manager->stage_num() << ").";
452   }
453   MS_LOG(INFO) << "Micro_num=" << micro_num << ",sub_micro_num=" << sub_micro_num << ",seg_num = " << seg_num;
454 
455   std::transform(pair_vector.begin(), pair_vector.end(), std::back_inserter(tmp_pair_vector),
456                  [&sub_step_num, &seg_num, &sub_micro_num, &micro_num](const auto &pipeline_pair) {
457                    return GetSubStepPairs(pipeline_pair, sub_step_num, seg_num, sub_micro_num, micro_num);
458                  });
459 
460   for (size_t i = 0; i < tmp_pair_vector.size(); i++) {
461     std::vector<PipelinePair> sub_step1;
462     std::vector<PipelinePair> sub_step2;
463     if (!sub_pair_vector.empty()) {
464       sub_pair_vector[0].push_back(sub_pair_vector[i][0]);
465       sub_pair_vector[1].push_back(sub_pair_vector[i][1]);
466     } else {
467       sub_step1.push_back(sub_pair_vector[i][0]);
468       sub_pair_vector.push_back(sub_step1);
469       sub_step2.push_back(sub_pair_vector[i][1]);
470       sub_pair_vector.push_back(sub_step2);
471     }
472   }
473   return sub_pair_vector;
474 }
475 
FoldPipelineReorder(const FuncGraphPtr & root)476 void FoldPipelineReorder(const FuncGraphPtr &root) {
477   std::vector<AnfNodePtr> forward_start;
478   std::vector<AnfNodePtr> forward_end;
479   std::vector<AnfNodePtr> forward_params;
480   std::vector<AnfNodePtr> backward_start;
481   std::vector<AnfNodePtr> backward_end;
482   std::vector<AnfNodePtr> backward_params;
483   std::vector<AnfNodePtr> allreduce_params;
484 
485   SetParameterStartForCellShare(root);
486   GetBorderNode(&forward_start, &forward_end, &backward_start, &backward_end, &forward_params, &backward_params,
487                 &allreduce_params, root);
488   int64_t micro_max = GetMicroMax(root, forward_end);
489   int64_t seg_max = GetSegmentMax(root, forward_end);
490   std::vector<int64_t> seg_micro_max{micro_max, seg_max};
491 
492   auto backward_start_pair = DeduplicateBySegDescending(backward_start, root, micro_max, true, seg_max);
493   auto backward_end_pair = DeduplicateBySegDescending(backward_end, root, micro_max, true, seg_max);
494   auto forward_start_pair = DeduplicateBySegAscending(forward_start, root, micro_max, true, seg_max);
495   auto forward_end_pair = DeduplicateBySegAscending(forward_end, root, micro_max, true, seg_max);
496   auto forward_params_pair = Deduplicate(forward_params, root, micro_max, true, seg_max);
497   auto backward_params_pair = Deduplicate(backward_params, root, micro_max, true, seg_max);
498   CheckBorderNode(forward_start_pair, forward_end_pair, backward_start_pair, backward_end_pair, seg_micro_max);
499   auto forward_end_before_pair = GetForwardEndBeforePair(forward_end_pair);
500   std::vector<PipelinePair> pair_vector{backward_start_pair, backward_end_pair, forward_start_pair, forward_end_pair,
501                                         forward_end_before_pair};
502   AnfNodePtr start_of_forward;
503   AnfNodePtr end_of_forward;
504   AnfNodePtr start_of_backward;
505   AnfNodePtr end_of_backward;
506   AnfNodePtr pre_end_of_backward;
507 
508   bool enable_1f1b = false;
509   if (common::GetEnv("FOLD_LAST_SEG_1F1B") != "") {
510     enable_1f1b = true;
511   }
512   int64_t sub_step_num = 0;
513   int64_t sub_micro_num = 0;
514   if (common::GetEnv("FOLD_ACCUMULATION") != "") sub_step_num = std::stoi(common::GetEnv("FOLD_ACCUMULATION"));
515   MS_LOG(INFO) << "Sub_step_num=" << sub_step_num;
516   PipelinePairVector sub_pair_vector;
517   if (sub_step_num > 0) {
518     int64_t micro_num = micro_max + 1;
519     int64_t seg_num = seg_max + 1;
520     sub_micro_num = micro_num / sub_step_num;
521     sub_pair_vector = UpdateSubPairs(sub_step_num, micro_num, pair_vector, sub_micro_num, seg_num);
522   }
523 
524   if (enable_1f1b) {
525     if (sub_step_num > 0) {
526       for (int64_t s = 0; s < sub_step_num; s++) {
527         ReorderForFoldPipelineForward(sub_pair_vector[s], seg_max, sub_micro_num - 1, root, &start_of_forward,
528                                       &end_of_forward, enable_1f1b);
529         ReorderForBackwardLastSeg(sub_pair_vector[s], root, &start_of_backward, &end_of_backward, sub_micro_num - 1);
530         if (s > 0) {
531           InsertDepend(pre_end_of_backward, start_of_forward, root->manager(), root);
532         }
533         pre_end_of_backward = end_of_backward;
534         ReorderForParams(backward_params_pair, forward_params_pair, sub_pair_vector[kBackwardEnd][s],
535                          sub_pair_vector[kForwardStart][s], root);
536       }
537     } else {
538       ReorderForFoldPipelineForward(pair_vector, seg_max, micro_max, root, &start_of_forward, &end_of_forward,
539                                     enable_1f1b);
540       ReorderForBackwardLastSeg(pair_vector, root, &start_of_backward, &end_of_backward, micro_max);
541       ReorderForParams(backward_params_pair, forward_params_pair, backward_end_pair, forward_start_pair, root);
542     }
543   } else {
544     if (sub_step_num > 0) {
545       for (int64_t s = 0; s < sub_step_num; s++) {
546         ReorderForFoldPipelineForward(sub_pair_vector[s], seg_max, sub_micro_num - 1, root, &start_of_forward,
547                                       &end_of_forward, enable_1f1b);
548 
549         ReorderForFoldPipelineBackward(sub_pair_vector[s], seg_max, sub_micro_num - 1, root, &start_of_backward,
550                                        &end_of_backward);
551         InsertDepend(end_of_forward, start_of_backward, root->manager(), root);
552         if (s > 0) {
553           InsertDepend(pre_end_of_backward, start_of_forward, root->manager(), root);
554         }
555         pre_end_of_backward = end_of_backward;
556         ReorderForParams(backward_params_pair, forward_params_pair, sub_pair_vector[1][s], sub_pair_vector[2][s], root);
557       }
558     } else {
559       ReorderForFoldPipelineForward(pair_vector, seg_max, micro_max, root, &start_of_forward, &end_of_forward,
560                                     enable_1f1b);
561       ReorderForFoldPipelineBackward(pair_vector, seg_max, micro_max, root, &start_of_backward, &end_of_backward);
562       InsertDepend(end_of_forward, start_of_backward, root->manager(), root);
563       ReorderForParams(backward_params_pair, forward_params_pair, backward_end_pair, forward_start_pair, root);
564     }
565   }
566 }
567 
568 }  // namespace parallel
569 }  // namespace mindspore
570