1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "absl/strings/str_cat.h"
17 #include "tensorflow/core/common_runtime/device_mgr.h"
18 #include "tensorflow/core/framework/device.h"
19 #include "tensorflow/core/framework/function.h"
20 #include "tensorflow/core/framework/op_kernel.h"
21 #include "tensorflow/core/framework/resource_mgr.h"
22 #include "tensorflow/core/framework/tensor.h"
23 #include "tensorflow/core/framework/tensor_util.h"
24 #include "tensorflow/core/framework/types.h"
25 #include "tensorflow/core/kernels/batching_util/adaptive_shared_batch_scheduler.h"
26 #include "tensorflow/core/kernels/batching_util/batch_resource_base.h"
27 #include "tensorflow/core/kernels/batching_util/concat_split_util.h"
28 #include "tensorflow/core/kernels/batching_util/periodic_function.h"
29 #include "tensorflow/core/kernels/ops_util.h"
30 #include "tensorflow/core/lib/monitoring/gauge.h"
31 #include "tensorflow/core/lib/random/random.h"
32 #include "tensorflow/core/platform/errors.h"
33 #include "tensorflow/core/platform/logging.h"
34 #include "tensorflow/core/platform/macros.h"
35 #include "tensorflow/core/platform/numbers.h"
36 #include "tensorflow/core/platform/threadpool.h"
37
38 namespace tensorflow {
39 namespace {
40 // Op attributes.
41 constexpr char kEnableAdaptiveSchedulerAttr[] = "_enable_adaptive_scheduler";
42 constexpr char kMinInflightBatchesAttr[] = "_min_inflight_batches";
43 constexpr char kInitialInflightBatchesAttr[] = "_initial_inflight_batches";
44 constexpr char kMaxInflightBatchesAttr[] = "_max_inflight_batches";
45 constexpr char kBatchesToAverageOverAttr[] = "_batches_to_average_over";
46
47 // Per-model inflight batches parameters.
48 constexpr int64_t kMinInflightBatches = 16;
49 constexpr int64_t kInitialInflightBatches = 16;
50 constexpr int64_t kBatchesToAverageOver = 10;
51 constexpr int64_t kMaxInflightBatches = 64;
52
53 // The max number of threads in the per-process thread pool, shared by
54 // all executions of batch-op.
55 constexpr int64_t kBatchThreadPoolSize = 128;
56 } // namespace
57
58 auto* batch_op_split_usage = monitoring::Gauge<string, 1>::New(
59 "/tensorflow/serving/batching/enable_large_batch_splitting",
60 "Tracks the usage of attribute `enable_large_batch_splitting` for "
61 "BatchFunction kernel in a saved model.",
62 "model_name");
63
RecordBatchSplitUsage(absl::optional<bool> maybe_enable_large_batch_splitting,const string & model_name)64 void RecordBatchSplitUsage(
65 absl::optional<bool> maybe_enable_large_batch_splitting,
66 const string& model_name) {
67 if (maybe_enable_large_batch_splitting.has_value()) {
68 if (maybe_enable_large_batch_splitting.value()) {
69 batch_op_split_usage->GetCell(model_name)->Set("true");
70 } else {
71 batch_op_split_usage->GetCell(model_name)->Set("false");
72 }
73 } else {
74 batch_op_split_usage->GetCell(model_name)->Set("unset");
75 }
76 }
77
RecordBatchParamNumBatchThreads(int64_t num_batch_threads,const string & model_name)78 void RecordBatchParamNumBatchThreads(int64_t num_batch_threads,
79 const string& model_name) {
80 static auto* cell = monitoring::Gauge<int64, 1>::New(
81 "/tensorflow/serving/batching/num_batch_threads",
82 "Tracks the number of batch threads of a model.", "model_name");
83 cell->GetCell(model_name)->Set(num_batch_threads);
84 }
85
GetModelName(OpKernelContext * ctx)86 const string& GetModelName(OpKernelContext* ctx) {
87 static string* kModelNameUnset = new string("model_name_unset");
88 if (!ctx->session_metadata()) return *kModelNameUnset;
89 if (ctx->session_metadata()->name().empty()) return *kModelNameUnset;
90 return ctx->session_metadata()->name();
91 }
92
93 using ::tensorflow::concat_split_util::Concat;
94 using ::tensorflow::concat_split_util::Split;
95
GetOrCreateBatchThreadsPool(const string & thread_name,int num_batch_threads)96 static thread::ThreadPool* GetOrCreateBatchThreadsPool(
97 const string& thread_name, int num_batch_threads) {
98 static thread::ThreadPool* pool =
99 new thread::ThreadPool(Env::Default(), thread_name, num_batch_threads);
100 return pool;
101 }
102
103 // A class encapsulating the state and logic for batching tensors.
104 class BatchResource : public serving::BatchResourceBase {
105 public:
Create(int32_t num_batch_threads,int32_t max_execution_batch_size,int32_t batch_timeout_micros,int32_t max_enqueued_batches,const std::vector<int32> & allowed_batch_sizes,FunctionLibraryRuntime::Handle fhandle,FunctionLibraryRuntime * flib,bool enable_large_batch_splitting,std::unique_ptr<BatchResource> * resource)106 static Status Create(int32_t num_batch_threads,
107 int32_t max_execution_batch_size,
108 int32_t batch_timeout_micros,
109 int32_t max_enqueued_batches,
110 const std::vector<int32>& allowed_batch_sizes,
111 FunctionLibraryRuntime::Handle fhandle,
112 FunctionLibraryRuntime* flib,
113 bool enable_large_batch_splitting,
114 std::unique_ptr<BatchResource>* resource) {
115 BatcherT::Options batcher_options;
116 batcher_options.num_batch_threads = num_batch_threads;
117 std::shared_ptr<BatcherT> batcher;
118 TF_RETURN_IF_ERROR(BatcherT::Create(batcher_options, &batcher));
119
120 resource->reset(new BatchResource(
121 fhandle, flib, std::move(batcher),
122 GetBatcherQueueOptions(num_batch_threads, max_execution_batch_size,
123 batch_timeout_micros, max_enqueued_batches,
124 allowed_batch_sizes,
125 enable_large_batch_splitting),
126 allowed_batch_sizes));
127 return Status::OK();
128 }
129
Create(AdaptiveBatcherT::Options adaptive_shared_batch_scheduler_options,int32_t max_batch_size,int32_t batch_timeout_micros,int32_t max_enqueued_batches,const std::vector<int32> & allowed_batch_sizes,FunctionLibraryRuntime::Handle fhandle,FunctionLibraryRuntime * flib,std::unique_ptr<BatchResource> * resource)130 static Status Create(
131 AdaptiveBatcherT::Options adaptive_shared_batch_scheduler_options,
132 int32_t max_batch_size, int32_t batch_timeout_micros,
133 int32_t max_enqueued_batches,
134 const std::vector<int32>& allowed_batch_sizes,
135 FunctionLibraryRuntime::Handle fhandle, FunctionLibraryRuntime* flib,
136 std::unique_ptr<BatchResource>* resource) {
137 std::shared_ptr<AdaptiveBatcherT> batcher;
138 TF_RETURN_IF_ERROR(AdaptiveBatcherT::Create(
139 adaptive_shared_batch_scheduler_options, &batcher));
140
141 resource->reset(new BatchResource(
142 fhandle, flib, std::move(batcher),
143 GetAdaptiveBatcherQueueOptions(
144 max_batch_size, batch_timeout_micros, max_enqueued_batches,
145 true /* enable large batch split */, allowed_batch_sizes),
146 allowed_batch_sizes));
147 return Status::OK();
148 }
149
DebugString() const150 string DebugString() const final { return "BatchResource"; }
151
152 private:
BatchResource(FunctionLibraryRuntime::Handle fhandle,FunctionLibraryRuntime * flib,std::shared_ptr<BatcherT> batcher,const BatcherT::QueueOptions & batcher_queue_options,std::vector<int32> allowed_batch_sizes)153 BatchResource(FunctionLibraryRuntime::Handle fhandle,
154 FunctionLibraryRuntime* flib, std::shared_ptr<BatcherT> batcher,
155 const BatcherT::QueueOptions& batcher_queue_options,
156 std::vector<int32> allowed_batch_sizes)
157 : BatchResourceBase(
158 /*has_process_batch_function=*/fhandle != kInvalidHandle,
159 std::move(batcher), batcher_queue_options,
160 std::move(allowed_batch_sizes)),
161 fhandle_(fhandle),
162 flib_(flib) {}
163
BatchResource(FunctionLibraryRuntime::Handle fhandle,FunctionLibraryRuntime * flib,std::shared_ptr<AdaptiveBatcherT> batcher,const AdaptiveBatcherT::QueueOptions & batcher_queue_options,std::vector<int32> allowed_batch_sizes)164 BatchResource(FunctionLibraryRuntime::Handle fhandle,
165 FunctionLibraryRuntime* flib,
166 std::shared_ptr<AdaptiveBatcherT> batcher,
167 const AdaptiveBatcherT::QueueOptions& batcher_queue_options,
168 std::vector<int32> allowed_batch_sizes)
169 : BatchResourceBase(
170 /*has_process_batch_function=*/fhandle != kInvalidHandle,
171 std::move(batcher), batcher_queue_options,
172 std::move(allowed_batch_sizes)),
173 fhandle_(fhandle),
174 flib_(flib) {}
175
ProcessFuncBatchImpl(const BatchTask & last_task,absl::Span<const Tensor> inputs,std::vector<Tensor> * combined_outputs,std::function<void (const Status &)> done) const176 void ProcessFuncBatchImpl(
177 const BatchTask& last_task, absl::Span<const Tensor> inputs,
178 std::vector<Tensor>* combined_outputs,
179 std::function<void(const Status&)> done) const override {
180 auto* last_task_context = last_task.context;
181 FunctionLibraryRuntime::Options opts;
182 opts.step_container = last_task_context->step_container();
183 opts.cancellation_manager = last_task_context->cancellation_manager();
184 opts.collective_executor = last_task_context->collective_executor();
185 opts.stats_collector = last_task_context->stats_collector();
186 opts.runner = last_task_context->runner();
187 opts.run_all_kernels_inline = last_task_context->run_all_kernels_inline();
188 // We do not set 'opts.rendezvous', since if the function is run multiple
189 // times in parallel with the same rendezvous, a _Send node from one run
190 // might be matched with a _Recv node of a different run. Not setting the
191 // rendezvous causes a new rendezvous to be used for each run.
192 Notification done_notif;
193
194 flib_->Run(opts, fhandle_, inputs, combined_outputs,
195 [&](const Status& run_status) {
196 done(run_status);
197 done_notif.Notify();
198 });
199 // By waiting for the notification we are ensuring that this thread isn't
200 // used for processing other batches, which gives the batches time to
201 // coalesce upstream. So overall the number of batches going through the
202 // devices goes down, improving latency and throughput in most cases.
203 done_notif.WaitForNotification();
204 }
205
206 FunctionLibraryRuntime::Handle fhandle_;
207 FunctionLibraryRuntime* flib_;
208 };
209
210 class BatchFunctionKernel : public AsyncOpKernel {
211 public:
BatchFunctionKernel(OpKernelConstruction * c)212 explicit BatchFunctionKernel(OpKernelConstruction* c) : AsyncOpKernel(c) {
213 OP_REQUIRES_OK(c, c->GetAttr("container", &container_));
214 OP_REQUIRES_OK(c, c->GetAttr("shared_name", &shared_name_));
215 OP_REQUIRES_OK(c, c->GetAttr("batching_queue", &batcher_queue_));
216 OP_REQUIRES_OK(c, c->GetAttr("num_batch_threads", &num_batch_threads_));
217 OP_REQUIRES_OK(c, c->GetAttr("max_batch_size", &max_batch_size_));
218 OP_REQUIRES_OK(c,
219 c->GetAttr("batch_timeout_micros", &batch_timeout_micros_));
220 OP_REQUIRES_OK(c,
221 c->GetAttr("max_enqueued_batches", &max_enqueued_batches_));
222 OP_REQUIRES_OK(c, c->GetAttr("allowed_batch_sizes", &allowed_batch_sizes_));
223
224 OP_REQUIRES_OK(c, c->GetAttr("f", &func_));
225 flib_ = c->function_library();
226
227 if (c->HasAttr("enable_large_batch_splitting")) {
228 OP_REQUIRES_OK(c, c->GetAttr("enable_large_batch_splitting",
229 &enable_large_batch_splitting_));
230 has_attribute_enable_large_batch_splitting_ = true;
231 } else {
232 enable_large_batch_splitting_ = false;
233 has_attribute_enable_large_batch_splitting_ = false;
234 }
235
236 // Helper function `SetAdaptiveBatchSchedulerOptions` calls
237 // `OP_REQUIRES_OK`, which exits the current function upon error.
238 // So validate status of `op-kernel-construction`.
239 SetAdaptiveBatchSchedulerOptions(c, num_batch_threads_);
240 if (!c->status().ok()) {
241 return;
242 }
243
244 if (enable_adaptive_batch_threads_) {
245 // One scheduler instance contains a couple of queue instances,
246 // `batcher_queue_` is the key to find queue for this batch-op in the
247 // graph.
248 // Use `shared_name_` and name() as prefix for `batcher_queue_`.
249 // Note name() is unique per session (from session metadata).
250 batcher_queue_ = name() + "/" + shared_name_ + batcher_queue_;
251 }
252
253 if (shared_name_.empty()) {
254 // If shared_name is not supplied, use name instead (prevent collisions by
255 // default).
256 shared_name_ = name();
257 }
258
259 OP_REQUIRES_OK(c, ValidateAllowedBatchSizes());
260 }
261
IsExpensive()262 bool IsExpensive() override { return false; }
263
ComputeAsync(OpKernelContext * c,DoneCallback done)264 void ComputeAsync(OpKernelContext* c, DoneCallback done) final {
265 RecordBatchSplitUsage(
266 has_attribute_enable_large_batch_splitting_
267 ? absl::make_optional(enable_large_batch_splitting_)
268 : absl::nullopt,
269 GetModelName(c));
270 // TODO(b/173255290): Add num_batch_threads_ parameter to TFRT batch kernel.
271 RecordBatchParamNumBatchThreads(num_batch_threads_, GetModelName(c));
272
273 std::function<Status(BatchResource**)> creator;
274
275 FunctionLibraryRuntime::Handle handle;
276 OP_REQUIRES_OK_ASYNC(c, GetOrCreateFunctionHandle(c, &handle), done);
277
278 if (adaptive_batch_scheduler_options_ != absl::nullopt) {
279 creator = [this, handle](BatchResource** r) {
280 serving::AdaptiveSharedBatchScheduler<
281 serving::BatchResourceBase::BatchTask>::Options
282 adaptive_shared_batch_scheduler_options;
283 adaptive_shared_batch_scheduler_options.thread_pool_name =
284 "adaptive_batch_threads";
285 adaptive_shared_batch_scheduler_options.num_batch_threads =
286 adaptive_batch_scheduler_options_->max_in_flight_batches_limit;
287 adaptive_shared_batch_scheduler_options.thread_pool =
288 GetOrCreateBatchThreadsPool(std::string("adaptive_batch_threads"),
289 kBatchThreadPoolSize);
290 // adaptive_shared_batch_scheduler_options.full_batch_scheduling_boost_micros
291 // is 0 (default value) intentionally, so tasks are scheduled in a FIFO
292 // way.
293 // Two rationales to use default value (zero) for
294 // `full_batch_scheduling_boost_micros`
295 // 1) In this way, tasks scheduling policy is FIFO. Compared with round
296 // robin (what shared batch scheduler does), FIFO ensures that model
297 // with low QPS (i.e., models enqueue fewer tasks in the shared queue)
298 // will be processed timely.
299 // 2) If set, `full_batch_scheduling_boost_micros` should be of order
300 // the batch processing latency (which varies on a model basis).
301 // If a non-zero value is not set properly, it harms tail latency.
302 adaptive_shared_batch_scheduler_options.min_in_flight_batches_limit =
303 adaptive_batch_scheduler_options_->min_in_flight_batches_limit;
304 adaptive_shared_batch_scheduler_options
305 .initial_in_flight_batches_limit =
306 adaptive_batch_scheduler_options_->initial_in_flight_batches_limit;
307 adaptive_shared_batch_scheduler_options.batches_to_average_over =
308 adaptive_batch_scheduler_options_->batches_to_average_over;
309 adaptive_shared_batch_scheduler_options.fifo_scheduling = true;
310 std::unique_ptr<BatchResource> new_resource;
311 TF_RETURN_IF_ERROR(BatchResource::Create(
312 adaptive_shared_batch_scheduler_options, max_batch_size_,
313 batch_timeout_micros_, max_enqueued_batches_, allowed_batch_sizes_,
314 handle, flib_, &new_resource));
315 *r = new_resource.release();
316 return Status::OK();
317 };
318 } else {
319 creator = [this, handle](BatchResource** r) {
320 std::unique_ptr<BatchResource> new_resource;
321 TF_RETURN_IF_ERROR(BatchResource::Create(
322 num_batch_threads_, max_batch_size_, batch_timeout_micros_,
323 max_enqueued_batches_, allowed_batch_sizes_, handle, flib_,
324 enable_large_batch_splitting_, &new_resource));
325 *r = new_resource.release();
326 return Status::OK();
327 };
328 }
329
330 BatchResource* br;
331 OP_REQUIRES_OK_ASYNC(c,
332 c->resource_manager()->LookupOrCreate(
333 container_, shared_name_, &br, creator),
334 done);
335 const Status status =
336 br->RegisterInput(random::New64(), c, batcher_queue_, done);
337 br->Unref();
338 OP_REQUIRES_OK_ASYNC(c, status, done);
339 // Assume br calls done, so nothing to do here.
340 }
341
InstantiateFunction(OpKernelContext * c,FunctionLibraryRuntime::Handle * handle) const342 Status InstantiateFunction(OpKernelContext* c,
343 FunctionLibraryRuntime::Handle* handle) const {
344 // TODO(b/173748062): Merge this instantiation logic with PartitionedCall.
345 if (!flib_) {
346 return errors::Internal("No function library");
347 }
348
349 FunctionLibraryRuntime::InstantiateOptions opts;
350 opts.target = flib_->device() == nullptr ? "" : flib_->device()->name();
351 opts.is_multi_device_function = true;
352 const ConfigProto* config = flib_->config_proto();
353 if (config) {
354 opts.config_proto = *config;
355 }
356
357 Device* cpu_device;
358 TF_RETURN_IF_ERROR(flib_->device_mgr()->LookupDevice("CPU:0", &cpu_device));
359
360 const FunctionDef* fdef =
361 flib_->GetFunctionLibraryDefinition()->Find(func_.name());
362 if (!fdef) {
363 return errors::NotFound("Failed to find definition for function \"",
364 func_.name(), "\"");
365 }
366 OpInputList in_tensors;
367 TF_RETURN_IF_ERROR(c->input_list("in_tensors", &in_tensors));
368 for (int i = 0; i < in_tensors.size(); i++) {
369 if (in_tensors[i].dtype() == DT_RESOURCE) {
370 return errors::InvalidArgument(
371 "BatchFunction cannot take resource inputs but input ", i,
372 " is a resource.");
373 } else {
374 // Currently, inputs are on CPU since they are concatenated on CPU
375 opts.input_devices.push_back(cpu_device->name());
376 }
377 }
378 OpInputList captured_tensors;
379 TF_RETURN_IF_ERROR(c->input_list("captured_tensors", &captured_tensors));
380 for (const Tensor& t : captured_tensors) {
381 if (t.dtype() == DT_RESOURCE) {
382 const ResourceHandle& rhandle = t.flat<ResourceHandle>()(0);
383 opts.input_devices.push_back(rhandle.device());
384 } else {
385 opts.input_devices.push_back(cpu_device->name());
386 }
387 }
388 const OpDef& signature = fdef->signature();
389 for (int i = 0; i < signature.output_arg_size(); i++) {
390 // Currently, outputs must be on CPU since they are split on CPU.
391 opts.output_devices.push_back(cpu_device->name());
392 }
393 if (opts.input_devices.size() != signature.input_arg_size()) {
394 return errors::InvalidArgument(
395 "Function takes ", signature.input_arg_size(), " argument(s) but ",
396 opts.input_devices.size(), " argument(s) were passed");
397 }
398 return flib_->Instantiate(func_.name(), AttrSlice(&func_.attr()), opts,
399 handle);
400 }
401
GetOrCreateFunctionHandle(OpKernelContext * c,FunctionLibraryRuntime::Handle * handle)402 Status GetOrCreateFunctionHandle(OpKernelContext* c,
403 FunctionLibraryRuntime::Handle* handle) {
404 mutex_lock ml(mu_);
405 if (!fhandle_) {
406 TF_RETURN_IF_ERROR(InstantiateFunction(c, handle));
407 fhandle_ = *handle;
408 } else {
409 *handle = fhandle_.value();
410 }
411 return Status::OK();
412 }
413
414 // Validates 'allowed_batch_sizes_'. The entries must increase monotonically.
415 // If large batch split is not enabled, the last one must equal
416 // `max_batch_size_`. otherwise the last element must be smaller than or equal
417 // to `max_batch_size_`.
ValidateAllowedBatchSizes() const418 Status ValidateAllowedBatchSizes() const {
419 if (allowed_batch_sizes_.empty()) {
420 return Status::OK();
421 }
422 int32_t last_size = 0;
423 for (size_t i = 0; i < allowed_batch_sizes_.size(); ++i) {
424 const int32_t size = allowed_batch_sizes_.at(i);
425 if (i > 0 && size <= last_size) {
426 return errors::InvalidArgument(
427 "allowed_batch_sizes entries must be monotonically increasing");
428 }
429
430 if ((!enable_large_batch_splitting_) &&
431 (i == allowed_batch_sizes_.size() - 1) && (size != max_batch_size_)) {
432 return errors::InvalidArgument(
433 "final entry in allowed_batch_sizes must equal max_batch_size when "
434 "enable_large_batch_splitting is False");
435 }
436
437 last_size = size;
438 }
439 return Status::OK();
440 }
441
442 private:
443 // Initialize vars by reading from op-kernel-construction.
444 // Vars
445 // - enable_adaptive_batch_threads_
446 // true if value of attribute `kEnableAdaptiveSchedulerAttr` is true, or
447 // if `num_batch_threads` is not positive.
448 // - adaptive_batch_scheduler_options_
449 // Read from corresponding attributes as long as they are set.
SetAdaptiveBatchSchedulerOptions(OpKernelConstruction * c,int32_t num_batch_threads)450 void SetAdaptiveBatchSchedulerOptions(OpKernelConstruction* c,
451 int32_t num_batch_threads) {
452 if (c->HasAttr(kEnableAdaptiveSchedulerAttr)) {
453 OP_REQUIRES_OK(c, c->GetAttr(kEnableAdaptiveSchedulerAttr,
454 &enable_adaptive_batch_threads_));
455 }
456
457 if (num_batch_threads <= 0) {
458 enable_adaptive_batch_threads_ = true;
459 }
460
461 if (!enable_adaptive_batch_threads_) {
462 // adaptive_batch_scheduler_options_ is nullopt.
463 return;
464 }
465
466 // adaptive_batch_scheduler_options_ is not nullopt
467 AdaptiveBatchSchedulerOptions options;
468
469 if (c->HasAttr(kBatchesToAverageOverAttr)) {
470 OP_REQUIRES_OK(c, c->GetAttr(kBatchesToAverageOverAttr,
471 &options.batches_to_average_over));
472 }
473
474 if (c->HasAttr(kMinInflightBatchesAttr)) {
475 OP_REQUIRES_OK(c, c->GetAttr(kMinInflightBatchesAttr,
476 &options.min_in_flight_batches_limit));
477 }
478
479 if (c->HasAttr(kInitialInflightBatchesAttr)) {
480 OP_REQUIRES_OK(c, c->GetAttr(kInitialInflightBatchesAttr,
481 &options.initial_in_flight_batches_limit));
482 }
483
484 if (c->HasAttr(kMaxInflightBatchesAttr)) {
485 OP_REQUIRES_OK(c, c->GetAttr(kMaxInflightBatchesAttr,
486 &options.max_in_flight_batches_limit));
487 }
488
489 adaptive_batch_scheduler_options_ = options;
490 }
491
492 string container_;
493 string shared_name_;
494 string batcher_queue_;
495 int32 num_batch_threads_;
496 int32 max_batch_size_;
497 int32 batch_timeout_micros_;
498 int32 max_enqueued_batches_;
499 std::vector<int32> allowed_batch_sizes_;
500 NameAttrList func_;
501 absl::optional<FunctionLibraryRuntime::Handle> fhandle_ TF_GUARDED_BY(mu_);
502 FunctionLibraryRuntime* flib_;
503 bool enable_large_batch_splitting_;
504 bool has_attribute_enable_large_batch_splitting_;
505 bool enable_adaptive_batch_threads_ = false;
506 mutex mu_;
507
508 // Parameters for adaptive batch scheduler only.
509 // Note 'num_batch_threads_' above is shared by two implementations of batch
510 // scheduler.
511 struct AdaptiveBatchSchedulerOptions {
512 int32 min_in_flight_batches_limit = kMinInflightBatches;
513 int32 initial_in_flight_batches_limit = kInitialInflightBatches;
514 int32 max_in_flight_batches_limit = kMaxInflightBatches;
515 int32 batches_to_average_over = kBatchesToAverageOver;
516 };
517 absl::optional<AdaptiveBatchSchedulerOptions>
518 adaptive_batch_scheduler_options_ = absl::nullopt;
519 };
520
521 REGISTER_KERNEL_BUILDER(Name("BatchFunction").Device(DEVICE_CPU),
522 BatchFunctionKernel);
523 // Currently all inputs and outputs are on the host.
524 // TODO(b/173748277): Accept inputs/outputs on the device.
525 REGISTER_KERNEL_BUILDER(Name("BatchFunction")
526 .Device(DEVICE_GPU)
527 .HostMemory("in_tensors")
528 .HostMemory("captured_tensors")
529 .HostMemory("out_tensors"),
530 BatchFunctionKernel);
531
532 class BatchKernel : public AsyncOpKernel {
533 public:
BatchKernel(OpKernelConstruction * c)534 explicit BatchKernel(OpKernelConstruction* c) : AsyncOpKernel(c) {
535 OP_REQUIRES_OK(c, c->GetAttr("container", &container_));
536 OP_REQUIRES_OK(c, c->GetAttr("shared_name", &shared_name_));
537 // If shared_name is not supplied, use name instead (prevent collisions by
538 // default).
539 if (shared_name_.empty()) {
540 shared_name_ = name();
541 }
542 OP_REQUIRES_OK(c, c->GetAttr("batching_queue", &batcher_queue_));
543 OP_REQUIRES_OK(c, c->GetAttr("num_batch_threads", &num_batch_threads_));
544 OP_REQUIRES_OK(c, c->GetAttr("max_batch_size", &max_batch_size_));
545 OP_REQUIRES_OK(c,
546 c->GetAttr("batch_timeout_micros", &batch_timeout_micros_));
547 OP_REQUIRES_OK(c,
548 c->GetAttr("max_enqueued_batches", &max_enqueued_batches_));
549 OP_REQUIRES_OK(c, c->GetAttr("allowed_batch_sizes", &allowed_batch_sizes_));
550 OP_REQUIRES_OK(c, ValidateAllowedBatchSizes());
551 }
552
ComputeAsync(OpKernelContext * c,DoneCallback done)553 void ComputeAsync(OpKernelContext* c, DoneCallback done) final {
554 BatchResource* br;
555 std::function<Status(BatchResource**)> creator = [this](BatchResource** r) {
556 std::unique_ptr<BatchResource> new_resource;
557 TF_RETURN_IF_ERROR(BatchResource::Create(
558 num_batch_threads_, max_batch_size_, batch_timeout_micros_,
559 max_enqueued_batches_, allowed_batch_sizes_, kInvalidHandle,
560 /*flib=*/nullptr, false, &new_resource));
561 *r = new_resource.release();
562 return Status::OK();
563 };
564 OP_REQUIRES_OK_ASYNC(c,
565 c->resource_manager()->LookupOrCreate(
566 container_, shared_name_, &br, creator),
567 done);
568 const Status status =
569 br->RegisterInput(random::New64(), c, batcher_queue_, done);
570 br->Unref();
571 OP_REQUIRES_OK_ASYNC(c, status, done);
572 // Assume br calls done, so nothing to do here.
573 }
574
575 // Validates 'allowed_batch_sizes_'. The entries must increase
576 // monotonically, and the last one must equal 'max_batch_size_'.
ValidateAllowedBatchSizes() const577 Status ValidateAllowedBatchSizes() const {
578 if (allowed_batch_sizes_.empty()) {
579 return Status::OK();
580 }
581 int32_t last_size = 0;
582 for (size_t i = 0; i < allowed_batch_sizes_.size(); ++i) {
583 const int32_t size = allowed_batch_sizes_.at(i);
584 if (i > 0 && size <= last_size) {
585 return errors::InvalidArgument(
586 "allowed_batch_sizes entries must be monotonically increasing");
587 }
588 if (i == allowed_batch_sizes_.size() - 1 && size != max_batch_size_) {
589 return errors::InvalidArgument(
590 "final entry in allowed_batch_sizes must equal max_batch_size");
591 }
592 last_size = size;
593 }
594 return Status::OK();
595 }
596
597 private:
598 string container_;
599 string shared_name_;
600 string batcher_queue_;
601 int32 num_batch_threads_;
602 int32 max_batch_size_;
603 int32 batch_timeout_micros_;
604 int32 max_enqueued_batches_;
605 std::vector<int32> allowed_batch_sizes_;
606 };
607
608 REGISTER_KERNEL_BUILDER(Name("Batch").Device(DEVICE_CPU), BatchKernel);
609
610 // A class encapsulating the state and logic for unbatching tensors.
611 //
612 // UnbatchResource keeps two data structures indexed by batch-key: one which has
613 // the continuations for all concurrent kernels which are waiting for tensors
614 // and another which has tensors which are waiting for their corresponding
615 // kernels to run. Whenever a kernel runs, we either grab its tensor if it's
616 // waiting already, or we insert it in the queue and then look at its tensor to
617 // see if it can be used to dispatch any stored continuations.
618 class UnbatchResource : public ResourceBase {
619 public:
UnbatchResource(int32_t timeout_micros)620 explicit UnbatchResource(int32_t timeout_micros)
621 : timeout_micros_(timeout_micros),
622 timeout_enforcer_(new serving::PeriodicFunction(
623 [this] { EnforceTimeout(); }, 1000 /* 1 ms */)) {}
624
~UnbatchResource()625 ~UnbatchResource() override {
626 // Tear down 'timeout_enforcer_' first, since it accesses other state in
627 // this class.
628 timeout_enforcer_ = nullptr;
629 }
630
DebugString() const631 string DebugString() const final { return "UnbatchResource"; }
632
Compute(OpKernelContext * context,AsyncOpKernel::DoneCallback done)633 Status Compute(OpKernelContext* context, AsyncOpKernel::DoneCallback done) {
634 const Tensor& data_t = context->input(0);
635 const Tensor& batch_index_t = context->input(1);
636
637 if (batch_index_t.shape().dim_size(0) > data_t.shape().dim_size(0)) {
638 return errors::InvalidArgument(
639 "Wrong shape for index tensor. Expected 0th dimension size to be no "
640 "greater than ",
641 data_t.shape().dim_size(0),
642 "; Got: ", batch_index_t.shape().dim_size(0), ".");
643 }
644 if (batch_index_t.shape().dim_size(1) != 3) {
645 return errors::InvalidArgument(
646 "Wrong shape for index tensor. Expected 1st dimension size to be 3 ; "
647 "Got: ",
648 batch_index_t.shape().dim_size(1), ".");
649 }
650
651 const int64_t batch_key = context->input(2).scalar<int64>()();
652 const bool nonempty_input = batch_index_t.dim_size(0) > 0;
653
654 // If we have a non-empty tensor, slice it up.
655 // (It is important to do this outside of the critical section below.)
656 // The following variables are populated iff 'nonempty_input==true'.
657 std::vector<int64> sizes;
658 std::vector<int64> batch_keys;
659 std::vector<Tensor> split_inputs;
660 if (nonempty_input) {
661 auto batch_indices =
662 batch_index_t.shaped<int64, 2>({batch_index_t.dim_size(0), 3});
663 for (int i = 0; i < batch_index_t.dim_size(0); ++i) {
664 sizes.push_back(batch_indices(i, 2) - batch_indices(i, 1));
665 batch_keys.push_back(batch_indices(i, 0));
666 }
667
668 TF_RETURN_IF_ERROR(Split(context, data_t, sizes, &split_inputs));
669 }
670
671 // Critical section.
672 std::vector<AsyncOpKernel::DoneCallback> done_callbacks_to_call;
673 Status status = [&]() -> Status {
674 mutex_lock ml(mu_);
675
676 // Check to see whether the tensor we want is already ready.
677 auto tensor_it = waiting_tensors_.find(batch_key);
678 if (tensor_it != waiting_tensors_.end()) {
679 context->set_output(0, tensor_it->second.tensor);
680 waiting_tensors_.erase(tensor_it);
681 done_callbacks_to_call.push_back(done);
682 return Status::OK();
683 }
684
685 const uint64 deadline_micros =
686 Env::Default()->NowMicros() + timeout_micros_;
687
688 // Add ourselves to the waitlist for tensors.
689 if (!waiting_callbacks_
690 .emplace(batch_key,
691 WaitingCallback{deadline_micros, context, done})
692 .second) {
693 return errors::AlreadyExists(
694 "Multiple session runs with the same batch key.");
695 }
696
697 // If we have a non-empty tensor, finish the waitlisted runs,
698 // and store any remaining pieces.
699 if (nonempty_input) {
700 for (size_t i = 0; i < batch_keys.size(); ++i) {
701 auto runs_it = waiting_callbacks_.find(batch_keys[i]);
702 if (runs_it != waiting_callbacks_.end()) {
703 runs_it->second.context->set_output(0, split_inputs[i]);
704 done_callbacks_to_call.push_back(runs_it->second.done);
705 waiting_callbacks_.erase(runs_it);
706 } else {
707 // Note: the deadline here is in case we are arriving late and the
708 // kernel that should rendezvous with this tensor has already waited
709 // and timed out.
710 if (!waiting_tensors_
711 .emplace(batch_keys[i],
712 WaitingTensor{deadline_micros, split_inputs[i]})
713 .second) {
714 return errors::AlreadyExists(
715 "Multiple tensors returned for same batch key.");
716 }
717 }
718 }
719 }
720
721 return Status::OK();
722 }();
723
724 for (const AsyncOpKernel::DoneCallback& done_callback :
725 done_callbacks_to_call) {
726 done_callback();
727 }
728
729 return status;
730 }
731
732 private:
733 // Evicts waiting tensors and callbacks that have exceeded their deadline.
EnforceTimeout()734 void EnforceTimeout() {
735 const uint64 now = Env::Default()->NowMicros();
736 std::vector<WaitingCallback> evicted_callbacks;
737
738 {
739 mutex_lock ml(mu_);
740
741 for (auto it = waiting_tensors_.begin(); it != waiting_tensors_.end();) {
742 const WaitingTensor& waiting_tensor = it->second;
743 if (waiting_tensor.deadline_micros < now) {
744 it = waiting_tensors_.erase(it);
745 } else {
746 ++it;
747 }
748 }
749
750 for (auto it = waiting_callbacks_.begin();
751 it != waiting_callbacks_.end();) {
752 const WaitingCallback& waiting_callback = it->second;
753 if (waiting_callback.deadline_micros < now) {
754 evicted_callbacks.push_back(waiting_callback);
755 it = waiting_callbacks_.erase(it);
756 } else {
757 ++it;
758 }
759 }
760 }
761
762 for (const WaitingCallback& evicted_callback : evicted_callbacks) {
763 evicted_callback.context->CtxFailureWithWarning(errors::DeadlineExceeded(
764 "Batched data did not arrive within timeout window."));
765 evicted_callback.done();
766 }
767 }
768
769 struct WaitingTensor {
770 uint64 deadline_micros;
771 Tensor tensor;
772 };
773
774 struct WaitingCallback {
775 uint64 deadline_micros;
776 OpKernelContext* context;
777 AsyncOpKernel::DoneCallback done;
778 };
779
780 const int32 timeout_micros_;
781
782 mutex mu_;
783
784 // Maps keyed by BatchKey of tensors waiting for callbacks and callbacks
785 // waiting for tensors.
786 std::unordered_map<int64, WaitingTensor> waiting_tensors_ TF_GUARDED_BY(mu_);
787 std::unordered_map<int64, WaitingCallback> waiting_callbacks_
788 TF_GUARDED_BY(mu_);
789
790 // A thread that evicts waiting tensors and callbacks that have exceeded their
791 // deadline.
792 std::unique_ptr<serving::PeriodicFunction> timeout_enforcer_;
793 };
794
795 class UnbatchKernel : public AsyncOpKernel {
796 public:
UnbatchKernel(OpKernelConstruction * c)797 explicit UnbatchKernel(OpKernelConstruction* c) : AsyncOpKernel(c) {
798 OP_REQUIRES_OK(c, c->GetAttr("container", &container_));
799 OP_REQUIRES_OK(c, c->GetAttr("shared_name", &shared_name_));
800 // If shared_name is not supplied, use name instead (prevent collisions by
801 // default).
802 if (shared_name_.empty()) {
803 shared_name_ = name();
804 }
805 OP_REQUIRES_OK(c, c->GetAttr("timeout_micros", &timeout_micros_));
806 }
807
ComputeAsync(OpKernelContext * c,DoneCallback done)808 void ComputeAsync(OpKernelContext* c, DoneCallback done) final {
809 UnbatchResource* ubr;
810 std::function<Status(UnbatchResource**)> creator =
811 [this](UnbatchResource** r) {
812 *r = new UnbatchResource(timeout_micros_);
813 return Status::OK();
814 };
815 OP_REQUIRES_OK_ASYNC(c,
816 c->resource_manager()->LookupOrCreate(
817 container_, shared_name_, &ubr, creator),
818 done);
819 auto status = ubr->Compute(c, done);
820 ubr->Unref();
821 OP_REQUIRES_OK_ASYNC(c, status, done);
822 // Assume ubr calls done, so nothing to do here.
823 }
824
825 private:
826 string container_;
827 string shared_name_;
828 int32 timeout_micros_;
829 };
830 REGISTER_KERNEL_BUILDER(Name("Unbatch").Device(DEVICE_CPU), UnbatchKernel);
831
832 // A class encapsulating the state and logic for batching tensors
833 // deterministically for the gradient of unbatch.
834 class UnbatchGradResource : public ResourceBase {
835 public:
UnbatchGradResource()836 UnbatchGradResource() {}
837
DebugString() const838 string DebugString() const final { return "UnbatchGradResource"; }
839
840 // Flushes the information for one batch, given its context and done
841 // callback. Clears all information about it from the available_tensors_.
OutputBatch(OpKernelContext * context,const AsyncOpKernel::DoneCallback & done)842 Status OutputBatch(OpKernelContext* context,
843 const AsyncOpKernel::DoneCallback& done)
844 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
845 const Tensor& batch_index_t = context->input(1);
846 auto batch_index =
847 batch_index_t.shaped<int64, 2>({batch_index_t.dim_size(0), 3});
848 std::vector<Tensor> tensors;
849 for (int i = 0; i < batch_index_t.dim_size(0); ++i) {
850 auto available_it = available_tensors_.find(batch_index(i, 0));
851 if (available_it == available_tensors_.end()) {
852 return errors::Internal("bad bookkeeping of available tensors.");
853 }
854 tensors.push_back(available_it->second);
855 available_tensors_.erase(available_it);
856 }
857
858 const DataType type = tensors[0].dtype();
859 Tensor concatenated_tensor;
860 switch (type) {
861 #define CASE(type) \
862 case DataTypeToEnum<type>::value: \
863 TF_RETURN_IF_ERROR(Concat<type>(context, tensors, &concatenated_tensor)); \
864 context->set_output(0, concatenated_tensor); \
865 break;
866 TF_CALL_ALL_TYPES(CASE);
867 #undef CASE
868 default:
869 return errors::InvalidArgument("Unsupported data type: ", type);
870 }
871 done();
872 return Status::OK();
873 }
874
875 // Ingests data from one invocation of the op.
Compute(OpKernelContext * context,const AsyncOpKernel::DoneCallback & done)876 Status Compute(OpKernelContext* context,
877 const AsyncOpKernel::DoneCallback& done) {
878 const Tensor& data_t = context->input(0);
879 const Tensor& batch_index_t = context->input(1);
880 const Tensor& grad_t = context->input(2);
881
882 mutex_lock ml(mu_);
883
884 const int64_t batch_key = context->input(3).scalar<int64>()();
885 // Mark our tensor as available.
886 if (!available_tensors_.emplace(batch_key, grad_t).second) {
887 return errors::InvalidArgument("Two runs with the same batch key.");
888 }
889
890 // Check whether we have a valid input tensor and, if so, create its
891 // dispatch logic.
892 if (data_t.NumElements() > 0) {
893 if (batch_index_t.NumElements() == 0) {
894 return errors::InvalidArgument(
895 "batch_index is empty while the tensor isn't.");
896 }
897 std::unordered_set<int64> missing_tensors;
898 const auto batch_index =
899 batch_index_t.shaped<int64, 2>({batch_index_t.dim_size(0), 3});
900 for (int i = 0; i < batch_index_t.dim_size(0); ++i) {
901 const int64_t batch_key = batch_index(i, 0);
902 if (available_tensors_.find(batch_key) == available_tensors_.end()) {
903 missing_tensors.emplace(batch_key);
904 }
905 }
906 if (missing_tensors.empty()) {
907 return OutputBatch(context, done);
908 }
909 if (!available_batches_
910 .emplace(batch_key, Batch{missing_tensors, context, done})
911 .second) {
912 return errors::InvalidArgument(
913 "Batch key with valid batch used twice.");
914 }
915 for (const int64_t i : missing_tensors) {
916 if (!desired_tensor_to_batch_map_.emplace(i, batch_key).second) {
917 return errors::InvalidArgument(
918 "Missing tensor wanted by more than one batch.");
919 }
920 }
921 } else {
922 // If we don't have a valid input tensor we can output an empty tensor and
923 // call our done closure.
924 TensorShape output_shape(grad_t.shape());
925 output_shape.set_dim(0, 0);
926 Tensor* output = nullptr;
927 TF_RETURN_IF_ERROR(context->allocate_output(0, output_shape, &output));
928 done();
929 }
930
931 // Search to see whether our tensor is desired by any existing batch.
932 auto desire_it = desired_tensor_to_batch_map_.find(batch_key);
933 if (desire_it != desired_tensor_to_batch_map_.end()) {
934 // Mark our tensor as no longer missing.
935 auto batch_it = available_batches_.find(desire_it->second);
936 desired_tensor_to_batch_map_.erase(desire_it);
937 if (batch_it == available_batches_.end()) {
938 return errors::InvalidArgument("Batch no longer exists.");
939 }
940 batch_it->second.missing_tensors.erase(batch_key);
941 // If all tensors are available we should concatenate them and dispatch
942 // the batch.
943 if (batch_it->second.missing_tensors.empty()) {
944 TF_RETURN_IF_ERROR(
945 OutputBatch(batch_it->second.context, batch_it->second.done));
946 available_batches_.erase(batch_it);
947 }
948 }
949 return Status::OK();
950 }
951
952 private:
953 mutex mu_;
954
955 // Represents a still-incomplete batch of tensors. When all tensors become
956 // available they will be concatenated in the right order and sent through the
957 // context.
958 struct Batch {
959 // Batch keys for tensors which are still missing from this batch. When this
960 // is empty the Tensors can be concatenated and forwarded.
961 std::unordered_set<int64> missing_tensors;
962
963 // Context and callback for the session responsible for finishing this
964 // batch.
965 OpKernelContext* context;
966 AsyncOpKernel::DoneCallback done;
967 };
968
969 // Map from batch key of the session which will output the batched gradients
970 // to still-incomplete batches.
971 std::unordered_map<int64, Batch> available_batches_;
972
973 // Map from batch key to tensors which are waiting for their batches to be
974 // available.
975 std::unordered_map<int64, Tensor> available_tensors_;
976
977 // Map from batch key of a tensor which is not yet available to the batch key
978 // of the batch to which it belongs.
979 std::unordered_map<int64, int64> desired_tensor_to_batch_map_;
980 };
981
982 class UnbatchGradKernel : public AsyncOpKernel {
983 public:
UnbatchGradKernel(OpKernelConstruction * c)984 explicit UnbatchGradKernel(OpKernelConstruction* c) : AsyncOpKernel(c) {
985 OP_REQUIRES_OK(c, c->GetAttr("container", &container_));
986 OP_REQUIRES_OK(c, c->GetAttr("shared_name", &shared_name_));
987 // If shared_name is not supplied, use name instead (prevent collisions by
988 // default).
989 if (shared_name_.empty()) {
990 shared_name_ = name();
991 }
992 }
993
ComputeAsync(OpKernelContext * c,DoneCallback done)994 void ComputeAsync(OpKernelContext* c, DoneCallback done) final {
995 UnbatchGradResource* ubr;
996 std::function<Status(UnbatchGradResource**)> creator =
997 [](UnbatchGradResource** r) {
998 *r = new UnbatchGradResource();
999 return Status::OK();
1000 };
1001 OP_REQUIRES_OK_ASYNC(c,
1002 c->resource_manager()->LookupOrCreate(
1003 container_, shared_name_, &ubr, creator),
1004 done);
1005 Status status = ubr->Compute(c, done);
1006 ubr->Unref();
1007 OP_REQUIRES_OK_ASYNC(c, status, done);
1008 // Assume ubr calls done, so nothing to do here.
1009 }
1010
1011 private:
1012 string container_;
1013 string shared_name_;
1014 };
1015 REGISTER_KERNEL_BUILDER(Name("UnbatchGrad").Device(DEVICE_CPU),
1016 UnbatchGradKernel);
1017
1018 } // namespace tensorflow
1019