• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_SERIAL_DEVICE_BATCH_SCHEDULER_H_
16 #define TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_SERIAL_DEVICE_BATCH_SCHEDULER_H_
17 
18 #include <algorithm>
19 #include <functional>
20 #include <memory>
21 #include <random>
22 #include <unordered_map>
23 #include <vector>
24 
25 #include "tensorflow/core/kernels/batching_util/batch_scheduler.h"
26 #include "tensorflow/core/lib/core/errors.h"
27 #include "tensorflow/core/lib/core/status.h"
28 #include "tensorflow/core/lib/core/threadpool.h"
29 #include "tensorflow/core/platform/cpu_info.h"
30 #include "tensorflow/core/platform/env.h"
31 #include "tensorflow/core/platform/thread_annotations.h"
32 #include "tensorflow/core/platform/types.h"
33 
34 namespace tensorflow {
35 namespace serving {
36 namespace internal {
37 template <typename TaskType>
38 class SDBSBatch;
39 
40 template <typename TaskType>
41 class SDBSQueue;
42 }  // namespace internal
43 
44 // EXPERIMENTAL: API MAY BE SUBJECTED TO SUDDEN CHANGES.
45 //
46 // Shared batch scheduler designed for batches which are processed by a serial
47 // device (e.g. GPU, TPU). When batch processing involves a mix of
48 // parallelizable cpu work and non-parallelizable on-device work, overall
49 // latency can be minimized by producing batches at a (load dependent) rate
50 // which keeps the serial device uniformly busy.
51 //
52 // SerialDeviceBatchScheduler (SDBS) controls the batching rate by limiting the
53 // allowed number of concurrently processed batches. Too large a limit causes
54 // batches to pile up behind the serial device, adding to the overall batch
55 // latency. Too small a limit underutilizes the serial device and harms latency
56 // by forcing batches to wait longer to be processed. Feedback from the device
57 // (i.e. avg number of batches directly pending on the device) is used to set
58 // the correct limit.
59 //
60 // SDBS groups requests into per model batches which are processed when a batch
61 // processing thread becomes available. SDBS prioritizes batches primarily by
62 // age (i.e. the batch's oldest request) along with a configurable preference
63 // for scheduling larger batches first.
64 
65 
66 template <typename TaskType>
67 class SerialDeviceBatchScheduler : public std::enable_shared_from_this<
68                                        SerialDeviceBatchScheduler<TaskType>> {
69  public:
70   ~SerialDeviceBatchScheduler();
71 
72   struct Options {
73     // The name to use for the pool of batch threads.
74     string thread_pool_name = {"batch_threads"};
75     // Maximum number of batch processing threads.
76     int64 num_batch_threads = port::NumSchedulableCPUs();
77     // Although batch selection is primarily based on age, this parameter
78     // specifies a preference for larger batches.  A full batch will be
79     // scheduled before an older, nearly empty batch as long as the age gap is
80     // less than full_batch_scheduling_boost_micros.  The optimal value for this
81     // parameter should be of order the batch processing latency, but must be
82     // chosen carefully, as too large a value will harm tail latency.
83     int64 full_batch_scheduling_boost_micros = 0;
84     // The environment to use (typically only overridden by test code).
85     Env* env = Env::Default();
86     // Initial limit for number of batches being concurrently processed.
87     int64 initial_in_flight_batches_limit = 3;
88     // Returns the current number of batches directly waiting to be processed
89     // by the serial device (i.e. GPU, TPU).
90     std::function<int64()> get_pending_on_serial_device;
91     // Desired average number of batches directly waiting to be processed by the
92     // serial device. Small numbers of O(1) should deliver the best latency.
93     double target_pending = 2;
94     // Number of batches between potential adjustments of
95     // in_flight_batches_limit.  Larger numbers will reduce noise, but will be
96     // less responsive to sudden changes in workload.
97     int64 batches_to_average_over = 1000;
98   };
99 
100   // Ownership is shared between the caller of Create() and any queues created
101   // via AddQueue().
102   static Status Create(
103       const Options& options,
104       std::shared_ptr<SerialDeviceBatchScheduler<TaskType>>* scheduler);
105 
106   struct QueueOptions {
107     // Maximum size of each batch.
108     int max_batch_size = 1000;
109     // Maximum number of enqueued (i.e. non-scheduled) batches.
110     int max_enqueued_batches = 10;
111   };
112 
113   using BatchProcessor = std::function<void(std::unique_ptr<Batch<TaskType>>)>;
114 
115   // Adds queue (and its callback) to be managed by this scheduler.
116   Status AddQueue(const QueueOptions& options,
117                   BatchProcessor process_batch_callback,
118                   std::unique_ptr<BatchScheduler<TaskType>>* queue);
119 
in_flight_batches_limit()120   double in_flight_batches_limit() {
121     mutex_lock l(mu_);
122     return in_flight_batches_limit_;
123   }
124 
recent_low_traffic_ratio()125   double recent_low_traffic_ratio() {
126     mutex_lock l(mu_);
127     return recent_low_traffic_ratio_;
128   }
129 
130  private:
131   // access to AddBatch(), RemoveQueue(), env().
132   friend class internal::SDBSQueue<TaskType>;
133 
134   explicit SerialDeviceBatchScheduler(const Options& options);
135 
136   // Continuously retrieves and processes batches.
137   void ProcessBatches();
138 
139   // Notifies scheduler of non-empty batch which is eligible for processing.
140   void AddBatch(const internal::SDBSBatch<TaskType>* batch);
141 
142   // Removes queue from scheduler.
143   void RemoveQueue(const internal::SDBSQueue<TaskType>* queue);
144 
env()145   Env* env() const { return options_.env; }
146 
147   const Options options_;
148 
149   // Collection of batches added by AddBatch. Owned by scheduler until they are
150   // released for processing.
151   std::vector<const internal::SDBSBatch<TaskType>*> batches_ TF_GUARDED_BY(mu_);
152 
153   // Unowned queues and callbacks added by AddQueue.
154   std::unordered_map<const internal::SDBSQueue<TaskType>*, BatchProcessor>
155       queues_and_callbacks_ TF_GUARDED_BY(mu_);
156 
157   // Responsible for running the batch processing callbacks.
158   std::unique_ptr<thread::ThreadPool> batch_thread_pool_;
159 
160   // Limit on number of batches which can be concurrently processed.
161   int64 in_flight_batches_limit_ TF_GUARDED_BY(mu_);
162 
163   // Number of batch processing threads.
164   int64 processing_threads_ TF_GUARDED_BY(mu_) = 0;
165 
166   // Number of batches processed since the last in_flight_batches_limit_
167   // adjustment.
168   int64 batch_count_ TF_GUARDED_BY(mu_) = 0;
169 
170   // Number of times since the last in_flight_batches_limit_ adjustment when a
171   // processing thread was available but there were no batches to process.
172   int64 no_batch_count_ TF_GUARDED_BY(mu_) = 0;
173 
174   // Sum of batches pending on the serial device since the last
175   // in_flight_batches_limit_ adjustment.
176   int64 pending_sum_ = 0;
177 
178   // Sum of batch latencies since the last in_flight_batches_limit_ adjustment.
179   int64 batch_latency_sum_ = 0;
180 
181   // Average period between which two consecutive batches begin processing.
182   int64 batch_period_micros_ = 0;
183 
184   // Moving average tracking the fraction of recent in_flight_batches_limit_
185   // adjustments where the external traffic was not high enough to provide
186   // useful feedback for an adjustment.
187   double recent_low_traffic_ratio_ = 0;
188 
189   mutex mu_;
190 
191   TF_DISALLOW_COPY_AND_ASSIGN(SerialDeviceBatchScheduler);
192 };
193 
194 //////////////////////////////////////////////////////////
195 // Implementation details follow. API users need not read.
196 
197 namespace internal {
198 // Consolidates tasks into batches, passing them off to the
199 // SerialDeviceBatchScheduler for processing.
200 template <typename TaskType>
201 class SDBSQueue : public BatchScheduler<TaskType> {
202  public:
203   using QueueOptions =
204       typename SerialDeviceBatchScheduler<TaskType>::QueueOptions;
205 
206   SDBSQueue(std::shared_ptr<SerialDeviceBatchScheduler<TaskType>> scheduler,
207             const QueueOptions& options);
208 
209   ~SDBSQueue() override;
210 
211   // Adds task to current batch. Fails if the task size is larger than the batch
212   // size or if the current batch is full and this queue's number of outstanding
213   // batches is at its maximum.
214   Status Schedule(std::unique_ptr<TaskType>* task) override;
215 
216   // Number of tasks waiting to be scheduled.
217   size_t NumEnqueuedTasks() const override;
218 
219   // Number of size 1 tasks which could currently be scheduled without failing.
220   size_t SchedulingCapacity() const override;
221 
222   // Notifies queue that a batch is about to be scheduled; the queue should not
223   // place any more tasks in this batch.
224   void ReleaseBatch(const SDBSBatch<TaskType>* batch);
225 
max_task_size()226   size_t max_task_size() const override { return options_.max_batch_size; }
227 
228  private:
229   std::shared_ptr<SerialDeviceBatchScheduler<TaskType>> scheduler_;
230   const QueueOptions options_;
231   // Owned by scheduler_.
232   SDBSBatch<TaskType>* current_batch_ TF_GUARDED_BY(mu_) = nullptr;
233   int64 num_enqueued_batches_ TF_GUARDED_BY(mu_) = 0;
234   int64 num_enqueued_tasks_ TF_GUARDED_BY(mu_) = 0;
235   mutable mutex mu_;
236   TF_DISALLOW_COPY_AND_ASSIGN(SDBSQueue);
237 };
238 
239 // Batch which remembers when and by whom it was created.
240 template <typename TaskType>
241 class SDBSBatch : public Batch<TaskType> {
242  public:
SDBSBatch(SDBSQueue<TaskType> * queue,int64 creation_time_micros)243   SDBSBatch(SDBSQueue<TaskType>* queue, int64 creation_time_micros)
244       : queue_(queue), creation_time_micros_(creation_time_micros) {}
245 
~SDBSBatch()246   ~SDBSBatch() override {}
247 
queue()248   SDBSQueue<TaskType>* queue() const { return queue_; }
249 
creation_time_micros()250   int64 creation_time_micros() const { return creation_time_micros_; }
251 
252  private:
253   SDBSQueue<TaskType>* queue_;
254   const int64 creation_time_micros_;
255   TF_DISALLOW_COPY_AND_ASSIGN(SDBSBatch);
256 };
257 }  // namespace internal
258 
259 // ---------------- SerialDeviceBatchScheduler ----------------
260 
261 template <typename TaskType>
Create(const Options & options,std::shared_ptr<SerialDeviceBatchScheduler<TaskType>> * scheduler)262 Status SerialDeviceBatchScheduler<TaskType>::Create(
263     const Options& options,
264     std::shared_ptr<SerialDeviceBatchScheduler<TaskType>>* scheduler) {
265   if (options.num_batch_threads < 1) {
266     return errors::InvalidArgument("num_batch_threads must be positive; was ",
267                                    options.num_batch_threads);
268   }
269   if (options.initial_in_flight_batches_limit < 1) {
270     return errors::InvalidArgument(
271         "initial_in_flight_batches_limit must be positive; was ",
272         options.initial_in_flight_batches_limit);
273   }
274   if (options.initial_in_flight_batches_limit > options.num_batch_threads) {
275     return errors::InvalidArgument(
276         "initial_in_flight_batches_limit (",
277         options.initial_in_flight_batches_limit,
278         ") should not be larger than num_batch_threads (",
279         options.num_batch_threads, ")");
280   }
281   if (options.full_batch_scheduling_boost_micros < 0) {
282     return errors::InvalidArgument(
283         "full_batch_scheduling_boost_micros can't be negative; was ",
284         options.full_batch_scheduling_boost_micros);
285   }
286   if (options.batches_to_average_over < 1) {
287     return errors::InvalidArgument(
288         "batches_to_average_over should be "
289         "greater than or equal to 1; was ",
290         options.batches_to_average_over);
291   }
292   if (options.target_pending <= 0) {
293     return errors::InvalidArgument(
294         "target_pending should be larger than zero; was ",
295         options.target_pending);
296   }
297   if (!options.get_pending_on_serial_device) {
298     return errors::InvalidArgument(
299         "get_pending_on_serial_device must be "
300         "specified");
301   }
302   scheduler->reset(new SerialDeviceBatchScheduler<TaskType>(options));
303   return Status::OK();
304 }
305 
306 template <typename TaskType>
SerialDeviceBatchScheduler(const Options & options)307 SerialDeviceBatchScheduler<TaskType>::SerialDeviceBatchScheduler(
308     const Options& options)
309     : options_(options),
310       in_flight_batches_limit_(options.initial_in_flight_batches_limit),
311       processing_threads_(options.initial_in_flight_batches_limit) {
312   batch_thread_pool_.reset(new thread::ThreadPool(
313       env(), options.thread_pool_name, options.num_batch_threads));
314   for (int i = 0; i < processing_threads_; i++) {
315     batch_thread_pool_->Schedule(
316         std::bind(&SerialDeviceBatchScheduler<TaskType>::ProcessBatches, this));
317   }
318 }
319 
320 template <typename TaskType>
~SerialDeviceBatchScheduler()321 SerialDeviceBatchScheduler<TaskType>::~SerialDeviceBatchScheduler() {
322   // Signal processing threads to exit.
323   {
324     mutex_lock l(mu_);
325     processing_threads_ = 0;
326   }
327   // Hangs until all threads finish.
328   batch_thread_pool_.reset();
329 }
330 
331 template <typename TaskType>
AddQueue(const QueueOptions & options,BatchProcessor process_batch_callback,std::unique_ptr<BatchScheduler<TaskType>> * queue)332 Status SerialDeviceBatchScheduler<TaskType>::AddQueue(
333     const QueueOptions& options, BatchProcessor process_batch_callback,
334     std::unique_ptr<BatchScheduler<TaskType>>* queue) {
335   if (options.max_batch_size <= 0) {
336     return errors::InvalidArgument("max_batch_size must be positive; was ",
337                                    options.max_batch_size);
338   }
339   if (options.max_enqueued_batches <= 0) {
340     return errors::InvalidArgument(
341         "max_enqueued_batches must be positive; was ",
342         options.max_enqueued_batches);
343   }
344   internal::SDBSQueue<TaskType>* SDBS_queue_raw;
345   queue->reset(SDBS_queue_raw = new internal::SDBSQueue<TaskType>(
346                    this->shared_from_this(), options));
347   mutex_lock l(mu_);
348   queues_and_callbacks_[SDBS_queue_raw] = process_batch_callback;
349   return Status::OK();
350 }
351 
352 template <typename TaskType>
AddBatch(const internal::SDBSBatch<TaskType> * batch)353 void SerialDeviceBatchScheduler<TaskType>::AddBatch(
354     const internal::SDBSBatch<TaskType>* batch) {
355   mutex_lock l(mu_);
356   batches_.push_back(batch);
357 }
358 
359 template <typename TaskType>
RemoveQueue(const internal::SDBSQueue<TaskType> * queue)360 void SerialDeviceBatchScheduler<TaskType>::RemoveQueue(
361     const internal::SDBSQueue<TaskType>* queue) {
362   mutex_lock l(mu_);
363   queues_and_callbacks_.erase(queue);
364 }
365 
366 template <typename TaskType>
ProcessBatches()367 void SerialDeviceBatchScheduler<TaskType>::ProcessBatches() {
368   const int64 kIdleThreadSleepTimeMicros = 1000;
369   const double kMaxNoBatchRatio = .1;
370   const double kLowTrafficMovingAverageFactor = .1;
371   for (;;) {
372     mu_.lock();
373     if (processing_threads_ < 1 ||
374         processing_threads_ > in_flight_batches_limit_) {
375       processing_threads_--;
376       mu_.unlock();
377       break;
378     }
379     if (batches_.empty()) {
380       no_batch_count_++;
381       int64 sleep_time = batch_period_micros_ ? batch_period_micros_
382                                               : kIdleThreadSleepTimeMicros;
383       mu_.unlock();
384       env()->SleepForMicroseconds(sleep_time);
385       continue;
386     }
387     auto best_it = batches_.begin();
388     double best_score =
389         (*best_it)->creation_time_micros() -
390         options_.full_batch_scheduling_boost_micros * (*best_it)->size() /
391             static_cast<double>((*best_it)->queue()->max_task_size());
392     for (auto it = batches_.begin() + 1; it != batches_.end(); it++) {
393       const double score =
394           (*it)->creation_time_micros() -
395           options_.full_batch_scheduling_boost_micros * (*it)->size() /
396               static_cast<double>((*it)->queue()->max_task_size());
397       if (score < best_score) {
398         best_score = score;
399         best_it = it;
400       }
401     }
402     const internal::SDBSBatch<TaskType>* batch = *best_it;
403     batches_.erase(best_it);
404     // Queue may destroy itself after ReleaseBatch is called.
405     batch->queue()->ReleaseBatch(batch);
406     auto callback = queues_and_callbacks_[batch->queue()];
407     mu_.unlock();
408     int64 start_time = env()->NowMicros();
409     callback(std::unique_ptr<Batch<TaskType>>(
410         const_cast<internal::SDBSBatch<TaskType>*>(batch)));
411     int64 end_time = env()->NowMicros();
412     mu_.lock();
413     batch_count_++;
414     batch_latency_sum_ += end_time - start_time;
415     pending_sum_ += options_.get_pending_on_serial_device();
416     if (batch_count_ == options_.batches_to_average_over) {
417       recent_low_traffic_ratio_ *= (1 - kLowTrafficMovingAverageFactor);
418       // Only adjust in_flight_batches_limit_ if external load is large enough
419       // to consistently provide batches. Otherwise we would (mistakenly) assume
420       // that the device is underutilized because in_flight_batches_limit_ is
421       // too small.
422       if (no_batch_count_ < kMaxNoBatchRatio * batch_count_) {
423         double avg_pending = pending_sum_ / static_cast<double>(batch_count_);
424         // Avg processing time / # of concurrent batches gives the avg period
425         // between which two consecutive batches begin processing. Used to set a
426         // reasonable sleep time for idle batch processing threads.
427         batch_period_micros_ =
428             batch_latency_sum_ / batch_count_ / in_flight_batches_limit_;
429         // When the processing pipeline is consistently busy, the average number
430         // of pending batches differs from in_flight_batches_limit_ by a
431         // load-dependent offset. Adjust in_flight_batches_limit_to maintain
432         // the desired target pending.
433         in_flight_batches_limit_ +=
434             std::round(options_.target_pending - avg_pending);
435         in_flight_batches_limit_ = std::max(in_flight_batches_limit_, int64{1});
436         in_flight_batches_limit_ =
437             std::min(in_flight_batches_limit_, options_.num_batch_threads);
438         // Add extra processing threads if necessary.
439         if (processing_threads_ > 0 &&
440             processing_threads_ < in_flight_batches_limit_) {
441           int extra_threads = in_flight_batches_limit_ - processing_threads_;
442           for (int i = 0; i < extra_threads; i++) {
443             batch_thread_pool_->Schedule(std::bind(
444                 &SerialDeviceBatchScheduler<TaskType>::ProcessBatches, this));
445           }
446           processing_threads_ = in_flight_batches_limit_;
447         }
448       } else {
449         recent_low_traffic_ratio_ += kLowTrafficMovingAverageFactor;
450       }
451       batch_count_ = 0;
452       no_batch_count_ = 0;
453       pending_sum_ = 0;
454       batch_latency_sum_ = 0;
455     }
456     mu_.unlock();
457   }
458 }
459 
460 // ---------------- SDBSQueue ----------------
461 
462 namespace internal {
463 template <typename TaskType>
SDBSQueue(std::shared_ptr<SerialDeviceBatchScheduler<TaskType>> scheduler,const QueueOptions & options)464 SDBSQueue<TaskType>::SDBSQueue(
465     std::shared_ptr<SerialDeviceBatchScheduler<TaskType>> scheduler,
466     const QueueOptions& options)
467     : scheduler_(scheduler), options_(options) {}
468 
469 template <typename TaskType>
~SDBSQueue()470 SDBSQueue<TaskType>::~SDBSQueue() {
471   // Wait until last batch has been scheduled.
472   const int kSleepMicros = 1000;
473   for (;;) {
474     {
475       mutex_lock l(mu_);
476       if (num_enqueued_batches_ == 0) {
477         break;
478       }
479     }
480     scheduler_->env()->SleepForMicroseconds(kSleepMicros);
481   }
482   scheduler_->RemoveQueue(this);
483 }
484 
485 template <typename TaskType>
Schedule(std::unique_ptr<TaskType> * task)486 Status SDBSQueue<TaskType>::Schedule(std::unique_ptr<TaskType>* task) {
487   SDBSBatch<TaskType>* new_batch = nullptr;
488   size_t size = (*task)->size();
489   if (size > options_.max_batch_size) {
490     return errors::InvalidArgument("Task size ", size,
491                                    " is larger than maximum batch size ",
492                                    options_.max_batch_size);
493   }
494   {
495     mutex_lock l(mu_);
496     // Current batch is full, create another if allowed.
497     if (current_batch_ &&
498         current_batch_->size() + size > options_.max_batch_size) {
499       if (num_enqueued_batches_ >= options_.max_enqueued_batches) {
500         return errors::Unavailable("The batch scheduling queue is full");
501       }
502       current_batch_->Close();
503       current_batch_ = nullptr;
504     }
505     if (!current_batch_) {
506       num_enqueued_batches_++;
507       current_batch_ = new_batch =
508           new SDBSBatch<TaskType>(this, scheduler_->env()->NowMicros());
509     }
510     current_batch_->AddTask(std::move(*task));
511     num_enqueued_tasks_++;
512   }
513   // AddBatch must be called outside of lock, since it may call ReleaseBatch.
514   if (new_batch != nullptr) scheduler_->AddBatch(new_batch);
515   return Status::OK();
516 }
517 
518 template <typename TaskType>
ReleaseBatch(const SDBSBatch<TaskType> * batch)519 void SDBSQueue<TaskType>::ReleaseBatch(const SDBSBatch<TaskType>* batch) {
520   mutex_lock l(mu_);
521   num_enqueued_batches_--;
522   num_enqueued_tasks_ -= batch->num_tasks();
523   if (batch == current_batch_) {
524     current_batch_->Close();
525     current_batch_ = nullptr;
526   }
527 }
528 
529 template <typename TaskType>
NumEnqueuedTasks()530 size_t SDBSQueue<TaskType>::NumEnqueuedTasks() const {
531   mutex_lock l(mu_);
532   return num_enqueued_tasks_;
533 }
534 
535 template <typename TaskType>
SchedulingCapacity()536 size_t SDBSQueue<TaskType>::SchedulingCapacity() const {
537   mutex_lock l(mu_);
538   const int current_batch_capacity =
539       current_batch_ ? options_.max_batch_size - current_batch_->size() : 0;
540   const int spare_batches =
541       options_.max_enqueued_batches - num_enqueued_batches_;
542   return spare_batches * options_.max_batch_size + current_batch_capacity;
543 }
544 }  // namespace internal
545 }  // namespace serving
546 }  // namespace tensorflow
547 
548 #endif  // TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_SERIAL_DEVICE_BATCH_SCHEDULER_H_
549