• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/prefetch_dataset_op.h"
16 
17 #include <deque>
18 
19 #include "tensorflow/core/common_runtime/metrics.h"
20 #include "tensorflow/core/framework/dataset.h"
21 #include "tensorflow/core/framework/model.h"
22 #include "tensorflow/core/framework/partial_tensor_shape.h"
23 #include "tensorflow/core/framework/stats_aggregator.h"
24 #include "tensorflow/core/framework/tensor.h"
25 #include "tensorflow/core/kernels/data/dataset_utils.h"
26 #include "tensorflow/core/kernels/data/name_utils.h"
27 #include "tensorflow/core/kernels/data/stats_utils.h"
28 #include "tensorflow/core/lib/gtl/cleanup.h"
29 #include "tensorflow/core/lib/strings/str_util.h"
30 #include "tensorflow/core/lib/strings/stringprintf.h"
31 #include "tensorflow/core/platform/stringprintf.h"
32 #include "tensorflow/core/profiler/lib/traceme.h"
33 #include "tensorflow/core/profiler/lib/traceme_encode.h"
34 #include "tensorflow/core/protobuf/error_codes.pb.h"
35 
36 namespace tensorflow {
37 namespace data {
38 
39 // See documentation in ../../ops/dataset_ops.cc for a high-level
40 // description of the following op.
41 
42 /* static */ constexpr const char* const PrefetchDatasetOp::kDatasetType;
43 /* static */ constexpr const char* const PrefetchDatasetOp::kInputDataset;
44 /* static */ constexpr const char* const PrefetchDatasetOp::kBufferSize;
45 /* static */ constexpr const char* const PrefetchDatasetOp::kOutputTypes;
46 /* static */ constexpr const char* const PrefetchDatasetOp::kOutputShapes;
47 /* static */ constexpr const char* const PrefetchDatasetOp::kSlackPeriod;
48 /* static */ constexpr const char* const PrefetchDatasetOp::kLegacyAutotune;
49 /* static */ constexpr const char* const PrefetchDatasetOp::kBufferSizeMin;
50 
51 namespace {
52 
53 // Determines the fraction of slack time by which to delay prefetching of data.
54 constexpr double kSleepFactor = 0.2;
55 constexpr char kBuffer[] = "buffer";
56 constexpr char kStatus[] = "status";
57 constexpr char kSizeSuffix[] = ".size";
58 constexpr char kCodeSuffix[] = ".code";
59 constexpr char kErrorMessageSuffix[] = ".error_message";
60 
61 }  // namespace
62 
63 class PrefetchDatasetOp::Dataset : public DatasetBase {
64  public:
Dataset(OpKernelContext * ctx,const DatasetBase * input,int64 buffer_size,int64 slack_period,bool legacy_autotune,int64 buffer_size_min)65   Dataset(OpKernelContext* ctx, const DatasetBase* input, int64 buffer_size,
66           int64 slack_period, bool legacy_autotune, int64 buffer_size_min)
67       : DatasetBase(DatasetContext(ctx)),
68         input_(input),
69         buffer_size_(buffer_size),
70         slack_period_(slack_period),
71         legacy_autotune_(legacy_autotune),
72         buffer_size_min_(buffer_size_min) {
73     input_->Ref();
74   }
75 
~Dataset()76   ~Dataset() override { input_->Unref(); }
77 
MakeIteratorInternal(const string & prefix) const78   std::unique_ptr<IteratorBase> MakeIteratorInternal(
79       const string& prefix) const override {
80     return absl::make_unique<Iterator>(Iterator::Params{
81         this, name_utils::IteratorPrefix(kDatasetType, prefix)});
82   }
83 
output_dtypes() const84   const DataTypeVector& output_dtypes() const override {
85     return input_->output_dtypes();
86   }
87 
output_shapes() const88   const std::vector<PartialTensorShape>& output_shapes() const override {
89     return input_->output_shapes();
90   }
91 
DebugString() const92   string DebugString() const override {
93     return name_utils::DatasetDebugString(kDatasetType);
94   }
95 
Cardinality() const96   int64 Cardinality() const override { return input_->Cardinality(); }
97 
InputDatasets(std::vector<const DatasetBase * > * inputs) const98   Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
99     inputs->push_back(input_);
100     return Status::OK();
101   }
102 
CheckExternalState() const103   Status CheckExternalState() const override {
104     return input_->CheckExternalState();
105   }
106 
107  protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const108   Status AsGraphDefInternal(SerializationContext* ctx,
109                             DatasetGraphDefBuilder* b,
110                             Node** output) const override {
111     Node* input_graph_node = nullptr;
112     TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
113     Node* buffer_size = nullptr;
114     TF_RETURN_IF_ERROR(b->AddScalar(buffer_size_, &buffer_size));
115     AttrValue slack_period_attr;
116     b->BuildAttrValue(slack_period_, &slack_period_attr);
117     AttrValue legacy_autotune_attr;
118     b->BuildAttrValue(legacy_autotune_, &legacy_autotune_attr);
119     AttrValue buffer_size_min_attr;
120     b->BuildAttrValue(buffer_size_min_, &buffer_size_min_attr);
121 
122     TF_RETURN_IF_ERROR(
123         b->AddDataset(this, {input_graph_node, buffer_size},
124                       {std::make_pair(kSlackPeriod, slack_period_attr),
125                        std::make_pair(kLegacyAutotune, legacy_autotune_attr),
126                        std::make_pair(kBufferSizeMin, buffer_size_min_attr)},
127                       output));
128     return Status::OK();
129   }
130 
131  private:
132   class Iterator : public DatasetIterator<Dataset> {
133    public:
Iterator(const Params & params)134     explicit Iterator(const Params& params)
135         : DatasetIterator<Dataset>(params),
136           mu_(std::make_shared<mutex>()),
137           cond_var_(std::make_shared<condition_variable>()),
138           buffer_size_min_(params.dataset->buffer_size_min_),
139           auto_tuner_(params.dataset->buffer_size_, buffer_size_min_),
140           legacy_autotune_(params.dataset->legacy_autotune_),
141           // If `legacy_autotune_`, initialize the `buffer_size_` value to be 0
142           // to avoid the created node to be collected as tunable nodes in the
143           // autotuning optimization.
144           buffer_size_(std::make_shared<model::SharedState>(
145               legacy_autotune_ ? 0 : params.dataset->buffer_size_, mu_,
146               cond_var_)) {
147       slack_us_ = 0;
148     }
149 
~Iterator()150     ~Iterator() override {
151       CancelThreads();
152       if (deregister_fn_) deregister_fn_();
153     }
154 
Initialize(IteratorContext * ctx)155     Status Initialize(IteratorContext* ctx) override {
156       mutex_lock l(*mu_);
157       if (buffer_size_->value == model::kAutotune) {
158         buffer_size_->value = buffer_size_min_;
159       }
160       TF_RETURN_IF_ERROR(RegisterCancellationCallback(
161           ctx->cancellation_manager(), [this]() { CancelThreads(); },
162           &deregister_fn_));
163       return dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_);
164     }
165 
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)166     Status GetNextInternal(IteratorContext* ctx,
167                            std::vector<Tensor>* out_tensors,
168                            bool* end_of_sequence) override {
169       const auto& stats_aggregator = ctx->stats_aggregator();
170       {
171         mutex_lock l(*mu_);
172         TF_RETURN_IF_ERROR(EnsurePrefetchThreadStarted(ctx));
173         // Wait until the next element in the buffer has been
174         // produced, or we are shutting down.
175         if (legacy_autotune_) {
176           while (!cancelled_ && buffer_.empty() && !prefetch_thread_finished_ &&
177                  auto_tuner_.buffer_limit() != 0) {
178             auto_tuner_.RecordEmpty();
179             buffer_size_->value = auto_tuner_.buffer_limit();
180             RecordStop(ctx);
181             cond_var_->wait(l);
182             RecordStart(ctx);
183           }
184         } else {
185           while (!cancelled_ && buffer_.empty() && !prefetch_thread_finished_ &&
186                  buffer_size_->value != 0) {
187             RecordStop(ctx);
188             cond_var_->wait(l);
189             RecordStart(ctx);
190           }
191         }
192 
193         if (cancelled_) {
194           return errors::Cancelled("Iterator was cancelled");
195         }
196 
197         if (!buffer_.empty()) {
198           return Consume(ctx, out_tensors, end_of_sequence);
199         }
200 
201         if (prefetch_thread_finished_) {
202           *end_of_sequence = true;
203           return Status::OK();
204         }
205 
206         DCHECK_EQ(buffer_limit(), 0);
207       }
208 
209       mutex_lock input_l(input_mu_);
210       {
211         mutex_lock l(*mu_);
212         if (stats_aggregator) {
213           stats_aggregator->AddScalar(
214               stats_utils::BufferSizeScalarName(dataset()->node_name()),
215               static_cast<float>(buffer_.size()), num_elements());
216           stats_aggregator->AddScalar(
217               stats_utils::BufferCapacityScalarName(dataset()->node_name()),
218               static_cast<float>(buffer_limit()), num_elements());
219         }
220         // Release mu_
221       }
222       return input_impl_->GetNext(ctx, out_tensors, end_of_sequence);
223     }
224 
225    protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const226     std::shared_ptr<model::Node> CreateNode(
227         IteratorContext* ctx, model::Node::Args args) const override {
228       return model::MakeAsyncKnownRatioNode(
229           std::move(args),
230           /*ratio=*/1,
231           {model::MakeParameter(kBufferSize, buffer_size_,
232                                 /*min=*/buffer_size_min_,
233                                 /*max=*/std::numeric_limits<int64>::max())});
234     }
235 
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)236     Status SaveInternal(SerializationContext* ctx,
237                         IteratorStateWriter* writer) override {
238       // Acquire both locks to ensure that the prefetch thread and
239       // all GetNext threads are blocked.
240       mutex_lock input_l(input_mu_);
241       mutex_lock l(*mu_);
242       TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
243       TF_RETURN_IF_ERROR(
244           writer->WriteScalar(prefix(), kBufferSize, buffer_.size()));
245       for (size_t i = 0; i < buffer_.size(); i++) {
246         auto& buffer_element = buffer_[i];
247         TF_RETURN_IF_ERROR(WriteStatus(writer, i, buffer_element.status));
248         if (buffer_element.status.ok()) {
249           TF_RETURN_IF_ERROR(writer->WriteScalar(
250               absl::StrCat(prefix(), "::", i),
251               absl::StrCat(kBuffer, kSizeSuffix), buffer_element.value.size()));
252           for (size_t j = 0; j < buffer_element.value.size(); j++) {
253             TF_RETURN_IF_ERROR(writer->WriteTensor(
254                 absl::StrCat(prefix(), "::", i),
255                 absl::StrCat(kBuffer, "[", j, "]"), buffer_element.value[j]));
256           }
257         }
258       }
259       return Status::OK();
260     }
261 
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)262     Status RestoreInternal(IteratorContext* ctx,
263                            IteratorStateReader* reader) override {
264       mutex_lock input_l(input_mu_);
265       mutex_lock l(*mu_);
266       DCHECK(buffer_.empty());
267       TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
268       size_t buffer_size;
269       {
270         int64 temp;
271         TF_RETURN_IF_ERROR(reader->ReadScalar(prefix(), kBufferSize, &temp));
272         buffer_size = static_cast<size_t>(temp);
273       }
274       for (size_t i = 0; i < buffer_size; i++) {
275         buffer_.emplace_back();
276         auto& buffer_element = buffer_.back();
277         TF_RETURN_IF_ERROR(ReadStatus(reader, i, &buffer_element.status));
278         if (buffer_element.status.ok()) {
279           size_t value_size;
280           {
281             int64 temp;
282             TF_RETURN_IF_ERROR(
283                 reader->ReadScalar(absl::StrCat(prefix(), "::", i),
284                                    absl::StrCat(kBuffer, kSizeSuffix), &temp));
285             value_size = static_cast<size_t>(temp);
286           }
287           buffer_element.value.reserve(value_size);
288           for (size_t j = 0; j < value_size; j++) {
289             buffer_element.value.emplace_back();
290             TF_RETURN_IF_ERROR(
291                 reader->ReadTensor(absl::StrCat(prefix(), "::", i),
292                                    absl::StrCat(kBuffer, "[", j, "]"),
293                                    &buffer_element.value.back()));
294           }
295         }
296         RecordBufferEnqueue(ctx, buffer_element.value);
297       }
298       return Status::OK();
299     }
300 
GetTraceMeMetadata() const301     data::TraceMeMetadata GetTraceMeMetadata() const override {
302       int64 limit = -1, size = -1;
303       data::TraceMeMetadata result;
304       // NOTE: We only set the parallelism value if the lock can be acquired
305       // right away to avoid introducing tracing overhead.
306       if (mu_->try_lock()) {
307         limit = buffer_limit();
308         size = buffer_.size();
309         if (!buffer_.empty()) {
310           std::vector<std::string> shapes(buffer_.front().value.size());
311           for (const auto& component : buffer_.front().value) {
312             shapes.push_back(component.shape().DebugString());
313           }
314           result.push_back(std::make_pair("next_element_shapes",
315                                           absl::StrJoin(shapes, ",")));
316         }
317         mu_->unlock();
318       }
319       result.push_back(std::make_pair(
320           "buffer_limit",
321           strings::Printf("%lld", static_cast<long long>(limit))));
322       result.push_back(std::make_pair(
323           "buffer_size",
324           strings::Printf("%lld", static_cast<long long>(size))));
325       result.push_back(std::make_pair(
326           "autotune",
327           dataset()->buffer_size_ == model::kAutotune ? "true" : "false"));
328       result.push_back(std::make_pair(
329           "autotune_mode", legacy_autotune_ ? "legacy" : "performance"));
330       if (dataset()->slack_period_ > 0) {
331         result.push_back(std::make_pair(
332             "slack",
333             strings::Printf("%lld", static_cast<long long>(slack_us_.load()))));
334       }
335       return result;
336     }
337 
338    private:
339     // A buffer element comprises a status and (if that status is
340     // OK) a vector of tensors, representing an element of the input dataset.
341     struct BufferElement {
BufferElementtensorflow::data::PrefetchDatasetOp::Dataset::Iterator::BufferElement342       BufferElement() : uid(tensorflow::EnvTime::NowNanos()) {}
343 
344       // The producer sets `status` if getting the input element fails.
345       Status status;
346       // The buffered data element.
347       std::vector<Tensor> value;
348       int64 created_us;
349       const uint64 uid;
350     };
351 
buffer_limit() const352     int64 buffer_limit() const TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
353       if (legacy_autotune_) {
354         return auto_tuner_.buffer_limit();
355       }
356       return buffer_size_->value;
357     }
358 
CancelThreads()359     void CancelThreads() TF_LOCKS_EXCLUDED(mu_) {
360       mutex_lock l(*mu_);
361       cancelled_ = true;
362       cond_var_->notify_all();
363     }
364 
Consume(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)365     Status Consume(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
366                    bool* end_of_sequence) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
367       const auto& stats_aggregator = ctx->stats_aggregator();
368       if (stats_aggregator) {
369         double buffer_limit_ = buffer_limit();
370         stats_aggregator->AddToHistogram(
371             stats_utils::BufferUtilizationHistogramName(dataset()->node_name()),
372             {static_cast<float>(buffer_.size()) /
373              static_cast<float>(buffer_limit_)},
374             num_elements());
375         stats_aggregator->AddScalar(
376             stats_utils::BufferSizeScalarName(dataset()->node_name()),
377             static_cast<float>(buffer_.size()), num_elements());
378         stats_aggregator->AddScalar(
379             stats_utils::BufferCapacityScalarName(dataset()->node_name()),
380             static_cast<float>(buffer_limit_), num_elements());
381       }
382       // A new element is available. Forward the status from computing it, and
383       // (if we successfully got an element) the output values.
384       Status s = buffer_.front().status;
385       if (s.ok()) {
386         int64 buffer_element_id = buffer_.front().uid;
387         profiler::TraceMe traceme(
388             [&] {
389               return profiler::TraceMeEncode(
390                   "PrefetchConsume", {{"element_id", buffer_element_id}});
391             },
392             profiler::kInfo);
393         if (dataset()->slack_period_ > 0 &&
394             (num_elements() + 1) % dataset()->slack_period_ == 0) {
395           // TODO(rachelim): Consider doing something more sophisticated
396           // to decide how long to sleep for; e.g. using a kalman filter.
397           int64 slack_us = EnvTime::NowMicros() - buffer_.front().created_us;
398           // Every slack_period_-th element, update the most recent slack time,
399           // measured by the duration between when the element is prefetched
400           // and when it is consumed. We add kSleepFactor * slack_us_ to the
401           // measurement because we slept for that duration before prefetching
402           // the element.
403           slack_us_ = kSleepFactor * slack_us_ + slack_us;
404           VLOG(2) << "Setting slack_us_: " << slack_us_;
405         }
406         *out_tensors = std::move(buffer_.front().value);
407         RecordBufferDequeue(ctx, *out_tensors);
408       } else {
409         // If status not ok, we still record the dequeue event to make sure each
410         // enqueue event is paired with a dequeue event even in the presence of
411         // errors.
412         RecordBufferDequeue(ctx, buffer_.front().value);
413       }
414       if (legacy_autotune_) {
415         auto_tuner_.RecordConsumption(buffer_.size());
416         buffer_size_->value = auto_tuner_.buffer_limit();
417       }
418       buffer_.pop_front();
419       *end_of_sequence = false;
420 
421       // Wake the prefetch thread, in case it has been waiting for space
422       // in the buffer. Also wake up threads from other calls to GetNext.
423       //
424       // TODO(mrry): Consider using different condition variables for
425       // GetNext and Prefetch.
426       cond_var_->notify_all();
427       return s;
428     }
429 
EnsurePrefetchThreadStarted(IteratorContext * ctx)430     Status EnsurePrefetchThreadStarted(IteratorContext* ctx)
431         TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
432       if (!prefetch_thread_) {
433         std::shared_ptr<IteratorContext> new_ctx =
434             std::make_shared<IteratorContext>(*ctx);
435         prefetch_thread_ = ctx->StartThread(
436             "tf_data_prefetch", [this, new_ctx]() { PrefetchThread(new_ctx); });
437       }
438       return Status::OK();
439     }
440 
441     // Prefetches elements of the input, storing results in an internal buffer.
442     //
443     // It owns the iterator context passed to it.
PrefetchThread(const std::shared_ptr<IteratorContext> & ctx)444     void PrefetchThread(const std::shared_ptr<IteratorContext>& ctx) {
445       RecordStart(ctx.get());
446       auto cleanup = gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); });
447       // Keep track of where we are in an iteration "burst"
448       int num_produced = 0;
449       while (true) {
450         // 1. Wait for a slot in the buffer.
451         {
452           mutex_lock l(*mu_);
453           while (!cancelled_ && buffer_.size() >= buffer_limit()) {
454             RecordStop(ctx.get());
455             cond_var_->wait(l);
456             RecordStart(ctx.get());
457           }
458 
459           if (cancelled_) {
460             prefetch_thread_finished_ = true;
461             cond_var_->notify_all();
462             return;
463           }
464         }
465 
466         if (dataset()->slack_period_ > 0 &&
467             num_produced % dataset()->slack_period_ == 0) {
468           // For the first element in the "burst", sleep for a bit if there is
469           // slack.
470           VLOG(2) << "Sleeping for: " << slack_us_ * kSleepFactor;
471           ctx->env()->SleepForMicroseconds(slack_us_ * kSleepFactor);
472         }
473 
474         // 2. Read the next element.
475         // Acquire the input mutex since we will be reading an element from the
476         // input iterator. Note that we do not wish to release this mutex till
477         // we have added the fetched element to the `buffer_` else there will be
478         // local state that may be missed by SaveInternal.
479         mutex_lock input_l(input_mu_);
480         bool end_of_sequence;
481         BufferElement buffer_element;
482         {
483           profiler::TraceMe traceme(
484               [&] {
485                 return profiler::TraceMeEncode(
486                     "PrefetchProduce", {{"element_id", buffer_element.uid}});
487               },
488               profiler::kInfo);
489           buffer_element.status = input_impl_->GetNext(
490               ctx.get(), &buffer_element.value, &end_of_sequence);
491         }
492         if (buffer_element.status.ok() && end_of_sequence) {
493           mutex_lock l(*mu_);
494           prefetch_thread_finished_ = true;
495           cond_var_->notify_all();
496           return;
497         }
498 
499         // 3. Signal that the element has been produced.
500         {
501           mutex_lock l(*mu_);
502           RecordBufferEnqueue(ctx.get(), buffer_element.value);
503           buffer_element.created_us = EnvTime::NowMicros();
504           buffer_.push_back(std::move(buffer_element));
505           cond_var_->notify_all();
506         }
507         ++num_produced;
508       }
509     }
510 
WriteStatus(IteratorStateWriter * writer,size_t index,const Status & status)511     Status WriteStatus(IteratorStateWriter* writer, size_t index,
512                        const Status& status) TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
513       TF_RETURN_IF_ERROR(
514           writer->WriteScalar(absl::StrCat(prefix(), "::", index), CodeKey(),
515                               static_cast<int64>(status.code())));
516       if (!status.ok()) {
517         TF_RETURN_IF_ERROR(
518             writer->WriteScalar(absl::StrCat(prefix(), "::", index),
519                                 ErrorMessageKey(), status.error_message()));
520       }
521       return Status::OK();
522     }
523 
ReadStatus(IteratorStateReader * reader,size_t index,Status * status)524     Status ReadStatus(IteratorStateReader* reader, size_t index, Status* status)
525         TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
526       int64 code_int;
527       TF_RETURN_IF_ERROR(reader->ReadScalar(absl::StrCat(prefix(), "::", index),
528                                             CodeKey(), &code_int));
529       error::Code code = static_cast<error::Code>(code_int);
530 
531       if (code != error::Code::OK) {
532         tstring error_message;
533         TF_RETURN_IF_ERROR(
534             reader->ReadScalar(absl::StrCat(prefix(), "::", index),
535                                ErrorMessageKey(), &error_message));
536         *status = Status(code, error_message);
537       } else {
538         *status = Status::OK();
539       }
540       return Status::OK();
541     }
542 
CodeKey()543     string CodeKey() { return absl::StrCat(kStatus, kCodeSuffix); }
544 
ErrorMessageKey()545     string ErrorMessageKey() {
546       return absl::StrCat(kStatus, kErrorMessageSuffix);
547     }
548 
549     // This mutex is used to ensure exclusivity between multiple threads
550     // reading/writing this iterator's local state.
551     //
552     // NOTE: We should never call GetNext on the input while holding this mutex.
553     const std::shared_ptr<mutex> mu_;
554     // This mutex is used to ensure exclusivity between multiple threads
555     // accessing the input iterator. We keep this separate from `mu_` to allow
556     // prefetching to run in parallel with GetNext calls.
557     mutex input_mu_ TF_ACQUIRED_BEFORE(*mu_);
558     std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(input_mu_);
559     const std::shared_ptr<condition_variable> cond_var_;
560     const int64 buffer_size_min_;
561     PrefetchAutotuner auto_tuner_ TF_GUARDED_BY(*mu_);
562     std::deque<BufferElement> buffer_ TF_GUARDED_BY(*mu_);
563     std::unique_ptr<Thread> prefetch_thread_ TF_GUARDED_BY(*mu_);
564     bool cancelled_ TF_GUARDED_BY(*mu_) = false;
565     bool prefetch_thread_finished_ TF_GUARDED_BY(*mu_) = false;
566     const bool legacy_autotune_;
567 
568     std::atomic<int64> slack_us_;
569 
570     // If legacy_autotune_ is false, identifies the maximum size of the buffer.
571     const std::shared_ptr<model::SharedState> buffer_size_;
572 
573     // Method for deregistering the cancellation callback.
574     std::function<void()> deregister_fn_;
575   };
576   const DatasetBase* const input_;
577   const int64 buffer_size_;
578 
579   // If non-zero, determines the period between injecting "slack" into the
580   // execution.
581   const int64 slack_period_;
582 
583   // Determines whether legacy autotuning should be used.
584   const bool legacy_autotune_ = true;
585 
586   // If autotune is enabled, determines the minimal value of `buffer_size`
587   // parameter.
588   const int64 buffer_size_min_ = 0;
589 
590   TraceMeMetadata traceme_metadata_;
591 };
592 
PrefetchDatasetOp(OpKernelConstruction * ctx)593 PrefetchDatasetOp::PrefetchDatasetOp(OpKernelConstruction* ctx)
594     : UnaryDatasetOpKernel(ctx) {
595   if (ctx->HasAttr(kSlackPeriod)) {
596     OP_REQUIRES_OK(ctx, ctx->GetAttr(kSlackPeriod, &slack_period_));
597   }
598   if (ctx->HasAttr(kLegacyAutotune)) {
599     OP_REQUIRES_OK(ctx, ctx->GetAttr(kLegacyAutotune, &legacy_autotune_));
600   }
601   if (ctx->HasAttr(kBufferSizeMin)) {
602     OP_REQUIRES_OK(ctx, ctx->GetAttr(kBufferSizeMin, &buffer_size_min_));
603   }
604 }
605 
MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)606 void PrefetchDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
607                                     DatasetBase** output) {
608   int64 buffer_size = 0;
609   OP_REQUIRES_OK(ctx,
610                  ParseScalarArgument<int64>(ctx, kBufferSize, &buffer_size));
611   OP_REQUIRES(ctx, buffer_size >= 0 || buffer_size == model::kAutotune,
612               errors::InvalidArgument("buffer_size must be >= 0 or set "
613                                       "buffer_size to be ",
614                                       model::kAutotune, " for auto-tuning"));
615 
616   if (buffer_size == model::kAutotune) {
617     metrics::RecordTFDataAutotune(kDatasetType);
618   }
619 
620   *output = new Dataset(ctx, input, buffer_size, slack_period_,
621                         legacy_autotune_, buffer_size_min_);
622 }
623 
624 namespace {
625 REGISTER_KERNEL_BUILDER(Name("PrefetchDataset").Device(DEVICE_CPU).Priority(2),
626                         PrefetchDatasetOp);
627 REGISTER_KERNEL_BUILDER(Name("PrefetchDataset")
628                             .Device(DEVICE_GPU)
629                             .HostMemory("buffer_size")
630                             .HostMemory("input_dataset")
631                             .HostMemory("handle")
632                             .Priority(1),
633                         PrefetchDatasetOp);
634 }  // namespace
635 
636 }  // namespace data
637 }  // namespace tensorflow
638