• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_ADAPTIVE_SHARED_BATCH_SCHEDULER_H_
17 #define TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_ADAPTIVE_SHARED_BATCH_SCHEDULER_H_
18 
19 #include <algorithm>
20 #include <atomic>
21 #include <functional>
22 #include <memory>
23 #include <random>
24 #include <unordered_map>
25 #include <vector>
26 
27 #include "absl/types/optional.h"
28 #include "tensorflow/core/kernels/batching_util/batch_scheduler.h"
29 #include "tensorflow/core/kernels/batching_util/periodic_function.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/lib/core/status.h"
32 #include "tensorflow/core/lib/core/threadpool.h"
33 #include "tensorflow/core/platform/byte_order.h"
34 #include "tensorflow/core/platform/cpu_info.h"
35 #include "tensorflow/core/platform/env.h"
36 #include "tensorflow/core/platform/errors.h"
37 #include "tensorflow/core/platform/thread_annotations.h"
38 #include "tensorflow/core/platform/threadpool_interface.h"
39 #include "tensorflow/core/platform/types.h"
40 #include "tensorflow/core/profiler/lib/connected_traceme.h"
41 
42 namespace tensorflow {
43 namespace serving {
44 namespace internal {
45 template <typename TaskType>
46 class ASBSBatch;
47 
48 template <typename TaskType>
49 class ASBSQueue;
50 }  // namespace internal
51 
52 // Shared batch scheduler designed to minimize latency. The scheduler keeps
53 // track of a number of queues (one per model or model version) which are
54 // continuously enqueuing requests. The scheduler groups the requests into
55 // batches which it periodically sends off for processing (see
56 // shared_batch_scheduler.h for more details). AdaptiveSharedBatchScheduler
57 // (ASBS) prioritizes batches primarily by age (i.e. the batch's oldest request)
58 // along with a configurable preference for scheduling larger batches first.
59 //
60 //
61 // ASBS tries to keep the system busy by maintaining an adjustable number of
62 // concurrently processed batches.  If a new batch is created, and the number of
63 // in flight batches is below the target, the next (i.e. oldest) batch is
64 // immediately scheduled.  Similarly, when a batch finishes processing, the
65 // target is rechecked, and another batch may be scheduled.  To avoid the need
66 // to carefully tune the target for workload, model type, platform, etc, it is
67 // dynamically adjusted in order to provide the lowest average latency.
68 //
69 // Some potential use cases:
70 // Hardware Accelerators (GPUs & TPUs) - If some phase of batch processing
71 //   involves serial processing by a device, from a latency perspective it is
72 //   desirable to keep the device evenly loaded, avoiding the need to wait for
73 //   the device to process prior batches.
74 // CPU utilization - If the batch processing is cpu dominated, you can reap
75 //   latency gains when underutilized by increasing the processing rate, but
76 //   back the rate off when the load increases to avoid overload.
77 
78 template <typename TaskType>
79 class AdaptiveSharedBatchScheduler
80     : public std::enable_shared_from_this<
81           AdaptiveSharedBatchScheduler<TaskType>> {
82  public:
~AdaptiveSharedBatchScheduler()83   ~AdaptiveSharedBatchScheduler() {
84     // Finish processing batches before destroying other class members.
85     if (owned_batch_thread_pool_) {
86       delete batch_thread_pool_;
87     }
88   }
89 
90   struct Options {
91     // The name to use for the pool of batch threads.
92     string thread_pool_name = {"batch_threads"};
93     // Number of batch processing threads - the maximum value of
94     // in_flight_batches_limit_.  It is recommended that this value be set by
95     // running the system under load, observing the learned value for
96     // in_flight_batches_limit_, and setting this maximum to ~ 2x the value.
97     // Under low load, in_flight_batches_limit_ has no substantial effect on
98     // latency and therefore undergoes a random walk.  Unreasonably large values
99     // for num_batch_threads allows for large in_flight_batches_limit_, which
100     // will harm latency for some time once load increases again.
101     int64 num_batch_threads = port::MaxParallelism();
102     // You can pass a ThreadPool directly rather than the above two
103     // parameters.  If given, the above two parameers are ignored.  Ownership of
104     // the threadpool is not transferred.
105     thread::ThreadPool* thread_pool = nullptr;
106 
107     // Lower bound for in_flight_batches_limit_. As discussed above, can be used
108     // to minimize the damage caused by the random walk under low load.
109     int64 min_in_flight_batches_limit = 1;
110     // Although batch selection is primarily based on age, this parameter
111     // specifies a preference for larger batches.  A full batch will be
112     // scheduled before an older, nearly empty batch as long as the age gap is
113     // less than full_batch_scheduling_boost_micros.  The optimal value for this
114     // parameter should be of order the batch processing latency, but must be
115     // chosen carefully, as too large a value will harm tail latency.
116     int64 full_batch_scheduling_boost_micros = 0;
117     // The environment to use (typically only overridden by test code).
118     Env* env = Env::Default();
119     // Initial limit for number of batches being concurrently processed.
120     // Non-integer values correspond to probabilistic limits - i.e. a value of
121     // 3.2 results in an actual cap of 3 80% of the time, and 4 20% of the time.
122     double initial_in_flight_batches_limit = 3;
123     // Number of batches between adjustments of in_flight_batches_limit.  Larger
124     // numbers will give less noisy latency measurements, but will be less
125     // responsive to changes in workload.
126     int64 batches_to_average_over = 1000;
127 
128     // If true, schedule batches using FIFO policy.
129     // Requires that `full_batch_scheduling_boost_micros` is zero.
130     // NOTE:
131     // A new parameter is introduced (not re-using
132     // full_batch_scheduling_boost_micros==zero) for backward compatibility of
133     // API.
134     bool fifo_scheduling = false;
135   };
136 
137   // Ownership is shared between the caller of Create() and any queues created
138   // via AddQueue().
139   static Status Create(
140       const Options& options,
141       std::shared_ptr<AdaptiveSharedBatchScheduler<TaskType>>* scheduler);
142 
143   struct QueueOptions {
144     // Maximum size of a batch that's formed within
145     // `ASBSQueue<TaskType>::Schedule`.
146     int max_batch_size = 1000;
147     // Maximum size of input task, which is submitted to the queue by
148     // calling `ASBSQueue<TaskType>::Schedule` and used to form batches.
149     //
150     // If specified, it should be larger than or equal to 'max_batch_size'.
151     absl::optional<int> max_input_task_size = absl::nullopt;
152     // Maximum number of enqueued (i.e. non-scheduled) batches.
153     int max_enqueued_batches = 10;
154     // Amount of time non-full batches must wait before becoming schedulable.
155     // A non-zero value can improve performance by limiting the scheduling of
156     // nearly empty batches.
157     int64 batch_timeout_micros = 0;
158     // If non nullptr, split_input_task_func should split input_task into
159     // multiple tasks, the first of which has size first_size and the remaining
160     // not exceeding max_size. This function may acquire ownership of input_task
161     // and should return a status indicating if the split was successful. Upon
162     // success, the caller can assume that all output_tasks will be scheduled.
163     // Including this option allows the scheduler to pack batches better and
164     // should usually improve overall throughput.
165     std::function<Status(std::unique_ptr<TaskType>* input_task, int first_size,
166                          int max_batch_size,
167                          std::vector<std::unique_ptr<TaskType>>* output_tasks)>
168         split_input_task_func;
169   };
170 
171   using BatchProcessor = std::function<void(std::unique_ptr<Batch<TaskType>>)>;
172 
173   // Adds queue (and its callback) to be managed by this scheduler.
174   Status AddQueue(const QueueOptions& options,
175                   BatchProcessor process_batch_callback,
176                   std::unique_ptr<BatchScheduler<TaskType>>* queue);
177 
in_flight_batches_limit()178   double in_flight_batches_limit() {
179     mutex_lock l(mu_);
180     return in_flight_batches_limit_;
181   }
182 
183  private:
184   // access to AddBatch, MaybeScheduleClosedBatches, RemoveQueue, GetEnv.
185   friend class internal::ASBSQueue<TaskType>;
186 
187   explicit AdaptiveSharedBatchScheduler(const Options& options);
188 
189   // Tracks processing latency and adjusts in_flight_batches_limit to minimize.
190   void CallbackWrapper(const internal::ASBSBatch<TaskType>* batch,
191                        BatchProcessor callback, bool is_express);
192 
193   // Schedules batch if in_flight_batches_limit_ is not met.
194   void MaybeScheduleNextBatch() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
195 
196   // Schedules batch using FIFO policy if in_flight_batches_limit_ is not met.
197   void MaybeScheduleNextBatchFIFO() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
198 
199   // Schedules all closed batches in batches_ for which an idle thread is
200   // available in batch_thread_pool_.
201   // Batches scheduled this way are called express batches.
202   // Express batches are not limited by in_flight_batches_limit_, and
203   // their latencies will not affect in_flight_batches_limit_.
204   void MaybeScheduleClosedBatches();
205 
206   void MaybeScheduleClosedBatchesLocked() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
207 
208   void MaybeScheduleClosedBatchesLockedFIFO() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
209 
210   // Notifies scheduler of non-empty batch which is eligible for processing.
211   void AddBatch(const internal::ASBSBatch<TaskType>* batch);
212 
213   // Removes queue from scheduler.
214   void RemoveQueue(const internal::ASBSQueue<TaskType>* queue);
215 
GetEnv()216   Env* GetEnv() const { return options_.env; }
217 
218   const Options options_;
219 
220   // Collection of batches added by AddBatch, ordered by age. Owned by scheduler
221   // until they are released for processing.
222   std::vector<const internal::ASBSBatch<TaskType>*> batches_ TF_GUARDED_BY(mu_);
223 
224   // Collection of batches added by AddBatch, ordered by age. Owned by
225   // scheduler until they are released for processing.
226   std::deque<const internal::ASBSBatch<TaskType>*> fifo_batches_
227       TF_GUARDED_BY(mu_);
228 
229   // Unowned queues and callbacks added by AddQueue.
230   std::unordered_map<const internal::ASBSQueue<TaskType>*, BatchProcessor>
231       queues_and_callbacks_ TF_GUARDED_BY(mu_);
232 
233   mutex mu_;
234 
235   // Responsible for running the batch processing callbacks.
236   thread::ThreadPool* batch_thread_pool_;
237 
238   bool owned_batch_thread_pool_ = false;
239 
240   // Limit on number of batches which can be concurrently processed.
241   // Non-integer values correspond to probabilistic limits - i.e. a value of 3.2
242   // results in an actual cap of 3 80% of the time, and 4 20% of the time.
243   double in_flight_batches_limit_ TF_GUARDED_BY(mu_);
244 
245   // Number of regular batches currently being processed.
246   int64 in_flight_batches_ TF_GUARDED_BY(mu_) = 0;
247   // Number of express batches currently being processed.
248   int64 in_flight_express_batches_ TF_GUARDED_BY(mu_) = 0;
249 
250   // RNG engine and distribution.
251   std::default_random_engine rand_engine_;
252   std::uniform_real_distribution<double> rand_double_;
253 
254   // Fields controlling the dynamic adjustment of in_flight_batches_limit_.
255   // Number of batches since the last in_flight_batches_limit_ adjustment.
256   int64 batch_count_ TF_GUARDED_BY(mu_) = 0;
257   // Sum of processing latency for batches counted by batch_count_.
258   int64 batch_latency_sum_ TF_GUARDED_BY(mu_) = 0;
259   // Average batch latency for previous value of in_flight_batches_limit_.
260   double last_avg_latency_ms_ TF_GUARDED_BY(mu_) = 0;
261   // Did last_avg_latency_ms_ decrease from the previous last_avg_latency_ms_?
262   bool last_latency_decreased_ TF_GUARDED_BY(mu_) = false;
263   // Current direction (+-) to adjust in_flight_batches_limit_
264   int step_direction_ TF_GUARDED_BY(mu_) = 1;
265   // Max adjustment size (as a fraction of in_flight_batches_limit_).
266   constexpr static double kMaxStepSizeMultiplier = 0.125;  // 1/8;
267   // Min adjustment size (as a fraction of in_flight_batches_limit_).
268   constexpr static double kMinStepSizeMultiplier = 0.0078125;  // 1/128
269   // Current adjustment size (as a fraction of in_flight_batches_limit_).
270   double step_size_multiplier_ TF_GUARDED_BY(mu_) = kMaxStepSizeMultiplier;
271 
272   TF_DISALLOW_COPY_AND_ASSIGN(AdaptiveSharedBatchScheduler);
273 };
274 
275 //////////////////////////////////////////////////////////
276 // Implementation details follow. API users need not read.
277 
278 namespace internal {
279 // Consolidates tasks into batches, passing them off to the
280 // AdaptiveSharedBatchScheduler for processing.
281 template <typename TaskType>
282 class ASBSQueue : public BatchScheduler<TaskType> {
283  public:
284   using QueueOptions =
285       typename AdaptiveSharedBatchScheduler<TaskType>::QueueOptions;
286 
287   ASBSQueue(std::shared_ptr<AdaptiveSharedBatchScheduler<TaskType>> scheduler,
288             const QueueOptions& options);
289 
290   ~ASBSQueue() override;
291 
292   // Adds task to current batch. Fails if the task size is larger than the batch
293   // size or if the current batch is full and this queue's number of outstanding
294   // batches is at its maximum.
295   Status Schedule(std::unique_ptr<TaskType>* task) override;
296 
297   // Number of tasks waiting to be scheduled.
298   size_t NumEnqueuedTasks() const override;
299 
300   // Number of size 1 tasks which could currently be scheduled without failing.
301   size_t SchedulingCapacity() const override;
302 
303   // Notifies queue that a batch is about to be scheduled; the queue should not
304   // place any more tasks in this batch.
305   void ReleaseBatch(const ASBSBatch<TaskType>* batch);
306 
max_task_size()307   size_t max_task_size() const override { return options_.max_batch_size; }
308 
309  private:
310   // Number of size 1 tasks which could currently be scheduled without failing.
311   size_t SchedulingCapacityLocked() const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
312 
313   // Returns uint64 one greater than was returned by the previous call.
314   // Context id is reused after std::numeric_limits<uint64>::max is exhausted.
315   static uint64 NewTraceMeContextIdForBatch();
316 
317   std::shared_ptr<AdaptiveSharedBatchScheduler<TaskType>> scheduler_;
318   const QueueOptions options_;
319   // Owned by scheduler_.
320   ASBSBatch<TaskType>* current_batch_ TF_GUARDED_BY(mu_) = nullptr;
321   int64 num_enqueued_batches_ TF_GUARDED_BY(mu_) = 0;
322   int64 num_enqueued_tasks_ TF_GUARDED_BY(mu_) = 0;
323   mutable mutex mu_;
324   TF_DISALLOW_COPY_AND_ASSIGN(ASBSQueue);
325 };
326 
327 // Batch which remembers when and by whom it was created.
328 template <typename TaskType>
329 class ASBSBatch : public Batch<TaskType> {
330  public:
ASBSBatch(ASBSQueue<TaskType> * queue,int64_t creation_time_micros,int64_t batch_timeout_micros,uint64 traceme_context_id)331   ASBSBatch(ASBSQueue<TaskType>* queue, int64_t creation_time_micros,
332             int64_t batch_timeout_micros, uint64 traceme_context_id)
333       : queue_(queue),
334         creation_time_micros_(creation_time_micros),
335         schedulable_time_micros_(creation_time_micros + batch_timeout_micros),
336         traceme_context_id_(traceme_context_id) {}
337 
~ASBSBatch()338   ~ASBSBatch() override {}
339 
queue()340   ASBSQueue<TaskType>* queue() const { return queue_; }
341 
creation_time_micros()342   int64 creation_time_micros() const { return creation_time_micros_; }
343 
schedulable_time_micros()344   int64 schedulable_time_micros() const { return schedulable_time_micros_; }
345 
traceme_context_id()346   uint64 traceme_context_id() const { return traceme_context_id_; }
347 
348  private:
349   ASBSQueue<TaskType>* queue_;
350   const int64 creation_time_micros_;
351   const int64 schedulable_time_micros_;
352   const uint64 traceme_context_id_;
353   TF_DISALLOW_COPY_AND_ASSIGN(ASBSBatch);
354 };
355 }  // namespace internal
356 
357 // ---------------- AdaptiveSharedBatchScheduler ----------------
358 
359 template <typename TaskType>
360 constexpr double AdaptiveSharedBatchScheduler<TaskType>::kMaxStepSizeMultiplier;
361 
362 template <typename TaskType>
363 constexpr double AdaptiveSharedBatchScheduler<TaskType>::kMinStepSizeMultiplier;
364 
365 template <typename TaskType>
Create(const Options & options,std::shared_ptr<AdaptiveSharedBatchScheduler<TaskType>> * scheduler)366 Status AdaptiveSharedBatchScheduler<TaskType>::Create(
367     const Options& options,
368     std::shared_ptr<AdaptiveSharedBatchScheduler<TaskType>>* scheduler) {
369   if (options.num_batch_threads < 1) {
370     return errors::InvalidArgument("num_batch_threads must be positive; was ",
371                                    options.num_batch_threads);
372   }
373   if (options.min_in_flight_batches_limit < 1) {
374     return errors::InvalidArgument(
375         "min_in_flight_batches_limit must be >= 1; was ",
376         options.min_in_flight_batches_limit);
377   }
378   if (options.min_in_flight_batches_limit > options.num_batch_threads) {
379     return errors::InvalidArgument(
380         "min_in_flight_batches_limit (", options.min_in_flight_batches_limit,
381         ") must be <= num_batch_threads (", options.num_batch_threads, ")");
382   }
383   if (options.full_batch_scheduling_boost_micros < 0) {
384     return errors::InvalidArgument(
385         "full_batch_scheduling_boost_micros can't be negative; was ",
386         options.full_batch_scheduling_boost_micros);
387   }
388   if (options.initial_in_flight_batches_limit > options.num_batch_threads) {
389     return errors::InvalidArgument(
390         "initial_in_flight_batches_limit (",
391         options.initial_in_flight_batches_limit,
392         ") should not be larger than num_batch_threads (",
393         options.num_batch_threads, ")");
394   }
395   if (options.initial_in_flight_batches_limit <
396       options.min_in_flight_batches_limit) {
397     return errors::InvalidArgument("initial_in_flight_batches_limit (",
398                                    options.initial_in_flight_batches_limit,
399                                    "must be >= min_in_flight_batches_limit (",
400                                    options.min_in_flight_batches_limit, ")");
401   }
402   if (options.batches_to_average_over < 1) {
403     return errors::InvalidArgument(
404         "batches_to_average_over should be "
405         "greater than or equal to 1; was ",
406         options.batches_to_average_over);
407   }
408   scheduler->reset(new AdaptiveSharedBatchScheduler<TaskType>(options));
409   return Status::OK();
410 }
411 
412 template <typename TaskType>
AdaptiveSharedBatchScheduler(const Options & options)413 AdaptiveSharedBatchScheduler<TaskType>::AdaptiveSharedBatchScheduler(
414     const Options& options)
415     : options_(options),
416       in_flight_batches_limit_(options.initial_in_flight_batches_limit),
417       rand_double_(0.0, 1.0) {
418   std::random_device device;
419   rand_engine_.seed(device());
420   if (options.thread_pool == nullptr) {
421     owned_batch_thread_pool_ = true;
422     batch_thread_pool_ = new thread::ThreadPool(
423         GetEnv(), options.thread_pool_name, options.num_batch_threads);
424   } else {
425     owned_batch_thread_pool_ = false;
426     batch_thread_pool_ = options.thread_pool;
427   }
428 }
429 
430 template <typename TaskType>
AddQueue(const QueueOptions & options,BatchProcessor process_batch_callback,std::unique_ptr<BatchScheduler<TaskType>> * queue)431 Status AdaptiveSharedBatchScheduler<TaskType>::AddQueue(
432     const QueueOptions& options, BatchProcessor process_batch_callback,
433     std::unique_ptr<BatchScheduler<TaskType>>* queue) {
434   if (options.max_batch_size <= 0) {
435     return errors::InvalidArgument("max_batch_size must be positive; was ",
436                                    options.max_batch_size);
437   }
438   if (options.max_enqueued_batches <= 0) {
439     return errors::InvalidArgument(
440         "max_enqueued_batches must be positive; was ",
441         options.max_enqueued_batches);
442   }
443   if (options.max_input_task_size.has_value()) {
444     if (options.max_input_task_size.value() < options.max_batch_size) {
445       return errors::InvalidArgument(
446           "max_input_task_size must be larger than or equal to max_batch_size;"
447           "got max_input_task_size as ",
448           options.max_input_task_size.value(), " and max_batch_size as ",
449           options.max_batch_size);
450     }
451   }
452   internal::ASBSQueue<TaskType>* asbs_queue_raw;
453   queue->reset(asbs_queue_raw = new internal::ASBSQueue<TaskType>(
454                    this->shared_from_this(), options));
455   mutex_lock l(mu_);
456   queues_and_callbacks_[asbs_queue_raw] = process_batch_callback;
457   return Status::OK();
458 }
459 
460 template <typename TaskType>
AddBatch(const internal::ASBSBatch<TaskType> * batch)461 void AdaptiveSharedBatchScheduler<TaskType>::AddBatch(
462     const internal::ASBSBatch<TaskType>* batch) {
463   mutex_lock l(mu_);
464   if (options_.fifo_scheduling) {
465     fifo_batches_.push_back(batch);
466   } else {
467     batches_.push_back(batch);
468   }
469   int64_t delay_micros =
470       batch->schedulable_time_micros() - GetEnv()->NowMicros();
471   if (delay_micros <= 0) {
472     MaybeScheduleNextBatch();
473     return;
474   }
475   // Try to schedule batch once it becomes schedulable. Although scheduler waits
476   // for all batches to finish processing before allowing itself to be deleted,
477   // MaybeScheduleNextBatch() is called in other places, and therefore it's
478   // possible the scheduler could be deleted by the time this closure runs.
479   // Grab a shared_ptr reference to prevent this from happening.
480   GetEnv()->SchedClosureAfter(
481       delay_micros, [this, lifetime_preserver = this->shared_from_this()] {
482         mutex_lock l(mu_);
483         MaybeScheduleNextBatch();
484       });
485 }
486 
487 template <typename TaskType>
RemoveQueue(const internal::ASBSQueue<TaskType> * queue)488 void AdaptiveSharedBatchScheduler<TaskType>::RemoveQueue(
489     const internal::ASBSQueue<TaskType>* queue) {
490   mutex_lock l(mu_);
491   queues_and_callbacks_.erase(queue);
492 }
493 
494 template <typename TaskType>
MaybeScheduleNextBatchFIFO()495 void AdaptiveSharedBatchScheduler<TaskType>::MaybeScheduleNextBatchFIFO() {
496   const internal::ASBSBatch<TaskType>* batch = *fifo_batches_.begin();
497   fifo_batches_.pop_front();
498   // Queue may destroy itself after ReleaseBatch is called.
499   batch->queue()->ReleaseBatch(batch);
500   batch_thread_pool_->Schedule(std::bind(
501       &AdaptiveSharedBatchScheduler<TaskType>::CallbackWrapper, this, batch,
502       queues_and_callbacks_[batch->queue()], false /* is express */));
503   in_flight_batches_++;
504 }
505 
506 template <typename TaskType>
507 void AdaptiveSharedBatchScheduler<
MaybeScheduleClosedBatchesLockedFIFO()508     TaskType>::MaybeScheduleClosedBatchesLockedFIFO() {
509   // Only schedule closed batches if we have spare capacity.
510   int available_threads =
511       static_cast<int>(options_.num_batch_threads - in_flight_batches_ -
512                        in_flight_express_batches_);
513   for (auto it = fifo_batches_.begin();
514        it != fifo_batches_.end() && available_threads > 0;
515        it = fifo_batches_.begin()) {
516     if ((*it)->IsClosed()) {
517       const internal::ASBSBatch<TaskType>* batch = *it;
518       fifo_batches_.pop_front();
519       batch->queue()->ReleaseBatch(batch);
520       batch_thread_pool_->Schedule(
521           std::bind(&AdaptiveSharedBatchScheduler<TaskType>::CallbackWrapper,
522                     this, batch, queues_and_callbacks_[batch->queue()], true));
523       in_flight_express_batches_++;
524       available_threads--;
525     } else {
526       // Batches are FIFO, so stop iteration after finding the first non-closed
527       // batches.
528       break;
529     }
530   }
531 }
532 
533 template <typename TaskType>
MaybeScheduleNextBatch()534 void AdaptiveSharedBatchScheduler<TaskType>::MaybeScheduleNextBatch() {
535   bool batch_empty =
536       options_.fifo_scheduling ? fifo_batches_.empty() : batches_.empty();
537   if (batch_empty || in_flight_batches_ >= in_flight_batches_limit_) return;
538   // Non-integer limit handled probabilistically.
539   if (in_flight_batches_limit_ - in_flight_batches_ < 1 &&
540       rand_double_(rand_engine_) >
541           in_flight_batches_limit_ - in_flight_batches_) {
542     return;
543   }
544 
545   if (options_.fifo_scheduling) {
546     MaybeScheduleNextBatchFIFO();
547     return;
548   }
549 
550   auto best_it = batches_.end();
551   double best_score = (std::numeric_limits<double>::max)();
552   int64_t now_micros = GetEnv()->NowMicros();
553   for (auto it = batches_.begin(); it != batches_.end(); it++) {
554     if ((*it)->schedulable_time_micros() > now_micros) continue;
555     const double score =
556         (*it)->creation_time_micros() -
557         options_.full_batch_scheduling_boost_micros * (*it)->size() /
558             static_cast<double>((*it)->queue()->max_task_size());
559     if (best_it == batches_.end() || score < best_score) {
560       best_score = score;
561       best_it = it;
562     }
563   }
564   // No schedulable batches.
565   if (best_it == batches_.end()) return;
566   const internal::ASBSBatch<TaskType>* batch = *best_it;
567   batches_.erase(best_it);
568   // Queue may destroy itself after ReleaseBatch is called.
569   batch->queue()->ReleaseBatch(batch);
570   batch_thread_pool_->Schedule(
571       std::bind(&AdaptiveSharedBatchScheduler<TaskType>::CallbackWrapper, this,
572                 batch, queues_and_callbacks_[batch->queue()], false));
573   in_flight_batches_++;
574 }
575 
576 template <typename TaskType>
MaybeScheduleClosedBatches()577 void AdaptiveSharedBatchScheduler<TaskType>::MaybeScheduleClosedBatches() {
578   mutex_lock l(mu_);
579   MaybeScheduleClosedBatchesLocked();
580 }
581 
582 template <typename TaskType>
583 void AdaptiveSharedBatchScheduler<
MaybeScheduleClosedBatchesLocked()584     TaskType>::MaybeScheduleClosedBatchesLocked() {
585   if (options_.fifo_scheduling) {
586     MaybeScheduleClosedBatchesLockedFIFO();
587     return;
588   }
589   // Only schedule closed batches if we have spare capacity.
590   int available_threads =
591       static_cast<int>(options_.num_batch_threads - in_flight_batches_ -
592                        in_flight_express_batches_);
593   for (auto it = batches_.begin();
594        it != batches_.end() && available_threads > 0;) {
595     if ((*it)->IsClosed()) {
596       const internal::ASBSBatch<TaskType>* batch = *it;
597       it = batches_.erase(it);
598       batch->queue()->ReleaseBatch(batch);
599       batch_thread_pool_->Schedule(
600           std::bind(&AdaptiveSharedBatchScheduler<TaskType>::CallbackWrapper,
601                     this, batch, queues_and_callbacks_[batch->queue()], true));
602       in_flight_express_batches_++;
603       available_threads--;
604     } else {
605       ++it;
606     }
607   }
608 }
609 
610 template <typename TaskType>
CallbackWrapper(const internal::ASBSBatch<TaskType> * batch,AdaptiveSharedBatchScheduler<TaskType>::BatchProcessor callback,bool is_express)611 void AdaptiveSharedBatchScheduler<TaskType>::CallbackWrapper(
612     const internal::ASBSBatch<TaskType>* batch,
613     AdaptiveSharedBatchScheduler<TaskType>::BatchProcessor callback,
614     bool is_express) {
615   profiler::TraceMeConsumer trace_me(
616       [&] {
617         return profiler::TraceMeEncode(
618             "ProcessBatch", {{"batch_size_before_padding", batch->size()},
619                              {"_r", 2} /*root_event*/});
620       },
621       profiler::ContextType::kAdaptiveSharedBatchScheduler,
622       batch->traceme_context_id());
623   int64_t start_time = batch->creation_time_micros();
624   callback(std::unique_ptr<Batch<TaskType>>(
625       const_cast<internal::ASBSBatch<TaskType>*>(batch)));
626   int64_t end_time = GetEnv()->NowMicros();
627   mutex_lock l(mu_);
628   if (is_express) {
629     in_flight_express_batches_--;
630     MaybeScheduleClosedBatchesLocked();
631     return;
632   }
633   in_flight_batches_--;
634   batch_count_++;
635   batch_latency_sum_ += end_time - start_time;
636   // Occasionally adjust in_flight_batches_limit_ to minimize average latency.
637   // Although the optimal value may depend on the workload, the latency should
638   // be a simple convex function of in_flight_batches_limit_, allowing us to
639   // locate the global minimum relatively quickly.
640   if (batch_count_ == options_.batches_to_average_over) {
641     double current_avg_latency_ms = (batch_latency_sum_ / 1000.) / batch_count_;
642     bool current_latency_decreased =
643         current_avg_latency_ms < last_avg_latency_ms_;
644     if (current_latency_decreased) {
645       // If latency improvement was because we're moving in the correct
646       // direction, increase step_size so that we can get to the minimum faster.
647       // If latency improvement was due to backtracking from a previous failure,
648       // decrease step_size in order to refine our location.
649       step_size_multiplier_ *= (last_latency_decreased_ ? 2 : 0.5);
650       step_size_multiplier_ =
651           std::min(step_size_multiplier_, kMaxStepSizeMultiplier);
652       step_size_multiplier_ =
653           std::max(step_size_multiplier_, kMinStepSizeMultiplier);
654     } else {
655       // Return (nearly) to previous position and confirm that latency is better
656       // there before decreasing step size.
657       step_direction_ = -step_direction_;
658     }
659     in_flight_batches_limit_ +=
660         step_direction_ * in_flight_batches_limit_ * step_size_multiplier_;
661     in_flight_batches_limit_ =
662         std::min(in_flight_batches_limit_,
663                  static_cast<double>(options_.num_batch_threads));
664     in_flight_batches_limit_ =
665         std::max(in_flight_batches_limit_,
666                  static_cast<double>(options_.min_in_flight_batches_limit));
667     last_avg_latency_ms_ = current_avg_latency_ms;
668     last_latency_decreased_ = current_latency_decreased;
669     batch_count_ = 0;
670     batch_latency_sum_ = 0;
671   }
672   MaybeScheduleNextBatch();
673 }
674 
675 // ---------------- ASBSQueue ----------------
676 
677 namespace internal {
678 template <typename TaskType>
ASBSQueue(std::shared_ptr<AdaptiveSharedBatchScheduler<TaskType>> scheduler,const QueueOptions & options)679 ASBSQueue<TaskType>::ASBSQueue(
680     std::shared_ptr<AdaptiveSharedBatchScheduler<TaskType>> scheduler,
681     const QueueOptions& options)
682     : scheduler_(scheduler), options_(options) {}
683 
684 template <typename TaskType>
~ASBSQueue()685 ASBSQueue<TaskType>::~ASBSQueue() {
686   // Wait until last batch has been scheduled.
687   const int kSleepMicros = 1000;
688   for (;;) {
689     {
690       mutex_lock l(mu_);
691       if (num_enqueued_batches_ == 0) {
692         break;
693       }
694     }
695     scheduler_->GetEnv()->SleepForMicroseconds(kSleepMicros);
696   }
697   scheduler_->RemoveQueue(this);
698 }
699 
700 template <typename TaskType>
Schedule(std::unique_ptr<TaskType> * task)701 Status ASBSQueue<TaskType>::Schedule(std::unique_ptr<TaskType>* task) {
702   size_t size = (*task)->size();
703   if (options_.split_input_task_func == nullptr &&
704       size > options_.max_batch_size) {
705     return errors::InvalidArgument("Task size ", size,
706                                    " is larger than maximum batch size ",
707                                    options_.max_batch_size);
708   }
709   if (options_.max_input_task_size.has_value() &&
710       (size > options_.max_input_task_size.value())) {
711     return errors::InvalidArgument("Task size ", size,
712                                    " is larger than max input task size ",
713                                    options_.max_input_task_size.value());
714   }
715 
716   std::vector<std::unique_ptr<TaskType>> tasks_to_schedule;
717   std::vector<ASBSBatch<TaskType>*> new_batches;
718   bool closed_batch = false;
719   {
720     mutex_lock l(mu_);
721     if (size > SchedulingCapacityLocked()) {
722       return errors::Unavailable("The batch scheduling queue is full");
723     }
724 
725     int remaining_batch_size =
726         current_batch_ == nullptr
727             ? options_.max_batch_size
728             : options_.max_batch_size - current_batch_->size();
729     if (options_.split_input_task_func == nullptr ||
730         size <= remaining_batch_size) {
731       // Either we don't allow task splitting or task fits within the current
732       // batch.
733       tasks_to_schedule.push_back(std::move(*task));
734     } else {
735       // Split task in order to completely fill the current batch.
736       // Beyond this point Schedule should not fail, as the caller has been
737       // promised that all of the split tasks will be scheduled.
738       TF_RETURN_IF_ERROR(options_.split_input_task_func(
739           task, remaining_batch_size, options_.max_batch_size,
740           &tasks_to_schedule));
741     }
742     for (auto& task : tasks_to_schedule) {
743       // Can't fit within current batch, close it off and try to create another.
744       if (current_batch_ &&
745           current_batch_->size() + task->size() > options_.max_batch_size) {
746         current_batch_->Close();
747         closed_batch = true;
748         current_batch_ = nullptr;
749       }
750       if (!current_batch_) {
751         num_enqueued_batches_++;
752         // batch.traceme_context_id connects TraceMeProducer and
753         // TraceMeConsumer.
754         // When multiple calls to "ASBS::Schedule" accumulate to one batch, they
755         // are processed in the same batch and should share traceme_context_id.
756         current_batch_ = new ASBSBatch<TaskType>(
757             this, scheduler_->GetEnv()->NowMicros(),
758             options_.batch_timeout_micros, NewTraceMeContextIdForBatch());
759         new_batches.push_back(current_batch_);
760       }
761 
762       // Annotate each task (corresponds to one call of schedule) with a
763       // TraceMeProducer.
764       profiler::TraceMeProducer trace_me(
765           [task_size = task->size()] {
766             return profiler::TraceMeEncode(
767                 "ASBSQueue::Schedule",
768                 {{"batching_input_task_size", task_size}});
769           },
770           profiler::ContextType::kAdaptiveSharedBatchScheduler,
771           this->current_batch_->traceme_context_id());
772       current_batch_->AddTask(std::move(task));
773       num_enqueued_tasks_++;
774       // If current_batch_ is now full, allow it to be processed immediately.
775       if (current_batch_->size() == options_.max_batch_size) {
776         current_batch_->Close();
777         closed_batch = true;
778         current_batch_ = nullptr;
779       }
780     }
781   }
782   // Scheduler functions must be called outside of lock, since they may call
783   // ReleaseBatch.
784   for (auto* batch : new_batches) {
785     scheduler_->AddBatch(batch);
786   }
787   if (closed_batch) {
788     scheduler_->MaybeScheduleClosedBatches();
789   }
790   return Status::OK();
791 }
792 
793 template <typename TaskType>
ReleaseBatch(const ASBSBatch<TaskType> * batch)794 void ASBSQueue<TaskType>::ReleaseBatch(const ASBSBatch<TaskType>* batch) {
795   mutex_lock l(mu_);
796   num_enqueued_batches_--;
797   num_enqueued_tasks_ -= batch->num_tasks();
798   if (batch == current_batch_) {
799     current_batch_->Close();
800     current_batch_ = nullptr;
801   }
802 }
803 
804 template <typename TaskType>
NumEnqueuedTasks()805 size_t ASBSQueue<TaskType>::NumEnqueuedTasks() const {
806   mutex_lock l(mu_);
807   return num_enqueued_tasks_;
808 }
809 
810 template <typename TaskType>
SchedulingCapacity()811 size_t ASBSQueue<TaskType>::SchedulingCapacity() const {
812   mutex_lock l(mu_);
813   return SchedulingCapacityLocked();
814 }
815 
816 template <typename TaskType>
SchedulingCapacityLocked()817 size_t ASBSQueue<TaskType>::SchedulingCapacityLocked() const {
818   const int current_batch_capacity =
819       current_batch_ ? options_.max_batch_size - current_batch_->size() : 0;
820   const int spare_batches =
821       options_.max_enqueued_batches - num_enqueued_batches_;
822   return spare_batches * options_.max_batch_size + current_batch_capacity;
823 }
824 
825 template <typename TaskType>
826 // static
NewTraceMeContextIdForBatch()827 uint64 ASBSQueue<TaskType>::NewTraceMeContextIdForBatch() {
828   static std::atomic<uint64> traceme_context_id(0);
829   return traceme_context_id.fetch_add(1, std::memory_order_relaxed);
830 }
831 }  // namespace internal
832 }  // namespace serving
833 }  // namespace tensorflow
834 
835 #endif  // TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_ADAPTIVE_SHARED_BATCH_SCHEDULER_H_
836