1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include <deque> 17 #include <vector> 18 19 #include "tensorflow/core/framework/partial_tensor_shape.h" 20 #include "tensorflow/core/framework/tensor.h" 21 #include "tensorflow/core/kernels/data/dataset.h" 22 #include "tensorflow/core/lib/random/philox_random.h" 23 #include "tensorflow/core/lib/random/random.h" 24 #include "tensorflow/core/lib/random/random_distributions.h" 25 26 namespace tensorflow { 27 28 namespace { 29 30 const int64 kLogIntervalMicros = 10 * 1000000; // 10 seconds. 31 32 // See documentation in ../ops/dataset_ops.cc for a high-level 33 // description of the following op. 34 35 class ShuffleDatasetOpBase : public UnaryDatasetOpKernel { 36 public: ShuffleDatasetOpBase(OpKernelConstruction * ctx)37 explicit ShuffleDatasetOpBase(OpKernelConstruction* ctx) 38 : UnaryDatasetOpKernel(ctx) {} 39 40 protected: 41 // Abstract base dataset that implements a shuffling iterator. 42 class ShuffleDatasetBase : public GraphDatasetBase { 43 public: ShuffleDatasetBase(OpKernelContext * ctx,const DatasetBase * input,int64 buffer_size,int64 count)44 ShuffleDatasetBase(OpKernelContext* ctx, const DatasetBase* input, 45 int64 buffer_size, int64 count) 46 : GraphDatasetBase(ctx), 47 input_(input), 48 buffer_size_(buffer_size), 49 count_(count) { 50 input_->Ref(); 51 } 52 ~ShuffleDatasetBase()53 ~ShuffleDatasetBase() override { input_->Unref(); } 54 output_dtypes() const55 const DataTypeVector& output_dtypes() const override { 56 return input_->output_dtypes(); 57 } 58 output_shapes() const59 const std::vector<PartialTensorShape>& output_shapes() const override { 60 return input_->output_shapes(); 61 } 62 63 protected: 64 class Iterator : public DatasetIterator<ShuffleDatasetBase> { 65 public: Iterator(const Params & params,int64 seed,int64 seed2)66 explicit Iterator(const Params& params, int64 seed, int64 seed2) 67 : DatasetIterator<ShuffleDatasetBase>(params), 68 input_impl_(nullptr), 69 seed_(seed), 70 seed2_(seed2), 71 epoch_(0), 72 num_elements_(0), 73 parent_generator_(seed, seed2), 74 generator_(&parent_generator_) { 75 buffer_.reset(new std::vector<Tensor>[params.dataset->buffer_size_]); 76 slices_.emplace_back(new Slice{0, 0}); 77 } 78 GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)79 Status GetNextInternal(IteratorContext* ctx, 80 std::vector<Tensor>* out_tensors, 81 bool* end_of_sequence) override { 82 mutex_lock l(mu_); 83 int64 start_micros = ctx->env()->NowMicros(); 84 int64 num_log_entries = 0; 85 bool first_call = false; 86 if (!input_impl_ && epoch_ == 0) { 87 first_call = true; 88 input_impl_ = dataset()->input_->MakeIterator(prefix()); 89 } 90 while (input_impl_ && num_elements_ < dataset()->buffer_size_) { 91 if (ctx->env()->NowMicros() > 92 ((num_log_entries + 1) * kLogIntervalMicros) + start_micros) { 93 num_log_entries++; 94 LOG(INFO) << "Filling up shuffle buffer (this may take a while): " 95 << num_elements_ << " of " << dataset()->buffer_size_; 96 } 97 std::vector<Tensor> input_element; 98 bool end_of_input_sequence = false; 99 while (dataset()->count_ == -1 || epoch_ < dataset()->count_) { 100 TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, &input_element, 101 &end_of_input_sequence)); 102 if (!end_of_input_sequence) { 103 first_call = false; 104 break; 105 } 106 if (first_call && dataset()->count_ == -1) { 107 // If the first call to GetNext() fails because the end 108 // of sequence has been reached, we terminate the 109 // iteration immediately. (Otherwise, this iterator 110 // would loop infinitely and never produce a value.) 111 *end_of_sequence = true; 112 return Status::OK(); 113 } 114 epoch_++; 115 int64 n = slices_.back()->end; 116 slices_.emplace_back(new Slice{n, n}); 117 input_impl_ = dataset()->input_->MakeIterator(prefix()); 118 } 119 if (!end_of_input_sequence) { 120 buffer_[slices_.back()->end % dataset()->buffer_size_] = 121 std::move(input_element); 122 num_elements_++; 123 slices_.back()->end++; 124 } else { 125 input_impl_.reset(); 126 } 127 } 128 if (num_log_entries > 0) { 129 LOG(INFO) << "Shuffle buffer filled."; 130 } 131 132 if (num_elements_ > 0) { 133 *end_of_sequence = false; 134 // Garbage collect all empty slices. 135 while (!slices_.empty() && 136 slices_.front()->start == slices_.front()->end) { 137 slices_.pop_front(); 138 } 139 DCHECK(!slices_.empty()); 140 // Choose an element to produce uniformly at random from the first 141 // slice, and then remove the element from the slice. 142 int64 offset = 143 Random() % (slices_.front()->end - slices_.front()->start); 144 int64 index = 145 (slices_.front()->start + offset) % dataset()->buffer_size_; 146 *out_tensors = std::move(buffer_[index]); 147 std::swap(buffer_[index], 148 buffer_[slices_.front()->start % dataset()->buffer_size_]); 149 slices_.front()->start++; 150 num_elements_--; 151 } else { 152 DCHECK(input_impl_ == nullptr); 153 *end_of_sequence = true; 154 } 155 return Status::OK(); 156 } 157 158 protected: SaveInternal(IteratorStateWriter * writer)159 Status SaveInternal(IteratorStateWriter* writer) override { 160 mutex_lock l(mu_); 161 162 // Save state needed to restore the random number generators. 163 TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("num_random_samples"), 164 num_random_samples_)); 165 166 // Save input iterator if it hasn't been exhausted else write 167 // "end_of_input_sequence". 168 if (!input_impl_) { 169 TF_RETURN_IF_ERROR( 170 writer->WriteScalar(full_name("end_of_input_sequence"), "")); 171 } else { 172 TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_)); 173 } 174 175 // Save the epoch counter, buffer, and buffer slices. 176 TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("epoch"), epoch_)); 177 TF_RETURN_IF_ERROR( 178 writer->WriteScalar(full_name("num_elements"), num_elements_)); 179 TF_RETURN_IF_ERROR( 180 writer->WriteScalar(full_name("slices_size"), slices_.size())); 181 for (size_t i = 0; i < slices_.size(); ++i) { 182 TF_RETURN_IF_ERROR(writer->WriteScalar( 183 full_name(strings::StrCat("slices_start_", i)), 184 slices_[i]->start)); 185 TF_RETURN_IF_ERROR(writer->WriteScalar( 186 full_name(strings::StrCat("slices_end_", i)), slices_[i]->end)); 187 for (size_t j = slices_[i]->start; j < slices_[i]->end; ++j) { 188 size_t index = j % dataset()->buffer_size_; 189 TF_RETURN_IF_ERROR(writer->WriteScalar( 190 full_name(strings::StrCat("buffer_", index, "_size")), 191 buffer_[index].size())); 192 for (size_t k = 0; k < buffer_[index].size(); ++k) { 193 TF_RETURN_IF_ERROR(writer->WriteTensor( 194 full_name(strings::StrCat("buffer_", index, "_", k)), 195 buffer_[index][k])); 196 } 197 } 198 } 199 200 return Status::OK(); 201 } 202 RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)203 Status RestoreInternal(IteratorContext* ctx, 204 IteratorStateReader* reader) override { 205 mutex_lock l(mu_); 206 207 // Restore the random number generators. 208 TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("num_random_samples"), 209 &num_random_samples_)); 210 ResetRngs(); 211 212 // Restore the input iterator if it wasn't already exhausted. 213 if (!reader->Contains(full_name("end_of_input_sequence"))) { 214 input_impl_ = dataset()->input_->MakeIterator(prefix()); 215 TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_)); 216 } else { 217 input_impl_.reset(); 218 } 219 220 // Restore the epoch counter, buffer, and buffer slices. 221 TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("epoch"), &epoch_)); 222 TF_RETURN_IF_ERROR( 223 reader->ReadScalar(full_name("num_elements"), &num_elements_)); 224 size_t slices_size; 225 { 226 int64 temp; 227 TF_RETURN_IF_ERROR( 228 reader->ReadScalar(full_name("slices_size"), &temp)); 229 slices_size = static_cast<size_t>(temp); 230 } 231 buffer_.reset(new std::vector<Tensor>[dataset()->buffer_size_]); 232 for (size_t i = 0; i < slices_size; ++i) { 233 int64 start; 234 TF_RETURN_IF_ERROR(reader->ReadScalar( 235 full_name(strings::StrCat("slices_start_", i)), &start)); 236 int64 end; 237 TF_RETURN_IF_ERROR(reader->ReadScalar( 238 full_name(strings::StrCat("slices_end_", i)), &end)); 239 slices_.emplace_back(new Slice{start, end}); 240 for (size_t j = start; j < end; ++j) { 241 size_t index = j % dataset()->buffer_size_; 242 int64 list_size; 243 TF_RETURN_IF_ERROR(reader->ReadScalar( 244 full_name(strings::StrCat("buffer_", index, "_size")), 245 &list_size)); 246 buffer_[index] = std::vector<Tensor>(list_size); 247 for (int k = 0; k < list_size; ++k) { 248 TF_RETURN_IF_ERROR(reader->ReadTensor( 249 full_name(strings::StrCat("buffer_", index, "_", k)), 250 &buffer_[index][k])); 251 } 252 } 253 } 254 255 return Status::OK(); 256 } 257 258 private: 259 // Used to represent slices of `buffer_` that belong to different epochs. 260 // The invariant maintained by the implementation is: `start` <= `end`. 261 // When using `start` and `end` to index into `buffer_`, their values 262 // should be taken modulo the size of `buffer_` as their absolute value 263 // can be greater than the range of `buffer_`. 264 struct Slice { Slicetensorflow::__anon090660da0111::ShuffleDatasetOpBase::ShuffleDatasetBase::Iterator::Slice265 Slice(int64 start, int64 end) : start(start), end(end) {} 266 267 int64 start; 268 int64 end; 269 }; 270 Random()271 random::SingleSampleAdapter<random::PhiloxRandom>::ResultType Random() 272 EXCLUSIVE_LOCKS_REQUIRED(mu_) { 273 num_random_samples_++; 274 auto out = generator_(); 275 return out; 276 } 277 ResetRngs()278 void ResetRngs() EXCLUSIVE_LOCKS_REQUIRED(mu_) { 279 // Reset the generators based on the current iterator seeds. 280 parent_generator_ = random::PhiloxRandom(seed_, seed2_); 281 generator_ = random::SingleSampleAdapter<random::PhiloxRandom>( 282 &parent_generator_); 283 generator_.Skip(num_random_samples_); 284 } 285 286 mutex mu_; 287 std::unique_ptr<std::vector<Tensor>[]> buffer_ GUARDED_BY(mu_); 288 std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_); 289 const int64 seed_ GUARDED_BY(mu_); 290 const int64 seed2_ GUARDED_BY(mu_); 291 int64 epoch_ GUARDED_BY(mu_); 292 int64 num_elements_ GUARDED_BY(mu_); 293 std::deque<std::unique_ptr<Slice>> slices_ GUARDED_BY(mu_); 294 random::PhiloxRandom parent_generator_ GUARDED_BY(mu_); 295 random::SingleSampleAdapter<random::PhiloxRandom> generator_ 296 GUARDED_BY(mu_); 297 int64 num_random_samples_ GUARDED_BY(mu_) = 0; 298 }; 299 300 const DatasetBase* const input_; 301 const int64 buffer_size_; 302 const int64 count_; 303 }; 304 }; 305 306 class ShuffleDatasetOp : public ShuffleDatasetOpBase { 307 public: ShuffleDatasetOp(OpKernelConstruction * ctx)308 explicit ShuffleDatasetOp(OpKernelConstruction* ctx) 309 : ShuffleDatasetOpBase(ctx) { 310 OP_REQUIRES_OK(ctx, ctx->GetAttr("reshuffle_each_iteration", 311 &reshuffle_each_iteration_)); 312 } 313 MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)314 void MakeDataset(OpKernelContext* ctx, DatasetBase* input, 315 DatasetBase** output) override { 316 int64 buffer_size; 317 OP_REQUIRES_OK( 318 ctx, ParseScalarArgument<int64>(ctx, "buffer_size", &buffer_size)); 319 OP_REQUIRES( 320 ctx, buffer_size > 0, 321 errors::InvalidArgument("buffer_size must be greater than zero.")); 322 323 int64 seed; 324 OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "seed", &seed)); 325 326 int64 seed2; 327 OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "seed2", &seed2)); 328 329 // By TensorFlow convention, passing 0 for both seeds indicates 330 // that the shuffling should be seeded non-deterministically. 331 if (seed == 0 && seed2 == 0) { 332 seed = random::New64(); 333 seed2 = random::New64(); 334 } 335 336 int64 count = 1; 337 if (reshuffle_each_iteration_) { 338 *output = 339 new ReshufflingDataset(ctx, input, buffer_size, seed, seed2, count); 340 } else { 341 *output = 342 new FixedSeedDataset(ctx, input, buffer_size, seed, seed2, count); 343 } 344 } 345 346 private: 347 // A dataset that uses a pseduorandom sequence of seeds for the iterators 348 // created from it. Used when `reshuffle_each_iteration` is true. 349 class ReshufflingDataset : public ShuffleDatasetBase { 350 public: ReshufflingDataset(OpKernelContext * ctx,const DatasetBase * input,int64 buffer_size,int64 seed,int64 seed2,int64 count)351 ReshufflingDataset(OpKernelContext* ctx, const DatasetBase* input, 352 int64 buffer_size, int64 seed, int64 seed2, int64 count) 353 : ShuffleDatasetBase(ctx, input, buffer_size, count), 354 seed_(seed), 355 seed2_(seed2), 356 parent_generator_(seed, seed2), 357 generator_(&parent_generator_) {} 358 DebugString()359 string DebugString() override { 360 return strings::StrCat("ShuffleDatasetOp(", buffer_size_, ", ", seed_, 361 ", ", seed2_, ")::ReshufflingDataset"); 362 } 363 MakeIterator(const string & prefix) const364 std::unique_ptr<IteratorBase> MakeIterator( 365 const string& prefix) const override { 366 int64 iterator_seed; 367 int64 iterator_seed2; 368 { 369 mutex_lock l(mu_); 370 iterator_seed = generator_(); 371 iterator_seed2 = generator_(); 372 } 373 return std::unique_ptr<IteratorBase>(new ShuffleDatasetBase::Iterator( 374 {this, strings::StrCat(prefix, "::Shuffle")}, iterator_seed, 375 iterator_seed2)); 376 } 377 378 private: 379 const int64 seed_; 380 const int64 seed2_; 381 mutable mutex mu_; 382 mutable random::PhiloxRandom parent_generator_ GUARDED_BY(mu_); 383 mutable random::SingleSampleAdapter<random::PhiloxRandom> generator_ 384 GUARDED_BY(mu_); 385 }; 386 387 // A dataset that uses the same fixed seed for all iterators created from it. 388 // Used when `reshuffle_each_iteration` is false. 389 class FixedSeedDataset : public ShuffleDatasetBase { 390 public: FixedSeedDataset(OpKernelContext * ctx,const DatasetBase * input,int64 buffer_size,int64 seed,int64 seed2,int64 count)391 FixedSeedDataset(OpKernelContext* ctx, const DatasetBase* input, 392 int64 buffer_size, int64 seed, int64 seed2, int64 count) 393 : ShuffleDatasetBase(ctx, input, buffer_size, count), 394 seed_(seed), 395 seed2_(seed) {} 396 DebugString()397 string DebugString() override { 398 return strings::StrCat("ShuffleDatasetOp(", buffer_size_, ", ", seed_, 399 ", ", seed2_, ")::FixedSeedDataset"); 400 } 401 MakeIterator(const string & prefix) const402 std::unique_ptr<IteratorBase> MakeIterator( 403 const string& prefix) const override { 404 return std::unique_ptr<IteratorBase>(new ShuffleDatasetBase::Iterator( 405 {this, strings::StrCat(prefix, "::Shuffle")}, seed_, seed2_)); 406 } 407 408 protected: AsGraphDefInternal(OpKernelContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const409 Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, 410 Node** output) const override { 411 Node* input_graph_node = nullptr; 412 TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_graph_node)); 413 Node* buffer_size = nullptr; 414 Node* seed = nullptr; 415 Node* seed2 = nullptr; 416 AttrValue reshuffle_each_iteration; 417 418 TF_RETURN_IF_ERROR(b->AddScalar(buffer_size_, &buffer_size)); 419 TF_RETURN_IF_ERROR(b->AddScalar(seed_, &seed)); 420 TF_RETURN_IF_ERROR(b->AddScalar(seed2_, &seed2)); 421 b->BuildAttrValue(false, &reshuffle_each_iteration); 422 TF_RETURN_IF_ERROR(b->AddDataset( 423 this, {input_graph_node, buffer_size, seed, seed2}, // Inputs 424 {std::make_pair("reshuffle_each_iteration", 425 reshuffle_each_iteration)}, // Attrs 426 output)); 427 return Status::OK(); 428 } 429 430 private: 431 const int64 seed_; 432 const int64 seed2_; 433 }; 434 435 bool reshuffle_each_iteration_; 436 }; 437 438 class ShuffleAndRepeatDatasetOp : public ShuffleDatasetOpBase { 439 public: ShuffleAndRepeatDatasetOp(OpKernelConstruction * ctx)440 explicit ShuffleAndRepeatDatasetOp(OpKernelConstruction* ctx) 441 : ShuffleDatasetOpBase(ctx) {} 442 MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)443 void MakeDataset(OpKernelContext* ctx, DatasetBase* input, 444 DatasetBase** output) override { 445 int64 buffer_size; 446 OP_REQUIRES_OK( 447 ctx, ParseScalarArgument<int64>(ctx, "buffer_size", &buffer_size)); 448 OP_REQUIRES( 449 ctx, buffer_size > 0, 450 errors::InvalidArgument("buffer_size must be greater than zero.")); 451 452 int64 seed; 453 OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "seed", &seed)); 454 455 int64 seed2; 456 OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "seed2", &seed2)); 457 458 int64 count; 459 OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "count", &count)); 460 461 // By TensorFlow convention, if both seeds are 0, then shuffling should be 462 // seeded non-deterministically. 463 if (seed == 0 && seed2 == 0) { 464 seed = random::New64(); 465 seed2 = random::New64(); 466 } 467 468 *output = new Dataset(ctx, input, buffer_size, seed, seed2, count); 469 } 470 471 private: 472 class Dataset : public ShuffleDatasetBase { 473 public: Dataset(OpKernelContext * ctx,const DatasetBase * input,int64 buffer_size,int64 seed,int64 seed2,int64 count)474 Dataset(OpKernelContext* ctx, const DatasetBase* input, int64 buffer_size, 475 int64 seed, int64 seed2, int64 count) 476 : ShuffleDatasetBase(ctx, input, buffer_size, count), 477 seed_(seed), 478 seed2_(seed2) {} 479 DebugString()480 string DebugString() override { 481 return strings::StrCat("ShuffleAndRepeatDatasetOp(", buffer_size_, ", ", 482 seed_, ", ", seed2_, ", ", count_, ")::Dataset"); 483 } 484 MakeIterator(const string & prefix) const485 std::unique_ptr<IteratorBase> MakeIterator( 486 const string& prefix) const override { 487 return std::unique_ptr<IteratorBase>(new ShuffleDatasetBase::Iterator( 488 {this, strings::StrCat(prefix, "::ShuffleAndRepeat")}, seed_, 489 seed2_)); 490 } 491 492 protected: AsGraphDefInternal(OpKernelContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const493 Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, 494 Node** output) const override { 495 Node* input_graph_node = nullptr; 496 TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_graph_node)); 497 Node* buffer_size = nullptr; 498 Node* seed = nullptr; 499 Node* seed2 = nullptr; 500 Node* count = nullptr; 501 502 TF_RETURN_IF_ERROR(b->AddScalar(buffer_size_, &buffer_size)); 503 TF_RETURN_IF_ERROR(b->AddScalar(seed_, &seed)); 504 TF_RETURN_IF_ERROR(b->AddScalar(seed2_, &seed2)); 505 TF_RETURN_IF_ERROR(b->AddScalar(count_, &count)); 506 TF_RETURN_IF_ERROR(b->AddDataset( 507 this, {input_graph_node, buffer_size, seed, seed2, count}, // Inputs 508 {}, // Attrs 509 output)); 510 return Status::OK(); 511 } 512 513 private: 514 const int64 seed_; 515 const int64 seed2_; 516 }; 517 }; 518 519 REGISTER_KERNEL_BUILDER(Name("ShuffleDataset").Device(DEVICE_CPU), 520 ShuffleDatasetOp); 521 522 REGISTER_KERNEL_BUILDER(Name("ShuffleAndRepeatDataset").Device(DEVICE_CPU), 523 ShuffleAndRepeatDatasetOp); 524 525 } // namespace 526 527 } // namespace tensorflow 528