1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/cache_dataset_ops.h"
16
17 #include "tensorflow/core/framework/partial_tensor_shape.h"
18 #include "tensorflow/core/framework/resource_mgr.h"
19 #include "tensorflow/core/framework/tensor.h"
20 #include "tensorflow/core/kernels/data/cache_ops.h"
21 #include "tensorflow/core/kernels/data/dataset_utils.h"
22 #include "tensorflow/core/kernels/data/name_utils.h"
23 #include "tensorflow/core/lib/core/errors.h"
24 #include "tensorflow/core/lib/strings/stringprintf.h"
25 #include "tensorflow/core/platform/env.h"
26 #include "tensorflow/core/platform/errors.h"
27 #include "tensorflow/core/util/tensor_bundle/tensor_bundle.h"
28
29 namespace tensorflow {
30 namespace data {
31
32 // See documentation in ../../ops/dataset_ops.cc for a high-level description of
33 // the following op.
34
35 /* static */ constexpr const char* const CacheDatasetOp::kDatasetType;
36 /* static */ constexpr const char* const CacheDatasetOp::kInputDataset;
37 /* static */ constexpr const char* const CacheDatasetOp::kFileName;
38 /* static */ constexpr const char* const CacheDatasetOp::kOutputTypes;
39 /* static */ constexpr const char* const CacheDatasetOp::kOutputShapes;
40
41 namespace {
42
43 constexpr char kKeyStrFormat[] = "%%%zuzu_%%%zuzu";
44 constexpr char kPaddingSizeStrFormat[] = "%zu";
45 constexpr char kFileDatasetPrefix[] = "File";
46 constexpr char kMode[] = "Mode";
47 constexpr char kLockFileSuffix[] = ".lockfile";
48 constexpr char kIterationCompleted[] = "iteration_completed";
49 constexpr char kCurIndex[] = "cur_index";
50 constexpr char kShardId[] = "shard_id";
51 constexpr char kCreatedAt[] = "Created at";
52 constexpr char kMemoryDatasetPrefix[] = "Memory";
53 constexpr char kMemoryCache[] = "MemoryCache";
54 constexpr char kCacheClaimed[] = "cache_claimed";
55 constexpr char kCacheSize[] = "cache_size";
56 constexpr char kCache[] = "cache";
57 constexpr char kSizeSuffix[] = ".size";
58 constexpr char kCacheCompleted[] = "cache_completed";
59 constexpr char kIndex[] = "index";
60 constexpr char kImpl[] = "Impl";
61 constexpr char kCacheDataset[] = "CacheDataset";
62 constexpr char kIncompleteCacheErrorMessage[] =
63 "The calling iterator did not fully read the dataset being cached. In "
64 "order to avoid unexpected truncation of the dataset, the partially cached "
65 "contents of the dataset will be discarded. This can happen if you have "
66 "an input pipeline similar to `dataset.cache().take(k).repeat()`. You "
67 "should use `dataset.take(k).cache().repeat()` instead.";
68
69 } // namespace
70
71 class CacheDatasetOp::FileDatasetBase : public DatasetBase {
72 public:
FileDatasetBase(OpKernelContext * ctx,const DatasetBase * input,string filename,Env * env)73 FileDatasetBase(OpKernelContext* ctx, const DatasetBase* input,
74 string filename, Env* env)
75 : DatasetBase(DatasetContext(ctx)),
76 input_(input),
77 filename_(std::move(filename)),
78 env_(env),
79 num_tensors_(input->output_dtypes().size()),
80 tensor_index_padding_size_(StringPaddingSize(num_tensors_)),
81 item_index_padding_size_(StringPaddingSize(kMaxItems)),
82 tensor_format_string_(strings::Printf(kKeyStrFormat,
83 item_index_padding_size_,
84 tensor_index_padding_size_)) {
85 input_->Ref();
86 DCHECK_EQ(item_index_padding_size_, 7);
87 }
88
~FileDatasetBase()89 ~FileDatasetBase() override { input_->Unref(); }
90
MakeIteratorInternal(const string & prefix) const91 std::unique_ptr<IteratorBase> MakeIteratorInternal(
92 const string& prefix) const override {
93 name_utils::IteratorPrefixParams params;
94 params.dataset_prefix = kFileDatasetPrefix;
95 return absl::make_unique<FileIterator>(FileIterator::Params{
96 this, name_utils::IteratorPrefix(kDatasetType, prefix, params)});
97 }
98
output_dtypes() const99 const DataTypeVector& output_dtypes() const override {
100 return input_->output_dtypes();
101 }
102
output_shapes() const103 const std::vector<PartialTensorShape>& output_shapes() const override {
104 return input_->output_shapes();
105 }
106
DebugString() const107 string DebugString() const override {
108 name_utils::DatasetDebugStringParams params;
109 params.dataset_prefix = kFileDatasetPrefix;
110 return name_utils::DatasetDebugString(kDatasetType, params);
111 }
112
Cardinality() const113 int64 Cardinality() const override { return input_->Cardinality(); }
114
InputDatasets(std::vector<const DatasetBase * > * inputs) const115 Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
116 inputs->push_back(input_);
117 return Status::OK();
118 }
119
CheckExternalState() const120 Status CheckExternalState() const override {
121 return input_->CheckExternalState();
122 }
123
124 protected:
125 const DatasetBase* const input_;
126 const tstring filename_;
127
128 private:
StringPaddingSize(size_t num_tensors)129 static size_t StringPaddingSize(size_t num_tensors) {
130 return strings::Printf(kPaddingSizeStrFormat, num_tensors - 1).size();
131 }
132
FormatName(size_t item_index,size_t tensor_index) const133 string FormatName(size_t item_index, size_t tensor_index) const {
134 return strings::Printf(tensor_format_string_.c_str(), item_index,
135 tensor_index);
136 }
137
138 class FileIterator : public DatasetIterator<FileDatasetBase> {
139 public:
FileIterator(const Params & params)140 explicit FileIterator(const Params& params)
141 : DatasetIterator<FileDatasetBase>(params) {
142 if (params.dataset->env_
143 ->FileExists(MetaFilename(params.dataset->filename_))
144 .ok()) {
145 mode_ = Mode::read;
146 } else {
147 mode_ = Mode::write;
148 }
149 }
150
Initialize(IteratorContext * ctx)151 Status Initialize(IteratorContext* ctx) override {
152 mutex_lock l(mu_);
153 return InitializeIterator(ctx);
154 }
155
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)156 Status GetNextInternal(IteratorContext* ctx,
157 std::vector<Tensor>* out_tensors,
158 bool* end_of_sequence) override {
159 mutex_lock l(mu_);
160 return iterator_->GetNext(ctx, out_tensors, end_of_sequence);
161 }
162
163 protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const164 std::shared_ptr<model::Node> CreateNode(
165 IteratorContext* ctx, model::Node::Args args) const override {
166 return model::MakeKnownRatioNode(std::move(args),
167 /*ratio=*/1);
168 }
169
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)170 Status SaveInternal(SerializationContext* ctx,
171 IteratorStateWriter* writer) override {
172 mutex_lock l(mu_);
173 TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kMode), mode_));
174 return SaveInput(ctx, writer, iterator_);
175 }
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)176 Status RestoreInternal(IteratorContext* ctx,
177 IteratorStateReader* reader) override {
178 mutex_lock l(mu_);
179 {
180 int64 temp;
181 TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kMode), &temp));
182 mode_ = static_cast<Mode>(temp);
183 }
184 if (mode_ == Mode::write &&
185 dataset()
186 ->env_->FileExists(MetaFilename(dataset()->filename_))
187 .ok()) {
188 // This could happen if the cache was completely written after the
189 // checkpoint was saved.
190 LOG(WARNING)
191 << "It looks like the cache was already completely written("
192 << MetaFilename(dataset()->filename_)
193 << ") after the last checkpoint was saved. Attempting to read "
194 << "the cache instead of continuing to write. If this is a "
195 << "mistake, please remove the above file and try running again.";
196 mode_ = Mode::read;
197 }
198 TF_RETURN_IF_ERROR(InitializeIterator(ctx));
199 return RestoreInput(ctx, reader, iterator_);
200 }
201
202 private:
203 // FileWriterIterator passes through and caches items from the input
204 // FileDatasetBase.
205 //
206 // This iterator is used when the cache directory is not found on disk. It
207 // creates the cache directory, and passes on the underlying iterator's
208 // elements.
209 //
210 // Caching is performed by writing the input tensors to disk using the
211 // `BundleWriter`. Note that the cache gets fully flushed to disk only
212 // after the input iterator has been fully exhausted. If the program
213 // exits, before completion of an epoch, the cached state would be lost.
214 // To ensure that the partial cache persists across sessions, one should
215 // checkpoint the input pipeline. On each call to `SaveInternal` the
216 // partial cache gets flushed to disk in files with prefix
217 // <filename>_<shard_id> where shard_id is unique for each checkpoint.
218 // When all elements have been produced, these shards get coalesced.
219 class FileWriterIterator : public DatasetIterator<FileDatasetBase> {
220 public:
FileWriterIterator(const Params & params)221 explicit FileWriterIterator(const Params& params)
222 : DatasetIterator<FileDatasetBase>(params),
223 cur_index_(0),
224 shard_id_(0),
225 filename_(
226 strings::StrCat(params.dataset->filename_, "_", shard_id_)),
227 lockfile_(strings::StrCat(filename_, kLockFileSuffix)),
228 lockfile_created_(false),
229 iteration_completed_(false) {}
230
~FileWriterIterator()231 ~FileWriterIterator() override {
232 if (!dataset()->env_->FileExists(MetaFilename(filename_)).ok()) {
233 LOG(WARNING) << kIncompleteCacheErrorMessage;
234 std::vector<string> cache_files;
235 Status s = dataset()->env_->GetMatchingPaths(
236 strings::StrCat(filename_, "*"), &cache_files);
237 if (!s.ok()) {
238 LOG(WARNING) << "Failed to get matching files on " << filename_
239 << "* : " << s.ToString();
240 }
241 for (const string& path : cache_files) {
242 s = dataset()->env_->DeleteFile(path);
243 if (!s.ok()) {
244 LOG(WARNING) << "Failed to delete " << path << " : "
245 << s.ToString();
246 }
247 }
248 }
249 }
250
Initialize(IteratorContext * ctx)251 Status Initialize(IteratorContext* ctx) override {
252 return dataset()->input_->MakeIterator(ctx, this, prefix(),
253 &input_impl_);
254 }
255
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)256 Status GetNextInternal(IteratorContext* ctx,
257 std::vector<Tensor>* out_tensors,
258 bool* end_of_sequence) override {
259 mutex_lock l(mu_);
260 *end_of_sequence = false;
261 TF_RETURN_IF_ERROR(EnsureLockFileExists(end_of_sequence));
262 if (*end_of_sequence) {
263 return Status::OK();
264 }
265 TF_RETURN_IF_ERROR(writer_->status());
266 if (cur_index_ >= kMaxItems) {
267 // As a courtesy, close the [truncated] cache file.
268 Status s = Finish();
269 if (!s.ok()) {
270 LOG(ERROR) << s;
271 }
272 return errors::InvalidArgument(
273 "Upstream iterator is producing more than ", kMaxItems,
274 " items, which is more than the cache limit.");
275 }
276
277 TF_RETURN_IF_ERROR(
278 input_impl_->GetNext(ctx, out_tensors, end_of_sequence));
279 if (*end_of_sequence && out_tensors->empty()) {
280 TF_RETURN_IF_ERROR(Finish());
281 cur_index_++;
282 return Status::OK();
283 }
284 if (out_tensors->size() != dataset()->num_tensors_) {
285 return errors::Internal(
286 "Upstream iterator returned invalid number of tensors. "
287 "Expected ",
288 dataset()->num_tensors_, " got: ", out_tensors->size());
289 }
290 size_t tensor_index = 0;
291 for (const Tensor& t : *out_tensors) {
292 DCHECK_LT(tensor_index, dataset()->num_tensors_);
293 string key = dataset()->FormatName(cur_index_, tensor_index++);
294 TF_RETURN_IF_ERROR(writer_->Add(key, t));
295 }
296 if (*end_of_sequence) {
297 TF_RETURN_IF_ERROR(Finish());
298 }
299 cur_index_++;
300 return Status::OK();
301 }
302
303 protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const304 std::shared_ptr<model::Node> CreateNode(
305 IteratorContext* ctx, model::Node::Args args) const override {
306 return model::MakeKnownRatioNode(std::move(args),
307 /*ratio=*/1);
308 }
309
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)310 Status SaveInternal(SerializationContext* ctx,
311 IteratorStateWriter* writer) override {
312 mutex_lock l(mu_);
313 TF_RETURN_IF_ERROR(
314 writer->WriteScalar(full_name(kCurIndex), cur_index_));
315
316 if (iteration_completed_) {
317 TF_RETURN_IF_ERROR(
318 writer->WriteScalar(full_name(kIterationCompleted), ""));
319 return Status::OK();
320 }
321
322 // lockfile is created on the first call to GetNextInternal. The
323 // absence of a lockfile means that GetNextInternal was not called
324 // and hence nothing was written to cache. So we don't need to worry
325 // about flushing the current shard. This ensures that we never write
326 // empty shards.
327 if (lockfile_created_) {
328 // Flush the current bundle.
329 TF_RETURN_IF_ERROR(writer_->Finish());
330
331 // Note: We do not delete the lockfile here. We keep lockfiles of
332 // all shards around until the entire cache has been written to
333 // prevent concurrent iterators from corrupting any of the shards.
334
335 // Start caching to a new shard.
336 shard_id_++;
337 filename_ = strings::StrCat(dataset()->filename_, "_", shard_id_);
338 lockfile_ = strings::StrCat(filename_, kLockFileSuffix);
339 lockfile_created_ = false;
340 }
341 TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
342 TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kShardId), shard_id_));
343 return Status::OK();
344 }
345
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)346 Status RestoreInternal(IteratorContext* ctx,
347 IteratorStateReader* reader) override {
348 mutex_lock l(mu_);
349 int64 temp;
350 // TODO(b/78048575): Update this when saving size_t tensors directly
351 // is supported.
352 {
353 TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kCurIndex), &temp));
354 cur_index_ = static_cast<size_t>(temp);
355 if (cur_index_ != temp) {
356 return errors::Internal("Invalid value for cur_index ", temp);
357 }
358 }
359
360 if (reader->Contains(full_name(kIterationCompleted))) {
361 iteration_completed_ = true;
362 return Status::OK();
363 }
364
365 TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
366
367 // TODO(b/78048575): Update this when saving size_t tensors directly
368 // is supported.
369 {
370 TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kShardId), &temp));
371 shard_id_ = static_cast<size_t>(temp);
372 if (shard_id_ != temp) {
373 return errors::Internal("Invalid value for shard_id ", temp);
374 }
375 }
376 filename_ = strings::StrCat(dataset()->filename_, "_", shard_id_);
377 lockfile_ = strings::StrCat(filename_, kLockFileSuffix);
378 writer_ = absl::make_unique<BundleWriter>(dataset()->env_, filename_);
379 return Status::OK();
380 }
381
382 private:
EnsureLockFileExists(bool * end_of_sequence)383 Status EnsureLockFileExists(bool* end_of_sequence)
384 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
385 if (iteration_completed_) {
386 *end_of_sequence = true;
387 return Status::OK();
388 }
389 if (lockfile_created_) {
390 return Status::OK();
391 }
392
393 // Perform rudimentary locking to help catch concurrent writes to the
394 // same cache files.
395
396 // 1. Check that a checkpoint for the shard has not already been
397 // written.
398 if (dataset()->env_->FileExists(MetaFilename(filename_)).ok()) {
399 return errors::AlreadyExists("Existing cache files found: \n",
400 MetaFilename(filename_), "\n",
401 DataFilename(filename_, 0, 1), "\n",
402 "To continue delete the above files.");
403 }
404
405 // 2. Check that there isn't a concurrent iterator that is writing
406 // to cache.
407 if (dataset()->env_->FileExists(lockfile_).ok()) {
408 // Attempt to read the contents of the lockfile.
409 char contents_scratch[151] = {0}; // Initialize all to 0.
410 StringPiece contents;
411 std::unique_ptr<RandomAccessFile> file;
412 if (dataset()->env_->NewRandomAccessFile(lockfile_, &file).ok()) {
413 file->Read(0, 150, &contents, contents_scratch).IgnoreError();
414 }
415 return errors::AlreadyExists(
416 "There appears to be a concurrent caching iterator running - "
417 "cache lockfile already exists ('",
418 lockfile_,
419 "'). If you are sure no other running TF computations are "
420 "using this cache prefix, delete the lockfile and "
421 "re-initialize the iterator. Lockfile contents: ",
422 contents);
423 }
424 // Create the file, and write some basic contents.
425 std::unique_ptr<WritableFile> lockfile;
426 TF_RETURN_IF_ERROR(
427 dataset()->env_->NewWritableFile(lockfile_, &lockfile));
428 TF_RETURN_IF_ERROR(lockfile->Append(
429 strings::StrCat(kCreatedAt, ": ", EnvTime::NowSeconds())));
430
431 // At this point we know that
432 // 1. There is no conflicting checkpoint with prefix `filename_`.
433 // 2. There is no concurrent session that is trying to write a ckpt
434 // to filename.
435 // So it is safe to create a BundleWriter here. Note that it is
436 // unsafe to initialize the BundleWriter anywhere the above
437 // conditions are not met since BundleWriter's constructor creates
438 // new temp files which can delete the temp files created by a
439 // BundleWriter in another Session.
440 writer_ = absl::make_unique<BundleWriter>(dataset()->env_, filename_);
441 lockfile_created_ = true;
442 return Status::OK();
443 }
444
Finish()445 Status Finish() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
446 iteration_completed_ = true;
447 // Flush the current bundle.
448 TF_RETURN_IF_ERROR(writer_->Finish());
449 // Merge all the bundles.
450 // Currently there are `shard_id_ + 1` bundles, one for each
451 // checkpoint. Each bundle has prefix <filename>_<id> where `id` is an
452 // integer starting at 0 and incremented by 1 for each new checkpoint.
453 // We merge all these bundles into a bundle with prefix <filename> so
454 // that the next call to `MakeIterator` can build a
455 // `FileReaderIterator`.
456 {
457 std::vector<tstring> prefixes;
458 prefixes.reserve(shard_id_ + 1);
459 for (size_t i = 0; i <= shard_id_; ++i) {
460 prefixes.emplace_back(
461 strings::StrCat(dataset()->filename_, "_", i));
462 }
463 TF_RETURN_IF_ERROR(
464 MergeBundles(dataset()->env_, prefixes, dataset()->filename_));
465 }
466 // Delete all lockfiles.
467 for (size_t i = 0; i <= shard_id_; ++i) {
468 TF_RETURN_IF_ERROR(dataset()->env_->DeleteFile(
469 strings::StrCat(dataset()->filename_, "_", i, kLockFileSuffix)));
470 }
471 return Status::OK();
472 }
473
474 mutex mu_;
475 size_t cur_index_ TF_GUARDED_BY(mu_);
476 // Index of the current shard. This gets incremented whenever a new
477 // cache shard is saved.
478 size_t shard_id_ TF_GUARDED_BY(mu_);
479 std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_);
480 // The current prefix for the cache file. This is equal to
481 // `StrCat(dataset()->filename_, "_", shard_id_)`.
482 string filename_;
483 std::unique_ptr<BundleWriter> writer_ TF_GUARDED_BY(mu_);
484 string lockfile_ TF_GUARDED_BY(mu_);
485 bool lockfile_created_ TF_GUARDED_BY(mu_);
486 bool iteration_completed_ TF_GUARDED_BY(mu_);
487 }; // FileWriterIterator
488
489 class FileReaderIterator : public DatasetIterator<FileDatasetBase> {
490 public:
FileReaderIterator(const Params & params)491 explicit FileReaderIterator(const Params& params)
492 : DatasetIterator<FileDatasetBase>(params),
493 cur_index_(0),
494 reader_(dataset()->env_, dataset()->filename_),
495 iterator_restored_(false) {}
496
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)497 Status GetNextInternal(IteratorContext* ctx,
498 std::vector<Tensor>* out_tensors,
499 bool* end_of_sequence) override {
500 mutex_lock l(mu_);
501 *end_of_sequence = false;
502 TF_RETURN_IF_ERROR(reader_.status());
503 if (!reader_.Valid()) {
504 *end_of_sequence = true;
505 return Status::OK();
506 }
507 out_tensors->clear();
508 out_tensors->resize(dataset()->num_tensors_);
509
510 for (size_t i = 0; i < dataset()->num_tensors_; ++i) {
511 // When the iterator is restored from the checkpoint, `reader_` is
512 // already pointing at `key` so we do not need to skip the header
513 // entry.
514 if (!iterator_restored_) {
515 reader_.Next(); // The first entry in the table is a header.
516 } else {
517 iterator_restored_ = false;
518 }
519 if (!reader_.Valid()) {
520 out_tensors->clear();
521 *end_of_sequence = true;
522 return Status::OK();
523 }
524 StringPiece key = reader_.key();
525 DCHECK_EQ(key, dataset()->FormatName(cur_index_, i));
526 TF_RETURN_IF_ERROR(reader_.ReadCurrent(&(*out_tensors)[i]));
527 TF_RETURN_IF_ERROR(reader_.status());
528 }
529 cur_index_++;
530 return Status::OK();
531 }
532
533 protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const534 std::shared_ptr<model::Node> CreateNode(
535 IteratorContext* ctx, model::Node::Args args) const override {
536 return model::MakeKnownRatioNode(std::move(args),
537 /*ratio=*/1);
538 }
539
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)540 Status SaveInternal(SerializationContext* ctx,
541 IteratorStateWriter* writer) override {
542 mutex_lock l(mu_);
543 TF_RETURN_IF_ERROR(
544 writer->WriteScalar(full_name(kCurIndex), cur_index_));
545 return Status::OK();
546 }
547
RestoreInternal(IteratorContext * ctx,IteratorStateReader * iterator_state_reader)548 Status RestoreInternal(
549 IteratorContext* ctx,
550 IteratorStateReader* iterator_state_reader) override {
551 mutex_lock l(mu_);
552 {
553 // TODO(b/78048575): Update this when saving size_t tensors directly
554 // is supported.
555 int64 temp;
556 TF_RETURN_IF_ERROR(
557 iterator_state_reader->ReadScalar(full_name(kCurIndex), &temp));
558 cur_index_ = static_cast<size_t>(temp);
559 if (cur_index_ != temp) {
560 return errors::Internal("Invalid value for cur_index ", temp);
561 }
562 }
563 if (!reader_.Valid()) {
564 return errors::Internal("Error initializing BundleReader.");
565 }
566 reader_.Seek(dataset()->FormatName(cur_index_, 0));
567 iterator_restored_ = true;
568 return Status::OK();
569 }
570
571 private:
572 mutex mu_;
573 size_t cur_index_ TF_GUARDED_BY(mu_);
574 BundleReader reader_ TF_GUARDED_BY(mu_);
575 bool iterator_restored_ TF_GUARDED_BY(mu_);
576 }; // FileReaderIterator
577
InitializeIterator(IteratorContext * ctx)578 Status InitializeIterator(IteratorContext* ctx)
579 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
580 // We intentionally use the same prefix for both `FileReaderIterator` and
581 // `FileWriterIterator`. Since at any time there will be at most one of
582 // them alive, there should be no conflicts. This allows both iterators to
583 // use a common key for `cur_index`. We leverage this in the corner case
584 // when this iterator is restored from an old checkpoint in `write` mode
585 // and the cache has been completely flushed to disk since then. In that
586 // case we simply build a `FileReaderIterator` and seek to the
587 // `cur_index`.
588 switch (mode_) {
589 case Mode::read:
590 iterator_ =
591 absl::make_unique<FileReaderIterator>(FileReaderIterator::Params{
592 dataset(), strings::StrCat(prefix(), kImpl)});
593 break;
594 case Mode::write:
595 iterator_ =
596 absl::make_unique<FileWriterIterator>(FileWriterIterator::Params{
597 dataset(), strings::StrCat(prefix(), kImpl)});
598 }
599 TF_RETURN_IF_ERROR(iterator_->InitializeBase(ctx, this));
600 return iterator_->Initialize(ctx);
601 }
602
603 mutex mu_;
604 enum Mode { read, write };
605 Mode mode_ TF_GUARDED_BY(mu_);
606 std::unique_ptr<IteratorBase> iterator_ TF_GUARDED_BY(mu_);
607 }; // FileIterator
608
609 Env* const env_;
610 const size_t num_tensors_;
611 const size_t tensor_index_padding_size_;
612 static constexpr size_t kMaxItems = 10000000; // 10 million
613 const size_t item_index_padding_size_;
614 const string tensor_format_string_;
615 }; // FileDatasetBase
616
617 class CacheDatasetOp::FileDataset : public CacheDatasetOp::FileDatasetBase {
618 public:
619 using FileDatasetBase::FileDatasetBase;
620
621 protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const622 Status AsGraphDefInternal(SerializationContext* ctx,
623 DatasetGraphDefBuilder* b,
624 Node** output) const override {
625 Node* input_graph = nullptr;
626 TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph));
627 Node* filename = nullptr;
628 TF_RETURN_IF_ERROR(b->AddScalar(filename_, &filename));
629 TF_RETURN_IF_ERROR(b->AddDataset(this, {input_graph, filename}, output));
630 return Status::OK();
631 }
632 };
633
634 class CacheDatasetOp::FileDatasetV2 : public CacheDatasetOp::FileDatasetBase {
635 public:
FileDatasetV2(OpKernelContext * ctx,const DatasetBase * input,string filename,Env * env,const Tensor & resource_handle)636 explicit FileDatasetV2(OpKernelContext* ctx, const DatasetBase* input,
637 string filename, Env* env,
638 const Tensor& resource_handle)
639 : FileDatasetBase(ctx, input, filename, env),
640 resource_handle_(resource_handle) {}
641
642 protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const643 Status AsGraphDefInternal(SerializationContext* ctx,
644 DatasetGraphDefBuilder* b,
645 Node** output) const override {
646 Node* input_node = nullptr;
647 TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node));
648 Node* filename_node = nullptr;
649 TF_RETURN_IF_ERROR(b->AddScalar(filename_, &filename_node));
650 Node* resource_handle_node = nullptr;
651 TF_RETURN_IF_ERROR(b->AddTensor(resource_handle_, &resource_handle_node));
652 TF_RETURN_IF_ERROR(b->AddDataset(
653 this, {input_node, filename_node, resource_handle_node}, output));
654 return Status::OK();
655 }
656
657 private:
658 const Tensor resource_handle_;
659 };
660
661 class CacheDatasetOp::MemoryDatasetBase : public DatasetBase {
662 public:
MemoryDatasetBase(OpKernelContext * ctx,const DatasetBase * input,std::shared_ptr<MemoryCache> cache)663 explicit MemoryDatasetBase(OpKernelContext* ctx, const DatasetBase* input,
664 std::shared_ptr<MemoryCache> cache)
665 : DatasetBase(DatasetContext(ctx)),
666 input_(input),
667 cache_(std::move(cache)) {
668 input_->Ref();
669 }
670
~MemoryDatasetBase()671 ~MemoryDatasetBase() override { input_->Unref(); }
672
MakeIteratorInternal(const string & prefix) const673 std::unique_ptr<IteratorBase> MakeIteratorInternal(
674 const string& prefix) const override {
675 name_utils::IteratorPrefixParams params;
676 params.dataset_prefix = kMemoryDatasetPrefix;
677 return absl::make_unique<MemoryIterator>(
678 MemoryIterator::Params{
679 this, name_utils::IteratorPrefix(kDatasetType, prefix, params)},
680 cache_.get());
681 }
682
output_dtypes() const683 const DataTypeVector& output_dtypes() const override {
684 return input_->output_dtypes();
685 }
686
output_shapes() const687 const std::vector<PartialTensorShape>& output_shapes() const override {
688 return input_->output_shapes();
689 }
690
DebugString() const691 string DebugString() const override {
692 name_utils::DatasetDebugStringParams params;
693 params.dataset_prefix = kMemoryDatasetPrefix;
694 return name_utils::DatasetDebugString(kDatasetType, params);
695 }
696
Cardinality() const697 int64 Cardinality() const override { return input_->Cardinality(); }
698
InputDatasets(std::vector<const DatasetBase * > * inputs) const699 Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
700 inputs->push_back(input_);
701 return Status::OK();
702 }
703
CheckExternalState() const704 Status CheckExternalState() const override {
705 return input_->CheckExternalState();
706 }
707
708 protected:
709 class MemoryIterator : public DatasetIterator<MemoryDatasetBase> {
710 public:
MemoryIterator(const Params & params,MemoryCache * cache)711 explicit MemoryIterator(const Params& params, MemoryCache* cache)
712 : DatasetIterator<MemoryDatasetBase>(params), cache_(cache) {}
713
Initialize(IteratorContext * ctx)714 Status Initialize(IteratorContext* ctx) override {
715 mutex_lock l(mu_);
716 return InitializeIterator(ctx);
717 }
718
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)719 Status GetNextInternal(IteratorContext* ctx,
720 std::vector<Tensor>* out_tensors,
721 bool* end_of_sequence) override {
722 mutex_lock l(mu_);
723 return iterator_->GetNext(ctx, out_tensors, end_of_sequence);
724 }
725
726 protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const727 std::shared_ptr<model::Node> CreateNode(
728 IteratorContext* ctx, model::Node::Args args) const override {
729 return model::MakeKnownRatioNode(std::move(args),
730 /*ratio=*/1);
731 }
732
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)733 Status SaveInternal(SerializationContext* ctx,
734 IteratorStateWriter* writer) override {
735 mutex_lock l(mu_);
736 if (cache_->IsCompleted()) {
737 TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kCacheCompleted), ""));
738 TF_RETURN_IF_ERROR(
739 WriteElementsToCheckpoint(writer, prefix(), cache_->data()));
740 }
741 return SaveInput(ctx, writer, iterator_);
742 }
743
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)744 Status RestoreInternal(IteratorContext* ctx,
745 IteratorStateReader* reader) override {
746 mutex_lock l(mu_);
747 iterator_.reset();
748 cache_->Reset();
749 if (reader->Contains(full_name(kCacheCompleted))) {
750 std::vector<std::vector<Tensor>> temp_cache;
751 TF_RETURN_IF_ERROR(
752 ReadElementsFromCheckpoint(reader, prefix(), &temp_cache));
753 cache_->Complete(std::move(temp_cache));
754 }
755 TF_RETURN_IF_ERROR(InitializeIterator(ctx));
756 return RestoreInput(ctx, reader, iterator_);
757 }
758
759 private:
760 class MemoryWriterIterator : public DatasetIterator<MemoryDatasetBase> {
761 public:
MemoryWriterIterator(const Params & params,MemoryCache * cache)762 explicit MemoryWriterIterator(const Params& params, MemoryCache* cache)
763 : DatasetIterator<MemoryDatasetBase>(params), cache_(cache) {}
764
~MemoryWriterIterator()765 ~MemoryWriterIterator() override {
766 mutex_lock l(mu_);
767 if (!temp_cache_.empty() && !cache_->IsCompleted()) {
768 LOG(WARNING) << kIncompleteCacheErrorMessage;
769 cache_->Reset();
770 }
771 }
772
Initialize(IteratorContext * ctx)773 Status Initialize(IteratorContext* ctx) override {
774 return dataset()->input_->MakeIterator(ctx, this, prefix(),
775 &input_impl_);
776 }
777
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)778 Status GetNextInternal(IteratorContext* ctx,
779 std::vector<Tensor>* out_tensors,
780 bool* end_of_sequence) override {
781 mutex_lock l(mu_);
782 TF_RETURN_IF_ERROR(
783 input_impl_->GetNext(ctx, out_tensors, end_of_sequence));
784 if (*end_of_sequence) {
785 if (!cache_->IsCompleted()) {
786 VLOG(2) << "Finalizing the cache because EOF has been reached.";
787 cache_->Complete(std::move(temp_cache_));
788 }
789 return Status::OK();
790 }
791 RecordBufferEnqueue(ctx, *out_tensors);
792 temp_cache_.emplace_back(*out_tensors);
793 if (temp_cache_.size() == dataset()->input_->Cardinality()) {
794 VLOG(2) << "Finalizing the cache because its size matches the "
795 "expected input cardinality.";
796 cache_->Complete(std::move(temp_cache_));
797 }
798 return Status::OK();
799 }
800
801 protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const802 std::shared_ptr<model::Node> CreateNode(
803 IteratorContext* ctx, model::Node::Args args) const override {
804 return model::MakeKnownRatioNode(std::move(args),
805 /*ratio=*/1);
806 }
807
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)808 Status SaveInternal(SerializationContext* ctx,
809 IteratorStateWriter* writer) override {
810 mutex_lock l(mu_);
811 if (!cache_->IsCompleted()) {
812 TF_RETURN_IF_ERROR(
813 WriteElementsToCheckpoint(writer, prefix(), temp_cache_));
814 }
815 return SaveInput(ctx, writer, input_impl_);
816 }
817
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)818 Status RestoreInternal(IteratorContext* ctx,
819 IteratorStateReader* reader) override {
820 mutex_lock l(mu_);
821 if (!reader->Contains(full_name(kCacheCompleted))) {
822 TF_RETURN_IF_ERROR(
823 ReadElementsFromCheckpoint(reader, prefix(), &temp_cache_));
824 }
825 return RestoreInput(ctx, reader, input_impl_);
826 }
827
828 private:
829 mutex mu_;
830 std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_);
831 MemoryCache* const cache_ TF_GUARDED_BY(mu_); // not owned.
832 std::vector<std::vector<Tensor>> temp_cache_ TF_GUARDED_BY(mu_);
833 }; // MemoryWriterIterator
834
835 class MemoryReaderIterator : public DatasetIterator<MemoryDatasetBase> {
836 public:
MemoryReaderIterator(const Params & params,MemoryCache * cache)837 explicit MemoryReaderIterator(const Params& params, MemoryCache* cache)
838 : DatasetIterator<MemoryDatasetBase>(params),
839 cache_(cache),
840 index_(0) {}
841
Initialize(IteratorContext * ctx)842 Status Initialize(IteratorContext* ctx) override {
843 // The memory allocated for the cache is owned by the parent
844 // dataset but performance modeling uses the iterator abstraction and
845 // thus we record the memory allocated for the cache here. The caveat
846 // is that this is incorrect if there are concurrent instances of this
847 // iterator.
848 tf_shared_lock l(mu_);
849 for (size_t i = 0; i < cache_->size(); ++i) {
850 RecordBufferEnqueue(ctx, cache_->at(i));
851 }
852 return Status::OK();
853 }
854
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)855 Status GetNextInternal(IteratorContext* ctx,
856 std::vector<Tensor>* out_tensors,
857 bool* end_of_sequence) override {
858 mutex_lock l(mu_);
859 if (index_ < cache_->size()) {
860 const std::vector<Tensor>& cache_tensors = cache_->at(index_);
861 out_tensors->insert(out_tensors->begin(), cache_tensors.begin(),
862 cache_tensors.end());
863 index_++;
864 *end_of_sequence = false;
865 return Status::OK();
866 } else {
867 *end_of_sequence = true;
868 return Status::OK();
869 }
870 }
871
872 protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const873 std::shared_ptr<model::Node> CreateNode(
874 IteratorContext* ctx, model::Node::Args args) const override {
875 return model::MakeKnownRatioNode(std::move(args),
876 /*ratio=*/1);
877 }
878
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)879 Status SaveInternal(SerializationContext* ctx,
880 IteratorStateWriter* writer) override {
881 mutex_lock l(mu_);
882 TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kIndex), index_));
883 return Status::OK();
884 }
885
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)886 Status RestoreInternal(IteratorContext* ctx,
887 IteratorStateReader* reader) override {
888 mutex_lock l(mu_);
889 {
890 int64 temp;
891 TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kIndex), &temp));
892 index_ = static_cast<size_t>(temp);
893 }
894 return Status::OK();
895 }
896
897 private:
898 mutex mu_;
899 MemoryCache* const cache_ TF_GUARDED_BY(mu_); // not owned.
900 size_t index_ TF_GUARDED_BY(mu_);
901 }; // MemoryReaderIterator
902
InitializeIterator(IteratorContext * ctx)903 Status InitializeIterator(IteratorContext* ctx)
904 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
905 if (cache_->IsCompleted()) {
906 iterator_ = absl::make_unique<MemoryReaderIterator>(
907 MemoryReaderIterator::Params{dataset(),
908 strings::StrCat(prefix(), kImpl)},
909 cache_);
910 } else {
911 iterator_ = absl::make_unique<MemoryWriterIterator>(
912 MemoryWriterIterator::Params{dataset(),
913 strings::StrCat(prefix(), kImpl)},
914 cache_);
915 }
916 TF_RETURN_IF_ERROR(iterator_->InitializeBase(ctx, this));
917 return iterator_->Initialize(ctx);
918 }
919
920 mutex mu_;
921 MemoryCache* cache_ TF_GUARDED_BY(mu_); // not owned.
922 std::unique_ptr<IteratorBase> iterator_ TF_GUARDED_BY(mu_);
923 }; // MemoryIterator
924
925 const DatasetBase* const input_;
926 const std::shared_ptr<MemoryCache> cache_;
927 }; // MemoryDatasetBase
928
929 // This version of memory dataset has an exclusive ownership of the memory cache
930 // resource. It supports sharing of the cache across different iterations of the
931 // `repeat` transformation but not across different iterators.
932 class CacheDatasetOp::MemoryDataset : public CacheDatasetOp::MemoryDatasetBase {
933 public:
MemoryDataset(OpKernelContext * ctx,const DatasetBase * input,MemoryCacheManager * manager,ResourceHandle && resource_handle)934 MemoryDataset(OpKernelContext* ctx, const DatasetBase* input,
935 MemoryCacheManager* manager, ResourceHandle&& resource_handle)
936 : MemoryDatasetBase(ctx, input, manager->get()),
937 manager_(manager),
938 resource_handle_(std::move(resource_handle)),
939 resource_mgr_(ctx->resource_manager()) {}
940
~MemoryDataset()941 ~MemoryDataset() override {
942 manager_->Unref();
943 Status s = resource_mgr_->Delete<MemoryCacheManager>(
944 resource_handle_.container(), resource_handle_.name());
945 if (!s.ok()) {
946 LOG(WARNING) << "Failed to delete cache resource: " << s.ToString();
947 }
948 }
949
950 protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const951 Status AsGraphDefInternal(SerializationContext* ctx,
952 DatasetGraphDefBuilder* b,
953 Node** output) const override {
954 Node* input_node = nullptr;
955 TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node));
956 Node* filename_node = nullptr;
957 TF_RETURN_IF_ERROR(b->AddScalar(tstring(""), &filename_node));
958 TF_RETURN_IF_ERROR(
959 b->AddDataset(this, {input_node, filename_node}, output));
960 return Status::OK();
961 }
962
963 private:
964 MemoryCacheManager* const manager_; // Owned.
965 const ResourceHandle resource_handle_;
966 ResourceMgr* const resource_mgr_; // Not owned.
967 };
968
969 // This version of memory dataset has a shared ownership of the memory cache
970 // resource. It supports sharing of the cache across different iterations of
971 // the `repeat` transformation and also across different iterators.
972 class CacheDatasetOp::MemoryDatasetV2
973 : public CacheDatasetOp::MemoryDatasetBase {
974 public:
MemoryDatasetV2(OpKernelContext * ctx,const DatasetBase * input,MemoryCacheManager * manager,ResourceHandle && resource_handle,bool owns_resource)975 MemoryDatasetV2(OpKernelContext* ctx, const DatasetBase* input,
976 MemoryCacheManager* manager, ResourceHandle&& resource_handle,
977 bool owns_resource)
978 : MemoryDatasetBase(ctx, input, manager->get()),
979 manager_(manager),
980 owns_resource_(owns_resource),
981 resource_handle_(std::move(resource_handle)),
982 resource_mgr_(ctx->resource_manager()) {}
983
~MemoryDatasetV2()984 ~MemoryDatasetV2() override {
985 manager_->Unref();
986 if (owns_resource_) {
987 Status s = resource_mgr_->Delete<MemoryCacheManager>(
988 resource_handle_.container(), resource_handle_.name());
989 if (!s.ok()) {
990 LOG(WARNING) << "Failed to delete cache resource: " << s.ToString();
991 }
992 }
993 }
994
995 protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const996 Status AsGraphDefInternal(SerializationContext* ctx,
997 DatasetGraphDefBuilder* b,
998 Node** output) const override {
999 Node* input_node = nullptr;
1000 TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node));
1001 Node* filename_node = nullptr;
1002 TF_RETURN_IF_ERROR(b->AddScalar(tstring(""), &filename_node));
1003 Node* resource_handle_node = nullptr;
1004 Tensor handle(DT_RESOURCE, TensorShape({}));
1005 handle.scalar<ResourceHandle>()() = resource_handle_;
1006 TF_RETURN_IF_ERROR(b->AddTensor(handle, &resource_handle_node));
1007 TF_RETURN_IF_ERROR(b->AddDataset(
1008 this, {input_node, filename_node, resource_handle_node}, output));
1009 return Status::OK();
1010 }
1011
1012 private:
1013 MemoryCacheManager* const manager_; // Owned.
1014 const bool owns_resource_;
1015 const ResourceHandle resource_handle_;
1016 ResourceMgr* const resource_mgr_; // Not owned.
1017 };
1018
CacheDatasetOp(OpKernelConstruction * ctx)1019 CacheDatasetOp::CacheDatasetOp(OpKernelConstruction* ctx)
1020 : UnaryDatasetOpKernel(ctx),
1021 op_version_(ctx->def().op() == kCacheDataset ? 1 : 2) {}
1022
MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)1023 void CacheDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
1024 DatasetBase** output) {
1025 // Parse out the filenames tensor.
1026 tstring filename;
1027 OP_REQUIRES_OK(ctx, ParseScalarArgument<tstring>(ctx, kFileName, &filename));
1028 if (filename.empty()) {
1029 static std::atomic<int64> resource_id_counter(0);
1030 const string& container = ctx->resource_manager()->default_container();
1031 auto name = strings::StrCat(ctx->op_kernel().name(), "/", kMemoryCache, "_",
1032 resource_id_counter.fetch_add(1));
1033 if (op_version_ == 2) {
1034 bool owns_resource = false;
1035 MemoryCacheManager* manager = nullptr;
1036 auto handle = HandleFromInput(ctx, 2);
1037 Status s = ctx->resource_manager()->Lookup<MemoryCacheManager>(
1038 handle.container(), handle.name(), &manager);
1039 if (errors::IsNotFound(s)) {
1040 owns_resource = true;
1041 OP_REQUIRES_OK(
1042 ctx,
1043 ctx->resource_manager()->LookupOrCreate<MemoryCacheManager>(
1044 container, name, &manager, [](MemoryCacheManager** manager) {
1045 *manager = new MemoryCacheManager();
1046 return Status::OK();
1047 }));
1048 handle = MakeResourceHandle<MemoryCacheManager>(ctx, container, name);
1049 } else {
1050 OP_REQUIRES_OK(ctx, s);
1051 }
1052 // Ownership of manager is transferred onto `MemoryDatasetV2`.
1053 *output = new MemoryDatasetV2(ctx, input, manager, std::move(handle),
1054 owns_resource);
1055 } else {
1056 MemoryCacheManager* manager;
1057 OP_REQUIRES_OK(
1058 ctx, ctx->resource_manager()->LookupOrCreate<MemoryCacheManager>(
1059 container, name, &manager, [](MemoryCacheManager** manager) {
1060 *manager = new MemoryCacheManager();
1061 return Status::OK();
1062 }));
1063 auto handle =
1064 MakeResourceHandle<MemoryCacheManager>(ctx, container, name);
1065 // Ownership of manager is transferred onto `MemoryDataset`.
1066 *output = new MemoryDataset(ctx, input, manager, std::move(handle));
1067 }
1068 } else {
1069 if (op_version_ == 2) {
1070 *output =
1071 new FileDatasetV2(ctx, input, filename, ctx->env(), ctx->input(2));
1072 } else {
1073 *output = new FileDataset(ctx, input, filename, ctx->env());
1074 }
1075 }
1076 }
1077
1078 namespace {
1079 REGISTER_KERNEL_BUILDER(Name("CacheDataset").Device(DEVICE_CPU),
1080 CacheDatasetOp);
1081 REGISTER_KERNEL_BUILDER(Name("CacheDatasetV2").Device(DEVICE_CPU),
1082 CacheDatasetOp);
1083 } // namespace
1084 } // namespace data
1085 } // namespace tensorflow
1086