1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include <deque> 17 #include <vector> 18 19 #include "tensorflow/core/framework/dataset.h" 20 #include "tensorflow/core/framework/partial_tensor_shape.h" 21 #include "tensorflow/core/framework/resource_mgr.h" 22 #include "tensorflow/core/framework/tensor.h" 23 #include "tensorflow/core/lib/random/philox_random.h" 24 #include "tensorflow/core/lib/random/random.h" 25 #include "tensorflow/core/lib/random/random_distributions.h" 26 27 namespace tensorflow { 28 namespace data { 29 namespace { 30 31 const int64 kLogIntervalMicros = 10 * 1000000; // 10 seconds. 32 33 const int64 kMaxEpochsInBuffer = 3; 34 35 // See documentation in ../../ops/dataset_ops.cc for a high-level 36 // description of the following op. 37 38 class ShuffleDatasetOpBase : public UnaryDatasetOpKernel { 39 public: ShuffleDatasetOpBase(OpKernelConstruction * ctx)40 explicit ShuffleDatasetOpBase(OpKernelConstruction* ctx) 41 : UnaryDatasetOpKernel(ctx) {} 42 43 protected: 44 // Abstract base dataset that implements a shuffling iterator. 45 class ShuffleDatasetBase : public DatasetBase { 46 public: ShuffleDatasetBase(OpKernelContext * ctx,const DatasetBase * input,int64 buffer_size,int64 count)47 ShuffleDatasetBase(OpKernelContext* ctx, const DatasetBase* input, 48 int64 buffer_size, int64 count) 49 : DatasetBase(DatasetContext(ctx)), 50 input_(input), 51 buffer_size_(buffer_size), 52 count_(count) { 53 input_->Ref(); 54 } 55 ~ShuffleDatasetBase()56 ~ShuffleDatasetBase() override { input_->Unref(); } 57 output_dtypes() const58 const DataTypeVector& output_dtypes() const override { 59 return input_->output_dtypes(); 60 } 61 output_shapes() const62 const std::vector<PartialTensorShape>& output_shapes() const override { 63 return input_->output_shapes(); 64 } 65 Cardinality() const66 int64 Cardinality() const override { return input_->Cardinality(); } 67 68 protected: 69 template <class T> 70 class Iterator : public DatasetIterator<T> { 71 public: Iterator(const typename DatasetIterator<T>::Params & params,int64 seed,int64 seed2)72 explicit Iterator(const typename DatasetIterator<T>::Params& params, 73 int64 seed, int64 seed2) 74 : DatasetIterator<T>(params), 75 seed_(seed), 76 seed2_(seed2), 77 input_impl_(nullptr), 78 epoch_(0), 79 num_elements_(0), 80 parent_generator_(seed, seed2), 81 generator_(&parent_generator_) { 82 buffer_ = absl::make_unique<std::vector<Tensor>[]>( 83 params.dataset->buffer_size_); 84 slices_.push_back(absl::make_unique<Slice>(0, 0)); 85 } 86 GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)87 Status GetNextInternal(IteratorContext* ctx, 88 std::vector<Tensor>* out_tensors, 89 bool* end_of_sequence) override { 90 mutex_lock l(mu_); 91 int64 start_micros = ctx->env()->NowMicros(); 92 int64 num_log_entries = 0; 93 bool first_call = false; 94 if (!input_impl_ && epoch_ == 0) { 95 first_call = true; 96 TF_RETURN_IF_ERROR(this->dataset()->input_->MakeIterator( 97 ctx, this->prefix(), &input_impl_)); 98 } 99 while (input_impl_ && num_elements_ < this->dataset()->buffer_size_) { 100 if (ctx->env()->NowMicros() > 101 ((num_log_entries + 1) * kLogIntervalMicros) + start_micros) { 102 num_log_entries++; 103 LOG(INFO) << "Filling up shuffle buffer (this may take a while): " 104 << num_elements_ << " of " 105 << this->dataset()->buffer_size_; 106 } 107 std::vector<Tensor> input_element; 108 bool end_of_input_sequence = false; 109 while (this->dataset()->count_ == -1 || 110 epoch_ < this->dataset()->count_) { 111 TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, &input_element, 112 &end_of_input_sequence)); 113 if (!end_of_input_sequence) { 114 first_call = false; 115 break; 116 } 117 if (first_call && this->dataset()->count_ == -1) { 118 // If the first call to GetNext() fails because the end 119 // of sequence has been reached, we terminate the 120 // iteration immediately. (Otherwise, this iterator 121 // would loop infinitely and never produce a value.) 122 *end_of_sequence = true; 123 return Status::OK(); 124 } 125 epoch_++; 126 int64 n = slices_.back()->end; 127 slices_.push_back(absl::make_unique<Slice>(n, n)); 128 TF_RETURN_IF_ERROR(this->dataset()->input_->MakeIterator( 129 ctx, this->prefix(), &input_impl_)); 130 } 131 if (!end_of_input_sequence) { 132 this->RecordBufferEnqueue(ctx, input_element); 133 buffer_[slices_.back()->end % this->dataset()->buffer_size_] = 134 std::move(input_element); 135 num_elements_++; 136 slices_.back()->end++; 137 } else { 138 input_impl_.reset(); 139 } 140 if (slices_.size() > kMaxEpochsInBuffer) { 141 // When the elements stored in `buffer_` span more than 142 // `kMaxEpochsInBuffer` epochs, we do not fill the buffer further to 143 // conserve memory. This means that the upper bound on the size of 144 // `buffer_` is `kMaxEpochsInBuffer * cardinality(input_dataset) + 145 // 1`. 146 break; 147 } 148 } 149 if (num_log_entries > 0) { 150 LOG(INFO) << "Shuffle buffer filled."; 151 } 152 153 if (num_elements_ > 0) { 154 *end_of_sequence = false; 155 // Garbage collect all empty slices. 156 while (!slices_.empty() && 157 slices_.front()->start == slices_.front()->end) { 158 slices_.pop_front(); 159 } 160 DCHECK(!slices_.empty()); 161 // Choose an element to produce uniformly at random from the first 162 // slice, and then remove the element from the slice. 163 int64 offset = 164 Random() % (slices_.front()->end - slices_.front()->start); 165 int64 index = 166 (slices_.front()->start + offset) % this->dataset()->buffer_size_; 167 *out_tensors = std::move(buffer_[index]); 168 this->RecordBufferDequeue(ctx, *out_tensors); 169 std::swap( 170 buffer_[index], 171 buffer_[slices_.front()->start % this->dataset()->buffer_size_]); 172 slices_.front()->start++; 173 num_elements_--; 174 } else { 175 DCHECK(input_impl_ == nullptr); 176 *end_of_sequence = true; 177 } 178 return Status::OK(); 179 } 180 181 protected: CreateNode(IteratorContext * ctx,model::Node::Args args) const182 std::shared_ptr<model::Node> CreateNode( 183 IteratorContext* ctx, model::Node::Args args) const override { 184 return model::MakeKnownRatioNode(std::move(args), 185 /*ratio=*/1); 186 } 187 ResetRngs()188 void ResetRngs() EXCLUSIVE_LOCKS_REQUIRED(mu_) { 189 // Reset the generators based on the current iterator seeds. 190 parent_generator_ = random::PhiloxRandom(seed_, seed2_); 191 generator_ = random::SingleSampleAdapter<random::PhiloxRandom>( 192 &parent_generator_); 193 generator_.Skip(num_random_samples_); 194 } 195 SaveInternal(IteratorStateWriter * writer)196 Status SaveInternal(IteratorStateWriter* writer) override { 197 mutex_lock l(mu_); 198 // Save state needed to restore the random number generators. 199 TF_RETURN_IF_ERROR(writer->WriteScalar( 200 this->full_name("num_random_samples"), num_random_samples_)); 201 TF_RETURN_IF_ERROR(writer->WriteScalar(this->full_name("seed"), seed_)); 202 TF_RETURN_IF_ERROR( 203 writer->WriteScalar(this->full_name("seed2"), seed2_)); 204 205 // Save input iterator if it hasn't been exhausted else write 206 // "end_of_input_sequence". 207 if (!input_impl_) { 208 TF_RETURN_IF_ERROR(writer->WriteScalar( 209 this->full_name("end_of_input_sequence"), "")); 210 } else { 211 TF_RETURN_IF_ERROR(this->SaveInput(writer, input_impl_)); 212 } 213 214 // Save the epoch counter, buffer, and buffer slices. 215 TF_RETURN_IF_ERROR( 216 writer->WriteScalar(this->full_name("epoch"), epoch_)); 217 TF_RETURN_IF_ERROR(writer->WriteScalar(this->full_name("num_elements"), 218 num_elements_)); 219 TF_RETURN_IF_ERROR(writer->WriteScalar(this->full_name("slices_size"), 220 slices_.size())); 221 for (size_t i = 0; i < slices_.size(); ++i) { 222 TF_RETURN_IF_ERROR(writer->WriteScalar( 223 this->full_name(strings::StrCat("slices_start_", i)), 224 slices_[i]->start)); 225 TF_RETURN_IF_ERROR(writer->WriteScalar( 226 this->full_name(strings::StrCat("slices_end_", i)), 227 slices_[i]->end)); 228 for (size_t j = slices_[i]->start; j < slices_[i]->end; ++j) { 229 size_t index = j % this->dataset()->buffer_size_; 230 TF_RETURN_IF_ERROR(writer->WriteScalar( 231 this->full_name(strings::StrCat("buffer_", index, "_size")), 232 buffer_[index].size())); 233 for (size_t k = 0; k < buffer_[index].size(); ++k) { 234 TF_RETURN_IF_ERROR(writer->WriteTensor( 235 this->full_name(strings::StrCat("buffer_", index, "_", k)), 236 buffer_[index][k])); 237 } 238 } 239 } 240 241 return Status::OK(); 242 } 243 RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)244 Status RestoreInternal(IteratorContext* ctx, 245 IteratorStateReader* reader) override { 246 mutex_lock l(mu_); 247 // Restore the random number generators. 248 TF_RETURN_IF_ERROR(reader->ReadScalar( 249 this->full_name("num_random_samples"), &num_random_samples_)); 250 TF_RETURN_IF_ERROR(reader->ReadScalar(this->full_name("seed"), &seed_)); 251 TF_RETURN_IF_ERROR( 252 reader->ReadScalar(this->full_name("seed2"), &seed2_)); 253 ResetRngs(); 254 255 // Restore the input iterator if it wasn't already exhausted. 256 if (!reader->Contains(this->full_name("end_of_input_sequence"))) { 257 TF_RETURN_IF_ERROR(this->dataset()->input_->MakeIterator( 258 ctx, this->prefix(), &input_impl_)); 259 TF_RETURN_IF_ERROR(this->RestoreInput(ctx, reader, input_impl_)); 260 } else { 261 input_impl_.reset(); 262 } 263 264 // Restore the epoch counter, buffer, and buffer slices. 265 TF_RETURN_IF_ERROR( 266 reader->ReadScalar(this->full_name("epoch"), &epoch_)); 267 TF_RETURN_IF_ERROR(reader->ReadScalar(this->full_name("num_elements"), 268 &num_elements_)); 269 size_t slices_size; 270 { 271 int64 temp; 272 TF_RETURN_IF_ERROR( 273 reader->ReadScalar(this->full_name("slices_size"), &temp)); 274 slices_size = static_cast<size_t>(temp); 275 } 276 buffer_ = absl::make_unique<std::vector<Tensor>[]>( 277 this->dataset()->buffer_size_); 278 for (size_t i = 0; i < slices_size; ++i) { 279 int64 start; 280 TF_RETURN_IF_ERROR(reader->ReadScalar( 281 this->full_name(strings::StrCat("slices_start_", i)), &start)); 282 int64 end; 283 TF_RETURN_IF_ERROR(reader->ReadScalar( 284 this->full_name(strings::StrCat("slices_end_", i)), &end)); 285 slices_.push_back(absl::make_unique<Slice>(start, end)); 286 for (size_t j = start; j < end; ++j) { 287 size_t index = j % this->dataset()->buffer_size_; 288 int64 list_size; 289 TF_RETURN_IF_ERROR(reader->ReadScalar( 290 this->full_name(strings::StrCat("buffer_", index, "_size")), 291 &list_size)); 292 buffer_[index] = std::vector<Tensor>(list_size); 293 for (int k = 0; k < list_size; ++k) { 294 TF_RETURN_IF_ERROR(reader->ReadTensor( 295 this->full_name(strings::StrCat("buffer_", index, "_", k)), 296 &buffer_[index][k])); 297 } 298 } 299 } 300 301 return Status::OK(); 302 } 303 304 mutex mu_; 305 int64 seed_ GUARDED_BY(mu_); 306 int64 seed2_ GUARDED_BY(mu_); 307 308 private: 309 // Used to represent slices of `buffer_` that belong to different epochs. 310 // The invariant maintained by the implementation is: `start` <= `end`. 311 // When using `start` and `end` to index into `buffer_`, their values 312 // should be taken modulo the size of `buffer_` as their absolute value 313 // can be greater than the range of `buffer_`. 314 struct Slice { Slicetensorflow::data::__anon0b80ed660111::ShuffleDatasetOpBase::ShuffleDatasetBase::Iterator::Slice315 Slice(int64 start, int64 end) : start(start), end(end) {} 316 317 int64 start; 318 int64 end; 319 }; 320 Random()321 random::SingleSampleAdapter<random::PhiloxRandom>::ResultType Random() 322 EXCLUSIVE_LOCKS_REQUIRED(mu_) { 323 num_random_samples_++; 324 auto out = generator_(); 325 return out; 326 } 327 328 std::unique_ptr<std::vector<Tensor>[]> buffer_ GUARDED_BY(mu_); 329 std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_); 330 int64 epoch_ GUARDED_BY(mu_); 331 int64 num_elements_ GUARDED_BY(mu_); 332 std::deque<std::unique_ptr<Slice>> slices_ GUARDED_BY(mu_); 333 random::PhiloxRandom parent_generator_ GUARDED_BY(mu_); 334 random::SingleSampleAdapter<random::PhiloxRandom> generator_ 335 GUARDED_BY(mu_); 336 int64 num_random_samples_ GUARDED_BY(mu_) = 0; 337 }; 338 339 const DatasetBase* const input_; 340 const int64 buffer_size_; 341 const int64 count_; 342 }; 343 }; 344 345 class ShuffleDatasetOp : public ShuffleDatasetOpBase { 346 public: ShuffleDatasetOp(OpKernelConstruction * ctx)347 explicit ShuffleDatasetOp(OpKernelConstruction* ctx) 348 : ShuffleDatasetOpBase(ctx) { 349 OP_REQUIRES_OK(ctx, ctx->GetAttr("reshuffle_each_iteration", 350 &reshuffle_each_iteration_)); 351 } 352 MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)353 void MakeDataset(OpKernelContext* ctx, DatasetBase* input, 354 DatasetBase** output) override { 355 int64 buffer_size; 356 OP_REQUIRES_OK( 357 ctx, ParseScalarArgument<int64>(ctx, "buffer_size", &buffer_size)); 358 OP_REQUIRES( 359 ctx, buffer_size > 0, 360 errors::InvalidArgument("buffer_size must be greater than zero.")); 361 362 int64 seed; 363 OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "seed", &seed)); 364 365 int64 seed2; 366 OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "seed2", &seed2)); 367 368 // By TensorFlow convention, passing 0 for both seeds indicates 369 // that the shuffling should be seeded non-deterministically. 370 if (seed == 0 && seed2 == 0) { 371 seed = random::New64(); 372 seed2 = random::New64(); 373 } 374 375 int64 count = 1; 376 if (reshuffle_each_iteration_) { 377 *output = 378 new ReshufflingDataset(ctx, input, buffer_size, seed, seed2, count); 379 } else { 380 *output = 381 new FixedSeedDataset(ctx, input, buffer_size, seed, seed2, count); 382 } 383 } 384 385 private: 386 // A dataset that uses a pseudorandom sequence of seeds for the iterators 387 // created from it. Used when `reshuffle_each_iteration` is true. 388 class ReshufflingDataset : public ShuffleDatasetBase { 389 public: ReshufflingDataset(OpKernelContext * ctx,const DatasetBase * input,int64 buffer_size,int64 seed,int64 seed2,int64 count)390 ReshufflingDataset(OpKernelContext* ctx, const DatasetBase* input, 391 int64 buffer_size, int64 seed, int64 seed2, int64 count) 392 : ShuffleDatasetBase(ctx, input, buffer_size, count), 393 seed_(seed), 394 seed2_(seed2) {} 395 DebugString() const396 string DebugString() const override { 397 return strings::StrCat("ShuffleDatasetOp(", buffer_size_, ", ", seed_, 398 ", ", seed2_, ")::ReshufflingDataset"); 399 } 400 MakeIteratorInternal(const string & prefix) const401 std::unique_ptr<IteratorBase> MakeIteratorInternal( 402 const string& prefix) const override { 403 return absl::make_unique<Iterator>( 404 Iterator::Params{this, strings::StrCat(prefix, "::Shuffle")}, seed_, 405 seed2_); 406 } 407 408 protected: 409 class RandomSeedGenerator : public ResourceBase { 410 public: RandomSeedGenerator(int64 seed,int64 seed2)411 RandomSeedGenerator(int64 seed, int64 seed2) 412 : seed_(seed), 413 seed2_(seed2), 414 parent_generator_(seed, seed2), 415 generator_(&parent_generator_) {} 416 DebugString() const417 string DebugString() const override { 418 return "ReshufflingDataset::RandomSeedGenerator"; 419 } 420 GenerateRandomSeeds(int64 * seed1,int64 * seed2)421 void GenerateRandomSeeds(int64* seed1, int64* seed2) { 422 mutex_lock l(mu_); 423 num_random_samples_++; 424 *seed1 = generator_(); 425 num_random_samples_++; 426 *seed2 = generator_(); 427 } 428 num_random_samples()429 int64 num_random_samples() { 430 tf_shared_lock l(mu_); 431 return num_random_samples_; 432 } 433 set_num_random_samples(int64 num_random_samples)434 void set_num_random_samples(int64 num_random_samples) { 435 mutex_lock l(mu_); 436 num_random_samples_ = num_random_samples; 437 } 438 Reset()439 void Reset() { 440 mutex_lock l(mu_); 441 // Reset the generators based on the current seeds. 442 parent_generator_ = random::PhiloxRandom(seed_, seed2_); 443 generator_ = random::SingleSampleAdapter<random::PhiloxRandom>( 444 &parent_generator_); 445 generator_.Skip(num_random_samples_); 446 } 447 448 private: 449 const int64 seed_; 450 const int64 seed2_; 451 mutex mu_; 452 random::PhiloxRandom parent_generator_ GUARDED_BY(mu_); 453 random::SingleSampleAdapter<random::PhiloxRandom> generator_ 454 GUARDED_BY(mu_); 455 int64 num_random_samples_ GUARDED_BY(mu_) = 0; 456 }; 457 458 class Iterator : public ShuffleDatasetBase::Iterator<ReshufflingDataset> { 459 public: Iterator(const Params & params,int64 seed,int64 seed2)460 explicit Iterator(const Params& params, int64 seed, int64 seed2) 461 : ShuffleDatasetBase::Iterator<ReshufflingDataset>(params, seed, 462 seed2) {} 463 ~Iterator()464 ~Iterator() override { seed_generator_->Unref(); } 465 Initialize(IteratorContext * ctx)466 Status Initialize(IteratorContext* ctx) override { 467 // Firstly, lookup or create a seed generator from the IteratorResource 468 // resource_mgr. 469 ResourceMgr* mgr = ctx->resource_mgr(); 470 RandomSeedGenerator* seed_generator; 471 const string name = strings::StrCat( 472 prefix(), "::", dataset()->type_string(), "::RandomSeedGenerator"); 473 474 int64 dataset_seed, dataset_seed2; 475 { 476 tf_shared_lock l(mu_); 477 // Ideally we'd like to hold this lock in the LookupOrCreate method, 478 // but that trips up our Deadlock detection code. 479 dataset_seed = seed_; 480 dataset_seed2 = seed2_; 481 } 482 TF_RETURN_IF_ERROR(mgr->LookupOrCreate<RandomSeedGenerator>( 483 "tf_data", name, &seed_generator, 484 [dataset_seed, 485 dataset_seed2](RandomSeedGenerator** seed_generator) { 486 // On the first iterator creation, use the original seeds from the 487 // dataset to seed a `RandomSeedGenerator` that will provide seeds 488 // for subsequent repetitions of the same dataset. 489 *seed_generator = 490 new RandomSeedGenerator(dataset_seed, dataset_seed2); 491 return Status::OK(); 492 })); 493 // Now use the seed generator to update the base class Iterator seeds 494 // and random number generator with generated seeds for the current 495 // repetition. 496 mutex_lock l(mu_); 497 seed_generator->GenerateRandomSeeds(&seed_, &seed2_); 498 ResetRngs(); 499 seed_generator_ = seed_generator; 500 return Status::OK(); 501 } 502 503 protected: CreateNode(IteratorContext * ctx,model::Node::Args args) const504 std::shared_ptr<model::Node> CreateNode( 505 IteratorContext* ctx, model::Node::Args args) const override { 506 return model::MakeKnownRatioNode(std::move(args), 507 /*ratio=*/1); 508 } 509 SaveInternal(IteratorStateWriter * writer)510 Status SaveInternal(IteratorStateWriter* writer) override { 511 // Save RNG state of Dataset. 512 TF_RETURN_IF_ERROR( 513 writer->WriteScalar(full_name("ds_num_random_samples"), 514 seed_generator_->num_random_samples())); 515 516 // Save the Iterator. 517 return ShuffleDatasetBase::Iterator<ReshufflingDataset>::SaveInternal( 518 writer); 519 } 520 RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)521 Status RestoreInternal(IteratorContext* ctx, 522 IteratorStateReader* reader) override { 523 // Restore RNG state of Dataset. 524 int64 num_random_samples; 525 TF_RETURN_IF_ERROR(reader->ReadScalar( 526 full_name("ds_num_random_samples"), &num_random_samples)); 527 seed_generator_->set_num_random_samples(num_random_samples); 528 seed_generator_->Reset(); 529 530 // Restore the Iterator. 531 return ShuffleDatasetBase::Iterator< 532 ReshufflingDataset>::RestoreInternal(ctx, reader); 533 } 534 535 private: 536 RandomSeedGenerator* seed_generator_; 537 }; 538 AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const539 Status AsGraphDefInternal(SerializationContext* ctx, 540 DatasetGraphDefBuilder* b, 541 Node** output) const override { 542 Node* input_graph_node = nullptr; 543 TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); 544 Node* buffer_size = nullptr; 545 Node* seed = nullptr; 546 Node* seed2 = nullptr; 547 AttrValue reshuffle_each_iteration; 548 549 TF_RETURN_IF_ERROR(b->AddScalar(buffer_size_, &buffer_size)); 550 TF_RETURN_IF_ERROR(b->AddScalar(seed_, &seed)); 551 TF_RETURN_IF_ERROR(b->AddScalar(seed2_, &seed2)); 552 b->BuildAttrValue(true, &reshuffle_each_iteration); 553 TF_RETURN_IF_ERROR(b->AddDataset( 554 this, {input_graph_node, buffer_size, seed, seed2}, // Inputs 555 {std::make_pair("reshuffle_each_iteration", 556 reshuffle_each_iteration)}, // Attrs 557 output)); 558 return Status::OK(); 559 } 560 561 private: 562 const int64 seed_; 563 const int64 seed2_; 564 }; 565 566 // A dataset that uses the same fixed seed for all iterators created from it. 567 // Used when `reshuffle_each_iteration` is false. 568 class FixedSeedDataset : public ShuffleDatasetBase { 569 public: FixedSeedDataset(OpKernelContext * ctx,const DatasetBase * input,int64 buffer_size,int64 seed,int64 seed2,int64 count)570 FixedSeedDataset(OpKernelContext* ctx, const DatasetBase* input, 571 int64 buffer_size, int64 seed, int64 seed2, int64 count) 572 : ShuffleDatasetBase(ctx, input, buffer_size, count), 573 seed_(seed), 574 seed2_(seed2) {} 575 DebugString() const576 string DebugString() const override { 577 return strings::StrCat("ShuffleDatasetOp(", buffer_size_, ", ", seed_, 578 ", ", seed2_, ")::FixedSeedDataset"); 579 } 580 MakeIteratorInternal(const string & prefix) const581 std::unique_ptr<IteratorBase> MakeIteratorInternal( 582 const string& prefix) const override { 583 return absl::make_unique< 584 ShuffleDatasetBase::Iterator<ShuffleDatasetBase>>( 585 ShuffleDatasetBase::Iterator<ShuffleDatasetBase>::Params{ 586 this, strings::StrCat(prefix, "::Shuffle")}, 587 seed_, seed2_); 588 } 589 590 protected: AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const591 Status AsGraphDefInternal(SerializationContext* ctx, 592 DatasetGraphDefBuilder* b, 593 Node** output) const override { 594 Node* input_graph_node = nullptr; 595 TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); 596 Node* buffer_size = nullptr; 597 Node* seed = nullptr; 598 Node* seed2 = nullptr; 599 AttrValue reshuffle_each_iteration; 600 601 TF_RETURN_IF_ERROR(b->AddScalar(buffer_size_, &buffer_size)); 602 TF_RETURN_IF_ERROR(b->AddScalar(seed_, &seed)); 603 TF_RETURN_IF_ERROR(b->AddScalar(seed2_, &seed2)); 604 b->BuildAttrValue(false, &reshuffle_each_iteration); 605 TF_RETURN_IF_ERROR(b->AddDataset( 606 this, {input_graph_node, buffer_size, seed, seed2}, // Inputs 607 {std::make_pair("reshuffle_each_iteration", 608 reshuffle_each_iteration)}, // Attrs 609 output)); 610 return Status::OK(); 611 } 612 613 private: 614 const int64 seed_; 615 const int64 seed2_; 616 }; 617 618 bool reshuffle_each_iteration_; 619 }; 620 621 class ShuffleAndRepeatDatasetOp : public ShuffleDatasetOpBase { 622 public: ShuffleAndRepeatDatasetOp(OpKernelConstruction * ctx)623 explicit ShuffleAndRepeatDatasetOp(OpKernelConstruction* ctx) 624 : ShuffleDatasetOpBase(ctx) {} 625 MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)626 void MakeDataset(OpKernelContext* ctx, DatasetBase* input, 627 DatasetBase** output) override { 628 int64 buffer_size; 629 OP_REQUIRES_OK( 630 ctx, ParseScalarArgument<int64>(ctx, "buffer_size", &buffer_size)); 631 OP_REQUIRES( 632 ctx, buffer_size > 0, 633 errors::InvalidArgument("buffer_size must be greater than zero.")); 634 635 int64 seed; 636 OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "seed", &seed)); 637 638 int64 seed2; 639 OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "seed2", &seed2)); 640 641 int64 count; 642 OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "count", &count)); 643 644 // By TensorFlow convention, if both seeds are 0, then shuffling should be 645 // seeded non-deterministically. 646 if (seed == 0 && seed2 == 0) { 647 seed = random::New64(); 648 seed2 = random::New64(); 649 } 650 651 *output = new Dataset(ctx, input, buffer_size, seed, seed2, count); 652 } 653 654 private: 655 class Dataset : public ShuffleDatasetBase { 656 public: Dataset(OpKernelContext * ctx,const DatasetBase * input,int64 buffer_size,int64 seed,int64 seed2,int64 count)657 Dataset(OpKernelContext* ctx, const DatasetBase* input, int64 buffer_size, 658 int64 seed, int64 seed2, int64 count) 659 : ShuffleDatasetBase(ctx, input, buffer_size, count), 660 seed_(seed), 661 seed2_(seed2) {} 662 DebugString() const663 string DebugString() const override { 664 return strings::StrCat("ShuffleAndRepeatDatasetOp(", buffer_size_, ", ", 665 seed_, ", ", seed2_, ", ", count_, ")::Dataset"); 666 } 667 MakeIteratorInternal(const string & prefix) const668 std::unique_ptr<IteratorBase> MakeIteratorInternal( 669 const string& prefix) const override { 670 return absl::make_unique< 671 ShuffleDatasetBase::Iterator<ShuffleDatasetBase>>( 672 ShuffleDatasetBase::Iterator<ShuffleDatasetBase>::Params{ 673 this, strings::StrCat(prefix, "::ShuffleAndRepeat")}, 674 seed_, seed2_); 675 } 676 677 protected: AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const678 Status AsGraphDefInternal(SerializationContext* ctx, 679 DatasetGraphDefBuilder* b, 680 Node** output) const override { 681 Node* input_graph_node = nullptr; 682 TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); 683 Node* buffer_size = nullptr; 684 Node* seed = nullptr; 685 Node* seed2 = nullptr; 686 Node* count = nullptr; 687 688 TF_RETURN_IF_ERROR(b->AddScalar(buffer_size_, &buffer_size)); 689 TF_RETURN_IF_ERROR(b->AddScalar(seed_, &seed)); 690 TF_RETURN_IF_ERROR(b->AddScalar(seed2_, &seed2)); 691 TF_RETURN_IF_ERROR(b->AddScalar(count_, &count)); 692 TF_RETURN_IF_ERROR(b->AddDataset( 693 this, {input_graph_node, buffer_size, seed, seed2, count}, // Inputs 694 {}, // Attrs 695 output)); 696 return Status::OK(); 697 } 698 699 private: 700 const int64 seed_; 701 const int64 seed2_; 702 }; 703 }; 704 705 REGISTER_KERNEL_BUILDER(Name("ShuffleDataset").Device(DEVICE_CPU), 706 ShuffleDatasetOp); 707 708 REGISTER_KERNEL_BUILDER(Name("ShuffleAndRepeatDataset").Device(DEVICE_CPU), 709 ShuffleAndRepeatDatasetOp); 710 711 } // namespace 712 } // namespace data 713 } // namespace tensorflow 714