• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/function.h"
17 #include "tensorflow/core/framework/op_kernel.h"
18 #include "tensorflow/core/framework/register_types.h"
19 #include "tensorflow/core/framework/resource_mgr.h"
20 #include "tensorflow/core/framework/tensor.h"
21 #include "tensorflow/core/framework/tensor_util.h"
22 #include "tensorflow/core/framework/types.h"
23 #include "tensorflow/core/kernels/batching_util/periodic_function.h"
24 #include "tensorflow/core/kernels/batching_util/shared_batch_scheduler.h"
25 #include "tensorflow/core/kernels/concat_lib.h"
26 #include "tensorflow/core/kernels/ops_util.h"
27 #include "tensorflow/core/kernels/split_lib.h"
28 #include "tensorflow/core/lib/gtl/cleanup.h"
29 #include "tensorflow/core/lib/random/random.h"
30 #include "tensorflow/core/platform/context.h"
31 #include "tensorflow/core/platform/errors.h"
32 #include "tensorflow/core/platform/macros.h"
33 
34 namespace tensorflow {
35 
36 typedef Eigen::ThreadPoolDevice CPUDevice;
37 typedef Eigen::GpuDevice GPUDevice;
38 #ifdef TENSORFLOW_USE_SYCL
39 typedef Eigen::SyclDevice SYCLDevice;
40 #endif  // TENSORFLOW_USE_SYCL
41 
42 // Concatenates 'inputs' into a single tensor along the zeroth dimension.
43 // Requires that all elements of 'inputs' have element type T. Writes to the
44 // op's output at position 'output_index', using 'context' for the allocation to
45 // ensure proper device placement.
46 template <typename T>
Concat(OpKernelContext * context,const gtl::ArraySlice<Tensor> & inputs,Tensor * output)47 Status Concat(OpKernelContext* context, const gtl::ArraySlice<Tensor>& inputs,
48               Tensor* output) {
49   const int input_dims = inputs[0].dims();
50   const TensorShape& input_shape = inputs[0].shape();
51 
52   // Note that we reduce the concat of k-dimensional tensors into a two
53   // dimensional concat. Assuming the dimensions of any input tensor are
54   // {y0, y1,...,ym-1}, we flatten it to {1, y}, where y = Prod_i(yi).
55   std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>> inputs_flat;
56   inputs_flat.reserve(inputs.size());
57   int64 output_dim0 = 0;
58   for (size_t i = 0; i < inputs.size(); ++i) {
59     const Tensor& input = inputs[i];
60     if (input.dims() != input_dims) {
61       return errors::InvalidArgument(
62           "Ranks of all input tensors should match: shape[0] = ",
63           input_shape.DebugString(), " vs. shape[", i,
64           "] = ", input.shape().DebugString());
65     }
66     for (int j = 1; j < input_dims; ++j) {
67       if (input.dim_size(j) != input_shape.dim_size(j)) {
68         return errors::InvalidArgument(
69             "Dimensions of inputs should match: shape[0] = ",
70             input_shape.DebugString(), " vs. shape[", i,
71             "] = ", input.shape().DebugString());
72       }
73     }
74     if (input.NumElements() > 0) {
75       inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
76           input.shaped<T, 2>({1, input.NumElements()})));
77     }
78     output_dim0 += input.dim_size(0);
79   }
80 
81   TensorShape output_shape(input_shape);
82   output_shape.set_dim(0, output_dim0);
83   TF_RETURN_IF_ERROR(
84       context->allocate_temp(DataTypeToEnum<T>::value, output_shape, output));
85   if (output->NumElements() > 0) {
86     auto output_flat = output->shaped<T, 2>({1, output->NumElements()});
87 #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
88     (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
89     if (std::is_same<Device, GPUDevice>::value) {
90       ConcatGPU<T>(context, inputs_flat, output, &output_flat);
91       return Status::OK();
92     }
93 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
94     ConcatCPU<T>(context->device(), inputs_flat, &output_flat);
95   }
96 
97   return Status::OK();
98 }
99 
100 // The Split*() functions split 'input' with element type T into 'sizes.size()'
101 // tensors along the zeroth dimension, with the ith split having zeroth-
102 // dimension size 'sizes[i]'. They allocate the output tensors using 'context',
103 // for proper device placement.
104 
105 // Handles special cases that are cheap. Sets 'done==true' iff it found an
106 // applicable special case and wrote to the outputs. Otherwise acts as a no-op.
107 template <typename T>
SplitEasyCases(OpKernelContext * context,const Tensor & input,const gtl::ArraySlice<int64> & sizes,std::vector<Tensor> * outputs,bool * done)108 Status SplitEasyCases(OpKernelContext* context, const Tensor& input,
109                       const gtl::ArraySlice<int64>& sizes,
110                       std::vector<Tensor>* outputs, bool* done) {
111   *done = false;
112 
113   int64 total_size = 0;
114   for (const int64 size : sizes) {
115     total_size += size;
116   }
117   if (total_size > input.shape().dim_size(0)) {
118     return errors::InvalidArgument(
119         "Sum of split sizes must not exceed dim0-size of input tensor");
120   }
121 
122   // Special case 0: trivial 1-way split.
123   if (sizes.size() == 1 && sizes.at(0) == input.shape().dim_size(0)) {
124     outputs->push_back(input);
125     *done = true;
126     return Status::OK();
127   }
128 
129   // Special case 1: input is aligned.
130   if (IsInnerDimsSizeAligned<T>(input.shape())) {
131     int64 position = 0;
132     for (const int64 size : sizes) {
133       outputs->emplace_back(input.Slice(position, position + size));
134       position += size;
135     }
136     *done = true;
137     return Status::OK();
138   }
139 
140   return Status::OK();
141 }
142 
143 // Handles the general case, on CPU.
144 template <typename T>
SplitCPU(OpKernelContext * context,const Tensor & input,const gtl::ArraySlice<int64> & sizes,std::vector<Tensor> * outputs)145 Status SplitCPU(OpKernelContext* context, const Tensor& input,
146                 const gtl::ArraySlice<int64>& sizes,
147                 std::vector<Tensor>* outputs) {
148   int64 suffix_dim_size = 1;
149   for (int i = 1; i < input.shape().dims(); ++i) {
150     suffix_dim_size *= input.shape().dim_size(i);
151   }
152   auto input_reshaped =
153       input.shaped<T, 2>({input.shape().dim_size(0), suffix_dim_size});
154 
155   int64 position = 0;
156   for (const int64 size : sizes) {
157     TensorShape output_shape = input.shape();
158     output_shape.set_dim(0, size);
159     Tensor output;
160     TF_RETURN_IF_ERROR(
161         context->allocate_temp(input.dtype(), output_shape, &output));
162     auto output_shaped = output.shaped<T, 2>({size, suffix_dim_size});
163 
164     Eigen::DSizes<Eigen::DenseIndex, 2> slice_indices{position, 0};
165     Eigen::DSizes<Eigen::DenseIndex, 2> slice_sizes{size, suffix_dim_size};
166     functor::Split<CPUDevice, T, 2>()(context->eigen_device<CPUDevice>(),
167                                       output_shaped, input_reshaped,
168                                       slice_indices, slice_sizes);
169 
170     outputs->emplace_back(output);
171 
172     position += size;
173   }
174 
175   return Status::OK();
176 }
177 
178 #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
179     (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
180 
181 // Handles the general case, on GPU.
182 template <typename T>
SplitGPU(OpKernelContext * context,const Tensor & input,const gtl::ArraySlice<int64> & sizes,std::vector<Tensor> * outputs)183 Status SplitGPU(OpKernelContext* context, const Tensor& input,
184                 const gtl::ArraySlice<int64>& sizes,
185                 std::vector<Tensor>* outputs) {
186   // TODO(olston, apassos): Implement this.
187   LOG(FATAL) << "Not yet implemented";  // Crash ok
188 }
189 
190 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
191 
192 // The outer function that dispatches to the various Split*() functions above.
193 template <typename T>
Split(OpKernelContext * context,const Tensor & input,const gtl::ArraySlice<int64> & sizes,std::vector<Tensor> * outputs)194 Status Split(OpKernelContext* context, const Tensor& input,
195              const gtl::ArraySlice<int64>& sizes,
196              std::vector<Tensor>* outputs) {
197   bool easy_cases_done;
198   TF_RETURN_IF_ERROR(
199       SplitEasyCases<T>(context, input, sizes, outputs, &easy_cases_done));
200   if (easy_cases_done) {
201     return Status::OK();
202   }
203 
204 #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
205     (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
206 // TODO(olston, apassos): Handle non-CPU cases.
207 // return SplitGPU<T>(context, input, sizes, outputs);
208 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
209   return SplitCPU<T>(context, input, sizes, outputs);
210 }
211 
212 // A class encapsulating the state and logic for batching tensors.
213 class BatchResource : public ResourceBase {
214  public:
Create(int32 num_batch_threads,int32 max_batch_size,int32 batch_timeout_micros,int32 max_enqueued_batches,const std::vector<int32> & allowed_batch_sizes,FunctionLibraryRuntime::Handle fhandle,std::unique_ptr<BatchResource> * resource)215   static Status Create(int32 num_batch_threads, int32 max_batch_size,
216                        int32 batch_timeout_micros, int32 max_enqueued_batches,
217                        const std::vector<int32>& allowed_batch_sizes,
218                        FunctionLibraryRuntime::Handle fhandle,
219                        std::unique_ptr<BatchResource>* resource) {
220     std::unique_ptr<BatchResource> new_resource(new BatchResource);
221 
222     Batcher::Options batcher_options;
223     batcher_options.num_batch_threads = num_batch_threads;
224     TF_RETURN_IF_ERROR(
225         Batcher::Create(batcher_options, &new_resource->batcher_));
226 
227     new_resource->batcher_queue_options_.max_batch_size = max_batch_size;
228     new_resource->batcher_queue_options_.max_enqueued_batches =
229         max_enqueued_batches;
230     new_resource->batcher_queue_options_.batch_timeout_micros =
231         batch_timeout_micros;
232 
233     new_resource->allowed_batch_sizes_ = allowed_batch_sizes;
234 
235     new_resource->fhandle_ = fhandle;
236 
237     *resource = std::move(new_resource);
238     return Status::OK();
239   }
240 
DebugString() const241   string DebugString() const final { return "BatchResource"; }
242 
243   // Ingests data from one invocation of the batch op. The data is enqueued to
244   // be combined with others into a batch, asynchronously.
RegisterInput(int64 guid,OpKernelContext * context,const string & batcher_queue_name,AsyncOpKernel::DoneCallback done_callback)245   Status RegisterInput(int64 guid, OpKernelContext* context,
246                        const string& batcher_queue_name,
247                        AsyncOpKernel::DoneCallback done_callback) {
248     std::unique_ptr<BatchTask> batch_components(new BatchTask);
249     batch_components->guid = guid;
250     batch_components->propagated_context = Context(ContextKind::kThread);
251     OpInputList tensors;
252     TF_RETURN_IF_ERROR(context->input_list("in_tensors", &tensors));
253     for (int i = 0; i < tensors.size(); ++i) {
254       const Tensor& tensor = tensors[i];
255       if (tensor.shape().dims() == 0) {
256         return errors::InvalidArgument(
257             "Batching input tensors must have at least one dimension");
258       }
259       if (tensors.size() >= 2 &&
260           tensor.shape().dim_size(0) != tensors[0].shape().dim_size(0)) {
261         return errors::InvalidArgument(
262             "Batching input tensors supplied in a given op invocation must "
263             "have equal 0th-dimension size");
264       }
265       batch_components->inputs.push_back(tensor);
266     }
267     OpInputList captured_tensors;
268     const auto captured_status =
269         context->input_list("captured_tensors", &captured_tensors);
270     if (captured_status.ok()) {
271       for (const Tensor& captured_tensor : captured_tensors) {
272         batch_components->captured_inputs.push_back(captured_tensor);
273       }
274     }
275     batch_components->context = context;
276     batch_components->done_callback = std::move(done_callback);
277 
278     BatcherQueue* batcher_queue;
279     TF_RETURN_IF_ERROR(
280         LookupOrCreateBatcherQueue(batcher_queue_name, &batcher_queue));
281     return batcher_queue->Schedule(&batch_components);
282   }
283 
284  private:
285   BatchResource() = default;
286 
287   // One input to be batched. Corresponds to one invocation of the batch op.
288   struct BatchTask : public serving::BatchTask {
289     // A unique ID to identify this invocation of Batch.
290     int64 guid;
291 
292     Context propagated_context;
293 
294     std::vector<Tensor> inputs;
295     std::vector<Tensor> captured_inputs;
296     OpKernelContext* context;
297     AsyncOpKernel::DoneCallback done_callback;
298 
sizetensorflow::BatchResource::BatchTask299     size_t size() const override { return inputs[0].shape().dim_size(0); }
300   };
301 
302   using Batcher = serving::SharedBatchScheduler<BatchTask>;
303   using BatcherQueue = serving::BatchScheduler<BatchTask>;
304   using Batch = serving::Batch<BatchTask>;
305 
306   // Validates that it's legal to combine the tasks in 'batch' into a batch.
307   // Assumes the batch is non-empty.
ValidateBatch(const Batch & batch)308   static Status ValidateBatch(const Batch& batch) {
309     for (int task_idx = 0; task_idx < batch.num_tasks(); ++task_idx) {
310       const BatchTask& task = batch.task(task_idx);
311 
312       if (task.inputs.size() != batch.task(0).inputs.size()) {
313         return errors::InvalidArgument(
314             "Batching inputs must have equal number of edges");
315       }
316     }
317 
318     return Status::OK();
319   }
320 
321   // Returns the smallest entry in 'allowed_batch_sizes_' that is greater than
322   // or equal to 'batch_size'. If 'allowed_batch_sizes_' is empty, simply
323   // returns 'batch_size'.
RoundToLowestAllowedBatchSize(int batch_size) const324   int RoundToLowestAllowedBatchSize(int batch_size) const {
325     if (allowed_batch_sizes_.empty()) {
326       return batch_size;
327     }
328     for (int allowed_size : allowed_batch_sizes_) {
329       if (allowed_size >= batch_size) {
330         return allowed_size;
331       }
332     }
333     LOG(ERROR) << "Maximum batch size greater than largest allowed size; "
334                   "ignoring allowed sizes constraint";
335     return batch_size;
336   }
337 
ConcatInputTensors(const Batch & batch,OpKernelContext * context,std::vector<Tensor> * concatenated_tensors) const338   Status ConcatInputTensors(const Batch& batch, OpKernelContext* context,
339                             std::vector<Tensor>* concatenated_tensors) const {
340     if (batch.num_tasks() == 0) {
341       return errors::InvalidArgument("Empty batch.");
342     }
343 
344     const int padded_batch_size = RoundToLowestAllowedBatchSize(batch.size());
345     const int padding_amount = padded_batch_size - batch.size();
346 
347     // All tasks should have the same number of input edges.
348     const int num_inputs = batch.task(0).inputs.size();
349     concatenated_tensors->reserve(num_inputs);
350 
351     // Process each input one at a time (the typical case has just one).
352     for (int i = 0; i < num_inputs; ++i) {
353       // Concatenate the tasks ith input tensors into a big output tensor.
354       std::vector<Tensor> to_concatenate;
355       to_concatenate.reserve(batch.num_tasks());
356       for (int task_idx = 0; task_idx < batch.num_tasks(); ++task_idx) {
357         to_concatenate.push_back(batch.task(task_idx).inputs.at(i));
358       }
359 
360       // Add padding as needed. Use the first row of the first task's tensor as
361       // the data for padding.
362       if (padding_amount > 0) {
363         const Tensor& padding_source = batch.task(0).inputs.at(i);
364         Tensor padding;
365         if (padding_source.shape().dim_size(0) == 0) {
366           return errors::InvalidArgument(
367               "Cannot use an empty tensor with zero rows as padding when "
368               "batching.");
369         }
370         if (padding_source.shape().dim_size(0) == 1) {
371           padding = padding_source;
372         } else {
373           padding = padding_source.Slice(0, 1);
374         }
375         for (int i = 0; i < padding_amount; ++i) {
376           to_concatenate.push_back(padding);
377         }
378       }
379 
380       const DataType type = to_concatenate[0].dtype();
381       Status concat_status;
382       Tensor concatenated_tensor;
383       switch (type) {
384 #define CASE(type)                                                   \
385   case DataTypeToEnum<type>::value:                                  \
386     concat_status =                                                  \
387         Concat<type>(context, to_concatenate, &concatenated_tensor); \
388     break;
389         TF_CALL_ALL_TYPES(CASE);
390 #undef CASE
391         default:
392           concat_status =
393               errors::InvalidArgument("Unsupported data type: ", type);
394           break;
395       }
396       TF_RETURN_IF_ERROR(concat_status);
397       concatenated_tensors->push_back(concatenated_tensor);
398     }
399     return Status::OK();
400   }
401 
SplitOutputTensors(const std::vector<Tensor> & combined_outputs,Batch * batch) const402   Status SplitOutputTensors(const std::vector<Tensor>& combined_outputs,
403                             Batch* batch) const {
404     DCHECK_GE(batch->num_tasks(), 1);
405     if (batch->num_tasks() < 1) {
406       return errors::Internal("Batch size expected to be positive; was ",
407                               batch->num_tasks());
408     }
409 
410     std::vector<int64> task_sizes_plus_optional_padding;
411     task_sizes_plus_optional_padding.reserve(batch->num_tasks());
412     for (int i = 0; i < batch->num_tasks(); ++i) {
413       task_sizes_plus_optional_padding.push_back(batch->task(i).size());
414     }
415     const int padding_size =
416         RoundToLowestAllowedBatchSize(batch->size()) - batch->size();
417     if (padding_size > 0) {
418       task_sizes_plus_optional_padding.push_back(padding_size);
419     }
420 
421     // For each output tensor name, a divided-up tensor with one entry per task.
422     std::map<string, std::vector<Tensor>> split_tensors;
423 
424     DCHECK_EQ(batch->task(0).context->num_outputs(), combined_outputs.size());
425     if (combined_outputs.size() != batch->task(0).context->num_outputs()) {
426       return errors::Internal("Wrong number of batched output tensors");
427     }
428 
429     // Generate 'split_tensors' and populate the context outputs.
430     for (int i = 0; i < combined_outputs.size(); ++i) {
431       const Tensor& output_tensor = combined_outputs[i];
432       if (output_tensor.shape().dims() == 0) {
433         return errors::FailedPrecondition(
434             "Batched output tensor has 0 dimensions");
435       }
436       if (output_tensor.shape().dim_size(0) != batch->size() + padding_size) {
437         return errors::FailedPrecondition(
438             "Batched output tensor's 0th dimension does not equal the sum of "
439             "the 0th dimension sizes of the input tensors");
440       }
441 
442       std::vector<Tensor> split_tensor;
443       const Status split_status = tensor::Split(
444           output_tensor, task_sizes_plus_optional_padding, &split_tensor);
445       DCHECK(split_status.ok()) << split_status.ToString();
446       if (!split_status.ok()) {
447         return errors::Internal("Tensor split operation failed: ",
448                                 split_status.ToString());
449       }
450       DCHECK_EQ(split_tensor.size(), task_sizes_plus_optional_padding.size());
451       if (split_tensor.size() != task_sizes_plus_optional_padding.size()) {
452         return errors::Internal(
453             "Tensor split operation did not work as expected; got ",
454             split_tensor.size(), " splits; expected ",
455             task_sizes_plus_optional_padding.size());
456       }
457 
458       for (int j = 0; j < batch->num_tasks(); ++j) {
459         BatchTask& task = *(batch->mutable_task(j));
460         task.context->set_output(i, split_tensor.at(j));
461       }  // (Ignore a possible final split_tensors entry containing the
462          // padding.)
463     }
464 
465     return Status::OK();
466   }
467 
ProcessFuncBatch(std::unique_ptr<Batch> batch) const468   void ProcessFuncBatch(std::unique_ptr<Batch> batch) const {
469     if (batch->empty()) {
470       return;
471     }
472 
473     // We use the 'propagated_context' from one of the threads which setup one
474     // of the tasks. This will propagate any common context over all the threads
475     // which are running this Session, of which this BatchOp is a part.
476     WithContext wc(batch->task(batch->num_tasks() - 1).propagated_context);
477 
478     OpKernelContext* last_task_context =
479         batch->task(batch->num_tasks() - 1).context;
480 
481     // Regardless of the outcome, we need to propagate the status to the
482     // individual tasks and signal that they are done. We use MakeCleanup() to
483     // ensure that this happens no matter how we exit the method below.
484     Status status;
485     bool cleanup_done = false;
486     auto cleanup_fn = [&cleanup_done, &batch](const Status& status) {
487       if (cleanup_done) {
488         return;
489       }
490       for (int i = 0; i < batch->num_tasks(); ++i) {
491         batch->mutable_task(i)->context->SetStatus(status);
492         batch->mutable_task(i)->done_callback();
493       }
494       cleanup_done = true;
495     };
496     auto finally =
497         gtl::MakeCleanup([&cleanup_fn, &status] { cleanup_fn(status); });
498 
499     status = ValidateBatch(*batch);
500     if (!status.ok()) {
501       return;
502     }
503 
504     std::vector<Tensor> concatenated_tensors;
505     status =
506         ConcatInputTensors(*batch, last_task_context, &concatenated_tensors);
507     if (!status.ok()) {
508       return;
509     }
510     FunctionLibraryRuntime::Options opts;
511     opts.step_container = last_task_context->step_container();
512     opts.cancellation_manager = last_task_context->cancellation_manager();
513     opts.stats_collector = last_task_context->stats_collector();
514     opts.rendezvous = last_task_context->rendezvous();
515     opts.runner = last_task_context->runner();
516 
517     auto* flib = last_task_context->function_library();
518     std::vector<Tensor> combined_outputs;
519     Notification done;
520     std::vector<Tensor> args(concatenated_tensors.begin(),
521                              concatenated_tensors.end());
522     const auto& captured_inputs =
523         batch->task(batch->num_tasks() - 1).captured_inputs;
524     args.insert(args.end(), captured_inputs.begin(), captured_inputs.end());
525 
526     // Releases the cleanup method here, because the callback of the function
527     // library runtime will handle it now.
528     finally.release();
529     flib->Run(
530         opts, fhandle_, args, &combined_outputs, [&](const Status& run_status) {
531           Status final_status;
532           auto run_finally = gtl::MakeCleanup([&]() {
533             // We do the cleanup here as an optimization, so that it runs in
534             // the underlying TF inter-op threadpool. Running it in the
535             // threadpool, let's the ensuing ops be scheduled faster,
536             // because the executor will add them to the front of the
537             // threadpool's task queue rather than the end.
538             cleanup_fn(final_status);
539             done.Notify();
540           });
541           final_status = run_status;
542           if (!final_status.ok()) {
543             return;
544           }
545           final_status = SplitOutputTensors(combined_outputs, batch.get());
546         });
547     // By waiting for the notification we are ensuring that this thread isn't
548     // used for processing other batches, which gives the batches time to
549     // coalesce upstream. So overall the number of batches going through the
550     // devices goes down, improving latency and throughput in most cases.
551     done.WaitForNotification();
552   }
553 
554   // Processes a batch of one or more BatchTask entries.
ProcessBatch(std::unique_ptr<Batch> batch) const555   void ProcessBatch(std::unique_ptr<Batch> batch) const {
556     if (batch->empty()) {
557       return;
558     }
559 
560     WithContext wc(batch->task(batch->num_tasks() - 1).propagated_context);
561 
562     OpKernelContext* last_task_context =
563         batch->task(batch->num_tasks() - 1).context;
564     AsyncOpKernel::DoneCallback last_task_callback =
565         batch->task(batch->num_tasks() - 1).done_callback;
566 
567     OP_REQUIRES_OK_ASYNC(last_task_context, ValidateBatch(*batch),
568                          last_task_callback);
569 
570     // All tasks should have the same number of input edges.
571     const int num_input_edges = batch->task(0).inputs.size();
572     std::vector<Tensor> concatenated_tensors;
573     const Status concat_status =
574         ConcatInputTensors(*batch, last_task_context, &concatenated_tensors);
575     OP_REQUIRES_OK_ASYNC(last_task_context, concat_status, last_task_callback);
576 
577     // Process each input edge one at a time (the typical case has just one).
578     for (int i = 0; i < num_input_edges; ++i) {
579       last_task_context->set_output(i, concatenated_tensors.at(i));
580 
581       // Emit batch->num_tasks() - 1 empty output tensors.
582       for (int task_idx = 0; task_idx < batch->num_tasks() - 1; ++task_idx) {
583         const BatchTask& task = batch->task(task_idx);
584         TensorShape output_shape(task.inputs.at(i).shape());
585         output_shape.set_dim(0, 0);
586         Tensor* output = nullptr;
587         OP_REQUIRES_OK_ASYNC(
588             task.context,
589             task.context->allocate_output(i, output_shape, &output),
590             task.done_callback);
591       }
592     }
593     // Emit batch->num_tasks() - 1 empty index tensors.
594     for (int task_idx = 0; task_idx < batch->num_tasks() - 1; ++task_idx) {
595       const BatchTask& task = batch->task(task_idx);
596       TensorShape index_shape({0, 3});
597       Tensor* output = nullptr;
598       OP_REQUIRES_OK_ASYNC(
599           task.context,
600           task.context->allocate_output(num_input_edges, index_shape, &output),
601           task.done_callback);
602     }
603     // Emit all ID tensors.
604     for (int task_idx = 0; task_idx < batch->num_tasks(); ++task_idx) {
605       const BatchTask& task = batch->task(task_idx);
606       Tensor* id;
607       OP_REQUIRES_OK_ASYNC(task.context,
608                            task.context->allocate_output(num_input_edges + 1,
609                                                          TensorShape({}), &id),
610                            task.done_callback);
611       id->scalar<int64>()() = task.guid;
612     }
613     OP_REQUIRES_OK_ASYNC(
614         last_task_context,
615         EmitIndexTensor(last_task_context, *batch, num_input_edges),
616         last_task_callback);
617 
618     // Signal done for each element of the batch. (At this point, the contexts
619     // are no longer guaranteed to remain live.)
620     for (int task_idx = 0; task_idx < batch->num_tasks(); ++task_idx) {
621       batch->mutable_task(task_idx)->done_callback();
622     }
623   }
624 
625   // Emits an index tensor, which the Unbatch op will use to un-concatenate
626   // the tensor and attribute the pieces to the right batch keys. The index
627   // tensor contains, for each input: [batch_key, start_offset, end_offset]
628   // where start_offset and end_offset represent the range of entries in the
629   // concatenated tensors that belong to that input.
630   //
631   // Emits the result to the output at 'output_index' using 'context'.
EmitIndexTensor(OpKernelContext * context,const Batch & batch,int output_index)632   static Status EmitIndexTensor(OpKernelContext* context, const Batch& batch,
633                                 int output_index) {
634     const TensorShape index_shape({batch.num_tasks(), 3});
635     Tensor* index = nullptr;
636     TF_RETURN_IF_ERROR(
637         context->allocate_output(output_index, index_shape, &index));
638     auto index_flat = index->shaped<int64, 2>({batch.num_tasks(), 3});
639     size_t offset = 0;
640     for (int task_idx = 0; task_idx < batch.num_tasks(); ++task_idx) {
641       const BatchTask& task = batch.task(task_idx);
642       index_flat(task_idx, 0) = task.guid;
643       index_flat(task_idx, 1) = offset;
644       index_flat(task_idx, 2) = offset + task.size();
645       offset += task.size();
646     }
647     return Status::OK();
648   }
649 
650   // Looks up the batcher queue for 'queue_name'. If it did't previously exist,
651   // creates it.
LookupOrCreateBatcherQueue(const string & queue_name,BatcherQueue ** queue)652   Status LookupOrCreateBatcherQueue(const string& queue_name,
653                                     BatcherQueue** queue) {
654     mutex_lock l(batcher_queues_mu_);
655 
656     auto it = batcher_queues_.find(queue_name);
657     if (it != batcher_queues_.end()) {
658       *queue = it->second.get();
659       return Status::OK();
660     }
661 
662     std::unique_ptr<BatcherQueue> new_queue;
663     auto process_batch_callback = [this](std::unique_ptr<Batch> batch) {
664       if (fhandle_ == kInvalidHandle) {
665         ProcessBatch(std::move(batch));
666       } else {
667         ProcessFuncBatch(std::move(batch));
668       }
669     };
670     TF_RETURN_IF_ERROR(batcher_->AddQueue(batcher_queue_options_,
671                                           process_batch_callback, &new_queue));
672     *queue = new_queue.get();
673     batcher_queues_[queue_name] = std::move(new_queue);
674     return Status::OK();
675   }
676 
677   // A batch scheduler, and options for creating queues.
678   std::shared_ptr<Batcher> batcher_;
679   Batcher::QueueOptions batcher_queue_options_;
680 
681   // A collection of batcher queues, keyed on queue name.
682   // TODO(olston): Garbage-collect unused queues (perhaps simply remove empty
683   // ones (with a time delay?); it's okay if they get recreated later).
684   mutable mutex batcher_queues_mu_;
685   std::map<string, std::unique_ptr<BatcherQueue>> batcher_queues_
686       GUARDED_BY(batcher_queues_mu_);
687 
688   std::vector<int32> allowed_batch_sizes_;
689   FunctionLibraryRuntime::Handle fhandle_;
690 };
691 
692 class BatchFunctionKernel : public AsyncOpKernel {
693  public:
BatchFunctionKernel(OpKernelConstruction * c)694   explicit BatchFunctionKernel(OpKernelConstruction* c) : AsyncOpKernel(c) {
695     OP_REQUIRES_OK(c, c->GetAttr("container", &container_));
696     OP_REQUIRES_OK(c, c->GetAttr("shared_name", &shared_name_));
697     // If shared_name is not supplied, use name instead (prevent collisions by
698     // default).
699     if (shared_name_.empty()) {
700       shared_name_ = name();
701     }
702     OP_REQUIRES_OK(c, c->GetAttr("batching_queue", &batcher_queue_));
703     OP_REQUIRES_OK(c, c->GetAttr("num_batch_threads", &num_batch_threads_));
704     OP_REQUIRES_OK(c, c->GetAttr("max_batch_size", &max_batch_size_));
705     OP_REQUIRES_OK(c,
706                    c->GetAttr("batch_timeout_micros", &batch_timeout_micros_));
707     OP_REQUIRES_OK(c,
708                    c->GetAttr("max_enqueued_batches", &max_enqueued_batches_));
709     OP_REQUIRES_OK(c, c->GetAttr("allowed_batch_sizes", &allowed_batch_sizes_));
710     OP_REQUIRES_OK(c, ValidateAllowedBatchSizes());
711 
712     auto lib = c->function_library();
713     OP_REQUIRES(c, lib != nullptr, errors::Internal("No function library"));
714     NameAttrList func;
715     OP_REQUIRES_OK(c, c->GetAttr("f", &func));
716     OP_REQUIRES_OK(
717         c, lib->Instantiate(func.name(), AttrSlice(&func.attr()), &fhandle_));
718   }
719 
IsExpensive()720   bool IsExpensive() override { return false; }
721 
ComputeAsync(OpKernelContext * c,DoneCallback done)722   void ComputeAsync(OpKernelContext* c, DoneCallback done) final {
723     BatchResource* br;
724     std::function<Status(BatchResource**)> creator = [this](BatchResource** r) {
725       std::unique_ptr<BatchResource> new_resource;
726       TF_RETURN_IF_ERROR(
727           BatchResource::Create(num_batch_threads_, max_batch_size_,
728                                 batch_timeout_micros_, max_enqueued_batches_,
729                                 allowed_batch_sizes_, fhandle_, &new_resource));
730       *r = new_resource.release();
731       return Status::OK();
732     };
733     OP_REQUIRES_OK_ASYNC(c,
734                          c->resource_manager()->LookupOrCreate(
735                              container_, shared_name_, &br, creator),
736                          done);
737     const Status status =
738         br->RegisterInput(random::New64(), c, batcher_queue_, done);
739     br->Unref();
740     OP_REQUIRES_OK_ASYNC(c, status, done);
741     // Assume br calls done, so nothing to do here.
742   }
743 
744   // Validates 'allowed_batch_sizes_'. The entries must increase monotonically,
745   // and the last one must equal 'max_batch_size_'.
ValidateAllowedBatchSizes() const746   Status ValidateAllowedBatchSizes() const {
747     if (allowed_batch_sizes_.empty()) {
748       return Status::OK();
749     }
750     int32 last_size = 0;
751     for (size_t i = 0; i < allowed_batch_sizes_.size(); ++i) {
752       const int32 size = allowed_batch_sizes_.at(i);
753       if (i > 0 && size <= last_size) {
754         return errors::InvalidArgument(
755             "allowed_batch_sizes entries must be monotonically increasing");
756       }
757       if (i == allowed_batch_sizes_.size() - 1 && size != max_batch_size_) {
758         return errors::InvalidArgument(
759             "final entry in allowed_batch_sizes must equal max_batch_size");
760       }
761       last_size = size;
762     }
763     return Status::OK();
764   }
765 
766  private:
767   string container_;
768   string shared_name_;
769   string batcher_queue_;
770   int32 num_batch_threads_;
771   int32 max_batch_size_;
772   int32 batch_timeout_micros_;
773   int32 max_enqueued_batches_;
774   std::vector<int32> allowed_batch_sizes_;
775   FunctionLibraryRuntime::Handle fhandle_;
776 };
777 
778 REGISTER_KERNEL_BUILDER(Name("BatchFunction").Device(DEVICE_CPU),
779                         BatchFunctionKernel);
780 
781 class BatchKernel : public AsyncOpKernel {
782  public:
BatchKernel(OpKernelConstruction * c)783   explicit BatchKernel(OpKernelConstruction* c) : AsyncOpKernel(c) {
784     OP_REQUIRES_OK(c, c->GetAttr("container", &container_));
785     OP_REQUIRES_OK(c, c->GetAttr("shared_name", &shared_name_));
786     // If shared_name is not supplied, use name instead (prevent collisions by
787     // default).
788     if (shared_name_.empty()) {
789       shared_name_ = name();
790     }
791     OP_REQUIRES_OK(c, c->GetAttr("batching_queue", &batcher_queue_));
792     OP_REQUIRES_OK(c, c->GetAttr("num_batch_threads", &num_batch_threads_));
793     OP_REQUIRES_OK(c, c->GetAttr("max_batch_size", &max_batch_size_));
794     OP_REQUIRES_OK(c,
795                    c->GetAttr("batch_timeout_micros", &batch_timeout_micros_));
796     OP_REQUIRES_OK(c,
797                    c->GetAttr("max_enqueued_batches", &max_enqueued_batches_));
798     OP_REQUIRES_OK(c, c->GetAttr("allowed_batch_sizes", &allowed_batch_sizes_));
799     OP_REQUIRES_OK(c, ValidateAllowedBatchSizes());
800   }
801 
ComputeAsync(OpKernelContext * c,DoneCallback done)802   void ComputeAsync(OpKernelContext* c, DoneCallback done) final {
803     BatchResource* br;
804     std::function<Status(BatchResource**)> creator = [this](BatchResource** r) {
805       std::unique_ptr<BatchResource> new_resource;
806       TF_RETURN_IF_ERROR(BatchResource::Create(
807           num_batch_threads_, max_batch_size_, batch_timeout_micros_,
808           max_enqueued_batches_, allowed_batch_sizes_, kInvalidHandle,
809           &new_resource));
810       *r = new_resource.release();
811       return Status::OK();
812     };
813     OP_REQUIRES_OK_ASYNC(c,
814                          c->resource_manager()->LookupOrCreate(
815                              container_, shared_name_, &br, creator),
816                          done);
817     const Status status =
818         br->RegisterInput(random::New64(), c, batcher_queue_, done);
819     br->Unref();
820     OP_REQUIRES_OK_ASYNC(c, status, done);
821     // Assume br calls done, so nothing to do here.
822   }
823 
824   // Validates 'allowed_batch_sizes_'. The entries must increase monotonically,
825   // and the last one must equal 'max_batch_size_'.
ValidateAllowedBatchSizes() const826   Status ValidateAllowedBatchSizes() const {
827     if (allowed_batch_sizes_.empty()) {
828       return Status::OK();
829     }
830     int32 last_size = 0;
831     for (size_t i = 0; i < allowed_batch_sizes_.size(); ++i) {
832       const int32 size = allowed_batch_sizes_.at(i);
833       if (i > 0 && size <= last_size) {
834         return errors::InvalidArgument(
835             "allowed_batch_sizes entries must be monotonically increasing");
836       }
837       if (i == allowed_batch_sizes_.size() - 1 && size != max_batch_size_) {
838         return errors::InvalidArgument(
839             "final entry in allowed_batch_sizes must equal max_batch_size");
840       }
841       last_size = size;
842     }
843     return Status::OK();
844   }
845 
846  private:
847   string container_;
848   string shared_name_;
849   string batcher_queue_;
850   int32 num_batch_threads_;
851   int32 max_batch_size_;
852   int32 batch_timeout_micros_;
853   int32 max_enqueued_batches_;
854   std::vector<int32> allowed_batch_sizes_;
855 };
856 
857 REGISTER_KERNEL_BUILDER(Name("Batch").Device(DEVICE_CPU), BatchKernel);
858 
859 // A class encapsulating the state and logic for unbatching tensors.
860 //
861 // UnbatchResource keeps two data structures indexed by batch-key: one which has
862 // the continuations for all concurrent kernels which are waiting for tensors
863 // and another which has tensors which are waiting for their corresponding
864 // kernels to run. Whenever a kernel runs, we either grab its tensor if it's
865 // waiting already, or we insert it in the queue and then look at its tensor to
866 // see if it can be used to dispatch any stored continuations.
867 class UnbatchResource : public ResourceBase {
868  public:
UnbatchResource(int32 timeout_micros)869   explicit UnbatchResource(int32 timeout_micros)
870       : timeout_micros_(timeout_micros),
871         timeout_enforcer_(new serving::PeriodicFunction(
872             [this] { EnforceTimeout(); }, 1000 /* 1 ms */)) {}
873 
~UnbatchResource()874   ~UnbatchResource() override {
875     // Tear down 'timeout_enforcer_' first, since it accesses other state in
876     // this class.
877     timeout_enforcer_ = nullptr;
878   }
879 
DebugString() const880   string DebugString() const final { return "UnbatchResource"; }
881 
Compute(OpKernelContext * context,AsyncOpKernel::DoneCallback done)882   Status Compute(OpKernelContext* context, AsyncOpKernel::DoneCallback done) {
883     const Tensor& data_t = context->input(0);
884     const Tensor& batch_index_t = context->input(1);
885 
886     if (batch_index_t.shape().dim_size(0) > data_t.shape().dim_size(0)) {
887       return errors::InvalidArgument(
888           "Wrong shape for index tensor. Expected 0th dimension size to be no "
889           "greater than ",
890           data_t.shape().dim_size(0),
891           "; Got: ", batch_index_t.shape().dim_size(0), ".");
892     }
893     if (batch_index_t.shape().dim_size(1) != 3) {
894       return errors::InvalidArgument(
895           "Wrong shape for index tensor. Expected 1st dimension size to be 3 ; "
896           "Got: ",
897           batch_index_t.shape().dim_size(1), ".");
898     }
899 
900     const int64 batch_key = context->input(2).scalar<int64>()();
901     const bool nonempty_input = batch_index_t.dim_size(0) > 0;
902 
903     // If we have a non-empty tensor, slice it up.
904     // (It is important to do this outside of the critical section below.)
905     // The following variables are populated iff 'nonempty_input==true'.
906     std::vector<int64> sizes;
907     std::vector<int64> batch_keys;
908     std::vector<Tensor> split_inputs;
909     if (nonempty_input) {
910       auto batch_indices =
911           batch_index_t.shaped<int64, 2>({batch_index_t.dim_size(0), 3});
912       for (int i = 0; i < batch_index_t.dim_size(0); ++i) {
913         sizes.push_back(batch_indices(i, 2) - batch_indices(i, 1));
914         batch_keys.push_back(batch_indices(i, 0));
915       }
916 
917       const DataType type = data_t.dtype();
918       switch (type) {
919 #define CASE(type)                                                          \
920   case DataTypeToEnum<type>::value:                                         \
921     TF_RETURN_IF_ERROR(Split<type>(context, data_t, sizes, &split_inputs)); \
922     break;
923         TF_CALL_ALL_TYPES(CASE);
924 #undef CASE
925         default:
926           return errors::InvalidArgument("Unsupported data type: ", type);
927       }
928     }
929 
930     // Critical section.
931     std::vector<AsyncOpKernel::DoneCallback> done_callbacks_to_call;
932     Status status = [&]() -> Status {
933       mutex_lock ml(mu_);
934 
935       // Check to see whether the tensor we want is already ready.
936       auto tensor_it = waiting_tensors_.find(batch_key);
937       if (tensor_it != waiting_tensors_.end()) {
938         context->set_output(0, tensor_it->second.tensor);
939         waiting_tensors_.erase(tensor_it);
940         done_callbacks_to_call.push_back(done);
941         return Status::OK();
942       }
943 
944       const uint64 deadline_micros =
945           Env::Default()->NowMicros() + timeout_micros_;
946 
947       // Add ourselves to the waitlist for tensors.
948       if (!waiting_callbacks_
949                .emplace(batch_key,
950                         WaitingCallback{deadline_micros, context, done})
951                .second) {
952         return errors::AlreadyExists(
953             "Multiple session runs with the same batch key.");
954       }
955 
956       // If we have a non-empty tensor, finish the waitlisted runs,
957       // and store any remaining pieces.
958       if (nonempty_input) {
959         for (size_t i = 0; i < batch_keys.size(); ++i) {
960           auto runs_it = waiting_callbacks_.find(batch_keys[i]);
961           if (runs_it != waiting_callbacks_.end()) {
962             runs_it->second.context->set_output(0, split_inputs[i]);
963             done_callbacks_to_call.push_back(runs_it->second.done);
964             waiting_callbacks_.erase(runs_it);
965           } else {
966             // Note: the deadline here is in case we are arriving late and the
967             // kernel that should rendezvous with this tensor has already waited
968             // and timed out.
969             if (!waiting_tensors_
970                      .emplace(batch_keys[i],
971                               WaitingTensor{deadline_micros, split_inputs[i]})
972                      .second) {
973               return errors::AlreadyExists(
974                   "Multiple tensors returned for same batch key.");
975             }
976           }
977         }
978       }
979 
980       return Status::OK();
981     }();
982 
983     for (const AsyncOpKernel::DoneCallback& done_callback :
984          done_callbacks_to_call) {
985       done_callback();
986     }
987 
988     return status;
989   }
990 
991  private:
992   // Evicts waiting tensors and callbacks that have exceeded their deadline.
EnforceTimeout()993   void EnforceTimeout() {
994     const uint64 now = Env::Default()->NowMicros();
995     std::vector<WaitingCallback> evicted_callbacks;
996 
997     {
998       mutex_lock ml(mu_);
999 
1000       for (auto it = waiting_tensors_.begin(); it != waiting_tensors_.end();) {
1001         const WaitingTensor& waiting_tensor = it->second;
1002         if (waiting_tensor.deadline_micros < now) {
1003           it = waiting_tensors_.erase(it);
1004         } else {
1005           ++it;
1006         }
1007       }
1008 
1009       for (auto it = waiting_callbacks_.begin();
1010            it != waiting_callbacks_.end();) {
1011         const WaitingCallback& waiting_callback = it->second;
1012         if (waiting_callback.deadline_micros < now) {
1013           evicted_callbacks.push_back(waiting_callback);
1014           it = waiting_callbacks_.erase(it);
1015         } else {
1016           ++it;
1017         }
1018       }
1019     }
1020 
1021     for (const WaitingCallback& evicted_callback : evicted_callbacks) {
1022       evicted_callback.context->CtxFailureWithWarning(errors::DeadlineExceeded(
1023           "Batched data did not arrive within timeout window."));
1024       evicted_callback.done();
1025     }
1026   }
1027 
1028   struct WaitingTensor {
1029     uint64 deadline_micros;
1030     Tensor tensor;
1031   };
1032 
1033   struct WaitingCallback {
1034     uint64 deadline_micros;
1035     OpKernelContext* context;
1036     AsyncOpKernel::DoneCallback done;
1037   };
1038 
1039   const int32 timeout_micros_;
1040 
1041   mutex mu_;
1042 
1043   // Maps keyed by BatchKey of tensors waiting for callbacks and callbacks
1044   // waiting for tensors.
1045   std::unordered_map<int64, WaitingTensor> waiting_tensors_ GUARDED_BY(mu_);
1046   std::unordered_map<int64, WaitingCallback> waiting_callbacks_ GUARDED_BY(mu_);
1047 
1048   // A thread that evicts waiting tensors and callbacks that have exceeded their
1049   // deadline.
1050   std::unique_ptr<serving::PeriodicFunction> timeout_enforcer_;
1051 };
1052 
1053 class UnbatchKernel : public AsyncOpKernel {
1054  public:
UnbatchKernel(OpKernelConstruction * c)1055   explicit UnbatchKernel(OpKernelConstruction* c) : AsyncOpKernel(c) {
1056     OP_REQUIRES_OK(c, c->GetAttr("container", &container_));
1057     OP_REQUIRES_OK(c, c->GetAttr("shared_name", &shared_name_));
1058     // If shared_name is not supplied, use name instead (prevent collisions by
1059     // default).
1060     if (shared_name_.empty()) {
1061       shared_name_ = name();
1062     }
1063     OP_REQUIRES_OK(c, c->GetAttr("timeout_micros", &timeout_micros_));
1064   }
1065 
ComputeAsync(OpKernelContext * c,DoneCallback done)1066   void ComputeAsync(OpKernelContext* c, DoneCallback done) final {
1067     UnbatchResource* ubr;
1068     std::function<Status(UnbatchResource**)> creator =
1069         [this](UnbatchResource** r) {
1070           *r = new UnbatchResource(timeout_micros_);
1071           return Status::OK();
1072         };
1073     OP_REQUIRES_OK_ASYNC(c,
1074                          c->resource_manager()->LookupOrCreate(
1075                              container_, shared_name_, &ubr, creator),
1076                          done);
1077     auto status = ubr->Compute(c, done);
1078     ubr->Unref();
1079     OP_REQUIRES_OK_ASYNC(c, status, done);
1080     // Assume ubr calls done, so nothing to do here.
1081   }
1082 
1083  private:
1084   string container_;
1085   string shared_name_;
1086   int32 timeout_micros_;
1087 };
1088 REGISTER_KERNEL_BUILDER(Name("Unbatch").Device(DEVICE_CPU), UnbatchKernel);
1089 
1090 // A class encapsulating the state and logic for batching tensors
1091 // deterministically for the gradient of unbatch.
1092 class UnbatchGradResource : public ResourceBase {
1093  public:
UnbatchGradResource()1094   UnbatchGradResource() {}
1095 
DebugString() const1096   string DebugString() const final { return "UnbatchGradResource"; }
1097 
1098   // Flushes the information for one batch, given its context and done
1099   // callback. Clears all information about it from the available_tensors_.
OutputBatch(OpKernelContext * context,const AsyncOpKernel::DoneCallback & done)1100   Status OutputBatch(OpKernelContext* context,
1101                      const AsyncOpKernel::DoneCallback& done)
1102       EXCLUSIVE_LOCKS_REQUIRED(mu_) {
1103     const Tensor& batch_index_t = context->input(1);
1104     auto batch_index =
1105         batch_index_t.shaped<int64, 2>({batch_index_t.dim_size(0), 3});
1106     std::vector<Tensor> tensors;
1107     for (int i = 0; i < batch_index_t.dim_size(0); ++i) {
1108       auto available_it = available_tensors_.find(batch_index(i, 0));
1109       if (available_it == available_tensors_.end()) {
1110         return errors::Internal("bad bookkeeping of available tensors.");
1111       }
1112       tensors.push_back(available_it->second);
1113       available_tensors_.erase(available_it);
1114     }
1115 
1116     const DataType type = tensors[0].dtype();
1117     Tensor concatenated_tensor;
1118     switch (type) {
1119 #define CASE(type)                                                            \
1120   case DataTypeToEnum<type>::value:                                           \
1121     TF_RETURN_IF_ERROR(Concat<type>(context, tensors, &concatenated_tensor)); \
1122     context->set_output(0, concatenated_tensor);                              \
1123     break;
1124       TF_CALL_ALL_TYPES(CASE);
1125 #undef CASE
1126       default:
1127         return errors::InvalidArgument("Unsupported data type: ", type);
1128     }
1129     done();
1130     return Status::OK();
1131   }
1132 
1133   // Ingests data from one invocation of the op.
Compute(OpKernelContext * context,const AsyncOpKernel::DoneCallback & done)1134   Status Compute(OpKernelContext* context,
1135                  const AsyncOpKernel::DoneCallback& done) {
1136     const Tensor& data_t = context->input(0);
1137     const Tensor& batch_index_t = context->input(1);
1138     const Tensor& grad_t = context->input(2);
1139 
1140     mutex_lock ml(mu_);
1141 
1142     const int64 batch_key = context->input(3).scalar<int64>()();
1143     // Mark our tensor as available.
1144     if (!available_tensors_.emplace(batch_key, grad_t).second) {
1145       return errors::InvalidArgument("Two runs with the same batch key.");
1146     }
1147 
1148     // Check whether we have a valid input tensor and, if so, create its
1149     // dispatch logic.
1150     if (data_t.NumElements() > 0) {
1151       if (batch_index_t.NumElements() == 0) {
1152         return errors::InvalidArgument(
1153             "batch_index is empty while the tensor isn't.");
1154       }
1155       std::unordered_set<int64> missing_tensors;
1156       const auto batch_index =
1157           batch_index_t.shaped<int64, 2>({batch_index_t.dim_size(0), 3});
1158       for (int i = 0; i < batch_index_t.dim_size(0); ++i) {
1159         const int64 batch_key = batch_index(i, 0);
1160         if (available_tensors_.find(batch_key) == available_tensors_.end()) {
1161           missing_tensors.emplace(batch_key);
1162         }
1163       }
1164       if (missing_tensors.empty()) {
1165         return OutputBatch(context, done);
1166       }
1167       if (!available_batches_
1168                .emplace(batch_key, Batch{missing_tensors, context, done})
1169                .second) {
1170         return errors::InvalidArgument(
1171             "Batch key with valid batch used twice.");
1172       }
1173       for (const int64 i : missing_tensors) {
1174         if (!desired_tensor_to_batch_map_.emplace(i, batch_key).second) {
1175           return errors::InvalidArgument(
1176               "Missing tensor wanted by more than one batch.");
1177         }
1178       }
1179     } else {
1180       // If we don't have a valid input tensor we can output an empty tensor and
1181       // call our done closure.
1182       TensorShape output_shape(grad_t.shape());
1183       output_shape.set_dim(0, 0);
1184       Tensor* output = nullptr;
1185       TF_RETURN_IF_ERROR(context->allocate_output(0, output_shape, &output));
1186       done();
1187     }
1188 
1189     // Search to see whether our tensor is desired by any existing batch.
1190     auto desire_it = desired_tensor_to_batch_map_.find(batch_key);
1191     if (desire_it != desired_tensor_to_batch_map_.end()) {
1192       // Mark our tensor as no longer missing.
1193       auto batch_it = available_batches_.find(desire_it->second);
1194       desired_tensor_to_batch_map_.erase(desire_it);
1195       if (batch_it == available_batches_.end()) {
1196         return errors::InvalidArgument("Batch no longer exists.");
1197       }
1198       batch_it->second.missing_tensors.erase(batch_key);
1199       // If all tensors are available we should concatenate them and dispatch
1200       // the batch.
1201       if (batch_it->second.missing_tensors.empty()) {
1202         TF_RETURN_IF_ERROR(
1203             OutputBatch(batch_it->second.context, batch_it->second.done));
1204         available_batches_.erase(batch_it);
1205       }
1206     }
1207     return Status::OK();
1208   }
1209 
1210  private:
1211   mutex mu_;
1212 
1213   // Represents a still-incomplete batch of tensors. When all tensors become
1214   // available they will be concatenated in the right order and sent through the
1215   // context.
1216   struct Batch {
1217     // Batch keys for tensors which are still missing from this batch. When this
1218     // is empty the Tensors can be concatenated and forwarded.
1219     std::unordered_set<int64> missing_tensors;
1220 
1221     // Context and callback for the session responsible for finishing this
1222     // batch.
1223     OpKernelContext* context;
1224     AsyncOpKernel::DoneCallback done;
1225   };
1226 
1227   // Map from batch key of the session which will output the batched gradients
1228   // to still-incomplete batches.
1229   std::unordered_map<int64, Batch> available_batches_;
1230 
1231   // Map from batch key to tensors which are waiting for their batches to be
1232   // available.
1233   std::unordered_map<int64, Tensor> available_tensors_;
1234 
1235   // Map from batch key of a tensor which is not yet available to the batch key
1236   // of the batch to which it belongs.
1237   std::unordered_map<int64, int64> desired_tensor_to_batch_map_;
1238 };
1239 
1240 class UnbatchGradKernel : public AsyncOpKernel {
1241  public:
UnbatchGradKernel(OpKernelConstruction * c)1242   explicit UnbatchGradKernel(OpKernelConstruction* c) : AsyncOpKernel(c) {
1243     OP_REQUIRES_OK(c, c->GetAttr("container", &container_));
1244     OP_REQUIRES_OK(c, c->GetAttr("shared_name", &shared_name_));
1245     // If shared_name is not supplied, use name instead (prevent collisions by
1246     // default).
1247     if (shared_name_.empty()) {
1248       shared_name_ = name();
1249     }
1250   }
1251 
ComputeAsync(OpKernelContext * c,DoneCallback done)1252   void ComputeAsync(OpKernelContext* c, DoneCallback done) final {
1253     UnbatchGradResource* ubr;
1254     std::function<Status(UnbatchGradResource**)> creator =
1255         [](UnbatchGradResource** r) {
1256           *r = new UnbatchGradResource();
1257           return Status::OK();
1258         };
1259     OP_REQUIRES_OK_ASYNC(c,
1260                          c->resource_manager()->LookupOrCreate(
1261                              container_, shared_name_, &ubr, creator),
1262                          done);
1263     Status status = ubr->Compute(c, done);
1264     ubr->Unref();
1265     OP_REQUIRES_OK_ASYNC(c, status, done);
1266     // Assume ubr calls done, so nothing to do here.
1267   }
1268 
1269  private:
1270   string container_;
1271   string shared_name_;
1272 };
1273 REGISTER_KERNEL_BUILDER(Name("UnbatchGrad").Device(DEVICE_CPU),
1274                         UnbatchGradKernel);
1275 
1276 }  // namespace tensorflow
1277