• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/distributed_runtime/scheduler.h"
17 
18 #include <queue>
19 
20 #include "tensorflow/core/common_runtime/device.h"
21 #include "tensorflow/core/common_runtime/device_set.h"
22 #include "tensorflow/core/graph/graph.h"
23 #include "tensorflow/core/util/util.h"
24 
25 namespace tensorflow {
26 
27 namespace {
28 
29 // Initialize the pending count for each node.
InitializePending(const Graph * graph,std::vector<int> * pending)30 void InitializePending(const Graph* graph, std::vector<int>* pending) {
31   pending->resize(graph->num_node_ids());
32   for (const Node* node : graph->nodes()) {
33     const int id = node->id();
34     int num_in_edges = 0;
35     if (IsMerge(node)) {
36       // For forward execution order, Merge nodes are special. We process
37       // them only once when one of its inputs is processed.
38       for (const Edge* edge : node->in_edges()) {
39         if (edge->IsControlEdge()) {
40           // Bit 0 is reserved to indicate if there is a data input.
41           num_in_edges += 2;
42         }
43       }
44     } else {
45       num_in_edges = node->in_edges().size();
46     }
47     (*pending)[id] = num_in_edges;
48   }
49 }
50 
51 // Return true if the update makes the destination of the edge ready to run.
UpdatePending(const Edge * edge,std::vector<int> * pending_count)52 bool UpdatePending(const Edge* edge, std::vector<int>* pending_count) {
53   const Node* out = edge->dst();
54   if (IsMerge(out)) {
55     if (edge->IsControlEdge()) {
56       (*pending_count)[out->id()] -= 2;
57       // Return true if we already got at least one input edge
58       //   and a control edge is the enabling one.
59       return ((*pending_count)[out->id()] == 1);
60     } else {
61       int count = (*pending_count)[out->id()];
62       (*pending_count)[out->id()] |= 0x1;
63       // If the first input edge is the enabling one, the count goes from
64       //   0 to 1 in this step. Return true iff count was zero.
65       return (count == 0);
66     }
67   } else {
68     return (--(*pending_count)[out->id()] == 0);
69   }
70 }
71 
72 }  // end namespace
73 
SlackAnalysis(const Graph * g,const CostModel * cost_model)74 SlackAnalysis::SlackAnalysis(const Graph* g, const CostModel* cost_model)
75     : graph_(g), cost_model_(cost_model) {}
76 
ComputeAsap(std::vector<Microseconds> * asap_times)77 Microseconds SlackAnalysis::ComputeAsap(std::vector<Microseconds>* asap_times) {
78   asap_times->resize(graph_->num_node_ids());
79 
80   std::vector<int> pending_count(graph_->num_node_ids());
81   InitializePending(graph_, &pending_count);
82 
83   std::deque<const Node*> queue;
84   Node* srcNode = graph_->source_node();
85   queue.push_back(srcNode);
86   (*asap_times)[srcNode->id()] = 0;
87 
88   while (!queue.empty()) {
89     const Node* curr = queue.front();
90     queue.pop_front();
91     Microseconds ctime = cost_model_->TimeEstimate(curr);
92     for (const Edge* out_edge : curr->out_edges()) {
93       // The time needed for 'out' to get its input from 'curr'.
94       Microseconds copy_time(0);
95       const Node* out = out_edge->dst();
96       if (!out_edge->IsControlEdge() &&
97           curr->assigned_device_name() != out->assigned_device_name()) {
98         // Add an arbitrary 10microsecs for each copy.
99         // TODO(yuanbyu): Use below with the real cost model.
100         // int index = out_edge->src_output();
101         // Bytes nb = cost_model_->SizeEstimate(curr, index);
102         // copy_time = CostModel::CopyTimeEstimate(nb);
103         copy_time = 10;
104       }
105       Microseconds new_asap = (*asap_times)[curr->id()] + ctime + copy_time;
106       if ((*asap_times)[out->id()] < new_asap) {
107         (*asap_times)[out->id()] = new_asap;
108       }
109 
110       bool is_ready = UpdatePending(out_edge, &pending_count);
111       if (is_ready) {
112         queue.push_back(out);
113       }
114     }
115   }
116   return (*asap_times)[graph_->sink_node()->id()];
117 }
118 
ComputeAlap(std::vector<Microseconds> * alap_times)119 Microseconds SlackAnalysis::ComputeAlap(std::vector<Microseconds>* alap_times) {
120   alap_times->resize(graph_->num_node_ids());
121 
122   std::vector<int> pending_count;
123   pending_count.resize(graph_->num_node_ids());
124   for (const Node* n : graph_->nodes()) {
125     // For reverse execution order, Switch nodes are special. We process
126     // them only once when one of its outputs is processed.
127     if (IsSwitch(n)) {
128       int32 num_control_edges = 0;
129       for (const Edge* edge : n->out_edges()) {
130         if (edge->IsControlEdge()) {
131           num_control_edges++;
132         }
133       }
134       pending_count[n->id()] = num_control_edges + 1;
135     } else {
136       pending_count[n->id()] = n->out_edges().size();
137     }
138   }
139 
140   std::deque<const Node*> queue;
141   Node* sinkNode = graph_->sink_node();
142   queue.push_back(sinkNode);
143   (*alap_times)[sinkNode->id()] = 0;
144 
145   while (!queue.empty()) {
146     const Node* curr = queue.front();
147     queue.pop_front();
148     for (const Edge* in_edge : curr->in_edges()) {
149       // The time needed for 'curr' to get its input from 'src'.
150       Microseconds copy_time(0);
151       const Node* src = in_edge->src();
152       if (!in_edge->IsControlEdge() &&
153           src->assigned_device_name() != curr->assigned_device_name()) {
154         // TODO(yuanbyu): Use the real cost model
155         // int index = out_edge->src_output();
156         // Bytes nb = cost_model_->SizeEstimate(curr, index);
157         // copy_time = CostModel::CopyTimeEstimate(nb);
158         copy_time = 10;
159       }
160       Microseconds ctime = cost_model_->TimeEstimate(src);
161       Microseconds new_latest = (*alap_times)[curr->id()] - ctime - copy_time;
162       if ((*alap_times)[src->id()] > new_latest) {
163         (*alap_times)[src->id()] = new_latest;
164       }
165 
166       int count = --pending_count[src->id()];
167       if (count == 0) {
168         queue.push_back(src);
169       }
170     }
171   }
172   return (*alap_times)[graph_->source_node()->id()];
173 }
174 
ComputeSlack(std::vector<int64> * slacks)175 void SlackAnalysis::ComputeSlack(std::vector<int64>* slacks) {
176   std::vector<Microseconds> asap_times;
177   std::vector<Microseconds> alap_times;
178   ComputeAsap(&asap_times);
179   ComputeAlap(&alap_times);
180   slacks->resize(graph_->num_node_ids());
181   Node* srcNode = graph_->source_node();
182   Microseconds makespan = alap_times[srcNode->id()];
183   for (Node* node : graph_->nodes()) {
184     Microseconds latest_stime = alap_times[node->id()] - makespan;
185     (*slacks)[node->id()] = (latest_stime - asap_times[node->id()]).value();
186   }
187 }
188 
GreedyScheduler(const DeviceSet * devices,const CostModel * cost_model,const Graph * g,std::vector<int64> * priority)189 GreedyScheduler::GreedyScheduler(const DeviceSet* devices,
190                                  const CostModel* cost_model, const Graph* g,
191                                  std::vector<int64>* priority)
192     : devices_(devices),
193       cost_model_(cost_model),
194       graph_(g),
195       priority_(priority) {
196   for (Device* d : devices_->devices()) {
197     Sim* s = new Sim;
198     // The number of compute units on a device. Set to 2 for now.
199     s->degree_parallelism = 2;
200     s->num_running = 0;
201     device_states_.insert(std::make_pair(d->name(), s));
202   }
203 }
204 
~GreedyScheduler()205 GreedyScheduler::~GreedyScheduler() {
206   for (auto& ds : device_states_) {
207     delete ds.second;
208   }
209 }
210 
ComputeSchedule(std::vector<Microseconds> * start_times)211 Microseconds GreedyScheduler::ComputeSchedule(
212     std::vector<Microseconds>* start_times) {
213   // Initialize pending_count
214   std::vector<int> pending_count(graph_->num_node_ids());
215   InitializePending(graph_, &pending_count);
216 
217   // Initialize event queue
218   std::priority_queue<Event> event_queue;
219   Event src_event;
220   src_event.node = graph_->source_node();
221   src_event.time = 0;
222   src_event.is_completion = true;
223   event_queue.push(src_event);
224   Microseconds max_completion = Microseconds(0);
225 
226   while (!event_queue.empty()) {
227     Event event = event_queue.top();
228     event_queue.pop();
229     if (event.is_completion) {
230       Sim* sim = device_states_[event.node->assigned_device_name()];
231       --sim->num_running;
232 
233       if (event.time > max_completion) {
234         max_completion = event.time;
235       }
236 
237       for (const Edge* out_edge : event.node->out_edges()) {
238         Microseconds copy_time(0);
239         const Node* out = out_edge->dst();
240         if (!out_edge->IsControlEdge() &&
241             event.node->assigned_device_name() != out->assigned_device_name()) {
242           // TODO(yuanbyu): Use below with the real cost model.
243           // int index = out_edge->src_output();
244           // Bytes nb = cost_model_->SizeEstimate(event.node, index);
245           // copy_time = CostModel::CopyTimeEstimate(nb);
246           copy_time = 10;
247         }
248         if ((*start_times)[out->id()] < event.time + copy_time) {
249           (*start_times)[out->id()] = event.time + copy_time;
250         }
251 
252         bool is_ready = UpdatePending(out_edge, &pending_count);
253         if (is_ready) {
254           Event e{out, (*start_times)[out->id()], false};
255           event_queue.push(e);
256         }
257       }
258     } else {
259       Sim* sim = device_states_[event.node->assigned_device_name()];
260       sim->ready_nodes.push_back(event.node);
261     }
262 
263     for (auto& x : device_states_) {
264       Sim* sim = x.second;
265       while (sim->num_running < sim->degree_parallelism &&
266              !sim->ready_nodes.empty()) {
267         Event e;
268         e.node = GetNodeWithHighestPriority(sim->ready_nodes);
269         e.time = event.time + cost_model_->TimeEstimate(e.node);
270         e.is_completion = true;
271         event_queue.push(e);
272         (*start_times)[e.node->id()] = event.time;
273         ++sim->num_running;
274       }
275     }
276   }
277   return max_completion;
278 }
279 
GetNodeWithHighestPriority(const std::vector<const Node * > & nodes)280 const Node* GreedyScheduler::GetNodeWithHighestPriority(
281     const std::vector<const Node*>& nodes) {
282   const Node* curr_node = nullptr;
283   int64 curr_priority = kint64max;
284   for (const Node* n : nodes) {
285     if ((*priority_)[n->id()] < curr_priority) {
286       curr_node = n;
287       curr_priority = (*priority_)[n->id()];
288     }
289   }
290   return curr_node;
291 }
292 
PriorityScheduler(const DeviceSet * devices,const CostModel * cost_model,const Graph * g)293 PriorityScheduler::PriorityScheduler(const DeviceSet* devices,
294                                      const CostModel* cost_model,
295                                      const Graph* g)
296     : devices_(devices), cost_model_(cost_model), graph_(g) {}
297 
ComputeSchedule(std::vector<Microseconds> * start_times)298 Microseconds PriorityScheduler::ComputeSchedule(
299     std::vector<Microseconds>* start_times) {
300   std::vector<int64> slacks;
301   SlackAnalysis slack(graph_, cost_model_);
302   slack.ComputeSlack(&slacks);
303   GreedyScheduler greedysched(devices_, cost_model_, graph_, &slacks);
304   return greedysched.ComputeSchedule(start_times);
305 }
306 
AssignPriorities(std::vector<int64> * priorities)307 Microseconds PriorityScheduler::AssignPriorities(
308     std::vector<int64>* priorities) {
309   std::vector<Microseconds> start_times;
310   Microseconds makespan = ComputeSchedule(&start_times);
311 
312   for (const Node* n : graph_->nodes()) {
313     (*priorities)[n->id()] = start_times[n->id()].value();
314   }
315   return makespan;
316 }
317 
318 }  // namespace tensorflow
319