1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/kernels/batching_util/batch_resource_base.h"
17
18 #include "absl/types/optional.h"
19 #include "tensorflow/core/framework/ops_util.h"
20 #include "tensorflow/core/framework/tensor_util.h"
21 #include "tensorflow/core/kernels/batching_util/concat_split_util.h"
22 #include "tensorflow/core/lib/gtl/cleanup.h"
23 #include "tensorflow/core/lib/monitoring/gauge.h"
24 #include "tensorflow/core/lib/monitoring/percentile_sampler.h"
25 #include "tensorflow/core/profiler/lib/traceme.h"
26 #include "tensorflow/core/profiler/lib/traceme_encode.h"
27 #include "tensorflow/core/util/incremental_barrier.h"
28
29 namespace tensorflow {
30 namespace serving {
31 namespace {
32
RecordPaddingSize(int32 padding_size,const string & model_name,int32 execution_batch_size,const string & op_name)33 void RecordPaddingSize(int32 padding_size, const string& model_name,
34 int32 execution_batch_size, const string& op_name) {
35 static auto* cell = tensorflow::monitoring::PercentileSampler<3>::New(
36 {"/tensorflow/serving/batching/padding_size",
37 "Tracks the padding size distribution on batches by model_name (if "
38 "available).",
39 "model_name", "execution_batch_size", "op_name"},
40 /*percentiles=*/{25.0, 50.0, 75.0, 90.0, 95.0, 99.0},
41 /*max_samples=*/1024, tensorflow::monitoring::UnitOfMeasure::kNumber);
42 cell->GetCell(model_name, absl::StrCat(execution_batch_size), op_name)
43 ->Add(static_cast<double>(padding_size));
44 }
45
RecordInputBatchSize(int32 batch_size,const string & model_name,const string & op_name)46 void RecordInputBatchSize(int32 batch_size, const string& model_name,
47 const string& op_name) {
48 static auto* cell = tensorflow::monitoring::PercentileSampler<2>::New(
49 {"/tensorflow/serving/batching/input_batch_size",
50 "Tracks the batch size distribution on the inputs by model_name (if "
51 "available).",
52 "model_name", "op_name"},
53 /*percentiles=*/{25.0, 50.0, 75.0, 90.0, 95.0, 99.0},
54 /*max_samples=*/1024, tensorflow::monitoring::UnitOfMeasure::kNumber);
55 cell->GetCell(model_name, op_name)->Add(static_cast<double>(batch_size));
56 }
57
RecordProcessedBatchSize(int32 batch_size,const string & model_name,const string & op_name)58 void RecordProcessedBatchSize(int32 batch_size, const string& model_name,
59 const string& op_name) {
60 static auto* cell = tensorflow::monitoring::PercentileSampler<2>::New(
61 {"/tensorflow/serving/batching/processed_batch_size",
62 "Tracks the batch size distribution on processing by model_name (if "
63 "available).",
64 "model_name", "op_name"},
65 /*percentiles=*/{25.0, 50.0, 75.0, 90.0, 95.0, 99.0},
66 /*max_samples=*/1024, tensorflow::monitoring::UnitOfMeasure::kNumber);
67 cell->GetCell(model_name, op_name)->Add(static_cast<double>(batch_size));
68 }
69
RecordBatchDelayUs(int64 batch_delay_us,const string & model_name,const string & op_name)70 void RecordBatchDelayUs(int64 batch_delay_us, const string& model_name,
71 const string& op_name) {
72 static auto* cell = monitoring::PercentileSampler<2>::New(
73 {"/tensorflow/serving/batching/batch_delay_us",
74 "Tracks the batching delay (in microseconds) for inputs by model_name "
75 "(if available).",
76 "model_name", "op_name"},
77 /*percentiles=*/{25.0, 50.0, 75.0, 90.0, 95.0, 99.0},
78 /*max_samples=*/1024, monitoring::UnitOfMeasure::kTime);
79 cell->GetCell(model_name, op_name)->Add(static_cast<double>(batch_delay_us));
80 }
81
RecordBatchParamBatchTimeoutMicros(int64 batch_timeout_micros,const string & model_name,const string & op_name)82 void RecordBatchParamBatchTimeoutMicros(int64 batch_timeout_micros,
83 const string& model_name,
84 const string& op_name) {
85 static auto* cell = monitoring::Gauge<int64, 2>::New(
86 "/tensorflow/serving/batching/batch_timeout_micros",
87 "Tracks how long a request can wait before being processed by a batch.",
88 "model_name", "op_name");
89 cell->GetCell(model_name, op_name)->Set(batch_timeout_micros);
90 }
91
RecordBatchParamMaxBatchSize(int64 max_batch_size,const string & model_name,const string & op_name)92 void RecordBatchParamMaxBatchSize(int64 max_batch_size,
93 const string& model_name,
94 const string& op_name) {
95 static auto* cell = monitoring::Gauge<int64, 2>::New(
96 "/tensorflow/serving/batching/max_batch_size",
97 "Tracks the maximum size of a batch.", "model_name", "op_name");
98 cell->GetCell(model_name, op_name)->Set(max_batch_size);
99 }
100
RecordBatchParamMaxEnqueuedBatches(int64 max_enqueued_batches,const string & model_name,const string & op_name)101 void RecordBatchParamMaxEnqueuedBatches(int64 max_enqueued_batches,
102 const string& model_name,
103 const string& op_name) {
104 static auto* cell = monitoring::Gauge<int64, 2>::New(
105 "/tensorflow/serving/batching/max_enqueued_batches",
106 "Tracks the maximum number of enqueued batches.", "model_name",
107 "op_name");
108 cell->GetCell(model_name, op_name)->Set(max_enqueued_batches);
109 }
110
RecordBatchParamAllowedBatchSizes(const string & allowed_batch_sizes,const string & model_name,const string & op_name)111 void RecordBatchParamAllowedBatchSizes(const string& allowed_batch_sizes,
112 const string& model_name,
113 const string& op_name) {
114 static auto* cell = monitoring::Gauge<string, 2>::New(
115 "/tensorflow/serving/batching/allowed_batch_sizes",
116 "Tracks the sizes that are allowed to form a batch.", "model_name",
117 "op_name");
118 cell->GetCell(model_name, op_name)->Set(allowed_batch_sizes);
119 }
120
GetModelName(OpKernelContext * ctx)121 const string& GetModelName(OpKernelContext* ctx) {
122 static string* kModelNameUnset = new string("model_name_unset");
123 if (!ctx->session_metadata()) return *kModelNameUnset;
124 if (ctx->session_metadata()->name().empty()) return *kModelNameUnset;
125 return ctx->session_metadata()->name();
126 }
127
128 } // namespace
129
130 std::unique_ptr<BatchResourceBase::BatchTask>
CreateSplitTask(int split_index,AsyncOpKernel::DoneCallback done_callback)131 BatchResourceBase::BatchTask::CreateSplitTask(
132 int split_index, AsyncOpKernel::DoneCallback done_callback) {
133 std::unique_ptr<BatchTask> task = CreateDerivedTask();
134
135 task->guid = this->guid;
136 task->propagated_context = Context(ContextKind::kThread);
137 task->inputs.reserve(this->inputs.size());
138 task->captured_inputs = this->captured_inputs;
139 task->context = this->context;
140 task->done_callback = done_callback;
141 task->split_index = split_index;
142 task->output = this->output;
143 task->status = this->status;
144 task->is_partial = true;
145 task->start_time = this->start_time;
146
147 return task;
148 }
149
150 using ::tensorflow::concat_split_util::Concat;
151 using ::tensorflow::concat_split_util::Split;
152 using TensorMatrix = std::vector<std::vector<Tensor>>;
153
RegisterInput(int64 guid,OpKernelContext * context,const string & batcher_queue_name,AsyncOpKernel::DoneCallback done_callback)154 Status BatchResourceBase::RegisterInput(
155 int64 guid, OpKernelContext* context, const string& batcher_queue_name,
156 AsyncOpKernel::DoneCallback done_callback) {
157 std::unique_ptr<BatchTask> batch_components;
158 TF_RETURN_IF_ERROR(CreateBatchTask(context, &batch_components));
159 batch_components->start_time = EnvTime::NowNanos();
160 batch_components->guid = guid;
161 batch_components->propagated_context = Context(ContextKind::kThread);
162 OpInputList tensors;
163 TF_RETURN_IF_ERROR(context->input_list("in_tensors", &tensors));
164 batch_components->inputs.reserve(tensors.size());
165 for (const Tensor& tensor : tensors) {
166 if (tensor.shape().dims() == 0) {
167 return errors::InvalidArgument(
168 "Batching input tensors must have at least one dimension");
169 }
170 if (tensors.size() >= 2 &&
171 tensor.shape().dim_size(0) != tensors[0].shape().dim_size(0)) {
172 return errors::InvalidArgument(
173 "Batching input tensors supplied in a given op invocation must "
174 "have equal 0th-dimension size");
175 }
176 batch_components->inputs.push_back(tensor);
177 }
178 RecordInputBatchSize(tensors[0].shape().dim_size(0), GetModelName(context),
179 context->op_kernel().name_view().data());
180 RecordBatchParamBatchTimeoutMicros(
181 batcher_queue_options_.batch_timeout_micros, GetModelName(context),
182 context->op_kernel().name_view().data());
183 RecordBatchParamMaxBatchSize(batcher_queue_options_.max_execution_batch_size,
184 GetModelName(context),
185 context->op_kernel().name_view().data());
186 RecordBatchParamMaxEnqueuedBatches(
187 batcher_queue_options_.max_enqueued_batches, GetModelName(context),
188 context->op_kernel().name_view().data());
189 RecordBatchParamAllowedBatchSizes(allowed_batch_sizes_str_,
190 GetModelName(context),
191 context->op_kernel().name_view().data());
192
193 // Degenerate case where the input is empty. Just return an empty tensor.
194 if (tensors[0].shape().dim_size(0) == 0) {
195 for (int i = 0; i < context->num_outputs(); i++) {
196 Tensor* empty_output;
197 AllocatorAttributes cpu_alloc;
198 cpu_alloc.set_on_host(true);
199 TF_RETURN_IF_ERROR(context->allocate_output(i, TensorShape({0}),
200 &empty_output, cpu_alloc));
201 }
202 done_callback();
203 return Status::OK();
204 }
205 OpInputList captured_tensors;
206 const auto captured_status =
207 context->input_list("captured_tensors", &captured_tensors);
208 if (captured_status.ok()) {
209 batch_components->captured_inputs.reserve(captured_tensors.size());
210 for (const Tensor& captured_tensor : captured_tensors) {
211 batch_components->captured_inputs.push_back(captured_tensor);
212 }
213 }
214 batch_components->context = context;
215 batch_components->done_callback = std::move(done_callback);
216 batch_components->split_index = 0;
217 batch_components->output = std::make_shared<TensorMatrix>();
218 batch_components->status = std::make_shared<ThreadSafeStatus>();
219
220 BatcherQueueT* batcher_queue;
221 TF_RETURN_IF_ERROR(
222 LookupOrCreateBatcherQueue(batcher_queue_name, &batcher_queue));
223 return batcher_queue->Schedule(&batch_components);
224 }
225
226 /*static*/ BatchResourceBase::BatcherT::QueueOptions
GetBatcherQueueOptions(int32 num_batch_threads,int32 max_batch_size,int32 batch_timeout_micros,int32 max_enqueued_batches,const std::vector<int32> & allowed_batch_sizes,bool enable_large_batch_splitting)227 BatchResourceBase::GetBatcherQueueOptions(
228 int32 num_batch_threads, int32 max_batch_size, int32 batch_timeout_micros,
229 int32 max_enqueued_batches, const std::vector<int32>& allowed_batch_sizes,
230 bool enable_large_batch_splitting) {
231 BatcherT::QueueOptions batcher_queue_options;
232 batcher_queue_options.input_batch_size_limit = max_batch_size;
233 batcher_queue_options.max_enqueued_batches = max_enqueued_batches;
234 batcher_queue_options.batch_timeout_micros = batch_timeout_micros;
235 batcher_queue_options.enable_large_batch_splitting =
236 enable_large_batch_splitting;
237 if (enable_large_batch_splitting) {
238 batcher_queue_options.split_input_task_func =
239 [](std::unique_ptr<BatchTask>* input_task,
240 int open_batch_remaining_slot, int max_batch_size,
241 std::vector<std::unique_ptr<BatchTask>>* output_tasks) -> Status {
242 return SplitInputTask(input_task, open_batch_remaining_slot,
243 max_batch_size, output_tasks);
244 };
245
246 if (allowed_batch_sizes.empty()) {
247 batcher_queue_options.max_execution_batch_size = max_batch_size;
248 } else {
249 batcher_queue_options.max_execution_batch_size =
250 *allowed_batch_sizes.rbegin();
251 }
252 }
253
254 return batcher_queue_options;
255 }
256
257 /*static*/ BatchResourceBase::AdaptiveBatcherT::QueueOptions
GetAdaptiveBatcherQueueOptions(int32 max_batch_size,int32 batch_timeout_micros,int32 max_enqueued_batches,bool enable_large_batch_splitting,const std::vector<int32> & allowed_batch_sizes)258 BatchResourceBase::GetAdaptiveBatcherQueueOptions(
259 int32 max_batch_size, int32 batch_timeout_micros,
260 int32 max_enqueued_batches, bool enable_large_batch_splitting,
261 const std::vector<int32>& allowed_batch_sizes) {
262 AdaptiveBatcherT::QueueOptions batcher_queue_options;
263 batcher_queue_options.max_input_task_size =
264 absl::make_optional(max_batch_size);
265 batcher_queue_options.max_enqueued_batches = max_enqueued_batches;
266 batcher_queue_options.batch_timeout_micros = batch_timeout_micros;
267 if (allowed_batch_sizes.empty()) {
268 batcher_queue_options.max_batch_size = max_batch_size;
269 } else {
270 batcher_queue_options.max_batch_size = *allowed_batch_sizes.rbegin();
271 }
272
273 if (enable_large_batch_splitting) {
274 batcher_queue_options.split_input_task_func =
275 [](std::unique_ptr<BatchTask>* input_task,
276 int open_batch_remaining_slot, int max_batch_size,
277 std::vector<std::unique_ptr<BatchTask>>* output_tasks) -> Status {
278 return SplitInputTask(input_task, open_batch_remaining_slot,
279 max_batch_size, output_tasks);
280 };
281 }
282
283 return batcher_queue_options;
284 }
285
ValidateBatch(const BatchT & batch)286 /*static*/ Status BatchResourceBase::ValidateBatch(const BatchT& batch) {
287 for (int task_idx = 0; task_idx < batch.num_tasks(); ++task_idx) {
288 const BatchResourceBase::BatchTask& task = batch.task(task_idx);
289
290 if (task.inputs.size() != batch.task(0).inputs.size()) {
291 return errors::InvalidArgument(
292 "Batching inputs must have equal number of edges");
293 }
294 }
295
296 return Status::OK();
297 }
298
299 // Returns the smallest entry in 'allowed_batch_sizes_' that is greater than
300 // or equal to 'batch_size'. If 'allowed_batch_sizes_' is empty, simply
301 // returns 'batch_size'.
RoundToLowestAllowedBatchSize(int batch_size) const302 int BatchResourceBase::RoundToLowestAllowedBatchSize(int batch_size) const {
303 if (allowed_batch_sizes_.empty()) {
304 return batch_size;
305 }
306 for (int allowed_size : allowed_batch_sizes_) {
307 if (allowed_size >= batch_size) {
308 return allowed_size;
309 }
310 }
311 LOG(ERROR) << "Batch size " << batch_size
312 << " is greater than largest allowed size; "
313 "ignoring allowed sizes constraint.";
314 return batch_size;
315 }
316
ConcatInputTensors(const BatchT & batch,OpKernelContext * context,std::vector<Tensor> * concatenated_tensors) const317 Status BatchResourceBase::ConcatInputTensors(
318 const BatchT& batch, OpKernelContext* context,
319 std::vector<Tensor>* concatenated_tensors) const {
320 if (batch.num_tasks() == 0) {
321 return errors::InvalidArgument("Empty batch.");
322 }
323
324 const int padded_batch_size = RoundToLowestAllowedBatchSize(batch.size());
325 const int padding_amount = padded_batch_size - batch.size();
326 profiler::TraceMe trace_me([padded_batch_size, padding_amount]() {
327 return profiler::TraceMeEncode(
328 "ConcatInputTensors", {{"batch_size_after_padding", padded_batch_size},
329 {"padding_amount", padding_amount}});
330 });
331 RecordPaddingSize(padding_amount, GetModelName(context), padded_batch_size,
332 context->op_kernel().name_view().data());
333 RecordProcessedBatchSize(padded_batch_size, GetModelName(context),
334 context->op_kernel().name_view().data());
335
336 // All tasks should have the same number of input edges.
337 const int num_inputs = batch.task(0).inputs.size();
338 concatenated_tensors->reserve(num_inputs);
339
340 // Process each input one at a time (the typical case has just one).
341 for (int i = 0; i < num_inputs; ++i) {
342 // Concatenate the tasks ith input tensors into a big output tensor.
343 std::vector<Tensor> to_concatenate;
344 to_concatenate.reserve(batch.num_tasks());
345 for (int task_idx = 0; task_idx < batch.num_tasks(); ++task_idx) {
346 to_concatenate.push_back(batch.task(task_idx).inputs.at(i));
347 }
348
349 // Add padding as needed. Use the first row of the first task's tensor as
350 // the data for padding.
351 if (padding_amount > 0) {
352 const Tensor& padding_source = batch.task(0).inputs.at(i);
353 Tensor padding;
354 if (padding_source.shape().dim_size(0) == 0) {
355 return errors::InvalidArgument(
356 "Cannot use an empty tensor with zero rows as padding when "
357 "batching. (Input ",
358 i, " got shape ", padding_source.shape().DebugString(), ".)");
359 }
360 if (padding_source.shape().dim_size(0) == 1) {
361 padding = padding_source;
362 } else {
363 padding = padding_source.Slice(0, 1);
364 }
365 for (int i = 0; i < padding_amount; ++i) {
366 to_concatenate.push_back(padding);
367 }
368 }
369
370 Tensor concatenated_tensor;
371 Status concat_status =
372 Concat(context, to_concatenate, &concatenated_tensor);
373 TF_RETURN_IF_ERROR(concat_status);
374 concatenated_tensors->push_back(concatenated_tensor);
375 }
376 return Status::OK();
377 }
378
SplitInputTask(std::unique_ptr<BatchTask> * input_task_ptr,int open_batch_remaining_slot,int max_batch_size,std::vector<std::unique_ptr<BatchTask>> * output_tasks)379 /*static*/ Status BatchResourceBase::SplitInputTask(
380 std::unique_ptr<BatchTask>* input_task_ptr, int open_batch_remaining_slot,
381 int max_batch_size, std::vector<std::unique_ptr<BatchTask>>* output_tasks) {
382 BatchTask& input_task = *(*input_task_ptr);
383 const int64 input_task_size = input_task.size();
384
385 DCHECK_GT(input_task_size, open_batch_remaining_slot);
386
387 std::shared_ptr<ThreadSafeStatus> shared_status = input_task.status;
388
389 // `split_task_done_callback` runs only after all splitted tasks are
390 // complete.
391 std::function<void()> split_task_done_callback =
392 [done_callback = input_task.done_callback, output = input_task.output,
393 op_kernel_context = input_task.context, status = shared_status]() {
394 const int num_output = op_kernel_context->num_outputs();
395 for (int i = 0; i < num_output; ++i) {
396 Tensor output_tensor;
397
398 // Concat would memcpy each input tensor to one output tensor.
399 // In this context, Concat can be further optimized to get rid of
400 // some (probably all) memcpy when input tensors are slices of
401 // another copy.
402 std::vector<Tensor> to_concatenate;
403 to_concatenate.reserve(output->size());
404 for (int j = 0; j < output->size(); ++j) {
405 to_concatenate.push_back(std::move((*output)[j][i]));
406 }
407 const auto concat_status =
408 Concat(op_kernel_context, to_concatenate, &output_tensor);
409 if (!concat_status.ok()) {
410 status->Update(concat_status);
411 }
412
413 op_kernel_context->set_output(i, std::move(output_tensor));
414 }
415 op_kernel_context->SetStatus(status->status());
416 done_callback();
417 };
418 IncrementalBarrier barrier(split_task_done_callback);
419
420 std::vector<int64> output_task_sizes;
421
422 if (open_batch_remaining_slot > 0) {
423 output_task_sizes.push_back(open_batch_remaining_slot);
424 }
425
426 for (int left_task_size = input_task_size - open_batch_remaining_slot;
427 left_task_size > 0; left_task_size -= max_batch_size) {
428 int next_task_size = std::min(left_task_size, max_batch_size);
429 output_task_sizes.push_back(next_task_size);
430 }
431
432 const int output_task_num = output_task_sizes.size();
433 input_task.output->resize(output_task_num);
434
435 for (int i = 0; i < output_task_num; ++i) {
436 (*input_task.output)[i].resize(input_task.context->num_outputs());
437 }
438
439 output_tasks->reserve(output_task_num);
440 for (int i = 0; i < output_task_num; i++) {
441 output_tasks->push_back(input_task.CreateSplitTask(i, barrier.Inc()));
442 }
443
444 const int num_input_tensors = input_task.inputs.size();
445
446 // Splits each input tensor according to `output_task_sizes`, and
447 // initializes input of `output_tasks` with split results.
448 for (int i = 0; i < num_input_tensors; ++i) {
449 std::vector<Tensor> split_tensors;
450 const Tensor& input_tensor = input_task.inputs[i];
451 // TODO(b/154140947):
452 // Figure out the optimal implementation of Split, by using
453 // 'Tensor::Slice' and eliminating unnecessary memcpy as much as possible.
454 const Status split_status = Split(input_task.context, input_tensor,
455 output_task_sizes, &split_tensors);
456 if (!split_status.ok()) {
457 return errors::Internal(
458 "When splitting input, Tensor split operation failed: ",
459 split_status.ToString());
460 }
461 if (split_tensors.size() != output_task_sizes.size()) {
462 return errors::Internal(
463 "When splitting input, tensor split operation did not work as "
464 "expected; got ",
465 split_tensors.size(), " splits; expected ", output_task_sizes.size());
466 }
467 for (int j = 0; j < output_tasks->size(); ++j) {
468 BatchTask& output_task = *((*output_tasks)[j]);
469 auto moved_tensor_iter = std::next(split_tensors.begin(), j);
470 std::move(moved_tensor_iter, moved_tensor_iter + 1,
471 std::back_inserter(output_task.inputs));
472 }
473 }
474 return Status::OK();
475 }
476
SplitOutputTensors(const std::vector<Tensor> & combined_outputs,BatchT * batch) const477 Status BatchResourceBase::SplitOutputTensors(
478 const std::vector<Tensor>& combined_outputs, BatchT* batch) const {
479 DCHECK_GE(batch->num_tasks(), 1);
480 if (batch->num_tasks() < 1) {
481 return errors::Internal("Batch size expected to be positive; was ",
482 batch->num_tasks());
483 }
484
485 std::vector<int64> task_sizes_plus_optional_padding;
486 task_sizes_plus_optional_padding.reserve(batch->num_tasks());
487 for (int i = 0; i < batch->num_tasks(); ++i) {
488 task_sizes_plus_optional_padding.push_back(batch->task(i).size());
489 }
490 const int padding_size =
491 RoundToLowestAllowedBatchSize(batch->size()) - batch->size();
492 if (padding_size > 0) {
493 task_sizes_plus_optional_padding.push_back(padding_size);
494 }
495
496 // For each output tensor name, a divided-up tensor with one entry per task.
497 std::map<string, std::vector<Tensor>> split_tensors;
498
499 DCHECK_EQ(batch->task(0).context->num_outputs(), combined_outputs.size());
500 int combined_outputs_size = combined_outputs.size();
501 if (combined_outputs_size != batch->task(0).context->num_outputs()) {
502 return errors::Internal("Wrong number of batched output tensors");
503 }
504
505 // Generate 'split_tensors' and populate the context outputs.
506 for (int i = 0, iter_limit = combined_outputs.size(); i < iter_limit; ++i) {
507 const Tensor& output_tensor = combined_outputs[i];
508 if (output_tensor.shape().dims() == 0) {
509 return errors::FailedPrecondition(
510 "Batched output tensor has 0 dimensions");
511 }
512 if (output_tensor.shape().dim_size(0) !=
513 static_cast<int64>(batch->size() + padding_size)) {
514 return errors::FailedPrecondition(
515 "Batched output tensor's 0th dimension does not equal the sum of "
516 "the 0th dimension sizes of the input tensors");
517 }
518
519 std::vector<Tensor> split_tensor;
520 const Status split_status = tensor::Split(
521 output_tensor, task_sizes_plus_optional_padding, &split_tensor);
522 DCHECK(split_status.ok()) << split_status.ToString();
523 if (!split_status.ok()) {
524 return errors::Internal("Tensor split operation failed: ",
525 split_status.ToString());
526 }
527 DCHECK_EQ(split_tensor.size(), task_sizes_plus_optional_padding.size());
528 if (split_tensor.size() != task_sizes_plus_optional_padding.size()) {
529 return errors::Internal(
530 "Tensor split operation did not work as expected; got ",
531 split_tensor.size(), " splits; expected ",
532 task_sizes_plus_optional_padding.size());
533 }
534
535 // Ignore a possible final split_tensors entry containing the padding.
536 for (int j = 0; j < batch->num_tasks(); ++j) {
537 BatchTask& task = *(batch->mutable_task(j));
538 if (task.is_partial) {
539 std::vector<Tensor>& tensor_vector = (*task.output)[task.split_index];
540 tensor_vector[i] = std::move(split_tensor[j]);
541 } else {
542 task.context->set_output(i, split_tensor[j]);
543 }
544 }
545 }
546
547 return Status::OK();
548 }
549
ProcessFuncBatch(std::unique_ptr<BatchT> batch) const550 void BatchResourceBase::ProcessFuncBatch(std::unique_ptr<BatchT> batch) const {
551 if (batch->empty()) {
552 return;
553 }
554
555 // We use the 'propagated_context' from one of the threads which setup one
556 // of the tasks. This will propagate any common context over all the threads
557 // which are running this Session, of which this BatchOp is a part.
558 WithContext wc(batch->task(batch->num_tasks() - 1).propagated_context);
559
560 auto& last_task = batch->task(batch->num_tasks() - 1);
561 OpKernelContext* last_task_context = last_task.context;
562
563 // Regardless of the outcome, we need to propagate the status to the
564 // individual tasks and signal that they are done. We use MakeCleanup() to
565 // ensure that this happens no matter how we exit the method below.
566 Status status;
567 bool cleanup_done = false;
568 auto cleanup_fn = [&cleanup_done, &batch](const Status& status) {
569 if (cleanup_done) {
570 return;
571 }
572 for (int i = 0; i < batch->num_tasks(); ++i) {
573 if (batch->task(i).is_partial) {
574 batch->mutable_task(i)->status->Update(status);
575 } else {
576 batch->mutable_task(i)->context->SetStatus(status);
577 }
578
579 batch->mutable_task(i)->done_callback();
580 }
581 cleanup_done = true;
582 };
583
584 auto finally =
585 gtl::MakeCleanup([&cleanup_fn, &status] { cleanup_fn(status); });
586
587 status = ValidateBatch(*batch);
588 if (!status.ok()) {
589 return;
590 }
591
592 std::vector<Tensor> concatenated_tensors;
593 status = ConcatInputTensors(*batch, last_task_context, &concatenated_tensors);
594 if (!status.ok()) {
595 return;
596 }
597
598 std::vector<Tensor> combined_outputs;
599 std::vector<Tensor> args(concatenated_tensors.begin(),
600 concatenated_tensors.end());
601 const auto& captured_inputs =
602 batch->task(batch->num_tasks() - 1).captured_inputs;
603 args.insert(args.end(), captured_inputs.begin(), captured_inputs.end());
604
605 uint64 current_time = EnvTime::NowNanos();
606 const string& model_name = GetModelName(last_task_context);
607 for (int i = 0; i < batch->num_tasks(); ++i) {
608 RecordBatchDelayUs((current_time - batch->task(i).start_time) * 1e-3,
609 model_name,
610 last_task_context->op_kernel().name_view().data());
611 }
612 // Releases the cleanup method here, because the callback of the function
613 // library runtime will handle it now.
614 finally.release();
615 ProcessFuncBatchImpl(
616 last_task, args, &combined_outputs, [&](const Status& run_status) {
617 Status final_status;
618 auto run_finally = gtl::MakeCleanup([&]() {
619 // We do the cleanup here as an optimization, so that
620 // it runs in the underlying TF inter-op threadpool.
621 // Running it in the threadpool, let's the ensuing
622 // ops be scheduled faster, because the executor will
623 // add them to the front of the threadpool's task
624 // queue rather than the end.
625 cleanup_fn(final_status);
626 });
627 final_status = run_status;
628 if (!final_status.ok()) {
629 return;
630 }
631 final_status = SplitOutputTensors(combined_outputs, batch.get());
632 });
633 }
634
635 // Processes a batch of one or more BatchTask entries.
ProcessBatch(std::unique_ptr<BatchT> batch) const636 void BatchResourceBase::ProcessBatch(std::unique_ptr<BatchT> batch) const {
637 if (batch->empty()) {
638 return;
639 }
640
641 WithContext wc(batch->task(batch->num_tasks() - 1).propagated_context);
642
643 OpKernelContext* last_task_context =
644 batch->task(batch->num_tasks() - 1).context;
645 AsyncOpKernel::DoneCallback last_task_callback =
646 batch->task(batch->num_tasks() - 1).done_callback;
647
648 OP_REQUIRES_OK_ASYNC(last_task_context, ValidateBatch(*batch),
649 last_task_callback);
650
651 // All tasks should have the same number of input edges.
652 const int num_input_edges = batch->task(0).inputs.size();
653 std::vector<Tensor> concatenated_tensors;
654 const Status concat_status =
655 ConcatInputTensors(*batch, last_task_context, &concatenated_tensors);
656 OP_REQUIRES_OK_ASYNC(last_task_context, concat_status, last_task_callback);
657
658 // Process each input edge one at a time (the typical case has just one).
659 for (int i = 0; i < num_input_edges; ++i) {
660 last_task_context->set_output(i, concatenated_tensors[i]);
661
662 // Emit batch->num_tasks() - 1 empty output tensors.
663 for (int task_idx = 0; task_idx < batch->num_tasks() - 1; ++task_idx) {
664 const BatchTask& task = batch->task(task_idx);
665 TensorShape output_shape(task.inputs[i].shape());
666 output_shape.set_dim(0, 0);
667 Tensor* output = nullptr;
668 OP_REQUIRES_OK_ASYNC(
669 task.context, task.context->allocate_output(i, output_shape, &output),
670 task.done_callback);
671 }
672 }
673 // Emit batch->num_tasks() - 1 empty index tensors.
674 for (int task_idx = 0; task_idx < batch->num_tasks() - 1; ++task_idx) {
675 const BatchTask& task = batch->task(task_idx);
676 TensorShape index_shape({0, 3});
677 Tensor* output = nullptr;
678 OP_REQUIRES_OK_ASYNC(
679 task.context,
680 task.context->allocate_output(num_input_edges, index_shape, &output),
681 task.done_callback);
682 }
683 // Emit all ID tensors.
684 for (int task_idx = 0; task_idx < batch->num_tasks(); ++task_idx) {
685 const BatchTask& task = batch->task(task_idx);
686 Tensor* id;
687 OP_REQUIRES_OK_ASYNC(task.context,
688 task.context->allocate_output(num_input_edges + 1,
689 TensorShape({}), &id),
690 task.done_callback);
691 id->scalar<int64>()() = task.guid;
692 }
693 OP_REQUIRES_OK_ASYNC(
694 last_task_context,
695 EmitIndexTensor(last_task_context, *batch, num_input_edges),
696 last_task_callback);
697
698 // Signal done for each element of the batch. (At this point, the contexts
699 // are no longer guaranteed to remain live.)
700 for (int task_idx = 0; task_idx < batch->num_tasks(); ++task_idx) {
701 batch->mutable_task(task_idx)->done_callback();
702 }
703 }
704
EmitIndexTensor(OpKernelContext * context,const BatchT & batch,int output_index)705 /*static*/ Status BatchResourceBase::EmitIndexTensor(OpKernelContext* context,
706 const BatchT& batch,
707 int output_index) {
708 const TensorShape index_shape({batch.num_tasks(), 3});
709 Tensor* index = nullptr;
710 TF_RETURN_IF_ERROR(
711 context->allocate_output(output_index, index_shape, &index));
712 auto index_flat = index->shaped<int64, 2>({batch.num_tasks(), 3});
713 size_t offset = 0;
714 for (int task_idx = 0; task_idx < batch.num_tasks(); ++task_idx) {
715 const BatchTask& task = batch.task(task_idx);
716 index_flat(task_idx, 0) = task.guid;
717 index_flat(task_idx, 1) = offset;
718 index_flat(task_idx, 2) = offset + task.size();
719 offset += task.size();
720 }
721 return Status::OK();
722 }
723
724 // Looks up the batcher queue for 'queue_name'. If it did't previously exist,
725 // creates it.
LookupOrCreateBatcherQueue(const string & queue_name,BatcherQueueT ** queue)726 Status BatchResourceBase::LookupOrCreateBatcherQueue(const string& queue_name,
727 BatcherQueueT** queue) {
728 mutex_lock l(batcher_queues_mu_);
729
730 auto it = batcher_queues_.find(queue_name);
731 if (it != batcher_queues_.end()) {
732 *queue = it->second.get();
733 return Status::OK();
734 }
735
736 std::unique_ptr<BatcherQueueT> new_queue;
737 auto process_batch_callback = [this](std::unique_ptr<BatchT> batch) {
738 if (!has_process_batch_function_) {
739 ProcessBatch(std::move(batch));
740 } else {
741 ProcessFuncBatch(std::move(batch));
742 }
743 };
744 if (batcher_) {
745 TF_RETURN_IF_ERROR(batcher_->AddQueue(batcher_queue_options_,
746 process_batch_callback, &new_queue));
747 } else if (adaptive_batcher_) {
748 TF_RETURN_IF_ERROR(adaptive_batcher_->AddQueue(
749 adaptive_batcher_queue_options_, process_batch_callback, &new_queue));
750 } else {
751 return errors::Internal("No batcher defined.");
752 }
753 *queue = new_queue.get();
754 batcher_queues_[queue_name] = std::move(new_queue);
755 return Status::OK();
756 }
757
CreateBatchTask(OpKernelContext * context,std::unique_ptr<BatchResourceBase::BatchTask> * output) const758 Status BatchResourceBase::CreateBatchTask(
759 OpKernelContext* context,
760 std::unique_ptr<BatchResourceBase::BatchTask>* output) const {
761 *output = absl::make_unique<BatchResourceBase::BatchTask>();
762 return Status::OK();
763 }
764
765 } // namespace serving
766 } // namespace tensorflow
767