• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/function.h"
17 #include "tensorflow/core/framework/op_kernel.h"
18 #include "tensorflow/core/framework/register_types.h"
19 #include "tensorflow/core/framework/resource_mgr.h"
20 #include "tensorflow/core/framework/tensor.h"
21 #include "tensorflow/core/framework/tensor_util.h"
22 #include "tensorflow/core/framework/types.h"
23 #include "tensorflow/core/kernels/batching_util/periodic_function.h"
24 #include "tensorflow/core/kernels/batching_util/shared_batch_scheduler.h"
25 #include "tensorflow/core/kernels/concat_lib.h"
26 #include "tensorflow/core/kernels/ops_util.h"
27 #include "tensorflow/core/kernels/split_lib.h"
28 #include "tensorflow/core/lib/gtl/cleanup.h"
29 #include "tensorflow/core/lib/random/random.h"
30 #include "tensorflow/core/platform/macros.h"
31 
32 namespace tensorflow {
33 
34 typedef Eigen::ThreadPoolDevice CPUDevice;
35 typedef Eigen::GpuDevice GPUDevice;
36 #ifdef TENSORFLOW_USE_SYCL
37 typedef Eigen::SyclDevice SYCLDevice;
38 #endif  // TENSORFLOW_USE_SYCL
39 
40 // Concatenates 'inputs' into a single tensor along the zeroth dimension.
41 // Requires that all elements of 'inputs' have element type T. Writes to the
42 // op's output at position 'output_index', using 'context' for the allocation to
43 // ensure proper device placement.
44 template <typename T>
Concat(OpKernelContext * context,const gtl::ArraySlice<Tensor> & inputs,Tensor * output)45 Status Concat(OpKernelContext* context, const gtl::ArraySlice<Tensor>& inputs,
46               Tensor* output) {
47   const int input_dims = inputs[0].dims();
48   const TensorShape& input_shape = inputs[0].shape();
49 
50   // Note that we reduce the concat of k-dimensional tensors into a two
51   // dimensional concat. Assuming the dimensions of any input tensor are
52   // {y0, y1,...,ym-1}, we flatten it to {1, y}, where y = Prod_i(yi).
53   std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>> inputs_flat;
54   inputs_flat.reserve(inputs.size());
55   int64 output_dim0 = 0;
56   for (size_t i = 0; i < inputs.size(); ++i) {
57     const Tensor& input = inputs[i];
58     if (input.dims() != input_dims) {
59       return errors::InvalidArgument(
60           "Ranks of all input tensors should match: shape[0] = ",
61           input_shape.DebugString(), " vs. shape[", i,
62           "] = ", input.shape().DebugString());
63     }
64     for (int j = 1; j < input_dims; ++j) {
65       if (input.dim_size(j) != input_shape.dim_size(j)) {
66         return errors::InvalidArgument(
67             "Dimensions of inputs should match: shape[0] = ",
68             input_shape.DebugString(), " vs. shape[", i,
69             "] = ", input.shape().DebugString());
70       }
71     }
72     if (input.NumElements() > 0) {
73       inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
74           input.shaped<T, 2>({1, input.NumElements()})));
75     }
76     output_dim0 += input.dim_size(0);
77   }
78 
79   TensorShape output_shape(input_shape);
80   output_shape.set_dim(0, output_dim0);
81   TF_RETURN_IF_ERROR(
82       context->allocate_temp(DataTypeToEnum<T>::value, output_shape, output));
83   if (output->NumElements() > 0) {
84     auto output_flat = output->shaped<T, 2>({1, output->NumElements()});
85 #if GOOGLE_CUDA
86     if (std::is_same<Device, GPUDevice>::value) {
87       ConcatGPU<T>(context, inputs_flat, output, &output_flat);
88       return Status::OK();
89     }
90 #endif  // GOOGLE_CUDA
91     ConcatCPU<T>(context->device(), inputs_flat, &output_flat);
92   }
93 
94   return Status::OK();
95 }
96 
97 // The Split*() functions split 'input' with element type T into 'sizes.size()'
98 // tensors along the zeroth dimension, with the ith split having zeroth-
99 // dimension size 'sizes[i]'. They allocate the output tensors using 'context',
100 // for proper device placement.
101 
102 // Handles special cases that are cheap. Sets 'done==true' iff it found an
103 // applicable special case and wrote to the outputs. Otherwise acts as a no-op.
104 template <typename T>
SplitEasyCases(OpKernelContext * context,const Tensor & input,const gtl::ArraySlice<int64> & sizes,std::vector<Tensor> * outputs,bool * done)105 Status SplitEasyCases(OpKernelContext* context, const Tensor& input,
106                       const gtl::ArraySlice<int64>& sizes,
107                       std::vector<Tensor>* outputs, bool* done) {
108   *done = false;
109 
110   int64 total_size = 0;
111   for (const int64 size : sizes) {
112     total_size += size;
113   }
114   if (total_size > input.shape().dim_size(0)) {
115     return errors::InvalidArgument(
116         "Sum of split sizes must not exceed dim0-size of input tensor");
117   }
118 
119   // Special case 0: trivial 1-way split.
120   if (sizes.size() == 1 && sizes.at(0) == input.shape().dim_size(0)) {
121     outputs->push_back(input);
122     *done = true;
123     return Status::OK();
124   }
125 
126   // Special case 1: input is aligned.
127   if (IsInnerDimsSizeAligned<T>(input.shape())) {
128     int64 position = 0;
129     for (const int64 size : sizes) {
130       outputs->emplace_back(input.Slice(position, position + size));
131       position += size;
132     }
133     *done = true;
134     return Status::OK();
135   }
136 
137   return Status::OK();
138 }
139 
140 // Handles the general case, on CPU.
141 template <typename T>
SplitCPU(OpKernelContext * context,const Tensor & input,const gtl::ArraySlice<int64> & sizes,std::vector<Tensor> * outputs)142 Status SplitCPU(OpKernelContext* context, const Tensor& input,
143                 const gtl::ArraySlice<int64>& sizes,
144                 std::vector<Tensor>* outputs) {
145   int64 suffix_dim_size = 1;
146   for (int i = 1; i < input.shape().dims(); ++i) {
147     suffix_dim_size *= input.shape().dim_size(i);
148   }
149   auto input_reshaped =
150       input.shaped<T, 2>({input.shape().dim_size(0), suffix_dim_size});
151 
152   int64 position = 0;
153   for (const int64 size : sizes) {
154     TensorShape output_shape = input.shape();
155     output_shape.set_dim(0, size);
156     Tensor output;
157     TF_RETURN_IF_ERROR(
158         context->allocate_temp(input.dtype(), output_shape, &output));
159     auto output_shaped = output.shaped<T, 2>({size, suffix_dim_size});
160 
161     Eigen::DSizes<Eigen::DenseIndex, 2> slice_indices{position, 0};
162     Eigen::DSizes<Eigen::DenseIndex, 2> slice_sizes{size, suffix_dim_size};
163     functor::Split<CPUDevice, T, 2>()(context->eigen_device<CPUDevice>(),
164                                       output_shaped, input_reshaped,
165                                       slice_indices, slice_sizes);
166 
167     outputs->emplace_back(output);
168 
169     position += size;
170   }
171 
172   return Status::OK();
173 }
174 
175 #if GOOGLE_CUDA
176 
177 // Handles the general case, on GPU.
178 template <typename T>
SplitGPU(OpKernelContext * context,const Tensor & input,const gtl::ArraySlice<int64> & sizes,std::vector<Tensor> * outputs)179 Status SplitGPU(OpKernelContext* context, const Tensor& input,
180                 const gtl::ArraySlice<int64>& sizes,
181                 std::vector<Tensor>* outputs) {
182   // TODO(olston, apassos): Implement this.
183   LOG(FATAL) << "Not yet implemented";  // Crash ok
184 }
185 
186 #endif  // GOOGLE_CUDA
187 
188 // The outer function that dispatches to the various Split*() functions above.
189 template <typename T>
Split(OpKernelContext * context,const Tensor & input,const gtl::ArraySlice<int64> & sizes,std::vector<Tensor> * outputs)190 Status Split(OpKernelContext* context, const Tensor& input,
191              const gtl::ArraySlice<int64>& sizes,
192              std::vector<Tensor>* outputs) {
193   bool easy_cases_done;
194   TF_RETURN_IF_ERROR(
195       SplitEasyCases<T>(context, input, sizes, outputs, &easy_cases_done));
196   if (easy_cases_done) {
197     return Status::OK();
198   }
199 
200 #if GOOGLE_CUDA
201 // TODO(olston, apassos): Handle non-CPU cases.
202 // return SplitGPU<T>(context, input, sizes, outputs);
203 #endif  // GOOGLE_CUDA
204   return SplitCPU<T>(context, input, sizes, outputs);
205 }
206 
207 // A class encapsulating the state and logic for batching tensors.
208 class BatchResource : public ResourceBase {
209  public:
Create(int32 num_batch_threads,int32 max_batch_size,int32 batch_timeout_micros,int32 max_enqueued_batches,const std::vector<int32> & allowed_batch_sizes,FunctionLibraryRuntime::Handle fhandle,std::unique_ptr<BatchResource> * resource)210   static Status Create(int32 num_batch_threads, int32 max_batch_size,
211                        int32 batch_timeout_micros, int32 max_enqueued_batches,
212                        const std::vector<int32>& allowed_batch_sizes,
213                        FunctionLibraryRuntime::Handle fhandle,
214                        std::unique_ptr<BatchResource>* resource) {
215     std::unique_ptr<BatchResource> new_resource(new BatchResource);
216 
217     Batcher::Options batcher_options;
218     batcher_options.num_batch_threads = num_batch_threads;
219     TF_RETURN_IF_ERROR(
220         Batcher::Create(batcher_options, &new_resource->batcher_));
221 
222     new_resource->batcher_queue_options_.max_batch_size = max_batch_size;
223     new_resource->batcher_queue_options_.max_enqueued_batches =
224         max_enqueued_batches;
225     new_resource->batcher_queue_options_.batch_timeout_micros =
226         batch_timeout_micros;
227 
228     new_resource->allowed_batch_sizes_ = allowed_batch_sizes;
229 
230     new_resource->fhandle_ = fhandle;
231 
232     *resource = std::move(new_resource);
233     return Status::OK();
234   }
235 
DebugString() const236   string DebugString() const final { return "BatchResource"; }
237 
238   // Ingests data from one invocation of the batch op. The data is enqueued to
239   // be combined with others into a batch, asynchronously.
RegisterInput(int64 guid,OpKernelContext * context,const string & batcher_queue_name,AsyncOpKernel::DoneCallback done_callback)240   Status RegisterInput(int64 guid, OpKernelContext* context,
241                        const string& batcher_queue_name,
242                        AsyncOpKernel::DoneCallback done_callback) {
243     std::unique_ptr<BatchTask> batch_components(new BatchTask);
244     batch_components->guid = guid;
245     OpInputList tensors;
246     TF_RETURN_IF_ERROR(context->input_list("in_tensors", &tensors));
247     for (int i = 0; i < tensors.size(); ++i) {
248       const Tensor& tensor = tensors[i];
249       if (tensor.shape().dims() == 0) {
250         return errors::InvalidArgument(
251             "Batching input tensors must have at least one dimension");
252       }
253       if (tensors.size() >= 2 &&
254           tensor.shape().dim_size(0) != tensors[0].shape().dim_size(0)) {
255         return errors::InvalidArgument(
256             "Batching input tensors supplied in a given op invocation must "
257             "have equal 0th-dimension size");
258       }
259       batch_components->inputs.push_back(tensor);
260     }
261     OpInputList captured_tensors;
262     const auto captured_status =
263         context->input_list("captured_tensors", &captured_tensors);
264     if (captured_status.ok()) {
265       for (const Tensor& captured_tensor : captured_tensors) {
266         batch_components->captured_inputs.push_back(captured_tensor);
267       }
268     }
269     batch_components->context = context;
270     batch_components->done_callback = std::move(done_callback);
271 
272     BatcherQueue* batcher_queue;
273     TF_RETURN_IF_ERROR(
274         LookupOrCreateBatcherQueue(batcher_queue_name, &batcher_queue));
275     return batcher_queue->Schedule(&batch_components);
276   }
277 
278  private:
279   BatchResource() = default;
280 
281   // One input to be batched. Corresponds to one invocation of the batch op.
282   struct BatchTask : public serving::BatchTask {
283     // A unique ID to identify this invocation of Batch.
284     int64 guid;
285 
286     std::vector<Tensor> inputs;
287     std::vector<Tensor> captured_inputs;
288     OpKernelContext* context;
289     AsyncOpKernel::DoneCallback done_callback;
290 
sizetensorflow::BatchResource::BatchTask291     size_t size() const override { return inputs[0].shape().dim_size(0); }
292   };
293 
294   using Batcher = serving::SharedBatchScheduler<BatchTask>;
295   using BatcherQueue = serving::BatchScheduler<BatchTask>;
296   using Batch = serving::Batch<BatchTask>;
297 
298   // Validates that it's legal to combine the tasks in 'batch' into a batch.
299   // Assumes the batch is non-empty.
ValidateBatch(const Batch & batch)300   static Status ValidateBatch(const Batch& batch) {
301     for (int task_idx = 0; task_idx < batch.num_tasks(); ++task_idx) {
302       const BatchTask& task = batch.task(task_idx);
303 
304       if (task.inputs.size() != batch.task(0).inputs.size()) {
305         return errors::InvalidArgument(
306             "Batching inputs must have equal number of edges");
307       }
308     }
309 
310     return Status::OK();
311   }
312 
313   // Returns the smallest entry in 'allowed_batch_sizes_' that is greater than
314   // or equal to 'batch_size'. If 'allowed_batch_sizes_' is empty, simply
315   // returns 'batch_size'.
RoundToLowestAllowedBatchSize(int batch_size) const316   int RoundToLowestAllowedBatchSize(int batch_size) const {
317     if (allowed_batch_sizes_.empty()) {
318       return batch_size;
319     }
320     for (int allowed_size : allowed_batch_sizes_) {
321       if (allowed_size >= batch_size) {
322         return allowed_size;
323       }
324     }
325     LOG(ERROR) << "Maximum batch size greater than largest allowed size; "
326                   "ignoring allowed sizes constraint";
327     return batch_size;
328   }
329 
ConcatInputTensors(const Batch & batch,OpKernelContext * context,std::vector<Tensor> * concatenated_tensors) const330   Status ConcatInputTensors(const Batch& batch, OpKernelContext* context,
331                             std::vector<Tensor>* concatenated_tensors) const {
332     if (batch.num_tasks() == 0) {
333       return errors::InvalidArgument("Empty batch.");
334     }
335 
336     const int padded_batch_size = RoundToLowestAllowedBatchSize(batch.size());
337     const int padding_amount = padded_batch_size - batch.size();
338 
339     // All tasks should have the same number of input edges.
340     const int num_inputs = batch.task(0).inputs.size();
341     concatenated_tensors->reserve(num_inputs);
342 
343     // Process each input one at a time (the typical case has just one).
344     for (int i = 0; i < num_inputs; ++i) {
345       // Concatenate the tasks ith input tensors into a big output tensor.
346       std::vector<Tensor> to_concatenate;
347       to_concatenate.reserve(batch.num_tasks());
348       for (int task_idx = 0; task_idx < batch.num_tasks(); ++task_idx) {
349         to_concatenate.push_back(batch.task(task_idx).inputs.at(i));
350       }
351 
352       // Add padding as needed. Use the first row of the first task's tensor as
353       // the data for padding.
354       if (padding_amount > 0) {
355         const Tensor& padding_source = batch.task(0).inputs.at(i);
356         Tensor padding;
357         if (padding_source.shape().dim_size(0) == 1) {
358           padding = padding_source;
359         } else {
360           const std::vector<int64> slice_sizes = {1};
361           const DataType type = padding_source.dtype();
362           Status slice_status;
363           std::vector<Tensor> slices;
364           switch (type) {
365 #define CASE(type)                                                     \
366   case DataTypeToEnum<type>::value:                                    \
367     slice_status =                                                     \
368         SplitCPU<type>(context, padding_source, slice_sizes, &slices); \
369     break;
370             TF_CALL_ALL_TYPES(CASE);
371 #undef CASE
372             default:
373               slice_status =
374                   errors::InvalidArgument("Unsupported data type: ", type);
375               break;
376           }
377           TF_RETURN_IF_ERROR(slice_status);
378           padding = slices.at(0);
379         }
380         for (int i = 0; i < padding_amount; ++i) {
381           to_concatenate.push_back(padding);
382         }
383       }
384 
385       const DataType type = to_concatenate[0].dtype();
386       Status concat_status;
387       Tensor concatenated_tensor;
388       switch (type) {
389 #define CASE(type)                                                   \
390   case DataTypeToEnum<type>::value:                                  \
391     concat_status =                                                  \
392         Concat<type>(context, to_concatenate, &concatenated_tensor); \
393     break;
394         TF_CALL_ALL_TYPES(CASE);
395 #undef CASE
396         default:
397           concat_status =
398               errors::InvalidArgument("Unsupported data type: ", type);
399           break;
400       }
401       TF_RETURN_IF_ERROR(concat_status);
402       concatenated_tensors->push_back(concatenated_tensor);
403     }
404     return Status::OK();
405   }
406 
SplitOutputTensors(const std::vector<Tensor> & combined_outputs,Batch * batch) const407   Status SplitOutputTensors(const std::vector<Tensor>& combined_outputs,
408                             Batch* batch) const {
409     DCHECK_GE(batch->num_tasks(), 1);
410     if (batch->num_tasks() < 1) {
411       return errors::Internal("Batch size expected to be positive; was ",
412                               batch->num_tasks());
413     }
414 
415     std::vector<int64> task_sizes_plus_optional_padding;
416     task_sizes_plus_optional_padding.reserve(batch->num_tasks());
417     for (int i = 0; i < batch->num_tasks(); ++i) {
418       task_sizes_plus_optional_padding.push_back(batch->task(i).size());
419     }
420     const int padding_size =
421         RoundToLowestAllowedBatchSize(batch->size()) - batch->size();
422     if (padding_size > 0) {
423       task_sizes_plus_optional_padding.push_back(padding_size);
424     }
425 
426     // For each output tensor name, a divided-up tensor with one entry per task.
427     std::map<string, std::vector<Tensor>> split_tensors;
428 
429     DCHECK_EQ(batch->task(0).context->num_outputs(), combined_outputs.size());
430     if (combined_outputs.size() != batch->task(0).context->num_outputs()) {
431       return errors::Internal("Wrong number of batched output tensors");
432     }
433 
434     // Generate 'split_tensors' and populate the context outputs.
435     for (int i = 0; i < combined_outputs.size(); ++i) {
436       const Tensor& output_tensor = combined_outputs[i];
437       if (output_tensor.shape().dims() == 0) {
438         return errors::FailedPrecondition(
439             "Batched output tensor has 0 dimensions");
440       }
441       if (output_tensor.shape().dim_size(0) != batch->size() + padding_size) {
442         return errors::FailedPrecondition(
443             "Batched output tensor's 0th dimension does not equal the sum of "
444             "the 0th dimension sizes of the input tensors");
445       }
446 
447       std::vector<Tensor> split_tensor;
448       const Status split_status = tensor::Split(
449           output_tensor, task_sizes_plus_optional_padding, &split_tensor);
450       DCHECK(split_status.ok()) << split_status.ToString();
451       if (!split_status.ok()) {
452         return errors::Internal("Tensor split operation failed: ",
453                                 split_status.ToString());
454       }
455       DCHECK_EQ(split_tensor.size(), task_sizes_plus_optional_padding.size());
456       if (split_tensor.size() != task_sizes_plus_optional_padding.size()) {
457         return errors::Internal(
458             "Tensor split operation did not work as expected; got ",
459             split_tensor.size(), " splits; expected ",
460             task_sizes_plus_optional_padding.size());
461       }
462 
463       for (int j = 0; j < batch->num_tasks(); ++j) {
464         BatchTask& task = *(batch->mutable_task(j));
465         task.context->set_output(i, split_tensor.at(j));
466       }  // (Ignore a possible final split_tensors entry containing the
467          // padding.)
468     }
469 
470     return Status::OK();
471   }
472 
ProcessFuncBatch(std::unique_ptr<Batch> batch) const473   void ProcessFuncBatch(std::unique_ptr<Batch> batch) const {
474     if (batch->empty()) {
475       return;
476     }
477 
478     OpKernelContext* last_task_context =
479         batch->task(batch->num_tasks() - 1).context;
480 
481     // Regardless of the outcome, we need to propagate the status to the
482     // individual tasks and signal that they are done. We use MakeCleanup() to
483     // ensure that this happens no matter how we exit the method below.
484     Status status;
485     bool cleanup_done = false;
486     auto cleanup_fn = [&cleanup_done, &batch](const Status& status) {
487       if (cleanup_done) {
488         return;
489       }
490       for (int i = 0; i < batch->num_tasks(); ++i) {
491         batch->mutable_task(i)->context->SetStatus(status);
492         batch->mutable_task(i)->done_callback();
493       }
494       cleanup_done = true;
495     };
496     auto finally =
497         gtl::MakeCleanup([&cleanup_fn, &status] { cleanup_fn(status); });
498 
499     status = ValidateBatch(*batch);
500     if (!status.ok()) {
501       return;
502     }
503 
504     std::vector<Tensor> concatenated_tensors;
505     status =
506         ConcatInputTensors(*batch, last_task_context, &concatenated_tensors);
507     if (!status.ok()) {
508       return;
509     }
510     FunctionLibraryRuntime::Options opts;
511     opts.step_id = last_task_context->step_id();
512     opts.step_container = last_task_context->step_container();
513     opts.cancellation_manager = last_task_context->cancellation_manager();
514     opts.stats_collector = last_task_context->stats_collector();
515     opts.rendezvous = last_task_context->rendezvous();
516     opts.runner = last_task_context->runner();
517 
518     auto* flib = last_task_context->function_library();
519     std::vector<Tensor> combined_outputs;
520     Notification done;
521     std::vector<Tensor> args(concatenated_tensors.begin(),
522                              concatenated_tensors.end());
523     const auto& captured_inputs =
524         batch->task(batch->num_tasks() - 1).captured_inputs;
525     args.insert(args.end(), captured_inputs.begin(), captured_inputs.end());
526 
527     // Releases the cleanup method here, because the callback of the function
528     // library runtime will handle it now.
529     finally.release();
530     flib->Run(
531         opts, fhandle_, args, &combined_outputs, [&](const Status& run_status) {
532           Status final_status;
533           auto run_finally = gtl::MakeCleanup([&]() {
534             // We do the cleanup here as an optimization, so that it runs in
535             // the underlying TF inter-op threadpool. Running it in the
536             // threadpool, let's the ensuing ops be scheduled faster,
537             // because the executor will add them to the front of the
538             // threadpool's task queue rather than the end.
539             cleanup_fn(final_status);
540             done.Notify();
541           });
542           final_status = run_status;
543           if (!final_status.ok()) {
544             return;
545           }
546           final_status = SplitOutputTensors(combined_outputs, batch.get());
547         });
548     // By waiting for the notification we are ensuring that this thread isn't
549     // used for processing other batches, which gives the batches time to
550     // coalesce upstream. So overall the number of batches going through the
551     // devices goes down, improving latency and throughput in most cases.
552     done.WaitForNotification();
553   }
554 
555   // Processes a batch of one or more BatchTask entries.
ProcessBatch(std::unique_ptr<Batch> batch) const556   void ProcessBatch(std::unique_ptr<Batch> batch) const {
557     if (batch->empty()) {
558       return;
559     }
560 
561     OpKernelContext* last_task_context =
562         batch->task(batch->num_tasks() - 1).context;
563     AsyncOpKernel::DoneCallback last_task_callback =
564         batch->task(batch->num_tasks() - 1).done_callback;
565 
566     OP_REQUIRES_OK_ASYNC(last_task_context, ValidateBatch(*batch),
567                          last_task_callback);
568 
569     // All tasks should have the same number of input edges.
570     const int num_input_edges = batch->task(0).inputs.size();
571     std::vector<Tensor> concatenated_tensors;
572     const Status concat_status =
573         ConcatInputTensors(*batch, last_task_context, &concatenated_tensors);
574     OP_REQUIRES_OK_ASYNC(last_task_context, concat_status, last_task_callback);
575 
576     // Process each input edge one at a time (the typical case has just one).
577     for (int i = 0; i < num_input_edges; ++i) {
578       last_task_context->set_output(i, concatenated_tensors.at(i));
579 
580       // Emit batch->num_tasks() - 1 empty output tensors.
581       for (int task_idx = 0; task_idx < batch->num_tasks() - 1; ++task_idx) {
582         const BatchTask& task = batch->task(task_idx);
583         TensorShape output_shape(task.inputs.at(i).shape());
584         output_shape.set_dim(0, 0);
585         Tensor* output = nullptr;
586         OP_REQUIRES_OK_ASYNC(
587             task.context,
588             task.context->allocate_output(i, output_shape, &output),
589             task.done_callback);
590       }
591     }
592     // Emit batch->num_tasks() - 1 empty index tensors.
593     for (int task_idx = 0; task_idx < batch->num_tasks() - 1; ++task_idx) {
594       const BatchTask& task = batch->task(task_idx);
595       TensorShape index_shape({0, 3});
596       Tensor* output = nullptr;
597       OP_REQUIRES_OK_ASYNC(
598           task.context,
599           task.context->allocate_output(num_input_edges, index_shape, &output),
600           task.done_callback);
601     }
602     // Emit all ID tensors.
603     for (int task_idx = 0; task_idx < batch->num_tasks(); ++task_idx) {
604       const BatchTask& task = batch->task(task_idx);
605       Tensor* id;
606       OP_REQUIRES_OK_ASYNC(task.context,
607                            task.context->allocate_output(num_input_edges + 1,
608                                                          TensorShape({}), &id),
609                            task.done_callback);
610       id->scalar<int64>()() = task.guid;
611     }
612     OP_REQUIRES_OK_ASYNC(
613         last_task_context,
614         EmitIndexTensor(last_task_context, *batch, num_input_edges),
615         last_task_callback);
616 
617     // Signal done for each element of the batch. (At this point, the contexts
618     // are no longer guaranteed to remain live.)
619     for (int task_idx = 0; task_idx < batch->num_tasks(); ++task_idx) {
620       batch->mutable_task(task_idx)->done_callback();
621     }
622   }
623 
624   // Emits an index tensor, which the Unbatch op will use to un-concatenate
625   // the tensor and attribute the pieces to the right batch keys. The index
626   // tensor contains, for each input: [batch_key, start_offset, end_offset]
627   // where start_offset and end_offset represent the range of entries in the
628   // concatenated tensors that belong to that input.
629   //
630   // Emits the result to the output at 'output_index' using 'context'.
EmitIndexTensor(OpKernelContext * context,const Batch & batch,int output_index)631   static Status EmitIndexTensor(OpKernelContext* context, const Batch& batch,
632                                 int output_index) {
633     const TensorShape index_shape({batch.num_tasks(), 3});
634     Tensor* index = nullptr;
635     TF_RETURN_IF_ERROR(
636         context->allocate_output(output_index, index_shape, &index));
637     auto index_flat = index->shaped<int64, 2>({batch.num_tasks(), 3});
638     size_t offset = 0;
639     for (int task_idx = 0; task_idx < batch.num_tasks(); ++task_idx) {
640       const BatchTask& task = batch.task(task_idx);
641       index_flat(task_idx, 0) = task.guid;
642       index_flat(task_idx, 1) = offset;
643       index_flat(task_idx, 2) = offset + task.size();
644       offset += task.size();
645     }
646     return Status::OK();
647   }
648 
649   // Looks up the batcher queue for 'queue_name'. If it did't previously exist,
650   // creates it.
LookupOrCreateBatcherQueue(const string & queue_name,BatcherQueue ** queue)651   Status LookupOrCreateBatcherQueue(const string& queue_name,
652                                     BatcherQueue** queue) {
653     mutex_lock l(batcher_queues_mu_);
654 
655     auto it = batcher_queues_.find(queue_name);
656     if (it != batcher_queues_.end()) {
657       *queue = it->second.get();
658       return Status::OK();
659     }
660 
661     std::unique_ptr<BatcherQueue> new_queue;
662     auto process_batch_callback = [this](std::unique_ptr<Batch> batch) {
663       if (fhandle_ == kInvalidHandle) {
664         ProcessBatch(std::move(batch));
665       } else {
666         ProcessFuncBatch(std::move(batch));
667       }
668     };
669     TF_RETURN_IF_ERROR(batcher_->AddQueue(batcher_queue_options_,
670                                           process_batch_callback, &new_queue));
671     *queue = new_queue.get();
672     batcher_queues_[queue_name] = std::move(new_queue);
673     return Status::OK();
674   }
675 
676   // A batch scheduler, and options for creating queues.
677   std::shared_ptr<Batcher> batcher_;
678   Batcher::QueueOptions batcher_queue_options_;
679 
680   // A collection of batcher queues, keyed on queue name.
681   // TODO(olston): Garbage-collect unused queues (perhaps simply remove empty
682   // ones (with a time delay?); it's okay if they get recreated later).
683   mutable mutex batcher_queues_mu_;
684   std::map<string, std::unique_ptr<BatcherQueue>> batcher_queues_
685       GUARDED_BY(batcher_queues_mu_);
686 
687   std::vector<int32> allowed_batch_sizes_;
688   FunctionLibraryRuntime::Handle fhandle_;
689 };
690 
691 class BatchFunctionKernel : public AsyncOpKernel {
692  public:
BatchFunctionKernel(OpKernelConstruction * c)693   explicit BatchFunctionKernel(OpKernelConstruction* c) : AsyncOpKernel(c) {
694     OP_REQUIRES_OK(c, c->GetAttr("container", &container_));
695     OP_REQUIRES_OK(c, c->GetAttr("shared_name", &shared_name_));
696     // If shared_name is not supplied, use name instead (prevent collisions by
697     // default).
698     if (shared_name_.empty()) {
699       shared_name_ = name();
700     }
701     OP_REQUIRES_OK(c, c->GetAttr("batching_queue", &batcher_queue_));
702     OP_REQUIRES_OK(c, c->GetAttr("num_batch_threads", &num_batch_threads_));
703     OP_REQUIRES_OK(c, c->GetAttr("max_batch_size", &max_batch_size_));
704     OP_REQUIRES_OK(c,
705                    c->GetAttr("batch_timeout_micros", &batch_timeout_micros_));
706     OP_REQUIRES_OK(c,
707                    c->GetAttr("max_enqueued_batches", &max_enqueued_batches_));
708     OP_REQUIRES_OK(c, c->GetAttr("allowed_batch_sizes", &allowed_batch_sizes_));
709     OP_REQUIRES_OK(c, ValidateAllowedBatchSizes());
710 
711     auto lib = c->function_library();
712     OP_REQUIRES(c, lib != nullptr, errors::Internal("No function library"));
713     NameAttrList func;
714     OP_REQUIRES_OK(c, c->GetAttr("f", &func));
715     OP_REQUIRES_OK(
716         c, lib->Instantiate(func.name(), AttrSlice(&func.attr()), &fhandle_));
717   }
718 
IsExpensive()719   bool IsExpensive() override { return false; }
720 
ComputeAsync(OpKernelContext * c,DoneCallback done)721   void ComputeAsync(OpKernelContext* c, DoneCallback done) final {
722     BatchResource* br;
723     std::function<Status(BatchResource**)> creator = [this](BatchResource** r) {
724       std::unique_ptr<BatchResource> new_resource;
725       TF_RETURN_IF_ERROR(
726           BatchResource::Create(num_batch_threads_, max_batch_size_,
727                                 batch_timeout_micros_, max_enqueued_batches_,
728                                 allowed_batch_sizes_, fhandle_, &new_resource));
729       *r = new_resource.release();
730       return Status::OK();
731     };
732     OP_REQUIRES_OK_ASYNC(c,
733                          c->resource_manager()->LookupOrCreate(
734                              container_, shared_name_, &br, creator),
735                          done);
736     const Status status =
737         br->RegisterInput(random::New64(), c, batcher_queue_, done);
738     br->Unref();
739     OP_REQUIRES_OK_ASYNC(c, status, done);
740     // Assume br calls done, so nothing to do here.
741   }
742 
743   // Validates 'allowed_batch_sizes_'. The entries must increase monotonically,
744   // and the last one must equal 'max_batch_size_'.
ValidateAllowedBatchSizes() const745   Status ValidateAllowedBatchSizes() const {
746     if (allowed_batch_sizes_.empty()) {
747       return Status::OK();
748     }
749     int32 last_size = 0;
750     for (size_t i = 0; i < allowed_batch_sizes_.size(); ++i) {
751       const int32 size = allowed_batch_sizes_.at(i);
752       if (i > 0 && size <= last_size) {
753         return errors::InvalidArgument(
754             "allowed_batch_sizes entries must be monotonically increasing");
755       }
756       if (i == allowed_batch_sizes_.size() - 1 && size != max_batch_size_) {
757         return errors::InvalidArgument(
758             "final entry in allowed_batch_sizes must equal max_batch_size");
759       }
760       last_size = size;
761     }
762     return Status::OK();
763   }
764 
765  private:
766   string container_;
767   string shared_name_;
768   string batcher_queue_;
769   int32 num_batch_threads_;
770   int32 max_batch_size_;
771   int32 batch_timeout_micros_;
772   int32 max_enqueued_batches_;
773   std::vector<int32> allowed_batch_sizes_;
774   FunctionLibraryRuntime::Handle fhandle_;
775 };
776 
777 REGISTER_KERNEL_BUILDER(Name("BatchFunction").Device(DEVICE_CPU),
778                         BatchFunctionKernel);
779 
780 class BatchKernel : public AsyncOpKernel {
781  public:
BatchKernel(OpKernelConstruction * c)782   explicit BatchKernel(OpKernelConstruction* c) : AsyncOpKernel(c) {
783     OP_REQUIRES_OK(c, c->GetAttr("container", &container_));
784     OP_REQUIRES_OK(c, c->GetAttr("shared_name", &shared_name_));
785     // If shared_name is not supplied, use name instead (prevent collisions by
786     // default).
787     if (shared_name_.empty()) {
788       shared_name_ = name();
789     }
790     OP_REQUIRES_OK(c, c->GetAttr("batching_queue", &batcher_queue_));
791     OP_REQUIRES_OK(c, c->GetAttr("num_batch_threads", &num_batch_threads_));
792     OP_REQUIRES_OK(c, c->GetAttr("max_batch_size", &max_batch_size_));
793     OP_REQUIRES_OK(c,
794                    c->GetAttr("batch_timeout_micros", &batch_timeout_micros_));
795     OP_REQUIRES_OK(c,
796                    c->GetAttr("max_enqueued_batches", &max_enqueued_batches_));
797     OP_REQUIRES_OK(c, c->GetAttr("allowed_batch_sizes", &allowed_batch_sizes_));
798     OP_REQUIRES_OK(c, ValidateAllowedBatchSizes());
799   }
800 
ComputeAsync(OpKernelContext * c,DoneCallback done)801   void ComputeAsync(OpKernelContext* c, DoneCallback done) final {
802     BatchResource* br;
803     std::function<Status(BatchResource**)> creator = [this](BatchResource** r) {
804       std::unique_ptr<BatchResource> new_resource;
805       TF_RETURN_IF_ERROR(BatchResource::Create(
806           num_batch_threads_, max_batch_size_, batch_timeout_micros_,
807           max_enqueued_batches_, allowed_batch_sizes_, kInvalidHandle,
808           &new_resource));
809       *r = new_resource.release();
810       return Status::OK();
811     };
812     OP_REQUIRES_OK_ASYNC(c,
813                          c->resource_manager()->LookupOrCreate(
814                              container_, shared_name_, &br, creator),
815                          done);
816     const Status status =
817         br->RegisterInput(random::New64(), c, batcher_queue_, done);
818     br->Unref();
819     OP_REQUIRES_OK_ASYNC(c, status, done);
820     // Assume br calls done, so nothing to do here.
821   }
822 
823   // Validates 'allowed_batch_sizes_'. The entries must increase monotonically,
824   // and the last one must equal 'max_batch_size_'.
ValidateAllowedBatchSizes() const825   Status ValidateAllowedBatchSizes() const {
826     if (allowed_batch_sizes_.empty()) {
827       return Status::OK();
828     }
829     int32 last_size = 0;
830     for (size_t i = 0; i < allowed_batch_sizes_.size(); ++i) {
831       const int32 size = allowed_batch_sizes_.at(i);
832       if (i > 0 && size <= last_size) {
833         return errors::InvalidArgument(
834             "allowed_batch_sizes entries must be monotonically increasing");
835       }
836       if (i == allowed_batch_sizes_.size() - 1 && size != max_batch_size_) {
837         return errors::InvalidArgument(
838             "final entry in allowed_batch_sizes must equal max_batch_size");
839       }
840       last_size = size;
841     }
842     return Status::OK();
843   }
844 
845  private:
846   string container_;
847   string shared_name_;
848   string batcher_queue_;
849   int32 num_batch_threads_;
850   int32 max_batch_size_;
851   int32 batch_timeout_micros_;
852   int32 max_enqueued_batches_;
853   std::vector<int32> allowed_batch_sizes_;
854 };
855 
856 REGISTER_KERNEL_BUILDER(Name("Batch").Device(DEVICE_CPU), BatchKernel);
857 
858 // A class encapsulating the state and logic for unbatching tensors.
859 //
860 // UnbatchResource keeps two data structures indexed by batch-key: one which has
861 // the continuations for all concurrent kernels which are waiting for tensors
862 // and another which has tensors which are waiting for their corresponding
863 // kernels to run. Whenever a kernel runs, we either grab its tensor if it's
864 // waiting already, or we insert it in the queue and then look at its tensor to
865 // see if it can be used to dispatch any stored continuations.
866 class UnbatchResource : public ResourceBase {
867  public:
UnbatchResource(int32 timeout_micros)868   explicit UnbatchResource(int32 timeout_micros)
869       : timeout_micros_(timeout_micros),
870         timeout_enforcer_(new serving::PeriodicFunction(
871             [this] { EnforceTimeout(); }, 1000 /* 1 ms */)) {}
872 
~UnbatchResource()873   ~UnbatchResource() override {
874     // Tear down 'timeout_enforcer_' first, since it accesses other state in
875     // this class.
876     timeout_enforcer_ = nullptr;
877   }
878 
DebugString() const879   string DebugString() const final { return "UnbatchResource"; }
880 
Compute(OpKernelContext * context,AsyncOpKernel::DoneCallback done)881   Status Compute(OpKernelContext* context, AsyncOpKernel::DoneCallback done) {
882     const Tensor& data_t = context->input(0);
883     const Tensor& batch_index_t = context->input(1);
884 
885     if (batch_index_t.shape().dim_size(0) > data_t.shape().dim_size(0)) {
886       return errors::InvalidArgument(
887           "Wrong shape for index tensor. Expected 0th dimension size to be no "
888           "greater than ",
889           data_t.shape().dim_size(0),
890           "; Got: ", batch_index_t.shape().dim_size(0), ".");
891     }
892     if (batch_index_t.shape().dim_size(1) != 3) {
893       return errors::InvalidArgument(
894           "Wrong shape for index tensor. Expected 1st dimension size to be 3 ; "
895           "Got: ",
896           batch_index_t.shape().dim_size(1), ".");
897     }
898 
899     const int64 batch_key = context->input(2).scalar<int64>()();
900     const bool nonempty_input = batch_index_t.dim_size(0) > 0;
901 
902     // If we have a non-empty tensor, slice it up.
903     // (It is important to do this outside of the critical section below.)
904     // The following variables are populated iff 'nonempty_input==true'.
905     std::vector<int64> sizes;
906     std::vector<int64> batch_keys;
907     std::vector<Tensor> split_inputs;
908     if (nonempty_input) {
909       auto batch_indices =
910           batch_index_t.shaped<int64, 2>({batch_index_t.dim_size(0), 3});
911       for (int i = 0; i < batch_index_t.dim_size(0); ++i) {
912         sizes.push_back(batch_indices(i, 2) - batch_indices(i, 1));
913         batch_keys.push_back(batch_indices(i, 0));
914       }
915 
916       const DataType type = data_t.dtype();
917       switch (type) {
918 #define CASE(type)                                                          \
919   case DataTypeToEnum<type>::value:                                         \
920     TF_RETURN_IF_ERROR(Split<type>(context, data_t, sizes, &split_inputs)); \
921     break;
922         TF_CALL_ALL_TYPES(CASE);
923 #undef CASE
924         default:
925           return errors::InvalidArgument("Unsupported data type: ", type);
926       }
927     }
928 
929     // Critical section.
930     std::vector<AsyncOpKernel::DoneCallback> done_callbacks_to_call;
931     Status status = [&]() -> Status {
932       mutex_lock ml(mu_);
933 
934       // Check to see whether the tensor we want is already ready.
935       auto tensor_it = waiting_tensors_.find(batch_key);
936       if (tensor_it != waiting_tensors_.end()) {
937         context->set_output(0, tensor_it->second.tensor);
938         waiting_tensors_.erase(tensor_it);
939         done_callbacks_to_call.push_back(done);
940         return Status::OK();
941       }
942 
943       const uint64 deadline_micros =
944           Env::Default()->NowMicros() + timeout_micros_;
945 
946       // Add ourselves to the waitlist for tensors.
947       if (!waiting_callbacks_
948                .emplace(batch_key,
949                         WaitingCallback{deadline_micros, context, done})
950                .second) {
951         return errors::AlreadyExists(
952             "Multiple session runs with the same batch key.");
953       }
954 
955       // If we have a non-empty tensor, finish the waitlisted runs,
956       // and store any remaining pieces.
957       if (nonempty_input) {
958         for (size_t i = 0; i < batch_keys.size(); ++i) {
959           auto runs_it = waiting_callbacks_.find(batch_keys[i]);
960           if (runs_it != waiting_callbacks_.end()) {
961             runs_it->second.context->set_output(0, split_inputs[i]);
962             done_callbacks_to_call.push_back(runs_it->second.done);
963             waiting_callbacks_.erase(runs_it);
964           } else {
965             // Note: the deadline here is in case we are arriving late and the
966             // kernel that should rendezvous with this tensor has already waited
967             // and timed out.
968             if (!waiting_tensors_
969                      .emplace(batch_keys[i],
970                               WaitingTensor{deadline_micros, split_inputs[i]})
971                      .second) {
972               return errors::AlreadyExists(
973                   "Multiple tensors returned for same batch key.");
974             }
975           }
976         }
977       }
978 
979       return Status::OK();
980     }();
981 
982     for (const AsyncOpKernel::DoneCallback& done_callback :
983          done_callbacks_to_call) {
984       done_callback();
985     }
986 
987     return status;
988   }
989 
990  private:
991   // Evicts waiting tensors and callbacks that have exceeded their deadline.
EnforceTimeout()992   void EnforceTimeout() {
993     const uint64 now = Env::Default()->NowMicros();
994     std::vector<WaitingCallback> evicted_callbacks;
995 
996     {
997       mutex_lock ml(mu_);
998 
999       for (auto it = waiting_tensors_.begin(); it != waiting_tensors_.end();) {
1000         const WaitingTensor& waiting_tensor = it->second;
1001         if (waiting_tensor.deadline_micros < now) {
1002           it = waiting_tensors_.erase(it);
1003         } else {
1004           ++it;
1005         }
1006       }
1007 
1008       for (auto it = waiting_callbacks_.begin();
1009            it != waiting_callbacks_.end();) {
1010         const WaitingCallback& waiting_callback = it->second;
1011         if (waiting_callback.deadline_micros < now) {
1012           evicted_callbacks.push_back(waiting_callback);
1013           it = waiting_callbacks_.erase(it);
1014         } else {
1015           ++it;
1016         }
1017       }
1018     }
1019 
1020     for (const WaitingCallback& evicted_callback : evicted_callbacks) {
1021       evicted_callback.context->CtxFailureWithWarning(errors::DeadlineExceeded(
1022           "Batched data did not arrive within timeout window."));
1023       evicted_callback.done();
1024     }
1025   }
1026 
1027   struct WaitingTensor {
1028     uint64 deadline_micros;
1029     Tensor tensor;
1030   };
1031 
1032   struct WaitingCallback {
1033     uint64 deadline_micros;
1034     OpKernelContext* context;
1035     AsyncOpKernel::DoneCallback done;
1036   };
1037 
1038   const int32 timeout_micros_;
1039 
1040   mutex mu_;
1041 
1042   // Maps keyed by BatchKey of tensors waiting for callbacks and callbacks
1043   // waiting for tensors.
1044   std::unordered_map<int64, WaitingTensor> waiting_tensors_ GUARDED_BY(mu_);
1045   std::unordered_map<int64, WaitingCallback> waiting_callbacks_ GUARDED_BY(mu_);
1046 
1047   // A thread that evicts waiting tensors and callbacks that have exceeded their
1048   // deadline.
1049   std::unique_ptr<serving::PeriodicFunction> timeout_enforcer_;
1050 };
1051 
1052 class UnbatchKernel : public AsyncOpKernel {
1053  public:
UnbatchKernel(OpKernelConstruction * c)1054   explicit UnbatchKernel(OpKernelConstruction* c) : AsyncOpKernel(c) {
1055     OP_REQUIRES_OK(c, c->GetAttr("container", &container_));
1056     OP_REQUIRES_OK(c, c->GetAttr("shared_name", &shared_name_));
1057     // If shared_name is not supplied, use name instead (prevent collisions by
1058     // default).
1059     if (shared_name_.empty()) {
1060       shared_name_ = name();
1061     }
1062     OP_REQUIRES_OK(c, c->GetAttr("timeout_micros", &timeout_micros_));
1063   }
1064 
ComputeAsync(OpKernelContext * c,DoneCallback done)1065   void ComputeAsync(OpKernelContext* c, DoneCallback done) final {
1066     UnbatchResource* ubr;
1067     std::function<Status(UnbatchResource**)> creator =
1068         [this](UnbatchResource** r) {
1069           *r = new UnbatchResource(timeout_micros_);
1070           return Status::OK();
1071         };
1072     OP_REQUIRES_OK_ASYNC(c,
1073                          c->resource_manager()->LookupOrCreate(
1074                              container_, shared_name_, &ubr, creator),
1075                          done);
1076     auto status = ubr->Compute(c, done);
1077     ubr->Unref();
1078     OP_REQUIRES_OK_ASYNC(c, status, done);
1079     // Assume ubr calls done, so nothing to do here.
1080   }
1081 
1082  private:
1083   string container_;
1084   string shared_name_;
1085   int32 timeout_micros_;
1086 };
1087 REGISTER_KERNEL_BUILDER(Name("Unbatch").Device(DEVICE_CPU), UnbatchKernel);
1088 
1089 // A class encapsulating the state and logic for batching tensors
1090 // deterministically for the gradient of unbatch.
1091 class UnbatchGradResource : public ResourceBase {
1092  public:
UnbatchGradResource()1093   UnbatchGradResource() {}
1094 
DebugString() const1095   string DebugString() const final { return "UnbatchGradResource"; }
1096 
1097   // Flushes the information for one batch, given its context and done
1098   // callback. Clears all information about it from the available_tensors_.
OutputBatch(OpKernelContext * context,const AsyncOpKernel::DoneCallback & done)1099   Status OutputBatch(OpKernelContext* context,
1100                      const AsyncOpKernel::DoneCallback& done)
1101       EXCLUSIVE_LOCKS_REQUIRED(mu_) {
1102     const Tensor& batch_index_t = context->input(1);
1103     auto batch_index =
1104         batch_index_t.shaped<int64, 2>({batch_index_t.dim_size(0), 3});
1105     std::vector<Tensor> tensors;
1106     for (int i = 0; i < batch_index_t.dim_size(0); ++i) {
1107       auto available_it = available_tensors_.find(batch_index(i, 0));
1108       if (available_it == available_tensors_.end()) {
1109         return errors::Internal("bad bookkeeping of available tensors.");
1110       }
1111       tensors.push_back(available_it->second);
1112       available_tensors_.erase(available_it);
1113     }
1114 
1115     const DataType type = tensors[0].dtype();
1116     Tensor concatenated_tensor;
1117     switch (type) {
1118 #define CASE(type)                                                            \
1119   case DataTypeToEnum<type>::value:                                           \
1120     TF_RETURN_IF_ERROR(Concat<type>(context, tensors, &concatenated_tensor)); \
1121     context->set_output(0, concatenated_tensor);                              \
1122     break;
1123       TF_CALL_ALL_TYPES(CASE);
1124 #undef CASE
1125       default:
1126         return errors::InvalidArgument("Unsupported data type: ", type);
1127     }
1128     done();
1129     return Status::OK();
1130   }
1131 
1132   // Ingests data from one invocation of the op.
Compute(OpKernelContext * context,const AsyncOpKernel::DoneCallback & done)1133   Status Compute(OpKernelContext* context,
1134                  const AsyncOpKernel::DoneCallback& done) {
1135     const Tensor& data_t = context->input(0);
1136     const Tensor& batch_index_t = context->input(1);
1137     const Tensor& grad_t = context->input(2);
1138 
1139     mutex_lock ml(mu_);
1140 
1141     const int64 batch_key = context->input(3).scalar<int64>()();
1142     // Mark our tensor as available.
1143     if (!available_tensors_.emplace(batch_key, grad_t).second) {
1144       return errors::InvalidArgument("Two runs with the same batch key.");
1145     }
1146 
1147     // Check whether we have a valid input tensor and, if so, create its
1148     // dispatch logic.
1149     if (data_t.NumElements() > 0) {
1150       if (batch_index_t.NumElements() == 0) {
1151         return errors::InvalidArgument(
1152             "batch_index is empty while the tensor isn't.");
1153       }
1154       std::unordered_set<int64> missing_tensors;
1155       const auto batch_index =
1156           batch_index_t.shaped<int64, 2>({batch_index_t.dim_size(0), 3});
1157       for (int i = 0; i < batch_index_t.dim_size(0); ++i) {
1158         const int64 batch_key = batch_index(i, 0);
1159         if (available_tensors_.find(batch_key) == available_tensors_.end()) {
1160           missing_tensors.emplace(batch_key);
1161         }
1162       }
1163       if (missing_tensors.empty()) {
1164         return OutputBatch(context, done);
1165       }
1166       if (!available_batches_
1167                .emplace(batch_key, Batch{missing_tensors, context, done})
1168                .second) {
1169         return errors::InvalidArgument(
1170             "Batch key with valid batch used twice.");
1171       }
1172       for (const int64 i : missing_tensors) {
1173         if (!desired_tensor_to_batch_map_.emplace(i, batch_key).second) {
1174           return errors::InvalidArgument(
1175               "Missing tensor wanted by more than one batch.");
1176         }
1177       }
1178     } else {
1179       // If we don't have a valid input tensor we can output an empty tensor and
1180       // call our done closure.
1181       TensorShape output_shape(grad_t.shape());
1182       output_shape.set_dim(0, 0);
1183       Tensor* output = nullptr;
1184       TF_RETURN_IF_ERROR(context->allocate_output(0, output_shape, &output));
1185       done();
1186     }
1187 
1188     // Search to see whether our tensor is desired by any existing batch.
1189     auto desire_it = desired_tensor_to_batch_map_.find(batch_key);
1190     if (desire_it != desired_tensor_to_batch_map_.end()) {
1191       // Mark our tensor as no longer missing.
1192       auto batch_it = available_batches_.find(desire_it->second);
1193       desired_tensor_to_batch_map_.erase(desire_it);
1194       if (batch_it == available_batches_.end()) {
1195         return errors::InvalidArgument("Batch no longer exists.");
1196       }
1197       batch_it->second.missing_tensors.erase(batch_key);
1198       // If all tensors are available we should concatenate them and dispatch
1199       // the batch.
1200       if (batch_it->second.missing_tensors.empty()) {
1201         TF_RETURN_IF_ERROR(
1202             OutputBatch(batch_it->second.context, batch_it->second.done));
1203         available_batches_.erase(batch_it);
1204       }
1205     }
1206     return Status::OK();
1207   }
1208 
1209  private:
1210   mutex mu_;
1211 
1212   // Represents a still-incomplete batch of tensors. When all tensors become
1213   // available they will be concatenated in the right order and sent through the
1214   // context.
1215   struct Batch {
1216     // Batch keys for tensors which are still missing from this batch. When this
1217     // is empty the Tensors can be concatenated and forwarded.
1218     std::unordered_set<int64> missing_tensors;
1219 
1220     // Context and callback for the session responsible for finishing this
1221     // batch.
1222     OpKernelContext* context;
1223     AsyncOpKernel::DoneCallback done;
1224   };
1225 
1226   // Map from batch key of the session which will output the batched gradients
1227   // to still-incomplete batches.
1228   std::unordered_map<int64, Batch> available_batches_;
1229 
1230   // Map from batch key to tensors which are waiting for their batches to be
1231   // available.
1232   std::unordered_map<int64, Tensor> available_tensors_;
1233 
1234   // Map from batch key of a tensor which is not yet available to the batch key
1235   // of the batch to which it belongs.
1236   std::unordered_map<int64, int64> desired_tensor_to_batch_map_;
1237 };
1238 
1239 class UnbatchGradKernel : public AsyncOpKernel {
1240  public:
UnbatchGradKernel(OpKernelConstruction * c)1241   explicit UnbatchGradKernel(OpKernelConstruction* c) : AsyncOpKernel(c) {
1242     OP_REQUIRES_OK(c, c->GetAttr("container", &container_));
1243     OP_REQUIRES_OK(c, c->GetAttr("shared_name", &shared_name_));
1244     // If shared_name is not supplied, use name instead (prevent collisions by
1245     // default).
1246     if (shared_name_.empty()) {
1247       shared_name_ = name();
1248     }
1249   }
1250 
ComputeAsync(OpKernelContext * c,DoneCallback done)1251   void ComputeAsync(OpKernelContext* c, DoneCallback done) final {
1252     UnbatchGradResource* ubr;
1253     std::function<Status(UnbatchGradResource**)> creator =
1254         [](UnbatchGradResource** r) {
1255           *r = new UnbatchGradResource();
1256           return Status::OK();
1257         };
1258     OP_REQUIRES_OK_ASYNC(c,
1259                          c->resource_manager()->LookupOrCreate(
1260                              container_, shared_name_, &ubr, creator),
1261                          done);
1262     Status status = ubr->Compute(c, done);
1263     ubr->Unref();
1264     OP_REQUIRES_OK_ASYNC(c, status, done);
1265     // Assume ubr calls done, so nothing to do here.
1266   }
1267 
1268  private:
1269   string container_;
1270   string shared_name_;
1271 };
1272 REGISTER_KERNEL_BUILDER(Name("UnbatchGrad").Device(DEVICE_CPU),
1273                         UnbatchGradKernel);
1274 
1275 }  // namespace tensorflow
1276