• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/cache_dataset_ops.h"
16 
17 #include "tensorflow/core/framework/partial_tensor_shape.h"
18 #include "tensorflow/core/framework/resource_mgr.h"
19 #include "tensorflow/core/framework/tensor.h"
20 #include "tensorflow/core/kernels/data/cache_ops.h"
21 #include "tensorflow/core/kernels/data/dataset_utils.h"
22 #include "tensorflow/core/kernels/data/name_utils.h"
23 #include "tensorflow/core/lib/core/errors.h"
24 #include "tensorflow/core/lib/strings/stringprintf.h"
25 #include "tensorflow/core/platform/env.h"
26 #include "tensorflow/core/platform/errors.h"
27 #include "tensorflow/core/util/tensor_bundle/tensor_bundle.h"
28 
29 namespace tensorflow {
30 namespace data {
31 
32 // See documentation in ../../ops/dataset_ops.cc for a high-level description of
33 // the following op.
34 
35 /* static */ constexpr const char* const CacheDatasetOp::kDatasetType;
36 /* static */ constexpr const char* const CacheDatasetOp::kInputDataset;
37 /* static */ constexpr const char* const CacheDatasetOp::kFileName;
38 /* static */ constexpr const char* const CacheDatasetOp::kOutputTypes;
39 /* static */ constexpr const char* const CacheDatasetOp::kOutputShapes;
40 
41 namespace {
42 
43 constexpr char kKeyStrFormat[] = "%%%zuzu_%%%zuzu";
44 constexpr char kPaddingSizeStrFormat[] = "%zu";
45 constexpr char kFileDatasetPrefix[] = "File";
46 constexpr char kMode[] = "Mode";
47 constexpr char kLockFileSuffix[] = ".lockfile";
48 constexpr char kIterationCompleted[] = "iteration_completed";
49 constexpr char kCurIndex[] = "cur_index";
50 constexpr char kShardId[] = "shard_id";
51 constexpr char kCreatedAt[] = "Created at";
52 constexpr char kMemoryDatasetPrefix[] = "Memory";
53 constexpr char kMemoryCache[] = "MemoryCache";
54 constexpr char kCacheClaimed[] = "cache_claimed";
55 constexpr char kCacheSize[] = "cache_size";
56 constexpr char kCache[] = "cache";
57 constexpr char kSizeSuffix[] = ".size";
58 constexpr char kCacheCompleted[] = "cache_completed";
59 constexpr char kIndex[] = "index";
60 constexpr char kImpl[] = "Impl";
61 constexpr char kCacheDataset[] = "CacheDataset";
62 constexpr char kIncompleteCacheErrorMessage[] =
63     "The calling iterator did not fully read the dataset being cached. In "
64     "order to avoid unexpected truncation of the dataset, the partially cached "
65     "contents of the dataset  will be discarded. This can happen if you have "
66     "an input pipeline similar to `dataset.cache().take(k).repeat()`. You "
67     "should use `dataset.take(k).cache().repeat()` instead.";
68 
69 }  // namespace
70 
71 class CacheDatasetOp::FileDatasetBase : public DatasetBase {
72  public:
FileDatasetBase(OpKernelContext * ctx,const DatasetBase * input,string filename,Env * env)73   FileDatasetBase(OpKernelContext* ctx, const DatasetBase* input,
74                   string filename, Env* env)
75       : DatasetBase(DatasetContext(ctx)),
76         input_(input),
77         filename_(std::move(filename)),
78         env_(env),
79         num_tensors_(input->output_dtypes().size()),
80         tensor_index_padding_size_(StringPaddingSize(num_tensors_)),
81         item_index_padding_size_(StringPaddingSize(kMaxItems)),
82         tensor_format_string_(strings::Printf(kKeyStrFormat,
83                                               item_index_padding_size_,
84                                               tensor_index_padding_size_)) {
85     input_->Ref();
86     DCHECK_EQ(item_index_padding_size_, 7);
87   }
88 
~FileDatasetBase()89   ~FileDatasetBase() override { input_->Unref(); }
90 
MakeIteratorInternal(const string & prefix) const91   std::unique_ptr<IteratorBase> MakeIteratorInternal(
92       const string& prefix) const override {
93     name_utils::IteratorPrefixParams params;
94     params.dataset_prefix = kFileDatasetPrefix;
95     return absl::make_unique<FileIterator>(FileIterator::Params{
96         this, name_utils::IteratorPrefix(kDatasetType, prefix, params)});
97   }
98 
output_dtypes() const99   const DataTypeVector& output_dtypes() const override {
100     return input_->output_dtypes();
101   }
102 
output_shapes() const103   const std::vector<PartialTensorShape>& output_shapes() const override {
104     return input_->output_shapes();
105   }
106 
DebugString() const107   string DebugString() const override {
108     name_utils::DatasetDebugStringParams params;
109     params.dataset_prefix = kFileDatasetPrefix;
110     return name_utils::DatasetDebugString(kDatasetType, params);
111   }
112 
Cardinality() const113   int64 Cardinality() const override { return input_->Cardinality(); }
114 
InputDatasets(std::vector<const DatasetBase * > * inputs) const115   Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
116     inputs->push_back(input_);
117     return Status::OK();
118   }
119 
CheckExternalState() const120   Status CheckExternalState() const override {
121     return input_->CheckExternalState();
122   }
123 
124  protected:
125   const DatasetBase* const input_;
126   const tstring filename_;
127 
128  private:
StringPaddingSize(size_t num_tensors)129   static size_t StringPaddingSize(size_t num_tensors) {
130     return strings::Printf(kPaddingSizeStrFormat, num_tensors - 1).size();
131   }
132 
FormatName(size_t item_index,size_t tensor_index) const133   string FormatName(size_t item_index, size_t tensor_index) const {
134     return strings::Printf(tensor_format_string_.c_str(), item_index,
135                            tensor_index);
136   }
137 
138   class FileIterator : public DatasetIterator<FileDatasetBase> {
139    public:
FileIterator(const Params & params)140     explicit FileIterator(const Params& params)
141         : DatasetIterator<FileDatasetBase>(params) {
142       if (params.dataset->env_
143               ->FileExists(MetaFilename(params.dataset->filename_))
144               .ok()) {
145         mode_ = Mode::read;
146       } else {
147         mode_ = Mode::write;
148       }
149     }
150 
Initialize(IteratorContext * ctx)151     Status Initialize(IteratorContext* ctx) override {
152       mutex_lock l(mu_);
153       return InitializeIterator(ctx);
154     }
155 
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)156     Status GetNextInternal(IteratorContext* ctx,
157                            std::vector<Tensor>* out_tensors,
158                            bool* end_of_sequence) override {
159       mutex_lock l(mu_);
160       return iterator_->GetNext(ctx, out_tensors, end_of_sequence);
161     }
162 
163    protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const164     std::shared_ptr<model::Node> CreateNode(
165         IteratorContext* ctx, model::Node::Args args) const override {
166       return model::MakeKnownRatioNode(std::move(args),
167                                        /*ratio=*/1);
168     }
169 
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)170     Status SaveInternal(SerializationContext* ctx,
171                         IteratorStateWriter* writer) override {
172       mutex_lock l(mu_);
173       TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kMode), mode_));
174       return SaveInput(ctx, writer, iterator_);
175     }
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)176     Status RestoreInternal(IteratorContext* ctx,
177                            IteratorStateReader* reader) override {
178       mutex_lock l(mu_);
179       {
180         int64 temp;
181         TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kMode), &temp));
182         mode_ = static_cast<Mode>(temp);
183       }
184       if (mode_ == Mode::write &&
185           dataset()
186               ->env_->FileExists(MetaFilename(dataset()->filename_))
187               .ok()) {
188         // This could happen if the cache was completely written after the
189         // checkpoint was saved.
190         LOG(WARNING)
191             << "It looks like the cache was already completely written("
192             << MetaFilename(dataset()->filename_)
193             << ") after the last checkpoint was saved. Attempting to read "
194             << "the cache instead of continuing to write. If this is a "
195             << "mistake, please remove the above file and try running again.";
196         mode_ = Mode::read;
197       }
198       TF_RETURN_IF_ERROR(InitializeIterator(ctx));
199       return RestoreInput(ctx, reader, iterator_);
200     }
201 
202    private:
203     // FileWriterIterator passes through and caches items from the input
204     // FileDatasetBase.
205     //
206     // This iterator is used when the cache directory is not found on disk. It
207     // creates the cache directory, and passes on the underlying iterator's
208     // elements.
209     //
210     // Caching is performed by writing the input tensors to disk using the
211     // `BundleWriter`. Note that the cache gets fully flushed to disk only
212     // after the input iterator has been fully exhausted. If the program
213     // exits, before completion of an epoch, the cached state would be lost.
214     // To ensure that the partial cache persists across sessions, one should
215     // checkpoint the input pipeline. On each call to `SaveInternal` the
216     // partial cache gets flushed to disk in files with prefix
217     // <filename>_<shard_id> where shard_id is unique for each checkpoint.
218     // When all elements have been produced, these shards get coalesced.
219     class FileWriterIterator : public DatasetIterator<FileDatasetBase> {
220      public:
FileWriterIterator(const Params & params)221       explicit FileWriterIterator(const Params& params)
222           : DatasetIterator<FileDatasetBase>(params),
223             cur_index_(0),
224             shard_id_(0),
225             filename_(
226                 strings::StrCat(params.dataset->filename_, "_", shard_id_)),
227             lockfile_(strings::StrCat(filename_, kLockFileSuffix)),
228             lockfile_created_(false),
229             iteration_completed_(false) {}
230 
~FileWriterIterator()231       ~FileWriterIterator() override {
232         if (!dataset()->env_->FileExists(MetaFilename(filename_)).ok()) {
233           LOG(WARNING) << kIncompleteCacheErrorMessage;
234           std::vector<string> cache_files;
235           Status s = dataset()->env_->GetMatchingPaths(
236               strings::StrCat(filename_, "*"), &cache_files);
237           if (!s.ok()) {
238             LOG(WARNING) << "Failed to get matching files on " << filename_
239                          << "* : " << s.ToString();
240           }
241           for (const string& path : cache_files) {
242             s = dataset()->env_->DeleteFile(path);
243             if (!s.ok()) {
244               LOG(WARNING) << "Failed to delete " << path << " : "
245                            << s.ToString();
246             }
247           }
248         }
249       }
250 
Initialize(IteratorContext * ctx)251       Status Initialize(IteratorContext* ctx) override {
252         return dataset()->input_->MakeIterator(ctx, this, prefix(),
253                                                &input_impl_);
254       }
255 
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)256       Status GetNextInternal(IteratorContext* ctx,
257                              std::vector<Tensor>* out_tensors,
258                              bool* end_of_sequence) override {
259         mutex_lock l(mu_);
260         *end_of_sequence = false;
261         TF_RETURN_IF_ERROR(EnsureLockFileExists(end_of_sequence));
262         if (*end_of_sequence) {
263           return Status::OK();
264         }
265         TF_RETURN_IF_ERROR(writer_->status());
266         if (cur_index_ >= kMaxItems) {
267           // As a courtesy, close the [truncated] cache file.
268           Status s = Finish();
269           if (!s.ok()) {
270             LOG(ERROR) << s;
271           }
272           return errors::InvalidArgument(
273               "Upstream iterator is producing more than ", kMaxItems,
274               " items, which is more than the cache limit.");
275         }
276 
277         TF_RETURN_IF_ERROR(
278             input_impl_->GetNext(ctx, out_tensors, end_of_sequence));
279         if (*end_of_sequence && out_tensors->empty()) {
280           TF_RETURN_IF_ERROR(Finish());
281           cur_index_++;
282           return Status::OK();
283         }
284         if (out_tensors->size() != dataset()->num_tensors_) {
285           return errors::Internal(
286               "Upstream iterator returned invalid number of tensors. "
287               "Expected ",
288               dataset()->num_tensors_, " got: ", out_tensors->size());
289         }
290         size_t tensor_index = 0;
291         for (const Tensor& t : *out_tensors) {
292           DCHECK_LT(tensor_index, dataset()->num_tensors_);
293           string key = dataset()->FormatName(cur_index_, tensor_index++);
294           TF_RETURN_IF_ERROR(writer_->Add(key, t));
295         }
296         if (*end_of_sequence) {
297           TF_RETURN_IF_ERROR(Finish());
298         }
299         cur_index_++;
300         return Status::OK();
301       }
302 
303      protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const304       std::shared_ptr<model::Node> CreateNode(
305           IteratorContext* ctx, model::Node::Args args) const override {
306         return model::MakeKnownRatioNode(std::move(args),
307                                          /*ratio=*/1);
308       }
309 
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)310       Status SaveInternal(SerializationContext* ctx,
311                           IteratorStateWriter* writer) override {
312         mutex_lock l(mu_);
313         TF_RETURN_IF_ERROR(
314             writer->WriteScalar(full_name(kCurIndex), cur_index_));
315 
316         if (iteration_completed_) {
317           TF_RETURN_IF_ERROR(
318               writer->WriteScalar(full_name(kIterationCompleted), ""));
319           return Status::OK();
320         }
321 
322         // lockfile is created on the first call to GetNextInternal. The
323         // absence of a lockfile means that GetNextInternal was not called
324         // and hence nothing was written to cache. So we don't need to worry
325         // about flushing the current shard. This ensures that we never write
326         // empty shards.
327         if (lockfile_created_) {
328           // Flush the current bundle.
329           TF_RETURN_IF_ERROR(writer_->Finish());
330 
331           // Note: We do not delete the lockfile here. We keep lockfiles of
332           // all shards around until the entire cache has been written to
333           // prevent concurrent iterators from corrupting any of the shards.
334 
335           // Start caching to a new shard.
336           shard_id_++;
337           filename_ = strings::StrCat(dataset()->filename_, "_", shard_id_);
338           lockfile_ = strings::StrCat(filename_, kLockFileSuffix);
339           lockfile_created_ = false;
340         }
341         TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
342         TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kShardId), shard_id_));
343         return Status::OK();
344       }
345 
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)346       Status RestoreInternal(IteratorContext* ctx,
347                              IteratorStateReader* reader) override {
348         mutex_lock l(mu_);
349         int64 temp;
350         // TODO(b/78048575): Update this when saving size_t tensors directly
351         // is supported.
352         {
353           TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kCurIndex), &temp));
354           cur_index_ = static_cast<size_t>(temp);
355           if (cur_index_ != temp) {
356             return errors::Internal("Invalid value for cur_index ", temp);
357           }
358         }
359 
360         if (reader->Contains(full_name(kIterationCompleted))) {
361           iteration_completed_ = true;
362           return Status::OK();
363         }
364 
365         TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
366 
367         // TODO(b/78048575): Update this when saving size_t tensors directly
368         // is supported.
369         {
370           TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kShardId), &temp));
371           shard_id_ = static_cast<size_t>(temp);
372           if (shard_id_ != temp) {
373             return errors::Internal("Invalid value for shard_id ", temp);
374           }
375         }
376         filename_ = strings::StrCat(dataset()->filename_, "_", shard_id_);
377         lockfile_ = strings::StrCat(filename_, kLockFileSuffix);
378         writer_ = absl::make_unique<BundleWriter>(dataset()->env_, filename_);
379         return Status::OK();
380       }
381 
382      private:
EnsureLockFileExists(bool * end_of_sequence)383       Status EnsureLockFileExists(bool* end_of_sequence)
384           TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
385         if (iteration_completed_) {
386           *end_of_sequence = true;
387           return Status::OK();
388         }
389         if (lockfile_created_) {
390           return Status::OK();
391         }
392 
393         // Perform rudimentary locking to help catch concurrent writes to the
394         // same cache files.
395 
396         // 1. Check that a checkpoint for the shard has not already been
397         // written.
398         if (dataset()->env_->FileExists(MetaFilename(filename_)).ok()) {
399           return errors::AlreadyExists("Existing cache files found: \n",
400                                        MetaFilename(filename_), "\n",
401                                        DataFilename(filename_, 0, 1), "\n",
402                                        "To continue delete the above files.");
403         }
404 
405         // 2. Check that there isn't a concurrent iterator that is writing
406         // to cache.
407         if (dataset()->env_->FileExists(lockfile_).ok()) {
408           // Attempt to read the contents of the lockfile.
409           char contents_scratch[151] = {0};  // Initialize all to 0.
410           StringPiece contents;
411           std::unique_ptr<RandomAccessFile> file;
412           if (dataset()->env_->NewRandomAccessFile(lockfile_, &file).ok()) {
413             file->Read(0, 150, &contents, contents_scratch).IgnoreError();
414           }
415           return errors::AlreadyExists(
416               "There appears to be a concurrent caching iterator running - "
417               "cache lockfile already exists ('",
418               lockfile_,
419               "'). If you are sure no other running TF computations are "
420               "using this cache prefix, delete the lockfile and "
421               "re-initialize the iterator. Lockfile contents: ",
422               contents);
423         }
424         // Create the file, and write some basic contents.
425         std::unique_ptr<WritableFile> lockfile;
426         TF_RETURN_IF_ERROR(
427             dataset()->env_->NewWritableFile(lockfile_, &lockfile));
428         TF_RETURN_IF_ERROR(lockfile->Append(
429             strings::StrCat(kCreatedAt, ": ", EnvTime::NowSeconds())));
430 
431         // At this point we know that
432         // 1. There is no conflicting checkpoint with prefix `filename_`.
433         // 2. There is no concurrent session that is trying to write a ckpt
434         //    to filename.
435         // So it is safe to create a BundleWriter here. Note that it is
436         // unsafe to initialize the BundleWriter anywhere the above
437         // conditions are not met since BundleWriter's constructor creates
438         // new temp files which can delete the temp files created by a
439         // BundleWriter in another Session.
440         writer_ = absl::make_unique<BundleWriter>(dataset()->env_, filename_);
441         lockfile_created_ = true;
442         return Status::OK();
443       }
444 
Finish()445       Status Finish() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
446         iteration_completed_ = true;
447         // Flush the current bundle.
448         TF_RETURN_IF_ERROR(writer_->Finish());
449         // Merge all the bundles.
450         // Currently there are `shard_id_ + 1` bundles, one for each
451         // checkpoint. Each bundle has prefix <filename>_<id> where `id` is an
452         // integer starting at 0 and incremented by 1 for each new checkpoint.
453         // We merge all these bundles into a bundle with prefix <filename> so
454         // that the next call to `MakeIterator` can build a
455         // `FileReaderIterator`.
456         {
457           std::vector<tstring> prefixes;
458           prefixes.reserve(shard_id_ + 1);
459           for (size_t i = 0; i <= shard_id_; ++i) {
460             prefixes.emplace_back(
461                 strings::StrCat(dataset()->filename_, "_", i));
462           }
463           TF_RETURN_IF_ERROR(
464               MergeBundles(dataset()->env_, prefixes, dataset()->filename_));
465         }
466         // Delete all lockfiles.
467         for (size_t i = 0; i <= shard_id_; ++i) {
468           TF_RETURN_IF_ERROR(dataset()->env_->DeleteFile(
469               strings::StrCat(dataset()->filename_, "_", i, kLockFileSuffix)));
470         }
471         return Status::OK();
472       }
473 
474       mutex mu_;
475       size_t cur_index_ TF_GUARDED_BY(mu_);
476       // Index of the current shard. This gets incremented whenever a new
477       // cache shard is saved.
478       size_t shard_id_ TF_GUARDED_BY(mu_);
479       std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_);
480       // The current prefix for the cache file. This is equal to
481       // `StrCat(dataset()->filename_, "_", shard_id_)`.
482       string filename_;
483       std::unique_ptr<BundleWriter> writer_ TF_GUARDED_BY(mu_);
484       string lockfile_ TF_GUARDED_BY(mu_);
485       bool lockfile_created_ TF_GUARDED_BY(mu_);
486       bool iteration_completed_ TF_GUARDED_BY(mu_);
487     };  // FileWriterIterator
488 
489     class FileReaderIterator : public DatasetIterator<FileDatasetBase> {
490      public:
FileReaderIterator(const Params & params)491       explicit FileReaderIterator(const Params& params)
492           : DatasetIterator<FileDatasetBase>(params),
493             cur_index_(0),
494             reader_(dataset()->env_, dataset()->filename_),
495             iterator_restored_(false) {}
496 
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)497       Status GetNextInternal(IteratorContext* ctx,
498                              std::vector<Tensor>* out_tensors,
499                              bool* end_of_sequence) override {
500         mutex_lock l(mu_);
501         *end_of_sequence = false;
502         TF_RETURN_IF_ERROR(reader_.status());
503         if (!reader_.Valid()) {
504           *end_of_sequence = true;
505           return Status::OK();
506         }
507         out_tensors->clear();
508         out_tensors->resize(dataset()->num_tensors_);
509 
510         for (size_t i = 0; i < dataset()->num_tensors_; ++i) {
511           // When the iterator is restored from the checkpoint, `reader_` is
512           // already pointing at `key` so we do not need to skip the header
513           // entry.
514           if (!iterator_restored_) {
515             reader_.Next();  // The first entry in the table is a header.
516           } else {
517             iterator_restored_ = false;
518           }
519           if (!reader_.Valid()) {
520             out_tensors->clear();
521             *end_of_sequence = true;
522             return Status::OK();
523           }
524           StringPiece key = reader_.key();
525           DCHECK_EQ(key, dataset()->FormatName(cur_index_, i));
526           TF_RETURN_IF_ERROR(reader_.ReadCurrent(&(*out_tensors)[i]));
527           TF_RETURN_IF_ERROR(reader_.status());
528         }
529         cur_index_++;
530         return Status::OK();
531       }
532 
533      protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const534       std::shared_ptr<model::Node> CreateNode(
535           IteratorContext* ctx, model::Node::Args args) const override {
536         return model::MakeKnownRatioNode(std::move(args),
537                                          /*ratio=*/1);
538       }
539 
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)540       Status SaveInternal(SerializationContext* ctx,
541                           IteratorStateWriter* writer) override {
542         mutex_lock l(mu_);
543         TF_RETURN_IF_ERROR(
544             writer->WriteScalar(full_name(kCurIndex), cur_index_));
545         return Status::OK();
546       }
547 
RestoreInternal(IteratorContext * ctx,IteratorStateReader * iterator_state_reader)548       Status RestoreInternal(
549           IteratorContext* ctx,
550           IteratorStateReader* iterator_state_reader) override {
551         mutex_lock l(mu_);
552         {
553           // TODO(b/78048575): Update this when saving size_t tensors directly
554           // is supported.
555           int64 temp;
556           TF_RETURN_IF_ERROR(
557               iterator_state_reader->ReadScalar(full_name(kCurIndex), &temp));
558           cur_index_ = static_cast<size_t>(temp);
559           if (cur_index_ != temp) {
560             return errors::Internal("Invalid value for cur_index ", temp);
561           }
562         }
563         if (!reader_.Valid()) {
564           return errors::Internal("Error initializing BundleReader.");
565         }
566         reader_.Seek(dataset()->FormatName(cur_index_, 0));
567         iterator_restored_ = true;
568         return Status::OK();
569       }
570 
571      private:
572       mutex mu_;
573       size_t cur_index_ TF_GUARDED_BY(mu_);
574       BundleReader reader_ TF_GUARDED_BY(mu_);
575       bool iterator_restored_ TF_GUARDED_BY(mu_);
576     };  // FileReaderIterator
577 
InitializeIterator(IteratorContext * ctx)578     Status InitializeIterator(IteratorContext* ctx)
579         TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
580       // We intentionally use the same prefix for both `FileReaderIterator` and
581       // `FileWriterIterator`. Since at any time there will be at most one of
582       // them alive, there should be no conflicts. This allows both iterators to
583       // use a common key for `cur_index`. We leverage this in the corner case
584       // when this iterator is restored from an old checkpoint in `write` mode
585       // and the cache has been completely flushed to disk since then. In that
586       // case we simply build a `FileReaderIterator` and seek to the
587       // `cur_index`.
588       switch (mode_) {
589         case Mode::read:
590           iterator_ =
591               absl::make_unique<FileReaderIterator>(FileReaderIterator::Params{
592                   dataset(), strings::StrCat(prefix(), kImpl)});
593           break;
594         case Mode::write:
595           iterator_ =
596               absl::make_unique<FileWriterIterator>(FileWriterIterator::Params{
597                   dataset(), strings::StrCat(prefix(), kImpl)});
598       }
599       TF_RETURN_IF_ERROR(iterator_->InitializeBase(ctx, this));
600       return iterator_->Initialize(ctx);
601     }
602 
603     mutex mu_;
604     enum Mode { read, write };
605     Mode mode_ TF_GUARDED_BY(mu_);
606     std::unique_ptr<IteratorBase> iterator_ TF_GUARDED_BY(mu_);
607   };  // FileIterator
608 
609   Env* const env_;
610   const size_t num_tensors_;
611   const size_t tensor_index_padding_size_;
612   static constexpr size_t kMaxItems = 10000000;  // 10 million
613   const size_t item_index_padding_size_;
614   const string tensor_format_string_;
615 };  // FileDatasetBase
616 
617 class CacheDatasetOp::FileDataset : public CacheDatasetOp::FileDatasetBase {
618  public:
619   using FileDatasetBase::FileDatasetBase;
620 
621  protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const622   Status AsGraphDefInternal(SerializationContext* ctx,
623                             DatasetGraphDefBuilder* b,
624                             Node** output) const override {
625     Node* input_graph = nullptr;
626     TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph));
627     Node* filename = nullptr;
628     TF_RETURN_IF_ERROR(b->AddScalar(filename_, &filename));
629     TF_RETURN_IF_ERROR(b->AddDataset(this, {input_graph, filename}, output));
630     return Status::OK();
631   }
632 };
633 
634 class CacheDatasetOp::FileDatasetV2 : public CacheDatasetOp::FileDatasetBase {
635  public:
FileDatasetV2(OpKernelContext * ctx,const DatasetBase * input,string filename,Env * env,const Tensor & resource_handle)636   explicit FileDatasetV2(OpKernelContext* ctx, const DatasetBase* input,
637                          string filename, Env* env,
638                          const Tensor& resource_handle)
639       : FileDatasetBase(ctx, input, filename, env),
640         resource_handle_(resource_handle) {}
641 
642  protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const643   Status AsGraphDefInternal(SerializationContext* ctx,
644                             DatasetGraphDefBuilder* b,
645                             Node** output) const override {
646     Node* input_node = nullptr;
647     TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node));
648     Node* filename_node = nullptr;
649     TF_RETURN_IF_ERROR(b->AddScalar(filename_, &filename_node));
650     Node* resource_handle_node = nullptr;
651     TF_RETURN_IF_ERROR(b->AddTensor(resource_handle_, &resource_handle_node));
652     TF_RETURN_IF_ERROR(b->AddDataset(
653         this, {input_node, filename_node, resource_handle_node}, output));
654     return Status::OK();
655   }
656 
657  private:
658   const Tensor resource_handle_;
659 };
660 
661 class CacheDatasetOp::MemoryDatasetBase : public DatasetBase {
662  public:
MemoryDatasetBase(OpKernelContext * ctx,const DatasetBase * input,std::shared_ptr<MemoryCache> cache)663   explicit MemoryDatasetBase(OpKernelContext* ctx, const DatasetBase* input,
664                              std::shared_ptr<MemoryCache> cache)
665       : DatasetBase(DatasetContext(ctx)),
666         input_(input),
667         cache_(std::move(cache)) {
668     input_->Ref();
669   }
670 
~MemoryDatasetBase()671   ~MemoryDatasetBase() override { input_->Unref(); }
672 
MakeIteratorInternal(const string & prefix) const673   std::unique_ptr<IteratorBase> MakeIteratorInternal(
674       const string& prefix) const override {
675     name_utils::IteratorPrefixParams params;
676     params.dataset_prefix = kMemoryDatasetPrefix;
677     return absl::make_unique<MemoryIterator>(
678         MemoryIterator::Params{
679             this, name_utils::IteratorPrefix(kDatasetType, prefix, params)},
680         cache_.get());
681   }
682 
output_dtypes() const683   const DataTypeVector& output_dtypes() const override {
684     return input_->output_dtypes();
685   }
686 
output_shapes() const687   const std::vector<PartialTensorShape>& output_shapes() const override {
688     return input_->output_shapes();
689   }
690 
DebugString() const691   string DebugString() const override {
692     name_utils::DatasetDebugStringParams params;
693     params.dataset_prefix = kMemoryDatasetPrefix;
694     return name_utils::DatasetDebugString(kDatasetType, params);
695   }
696 
Cardinality() const697   int64 Cardinality() const override { return input_->Cardinality(); }
698 
InputDatasets(std::vector<const DatasetBase * > * inputs) const699   Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
700     inputs->push_back(input_);
701     return Status::OK();
702   }
703 
CheckExternalState() const704   Status CheckExternalState() const override {
705     return input_->CheckExternalState();
706   }
707 
708  protected:
709   class MemoryIterator : public DatasetIterator<MemoryDatasetBase> {
710    public:
MemoryIterator(const Params & params,MemoryCache * cache)711     explicit MemoryIterator(const Params& params, MemoryCache* cache)
712         : DatasetIterator<MemoryDatasetBase>(params), cache_(cache) {}
713 
Initialize(IteratorContext * ctx)714     Status Initialize(IteratorContext* ctx) override {
715       mutex_lock l(mu_);
716       return InitializeIterator(ctx);
717     }
718 
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)719     Status GetNextInternal(IteratorContext* ctx,
720                            std::vector<Tensor>* out_tensors,
721                            bool* end_of_sequence) override {
722       mutex_lock l(mu_);
723       return iterator_->GetNext(ctx, out_tensors, end_of_sequence);
724     }
725 
726    protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const727     std::shared_ptr<model::Node> CreateNode(
728         IteratorContext* ctx, model::Node::Args args) const override {
729       return model::MakeKnownRatioNode(std::move(args),
730                                        /*ratio=*/1);
731     }
732 
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)733     Status SaveInternal(SerializationContext* ctx,
734                         IteratorStateWriter* writer) override {
735       mutex_lock l(mu_);
736       if (cache_->IsCompleted()) {
737         TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kCacheCompleted), ""));
738         TF_RETURN_IF_ERROR(
739             WriteElementsToCheckpoint(writer, prefix(), cache_->data()));
740       }
741       return SaveInput(ctx, writer, iterator_);
742     }
743 
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)744     Status RestoreInternal(IteratorContext* ctx,
745                            IteratorStateReader* reader) override {
746       mutex_lock l(mu_);
747       iterator_.reset();
748       cache_->Reset();
749       if (reader->Contains(full_name(kCacheCompleted))) {
750         std::vector<std::vector<Tensor>> temp_cache;
751         TF_RETURN_IF_ERROR(
752             ReadElementsFromCheckpoint(reader, prefix(), &temp_cache));
753         cache_->Complete(std::move(temp_cache));
754       }
755       TF_RETURN_IF_ERROR(InitializeIterator(ctx));
756       return RestoreInput(ctx, reader, iterator_);
757     }
758 
759    private:
760     class MemoryWriterIterator : public DatasetIterator<MemoryDatasetBase> {
761      public:
MemoryWriterIterator(const Params & params,MemoryCache * cache)762       explicit MemoryWriterIterator(const Params& params, MemoryCache* cache)
763           : DatasetIterator<MemoryDatasetBase>(params), cache_(cache) {}
764 
~MemoryWriterIterator()765       ~MemoryWriterIterator() override {
766         mutex_lock l(mu_);
767         if (!temp_cache_.empty() && !cache_->IsCompleted()) {
768           LOG(WARNING) << kIncompleteCacheErrorMessage;
769           cache_->Reset();
770         }
771       }
772 
Initialize(IteratorContext * ctx)773       Status Initialize(IteratorContext* ctx) override {
774         return dataset()->input_->MakeIterator(ctx, this, prefix(),
775                                                &input_impl_);
776       }
777 
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)778       Status GetNextInternal(IteratorContext* ctx,
779                              std::vector<Tensor>* out_tensors,
780                              bool* end_of_sequence) override {
781         mutex_lock l(mu_);
782         TF_RETURN_IF_ERROR(
783             input_impl_->GetNext(ctx, out_tensors, end_of_sequence));
784         if (*end_of_sequence) {
785           if (!cache_->IsCompleted()) {
786             VLOG(2) << "Finalizing the cache because EOF has been reached.";
787             cache_->Complete(std::move(temp_cache_));
788           }
789           return Status::OK();
790         }
791         RecordBufferEnqueue(ctx, *out_tensors);
792         temp_cache_.emplace_back(*out_tensors);
793         if (temp_cache_.size() == dataset()->input_->Cardinality()) {
794           VLOG(2) << "Finalizing the cache because its size matches the "
795                      "expected input cardinality.";
796           cache_->Complete(std::move(temp_cache_));
797         }
798         return Status::OK();
799       }
800 
801      protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const802       std::shared_ptr<model::Node> CreateNode(
803           IteratorContext* ctx, model::Node::Args args) const override {
804         return model::MakeKnownRatioNode(std::move(args),
805                                          /*ratio=*/1);
806       }
807 
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)808       Status SaveInternal(SerializationContext* ctx,
809                           IteratorStateWriter* writer) override {
810         mutex_lock l(mu_);
811         if (!cache_->IsCompleted()) {
812           TF_RETURN_IF_ERROR(
813               WriteElementsToCheckpoint(writer, prefix(), temp_cache_));
814         }
815         return SaveInput(ctx, writer, input_impl_);
816       }
817 
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)818       Status RestoreInternal(IteratorContext* ctx,
819                              IteratorStateReader* reader) override {
820         mutex_lock l(mu_);
821         if (!reader->Contains(full_name(kCacheCompleted))) {
822           TF_RETURN_IF_ERROR(
823               ReadElementsFromCheckpoint(reader, prefix(), &temp_cache_));
824         }
825         return RestoreInput(ctx, reader, input_impl_);
826       }
827 
828      private:
829       mutex mu_;
830       std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_);
831       MemoryCache* const cache_ TF_GUARDED_BY(mu_);  // not owned.
832       std::vector<std::vector<Tensor>> temp_cache_ TF_GUARDED_BY(mu_);
833     };  // MemoryWriterIterator
834 
835     class MemoryReaderIterator : public DatasetIterator<MemoryDatasetBase> {
836      public:
MemoryReaderIterator(const Params & params,MemoryCache * cache)837       explicit MemoryReaderIterator(const Params& params, MemoryCache* cache)
838           : DatasetIterator<MemoryDatasetBase>(params),
839             cache_(cache),
840             index_(0) {}
841 
Initialize(IteratorContext * ctx)842       Status Initialize(IteratorContext* ctx) override {
843         // The memory allocated for the cache is owned by the parent
844         // dataset but performance modeling uses the iterator abstraction and
845         // thus we record the memory allocated for the cache here. The caveat
846         // is that this is incorrect if there are concurrent instances of this
847         // iterator.
848         tf_shared_lock l(mu_);
849         for (size_t i = 0; i < cache_->size(); ++i) {
850           RecordBufferEnqueue(ctx, cache_->at(i));
851         }
852         return Status::OK();
853       }
854 
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)855       Status GetNextInternal(IteratorContext* ctx,
856                              std::vector<Tensor>* out_tensors,
857                              bool* end_of_sequence) override {
858         mutex_lock l(mu_);
859         if (index_ < cache_->size()) {
860           const std::vector<Tensor>& cache_tensors = cache_->at(index_);
861           out_tensors->insert(out_tensors->begin(), cache_tensors.begin(),
862                               cache_tensors.end());
863           index_++;
864           *end_of_sequence = false;
865           return Status::OK();
866         } else {
867           *end_of_sequence = true;
868           return Status::OK();
869         }
870       }
871 
872      protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const873       std::shared_ptr<model::Node> CreateNode(
874           IteratorContext* ctx, model::Node::Args args) const override {
875         return model::MakeKnownRatioNode(std::move(args),
876                                          /*ratio=*/1);
877       }
878 
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)879       Status SaveInternal(SerializationContext* ctx,
880                           IteratorStateWriter* writer) override {
881         mutex_lock l(mu_);
882         TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kIndex), index_));
883         return Status::OK();
884       }
885 
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)886       Status RestoreInternal(IteratorContext* ctx,
887                              IteratorStateReader* reader) override {
888         mutex_lock l(mu_);
889         {
890           int64 temp;
891           TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kIndex), &temp));
892           index_ = static_cast<size_t>(temp);
893         }
894         return Status::OK();
895       }
896 
897      private:
898       mutex mu_;
899       MemoryCache* const cache_ TF_GUARDED_BY(mu_);  // not owned.
900       size_t index_ TF_GUARDED_BY(mu_);
901     };  // MemoryReaderIterator
902 
InitializeIterator(IteratorContext * ctx)903     Status InitializeIterator(IteratorContext* ctx)
904         TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
905       if (cache_->IsCompleted()) {
906         iterator_ = absl::make_unique<MemoryReaderIterator>(
907             MemoryReaderIterator::Params{dataset(),
908                                          strings::StrCat(prefix(), kImpl)},
909             cache_);
910       } else {
911         iterator_ = absl::make_unique<MemoryWriterIterator>(
912             MemoryWriterIterator::Params{dataset(),
913                                          strings::StrCat(prefix(), kImpl)},
914             cache_);
915       }
916       TF_RETURN_IF_ERROR(iterator_->InitializeBase(ctx, this));
917       return iterator_->Initialize(ctx);
918     }
919 
920     mutex mu_;
921     MemoryCache* cache_ TF_GUARDED_BY(mu_);  // not owned.
922     std::unique_ptr<IteratorBase> iterator_ TF_GUARDED_BY(mu_);
923   };  // MemoryIterator
924 
925   const DatasetBase* const input_;
926   const std::shared_ptr<MemoryCache> cache_;
927 };  // MemoryDatasetBase
928 
929 // This version of memory dataset has an exclusive ownership of the memory cache
930 // resource. It supports sharing of the cache across different iterations of the
931 // `repeat` transformation but not across different iterators.
932 class CacheDatasetOp::MemoryDataset : public CacheDatasetOp::MemoryDatasetBase {
933  public:
MemoryDataset(OpKernelContext * ctx,const DatasetBase * input,MemoryCacheManager * manager,ResourceHandle && resource_handle)934   MemoryDataset(OpKernelContext* ctx, const DatasetBase* input,
935                 MemoryCacheManager* manager, ResourceHandle&& resource_handle)
936       : MemoryDatasetBase(ctx, input, manager->get()),
937         manager_(manager),
938         resource_handle_(std::move(resource_handle)),
939         resource_mgr_(ctx->resource_manager()) {}
940 
~MemoryDataset()941   ~MemoryDataset() override {
942     manager_->Unref();
943     Status s = resource_mgr_->Delete<MemoryCacheManager>(
944         resource_handle_.container(), resource_handle_.name());
945     if (!s.ok()) {
946       LOG(WARNING) << "Failed to delete cache resource: " << s.ToString();
947     }
948   }
949 
950  protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const951   Status AsGraphDefInternal(SerializationContext* ctx,
952                             DatasetGraphDefBuilder* b,
953                             Node** output) const override {
954     Node* input_node = nullptr;
955     TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node));
956     Node* filename_node = nullptr;
957     TF_RETURN_IF_ERROR(b->AddScalar(tstring(""), &filename_node));
958     TF_RETURN_IF_ERROR(
959         b->AddDataset(this, {input_node, filename_node}, output));
960     return Status::OK();
961   }
962 
963  private:
964   MemoryCacheManager* const manager_;  // Owned.
965   const ResourceHandle resource_handle_;
966   ResourceMgr* const resource_mgr_;  // Not owned.
967 };
968 
969 // This version of memory dataset has a shared ownership of the memory cache
970 // resource. It supports sharing of the cache across different iterations of
971 // the `repeat` transformation and also across different iterators.
972 class CacheDatasetOp::MemoryDatasetV2
973     : public CacheDatasetOp::MemoryDatasetBase {
974  public:
MemoryDatasetV2(OpKernelContext * ctx,const DatasetBase * input,MemoryCacheManager * manager,ResourceHandle && resource_handle,bool owns_resource)975   MemoryDatasetV2(OpKernelContext* ctx, const DatasetBase* input,
976                   MemoryCacheManager* manager, ResourceHandle&& resource_handle,
977                   bool owns_resource)
978       : MemoryDatasetBase(ctx, input, manager->get()),
979         manager_(manager),
980         owns_resource_(owns_resource),
981         resource_handle_(std::move(resource_handle)),
982         resource_mgr_(ctx->resource_manager()) {}
983 
~MemoryDatasetV2()984   ~MemoryDatasetV2() override {
985     manager_->Unref();
986     if (owns_resource_) {
987       Status s = resource_mgr_->Delete<MemoryCacheManager>(
988           resource_handle_.container(), resource_handle_.name());
989       if (!s.ok()) {
990         LOG(WARNING) << "Failed to delete cache resource: " << s.ToString();
991       }
992     }
993   }
994 
995  protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const996   Status AsGraphDefInternal(SerializationContext* ctx,
997                             DatasetGraphDefBuilder* b,
998                             Node** output) const override {
999     Node* input_node = nullptr;
1000     TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node));
1001     Node* filename_node = nullptr;
1002     TF_RETURN_IF_ERROR(b->AddScalar(tstring(""), &filename_node));
1003     Node* resource_handle_node = nullptr;
1004     Tensor handle(DT_RESOURCE, TensorShape({}));
1005     handle.scalar<ResourceHandle>()() = resource_handle_;
1006     TF_RETURN_IF_ERROR(b->AddTensor(handle, &resource_handle_node));
1007     TF_RETURN_IF_ERROR(b->AddDataset(
1008         this, {input_node, filename_node, resource_handle_node}, output));
1009     return Status::OK();
1010   }
1011 
1012  private:
1013   MemoryCacheManager* const manager_;  // Owned.
1014   const bool owns_resource_;
1015   const ResourceHandle resource_handle_;
1016   ResourceMgr* const resource_mgr_;  // Not owned.
1017 };
1018 
CacheDatasetOp(OpKernelConstruction * ctx)1019 CacheDatasetOp::CacheDatasetOp(OpKernelConstruction* ctx)
1020     : UnaryDatasetOpKernel(ctx),
1021       op_version_(ctx->def().op() == kCacheDataset ? 1 : 2) {}
1022 
MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)1023 void CacheDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
1024                                  DatasetBase** output) {
1025   // Parse out the filenames tensor.
1026   tstring filename;
1027   OP_REQUIRES_OK(ctx, ParseScalarArgument<tstring>(ctx, kFileName, &filename));
1028   if (filename.empty()) {
1029     static std::atomic<int64> resource_id_counter(0);
1030     const string& container = ctx->resource_manager()->default_container();
1031     auto name = strings::StrCat(ctx->op_kernel().name(), "/", kMemoryCache, "_",
1032                                 resource_id_counter.fetch_add(1));
1033     if (op_version_ == 2) {
1034       bool owns_resource = false;
1035       MemoryCacheManager* manager = nullptr;
1036       auto handle = HandleFromInput(ctx, 2);
1037       Status s = ctx->resource_manager()->Lookup<MemoryCacheManager>(
1038           handle.container(), handle.name(), &manager);
1039       if (errors::IsNotFound(s)) {
1040         owns_resource = true;
1041         OP_REQUIRES_OK(
1042             ctx,
1043             ctx->resource_manager()->LookupOrCreate<MemoryCacheManager>(
1044                 container, name, &manager, [](MemoryCacheManager** manager) {
1045                   *manager = new MemoryCacheManager();
1046                   return Status::OK();
1047                 }));
1048         handle = MakeResourceHandle<MemoryCacheManager>(ctx, container, name);
1049       } else {
1050         OP_REQUIRES_OK(ctx, s);
1051       }
1052       // Ownership of manager is transferred onto `MemoryDatasetV2`.
1053       *output = new MemoryDatasetV2(ctx, input, manager, std::move(handle),
1054                                     owns_resource);
1055     } else {
1056       MemoryCacheManager* manager;
1057       OP_REQUIRES_OK(
1058           ctx, ctx->resource_manager()->LookupOrCreate<MemoryCacheManager>(
1059                    container, name, &manager, [](MemoryCacheManager** manager) {
1060                      *manager = new MemoryCacheManager();
1061                      return Status::OK();
1062                    }));
1063       auto handle =
1064           MakeResourceHandle<MemoryCacheManager>(ctx, container, name);
1065       // Ownership of manager is transferred onto `MemoryDataset`.
1066       *output = new MemoryDataset(ctx, input, manager, std::move(handle));
1067     }
1068   } else {
1069     if (op_version_ == 2) {
1070       *output =
1071           new FileDatasetV2(ctx, input, filename, ctx->env(), ctx->input(2));
1072     } else {
1073       *output = new FileDataset(ctx, input, filename, ctx->env());
1074     }
1075   }
1076 }
1077 
1078 namespace {
1079 REGISTER_KERNEL_BUILDER(Name("CacheDataset").Device(DEVICE_CPU),
1080                         CacheDatasetOp);
1081 REGISTER_KERNEL_BUILDER(Name("CacheDatasetV2").Device(DEVICE_CPU),
1082                         CacheDatasetOp);
1083 }  // namespace
1084 }  // namespace data
1085 }  // namespace tensorflow
1086