1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/prefetch_dataset_op.h"
16
17 #include <deque>
18
19 #include "tensorflow/core/common_runtime/metrics.h"
20 #include "tensorflow/core/framework/dataset.h"
21 #include "tensorflow/core/framework/model.h"
22 #include "tensorflow/core/framework/partial_tensor_shape.h"
23 #include "tensorflow/core/framework/stats_aggregator.h"
24 #include "tensorflow/core/framework/tensor.h"
25 #include "tensorflow/core/kernels/data/dataset_utils.h"
26 #include "tensorflow/core/kernels/data/name_utils.h"
27 #include "tensorflow/core/kernels/data/stats_utils.h"
28 #include "tensorflow/core/lib/gtl/cleanup.h"
29 #include "tensorflow/core/lib/strings/str_util.h"
30 #include "tensorflow/core/lib/strings/stringprintf.h"
31 #include "tensorflow/core/platform/stringprintf.h"
32 #include "tensorflow/core/profiler/lib/traceme.h"
33 #include "tensorflow/core/profiler/lib/traceme_encode.h"
34 #include "tensorflow/core/protobuf/error_codes.pb.h"
35
36 namespace tensorflow {
37 namespace data {
38
39 // See documentation in ../../ops/dataset_ops.cc for a high-level
40 // description of the following op.
41
42 /* static */ constexpr const char* const PrefetchDatasetOp::kDatasetType;
43 /* static */ constexpr const char* const PrefetchDatasetOp::kInputDataset;
44 /* static */ constexpr const char* const PrefetchDatasetOp::kBufferSize;
45 /* static */ constexpr const char* const PrefetchDatasetOp::kOutputTypes;
46 /* static */ constexpr const char* const PrefetchDatasetOp::kOutputShapes;
47 /* static */ constexpr const char* const PrefetchDatasetOp::kSlackPeriod;
48 /* static */ constexpr const char* const PrefetchDatasetOp::kLegacyAutotune;
49 /* static */ constexpr const char* const PrefetchDatasetOp::kBufferSizeMin;
50
51 namespace {
52
53 // Determines the fraction of slack time by which to delay prefetching of data.
54 constexpr double kSleepFactor = 0.2;
55 constexpr char kBuffer[] = "buffer";
56 constexpr char kStatus[] = "status";
57 constexpr char kSizeSuffix[] = ".size";
58 constexpr char kCodeSuffix[] = ".code";
59 constexpr char kErrorMessageSuffix[] = ".error_message";
60
61 } // namespace
62
63 class PrefetchDatasetOp::Dataset : public DatasetBase {
64 public:
Dataset(OpKernelContext * ctx,const DatasetBase * input,int64 buffer_size,int64 slack_period,bool legacy_autotune,int64 buffer_size_min)65 Dataset(OpKernelContext* ctx, const DatasetBase* input, int64 buffer_size,
66 int64 slack_period, bool legacy_autotune, int64 buffer_size_min)
67 : DatasetBase(DatasetContext(ctx)),
68 input_(input),
69 buffer_size_(buffer_size),
70 slack_period_(slack_period),
71 legacy_autotune_(legacy_autotune),
72 buffer_size_min_(buffer_size_min) {
73 input_->Ref();
74 }
75
~Dataset()76 ~Dataset() override { input_->Unref(); }
77
MakeIteratorInternal(const string & prefix) const78 std::unique_ptr<IteratorBase> MakeIteratorInternal(
79 const string& prefix) const override {
80 return absl::make_unique<Iterator>(Iterator::Params{
81 this, name_utils::IteratorPrefix(kDatasetType, prefix)});
82 }
83
output_dtypes() const84 const DataTypeVector& output_dtypes() const override {
85 return input_->output_dtypes();
86 }
87
output_shapes() const88 const std::vector<PartialTensorShape>& output_shapes() const override {
89 return input_->output_shapes();
90 }
91
DebugString() const92 string DebugString() const override {
93 return name_utils::DatasetDebugString(kDatasetType);
94 }
95
Cardinality() const96 int64 Cardinality() const override { return input_->Cardinality(); }
97
InputDatasets(std::vector<const DatasetBase * > * inputs) const98 Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
99 inputs->push_back(input_);
100 return Status::OK();
101 }
102
CheckExternalState() const103 Status CheckExternalState() const override {
104 return input_->CheckExternalState();
105 }
106
107 protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const108 Status AsGraphDefInternal(SerializationContext* ctx,
109 DatasetGraphDefBuilder* b,
110 Node** output) const override {
111 Node* input_graph_node = nullptr;
112 TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
113 Node* buffer_size = nullptr;
114 TF_RETURN_IF_ERROR(b->AddScalar(buffer_size_, &buffer_size));
115 AttrValue slack_period_attr;
116 b->BuildAttrValue(slack_period_, &slack_period_attr);
117 AttrValue legacy_autotune_attr;
118 b->BuildAttrValue(legacy_autotune_, &legacy_autotune_attr);
119 AttrValue buffer_size_min_attr;
120 b->BuildAttrValue(buffer_size_min_, &buffer_size_min_attr);
121
122 TF_RETURN_IF_ERROR(
123 b->AddDataset(this, {input_graph_node, buffer_size},
124 {std::make_pair(kSlackPeriod, slack_period_attr),
125 std::make_pair(kLegacyAutotune, legacy_autotune_attr),
126 std::make_pair(kBufferSizeMin, buffer_size_min_attr)},
127 output));
128 return Status::OK();
129 }
130
131 private:
132 class Iterator : public DatasetIterator<Dataset> {
133 public:
Iterator(const Params & params)134 explicit Iterator(const Params& params)
135 : DatasetIterator<Dataset>(params),
136 mu_(std::make_shared<mutex>()),
137 cond_var_(std::make_shared<condition_variable>()),
138 buffer_size_min_(params.dataset->buffer_size_min_),
139 auto_tuner_(params.dataset->buffer_size_, buffer_size_min_),
140 legacy_autotune_(params.dataset->legacy_autotune_),
141 // If `legacy_autotune_`, initialize the `buffer_size_` value to be 0
142 // to avoid the created node to be collected as tunable nodes in the
143 // autotuning optimization.
144 buffer_size_(std::make_shared<model::SharedState>(
145 legacy_autotune_ ? 0 : params.dataset->buffer_size_, mu_,
146 cond_var_)) {
147 slack_us_ = 0;
148 }
149
~Iterator()150 ~Iterator() override {
151 CancelThreads();
152 if (deregister_fn_) deregister_fn_();
153 }
154
Initialize(IteratorContext * ctx)155 Status Initialize(IteratorContext* ctx) override {
156 mutex_lock l(*mu_);
157 if (buffer_size_->value == model::kAutotune) {
158 buffer_size_->value = buffer_size_min_;
159 }
160 TF_RETURN_IF_ERROR(RegisterCancellationCallback(
161 ctx->cancellation_manager(), [this]() { CancelThreads(); },
162 &deregister_fn_));
163 return dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_);
164 }
165
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)166 Status GetNextInternal(IteratorContext* ctx,
167 std::vector<Tensor>* out_tensors,
168 bool* end_of_sequence) override {
169 const auto& stats_aggregator = ctx->stats_aggregator();
170 {
171 mutex_lock l(*mu_);
172 TF_RETURN_IF_ERROR(EnsurePrefetchThreadStarted(ctx));
173 // Wait until the next element in the buffer has been
174 // produced, or we are shutting down.
175 if (legacy_autotune_) {
176 while (!cancelled_ && buffer_.empty() && !prefetch_thread_finished_ &&
177 auto_tuner_.buffer_limit() != 0) {
178 auto_tuner_.RecordEmpty();
179 buffer_size_->value = auto_tuner_.buffer_limit();
180 RecordStop(ctx);
181 cond_var_->wait(l);
182 RecordStart(ctx);
183 }
184 } else {
185 while (!cancelled_ && buffer_.empty() && !prefetch_thread_finished_ &&
186 buffer_size_->value != 0) {
187 RecordStop(ctx);
188 cond_var_->wait(l);
189 RecordStart(ctx);
190 }
191 }
192
193 if (cancelled_) {
194 return errors::Cancelled("Iterator was cancelled");
195 }
196
197 if (!buffer_.empty()) {
198 return Consume(ctx, out_tensors, end_of_sequence);
199 }
200
201 if (prefetch_thread_finished_) {
202 *end_of_sequence = true;
203 return Status::OK();
204 }
205
206 DCHECK_EQ(buffer_limit(), 0);
207 }
208
209 mutex_lock input_l(input_mu_);
210 {
211 mutex_lock l(*mu_);
212 if (stats_aggregator) {
213 stats_aggregator->AddScalar(
214 stats_utils::BufferSizeScalarName(dataset()->node_name()),
215 static_cast<float>(buffer_.size()), num_elements());
216 stats_aggregator->AddScalar(
217 stats_utils::BufferCapacityScalarName(dataset()->node_name()),
218 static_cast<float>(buffer_limit()), num_elements());
219 }
220 // Release mu_
221 }
222 return input_impl_->GetNext(ctx, out_tensors, end_of_sequence);
223 }
224
225 protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const226 std::shared_ptr<model::Node> CreateNode(
227 IteratorContext* ctx, model::Node::Args args) const override {
228 return model::MakeAsyncKnownRatioNode(
229 std::move(args),
230 /*ratio=*/1,
231 {model::MakeParameter(kBufferSize, buffer_size_,
232 /*min=*/buffer_size_min_,
233 /*max=*/std::numeric_limits<int64>::max())});
234 }
235
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)236 Status SaveInternal(SerializationContext* ctx,
237 IteratorStateWriter* writer) override {
238 // Acquire both locks to ensure that the prefetch thread and
239 // all GetNext threads are blocked.
240 mutex_lock input_l(input_mu_);
241 mutex_lock l(*mu_);
242 TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
243 TF_RETURN_IF_ERROR(
244 writer->WriteScalar(prefix(), kBufferSize, buffer_.size()));
245 for (size_t i = 0; i < buffer_.size(); i++) {
246 auto& buffer_element = buffer_[i];
247 TF_RETURN_IF_ERROR(WriteStatus(writer, i, buffer_element.status));
248 if (buffer_element.status.ok()) {
249 TF_RETURN_IF_ERROR(writer->WriteScalar(
250 absl::StrCat(prefix(), "::", i),
251 absl::StrCat(kBuffer, kSizeSuffix), buffer_element.value.size()));
252 for (size_t j = 0; j < buffer_element.value.size(); j++) {
253 TF_RETURN_IF_ERROR(writer->WriteTensor(
254 absl::StrCat(prefix(), "::", i),
255 absl::StrCat(kBuffer, "[", j, "]"), buffer_element.value[j]));
256 }
257 }
258 }
259 return Status::OK();
260 }
261
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)262 Status RestoreInternal(IteratorContext* ctx,
263 IteratorStateReader* reader) override {
264 mutex_lock input_l(input_mu_);
265 mutex_lock l(*mu_);
266 DCHECK(buffer_.empty());
267 TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
268 size_t buffer_size;
269 {
270 int64 temp;
271 TF_RETURN_IF_ERROR(reader->ReadScalar(prefix(), kBufferSize, &temp));
272 buffer_size = static_cast<size_t>(temp);
273 }
274 for (size_t i = 0; i < buffer_size; i++) {
275 buffer_.emplace_back();
276 auto& buffer_element = buffer_.back();
277 TF_RETURN_IF_ERROR(ReadStatus(reader, i, &buffer_element.status));
278 if (buffer_element.status.ok()) {
279 size_t value_size;
280 {
281 int64 temp;
282 TF_RETURN_IF_ERROR(
283 reader->ReadScalar(absl::StrCat(prefix(), "::", i),
284 absl::StrCat(kBuffer, kSizeSuffix), &temp));
285 value_size = static_cast<size_t>(temp);
286 }
287 buffer_element.value.reserve(value_size);
288 for (size_t j = 0; j < value_size; j++) {
289 buffer_element.value.emplace_back();
290 TF_RETURN_IF_ERROR(
291 reader->ReadTensor(absl::StrCat(prefix(), "::", i),
292 absl::StrCat(kBuffer, "[", j, "]"),
293 &buffer_element.value.back()));
294 }
295 }
296 RecordBufferEnqueue(ctx, buffer_element.value);
297 }
298 return Status::OK();
299 }
300
GetTraceMeMetadata() const301 data::TraceMeMetadata GetTraceMeMetadata() const override {
302 int64 limit = -1, size = -1;
303 data::TraceMeMetadata result;
304 // NOTE: We only set the parallelism value if the lock can be acquired
305 // right away to avoid introducing tracing overhead.
306 if (mu_->try_lock()) {
307 limit = buffer_limit();
308 size = buffer_.size();
309 if (!buffer_.empty()) {
310 std::vector<std::string> shapes(buffer_.front().value.size());
311 for (const auto& component : buffer_.front().value) {
312 shapes.push_back(component.shape().DebugString());
313 }
314 result.push_back(std::make_pair("next_element_shapes",
315 absl::StrJoin(shapes, ",")));
316 }
317 mu_->unlock();
318 }
319 result.push_back(std::make_pair(
320 "buffer_limit",
321 strings::Printf("%lld", static_cast<long long>(limit))));
322 result.push_back(std::make_pair(
323 "buffer_size",
324 strings::Printf("%lld", static_cast<long long>(size))));
325 result.push_back(std::make_pair(
326 "autotune",
327 dataset()->buffer_size_ == model::kAutotune ? "true" : "false"));
328 result.push_back(std::make_pair(
329 "autotune_mode", legacy_autotune_ ? "legacy" : "performance"));
330 if (dataset()->slack_period_ > 0) {
331 result.push_back(std::make_pair(
332 "slack",
333 strings::Printf("%lld", static_cast<long long>(slack_us_.load()))));
334 }
335 return result;
336 }
337
338 private:
339 // A buffer element comprises a status and (if that status is
340 // OK) a vector of tensors, representing an element of the input dataset.
341 struct BufferElement {
BufferElementtensorflow::data::PrefetchDatasetOp::Dataset::Iterator::BufferElement342 BufferElement() : uid(tensorflow::EnvTime::NowNanos()) {}
343
344 // The producer sets `status` if getting the input element fails.
345 Status status;
346 // The buffered data element.
347 std::vector<Tensor> value;
348 int64 created_us;
349 const uint64 uid;
350 };
351
buffer_limit() const352 int64 buffer_limit() const TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
353 if (legacy_autotune_) {
354 return auto_tuner_.buffer_limit();
355 }
356 return buffer_size_->value;
357 }
358
CancelThreads()359 void CancelThreads() TF_LOCKS_EXCLUDED(mu_) {
360 mutex_lock l(*mu_);
361 cancelled_ = true;
362 cond_var_->notify_all();
363 }
364
Consume(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)365 Status Consume(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
366 bool* end_of_sequence) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
367 const auto& stats_aggregator = ctx->stats_aggregator();
368 if (stats_aggregator) {
369 double buffer_limit_ = buffer_limit();
370 stats_aggregator->AddToHistogram(
371 stats_utils::BufferUtilizationHistogramName(dataset()->node_name()),
372 {static_cast<float>(buffer_.size()) /
373 static_cast<float>(buffer_limit_)},
374 num_elements());
375 stats_aggregator->AddScalar(
376 stats_utils::BufferSizeScalarName(dataset()->node_name()),
377 static_cast<float>(buffer_.size()), num_elements());
378 stats_aggregator->AddScalar(
379 stats_utils::BufferCapacityScalarName(dataset()->node_name()),
380 static_cast<float>(buffer_limit_), num_elements());
381 }
382 // A new element is available. Forward the status from computing it, and
383 // (if we successfully got an element) the output values.
384 Status s = buffer_.front().status;
385 if (s.ok()) {
386 int64 buffer_element_id = buffer_.front().uid;
387 profiler::TraceMe traceme(
388 [&] {
389 return profiler::TraceMeEncode(
390 "PrefetchConsume", {{"element_id", buffer_element_id}});
391 },
392 profiler::kInfo);
393 if (dataset()->slack_period_ > 0 &&
394 (num_elements() + 1) % dataset()->slack_period_ == 0) {
395 // TODO(rachelim): Consider doing something more sophisticated
396 // to decide how long to sleep for; e.g. using a kalman filter.
397 int64 slack_us = EnvTime::NowMicros() - buffer_.front().created_us;
398 // Every slack_period_-th element, update the most recent slack time,
399 // measured by the duration between when the element is prefetched
400 // and when it is consumed. We add kSleepFactor * slack_us_ to the
401 // measurement because we slept for that duration before prefetching
402 // the element.
403 slack_us_ = kSleepFactor * slack_us_ + slack_us;
404 VLOG(2) << "Setting slack_us_: " << slack_us_;
405 }
406 *out_tensors = std::move(buffer_.front().value);
407 RecordBufferDequeue(ctx, *out_tensors);
408 } else {
409 // If status not ok, we still record the dequeue event to make sure each
410 // enqueue event is paired with a dequeue event even in the presence of
411 // errors.
412 RecordBufferDequeue(ctx, buffer_.front().value);
413 }
414 if (legacy_autotune_) {
415 auto_tuner_.RecordConsumption(buffer_.size());
416 buffer_size_->value = auto_tuner_.buffer_limit();
417 }
418 buffer_.pop_front();
419 *end_of_sequence = false;
420
421 // Wake the prefetch thread, in case it has been waiting for space
422 // in the buffer. Also wake up threads from other calls to GetNext.
423 //
424 // TODO(mrry): Consider using different condition variables for
425 // GetNext and Prefetch.
426 cond_var_->notify_all();
427 return s;
428 }
429
EnsurePrefetchThreadStarted(IteratorContext * ctx)430 Status EnsurePrefetchThreadStarted(IteratorContext* ctx)
431 TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
432 if (!prefetch_thread_) {
433 std::shared_ptr<IteratorContext> new_ctx =
434 std::make_shared<IteratorContext>(*ctx);
435 prefetch_thread_ = ctx->StartThread(
436 "tf_data_prefetch", [this, new_ctx]() { PrefetchThread(new_ctx); });
437 }
438 return Status::OK();
439 }
440
441 // Prefetches elements of the input, storing results in an internal buffer.
442 //
443 // It owns the iterator context passed to it.
PrefetchThread(const std::shared_ptr<IteratorContext> & ctx)444 void PrefetchThread(const std::shared_ptr<IteratorContext>& ctx) {
445 RecordStart(ctx.get());
446 auto cleanup = gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); });
447 // Keep track of where we are in an iteration "burst"
448 int num_produced = 0;
449 while (true) {
450 // 1. Wait for a slot in the buffer.
451 {
452 mutex_lock l(*mu_);
453 while (!cancelled_ && buffer_.size() >= buffer_limit()) {
454 RecordStop(ctx.get());
455 cond_var_->wait(l);
456 RecordStart(ctx.get());
457 }
458
459 if (cancelled_) {
460 prefetch_thread_finished_ = true;
461 cond_var_->notify_all();
462 return;
463 }
464 }
465
466 if (dataset()->slack_period_ > 0 &&
467 num_produced % dataset()->slack_period_ == 0) {
468 // For the first element in the "burst", sleep for a bit if there is
469 // slack.
470 VLOG(2) << "Sleeping for: " << slack_us_ * kSleepFactor;
471 ctx->env()->SleepForMicroseconds(slack_us_ * kSleepFactor);
472 }
473
474 // 2. Read the next element.
475 // Acquire the input mutex since we will be reading an element from the
476 // input iterator. Note that we do not wish to release this mutex till
477 // we have added the fetched element to the `buffer_` else there will be
478 // local state that may be missed by SaveInternal.
479 mutex_lock input_l(input_mu_);
480 bool end_of_sequence;
481 BufferElement buffer_element;
482 {
483 profiler::TraceMe traceme(
484 [&] {
485 return profiler::TraceMeEncode(
486 "PrefetchProduce", {{"element_id", buffer_element.uid}});
487 },
488 profiler::kInfo);
489 buffer_element.status = input_impl_->GetNext(
490 ctx.get(), &buffer_element.value, &end_of_sequence);
491 }
492 if (buffer_element.status.ok() && end_of_sequence) {
493 mutex_lock l(*mu_);
494 prefetch_thread_finished_ = true;
495 cond_var_->notify_all();
496 return;
497 }
498
499 // 3. Signal that the element has been produced.
500 {
501 mutex_lock l(*mu_);
502 RecordBufferEnqueue(ctx.get(), buffer_element.value);
503 buffer_element.created_us = EnvTime::NowMicros();
504 buffer_.push_back(std::move(buffer_element));
505 cond_var_->notify_all();
506 }
507 ++num_produced;
508 }
509 }
510
WriteStatus(IteratorStateWriter * writer,size_t index,const Status & status)511 Status WriteStatus(IteratorStateWriter* writer, size_t index,
512 const Status& status) TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
513 TF_RETURN_IF_ERROR(
514 writer->WriteScalar(absl::StrCat(prefix(), "::", index), CodeKey(),
515 static_cast<int64>(status.code())));
516 if (!status.ok()) {
517 TF_RETURN_IF_ERROR(
518 writer->WriteScalar(absl::StrCat(prefix(), "::", index),
519 ErrorMessageKey(), status.error_message()));
520 }
521 return Status::OK();
522 }
523
ReadStatus(IteratorStateReader * reader,size_t index,Status * status)524 Status ReadStatus(IteratorStateReader* reader, size_t index, Status* status)
525 TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
526 int64 code_int;
527 TF_RETURN_IF_ERROR(reader->ReadScalar(absl::StrCat(prefix(), "::", index),
528 CodeKey(), &code_int));
529 error::Code code = static_cast<error::Code>(code_int);
530
531 if (code != error::Code::OK) {
532 tstring error_message;
533 TF_RETURN_IF_ERROR(
534 reader->ReadScalar(absl::StrCat(prefix(), "::", index),
535 ErrorMessageKey(), &error_message));
536 *status = Status(code, error_message);
537 } else {
538 *status = Status::OK();
539 }
540 return Status::OK();
541 }
542
CodeKey()543 string CodeKey() { return absl::StrCat(kStatus, kCodeSuffix); }
544
ErrorMessageKey()545 string ErrorMessageKey() {
546 return absl::StrCat(kStatus, kErrorMessageSuffix);
547 }
548
549 // This mutex is used to ensure exclusivity between multiple threads
550 // reading/writing this iterator's local state.
551 //
552 // NOTE: We should never call GetNext on the input while holding this mutex.
553 const std::shared_ptr<mutex> mu_;
554 // This mutex is used to ensure exclusivity between multiple threads
555 // accessing the input iterator. We keep this separate from `mu_` to allow
556 // prefetching to run in parallel with GetNext calls.
557 mutex input_mu_ TF_ACQUIRED_BEFORE(*mu_);
558 std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(input_mu_);
559 const std::shared_ptr<condition_variable> cond_var_;
560 const int64 buffer_size_min_;
561 PrefetchAutotuner auto_tuner_ TF_GUARDED_BY(*mu_);
562 std::deque<BufferElement> buffer_ TF_GUARDED_BY(*mu_);
563 std::unique_ptr<Thread> prefetch_thread_ TF_GUARDED_BY(*mu_);
564 bool cancelled_ TF_GUARDED_BY(*mu_) = false;
565 bool prefetch_thread_finished_ TF_GUARDED_BY(*mu_) = false;
566 const bool legacy_autotune_;
567
568 std::atomic<int64> slack_us_;
569
570 // If legacy_autotune_ is false, identifies the maximum size of the buffer.
571 const std::shared_ptr<model::SharedState> buffer_size_;
572
573 // Method for deregistering the cancellation callback.
574 std::function<void()> deregister_fn_;
575 };
576 const DatasetBase* const input_;
577 const int64 buffer_size_;
578
579 // If non-zero, determines the period between injecting "slack" into the
580 // execution.
581 const int64 slack_period_;
582
583 // Determines whether legacy autotuning should be used.
584 const bool legacy_autotune_ = true;
585
586 // If autotune is enabled, determines the minimal value of `buffer_size`
587 // parameter.
588 const int64 buffer_size_min_ = 0;
589
590 TraceMeMetadata traceme_metadata_;
591 };
592
PrefetchDatasetOp(OpKernelConstruction * ctx)593 PrefetchDatasetOp::PrefetchDatasetOp(OpKernelConstruction* ctx)
594 : UnaryDatasetOpKernel(ctx) {
595 if (ctx->HasAttr(kSlackPeriod)) {
596 OP_REQUIRES_OK(ctx, ctx->GetAttr(kSlackPeriod, &slack_period_));
597 }
598 if (ctx->HasAttr(kLegacyAutotune)) {
599 OP_REQUIRES_OK(ctx, ctx->GetAttr(kLegacyAutotune, &legacy_autotune_));
600 }
601 if (ctx->HasAttr(kBufferSizeMin)) {
602 OP_REQUIRES_OK(ctx, ctx->GetAttr(kBufferSizeMin, &buffer_size_min_));
603 }
604 }
605
MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)606 void PrefetchDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
607 DatasetBase** output) {
608 int64 buffer_size = 0;
609 OP_REQUIRES_OK(ctx,
610 ParseScalarArgument<int64>(ctx, kBufferSize, &buffer_size));
611 OP_REQUIRES(ctx, buffer_size >= 0 || buffer_size == model::kAutotune,
612 errors::InvalidArgument("buffer_size must be >= 0 or set "
613 "buffer_size to be ",
614 model::kAutotune, " for auto-tuning"));
615
616 if (buffer_size == model::kAutotune) {
617 metrics::RecordTFDataAutotune(kDatasetType);
618 }
619
620 *output = new Dataset(ctx, input, buffer_size, slack_period_,
621 legacy_autotune_, buffer_size_min_);
622 }
623
624 namespace {
625 REGISTER_KERNEL_BUILDER(Name("PrefetchDataset").Device(DEVICE_CPU).Priority(2),
626 PrefetchDatasetOp);
627 REGISTER_KERNEL_BUILDER(Name("PrefetchDataset")
628 .Device(DEVICE_GPU)
629 .HostMemory("buffer_size")
630 .HostMemory("input_dataset")
631 .HostMemory("handle")
632 .Priority(1),
633 PrefetchDatasetOp);
634 } // namespace
635
636 } // namespace data
637 } // namespace tensorflow
638