• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_PIPELINE_TRANSFORMER_PIPELINE_SCHEDULER_H_
18 #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_PIPELINE_TRANSFORMER_PIPELINE_SCHEDULER_H_
19 
20 #include <set>
21 #include <utility>
22 #include <string>
23 #include <memory>
24 #include <vector>
25 #include "ir/value.h"
26 #include "ir/graph_utils.h"
27 #include "base/base.h"
28 #include "utils/hash_map.h"
29 #include "frontend/parallel/step_parallel.h"
30 #include "frontend/parallel/graph_util/generate_graph.h"
31 
32 namespace mindspore {
33 namespace parallel {
34 typedef struct {
35   CNodePtr border;
36   int64_t chunk;
37   int64_t micro;
38 } Border;
39 
40 using BorderPair = std::pair<Border, Border>;
41 
42 class PipelineScheduler {
43  public:
PipelineScheduler(const FuncGraphManagerPtr & manager,const FuncGraphPtr & root,int64_t stage,int64_t stage_num)44   explicit PipelineScheduler(const FuncGraphManagerPtr &manager, const FuncGraphPtr &root, int64_t stage,
45                              int64_t stage_num)
46       : manager_(manager), root_(root), stage_(stage), stage_num_(stage_num) {}
47   virtual ~PipelineScheduler() = default;
48   virtual void GetBorderNode() = 0;
49   virtual void Reorder() = 0;
50 
51  protected:
52   std::vector<BorderPair> SortInsideMicro(const std::vector<Border> &borders);
53   std::pair<Border, Border> SpecifiedBorder(const std::vector<Border> &borders, int64_t chunk, int64_t micro);
54   void ControlOrder(const Border &b_prior, const Border &b_last);
55   int64_t micro_size_ = 1;
56   int64_t chunk_num_ = 1;
57   FuncGraphManagerPtr manager_;
58   FuncGraphPtr root_;
59   int64_t stage_;
60   int64_t stage_num_;
61 };
62 
63 class InterleavedScheduler : public PipelineScheduler {
64  public:
InterleavedScheduler(const FuncGraphManagerPtr & manager,const FuncGraphPtr & root,int64_t stage,int64_t stage_num)65   InterleavedScheduler(const FuncGraphManagerPtr &manager, const FuncGraphPtr &root, int64_t stage, int64_t stage_num)
66       : PipelineScheduler(manager, root, stage, stage_num) {}
67   virtual ~InterleavedScheduler() = default;
68 
69   void GetBorderNode() override;
70   void Reorder() override;
71 
72  private:
73   void MemoryOptimizedWarmUpPhaseReorder();
74   void MemoryOptimizedStablePhaseReorder();
75   void MemoryOptimizedReorder();
76   void WarmUpPhaseReorder();
77   void StablePhaseReorder();
78   void LastForwardMicroReorder();
79   void EndPhaseReorder();
80   AbstractBasePtr GenerateTupleAbstract(const std::vector<AnfNodePtr> &nodes);
81   void OptimizerShardCommReorder();
82   void ParameterReorder(const std::vector<BorderPair> &sorted_fwd_begin, const std::vector<BorderPair> &sorted_bwd_end);
83   void GetBackwardBorderNode(const CNodePtr &cnode);
84   std::vector<BorderPair> SortBetweenMicro(const std::vector<Border> &borders, bool is_backward);
85   std::vector<Border> fwd_begin_;
86   std::vector<Border> fwd_end_;
87   std::vector<Border> bwd_begin_;
88   std::vector<Border> bwd_end_;
89   std::vector<Border> fwd_cell_;
90   std::vector<Border> bwd_cell_;
91   std::vector<Border> fwd_params_;
92   std::vector<Border> bwd_params_;
93   size_t bias_ = 0;
94   size_t offset_ = 0;
95   bool is_even_stage_ = true;
96 };
97 bool SortFuncInsideMicro(const Border &b_i, const Border &b_j);
98 CNodePtr GetCellByReceive(const AnfNodePtr &node, const FuncGraphManagerPtr &manager);
99 CNodePtr GetCellBySend(const AnfNodePtr &node);
100 }  // namespace parallel
101 }  // namespace mindspore
102 #endif  // MINDSPORE_CCSRC_FRONTEND_PARALLEL_PIPELINE_TRANSFORMER_PIPELINE_SCHEDULER_H_
103