• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <deque>
17 #include <vector>
18 
19 #include "tensorflow/core/framework/dataset.h"
20 #include "tensorflow/core/framework/partial_tensor_shape.h"
21 #include "tensorflow/core/framework/resource_mgr.h"
22 #include "tensorflow/core/framework/tensor.h"
23 #include "tensorflow/core/lib/random/philox_random.h"
24 #include "tensorflow/core/lib/random/random.h"
25 #include "tensorflow/core/lib/random/random_distributions.h"
26 
27 namespace tensorflow {
28 namespace data {
29 namespace {
30 
31 const int64 kLogIntervalMicros = 10 * 1000000;  // 10 seconds.
32 
33 const int64 kMaxEpochsInBuffer = 3;
34 
35 // See documentation in ../../ops/dataset_ops.cc for a high-level
36 // description of the following op.
37 
38 class ShuffleDatasetOpBase : public UnaryDatasetOpKernel {
39  public:
ShuffleDatasetOpBase(OpKernelConstruction * ctx)40   explicit ShuffleDatasetOpBase(OpKernelConstruction* ctx)
41       : UnaryDatasetOpKernel(ctx) {}
42 
43  protected:
44   // Abstract base dataset that implements a shuffling iterator.
45   class ShuffleDatasetBase : public DatasetBase {
46    public:
ShuffleDatasetBase(OpKernelContext * ctx,const DatasetBase * input,int64 buffer_size,int64 count)47     ShuffleDatasetBase(OpKernelContext* ctx, const DatasetBase* input,
48                        int64 buffer_size, int64 count)
49         : DatasetBase(DatasetContext(ctx)),
50           input_(input),
51           buffer_size_(buffer_size),
52           count_(count) {
53       input_->Ref();
54     }
55 
~ShuffleDatasetBase()56     ~ShuffleDatasetBase() override { input_->Unref(); }
57 
output_dtypes() const58     const DataTypeVector& output_dtypes() const override {
59       return input_->output_dtypes();
60     }
61 
output_shapes() const62     const std::vector<PartialTensorShape>& output_shapes() const override {
63       return input_->output_shapes();
64     }
65 
Cardinality() const66     int64 Cardinality() const override { return input_->Cardinality(); }
67 
68    protected:
69     template <class T>
70     class Iterator : public DatasetIterator<T> {
71      public:
Iterator(const typename DatasetIterator<T>::Params & params,int64 seed,int64 seed2)72       explicit Iterator(const typename DatasetIterator<T>::Params& params,
73                         int64 seed, int64 seed2)
74           : DatasetIterator<T>(params),
75             seed_(seed),
76             seed2_(seed2),
77             input_impl_(nullptr),
78             epoch_(0),
79             num_elements_(0),
80             parent_generator_(seed, seed2),
81             generator_(&parent_generator_) {
82         buffer_ = absl::make_unique<std::vector<Tensor>[]>(
83             params.dataset->buffer_size_);
84         slices_.push_back(absl::make_unique<Slice>(0, 0));
85       }
86 
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)87       Status GetNextInternal(IteratorContext* ctx,
88                              std::vector<Tensor>* out_tensors,
89                              bool* end_of_sequence) override {
90         mutex_lock l(mu_);
91         int64 start_micros = ctx->env()->NowMicros();
92         int64 num_log_entries = 0;
93         bool first_call = false;
94         if (!input_impl_ && epoch_ == 0) {
95           first_call = true;
96           TF_RETURN_IF_ERROR(this->dataset()->input_->MakeIterator(
97               ctx, this->prefix(), &input_impl_));
98         }
99         while (input_impl_ && num_elements_ < this->dataset()->buffer_size_) {
100           if (ctx->env()->NowMicros() >
101               ((num_log_entries + 1) * kLogIntervalMicros) + start_micros) {
102             num_log_entries++;
103             LOG(INFO) << "Filling up shuffle buffer (this may take a while): "
104                       << num_elements_ << " of "
105                       << this->dataset()->buffer_size_;
106           }
107           std::vector<Tensor> input_element;
108           bool end_of_input_sequence = false;
109           while (this->dataset()->count_ == -1 ||
110                  epoch_ < this->dataset()->count_) {
111             TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, &input_element,
112                                                     &end_of_input_sequence));
113             if (!end_of_input_sequence) {
114               first_call = false;
115               break;
116             }
117             if (first_call && this->dataset()->count_ == -1) {
118               // If the first call to GetNext() fails because the end
119               // of sequence has been reached, we terminate the
120               // iteration immediately. (Otherwise, this iterator
121               // would loop infinitely and never produce a value.)
122               *end_of_sequence = true;
123               return Status::OK();
124             }
125             epoch_++;
126             int64 n = slices_.back()->end;
127             slices_.push_back(absl::make_unique<Slice>(n, n));
128             TF_RETURN_IF_ERROR(this->dataset()->input_->MakeIterator(
129                 ctx, this->prefix(), &input_impl_));
130           }
131           if (!end_of_input_sequence) {
132             this->RecordBufferEnqueue(ctx, input_element);
133             buffer_[slices_.back()->end % this->dataset()->buffer_size_] =
134                 std::move(input_element);
135             num_elements_++;
136             slices_.back()->end++;
137           } else {
138             input_impl_.reset();
139           }
140           if (slices_.size() > kMaxEpochsInBuffer) {
141             // When the elements stored in `buffer_` span more than
142             // `kMaxEpochsInBuffer` epochs, we do not fill the buffer further to
143             // conserve memory. This means that the upper bound on the size of
144             // `buffer_` is `kMaxEpochsInBuffer * cardinality(input_dataset) +
145             // 1`.
146             break;
147           }
148         }
149         if (num_log_entries > 0) {
150           LOG(INFO) << "Shuffle buffer filled.";
151         }
152 
153         if (num_elements_ > 0) {
154           *end_of_sequence = false;
155           // Garbage collect all empty slices.
156           while (!slices_.empty() &&
157                  slices_.front()->start == slices_.front()->end) {
158             slices_.pop_front();
159           }
160           DCHECK(!slices_.empty());
161           // Choose an element to produce uniformly at random from the first
162           // slice, and then remove the element from the slice.
163           int64 offset =
164               Random() % (slices_.front()->end - slices_.front()->start);
165           int64 index =
166               (slices_.front()->start + offset) % this->dataset()->buffer_size_;
167           *out_tensors = std::move(buffer_[index]);
168           this->RecordBufferDequeue(ctx, *out_tensors);
169           std::swap(
170               buffer_[index],
171               buffer_[slices_.front()->start % this->dataset()->buffer_size_]);
172           slices_.front()->start++;
173           num_elements_--;
174         } else {
175           DCHECK(input_impl_ == nullptr);
176           *end_of_sequence = true;
177         }
178         return Status::OK();
179       }
180 
181      protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const182       std::shared_ptr<model::Node> CreateNode(
183           IteratorContext* ctx, model::Node::Args args) const override {
184         return model::MakeKnownRatioNode(std::move(args),
185                                          /*ratio=*/1);
186       }
187 
ResetRngs()188       void ResetRngs() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
189         // Reset the generators based on the current iterator seeds.
190         parent_generator_ = random::PhiloxRandom(seed_, seed2_);
191         generator_ = random::SingleSampleAdapter<random::PhiloxRandom>(
192             &parent_generator_);
193         generator_.Skip(num_random_samples_);
194       }
195 
SaveInternal(IteratorStateWriter * writer)196       Status SaveInternal(IteratorStateWriter* writer) override {
197         mutex_lock l(mu_);
198         // Save state needed to restore the random number generators.
199         TF_RETURN_IF_ERROR(writer->WriteScalar(
200             this->full_name("num_random_samples"), num_random_samples_));
201         TF_RETURN_IF_ERROR(writer->WriteScalar(this->full_name("seed"), seed_));
202         TF_RETURN_IF_ERROR(
203             writer->WriteScalar(this->full_name("seed2"), seed2_));
204 
205         // Save input iterator if it hasn't been exhausted else write
206         // "end_of_input_sequence".
207         if (!input_impl_) {
208           TF_RETURN_IF_ERROR(writer->WriteScalar(
209               this->full_name("end_of_input_sequence"), ""));
210         } else {
211           TF_RETURN_IF_ERROR(this->SaveInput(writer, input_impl_));
212         }
213 
214         // Save the epoch counter, buffer, and buffer slices.
215         TF_RETURN_IF_ERROR(
216             writer->WriteScalar(this->full_name("epoch"), epoch_));
217         TF_RETURN_IF_ERROR(writer->WriteScalar(this->full_name("num_elements"),
218                                                num_elements_));
219         TF_RETURN_IF_ERROR(writer->WriteScalar(this->full_name("slices_size"),
220                                                slices_.size()));
221         for (size_t i = 0; i < slices_.size(); ++i) {
222           TF_RETURN_IF_ERROR(writer->WriteScalar(
223               this->full_name(strings::StrCat("slices_start_", i)),
224               slices_[i]->start));
225           TF_RETURN_IF_ERROR(writer->WriteScalar(
226               this->full_name(strings::StrCat("slices_end_", i)),
227               slices_[i]->end));
228           for (size_t j = slices_[i]->start; j < slices_[i]->end; ++j) {
229             size_t index = j % this->dataset()->buffer_size_;
230             TF_RETURN_IF_ERROR(writer->WriteScalar(
231                 this->full_name(strings::StrCat("buffer_", index, "_size")),
232                 buffer_[index].size()));
233             for (size_t k = 0; k < buffer_[index].size(); ++k) {
234               TF_RETURN_IF_ERROR(writer->WriteTensor(
235                   this->full_name(strings::StrCat("buffer_", index, "_", k)),
236                   buffer_[index][k]));
237             }
238           }
239         }
240 
241         return Status::OK();
242       }
243 
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)244       Status RestoreInternal(IteratorContext* ctx,
245                              IteratorStateReader* reader) override {
246         mutex_lock l(mu_);
247         // Restore the random number generators.
248         TF_RETURN_IF_ERROR(reader->ReadScalar(
249             this->full_name("num_random_samples"), &num_random_samples_));
250         TF_RETURN_IF_ERROR(reader->ReadScalar(this->full_name("seed"), &seed_));
251         TF_RETURN_IF_ERROR(
252             reader->ReadScalar(this->full_name("seed2"), &seed2_));
253         ResetRngs();
254 
255         // Restore the input iterator if it wasn't already exhausted.
256         if (!reader->Contains(this->full_name("end_of_input_sequence"))) {
257           TF_RETURN_IF_ERROR(this->dataset()->input_->MakeIterator(
258               ctx, this->prefix(), &input_impl_));
259           TF_RETURN_IF_ERROR(this->RestoreInput(ctx, reader, input_impl_));
260         } else {
261           input_impl_.reset();
262         }
263 
264         // Restore the epoch counter, buffer, and buffer slices.
265         TF_RETURN_IF_ERROR(
266             reader->ReadScalar(this->full_name("epoch"), &epoch_));
267         TF_RETURN_IF_ERROR(reader->ReadScalar(this->full_name("num_elements"),
268                                               &num_elements_));
269         size_t slices_size;
270         {
271           int64 temp;
272           TF_RETURN_IF_ERROR(
273               reader->ReadScalar(this->full_name("slices_size"), &temp));
274           slices_size = static_cast<size_t>(temp);
275         }
276         buffer_ = absl::make_unique<std::vector<Tensor>[]>(
277             this->dataset()->buffer_size_);
278         for (size_t i = 0; i < slices_size; ++i) {
279           int64 start;
280           TF_RETURN_IF_ERROR(reader->ReadScalar(
281               this->full_name(strings::StrCat("slices_start_", i)), &start));
282           int64 end;
283           TF_RETURN_IF_ERROR(reader->ReadScalar(
284               this->full_name(strings::StrCat("slices_end_", i)), &end));
285           slices_.push_back(absl::make_unique<Slice>(start, end));
286           for (size_t j = start; j < end; ++j) {
287             size_t index = j % this->dataset()->buffer_size_;
288             int64 list_size;
289             TF_RETURN_IF_ERROR(reader->ReadScalar(
290                 this->full_name(strings::StrCat("buffer_", index, "_size")),
291                 &list_size));
292             buffer_[index] = std::vector<Tensor>(list_size);
293             for (int k = 0; k < list_size; ++k) {
294               TF_RETURN_IF_ERROR(reader->ReadTensor(
295                   this->full_name(strings::StrCat("buffer_", index, "_", k)),
296                   &buffer_[index][k]));
297             }
298           }
299         }
300 
301         return Status::OK();
302       }
303 
304       mutex mu_;
305       int64 seed_ GUARDED_BY(mu_);
306       int64 seed2_ GUARDED_BY(mu_);
307 
308      private:
309       // Used to represent slices of `buffer_` that belong to different epochs.
310       // The invariant maintained by the implementation is: `start` <= `end`.
311       // When using `start` and `end` to index into `buffer_`, their values
312       // should be taken modulo the size of `buffer_` as their absolute value
313       // can be greater than the range of `buffer_`.
314       struct Slice {
Slicetensorflow::data::__anon0b80ed660111::ShuffleDatasetOpBase::ShuffleDatasetBase::Iterator::Slice315         Slice(int64 start, int64 end) : start(start), end(end) {}
316 
317         int64 start;
318         int64 end;
319       };
320 
Random()321       random::SingleSampleAdapter<random::PhiloxRandom>::ResultType Random()
322           EXCLUSIVE_LOCKS_REQUIRED(mu_) {
323         num_random_samples_++;
324         auto out = generator_();
325         return out;
326       }
327 
328       std::unique_ptr<std::vector<Tensor>[]> buffer_ GUARDED_BY(mu_);
329       std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
330       int64 epoch_ GUARDED_BY(mu_);
331       int64 num_elements_ GUARDED_BY(mu_);
332       std::deque<std::unique_ptr<Slice>> slices_ GUARDED_BY(mu_);
333       random::PhiloxRandom parent_generator_ GUARDED_BY(mu_);
334       random::SingleSampleAdapter<random::PhiloxRandom> generator_
335           GUARDED_BY(mu_);
336       int64 num_random_samples_ GUARDED_BY(mu_) = 0;
337     };
338 
339     const DatasetBase* const input_;
340     const int64 buffer_size_;
341     const int64 count_;
342   };
343 };
344 
345 class ShuffleDatasetOp : public ShuffleDatasetOpBase {
346  public:
ShuffleDatasetOp(OpKernelConstruction * ctx)347   explicit ShuffleDatasetOp(OpKernelConstruction* ctx)
348       : ShuffleDatasetOpBase(ctx) {
349     OP_REQUIRES_OK(ctx, ctx->GetAttr("reshuffle_each_iteration",
350                                      &reshuffle_each_iteration_));
351   }
352 
MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)353   void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
354                    DatasetBase** output) override {
355     int64 buffer_size;
356     OP_REQUIRES_OK(
357         ctx, ParseScalarArgument<int64>(ctx, "buffer_size", &buffer_size));
358     OP_REQUIRES(
359         ctx, buffer_size > 0,
360         errors::InvalidArgument("buffer_size must be greater than zero."));
361 
362     int64 seed;
363     OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "seed", &seed));
364 
365     int64 seed2;
366     OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "seed2", &seed2));
367 
368     // By TensorFlow convention, passing 0 for both seeds indicates
369     // that the shuffling should be seeded non-deterministically.
370     if (seed == 0 && seed2 == 0) {
371       seed = random::New64();
372       seed2 = random::New64();
373     }
374 
375     int64 count = 1;
376     if (reshuffle_each_iteration_) {
377       *output =
378           new ReshufflingDataset(ctx, input, buffer_size, seed, seed2, count);
379     } else {
380       *output =
381           new FixedSeedDataset(ctx, input, buffer_size, seed, seed2, count);
382     }
383   }
384 
385  private:
386   // A dataset that uses a pseudorandom sequence of seeds for the iterators
387   // created from it. Used when `reshuffle_each_iteration` is true.
388   class ReshufflingDataset : public ShuffleDatasetBase {
389    public:
ReshufflingDataset(OpKernelContext * ctx,const DatasetBase * input,int64 buffer_size,int64 seed,int64 seed2,int64 count)390     ReshufflingDataset(OpKernelContext* ctx, const DatasetBase* input,
391                        int64 buffer_size, int64 seed, int64 seed2, int64 count)
392         : ShuffleDatasetBase(ctx, input, buffer_size, count),
393           seed_(seed),
394           seed2_(seed2) {}
395 
DebugString() const396     string DebugString() const override {
397       return strings::StrCat("ShuffleDatasetOp(", buffer_size_, ", ", seed_,
398                              ", ", seed2_, ")::ReshufflingDataset");
399     }
400 
MakeIteratorInternal(const string & prefix) const401     std::unique_ptr<IteratorBase> MakeIteratorInternal(
402         const string& prefix) const override {
403       return absl::make_unique<Iterator>(
404           Iterator::Params{this, strings::StrCat(prefix, "::Shuffle")}, seed_,
405           seed2_);
406     }
407 
408    protected:
409     class RandomSeedGenerator : public ResourceBase {
410      public:
RandomSeedGenerator(int64 seed,int64 seed2)411       RandomSeedGenerator(int64 seed, int64 seed2)
412           : seed_(seed),
413             seed2_(seed2),
414             parent_generator_(seed, seed2),
415             generator_(&parent_generator_) {}
416 
DebugString() const417       string DebugString() const override {
418         return "ReshufflingDataset::RandomSeedGenerator";
419       }
420 
GenerateRandomSeeds(int64 * seed1,int64 * seed2)421       void GenerateRandomSeeds(int64* seed1, int64* seed2) {
422         mutex_lock l(mu_);
423         num_random_samples_++;
424         *seed1 = generator_();
425         num_random_samples_++;
426         *seed2 = generator_();
427       }
428 
num_random_samples()429       int64 num_random_samples() {
430         tf_shared_lock l(mu_);
431         return num_random_samples_;
432       }
433 
set_num_random_samples(int64 num_random_samples)434       void set_num_random_samples(int64 num_random_samples) {
435         mutex_lock l(mu_);
436         num_random_samples_ = num_random_samples;
437       }
438 
Reset()439       void Reset() {
440         mutex_lock l(mu_);
441         // Reset the generators based on the current seeds.
442         parent_generator_ = random::PhiloxRandom(seed_, seed2_);
443         generator_ = random::SingleSampleAdapter<random::PhiloxRandom>(
444             &parent_generator_);
445         generator_.Skip(num_random_samples_);
446       }
447 
448      private:
449       const int64 seed_;
450       const int64 seed2_;
451       mutex mu_;
452       random::PhiloxRandom parent_generator_ GUARDED_BY(mu_);
453       random::SingleSampleAdapter<random::PhiloxRandom> generator_
454           GUARDED_BY(mu_);
455       int64 num_random_samples_ GUARDED_BY(mu_) = 0;
456     };
457 
458     class Iterator : public ShuffleDatasetBase::Iterator<ReshufflingDataset> {
459      public:
Iterator(const Params & params,int64 seed,int64 seed2)460       explicit Iterator(const Params& params, int64 seed, int64 seed2)
461           : ShuffleDatasetBase::Iterator<ReshufflingDataset>(params, seed,
462                                                              seed2) {}
463 
~Iterator()464       ~Iterator() override { seed_generator_->Unref(); }
465 
Initialize(IteratorContext * ctx)466       Status Initialize(IteratorContext* ctx) override {
467         // Firstly, lookup or create a seed generator from the IteratorResource
468         // resource_mgr.
469         ResourceMgr* mgr = ctx->resource_mgr();
470         RandomSeedGenerator* seed_generator;
471         const string name = strings::StrCat(
472             prefix(), "::", dataset()->type_string(), "::RandomSeedGenerator");
473 
474         int64 dataset_seed, dataset_seed2;
475         {
476           tf_shared_lock l(mu_);
477           // Ideally we'd like to hold this lock in the LookupOrCreate method,
478           // but that trips up our Deadlock detection code.
479           dataset_seed = seed_;
480           dataset_seed2 = seed2_;
481         }
482         TF_RETURN_IF_ERROR(mgr->LookupOrCreate<RandomSeedGenerator>(
483             "tf_data", name, &seed_generator,
484             [dataset_seed,
485              dataset_seed2](RandomSeedGenerator** seed_generator) {
486               // On the first iterator creation, use the original seeds from the
487               // dataset to seed a `RandomSeedGenerator` that will provide seeds
488               // for subsequent repetitions of the same dataset.
489               *seed_generator =
490                   new RandomSeedGenerator(dataset_seed, dataset_seed2);
491               return Status::OK();
492             }));
493         // Now use the seed generator to update the base class Iterator seeds
494         // and random number generator with generated seeds for the current
495         // repetition.
496         mutex_lock l(mu_);
497         seed_generator->GenerateRandomSeeds(&seed_, &seed2_);
498         ResetRngs();
499         seed_generator_ = seed_generator;
500         return Status::OK();
501       }
502 
503      protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const504       std::shared_ptr<model::Node> CreateNode(
505           IteratorContext* ctx, model::Node::Args args) const override {
506         return model::MakeKnownRatioNode(std::move(args),
507                                          /*ratio=*/1);
508       }
509 
SaveInternal(IteratorStateWriter * writer)510       Status SaveInternal(IteratorStateWriter* writer) override {
511         // Save RNG state of Dataset.
512         TF_RETURN_IF_ERROR(
513             writer->WriteScalar(full_name("ds_num_random_samples"),
514                                 seed_generator_->num_random_samples()));
515 
516         // Save the Iterator.
517         return ShuffleDatasetBase::Iterator<ReshufflingDataset>::SaveInternal(
518             writer);
519       }
520 
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)521       Status RestoreInternal(IteratorContext* ctx,
522                              IteratorStateReader* reader) override {
523         // Restore RNG state of Dataset.
524         int64 num_random_samples;
525         TF_RETURN_IF_ERROR(reader->ReadScalar(
526             full_name("ds_num_random_samples"), &num_random_samples));
527         seed_generator_->set_num_random_samples(num_random_samples);
528         seed_generator_->Reset();
529 
530         // Restore the Iterator.
531         return ShuffleDatasetBase::Iterator<
532             ReshufflingDataset>::RestoreInternal(ctx, reader);
533       }
534 
535      private:
536       RandomSeedGenerator* seed_generator_;
537     };
538 
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const539     Status AsGraphDefInternal(SerializationContext* ctx,
540                               DatasetGraphDefBuilder* b,
541                               Node** output) const override {
542       Node* input_graph_node = nullptr;
543       TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
544       Node* buffer_size = nullptr;
545       Node* seed = nullptr;
546       Node* seed2 = nullptr;
547       AttrValue reshuffle_each_iteration;
548 
549       TF_RETURN_IF_ERROR(b->AddScalar(buffer_size_, &buffer_size));
550       TF_RETURN_IF_ERROR(b->AddScalar(seed_, &seed));
551       TF_RETURN_IF_ERROR(b->AddScalar(seed2_, &seed2));
552       b->BuildAttrValue(true, &reshuffle_each_iteration);
553       TF_RETURN_IF_ERROR(b->AddDataset(
554           this, {input_graph_node, buffer_size, seed, seed2},  // Inputs
555           {std::make_pair("reshuffle_each_iteration",
556                           reshuffle_each_iteration)},  // Attrs
557           output));
558       return Status::OK();
559     }
560 
561    private:
562     const int64 seed_;
563     const int64 seed2_;
564   };
565 
566   // A dataset that uses the same fixed seed for all iterators created from it.
567   // Used when `reshuffle_each_iteration` is false.
568   class FixedSeedDataset : public ShuffleDatasetBase {
569    public:
FixedSeedDataset(OpKernelContext * ctx,const DatasetBase * input,int64 buffer_size,int64 seed,int64 seed2,int64 count)570     FixedSeedDataset(OpKernelContext* ctx, const DatasetBase* input,
571                      int64 buffer_size, int64 seed, int64 seed2, int64 count)
572         : ShuffleDatasetBase(ctx, input, buffer_size, count),
573           seed_(seed),
574           seed2_(seed2) {}
575 
DebugString() const576     string DebugString() const override {
577       return strings::StrCat("ShuffleDatasetOp(", buffer_size_, ", ", seed_,
578                              ", ", seed2_, ")::FixedSeedDataset");
579     }
580 
MakeIteratorInternal(const string & prefix) const581     std::unique_ptr<IteratorBase> MakeIteratorInternal(
582         const string& prefix) const override {
583       return absl::make_unique<
584           ShuffleDatasetBase::Iterator<ShuffleDatasetBase>>(
585           ShuffleDatasetBase::Iterator<ShuffleDatasetBase>::Params{
586               this, strings::StrCat(prefix, "::Shuffle")},
587           seed_, seed2_);
588     }
589 
590    protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const591     Status AsGraphDefInternal(SerializationContext* ctx,
592                               DatasetGraphDefBuilder* b,
593                               Node** output) const override {
594       Node* input_graph_node = nullptr;
595       TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
596       Node* buffer_size = nullptr;
597       Node* seed = nullptr;
598       Node* seed2 = nullptr;
599       AttrValue reshuffle_each_iteration;
600 
601       TF_RETURN_IF_ERROR(b->AddScalar(buffer_size_, &buffer_size));
602       TF_RETURN_IF_ERROR(b->AddScalar(seed_, &seed));
603       TF_RETURN_IF_ERROR(b->AddScalar(seed2_, &seed2));
604       b->BuildAttrValue(false, &reshuffle_each_iteration);
605       TF_RETURN_IF_ERROR(b->AddDataset(
606           this, {input_graph_node, buffer_size, seed, seed2},  // Inputs
607           {std::make_pair("reshuffle_each_iteration",
608                           reshuffle_each_iteration)},  // Attrs
609           output));
610       return Status::OK();
611     }
612 
613    private:
614     const int64 seed_;
615     const int64 seed2_;
616   };
617 
618   bool reshuffle_each_iteration_;
619 };
620 
621 class ShuffleAndRepeatDatasetOp : public ShuffleDatasetOpBase {
622  public:
ShuffleAndRepeatDatasetOp(OpKernelConstruction * ctx)623   explicit ShuffleAndRepeatDatasetOp(OpKernelConstruction* ctx)
624       : ShuffleDatasetOpBase(ctx) {}
625 
MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)626   void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
627                    DatasetBase** output) override {
628     int64 buffer_size;
629     OP_REQUIRES_OK(
630         ctx, ParseScalarArgument<int64>(ctx, "buffer_size", &buffer_size));
631     OP_REQUIRES(
632         ctx, buffer_size > 0,
633         errors::InvalidArgument("buffer_size must be greater than zero."));
634 
635     int64 seed;
636     OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "seed", &seed));
637 
638     int64 seed2;
639     OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "seed2", &seed2));
640 
641     int64 count;
642     OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "count", &count));
643 
644     // By TensorFlow convention, if both seeds are 0, then shuffling should be
645     // seeded non-deterministically.
646     if (seed == 0 && seed2 == 0) {
647       seed = random::New64();
648       seed2 = random::New64();
649     }
650 
651     *output = new Dataset(ctx, input, buffer_size, seed, seed2, count);
652   }
653 
654  private:
655   class Dataset : public ShuffleDatasetBase {
656    public:
Dataset(OpKernelContext * ctx,const DatasetBase * input,int64 buffer_size,int64 seed,int64 seed2,int64 count)657     Dataset(OpKernelContext* ctx, const DatasetBase* input, int64 buffer_size,
658             int64 seed, int64 seed2, int64 count)
659         : ShuffleDatasetBase(ctx, input, buffer_size, count),
660           seed_(seed),
661           seed2_(seed2) {}
662 
DebugString() const663     string DebugString() const override {
664       return strings::StrCat("ShuffleAndRepeatDatasetOp(", buffer_size_, ", ",
665                              seed_, ", ", seed2_, ", ", count_, ")::Dataset");
666     }
667 
MakeIteratorInternal(const string & prefix) const668     std::unique_ptr<IteratorBase> MakeIteratorInternal(
669         const string& prefix) const override {
670       return absl::make_unique<
671           ShuffleDatasetBase::Iterator<ShuffleDatasetBase>>(
672           ShuffleDatasetBase::Iterator<ShuffleDatasetBase>::Params{
673               this, strings::StrCat(prefix, "::ShuffleAndRepeat")},
674           seed_, seed2_);
675     }
676 
677    protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const678     Status AsGraphDefInternal(SerializationContext* ctx,
679                               DatasetGraphDefBuilder* b,
680                               Node** output) const override {
681       Node* input_graph_node = nullptr;
682       TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
683       Node* buffer_size = nullptr;
684       Node* seed = nullptr;
685       Node* seed2 = nullptr;
686       Node* count = nullptr;
687 
688       TF_RETURN_IF_ERROR(b->AddScalar(buffer_size_, &buffer_size));
689       TF_RETURN_IF_ERROR(b->AddScalar(seed_, &seed));
690       TF_RETURN_IF_ERROR(b->AddScalar(seed2_, &seed2));
691       TF_RETURN_IF_ERROR(b->AddScalar(count_, &count));
692       TF_RETURN_IF_ERROR(b->AddDataset(
693           this, {input_graph_node, buffer_size, seed, seed2, count},  // Inputs
694           {},                                                         // Attrs
695           output));
696       return Status::OK();
697     }
698 
699    private:
700     const int64 seed_;
701     const int64 seed2_;
702   };
703 };
704 
705 REGISTER_KERNEL_BUILDER(Name("ShuffleDataset").Device(DEVICE_CPU),
706                         ShuffleDatasetOp);
707 
708 REGISTER_KERNEL_BUILDER(Name("ShuffleAndRepeatDataset").Device(DEVICE_CPU),
709                         ShuffleAndRepeatDatasetOp);
710 
711 }  // namespace
712 }  // namespace data
713 }  // namespace tensorflow
714