1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/kernels/batching_util/batch_resource_base.h"
17 
18 #include "absl/time/time.h"
19 #include "absl/types/optional.h"
20 #include "tensorflow/core/common_runtime/cost_measurement_registry.h"
21 #include "tensorflow/core/framework/ops_util.h"
22 #include "tensorflow/core/framework/tensor_util.h"
23 #include "tensorflow/core/kernels/batching_util/concat_split_util.h"
24 #include "tensorflow/core/lib/gtl/cleanup.h"
25 #include "tensorflow/core/lib/monitoring/counter.h"
26 #include "tensorflow/core/lib/monitoring/gauge.h"
27 #include "tensorflow/core/lib/monitoring/percentile_sampler.h"
28 #include "tensorflow/core/lib/monitoring/sampler.h"
29 #include "tensorflow/core/profiler/lib/traceme.h"
30 #include "tensorflow/core/profiler/lib/traceme_encode.h"
31 #include "tensorflow/core/util/incremental_barrier.h"
32 
33 namespace tensorflow {
34 namespace serving {
35 namespace {
36 
GetCostMeasurementType()37 const char* GetCostMeasurementType() {
38   return std::getenv("TF_COST_MEASUREMENT_TYPE");
39 }
40 
41 // TODO(b/181883417): Replace with RecordPaddingSizeV2.
RecordPaddingSize(int32_t padding_size,const string & model_name,int32_t execution_batch_size,const string & op_name)42 void RecordPaddingSize(int32_t padding_size, const string& model_name,
43                        int32_t execution_batch_size, const string& op_name) {
44   static auto* cell = tensorflow::monitoring::PercentileSampler<3>::New(
45       {"/tensorflow/serving/batching/padding_size",
46        "Tracks the padding size distribution on batches by model_name (if "
47        "available).",
48        "model_name", "execution_batch_size", "op_name"},
49       /*percentiles=*/{25.0, 50.0, 75.0, 90.0, 95.0, 99.0},
50       /*max_samples=*/1024, tensorflow::monitoring::UnitOfMeasure::kNumber);
51   cell->GetCell(model_name, absl::StrCat(execution_batch_size), op_name)
52       ->Add(static_cast<double>(padding_size));
53 }
54 
RecordPaddingSizeV2(int32_t padding_size,const string & model_name,int32_t execution_batch_size,const string & op_name)55 void RecordPaddingSizeV2(int32_t padding_size, const string& model_name,
56                          int32_t execution_batch_size, const string& op_name) {
57   static auto* cell = tensorflow::monitoring::Sampler<3>::New(
58       {"/tensorflow/serving/batching/padding_size_v2",
59        "Tracks the padding size distribution on batches by model_name (if "
60        "available).",
61        "model_name", "execution_batch_size", "op_name"},
62       // It's 14 buckets with the last bucket being 2^13 to DBL_MAX;
63       // so the limits are [1, 2, 4, 8, ..., 8 * 1024, DBL_MAX].
64       monitoring::Buckets::Exponential(1, 2, 14));
65   cell->GetCell(model_name, absl::StrCat(execution_batch_size), op_name)
66       ->Add(static_cast<double>(padding_size));
67 }
68 
69 // TODO(b/181883417): Replace with RecordInputBatchSizeV2.
RecordInputBatchSize(int32_t batch_size,const string & model_name,const string & op_name)70 void RecordInputBatchSize(int32_t batch_size, const string& model_name,
71                           const string& op_name) {
72   static auto* cell = tensorflow::monitoring::PercentileSampler<2>::New(
73       {"/tensorflow/serving/batching/input_batch_size",
74        "Tracks the batch size distribution on the inputs by model_name (if "
75        "available).",
76        "model_name", "op_name"},
77       /*percentiles=*/{25.0, 50.0, 75.0, 90.0, 95.0, 99.0},
78       /*max_samples=*/1024, tensorflow::monitoring::UnitOfMeasure::kNumber);
79   cell->GetCell(model_name, op_name)->Add(static_cast<double>(batch_size));
80 }
81 
RecordInputBatchSizeV2(int32_t batch_size,const string & model_name,const string & op_name)82 void RecordInputBatchSizeV2(int32_t batch_size, const string& model_name,
83                             const string& op_name) {
84   static auto* cell = tensorflow::monitoring::Sampler<2>::New(
85       {"/tensorflow/serving/batching/input_batch_size_v2",
86        "Tracks the batch size distribution on the inputs by model_name (if "
87        "available).",
88        "model_name", "op_name"},
89       // It's 14 buckets with the last bucket being 2^13 to DBL_MAX;
90       // so the limits are [1, 2, 4, 8, ..., 8 * 1024, DBL_MAX].
91       monitoring::Buckets::Exponential(1, 2, 14));
92   cell->GetCell(model_name, op_name)->Add(static_cast<double>(batch_size));
93 }
94 
RecordProcessedBatchSize(int32_t batch_size,const string & model_name,const string & op_name)95 void RecordProcessedBatchSize(int32_t batch_size, const string& model_name,
96                               const string& op_name) {
97   static auto* cell = tensorflow::monitoring::PercentileSampler<2>::New(
98       {"/tensorflow/serving/batching/processed_batch_size",
99        "Tracks the batch size distribution on processing by model_name (if "
100        "available).",
101        "model_name", "op_name"},
102       /*percentiles=*/{25.0, 50.0, 75.0, 90.0, 95.0, 99.0},
103       /*max_samples=*/1024, tensorflow::monitoring::UnitOfMeasure::kNumber);
104   cell->GetCell(model_name, op_name)->Add(static_cast<double>(batch_size));
105 }
106 
107 // Export the exact number instead of the distribution of processed batch size.
RecordProcessedBatchSizeV2(int32_t batch_size,const string & model_name,const string & op_name)108 void RecordProcessedBatchSizeV2(int32_t batch_size, const string& model_name,
109                                 const string& op_name) {
110   static auto* cell = monitoring::Counter<3>::New(
111       "/tensorflow/serving/batching/processed_batch_size_v2",
112       "Tracks the batch size on processing by model_name and op name (if "
113       "available).",
114       "model_name", "op_name", "batch_size");
115   cell->GetCell(model_name, op_name, std::to_string(batch_size))
116       ->IncrementBy(1);
117 }
118 
119 // TODO(b/181883417): Replace with RecordBatchDelayUsV2.
RecordBatchDelayUs(int64_t batch_delay_us,const string & model_name,const string & op_name)120 void RecordBatchDelayUs(int64_t batch_delay_us, const string& model_name,
121                         const string& op_name) {
122   static auto* cell = monitoring::PercentileSampler<2>::New(
123       {"/tensorflow/serving/batching/batch_delay_us",
124        "Tracks the batching delay (in microseconds) for inputs by model_name "
125        "(if available).",
126        "model_name", "op_name"},
127       /*percentiles=*/{25.0, 50.0, 75.0, 90.0, 95.0, 99.0},
128       /*max_samples=*/1024, monitoring::UnitOfMeasure::kTime);
129   cell->GetCell(model_name, op_name)->Add(static_cast<double>(batch_delay_us));
130 }
131 
RecordBatchDelayUsV2(int64_t batch_delay_us,const string & model_name,const string & op_name)132 void RecordBatchDelayUsV2(int64_t batch_delay_us, const string& model_name,
133                           const string& op_name) {
134   static auto* cell = tensorflow::monitoring::Sampler<2>::New(
135       {"/tensorflow/serving/batching/batch_delay_us_v2",
136        "Tracks the batching delay (in microseconds) for inputs by model_name "
137        "(if available).",
138        "model_name", "op_name"},
139       // It's 27 buckets with the last bucket being 2^26 to DBL_MAX;
140       // so the limits are [1, 2, 4, 8, ..., 64 * 1024 * 1024, DBL_MAX].
141       monitoring::Buckets::Exponential(1, 2, 27));
142   cell->GetCell(model_name, op_name)->Add(static_cast<double>(batch_delay_us));
143 }
144 
RecordBatchParamBatchTimeoutMicros(int64_t batch_timeout_micros,const string & model_name,const string & op_name)145 void RecordBatchParamBatchTimeoutMicros(int64_t batch_timeout_micros,
146                                         const string& model_name,
147                                         const string& op_name) {
148   static auto* cell = monitoring::Gauge<int64, 2>::New(
149       "/tensorflow/serving/batching/batch_timeout_micros",
150       "Tracks how long a request can wait before being processed by a batch.",
151       "model_name", "op_name");
152   cell->GetCell(model_name, op_name)->Set(batch_timeout_micros);
153 }
154 
RecordBatchParamMaxBatchSize(int64_t max_batch_size,const string & model_name,const string & op_name)155 void RecordBatchParamMaxBatchSize(int64_t max_batch_size,
156                                   const string& model_name,
157                                   const string& op_name) {
158   static auto* cell = monitoring::Gauge<int64, 2>::New(
159       "/tensorflow/serving/batching/max_batch_size",
160       "Tracks the maximum size of a batch.", "model_name", "op_name");
161   cell->GetCell(model_name, op_name)->Set(max_batch_size);
162 }
163 
RecordBatchParamMaxEnqueuedBatches(int64_t max_enqueued_batches,const string & model_name,const string & op_name)164 void RecordBatchParamMaxEnqueuedBatches(int64_t max_enqueued_batches,
165                                         const string& model_name,
166                                         const string& op_name) {
167   static auto* cell = monitoring::Gauge<int64, 2>::New(
168       "/tensorflow/serving/batching/max_enqueued_batches",
169       "Tracks the maximum number of enqueued batches.", "model_name",
170       "op_name");
171   cell->GetCell(model_name, op_name)->Set(max_enqueued_batches);
172 }
173 
RecordBatchParamAllowedBatchSizes(const string & allowed_batch_sizes,const string & model_name,const string & op_name)174 void RecordBatchParamAllowedBatchSizes(const string& allowed_batch_sizes,
175                                        const string& model_name,
176                                        const string& op_name) {
177   static auto* cell = monitoring::Gauge<string, 2>::New(
178       "/tensorflow/serving/batching/allowed_batch_sizes",
179       "Tracks the sizes that are allowed to form a batch.", "model_name",
180       "op_name");
181   cell->GetCell(model_name, op_name)->Set(allowed_batch_sizes);
182 }
183 
GetModelName(OpKernelContext * ctx)184 const string& GetModelName(OpKernelContext* ctx) {
185   static string* kModelNameUnset = new string("model_name_unset");
186   if (!ctx->session_metadata()) return *kModelNameUnset;
187   if (ctx->session_metadata()->name().empty()) return *kModelNameUnset;
188   return ctx->session_metadata()->name();
189 }
190 
191 }  // namespace
192 
193 std::unique_ptr<BatchResourceBase::BatchTask>
CreateSplitTask(int split_index,AsyncOpKernel::DoneCallback done_callback)194 BatchResourceBase::BatchTask::CreateSplitTask(
195     int split_index, AsyncOpKernel::DoneCallback done_callback) {
196   std::unique_ptr<BatchTask> task = CreateDerivedTask();
197 
198   task->guid = this->guid;
199   task->propagated_context = Context(ContextKind::kThread);
200   task->inputs.reserve(this->inputs.size());
201   task->captured_inputs = this->captured_inputs;
202   task->context = this->context;
203   task->done_callback = done_callback;
204   task->split_index = split_index;
205   task->output = this->output;
206   task->status = this->status;
207   task->is_partial = true;
208   task->start_time = this->start_time;
209   task->request_cost = this->request_cost;
210 
211   return task;
212 }
213 
214 using ::tensorflow::concat_split_util::Concat;
215 using ::tensorflow::concat_split_util::Split;
216 using TensorMatrix = std::vector<std::vector<Tensor>>;
217 
RegisterInput(int64_t guid,OpKernelContext * context,const string & batcher_queue_name,AsyncOpKernel::DoneCallback done_callback)218 Status BatchResourceBase::RegisterInput(
219     int64_t guid, OpKernelContext* context, const string& batcher_queue_name,
220     AsyncOpKernel::DoneCallback done_callback) {
221   std::unique_ptr<BatchTask> batch_components;
222   TF_RETURN_IF_ERROR(CreateBatchTask(context, &batch_components));
223   batch_components->start_time = EnvTime::NowNanos();
224   batch_components->guid = guid;
225   batch_components->propagated_context = Context(ContextKind::kThread);
226   OpInputList tensors;
227   TF_RETURN_IF_ERROR(context->input_list("in_tensors", &tensors));
228   batch_components->inputs.reserve(tensors.size());
229   for (const Tensor& tensor : tensors) {
230     if (tensor.shape().dims() == 0) {
231       return errors::InvalidArgument(
232           "Batching input tensors must have at least one dimension");
233     }
234     if (tensors.size() >= 2 &&
235         tensor.shape().dim_size(0) != tensors[0].shape().dim_size(0)) {
236       return errors::InvalidArgument(
237           "Batching input tensors supplied in a given op invocation must "
238           "have equal 0th-dimension size");
239     }
240     batch_components->inputs.push_back(tensor);
241   }
242   RecordInputBatchSize(tensors[0].shape().dim_size(0), GetModelName(context),
243                        context->op_kernel().name_view().data());
244   RecordInputBatchSizeV2(tensors[0].shape().dim_size(0), GetModelName(context),
245                          context->op_kernel().name());
246   RecordBatchParamBatchTimeoutMicros(
247       batcher_queue_options_.batch_timeout_micros, GetModelName(context),
248       context->op_kernel().name_view().data());
249   RecordBatchParamMaxBatchSize(batcher_queue_options_.max_execution_batch_size,
250                                GetModelName(context),
251                                context->op_kernel().name_view().data());
252   RecordBatchParamMaxEnqueuedBatches(
253       batcher_queue_options_.max_enqueued_batches, GetModelName(context),
254       context->op_kernel().name_view().data());
255   RecordBatchParamAllowedBatchSizes(allowed_batch_sizes_str_,
256                                     GetModelName(context),
257                                     context->op_kernel().name_view().data());
258 
259   // Degenerate case where the input is empty. Just return an empty tensor.
260   if (tensors[0].shape().dim_size(0) == 0) {
261     for (int i = 0; i < context->num_outputs(); i++) {
262       Tensor* empty_output;
263       AllocatorAttributes cpu_alloc;
264       cpu_alloc.set_on_host(true);
265       TF_RETURN_IF_ERROR(context->allocate_output(i, TensorShape({0}),
266                                                   &empty_output, cpu_alloc));
267     }
268     done_callback();
269     return Status::OK();
270   }
271   OpInputList captured_tensors;
272   const auto captured_status =
273       context->input_list("captured_tensors", &captured_tensors);
274   if (captured_status.ok()) {
275     batch_components->captured_inputs.reserve(captured_tensors.size());
276     for (const Tensor& captured_tensor : captured_tensors) {
277       batch_components->captured_inputs.push_back(captured_tensor);
278     }
279   }
280   batch_components->context = context;
281   batch_components->done_callback = std::move(done_callback);
282   batch_components->split_index = 0;
283   batch_components->output = std::make_shared<TensorMatrix>();
284   batch_components->status = std::make_shared<ThreadSafeStatus>();
285 
286   BatcherQueueT* batcher_queue;
287   TF_RETURN_IF_ERROR(
288       LookupOrCreateBatcherQueue(batcher_queue_name, &batcher_queue));
289   return batcher_queue->Schedule(&batch_components);
290 }
291 
292 /*static*/ BatchResourceBase::BatcherT::QueueOptions
GetBatcherQueueOptions(int32_t num_batch_threads,int32_t max_batch_size,int32_t batch_timeout_micros,int32_t max_enqueued_batches,const std::vector<int32> & allowed_batch_sizes,bool enable_large_batch_splitting)293 BatchResourceBase::GetBatcherQueueOptions(
294     int32_t num_batch_threads, int32_t max_batch_size,
295     int32_t batch_timeout_micros, int32_t max_enqueued_batches,
296     const std::vector<int32>& allowed_batch_sizes,
297     bool enable_large_batch_splitting) {
298   BatcherT::QueueOptions batcher_queue_options;
299   batcher_queue_options.input_batch_size_limit = max_batch_size;
300   batcher_queue_options.max_enqueued_batches = max_enqueued_batches;
301   batcher_queue_options.batch_timeout_micros = batch_timeout_micros;
302   batcher_queue_options.enable_large_batch_splitting =
303       enable_large_batch_splitting;
304   if (enable_large_batch_splitting) {
305     batcher_queue_options.split_input_task_func =
306         [](std::unique_ptr<BatchTask>* input_task,
307            int open_batch_remaining_slot, int max_batch_size,
308            std::vector<std::unique_ptr<BatchTask>>* output_tasks) -> Status {
309       return SplitInputTask(input_task, open_batch_remaining_slot,
310                             max_batch_size, output_tasks);
311     };
312 
313     if (allowed_batch_sizes.empty()) {
314       batcher_queue_options.max_execution_batch_size = max_batch_size;
315     } else {
316       batcher_queue_options.max_execution_batch_size =
317           *allowed_batch_sizes.rbegin();
318     }
319   }
320 
321   return batcher_queue_options;
322 }
323 
324 /*static*/ BatchResourceBase::AdaptiveBatcherT::QueueOptions
GetAdaptiveBatcherQueueOptions(int32_t max_batch_size,int32_t batch_timeout_micros,int32_t max_enqueued_batches,bool enable_large_batch_splitting,const std::vector<int32> & allowed_batch_sizes)325 BatchResourceBase::GetAdaptiveBatcherQueueOptions(
326     int32_t max_batch_size, int32_t batch_timeout_micros,
327     int32_t max_enqueued_batches, bool enable_large_batch_splitting,
328     const std::vector<int32>& allowed_batch_sizes) {
329   AdaptiveBatcherT::QueueOptions batcher_queue_options;
330   batcher_queue_options.max_input_task_size =
331       absl::make_optional(max_batch_size);
332   batcher_queue_options.max_enqueued_batches = max_enqueued_batches;
333   batcher_queue_options.batch_timeout_micros = batch_timeout_micros;
334   if (allowed_batch_sizes.empty()) {
335     batcher_queue_options.max_batch_size = max_batch_size;
336   } else {
337     batcher_queue_options.max_batch_size = *allowed_batch_sizes.rbegin();
338   }
339 
340   if (enable_large_batch_splitting) {
341     batcher_queue_options.split_input_task_func =
342         [](std::unique_ptr<BatchTask>* input_task,
343            int open_batch_remaining_slot, int max_batch_size,
344            std::vector<std::unique_ptr<BatchTask>>* output_tasks) -> Status {
345       return SplitInputTask(input_task, open_batch_remaining_slot,
346                             max_batch_size, output_tasks);
347     };
348   }
349 
350   return batcher_queue_options;
351 }
352 
ValidateBatch(const BatchT & batch)353 /*static*/ Status BatchResourceBase::ValidateBatch(const BatchT& batch) {
354   for (int task_idx = 0; task_idx < batch.num_tasks(); ++task_idx) {
355     const BatchResourceBase::BatchTask& task = batch.task(task_idx);
356 
357     if (task.inputs.size() != batch.task(0).inputs.size()) {
358       return errors::InvalidArgument(
359           "Batching inputs must have equal number of edges");
360     }
361   }
362 
363   return Status::OK();
364 }
365 
366 // Returns the smallest entry in 'allowed_batch_sizes_' that is greater than
367 // or equal to 'batch_size'. If 'allowed_batch_sizes_' is empty, simply
368 // returns 'batch_size'.
RoundToLowestAllowedBatchSize(int batch_size) const369 int BatchResourceBase::RoundToLowestAllowedBatchSize(int batch_size) const {
370   if (allowed_batch_sizes_.empty()) {
371     return batch_size;
372   }
373   for (int allowed_size : allowed_batch_sizes_) {
374     if (allowed_size >= batch_size) {
375       return allowed_size;
376     }
377   }
378   LOG(ERROR) << "Batch size " << batch_size
379              << " is greater than largest allowed size; "
380                 "ignoring allowed sizes constraint.";
381   return batch_size;
382 }
383 
ConcatInputTensors(const BatchT & batch,OpKernelContext * context,std::vector<Tensor> * concatenated_tensors) const384 Status BatchResourceBase::ConcatInputTensors(
385     const BatchT& batch, OpKernelContext* context,
386     std::vector<Tensor>* concatenated_tensors) const {
387   if (batch.num_tasks() == 0) {
388     return errors::InvalidArgument("Empty batch.");
389   }
390 
391   const int padded_batch_size = RoundToLowestAllowedBatchSize(batch.size());
392   const int padding_amount = padded_batch_size - batch.size();
393   profiler::TraceMe trace_me([padded_batch_size, padding_amount]() {
394     return profiler::TraceMeEncode(
395         "ConcatInputTensors", {{"batch_size_after_padding", padded_batch_size},
396                                {"padding_amount", padding_amount}});
397   });
398   RecordPaddingSize(padding_amount, GetModelName(context), padded_batch_size,
399                     context->op_kernel().name_view().data());
400   RecordPaddingSizeV2(padding_amount, GetModelName(context), padded_batch_size,
401                       context->op_kernel().name());
402   RecordProcessedBatchSize(padded_batch_size, GetModelName(context),
403                            context->op_kernel().name_view().data());
404   RecordProcessedBatchSizeV2(padded_batch_size, GetModelName(context),
405                              string(context->op_kernel().name_view()));
406 
407   // All tasks should have the same number of input edges.
408   const int num_inputs = batch.task(0).inputs.size();
409   concatenated_tensors->reserve(num_inputs);
410 
411   // Process each input one at a time (the typical case has just one).
412   for (int i = 0; i < num_inputs; ++i) {
413     // Concatenate the tasks ith input tensors into a big output tensor.
414     std::vector<Tensor> to_concatenate;
415     to_concatenate.reserve(batch.num_tasks());
416     for (int task_idx = 0; task_idx < batch.num_tasks(); ++task_idx) {
417       to_concatenate.push_back(batch.task(task_idx).inputs.at(i));
418     }
419 
420     // Add padding as needed. Use the first row of the first task's tensor as
421     // the data for padding.
422     if (padding_amount > 0) {
423       const Tensor& padding_source = batch.task(0).inputs.at(i);
424       Tensor padding;
425       if (padding_source.shape().dim_size(0) == 0) {
426         return errors::InvalidArgument(
427             "Cannot use an empty tensor with zero rows as padding when "
428             "batching. (Input ",
429             i, " got shape ", padding_source.shape().DebugString(), ".)");
430       }
431       if (padding_source.shape().dim_size(0) == 1) {
432         padding = padding_source;
433       } else {
434         padding = padding_source.Slice(0, 1);
435       }
436       for (int i = 0; i < padding_amount; ++i) {
437         to_concatenate.push_back(padding);
438       }
439     }
440 
441     Tensor concatenated_tensor;
442     Status concat_status =
443         Concat(context, to_concatenate, &concatenated_tensor);
444     TF_RETURN_IF_ERROR(concat_status);
445     concatenated_tensors->push_back(concatenated_tensor);
446   }
447   return Status::OK();
448 }
449 
SplitInputTask(std::unique_ptr<BatchTask> * input_task_ptr,int open_batch_remaining_slot,int max_batch_size,std::vector<std::unique_ptr<BatchTask>> * output_tasks)450 /*static*/ Status BatchResourceBase::SplitInputTask(
451     std::unique_ptr<BatchTask>* input_task_ptr, int open_batch_remaining_slot,
452     int max_batch_size, std::vector<std::unique_ptr<BatchTask>>* output_tasks) {
453   BatchTask& input_task = *(*input_task_ptr);
454   const int64_t input_task_size = input_task.size();
455 
456   DCHECK_GT(input_task_size, open_batch_remaining_slot);
457 
458   std::shared_ptr<ThreadSafeStatus> shared_status = input_task.status;
459 
460   // `split_task_done_callback` runs only after all splitted tasks are
461   // complete.
462   std::function<void()> split_task_done_callback =
463       [done_callback = input_task.done_callback, output = input_task.output,
464        op_kernel_context = input_task.context, status = shared_status]() {
465         const int num_output = op_kernel_context->num_outputs();
466         for (int i = 0; i < num_output; ++i) {
467           Tensor output_tensor;
468 
469           // Concat would memcpy each input tensor to one output tensor.
470           // In this context, Concat can be further optimized to get rid of
471           // some (probably all) memcpy when input tensors are slices of
472           // another copy.
473           std::vector<Tensor> to_concatenate;
474           to_concatenate.reserve(output->size());
475           for (int j = 0; j < output->size(); ++j) {
476             to_concatenate.push_back(std::move((*output)[j][i]));
477           }
478           const auto concat_status =
479               Concat(op_kernel_context, to_concatenate, &output_tensor);
480           if (!concat_status.ok()) {
481             status->Update(concat_status);
482           }
483 
484           op_kernel_context->set_output(i, std::move(output_tensor));
485         }
486         op_kernel_context->SetStatus(status->status());
487         done_callback();
488       };
489   IncrementalBarrier barrier(split_task_done_callback);
490 
491   std::vector<int64> output_task_sizes;
492 
493   if (open_batch_remaining_slot > 0) {
494     output_task_sizes.push_back(open_batch_remaining_slot);
495   }
496 
497   for (int left_task_size = input_task_size - open_batch_remaining_slot;
498        left_task_size > 0; left_task_size -= max_batch_size) {
499     int next_task_size = std::min(left_task_size, max_batch_size);
500     output_task_sizes.push_back(next_task_size);
501   }
502 
503   const int output_task_num = output_task_sizes.size();
504   input_task.output->resize(output_task_num);
505 
506   for (int i = 0; i < output_task_num; ++i) {
507     (*input_task.output)[i].resize(input_task.context->num_outputs());
508   }
509 
510   output_tasks->reserve(output_task_num);
511   for (int i = 0; i < output_task_num; i++) {
512     output_tasks->push_back(input_task.CreateSplitTask(i, barrier.Inc()));
513   }
514 
515   const int num_input_tensors = input_task.inputs.size();
516 
517   // Splits each input tensor according to `output_task_sizes`, and
518   // initializes input of `output_tasks` with split results.
519   for (int i = 0; i < num_input_tensors; ++i) {
520     std::vector<Tensor> split_tensors;
521     const Tensor& input_tensor = input_task.inputs[i];
522     // TODO(b/154140947):
523     // Figure out the optimal implementation of Split, by using
524     // 'Tensor::Slice' and eliminating unnecessary memcpy as much as possible.
525     const Status split_status = Split(input_task.context, input_tensor,
526                                       output_task_sizes, &split_tensors);
527     if (!split_status.ok()) {
528       return errors::Internal(
529           "When splitting input, Tensor split operation failed: ",
530           split_status.ToString());
531     }
532     if (split_tensors.size() != output_task_sizes.size()) {
533       return errors::Internal(
534           "When splitting input, tensor split operation did not work as "
535           "expected; got ",
536           split_tensors.size(), " splits; expected ", output_task_sizes.size());
537     }
538     for (int j = 0; j < output_tasks->size(); ++j) {
539       BatchTask& output_task = *((*output_tasks)[j]);
540       auto moved_tensor_iter = std::next(split_tensors.begin(), j);
541       std::move(moved_tensor_iter, moved_tensor_iter + 1,
542                 std::back_inserter(output_task.inputs));
543     }
544   }
545   return Status::OK();
546 }
547 
SplitOutputTensors(const std::vector<Tensor> & combined_outputs,BatchT * batch) const548 Status BatchResourceBase::SplitOutputTensors(
549     const std::vector<Tensor>& combined_outputs, BatchT* batch) const {
550   DCHECK_GE(batch->num_tasks(), 1);
551   if (batch->num_tasks() < 1) {
552     return errors::Internal("Batch size expected to be positive; was ",
553                             batch->num_tasks());
554   }
555 
556   std::vector<int64> task_sizes_plus_optional_padding;
557   task_sizes_plus_optional_padding.reserve(batch->num_tasks());
558   for (int i = 0; i < batch->num_tasks(); ++i) {
559     task_sizes_plus_optional_padding.push_back(batch->task(i).size());
560   }
561   const int padding_size =
562       RoundToLowestAllowedBatchSize(batch->size()) - batch->size();
563   if (padding_size > 0) {
564     task_sizes_plus_optional_padding.push_back(padding_size);
565   }
566 
567   // For each output tensor name, a divided-up tensor with one entry per task.
568   std::map<string, std::vector<Tensor>> split_tensors;
569 
570   DCHECK_EQ(batch->task(0).context->num_outputs(), combined_outputs.size());
571   int combined_outputs_size = combined_outputs.size();
572   if (combined_outputs_size != batch->task(0).context->num_outputs()) {
573     return errors::Internal("Wrong number of batched output tensors");
574   }
575 
576   // Generate 'split_tensors' and populate the context outputs.
577   for (int i = 0, iter_limit = combined_outputs.size(); i < iter_limit; ++i) {
578     const Tensor& output_tensor = combined_outputs[i];
579     if (output_tensor.shape().dims() == 0) {
580       return errors::FailedPrecondition(
581           "Batched output tensor has 0 dimensions");
582     }
583     if (output_tensor.shape().dim_size(0) !=
584         static_cast<int64>(batch->size() + padding_size)) {
585       return errors::FailedPrecondition(
586           "Batched output tensor's 0th dimension does not equal the sum of "
587           "the 0th dimension sizes of the input tensors");
588     }
589 
590     std::vector<Tensor> split_tensor;
591     const Status split_status = tensor::Split(
592         output_tensor, task_sizes_plus_optional_padding, &split_tensor);
593     DCHECK(split_status.ok()) << split_status.ToString();
594     if (!split_status.ok()) {
595       return errors::Internal("Tensor split operation failed: ",
596                               split_status.ToString());
597     }
598     DCHECK_EQ(split_tensor.size(), task_sizes_plus_optional_padding.size());
599     if (split_tensor.size() != task_sizes_plus_optional_padding.size()) {
600       return errors::Internal(
601           "Tensor split operation did not work as expected; got ",
602           split_tensor.size(), " splits; expected ",
603           task_sizes_plus_optional_padding.size());
604     }
605 
606     // Ignore a possible final split_tensors entry containing the padding.
607     for (int j = 0; j < batch->num_tasks(); ++j) {
608       BatchTask& task = *(batch->mutable_task(j));
609       if (task.is_partial) {
610         std::vector<Tensor>& tensor_vector = (*task.output)[task.split_index];
611         tensor_vector[i] = std::move(split_tensor[j]);
612       } else {
613         task.context->set_output(i, split_tensor[j]);
614       }
615     }
616   }
617 
618   return Status::OK();
619 }
620 
ProcessFuncBatch(std::unique_ptr<BatchT> batch) const621 void BatchResourceBase::ProcessFuncBatch(std::unique_ptr<BatchT> batch) const {
622   if (batch->empty()) {
623     return;
624   }
625 
626   const char* cost_measurement_type = GetCostMeasurementType();
627   auto batch_cost_measurement =
628       cost_measurement_type
629           ? CostMeasurementRegistry::CreateByNameOrNull(cost_measurement_type)
630           : nullptr;
631   int64_t processed_size = batch->size();
632   auto batch_cost_split_cleanup = gtl::MakeCleanup([&] {
633     SplitBatchCost(batch_cost_measurement.get(), processed_size, *batch);
634   });
635 
636   // We use the 'propagated_context' from one of the threads which setup one
637   // of the tasks. This will propagate any common context over all the threads
638   // which are running this Session, of which this BatchOp is a part.
639   WithContext wc(batch->task(batch->num_tasks() - 1).propagated_context);
640 
641   auto& last_task = batch->task(batch->num_tasks() - 1);
642   OpKernelContext* last_task_context = last_task.context;
643 
644   // Regardless of the outcome, we need to propagate the status to the
645   // individual tasks and signal that they are done. We use MakeCleanup() to
646   // ensure that this happens no matter how we exit the method below.
647   Status status;
648   bool cleanup_done = false;
649   auto cleanup_fn = [&cleanup_done, &batch](const Status& status) {
650     if (cleanup_done) {
651       return;
652     }
653     for (int i = 0; i < batch->num_tasks(); ++i) {
654       if (batch->task(i).is_partial) {
655         batch->mutable_task(i)->status->Update(status);
656       } else {
657         batch->mutable_task(i)->context->SetStatus(status);
658       }
659 
660       batch->mutable_task(i)->done_callback();
661     }
662     cleanup_done = true;
663   };
664 
665   auto finally =
666       gtl::MakeCleanup([&cleanup_fn, &status] { cleanup_fn(status); });
667 
668   status = ValidateBatch(*batch);
669   if (!status.ok()) {
670     return;
671   }
672 
673   std::vector<Tensor> concatenated_tensors;
674   status = ConcatInputTensors(*batch, last_task_context, &concatenated_tensors);
675   processed_size = RoundToLowestAllowedBatchSize(batch->size());
676   if (!status.ok()) {
677     return;
678   }
679 
680   std::vector<Tensor> combined_outputs;
681   std::vector<Tensor> args(concatenated_tensors.begin(),
682                            concatenated_tensors.end());
683   const auto& captured_inputs =
684       batch->task(batch->num_tasks() - 1).captured_inputs;
685   args.insert(args.end(), captured_inputs.begin(), captured_inputs.end());
686 
687   uint64 current_time = EnvTime::NowNanos();
688   const string& model_name = GetModelName(last_task_context);
689   for (int i = 0; i < batch->num_tasks(); ++i) {
690     RecordBatchDelayUs((current_time - batch->task(i).start_time) * 1e-3,
691                        model_name,
692                        last_task_context->op_kernel().name_view().data());
693     RecordBatchDelayUsV2((current_time - batch->task(i).start_time) * 1e-3,
694                          model_name, last_task_context->op_kernel().name());
695   }
696   // Releases the cleanup method here, because the callback of the function
697   // library runtime will handle it now.
698   finally.release();
699   ProcessFuncBatchImpl(
700       last_task, args, &combined_outputs, [&](const Status& run_status) {
701         Status final_status;
702         auto run_finally = gtl::MakeCleanup([&]() {
703           // We do the cleanup here as an optimization, so that
704           // it runs in the underlying TF inter-op threadpool.
705           // Running it in the threadpool, let's the ensuing
706           // ops be scheduled faster, because the executor will
707           // add them to the front of the threadpool's task
708           // queue rather than the end.
709           cleanup_fn(final_status);
710         });
711         final_status = run_status;
712         if (!final_status.ok()) {
713           return;
714         }
715         final_status = SplitOutputTensors(combined_outputs, batch.get());
716       });
717 }
718 
719 // Processes a batch of one or more BatchTask entries.
ProcessBatch(std::unique_ptr<BatchT> batch) const720 void BatchResourceBase::ProcessBatch(std::unique_ptr<BatchT> batch) const {
721   if (batch->empty()) {
722     return;
723   }
724 
725   const char* cost_measurement_type = GetCostMeasurementType();
726   auto batch_cost_measurement =
727       cost_measurement_type
728           ? CostMeasurementRegistry::CreateByNameOrNull(cost_measurement_type)
729           : nullptr;
730   int64_t processed_size = batch->size();
731   auto batch_cost_split_cleaner = gtl::MakeCleanup([&] {
732     SplitBatchCost(batch_cost_measurement.get(), processed_size, *batch);
733   });
734 
735   WithContext wc(batch->task(batch->num_tasks() - 1).propagated_context);
736 
737   OpKernelContext* last_task_context =
738       batch->task(batch->num_tasks() - 1).context;
739   AsyncOpKernel::DoneCallback last_task_callback =
740       batch->task(batch->num_tasks() - 1).done_callback;
741 
742   OP_REQUIRES_OK_ASYNC(last_task_context, ValidateBatch(*batch),
743                        last_task_callback);
744 
745   // All tasks should have the same number of input edges.
746   const int num_input_edges = batch->task(0).inputs.size();
747   std::vector<Tensor> concatenated_tensors;
748   const Status concat_status =
749       ConcatInputTensors(*batch, last_task_context, &concatenated_tensors);
750   processed_size = RoundToLowestAllowedBatchSize(batch->size());
751   OP_REQUIRES_OK_ASYNC(last_task_context, concat_status, last_task_callback);
752 
753   // Process each input edge one at a time (the typical case has just one).
754   for (int i = 0; i < num_input_edges; ++i) {
755     last_task_context->set_output(i, concatenated_tensors[i]);
756 
757     // Emit batch->num_tasks() - 1 empty output tensors.
758     for (int task_idx = 0; task_idx < batch->num_tasks() - 1; ++task_idx) {
759       const BatchTask& task = batch->task(task_idx);
760       TensorShape output_shape(task.inputs[i].shape());
761       output_shape.set_dim(0, 0);
762       Tensor* output = nullptr;
763       OP_REQUIRES_OK_ASYNC(
764           task.context, task.context->allocate_output(i, output_shape, &output),
765           task.done_callback);
766     }
767   }
768   // Emit batch->num_tasks() - 1 empty index tensors.
769   for (int task_idx = 0; task_idx < batch->num_tasks() - 1; ++task_idx) {
770     const BatchTask& task = batch->task(task_idx);
771     TensorShape index_shape({0, 3});
772     Tensor* output = nullptr;
773     OP_REQUIRES_OK_ASYNC(
774         task.context,
775         task.context->allocate_output(num_input_edges, index_shape, &output),
776         task.done_callback);
777   }
778   // Emit all ID tensors.
779   for (int task_idx = 0; task_idx < batch->num_tasks(); ++task_idx) {
780     const BatchTask& task = batch->task(task_idx);
781     Tensor* id;
782     OP_REQUIRES_OK_ASYNC(task.context,
783                          task.context->allocate_output(num_input_edges + 1,
784                                                        TensorShape({}), &id),
785                          task.done_callback);
786     id->scalar<int64>()() = task.guid;
787   }
788   OP_REQUIRES_OK_ASYNC(
789       last_task_context,
790       EmitIndexTensor(last_task_context, *batch, num_input_edges),
791       last_task_callback);
792 
793   // Signal done for each element of the batch. (At this point, the contexts
794   // are no longer guaranteed to remain live.)
795   for (int task_idx = 0; task_idx < batch->num_tasks(); ++task_idx) {
796     batch->mutable_task(task_idx)->done_callback();
797   }
798 }
799 
EmitIndexTensor(OpKernelContext * context,const BatchT & batch,int output_index)800 /*static*/ Status BatchResourceBase::EmitIndexTensor(OpKernelContext* context,
801                                                      const BatchT& batch,
802                                                      int output_index) {
803   const TensorShape index_shape({batch.num_tasks(), 3});
804   Tensor* index = nullptr;
805   TF_RETURN_IF_ERROR(
806       context->allocate_output(output_index, index_shape, &index));
807   auto index_flat = index->shaped<int64, 2>({batch.num_tasks(), 3});
808   size_t offset = 0;
809   for (int task_idx = 0; task_idx < batch.num_tasks(); ++task_idx) {
810     const BatchTask& task = batch.task(task_idx);
811     index_flat(task_idx, 0) = task.guid;
812     index_flat(task_idx, 1) = offset;
813     index_flat(task_idx, 2) = offset + task.size();
814     offset += task.size();
815   }
816   return Status::OK();
817 }
818 
819 // Looks up the batcher queue for 'queue_name'. If it did't previously exist,
820 // creates it.
LookupOrCreateBatcherQueue(const string & queue_name,BatcherQueueT ** queue)821 Status BatchResourceBase::LookupOrCreateBatcherQueue(const string& queue_name,
822                                                      BatcherQueueT** queue) {
823   mutex_lock l(batcher_queues_mu_);
824 
825   auto it = batcher_queues_.find(queue_name);
826   if (it != batcher_queues_.end()) {
827     *queue = it->second.get();
828     return Status::OK();
829   }
830 
831   std::unique_ptr<BatcherQueueT> new_queue;
832   auto process_batch_callback = [this](std::unique_ptr<BatchT> batch) {
833     if (!has_process_batch_function_) {
834       ProcessBatch(std::move(batch));
835     } else {
836       ProcessFuncBatch(std::move(batch));
837     }
838   };
839   if (batcher_) {
840     TF_RETURN_IF_ERROR(batcher_->AddQueue(batcher_queue_options_,
841                                           process_batch_callback, &new_queue));
842   } else if (adaptive_batcher_) {
843     TF_RETURN_IF_ERROR(adaptive_batcher_->AddQueue(
844         adaptive_batcher_queue_options_, process_batch_callback, &new_queue));
845   } else {
846     return errors::Internal("No batcher defined.");
847   }
848   *queue = new_queue.get();
849   batcher_queues_[queue_name] = std::move(new_queue);
850   return Status::OK();
851 }
852 
CreateBatchTask(OpKernelContext * context,std::unique_ptr<BatchResourceBase::BatchTask> * output) const853 Status BatchResourceBase::CreateBatchTask(
854     OpKernelContext* context,
855     std::unique_ptr<BatchResourceBase::BatchTask>* output) const {
856   *output = absl::make_unique<BatchResourceBase::BatchTask>();
857   return Status::OK();
858 }
859 
SplitBatchCost(CostMeasurement * batch_cost_measurement,const int64_t processed_size,BatchT & batch) const860 void BatchResourceBase::SplitBatchCost(CostMeasurement* batch_cost_measurement,
861                                        const int64_t processed_size,
862                                        BatchT& batch) const {
863   if (batch_cost_measurement == nullptr ||
864       batch_cost_measurement->GetTotalCost() == absl::ZeroDuration()) {
865     return;
866   }
867   // TODO(b/1858529900): Split the cost to each task: define RequestCost for
868   // each inference request and add it as a field of BatchTask, implement the
869   // cost split algorithms where the paddings share / do not share the cost.
870 }
871 
872 }  // namespace serving
873 }  // namespace tensorflow
874