1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/parallel/graph_util/fold_pipeline_split_utils.h"
18 #include <memory>
19 #include <list>
20 #include <set>
21 #include <queue>
22 #include <algorithm>
23
24 #include "frontend/parallel/graph_util/generate_graph.h"
25 #include "frontend/parallel/graph_util/pipeline_split_utils.h"
26 #include "ops/other_ops.h"
27 #include "ops/math_ops.h"
28 #include "ops/framework_ops.h"
29 #include "ops/array_ops.h"
30 #include "ops/nn_ops.h"
31 #include "ir/value.h"
32 #include "frontend/parallel/ops_info/ops_utils.h"
33 #include "frontend/parallel/device_manager.h"
34 #include "include/common/utils/parallel_context.h"
35 #include "frontend/parallel/step_parallel.h"
36 #include "frontend/parallel/step_parallel_utils.h"
37 #include "frontend/parallel/graph_util/node_info.h"
38 #include "utils/parallel_node_check.h"
39
40 namespace mindspore {
41 namespace parallel {
42
43 namespace {
44 constexpr int kBackwardEnd = 1;
45 constexpr int kForwardStart = 2;
46 constexpr int kForwardEnd = 3;
47 } // namespace
48
49 const std::set<PrimitivePtr> END_NODE_BLACK_LIST = {
50 prim::kPrimDepend, prim::kPrimTupleGetItem, prim::kPrimAdd, prim::kPrimSoftmaxCrossEntropyWithLogits,
51 prim::kPrimMakeTuple, prim::kPrimUpdateState, prim::kPrimReshape};
52
GetSegmentMax(const FuncGraphPtr & root,const std::vector<AnfNodePtr> & forward_end)53 int64_t GetSegmentMax(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &forward_end) {
54 int64_t seg_max = 0;
55 if (forward_end.empty()) {
56 MS_LOG(EXCEPTION) << "Can not find the end node of pipeline, you are advised to use 'PipelineCell' to fix it.";
57 } else {
58 auto forward_end_cnode = forward_end.back()->cast<CNodePtr>();
59 auto seg_size = forward_end_cnode->GetPrimalAttr(SEGMENT);
60 MS_EXCEPTION_IF_NULL(seg_size);
61 seg_max = GetValue<int64_t>(seg_size);
62 }
63 return seg_max;
64 }
65
GetSubStepPairs(const PipelinePair & fp_or_bp_pair,int64_t sub_step_num,int64_t seg_num,int64_t sub_micro_num,int64_t micro_num)66 std::vector<PipelinePair> GetSubStepPairs(const PipelinePair &fp_or_bp_pair, int64_t sub_step_num, int64_t seg_num,
67 int64_t sub_micro_num, int64_t micro_num) {
68 std::vector<PipelinePair> fp_or_bp_sub_pairs;
69 for (int64_t s = 0; s < sub_step_num; s++) {
70 std::vector<AnfNodePtr> temp_first;
71 std::vector<AnfNodePtr> temp_second;
72 for (int64_t sid = 0; sid < seg_num; sid++) {
73 temp_first.insert(temp_first.end(), fp_or_bp_pair.first.begin() + s * sub_micro_num + sid * micro_num,
74 fp_or_bp_pair.first.begin() + (s + 1) * sub_micro_num + sid * micro_num);
75 temp_second.insert(temp_second.end(), fp_or_bp_pair.second.begin() + s * sub_micro_num + sid * micro_num,
76 fp_or_bp_pair.second.begin() + (s + 1) * sub_micro_num + sid * micro_num);
77 }
78 fp_or_bp_sub_pairs.emplace_back(temp_first, temp_second);
79 }
80 return fp_or_bp_sub_pairs;
81 }
82
CompFuncBySegAscending(const AnfNodePtr & node1,const AnfNodePtr & node2)83 bool CompFuncBySegAscending(const AnfNodePtr &node1, const AnfNodePtr &node2) {
84 auto parallel_context = parallel::ParallelContext::GetInstance();
85 if (parallel_context->enable_fold_pipeline()) {
86 auto get_value_func = [](const AnfNodePtr &node) {
87 MS_EXCEPTION_IF_NULL(node);
88 auto cnode = node->cast<CNodePtr>();
89 MS_EXCEPTION_IF_NULL(cnode);
90 auto seg = cnode->GetPrimalAttr(SEGMENT);
91 MS_EXCEPTION_IF_NULL(seg);
92 return GetValue<int64_t>(seg);
93 };
94
95 if (get_value_func(node1) != get_value_func(node2)) {
96 return get_value_func(node1) < get_value_func(node2);
97 }
98 }
99 return CompFunc(node1, node2);
100 }
101
CompFuncBySegDescending(const AnfNodePtr & node1,const AnfNodePtr & node2)102 bool CompFuncBySegDescending(const AnfNodePtr &node1, const AnfNodePtr &node2) {
103 auto parallel_context = parallel::ParallelContext::GetInstance();
104 if (parallel_context->enable_fold_pipeline()) {
105 auto get_value_func = [](const AnfNodePtr &node) {
106 MS_EXCEPTION_IF_NULL(node);
107 auto cnode = node->cast<CNodePtr>();
108 MS_EXCEPTION_IF_NULL(cnode);
109 auto seg = cnode->GetPrimalAttr(SEGMENT);
110 MS_EXCEPTION_IF_NULL(seg);
111 return GetValue<int64_t>(seg);
112 };
113
114 if (get_value_func(node1) != get_value_func(node2)) {
115 return get_value_func(node1) > get_value_func(node2);
116 }
117 }
118 return CompFunc(node1, node2);
119 }
120
InsertVirtualFoldPipelineEndNode(const AnfNodePtr & temp_node,const FuncGraphManagerPtr & manager)121 void InsertVirtualFoldPipelineEndNode(const AnfNodePtr &temp_node, const FuncGraphManagerPtr &manager) {
122 auto end_node = GetPreNode(temp_node);
123 MS_EXCEPTION_IF_NULL(end_node);
124 auto end_cnode = end_node->cast<CNodePtr>();
125 MS_EXCEPTION_IF_NULL(end_cnode);
126 auto end_prim = GetCNodePrimitive(end_node);
127 OperatorAttrs attrs_;
128 auto op = CreateOpInstance(attrs_, "_VirtualPipelineEnd", "end_node");
129 auto value_node = NewValueNode(op);
130 auto new_prim = GetValueNode(value_node)->cast<PrimitivePtr>();
131 (void)new_prim->SetAttrs(end_prim->attrs());
132 manager->SetEdge(end_node, 0, value_node);
133 end_cnode->AddPrimalAttr(PIPELINE_END, end_cnode->GetPrimalAttr(MICRO));
134 auto seg = ParallelContext::GetInstance()->pipeline_segment_split_num();
135 end_cnode->AddPrimalAttr(SEGMENT, MakeValue(seg - 1));
136 }
137
FindNodeFirstUser(const FuncGraphPtr & root,const AnfNodePtr & node)138 AnfNodePtr FindNodeFirstUser(const FuncGraphPtr &root, const AnfNodePtr &node) {
139 MS_EXCEPTION_IF_NULL(root);
140 auto node_users_map = root->manager()->node_users();
141 auto users = node_users_map[node];
142 for (auto &temp_user : users) {
143 MS_LOG(INFO) << "Receive user: " << (temp_user.first)->ToString();
144 return temp_user.first;
145 }
146 return nullptr;
147 }
148
IsInEndNodeBlackListOrParallelBlackList(const CNodePtr & cnode)149 static bool IsInEndNodeBlackListOrParallelBlackList(const CNodePtr &cnode) {
150 MS_EXCEPTION_IF_NULL(cnode);
151 if (!IsValueNode<Primitive>(cnode->input(0))) {
152 return true;
153 }
154 auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
155 if (IsInParallelBlackList(prim)) {
156 return true;
157 }
158 for (auto &prim_node : END_NODE_BLACK_LIST) {
159 if (IsPrimitiveCNode(cnode, prim_node)) {
160 return true;
161 }
162 }
163 return false;
164 }
165
GetPreNode(const AnfNodePtr & node)166 AnfNodePtr GetPreNode(const AnfNodePtr &node) {
167 auto cnode = node->cast<CNodePtr>();
168 MS_EXCEPTION_IF_NULL(cnode);
169 std::vector<AnfNodePtr> node_queue = {node};
170 while (!node_queue.empty()) {
171 auto cur_node = (*node_queue.begin())->cast<CNodePtr>();
172 (void)node_queue.erase(node_queue.begin());
173 if (!cur_node) {
174 continue;
175 }
176 if (!IsInEndNodeBlackListOrParallelBlackList(cur_node) && cur_node->HasPrimalAttr(NEED_GRAD)) {
177 MS_LOG(INFO) << "Pipeline End node: " << cur_node->DebugString();
178 return cur_node;
179 }
180 (void)node_queue.insert(node_queue.end(), cur_node->inputs().begin() + 1, cur_node->inputs().end());
181 }
182 MS_LOG(EXCEPTION) << "Get Pipeline End node failed.";
183 }
184
ComputeLastSegForwardEndIdx(const PipelinePair & forward_start,size_t curr_idx,int64_t micro_max,int64_t stage_num,int64_t stage_id)185 static bool ComputeLastSegForwardEndIdx(const PipelinePair &forward_start, size_t curr_idx, int64_t micro_max,
186 int64_t stage_num, int64_t stage_id) {
187 auto last_seg_idx = static_cast<size_t>(1 + micro_max + 1 - 2 * (stage_num - stage_id - 1) - 1);
188 return curr_idx > forward_start.first.size() - last_seg_idx;
189 }
190
ReorderForFoldPipelineForward(const std::vector<PipelinePair> & pair_vector,int64_t seg_max,int64_t micro_max,const FuncGraphPtr & root,AnfNodePtr * start_of_forward,AnfNodePtr * end_of_forward,bool enable_1f1b)191 void ReorderForFoldPipelineForward(const std::vector<PipelinePair> &pair_vector, int64_t seg_max, int64_t micro_max,
192 const FuncGraphPtr &root, AnfNodePtr *start_of_forward, AnfNodePtr *end_of_forward,
193 bool enable_1f1b) {
194 MS_EXCEPTION_IF_NULL(g_device_manager);
195 MS_EXCEPTION_IF_NULL(root);
196 auto manager = root->manager();
197 MS_EXCEPTION_IF_NULL(manager);
198
199 auto stage_num = g_device_manager->stage_num();
200 auto stage_id = g_device_manager->stage_id();
201 *start_of_forward = pair_vector[kForwardStart].first[0];
202 for (size_t i = 1; i < pair_vector[kForwardStart].first.size(); ++i) {
203 auto prior_node_begin = pair_vector[kForwardEnd].first[i - 1];
204 auto prior_node_end = pair_vector[kForwardEnd].second[i - 1];
205 auto post_node_begin = pair_vector[kForwardStart].first[i];
206 auto post_node_end = pair_vector[kForwardStart].second[i];
207 if (IsFirstStage() && (i > IntToSize(micro_max))) {
208 auto receive_node = post_node_begin;
209 post_node_begin = FindNodeFirstUser(root, post_node_begin);
210
211 MS_EXCEPTION_IF_NULL(post_node_begin);
212 auto insert_idx = i - LongToSize(micro_max + 1) + LongToSize(stage_num - 1);
213 auto send_node_begin = pair_vector[3].first[insert_idx];
214 auto send_node_end = pair_vector[3].second[insert_idx];
215 InsertDepend(post_node_end, send_node_begin, manager, root);
216
217 auto send_cnode = send_node_begin->cast<CNodePtr>();
218 auto before_send_node = GetActualOp(send_cnode->input(1));
219
220 InsertDepend(before_send_node, receive_node, manager, root);
221 }
222 if (enable_1f1b && ComputeLastSegForwardEndIdx(pair_vector[kForwardStart], i, micro_max, stage_num, stage_id)) {
223 continue;
224 }
225
226 InsertDepend(prior_node_end, post_node_begin, manager, root);
227 *end_of_forward = pair_vector[kForwardEnd].second[i];
228 }
229 (*end_of_forward)->cast<CNodePtr>()->AddPrimalAttr(FORWARD_END, MakeValue(true));
230 (*end_of_forward)->cast<CNodePtr>()->AddPrimalAttr(SEGMENT_MAX, MakeValue(seg_max));
231 }
232
ReorderForBackwardLastSeg(const std::vector<PipelinePair> & pair_vector,const FuncGraphPtr & root,AnfNodePtr * start_of_backward,AnfNodePtr * end_of_backward,int64_t micro_max)233 void ReorderForBackwardLastSeg(const std::vector<PipelinePair> &pair_vector, const FuncGraphPtr &root,
234 AnfNodePtr *start_of_backward, AnfNodePtr *end_of_backward, int64_t micro_max) {
235 MS_EXCEPTION_IF_NULL(g_device_manager);
236 MS_EXCEPTION_IF_NULL(root);
237 auto manager = root->manager();
238 MS_EXCEPTION_IF_NULL(manager);
239 auto stage_num = g_device_manager->stage_num();
240 auto stage_id = g_device_manager->stage_id();
241 int64_t seg_max = GetSegmentMax(root, pair_vector[3].second);
242 MS_LOG(INFO) << "Micro max:" << micro_max << "seg_max" << seg_max;
243 int64_t last_seg_index = SizeToLong(pair_vector[2].first.size()) - 1 - micro_max;
244 int64_t cur_stage_fwd_max_idx = 2 * (stage_num - stage_id - 1) + 1;
245 if (!IsFirstStage() && (micro_max + 1 > cur_stage_fwd_max_idx)) {
246 for (size_t i = LongToSize(cur_stage_fwd_max_idx); i < LongToSize(micro_max + 1); ++i) {
247 auto forward_node_begin = pair_vector[2].first[LongToSize(last_seg_index) + i];
248 auto forward_node_end = pair_vector[2].second[LongToSize(last_seg_index) + i];
249 size_t insert_idx;
250 if (i == LongToSize(cur_stage_fwd_max_idx)) {
251 if (IsLastStage()) {
252 continue;
253 }
254 insert_idx = LongToSize(last_seg_index) + i - 1;
255 auto post_node = pair_vector[3].first[insert_idx];
256 InsertDepend(forward_node_end, post_node, manager, root);
257
258 auto prior_node = pair_vector[4].second[insert_idx];
259 InsertDepend(prior_node, forward_node_begin, manager, root);
260 } else {
261 if (IsLastStage() && i == LongToSize(cur_stage_fwd_max_idx + 1)) {
262 auto post_node0 = pair_vector[1].first[0];
263 InsertDepend(forward_node_end, post_node0, manager, root);
264 auto pre_prior_node = pair_vector[2].second[LongToSize(last_seg_index) + i - 1];
265 InsertDepend(pre_prior_node, forward_node_begin, manager, root);
266 auto pre_post_node = pair_vector[2].first[LongToSize(last_seg_index) + i - 1];
267 auto prior_node0 = GetActualOp(pair_vector[1].first[0]->cast<CNodePtr>()->input(1));
268 InsertDepend(prior_node0, pre_post_node, manager, root);
269 continue;
270 }
271 insert_idx = i - LongToSize(cur_stage_fwd_max_idx) - 1;
272 auto post_node1 = pair_vector[1].first[insert_idx];
273 InsertDepend(forward_node_end, post_node1, manager, root);
274
275 auto prior_cnode1 = post_node1->cast<CNodePtr>();
276 auto before_prior_cnode = GetActualOp(prior_cnode1->input(1));
277 InsertDepend(before_prior_cnode, forward_node_begin, manager, root);
278 }
279 }
280 }
281
282 if (micro_max + 1 > cur_stage_fwd_max_idx) {
283 for (size_t i = LongToSize(cur_stage_fwd_max_idx); i < LongToSize(micro_max + 1); ++i) {
284 if (!IsLastStage()) {
285 auto prior_node1 = pair_vector[3].second[last_seg_index + i];
286 auto post_node1 = pair_vector[0].first[LongToSize(SizeToLong(i) - cur_stage_fwd_max_idx + 1)];
287 InsertDepend(prior_node1, post_node1, manager, root);
288 }
289 std::shared_ptr<AnfNode> post_node2;
290 post_node2 = FindNodeFirstUser(root, pair_vector[kForwardStart].first[last_seg_index + i]);
291 auto prior_node2 = pair_vector[1].second[LongToSize(SizeToLong(i) - cur_stage_fwd_max_idx)];
292 InsertDepend(prior_node2, post_node2, manager, root);
293 }
294
295 for (size_t j = LongToSize(micro_max + 1 - 2 * (stage_num - stage_id - 1)); j < LongToSize(micro_max + 1); ++j) {
296 auto prior_node3 = pair_vector[1].second[j - 1];
297 auto post_node3 = pair_vector[0].first[j];
298 InsertDepend(prior_node3, post_node3, manager, root);
299 }
300 } else {
301 for (size_t j = 1; j < LongToSize(micro_max + 1); ++j) {
302 auto prior_node4 = pair_vector[1].second[j - 1];
303 auto post_node4 = pair_vector[0].first[j];
304 InsertDepend(prior_node4, post_node4, manager, root);
305 }
306 }
307
308 if (!IsLastStage()) {
309 std::shared_ptr<AnfNode> prior_node5;
310 if ((micro_max + 1 > cur_stage_fwd_max_idx)) {
311 prior_node5 = pair_vector[kForwardEnd].second[LongToSize(last_seg_index + cur_stage_fwd_max_idx - 1)];
312 } else {
313 prior_node5 = pair_vector[kForwardEnd].second[LongToSize(last_seg_index + micro_max)];
314 }
315 auto post_node5 = pair_vector[0].first[0];
316 InsertDepend(prior_node5, post_node5, manager, root);
317 }
318
319 for (size_t i = 0; i < pair_vector[0].first.size(); ++i) {
320 pair_vector[0].first[i]->cast<CNodePtr>()->AddPrimalAttr(BACKWARD_MICRO_END, MakeValue(true));
321 pair_vector[0].first[i]->cast<CNodePtr>()->AddPrimalAttr(SEGMENT_MAX, MakeValue(seg_max));
322 }
323 *start_of_backward = pair_vector[0].first[0];
324 *end_of_backward = pair_vector[1].second.back();
325 ReorderForBackwardOtherSeg(pair_vector[0], pair_vector[1], micro_max, stage_num, root);
326 }
327
ReorderForBackwardOtherSeg(const PipelinePair & backward_start_pair,const PipelinePair & backward_end_pair,int64_t micro_max,int64_t stage_num,const FuncGraphPtr & root)328 void ReorderForBackwardOtherSeg(const PipelinePair &backward_start_pair, const PipelinePair &backward_end_pair,
329 int64_t micro_max, int64_t stage_num, const FuncGraphPtr &root) {
330 MS_EXCEPTION_IF_NULL(root);
331 auto manager = root->manager();
332 for (size_t i = LongToSize(micro_max) + 1; i < backward_start_pair.first.size(); ++i) {
333 auto prior_node_begin = backward_end_pair.first[i - 1];
334 auto prior_node_end = backward_end_pair.second[i - 1];
335 auto post_node_begin = backward_start_pair.first[i];
336 auto post_node_end = backward_start_pair.second[i];
337
338 if (IsLastStage() && (i > IntToSize(micro_max))) {
339 auto receive_node = post_node_begin;
340 post_node_begin = FindNodeFirstUser(root, post_node_begin);
341 int64_t insert_idx = SizeToLong(i) - (micro_max + 1) + (stage_num - 1);
342 auto send_node_begin = backward_end_pair.first[insert_idx];
343 auto send_node_end = backward_end_pair.second[insert_idx];
344 InsertDepend(post_node_end, send_node_begin, manager, root);
345
346 auto send_cnode = send_node_begin->cast<CNodePtr>();
347 auto before_send_node = GetActualOp(send_cnode->input(1));
348 before_send_node = GetActualOp((before_send_node->cast<CNodePtr>())->input(1));
349
350 InsertDepend(before_send_node, receive_node, manager, root);
351 }
352
353 InsertDepend(prior_node_end, post_node_begin, manager, root);
354 }
355 }
356
Deduplicate(const std::vector<AnfNodePtr> & node_vector,const FuncGraphPtr & root,int64_t micro_max,int64_t seg_max,bool is_train)357 PipelinePair Deduplicate(const std::vector<AnfNodePtr> &node_vector, const FuncGraphPtr &root, int64_t micro_max,
358 int64_t seg_max, bool is_train) {
359 std::vector<AnfNodePtr> out_vec_begin;
360 std::vector<AnfNodePtr> out_vec_end;
361 for (int64_t h = 0; h <= seg_max; ++h) {
362 CommonDeduplicate(node_vector, &out_vec_begin, &out_vec_end, root, micro_max, seg_max, h, is_train);
363 }
364 if (out_vec_begin.empty()) {
365 return std::make_pair(node_vector, node_vector);
366 }
367 return std::make_pair(out_vec_begin, out_vec_end);
368 }
369
DeduplicateBySegAscending(const std::vector<AnfNodePtr> & node_vector,const FuncGraphPtr & root,int64_t micro_max,bool is_train,int64_t seg_max=0)370 PipelinePair DeduplicateBySegAscending(const std::vector<AnfNodePtr> &node_vector, const FuncGraphPtr &root,
371 int64_t micro_max, bool is_train, int64_t seg_max = 0) {
372 std::vector<AnfNodePtr> out_vec_begin;
373 std::vector<AnfNodePtr> out_vec_end;
374 for (int64_t h = 0; h <= seg_max; ++h) {
375 CommonDeduplicate(node_vector, &out_vec_begin, &out_vec_end, root, micro_max, seg_max, h, is_train);
376 }
377 if (out_vec_begin.empty()) {
378 return std::make_pair(node_vector, node_vector);
379 }
380 return std::make_pair(out_vec_begin, out_vec_end);
381 }
382
DeduplicateBySegDescending(const std::vector<AnfNodePtr> & node_vector,const FuncGraphPtr & root,int64_t micro_max,bool is_train,int64_t seg_max=0)383 PipelinePair DeduplicateBySegDescending(const std::vector<AnfNodePtr> &node_vector, const FuncGraphPtr &root,
384 int64_t micro_max, bool is_train, int64_t seg_max = 0) {
385 std::vector<AnfNodePtr> out_vec_begin;
386 std::vector<AnfNodePtr> out_vec_end;
387 for (int64_t h = seg_max; h >= 0; --h) {
388 CommonDeduplicate(node_vector, &out_vec_begin, &out_vec_end, root, micro_max, seg_max, h, is_train);
389 }
390 if (out_vec_begin.empty()) {
391 return std::make_pair(node_vector, node_vector);
392 }
393 return std::make_pair(out_vec_begin, out_vec_end);
394 }
395
ReorderForFoldPipelineBackward(const std::vector<PipelinePair> & pair_vector,int64_t seg_max,int64_t micro_max,const FuncGraphPtr & root,AnfNodePtr * start_of_backward,AnfNodePtr * end_of_backward)396 void ReorderForFoldPipelineBackward(const std::vector<PipelinePair> &pair_vector, int64_t seg_max, int64_t micro_max,
397 const FuncGraphPtr &root, AnfNodePtr *start_of_backward,
398 AnfNodePtr *end_of_backward) {
399 MS_EXCEPTION_IF_NULL(g_device_manager);
400 MS_EXCEPTION_IF_NULL(root);
401 auto manager = root->manager();
402 MS_EXCEPTION_IF_NULL(manager);
403 auto stage_num = g_device_manager->stage_num();
404
405 bool first = true;
406 for (size_t i = 0; i < pair_vector[0].first.size(); ++i) {
407 pair_vector[0].first[i]->cast<CNodePtr>()->AddPrimalAttr(BACKWARD_MICRO_END, MakeValue(true));
408 pair_vector[0].first[i]->cast<CNodePtr>()->AddPrimalAttr(SEGMENT_MAX, MakeValue(seg_max));
409 }
410 for (size_t i = 1; i < pair_vector[0].first.size(); ++i) {
411 auto prior_node_begin = pair_vector[1].first[i - 1];
412 auto prior_node_end = pair_vector[1].second[i - 1];
413 auto post_node_begin = pair_vector[0].first[i];
414 auto post_node_end = pair_vector[0].second[i];
415
416 if (IsLastStage() && (i > IntToSize(micro_max))) {
417 auto receive_node = post_node_begin;
418 post_node_begin = FindNodeFirstUser(root, post_node_begin);
419 auto insert_idx = i - (IntToSize(micro_max) + 1) + (IntToSize(stage_num) - 1);
420 auto send_node_begin = pair_vector[1].first[insert_idx];
421 auto send_node_end = pair_vector[1].second[insert_idx];
422
423 InsertDepend(post_node_end, send_node_begin, manager, root);
424
425 auto send_cnode = send_node_begin->cast<CNodePtr>();
426 auto before_send_node = GetActualOp(send_cnode->input(1));
427 before_send_node = GetActualOp((before_send_node->cast<CNodePtr>())->input(1));
428
429 InsertDepend(before_send_node, receive_node, manager, root);
430 }
431
432 InsertDepend(prior_node_end, post_node_begin, manager, root);
433 if (first) {
434 *start_of_backward = pair_vector[0].first[i - 1];
435 first = false;
436 }
437 }
438 *end_of_backward = pair_vector[1].second.back();
439 }
440
UpdateSubPairs(int64_t sub_step_num,int64_t micro_num,std::vector<PipelinePair> pair_vector,int64_t sub_micro_num,int64_t seg_num)441 PipelinePairVector UpdateSubPairs(int64_t sub_step_num, int64_t micro_num, std::vector<PipelinePair> pair_vector,
442 int64_t sub_micro_num, int64_t seg_num) {
443 PipelinePairVector sub_pair_vector;
444 PipelinePairVector tmp_pair_vector;
445 if (micro_num % sub_step_num != 0) {
446 MS_LOG(EXCEPTION) << "Micro_num(" << micro_num << ")cannot be divisible by sub_step_num(" << sub_step_num << ").";
447 }
448
449 if (sub_micro_num < g_device_manager->stage_num()) {
450 MS_LOG(EXCEPTION) << "Sub_micro_num(" << sub_micro_num << ") is less than stage_num("
451 << g_device_manager->stage_num() << ").";
452 }
453 MS_LOG(INFO) << "Micro_num=" << micro_num << ",sub_micro_num=" << sub_micro_num << ",seg_num = " << seg_num;
454
455 std::transform(pair_vector.begin(), pair_vector.end(), std::back_inserter(tmp_pair_vector),
456 [&sub_step_num, &seg_num, &sub_micro_num, µ_num](const auto &pipeline_pair) {
457 return GetSubStepPairs(pipeline_pair, sub_step_num, seg_num, sub_micro_num, micro_num);
458 });
459
460 for (size_t i = 0; i < tmp_pair_vector.size(); i++) {
461 std::vector<PipelinePair> sub_step1;
462 std::vector<PipelinePair> sub_step2;
463 if (!sub_pair_vector.empty()) {
464 sub_pair_vector[0].push_back(sub_pair_vector[i][0]);
465 sub_pair_vector[1].push_back(sub_pair_vector[i][1]);
466 } else {
467 sub_step1.push_back(sub_pair_vector[i][0]);
468 sub_pair_vector.push_back(sub_step1);
469 sub_step2.push_back(sub_pair_vector[i][1]);
470 sub_pair_vector.push_back(sub_step2);
471 }
472 }
473 return sub_pair_vector;
474 }
475
FoldPipelineReorder(const FuncGraphPtr & root)476 void FoldPipelineReorder(const FuncGraphPtr &root) {
477 std::vector<AnfNodePtr> forward_start;
478 std::vector<AnfNodePtr> forward_end;
479 std::vector<AnfNodePtr> forward_params;
480 std::vector<AnfNodePtr> backward_start;
481 std::vector<AnfNodePtr> backward_end;
482 std::vector<AnfNodePtr> backward_params;
483 std::vector<AnfNodePtr> allreduce_params;
484
485 SetParameterStartForCellShare(root);
486 GetBorderNode(&forward_start, &forward_end, &backward_start, &backward_end, &forward_params, &backward_params,
487 &allreduce_params, root);
488 int64_t micro_max = GetMicroMax(root, forward_end);
489 int64_t seg_max = GetSegmentMax(root, forward_end);
490 std::vector<int64_t> seg_micro_max{micro_max, seg_max};
491
492 auto backward_start_pair = DeduplicateBySegDescending(backward_start, root, micro_max, true, seg_max);
493 auto backward_end_pair = DeduplicateBySegDescending(backward_end, root, micro_max, true, seg_max);
494 auto forward_start_pair = DeduplicateBySegAscending(forward_start, root, micro_max, true, seg_max);
495 auto forward_end_pair = DeduplicateBySegAscending(forward_end, root, micro_max, true, seg_max);
496 auto forward_params_pair = Deduplicate(forward_params, root, micro_max, true, seg_max);
497 auto backward_params_pair = Deduplicate(backward_params, root, micro_max, true, seg_max);
498 CheckBorderNode(forward_start_pair, forward_end_pair, backward_start_pair, backward_end_pair, seg_micro_max);
499 auto forward_end_before_pair = GetForwardEndBeforePair(forward_end_pair);
500 std::vector<PipelinePair> pair_vector{backward_start_pair, backward_end_pair, forward_start_pair, forward_end_pair,
501 forward_end_before_pair};
502 AnfNodePtr start_of_forward;
503 AnfNodePtr end_of_forward;
504 AnfNodePtr start_of_backward;
505 AnfNodePtr end_of_backward;
506 AnfNodePtr pre_end_of_backward;
507
508 bool enable_1f1b = false;
509 if (common::GetEnv("FOLD_LAST_SEG_1F1B") != "") {
510 enable_1f1b = true;
511 }
512 int64_t sub_step_num = 0;
513 int64_t sub_micro_num = 0;
514 if (common::GetEnv("FOLD_ACCUMULATION") != "") sub_step_num = std::stoi(common::GetEnv("FOLD_ACCUMULATION"));
515 MS_LOG(INFO) << "Sub_step_num=" << sub_step_num;
516 PipelinePairVector sub_pair_vector;
517 if (sub_step_num > 0) {
518 int64_t micro_num = micro_max + 1;
519 int64_t seg_num = seg_max + 1;
520 sub_micro_num = micro_num / sub_step_num;
521 sub_pair_vector = UpdateSubPairs(sub_step_num, micro_num, pair_vector, sub_micro_num, seg_num);
522 }
523
524 if (enable_1f1b) {
525 if (sub_step_num > 0) {
526 for (int64_t s = 0; s < sub_step_num; s++) {
527 ReorderForFoldPipelineForward(sub_pair_vector[s], seg_max, sub_micro_num - 1, root, &start_of_forward,
528 &end_of_forward, enable_1f1b);
529 ReorderForBackwardLastSeg(sub_pair_vector[s], root, &start_of_backward, &end_of_backward, sub_micro_num - 1);
530 if (s > 0) {
531 InsertDepend(pre_end_of_backward, start_of_forward, root->manager(), root);
532 }
533 pre_end_of_backward = end_of_backward;
534 ReorderForParams(backward_params_pair, forward_params_pair, sub_pair_vector[kBackwardEnd][s],
535 sub_pair_vector[kForwardStart][s], root);
536 }
537 } else {
538 ReorderForFoldPipelineForward(pair_vector, seg_max, micro_max, root, &start_of_forward, &end_of_forward,
539 enable_1f1b);
540 ReorderForBackwardLastSeg(pair_vector, root, &start_of_backward, &end_of_backward, micro_max);
541 ReorderForParams(backward_params_pair, forward_params_pair, backward_end_pair, forward_start_pair, root);
542 }
543 } else {
544 if (sub_step_num > 0) {
545 for (int64_t s = 0; s < sub_step_num; s++) {
546 ReorderForFoldPipelineForward(sub_pair_vector[s], seg_max, sub_micro_num - 1, root, &start_of_forward,
547 &end_of_forward, enable_1f1b);
548
549 ReorderForFoldPipelineBackward(sub_pair_vector[s], seg_max, sub_micro_num - 1, root, &start_of_backward,
550 &end_of_backward);
551 InsertDepend(end_of_forward, start_of_backward, root->manager(), root);
552 if (s > 0) {
553 InsertDepend(pre_end_of_backward, start_of_forward, root->manager(), root);
554 }
555 pre_end_of_backward = end_of_backward;
556 ReorderForParams(backward_params_pair, forward_params_pair, sub_pair_vector[1][s], sub_pair_vector[2][s], root);
557 }
558 } else {
559 ReorderForFoldPipelineForward(pair_vector, seg_max, micro_max, root, &start_of_forward, &end_of_forward,
560 enable_1f1b);
561 ReorderForFoldPipelineBackward(pair_vector, seg_max, micro_max, root, &start_of_backward, &end_of_backward);
562 InsertDepend(end_of_forward, start_of_backward, root->manager(), root);
563 ReorderForParams(backward_params_pair, forward_params_pair, backward_end_pair, forward_start_pair, root);
564 }
565 }
566 }
567
568 } // namespace parallel
569 } // namespace mindspore
570