1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/shuffle_dataset_op.h"
16
17 #include <cstdint>
18 #include <deque>
19 #include <string>
20 #include <tuple>
21 #include <utility>
22 #include <vector>
23
24 #include "tensorflow/core/data/dataset_utils.h"
25 #include "tensorflow/core/data/name_utils.h"
26 #include "tensorflow/core/data/serialization_utils.h"
27 #include "tensorflow/core/framework/dataset.h"
28 #include "tensorflow/core/framework/partial_tensor_shape.h"
29 #include "tensorflow/core/framework/resource_mgr.h"
30 #include "tensorflow/core/framework/tensor.h"
31 #include "tensorflow/core/kernels/data/random_seed_ops.h"
32 #include "tensorflow/core/lib/core/errors.h"
33 #include "tensorflow/core/lib/random/philox_random.h"
34 #include "tensorflow/core/lib/random/random.h"
35 #include "tensorflow/core/lib/random/random_distributions.h"
36 #include "tensorflow/core/platform/errors.h"
37 #include "tensorflow/core/platform/stringprintf.h"
38
39 namespace tensorflow {
40 namespace data {
41
42 // See documentation in ../../ops/dataset_ops.cc for a high-level
43 // description of the following op.
44
45 /* static */ constexpr const char* const ShuffleDatasetOpBase::kInputDataset;
46 /* static */ constexpr const char* const ShuffleDatasetOpBase::kBufferSize;
47 /* static */ constexpr const char* const ShuffleDatasetOpBase::kSeed;
48 /* static */ constexpr const char* const ShuffleDatasetOpBase::kSeed2;
49 /* static */ constexpr const char* const ShuffleDatasetOpBase::kOutputTypes;
50 /* static */ constexpr const char* const ShuffleDatasetOpBase::kOutputShapes;
51 /* static */ constexpr const char* const
52 ShuffleDatasetOpBase::kReshuffleEachIteration;
53
54 /* static */ constexpr const char* const ShuffleDatasetOp::kDatasetType;
55
56 /* static */ constexpr const char* const
57 ShuffleAndRepeatDatasetOp::kDatasetType;
58 /* static */ constexpr const char* const ShuffleAndRepeatDatasetOp::kCount;
59
60 const int64_t kLogIntervalMicros = 10 * 1000000; // 10 seconds.
61 const int64_t kMaxEpochsInBuffer = 3;
62
63 constexpr char kNumRandomSamples[] = "num_random_samples";
64 constexpr char kDataProduced[] = "data_produced";
65 constexpr char kEndOfInputSequence[] = "end_of_input_sequence";
66 constexpr char kEpoch[] = "epoch";
67 constexpr char kNumElements[] = "num_elements";
68 constexpr char kSlicesSize[] = "slices_size";
69 constexpr char kSlicesStart[] = "slices_start";
70 constexpr char kSlicesEnd[] = "slices_end";
71 constexpr char kSeedGenerator[] = "SeedGenerator";
72 constexpr char kEpochNumRandomSamples[] = "epoch_num_random_samples";
73 constexpr char kShuffleDatasetV1[] = "ShuffleDataset";
74 constexpr char kShuffleDatasetV2[] = "ShuffleDatasetV2";
75 constexpr char kShuffleDatasetV3[] = "ShuffleDatasetV3";
76 constexpr char kShuffleAndRepeatDatasetV1[] = "ShuffleAndRepeatDataset";
77 constexpr char kShuffleAndRepeatDatasetV2[] = "ShuffleAndRepeatDatasetV2";
78
ShuffleDatasetOpBase(OpKernelConstruction * ctx)79 ShuffleDatasetOpBase::ShuffleDatasetOpBase(OpKernelConstruction* ctx)
80 : UnaryDatasetOpKernel(ctx) {}
81
82 // Abstract base dataset that implements a shuffling iterator.
83 class ShuffleDatasetOpBase::ShuffleDatasetBase : public DatasetBase {
84 public:
ShuffleDatasetBase(OpKernelContext * ctx,const DatasetBase * input,int64_t buffer_size,std::shared_ptr<SeedGenerator> seed_generator,int64_t count)85 ShuffleDatasetBase(OpKernelContext* ctx, const DatasetBase* input,
86 int64_t buffer_size,
87 std::shared_ptr<SeedGenerator> seed_generator,
88 int64_t count)
89 : DatasetBase(DatasetContext(ctx)),
90 input_(input),
91 buffer_size_(buffer_size),
92 seed_generator_(std::move(seed_generator)),
93 count_(count),
94 traceme_metadata_(
95 {{"buffer_size",
96 strings::Printf("%lld", static_cast<long long>(buffer_size))}}) {
97 input_->Ref();
98 }
99
~ShuffleDatasetBase()100 ~ShuffleDatasetBase() override { input_->Unref(); }
101
102 virtual string op_type() const = 0;
103
output_dtypes() const104 const DataTypeVector& output_dtypes() const override {
105 return input_->output_dtypes();
106 }
107
output_shapes() const108 const std::vector<PartialTensorShape>& output_shapes() const override {
109 return input_->output_shapes();
110 }
111
CardinalityInternal() const112 int64_t CardinalityInternal() const override {
113 if (count_ == -1 || input_->Cardinality() == kInfiniteCardinality) {
114 return kInfiniteCardinality;
115 } else if (input_->Cardinality() == kUnknownCardinality) {
116 return kUnknownCardinality;
117 } else {
118 return input_->Cardinality() * count_;
119 }
120 }
121
CardinalityInternal(CardinalityOptions options) const122 int64_t CardinalityInternal(CardinalityOptions options) const override {
123 if (count_ == -1 || input_->Cardinality(options) == kInfiniteCardinality) {
124 return kInfiniteCardinality;
125 } else if (input_->Cardinality(options) == kUnknownCardinality) {
126 return kUnknownCardinality;
127 } else {
128 return input_->Cardinality(options) * count_;
129 }
130 }
131
InputDatasets(std::vector<const DatasetBase * > * inputs) const132 Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
133 inputs->push_back(input_);
134 return OkStatus();
135 }
136
CheckExternalState() const137 Status CheckExternalState() const override {
138 return input_->CheckExternalState();
139 }
140
Get(OpKernelContext * ctx,int64 index,std::vector<Tensor> * out_tensors) const141 Status Get(OpKernelContext* ctx, int64 index,
142 std::vector<Tensor>* out_tensors) const override {
143 TF_RETURN_IF_ERROR(CheckRandomAccessCompatible(index));
144 {
145 mutex_lock l(mu_);
146 if (shuffled_indices_.empty()) {
147 InitializeRandomAccessIndices();
148 }
149 }
150 int64 shuffled_index;
151 {
152 tf_shared_lock l(mu_);
153 shuffled_index = shuffled_indices_[index];
154 }
155 TF_RETURN_IF_ERROR(input_->Get(ctx, shuffled_index, out_tensors));
156 return OkStatus();
157 }
158
DebugString() const159 string DebugString() const override {
160 name_utils::DatasetDebugStringParams params;
161 params.set_args(buffer_size_, seed_generator_->seed(),
162 seed_generator_->seed2(), count_);
163 return name_utils::DatasetDebugString(op_type(), params);
164 }
165
MakeIteratorInternal(const string & prefix) const166 std::unique_ptr<IteratorBase> MakeIteratorInternal(
167 const string& prefix) const override {
168 return std::make_unique<Iterator>(
169 Iterator::Params{this, name_utils::IteratorPrefix(op_type(), prefix)},
170 seed_generator_.get());
171 }
172
InitializeRandomAccessIndices() const173 void InitializeRandomAccessIndices() const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
174 const int64 cardinality = Cardinality();
175 shuffled_indices_ = std::vector<std::int64_t>(cardinality);
176 std::iota(shuffled_indices_.begin(), shuffled_indices_.end(), 0);
177 int64_t shuffled_index = 0;
178 random::PhiloxRandom parent_generator =
179 random::PhiloxRandom(seed_generator_->seed(), seed_generator_->seed2());
180 random::SingleSampleAdapter<random::PhiloxRandom> generator =
181 random::SingleSampleAdapter<random::PhiloxRandom>(&parent_generator);
182
183 while (shuffled_index < cardinality) {
184 int64_t offset = generator() % (cardinality - shuffled_index);
185 std::swap(shuffled_indices_[shuffled_index + offset],
186 shuffled_indices_[shuffled_index]);
187 shuffled_index += 1;
188 }
189 }
190
191 protected:
192 class Iterator : public DatasetIterator<ShuffleDatasetBase> {
193 public:
Iterator(const Params & params,SeedGenerator * seed_generator)194 explicit Iterator(const Params& params, SeedGenerator* seed_generator)
195 : DatasetIterator<ShuffleDatasetBase>(params),
196 seed_generator_(seed_generator),
197 parent_generator_(seed_generator->seed(), seed_generator->seed2()),
198 generator_(&parent_generator_) {
199 buffer_ = std::make_unique<std::vector<std::vector<Tensor>>>(
200 params.dataset->buffer_size_);
201 }
202
Initialize(IteratorContext * ctx)203 Status Initialize(IteratorContext* ctx) override {
204 mutex_lock l(mu_);
205 seed_generator_->GenerateSeeds(&seed_, &seed2_);
206 ResetRngs();
207 return OkStatus();
208 }
209
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)210 Status GetNextInternal(IteratorContext* ctx,
211 std::vector<Tensor>* out_tensors,
212 bool* end_of_sequence) override {
213 mutex_lock l(mu_);
214 TF_RETURN_IF_ERROR(FillBuffer(ctx));
215 if (num_elements_ == 0) {
216 DCHECK(input_impl_ == nullptr);
217 *end_of_sequence = true;
218 return OkStatus();
219 }
220
221 *end_of_sequence = false;
222 ClearEmptySlices();
223 DCHECK(!slices_.empty());
224 // Choose an element to produce uniformly at random from the first
225 // slice, and then remove the element from the slice.
226 int64_t offset =
227 Random() % (slices_.front()->end - slices_.front()->start);
228 int64_t index = (slices_.front()->start + offset) % buffer_->size();
229 *out_tensors = std::move(buffer_->at(index));
230 this->RecordBufferDequeue(ctx, *out_tensors);
231 std::swap(buffer_->at(index),
232 buffer_->at(slices_.front()->start % buffer_->size()));
233 slices_.front()->start++;
234 num_elements_--;
235 return OkStatus();
236 }
237
238 protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const239 std::shared_ptr<model::Node> CreateNode(
240 IteratorContext* ctx, model::Node::Args args) const override {
241 return model::MakeKnownRatioNode(std::move(args),
242 /*ratio=*/1);
243 }
244
ResetRngs()245 void ResetRngs() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
246 // Reset the generators based on the current iterator seeds.
247 parent_generator_ = random::PhiloxRandom(seed_, seed2_);
248 generator_ =
249 random::SingleSampleAdapter<random::PhiloxRandom>(&parent_generator_);
250 generator_.Skip(num_random_samples_);
251 }
252
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)253 Status SaveInternal(SerializationContext* ctx,
254 IteratorStateWriter* writer) override {
255 mutex_lock l(mu_);
256 // Save state needed to restore the random number generators.
257 TF_RETURN_IF_ERROR(
258 writer->WriteScalar(full_name(kEpochNumRandomSamples),
259 seed_generator_->num_random_samples()));
260 TF_RETURN_IF_ERROR(writer->WriteScalar(this->full_name(kNumRandomSamples),
261 num_random_samples_));
262 TF_RETURN_IF_ERROR(writer->WriteScalar(this->full_name(kSeed), seed_));
263 TF_RETURN_IF_ERROR(writer->WriteScalar(this->full_name(kSeed2), seed2_));
264
265 // Save input iterator if it hasn't been exhausted else write
266 // "end_of_input_sequence".
267 if (!input_impl_) {
268 TF_RETURN_IF_ERROR(
269 writer->WriteScalar(this->full_name(kEndOfInputSequence), ""));
270 } else {
271 TF_RETURN_IF_ERROR(this->SaveInput(ctx, writer, input_impl_));
272 }
273
274 // Save the epoch counter, buffer, and buffer slices.
275 TF_RETURN_IF_ERROR(writer->WriteScalar(this->full_name(kEpoch), epoch_));
276 TF_RETURN_IF_ERROR(
277 writer->WriteScalar(this->full_name(kNumElements), num_elements_));
278 TF_RETURN_IF_ERROR(WriteElementsToCheckpoint(writer, prefix(), *buffer_));
279 TF_RETURN_IF_ERROR(
280 writer->WriteScalar(this->full_name(kSlicesSize), slices_.size()));
281 for (size_t i = 0; i < slices_.size(); ++i) {
282 TF_RETURN_IF_ERROR(
283 writer->WriteScalar(this->full_name(absl::StrJoin(
284 std::make_tuple(kSlicesStart, i), "_")),
285 slices_[i]->start));
286 TF_RETURN_IF_ERROR(writer->WriteScalar(
287 this->full_name(absl::StrJoin(std::make_tuple(kSlicesEnd, i), "_")),
288 slices_[i]->end));
289 }
290 if (data_produced_) {
291 TF_RETURN_IF_ERROR(
292 writer->WriteScalar(this->full_name(kDataProduced), ""));
293 }
294
295 return OkStatus();
296 }
297
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)298 Status RestoreInternal(IteratorContext* ctx,
299 IteratorStateReader* reader) override {
300 mutex_lock l(mu_);
301 // Restore the random number generators.
302 int64_t num_random_samples;
303 TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kEpochNumRandomSamples),
304 &num_random_samples));
305 seed_generator_->set_num_random_samples(num_random_samples);
306 seed_generator_->Reset();
307 TF_RETURN_IF_ERROR(reader->ReadScalar(this->full_name(kNumRandomSamples),
308 &num_random_samples_));
309 TF_RETURN_IF_ERROR(reader->ReadScalar(this->full_name(kSeed), &seed_));
310 TF_RETURN_IF_ERROR(reader->ReadScalar(this->full_name(kSeed2), &seed2_));
311 ResetRngs();
312
313 // Restore the input iterator if it wasn't already exhausted.
314 if (!reader->Contains(this->full_name(kEndOfInputSequence))) {
315 TF_RETURN_IF_ERROR(this->dataset()->input_->MakeIterator(
316 ctx, this, this->prefix(), &input_impl_));
317 TF_RETURN_IF_ERROR(this->RestoreInput(ctx, reader, input_impl_));
318 } else {
319 input_impl_.reset();
320 }
321
322 // Restore the epoch counter, buffer, and buffer slices.
323 TF_RETURN_IF_ERROR(reader->ReadScalar(this->full_name(kEpoch), &epoch_));
324 TF_RETURN_IF_ERROR(
325 reader->ReadScalar(this->full_name(kNumElements), &num_elements_));
326 size_t slices_size;
327 {
328 int64_t temp;
329 TF_RETURN_IF_ERROR(
330 reader->ReadScalar(this->full_name(kSlicesSize), &temp));
331 slices_size = static_cast<size_t>(temp);
332 }
333 buffer_ = std::make_unique<std::vector<std::vector<Tensor>>>();
334 TF_RETURN_IF_ERROR(
335 ReadElementsFromCheckpoint(ctx, reader, prefix(), buffer_.get()));
336 for (const auto& element : *buffer_) {
337 RecordBufferEnqueue(ctx, element);
338 }
339 buffer_->resize(dataset()->buffer_size_);
340 slices_.clear();
341 for (size_t i = 0; i < slices_size; ++i) {
342 int64_t start;
343 TF_RETURN_IF_ERROR(
344 reader->ReadScalar(this->full_name(absl::StrJoin(
345 std::make_tuple(kSlicesStart, i), "_")),
346 &start));
347 int64_t end;
348 TF_RETURN_IF_ERROR(reader->ReadScalar(
349 this->full_name(absl::StrJoin(std::make_tuple(kSlicesEnd, i), "_")),
350 &end));
351 slices_.push_back(std::make_unique<Slice>(start, end));
352 }
353 data_produced_ = reader->Contains(this->full_name(kDataProduced));
354
355 return OkStatus();
356 }
357
GetTraceMeMetadata() const358 TraceMeMetadata GetTraceMeMetadata() const override {
359 return this->dataset()->traceme_metadata_;
360 }
361
362 private:
363 // Used to represent slices of `buffer_` that belong to different epochs.
364 // The invariant maintained by the implementation is: `start` <= `end`.
365 // When using `start` and `end` to index into `buffer_`, their values
366 // should be taken modulo the size of `buffer_` as their absolute value
367 // can be greater than the range of `buffer_`.
368 struct Slice {
Slicetensorflow::data::ShuffleDatasetOpBase::ShuffleDatasetBase::Iterator::Slice369 Slice(int64_t start, int64_t end) : start(start), end(end) {}
370
371 int64_t start;
372 int64_t end;
373 };
374
Random()375 random::SingleSampleAdapter<random::PhiloxRandom>::ResultType Random()
376 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
377 num_random_samples_++;
378 auto out = generator_();
379 return out;
380 }
381
382 // Fills the shuffle buffer, preparing the buffer for sampling.
FillBuffer(IteratorContext * ctx)383 Status FillBuffer(IteratorContext* ctx) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
384 int64_t start_micros = EnvTime::NowMicros();
385 int64_t num_log_entries = 0;
386 while (ShouldFillBuffer()) {
387 if (EnvTime::NowMicros() >
388 ((num_log_entries + 1) * kLogIntervalMicros) + start_micros) {
389 num_log_entries++;
390 LOG(INFO) << "Filling up shuffle buffer (this may take a while): "
391 << num_elements_ << " of " << BufferSizeString();
392 }
393 if (!input_impl_) {
394 TF_RETURN_IF_ERROR(PrepareNextEpoch(ctx));
395 }
396 std::vector<Tensor> input_element;
397 bool end_of_input_sequence = false;
398 TF_RETURN_IF_ERROR(
399 input_impl_->GetNext(ctx, &input_element, &end_of_input_sequence));
400 if (!end_of_input_sequence) {
401 AddToShuffleBuffer(ctx, std::move(input_element));
402 continue;
403 }
404 input_impl_.reset();
405 // Reached end of input_impl_.
406 if (ctx->split_providers().empty() && !data_produced_ &&
407 this->dataset()->count_ == -1) {
408 // If we encounter the end of sequence without producing data, we
409 // terminate the iteration immediately. (Otherwise, this iterator
410 // would loop infinitely and never produce a value.)
411 return OkStatus();
412 }
413 }
414 if (num_log_entries > 0) {
415 LOG(INFO) << "Shuffle buffer filled.";
416 }
417 return OkStatus();
418 }
419
ShouldFillBuffer()420 bool ShouldFillBuffer() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
421 if (!input_impl_ && dataset()->count_ != -1 &&
422 epoch_ >= dataset()->count_) {
423 return false;
424 }
425 if (slices_.size() > kMaxEpochsInBuffer && num_elements_ > 0) {
426 // When the elements stored in `buffer_` span more than
427 // `kMaxEpochsInBuffer` epochs, we do not fill the buffer further to
428 // conserve memory. This means that the upper bound on the size of
429 // `buffer_` is `kMaxEpochsInBuffer * cardinality(input_dataset) +
430 // 1`.
431 return false;
432 }
433 return num_elements_ < buffer_->size();
434 }
435
PrepareNextEpoch(IteratorContext * ctx)436 Status PrepareNextEpoch(IteratorContext* ctx)
437 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
438 if (epoch_ == 0) {
439 slices_.push_back(std::make_unique<Slice>(0, 0));
440 } else {
441 int64_t n = slices_.back()->end;
442 slices_.push_back(std::make_unique<Slice>(n, n));
443 for (const auto& provider : ctx->split_providers()) {
444 TF_RETURN_IF_ERROR(provider->Reset());
445 }
446 }
447 TF_RETURN_IF_ERROR(this->dataset()->input_->MakeIterator(
448 ctx, this, this->prefix(), &input_impl_));
449 epoch_++;
450 return OkStatus();
451 }
452
AddToShuffleBuffer(IteratorContext * ctx,std::vector<Tensor> && element)453 void AddToShuffleBuffer(IteratorContext* ctx, std::vector<Tensor>&& element)
454 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
455 data_produced_ = true;
456 if (num_elements_ == 0) {
457 VLOG(1) << "Starting to fill up shuffle buffer of size: "
458 << BufferSizeString();
459 }
460 this->RecordBufferEnqueue(ctx, element);
461 size_t index = slices_.back()->end % buffer_->size();
462 buffer_->at(index) = std::move(element);
463 num_elements_++;
464 slices_.back()->end++;
465 }
466
ClearEmptySlices()467 void ClearEmptySlices() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
468 // Garbage collect all empty slices.
469 while (slices_.front()->start == slices_.front()->end) {
470 slices_.pop_front();
471 // Reinitialize the RNG state for the next epoch.
472 num_random_samples_ = 0;
473 seed_generator_->GenerateSeeds(&seed_, &seed2_);
474 ResetRngs();
475 }
476 }
477
BufferSizeString()478 std::string BufferSizeString() {
479 return absl::StrCat(dataset()->buffer_size_);
480 }
481
482 mutex mu_;
483 SeedGenerator* const seed_generator_ TF_GUARDED_BY(mu_); // Not owned.
484 std::unique_ptr<std::vector<std::vector<Tensor>>> buffer_
485 TF_GUARDED_BY(mu_);
486 std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_) = nullptr;
487 int64_t epoch_ TF_GUARDED_BY(mu_) = 0;
488 int64_t num_elements_ TF_GUARDED_BY(mu_) = 0;
489 int64_t seed_ TF_GUARDED_BY(mu_) = 0;
490 int64_t seed2_ TF_GUARDED_BY(mu_) = 0;
491 // Indices into `buffer_` indicating which data belongs to which epoch.
492 // The slice at the front of the deque references data from the earliest
493 // buffered epoch. It is an invariant that all slices reference
494 // non-overlapping sections of `buffer_`.
495 std::deque<std::unique_ptr<Slice>> slices_ TF_GUARDED_BY(mu_);
496 random::PhiloxRandom parent_generator_ TF_GUARDED_BY(mu_);
497 random::SingleSampleAdapter<random::PhiloxRandom> generator_
498 TF_GUARDED_BY(mu_);
499 int64_t num_random_samples_ TF_GUARDED_BY(mu_) = 0;
500 bool data_produced_ TF_GUARDED_BY(mu_) = false;
501 };
502
503 const DatasetBase* const input_;
504 const int64_t buffer_size_;
505 const std::shared_ptr<SeedGenerator> seed_generator_;
506 // The number of epochs to run for. Normally this is just 1, but sometimes we
507 // fuse shuffle and repeat together, and make the shuffle dataset op
508 // responsible for repeating as well.
509 const int64_t count_;
510 const TraceMeMetadata traceme_metadata_;
511 mutable mutex mu_;
512 mutable std::vector<std::int64_t> shuffled_indices_ TF_GUARDED_BY(mu_);
513 }; // ShuffleDatasetBase
514
515 // This version of memory dataset has an exclusive ownership of the seed
516 // generator resource. It supports sharing of the seed generator across
517 // different iterations of the `repeat` transformation but not across different
518 // iterators.
519 class ShuffleDatasetOp::Dataset : public ShuffleDatasetBase {
520 public:
Dataset(OpKernelContext * ctx,const DatasetBase * input,int64_t buffer_size,int64_t count,RandomSeeds && seeds,SeedGeneratorManager * manager,ResourceHandle && resource_handle)521 Dataset(OpKernelContext* ctx, const DatasetBase* input, int64_t buffer_size,
522 int64_t count, RandomSeeds&& seeds, SeedGeneratorManager* manager,
523 ResourceHandle&& resource_handle)
524 : ShuffleDatasetBase(ctx, input, buffer_size, manager->get(), count),
525 manager_(manager),
526 resource_handle_(std::move(resource_handle)),
527 resource_mgr_(ctx->resource_manager()),
528 seeds_(std::move(seeds)) {}
529
~Dataset()530 ~Dataset() override {
531 manager_->Unref();
532 Status s = resource_mgr_->Delete<SeedGeneratorManager>(
533 resource_handle_.container(), resource_handle_.name());
534 if (!s.ok()) {
535 LOG(WARNING) << "Failed to delete RNG resource: " << s.ToString();
536 }
537 }
538
op_type() const539 string op_type() const override { return kDatasetType; }
540
541 protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const542 Status AsGraphDefInternal(SerializationContext* ctx,
543 DatasetGraphDefBuilder* b,
544 Node** output) const override {
545 Node* input_graph_node = nullptr;
546 TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
547 Node* buffer_size_node = nullptr;
548 Node* seed_node = nullptr;
549 Node* seed2_node = nullptr;
550 AttrValue reshuffle_each_iteration;
551
552 TF_RETURN_IF_ERROR(b->AddScalar(buffer_size_, &buffer_size_node));
553 TF_RETURN_IF_ERROR(b->AddScalar(seeds_.input_seed(), &seed_node));
554 TF_RETURN_IF_ERROR(b->AddScalar(seeds_.input_seed2(), &seed2_node));
555 b->BuildAttrValue(seed_generator_->reshuffle_each_iteration(),
556 &reshuffle_each_iteration);
557 TF_RETURN_IF_ERROR(b->AddDataset(
558 this,
559 {input_graph_node, buffer_size_node, seed_node, seed2_node}, // Inputs
560 {std::make_pair(kReshuffleEachIteration,
561 reshuffle_each_iteration)}, // Attrs
562 output));
563 return OkStatus();
564 }
565
566 private:
567 SeedGeneratorManager* const manager_; // Owned.
568 const ResourceHandle resource_handle_;
569 ResourceMgr* const resource_mgr_; // Not owned.
570 const RandomSeeds seeds_;
571 };
572
573 // This version of shuffle dataset has a shared ownership of the seed generator
574 // resource. It supports sharing of the generator state across different
575 // iterations of the `repeat` transformation and also across different
576 // iterators.
577 class ShuffleDatasetOp::DatasetV2 : public ShuffleDatasetBase {
578 public:
DatasetV2(OpKernelContext * ctx,const DatasetBase * input,int64_t buffer_size,int64_t count,SeedGeneratorManager * manager,ResourceHandle && resource_handle,bool owns_resource)579 DatasetV2(OpKernelContext* ctx, const DatasetBase* input, int64_t buffer_size,
580 int64_t count, SeedGeneratorManager* manager,
581 ResourceHandle&& resource_handle, bool owns_resource)
582 : ShuffleDatasetBase(ctx, input, buffer_size, manager->get(), count),
583 manager_(manager),
584 owns_resource_(owns_resource),
585 resource_handle_(std::move(resource_handle)),
586 resource_mgr_(ctx->resource_manager()) {}
587
~DatasetV2()588 ~DatasetV2() override {
589 manager_->Unref();
590 if (owns_resource_) {
591 Status s = resource_mgr_->Delete<SeedGeneratorManager>(
592 resource_handle_.container(), resource_handle_.name());
593 if (!s.ok()) {
594 LOG(WARNING) << "Failed to delete RNG resource: " << s.ToString();
595 }
596 }
597 }
598
op_type() const599 string op_type() const override { return kDatasetType; }
600
601 protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const602 Status AsGraphDefInternal(SerializationContext* ctx,
603 DatasetGraphDefBuilder* b,
604 Node** output) const override {
605 Node* input_graph_node = nullptr;
606 TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
607 Node* buffer_size_node = nullptr;
608 TF_RETURN_IF_ERROR(b->AddScalar(buffer_size_, &buffer_size_node));
609 Node* resource_handle_node = nullptr;
610 Tensor handle(DT_RESOURCE, TensorShape({}));
611 handle.scalar<ResourceHandle>()() = resource_handle_;
612 TF_RETURN_IF_ERROR(b->AddTensor(handle, &resource_handle_node));
613 TF_RETURN_IF_ERROR(b->AddDataset(
614 this,
615 {input_graph_node, buffer_size_node, resource_handle_node}, // Inputs
616 {}, // Attrs
617 output));
618 return OkStatus();
619 }
620
621 private:
622 SeedGeneratorManager* const manager_; // Owned.
623 const bool owns_resource_;
624 const ResourceHandle resource_handle_;
625 ResourceMgr* const resource_mgr_; // Not owned.
626 };
627
628 // This version of shuffle dataset extends the functionality of DatasetV2 with
629 // the ability to preserve seed generator configuration (i.e. initial seeds and
630 // whether to reshuffle each iteration) across serialization of the dataset.
631 class ShuffleDatasetOp::DatasetV3 : public ShuffleDatasetBase {
632 public:
DatasetV3(OpKernelContext * ctx,const DatasetBase * input,int64_t buffer_size,int64_t count,RandomSeeds && seeds,SeedGeneratorManager * manager,ResourceHandle && resource_handle,bool owns_resource)633 DatasetV3(OpKernelContext* ctx, const DatasetBase* input, int64_t buffer_size,
634 int64_t count, RandomSeeds&& seeds, SeedGeneratorManager* manager,
635 ResourceHandle&& resource_handle, bool owns_resource)
636 : ShuffleDatasetBase(ctx, input, buffer_size, manager->get(), count),
637 manager_(manager),
638 owns_resource_(owns_resource),
639 resource_handle_(std::move(resource_handle)),
640 resource_mgr_(ctx->resource_manager()),
641 seeds_(std::move(seeds)) {}
642
~DatasetV3()643 ~DatasetV3() override {
644 manager_->Unref();
645 if (owns_resource_) {
646 Status s = resource_mgr_->Delete<SeedGeneratorManager>(
647 resource_handle_.container(), resource_handle_.name());
648 if (!s.ok()) {
649 LOG(WARNING) << "Failed to delete RNG resource: " << s.ToString();
650 }
651 }
652 }
653
op_type() const654 string op_type() const override { return kDatasetType; }
655
656 protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const657 Status AsGraphDefInternal(SerializationContext* ctx,
658 DatasetGraphDefBuilder* b,
659 Node** output) const override {
660 Node* input_graph_node = nullptr;
661 TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
662 Node* buffer_size_node = nullptr;
663 Node* seed_node = nullptr;
664 Node* seed2_node = nullptr;
665 TF_RETURN_IF_ERROR(b->AddScalar(buffer_size_, &buffer_size_node));
666 TF_RETURN_IF_ERROR(b->AddScalar(seeds_.input_seed(), &seed_node));
667 TF_RETURN_IF_ERROR(b->AddScalar(seeds_.input_seed2(), &seed2_node));
668 Node* resource_handle_node = nullptr;
669 Tensor handle(DT_RESOURCE, TensorShape({}));
670 handle.scalar<ResourceHandle>()() = resource_handle_;
671 TF_RETURN_IF_ERROR(b->AddTensor(handle, &resource_handle_node));
672 AttrValue reshuffle_each_iteration;
673 b->BuildAttrValue(seed_generator_->reshuffle_each_iteration(),
674 &reshuffle_each_iteration);
675 TF_RETURN_IF_ERROR(
676 b->AddDataset(this,
677 {input_graph_node, buffer_size_node, seed_node,
678 seed2_node, resource_handle_node}, // Inputs
679 {std::make_pair(kReshuffleEachIteration,
680 reshuffle_each_iteration)}, // Attrs
681 output));
682 return OkStatus();
683 }
684
685 private:
686 SeedGeneratorManager* const manager_; // Owned
687 const bool owns_resource_;
688 const ResourceHandle resource_handle_;
689 ResourceMgr* const resource_mgr_; // Not owned.
690 const RandomSeeds seeds_;
691 };
692
ShuffleDatasetOp(OpKernelConstruction * ctx)693 ShuffleDatasetOp::ShuffleDatasetOp(OpKernelConstruction* ctx)
694 : ShuffleDatasetOpBase(ctx) {
695 auto& op_name = ctx->def().op();
696 if (op_name == kShuffleDatasetV3) {
697 op_version_ = 3;
698 } else if (op_name == kShuffleDatasetV2) {
699 op_version_ = 2;
700 } else if (op_name == kShuffleDatasetV1) {
701 op_version_ = 1;
702 }
703 if (ctx->HasAttr(kReshuffleEachIteration)) {
704 OP_REQUIRES_OK(
705 ctx, ctx->GetAttr(kReshuffleEachIteration, &reshuffle_each_iteration_));
706 }
707 }
708
MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)709 void ShuffleDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
710 DatasetBase** output) {
711 int64_t buffer_size = 0;
712 OP_REQUIRES_OK(ctx,
713 ParseScalarArgument<int64_t>(ctx, kBufferSize, &buffer_size));
714 OP_REQUIRES(
715 ctx, buffer_size > 0,
716 errors::InvalidArgument("buffer_size must be greater than zero."));
717
718 int64_t count = 1;
719 static std::atomic<int64_t> resource_id_counter(0);
720 const string& container = ctx->resource_manager()->default_container();
721 auto name = strings::StrCat(ctx->op_kernel().name(), "/", kSeedGenerator, "_",
722 resource_id_counter.fetch_add(1));
723 if (op_version_ == 3) {
724 auto handle = HandleFromInput(ctx, 4);
725 SeedGeneratorManager* manager = nullptr;
726 Status s = ctx->resource_manager()->Lookup<SeedGeneratorManager>(
727 handle.container(), handle.name(), &manager);
728 int64_t seed;
729 OP_REQUIRES_OK(ctx, ParseScalarArgument<int64_t>(ctx, kSeed, &seed));
730 int64_t seed2;
731 OP_REQUIRES_OK(ctx, ParseScalarArgument<int64_t>(ctx, kSeed2, &seed2));
732 RandomSeeds seeds(seed, seed2);
733 bool owns_resource = false;
734 if (errors::IsNotFound(s)) {
735 owns_resource = true;
736 OP_REQUIRES_OK(
737 ctx,
738 ctx->resource_manager()->LookupOrCreate<SeedGeneratorManager>(
739 container, name, &manager,
740 [reshuffle = reshuffle_each_iteration_,
741 &seeds](SeedGeneratorManager** manager) {
742 if (reshuffle) {
743 *manager =
744 new SeedGeneratorManager(new RandomSeedGenerator(seeds));
745 } else {
746 *manager =
747 new SeedGeneratorManager(new FixedSeedGenerator(seeds));
748 }
749 return OkStatus();
750 }));
751 handle = MakeResourceHandle<SeedGenerator>(ctx, container, name);
752 } else {
753 OP_REQUIRES_OK(ctx, s);
754 }
755
756 // Ownership of manager is transferred onto `DatasetV3`.
757 *output = new ShuffleDatasetOp::DatasetV3(ctx, input, buffer_size, count,
758 std::move(seeds), manager,
759 std::move(handle), owns_resource);
760 } else if (op_version_ == 2) {
761 auto handle = HandleFromInput(ctx, 2);
762 SeedGeneratorManager* manager = nullptr;
763 Status s = ctx->resource_manager()->Lookup<SeedGeneratorManager>(
764 handle.container(), handle.name(), &manager);
765 bool owns_resource = false;
766 if (errors::IsNotFound(s)) {
767 owns_resource = true;
768 LOG(WARNING) << "Failed to find seed generator resource. Falling back to "
769 "using a non-deterministically seeded generator and "
770 "reshuffling each iteration.";
771 RandomSeeds seeds(0, 0);
772 OP_REQUIRES_OK(
773 ctx, ctx->resource_manager()->LookupOrCreate<SeedGeneratorManager>(
774 container, name, &manager,
775 [&seeds](SeedGeneratorManager** manager) {
776 *manager = new SeedGeneratorManager(
777 new RandomSeedGenerator(seeds));
778 return OkStatus();
779 }));
780 handle = MakeResourceHandle<SeedGeneratorManager>(ctx, container, name);
781 } else {
782 OP_REQUIRES_OK(ctx, s);
783 }
784
785 // Ownership of manager is transferred onto `DatasetV2`.
786 *output =
787 new ShuffleDatasetOp::DatasetV2(ctx, input, buffer_size, count, manager,
788 std::move(handle), owns_resource);
789 } else {
790 if (op_version_ != 1) {
791 LOG(WARNING) << "Unsupported version of shuffle dataset op: "
792 << op_version_ << ". Defaulting to version 1.";
793 }
794 int64_t seed;
795 OP_REQUIRES_OK(ctx, ParseScalarArgument<int64_t>(ctx, kSeed, &seed));
796 int64_t seed2;
797 OP_REQUIRES_OK(ctx, ParseScalarArgument<int64_t>(ctx, kSeed2, &seed2));
798 RandomSeeds seeds(seed, seed2);
799 SeedGeneratorManager* manager;
800 OP_REQUIRES_OK(
801 ctx,
802 ctx->resource_manager()->LookupOrCreate<SeedGeneratorManager>(
803 container, name, &manager,
804 [reshuffle = reshuffle_each_iteration_,
805 &seeds](SeedGeneratorManager** manager) {
806 if (reshuffle) {
807 *manager =
808 new SeedGeneratorManager(new RandomSeedGenerator(seeds));
809 } else {
810 *manager =
811 new SeedGeneratorManager(new FixedSeedGenerator(seeds));
812 }
813 return OkStatus();
814 }));
815 auto handle =
816 MakeResourceHandle<SeedGeneratorManager>(ctx, container, name);
817
818 // Ownership of manager is transferred onto `Dataset`.
819 *output = new ShuffleDatasetOp::Dataset(ctx, input, buffer_size, count,
820 std::move(seeds), manager,
821 std::move(handle));
822 }
823 }
824
825 class ShuffleAndRepeatDatasetOp::Dataset : public ShuffleDatasetBase {
826 public:
Dataset(OpKernelContext * ctx,const DatasetBase * input,int64_t buffer_size,RandomSeeds && seeds,SeedGeneratorManager * manager,int64_t count,ResourceHandle && resource_handle)827 Dataset(OpKernelContext* ctx, const DatasetBase* input, int64_t buffer_size,
828 RandomSeeds&& seeds, SeedGeneratorManager* manager, int64_t count,
829 ResourceHandle&& resource_handle)
830 : ShuffleDatasetBase(ctx, input, buffer_size, manager->get(), count),
831 manager_(manager),
832 resource_handle_(std::move(resource_handle)),
833 resource_mgr_(ctx->resource_manager()),
834 seeds_(std::move(seeds)) {}
835
~Dataset()836 ~Dataset() override {
837 manager_->Unref();
838 Status s = resource_mgr_->Delete<SeedGeneratorManager>(
839 resource_handle_.container(), resource_handle_.name());
840 if (!s.ok()) {
841 LOG(WARNING) << "Failed to delete RNG resource: " << s.ToString();
842 }
843 }
844
op_type() const845 string op_type() const override { return kDatasetType; }
846
847 protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const848 Status AsGraphDefInternal(SerializationContext* ctx,
849 DatasetGraphDefBuilder* b,
850 Node** output) const override {
851 Node* input_graph_node = nullptr;
852 TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
853 Node* buffer_size = nullptr;
854 Node* seed = nullptr;
855 Node* seed2 = nullptr;
856 Node* count = nullptr;
857
858 TF_RETURN_IF_ERROR(b->AddScalar(buffer_size_, &buffer_size));
859 TF_RETURN_IF_ERROR(b->AddScalar(seeds_.input_seed(), &seed));
860 TF_RETURN_IF_ERROR(b->AddScalar(seeds_.input_seed2(), &seed2));
861 TF_RETURN_IF_ERROR(b->AddScalar(count_, &count));
862 AttrValue reshuffle_each_iteration;
863 b->BuildAttrValue(seed_generator_->reshuffle_each_iteration(),
864 &reshuffle_each_iteration);
865 TF_RETURN_IF_ERROR(b->AddDataset(
866 this, {input_graph_node, buffer_size, seed, seed2, count}, // Inputs
867 {std::make_pair(kReshuffleEachIteration,
868 reshuffle_each_iteration)}, // Attrs
869 output));
870 return OkStatus();
871 }
872
873 private:
874 SeedGeneratorManager* const manager_; // Owned.
875 const ResourceHandle resource_handle_;
876 ResourceMgr* const resource_mgr_; // Not owned.
877 const RandomSeeds seeds_;
878 };
879
880 class ShuffleAndRepeatDatasetOp::DatasetV2 : public ShuffleDatasetBase {
881 public:
DatasetV2(OpKernelContext * ctx,const DatasetBase * input,int64_t buffer_size,int64_t count,RandomSeeds && seeds,SeedGeneratorManager * manager,ResourceHandle && resource_handle,bool owns_resource)882 DatasetV2(OpKernelContext* ctx, const DatasetBase* input, int64_t buffer_size,
883 int64_t count, RandomSeeds&& seeds, SeedGeneratorManager* manager,
884 ResourceHandle&& resource_handle, bool owns_resource)
885 : ShuffleDatasetBase(ctx, input, buffer_size, manager->get(), count),
886 manager_(manager),
887 owns_resource_(owns_resource),
888 resource_handle_(std::move(resource_handle)),
889 resource_mgr_(ctx->resource_manager()),
890 seeds_(std::move(seeds)) {}
891
~DatasetV2()892 ~DatasetV2() override {
893 manager_->Unref();
894 if (owns_resource_) {
895 Status s = resource_mgr_->Delete<SeedGeneratorManager>(
896 resource_handle_.container(), resource_handle_.name());
897 if (!s.ok()) {
898 LOG(WARNING) << "Failed to delete RNG resource: " << s.ToString();
899 }
900 }
901 }
902
op_type() const903 string op_type() const override { return kDatasetType; }
904
905 protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const906 Status AsGraphDefInternal(SerializationContext* ctx,
907 DatasetGraphDefBuilder* b,
908 Node** output) const override {
909 Node* input_graph_node = nullptr;
910 TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
911 Node* buffer_size_node = nullptr;
912 Node* seed_node = nullptr;
913 Node* seed2_node = nullptr;
914 Node* count_node = nullptr;
915 TF_RETURN_IF_ERROR(b->AddScalar(buffer_size_, &buffer_size_node));
916 TF_RETURN_IF_ERROR(b->AddScalar(seeds_.input_seed(), &seed_node));
917 TF_RETURN_IF_ERROR(b->AddScalar(seeds_.input_seed2(), &seed2_node));
918 TF_RETURN_IF_ERROR(b->AddScalar(count_, &count_node));
919 Node* resource_handle_node = nullptr;
920 Tensor handle(DT_RESOURCE, TensorShape({}));
921 handle.scalar<ResourceHandle>()() = resource_handle_;
922 TF_RETURN_IF_ERROR(b->AddTensor(handle, &resource_handle_node));
923 AttrValue reshuffle_each_iteration;
924 b->BuildAttrValue(seed_generator_->reshuffle_each_iteration(),
925 &reshuffle_each_iteration);
926 TF_RETURN_IF_ERROR(
927 b->AddDataset(this,
928 {input_graph_node, buffer_size_node, seed_node,
929 seed2_node, count_node, resource_handle_node}, // Inputs
930 {std::make_pair(kReshuffleEachIteration,
931 reshuffle_each_iteration)}, // Attrs
932 output));
933 return OkStatus();
934 }
935
936 private:
937 SeedGeneratorManager* const manager_; // Owned
938 const bool owns_resource_;
939 const ResourceHandle resource_handle_;
940 ResourceMgr* const resource_mgr_; // Not owned.
941 const RandomSeeds seeds_;
942 };
943
ShuffleAndRepeatDatasetOp(OpKernelConstruction * ctx)944 ShuffleAndRepeatDatasetOp::ShuffleAndRepeatDatasetOp(OpKernelConstruction* ctx)
945 : ShuffleDatasetOpBase(ctx) {
946 auto& op_name = ctx->def().op();
947 if (op_name == kShuffleAndRepeatDatasetV2) {
948 op_version_ = 2;
949 } else if (op_name == kShuffleAndRepeatDatasetV1) {
950 op_version_ = 1;
951 }
952 if (ctx->HasAttr(kReshuffleEachIteration)) {
953 OP_REQUIRES_OK(
954 ctx, ctx->GetAttr(kReshuffleEachIteration, &reshuffle_each_iteration_));
955 }
956 }
957
MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)958 void ShuffleAndRepeatDatasetOp::MakeDataset(OpKernelContext* ctx,
959 DatasetBase* input,
960 DatasetBase** output) {
961 int64_t buffer_size = 0;
962 OP_REQUIRES_OK(ctx,
963 ParseScalarArgument<int64_t>(ctx, kBufferSize, &buffer_size));
964 OP_REQUIRES(
965 ctx, buffer_size > 0,
966 errors::InvalidArgument("buffer_size must be greater than zero."));
967
968 int64_t seed;
969 OP_REQUIRES_OK(ctx, ParseScalarArgument<int64_t>(ctx, kSeed, &seed));
970
971 int64_t seed2;
972 OP_REQUIRES_OK(ctx, ParseScalarArgument<int64_t>(ctx, kSeed2, &seed2));
973
974 int64_t count;
975 OP_REQUIRES_OK(ctx, ParseScalarArgument<int64_t>(ctx, kCount, &count));
976
977 OP_REQUIRES(ctx, count > 0 || count == -1,
978 errors::InvalidArgument(
979 "count must be greater than zero or equal to -1."));
980
981 RandomSeeds seeds(seed, seed2);
982
983 static std::atomic<int64_t> resource_id_counter(0);
984 const string& container = ctx->resource_manager()->default_container();
985 auto name = strings::StrCat(ctx->op_kernel().name(), "/", kSeedGenerator, "_",
986 resource_id_counter.fetch_add(1));
987 if (op_version_ == 2) {
988 auto handle = HandleFromInput(ctx, 5);
989 SeedGeneratorManager* manager = nullptr;
990 Status s = ctx->resource_manager()->Lookup<SeedGeneratorManager>(
991 handle.container(), handle.name(), &manager);
992 bool owns_resource = false;
993 if (errors::IsNotFound(s)) {
994 owns_resource = true;
995 OP_REQUIRES_OK(
996 ctx,
997 ctx->resource_manager()->LookupOrCreate<SeedGeneratorManager>(
998 container, name, &manager,
999 [reshuffle = reshuffle_each_iteration_,
1000 &seeds](SeedGeneratorManager** manager) {
1001 if (reshuffle) {
1002 *manager =
1003 new SeedGeneratorManager(new RandomSeedGenerator(seeds));
1004 } else {
1005 *manager =
1006 new SeedGeneratorManager(new FixedSeedGenerator(seeds));
1007 }
1008 return OkStatus();
1009 }));
1010 handle = MakeResourceHandle<SeedGenerator>(ctx, container, name);
1011 } else {
1012 OP_REQUIRES_OK(ctx, s);
1013 }
1014
1015 // Ownership of manager is transferred onto `DatasetV2`.
1016 *output = new ShuffleAndRepeatDatasetOp::DatasetV2(
1017 ctx, input, buffer_size, count, std::move(seeds), manager,
1018 std::move(handle), owns_resource);
1019 } else {
1020 if (op_version_ != 1) {
1021 LOG(WARNING) << "Unsupported version of shuffle dataset op: "
1022 << op_version_ << ". Defaulting to version 1.";
1023 }
1024 SeedGeneratorManager* manager;
1025 OP_REQUIRES_OK(
1026 ctx,
1027 ctx->resource_manager()->LookupOrCreate<SeedGeneratorManager>(
1028 container, name, &manager,
1029 [reshuffle = reshuffle_each_iteration_,
1030 &seeds](SeedGeneratorManager** manager) {
1031 if (reshuffle) {
1032 *manager =
1033 new SeedGeneratorManager(new RandomSeedGenerator(seeds));
1034 } else {
1035 *manager =
1036 new SeedGeneratorManager(new FixedSeedGenerator(seeds));
1037 }
1038 return OkStatus();
1039 }));
1040 auto handle =
1041 MakeResourceHandle<SeedGeneratorManager>(ctx, container, name);
1042
1043 // Ownership of manager is transferred onto `Dataset`.
1044 *output = new Dataset(ctx, input, buffer_size, std::move(seeds), manager,
1045 count, std::move(handle));
1046 }
1047 }
1048
1049 namespace {
1050 REGISTER_KERNEL_BUILDER(Name("ShuffleDataset").Device(DEVICE_CPU),
1051 ShuffleDatasetOp);
1052
1053 REGISTER_KERNEL_BUILDER(Name("ShuffleDatasetV2").Device(DEVICE_CPU),
1054 ShuffleDatasetOp);
1055
1056 REGISTER_KERNEL_BUILDER(Name("ShuffleDatasetV3").Device(DEVICE_CPU),
1057 ShuffleDatasetOp);
1058
1059 REGISTER_KERNEL_BUILDER(Name("ShuffleAndRepeatDataset").Device(DEVICE_CPU),
1060 ShuffleAndRepeatDatasetOp);
1061
1062 REGISTER_KERNEL_BUILDER(Name("ShuffleAndRepeatDatasetV2").Device(DEVICE_CPU),
1063 ShuffleAndRepeatDatasetOp);
1064 } // namespace
1065 } // namespace data
1066 } // namespace tensorflow
1067