• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <algorithm>
18 #include <unordered_map>
19 #include <set>
20 #include <map>
21 #include <deque>
22 #include <queue>
23 #include <iostream>
24 #include <string>
25 #include <cmath>
26 #include <iomanip>
27 #include <functional>
28 
29 #include "mindspore/ccsrc/frontend/parallel/pass/comp_comm_scheduling.h"
30 #include "mindspore/ccsrc/include/common/utils/utils.h"
31 #include "mindspore/ccsrc/frontend/parallel/step_parallel.h"
32 #include "mindspore/core/utils/misc.h"
33 #include "mindspore/core/utils/convert_utils_base.h"
34 
35 #include "include/backend/optimizer/helper.h"
36 
37 namespace mindspore {
38 namespace opt {
39 // Subroutines Implementing "Scheduling to Dependencies"
40 struct SortByStart {
operator ()mindspore::opt::SortByStart41   bool operator()(const Interval &interval1, const Interval &interval2) const {
42     const auto &id1 = interval1.id;
43     const auto &start1 = interval1.start;
44     const auto &end1 = interval1.end;
45     const auto &id2 = interval2.id;
46     const auto &start2 = interval2.start;
47     const auto &end2 = interval2.end;
48     return start1 < start2 || (start1 == start2 && end1 < end2) || (start1 == start2 && end1 == end2 && id1 < id2);
49   }
50 };
51 
52 struct SortByEnd {
operator ()mindspore::opt::SortByEnd53   bool operator()(const Interval &interval1, const Interval &interval2) const {
54     const auto &id1 = interval1.id;
55     const auto &start1 = interval1.start;
56     const auto &end1 = interval1.end;
57     const auto &id2 = interval2.id;
58     const auto &start2 = interval2.start;
59     const auto &end2 = interval2.end;
60     return end1 < end2 || (end1 == end2 && start1 < start2) || (end1 == end2 && start1 == start2 && id1 < id2);
61   }
62 };
63 
Overlap(const Time & start1,const Time & end1,const Time & start2,const Time & end2)64 bool Overlap(const Time &start1, const Time &end1, const Time &start2, const Time &end2) {
65   return (start1 >= start2 && start1 < end2) ||
66          (start2 >= start1 && start2 < end1);  // if equal start and end for two interval, then no overlap
67 }
68 
ScheduleToDependencies(const SchedulingOutput & schedule)69 std::vector<std::pair<TaskId, TaskId>> FastGreedyScheduler::ScheduleToDependencies(const SchedulingOutput &schedule) {
70   std::vector<std::pair<TaskId, TaskId>> dependencies;  // to return
71   MS_LOG(INFO) << "Started Preprocessing of Intervals";
72   // Distinguish types and sort
73   std::unordered_map<TaskType, std::set<Interval, SortByStart>> tasks_start;
74   std::unordered_map<TaskType, std::set<Interval, SortByEnd>> tasks_end;
75   for (const auto &task_time : schedule.task_times) {
76     tasks_start[task_time.type].insert(task_time);
77     tasks_end[task_time.type].insert(task_time);
78   }
79   MS_LOG(INFO) << "Finished Preprocessing of Intervals";
80   MS_LOG(INFO) << "Started Main Loop";
81   // Main loop: check each task for potential dependencies in its right neighborhood
82   for (const auto &type_to_set : tasks_start) {
83     const auto &type = type_to_set.first;
84     for (auto it = tasks_start[type].begin(); it != tasks_start[type].end(); ++it) {
85       tasks_end[type].erase(*it);
86       // Dismiss overlapping tasks: save min end value of non-overlapping task to the right
87       std::unordered_map<TaskId, bool> dismissed;
88       auto it1 = std::next(it);
89       for (; Overlap(it->start, it->end, it1->start, it1->end) && it1 != tasks_start[type].end(); ++it1) {
90         dismissed[it1->id] = true;
91       }
92       Time min_end_value = 0;
93       for (auto it2 = tasks_end[type].begin(); it2 != tasks_end[type].end(); ++it2) {
94         if (!dismissed[it2->id]) {
95           min_end_value = it2->end;
96           break;
97         }
98       }
99       // Add dependencies to immediate right neighborhood
100       for (; it1->start < min_end_value && it1 != tasks_start[type].end(); ++it1) {
101         dependencies.emplace_back(it->id, it1->id);
102       }
103     }
104   }
105   MS_LOG(INFO) << "Finished Main Loop";
106   MS_LOG(INFO) << "Generated " << dependencies.size() << " dependencies";
107   return dependencies;
108 }
109 
110 // Sorting for tasks
SortByWeightMax(const std::shared_ptr<Task> & task1,const std::shared_ptr<Task> & task2)111 bool SortByWeightMax(const std::shared_ptr<Task> &task1, const std::shared_ptr<Task> &task2) {
112   return task1->weight() > task2->weight() || (task1->weight() == task2->weight() && task1->id() < task2->id());
113 }
114 
SortByWeightMin(const std::shared_ptr<Task> & task1,const std::shared_ptr<Task> & task2)115 bool SortByWeightMin(const std::shared_ptr<Task> &task1, const std::shared_ptr<Task> &task2) {
116   return task1->weight() < task2->weight() || (task1->weight() == task2->weight() && task1->id() < task2->id());
117 }
118 
SortBySuccDiff(const std::shared_ptr<Task> & task1,const std::shared_ptr<Task> & task2)119 bool SortBySuccDiff(const std::shared_ptr<Task> &task1, const std::shared_ptr<Task> &task2) {
120   return task1->succ_diff_type() > task2->succ_diff_type() ||
121          (task1->succ_diff_type() == task2->succ_diff_type() && task1->weight() > task2->weight()) ||
122          (task1->succ_diff_type() == task2->succ_diff_type() && task1->weight() == task2->weight() &&
123           task1->id() < task2->id());
124 }
125 
SortByBottomLevelMax(const std::shared_ptr<Task> & task1,const std::shared_ptr<Task> & task2)126 bool SortByBottomLevelMax(const std::shared_ptr<Task> &task1, const std::shared_ptr<Task> &task2) {
127   return task1->bottom_level() > task2->bottom_level() ||
128          (task1->bottom_level() == task2->bottom_level() && task1->weight() > task2->weight()) ||
129          (task1->bottom_level() == task2->bottom_level() && task1->weight() == task2->weight() &&
130           task1->id() < task2->id());
131 }
132 
SortByBottomLevelMin(const std::shared_ptr<Task> & task1,const std::shared_ptr<Task> & task2)133 bool SortByBottomLevelMin(const std::shared_ptr<Task> &task1, const std::shared_ptr<Task> &task2) {
134   return task1->bottom_level() < task2->bottom_level() ||
135          (task1->bottom_level() == task2->bottom_level() && task1->weight() > task2->weight()) ||
136          (task1->bottom_level() == task2->bottom_level() && task1->weight() == task2->weight() &&
137           task1->id() < task2->id());
138 }
139 
SortByTopLevelMax(const std::shared_ptr<Task> & task1,const std::shared_ptr<Task> & task2)140 bool SortByTopLevelMax(const std::shared_ptr<Task> &task1, const std::shared_ptr<Task> &task2) {
141   return task1->top_level() > task2->top_level() ||
142          (task1->top_level() == task2->top_level() && task1->weight() > task2->weight()) ||
143          (task1->top_level() == task2->top_level() && task1->weight() == task2->weight() && task1->id() < task2->id());
144 }
145 
SortByTopLevelMin(const std::shared_ptr<Task> & task1,const std::shared_ptr<Task> & task2)146 bool SortByTopLevelMin(const std::shared_ptr<Task> &task1, const std::shared_ptr<Task> &task2) {
147   return task1->top_level() < task2->top_level() ||
148          (task1->top_level() == task2->top_level() && task1->weight() > task2->weight()) ||
149          (task1->top_level() == task2->top_level() && task1->weight() == task2->weight() && task1->id() < task2->id());
150 }
151 
SortByBottomTopLevelMaxSum(const std::shared_ptr<Task> & task1,const std::shared_ptr<Task> & task2)152 bool SortByBottomTopLevelMaxSum(const std::shared_ptr<Task> &task1, const std::shared_ptr<Task> &task2) {
153   return task1->top_level() + task1->bottom_level() > task2->top_level() + task2->bottom_level() ||
154          (task1->top_level() + task1->bottom_level() == task2->top_level() + task2->bottom_level() &&
155           task1->weight() > task2->weight()) ||
156          (task1->top_level() + task1->bottom_level() == task2->top_level() + task2->bottom_level() &&
157           task1->weight() == task2->weight() && task1->id() < task2->id());
158 }
159 
SortByBottomTopLevelMinSum(const std::shared_ptr<Task> & task1,const std::shared_ptr<Task> & task2)160 bool SortByBottomTopLevelMinSum(const std::shared_ptr<Task> &task1, const std::shared_ptr<Task> &task2) {
161   return task1->top_level() + task1->bottom_level() < task2->top_level() + task2->bottom_level() ||
162          (task1->top_level() + task1->bottom_level() == task2->top_level() + task2->bottom_level() &&
163           task1->weight() > task2->weight()) ||
164          (task1->top_level() + task1->bottom_level() == task2->top_level() + task2->bottom_level() &&
165           task1->weight() == task2->weight() && task1->id() < task2->id());
166 }
167 
168 // Ishfaq Ahmad, Yu-Kwong Kwok, and Min-You Wu.
169 // Analysis, evaluation, and comparison of algorithms for scheduling task graphs on parallel processors.
170 // Second  International  Symposium  on Parallel Architectures, Algorithms, and Networks (I-SPAN?96),
171 // pages 207?213. IEEE, 1996.
SortByBottomTopLevelComposite(const std::shared_ptr<Task> & task1,const std::shared_ptr<Task> & task2)172 bool SortByBottomTopLevelComposite(const std::shared_ptr<Task> &task1, const std::shared_ptr<Task> &task2) {
173   return task1->bottom_level() - task1->top_level() > task2->bottom_level() - task2->top_level() ||
174          (task1->bottom_level() - task1->top_level() == task2->bottom_level() - task2->top_level() &&
175           task1->weight() > task2->weight()) ||
176          (task1->bottom_level() - task1->top_level() == task2->bottom_level() - task2->top_level() &&
177           task1->weight() == task2->weight() && task1->id() < task2->id());
178 }
179 
180 // Behrooz Shirazi, Mingfang Wang, and Girish Pathak.
181 // Analysis and evaluation of heuristic methods for static task scheduling.
182 // Journal of Parallel and Distributed Computing, 10(3):222?232, 1990.
SortByWeightedLength(const std::shared_ptr<Task> & task1,const std::shared_ptr<Task> & task2)183 bool SortByWeightedLength(const std::shared_ptr<Task> &task1, const std::shared_ptr<Task> &task2) {
184   return task1->weighted_length() > task2->weighted_length() ||
185          (task1->weighted_length() == task2->weighted_length() && task1->id() < task2->id());
186 }
187 
188 // DFS with weights for tie breaking
SortByDepthMax(const std::shared_ptr<Task> & task1,const std::shared_ptr<Task> & task2)189 bool SortByDepthMax(const std::shared_ptr<Task> &task1, const std::shared_ptr<Task> &task2) {
190   return task1->depth() > task2->depth() || (task1->depth() == task2->depth() && task1->weight() > task2->weight()) ||
191          (task1->depth() == task2->depth() && task1->weight() == task2->weight() && task1->id() < task2->id());
192 }
193 
194 // BFS with weights for tie breaking
SortByDepthMin(const std::shared_ptr<Task> & task1,const std::shared_ptr<Task> & task2)195 bool SortByDepthMin(const std::shared_ptr<Task> &task1, const std::shared_ptr<Task> &task2) {
196   return task1->depth() < task2->depth() || (task1->depth() == task2->depth() && task1->weight() > task2->weight()) ||
197          (task1->depth() == task2->depth() && task1->weight() == task2->weight() && task1->id() < task2->id());
198 }
199 
200 // Sort by predecessor to comm
SortByPredComm(const std::shared_ptr<Task> & task1,const std::shared_ptr<Task> & task2)201 bool SortByPredComm(const std::shared_ptr<Task> &task1, const std::shared_ptr<Task> &task2) {
202   return task1->pred_comm() < task2->pred_comm() ||
203          (task1->pred_comm() == task2->pred_comm() && task1->bottom_level() > task2->bottom_level()) ||
204          (task1->pred_comm() == task2->pred_comm() && task1->bottom_level() == task2->bottom_level() &&
205           task1->id() < task2->id());
206 }
207 
208 // Sort by predecessor to comm DFS
SortByPredCommDepth(const std::shared_ptr<Task> & task1,const std::shared_ptr<Task> & task2)209 bool SortByPredCommDepth(const std::shared_ptr<Task> &task1, const std::shared_ptr<Task> &task2) {
210   return task1->pred_comm() < task2->pred_comm() ||
211          (task1->pred_comm() == task2->pred_comm() && task1->depth() > task2->depth()) ||
212          (task1->pred_comm() == task2->pred_comm() && task1->depth() == task2->depth() && task1->id() < task2->id());
213 }
214 
215 // Sorting by load for processing elements
216 struct SortByLoad {
operator ()mindspore::opt::SortByLoad217   bool operator()(const ProcessingElement &pe1, const ProcessingElement &pe2) const {
218     return pe1.load < pe2.load || (pe1.load == pe2.load && pe1.id < pe2.id);
219   }
220 };
221 
222 // Get PEs description
GetTestPEs()223 std::unordered_map<TaskType, int32_t> GetTestPEs() {
224   std::unordered_map<TaskType, int32_t> new_pem;
225   new_pem[kComm] = 1;
226   new_pem[kComp] = 1;
227   return new_pem;
228 }
229 
230 // Auxiliary subroutines and lower bounds
ComputeDepthAndTopLevel(std::vector<std::shared_ptr<Task>> & tasks)231 void FastGreedyScheduler::ComputeDepthAndTopLevel(std::vector<std::shared_ptr<Task>> &tasks) {
232   MS_LOG(INFO) << "Top Level: Start Initialization";
233   std::unordered_map<TaskId, size_t> unprocessed_parents;
234   std::queue<std::shared_ptr<Task>> tasks_to_visit;
235   // Initialization loop
236   for (size_t j = 0; j < tasks.size(); ++j) {
237     const auto &id = tasks[j]->id();
238     unprocessed_parents[id] = tasks[j]->parents().size();
239     if (unprocessed_parents[id] == 0) {
240       tasks[j]->set_top_level(tasks[j]->parallel_weight());
241       tasks_to_visit.push(tasks[j]);
242     }
243   }
244   MS_LOG(INFO) << "Top Level: End Initialization";
245   MS_LOG(INFO) << "Top Level: Start Traversal Loop";
246   while (!tasks_to_visit.empty()) {
247     const auto &selected_task = tasks_to_visit.front();
248     // Update candidate tasks
249     for (auto &successor : selected_task->children()) {
250       const auto &succ_id = successor->id();
251       successor->set_depth(std::max(successor->depth(), selected_task->depth() + 1));
252       successor->set_top_level(
253         std::max(successor->top_level(), selected_task->top_level() + successor->parallel_weight()));
254       unprocessed_parents[succ_id] -= 1;
255       if (unprocessed_parents[succ_id] == 0) {
256         tasks_to_visit.push(successor);
257       }
258     }
259     tasks_to_visit.pop();
260   }
261   MS_LOG(INFO) << "Top Level: End Traversal Loop";
262 }
263 
ComputeBottomLevelAndWeightedLength(std::vector<std::shared_ptr<Task>> & tasks)264 void FastGreedyScheduler::ComputeBottomLevelAndWeightedLength(std::vector<std::shared_ptr<Task>> &tasks) {
265   MS_LOG(INFO) << "Bottom Level: Start Initialization";
266   std::unordered_map<TaskId, size_t> unprocessed_children;
267   std::unordered_map<TaskId, double> children_sum;
268   std::unordered_map<TaskId, double> children_max;
269   std::queue<std::shared_ptr<Task>> tasks_to_visit;
270   // Initialization loop
271   for (auto &task : tasks) {
272     const auto &id = task->id();
273     task->set_bottom_level(task->parallel_weight());
274     task->set_weighted_length(task->parallel_weight());
275     unprocessed_children[id] = task->children().size();
276     if (unprocessed_children[id] == 0) {
277       tasks_to_visit.push(task);
278     }
279   }
280   MS_LOG(INFO) << "Bottom Level: End Initialization";
281   MS_LOG(INFO) << "Bottom Level: Start Traversal Loop";
282   while (!tasks_to_visit.empty()) {
283     const auto &selected_task = tasks_to_visit.front();
284     // Update candidate tasks
285     for (auto &predecessor : selected_task->parents()) {
286       const auto &pred_id = predecessor.lock()->id();
287       predecessor.lock()->set_bottom_level(std::max(
288         predecessor.lock()->bottom_level(), selected_task->bottom_level() + predecessor.lock()->parallel_weight()));
289       children_sum[pred_id] += selected_task->weighted_length();
290       children_max[pred_id] = std::max(children_max[pred_id], selected_task->weighted_length());
291       unprocessed_children[pred_id] -= 1;
292       if (unprocessed_children[pred_id] == 0) {
293         if (children_max[pred_id] == 0) {
294           MS_LOG(EXCEPTION) << "divisor children_max[pred_id] cannot be 0!";
295         }
296         predecessor.lock()->set_weighted_length(predecessor.lock()->parallel_weight() + children_max[pred_id] +
297                                                 children_sum[pred_id] / children_max[pred_id]);
298         tasks_to_visit.push(predecessor.lock());
299       }
300     }
301     tasks_to_visit.pop();
302   }
303   MS_LOG(INFO) << "Bottom Level: End Traversal Loop";
304 }
305 
ComputePredComm(std::vector<std::shared_ptr<Task>> & tasks)306 void FastGreedyScheduler::ComputePredComm(std::vector<std::shared_ptr<Task>> &tasks) {
307   for (auto &task : tasks) {
308     task->set_pred_comm(0);
309     for (auto &predecessor : task->parents()) {
310       if (predecessor.lock()->type() == kComm) {
311         task->set_pred_comm(task->pred_comm() + 1);
312       }
313     }
314   }
315 }
316 
LowerBoundBottomLevel(std::vector<std::shared_ptr<Task>> & tasks)317 Time FastGreedyScheduler::LowerBoundBottomLevel(std::vector<std::shared_ptr<Task>> &tasks) {
318   Time max_bottom_level = 0;
319   for (const auto &task : tasks) {
320     max_bottom_level = std::max(max_bottom_level, task->bottom_level());
321   }
322   return max_bottom_level;
323 }
324 
LowerBoundPEs(std::vector<std::shared_ptr<Task>> & tasks,std::unordered_map<TaskType,int32_t> & type_to_num_cores_map)325 Time FastGreedyScheduler::LowerBoundPEs(std::vector<std::shared_ptr<Task>> &tasks,
326                                         std::unordered_map<TaskType, int32_t> &type_to_num_cores_map) {
327   double lower_bound = 0;
328 
329   std::unordered_map<TaskType, Time> type_task_sum;
330   for (const auto &task : tasks) {
331     type_task_sum[task->type()] += task->weight();
332   }
333   for (const auto &type_to_num : type_to_num_cores_map) {
334     const auto &type = type_to_num.first;
335     const auto &num_cores = type_to_num.second;
336     if (num_cores == 0) {
337       MS_LOG(EXCEPTION) << "divisor num_cores cannot be 0!";
338     }
339     lower_bound = std::max(lower_bound, type_task_sum[type] / (1.0 * num_cores));
340   }
341   return std::ceil(lower_bound);
342 }
343 
344 // Main algorithms/subroutines
SelectPEandTime(const Task & task,Time can_start,std::set<ProcessingElement,SortByLoad> * PEs_ptr)345 std::pair<PeId, Time> SelectPEandTime(const Task &task, Time can_start,
346                                       std::set<ProcessingElement, SortByLoad> *PEs_ptr) {
347   auto &PEs = *PEs_ptr;
348   std::pair<PeId, Time> return_pair = std::make_pair(0, 0);
349   for (auto it = PEs.begin(); it != PEs.end(); ++it) {
350     // unsafe use of const_cast, but we modify only idle list and not key sorting parameters like load, id, etc.
351     // cf: https://stackoverflow.com/questions/43340050/modification-of-elements-of-stdset-defined-behavior
352     auto &mut_pe = const_cast<ProcessingElement &>(*it);
353     // Put in first idle that fits it
354     for (auto idle_it = mut_pe.idle.begin(); idle_it != mut_pe.idle.end(); ++idle_it) {
355       Time start_time;
356       bool case_flag = false;
357       // Distinguish cases based on can_start constraint
358       if (can_start <= idle_it->first) {
359         start_time = idle_it->first;
360       } else if (can_start <= idle_it->second) {
361         start_time = can_start;
362         case_flag = true;
363       } else {  // can_start > idle_it->second means we are not allowed to schedule the task here
364         continue;
365       }
366       // If the task fits, then place it here
367       if (idle_it->second - start_time >= task.weight()) {
368         // Save info to return: start task at time idle_it->first
369         return_pair.first = (*it).id;
370         return_pair.second = start_time;
371         // Update idle list
372         if (!case_flag) {
373           if (idle_it->second - idle_it->first == task.weight()) {  // whole idle interval is filled in, erase it
374             mut_pe.idle.erase(idle_it);
375           } else {  // idle_it->second - idle_it->first > task.weight()
376             idle_it->first += task.weight();
377           }
378         } else {  // case_flag = true, idle interval is broken into two sub-blocks [idle_it->first, can_start] and
379                   // (maybe empty) [can_start + weight, idle_it->second]
380           Time upper = idle_it->second;
381           idle_it->second = can_start;
382           if (upper - can_start - task.weight() > 0) {
383             std::pair<Time, Time> new_idle = std::make_pair(can_start + task.weight(), upper);
384             mut_pe.idle.emplace(std::next(idle_it), new_idle);
385           }
386         }
387         // Update load and PEs set
388         auto updated_PE = PEs.extract(it);
389         updated_PE.value().load += task.weight();
390         PEs.insert(std::move(updated_PE));
391         return return_pair;
392       }
393     }
394   }
395   return return_pair;
396 }
397 
SelectPEandTimeAvailableStart(const Task & task,Time can_start,std::vector<ProcessingElement> * PEs_ptr)398 std::pair<PeId, Time> SelectPEandTimeAvailableStart(const Task &task, Time can_start,
399                                                     std::vector<ProcessingElement> *PEs_ptr) {
400   auto &PEs = *PEs_ptr;
401   // Precompute min first available start for task
402   Time min_start = SIZE_MAX;
403   bool min_case = false;
404   std::vector<ProcessingElement>::iterator min_it;
405   std::list<std::pair<Time, Time>>::iterator min_idle_it;
406   for (auto it = PEs.begin(); it != PEs.end(); ++it) {
407     for (auto idle_it = it->idle.begin(); idle_it != it->idle.end(); ++idle_it) {
408       Time start_time;
409       bool case_flag = false;
410       // Distinguish cases based on can_start constraint
411       if (can_start <= idle_it->first) {
412         start_time = idle_it->first;
413       } else if (can_start <= idle_it->second) {
414         start_time = can_start;
415         case_flag = true;
416       } else {  // can_start > idle_it->second means we are not allowed to schedule the task here
417         continue;
418       }
419       if (idle_it->second - start_time >= task.weight()) {
420         if (min_start > start_time) {
421           min_start = start_time;
422           min_case = case_flag;
423           min_it = it;
424           min_idle_it = idle_it;
425           break;
426         }
427       }
428     }
429   }
430   // Assign task to min PE
431   std::pair<PeId, Time> return_pair = std::make_pair(0, 0);
432   // Save info to return: start task at time idle_it->first
433   return_pair.first = (*min_it).id;
434   return_pair.second = min_start;
435   // Update idle list
436   if (!min_case) {
437     if (min_idle_it->second - min_idle_it->first == task.weight()) {  // whole idle interval is filled in, erase it
438       min_it->idle.erase(min_idle_it);
439     } else {  // idle_it->second - idle_it->first > task.weight()
440       min_idle_it->first += task.weight();
441     }
442   } else {  // min_case = true, idle interval is broken into two sub-blocks [idle_it->first, can_start] and
443             // (maybe empty)[can_start + task.weight(), idle_it->second]
444     Time upper = min_idle_it->second;
445     min_idle_it->second = can_start;
446     if (upper - can_start - task.weight() > 0) {
447       std::pair<Time, Time> new_idle = std::make_pair(can_start + task.weight(), upper);
448       min_it->idle.emplace(std::next(min_idle_it), new_idle);
449     }
450   }
451   // Update load
452   min_it->load += task.weight();
453   return return_pair;
454 }
455 
456 constexpr TaskSortFunction THREAD_SORT[] = {SortByWeightMax,
457                                             SortByWeightMin,
458                                             SortBySuccDiff,
459                                             SortByBottomLevelMax,
460                                             SortByBottomLevelMin,
461                                             SortByTopLevelMax,
462                                             SortByTopLevelMin,
463                                             SortByBottomTopLevelMaxSum,
464                                             SortByBottomTopLevelMinSum,
465                                             SortByBottomTopLevelComposite,
466                                             SortByWeightedLength,
467                                             SortByDepthMax,
468                                             SortByDepthMin,
469                                             SortByPredComm,
470                                             SortByPredCommDepth};
471 
472 constexpr std::string_view THREAD_SORT_NAMES[] = {"SortByWeightMax",
473                                                   "SortByWeightMin",
474                                                   "SortBySuccDiff",
475                                                   "SortByBottomLevelMax",
476                                                   "SortByBottomLevelMin",
477                                                   "SortByTopLevelMax",
478                                                   "SortByTopLevelMin",
479                                                   "SortByBottomTopLevelMaxSum",
480                                                   "SortByBottomTopLevelMinSum",
481                                                   "SortByBottomTopLevelComposite",
482                                                   "SortByWeightedLength",
483                                                   "SortByDepthMax",
484                                                   "SortByDepthMin",
485                                                   "SortByPredComm",
486                                                   "SortByPredCommDepth"};
487 
488 enum class PEsSort { kSortByLoad = 0, kSortByValidStart, kNumPEsSort };
489 
490 constexpr std::string_view PE_NAME_SORT[] = {"SortByLoad", "SortByValidStart"};
491 
Process(SchedulingInput & input,const std::string & graph_name)492 SchedulingOutput FastGreedyScheduler::Process(SchedulingInput &input, const std::string &graph_name) {
493   std::vector<std::shared_ptr<Task>> *tasks = &(input.tasks);
494   auto type_to_num_cores_map = GetTestPEs();
495   SchedulingOutput output{{}, SIZE_MAX};
496   // Optional: verify input task graph is a DAG
497   if (VerifyDAG(*tasks)) {
498     MS_LOG(INFO) << "Verification of DAG: SUCCESS";
499   } else {
500     MS_LOG(INFO) << "Verification of DAG: FAILURE";
501   }
502 
503   // Preprocessing: values computation for necessary sorting
504   ComputeBottomLevelAndWeightedLength(*tasks);
505   ComputeDepthAndTopLevel(*tasks);
506   ComputePredComm(*tasks);
507 
508   // Loop over all sorting combinations
509   std::unordered_map<std::shared_ptr<Task>, Time> best_start;
510   std::unordered_map<std::shared_ptr<Task>, Time> best_end;  // to use in verify dependencies only
511   std::string best_solution;
512   MS_LOG(INFO) << "Start loop multiple scheduling functions";
513   for (size_t task_sort = 0; task_sort < static_cast<size_t>(kNumTaskSort); ++task_sort) {
514     for (size_t pes_sort = 0; pes_sort < static_cast<size_t>(PEsSort::kNumPEsSort); ++pes_sort) {
515       MS_LOG(INFO) << THREAD_SORT_NAMES[task_sort] << " and " << PE_NAME_SORT[pes_sort];
516       SchedulingOutput solution = ProcessCore(*tasks, type_to_num_cores_map, THREAD_SORT[task_sort],
517                                               (pes_sort == static_cast<size_t>(PEsSort::kSortByLoad)));
518       if (solution.makespan < output.makespan) {
519         output = solution;
520         best_solution = THREAD_SORT_NAMES[task_sort];
521         for (const auto &task : *tasks) {  // to use in verify dependencies only
522           best_start[task] = task->start();
523           best_end[task] = task->end();
524         }
525       }
526       for (const auto &task : *tasks) {
527         task->ResetStartEnd();
528       }
529     }
530   }
531   MS_LOG(INFO) << "End loop multiple scheduling functions";
532 
533   // Print stats about best solution
534   MS_LOG(INFO) << "Best solution is: " << best_solution;
535   MS_LOG(INFO) << "Makespan of best solution is " << output.makespan;
536   MS_LOG(INFO) << "Bottom level lower bound is " << LowerBoundBottomLevel(*tasks);
537   MS_LOG(INFO) << "Max type lower bound is " << LowerBoundPEs(*tasks, type_to_num_cores_map);
538   MS_LOG(INFO) << "Solution relative error is " << std::setprecision(5)
539                << ((output.makespan /
540                       (1.0 * std::max(LowerBoundBottomLevel(*tasks), LowerBoundPEs(*tasks, type_to_num_cores_map))) -
541                     1) *
542                    100)
543                << "%";
544 
545   // Create and (optionally) verify dependencies (here only for testing)
546   MS_LOG(INFO) << "Start Schedule to Dependencies";
547   auto dependencies = ScheduleToDependencies(output);
548   MS_LOG(INFO) << "End Schedule to Dependencies";
549   for (const auto &task : *tasks) {
550     task->set_start(best_start[task]);
551     task->set_end(best_end[task]);
552   }
553 
554   // Output log file with all info (scheduling and dependencies)
555   MS_LOG(INFO) << "Start printing output log file";
556   PrintLog(output, dependencies, graph_name);
557   MS_LOG(INFO) << "End printing output log file";
558 
559   return output;
560 }
561 
ProcessCore(std::vector<std::shared_ptr<Task>> & tasks,std::unordered_map<TaskType,int32_t> & type_to_num_cores_map,const TaskSortFunction & sortPtr,bool pe_load_sort)562 SchedulingOutput FastGreedyScheduler::ProcessCore(std::vector<std::shared_ptr<Task>> &tasks,
563                                                   std::unordered_map<TaskType, int32_t> &type_to_num_cores_map,
564                                                   const TaskSortFunction &sortPtr, bool pe_load_sort) {
565   SchedulingOutput output{{}, 0};
566 
567   // Initializations for tasks
568   MS_LOG(INFO) << "Started Task Initialization";
569   std::set<std::shared_ptr<Task>, TaskSortFunction> candidate_tasks(sortPtr);
570   std::unordered_map<TaskId, Time> can_start;
571   std::unordered_map<TaskId, size_t> unprocessed_parents;
572   for (auto &task : tasks) {
573     const auto &id = task->id();
574     can_start[id] = 0;
575     unprocessed_parents[id] = task->parents().size();
576     if (unprocessed_parents[id] == 0) {
577       candidate_tasks.insert(task);
578     }
579   }
580   MS_LOG(INFO) << "Finished Task Initialization";
581 
582   // Initializations for processing elements
583   // Pick a sorting for processing elements
584   // Implemented: SortByLoad, SortByAvailableStart
585   // Only one structure to be used depending on argument; we define both here
586   std::unordered_map<TaskType, std::set<ProcessingElement, SortByLoad>> PEs_load;
587   std::unordered_map<TaskType, std::vector<ProcessingElement>> PEs_start;
588   MS_LOG(INFO) << "Started Processing Element Initialization";
589   size_t count = 0;
590   for (const auto &type_to_num : type_to_num_cores_map) {
591     const auto &type = type_to_num.first;
592     const auto &num_cores = type_to_num.second;
593     for (int i = 0; i < num_cores; ++i) {
594       ProcessingElement new_pe;
595       new_pe.id = count + IntToSize(i);
596       new_pe.type = type;
597       new_pe.load = 0;
598       new_pe.idle.emplace_back(0, SIZE_MAX);
599       if (pe_load_sort) {
600         PEs_load[type].insert(new_pe);
601       } else {
602         PEs_start[type].push_back(new_pe);
603       }
604     }
605     count += num_cores;
606   }
607   MS_LOG(INFO) << "Finished Processing Element Initialization";
608 
609   // Task graph scheduling loop
610   MS_LOG(INFO) << "Started Scheduling Main Loop";
611   while (!candidate_tasks.empty()) {
612     // Select task and schedule it, save info for output
613     const auto selected_task = *(candidate_tasks.begin());
614     const auto &selected_id = selected_task->id();
615     // Selected PE and start time
616     std::pair<PeId, Time> PE_and_time;
617     if (pe_load_sort) {
618       PE_and_time = SelectPEandTime(*selected_task, can_start[selected_id], &PEs_load[selected_task->type()]);
619     } else {
620       PE_and_time =
621         SelectPEandTimeAvailableStart(*selected_task, can_start[selected_id], &PEs_start[selected_task->type()]);
622     }
623 
624     const auto &sigma = PE_and_time.second;
625 
626     // Maintenance of task interval
627     selected_task->set_start(sigma);
628     selected_task->set_end(sigma + selected_task->weight());
629     // New interval for task in output
630     Interval new_interval{selected_id, selected_task->type(), selected_task->start(), selected_task->end()};
631     output.task_times.push_back(new_interval);
632     // Update makespan
633     output.makespan = std::max(output.makespan, selected_task->end());
634     // Update candidate tasks
635     candidate_tasks.erase(selected_task);
636     for (const auto &successor : selected_task->children()) {
637       const auto &succ_id = successor->id();
638       can_start[succ_id] = std::max(can_start[succ_id], selected_task->end());
639       unprocessed_parents[succ_id] -= 1;
640       if (unprocessed_parents[succ_id] == 0) {
641         candidate_tasks.insert(successor);
642       }
643     }
644   }
645   MS_LOG(INFO) << "Finished Scheduling Main Loop";
646   MS_LOG(INFO) << "Makespan is " << output.makespan;
647   // Verification of scheduling solution (optional)
648   if (VerifyScheduling(tasks)) {
649     MS_LOG(INFO) << "Verification of Scheduling: SUCCESS";
650   } else {
651     MS_LOG(INFO) << "Verification of Scheduling: FAILURE";
652   }
653 
654   return output;
655 }
656 
ProcessSingle(const SchedulingInput & input,const TaskSortFunction & sortPtr,bool pe_load_sort,const std::string & graph_name)657 SchedulingOutput FastGreedyScheduler::ProcessSingle(const SchedulingInput &input, const TaskSortFunction &sortPtr,
658                                                     bool pe_load_sort, const std::string &graph_name) {
659   auto tasks = input.tasks;
660   auto type_to_num_cores_map = GetTestPEs();
661   SchedulingOutput output{{}, 0};
662   // Optional: verify input task graph is a DAG
663   if (VerifyDAG(tasks)) {
664     MS_LOG(INFO) << "Verification of DAG: SUCCESS";
665   } else {
666     MS_LOG(INFO) << "Verification of DAG: FAILURE";
667   }
668   // Preprocessing: values computation for sorting necessary
669   ComputeBottomLevelAndWeightedLength(tasks);
670   ComputeDepthAndTopLevel(tasks);
671   // Initializations for tasks
672   MS_LOG(INFO) << "Started Task Initialization";
673   std::set<std::shared_ptr<Task>, TaskSortFunction> candidate_tasks(sortPtr);
674   std::unordered_map<TaskId, Time> can_start;
675   std::unordered_map<TaskId, size_t> unprocessed_parents;
676   for (auto &task : tasks) {
677     const auto &id = task->id();
678     can_start[id] = 0;
679     unprocessed_parents[id] = task->parents().size();
680     if (unprocessed_parents[id] == 0) {
681       candidate_tasks.insert(task);
682     }
683   }
684   MS_LOG(INFO) << "Finished Task Initialization";
685 
686   // Initializations for processing elements
687   // Pick a sorting for processing elements
688   // Implemented: SortByLoad, SortByAvailableStart
689   // Only one structure to be used depending on argument; we define both here
690   std::unordered_map<TaskType, std::set<ProcessingElement, SortByLoad>> PEs_load;
691   std::unordered_map<TaskType, std::vector<ProcessingElement>> PEs_start;
692   MS_LOG(INFO) << "Started Processing Element Initialization";
693   size_t count = 0;
694   for (const auto &type_to_num : type_to_num_cores_map) {
695     const auto &type = type_to_num.first;
696     const auto &num_cores = type_to_num.second;
697     for (int i = 0; i < num_cores; ++i) {
698       ProcessingElement new_pe;
699       new_pe.id = count + i;
700       new_pe.type = type;
701       new_pe.load = 0;
702       new_pe.idle.emplace_back(0, SIZE_MAX);
703       if (pe_load_sort) {
704         PEs_load[type].insert(new_pe);
705       } else {
706         PEs_start[type].push_back(new_pe);
707       }
708     }
709     count += num_cores;
710   }
711   MS_LOG(INFO) << "Finished Processing Element Initialization";
712 
713   // Task graph scheduling loop
714   MS_LOG(INFO) << "Started Scheduling Main Loop";
715   while (!candidate_tasks.empty()) {
716     // Select task and schedule its blocks, save info for output
717     const auto selected_task = *(candidate_tasks.begin());
718     const auto &selected_id = selected_task->id();
719     // Selected PE and start time
720     std::pair<PeId, Time> PE_and_time;
721     if (pe_load_sort) {
722       PE_and_time = SelectPEandTime(*selected_task, can_start[selected_id], &PEs_load[selected_task->type()]);
723     } else {
724       PE_and_time =
725         SelectPEandTimeAvailableStart(*selected_task, can_start[selected_id], &PEs_start[selected_task->type()]);
726     }
727     const auto &sigma = PE_and_time.second;
728 
729     // Maintenance of task interval
730     selected_task->set_start(sigma);
731     selected_task->set_end(sigma + selected_task->weight());
732     // New interval for task in output
733     Interval new_interval{selected_id, selected_task->type(), selected_task->start(), selected_task->end()};
734     output.task_times.push_back(new_interval);
735     // Update makespan
736     output.makespan = std::max(output.makespan, selected_task->end());
737     // Update candidate tasks
738     candidate_tasks.erase(selected_task);
739     for (const auto &successor : selected_task->children()) {
740       const auto &succ_id = successor->id();
741       can_start[succ_id] = std::max(can_start[succ_id], selected_task->end());
742       unprocessed_parents[succ_id] -= 1;
743       if (unprocessed_parents[succ_id] == 0) {
744         candidate_tasks.insert(successor);
745       }
746     }
747   }
748   MS_LOG(INFO) << "Finished Scheduling Main Loop";
749   MS_LOG(INFO) << "Makespan is " << output.makespan;
750   MS_LOG(INFO) << "Bottom level lower bound is " << LowerBoundBottomLevel(tasks);
751   MS_LOG(INFO) << "Max type lower bound is " << LowerBoundPEs(tasks, type_to_num_cores_map);
752   MS_LOG(INFO) << "Solution relative error is " << std::setprecision(5)
753                << ((output.makespan /
754                       (1.0 * std::max(LowerBoundBottomLevel(tasks), LowerBoundPEs(tasks, type_to_num_cores_map))) -
755                     1) *
756                    100)
757                << "%";
758   // Verification of scheduling solution (optional)
759   if (VerifyScheduling(tasks)) {
760     MS_LOG(INFO) << "Verification of Scheduling: SUCCESS";
761   } else {
762     MS_LOG(INFO) << "Verification of Scheduling: FAILURE";
763   }
764   // Scheduling to Dependencies (here only for testing)
765   MS_LOG(INFO) << "Start Schedule to Dependencies";
766   auto dependencies = ScheduleToDependencies(output);
767   MS_LOG(INFO) << "End Schedule to Dependencies";
768   if (VerifyDependencies(tasks, dependencies)) {
769     MS_LOG(INFO) << "Verification of Dependencies: SUCCESS";
770   } else {
771     MS_LOG(INFO) << "Verification of Dependencies: FAILURE";
772   }
773   PrintLog(output, dependencies, graph_name);
774 
775   return output;
776 }
777 
VerifyScheduling(std::vector<std::shared_ptr<Task>> & tasks)778 bool FastGreedyScheduler::VerifyScheduling(std::vector<std::shared_ptr<Task>> &tasks) {
779   bool flag = true;
780   MS_LOG(INFO) << "Start Verification of Scheduling";
781   for (auto &task : tasks) {
782     // Check if task is scheduled before its children
783     for (auto child = task->children().begin(); child != task->children().end(); ++child) {
784       if (!(task->start() < task->end() && task->end() <= (*child)->start() &&
785             (*child)->start() < (*child)->end())) {  // assume open-rightpoint intervals and non-zero size
786         MS_LOG(INFO) << "Verification violation: task " << task->id() << " [" << task->start() << "," << task->end()
787                      << "] and task " << (*child)->id() << " [" << (*child)->start() << "," << (*child)->end() << "]";
788         flag = false;
789       }
790     }
791   }
792   MS_LOG(INFO) << "End Verification of Scheduling";
793   return flag;
794 }
795 
BFSsort(const std::shared_ptr<Task> & task1,const std::shared_ptr<Task> & task2)796 bool BFSsort(const std::shared_ptr<Task> &task1, const std::shared_ptr<Task> &task2) {
797   return task1->depth() < task2->depth() || (task1->depth() == task2->depth() && task1->id() < task2->id());
798 }
799 
VerifyDependencies(std::vector<std::shared_ptr<Task>> & tasks,std::vector<std::pair<TaskId,TaskId>> & dependencies)800 bool FastGreedyScheduler::VerifyDependencies(std::vector<std::shared_ptr<Task>> &tasks,
801                                              std::vector<std::pair<TaskId, TaskId>> &dependencies) {
802   bool flag = true;
803 
804   MS_LOG(INFO) << "Start Verification of Dependencies";
805   // Traverse graph by depth to maintain ancestor info
806   auto tasks_sorted = tasks;
807   std::sort(tasks_sorted.begin(), tasks_sorted.end(), BFSsort);
808   std::map<TaskId, std::map<TaskId, bool>> exists_path;
809   std::map<TaskId, std::shared_ptr<Task>> id_to_ptr;
810   for (auto current = tasks_sorted.begin(); current != tasks_sorted.end(); ++current) {
811     id_to_ptr[(*current)->id()] = *current;
812     for (auto parent = (*current)->parents().begin(); parent != (*current)->parents().end(); ++parent) {
813       exists_path[(*parent).lock()->id()][(*current)->id()] = true;
814       for (auto &it : tasks_sorted) {
815         if (exists_path[it->id()][(*parent).lock()->id()]) {
816           exists_path[it->id()][(*current)->id()] = true;
817         }
818       }
819     }
820   }
821   // For each dependency, check if redundant it forms a directed cycle and if corresponding tasks are scheduled
822   // correctly
823   size_t redundant_count = 0;
824   for (auto it = dependencies.begin(); it != dependencies.end(); ++it) {
825     const auto &source = id_to_ptr[it->first];
826     const auto &dst = id_to_ptr[it->second];
827     if (exists_path[it->first][it->second]) {
828       redundant_count++;
829     }
830     if (exists_path[it->second][it->first]) {
831       MS_LOG(INFO) << "Dependency cycle formation: task " << source->id() << " [" << source->start() << ","
832                    << source->end() << "] and task " << dst->id() << " [" << dst->start() << "," << dst->end() << "]";
833     }
834     if (!(source->start() < source->end() && source->end() <= dst->start() && dst->start() < dst->end())) {
835       // allow weights of size 0
836       MS_LOG(INFO) << "Dependency scheduling violation: task " << source->id() << " [" << source->start() << ","
837                    << source->end() << "] and task " << dst->id() << " [" << dst->start() << "," << dst->end() << "]";
838     }
839   }
840   MS_LOG(INFO) << "End Verification of Dependencies";
841   MS_LOG(INFO) << redundant_count << " dependencies are redundant, " << dependencies.size() - redundant_count
842                << " are real";
843 
844   return flag;
845 }
846 
VerifyDAG(std::vector<std::shared_ptr<Task>> & tasks)847 bool FastGreedyScheduler::VerifyDAG(std::vector<std::shared_ptr<Task>> &tasks) {
848   // simple verifier that no directed cycle exists
849   std::unordered_map<TaskId, bool> visited;
850   std::unordered_map<TaskId, size_t> unprocessed_parents;
851   std::deque<std::shared_ptr<Task>> to_visit;
852   MS_LOG(INFO) << "Start Verification of DAG";
853   for (auto &task : tasks) {
854     const auto &id = task->id();
855     visited[id] = false;
856     unprocessed_parents[id] = task->parents().size();
857     if (unprocessed_parents[id] == 0) {
858       to_visit.push_back(task);
859     }
860   }
861   while (!to_visit.empty()) {
862     const auto selected_task = *(to_visit.begin());
863     const auto &selected_id = selected_task->id();
864     if (visited[selected_id]) {
865       MS_LOG(INFO) << "Cycle including task " << selected_id;
866       return false;
867     } else {
868       visited[selected_id] = true;
869     }
870     to_visit.pop_front();
871     for (const auto &successor : selected_task->children()) {
872       const auto &succ_id = successor->id();
873       unprocessed_parents[succ_id] -= 1;
874       if (unprocessed_parents[succ_id] == 0) {
875         to_visit.push_back(successor);
876       }
877     }
878   }
879   MS_LOG(INFO) << "End Verification of DAG";
880 
881   return true;
882 }
883 
PrintLog(const SchedulingOutput & output,const std::vector<std::pair<TaskId,TaskId>> & dependencies,const std::string & graph_name)884 void FastGreedyScheduler::PrintLog(const SchedulingOutput &output,
885                                    const std::vector<std::pair<TaskId, TaskId>> &dependencies,
886                                    const std::string &graph_name) {
887   std::ofstream out_file("comp_comm_scheduling_out_" + graph_name + ".log", std::ios::out | std::ios::trunc);
888   if (!out_file.is_open()) {
889     MS_LOG(ERROR) << "Could not open comp_comm_scheduling_out.log";
890     return;
891   }
892 
893   // Print info for tasks
894   const auto &tasks = output.task_times;
895   for (const auto &task : tasks) {
896     out_file << "THREAD id=" << std::to_string(task.id) << ", type=" << std::to_string(task.type)
897              << ", start=" << std::to_string(task.start) << ", end=" << std::to_string(task.end) << "\n";
898   }
899   // Print dependencies
900   for (const auto &dependency : dependencies) {
901     const auto &source = dependency.first;
902     const auto &dst = dependency.second;
903     out_file << "DEPENDENCY " << std::to_string(source) << " " << std::to_string(dst) << "\n";
904   }
905   out_file.close();
906 }
907 
InsertTaskGraph(const std::vector<CNodePtr> & cnode_vec,std::unordered_map<CNodePtr,TaskPtr> * cnode_to_task_map_ptr)908 void InsertTaskGraph(const std::vector<CNodePtr> &cnode_vec,
909                      std::unordered_map<CNodePtr, TaskPtr> *cnode_to_task_map_ptr) {
910   for (size_t i = 0; i < cnode_vec.size(); ++i) {
911     for (size_t j = 0; j < cnode_vec[i]->size(); ++j) {
912       const auto &input_node = cnode_vec[i]->input(j)->cast<CNodePtr>();
913       if ((*cnode_to_task_map_ptr).count(input_node) == 0) continue;
914 
915       ((*cnode_to_task_map_ptr)[cnode_vec[i]])->AddParent((*cnode_to_task_map_ptr)[input_node]);
916       ((*cnode_to_task_map_ptr)[input_node])->AddChild((*cnode_to_task_map_ptr)[cnode_vec[i]]);
917       MS_LOG(INFO) << "Edge " << (*cnode_to_task_map_ptr)[input_node]->id() << " "
918                    << (*cnode_to_task_map_ptr)[cnode_vec[i]]->id();
919       MS_LOG(INFO) << "Edge (UniqueName) " << input_node->UniqueName() << " " << cnode_vec[i]->UniqueName();
920     }
921   }
922 }
923 
ExtractSchedulingInput(const FuncGraphManagerPtr & manager,const std::vector<CNodePtr> & cnode_vec,std::unordered_map<CNodePtr,TaskPtr> * cnode_to_task_map_ptr)924 SchedulingInput ExtractSchedulingInput(const FuncGraphManagerPtr &manager, const std::vector<CNodePtr> &cnode_vec,
925                                        std::unordered_map<CNodePtr, TaskPtr> *cnode_to_task_map_ptr) {
926   SchedulingInput scheduling_input;  // to fill in and return
927 
928   // Create a task per node
929   for (size_t i = 0; i < cnode_vec.size(); ++i) {
930     std::shared_ptr<Task> task1 =
931       std::make_shared<Task>(i, common::AnfAlgo::IsCommunicationOp(cnode_vec[i]) ? kComm : kComp);
932     MS_LOG(INFO) << "Start Assign Weight";
933     const auto &cnode = cnode_vec[i];
934     size_t output_num = AnfUtils::GetOutputTensorNum(cnode);
935     size_t input_num = common::AnfAlgo::GetInputTensorNum(cnode);
936     Time weight = 0;
937 
938     // For each operator we get the inputs and outputs
939     // For each inputs, we multiply the shape to have the total size and we multiply the size by the data type
940     // We then sum all inputs
941     // If there is more than 1 output, we do the same for the outputs
942     // If output == 1 then cost is 0. We then sum all outputs
943     // We sum inputs cost + outputs cost
944     for (size_t j = 0; j < input_num; j++) {
945       KernelWithIndex kernel_with_index = common::AnfAlgo::GetPrevNodeOutput(cnode, j);
946       if (dyn_cast<abstract::BaseShape>(kernel_with_index.first->Shape()) == nullptr ||
947           dyn_cast<Type>(kernel_with_index.first->Type()) == nullptr) {
948         MS_LOG(INFO) << "shape or type is nullptr, ignore";
949         continue;
950       }
951       ShapeVector shape = common::AnfAlgo::GetOutputInferShape(kernel_with_index.first, kernel_with_index.second);
952       if (shape.size() <= 0) continue;
953 
954       const TypeId type = common::AnfAlgo::GetOutputInferDataType(kernel_with_index.first, 0);
955       if (type == kObjectTypeUMonad || type == kObjectTypeMonad || type == kObjectTypeFunction) continue;
956 
957       size_t type_size = GetDataTypeSize(type);
958       // size_t type_size = convert_type_to_num(type);
959       weight += std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<size_t>()) * type_size;
960     }
961 
962     if (output_num > 1) {
963       for (size_t j = 0; j < output_num; j++) {
964         ShapeVector shape = common::AnfAlgo::GetOutputInferShape(cnode, j);
965         if (shape.size() <= 0) continue;
966 
967         const TypeId type = common::AnfAlgo::GetOutputInferDataType(cnode, j);
968         if (type == kObjectTypeUMonad || type == kObjectTypeMonad || type == kObjectTypeFunction) continue;
969 
970         size_t type_size = GetDataTypeSize(type);
971         // size_t type_size = convert_type_to_num(type);
972         weight += std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<size_t>()) * type_size;
973       }
974     }
975 
976     if (weight < 0) MS_LOG(EXCEPTION) << "Weight < 0, replace by SIZE_MAX";
977 
978     task1->AssignWeight(weight);
979     MS_LOG(INFO) << "End Assign Weight";
980 
981     (*cnode_to_task_map_ptr)[cnode_vec[i]] = task1;
982     scheduling_input.tasks.push_back(task1);
983     MS_LOG(INFO) << "Task " << task1->id() << " with name " << cnode->UniqueName() << " and CNodePtr " << cnode_vec[i]
984                  << " with weight " << weight;
985   }
986 
987   // Insert task graph edges
988   InsertTaskGraph(cnode_vec, cnode_to_task_map_ptr);
989 
990   return scheduling_input;
991 }
992 
AddRealDependencies(const FuncGraphManagerPtr & manager,const std::vector<CNodePtr> & cnode_vec,const std::vector<std::pair<TaskId,TaskId>> & dependencies,std::unordered_map<CNodePtr,TaskPtr> * cnode_to_task)993 void AddRealDependencies(const FuncGraphManagerPtr &manager, const std::vector<CNodePtr> &cnode_vec,
994                          const std::vector<std::pair<TaskId, TaskId>> &dependencies,
995                          std::unordered_map<CNodePtr, TaskPtr> *cnode_to_task) {
996   size_t count = 0;
997   size_t redundant_count = 0;
998   for (const auto &dependency : dependencies) {
999     MS_LOG(INFO) << "Checking dependency " << dependency.first << " " << dependency.second;
1000     const auto &source = cnode_vec[dependency.first];
1001     const auto &dest = cnode_vec[dependency.second];
1002 
1003     // Ignore dependencies already there
1004     if ((*cnode_to_task)[source]->HasChild((*cnode_to_task)[dest])) {
1005       MS_LOG(INFO) << "Dependency " << dependency.first << " " << dependency.second
1006                    << " is redundant (already parent and child)";
1007       redundant_count++;
1008       continue;
1009     }
1010 
1011     // Add dependency only between Comp (check also between Comm later)
1012     bool comp_comp = true;
1013     if (common::AnfAlgo::IsCommunicationOp(source) && common::AnfAlgo::IsCommunicationOp(dest)) {
1014       comp_comp = false;
1015       MS_LOG(INFO) << "Ignore Comm to Comm dependency " << dependency.first << " " << dependency.second;
1016       continue;
1017     }
1018 
1019     // At least two inputs in destination node (input 0 is the node primitive)
1020     if (dest->size() < 2) {
1021       MS_LOG(INFO) << "Destination inputs size < 2: ignore";
1022       continue;
1023     }
1024 
1025     // If destination node (comp) has comm inputs, make dependency involving one of them
1026     for (size_t j = 1; j < dest->size(); ++j) {  // input 0 is node primitive: ignore
1027       if (!utils::isa<CNodePtr>(dest->input(j))) {
1028         MS_LOG(INFO) << "Not a cnodeptr at input " << j;
1029         continue;
1030       }
1031       if (comp_comp && !common::AnfAlgo::IsCommunicationOp(dest->input(j))) {
1032         MS_LOG(INFO) << "Dest " << dest << " Task " << dependency.second << " Input " << j
1033                      << " is not CommunicationOp: Ignore";
1034         continue;
1035       }
1036       if (!comp_comp && common::AnfAlgo::IsCommunicationOp(dest->input(j))) {
1037         MS_LOG(INFO) << "Dest " << dest << " Task " << dependency.second << " Input " << j
1038                      << " is CommunicationOp: Ignore";
1039         continue;
1040       }
1041 
1042       // Add real dependency logic here
1043       const auto &input_node = dest->input(j)->cast<CNodePtr>();
1044       std::vector<AnfNodePtr> depend_inputs{NewValueNode(prim::kPrimDepend), input_node, source};
1045       auto depend_node = dest->func_graph()->NewCNode(depend_inputs);
1046       depend_node->set_abstract(input_node->abstract()->Clone());
1047       depend_node->AddAttr("comp_comm_scheduling_depend", MakeValue(true));
1048       MS_EXCEPTION_IF_NULL(depend_node);
1049       auto &nodes = manager->node_users()[input_node];
1050       auto it = std::find_if(nodes.begin(), nodes.end(), [dest](const auto &user) { return user.first == dest; });
1051       if (it != nodes.end()) {
1052         int idx = (*it).second;
1053         manager->SetEdge(dest, idx, depend_node);
1054         MS_LOG(INFO) << "Added dependency from " << dependency.first << ", unique name " << source->UniqueName()
1055                      << ", to " << dependency.second << ", unique name " << dest->UniqueName();
1056         count++;
1057         break;  // add dependency involving only one destination node input
1058       } else {
1059         MS_LOG(INFO) << "User index not found: Ignore dependency and continue";
1060         continue;
1061       }
1062     }
1063   }
1064   MS_LOG(INFO) << "Num of real dependencies added is " << count;
1065   MS_LOG(INFO) << "Num of redundant dependencies (HasChild) is " << redundant_count;
1066   return;
1067 }
1068 
CompCommScheduling(const FuncGraphPtr & graph)1069 void CompCommScheduling(const FuncGraphPtr &graph) {
1070   if (parallel::ParallelContext::GetInstance()->parallel_mode() != parallel::kSemiAutoParallel &&
1071       parallel::ParallelContext::GetInstance()->parallel_mode() != parallel::kAutoParallel) {
1072     return;
1073   }
1074   if (common::GetEnv("MS_ENABLE_FRONTEND_SCHEDULING_OPTIMIZATION") != "1") {
1075     return;
1076   }
1077   MS_EXCEPTION_IF_NULL(graph);
1078   auto manager = graph->manager();
1079   MS_LOG(INFO) << "Main graph pointer: " << graph;
1080   MS_EXCEPTION_IF_NULL(manager);
1081 
1082   FuncGraphSet graphs = manager->func_graphs();
1083   for (const auto &subgraph : graphs) {
1084     MS_LOG(INFO) << "Start Scheduling Subgraph " << subgraph;
1085     std::stringstream graph_ss;
1086     graph_ss << subgraph;
1087     std::string graph_name = graph_ss.str();
1088     std::list<CNodePtr> cnode_list = subgraph->GetOrderedCnodes();
1089     std::vector<CNodePtr> cnode_vec(cnode_list.cbegin(), cnode_list.cend());
1090 
1091     MS_LOG(INFO) << "Start ExtractSchedulingInput";
1092     std::unordered_map<CNodePtr, TaskPtr> cnode_to_task;
1093     SchedulingInput scheduling_input = ExtractSchedulingInput(manager, cnode_vec, &cnode_to_task);
1094     MS_LOG(INFO) << "End ExtractSchedulingInput";
1095 
1096     auto scheduling_output = FastGreedyScheduler::Process(scheduling_input, graph_name);
1097     auto dependencies = FastGreedyScheduler::ScheduleToDependencies(scheduling_output);
1098 
1099     MS_LOG(INFO) << "Start AddRealDependencies";
1100     AddRealDependencies(manager, cnode_vec, dependencies, &cnode_to_task);
1101     MS_LOG(INFO) << "End AddRealDependencies";
1102   }
1103 }
1104 }  // namespace opt
1105 }  // namespace mindspore
1106