• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/shuffle_dataset_op.h"
16 
17 #include <cstdint>
18 #include <deque>
19 #include <string>
20 #include <tuple>
21 #include <utility>
22 #include <vector>
23 
24 #include "tensorflow/core/data/dataset_utils.h"
25 #include "tensorflow/core/data/name_utils.h"
26 #include "tensorflow/core/data/serialization_utils.h"
27 #include "tensorflow/core/framework/dataset.h"
28 #include "tensorflow/core/framework/partial_tensor_shape.h"
29 #include "tensorflow/core/framework/resource_mgr.h"
30 #include "tensorflow/core/framework/tensor.h"
31 #include "tensorflow/core/kernels/data/random_seed_ops.h"
32 #include "tensorflow/core/lib/core/errors.h"
33 #include "tensorflow/core/lib/random/philox_random.h"
34 #include "tensorflow/core/lib/random/random.h"
35 #include "tensorflow/core/lib/random/random_distributions.h"
36 #include "tensorflow/core/platform/errors.h"
37 #include "tensorflow/core/platform/stringprintf.h"
38 
39 namespace tensorflow {
40 namespace data {
41 
42 // See documentation in ../../ops/dataset_ops.cc for a high-level
43 // description of the following op.
44 
45 /* static */ constexpr const char* const ShuffleDatasetOpBase::kInputDataset;
46 /* static */ constexpr const char* const ShuffleDatasetOpBase::kBufferSize;
47 /* static */ constexpr const char* const ShuffleDatasetOpBase::kSeed;
48 /* static */ constexpr const char* const ShuffleDatasetOpBase::kSeed2;
49 /* static */ constexpr const char* const ShuffleDatasetOpBase::kOutputTypes;
50 /* static */ constexpr const char* const ShuffleDatasetOpBase::kOutputShapes;
51 /* static */ constexpr const char* const
52     ShuffleDatasetOpBase::kReshuffleEachIteration;
53 
54 /* static */ constexpr const char* const ShuffleDatasetOp::kDatasetType;
55 
56 /* static */ constexpr const char* const
57     ShuffleAndRepeatDatasetOp::kDatasetType;
58 /* static */ constexpr const char* const ShuffleAndRepeatDatasetOp::kCount;
59 
60 const int64_t kLogIntervalMicros = 10 * 1000000;  // 10 seconds.
61 const int64_t kMaxEpochsInBuffer = 3;
62 
63 constexpr char kNumRandomSamples[] = "num_random_samples";
64 constexpr char kDataProduced[] = "data_produced";
65 constexpr char kEndOfInputSequence[] = "end_of_input_sequence";
66 constexpr char kEpoch[] = "epoch";
67 constexpr char kNumElements[] = "num_elements";
68 constexpr char kSlicesSize[] = "slices_size";
69 constexpr char kSlicesStart[] = "slices_start";
70 constexpr char kSlicesEnd[] = "slices_end";
71 constexpr char kSeedGenerator[] = "SeedGenerator";
72 constexpr char kEpochNumRandomSamples[] = "epoch_num_random_samples";
73 constexpr char kShuffleDatasetV1[] = "ShuffleDataset";
74 constexpr char kShuffleDatasetV2[] = "ShuffleDatasetV2";
75 constexpr char kShuffleDatasetV3[] = "ShuffleDatasetV3";
76 constexpr char kShuffleAndRepeatDatasetV1[] = "ShuffleAndRepeatDataset";
77 constexpr char kShuffleAndRepeatDatasetV2[] = "ShuffleAndRepeatDatasetV2";
78 
ShuffleDatasetOpBase(OpKernelConstruction * ctx)79 ShuffleDatasetOpBase::ShuffleDatasetOpBase(OpKernelConstruction* ctx)
80     : UnaryDatasetOpKernel(ctx) {}
81 
82 // Abstract base dataset that implements a shuffling iterator.
83 class ShuffleDatasetOpBase::ShuffleDatasetBase : public DatasetBase {
84  public:
ShuffleDatasetBase(OpKernelContext * ctx,const DatasetBase * input,int64_t buffer_size,std::shared_ptr<SeedGenerator> seed_generator,int64_t count)85   ShuffleDatasetBase(OpKernelContext* ctx, const DatasetBase* input,
86                      int64_t buffer_size,
87                      std::shared_ptr<SeedGenerator> seed_generator,
88                      int64_t count)
89       : DatasetBase(DatasetContext(ctx)),
90         input_(input),
91         buffer_size_(buffer_size),
92         seed_generator_(std::move(seed_generator)),
93         count_(count),
94         traceme_metadata_(
95             {{"buffer_size",
96               strings::Printf("%lld", static_cast<long long>(buffer_size))}}) {
97     input_->Ref();
98   }
99 
~ShuffleDatasetBase()100   ~ShuffleDatasetBase() override { input_->Unref(); }
101 
102   virtual string op_type() const = 0;
103 
output_dtypes() const104   const DataTypeVector& output_dtypes() const override {
105     return input_->output_dtypes();
106   }
107 
output_shapes() const108   const std::vector<PartialTensorShape>& output_shapes() const override {
109     return input_->output_shapes();
110   }
111 
CardinalityInternal() const112   int64_t CardinalityInternal() const override {
113     if (count_ == -1 || input_->Cardinality() == kInfiniteCardinality) {
114       return kInfiniteCardinality;
115     } else if (input_->Cardinality() == kUnknownCardinality) {
116       return kUnknownCardinality;
117     } else {
118       return input_->Cardinality() * count_;
119     }
120   }
121 
CardinalityInternal(CardinalityOptions options) const122   int64_t CardinalityInternal(CardinalityOptions options) const override {
123     if (count_ == -1 || input_->Cardinality(options) == kInfiniteCardinality) {
124       return kInfiniteCardinality;
125     } else if (input_->Cardinality(options) == kUnknownCardinality) {
126       return kUnknownCardinality;
127     } else {
128       return input_->Cardinality(options) * count_;
129     }
130   }
131 
InputDatasets(std::vector<const DatasetBase * > * inputs) const132   Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
133     inputs->push_back(input_);
134     return OkStatus();
135   }
136 
CheckExternalState() const137   Status CheckExternalState() const override {
138     return input_->CheckExternalState();
139   }
140 
Get(OpKernelContext * ctx,int64 index,std::vector<Tensor> * out_tensors) const141   Status Get(OpKernelContext* ctx, int64 index,
142              std::vector<Tensor>* out_tensors) const override {
143     TF_RETURN_IF_ERROR(CheckRandomAccessCompatible(index));
144     {
145       mutex_lock l(mu_);
146       if (shuffled_indices_.empty()) {
147         InitializeRandomAccessIndices();
148       }
149     }
150     int64 shuffled_index;
151     {
152       tf_shared_lock l(mu_);
153       shuffled_index = shuffled_indices_[index];
154     }
155     TF_RETURN_IF_ERROR(input_->Get(ctx, shuffled_index, out_tensors));
156     return OkStatus();
157   }
158 
DebugString() const159   string DebugString() const override {
160     name_utils::DatasetDebugStringParams params;
161     params.set_args(buffer_size_, seed_generator_->seed(),
162                     seed_generator_->seed2(), count_);
163     return name_utils::DatasetDebugString(op_type(), params);
164   }
165 
MakeIteratorInternal(const string & prefix) const166   std::unique_ptr<IteratorBase> MakeIteratorInternal(
167       const string& prefix) const override {
168     return std::make_unique<Iterator>(
169         Iterator::Params{this, name_utils::IteratorPrefix(op_type(), prefix)},
170         seed_generator_.get());
171   }
172 
InitializeRandomAccessIndices() const173   void InitializeRandomAccessIndices() const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
174     const int64 cardinality = Cardinality();
175     shuffled_indices_ = std::vector<std::int64_t>(cardinality);
176     std::iota(shuffled_indices_.begin(), shuffled_indices_.end(), 0);
177     int64_t shuffled_index = 0;
178     random::PhiloxRandom parent_generator =
179         random::PhiloxRandom(seed_generator_->seed(), seed_generator_->seed2());
180     random::SingleSampleAdapter<random::PhiloxRandom> generator =
181         random::SingleSampleAdapter<random::PhiloxRandom>(&parent_generator);
182 
183     while (shuffled_index < cardinality) {
184       int64_t offset = generator() % (cardinality - shuffled_index);
185       std::swap(shuffled_indices_[shuffled_index + offset],
186                 shuffled_indices_[shuffled_index]);
187       shuffled_index += 1;
188     }
189   }
190 
191  protected:
192   class Iterator : public DatasetIterator<ShuffleDatasetBase> {
193    public:
Iterator(const Params & params,SeedGenerator * seed_generator)194     explicit Iterator(const Params& params, SeedGenerator* seed_generator)
195         : DatasetIterator<ShuffleDatasetBase>(params),
196           seed_generator_(seed_generator),
197           parent_generator_(seed_generator->seed(), seed_generator->seed2()),
198           generator_(&parent_generator_) {
199       buffer_ = std::make_unique<std::vector<std::vector<Tensor>>>(
200           params.dataset->buffer_size_);
201     }
202 
Initialize(IteratorContext * ctx)203     Status Initialize(IteratorContext* ctx) override {
204       mutex_lock l(mu_);
205       seed_generator_->GenerateSeeds(&seed_, &seed2_);
206       ResetRngs();
207       return OkStatus();
208     }
209 
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)210     Status GetNextInternal(IteratorContext* ctx,
211                            std::vector<Tensor>* out_tensors,
212                            bool* end_of_sequence) override {
213       mutex_lock l(mu_);
214       TF_RETURN_IF_ERROR(FillBuffer(ctx));
215       if (num_elements_ == 0) {
216         DCHECK(input_impl_ == nullptr);
217         *end_of_sequence = true;
218         return OkStatus();
219       }
220 
221       *end_of_sequence = false;
222       ClearEmptySlices();
223       DCHECK(!slices_.empty());
224       // Choose an element to produce uniformly at random from the first
225       // slice, and then remove the element from the slice.
226       int64_t offset =
227           Random() % (slices_.front()->end - slices_.front()->start);
228       int64_t index = (slices_.front()->start + offset) % buffer_->size();
229       *out_tensors = std::move(buffer_->at(index));
230       this->RecordBufferDequeue(ctx, *out_tensors);
231       std::swap(buffer_->at(index),
232                 buffer_->at(slices_.front()->start % buffer_->size()));
233       slices_.front()->start++;
234       num_elements_--;
235       return OkStatus();
236     }
237 
238    protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const239     std::shared_ptr<model::Node> CreateNode(
240         IteratorContext* ctx, model::Node::Args args) const override {
241       return model::MakeKnownRatioNode(std::move(args),
242                                        /*ratio=*/1);
243     }
244 
ResetRngs()245     void ResetRngs() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
246       // Reset the generators based on the current iterator seeds.
247       parent_generator_ = random::PhiloxRandom(seed_, seed2_);
248       generator_ =
249           random::SingleSampleAdapter<random::PhiloxRandom>(&parent_generator_);
250       generator_.Skip(num_random_samples_);
251     }
252 
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)253     Status SaveInternal(SerializationContext* ctx,
254                         IteratorStateWriter* writer) override {
255       mutex_lock l(mu_);
256       // Save state needed to restore the random number generators.
257       TF_RETURN_IF_ERROR(
258           writer->WriteScalar(full_name(kEpochNumRandomSamples),
259                               seed_generator_->num_random_samples()));
260       TF_RETURN_IF_ERROR(writer->WriteScalar(this->full_name(kNumRandomSamples),
261                                              num_random_samples_));
262       TF_RETURN_IF_ERROR(writer->WriteScalar(this->full_name(kSeed), seed_));
263       TF_RETURN_IF_ERROR(writer->WriteScalar(this->full_name(kSeed2), seed2_));
264 
265       // Save input iterator if it hasn't been exhausted else write
266       // "end_of_input_sequence".
267       if (!input_impl_) {
268         TF_RETURN_IF_ERROR(
269             writer->WriteScalar(this->full_name(kEndOfInputSequence), ""));
270       } else {
271         TF_RETURN_IF_ERROR(this->SaveInput(ctx, writer, input_impl_));
272       }
273 
274       // Save the epoch counter, buffer, and buffer slices.
275       TF_RETURN_IF_ERROR(writer->WriteScalar(this->full_name(kEpoch), epoch_));
276       TF_RETURN_IF_ERROR(
277           writer->WriteScalar(this->full_name(kNumElements), num_elements_));
278       TF_RETURN_IF_ERROR(WriteElementsToCheckpoint(writer, prefix(), *buffer_));
279       TF_RETURN_IF_ERROR(
280           writer->WriteScalar(this->full_name(kSlicesSize), slices_.size()));
281       for (size_t i = 0; i < slices_.size(); ++i) {
282         TF_RETURN_IF_ERROR(
283             writer->WriteScalar(this->full_name(absl::StrJoin(
284                                     std::make_tuple(kSlicesStart, i), "_")),
285                                 slices_[i]->start));
286         TF_RETURN_IF_ERROR(writer->WriteScalar(
287             this->full_name(absl::StrJoin(std::make_tuple(kSlicesEnd, i), "_")),
288             slices_[i]->end));
289       }
290       if (data_produced_) {
291         TF_RETURN_IF_ERROR(
292             writer->WriteScalar(this->full_name(kDataProduced), ""));
293       }
294 
295       return OkStatus();
296     }
297 
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)298     Status RestoreInternal(IteratorContext* ctx,
299                            IteratorStateReader* reader) override {
300       mutex_lock l(mu_);
301       // Restore the random number generators.
302       int64_t num_random_samples;
303       TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kEpochNumRandomSamples),
304                                             &num_random_samples));
305       seed_generator_->set_num_random_samples(num_random_samples);
306       seed_generator_->Reset();
307       TF_RETURN_IF_ERROR(reader->ReadScalar(this->full_name(kNumRandomSamples),
308                                             &num_random_samples_));
309       TF_RETURN_IF_ERROR(reader->ReadScalar(this->full_name(kSeed), &seed_));
310       TF_RETURN_IF_ERROR(reader->ReadScalar(this->full_name(kSeed2), &seed2_));
311       ResetRngs();
312 
313       // Restore the input iterator if it wasn't already exhausted.
314       if (!reader->Contains(this->full_name(kEndOfInputSequence))) {
315         TF_RETURN_IF_ERROR(this->dataset()->input_->MakeIterator(
316             ctx, this, this->prefix(), &input_impl_));
317         TF_RETURN_IF_ERROR(this->RestoreInput(ctx, reader, input_impl_));
318       } else {
319         input_impl_.reset();
320       }
321 
322       // Restore the epoch counter, buffer, and buffer slices.
323       TF_RETURN_IF_ERROR(reader->ReadScalar(this->full_name(kEpoch), &epoch_));
324       TF_RETURN_IF_ERROR(
325           reader->ReadScalar(this->full_name(kNumElements), &num_elements_));
326       size_t slices_size;
327       {
328         int64_t temp;
329         TF_RETURN_IF_ERROR(
330             reader->ReadScalar(this->full_name(kSlicesSize), &temp));
331         slices_size = static_cast<size_t>(temp);
332       }
333       buffer_ = std::make_unique<std::vector<std::vector<Tensor>>>();
334       TF_RETURN_IF_ERROR(
335           ReadElementsFromCheckpoint(ctx, reader, prefix(), buffer_.get()));
336       for (const auto& element : *buffer_) {
337         RecordBufferEnqueue(ctx, element);
338       }
339       buffer_->resize(dataset()->buffer_size_);
340       slices_.clear();
341       for (size_t i = 0; i < slices_size; ++i) {
342         int64_t start;
343         TF_RETURN_IF_ERROR(
344             reader->ReadScalar(this->full_name(absl::StrJoin(
345                                    std::make_tuple(kSlicesStart, i), "_")),
346                                &start));
347         int64_t end;
348         TF_RETURN_IF_ERROR(reader->ReadScalar(
349             this->full_name(absl::StrJoin(std::make_tuple(kSlicesEnd, i), "_")),
350             &end));
351         slices_.push_back(std::make_unique<Slice>(start, end));
352       }
353       data_produced_ = reader->Contains(this->full_name(kDataProduced));
354 
355       return OkStatus();
356     }
357 
GetTraceMeMetadata() const358     TraceMeMetadata GetTraceMeMetadata() const override {
359       return this->dataset()->traceme_metadata_;
360     }
361 
362    private:
363     // Used to represent slices of `buffer_` that belong to different epochs.
364     // The invariant maintained by the implementation is: `start` <= `end`.
365     // When using `start` and `end` to index into `buffer_`, their values
366     // should be taken modulo the size of `buffer_` as their absolute value
367     // can be greater than the range of `buffer_`.
368     struct Slice {
Slicetensorflow::data::ShuffleDatasetOpBase::ShuffleDatasetBase::Iterator::Slice369       Slice(int64_t start, int64_t end) : start(start), end(end) {}
370 
371       int64_t start;
372       int64_t end;
373     };
374 
Random()375     random::SingleSampleAdapter<random::PhiloxRandom>::ResultType Random()
376         TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
377       num_random_samples_++;
378       auto out = generator_();
379       return out;
380     }
381 
382     // Fills the shuffle buffer, preparing the buffer for sampling.
FillBuffer(IteratorContext * ctx)383     Status FillBuffer(IteratorContext* ctx) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
384       int64_t start_micros = EnvTime::NowMicros();
385       int64_t num_log_entries = 0;
386       while (ShouldFillBuffer()) {
387         if (EnvTime::NowMicros() >
388             ((num_log_entries + 1) * kLogIntervalMicros) + start_micros) {
389           num_log_entries++;
390           LOG(INFO) << "Filling up shuffle buffer (this may take a while): "
391                     << num_elements_ << " of " << BufferSizeString();
392         }
393         if (!input_impl_) {
394           TF_RETURN_IF_ERROR(PrepareNextEpoch(ctx));
395         }
396         std::vector<Tensor> input_element;
397         bool end_of_input_sequence = false;
398         TF_RETURN_IF_ERROR(
399             input_impl_->GetNext(ctx, &input_element, &end_of_input_sequence));
400         if (!end_of_input_sequence) {
401           AddToShuffleBuffer(ctx, std::move(input_element));
402           continue;
403         }
404         input_impl_.reset();
405         // Reached end of input_impl_.
406         if (ctx->split_providers().empty() && !data_produced_ &&
407             this->dataset()->count_ == -1) {
408           // If we encounter the end of sequence without producing data, we
409           // terminate the iteration immediately. (Otherwise, this iterator
410           // would loop infinitely and never produce a value.)
411           return OkStatus();
412         }
413       }
414       if (num_log_entries > 0) {
415         LOG(INFO) << "Shuffle buffer filled.";
416       }
417       return OkStatus();
418     }
419 
ShouldFillBuffer()420     bool ShouldFillBuffer() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
421       if (!input_impl_ && dataset()->count_ != -1 &&
422           epoch_ >= dataset()->count_) {
423         return false;
424       }
425       if (slices_.size() > kMaxEpochsInBuffer && num_elements_ > 0) {
426         // When the elements stored in `buffer_` span more than
427         // `kMaxEpochsInBuffer` epochs, we do not fill the buffer further to
428         // conserve memory. This means that the upper bound on the size of
429         // `buffer_` is `kMaxEpochsInBuffer * cardinality(input_dataset) +
430         // 1`.
431         return false;
432       }
433       return num_elements_ < buffer_->size();
434     }
435 
PrepareNextEpoch(IteratorContext * ctx)436     Status PrepareNextEpoch(IteratorContext* ctx)
437         TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
438       if (epoch_ == 0) {
439         slices_.push_back(std::make_unique<Slice>(0, 0));
440       } else {
441         int64_t n = slices_.back()->end;
442         slices_.push_back(std::make_unique<Slice>(n, n));
443         for (const auto& provider : ctx->split_providers()) {
444           TF_RETURN_IF_ERROR(provider->Reset());
445         }
446       }
447       TF_RETURN_IF_ERROR(this->dataset()->input_->MakeIterator(
448           ctx, this, this->prefix(), &input_impl_));
449       epoch_++;
450       return OkStatus();
451     }
452 
AddToShuffleBuffer(IteratorContext * ctx,std::vector<Tensor> && element)453     void AddToShuffleBuffer(IteratorContext* ctx, std::vector<Tensor>&& element)
454         TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
455       data_produced_ = true;
456       if (num_elements_ == 0) {
457         VLOG(1) << "Starting to fill up shuffle buffer of size: "
458                 << BufferSizeString();
459       }
460       this->RecordBufferEnqueue(ctx, element);
461       size_t index = slices_.back()->end % buffer_->size();
462       buffer_->at(index) = std::move(element);
463       num_elements_++;
464       slices_.back()->end++;
465     }
466 
ClearEmptySlices()467     void ClearEmptySlices() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
468       // Garbage collect all empty slices.
469       while (slices_.front()->start == slices_.front()->end) {
470         slices_.pop_front();
471         // Reinitialize the RNG state for the next epoch.
472         num_random_samples_ = 0;
473         seed_generator_->GenerateSeeds(&seed_, &seed2_);
474         ResetRngs();
475       }
476     }
477 
BufferSizeString()478     std::string BufferSizeString() {
479       return absl::StrCat(dataset()->buffer_size_);
480     }
481 
482     mutex mu_;
483     SeedGenerator* const seed_generator_ TF_GUARDED_BY(mu_);  // Not owned.
484     std::unique_ptr<std::vector<std::vector<Tensor>>> buffer_
485         TF_GUARDED_BY(mu_);
486     std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_) = nullptr;
487     int64_t epoch_ TF_GUARDED_BY(mu_) = 0;
488     int64_t num_elements_ TF_GUARDED_BY(mu_) = 0;
489     int64_t seed_ TF_GUARDED_BY(mu_) = 0;
490     int64_t seed2_ TF_GUARDED_BY(mu_) = 0;
491     // Indices into `buffer_` indicating which data belongs to which epoch.
492     // The slice at the front of the deque references data from the earliest
493     // buffered epoch. It is an invariant that all slices reference
494     // non-overlapping sections of `buffer_`.
495     std::deque<std::unique_ptr<Slice>> slices_ TF_GUARDED_BY(mu_);
496     random::PhiloxRandom parent_generator_ TF_GUARDED_BY(mu_);
497     random::SingleSampleAdapter<random::PhiloxRandom> generator_
498         TF_GUARDED_BY(mu_);
499     int64_t num_random_samples_ TF_GUARDED_BY(mu_) = 0;
500     bool data_produced_ TF_GUARDED_BY(mu_) = false;
501   };
502 
503   const DatasetBase* const input_;
504   const int64_t buffer_size_;
505   const std::shared_ptr<SeedGenerator> seed_generator_;
506   // The number of epochs to run for. Normally this is just 1, but sometimes we
507   // fuse shuffle and repeat together, and make the shuffle dataset op
508   // responsible for repeating as well.
509   const int64_t count_;
510   const TraceMeMetadata traceme_metadata_;
511   mutable mutex mu_;
512   mutable std::vector<std::int64_t> shuffled_indices_ TF_GUARDED_BY(mu_);
513 };  // ShuffleDatasetBase
514 
515 // This version of memory dataset has an exclusive ownership of the seed
516 // generator resource. It supports sharing of the seed generator across
517 // different iterations of the `repeat` transformation but not across different
518 // iterators.
519 class ShuffleDatasetOp::Dataset : public ShuffleDatasetBase {
520  public:
Dataset(OpKernelContext * ctx,const DatasetBase * input,int64_t buffer_size,int64_t count,RandomSeeds && seeds,SeedGeneratorManager * manager,ResourceHandle && resource_handle)521   Dataset(OpKernelContext* ctx, const DatasetBase* input, int64_t buffer_size,
522           int64_t count, RandomSeeds&& seeds, SeedGeneratorManager* manager,
523           ResourceHandle&& resource_handle)
524       : ShuffleDatasetBase(ctx, input, buffer_size, manager->get(), count),
525         manager_(manager),
526         resource_handle_(std::move(resource_handle)),
527         resource_mgr_(ctx->resource_manager()),
528         seeds_(std::move(seeds)) {}
529 
~Dataset()530   ~Dataset() override {
531     manager_->Unref();
532     Status s = resource_mgr_->Delete<SeedGeneratorManager>(
533         resource_handle_.container(), resource_handle_.name());
534     if (!s.ok()) {
535       LOG(WARNING) << "Failed to delete RNG resource: " << s.ToString();
536     }
537   }
538 
op_type() const539   string op_type() const override { return kDatasetType; }
540 
541  protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const542   Status AsGraphDefInternal(SerializationContext* ctx,
543                             DatasetGraphDefBuilder* b,
544                             Node** output) const override {
545     Node* input_graph_node = nullptr;
546     TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
547     Node* buffer_size_node = nullptr;
548     Node* seed_node = nullptr;
549     Node* seed2_node = nullptr;
550     AttrValue reshuffle_each_iteration;
551 
552     TF_RETURN_IF_ERROR(b->AddScalar(buffer_size_, &buffer_size_node));
553     TF_RETURN_IF_ERROR(b->AddScalar(seeds_.input_seed(), &seed_node));
554     TF_RETURN_IF_ERROR(b->AddScalar(seeds_.input_seed2(), &seed2_node));
555     b->BuildAttrValue(seed_generator_->reshuffle_each_iteration(),
556                       &reshuffle_each_iteration);
557     TF_RETURN_IF_ERROR(b->AddDataset(
558         this,
559         {input_graph_node, buffer_size_node, seed_node, seed2_node},  // Inputs
560         {std::make_pair(kReshuffleEachIteration,
561                         reshuffle_each_iteration)},  // Attrs
562         output));
563     return OkStatus();
564   }
565 
566  private:
567   SeedGeneratorManager* const manager_;  // Owned.
568   const ResourceHandle resource_handle_;
569   ResourceMgr* const resource_mgr_;  // Not owned.
570   const RandomSeeds seeds_;
571 };
572 
573 // This version of shuffle dataset has a shared ownership of the seed generator
574 // resource. It supports sharing of the generator state across different
575 // iterations of the `repeat` transformation and also across different
576 // iterators.
577 class ShuffleDatasetOp::DatasetV2 : public ShuffleDatasetBase {
578  public:
DatasetV2(OpKernelContext * ctx,const DatasetBase * input,int64_t buffer_size,int64_t count,SeedGeneratorManager * manager,ResourceHandle && resource_handle,bool owns_resource)579   DatasetV2(OpKernelContext* ctx, const DatasetBase* input, int64_t buffer_size,
580             int64_t count, SeedGeneratorManager* manager,
581             ResourceHandle&& resource_handle, bool owns_resource)
582       : ShuffleDatasetBase(ctx, input, buffer_size, manager->get(), count),
583         manager_(manager),
584         owns_resource_(owns_resource),
585         resource_handle_(std::move(resource_handle)),
586         resource_mgr_(ctx->resource_manager()) {}
587 
~DatasetV2()588   ~DatasetV2() override {
589     manager_->Unref();
590     if (owns_resource_) {
591       Status s = resource_mgr_->Delete<SeedGeneratorManager>(
592           resource_handle_.container(), resource_handle_.name());
593       if (!s.ok()) {
594         LOG(WARNING) << "Failed to delete RNG resource: " << s.ToString();
595       }
596     }
597   }
598 
op_type() const599   string op_type() const override { return kDatasetType; }
600 
601  protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const602   Status AsGraphDefInternal(SerializationContext* ctx,
603                             DatasetGraphDefBuilder* b,
604                             Node** output) const override {
605     Node* input_graph_node = nullptr;
606     TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
607     Node* buffer_size_node = nullptr;
608     TF_RETURN_IF_ERROR(b->AddScalar(buffer_size_, &buffer_size_node));
609     Node* resource_handle_node = nullptr;
610     Tensor handle(DT_RESOURCE, TensorShape({}));
611     handle.scalar<ResourceHandle>()() = resource_handle_;
612     TF_RETURN_IF_ERROR(b->AddTensor(handle, &resource_handle_node));
613     TF_RETURN_IF_ERROR(b->AddDataset(
614         this,
615         {input_graph_node, buffer_size_node, resource_handle_node},  // Inputs
616         {},                                                          // Attrs
617         output));
618     return OkStatus();
619   }
620 
621  private:
622   SeedGeneratorManager* const manager_;  // Owned.
623   const bool owns_resource_;
624   const ResourceHandle resource_handle_;
625   ResourceMgr* const resource_mgr_;  // Not owned.
626 };
627 
628 // This version of shuffle dataset extends the functionality of DatasetV2 with
629 // the ability to preserve seed generator configuration (i.e. initial seeds and
630 // whether to reshuffle each iteration) across serialization of the dataset.
631 class ShuffleDatasetOp::DatasetV3 : public ShuffleDatasetBase {
632  public:
DatasetV3(OpKernelContext * ctx,const DatasetBase * input,int64_t buffer_size,int64_t count,RandomSeeds && seeds,SeedGeneratorManager * manager,ResourceHandle && resource_handle,bool owns_resource)633   DatasetV3(OpKernelContext* ctx, const DatasetBase* input, int64_t buffer_size,
634             int64_t count, RandomSeeds&& seeds, SeedGeneratorManager* manager,
635             ResourceHandle&& resource_handle, bool owns_resource)
636       : ShuffleDatasetBase(ctx, input, buffer_size, manager->get(), count),
637         manager_(manager),
638         owns_resource_(owns_resource),
639         resource_handle_(std::move(resource_handle)),
640         resource_mgr_(ctx->resource_manager()),
641         seeds_(std::move(seeds)) {}
642 
~DatasetV3()643   ~DatasetV3() override {
644     manager_->Unref();
645     if (owns_resource_) {
646       Status s = resource_mgr_->Delete<SeedGeneratorManager>(
647           resource_handle_.container(), resource_handle_.name());
648       if (!s.ok()) {
649         LOG(WARNING) << "Failed to delete RNG resource: " << s.ToString();
650       }
651     }
652   }
653 
op_type() const654   string op_type() const override { return kDatasetType; }
655 
656  protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const657   Status AsGraphDefInternal(SerializationContext* ctx,
658                             DatasetGraphDefBuilder* b,
659                             Node** output) const override {
660     Node* input_graph_node = nullptr;
661     TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
662     Node* buffer_size_node = nullptr;
663     Node* seed_node = nullptr;
664     Node* seed2_node = nullptr;
665     TF_RETURN_IF_ERROR(b->AddScalar(buffer_size_, &buffer_size_node));
666     TF_RETURN_IF_ERROR(b->AddScalar(seeds_.input_seed(), &seed_node));
667     TF_RETURN_IF_ERROR(b->AddScalar(seeds_.input_seed2(), &seed2_node));
668     Node* resource_handle_node = nullptr;
669     Tensor handle(DT_RESOURCE, TensorShape({}));
670     handle.scalar<ResourceHandle>()() = resource_handle_;
671     TF_RETURN_IF_ERROR(b->AddTensor(handle, &resource_handle_node));
672     AttrValue reshuffle_each_iteration;
673     b->BuildAttrValue(seed_generator_->reshuffle_each_iteration(),
674                       &reshuffle_each_iteration);
675     TF_RETURN_IF_ERROR(
676         b->AddDataset(this,
677                       {input_graph_node, buffer_size_node, seed_node,
678                        seed2_node, resource_handle_node},  // Inputs
679                       {std::make_pair(kReshuffleEachIteration,
680                                       reshuffle_each_iteration)},  // Attrs
681                       output));
682     return OkStatus();
683   }
684 
685  private:
686   SeedGeneratorManager* const manager_;  // Owned
687   const bool owns_resource_;
688   const ResourceHandle resource_handle_;
689   ResourceMgr* const resource_mgr_;  // Not owned.
690   const RandomSeeds seeds_;
691 };
692 
ShuffleDatasetOp(OpKernelConstruction * ctx)693 ShuffleDatasetOp::ShuffleDatasetOp(OpKernelConstruction* ctx)
694     : ShuffleDatasetOpBase(ctx) {
695   auto& op_name = ctx->def().op();
696   if (op_name == kShuffleDatasetV3) {
697     op_version_ = 3;
698   } else if (op_name == kShuffleDatasetV2) {
699     op_version_ = 2;
700   } else if (op_name == kShuffleDatasetV1) {
701     op_version_ = 1;
702   }
703   if (ctx->HasAttr(kReshuffleEachIteration)) {
704     OP_REQUIRES_OK(
705         ctx, ctx->GetAttr(kReshuffleEachIteration, &reshuffle_each_iteration_));
706   }
707 }
708 
MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)709 void ShuffleDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
710                                    DatasetBase** output) {
711   int64_t buffer_size = 0;
712   OP_REQUIRES_OK(ctx,
713                  ParseScalarArgument<int64_t>(ctx, kBufferSize, &buffer_size));
714   OP_REQUIRES(
715       ctx, buffer_size > 0,
716       errors::InvalidArgument("buffer_size must be greater than zero."));
717 
718   int64_t count = 1;
719   static std::atomic<int64_t> resource_id_counter(0);
720   const string& container = ctx->resource_manager()->default_container();
721   auto name = strings::StrCat(ctx->op_kernel().name(), "/", kSeedGenerator, "_",
722                               resource_id_counter.fetch_add(1));
723   if (op_version_ == 3) {
724     auto handle = HandleFromInput(ctx, 4);
725     SeedGeneratorManager* manager = nullptr;
726     Status s = ctx->resource_manager()->Lookup<SeedGeneratorManager>(
727         handle.container(), handle.name(), &manager);
728     int64_t seed;
729     OP_REQUIRES_OK(ctx, ParseScalarArgument<int64_t>(ctx, kSeed, &seed));
730     int64_t seed2;
731     OP_REQUIRES_OK(ctx, ParseScalarArgument<int64_t>(ctx, kSeed2, &seed2));
732     RandomSeeds seeds(seed, seed2);
733     bool owns_resource = false;
734     if (errors::IsNotFound(s)) {
735       owns_resource = true;
736       OP_REQUIRES_OK(
737           ctx,
738           ctx->resource_manager()->LookupOrCreate<SeedGeneratorManager>(
739               container, name, &manager,
740               [reshuffle = reshuffle_each_iteration_,
741                &seeds](SeedGeneratorManager** manager) {
742                 if (reshuffle) {
743                   *manager =
744                       new SeedGeneratorManager(new RandomSeedGenerator(seeds));
745                 } else {
746                   *manager =
747                       new SeedGeneratorManager(new FixedSeedGenerator(seeds));
748                 }
749                 return OkStatus();
750               }));
751       handle = MakeResourceHandle<SeedGenerator>(ctx, container, name);
752     } else {
753       OP_REQUIRES_OK(ctx, s);
754     }
755 
756     // Ownership of manager is transferred onto `DatasetV3`.
757     *output = new ShuffleDatasetOp::DatasetV3(ctx, input, buffer_size, count,
758                                               std::move(seeds), manager,
759                                               std::move(handle), owns_resource);
760   } else if (op_version_ == 2) {
761     auto handle = HandleFromInput(ctx, 2);
762     SeedGeneratorManager* manager = nullptr;
763     Status s = ctx->resource_manager()->Lookup<SeedGeneratorManager>(
764         handle.container(), handle.name(), &manager);
765     bool owns_resource = false;
766     if (errors::IsNotFound(s)) {
767       owns_resource = true;
768       LOG(WARNING) << "Failed to find seed generator resource. Falling back to "
769                       "using a non-deterministically seeded generator and "
770                       "reshuffling each iteration.";
771       RandomSeeds seeds(0, 0);
772       OP_REQUIRES_OK(
773           ctx, ctx->resource_manager()->LookupOrCreate<SeedGeneratorManager>(
774                    container, name, &manager,
775                    [&seeds](SeedGeneratorManager** manager) {
776                      *manager = new SeedGeneratorManager(
777                          new RandomSeedGenerator(seeds));
778                      return OkStatus();
779                    }));
780       handle = MakeResourceHandle<SeedGeneratorManager>(ctx, container, name);
781     } else {
782       OP_REQUIRES_OK(ctx, s);
783     }
784 
785     // Ownership of manager is transferred onto `DatasetV2`.
786     *output =
787         new ShuffleDatasetOp::DatasetV2(ctx, input, buffer_size, count, manager,
788                                         std::move(handle), owns_resource);
789   } else {
790     if (op_version_ != 1) {
791       LOG(WARNING) << "Unsupported version of shuffle dataset op: "
792                    << op_version_ << ". Defaulting to version 1.";
793     }
794     int64_t seed;
795     OP_REQUIRES_OK(ctx, ParseScalarArgument<int64_t>(ctx, kSeed, &seed));
796     int64_t seed2;
797     OP_REQUIRES_OK(ctx, ParseScalarArgument<int64_t>(ctx, kSeed2, &seed2));
798     RandomSeeds seeds(seed, seed2);
799     SeedGeneratorManager* manager;
800     OP_REQUIRES_OK(
801         ctx,
802         ctx->resource_manager()->LookupOrCreate<SeedGeneratorManager>(
803             container, name, &manager,
804             [reshuffle = reshuffle_each_iteration_,
805              &seeds](SeedGeneratorManager** manager) {
806               if (reshuffle) {
807                 *manager =
808                     new SeedGeneratorManager(new RandomSeedGenerator(seeds));
809               } else {
810                 *manager =
811                     new SeedGeneratorManager(new FixedSeedGenerator(seeds));
812               }
813               return OkStatus();
814             }));
815     auto handle =
816         MakeResourceHandle<SeedGeneratorManager>(ctx, container, name);
817 
818     // Ownership of manager is transferred onto `Dataset`.
819     *output = new ShuffleDatasetOp::Dataset(ctx, input, buffer_size, count,
820                                             std::move(seeds), manager,
821                                             std::move(handle));
822   }
823 }
824 
825 class ShuffleAndRepeatDatasetOp::Dataset : public ShuffleDatasetBase {
826  public:
Dataset(OpKernelContext * ctx,const DatasetBase * input,int64_t buffer_size,RandomSeeds && seeds,SeedGeneratorManager * manager,int64_t count,ResourceHandle && resource_handle)827   Dataset(OpKernelContext* ctx, const DatasetBase* input, int64_t buffer_size,
828           RandomSeeds&& seeds, SeedGeneratorManager* manager, int64_t count,
829           ResourceHandle&& resource_handle)
830       : ShuffleDatasetBase(ctx, input, buffer_size, manager->get(), count),
831         manager_(manager),
832         resource_handle_(std::move(resource_handle)),
833         resource_mgr_(ctx->resource_manager()),
834         seeds_(std::move(seeds)) {}
835 
~Dataset()836   ~Dataset() override {
837     manager_->Unref();
838     Status s = resource_mgr_->Delete<SeedGeneratorManager>(
839         resource_handle_.container(), resource_handle_.name());
840     if (!s.ok()) {
841       LOG(WARNING) << "Failed to delete RNG resource: " << s.ToString();
842     }
843   }
844 
op_type() const845   string op_type() const override { return kDatasetType; }
846 
847  protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const848   Status AsGraphDefInternal(SerializationContext* ctx,
849                             DatasetGraphDefBuilder* b,
850                             Node** output) const override {
851     Node* input_graph_node = nullptr;
852     TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
853     Node* buffer_size = nullptr;
854     Node* seed = nullptr;
855     Node* seed2 = nullptr;
856     Node* count = nullptr;
857 
858     TF_RETURN_IF_ERROR(b->AddScalar(buffer_size_, &buffer_size));
859     TF_RETURN_IF_ERROR(b->AddScalar(seeds_.input_seed(), &seed));
860     TF_RETURN_IF_ERROR(b->AddScalar(seeds_.input_seed2(), &seed2));
861     TF_RETURN_IF_ERROR(b->AddScalar(count_, &count));
862     AttrValue reshuffle_each_iteration;
863     b->BuildAttrValue(seed_generator_->reshuffle_each_iteration(),
864                       &reshuffle_each_iteration);
865     TF_RETURN_IF_ERROR(b->AddDataset(
866         this, {input_graph_node, buffer_size, seed, seed2, count},  // Inputs
867         {std::make_pair(kReshuffleEachIteration,
868                         reshuffle_each_iteration)},  // Attrs
869         output));
870     return OkStatus();
871   }
872 
873  private:
874   SeedGeneratorManager* const manager_;  // Owned.
875   const ResourceHandle resource_handle_;
876   ResourceMgr* const resource_mgr_;  // Not owned.
877   const RandomSeeds seeds_;
878 };
879 
880 class ShuffleAndRepeatDatasetOp::DatasetV2 : public ShuffleDatasetBase {
881  public:
DatasetV2(OpKernelContext * ctx,const DatasetBase * input,int64_t buffer_size,int64_t count,RandomSeeds && seeds,SeedGeneratorManager * manager,ResourceHandle && resource_handle,bool owns_resource)882   DatasetV2(OpKernelContext* ctx, const DatasetBase* input, int64_t buffer_size,
883             int64_t count, RandomSeeds&& seeds, SeedGeneratorManager* manager,
884             ResourceHandle&& resource_handle, bool owns_resource)
885       : ShuffleDatasetBase(ctx, input, buffer_size, manager->get(), count),
886         manager_(manager),
887         owns_resource_(owns_resource),
888         resource_handle_(std::move(resource_handle)),
889         resource_mgr_(ctx->resource_manager()),
890         seeds_(std::move(seeds)) {}
891 
~DatasetV2()892   ~DatasetV2() override {
893     manager_->Unref();
894     if (owns_resource_) {
895       Status s = resource_mgr_->Delete<SeedGeneratorManager>(
896           resource_handle_.container(), resource_handle_.name());
897       if (!s.ok()) {
898         LOG(WARNING) << "Failed to delete RNG resource: " << s.ToString();
899       }
900     }
901   }
902 
op_type() const903   string op_type() const override { return kDatasetType; }
904 
905  protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const906   Status AsGraphDefInternal(SerializationContext* ctx,
907                             DatasetGraphDefBuilder* b,
908                             Node** output) const override {
909     Node* input_graph_node = nullptr;
910     TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
911     Node* buffer_size_node = nullptr;
912     Node* seed_node = nullptr;
913     Node* seed2_node = nullptr;
914     Node* count_node = nullptr;
915     TF_RETURN_IF_ERROR(b->AddScalar(buffer_size_, &buffer_size_node));
916     TF_RETURN_IF_ERROR(b->AddScalar(seeds_.input_seed(), &seed_node));
917     TF_RETURN_IF_ERROR(b->AddScalar(seeds_.input_seed2(), &seed2_node));
918     TF_RETURN_IF_ERROR(b->AddScalar(count_, &count_node));
919     Node* resource_handle_node = nullptr;
920     Tensor handle(DT_RESOURCE, TensorShape({}));
921     handle.scalar<ResourceHandle>()() = resource_handle_;
922     TF_RETURN_IF_ERROR(b->AddTensor(handle, &resource_handle_node));
923     AttrValue reshuffle_each_iteration;
924     b->BuildAttrValue(seed_generator_->reshuffle_each_iteration(),
925                       &reshuffle_each_iteration);
926     TF_RETURN_IF_ERROR(
927         b->AddDataset(this,
928                       {input_graph_node, buffer_size_node, seed_node,
929                        seed2_node, count_node, resource_handle_node},  // Inputs
930                       {std::make_pair(kReshuffleEachIteration,
931                                       reshuffle_each_iteration)},  // Attrs
932                       output));
933     return OkStatus();
934   }
935 
936  private:
937   SeedGeneratorManager* const manager_;  // Owned
938   const bool owns_resource_;
939   const ResourceHandle resource_handle_;
940   ResourceMgr* const resource_mgr_;  // Not owned.
941   const RandomSeeds seeds_;
942 };
943 
ShuffleAndRepeatDatasetOp(OpKernelConstruction * ctx)944 ShuffleAndRepeatDatasetOp::ShuffleAndRepeatDatasetOp(OpKernelConstruction* ctx)
945     : ShuffleDatasetOpBase(ctx) {
946   auto& op_name = ctx->def().op();
947   if (op_name == kShuffleAndRepeatDatasetV2) {
948     op_version_ = 2;
949   } else if (op_name == kShuffleAndRepeatDatasetV1) {
950     op_version_ = 1;
951   }
952   if (ctx->HasAttr(kReshuffleEachIteration)) {
953     OP_REQUIRES_OK(
954         ctx, ctx->GetAttr(kReshuffleEachIteration, &reshuffle_each_iteration_));
955   }
956 }
957 
MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)958 void ShuffleAndRepeatDatasetOp::MakeDataset(OpKernelContext* ctx,
959                                             DatasetBase* input,
960                                             DatasetBase** output) {
961   int64_t buffer_size = 0;
962   OP_REQUIRES_OK(ctx,
963                  ParseScalarArgument<int64_t>(ctx, kBufferSize, &buffer_size));
964   OP_REQUIRES(
965       ctx, buffer_size > 0,
966       errors::InvalidArgument("buffer_size must be greater than zero."));
967 
968   int64_t seed;
969   OP_REQUIRES_OK(ctx, ParseScalarArgument<int64_t>(ctx, kSeed, &seed));
970 
971   int64_t seed2;
972   OP_REQUIRES_OK(ctx, ParseScalarArgument<int64_t>(ctx, kSeed2, &seed2));
973 
974   int64_t count;
975   OP_REQUIRES_OK(ctx, ParseScalarArgument<int64_t>(ctx, kCount, &count));
976 
977   OP_REQUIRES(ctx, count > 0 || count == -1,
978               errors::InvalidArgument(
979                   "count must be greater than zero or equal to -1."));
980 
981   RandomSeeds seeds(seed, seed2);
982 
983   static std::atomic<int64_t> resource_id_counter(0);
984   const string& container = ctx->resource_manager()->default_container();
985   auto name = strings::StrCat(ctx->op_kernel().name(), "/", kSeedGenerator, "_",
986                               resource_id_counter.fetch_add(1));
987   if (op_version_ == 2) {
988     auto handle = HandleFromInput(ctx, 5);
989     SeedGeneratorManager* manager = nullptr;
990     Status s = ctx->resource_manager()->Lookup<SeedGeneratorManager>(
991         handle.container(), handle.name(), &manager);
992     bool owns_resource = false;
993     if (errors::IsNotFound(s)) {
994       owns_resource = true;
995       OP_REQUIRES_OK(
996           ctx,
997           ctx->resource_manager()->LookupOrCreate<SeedGeneratorManager>(
998               container, name, &manager,
999               [reshuffle = reshuffle_each_iteration_,
1000                &seeds](SeedGeneratorManager** manager) {
1001                 if (reshuffle) {
1002                   *manager =
1003                       new SeedGeneratorManager(new RandomSeedGenerator(seeds));
1004                 } else {
1005                   *manager =
1006                       new SeedGeneratorManager(new FixedSeedGenerator(seeds));
1007                 }
1008                 return OkStatus();
1009               }));
1010       handle = MakeResourceHandle<SeedGenerator>(ctx, container, name);
1011     } else {
1012       OP_REQUIRES_OK(ctx, s);
1013     }
1014 
1015     // Ownership of manager is transferred onto `DatasetV2`.
1016     *output = new ShuffleAndRepeatDatasetOp::DatasetV2(
1017         ctx, input, buffer_size, count, std::move(seeds), manager,
1018         std::move(handle), owns_resource);
1019   } else {
1020     if (op_version_ != 1) {
1021       LOG(WARNING) << "Unsupported version of shuffle dataset op: "
1022                    << op_version_ << ". Defaulting to version 1.";
1023     }
1024     SeedGeneratorManager* manager;
1025     OP_REQUIRES_OK(
1026         ctx,
1027         ctx->resource_manager()->LookupOrCreate<SeedGeneratorManager>(
1028             container, name, &manager,
1029             [reshuffle = reshuffle_each_iteration_,
1030              &seeds](SeedGeneratorManager** manager) {
1031               if (reshuffle) {
1032                 *manager =
1033                     new SeedGeneratorManager(new RandomSeedGenerator(seeds));
1034               } else {
1035                 *manager =
1036                     new SeedGeneratorManager(new FixedSeedGenerator(seeds));
1037               }
1038               return OkStatus();
1039             }));
1040     auto handle =
1041         MakeResourceHandle<SeedGeneratorManager>(ctx, container, name);
1042 
1043     // Ownership of manager is transferred onto `Dataset`.
1044     *output = new Dataset(ctx, input, buffer_size, std::move(seeds), manager,
1045                           count, std::move(handle));
1046   }
1047 }
1048 
1049 namespace {
1050 REGISTER_KERNEL_BUILDER(Name("ShuffleDataset").Device(DEVICE_CPU),
1051                         ShuffleDatasetOp);
1052 
1053 REGISTER_KERNEL_BUILDER(Name("ShuffleDatasetV2").Device(DEVICE_CPU),
1054                         ShuffleDatasetOp);
1055 
1056 REGISTER_KERNEL_BUILDER(Name("ShuffleDatasetV3").Device(DEVICE_CPU),
1057                         ShuffleDatasetOp);
1058 
1059 REGISTER_KERNEL_BUILDER(Name("ShuffleAndRepeatDataset").Device(DEVICE_CPU),
1060                         ShuffleAndRepeatDatasetOp);
1061 
1062 REGISTER_KERNEL_BUILDER(Name("ShuffleAndRepeatDatasetV2").Device(DEVICE_CPU),
1063                         ShuffleAndRepeatDatasetOp);
1064 }  // namespace
1065 }  // namespace data
1066 }  // namespace tensorflow
1067