1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/shuffle_dataset_op.h"
16
17 #include <deque>
18 #include <tuple>
19 #include <vector>
20
21 #include "tensorflow/core/framework/dataset.h"
22 #include "tensorflow/core/framework/partial_tensor_shape.h"
23 #include "tensorflow/core/framework/resource_mgr.h"
24 #include "tensorflow/core/framework/tensor.h"
25 #include "tensorflow/core/kernels/data/dataset_utils.h"
26 #include "tensorflow/core/kernels/data/name_utils.h"
27 #include "tensorflow/core/kernels/data/random_seed_ops.h"
28 #include "tensorflow/core/lib/core/errors.h"
29 #include "tensorflow/core/lib/random/philox_random.h"
30 #include "tensorflow/core/lib/random/random.h"
31 #include "tensorflow/core/lib/random/random_distributions.h"
32 #include "tensorflow/core/platform/errors.h"
33 #include "tensorflow/core/platform/stringprintf.h"
34
35 namespace tensorflow {
36 namespace data {
37
38 // See documentation in ../../ops/dataset_ops.cc for a high-level
39 // description of the following op.
40
41 /* static */ constexpr const char* const ShuffleDatasetOpBase::kInputDataset;
42 /* static */ constexpr const char* const ShuffleDatasetOpBase::kBufferSize;
43 /* static */ constexpr const char* const ShuffleDatasetOpBase::kSeed;
44 /* static */ constexpr const char* const ShuffleDatasetOpBase::kSeed2;
45 /* static */ constexpr const char* const ShuffleDatasetOpBase::kOutputTypes;
46 /* static */ constexpr const char* const ShuffleDatasetOpBase::kOutputShapes;
47 /* static */ constexpr const char* const
48 ShuffleDatasetOpBase::kReshuffleEachIteration;
49
50 /* static */ constexpr const char* const ShuffleDatasetOp::kDatasetType;
51
52 /* static */ constexpr const char* const
53 ShuffleAndRepeatDatasetOp::kDatasetType;
54 /* static */ constexpr const char* const ShuffleAndRepeatDatasetOp::kCount;
55
56 const int64 kLogIntervalMicros = 10 * 1000000; // 10 seconds.
57 const int64 kMaxEpochsInBuffer = 3;
58
59 constexpr char kNumRandomSamples[] = "num_random_samples";
60 constexpr char kDataProduced[] = "data_produced";
61 constexpr char kEndOfInputSequence[] = "end_of_input_sequence";
62 constexpr char kEpoch[] = "epoch";
63 constexpr char kNumElements[] = "num_elements";
64 constexpr char kSlicesSize[] = "slices_size";
65 constexpr char kSlicesStart[] = "slices_start";
66 constexpr char kSlicesEnd[] = "slices_end";
67 constexpr char kBuffer[] = "buffer";
68 constexpr char kSize[] = "size";
69 constexpr char kSeedGenerator[] = "SeedGenerator";
70 constexpr char kTFData[] = "tf_data";
71 constexpr char kEpochNumRandomSamples[] = "epoch_num_random_samples";
72 constexpr char kShuffleDatasetV1[] = "ShuffleDataset";
73 constexpr char kShuffleDatasetV2[] = "ShuffleDatasetV2";
74 constexpr char kShuffleDatasetV3[] = "ShuffleDatasetV3";
75 constexpr char kShuffleAndRepeatDatasetV1[] = "ShuffleAndRepeatDataset";
76 constexpr char kShuffleAndRepeatDatasetV2[] = "ShuffleAndRepeatDatasetV2";
77
ShuffleDatasetOpBase(OpKernelConstruction * ctx)78 ShuffleDatasetOpBase::ShuffleDatasetOpBase(OpKernelConstruction* ctx)
79 : UnaryDatasetOpKernel(ctx) {}
80
81 // Abstract base dataset that implements a shuffling iterator.
82 class ShuffleDatasetOpBase::ShuffleDatasetBase : public DatasetBase {
83 public:
ShuffleDatasetBase(OpKernelContext * ctx,const DatasetBase * input,int64 buffer_size,std::shared_ptr<SeedGenerator> seed_generator,int64 count)84 ShuffleDatasetBase(OpKernelContext* ctx, const DatasetBase* input,
85 int64 buffer_size,
86 std::shared_ptr<SeedGenerator> seed_generator, int64 count)
87 : DatasetBase(DatasetContext(ctx)),
88 input_(input),
89 buffer_size_(buffer_size),
90 seed_generator_(std::move(seed_generator)),
91 count_(count),
92 traceme_metadata_(
93 {{"buffer_size",
94 strings::Printf("%lld", static_cast<long long>(buffer_size))}}) {
95 input_->Ref();
96 }
97
~ShuffleDatasetBase()98 ~ShuffleDatasetBase() override { input_->Unref(); }
99
100 virtual string op_type() const = 0;
101
output_dtypes() const102 const DataTypeVector& output_dtypes() const override {
103 return input_->output_dtypes();
104 }
105
output_shapes() const106 const std::vector<PartialTensorShape>& output_shapes() const override {
107 return input_->output_shapes();
108 }
109
Cardinality() const110 int64 Cardinality() const override {
111 if (count_ == -1 || input_->Cardinality() == kInfiniteCardinality) {
112 return kInfiniteCardinality;
113 } else if (input_->Cardinality() == kUnknownCardinality) {
114 return kUnknownCardinality;
115 } else {
116 return input_->Cardinality() * count_;
117 }
118 }
119
InputDatasets(std::vector<const DatasetBase * > * inputs) const120 Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
121 inputs->push_back(input_);
122 return Status::OK();
123 }
124
CheckExternalState() const125 Status CheckExternalState() const override {
126 return input_->CheckExternalState();
127 }
128
DebugString() const129 string DebugString() const override {
130 name_utils::DatasetDebugStringParams params;
131 params.set_args(buffer_size_, seed_generator_->seed(),
132 seed_generator_->seed2(), count_);
133 return name_utils::DatasetDebugString(op_type(), params);
134 }
135
MakeIteratorInternal(const string & prefix) const136 std::unique_ptr<IteratorBase> MakeIteratorInternal(
137 const string& prefix) const override {
138 return absl::make_unique<Iterator>(
139 Iterator::Params{this, name_utils::IteratorPrefix(op_type(), prefix)},
140 seed_generator_.get());
141 }
142
143 protected:
144 class Iterator : public DatasetIterator<ShuffleDatasetBase> {
145 public:
Iterator(const Params & params,SeedGenerator * seed_generator)146 explicit Iterator(const Params& params, SeedGenerator* seed_generator)
147 : DatasetIterator<ShuffleDatasetBase>(params),
148 seed_generator_(seed_generator),
149 parent_generator_(seed_generator->seed(), seed_generator->seed2()),
150 generator_(&parent_generator_) {
151 buffer_ = absl::make_unique<std::vector<std::vector<Tensor>>>(
152 params.dataset->buffer_size_);
153 slices_.push_back(absl::make_unique<Slice>(0, 0));
154 }
155
Initialize(IteratorContext * ctx)156 Status Initialize(IteratorContext* ctx) override {
157 mutex_lock l(mu_);
158 seed_generator_->GenerateSeeds(&seed_, &seed2_);
159 ResetRngs();
160 return Status::OK();
161 }
162
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)163 Status GetNextInternal(IteratorContext* ctx,
164 std::vector<Tensor>* out_tensors,
165 bool* end_of_sequence) override {
166 mutex_lock l(mu_);
167 int64 start_micros = EnvTime::NowMicros();
168 int64 num_log_entries = 0;
169 if (!input_impl_ && epoch_ == 0) {
170 TF_RETURN_IF_ERROR(this->dataset()->input_->MakeIterator(
171 ctx, this, this->prefix(), &input_impl_));
172 }
173 while (input_impl_ && num_elements_ < this->dataset()->buffer_size_) {
174 if (EnvTime::NowMicros() >
175 ((num_log_entries + 1) * kLogIntervalMicros) + start_micros) {
176 num_log_entries++;
177 LOG(INFO) << "Filling up shuffle buffer (this may take a while): "
178 << num_elements_ << " of " << this->dataset()->buffer_size_;
179 }
180 std::vector<Tensor> input_element;
181 bool end_of_input_sequence = false;
182 while (this->dataset()->count_ == -1 ||
183 epoch_ < this->dataset()->count_) {
184 TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, &input_element,
185 &end_of_input_sequence));
186 if (!end_of_input_sequence) {
187 data_produced_ = true;
188 break;
189 }
190 if (ctx->split_provider() == nullptr && !data_produced_ &&
191 this->dataset()->count_ == -1) {
192 // If we encounter the end of sequence without producing data, we
193 // terminate the iteration immediately. (Otherwise, this iterator
194 // would loop infinitely and never produce a value.)
195 *end_of_sequence = true;
196 return Status::OK();
197 }
198 epoch_++;
199 int64 n = slices_.back()->end;
200 slices_.push_back(absl::make_unique<Slice>(n, n));
201 if (ctx->split_provider()) {
202 TF_RETURN_IF_ERROR(ctx->split_provider()->Reset());
203 }
204 TF_RETURN_IF_ERROR(this->dataset()->input_->MakeIterator(
205 ctx, this, this->prefix(), &input_impl_));
206 }
207 if (!end_of_input_sequence) {
208 if (num_elements_ == 0) {
209 VLOG(1) << "Starting to fill up shuffle buffer of size: "
210 << this->dataset()->buffer_size_;
211 }
212 this->RecordBufferEnqueue(ctx, input_element);
213 buffer_->at(slices_.back()->end % this->dataset()->buffer_size_) =
214 std::move(input_element);
215 num_elements_++;
216 slices_.back()->end++;
217 } else {
218 input_impl_.reset();
219 }
220 if (slices_.size() > kMaxEpochsInBuffer) {
221 // When the elements stored in `buffer_` span more than
222 // `kMaxEpochsInBuffer` epochs, we do not fill the buffer further to
223 // conserve memory. This means that the upper bound on the size of
224 // `buffer_` is `kMaxEpochsInBuffer * cardinality(input_dataset) +
225 // 1`.
226 break;
227 }
228 }
229 if (num_log_entries > 0) {
230 LOG(INFO) << "Shuffle buffer filled.";
231 }
232
233 if (num_elements_ > 0) {
234 *end_of_sequence = false;
235 // Garbage collect all empty slices.
236 while (!slices_.empty() &&
237 slices_.front()->start == slices_.front()->end) {
238 slices_.pop_front();
239 // Reinitialize the RNG state for the next epoch.
240 num_random_samples_ = 0;
241 seed_generator_->GenerateSeeds(&seed_, &seed2_);
242 ResetRngs();
243 }
244 DCHECK(!slices_.empty());
245 // Choose an element to produce uniformly at random from the first
246 // slice, and then remove the element from the slice.
247 int64 offset =
248 Random() % (slices_.front()->end - slices_.front()->start);
249 int64 index =
250 (slices_.front()->start + offset) % this->dataset()->buffer_size_;
251 *out_tensors = std::move(buffer_->at(index));
252 this->RecordBufferDequeue(ctx, *out_tensors);
253 std::swap(buffer_->at(index),
254 buffer_->at(slices_.front()->start %
255 this->dataset()->buffer_size_));
256 slices_.front()->start++;
257 num_elements_--;
258 } else {
259 DCHECK(input_impl_ == nullptr);
260 *end_of_sequence = true;
261 }
262 return Status::OK();
263 }
264
265 protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const266 std::shared_ptr<model::Node> CreateNode(
267 IteratorContext* ctx, model::Node::Args args) const override {
268 return model::MakeKnownRatioNode(std::move(args),
269 /*ratio=*/1);
270 }
271
ResetRngs()272 void ResetRngs() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
273 // Reset the generators based on the current iterator seeds.
274 parent_generator_ = random::PhiloxRandom(seed_, seed2_);
275 generator_ =
276 random::SingleSampleAdapter<random::PhiloxRandom>(&parent_generator_);
277 generator_.Skip(num_random_samples_);
278 }
279
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)280 Status SaveInternal(SerializationContext* ctx,
281 IteratorStateWriter* writer) override {
282 mutex_lock l(mu_);
283 // Save state needed to restore the random number generators.
284 TF_RETURN_IF_ERROR(
285 writer->WriteScalar(full_name(kEpochNumRandomSamples),
286 seed_generator_->num_random_samples()));
287 TF_RETURN_IF_ERROR(writer->WriteScalar(this->full_name(kNumRandomSamples),
288 num_random_samples_));
289 TF_RETURN_IF_ERROR(writer->WriteScalar(this->full_name(kSeed), seed_));
290 TF_RETURN_IF_ERROR(writer->WriteScalar(this->full_name(kSeed2), seed2_));
291
292 // Save input iterator if it hasn't been exhausted else write
293 // "end_of_input_sequence".
294 if (!input_impl_) {
295 TF_RETURN_IF_ERROR(
296 writer->WriteScalar(this->full_name(kEndOfInputSequence), ""));
297 } else {
298 TF_RETURN_IF_ERROR(this->SaveInput(ctx, writer, input_impl_));
299 }
300
301 // Save the epoch counter, buffer, and buffer slices.
302 TF_RETURN_IF_ERROR(writer->WriteScalar(this->full_name(kEpoch), epoch_));
303 TF_RETURN_IF_ERROR(
304 writer->WriteScalar(this->full_name(kNumElements), num_elements_));
305 TF_RETURN_IF_ERROR(WriteElementsToCheckpoint(writer, prefix(), *buffer_));
306 TF_RETURN_IF_ERROR(
307 writer->WriteScalar(this->full_name(kSlicesSize), slices_.size()));
308 for (size_t i = 0; i < slices_.size(); ++i) {
309 TF_RETURN_IF_ERROR(
310 writer->WriteScalar(this->full_name(absl::StrJoin(
311 std::make_tuple(kSlicesStart, i), "_")),
312 slices_[i]->start));
313 TF_RETURN_IF_ERROR(writer->WriteScalar(
314 this->full_name(absl::StrJoin(std::make_tuple(kSlicesEnd, i), "_")),
315 slices_[i]->end));
316 }
317 if (data_produced_) {
318 TF_RETURN_IF_ERROR(
319 writer->WriteScalar(this->full_name(kDataProduced), ""));
320 }
321
322 return Status::OK();
323 }
324
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)325 Status RestoreInternal(IteratorContext* ctx,
326 IteratorStateReader* reader) override {
327 mutex_lock l(mu_);
328 // Restore the random number generators.
329 int64 num_random_samples;
330 TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kEpochNumRandomSamples),
331 &num_random_samples));
332 seed_generator_->set_num_random_samples(num_random_samples);
333 seed_generator_->Reset();
334 TF_RETURN_IF_ERROR(reader->ReadScalar(this->full_name(kNumRandomSamples),
335 &num_random_samples_));
336 TF_RETURN_IF_ERROR(reader->ReadScalar(this->full_name(kSeed), &seed_));
337 TF_RETURN_IF_ERROR(reader->ReadScalar(this->full_name(kSeed2), &seed2_));
338 ResetRngs();
339
340 // Restore the input iterator if it wasn't already exhausted.
341 if (!reader->Contains(this->full_name(kEndOfInputSequence))) {
342 TF_RETURN_IF_ERROR(this->dataset()->input_->MakeIterator(
343 ctx, this, this->prefix(), &input_impl_));
344 TF_RETURN_IF_ERROR(this->RestoreInput(ctx, reader, input_impl_));
345 } else {
346 input_impl_.reset();
347 }
348
349 // Restore the epoch counter, buffer, and buffer slices.
350 TF_RETURN_IF_ERROR(reader->ReadScalar(this->full_name(kEpoch), &epoch_));
351 TF_RETURN_IF_ERROR(
352 reader->ReadScalar(this->full_name(kNumElements), &num_elements_));
353 size_t slices_size;
354 {
355 int64 temp;
356 TF_RETURN_IF_ERROR(
357 reader->ReadScalar(this->full_name(kSlicesSize), &temp));
358 slices_size = static_cast<size_t>(temp);
359 }
360 buffer_ = absl::make_unique<std::vector<std::vector<Tensor>>>(
361 this->dataset()->buffer_size_);
362 TF_RETURN_IF_ERROR(
363 ReadElementsFromCheckpoint(reader, prefix(), buffer_.get()));
364 slices_.clear();
365 for (size_t i = 0; i < slices_size; ++i) {
366 int64 start;
367 TF_RETURN_IF_ERROR(
368 reader->ReadScalar(this->full_name(absl::StrJoin(
369 std::make_tuple(kSlicesStart, i), "_")),
370 &start));
371 int64 end;
372 TF_RETURN_IF_ERROR(reader->ReadScalar(
373 this->full_name(absl::StrJoin(std::make_tuple(kSlicesEnd, i), "_")),
374 &end));
375 slices_.push_back(absl::make_unique<Slice>(start, end));
376 }
377 data_produced_ = reader->Contains(this->full_name(kDataProduced));
378
379 return Status::OK();
380 }
381
GetTraceMeMetadata() const382 TraceMeMetadata GetTraceMeMetadata() const override {
383 return this->dataset()->traceme_metadata_;
384 }
385
386 private:
387 // Used to represent slices of `buffer_` that belong to different epochs.
388 // The invariant maintained by the implementation is: `start` <= `end`.
389 // When using `start` and `end` to index into `buffer_`, their values
390 // should be taken modulo the size of `buffer_` as their absolute value
391 // can be greater than the range of `buffer_`.
392 struct Slice {
Slicetensorflow::data::ShuffleDatasetOpBase::ShuffleDatasetBase::Iterator::Slice393 Slice(int64 start, int64 end) : start(start), end(end) {}
394
395 int64 start;
396 int64 end;
397 };
398
Random()399 random::SingleSampleAdapter<random::PhiloxRandom>::ResultType Random()
400 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
401 num_random_samples_++;
402 auto out = generator_();
403 return out;
404 }
405
406 mutex mu_;
407 SeedGenerator* const seed_generator_ TF_GUARDED_BY(mu_); // Not owned.
408 std::unique_ptr<std::vector<std::vector<Tensor>>> buffer_
409 TF_GUARDED_BY(mu_);
410 std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_) = nullptr;
411 int64 epoch_ TF_GUARDED_BY(mu_) = 0;
412 int64 num_elements_ TF_GUARDED_BY(mu_) = 0;
413 int64 seed_ TF_GUARDED_BY(mu_) = 0;
414 int64 seed2_ TF_GUARDED_BY(mu_) = 0;
415 // Indices into `buffer_` indicating which data belongs to which epoch.
416 // The slice at the front of the deque references data from the earliest
417 // buffered epoch. It is an invariant that all slices reference
418 // non-overlapping sections of `buffer_`.
419 std::deque<std::unique_ptr<Slice>> slices_ TF_GUARDED_BY(mu_);
420 random::PhiloxRandom parent_generator_ TF_GUARDED_BY(mu_);
421 random::SingleSampleAdapter<random::PhiloxRandom> generator_
422 TF_GUARDED_BY(mu_);
423 int64 num_random_samples_ TF_GUARDED_BY(mu_) = 0;
424 bool data_produced_ TF_GUARDED_BY(mu_) = false;
425 };
426
427 const DatasetBase* const input_;
428 const int64 buffer_size_;
429 const std::shared_ptr<SeedGenerator> seed_generator_;
430 // The number of epochs to run for. Normally this is just 1, but sometimes we
431 // fuse shuffle and repeat together, and make the shuffle dataset op
432 // responsible for repeating as well.
433 const int64 count_;
434 const TraceMeMetadata traceme_metadata_;
435 }; // ShuffleDatasetBase
436
437 // This version of memory dataset has an exclusive ownership of the seed
438 // generator resource. It supports sharing of the seed generator across
439 // different iterations of the `repeat` transformation but not across different
440 // iterators.
441 class ShuffleDatasetOp::Dataset : public ShuffleDatasetBase {
442 public:
Dataset(OpKernelContext * ctx,const DatasetBase * input,int64 buffer_size,int64 count,RandomSeeds && seeds,SeedGeneratorManager * manager,ResourceHandle && resource_handle)443 Dataset(OpKernelContext* ctx, const DatasetBase* input, int64 buffer_size,
444 int64 count, RandomSeeds&& seeds, SeedGeneratorManager* manager,
445 ResourceHandle&& resource_handle)
446 : ShuffleDatasetBase(ctx, input, buffer_size, manager->get(), count),
447 manager_(manager),
448 resource_handle_(std::move(resource_handle)),
449 resource_mgr_(ctx->resource_manager()),
450 seeds_(std::move(seeds)) {}
451
~Dataset()452 ~Dataset() override {
453 manager_->Unref();
454 Status s = resource_mgr_->Delete<SeedGeneratorManager>(
455 resource_handle_.container(), resource_handle_.name());
456 if (!s.ok()) {
457 LOG(WARNING) << "Failed to delete RNG resource: " << s.ToString();
458 }
459 }
460
op_type() const461 string op_type() const override { return kDatasetType; }
462
463 protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const464 Status AsGraphDefInternal(SerializationContext* ctx,
465 DatasetGraphDefBuilder* b,
466 Node** output) const override {
467 Node* input_graph_node = nullptr;
468 TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
469 Node* buffer_size_node = nullptr;
470 Node* seed_node = nullptr;
471 Node* seed2_node = nullptr;
472 AttrValue reshuffle_each_iteration;
473
474 TF_RETURN_IF_ERROR(b->AddScalar(buffer_size_, &buffer_size_node));
475 TF_RETURN_IF_ERROR(b->AddScalar(seeds_.input_seed(), &seed_node));
476 TF_RETURN_IF_ERROR(b->AddScalar(seeds_.input_seed2(), &seed2_node));
477 b->BuildAttrValue(seed_generator_->reshuffle_each_iteration(),
478 &reshuffle_each_iteration);
479 TF_RETURN_IF_ERROR(b->AddDataset(
480 this,
481 {input_graph_node, buffer_size_node, seed_node, seed2_node}, // Inputs
482 {std::make_pair(kReshuffleEachIteration,
483 reshuffle_each_iteration)}, // Attrs
484 output));
485 return Status::OK();
486 }
487
488 private:
489 SeedGeneratorManager* const manager_; // Owned.
490 const ResourceHandle resource_handle_;
491 ResourceMgr* const resource_mgr_; // Not owned.
492 const RandomSeeds seeds_;
493 };
494
495 // This version of shuffle dataset has a shared ownership of the seed generator
496 // resource. It supports sharing of the generator state across different
497 // iterations of the `repeat` transformation and also across different
498 // iterators.
499 class ShuffleDatasetOp::DatasetV2 : public ShuffleDatasetBase {
500 public:
DatasetV2(OpKernelContext * ctx,const DatasetBase * input,int64 buffer_size,int64 count,SeedGeneratorManager * manager,ResourceHandle && resource_handle,bool owns_resource)501 DatasetV2(OpKernelContext* ctx, const DatasetBase* input, int64 buffer_size,
502 int64 count, SeedGeneratorManager* manager,
503 ResourceHandle&& resource_handle, bool owns_resource)
504 : ShuffleDatasetBase(ctx, input, buffer_size, manager->get(), count),
505 manager_(manager),
506 owns_resource_(owns_resource),
507 resource_handle_(std::move(resource_handle)),
508 resource_mgr_(ctx->resource_manager()) {}
509
~DatasetV2()510 ~DatasetV2() override {
511 manager_->Unref();
512 if (owns_resource_) {
513 Status s = resource_mgr_->Delete<SeedGeneratorManager>(
514 resource_handle_.container(), resource_handle_.name());
515 if (!s.ok()) {
516 LOG(WARNING) << "Failed to delete RNG resource: " << s.ToString();
517 }
518 }
519 }
520
op_type() const521 string op_type() const override { return kDatasetType; }
522
523 protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const524 Status AsGraphDefInternal(SerializationContext* ctx,
525 DatasetGraphDefBuilder* b,
526 Node** output) const override {
527 Node* input_graph_node = nullptr;
528 TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
529 Node* buffer_size_node = nullptr;
530 TF_RETURN_IF_ERROR(b->AddScalar(buffer_size_, &buffer_size_node));
531 Node* resource_handle_node = nullptr;
532 Tensor handle(DT_RESOURCE, TensorShape({}));
533 handle.scalar<ResourceHandle>()() = resource_handle_;
534 TF_RETURN_IF_ERROR(b->AddTensor(handle, &resource_handle_node));
535 TF_RETURN_IF_ERROR(b->AddDataset(
536 this,
537 {input_graph_node, buffer_size_node, resource_handle_node}, // Inputs
538 {}, // Attrs
539 output));
540 return Status::OK();
541 }
542
543 private:
544 SeedGeneratorManager* const manager_; // Owned.
545 const bool owns_resource_;
546 const ResourceHandle resource_handle_;
547 ResourceMgr* const resource_mgr_; // Not owned.
548 };
549
550 // This version of shuffle dataset extends the functionality of DatasetV2 with
551 // the ability to preserve seed generator configuration (i.e. initial seeds and
552 // whether to reshuffle each iteration) across serialization of the dataset.
553 class ShuffleDatasetOp::DatasetV3 : public ShuffleDatasetBase {
554 public:
DatasetV3(OpKernelContext * ctx,const DatasetBase * input,int64 buffer_size,int64 count,RandomSeeds && seeds,SeedGeneratorManager * manager,ResourceHandle && resource_handle,bool owns_resource)555 DatasetV3(OpKernelContext* ctx, const DatasetBase* input, int64 buffer_size,
556 int64 count, RandomSeeds&& seeds, SeedGeneratorManager* manager,
557 ResourceHandle&& resource_handle, bool owns_resource)
558 : ShuffleDatasetBase(ctx, input, buffer_size, manager->get(), count),
559 manager_(manager),
560 owns_resource_(owns_resource),
561 resource_handle_(std::move(resource_handle)),
562 resource_mgr_(ctx->resource_manager()),
563 seeds_(std::move(seeds)) {}
564
~DatasetV3()565 ~DatasetV3() override {
566 manager_->Unref();
567 if (owns_resource_) {
568 Status s = resource_mgr_->Delete<SeedGeneratorManager>(
569 resource_handle_.container(), resource_handle_.name());
570 if (!s.ok()) {
571 LOG(WARNING) << "Failed to delete RNG resource: " << s.ToString();
572 }
573 }
574 }
575
op_type() const576 string op_type() const override { return kDatasetType; }
577
578 protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const579 Status AsGraphDefInternal(SerializationContext* ctx,
580 DatasetGraphDefBuilder* b,
581 Node** output) const override {
582 Node* input_graph_node = nullptr;
583 TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
584 Node* buffer_size_node = nullptr;
585 Node* seed_node = nullptr;
586 Node* seed2_node = nullptr;
587 TF_RETURN_IF_ERROR(b->AddScalar(buffer_size_, &buffer_size_node));
588 TF_RETURN_IF_ERROR(b->AddScalar(seeds_.input_seed(), &seed_node));
589 TF_RETURN_IF_ERROR(b->AddScalar(seeds_.input_seed2(), &seed2_node));
590 Node* resource_handle_node = nullptr;
591 Tensor handle(DT_RESOURCE, TensorShape({}));
592 handle.scalar<ResourceHandle>()() = resource_handle_;
593 TF_RETURN_IF_ERROR(b->AddTensor(handle, &resource_handle_node));
594 AttrValue reshuffle_each_iteration;
595 b->BuildAttrValue(seed_generator_->reshuffle_each_iteration(),
596 &reshuffle_each_iteration);
597 TF_RETURN_IF_ERROR(
598 b->AddDataset(this,
599 {input_graph_node, buffer_size_node, seed_node,
600 seed2_node, resource_handle_node}, // Inputs
601 {std::make_pair(kReshuffleEachIteration,
602 reshuffle_each_iteration)}, // Attrs
603 output));
604 return Status::OK();
605 }
606
607 private:
608 SeedGeneratorManager* const manager_; // Owned
609 const bool owns_resource_;
610 const ResourceHandle resource_handle_;
611 ResourceMgr* const resource_mgr_; // Not owned.
612 const RandomSeeds seeds_;
613 };
614
ShuffleDatasetOp(OpKernelConstruction * ctx)615 ShuffleDatasetOp::ShuffleDatasetOp(OpKernelConstruction* ctx)
616 : ShuffleDatasetOpBase(ctx) {
617 auto& op_name = ctx->def().op();
618 if (op_name == kShuffleDatasetV3) {
619 op_version_ = 3;
620 } else if (op_name == kShuffleDatasetV2) {
621 op_version_ = 2;
622 } else if (op_name == kShuffleDatasetV1) {
623 op_version_ = 1;
624 }
625 if (ctx->HasAttr(kReshuffleEachIteration)) {
626 OP_REQUIRES_OK(
627 ctx, ctx->GetAttr(kReshuffleEachIteration, &reshuffle_each_iteration_));
628 }
629 }
630
MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)631 void ShuffleDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
632 DatasetBase** output) {
633 int64 buffer_size = 0;
634 OP_REQUIRES_OK(ctx,
635 ParseScalarArgument<int64>(ctx, kBufferSize, &buffer_size));
636 OP_REQUIRES(
637 ctx, buffer_size > 0,
638 errors::InvalidArgument("buffer_size must be greater than zero."));
639
640 int64 count = 1;
641 static std::atomic<int64> resource_id_counter(0);
642 const string& container = ctx->resource_manager()->default_container();
643 auto name = strings::StrCat(ctx->op_kernel().name(), "/", kSeedGenerator, "_",
644 resource_id_counter.fetch_add(1));
645 if (op_version_ == 3) {
646 auto handle = HandleFromInput(ctx, 4);
647 SeedGeneratorManager* manager = nullptr;
648 Status s = ctx->resource_manager()->Lookup<SeedGeneratorManager>(
649 handle.container(), handle.name(), &manager);
650 int64 seed;
651 OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, kSeed, &seed));
652 int64 seed2;
653 OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, kSeed2, &seed2));
654 RandomSeeds seeds(seed, seed2);
655 bool owns_resource = false;
656 if (errors::IsNotFound(s)) {
657 owns_resource = true;
658 OP_REQUIRES_OK(
659 ctx,
660 ctx->resource_manager()->LookupOrCreate<SeedGeneratorManager>(
661 container, name, &manager,
662 [reshuffle = reshuffle_each_iteration_,
663 &seeds](SeedGeneratorManager** manager) {
664 if (reshuffle) {
665 *manager =
666 new SeedGeneratorManager(new RandomSeedGenerator(seeds));
667 } else {
668 *manager =
669 new SeedGeneratorManager(new FixedSeedGenerator(seeds));
670 }
671 return Status::OK();
672 }));
673 handle = MakeResourceHandle<SeedGenerator>(ctx, container, name);
674 } else {
675 OP_REQUIRES_OK(ctx, s);
676 }
677
678 // Ownership of manager is transferred onto `DatasetV3`.
679 *output = new ShuffleDatasetOp::DatasetV3(ctx, input, buffer_size, count,
680 std::move(seeds), manager,
681 std::move(handle), owns_resource);
682 } else if (op_version_ == 2) {
683 auto handle = HandleFromInput(ctx, 2);
684 SeedGeneratorManager* manager = nullptr;
685 Status s = ctx->resource_manager()->Lookup<SeedGeneratorManager>(
686 handle.container(), handle.name(), &manager);
687 bool owns_resource = false;
688 if (errors::IsNotFound(s)) {
689 owns_resource = true;
690 LOG(WARNING) << "Failed to find seed generator resource. Falling back to "
691 "using a non-deterministically seeded generator and "
692 "reshuffling each iteration.";
693 RandomSeeds seeds(0, 0);
694 OP_REQUIRES_OK(
695 ctx, ctx->resource_manager()->LookupOrCreate<SeedGeneratorManager>(
696 container, name, &manager,
697 [&seeds](SeedGeneratorManager** manager) {
698 *manager = new SeedGeneratorManager(
699 new RandomSeedGenerator(seeds));
700 return Status::OK();
701 }));
702 handle = MakeResourceHandle<SeedGeneratorManager>(ctx, container, name);
703 } else {
704 OP_REQUIRES_OK(ctx, s);
705 }
706
707 // Ownership of manager is transferred onto `DatasetV2`.
708 *output =
709 new ShuffleDatasetOp::DatasetV2(ctx, input, buffer_size, count, manager,
710 std::move(handle), owns_resource);
711 } else {
712 if (op_version_ != 1) {
713 LOG(WARNING) << "Unsupported version of shuffle dataset op: "
714 << op_version_ << ". Defaulting to version 1.";
715 }
716 int64 seed;
717 OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, kSeed, &seed));
718 int64 seed2;
719 OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, kSeed2, &seed2));
720 RandomSeeds seeds(seed, seed2);
721 SeedGeneratorManager* manager;
722 OP_REQUIRES_OK(
723 ctx,
724 ctx->resource_manager()->LookupOrCreate<SeedGeneratorManager>(
725 container, name, &manager,
726 [reshuffle = reshuffle_each_iteration_,
727 &seeds](SeedGeneratorManager** manager) {
728 if (reshuffle) {
729 *manager =
730 new SeedGeneratorManager(new RandomSeedGenerator(seeds));
731 } else {
732 *manager =
733 new SeedGeneratorManager(new FixedSeedGenerator(seeds));
734 }
735 return Status::OK();
736 }));
737 auto handle =
738 MakeResourceHandle<SeedGeneratorManager>(ctx, container, name);
739
740 // Ownership of manager is transferred onto `Dataset`.
741 *output = new ShuffleDatasetOp::Dataset(ctx, input, buffer_size, count,
742 std::move(seeds), manager,
743 std::move(handle));
744 }
745 }
746
747 class ShuffleAndRepeatDatasetOp::Dataset : public ShuffleDatasetBase {
748 public:
Dataset(OpKernelContext * ctx,const DatasetBase * input,int64 buffer_size,RandomSeeds && seeds,SeedGeneratorManager * manager,int64 count,ResourceHandle && resource_handle)749 Dataset(OpKernelContext* ctx, const DatasetBase* input, int64 buffer_size,
750 RandomSeeds&& seeds, SeedGeneratorManager* manager, int64 count,
751 ResourceHandle&& resource_handle)
752 : ShuffleDatasetBase(ctx, input, buffer_size, manager->get(), count),
753 manager_(manager),
754 resource_handle_(std::move(resource_handle)),
755 resource_mgr_(ctx->resource_manager()),
756 seeds_(std::move(seeds)) {}
757
~Dataset()758 ~Dataset() override {
759 manager_->Unref();
760 Status s = resource_mgr_->Delete<SeedGeneratorManager>(
761 resource_handle_.container(), resource_handle_.name());
762 if (!s.ok()) {
763 LOG(WARNING) << "Failed to delete RNG resource: " << s.ToString();
764 }
765 }
766
op_type() const767 string op_type() const override { return kDatasetType; }
768
769 protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const770 Status AsGraphDefInternal(SerializationContext* ctx,
771 DatasetGraphDefBuilder* b,
772 Node** output) const override {
773 Node* input_graph_node = nullptr;
774 TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
775 Node* buffer_size = nullptr;
776 Node* seed = nullptr;
777 Node* seed2 = nullptr;
778 Node* count = nullptr;
779
780 TF_RETURN_IF_ERROR(b->AddScalar(buffer_size_, &buffer_size));
781 TF_RETURN_IF_ERROR(b->AddScalar(seeds_.input_seed(), &seed));
782 TF_RETURN_IF_ERROR(b->AddScalar(seeds_.input_seed2(), &seed2));
783 TF_RETURN_IF_ERROR(b->AddScalar(count_, &count));
784 AttrValue reshuffle_each_iteration;
785 b->BuildAttrValue(seed_generator_->reshuffle_each_iteration(),
786 &reshuffle_each_iteration);
787 TF_RETURN_IF_ERROR(b->AddDataset(
788 this, {input_graph_node, buffer_size, seed, seed2, count}, // Inputs
789 {std::make_pair(kReshuffleEachIteration,
790 reshuffle_each_iteration)}, // Attrs
791 output));
792 return Status::OK();
793 }
794
795 private:
796 SeedGeneratorManager* const manager_; // Owned.
797 const ResourceHandle resource_handle_;
798 ResourceMgr* const resource_mgr_; // Not owned.
799 const RandomSeeds seeds_;
800 };
801
802 class ShuffleAndRepeatDatasetOp::DatasetV2 : public ShuffleDatasetBase {
803 public:
DatasetV2(OpKernelContext * ctx,const DatasetBase * input,int64 buffer_size,int64 count,RandomSeeds && seeds,SeedGeneratorManager * manager,ResourceHandle && resource_handle,bool owns_resource)804 DatasetV2(OpKernelContext* ctx, const DatasetBase* input, int64 buffer_size,
805 int64 count, RandomSeeds&& seeds, SeedGeneratorManager* manager,
806 ResourceHandle&& resource_handle, bool owns_resource)
807 : ShuffleDatasetBase(ctx, input, buffer_size, manager->get(), count),
808 manager_(manager),
809 owns_resource_(owns_resource),
810 resource_handle_(std::move(resource_handle)),
811 resource_mgr_(ctx->resource_manager()),
812 seeds_(std::move(seeds)) {}
813
~DatasetV2()814 ~DatasetV2() override {
815 manager_->Unref();
816 if (owns_resource_) {
817 Status s = resource_mgr_->Delete<SeedGeneratorManager>(
818 resource_handle_.container(), resource_handle_.name());
819 if (!s.ok()) {
820 LOG(WARNING) << "Failed to delete RNG resource: " << s.ToString();
821 }
822 }
823 }
824
op_type() const825 string op_type() const override { return kDatasetType; }
826
827 protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const828 Status AsGraphDefInternal(SerializationContext* ctx,
829 DatasetGraphDefBuilder* b,
830 Node** output) const override {
831 Node* input_graph_node = nullptr;
832 TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
833 Node* buffer_size_node = nullptr;
834 Node* seed_node = nullptr;
835 Node* seed2_node = nullptr;
836 Node* count_node = nullptr;
837 TF_RETURN_IF_ERROR(b->AddScalar(buffer_size_, &buffer_size_node));
838 TF_RETURN_IF_ERROR(b->AddScalar(seeds_.input_seed(), &seed_node));
839 TF_RETURN_IF_ERROR(b->AddScalar(seeds_.input_seed2(), &seed2_node));
840 TF_RETURN_IF_ERROR(b->AddScalar(count_, &count_node));
841 Node* resource_handle_node = nullptr;
842 Tensor handle(DT_RESOURCE, TensorShape({}));
843 handle.scalar<ResourceHandle>()() = resource_handle_;
844 TF_RETURN_IF_ERROR(b->AddTensor(handle, &resource_handle_node));
845 AttrValue reshuffle_each_iteration;
846 b->BuildAttrValue(seed_generator_->reshuffle_each_iteration(),
847 &reshuffle_each_iteration);
848 TF_RETURN_IF_ERROR(
849 b->AddDataset(this,
850 {input_graph_node, buffer_size_node, seed_node,
851 seed2_node, count_node, resource_handle_node}, // Inputs
852 {std::make_pair(kReshuffleEachIteration,
853 reshuffle_each_iteration)}, // Attrs
854 output));
855 return Status::OK();
856 }
857
858 private:
859 SeedGeneratorManager* const manager_; // Owned
860 const bool owns_resource_;
861 const ResourceHandle resource_handle_;
862 ResourceMgr* const resource_mgr_; // Not owned.
863 const RandomSeeds seeds_;
864 };
865
ShuffleAndRepeatDatasetOp(OpKernelConstruction * ctx)866 ShuffleAndRepeatDatasetOp::ShuffleAndRepeatDatasetOp(OpKernelConstruction* ctx)
867 : ShuffleDatasetOpBase(ctx) {
868 auto& op_name = ctx->def().op();
869 if (op_name == kShuffleAndRepeatDatasetV2) {
870 op_version_ = 2;
871 } else if (op_name == kShuffleAndRepeatDatasetV1) {
872 op_version_ = 1;
873 }
874 if (ctx->HasAttr(kReshuffleEachIteration)) {
875 OP_REQUIRES_OK(
876 ctx, ctx->GetAttr(kReshuffleEachIteration, &reshuffle_each_iteration_));
877 }
878 }
879
MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)880 void ShuffleAndRepeatDatasetOp::MakeDataset(OpKernelContext* ctx,
881 DatasetBase* input,
882 DatasetBase** output) {
883 int64 buffer_size = 0;
884 OP_REQUIRES_OK(ctx,
885 ParseScalarArgument<int64>(ctx, kBufferSize, &buffer_size));
886 OP_REQUIRES(
887 ctx, buffer_size > 0,
888 errors::InvalidArgument("buffer_size must be greater than zero."));
889
890 int64 seed;
891 OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, kSeed, &seed));
892
893 int64 seed2;
894 OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, kSeed2, &seed2));
895
896 int64 count;
897 OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, kCount, &count));
898
899 OP_REQUIRES(ctx, count > 0 || count == -1,
900 errors::InvalidArgument(
901 "count must be greater than zero or equal to -1."));
902
903 RandomSeeds seeds(seed, seed2);
904
905 static std::atomic<int64> resource_id_counter(0);
906 const string& container = ctx->resource_manager()->default_container();
907 auto name = strings::StrCat(ctx->op_kernel().name(), "/", kSeedGenerator, "_",
908 resource_id_counter.fetch_add(1));
909 if (op_version_ == 2) {
910 auto handle = HandleFromInput(ctx, 5);
911 SeedGeneratorManager* manager = nullptr;
912 Status s = ctx->resource_manager()->Lookup<SeedGeneratorManager>(
913 handle.container(), handle.name(), &manager);
914 bool owns_resource = false;
915 if (errors::IsNotFound(s)) {
916 owns_resource = true;
917 OP_REQUIRES_OK(
918 ctx,
919 ctx->resource_manager()->LookupOrCreate<SeedGeneratorManager>(
920 container, name, &manager,
921 [reshuffle = reshuffle_each_iteration_,
922 &seeds](SeedGeneratorManager** manager) {
923 if (reshuffle) {
924 *manager =
925 new SeedGeneratorManager(new RandomSeedGenerator(seeds));
926 } else {
927 *manager =
928 new SeedGeneratorManager(new FixedSeedGenerator(seeds));
929 }
930 return Status::OK();
931 }));
932 handle = MakeResourceHandle<SeedGenerator>(ctx, container, name);
933 } else {
934 OP_REQUIRES_OK(ctx, s);
935 }
936
937 // Ownership of manager is transferred onto `DatasetV2`.
938 *output = new ShuffleAndRepeatDatasetOp::DatasetV2(
939 ctx, input, buffer_size, count, std::move(seeds), manager,
940 std::move(handle), owns_resource);
941 } else {
942 if (op_version_ != 1) {
943 LOG(WARNING) << "Unsupported version of shuffle dataset op: "
944 << op_version_ << ". Defaulting to version 1.";
945 }
946 SeedGeneratorManager* manager;
947 OP_REQUIRES_OK(
948 ctx,
949 ctx->resource_manager()->LookupOrCreate<SeedGeneratorManager>(
950 container, name, &manager,
951 [reshuffle = reshuffle_each_iteration_,
952 &seeds](SeedGeneratorManager** manager) {
953 if (reshuffle) {
954 *manager =
955 new SeedGeneratorManager(new RandomSeedGenerator(seeds));
956 } else {
957 *manager =
958 new SeedGeneratorManager(new FixedSeedGenerator(seeds));
959 }
960 return Status::OK();
961 }));
962 auto handle =
963 MakeResourceHandle<SeedGeneratorManager>(ctx, container, name);
964
965 // Ownership of manager is transferred onto `Dataset`.
966 *output = new Dataset(ctx, input, buffer_size, std::move(seeds), manager,
967 count, std::move(handle));
968 }
969 }
970
971 namespace {
972 REGISTER_KERNEL_BUILDER(Name("ShuffleDataset").Device(DEVICE_CPU),
973 ShuffleDatasetOp);
974
975 REGISTER_KERNEL_BUILDER(Name("ShuffleDatasetV2").Device(DEVICE_CPU),
976 ShuffleDatasetOp);
977
978 REGISTER_KERNEL_BUILDER(Name("ShuffleDatasetV3").Device(DEVICE_CPU),
979 ShuffleDatasetOp);
980
981 REGISTER_KERNEL_BUILDER(Name("ShuffleAndRepeatDataset").Device(DEVICE_CPU),
982 ShuffleAndRepeatDatasetOp);
983
984 REGISTER_KERNEL_BUILDER(Name("ShuffleAndRepeatDatasetV2").Device(DEVICE_CPU),
985 ShuffleAndRepeatDatasetOp);
986 } // namespace
987 } // namespace data
988 } // namespace tensorflow
989