• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_SHARED_BATCH_SCHEDULER_H_
17 #define TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_SHARED_BATCH_SCHEDULER_H_
18 
19 #include <stddef.h>
20 
21 #include <deque>
22 #include <functional>
23 #include <list>
24 #include <memory>
25 #include <string>
26 #include <utility>
27 #include <vector>
28 
29 #include "absl/time/clock.h"
30 #include "tensorflow/core/kernels/batching_util/batch_scheduler.h"
31 #include "tensorflow/core/kernels/batching_util/periodic_function.h"
32 #include "tensorflow/core/lib/core/errors.h"
33 #include "tensorflow/core/lib/core/status.h"
34 #include "tensorflow/core/lib/strings/strcat.h"
35 #include "tensorflow/core/platform/byte_order.h"
36 #include "tensorflow/core/platform/cpu_info.h"
37 #include "tensorflow/core/platform/env.h"
38 #include "tensorflow/core/platform/errors.h"
39 #include "tensorflow/core/platform/thread_annotations.h"
40 #include "tensorflow/core/platform/types.h"
41 #include "tensorflow/core/profiler/lib/connected_traceme.h"
42 #include "tensorflow/core/profiler/lib/traceme.h"
43 #include "tensorflow/core/profiler/lib/traceme_encode.h"
44 
45 namespace tensorflow {
46 namespace serving {
47 namespace internal {
48 template <typename TaskType>
49 class Queue;
50 }  // namespace internal
51 }  // namespace serving
52 }  // namespace tensorflow
53 
54 namespace tensorflow {
55 namespace serving {
56 
57 // A batch scheduler for server instances that service multiple request types
58 // (e.g. multiple machine-learned models, or multiple versions of a model served
59 // concurrently), or even multiple distinct tasks for a given request. The
60 // scheduler multiplexes batches of different kinds of tasks onto a fixed-size
61 // thread pool (each batch contains tasks of a single type), in a carefully
62 // controlled manner. A common configuration is to set the number of threads
63 // equal to the number of hardware accelerator units, in which case the
64 // scheduler takes care of multiplexing the task types onto the shared hardware,
65 // in a manner that is both fair and efficient.
66 //
67 // Semantically, SharedBatchScheduler behaves like having N instances of
68 // BasicBatchScheduler (see basic_batch_scheduler.h), one per task type. The
69 // difference is that under the covers there is a single shared thread pool,
70 // instead of N independent ones, with their sharing deliberately coordinated.
71 //
72 // SharedBatchScheduler does not implement the BatchScheduler API; rather, it
73 // presents an abstraction of "queues", where each queue corresponds to one type
74 // of task. Tasks submitted to a given queue are placed in their own batches,
75 // and cannot be mixed with other tasks. Queues can be added and deleted
76 // dynamically, to accommodate e.g. versions of a model being brought up and
77 // down over the lifetime of a server.
78 //
79 // The batch thread pool round-robins through the queues, running one batch
80 // from a queue and then moving to the next queue. Each queue behaves like a
81 // BasicBatchScheduler instance, in the sense that it has maximum batch size and
82 // timeout parameters, which govern when a batch is eligible to be processed.
83 //
84 // Each queue is independently configured with a maximum size (in terms of the
85 // maximum number of batches worth of enqueued tasks). For online serving, it is
86 // recommended that the queue sizes be configured such that the sum of the sizes
87 // of the active queues roughly equal the number of batch threads. (The idea is
88 // that if all threads become available at roughly the same time, there will be
89 // enough enqueued work for them to take on, but no more.)
90 //
91 // If queue sizes are configured in the manner suggested above, the maximum time
92 // a task can spend in a queue before being placed in a batch and assigned to a
93 // thread for processing, is the greater of:
94 //  - the maximum time to process one batch of tasks from any active queue
95 //  - the configured timeout parameter for the task's queue (which can be 0)
96 //
97 // For bulk processing jobs and throughput-oriented benchmarks, you may want to
98 // set the maximum queue size to a large value.
99 //
100 // TODO(b/26539183): Support queue servicing policies other than round-robin.
101 // E.g. let each queue specify a "share" (an int >= 1), so e.g. with queues A
102 // and B having shares 1 and 2 respectively, the servicing pattern is ABBABB...
103 //
104 //
105 // PERFORMANCE TUNING: See README.md.
106 //
107 template <typename TaskType>
108 class SharedBatchScheduler
109     : public std::enable_shared_from_this<SharedBatchScheduler<TaskType>> {
110  public:
111   // TODO(b/25089730): Tune defaults based on best practices as they develop.
112   struct Options {
113     // The name to use for the pool of batch threads.
114     string thread_pool_name = {"batch_threads"};
115 
116     // The number of threads to use to process batches.
117     // Must be >= 1, and should be tuned carefully.
118     int num_batch_threads = port::MaxParallelism();
119 
120     // The environment to use.
121     // (Typically only overridden by test code.)
122     Env* env = Env::Default();
123   };
124   // Ownership is shared between the caller of Create() and any queues created
125   // via AddQueue().
126   static Status Create(
127       const Options& options,
128       std::shared_ptr<SharedBatchScheduler<TaskType>>* scheduler);
129 
130   ~SharedBatchScheduler();
131 
132   // Adds a queue to which tasks may be submitted. The returned queue implements
133   // the BatchScheduler API. Each queue has its own set of scheduling options,
134   // and its own callback to process batches of tasks submitted to the queue.
135   //
136   // The returned queue's destructor blocks until all tasks submitted to it have
137   // been processed.
138   struct QueueOptions {
139     // The size limit of an input batch to the queue.
140     //
141     // If `enable_large_batch_splitting` is True, 'input_batch_size_limit'
142     // should be greater or equal than `max_execution_batch_size`; otherwise
143     // `input_batch_size_limit` should be equal to `max_execution_batch_size`.
144     size_t input_batch_size_limit = 1000;
145 
146     // If a task has been enqueued for this amount of time (in microseconds),
147     // and a thread is available, the scheduler will immediately form a batch
148     // from enqueued tasks and assign the batch to the thread for processing,
149     // even if the batch's size is below 'input_batch_size_limit'.
150     //
151     // This parameter offers a way to bound queue latency, so that a task isn't
152     // stuck in the queue indefinitely waiting for enough tasks to arrive to
153     // make a full batch. (The latency bound is given in the class documentation
154     // above.)
155     //
156     // The goal is to smooth out batch sizes under low request rates, and thus
157     // avoid latency spikes.
158     int64 batch_timeout_micros = 0;
159 
160     // The maximum allowable number of enqueued (accepted by Schedule() but
161     // not yet being processed on a batch thread) tasks in terms of batches.
162     // If this limit is reached, Schedule() will return an UNAVAILABLE error.
163     // See the class documentation above for guidelines on how to tune this
164     // parameter.
165     size_t max_enqueued_batches = 10;
166 
167     // If true, queue implementation would split one input batch task into
168     // subtasks (as specified by `split_input_task_func` below) and fit subtasks
169     // into different batches.
170     //
171     // For usage of `split_input_task_func`, please see its comment.
172     bool enable_large_batch_splitting = false;
173 
174     // `input_task`: a unit of task to be split.
175     // `first_output_task_size`: task size of first output.
176     // `max_execution_batch_size`: Maximum size of each batch.
177     // `output_tasks`: A list of output tasks after split.
178     //
179     // REQUIRED:
180     // 1) All `output_tasks` should be non-empty tasks.
181     // 2) Sizes of `output_tasks` add up to size of `input_task`.
182     //
183     // NOTE:
184     // Instantiations of `TaskType` may vary, so it's up to caller to define
185     // how (e.g., which members to access) to split input tasks.
186     std::function<Status(std::unique_ptr<TaskType>* input_task,
187                          int first_output_task_size, int input_batch_size_limit,
188                          std::vector<std::unique_ptr<TaskType>>* output_tasks)>
189         split_input_task_func;
190 
191     // The maximum size of each enqueued batch (i.e., in `batches_`).
192     //
193     // The scheduler may form batches of any size between 1 and this number
194     // (inclusive). If there is a need to quantize the batch sizes, i.e. only
195     // submit batches whose size is in a small set of allowed sizes, that can be
196     // done by adding padding in the process-batch callback.
197     size_t max_execution_batch_size = 1000;
198   };
199   Status AddQueue(const QueueOptions& options,
200                   std::function<void(std::unique_ptr<Batch<TaskType>>)>
201                       process_batch_callback,
202                   std::unique_ptr<BatchScheduler<TaskType>>* queue);
203 
204  private:
205   explicit SharedBatchScheduler(const Options& options);
206 
207   // The code executed in 'batch_threads_'. Obtains a batch to process from the
208   // queue pointed to by 'next_queue_to_schedule_', and processes it. If that
209   // queue declines to provide a batch to process, moves onto the next queue. If
210   // no queues provide a batch to process, just sleeps briefly and exits.
211   void ThreadLogic();
212 
213   const Options options_;
214 
215   mutex mu_;
216 
217   // A list of queues. (We use std::list instead of std::vector to ensure that
218   // iterators are not invalidated by adding/removing elements. It also offers
219   // efficient removal of elements from the middle.)
220   using QueueList = std::list<std::unique_ptr<internal::Queue<TaskType>>>;
221 
222   // All "active" queues, i.e. ones that either:
223   //  - have not been removed, or
224   //  - have been removed but are not yet empty.
225   QueueList queues_ TF_GUARDED_BY(mu_);
226 
227   // An iterator over 'queues_', pointing to the queue from which the next
228   // available batch thread should grab work.
229   typename QueueList::iterator next_queue_to_schedule_ TF_GUARDED_BY(mu_);
230 
231   // Used by idle batch threads to wait for work to enter the system. Notified
232   // whenever a batch becomes schedulable.
233   condition_variable schedulable_batch_cv_;
234 
235   // Threads that process batches obtained from the queues.
236   std::vector<std::unique_ptr<PeriodicFunction>> batch_threads_;
237 
238   TF_DISALLOW_COPY_AND_ASSIGN(SharedBatchScheduler);
239 };
240 
241 //////////
242 // Implementation details follow. API users need not read.
243 
244 namespace internal {
245 
246 // A task queue for SharedBatchScheduler. Accepts tasks and accumulates them
247 // into batches, and dispenses those batches to be processed via a "pull"
248 // interface. The queue's behavior is governed by maximum batch size, timeout
249 // and maximum queue length parameters; see their documentation in
250 // SharedBatchScheduler.
251 //
252 // The queue is implemented as a deque of batches, with these invariants:
253 //  - The number of batches is between 1 and 'options_.max_enqueued_batches'.
254 //  - The back-most batch is open; the rest are closed.
255 //
256 // Submitted tasks are added to the open batch. If that batch doesn't have room
257 // but the queue isn't full, then that batch is closed and a new open batch is
258 // started.
259 //
260 // Batch pull requests are handled by dequeuing the front-most batch if it is
261 // closed. If the front-most batch is open (i.e. the queue contains only one
262 // batch) and has reached the timeout, it is immediately closed and returned;
263 // otherwise no batch is returned for the request.
264 template <typename TaskType>
265 class Queue {
266  public:
267   using ProcessBatchCallback =
268       std::function<void(std::unique_ptr<Batch<TaskType>>)>;
269   using SchedulableBatchCallback = std::function<void()>;
270   using SplitInputTaskIntoSubtasksCallback = std::function<Status(
271       std::unique_ptr<TaskType>* input_task, int open_batch_remaining_slot,
272       int max_execution_batch_size,
273       std::vector<std::unique_ptr<TaskType>>* output_tasks)>;
274   Queue(const typename SharedBatchScheduler<TaskType>::QueueOptions& options,
275         Env* env, ProcessBatchCallback process_batch_callback,
276         SchedulableBatchCallback schedulable_batch_callback);
277 
278   // Illegal to destruct unless the queue is empty.
279   ~Queue();
280 
281   // Submits a task to the queue, with the same semantics as
282   // BatchScheduler::Schedule().
283   Status Schedule(std::unique_ptr<TaskType>* task);
284 
285   // 'ScheduleWithoutSplit'.
286   Status ScheduleWithoutSplit(std::unique_ptr<TaskType>* task);
287 
288   // 'ScheduleWithSplit'
289   Status ScheduleWithSplit(std::unique_ptr<TaskType>* task);
290 
291   // Returns the number of enqueued tasks, with the same semantics as
292   // BatchScheduler::NumEnqueuedTasks().
293   size_t NumEnqueuedTasks() const;
294 
295   // Returns the queue capacity, with the same semantics as
296   // BatchScheduler::SchedulingCapacity().
297   size_t SchedulingCapacity() const;
298 
299   // Returns the maximum allowed size of tasks submitted to the queue.
max_task_size()300   size_t max_task_size() const { return options_.input_batch_size_limit; }
301 
302   // Returns the maximum allowed size of tasks to be enqueued.
303   // Returned value would be less than or equal to the maximum allowed input
304   // size that's provided by caller of batch scheduler.
max_execution_batch_size()305   size_t max_execution_batch_size() const {
306     if (options_.enable_large_batch_splitting) {
307       return options_.max_execution_batch_size;
308     } else {
309       return options_.input_batch_size_limit;
310     }
311   }
312 
313   // Called by a thread that is ready to process a batch, to request one from
314   // this queue. Either returns a batch that is ready to be processed, or
315   // nullptr if the queue declines to schedule a batch at this time. If it
316   // returns a batch, the batch is guaranteed to be closed.
317   std::unique_ptr<Batch<TaskType>> ScheduleBatch();
318 
319   // Processes a batch that has been returned earlier by ScheduleBatch().
320   void ProcessBatch(std::unique_ptr<Batch<TaskType>> batch);
321 
322   // Determines whether the queue is empty, i.e. has no tasks waiting or being
323   // processed.
324   bool IsEmpty() const;
325 
326   // Marks the queue closed, and waits until it is empty.
327   void CloseAndWaitUntilEmpty();
328 
closed()329   bool closed() const TF_NO_THREAD_SAFETY_ANALYSIS { return closed_.load(); }
330 
331  private:
332   // Same as IsEmpty(), but assumes the caller already holds a lock on 'mu_'.
333   bool IsEmptyInternal() const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
334 
335   // Closes the open batch residing at the back of 'batches_', and inserts a
336   // fresh open batch behind it.
337   void StartNewBatch() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
338 
339   // Split `input task` into `output_tasks` according to 'task_sizes'.
340   Status SplitInputBatchIntoSubtasks(
341       std::unique_ptr<TaskType>* input_task,
342       std::vector<std::unique_ptr<TaskType>>* output_tasks)
343       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
344 
345   // Determines whether the open batch residing at the back of 'batches_' is
346   // currently schedulable.
347   bool IsOpenBatchSchedulable() const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
348 
349   const typename SharedBatchScheduler<TaskType>::QueueOptions options_;
350 
351   // The environment to use.
352   Env* env_;
353 
354   // A callback invoked to processes a batch of work units. Always invoked
355   // from a batch thread.
356   ProcessBatchCallback process_batch_callback_;
357 
358   // A callback invoked to notify the scheduler that a new batch has become
359   // schedulable.
360   SchedulableBatchCallback schedulable_batch_callback_;
361 
362   mutable mutex mu_;
363 
364   // Whether this queue can accept new tasks. This variable is monotonic: it
365   // starts as false, and then at some point gets set to true and remains true
366   // for the duration of this object's life.
TF_GUARDED_BY(mu_)367   std::atomic<bool> closed_ TF_GUARDED_BY(mu_){false};
368 
369   // The enqueued batches. See the invariants in the class comments above.
370   std::deque<std::unique_ptr<Batch<TaskType>>> batches_ TF_GUARDED_BY(mu_);
371 
372   // The counter of the TraceMe context ids.
373   uint64 traceme_context_id_counter_ TF_GUARDED_BY(mu_) = 0;
374 
375   // The time at which the first task was added to the open (back-most) batch
376   // in 'batches_'. Valid iff that batch contains at least one task.
377   uint64 open_batch_start_time_micros_ TF_GUARDED_BY(mu_);
378 
379   // Whether this queue contains a batch that is eligible to be scheduled.
380   // Used to keep track of when to call 'schedulable_batch_callback_'.
381   bool schedulable_batch_ TF_GUARDED_BY(mu_) = false;
382 
383   // The number of batches currently being processed by batch threads.
384   // Incremented in ScheduleBatch() and decremented in ProcessBatch().
385   int num_batches_being_processed_ TF_GUARDED_BY(mu_) = 0;
386 
387   // Used by CloseAndWaitUntilEmpty() to wait until the queue is empty, for
388   // the case in which the queue is not empty when CloseAndWaitUntilEmpty()
389   // starts. When ProcessBatch() dequeues the last batch and makes the queue
390   // empty, if 'empty_notification_' is non-null it calls
391   // 'empty_notification_->Notify()'.
392   Notification* empty_notification_ TF_GUARDED_BY(mu_) = nullptr;
393 
394   TF_DISALLOW_COPY_AND_ASSIGN(Queue);
395 };
396 
397 // A RAII-style object that points to a Queue and implements
398 // the BatchScheduler API. To be handed out to clients who call AddQueue().
399 template <typename TaskType>
400 class QueueHandle : public BatchScheduler<TaskType> {
401  public:
402   QueueHandle(std::shared_ptr<SharedBatchScheduler<TaskType>> scheduler,
403               Queue<TaskType>* queue);
404   ~QueueHandle() override;
405 
406   Status Schedule(std::unique_ptr<TaskType>* task) override;
407   size_t NumEnqueuedTasks() const override;
408   size_t SchedulingCapacity() const override;
409 
max_task_size()410   size_t max_task_size() const override { return queue_->max_task_size(); }
411 
412  private:
413   // The scheduler that owns 'queue_'.
414   std::shared_ptr<SharedBatchScheduler<TaskType>> scheduler_;
415 
416   // The queue this handle wraps. Owned by 'scheduler_', which keeps it alive at
417   // least until this class's destructor closes it.
418   Queue<TaskType>* queue_;
419 
420   TF_DISALLOW_COPY_AND_ASSIGN(QueueHandle);
421 };
422 
423 }  // namespace internal
424 
425 template <typename TaskType>
Create(const Options & options,std::shared_ptr<SharedBatchScheduler<TaskType>> * scheduler)426 Status SharedBatchScheduler<TaskType>::Create(
427     const Options& options,
428     std::shared_ptr<SharedBatchScheduler<TaskType>>* scheduler) {
429   if (options.num_batch_threads < 1) {
430     return errors::InvalidArgument("num_batch_threads must be positive; was ",
431                                    options.num_batch_threads);
432   }
433   scheduler->reset(new SharedBatchScheduler<TaskType>(options));
434   return Status::OK();
435 }
436 
437 template <typename TaskType>
~SharedBatchScheduler()438 SharedBatchScheduler<TaskType>::~SharedBatchScheduler() {
439   // Wait until the batch threads finish clearing out and deleting the closed
440   // queues.
441   for (;;) {
442     {
443       mutex_lock l(mu_);
444       if (queues_.empty()) {
445         break;
446       }
447     }
448     const int64 kSleepTimeMicros = 100;
449     options_.env->SleepForMicroseconds(kSleepTimeMicros);
450   }
451   // Delete the batch threads before allowing state the threads may access (e.g.
452   // 'mu_') to be deleted.
453   batch_threads_.clear();
454 }
455 
456 template <typename TaskType>
AddQueue(const QueueOptions & options,std::function<void (std::unique_ptr<Batch<TaskType>>)> process_batch_callback,std::unique_ptr<BatchScheduler<TaskType>> * queue)457 Status SharedBatchScheduler<TaskType>::AddQueue(
458     const QueueOptions& options,
459     std::function<void(std::unique_ptr<Batch<TaskType>>)>
460         process_batch_callback,
461     std::unique_ptr<BatchScheduler<TaskType>>* queue) {
462   if (options.input_batch_size_limit == 0) {
463     return errors::InvalidArgument(
464         "input_batch_size_limit must be positive; was ",
465         options.input_batch_size_limit);
466   }
467   if (options.batch_timeout_micros < 0) {
468     return errors::InvalidArgument(
469         "batch_timeout_micros must be non-negative; was ",
470         options.batch_timeout_micros);
471   }
472   if (options.max_enqueued_batches < 0) {
473     return errors::InvalidArgument(
474         "max_enqueued_batches must be non-negative; was ",
475         options.max_enqueued_batches);
476   }
477 
478   if (options.enable_large_batch_splitting &&
479       options.split_input_task_func == nullptr) {
480     return errors::InvalidArgument(
481         "split_input_task_func must be specified when split_input_task is "
482         "true: ",
483         options.enable_large_batch_splitting);
484   }
485 
486   if (options.enable_large_batch_splitting &&
487       (options.input_batch_size_limit < options.max_execution_batch_size)) {
488     return errors::InvalidArgument(
489         "When enable_large_batch_splitting is true, input_batch_size_limit "
490         "must be "
491         "greater than or equal to max_execution_batch_size.",
492         options.enable_large_batch_splitting, options.input_batch_size_limit,
493         options.max_execution_batch_size);
494   }
495 
496   auto schedulable_batch_callback = [this] {
497     mutex_lock l(mu_);
498     schedulable_batch_cv_.notify_one();
499   };
500   auto internal_queue =
501       std::unique_ptr<internal::Queue<TaskType>>(new internal::Queue<TaskType>(
502           options, options_.env, process_batch_callback,
503           schedulable_batch_callback));
504   auto handle = std::unique_ptr<BatchScheduler<TaskType>>(
505       new internal::QueueHandle<TaskType>(this->shared_from_this(),
506                                           internal_queue.get()));
507   {
508     mutex_lock l(mu_);
509     queues_.push_back(std::move(internal_queue));
510     if (next_queue_to_schedule_ == queues_.end()) {
511       next_queue_to_schedule_ = queues_.begin();
512     }
513   }
514   *queue = std::move(handle);
515   return Status::OK();
516 }
517 
518 template <typename TaskType>
SharedBatchScheduler(const Options & options)519 SharedBatchScheduler<TaskType>::SharedBatchScheduler(const Options& options)
520     : options_(options), next_queue_to_schedule_(queues_.end()) {
521   // Kick off the batch threads.
522   PeriodicFunction::Options periodic_fn_options;
523   periodic_fn_options.thread_name_prefix =
524       strings::StrCat(options.thread_pool_name, "_");
525   for (int i = 0; i < options.num_batch_threads; ++i) {
526     std::unique_ptr<PeriodicFunction> thread(new PeriodicFunction(
527         [this] { this->ThreadLogic(); },
528         0 /* function invocation interval time */, periodic_fn_options));
529     batch_threads_.push_back(std::move(thread));
530   }
531 }
532 
533 template <typename TaskType>
ThreadLogic()534 void SharedBatchScheduler<TaskType>::ThreadLogic() {
535   // A batch to process next (or nullptr if no work to do).
536   std::unique_ptr<Batch<TaskType>> batch_to_process;
537   // The queue with which 'batch_to_process' is associated.
538   internal::Queue<TaskType>* queue_for_batch = nullptr;
539   {
540     mutex_lock l(mu_);
541 
542     const int num_queues = queues_.size();
543     for (int num_queues_tried = 0;
544          batch_to_process == nullptr && num_queues_tried < num_queues;
545          ++num_queues_tried) {
546       DCHECK(next_queue_to_schedule_ != queues_.end());
547 
548       // If a closed queue responds to ScheduleBatch() with nullptr, the queue
549       // will never yield any further batches so we can drop it. To avoid a
550       // race, we take a snapshot of the queue's closedness state *before*
551       // calling ScheduleBatch().
552       const bool queue_closed = (*next_queue_to_schedule_)->closed();
553 
554       // Ask '*next_queue_to_schedule_' if it wants us to process a batch.
555       batch_to_process = (*next_queue_to_schedule_)->ScheduleBatch();
556       if (batch_to_process != nullptr) {
557         queue_for_batch = next_queue_to_schedule_->get();
558       }
559 
560       // Advance 'next_queue_to_schedule_'.
561       if (queue_closed && (*next_queue_to_schedule_)->IsEmpty() &&
562           batch_to_process == nullptr) {
563         // We've encountered a closed queue with no work to do. Drop it.
564         DCHECK_NE(queue_for_batch, next_queue_to_schedule_->get());
565         next_queue_to_schedule_ = queues_.erase(next_queue_to_schedule_);
566       } else {
567         ++next_queue_to_schedule_;
568       }
569       if (next_queue_to_schedule_ == queues_.end() && !queues_.empty()) {
570         // We've hit the end. Wrap to the first queue.
571         next_queue_to_schedule_ = queues_.begin();
572       }
573     }
574 
575     if (batch_to_process == nullptr) {
576       // We couldn't find any work to do. Wait until a new batch becomes
577       // schedulable, or some time has elapsed, before checking again.
578       const int64 kTimeoutMillis = 1;  // The smallest accepted granule of time.
579       WaitForMilliseconds(&l, &schedulable_batch_cv_, kTimeoutMillis);
580       return;
581     }
582   }
583 
584   queue_for_batch->ProcessBatch(std::move(batch_to_process));
585 }
586 
587 namespace internal {
588 
589 template <typename TaskType>
Queue(const typename SharedBatchScheduler<TaskType>::QueueOptions & options,Env * env,ProcessBatchCallback process_batch_callback,SchedulableBatchCallback schedulable_batch_callback)590 Queue<TaskType>::Queue(
591     const typename SharedBatchScheduler<TaskType>::QueueOptions& options,
592     Env* env, ProcessBatchCallback process_batch_callback,
593     SchedulableBatchCallback schedulable_batch_callback)
594     : options_(options),
595       env_(env),
596       process_batch_callback_(process_batch_callback),
597       schedulable_batch_callback_(schedulable_batch_callback) {
598   // Set the higher 32 bits of traceme_context_id_counter_ to be the creation
599   // time of the queue. This prevents the batches in different queues to have
600   // the same traceme_context_id_counter_.
601   traceme_context_id_counter_ = absl::GetCurrentTimeNanos() << 32;
602   // Create an initial, open batch.
603   batches_.emplace_back(new Batch<TaskType>);
604 }
605 
606 template <typename TaskType>
~Queue()607 Queue<TaskType>::~Queue() {
608   mutex_lock l(mu_);
609   DCHECK(IsEmptyInternal());
610 
611   // Close the (empty) open batch, so its destructor doesn't block.
612   batches_.back()->Close();
613 }
614 
615 template <typename TaskType>
Schedule(std::unique_ptr<TaskType> * task)616 Status Queue<TaskType>::Schedule(std::unique_ptr<TaskType>* task) {
617   if (options_.enable_large_batch_splitting) {
618     return ScheduleWithSplit(std::move(task));
619   }
620   return ScheduleWithoutSplit(std::move(task));
621 }
622 
623 template <typename TaskType>
ScheduleWithoutSplit(std::unique_ptr<TaskType> * task)624 Status Queue<TaskType>::ScheduleWithoutSplit(std::unique_ptr<TaskType>* task) {
625   if ((*task)->size() > options_.input_batch_size_limit) {
626     return errors::InvalidArgument("Task size ", (*task)->size(),
627                                    " is larger than maximum input batch size ",
628                                    options_.input_batch_size_limit);
629   }
630 
631   bool notify_of_schedulable_batch = false;
632   {
633     mutex_lock l(mu_);
634 
635     DCHECK(!closed_);
636 
637     if (batches_.back()->size() + (*task)->size() >
638         options_.input_batch_size_limit) {
639       if (batches_.size() >= options_.max_enqueued_batches) {
640         return errors::Unavailable(
641             "The batch scheduling queue to which this task was submitted is "
642             "full");
643       }
644       StartNewBatch();
645     }
646     if (batches_.back()->empty()) {
647       open_batch_start_time_micros_ = env_->NowMicros();
648     }
649     profiler::TraceMeProducer trace_me(
650         [task] {
651           return profiler::TraceMeEncode(
652               "ScheduleWithoutSplit",
653               {{"batching_input_task_size", (*task)->size()}});
654         },
655         profiler::ContextType::kSharedBatchScheduler,
656         batches_.back()->traceme_context_id());
657     batches_.back()->AddTask(std::move(*task));
658 
659     if (!schedulable_batch_) {
660       if (batches_.size() > 1 || IsOpenBatchSchedulable()) {
661         schedulable_batch_ = true;
662         notify_of_schedulable_batch = true;
663       }
664     }
665   }
666 
667   if (notify_of_schedulable_batch) {
668     schedulable_batch_callback_();
669   }
670 
671   return Status::OK();
672 }
673 
674 // TODO(b/154140947):
675 // Merge `ScheduleWithSplit` and `ScheduleWithoutSplit` into `Schedule`.
676 // Two variants are created so original path (ScheduleWithoutSplit) is kept as
677 // it is.
678 template <typename TaskType>
ScheduleWithSplit(std::unique_ptr<TaskType> * task)679 Status Queue<TaskType>::ScheduleWithSplit(std::unique_ptr<TaskType>* task) {
680   profiler::TraceMe trace_me([task] {
681     return profiler::TraceMeEncode(
682         "ScheduleWithSplit", {{"batching_input_task_size", (*task)->size()}});
683   });
684   if ((*task)->size() > options_.input_batch_size_limit) {
685     return errors::InvalidArgument("Task size ", (*task)->size(),
686                                    " is larger than maximum input batch size ",
687                                    options_.input_batch_size_limit);
688   }
689 
690   // The max size to be enqueued.
691   const int max_execution_batch_size = options_.max_execution_batch_size;
692 
693   bool notify_of_schedulable_batch = false;
694   {
695     mutex_lock l(mu_);
696 
697     DCHECK(!closed_);
698 
699     const int num_new_batches_schedulable =
700         options_.max_enqueued_batches - batches_.size();
701     const int open_batch_capacity =
702         max_execution_batch_size - batches_.back()->size();
703     const int scheduling_capacity =
704         (num_new_batches_schedulable * max_execution_batch_size) +
705         open_batch_capacity;
706 
707     // The scenario when concurrent incoming batches arrives and use up all
708     // queue capacity isn't covered by unit test.
709     // The coverage boils down to sepcify "function library" in a way that,
710     // one batch task can synchronize with another task, and then two tasks
711     // run concurrently. An integration test might be a better fit.
712     if ((*task)->size() > scheduling_capacity) {
713       return errors::Unavailable(
714           "The batch scheduling queue to which this task was submitted is "
715           "full");
716     }
717 
718     const int64 open_batch_remaining_slot =
719         max_execution_batch_size - batches_.back()->size();
720 
721     const int64 input_task_size = (*task)->size();
722 
723     std::vector<std::unique_ptr<TaskType>> output_tasks;
724 
725     if (input_task_size <= open_batch_remaining_slot) {
726       // This is the fast path when input doesn't need to be split.
727       output_tasks.push_back(std::move(*task));
728     } else {
729       TF_RETURN_IF_ERROR(SplitInputBatchIntoSubtasks(task, &output_tasks));
730     }
731 
732     for (int i = 0; i < output_tasks.size(); ++i) {
733       if (batches_.back()->size() + output_tasks[i]->size() >
734           options_.max_execution_batch_size) {
735         StartNewBatch();
736       }
737       if (batches_.back()->empty()) {
738         open_batch_start_time_micros_ = env_->NowMicros();
739       }
740       profiler::TraceMeProducer trace_me(
741           [&output_tasks, i] {
742             return profiler::TraceMeEncode("ScheduleOutputTask",
743                                            {{"size", output_tasks[i]->size()}});
744           },
745           profiler::ContextType::kSharedBatchScheduler,
746           batches_.back()->traceme_context_id());
747       batches_.back()->AddTask(std::move(output_tasks[i]));
748     }
749 
750     if (!schedulable_batch_) {
751       if (batches_.size() > 1 || IsOpenBatchSchedulable()) {
752         schedulable_batch_ = true;
753         notify_of_schedulable_batch = true;
754       }
755     }
756   }
757 
758   if (notify_of_schedulable_batch) {
759     schedulable_batch_callback_();
760   }
761 
762   return Status::OK();
763 }
764 
765 template <typename TaskType>
NumEnqueuedTasks()766 size_t Queue<TaskType>::NumEnqueuedTasks() const {
767   mutex_lock l(mu_);
768   size_t num_enqueued_tasks = 0;
769   for (const auto& batch : batches_) {
770     num_enqueued_tasks += batch->num_tasks();
771   }
772   return num_enqueued_tasks;
773 }
774 
775 template <typename TaskType>
SchedulingCapacity()776 size_t Queue<TaskType>::SchedulingCapacity() const {
777   mutex_lock l(mu_);
778   const int num_new_batches_schedulable =
779       options_.max_enqueued_batches - batches_.size();
780   const int open_batch_capacity =
781       max_execution_batch_size() - batches_.back()->size();
782   return (num_new_batches_schedulable * max_execution_batch_size()) +
783          open_batch_capacity;
784 }
785 
786 template <typename TaskType>
ScheduleBatch()787 std::unique_ptr<Batch<TaskType>> Queue<TaskType>::ScheduleBatch() {
788   // The batch to schedule, which we may populate below. (If left as nullptr,
789   // that means we are electing not to schedule a batch at this time.)
790   std::unique_ptr<Batch<TaskType>> batch_to_schedule;
791 
792   {
793     mutex_lock l(mu_);
794 
795     // Consider closing the open batch at this time, to schedule it.
796     if (batches_.size() == 1 && IsOpenBatchSchedulable()) {
797       StartNewBatch();
798     }
799 
800     if (batches_.size() >= 2) {
801       // There is at least one closed batch that is ready to be scheduled.
802       ++num_batches_being_processed_;
803       batch_to_schedule = std::move(batches_.front());
804       batches_.pop_front();
805     } else {
806       schedulable_batch_ = false;
807     }
808   }
809 
810   return batch_to_schedule;
811 }
812 
813 template <typename TaskType>
ProcessBatch(std::unique_ptr<Batch<TaskType>> batch)814 void Queue<TaskType>::ProcessBatch(std::unique_ptr<Batch<TaskType>> batch) {
815   profiler::TraceMeConsumer trace_me(
816       [&] {
817         return profiler::TraceMeEncode(
818             "ProcessBatch", {{"batch_size_before_padding", batch->size()}});
819       },
820       profiler::ContextType::kSharedBatchScheduler,
821       batch->traceme_context_id());
822   process_batch_callback_(std::move(batch));
823 
824   {
825     mutex_lock l(mu_);
826     --num_batches_being_processed_;
827     if (empty_notification_ != nullptr && IsEmptyInternal()) {
828       empty_notification_->Notify();
829     }
830   }
831 }
832 
833 template <typename TaskType>
IsEmpty()834 bool Queue<TaskType>::IsEmpty() const {
835   mutex_lock l(mu_);
836   return IsEmptyInternal();
837 }
838 
839 template <typename TaskType>
CloseAndWaitUntilEmpty()840 void Queue<TaskType>::CloseAndWaitUntilEmpty() {
841   Notification empty;
842   {
843     mutex_lock l(mu_);
844     closed_ = true;
845     if (IsEmptyInternal()) {
846       empty.Notify();
847     } else {
848       // Arrange for ProcessBatch() to notify when the queue becomes empty.
849       empty_notification_ = &empty;
850     }
851   }
852   empty.WaitForNotification();
853 }
854 
855 template <typename TaskType>
IsEmptyInternal()856 bool Queue<TaskType>::IsEmptyInternal() const {
857   return num_batches_being_processed_ == 0 && batches_.size() == 1 &&
858          batches_.back()->empty();
859 }
860 
861 template <typename TaskType>
StartNewBatch()862 void Queue<TaskType>::StartNewBatch() {
863   batches_.back()->Close();
864   batches_.emplace_back(new Batch<TaskType>(++traceme_context_id_counter_));
865 }
866 
867 template <typename TaskType>
SplitInputBatchIntoSubtasks(std::unique_ptr<TaskType> * input_task,std::vector<std::unique_ptr<TaskType>> * output_tasks)868 Status Queue<TaskType>::SplitInputBatchIntoSubtasks(
869     std::unique_ptr<TaskType>* input_task,
870     std::vector<std::unique_ptr<TaskType>>* output_tasks) {
871   const int open_batch_remaining_slot =
872       max_execution_batch_size() - batches_.back()->size();
873   return options_.split_input_task_func(
874       std::move(input_task), open_batch_remaining_slot,
875       max_execution_batch_size(), std::move(output_tasks));
876 }
877 
878 template <typename TaskType>
IsOpenBatchSchedulable()879 bool Queue<TaskType>::IsOpenBatchSchedulable() const {
880   Batch<TaskType>* open_batch = batches_.back().get();
881   if (open_batch->empty()) {
882     return false;
883   }
884   return closed_ || open_batch->size() >= max_execution_batch_size() ||
885          env_->NowMicros() >=
886              open_batch_start_time_micros_ + options_.batch_timeout_micros;
887 }
888 
889 template <typename TaskType>
QueueHandle(std::shared_ptr<SharedBatchScheduler<TaskType>> scheduler,Queue<TaskType> * queue)890 QueueHandle<TaskType>::QueueHandle(
891     std::shared_ptr<SharedBatchScheduler<TaskType>> scheduler,
892     Queue<TaskType>* queue)
893     : scheduler_(scheduler), queue_(queue) {}
894 
895 template <typename TaskType>
~QueueHandle()896 QueueHandle<TaskType>::~QueueHandle() {
897   queue_->CloseAndWaitUntilEmpty();
898 }
899 
900 template <typename TaskType>
Schedule(std::unique_ptr<TaskType> * task)901 Status QueueHandle<TaskType>::Schedule(std::unique_ptr<TaskType>* task) {
902   return queue_->Schedule(task);
903 }
904 
905 template <typename TaskType>
NumEnqueuedTasks()906 size_t QueueHandle<TaskType>::NumEnqueuedTasks() const {
907   return queue_->NumEnqueuedTasks();
908 }
909 
910 template <typename TaskType>
SchedulingCapacity()911 size_t QueueHandle<TaskType>::SchedulingCapacity() const {
912   return queue_->SchedulingCapacity();
913 }
914 
915 }  // namespace internal
916 
917 }  // namespace serving
918 }  // namespace tensorflow
919 
920 #endif  // TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_SHARED_BATCH_SCHEDULER_H_
921