• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/shuffle_dataset_op.h"
16 
17 #include <deque>
18 #include <tuple>
19 #include <vector>
20 
21 #include "tensorflow/core/framework/dataset.h"
22 #include "tensorflow/core/framework/partial_tensor_shape.h"
23 #include "tensorflow/core/framework/resource_mgr.h"
24 #include "tensorflow/core/framework/tensor.h"
25 #include "tensorflow/core/kernels/data/dataset_utils.h"
26 #include "tensorflow/core/kernels/data/name_utils.h"
27 #include "tensorflow/core/kernels/data/random_seed_ops.h"
28 #include "tensorflow/core/lib/core/errors.h"
29 #include "tensorflow/core/lib/random/philox_random.h"
30 #include "tensorflow/core/lib/random/random.h"
31 #include "tensorflow/core/lib/random/random_distributions.h"
32 #include "tensorflow/core/platform/errors.h"
33 #include "tensorflow/core/platform/stringprintf.h"
34 
35 namespace tensorflow {
36 namespace data {
37 
38 // See documentation in ../../ops/dataset_ops.cc for a high-level
39 // description of the following op.
40 
41 /* static */ constexpr const char* const ShuffleDatasetOpBase::kInputDataset;
42 /* static */ constexpr const char* const ShuffleDatasetOpBase::kBufferSize;
43 /* static */ constexpr const char* const ShuffleDatasetOpBase::kSeed;
44 /* static */ constexpr const char* const ShuffleDatasetOpBase::kSeed2;
45 /* static */ constexpr const char* const ShuffleDatasetOpBase::kOutputTypes;
46 /* static */ constexpr const char* const ShuffleDatasetOpBase::kOutputShapes;
47 /* static */ constexpr const char* const
48     ShuffleDatasetOpBase::kReshuffleEachIteration;
49 
50 /* static */ constexpr const char* const ShuffleDatasetOp::kDatasetType;
51 
52 /* static */ constexpr const char* const
53     ShuffleAndRepeatDatasetOp::kDatasetType;
54 /* static */ constexpr const char* const ShuffleAndRepeatDatasetOp::kCount;
55 
56 const int64 kLogIntervalMicros = 10 * 1000000;  // 10 seconds.
57 const int64 kMaxEpochsInBuffer = 3;
58 
59 constexpr char kNumRandomSamples[] = "num_random_samples";
60 constexpr char kDataProduced[] = "data_produced";
61 constexpr char kEndOfInputSequence[] = "end_of_input_sequence";
62 constexpr char kEpoch[] = "epoch";
63 constexpr char kNumElements[] = "num_elements";
64 constexpr char kSlicesSize[] = "slices_size";
65 constexpr char kSlicesStart[] = "slices_start";
66 constexpr char kSlicesEnd[] = "slices_end";
67 constexpr char kBuffer[] = "buffer";
68 constexpr char kSize[] = "size";
69 constexpr char kSeedGenerator[] = "SeedGenerator";
70 constexpr char kTFData[] = "tf_data";
71 constexpr char kEpochNumRandomSamples[] = "epoch_num_random_samples";
72 constexpr char kShuffleDatasetV1[] = "ShuffleDataset";
73 constexpr char kShuffleDatasetV2[] = "ShuffleDatasetV2";
74 constexpr char kShuffleDatasetV3[] = "ShuffleDatasetV3";
75 constexpr char kShuffleAndRepeatDatasetV1[] = "ShuffleAndRepeatDataset";
76 constexpr char kShuffleAndRepeatDatasetV2[] = "ShuffleAndRepeatDatasetV2";
77 
ShuffleDatasetOpBase(OpKernelConstruction * ctx)78 ShuffleDatasetOpBase::ShuffleDatasetOpBase(OpKernelConstruction* ctx)
79     : UnaryDatasetOpKernel(ctx) {}
80 
81 // Abstract base dataset that implements a shuffling iterator.
82 class ShuffleDatasetOpBase::ShuffleDatasetBase : public DatasetBase {
83  public:
ShuffleDatasetBase(OpKernelContext * ctx,const DatasetBase * input,int64 buffer_size,std::shared_ptr<SeedGenerator> seed_generator,int64 count)84   ShuffleDatasetBase(OpKernelContext* ctx, const DatasetBase* input,
85                      int64 buffer_size,
86                      std::shared_ptr<SeedGenerator> seed_generator, int64 count)
87       : DatasetBase(DatasetContext(ctx)),
88         input_(input),
89         buffer_size_(buffer_size),
90         seed_generator_(std::move(seed_generator)),
91         count_(count),
92         traceme_metadata_(
93             {{"buffer_size",
94               strings::Printf("%lld", static_cast<long long>(buffer_size))}}) {
95     input_->Ref();
96   }
97 
~ShuffleDatasetBase()98   ~ShuffleDatasetBase() override { input_->Unref(); }
99 
100   virtual string op_type() const = 0;
101 
output_dtypes() const102   const DataTypeVector& output_dtypes() const override {
103     return input_->output_dtypes();
104   }
105 
output_shapes() const106   const std::vector<PartialTensorShape>& output_shapes() const override {
107     return input_->output_shapes();
108   }
109 
Cardinality() const110   int64 Cardinality() const override {
111     if (count_ == -1 || input_->Cardinality() == kInfiniteCardinality) {
112       return kInfiniteCardinality;
113     } else if (input_->Cardinality() == kUnknownCardinality) {
114       return kUnknownCardinality;
115     } else {
116       return input_->Cardinality() * count_;
117     }
118   }
119 
InputDatasets(std::vector<const DatasetBase * > * inputs) const120   Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
121     inputs->push_back(input_);
122     return Status::OK();
123   }
124 
CheckExternalState() const125   Status CheckExternalState() const override {
126     return input_->CheckExternalState();
127   }
128 
DebugString() const129   string DebugString() const override {
130     name_utils::DatasetDebugStringParams params;
131     params.set_args(buffer_size_, seed_generator_->seed(),
132                     seed_generator_->seed2(), count_);
133     return name_utils::DatasetDebugString(op_type(), params);
134   }
135 
MakeIteratorInternal(const string & prefix) const136   std::unique_ptr<IteratorBase> MakeIteratorInternal(
137       const string& prefix) const override {
138     return absl::make_unique<Iterator>(
139         Iterator::Params{this, name_utils::IteratorPrefix(op_type(), prefix)},
140         seed_generator_.get());
141   }
142 
143  protected:
144   class Iterator : public DatasetIterator<ShuffleDatasetBase> {
145    public:
Iterator(const Params & params,SeedGenerator * seed_generator)146     explicit Iterator(const Params& params, SeedGenerator* seed_generator)
147         : DatasetIterator<ShuffleDatasetBase>(params),
148           seed_generator_(seed_generator),
149           parent_generator_(seed_generator->seed(), seed_generator->seed2()),
150           generator_(&parent_generator_) {
151       buffer_ = absl::make_unique<std::vector<std::vector<Tensor>>>(
152           params.dataset->buffer_size_);
153       slices_.push_back(absl::make_unique<Slice>(0, 0));
154     }
155 
Initialize(IteratorContext * ctx)156     Status Initialize(IteratorContext* ctx) override {
157       mutex_lock l(mu_);
158       seed_generator_->GenerateSeeds(&seed_, &seed2_);
159       ResetRngs();
160       return Status::OK();
161     }
162 
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)163     Status GetNextInternal(IteratorContext* ctx,
164                            std::vector<Tensor>* out_tensors,
165                            bool* end_of_sequence) override {
166       mutex_lock l(mu_);
167       int64 start_micros = EnvTime::NowMicros();
168       int64 num_log_entries = 0;
169       if (!input_impl_ && epoch_ == 0) {
170         TF_RETURN_IF_ERROR(this->dataset()->input_->MakeIterator(
171             ctx, this, this->prefix(), &input_impl_));
172       }
173       while (input_impl_ && num_elements_ < this->dataset()->buffer_size_) {
174         if (EnvTime::NowMicros() >
175             ((num_log_entries + 1) * kLogIntervalMicros) + start_micros) {
176           num_log_entries++;
177           LOG(INFO) << "Filling up shuffle buffer (this may take a while): "
178                     << num_elements_ << " of " << this->dataset()->buffer_size_;
179         }
180         std::vector<Tensor> input_element;
181         bool end_of_input_sequence = false;
182         while (this->dataset()->count_ == -1 ||
183                epoch_ < this->dataset()->count_) {
184           TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, &input_element,
185                                                   &end_of_input_sequence));
186           if (!end_of_input_sequence) {
187             data_produced_ = true;
188             break;
189           }
190           if (ctx->split_provider() == nullptr && !data_produced_ &&
191               this->dataset()->count_ == -1) {
192             // If we encounter the end of sequence without producing data, we
193             // terminate the iteration immediately. (Otherwise, this iterator
194             // would loop infinitely and never produce a value.)
195             *end_of_sequence = true;
196             return Status::OK();
197           }
198           epoch_++;
199           int64 n = slices_.back()->end;
200           slices_.push_back(absl::make_unique<Slice>(n, n));
201           if (ctx->split_provider()) {
202             TF_RETURN_IF_ERROR(ctx->split_provider()->Reset());
203           }
204           TF_RETURN_IF_ERROR(this->dataset()->input_->MakeIterator(
205               ctx, this, this->prefix(), &input_impl_));
206         }
207         if (!end_of_input_sequence) {
208           if (num_elements_ == 0) {
209             VLOG(1) << "Starting to fill up shuffle buffer of size: "
210                     << this->dataset()->buffer_size_;
211           }
212           this->RecordBufferEnqueue(ctx, input_element);
213           buffer_->at(slices_.back()->end % this->dataset()->buffer_size_) =
214               std::move(input_element);
215           num_elements_++;
216           slices_.back()->end++;
217         } else {
218           input_impl_.reset();
219         }
220         if (slices_.size() > kMaxEpochsInBuffer) {
221           // When the elements stored in `buffer_` span more than
222           // `kMaxEpochsInBuffer` epochs, we do not fill the buffer further to
223           // conserve memory. This means that the upper bound on the size of
224           // `buffer_` is `kMaxEpochsInBuffer * cardinality(input_dataset) +
225           // 1`.
226           break;
227         }
228       }
229       if (num_log_entries > 0) {
230         LOG(INFO) << "Shuffle buffer filled.";
231       }
232 
233       if (num_elements_ > 0) {
234         *end_of_sequence = false;
235         // Garbage collect all empty slices.
236         while (!slices_.empty() &&
237                slices_.front()->start == slices_.front()->end) {
238           slices_.pop_front();
239           // Reinitialize the RNG state for the next epoch.
240           num_random_samples_ = 0;
241           seed_generator_->GenerateSeeds(&seed_, &seed2_);
242           ResetRngs();
243         }
244         DCHECK(!slices_.empty());
245         // Choose an element to produce uniformly at random from the first
246         // slice, and then remove the element from the slice.
247         int64 offset =
248             Random() % (slices_.front()->end - slices_.front()->start);
249         int64 index =
250             (slices_.front()->start + offset) % this->dataset()->buffer_size_;
251         *out_tensors = std::move(buffer_->at(index));
252         this->RecordBufferDequeue(ctx, *out_tensors);
253         std::swap(buffer_->at(index),
254                   buffer_->at(slices_.front()->start %
255                               this->dataset()->buffer_size_));
256         slices_.front()->start++;
257         num_elements_--;
258       } else {
259         DCHECK(input_impl_ == nullptr);
260         *end_of_sequence = true;
261       }
262       return Status::OK();
263     }
264 
265    protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const266     std::shared_ptr<model::Node> CreateNode(
267         IteratorContext* ctx, model::Node::Args args) const override {
268       return model::MakeKnownRatioNode(std::move(args),
269                                        /*ratio=*/1);
270     }
271 
ResetRngs()272     void ResetRngs() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
273       // Reset the generators based on the current iterator seeds.
274       parent_generator_ = random::PhiloxRandom(seed_, seed2_);
275       generator_ =
276           random::SingleSampleAdapter<random::PhiloxRandom>(&parent_generator_);
277       generator_.Skip(num_random_samples_);
278     }
279 
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)280     Status SaveInternal(SerializationContext* ctx,
281                         IteratorStateWriter* writer) override {
282       mutex_lock l(mu_);
283       // Save state needed to restore the random number generators.
284       TF_RETURN_IF_ERROR(
285           writer->WriteScalar(full_name(kEpochNumRandomSamples),
286                               seed_generator_->num_random_samples()));
287       TF_RETURN_IF_ERROR(writer->WriteScalar(this->full_name(kNumRandomSamples),
288                                              num_random_samples_));
289       TF_RETURN_IF_ERROR(writer->WriteScalar(this->full_name(kSeed), seed_));
290       TF_RETURN_IF_ERROR(writer->WriteScalar(this->full_name(kSeed2), seed2_));
291 
292       // Save input iterator if it hasn't been exhausted else write
293       // "end_of_input_sequence".
294       if (!input_impl_) {
295         TF_RETURN_IF_ERROR(
296             writer->WriteScalar(this->full_name(kEndOfInputSequence), ""));
297       } else {
298         TF_RETURN_IF_ERROR(this->SaveInput(ctx, writer, input_impl_));
299       }
300 
301       // Save the epoch counter, buffer, and buffer slices.
302       TF_RETURN_IF_ERROR(writer->WriteScalar(this->full_name(kEpoch), epoch_));
303       TF_RETURN_IF_ERROR(
304           writer->WriteScalar(this->full_name(kNumElements), num_elements_));
305       TF_RETURN_IF_ERROR(WriteElementsToCheckpoint(writer, prefix(), *buffer_));
306       TF_RETURN_IF_ERROR(
307           writer->WriteScalar(this->full_name(kSlicesSize), slices_.size()));
308       for (size_t i = 0; i < slices_.size(); ++i) {
309         TF_RETURN_IF_ERROR(
310             writer->WriteScalar(this->full_name(absl::StrJoin(
311                                     std::make_tuple(kSlicesStart, i), "_")),
312                                 slices_[i]->start));
313         TF_RETURN_IF_ERROR(writer->WriteScalar(
314             this->full_name(absl::StrJoin(std::make_tuple(kSlicesEnd, i), "_")),
315             slices_[i]->end));
316       }
317       if (data_produced_) {
318         TF_RETURN_IF_ERROR(
319             writer->WriteScalar(this->full_name(kDataProduced), ""));
320       }
321 
322       return Status::OK();
323     }
324 
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)325     Status RestoreInternal(IteratorContext* ctx,
326                            IteratorStateReader* reader) override {
327       mutex_lock l(mu_);
328       // Restore the random number generators.
329       int64 num_random_samples;
330       TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kEpochNumRandomSamples),
331                                             &num_random_samples));
332       seed_generator_->set_num_random_samples(num_random_samples);
333       seed_generator_->Reset();
334       TF_RETURN_IF_ERROR(reader->ReadScalar(this->full_name(kNumRandomSamples),
335                                             &num_random_samples_));
336       TF_RETURN_IF_ERROR(reader->ReadScalar(this->full_name(kSeed), &seed_));
337       TF_RETURN_IF_ERROR(reader->ReadScalar(this->full_name(kSeed2), &seed2_));
338       ResetRngs();
339 
340       // Restore the input iterator if it wasn't already exhausted.
341       if (!reader->Contains(this->full_name(kEndOfInputSequence))) {
342         TF_RETURN_IF_ERROR(this->dataset()->input_->MakeIterator(
343             ctx, this, this->prefix(), &input_impl_));
344         TF_RETURN_IF_ERROR(this->RestoreInput(ctx, reader, input_impl_));
345       } else {
346         input_impl_.reset();
347       }
348 
349       // Restore the epoch counter, buffer, and buffer slices.
350       TF_RETURN_IF_ERROR(reader->ReadScalar(this->full_name(kEpoch), &epoch_));
351       TF_RETURN_IF_ERROR(
352           reader->ReadScalar(this->full_name(kNumElements), &num_elements_));
353       size_t slices_size;
354       {
355         int64 temp;
356         TF_RETURN_IF_ERROR(
357             reader->ReadScalar(this->full_name(kSlicesSize), &temp));
358         slices_size = static_cast<size_t>(temp);
359       }
360       buffer_ = absl::make_unique<std::vector<std::vector<Tensor>>>(
361           this->dataset()->buffer_size_);
362       TF_RETURN_IF_ERROR(
363           ReadElementsFromCheckpoint(reader, prefix(), buffer_.get()));
364       slices_.clear();
365       for (size_t i = 0; i < slices_size; ++i) {
366         int64 start;
367         TF_RETURN_IF_ERROR(
368             reader->ReadScalar(this->full_name(absl::StrJoin(
369                                    std::make_tuple(kSlicesStart, i), "_")),
370                                &start));
371         int64 end;
372         TF_RETURN_IF_ERROR(reader->ReadScalar(
373             this->full_name(absl::StrJoin(std::make_tuple(kSlicesEnd, i), "_")),
374             &end));
375         slices_.push_back(absl::make_unique<Slice>(start, end));
376       }
377       data_produced_ = reader->Contains(this->full_name(kDataProduced));
378 
379       return Status::OK();
380     }
381 
GetTraceMeMetadata() const382     TraceMeMetadata GetTraceMeMetadata() const override {
383       return this->dataset()->traceme_metadata_;
384     }
385 
386    private:
387     // Used to represent slices of `buffer_` that belong to different epochs.
388     // The invariant maintained by the implementation is: `start` <= `end`.
389     // When using `start` and `end` to index into `buffer_`, their values
390     // should be taken modulo the size of `buffer_` as their absolute value
391     // can be greater than the range of `buffer_`.
392     struct Slice {
Slicetensorflow::data::ShuffleDatasetOpBase::ShuffleDatasetBase::Iterator::Slice393       Slice(int64 start, int64 end) : start(start), end(end) {}
394 
395       int64 start;
396       int64 end;
397     };
398 
Random()399     random::SingleSampleAdapter<random::PhiloxRandom>::ResultType Random()
400         TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
401       num_random_samples_++;
402       auto out = generator_();
403       return out;
404     }
405 
406     mutex mu_;
407     SeedGenerator* const seed_generator_ TF_GUARDED_BY(mu_);  // Not owned.
408     std::unique_ptr<std::vector<std::vector<Tensor>>> buffer_
409         TF_GUARDED_BY(mu_);
410     std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_) = nullptr;
411     int64 epoch_ TF_GUARDED_BY(mu_) = 0;
412     int64 num_elements_ TF_GUARDED_BY(mu_) = 0;
413     int64 seed_ TF_GUARDED_BY(mu_) = 0;
414     int64 seed2_ TF_GUARDED_BY(mu_) = 0;
415     // Indices into `buffer_` indicating which data belongs to which epoch.
416     // The slice at the front of the deque references data from the earliest
417     // buffered epoch. It is an invariant that all slices reference
418     // non-overlapping sections of `buffer_`.
419     std::deque<std::unique_ptr<Slice>> slices_ TF_GUARDED_BY(mu_);
420     random::PhiloxRandom parent_generator_ TF_GUARDED_BY(mu_);
421     random::SingleSampleAdapter<random::PhiloxRandom> generator_
422         TF_GUARDED_BY(mu_);
423     int64 num_random_samples_ TF_GUARDED_BY(mu_) = 0;
424     bool data_produced_ TF_GUARDED_BY(mu_) = false;
425   };
426 
427   const DatasetBase* const input_;
428   const int64 buffer_size_;
429   const std::shared_ptr<SeedGenerator> seed_generator_;
430   // The number of epochs to run for. Normally this is just 1, but sometimes we
431   // fuse shuffle and repeat together, and make the shuffle dataset op
432   // responsible for repeating as well.
433   const int64 count_;
434   const TraceMeMetadata traceme_metadata_;
435 };  // ShuffleDatasetBase
436 
437 // This version of memory dataset has an exclusive ownership of the seed
438 // generator resource. It supports sharing of the seed generator across
439 // different iterations of the `repeat` transformation but not across different
440 // iterators.
441 class ShuffleDatasetOp::Dataset : public ShuffleDatasetBase {
442  public:
Dataset(OpKernelContext * ctx,const DatasetBase * input,int64 buffer_size,int64 count,RandomSeeds && seeds,SeedGeneratorManager * manager,ResourceHandle && resource_handle)443   Dataset(OpKernelContext* ctx, const DatasetBase* input, int64 buffer_size,
444           int64 count, RandomSeeds&& seeds, SeedGeneratorManager* manager,
445           ResourceHandle&& resource_handle)
446       : ShuffleDatasetBase(ctx, input, buffer_size, manager->get(), count),
447         manager_(manager),
448         resource_handle_(std::move(resource_handle)),
449         resource_mgr_(ctx->resource_manager()),
450         seeds_(std::move(seeds)) {}
451 
~Dataset()452   ~Dataset() override {
453     manager_->Unref();
454     Status s = resource_mgr_->Delete<SeedGeneratorManager>(
455         resource_handle_.container(), resource_handle_.name());
456     if (!s.ok()) {
457       LOG(WARNING) << "Failed to delete RNG resource: " << s.ToString();
458     }
459   }
460 
op_type() const461   string op_type() const override { return kDatasetType; }
462 
463  protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const464   Status AsGraphDefInternal(SerializationContext* ctx,
465                             DatasetGraphDefBuilder* b,
466                             Node** output) const override {
467     Node* input_graph_node = nullptr;
468     TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
469     Node* buffer_size_node = nullptr;
470     Node* seed_node = nullptr;
471     Node* seed2_node = nullptr;
472     AttrValue reshuffle_each_iteration;
473 
474     TF_RETURN_IF_ERROR(b->AddScalar(buffer_size_, &buffer_size_node));
475     TF_RETURN_IF_ERROR(b->AddScalar(seeds_.input_seed(), &seed_node));
476     TF_RETURN_IF_ERROR(b->AddScalar(seeds_.input_seed2(), &seed2_node));
477     b->BuildAttrValue(seed_generator_->reshuffle_each_iteration(),
478                       &reshuffle_each_iteration);
479     TF_RETURN_IF_ERROR(b->AddDataset(
480         this,
481         {input_graph_node, buffer_size_node, seed_node, seed2_node},  // Inputs
482         {std::make_pair(kReshuffleEachIteration,
483                         reshuffle_each_iteration)},  // Attrs
484         output));
485     return Status::OK();
486   }
487 
488  private:
489   SeedGeneratorManager* const manager_;  // Owned.
490   const ResourceHandle resource_handle_;
491   ResourceMgr* const resource_mgr_;  // Not owned.
492   const RandomSeeds seeds_;
493 };
494 
495 // This version of shuffle dataset has a shared ownership of the seed generator
496 // resource. It supports sharing of the generator state across different
497 // iterations of the `repeat` transformation and also across different
498 // iterators.
499 class ShuffleDatasetOp::DatasetV2 : public ShuffleDatasetBase {
500  public:
DatasetV2(OpKernelContext * ctx,const DatasetBase * input,int64 buffer_size,int64 count,SeedGeneratorManager * manager,ResourceHandle && resource_handle,bool owns_resource)501   DatasetV2(OpKernelContext* ctx, const DatasetBase* input, int64 buffer_size,
502             int64 count, SeedGeneratorManager* manager,
503             ResourceHandle&& resource_handle, bool owns_resource)
504       : ShuffleDatasetBase(ctx, input, buffer_size, manager->get(), count),
505         manager_(manager),
506         owns_resource_(owns_resource),
507         resource_handle_(std::move(resource_handle)),
508         resource_mgr_(ctx->resource_manager()) {}
509 
~DatasetV2()510   ~DatasetV2() override {
511     manager_->Unref();
512     if (owns_resource_) {
513       Status s = resource_mgr_->Delete<SeedGeneratorManager>(
514           resource_handle_.container(), resource_handle_.name());
515       if (!s.ok()) {
516         LOG(WARNING) << "Failed to delete RNG resource: " << s.ToString();
517       }
518     }
519   }
520 
op_type() const521   string op_type() const override { return kDatasetType; }
522 
523  protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const524   Status AsGraphDefInternal(SerializationContext* ctx,
525                             DatasetGraphDefBuilder* b,
526                             Node** output) const override {
527     Node* input_graph_node = nullptr;
528     TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
529     Node* buffer_size_node = nullptr;
530     TF_RETURN_IF_ERROR(b->AddScalar(buffer_size_, &buffer_size_node));
531     Node* resource_handle_node = nullptr;
532     Tensor handle(DT_RESOURCE, TensorShape({}));
533     handle.scalar<ResourceHandle>()() = resource_handle_;
534     TF_RETURN_IF_ERROR(b->AddTensor(handle, &resource_handle_node));
535     TF_RETURN_IF_ERROR(b->AddDataset(
536         this,
537         {input_graph_node, buffer_size_node, resource_handle_node},  // Inputs
538         {},                                                          // Attrs
539         output));
540     return Status::OK();
541   }
542 
543  private:
544   SeedGeneratorManager* const manager_;  // Owned.
545   const bool owns_resource_;
546   const ResourceHandle resource_handle_;
547   ResourceMgr* const resource_mgr_;  // Not owned.
548 };
549 
550 // This version of shuffle dataset extends the functionality of DatasetV2 with
551 // the ability to preserve seed generator configuration (i.e. initial seeds and
552 // whether to reshuffle each iteration) across serialization of the dataset.
553 class ShuffleDatasetOp::DatasetV3 : public ShuffleDatasetBase {
554  public:
DatasetV3(OpKernelContext * ctx,const DatasetBase * input,int64 buffer_size,int64 count,RandomSeeds && seeds,SeedGeneratorManager * manager,ResourceHandle && resource_handle,bool owns_resource)555   DatasetV3(OpKernelContext* ctx, const DatasetBase* input, int64 buffer_size,
556             int64 count, RandomSeeds&& seeds, SeedGeneratorManager* manager,
557             ResourceHandle&& resource_handle, bool owns_resource)
558       : ShuffleDatasetBase(ctx, input, buffer_size, manager->get(), count),
559         manager_(manager),
560         owns_resource_(owns_resource),
561         resource_handle_(std::move(resource_handle)),
562         resource_mgr_(ctx->resource_manager()),
563         seeds_(std::move(seeds)) {}
564 
~DatasetV3()565   ~DatasetV3() override {
566     manager_->Unref();
567     if (owns_resource_) {
568       Status s = resource_mgr_->Delete<SeedGeneratorManager>(
569           resource_handle_.container(), resource_handle_.name());
570       if (!s.ok()) {
571         LOG(WARNING) << "Failed to delete RNG resource: " << s.ToString();
572       }
573     }
574   }
575 
op_type() const576   string op_type() const override { return kDatasetType; }
577 
578  protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const579   Status AsGraphDefInternal(SerializationContext* ctx,
580                             DatasetGraphDefBuilder* b,
581                             Node** output) const override {
582     Node* input_graph_node = nullptr;
583     TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
584     Node* buffer_size_node = nullptr;
585     Node* seed_node = nullptr;
586     Node* seed2_node = nullptr;
587     TF_RETURN_IF_ERROR(b->AddScalar(buffer_size_, &buffer_size_node));
588     TF_RETURN_IF_ERROR(b->AddScalar(seeds_.input_seed(), &seed_node));
589     TF_RETURN_IF_ERROR(b->AddScalar(seeds_.input_seed2(), &seed2_node));
590     Node* resource_handle_node = nullptr;
591     Tensor handle(DT_RESOURCE, TensorShape({}));
592     handle.scalar<ResourceHandle>()() = resource_handle_;
593     TF_RETURN_IF_ERROR(b->AddTensor(handle, &resource_handle_node));
594     AttrValue reshuffle_each_iteration;
595     b->BuildAttrValue(seed_generator_->reshuffle_each_iteration(),
596                       &reshuffle_each_iteration);
597     TF_RETURN_IF_ERROR(
598         b->AddDataset(this,
599                       {input_graph_node, buffer_size_node, seed_node,
600                        seed2_node, resource_handle_node},  // Inputs
601                       {std::make_pair(kReshuffleEachIteration,
602                                       reshuffle_each_iteration)},  // Attrs
603                       output));
604     return Status::OK();
605   }
606 
607  private:
608   SeedGeneratorManager* const manager_;  // Owned
609   const bool owns_resource_;
610   const ResourceHandle resource_handle_;
611   ResourceMgr* const resource_mgr_;  // Not owned.
612   const RandomSeeds seeds_;
613 };
614 
ShuffleDatasetOp(OpKernelConstruction * ctx)615 ShuffleDatasetOp::ShuffleDatasetOp(OpKernelConstruction* ctx)
616     : ShuffleDatasetOpBase(ctx) {
617   auto& op_name = ctx->def().op();
618   if (op_name == kShuffleDatasetV3) {
619     op_version_ = 3;
620   } else if (op_name == kShuffleDatasetV2) {
621     op_version_ = 2;
622   } else if (op_name == kShuffleDatasetV1) {
623     op_version_ = 1;
624   }
625   if (ctx->HasAttr(kReshuffleEachIteration)) {
626     OP_REQUIRES_OK(
627         ctx, ctx->GetAttr(kReshuffleEachIteration, &reshuffle_each_iteration_));
628   }
629 }
630 
MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)631 void ShuffleDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
632                                    DatasetBase** output) {
633   int64 buffer_size = 0;
634   OP_REQUIRES_OK(ctx,
635                  ParseScalarArgument<int64>(ctx, kBufferSize, &buffer_size));
636   OP_REQUIRES(
637       ctx, buffer_size > 0,
638       errors::InvalidArgument("buffer_size must be greater than zero."));
639 
640   int64 count = 1;
641   static std::atomic<int64> resource_id_counter(0);
642   const string& container = ctx->resource_manager()->default_container();
643   auto name = strings::StrCat(ctx->op_kernel().name(), "/", kSeedGenerator, "_",
644                               resource_id_counter.fetch_add(1));
645   if (op_version_ == 3) {
646     auto handle = HandleFromInput(ctx, 4);
647     SeedGeneratorManager* manager = nullptr;
648     Status s = ctx->resource_manager()->Lookup<SeedGeneratorManager>(
649         handle.container(), handle.name(), &manager);
650     int64 seed;
651     OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, kSeed, &seed));
652     int64 seed2;
653     OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, kSeed2, &seed2));
654     RandomSeeds seeds(seed, seed2);
655     bool owns_resource = false;
656     if (errors::IsNotFound(s)) {
657       owns_resource = true;
658       OP_REQUIRES_OK(
659           ctx,
660           ctx->resource_manager()->LookupOrCreate<SeedGeneratorManager>(
661               container, name, &manager,
662               [reshuffle = reshuffle_each_iteration_,
663                &seeds](SeedGeneratorManager** manager) {
664                 if (reshuffle) {
665                   *manager =
666                       new SeedGeneratorManager(new RandomSeedGenerator(seeds));
667                 } else {
668                   *manager =
669                       new SeedGeneratorManager(new FixedSeedGenerator(seeds));
670                 }
671                 return Status::OK();
672               }));
673       handle = MakeResourceHandle<SeedGenerator>(ctx, container, name);
674     } else {
675       OP_REQUIRES_OK(ctx, s);
676     }
677 
678     // Ownership of manager is transferred onto `DatasetV3`.
679     *output = new ShuffleDatasetOp::DatasetV3(ctx, input, buffer_size, count,
680                                               std::move(seeds), manager,
681                                               std::move(handle), owns_resource);
682   } else if (op_version_ == 2) {
683     auto handle = HandleFromInput(ctx, 2);
684     SeedGeneratorManager* manager = nullptr;
685     Status s = ctx->resource_manager()->Lookup<SeedGeneratorManager>(
686         handle.container(), handle.name(), &manager);
687     bool owns_resource = false;
688     if (errors::IsNotFound(s)) {
689       owns_resource = true;
690       LOG(WARNING) << "Failed to find seed generator resource. Falling back to "
691                       "using a non-deterministically seeded generator and "
692                       "reshuffling each iteration.";
693       RandomSeeds seeds(0, 0);
694       OP_REQUIRES_OK(
695           ctx, ctx->resource_manager()->LookupOrCreate<SeedGeneratorManager>(
696                    container, name, &manager,
697                    [&seeds](SeedGeneratorManager** manager) {
698                      *manager = new SeedGeneratorManager(
699                          new RandomSeedGenerator(seeds));
700                      return Status::OK();
701                    }));
702       handle = MakeResourceHandle<SeedGeneratorManager>(ctx, container, name);
703     } else {
704       OP_REQUIRES_OK(ctx, s);
705     }
706 
707     // Ownership of manager is transferred onto `DatasetV2`.
708     *output =
709         new ShuffleDatasetOp::DatasetV2(ctx, input, buffer_size, count, manager,
710                                         std::move(handle), owns_resource);
711   } else {
712     if (op_version_ != 1) {
713       LOG(WARNING) << "Unsupported version of shuffle dataset op: "
714                    << op_version_ << ". Defaulting to version 1.";
715     }
716     int64 seed;
717     OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, kSeed, &seed));
718     int64 seed2;
719     OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, kSeed2, &seed2));
720     RandomSeeds seeds(seed, seed2);
721     SeedGeneratorManager* manager;
722     OP_REQUIRES_OK(
723         ctx,
724         ctx->resource_manager()->LookupOrCreate<SeedGeneratorManager>(
725             container, name, &manager,
726             [reshuffle = reshuffle_each_iteration_,
727              &seeds](SeedGeneratorManager** manager) {
728               if (reshuffle) {
729                 *manager =
730                     new SeedGeneratorManager(new RandomSeedGenerator(seeds));
731               } else {
732                 *manager =
733                     new SeedGeneratorManager(new FixedSeedGenerator(seeds));
734               }
735               return Status::OK();
736             }));
737     auto handle =
738         MakeResourceHandle<SeedGeneratorManager>(ctx, container, name);
739 
740     // Ownership of manager is transferred onto `Dataset`.
741     *output = new ShuffleDatasetOp::Dataset(ctx, input, buffer_size, count,
742                                             std::move(seeds), manager,
743                                             std::move(handle));
744   }
745 }
746 
747 class ShuffleAndRepeatDatasetOp::Dataset : public ShuffleDatasetBase {
748  public:
Dataset(OpKernelContext * ctx,const DatasetBase * input,int64 buffer_size,RandomSeeds && seeds,SeedGeneratorManager * manager,int64 count,ResourceHandle && resource_handle)749   Dataset(OpKernelContext* ctx, const DatasetBase* input, int64 buffer_size,
750           RandomSeeds&& seeds, SeedGeneratorManager* manager, int64 count,
751           ResourceHandle&& resource_handle)
752       : ShuffleDatasetBase(ctx, input, buffer_size, manager->get(), count),
753         manager_(manager),
754         resource_handle_(std::move(resource_handle)),
755         resource_mgr_(ctx->resource_manager()),
756         seeds_(std::move(seeds)) {}
757 
~Dataset()758   ~Dataset() override {
759     manager_->Unref();
760     Status s = resource_mgr_->Delete<SeedGeneratorManager>(
761         resource_handle_.container(), resource_handle_.name());
762     if (!s.ok()) {
763       LOG(WARNING) << "Failed to delete RNG resource: " << s.ToString();
764     }
765   }
766 
op_type() const767   string op_type() const override { return kDatasetType; }
768 
769  protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const770   Status AsGraphDefInternal(SerializationContext* ctx,
771                             DatasetGraphDefBuilder* b,
772                             Node** output) const override {
773     Node* input_graph_node = nullptr;
774     TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
775     Node* buffer_size = nullptr;
776     Node* seed = nullptr;
777     Node* seed2 = nullptr;
778     Node* count = nullptr;
779 
780     TF_RETURN_IF_ERROR(b->AddScalar(buffer_size_, &buffer_size));
781     TF_RETURN_IF_ERROR(b->AddScalar(seeds_.input_seed(), &seed));
782     TF_RETURN_IF_ERROR(b->AddScalar(seeds_.input_seed2(), &seed2));
783     TF_RETURN_IF_ERROR(b->AddScalar(count_, &count));
784     AttrValue reshuffle_each_iteration;
785     b->BuildAttrValue(seed_generator_->reshuffle_each_iteration(),
786                       &reshuffle_each_iteration);
787     TF_RETURN_IF_ERROR(b->AddDataset(
788         this, {input_graph_node, buffer_size, seed, seed2, count},  // Inputs
789         {std::make_pair(kReshuffleEachIteration,
790                         reshuffle_each_iteration)},  // Attrs
791         output));
792     return Status::OK();
793   }
794 
795  private:
796   SeedGeneratorManager* const manager_;  // Owned.
797   const ResourceHandle resource_handle_;
798   ResourceMgr* const resource_mgr_;  // Not owned.
799   const RandomSeeds seeds_;
800 };
801 
802 class ShuffleAndRepeatDatasetOp::DatasetV2 : public ShuffleDatasetBase {
803  public:
DatasetV2(OpKernelContext * ctx,const DatasetBase * input,int64 buffer_size,int64 count,RandomSeeds && seeds,SeedGeneratorManager * manager,ResourceHandle && resource_handle,bool owns_resource)804   DatasetV2(OpKernelContext* ctx, const DatasetBase* input, int64 buffer_size,
805             int64 count, RandomSeeds&& seeds, SeedGeneratorManager* manager,
806             ResourceHandle&& resource_handle, bool owns_resource)
807       : ShuffleDatasetBase(ctx, input, buffer_size, manager->get(), count),
808         manager_(manager),
809         owns_resource_(owns_resource),
810         resource_handle_(std::move(resource_handle)),
811         resource_mgr_(ctx->resource_manager()),
812         seeds_(std::move(seeds)) {}
813 
~DatasetV2()814   ~DatasetV2() override {
815     manager_->Unref();
816     if (owns_resource_) {
817       Status s = resource_mgr_->Delete<SeedGeneratorManager>(
818           resource_handle_.container(), resource_handle_.name());
819       if (!s.ok()) {
820         LOG(WARNING) << "Failed to delete RNG resource: " << s.ToString();
821       }
822     }
823   }
824 
op_type() const825   string op_type() const override { return kDatasetType; }
826 
827  protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const828   Status AsGraphDefInternal(SerializationContext* ctx,
829                             DatasetGraphDefBuilder* b,
830                             Node** output) const override {
831     Node* input_graph_node = nullptr;
832     TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
833     Node* buffer_size_node = nullptr;
834     Node* seed_node = nullptr;
835     Node* seed2_node = nullptr;
836     Node* count_node = nullptr;
837     TF_RETURN_IF_ERROR(b->AddScalar(buffer_size_, &buffer_size_node));
838     TF_RETURN_IF_ERROR(b->AddScalar(seeds_.input_seed(), &seed_node));
839     TF_RETURN_IF_ERROR(b->AddScalar(seeds_.input_seed2(), &seed2_node));
840     TF_RETURN_IF_ERROR(b->AddScalar(count_, &count_node));
841     Node* resource_handle_node = nullptr;
842     Tensor handle(DT_RESOURCE, TensorShape({}));
843     handle.scalar<ResourceHandle>()() = resource_handle_;
844     TF_RETURN_IF_ERROR(b->AddTensor(handle, &resource_handle_node));
845     AttrValue reshuffle_each_iteration;
846     b->BuildAttrValue(seed_generator_->reshuffle_each_iteration(),
847                       &reshuffle_each_iteration);
848     TF_RETURN_IF_ERROR(
849         b->AddDataset(this,
850                       {input_graph_node, buffer_size_node, seed_node,
851                        seed2_node, count_node, resource_handle_node},  // Inputs
852                       {std::make_pair(kReshuffleEachIteration,
853                                       reshuffle_each_iteration)},  // Attrs
854                       output));
855     return Status::OK();
856   }
857 
858  private:
859   SeedGeneratorManager* const manager_;  // Owned
860   const bool owns_resource_;
861   const ResourceHandle resource_handle_;
862   ResourceMgr* const resource_mgr_;  // Not owned.
863   const RandomSeeds seeds_;
864 };
865 
ShuffleAndRepeatDatasetOp(OpKernelConstruction * ctx)866 ShuffleAndRepeatDatasetOp::ShuffleAndRepeatDatasetOp(OpKernelConstruction* ctx)
867     : ShuffleDatasetOpBase(ctx) {
868   auto& op_name = ctx->def().op();
869   if (op_name == kShuffleAndRepeatDatasetV2) {
870     op_version_ = 2;
871   } else if (op_name == kShuffleAndRepeatDatasetV1) {
872     op_version_ = 1;
873   }
874   if (ctx->HasAttr(kReshuffleEachIteration)) {
875     OP_REQUIRES_OK(
876         ctx, ctx->GetAttr(kReshuffleEachIteration, &reshuffle_each_iteration_));
877   }
878 }
879 
MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)880 void ShuffleAndRepeatDatasetOp::MakeDataset(OpKernelContext* ctx,
881                                             DatasetBase* input,
882                                             DatasetBase** output) {
883   int64 buffer_size = 0;
884   OP_REQUIRES_OK(ctx,
885                  ParseScalarArgument<int64>(ctx, kBufferSize, &buffer_size));
886   OP_REQUIRES(
887       ctx, buffer_size > 0,
888       errors::InvalidArgument("buffer_size must be greater than zero."));
889 
890   int64 seed;
891   OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, kSeed, &seed));
892 
893   int64 seed2;
894   OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, kSeed2, &seed2));
895 
896   int64 count;
897   OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, kCount, &count));
898 
899   OP_REQUIRES(ctx, count > 0 || count == -1,
900               errors::InvalidArgument(
901                   "count must be greater than zero or equal to -1."));
902 
903   RandomSeeds seeds(seed, seed2);
904 
905   static std::atomic<int64> resource_id_counter(0);
906   const string& container = ctx->resource_manager()->default_container();
907   auto name = strings::StrCat(ctx->op_kernel().name(), "/", kSeedGenerator, "_",
908                               resource_id_counter.fetch_add(1));
909   if (op_version_ == 2) {
910     auto handle = HandleFromInput(ctx, 5);
911     SeedGeneratorManager* manager = nullptr;
912     Status s = ctx->resource_manager()->Lookup<SeedGeneratorManager>(
913         handle.container(), handle.name(), &manager);
914     bool owns_resource = false;
915     if (errors::IsNotFound(s)) {
916       owns_resource = true;
917       OP_REQUIRES_OK(
918           ctx,
919           ctx->resource_manager()->LookupOrCreate<SeedGeneratorManager>(
920               container, name, &manager,
921               [reshuffle = reshuffle_each_iteration_,
922                &seeds](SeedGeneratorManager** manager) {
923                 if (reshuffle) {
924                   *manager =
925                       new SeedGeneratorManager(new RandomSeedGenerator(seeds));
926                 } else {
927                   *manager =
928                       new SeedGeneratorManager(new FixedSeedGenerator(seeds));
929                 }
930                 return Status::OK();
931               }));
932       handle = MakeResourceHandle<SeedGenerator>(ctx, container, name);
933     } else {
934       OP_REQUIRES_OK(ctx, s);
935     }
936 
937     // Ownership of manager is transferred onto `DatasetV2`.
938     *output = new ShuffleAndRepeatDatasetOp::DatasetV2(
939         ctx, input, buffer_size, count, std::move(seeds), manager,
940         std::move(handle), owns_resource);
941   } else {
942     if (op_version_ != 1) {
943       LOG(WARNING) << "Unsupported version of shuffle dataset op: "
944                    << op_version_ << ". Defaulting to version 1.";
945     }
946     SeedGeneratorManager* manager;
947     OP_REQUIRES_OK(
948         ctx,
949         ctx->resource_manager()->LookupOrCreate<SeedGeneratorManager>(
950             container, name, &manager,
951             [reshuffle = reshuffle_each_iteration_,
952              &seeds](SeedGeneratorManager** manager) {
953               if (reshuffle) {
954                 *manager =
955                     new SeedGeneratorManager(new RandomSeedGenerator(seeds));
956               } else {
957                 *manager =
958                     new SeedGeneratorManager(new FixedSeedGenerator(seeds));
959               }
960               return Status::OK();
961             }));
962     auto handle =
963         MakeResourceHandle<SeedGeneratorManager>(ctx, container, name);
964 
965     // Ownership of manager is transferred onto `Dataset`.
966     *output = new Dataset(ctx, input, buffer_size, std::move(seeds), manager,
967                           count, std::move(handle));
968   }
969 }
970 
971 namespace {
972 REGISTER_KERNEL_BUILDER(Name("ShuffleDataset").Device(DEVICE_CPU),
973                         ShuffleDatasetOp);
974 
975 REGISTER_KERNEL_BUILDER(Name("ShuffleDatasetV2").Device(DEVICE_CPU),
976                         ShuffleDatasetOp);
977 
978 REGISTER_KERNEL_BUILDER(Name("ShuffleDatasetV3").Device(DEVICE_CPU),
979                         ShuffleDatasetOp);
980 
981 REGISTER_KERNEL_BUILDER(Name("ShuffleAndRepeatDataset").Device(DEVICE_CPU),
982                         ShuffleAndRepeatDatasetOp);
983 
984 REGISTER_KERNEL_BUILDER(Name("ShuffleAndRepeatDatasetV2").Device(DEVICE_CPU),
985                         ShuffleAndRepeatDatasetOp);
986 }  // namespace
987 }  // namespace data
988 }  // namespace tensorflow
989