1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_SERIAL_DEVICE_BATCH_SCHEDULER_H_
16 #define TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_SERIAL_DEVICE_BATCH_SCHEDULER_H_
17
18 #include <algorithm>
19 #include <functional>
20 #include <memory>
21 #include <random>
22 #include <unordered_map>
23 #include <vector>
24
25 #include "tensorflow/core/kernels/batching_util/batch_scheduler.h"
26 #include "tensorflow/core/lib/core/errors.h"
27 #include "tensorflow/core/lib/core/status.h"
28 #include "tensorflow/core/lib/core/threadpool.h"
29 #include "tensorflow/core/platform/cpu_info.h"
30 #include "tensorflow/core/platform/env.h"
31 #include "tensorflow/core/platform/thread_annotations.h"
32 #include "tensorflow/core/platform/types.h"
33
34 namespace tensorflow {
35 namespace serving {
36 namespace internal {
37 template <typename TaskType>
38 class SDBSBatch;
39
40 template <typename TaskType>
41 class SDBSQueue;
42 } // namespace internal
43
44 // EXPERIMENTAL: API MAY BE SUBJECTED TO SUDDEN CHANGES.
45 //
46 // Shared batch scheduler designed for batches which are processed by a serial
47 // device (e.g. GPU, TPU). When batch processing involves a mix of
48 // parallelizable cpu work and non-parallelizable on-device work, overall
49 // latency can be minimized by producing batches at a (load dependent) rate
50 // which keeps the serial device uniformly busy.
51 //
52 // SerialDeviceBatchScheduler (SDBS) controls the batching rate by limiting the
53 // allowed number of concurrently processed batches. Too large a limit causes
54 // batches to pile up behind the serial device, adding to the overall batch
55 // latency. Too small a limit underutilizes the serial device and harms latency
56 // by forcing batches to wait longer to be processed. Feedback from the device
57 // (i.e. avg number of batches directly pending on the device) is used to set
58 // the correct limit.
59 //
60 // SDBS groups requests into per model batches which are processed when a batch
61 // processing thread becomes available. SDBS prioritizes batches primarily by
62 // age (i.e. the batch's oldest request) along with a configurable preference
63 // for scheduling larger batches first.
64
65
66 template <typename TaskType>
67 class SerialDeviceBatchScheduler : public std::enable_shared_from_this<
68 SerialDeviceBatchScheduler<TaskType>> {
69 public:
70 ~SerialDeviceBatchScheduler();
71
72 struct Options {
73 // The name to use for the pool of batch threads.
74 string thread_pool_name = {"batch_threads"};
75 // Maximum number of batch processing threads.
76 int64 num_batch_threads = port::NumSchedulableCPUs();
77 // Although batch selection is primarily based on age, this parameter
78 // specifies a preference for larger batches. A full batch will be
79 // scheduled before an older, nearly empty batch as long as the age gap is
80 // less than full_batch_scheduling_boost_micros. The optimal value for this
81 // parameter should be of order the batch processing latency, but must be
82 // chosen carefully, as too large a value will harm tail latency.
83 int64 full_batch_scheduling_boost_micros = 0;
84 // The environment to use (typically only overridden by test code).
85 Env* env = Env::Default();
86 // Initial limit for number of batches being concurrently processed.
87 int64 initial_in_flight_batches_limit = 3;
88 // Returns the current number of batches directly waiting to be processed
89 // by the serial device (i.e. GPU, TPU).
90 std::function<int64()> get_pending_on_serial_device;
91 // Desired average number of batches directly waiting to be processed by the
92 // serial device. Small numbers of O(1) should deliver the best latency.
93 double target_pending = 2;
94 // Number of batches between potential adjustments of
95 // in_flight_batches_limit. Larger numbers will reduce noise, but will be
96 // less responsive to sudden changes in workload.
97 int64 batches_to_average_over = 1000;
98 };
99
100 // Ownership is shared between the caller of Create() and any queues created
101 // via AddQueue().
102 static Status Create(
103 const Options& options,
104 std::shared_ptr<SerialDeviceBatchScheduler<TaskType>>* scheduler);
105
106 struct QueueOptions {
107 // Maximum size of each batch.
108 int max_batch_size = 1000;
109 // Maximum number of enqueued (i.e. non-scheduled) batches.
110 int max_enqueued_batches = 10;
111 };
112
113 using BatchProcessor = std::function<void(std::unique_ptr<Batch<TaskType>>)>;
114
115 // Adds queue (and its callback) to be managed by this scheduler.
116 Status AddQueue(const QueueOptions& options,
117 BatchProcessor process_batch_callback,
118 std::unique_ptr<BatchScheduler<TaskType>>* queue);
119
in_flight_batches_limit()120 double in_flight_batches_limit() {
121 mutex_lock l(mu_);
122 return in_flight_batches_limit_;
123 }
124
recent_low_traffic_ratio()125 double recent_low_traffic_ratio() {
126 mutex_lock l(mu_);
127 return recent_low_traffic_ratio_;
128 }
129
130 private:
131 // access to AddBatch(), RemoveQueue(), env().
132 friend class internal::SDBSQueue<TaskType>;
133
134 explicit SerialDeviceBatchScheduler(const Options& options);
135
136 // Continuously retrieves and processes batches.
137 void ProcessBatches();
138
139 // Notifies scheduler of non-empty batch which is eligible for processing.
140 void AddBatch(const internal::SDBSBatch<TaskType>* batch);
141
142 // Removes queue from scheduler.
143 void RemoveQueue(const internal::SDBSQueue<TaskType>* queue);
144
env()145 Env* env() const { return options_.env; }
146
147 const Options options_;
148
149 // Collection of batches added by AddBatch. Owned by scheduler until they are
150 // released for processing.
151 std::vector<const internal::SDBSBatch<TaskType>*> batches_ TF_GUARDED_BY(mu_);
152
153 // Unowned queues and callbacks added by AddQueue.
154 std::unordered_map<const internal::SDBSQueue<TaskType>*, BatchProcessor>
155 queues_and_callbacks_ TF_GUARDED_BY(mu_);
156
157 // Responsible for running the batch processing callbacks.
158 std::unique_ptr<thread::ThreadPool> batch_thread_pool_;
159
160 // Limit on number of batches which can be concurrently processed.
161 int64 in_flight_batches_limit_ TF_GUARDED_BY(mu_);
162
163 // Number of batch processing threads.
164 int64 processing_threads_ TF_GUARDED_BY(mu_) = 0;
165
166 // Number of batches processed since the last in_flight_batches_limit_
167 // adjustment.
168 int64 batch_count_ TF_GUARDED_BY(mu_) = 0;
169
170 // Number of times since the last in_flight_batches_limit_ adjustment when a
171 // processing thread was available but there were no batches to process.
172 int64 no_batch_count_ TF_GUARDED_BY(mu_) = 0;
173
174 // Sum of batches pending on the serial device since the last
175 // in_flight_batches_limit_ adjustment.
176 int64 pending_sum_ = 0;
177
178 // Sum of batch latencies since the last in_flight_batches_limit_ adjustment.
179 int64 batch_latency_sum_ = 0;
180
181 // Average period between which two consecutive batches begin processing.
182 int64 batch_period_micros_ = 0;
183
184 // Moving average tracking the fraction of recent in_flight_batches_limit_
185 // adjustments where the external traffic was not high enough to provide
186 // useful feedback for an adjustment.
187 double recent_low_traffic_ratio_ = 0;
188
189 mutex mu_;
190
191 TF_DISALLOW_COPY_AND_ASSIGN(SerialDeviceBatchScheduler);
192 };
193
194 //////////////////////////////////////////////////////////
195 // Implementation details follow. API users need not read.
196
197 namespace internal {
198 // Consolidates tasks into batches, passing them off to the
199 // SerialDeviceBatchScheduler for processing.
200 template <typename TaskType>
201 class SDBSQueue : public BatchScheduler<TaskType> {
202 public:
203 using QueueOptions =
204 typename SerialDeviceBatchScheduler<TaskType>::QueueOptions;
205
206 SDBSQueue(std::shared_ptr<SerialDeviceBatchScheduler<TaskType>> scheduler,
207 const QueueOptions& options);
208
209 ~SDBSQueue() override;
210
211 // Adds task to current batch. Fails if the task size is larger than the batch
212 // size or if the current batch is full and this queue's number of outstanding
213 // batches is at its maximum.
214 Status Schedule(std::unique_ptr<TaskType>* task) override;
215
216 // Number of tasks waiting to be scheduled.
217 size_t NumEnqueuedTasks() const override;
218
219 // Number of size 1 tasks which could currently be scheduled without failing.
220 size_t SchedulingCapacity() const override;
221
222 // Notifies queue that a batch is about to be scheduled; the queue should not
223 // place any more tasks in this batch.
224 void ReleaseBatch(const SDBSBatch<TaskType>* batch);
225
max_task_size()226 size_t max_task_size() const override { return options_.max_batch_size; }
227
228 private:
229 std::shared_ptr<SerialDeviceBatchScheduler<TaskType>> scheduler_;
230 const QueueOptions options_;
231 // Owned by scheduler_.
232 SDBSBatch<TaskType>* current_batch_ TF_GUARDED_BY(mu_) = nullptr;
233 int64 num_enqueued_batches_ TF_GUARDED_BY(mu_) = 0;
234 int64 num_enqueued_tasks_ TF_GUARDED_BY(mu_) = 0;
235 mutable mutex mu_;
236 TF_DISALLOW_COPY_AND_ASSIGN(SDBSQueue);
237 };
238
239 // Batch which remembers when and by whom it was created.
240 template <typename TaskType>
241 class SDBSBatch : public Batch<TaskType> {
242 public:
SDBSBatch(SDBSQueue<TaskType> * queue,int64 creation_time_micros)243 SDBSBatch(SDBSQueue<TaskType>* queue, int64 creation_time_micros)
244 : queue_(queue), creation_time_micros_(creation_time_micros) {}
245
~SDBSBatch()246 ~SDBSBatch() override {}
247
queue()248 SDBSQueue<TaskType>* queue() const { return queue_; }
249
creation_time_micros()250 int64 creation_time_micros() const { return creation_time_micros_; }
251
252 private:
253 SDBSQueue<TaskType>* queue_;
254 const int64 creation_time_micros_;
255 TF_DISALLOW_COPY_AND_ASSIGN(SDBSBatch);
256 };
257 } // namespace internal
258
259 // ---------------- SerialDeviceBatchScheduler ----------------
260
261 template <typename TaskType>
Create(const Options & options,std::shared_ptr<SerialDeviceBatchScheduler<TaskType>> * scheduler)262 Status SerialDeviceBatchScheduler<TaskType>::Create(
263 const Options& options,
264 std::shared_ptr<SerialDeviceBatchScheduler<TaskType>>* scheduler) {
265 if (options.num_batch_threads < 1) {
266 return errors::InvalidArgument("num_batch_threads must be positive; was ",
267 options.num_batch_threads);
268 }
269 if (options.initial_in_flight_batches_limit < 1) {
270 return errors::InvalidArgument(
271 "initial_in_flight_batches_limit must be positive; was ",
272 options.initial_in_flight_batches_limit);
273 }
274 if (options.initial_in_flight_batches_limit > options.num_batch_threads) {
275 return errors::InvalidArgument(
276 "initial_in_flight_batches_limit (",
277 options.initial_in_flight_batches_limit,
278 ") should not be larger than num_batch_threads (",
279 options.num_batch_threads, ")");
280 }
281 if (options.full_batch_scheduling_boost_micros < 0) {
282 return errors::InvalidArgument(
283 "full_batch_scheduling_boost_micros can't be negative; was ",
284 options.full_batch_scheduling_boost_micros);
285 }
286 if (options.batches_to_average_over < 1) {
287 return errors::InvalidArgument(
288 "batches_to_average_over should be "
289 "greater than or equal to 1; was ",
290 options.batches_to_average_over);
291 }
292 if (options.target_pending <= 0) {
293 return errors::InvalidArgument(
294 "target_pending should be larger than zero; was ",
295 options.target_pending);
296 }
297 if (!options.get_pending_on_serial_device) {
298 return errors::InvalidArgument(
299 "get_pending_on_serial_device must be "
300 "specified");
301 }
302 scheduler->reset(new SerialDeviceBatchScheduler<TaskType>(options));
303 return Status::OK();
304 }
305
306 template <typename TaskType>
SerialDeviceBatchScheduler(const Options & options)307 SerialDeviceBatchScheduler<TaskType>::SerialDeviceBatchScheduler(
308 const Options& options)
309 : options_(options),
310 in_flight_batches_limit_(options.initial_in_flight_batches_limit),
311 processing_threads_(options.initial_in_flight_batches_limit) {
312 batch_thread_pool_.reset(new thread::ThreadPool(
313 env(), options.thread_pool_name, options.num_batch_threads));
314 for (int i = 0; i < processing_threads_; i++) {
315 batch_thread_pool_->Schedule(
316 std::bind(&SerialDeviceBatchScheduler<TaskType>::ProcessBatches, this));
317 }
318 }
319
320 template <typename TaskType>
~SerialDeviceBatchScheduler()321 SerialDeviceBatchScheduler<TaskType>::~SerialDeviceBatchScheduler() {
322 // Signal processing threads to exit.
323 {
324 mutex_lock l(mu_);
325 processing_threads_ = 0;
326 }
327 // Hangs until all threads finish.
328 batch_thread_pool_.reset();
329 }
330
331 template <typename TaskType>
AddQueue(const QueueOptions & options,BatchProcessor process_batch_callback,std::unique_ptr<BatchScheduler<TaskType>> * queue)332 Status SerialDeviceBatchScheduler<TaskType>::AddQueue(
333 const QueueOptions& options, BatchProcessor process_batch_callback,
334 std::unique_ptr<BatchScheduler<TaskType>>* queue) {
335 if (options.max_batch_size <= 0) {
336 return errors::InvalidArgument("max_batch_size must be positive; was ",
337 options.max_batch_size);
338 }
339 if (options.max_enqueued_batches <= 0) {
340 return errors::InvalidArgument(
341 "max_enqueued_batches must be positive; was ",
342 options.max_enqueued_batches);
343 }
344 internal::SDBSQueue<TaskType>* SDBS_queue_raw;
345 queue->reset(SDBS_queue_raw = new internal::SDBSQueue<TaskType>(
346 this->shared_from_this(), options));
347 mutex_lock l(mu_);
348 queues_and_callbacks_[SDBS_queue_raw] = process_batch_callback;
349 return Status::OK();
350 }
351
352 template <typename TaskType>
AddBatch(const internal::SDBSBatch<TaskType> * batch)353 void SerialDeviceBatchScheduler<TaskType>::AddBatch(
354 const internal::SDBSBatch<TaskType>* batch) {
355 mutex_lock l(mu_);
356 batches_.push_back(batch);
357 }
358
359 template <typename TaskType>
RemoveQueue(const internal::SDBSQueue<TaskType> * queue)360 void SerialDeviceBatchScheduler<TaskType>::RemoveQueue(
361 const internal::SDBSQueue<TaskType>* queue) {
362 mutex_lock l(mu_);
363 queues_and_callbacks_.erase(queue);
364 }
365
366 template <typename TaskType>
ProcessBatches()367 void SerialDeviceBatchScheduler<TaskType>::ProcessBatches() {
368 const int64 kIdleThreadSleepTimeMicros = 1000;
369 const double kMaxNoBatchRatio = .1;
370 const double kLowTrafficMovingAverageFactor = .1;
371 for (;;) {
372 mu_.lock();
373 if (processing_threads_ < 1 ||
374 processing_threads_ > in_flight_batches_limit_) {
375 processing_threads_--;
376 mu_.unlock();
377 break;
378 }
379 if (batches_.empty()) {
380 no_batch_count_++;
381 int64 sleep_time = batch_period_micros_ ? batch_period_micros_
382 : kIdleThreadSleepTimeMicros;
383 mu_.unlock();
384 env()->SleepForMicroseconds(sleep_time);
385 continue;
386 }
387 auto best_it = batches_.begin();
388 double best_score =
389 (*best_it)->creation_time_micros() -
390 options_.full_batch_scheduling_boost_micros * (*best_it)->size() /
391 static_cast<double>((*best_it)->queue()->max_task_size());
392 for (auto it = batches_.begin() + 1; it != batches_.end(); it++) {
393 const double score =
394 (*it)->creation_time_micros() -
395 options_.full_batch_scheduling_boost_micros * (*it)->size() /
396 static_cast<double>((*it)->queue()->max_task_size());
397 if (score < best_score) {
398 best_score = score;
399 best_it = it;
400 }
401 }
402 const internal::SDBSBatch<TaskType>* batch = *best_it;
403 batches_.erase(best_it);
404 // Queue may destroy itself after ReleaseBatch is called.
405 batch->queue()->ReleaseBatch(batch);
406 auto callback = queues_and_callbacks_[batch->queue()];
407 mu_.unlock();
408 int64 start_time = env()->NowMicros();
409 callback(std::unique_ptr<Batch<TaskType>>(
410 const_cast<internal::SDBSBatch<TaskType>*>(batch)));
411 int64 end_time = env()->NowMicros();
412 mu_.lock();
413 batch_count_++;
414 batch_latency_sum_ += end_time - start_time;
415 pending_sum_ += options_.get_pending_on_serial_device();
416 if (batch_count_ == options_.batches_to_average_over) {
417 recent_low_traffic_ratio_ *= (1 - kLowTrafficMovingAverageFactor);
418 // Only adjust in_flight_batches_limit_ if external load is large enough
419 // to consistently provide batches. Otherwise we would (mistakenly) assume
420 // that the device is underutilized because in_flight_batches_limit_ is
421 // too small.
422 if (no_batch_count_ < kMaxNoBatchRatio * batch_count_) {
423 double avg_pending = pending_sum_ / static_cast<double>(batch_count_);
424 // Avg processing time / # of concurrent batches gives the avg period
425 // between which two consecutive batches begin processing. Used to set a
426 // reasonable sleep time for idle batch processing threads.
427 batch_period_micros_ =
428 batch_latency_sum_ / batch_count_ / in_flight_batches_limit_;
429 // When the processing pipeline is consistently busy, the average number
430 // of pending batches differs from in_flight_batches_limit_ by a
431 // load-dependent offset. Adjust in_flight_batches_limit_to maintain
432 // the desired target pending.
433 in_flight_batches_limit_ +=
434 std::round(options_.target_pending - avg_pending);
435 in_flight_batches_limit_ = std::max(in_flight_batches_limit_, int64{1});
436 in_flight_batches_limit_ =
437 std::min(in_flight_batches_limit_, options_.num_batch_threads);
438 // Add extra processing threads if necessary.
439 if (processing_threads_ > 0 &&
440 processing_threads_ < in_flight_batches_limit_) {
441 int extra_threads = in_flight_batches_limit_ - processing_threads_;
442 for (int i = 0; i < extra_threads; i++) {
443 batch_thread_pool_->Schedule(std::bind(
444 &SerialDeviceBatchScheduler<TaskType>::ProcessBatches, this));
445 }
446 processing_threads_ = in_flight_batches_limit_;
447 }
448 } else {
449 recent_low_traffic_ratio_ += kLowTrafficMovingAverageFactor;
450 }
451 batch_count_ = 0;
452 no_batch_count_ = 0;
453 pending_sum_ = 0;
454 batch_latency_sum_ = 0;
455 }
456 mu_.unlock();
457 }
458 }
459
460 // ---------------- SDBSQueue ----------------
461
462 namespace internal {
463 template <typename TaskType>
SDBSQueue(std::shared_ptr<SerialDeviceBatchScheduler<TaskType>> scheduler,const QueueOptions & options)464 SDBSQueue<TaskType>::SDBSQueue(
465 std::shared_ptr<SerialDeviceBatchScheduler<TaskType>> scheduler,
466 const QueueOptions& options)
467 : scheduler_(scheduler), options_(options) {}
468
469 template <typename TaskType>
~SDBSQueue()470 SDBSQueue<TaskType>::~SDBSQueue() {
471 // Wait until last batch has been scheduled.
472 const int kSleepMicros = 1000;
473 for (;;) {
474 {
475 mutex_lock l(mu_);
476 if (num_enqueued_batches_ == 0) {
477 break;
478 }
479 }
480 scheduler_->env()->SleepForMicroseconds(kSleepMicros);
481 }
482 scheduler_->RemoveQueue(this);
483 }
484
485 template <typename TaskType>
Schedule(std::unique_ptr<TaskType> * task)486 Status SDBSQueue<TaskType>::Schedule(std::unique_ptr<TaskType>* task) {
487 SDBSBatch<TaskType>* new_batch = nullptr;
488 size_t size = (*task)->size();
489 if (size > options_.max_batch_size) {
490 return errors::InvalidArgument("Task size ", size,
491 " is larger than maximum batch size ",
492 options_.max_batch_size);
493 }
494 {
495 mutex_lock l(mu_);
496 // Current batch is full, create another if allowed.
497 if (current_batch_ &&
498 current_batch_->size() + size > options_.max_batch_size) {
499 if (num_enqueued_batches_ >= options_.max_enqueued_batches) {
500 return errors::Unavailable("The batch scheduling queue is full");
501 }
502 current_batch_->Close();
503 current_batch_ = nullptr;
504 }
505 if (!current_batch_) {
506 num_enqueued_batches_++;
507 current_batch_ = new_batch =
508 new SDBSBatch<TaskType>(this, scheduler_->env()->NowMicros());
509 }
510 current_batch_->AddTask(std::move(*task));
511 num_enqueued_tasks_++;
512 }
513 // AddBatch must be called outside of lock, since it may call ReleaseBatch.
514 if (new_batch != nullptr) scheduler_->AddBatch(new_batch);
515 return Status::OK();
516 }
517
518 template <typename TaskType>
ReleaseBatch(const SDBSBatch<TaskType> * batch)519 void SDBSQueue<TaskType>::ReleaseBatch(const SDBSBatch<TaskType>* batch) {
520 mutex_lock l(mu_);
521 num_enqueued_batches_--;
522 num_enqueued_tasks_ -= batch->num_tasks();
523 if (batch == current_batch_) {
524 current_batch_->Close();
525 current_batch_ = nullptr;
526 }
527 }
528
529 template <typename TaskType>
NumEnqueuedTasks()530 size_t SDBSQueue<TaskType>::NumEnqueuedTasks() const {
531 mutex_lock l(mu_);
532 return num_enqueued_tasks_;
533 }
534
535 template <typename TaskType>
SchedulingCapacity()536 size_t SDBSQueue<TaskType>::SchedulingCapacity() const {
537 mutex_lock l(mu_);
538 const int current_batch_capacity =
539 current_batch_ ? options_.max_batch_size - current_batch_->size() : 0;
540 const int spare_batches =
541 options_.max_enqueued_batches - num_enqueued_batches_;
542 return spare_batches * options_.max_batch_size + current_batch_capacity;
543 }
544 } // namespace internal
545 } // namespace serving
546 } // namespace tensorflow
547
548 #endif // TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_SERIAL_DEVICE_BATCH_SCHEDULER_H_
549