1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "absl/strings/str_cat.h"
17 #include "tensorflow/core/common_runtime/device_mgr.h"
18 #include "tensorflow/core/framework/device.h"
19 #include "tensorflow/core/framework/function.h"
20 #include "tensorflow/core/framework/op_kernel.h"
21 #include "tensorflow/core/framework/resource_mgr.h"
22 #include "tensorflow/core/framework/tensor.h"
23 #include "tensorflow/core/framework/tensor_util.h"
24 #include "tensorflow/core/framework/types.h"
25 #include "tensorflow/core/kernels/batching_util/adaptive_shared_batch_scheduler.h"
26 #include "tensorflow/core/kernels/batching_util/batch_resource_base.h"
27 #include "tensorflow/core/kernels/batching_util/concat_split_util.h"
28 #include "tensorflow/core/kernels/batching_util/periodic_function.h"
29 #include "tensorflow/core/kernels/ops_util.h"
30 #include "tensorflow/core/lib/monitoring/gauge.h"
31 #include "tensorflow/core/lib/random/random.h"
32 #include "tensorflow/core/platform/errors.h"
33 #include "tensorflow/core/platform/logging.h"
34 #include "tensorflow/core/platform/macros.h"
35 #include "tensorflow/core/platform/numbers.h"
36
37 namespace tensorflow {
38 namespace {
39 constexpr int64 kMinInflightBatchesLimit = 16;
40 constexpr double kInitialInflightBatchesLimit = 64;
41 constexpr int64 kBatchesToAverageOver = 10;
42 constexpr int64 kMaxInflightBatchesLimit = 128;
43 } // namespace
44
45 auto* batch_op_split_usage = monitoring::Gauge<string, 1>::New(
46 "/tensorflow/serving/batching/enable_large_batch_splitting",
47 "Tracks the usage of attribute `enable_large_batch_splitting` for "
48 "BatchFunction kernel in a saved model.",
49 "model_name");
50
RecordBatchSplitUsage(absl::optional<bool> maybe_enable_large_batch_splitting,const string & model_name)51 void RecordBatchSplitUsage(
52 absl::optional<bool> maybe_enable_large_batch_splitting,
53 const string& model_name) {
54 if (maybe_enable_large_batch_splitting.has_value()) {
55 if (maybe_enable_large_batch_splitting.value()) {
56 batch_op_split_usage->GetCell(model_name)->Set("true");
57 } else {
58 batch_op_split_usage->GetCell(model_name)->Set("false");
59 }
60 } else {
61 batch_op_split_usage->GetCell(model_name)->Set("unset");
62 }
63 }
64
RecordBatchParamNumBatchThreads(int64 num_batch_threads,const string & model_name)65 void RecordBatchParamNumBatchThreads(int64 num_batch_threads,
66 const string& model_name) {
67 static auto* cell = monitoring::Gauge<int64, 1>::New(
68 "/tensorflow/serving/batching/num_batch_threads",
69 "Tracks the number of batch threads of a model.", "model_name");
70 cell->GetCell(model_name)->Set(num_batch_threads);
71 }
72
GetModelName(OpKernelContext * ctx)73 const string& GetModelName(OpKernelContext* ctx) {
74 static string* kModelNameUnset = new string("model_name_unset");
75 if (!ctx->session_metadata()) return *kModelNameUnset;
76 if (ctx->session_metadata()->name().empty()) return *kModelNameUnset;
77 return ctx->session_metadata()->name();
78 }
79
80 using ::tensorflow::concat_split_util::Concat;
81 using ::tensorflow::concat_split_util::Split;
82
83 // A class encapsulating the state and logic for batching tensors.
84 class BatchResource : public serving::BatchResourceBase {
85 public:
Create(int32 num_batch_threads,int32 max_execution_batch_size,int32 batch_timeout_micros,int32 max_enqueued_batches,const std::vector<int32> & allowed_batch_sizes,FunctionLibraryRuntime::Handle fhandle,FunctionLibraryRuntime * flib,bool enable_large_batch_splitting,std::unique_ptr<BatchResource> * resource)86 static Status Create(int32 num_batch_threads, int32 max_execution_batch_size,
87 int32 batch_timeout_micros, int32 max_enqueued_batches,
88 const std::vector<int32>& allowed_batch_sizes,
89 FunctionLibraryRuntime::Handle fhandle,
90 FunctionLibraryRuntime* flib,
91 bool enable_large_batch_splitting,
92 std::unique_ptr<BatchResource>* resource) {
93 BatcherT::Options batcher_options;
94 batcher_options.num_batch_threads = num_batch_threads;
95 std::shared_ptr<BatcherT> batcher;
96 TF_RETURN_IF_ERROR(BatcherT::Create(batcher_options, &batcher));
97
98 resource->reset(new BatchResource(
99 fhandle, flib, std::move(batcher),
100 GetBatcherQueueOptions(num_batch_threads, max_execution_batch_size,
101 batch_timeout_micros, max_enqueued_batches,
102 allowed_batch_sizes,
103 enable_large_batch_splitting),
104 allowed_batch_sizes));
105 return Status::OK();
106 }
107
Create(AdaptiveBatcherT::Options adaptive_shared_batch_scheduler_options,int32 max_batch_size,int32 batch_timeout_micros,int32 max_enqueued_batches,const std::vector<int32> & allowed_batch_sizes,FunctionLibraryRuntime::Handle fhandle,FunctionLibraryRuntime * flib,std::unique_ptr<BatchResource> * resource)108 static Status Create(
109 AdaptiveBatcherT::Options adaptive_shared_batch_scheduler_options,
110 int32 max_batch_size, int32 batch_timeout_micros,
111 int32 max_enqueued_batches, const std::vector<int32>& allowed_batch_sizes,
112 FunctionLibraryRuntime::Handle fhandle, FunctionLibraryRuntime* flib,
113 std::unique_ptr<BatchResource>* resource) {
114 std::shared_ptr<AdaptiveBatcherT> batcher;
115 TF_RETURN_IF_ERROR(AdaptiveBatcherT::Create(
116 adaptive_shared_batch_scheduler_options, &batcher));
117
118 resource->reset(new BatchResource(
119 fhandle, flib, std::move(batcher),
120 GetAdaptiveBatcherQueueOptions(
121 max_batch_size, batch_timeout_micros, max_enqueued_batches,
122 true /* enable large batch split */, allowed_batch_sizes),
123 allowed_batch_sizes));
124 return Status::OK();
125 }
126
DebugString() const127 string DebugString() const final { return "BatchResource"; }
128
129 private:
BatchResource(FunctionLibraryRuntime::Handle fhandle,FunctionLibraryRuntime * flib,std::shared_ptr<BatcherT> batcher,const BatcherT::QueueOptions & batcher_queue_options,std::vector<int32> allowed_batch_sizes)130 BatchResource(FunctionLibraryRuntime::Handle fhandle,
131 FunctionLibraryRuntime* flib, std::shared_ptr<BatcherT> batcher,
132 const BatcherT::QueueOptions& batcher_queue_options,
133 std::vector<int32> allowed_batch_sizes)
134 : BatchResourceBase(
135 /*has_process_batch_function=*/fhandle != kInvalidHandle,
136 std::move(batcher), batcher_queue_options,
137 std::move(allowed_batch_sizes)),
138 fhandle_(fhandle),
139 flib_(flib) {}
140
BatchResource(FunctionLibraryRuntime::Handle fhandle,FunctionLibraryRuntime * flib,std::shared_ptr<AdaptiveBatcherT> batcher,const AdaptiveBatcherT::QueueOptions & batcher_queue_options,std::vector<int32> allowed_batch_sizes)141 BatchResource(FunctionLibraryRuntime::Handle fhandle,
142 FunctionLibraryRuntime* flib,
143 std::shared_ptr<AdaptiveBatcherT> batcher,
144 const AdaptiveBatcherT::QueueOptions& batcher_queue_options,
145 std::vector<int32> allowed_batch_sizes)
146 : BatchResourceBase(
147 /*has_process_batch_function=*/fhandle != kInvalidHandle,
148 std::move(batcher), batcher_queue_options,
149 std::move(allowed_batch_sizes)),
150 fhandle_(fhandle),
151 flib_(flib) {}
152
ProcessFuncBatchImpl(const BatchTask & last_task,absl::Span<const Tensor> inputs,std::vector<Tensor> * combined_outputs,std::function<void (const Status &)> done) const153 void ProcessFuncBatchImpl(
154 const BatchTask& last_task, absl::Span<const Tensor> inputs,
155 std::vector<Tensor>* combined_outputs,
156 std::function<void(const Status&)> done) const override {
157 auto* last_task_context = last_task.context;
158 FunctionLibraryRuntime::Options opts;
159 opts.step_container = last_task_context->step_container();
160 opts.cancellation_manager = last_task_context->cancellation_manager();
161 opts.collective_executor = last_task_context->collective_executor();
162 opts.stats_collector = last_task_context->stats_collector();
163 opts.runner = last_task_context->runner();
164 opts.run_all_kernels_inline = last_task_context->run_all_kernels_inline();
165 // We do not set 'opts.rendezvous', since if the function is run multiple
166 // times in parallel with the same rendezvous, a _Send node from one run
167 // might be matched with a _Recv node of a different run. Not setting the
168 // rendezvous causes a new rendezvous to be used for each run.
169 Notification done_notif;
170
171 flib_->Run(opts, fhandle_, inputs, combined_outputs,
172 [&](const Status& run_status) {
173 done(run_status);
174 done_notif.Notify();
175 });
176 // By waiting for the notification we are ensuring that this thread isn't
177 // used for processing other batches, which gives the batches time to
178 // coalesce upstream. So overall the number of batches going through the
179 // devices goes down, improving latency and throughput in most cases.
180 done_notif.WaitForNotification();
181 }
182
183 FunctionLibraryRuntime::Handle fhandle_;
184 FunctionLibraryRuntime* flib_;
185 };
186
187 class BatchFunctionKernel : public AsyncOpKernel {
188 public:
BatchFunctionKernel(OpKernelConstruction * c)189 explicit BatchFunctionKernel(OpKernelConstruction* c) : AsyncOpKernel(c) {
190 OP_REQUIRES_OK(c, c->GetAttr("container", &container_));
191 OP_REQUIRES_OK(c, c->GetAttr("shared_name", &shared_name_));
192 OP_REQUIRES_OK(c, c->GetAttr("batching_queue", &batcher_queue_));
193 OP_REQUIRES_OK(c, c->GetAttr("num_batch_threads", &num_batch_threads_));
194 OP_REQUIRES_OK(c, c->GetAttr("max_batch_size", &max_batch_size_));
195 OP_REQUIRES_OK(c,
196 c->GetAttr("batch_timeout_micros", &batch_timeout_micros_));
197 OP_REQUIRES_OK(c,
198 c->GetAttr("max_enqueued_batches", &max_enqueued_batches_));
199 OP_REQUIRES_OK(c, c->GetAttr("allowed_batch_sizes", &allowed_batch_sizes_));
200
201 OP_REQUIRES_OK(c, c->GetAttr("f", &func_));
202 flib_ = c->function_library();
203 if (num_batch_threads_ <= 0) {
204 adaptive_batch_scheduler_options_ =
205 absl::make_optional(AdaptiveBatchSchedulerOptions{
206 kMinInflightBatchesLimit, kInitialInflightBatchesLimit,
207 kBatchesToAverageOver});
208
209 // One scheduler instance contains a couple of queue instances,
210 // `batcher_queue_` is the key to find queue for this batch-op in the
211 // graph.
212 // Use `shared_name_` and name() as prefix for `batcher_queue_`.
213 // Note name() is unique per session (from session metadata).
214 batcher_queue_ = name() + "/" + shared_name_ + batcher_queue_;
215
216 // `shared_name_` and `container_` is used to look up an instantiated
217 // scheduler instance in `ComputeAsync`.
218 //
219 // Rewrite `container_` and `shared_name_` to a pre-defined constant so
220 // that a shared shared pool across all models if adaptive shared batch
221 // scheduler is used.
222 container_ = "__adapative_container";
223 shared_name_ = "__adaptive_global_shared_thread_pool";
224 }
225
226 if (shared_name_.empty()) {
227 // If shared_name is not supplied, use name instead (prevent collisions by
228 // default).
229 shared_name_ = name();
230 }
231
232 if (c->HasAttr("enable_large_batch_splitting")) {
233 OP_REQUIRES_OK(c, c->GetAttr("enable_large_batch_splitting",
234 &enable_large_batch_splitting_));
235 has_attribute_enable_large_batch_splitting_ = true;
236 } else {
237 enable_large_batch_splitting_ = false;
238 has_attribute_enable_large_batch_splitting_ = false;
239 }
240
241 OP_REQUIRES_OK(c, ValidateAllowedBatchSizes());
242 }
243
IsExpensive()244 bool IsExpensive() override { return false; }
245
ComputeAsync(OpKernelContext * c,DoneCallback done)246 void ComputeAsync(OpKernelContext* c, DoneCallback done) final {
247 RecordBatchSplitUsage(
248 has_attribute_enable_large_batch_splitting_
249 ? absl::make_optional(enable_large_batch_splitting_)
250 : absl::nullopt,
251 GetModelName(c));
252 // TODO(b/173255290): Add num_batch_threads_ parameter to TFRT batch kernel.
253 RecordBatchParamNumBatchThreads(num_batch_threads_, GetModelName(c));
254
255 std::function<Status(BatchResource**)> creator;
256
257 FunctionLibraryRuntime::Handle handle;
258 OP_REQUIRES_OK_ASYNC(c, GetOrCreateFunctionHandle(c, &handle), done);
259
260 if (adaptive_batch_scheduler_options_ != absl::nullopt) {
261 creator = [this, handle](BatchResource** r) {
262 serving::AdaptiveSharedBatchScheduler<
263 serving::BatchResourceBase::BatchTask>::Options
264 adaptive_shared_batch_scheduler_options;
265 adaptive_shared_batch_scheduler_options.thread_pool_name =
266 "adaptive_batch_threads";
267 adaptive_shared_batch_scheduler_options.num_batch_threads =
268 kMaxInflightBatchesLimit;
269 // adaptive_shared_batch_scheduler_options.full_batch_scheduling_boost_micros
270 // is 0 (default value) intentionally, so tasks are scheduled in a FIFO
271 // way.
272 // Two rationales to use default value (zero) for
273 // `full_batch_scheduling_boost_micros`
274 // 1) In this way, tasks scheduling policy is FIFO. Compared with round
275 // robin (what shared batch scheduler does), FIFO ensures that model
276 // with low QPS (i.e., models enqueue fewer tasks in the shared queue)
277 // will be processed timely.
278 // 2) If set, `full_batch_scheduling_boost_micros` should be of order
279 // the batch processing latency (which varies on a model basis).
280 // If a non-zero value is not set properly, it harms tail latency.
281 adaptive_shared_batch_scheduler_options.min_in_flight_batches_limit =
282 adaptive_batch_scheduler_options_->min_in_flight_batches_limit;
283 adaptive_shared_batch_scheduler_options
284 .initial_in_flight_batches_limit =
285 adaptive_batch_scheduler_options_->initial_in_flight_batches_limit;
286 adaptive_shared_batch_scheduler_options.batches_to_average_over =
287 adaptive_batch_scheduler_options_->batches_to_average_over;
288 std::unique_ptr<BatchResource> new_resource;
289 TF_RETURN_IF_ERROR(BatchResource::Create(
290 adaptive_shared_batch_scheduler_options, max_batch_size_,
291 batch_timeout_micros_, max_enqueued_batches_, allowed_batch_sizes_,
292 handle, flib_, &new_resource));
293 *r = new_resource.release();
294 return Status::OK();
295 };
296 } else {
297 creator = [this, handle](BatchResource** r) {
298 std::unique_ptr<BatchResource> new_resource;
299 TF_RETURN_IF_ERROR(BatchResource::Create(
300 num_batch_threads_, max_batch_size_, batch_timeout_micros_,
301 max_enqueued_batches_, allowed_batch_sizes_, handle, flib_,
302 enable_large_batch_splitting_, &new_resource));
303 *r = new_resource.release();
304 return Status::OK();
305 };
306 }
307
308 BatchResource* br;
309 OP_REQUIRES_OK_ASYNC(c,
310 c->resource_manager()->LookupOrCreate(
311 container_, shared_name_, &br, creator),
312 done);
313 const Status status =
314 br->RegisterInput(random::New64(), c, batcher_queue_, done);
315 br->Unref();
316 OP_REQUIRES_OK_ASYNC(c, status, done);
317 // Assume br calls done, so nothing to do here.
318 }
319
InstantiateFunction(OpKernelContext * c,FunctionLibraryRuntime::Handle * handle) const320 Status InstantiateFunction(OpKernelContext* c,
321 FunctionLibraryRuntime::Handle* handle) const {
322 // TODO(b/173748062): Merge this instantiation logic with PartitionedCall.
323 if (!flib_) {
324 return errors::Internal("No function library");
325 }
326
327 FunctionLibraryRuntime::InstantiateOptions opts;
328 opts.target = flib_->device() == nullptr ? "" : flib_->device()->name();
329 opts.is_multi_device_function = true;
330 const ConfigProto* config = flib_->config_proto();
331 if (config) {
332 opts.config_proto = *config;
333 }
334
335 Device* cpu_device;
336 TF_RETURN_IF_ERROR(flib_->device_mgr()->LookupDevice("CPU:0", &cpu_device));
337
338 const FunctionDef* fdef =
339 flib_->GetFunctionLibraryDefinition()->Find(func_.name());
340 if (!fdef) {
341 return errors::NotFound("Failed to find definition for function \"",
342 func_.name(), "\"");
343 }
344 OpInputList in_tensors;
345 TF_RETURN_IF_ERROR(c->input_list("in_tensors", &in_tensors));
346 for (int i = 0; i < in_tensors.size(); i++) {
347 if (in_tensors[i].dtype() == DT_RESOURCE) {
348 return errors::InvalidArgument(
349 "BatchFunction cannot take resource inputs but input ", i,
350 " is a resource.");
351 } else {
352 // Currently, inputs are on CPU since they are concatenated on CPU
353 opts.input_devices.push_back(cpu_device->name());
354 }
355 }
356 OpInputList captured_tensors;
357 TF_RETURN_IF_ERROR(c->input_list("captured_tensors", &captured_tensors));
358 for (const Tensor& t : captured_tensors) {
359 if (t.dtype() == DT_RESOURCE) {
360 const ResourceHandle& rhandle = t.flat<ResourceHandle>()(0);
361 opts.input_devices.push_back(rhandle.device());
362 } else {
363 opts.input_devices.push_back(cpu_device->name());
364 }
365 }
366 const OpDef& signature = fdef->signature();
367 for (int i = 0; i < signature.output_arg_size(); i++) {
368 // Currently, outputs must be on CPU since they are split on CPU.
369 opts.output_devices.push_back(cpu_device->name());
370 }
371 if (opts.input_devices.size() != signature.input_arg_size()) {
372 return errors::InvalidArgument(
373 "Function takes ", signature.input_arg_size(), " argument(s) but ",
374 opts.input_devices.size(), " argument(s) were passed");
375 }
376 return flib_->Instantiate(func_.name(), AttrSlice(&func_.attr()), opts,
377 handle);
378 }
379
GetOrCreateFunctionHandle(OpKernelContext * c,FunctionLibraryRuntime::Handle * handle)380 Status GetOrCreateFunctionHandle(OpKernelContext* c,
381 FunctionLibraryRuntime::Handle* handle) {
382 mutex_lock ml(mu_);
383 if (!fhandle_) {
384 TF_RETURN_IF_ERROR(InstantiateFunction(c, handle));
385 fhandle_ = *handle;
386 } else {
387 *handle = fhandle_.value();
388 }
389 return Status::OK();
390 }
391
392 // Validates 'allowed_batch_sizes_'. The entries must increase monotonically.
393 // If large batch split is not enabled, the last one must equal
394 // `max_batch_size_`. otherwise the last element must be smaller than or equal
395 // to `max_batch_size_`.
ValidateAllowedBatchSizes() const396 Status ValidateAllowedBatchSizes() const {
397 if (allowed_batch_sizes_.empty()) {
398 return Status::OK();
399 }
400 int32 last_size = 0;
401 for (size_t i = 0; i < allowed_batch_sizes_.size(); ++i) {
402 const int32 size = allowed_batch_sizes_.at(i);
403 if (i > 0 && size <= last_size) {
404 return errors::InvalidArgument(
405 "allowed_batch_sizes entries must be monotonically increasing");
406 }
407
408 if ((!enable_large_batch_splitting_) &&
409 (i == allowed_batch_sizes_.size() - 1) && (size != max_batch_size_)) {
410 return errors::InvalidArgument(
411 "final entry in allowed_batch_sizes must equal max_batch_size when "
412 "enable_large_batch_splitting is False");
413 }
414
415 last_size = size;
416 }
417 return Status::OK();
418 }
419
420 private:
421 string container_;
422 string shared_name_;
423 string batcher_queue_;
424 int32 num_batch_threads_;
425 int32 max_batch_size_;
426 int32 batch_timeout_micros_;
427 int32 max_enqueued_batches_;
428 std::vector<int32> allowed_batch_sizes_;
429 NameAttrList func_;
430 absl::optional<FunctionLibraryRuntime::Handle> fhandle_ TF_GUARDED_BY(mu_);
431 FunctionLibraryRuntime* flib_;
432 bool enable_large_batch_splitting_;
433 bool has_attribute_enable_large_batch_splitting_;
434 mutex mu_;
435
436 // Parameters for adaptive batch scheduler only.
437 // Note 'num_batch_threads_' above is shared by two implementations of batch
438 // scheduler.
439 struct AdaptiveBatchSchedulerOptions {
440 int64 min_in_flight_batches_limit;
441 double initial_in_flight_batches_limit;
442 int64 batches_to_average_over;
443 };
444 absl::optional<AdaptiveBatchSchedulerOptions>
445 adaptive_batch_scheduler_options_ = absl::nullopt;
446 };
447
448 REGISTER_KERNEL_BUILDER(Name("BatchFunction").Device(DEVICE_CPU),
449 BatchFunctionKernel);
450 // Currently all inputs and outputs are on the host.
451 // TODO(b/173748277): Accept inputs/outputs on the device.
452 REGISTER_KERNEL_BUILDER(Name("BatchFunction")
453 .Device(DEVICE_GPU)
454 .HostMemory("in_tensors")
455 .HostMemory("captured_tensors")
456 .HostMemory("out_tensors"),
457 BatchFunctionKernel);
458
459 class BatchKernel : public AsyncOpKernel {
460 public:
BatchKernel(OpKernelConstruction * c)461 explicit BatchKernel(OpKernelConstruction* c) : AsyncOpKernel(c) {
462 OP_REQUIRES_OK(c, c->GetAttr("container", &container_));
463 OP_REQUIRES_OK(c, c->GetAttr("shared_name", &shared_name_));
464 // If shared_name is not supplied, use name instead (prevent collisions by
465 // default).
466 if (shared_name_.empty()) {
467 shared_name_ = name();
468 }
469 OP_REQUIRES_OK(c, c->GetAttr("batching_queue", &batcher_queue_));
470 OP_REQUIRES_OK(c, c->GetAttr("num_batch_threads", &num_batch_threads_));
471 OP_REQUIRES_OK(c, c->GetAttr("max_batch_size", &max_batch_size_));
472 OP_REQUIRES_OK(c,
473 c->GetAttr("batch_timeout_micros", &batch_timeout_micros_));
474 OP_REQUIRES_OK(c,
475 c->GetAttr("max_enqueued_batches", &max_enqueued_batches_));
476 OP_REQUIRES_OK(c, c->GetAttr("allowed_batch_sizes", &allowed_batch_sizes_));
477 OP_REQUIRES_OK(c, ValidateAllowedBatchSizes());
478 }
479
ComputeAsync(OpKernelContext * c,DoneCallback done)480 void ComputeAsync(OpKernelContext* c, DoneCallback done) final {
481 BatchResource* br;
482 std::function<Status(BatchResource**)> creator = [this](BatchResource** r) {
483 std::unique_ptr<BatchResource> new_resource;
484 TF_RETURN_IF_ERROR(BatchResource::Create(
485 num_batch_threads_, max_batch_size_, batch_timeout_micros_,
486 max_enqueued_batches_, allowed_batch_sizes_, kInvalidHandle,
487 /*flib=*/nullptr, false, &new_resource));
488 *r = new_resource.release();
489 return Status::OK();
490 };
491 OP_REQUIRES_OK_ASYNC(c,
492 c->resource_manager()->LookupOrCreate(
493 container_, shared_name_, &br, creator),
494 done);
495 const Status status =
496 br->RegisterInput(random::New64(), c, batcher_queue_, done);
497 br->Unref();
498 OP_REQUIRES_OK_ASYNC(c, status, done);
499 // Assume br calls done, so nothing to do here.
500 }
501
502 // Validates 'allowed_batch_sizes_'. The entries must increase
503 // monotonically, and the last one must equal 'max_batch_size_'.
ValidateAllowedBatchSizes() const504 Status ValidateAllowedBatchSizes() const {
505 if (allowed_batch_sizes_.empty()) {
506 return Status::OK();
507 }
508 int32 last_size = 0;
509 for (size_t i = 0; i < allowed_batch_sizes_.size(); ++i) {
510 const int32 size = allowed_batch_sizes_.at(i);
511 if (i > 0 && size <= last_size) {
512 return errors::InvalidArgument(
513 "allowed_batch_sizes entries must be monotonically increasing");
514 }
515 if (i == allowed_batch_sizes_.size() - 1 && size != max_batch_size_) {
516 return errors::InvalidArgument(
517 "final entry in allowed_batch_sizes must equal max_batch_size");
518 }
519 last_size = size;
520 }
521 return Status::OK();
522 }
523
524 private:
525 string container_;
526 string shared_name_;
527 string batcher_queue_;
528 int32 num_batch_threads_;
529 int32 max_batch_size_;
530 int32 batch_timeout_micros_;
531 int32 max_enqueued_batches_;
532 std::vector<int32> allowed_batch_sizes_;
533 };
534
535 REGISTER_KERNEL_BUILDER(Name("Batch").Device(DEVICE_CPU), BatchKernel);
536
537 // A class encapsulating the state and logic for unbatching tensors.
538 //
539 // UnbatchResource keeps two data structures indexed by batch-key: one which has
540 // the continuations for all concurrent kernels which are waiting for tensors
541 // and another which has tensors which are waiting for their corresponding
542 // kernels to run. Whenever a kernel runs, we either grab its tensor if it's
543 // waiting already, or we insert it in the queue and then look at its tensor to
544 // see if it can be used to dispatch any stored continuations.
545 class UnbatchResource : public ResourceBase {
546 public:
UnbatchResource(int32 timeout_micros)547 explicit UnbatchResource(int32 timeout_micros)
548 : timeout_micros_(timeout_micros),
549 timeout_enforcer_(new serving::PeriodicFunction(
550 [this] { EnforceTimeout(); }, 1000 /* 1 ms */)) {}
551
~UnbatchResource()552 ~UnbatchResource() override {
553 // Tear down 'timeout_enforcer_' first, since it accesses other state in
554 // this class.
555 timeout_enforcer_ = nullptr;
556 }
557
DebugString() const558 string DebugString() const final { return "UnbatchResource"; }
559
Compute(OpKernelContext * context,AsyncOpKernel::DoneCallback done)560 Status Compute(OpKernelContext* context, AsyncOpKernel::DoneCallback done) {
561 const Tensor& data_t = context->input(0);
562 const Tensor& batch_index_t = context->input(1);
563
564 if (batch_index_t.shape().dim_size(0) > data_t.shape().dim_size(0)) {
565 return errors::InvalidArgument(
566 "Wrong shape for index tensor. Expected 0th dimension size to be no "
567 "greater than ",
568 data_t.shape().dim_size(0),
569 "; Got: ", batch_index_t.shape().dim_size(0), ".");
570 }
571 if (batch_index_t.shape().dim_size(1) != 3) {
572 return errors::InvalidArgument(
573 "Wrong shape for index tensor. Expected 1st dimension size to be 3 ; "
574 "Got: ",
575 batch_index_t.shape().dim_size(1), ".");
576 }
577
578 const int64 batch_key = context->input(2).scalar<int64>()();
579 const bool nonempty_input = batch_index_t.dim_size(0) > 0;
580
581 // If we have a non-empty tensor, slice it up.
582 // (It is important to do this outside of the critical section below.)
583 // The following variables are populated iff 'nonempty_input==true'.
584 std::vector<int64> sizes;
585 std::vector<int64> batch_keys;
586 std::vector<Tensor> split_inputs;
587 if (nonempty_input) {
588 auto batch_indices =
589 batch_index_t.shaped<int64, 2>({batch_index_t.dim_size(0), 3});
590 for (int i = 0; i < batch_index_t.dim_size(0); ++i) {
591 sizes.push_back(batch_indices(i, 2) - batch_indices(i, 1));
592 batch_keys.push_back(batch_indices(i, 0));
593 }
594
595 TF_RETURN_IF_ERROR(Split(context, data_t, sizes, &split_inputs));
596 }
597
598 // Critical section.
599 std::vector<AsyncOpKernel::DoneCallback> done_callbacks_to_call;
600 Status status = [&]() -> Status {
601 mutex_lock ml(mu_);
602
603 // Check to see whether the tensor we want is already ready.
604 auto tensor_it = waiting_tensors_.find(batch_key);
605 if (tensor_it != waiting_tensors_.end()) {
606 context->set_output(0, tensor_it->second.tensor);
607 waiting_tensors_.erase(tensor_it);
608 done_callbacks_to_call.push_back(done);
609 return Status::OK();
610 }
611
612 const uint64 deadline_micros =
613 Env::Default()->NowMicros() + timeout_micros_;
614
615 // Add ourselves to the waitlist for tensors.
616 if (!waiting_callbacks_
617 .emplace(batch_key,
618 WaitingCallback{deadline_micros, context, done})
619 .second) {
620 return errors::AlreadyExists(
621 "Multiple session runs with the same batch key.");
622 }
623
624 // If we have a non-empty tensor, finish the waitlisted runs,
625 // and store any remaining pieces.
626 if (nonempty_input) {
627 for (size_t i = 0; i < batch_keys.size(); ++i) {
628 auto runs_it = waiting_callbacks_.find(batch_keys[i]);
629 if (runs_it != waiting_callbacks_.end()) {
630 runs_it->second.context->set_output(0, split_inputs[i]);
631 done_callbacks_to_call.push_back(runs_it->second.done);
632 waiting_callbacks_.erase(runs_it);
633 } else {
634 // Note: the deadline here is in case we are arriving late and the
635 // kernel that should rendezvous with this tensor has already waited
636 // and timed out.
637 if (!waiting_tensors_
638 .emplace(batch_keys[i],
639 WaitingTensor{deadline_micros, split_inputs[i]})
640 .second) {
641 return errors::AlreadyExists(
642 "Multiple tensors returned for same batch key.");
643 }
644 }
645 }
646 }
647
648 return Status::OK();
649 }();
650
651 for (const AsyncOpKernel::DoneCallback& done_callback :
652 done_callbacks_to_call) {
653 done_callback();
654 }
655
656 return status;
657 }
658
659 private:
660 // Evicts waiting tensors and callbacks that have exceeded their deadline.
EnforceTimeout()661 void EnforceTimeout() {
662 const uint64 now = Env::Default()->NowMicros();
663 std::vector<WaitingCallback> evicted_callbacks;
664
665 {
666 mutex_lock ml(mu_);
667
668 for (auto it = waiting_tensors_.begin(); it != waiting_tensors_.end();) {
669 const WaitingTensor& waiting_tensor = it->second;
670 if (waiting_tensor.deadline_micros < now) {
671 it = waiting_tensors_.erase(it);
672 } else {
673 ++it;
674 }
675 }
676
677 for (auto it = waiting_callbacks_.begin();
678 it != waiting_callbacks_.end();) {
679 const WaitingCallback& waiting_callback = it->second;
680 if (waiting_callback.deadline_micros < now) {
681 evicted_callbacks.push_back(waiting_callback);
682 it = waiting_callbacks_.erase(it);
683 } else {
684 ++it;
685 }
686 }
687 }
688
689 for (const WaitingCallback& evicted_callback : evicted_callbacks) {
690 evicted_callback.context->CtxFailureWithWarning(errors::DeadlineExceeded(
691 "Batched data did not arrive within timeout window."));
692 evicted_callback.done();
693 }
694 }
695
696 struct WaitingTensor {
697 uint64 deadline_micros;
698 Tensor tensor;
699 };
700
701 struct WaitingCallback {
702 uint64 deadline_micros;
703 OpKernelContext* context;
704 AsyncOpKernel::DoneCallback done;
705 };
706
707 const int32 timeout_micros_;
708
709 mutex mu_;
710
711 // Maps keyed by BatchKey of tensors waiting for callbacks and callbacks
712 // waiting for tensors.
713 std::unordered_map<int64, WaitingTensor> waiting_tensors_ TF_GUARDED_BY(mu_);
714 std::unordered_map<int64, WaitingCallback> waiting_callbacks_
715 TF_GUARDED_BY(mu_);
716
717 // A thread that evicts waiting tensors and callbacks that have exceeded their
718 // deadline.
719 std::unique_ptr<serving::PeriodicFunction> timeout_enforcer_;
720 };
721
722 class UnbatchKernel : public AsyncOpKernel {
723 public:
UnbatchKernel(OpKernelConstruction * c)724 explicit UnbatchKernel(OpKernelConstruction* c) : AsyncOpKernel(c) {
725 OP_REQUIRES_OK(c, c->GetAttr("container", &container_));
726 OP_REQUIRES_OK(c, c->GetAttr("shared_name", &shared_name_));
727 // If shared_name is not supplied, use name instead (prevent collisions by
728 // default).
729 if (shared_name_.empty()) {
730 shared_name_ = name();
731 }
732 OP_REQUIRES_OK(c, c->GetAttr("timeout_micros", &timeout_micros_));
733 }
734
ComputeAsync(OpKernelContext * c,DoneCallback done)735 void ComputeAsync(OpKernelContext* c, DoneCallback done) final {
736 UnbatchResource* ubr;
737 std::function<Status(UnbatchResource**)> creator =
738 [this](UnbatchResource** r) {
739 *r = new UnbatchResource(timeout_micros_);
740 return Status::OK();
741 };
742 OP_REQUIRES_OK_ASYNC(c,
743 c->resource_manager()->LookupOrCreate(
744 container_, shared_name_, &ubr, creator),
745 done);
746 auto status = ubr->Compute(c, done);
747 ubr->Unref();
748 OP_REQUIRES_OK_ASYNC(c, status, done);
749 // Assume ubr calls done, so nothing to do here.
750 }
751
752 private:
753 string container_;
754 string shared_name_;
755 int32 timeout_micros_;
756 };
757 REGISTER_KERNEL_BUILDER(Name("Unbatch").Device(DEVICE_CPU), UnbatchKernel);
758
759 // A class encapsulating the state and logic for batching tensors
760 // deterministically for the gradient of unbatch.
761 class UnbatchGradResource : public ResourceBase {
762 public:
UnbatchGradResource()763 UnbatchGradResource() {}
764
DebugString() const765 string DebugString() const final { return "UnbatchGradResource"; }
766
767 // Flushes the information for one batch, given its context and done
768 // callback. Clears all information about it from the available_tensors_.
OutputBatch(OpKernelContext * context,const AsyncOpKernel::DoneCallback & done)769 Status OutputBatch(OpKernelContext* context,
770 const AsyncOpKernel::DoneCallback& done)
771 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
772 const Tensor& batch_index_t = context->input(1);
773 auto batch_index =
774 batch_index_t.shaped<int64, 2>({batch_index_t.dim_size(0), 3});
775 std::vector<Tensor> tensors;
776 for (int i = 0; i < batch_index_t.dim_size(0); ++i) {
777 auto available_it = available_tensors_.find(batch_index(i, 0));
778 if (available_it == available_tensors_.end()) {
779 return errors::Internal("bad bookkeeping of available tensors.");
780 }
781 tensors.push_back(available_it->second);
782 available_tensors_.erase(available_it);
783 }
784
785 const DataType type = tensors[0].dtype();
786 Tensor concatenated_tensor;
787 switch (type) {
788 #define CASE(type) \
789 case DataTypeToEnum<type>::value: \
790 TF_RETURN_IF_ERROR(Concat<type>(context, tensors, &concatenated_tensor)); \
791 context->set_output(0, concatenated_tensor); \
792 break;
793 TF_CALL_ALL_TYPES(CASE);
794 #undef CASE
795 default:
796 return errors::InvalidArgument("Unsupported data type: ", type);
797 }
798 done();
799 return Status::OK();
800 }
801
802 // Ingests data from one invocation of the op.
Compute(OpKernelContext * context,const AsyncOpKernel::DoneCallback & done)803 Status Compute(OpKernelContext* context,
804 const AsyncOpKernel::DoneCallback& done) {
805 const Tensor& data_t = context->input(0);
806 const Tensor& batch_index_t = context->input(1);
807 const Tensor& grad_t = context->input(2);
808
809 mutex_lock ml(mu_);
810
811 const int64 batch_key = context->input(3).scalar<int64>()();
812 // Mark our tensor as available.
813 if (!available_tensors_.emplace(batch_key, grad_t).second) {
814 return errors::InvalidArgument("Two runs with the same batch key.");
815 }
816
817 // Check whether we have a valid input tensor and, if so, create its
818 // dispatch logic.
819 if (data_t.NumElements() > 0) {
820 if (batch_index_t.NumElements() == 0) {
821 return errors::InvalidArgument(
822 "batch_index is empty while the tensor isn't.");
823 }
824 std::unordered_set<int64> missing_tensors;
825 const auto batch_index =
826 batch_index_t.shaped<int64, 2>({batch_index_t.dim_size(0), 3});
827 for (int i = 0; i < batch_index_t.dim_size(0); ++i) {
828 const int64 batch_key = batch_index(i, 0);
829 if (available_tensors_.find(batch_key) == available_tensors_.end()) {
830 missing_tensors.emplace(batch_key);
831 }
832 }
833 if (missing_tensors.empty()) {
834 return OutputBatch(context, done);
835 }
836 if (!available_batches_
837 .emplace(batch_key, Batch{missing_tensors, context, done})
838 .second) {
839 return errors::InvalidArgument(
840 "Batch key with valid batch used twice.");
841 }
842 for (const int64 i : missing_tensors) {
843 if (!desired_tensor_to_batch_map_.emplace(i, batch_key).second) {
844 return errors::InvalidArgument(
845 "Missing tensor wanted by more than one batch.");
846 }
847 }
848 } else {
849 // If we don't have a valid input tensor we can output an empty tensor and
850 // call our done closure.
851 TensorShape output_shape(grad_t.shape());
852 output_shape.set_dim(0, 0);
853 Tensor* output = nullptr;
854 TF_RETURN_IF_ERROR(context->allocate_output(0, output_shape, &output));
855 done();
856 }
857
858 // Search to see whether our tensor is desired by any existing batch.
859 auto desire_it = desired_tensor_to_batch_map_.find(batch_key);
860 if (desire_it != desired_tensor_to_batch_map_.end()) {
861 // Mark our tensor as no longer missing.
862 auto batch_it = available_batches_.find(desire_it->second);
863 desired_tensor_to_batch_map_.erase(desire_it);
864 if (batch_it == available_batches_.end()) {
865 return errors::InvalidArgument("Batch no longer exists.");
866 }
867 batch_it->second.missing_tensors.erase(batch_key);
868 // If all tensors are available we should concatenate them and dispatch
869 // the batch.
870 if (batch_it->second.missing_tensors.empty()) {
871 TF_RETURN_IF_ERROR(
872 OutputBatch(batch_it->second.context, batch_it->second.done));
873 available_batches_.erase(batch_it);
874 }
875 }
876 return Status::OK();
877 }
878
879 private:
880 mutex mu_;
881
882 // Represents a still-incomplete batch of tensors. When all tensors become
883 // available they will be concatenated in the right order and sent through the
884 // context.
885 struct Batch {
886 // Batch keys for tensors which are still missing from this batch. When this
887 // is empty the Tensors can be concatenated and forwarded.
888 std::unordered_set<int64> missing_tensors;
889
890 // Context and callback for the session responsible for finishing this
891 // batch.
892 OpKernelContext* context;
893 AsyncOpKernel::DoneCallback done;
894 };
895
896 // Map from batch key of the session which will output the batched gradients
897 // to still-incomplete batches.
898 std::unordered_map<int64, Batch> available_batches_;
899
900 // Map from batch key to tensors which are waiting for their batches to be
901 // available.
902 std::unordered_map<int64, Tensor> available_tensors_;
903
904 // Map from batch key of a tensor which is not yet available to the batch key
905 // of the batch to which it belongs.
906 std::unordered_map<int64, int64> desired_tensor_to_batch_map_;
907 };
908
909 class UnbatchGradKernel : public AsyncOpKernel {
910 public:
UnbatchGradKernel(OpKernelConstruction * c)911 explicit UnbatchGradKernel(OpKernelConstruction* c) : AsyncOpKernel(c) {
912 OP_REQUIRES_OK(c, c->GetAttr("container", &container_));
913 OP_REQUIRES_OK(c, c->GetAttr("shared_name", &shared_name_));
914 // If shared_name is not supplied, use name instead (prevent collisions by
915 // default).
916 if (shared_name_.empty()) {
917 shared_name_ = name();
918 }
919 }
920
ComputeAsync(OpKernelContext * c,DoneCallback done)921 void ComputeAsync(OpKernelContext* c, DoneCallback done) final {
922 UnbatchGradResource* ubr;
923 std::function<Status(UnbatchGradResource**)> creator =
924 [](UnbatchGradResource** r) {
925 *r = new UnbatchGradResource();
926 return Status::OK();
927 };
928 OP_REQUIRES_OK_ASYNC(c,
929 c->resource_manager()->LookupOrCreate(
930 container_, shared_name_, &ubr, creator),
931 done);
932 Status status = ubr->Compute(c, done);
933 ubr->Unref();
934 OP_REQUIRES_OK_ASYNC(c, status, done);
935 // Assume ubr calls done, so nothing to do here.
936 }
937
938 private:
939 string container_;
940 string shared_name_;
941 };
942 REGISTER_KERNEL_BUILDER(Name("UnbatchGrad").Device(DEVICE_CPU),
943 UnbatchGradKernel);
944
945 } // namespace tensorflow
946