1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/experimental/map_and_batch_dataset_op.h"
16
17 #include <atomic>
18 #include <utility>
19
20 #include "tensorflow/core/common_runtime/function.h"
21 #include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h"
22 #include "tensorflow/core/data/dataset_utils.h"
23 #include "tensorflow/core/data/name_utils.h"
24 #include "tensorflow/core/data/stats_utils.h"
25 #include "tensorflow/core/framework/metrics.h"
26 #include "tensorflow/core/framework/model.h"
27 #include "tensorflow/core/framework/partial_tensor_shape.h"
28 #include "tensorflow/core/framework/stats_aggregator.h"
29 #include "tensorflow/core/framework/tensor.h"
30 #include "tensorflow/core/kernels/inplace_ops_functor.h"
31 #include "tensorflow/core/lib/core/errors.h"
32 #include "tensorflow/core/lib/gtl/cleanup.h"
33 #include "tensorflow/core/lib/random/random.h"
34 #include "tensorflow/core/lib/strings/strcat.h"
35 #include "tensorflow/core/platform/cpu_info.h"
36 #include "tensorflow/core/platform/env_time.h"
37 #include "tensorflow/core/platform/status.h"
38 #include "tensorflow/core/platform/stringprintf.h"
39 #include "tensorflow/core/platform/tracing.h"
40 #include "tensorflow/core/profiler/lib/traceme.h"
41 #include "tensorflow/core/profiler/lib/traceme_encode.h"
42
43 namespace tensorflow {
44 namespace data {
45 namespace experimental {
46
47 /* static */ constexpr const char* const MapAndBatchDatasetOp::kDatasetType;
48 /* static */ constexpr const char* const MapAndBatchDatasetOp::kInputDataset;
49 /* static */ constexpr const char* const MapAndBatchDatasetOp::kOtherArguments;
50 /* static */ constexpr const char* const MapAndBatchDatasetOp::kBatchSize;
51 /* static */ constexpr const char* const
52 MapAndBatchDatasetOp::kNumParallelCalls;
53 /* static */ constexpr const char* const MapAndBatchDatasetOp::kDropRemainder;
54 /* static */ constexpr const char* const MapAndBatchDatasetOp::kFunc;
55 /* static */ constexpr const char* const MapAndBatchDatasetOp::kTarguments;
56 /* static */ constexpr const char* const MapAndBatchDatasetOp::kOutputTypes;
57 /* static */ constexpr const char* const MapAndBatchDatasetOp::kOutputShapes;
58 /* static */ constexpr const char* const
59 MapAndBatchDatasetOp::kPreserveCardinality;
60
61 // Maximum number of batch results to buffer.
62
63 namespace {
64
65 constexpr int64_t kMaxBatchResults = 16;
66 constexpr char kParallelism[] = "parallelism";
67 constexpr char kCallCounter[] = "call_counter";
68 constexpr char kBatchResultsSize[] = "batch_results_size";
69 constexpr char kTFDataMapAndBatch[] = "tf_data_map_and_batch";
70 constexpr char kBatchResults[] = "batch_results";
71 constexpr char kEndOfInput[] = "end_of_input";
72 constexpr char kNumCalls[] = "num_calls";
73 constexpr char kNumElements[] = "num_elements";
74 constexpr char kOutputAllocated[] = "output_allocated";
75 constexpr char kStatus[] = "status";
76
77 // Computes ceil(x / y).
CeilDiv(int64_t x,int64_t y)78 inline int64_t CeilDiv(int64_t x, int64_t y) { return (x + y - 1) / y; }
79
80 } // namespace
81
82 class MapAndBatchDatasetOp::Dataset : public DatasetBase {
83 public:
Dataset(OpKernelContext * ctx,const DatasetBase * input,int64_t batch_size,int64_t num_parallel_calls,bool drop_remainder,const DataTypeVector & output_types,const std::vector<PartialTensorShape> & output_shapes,std::unique_ptr<CapturedFunction> captured_func,bool preserve_cardinality)84 Dataset(OpKernelContext* ctx, const DatasetBase* input, int64_t batch_size,
85 int64_t num_parallel_calls, bool drop_remainder,
86 const DataTypeVector& output_types,
87 const std::vector<PartialTensorShape>& output_shapes,
88 std::unique_ptr<CapturedFunction> captured_func,
89 bool preserve_cardinality)
90 : DatasetBase(DatasetContext(ctx)),
91 input_(input),
92 batch_size_(batch_size),
93 num_parallel_calls_(num_parallel_calls),
94 drop_remainder_(drop_remainder),
95 output_types_(output_types),
96 output_shapes_(output_shapes),
97 captured_func_(std::move(captured_func)),
98 preserve_cardinality_(preserve_cardinality),
99 traceme_metadata_(
100 {{"autotune",
101 num_parallel_calls == model::kAutotune ? "true" : "false"},
102 {"batch_size",
103 strings::Printf("%lld", static_cast<long long>(batch_size))},
104 {"drop_remainder", drop_remainder ? "true" : "false"}}) {
105 input_->Ref();
106 }
107
~Dataset()108 ~Dataset() override { input_->Unref(); }
109
MakeIteratorInternal(const string & prefix) const110 std::unique_ptr<IteratorBase> MakeIteratorInternal(
111 const string& prefix) const override {
112 return std::make_unique<Iterator>(Iterator::Params{
113 this, name_utils::IteratorPrefix(kDatasetType, prefix)});
114 }
115
output_dtypes() const116 const DataTypeVector& output_dtypes() const override { return output_types_; }
117
output_shapes() const118 const std::vector<PartialTensorShape>& output_shapes() const override {
119 return output_shapes_;
120 }
121
DebugString() const122 string DebugString() const override {
123 return name_utils::DatasetDebugString(kDatasetType);
124 }
125
CardinalityInternal() const126 int64_t CardinalityInternal() const override {
127 if (!preserve_cardinality_) {
128 return kUnknownCardinality;
129 }
130 int64_t n = input_->Cardinality();
131 if (n == kInfiniteCardinality || n == kUnknownCardinality) {
132 return n;
133 }
134 return n / batch_size_ + (n % batch_size_ == 0 || drop_remainder_ ? 0 : 1);
135 }
136
InputDatasets(std::vector<const DatasetBase * > * inputs) const137 Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
138 inputs->push_back(input_);
139 return OkStatus();
140 }
141
CheckExternalState() const142 Status CheckExternalState() const override {
143 TF_RETURN_IF_ERROR(captured_func_->CheckExternalState());
144 return input_->CheckExternalState();
145 }
146
147 protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const148 Status AsGraphDefInternal(SerializationContext* ctx,
149 DatasetGraphDefBuilder* b,
150 Node** output) const override {
151 Node* input_graph_node = nullptr;
152 TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
153 Node* batch_size_node;
154 TF_RETURN_IF_ERROR(b->AddScalar(batch_size_, &batch_size_node));
155 Node* num_parallel_calls_node;
156 TF_RETURN_IF_ERROR(
157 b->AddScalar(num_parallel_calls_, &num_parallel_calls_node));
158 Node* drop_remainder_node;
159 TF_RETURN_IF_ERROR(b->AddScalar(drop_remainder_, &drop_remainder_node));
160 std::vector<Node*> other_arguments;
161 DataTypeVector other_arguments_types;
162 TF_RETURN_IF_ERROR(captured_func_->AddToGraph(ctx, b, &other_arguments,
163 &other_arguments_types));
164 AttrValue f;
165 b->BuildAttrValue(captured_func_->func(), &f);
166 AttrValue other_arguments_types_attr;
167 b->BuildAttrValue(other_arguments_types, &other_arguments_types_attr);
168 AttrValue preserve_cardinality_attr;
169 b->BuildAttrValue(preserve_cardinality_, &preserve_cardinality_attr);
170
171 TF_RETURN_IF_ERROR(b->AddDataset(
172 this,
173 {std::make_pair(0, input_graph_node),
174 std::make_pair(2, batch_size_node),
175 std::make_pair(3, num_parallel_calls_node),
176 std::make_pair(4, drop_remainder_node)}, // Single tensor inputs.
177 {std::make_pair(1, other_arguments)}, // Tensor list inputs.
178 {std::make_pair(kFunc, f),
179 std::make_pair(kTarguments, other_arguments_types_attr),
180 std::make_pair(kPreserveCardinality,
181 preserve_cardinality_attr)}, // Attrs
182 output));
183 return OkStatus();
184 }
185
186 private:
187 class Iterator : public DatasetIterator<Dataset> {
188 public:
Iterator(const Params & params)189 explicit Iterator(const Params& params)
190 : DatasetIterator<Dataset>(params),
191 mu_(std::make_shared<mutex>()),
192 cond_var_(std::make_shared<condition_variable>()),
193 num_parallel_calls_(std::make_shared<model::SharedState>(
194 params.dataset->num_parallel_calls_, mu_, cond_var_)) {
195 // To mitigate the effect of stragglers (i.e. map invocations that take
196 // much longer than others), we allow the kernel to pre-compute batches
197 // ahead of time and store them in an internal buffer. The maximum number
198 // of batches to buffer is a trade-off between performance and memory and
199 // we derive it from the degree of parallelism and the batch size.
200 //
201 // TODO(b/178059273): If we handle RAM budget correctly, the upper bound
202 // should be removed.
203 max_batch_results_ = std::min(
204 kMaxBatchResults,
205 CeilDiv(params.dataset->num_parallel_calls_ == model::kAutotune
206 ? GetCpuBudget() // maximum parallelism
207 : params.dataset->num_parallel_calls_,
208 params.dataset->batch_size_));
209 }
210
~Iterator()211 ~Iterator() override {
212 CancelThreads(/*wait=*/true);
213 if (deregister_fn_) deregister_fn_();
214 }
215
Initialize(IteratorContext * ctx)216 Status Initialize(IteratorContext* ctx) override {
217 mutex_lock l(*mu_);
218 interleave_depth_ = ctx->interleave_depth();
219
220 if (num_parallel_calls_->value == model::kAutotune) {
221 num_parallel_calls_->value = GetAutotuneDefaultParallelism(ctx);
222 }
223 cancellation_manager_ = std::make_unique<CancellationManager>();
224 TF_RETURN_IF_ERROR(RegisterCancellationCallback(
225 ctx->cancellation_manager(),
226 [this]() { CancelThreads(/*wait=*/false); }, &deregister_fn_));
227 IteratorContext::Params params(ctx);
228 params.cancellation_manager = cancellation_manager_.get();
229 TF_RETURN_IF_ERROR(dataset()->input_->MakeIterator(
230 IteratorContext(params), this, prefix(), &input_impl_));
231 return dataset()->captured_func_->Instantiate(
232 ctx, &instantiated_captured_func_);
233 }
234
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)235 Status GetNextInternal(IteratorContext* ctx,
236 std::vector<Tensor>* out_tensors,
237 bool* end_of_sequence) override {
238 std::shared_ptr<BatchResult> result;
239 {
240 mutex_lock l(*mu_);
241 EnsureRunnerThreadStarted(ctx);
242 while (!cancelled_ && (batch_results_.empty() ||
243 batch_results_.front()->num_calls > 0)) {
244 ++waiting_;
245 RecordStop(ctx);
246 cond_var_->wait(l);
247 RecordStart(ctx);
248 --waiting_;
249 }
250 if (cancelled_) {
251 return errors::Cancelled("Iterator was cancelled");
252 }
253 std::swap(result, batch_results_.front());
254 batch_results_.pop_front();
255 cond_var_->notify_all();
256 }
257 profiler::TraceMe traceme([&] {
258 return profiler::TraceMeEncode("MapAndBatchConsume",
259 {{"element_id", result->uid}});
260 });
261 // Deallocate tensors allocated for the output.
262 auto cleanup = gtl::MakeCleanup([result] { result->output.clear(); });
263 mutex_lock l(result->mu);
264 if (result->output_allocated) {
265 RecordBufferDequeue(ctx, result->output);
266 }
267 TF_RETURN_IF_ERROR(
268 ProcessBatch(dataset()->batch_size_, result->num_elements,
269 dataset()->drop_remainder_, result->status, ctx,
270 out_tensors, end_of_sequence, &result->output));
271 return OkStatus();
272 }
273
274 protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const275 std::shared_ptr<model::Node> CreateNode(
276 IteratorContext* ctx, model::Node::Args args) const override {
277 return model::MakeAsyncKnownRatioNode(
278 std::move(args), dataset()->batch_size_,
279 {model::MakeParameter(kParallelism, num_parallel_calls_, /*min=*/1,
280 /*max=*/ctx->runner_threadpool_size())});
281 }
282
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)283 Status SaveInternal(SerializationContext* ctx,
284 IteratorStateWriter* writer) override {
285 TF_RETURN_IF_ERROR(ctx->HandleCheckExternalStateStatus(
286 dataset()->captured_func_->CheckExternalState()));
287 mutex_lock l(*mu_);
288 // Wait for all in-flight calls to complete.
289 while (num_calls_ > 0) {
290 cond_var_->wait(l);
291 }
292 DCHECK_EQ(num_calls_, 0);
293 TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
294 TF_RETURN_IF_ERROR(
295 writer->WriteScalar(full_name(kCallCounter), call_counter_));
296 TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kBatchResultsSize),
297 batch_results_.size()));
298 for (size_t i = 0; i < batch_results_.size(); ++i) {
299 TF_RETURN_IF_ERROR(WriteBatchResult(writer, i));
300 }
301 return OkStatus();
302 }
303
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)304 Status RestoreInternal(IteratorContext* ctx,
305 IteratorStateReader* reader) override {
306 mutex_lock l(*mu_);
307 TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
308 TF_RETURN_IF_ERROR(
309 reader->ReadScalar(full_name(kCallCounter), &call_counter_));
310 int64_t batch_results_size;
311 TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kBatchResultsSize),
312 &batch_results_size));
313 DCHECK(batch_results_.empty());
314 for (int i = 0; i < batch_results_size; ++i) {
315 TF_RETURN_IF_ERROR(ReadBatchResult(ctx, reader, i));
316 }
317 return OkStatus();
318 }
319
GetTraceMeMetadata() const320 TraceMeMetadata GetTraceMeMetadata() const override {
321 int64_t parallelism = -1;
322 int64_t max_batch_results = -1;
323 // NOTE: We only set the parallelism value if the lock can be acquired
324 // right away to avoid introducing tracing overhead.
325 if (mu_->try_lock()) {
326 parallelism = num_parallel_calls_->value;
327 max_batch_results = max_batch_results_;
328 mu_->unlock();
329 }
330 auto result = dataset()->traceme_metadata_;
331 result.push_back(std::make_pair(
332 "max_batch_results",
333 strings::Printf("%lld", static_cast<long long>(max_batch_results))));
334 result.push_back(std::make_pair(
335 "parallelism",
336 parallelism == -1
337 ? kTraceInfoUnavailable
338 : strings::Printf("%lld", static_cast<long long>(parallelism))));
339 result.push_back(std::make_pair(
340 "interleave_depth",
341 strings::Printf("%lld", static_cast<long long>(interleave_depth_))));
342 return result;
343 }
344
345 private:
346 // BatchResult encapsulates the output batch, as well as ancillary
347 // metadata required to execute the fused map-and-batch operation.
348 struct BatchResult {
BatchResulttensorflow::data::experimental::MapAndBatchDatasetOp::Dataset::Iterator::BatchResult349 explicit BatchResult(int64_t batch_size)
350 : end_of_input(false),
351 num_elements(0),
352 output_allocated(false),
353 status(OkStatus()),
354 status_offset(-1),
355 num_calls(batch_size),
356 uid(tensorflow::EnvTime::NowNanos()) {}
357
358 // UpdateStatus updates the batch's aggregate Status.
359 //
360 // In order to ensure that exactly the first non-OK status is returned
361 // (required to make the behavior is observably identical to a
362 // sequential execution of map followed by batch), we must also keep
363 // track of the offset into the batch that produced `s`.
UpdateStatustensorflow::data::experimental::MapAndBatchDatasetOp::Dataset::Iterator::BatchResult364 void UpdateStatus(const Status& s, int64_t offset) {
365 if (TF_PREDICT_FALSE(!s.ok())) {
366 mutex_lock l(mu);
367 if (status.ok() || offset < status_offset) {
368 status = s;
369 status_offset = offset;
370 }
371 }
372 }
373
374 mutex mu;
375 bool end_of_input TF_GUARDED_BY(mu);
376 int64_t num_elements TF_GUARDED_BY(mu);
377 std::vector<Tensor> output;
378 bool output_allocated TF_GUARDED_BY(mu);
379 Status status TF_GUARDED_BY(mu);
380 int64_t status_offset TF_GUARDED_BY(mu);
381 // Counts the number of outstanding calls for this batch.
382 int64_t num_calls TF_GUARDED_BY(&Iterator::mu_);
383 const uint64 uid = -1;
384 };
385
CallCompleted(const std::shared_ptr<IteratorContext> & ctx,const std::shared_ptr<BatchResult> & result)386 void CallCompleted(const std::shared_ptr<IteratorContext>& ctx,
387 const std::shared_ptr<BatchResult>& result)
388 TF_LOCKS_EXCLUDED(*mu_) {
389 mutex_lock l(*mu_);
390 num_calls_--;
391 result->num_calls--;
392 const auto& stats_aggregator = ctx->stats_aggregator();
393 if (stats_aggregator) {
394 stats_aggregator->AddScalar(
395 stats_utils::ThreadUtilizationScalarName(dataset()->node_name()),
396 static_cast<float>(num_calls_) /
397 static_cast<float>(num_parallel_calls_->value),
398 num_elements());
399 }
400 cond_var_->notify_all();
401 }
402
CallFunction(std::shared_ptr<IteratorContext> ctx,const std::shared_ptr<BatchResult> & result,int64_t offset)403 void CallFunction(std::shared_ptr<IteratorContext> ctx,
404 const std::shared_ptr<BatchResult>& result,
405 int64_t offset) TF_LOCKS_EXCLUDED(*mu_) {
406 profiler::TraceMe traceme([&] {
407 return profiler::TraceMeEncode("MapAndBatchProduce",
408 {{"element_id", result->uid}});
409 });
410 // Get the next input element.
411 std::vector<Tensor> input_element;
412 bool end_of_input = false;
413 Status status =
414 input_impl_->GetNext(ctx.get(), &input_element, &end_of_input);
415 bool return_early;
416 {
417 mutex_lock l(result->mu);
418 result->end_of_input = result->end_of_input || end_of_input;
419 result->status.Update(status);
420 return_early = result->end_of_input || !result->status.ok();
421 }
422 if (return_early) {
423 CallCompleted(ctx, result);
424 return;
425 }
426
427 std::shared_ptr<std::vector<Tensor>> return_values =
428 std::make_shared<std::vector<Tensor>>();
429 auto done = [this, ctx, result, return_values, offset](Status status) {
430 if (dataset()->preserve_cardinality_ && errors::IsOutOfRange(status)) {
431 // To guarantee that the transformation preserves the cardinality of
432 // the dataset, we convert `OutOfRange` to `InvalidArgument` as the
433 // former may be interpreted by a caller as the end of sequence.
434 status = errors::InvalidArgument(
435 "Function invocation produced OutOfRangeError: ",
436 status.error_message());
437 }
438 result->UpdateStatus(status, offset);
439 if (status.ok()) {
440 Status allocate_status =
441 EnsureOutputAllocated(ctx, result, return_values);
442 if (!allocate_status.ok()) {
443 result->UpdateStatus(allocate_status, offset);
444 } else {
445 for (size_t i = 0; i < return_values->size(); ++i) {
446 Tensor& tensor = return_values->at(i);
447 Tensor* batch = &(result->output)[i];
448 if (tensor.NumElements() !=
449 (batch->NumElements() / batch->dim_size(0))) {
450 TensorShape batch_shape = batch->shape();
451 batch_shape.RemoveDim(0);
452 result->UpdateStatus(
453 errors::InvalidArgument(
454 "Cannot add tensor to the batch: number of elements "
455 "does not match. Shapes are: [tensor]: ",
456 tensor.shape().DebugString(),
457 ", [batch]: ", batch_shape.DebugString()),
458 offset);
459 break;
460 }
461 // TODO(mrry): Add a version of DoParallelConcat that allows us
462 // to move `tensor` where possible, to speed up string tensor
463 // batching.
464 Status copy_status = batch_util::CopyElementToSlice(
465 std::move(tensor), batch, offset);
466 if (!copy_status.ok()) {
467 result->UpdateStatus(copy_status, offset);
468 break;
469 }
470 }
471 }
472 {
473 mutex_lock l(result->mu);
474 result->num_elements++;
475 }
476 }
477 CallCompleted(ctx, result);
478 };
479
480 // Apply the map function on `input_element`, storing the result in
481 // `return_values`, and invoking `done` when finished.
482 instantiated_captured_func_->RunAsync(ctx.get(), std::move(input_element),
483 return_values.get(),
484 std::move(done), model_node());
485 }
486
CancelThreads(bool wait)487 void CancelThreads(bool wait) TF_LOCKS_EXCLUDED(mu_) {
488 cancellation_manager_->StartCancel();
489 mutex_lock l(*mu_);
490 cancelled_ = true;
491 cond_var_->notify_all();
492 // Wait for all in-flight calls to complete.
493 while (wait && num_calls_ > 0) {
494 cond_var_->wait(l);
495 }
496 }
497
EnsureRunnerThreadStarted(IteratorContext * ctx)498 void EnsureRunnerThreadStarted(IteratorContext* ctx)
499 TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
500 if (!runner_thread_) {
501 auto ctx_copy = std::make_shared<IteratorContext>(*ctx);
502 runner_thread_ = ctx->StartThread(
503 kTFDataMapAndBatch,
504 std::bind(&Iterator::RunnerThread, this, ctx_copy));
505 }
506 }
507
EnsureOutputAllocated(const std::shared_ptr<IteratorContext> & ctx,const std::shared_ptr<BatchResult> & result,const std::shared_ptr<std::vector<Tensor>> & return_values)508 Status EnsureOutputAllocated(
509 const std::shared_ptr<IteratorContext>& ctx,
510 const std::shared_ptr<BatchResult>& result,
511 const std::shared_ptr<std::vector<Tensor>>& return_values) {
512 mutex_lock l(result->mu);
513 if (result->output_allocated) {
514 return OkStatus();
515 }
516 const size_t num_components = return_values->size();
517 result->output.reserve(num_components);
518 for (size_t i = 0; i < num_components; ++i) {
519 TensorShape component_shape({dataset()->batch_size_});
520 component_shape.AppendShape(return_values->at(i).shape());
521 AllocatorAttributes attr;
522 attr.set_gpu_compatible(true);
523 result->output.emplace_back(ctx->allocator(attr),
524 return_values->at(i).dtype(),
525 component_shape);
526 if (!result->output.back().IsInitialized()) {
527 return errors::ResourceExhausted(
528 "Failed to allocate memory for the batch of component ", i);
529 }
530 }
531 RecordBufferEnqueue(ctx.get(), result->output);
532 result->output_allocated = true;
533 return OkStatus();
534 }
535
RunnerThread(const std::shared_ptr<IteratorContext> & ctx)536 void RunnerThread(const std::shared_ptr<IteratorContext>& ctx)
537 TF_LOCKS_EXCLUDED(*mu_) {
538 std::vector<std::pair<std::shared_ptr<BatchResult>, int64_t>> new_calls;
539 RecordStart(ctx.get());
540 auto stop_cleanup =
541 gtl::MakeCleanup([this, &ctx]() { RecordStop(ctx.get()); });
542 {
543 tf_shared_lock l(*mu_); // mu_ == num_parallel_calls_->mu
544 new_calls.reserve(num_parallel_calls_->value);
545 }
546 auto busy = [this]() TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) -> bool {
547 int64_t num_parallel_calls = num_parallel_calls_->value;
548 return num_calls_ >= num_parallel_calls ||
549 (batch_results_.size() > max_batch_results_ ||
550 (batch_results_.size() == max_batch_results_ &&
551 call_counter_ % dataset()->batch_size_ == 0));
552 };
553 while (true) {
554 {
555 mutex_lock l(*mu_);
556 while (!cancelled_ && busy()) {
557 if (waiting_ > 0 && num_calls_ < num_parallel_calls_->value &&
558 max_batch_results_ < kMaxBatchResults) {
559 // If there is a caller waiting for a batch and the number of
560 // outstanding calls is not maxed out, it means we are out of
561 // `batch_results_` slots. Instead of waiting for a slot to open
562 // up, we create a new one to utilize CPU efficiently.
563 max_batch_results_++;
564 continue;
565 }
566 RecordStop(ctx.get());
567 cond_var_->wait(l);
568 RecordStart(ctx.get());
569 }
570
571 if (cancelled_) {
572 return;
573 }
574
575 while (!busy()) {
576 if (call_counter_ % dataset()->batch_size_ == 0) {
577 batch_results_.push_back(
578 std::make_shared<BatchResult>(dataset()->batch_size_));
579 }
580 int64_t offset = call_counter_++ % dataset()->batch_size_;
581 new_calls.emplace_back(batch_results_.back(), offset);
582 num_calls_++;
583 }
584 }
585 const auto& stats_aggregator = ctx->stats_aggregator();
586 if (stats_aggregator) {
587 mutex_lock l(*mu_);
588 stats_aggregator->AddScalar(
589 stats_utils::ThreadUtilizationScalarName(dataset()->node_name()),
590 static_cast<float>(num_calls_) /
591 static_cast<float>(num_parallel_calls_->value),
592 num_elements());
593 }
594 for (const auto& call : new_calls) {
595 CallFunction(ctx, call.first, call.second);
596 }
597 new_calls.clear();
598 }
599 }
600
ReadBatchResult(IteratorContext * ctx,IteratorStateReader * reader,size_t index)601 Status ReadBatchResult(IteratorContext* ctx, IteratorStateReader* reader,
602 size_t index) TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
603 batch_results_.push_back(
604 std::make_shared<BatchResult>(dataset()->batch_size_));
605 std::shared_ptr<BatchResult> result = batch_results_.back();
606 string batch_prefix = strings::StrCat(kBatchResults, "_", index);
607 mutex_lock l(result->mu);
608 result->end_of_input = reader->Contains(
609 full_name(strings::StrCat(batch_prefix, "_", kEndOfInput)));
610 TF_RETURN_IF_ERROR(reader->ReadScalar(
611 full_name(strings::StrCat(batch_prefix, "_", kNumCalls)),
612 &result->num_calls));
613 TF_RETURN_IF_ERROR(reader->ReadScalar(
614 full_name(strings::StrCat(batch_prefix, "_", kNumElements)),
615 &result->num_elements));
616 result->output_allocated = reader->Contains(
617 full_name(strings::StrCat(batch_prefix, "_", kOutputAllocated)));
618
619 TF_RETURN_IF_ERROR(ReadBatch(ctx, reader, dataset()->batch_size_,
620 prefix(), batch_prefix, &result->output));
621 TF_RETURN_IF_ERROR(ReadStatus(prefix(),
622 strings::StrCat(batch_prefix, "_", kStatus),
623 reader, &result->status));
624 if (result->output_allocated) {
625 RecordBufferEnqueue(ctx, result->output);
626 }
627 return OkStatus();
628 }
629
WriteBatchResult(IteratorStateWriter * writer,size_t index)630 Status WriteBatchResult(IteratorStateWriter* writer, size_t index)
631 TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
632 std::shared_ptr<BatchResult> result = batch_results_[index];
633 string batch_prefix = strings::StrCat(kBatchResults, "_", index);
634 mutex_lock l(result->mu);
635 if (result->end_of_input) {
636 TF_RETURN_IF_ERROR(writer->WriteScalar(
637 full_name(strings::StrCat(batch_prefix, "_", kEndOfInput)), ""));
638 }
639 TF_RETURN_IF_ERROR(writer->WriteScalar(
640 full_name(strings::StrCat(batch_prefix, "_", kNumCalls)),
641 result->num_calls));
642 TF_RETURN_IF_ERROR(writer->WriteScalar(
643 full_name(strings::StrCat(batch_prefix, "_", kNumElements)),
644 result->num_elements));
645 if (result->output_allocated) {
646 TF_RETURN_IF_ERROR(writer->WriteScalar(
647 full_name(strings::StrCat(batch_prefix, "_", kOutputAllocated)),
648 ""));
649 }
650
651 TF_RETURN_IF_ERROR(WriteBatch(dataset()->batch_size_,
652 result->num_elements, prefix(),
653 batch_prefix, writer, &result->output));
654 TF_RETURN_IF_ERROR(
655 WriteStatus(prefix(), strings::StrCat(batch_prefix, "_", kStatus),
656 result->status, writer));
657 return OkStatus();
658 }
659
660 // Used for coordination between the main thread, the runner thread, and
661 // the callback threads.
662 const std::shared_ptr<mutex> mu_;
663 // Used for coordination between the main thread, the runner thread, and
664 // the callback threads. In particular, the runner thread should only
665 // schedule new calls when the number of in-flight calls is less than
666 // `num_parallel_calls_->value` and there are slots available in the
667 // `batch_results_` buffer.
668 const std::shared_ptr<condition_variable> cond_var_;
669 // Identifies the maximum number of parallel calls.
670 const std::shared_ptr<model::SharedState> num_parallel_calls_;
671
672 // Controls cancellation of `input_impl_`. Must be ordered before
673 // `input_impl_` so that `input_impl_` is destroyed first.
674 std::unique_ptr<CancellationManager> cancellation_manager_;
675 // Counts the number of outstanding calls for this batch.
676 int64_t num_calls_ TF_GUARDED_BY(*mu_) = 0;
677 // Counts the total number of calls.
678 int64_t call_counter_ TF_GUARDED_BY(*mu_) = 0;
679 std::unique_ptr<IteratorBase> input_impl_;
680 // Buffer for storing the (intermediate) batch results. Whenever an
681 // output-allocated batch result is added to or removed from
682 // `batch_results_`, call `RecordBufferEnqueue` or `RecordBufferDequeue`
683 // respectively.
684 std::deque<std::shared_ptr<BatchResult>> batch_results_ TF_GUARDED_BY(*mu_);
685 // Determines whether the transformation has been cancelled.
686 bool cancelled_ TF_GUARDED_BY(*mu_) = false;
687 // Identifies the number of callers currently waiting for a batch result.
688 int64_t waiting_ TF_GUARDED_BY(*mu_) = 0;
689 // Identifies the maximum number of batch results to store.
690 int64_t max_batch_results_ TF_GUARDED_BY(*mu_);
691 std::unique_ptr<InstantiatedCapturedFunction> instantiated_captured_func_;
692
693 // Method for deregistering the cancellation callback.
694 std::function<void()> deregister_fn_;
695
696 // Records the number of ParallelInterleave operations in the path from the
697 // root node to this node (not including this node) in the input pipeline
698 // tree. We record the interleave depth so that it can be included in the
699 // trace metadata.
700 int64 interleave_depth_ = -1;
701 // Background thread used for coordinating input processing. The thread
702 // should be destroyed before the variables it accesses are destroyed.
703 std::unique_ptr<Thread> runner_thread_ TF_GUARDED_BY(*mu_);
704 };
705
706 const DatasetBase* const input_;
707 const int64_t batch_size_;
708 const int64_t num_parallel_calls_;
709 const bool drop_remainder_;
710 const DataTypeVector output_types_;
711 const std::vector<PartialTensorShape> output_shapes_;
712 const std::unique_ptr<CapturedFunction> captured_func_;
713 const bool preserve_cardinality_;
714 const TraceMeMetadata traceme_metadata_;
715 };
716
MapAndBatchDatasetOp(OpKernelConstruction * ctx)717 MapAndBatchDatasetOp::MapAndBatchDatasetOp(OpKernelConstruction* ctx)
718 : UnaryDatasetOpKernel(ctx) {
719 OP_REQUIRES_OK(ctx, FunctionMetadata::Create(ctx, kFunc, /*params=*/{},
720 &func_metadata_));
721 OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_types_));
722 OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
723 OP_REQUIRES_OK(ctx,
724 ctx->GetAttr(kPreserveCardinality, &preserve_cardinality_));
725 }
726
MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)727 void MapAndBatchDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
728 DatasetBase** output) {
729 int64_t batch_size = 0;
730 OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kBatchSize, &batch_size));
731 OP_REQUIRES(ctx, batch_size > 0,
732 errors::InvalidArgument("batch_size must be greater than zero."));
733
734 int64_t num_parallel_calls = 0;
735 OP_REQUIRES_OK(
736 ctx, ParseScalarArgument(ctx, kNumParallelCalls, &num_parallel_calls));
737 OP_REQUIRES(
738 ctx, num_parallel_calls > 0 || num_parallel_calls == model::kAutotune,
739 errors::InvalidArgument("num_parallel_calls must be greater than zero."));
740
741 bool drop_remainder;
742 OP_REQUIRES_OK(ctx,
743 ParseScalarArgument(ctx, kDropRemainder, &drop_remainder));
744
745 std::unique_ptr<CapturedFunction> captured_func;
746 OP_REQUIRES_OK(ctx,
747 CapturedFunction::Create(ctx, func_metadata_, kOtherArguments,
748 &captured_func));
749
750 if (num_parallel_calls == model::kAutotune) {
751 metrics::RecordTFDataAutotune(kDatasetType);
752 }
753
754 *output = new Dataset(ctx, input, batch_size, num_parallel_calls,
755 drop_remainder, output_types_, output_shapes_,
756 std::move(captured_func), preserve_cardinality_);
757 }
758
759 namespace {
760 REGISTER_KERNEL_BUILDER(Name("MapAndBatchDataset").Device(DEVICE_CPU),
761 MapAndBatchDatasetOp);
762 REGISTER_KERNEL_BUILDER(
763 Name("ExperimentalMapAndBatchDataset").Device(DEVICE_CPU),
764 MapAndBatchDatasetOp);
765
766 REGISTER_INPUT_COLOCATION_EXEMPTION("MapAndBatchDataset");
767 REGISTER_INPUT_COLOCATION_EXEMPTION("ExperimentalMapAndBatchDataset");
768 } // namespace
769 } // namespace experimental
770 } // namespace data
771 } // namespace tensorflow
772