1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_ADAPTIVE_SHARED_BATCH_SCHEDULER_H_
17 #define TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_ADAPTIVE_SHARED_BATCH_SCHEDULER_H_
18
19 #include <algorithm>
20 #include <atomic>
21 #include <functional>
22 #include <memory>
23 #include <random>
24 #include <unordered_map>
25 #include <vector>
26
27 #include "absl/types/optional.h"
28 #include "tensorflow/core/kernels/batching_util/batch_scheduler.h"
29 #include "tensorflow/core/kernels/batching_util/periodic_function.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/lib/core/status.h"
32 #include "tensorflow/core/lib/core/threadpool.h"
33 #include "tensorflow/core/platform/byte_order.h"
34 #include "tensorflow/core/platform/cpu_info.h"
35 #include "tensorflow/core/platform/env.h"
36 #include "tensorflow/core/platform/errors.h"
37 #include "tensorflow/core/platform/thread_annotations.h"
38 #include "tensorflow/core/platform/threadpool_interface.h"
39 #include "tensorflow/core/platform/types.h"
40 #include "tensorflow/core/profiler/lib/connected_traceme.h"
41
42 namespace tensorflow {
43 namespace serving {
44 namespace internal {
45 template <typename TaskType>
46 class ASBSBatch;
47
48 template <typename TaskType>
49 class ASBSQueue;
50 } // namespace internal
51
52 // Shared batch scheduler designed to minimize latency. The scheduler keeps
53 // track of a number of queues (one per model or model version) which are
54 // continuously enqueuing requests. The scheduler groups the requests into
55 // batches which it periodically sends off for processing (see
56 // shared_batch_scheduler.h for more details). AdaptiveSharedBatchScheduler
57 // (ASBS) prioritizes batches primarily by age (i.e. the batch's oldest request)
58 // along with a configurable preference for scheduling larger batches first.
59 //
60 //
61 // ASBS tries to keep the system busy by maintaining an adjustable number of
62 // concurrently processed batches. If a new batch is created, and the number of
63 // in flight batches is below the target, the next (i.e. oldest) batch is
64 // immediately scheduled. Similarly, when a batch finishes processing, the
65 // target is rechecked, and another batch may be scheduled. To avoid the need
66 // to carefully tune the target for workload, model type, platform, etc, it is
67 // dynamically adjusted in order to provide the lowest average latency.
68 //
69 // Some potential use cases:
70 // Hardware Accelerators (GPUs & TPUs) - If some phase of batch processing
71 // involves serial processing by a device, from a latency perspective it is
72 // desirable to keep the device evenly loaded, avoiding the need to wait for
73 // the device to process prior batches.
74 // CPU utilization - If the batch processing is cpu dominated, you can reap
75 // latency gains when underutilized by increasing the processing rate, but
76 // back the rate off when the load increases to avoid overload.
77
78 template <typename TaskType>
79 class AdaptiveSharedBatchScheduler
80 : public std::enable_shared_from_this<
81 AdaptiveSharedBatchScheduler<TaskType>> {
82 public:
~AdaptiveSharedBatchScheduler()83 ~AdaptiveSharedBatchScheduler() {
84 // Finish processing batches before destroying other class members.
85 batch_thread_pool_.reset();
86 }
87
88 struct Options {
89 // The name to use for the pool of batch threads.
90 string thread_pool_name = {"batch_threads"};
91 // Number of batch processing threads - the maximum value of
92 // in_flight_batches_limit_. It is recommended that this value be set by
93 // running the system under load, observing the learned value for
94 // in_flight_batches_limit_, and setting this maximum to ~ 2x the value.
95 // Under low load, in_flight_batches_limit_ has no substantial effect on
96 // latency and therefore undergoes a random walk. Unreasonably large values
97 // for num_batch_threads allows for large in_flight_batches_limit_, which
98 // will harm latency for some time once load increases again.
99 int64 num_batch_threads = port::MaxParallelism();
100 // You can pass a ThreadPoolInterface directly rather than the above two
101 // parameters. If given, the above two parameers are ignored. Ownership of
102 // the threadpool is not transferred.
103 thread::ThreadPoolInterface* thread_pool = nullptr;
104 // Lower bound for in_flight_batches_limit_. As discussed above, can be used
105 // to minimize the damage caused by the random walk under low load.
106 int64 min_in_flight_batches_limit = 1;
107 // Although batch selection is primarily based on age, this parameter
108 // specifies a preference for larger batches. A full batch will be
109 // scheduled before an older, nearly empty batch as long as the age gap is
110 // less than full_batch_scheduling_boost_micros. The optimal value for this
111 // parameter should be of order the batch processing latency, but must be
112 // chosen carefully, as too large a value will harm tail latency.
113 int64 full_batch_scheduling_boost_micros = 0;
114 // The environment to use (typically only overridden by test code).
115 Env* env = Env::Default();
116 // Initial limit for number of batches being concurrently processed.
117 // Non-integer values correspond to probabilistic limits - i.e. a value of
118 // 3.2 results in an actual cap of 3 80% of the time, and 4 20% of the time.
119 double initial_in_flight_batches_limit = 3;
120 // Number of batches between adjustments of in_flight_batches_limit. Larger
121 // numbers will give less noisy latency measurements, but will be less
122 // responsive to changes in workload.
123 int64 batches_to_average_over = 1000;
124 };
125
126 // Ownership is shared between the caller of Create() and any queues created
127 // via AddQueue().
128 static Status Create(
129 const Options& options,
130 std::shared_ptr<AdaptiveSharedBatchScheduler<TaskType>>* scheduler);
131
132 struct QueueOptions {
133 // Maximum size of a batch that's formed within
134 // `ASBSQueue<TaskType>::Schedule`.
135 int max_batch_size = 1000;
136 // Maximum size of input task, which is submitted to the queue by
137 // calling `ASBSQueue<TaskType>::Schedule` and used to form batches.
138 //
139 // If specified, it should be larger than or equal to 'max_batch_size'.
140 absl::optional<int> max_input_task_size = absl::nullopt;
141 // Maximum number of enqueued (i.e. non-scheduled) batches.
142 int max_enqueued_batches = 10;
143 // Amount of time non-full batches must wait before becoming schedulable.
144 // A non-zero value can improve performance by limiting the scheduling of
145 // nearly empty batches.
146 int64 batch_timeout_micros = 0;
147 // If non nullptr, split_input_task_func should split input_task into
148 // multiple tasks, the first of which has size first_size and the remaining
149 // not exceeding max_size. This function may acquire ownership of input_task
150 // and should return a status indicating if the split was successful. Upon
151 // success, the caller can assume that all output_tasks will be scheduled.
152 // Including this option allows the scheduler to pack batches better and
153 // should usually improve overall throughput.
154 std::function<Status(std::unique_ptr<TaskType>* input_task, int first_size,
155 int max_batch_size,
156 std::vector<std::unique_ptr<TaskType>>* output_tasks)>
157 split_input_task_func;
158 };
159
160 using BatchProcessor = std::function<void(std::unique_ptr<Batch<TaskType>>)>;
161
162 // Adds queue (and its callback) to be managed by this scheduler.
163 Status AddQueue(const QueueOptions& options,
164 BatchProcessor process_batch_callback,
165 std::unique_ptr<BatchScheduler<TaskType>>* queue);
166
in_flight_batches_limit()167 double in_flight_batches_limit() {
168 mutex_lock l(mu_);
169 return in_flight_batches_limit_;
170 }
171
172 private:
173 // access to AddBatch, MaybeScheduleClosedBatches, RemoveQueue, GetEnv.
174 friend class internal::ASBSQueue<TaskType>;
175
176 explicit AdaptiveSharedBatchScheduler(const Options& options);
177
178 // Tracks processing latency and adjusts in_flight_batches_limit to minimize.
179 void CallbackWrapper(const internal::ASBSBatch<TaskType>* batch,
180 BatchProcessor callback, bool is_express);
181
182 // Schedules batch if in_flight_batches_limit_ is not met.
183 void MaybeScheduleNextBatch() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
184
185 // Schedules all closed batches in batches_ for which an idle thread is
186 // available in batch_thread_pool_.
187 // Batches scheduled this way are called express batches.
188 // Express batches are not limited by in_flight_batches_limit_, and
189 // their latencies will not affect in_flight_batches_limit_.
190 void MaybeScheduleClosedBatches();
191
192 void MaybeScheduleClosedBatchesLocked() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
193
194 // Notifies scheduler of non-empty batch which is eligible for processing.
195 void AddBatch(const internal::ASBSBatch<TaskType>* batch);
196
197 // Removes queue from scheduler.
198 void RemoveQueue(const internal::ASBSQueue<TaskType>* queue);
199
GetEnv()200 Env* GetEnv() const { return options_.env; }
201
202 const Options options_;
203
204 // Collection of batches added by AddBatch, ordered by age. Owned by scheduler
205 // until they are released for processing.
206 std::vector<const internal::ASBSBatch<TaskType>*> batches_ TF_GUARDED_BY(mu_);
207
208 // Unowned queues and callbacks added by AddQueue.
209 std::unordered_map<const internal::ASBSQueue<TaskType>*, BatchProcessor>
210 queues_and_callbacks_ TF_GUARDED_BY(mu_);
211
212 mutex mu_;
213
214 // Responsible for running the batch processing callbacks.
215 std::unique_ptr<thread::ThreadPool> batch_thread_pool_;
216
217 // Limit on number of batches which can be concurrently processed.
218 // Non-integer values correspond to probabilistic limits - i.e. a value of 3.2
219 // results in an actual cap of 3 80% of the time, and 4 20% of the time.
220 double in_flight_batches_limit_ TF_GUARDED_BY(mu_);
221
222 // Number of regular batches currently being processed.
223 int64 in_flight_batches_ TF_GUARDED_BY(mu_) = 0;
224 // Number of express batches currently being processed.
225 int64 in_flight_express_batches_ TF_GUARDED_BY(mu_) = 0;
226
227 // RNG engine and distribution.
228 std::default_random_engine rand_engine_;
229 std::uniform_real_distribution<double> rand_double_;
230
231 // Fields controlling the dynamic adjustment of in_flight_batches_limit_.
232 // Number of batches since the last in_flight_batches_limit_ adjustment.
233 int64 batch_count_ TF_GUARDED_BY(mu_) = 0;
234 // Sum of processing latency for batches counted by batch_count_.
235 int64 batch_latency_sum_ TF_GUARDED_BY(mu_) = 0;
236 // Average batch latency for previous value of in_flight_batches_limit_.
237 double last_avg_latency_ms_ TF_GUARDED_BY(mu_) = 0;
238 // Did last_avg_latency_ms_ decrease from the previous last_avg_latency_ms_?
239 bool last_latency_decreased_ TF_GUARDED_BY(mu_) = false;
240 // Current direction (+-) to adjust in_flight_batches_limit_
241 int step_direction_ TF_GUARDED_BY(mu_) = 1;
242 // Max adjustment size (as a fraction of in_flight_batches_limit_).
243 constexpr static double kMaxStepSizeMultiplier = 0.125; // 1/8;
244 // Min adjustment size (as a fraction of in_flight_batches_limit_).
245 constexpr static double kMinStepSizeMultiplier = 0.0078125; // 1/128
246 // Current adjustment size (as a fraction of in_flight_batches_limit_).
247 double step_size_multiplier_ TF_GUARDED_BY(mu_) = kMaxStepSizeMultiplier;
248
249 TF_DISALLOW_COPY_AND_ASSIGN(AdaptiveSharedBatchScheduler);
250 };
251
252 //////////////////////////////////////////////////////////
253 // Implementation details follow. API users need not read.
254
255 namespace internal {
256 // Consolidates tasks into batches, passing them off to the
257 // AdaptiveSharedBatchScheduler for processing.
258 template <typename TaskType>
259 class ASBSQueue : public BatchScheduler<TaskType> {
260 public:
261 using QueueOptions =
262 typename AdaptiveSharedBatchScheduler<TaskType>::QueueOptions;
263
264 ASBSQueue(std::shared_ptr<AdaptiveSharedBatchScheduler<TaskType>> scheduler,
265 const QueueOptions& options);
266
267 ~ASBSQueue() override;
268
269 // Adds task to current batch. Fails if the task size is larger than the batch
270 // size or if the current batch is full and this queue's number of outstanding
271 // batches is at its maximum.
272 Status Schedule(std::unique_ptr<TaskType>* task) override;
273
274 // Number of tasks waiting to be scheduled.
275 size_t NumEnqueuedTasks() const override;
276
277 // Number of size 1 tasks which could currently be scheduled without failing.
278 size_t SchedulingCapacity() const override;
279
280 // Notifies queue that a batch is about to be scheduled; the queue should not
281 // place any more tasks in this batch.
282 void ReleaseBatch(const ASBSBatch<TaskType>* batch);
283
max_task_size()284 size_t max_task_size() const override { return options_.max_batch_size; }
285
286 private:
287 // Number of size 1 tasks which could currently be scheduled without failing.
288 size_t SchedulingCapacityLocked() const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
289
290 // Returns uint64 one greater than was returned by the previous call.
291 // Context id is reused after std::numeric_limits<uint64>::max is exhausted.
292 static uint64 NewTraceMeContextIdForBatch();
293
294 std::shared_ptr<AdaptiveSharedBatchScheduler<TaskType>> scheduler_;
295 const QueueOptions options_;
296 // Owned by scheduler_.
297 ASBSBatch<TaskType>* current_batch_ TF_GUARDED_BY(mu_) = nullptr;
298 int64 num_enqueued_batches_ TF_GUARDED_BY(mu_) = 0;
299 int64 num_enqueued_tasks_ TF_GUARDED_BY(mu_) = 0;
300 mutable mutex mu_;
301 TF_DISALLOW_COPY_AND_ASSIGN(ASBSQueue);
302 };
303
304 // Batch which remembers when and by whom it was created.
305 template <typename TaskType>
306 class ASBSBatch : public Batch<TaskType> {
307 public:
ASBSBatch(ASBSQueue<TaskType> * queue,int64 creation_time_micros,int64 batch_timeout_micros,uint64 traceme_context_id)308 ASBSBatch(ASBSQueue<TaskType>* queue, int64 creation_time_micros,
309 int64 batch_timeout_micros, uint64 traceme_context_id)
310 : queue_(queue),
311 creation_time_micros_(creation_time_micros),
312 schedulable_time_micros_(creation_time_micros + batch_timeout_micros),
313 traceme_context_id_(traceme_context_id) {}
314
~ASBSBatch()315 ~ASBSBatch() override {}
316
queue()317 ASBSQueue<TaskType>* queue() const { return queue_; }
318
creation_time_micros()319 int64 creation_time_micros() const { return creation_time_micros_; }
320
schedulable_time_micros()321 int64 schedulable_time_micros() const { return schedulable_time_micros_; }
322
traceme_context_id()323 uint64 traceme_context_id() const { return traceme_context_id_; }
324
325 private:
326 ASBSQueue<TaskType>* queue_;
327 const int64 creation_time_micros_;
328 const int64 schedulable_time_micros_;
329 const uint64 traceme_context_id_;
330 TF_DISALLOW_COPY_AND_ASSIGN(ASBSBatch);
331 };
332 } // namespace internal
333
334 // ---------------- AdaptiveSharedBatchScheduler ----------------
335
336 template <typename TaskType>
337 constexpr double AdaptiveSharedBatchScheduler<TaskType>::kMaxStepSizeMultiplier;
338
339 template <typename TaskType>
340 constexpr double AdaptiveSharedBatchScheduler<TaskType>::kMinStepSizeMultiplier;
341
342 template <typename TaskType>
Create(const Options & options,std::shared_ptr<AdaptiveSharedBatchScheduler<TaskType>> * scheduler)343 Status AdaptiveSharedBatchScheduler<TaskType>::Create(
344 const Options& options,
345 std::shared_ptr<AdaptiveSharedBatchScheduler<TaskType>>* scheduler) {
346 if (options.num_batch_threads < 1) {
347 return errors::InvalidArgument("num_batch_threads must be positive; was ",
348 options.num_batch_threads);
349 }
350 if (options.min_in_flight_batches_limit < 1) {
351 return errors::InvalidArgument(
352 "min_in_flight_batches_limit must be >= 1; was ",
353 options.min_in_flight_batches_limit);
354 }
355 if (options.min_in_flight_batches_limit > options.num_batch_threads) {
356 return errors::InvalidArgument(
357 "min_in_flight_batches_limit (", options.min_in_flight_batches_limit,
358 ") must be <= num_batch_threads (", options.num_batch_threads, ")");
359 }
360 if (options.full_batch_scheduling_boost_micros < 0) {
361 return errors::InvalidArgument(
362 "full_batch_scheduling_boost_micros can't be negative; was ",
363 options.full_batch_scheduling_boost_micros);
364 }
365 if (options.initial_in_flight_batches_limit > options.num_batch_threads) {
366 return errors::InvalidArgument(
367 "initial_in_flight_batches_limit (",
368 options.initial_in_flight_batches_limit,
369 ") should not be larger than num_batch_threads (",
370 options.num_batch_threads, ")");
371 }
372 if (options.initial_in_flight_batches_limit <
373 options.min_in_flight_batches_limit) {
374 return errors::InvalidArgument("initial_in_flight_batches_limit (",
375 options.initial_in_flight_batches_limit,
376 "must be >= min_in_flight_batches_limit (",
377 options.min_in_flight_batches_limit, ")");
378 }
379 if (options.batches_to_average_over < 1) {
380 return errors::InvalidArgument(
381 "batches_to_average_over should be "
382 "greater than or equal to 1; was ",
383 options.batches_to_average_over);
384 }
385 scheduler->reset(new AdaptiveSharedBatchScheduler<TaskType>(options));
386 return Status::OK();
387 }
388
389 template <typename TaskType>
AdaptiveSharedBatchScheduler(const Options & options)390 AdaptiveSharedBatchScheduler<TaskType>::AdaptiveSharedBatchScheduler(
391 const Options& options)
392 : options_(options),
393 in_flight_batches_limit_(options.initial_in_flight_batches_limit),
394 rand_double_(0.0, 1.0) {
395 std::random_device device;
396 rand_engine_.seed(device());
397 if (options.thread_pool == nullptr) {
398 batch_thread_pool_.reset(new thread::ThreadPool(
399 GetEnv(), options.thread_pool_name, options.num_batch_threads));
400 } else {
401 batch_thread_pool_.reset(new thread::ThreadPool(options.thread_pool));
402 }
403 }
404
405 template <typename TaskType>
AddQueue(const QueueOptions & options,BatchProcessor process_batch_callback,std::unique_ptr<BatchScheduler<TaskType>> * queue)406 Status AdaptiveSharedBatchScheduler<TaskType>::AddQueue(
407 const QueueOptions& options, BatchProcessor process_batch_callback,
408 std::unique_ptr<BatchScheduler<TaskType>>* queue) {
409 if (options.max_batch_size <= 0) {
410 return errors::InvalidArgument("max_batch_size must be positive; was ",
411 options.max_batch_size);
412 }
413 if (options.max_enqueued_batches <= 0) {
414 return errors::InvalidArgument(
415 "max_enqueued_batches must be positive; was ",
416 options.max_enqueued_batches);
417 }
418 if (options.max_input_task_size.has_value()) {
419 if (options.max_input_task_size.value() < options.max_batch_size) {
420 return errors::InvalidArgument(
421 "max_input_task_size must be larger than or equal to max_batch_size;"
422 "got max_input_task_size as ",
423 options.max_input_task_size.value(), " and max_batch_size as ",
424 options.max_batch_size);
425 }
426 }
427 internal::ASBSQueue<TaskType>* asbs_queue_raw;
428 queue->reset(asbs_queue_raw = new internal::ASBSQueue<TaskType>(
429 this->shared_from_this(), options));
430 mutex_lock l(mu_);
431 queues_and_callbacks_[asbs_queue_raw] = process_batch_callback;
432 return Status::OK();
433 }
434
435 template <typename TaskType>
AddBatch(const internal::ASBSBatch<TaskType> * batch)436 void AdaptiveSharedBatchScheduler<TaskType>::AddBatch(
437 const internal::ASBSBatch<TaskType>* batch) {
438 mutex_lock l(mu_);
439 batches_.push_back(batch);
440 int64 delay_micros = batch->schedulable_time_micros() - GetEnv()->NowMicros();
441 if (delay_micros <= 0) {
442 MaybeScheduleNextBatch();
443 return;
444 }
445 // Try to schedule batch once it becomes schedulable. Although scheduler waits
446 // for all batches to finish processing before allowing itself to be deleted,
447 // MaybeScheduleNextBatch() is called in other places, and therefore it's
448 // possible the scheduler could be deleted by the time this closure runs.
449 // Grab a shared_ptr reference to prevent this from happening.
450 GetEnv()->SchedClosureAfter(
451 delay_micros, [this, lifetime_preserver = this->shared_from_this()] {
452 mutex_lock l(mu_);
453 MaybeScheduleNextBatch();
454 });
455 }
456
457 template <typename TaskType>
RemoveQueue(const internal::ASBSQueue<TaskType> * queue)458 void AdaptiveSharedBatchScheduler<TaskType>::RemoveQueue(
459 const internal::ASBSQueue<TaskType>* queue) {
460 mutex_lock l(mu_);
461 queues_and_callbacks_.erase(queue);
462 }
463
464 template <typename TaskType>
MaybeScheduleNextBatch()465 void AdaptiveSharedBatchScheduler<TaskType>::MaybeScheduleNextBatch() {
466 if (batches_.empty() || in_flight_batches_ >= in_flight_batches_limit_)
467 return;
468 // Non-integer limit handled probabilistically.
469 if (in_flight_batches_limit_ - in_flight_batches_ < 1 &&
470 rand_double_(rand_engine_) >
471 in_flight_batches_limit_ - in_flight_batches_) {
472 return;
473 }
474 auto best_it = batches_.end();
475 double best_score = (std::numeric_limits<double>::max)();
476 int64 now_micros = GetEnv()->NowMicros();
477 for (auto it = batches_.begin(); it != batches_.end(); it++) {
478 if ((*it)->schedulable_time_micros() > now_micros) continue;
479 const double score =
480 (*it)->creation_time_micros() -
481 options_.full_batch_scheduling_boost_micros * (*it)->size() /
482 static_cast<double>((*it)->queue()->max_task_size());
483 if (best_it == batches_.end() || score < best_score) {
484 best_score = score;
485 best_it = it;
486 }
487 }
488 // No schedulable batches.
489 if (best_it == batches_.end()) return;
490 const internal::ASBSBatch<TaskType>* batch = *best_it;
491 batches_.erase(best_it);
492 // Queue may destroy itself after ReleaseBatch is called.
493 batch->queue()->ReleaseBatch(batch);
494 batch_thread_pool_->Schedule(
495 std::bind(&AdaptiveSharedBatchScheduler<TaskType>::CallbackWrapper, this,
496 batch, queues_and_callbacks_[batch->queue()], false));
497 in_flight_batches_++;
498 }
499
500 template <typename TaskType>
MaybeScheduleClosedBatches()501 void AdaptiveSharedBatchScheduler<TaskType>::MaybeScheduleClosedBatches() {
502 mutex_lock l(mu_);
503 MaybeScheduleClosedBatchesLocked();
504 }
505
506 template <typename TaskType>
507 void AdaptiveSharedBatchScheduler<
MaybeScheduleClosedBatchesLocked()508 TaskType>::MaybeScheduleClosedBatchesLocked() {
509 // Only schedule closed batches if we have spare capacity.
510 int available_threads =
511 static_cast<int>(options_.num_batch_threads - in_flight_batches_ -
512 in_flight_express_batches_);
513 for (auto it = batches_.begin();
514 it != batches_.end() && available_threads > 0;) {
515 if ((*it)->IsClosed()) {
516 const internal::ASBSBatch<TaskType>* batch = *it;
517 it = batches_.erase(it);
518 batch->queue()->ReleaseBatch(batch);
519 batch_thread_pool_->Schedule(
520 std::bind(&AdaptiveSharedBatchScheduler<TaskType>::CallbackWrapper,
521 this, batch, queues_and_callbacks_[batch->queue()], true));
522 in_flight_express_batches_++;
523 available_threads--;
524 } else {
525 ++it;
526 }
527 }
528 }
529
530 template <typename TaskType>
CallbackWrapper(const internal::ASBSBatch<TaskType> * batch,AdaptiveSharedBatchScheduler<TaskType>::BatchProcessor callback,bool is_express)531 void AdaptiveSharedBatchScheduler<TaskType>::CallbackWrapper(
532 const internal::ASBSBatch<TaskType>* batch,
533 AdaptiveSharedBatchScheduler<TaskType>::BatchProcessor callback,
534 bool is_express) {
535 profiler::TraceMeConsumer trace_me(
536 [&] {
537 return profiler::TraceMeEncode(
538 "ProcessBatch", {{"batch_size_before_padding", batch->size()}});
539 },
540 profiler::ContextType::kAdaptiveSharedBatchScheduler,
541 batch->traceme_context_id());
542 int64 start_time = batch->creation_time_micros();
543 callback(std::unique_ptr<Batch<TaskType>>(
544 const_cast<internal::ASBSBatch<TaskType>*>(batch)));
545 int64 end_time = GetEnv()->NowMicros();
546 mutex_lock l(mu_);
547 if (is_express) {
548 in_flight_express_batches_--;
549 MaybeScheduleClosedBatchesLocked();
550 return;
551 }
552 in_flight_batches_--;
553 batch_count_++;
554 batch_latency_sum_ += end_time - start_time;
555 // Occasionally adjust in_flight_batches_limit_ to minimize average latency.
556 // Although the optimal value may depend on the workload, the latency should
557 // be a simple convex function of in_flight_batches_limit_, allowing us to
558 // locate the global minimum relatively quickly.
559 if (batch_count_ == options_.batches_to_average_over) {
560 double current_avg_latency_ms = (batch_latency_sum_ / 1000.) / batch_count_;
561 bool current_latency_decreased =
562 current_avg_latency_ms < last_avg_latency_ms_;
563 if (current_latency_decreased) {
564 // If latency improvement was because we're moving in the correct
565 // direction, increase step_size so that we can get to the minimum faster.
566 // If latency improvement was due to backtracking from a previous failure,
567 // decrease step_size in order to refine our location.
568 step_size_multiplier_ *= (last_latency_decreased_ ? 2 : 0.5);
569 step_size_multiplier_ =
570 std::min(step_size_multiplier_, kMaxStepSizeMultiplier);
571 step_size_multiplier_ =
572 std::max(step_size_multiplier_, kMinStepSizeMultiplier);
573 } else {
574 // Return (nearly) to previous position and confirm that latency is better
575 // there before decreasing step size.
576 step_direction_ = -step_direction_;
577 }
578 in_flight_batches_limit_ +=
579 step_direction_ * in_flight_batches_limit_ * step_size_multiplier_;
580 in_flight_batches_limit_ =
581 std::min(in_flight_batches_limit_,
582 static_cast<double>(options_.num_batch_threads));
583 in_flight_batches_limit_ =
584 std::max(in_flight_batches_limit_,
585 static_cast<double>(options_.min_in_flight_batches_limit));
586 last_avg_latency_ms_ = current_avg_latency_ms;
587 last_latency_decreased_ = current_latency_decreased;
588 batch_count_ = 0;
589 batch_latency_sum_ = 0;
590 }
591 MaybeScheduleNextBatch();
592 }
593
594 // ---------------- ASBSQueue ----------------
595
596 namespace internal {
597 template <typename TaskType>
ASBSQueue(std::shared_ptr<AdaptiveSharedBatchScheduler<TaskType>> scheduler,const QueueOptions & options)598 ASBSQueue<TaskType>::ASBSQueue(
599 std::shared_ptr<AdaptiveSharedBatchScheduler<TaskType>> scheduler,
600 const QueueOptions& options)
601 : scheduler_(scheduler), options_(options) {}
602
603 template <typename TaskType>
~ASBSQueue()604 ASBSQueue<TaskType>::~ASBSQueue() {
605 // Wait until last batch has been scheduled.
606 const int kSleepMicros = 1000;
607 for (;;) {
608 {
609 mutex_lock l(mu_);
610 if (num_enqueued_batches_ == 0) {
611 break;
612 }
613 }
614 scheduler_->GetEnv()->SleepForMicroseconds(kSleepMicros);
615 }
616 scheduler_->RemoveQueue(this);
617 }
618
619 template <typename TaskType>
Schedule(std::unique_ptr<TaskType> * task)620 Status ASBSQueue<TaskType>::Schedule(std::unique_ptr<TaskType>* task) {
621 size_t size = (*task)->size();
622 if (options_.split_input_task_func == nullptr &&
623 size > options_.max_batch_size) {
624 return errors::InvalidArgument("Task size ", size,
625 " is larger than maximum batch size ",
626 options_.max_batch_size);
627 }
628 if (options_.max_input_task_size.has_value() &&
629 (size > options_.max_input_task_size.value())) {
630 return errors::InvalidArgument("Task size ", size,
631 " is larger than max input task size ",
632 options_.max_input_task_size.value());
633 }
634
635 std::vector<std::unique_ptr<TaskType>> tasks_to_schedule;
636 std::vector<ASBSBatch<TaskType>*> new_batches;
637 bool closed_batch = false;
638 {
639 mutex_lock l(mu_);
640 if (size > SchedulingCapacityLocked()) {
641 return errors::Unavailable("The batch scheduling queue is full");
642 }
643
644 int remaining_batch_size =
645 current_batch_ == nullptr
646 ? options_.max_batch_size
647 : options_.max_batch_size - current_batch_->size();
648 if (options_.split_input_task_func == nullptr ||
649 size <= remaining_batch_size) {
650 // Either we don't allow task splitting or task fits within the current
651 // batch.
652 tasks_to_schedule.push_back(std::move(*task));
653 } else {
654 // Split task in order to completely fill the current batch.
655 // Beyond this point Schedule should not fail, as the caller has been
656 // promised that all of the split tasks will be scheduled.
657 TF_RETURN_IF_ERROR(options_.split_input_task_func(
658 task, remaining_batch_size, options_.max_batch_size,
659 &tasks_to_schedule));
660 }
661 for (auto& task : tasks_to_schedule) {
662 // Can't fit within current batch, close it off and try to create another.
663 if (current_batch_ &&
664 current_batch_->size() + task->size() > options_.max_batch_size) {
665 current_batch_->Close();
666 closed_batch = true;
667 current_batch_ = nullptr;
668 }
669 if (!current_batch_) {
670 num_enqueued_batches_++;
671 // batch.traceme_context_id connects TraceMeProducer and
672 // TraceMeConsumer.
673 // When multiple calls to "ASBS::Schedule" accumulate to one batch, they
674 // are processed in the same batch and should share traceme_context_id.
675 current_batch_ = new ASBSBatch<TaskType>(
676 this, scheduler_->GetEnv()->NowMicros(),
677 options_.batch_timeout_micros, NewTraceMeContextIdForBatch());
678 new_batches.push_back(current_batch_);
679 }
680
681 // Annotate each task (corresponds to one call of schedule) with a
682 // TraceMeProducer.
683 profiler::TraceMeProducer trace_me(
684 [task_size = task->size()] {
685 return profiler::TraceMeEncode(
686 "ASBSQueue::Schedule",
687 {{"batching_input_task_size", task_size}});
688 },
689 profiler::ContextType::kAdaptiveSharedBatchScheduler,
690 this->current_batch_->traceme_context_id());
691 current_batch_->AddTask(std::move(task));
692 num_enqueued_tasks_++;
693 // If current_batch_ is now full, allow it to be processed immediately.
694 if (current_batch_->size() == options_.max_batch_size) {
695 current_batch_->Close();
696 closed_batch = true;
697 current_batch_ = nullptr;
698 }
699 }
700 }
701 // Scheduler functions must be called outside of lock, since they may call
702 // ReleaseBatch.
703 for (auto* batch : new_batches) {
704 scheduler_->AddBatch(batch);
705 }
706 if (closed_batch) {
707 scheduler_->MaybeScheduleClosedBatches();
708 }
709 return Status::OK();
710 }
711
712 template <typename TaskType>
ReleaseBatch(const ASBSBatch<TaskType> * batch)713 void ASBSQueue<TaskType>::ReleaseBatch(const ASBSBatch<TaskType>* batch) {
714 mutex_lock l(mu_);
715 num_enqueued_batches_--;
716 num_enqueued_tasks_ -= batch->num_tasks();
717 if (batch == current_batch_) {
718 current_batch_->Close();
719 current_batch_ = nullptr;
720 }
721 }
722
723 template <typename TaskType>
NumEnqueuedTasks()724 size_t ASBSQueue<TaskType>::NumEnqueuedTasks() const {
725 mutex_lock l(mu_);
726 return num_enqueued_tasks_;
727 }
728
729 template <typename TaskType>
SchedulingCapacity()730 size_t ASBSQueue<TaskType>::SchedulingCapacity() const {
731 mutex_lock l(mu_);
732 return SchedulingCapacityLocked();
733 }
734
735 template <typename TaskType>
SchedulingCapacityLocked()736 size_t ASBSQueue<TaskType>::SchedulingCapacityLocked() const {
737 const int current_batch_capacity =
738 current_batch_ ? options_.max_batch_size - current_batch_->size() : 0;
739 const int spare_batches =
740 options_.max_enqueued_batches - num_enqueued_batches_;
741 return spare_batches * options_.max_batch_size + current_batch_capacity;
742 }
743
744 template <typename TaskType>
745 // static
NewTraceMeContextIdForBatch()746 uint64 ASBSQueue<TaskType>::NewTraceMeContextIdForBatch() {
747 static std::atomic<uint64> traceme_context_id(0);
748 return traceme_context_id.fetch_add(1, std::memory_order_relaxed);
749 }
750 } // namespace internal
751 } // namespace serving
752 } // namespace tensorflow
753
754 #endif // TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_ADAPTIVE_SHARED_BATCH_SCHEDULER_H_
755