1 /**
2 * Copyright 2024 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <vector>
18 #include <string>
19 #include <utility>
20 #include <algorithm>
21 #include <memory>
22 #include "frontend/parallel/pipeline_transformer/pipeline_scheduler.h"
23 #include "frontend/parallel/ops_info/ops_utils.h"
24 #include "frontend/parallel/step_parallel_utils.h"
25 #include "frontend/parallel/node_check.h"
26 #include "mindspore/core/ops/array_ops.h"
27 #include "mindspore/core/ops/other_ops.h"
28 #include "ir/anf.h"
29 #include "ir/graph_utils.h"
30
31 namespace mindspore {
32 namespace parallel {
GetCellByReceive(const AnfNodePtr & node,const FuncGraphManagerPtr & manager)33 CNodePtr GetCellByReceive(const AnfNodePtr &node, const FuncGraphManagerPtr &manager) {
34 // receive->fg
35 if (!IsPrimitiveCNode(node, prim::kPrimReceive)) {
36 return nullptr;
37 }
38 auto users = manager->node_users()[node];
39 auto user = users.front().first;
40 while (IsPrimitiveCNode(user, prim::kPrimDepend)) {
41 users = manager->node_users()[user];
42 user = users.front().first;
43 }
44 auto fg_cnode = users.front().first->cast<CNodePtr>();
45 auto cnode = node->cast<CNodePtr>();
46 if (cnode->HasPrimalAttr(ORDER)) {
47 auto order = cnode->GetPrimalAttr(ORDER);
48 fg_cnode->AddPrimalAttr(ORDER, order);
49 }
50 return fg_cnode;
51 }
52
GetCellBySend(const AnfNodePtr & node)53 CNodePtr GetCellBySend(const AnfNodePtr &node) {
54 // send->tuple_getitem->fg->slice
55 if (!IsPrimitiveCNode(node, prim::kPrimSend)) {
56 return nullptr;
57 }
58 auto cnode = node->cast<CNodePtr>();
59 auto fg_node = cnode->input(1);
60 while (IsPrimitiveCNode(fg_node, prim::kPrimTupleGetItem) || IsPrimitiveCNode(fg_node, prim::kPrimDepend)) {
61 fg_node = fg_node->cast<CNodePtr>()->input(1);
62 }
63 auto fg_cnode = fg_node->cast<CNodePtr>();
64 if (cnode->HasPrimalAttr(ORDER)) {
65 auto order = cnode->GetPrimalAttr(ORDER);
66 fg_cnode->AddPrimalAttr(ORDER, order);
67 }
68 return fg_cnode;
69 }
70
GetBackwardBorderNode(const CNodePtr & cnode)71 void InterleavedScheduler::GetBackwardBorderNode(const CNodePtr &cnode) {
72 auto chunk = GetValue<int64_t>(cnode->GetPrimalAttr(CHUNK));
73 auto micro = GetValue<int64_t>(cnode->GetPrimalAttr(MICRO));
74 Border border = {cnode, chunk, micro};
75 Border border_cell = {nullptr, chunk, micro};
76 if (cnode->HasPrimalAttr(kPrimalAttrForwardNodeName)) {
77 if (cnode->HasPrimalAttr(PIPELINE_BEGIN)) {
78 auto bwd_cell = GetCellBySend(cnode);
79 MS_EXCEPTION_IF_NULL(bwd_cell);
80 if (stage_ == stage_num_ - 1 && chunk == chunk_num_ - 1) {
81 Border bwd_begin = {bwd_cell, chunk, micro};
82 bwd_begin_.emplace_back(bwd_begin);
83 border_cell.border = bwd_cell;
84 bwd_cell_.emplace_back(border_cell);
85 }
86 bwd_end_.emplace_back(border);
87 }
88 if (cnode->HasPrimalAttr(PIPELINE_END)) {
89 auto bwd_cell = GetCellByReceive(cnode, manager_);
90 MS_EXCEPTION_IF_NULL(bwd_cell);
91 if (stage_ == 0 && chunk == 0) {
92 Border bwd_end = {bwd_cell, chunk, micro};
93 bwd_end_.emplace_back(bwd_end);
94 }
95 border_cell.border = bwd_cell;
96 bwd_cell_.emplace_back(border_cell);
97 bwd_begin_.emplace_back(border);
98 }
99 if (cnode->HasPrimalAttr(PIPELINE_PARAM)) {
100 bwd_params_.emplace_back(border);
101 }
102 }
103 }
104
GetBorderNode()105 void InterleavedScheduler::GetBorderNode() {
106 auto all_nodes = DeepScopedGraphSearch(root_->get_return());
107 for (auto &node : all_nodes) {
108 if (!IsPrimitiveCNode(node, prim::kPrimSend) && !IsPrimitiveCNode(node, prim::kPrimReceive)) {
109 continue;
110 }
111 auto cnode = node->cast<CNodePtr>();
112 auto chunk = GetValue<int64_t>(cnode->GetPrimalAttr(CHUNK));
113 chunk_num_ = (chunk + 1) > chunk_num_ ? (chunk + 1) : chunk_num_;
114 auto micro = GetValue<int64_t>(cnode->GetPrimalAttr(MICRO));
115 micro_size_ = (micro + 1) > micro_size_ ? (micro + 1) : micro_size_;
116 }
117 for (auto &node : all_nodes) {
118 if (!IsPrimitiveCNode(node, prim::kPrimSend) && !IsPrimitiveCNode(node, prim::kPrimReceive)) {
119 continue;
120 }
121 auto cnode = node->cast<CNodePtr>();
122 MS_EXCEPTION_IF_NULL(cnode);
123 auto chunk = GetValue<int64_t>(cnode->GetPrimalAttr(CHUNK));
124 auto micro = GetValue<int64_t>(cnode->GetPrimalAttr(MICRO));
125 Border border = {cnode, chunk, micro};
126 Border border_cell = {nullptr, chunk, micro};
127 if (cnode->HasPrimalAttr(kPrimalAttrForwardNodeName)) {
128 GetBackwardBorderNode(cnode);
129 continue;
130 }
131 if (cnode->HasPrimalAttr(PIPELINE_BEGIN)) {
132 auto fwd_cell = GetCellByReceive(cnode, manager_);
133 MS_EXCEPTION_IF_NULL(fwd_cell);
134 if (stage_ == stage_num_ - 1 && chunk == chunk_num_ - 1) {
135 Border fwd_end = {fwd_cell, chunk, micro};
136 fwd_end_.emplace_back(fwd_end);
137 }
138 border_cell.border = fwd_cell;
139 fwd_cell_.emplace_back(border_cell);
140 fwd_begin_.emplace_back(border);
141 continue;
142 }
143 if (cnode->HasPrimalAttr(PIPELINE_END)) {
144 auto fwd_cell = GetCellBySend(cnode);
145 MS_EXCEPTION_IF_NULL(fwd_cell);
146 if (stage_ == 0 && chunk == 0) {
147 Border fwd_begin = {fwd_cell, chunk, micro};
148 fwd_begin_.emplace_back(fwd_begin);
149 border_cell.border = fwd_cell;
150 fwd_cell_.emplace_back(border_cell);
151 }
152 fwd_end_.emplace_back(border);
153 continue;
154 }
155 if (cnode->HasPrimalAttr(PIPELINE_PARAM)) {
156 fwd_params_.emplace_back(border);
157 continue;
158 }
159 }
160 }
161
SortFuncInsideMicro(const Border & b_i,const Border & b_j)162 bool SortFuncInsideMicro(const Border &b_i, const Border &b_j) {
163 auto node_i = b_i.border;
164 auto node_j = b_j.border;
165 auto order_i = node_i->GetPrimalAttr(ORDER);
166 auto order_j = node_j->GetPrimalAttr(ORDER);
167 MS_EXCEPTION_IF_NULL(order_i);
168 MS_EXCEPTION_IF_NULL(order_j);
169 return (GetValue<int64_t>(order_i) < GetValue<int64_t>(order_j));
170 }
171
SortFuncBetweenMicro(const BorderPair & b_i,const BorderPair & b_j,int64_t stage_num,bool is_backward,int64_t offset)172 static bool SortFuncBetweenMicro(const BorderPair &b_i, const BorderPair &b_j, int64_t stage_num, bool is_backward,
173 int64_t offset) {
174 auto micro_i = b_i.first.micro;
175 auto micro_j = b_j.first.micro;
176 auto chunk_i = b_i.first.chunk;
177 auto chunk_j = b_j.first.chunk;
178 auto loop_i = (micro_i - offset) / stage_num;
179 auto loop_j = (micro_j - offset) / stage_num;
180 auto loop_i_offset = micro_i / (stage_num + offset);
181 auto loop_j_offset = micro_j / (stage_num + offset);
182 loop_i = loop_i_offset == 0 ? 0 : loop_i;
183 loop_j = loop_j_offset == 0 ? 0 : loop_j;
184 if (loop_i != loop_j) {
185 return loop_i < loop_j;
186 }
187 if (chunk_i != chunk_j) {
188 if (is_backward) {
189 return chunk_i > chunk_j;
190 }
191 return chunk_i < chunk_j;
192 }
193
194 if (micro_i == micro_j) {
195 MS_LOG(EXCEPTION) << "Some wrong when sorted order between micro.";
196 }
197 return micro_i < micro_j;
198 }
199
ControlOrder(const Border & b_prior,const Border & b_last)200 void PipelineScheduler::ControlOrder(const Border &b_prior, const Border &b_last) {
201 auto node_prior = b_prior.border;
202 auto node_last = b_last.border;
203 if (node_prior == node_last) {
204 return;
205 }
206 std::vector<AnfNodePtr> depend_input = {NewValueNode(prim::kPrimDepend), node_last->input(1), node_prior};
207 auto depend_node = root_->NewCNode(depend_input);
208 depend_node->set_abstract(node_last->input(1)->abstract());
209 depend_node->AddPrimalAttr("pipeline_control", MakeValue(true));
210 manager_->SetEdge(node_last, 1, depend_node);
211 }
212
SortInsideMicro(const std::vector<Border> & borders)213 std::vector<BorderPair> PipelineScheduler::SortInsideMicro(const std::vector<Border> &borders) {
214 std::vector<BorderPair> out;
215 for (int64_t chunk = 0; chunk < chunk_num_; ++chunk) {
216 for (int64_t micro = 0; micro < micro_size_; ++micro) {
217 auto border = SpecifiedBorder(borders, chunk, micro);
218 out.emplace_back(border);
219 }
220 }
221 return out;
222 }
223
SortBetweenMicro(const std::vector<Border> & borders,bool is_backward)224 std::vector<BorderPair> InterleavedScheduler::SortBetweenMicro(const std::vector<Border> &borders, bool is_backward) {
225 auto sorted_borders = SortInsideMicro(borders);
226 std::sort(sorted_borders.begin(), sorted_borders.end(), [this, is_backward](BorderPair a, BorderPair b) -> bool {
227 return SortFuncBetweenMicro(a, b, this->stage_num_, is_backward, this->offset_);
228 });
229 return sorted_borders;
230 }
231
SpecifiedBorder(const std::vector<Border> & borders,int64_t chunk,int64_t micro)232 std::pair<Border, Border> PipelineScheduler::SpecifiedBorder(const std::vector<Border> &borders, int64_t chunk,
233 int64_t micro) {
234 std::vector<Border> candidates;
235 std::copy_if(borders.begin(), borders.end(), std::back_inserter(candidates),
236 [&chunk, µ](const auto &b) { return (b.chunk == chunk && b.micro == micro); });
237 if (candidates.empty()) {
238 MS_LOG(EXCEPTION) << "Can not find border of the pipeline.";
239 }
240 if (candidates.size() > 1) {
241 std::sort(candidates.begin(), candidates.end(), SortFuncInsideMicro);
242 for (size_t index = 0; index < candidates.size() - 1; ++index) {
243 auto prior = candidates[index];
244 auto last = candidates[index + 1];
245 ControlOrder(prior, last);
246 }
247 }
248 return std::make_pair(candidates.front(), candidates.back());
249 }
250
WarmUpPhaseReorder()251 void InterleavedScheduler::WarmUpPhaseReorder() {
252 auto sorted_fwd_begin = SortBetweenMicro(fwd_begin_, false);
253 auto sorted_fwd_end = SortBetweenMicro(fwd_end_, false);
254 auto sorted_fwd_cell = SortBetweenMicro(fwd_cell_, false);
255 auto sorted_bwd_cell = SortBetweenMicro(bwd_cell_, true);
256 // WarmUp phase
257 for (size_t i = 0; i < LongToSize(bias_); ++i) {
258 if (i == LongToSize(micro_size_ * chunk_num_ - 1)) {
259 return;
260 }
261 // last stage
262 if (stage_ == stage_num_ - 1) {
263 if (offset_ > 0) {
264 auto prior = sorted_fwd_cell[i].second;
265 auto last = sorted_fwd_begin[i + 1].first;
266 ControlOrder(prior, last);
267 }
268 if (is_even_stage_) {
269 if (offset_ > 0) {
270 if (i + LongToSize(offset_) >= LongToSize(bias_)) {
271 auto prior1 = sorted_bwd_cell[i + LongToSize(offset_) - LongToSize(bias_)].second;
272 auto last1 = sorted_fwd_end[i].first;
273 ControlOrder(prior1, last1);
274 } else {
275 auto prior1 = sorted_fwd_cell[i + LongToSize(offset_)].second;
276 auto last1 = sorted_fwd_end[i].first;
277 ControlOrder(prior1, last1);
278 }
279 }
280 auto prior2 = sorted_fwd_end[i].second;
281 auto last2 = sorted_fwd_begin[i + LongToSize(offset_) + 1].first;
282 ControlOrder(prior2, last2);
283 continue;
284 }
285 auto prior1 = sorted_fwd_cell[i + LongToSize(offset_)].second;
286 if (i + LongToSize(offset_) >= LongToSize(bias_)) {
287 prior1 = sorted_bwd_cell[i + LongToSize(offset_) - LongToSize(bias_)].second;
288 }
289 auto last1 = sorted_fwd_begin[i + LongToSize(offset_) + 1].first;
290 ControlOrder(prior1, last1);
291 auto prior2 = sorted_fwd_begin[i + LongToSize(offset_) + 1].second;
292 auto last2 = sorted_fwd_end[i].first;
293 ControlOrder(prior2, last2);
294 auto prior3 = sorted_fwd_end[i].second;
295 auto last3 = sorted_fwd_cell[i + LongToSize(offset_) + 1].first;
296 ControlOrder(prior3, last3);
297 continue;
298 }
299 if (is_even_stage_) {
300 auto prior = sorted_fwd_end[i].second;
301 auto last = sorted_fwd_begin[i + 1].first;
302 ControlOrder(prior, last);
303 continue;
304 }
305 auto prior = sorted_fwd_cell[i].second;
306 auto last = sorted_fwd_begin[i + 1].first;
307 ControlOrder(prior, last);
308 auto prior1 = sorted_fwd_begin[i + 1].second;
309 auto last1 = sorted_fwd_end[i].first;
310 ControlOrder(prior1, last1);
311 auto prior2 = sorted_fwd_end[i].second;
312 auto last2 = sorted_fwd_cell[i + 1].first;
313 ControlOrder(prior2, last2);
314 }
315 }
316
LastForwardMicroReorder()317 void InterleavedScheduler::LastForwardMicroReorder() {
318 auto sorted_fwd_end = SortBetweenMicro(fwd_end_, false);
319 auto sorted_fwd_cell = SortBetweenMicro(fwd_cell_, false);
320 auto sorted_bwd_begin = SortBetweenMicro(bwd_begin_, true);
321 auto sorted_bwd_cell = SortBetweenMicro(bwd_cell_, true);
322 auto sorted_bwd_end = SortBetweenMicro(bwd_end_, true);
323 auto index = chunk_num_ * micro_size_ - 1 - SizeToLong(bias_);
324 if (index < 0) {
325 auto prior = sorted_fwd_end.back().second;
326 auto last = sorted_bwd_begin.front().first;
327 ControlOrder(prior, last);
328 return;
329 }
330 if (stage_ == stage_num_ - 1) {
331 auto prior = sorted_fwd_end.back().second;
332 auto last = sorted_bwd_begin[index].first;
333 ControlOrder(prior, last);
334 return;
335 }
336 auto prior = sorted_bwd_cell[index].second;
337 auto last = sorted_fwd_end.back().first;
338 ControlOrder(prior, last);
339 auto prior1 = sorted_fwd_cell.back().second;
340 auto last1 = sorted_bwd_begin[index].first;
341 ControlOrder(prior1, last1);
342 if (stage_ == 0 && sorted_bwd_end[index].second.chunk == 0) {
343 auto prior2 = sorted_fwd_end.back().second;
344 auto last2 = sorted_bwd_begin[index + 1].first;
345 ControlOrder(prior2, last2);
346 return;
347 }
348 if (stage_ == 0) {
349 auto loop_index = sorted_bwd_end[index].second.micro / (stage_num_ + SizeToLong(offset_));
350 if (loop_index == 0) {
351 auto prior2 = sorted_fwd_end.back().second;
352 auto last2 = sorted_bwd_begin[index + 1].first;
353 ControlOrder(prior2, last2);
354 } else {
355 auto prior2 = sorted_fwd_end.back().second;
356 auto last2 = sorted_bwd_end[index].first;
357 ControlOrder(prior2, last2);
358 }
359 return;
360 }
361 if (is_even_stage_) {
362 auto prior2 = sorted_fwd_end.back().second;
363 auto last2 = sorted_bwd_end[index].first;
364 ControlOrder(prior2, last2);
365 } else {
366 auto prior2 = sorted_fwd_end.back().second;
367 auto last2 = sorted_bwd_begin[index + 1].first;
368 ControlOrder(prior2, last2);
369 }
370 }
371
EndPhaseReorder()372 void InterleavedScheduler::EndPhaseReorder() {
373 auto sorted_fwd_end = SortBetweenMicro(fwd_end_, false);
374 auto sorted_bwd_begin = SortBetweenMicro(bwd_begin_, true);
375 auto sorted_bwd_cell = SortBetweenMicro(bwd_cell_, true);
376 auto sorted_bwd_end = SortBetweenMicro(bwd_end_, true);
377 auto begin_index =
378 LongToSize(chunk_num_ * micro_size_) > bias_ ? LongToSize(chunk_num_ * micro_size_ - bias_ - 1) : 0;
379 for (size_t i = LongToSize(begin_index); i < LongToSize(chunk_num_ * micro_size_ - 1); ++i) {
380 if (stage_ == 0) {
381 auto loop_index = sorted_bwd_end[i].second.micro / (stage_num_ + SizeToLong(offset_));
382 auto offset = LongToSize(offset_);
383 if (loop_index != 0 || sorted_bwd_end[i].second.chunk == 0) {
384 offset = 0;
385 }
386 if (offset > 0) {
387 auto prior = sorted_bwd_cell[i].second;
388 auto last = sorted_bwd_begin[i + 1].first;
389 ControlOrder(prior, last);
390 auto prior1 = sorted_bwd_cell[i + offset].second;
391 auto last1 = sorted_bwd_end[i].first;
392 ControlOrder(prior1, last1);
393 }
394 auto prior2 = sorted_bwd_end[i].second;
395 auto last2 = sorted_bwd_begin[i + offset + 1].first;
396 ControlOrder(prior2, last2);
397 continue;
398 }
399 if (is_even_stage_ || (stage_ == stage_num_ - 1 && sorted_bwd_begin[i + 1].first.chunk == chunk_num_ - 1)) {
400 auto prior = sorted_bwd_end[i].second;
401 auto last = sorted_bwd_begin[i + 1].first;
402 ControlOrder(prior, last);
403 continue;
404 }
405 auto prior1 = sorted_bwd_cell[i].second;
406 auto last1 = sorted_bwd_begin[i + 1].first;
407 ControlOrder(prior1, last1);
408 auto prior2 = sorted_bwd_begin[i + 1].second;
409 auto last2 = sorted_bwd_end[i].first;
410 ControlOrder(prior2, last2);
411 auto prior3 = sorted_bwd_end[i].second;
412 auto last3 = sorted_bwd_cell[i + 1].first;
413 ControlOrder(prior3, last3);
414 }
415 }
416
GenerateTupleAbstract(const std::vector<AnfNodePtr> & nodes)417 AbstractBasePtr InterleavedScheduler::GenerateTupleAbstract(const std::vector<AnfNodePtr> &nodes) {
418 AbstractBasePtr abs;
419 if (nodes.size() == 2) {
420 auto cnode = nodes.back()->cast<CNodePtr>();
421 MS_EXCEPTION_IF_NULL(cnode);
422 abs = cnode->abstract();
423 } else {
424 AbstractBasePtrList abstract_list;
425 abstract_list.resize(nodes.size() - 1);
426 (void)std::transform(nodes.begin() + 1, nodes.end(), abstract_list.begin(), [](const AnfNodePtr &node) {
427 auto cnode = node->cast<CNodePtr>();
428 MS_EXCEPTION_IF_NULL(cnode);
429 return cnode->abstract();
430 });
431 abs = std::make_shared<abstract::AbstractTuple>(abstract_list);
432 }
433 return abs;
434 }
435
OptimizerShardCommReorder()436 void InterleavedScheduler::OptimizerShardCommReorder() {
437 auto enable_opt_shard = ParallelContext::GetInstance()->enable_parallel_optimizer();
438 if (!enable_opt_shard) {
439 return;
440 }
441 auto sorted_fwd_begin = SortBetweenMicro(fwd_begin_, false);
442 auto sorted_fwd_cell = SortBetweenMicro(fwd_cell_, false);
443 std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
444 for (int64_t chunk = 1; chunk < chunk_num_; ++chunk) {
445 for (const auto &border : sorted_fwd_cell) {
446 if (border.first.chunk == chunk) {
447 auto cnode = border.first.border;
448 for (const auto &input : cnode->inputs()) {
449 if (!IsPrimitiveCNode(input, prim::kPrimAllGather)) {
450 continue;
451 }
452 make_tuple_inputs.emplace_back(input);
453 }
454 }
455 }
456 }
457 if (make_tuple_inputs.size() > 1) {
458 auto make_tuple = root_->NewCNode(make_tuple_inputs);
459 auto abs = GenerateTupleAbstract(make_tuple_inputs);
460 make_tuple->set_abstract(abs);
461 auto begin_node = sorted_fwd_begin.front().first.border;
462 if (begin_node->inputs().size() < 2) {
463 return;
464 }
465 std::vector<AnfNodePtr> depend_inputs = {NewValueNode(prim::kPrimDepend), begin_node->input(1), make_tuple};
466 auto depend = root_->NewCNode(depend_inputs);
467 depend->set_abstract(begin_node->input(1)->abstract());
468 manager_->SetEdge(begin_node, 1, depend);
469 }
470 }
471
ParameterReorder(const std::vector<BorderPair> & sorted_fwd_begin,const std::vector<BorderPair> & sorted_bwd_end)472 void InterleavedScheduler::ParameterReorder(const std::vector<BorderPair> &sorted_fwd_begin,
473 const std::vector<BorderPair> &sorted_bwd_end) {
474 if (!fwd_params_.empty()) {
475 std::sort(fwd_params_.begin(), fwd_params_.end(), SortFuncInsideMicro);
476 std::sort(bwd_params_.begin(), bwd_params_.end(), SortFuncInsideMicro);
477 auto prior = fwd_params_.back();
478 auto last = sorted_fwd_begin.front().first;
479 ControlOrder(prior, last);
480 auto prior2 = sorted_bwd_end.back().second;
481 auto last2 = bwd_params_.front();
482 ControlOrder(prior2, last2);
483 }
484 }
485
MemoryOptimizedWarmUpPhaseReorder()486 void InterleavedScheduler::MemoryOptimizedWarmUpPhaseReorder() {
487 auto sorted_fwd_begin = SortBetweenMicro(fwd_begin_, false);
488 auto sorted_fwd_end = SortBetweenMicro(fwd_end_, false);
489 auto sorted_fwd_cell = SortBetweenMicro(fwd_cell_, false);
490 auto sorted_bwd_cell = SortBetweenMicro(bwd_cell_, true);
491 for (size_t i = 0; i < LongToSize(bias_); ++i) {
492 if (stage_ != 0) {
493 auto prior = sorted_fwd_end[i].second;
494 auto last = sorted_fwd_begin[i + 1].first;
495 ControlOrder(prior, last);
496 continue;
497 } else {
498 size_t offset = 0;
499 if (sorted_fwd_begin[i + 1].first.chunk != 0) {
500 offset = offset_;
501 }
502 auto prior = sorted_fwd_end[i].second;
503 auto last = sorted_fwd_cell[i + 1].first;
504 ControlOrder(prior, last);
505 auto prior1 = sorted_fwd_cell[i - LongToSize(offset)].second;
506 auto last1 = sorted_fwd_begin[i + 1].first;
507 ControlOrder(prior1, last1);
508 auto prior2 = sorted_fwd_begin[i + 1].second;
509 if (last.border == prior2.border) {
510 continue;
511 }
512 auto last2 = sorted_fwd_end[i - LongToSize(offset)].first;
513 ControlOrder(prior2, last2);
514 }
515 }
516 }
517
MemoryOptimizedStablePhaseReorder()518 void InterleavedScheduler::MemoryOptimizedStablePhaseReorder() {
519 auto sorted_fwd_begin = SortBetweenMicro(fwd_begin_, false);
520 auto sorted_fwd_end = SortBetweenMicro(fwd_end_, false);
521 auto sorted_bwd_begin = SortBetweenMicro(bwd_begin_, true);
522 auto sorted_bwd_end = SortBetweenMicro(bwd_end_, true);
523 auto sorted_fwd_cell = SortBetweenMicro(fwd_cell_, false);
524 auto sorted_bwd_cell = SortBetweenMicro(bwd_cell_, true);
525 for (size_t i = bias_; i < LongToSize(micro_size_ * chunk_num_); ++i) {
526 if (i == LongToSize(micro_size_ * chunk_num_ - 1)) {
527 if (stage_ != 0) {
528 auto prior = sorted_fwd_end[i].second;
529 auto last = sorted_bwd_begin[i - bias_].first;
530 ControlOrder(prior, last);
531 } else {
532 auto prior = sorted_fwd_cell[i].second;
533 auto last = sorted_bwd_begin[i - bias_].first;
534 ControlOrder(prior, last);
535 auto prior1 = sorted_bwd_begin[i - bias_].second;
536 auto last1 = sorted_fwd_end[i].first;
537 ControlOrder(prior1, last1);
538 auto prior2 = sorted_fwd_end[i].second;
539 auto last2 = sorted_bwd_cell[i - bias_].first;
540 ControlOrder(prior2, last2);
541 }
542 continue;
543 }
544 if (stage_ != 0) {
545 auto prior = sorted_bwd_end[i - bias_].second;
546 auto last = sorted_fwd_begin[i + 1].first;
547 ControlOrder(prior, last);
548 } else {
549 auto offset = offset_;
550 auto loop_index_bwd = sorted_bwd_end[i - bias_].second.micro / (stage_num_ + SizeToLong(offset_));
551 if (loop_index_bwd != 0) {
552 offset = 0;
553 }
554 auto loop_index_fwd = sorted_fwd_end[i + 1].second.micro / (stage_num_ + SizeToLong(offset_));
555 if (loop_index_fwd == 0) {
556 auto prior1 = sorted_fwd_end[i - offset_].second;
557 auto last1 = sorted_fwd_cell[i + 1 - offset_].first;
558 ControlOrder(prior1, last1);
559 auto prior2 = sorted_fwd_cell[i - offset_].second;
560 auto last2 = sorted_fwd_begin[i + 1].first;
561 ControlOrder(prior2, last2);
562 auto prior3 = sorted_fwd_begin[i + 1].second;
563 auto last3 = sorted_fwd_end[i - offset_].first;
564 ControlOrder(prior3, last3);
565 }
566 if (sorted_bwd_end[i - bias_].second.chunk != 0) {
567 auto prior1 = sorted_bwd_cell[i - bias_].second;
568 auto last1 = sorted_fwd_cell[i + 1].first;
569 ControlOrder(prior1, last1);
570 if (i + 1 + offset > LongToSize(micro_size_ * chunk_num_ - 1)) {
571 auto prior2 = sorted_bwd_begin[i - bias_ + 1 + offset].second;
572 auto last2 = sorted_bwd_end[i - bias_].first;
573 ControlOrder(prior2, last2);
574 } else {
575 auto prior2 = sorted_fwd_end[i + 1 + offset].second;
576 auto last2 = sorted_bwd_end[i - bias_].first;
577 ControlOrder(prior2, last2);
578 }
579 if ((i + 1 + offset <= LongToSize(micro_size_ * chunk_num_ - 1)) &&
580 sorted_fwd_end[i + 1 + offset].second.chunk != chunk_num_ - 1) {
581 auto prior3 = sorted_bwd_end[i - bias_].second;
582 auto last3 = sorted_fwd_begin[i + 1 + LongToSize(stage_num_) + offset].first;
583 ControlOrder(prior3, last3);
584 auto prior4 = sorted_fwd_begin[i + 1 + LongToSize(stage_num_) + offset].second;
585 auto last4 = sorted_bwd_cell[i - bias_ + 1 + offset].first;
586 ControlOrder(prior4, last4);
587 } else {
588 auto prior3 = sorted_bwd_end[i - bias_].second;
589 auto last3 = sorted_bwd_cell[i - bias_ + 1 + offset].first;
590 ControlOrder(prior3, last3);
591 }
592 } else {
593 auto prior = sorted_bwd_end[i - bias_].second;
594 auto last = sorted_fwd_cell[i + 1].first;
595 ControlOrder(prior, last);
596 }
597 }
598 if (stage_ != stage_num_ - 1 || sorted_fwd_end[i].second.chunk != chunk_num_ - 1) {
599 auto prior = sorted_fwd_cell[i].second;
600 auto last = sorted_bwd_begin[i - bias_].first;
601 ControlOrder(prior, last);
602 auto prior1 = sorted_bwd_begin[i - bias_].second;
603 auto last1 = sorted_fwd_end[i].first;
604 ControlOrder(prior1, last1);
605 auto prior2 = sorted_fwd_end[i].second;
606 auto last2 = sorted_bwd_cell[i - bias_].first;
607 ControlOrder(prior2, last2);
608 } else {
609 auto prior = sorted_fwd_end[i].second;
610 auto last = sorted_bwd_begin[i - bias_].first;
611 ControlOrder(prior, last);
612 }
613 }
614 }
615
MemoryOptimizedReorder()616 void InterleavedScheduler::MemoryOptimizedReorder() {
617 offset_ = LongToSize(micro_size_ % stage_num_);
618 bias_ = LongToSize((stage_num_ + offset_) * (chunk_num_ - 1) + stage_num_ - stage_ - 1);
619 auto sorted_bwd_begin = SortBetweenMicro(bwd_begin_, true);
620 auto sorted_bwd_end = SortBetweenMicro(bwd_end_, true);
621 auto sorted_bwd_cell = SortBetweenMicro(bwd_cell_, true);
622 if (micro_size_ < stage_num_) {
623 MS_LOG(EXCEPTION) << "For 1F1B Scheduler, MicroBatch num must be larger or equal than StageNum, but got MicroBatch:"
624 << micro_size_ << " StageNum:" << stage_num_;
625 }
626 // WarmUp phase
627 MemoryOptimizedWarmUpPhaseReorder();
628
629 // Stable phase
630 MemoryOptimizedStablePhaseReorder();
631
632 for (size_t i = LongToSize(micro_size_ * chunk_num_ - bias_ - 1); i < LongToSize(micro_size_ * chunk_num_ - 1); ++i) {
633 if (stage_ != stage_num_ - 1 || sorted_bwd_begin[i + 1].first.chunk == chunk_num_ - 1) {
634 auto prior = sorted_bwd_end[i].second;
635 auto last = sorted_bwd_begin[i + 1].first;
636 ControlOrder(prior, last);
637 } else {
638 auto prior = sorted_bwd_cell[i].second;
639 auto last = sorted_bwd_begin[i + 1].first;
640 ControlOrder(prior, last);
641 auto prior1 = sorted_bwd_begin[i + 1].second;
642 auto last1 = sorted_bwd_end[i].first;
643 ControlOrder(prior1, last1);
644 auto prior2 = sorted_bwd_end[i].second;
645 auto last2 = sorted_bwd_cell[i + 1].first;
646 ControlOrder(prior2, last2);
647 }
648 }
649 }
650
EnableKbk()651 static bool EnableKbk() {
652 auto context = MsContext::GetInstance();
653 MS_EXCEPTION_IF_NULL(context);
654 auto jit_level = context->get_param<std::string>(MS_CTX_JIT_LEVEL);
655 MS_LOG(WARNING) << "Enable less mem vpp status:" << common::GetEnv("ENABLE_LESS_MEM_VPP");
656 return (jit_level == "O0" || jit_level == "O1") && common::GetEnv("ENABLE_LESS_MEM_VPP") == "1";
657 }
658
StablePhaseReorder()659 void InterleavedScheduler::StablePhaseReorder() {
660 auto sorted_fwd_begin = SortBetweenMicro(fwd_begin_, false);
661 auto sorted_fwd_end = SortBetweenMicro(fwd_end_, false);
662 auto sorted_bwd_begin = SortBetweenMicro(bwd_begin_, true);
663 auto sorted_bwd_end = SortBetweenMicro(bwd_end_, true);
664 auto sorted_fwd_cell = SortBetweenMicro(fwd_cell_, false);
665 auto sorted_bwd_cell = SortBetweenMicro(bwd_cell_, true);
666 for (size_t i = LongToSize(bias_); i < LongToSize(micro_size_ * chunk_num_ - 1); ++i) {
667 if (stage_ == stage_num_ - 1 && sorted_fwd_end[i].first.chunk == chunk_num_ - 1) {
668 auto prior = sorted_fwd_end[i].second;
669 auto last = sorted_bwd_begin[i - LongToSize(bias_)].first;
670 ControlOrder(prior, last);
671 auto prior1 = sorted_bwd_cell[i - LongToSize(bias_)].second;
672 auto last1 = sorted_fwd_begin[i + 1].first;
673 ControlOrder(prior1, last1);
674 } else {
675 auto prior = sorted_fwd_cell[i].second;
676 auto last = sorted_bwd_begin[i - LongToSize(bias_)].first;
677 ControlOrder(prior, last);
678 }
679 if (is_even_stage_) {
680 if (stage_ != stage_num_ - 1 || sorted_fwd_end[i].first.chunk != chunk_num_ - 1) {
681 auto prior = sorted_bwd_cell[i - LongToSize(bias_)].second;
682 auto last = sorted_fwd_end[i].first;
683 ControlOrder(prior, last);
684 auto prior1 = sorted_fwd_end[i].second;
685 auto last1 = sorted_fwd_begin[i + 1].first;
686 ControlOrder(prior1, last1);
687 }
688 if (stage_ != 0 || sorted_bwd_end[i - LongToSize(bias_)].first.chunk != 0) {
689 auto loop_index = sorted_bwd_end[i - LongToSize(bias_)].first.micro / (stage_num_ + SizeToLong(offset_));
690 auto offset = LongToSize(offset_);
691 if (loop_index != 0 || stage_ != 0) {
692 offset = 0;
693 }
694 if (i + offset + 1 > LongToSize(micro_size_ * chunk_num_ - 1)) {
695 auto prior = sorted_bwd_cell[i + offset - LongToSize(bias_)].second;
696 auto last = sorted_bwd_end[i - LongToSize(bias_)].first;
697 ControlOrder(prior, last);
698 } else {
699 auto prior = sorted_fwd_cell[i + offset + 1].second;
700 auto last = sorted_bwd_end[i - LongToSize(bias_)].first;
701 ControlOrder(prior, last);
702 }
703 auto prior1 = sorted_bwd_end[i - LongToSize(bias_)].second;
704 auto last1 = sorted_bwd_begin[i + offset + 1 - LongToSize(bias_)].first;
705 ControlOrder(prior1, last1);
706 }
707 continue;
708 }
709 if (stage_ != stage_num_ - 1 || sorted_fwd_end[i].first.chunk != chunk_num_ - 1) {
710 auto prior = sorted_bwd_cell[i - LongToSize(bias_)].second;
711 auto last = sorted_fwd_begin[i + 1].first;
712 ControlOrder(prior, last);
713 auto prior1 = sorted_fwd_begin[i + 1].second;
714 auto last1 = sorted_fwd_end[i].first;
715 ControlOrder(prior1, last1);
716 auto prior2 = sorted_fwd_end[i].second;
717 auto last2 = sorted_fwd_cell[i + 1].first;
718 ControlOrder(prior2, last2);
719 }
720 if (stage_ != stage_num_ - 1 || sorted_bwd_begin[i - LongToSize(bias_) + 1].second.chunk != chunk_num_ - 1) {
721 auto prior = sorted_bwd_begin[i - LongToSize(bias_) + 1].second;
722 auto last = sorted_bwd_end[i - LongToSize(bias_)].first;
723 ControlOrder(prior, last);
724 auto prior1 = sorted_bwd_end[i - LongToSize(bias_)].second;
725 auto last1 = sorted_bwd_cell[i - LongToSize(bias_) + 1].first;
726 ControlOrder(prior1, last1);
727 continue;
728 }
729 auto prior = sorted_fwd_cell[i + 1].second;
730 auto last = sorted_bwd_end[i - LongToSize(bias_)].first;
731 ControlOrder(prior, last);
732 auto prior1 = sorted_bwd_end[i - LongToSize(bias_)].second;
733 auto last1 = sorted_bwd_cell[i - LongToSize(bias_) + 1].first;
734 ControlOrder(prior1, last1);
735 }
736 }
737
Reorder()738 void InterleavedScheduler::Reorder() {
739 auto enable_kbk = EnableKbk();
740 auto sorted_fwd_begin = SortBetweenMicro(fwd_begin_, false);
741 auto sorted_bwd_end = SortBetweenMicro(bwd_end_, true);
742 if (enable_kbk) {
743 MemoryOptimizedReorder();
744 ParameterReorder(sorted_fwd_begin, sorted_bwd_end);
745 OptimizerShardCommReorder();
746 return;
747 }
748 offset_ = LongToSize(micro_size_ % stage_num_);
749 bias_ = LongToSize((stage_num_ + SizeToLong(offset_)) * (chunk_num_ - 1) + (stage_num_ - stage_ - 1) * INT64_TWO);
750 is_even_stage_ = stage_ % INT64_TWO == 0;
751 if (micro_size_ < stage_num_) {
752 MS_LOG(EXCEPTION) << "For 1F1B Scheduler, MicroBatch num must be larger or equal than StageNum, but got MicroBatch:"
753 << micro_size_ << " StageNum:" << stage_num_;
754 }
755 // WarmUp phase
756 WarmUpPhaseReorder();
757
758 // Stable phase
759 StablePhaseReorder();
760 LastForwardMicroReorder();
761
762 // End phase
763 EndPhaseReorder();
764
765 // Parameters phase
766 ParameterReorder(sorted_fwd_begin, sorted_bwd_end);
767 OptimizerShardCommReorder();
768 }
769 } // namespace parallel
770 } // namespace mindspore
771