• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/prefetch_dataset_op.h"
16 
17 #include <deque>
18 
19 #include "tensorflow/core/common_runtime/metrics.h"
20 #include "tensorflow/core/data/dataset_utils.h"
21 #include "tensorflow/core/data/name_utils.h"
22 #include "tensorflow/core/data/stats_utils.h"
23 #include "tensorflow/core/framework/dataset.h"
24 #include "tensorflow/core/framework/model.h"
25 #include "tensorflow/core/framework/partial_tensor_shape.h"
26 #include "tensorflow/core/framework/stats_aggregator.h"
27 #include "tensorflow/core/framework/tensor.h"
28 #include "tensorflow/core/lib/gtl/cleanup.h"
29 #include "tensorflow/core/lib/strings/str_util.h"
30 #include "tensorflow/core/lib/strings/stringprintf.h"
31 #include "tensorflow/core/platform/stringprintf.h"
32 #include "tensorflow/core/profiler/lib/traceme.h"
33 #include "tensorflow/core/profiler/lib/traceme_encode.h"
34 #include "tensorflow/core/protobuf/error_codes.pb.h"
35 
36 namespace tensorflow {
37 namespace data {
38 
39 // See documentation in ../../ops/dataset_ops.cc for a high-level
40 // description of the following op.
41 
42 /* static */ constexpr const char* const PrefetchDatasetOp::kDatasetType;
43 /* static */ constexpr const char* const PrefetchDatasetOp::kInputDataset;
44 /* static */ constexpr const char* const PrefetchDatasetOp::kBufferSize;
45 /* static */ constexpr const char* const PrefetchDatasetOp::kOutputTypes;
46 /* static */ constexpr const char* const PrefetchDatasetOp::kOutputShapes;
47 /* static */ constexpr const char* const PrefetchDatasetOp::kSlackPeriod;
48 /* static */ constexpr const char* const PrefetchDatasetOp::kLegacyAutotune;
49 /* static */ constexpr const char* const PrefetchDatasetOp::kBufferSizeMin;
50 
51 namespace {
52 
53 // Determines the fraction of slack time by which to delay prefetching of data.
54 constexpr double kSleepFactor = 0.2;
55 constexpr char kBuffer[] = "buffer";
56 constexpr char kStatus[] = "status";
57 constexpr char kSizeSuffix[] = ".size";
58 constexpr char kCodeSuffix[] = ".code";
59 constexpr char kErrorMessageSuffix[] = ".error_message";
60 
61 }  // namespace
62 
63 class PrefetchDatasetOp::Dataset : public DatasetBase {
64  public:
Dataset(OpKernelContext * ctx,const DatasetBase * input,int64_t buffer_size,int64_t slack_period,bool legacy_autotune,int64_t buffer_size_min)65   Dataset(OpKernelContext* ctx, const DatasetBase* input, int64_t buffer_size,
66           int64_t slack_period, bool legacy_autotune, int64_t buffer_size_min)
67       : DatasetBase(DatasetContext(ctx)),
68         input_(input),
69         buffer_size_(buffer_size),
70         slack_period_(slack_period),
71         legacy_autotune_(legacy_autotune),
72         buffer_size_min_(buffer_size_min) {
73     input_->Ref();
74   }
75 
~Dataset()76   ~Dataset() override { input_->Unref(); }
77 
MakeIteratorInternal(const string & prefix) const78   std::unique_ptr<IteratorBase> MakeIteratorInternal(
79       const string& prefix) const override {
80     return absl::make_unique<Iterator>(Iterator::Params{
81         this, name_utils::IteratorPrefix(kDatasetType, prefix)});
82   }
83 
output_dtypes() const84   const DataTypeVector& output_dtypes() const override {
85     return input_->output_dtypes();
86   }
87 
output_shapes() const88   const std::vector<PartialTensorShape>& output_shapes() const override {
89     return input_->output_shapes();
90   }
91 
DebugString() const92   string DebugString() const override {
93     return name_utils::DatasetDebugString(kDatasetType);
94   }
95 
Cardinality() const96   int64 Cardinality() const override { return input_->Cardinality(); }
97 
InputDatasets(std::vector<const DatasetBase * > * inputs) const98   Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
99     inputs->push_back(input_);
100     return Status::OK();
101   }
102 
CheckExternalState() const103   Status CheckExternalState() const override {
104     return input_->CheckExternalState();
105   }
106 
107  protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const108   Status AsGraphDefInternal(SerializationContext* ctx,
109                             DatasetGraphDefBuilder* b,
110                             Node** output) const override {
111     Node* input_graph_node = nullptr;
112     TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
113     Node* buffer_size = nullptr;
114     TF_RETURN_IF_ERROR(b->AddScalar(buffer_size_, &buffer_size));
115     AttrValue slack_period_attr;
116     b->BuildAttrValue(slack_period_, &slack_period_attr);
117     AttrValue legacy_autotune_attr;
118     b->BuildAttrValue(legacy_autotune_, &legacy_autotune_attr);
119     AttrValue buffer_size_min_attr;
120     b->BuildAttrValue(buffer_size_min_, &buffer_size_min_attr);
121 
122     TF_RETURN_IF_ERROR(
123         b->AddDataset(this, {input_graph_node, buffer_size},
124                       {std::make_pair(kSlackPeriod, slack_period_attr),
125                        std::make_pair(kLegacyAutotune, legacy_autotune_attr),
126                        std::make_pair(kBufferSizeMin, buffer_size_min_attr)},
127                       output));
128     return Status::OK();
129   }
130 
131  private:
132   class Iterator : public DatasetIterator<Dataset> {
133    public:
Iterator(const Params & params)134     explicit Iterator(const Params& params)
135         : DatasetIterator<Dataset>(params),
136           mu_(std::make_shared<mutex>()),
137           cond_var_(std::make_shared<condition_variable>()),
138           buffer_size_min_(params.dataset->buffer_size_min_),
139           auto_tuner_(params.dataset->buffer_size_, buffer_size_min_),
140           legacy_autotune_(params.dataset->legacy_autotune_),
141           // If `legacy_autotune_`, initialize the `buffer_size_` value to be 0
142           // to avoid the created node to be collected as tunable nodes in the
143           // autotuning optimization.
144           buffer_size_(std::make_shared<model::SharedState>(
145               legacy_autotune_ ? 0 : params.dataset->buffer_size_, mu_,
146               cond_var_)) {
147       slack_us_ = 0;
148     }
149 
~Iterator()150     ~Iterator() override {
151       CancelThreads();
152       if (deregister_fn_) deregister_fn_();
153     }
154 
Initialize(IteratorContext * ctx)155     Status Initialize(IteratorContext* ctx) override {
156       mutex_lock l(*mu_);
157       if (buffer_size_->value == model::kAutotune) {
158         buffer_size_->value = buffer_size_min_;
159       }
160       cancellation_manager_ = absl::make_unique<CancellationManager>();
161       TF_RETURN_IF_ERROR(RegisterCancellationCallback(
162           ctx->cancellation_manager(), [this]() { CancelThreads(); },
163           &deregister_fn_));
164       IteratorContext::Params params(ctx);
165       params.cancellation_manager = cancellation_manager_.get();
166       return dataset()->input_->MakeIterator(IteratorContext(params), this,
167                                              prefix(), &input_impl_);
168     }
169 
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)170     Status GetNextInternal(IteratorContext* ctx,
171                            std::vector<Tensor>* out_tensors,
172                            bool* end_of_sequence) override {
173       const auto& stats_aggregator = ctx->stats_aggregator();
174       {
175         mutex_lock l(*mu_);
176         TF_RETURN_IF_ERROR(EnsurePrefetchThreadStarted(ctx));
177         // Wait until the next element in the buffer has been
178         // produced, or we are shutting down.
179         if (legacy_autotune_) {
180           while (!cancelled_ && buffer_.empty() && !prefetch_thread_finished_ &&
181                  auto_tuner_.buffer_limit() != 0) {
182             auto_tuner_.RecordEmpty();
183             buffer_size_->value = auto_tuner_.buffer_limit();
184             RecordStop(ctx);
185             cond_var_->wait(l);
186             RecordStart(ctx);
187           }
188         } else {
189           while (!cancelled_ && buffer_.empty() && !prefetch_thread_finished_ &&
190                  buffer_size_->value != 0) {
191             RecordStop(ctx);
192             cond_var_->wait(l);
193             RecordStart(ctx);
194           }
195         }
196 
197         if (cancelled_) {
198           return errors::Cancelled("Iterator was cancelled");
199         }
200 
201         if (!buffer_.empty()) {
202           return Consume(ctx, out_tensors, end_of_sequence);
203         }
204 
205         if (prefetch_thread_finished_) {
206           *end_of_sequence = true;
207           return Status::OK();
208         }
209 
210         DCHECK_EQ(buffer_limit(), 0);
211       }
212 
213       mutex_lock input_l(input_mu_);
214       {
215         mutex_lock l(*mu_);
216         if (stats_aggregator) {
217           stats_aggregator->AddScalar(
218               stats_utils::BufferSizeScalarName(dataset()->node_name()),
219               static_cast<float>(buffer_.size()), num_elements());
220           stats_aggregator->AddScalar(
221               stats_utils::BufferCapacityScalarName(dataset()->node_name()),
222               static_cast<float>(buffer_limit()), num_elements());
223         }
224         // Release mu_
225       }
226       return input_impl_->GetNext(ctx, out_tensors, end_of_sequence);
227     }
228 
229    protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const230     std::shared_ptr<model::Node> CreateNode(
231         IteratorContext* ctx, model::Node::Args args) const override {
232       return model::MakeAsyncKnownRatioNode(
233           std::move(args),
234           /*ratio=*/1,
235           {model::MakeParameter(kBufferSize, buffer_size_,
236                                 /*min=*/buffer_size_min_,
237                                 /*max=*/std::numeric_limits<int64>::max())});
238     }
239 
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)240     Status SaveInternal(SerializationContext* ctx,
241                         IteratorStateWriter* writer) override {
242       // Acquire both locks to ensure that the prefetch thread and
243       // all GetNext threads are blocked.
244       mutex_lock input_l(input_mu_);
245       mutex_lock l(*mu_);
246       TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
247       TF_RETURN_IF_ERROR(
248           writer->WriteScalar(prefix(), kBufferSize, buffer_.size()));
249       for (size_t i = 0; i < buffer_.size(); i++) {
250         auto& buffer_element = buffer_[i];
251         TF_RETURN_IF_ERROR(WriteStatus(writer, i, buffer_element.status));
252         if (buffer_element.status.ok()) {
253           TF_RETURN_IF_ERROR(writer->WriteScalar(
254               absl::StrCat(prefix(), "::", i),
255               absl::StrCat(kBuffer, kSizeSuffix), buffer_element.value.size()));
256           for (size_t j = 0; j < buffer_element.value.size(); j++) {
257             TF_RETURN_IF_ERROR(writer->WriteTensor(
258                 absl::StrCat(prefix(), "::", i),
259                 absl::StrCat(kBuffer, "[", j, "]"), buffer_element.value[j]));
260           }
261         }
262       }
263       return Status::OK();
264     }
265 
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)266     Status RestoreInternal(IteratorContext* ctx,
267                            IteratorStateReader* reader) override {
268       mutex_lock input_l(input_mu_);
269       mutex_lock l(*mu_);
270       DCHECK(buffer_.empty());
271       TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
272       size_t buffer_size;
273       {
274         int64_t temp;
275         TF_RETURN_IF_ERROR(reader->ReadScalar(prefix(), kBufferSize, &temp));
276         buffer_size = static_cast<size_t>(temp);
277       }
278       for (size_t i = 0; i < buffer_size; i++) {
279         buffer_.emplace_back();
280         auto& buffer_element = buffer_.back();
281         TF_RETURN_IF_ERROR(ReadStatus(reader, i, &buffer_element.status));
282         if (buffer_element.status.ok()) {
283           size_t value_size;
284           {
285             int64_t temp;
286             TF_RETURN_IF_ERROR(
287                 reader->ReadScalar(absl::StrCat(prefix(), "::", i),
288                                    absl::StrCat(kBuffer, kSizeSuffix), &temp));
289             value_size = static_cast<size_t>(temp);
290           }
291           buffer_element.value.reserve(value_size);
292           for (size_t j = 0; j < value_size; j++) {
293             buffer_element.value.emplace_back();
294             TF_RETURN_IF_ERROR(
295                 reader->ReadTensor(ctx->flr(), absl::StrCat(prefix(), "::", i),
296                                    absl::StrCat(kBuffer, "[", j, "]"),
297                                    &buffer_element.value.back()));
298           }
299         }
300         RecordBufferEnqueue(ctx, buffer_element.value);
301       }
302       return Status::OK();
303     }
304 
GetTraceMeMetadata() const305     data::TraceMeMetadata GetTraceMeMetadata() const override {
306       int64_t limit = -1, size = -1;
307       data::TraceMeMetadata result;
308       // NOTE: We only set the parallelism value if the lock can be acquired
309       // right away to avoid introducing tracing overhead.
310       if (mu_->try_lock()) {
311         limit = buffer_limit();
312         size = buffer_.size();
313         if (!buffer_.empty()) {
314           std::vector<std::string> shapes(buffer_.front().value.size());
315           for (const auto& component : buffer_.front().value) {
316             shapes.push_back(component.shape().DebugString());
317           }
318           result.push_back(std::make_pair("next_element_shapes",
319                                           absl::StrJoin(shapes, ",")));
320         }
321         mu_->unlock();
322       }
323       result.push_back(std::make_pair(
324           "buffer_limit",
325           limit == -1
326               ? kTraceInfoUnavailable
327               : strings::Printf("%lld", static_cast<long long>(limit))));
328       result.push_back(std::make_pair(
329           "buffer_size",
330           size == -1 ? kTraceInfoUnavailable
331                      : strings::Printf("%lld", static_cast<long long>(size))));
332       result.push_back(std::make_pair(
333           "autotune",
334           dataset()->buffer_size_ == model::kAutotune ? "true" : "false"));
335       result.push_back(std::make_pair(
336           "autotune_mode", legacy_autotune_ ? "legacy" : "performance"));
337       if (dataset()->slack_period_ > 0) {
338         result.push_back(std::make_pair(
339             "slack",
340             strings::Printf("%lld", static_cast<long long>(slack_us_.load()))));
341       }
342       return result;
343     }
344 
345    private:
346     // A buffer element comprises a status and (if that status is
347     // OK) a vector of tensors, representing an element of the input dataset.
348     struct BufferElement {
BufferElementtensorflow::data::PrefetchDatasetOp::Dataset::Iterator::BufferElement349       BufferElement() : uid(tensorflow::EnvTime::NowNanos()) {}
350 
351       // The producer sets `status` if getting the input element fails.
352       Status status;
353       // The buffered data element.
354       std::vector<Tensor> value;
355       int64 created_us;
356       const uint64 uid;
357     };
358 
buffer_limit() const359     int64 buffer_limit() const TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
360       if (legacy_autotune_) {
361         return auto_tuner_.buffer_limit();
362       }
363       return buffer_size_->value;
364     }
365 
CancelThreads()366     void CancelThreads() TF_LOCKS_EXCLUDED(mu_) {
367       cancellation_manager_->StartCancel();
368       mutex_lock l(*mu_);
369       cancelled_ = true;
370       cond_var_->notify_all();
371     }
372 
Consume(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)373     Status Consume(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
374                    bool* end_of_sequence) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
375       const auto& stats_aggregator = ctx->stats_aggregator();
376       if (stats_aggregator) {
377         double buffer_limit_ = buffer_limit();
378         stats_aggregator->AddToHistogram(
379             stats_utils::BufferUtilizationHistogramName(dataset()->node_name()),
380             {static_cast<float>(buffer_.size()) /
381              static_cast<float>(buffer_limit_)},
382             num_elements());
383         stats_aggregator->AddScalar(
384             stats_utils::BufferSizeScalarName(dataset()->node_name()),
385             static_cast<float>(buffer_.size()), num_elements());
386         stats_aggregator->AddScalar(
387             stats_utils::BufferCapacityScalarName(dataset()->node_name()),
388             static_cast<float>(buffer_limit_), num_elements());
389       }
390       // A new element is available. Forward the status from computing it, and
391       // (if we successfully got an element) the output values.
392       Status s = buffer_.front().status;
393       if (s.ok()) {
394         int64_t buffer_element_id = buffer_.front().uid;
395         profiler::TraceMe traceme(
396             [&] {
397               return profiler::TraceMeEncode(
398                   "PrefetchConsume", {{"element_id", buffer_element_id}});
399             },
400             profiler::kInfo);
401         if (dataset()->slack_period_ > 0 &&
402             (num_elements() + 1) % dataset()->slack_period_ == 0) {
403           // TODO(rachelim): Consider doing something more sophisticated
404           // to decide how long to sleep for; e.g. using a kalman filter.
405           int64_t slack_us = EnvTime::NowMicros() - buffer_.front().created_us;
406           // Every slack_period_-th element, update the most recent slack time,
407           // measured by the duration between when the element is prefetched
408           // and when it is consumed. We add kSleepFactor * slack_us_ to the
409           // measurement because we slept for that duration before prefetching
410           // the element.
411           slack_us_ = kSleepFactor * slack_us_ + slack_us;
412           VLOG(2) << "Setting slack_us_: " << slack_us_;
413         }
414         *out_tensors = std::move(buffer_.front().value);
415         RecordBufferDequeue(ctx, *out_tensors);
416       } else {
417         // If status not ok, we still record the dequeue event to make sure each
418         // enqueue event is paired with a dequeue event even in the presence of
419         // errors.
420         RecordBufferDequeue(ctx, buffer_.front().value);
421       }
422       if (legacy_autotune_) {
423         auto_tuner_.RecordConsumption(buffer_.size());
424         buffer_size_->value = auto_tuner_.buffer_limit();
425       }
426       buffer_.pop_front();
427       *end_of_sequence = false;
428 
429       // Wake the prefetch thread, in case it has been waiting for space
430       // in the buffer. Also wake up threads from other calls to GetNext.
431       //
432       // TODO(mrry): Consider using different condition variables for
433       // GetNext and Prefetch.
434       cond_var_->notify_all();
435       return s;
436     }
437 
EnsurePrefetchThreadStarted(IteratorContext * ctx)438     Status EnsurePrefetchThreadStarted(IteratorContext* ctx)
439         TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
440       if (!prefetch_thread_) {
441         std::shared_ptr<IteratorContext> new_ctx =
442             std::make_shared<IteratorContext>(*ctx);
443         prefetch_thread_ = ctx->StartThread(
444             "tf_data_prefetch", [this, new_ctx]() { PrefetchThread(new_ctx); });
445       }
446       return Status::OK();
447     }
448 
449     // Prefetches elements of the input, storing results in an internal buffer.
450     //
451     // It owns the iterator context passed to it.
PrefetchThread(const std::shared_ptr<IteratorContext> & ctx)452     void PrefetchThread(const std::shared_ptr<IteratorContext>& ctx) {
453       RecordStart(ctx.get());
454       auto cleanup = gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); });
455       // Keep track of where we are in an iteration "burst"
456       int num_produced = 0;
457       while (true) {
458         // 1. Wait for a slot in the buffer.
459         {
460           mutex_lock l(*mu_);
461           while (!cancelled_ && buffer_.size() >= buffer_limit()) {
462             RecordStop(ctx.get());
463             cond_var_->wait(l);
464             RecordStart(ctx.get());
465           }
466 
467           if (cancelled_) {
468             prefetch_thread_finished_ = true;
469             cond_var_->notify_all();
470             return;
471           }
472         }
473 
474         if (dataset()->slack_period_ > 0 &&
475             num_produced % dataset()->slack_period_ == 0) {
476           // For the first element in the "burst", sleep for a bit if there is
477           // slack.
478           VLOG(2) << "Sleeping for: " << slack_us_ * kSleepFactor;
479           ctx->env()->SleepForMicroseconds(slack_us_ * kSleepFactor);
480         }
481 
482         // 2. Read the next element.
483         // Acquire the input mutex since we will be reading an element from the
484         // input iterator. Note that we do not wish to release this mutex till
485         // we have added the fetched element to the `buffer_` else there will be
486         // local state that may be missed by SaveInternal.
487         mutex_lock input_l(input_mu_);
488         bool end_of_sequence;
489         BufferElement buffer_element;
490         {
491           profiler::TraceMe traceme(
492               [&] {
493                 return profiler::TraceMeEncode(
494                     "PrefetchProduce", {{"element_id", buffer_element.uid}});
495               },
496               profiler::kInfo);
497           buffer_element.status = input_impl_->GetNext(
498               ctx.get(), &buffer_element.value, &end_of_sequence);
499         }
500         if (buffer_element.status.ok() && end_of_sequence) {
501           mutex_lock l(*mu_);
502           prefetch_thread_finished_ = true;
503           cond_var_->notify_all();
504           return;
505         }
506 
507         // 3. Signal that the element has been produced.
508         {
509           mutex_lock l(*mu_);
510           RecordBufferEnqueue(ctx.get(), buffer_element.value);
511           buffer_element.created_us = EnvTime::NowMicros();
512           buffer_.push_back(std::move(buffer_element));
513           cond_var_->notify_all();
514         }
515         ++num_produced;
516       }
517     }
518 
WriteStatus(IteratorStateWriter * writer,size_t index,const Status & status)519     Status WriteStatus(IteratorStateWriter* writer, size_t index,
520                        const Status& status) TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
521       TF_RETURN_IF_ERROR(
522           writer->WriteScalar(absl::StrCat(prefix(), "::", index), CodeKey(),
523                               static_cast<int64>(status.code())));
524       if (!status.ok()) {
525         TF_RETURN_IF_ERROR(
526             writer->WriteScalar(absl::StrCat(prefix(), "::", index),
527                                 ErrorMessageKey(), status.error_message()));
528       }
529       return Status::OK();
530     }
531 
ReadStatus(IteratorStateReader * reader,size_t index,Status * status)532     Status ReadStatus(IteratorStateReader* reader, size_t index, Status* status)
533         TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
534       int64_t code_int;
535       TF_RETURN_IF_ERROR(reader->ReadScalar(absl::StrCat(prefix(), "::", index),
536                                             CodeKey(), &code_int));
537       error::Code code = static_cast<error::Code>(code_int);
538 
539       if (code != error::Code::OK) {
540         tstring error_message;
541         TF_RETURN_IF_ERROR(
542             reader->ReadScalar(absl::StrCat(prefix(), "::", index),
543                                ErrorMessageKey(), &error_message));
544         *status = Status(code, error_message);
545       } else {
546         *status = Status::OK();
547       }
548       return Status::OK();
549     }
550 
CodeKey()551     string CodeKey() { return absl::StrCat(kStatus, kCodeSuffix); }
552 
ErrorMessageKey()553     string ErrorMessageKey() {
554       return absl::StrCat(kStatus, kErrorMessageSuffix);
555     }
556 
557     // This mutex is used to ensure exclusivity between multiple threads
558     // reading/writing this iterator's local state.
559     //
560     // NOTE: We should never call GetNext on the input while holding this mutex.
561     const std::shared_ptr<mutex> mu_;
562     // This mutex is used to ensure exclusivity between multiple threads
563     // accessing the input iterator. We keep this separate from `mu_` to allow
564     // prefetching to run in parallel with GetNext calls.
565     mutex input_mu_ TF_ACQUIRED_BEFORE(*mu_);
566     // Controls cancellation of `input_impl_`. Must be ordered before
567     // `input_impl_` so that `input_impl_` is destroyed first.
568     std::unique_ptr<CancellationManager> cancellation_manager_;
569     std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(input_mu_);
570     const std::shared_ptr<condition_variable> cond_var_;
571     const int64 buffer_size_min_;
572     PrefetchAutotuner auto_tuner_ TF_GUARDED_BY(*mu_);
573     std::deque<BufferElement> buffer_ TF_GUARDED_BY(*mu_);
574     std::unique_ptr<Thread> prefetch_thread_ TF_GUARDED_BY(*mu_);
575     bool cancelled_ TF_GUARDED_BY(*mu_) = false;
576     bool prefetch_thread_finished_ TF_GUARDED_BY(*mu_) = false;
577     const bool legacy_autotune_;
578 
579     std::atomic<int64> slack_us_;
580 
581     // If legacy_autotune_ is false, identifies the maximum size of the buffer.
582     const std::shared_ptr<model::SharedState> buffer_size_;
583 
584     // Method for deregistering the cancellation callback.
585     std::function<void()> deregister_fn_;
586   };
587   const DatasetBase* const input_;
588   const int64 buffer_size_;
589 
590   // If non-zero, determines the period between injecting "slack" into the
591   // execution.
592   const int64 slack_period_;
593 
594   // Determines whether legacy autotuning should be used.
595   const bool legacy_autotune_ = true;
596 
597   // If autotune is enabled, determines the minimal value of `buffer_size`
598   // parameter.
599   const int64 buffer_size_min_ = 0;
600 
601   TraceMeMetadata traceme_metadata_;
602 };
603 
PrefetchDatasetOp(OpKernelConstruction * ctx)604 PrefetchDatasetOp::PrefetchDatasetOp(OpKernelConstruction* ctx)
605     : UnaryDatasetOpKernel(ctx) {
606   if (ctx->HasAttr(kSlackPeriod)) {
607     OP_REQUIRES_OK(ctx, ctx->GetAttr(kSlackPeriod, &slack_period_));
608   }
609   if (ctx->HasAttr(kLegacyAutotune)) {
610     OP_REQUIRES_OK(ctx, ctx->GetAttr(kLegacyAutotune, &legacy_autotune_));
611   }
612   if (ctx->HasAttr(kBufferSizeMin)) {
613     OP_REQUIRES_OK(ctx, ctx->GetAttr(kBufferSizeMin, &buffer_size_min_));
614   }
615 }
616 
MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)617 void PrefetchDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
618                                     DatasetBase** output) {
619   int64_t buffer_size = 0;
620   OP_REQUIRES_OK(ctx,
621                  ParseScalarArgument<int64>(ctx, kBufferSize, &buffer_size));
622   OP_REQUIRES(ctx, buffer_size >= 0 || buffer_size == model::kAutotune,
623               errors::InvalidArgument("buffer_size must be >= 0 or set "
624                                       "buffer_size to be ",
625                                       model::kAutotune, " for auto-tuning"));
626 
627   if (buffer_size == model::kAutotune) {
628     metrics::RecordTFDataAutotune(kDatasetType);
629   }
630 
631   *output = new Dataset(ctx, input, buffer_size, slack_period_,
632                         legacy_autotune_, buffer_size_min_);
633 }
634 
635 namespace {
636 REGISTER_KERNEL_BUILDER(Name("PrefetchDataset").Device(DEVICE_CPU).Priority(2),
637                         PrefetchDatasetOp);
638 REGISTER_KERNEL_BUILDER(Name("PrefetchDataset")
639                             .Device(DEVICE_GPU)
640                             .HostMemory("buffer_size")
641                             .HostMemory("input_dataset")
642                             .HostMemory("handle")
643                             .Priority(1),
644                         PrefetchDatasetOp);
645 }  // namespace
646 
647 }  // namespace data
648 }  // namespace tensorflow
649