1 /** 2 * Copyright 2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_COMP_COMM_SCHEDULING_H_ 18 #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_COMP_COMM_SCHEDULING_H_ 19 20 #include <list> 21 #include <vector> 22 #include <utility> 23 #include <memory> 24 #include <unordered_map> 25 #include <string> 26 27 #include "mindspore/core/ir/anf.h" 28 #include "mindspore/core/ir/manager.h" 29 #include "mindspore/core/mindapi/base/shape_vector.h" 30 31 namespace mindspore { 32 namespace opt { 33 // Preliminary definitions 34 using Time = size_t; 35 using TaskId = size_t; 36 using PeId = size_t; 37 enum TaskType { kNone, kComp, kComm }; 38 39 struct ProcessingElement { 40 PeId id; 41 TaskType type; 42 Time load; 43 std::list<std::pair<Time, Time>> idle; 44 }; 45 46 struct Interval { // Information extracted by scheduling 47 TaskId id; 48 TaskType type; 49 Time start; 50 Time end; 51 }; 52 53 enum TaskSort { 54 kSortByWeightMax = 0, 55 kSortByWeightMin, 56 kSortBySuccDiff, 57 kSortByBottomLevelMax, 58 kSortByBottomLevelMin, 59 kSortByTopLevelMax, 60 kSortByTopLevelMin, 61 kSortByBottomTopLevelMaxSum, 62 kSortByBottomTopLevelMinSum, 63 kSortByBottomTopLevelComposite, 64 kSortByWeightedLength, 65 kSortByDepthMax, 66 kSortByDepthMin, 67 kSortByPredComm, 68 kNumTaskSort 69 }; 70 71 // Class to define tasks for scheduling 72 class Task { 73 public: 74 // Constructors Task(const TaskId & id,const TaskType & type)75 Task(const TaskId &id, const TaskType &type) { 76 this->id_ = id; 77 this->type_ = type; 78 // scheduling related 79 this->weight_ = 1; 80 this->parallel_weight_ = 1; 81 this->bottom_level_ = 0; 82 this->top_level_ = 0; 83 this->depth_ = 0; 84 this->succ_diff_type_ = 0; 85 this->weighted_length_ = 0.0; 86 this->start_ = SIZE_MAX; 87 this->end_ = 0; 88 this->pred_comm_ = 0; 89 } 90 91 // Accessors id()92 TaskId id() const { return this->id_; } type()93 TaskType type() const { return this->type_; } 94 weight()95 Time weight() const { return this->weight_; } parallel_weight()96 Time parallel_weight() const { return this->parallel_weight_; } bottom_level()97 Time bottom_level() const { return this->bottom_level_; } top_level()98 Time top_level() const { return this->top_level_; } depth()99 size_t depth() const { return this->depth_; } succ_diff_type()100 size_t succ_diff_type() const { return this->succ_diff_type_; } weighted_length()101 double weighted_length() const { return this->weighted_length_; } start()102 Time start() const { return this->start_; } end()103 Time end() const { return this->end_; } pred_comm()104 size_t pred_comm() const { return this->pred_comm_; } 105 parents()106 std::vector<std::weak_ptr<Task>> &parents() { return this->parents_; } children()107 std::vector<std::shared_ptr<Task>> &children() { return this->children_; } no_dep_grandchildren()108 std::vector<std::shared_ptr<Task>> &no_dep_grandchildren() { return this->no_dep_grandchildren_; } 109 110 // Mutators set_id(TaskId id)111 void set_id(TaskId id) { this->id_ = id; } set_type(TaskType type)112 void set_type(TaskType type) { this->type_ = type; } 113 set_weight(Time weight)114 void set_weight(Time weight) { this->weight_ = weight; } set_parallel_weight(Time parallel_weight)115 void set_parallel_weight(Time parallel_weight) { this->parallel_weight_ = parallel_weight; } set_bottom_level(Time bottom_level)116 void set_bottom_level(Time bottom_level) { this->bottom_level_ = bottom_level; } set_top_level(Time top_level)117 void set_top_level(Time top_level) { this->top_level_ = top_level; } set_depth(size_t depth)118 void set_depth(size_t depth) { this->depth_ = depth; } set_succ_diff_type(size_t succ_diff_type)119 void set_succ_diff_type(size_t succ_diff_type) { this->succ_diff_type_ = succ_diff_type; } set_weighted_length(double weighted_length)120 void set_weighted_length(double weighted_length) { this->weighted_length_ = weighted_length; } set_start(Time start)121 void set_start(Time start) { this->start_ = start; } set_end(Time end)122 void set_end(Time end) { this->end_ = end; } set_pred_comm(size_t pred_comm)123 void set_pred_comm(size_t pred_comm) { this->pred_comm_ = pred_comm; } 124 125 // Maintaining Graph Topology AddParent(std::weak_ptr<Task> parent)126 void AddParent(std::weak_ptr<Task> parent) { this->parents_.push_back(parent); } AddChild(std::shared_ptr<Task> child)127 void AddChild(std::shared_ptr<Task> child) { this->children_.push_back(child); } AddNoDepGrandchild(std::shared_ptr<Task> grandchild)128 void AddNoDepGrandchild(std::shared_ptr<Task> grandchild) { this->no_dep_grandchildren_.push_back(grandchild); } 129 HasChild(std::shared_ptr<Task> child)130 bool HasChild(std::shared_ptr<Task> child) { 131 return std::find(children_.begin(), children_.end(), child) != children_.end(); 132 } HasNoDepGrandchild(std::shared_ptr<Task> grandchild)133 bool HasNoDepGrandchild(std::shared_ptr<Task> grandchild) { 134 return std::find(no_dep_grandchildren_.begin(), no_dep_grandchildren_.end(), grandchild) != 135 no_dep_grandchildren_.end(); 136 } 137 138 // Weighting (Evaluation) AssignWeight(size_t weight)139 void AssignWeight(size_t weight) { 140 if (weight == 0) { 141 this->weight_ = 1; 142 } else if (weight < 0) { 143 this->weight_ = SIZE_MAX; 144 } else { 145 this->weight_ = weight; 146 } 147 this->parallel_weight_ = this->weight_; 148 } 149 150 // Other ResetStartEnd()151 void ResetStartEnd() { 152 this->set_start(SIZE_MAX); 153 this->set_end(0); 154 } 155 156 private: 157 TaskId id_; 158 TaskType type_; 159 160 // Attributes used to select task during scheduling 161 Time weight_; 162 Time parallel_weight_; 163 Time bottom_level_; 164 Time top_level_; 165 size_t depth_; 166 size_t succ_diff_type_; 167 double weighted_length_; 168 Time start_; 169 Time end_; 170 size_t pred_comm_; 171 172 // Attributes to maintain graph info 173 174 std::vector<std::weak_ptr<Task>> parents_; 175 std::vector<std::shared_ptr<Task>> children_; 176 std::vector<std::shared_ptr<Task>> no_dep_grandchildren_; 177 }; 178 179 using TaskPtr = std::shared_ptr<Task>; 180 using TaskSortFunction = bool (*)(std::shared_ptr<Task> const &, std::shared_ptr<Task> const &); 181 182 // Class for scheduling algorithms 183 struct SchedulingInput { 184 std::vector<std::shared_ptr<Task>> tasks; 185 }; 186 187 struct SchedulingOutput { 188 std::vector<Interval> task_times; 189 Time makespan; 190 }; 191 192 namespace FastGreedyScheduler { 193 // Main functionality 194 SchedulingOutput Process(SchedulingInput &, const std::string &); 195 SchedulingOutput ProcessCore(std::vector<std::shared_ptr<Task>> &, std::unordered_map<TaskType, int32_t> &, 196 const TaskSortFunction &, bool); 197 SchedulingOutput ProcessSingle(const SchedulingInput &, const TaskSortFunction &, bool, const std::string &); 198 199 // Compute Auxiliary Values for Task Sorting 200 void ComputeBottomLevelAndWeightedLength(std::vector<std::shared_ptr<Task>> &); 201 void ComputeDepthAndTopLevel(std::vector<std::shared_ptr<Task>> &); 202 void ComputePredComm(std::vector<std::shared_ptr<Task>> &); 203 204 // Lower Bounds 205 Time LowerBoundBottomLevel(std::vector<std::shared_ptr<Task>> &); 206 Time LowerBoundPEs(std::vector<std::shared_ptr<Task>> &, std::unordered_map<TaskType, int32_t> &); 207 208 // Dependency Generation 209 std::vector<std::pair<TaskId, TaskId>> ScheduleToDependencies(const SchedulingOutput &); 210 211 // Verification 212 bool VerifyDAG(std::vector<std::shared_ptr<Task>> &); 213 bool VerifyScheduling(std::vector<std::shared_ptr<Task>> &); 214 bool VerifyDependencies(std::vector<std::shared_ptr<Task>> &, std::vector<std::pair<TaskId, TaskId>> &); 215 216 // Log 217 void PrintLog(const SchedulingOutput &, const std::vector<std::pair<TaskId, TaskId>> &, const std::string &); 218 } // namespace FastGreedyScheduler 219 220 SchedulingInput ExtractSchedulingInput(const FuncGraphManagerPtr &, const std::vector<CNodePtr> &, 221 std::unordered_map<CNodePtr, TaskPtr> *); 222 void AddRealDependencies(const FuncGraphManagerPtr &, const std::vector<CNodePtr> &, 223 const std::vector<std::pair<TaskId, TaskId>> &, std::unordered_map<CNodePtr, TaskPtr> &); 224 225 // Functions for integration 226 void CompCommScheduling(const FuncGraphPtr &); 227 } // namespace opt 228 } // namespace mindspore 229 #endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_COMP_COMM_SCHEDULING_H_ 230