• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <vector>
18 #include <string>
19 #include <utility>
20 #include <algorithm>
21 #include <memory>
22 #include "frontend/parallel/pipeline_transformer/pipeline_scheduler.h"
23 #include "frontend/parallel/ops_info/ops_utils.h"
24 #include "frontend/parallel/step_parallel_utils.h"
25 #include "frontend/parallel/node_check.h"
26 #include "mindspore/core/ops/array_ops.h"
27 #include "mindspore/core/ops/other_ops.h"
28 #include "ir/anf.h"
29 #include "ir/graph_utils.h"
30 
31 namespace mindspore {
32 namespace parallel {
GetCellByReceive(const AnfNodePtr & node,const FuncGraphManagerPtr & manager)33 CNodePtr GetCellByReceive(const AnfNodePtr &node, const FuncGraphManagerPtr &manager) {
34   // receive->fg
35   if (!IsPrimitiveCNode(node, prim::kPrimReceive)) {
36     return nullptr;
37   }
38   auto users = manager->node_users()[node];
39   auto user = users.front().first;
40   while (IsPrimitiveCNode(user, prim::kPrimDepend)) {
41     users = manager->node_users()[user];
42     user = users.front().first;
43   }
44   auto fg_cnode = users.front().first->cast<CNodePtr>();
45   auto cnode = node->cast<CNodePtr>();
46   if (cnode->HasPrimalAttr(ORDER)) {
47     auto order = cnode->GetPrimalAttr(ORDER);
48     fg_cnode->AddPrimalAttr(ORDER, order);
49   }
50   return fg_cnode;
51 }
52 
GetCellBySend(const AnfNodePtr & node)53 CNodePtr GetCellBySend(const AnfNodePtr &node) {
54   // send->tuple_getitem->fg->slice
55   if (!IsPrimitiveCNode(node, prim::kPrimSend)) {
56     return nullptr;
57   }
58   auto cnode = node->cast<CNodePtr>();
59   auto fg_node = cnode->input(1);
60   while (IsPrimitiveCNode(fg_node, prim::kPrimTupleGetItem) || IsPrimitiveCNode(fg_node, prim::kPrimDepend)) {
61     fg_node = fg_node->cast<CNodePtr>()->input(1);
62   }
63   auto fg_cnode = fg_node->cast<CNodePtr>();
64   if (cnode->HasPrimalAttr(ORDER)) {
65     auto order = cnode->GetPrimalAttr(ORDER);
66     fg_cnode->AddPrimalAttr(ORDER, order);
67   }
68   return fg_cnode;
69 }
70 
GetBackwardBorderNode(const CNodePtr & cnode)71 void InterleavedScheduler::GetBackwardBorderNode(const CNodePtr &cnode) {
72   auto chunk = GetValue<int64_t>(cnode->GetPrimalAttr(CHUNK));
73   auto micro = GetValue<int64_t>(cnode->GetPrimalAttr(MICRO));
74   Border border = {cnode, chunk, micro};
75   Border border_cell = {nullptr, chunk, micro};
76   if (cnode->HasPrimalAttr(kPrimalAttrForwardNodeName)) {
77     if (cnode->HasPrimalAttr(PIPELINE_BEGIN)) {
78       auto bwd_cell = GetCellBySend(cnode);
79       MS_EXCEPTION_IF_NULL(bwd_cell);
80       if (stage_ == stage_num_ - 1 && chunk == chunk_num_ - 1) {
81         Border bwd_begin = {bwd_cell, chunk, micro};
82         bwd_begin_.emplace_back(bwd_begin);
83         border_cell.border = bwd_cell;
84         bwd_cell_.emplace_back(border_cell);
85       }
86       bwd_end_.emplace_back(border);
87     }
88     if (cnode->HasPrimalAttr(PIPELINE_END)) {
89       auto bwd_cell = GetCellByReceive(cnode, manager_);
90       MS_EXCEPTION_IF_NULL(bwd_cell);
91       if (stage_ == 0 && chunk == 0) {
92         Border bwd_end = {bwd_cell, chunk, micro};
93         bwd_end_.emplace_back(bwd_end);
94       }
95       border_cell.border = bwd_cell;
96       bwd_cell_.emplace_back(border_cell);
97       bwd_begin_.emplace_back(border);
98     }
99     if (cnode->HasPrimalAttr(PIPELINE_PARAM)) {
100       bwd_params_.emplace_back(border);
101     }
102   }
103 }
104 
GetBorderNode()105 void InterleavedScheduler::GetBorderNode() {
106   auto all_nodes = DeepScopedGraphSearch(root_->get_return());
107   for (auto &node : all_nodes) {
108     if (!IsPrimitiveCNode(node, prim::kPrimSend) && !IsPrimitiveCNode(node, prim::kPrimReceive)) {
109       continue;
110     }
111     auto cnode = node->cast<CNodePtr>();
112     auto chunk = GetValue<int64_t>(cnode->GetPrimalAttr(CHUNK));
113     chunk_num_ = (chunk + 1) > chunk_num_ ? (chunk + 1) : chunk_num_;
114     auto micro = GetValue<int64_t>(cnode->GetPrimalAttr(MICRO));
115     micro_size_ = (micro + 1) > micro_size_ ? (micro + 1) : micro_size_;
116   }
117   for (auto &node : all_nodes) {
118     if (!IsPrimitiveCNode(node, prim::kPrimSend) && !IsPrimitiveCNode(node, prim::kPrimReceive)) {
119       continue;
120     }
121     auto cnode = node->cast<CNodePtr>();
122     MS_EXCEPTION_IF_NULL(cnode);
123     auto chunk = GetValue<int64_t>(cnode->GetPrimalAttr(CHUNK));
124     auto micro = GetValue<int64_t>(cnode->GetPrimalAttr(MICRO));
125     Border border = {cnode, chunk, micro};
126     Border border_cell = {nullptr, chunk, micro};
127     if (cnode->HasPrimalAttr(kPrimalAttrForwardNodeName)) {
128       GetBackwardBorderNode(cnode);
129       continue;
130     }
131     if (cnode->HasPrimalAttr(PIPELINE_BEGIN)) {
132       auto fwd_cell = GetCellByReceive(cnode, manager_);
133       MS_EXCEPTION_IF_NULL(fwd_cell);
134       if (stage_ == stage_num_ - 1 && chunk == chunk_num_ - 1) {
135         Border fwd_end = {fwd_cell, chunk, micro};
136         fwd_end_.emplace_back(fwd_end);
137       }
138       border_cell.border = fwd_cell;
139       fwd_cell_.emplace_back(border_cell);
140       fwd_begin_.emplace_back(border);
141       continue;
142     }
143     if (cnode->HasPrimalAttr(PIPELINE_END)) {
144       auto fwd_cell = GetCellBySend(cnode);
145       MS_EXCEPTION_IF_NULL(fwd_cell);
146       if (stage_ == 0 && chunk == 0) {
147         Border fwd_begin = {fwd_cell, chunk, micro};
148         fwd_begin_.emplace_back(fwd_begin);
149         border_cell.border = fwd_cell;
150         fwd_cell_.emplace_back(border_cell);
151       }
152       fwd_end_.emplace_back(border);
153       continue;
154     }
155     if (cnode->HasPrimalAttr(PIPELINE_PARAM)) {
156       fwd_params_.emplace_back(border);
157       continue;
158     }
159   }
160 }
161 
SortFuncInsideMicro(const Border & b_i,const Border & b_j)162 bool SortFuncInsideMicro(const Border &b_i, const Border &b_j) {
163   auto node_i = b_i.border;
164   auto node_j = b_j.border;
165   auto order_i = node_i->GetPrimalAttr(ORDER);
166   auto order_j = node_j->GetPrimalAttr(ORDER);
167   MS_EXCEPTION_IF_NULL(order_i);
168   MS_EXCEPTION_IF_NULL(order_j);
169   return (GetValue<int64_t>(order_i) < GetValue<int64_t>(order_j));
170 }
171 
SortFuncBetweenMicro(const BorderPair & b_i,const BorderPair & b_j,int64_t stage_num,bool is_backward,int64_t offset)172 static bool SortFuncBetweenMicro(const BorderPair &b_i, const BorderPair &b_j, int64_t stage_num, bool is_backward,
173                                  int64_t offset) {
174   auto micro_i = b_i.first.micro;
175   auto micro_j = b_j.first.micro;
176   auto chunk_i = b_i.first.chunk;
177   auto chunk_j = b_j.first.chunk;
178   auto loop_i = (micro_i - offset) / stage_num;
179   auto loop_j = (micro_j - offset) / stage_num;
180   auto loop_i_offset = micro_i / (stage_num + offset);
181   auto loop_j_offset = micro_j / (stage_num + offset);
182   loop_i = loop_i_offset == 0 ? 0 : loop_i;
183   loop_j = loop_j_offset == 0 ? 0 : loop_j;
184   if (loop_i != loop_j) {
185     return loop_i < loop_j;
186   }
187   if (chunk_i != chunk_j) {
188     if (is_backward) {
189       return chunk_i > chunk_j;
190     }
191     return chunk_i < chunk_j;
192   }
193 
194   if (micro_i == micro_j) {
195     MS_LOG(EXCEPTION) << "Some wrong when sorted order between micro.";
196   }
197   return micro_i < micro_j;
198 }
199 
ControlOrder(const Border & b_prior,const Border & b_last)200 void PipelineScheduler::ControlOrder(const Border &b_prior, const Border &b_last) {
201   auto node_prior = b_prior.border;
202   auto node_last = b_last.border;
203   if (node_prior == node_last) {
204     return;
205   }
206   std::vector<AnfNodePtr> depend_input = {NewValueNode(prim::kPrimDepend), node_last->input(1), node_prior};
207   auto depend_node = root_->NewCNode(depend_input);
208   depend_node->set_abstract(node_last->input(1)->abstract());
209   depend_node->AddPrimalAttr("pipeline_control", MakeValue(true));
210   manager_->SetEdge(node_last, 1, depend_node);
211 }
212 
SortInsideMicro(const std::vector<Border> & borders)213 std::vector<BorderPair> PipelineScheduler::SortInsideMicro(const std::vector<Border> &borders) {
214   std::vector<BorderPair> out;
215   for (int64_t chunk = 0; chunk < chunk_num_; ++chunk) {
216     for (int64_t micro = 0; micro < micro_size_; ++micro) {
217       auto border = SpecifiedBorder(borders, chunk, micro);
218       out.emplace_back(border);
219     }
220   }
221   return out;
222 }
223 
SortBetweenMicro(const std::vector<Border> & borders,bool is_backward)224 std::vector<BorderPair> InterleavedScheduler::SortBetweenMicro(const std::vector<Border> &borders, bool is_backward) {
225   auto sorted_borders = SortInsideMicro(borders);
226   std::sort(sorted_borders.begin(), sorted_borders.end(), [this, is_backward](BorderPair a, BorderPair b) -> bool {
227     return SortFuncBetweenMicro(a, b, this->stage_num_, is_backward, this->offset_);
228   });
229   return sorted_borders;
230 }
231 
SpecifiedBorder(const std::vector<Border> & borders,int64_t chunk,int64_t micro)232 std::pair<Border, Border> PipelineScheduler::SpecifiedBorder(const std::vector<Border> &borders, int64_t chunk,
233                                                              int64_t micro) {
234   std::vector<Border> candidates;
235   std::copy_if(borders.begin(), borders.end(), std::back_inserter(candidates),
236                [&chunk, &micro](const auto &b) { return (b.chunk == chunk && b.micro == micro); });
237   if (candidates.empty()) {
238     MS_LOG(EXCEPTION) << "Can not find border of the pipeline.";
239   }
240   if (candidates.size() > 1) {
241     std::sort(candidates.begin(), candidates.end(), SortFuncInsideMicro);
242     for (size_t index = 0; index < candidates.size() - 1; ++index) {
243       auto prior = candidates[index];
244       auto last = candidates[index + 1];
245       ControlOrder(prior, last);
246     }
247   }
248   return std::make_pair(candidates.front(), candidates.back());
249 }
250 
WarmUpPhaseReorder()251 void InterleavedScheduler::WarmUpPhaseReorder() {
252   auto sorted_fwd_begin = SortBetweenMicro(fwd_begin_, false);
253   auto sorted_fwd_end = SortBetweenMicro(fwd_end_, false);
254   auto sorted_fwd_cell = SortBetweenMicro(fwd_cell_, false);
255   auto sorted_bwd_cell = SortBetweenMicro(bwd_cell_, true);
256   // WarmUp phase
257   for (size_t i = 0; i < LongToSize(bias_); ++i) {
258     if (i == LongToSize(micro_size_ * chunk_num_ - 1)) {
259       return;
260     }
261     // last stage
262     if (stage_ == stage_num_ - 1) {
263       if (offset_ > 0) {
264         auto prior = sorted_fwd_cell[i].second;
265         auto last = sorted_fwd_begin[i + 1].first;
266         ControlOrder(prior, last);
267       }
268       if (is_even_stage_) {
269         if (offset_ > 0) {
270           if (i + LongToSize(offset_) >= LongToSize(bias_)) {
271             auto prior1 = sorted_bwd_cell[i + LongToSize(offset_) - LongToSize(bias_)].second;
272             auto last1 = sorted_fwd_end[i].first;
273             ControlOrder(prior1, last1);
274           } else {
275             auto prior1 = sorted_fwd_cell[i + LongToSize(offset_)].second;
276             auto last1 = sorted_fwd_end[i].first;
277             ControlOrder(prior1, last1);
278           }
279         }
280         auto prior2 = sorted_fwd_end[i].second;
281         auto last2 = sorted_fwd_begin[i + LongToSize(offset_) + 1].first;
282         ControlOrder(prior2, last2);
283         continue;
284       }
285       auto prior1 = sorted_fwd_cell[i + LongToSize(offset_)].second;
286       if (i + LongToSize(offset_) >= LongToSize(bias_)) {
287         prior1 = sorted_bwd_cell[i + LongToSize(offset_) - LongToSize(bias_)].second;
288       }
289       auto last1 = sorted_fwd_begin[i + LongToSize(offset_) + 1].first;
290       ControlOrder(prior1, last1);
291       auto prior2 = sorted_fwd_begin[i + LongToSize(offset_) + 1].second;
292       auto last2 = sorted_fwd_end[i].first;
293       ControlOrder(prior2, last2);
294       auto prior3 = sorted_fwd_end[i].second;
295       auto last3 = sorted_fwd_cell[i + LongToSize(offset_) + 1].first;
296       ControlOrder(prior3, last3);
297       continue;
298     }
299     if (is_even_stage_) {
300       auto prior = sorted_fwd_end[i].second;
301       auto last = sorted_fwd_begin[i + 1].first;
302       ControlOrder(prior, last);
303       continue;
304     }
305     auto prior = sorted_fwd_cell[i].second;
306     auto last = sorted_fwd_begin[i + 1].first;
307     ControlOrder(prior, last);
308     auto prior1 = sorted_fwd_begin[i + 1].second;
309     auto last1 = sorted_fwd_end[i].first;
310     ControlOrder(prior1, last1);
311     auto prior2 = sorted_fwd_end[i].second;
312     auto last2 = sorted_fwd_cell[i + 1].first;
313     ControlOrder(prior2, last2);
314   }
315 }
316 
LastForwardMicroReorder()317 void InterleavedScheduler::LastForwardMicroReorder() {
318   auto sorted_fwd_end = SortBetweenMicro(fwd_end_, false);
319   auto sorted_fwd_cell = SortBetweenMicro(fwd_cell_, false);
320   auto sorted_bwd_begin = SortBetweenMicro(bwd_begin_, true);
321   auto sorted_bwd_cell = SortBetweenMicro(bwd_cell_, true);
322   auto sorted_bwd_end = SortBetweenMicro(bwd_end_, true);
323   auto index = chunk_num_ * micro_size_ - 1 - SizeToLong(bias_);
324   if (index < 0) {
325     auto prior = sorted_fwd_end.back().second;
326     auto last = sorted_bwd_begin.front().first;
327     ControlOrder(prior, last);
328     return;
329   }
330   if (stage_ == stage_num_ - 1) {
331     auto prior = sorted_fwd_end.back().second;
332     auto last = sorted_bwd_begin[index].first;
333     ControlOrder(prior, last);
334     return;
335   }
336   auto prior = sorted_bwd_cell[index].second;
337   auto last = sorted_fwd_end.back().first;
338   ControlOrder(prior, last);
339   auto prior1 = sorted_fwd_cell.back().second;
340   auto last1 = sorted_bwd_begin[index].first;
341   ControlOrder(prior1, last1);
342   if (stage_ == 0 && sorted_bwd_end[index].second.chunk == 0) {
343     auto prior2 = sorted_fwd_end.back().second;
344     auto last2 = sorted_bwd_begin[index + 1].first;
345     ControlOrder(prior2, last2);
346     return;
347   }
348   if (stage_ == 0) {
349     auto loop_index = sorted_bwd_end[index].second.micro / (stage_num_ + SizeToLong(offset_));
350     if (loop_index == 0) {
351       auto prior2 = sorted_fwd_end.back().second;
352       auto last2 = sorted_bwd_begin[index + 1].first;
353       ControlOrder(prior2, last2);
354     } else {
355       auto prior2 = sorted_fwd_end.back().second;
356       auto last2 = sorted_bwd_end[index].first;
357       ControlOrder(prior2, last2);
358     }
359     return;
360   }
361   if (is_even_stage_) {
362     auto prior2 = sorted_fwd_end.back().second;
363     auto last2 = sorted_bwd_end[index].first;
364     ControlOrder(prior2, last2);
365   } else {
366     auto prior2 = sorted_fwd_end.back().second;
367     auto last2 = sorted_bwd_begin[index + 1].first;
368     ControlOrder(prior2, last2);
369   }
370 }
371 
EndPhaseReorder()372 void InterleavedScheduler::EndPhaseReorder() {
373   auto sorted_fwd_end = SortBetweenMicro(fwd_end_, false);
374   auto sorted_bwd_begin = SortBetweenMicro(bwd_begin_, true);
375   auto sorted_bwd_cell = SortBetweenMicro(bwd_cell_, true);
376   auto sorted_bwd_end = SortBetweenMicro(bwd_end_, true);
377   auto begin_index =
378     LongToSize(chunk_num_ * micro_size_) > bias_ ? LongToSize(chunk_num_ * micro_size_ - bias_ - 1) : 0;
379   for (size_t i = LongToSize(begin_index); i < LongToSize(chunk_num_ * micro_size_ - 1); ++i) {
380     if (stage_ == 0) {
381       auto loop_index = sorted_bwd_end[i].second.micro / (stage_num_ + SizeToLong(offset_));
382       auto offset = LongToSize(offset_);
383       if (loop_index != 0 || sorted_bwd_end[i].second.chunk == 0) {
384         offset = 0;
385       }
386       if (offset > 0) {
387         auto prior = sorted_bwd_cell[i].second;
388         auto last = sorted_bwd_begin[i + 1].first;
389         ControlOrder(prior, last);
390         auto prior1 = sorted_bwd_cell[i + offset].second;
391         auto last1 = sorted_bwd_end[i].first;
392         ControlOrder(prior1, last1);
393       }
394       auto prior2 = sorted_bwd_end[i].second;
395       auto last2 = sorted_bwd_begin[i + offset + 1].first;
396       ControlOrder(prior2, last2);
397       continue;
398     }
399     if (is_even_stage_ || (stage_ == stage_num_ - 1 && sorted_bwd_begin[i + 1].first.chunk == chunk_num_ - 1)) {
400       auto prior = sorted_bwd_end[i].second;
401       auto last = sorted_bwd_begin[i + 1].first;
402       ControlOrder(prior, last);
403       continue;
404     }
405     auto prior1 = sorted_bwd_cell[i].second;
406     auto last1 = sorted_bwd_begin[i + 1].first;
407     ControlOrder(prior1, last1);
408     auto prior2 = sorted_bwd_begin[i + 1].second;
409     auto last2 = sorted_bwd_end[i].first;
410     ControlOrder(prior2, last2);
411     auto prior3 = sorted_bwd_end[i].second;
412     auto last3 = sorted_bwd_cell[i + 1].first;
413     ControlOrder(prior3, last3);
414   }
415 }
416 
GenerateTupleAbstract(const std::vector<AnfNodePtr> & nodes)417 AbstractBasePtr InterleavedScheduler::GenerateTupleAbstract(const std::vector<AnfNodePtr> &nodes) {
418   AbstractBasePtr abs;
419   if (nodes.size() == 2) {
420     auto cnode = nodes.back()->cast<CNodePtr>();
421     MS_EXCEPTION_IF_NULL(cnode);
422     abs = cnode->abstract();
423   } else {
424     AbstractBasePtrList abstract_list;
425     abstract_list.resize(nodes.size() - 1);
426     (void)std::transform(nodes.begin() + 1, nodes.end(), abstract_list.begin(), [](const AnfNodePtr &node) {
427       auto cnode = node->cast<CNodePtr>();
428       MS_EXCEPTION_IF_NULL(cnode);
429       return cnode->abstract();
430     });
431     abs = std::make_shared<abstract::AbstractTuple>(abstract_list);
432   }
433   return abs;
434 }
435 
OptimizerShardCommReorder()436 void InterleavedScheduler::OptimizerShardCommReorder() {
437   auto enable_opt_shard = ParallelContext::GetInstance()->enable_parallel_optimizer();
438   if (!enable_opt_shard) {
439     return;
440   }
441   auto sorted_fwd_begin = SortBetweenMicro(fwd_begin_, false);
442   auto sorted_fwd_cell = SortBetweenMicro(fwd_cell_, false);
443   std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
444   for (int64_t chunk = 1; chunk < chunk_num_; ++chunk) {
445     for (const auto &border : sorted_fwd_cell) {
446       if (border.first.chunk == chunk) {
447         auto cnode = border.first.border;
448         for (const auto &input : cnode->inputs()) {
449           if (!IsPrimitiveCNode(input, prim::kPrimAllGather)) {
450             continue;
451           }
452           make_tuple_inputs.emplace_back(input);
453         }
454       }
455     }
456   }
457   if (make_tuple_inputs.size() > 1) {
458     auto make_tuple = root_->NewCNode(make_tuple_inputs);
459     auto abs = GenerateTupleAbstract(make_tuple_inputs);
460     make_tuple->set_abstract(abs);
461     auto begin_node = sorted_fwd_begin.front().first.border;
462     if (begin_node->inputs().size() < 2) {
463       return;
464     }
465     std::vector<AnfNodePtr> depend_inputs = {NewValueNode(prim::kPrimDepend), begin_node->input(1), make_tuple};
466     auto depend = root_->NewCNode(depend_inputs);
467     depend->set_abstract(begin_node->input(1)->abstract());
468     manager_->SetEdge(begin_node, 1, depend);
469   }
470 }
471 
ParameterReorder(const std::vector<BorderPair> & sorted_fwd_begin,const std::vector<BorderPair> & sorted_bwd_end)472 void InterleavedScheduler::ParameterReorder(const std::vector<BorderPair> &sorted_fwd_begin,
473                                             const std::vector<BorderPair> &sorted_bwd_end) {
474   if (!fwd_params_.empty()) {
475     std::sort(fwd_params_.begin(), fwd_params_.end(), SortFuncInsideMicro);
476     std::sort(bwd_params_.begin(), bwd_params_.end(), SortFuncInsideMicro);
477     auto prior = fwd_params_.back();
478     auto last = sorted_fwd_begin.front().first;
479     ControlOrder(prior, last);
480     auto prior2 = sorted_bwd_end.back().second;
481     auto last2 = bwd_params_.front();
482     ControlOrder(prior2, last2);
483   }
484 }
485 
MemoryOptimizedWarmUpPhaseReorder()486 void InterleavedScheduler::MemoryOptimizedWarmUpPhaseReorder() {
487   auto sorted_fwd_begin = SortBetweenMicro(fwd_begin_, false);
488   auto sorted_fwd_end = SortBetweenMicro(fwd_end_, false);
489   auto sorted_fwd_cell = SortBetweenMicro(fwd_cell_, false);
490   auto sorted_bwd_cell = SortBetweenMicro(bwd_cell_, true);
491   for (size_t i = 0; i < LongToSize(bias_); ++i) {
492     if (stage_ != 0) {
493       auto prior = sorted_fwd_end[i].second;
494       auto last = sorted_fwd_begin[i + 1].first;
495       ControlOrder(prior, last);
496       continue;
497     } else {
498       size_t offset = 0;
499       if (sorted_fwd_begin[i + 1].first.chunk != 0) {
500         offset = offset_;
501       }
502       auto prior = sorted_fwd_end[i].second;
503       auto last = sorted_fwd_cell[i + 1].first;
504       ControlOrder(prior, last);
505       auto prior1 = sorted_fwd_cell[i - LongToSize(offset)].second;
506       auto last1 = sorted_fwd_begin[i + 1].first;
507       ControlOrder(prior1, last1);
508       auto prior2 = sorted_fwd_begin[i + 1].second;
509       if (last.border == prior2.border) {
510         continue;
511       }
512       auto last2 = sorted_fwd_end[i - LongToSize(offset)].first;
513       ControlOrder(prior2, last2);
514     }
515   }
516 }
517 
MemoryOptimizedStablePhaseReorder()518 void InterleavedScheduler::MemoryOptimizedStablePhaseReorder() {
519   auto sorted_fwd_begin = SortBetweenMicro(fwd_begin_, false);
520   auto sorted_fwd_end = SortBetweenMicro(fwd_end_, false);
521   auto sorted_bwd_begin = SortBetweenMicro(bwd_begin_, true);
522   auto sorted_bwd_end = SortBetweenMicro(bwd_end_, true);
523   auto sorted_fwd_cell = SortBetweenMicro(fwd_cell_, false);
524   auto sorted_bwd_cell = SortBetweenMicro(bwd_cell_, true);
525   for (size_t i = bias_; i < LongToSize(micro_size_ * chunk_num_); ++i) {
526     if (i == LongToSize(micro_size_ * chunk_num_ - 1)) {
527       if (stage_ != 0) {
528         auto prior = sorted_fwd_end[i].second;
529         auto last = sorted_bwd_begin[i - bias_].first;
530         ControlOrder(prior, last);
531       } else {
532         auto prior = sorted_fwd_cell[i].second;
533         auto last = sorted_bwd_begin[i - bias_].first;
534         ControlOrder(prior, last);
535         auto prior1 = sorted_bwd_begin[i - bias_].second;
536         auto last1 = sorted_fwd_end[i].first;
537         ControlOrder(prior1, last1);
538         auto prior2 = sorted_fwd_end[i].second;
539         auto last2 = sorted_bwd_cell[i - bias_].first;
540         ControlOrder(prior2, last2);
541       }
542       continue;
543     }
544     if (stage_ != 0) {
545       auto prior = sorted_bwd_end[i - bias_].second;
546       auto last = sorted_fwd_begin[i + 1].first;
547       ControlOrder(prior, last);
548     } else {
549       auto offset = offset_;
550       auto loop_index_bwd = sorted_bwd_end[i - bias_].second.micro / (stage_num_ + SizeToLong(offset_));
551       if (loop_index_bwd != 0) {
552         offset = 0;
553       }
554       auto loop_index_fwd = sorted_fwd_end[i + 1].second.micro / (stage_num_ + SizeToLong(offset_));
555       if (loop_index_fwd == 0) {
556         auto prior1 = sorted_fwd_end[i - offset_].second;
557         auto last1 = sorted_fwd_cell[i + 1 - offset_].first;
558         ControlOrder(prior1, last1);
559         auto prior2 = sorted_fwd_cell[i - offset_].second;
560         auto last2 = sorted_fwd_begin[i + 1].first;
561         ControlOrder(prior2, last2);
562         auto prior3 = sorted_fwd_begin[i + 1].second;
563         auto last3 = sorted_fwd_end[i - offset_].first;
564         ControlOrder(prior3, last3);
565       }
566       if (sorted_bwd_end[i - bias_].second.chunk != 0) {
567         auto prior1 = sorted_bwd_cell[i - bias_].second;
568         auto last1 = sorted_fwd_cell[i + 1].first;
569         ControlOrder(prior1, last1);
570         if (i + 1 + offset > LongToSize(micro_size_ * chunk_num_ - 1)) {
571           auto prior2 = sorted_bwd_begin[i - bias_ + 1 + offset].second;
572           auto last2 = sorted_bwd_end[i - bias_].first;
573           ControlOrder(prior2, last2);
574         } else {
575           auto prior2 = sorted_fwd_end[i + 1 + offset].second;
576           auto last2 = sorted_bwd_end[i - bias_].first;
577           ControlOrder(prior2, last2);
578         }
579         if ((i + 1 + offset <= LongToSize(micro_size_ * chunk_num_ - 1)) &&
580             sorted_fwd_end[i + 1 + offset].second.chunk != chunk_num_ - 1) {
581           auto prior3 = sorted_bwd_end[i - bias_].second;
582           auto last3 = sorted_fwd_begin[i + 1 + LongToSize(stage_num_) + offset].first;
583           ControlOrder(prior3, last3);
584           auto prior4 = sorted_fwd_begin[i + 1 + LongToSize(stage_num_) + offset].second;
585           auto last4 = sorted_bwd_cell[i - bias_ + 1 + offset].first;
586           ControlOrder(prior4, last4);
587         } else {
588           auto prior3 = sorted_bwd_end[i - bias_].second;
589           auto last3 = sorted_bwd_cell[i - bias_ + 1 + offset].first;
590           ControlOrder(prior3, last3);
591         }
592       } else {
593         auto prior = sorted_bwd_end[i - bias_].second;
594         auto last = sorted_fwd_cell[i + 1].first;
595         ControlOrder(prior, last);
596       }
597     }
598     if (stage_ != stage_num_ - 1 || sorted_fwd_end[i].second.chunk != chunk_num_ - 1) {
599       auto prior = sorted_fwd_cell[i].second;
600       auto last = sorted_bwd_begin[i - bias_].first;
601       ControlOrder(prior, last);
602       auto prior1 = sorted_bwd_begin[i - bias_].second;
603       auto last1 = sorted_fwd_end[i].first;
604       ControlOrder(prior1, last1);
605       auto prior2 = sorted_fwd_end[i].second;
606       auto last2 = sorted_bwd_cell[i - bias_].first;
607       ControlOrder(prior2, last2);
608     } else {
609       auto prior = sorted_fwd_end[i].second;
610       auto last = sorted_bwd_begin[i - bias_].first;
611       ControlOrder(prior, last);
612     }
613   }
614 }
615 
MemoryOptimizedReorder()616 void InterleavedScheduler::MemoryOptimizedReorder() {
617   offset_ = LongToSize(micro_size_ % stage_num_);
618   bias_ = LongToSize((stage_num_ + offset_) * (chunk_num_ - 1) + stage_num_ - stage_ - 1);
619   auto sorted_bwd_begin = SortBetweenMicro(bwd_begin_, true);
620   auto sorted_bwd_end = SortBetweenMicro(bwd_end_, true);
621   auto sorted_bwd_cell = SortBetweenMicro(bwd_cell_, true);
622   if (micro_size_ < stage_num_) {
623     MS_LOG(EXCEPTION) << "For 1F1B Scheduler, MicroBatch num must be larger or equal than StageNum, but got MicroBatch:"
624                       << micro_size_ << " StageNum:" << stage_num_;
625   }
626   // WarmUp phase
627   MemoryOptimizedWarmUpPhaseReorder();
628 
629   // Stable phase
630   MemoryOptimizedStablePhaseReorder();
631 
632   for (size_t i = LongToSize(micro_size_ * chunk_num_ - bias_ - 1); i < LongToSize(micro_size_ * chunk_num_ - 1); ++i) {
633     if (stage_ != stage_num_ - 1 || sorted_bwd_begin[i + 1].first.chunk == chunk_num_ - 1) {
634       auto prior = sorted_bwd_end[i].second;
635       auto last = sorted_bwd_begin[i + 1].first;
636       ControlOrder(prior, last);
637     } else {
638       auto prior = sorted_bwd_cell[i].second;
639       auto last = sorted_bwd_begin[i + 1].first;
640       ControlOrder(prior, last);
641       auto prior1 = sorted_bwd_begin[i + 1].second;
642       auto last1 = sorted_bwd_end[i].first;
643       ControlOrder(prior1, last1);
644       auto prior2 = sorted_bwd_end[i].second;
645       auto last2 = sorted_bwd_cell[i + 1].first;
646       ControlOrder(prior2, last2);
647     }
648   }
649 }
650 
EnableKbk()651 static bool EnableKbk() {
652   auto context = MsContext::GetInstance();
653   MS_EXCEPTION_IF_NULL(context);
654   auto jit_level = context->get_param<std::string>(MS_CTX_JIT_LEVEL);
655   MS_LOG(WARNING) << "Enable less mem vpp status:" << common::GetEnv("ENABLE_LESS_MEM_VPP");
656   return (jit_level == "O0" || jit_level == "O1") && common::GetEnv("ENABLE_LESS_MEM_VPP") == "1";
657 }
658 
StablePhaseReorder()659 void InterleavedScheduler::StablePhaseReorder() {
660   auto sorted_fwd_begin = SortBetweenMicro(fwd_begin_, false);
661   auto sorted_fwd_end = SortBetweenMicro(fwd_end_, false);
662   auto sorted_bwd_begin = SortBetweenMicro(bwd_begin_, true);
663   auto sorted_bwd_end = SortBetweenMicro(bwd_end_, true);
664   auto sorted_fwd_cell = SortBetweenMicro(fwd_cell_, false);
665   auto sorted_bwd_cell = SortBetweenMicro(bwd_cell_, true);
666   for (size_t i = LongToSize(bias_); i < LongToSize(micro_size_ * chunk_num_ - 1); ++i) {
667     if (stage_ == stage_num_ - 1 && sorted_fwd_end[i].first.chunk == chunk_num_ - 1) {
668       auto prior = sorted_fwd_end[i].second;
669       auto last = sorted_bwd_begin[i - LongToSize(bias_)].first;
670       ControlOrder(prior, last);
671       auto prior1 = sorted_bwd_cell[i - LongToSize(bias_)].second;
672       auto last1 = sorted_fwd_begin[i + 1].first;
673       ControlOrder(prior1, last1);
674     } else {
675       auto prior = sorted_fwd_cell[i].second;
676       auto last = sorted_bwd_begin[i - LongToSize(bias_)].first;
677       ControlOrder(prior, last);
678     }
679     if (is_even_stage_) {
680       if (stage_ != stage_num_ - 1 || sorted_fwd_end[i].first.chunk != chunk_num_ - 1) {
681         auto prior = sorted_bwd_cell[i - LongToSize(bias_)].second;
682         auto last = sorted_fwd_end[i].first;
683         ControlOrder(prior, last);
684         auto prior1 = sorted_fwd_end[i].second;
685         auto last1 = sorted_fwd_begin[i + 1].first;
686         ControlOrder(prior1, last1);
687       }
688       if (stage_ != 0 || sorted_bwd_end[i - LongToSize(bias_)].first.chunk != 0) {
689         auto loop_index = sorted_bwd_end[i - LongToSize(bias_)].first.micro / (stage_num_ + SizeToLong(offset_));
690         auto offset = LongToSize(offset_);
691         if (loop_index != 0 || stage_ != 0) {
692           offset = 0;
693         }
694         if (i + offset + 1 > LongToSize(micro_size_ * chunk_num_ - 1)) {
695           auto prior = sorted_bwd_cell[i + offset - LongToSize(bias_)].second;
696           auto last = sorted_bwd_end[i - LongToSize(bias_)].first;
697           ControlOrder(prior, last);
698         } else {
699           auto prior = sorted_fwd_cell[i + offset + 1].second;
700           auto last = sorted_bwd_end[i - LongToSize(bias_)].first;
701           ControlOrder(prior, last);
702         }
703         auto prior1 = sorted_bwd_end[i - LongToSize(bias_)].second;
704         auto last1 = sorted_bwd_begin[i + offset + 1 - LongToSize(bias_)].first;
705         ControlOrder(prior1, last1);
706       }
707       continue;
708     }
709     if (stage_ != stage_num_ - 1 || sorted_fwd_end[i].first.chunk != chunk_num_ - 1) {
710       auto prior = sorted_bwd_cell[i - LongToSize(bias_)].second;
711       auto last = sorted_fwd_begin[i + 1].first;
712       ControlOrder(prior, last);
713       auto prior1 = sorted_fwd_begin[i + 1].second;
714       auto last1 = sorted_fwd_end[i].first;
715       ControlOrder(prior1, last1);
716       auto prior2 = sorted_fwd_end[i].second;
717       auto last2 = sorted_fwd_cell[i + 1].first;
718       ControlOrder(prior2, last2);
719     }
720     if (stage_ != stage_num_ - 1 || sorted_bwd_begin[i - LongToSize(bias_) + 1].second.chunk != chunk_num_ - 1) {
721       auto prior = sorted_bwd_begin[i - LongToSize(bias_) + 1].second;
722       auto last = sorted_bwd_end[i - LongToSize(bias_)].first;
723       ControlOrder(prior, last);
724       auto prior1 = sorted_bwd_end[i - LongToSize(bias_)].second;
725       auto last1 = sorted_bwd_cell[i - LongToSize(bias_) + 1].first;
726       ControlOrder(prior1, last1);
727       continue;
728     }
729     auto prior = sorted_fwd_cell[i + 1].second;
730     auto last = sorted_bwd_end[i - LongToSize(bias_)].first;
731     ControlOrder(prior, last);
732     auto prior1 = sorted_bwd_end[i - LongToSize(bias_)].second;
733     auto last1 = sorted_bwd_cell[i - LongToSize(bias_) + 1].first;
734     ControlOrder(prior1, last1);
735   }
736 }
737 
Reorder()738 void InterleavedScheduler::Reorder() {
739   auto enable_kbk = EnableKbk();
740   auto sorted_fwd_begin = SortBetweenMicro(fwd_begin_, false);
741   auto sorted_bwd_end = SortBetweenMicro(bwd_end_, true);
742   if (enable_kbk) {
743     MemoryOptimizedReorder();
744     ParameterReorder(sorted_fwd_begin, sorted_bwd_end);
745     OptimizerShardCommReorder();
746     return;
747   }
748   offset_ = LongToSize(micro_size_ % stage_num_);
749   bias_ = LongToSize((stage_num_ + SizeToLong(offset_)) * (chunk_num_ - 1) + (stage_num_ - stage_ - 1) * INT64_TWO);
750   is_even_stage_ = stage_ % INT64_TWO == 0;
751   if (micro_size_ < stage_num_) {
752     MS_LOG(EXCEPTION) << "For 1F1B Scheduler, MicroBatch num must be larger or equal than StageNum, but got MicroBatch:"
753                       << micro_size_ << " StageNum:" << stage_num_;
754   }
755   // WarmUp phase
756   WarmUpPhaseReorder();
757 
758   // Stable phase
759   StablePhaseReorder();
760   LastForwardMicroReorder();
761 
762   // End phase
763   EndPhaseReorder();
764 
765   // Parameters phase
766   ParameterReorder(sorted_fwd_begin, sorted_bwd_end);
767   OptimizerShardCommReorder();
768 }
769 }  // namespace parallel
770 }  // namespace mindspore
771