1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/function.h"
17 #include "tensorflow/core/framework/op_kernel.h"
18 #include "tensorflow/core/framework/register_types.h"
19 #include "tensorflow/core/framework/resource_mgr.h"
20 #include "tensorflow/core/framework/tensor.h"
21 #include "tensorflow/core/framework/tensor_util.h"
22 #include "tensorflow/core/framework/types.h"
23 #include "tensorflow/core/kernels/batching_util/periodic_function.h"
24 #include "tensorflow/core/kernels/batching_util/shared_batch_scheduler.h"
25 #include "tensorflow/core/kernels/concat_lib.h"
26 #include "tensorflow/core/kernels/ops_util.h"
27 #include "tensorflow/core/kernels/split_lib.h"
28 #include "tensorflow/core/lib/gtl/cleanup.h"
29 #include "tensorflow/core/lib/random/random.h"
30 #include "tensorflow/core/platform/macros.h"
31
32 namespace tensorflow {
33
34 typedef Eigen::ThreadPoolDevice CPUDevice;
35 typedef Eigen::GpuDevice GPUDevice;
36 #ifdef TENSORFLOW_USE_SYCL
37 typedef Eigen::SyclDevice SYCLDevice;
38 #endif // TENSORFLOW_USE_SYCL
39
40 // Concatenates 'inputs' into a single tensor along the zeroth dimension.
41 // Requires that all elements of 'inputs' have element type T. Writes to the
42 // op's output at position 'output_index', using 'context' for the allocation to
43 // ensure proper device placement.
44 template <typename T>
Concat(OpKernelContext * context,const gtl::ArraySlice<Tensor> & inputs,Tensor * output)45 Status Concat(OpKernelContext* context, const gtl::ArraySlice<Tensor>& inputs,
46 Tensor* output) {
47 const int input_dims = inputs[0].dims();
48 const TensorShape& input_shape = inputs[0].shape();
49
50 // Note that we reduce the concat of k-dimensional tensors into a two
51 // dimensional concat. Assuming the dimensions of any input tensor are
52 // {y0, y1,...,ym-1}, we flatten it to {1, y}, where y = Prod_i(yi).
53 std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>> inputs_flat;
54 inputs_flat.reserve(inputs.size());
55 int64 output_dim0 = 0;
56 for (size_t i = 0; i < inputs.size(); ++i) {
57 const Tensor& input = inputs[i];
58 if (input.dims() != input_dims) {
59 return errors::InvalidArgument(
60 "Ranks of all input tensors should match: shape[0] = ",
61 input_shape.DebugString(), " vs. shape[", i,
62 "] = ", input.shape().DebugString());
63 }
64 for (int j = 1; j < input_dims; ++j) {
65 if (input.dim_size(j) != input_shape.dim_size(j)) {
66 return errors::InvalidArgument(
67 "Dimensions of inputs should match: shape[0] = ",
68 input_shape.DebugString(), " vs. shape[", i,
69 "] = ", input.shape().DebugString());
70 }
71 }
72 if (input.NumElements() > 0) {
73 inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
74 input.shaped<T, 2>({1, input.NumElements()})));
75 }
76 output_dim0 += input.dim_size(0);
77 }
78
79 TensorShape output_shape(input_shape);
80 output_shape.set_dim(0, output_dim0);
81 TF_RETURN_IF_ERROR(
82 context->allocate_temp(DataTypeToEnum<T>::value, output_shape, output));
83 if (output->NumElements() > 0) {
84 auto output_flat = output->shaped<T, 2>({1, output->NumElements()});
85 #if GOOGLE_CUDA
86 if (std::is_same<Device, GPUDevice>::value) {
87 ConcatGPU<T>(context, inputs_flat, output, &output_flat);
88 return Status::OK();
89 }
90 #endif // GOOGLE_CUDA
91 ConcatCPU<T>(context->device(), inputs_flat, &output_flat);
92 }
93
94 return Status::OK();
95 }
96
97 // The Split*() functions split 'input' with element type T into 'sizes.size()'
98 // tensors along the zeroth dimension, with the ith split having zeroth-
99 // dimension size 'sizes[i]'. They allocate the output tensors using 'context',
100 // for proper device placement.
101
102 // Handles special cases that are cheap. Sets 'done==true' iff it found an
103 // applicable special case and wrote to the outputs. Otherwise acts as a no-op.
104 template <typename T>
SplitEasyCases(OpKernelContext * context,const Tensor & input,const gtl::ArraySlice<int64> & sizes,std::vector<Tensor> * outputs,bool * done)105 Status SplitEasyCases(OpKernelContext* context, const Tensor& input,
106 const gtl::ArraySlice<int64>& sizes,
107 std::vector<Tensor>* outputs, bool* done) {
108 *done = false;
109
110 int64 total_size = 0;
111 for (const int64 size : sizes) {
112 total_size += size;
113 }
114 if (total_size > input.shape().dim_size(0)) {
115 return errors::InvalidArgument(
116 "Sum of split sizes must not exceed dim0-size of input tensor");
117 }
118
119 // Special case 0: trivial 1-way split.
120 if (sizes.size() == 1 && sizes.at(0) == input.shape().dim_size(0)) {
121 outputs->push_back(input);
122 *done = true;
123 return Status::OK();
124 }
125
126 // Special case 1: input is aligned.
127 if (IsInnerDimsSizeAligned<T>(input.shape())) {
128 int64 position = 0;
129 for (const int64 size : sizes) {
130 outputs->emplace_back(input.Slice(position, position + size));
131 position += size;
132 }
133 *done = true;
134 return Status::OK();
135 }
136
137 return Status::OK();
138 }
139
140 // Handles the general case, on CPU.
141 template <typename T>
SplitCPU(OpKernelContext * context,const Tensor & input,const gtl::ArraySlice<int64> & sizes,std::vector<Tensor> * outputs)142 Status SplitCPU(OpKernelContext* context, const Tensor& input,
143 const gtl::ArraySlice<int64>& sizes,
144 std::vector<Tensor>* outputs) {
145 int64 suffix_dim_size = 1;
146 for (int i = 1; i < input.shape().dims(); ++i) {
147 suffix_dim_size *= input.shape().dim_size(i);
148 }
149 auto input_reshaped =
150 input.shaped<T, 2>({input.shape().dim_size(0), suffix_dim_size});
151
152 int64 position = 0;
153 for (const int64 size : sizes) {
154 TensorShape output_shape = input.shape();
155 output_shape.set_dim(0, size);
156 Tensor output;
157 TF_RETURN_IF_ERROR(
158 context->allocate_temp(input.dtype(), output_shape, &output));
159 auto output_shaped = output.shaped<T, 2>({size, suffix_dim_size});
160
161 Eigen::DSizes<Eigen::DenseIndex, 2> slice_indices{position, 0};
162 Eigen::DSizes<Eigen::DenseIndex, 2> slice_sizes{size, suffix_dim_size};
163 functor::Split<CPUDevice, T, 2>()(context->eigen_device<CPUDevice>(),
164 output_shaped, input_reshaped,
165 slice_indices, slice_sizes);
166
167 outputs->emplace_back(output);
168
169 position += size;
170 }
171
172 return Status::OK();
173 }
174
175 #if GOOGLE_CUDA
176
177 // Handles the general case, on GPU.
178 template <typename T>
SplitGPU(OpKernelContext * context,const Tensor & input,const gtl::ArraySlice<int64> & sizes,std::vector<Tensor> * outputs)179 Status SplitGPU(OpKernelContext* context, const Tensor& input,
180 const gtl::ArraySlice<int64>& sizes,
181 std::vector<Tensor>* outputs) {
182 // TODO(olston, apassos): Implement this.
183 LOG(FATAL) << "Not yet implemented"; // Crash ok
184 }
185
186 #endif // GOOGLE_CUDA
187
188 // The outer function that dispatches to the various Split*() functions above.
189 template <typename T>
Split(OpKernelContext * context,const Tensor & input,const gtl::ArraySlice<int64> & sizes,std::vector<Tensor> * outputs)190 Status Split(OpKernelContext* context, const Tensor& input,
191 const gtl::ArraySlice<int64>& sizes,
192 std::vector<Tensor>* outputs) {
193 bool easy_cases_done;
194 TF_RETURN_IF_ERROR(
195 SplitEasyCases<T>(context, input, sizes, outputs, &easy_cases_done));
196 if (easy_cases_done) {
197 return Status::OK();
198 }
199
200 #if GOOGLE_CUDA
201 // TODO(olston, apassos): Handle non-CPU cases.
202 // return SplitGPU<T>(context, input, sizes, outputs);
203 #endif // GOOGLE_CUDA
204 return SplitCPU<T>(context, input, sizes, outputs);
205 }
206
207 // A class encapsulating the state and logic for batching tensors.
208 class BatchResource : public ResourceBase {
209 public:
Create(int32 num_batch_threads,int32 max_batch_size,int32 batch_timeout_micros,int32 max_enqueued_batches,const std::vector<int32> & allowed_batch_sizes,FunctionLibraryRuntime::Handle fhandle,std::unique_ptr<BatchResource> * resource)210 static Status Create(int32 num_batch_threads, int32 max_batch_size,
211 int32 batch_timeout_micros, int32 max_enqueued_batches,
212 const std::vector<int32>& allowed_batch_sizes,
213 FunctionLibraryRuntime::Handle fhandle,
214 std::unique_ptr<BatchResource>* resource) {
215 std::unique_ptr<BatchResource> new_resource(new BatchResource);
216
217 Batcher::Options batcher_options;
218 batcher_options.num_batch_threads = num_batch_threads;
219 TF_RETURN_IF_ERROR(
220 Batcher::Create(batcher_options, &new_resource->batcher_));
221
222 new_resource->batcher_queue_options_.max_batch_size = max_batch_size;
223 new_resource->batcher_queue_options_.max_enqueued_batches =
224 max_enqueued_batches;
225 new_resource->batcher_queue_options_.batch_timeout_micros =
226 batch_timeout_micros;
227
228 new_resource->allowed_batch_sizes_ = allowed_batch_sizes;
229
230 new_resource->fhandle_ = fhandle;
231
232 *resource = std::move(new_resource);
233 return Status::OK();
234 }
235
DebugString() const236 string DebugString() const final { return "BatchResource"; }
237
238 // Ingests data from one invocation of the batch op. The data is enqueued to
239 // be combined with others into a batch, asynchronously.
RegisterInput(int64 guid,OpKernelContext * context,const string & batcher_queue_name,AsyncOpKernel::DoneCallback done_callback)240 Status RegisterInput(int64 guid, OpKernelContext* context,
241 const string& batcher_queue_name,
242 AsyncOpKernel::DoneCallback done_callback) {
243 std::unique_ptr<BatchTask> batch_components(new BatchTask);
244 batch_components->guid = guid;
245 OpInputList tensors;
246 TF_RETURN_IF_ERROR(context->input_list("in_tensors", &tensors));
247 for (int i = 0; i < tensors.size(); ++i) {
248 const Tensor& tensor = tensors[i];
249 if (tensor.shape().dims() == 0) {
250 return errors::InvalidArgument(
251 "Batching input tensors must have at least one dimension");
252 }
253 if (tensors.size() >= 2 &&
254 tensor.shape().dim_size(0) != tensors[0].shape().dim_size(0)) {
255 return errors::InvalidArgument(
256 "Batching input tensors supplied in a given op invocation must "
257 "have equal 0th-dimension size");
258 }
259 batch_components->inputs.push_back(tensor);
260 }
261 OpInputList captured_tensors;
262 const auto captured_status =
263 context->input_list("captured_tensors", &captured_tensors);
264 if (captured_status.ok()) {
265 for (const Tensor& captured_tensor : captured_tensors) {
266 batch_components->captured_inputs.push_back(captured_tensor);
267 }
268 }
269 batch_components->context = context;
270 batch_components->done_callback = std::move(done_callback);
271
272 BatcherQueue* batcher_queue;
273 TF_RETURN_IF_ERROR(
274 LookupOrCreateBatcherQueue(batcher_queue_name, &batcher_queue));
275 return batcher_queue->Schedule(&batch_components);
276 }
277
278 private:
279 BatchResource() = default;
280
281 // One input to be batched. Corresponds to one invocation of the batch op.
282 struct BatchTask : public serving::BatchTask {
283 // A unique ID to identify this invocation of Batch.
284 int64 guid;
285
286 std::vector<Tensor> inputs;
287 std::vector<Tensor> captured_inputs;
288 OpKernelContext* context;
289 AsyncOpKernel::DoneCallback done_callback;
290
sizetensorflow::BatchResource::BatchTask291 size_t size() const override { return inputs[0].shape().dim_size(0); }
292 };
293
294 using Batcher = serving::SharedBatchScheduler<BatchTask>;
295 using BatcherQueue = serving::BatchScheduler<BatchTask>;
296 using Batch = serving::Batch<BatchTask>;
297
298 // Validates that it's legal to combine the tasks in 'batch' into a batch.
299 // Assumes the batch is non-empty.
ValidateBatch(const Batch & batch)300 static Status ValidateBatch(const Batch& batch) {
301 for (int task_idx = 0; task_idx < batch.num_tasks(); ++task_idx) {
302 const BatchTask& task = batch.task(task_idx);
303
304 if (task.inputs.size() != batch.task(0).inputs.size()) {
305 return errors::InvalidArgument(
306 "Batching inputs must have equal number of edges");
307 }
308 }
309
310 return Status::OK();
311 }
312
313 // Returns the smallest entry in 'allowed_batch_sizes_' that is greater than
314 // or equal to 'batch_size'. If 'allowed_batch_sizes_' is empty, simply
315 // returns 'batch_size'.
RoundToLowestAllowedBatchSize(int batch_size) const316 int RoundToLowestAllowedBatchSize(int batch_size) const {
317 if (allowed_batch_sizes_.empty()) {
318 return batch_size;
319 }
320 for (int allowed_size : allowed_batch_sizes_) {
321 if (allowed_size >= batch_size) {
322 return allowed_size;
323 }
324 }
325 LOG(ERROR) << "Maximum batch size greater than largest allowed size; "
326 "ignoring allowed sizes constraint";
327 return batch_size;
328 }
329
ConcatInputTensors(const Batch & batch,OpKernelContext * context,std::vector<Tensor> * concatenated_tensors) const330 Status ConcatInputTensors(const Batch& batch, OpKernelContext* context,
331 std::vector<Tensor>* concatenated_tensors) const {
332 if (batch.num_tasks() == 0) {
333 return errors::InvalidArgument("Empty batch.");
334 }
335
336 const int padded_batch_size = RoundToLowestAllowedBatchSize(batch.size());
337 const int padding_amount = padded_batch_size - batch.size();
338
339 // All tasks should have the same number of input edges.
340 const int num_inputs = batch.task(0).inputs.size();
341 concatenated_tensors->reserve(num_inputs);
342
343 // Process each input one at a time (the typical case has just one).
344 for (int i = 0; i < num_inputs; ++i) {
345 // Concatenate the tasks ith input tensors into a big output tensor.
346 std::vector<Tensor> to_concatenate;
347 to_concatenate.reserve(batch.num_tasks());
348 for (int task_idx = 0; task_idx < batch.num_tasks(); ++task_idx) {
349 to_concatenate.push_back(batch.task(task_idx).inputs.at(i));
350 }
351
352 // Add padding as needed. Use the first row of the first task's tensor as
353 // the data for padding.
354 if (padding_amount > 0) {
355 const Tensor& padding_source = batch.task(0).inputs.at(i);
356 Tensor padding;
357 if (padding_source.shape().dim_size(0) == 1) {
358 padding = padding_source;
359 } else {
360 const std::vector<int64> slice_sizes = {1};
361 const DataType type = padding_source.dtype();
362 Status slice_status;
363 std::vector<Tensor> slices;
364 switch (type) {
365 #define CASE(type) \
366 case DataTypeToEnum<type>::value: \
367 slice_status = \
368 SplitCPU<type>(context, padding_source, slice_sizes, &slices); \
369 break;
370 TF_CALL_ALL_TYPES(CASE);
371 #undef CASE
372 default:
373 slice_status =
374 errors::InvalidArgument("Unsupported data type: ", type);
375 break;
376 }
377 TF_RETURN_IF_ERROR(slice_status);
378 padding = slices.at(0);
379 }
380 for (int i = 0; i < padding_amount; ++i) {
381 to_concatenate.push_back(padding);
382 }
383 }
384
385 const DataType type = to_concatenate[0].dtype();
386 Status concat_status;
387 Tensor concatenated_tensor;
388 switch (type) {
389 #define CASE(type) \
390 case DataTypeToEnum<type>::value: \
391 concat_status = \
392 Concat<type>(context, to_concatenate, &concatenated_tensor); \
393 break;
394 TF_CALL_ALL_TYPES(CASE);
395 #undef CASE
396 default:
397 concat_status =
398 errors::InvalidArgument("Unsupported data type: ", type);
399 break;
400 }
401 TF_RETURN_IF_ERROR(concat_status);
402 concatenated_tensors->push_back(concatenated_tensor);
403 }
404 return Status::OK();
405 }
406
SplitOutputTensors(const std::vector<Tensor> & combined_outputs,Batch * batch) const407 Status SplitOutputTensors(const std::vector<Tensor>& combined_outputs,
408 Batch* batch) const {
409 DCHECK_GE(batch->num_tasks(), 1);
410 if (batch->num_tasks() < 1) {
411 return errors::Internal("Batch size expected to be positive; was ",
412 batch->num_tasks());
413 }
414
415 std::vector<int64> task_sizes_plus_optional_padding;
416 task_sizes_plus_optional_padding.reserve(batch->num_tasks());
417 for (int i = 0; i < batch->num_tasks(); ++i) {
418 task_sizes_plus_optional_padding.push_back(batch->task(i).size());
419 }
420 const int padding_size =
421 RoundToLowestAllowedBatchSize(batch->size()) - batch->size();
422 if (padding_size > 0) {
423 task_sizes_plus_optional_padding.push_back(padding_size);
424 }
425
426 // For each output tensor name, a divided-up tensor with one entry per task.
427 std::map<string, std::vector<Tensor>> split_tensors;
428
429 DCHECK_EQ(batch->task(0).context->num_outputs(), combined_outputs.size());
430 if (combined_outputs.size() != batch->task(0).context->num_outputs()) {
431 return errors::Internal("Wrong number of batched output tensors");
432 }
433
434 // Generate 'split_tensors' and populate the context outputs.
435 for (int i = 0; i < combined_outputs.size(); ++i) {
436 const Tensor& output_tensor = combined_outputs[i];
437 if (output_tensor.shape().dims() == 0) {
438 return errors::FailedPrecondition(
439 "Batched output tensor has 0 dimensions");
440 }
441 if (output_tensor.shape().dim_size(0) != batch->size() + padding_size) {
442 return errors::FailedPrecondition(
443 "Batched output tensor's 0th dimension does not equal the sum of "
444 "the 0th dimension sizes of the input tensors");
445 }
446
447 std::vector<Tensor> split_tensor;
448 const Status split_status = tensor::Split(
449 output_tensor, task_sizes_plus_optional_padding, &split_tensor);
450 DCHECK(split_status.ok()) << split_status.ToString();
451 if (!split_status.ok()) {
452 return errors::Internal("Tensor split operation failed: ",
453 split_status.ToString());
454 }
455 DCHECK_EQ(split_tensor.size(), task_sizes_plus_optional_padding.size());
456 if (split_tensor.size() != task_sizes_plus_optional_padding.size()) {
457 return errors::Internal(
458 "Tensor split operation did not work as expected; got ",
459 split_tensor.size(), " splits; expected ",
460 task_sizes_plus_optional_padding.size());
461 }
462
463 for (int j = 0; j < batch->num_tasks(); ++j) {
464 BatchTask& task = *(batch->mutable_task(j));
465 task.context->set_output(i, split_tensor.at(j));
466 } // (Ignore a possible final split_tensors entry containing the
467 // padding.)
468 }
469
470 return Status::OK();
471 }
472
ProcessFuncBatch(std::unique_ptr<Batch> batch) const473 void ProcessFuncBatch(std::unique_ptr<Batch> batch) const {
474 if (batch->empty()) {
475 return;
476 }
477
478 OpKernelContext* last_task_context =
479 batch->task(batch->num_tasks() - 1).context;
480
481 // Regardless of the outcome, we need to propagate the status to the
482 // individual tasks and signal that they are done. We use MakeCleanup() to
483 // ensure that this happens no matter how we exit the method below.
484 Status status;
485 bool cleanup_done = false;
486 auto cleanup_fn = [&cleanup_done, &batch](const Status& status) {
487 if (cleanup_done) {
488 return;
489 }
490 for (int i = 0; i < batch->num_tasks(); ++i) {
491 batch->mutable_task(i)->context->SetStatus(status);
492 batch->mutable_task(i)->done_callback();
493 }
494 cleanup_done = true;
495 };
496 auto finally =
497 gtl::MakeCleanup([&cleanup_fn, &status] { cleanup_fn(status); });
498
499 status = ValidateBatch(*batch);
500 if (!status.ok()) {
501 return;
502 }
503
504 std::vector<Tensor> concatenated_tensors;
505 status =
506 ConcatInputTensors(*batch, last_task_context, &concatenated_tensors);
507 if (!status.ok()) {
508 return;
509 }
510 FunctionLibraryRuntime::Options opts;
511 opts.step_id = last_task_context->step_id();
512 opts.step_container = last_task_context->step_container();
513 opts.cancellation_manager = last_task_context->cancellation_manager();
514 opts.stats_collector = last_task_context->stats_collector();
515 opts.rendezvous = last_task_context->rendezvous();
516 opts.runner = last_task_context->runner();
517
518 auto* flib = last_task_context->function_library();
519 std::vector<Tensor> combined_outputs;
520 Notification done;
521 std::vector<Tensor> args(concatenated_tensors.begin(),
522 concatenated_tensors.end());
523 const auto& captured_inputs =
524 batch->task(batch->num_tasks() - 1).captured_inputs;
525 args.insert(args.end(), captured_inputs.begin(), captured_inputs.end());
526
527 // Releases the cleanup method here, because the callback of the function
528 // library runtime will handle it now.
529 finally.release();
530 flib->Run(
531 opts, fhandle_, args, &combined_outputs, [&](const Status& run_status) {
532 Status final_status;
533 auto run_finally = gtl::MakeCleanup([&]() {
534 // We do the cleanup here as an optimization, so that it runs in
535 // the underlying TF inter-op threadpool. Running it in the
536 // threadpool, let's the ensuing ops be scheduled faster,
537 // because the executor will add them to the front of the
538 // threadpool's task queue rather than the end.
539 cleanup_fn(final_status);
540 done.Notify();
541 });
542 final_status = run_status;
543 if (!final_status.ok()) {
544 return;
545 }
546 final_status = SplitOutputTensors(combined_outputs, batch.get());
547 });
548 // By waiting for the notification we are ensuring that this thread isn't
549 // used for processing other batches, which gives the batches time to
550 // coalesce upstream. So overall the number of batches going through the
551 // devices goes down, improving latency and throughput in most cases.
552 done.WaitForNotification();
553 }
554
555 // Processes a batch of one or more BatchTask entries.
ProcessBatch(std::unique_ptr<Batch> batch) const556 void ProcessBatch(std::unique_ptr<Batch> batch) const {
557 if (batch->empty()) {
558 return;
559 }
560
561 OpKernelContext* last_task_context =
562 batch->task(batch->num_tasks() - 1).context;
563 AsyncOpKernel::DoneCallback last_task_callback =
564 batch->task(batch->num_tasks() - 1).done_callback;
565
566 OP_REQUIRES_OK_ASYNC(last_task_context, ValidateBatch(*batch),
567 last_task_callback);
568
569 // All tasks should have the same number of input edges.
570 const int num_input_edges = batch->task(0).inputs.size();
571 std::vector<Tensor> concatenated_tensors;
572 const Status concat_status =
573 ConcatInputTensors(*batch, last_task_context, &concatenated_tensors);
574 OP_REQUIRES_OK_ASYNC(last_task_context, concat_status, last_task_callback);
575
576 // Process each input edge one at a time (the typical case has just one).
577 for (int i = 0; i < num_input_edges; ++i) {
578 last_task_context->set_output(i, concatenated_tensors.at(i));
579
580 // Emit batch->num_tasks() - 1 empty output tensors.
581 for (int task_idx = 0; task_idx < batch->num_tasks() - 1; ++task_idx) {
582 const BatchTask& task = batch->task(task_idx);
583 TensorShape output_shape(task.inputs.at(i).shape());
584 output_shape.set_dim(0, 0);
585 Tensor* output = nullptr;
586 OP_REQUIRES_OK_ASYNC(
587 task.context,
588 task.context->allocate_output(i, output_shape, &output),
589 task.done_callback);
590 }
591 }
592 // Emit batch->num_tasks() - 1 empty index tensors.
593 for (int task_idx = 0; task_idx < batch->num_tasks() - 1; ++task_idx) {
594 const BatchTask& task = batch->task(task_idx);
595 TensorShape index_shape({0, 3});
596 Tensor* output = nullptr;
597 OP_REQUIRES_OK_ASYNC(
598 task.context,
599 task.context->allocate_output(num_input_edges, index_shape, &output),
600 task.done_callback);
601 }
602 // Emit all ID tensors.
603 for (int task_idx = 0; task_idx < batch->num_tasks(); ++task_idx) {
604 const BatchTask& task = batch->task(task_idx);
605 Tensor* id;
606 OP_REQUIRES_OK_ASYNC(task.context,
607 task.context->allocate_output(num_input_edges + 1,
608 TensorShape({}), &id),
609 task.done_callback);
610 id->scalar<int64>()() = task.guid;
611 }
612 OP_REQUIRES_OK_ASYNC(
613 last_task_context,
614 EmitIndexTensor(last_task_context, *batch, num_input_edges),
615 last_task_callback);
616
617 // Signal done for each element of the batch. (At this point, the contexts
618 // are no longer guaranteed to remain live.)
619 for (int task_idx = 0; task_idx < batch->num_tasks(); ++task_idx) {
620 batch->mutable_task(task_idx)->done_callback();
621 }
622 }
623
624 // Emits an index tensor, which the Unbatch op will use to un-concatenate
625 // the tensor and attribute the pieces to the right batch keys. The index
626 // tensor contains, for each input: [batch_key, start_offset, end_offset]
627 // where start_offset and end_offset represent the range of entries in the
628 // concatenated tensors that belong to that input.
629 //
630 // Emits the result to the output at 'output_index' using 'context'.
EmitIndexTensor(OpKernelContext * context,const Batch & batch,int output_index)631 static Status EmitIndexTensor(OpKernelContext* context, const Batch& batch,
632 int output_index) {
633 const TensorShape index_shape({batch.num_tasks(), 3});
634 Tensor* index = nullptr;
635 TF_RETURN_IF_ERROR(
636 context->allocate_output(output_index, index_shape, &index));
637 auto index_flat = index->shaped<int64, 2>({batch.num_tasks(), 3});
638 size_t offset = 0;
639 for (int task_idx = 0; task_idx < batch.num_tasks(); ++task_idx) {
640 const BatchTask& task = batch.task(task_idx);
641 index_flat(task_idx, 0) = task.guid;
642 index_flat(task_idx, 1) = offset;
643 index_flat(task_idx, 2) = offset + task.size();
644 offset += task.size();
645 }
646 return Status::OK();
647 }
648
649 // Looks up the batcher queue for 'queue_name'. If it did't previously exist,
650 // creates it.
LookupOrCreateBatcherQueue(const string & queue_name,BatcherQueue ** queue)651 Status LookupOrCreateBatcherQueue(const string& queue_name,
652 BatcherQueue** queue) {
653 mutex_lock l(batcher_queues_mu_);
654
655 auto it = batcher_queues_.find(queue_name);
656 if (it != batcher_queues_.end()) {
657 *queue = it->second.get();
658 return Status::OK();
659 }
660
661 std::unique_ptr<BatcherQueue> new_queue;
662 auto process_batch_callback = [this](std::unique_ptr<Batch> batch) {
663 if (fhandle_ == kInvalidHandle) {
664 ProcessBatch(std::move(batch));
665 } else {
666 ProcessFuncBatch(std::move(batch));
667 }
668 };
669 TF_RETURN_IF_ERROR(batcher_->AddQueue(batcher_queue_options_,
670 process_batch_callback, &new_queue));
671 *queue = new_queue.get();
672 batcher_queues_[queue_name] = std::move(new_queue);
673 return Status::OK();
674 }
675
676 // A batch scheduler, and options for creating queues.
677 std::shared_ptr<Batcher> batcher_;
678 Batcher::QueueOptions batcher_queue_options_;
679
680 // A collection of batcher queues, keyed on queue name.
681 // TODO(olston): Garbage-collect unused queues (perhaps simply remove empty
682 // ones (with a time delay?); it's okay if they get recreated later).
683 mutable mutex batcher_queues_mu_;
684 std::map<string, std::unique_ptr<BatcherQueue>> batcher_queues_
685 GUARDED_BY(batcher_queues_mu_);
686
687 std::vector<int32> allowed_batch_sizes_;
688 FunctionLibraryRuntime::Handle fhandle_;
689 };
690
691 class BatchFunctionKernel : public AsyncOpKernel {
692 public:
BatchFunctionKernel(OpKernelConstruction * c)693 explicit BatchFunctionKernel(OpKernelConstruction* c) : AsyncOpKernel(c) {
694 OP_REQUIRES_OK(c, c->GetAttr("container", &container_));
695 OP_REQUIRES_OK(c, c->GetAttr("shared_name", &shared_name_));
696 // If shared_name is not supplied, use name instead (prevent collisions by
697 // default).
698 if (shared_name_.empty()) {
699 shared_name_ = name();
700 }
701 OP_REQUIRES_OK(c, c->GetAttr("batching_queue", &batcher_queue_));
702 OP_REQUIRES_OK(c, c->GetAttr("num_batch_threads", &num_batch_threads_));
703 OP_REQUIRES_OK(c, c->GetAttr("max_batch_size", &max_batch_size_));
704 OP_REQUIRES_OK(c,
705 c->GetAttr("batch_timeout_micros", &batch_timeout_micros_));
706 OP_REQUIRES_OK(c,
707 c->GetAttr("max_enqueued_batches", &max_enqueued_batches_));
708 OP_REQUIRES_OK(c, c->GetAttr("allowed_batch_sizes", &allowed_batch_sizes_));
709 OP_REQUIRES_OK(c, ValidateAllowedBatchSizes());
710
711 auto lib = c->function_library();
712 OP_REQUIRES(c, lib != nullptr, errors::Internal("No function library"));
713 NameAttrList func;
714 OP_REQUIRES_OK(c, c->GetAttr("f", &func));
715 OP_REQUIRES_OK(
716 c, lib->Instantiate(func.name(), AttrSlice(&func.attr()), &fhandle_));
717 }
718
IsExpensive()719 bool IsExpensive() override { return false; }
720
ComputeAsync(OpKernelContext * c,DoneCallback done)721 void ComputeAsync(OpKernelContext* c, DoneCallback done) final {
722 BatchResource* br;
723 std::function<Status(BatchResource**)> creator = [this](BatchResource** r) {
724 std::unique_ptr<BatchResource> new_resource;
725 TF_RETURN_IF_ERROR(
726 BatchResource::Create(num_batch_threads_, max_batch_size_,
727 batch_timeout_micros_, max_enqueued_batches_,
728 allowed_batch_sizes_, fhandle_, &new_resource));
729 *r = new_resource.release();
730 return Status::OK();
731 };
732 OP_REQUIRES_OK_ASYNC(c,
733 c->resource_manager()->LookupOrCreate(
734 container_, shared_name_, &br, creator),
735 done);
736 const Status status =
737 br->RegisterInput(random::New64(), c, batcher_queue_, done);
738 br->Unref();
739 OP_REQUIRES_OK_ASYNC(c, status, done);
740 // Assume br calls done, so nothing to do here.
741 }
742
743 // Validates 'allowed_batch_sizes_'. The entries must increase monotonically,
744 // and the last one must equal 'max_batch_size_'.
ValidateAllowedBatchSizes() const745 Status ValidateAllowedBatchSizes() const {
746 if (allowed_batch_sizes_.empty()) {
747 return Status::OK();
748 }
749 int32 last_size = 0;
750 for (size_t i = 0; i < allowed_batch_sizes_.size(); ++i) {
751 const int32 size = allowed_batch_sizes_.at(i);
752 if (i > 0 && size <= last_size) {
753 return errors::InvalidArgument(
754 "allowed_batch_sizes entries must be monotonically increasing");
755 }
756 if (i == allowed_batch_sizes_.size() - 1 && size != max_batch_size_) {
757 return errors::InvalidArgument(
758 "final entry in allowed_batch_sizes must equal max_batch_size");
759 }
760 last_size = size;
761 }
762 return Status::OK();
763 }
764
765 private:
766 string container_;
767 string shared_name_;
768 string batcher_queue_;
769 int32 num_batch_threads_;
770 int32 max_batch_size_;
771 int32 batch_timeout_micros_;
772 int32 max_enqueued_batches_;
773 std::vector<int32> allowed_batch_sizes_;
774 FunctionLibraryRuntime::Handle fhandle_;
775 };
776
777 REGISTER_KERNEL_BUILDER(Name("BatchFunction").Device(DEVICE_CPU),
778 BatchFunctionKernel);
779
780 class BatchKernel : public AsyncOpKernel {
781 public:
BatchKernel(OpKernelConstruction * c)782 explicit BatchKernel(OpKernelConstruction* c) : AsyncOpKernel(c) {
783 OP_REQUIRES_OK(c, c->GetAttr("container", &container_));
784 OP_REQUIRES_OK(c, c->GetAttr("shared_name", &shared_name_));
785 // If shared_name is not supplied, use name instead (prevent collisions by
786 // default).
787 if (shared_name_.empty()) {
788 shared_name_ = name();
789 }
790 OP_REQUIRES_OK(c, c->GetAttr("batching_queue", &batcher_queue_));
791 OP_REQUIRES_OK(c, c->GetAttr("num_batch_threads", &num_batch_threads_));
792 OP_REQUIRES_OK(c, c->GetAttr("max_batch_size", &max_batch_size_));
793 OP_REQUIRES_OK(c,
794 c->GetAttr("batch_timeout_micros", &batch_timeout_micros_));
795 OP_REQUIRES_OK(c,
796 c->GetAttr("max_enqueued_batches", &max_enqueued_batches_));
797 OP_REQUIRES_OK(c, c->GetAttr("allowed_batch_sizes", &allowed_batch_sizes_));
798 OP_REQUIRES_OK(c, ValidateAllowedBatchSizes());
799 }
800
ComputeAsync(OpKernelContext * c,DoneCallback done)801 void ComputeAsync(OpKernelContext* c, DoneCallback done) final {
802 BatchResource* br;
803 std::function<Status(BatchResource**)> creator = [this](BatchResource** r) {
804 std::unique_ptr<BatchResource> new_resource;
805 TF_RETURN_IF_ERROR(BatchResource::Create(
806 num_batch_threads_, max_batch_size_, batch_timeout_micros_,
807 max_enqueued_batches_, allowed_batch_sizes_, kInvalidHandle,
808 &new_resource));
809 *r = new_resource.release();
810 return Status::OK();
811 };
812 OP_REQUIRES_OK_ASYNC(c,
813 c->resource_manager()->LookupOrCreate(
814 container_, shared_name_, &br, creator),
815 done);
816 const Status status =
817 br->RegisterInput(random::New64(), c, batcher_queue_, done);
818 br->Unref();
819 OP_REQUIRES_OK_ASYNC(c, status, done);
820 // Assume br calls done, so nothing to do here.
821 }
822
823 // Validates 'allowed_batch_sizes_'. The entries must increase monotonically,
824 // and the last one must equal 'max_batch_size_'.
ValidateAllowedBatchSizes() const825 Status ValidateAllowedBatchSizes() const {
826 if (allowed_batch_sizes_.empty()) {
827 return Status::OK();
828 }
829 int32 last_size = 0;
830 for (size_t i = 0; i < allowed_batch_sizes_.size(); ++i) {
831 const int32 size = allowed_batch_sizes_.at(i);
832 if (i > 0 && size <= last_size) {
833 return errors::InvalidArgument(
834 "allowed_batch_sizes entries must be monotonically increasing");
835 }
836 if (i == allowed_batch_sizes_.size() - 1 && size != max_batch_size_) {
837 return errors::InvalidArgument(
838 "final entry in allowed_batch_sizes must equal max_batch_size");
839 }
840 last_size = size;
841 }
842 return Status::OK();
843 }
844
845 private:
846 string container_;
847 string shared_name_;
848 string batcher_queue_;
849 int32 num_batch_threads_;
850 int32 max_batch_size_;
851 int32 batch_timeout_micros_;
852 int32 max_enqueued_batches_;
853 std::vector<int32> allowed_batch_sizes_;
854 };
855
856 REGISTER_KERNEL_BUILDER(Name("Batch").Device(DEVICE_CPU), BatchKernel);
857
858 // A class encapsulating the state and logic for unbatching tensors.
859 //
860 // UnbatchResource keeps two data structures indexed by batch-key: one which has
861 // the continuations for all concurrent kernels which are waiting for tensors
862 // and another which has tensors which are waiting for their corresponding
863 // kernels to run. Whenever a kernel runs, we either grab its tensor if it's
864 // waiting already, or we insert it in the queue and then look at its tensor to
865 // see if it can be used to dispatch any stored continuations.
866 class UnbatchResource : public ResourceBase {
867 public:
UnbatchResource(int32 timeout_micros)868 explicit UnbatchResource(int32 timeout_micros)
869 : timeout_micros_(timeout_micros),
870 timeout_enforcer_(new serving::PeriodicFunction(
871 [this] { EnforceTimeout(); }, 1000 /* 1 ms */)) {}
872
~UnbatchResource()873 ~UnbatchResource() override {
874 // Tear down 'timeout_enforcer_' first, since it accesses other state in
875 // this class.
876 timeout_enforcer_ = nullptr;
877 }
878
DebugString() const879 string DebugString() const final { return "UnbatchResource"; }
880
Compute(OpKernelContext * context,AsyncOpKernel::DoneCallback done)881 Status Compute(OpKernelContext* context, AsyncOpKernel::DoneCallback done) {
882 const Tensor& data_t = context->input(0);
883 const Tensor& batch_index_t = context->input(1);
884
885 if (batch_index_t.shape().dim_size(0) > data_t.shape().dim_size(0)) {
886 return errors::InvalidArgument(
887 "Wrong shape for index tensor. Expected 0th dimension size to be no "
888 "greater than ",
889 data_t.shape().dim_size(0),
890 "; Got: ", batch_index_t.shape().dim_size(0), ".");
891 }
892 if (batch_index_t.shape().dim_size(1) != 3) {
893 return errors::InvalidArgument(
894 "Wrong shape for index tensor. Expected 1st dimension size to be 3 ; "
895 "Got: ",
896 batch_index_t.shape().dim_size(1), ".");
897 }
898
899 const int64 batch_key = context->input(2).scalar<int64>()();
900 const bool nonempty_input = batch_index_t.dim_size(0) > 0;
901
902 // If we have a non-empty tensor, slice it up.
903 // (It is important to do this outside of the critical section below.)
904 // The following variables are populated iff 'nonempty_input==true'.
905 std::vector<int64> sizes;
906 std::vector<int64> batch_keys;
907 std::vector<Tensor> split_inputs;
908 if (nonempty_input) {
909 auto batch_indices =
910 batch_index_t.shaped<int64, 2>({batch_index_t.dim_size(0), 3});
911 for (int i = 0; i < batch_index_t.dim_size(0); ++i) {
912 sizes.push_back(batch_indices(i, 2) - batch_indices(i, 1));
913 batch_keys.push_back(batch_indices(i, 0));
914 }
915
916 const DataType type = data_t.dtype();
917 switch (type) {
918 #define CASE(type) \
919 case DataTypeToEnum<type>::value: \
920 TF_RETURN_IF_ERROR(Split<type>(context, data_t, sizes, &split_inputs)); \
921 break;
922 TF_CALL_ALL_TYPES(CASE);
923 #undef CASE
924 default:
925 return errors::InvalidArgument("Unsupported data type: ", type);
926 }
927 }
928
929 // Critical section.
930 std::vector<AsyncOpKernel::DoneCallback> done_callbacks_to_call;
931 Status status = [&]() -> Status {
932 mutex_lock ml(mu_);
933
934 // Check to see whether the tensor we want is already ready.
935 auto tensor_it = waiting_tensors_.find(batch_key);
936 if (tensor_it != waiting_tensors_.end()) {
937 context->set_output(0, tensor_it->second.tensor);
938 waiting_tensors_.erase(tensor_it);
939 done_callbacks_to_call.push_back(done);
940 return Status::OK();
941 }
942
943 const uint64 deadline_micros =
944 Env::Default()->NowMicros() + timeout_micros_;
945
946 // Add ourselves to the waitlist for tensors.
947 if (!waiting_callbacks_
948 .emplace(batch_key,
949 WaitingCallback{deadline_micros, context, done})
950 .second) {
951 return errors::AlreadyExists(
952 "Multiple session runs with the same batch key.");
953 }
954
955 // If we have a non-empty tensor, finish the waitlisted runs,
956 // and store any remaining pieces.
957 if (nonempty_input) {
958 for (size_t i = 0; i < batch_keys.size(); ++i) {
959 auto runs_it = waiting_callbacks_.find(batch_keys[i]);
960 if (runs_it != waiting_callbacks_.end()) {
961 runs_it->second.context->set_output(0, split_inputs[i]);
962 done_callbacks_to_call.push_back(runs_it->second.done);
963 waiting_callbacks_.erase(runs_it);
964 } else {
965 // Note: the deadline here is in case we are arriving late and the
966 // kernel that should rendezvous with this tensor has already waited
967 // and timed out.
968 if (!waiting_tensors_
969 .emplace(batch_keys[i],
970 WaitingTensor{deadline_micros, split_inputs[i]})
971 .second) {
972 return errors::AlreadyExists(
973 "Multiple tensors returned for same batch key.");
974 }
975 }
976 }
977 }
978
979 return Status::OK();
980 }();
981
982 for (const AsyncOpKernel::DoneCallback& done_callback :
983 done_callbacks_to_call) {
984 done_callback();
985 }
986
987 return status;
988 }
989
990 private:
991 // Evicts waiting tensors and callbacks that have exceeded their deadline.
EnforceTimeout()992 void EnforceTimeout() {
993 const uint64 now = Env::Default()->NowMicros();
994 std::vector<WaitingCallback> evicted_callbacks;
995
996 {
997 mutex_lock ml(mu_);
998
999 for (auto it = waiting_tensors_.begin(); it != waiting_tensors_.end();) {
1000 const WaitingTensor& waiting_tensor = it->second;
1001 if (waiting_tensor.deadline_micros < now) {
1002 it = waiting_tensors_.erase(it);
1003 } else {
1004 ++it;
1005 }
1006 }
1007
1008 for (auto it = waiting_callbacks_.begin();
1009 it != waiting_callbacks_.end();) {
1010 const WaitingCallback& waiting_callback = it->second;
1011 if (waiting_callback.deadline_micros < now) {
1012 evicted_callbacks.push_back(waiting_callback);
1013 it = waiting_callbacks_.erase(it);
1014 } else {
1015 ++it;
1016 }
1017 }
1018 }
1019
1020 for (const WaitingCallback& evicted_callback : evicted_callbacks) {
1021 evicted_callback.context->CtxFailureWithWarning(errors::DeadlineExceeded(
1022 "Batched data did not arrive within timeout window."));
1023 evicted_callback.done();
1024 }
1025 }
1026
1027 struct WaitingTensor {
1028 uint64 deadline_micros;
1029 Tensor tensor;
1030 };
1031
1032 struct WaitingCallback {
1033 uint64 deadline_micros;
1034 OpKernelContext* context;
1035 AsyncOpKernel::DoneCallback done;
1036 };
1037
1038 const int32 timeout_micros_;
1039
1040 mutex mu_;
1041
1042 // Maps keyed by BatchKey of tensors waiting for callbacks and callbacks
1043 // waiting for tensors.
1044 std::unordered_map<int64, WaitingTensor> waiting_tensors_ GUARDED_BY(mu_);
1045 std::unordered_map<int64, WaitingCallback> waiting_callbacks_ GUARDED_BY(mu_);
1046
1047 // A thread that evicts waiting tensors and callbacks that have exceeded their
1048 // deadline.
1049 std::unique_ptr<serving::PeriodicFunction> timeout_enforcer_;
1050 };
1051
1052 class UnbatchKernel : public AsyncOpKernel {
1053 public:
UnbatchKernel(OpKernelConstruction * c)1054 explicit UnbatchKernel(OpKernelConstruction* c) : AsyncOpKernel(c) {
1055 OP_REQUIRES_OK(c, c->GetAttr("container", &container_));
1056 OP_REQUIRES_OK(c, c->GetAttr("shared_name", &shared_name_));
1057 // If shared_name is not supplied, use name instead (prevent collisions by
1058 // default).
1059 if (shared_name_.empty()) {
1060 shared_name_ = name();
1061 }
1062 OP_REQUIRES_OK(c, c->GetAttr("timeout_micros", &timeout_micros_));
1063 }
1064
ComputeAsync(OpKernelContext * c,DoneCallback done)1065 void ComputeAsync(OpKernelContext* c, DoneCallback done) final {
1066 UnbatchResource* ubr;
1067 std::function<Status(UnbatchResource**)> creator =
1068 [this](UnbatchResource** r) {
1069 *r = new UnbatchResource(timeout_micros_);
1070 return Status::OK();
1071 };
1072 OP_REQUIRES_OK_ASYNC(c,
1073 c->resource_manager()->LookupOrCreate(
1074 container_, shared_name_, &ubr, creator),
1075 done);
1076 auto status = ubr->Compute(c, done);
1077 ubr->Unref();
1078 OP_REQUIRES_OK_ASYNC(c, status, done);
1079 // Assume ubr calls done, so nothing to do here.
1080 }
1081
1082 private:
1083 string container_;
1084 string shared_name_;
1085 int32 timeout_micros_;
1086 };
1087 REGISTER_KERNEL_BUILDER(Name("Unbatch").Device(DEVICE_CPU), UnbatchKernel);
1088
1089 // A class encapsulating the state and logic for batching tensors
1090 // deterministically for the gradient of unbatch.
1091 class UnbatchGradResource : public ResourceBase {
1092 public:
UnbatchGradResource()1093 UnbatchGradResource() {}
1094
DebugString() const1095 string DebugString() const final { return "UnbatchGradResource"; }
1096
1097 // Flushes the information for one batch, given its context and done
1098 // callback. Clears all information about it from the available_tensors_.
OutputBatch(OpKernelContext * context,const AsyncOpKernel::DoneCallback & done)1099 Status OutputBatch(OpKernelContext* context,
1100 const AsyncOpKernel::DoneCallback& done)
1101 EXCLUSIVE_LOCKS_REQUIRED(mu_) {
1102 const Tensor& batch_index_t = context->input(1);
1103 auto batch_index =
1104 batch_index_t.shaped<int64, 2>({batch_index_t.dim_size(0), 3});
1105 std::vector<Tensor> tensors;
1106 for (int i = 0; i < batch_index_t.dim_size(0); ++i) {
1107 auto available_it = available_tensors_.find(batch_index(i, 0));
1108 if (available_it == available_tensors_.end()) {
1109 return errors::Internal("bad bookkeeping of available tensors.");
1110 }
1111 tensors.push_back(available_it->second);
1112 available_tensors_.erase(available_it);
1113 }
1114
1115 const DataType type = tensors[0].dtype();
1116 Tensor concatenated_tensor;
1117 switch (type) {
1118 #define CASE(type) \
1119 case DataTypeToEnum<type>::value: \
1120 TF_RETURN_IF_ERROR(Concat<type>(context, tensors, &concatenated_tensor)); \
1121 context->set_output(0, concatenated_tensor); \
1122 break;
1123 TF_CALL_ALL_TYPES(CASE);
1124 #undef CASE
1125 default:
1126 return errors::InvalidArgument("Unsupported data type: ", type);
1127 }
1128 done();
1129 return Status::OK();
1130 }
1131
1132 // Ingests data from one invocation of the op.
Compute(OpKernelContext * context,const AsyncOpKernel::DoneCallback & done)1133 Status Compute(OpKernelContext* context,
1134 const AsyncOpKernel::DoneCallback& done) {
1135 const Tensor& data_t = context->input(0);
1136 const Tensor& batch_index_t = context->input(1);
1137 const Tensor& grad_t = context->input(2);
1138
1139 mutex_lock ml(mu_);
1140
1141 const int64 batch_key = context->input(3).scalar<int64>()();
1142 // Mark our tensor as available.
1143 if (!available_tensors_.emplace(batch_key, grad_t).second) {
1144 return errors::InvalidArgument("Two runs with the same batch key.");
1145 }
1146
1147 // Check whether we have a valid input tensor and, if so, create its
1148 // dispatch logic.
1149 if (data_t.NumElements() > 0) {
1150 if (batch_index_t.NumElements() == 0) {
1151 return errors::InvalidArgument(
1152 "batch_index is empty while the tensor isn't.");
1153 }
1154 std::unordered_set<int64> missing_tensors;
1155 const auto batch_index =
1156 batch_index_t.shaped<int64, 2>({batch_index_t.dim_size(0), 3});
1157 for (int i = 0; i < batch_index_t.dim_size(0); ++i) {
1158 const int64 batch_key = batch_index(i, 0);
1159 if (available_tensors_.find(batch_key) == available_tensors_.end()) {
1160 missing_tensors.emplace(batch_key);
1161 }
1162 }
1163 if (missing_tensors.empty()) {
1164 return OutputBatch(context, done);
1165 }
1166 if (!available_batches_
1167 .emplace(batch_key, Batch{missing_tensors, context, done})
1168 .second) {
1169 return errors::InvalidArgument(
1170 "Batch key with valid batch used twice.");
1171 }
1172 for (const int64 i : missing_tensors) {
1173 if (!desired_tensor_to_batch_map_.emplace(i, batch_key).second) {
1174 return errors::InvalidArgument(
1175 "Missing tensor wanted by more than one batch.");
1176 }
1177 }
1178 } else {
1179 // If we don't have a valid input tensor we can output an empty tensor and
1180 // call our done closure.
1181 TensorShape output_shape(grad_t.shape());
1182 output_shape.set_dim(0, 0);
1183 Tensor* output = nullptr;
1184 TF_RETURN_IF_ERROR(context->allocate_output(0, output_shape, &output));
1185 done();
1186 }
1187
1188 // Search to see whether our tensor is desired by any existing batch.
1189 auto desire_it = desired_tensor_to_batch_map_.find(batch_key);
1190 if (desire_it != desired_tensor_to_batch_map_.end()) {
1191 // Mark our tensor as no longer missing.
1192 auto batch_it = available_batches_.find(desire_it->second);
1193 desired_tensor_to_batch_map_.erase(desire_it);
1194 if (batch_it == available_batches_.end()) {
1195 return errors::InvalidArgument("Batch no longer exists.");
1196 }
1197 batch_it->second.missing_tensors.erase(batch_key);
1198 // If all tensors are available we should concatenate them and dispatch
1199 // the batch.
1200 if (batch_it->second.missing_tensors.empty()) {
1201 TF_RETURN_IF_ERROR(
1202 OutputBatch(batch_it->second.context, batch_it->second.done));
1203 available_batches_.erase(batch_it);
1204 }
1205 }
1206 return Status::OK();
1207 }
1208
1209 private:
1210 mutex mu_;
1211
1212 // Represents a still-incomplete batch of tensors. When all tensors become
1213 // available they will be concatenated in the right order and sent through the
1214 // context.
1215 struct Batch {
1216 // Batch keys for tensors which are still missing from this batch. When this
1217 // is empty the Tensors can be concatenated and forwarded.
1218 std::unordered_set<int64> missing_tensors;
1219
1220 // Context and callback for the session responsible for finishing this
1221 // batch.
1222 OpKernelContext* context;
1223 AsyncOpKernel::DoneCallback done;
1224 };
1225
1226 // Map from batch key of the session which will output the batched gradients
1227 // to still-incomplete batches.
1228 std::unordered_map<int64, Batch> available_batches_;
1229
1230 // Map from batch key to tensors which are waiting for their batches to be
1231 // available.
1232 std::unordered_map<int64, Tensor> available_tensors_;
1233
1234 // Map from batch key of a tensor which is not yet available to the batch key
1235 // of the batch to which it belongs.
1236 std::unordered_map<int64, int64> desired_tensor_to_batch_map_;
1237 };
1238
1239 class UnbatchGradKernel : public AsyncOpKernel {
1240 public:
UnbatchGradKernel(OpKernelConstruction * c)1241 explicit UnbatchGradKernel(OpKernelConstruction* c) : AsyncOpKernel(c) {
1242 OP_REQUIRES_OK(c, c->GetAttr("container", &container_));
1243 OP_REQUIRES_OK(c, c->GetAttr("shared_name", &shared_name_));
1244 // If shared_name is not supplied, use name instead (prevent collisions by
1245 // default).
1246 if (shared_name_.empty()) {
1247 shared_name_ = name();
1248 }
1249 }
1250
ComputeAsync(OpKernelContext * c,DoneCallback done)1251 void ComputeAsync(OpKernelContext* c, DoneCallback done) final {
1252 UnbatchGradResource* ubr;
1253 std::function<Status(UnbatchGradResource**)> creator =
1254 [](UnbatchGradResource** r) {
1255 *r = new UnbatchGradResource();
1256 return Status::OK();
1257 };
1258 OP_REQUIRES_OK_ASYNC(c,
1259 c->resource_manager()->LookupOrCreate(
1260 container_, shared_name_, &ubr, creator),
1261 done);
1262 Status status = ubr->Compute(c, done);
1263 ubr->Unref();
1264 OP_REQUIRES_OK_ASYNC(c, status, done);
1265 // Assume ubr calls done, so nothing to do here.
1266 }
1267
1268 private:
1269 string container_;
1270 string shared_name_;
1271 };
1272 REGISTER_KERNEL_BUILDER(Name("UnbatchGrad").Device(DEVICE_CPU),
1273 UnbatchGradKernel);
1274
1275 } // namespace tensorflow
1276