1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/function.h"
17 #include "tensorflow/core/framework/op_kernel.h"
18 #include "tensorflow/core/framework/register_types.h"
19 #include "tensorflow/core/framework/resource_mgr.h"
20 #include "tensorflow/core/framework/tensor.h"
21 #include "tensorflow/core/framework/tensor_util.h"
22 #include "tensorflow/core/framework/types.h"
23 #include "tensorflow/core/kernels/batching_util/periodic_function.h"
24 #include "tensorflow/core/kernels/batching_util/shared_batch_scheduler.h"
25 #include "tensorflow/core/kernels/concat_lib.h"
26 #include "tensorflow/core/kernels/ops_util.h"
27 #include "tensorflow/core/kernels/split_lib.h"
28 #include "tensorflow/core/lib/gtl/cleanup.h"
29 #include "tensorflow/core/lib/random/random.h"
30 #include "tensorflow/core/platform/context.h"
31 #include "tensorflow/core/platform/errors.h"
32 #include "tensorflow/core/platform/macros.h"
33
34 namespace tensorflow {
35
36 typedef Eigen::ThreadPoolDevice CPUDevice;
37 typedef Eigen::GpuDevice GPUDevice;
38 #ifdef TENSORFLOW_USE_SYCL
39 typedef Eigen::SyclDevice SYCLDevice;
40 #endif // TENSORFLOW_USE_SYCL
41
42 // Concatenates 'inputs' into a single tensor along the zeroth dimension.
43 // Requires that all elements of 'inputs' have element type T. Writes to the
44 // op's output at position 'output_index', using 'context' for the allocation to
45 // ensure proper device placement.
46 template <typename T>
Concat(OpKernelContext * context,const gtl::ArraySlice<Tensor> & inputs,Tensor * output)47 Status Concat(OpKernelContext* context, const gtl::ArraySlice<Tensor>& inputs,
48 Tensor* output) {
49 const int input_dims = inputs[0].dims();
50 const TensorShape& input_shape = inputs[0].shape();
51
52 // Note that we reduce the concat of k-dimensional tensors into a two
53 // dimensional concat. Assuming the dimensions of any input tensor are
54 // {y0, y1,...,ym-1}, we flatten it to {1, y}, where y = Prod_i(yi).
55 std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>> inputs_flat;
56 inputs_flat.reserve(inputs.size());
57 int64 output_dim0 = 0;
58 for (size_t i = 0; i < inputs.size(); ++i) {
59 const Tensor& input = inputs[i];
60 if (input.dims() != input_dims) {
61 return errors::InvalidArgument(
62 "Ranks of all input tensors should match: shape[0] = ",
63 input_shape.DebugString(), " vs. shape[", i,
64 "] = ", input.shape().DebugString());
65 }
66 for (int j = 1; j < input_dims; ++j) {
67 if (input.dim_size(j) != input_shape.dim_size(j)) {
68 return errors::InvalidArgument(
69 "Dimensions of inputs should match: shape[0] = ",
70 input_shape.DebugString(), " vs. shape[", i,
71 "] = ", input.shape().DebugString());
72 }
73 }
74 if (input.NumElements() > 0) {
75 inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
76 input.shaped<T, 2>({1, input.NumElements()})));
77 }
78 output_dim0 += input.dim_size(0);
79 }
80
81 TensorShape output_shape(input_shape);
82 output_shape.set_dim(0, output_dim0);
83 TF_RETURN_IF_ERROR(
84 context->allocate_temp(DataTypeToEnum<T>::value, output_shape, output));
85 if (output->NumElements() > 0) {
86 auto output_flat = output->shaped<T, 2>({1, output->NumElements()});
87 #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
88 (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
89 if (std::is_same<Device, GPUDevice>::value) {
90 ConcatGPU<T>(context, inputs_flat, output, &output_flat);
91 return Status::OK();
92 }
93 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
94 ConcatCPU<T>(context->device(), inputs_flat, &output_flat);
95 }
96
97 return Status::OK();
98 }
99
100 // The Split*() functions split 'input' with element type T into 'sizes.size()'
101 // tensors along the zeroth dimension, with the ith split having zeroth-
102 // dimension size 'sizes[i]'. They allocate the output tensors using 'context',
103 // for proper device placement.
104
105 // Handles special cases that are cheap. Sets 'done==true' iff it found an
106 // applicable special case and wrote to the outputs. Otherwise acts as a no-op.
107 template <typename T>
SplitEasyCases(OpKernelContext * context,const Tensor & input,const gtl::ArraySlice<int64> & sizes,std::vector<Tensor> * outputs,bool * done)108 Status SplitEasyCases(OpKernelContext* context, const Tensor& input,
109 const gtl::ArraySlice<int64>& sizes,
110 std::vector<Tensor>* outputs, bool* done) {
111 *done = false;
112
113 int64 total_size = 0;
114 for (const int64 size : sizes) {
115 total_size += size;
116 }
117 if (total_size > input.shape().dim_size(0)) {
118 return errors::InvalidArgument(
119 "Sum of split sizes must not exceed dim0-size of input tensor");
120 }
121
122 // Special case 0: trivial 1-way split.
123 if (sizes.size() == 1 && sizes.at(0) == input.shape().dim_size(0)) {
124 outputs->push_back(input);
125 *done = true;
126 return Status::OK();
127 }
128
129 // Special case 1: input is aligned.
130 if (IsInnerDimsSizeAligned<T>(input.shape())) {
131 int64 position = 0;
132 for (const int64 size : sizes) {
133 outputs->emplace_back(input.Slice(position, position + size));
134 position += size;
135 }
136 *done = true;
137 return Status::OK();
138 }
139
140 return Status::OK();
141 }
142
143 // Handles the general case, on CPU.
144 template <typename T>
SplitCPU(OpKernelContext * context,const Tensor & input,const gtl::ArraySlice<int64> & sizes,std::vector<Tensor> * outputs)145 Status SplitCPU(OpKernelContext* context, const Tensor& input,
146 const gtl::ArraySlice<int64>& sizes,
147 std::vector<Tensor>* outputs) {
148 int64 suffix_dim_size = 1;
149 for (int i = 1; i < input.shape().dims(); ++i) {
150 suffix_dim_size *= input.shape().dim_size(i);
151 }
152 auto input_reshaped =
153 input.shaped<T, 2>({input.shape().dim_size(0), suffix_dim_size});
154
155 int64 position = 0;
156 for (const int64 size : sizes) {
157 TensorShape output_shape = input.shape();
158 output_shape.set_dim(0, size);
159 Tensor output;
160 TF_RETURN_IF_ERROR(
161 context->allocate_temp(input.dtype(), output_shape, &output));
162 auto output_shaped = output.shaped<T, 2>({size, suffix_dim_size});
163
164 Eigen::DSizes<Eigen::DenseIndex, 2> slice_indices{position, 0};
165 Eigen::DSizes<Eigen::DenseIndex, 2> slice_sizes{size, suffix_dim_size};
166 functor::Split<CPUDevice, T, 2>()(context->eigen_device<CPUDevice>(),
167 output_shaped, input_reshaped,
168 slice_indices, slice_sizes);
169
170 outputs->emplace_back(output);
171
172 position += size;
173 }
174
175 return Status::OK();
176 }
177
178 #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
179 (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
180
181 // Handles the general case, on GPU.
182 template <typename T>
SplitGPU(OpKernelContext * context,const Tensor & input,const gtl::ArraySlice<int64> & sizes,std::vector<Tensor> * outputs)183 Status SplitGPU(OpKernelContext* context, const Tensor& input,
184 const gtl::ArraySlice<int64>& sizes,
185 std::vector<Tensor>* outputs) {
186 // TODO(olston, apassos): Implement this.
187 LOG(FATAL) << "Not yet implemented"; // Crash ok
188 }
189
190 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
191
192 // The outer function that dispatches to the various Split*() functions above.
193 template <typename T>
Split(OpKernelContext * context,const Tensor & input,const gtl::ArraySlice<int64> & sizes,std::vector<Tensor> * outputs)194 Status Split(OpKernelContext* context, const Tensor& input,
195 const gtl::ArraySlice<int64>& sizes,
196 std::vector<Tensor>* outputs) {
197 bool easy_cases_done;
198 TF_RETURN_IF_ERROR(
199 SplitEasyCases<T>(context, input, sizes, outputs, &easy_cases_done));
200 if (easy_cases_done) {
201 return Status::OK();
202 }
203
204 #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
205 (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
206 // TODO(olston, apassos): Handle non-CPU cases.
207 // return SplitGPU<T>(context, input, sizes, outputs);
208 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
209 return SplitCPU<T>(context, input, sizes, outputs);
210 }
211
212 // A class encapsulating the state and logic for batching tensors.
213 class BatchResource : public ResourceBase {
214 public:
Create(int32 num_batch_threads,int32 max_batch_size,int32 batch_timeout_micros,int32 max_enqueued_batches,const std::vector<int32> & allowed_batch_sizes,FunctionLibraryRuntime::Handle fhandle,std::unique_ptr<BatchResource> * resource)215 static Status Create(int32 num_batch_threads, int32 max_batch_size,
216 int32 batch_timeout_micros, int32 max_enqueued_batches,
217 const std::vector<int32>& allowed_batch_sizes,
218 FunctionLibraryRuntime::Handle fhandle,
219 std::unique_ptr<BatchResource>* resource) {
220 std::unique_ptr<BatchResource> new_resource(new BatchResource);
221
222 Batcher::Options batcher_options;
223 batcher_options.num_batch_threads = num_batch_threads;
224 TF_RETURN_IF_ERROR(
225 Batcher::Create(batcher_options, &new_resource->batcher_));
226
227 new_resource->batcher_queue_options_.max_batch_size = max_batch_size;
228 new_resource->batcher_queue_options_.max_enqueued_batches =
229 max_enqueued_batches;
230 new_resource->batcher_queue_options_.batch_timeout_micros =
231 batch_timeout_micros;
232
233 new_resource->allowed_batch_sizes_ = allowed_batch_sizes;
234
235 new_resource->fhandle_ = fhandle;
236
237 *resource = std::move(new_resource);
238 return Status::OK();
239 }
240
DebugString() const241 string DebugString() const final { return "BatchResource"; }
242
243 // Ingests data from one invocation of the batch op. The data is enqueued to
244 // be combined with others into a batch, asynchronously.
RegisterInput(int64 guid,OpKernelContext * context,const string & batcher_queue_name,AsyncOpKernel::DoneCallback done_callback)245 Status RegisterInput(int64 guid, OpKernelContext* context,
246 const string& batcher_queue_name,
247 AsyncOpKernel::DoneCallback done_callback) {
248 std::unique_ptr<BatchTask> batch_components(new BatchTask);
249 batch_components->guid = guid;
250 batch_components->propagated_context = Context(ContextKind::kThread);
251 OpInputList tensors;
252 TF_RETURN_IF_ERROR(context->input_list("in_tensors", &tensors));
253 for (int i = 0; i < tensors.size(); ++i) {
254 const Tensor& tensor = tensors[i];
255 if (tensor.shape().dims() == 0) {
256 return errors::InvalidArgument(
257 "Batching input tensors must have at least one dimension");
258 }
259 if (tensors.size() >= 2 &&
260 tensor.shape().dim_size(0) != tensors[0].shape().dim_size(0)) {
261 return errors::InvalidArgument(
262 "Batching input tensors supplied in a given op invocation must "
263 "have equal 0th-dimension size");
264 }
265 batch_components->inputs.push_back(tensor);
266 }
267 OpInputList captured_tensors;
268 const auto captured_status =
269 context->input_list("captured_tensors", &captured_tensors);
270 if (captured_status.ok()) {
271 for (const Tensor& captured_tensor : captured_tensors) {
272 batch_components->captured_inputs.push_back(captured_tensor);
273 }
274 }
275 batch_components->context = context;
276 batch_components->done_callback = std::move(done_callback);
277
278 BatcherQueue* batcher_queue;
279 TF_RETURN_IF_ERROR(
280 LookupOrCreateBatcherQueue(batcher_queue_name, &batcher_queue));
281 return batcher_queue->Schedule(&batch_components);
282 }
283
284 private:
285 BatchResource() = default;
286
287 // One input to be batched. Corresponds to one invocation of the batch op.
288 struct BatchTask : public serving::BatchTask {
289 // A unique ID to identify this invocation of Batch.
290 int64 guid;
291
292 Context propagated_context;
293
294 std::vector<Tensor> inputs;
295 std::vector<Tensor> captured_inputs;
296 OpKernelContext* context;
297 AsyncOpKernel::DoneCallback done_callback;
298
sizetensorflow::BatchResource::BatchTask299 size_t size() const override { return inputs[0].shape().dim_size(0); }
300 };
301
302 using Batcher = serving::SharedBatchScheduler<BatchTask>;
303 using BatcherQueue = serving::BatchScheduler<BatchTask>;
304 using Batch = serving::Batch<BatchTask>;
305
306 // Validates that it's legal to combine the tasks in 'batch' into a batch.
307 // Assumes the batch is non-empty.
ValidateBatch(const Batch & batch)308 static Status ValidateBatch(const Batch& batch) {
309 for (int task_idx = 0; task_idx < batch.num_tasks(); ++task_idx) {
310 const BatchTask& task = batch.task(task_idx);
311
312 if (task.inputs.size() != batch.task(0).inputs.size()) {
313 return errors::InvalidArgument(
314 "Batching inputs must have equal number of edges");
315 }
316 }
317
318 return Status::OK();
319 }
320
321 // Returns the smallest entry in 'allowed_batch_sizes_' that is greater than
322 // or equal to 'batch_size'. If 'allowed_batch_sizes_' is empty, simply
323 // returns 'batch_size'.
RoundToLowestAllowedBatchSize(int batch_size) const324 int RoundToLowestAllowedBatchSize(int batch_size) const {
325 if (allowed_batch_sizes_.empty()) {
326 return batch_size;
327 }
328 for (int allowed_size : allowed_batch_sizes_) {
329 if (allowed_size >= batch_size) {
330 return allowed_size;
331 }
332 }
333 LOG(ERROR) << "Maximum batch size greater than largest allowed size; "
334 "ignoring allowed sizes constraint";
335 return batch_size;
336 }
337
ConcatInputTensors(const Batch & batch,OpKernelContext * context,std::vector<Tensor> * concatenated_tensors) const338 Status ConcatInputTensors(const Batch& batch, OpKernelContext* context,
339 std::vector<Tensor>* concatenated_tensors) const {
340 if (batch.num_tasks() == 0) {
341 return errors::InvalidArgument("Empty batch.");
342 }
343
344 const int padded_batch_size = RoundToLowestAllowedBatchSize(batch.size());
345 const int padding_amount = padded_batch_size - batch.size();
346
347 // All tasks should have the same number of input edges.
348 const int num_inputs = batch.task(0).inputs.size();
349 concatenated_tensors->reserve(num_inputs);
350
351 // Process each input one at a time (the typical case has just one).
352 for (int i = 0; i < num_inputs; ++i) {
353 // Concatenate the tasks ith input tensors into a big output tensor.
354 std::vector<Tensor> to_concatenate;
355 to_concatenate.reserve(batch.num_tasks());
356 for (int task_idx = 0; task_idx < batch.num_tasks(); ++task_idx) {
357 to_concatenate.push_back(batch.task(task_idx).inputs.at(i));
358 }
359
360 // Add padding as needed. Use the first row of the first task's tensor as
361 // the data for padding.
362 if (padding_amount > 0) {
363 const Tensor& padding_source = batch.task(0).inputs.at(i);
364 Tensor padding;
365 if (padding_source.shape().dim_size(0) == 0) {
366 return errors::InvalidArgument(
367 "Cannot use an empty tensor with zero rows as padding when "
368 "batching.");
369 }
370 if (padding_source.shape().dim_size(0) == 1) {
371 padding = padding_source;
372 } else {
373 padding = padding_source.Slice(0, 1);
374 }
375 for (int i = 0; i < padding_amount; ++i) {
376 to_concatenate.push_back(padding);
377 }
378 }
379
380 const DataType type = to_concatenate[0].dtype();
381 Status concat_status;
382 Tensor concatenated_tensor;
383 switch (type) {
384 #define CASE(type) \
385 case DataTypeToEnum<type>::value: \
386 concat_status = \
387 Concat<type>(context, to_concatenate, &concatenated_tensor); \
388 break;
389 TF_CALL_ALL_TYPES(CASE);
390 #undef CASE
391 default:
392 concat_status =
393 errors::InvalidArgument("Unsupported data type: ", type);
394 break;
395 }
396 TF_RETURN_IF_ERROR(concat_status);
397 concatenated_tensors->push_back(concatenated_tensor);
398 }
399 return Status::OK();
400 }
401
SplitOutputTensors(const std::vector<Tensor> & combined_outputs,Batch * batch) const402 Status SplitOutputTensors(const std::vector<Tensor>& combined_outputs,
403 Batch* batch) const {
404 DCHECK_GE(batch->num_tasks(), 1);
405 if (batch->num_tasks() < 1) {
406 return errors::Internal("Batch size expected to be positive; was ",
407 batch->num_tasks());
408 }
409
410 std::vector<int64> task_sizes_plus_optional_padding;
411 task_sizes_plus_optional_padding.reserve(batch->num_tasks());
412 for (int i = 0; i < batch->num_tasks(); ++i) {
413 task_sizes_plus_optional_padding.push_back(batch->task(i).size());
414 }
415 const int padding_size =
416 RoundToLowestAllowedBatchSize(batch->size()) - batch->size();
417 if (padding_size > 0) {
418 task_sizes_plus_optional_padding.push_back(padding_size);
419 }
420
421 // For each output tensor name, a divided-up tensor with one entry per task.
422 std::map<string, std::vector<Tensor>> split_tensors;
423
424 DCHECK_EQ(batch->task(0).context->num_outputs(), combined_outputs.size());
425 if (combined_outputs.size() != batch->task(0).context->num_outputs()) {
426 return errors::Internal("Wrong number of batched output tensors");
427 }
428
429 // Generate 'split_tensors' and populate the context outputs.
430 for (int i = 0; i < combined_outputs.size(); ++i) {
431 const Tensor& output_tensor = combined_outputs[i];
432 if (output_tensor.shape().dims() == 0) {
433 return errors::FailedPrecondition(
434 "Batched output tensor has 0 dimensions");
435 }
436 if (output_tensor.shape().dim_size(0) != batch->size() + padding_size) {
437 return errors::FailedPrecondition(
438 "Batched output tensor's 0th dimension does not equal the sum of "
439 "the 0th dimension sizes of the input tensors");
440 }
441
442 std::vector<Tensor> split_tensor;
443 const Status split_status = tensor::Split(
444 output_tensor, task_sizes_plus_optional_padding, &split_tensor);
445 DCHECK(split_status.ok()) << split_status.ToString();
446 if (!split_status.ok()) {
447 return errors::Internal("Tensor split operation failed: ",
448 split_status.ToString());
449 }
450 DCHECK_EQ(split_tensor.size(), task_sizes_plus_optional_padding.size());
451 if (split_tensor.size() != task_sizes_plus_optional_padding.size()) {
452 return errors::Internal(
453 "Tensor split operation did not work as expected; got ",
454 split_tensor.size(), " splits; expected ",
455 task_sizes_plus_optional_padding.size());
456 }
457
458 for (int j = 0; j < batch->num_tasks(); ++j) {
459 BatchTask& task = *(batch->mutable_task(j));
460 task.context->set_output(i, split_tensor.at(j));
461 } // (Ignore a possible final split_tensors entry containing the
462 // padding.)
463 }
464
465 return Status::OK();
466 }
467
ProcessFuncBatch(std::unique_ptr<Batch> batch) const468 void ProcessFuncBatch(std::unique_ptr<Batch> batch) const {
469 if (batch->empty()) {
470 return;
471 }
472
473 // We use the 'propagated_context' from one of the threads which setup one
474 // of the tasks. This will propagate any common context over all the threads
475 // which are running this Session, of which this BatchOp is a part.
476 WithContext wc(batch->task(batch->num_tasks() - 1).propagated_context);
477
478 OpKernelContext* last_task_context =
479 batch->task(batch->num_tasks() - 1).context;
480
481 // Regardless of the outcome, we need to propagate the status to the
482 // individual tasks and signal that they are done. We use MakeCleanup() to
483 // ensure that this happens no matter how we exit the method below.
484 Status status;
485 bool cleanup_done = false;
486 auto cleanup_fn = [&cleanup_done, &batch](const Status& status) {
487 if (cleanup_done) {
488 return;
489 }
490 for (int i = 0; i < batch->num_tasks(); ++i) {
491 batch->mutable_task(i)->context->SetStatus(status);
492 batch->mutable_task(i)->done_callback();
493 }
494 cleanup_done = true;
495 };
496 auto finally =
497 gtl::MakeCleanup([&cleanup_fn, &status] { cleanup_fn(status); });
498
499 status = ValidateBatch(*batch);
500 if (!status.ok()) {
501 return;
502 }
503
504 std::vector<Tensor> concatenated_tensors;
505 status =
506 ConcatInputTensors(*batch, last_task_context, &concatenated_tensors);
507 if (!status.ok()) {
508 return;
509 }
510 FunctionLibraryRuntime::Options opts;
511 opts.step_container = last_task_context->step_container();
512 opts.cancellation_manager = last_task_context->cancellation_manager();
513 opts.stats_collector = last_task_context->stats_collector();
514 opts.rendezvous = last_task_context->rendezvous();
515 opts.runner = last_task_context->runner();
516
517 auto* flib = last_task_context->function_library();
518 std::vector<Tensor> combined_outputs;
519 Notification done;
520 std::vector<Tensor> args(concatenated_tensors.begin(),
521 concatenated_tensors.end());
522 const auto& captured_inputs =
523 batch->task(batch->num_tasks() - 1).captured_inputs;
524 args.insert(args.end(), captured_inputs.begin(), captured_inputs.end());
525
526 // Releases the cleanup method here, because the callback of the function
527 // library runtime will handle it now.
528 finally.release();
529 flib->Run(
530 opts, fhandle_, args, &combined_outputs, [&](const Status& run_status) {
531 Status final_status;
532 auto run_finally = gtl::MakeCleanup([&]() {
533 // We do the cleanup here as an optimization, so that it runs in
534 // the underlying TF inter-op threadpool. Running it in the
535 // threadpool, let's the ensuing ops be scheduled faster,
536 // because the executor will add them to the front of the
537 // threadpool's task queue rather than the end.
538 cleanup_fn(final_status);
539 done.Notify();
540 });
541 final_status = run_status;
542 if (!final_status.ok()) {
543 return;
544 }
545 final_status = SplitOutputTensors(combined_outputs, batch.get());
546 });
547 // By waiting for the notification we are ensuring that this thread isn't
548 // used for processing other batches, which gives the batches time to
549 // coalesce upstream. So overall the number of batches going through the
550 // devices goes down, improving latency and throughput in most cases.
551 done.WaitForNotification();
552 }
553
554 // Processes a batch of one or more BatchTask entries.
ProcessBatch(std::unique_ptr<Batch> batch) const555 void ProcessBatch(std::unique_ptr<Batch> batch) const {
556 if (batch->empty()) {
557 return;
558 }
559
560 WithContext wc(batch->task(batch->num_tasks() - 1).propagated_context);
561
562 OpKernelContext* last_task_context =
563 batch->task(batch->num_tasks() - 1).context;
564 AsyncOpKernel::DoneCallback last_task_callback =
565 batch->task(batch->num_tasks() - 1).done_callback;
566
567 OP_REQUIRES_OK_ASYNC(last_task_context, ValidateBatch(*batch),
568 last_task_callback);
569
570 // All tasks should have the same number of input edges.
571 const int num_input_edges = batch->task(0).inputs.size();
572 std::vector<Tensor> concatenated_tensors;
573 const Status concat_status =
574 ConcatInputTensors(*batch, last_task_context, &concatenated_tensors);
575 OP_REQUIRES_OK_ASYNC(last_task_context, concat_status, last_task_callback);
576
577 // Process each input edge one at a time (the typical case has just one).
578 for (int i = 0; i < num_input_edges; ++i) {
579 last_task_context->set_output(i, concatenated_tensors.at(i));
580
581 // Emit batch->num_tasks() - 1 empty output tensors.
582 for (int task_idx = 0; task_idx < batch->num_tasks() - 1; ++task_idx) {
583 const BatchTask& task = batch->task(task_idx);
584 TensorShape output_shape(task.inputs.at(i).shape());
585 output_shape.set_dim(0, 0);
586 Tensor* output = nullptr;
587 OP_REQUIRES_OK_ASYNC(
588 task.context,
589 task.context->allocate_output(i, output_shape, &output),
590 task.done_callback);
591 }
592 }
593 // Emit batch->num_tasks() - 1 empty index tensors.
594 for (int task_idx = 0; task_idx < batch->num_tasks() - 1; ++task_idx) {
595 const BatchTask& task = batch->task(task_idx);
596 TensorShape index_shape({0, 3});
597 Tensor* output = nullptr;
598 OP_REQUIRES_OK_ASYNC(
599 task.context,
600 task.context->allocate_output(num_input_edges, index_shape, &output),
601 task.done_callback);
602 }
603 // Emit all ID tensors.
604 for (int task_idx = 0; task_idx < batch->num_tasks(); ++task_idx) {
605 const BatchTask& task = batch->task(task_idx);
606 Tensor* id;
607 OP_REQUIRES_OK_ASYNC(task.context,
608 task.context->allocate_output(num_input_edges + 1,
609 TensorShape({}), &id),
610 task.done_callback);
611 id->scalar<int64>()() = task.guid;
612 }
613 OP_REQUIRES_OK_ASYNC(
614 last_task_context,
615 EmitIndexTensor(last_task_context, *batch, num_input_edges),
616 last_task_callback);
617
618 // Signal done for each element of the batch. (At this point, the contexts
619 // are no longer guaranteed to remain live.)
620 for (int task_idx = 0; task_idx < batch->num_tasks(); ++task_idx) {
621 batch->mutable_task(task_idx)->done_callback();
622 }
623 }
624
625 // Emits an index tensor, which the Unbatch op will use to un-concatenate
626 // the tensor and attribute the pieces to the right batch keys. The index
627 // tensor contains, for each input: [batch_key, start_offset, end_offset]
628 // where start_offset and end_offset represent the range of entries in the
629 // concatenated tensors that belong to that input.
630 //
631 // Emits the result to the output at 'output_index' using 'context'.
EmitIndexTensor(OpKernelContext * context,const Batch & batch,int output_index)632 static Status EmitIndexTensor(OpKernelContext* context, const Batch& batch,
633 int output_index) {
634 const TensorShape index_shape({batch.num_tasks(), 3});
635 Tensor* index = nullptr;
636 TF_RETURN_IF_ERROR(
637 context->allocate_output(output_index, index_shape, &index));
638 auto index_flat = index->shaped<int64, 2>({batch.num_tasks(), 3});
639 size_t offset = 0;
640 for (int task_idx = 0; task_idx < batch.num_tasks(); ++task_idx) {
641 const BatchTask& task = batch.task(task_idx);
642 index_flat(task_idx, 0) = task.guid;
643 index_flat(task_idx, 1) = offset;
644 index_flat(task_idx, 2) = offset + task.size();
645 offset += task.size();
646 }
647 return Status::OK();
648 }
649
650 // Looks up the batcher queue for 'queue_name'. If it did't previously exist,
651 // creates it.
LookupOrCreateBatcherQueue(const string & queue_name,BatcherQueue ** queue)652 Status LookupOrCreateBatcherQueue(const string& queue_name,
653 BatcherQueue** queue) {
654 mutex_lock l(batcher_queues_mu_);
655
656 auto it = batcher_queues_.find(queue_name);
657 if (it != batcher_queues_.end()) {
658 *queue = it->second.get();
659 return Status::OK();
660 }
661
662 std::unique_ptr<BatcherQueue> new_queue;
663 auto process_batch_callback = [this](std::unique_ptr<Batch> batch) {
664 if (fhandle_ == kInvalidHandle) {
665 ProcessBatch(std::move(batch));
666 } else {
667 ProcessFuncBatch(std::move(batch));
668 }
669 };
670 TF_RETURN_IF_ERROR(batcher_->AddQueue(batcher_queue_options_,
671 process_batch_callback, &new_queue));
672 *queue = new_queue.get();
673 batcher_queues_[queue_name] = std::move(new_queue);
674 return Status::OK();
675 }
676
677 // A batch scheduler, and options for creating queues.
678 std::shared_ptr<Batcher> batcher_;
679 Batcher::QueueOptions batcher_queue_options_;
680
681 // A collection of batcher queues, keyed on queue name.
682 // TODO(olston): Garbage-collect unused queues (perhaps simply remove empty
683 // ones (with a time delay?); it's okay if they get recreated later).
684 mutable mutex batcher_queues_mu_;
685 std::map<string, std::unique_ptr<BatcherQueue>> batcher_queues_
686 GUARDED_BY(batcher_queues_mu_);
687
688 std::vector<int32> allowed_batch_sizes_;
689 FunctionLibraryRuntime::Handle fhandle_;
690 };
691
692 class BatchFunctionKernel : public AsyncOpKernel {
693 public:
BatchFunctionKernel(OpKernelConstruction * c)694 explicit BatchFunctionKernel(OpKernelConstruction* c) : AsyncOpKernel(c) {
695 OP_REQUIRES_OK(c, c->GetAttr("container", &container_));
696 OP_REQUIRES_OK(c, c->GetAttr("shared_name", &shared_name_));
697 // If shared_name is not supplied, use name instead (prevent collisions by
698 // default).
699 if (shared_name_.empty()) {
700 shared_name_ = name();
701 }
702 OP_REQUIRES_OK(c, c->GetAttr("batching_queue", &batcher_queue_));
703 OP_REQUIRES_OK(c, c->GetAttr("num_batch_threads", &num_batch_threads_));
704 OP_REQUIRES_OK(c, c->GetAttr("max_batch_size", &max_batch_size_));
705 OP_REQUIRES_OK(c,
706 c->GetAttr("batch_timeout_micros", &batch_timeout_micros_));
707 OP_REQUIRES_OK(c,
708 c->GetAttr("max_enqueued_batches", &max_enqueued_batches_));
709 OP_REQUIRES_OK(c, c->GetAttr("allowed_batch_sizes", &allowed_batch_sizes_));
710 OP_REQUIRES_OK(c, ValidateAllowedBatchSizes());
711
712 auto lib = c->function_library();
713 OP_REQUIRES(c, lib != nullptr, errors::Internal("No function library"));
714 NameAttrList func;
715 OP_REQUIRES_OK(c, c->GetAttr("f", &func));
716 OP_REQUIRES_OK(
717 c, lib->Instantiate(func.name(), AttrSlice(&func.attr()), &fhandle_));
718 }
719
IsExpensive()720 bool IsExpensive() override { return false; }
721
ComputeAsync(OpKernelContext * c,DoneCallback done)722 void ComputeAsync(OpKernelContext* c, DoneCallback done) final {
723 BatchResource* br;
724 std::function<Status(BatchResource**)> creator = [this](BatchResource** r) {
725 std::unique_ptr<BatchResource> new_resource;
726 TF_RETURN_IF_ERROR(
727 BatchResource::Create(num_batch_threads_, max_batch_size_,
728 batch_timeout_micros_, max_enqueued_batches_,
729 allowed_batch_sizes_, fhandle_, &new_resource));
730 *r = new_resource.release();
731 return Status::OK();
732 };
733 OP_REQUIRES_OK_ASYNC(c,
734 c->resource_manager()->LookupOrCreate(
735 container_, shared_name_, &br, creator),
736 done);
737 const Status status =
738 br->RegisterInput(random::New64(), c, batcher_queue_, done);
739 br->Unref();
740 OP_REQUIRES_OK_ASYNC(c, status, done);
741 // Assume br calls done, so nothing to do here.
742 }
743
744 // Validates 'allowed_batch_sizes_'. The entries must increase monotonically,
745 // and the last one must equal 'max_batch_size_'.
ValidateAllowedBatchSizes() const746 Status ValidateAllowedBatchSizes() const {
747 if (allowed_batch_sizes_.empty()) {
748 return Status::OK();
749 }
750 int32 last_size = 0;
751 for (size_t i = 0; i < allowed_batch_sizes_.size(); ++i) {
752 const int32 size = allowed_batch_sizes_.at(i);
753 if (i > 0 && size <= last_size) {
754 return errors::InvalidArgument(
755 "allowed_batch_sizes entries must be monotonically increasing");
756 }
757 if (i == allowed_batch_sizes_.size() - 1 && size != max_batch_size_) {
758 return errors::InvalidArgument(
759 "final entry in allowed_batch_sizes must equal max_batch_size");
760 }
761 last_size = size;
762 }
763 return Status::OK();
764 }
765
766 private:
767 string container_;
768 string shared_name_;
769 string batcher_queue_;
770 int32 num_batch_threads_;
771 int32 max_batch_size_;
772 int32 batch_timeout_micros_;
773 int32 max_enqueued_batches_;
774 std::vector<int32> allowed_batch_sizes_;
775 FunctionLibraryRuntime::Handle fhandle_;
776 };
777
778 REGISTER_KERNEL_BUILDER(Name("BatchFunction").Device(DEVICE_CPU),
779 BatchFunctionKernel);
780
781 class BatchKernel : public AsyncOpKernel {
782 public:
BatchKernel(OpKernelConstruction * c)783 explicit BatchKernel(OpKernelConstruction* c) : AsyncOpKernel(c) {
784 OP_REQUIRES_OK(c, c->GetAttr("container", &container_));
785 OP_REQUIRES_OK(c, c->GetAttr("shared_name", &shared_name_));
786 // If shared_name is not supplied, use name instead (prevent collisions by
787 // default).
788 if (shared_name_.empty()) {
789 shared_name_ = name();
790 }
791 OP_REQUIRES_OK(c, c->GetAttr("batching_queue", &batcher_queue_));
792 OP_REQUIRES_OK(c, c->GetAttr("num_batch_threads", &num_batch_threads_));
793 OP_REQUIRES_OK(c, c->GetAttr("max_batch_size", &max_batch_size_));
794 OP_REQUIRES_OK(c,
795 c->GetAttr("batch_timeout_micros", &batch_timeout_micros_));
796 OP_REQUIRES_OK(c,
797 c->GetAttr("max_enqueued_batches", &max_enqueued_batches_));
798 OP_REQUIRES_OK(c, c->GetAttr("allowed_batch_sizes", &allowed_batch_sizes_));
799 OP_REQUIRES_OK(c, ValidateAllowedBatchSizes());
800 }
801
ComputeAsync(OpKernelContext * c,DoneCallback done)802 void ComputeAsync(OpKernelContext* c, DoneCallback done) final {
803 BatchResource* br;
804 std::function<Status(BatchResource**)> creator = [this](BatchResource** r) {
805 std::unique_ptr<BatchResource> new_resource;
806 TF_RETURN_IF_ERROR(BatchResource::Create(
807 num_batch_threads_, max_batch_size_, batch_timeout_micros_,
808 max_enqueued_batches_, allowed_batch_sizes_, kInvalidHandle,
809 &new_resource));
810 *r = new_resource.release();
811 return Status::OK();
812 };
813 OP_REQUIRES_OK_ASYNC(c,
814 c->resource_manager()->LookupOrCreate(
815 container_, shared_name_, &br, creator),
816 done);
817 const Status status =
818 br->RegisterInput(random::New64(), c, batcher_queue_, done);
819 br->Unref();
820 OP_REQUIRES_OK_ASYNC(c, status, done);
821 // Assume br calls done, so nothing to do here.
822 }
823
824 // Validates 'allowed_batch_sizes_'. The entries must increase monotonically,
825 // and the last one must equal 'max_batch_size_'.
ValidateAllowedBatchSizes() const826 Status ValidateAllowedBatchSizes() const {
827 if (allowed_batch_sizes_.empty()) {
828 return Status::OK();
829 }
830 int32 last_size = 0;
831 for (size_t i = 0; i < allowed_batch_sizes_.size(); ++i) {
832 const int32 size = allowed_batch_sizes_.at(i);
833 if (i > 0 && size <= last_size) {
834 return errors::InvalidArgument(
835 "allowed_batch_sizes entries must be monotonically increasing");
836 }
837 if (i == allowed_batch_sizes_.size() - 1 && size != max_batch_size_) {
838 return errors::InvalidArgument(
839 "final entry in allowed_batch_sizes must equal max_batch_size");
840 }
841 last_size = size;
842 }
843 return Status::OK();
844 }
845
846 private:
847 string container_;
848 string shared_name_;
849 string batcher_queue_;
850 int32 num_batch_threads_;
851 int32 max_batch_size_;
852 int32 batch_timeout_micros_;
853 int32 max_enqueued_batches_;
854 std::vector<int32> allowed_batch_sizes_;
855 };
856
857 REGISTER_KERNEL_BUILDER(Name("Batch").Device(DEVICE_CPU), BatchKernel);
858
859 // A class encapsulating the state and logic for unbatching tensors.
860 //
861 // UnbatchResource keeps two data structures indexed by batch-key: one which has
862 // the continuations for all concurrent kernels which are waiting for tensors
863 // and another which has tensors which are waiting for their corresponding
864 // kernels to run. Whenever a kernel runs, we either grab its tensor if it's
865 // waiting already, or we insert it in the queue and then look at its tensor to
866 // see if it can be used to dispatch any stored continuations.
867 class UnbatchResource : public ResourceBase {
868 public:
UnbatchResource(int32 timeout_micros)869 explicit UnbatchResource(int32 timeout_micros)
870 : timeout_micros_(timeout_micros),
871 timeout_enforcer_(new serving::PeriodicFunction(
872 [this] { EnforceTimeout(); }, 1000 /* 1 ms */)) {}
873
~UnbatchResource()874 ~UnbatchResource() override {
875 // Tear down 'timeout_enforcer_' first, since it accesses other state in
876 // this class.
877 timeout_enforcer_ = nullptr;
878 }
879
DebugString() const880 string DebugString() const final { return "UnbatchResource"; }
881
Compute(OpKernelContext * context,AsyncOpKernel::DoneCallback done)882 Status Compute(OpKernelContext* context, AsyncOpKernel::DoneCallback done) {
883 const Tensor& data_t = context->input(0);
884 const Tensor& batch_index_t = context->input(1);
885
886 if (batch_index_t.shape().dim_size(0) > data_t.shape().dim_size(0)) {
887 return errors::InvalidArgument(
888 "Wrong shape for index tensor. Expected 0th dimension size to be no "
889 "greater than ",
890 data_t.shape().dim_size(0),
891 "; Got: ", batch_index_t.shape().dim_size(0), ".");
892 }
893 if (batch_index_t.shape().dim_size(1) != 3) {
894 return errors::InvalidArgument(
895 "Wrong shape for index tensor. Expected 1st dimension size to be 3 ; "
896 "Got: ",
897 batch_index_t.shape().dim_size(1), ".");
898 }
899
900 const int64 batch_key = context->input(2).scalar<int64>()();
901 const bool nonempty_input = batch_index_t.dim_size(0) > 0;
902
903 // If we have a non-empty tensor, slice it up.
904 // (It is important to do this outside of the critical section below.)
905 // The following variables are populated iff 'nonempty_input==true'.
906 std::vector<int64> sizes;
907 std::vector<int64> batch_keys;
908 std::vector<Tensor> split_inputs;
909 if (nonempty_input) {
910 auto batch_indices =
911 batch_index_t.shaped<int64, 2>({batch_index_t.dim_size(0), 3});
912 for (int i = 0; i < batch_index_t.dim_size(0); ++i) {
913 sizes.push_back(batch_indices(i, 2) - batch_indices(i, 1));
914 batch_keys.push_back(batch_indices(i, 0));
915 }
916
917 const DataType type = data_t.dtype();
918 switch (type) {
919 #define CASE(type) \
920 case DataTypeToEnum<type>::value: \
921 TF_RETURN_IF_ERROR(Split<type>(context, data_t, sizes, &split_inputs)); \
922 break;
923 TF_CALL_ALL_TYPES(CASE);
924 #undef CASE
925 default:
926 return errors::InvalidArgument("Unsupported data type: ", type);
927 }
928 }
929
930 // Critical section.
931 std::vector<AsyncOpKernel::DoneCallback> done_callbacks_to_call;
932 Status status = [&]() -> Status {
933 mutex_lock ml(mu_);
934
935 // Check to see whether the tensor we want is already ready.
936 auto tensor_it = waiting_tensors_.find(batch_key);
937 if (tensor_it != waiting_tensors_.end()) {
938 context->set_output(0, tensor_it->second.tensor);
939 waiting_tensors_.erase(tensor_it);
940 done_callbacks_to_call.push_back(done);
941 return Status::OK();
942 }
943
944 const uint64 deadline_micros =
945 Env::Default()->NowMicros() + timeout_micros_;
946
947 // Add ourselves to the waitlist for tensors.
948 if (!waiting_callbacks_
949 .emplace(batch_key,
950 WaitingCallback{deadline_micros, context, done})
951 .second) {
952 return errors::AlreadyExists(
953 "Multiple session runs with the same batch key.");
954 }
955
956 // If we have a non-empty tensor, finish the waitlisted runs,
957 // and store any remaining pieces.
958 if (nonempty_input) {
959 for (size_t i = 0; i < batch_keys.size(); ++i) {
960 auto runs_it = waiting_callbacks_.find(batch_keys[i]);
961 if (runs_it != waiting_callbacks_.end()) {
962 runs_it->second.context->set_output(0, split_inputs[i]);
963 done_callbacks_to_call.push_back(runs_it->second.done);
964 waiting_callbacks_.erase(runs_it);
965 } else {
966 // Note: the deadline here is in case we are arriving late and the
967 // kernel that should rendezvous with this tensor has already waited
968 // and timed out.
969 if (!waiting_tensors_
970 .emplace(batch_keys[i],
971 WaitingTensor{deadline_micros, split_inputs[i]})
972 .second) {
973 return errors::AlreadyExists(
974 "Multiple tensors returned for same batch key.");
975 }
976 }
977 }
978 }
979
980 return Status::OK();
981 }();
982
983 for (const AsyncOpKernel::DoneCallback& done_callback :
984 done_callbacks_to_call) {
985 done_callback();
986 }
987
988 return status;
989 }
990
991 private:
992 // Evicts waiting tensors and callbacks that have exceeded their deadline.
EnforceTimeout()993 void EnforceTimeout() {
994 const uint64 now = Env::Default()->NowMicros();
995 std::vector<WaitingCallback> evicted_callbacks;
996
997 {
998 mutex_lock ml(mu_);
999
1000 for (auto it = waiting_tensors_.begin(); it != waiting_tensors_.end();) {
1001 const WaitingTensor& waiting_tensor = it->second;
1002 if (waiting_tensor.deadline_micros < now) {
1003 it = waiting_tensors_.erase(it);
1004 } else {
1005 ++it;
1006 }
1007 }
1008
1009 for (auto it = waiting_callbacks_.begin();
1010 it != waiting_callbacks_.end();) {
1011 const WaitingCallback& waiting_callback = it->second;
1012 if (waiting_callback.deadline_micros < now) {
1013 evicted_callbacks.push_back(waiting_callback);
1014 it = waiting_callbacks_.erase(it);
1015 } else {
1016 ++it;
1017 }
1018 }
1019 }
1020
1021 for (const WaitingCallback& evicted_callback : evicted_callbacks) {
1022 evicted_callback.context->CtxFailureWithWarning(errors::DeadlineExceeded(
1023 "Batched data did not arrive within timeout window."));
1024 evicted_callback.done();
1025 }
1026 }
1027
1028 struct WaitingTensor {
1029 uint64 deadline_micros;
1030 Tensor tensor;
1031 };
1032
1033 struct WaitingCallback {
1034 uint64 deadline_micros;
1035 OpKernelContext* context;
1036 AsyncOpKernel::DoneCallback done;
1037 };
1038
1039 const int32 timeout_micros_;
1040
1041 mutex mu_;
1042
1043 // Maps keyed by BatchKey of tensors waiting for callbacks and callbacks
1044 // waiting for tensors.
1045 std::unordered_map<int64, WaitingTensor> waiting_tensors_ GUARDED_BY(mu_);
1046 std::unordered_map<int64, WaitingCallback> waiting_callbacks_ GUARDED_BY(mu_);
1047
1048 // A thread that evicts waiting tensors and callbacks that have exceeded their
1049 // deadline.
1050 std::unique_ptr<serving::PeriodicFunction> timeout_enforcer_;
1051 };
1052
1053 class UnbatchKernel : public AsyncOpKernel {
1054 public:
UnbatchKernel(OpKernelConstruction * c)1055 explicit UnbatchKernel(OpKernelConstruction* c) : AsyncOpKernel(c) {
1056 OP_REQUIRES_OK(c, c->GetAttr("container", &container_));
1057 OP_REQUIRES_OK(c, c->GetAttr("shared_name", &shared_name_));
1058 // If shared_name is not supplied, use name instead (prevent collisions by
1059 // default).
1060 if (shared_name_.empty()) {
1061 shared_name_ = name();
1062 }
1063 OP_REQUIRES_OK(c, c->GetAttr("timeout_micros", &timeout_micros_));
1064 }
1065
ComputeAsync(OpKernelContext * c,DoneCallback done)1066 void ComputeAsync(OpKernelContext* c, DoneCallback done) final {
1067 UnbatchResource* ubr;
1068 std::function<Status(UnbatchResource**)> creator =
1069 [this](UnbatchResource** r) {
1070 *r = new UnbatchResource(timeout_micros_);
1071 return Status::OK();
1072 };
1073 OP_REQUIRES_OK_ASYNC(c,
1074 c->resource_manager()->LookupOrCreate(
1075 container_, shared_name_, &ubr, creator),
1076 done);
1077 auto status = ubr->Compute(c, done);
1078 ubr->Unref();
1079 OP_REQUIRES_OK_ASYNC(c, status, done);
1080 // Assume ubr calls done, so nothing to do here.
1081 }
1082
1083 private:
1084 string container_;
1085 string shared_name_;
1086 int32 timeout_micros_;
1087 };
1088 REGISTER_KERNEL_BUILDER(Name("Unbatch").Device(DEVICE_CPU), UnbatchKernel);
1089
1090 // A class encapsulating the state and logic for batching tensors
1091 // deterministically for the gradient of unbatch.
1092 class UnbatchGradResource : public ResourceBase {
1093 public:
UnbatchGradResource()1094 UnbatchGradResource() {}
1095
DebugString() const1096 string DebugString() const final { return "UnbatchGradResource"; }
1097
1098 // Flushes the information for one batch, given its context and done
1099 // callback. Clears all information about it from the available_tensors_.
OutputBatch(OpKernelContext * context,const AsyncOpKernel::DoneCallback & done)1100 Status OutputBatch(OpKernelContext* context,
1101 const AsyncOpKernel::DoneCallback& done)
1102 EXCLUSIVE_LOCKS_REQUIRED(mu_) {
1103 const Tensor& batch_index_t = context->input(1);
1104 auto batch_index =
1105 batch_index_t.shaped<int64, 2>({batch_index_t.dim_size(0), 3});
1106 std::vector<Tensor> tensors;
1107 for (int i = 0; i < batch_index_t.dim_size(0); ++i) {
1108 auto available_it = available_tensors_.find(batch_index(i, 0));
1109 if (available_it == available_tensors_.end()) {
1110 return errors::Internal("bad bookkeeping of available tensors.");
1111 }
1112 tensors.push_back(available_it->second);
1113 available_tensors_.erase(available_it);
1114 }
1115
1116 const DataType type = tensors[0].dtype();
1117 Tensor concatenated_tensor;
1118 switch (type) {
1119 #define CASE(type) \
1120 case DataTypeToEnum<type>::value: \
1121 TF_RETURN_IF_ERROR(Concat<type>(context, tensors, &concatenated_tensor)); \
1122 context->set_output(0, concatenated_tensor); \
1123 break;
1124 TF_CALL_ALL_TYPES(CASE);
1125 #undef CASE
1126 default:
1127 return errors::InvalidArgument("Unsupported data type: ", type);
1128 }
1129 done();
1130 return Status::OK();
1131 }
1132
1133 // Ingests data from one invocation of the op.
Compute(OpKernelContext * context,const AsyncOpKernel::DoneCallback & done)1134 Status Compute(OpKernelContext* context,
1135 const AsyncOpKernel::DoneCallback& done) {
1136 const Tensor& data_t = context->input(0);
1137 const Tensor& batch_index_t = context->input(1);
1138 const Tensor& grad_t = context->input(2);
1139
1140 mutex_lock ml(mu_);
1141
1142 const int64 batch_key = context->input(3).scalar<int64>()();
1143 // Mark our tensor as available.
1144 if (!available_tensors_.emplace(batch_key, grad_t).second) {
1145 return errors::InvalidArgument("Two runs with the same batch key.");
1146 }
1147
1148 // Check whether we have a valid input tensor and, if so, create its
1149 // dispatch logic.
1150 if (data_t.NumElements() > 0) {
1151 if (batch_index_t.NumElements() == 0) {
1152 return errors::InvalidArgument(
1153 "batch_index is empty while the tensor isn't.");
1154 }
1155 std::unordered_set<int64> missing_tensors;
1156 const auto batch_index =
1157 batch_index_t.shaped<int64, 2>({batch_index_t.dim_size(0), 3});
1158 for (int i = 0; i < batch_index_t.dim_size(0); ++i) {
1159 const int64 batch_key = batch_index(i, 0);
1160 if (available_tensors_.find(batch_key) == available_tensors_.end()) {
1161 missing_tensors.emplace(batch_key);
1162 }
1163 }
1164 if (missing_tensors.empty()) {
1165 return OutputBatch(context, done);
1166 }
1167 if (!available_batches_
1168 .emplace(batch_key, Batch{missing_tensors, context, done})
1169 .second) {
1170 return errors::InvalidArgument(
1171 "Batch key with valid batch used twice.");
1172 }
1173 for (const int64 i : missing_tensors) {
1174 if (!desired_tensor_to_batch_map_.emplace(i, batch_key).second) {
1175 return errors::InvalidArgument(
1176 "Missing tensor wanted by more than one batch.");
1177 }
1178 }
1179 } else {
1180 // If we don't have a valid input tensor we can output an empty tensor and
1181 // call our done closure.
1182 TensorShape output_shape(grad_t.shape());
1183 output_shape.set_dim(0, 0);
1184 Tensor* output = nullptr;
1185 TF_RETURN_IF_ERROR(context->allocate_output(0, output_shape, &output));
1186 done();
1187 }
1188
1189 // Search to see whether our tensor is desired by any existing batch.
1190 auto desire_it = desired_tensor_to_batch_map_.find(batch_key);
1191 if (desire_it != desired_tensor_to_batch_map_.end()) {
1192 // Mark our tensor as no longer missing.
1193 auto batch_it = available_batches_.find(desire_it->second);
1194 desired_tensor_to_batch_map_.erase(desire_it);
1195 if (batch_it == available_batches_.end()) {
1196 return errors::InvalidArgument("Batch no longer exists.");
1197 }
1198 batch_it->second.missing_tensors.erase(batch_key);
1199 // If all tensors are available we should concatenate them and dispatch
1200 // the batch.
1201 if (batch_it->second.missing_tensors.empty()) {
1202 TF_RETURN_IF_ERROR(
1203 OutputBatch(batch_it->second.context, batch_it->second.done));
1204 available_batches_.erase(batch_it);
1205 }
1206 }
1207 return Status::OK();
1208 }
1209
1210 private:
1211 mutex mu_;
1212
1213 // Represents a still-incomplete batch of tensors. When all tensors become
1214 // available they will be concatenated in the right order and sent through the
1215 // context.
1216 struct Batch {
1217 // Batch keys for tensors which are still missing from this batch. When this
1218 // is empty the Tensors can be concatenated and forwarded.
1219 std::unordered_set<int64> missing_tensors;
1220
1221 // Context and callback for the session responsible for finishing this
1222 // batch.
1223 OpKernelContext* context;
1224 AsyncOpKernel::DoneCallback done;
1225 };
1226
1227 // Map from batch key of the session which will output the batched gradients
1228 // to still-incomplete batches.
1229 std::unordered_map<int64, Batch> available_batches_;
1230
1231 // Map from batch key to tensors which are waiting for their batches to be
1232 // available.
1233 std::unordered_map<int64, Tensor> available_tensors_;
1234
1235 // Map from batch key of a tensor which is not yet available to the batch key
1236 // of the batch to which it belongs.
1237 std::unordered_map<int64, int64> desired_tensor_to_batch_map_;
1238 };
1239
1240 class UnbatchGradKernel : public AsyncOpKernel {
1241 public:
UnbatchGradKernel(OpKernelConstruction * c)1242 explicit UnbatchGradKernel(OpKernelConstruction* c) : AsyncOpKernel(c) {
1243 OP_REQUIRES_OK(c, c->GetAttr("container", &container_));
1244 OP_REQUIRES_OK(c, c->GetAttr("shared_name", &shared_name_));
1245 // If shared_name is not supplied, use name instead (prevent collisions by
1246 // default).
1247 if (shared_name_.empty()) {
1248 shared_name_ = name();
1249 }
1250 }
1251
ComputeAsync(OpKernelContext * c,DoneCallback done)1252 void ComputeAsync(OpKernelContext* c, DoneCallback done) final {
1253 UnbatchGradResource* ubr;
1254 std::function<Status(UnbatchGradResource**)> creator =
1255 [](UnbatchGradResource** r) {
1256 *r = new UnbatchGradResource();
1257 return Status::OK();
1258 };
1259 OP_REQUIRES_OK_ASYNC(c,
1260 c->resource_manager()->LookupOrCreate(
1261 container_, shared_name_, &ubr, creator),
1262 done);
1263 Status status = ubr->Compute(c, done);
1264 ubr->Unref();
1265 OP_REQUIRES_OK_ASYNC(c, status, done);
1266 // Assume ubr calls done, so nothing to do here.
1267 }
1268
1269 private:
1270 string container_;
1271 string shared_name_;
1272 };
1273 REGISTER_KERNEL_BUILDER(Name("UnbatchGrad").Device(DEVICE_CPU),
1274 UnbatchGradKernel);
1275
1276 } // namespace tensorflow
1277