• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "absl/strings/str_cat.h"
17 #include "tensorflow/core/common_runtime/device_mgr.h"
18 #include "tensorflow/core/framework/device.h"
19 #include "tensorflow/core/framework/function.h"
20 #include "tensorflow/core/framework/op_kernel.h"
21 #include "tensorflow/core/framework/resource_mgr.h"
22 #include "tensorflow/core/framework/tensor.h"
23 #include "tensorflow/core/framework/tensor_util.h"
24 #include "tensorflow/core/framework/types.h"
25 #include "tensorflow/core/kernels/batching_util/adaptive_shared_batch_scheduler.h"
26 #include "tensorflow/core/kernels/batching_util/batch_resource_base.h"
27 #include "tensorflow/core/kernels/batching_util/concat_split_util.h"
28 #include "tensorflow/core/kernels/batching_util/periodic_function.h"
29 #include "tensorflow/core/kernels/ops_util.h"
30 #include "tensorflow/core/lib/monitoring/gauge.h"
31 #include "tensorflow/core/lib/random/random.h"
32 #include "tensorflow/core/platform/errors.h"
33 #include "tensorflow/core/platform/logging.h"
34 #include "tensorflow/core/platform/macros.h"
35 #include "tensorflow/core/platform/numbers.h"
36 
37 namespace tensorflow {
38 namespace {
39 constexpr int64 kMinInflightBatchesLimit = 16;
40 constexpr double kInitialInflightBatchesLimit = 64;
41 constexpr int64 kBatchesToAverageOver = 10;
42 constexpr int64 kMaxInflightBatchesLimit = 128;
43 }  // namespace
44 
45 auto* batch_op_split_usage = monitoring::Gauge<string, 1>::New(
46     "/tensorflow/serving/batching/enable_large_batch_splitting",
47     "Tracks the usage of attribute `enable_large_batch_splitting` for "
48     "BatchFunction kernel in a saved model.",
49     "model_name");
50 
RecordBatchSplitUsage(absl::optional<bool> maybe_enable_large_batch_splitting,const string & model_name)51 void RecordBatchSplitUsage(
52     absl::optional<bool> maybe_enable_large_batch_splitting,
53     const string& model_name) {
54   if (maybe_enable_large_batch_splitting.has_value()) {
55     if (maybe_enable_large_batch_splitting.value()) {
56       batch_op_split_usage->GetCell(model_name)->Set("true");
57     } else {
58       batch_op_split_usage->GetCell(model_name)->Set("false");
59     }
60   } else {
61     batch_op_split_usage->GetCell(model_name)->Set("unset");
62   }
63 }
64 
RecordBatchParamNumBatchThreads(int64 num_batch_threads,const string & model_name)65 void RecordBatchParamNumBatchThreads(int64 num_batch_threads,
66                                      const string& model_name) {
67   static auto* cell = monitoring::Gauge<int64, 1>::New(
68       "/tensorflow/serving/batching/num_batch_threads",
69       "Tracks the number of batch threads of a model.", "model_name");
70   cell->GetCell(model_name)->Set(num_batch_threads);
71 }
72 
GetModelName(OpKernelContext * ctx)73 const string& GetModelName(OpKernelContext* ctx) {
74   static string* kModelNameUnset = new string("model_name_unset");
75   if (!ctx->session_metadata()) return *kModelNameUnset;
76   if (ctx->session_metadata()->name().empty()) return *kModelNameUnset;
77   return ctx->session_metadata()->name();
78 }
79 
80 using ::tensorflow::concat_split_util::Concat;
81 using ::tensorflow::concat_split_util::Split;
82 
83 // A class encapsulating the state and logic for batching tensors.
84 class BatchResource : public serving::BatchResourceBase {
85  public:
Create(int32 num_batch_threads,int32 max_execution_batch_size,int32 batch_timeout_micros,int32 max_enqueued_batches,const std::vector<int32> & allowed_batch_sizes,FunctionLibraryRuntime::Handle fhandle,FunctionLibraryRuntime * flib,bool enable_large_batch_splitting,std::unique_ptr<BatchResource> * resource)86   static Status Create(int32 num_batch_threads, int32 max_execution_batch_size,
87                        int32 batch_timeout_micros, int32 max_enqueued_batches,
88                        const std::vector<int32>& allowed_batch_sizes,
89                        FunctionLibraryRuntime::Handle fhandle,
90                        FunctionLibraryRuntime* flib,
91                        bool enable_large_batch_splitting,
92                        std::unique_ptr<BatchResource>* resource) {
93     BatcherT::Options batcher_options;
94     batcher_options.num_batch_threads = num_batch_threads;
95     std::shared_ptr<BatcherT> batcher;
96     TF_RETURN_IF_ERROR(BatcherT::Create(batcher_options, &batcher));
97 
98     resource->reset(new BatchResource(
99         fhandle, flib, std::move(batcher),
100         GetBatcherQueueOptions(num_batch_threads, max_execution_batch_size,
101                                batch_timeout_micros, max_enqueued_batches,
102                                allowed_batch_sizes,
103                                enable_large_batch_splitting),
104         allowed_batch_sizes));
105     return Status::OK();
106   }
107 
Create(AdaptiveBatcherT::Options adaptive_shared_batch_scheduler_options,int32 max_batch_size,int32 batch_timeout_micros,int32 max_enqueued_batches,const std::vector<int32> & allowed_batch_sizes,FunctionLibraryRuntime::Handle fhandle,FunctionLibraryRuntime * flib,std::unique_ptr<BatchResource> * resource)108   static Status Create(
109       AdaptiveBatcherT::Options adaptive_shared_batch_scheduler_options,
110       int32 max_batch_size, int32 batch_timeout_micros,
111       int32 max_enqueued_batches, const std::vector<int32>& allowed_batch_sizes,
112       FunctionLibraryRuntime::Handle fhandle, FunctionLibraryRuntime* flib,
113       std::unique_ptr<BatchResource>* resource) {
114     std::shared_ptr<AdaptiveBatcherT> batcher;
115     TF_RETURN_IF_ERROR(AdaptiveBatcherT::Create(
116         adaptive_shared_batch_scheduler_options, &batcher));
117 
118     resource->reset(new BatchResource(
119         fhandle, flib, std::move(batcher),
120         GetAdaptiveBatcherQueueOptions(
121             max_batch_size, batch_timeout_micros, max_enqueued_batches,
122             true /* enable large batch split */, allowed_batch_sizes),
123         allowed_batch_sizes));
124     return Status::OK();
125   }
126 
DebugString() const127   string DebugString() const final { return "BatchResource"; }
128 
129  private:
BatchResource(FunctionLibraryRuntime::Handle fhandle,FunctionLibraryRuntime * flib,std::shared_ptr<BatcherT> batcher,const BatcherT::QueueOptions & batcher_queue_options,std::vector<int32> allowed_batch_sizes)130   BatchResource(FunctionLibraryRuntime::Handle fhandle,
131                 FunctionLibraryRuntime* flib, std::shared_ptr<BatcherT> batcher,
132                 const BatcherT::QueueOptions& batcher_queue_options,
133                 std::vector<int32> allowed_batch_sizes)
134       : BatchResourceBase(
135             /*has_process_batch_function=*/fhandle != kInvalidHandle,
136             std::move(batcher), batcher_queue_options,
137             std::move(allowed_batch_sizes)),
138         fhandle_(fhandle),
139         flib_(flib) {}
140 
BatchResource(FunctionLibraryRuntime::Handle fhandle,FunctionLibraryRuntime * flib,std::shared_ptr<AdaptiveBatcherT> batcher,const AdaptiveBatcherT::QueueOptions & batcher_queue_options,std::vector<int32> allowed_batch_sizes)141   BatchResource(FunctionLibraryRuntime::Handle fhandle,
142                 FunctionLibraryRuntime* flib,
143                 std::shared_ptr<AdaptiveBatcherT> batcher,
144                 const AdaptiveBatcherT::QueueOptions& batcher_queue_options,
145                 std::vector<int32> allowed_batch_sizes)
146       : BatchResourceBase(
147             /*has_process_batch_function=*/fhandle != kInvalidHandle,
148             std::move(batcher), batcher_queue_options,
149             std::move(allowed_batch_sizes)),
150         fhandle_(fhandle),
151         flib_(flib) {}
152 
ProcessFuncBatchImpl(const BatchTask & last_task,absl::Span<const Tensor> inputs,std::vector<Tensor> * combined_outputs,std::function<void (const Status &)> done) const153   void ProcessFuncBatchImpl(
154       const BatchTask& last_task, absl::Span<const Tensor> inputs,
155       std::vector<Tensor>* combined_outputs,
156       std::function<void(const Status&)> done) const override {
157     auto* last_task_context = last_task.context;
158     FunctionLibraryRuntime::Options opts;
159     opts.step_container = last_task_context->step_container();
160     opts.cancellation_manager = last_task_context->cancellation_manager();
161     opts.collective_executor = last_task_context->collective_executor();
162     opts.stats_collector = last_task_context->stats_collector();
163     opts.runner = last_task_context->runner();
164     opts.run_all_kernels_inline = last_task_context->run_all_kernels_inline();
165     // We do not set 'opts.rendezvous', since if the function is run multiple
166     // times in parallel with the same rendezvous, a _Send node from one run
167     // might be matched with a _Recv node of a different run. Not setting the
168     // rendezvous causes a new rendezvous to be used for each run.
169     Notification done_notif;
170 
171     flib_->Run(opts, fhandle_, inputs, combined_outputs,
172                [&](const Status& run_status) {
173                  done(run_status);
174                  done_notif.Notify();
175                });
176     // By waiting for the notification we are ensuring that this thread isn't
177     // used for processing other batches, which gives the batches time to
178     // coalesce upstream. So overall the number of batches going through the
179     // devices goes down, improving latency and throughput in most cases.
180     done_notif.WaitForNotification();
181   }
182 
183   FunctionLibraryRuntime::Handle fhandle_;
184   FunctionLibraryRuntime* flib_;
185 };
186 
187 class BatchFunctionKernel : public AsyncOpKernel {
188  public:
BatchFunctionKernel(OpKernelConstruction * c)189   explicit BatchFunctionKernel(OpKernelConstruction* c) : AsyncOpKernel(c) {
190     OP_REQUIRES_OK(c, c->GetAttr("container", &container_));
191     OP_REQUIRES_OK(c, c->GetAttr("shared_name", &shared_name_));
192     OP_REQUIRES_OK(c, c->GetAttr("batching_queue", &batcher_queue_));
193     OP_REQUIRES_OK(c, c->GetAttr("num_batch_threads", &num_batch_threads_));
194     OP_REQUIRES_OK(c, c->GetAttr("max_batch_size", &max_batch_size_));
195     OP_REQUIRES_OK(c,
196                    c->GetAttr("batch_timeout_micros", &batch_timeout_micros_));
197     OP_REQUIRES_OK(c,
198                    c->GetAttr("max_enqueued_batches", &max_enqueued_batches_));
199     OP_REQUIRES_OK(c, c->GetAttr("allowed_batch_sizes", &allowed_batch_sizes_));
200 
201     OP_REQUIRES_OK(c, c->GetAttr("f", &func_));
202     flib_ = c->function_library();
203     if (num_batch_threads_ <= 0) {
204       adaptive_batch_scheduler_options_ =
205           absl::make_optional(AdaptiveBatchSchedulerOptions{
206               kMinInflightBatchesLimit, kInitialInflightBatchesLimit,
207               kBatchesToAverageOver});
208 
209       // One scheduler instance contains a couple of queue instances,
210       // `batcher_queue_` is the key to find queue for this batch-op in the
211       // graph.
212       // Use `shared_name_` and name() as prefix for `batcher_queue_`.
213       // Note name() is unique per session (from session metadata).
214       batcher_queue_ = name() + "/" + shared_name_ + batcher_queue_;
215 
216       // `shared_name_` and `container_` is used to look up an instantiated
217       // scheduler instance in `ComputeAsync`.
218       //
219       // Rewrite `container_` and `shared_name_` to a pre-defined constant so
220       // that a shared shared pool across all models if adaptive shared batch
221       // scheduler is used.
222       container_ = "__adapative_container";
223       shared_name_ = "__adaptive_global_shared_thread_pool";
224     }
225 
226     if (shared_name_.empty()) {
227       // If shared_name is not supplied, use name instead (prevent collisions by
228       // default).
229       shared_name_ = name();
230     }
231 
232     if (c->HasAttr("enable_large_batch_splitting")) {
233       OP_REQUIRES_OK(c, c->GetAttr("enable_large_batch_splitting",
234                                    &enable_large_batch_splitting_));
235       has_attribute_enable_large_batch_splitting_ = true;
236     } else {
237       enable_large_batch_splitting_ = false;
238       has_attribute_enable_large_batch_splitting_ = false;
239     }
240 
241     OP_REQUIRES_OK(c, ValidateAllowedBatchSizes());
242   }
243 
IsExpensive()244   bool IsExpensive() override { return false; }
245 
ComputeAsync(OpKernelContext * c,DoneCallback done)246   void ComputeAsync(OpKernelContext* c, DoneCallback done) final {
247     RecordBatchSplitUsage(
248         has_attribute_enable_large_batch_splitting_
249             ? absl::make_optional(enable_large_batch_splitting_)
250             : absl::nullopt,
251         GetModelName(c));
252     // TODO(b/173255290): Add num_batch_threads_ parameter to TFRT batch kernel.
253     RecordBatchParamNumBatchThreads(num_batch_threads_, GetModelName(c));
254 
255     std::function<Status(BatchResource**)> creator;
256 
257     FunctionLibraryRuntime::Handle handle;
258     OP_REQUIRES_OK_ASYNC(c, GetOrCreateFunctionHandle(c, &handle), done);
259 
260     if (adaptive_batch_scheduler_options_ != absl::nullopt) {
261       creator = [this, handle](BatchResource** r) {
262         serving::AdaptiveSharedBatchScheduler<
263             serving::BatchResourceBase::BatchTask>::Options
264             adaptive_shared_batch_scheduler_options;
265         adaptive_shared_batch_scheduler_options.thread_pool_name =
266             "adaptive_batch_threads";
267         adaptive_shared_batch_scheduler_options.num_batch_threads =
268             kMaxInflightBatchesLimit;
269         // adaptive_shared_batch_scheduler_options.full_batch_scheduling_boost_micros
270         // is 0 (default value) intentionally, so tasks are scheduled in a FIFO
271         // way.
272         // Two rationales to use default value (zero) for
273         // `full_batch_scheduling_boost_micros`
274         // 1) In this way, tasks scheduling policy is FIFO. Compared with round
275         // robin (what shared batch scheduler does), FIFO ensures that model
276         // with low QPS (i.e., models enqueue fewer tasks in the shared queue)
277         // will be processed timely.
278         // 2) If set, `full_batch_scheduling_boost_micros` should be of order
279         // the batch processing latency (which varies on a model basis).
280         // If a non-zero value is not set properly, it harms tail latency.
281         adaptive_shared_batch_scheduler_options.min_in_flight_batches_limit =
282             adaptive_batch_scheduler_options_->min_in_flight_batches_limit;
283         adaptive_shared_batch_scheduler_options
284             .initial_in_flight_batches_limit =
285             adaptive_batch_scheduler_options_->initial_in_flight_batches_limit;
286         adaptive_shared_batch_scheduler_options.batches_to_average_over =
287             adaptive_batch_scheduler_options_->batches_to_average_over;
288         std::unique_ptr<BatchResource> new_resource;
289         TF_RETURN_IF_ERROR(BatchResource::Create(
290             adaptive_shared_batch_scheduler_options, max_batch_size_,
291             batch_timeout_micros_, max_enqueued_batches_, allowed_batch_sizes_,
292             handle, flib_, &new_resource));
293         *r = new_resource.release();
294         return Status::OK();
295       };
296     } else {
297       creator = [this, handle](BatchResource** r) {
298         std::unique_ptr<BatchResource> new_resource;
299         TF_RETURN_IF_ERROR(BatchResource::Create(
300             num_batch_threads_, max_batch_size_, batch_timeout_micros_,
301             max_enqueued_batches_, allowed_batch_sizes_, handle, flib_,
302             enable_large_batch_splitting_, &new_resource));
303         *r = new_resource.release();
304         return Status::OK();
305       };
306     }
307 
308     BatchResource* br;
309     OP_REQUIRES_OK_ASYNC(c,
310                          c->resource_manager()->LookupOrCreate(
311                              container_, shared_name_, &br, creator),
312                          done);
313     const Status status =
314         br->RegisterInput(random::New64(), c, batcher_queue_, done);
315     br->Unref();
316     OP_REQUIRES_OK_ASYNC(c, status, done);
317     // Assume br calls done, so nothing to do here.
318   }
319 
InstantiateFunction(OpKernelContext * c,FunctionLibraryRuntime::Handle * handle) const320   Status InstantiateFunction(OpKernelContext* c,
321                              FunctionLibraryRuntime::Handle* handle) const {
322     // TODO(b/173748062): Merge this instantiation logic with PartitionedCall.
323     if (!flib_) {
324       return errors::Internal("No function library");
325     }
326 
327     FunctionLibraryRuntime::InstantiateOptions opts;
328     opts.target = flib_->device() == nullptr ? "" : flib_->device()->name();
329     opts.is_multi_device_function = true;
330     const ConfigProto* config = flib_->config_proto();
331     if (config) {
332       opts.config_proto = *config;
333     }
334 
335     Device* cpu_device;
336     TF_RETURN_IF_ERROR(flib_->device_mgr()->LookupDevice("CPU:0", &cpu_device));
337 
338     const FunctionDef* fdef =
339         flib_->GetFunctionLibraryDefinition()->Find(func_.name());
340     if (!fdef) {
341       return errors::NotFound("Failed to find definition for function \"",
342                               func_.name(), "\"");
343     }
344     OpInputList in_tensors;
345     TF_RETURN_IF_ERROR(c->input_list("in_tensors", &in_tensors));
346     for (int i = 0; i < in_tensors.size(); i++) {
347       if (in_tensors[i].dtype() == DT_RESOURCE) {
348         return errors::InvalidArgument(
349             "BatchFunction cannot take resource inputs but input ", i,
350             " is a resource.");
351       } else {
352         // Currently, inputs are on CPU since they are concatenated on CPU
353         opts.input_devices.push_back(cpu_device->name());
354       }
355     }
356     OpInputList captured_tensors;
357     TF_RETURN_IF_ERROR(c->input_list("captured_tensors", &captured_tensors));
358     for (const Tensor& t : captured_tensors) {
359       if (t.dtype() == DT_RESOURCE) {
360         const ResourceHandle& rhandle = t.flat<ResourceHandle>()(0);
361         opts.input_devices.push_back(rhandle.device());
362       } else {
363         opts.input_devices.push_back(cpu_device->name());
364       }
365     }
366     const OpDef& signature = fdef->signature();
367     for (int i = 0; i < signature.output_arg_size(); i++) {
368       // Currently, outputs must be on CPU since they are split on CPU.
369       opts.output_devices.push_back(cpu_device->name());
370     }
371     if (opts.input_devices.size() != signature.input_arg_size()) {
372       return errors::InvalidArgument(
373           "Function takes ", signature.input_arg_size(), " argument(s) but ",
374           opts.input_devices.size(), " argument(s) were passed");
375     }
376     return flib_->Instantiate(func_.name(), AttrSlice(&func_.attr()), opts,
377                               handle);
378   }
379 
GetOrCreateFunctionHandle(OpKernelContext * c,FunctionLibraryRuntime::Handle * handle)380   Status GetOrCreateFunctionHandle(OpKernelContext* c,
381                                    FunctionLibraryRuntime::Handle* handle) {
382     mutex_lock ml(mu_);
383     if (!fhandle_) {
384       TF_RETURN_IF_ERROR(InstantiateFunction(c, handle));
385       fhandle_ = *handle;
386     } else {
387       *handle = fhandle_.value();
388     }
389     return Status::OK();
390   }
391 
392   // Validates 'allowed_batch_sizes_'. The entries must increase monotonically.
393   // If large batch split is not enabled, the last one must equal
394   // `max_batch_size_`. otherwise the last element must be smaller than or equal
395   // to `max_batch_size_`.
ValidateAllowedBatchSizes() const396   Status ValidateAllowedBatchSizes() const {
397     if (allowed_batch_sizes_.empty()) {
398       return Status::OK();
399     }
400     int32 last_size = 0;
401     for (size_t i = 0; i < allowed_batch_sizes_.size(); ++i) {
402       const int32 size = allowed_batch_sizes_.at(i);
403       if (i > 0 && size <= last_size) {
404         return errors::InvalidArgument(
405             "allowed_batch_sizes entries must be monotonically increasing");
406       }
407 
408       if ((!enable_large_batch_splitting_) &&
409           (i == allowed_batch_sizes_.size() - 1) && (size != max_batch_size_)) {
410         return errors::InvalidArgument(
411             "final entry in allowed_batch_sizes must equal max_batch_size when "
412             "enable_large_batch_splitting is False");
413       }
414 
415       last_size = size;
416     }
417     return Status::OK();
418   }
419 
420  private:
421   string container_;
422   string shared_name_;
423   string batcher_queue_;
424   int32 num_batch_threads_;
425   int32 max_batch_size_;
426   int32 batch_timeout_micros_;
427   int32 max_enqueued_batches_;
428   std::vector<int32> allowed_batch_sizes_;
429   NameAttrList func_;
430   absl::optional<FunctionLibraryRuntime::Handle> fhandle_ TF_GUARDED_BY(mu_);
431   FunctionLibraryRuntime* flib_;
432   bool enable_large_batch_splitting_;
433   bool has_attribute_enable_large_batch_splitting_;
434   mutex mu_;
435 
436   // Parameters for adaptive batch scheduler only.
437   // Note 'num_batch_threads_' above is shared by two implementations of batch
438   // scheduler.
439   struct AdaptiveBatchSchedulerOptions {
440     int64 min_in_flight_batches_limit;
441     double initial_in_flight_batches_limit;
442     int64 batches_to_average_over;
443   };
444   absl::optional<AdaptiveBatchSchedulerOptions>
445       adaptive_batch_scheduler_options_ = absl::nullopt;
446 };
447 
448 REGISTER_KERNEL_BUILDER(Name("BatchFunction").Device(DEVICE_CPU),
449                         BatchFunctionKernel);
450 // Currently all inputs and outputs are on the host.
451 // TODO(b/173748277): Accept inputs/outputs on the device.
452 REGISTER_KERNEL_BUILDER(Name("BatchFunction")
453                             .Device(DEVICE_GPU)
454                             .HostMemory("in_tensors")
455                             .HostMemory("captured_tensors")
456                             .HostMemory("out_tensors"),
457                         BatchFunctionKernel);
458 
459 class BatchKernel : public AsyncOpKernel {
460  public:
BatchKernel(OpKernelConstruction * c)461   explicit BatchKernel(OpKernelConstruction* c) : AsyncOpKernel(c) {
462     OP_REQUIRES_OK(c, c->GetAttr("container", &container_));
463     OP_REQUIRES_OK(c, c->GetAttr("shared_name", &shared_name_));
464     // If shared_name is not supplied, use name instead (prevent collisions by
465     // default).
466     if (shared_name_.empty()) {
467       shared_name_ = name();
468     }
469     OP_REQUIRES_OK(c, c->GetAttr("batching_queue", &batcher_queue_));
470     OP_REQUIRES_OK(c, c->GetAttr("num_batch_threads", &num_batch_threads_));
471     OP_REQUIRES_OK(c, c->GetAttr("max_batch_size", &max_batch_size_));
472     OP_REQUIRES_OK(c,
473                    c->GetAttr("batch_timeout_micros", &batch_timeout_micros_));
474     OP_REQUIRES_OK(c,
475                    c->GetAttr("max_enqueued_batches", &max_enqueued_batches_));
476     OP_REQUIRES_OK(c, c->GetAttr("allowed_batch_sizes", &allowed_batch_sizes_));
477     OP_REQUIRES_OK(c, ValidateAllowedBatchSizes());
478   }
479 
ComputeAsync(OpKernelContext * c,DoneCallback done)480   void ComputeAsync(OpKernelContext* c, DoneCallback done) final {
481     BatchResource* br;
482     std::function<Status(BatchResource**)> creator = [this](BatchResource** r) {
483       std::unique_ptr<BatchResource> new_resource;
484       TF_RETURN_IF_ERROR(BatchResource::Create(
485           num_batch_threads_, max_batch_size_, batch_timeout_micros_,
486           max_enqueued_batches_, allowed_batch_sizes_, kInvalidHandle,
487           /*flib=*/nullptr, false, &new_resource));
488       *r = new_resource.release();
489       return Status::OK();
490     };
491     OP_REQUIRES_OK_ASYNC(c,
492                          c->resource_manager()->LookupOrCreate(
493                              container_, shared_name_, &br, creator),
494                          done);
495     const Status status =
496         br->RegisterInput(random::New64(), c, batcher_queue_, done);
497     br->Unref();
498     OP_REQUIRES_OK_ASYNC(c, status, done);
499     // Assume br calls done, so nothing to do here.
500   }
501 
502   // Validates 'allowed_batch_sizes_'. The entries must increase
503   // monotonically, and the last one must equal 'max_batch_size_'.
ValidateAllowedBatchSizes() const504   Status ValidateAllowedBatchSizes() const {
505     if (allowed_batch_sizes_.empty()) {
506       return Status::OK();
507     }
508     int32 last_size = 0;
509     for (size_t i = 0; i < allowed_batch_sizes_.size(); ++i) {
510       const int32 size = allowed_batch_sizes_.at(i);
511       if (i > 0 && size <= last_size) {
512         return errors::InvalidArgument(
513             "allowed_batch_sizes entries must be monotonically increasing");
514       }
515       if (i == allowed_batch_sizes_.size() - 1 && size != max_batch_size_) {
516         return errors::InvalidArgument(
517             "final entry in allowed_batch_sizes must equal max_batch_size");
518       }
519       last_size = size;
520     }
521     return Status::OK();
522   }
523 
524  private:
525   string container_;
526   string shared_name_;
527   string batcher_queue_;
528   int32 num_batch_threads_;
529   int32 max_batch_size_;
530   int32 batch_timeout_micros_;
531   int32 max_enqueued_batches_;
532   std::vector<int32> allowed_batch_sizes_;
533 };
534 
535 REGISTER_KERNEL_BUILDER(Name("Batch").Device(DEVICE_CPU), BatchKernel);
536 
537 // A class encapsulating the state and logic for unbatching tensors.
538 //
539 // UnbatchResource keeps two data structures indexed by batch-key: one which has
540 // the continuations for all concurrent kernels which are waiting for tensors
541 // and another which has tensors which are waiting for their corresponding
542 // kernels to run. Whenever a kernel runs, we either grab its tensor if it's
543 // waiting already, or we insert it in the queue and then look at its tensor to
544 // see if it can be used to dispatch any stored continuations.
545 class UnbatchResource : public ResourceBase {
546  public:
UnbatchResource(int32 timeout_micros)547   explicit UnbatchResource(int32 timeout_micros)
548       : timeout_micros_(timeout_micros),
549         timeout_enforcer_(new serving::PeriodicFunction(
550             [this] { EnforceTimeout(); }, 1000 /* 1 ms */)) {}
551 
~UnbatchResource()552   ~UnbatchResource() override {
553     // Tear down 'timeout_enforcer_' first, since it accesses other state in
554     // this class.
555     timeout_enforcer_ = nullptr;
556   }
557 
DebugString() const558   string DebugString() const final { return "UnbatchResource"; }
559 
Compute(OpKernelContext * context,AsyncOpKernel::DoneCallback done)560   Status Compute(OpKernelContext* context, AsyncOpKernel::DoneCallback done) {
561     const Tensor& data_t = context->input(0);
562     const Tensor& batch_index_t = context->input(1);
563 
564     if (batch_index_t.shape().dim_size(0) > data_t.shape().dim_size(0)) {
565       return errors::InvalidArgument(
566           "Wrong shape for index tensor. Expected 0th dimension size to be no "
567           "greater than ",
568           data_t.shape().dim_size(0),
569           "; Got: ", batch_index_t.shape().dim_size(0), ".");
570     }
571     if (batch_index_t.shape().dim_size(1) != 3) {
572       return errors::InvalidArgument(
573           "Wrong shape for index tensor. Expected 1st dimension size to be 3 ; "
574           "Got: ",
575           batch_index_t.shape().dim_size(1), ".");
576     }
577 
578     const int64 batch_key = context->input(2).scalar<int64>()();
579     const bool nonempty_input = batch_index_t.dim_size(0) > 0;
580 
581     // If we have a non-empty tensor, slice it up.
582     // (It is important to do this outside of the critical section below.)
583     // The following variables are populated iff 'nonempty_input==true'.
584     std::vector<int64> sizes;
585     std::vector<int64> batch_keys;
586     std::vector<Tensor> split_inputs;
587     if (nonempty_input) {
588       auto batch_indices =
589           batch_index_t.shaped<int64, 2>({batch_index_t.dim_size(0), 3});
590       for (int i = 0; i < batch_index_t.dim_size(0); ++i) {
591         sizes.push_back(batch_indices(i, 2) - batch_indices(i, 1));
592         batch_keys.push_back(batch_indices(i, 0));
593       }
594 
595       TF_RETURN_IF_ERROR(Split(context, data_t, sizes, &split_inputs));
596     }
597 
598     // Critical section.
599     std::vector<AsyncOpKernel::DoneCallback> done_callbacks_to_call;
600     Status status = [&]() -> Status {
601       mutex_lock ml(mu_);
602 
603       // Check to see whether the tensor we want is already ready.
604       auto tensor_it = waiting_tensors_.find(batch_key);
605       if (tensor_it != waiting_tensors_.end()) {
606         context->set_output(0, tensor_it->second.tensor);
607         waiting_tensors_.erase(tensor_it);
608         done_callbacks_to_call.push_back(done);
609         return Status::OK();
610       }
611 
612       const uint64 deadline_micros =
613           Env::Default()->NowMicros() + timeout_micros_;
614 
615       // Add ourselves to the waitlist for tensors.
616       if (!waiting_callbacks_
617                .emplace(batch_key,
618                         WaitingCallback{deadline_micros, context, done})
619                .second) {
620         return errors::AlreadyExists(
621             "Multiple session runs with the same batch key.");
622       }
623 
624       // If we have a non-empty tensor, finish the waitlisted runs,
625       // and store any remaining pieces.
626       if (nonempty_input) {
627         for (size_t i = 0; i < batch_keys.size(); ++i) {
628           auto runs_it = waiting_callbacks_.find(batch_keys[i]);
629           if (runs_it != waiting_callbacks_.end()) {
630             runs_it->second.context->set_output(0, split_inputs[i]);
631             done_callbacks_to_call.push_back(runs_it->second.done);
632             waiting_callbacks_.erase(runs_it);
633           } else {
634             // Note: the deadline here is in case we are arriving late and the
635             // kernel that should rendezvous with this tensor has already waited
636             // and timed out.
637             if (!waiting_tensors_
638                      .emplace(batch_keys[i],
639                               WaitingTensor{deadline_micros, split_inputs[i]})
640                      .second) {
641               return errors::AlreadyExists(
642                   "Multiple tensors returned for same batch key.");
643             }
644           }
645         }
646       }
647 
648       return Status::OK();
649     }();
650 
651     for (const AsyncOpKernel::DoneCallback& done_callback :
652          done_callbacks_to_call) {
653       done_callback();
654     }
655 
656     return status;
657   }
658 
659  private:
660   // Evicts waiting tensors and callbacks that have exceeded their deadline.
EnforceTimeout()661   void EnforceTimeout() {
662     const uint64 now = Env::Default()->NowMicros();
663     std::vector<WaitingCallback> evicted_callbacks;
664 
665     {
666       mutex_lock ml(mu_);
667 
668       for (auto it = waiting_tensors_.begin(); it != waiting_tensors_.end();) {
669         const WaitingTensor& waiting_tensor = it->second;
670         if (waiting_tensor.deadline_micros < now) {
671           it = waiting_tensors_.erase(it);
672         } else {
673           ++it;
674         }
675       }
676 
677       for (auto it = waiting_callbacks_.begin();
678            it != waiting_callbacks_.end();) {
679         const WaitingCallback& waiting_callback = it->second;
680         if (waiting_callback.deadline_micros < now) {
681           evicted_callbacks.push_back(waiting_callback);
682           it = waiting_callbacks_.erase(it);
683         } else {
684           ++it;
685         }
686       }
687     }
688 
689     for (const WaitingCallback& evicted_callback : evicted_callbacks) {
690       evicted_callback.context->CtxFailureWithWarning(errors::DeadlineExceeded(
691           "Batched data did not arrive within timeout window."));
692       evicted_callback.done();
693     }
694   }
695 
696   struct WaitingTensor {
697     uint64 deadline_micros;
698     Tensor tensor;
699   };
700 
701   struct WaitingCallback {
702     uint64 deadline_micros;
703     OpKernelContext* context;
704     AsyncOpKernel::DoneCallback done;
705   };
706 
707   const int32 timeout_micros_;
708 
709   mutex mu_;
710 
711   // Maps keyed by BatchKey of tensors waiting for callbacks and callbacks
712   // waiting for tensors.
713   std::unordered_map<int64, WaitingTensor> waiting_tensors_ TF_GUARDED_BY(mu_);
714   std::unordered_map<int64, WaitingCallback> waiting_callbacks_
715       TF_GUARDED_BY(mu_);
716 
717   // A thread that evicts waiting tensors and callbacks that have exceeded their
718   // deadline.
719   std::unique_ptr<serving::PeriodicFunction> timeout_enforcer_;
720 };
721 
722 class UnbatchKernel : public AsyncOpKernel {
723  public:
UnbatchKernel(OpKernelConstruction * c)724   explicit UnbatchKernel(OpKernelConstruction* c) : AsyncOpKernel(c) {
725     OP_REQUIRES_OK(c, c->GetAttr("container", &container_));
726     OP_REQUIRES_OK(c, c->GetAttr("shared_name", &shared_name_));
727     // If shared_name is not supplied, use name instead (prevent collisions by
728     // default).
729     if (shared_name_.empty()) {
730       shared_name_ = name();
731     }
732     OP_REQUIRES_OK(c, c->GetAttr("timeout_micros", &timeout_micros_));
733   }
734 
ComputeAsync(OpKernelContext * c,DoneCallback done)735   void ComputeAsync(OpKernelContext* c, DoneCallback done) final {
736     UnbatchResource* ubr;
737     std::function<Status(UnbatchResource**)> creator =
738         [this](UnbatchResource** r) {
739           *r = new UnbatchResource(timeout_micros_);
740           return Status::OK();
741         };
742     OP_REQUIRES_OK_ASYNC(c,
743                          c->resource_manager()->LookupOrCreate(
744                              container_, shared_name_, &ubr, creator),
745                          done);
746     auto status = ubr->Compute(c, done);
747     ubr->Unref();
748     OP_REQUIRES_OK_ASYNC(c, status, done);
749     // Assume ubr calls done, so nothing to do here.
750   }
751 
752  private:
753   string container_;
754   string shared_name_;
755   int32 timeout_micros_;
756 };
757 REGISTER_KERNEL_BUILDER(Name("Unbatch").Device(DEVICE_CPU), UnbatchKernel);
758 
759 // A class encapsulating the state and logic for batching tensors
760 // deterministically for the gradient of unbatch.
761 class UnbatchGradResource : public ResourceBase {
762  public:
UnbatchGradResource()763   UnbatchGradResource() {}
764 
DebugString() const765   string DebugString() const final { return "UnbatchGradResource"; }
766 
767   // Flushes the information for one batch, given its context and done
768   // callback. Clears all information about it from the available_tensors_.
OutputBatch(OpKernelContext * context,const AsyncOpKernel::DoneCallback & done)769   Status OutputBatch(OpKernelContext* context,
770                      const AsyncOpKernel::DoneCallback& done)
771       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
772     const Tensor& batch_index_t = context->input(1);
773     auto batch_index =
774         batch_index_t.shaped<int64, 2>({batch_index_t.dim_size(0), 3});
775     std::vector<Tensor> tensors;
776     for (int i = 0; i < batch_index_t.dim_size(0); ++i) {
777       auto available_it = available_tensors_.find(batch_index(i, 0));
778       if (available_it == available_tensors_.end()) {
779         return errors::Internal("bad bookkeeping of available tensors.");
780       }
781       tensors.push_back(available_it->second);
782       available_tensors_.erase(available_it);
783     }
784 
785     const DataType type = tensors[0].dtype();
786     Tensor concatenated_tensor;
787     switch (type) {
788 #define CASE(type)                                                            \
789   case DataTypeToEnum<type>::value:                                           \
790     TF_RETURN_IF_ERROR(Concat<type>(context, tensors, &concatenated_tensor)); \
791     context->set_output(0, concatenated_tensor);                              \
792     break;
793       TF_CALL_ALL_TYPES(CASE);
794 #undef CASE
795       default:
796         return errors::InvalidArgument("Unsupported data type: ", type);
797     }
798     done();
799     return Status::OK();
800   }
801 
802   // Ingests data from one invocation of the op.
Compute(OpKernelContext * context,const AsyncOpKernel::DoneCallback & done)803   Status Compute(OpKernelContext* context,
804                  const AsyncOpKernel::DoneCallback& done) {
805     const Tensor& data_t = context->input(0);
806     const Tensor& batch_index_t = context->input(1);
807     const Tensor& grad_t = context->input(2);
808 
809     mutex_lock ml(mu_);
810 
811     const int64 batch_key = context->input(3).scalar<int64>()();
812     // Mark our tensor as available.
813     if (!available_tensors_.emplace(batch_key, grad_t).second) {
814       return errors::InvalidArgument("Two runs with the same batch key.");
815     }
816 
817     // Check whether we have a valid input tensor and, if so, create its
818     // dispatch logic.
819     if (data_t.NumElements() > 0) {
820       if (batch_index_t.NumElements() == 0) {
821         return errors::InvalidArgument(
822             "batch_index is empty while the tensor isn't.");
823       }
824       std::unordered_set<int64> missing_tensors;
825       const auto batch_index =
826           batch_index_t.shaped<int64, 2>({batch_index_t.dim_size(0), 3});
827       for (int i = 0; i < batch_index_t.dim_size(0); ++i) {
828         const int64 batch_key = batch_index(i, 0);
829         if (available_tensors_.find(batch_key) == available_tensors_.end()) {
830           missing_tensors.emplace(batch_key);
831         }
832       }
833       if (missing_tensors.empty()) {
834         return OutputBatch(context, done);
835       }
836       if (!available_batches_
837                .emplace(batch_key, Batch{missing_tensors, context, done})
838                .second) {
839         return errors::InvalidArgument(
840             "Batch key with valid batch used twice.");
841       }
842       for (const int64 i : missing_tensors) {
843         if (!desired_tensor_to_batch_map_.emplace(i, batch_key).second) {
844           return errors::InvalidArgument(
845               "Missing tensor wanted by more than one batch.");
846         }
847       }
848     } else {
849       // If we don't have a valid input tensor we can output an empty tensor and
850       // call our done closure.
851       TensorShape output_shape(grad_t.shape());
852       output_shape.set_dim(0, 0);
853       Tensor* output = nullptr;
854       TF_RETURN_IF_ERROR(context->allocate_output(0, output_shape, &output));
855       done();
856     }
857 
858     // Search to see whether our tensor is desired by any existing batch.
859     auto desire_it = desired_tensor_to_batch_map_.find(batch_key);
860     if (desire_it != desired_tensor_to_batch_map_.end()) {
861       // Mark our tensor as no longer missing.
862       auto batch_it = available_batches_.find(desire_it->second);
863       desired_tensor_to_batch_map_.erase(desire_it);
864       if (batch_it == available_batches_.end()) {
865         return errors::InvalidArgument("Batch no longer exists.");
866       }
867       batch_it->second.missing_tensors.erase(batch_key);
868       // If all tensors are available we should concatenate them and dispatch
869       // the batch.
870       if (batch_it->second.missing_tensors.empty()) {
871         TF_RETURN_IF_ERROR(
872             OutputBatch(batch_it->second.context, batch_it->second.done));
873         available_batches_.erase(batch_it);
874       }
875     }
876     return Status::OK();
877   }
878 
879  private:
880   mutex mu_;
881 
882   // Represents a still-incomplete batch of tensors. When all tensors become
883   // available they will be concatenated in the right order and sent through the
884   // context.
885   struct Batch {
886     // Batch keys for tensors which are still missing from this batch. When this
887     // is empty the Tensors can be concatenated and forwarded.
888     std::unordered_set<int64> missing_tensors;
889 
890     // Context and callback for the session responsible for finishing this
891     // batch.
892     OpKernelContext* context;
893     AsyncOpKernel::DoneCallback done;
894   };
895 
896   // Map from batch key of the session which will output the batched gradients
897   // to still-incomplete batches.
898   std::unordered_map<int64, Batch> available_batches_;
899 
900   // Map from batch key to tensors which are waiting for their batches to be
901   // available.
902   std::unordered_map<int64, Tensor> available_tensors_;
903 
904   // Map from batch key of a tensor which is not yet available to the batch key
905   // of the batch to which it belongs.
906   std::unordered_map<int64, int64> desired_tensor_to_batch_map_;
907 };
908 
909 class UnbatchGradKernel : public AsyncOpKernel {
910  public:
UnbatchGradKernel(OpKernelConstruction * c)911   explicit UnbatchGradKernel(OpKernelConstruction* c) : AsyncOpKernel(c) {
912     OP_REQUIRES_OK(c, c->GetAttr("container", &container_));
913     OP_REQUIRES_OK(c, c->GetAttr("shared_name", &shared_name_));
914     // If shared_name is not supplied, use name instead (prevent collisions by
915     // default).
916     if (shared_name_.empty()) {
917       shared_name_ = name();
918     }
919   }
920 
ComputeAsync(OpKernelContext * c,DoneCallback done)921   void ComputeAsync(OpKernelContext* c, DoneCallback done) final {
922     UnbatchGradResource* ubr;
923     std::function<Status(UnbatchGradResource**)> creator =
924         [](UnbatchGradResource** r) {
925           *r = new UnbatchGradResource();
926           return Status::OK();
927         };
928     OP_REQUIRES_OK_ASYNC(c,
929                          c->resource_manager()->LookupOrCreate(
930                              container_, shared_name_, &ubr, creator),
931                          done);
932     Status status = ubr->Compute(c, done);
933     ubr->Unref();
934     OP_REQUIRES_OK_ASYNC(c, status, done);
935     // Assume ubr calls done, so nothing to do here.
936   }
937 
938  private:
939   string container_;
940   string shared_name_;
941 };
942 REGISTER_KERNEL_BUILDER(Name("UnbatchGrad").Device(DEVICE_CPU),
943                         UnbatchGradKernel);
944 
945 }  // namespace tensorflow
946