• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <deque>
17 #include <vector>
18 
19 #include "tensorflow/core/framework/partial_tensor_shape.h"
20 #include "tensorflow/core/framework/tensor.h"
21 #include "tensorflow/core/kernels/data/dataset.h"
22 #include "tensorflow/core/lib/random/philox_random.h"
23 #include "tensorflow/core/lib/random/random.h"
24 #include "tensorflow/core/lib/random/random_distributions.h"
25 
26 namespace tensorflow {
27 
28 namespace {
29 
30 const int64 kLogIntervalMicros = 10 * 1000000;  // 10 seconds.
31 
32 // See documentation in ../ops/dataset_ops.cc for a high-level
33 // description of the following op.
34 
35 class ShuffleDatasetOpBase : public UnaryDatasetOpKernel {
36  public:
ShuffleDatasetOpBase(OpKernelConstruction * ctx)37   explicit ShuffleDatasetOpBase(OpKernelConstruction* ctx)
38       : UnaryDatasetOpKernel(ctx) {}
39 
40  protected:
41   // Abstract base dataset that implements a shuffling iterator.
42   class ShuffleDatasetBase : public GraphDatasetBase {
43    public:
ShuffleDatasetBase(OpKernelContext * ctx,const DatasetBase * input,int64 buffer_size,int64 count)44     ShuffleDatasetBase(OpKernelContext* ctx, const DatasetBase* input,
45                        int64 buffer_size, int64 count)
46         : GraphDatasetBase(ctx),
47           input_(input),
48           buffer_size_(buffer_size),
49           count_(count) {
50       input_->Ref();
51     }
52 
~ShuffleDatasetBase()53     ~ShuffleDatasetBase() override { input_->Unref(); }
54 
output_dtypes() const55     const DataTypeVector& output_dtypes() const override {
56       return input_->output_dtypes();
57     }
58 
output_shapes() const59     const std::vector<PartialTensorShape>& output_shapes() const override {
60       return input_->output_shapes();
61     }
62 
63    protected:
64     class Iterator : public DatasetIterator<ShuffleDatasetBase> {
65      public:
Iterator(const Params & params,int64 seed,int64 seed2)66       explicit Iterator(const Params& params, int64 seed, int64 seed2)
67           : DatasetIterator<ShuffleDatasetBase>(params),
68             input_impl_(nullptr),
69             seed_(seed),
70             seed2_(seed2),
71             epoch_(0),
72             num_elements_(0),
73             parent_generator_(seed, seed2),
74             generator_(&parent_generator_) {
75         buffer_.reset(new std::vector<Tensor>[params.dataset->buffer_size_]);
76         slices_.emplace_back(new Slice{0, 0});
77       }
78 
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)79       Status GetNextInternal(IteratorContext* ctx,
80                              std::vector<Tensor>* out_tensors,
81                              bool* end_of_sequence) override {
82         mutex_lock l(mu_);
83         int64 start_micros = ctx->env()->NowMicros();
84         int64 num_log_entries = 0;
85         bool first_call = false;
86         if (!input_impl_ && epoch_ == 0) {
87           first_call = true;
88           input_impl_ = dataset()->input_->MakeIterator(prefix());
89         }
90         while (input_impl_ && num_elements_ < dataset()->buffer_size_) {
91           if (ctx->env()->NowMicros() >
92               ((num_log_entries + 1) * kLogIntervalMicros) + start_micros) {
93             num_log_entries++;
94             LOG(INFO) << "Filling up shuffle buffer (this may take a while): "
95                       << num_elements_ << " of " << dataset()->buffer_size_;
96           }
97           std::vector<Tensor> input_element;
98           bool end_of_input_sequence = false;
99           while (dataset()->count_ == -1 || epoch_ < dataset()->count_) {
100             TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, &input_element,
101                                                     &end_of_input_sequence));
102             if (!end_of_input_sequence) {
103               first_call = false;
104               break;
105             }
106             if (first_call && dataset()->count_ == -1) {
107               // If the first call to GetNext() fails because the end
108               // of sequence has been reached, we terminate the
109               // iteration immediately. (Otherwise, this iterator
110               // would loop infinitely and never produce a value.)
111               *end_of_sequence = true;
112               return Status::OK();
113             }
114             epoch_++;
115             int64 n = slices_.back()->end;
116             slices_.emplace_back(new Slice{n, n});
117             input_impl_ = dataset()->input_->MakeIterator(prefix());
118           }
119           if (!end_of_input_sequence) {
120             buffer_[slices_.back()->end % dataset()->buffer_size_] =
121                 std::move(input_element);
122             num_elements_++;
123             slices_.back()->end++;
124           } else {
125             input_impl_.reset();
126           }
127         }
128         if (num_log_entries > 0) {
129           LOG(INFO) << "Shuffle buffer filled.";
130         }
131 
132         if (num_elements_ > 0) {
133           *end_of_sequence = false;
134           // Garbage collect all empty slices.
135           while (!slices_.empty() &&
136                  slices_.front()->start == slices_.front()->end) {
137             slices_.pop_front();
138           }
139           DCHECK(!slices_.empty());
140           // Choose an element to produce uniformly at random from the first
141           // slice, and then remove the element from the slice.
142           int64 offset =
143               Random() % (slices_.front()->end - slices_.front()->start);
144           int64 index =
145               (slices_.front()->start + offset) % dataset()->buffer_size_;
146           *out_tensors = std::move(buffer_[index]);
147           std::swap(buffer_[index],
148                     buffer_[slices_.front()->start % dataset()->buffer_size_]);
149           slices_.front()->start++;
150           num_elements_--;
151         } else {
152           DCHECK(input_impl_ == nullptr);
153           *end_of_sequence = true;
154         }
155         return Status::OK();
156       }
157 
158      protected:
SaveInternal(IteratorStateWriter * writer)159       Status SaveInternal(IteratorStateWriter* writer) override {
160         mutex_lock l(mu_);
161 
162         // Save state needed to restore the random number generators.
163         TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("num_random_samples"),
164                                                num_random_samples_));
165 
166         // Save input iterator if it hasn't been exhausted else write
167         // "end_of_input_sequence".
168         if (!input_impl_) {
169           TF_RETURN_IF_ERROR(
170               writer->WriteScalar(full_name("end_of_input_sequence"), ""));
171         } else {
172           TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
173         }
174 
175         // Save the epoch counter, buffer, and buffer slices.
176         TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("epoch"), epoch_));
177         TF_RETURN_IF_ERROR(
178             writer->WriteScalar(full_name("num_elements"), num_elements_));
179         TF_RETURN_IF_ERROR(
180             writer->WriteScalar(full_name("slices_size"), slices_.size()));
181         for (size_t i = 0; i < slices_.size(); ++i) {
182           TF_RETURN_IF_ERROR(writer->WriteScalar(
183               full_name(strings::StrCat("slices_start_", i)),
184               slices_[i]->start));
185           TF_RETURN_IF_ERROR(writer->WriteScalar(
186               full_name(strings::StrCat("slices_end_", i)), slices_[i]->end));
187           for (size_t j = slices_[i]->start; j < slices_[i]->end; ++j) {
188             size_t index = j % dataset()->buffer_size_;
189             TF_RETURN_IF_ERROR(writer->WriteScalar(
190                 full_name(strings::StrCat("buffer_", index, "_size")),
191                 buffer_[index].size()));
192             for (size_t k = 0; k < buffer_[index].size(); ++k) {
193               TF_RETURN_IF_ERROR(writer->WriteTensor(
194                   full_name(strings::StrCat("buffer_", index, "_", k)),
195                   buffer_[index][k]));
196             }
197           }
198         }
199 
200         return Status::OK();
201       }
202 
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)203       Status RestoreInternal(IteratorContext* ctx,
204                              IteratorStateReader* reader) override {
205         mutex_lock l(mu_);
206 
207         // Restore the random number generators.
208         TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("num_random_samples"),
209                                               &num_random_samples_));
210         ResetRngs();
211 
212         // Restore the input iterator if it wasn't already exhausted.
213         if (!reader->Contains(full_name("end_of_input_sequence"))) {
214           input_impl_ = dataset()->input_->MakeIterator(prefix());
215           TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
216         } else {
217           input_impl_.reset();
218         }
219 
220         // Restore the epoch counter, buffer, and buffer slices.
221         TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("epoch"), &epoch_));
222         TF_RETURN_IF_ERROR(
223             reader->ReadScalar(full_name("num_elements"), &num_elements_));
224         size_t slices_size;
225         {
226           int64 temp;
227           TF_RETURN_IF_ERROR(
228               reader->ReadScalar(full_name("slices_size"), &temp));
229           slices_size = static_cast<size_t>(temp);
230         }
231         buffer_.reset(new std::vector<Tensor>[dataset()->buffer_size_]);
232         for (size_t i = 0; i < slices_size; ++i) {
233           int64 start;
234           TF_RETURN_IF_ERROR(reader->ReadScalar(
235               full_name(strings::StrCat("slices_start_", i)), &start));
236           int64 end;
237           TF_RETURN_IF_ERROR(reader->ReadScalar(
238               full_name(strings::StrCat("slices_end_", i)), &end));
239           slices_.emplace_back(new Slice{start, end});
240           for (size_t j = start; j < end; ++j) {
241             size_t index = j % dataset()->buffer_size_;
242             int64 list_size;
243             TF_RETURN_IF_ERROR(reader->ReadScalar(
244                 full_name(strings::StrCat("buffer_", index, "_size")),
245                 &list_size));
246             buffer_[index] = std::vector<Tensor>(list_size);
247             for (int k = 0; k < list_size; ++k) {
248               TF_RETURN_IF_ERROR(reader->ReadTensor(
249                   full_name(strings::StrCat("buffer_", index, "_", k)),
250                   &buffer_[index][k]));
251             }
252           }
253         }
254 
255         return Status::OK();
256       }
257 
258      private:
259       // Used to represent slices of `buffer_` that belong to different epochs.
260       // The invariant maintained by the implementation is: `start` <= `end`.
261       // When using `start` and `end` to index into `buffer_`, their values
262       // should be taken modulo the size of `buffer_` as their absolute value
263       // can be greater than the range of `buffer_`.
264       struct Slice {
Slicetensorflow::__anon090660da0111::ShuffleDatasetOpBase::ShuffleDatasetBase::Iterator::Slice265         Slice(int64 start, int64 end) : start(start), end(end) {}
266 
267         int64 start;
268         int64 end;
269       };
270 
Random()271       random::SingleSampleAdapter<random::PhiloxRandom>::ResultType Random()
272           EXCLUSIVE_LOCKS_REQUIRED(mu_) {
273         num_random_samples_++;
274         auto out = generator_();
275         return out;
276       }
277 
ResetRngs()278       void ResetRngs() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
279         // Reset the generators based on the current iterator seeds.
280         parent_generator_ = random::PhiloxRandom(seed_, seed2_);
281         generator_ = random::SingleSampleAdapter<random::PhiloxRandom>(
282             &parent_generator_);
283         generator_.Skip(num_random_samples_);
284       }
285 
286       mutex mu_;
287       std::unique_ptr<std::vector<Tensor>[]> buffer_ GUARDED_BY(mu_);
288       std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
289       const int64 seed_ GUARDED_BY(mu_);
290       const int64 seed2_ GUARDED_BY(mu_);
291       int64 epoch_ GUARDED_BY(mu_);
292       int64 num_elements_ GUARDED_BY(mu_);
293       std::deque<std::unique_ptr<Slice>> slices_ GUARDED_BY(mu_);
294       random::PhiloxRandom parent_generator_ GUARDED_BY(mu_);
295       random::SingleSampleAdapter<random::PhiloxRandom> generator_
296           GUARDED_BY(mu_);
297       int64 num_random_samples_ GUARDED_BY(mu_) = 0;
298     };
299 
300     const DatasetBase* const input_;
301     const int64 buffer_size_;
302     const int64 count_;
303   };
304 };
305 
306 class ShuffleDatasetOp : public ShuffleDatasetOpBase {
307  public:
ShuffleDatasetOp(OpKernelConstruction * ctx)308   explicit ShuffleDatasetOp(OpKernelConstruction* ctx)
309       : ShuffleDatasetOpBase(ctx) {
310     OP_REQUIRES_OK(ctx, ctx->GetAttr("reshuffle_each_iteration",
311                                      &reshuffle_each_iteration_));
312   }
313 
MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)314   void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
315                    DatasetBase** output) override {
316     int64 buffer_size;
317     OP_REQUIRES_OK(
318         ctx, ParseScalarArgument<int64>(ctx, "buffer_size", &buffer_size));
319     OP_REQUIRES(
320         ctx, buffer_size > 0,
321         errors::InvalidArgument("buffer_size must be greater than zero."));
322 
323     int64 seed;
324     OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "seed", &seed));
325 
326     int64 seed2;
327     OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "seed2", &seed2));
328 
329     // By TensorFlow convention, passing 0 for both seeds indicates
330     // that the shuffling should be seeded non-deterministically.
331     if (seed == 0 && seed2 == 0) {
332       seed = random::New64();
333       seed2 = random::New64();
334     }
335 
336     int64 count = 1;
337     if (reshuffle_each_iteration_) {
338       *output =
339           new ReshufflingDataset(ctx, input, buffer_size, seed, seed2, count);
340     } else {
341       *output =
342           new FixedSeedDataset(ctx, input, buffer_size, seed, seed2, count);
343     }
344   }
345 
346  private:
347   // A dataset that uses a pseduorandom sequence of seeds for the iterators
348   // created from it. Used when `reshuffle_each_iteration` is true.
349   class ReshufflingDataset : public ShuffleDatasetBase {
350    public:
ReshufflingDataset(OpKernelContext * ctx,const DatasetBase * input,int64 buffer_size,int64 seed,int64 seed2,int64 count)351     ReshufflingDataset(OpKernelContext* ctx, const DatasetBase* input,
352                        int64 buffer_size, int64 seed, int64 seed2, int64 count)
353         : ShuffleDatasetBase(ctx, input, buffer_size, count),
354           seed_(seed),
355           seed2_(seed2),
356           parent_generator_(seed, seed2),
357           generator_(&parent_generator_) {}
358 
DebugString()359     string DebugString() override {
360       return strings::StrCat("ShuffleDatasetOp(", buffer_size_, ", ", seed_,
361                              ", ", seed2_, ")::ReshufflingDataset");
362     }
363 
MakeIterator(const string & prefix) const364     std::unique_ptr<IteratorBase> MakeIterator(
365         const string& prefix) const override {
366       int64 iterator_seed;
367       int64 iterator_seed2;
368       {
369         mutex_lock l(mu_);
370         iterator_seed = generator_();
371         iterator_seed2 = generator_();
372       }
373       return std::unique_ptr<IteratorBase>(new ShuffleDatasetBase::Iterator(
374           {this, strings::StrCat(prefix, "::Shuffle")}, iterator_seed,
375           iterator_seed2));
376     }
377 
378    private:
379     const int64 seed_;
380     const int64 seed2_;
381     mutable mutex mu_;
382     mutable random::PhiloxRandom parent_generator_ GUARDED_BY(mu_);
383     mutable random::SingleSampleAdapter<random::PhiloxRandom> generator_
384         GUARDED_BY(mu_);
385   };
386 
387   // A dataset that uses the same fixed seed for all iterators created from it.
388   // Used when `reshuffle_each_iteration` is false.
389   class FixedSeedDataset : public ShuffleDatasetBase {
390    public:
FixedSeedDataset(OpKernelContext * ctx,const DatasetBase * input,int64 buffer_size,int64 seed,int64 seed2,int64 count)391     FixedSeedDataset(OpKernelContext* ctx, const DatasetBase* input,
392                      int64 buffer_size, int64 seed, int64 seed2, int64 count)
393         : ShuffleDatasetBase(ctx, input, buffer_size, count),
394           seed_(seed),
395           seed2_(seed) {}
396 
DebugString()397     string DebugString() override {
398       return strings::StrCat("ShuffleDatasetOp(", buffer_size_, ", ", seed_,
399                              ", ", seed2_, ")::FixedSeedDataset");
400     }
401 
MakeIterator(const string & prefix) const402     std::unique_ptr<IteratorBase> MakeIterator(
403         const string& prefix) const override {
404       return std::unique_ptr<IteratorBase>(new ShuffleDatasetBase::Iterator(
405           {this, strings::StrCat(prefix, "::Shuffle")}, seed_, seed2_));
406     }
407 
408    protected:
AsGraphDefInternal(OpKernelContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const409     Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
410                               Node** output) const override {
411       Node* input_graph_node = nullptr;
412       TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_graph_node));
413       Node* buffer_size = nullptr;
414       Node* seed = nullptr;
415       Node* seed2 = nullptr;
416       AttrValue reshuffle_each_iteration;
417 
418       TF_RETURN_IF_ERROR(b->AddScalar(buffer_size_, &buffer_size));
419       TF_RETURN_IF_ERROR(b->AddScalar(seed_, &seed));
420       TF_RETURN_IF_ERROR(b->AddScalar(seed2_, &seed2));
421       b->BuildAttrValue(false, &reshuffle_each_iteration);
422       TF_RETURN_IF_ERROR(b->AddDataset(
423           this, {input_graph_node, buffer_size, seed, seed2},  // Inputs
424           {std::make_pair("reshuffle_each_iteration",
425                           reshuffle_each_iteration)},  // Attrs
426           output));
427       return Status::OK();
428     }
429 
430    private:
431     const int64 seed_;
432     const int64 seed2_;
433   };
434 
435   bool reshuffle_each_iteration_;
436 };
437 
438 class ShuffleAndRepeatDatasetOp : public ShuffleDatasetOpBase {
439  public:
ShuffleAndRepeatDatasetOp(OpKernelConstruction * ctx)440   explicit ShuffleAndRepeatDatasetOp(OpKernelConstruction* ctx)
441       : ShuffleDatasetOpBase(ctx) {}
442 
MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)443   void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
444                    DatasetBase** output) override {
445     int64 buffer_size;
446     OP_REQUIRES_OK(
447         ctx, ParseScalarArgument<int64>(ctx, "buffer_size", &buffer_size));
448     OP_REQUIRES(
449         ctx, buffer_size > 0,
450         errors::InvalidArgument("buffer_size must be greater than zero."));
451 
452     int64 seed;
453     OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "seed", &seed));
454 
455     int64 seed2;
456     OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "seed2", &seed2));
457 
458     int64 count;
459     OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "count", &count));
460 
461     // By TensorFlow convention, if both seeds are 0, then shuffling should be
462     // seeded non-deterministically.
463     if (seed == 0 && seed2 == 0) {
464       seed = random::New64();
465       seed2 = random::New64();
466     }
467 
468     *output = new Dataset(ctx, input, buffer_size, seed, seed2, count);
469   }
470 
471  private:
472   class Dataset : public ShuffleDatasetBase {
473    public:
Dataset(OpKernelContext * ctx,const DatasetBase * input,int64 buffer_size,int64 seed,int64 seed2,int64 count)474     Dataset(OpKernelContext* ctx, const DatasetBase* input, int64 buffer_size,
475             int64 seed, int64 seed2, int64 count)
476         : ShuffleDatasetBase(ctx, input, buffer_size, count),
477           seed_(seed),
478           seed2_(seed2) {}
479 
DebugString()480     string DebugString() override {
481       return strings::StrCat("ShuffleAndRepeatDatasetOp(", buffer_size_, ", ",
482                              seed_, ", ", seed2_, ", ", count_, ")::Dataset");
483     }
484 
MakeIterator(const string & prefix) const485     std::unique_ptr<IteratorBase> MakeIterator(
486         const string& prefix) const override {
487       return std::unique_ptr<IteratorBase>(new ShuffleDatasetBase::Iterator(
488           {this, strings::StrCat(prefix, "::ShuffleAndRepeat")}, seed_,
489           seed2_));
490     }
491 
492    protected:
AsGraphDefInternal(OpKernelContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const493     Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
494                               Node** output) const override {
495       Node* input_graph_node = nullptr;
496       TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_graph_node));
497       Node* buffer_size = nullptr;
498       Node* seed = nullptr;
499       Node* seed2 = nullptr;
500       Node* count = nullptr;
501 
502       TF_RETURN_IF_ERROR(b->AddScalar(buffer_size_, &buffer_size));
503       TF_RETURN_IF_ERROR(b->AddScalar(seed_, &seed));
504       TF_RETURN_IF_ERROR(b->AddScalar(seed2_, &seed2));
505       TF_RETURN_IF_ERROR(b->AddScalar(count_, &count));
506       TF_RETURN_IF_ERROR(b->AddDataset(
507           this, {input_graph_node, buffer_size, seed, seed2, count},  // Inputs
508           {},                                                         // Attrs
509           output));
510       return Status::OK();
511     }
512 
513    private:
514     const int64 seed_;
515     const int64 seed2_;
516   };
517 };
518 
519 REGISTER_KERNEL_BUILDER(Name("ShuffleDataset").Device(DEVICE_CPU),
520                         ShuffleDatasetOp);
521 
522 REGISTER_KERNEL_BUILDER(Name("ShuffleAndRepeatDataset").Device(DEVICE_CPU),
523                         ShuffleAndRepeatDatasetOp);
524 
525 }  // namespace
526 
527 }  // namespace tensorflow
528