1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_ADAPTIVE_SHARED_BATCH_SCHEDULER_H_
17 #define TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_ADAPTIVE_SHARED_BATCH_SCHEDULER_H_
18
19 #include <algorithm>
20 #include <atomic>
21 #include <functional>
22 #include <memory>
23 #include <random>
24 #include <unordered_map>
25 #include <vector>
26
27 #include "absl/types/optional.h"
28 #include "tensorflow/core/kernels/batching_util/batch_scheduler.h"
29 #include "tensorflow/core/kernels/batching_util/periodic_function.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/lib/core/status.h"
32 #include "tensorflow/core/lib/core/threadpool.h"
33 #include "tensorflow/core/platform/byte_order.h"
34 #include "tensorflow/core/platform/cpu_info.h"
35 #include "tensorflow/core/platform/env.h"
36 #include "tensorflow/core/platform/errors.h"
37 #include "tensorflow/core/platform/thread_annotations.h"
38 #include "tensorflow/core/platform/threadpool_interface.h"
39 #include "tensorflow/core/platform/types.h"
40 #include "tensorflow/core/profiler/lib/connected_traceme.h"
41
42 namespace tensorflow {
43 namespace serving {
44 namespace internal {
45 template <typename TaskType>
46 class ASBSBatch;
47
48 template <typename TaskType>
49 class ASBSQueue;
50 } // namespace internal
51
52 // Shared batch scheduler designed to minimize latency. The scheduler keeps
53 // track of a number of queues (one per model or model version) which are
54 // continuously enqueuing requests. The scheduler groups the requests into
55 // batches which it periodically sends off for processing (see
56 // shared_batch_scheduler.h for more details). AdaptiveSharedBatchScheduler
57 // (ASBS) prioritizes batches primarily by age (i.e. the batch's oldest request)
58 // along with a configurable preference for scheduling larger batches first.
59 //
60 //
61 // ASBS tries to keep the system busy by maintaining an adjustable number of
62 // concurrently processed batches. If a new batch is created, and the number of
63 // in flight batches is below the target, the next (i.e. oldest) batch is
64 // immediately scheduled. Similarly, when a batch finishes processing, the
65 // target is rechecked, and another batch may be scheduled. To avoid the need
66 // to carefully tune the target for workload, model type, platform, etc, it is
67 // dynamically adjusted in order to provide the lowest average latency.
68 //
69 // Some potential use cases:
70 // Hardware Accelerators (GPUs & TPUs) - If some phase of batch processing
71 // involves serial processing by a device, from a latency perspective it is
72 // desirable to keep the device evenly loaded, avoiding the need to wait for
73 // the device to process prior batches.
74 // CPU utilization - If the batch processing is cpu dominated, you can reap
75 // latency gains when underutilized by increasing the processing rate, but
76 // back the rate off when the load increases to avoid overload.
77
78 template <typename TaskType>
79 class AdaptiveSharedBatchScheduler
80 : public std::enable_shared_from_this<
81 AdaptiveSharedBatchScheduler<TaskType>> {
82 public:
~AdaptiveSharedBatchScheduler()83 ~AdaptiveSharedBatchScheduler() {
84 // Finish processing batches before destroying other class members.
85 if (owned_batch_thread_pool_) {
86 delete batch_thread_pool_;
87 }
88 }
89
90 struct Options {
91 // The name to use for the pool of batch threads.
92 string thread_pool_name = {"batch_threads"};
93 // Number of batch processing threads - the maximum value of
94 // in_flight_batches_limit_. It is recommended that this value be set by
95 // running the system under load, observing the learned value for
96 // in_flight_batches_limit_, and setting this maximum to ~ 2x the value.
97 // Under low load, in_flight_batches_limit_ has no substantial effect on
98 // latency and therefore undergoes a random walk. Unreasonably large values
99 // for num_batch_threads allows for large in_flight_batches_limit_, which
100 // will harm latency for some time once load increases again.
101 int64 num_batch_threads = port::MaxParallelism();
102 // You can pass a ThreadPool directly rather than the above two
103 // parameters. If given, the above two parameers are ignored. Ownership of
104 // the threadpool is not transferred.
105 thread::ThreadPool* thread_pool = nullptr;
106
107 // Lower bound for in_flight_batches_limit_. As discussed above, can be used
108 // to minimize the damage caused by the random walk under low load.
109 int64 min_in_flight_batches_limit = 1;
110 // Although batch selection is primarily based on age, this parameter
111 // specifies a preference for larger batches. A full batch will be
112 // scheduled before an older, nearly empty batch as long as the age gap is
113 // less than full_batch_scheduling_boost_micros. The optimal value for this
114 // parameter should be of order the batch processing latency, but must be
115 // chosen carefully, as too large a value will harm tail latency.
116 int64 full_batch_scheduling_boost_micros = 0;
117 // The environment to use (typically only overridden by test code).
118 Env* env = Env::Default();
119 // Initial limit for number of batches being concurrently processed.
120 // Non-integer values correspond to probabilistic limits - i.e. a value of
121 // 3.2 results in an actual cap of 3 80% of the time, and 4 20% of the time.
122 double initial_in_flight_batches_limit = 3;
123 // Number of batches between adjustments of in_flight_batches_limit. Larger
124 // numbers will give less noisy latency measurements, but will be less
125 // responsive to changes in workload.
126 int64 batches_to_average_over = 1000;
127
128 // If true, schedule batches using FIFO policy.
129 // Requires that `full_batch_scheduling_boost_micros` is zero.
130 // NOTE:
131 // A new parameter is introduced (not re-using
132 // full_batch_scheduling_boost_micros==zero) for backward compatibility of
133 // API.
134 bool fifo_scheduling = false;
135 };
136
137 // Ownership is shared between the caller of Create() and any queues created
138 // via AddQueue().
139 static Status Create(
140 const Options& options,
141 std::shared_ptr<AdaptiveSharedBatchScheduler<TaskType>>* scheduler);
142
143 struct QueueOptions {
144 // Maximum size of a batch that's formed within
145 // `ASBSQueue<TaskType>::Schedule`.
146 int max_batch_size = 1000;
147 // Maximum size of input task, which is submitted to the queue by
148 // calling `ASBSQueue<TaskType>::Schedule` and used to form batches.
149 //
150 // If specified, it should be larger than or equal to 'max_batch_size'.
151 absl::optional<int> max_input_task_size = absl::nullopt;
152 // Maximum number of enqueued (i.e. non-scheduled) batches.
153 int max_enqueued_batches = 10;
154 // Amount of time non-full batches must wait before becoming schedulable.
155 // A non-zero value can improve performance by limiting the scheduling of
156 // nearly empty batches.
157 int64 batch_timeout_micros = 0;
158 // If non nullptr, split_input_task_func should split input_task into
159 // multiple tasks, the first of which has size first_size and the remaining
160 // not exceeding max_size. This function may acquire ownership of input_task
161 // and should return a status indicating if the split was successful. Upon
162 // success, the caller can assume that all output_tasks will be scheduled.
163 // Including this option allows the scheduler to pack batches better and
164 // should usually improve overall throughput.
165 std::function<Status(std::unique_ptr<TaskType>* input_task, int first_size,
166 int max_batch_size,
167 std::vector<std::unique_ptr<TaskType>>* output_tasks)>
168 split_input_task_func;
169 };
170
171 using BatchProcessor = std::function<void(std::unique_ptr<Batch<TaskType>>)>;
172
173 // Adds queue (and its callback) to be managed by this scheduler.
174 Status AddQueue(const QueueOptions& options,
175 BatchProcessor process_batch_callback,
176 std::unique_ptr<BatchScheduler<TaskType>>* queue);
177
in_flight_batches_limit()178 double in_flight_batches_limit() {
179 mutex_lock l(mu_);
180 return in_flight_batches_limit_;
181 }
182
183 private:
184 // access to AddBatch, MaybeScheduleClosedBatches, RemoveQueue, GetEnv.
185 friend class internal::ASBSQueue<TaskType>;
186
187 explicit AdaptiveSharedBatchScheduler(const Options& options);
188
189 // Tracks processing latency and adjusts in_flight_batches_limit to minimize.
190 void CallbackWrapper(const internal::ASBSBatch<TaskType>* batch,
191 BatchProcessor callback, bool is_express);
192
193 // Schedules batch if in_flight_batches_limit_ is not met.
194 void MaybeScheduleNextBatch() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
195
196 // Schedules batch using FIFO policy if in_flight_batches_limit_ is not met.
197 void MaybeScheduleNextBatchFIFO() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
198
199 // Schedules all closed batches in batches_ for which an idle thread is
200 // available in batch_thread_pool_.
201 // Batches scheduled this way are called express batches.
202 // Express batches are not limited by in_flight_batches_limit_, and
203 // their latencies will not affect in_flight_batches_limit_.
204 void MaybeScheduleClosedBatches();
205
206 void MaybeScheduleClosedBatchesLocked() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
207
208 void MaybeScheduleClosedBatchesLockedFIFO() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
209
210 // Notifies scheduler of non-empty batch which is eligible for processing.
211 void AddBatch(const internal::ASBSBatch<TaskType>* batch);
212
213 // Removes queue from scheduler.
214 void RemoveQueue(const internal::ASBSQueue<TaskType>* queue);
215
GetEnv()216 Env* GetEnv() const { return options_.env; }
217
218 const Options options_;
219
220 // Collection of batches added by AddBatch, ordered by age. Owned by scheduler
221 // until they are released for processing.
222 std::vector<const internal::ASBSBatch<TaskType>*> batches_ TF_GUARDED_BY(mu_);
223
224 // Collection of batches added by AddBatch, ordered by age. Owned by
225 // scheduler until they are released for processing.
226 std::deque<const internal::ASBSBatch<TaskType>*> fifo_batches_
227 TF_GUARDED_BY(mu_);
228
229 // Unowned queues and callbacks added by AddQueue.
230 std::unordered_map<const internal::ASBSQueue<TaskType>*, BatchProcessor>
231 queues_and_callbacks_ TF_GUARDED_BY(mu_);
232
233 mutex mu_;
234
235 // Responsible for running the batch processing callbacks.
236 thread::ThreadPool* batch_thread_pool_;
237
238 bool owned_batch_thread_pool_ = false;
239
240 // Limit on number of batches which can be concurrently processed.
241 // Non-integer values correspond to probabilistic limits - i.e. a value of 3.2
242 // results in an actual cap of 3 80% of the time, and 4 20% of the time.
243 double in_flight_batches_limit_ TF_GUARDED_BY(mu_);
244
245 // Number of regular batches currently being processed.
246 int64 in_flight_batches_ TF_GUARDED_BY(mu_) = 0;
247 // Number of express batches currently being processed.
248 int64 in_flight_express_batches_ TF_GUARDED_BY(mu_) = 0;
249
250 // RNG engine and distribution.
251 std::default_random_engine rand_engine_;
252 std::uniform_real_distribution<double> rand_double_;
253
254 // Fields controlling the dynamic adjustment of in_flight_batches_limit_.
255 // Number of batches since the last in_flight_batches_limit_ adjustment.
256 int64 batch_count_ TF_GUARDED_BY(mu_) = 0;
257 // Sum of processing latency for batches counted by batch_count_.
258 int64 batch_latency_sum_ TF_GUARDED_BY(mu_) = 0;
259 // Average batch latency for previous value of in_flight_batches_limit_.
260 double last_avg_latency_ms_ TF_GUARDED_BY(mu_) = 0;
261 // Did last_avg_latency_ms_ decrease from the previous last_avg_latency_ms_?
262 bool last_latency_decreased_ TF_GUARDED_BY(mu_) = false;
263 // Current direction (+-) to adjust in_flight_batches_limit_
264 int step_direction_ TF_GUARDED_BY(mu_) = 1;
265 // Max adjustment size (as a fraction of in_flight_batches_limit_).
266 constexpr static double kMaxStepSizeMultiplier = 0.125; // 1/8;
267 // Min adjustment size (as a fraction of in_flight_batches_limit_).
268 constexpr static double kMinStepSizeMultiplier = 0.0078125; // 1/128
269 // Current adjustment size (as a fraction of in_flight_batches_limit_).
270 double step_size_multiplier_ TF_GUARDED_BY(mu_) = kMaxStepSizeMultiplier;
271
272 TF_DISALLOW_COPY_AND_ASSIGN(AdaptiveSharedBatchScheduler);
273 };
274
275 //////////////////////////////////////////////////////////
276 // Implementation details follow. API users need not read.
277
278 namespace internal {
279 // Consolidates tasks into batches, passing them off to the
280 // AdaptiveSharedBatchScheduler for processing.
281 template <typename TaskType>
282 class ASBSQueue : public BatchScheduler<TaskType> {
283 public:
284 using QueueOptions =
285 typename AdaptiveSharedBatchScheduler<TaskType>::QueueOptions;
286
287 ASBSQueue(std::shared_ptr<AdaptiveSharedBatchScheduler<TaskType>> scheduler,
288 const QueueOptions& options);
289
290 ~ASBSQueue() override;
291
292 // Adds task to current batch. Fails if the task size is larger than the batch
293 // size or if the current batch is full and this queue's number of outstanding
294 // batches is at its maximum.
295 Status Schedule(std::unique_ptr<TaskType>* task) override;
296
297 // Number of tasks waiting to be scheduled.
298 size_t NumEnqueuedTasks() const override;
299
300 // Number of size 1 tasks which could currently be scheduled without failing.
301 size_t SchedulingCapacity() const override;
302
303 // Notifies queue that a batch is about to be scheduled; the queue should not
304 // place any more tasks in this batch.
305 void ReleaseBatch(const ASBSBatch<TaskType>* batch);
306
max_task_size()307 size_t max_task_size() const override { return options_.max_batch_size; }
308
309 private:
310 // Number of size 1 tasks which could currently be scheduled without failing.
311 size_t SchedulingCapacityLocked() const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
312
313 // Returns uint64 one greater than was returned by the previous call.
314 // Context id is reused after std::numeric_limits<uint64>::max is exhausted.
315 static uint64 NewTraceMeContextIdForBatch();
316
317 std::shared_ptr<AdaptiveSharedBatchScheduler<TaskType>> scheduler_;
318 const QueueOptions options_;
319 // Owned by scheduler_.
320 ASBSBatch<TaskType>* current_batch_ TF_GUARDED_BY(mu_) = nullptr;
321 int64 num_enqueued_batches_ TF_GUARDED_BY(mu_) = 0;
322 int64 num_enqueued_tasks_ TF_GUARDED_BY(mu_) = 0;
323 mutable mutex mu_;
324 TF_DISALLOW_COPY_AND_ASSIGN(ASBSQueue);
325 };
326
327 // Batch which remembers when and by whom it was created.
328 template <typename TaskType>
329 class ASBSBatch : public Batch<TaskType> {
330 public:
ASBSBatch(ASBSQueue<TaskType> * queue,int64_t creation_time_micros,int64_t batch_timeout_micros,uint64 traceme_context_id)331 ASBSBatch(ASBSQueue<TaskType>* queue, int64_t creation_time_micros,
332 int64_t batch_timeout_micros, uint64 traceme_context_id)
333 : queue_(queue),
334 creation_time_micros_(creation_time_micros),
335 schedulable_time_micros_(creation_time_micros + batch_timeout_micros),
336 traceme_context_id_(traceme_context_id) {}
337
~ASBSBatch()338 ~ASBSBatch() override {}
339
queue()340 ASBSQueue<TaskType>* queue() const { return queue_; }
341
creation_time_micros()342 int64 creation_time_micros() const { return creation_time_micros_; }
343
schedulable_time_micros()344 int64 schedulable_time_micros() const { return schedulable_time_micros_; }
345
traceme_context_id()346 uint64 traceme_context_id() const { return traceme_context_id_; }
347
348 private:
349 ASBSQueue<TaskType>* queue_;
350 const int64 creation_time_micros_;
351 const int64 schedulable_time_micros_;
352 const uint64 traceme_context_id_;
353 TF_DISALLOW_COPY_AND_ASSIGN(ASBSBatch);
354 };
355 } // namespace internal
356
357 // ---------------- AdaptiveSharedBatchScheduler ----------------
358
359 template <typename TaskType>
360 constexpr double AdaptiveSharedBatchScheduler<TaskType>::kMaxStepSizeMultiplier;
361
362 template <typename TaskType>
363 constexpr double AdaptiveSharedBatchScheduler<TaskType>::kMinStepSizeMultiplier;
364
365 template <typename TaskType>
Create(const Options & options,std::shared_ptr<AdaptiveSharedBatchScheduler<TaskType>> * scheduler)366 Status AdaptiveSharedBatchScheduler<TaskType>::Create(
367 const Options& options,
368 std::shared_ptr<AdaptiveSharedBatchScheduler<TaskType>>* scheduler) {
369 if (options.num_batch_threads < 1) {
370 return errors::InvalidArgument("num_batch_threads must be positive; was ",
371 options.num_batch_threads);
372 }
373 if (options.min_in_flight_batches_limit < 1) {
374 return errors::InvalidArgument(
375 "min_in_flight_batches_limit must be >= 1; was ",
376 options.min_in_flight_batches_limit);
377 }
378 if (options.min_in_flight_batches_limit > options.num_batch_threads) {
379 return errors::InvalidArgument(
380 "min_in_flight_batches_limit (", options.min_in_flight_batches_limit,
381 ") must be <= num_batch_threads (", options.num_batch_threads, ")");
382 }
383 if (options.full_batch_scheduling_boost_micros < 0) {
384 return errors::InvalidArgument(
385 "full_batch_scheduling_boost_micros can't be negative; was ",
386 options.full_batch_scheduling_boost_micros);
387 }
388 if (options.initial_in_flight_batches_limit > options.num_batch_threads) {
389 return errors::InvalidArgument(
390 "initial_in_flight_batches_limit (",
391 options.initial_in_flight_batches_limit,
392 ") should not be larger than num_batch_threads (",
393 options.num_batch_threads, ")");
394 }
395 if (options.initial_in_flight_batches_limit <
396 options.min_in_flight_batches_limit) {
397 return errors::InvalidArgument("initial_in_flight_batches_limit (",
398 options.initial_in_flight_batches_limit,
399 "must be >= min_in_flight_batches_limit (",
400 options.min_in_flight_batches_limit, ")");
401 }
402 if (options.batches_to_average_over < 1) {
403 return errors::InvalidArgument(
404 "batches_to_average_over should be "
405 "greater than or equal to 1; was ",
406 options.batches_to_average_over);
407 }
408 scheduler->reset(new AdaptiveSharedBatchScheduler<TaskType>(options));
409 return Status::OK();
410 }
411
412 template <typename TaskType>
AdaptiveSharedBatchScheduler(const Options & options)413 AdaptiveSharedBatchScheduler<TaskType>::AdaptiveSharedBatchScheduler(
414 const Options& options)
415 : options_(options),
416 in_flight_batches_limit_(options.initial_in_flight_batches_limit),
417 rand_double_(0.0, 1.0) {
418 std::random_device device;
419 rand_engine_.seed(device());
420 if (options.thread_pool == nullptr) {
421 owned_batch_thread_pool_ = true;
422 batch_thread_pool_ = new thread::ThreadPool(
423 GetEnv(), options.thread_pool_name, options.num_batch_threads);
424 } else {
425 owned_batch_thread_pool_ = false;
426 batch_thread_pool_ = options.thread_pool;
427 }
428 }
429
430 template <typename TaskType>
AddQueue(const QueueOptions & options,BatchProcessor process_batch_callback,std::unique_ptr<BatchScheduler<TaskType>> * queue)431 Status AdaptiveSharedBatchScheduler<TaskType>::AddQueue(
432 const QueueOptions& options, BatchProcessor process_batch_callback,
433 std::unique_ptr<BatchScheduler<TaskType>>* queue) {
434 if (options.max_batch_size <= 0) {
435 return errors::InvalidArgument("max_batch_size must be positive; was ",
436 options.max_batch_size);
437 }
438 if (options.max_enqueued_batches <= 0) {
439 return errors::InvalidArgument(
440 "max_enqueued_batches must be positive; was ",
441 options.max_enqueued_batches);
442 }
443 if (options.max_input_task_size.has_value()) {
444 if (options.max_input_task_size.value() < options.max_batch_size) {
445 return errors::InvalidArgument(
446 "max_input_task_size must be larger than or equal to max_batch_size;"
447 "got max_input_task_size as ",
448 options.max_input_task_size.value(), " and max_batch_size as ",
449 options.max_batch_size);
450 }
451 }
452 internal::ASBSQueue<TaskType>* asbs_queue_raw;
453 queue->reset(asbs_queue_raw = new internal::ASBSQueue<TaskType>(
454 this->shared_from_this(), options));
455 mutex_lock l(mu_);
456 queues_and_callbacks_[asbs_queue_raw] = process_batch_callback;
457 return Status::OK();
458 }
459
460 template <typename TaskType>
AddBatch(const internal::ASBSBatch<TaskType> * batch)461 void AdaptiveSharedBatchScheduler<TaskType>::AddBatch(
462 const internal::ASBSBatch<TaskType>* batch) {
463 mutex_lock l(mu_);
464 if (options_.fifo_scheduling) {
465 fifo_batches_.push_back(batch);
466 } else {
467 batches_.push_back(batch);
468 }
469 int64_t delay_micros =
470 batch->schedulable_time_micros() - GetEnv()->NowMicros();
471 if (delay_micros <= 0) {
472 MaybeScheduleNextBatch();
473 return;
474 }
475 // Try to schedule batch once it becomes schedulable. Although scheduler waits
476 // for all batches to finish processing before allowing itself to be deleted,
477 // MaybeScheduleNextBatch() is called in other places, and therefore it's
478 // possible the scheduler could be deleted by the time this closure runs.
479 // Grab a shared_ptr reference to prevent this from happening.
480 GetEnv()->SchedClosureAfter(
481 delay_micros, [this, lifetime_preserver = this->shared_from_this()] {
482 mutex_lock l(mu_);
483 MaybeScheduleNextBatch();
484 });
485 }
486
487 template <typename TaskType>
RemoveQueue(const internal::ASBSQueue<TaskType> * queue)488 void AdaptiveSharedBatchScheduler<TaskType>::RemoveQueue(
489 const internal::ASBSQueue<TaskType>* queue) {
490 mutex_lock l(mu_);
491 queues_and_callbacks_.erase(queue);
492 }
493
494 template <typename TaskType>
MaybeScheduleNextBatchFIFO()495 void AdaptiveSharedBatchScheduler<TaskType>::MaybeScheduleNextBatchFIFO() {
496 const internal::ASBSBatch<TaskType>* batch = *fifo_batches_.begin();
497 fifo_batches_.pop_front();
498 // Queue may destroy itself after ReleaseBatch is called.
499 batch->queue()->ReleaseBatch(batch);
500 batch_thread_pool_->Schedule(std::bind(
501 &AdaptiveSharedBatchScheduler<TaskType>::CallbackWrapper, this, batch,
502 queues_and_callbacks_[batch->queue()], false /* is express */));
503 in_flight_batches_++;
504 }
505
506 template <typename TaskType>
507 void AdaptiveSharedBatchScheduler<
MaybeScheduleClosedBatchesLockedFIFO()508 TaskType>::MaybeScheduleClosedBatchesLockedFIFO() {
509 // Only schedule closed batches if we have spare capacity.
510 int available_threads =
511 static_cast<int>(options_.num_batch_threads - in_flight_batches_ -
512 in_flight_express_batches_);
513 for (auto it = fifo_batches_.begin();
514 it != fifo_batches_.end() && available_threads > 0;
515 it = fifo_batches_.begin()) {
516 if ((*it)->IsClosed()) {
517 const internal::ASBSBatch<TaskType>* batch = *it;
518 fifo_batches_.pop_front();
519 batch->queue()->ReleaseBatch(batch);
520 batch_thread_pool_->Schedule(
521 std::bind(&AdaptiveSharedBatchScheduler<TaskType>::CallbackWrapper,
522 this, batch, queues_and_callbacks_[batch->queue()], true));
523 in_flight_express_batches_++;
524 available_threads--;
525 } else {
526 // Batches are FIFO, so stop iteration after finding the first non-closed
527 // batches.
528 break;
529 }
530 }
531 }
532
533 template <typename TaskType>
MaybeScheduleNextBatch()534 void AdaptiveSharedBatchScheduler<TaskType>::MaybeScheduleNextBatch() {
535 bool batch_empty =
536 options_.fifo_scheduling ? fifo_batches_.empty() : batches_.empty();
537 if (batch_empty || in_flight_batches_ >= in_flight_batches_limit_) return;
538 // Non-integer limit handled probabilistically.
539 if (in_flight_batches_limit_ - in_flight_batches_ < 1 &&
540 rand_double_(rand_engine_) >
541 in_flight_batches_limit_ - in_flight_batches_) {
542 return;
543 }
544
545 if (options_.fifo_scheduling) {
546 MaybeScheduleNextBatchFIFO();
547 return;
548 }
549
550 auto best_it = batches_.end();
551 double best_score = (std::numeric_limits<double>::max)();
552 int64_t now_micros = GetEnv()->NowMicros();
553 for (auto it = batches_.begin(); it != batches_.end(); it++) {
554 if ((*it)->schedulable_time_micros() > now_micros) continue;
555 const double score =
556 (*it)->creation_time_micros() -
557 options_.full_batch_scheduling_boost_micros * (*it)->size() /
558 static_cast<double>((*it)->queue()->max_task_size());
559 if (best_it == batches_.end() || score < best_score) {
560 best_score = score;
561 best_it = it;
562 }
563 }
564 // No schedulable batches.
565 if (best_it == batches_.end()) return;
566 const internal::ASBSBatch<TaskType>* batch = *best_it;
567 batches_.erase(best_it);
568 // Queue may destroy itself after ReleaseBatch is called.
569 batch->queue()->ReleaseBatch(batch);
570 batch_thread_pool_->Schedule(
571 std::bind(&AdaptiveSharedBatchScheduler<TaskType>::CallbackWrapper, this,
572 batch, queues_and_callbacks_[batch->queue()], false));
573 in_flight_batches_++;
574 }
575
576 template <typename TaskType>
MaybeScheduleClosedBatches()577 void AdaptiveSharedBatchScheduler<TaskType>::MaybeScheduleClosedBatches() {
578 mutex_lock l(mu_);
579 MaybeScheduleClosedBatchesLocked();
580 }
581
582 template <typename TaskType>
583 void AdaptiveSharedBatchScheduler<
MaybeScheduleClosedBatchesLocked()584 TaskType>::MaybeScheduleClosedBatchesLocked() {
585 if (options_.fifo_scheduling) {
586 MaybeScheduleClosedBatchesLockedFIFO();
587 return;
588 }
589 // Only schedule closed batches if we have spare capacity.
590 int available_threads =
591 static_cast<int>(options_.num_batch_threads - in_flight_batches_ -
592 in_flight_express_batches_);
593 for (auto it = batches_.begin();
594 it != batches_.end() && available_threads > 0;) {
595 if ((*it)->IsClosed()) {
596 const internal::ASBSBatch<TaskType>* batch = *it;
597 it = batches_.erase(it);
598 batch->queue()->ReleaseBatch(batch);
599 batch_thread_pool_->Schedule(
600 std::bind(&AdaptiveSharedBatchScheduler<TaskType>::CallbackWrapper,
601 this, batch, queues_and_callbacks_[batch->queue()], true));
602 in_flight_express_batches_++;
603 available_threads--;
604 } else {
605 ++it;
606 }
607 }
608 }
609
610 template <typename TaskType>
CallbackWrapper(const internal::ASBSBatch<TaskType> * batch,AdaptiveSharedBatchScheduler<TaskType>::BatchProcessor callback,bool is_express)611 void AdaptiveSharedBatchScheduler<TaskType>::CallbackWrapper(
612 const internal::ASBSBatch<TaskType>* batch,
613 AdaptiveSharedBatchScheduler<TaskType>::BatchProcessor callback,
614 bool is_express) {
615 profiler::TraceMeConsumer trace_me(
616 [&] {
617 return profiler::TraceMeEncode(
618 "ProcessBatch", {{"batch_size_before_padding", batch->size()},
619 {"_r", 2} /*root_event*/});
620 },
621 profiler::ContextType::kAdaptiveSharedBatchScheduler,
622 batch->traceme_context_id());
623 int64_t start_time = batch->creation_time_micros();
624 callback(std::unique_ptr<Batch<TaskType>>(
625 const_cast<internal::ASBSBatch<TaskType>*>(batch)));
626 int64_t end_time = GetEnv()->NowMicros();
627 mutex_lock l(mu_);
628 if (is_express) {
629 in_flight_express_batches_--;
630 MaybeScheduleClosedBatchesLocked();
631 return;
632 }
633 in_flight_batches_--;
634 batch_count_++;
635 batch_latency_sum_ += end_time - start_time;
636 // Occasionally adjust in_flight_batches_limit_ to minimize average latency.
637 // Although the optimal value may depend on the workload, the latency should
638 // be a simple convex function of in_flight_batches_limit_, allowing us to
639 // locate the global minimum relatively quickly.
640 if (batch_count_ == options_.batches_to_average_over) {
641 double current_avg_latency_ms = (batch_latency_sum_ / 1000.) / batch_count_;
642 bool current_latency_decreased =
643 current_avg_latency_ms < last_avg_latency_ms_;
644 if (current_latency_decreased) {
645 // If latency improvement was because we're moving in the correct
646 // direction, increase step_size so that we can get to the minimum faster.
647 // If latency improvement was due to backtracking from a previous failure,
648 // decrease step_size in order to refine our location.
649 step_size_multiplier_ *= (last_latency_decreased_ ? 2 : 0.5);
650 step_size_multiplier_ =
651 std::min(step_size_multiplier_, kMaxStepSizeMultiplier);
652 step_size_multiplier_ =
653 std::max(step_size_multiplier_, kMinStepSizeMultiplier);
654 } else {
655 // Return (nearly) to previous position and confirm that latency is better
656 // there before decreasing step size.
657 step_direction_ = -step_direction_;
658 }
659 in_flight_batches_limit_ +=
660 step_direction_ * in_flight_batches_limit_ * step_size_multiplier_;
661 in_flight_batches_limit_ =
662 std::min(in_flight_batches_limit_,
663 static_cast<double>(options_.num_batch_threads));
664 in_flight_batches_limit_ =
665 std::max(in_flight_batches_limit_,
666 static_cast<double>(options_.min_in_flight_batches_limit));
667 last_avg_latency_ms_ = current_avg_latency_ms;
668 last_latency_decreased_ = current_latency_decreased;
669 batch_count_ = 0;
670 batch_latency_sum_ = 0;
671 }
672 MaybeScheduleNextBatch();
673 }
674
675 // ---------------- ASBSQueue ----------------
676
677 namespace internal {
678 template <typename TaskType>
ASBSQueue(std::shared_ptr<AdaptiveSharedBatchScheduler<TaskType>> scheduler,const QueueOptions & options)679 ASBSQueue<TaskType>::ASBSQueue(
680 std::shared_ptr<AdaptiveSharedBatchScheduler<TaskType>> scheduler,
681 const QueueOptions& options)
682 : scheduler_(scheduler), options_(options) {}
683
684 template <typename TaskType>
~ASBSQueue()685 ASBSQueue<TaskType>::~ASBSQueue() {
686 // Wait until last batch has been scheduled.
687 const int kSleepMicros = 1000;
688 for (;;) {
689 {
690 mutex_lock l(mu_);
691 if (num_enqueued_batches_ == 0) {
692 break;
693 }
694 }
695 scheduler_->GetEnv()->SleepForMicroseconds(kSleepMicros);
696 }
697 scheduler_->RemoveQueue(this);
698 }
699
700 template <typename TaskType>
Schedule(std::unique_ptr<TaskType> * task)701 Status ASBSQueue<TaskType>::Schedule(std::unique_ptr<TaskType>* task) {
702 size_t size = (*task)->size();
703 if (options_.split_input_task_func == nullptr &&
704 size > options_.max_batch_size) {
705 return errors::InvalidArgument("Task size ", size,
706 " is larger than maximum batch size ",
707 options_.max_batch_size);
708 }
709 if (options_.max_input_task_size.has_value() &&
710 (size > options_.max_input_task_size.value())) {
711 return errors::InvalidArgument("Task size ", size,
712 " is larger than max input task size ",
713 options_.max_input_task_size.value());
714 }
715
716 std::vector<std::unique_ptr<TaskType>> tasks_to_schedule;
717 std::vector<ASBSBatch<TaskType>*> new_batches;
718 bool closed_batch = false;
719 {
720 mutex_lock l(mu_);
721 if (size > SchedulingCapacityLocked()) {
722 return errors::Unavailable("The batch scheduling queue is full");
723 }
724
725 int remaining_batch_size =
726 current_batch_ == nullptr
727 ? options_.max_batch_size
728 : options_.max_batch_size - current_batch_->size();
729 if (options_.split_input_task_func == nullptr ||
730 size <= remaining_batch_size) {
731 // Either we don't allow task splitting or task fits within the current
732 // batch.
733 tasks_to_schedule.push_back(std::move(*task));
734 } else {
735 // Split task in order to completely fill the current batch.
736 // Beyond this point Schedule should not fail, as the caller has been
737 // promised that all of the split tasks will be scheduled.
738 TF_RETURN_IF_ERROR(options_.split_input_task_func(
739 task, remaining_batch_size, options_.max_batch_size,
740 &tasks_to_schedule));
741 }
742 for (auto& task : tasks_to_schedule) {
743 // Can't fit within current batch, close it off and try to create another.
744 if (current_batch_ &&
745 current_batch_->size() + task->size() > options_.max_batch_size) {
746 current_batch_->Close();
747 closed_batch = true;
748 current_batch_ = nullptr;
749 }
750 if (!current_batch_) {
751 num_enqueued_batches_++;
752 // batch.traceme_context_id connects TraceMeProducer and
753 // TraceMeConsumer.
754 // When multiple calls to "ASBS::Schedule" accumulate to one batch, they
755 // are processed in the same batch and should share traceme_context_id.
756 current_batch_ = new ASBSBatch<TaskType>(
757 this, scheduler_->GetEnv()->NowMicros(),
758 options_.batch_timeout_micros, NewTraceMeContextIdForBatch());
759 new_batches.push_back(current_batch_);
760 }
761
762 // Annotate each task (corresponds to one call of schedule) with a
763 // TraceMeProducer.
764 profiler::TraceMeProducer trace_me(
765 [task_size = task->size()] {
766 return profiler::TraceMeEncode(
767 "ASBSQueue::Schedule",
768 {{"batching_input_task_size", task_size}});
769 },
770 profiler::ContextType::kAdaptiveSharedBatchScheduler,
771 this->current_batch_->traceme_context_id());
772 current_batch_->AddTask(std::move(task));
773 num_enqueued_tasks_++;
774 // If current_batch_ is now full, allow it to be processed immediately.
775 if (current_batch_->size() == options_.max_batch_size) {
776 current_batch_->Close();
777 closed_batch = true;
778 current_batch_ = nullptr;
779 }
780 }
781 }
782 // Scheduler functions must be called outside of lock, since they may call
783 // ReleaseBatch.
784 for (auto* batch : new_batches) {
785 scheduler_->AddBatch(batch);
786 }
787 if (closed_batch) {
788 scheduler_->MaybeScheduleClosedBatches();
789 }
790 return Status::OK();
791 }
792
793 template <typename TaskType>
ReleaseBatch(const ASBSBatch<TaskType> * batch)794 void ASBSQueue<TaskType>::ReleaseBatch(const ASBSBatch<TaskType>* batch) {
795 mutex_lock l(mu_);
796 num_enqueued_batches_--;
797 num_enqueued_tasks_ -= batch->num_tasks();
798 if (batch == current_batch_) {
799 current_batch_->Close();
800 current_batch_ = nullptr;
801 }
802 }
803
804 template <typename TaskType>
NumEnqueuedTasks()805 size_t ASBSQueue<TaskType>::NumEnqueuedTasks() const {
806 mutex_lock l(mu_);
807 return num_enqueued_tasks_;
808 }
809
810 template <typename TaskType>
SchedulingCapacity()811 size_t ASBSQueue<TaskType>::SchedulingCapacity() const {
812 mutex_lock l(mu_);
813 return SchedulingCapacityLocked();
814 }
815
816 template <typename TaskType>
SchedulingCapacityLocked()817 size_t ASBSQueue<TaskType>::SchedulingCapacityLocked() const {
818 const int current_batch_capacity =
819 current_batch_ ? options_.max_batch_size - current_batch_->size() : 0;
820 const int spare_batches =
821 options_.max_enqueued_batches - num_enqueued_batches_;
822 return spare_batches * options_.max_batch_size + current_batch_capacity;
823 }
824
825 template <typename TaskType>
826 // static
NewTraceMeContextIdForBatch()827 uint64 ASBSQueue<TaskType>::NewTraceMeContextIdForBatch() {
828 static std::atomic<uint64> traceme_context_id(0);
829 return traceme_context_id.fetch_add(1, std::memory_order_relaxed);
830 }
831 } // namespace internal
832 } // namespace serving
833 } // namespace tensorflow
834
835 #endif // TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_ADAPTIVE_SHARED_BATCH_SCHEDULER_H_
836