• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_COMP_COMM_SCHEDULING_H_
18 #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_COMP_COMM_SCHEDULING_H_
19 
20 #include <list>
21 #include <vector>
22 #include <utility>
23 #include <memory>
24 #include <unordered_map>
25 #include <string>
26 
27 #include "mindspore/core/ir/anf.h"
28 #include "mindspore/core/ir/manager.h"
29 #include "mindspore/core/mindapi/base/shape_vector.h"
30 
31 namespace mindspore {
32 namespace opt {
33 // Preliminary definitions
34 using Time = size_t;
35 using TaskId = size_t;
36 using PeId = size_t;
37 enum TaskType { kNone, kComp, kComm };
38 
39 struct ProcessingElement {
40   PeId id;
41   TaskType type;
42   Time load;
43   std::list<std::pair<Time, Time>> idle;
44 };
45 
46 struct Interval {  // Information extracted by scheduling
47   TaskId id;
48   TaskType type;
49   Time start;
50   Time end;
51 };
52 
53 enum TaskSort {
54   kSortByWeightMax = 0,
55   kSortByWeightMin,
56   kSortBySuccDiff,
57   kSortByBottomLevelMax,
58   kSortByBottomLevelMin,
59   kSortByTopLevelMax,
60   kSortByTopLevelMin,
61   kSortByBottomTopLevelMaxSum,
62   kSortByBottomTopLevelMinSum,
63   kSortByBottomTopLevelComposite,
64   kSortByWeightedLength,
65   kSortByDepthMax,
66   kSortByDepthMin,
67   kSortByPredComm,
68   kNumTaskSort
69 };
70 
71 // Class to define tasks for scheduling
72 class Task {
73  public:
74   // Constructors
Task(const TaskId & id,const TaskType & type)75   Task(const TaskId &id, const TaskType &type) {
76     this->id_ = id;
77     this->type_ = type;
78     // scheduling related
79     this->weight_ = 1;
80     this->parallel_weight_ = 1;
81     this->bottom_level_ = 0;
82     this->top_level_ = 0;
83     this->depth_ = 0;
84     this->succ_diff_type_ = 0;
85     this->weighted_length_ = 0.0;
86     this->start_ = SIZE_MAX;
87     this->end_ = 0;
88     this->pred_comm_ = 0;
89   }
90 
91   // Accessors
id()92   TaskId id() const { return this->id_; }
type()93   TaskType type() const { return this->type_; }
94 
weight()95   Time weight() const { return this->weight_; }
parallel_weight()96   Time parallel_weight() const { return this->parallel_weight_; }
bottom_level()97   Time bottom_level() const { return this->bottom_level_; }
top_level()98   Time top_level() const { return this->top_level_; }
depth()99   size_t depth() const { return this->depth_; }
succ_diff_type()100   size_t succ_diff_type() const { return this->succ_diff_type_; }
weighted_length()101   double weighted_length() const { return this->weighted_length_; }
start()102   Time start() const { return this->start_; }
end()103   Time end() const { return this->end_; }
pred_comm()104   size_t pred_comm() const { return this->pred_comm_; }
105 
parents()106   std::vector<std::weak_ptr<Task>> &parents() { return this->parents_; }
children()107   std::vector<std::shared_ptr<Task>> &children() { return this->children_; }
no_dep_grandchildren()108   std::vector<std::shared_ptr<Task>> &no_dep_grandchildren() { return this->no_dep_grandchildren_; }
109 
110   // Mutators
set_id(TaskId id)111   void set_id(TaskId id) { this->id_ = id; }
set_type(TaskType type)112   void set_type(TaskType type) { this->type_ = type; }
113 
set_weight(Time weight)114   void set_weight(Time weight) { this->weight_ = weight; }
set_parallel_weight(Time parallel_weight)115   void set_parallel_weight(Time parallel_weight) { this->parallel_weight_ = parallel_weight; }
set_bottom_level(Time bottom_level)116   void set_bottom_level(Time bottom_level) { this->bottom_level_ = bottom_level; }
set_top_level(Time top_level)117   void set_top_level(Time top_level) { this->top_level_ = top_level; }
set_depth(size_t depth)118   void set_depth(size_t depth) { this->depth_ = depth; }
set_succ_diff_type(size_t succ_diff_type)119   void set_succ_diff_type(size_t succ_diff_type) { this->succ_diff_type_ = succ_diff_type; }
set_weighted_length(double weighted_length)120   void set_weighted_length(double weighted_length) { this->weighted_length_ = weighted_length; }
set_start(Time start)121   void set_start(Time start) { this->start_ = start; }
set_end(Time end)122   void set_end(Time end) { this->end_ = end; }
set_pred_comm(size_t pred_comm)123   void set_pred_comm(size_t pred_comm) { this->pred_comm_ = pred_comm; }
124 
125   // Maintaining Graph Topology
AddParent(std::weak_ptr<Task> parent)126   void AddParent(std::weak_ptr<Task> parent) { this->parents_.push_back(parent); }
AddChild(std::shared_ptr<Task> child)127   void AddChild(std::shared_ptr<Task> child) { this->children_.push_back(child); }
AddNoDepGrandchild(std::shared_ptr<Task> grandchild)128   void AddNoDepGrandchild(std::shared_ptr<Task> grandchild) { this->no_dep_grandchildren_.push_back(grandchild); }
129 
HasChild(std::shared_ptr<Task> child)130   bool HasChild(std::shared_ptr<Task> child) {
131     return std::find(children_.begin(), children_.end(), child) != children_.end();
132   }
HasNoDepGrandchild(std::shared_ptr<Task> grandchild)133   bool HasNoDepGrandchild(std::shared_ptr<Task> grandchild) {
134     return std::find(no_dep_grandchildren_.begin(), no_dep_grandchildren_.end(), grandchild) !=
135            no_dep_grandchildren_.end();
136   }
137 
138   // Weighting (Evaluation)
AssignWeight(size_t weight)139   void AssignWeight(size_t weight) {
140     if (weight == 0) {
141       this->weight_ = 1;
142     } else if (weight < 0) {
143       this->weight_ = SIZE_MAX;
144     } else {
145       this->weight_ = weight;
146     }
147     this->parallel_weight_ = this->weight_;
148   }
149 
150   // Other
ResetStartEnd()151   void ResetStartEnd() {
152     this->set_start(SIZE_MAX);
153     this->set_end(0);
154   }
155 
156  private:
157   TaskId id_;
158   TaskType type_;
159 
160   // Attributes used to select task during scheduling
161   Time weight_;
162   Time parallel_weight_;
163   Time bottom_level_;
164   Time top_level_;
165   size_t depth_;
166   size_t succ_diff_type_;
167   double weighted_length_;
168   Time start_;
169   Time end_;
170   size_t pred_comm_;
171 
172   // Attributes to maintain graph info
173 
174   std::vector<std::weak_ptr<Task>> parents_;
175   std::vector<std::shared_ptr<Task>> children_;
176   std::vector<std::shared_ptr<Task>> no_dep_grandchildren_;
177 };
178 
179 using TaskPtr = std::shared_ptr<Task>;
180 using TaskSortFunction = bool (*)(std::shared_ptr<Task> const &, std::shared_ptr<Task> const &);
181 
182 // Class for scheduling algorithms
183 struct SchedulingInput {
184   std::vector<std::shared_ptr<Task>> tasks;
185 };
186 
187 struct SchedulingOutput {
188   std::vector<Interval> task_times;
189   Time makespan;
190 };
191 
192 namespace FastGreedyScheduler {
193 // Main functionality
194 SchedulingOutput Process(SchedulingInput &, const std::string &);
195 SchedulingOutput ProcessCore(std::vector<std::shared_ptr<Task>> &, std::unordered_map<TaskType, int32_t> &,
196                              const TaskSortFunction &, bool);
197 SchedulingOutput ProcessSingle(const SchedulingInput &, const TaskSortFunction &, bool, const std::string &);
198 
199 // Compute Auxiliary Values for Task Sorting
200 void ComputeBottomLevelAndWeightedLength(std::vector<std::shared_ptr<Task>> &);
201 void ComputeDepthAndTopLevel(std::vector<std::shared_ptr<Task>> &);
202 void ComputePredComm(std::vector<std::shared_ptr<Task>> &);
203 
204 // Lower Bounds
205 Time LowerBoundBottomLevel(std::vector<std::shared_ptr<Task>> &);
206 Time LowerBoundPEs(std::vector<std::shared_ptr<Task>> &, std::unordered_map<TaskType, int32_t> &);
207 
208 // Dependency Generation
209 std::vector<std::pair<TaskId, TaskId>> ScheduleToDependencies(const SchedulingOutput &);
210 
211 // Verification
212 bool VerifyDAG(std::vector<std::shared_ptr<Task>> &);
213 bool VerifyScheduling(std::vector<std::shared_ptr<Task>> &);
214 bool VerifyDependencies(std::vector<std::shared_ptr<Task>> &, std::vector<std::pair<TaskId, TaskId>> &);
215 
216 // Log
217 void PrintLog(const SchedulingOutput &, const std::vector<std::pair<TaskId, TaskId>> &, const std::string &);
218 }  // namespace FastGreedyScheduler
219 
220 SchedulingInput ExtractSchedulingInput(const FuncGraphManagerPtr &, const std::vector<CNodePtr> &,
221                                        std::unordered_map<CNodePtr, TaskPtr> *);
222 void AddRealDependencies(const FuncGraphManagerPtr &, const std::vector<CNodePtr> &,
223                          const std::vector<std::pair<TaskId, TaskId>> &, std::unordered_map<CNodePtr, TaskPtr> &);
224 
225 // Functions for integration
226 void CompCommScheduling(const FuncGraphPtr &);
227 }  // namespace opt
228 }  // namespace mindspore
229 #endif  // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_COMP_COMM_SCHEDULING_H_
230