1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/prefetch_dataset_op.h"
16
17 #include <deque>
18
19 #include "tensorflow/core/common_runtime/metrics.h"
20 #include "tensorflow/core/data/dataset_utils.h"
21 #include "tensorflow/core/data/name_utils.h"
22 #include "tensorflow/core/data/stats_utils.h"
23 #include "tensorflow/core/framework/dataset.h"
24 #include "tensorflow/core/framework/model.h"
25 #include "tensorflow/core/framework/partial_tensor_shape.h"
26 #include "tensorflow/core/framework/stats_aggregator.h"
27 #include "tensorflow/core/framework/tensor.h"
28 #include "tensorflow/core/lib/gtl/cleanup.h"
29 #include "tensorflow/core/lib/strings/str_util.h"
30 #include "tensorflow/core/lib/strings/stringprintf.h"
31 #include "tensorflow/core/platform/stringprintf.h"
32 #include "tensorflow/core/profiler/lib/traceme.h"
33 #include "tensorflow/core/profiler/lib/traceme_encode.h"
34 #include "tensorflow/core/protobuf/error_codes.pb.h"
35
36 namespace tensorflow {
37 namespace data {
38
39 // See documentation in ../../ops/dataset_ops.cc for a high-level
40 // description of the following op.
41
42 /* static */ constexpr const char* const PrefetchDatasetOp::kDatasetType;
43 /* static */ constexpr const char* const PrefetchDatasetOp::kInputDataset;
44 /* static */ constexpr const char* const PrefetchDatasetOp::kBufferSize;
45 /* static */ constexpr const char* const PrefetchDatasetOp::kOutputTypes;
46 /* static */ constexpr const char* const PrefetchDatasetOp::kOutputShapes;
47 /* static */ constexpr const char* const PrefetchDatasetOp::kSlackPeriod;
48 /* static */ constexpr const char* const PrefetchDatasetOp::kLegacyAutotune;
49 /* static */ constexpr const char* const PrefetchDatasetOp::kBufferSizeMin;
50
51 namespace {
52
53 // Determines the fraction of slack time by which to delay prefetching of data.
54 constexpr double kSleepFactor = 0.2;
55 constexpr char kBuffer[] = "buffer";
56 constexpr char kStatus[] = "status";
57 constexpr char kSizeSuffix[] = ".size";
58 constexpr char kCodeSuffix[] = ".code";
59 constexpr char kErrorMessageSuffix[] = ".error_message";
60
61 } // namespace
62
63 class PrefetchDatasetOp::Dataset : public DatasetBase {
64 public:
Dataset(OpKernelContext * ctx,const DatasetBase * input,int64_t buffer_size,int64_t slack_period,bool legacy_autotune,int64_t buffer_size_min)65 Dataset(OpKernelContext* ctx, const DatasetBase* input, int64_t buffer_size,
66 int64_t slack_period, bool legacy_autotune, int64_t buffer_size_min)
67 : DatasetBase(DatasetContext(ctx)),
68 input_(input),
69 buffer_size_(buffer_size),
70 slack_period_(slack_period),
71 legacy_autotune_(legacy_autotune),
72 buffer_size_min_(buffer_size_min) {
73 input_->Ref();
74 }
75
~Dataset()76 ~Dataset() override { input_->Unref(); }
77
MakeIteratorInternal(const string & prefix) const78 std::unique_ptr<IteratorBase> MakeIteratorInternal(
79 const string& prefix) const override {
80 return absl::make_unique<Iterator>(Iterator::Params{
81 this, name_utils::IteratorPrefix(kDatasetType, prefix)});
82 }
83
output_dtypes() const84 const DataTypeVector& output_dtypes() const override {
85 return input_->output_dtypes();
86 }
87
output_shapes() const88 const std::vector<PartialTensorShape>& output_shapes() const override {
89 return input_->output_shapes();
90 }
91
DebugString() const92 string DebugString() const override {
93 return name_utils::DatasetDebugString(kDatasetType);
94 }
95
Cardinality() const96 int64 Cardinality() const override { return input_->Cardinality(); }
97
InputDatasets(std::vector<const DatasetBase * > * inputs) const98 Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
99 inputs->push_back(input_);
100 return Status::OK();
101 }
102
CheckExternalState() const103 Status CheckExternalState() const override {
104 return input_->CheckExternalState();
105 }
106
107 protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const108 Status AsGraphDefInternal(SerializationContext* ctx,
109 DatasetGraphDefBuilder* b,
110 Node** output) const override {
111 Node* input_graph_node = nullptr;
112 TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
113 Node* buffer_size = nullptr;
114 TF_RETURN_IF_ERROR(b->AddScalar(buffer_size_, &buffer_size));
115 AttrValue slack_period_attr;
116 b->BuildAttrValue(slack_period_, &slack_period_attr);
117 AttrValue legacy_autotune_attr;
118 b->BuildAttrValue(legacy_autotune_, &legacy_autotune_attr);
119 AttrValue buffer_size_min_attr;
120 b->BuildAttrValue(buffer_size_min_, &buffer_size_min_attr);
121
122 TF_RETURN_IF_ERROR(
123 b->AddDataset(this, {input_graph_node, buffer_size},
124 {std::make_pair(kSlackPeriod, slack_period_attr),
125 std::make_pair(kLegacyAutotune, legacy_autotune_attr),
126 std::make_pair(kBufferSizeMin, buffer_size_min_attr)},
127 output));
128 return Status::OK();
129 }
130
131 private:
132 class Iterator : public DatasetIterator<Dataset> {
133 public:
Iterator(const Params & params)134 explicit Iterator(const Params& params)
135 : DatasetIterator<Dataset>(params),
136 mu_(std::make_shared<mutex>()),
137 cond_var_(std::make_shared<condition_variable>()),
138 buffer_size_min_(params.dataset->buffer_size_min_),
139 auto_tuner_(params.dataset->buffer_size_, buffer_size_min_),
140 legacy_autotune_(params.dataset->legacy_autotune_),
141 // If `legacy_autotune_`, initialize the `buffer_size_` value to be 0
142 // to avoid the created node to be collected as tunable nodes in the
143 // autotuning optimization.
144 buffer_size_(std::make_shared<model::SharedState>(
145 legacy_autotune_ ? 0 : params.dataset->buffer_size_, mu_,
146 cond_var_)) {
147 slack_us_ = 0;
148 }
149
~Iterator()150 ~Iterator() override {
151 CancelThreads();
152 if (deregister_fn_) deregister_fn_();
153 }
154
Initialize(IteratorContext * ctx)155 Status Initialize(IteratorContext* ctx) override {
156 mutex_lock l(*mu_);
157 if (buffer_size_->value == model::kAutotune) {
158 buffer_size_->value = buffer_size_min_;
159 }
160 cancellation_manager_ = absl::make_unique<CancellationManager>();
161 TF_RETURN_IF_ERROR(RegisterCancellationCallback(
162 ctx->cancellation_manager(), [this]() { CancelThreads(); },
163 &deregister_fn_));
164 IteratorContext::Params params(ctx);
165 params.cancellation_manager = cancellation_manager_.get();
166 return dataset()->input_->MakeIterator(IteratorContext(params), this,
167 prefix(), &input_impl_);
168 }
169
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)170 Status GetNextInternal(IteratorContext* ctx,
171 std::vector<Tensor>* out_tensors,
172 bool* end_of_sequence) override {
173 const auto& stats_aggregator = ctx->stats_aggregator();
174 {
175 mutex_lock l(*mu_);
176 TF_RETURN_IF_ERROR(EnsurePrefetchThreadStarted(ctx));
177 // Wait until the next element in the buffer has been
178 // produced, or we are shutting down.
179 if (legacy_autotune_) {
180 while (!cancelled_ && buffer_.empty() && !prefetch_thread_finished_ &&
181 auto_tuner_.buffer_limit() != 0) {
182 auto_tuner_.RecordEmpty();
183 buffer_size_->value = auto_tuner_.buffer_limit();
184 RecordStop(ctx);
185 cond_var_->wait(l);
186 RecordStart(ctx);
187 }
188 } else {
189 while (!cancelled_ && buffer_.empty() && !prefetch_thread_finished_ &&
190 buffer_size_->value != 0) {
191 RecordStop(ctx);
192 cond_var_->wait(l);
193 RecordStart(ctx);
194 }
195 }
196
197 if (cancelled_) {
198 return errors::Cancelled("Iterator was cancelled");
199 }
200
201 if (!buffer_.empty()) {
202 return Consume(ctx, out_tensors, end_of_sequence);
203 }
204
205 if (prefetch_thread_finished_) {
206 *end_of_sequence = true;
207 return Status::OK();
208 }
209
210 DCHECK_EQ(buffer_limit(), 0);
211 }
212
213 mutex_lock input_l(input_mu_);
214 {
215 mutex_lock l(*mu_);
216 if (stats_aggregator) {
217 stats_aggregator->AddScalar(
218 stats_utils::BufferSizeScalarName(dataset()->node_name()),
219 static_cast<float>(buffer_.size()), num_elements());
220 stats_aggregator->AddScalar(
221 stats_utils::BufferCapacityScalarName(dataset()->node_name()),
222 static_cast<float>(buffer_limit()), num_elements());
223 }
224 // Release mu_
225 }
226 return input_impl_->GetNext(ctx, out_tensors, end_of_sequence);
227 }
228
229 protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const230 std::shared_ptr<model::Node> CreateNode(
231 IteratorContext* ctx, model::Node::Args args) const override {
232 return model::MakeAsyncKnownRatioNode(
233 std::move(args),
234 /*ratio=*/1,
235 {model::MakeParameter(kBufferSize, buffer_size_,
236 /*min=*/buffer_size_min_,
237 /*max=*/std::numeric_limits<int64>::max())});
238 }
239
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)240 Status SaveInternal(SerializationContext* ctx,
241 IteratorStateWriter* writer) override {
242 // Acquire both locks to ensure that the prefetch thread and
243 // all GetNext threads are blocked.
244 mutex_lock input_l(input_mu_);
245 mutex_lock l(*mu_);
246 TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
247 TF_RETURN_IF_ERROR(
248 writer->WriteScalar(prefix(), kBufferSize, buffer_.size()));
249 for (size_t i = 0; i < buffer_.size(); i++) {
250 auto& buffer_element = buffer_[i];
251 TF_RETURN_IF_ERROR(WriteStatus(writer, i, buffer_element.status));
252 if (buffer_element.status.ok()) {
253 TF_RETURN_IF_ERROR(writer->WriteScalar(
254 absl::StrCat(prefix(), "::", i),
255 absl::StrCat(kBuffer, kSizeSuffix), buffer_element.value.size()));
256 for (size_t j = 0; j < buffer_element.value.size(); j++) {
257 TF_RETURN_IF_ERROR(writer->WriteTensor(
258 absl::StrCat(prefix(), "::", i),
259 absl::StrCat(kBuffer, "[", j, "]"), buffer_element.value[j]));
260 }
261 }
262 }
263 return Status::OK();
264 }
265
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)266 Status RestoreInternal(IteratorContext* ctx,
267 IteratorStateReader* reader) override {
268 mutex_lock input_l(input_mu_);
269 mutex_lock l(*mu_);
270 DCHECK(buffer_.empty());
271 TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
272 size_t buffer_size;
273 {
274 int64_t temp;
275 TF_RETURN_IF_ERROR(reader->ReadScalar(prefix(), kBufferSize, &temp));
276 buffer_size = static_cast<size_t>(temp);
277 }
278 for (size_t i = 0; i < buffer_size; i++) {
279 buffer_.emplace_back();
280 auto& buffer_element = buffer_.back();
281 TF_RETURN_IF_ERROR(ReadStatus(reader, i, &buffer_element.status));
282 if (buffer_element.status.ok()) {
283 size_t value_size;
284 {
285 int64_t temp;
286 TF_RETURN_IF_ERROR(
287 reader->ReadScalar(absl::StrCat(prefix(), "::", i),
288 absl::StrCat(kBuffer, kSizeSuffix), &temp));
289 value_size = static_cast<size_t>(temp);
290 }
291 buffer_element.value.reserve(value_size);
292 for (size_t j = 0; j < value_size; j++) {
293 buffer_element.value.emplace_back();
294 TF_RETURN_IF_ERROR(
295 reader->ReadTensor(ctx->flr(), absl::StrCat(prefix(), "::", i),
296 absl::StrCat(kBuffer, "[", j, "]"),
297 &buffer_element.value.back()));
298 }
299 }
300 RecordBufferEnqueue(ctx, buffer_element.value);
301 }
302 return Status::OK();
303 }
304
GetTraceMeMetadata() const305 data::TraceMeMetadata GetTraceMeMetadata() const override {
306 int64_t limit = -1, size = -1;
307 data::TraceMeMetadata result;
308 // NOTE: We only set the parallelism value if the lock can be acquired
309 // right away to avoid introducing tracing overhead.
310 if (mu_->try_lock()) {
311 limit = buffer_limit();
312 size = buffer_.size();
313 if (!buffer_.empty()) {
314 std::vector<std::string> shapes(buffer_.front().value.size());
315 for (const auto& component : buffer_.front().value) {
316 shapes.push_back(component.shape().DebugString());
317 }
318 result.push_back(std::make_pair("next_element_shapes",
319 absl::StrJoin(shapes, ",")));
320 }
321 mu_->unlock();
322 }
323 result.push_back(std::make_pair(
324 "buffer_limit",
325 limit == -1
326 ? kTraceInfoUnavailable
327 : strings::Printf("%lld", static_cast<long long>(limit))));
328 result.push_back(std::make_pair(
329 "buffer_size",
330 size == -1 ? kTraceInfoUnavailable
331 : strings::Printf("%lld", static_cast<long long>(size))));
332 result.push_back(std::make_pair(
333 "autotune",
334 dataset()->buffer_size_ == model::kAutotune ? "true" : "false"));
335 result.push_back(std::make_pair(
336 "autotune_mode", legacy_autotune_ ? "legacy" : "performance"));
337 if (dataset()->slack_period_ > 0) {
338 result.push_back(std::make_pair(
339 "slack",
340 strings::Printf("%lld", static_cast<long long>(slack_us_.load()))));
341 }
342 return result;
343 }
344
345 private:
346 // A buffer element comprises a status and (if that status is
347 // OK) a vector of tensors, representing an element of the input dataset.
348 struct BufferElement {
BufferElementtensorflow::data::PrefetchDatasetOp::Dataset::Iterator::BufferElement349 BufferElement() : uid(tensorflow::EnvTime::NowNanos()) {}
350
351 // The producer sets `status` if getting the input element fails.
352 Status status;
353 // The buffered data element.
354 std::vector<Tensor> value;
355 int64 created_us;
356 const uint64 uid;
357 };
358
buffer_limit() const359 int64 buffer_limit() const TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
360 if (legacy_autotune_) {
361 return auto_tuner_.buffer_limit();
362 }
363 return buffer_size_->value;
364 }
365
CancelThreads()366 void CancelThreads() TF_LOCKS_EXCLUDED(mu_) {
367 cancellation_manager_->StartCancel();
368 mutex_lock l(*mu_);
369 cancelled_ = true;
370 cond_var_->notify_all();
371 }
372
Consume(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)373 Status Consume(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
374 bool* end_of_sequence) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
375 const auto& stats_aggregator = ctx->stats_aggregator();
376 if (stats_aggregator) {
377 double buffer_limit_ = buffer_limit();
378 stats_aggregator->AddToHistogram(
379 stats_utils::BufferUtilizationHistogramName(dataset()->node_name()),
380 {static_cast<float>(buffer_.size()) /
381 static_cast<float>(buffer_limit_)},
382 num_elements());
383 stats_aggregator->AddScalar(
384 stats_utils::BufferSizeScalarName(dataset()->node_name()),
385 static_cast<float>(buffer_.size()), num_elements());
386 stats_aggregator->AddScalar(
387 stats_utils::BufferCapacityScalarName(dataset()->node_name()),
388 static_cast<float>(buffer_limit_), num_elements());
389 }
390 // A new element is available. Forward the status from computing it, and
391 // (if we successfully got an element) the output values.
392 Status s = buffer_.front().status;
393 if (s.ok()) {
394 int64_t buffer_element_id = buffer_.front().uid;
395 profiler::TraceMe traceme(
396 [&] {
397 return profiler::TraceMeEncode(
398 "PrefetchConsume", {{"element_id", buffer_element_id}});
399 },
400 profiler::kInfo);
401 if (dataset()->slack_period_ > 0 &&
402 (num_elements() + 1) % dataset()->slack_period_ == 0) {
403 // TODO(rachelim): Consider doing something more sophisticated
404 // to decide how long to sleep for; e.g. using a kalman filter.
405 int64_t slack_us = EnvTime::NowMicros() - buffer_.front().created_us;
406 // Every slack_period_-th element, update the most recent slack time,
407 // measured by the duration between when the element is prefetched
408 // and when it is consumed. We add kSleepFactor * slack_us_ to the
409 // measurement because we slept for that duration before prefetching
410 // the element.
411 slack_us_ = kSleepFactor * slack_us_ + slack_us;
412 VLOG(2) << "Setting slack_us_: " << slack_us_;
413 }
414 *out_tensors = std::move(buffer_.front().value);
415 RecordBufferDequeue(ctx, *out_tensors);
416 } else {
417 // If status not ok, we still record the dequeue event to make sure each
418 // enqueue event is paired with a dequeue event even in the presence of
419 // errors.
420 RecordBufferDequeue(ctx, buffer_.front().value);
421 }
422 if (legacy_autotune_) {
423 auto_tuner_.RecordConsumption(buffer_.size());
424 buffer_size_->value = auto_tuner_.buffer_limit();
425 }
426 buffer_.pop_front();
427 *end_of_sequence = false;
428
429 // Wake the prefetch thread, in case it has been waiting for space
430 // in the buffer. Also wake up threads from other calls to GetNext.
431 //
432 // TODO(mrry): Consider using different condition variables for
433 // GetNext and Prefetch.
434 cond_var_->notify_all();
435 return s;
436 }
437
EnsurePrefetchThreadStarted(IteratorContext * ctx)438 Status EnsurePrefetchThreadStarted(IteratorContext* ctx)
439 TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
440 if (!prefetch_thread_) {
441 std::shared_ptr<IteratorContext> new_ctx =
442 std::make_shared<IteratorContext>(*ctx);
443 prefetch_thread_ = ctx->StartThread(
444 "tf_data_prefetch", [this, new_ctx]() { PrefetchThread(new_ctx); });
445 }
446 return Status::OK();
447 }
448
449 // Prefetches elements of the input, storing results in an internal buffer.
450 //
451 // It owns the iterator context passed to it.
PrefetchThread(const std::shared_ptr<IteratorContext> & ctx)452 void PrefetchThread(const std::shared_ptr<IteratorContext>& ctx) {
453 RecordStart(ctx.get());
454 auto cleanup = gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); });
455 // Keep track of where we are in an iteration "burst"
456 int num_produced = 0;
457 while (true) {
458 // 1. Wait for a slot in the buffer.
459 {
460 mutex_lock l(*mu_);
461 while (!cancelled_ && buffer_.size() >= buffer_limit()) {
462 RecordStop(ctx.get());
463 cond_var_->wait(l);
464 RecordStart(ctx.get());
465 }
466
467 if (cancelled_) {
468 prefetch_thread_finished_ = true;
469 cond_var_->notify_all();
470 return;
471 }
472 }
473
474 if (dataset()->slack_period_ > 0 &&
475 num_produced % dataset()->slack_period_ == 0) {
476 // For the first element in the "burst", sleep for a bit if there is
477 // slack.
478 VLOG(2) << "Sleeping for: " << slack_us_ * kSleepFactor;
479 ctx->env()->SleepForMicroseconds(slack_us_ * kSleepFactor);
480 }
481
482 // 2. Read the next element.
483 // Acquire the input mutex since we will be reading an element from the
484 // input iterator. Note that we do not wish to release this mutex till
485 // we have added the fetched element to the `buffer_` else there will be
486 // local state that may be missed by SaveInternal.
487 mutex_lock input_l(input_mu_);
488 bool end_of_sequence;
489 BufferElement buffer_element;
490 {
491 profiler::TraceMe traceme(
492 [&] {
493 return profiler::TraceMeEncode(
494 "PrefetchProduce", {{"element_id", buffer_element.uid}});
495 },
496 profiler::kInfo);
497 buffer_element.status = input_impl_->GetNext(
498 ctx.get(), &buffer_element.value, &end_of_sequence);
499 }
500 if (buffer_element.status.ok() && end_of_sequence) {
501 mutex_lock l(*mu_);
502 prefetch_thread_finished_ = true;
503 cond_var_->notify_all();
504 return;
505 }
506
507 // 3. Signal that the element has been produced.
508 {
509 mutex_lock l(*mu_);
510 RecordBufferEnqueue(ctx.get(), buffer_element.value);
511 buffer_element.created_us = EnvTime::NowMicros();
512 buffer_.push_back(std::move(buffer_element));
513 cond_var_->notify_all();
514 }
515 ++num_produced;
516 }
517 }
518
WriteStatus(IteratorStateWriter * writer,size_t index,const Status & status)519 Status WriteStatus(IteratorStateWriter* writer, size_t index,
520 const Status& status) TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
521 TF_RETURN_IF_ERROR(
522 writer->WriteScalar(absl::StrCat(prefix(), "::", index), CodeKey(),
523 static_cast<int64>(status.code())));
524 if (!status.ok()) {
525 TF_RETURN_IF_ERROR(
526 writer->WriteScalar(absl::StrCat(prefix(), "::", index),
527 ErrorMessageKey(), status.error_message()));
528 }
529 return Status::OK();
530 }
531
ReadStatus(IteratorStateReader * reader,size_t index,Status * status)532 Status ReadStatus(IteratorStateReader* reader, size_t index, Status* status)
533 TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
534 int64_t code_int;
535 TF_RETURN_IF_ERROR(reader->ReadScalar(absl::StrCat(prefix(), "::", index),
536 CodeKey(), &code_int));
537 error::Code code = static_cast<error::Code>(code_int);
538
539 if (code != error::Code::OK) {
540 tstring error_message;
541 TF_RETURN_IF_ERROR(
542 reader->ReadScalar(absl::StrCat(prefix(), "::", index),
543 ErrorMessageKey(), &error_message));
544 *status = Status(code, error_message);
545 } else {
546 *status = Status::OK();
547 }
548 return Status::OK();
549 }
550
CodeKey()551 string CodeKey() { return absl::StrCat(kStatus, kCodeSuffix); }
552
ErrorMessageKey()553 string ErrorMessageKey() {
554 return absl::StrCat(kStatus, kErrorMessageSuffix);
555 }
556
557 // This mutex is used to ensure exclusivity between multiple threads
558 // reading/writing this iterator's local state.
559 //
560 // NOTE: We should never call GetNext on the input while holding this mutex.
561 const std::shared_ptr<mutex> mu_;
562 // This mutex is used to ensure exclusivity between multiple threads
563 // accessing the input iterator. We keep this separate from `mu_` to allow
564 // prefetching to run in parallel with GetNext calls.
565 mutex input_mu_ TF_ACQUIRED_BEFORE(*mu_);
566 // Controls cancellation of `input_impl_`. Must be ordered before
567 // `input_impl_` so that `input_impl_` is destroyed first.
568 std::unique_ptr<CancellationManager> cancellation_manager_;
569 std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(input_mu_);
570 const std::shared_ptr<condition_variable> cond_var_;
571 const int64 buffer_size_min_;
572 PrefetchAutotuner auto_tuner_ TF_GUARDED_BY(*mu_);
573 std::deque<BufferElement> buffer_ TF_GUARDED_BY(*mu_);
574 std::unique_ptr<Thread> prefetch_thread_ TF_GUARDED_BY(*mu_);
575 bool cancelled_ TF_GUARDED_BY(*mu_) = false;
576 bool prefetch_thread_finished_ TF_GUARDED_BY(*mu_) = false;
577 const bool legacy_autotune_;
578
579 std::atomic<int64> slack_us_;
580
581 // If legacy_autotune_ is false, identifies the maximum size of the buffer.
582 const std::shared_ptr<model::SharedState> buffer_size_;
583
584 // Method for deregistering the cancellation callback.
585 std::function<void()> deregister_fn_;
586 };
587 const DatasetBase* const input_;
588 const int64 buffer_size_;
589
590 // If non-zero, determines the period between injecting "slack" into the
591 // execution.
592 const int64 slack_period_;
593
594 // Determines whether legacy autotuning should be used.
595 const bool legacy_autotune_ = true;
596
597 // If autotune is enabled, determines the minimal value of `buffer_size`
598 // parameter.
599 const int64 buffer_size_min_ = 0;
600
601 TraceMeMetadata traceme_metadata_;
602 };
603
PrefetchDatasetOp(OpKernelConstruction * ctx)604 PrefetchDatasetOp::PrefetchDatasetOp(OpKernelConstruction* ctx)
605 : UnaryDatasetOpKernel(ctx) {
606 if (ctx->HasAttr(kSlackPeriod)) {
607 OP_REQUIRES_OK(ctx, ctx->GetAttr(kSlackPeriod, &slack_period_));
608 }
609 if (ctx->HasAttr(kLegacyAutotune)) {
610 OP_REQUIRES_OK(ctx, ctx->GetAttr(kLegacyAutotune, &legacy_autotune_));
611 }
612 if (ctx->HasAttr(kBufferSizeMin)) {
613 OP_REQUIRES_OK(ctx, ctx->GetAttr(kBufferSizeMin, &buffer_size_min_));
614 }
615 }
616
MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)617 void PrefetchDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
618 DatasetBase** output) {
619 int64_t buffer_size = 0;
620 OP_REQUIRES_OK(ctx,
621 ParseScalarArgument<int64>(ctx, kBufferSize, &buffer_size));
622 OP_REQUIRES(ctx, buffer_size >= 0 || buffer_size == model::kAutotune,
623 errors::InvalidArgument("buffer_size must be >= 0 or set "
624 "buffer_size to be ",
625 model::kAutotune, " for auto-tuning"));
626
627 if (buffer_size == model::kAutotune) {
628 metrics::RecordTFDataAutotune(kDatasetType);
629 }
630
631 *output = new Dataset(ctx, input, buffer_size, slack_period_,
632 legacy_autotune_, buffer_size_min_);
633 }
634
635 namespace {
636 REGISTER_KERNEL_BUILDER(Name("PrefetchDataset").Device(DEVICE_CPU).Priority(2),
637 PrefetchDatasetOp);
638 REGISTER_KERNEL_BUILDER(Name("PrefetchDataset")
639 .Device(DEVICE_GPU)
640 .HostMemory("buffer_size")
641 .HostMemory("input_dataset")
642 .HostMemory("handle")
643 .Priority(1),
644 PrefetchDatasetOp);
645 } // namespace
646
647 } // namespace data
648 } // namespace tensorflow
649