1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <algorithm>
18 #include <unordered_map>
19 #include <set>
20 #include <map>
21 #include <deque>
22 #include <queue>
23 #include <iostream>
24 #include <string>
25 #include <cmath>
26 #include <iomanip>
27 #include <functional>
28
29 #include "mindspore/ccsrc/frontend/parallel/pass/comp_comm_scheduling.h"
30 #include "mindspore/ccsrc/include/common/utils/utils.h"
31 #include "mindspore/ccsrc/frontend/parallel/step_parallel.h"
32 #include "mindspore/core/utils/misc.h"
33 #include "mindspore/core/utils/convert_utils_base.h"
34
35 #include "include/backend/optimizer/helper.h"
36
37 namespace mindspore {
38 namespace opt {
39 // Subroutines Implementing "Scheduling to Dependencies"
40 struct SortByStart {
operator ()mindspore::opt::SortByStart41 bool operator()(const Interval &interval1, const Interval &interval2) const {
42 const auto &id1 = interval1.id;
43 const auto &start1 = interval1.start;
44 const auto &end1 = interval1.end;
45 const auto &id2 = interval2.id;
46 const auto &start2 = interval2.start;
47 const auto &end2 = interval2.end;
48 return start1 < start2 || (start1 == start2 && end1 < end2) || (start1 == start2 && end1 == end2 && id1 < id2);
49 }
50 };
51
52 struct SortByEnd {
operator ()mindspore::opt::SortByEnd53 bool operator()(const Interval &interval1, const Interval &interval2) const {
54 const auto &id1 = interval1.id;
55 const auto &start1 = interval1.start;
56 const auto &end1 = interval1.end;
57 const auto &id2 = interval2.id;
58 const auto &start2 = interval2.start;
59 const auto &end2 = interval2.end;
60 return end1 < end2 || (end1 == end2 && start1 < start2) || (end1 == end2 && start1 == start2 && id1 < id2);
61 }
62 };
63
Overlap(const Time & start1,const Time & end1,const Time & start2,const Time & end2)64 bool Overlap(const Time &start1, const Time &end1, const Time &start2, const Time &end2) {
65 return (start1 >= start2 && start1 < end2) ||
66 (start2 >= start1 && start2 < end1); // if equal start and end for two interval, then no overlap
67 }
68
ScheduleToDependencies(const SchedulingOutput & schedule)69 std::vector<std::pair<TaskId, TaskId>> FastGreedyScheduler::ScheduleToDependencies(const SchedulingOutput &schedule) {
70 std::vector<std::pair<TaskId, TaskId>> dependencies; // to return
71 MS_LOG(INFO) << "Started Preprocessing of Intervals";
72 // Distinguish types and sort
73 std::unordered_map<TaskType, std::set<Interval, SortByStart>> tasks_start;
74 std::unordered_map<TaskType, std::set<Interval, SortByEnd>> tasks_end;
75 for (const auto &task_time : schedule.task_times) {
76 tasks_start[task_time.type].insert(task_time);
77 tasks_end[task_time.type].insert(task_time);
78 }
79 MS_LOG(INFO) << "Finished Preprocessing of Intervals";
80 MS_LOG(INFO) << "Started Main Loop";
81 // Main loop: check each task for potential dependencies in its right neighborhood
82 for (const auto &type_to_set : tasks_start) {
83 const auto &type = type_to_set.first;
84 for (auto it = tasks_start[type].begin(); it != tasks_start[type].end(); ++it) {
85 tasks_end[type].erase(*it);
86 // Dismiss overlapping tasks: save min end value of non-overlapping task to the right
87 std::unordered_map<TaskId, bool> dismissed;
88 auto it1 = std::next(it);
89 for (; Overlap(it->start, it->end, it1->start, it1->end) && it1 != tasks_start[type].end(); ++it1) {
90 dismissed[it1->id] = true;
91 }
92 Time min_end_value = 0;
93 for (auto it2 = tasks_end[type].begin(); it2 != tasks_end[type].end(); ++it2) {
94 if (!dismissed[it2->id]) {
95 min_end_value = it2->end;
96 break;
97 }
98 }
99 // Add dependencies to immediate right neighborhood
100 for (; it1->start < min_end_value && it1 != tasks_start[type].end(); ++it1) {
101 dependencies.emplace_back(it->id, it1->id);
102 }
103 }
104 }
105 MS_LOG(INFO) << "Finished Main Loop";
106 MS_LOG(INFO) << "Generated " << dependencies.size() << " dependencies";
107 return dependencies;
108 }
109
110 // Sorting for tasks
SortByWeightMax(const std::shared_ptr<Task> & task1,const std::shared_ptr<Task> & task2)111 bool SortByWeightMax(const std::shared_ptr<Task> &task1, const std::shared_ptr<Task> &task2) {
112 return task1->weight() > task2->weight() || (task1->weight() == task2->weight() && task1->id() < task2->id());
113 }
114
SortByWeightMin(const std::shared_ptr<Task> & task1,const std::shared_ptr<Task> & task2)115 bool SortByWeightMin(const std::shared_ptr<Task> &task1, const std::shared_ptr<Task> &task2) {
116 return task1->weight() < task2->weight() || (task1->weight() == task2->weight() && task1->id() < task2->id());
117 }
118
SortBySuccDiff(const std::shared_ptr<Task> & task1,const std::shared_ptr<Task> & task2)119 bool SortBySuccDiff(const std::shared_ptr<Task> &task1, const std::shared_ptr<Task> &task2) {
120 return task1->succ_diff_type() > task2->succ_diff_type() ||
121 (task1->succ_diff_type() == task2->succ_diff_type() && task1->weight() > task2->weight()) ||
122 (task1->succ_diff_type() == task2->succ_diff_type() && task1->weight() == task2->weight() &&
123 task1->id() < task2->id());
124 }
125
SortByBottomLevelMax(const std::shared_ptr<Task> & task1,const std::shared_ptr<Task> & task2)126 bool SortByBottomLevelMax(const std::shared_ptr<Task> &task1, const std::shared_ptr<Task> &task2) {
127 return task1->bottom_level() > task2->bottom_level() ||
128 (task1->bottom_level() == task2->bottom_level() && task1->weight() > task2->weight()) ||
129 (task1->bottom_level() == task2->bottom_level() && task1->weight() == task2->weight() &&
130 task1->id() < task2->id());
131 }
132
SortByBottomLevelMin(const std::shared_ptr<Task> & task1,const std::shared_ptr<Task> & task2)133 bool SortByBottomLevelMin(const std::shared_ptr<Task> &task1, const std::shared_ptr<Task> &task2) {
134 return task1->bottom_level() < task2->bottom_level() ||
135 (task1->bottom_level() == task2->bottom_level() && task1->weight() > task2->weight()) ||
136 (task1->bottom_level() == task2->bottom_level() && task1->weight() == task2->weight() &&
137 task1->id() < task2->id());
138 }
139
SortByTopLevelMax(const std::shared_ptr<Task> & task1,const std::shared_ptr<Task> & task2)140 bool SortByTopLevelMax(const std::shared_ptr<Task> &task1, const std::shared_ptr<Task> &task2) {
141 return task1->top_level() > task2->top_level() ||
142 (task1->top_level() == task2->top_level() && task1->weight() > task2->weight()) ||
143 (task1->top_level() == task2->top_level() && task1->weight() == task2->weight() && task1->id() < task2->id());
144 }
145
SortByTopLevelMin(const std::shared_ptr<Task> & task1,const std::shared_ptr<Task> & task2)146 bool SortByTopLevelMin(const std::shared_ptr<Task> &task1, const std::shared_ptr<Task> &task2) {
147 return task1->top_level() < task2->top_level() ||
148 (task1->top_level() == task2->top_level() && task1->weight() > task2->weight()) ||
149 (task1->top_level() == task2->top_level() && task1->weight() == task2->weight() && task1->id() < task2->id());
150 }
151
SortByBottomTopLevelMaxSum(const std::shared_ptr<Task> & task1,const std::shared_ptr<Task> & task2)152 bool SortByBottomTopLevelMaxSum(const std::shared_ptr<Task> &task1, const std::shared_ptr<Task> &task2) {
153 return task1->top_level() + task1->bottom_level() > task2->top_level() + task2->bottom_level() ||
154 (task1->top_level() + task1->bottom_level() == task2->top_level() + task2->bottom_level() &&
155 task1->weight() > task2->weight()) ||
156 (task1->top_level() + task1->bottom_level() == task2->top_level() + task2->bottom_level() &&
157 task1->weight() == task2->weight() && task1->id() < task2->id());
158 }
159
SortByBottomTopLevelMinSum(const std::shared_ptr<Task> & task1,const std::shared_ptr<Task> & task2)160 bool SortByBottomTopLevelMinSum(const std::shared_ptr<Task> &task1, const std::shared_ptr<Task> &task2) {
161 return task1->top_level() + task1->bottom_level() < task2->top_level() + task2->bottom_level() ||
162 (task1->top_level() + task1->bottom_level() == task2->top_level() + task2->bottom_level() &&
163 task1->weight() > task2->weight()) ||
164 (task1->top_level() + task1->bottom_level() == task2->top_level() + task2->bottom_level() &&
165 task1->weight() == task2->weight() && task1->id() < task2->id());
166 }
167
168 // Ishfaq Ahmad, Yu-Kwong Kwok, and Min-You Wu.
169 // Analysis, evaluation, and comparison of algorithms for scheduling task graphs on parallel processors.
170 // Second International Symposium on Parallel Architectures, Algorithms, and Networks (I-SPAN?96),
171 // pages 207?213. IEEE, 1996.
SortByBottomTopLevelComposite(const std::shared_ptr<Task> & task1,const std::shared_ptr<Task> & task2)172 bool SortByBottomTopLevelComposite(const std::shared_ptr<Task> &task1, const std::shared_ptr<Task> &task2) {
173 return task1->bottom_level() - task1->top_level() > task2->bottom_level() - task2->top_level() ||
174 (task1->bottom_level() - task1->top_level() == task2->bottom_level() - task2->top_level() &&
175 task1->weight() > task2->weight()) ||
176 (task1->bottom_level() - task1->top_level() == task2->bottom_level() - task2->top_level() &&
177 task1->weight() == task2->weight() && task1->id() < task2->id());
178 }
179
180 // Behrooz Shirazi, Mingfang Wang, and Girish Pathak.
181 // Analysis and evaluation of heuristic methods for static task scheduling.
182 // Journal of Parallel and Distributed Computing, 10(3):222?232, 1990.
SortByWeightedLength(const std::shared_ptr<Task> & task1,const std::shared_ptr<Task> & task2)183 bool SortByWeightedLength(const std::shared_ptr<Task> &task1, const std::shared_ptr<Task> &task2) {
184 return task1->weighted_length() > task2->weighted_length() ||
185 (task1->weighted_length() == task2->weighted_length() && task1->id() < task2->id());
186 }
187
188 // DFS with weights for tie breaking
SortByDepthMax(const std::shared_ptr<Task> & task1,const std::shared_ptr<Task> & task2)189 bool SortByDepthMax(const std::shared_ptr<Task> &task1, const std::shared_ptr<Task> &task2) {
190 return task1->depth() > task2->depth() || (task1->depth() == task2->depth() && task1->weight() > task2->weight()) ||
191 (task1->depth() == task2->depth() && task1->weight() == task2->weight() && task1->id() < task2->id());
192 }
193
194 // BFS with weights for tie breaking
SortByDepthMin(const std::shared_ptr<Task> & task1,const std::shared_ptr<Task> & task2)195 bool SortByDepthMin(const std::shared_ptr<Task> &task1, const std::shared_ptr<Task> &task2) {
196 return task1->depth() < task2->depth() || (task1->depth() == task2->depth() && task1->weight() > task2->weight()) ||
197 (task1->depth() == task2->depth() && task1->weight() == task2->weight() && task1->id() < task2->id());
198 }
199
200 // Sort by predecessor to comm
SortByPredComm(const std::shared_ptr<Task> & task1,const std::shared_ptr<Task> & task2)201 bool SortByPredComm(const std::shared_ptr<Task> &task1, const std::shared_ptr<Task> &task2) {
202 return task1->pred_comm() < task2->pred_comm() ||
203 (task1->pred_comm() == task2->pred_comm() && task1->bottom_level() > task2->bottom_level()) ||
204 (task1->pred_comm() == task2->pred_comm() && task1->bottom_level() == task2->bottom_level() &&
205 task1->id() < task2->id());
206 }
207
208 // Sort by predecessor to comm DFS
SortByPredCommDepth(const std::shared_ptr<Task> & task1,const std::shared_ptr<Task> & task2)209 bool SortByPredCommDepth(const std::shared_ptr<Task> &task1, const std::shared_ptr<Task> &task2) {
210 return task1->pred_comm() < task2->pred_comm() ||
211 (task1->pred_comm() == task2->pred_comm() && task1->depth() > task2->depth()) ||
212 (task1->pred_comm() == task2->pred_comm() && task1->depth() == task2->depth() && task1->id() < task2->id());
213 }
214
215 // Sorting by load for processing elements
216 struct SortByLoad {
operator ()mindspore::opt::SortByLoad217 bool operator()(const ProcessingElement &pe1, const ProcessingElement &pe2) const {
218 return pe1.load < pe2.load || (pe1.load == pe2.load && pe1.id < pe2.id);
219 }
220 };
221
222 // Get PEs description
GetTestPEs()223 std::unordered_map<TaskType, int32_t> GetTestPEs() {
224 std::unordered_map<TaskType, int32_t> new_pem;
225 new_pem[kComm] = 1;
226 new_pem[kComp] = 1;
227 return new_pem;
228 }
229
230 // Auxiliary subroutines and lower bounds
ComputeDepthAndTopLevel(std::vector<std::shared_ptr<Task>> & tasks)231 void FastGreedyScheduler::ComputeDepthAndTopLevel(std::vector<std::shared_ptr<Task>> &tasks) {
232 MS_LOG(INFO) << "Top Level: Start Initialization";
233 std::unordered_map<TaskId, size_t> unprocessed_parents;
234 std::queue<std::shared_ptr<Task>> tasks_to_visit;
235 // Initialization loop
236 for (size_t j = 0; j < tasks.size(); ++j) {
237 const auto &id = tasks[j]->id();
238 unprocessed_parents[id] = tasks[j]->parents().size();
239 if (unprocessed_parents[id] == 0) {
240 tasks[j]->set_top_level(tasks[j]->parallel_weight());
241 tasks_to_visit.push(tasks[j]);
242 }
243 }
244 MS_LOG(INFO) << "Top Level: End Initialization";
245 MS_LOG(INFO) << "Top Level: Start Traversal Loop";
246 while (!tasks_to_visit.empty()) {
247 const auto &selected_task = tasks_to_visit.front();
248 // Update candidate tasks
249 for (auto &successor : selected_task->children()) {
250 const auto &succ_id = successor->id();
251 successor->set_depth(std::max(successor->depth(), selected_task->depth() + 1));
252 successor->set_top_level(
253 std::max(successor->top_level(), selected_task->top_level() + successor->parallel_weight()));
254 unprocessed_parents[succ_id] -= 1;
255 if (unprocessed_parents[succ_id] == 0) {
256 tasks_to_visit.push(successor);
257 }
258 }
259 tasks_to_visit.pop();
260 }
261 MS_LOG(INFO) << "Top Level: End Traversal Loop";
262 }
263
ComputeBottomLevelAndWeightedLength(std::vector<std::shared_ptr<Task>> & tasks)264 void FastGreedyScheduler::ComputeBottomLevelAndWeightedLength(std::vector<std::shared_ptr<Task>> &tasks) {
265 MS_LOG(INFO) << "Bottom Level: Start Initialization";
266 std::unordered_map<TaskId, size_t> unprocessed_children;
267 std::unordered_map<TaskId, double> children_sum;
268 std::unordered_map<TaskId, double> children_max;
269 std::queue<std::shared_ptr<Task>> tasks_to_visit;
270 // Initialization loop
271 for (auto &task : tasks) {
272 const auto &id = task->id();
273 task->set_bottom_level(task->parallel_weight());
274 task->set_weighted_length(task->parallel_weight());
275 unprocessed_children[id] = task->children().size();
276 if (unprocessed_children[id] == 0) {
277 tasks_to_visit.push(task);
278 }
279 }
280 MS_LOG(INFO) << "Bottom Level: End Initialization";
281 MS_LOG(INFO) << "Bottom Level: Start Traversal Loop";
282 while (!tasks_to_visit.empty()) {
283 const auto &selected_task = tasks_to_visit.front();
284 // Update candidate tasks
285 for (auto &predecessor : selected_task->parents()) {
286 const auto &pred_id = predecessor.lock()->id();
287 predecessor.lock()->set_bottom_level(std::max(
288 predecessor.lock()->bottom_level(), selected_task->bottom_level() + predecessor.lock()->parallel_weight()));
289 children_sum[pred_id] += selected_task->weighted_length();
290 children_max[pred_id] = std::max(children_max[pred_id], selected_task->weighted_length());
291 unprocessed_children[pred_id] -= 1;
292 if (unprocessed_children[pred_id] == 0) {
293 if (children_max[pred_id] == 0) {
294 MS_LOG(EXCEPTION) << "divisor children_max[pred_id] cannot be 0!";
295 }
296 predecessor.lock()->set_weighted_length(predecessor.lock()->parallel_weight() + children_max[pred_id] +
297 children_sum[pred_id] / children_max[pred_id]);
298 tasks_to_visit.push(predecessor.lock());
299 }
300 }
301 tasks_to_visit.pop();
302 }
303 MS_LOG(INFO) << "Bottom Level: End Traversal Loop";
304 }
305
ComputePredComm(std::vector<std::shared_ptr<Task>> & tasks)306 void FastGreedyScheduler::ComputePredComm(std::vector<std::shared_ptr<Task>> &tasks) {
307 for (auto &task : tasks) {
308 task->set_pred_comm(0);
309 for (auto &predecessor : task->parents()) {
310 if (predecessor.lock()->type() == kComm) {
311 task->set_pred_comm(task->pred_comm() + 1);
312 }
313 }
314 }
315 }
316
LowerBoundBottomLevel(std::vector<std::shared_ptr<Task>> & tasks)317 Time FastGreedyScheduler::LowerBoundBottomLevel(std::vector<std::shared_ptr<Task>> &tasks) {
318 Time max_bottom_level = 0;
319 for (const auto &task : tasks) {
320 max_bottom_level = std::max(max_bottom_level, task->bottom_level());
321 }
322 return max_bottom_level;
323 }
324
LowerBoundPEs(std::vector<std::shared_ptr<Task>> & tasks,std::unordered_map<TaskType,int32_t> & type_to_num_cores_map)325 Time FastGreedyScheduler::LowerBoundPEs(std::vector<std::shared_ptr<Task>> &tasks,
326 std::unordered_map<TaskType, int32_t> &type_to_num_cores_map) {
327 double lower_bound = 0;
328
329 std::unordered_map<TaskType, Time> type_task_sum;
330 for (const auto &task : tasks) {
331 type_task_sum[task->type()] += task->weight();
332 }
333 for (const auto &type_to_num : type_to_num_cores_map) {
334 const auto &type = type_to_num.first;
335 const auto &num_cores = type_to_num.second;
336 if (num_cores == 0) {
337 MS_LOG(EXCEPTION) << "divisor num_cores cannot be 0!";
338 }
339 lower_bound = std::max(lower_bound, type_task_sum[type] / (1.0 * num_cores));
340 }
341 return std::ceil(lower_bound);
342 }
343
344 // Main algorithms/subroutines
SelectPEandTime(const Task & task,Time can_start,std::set<ProcessingElement,SortByLoad> * PEs_ptr)345 std::pair<PeId, Time> SelectPEandTime(const Task &task, Time can_start,
346 std::set<ProcessingElement, SortByLoad> *PEs_ptr) {
347 auto &PEs = *PEs_ptr;
348 std::pair<PeId, Time> return_pair = std::make_pair(0, 0);
349 for (auto it = PEs.begin(); it != PEs.end(); ++it) {
350 // unsafe use of const_cast, but we modify only idle list and not key sorting parameters like load, id, etc.
351 // cf: https://stackoverflow.com/questions/43340050/modification-of-elements-of-stdset-defined-behavior
352 auto &mut_pe = const_cast<ProcessingElement &>(*it);
353 // Put in first idle that fits it
354 for (auto idle_it = mut_pe.idle.begin(); idle_it != mut_pe.idle.end(); ++idle_it) {
355 Time start_time;
356 bool case_flag = false;
357 // Distinguish cases based on can_start constraint
358 if (can_start <= idle_it->first) {
359 start_time = idle_it->first;
360 } else if (can_start <= idle_it->second) {
361 start_time = can_start;
362 case_flag = true;
363 } else { // can_start > idle_it->second means we are not allowed to schedule the task here
364 continue;
365 }
366 // If the task fits, then place it here
367 if (idle_it->second - start_time >= task.weight()) {
368 // Save info to return: start task at time idle_it->first
369 return_pair.first = (*it).id;
370 return_pair.second = start_time;
371 // Update idle list
372 if (!case_flag) {
373 if (idle_it->second - idle_it->first == task.weight()) { // whole idle interval is filled in, erase it
374 mut_pe.idle.erase(idle_it);
375 } else { // idle_it->second - idle_it->first > task.weight()
376 idle_it->first += task.weight();
377 }
378 } else { // case_flag = true, idle interval is broken into two sub-blocks [idle_it->first, can_start] and
379 // (maybe empty) [can_start + weight, idle_it->second]
380 Time upper = idle_it->second;
381 idle_it->second = can_start;
382 if (upper - can_start - task.weight() > 0) {
383 std::pair<Time, Time> new_idle = std::make_pair(can_start + task.weight(), upper);
384 mut_pe.idle.emplace(std::next(idle_it), new_idle);
385 }
386 }
387 // Update load and PEs set
388 auto updated_PE = PEs.extract(it);
389 updated_PE.value().load += task.weight();
390 PEs.insert(std::move(updated_PE));
391 return return_pair;
392 }
393 }
394 }
395 return return_pair;
396 }
397
SelectPEandTimeAvailableStart(const Task & task,Time can_start,std::vector<ProcessingElement> * PEs_ptr)398 std::pair<PeId, Time> SelectPEandTimeAvailableStart(const Task &task, Time can_start,
399 std::vector<ProcessingElement> *PEs_ptr) {
400 auto &PEs = *PEs_ptr;
401 // Precompute min first available start for task
402 Time min_start = SIZE_MAX;
403 bool min_case = false;
404 std::vector<ProcessingElement>::iterator min_it;
405 std::list<std::pair<Time, Time>>::iterator min_idle_it;
406 for (auto it = PEs.begin(); it != PEs.end(); ++it) {
407 for (auto idle_it = it->idle.begin(); idle_it != it->idle.end(); ++idle_it) {
408 Time start_time;
409 bool case_flag = false;
410 // Distinguish cases based on can_start constraint
411 if (can_start <= idle_it->first) {
412 start_time = idle_it->first;
413 } else if (can_start <= idle_it->second) {
414 start_time = can_start;
415 case_flag = true;
416 } else { // can_start > idle_it->second means we are not allowed to schedule the task here
417 continue;
418 }
419 if (idle_it->second - start_time >= task.weight()) {
420 if (min_start > start_time) {
421 min_start = start_time;
422 min_case = case_flag;
423 min_it = it;
424 min_idle_it = idle_it;
425 break;
426 }
427 }
428 }
429 }
430 // Assign task to min PE
431 std::pair<PeId, Time> return_pair = std::make_pair(0, 0);
432 // Save info to return: start task at time idle_it->first
433 return_pair.first = (*min_it).id;
434 return_pair.second = min_start;
435 // Update idle list
436 if (!min_case) {
437 if (min_idle_it->second - min_idle_it->first == task.weight()) { // whole idle interval is filled in, erase it
438 min_it->idle.erase(min_idle_it);
439 } else { // idle_it->second - idle_it->first > task.weight()
440 min_idle_it->first += task.weight();
441 }
442 } else { // min_case = true, idle interval is broken into two sub-blocks [idle_it->first, can_start] and
443 // (maybe empty)[can_start + task.weight(), idle_it->second]
444 Time upper = min_idle_it->second;
445 min_idle_it->second = can_start;
446 if (upper - can_start - task.weight() > 0) {
447 std::pair<Time, Time> new_idle = std::make_pair(can_start + task.weight(), upper);
448 min_it->idle.emplace(std::next(min_idle_it), new_idle);
449 }
450 }
451 // Update load
452 min_it->load += task.weight();
453 return return_pair;
454 }
455
456 constexpr TaskSortFunction THREAD_SORT[] = {SortByWeightMax,
457 SortByWeightMin,
458 SortBySuccDiff,
459 SortByBottomLevelMax,
460 SortByBottomLevelMin,
461 SortByTopLevelMax,
462 SortByTopLevelMin,
463 SortByBottomTopLevelMaxSum,
464 SortByBottomTopLevelMinSum,
465 SortByBottomTopLevelComposite,
466 SortByWeightedLength,
467 SortByDepthMax,
468 SortByDepthMin,
469 SortByPredComm,
470 SortByPredCommDepth};
471
472 constexpr std::string_view THREAD_SORT_NAMES[] = {"SortByWeightMax",
473 "SortByWeightMin",
474 "SortBySuccDiff",
475 "SortByBottomLevelMax",
476 "SortByBottomLevelMin",
477 "SortByTopLevelMax",
478 "SortByTopLevelMin",
479 "SortByBottomTopLevelMaxSum",
480 "SortByBottomTopLevelMinSum",
481 "SortByBottomTopLevelComposite",
482 "SortByWeightedLength",
483 "SortByDepthMax",
484 "SortByDepthMin",
485 "SortByPredComm",
486 "SortByPredCommDepth"};
487
488 enum class PEsSort { kSortByLoad = 0, kSortByValidStart, kNumPEsSort };
489
490 constexpr std::string_view PE_NAME_SORT[] = {"SortByLoad", "SortByValidStart"};
491
Process(SchedulingInput & input,const std::string & graph_name)492 SchedulingOutput FastGreedyScheduler::Process(SchedulingInput &input, const std::string &graph_name) {
493 std::vector<std::shared_ptr<Task>> *tasks = &(input.tasks);
494 auto type_to_num_cores_map = GetTestPEs();
495 SchedulingOutput output{{}, SIZE_MAX};
496 // Optional: verify input task graph is a DAG
497 if (VerifyDAG(*tasks)) {
498 MS_LOG(INFO) << "Verification of DAG: SUCCESS";
499 } else {
500 MS_LOG(INFO) << "Verification of DAG: FAILURE";
501 }
502
503 // Preprocessing: values computation for necessary sorting
504 ComputeBottomLevelAndWeightedLength(*tasks);
505 ComputeDepthAndTopLevel(*tasks);
506 ComputePredComm(*tasks);
507
508 // Loop over all sorting combinations
509 std::unordered_map<std::shared_ptr<Task>, Time> best_start;
510 std::unordered_map<std::shared_ptr<Task>, Time> best_end; // to use in verify dependencies only
511 std::string best_solution;
512 MS_LOG(INFO) << "Start loop multiple scheduling functions";
513 for (size_t task_sort = 0; task_sort < static_cast<size_t>(kNumTaskSort); ++task_sort) {
514 for (size_t pes_sort = 0; pes_sort < static_cast<size_t>(PEsSort::kNumPEsSort); ++pes_sort) {
515 MS_LOG(INFO) << THREAD_SORT_NAMES[task_sort] << " and " << PE_NAME_SORT[pes_sort];
516 SchedulingOutput solution = ProcessCore(*tasks, type_to_num_cores_map, THREAD_SORT[task_sort],
517 (pes_sort == static_cast<size_t>(PEsSort::kSortByLoad)));
518 if (solution.makespan < output.makespan) {
519 output = solution;
520 best_solution = THREAD_SORT_NAMES[task_sort];
521 for (const auto &task : *tasks) { // to use in verify dependencies only
522 best_start[task] = task->start();
523 best_end[task] = task->end();
524 }
525 }
526 for (const auto &task : *tasks) {
527 task->ResetStartEnd();
528 }
529 }
530 }
531 MS_LOG(INFO) << "End loop multiple scheduling functions";
532
533 // Print stats about best solution
534 MS_LOG(INFO) << "Best solution is: " << best_solution;
535 MS_LOG(INFO) << "Makespan of best solution is " << output.makespan;
536 MS_LOG(INFO) << "Bottom level lower bound is " << LowerBoundBottomLevel(*tasks);
537 MS_LOG(INFO) << "Max type lower bound is " << LowerBoundPEs(*tasks, type_to_num_cores_map);
538 MS_LOG(INFO) << "Solution relative error is " << std::setprecision(5)
539 << ((output.makespan /
540 (1.0 * std::max(LowerBoundBottomLevel(*tasks), LowerBoundPEs(*tasks, type_to_num_cores_map))) -
541 1) *
542 100)
543 << "%";
544
545 // Create and (optionally) verify dependencies (here only for testing)
546 MS_LOG(INFO) << "Start Schedule to Dependencies";
547 auto dependencies = ScheduleToDependencies(output);
548 MS_LOG(INFO) << "End Schedule to Dependencies";
549 for (const auto &task : *tasks) {
550 task->set_start(best_start[task]);
551 task->set_end(best_end[task]);
552 }
553
554 // Output log file with all info (scheduling and dependencies)
555 MS_LOG(INFO) << "Start printing output log file";
556 PrintLog(output, dependencies, graph_name);
557 MS_LOG(INFO) << "End printing output log file";
558
559 return output;
560 }
561
ProcessCore(std::vector<std::shared_ptr<Task>> & tasks,std::unordered_map<TaskType,int32_t> & type_to_num_cores_map,const TaskSortFunction & sortPtr,bool pe_load_sort)562 SchedulingOutput FastGreedyScheduler::ProcessCore(std::vector<std::shared_ptr<Task>> &tasks,
563 std::unordered_map<TaskType, int32_t> &type_to_num_cores_map,
564 const TaskSortFunction &sortPtr, bool pe_load_sort) {
565 SchedulingOutput output{{}, 0};
566
567 // Initializations for tasks
568 MS_LOG(INFO) << "Started Task Initialization";
569 std::set<std::shared_ptr<Task>, TaskSortFunction> candidate_tasks(sortPtr);
570 std::unordered_map<TaskId, Time> can_start;
571 std::unordered_map<TaskId, size_t> unprocessed_parents;
572 for (auto &task : tasks) {
573 const auto &id = task->id();
574 can_start[id] = 0;
575 unprocessed_parents[id] = task->parents().size();
576 if (unprocessed_parents[id] == 0) {
577 candidate_tasks.insert(task);
578 }
579 }
580 MS_LOG(INFO) << "Finished Task Initialization";
581
582 // Initializations for processing elements
583 // Pick a sorting for processing elements
584 // Implemented: SortByLoad, SortByAvailableStart
585 // Only one structure to be used depending on argument; we define both here
586 std::unordered_map<TaskType, std::set<ProcessingElement, SortByLoad>> PEs_load;
587 std::unordered_map<TaskType, std::vector<ProcessingElement>> PEs_start;
588 MS_LOG(INFO) << "Started Processing Element Initialization";
589 size_t count = 0;
590 for (const auto &type_to_num : type_to_num_cores_map) {
591 const auto &type = type_to_num.first;
592 const auto &num_cores = type_to_num.second;
593 for (int i = 0; i < num_cores; ++i) {
594 ProcessingElement new_pe;
595 new_pe.id = count + IntToSize(i);
596 new_pe.type = type;
597 new_pe.load = 0;
598 new_pe.idle.emplace_back(0, SIZE_MAX);
599 if (pe_load_sort) {
600 PEs_load[type].insert(new_pe);
601 } else {
602 PEs_start[type].push_back(new_pe);
603 }
604 }
605 count += num_cores;
606 }
607 MS_LOG(INFO) << "Finished Processing Element Initialization";
608
609 // Task graph scheduling loop
610 MS_LOG(INFO) << "Started Scheduling Main Loop";
611 while (!candidate_tasks.empty()) {
612 // Select task and schedule it, save info for output
613 const auto selected_task = *(candidate_tasks.begin());
614 const auto &selected_id = selected_task->id();
615 // Selected PE and start time
616 std::pair<PeId, Time> PE_and_time;
617 if (pe_load_sort) {
618 PE_and_time = SelectPEandTime(*selected_task, can_start[selected_id], &PEs_load[selected_task->type()]);
619 } else {
620 PE_and_time =
621 SelectPEandTimeAvailableStart(*selected_task, can_start[selected_id], &PEs_start[selected_task->type()]);
622 }
623
624 const auto &sigma = PE_and_time.second;
625
626 // Maintenance of task interval
627 selected_task->set_start(sigma);
628 selected_task->set_end(sigma + selected_task->weight());
629 // New interval for task in output
630 Interval new_interval{selected_id, selected_task->type(), selected_task->start(), selected_task->end()};
631 output.task_times.push_back(new_interval);
632 // Update makespan
633 output.makespan = std::max(output.makespan, selected_task->end());
634 // Update candidate tasks
635 candidate_tasks.erase(selected_task);
636 for (const auto &successor : selected_task->children()) {
637 const auto &succ_id = successor->id();
638 can_start[succ_id] = std::max(can_start[succ_id], selected_task->end());
639 unprocessed_parents[succ_id] -= 1;
640 if (unprocessed_parents[succ_id] == 0) {
641 candidate_tasks.insert(successor);
642 }
643 }
644 }
645 MS_LOG(INFO) << "Finished Scheduling Main Loop";
646 MS_LOG(INFO) << "Makespan is " << output.makespan;
647 // Verification of scheduling solution (optional)
648 if (VerifyScheduling(tasks)) {
649 MS_LOG(INFO) << "Verification of Scheduling: SUCCESS";
650 } else {
651 MS_LOG(INFO) << "Verification of Scheduling: FAILURE";
652 }
653
654 return output;
655 }
656
ProcessSingle(const SchedulingInput & input,const TaskSortFunction & sortPtr,bool pe_load_sort,const std::string & graph_name)657 SchedulingOutput FastGreedyScheduler::ProcessSingle(const SchedulingInput &input, const TaskSortFunction &sortPtr,
658 bool pe_load_sort, const std::string &graph_name) {
659 auto tasks = input.tasks;
660 auto type_to_num_cores_map = GetTestPEs();
661 SchedulingOutput output{{}, 0};
662 // Optional: verify input task graph is a DAG
663 if (VerifyDAG(tasks)) {
664 MS_LOG(INFO) << "Verification of DAG: SUCCESS";
665 } else {
666 MS_LOG(INFO) << "Verification of DAG: FAILURE";
667 }
668 // Preprocessing: values computation for sorting necessary
669 ComputeBottomLevelAndWeightedLength(tasks);
670 ComputeDepthAndTopLevel(tasks);
671 // Initializations for tasks
672 MS_LOG(INFO) << "Started Task Initialization";
673 std::set<std::shared_ptr<Task>, TaskSortFunction> candidate_tasks(sortPtr);
674 std::unordered_map<TaskId, Time> can_start;
675 std::unordered_map<TaskId, size_t> unprocessed_parents;
676 for (auto &task : tasks) {
677 const auto &id = task->id();
678 can_start[id] = 0;
679 unprocessed_parents[id] = task->parents().size();
680 if (unprocessed_parents[id] == 0) {
681 candidate_tasks.insert(task);
682 }
683 }
684 MS_LOG(INFO) << "Finished Task Initialization";
685
686 // Initializations for processing elements
687 // Pick a sorting for processing elements
688 // Implemented: SortByLoad, SortByAvailableStart
689 // Only one structure to be used depending on argument; we define both here
690 std::unordered_map<TaskType, std::set<ProcessingElement, SortByLoad>> PEs_load;
691 std::unordered_map<TaskType, std::vector<ProcessingElement>> PEs_start;
692 MS_LOG(INFO) << "Started Processing Element Initialization";
693 size_t count = 0;
694 for (const auto &type_to_num : type_to_num_cores_map) {
695 const auto &type = type_to_num.first;
696 const auto &num_cores = type_to_num.second;
697 for (int i = 0; i < num_cores; ++i) {
698 ProcessingElement new_pe;
699 new_pe.id = count + i;
700 new_pe.type = type;
701 new_pe.load = 0;
702 new_pe.idle.emplace_back(0, SIZE_MAX);
703 if (pe_load_sort) {
704 PEs_load[type].insert(new_pe);
705 } else {
706 PEs_start[type].push_back(new_pe);
707 }
708 }
709 count += num_cores;
710 }
711 MS_LOG(INFO) << "Finished Processing Element Initialization";
712
713 // Task graph scheduling loop
714 MS_LOG(INFO) << "Started Scheduling Main Loop";
715 while (!candidate_tasks.empty()) {
716 // Select task and schedule its blocks, save info for output
717 const auto selected_task = *(candidate_tasks.begin());
718 const auto &selected_id = selected_task->id();
719 // Selected PE and start time
720 std::pair<PeId, Time> PE_and_time;
721 if (pe_load_sort) {
722 PE_and_time = SelectPEandTime(*selected_task, can_start[selected_id], &PEs_load[selected_task->type()]);
723 } else {
724 PE_and_time =
725 SelectPEandTimeAvailableStart(*selected_task, can_start[selected_id], &PEs_start[selected_task->type()]);
726 }
727 const auto &sigma = PE_and_time.second;
728
729 // Maintenance of task interval
730 selected_task->set_start(sigma);
731 selected_task->set_end(sigma + selected_task->weight());
732 // New interval for task in output
733 Interval new_interval{selected_id, selected_task->type(), selected_task->start(), selected_task->end()};
734 output.task_times.push_back(new_interval);
735 // Update makespan
736 output.makespan = std::max(output.makespan, selected_task->end());
737 // Update candidate tasks
738 candidate_tasks.erase(selected_task);
739 for (const auto &successor : selected_task->children()) {
740 const auto &succ_id = successor->id();
741 can_start[succ_id] = std::max(can_start[succ_id], selected_task->end());
742 unprocessed_parents[succ_id] -= 1;
743 if (unprocessed_parents[succ_id] == 0) {
744 candidate_tasks.insert(successor);
745 }
746 }
747 }
748 MS_LOG(INFO) << "Finished Scheduling Main Loop";
749 MS_LOG(INFO) << "Makespan is " << output.makespan;
750 MS_LOG(INFO) << "Bottom level lower bound is " << LowerBoundBottomLevel(tasks);
751 MS_LOG(INFO) << "Max type lower bound is " << LowerBoundPEs(tasks, type_to_num_cores_map);
752 MS_LOG(INFO) << "Solution relative error is " << std::setprecision(5)
753 << ((output.makespan /
754 (1.0 * std::max(LowerBoundBottomLevel(tasks), LowerBoundPEs(tasks, type_to_num_cores_map))) -
755 1) *
756 100)
757 << "%";
758 // Verification of scheduling solution (optional)
759 if (VerifyScheduling(tasks)) {
760 MS_LOG(INFO) << "Verification of Scheduling: SUCCESS";
761 } else {
762 MS_LOG(INFO) << "Verification of Scheduling: FAILURE";
763 }
764 // Scheduling to Dependencies (here only for testing)
765 MS_LOG(INFO) << "Start Schedule to Dependencies";
766 auto dependencies = ScheduleToDependencies(output);
767 MS_LOG(INFO) << "End Schedule to Dependencies";
768 if (VerifyDependencies(tasks, dependencies)) {
769 MS_LOG(INFO) << "Verification of Dependencies: SUCCESS";
770 } else {
771 MS_LOG(INFO) << "Verification of Dependencies: FAILURE";
772 }
773 PrintLog(output, dependencies, graph_name);
774
775 return output;
776 }
777
VerifyScheduling(std::vector<std::shared_ptr<Task>> & tasks)778 bool FastGreedyScheduler::VerifyScheduling(std::vector<std::shared_ptr<Task>> &tasks) {
779 bool flag = true;
780 MS_LOG(INFO) << "Start Verification of Scheduling";
781 for (auto &task : tasks) {
782 // Check if task is scheduled before its children
783 for (auto child = task->children().begin(); child != task->children().end(); ++child) {
784 if (!(task->start() < task->end() && task->end() <= (*child)->start() &&
785 (*child)->start() < (*child)->end())) { // assume open-rightpoint intervals and non-zero size
786 MS_LOG(INFO) << "Verification violation: task " << task->id() << " [" << task->start() << "," << task->end()
787 << "] and task " << (*child)->id() << " [" << (*child)->start() << "," << (*child)->end() << "]";
788 flag = false;
789 }
790 }
791 }
792 MS_LOG(INFO) << "End Verification of Scheduling";
793 return flag;
794 }
795
BFSsort(const std::shared_ptr<Task> & task1,const std::shared_ptr<Task> & task2)796 bool BFSsort(const std::shared_ptr<Task> &task1, const std::shared_ptr<Task> &task2) {
797 return task1->depth() < task2->depth() || (task1->depth() == task2->depth() && task1->id() < task2->id());
798 }
799
VerifyDependencies(std::vector<std::shared_ptr<Task>> & tasks,std::vector<std::pair<TaskId,TaskId>> & dependencies)800 bool FastGreedyScheduler::VerifyDependencies(std::vector<std::shared_ptr<Task>> &tasks,
801 std::vector<std::pair<TaskId, TaskId>> &dependencies) {
802 bool flag = true;
803
804 MS_LOG(INFO) << "Start Verification of Dependencies";
805 // Traverse graph by depth to maintain ancestor info
806 auto tasks_sorted = tasks;
807 std::sort(tasks_sorted.begin(), tasks_sorted.end(), BFSsort);
808 std::map<TaskId, std::map<TaskId, bool>> exists_path;
809 std::map<TaskId, std::shared_ptr<Task>> id_to_ptr;
810 for (auto current = tasks_sorted.begin(); current != tasks_sorted.end(); ++current) {
811 id_to_ptr[(*current)->id()] = *current;
812 for (auto parent = (*current)->parents().begin(); parent != (*current)->parents().end(); ++parent) {
813 exists_path[(*parent).lock()->id()][(*current)->id()] = true;
814 for (auto &it : tasks_sorted) {
815 if (exists_path[it->id()][(*parent).lock()->id()]) {
816 exists_path[it->id()][(*current)->id()] = true;
817 }
818 }
819 }
820 }
821 // For each dependency, check if redundant it forms a directed cycle and if corresponding tasks are scheduled
822 // correctly
823 size_t redundant_count = 0;
824 for (auto it = dependencies.begin(); it != dependencies.end(); ++it) {
825 const auto &source = id_to_ptr[it->first];
826 const auto &dst = id_to_ptr[it->second];
827 if (exists_path[it->first][it->second]) {
828 redundant_count++;
829 }
830 if (exists_path[it->second][it->first]) {
831 MS_LOG(INFO) << "Dependency cycle formation: task " << source->id() << " [" << source->start() << ","
832 << source->end() << "] and task " << dst->id() << " [" << dst->start() << "," << dst->end() << "]";
833 }
834 if (!(source->start() < source->end() && source->end() <= dst->start() && dst->start() < dst->end())) {
835 // allow weights of size 0
836 MS_LOG(INFO) << "Dependency scheduling violation: task " << source->id() << " [" << source->start() << ","
837 << source->end() << "] and task " << dst->id() << " [" << dst->start() << "," << dst->end() << "]";
838 }
839 }
840 MS_LOG(INFO) << "End Verification of Dependencies";
841 MS_LOG(INFO) << redundant_count << " dependencies are redundant, " << dependencies.size() - redundant_count
842 << " are real";
843
844 return flag;
845 }
846
VerifyDAG(std::vector<std::shared_ptr<Task>> & tasks)847 bool FastGreedyScheduler::VerifyDAG(std::vector<std::shared_ptr<Task>> &tasks) {
848 // simple verifier that no directed cycle exists
849 std::unordered_map<TaskId, bool> visited;
850 std::unordered_map<TaskId, size_t> unprocessed_parents;
851 std::deque<std::shared_ptr<Task>> to_visit;
852 MS_LOG(INFO) << "Start Verification of DAG";
853 for (auto &task : tasks) {
854 const auto &id = task->id();
855 visited[id] = false;
856 unprocessed_parents[id] = task->parents().size();
857 if (unprocessed_parents[id] == 0) {
858 to_visit.push_back(task);
859 }
860 }
861 while (!to_visit.empty()) {
862 const auto selected_task = *(to_visit.begin());
863 const auto &selected_id = selected_task->id();
864 if (visited[selected_id]) {
865 MS_LOG(INFO) << "Cycle including task " << selected_id;
866 return false;
867 } else {
868 visited[selected_id] = true;
869 }
870 to_visit.pop_front();
871 for (const auto &successor : selected_task->children()) {
872 const auto &succ_id = successor->id();
873 unprocessed_parents[succ_id] -= 1;
874 if (unprocessed_parents[succ_id] == 0) {
875 to_visit.push_back(successor);
876 }
877 }
878 }
879 MS_LOG(INFO) << "End Verification of DAG";
880
881 return true;
882 }
883
PrintLog(const SchedulingOutput & output,const std::vector<std::pair<TaskId,TaskId>> & dependencies,const std::string & graph_name)884 void FastGreedyScheduler::PrintLog(const SchedulingOutput &output,
885 const std::vector<std::pair<TaskId, TaskId>> &dependencies,
886 const std::string &graph_name) {
887 std::ofstream out_file("comp_comm_scheduling_out_" + graph_name + ".log", std::ios::out | std::ios::trunc);
888 if (!out_file.is_open()) {
889 MS_LOG(ERROR) << "Could not open comp_comm_scheduling_out.log";
890 return;
891 }
892
893 // Print info for tasks
894 const auto &tasks = output.task_times;
895 for (const auto &task : tasks) {
896 out_file << "THREAD id=" << std::to_string(task.id) << ", type=" << std::to_string(task.type)
897 << ", start=" << std::to_string(task.start) << ", end=" << std::to_string(task.end) << "\n";
898 }
899 // Print dependencies
900 for (const auto &dependency : dependencies) {
901 const auto &source = dependency.first;
902 const auto &dst = dependency.second;
903 out_file << "DEPENDENCY " << std::to_string(source) << " " << std::to_string(dst) << "\n";
904 }
905 out_file.close();
906 }
907
InsertTaskGraph(const std::vector<CNodePtr> & cnode_vec,std::unordered_map<CNodePtr,TaskPtr> * cnode_to_task_map_ptr)908 void InsertTaskGraph(const std::vector<CNodePtr> &cnode_vec,
909 std::unordered_map<CNodePtr, TaskPtr> *cnode_to_task_map_ptr) {
910 for (size_t i = 0; i < cnode_vec.size(); ++i) {
911 for (size_t j = 0; j < cnode_vec[i]->size(); ++j) {
912 const auto &input_node = cnode_vec[i]->input(j)->cast<CNodePtr>();
913 if ((*cnode_to_task_map_ptr).count(input_node) == 0) continue;
914
915 ((*cnode_to_task_map_ptr)[cnode_vec[i]])->AddParent((*cnode_to_task_map_ptr)[input_node]);
916 ((*cnode_to_task_map_ptr)[input_node])->AddChild((*cnode_to_task_map_ptr)[cnode_vec[i]]);
917 MS_LOG(INFO) << "Edge " << (*cnode_to_task_map_ptr)[input_node]->id() << " "
918 << (*cnode_to_task_map_ptr)[cnode_vec[i]]->id();
919 MS_LOG(INFO) << "Edge (UniqueName) " << input_node->UniqueName() << " " << cnode_vec[i]->UniqueName();
920 }
921 }
922 }
923
ExtractSchedulingInput(const FuncGraphManagerPtr & manager,const std::vector<CNodePtr> & cnode_vec,std::unordered_map<CNodePtr,TaskPtr> * cnode_to_task_map_ptr)924 SchedulingInput ExtractSchedulingInput(const FuncGraphManagerPtr &manager, const std::vector<CNodePtr> &cnode_vec,
925 std::unordered_map<CNodePtr, TaskPtr> *cnode_to_task_map_ptr) {
926 SchedulingInput scheduling_input; // to fill in and return
927
928 // Create a task per node
929 for (size_t i = 0; i < cnode_vec.size(); ++i) {
930 std::shared_ptr<Task> task1 =
931 std::make_shared<Task>(i, common::AnfAlgo::IsCommunicationOp(cnode_vec[i]) ? kComm : kComp);
932 MS_LOG(INFO) << "Start Assign Weight";
933 const auto &cnode = cnode_vec[i];
934 size_t output_num = AnfUtils::GetOutputTensorNum(cnode);
935 size_t input_num = common::AnfAlgo::GetInputTensorNum(cnode);
936 Time weight = 0;
937
938 // For each operator we get the inputs and outputs
939 // For each inputs, we multiply the shape to have the total size and we multiply the size by the data type
940 // We then sum all inputs
941 // If there is more than 1 output, we do the same for the outputs
942 // If output == 1 then cost is 0. We then sum all outputs
943 // We sum inputs cost + outputs cost
944 for (size_t j = 0; j < input_num; j++) {
945 KernelWithIndex kernel_with_index = common::AnfAlgo::GetPrevNodeOutput(cnode, j);
946 if (dyn_cast<abstract::BaseShape>(kernel_with_index.first->Shape()) == nullptr ||
947 dyn_cast<Type>(kernel_with_index.first->Type()) == nullptr) {
948 MS_LOG(INFO) << "shape or type is nullptr, ignore";
949 continue;
950 }
951 ShapeVector shape = common::AnfAlgo::GetOutputInferShape(kernel_with_index.first, kernel_with_index.second);
952 if (shape.size() <= 0) continue;
953
954 const TypeId type = common::AnfAlgo::GetOutputInferDataType(kernel_with_index.first, 0);
955 if (type == kObjectTypeUMonad || type == kObjectTypeMonad || type == kObjectTypeFunction) continue;
956
957 size_t type_size = GetDataTypeSize(type);
958 // size_t type_size = convert_type_to_num(type);
959 weight += std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<size_t>()) * type_size;
960 }
961
962 if (output_num > 1) {
963 for (size_t j = 0; j < output_num; j++) {
964 ShapeVector shape = common::AnfAlgo::GetOutputInferShape(cnode, j);
965 if (shape.size() <= 0) continue;
966
967 const TypeId type = common::AnfAlgo::GetOutputInferDataType(cnode, j);
968 if (type == kObjectTypeUMonad || type == kObjectTypeMonad || type == kObjectTypeFunction) continue;
969
970 size_t type_size = GetDataTypeSize(type);
971 // size_t type_size = convert_type_to_num(type);
972 weight += std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<size_t>()) * type_size;
973 }
974 }
975
976 if (weight < 0) MS_LOG(EXCEPTION) << "Weight < 0, replace by SIZE_MAX";
977
978 task1->AssignWeight(weight);
979 MS_LOG(INFO) << "End Assign Weight";
980
981 (*cnode_to_task_map_ptr)[cnode_vec[i]] = task1;
982 scheduling_input.tasks.push_back(task1);
983 MS_LOG(INFO) << "Task " << task1->id() << " with name " << cnode->UniqueName() << " and CNodePtr " << cnode_vec[i]
984 << " with weight " << weight;
985 }
986
987 // Insert task graph edges
988 InsertTaskGraph(cnode_vec, cnode_to_task_map_ptr);
989
990 return scheduling_input;
991 }
992
AddRealDependencies(const FuncGraphManagerPtr & manager,const std::vector<CNodePtr> & cnode_vec,const std::vector<std::pair<TaskId,TaskId>> & dependencies,std::unordered_map<CNodePtr,TaskPtr> * cnode_to_task)993 void AddRealDependencies(const FuncGraphManagerPtr &manager, const std::vector<CNodePtr> &cnode_vec,
994 const std::vector<std::pair<TaskId, TaskId>> &dependencies,
995 std::unordered_map<CNodePtr, TaskPtr> *cnode_to_task) {
996 size_t count = 0;
997 size_t redundant_count = 0;
998 for (const auto &dependency : dependencies) {
999 MS_LOG(INFO) << "Checking dependency " << dependency.first << " " << dependency.second;
1000 const auto &source = cnode_vec[dependency.first];
1001 const auto &dest = cnode_vec[dependency.second];
1002
1003 // Ignore dependencies already there
1004 if ((*cnode_to_task)[source]->HasChild((*cnode_to_task)[dest])) {
1005 MS_LOG(INFO) << "Dependency " << dependency.first << " " << dependency.second
1006 << " is redundant (already parent and child)";
1007 redundant_count++;
1008 continue;
1009 }
1010
1011 // Add dependency only between Comp (check also between Comm later)
1012 bool comp_comp = true;
1013 if (common::AnfAlgo::IsCommunicationOp(source) && common::AnfAlgo::IsCommunicationOp(dest)) {
1014 comp_comp = false;
1015 MS_LOG(INFO) << "Ignore Comm to Comm dependency " << dependency.first << " " << dependency.second;
1016 continue;
1017 }
1018
1019 // At least two inputs in destination node (input 0 is the node primitive)
1020 if (dest->size() < 2) {
1021 MS_LOG(INFO) << "Destination inputs size < 2: ignore";
1022 continue;
1023 }
1024
1025 // If destination node (comp) has comm inputs, make dependency involving one of them
1026 for (size_t j = 1; j < dest->size(); ++j) { // input 0 is node primitive: ignore
1027 if (!utils::isa<CNodePtr>(dest->input(j))) {
1028 MS_LOG(INFO) << "Not a cnodeptr at input " << j;
1029 continue;
1030 }
1031 if (comp_comp && !common::AnfAlgo::IsCommunicationOp(dest->input(j))) {
1032 MS_LOG(INFO) << "Dest " << dest << " Task " << dependency.second << " Input " << j
1033 << " is not CommunicationOp: Ignore";
1034 continue;
1035 }
1036 if (!comp_comp && common::AnfAlgo::IsCommunicationOp(dest->input(j))) {
1037 MS_LOG(INFO) << "Dest " << dest << " Task " << dependency.second << " Input " << j
1038 << " is CommunicationOp: Ignore";
1039 continue;
1040 }
1041
1042 // Add real dependency logic here
1043 const auto &input_node = dest->input(j)->cast<CNodePtr>();
1044 std::vector<AnfNodePtr> depend_inputs{NewValueNode(prim::kPrimDepend), input_node, source};
1045 auto depend_node = dest->func_graph()->NewCNode(depend_inputs);
1046 depend_node->set_abstract(input_node->abstract()->Clone());
1047 depend_node->AddAttr("comp_comm_scheduling_depend", MakeValue(true));
1048 MS_EXCEPTION_IF_NULL(depend_node);
1049 auto &nodes = manager->node_users()[input_node];
1050 auto it = std::find_if(nodes.begin(), nodes.end(), [dest](const auto &user) { return user.first == dest; });
1051 if (it != nodes.end()) {
1052 int idx = (*it).second;
1053 manager->SetEdge(dest, idx, depend_node);
1054 MS_LOG(INFO) << "Added dependency from " << dependency.first << ", unique name " << source->UniqueName()
1055 << ", to " << dependency.second << ", unique name " << dest->UniqueName();
1056 count++;
1057 break; // add dependency involving only one destination node input
1058 } else {
1059 MS_LOG(INFO) << "User index not found: Ignore dependency and continue";
1060 continue;
1061 }
1062 }
1063 }
1064 MS_LOG(INFO) << "Num of real dependencies added is " << count;
1065 MS_LOG(INFO) << "Num of redundant dependencies (HasChild) is " << redundant_count;
1066 return;
1067 }
1068
CompCommScheduling(const FuncGraphPtr & graph)1069 void CompCommScheduling(const FuncGraphPtr &graph) {
1070 if (parallel::ParallelContext::GetInstance()->parallel_mode() != parallel::kSemiAutoParallel &&
1071 parallel::ParallelContext::GetInstance()->parallel_mode() != parallel::kAutoParallel) {
1072 return;
1073 }
1074 if (common::GetEnv("MS_ENABLE_FRONTEND_SCHEDULING_OPTIMIZATION") != "1") {
1075 return;
1076 }
1077 MS_EXCEPTION_IF_NULL(graph);
1078 auto manager = graph->manager();
1079 MS_LOG(INFO) << "Main graph pointer: " << graph;
1080 MS_EXCEPTION_IF_NULL(manager);
1081
1082 FuncGraphSet graphs = manager->func_graphs();
1083 for (const auto &subgraph : graphs) {
1084 MS_LOG(INFO) << "Start Scheduling Subgraph " << subgraph;
1085 std::stringstream graph_ss;
1086 graph_ss << subgraph;
1087 std::string graph_name = graph_ss.str();
1088 std::list<CNodePtr> cnode_list = subgraph->GetOrderedCnodes();
1089 std::vector<CNodePtr> cnode_vec(cnode_list.cbegin(), cnode_list.cend());
1090
1091 MS_LOG(INFO) << "Start ExtractSchedulingInput";
1092 std::unordered_map<CNodePtr, TaskPtr> cnode_to_task;
1093 SchedulingInput scheduling_input = ExtractSchedulingInput(manager, cnode_vec, &cnode_to_task);
1094 MS_LOG(INFO) << "End ExtractSchedulingInput";
1095
1096 auto scheduling_output = FastGreedyScheduler::Process(scheduling_input, graph_name);
1097 auto dependencies = FastGreedyScheduler::ScheduleToDependencies(scheduling_output);
1098
1099 MS_LOG(INFO) << "Start AddRealDependencies";
1100 AddRealDependencies(manager, cnode_vec, dependencies, &cnode_to_task);
1101 MS_LOG(INFO) << "End AddRealDependencies";
1102 }
1103 }
1104 } // namespace opt
1105 } // namespace mindspore
1106