1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/kernels/batching_util/batch_resource_base.h"
17
18 #include "absl/time/time.h"
19 #include "absl/types/optional.h"
20 #include "tensorflow/core/common_runtime/cost_measurement_registry.h"
21 #include "tensorflow/core/framework/ops_util.h"
22 #include "tensorflow/core/framework/tensor_util.h"
23 #include "tensorflow/core/kernels/batching_util/concat_split_util.h"
24 #include "tensorflow/core/lib/gtl/cleanup.h"
25 #include "tensorflow/core/lib/monitoring/counter.h"
26 #include "tensorflow/core/lib/monitoring/gauge.h"
27 #include "tensorflow/core/lib/monitoring/percentile_sampler.h"
28 #include "tensorflow/core/lib/monitoring/sampler.h"
29 #include "tensorflow/core/profiler/lib/traceme.h"
30 #include "tensorflow/core/profiler/lib/traceme_encode.h"
31 #include "tensorflow/core/util/incremental_barrier.h"
32
33 namespace tensorflow {
34 namespace serving {
35 namespace {
36
GetCostMeasurementType()37 const char* GetCostMeasurementType() {
38 return std::getenv("TF_COST_MEASUREMENT_TYPE");
39 }
40
41 // TODO(b/181883417): Replace with RecordPaddingSizeV2.
RecordPaddingSize(int32_t padding_size,const string & model_name,int32_t execution_batch_size,const string & op_name)42 void RecordPaddingSize(int32_t padding_size, const string& model_name,
43 int32_t execution_batch_size, const string& op_name) {
44 static auto* cell = tensorflow::monitoring::PercentileSampler<3>::New(
45 {"/tensorflow/serving/batching/padding_size",
46 "Tracks the padding size distribution on batches by model_name (if "
47 "available).",
48 "model_name", "execution_batch_size", "op_name"},
49 /*percentiles=*/{25.0, 50.0, 75.0, 90.0, 95.0, 99.0},
50 /*max_samples=*/1024, tensorflow::monitoring::UnitOfMeasure::kNumber);
51 cell->GetCell(model_name, absl::StrCat(execution_batch_size), op_name)
52 ->Add(static_cast<double>(padding_size));
53 }
54
RecordPaddingSizeV2(int32_t padding_size,const string & model_name,int32_t execution_batch_size,const string & op_name)55 void RecordPaddingSizeV2(int32_t padding_size, const string& model_name,
56 int32_t execution_batch_size, const string& op_name) {
57 static auto* cell = tensorflow::monitoring::Sampler<3>::New(
58 {"/tensorflow/serving/batching/padding_size_v2",
59 "Tracks the padding size distribution on batches by model_name (if "
60 "available).",
61 "model_name", "execution_batch_size", "op_name"},
62 // It's 14 buckets with the last bucket being 2^13 to DBL_MAX;
63 // so the limits are [1, 2, 4, 8, ..., 8 * 1024, DBL_MAX].
64 monitoring::Buckets::Exponential(1, 2, 14));
65 cell->GetCell(model_name, absl::StrCat(execution_batch_size), op_name)
66 ->Add(static_cast<double>(padding_size));
67 }
68
69 // TODO(b/181883417): Replace with RecordInputBatchSizeV2.
RecordInputBatchSize(int32_t batch_size,const string & model_name,const string & op_name)70 void RecordInputBatchSize(int32_t batch_size, const string& model_name,
71 const string& op_name) {
72 static auto* cell = tensorflow::monitoring::PercentileSampler<2>::New(
73 {"/tensorflow/serving/batching/input_batch_size",
74 "Tracks the batch size distribution on the inputs by model_name (if "
75 "available).",
76 "model_name", "op_name"},
77 /*percentiles=*/{25.0, 50.0, 75.0, 90.0, 95.0, 99.0},
78 /*max_samples=*/1024, tensorflow::monitoring::UnitOfMeasure::kNumber);
79 cell->GetCell(model_name, op_name)->Add(static_cast<double>(batch_size));
80 }
81
RecordInputBatchSizeV2(int32_t batch_size,const string & model_name,const string & op_name)82 void RecordInputBatchSizeV2(int32_t batch_size, const string& model_name,
83 const string& op_name) {
84 static auto* cell = tensorflow::monitoring::Sampler<2>::New(
85 {"/tensorflow/serving/batching/input_batch_size_v2",
86 "Tracks the batch size distribution on the inputs by model_name (if "
87 "available).",
88 "model_name", "op_name"},
89 // It's 14 buckets with the last bucket being 2^13 to DBL_MAX;
90 // so the limits are [1, 2, 4, 8, ..., 8 * 1024, DBL_MAX].
91 monitoring::Buckets::Exponential(1, 2, 14));
92 cell->GetCell(model_name, op_name)->Add(static_cast<double>(batch_size));
93 }
94
RecordProcessedBatchSize(int32_t batch_size,const string & model_name,const string & op_name)95 void RecordProcessedBatchSize(int32_t batch_size, const string& model_name,
96 const string& op_name) {
97 static auto* cell = tensorflow::monitoring::PercentileSampler<2>::New(
98 {"/tensorflow/serving/batching/processed_batch_size",
99 "Tracks the batch size distribution on processing by model_name (if "
100 "available).",
101 "model_name", "op_name"},
102 /*percentiles=*/{25.0, 50.0, 75.0, 90.0, 95.0, 99.0},
103 /*max_samples=*/1024, tensorflow::monitoring::UnitOfMeasure::kNumber);
104 cell->GetCell(model_name, op_name)->Add(static_cast<double>(batch_size));
105 }
106
107 // Export the exact number instead of the distribution of processed batch size.
RecordProcessedBatchSizeV2(int32_t batch_size,const string & model_name,const string & op_name)108 void RecordProcessedBatchSizeV2(int32_t batch_size, const string& model_name,
109 const string& op_name) {
110 static auto* cell = monitoring::Counter<3>::New(
111 "/tensorflow/serving/batching/processed_batch_size_v2",
112 "Tracks the batch size on processing by model_name and op name (if "
113 "available).",
114 "model_name", "op_name", "batch_size");
115 cell->GetCell(model_name, op_name, std::to_string(batch_size))
116 ->IncrementBy(1);
117 }
118
119 // TODO(b/181883417): Replace with RecordBatchDelayUsV2.
RecordBatchDelayUs(int64_t batch_delay_us,const string & model_name,const string & op_name)120 void RecordBatchDelayUs(int64_t batch_delay_us, const string& model_name,
121 const string& op_name) {
122 static auto* cell = monitoring::PercentileSampler<2>::New(
123 {"/tensorflow/serving/batching/batch_delay_us",
124 "Tracks the batching delay (in microseconds) for inputs by model_name "
125 "(if available).",
126 "model_name", "op_name"},
127 /*percentiles=*/{25.0, 50.0, 75.0, 90.0, 95.0, 99.0},
128 /*max_samples=*/1024, monitoring::UnitOfMeasure::kTime);
129 cell->GetCell(model_name, op_name)->Add(static_cast<double>(batch_delay_us));
130 }
131
RecordBatchDelayUsV2(int64_t batch_delay_us,const string & model_name,const string & op_name)132 void RecordBatchDelayUsV2(int64_t batch_delay_us, const string& model_name,
133 const string& op_name) {
134 static auto* cell = tensorflow::monitoring::Sampler<2>::New(
135 {"/tensorflow/serving/batching/batch_delay_us_v2",
136 "Tracks the batching delay (in microseconds) for inputs by model_name "
137 "(if available).",
138 "model_name", "op_name"},
139 // It's 27 buckets with the last bucket being 2^26 to DBL_MAX;
140 // so the limits are [1, 2, 4, 8, ..., 64 * 1024 * 1024, DBL_MAX].
141 monitoring::Buckets::Exponential(1, 2, 27));
142 cell->GetCell(model_name, op_name)->Add(static_cast<double>(batch_delay_us));
143 }
144
RecordBatchParamBatchTimeoutMicros(int64_t batch_timeout_micros,const string & model_name,const string & op_name)145 void RecordBatchParamBatchTimeoutMicros(int64_t batch_timeout_micros,
146 const string& model_name,
147 const string& op_name) {
148 static auto* cell = monitoring::Gauge<int64, 2>::New(
149 "/tensorflow/serving/batching/batch_timeout_micros",
150 "Tracks how long a request can wait before being processed by a batch.",
151 "model_name", "op_name");
152 cell->GetCell(model_name, op_name)->Set(batch_timeout_micros);
153 }
154
RecordBatchParamMaxBatchSize(int64_t max_batch_size,const string & model_name,const string & op_name)155 void RecordBatchParamMaxBatchSize(int64_t max_batch_size,
156 const string& model_name,
157 const string& op_name) {
158 static auto* cell = monitoring::Gauge<int64, 2>::New(
159 "/tensorflow/serving/batching/max_batch_size",
160 "Tracks the maximum size of a batch.", "model_name", "op_name");
161 cell->GetCell(model_name, op_name)->Set(max_batch_size);
162 }
163
RecordBatchParamMaxEnqueuedBatches(int64_t max_enqueued_batches,const string & model_name,const string & op_name)164 void RecordBatchParamMaxEnqueuedBatches(int64_t max_enqueued_batches,
165 const string& model_name,
166 const string& op_name) {
167 static auto* cell = monitoring::Gauge<int64, 2>::New(
168 "/tensorflow/serving/batching/max_enqueued_batches",
169 "Tracks the maximum number of enqueued batches.", "model_name",
170 "op_name");
171 cell->GetCell(model_name, op_name)->Set(max_enqueued_batches);
172 }
173
RecordBatchParamAllowedBatchSizes(const string & allowed_batch_sizes,const string & model_name,const string & op_name)174 void RecordBatchParamAllowedBatchSizes(const string& allowed_batch_sizes,
175 const string& model_name,
176 const string& op_name) {
177 static auto* cell = monitoring::Gauge<string, 2>::New(
178 "/tensorflow/serving/batching/allowed_batch_sizes",
179 "Tracks the sizes that are allowed to form a batch.", "model_name",
180 "op_name");
181 cell->GetCell(model_name, op_name)->Set(allowed_batch_sizes);
182 }
183
GetModelName(OpKernelContext * ctx)184 const string& GetModelName(OpKernelContext* ctx) {
185 static string* kModelNameUnset = new string("model_name_unset");
186 if (!ctx->session_metadata()) return *kModelNameUnset;
187 if (ctx->session_metadata()->name().empty()) return *kModelNameUnset;
188 return ctx->session_metadata()->name();
189 }
190
191 } // namespace
192
193 std::unique_ptr<BatchResourceBase::BatchTask>
CreateSplitTask(int split_index,AsyncOpKernel::DoneCallback done_callback)194 BatchResourceBase::BatchTask::CreateSplitTask(
195 int split_index, AsyncOpKernel::DoneCallback done_callback) {
196 std::unique_ptr<BatchTask> task = CreateDerivedTask();
197
198 task->guid = this->guid;
199 task->propagated_context = Context(ContextKind::kThread);
200 task->inputs.reserve(this->inputs.size());
201 task->captured_inputs = this->captured_inputs;
202 task->context = this->context;
203 task->done_callback = done_callback;
204 task->split_index = split_index;
205 task->output = this->output;
206 task->status = this->status;
207 task->is_partial = true;
208 task->start_time = this->start_time;
209 task->request_cost = this->request_cost;
210
211 return task;
212 }
213
214 using ::tensorflow::concat_split_util::Concat;
215 using ::tensorflow::concat_split_util::Split;
216 using TensorMatrix = std::vector<std::vector<Tensor>>;
217
RegisterInput(int64_t guid,OpKernelContext * context,const string & batcher_queue_name,AsyncOpKernel::DoneCallback done_callback)218 Status BatchResourceBase::RegisterInput(
219 int64_t guid, OpKernelContext* context, const string& batcher_queue_name,
220 AsyncOpKernel::DoneCallback done_callback) {
221 std::unique_ptr<BatchTask> batch_components;
222 TF_RETURN_IF_ERROR(CreateBatchTask(context, &batch_components));
223 batch_components->start_time = EnvTime::NowNanos();
224 batch_components->guid = guid;
225 batch_components->propagated_context = Context(ContextKind::kThread);
226 OpInputList tensors;
227 TF_RETURN_IF_ERROR(context->input_list("in_tensors", &tensors));
228 batch_components->inputs.reserve(tensors.size());
229 for (const Tensor& tensor : tensors) {
230 if (tensor.shape().dims() == 0) {
231 return errors::InvalidArgument(
232 "Batching input tensors must have at least one dimension");
233 }
234 if (tensors.size() >= 2 &&
235 tensor.shape().dim_size(0) != tensors[0].shape().dim_size(0)) {
236 return errors::InvalidArgument(
237 "Batching input tensors supplied in a given op invocation must "
238 "have equal 0th-dimension size");
239 }
240 batch_components->inputs.push_back(tensor);
241 }
242 RecordInputBatchSize(tensors[0].shape().dim_size(0), GetModelName(context),
243 context->op_kernel().name_view().data());
244 RecordInputBatchSizeV2(tensors[0].shape().dim_size(0), GetModelName(context),
245 context->op_kernel().name());
246 RecordBatchParamBatchTimeoutMicros(
247 batcher_queue_options_.batch_timeout_micros, GetModelName(context),
248 context->op_kernel().name_view().data());
249 RecordBatchParamMaxBatchSize(batcher_queue_options_.max_execution_batch_size,
250 GetModelName(context),
251 context->op_kernel().name_view().data());
252 RecordBatchParamMaxEnqueuedBatches(
253 batcher_queue_options_.max_enqueued_batches, GetModelName(context),
254 context->op_kernel().name_view().data());
255 RecordBatchParamAllowedBatchSizes(allowed_batch_sizes_str_,
256 GetModelName(context),
257 context->op_kernel().name_view().data());
258
259 // Degenerate case where the input is empty. Just return an empty tensor.
260 if (tensors[0].shape().dim_size(0) == 0) {
261 for (int i = 0; i < context->num_outputs(); i++) {
262 Tensor* empty_output;
263 AllocatorAttributes cpu_alloc;
264 cpu_alloc.set_on_host(true);
265 TF_RETURN_IF_ERROR(context->allocate_output(i, TensorShape({0}),
266 &empty_output, cpu_alloc));
267 }
268 done_callback();
269 return Status::OK();
270 }
271 OpInputList captured_tensors;
272 const auto captured_status =
273 context->input_list("captured_tensors", &captured_tensors);
274 if (captured_status.ok()) {
275 batch_components->captured_inputs.reserve(captured_tensors.size());
276 for (const Tensor& captured_tensor : captured_tensors) {
277 batch_components->captured_inputs.push_back(captured_tensor);
278 }
279 }
280 batch_components->context = context;
281 batch_components->done_callback = std::move(done_callback);
282 batch_components->split_index = 0;
283 batch_components->output = std::make_shared<TensorMatrix>();
284 batch_components->status = std::make_shared<ThreadSafeStatus>();
285
286 BatcherQueueT* batcher_queue;
287 TF_RETURN_IF_ERROR(
288 LookupOrCreateBatcherQueue(batcher_queue_name, &batcher_queue));
289 return batcher_queue->Schedule(&batch_components);
290 }
291
292 /*static*/ BatchResourceBase::BatcherT::QueueOptions
GetBatcherQueueOptions(int32_t num_batch_threads,int32_t max_batch_size,int32_t batch_timeout_micros,int32_t max_enqueued_batches,const std::vector<int32> & allowed_batch_sizes,bool enable_large_batch_splitting)293 BatchResourceBase::GetBatcherQueueOptions(
294 int32_t num_batch_threads, int32_t max_batch_size,
295 int32_t batch_timeout_micros, int32_t max_enqueued_batches,
296 const std::vector<int32>& allowed_batch_sizes,
297 bool enable_large_batch_splitting) {
298 BatcherT::QueueOptions batcher_queue_options;
299 batcher_queue_options.input_batch_size_limit = max_batch_size;
300 batcher_queue_options.max_enqueued_batches = max_enqueued_batches;
301 batcher_queue_options.batch_timeout_micros = batch_timeout_micros;
302 batcher_queue_options.enable_large_batch_splitting =
303 enable_large_batch_splitting;
304 if (enable_large_batch_splitting) {
305 batcher_queue_options.split_input_task_func =
306 [](std::unique_ptr<BatchTask>* input_task,
307 int open_batch_remaining_slot, int max_batch_size,
308 std::vector<std::unique_ptr<BatchTask>>* output_tasks) -> Status {
309 return SplitInputTask(input_task, open_batch_remaining_slot,
310 max_batch_size, output_tasks);
311 };
312
313 if (allowed_batch_sizes.empty()) {
314 batcher_queue_options.max_execution_batch_size = max_batch_size;
315 } else {
316 batcher_queue_options.max_execution_batch_size =
317 *allowed_batch_sizes.rbegin();
318 }
319 }
320
321 return batcher_queue_options;
322 }
323
324 /*static*/ BatchResourceBase::AdaptiveBatcherT::QueueOptions
GetAdaptiveBatcherQueueOptions(int32_t max_batch_size,int32_t batch_timeout_micros,int32_t max_enqueued_batches,bool enable_large_batch_splitting,const std::vector<int32> & allowed_batch_sizes)325 BatchResourceBase::GetAdaptiveBatcherQueueOptions(
326 int32_t max_batch_size, int32_t batch_timeout_micros,
327 int32_t max_enqueued_batches, bool enable_large_batch_splitting,
328 const std::vector<int32>& allowed_batch_sizes) {
329 AdaptiveBatcherT::QueueOptions batcher_queue_options;
330 batcher_queue_options.max_input_task_size =
331 absl::make_optional(max_batch_size);
332 batcher_queue_options.max_enqueued_batches = max_enqueued_batches;
333 batcher_queue_options.batch_timeout_micros = batch_timeout_micros;
334 if (allowed_batch_sizes.empty()) {
335 batcher_queue_options.max_batch_size = max_batch_size;
336 } else {
337 batcher_queue_options.max_batch_size = *allowed_batch_sizes.rbegin();
338 }
339
340 if (enable_large_batch_splitting) {
341 batcher_queue_options.split_input_task_func =
342 [](std::unique_ptr<BatchTask>* input_task,
343 int open_batch_remaining_slot, int max_batch_size,
344 std::vector<std::unique_ptr<BatchTask>>* output_tasks) -> Status {
345 return SplitInputTask(input_task, open_batch_remaining_slot,
346 max_batch_size, output_tasks);
347 };
348 }
349
350 return batcher_queue_options;
351 }
352
ValidateBatch(const BatchT & batch)353 /*static*/ Status BatchResourceBase::ValidateBatch(const BatchT& batch) {
354 for (int task_idx = 0; task_idx < batch.num_tasks(); ++task_idx) {
355 const BatchResourceBase::BatchTask& task = batch.task(task_idx);
356
357 if (task.inputs.size() != batch.task(0).inputs.size()) {
358 return errors::InvalidArgument(
359 "Batching inputs must have equal number of edges");
360 }
361 }
362
363 return Status::OK();
364 }
365
366 // Returns the smallest entry in 'allowed_batch_sizes_' that is greater than
367 // or equal to 'batch_size'. If 'allowed_batch_sizes_' is empty, simply
368 // returns 'batch_size'.
RoundToLowestAllowedBatchSize(int batch_size) const369 int BatchResourceBase::RoundToLowestAllowedBatchSize(int batch_size) const {
370 if (allowed_batch_sizes_.empty()) {
371 return batch_size;
372 }
373 for (int allowed_size : allowed_batch_sizes_) {
374 if (allowed_size >= batch_size) {
375 return allowed_size;
376 }
377 }
378 LOG(ERROR) << "Batch size " << batch_size
379 << " is greater than largest allowed size; "
380 "ignoring allowed sizes constraint.";
381 return batch_size;
382 }
383
ConcatInputTensors(const BatchT & batch,OpKernelContext * context,std::vector<Tensor> * concatenated_tensors) const384 Status BatchResourceBase::ConcatInputTensors(
385 const BatchT& batch, OpKernelContext* context,
386 std::vector<Tensor>* concatenated_tensors) const {
387 if (batch.num_tasks() == 0) {
388 return errors::InvalidArgument("Empty batch.");
389 }
390
391 const int padded_batch_size = RoundToLowestAllowedBatchSize(batch.size());
392 const int padding_amount = padded_batch_size - batch.size();
393 profiler::TraceMe trace_me([padded_batch_size, padding_amount]() {
394 return profiler::TraceMeEncode(
395 "ConcatInputTensors", {{"batch_size_after_padding", padded_batch_size},
396 {"padding_amount", padding_amount}});
397 });
398 RecordPaddingSize(padding_amount, GetModelName(context), padded_batch_size,
399 context->op_kernel().name_view().data());
400 RecordPaddingSizeV2(padding_amount, GetModelName(context), padded_batch_size,
401 context->op_kernel().name());
402 RecordProcessedBatchSize(padded_batch_size, GetModelName(context),
403 context->op_kernel().name_view().data());
404 RecordProcessedBatchSizeV2(padded_batch_size, GetModelName(context),
405 string(context->op_kernel().name_view()));
406
407 // All tasks should have the same number of input edges.
408 const int num_inputs = batch.task(0).inputs.size();
409 concatenated_tensors->reserve(num_inputs);
410
411 // Process each input one at a time (the typical case has just one).
412 for (int i = 0; i < num_inputs; ++i) {
413 // Concatenate the tasks ith input tensors into a big output tensor.
414 std::vector<Tensor> to_concatenate;
415 to_concatenate.reserve(batch.num_tasks());
416 for (int task_idx = 0; task_idx < batch.num_tasks(); ++task_idx) {
417 to_concatenate.push_back(batch.task(task_idx).inputs.at(i));
418 }
419
420 // Add padding as needed. Use the first row of the first task's tensor as
421 // the data for padding.
422 if (padding_amount > 0) {
423 const Tensor& padding_source = batch.task(0).inputs.at(i);
424 Tensor padding;
425 if (padding_source.shape().dim_size(0) == 0) {
426 return errors::InvalidArgument(
427 "Cannot use an empty tensor with zero rows as padding when "
428 "batching. (Input ",
429 i, " got shape ", padding_source.shape().DebugString(), ".)");
430 }
431 if (padding_source.shape().dim_size(0) == 1) {
432 padding = padding_source;
433 } else {
434 padding = padding_source.Slice(0, 1);
435 }
436 for (int i = 0; i < padding_amount; ++i) {
437 to_concatenate.push_back(padding);
438 }
439 }
440
441 Tensor concatenated_tensor;
442 Status concat_status =
443 Concat(context, to_concatenate, &concatenated_tensor);
444 TF_RETURN_IF_ERROR(concat_status);
445 concatenated_tensors->push_back(concatenated_tensor);
446 }
447 return Status::OK();
448 }
449
SplitInputTask(std::unique_ptr<BatchTask> * input_task_ptr,int open_batch_remaining_slot,int max_batch_size,std::vector<std::unique_ptr<BatchTask>> * output_tasks)450 /*static*/ Status BatchResourceBase::SplitInputTask(
451 std::unique_ptr<BatchTask>* input_task_ptr, int open_batch_remaining_slot,
452 int max_batch_size, std::vector<std::unique_ptr<BatchTask>>* output_tasks) {
453 BatchTask& input_task = *(*input_task_ptr);
454 const int64_t input_task_size = input_task.size();
455
456 DCHECK_GT(input_task_size, open_batch_remaining_slot);
457
458 std::shared_ptr<ThreadSafeStatus> shared_status = input_task.status;
459
460 // `split_task_done_callback` runs only after all splitted tasks are
461 // complete.
462 std::function<void()> split_task_done_callback =
463 [done_callback = input_task.done_callback, output = input_task.output,
464 op_kernel_context = input_task.context, status = shared_status]() {
465 const int num_output = op_kernel_context->num_outputs();
466 for (int i = 0; i < num_output; ++i) {
467 Tensor output_tensor;
468
469 // Concat would memcpy each input tensor to one output tensor.
470 // In this context, Concat can be further optimized to get rid of
471 // some (probably all) memcpy when input tensors are slices of
472 // another copy.
473 std::vector<Tensor> to_concatenate;
474 to_concatenate.reserve(output->size());
475 for (int j = 0; j < output->size(); ++j) {
476 to_concatenate.push_back(std::move((*output)[j][i]));
477 }
478 const auto concat_status =
479 Concat(op_kernel_context, to_concatenate, &output_tensor);
480 if (!concat_status.ok()) {
481 status->Update(concat_status);
482 }
483
484 op_kernel_context->set_output(i, std::move(output_tensor));
485 }
486 op_kernel_context->SetStatus(status->status());
487 done_callback();
488 };
489 IncrementalBarrier barrier(split_task_done_callback);
490
491 std::vector<int64> output_task_sizes;
492
493 if (open_batch_remaining_slot > 0) {
494 output_task_sizes.push_back(open_batch_remaining_slot);
495 }
496
497 for (int left_task_size = input_task_size - open_batch_remaining_slot;
498 left_task_size > 0; left_task_size -= max_batch_size) {
499 int next_task_size = std::min(left_task_size, max_batch_size);
500 output_task_sizes.push_back(next_task_size);
501 }
502
503 const int output_task_num = output_task_sizes.size();
504 input_task.output->resize(output_task_num);
505
506 for (int i = 0; i < output_task_num; ++i) {
507 (*input_task.output)[i].resize(input_task.context->num_outputs());
508 }
509
510 output_tasks->reserve(output_task_num);
511 for (int i = 0; i < output_task_num; i++) {
512 output_tasks->push_back(input_task.CreateSplitTask(i, barrier.Inc()));
513 }
514
515 const int num_input_tensors = input_task.inputs.size();
516
517 // Splits each input tensor according to `output_task_sizes`, and
518 // initializes input of `output_tasks` with split results.
519 for (int i = 0; i < num_input_tensors; ++i) {
520 std::vector<Tensor> split_tensors;
521 const Tensor& input_tensor = input_task.inputs[i];
522 // TODO(b/154140947):
523 // Figure out the optimal implementation of Split, by using
524 // 'Tensor::Slice' and eliminating unnecessary memcpy as much as possible.
525 const Status split_status = Split(input_task.context, input_tensor,
526 output_task_sizes, &split_tensors);
527 if (!split_status.ok()) {
528 return errors::Internal(
529 "When splitting input, Tensor split operation failed: ",
530 split_status.ToString());
531 }
532 if (split_tensors.size() != output_task_sizes.size()) {
533 return errors::Internal(
534 "When splitting input, tensor split operation did not work as "
535 "expected; got ",
536 split_tensors.size(), " splits; expected ", output_task_sizes.size());
537 }
538 for (int j = 0; j < output_tasks->size(); ++j) {
539 BatchTask& output_task = *((*output_tasks)[j]);
540 auto moved_tensor_iter = std::next(split_tensors.begin(), j);
541 std::move(moved_tensor_iter, moved_tensor_iter + 1,
542 std::back_inserter(output_task.inputs));
543 }
544 }
545 return Status::OK();
546 }
547
SplitOutputTensors(const std::vector<Tensor> & combined_outputs,BatchT * batch) const548 Status BatchResourceBase::SplitOutputTensors(
549 const std::vector<Tensor>& combined_outputs, BatchT* batch) const {
550 DCHECK_GE(batch->num_tasks(), 1);
551 if (batch->num_tasks() < 1) {
552 return errors::Internal("Batch size expected to be positive; was ",
553 batch->num_tasks());
554 }
555
556 std::vector<int64> task_sizes_plus_optional_padding;
557 task_sizes_plus_optional_padding.reserve(batch->num_tasks());
558 for (int i = 0; i < batch->num_tasks(); ++i) {
559 task_sizes_plus_optional_padding.push_back(batch->task(i).size());
560 }
561 const int padding_size =
562 RoundToLowestAllowedBatchSize(batch->size()) - batch->size();
563 if (padding_size > 0) {
564 task_sizes_plus_optional_padding.push_back(padding_size);
565 }
566
567 // For each output tensor name, a divided-up tensor with one entry per task.
568 std::map<string, std::vector<Tensor>> split_tensors;
569
570 DCHECK_EQ(batch->task(0).context->num_outputs(), combined_outputs.size());
571 int combined_outputs_size = combined_outputs.size();
572 if (combined_outputs_size != batch->task(0).context->num_outputs()) {
573 return errors::Internal("Wrong number of batched output tensors");
574 }
575
576 // Generate 'split_tensors' and populate the context outputs.
577 for (int i = 0, iter_limit = combined_outputs.size(); i < iter_limit; ++i) {
578 const Tensor& output_tensor = combined_outputs[i];
579 if (output_tensor.shape().dims() == 0) {
580 return errors::FailedPrecondition(
581 "Batched output tensor has 0 dimensions");
582 }
583 if (output_tensor.shape().dim_size(0) !=
584 static_cast<int64>(batch->size() + padding_size)) {
585 return errors::FailedPrecondition(
586 "Batched output tensor's 0th dimension does not equal the sum of "
587 "the 0th dimension sizes of the input tensors");
588 }
589
590 std::vector<Tensor> split_tensor;
591 const Status split_status = tensor::Split(
592 output_tensor, task_sizes_plus_optional_padding, &split_tensor);
593 DCHECK(split_status.ok()) << split_status.ToString();
594 if (!split_status.ok()) {
595 return errors::Internal("Tensor split operation failed: ",
596 split_status.ToString());
597 }
598 DCHECK_EQ(split_tensor.size(), task_sizes_plus_optional_padding.size());
599 if (split_tensor.size() != task_sizes_plus_optional_padding.size()) {
600 return errors::Internal(
601 "Tensor split operation did not work as expected; got ",
602 split_tensor.size(), " splits; expected ",
603 task_sizes_plus_optional_padding.size());
604 }
605
606 // Ignore a possible final split_tensors entry containing the padding.
607 for (int j = 0; j < batch->num_tasks(); ++j) {
608 BatchTask& task = *(batch->mutable_task(j));
609 if (task.is_partial) {
610 std::vector<Tensor>& tensor_vector = (*task.output)[task.split_index];
611 tensor_vector[i] = std::move(split_tensor[j]);
612 } else {
613 task.context->set_output(i, split_tensor[j]);
614 }
615 }
616 }
617
618 return Status::OK();
619 }
620
ProcessFuncBatch(std::unique_ptr<BatchT> batch) const621 void BatchResourceBase::ProcessFuncBatch(std::unique_ptr<BatchT> batch) const {
622 if (batch->empty()) {
623 return;
624 }
625
626 const char* cost_measurement_type = GetCostMeasurementType();
627 auto batch_cost_measurement =
628 cost_measurement_type
629 ? CostMeasurementRegistry::CreateByNameOrNull(cost_measurement_type)
630 : nullptr;
631 int64_t processed_size = batch->size();
632 auto batch_cost_split_cleanup = gtl::MakeCleanup([&] {
633 SplitBatchCost(batch_cost_measurement.get(), processed_size, *batch);
634 });
635
636 // We use the 'propagated_context' from one of the threads which setup one
637 // of the tasks. This will propagate any common context over all the threads
638 // which are running this Session, of which this BatchOp is a part.
639 WithContext wc(batch->task(batch->num_tasks() - 1).propagated_context);
640
641 auto& last_task = batch->task(batch->num_tasks() - 1);
642 OpKernelContext* last_task_context = last_task.context;
643
644 // Regardless of the outcome, we need to propagate the status to the
645 // individual tasks and signal that they are done. We use MakeCleanup() to
646 // ensure that this happens no matter how we exit the method below.
647 Status status;
648 bool cleanup_done = false;
649 auto cleanup_fn = [&cleanup_done, &batch](const Status& status) {
650 if (cleanup_done) {
651 return;
652 }
653 for (int i = 0; i < batch->num_tasks(); ++i) {
654 if (batch->task(i).is_partial) {
655 batch->mutable_task(i)->status->Update(status);
656 } else {
657 batch->mutable_task(i)->context->SetStatus(status);
658 }
659
660 batch->mutable_task(i)->done_callback();
661 }
662 cleanup_done = true;
663 };
664
665 auto finally =
666 gtl::MakeCleanup([&cleanup_fn, &status] { cleanup_fn(status); });
667
668 status = ValidateBatch(*batch);
669 if (!status.ok()) {
670 return;
671 }
672
673 std::vector<Tensor> concatenated_tensors;
674 status = ConcatInputTensors(*batch, last_task_context, &concatenated_tensors);
675 processed_size = RoundToLowestAllowedBatchSize(batch->size());
676 if (!status.ok()) {
677 return;
678 }
679
680 std::vector<Tensor> combined_outputs;
681 std::vector<Tensor> args(concatenated_tensors.begin(),
682 concatenated_tensors.end());
683 const auto& captured_inputs =
684 batch->task(batch->num_tasks() - 1).captured_inputs;
685 args.insert(args.end(), captured_inputs.begin(), captured_inputs.end());
686
687 uint64 current_time = EnvTime::NowNanos();
688 const string& model_name = GetModelName(last_task_context);
689 for (int i = 0; i < batch->num_tasks(); ++i) {
690 RecordBatchDelayUs((current_time - batch->task(i).start_time) * 1e-3,
691 model_name,
692 last_task_context->op_kernel().name_view().data());
693 RecordBatchDelayUsV2((current_time - batch->task(i).start_time) * 1e-3,
694 model_name, last_task_context->op_kernel().name());
695 }
696 // Releases the cleanup method here, because the callback of the function
697 // library runtime will handle it now.
698 finally.release();
699 ProcessFuncBatchImpl(
700 last_task, args, &combined_outputs, [&](const Status& run_status) {
701 Status final_status;
702 auto run_finally = gtl::MakeCleanup([&]() {
703 // We do the cleanup here as an optimization, so that
704 // it runs in the underlying TF inter-op threadpool.
705 // Running it in the threadpool, let's the ensuing
706 // ops be scheduled faster, because the executor will
707 // add them to the front of the threadpool's task
708 // queue rather than the end.
709 cleanup_fn(final_status);
710 });
711 final_status = run_status;
712 if (!final_status.ok()) {
713 return;
714 }
715 final_status = SplitOutputTensors(combined_outputs, batch.get());
716 });
717 }
718
719 // Processes a batch of one or more BatchTask entries.
ProcessBatch(std::unique_ptr<BatchT> batch) const720 void BatchResourceBase::ProcessBatch(std::unique_ptr<BatchT> batch) const {
721 if (batch->empty()) {
722 return;
723 }
724
725 const char* cost_measurement_type = GetCostMeasurementType();
726 auto batch_cost_measurement =
727 cost_measurement_type
728 ? CostMeasurementRegistry::CreateByNameOrNull(cost_measurement_type)
729 : nullptr;
730 int64_t processed_size = batch->size();
731 auto batch_cost_split_cleaner = gtl::MakeCleanup([&] {
732 SplitBatchCost(batch_cost_measurement.get(), processed_size, *batch);
733 });
734
735 WithContext wc(batch->task(batch->num_tasks() - 1).propagated_context);
736
737 OpKernelContext* last_task_context =
738 batch->task(batch->num_tasks() - 1).context;
739 AsyncOpKernel::DoneCallback last_task_callback =
740 batch->task(batch->num_tasks() - 1).done_callback;
741
742 OP_REQUIRES_OK_ASYNC(last_task_context, ValidateBatch(*batch),
743 last_task_callback);
744
745 // All tasks should have the same number of input edges.
746 const int num_input_edges = batch->task(0).inputs.size();
747 std::vector<Tensor> concatenated_tensors;
748 const Status concat_status =
749 ConcatInputTensors(*batch, last_task_context, &concatenated_tensors);
750 processed_size = RoundToLowestAllowedBatchSize(batch->size());
751 OP_REQUIRES_OK_ASYNC(last_task_context, concat_status, last_task_callback);
752
753 // Process each input edge one at a time (the typical case has just one).
754 for (int i = 0; i < num_input_edges; ++i) {
755 last_task_context->set_output(i, concatenated_tensors[i]);
756
757 // Emit batch->num_tasks() - 1 empty output tensors.
758 for (int task_idx = 0; task_idx < batch->num_tasks() - 1; ++task_idx) {
759 const BatchTask& task = batch->task(task_idx);
760 TensorShape output_shape(task.inputs[i].shape());
761 output_shape.set_dim(0, 0);
762 Tensor* output = nullptr;
763 OP_REQUIRES_OK_ASYNC(
764 task.context, task.context->allocate_output(i, output_shape, &output),
765 task.done_callback);
766 }
767 }
768 // Emit batch->num_tasks() - 1 empty index tensors.
769 for (int task_idx = 0; task_idx < batch->num_tasks() - 1; ++task_idx) {
770 const BatchTask& task = batch->task(task_idx);
771 TensorShape index_shape({0, 3});
772 Tensor* output = nullptr;
773 OP_REQUIRES_OK_ASYNC(
774 task.context,
775 task.context->allocate_output(num_input_edges, index_shape, &output),
776 task.done_callback);
777 }
778 // Emit all ID tensors.
779 for (int task_idx = 0; task_idx < batch->num_tasks(); ++task_idx) {
780 const BatchTask& task = batch->task(task_idx);
781 Tensor* id;
782 OP_REQUIRES_OK_ASYNC(task.context,
783 task.context->allocate_output(num_input_edges + 1,
784 TensorShape({}), &id),
785 task.done_callback);
786 id->scalar<int64>()() = task.guid;
787 }
788 OP_REQUIRES_OK_ASYNC(
789 last_task_context,
790 EmitIndexTensor(last_task_context, *batch, num_input_edges),
791 last_task_callback);
792
793 // Signal done for each element of the batch. (At this point, the contexts
794 // are no longer guaranteed to remain live.)
795 for (int task_idx = 0; task_idx < batch->num_tasks(); ++task_idx) {
796 batch->mutable_task(task_idx)->done_callback();
797 }
798 }
799
EmitIndexTensor(OpKernelContext * context,const BatchT & batch,int output_index)800 /*static*/ Status BatchResourceBase::EmitIndexTensor(OpKernelContext* context,
801 const BatchT& batch,
802 int output_index) {
803 const TensorShape index_shape({batch.num_tasks(), 3});
804 Tensor* index = nullptr;
805 TF_RETURN_IF_ERROR(
806 context->allocate_output(output_index, index_shape, &index));
807 auto index_flat = index->shaped<int64, 2>({batch.num_tasks(), 3});
808 size_t offset = 0;
809 for (int task_idx = 0; task_idx < batch.num_tasks(); ++task_idx) {
810 const BatchTask& task = batch.task(task_idx);
811 index_flat(task_idx, 0) = task.guid;
812 index_flat(task_idx, 1) = offset;
813 index_flat(task_idx, 2) = offset + task.size();
814 offset += task.size();
815 }
816 return Status::OK();
817 }
818
819 // Looks up the batcher queue for 'queue_name'. If it did't previously exist,
820 // creates it.
LookupOrCreateBatcherQueue(const string & queue_name,BatcherQueueT ** queue)821 Status BatchResourceBase::LookupOrCreateBatcherQueue(const string& queue_name,
822 BatcherQueueT** queue) {
823 mutex_lock l(batcher_queues_mu_);
824
825 auto it = batcher_queues_.find(queue_name);
826 if (it != batcher_queues_.end()) {
827 *queue = it->second.get();
828 return Status::OK();
829 }
830
831 std::unique_ptr<BatcherQueueT> new_queue;
832 auto process_batch_callback = [this](std::unique_ptr<BatchT> batch) {
833 if (!has_process_batch_function_) {
834 ProcessBatch(std::move(batch));
835 } else {
836 ProcessFuncBatch(std::move(batch));
837 }
838 };
839 if (batcher_) {
840 TF_RETURN_IF_ERROR(batcher_->AddQueue(batcher_queue_options_,
841 process_batch_callback, &new_queue));
842 } else if (adaptive_batcher_) {
843 TF_RETURN_IF_ERROR(adaptive_batcher_->AddQueue(
844 adaptive_batcher_queue_options_, process_batch_callback, &new_queue));
845 } else {
846 return errors::Internal("No batcher defined.");
847 }
848 *queue = new_queue.get();
849 batcher_queues_[queue_name] = std::move(new_queue);
850 return Status::OK();
851 }
852
CreateBatchTask(OpKernelContext * context,std::unique_ptr<BatchResourceBase::BatchTask> * output) const853 Status BatchResourceBase::CreateBatchTask(
854 OpKernelContext* context,
855 std::unique_ptr<BatchResourceBase::BatchTask>* output) const {
856 *output = absl::make_unique<BatchResourceBase::BatchTask>();
857 return Status::OK();
858 }
859
SplitBatchCost(CostMeasurement * batch_cost_measurement,const int64_t processed_size,BatchT & batch) const860 void BatchResourceBase::SplitBatchCost(CostMeasurement* batch_cost_measurement,
861 const int64_t processed_size,
862 BatchT& batch) const {
863 if (batch_cost_measurement == nullptr ||
864 batch_cost_measurement->GetTotalCost() == absl::ZeroDuration()) {
865 return;
866 }
867 // TODO(b/1858529900): Split the cost to each task: define RequestCost for
868 // each inference request and add it as a field of BatchTask, implement the
869 // cost split algorithms where the paddings share / do not share the cost.
870 }
871
872 } // namespace serving
873 } // namespace tensorflow
874