1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/distributed_runtime/scheduler.h"
17
18 #include <queue>
19
20 #include "tensorflow/core/common_runtime/device.h"
21 #include "tensorflow/core/common_runtime/device_set.h"
22 #include "tensorflow/core/graph/graph.h"
23 #include "tensorflow/core/util/util.h"
24
25 namespace tensorflow {
26
27 namespace {
28
29 // Initialize the pending count for each node.
InitializePending(const Graph * graph,std::vector<int> * pending)30 void InitializePending(const Graph* graph, std::vector<int>* pending) {
31 pending->resize(graph->num_node_ids());
32 for (const Node* node : graph->nodes()) {
33 const int id = node->id();
34 int num_in_edges = 0;
35 if (IsMerge(node)) {
36 // For forward execution order, Merge nodes are special. We process
37 // them only once when one of its inputs is processed.
38 for (const Edge* edge : node->in_edges()) {
39 if (edge->IsControlEdge()) {
40 // Bit 0 is reserved to indicate if there is a data input.
41 num_in_edges += 2;
42 }
43 }
44 } else {
45 num_in_edges = node->in_edges().size();
46 }
47 (*pending)[id] = num_in_edges;
48 }
49 }
50
51 // Return true if the update makes the destination of the edge ready to run.
UpdatePending(const Edge * edge,std::vector<int> * pending_count)52 bool UpdatePending(const Edge* edge, std::vector<int>* pending_count) {
53 const Node* out = edge->dst();
54 if (IsMerge(out)) {
55 if (edge->IsControlEdge()) {
56 (*pending_count)[out->id()] -= 2;
57 // Return true if we already got at least one input edge
58 // and a control edge is the enabling one.
59 return ((*pending_count)[out->id()] == 1);
60 } else {
61 int count = (*pending_count)[out->id()];
62 (*pending_count)[out->id()] |= 0x1;
63 // If the first input edge is the enabling one, the count goes from
64 // 0 to 1 in this step. Return true iff count was zero.
65 return (count == 0);
66 }
67 } else {
68 return (--(*pending_count)[out->id()] == 0);
69 }
70 }
71
72 } // end namespace
73
SlackAnalysis(const Graph * g,const CostModel * cost_model)74 SlackAnalysis::SlackAnalysis(const Graph* g, const CostModel* cost_model)
75 : graph_(g), cost_model_(cost_model) {}
76
ComputeAsap(std::vector<Microseconds> * asap_times)77 Microseconds SlackAnalysis::ComputeAsap(std::vector<Microseconds>* asap_times) {
78 asap_times->resize(graph_->num_node_ids());
79
80 std::vector<int> pending_count(graph_->num_node_ids());
81 InitializePending(graph_, &pending_count);
82
83 std::deque<const Node*> queue;
84 Node* srcNode = graph_->source_node();
85 queue.push_back(srcNode);
86 (*asap_times)[srcNode->id()] = 0;
87
88 while (!queue.empty()) {
89 const Node* curr = queue.front();
90 queue.pop_front();
91 Microseconds ctime = cost_model_->TimeEstimate(curr);
92 for (const Edge* out_edge : curr->out_edges()) {
93 // The time needed for 'out' to get its input from 'curr'.
94 Microseconds copy_time(0);
95 const Node* out = out_edge->dst();
96 if (!out_edge->IsControlEdge() &&
97 curr->assigned_device_name() != out->assigned_device_name()) {
98 // Add an arbitrary 10microsecs for each copy.
99 // TODO(yuanbyu): Use below with the real cost model.
100 // int index = out_edge->src_output();
101 // Bytes nb = cost_model_->SizeEstimate(curr, index);
102 // copy_time = CostModel::CopyTimeEstimate(nb);
103 copy_time = 10;
104 }
105 Microseconds new_asap = (*asap_times)[curr->id()] + ctime + copy_time;
106 if ((*asap_times)[out->id()] < new_asap) {
107 (*asap_times)[out->id()] = new_asap;
108 }
109
110 bool is_ready = UpdatePending(out_edge, &pending_count);
111 if (is_ready) {
112 queue.push_back(out);
113 }
114 }
115 }
116 return (*asap_times)[graph_->sink_node()->id()];
117 }
118
ComputeAlap(std::vector<Microseconds> * alap_times)119 Microseconds SlackAnalysis::ComputeAlap(std::vector<Microseconds>* alap_times) {
120 alap_times->resize(graph_->num_node_ids());
121
122 std::vector<int> pending_count;
123 pending_count.resize(graph_->num_node_ids());
124 for (const Node* n : graph_->nodes()) {
125 // For reverse execution order, Switch nodes are special. We process
126 // them only once when one of its outputs is processed.
127 if (IsSwitch(n)) {
128 int32 num_control_edges = 0;
129 for (const Edge* edge : n->out_edges()) {
130 if (edge->IsControlEdge()) {
131 num_control_edges++;
132 }
133 }
134 pending_count[n->id()] = num_control_edges + 1;
135 } else {
136 pending_count[n->id()] = n->out_edges().size();
137 }
138 }
139
140 std::deque<const Node*> queue;
141 Node* sinkNode = graph_->sink_node();
142 queue.push_back(sinkNode);
143 (*alap_times)[sinkNode->id()] = 0;
144
145 while (!queue.empty()) {
146 const Node* curr = queue.front();
147 queue.pop_front();
148 for (const Edge* in_edge : curr->in_edges()) {
149 // The time needed for 'curr' to get its input from 'src'.
150 Microseconds copy_time(0);
151 const Node* src = in_edge->src();
152 if (!in_edge->IsControlEdge() &&
153 src->assigned_device_name() != curr->assigned_device_name()) {
154 // TODO(yuanbyu): Use the real cost model
155 // int index = out_edge->src_output();
156 // Bytes nb = cost_model_->SizeEstimate(curr, index);
157 // copy_time = CostModel::CopyTimeEstimate(nb);
158 copy_time = 10;
159 }
160 Microseconds ctime = cost_model_->TimeEstimate(src);
161 Microseconds new_latest = (*alap_times)[curr->id()] - ctime - copy_time;
162 if ((*alap_times)[src->id()] > new_latest) {
163 (*alap_times)[src->id()] = new_latest;
164 }
165
166 int count = --pending_count[src->id()];
167 if (count == 0) {
168 queue.push_back(src);
169 }
170 }
171 }
172 return (*alap_times)[graph_->source_node()->id()];
173 }
174
ComputeSlack(std::vector<int64> * slacks)175 void SlackAnalysis::ComputeSlack(std::vector<int64>* slacks) {
176 std::vector<Microseconds> asap_times;
177 std::vector<Microseconds> alap_times;
178 ComputeAsap(&asap_times);
179 ComputeAlap(&alap_times);
180 slacks->resize(graph_->num_node_ids());
181 Node* srcNode = graph_->source_node();
182 Microseconds makespan = alap_times[srcNode->id()];
183 for (Node* node : graph_->nodes()) {
184 Microseconds latest_stime = alap_times[node->id()] - makespan;
185 (*slacks)[node->id()] = (latest_stime - asap_times[node->id()]).value();
186 }
187 }
188
GreedyScheduler(const DeviceSet * devices,const CostModel * cost_model,const Graph * g,std::vector<int64> * priority)189 GreedyScheduler::GreedyScheduler(const DeviceSet* devices,
190 const CostModel* cost_model, const Graph* g,
191 std::vector<int64>* priority)
192 : devices_(devices),
193 cost_model_(cost_model),
194 graph_(g),
195 priority_(priority) {
196 for (Device* d : devices_->devices()) {
197 Sim* s = new Sim;
198 // The number of compute units on a device. Set to 2 for now.
199 s->degree_parallelism = 2;
200 s->num_running = 0;
201 device_states_.insert(std::make_pair(d->name(), s));
202 }
203 }
204
~GreedyScheduler()205 GreedyScheduler::~GreedyScheduler() {
206 for (auto& ds : device_states_) {
207 delete ds.second;
208 }
209 }
210
ComputeSchedule(std::vector<Microseconds> * start_times)211 Microseconds GreedyScheduler::ComputeSchedule(
212 std::vector<Microseconds>* start_times) {
213 // Initialize pending_count
214 std::vector<int> pending_count(graph_->num_node_ids());
215 InitializePending(graph_, &pending_count);
216
217 // Initialize event queue
218 std::priority_queue<Event> event_queue;
219 Event src_event;
220 src_event.node = graph_->source_node();
221 src_event.time = 0;
222 src_event.is_completion = true;
223 event_queue.push(src_event);
224 Microseconds max_completion = Microseconds(0);
225
226 while (!event_queue.empty()) {
227 Event event = event_queue.top();
228 event_queue.pop();
229 if (event.is_completion) {
230 Sim* sim = device_states_[event.node->assigned_device_name()];
231 --sim->num_running;
232
233 if (event.time > max_completion) {
234 max_completion = event.time;
235 }
236
237 for (const Edge* out_edge : event.node->out_edges()) {
238 Microseconds copy_time(0);
239 const Node* out = out_edge->dst();
240 if (!out_edge->IsControlEdge() &&
241 event.node->assigned_device_name() != out->assigned_device_name()) {
242 // TODO(yuanbyu): Use below with the real cost model.
243 // int index = out_edge->src_output();
244 // Bytes nb = cost_model_->SizeEstimate(event.node, index);
245 // copy_time = CostModel::CopyTimeEstimate(nb);
246 copy_time = 10;
247 }
248 if ((*start_times)[out->id()] < event.time + copy_time) {
249 (*start_times)[out->id()] = event.time + copy_time;
250 }
251
252 bool is_ready = UpdatePending(out_edge, &pending_count);
253 if (is_ready) {
254 Event e{out, (*start_times)[out->id()], false};
255 event_queue.push(e);
256 }
257 }
258 } else {
259 Sim* sim = device_states_[event.node->assigned_device_name()];
260 sim->ready_nodes.push_back(event.node);
261 }
262
263 for (auto& x : device_states_) {
264 Sim* sim = x.second;
265 while (sim->num_running < sim->degree_parallelism &&
266 !sim->ready_nodes.empty()) {
267 Event e;
268 e.node = GetNodeWithHighestPriority(sim->ready_nodes);
269 e.time = event.time + cost_model_->TimeEstimate(e.node);
270 e.is_completion = true;
271 event_queue.push(e);
272 (*start_times)[e.node->id()] = event.time;
273 ++sim->num_running;
274 }
275 }
276 }
277 return max_completion;
278 }
279
GetNodeWithHighestPriority(const std::vector<const Node * > & nodes)280 const Node* GreedyScheduler::GetNodeWithHighestPriority(
281 const std::vector<const Node*>& nodes) {
282 const Node* curr_node = nullptr;
283 int64 curr_priority = kint64max;
284 for (const Node* n : nodes) {
285 if ((*priority_)[n->id()] < curr_priority) {
286 curr_node = n;
287 curr_priority = (*priority_)[n->id()];
288 }
289 }
290 return curr_node;
291 }
292
PriorityScheduler(const DeviceSet * devices,const CostModel * cost_model,const Graph * g)293 PriorityScheduler::PriorityScheduler(const DeviceSet* devices,
294 const CostModel* cost_model,
295 const Graph* g)
296 : devices_(devices), cost_model_(cost_model), graph_(g) {}
297
ComputeSchedule(std::vector<Microseconds> * start_times)298 Microseconds PriorityScheduler::ComputeSchedule(
299 std::vector<Microseconds>* start_times) {
300 std::vector<int64> slacks;
301 SlackAnalysis slack(graph_, cost_model_);
302 slack.ComputeSlack(&slacks);
303 GreedyScheduler greedysched(devices_, cost_model_, graph_, &slacks);
304 return greedysched.ComputeSchedule(start_times);
305 }
306
AssignPriorities(std::vector<int64> * priorities)307 Microseconds PriorityScheduler::AssignPriorities(
308 std::vector<int64>* priorities) {
309 std::vector<Microseconds> start_times;
310 Microseconds makespan = ComputeSchedule(&start_times);
311
312 for (const Node* n : graph_->nodes()) {
313 (*priorities)[n->id()] = start_times[n->id()].value();
314 }
315 return makespan;
316 }
317
318 } // namespace tensorflow
319