• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/framework/dataset.h"
16 #include "tensorflow/core/framework/partial_tensor_shape.h"
17 #include "tensorflow/core/framework/resource_mgr.h"
18 #include "tensorflow/core/framework/tensor.h"
19 #include "tensorflow/core/lib/strings/stringprintf.h"
20 #include "tensorflow/core/platform/env.h"
21 #include "tensorflow/core/util/tensor_bundle/tensor_bundle.h"
22 
23 namespace tensorflow {
24 namespace data {
25 namespace {
26 
27 // See documentation in ../../ops/dataset_ops.cc for a high-level description of
28 // the following op.
29 
30 class CacheDatasetOp : public UnaryDatasetOpKernel {
31  public:
CacheDatasetOp(OpKernelConstruction * ctx)32   explicit CacheDatasetOp(OpKernelConstruction* ctx)
33       : UnaryDatasetOpKernel(ctx) {}
34 
MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)35   void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
36                    DatasetBase** output) override {
37     // Parse out the filenames tensor.
38     string filename;
39     OP_REQUIRES_OK(ctx,
40                    ParseScalarArgument<string>(ctx, "filename", &filename));
41 
42     if (filename.empty()) {
43       *output = new MemoryDataset(ctx, input);
44     } else {
45       *output = new FileDataset(ctx, input, filename, ctx->env());
46     }
47   }
48 
49  private:
50   class FileDataset : public DatasetBase {
51    public:
FileDataset(OpKernelContext * ctx,const DatasetBase * input,string filename,Env * env)52     explicit FileDataset(OpKernelContext* ctx, const DatasetBase* input,
53                          string filename, Env* env)
54         : DatasetBase(DatasetContext(ctx)),
55           input_(input),
56           filename_(std::move(filename)),
57           env_(env),
58           num_tensors_(input->output_dtypes().size()),
59           tensor_index_padding_size_(StringPaddingSize(num_tensors_)),
60           item_index_padding_size_(StringPaddingSize(kMaxItems)),
61           tensor_format_string_(strings::Printf("%%%zuzu_%%%zuzu",
62                                                 item_index_padding_size_,
63                                                 tensor_index_padding_size_)) {
64       input_->Ref();
65       DCHECK_EQ(item_index_padding_size_, 7);
66     }
67 
~FileDataset()68     ~FileDataset() override { input_->Unref(); }
69 
MakeIteratorInternal(const string & prefix) const70     std::unique_ptr<IteratorBase> MakeIteratorInternal(
71         const string& prefix) const override {
72       return absl::make_unique<FileIterator>(
73           FileIterator::Params{this, strings::StrCat(prefix, "::FileCache")});
74     }
75 
output_dtypes() const76     const DataTypeVector& output_dtypes() const override {
77       return input_->output_dtypes();
78     }
79 
output_shapes() const80     const std::vector<PartialTensorShape>& output_shapes() const override {
81       return input_->output_shapes();
82     }
83 
DebugString() const84     string DebugString() const override {
85       return "CacheDatasetOp::FileDataset";
86     }
87 
Cardinality() const88     int64 Cardinality() const override { return input_->Cardinality(); }
89 
90    protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const91     Status AsGraphDefInternal(SerializationContext* ctx,
92                               DatasetGraphDefBuilder* b,
93                               Node** output) const override {
94       Node* input_graph = nullptr;
95       TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph));
96       Node* filename = nullptr;
97       TF_RETURN_IF_ERROR(b->AddScalar(filename_, &filename));
98       TF_RETURN_IF_ERROR(b->AddDataset(this, {input_graph, filename}, output));
99       return Status::OK();
100     }
101 
102    private:
StringPaddingSize(size_t num_tensors)103     static size_t StringPaddingSize(size_t num_tensors) {
104       return strings::Printf("%zu", num_tensors - 1).size();
105     }
106 
FormatName(size_t item_index,size_t tensor_index) const107     string FormatName(size_t item_index, size_t tensor_index) const {
108       return strings::Printf(tensor_format_string_.c_str(), item_index,
109                              tensor_index);
110     }
111 
112     class FileIterator : public DatasetIterator<FileDataset> {
113      public:
FileIterator(const Params & params)114       explicit FileIterator(const Params& params)
115           : DatasetIterator<FileDataset>(params) {
116         if (params.dataset->env_
117                 ->FileExists(MetaFilename(params.dataset->filename_))
118                 .ok()) {
119           mode_ = Mode::read;
120         } else {
121           mode_ = Mode::write;
122         }
123         InitializeIterator();
124       }
125 
Initialize(IteratorContext * ctx)126       Status Initialize(IteratorContext* ctx) override {
127         mutex_lock l(mu_);
128         return iterator_->Initialize(ctx);
129       }
130 
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)131       Status GetNextInternal(IteratorContext* ctx,
132                              std::vector<Tensor>* out_tensors,
133                              bool* end_of_sequence) override {
134         mutex_lock l(mu_);
135         return iterator_->GetNext(ctx, out_tensors, end_of_sequence);
136       }
137 
138      protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const139       std::shared_ptr<model::Node> CreateNode(
140           IteratorContext* ctx, model::Node::Args args) const override {
141         return model::MakeKnownRatioNode(std::move(args),
142                                          /*ratio=*/1);
143       }
144 
SaveInternal(IteratorStateWriter * writer)145       Status SaveInternal(IteratorStateWriter* writer) override {
146         mutex_lock l(mu_);
147         TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("mode"), mode_));
148         return SaveInput(writer, iterator_);
149       }
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)150       Status RestoreInternal(IteratorContext* ctx,
151                              IteratorStateReader* reader) override {
152         mutex_lock l(mu_);
153         {
154           int64 temp;
155           TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("mode"), &temp));
156           mode_ = static_cast<Mode>(temp);
157         }
158         if (mode_ == Mode::write &&
159             dataset()
160                 ->env_->FileExists(MetaFilename(dataset()->filename_))
161                 .ok()) {
162           // This could happen if the cache was completely written after the
163           // checkpoint was saved.
164           LOG(WARNING)
165               << "It looks like the cache was already completely written("
166               << MetaFilename(dataset()->filename_)
167               << ") after the last checkpoint was saved. "
168               << "Attempting to read the cache instead of continuing to "
169               << "write. If this is a mistake, please remove the above file "
170               << "and try running again.";
171           mode_ = Mode::read;
172         }
173         InitializeIterator();
174         TF_RETURN_IF_ERROR(iterator_->Initialize(ctx));
175         return RestoreInput(ctx, reader, iterator_);
176       }
177 
178      private:
179       // FileWriterIterator passes through and caches items from the input
180       // FileDataset.
181       //
182       // This iterator is used when the cache directory is not found on disk. It
183       // creates the cache directory, and passes on the underlying iterator's
184       // elements.
185       //
186       // Caching is performed by writing the input tensors to disk using the
187       // `BundleWriter`. Note that the cache gets fully flushed to disk only
188       // after the input iterator has been fully exhausted. If the program
189       // exits, before completion of an epoch, the cached state would be lost.
190       // To ensure that the partial cache persists across sessions, one should
191       // checkpoint the input pipeline. On each call to `SaveInternal` the
192       // partial cache gets flushed to disk in files with prefix
193       // <filename>_<shard_id> where shard_id is unique for each checkpoint.
194       // When all elements have been produced, these shards get coalesced.
195       class FileWriterIterator : public DatasetIterator<FileDataset> {
196        public:
FileWriterIterator(const Params & params)197         explicit FileWriterIterator(const Params& params)
198             : DatasetIterator<FileDataset>(params),
199               cur_index_(0),
200               shard_id_(0),
201               filename_(
202                   strings::StrCat(params.dataset->filename_, "_", shard_id_)),
203               lockfile_(strings::StrCat(filename_, ".lockfile")),
204               lockfile_created_(false),
205               iteration_completed_(false) {}
206 
Initialize(IteratorContext * ctx)207         Status Initialize(IteratorContext* ctx) override {
208           return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
209         }
210 
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)211         Status GetNextInternal(IteratorContext* ctx,
212                                std::vector<Tensor>* out_tensors,
213                                bool* end_of_sequence) override {
214           mutex_lock l(mu_);
215           TF_RETURN_IF_ERROR(EnsureLockFileExists());
216           TF_RETURN_IF_ERROR(writer_->status());
217           if (cur_index_ >= kMaxItems) {
218             // As a courtesy, close the [truncated] cache file.
219             Status s = Finish();
220             if (!s.ok()) {
221               LOG(ERROR) << s;
222             }
223             return errors::InvalidArgument(
224                 "Upstream iterator is producing more than ", kMaxItems,
225                 " items, which is more than the cache limit.");
226           }
227 
228           TF_RETURN_IF_ERROR(
229               input_impl_->GetNext(ctx, out_tensors, end_of_sequence));
230           if (*end_of_sequence && out_tensors->empty()) {
231             TF_RETURN_IF_ERROR(Finish());
232             cur_index_++;
233             return Status::OK();
234           }
235           if (out_tensors->size() != dataset()->num_tensors_) {
236             return errors::Internal(
237                 "Upstream iterator returned invalid number of tensors. "
238                 "Expected ",
239                 dataset()->num_tensors_, " got: ", out_tensors->size());
240           }
241           size_t tensor_index = 0;
242           for (const Tensor& t : *out_tensors) {
243             DCHECK_LT(tensor_index, dataset()->num_tensors_);
244             string key = dataset()->FormatName(cur_index_, tensor_index++);
245             TF_RETURN_IF_ERROR(writer_->Add(key, t));
246           }
247           if (*end_of_sequence) {
248             TF_RETURN_IF_ERROR(Finish());
249           }
250           cur_index_++;
251           return Status::OK();
252         }
253 
254        protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const255         std::shared_ptr<model::Node> CreateNode(
256             IteratorContext* ctx, model::Node::Args args) const override {
257           return model::MakeKnownRatioNode(std::move(args),
258                                            /*ratio=*/1);
259         }
260 
SaveInternal(IteratorStateWriter * writer)261         Status SaveInternal(IteratorStateWriter* writer) override {
262           mutex_lock l(mu_);
263           if (iteration_completed_) {
264             TF_RETURN_IF_ERROR(
265                 writer->WriteScalar(full_name("iteration_completed"), ""));
266             return Status::OK();
267           }
268 
269           // lockfile is created on the first call to GetNextInternal. The
270           // absence of a lockfile means that GetNextInternal was not called
271           // and hence nothing was written to cache. So we don't need to worry
272           // about flushing the current shard. This ensures that we never write
273           // empty shards.
274           if (lockfile_created_) {
275             // Flush the current bundle.
276             TF_RETURN_IF_ERROR(writer_->Finish());
277 
278             // Note: We do not delete the lockfile here. We keep lockfiles of
279             // all shards around until the entire cache has been written to
280             // prevent concurrent iterators from corrupting any of the shards.
281 
282             // Start caching to a new shard.
283             shard_id_++;
284             filename_ = strings::StrCat(dataset()->filename_, "_", shard_id_);
285             lockfile_ = strings::StrCat(filename_, ".lockfile");
286             lockfile_created_ = false;
287           }
288           TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
289           TF_RETURN_IF_ERROR(
290               writer->WriteScalar(full_name("cur_index"), cur_index_));
291           TF_RETURN_IF_ERROR(
292               writer->WriteScalar(full_name("shard_id"), shard_id_));
293           return Status::OK();
294         }
295 
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)296         Status RestoreInternal(IteratorContext* ctx,
297                                IteratorStateReader* reader) override {
298           mutex_lock l(mu_);
299           if (reader->Contains(full_name("iteration_completed"))) {
300             iteration_completed_ = true;
301             return Status::OK();
302           }
303 
304           TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
305           int64 temp;
306           // TODO(b/78048575): Update this when saving size_t tensors directly
307           // is supported.
308           {
309             TF_RETURN_IF_ERROR(
310                 reader->ReadScalar(full_name("cur_index"), &temp));
311             cur_index_ = static_cast<size_t>(temp);
312             if (cur_index_ != temp) {
313               return errors::Internal("Invalid value for cur_index ", temp);
314             }
315           }
316           // TODO(b/78048575): Update this when saving size_t tensors directly
317           // is supported.
318           {
319             TF_RETURN_IF_ERROR(
320                 reader->ReadScalar(full_name("shard_id"), &temp));
321             shard_id_ = static_cast<size_t>(temp);
322             if (shard_id_ != temp) {
323               return errors::Internal("Invalid value for shard_id ", temp);
324             }
325           }
326           filename_ = strings::StrCat(dataset()->filename_, "_", shard_id_);
327           lockfile_ = strings::StrCat(filename_, ".lockfile");
328           writer_ = absl::make_unique<BundleWriter>(dataset()->env_, filename_);
329           return Status::OK();
330         }
331 
332        private:
EnsureLockFileExists()333         Status EnsureLockFileExists() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
334           if (iteration_completed_)
335             return errors::OutOfRange(
336                 "Attempting to call get_next after iteration should have "
337                 "finished.");
338           if (lockfile_created_ && !iteration_completed_) return Status::OK();
339 
340           // Perform rudimentary locking to help catch concurrent writes to the
341           // same cache files.
342 
343           // 1. Check that a checkpoint for the shard has not already been
344           // written.
345           if (dataset()->env_->FileExists(MetaFilename(filename_)).ok()) {
346             return errors::AlreadyExists("Existing cache files found: \n",
347                                          MetaFilename(filename_), "\n",
348                                          DataFilename(filename_, 0, 1), "\n",
349                                          "To continue delete the above files.");
350           }
351 
352           // 2. Check that there isn't a concurrent iterator that is writing
353           // to cache.
354           if (dataset()->env_->FileExists(lockfile_).ok()) {
355             // Attempt to read the contents of the lockfile.
356             char contents_scratch[151] = {0};  // Initialize all to 0.
357             StringPiece contents;
358             std::unique_ptr<RandomAccessFile> file;
359             if (dataset()->env_->NewRandomAccessFile(lockfile_, &file).ok()) {
360               file->Read(0, 150, &contents, contents_scratch).IgnoreError();
361             }
362             return errors::AlreadyExists(
363                 "There appears to be a concurrent caching iterator running - "
364                 "cache lockfile already exists ('",
365                 lockfile_,
366                 "'). If you are sure no other running TF computations are "
367                 "using "
368                 "this cache prefix, delete the lockfile and re-initialize the "
369                 "iterator. Lockfile contents: ",
370                 contents);
371           } else {
372             // Create the file, and write some basic contents.
373             std::unique_ptr<WritableFile> lockfile;
374             TF_RETURN_IF_ERROR(
375                 dataset()->env_->NewWritableFile(lockfile_, &lockfile));
376             TF_RETURN_IF_ERROR(lockfile->Append(strings::StrCat(
377                 "Created at: ", dataset()->env_->NowSeconds())));
378 
379             // At this point we know that
380             // 1. There is no conflicting checkpoint with prefix `filename_`.
381             // 2. There is no concurrent session that is trying to write a ckpt
382             //    to filename.
383             // So it is safe to create a BundleWriter here. Note that it is
384             // unsafe to initialize the BundleWriter anywhere the above
385             // conditions are not met since BundleWriter's constructor creates
386             // new temp files which can delete the temp files created by a
387             // BundleWriter in another Session.
388             writer_ =
389                 absl::make_unique<BundleWriter>(dataset()->env_, filename_);
390             lockfile_created_ = true;
391             return Status::OK();
392           }
393         }
394 
Finish()395         Status Finish() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
396           iteration_completed_ = true;
397           // Flush the current bundle.
398           TF_RETURN_IF_ERROR(writer_->Finish());
399           // Merge all the bundles.
400           // Currently there are `shard_id_ + 1` bundles, one for each
401           // checkpoint. Each bundle has prefix <filename>_<id> where `id` is an
402           // integer starting at 0 an incremented by 1 for each new checkpoint.
403           // We merge all these bundles into a bundle with prefix <filename> so
404           // that the next call to `MakeIterator` can build a
405           // `FileReaderIterator`.
406           {
407             std::vector<string> prefixes;
408             prefixes.reserve(shard_id_ + 1);
409             for (size_t i = 0; i <= shard_id_; ++i) {
410               prefixes.emplace_back(
411                   strings::StrCat(dataset()->filename_, "_", i));
412             }
413             TF_RETURN_IF_ERROR(
414                 MergeBundles(dataset()->env_, prefixes, dataset()->filename_));
415           }
416           // Delete all lockfiles.
417           for (size_t i = 0; i <= shard_id_; ++i) {
418             TF_RETURN_IF_ERROR(dataset()->env_->DeleteFile(
419                 strings::StrCat(dataset()->filename_, "_", i, ".lockfile")));
420           }
421           return Status::OK();
422         }
423 
424         mutex mu_;
425         size_t cur_index_ GUARDED_BY(mu_);
426         // Index of the current shard. This gets incremented whenever a new
427         // cache shard is saved.
428         size_t shard_id_ GUARDED_BY(mu_);
429         std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
430         // The current prefix for the cache file. This is equal to
431         // `StrCat(dataset()->filename_, "_", shard_id_)`.
432         string filename_;
433         std::unique_ptr<BundleWriter> writer_ GUARDED_BY(mu_);
434         string lockfile_ GUARDED_BY(mu_);
435         bool lockfile_created_ GUARDED_BY(mu_);
436         bool iteration_completed_ GUARDED_BY(mu_);
437       };  // FileWriterIterator
438 
439       class FileReaderIterator : public DatasetIterator<FileDataset> {
440        public:
FileReaderIterator(const Params & params)441         explicit FileReaderIterator(const Params& params)
442             : DatasetIterator<FileDataset>(params),
443               cur_index_(0),
444               reader_(dataset()->env_, dataset()->filename_),
445               iterator_restored_(false) {}
446 
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)447         Status GetNextInternal(IteratorContext* ctx,
448                                std::vector<Tensor>* out_tensors,
449                                bool* end_of_sequence) override {
450           mutex_lock l(mu_);
451           *end_of_sequence = false;
452           TF_RETURN_IF_ERROR(reader_.status());
453           if (!reader_.Valid()) {
454             return errors::Internal(
455                 "Cache iterator is in an invalid state. (Perhaps GetNext "
456                 "called "
457                 "after end_of_sequence?)");
458           }
459           out_tensors->clear();
460           out_tensors->resize(dataset()->num_tensors_);
461 
462           for (size_t i = 0; i < dataset()->num_tensors_; ++i) {
463             // When the iterator is restored from the checkpoint, `reader_` is
464             // already pointing at `key` so we do not need to skip the header
465             // entry.
466             if (!iterator_restored_) {
467               reader_
468                   .Next();  // The first entry in the table is a header entry.
469             } else {
470               iterator_restored_ = false;
471             }
472             if (!reader_.Valid()) {
473               out_tensors->clear();
474               *end_of_sequence = true;
475               return Status::OK();
476             }
477             StringPiece key = reader_.key();
478             DCHECK_EQ(key, dataset()->FormatName(cur_index_, i));
479             TF_RETURN_IF_ERROR(reader_.ReadCurrent(&(*out_tensors)[i]));
480             TF_RETURN_IF_ERROR(reader_.status());
481           }
482           cur_index_++;
483           return Status::OK();
484         }
485 
486        protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const487         std::shared_ptr<model::Node> CreateNode(
488             IteratorContext* ctx, model::Node::Args args) const override {
489           return model::MakeKnownRatioNode(std::move(args),
490                                            /*ratio=*/1);
491         }
492 
SaveInternal(IteratorStateWriter * writer)493         Status SaveInternal(IteratorStateWriter* writer) override {
494           mutex_lock l(mu_);
495           TF_RETURN_IF_ERROR(
496               writer->WriteScalar(full_name("cur_index"), cur_index_));
497           return Status::OK();
498         }
499 
RestoreInternal(IteratorContext * ctx,IteratorStateReader * iterator_state_reader)500         Status RestoreInternal(
501             IteratorContext* ctx,
502             IteratorStateReader* iterator_state_reader) override {
503           mutex_lock l(mu_);
504           {
505             // TODO(b/78048575): Update this when saving size_t tensors directly
506             // is supported.
507             int64 temp;
508             TF_RETURN_IF_ERROR(iterator_state_reader->ReadScalar(
509                 full_name("cur_index"), &temp));
510             cur_index_ = static_cast<size_t>(temp);
511             if (cur_index_ != temp) {
512               return errors::Internal("Invalid value for cur_index ", temp);
513             }
514           }
515           if (!reader_.Valid()) {
516             return errors::Internal("Error initializing BundleReader.");
517           }
518           reader_.Seek(dataset()->FormatName(cur_index_, 0));
519           iterator_restored_ = true;
520           return Status::OK();
521         }
522 
523        private:
524         mutex mu_;
525         size_t cur_index_ GUARDED_BY(mu_);
526         BundleReader reader_ GUARDED_BY(mu_);
527         bool iterator_restored_ GUARDED_BY(mu_);
528       };  // FileReaderIterator
529 
InitializeIterator()530       void InitializeIterator() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
531         // We intentionally use the same prefix for both `FileReaderIterator`
532         // and `FileWriterIterator`. Since at any time there will be at most
533         // one of them alive, there should be no conflicts. This allows both
534         // iterators to use a common key for `cur_index`. We leverage this
535         // in the corner case when this iterator is restored from an old
536         // checkpoint in `write` mode and the cache has been completely
537         // flushed to disk since then. In that case we simply build a
538         // `FileReaderIterator` and seek to the `cur_index`.
539         switch (mode_) {
540           case Mode::read:
541             iterator_ = absl::make_unique<FileReaderIterator>(
542                 FileReaderIterator::Params{dataset(),
543                                            strings::StrCat(prefix(), "Impl")});
544             break;
545           case Mode::write:
546             iterator_ = absl::make_unique<FileWriterIterator>(
547                 FileWriterIterator::Params{dataset(),
548                                            strings::StrCat(prefix(), "Impl")});
549         }
550       }
551 
552       mutex mu_;
553       enum Mode { read, write };
554       Mode mode_ GUARDED_BY(mu_);
555       std::unique_ptr<IteratorBase> iterator_ GUARDED_BY(mu_);
556     };  // FileIterator
557 
558     const DatasetBase* const input_;
559     const string filename_;
560     Env* const env_;
561     const size_t num_tensors_;
562     const size_t tensor_index_padding_size_;
563     static const size_t kMaxItems = 10000000;  // 10 million
564     const size_t item_index_padding_size_;
565     const string tensor_format_string_;
566   };  // FileDataset
567 
568   class MemoryDataset : public DatasetBase {
569    public:
MemoryDataset(OpKernelContext * ctx,const DatasetBase * input)570     explicit MemoryDataset(OpKernelContext* ctx, const DatasetBase* input)
571         : DatasetBase(DatasetContext(ctx)), input_(input) {
572       input->Ref();
573     }
574 
~MemoryDataset()575     ~MemoryDataset() override { input_->Unref(); }
576 
MakeIteratorInternal(const string & prefix) const577     std::unique_ptr<IteratorBase> MakeIteratorInternal(
578         const string& prefix) const override {
579       return absl::make_unique<MemoryIterator>(MemoryIterator::Params{
580           this, strings::StrCat(prefix, "::MemoryCache")});
581     }
582 
output_dtypes() const583     const DataTypeVector& output_dtypes() const override {
584       return input_->output_dtypes();
585     }
586 
output_shapes() const587     const std::vector<PartialTensorShape>& output_shapes() const override {
588       return input_->output_shapes();
589     }
590 
DebugString() const591     string DebugString() const override {
592       return "CacheDatasetOp::MemoryDataset";
593     }
594 
Cardinality() const595     int64 Cardinality() const override { return input_->Cardinality(); }
596 
597    protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const598     Status AsGraphDefInternal(SerializationContext* ctx,
599                               DatasetGraphDefBuilder* b,
600                               Node** output) const override {
601       Node* input_node = nullptr;
602       TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node));
603       Node* filename_node = nullptr;
604       TF_RETURN_IF_ERROR(b->AddScalar(string(""), &filename_node));
605       TF_RETURN_IF_ERROR(
606           b->AddDataset(this, {input_node, filename_node}, output));
607       return Status::OK();
608     }
609 
610    private:
611     // A thread-safe data structure for caching dataset elements.
612     //
613     // The expected use is that a single `MemoryWriterIterator` populates the
614     // cache with dataset elements. Once all elements are cached, the cache can
615     // be used by one or more `MemoryReaderIterator`s.
616     class MemoryCache : public ResourceBase {
617      public:
618       MemoryCache() = default;
619 
DebugString() const620       string DebugString() const override {
621         return "CacheDataset::MemoryCache";
622       }
623 
624       // Marks the cache as completed.
Complete()625       void Complete() {
626         mutex_lock l(mu_);
627         completed_ = true;
628       }
629 
630       // Returns whether the cache is claimed.
IsClaimed()631       bool IsClaimed() {
632         tf_shared_lock l(mu_);
633         return claimed_;
634       }
635 
636       // Returns whether the cache is completed.
IsCompleted()637       bool IsCompleted() {
638         tf_shared_lock l(mu_);
639         return completed_;
640       }
641 
642       // Attempts to claim the cache, returning whether the cache was claimed.
MaybeClaim()643       bool MaybeClaim() {
644         mutex_lock l(mu_);
645         if (!claimed_) {
646           claimed_ = true;
647           return true;
648         }
649         return false;
650       }
651 
652       // Resets the cache.
Reset()653       void Reset() {
654         mutex_lock l(mu_);
655         claimed_ = false;
656         completed_ = false;
657         cache_.clear();
658       }
659 
660       // Returns the element at the given index.
at(int64 index)661       const std::vector<Tensor>& at(int64 index) {
662         tf_shared_lock l(mu_);
663         DCHECK(index < cache_.size());
664         return cache_[index];
665       }
666 
667       // Adds the element to the cache.
emplace_back(std::vector<Tensor> element)668       void emplace_back(std::vector<Tensor> element) {
669         mutex_lock l(mu_);
670         cache_.emplace_back(std::move(element));
671       }
672 
673       // Returns the size of the cache.
size()674       size_t size() {
675         tf_shared_lock l(mu_);
676         return cache_.size();
677       }
678 
679      private:
680       mutex mu_;
681       // Determines whether a writer has claimed the cache.
682       bool claimed_ GUARDED_BY(mu_) = false;
683       // Determines whether all elements of the dataset have been cached.
684       bool completed_ GUARDED_BY(mu_) = false;
685       std::vector<std::vector<Tensor>> cache_ GUARDED_BY(mu_);
686     };
687 
688     class MemoryIterator : public DatasetIterator<MemoryDataset> {
689      public:
MemoryIterator(const Params & params)690       explicit MemoryIterator(const Params& params)
691           : DatasetIterator<MemoryDataset>(params) {}
692 
~MemoryIterator()693       ~MemoryIterator() override { cache_->Unref(); }
694 
Initialize(IteratorContext * ctx)695       Status Initialize(IteratorContext* ctx) override {
696         mutex_lock l(mu_);
697         // Use the resource manager in the iterator context to get / create
698         // a cache.
699         ResourceMgr* mgr = ctx->resource_mgr();
700         const string name = strings::StrCat(
701             prefix(), "::", dataset()->node_name(), "::MemoryCache");
702         TF_RETURN_IF_ERROR(mgr->LookupOrCreate<MemoryCache>(
703             "tf_data", name, &cache_, [](MemoryCache** cache) {
704               *cache = new MemoryCache();
705               return Status::OK();
706             }));
707         mode_ = cache_->MaybeClaim() ? Mode::write : Mode::read;
708         InitializeIterator();
709         if (mode_ == Mode::read && !cache_->IsCompleted()) {
710           return errors::Internal(
711               "Cache should only be read after it has been completed.");
712         }
713         return iterator_->Initialize(ctx);
714       }
715 
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)716       Status GetNextInternal(IteratorContext* ctx,
717                              std::vector<Tensor>* out_tensors,
718                              bool* end_of_sequence) override {
719         mutex_lock l(mu_);
720         return iterator_->GetNext(ctx, out_tensors, end_of_sequence);
721       }
722 
723      protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const724       std::shared_ptr<model::Node> CreateNode(
725           IteratorContext* ctx, model::Node::Args args) const override {
726         return model::MakeKnownRatioNode(std::move(args),
727                                          /*ratio=*/1);
728       }
729 
SaveInternal(IteratorStateWriter * writer)730       Status SaveInternal(IteratorStateWriter* writer) override {
731         mutex_lock l(mu_);
732         TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("mode"), mode_));
733         if (cache_->IsClaimed()) {
734           TF_RETURN_IF_ERROR(
735               writer->WriteScalar(full_name("cache_claimed"), ""));
736           size_t cache_size = cache_->size();
737           TF_RETURN_IF_ERROR(
738               writer->WriteScalar(full_name("cache_size"), cache_size));
739           for (size_t i = 0; i < cache_size; i++) {
740             auto& element = cache_->at(i);
741             TF_RETURN_IF_ERROR(writer->WriteScalar(
742                 full_name(strings::StrCat("cache[", i, "].size")),
743                 element.size()));
744             for (size_t j = 0; j < element.size(); ++j) {
745               TF_RETURN_IF_ERROR(writer->WriteTensor(
746                   full_name(strings::StrCat("cache[", i, "][", j, "]")),
747                   element[j]));
748             }
749           }
750           if (cache_->IsCompleted()) {
751             TF_RETURN_IF_ERROR(
752                 writer->WriteScalar(full_name("cache_completed"), ""));
753           }
754         }
755         return SaveInput(writer, iterator_);
756       }
757 
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)758       Status RestoreInternal(IteratorContext* ctx,
759                              IteratorStateReader* reader) override {
760         mutex_lock l(mu_);
761         iterator_.reset();
762         cache_->Reset();
763         {
764           int64 temp;
765           TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("mode"), &temp));
766           mode_ = static_cast<Mode>(temp);
767         }
768         if (reader->Contains(full_name("cache_claimed"))) {
769           CHECK(cache_->MaybeClaim());
770           size_t cache_size;
771           {
772             int64 temp;
773             TF_RETURN_IF_ERROR(
774                 reader->ReadScalar(full_name("cache_size"), &temp));
775             cache_size = static_cast<size_t>(temp);
776           }
777           for (size_t i = 0; i < cache_size; ++i) {
778             std::vector<Tensor> element;
779             size_t element_size;
780             {
781               int64 temp;
782               TF_RETURN_IF_ERROR(reader->ReadScalar(
783                   full_name(strings::StrCat("cache[", i, "].size")), &temp));
784               element_size = static_cast<size_t>(temp);
785             }
786             element.reserve(element_size);
787             for (size_t j = 0; j < element_size; ++j) {
788               element.emplace_back();
789               TF_RETURN_IF_ERROR(reader->ReadTensor(
790                   full_name(strings::StrCat("cache[", i, "][", j, "]")),
791                   &element.back()));
792             }
793             cache_->emplace_back(std::move(element));
794           }
795           if (reader->Contains(full_name("cache_completed"))) {
796             cache_->Complete();
797           }
798         }
799         InitializeIterator();
800         TF_RETURN_IF_ERROR(iterator_->Initialize(ctx));
801         return RestoreInput(ctx, reader, iterator_);
802       }
803 
804      private:
805       class MemoryWriterIterator : public DatasetIterator<MemoryDataset> {
806        public:
MemoryWriterIterator(const Params & params,MemoryCache * cache)807         explicit MemoryWriterIterator(const Params& params, MemoryCache* cache)
808             : DatasetIterator<MemoryDataset>(params), cache_(cache) {
809           CHECK(cache_);
810         }
811 
~MemoryWriterIterator()812         ~MemoryWriterIterator() override {
813           mutex_lock l(mu_);
814           if (cache_->size() > 0 && !cache_->IsCompleted()) {
815             LOG(WARNING)
816                 << "The calling iterator did not fully read the dataset being "
817                    "cached. In order to avoid unexpected truncation of the "
818                    "dataset, the partially cached contents of the dataset "
819                    "will be discarded. This can happen if you have an input "
820                    "pipeline similar to `dataset.cache().take(k).repeat()`. "
821                    "You should use `dataset.take(k).cache().repeat()` instead.";
822             cache_->Reset();
823           }
824         }
825 
Initialize(IteratorContext * ctx)826         Status Initialize(IteratorContext* ctx) override {
827           return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
828         }
829 
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)830         Status GetNextInternal(IteratorContext* ctx,
831                                std::vector<Tensor>* out_tensors,
832                                bool* end_of_sequence) override {
833           mutex_lock l(mu_);
834           TF_RETURN_IF_ERROR(
835               input_impl_->GetNext(ctx, out_tensors, end_of_sequence));
836           if (*end_of_sequence) {
837             cache_->Complete();
838             return Status::OK();
839           }
840           RecordBufferEnqueue(ctx, *out_tensors);
841           cache_->emplace_back(*out_tensors);
842           return Status::OK();
843         }
844 
845        protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const846         std::shared_ptr<model::Node> CreateNode(
847             IteratorContext* ctx, model::Node::Args args) const override {
848           return model::MakeKnownRatioNode(std::move(args),
849                                            /*ratio=*/1);
850         }
851 
SaveInternal(IteratorStateWriter * writer)852         Status SaveInternal(IteratorStateWriter* writer) override {
853           mutex_lock l(mu_);
854           return SaveInput(writer, input_impl_);
855         }
856 
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)857         Status RestoreInternal(IteratorContext* ctx,
858                                IteratorStateReader* reader) override {
859           mutex_lock l(mu_);
860           return RestoreInput(ctx, reader, input_impl_);
861         }
862 
863        private:
864         mutex mu_;
865         std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
866         MemoryCache* const cache_ GUARDED_BY(mu_);  // not owned.
867       };  // MemoryWriterIterator
868 
869       class MemoryReaderIterator : public DatasetIterator<MemoryDataset> {
870        public:
MemoryReaderIterator(const Params & params,MemoryCache * cache)871         explicit MemoryReaderIterator(const Params& params, MemoryCache* cache)
872             : DatasetIterator<MemoryDataset>(params), cache_(cache), index_(0) {
873           CHECK(cache);
874         }
875 
Initialize(IteratorContext * ctx)876         Status Initialize(IteratorContext* ctx) override {
877           // The memory allocated for the cache is owned by the parent
878           // dataset but performance modeling uses the iterator abstraction and
879           // thus we record the memory allocated for the cache here. The caveat
880           // is that this is incorrect if there are concurrent instances of this
881           // iterator.
882           tf_shared_lock l(mu_);
883           for (size_t i = 0; i < cache_->size(); ++i) {
884             RecordBufferEnqueue(ctx, cache_->at(i));
885           }
886           return Status::OK();
887         }
888 
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)889         Status GetNextInternal(IteratorContext* ctx,
890                                std::vector<Tensor>* out_tensors,
891                                bool* end_of_sequence) override {
892           mutex_lock l(mu_);
893           if (index_ < cache_->size()) {
894             const std::vector<Tensor>& cache_tensors = cache_->at(index_);
895             out_tensors->insert(out_tensors->begin(), cache_tensors.begin(),
896                                 cache_tensors.end());
897             index_++;
898             *end_of_sequence = false;
899             return Status::OK();
900           } else {
901             *end_of_sequence = true;
902             return Status::OK();
903           }
904         }
905 
906        protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const907         std::shared_ptr<model::Node> CreateNode(
908             IteratorContext* ctx, model::Node::Args args) const override {
909           return model::MakeKnownRatioNode(std::move(args),
910                                            /*ratio=*/1);
911         }
912 
SaveInternal(IteratorStateWriter * writer)913         Status SaveInternal(IteratorStateWriter* writer) override {
914           mutex_lock l(mu_);
915           TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("index"), index_));
916           return Status::OK();
917         }
918 
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)919         Status RestoreInternal(IteratorContext* ctx,
920                                IteratorStateReader* reader) override {
921           mutex_lock l(mu_);
922           {
923             int64 temp;
924             TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("index"), &temp));
925             index_ = static_cast<size_t>(temp);
926           }
927           return Status::OK();
928         }
929 
930        private:
931         mutex mu_;
932         MemoryCache* const cache_ GUARDED_BY(mu_);  // not owned.
933         size_t index_ GUARDED_BY(mu_);
934       };  // MemoryReaderIterator
935 
InitializeIterator()936       void InitializeIterator() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
937         switch (mode_) {
938           case Mode::read:
939             iterator_ = absl::make_unique<MemoryReaderIterator>(
940                 MemoryReaderIterator::Params{dataset(),
941                                              strings::StrCat(prefix(), "Impl")},
942                 cache_);
943             break;
944           case Mode::write:
945             iterator_ = absl::make_unique<MemoryWriterIterator>(
946                 MemoryWriterIterator::Params{dataset(),
947                                              strings::StrCat(prefix(), "Impl")},
948                 cache_);
949         }
950       }
951 
952       mutex mu_;
953       MemoryCache* cache_ GUARDED_BY(mu_);  // not owned.
954       enum Mode { read, write };
955       Mode mode_ GUARDED_BY(mu_);
956       std::unique_ptr<IteratorBase> iterator_ GUARDED_BY(mu_);
957     };  // MemoryIterator
958 
959     const DatasetBase* const input_;
960   };  // MemoryDataset
961 };    // CacheDatasetOp
962 
963 REGISTER_KERNEL_BUILDER(Name("CacheDataset").Device(DEVICE_CPU),
964                         CacheDatasetOp);
965 
966 }  // namespace
967 }  // namespace data
968 }  // namespace tensorflow
969