1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "absl/strings/str_cat.h"
17 #include "tensorflow/core/common_runtime/device_mgr.h"
18 #include "tensorflow/core/framework/device.h"
19 #include "tensorflow/core/framework/function.h"
20 #include "tensorflow/core/framework/op_kernel.h"
21 #include "tensorflow/core/framework/resource_mgr.h"
22 #include "tensorflow/core/framework/tensor.h"
23 #include "tensorflow/core/framework/tensor_util.h"
24 #include "tensorflow/core/framework/types.h"
25 #include "tensorflow/core/kernels/batching_util/adaptive_shared_batch_scheduler.h"
26 #include "tensorflow/core/kernels/batching_util/batch_resource_base.h"
27 #include "tensorflow/core/kernels/batching_util/concat_split_util.h"
28 #include "tensorflow/core/kernels/batching_util/periodic_function.h"
29 #include "tensorflow/core/kernels/ops_util.h"
30 #include "tensorflow/core/lib/monitoring/gauge.h"
31 #include "tensorflow/core/lib/random/random.h"
32 #include "tensorflow/core/platform/errors.h"
33 #include "tensorflow/core/platform/logging.h"
34 #include "tensorflow/core/platform/macros.h"
35 #include "tensorflow/core/platform/numbers.h"
36 #include "tensorflow/core/platform/threadpool.h"
37 
38 namespace tensorflow {
39 namespace {
40 // Op attributes.
41 constexpr char kEnableAdaptiveSchedulerAttr[] = "_enable_adaptive_scheduler";
42 constexpr char kMinInflightBatchesAttr[] = "_min_inflight_batches";
43 constexpr char kInitialInflightBatchesAttr[] = "_initial_inflight_batches";
44 constexpr char kMaxInflightBatchesAttr[] = "_max_inflight_batches";
45 constexpr char kBatchesToAverageOverAttr[] = "_batches_to_average_over";
46 
47 // Per-model inflight batches parameters.
48 constexpr int64_t kMinInflightBatches = 16;
49 constexpr int64_t kInitialInflightBatches = 16;
50 constexpr int64_t kBatchesToAverageOver = 10;
51 constexpr int64_t kMaxInflightBatches = 64;
52 
53 // The max number of threads in the per-process thread pool, shared by
54 // all executions of batch-op.
55 constexpr int64_t kBatchThreadPoolSize = 128;
56 }  // namespace
57 
58 auto* batch_op_split_usage = monitoring::Gauge<string, 1>::New(
59     "/tensorflow/serving/batching/enable_large_batch_splitting",
60     "Tracks the usage of attribute `enable_large_batch_splitting` for "
61     "BatchFunction kernel in a saved model.",
62     "model_name");
63 
RecordBatchSplitUsage(absl::optional<bool> maybe_enable_large_batch_splitting,const string & model_name)64 void RecordBatchSplitUsage(
65     absl::optional<bool> maybe_enable_large_batch_splitting,
66     const string& model_name) {
67   if (maybe_enable_large_batch_splitting.has_value()) {
68     if (maybe_enable_large_batch_splitting.value()) {
69       batch_op_split_usage->GetCell(model_name)->Set("true");
70     } else {
71       batch_op_split_usage->GetCell(model_name)->Set("false");
72     }
73   } else {
74     batch_op_split_usage->GetCell(model_name)->Set("unset");
75   }
76 }
77 
RecordBatchParamNumBatchThreads(int64_t num_batch_threads,const string & model_name)78 void RecordBatchParamNumBatchThreads(int64_t num_batch_threads,
79                                      const string& model_name) {
80   static auto* cell = monitoring::Gauge<int64, 1>::New(
81       "/tensorflow/serving/batching/num_batch_threads",
82       "Tracks the number of batch threads of a model.", "model_name");
83   cell->GetCell(model_name)->Set(num_batch_threads);
84 }
85 
GetModelName(OpKernelContext * ctx)86 const string& GetModelName(OpKernelContext* ctx) {
87   static string* kModelNameUnset = new string("model_name_unset");
88   if (!ctx->session_metadata()) return *kModelNameUnset;
89   if (ctx->session_metadata()->name().empty()) return *kModelNameUnset;
90   return ctx->session_metadata()->name();
91 }
92 
93 using ::tensorflow::concat_split_util::Concat;
94 using ::tensorflow::concat_split_util::Split;
95 
GetOrCreateBatchThreadsPool(const string & thread_name,int num_batch_threads)96 static thread::ThreadPool* GetOrCreateBatchThreadsPool(
97     const string& thread_name, int num_batch_threads) {
98   static thread::ThreadPool* pool =
99       new thread::ThreadPool(Env::Default(), thread_name, num_batch_threads);
100   return pool;
101 }
102 
103 // A class encapsulating the state and logic for batching tensors.
104 class BatchResource : public serving::BatchResourceBase {
105  public:
Create(int32_t num_batch_threads,int32_t max_execution_batch_size,int32_t batch_timeout_micros,int32_t max_enqueued_batches,const std::vector<int32> & allowed_batch_sizes,FunctionLibraryRuntime::Handle fhandle,FunctionLibraryRuntime * flib,bool enable_large_batch_splitting,std::unique_ptr<BatchResource> * resource)106   static Status Create(int32_t num_batch_threads,
107                        int32_t max_execution_batch_size,
108                        int32_t batch_timeout_micros,
109                        int32_t max_enqueued_batches,
110                        const std::vector<int32>& allowed_batch_sizes,
111                        FunctionLibraryRuntime::Handle fhandle,
112                        FunctionLibraryRuntime* flib,
113                        bool enable_large_batch_splitting,
114                        std::unique_ptr<BatchResource>* resource) {
115     BatcherT::Options batcher_options;
116     batcher_options.num_batch_threads = num_batch_threads;
117     std::shared_ptr<BatcherT> batcher;
118     TF_RETURN_IF_ERROR(BatcherT::Create(batcher_options, &batcher));
119 
120     resource->reset(new BatchResource(
121         fhandle, flib, std::move(batcher),
122         GetBatcherQueueOptions(num_batch_threads, max_execution_batch_size,
123                                batch_timeout_micros, max_enqueued_batches,
124                                allowed_batch_sizes,
125                                enable_large_batch_splitting),
126         allowed_batch_sizes));
127     return Status::OK();
128   }
129 
Create(AdaptiveBatcherT::Options adaptive_shared_batch_scheduler_options,int32_t max_batch_size,int32_t batch_timeout_micros,int32_t max_enqueued_batches,const std::vector<int32> & allowed_batch_sizes,FunctionLibraryRuntime::Handle fhandle,FunctionLibraryRuntime * flib,std::unique_ptr<BatchResource> * resource)130   static Status Create(
131       AdaptiveBatcherT::Options adaptive_shared_batch_scheduler_options,
132       int32_t max_batch_size, int32_t batch_timeout_micros,
133       int32_t max_enqueued_batches,
134       const std::vector<int32>& allowed_batch_sizes,
135       FunctionLibraryRuntime::Handle fhandle, FunctionLibraryRuntime* flib,
136       std::unique_ptr<BatchResource>* resource) {
137     std::shared_ptr<AdaptiveBatcherT> batcher;
138     TF_RETURN_IF_ERROR(AdaptiveBatcherT::Create(
139         adaptive_shared_batch_scheduler_options, &batcher));
140 
141     resource->reset(new BatchResource(
142         fhandle, flib, std::move(batcher),
143         GetAdaptiveBatcherQueueOptions(
144             max_batch_size, batch_timeout_micros, max_enqueued_batches,
145             true /* enable large batch split */, allowed_batch_sizes),
146         allowed_batch_sizes));
147     return Status::OK();
148   }
149 
DebugString() const150   string DebugString() const final { return "BatchResource"; }
151 
152  private:
BatchResource(FunctionLibraryRuntime::Handle fhandle,FunctionLibraryRuntime * flib,std::shared_ptr<BatcherT> batcher,const BatcherT::QueueOptions & batcher_queue_options,std::vector<int32> allowed_batch_sizes)153   BatchResource(FunctionLibraryRuntime::Handle fhandle,
154                 FunctionLibraryRuntime* flib, std::shared_ptr<BatcherT> batcher,
155                 const BatcherT::QueueOptions& batcher_queue_options,
156                 std::vector<int32> allowed_batch_sizes)
157       : BatchResourceBase(
158             /*has_process_batch_function=*/fhandle != kInvalidHandle,
159             std::move(batcher), batcher_queue_options,
160             std::move(allowed_batch_sizes)),
161         fhandle_(fhandle),
162         flib_(flib) {}
163 
BatchResource(FunctionLibraryRuntime::Handle fhandle,FunctionLibraryRuntime * flib,std::shared_ptr<AdaptiveBatcherT> batcher,const AdaptiveBatcherT::QueueOptions & batcher_queue_options,std::vector<int32> allowed_batch_sizes)164   BatchResource(FunctionLibraryRuntime::Handle fhandle,
165                 FunctionLibraryRuntime* flib,
166                 std::shared_ptr<AdaptiveBatcherT> batcher,
167                 const AdaptiveBatcherT::QueueOptions& batcher_queue_options,
168                 std::vector<int32> allowed_batch_sizes)
169       : BatchResourceBase(
170             /*has_process_batch_function=*/fhandle != kInvalidHandle,
171             std::move(batcher), batcher_queue_options,
172             std::move(allowed_batch_sizes)),
173         fhandle_(fhandle),
174         flib_(flib) {}
175 
ProcessFuncBatchImpl(const BatchTask & last_task,absl::Span<const Tensor> inputs,std::vector<Tensor> * combined_outputs,std::function<void (const Status &)> done) const176   void ProcessFuncBatchImpl(
177       const BatchTask& last_task, absl::Span<const Tensor> inputs,
178       std::vector<Tensor>* combined_outputs,
179       std::function<void(const Status&)> done) const override {
180     auto* last_task_context = last_task.context;
181     FunctionLibraryRuntime::Options opts;
182     opts.step_container = last_task_context->step_container();
183     opts.cancellation_manager = last_task_context->cancellation_manager();
184     opts.collective_executor = last_task_context->collective_executor();
185     opts.stats_collector = last_task_context->stats_collector();
186     opts.runner = last_task_context->runner();
187     opts.run_all_kernels_inline = last_task_context->run_all_kernels_inline();
188     // We do not set 'opts.rendezvous', since if the function is run multiple
189     // times in parallel with the same rendezvous, a _Send node from one run
190     // might be matched with a _Recv node of a different run. Not setting the
191     // rendezvous causes a new rendezvous to be used for each run.
192     Notification done_notif;
193 
194     flib_->Run(opts, fhandle_, inputs, combined_outputs,
195                [&](const Status& run_status) {
196                  done(run_status);
197                  done_notif.Notify();
198                });
199     // By waiting for the notification we are ensuring that this thread isn't
200     // used for processing other batches, which gives the batches time to
201     // coalesce upstream. So overall the number of batches going through the
202     // devices goes down, improving latency and throughput in most cases.
203     done_notif.WaitForNotification();
204   }
205 
206   FunctionLibraryRuntime::Handle fhandle_;
207   FunctionLibraryRuntime* flib_;
208 };
209 
210 class BatchFunctionKernel : public AsyncOpKernel {
211  public:
BatchFunctionKernel(OpKernelConstruction * c)212   explicit BatchFunctionKernel(OpKernelConstruction* c) : AsyncOpKernel(c) {
213     OP_REQUIRES_OK(c, c->GetAttr("container", &container_));
214     OP_REQUIRES_OK(c, c->GetAttr("shared_name", &shared_name_));
215     OP_REQUIRES_OK(c, c->GetAttr("batching_queue", &batcher_queue_));
216     OP_REQUIRES_OK(c, c->GetAttr("num_batch_threads", &num_batch_threads_));
217     OP_REQUIRES_OK(c, c->GetAttr("max_batch_size", &max_batch_size_));
218     OP_REQUIRES_OK(c,
219                    c->GetAttr("batch_timeout_micros", &batch_timeout_micros_));
220     OP_REQUIRES_OK(c,
221                    c->GetAttr("max_enqueued_batches", &max_enqueued_batches_));
222     OP_REQUIRES_OK(c, c->GetAttr("allowed_batch_sizes", &allowed_batch_sizes_));
223 
224     OP_REQUIRES_OK(c, c->GetAttr("f", &func_));
225     flib_ = c->function_library();
226 
227     if (c->HasAttr("enable_large_batch_splitting")) {
228       OP_REQUIRES_OK(c, c->GetAttr("enable_large_batch_splitting",
229                                    &enable_large_batch_splitting_));
230       has_attribute_enable_large_batch_splitting_ = true;
231     } else {
232       enable_large_batch_splitting_ = false;
233       has_attribute_enable_large_batch_splitting_ = false;
234     }
235 
236     // Helper function `SetAdaptiveBatchSchedulerOptions` calls
237     // `OP_REQUIRES_OK`, which exits the current function upon error.
238     // So validate status of `op-kernel-construction`.
239     SetAdaptiveBatchSchedulerOptions(c, num_batch_threads_);
240     if (!c->status().ok()) {
241       return;
242     }
243 
244     if (enable_adaptive_batch_threads_) {
245       // One scheduler instance contains a couple of queue instances,
246       // `batcher_queue_` is the key to find queue for this batch-op in the
247       // graph.
248       // Use `shared_name_` and name() as prefix for `batcher_queue_`.
249       // Note name() is unique per session (from session metadata).
250       batcher_queue_ = name() + "/" + shared_name_ + batcher_queue_;
251     }
252 
253     if (shared_name_.empty()) {
254       // If shared_name is not supplied, use name instead (prevent collisions by
255       // default).
256       shared_name_ = name();
257     }
258 
259     OP_REQUIRES_OK(c, ValidateAllowedBatchSizes());
260   }
261 
IsExpensive()262   bool IsExpensive() override { return false; }
263 
ComputeAsync(OpKernelContext * c,DoneCallback done)264   void ComputeAsync(OpKernelContext* c, DoneCallback done) final {
265     RecordBatchSplitUsage(
266         has_attribute_enable_large_batch_splitting_
267             ? absl::make_optional(enable_large_batch_splitting_)
268             : absl::nullopt,
269         GetModelName(c));
270     // TODO(b/173255290): Add num_batch_threads_ parameter to TFRT batch kernel.
271     RecordBatchParamNumBatchThreads(num_batch_threads_, GetModelName(c));
272 
273     std::function<Status(BatchResource**)> creator;
274 
275     FunctionLibraryRuntime::Handle handle;
276     OP_REQUIRES_OK_ASYNC(c, GetOrCreateFunctionHandle(c, &handle), done);
277 
278     if (adaptive_batch_scheduler_options_ != absl::nullopt) {
279       creator = [this, handle](BatchResource** r) {
280         serving::AdaptiveSharedBatchScheduler<
281             serving::BatchResourceBase::BatchTask>::Options
282             adaptive_shared_batch_scheduler_options;
283         adaptive_shared_batch_scheduler_options.thread_pool_name =
284             "adaptive_batch_threads";
285         adaptive_shared_batch_scheduler_options.num_batch_threads =
286             adaptive_batch_scheduler_options_->max_in_flight_batches_limit;
287         adaptive_shared_batch_scheduler_options.thread_pool =
288             GetOrCreateBatchThreadsPool(std::string("adaptive_batch_threads"),
289                                         kBatchThreadPoolSize);
290         // adaptive_shared_batch_scheduler_options.full_batch_scheduling_boost_micros
291         // is 0 (default value) intentionally, so tasks are scheduled in a FIFO
292         // way.
293         // Two rationales to use default value (zero) for
294         // `full_batch_scheduling_boost_micros`
295         // 1) In this way, tasks scheduling policy is FIFO. Compared with round
296         // robin (what shared batch scheduler does), FIFO ensures that model
297         // with low QPS (i.e., models enqueue fewer tasks in the shared queue)
298         // will be processed timely.
299         // 2) If set, `full_batch_scheduling_boost_micros` should be of order
300         // the batch processing latency (which varies on a model basis).
301         // If a non-zero value is not set properly, it harms tail latency.
302         adaptive_shared_batch_scheduler_options.min_in_flight_batches_limit =
303             adaptive_batch_scheduler_options_->min_in_flight_batches_limit;
304         adaptive_shared_batch_scheduler_options
305             .initial_in_flight_batches_limit =
306             adaptive_batch_scheduler_options_->initial_in_flight_batches_limit;
307         adaptive_shared_batch_scheduler_options.batches_to_average_over =
308             adaptive_batch_scheduler_options_->batches_to_average_over;
309         adaptive_shared_batch_scheduler_options.fifo_scheduling = true;
310         std::unique_ptr<BatchResource> new_resource;
311         TF_RETURN_IF_ERROR(BatchResource::Create(
312             adaptive_shared_batch_scheduler_options, max_batch_size_,
313             batch_timeout_micros_, max_enqueued_batches_, allowed_batch_sizes_,
314             handle, flib_, &new_resource));
315         *r = new_resource.release();
316         return Status::OK();
317       };
318     } else {
319       creator = [this, handle](BatchResource** r) {
320         std::unique_ptr<BatchResource> new_resource;
321         TF_RETURN_IF_ERROR(BatchResource::Create(
322             num_batch_threads_, max_batch_size_, batch_timeout_micros_,
323             max_enqueued_batches_, allowed_batch_sizes_, handle, flib_,
324             enable_large_batch_splitting_, &new_resource));
325         *r = new_resource.release();
326         return Status::OK();
327       };
328     }
329 
330     BatchResource* br;
331     OP_REQUIRES_OK_ASYNC(c,
332                          c->resource_manager()->LookupOrCreate(
333                              container_, shared_name_, &br, creator),
334                          done);
335     const Status status =
336         br->RegisterInput(random::New64(), c, batcher_queue_, done);
337     br->Unref();
338     OP_REQUIRES_OK_ASYNC(c, status, done);
339     // Assume br calls done, so nothing to do here.
340   }
341 
InstantiateFunction(OpKernelContext * c,FunctionLibraryRuntime::Handle * handle) const342   Status InstantiateFunction(OpKernelContext* c,
343                              FunctionLibraryRuntime::Handle* handle) const {
344     // TODO(b/173748062): Merge this instantiation logic with PartitionedCall.
345     if (!flib_) {
346       return errors::Internal("No function library");
347     }
348 
349     FunctionLibraryRuntime::InstantiateOptions opts;
350     opts.target = flib_->device() == nullptr ? "" : flib_->device()->name();
351     opts.is_multi_device_function = true;
352     const ConfigProto* config = flib_->config_proto();
353     if (config) {
354       opts.config_proto = *config;
355     }
356 
357     Device* cpu_device;
358     TF_RETURN_IF_ERROR(flib_->device_mgr()->LookupDevice("CPU:0", &cpu_device));
359 
360     const FunctionDef* fdef =
361         flib_->GetFunctionLibraryDefinition()->Find(func_.name());
362     if (!fdef) {
363       return errors::NotFound("Failed to find definition for function \"",
364                               func_.name(), "\"");
365     }
366     OpInputList in_tensors;
367     TF_RETURN_IF_ERROR(c->input_list("in_tensors", &in_tensors));
368     for (int i = 0; i < in_tensors.size(); i++) {
369       if (in_tensors[i].dtype() == DT_RESOURCE) {
370         return errors::InvalidArgument(
371             "BatchFunction cannot take resource inputs but input ", i,
372             " is a resource.");
373       } else {
374         // Currently, inputs are on CPU since they are concatenated on CPU
375         opts.input_devices.push_back(cpu_device->name());
376       }
377     }
378     OpInputList captured_tensors;
379     TF_RETURN_IF_ERROR(c->input_list("captured_tensors", &captured_tensors));
380     for (const Tensor& t : captured_tensors) {
381       if (t.dtype() == DT_RESOURCE) {
382         const ResourceHandle& rhandle = t.flat<ResourceHandle>()(0);
383         opts.input_devices.push_back(rhandle.device());
384       } else {
385         opts.input_devices.push_back(cpu_device->name());
386       }
387     }
388     const OpDef& signature = fdef->signature();
389     for (int i = 0; i < signature.output_arg_size(); i++) {
390       // Currently, outputs must be on CPU since they are split on CPU.
391       opts.output_devices.push_back(cpu_device->name());
392     }
393     if (opts.input_devices.size() != signature.input_arg_size()) {
394       return errors::InvalidArgument(
395           "Function takes ", signature.input_arg_size(), " argument(s) but ",
396           opts.input_devices.size(), " argument(s) were passed");
397     }
398     return flib_->Instantiate(func_.name(), AttrSlice(&func_.attr()), opts,
399                               handle);
400   }
401 
GetOrCreateFunctionHandle(OpKernelContext * c,FunctionLibraryRuntime::Handle * handle)402   Status GetOrCreateFunctionHandle(OpKernelContext* c,
403                                    FunctionLibraryRuntime::Handle* handle) {
404     mutex_lock ml(mu_);
405     if (!fhandle_) {
406       TF_RETURN_IF_ERROR(InstantiateFunction(c, handle));
407       fhandle_ = *handle;
408     } else {
409       *handle = fhandle_.value();
410     }
411     return Status::OK();
412   }
413 
414   // Validates 'allowed_batch_sizes_'. The entries must increase monotonically.
415   // If large batch split is not enabled, the last one must equal
416   // `max_batch_size_`. otherwise the last element must be smaller than or equal
417   // to `max_batch_size_`.
ValidateAllowedBatchSizes() const418   Status ValidateAllowedBatchSizes() const {
419     if (allowed_batch_sizes_.empty()) {
420       return Status::OK();
421     }
422     int32_t last_size = 0;
423     for (size_t i = 0; i < allowed_batch_sizes_.size(); ++i) {
424       const int32_t size = allowed_batch_sizes_.at(i);
425       if (i > 0 && size <= last_size) {
426         return errors::InvalidArgument(
427             "allowed_batch_sizes entries must be monotonically increasing");
428       }
429 
430       if ((!enable_large_batch_splitting_) &&
431           (i == allowed_batch_sizes_.size() - 1) && (size != max_batch_size_)) {
432         return errors::InvalidArgument(
433             "final entry in allowed_batch_sizes must equal max_batch_size when "
434             "enable_large_batch_splitting is False");
435       }
436 
437       last_size = size;
438     }
439     return Status::OK();
440   }
441 
442  private:
443   // Initialize vars by reading from op-kernel-construction.
444   // Vars
445   // - enable_adaptive_batch_threads_
446   //   true if value of attribute `kEnableAdaptiveSchedulerAttr` is true, or
447   //   if `num_batch_threads` is not positive.
448   // - adaptive_batch_scheduler_options_
449   //   Read from corresponding attributes as long as they are set.
SetAdaptiveBatchSchedulerOptions(OpKernelConstruction * c,int32_t num_batch_threads)450   void SetAdaptiveBatchSchedulerOptions(OpKernelConstruction* c,
451                                         int32_t num_batch_threads) {
452     if (c->HasAttr(kEnableAdaptiveSchedulerAttr)) {
453       OP_REQUIRES_OK(c, c->GetAttr(kEnableAdaptiveSchedulerAttr,
454                                    &enable_adaptive_batch_threads_));
455     }
456 
457     if (num_batch_threads <= 0) {
458       enable_adaptive_batch_threads_ = true;
459     }
460 
461     if (!enable_adaptive_batch_threads_) {
462       // adaptive_batch_scheduler_options_ is nullopt.
463       return;
464     }
465 
466     // adaptive_batch_scheduler_options_ is not nullopt
467     AdaptiveBatchSchedulerOptions options;
468 
469     if (c->HasAttr(kBatchesToAverageOverAttr)) {
470       OP_REQUIRES_OK(c, c->GetAttr(kBatchesToAverageOverAttr,
471                                    &options.batches_to_average_over));
472     }
473 
474     if (c->HasAttr(kMinInflightBatchesAttr)) {
475       OP_REQUIRES_OK(c, c->GetAttr(kMinInflightBatchesAttr,
476                                    &options.min_in_flight_batches_limit));
477     }
478 
479     if (c->HasAttr(kInitialInflightBatchesAttr)) {
480       OP_REQUIRES_OK(c, c->GetAttr(kInitialInflightBatchesAttr,
481                                    &options.initial_in_flight_batches_limit));
482     }
483 
484     if (c->HasAttr(kMaxInflightBatchesAttr)) {
485       OP_REQUIRES_OK(c, c->GetAttr(kMaxInflightBatchesAttr,
486                                    &options.max_in_flight_batches_limit));
487     }
488 
489     adaptive_batch_scheduler_options_ = options;
490   }
491 
492   string container_;
493   string shared_name_;
494   string batcher_queue_;
495   int32 num_batch_threads_;
496   int32 max_batch_size_;
497   int32 batch_timeout_micros_;
498   int32 max_enqueued_batches_;
499   std::vector<int32> allowed_batch_sizes_;
500   NameAttrList func_;
501   absl::optional<FunctionLibraryRuntime::Handle> fhandle_ TF_GUARDED_BY(mu_);
502   FunctionLibraryRuntime* flib_;
503   bool enable_large_batch_splitting_;
504   bool has_attribute_enable_large_batch_splitting_;
505   bool enable_adaptive_batch_threads_ = false;
506   mutex mu_;
507 
508   // Parameters for adaptive batch scheduler only.
509   // Note 'num_batch_threads_' above is shared by two implementations of batch
510   // scheduler.
511   struct AdaptiveBatchSchedulerOptions {
512     int32 min_in_flight_batches_limit = kMinInflightBatches;
513     int32 initial_in_flight_batches_limit = kInitialInflightBatches;
514     int32 max_in_flight_batches_limit = kMaxInflightBatches;
515     int32 batches_to_average_over = kBatchesToAverageOver;
516   };
517   absl::optional<AdaptiveBatchSchedulerOptions>
518       adaptive_batch_scheduler_options_ = absl::nullopt;
519 };
520 
521 REGISTER_KERNEL_BUILDER(Name("BatchFunction").Device(DEVICE_CPU),
522                         BatchFunctionKernel);
523 // Currently all inputs and outputs are on the host.
524 // TODO(b/173748277): Accept inputs/outputs on the device.
525 REGISTER_KERNEL_BUILDER(Name("BatchFunction")
526                             .Device(DEVICE_GPU)
527                             .HostMemory("in_tensors")
528                             .HostMemory("captured_tensors")
529                             .HostMemory("out_tensors"),
530                         BatchFunctionKernel);
531 
532 class BatchKernel : public AsyncOpKernel {
533  public:
BatchKernel(OpKernelConstruction * c)534   explicit BatchKernel(OpKernelConstruction* c) : AsyncOpKernel(c) {
535     OP_REQUIRES_OK(c, c->GetAttr("container", &container_));
536     OP_REQUIRES_OK(c, c->GetAttr("shared_name", &shared_name_));
537     // If shared_name is not supplied, use name instead (prevent collisions by
538     // default).
539     if (shared_name_.empty()) {
540       shared_name_ = name();
541     }
542     OP_REQUIRES_OK(c, c->GetAttr("batching_queue", &batcher_queue_));
543     OP_REQUIRES_OK(c, c->GetAttr("num_batch_threads", &num_batch_threads_));
544     OP_REQUIRES_OK(c, c->GetAttr("max_batch_size", &max_batch_size_));
545     OP_REQUIRES_OK(c,
546                    c->GetAttr("batch_timeout_micros", &batch_timeout_micros_));
547     OP_REQUIRES_OK(c,
548                    c->GetAttr("max_enqueued_batches", &max_enqueued_batches_));
549     OP_REQUIRES_OK(c, c->GetAttr("allowed_batch_sizes", &allowed_batch_sizes_));
550     OP_REQUIRES_OK(c, ValidateAllowedBatchSizes());
551   }
552 
ComputeAsync(OpKernelContext * c,DoneCallback done)553   void ComputeAsync(OpKernelContext* c, DoneCallback done) final {
554     BatchResource* br;
555     std::function<Status(BatchResource**)> creator = [this](BatchResource** r) {
556       std::unique_ptr<BatchResource> new_resource;
557       TF_RETURN_IF_ERROR(BatchResource::Create(
558           num_batch_threads_, max_batch_size_, batch_timeout_micros_,
559           max_enqueued_batches_, allowed_batch_sizes_, kInvalidHandle,
560           /*flib=*/nullptr, false, &new_resource));
561       *r = new_resource.release();
562       return Status::OK();
563     };
564     OP_REQUIRES_OK_ASYNC(c,
565                          c->resource_manager()->LookupOrCreate(
566                              container_, shared_name_, &br, creator),
567                          done);
568     const Status status =
569         br->RegisterInput(random::New64(), c, batcher_queue_, done);
570     br->Unref();
571     OP_REQUIRES_OK_ASYNC(c, status, done);
572     // Assume br calls done, so nothing to do here.
573   }
574 
575   // Validates 'allowed_batch_sizes_'. The entries must increase
576   // monotonically, and the last one must equal 'max_batch_size_'.
ValidateAllowedBatchSizes() const577   Status ValidateAllowedBatchSizes() const {
578     if (allowed_batch_sizes_.empty()) {
579       return Status::OK();
580     }
581     int32_t last_size = 0;
582     for (size_t i = 0; i < allowed_batch_sizes_.size(); ++i) {
583       const int32_t size = allowed_batch_sizes_.at(i);
584       if (i > 0 && size <= last_size) {
585         return errors::InvalidArgument(
586             "allowed_batch_sizes entries must be monotonically increasing");
587       }
588       if (i == allowed_batch_sizes_.size() - 1 && size != max_batch_size_) {
589         return errors::InvalidArgument(
590             "final entry in allowed_batch_sizes must equal max_batch_size");
591       }
592       last_size = size;
593     }
594     return Status::OK();
595   }
596 
597  private:
598   string container_;
599   string shared_name_;
600   string batcher_queue_;
601   int32 num_batch_threads_;
602   int32 max_batch_size_;
603   int32 batch_timeout_micros_;
604   int32 max_enqueued_batches_;
605   std::vector<int32> allowed_batch_sizes_;
606 };
607 
608 REGISTER_KERNEL_BUILDER(Name("Batch").Device(DEVICE_CPU), BatchKernel);
609 
610 // A class encapsulating the state and logic for unbatching tensors.
611 //
612 // UnbatchResource keeps two data structures indexed by batch-key: one which has
613 // the continuations for all concurrent kernels which are waiting for tensors
614 // and another which has tensors which are waiting for their corresponding
615 // kernels to run. Whenever a kernel runs, we either grab its tensor if it's
616 // waiting already, or we insert it in the queue and then look at its tensor to
617 // see if it can be used to dispatch any stored continuations.
618 class UnbatchResource : public ResourceBase {
619  public:
UnbatchResource(int32_t timeout_micros)620   explicit UnbatchResource(int32_t timeout_micros)
621       : timeout_micros_(timeout_micros),
622         timeout_enforcer_(new serving::PeriodicFunction(
623             [this] { EnforceTimeout(); }, 1000 /* 1 ms */)) {}
624 
~UnbatchResource()625   ~UnbatchResource() override {
626     // Tear down 'timeout_enforcer_' first, since it accesses other state in
627     // this class.
628     timeout_enforcer_ = nullptr;
629   }
630 
DebugString() const631   string DebugString() const final { return "UnbatchResource"; }
632 
Compute(OpKernelContext * context,AsyncOpKernel::DoneCallback done)633   Status Compute(OpKernelContext* context, AsyncOpKernel::DoneCallback done) {
634     const Tensor& data_t = context->input(0);
635     const Tensor& batch_index_t = context->input(1);
636 
637     if (batch_index_t.shape().dim_size(0) > data_t.shape().dim_size(0)) {
638       return errors::InvalidArgument(
639           "Wrong shape for index tensor. Expected 0th dimension size to be no "
640           "greater than ",
641           data_t.shape().dim_size(0),
642           "; Got: ", batch_index_t.shape().dim_size(0), ".");
643     }
644     if (batch_index_t.shape().dim_size(1) != 3) {
645       return errors::InvalidArgument(
646           "Wrong shape for index tensor. Expected 1st dimension size to be 3 ; "
647           "Got: ",
648           batch_index_t.shape().dim_size(1), ".");
649     }
650 
651     const int64_t batch_key = context->input(2).scalar<int64>()();
652     const bool nonempty_input = batch_index_t.dim_size(0) > 0;
653 
654     // If we have a non-empty tensor, slice it up.
655     // (It is important to do this outside of the critical section below.)
656     // The following variables are populated iff 'nonempty_input==true'.
657     std::vector<int64> sizes;
658     std::vector<int64> batch_keys;
659     std::vector<Tensor> split_inputs;
660     if (nonempty_input) {
661       auto batch_indices =
662           batch_index_t.shaped<int64, 2>({batch_index_t.dim_size(0), 3});
663       for (int i = 0; i < batch_index_t.dim_size(0); ++i) {
664         sizes.push_back(batch_indices(i, 2) - batch_indices(i, 1));
665         batch_keys.push_back(batch_indices(i, 0));
666       }
667 
668       TF_RETURN_IF_ERROR(Split(context, data_t, sizes, &split_inputs));
669     }
670 
671     // Critical section.
672     std::vector<AsyncOpKernel::DoneCallback> done_callbacks_to_call;
673     Status status = [&]() -> Status {
674       mutex_lock ml(mu_);
675 
676       // Check to see whether the tensor we want is already ready.
677       auto tensor_it = waiting_tensors_.find(batch_key);
678       if (tensor_it != waiting_tensors_.end()) {
679         context->set_output(0, tensor_it->second.tensor);
680         waiting_tensors_.erase(tensor_it);
681         done_callbacks_to_call.push_back(done);
682         return Status::OK();
683       }
684 
685       const uint64 deadline_micros =
686           Env::Default()->NowMicros() + timeout_micros_;
687 
688       // Add ourselves to the waitlist for tensors.
689       if (!waiting_callbacks_
690                .emplace(batch_key,
691                         WaitingCallback{deadline_micros, context, done})
692                .second) {
693         return errors::AlreadyExists(
694             "Multiple session runs with the same batch key.");
695       }
696 
697       // If we have a non-empty tensor, finish the waitlisted runs,
698       // and store any remaining pieces.
699       if (nonempty_input) {
700         for (size_t i = 0; i < batch_keys.size(); ++i) {
701           auto runs_it = waiting_callbacks_.find(batch_keys[i]);
702           if (runs_it != waiting_callbacks_.end()) {
703             runs_it->second.context->set_output(0, split_inputs[i]);
704             done_callbacks_to_call.push_back(runs_it->second.done);
705             waiting_callbacks_.erase(runs_it);
706           } else {
707             // Note: the deadline here is in case we are arriving late and the
708             // kernel that should rendezvous with this tensor has already waited
709             // and timed out.
710             if (!waiting_tensors_
711                      .emplace(batch_keys[i],
712                               WaitingTensor{deadline_micros, split_inputs[i]})
713                      .second) {
714               return errors::AlreadyExists(
715                   "Multiple tensors returned for same batch key.");
716             }
717           }
718         }
719       }
720 
721       return Status::OK();
722     }();
723 
724     for (const AsyncOpKernel::DoneCallback& done_callback :
725          done_callbacks_to_call) {
726       done_callback();
727     }
728 
729     return status;
730   }
731 
732  private:
733   // Evicts waiting tensors and callbacks that have exceeded their deadline.
EnforceTimeout()734   void EnforceTimeout() {
735     const uint64 now = Env::Default()->NowMicros();
736     std::vector<WaitingCallback> evicted_callbacks;
737 
738     {
739       mutex_lock ml(mu_);
740 
741       for (auto it = waiting_tensors_.begin(); it != waiting_tensors_.end();) {
742         const WaitingTensor& waiting_tensor = it->second;
743         if (waiting_tensor.deadline_micros < now) {
744           it = waiting_tensors_.erase(it);
745         } else {
746           ++it;
747         }
748       }
749 
750       for (auto it = waiting_callbacks_.begin();
751            it != waiting_callbacks_.end();) {
752         const WaitingCallback& waiting_callback = it->second;
753         if (waiting_callback.deadline_micros < now) {
754           evicted_callbacks.push_back(waiting_callback);
755           it = waiting_callbacks_.erase(it);
756         } else {
757           ++it;
758         }
759       }
760     }
761 
762     for (const WaitingCallback& evicted_callback : evicted_callbacks) {
763       evicted_callback.context->CtxFailureWithWarning(errors::DeadlineExceeded(
764           "Batched data did not arrive within timeout window."));
765       evicted_callback.done();
766     }
767   }
768 
769   struct WaitingTensor {
770     uint64 deadline_micros;
771     Tensor tensor;
772   };
773 
774   struct WaitingCallback {
775     uint64 deadline_micros;
776     OpKernelContext* context;
777     AsyncOpKernel::DoneCallback done;
778   };
779 
780   const int32 timeout_micros_;
781 
782   mutex mu_;
783 
784   // Maps keyed by BatchKey of tensors waiting for callbacks and callbacks
785   // waiting for tensors.
786   std::unordered_map<int64, WaitingTensor> waiting_tensors_ TF_GUARDED_BY(mu_);
787   std::unordered_map<int64, WaitingCallback> waiting_callbacks_
788       TF_GUARDED_BY(mu_);
789 
790   // A thread that evicts waiting tensors and callbacks that have exceeded their
791   // deadline.
792   std::unique_ptr<serving::PeriodicFunction> timeout_enforcer_;
793 };
794 
795 class UnbatchKernel : public AsyncOpKernel {
796  public:
UnbatchKernel(OpKernelConstruction * c)797   explicit UnbatchKernel(OpKernelConstruction* c) : AsyncOpKernel(c) {
798     OP_REQUIRES_OK(c, c->GetAttr("container", &container_));
799     OP_REQUIRES_OK(c, c->GetAttr("shared_name", &shared_name_));
800     // If shared_name is not supplied, use name instead (prevent collisions by
801     // default).
802     if (shared_name_.empty()) {
803       shared_name_ = name();
804     }
805     OP_REQUIRES_OK(c, c->GetAttr("timeout_micros", &timeout_micros_));
806   }
807 
ComputeAsync(OpKernelContext * c,DoneCallback done)808   void ComputeAsync(OpKernelContext* c, DoneCallback done) final {
809     UnbatchResource* ubr;
810     std::function<Status(UnbatchResource**)> creator =
811         [this](UnbatchResource** r) {
812           *r = new UnbatchResource(timeout_micros_);
813           return Status::OK();
814         };
815     OP_REQUIRES_OK_ASYNC(c,
816                          c->resource_manager()->LookupOrCreate(
817                              container_, shared_name_, &ubr, creator),
818                          done);
819     auto status = ubr->Compute(c, done);
820     ubr->Unref();
821     OP_REQUIRES_OK_ASYNC(c, status, done);
822     // Assume ubr calls done, so nothing to do here.
823   }
824 
825  private:
826   string container_;
827   string shared_name_;
828   int32 timeout_micros_;
829 };
830 REGISTER_KERNEL_BUILDER(Name("Unbatch").Device(DEVICE_CPU), UnbatchKernel);
831 
832 // A class encapsulating the state and logic for batching tensors
833 // deterministically for the gradient of unbatch.
834 class UnbatchGradResource : public ResourceBase {
835  public:
UnbatchGradResource()836   UnbatchGradResource() {}
837 
DebugString() const838   string DebugString() const final { return "UnbatchGradResource"; }
839 
840   // Flushes the information for one batch, given its context and done
841   // callback. Clears all information about it from the available_tensors_.
OutputBatch(OpKernelContext * context,const AsyncOpKernel::DoneCallback & done)842   Status OutputBatch(OpKernelContext* context,
843                      const AsyncOpKernel::DoneCallback& done)
844       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
845     const Tensor& batch_index_t = context->input(1);
846     auto batch_index =
847         batch_index_t.shaped<int64, 2>({batch_index_t.dim_size(0), 3});
848     std::vector<Tensor> tensors;
849     for (int i = 0; i < batch_index_t.dim_size(0); ++i) {
850       auto available_it = available_tensors_.find(batch_index(i, 0));
851       if (available_it == available_tensors_.end()) {
852         return errors::Internal("bad bookkeeping of available tensors.");
853       }
854       tensors.push_back(available_it->second);
855       available_tensors_.erase(available_it);
856     }
857 
858     const DataType type = tensors[0].dtype();
859     Tensor concatenated_tensor;
860     switch (type) {
861 #define CASE(type)                                                            \
862   case DataTypeToEnum<type>::value:                                           \
863     TF_RETURN_IF_ERROR(Concat<type>(context, tensors, &concatenated_tensor)); \
864     context->set_output(0, concatenated_tensor);                              \
865     break;
866       TF_CALL_ALL_TYPES(CASE);
867 #undef CASE
868       default:
869         return errors::InvalidArgument("Unsupported data type: ", type);
870     }
871     done();
872     return Status::OK();
873   }
874 
875   // Ingests data from one invocation of the op.
Compute(OpKernelContext * context,const AsyncOpKernel::DoneCallback & done)876   Status Compute(OpKernelContext* context,
877                  const AsyncOpKernel::DoneCallback& done) {
878     const Tensor& data_t = context->input(0);
879     const Tensor& batch_index_t = context->input(1);
880     const Tensor& grad_t = context->input(2);
881 
882     mutex_lock ml(mu_);
883 
884     const int64_t batch_key = context->input(3).scalar<int64>()();
885     // Mark our tensor as available.
886     if (!available_tensors_.emplace(batch_key, grad_t).second) {
887       return errors::InvalidArgument("Two runs with the same batch key.");
888     }
889 
890     // Check whether we have a valid input tensor and, if so, create its
891     // dispatch logic.
892     if (data_t.NumElements() > 0) {
893       if (batch_index_t.NumElements() == 0) {
894         return errors::InvalidArgument(
895             "batch_index is empty while the tensor isn't.");
896       }
897       std::unordered_set<int64> missing_tensors;
898       const auto batch_index =
899           batch_index_t.shaped<int64, 2>({batch_index_t.dim_size(0), 3});
900       for (int i = 0; i < batch_index_t.dim_size(0); ++i) {
901         const int64_t batch_key = batch_index(i, 0);
902         if (available_tensors_.find(batch_key) == available_tensors_.end()) {
903           missing_tensors.emplace(batch_key);
904         }
905       }
906       if (missing_tensors.empty()) {
907         return OutputBatch(context, done);
908       }
909       if (!available_batches_
910                .emplace(batch_key, Batch{missing_tensors, context, done})
911                .second) {
912         return errors::InvalidArgument(
913             "Batch key with valid batch used twice.");
914       }
915       for (const int64_t i : missing_tensors) {
916         if (!desired_tensor_to_batch_map_.emplace(i, batch_key).second) {
917           return errors::InvalidArgument(
918               "Missing tensor wanted by more than one batch.");
919         }
920       }
921     } else {
922       // If we don't have a valid input tensor we can output an empty tensor and
923       // call our done closure.
924       TensorShape output_shape(grad_t.shape());
925       output_shape.set_dim(0, 0);
926       Tensor* output = nullptr;
927       TF_RETURN_IF_ERROR(context->allocate_output(0, output_shape, &output));
928       done();
929     }
930 
931     // Search to see whether our tensor is desired by any existing batch.
932     auto desire_it = desired_tensor_to_batch_map_.find(batch_key);
933     if (desire_it != desired_tensor_to_batch_map_.end()) {
934       // Mark our tensor as no longer missing.
935       auto batch_it = available_batches_.find(desire_it->second);
936       desired_tensor_to_batch_map_.erase(desire_it);
937       if (batch_it == available_batches_.end()) {
938         return errors::InvalidArgument("Batch no longer exists.");
939       }
940       batch_it->second.missing_tensors.erase(batch_key);
941       // If all tensors are available we should concatenate them and dispatch
942       // the batch.
943       if (batch_it->second.missing_tensors.empty()) {
944         TF_RETURN_IF_ERROR(
945             OutputBatch(batch_it->second.context, batch_it->second.done));
946         available_batches_.erase(batch_it);
947       }
948     }
949     return Status::OK();
950   }
951 
952  private:
953   mutex mu_;
954 
955   // Represents a still-incomplete batch of tensors. When all tensors become
956   // available they will be concatenated in the right order and sent through the
957   // context.
958   struct Batch {
959     // Batch keys for tensors which are still missing from this batch. When this
960     // is empty the Tensors can be concatenated and forwarded.
961     std::unordered_set<int64> missing_tensors;
962 
963     // Context and callback for the session responsible for finishing this
964     // batch.
965     OpKernelContext* context;
966     AsyncOpKernel::DoneCallback done;
967   };
968 
969   // Map from batch key of the session which will output the batched gradients
970   // to still-incomplete batches.
971   std::unordered_map<int64, Batch> available_batches_;
972 
973   // Map from batch key to tensors which are waiting for their batches to be
974   // available.
975   std::unordered_map<int64, Tensor> available_tensors_;
976 
977   // Map from batch key of a tensor which is not yet available to the batch key
978   // of the batch to which it belongs.
979   std::unordered_map<int64, int64> desired_tensor_to_batch_map_;
980 };
981 
982 class UnbatchGradKernel : public AsyncOpKernel {
983  public:
UnbatchGradKernel(OpKernelConstruction * c)984   explicit UnbatchGradKernel(OpKernelConstruction* c) : AsyncOpKernel(c) {
985     OP_REQUIRES_OK(c, c->GetAttr("container", &container_));
986     OP_REQUIRES_OK(c, c->GetAttr("shared_name", &shared_name_));
987     // If shared_name is not supplied, use name instead (prevent collisions by
988     // default).
989     if (shared_name_.empty()) {
990       shared_name_ = name();
991     }
992   }
993 
ComputeAsync(OpKernelContext * c,DoneCallback done)994   void ComputeAsync(OpKernelContext* c, DoneCallback done) final {
995     UnbatchGradResource* ubr;
996     std::function<Status(UnbatchGradResource**)> creator =
997         [](UnbatchGradResource** r) {
998           *r = new UnbatchGradResource();
999           return Status::OK();
1000         };
1001     OP_REQUIRES_OK_ASYNC(c,
1002                          c->resource_manager()->LookupOrCreate(
1003                              container_, shared_name_, &ubr, creator),
1004                          done);
1005     Status status = ubr->Compute(c, done);
1006     ubr->Unref();
1007     OP_REQUIRES_OK_ASYNC(c, status, done);
1008     // Assume ubr calls done, so nothing to do here.
1009   }
1010 
1011  private:
1012   string container_;
1013   string shared_name_;
1014 };
1015 REGISTER_KERNEL_BUILDER(Name("UnbatchGrad").Device(DEVICE_CPU),
1016                         UnbatchGradKernel);
1017 
1018 }  // namespace tensorflow
1019