1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #include "tensorflow/core/framework/dataset.h" 16 #include "tensorflow/core/framework/partial_tensor_shape.h" 17 #include "tensorflow/core/framework/resource_mgr.h" 18 #include "tensorflow/core/framework/tensor.h" 19 #include "tensorflow/core/lib/strings/stringprintf.h" 20 #include "tensorflow/core/platform/env.h" 21 #include "tensorflow/core/util/tensor_bundle/tensor_bundle.h" 22 23 namespace tensorflow { 24 namespace data { 25 namespace { 26 27 // See documentation in ../../ops/dataset_ops.cc for a high-level description of 28 // the following op. 29 30 class CacheDatasetOp : public UnaryDatasetOpKernel { 31 public: CacheDatasetOp(OpKernelConstruction * ctx)32 explicit CacheDatasetOp(OpKernelConstruction* ctx) 33 : UnaryDatasetOpKernel(ctx) {} 34 MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)35 void MakeDataset(OpKernelContext* ctx, DatasetBase* input, 36 DatasetBase** output) override { 37 // Parse out the filenames tensor. 38 string filename; 39 OP_REQUIRES_OK(ctx, 40 ParseScalarArgument<string>(ctx, "filename", &filename)); 41 42 if (filename.empty()) { 43 *output = new MemoryDataset(ctx, input); 44 } else { 45 *output = new FileDataset(ctx, input, filename, ctx->env()); 46 } 47 } 48 49 private: 50 class FileDataset : public DatasetBase { 51 public: FileDataset(OpKernelContext * ctx,const DatasetBase * input,string filename,Env * env)52 explicit FileDataset(OpKernelContext* ctx, const DatasetBase* input, 53 string filename, Env* env) 54 : DatasetBase(DatasetContext(ctx)), 55 input_(input), 56 filename_(std::move(filename)), 57 env_(env), 58 num_tensors_(input->output_dtypes().size()), 59 tensor_index_padding_size_(StringPaddingSize(num_tensors_)), 60 item_index_padding_size_(StringPaddingSize(kMaxItems)), 61 tensor_format_string_(strings::Printf("%%%zuzu_%%%zuzu", 62 item_index_padding_size_, 63 tensor_index_padding_size_)) { 64 input_->Ref(); 65 DCHECK_EQ(item_index_padding_size_, 7); 66 } 67 ~FileDataset()68 ~FileDataset() override { input_->Unref(); } 69 MakeIteratorInternal(const string & prefix) const70 std::unique_ptr<IteratorBase> MakeIteratorInternal( 71 const string& prefix) const override { 72 return absl::make_unique<FileIterator>( 73 FileIterator::Params{this, strings::StrCat(prefix, "::FileCache")}); 74 } 75 output_dtypes() const76 const DataTypeVector& output_dtypes() const override { 77 return input_->output_dtypes(); 78 } 79 output_shapes() const80 const std::vector<PartialTensorShape>& output_shapes() const override { 81 return input_->output_shapes(); 82 } 83 DebugString() const84 string DebugString() const override { 85 return "CacheDatasetOp::FileDataset"; 86 } 87 Cardinality() const88 int64 Cardinality() const override { return input_->Cardinality(); } 89 90 protected: AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const91 Status AsGraphDefInternal(SerializationContext* ctx, 92 DatasetGraphDefBuilder* b, 93 Node** output) const override { 94 Node* input_graph = nullptr; 95 TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph)); 96 Node* filename = nullptr; 97 TF_RETURN_IF_ERROR(b->AddScalar(filename_, &filename)); 98 TF_RETURN_IF_ERROR(b->AddDataset(this, {input_graph, filename}, output)); 99 return Status::OK(); 100 } 101 102 private: StringPaddingSize(size_t num_tensors)103 static size_t StringPaddingSize(size_t num_tensors) { 104 return strings::Printf("%zu", num_tensors - 1).size(); 105 } 106 FormatName(size_t item_index,size_t tensor_index) const107 string FormatName(size_t item_index, size_t tensor_index) const { 108 return strings::Printf(tensor_format_string_.c_str(), item_index, 109 tensor_index); 110 } 111 112 class FileIterator : public DatasetIterator<FileDataset> { 113 public: FileIterator(const Params & params)114 explicit FileIterator(const Params& params) 115 : DatasetIterator<FileDataset>(params) { 116 if (params.dataset->env_ 117 ->FileExists(MetaFilename(params.dataset->filename_)) 118 .ok()) { 119 mode_ = Mode::read; 120 } else { 121 mode_ = Mode::write; 122 } 123 InitializeIterator(); 124 } 125 Initialize(IteratorContext * ctx)126 Status Initialize(IteratorContext* ctx) override { 127 mutex_lock l(mu_); 128 return iterator_->Initialize(ctx); 129 } 130 GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)131 Status GetNextInternal(IteratorContext* ctx, 132 std::vector<Tensor>* out_tensors, 133 bool* end_of_sequence) override { 134 mutex_lock l(mu_); 135 return iterator_->GetNext(ctx, out_tensors, end_of_sequence); 136 } 137 138 protected: CreateNode(IteratorContext * ctx,model::Node::Args args) const139 std::shared_ptr<model::Node> CreateNode( 140 IteratorContext* ctx, model::Node::Args args) const override { 141 return model::MakeKnownRatioNode(std::move(args), 142 /*ratio=*/1); 143 } 144 SaveInternal(IteratorStateWriter * writer)145 Status SaveInternal(IteratorStateWriter* writer) override { 146 mutex_lock l(mu_); 147 TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("mode"), mode_)); 148 return SaveInput(writer, iterator_); 149 } RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)150 Status RestoreInternal(IteratorContext* ctx, 151 IteratorStateReader* reader) override { 152 mutex_lock l(mu_); 153 { 154 int64 temp; 155 TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("mode"), &temp)); 156 mode_ = static_cast<Mode>(temp); 157 } 158 if (mode_ == Mode::write && 159 dataset() 160 ->env_->FileExists(MetaFilename(dataset()->filename_)) 161 .ok()) { 162 // This could happen if the cache was completely written after the 163 // checkpoint was saved. 164 LOG(WARNING) 165 << "It looks like the cache was already completely written(" 166 << MetaFilename(dataset()->filename_) 167 << ") after the last checkpoint was saved. " 168 << "Attempting to read the cache instead of continuing to " 169 << "write. If this is a mistake, please remove the above file " 170 << "and try running again."; 171 mode_ = Mode::read; 172 } 173 InitializeIterator(); 174 TF_RETURN_IF_ERROR(iterator_->Initialize(ctx)); 175 return RestoreInput(ctx, reader, iterator_); 176 } 177 178 private: 179 // FileWriterIterator passes through and caches items from the input 180 // FileDataset. 181 // 182 // This iterator is used when the cache directory is not found on disk. It 183 // creates the cache directory, and passes on the underlying iterator's 184 // elements. 185 // 186 // Caching is performed by writing the input tensors to disk using the 187 // `BundleWriter`. Note that the cache gets fully flushed to disk only 188 // after the input iterator has been fully exhausted. If the program 189 // exits, before completion of an epoch, the cached state would be lost. 190 // To ensure that the partial cache persists across sessions, one should 191 // checkpoint the input pipeline. On each call to `SaveInternal` the 192 // partial cache gets flushed to disk in files with prefix 193 // <filename>_<shard_id> where shard_id is unique for each checkpoint. 194 // When all elements have been produced, these shards get coalesced. 195 class FileWriterIterator : public DatasetIterator<FileDataset> { 196 public: FileWriterIterator(const Params & params)197 explicit FileWriterIterator(const Params& params) 198 : DatasetIterator<FileDataset>(params), 199 cur_index_(0), 200 shard_id_(0), 201 filename_( 202 strings::StrCat(params.dataset->filename_, "_", shard_id_)), 203 lockfile_(strings::StrCat(filename_, ".lockfile")), 204 lockfile_created_(false), 205 iteration_completed_(false) {} 206 Initialize(IteratorContext * ctx)207 Status Initialize(IteratorContext* ctx) override { 208 return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); 209 } 210 GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)211 Status GetNextInternal(IteratorContext* ctx, 212 std::vector<Tensor>* out_tensors, 213 bool* end_of_sequence) override { 214 mutex_lock l(mu_); 215 TF_RETURN_IF_ERROR(EnsureLockFileExists()); 216 TF_RETURN_IF_ERROR(writer_->status()); 217 if (cur_index_ >= kMaxItems) { 218 // As a courtesy, close the [truncated] cache file. 219 Status s = Finish(); 220 if (!s.ok()) { 221 LOG(ERROR) << s; 222 } 223 return errors::InvalidArgument( 224 "Upstream iterator is producing more than ", kMaxItems, 225 " items, which is more than the cache limit."); 226 } 227 228 TF_RETURN_IF_ERROR( 229 input_impl_->GetNext(ctx, out_tensors, end_of_sequence)); 230 if (*end_of_sequence && out_tensors->empty()) { 231 TF_RETURN_IF_ERROR(Finish()); 232 cur_index_++; 233 return Status::OK(); 234 } 235 if (out_tensors->size() != dataset()->num_tensors_) { 236 return errors::Internal( 237 "Upstream iterator returned invalid number of tensors. " 238 "Expected ", 239 dataset()->num_tensors_, " got: ", out_tensors->size()); 240 } 241 size_t tensor_index = 0; 242 for (const Tensor& t : *out_tensors) { 243 DCHECK_LT(tensor_index, dataset()->num_tensors_); 244 string key = dataset()->FormatName(cur_index_, tensor_index++); 245 TF_RETURN_IF_ERROR(writer_->Add(key, t)); 246 } 247 if (*end_of_sequence) { 248 TF_RETURN_IF_ERROR(Finish()); 249 } 250 cur_index_++; 251 return Status::OK(); 252 } 253 254 protected: CreateNode(IteratorContext * ctx,model::Node::Args args) const255 std::shared_ptr<model::Node> CreateNode( 256 IteratorContext* ctx, model::Node::Args args) const override { 257 return model::MakeKnownRatioNode(std::move(args), 258 /*ratio=*/1); 259 } 260 SaveInternal(IteratorStateWriter * writer)261 Status SaveInternal(IteratorStateWriter* writer) override { 262 mutex_lock l(mu_); 263 if (iteration_completed_) { 264 TF_RETURN_IF_ERROR( 265 writer->WriteScalar(full_name("iteration_completed"), "")); 266 return Status::OK(); 267 } 268 269 // lockfile is created on the first call to GetNextInternal. The 270 // absence of a lockfile means that GetNextInternal was not called 271 // and hence nothing was written to cache. So we don't need to worry 272 // about flushing the current shard. This ensures that we never write 273 // empty shards. 274 if (lockfile_created_) { 275 // Flush the current bundle. 276 TF_RETURN_IF_ERROR(writer_->Finish()); 277 278 // Note: We do not delete the lockfile here. We keep lockfiles of 279 // all shards around until the entire cache has been written to 280 // prevent concurrent iterators from corrupting any of the shards. 281 282 // Start caching to a new shard. 283 shard_id_++; 284 filename_ = strings::StrCat(dataset()->filename_, "_", shard_id_); 285 lockfile_ = strings::StrCat(filename_, ".lockfile"); 286 lockfile_created_ = false; 287 } 288 TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); 289 TF_RETURN_IF_ERROR( 290 writer->WriteScalar(full_name("cur_index"), cur_index_)); 291 TF_RETURN_IF_ERROR( 292 writer->WriteScalar(full_name("shard_id"), shard_id_)); 293 return Status::OK(); 294 } 295 RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)296 Status RestoreInternal(IteratorContext* ctx, 297 IteratorStateReader* reader) override { 298 mutex_lock l(mu_); 299 if (reader->Contains(full_name("iteration_completed"))) { 300 iteration_completed_ = true; 301 return Status::OK(); 302 } 303 304 TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); 305 int64 temp; 306 // TODO(b/78048575): Update this when saving size_t tensors directly 307 // is supported. 308 { 309 TF_RETURN_IF_ERROR( 310 reader->ReadScalar(full_name("cur_index"), &temp)); 311 cur_index_ = static_cast<size_t>(temp); 312 if (cur_index_ != temp) { 313 return errors::Internal("Invalid value for cur_index ", temp); 314 } 315 } 316 // TODO(b/78048575): Update this when saving size_t tensors directly 317 // is supported. 318 { 319 TF_RETURN_IF_ERROR( 320 reader->ReadScalar(full_name("shard_id"), &temp)); 321 shard_id_ = static_cast<size_t>(temp); 322 if (shard_id_ != temp) { 323 return errors::Internal("Invalid value for shard_id ", temp); 324 } 325 } 326 filename_ = strings::StrCat(dataset()->filename_, "_", shard_id_); 327 lockfile_ = strings::StrCat(filename_, ".lockfile"); 328 writer_ = absl::make_unique<BundleWriter>(dataset()->env_, filename_); 329 return Status::OK(); 330 } 331 332 private: EnsureLockFileExists()333 Status EnsureLockFileExists() EXCLUSIVE_LOCKS_REQUIRED(mu_) { 334 if (iteration_completed_) 335 return errors::OutOfRange( 336 "Attempting to call get_next after iteration should have " 337 "finished."); 338 if (lockfile_created_ && !iteration_completed_) return Status::OK(); 339 340 // Perform rudimentary locking to help catch concurrent writes to the 341 // same cache files. 342 343 // 1. Check that a checkpoint for the shard has not already been 344 // written. 345 if (dataset()->env_->FileExists(MetaFilename(filename_)).ok()) { 346 return errors::AlreadyExists("Existing cache files found: \n", 347 MetaFilename(filename_), "\n", 348 DataFilename(filename_, 0, 1), "\n", 349 "To continue delete the above files."); 350 } 351 352 // 2. Check that there isn't a concurrent iterator that is writing 353 // to cache. 354 if (dataset()->env_->FileExists(lockfile_).ok()) { 355 // Attempt to read the contents of the lockfile. 356 char contents_scratch[151] = {0}; // Initialize all to 0. 357 StringPiece contents; 358 std::unique_ptr<RandomAccessFile> file; 359 if (dataset()->env_->NewRandomAccessFile(lockfile_, &file).ok()) { 360 file->Read(0, 150, &contents, contents_scratch).IgnoreError(); 361 } 362 return errors::AlreadyExists( 363 "There appears to be a concurrent caching iterator running - " 364 "cache lockfile already exists ('", 365 lockfile_, 366 "'). If you are sure no other running TF computations are " 367 "using " 368 "this cache prefix, delete the lockfile and re-initialize the " 369 "iterator. Lockfile contents: ", 370 contents); 371 } else { 372 // Create the file, and write some basic contents. 373 std::unique_ptr<WritableFile> lockfile; 374 TF_RETURN_IF_ERROR( 375 dataset()->env_->NewWritableFile(lockfile_, &lockfile)); 376 TF_RETURN_IF_ERROR(lockfile->Append(strings::StrCat( 377 "Created at: ", dataset()->env_->NowSeconds()))); 378 379 // At this point we know that 380 // 1. There is no conflicting checkpoint with prefix `filename_`. 381 // 2. There is no concurrent session that is trying to write a ckpt 382 // to filename. 383 // So it is safe to create a BundleWriter here. Note that it is 384 // unsafe to initialize the BundleWriter anywhere the above 385 // conditions are not met since BundleWriter's constructor creates 386 // new temp files which can delete the temp files created by a 387 // BundleWriter in another Session. 388 writer_ = 389 absl::make_unique<BundleWriter>(dataset()->env_, filename_); 390 lockfile_created_ = true; 391 return Status::OK(); 392 } 393 } 394 Finish()395 Status Finish() EXCLUSIVE_LOCKS_REQUIRED(mu_) { 396 iteration_completed_ = true; 397 // Flush the current bundle. 398 TF_RETURN_IF_ERROR(writer_->Finish()); 399 // Merge all the bundles. 400 // Currently there are `shard_id_ + 1` bundles, one for each 401 // checkpoint. Each bundle has prefix <filename>_<id> where `id` is an 402 // integer starting at 0 an incremented by 1 for each new checkpoint. 403 // We merge all these bundles into a bundle with prefix <filename> so 404 // that the next call to `MakeIterator` can build a 405 // `FileReaderIterator`. 406 { 407 std::vector<string> prefixes; 408 prefixes.reserve(shard_id_ + 1); 409 for (size_t i = 0; i <= shard_id_; ++i) { 410 prefixes.emplace_back( 411 strings::StrCat(dataset()->filename_, "_", i)); 412 } 413 TF_RETURN_IF_ERROR( 414 MergeBundles(dataset()->env_, prefixes, dataset()->filename_)); 415 } 416 // Delete all lockfiles. 417 for (size_t i = 0; i <= shard_id_; ++i) { 418 TF_RETURN_IF_ERROR(dataset()->env_->DeleteFile( 419 strings::StrCat(dataset()->filename_, "_", i, ".lockfile"))); 420 } 421 return Status::OK(); 422 } 423 424 mutex mu_; 425 size_t cur_index_ GUARDED_BY(mu_); 426 // Index of the current shard. This gets incremented whenever a new 427 // cache shard is saved. 428 size_t shard_id_ GUARDED_BY(mu_); 429 std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_); 430 // The current prefix for the cache file. This is equal to 431 // `StrCat(dataset()->filename_, "_", shard_id_)`. 432 string filename_; 433 std::unique_ptr<BundleWriter> writer_ GUARDED_BY(mu_); 434 string lockfile_ GUARDED_BY(mu_); 435 bool lockfile_created_ GUARDED_BY(mu_); 436 bool iteration_completed_ GUARDED_BY(mu_); 437 }; // FileWriterIterator 438 439 class FileReaderIterator : public DatasetIterator<FileDataset> { 440 public: FileReaderIterator(const Params & params)441 explicit FileReaderIterator(const Params& params) 442 : DatasetIterator<FileDataset>(params), 443 cur_index_(0), 444 reader_(dataset()->env_, dataset()->filename_), 445 iterator_restored_(false) {} 446 GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)447 Status GetNextInternal(IteratorContext* ctx, 448 std::vector<Tensor>* out_tensors, 449 bool* end_of_sequence) override { 450 mutex_lock l(mu_); 451 *end_of_sequence = false; 452 TF_RETURN_IF_ERROR(reader_.status()); 453 if (!reader_.Valid()) { 454 return errors::Internal( 455 "Cache iterator is in an invalid state. (Perhaps GetNext " 456 "called " 457 "after end_of_sequence?)"); 458 } 459 out_tensors->clear(); 460 out_tensors->resize(dataset()->num_tensors_); 461 462 for (size_t i = 0; i < dataset()->num_tensors_; ++i) { 463 // When the iterator is restored from the checkpoint, `reader_` is 464 // already pointing at `key` so we do not need to skip the header 465 // entry. 466 if (!iterator_restored_) { 467 reader_ 468 .Next(); // The first entry in the table is a header entry. 469 } else { 470 iterator_restored_ = false; 471 } 472 if (!reader_.Valid()) { 473 out_tensors->clear(); 474 *end_of_sequence = true; 475 return Status::OK(); 476 } 477 StringPiece key = reader_.key(); 478 DCHECK_EQ(key, dataset()->FormatName(cur_index_, i)); 479 TF_RETURN_IF_ERROR(reader_.ReadCurrent(&(*out_tensors)[i])); 480 TF_RETURN_IF_ERROR(reader_.status()); 481 } 482 cur_index_++; 483 return Status::OK(); 484 } 485 486 protected: CreateNode(IteratorContext * ctx,model::Node::Args args) const487 std::shared_ptr<model::Node> CreateNode( 488 IteratorContext* ctx, model::Node::Args args) const override { 489 return model::MakeKnownRatioNode(std::move(args), 490 /*ratio=*/1); 491 } 492 SaveInternal(IteratorStateWriter * writer)493 Status SaveInternal(IteratorStateWriter* writer) override { 494 mutex_lock l(mu_); 495 TF_RETURN_IF_ERROR( 496 writer->WriteScalar(full_name("cur_index"), cur_index_)); 497 return Status::OK(); 498 } 499 RestoreInternal(IteratorContext * ctx,IteratorStateReader * iterator_state_reader)500 Status RestoreInternal( 501 IteratorContext* ctx, 502 IteratorStateReader* iterator_state_reader) override { 503 mutex_lock l(mu_); 504 { 505 // TODO(b/78048575): Update this when saving size_t tensors directly 506 // is supported. 507 int64 temp; 508 TF_RETURN_IF_ERROR(iterator_state_reader->ReadScalar( 509 full_name("cur_index"), &temp)); 510 cur_index_ = static_cast<size_t>(temp); 511 if (cur_index_ != temp) { 512 return errors::Internal("Invalid value for cur_index ", temp); 513 } 514 } 515 if (!reader_.Valid()) { 516 return errors::Internal("Error initializing BundleReader."); 517 } 518 reader_.Seek(dataset()->FormatName(cur_index_, 0)); 519 iterator_restored_ = true; 520 return Status::OK(); 521 } 522 523 private: 524 mutex mu_; 525 size_t cur_index_ GUARDED_BY(mu_); 526 BundleReader reader_ GUARDED_BY(mu_); 527 bool iterator_restored_ GUARDED_BY(mu_); 528 }; // FileReaderIterator 529 InitializeIterator()530 void InitializeIterator() EXCLUSIVE_LOCKS_REQUIRED(mu_) { 531 // We intentionally use the same prefix for both `FileReaderIterator` 532 // and `FileWriterIterator`. Since at any time there will be at most 533 // one of them alive, there should be no conflicts. This allows both 534 // iterators to use a common key for `cur_index`. We leverage this 535 // in the corner case when this iterator is restored from an old 536 // checkpoint in `write` mode and the cache has been completely 537 // flushed to disk since then. In that case we simply build a 538 // `FileReaderIterator` and seek to the `cur_index`. 539 switch (mode_) { 540 case Mode::read: 541 iterator_ = absl::make_unique<FileReaderIterator>( 542 FileReaderIterator::Params{dataset(), 543 strings::StrCat(prefix(), "Impl")}); 544 break; 545 case Mode::write: 546 iterator_ = absl::make_unique<FileWriterIterator>( 547 FileWriterIterator::Params{dataset(), 548 strings::StrCat(prefix(), "Impl")}); 549 } 550 } 551 552 mutex mu_; 553 enum Mode { read, write }; 554 Mode mode_ GUARDED_BY(mu_); 555 std::unique_ptr<IteratorBase> iterator_ GUARDED_BY(mu_); 556 }; // FileIterator 557 558 const DatasetBase* const input_; 559 const string filename_; 560 Env* const env_; 561 const size_t num_tensors_; 562 const size_t tensor_index_padding_size_; 563 static const size_t kMaxItems = 10000000; // 10 million 564 const size_t item_index_padding_size_; 565 const string tensor_format_string_; 566 }; // FileDataset 567 568 class MemoryDataset : public DatasetBase { 569 public: MemoryDataset(OpKernelContext * ctx,const DatasetBase * input)570 explicit MemoryDataset(OpKernelContext* ctx, const DatasetBase* input) 571 : DatasetBase(DatasetContext(ctx)), input_(input) { 572 input->Ref(); 573 } 574 ~MemoryDataset()575 ~MemoryDataset() override { input_->Unref(); } 576 MakeIteratorInternal(const string & prefix) const577 std::unique_ptr<IteratorBase> MakeIteratorInternal( 578 const string& prefix) const override { 579 return absl::make_unique<MemoryIterator>(MemoryIterator::Params{ 580 this, strings::StrCat(prefix, "::MemoryCache")}); 581 } 582 output_dtypes() const583 const DataTypeVector& output_dtypes() const override { 584 return input_->output_dtypes(); 585 } 586 output_shapes() const587 const std::vector<PartialTensorShape>& output_shapes() const override { 588 return input_->output_shapes(); 589 } 590 DebugString() const591 string DebugString() const override { 592 return "CacheDatasetOp::MemoryDataset"; 593 } 594 Cardinality() const595 int64 Cardinality() const override { return input_->Cardinality(); } 596 597 protected: AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const598 Status AsGraphDefInternal(SerializationContext* ctx, 599 DatasetGraphDefBuilder* b, 600 Node** output) const override { 601 Node* input_node = nullptr; 602 TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node)); 603 Node* filename_node = nullptr; 604 TF_RETURN_IF_ERROR(b->AddScalar(string(""), &filename_node)); 605 TF_RETURN_IF_ERROR( 606 b->AddDataset(this, {input_node, filename_node}, output)); 607 return Status::OK(); 608 } 609 610 private: 611 // A thread-safe data structure for caching dataset elements. 612 // 613 // The expected use is that a single `MemoryWriterIterator` populates the 614 // cache with dataset elements. Once all elements are cached, the cache can 615 // be used by one or more `MemoryReaderIterator`s. 616 class MemoryCache : public ResourceBase { 617 public: 618 MemoryCache() = default; 619 DebugString() const620 string DebugString() const override { 621 return "CacheDataset::MemoryCache"; 622 } 623 624 // Marks the cache as completed. Complete()625 void Complete() { 626 mutex_lock l(mu_); 627 completed_ = true; 628 } 629 630 // Returns whether the cache is claimed. IsClaimed()631 bool IsClaimed() { 632 tf_shared_lock l(mu_); 633 return claimed_; 634 } 635 636 // Returns whether the cache is completed. IsCompleted()637 bool IsCompleted() { 638 tf_shared_lock l(mu_); 639 return completed_; 640 } 641 642 // Attempts to claim the cache, returning whether the cache was claimed. MaybeClaim()643 bool MaybeClaim() { 644 mutex_lock l(mu_); 645 if (!claimed_) { 646 claimed_ = true; 647 return true; 648 } 649 return false; 650 } 651 652 // Resets the cache. Reset()653 void Reset() { 654 mutex_lock l(mu_); 655 claimed_ = false; 656 completed_ = false; 657 cache_.clear(); 658 } 659 660 // Returns the element at the given index. at(int64 index)661 const std::vector<Tensor>& at(int64 index) { 662 tf_shared_lock l(mu_); 663 DCHECK(index < cache_.size()); 664 return cache_[index]; 665 } 666 667 // Adds the element to the cache. emplace_back(std::vector<Tensor> element)668 void emplace_back(std::vector<Tensor> element) { 669 mutex_lock l(mu_); 670 cache_.emplace_back(std::move(element)); 671 } 672 673 // Returns the size of the cache. size()674 size_t size() { 675 tf_shared_lock l(mu_); 676 return cache_.size(); 677 } 678 679 private: 680 mutex mu_; 681 // Determines whether a writer has claimed the cache. 682 bool claimed_ GUARDED_BY(mu_) = false; 683 // Determines whether all elements of the dataset have been cached. 684 bool completed_ GUARDED_BY(mu_) = false; 685 std::vector<std::vector<Tensor>> cache_ GUARDED_BY(mu_); 686 }; 687 688 class MemoryIterator : public DatasetIterator<MemoryDataset> { 689 public: MemoryIterator(const Params & params)690 explicit MemoryIterator(const Params& params) 691 : DatasetIterator<MemoryDataset>(params) {} 692 ~MemoryIterator()693 ~MemoryIterator() override { cache_->Unref(); } 694 Initialize(IteratorContext * ctx)695 Status Initialize(IteratorContext* ctx) override { 696 mutex_lock l(mu_); 697 // Use the resource manager in the iterator context to get / create 698 // a cache. 699 ResourceMgr* mgr = ctx->resource_mgr(); 700 const string name = strings::StrCat( 701 prefix(), "::", dataset()->node_name(), "::MemoryCache"); 702 TF_RETURN_IF_ERROR(mgr->LookupOrCreate<MemoryCache>( 703 "tf_data", name, &cache_, [](MemoryCache** cache) { 704 *cache = new MemoryCache(); 705 return Status::OK(); 706 })); 707 mode_ = cache_->MaybeClaim() ? Mode::write : Mode::read; 708 InitializeIterator(); 709 if (mode_ == Mode::read && !cache_->IsCompleted()) { 710 return errors::Internal( 711 "Cache should only be read after it has been completed."); 712 } 713 return iterator_->Initialize(ctx); 714 } 715 GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)716 Status GetNextInternal(IteratorContext* ctx, 717 std::vector<Tensor>* out_tensors, 718 bool* end_of_sequence) override { 719 mutex_lock l(mu_); 720 return iterator_->GetNext(ctx, out_tensors, end_of_sequence); 721 } 722 723 protected: CreateNode(IteratorContext * ctx,model::Node::Args args) const724 std::shared_ptr<model::Node> CreateNode( 725 IteratorContext* ctx, model::Node::Args args) const override { 726 return model::MakeKnownRatioNode(std::move(args), 727 /*ratio=*/1); 728 } 729 SaveInternal(IteratorStateWriter * writer)730 Status SaveInternal(IteratorStateWriter* writer) override { 731 mutex_lock l(mu_); 732 TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("mode"), mode_)); 733 if (cache_->IsClaimed()) { 734 TF_RETURN_IF_ERROR( 735 writer->WriteScalar(full_name("cache_claimed"), "")); 736 size_t cache_size = cache_->size(); 737 TF_RETURN_IF_ERROR( 738 writer->WriteScalar(full_name("cache_size"), cache_size)); 739 for (size_t i = 0; i < cache_size; i++) { 740 auto& element = cache_->at(i); 741 TF_RETURN_IF_ERROR(writer->WriteScalar( 742 full_name(strings::StrCat("cache[", i, "].size")), 743 element.size())); 744 for (size_t j = 0; j < element.size(); ++j) { 745 TF_RETURN_IF_ERROR(writer->WriteTensor( 746 full_name(strings::StrCat("cache[", i, "][", j, "]")), 747 element[j])); 748 } 749 } 750 if (cache_->IsCompleted()) { 751 TF_RETURN_IF_ERROR( 752 writer->WriteScalar(full_name("cache_completed"), "")); 753 } 754 } 755 return SaveInput(writer, iterator_); 756 } 757 RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)758 Status RestoreInternal(IteratorContext* ctx, 759 IteratorStateReader* reader) override { 760 mutex_lock l(mu_); 761 iterator_.reset(); 762 cache_->Reset(); 763 { 764 int64 temp; 765 TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("mode"), &temp)); 766 mode_ = static_cast<Mode>(temp); 767 } 768 if (reader->Contains(full_name("cache_claimed"))) { 769 CHECK(cache_->MaybeClaim()); 770 size_t cache_size; 771 { 772 int64 temp; 773 TF_RETURN_IF_ERROR( 774 reader->ReadScalar(full_name("cache_size"), &temp)); 775 cache_size = static_cast<size_t>(temp); 776 } 777 for (size_t i = 0; i < cache_size; ++i) { 778 std::vector<Tensor> element; 779 size_t element_size; 780 { 781 int64 temp; 782 TF_RETURN_IF_ERROR(reader->ReadScalar( 783 full_name(strings::StrCat("cache[", i, "].size")), &temp)); 784 element_size = static_cast<size_t>(temp); 785 } 786 element.reserve(element_size); 787 for (size_t j = 0; j < element_size; ++j) { 788 element.emplace_back(); 789 TF_RETURN_IF_ERROR(reader->ReadTensor( 790 full_name(strings::StrCat("cache[", i, "][", j, "]")), 791 &element.back())); 792 } 793 cache_->emplace_back(std::move(element)); 794 } 795 if (reader->Contains(full_name("cache_completed"))) { 796 cache_->Complete(); 797 } 798 } 799 InitializeIterator(); 800 TF_RETURN_IF_ERROR(iterator_->Initialize(ctx)); 801 return RestoreInput(ctx, reader, iterator_); 802 } 803 804 private: 805 class MemoryWriterIterator : public DatasetIterator<MemoryDataset> { 806 public: MemoryWriterIterator(const Params & params,MemoryCache * cache)807 explicit MemoryWriterIterator(const Params& params, MemoryCache* cache) 808 : DatasetIterator<MemoryDataset>(params), cache_(cache) { 809 CHECK(cache_); 810 } 811 ~MemoryWriterIterator()812 ~MemoryWriterIterator() override { 813 mutex_lock l(mu_); 814 if (cache_->size() > 0 && !cache_->IsCompleted()) { 815 LOG(WARNING) 816 << "The calling iterator did not fully read the dataset being " 817 "cached. In order to avoid unexpected truncation of the " 818 "dataset, the partially cached contents of the dataset " 819 "will be discarded. This can happen if you have an input " 820 "pipeline similar to `dataset.cache().take(k).repeat()`. " 821 "You should use `dataset.take(k).cache().repeat()` instead."; 822 cache_->Reset(); 823 } 824 } 825 Initialize(IteratorContext * ctx)826 Status Initialize(IteratorContext* ctx) override { 827 return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); 828 } 829 GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)830 Status GetNextInternal(IteratorContext* ctx, 831 std::vector<Tensor>* out_tensors, 832 bool* end_of_sequence) override { 833 mutex_lock l(mu_); 834 TF_RETURN_IF_ERROR( 835 input_impl_->GetNext(ctx, out_tensors, end_of_sequence)); 836 if (*end_of_sequence) { 837 cache_->Complete(); 838 return Status::OK(); 839 } 840 RecordBufferEnqueue(ctx, *out_tensors); 841 cache_->emplace_back(*out_tensors); 842 return Status::OK(); 843 } 844 845 protected: CreateNode(IteratorContext * ctx,model::Node::Args args) const846 std::shared_ptr<model::Node> CreateNode( 847 IteratorContext* ctx, model::Node::Args args) const override { 848 return model::MakeKnownRatioNode(std::move(args), 849 /*ratio=*/1); 850 } 851 SaveInternal(IteratorStateWriter * writer)852 Status SaveInternal(IteratorStateWriter* writer) override { 853 mutex_lock l(mu_); 854 return SaveInput(writer, input_impl_); 855 } 856 RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)857 Status RestoreInternal(IteratorContext* ctx, 858 IteratorStateReader* reader) override { 859 mutex_lock l(mu_); 860 return RestoreInput(ctx, reader, input_impl_); 861 } 862 863 private: 864 mutex mu_; 865 std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_); 866 MemoryCache* const cache_ GUARDED_BY(mu_); // not owned. 867 }; // MemoryWriterIterator 868 869 class MemoryReaderIterator : public DatasetIterator<MemoryDataset> { 870 public: MemoryReaderIterator(const Params & params,MemoryCache * cache)871 explicit MemoryReaderIterator(const Params& params, MemoryCache* cache) 872 : DatasetIterator<MemoryDataset>(params), cache_(cache), index_(0) { 873 CHECK(cache); 874 } 875 Initialize(IteratorContext * ctx)876 Status Initialize(IteratorContext* ctx) override { 877 // The memory allocated for the cache is owned by the parent 878 // dataset but performance modeling uses the iterator abstraction and 879 // thus we record the memory allocated for the cache here. The caveat 880 // is that this is incorrect if there are concurrent instances of this 881 // iterator. 882 tf_shared_lock l(mu_); 883 for (size_t i = 0; i < cache_->size(); ++i) { 884 RecordBufferEnqueue(ctx, cache_->at(i)); 885 } 886 return Status::OK(); 887 } 888 GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)889 Status GetNextInternal(IteratorContext* ctx, 890 std::vector<Tensor>* out_tensors, 891 bool* end_of_sequence) override { 892 mutex_lock l(mu_); 893 if (index_ < cache_->size()) { 894 const std::vector<Tensor>& cache_tensors = cache_->at(index_); 895 out_tensors->insert(out_tensors->begin(), cache_tensors.begin(), 896 cache_tensors.end()); 897 index_++; 898 *end_of_sequence = false; 899 return Status::OK(); 900 } else { 901 *end_of_sequence = true; 902 return Status::OK(); 903 } 904 } 905 906 protected: CreateNode(IteratorContext * ctx,model::Node::Args args) const907 std::shared_ptr<model::Node> CreateNode( 908 IteratorContext* ctx, model::Node::Args args) const override { 909 return model::MakeKnownRatioNode(std::move(args), 910 /*ratio=*/1); 911 } 912 SaveInternal(IteratorStateWriter * writer)913 Status SaveInternal(IteratorStateWriter* writer) override { 914 mutex_lock l(mu_); 915 TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("index"), index_)); 916 return Status::OK(); 917 } 918 RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)919 Status RestoreInternal(IteratorContext* ctx, 920 IteratorStateReader* reader) override { 921 mutex_lock l(mu_); 922 { 923 int64 temp; 924 TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("index"), &temp)); 925 index_ = static_cast<size_t>(temp); 926 } 927 return Status::OK(); 928 } 929 930 private: 931 mutex mu_; 932 MemoryCache* const cache_ GUARDED_BY(mu_); // not owned. 933 size_t index_ GUARDED_BY(mu_); 934 }; // MemoryReaderIterator 935 InitializeIterator()936 void InitializeIterator() EXCLUSIVE_LOCKS_REQUIRED(mu_) { 937 switch (mode_) { 938 case Mode::read: 939 iterator_ = absl::make_unique<MemoryReaderIterator>( 940 MemoryReaderIterator::Params{dataset(), 941 strings::StrCat(prefix(), "Impl")}, 942 cache_); 943 break; 944 case Mode::write: 945 iterator_ = absl::make_unique<MemoryWriterIterator>( 946 MemoryWriterIterator::Params{dataset(), 947 strings::StrCat(prefix(), "Impl")}, 948 cache_); 949 } 950 } 951 952 mutex mu_; 953 MemoryCache* cache_ GUARDED_BY(mu_); // not owned. 954 enum Mode { read, write }; 955 Mode mode_ GUARDED_BY(mu_); 956 std::unique_ptr<IteratorBase> iterator_ GUARDED_BY(mu_); 957 }; // MemoryIterator 958 959 const DatasetBase* const input_; 960 }; // MemoryDataset 961 }; // CacheDatasetOp 962 963 REGISTER_KERNEL_BUILDER(Name("CacheDataset").Device(DEVICE_CPU), 964 CacheDatasetOp); 965 966 } // namespace 967 } // namespace data 968 } // namespace tensorflow 969