• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/experimental/snapshot_dataset_op.h"
16 
17 #include <random>
18 
19 #include "absl/time/clock.h"
20 #include "tensorflow/core/common_runtime/dma_helper.h"
21 #include "tensorflow/core/framework/dataset.h"
22 #include "tensorflow/core/framework/op_kernel.h"
23 #include "tensorflow/core/framework/partial_tensor_shape.h"
24 #include "tensorflow/core/framework/stats_aggregator.h"
25 #include "tensorflow/core/framework/tensor.h"
26 #include "tensorflow/core/framework/tensor.pb.h"  // NOLINT
27 #include "tensorflow/core/grappler/graph_view.h"
28 #include "tensorflow/core/kernels/data/dataset_utils.h"
29 #include "tensorflow/core/kernels/data/experimental/snapshot_util.h"
30 #include "tensorflow/core/kernels/data/hash_utils.h"
31 #include "tensorflow/core/lib/core/coding.h"
32 #include "tensorflow/core/lib/core/errors.h"
33 #include "tensorflow/core/lib/core/raw_coding.h"
34 #include "tensorflow/core/lib/gtl/cleanup.h"
35 #include "tensorflow/core/lib/hash/hash.h"
36 #include "tensorflow/core/lib/io/buffered_inputstream.h"
37 #include "tensorflow/core/lib/io/compression.h"
38 #include "tensorflow/core/lib/io/path.h"
39 #include "tensorflow/core/lib/io/random_inputstream.h"
40 #include "tensorflow/core/platform/errors.h"
41 #include "tensorflow/core/platform/file_system.h"
42 #include "tensorflow/core/platform/snappy.h"
43 #if !defined(IS_SLIM_BUILD)
44 #include "tensorflow/core/lib/io/snappy/snappy_inputbuffer.h"
45 #include "tensorflow/core/lib/io/snappy/snappy_outputbuffer.h"
46 #include "tensorflow/core/lib/io/zlib_compression_options.h"
47 #include "tensorflow/core/lib/io/zlib_inputstream.h"
48 #include "tensorflow/core/lib/io/zlib_outputbuffer.h"
49 #endif  // IS_SLIM_BUILD
50 #include "tensorflow/core/lib/random/random.h"
51 #include "tensorflow/core/lib/strings/base64.h"
52 #include "tensorflow/core/lib/strings/proto_serialization.h"
53 #include "tensorflow/core/lib/strings/strcat.h"
54 #include "tensorflow/core/lib/strings/stringprintf.h"
55 #include "tensorflow/core/platform/cord.h"
56 #include "tensorflow/core/platform/macros.h"
57 #include "tensorflow/core/platform/stringprintf.h"
58 #include "tensorflow/core/profiler/lib/traceme.h"
59 #include "tensorflow/core/protobuf/snapshot.pb.h"
60 #include "tensorflow/core/util/batch_util.h"
61 #include "tensorflow/core/util/ptr_util.h"
62 
63 namespace tensorflow {
64 namespace data {
65 namespace experimental {
66 
67 /* static */ constexpr const char* const SnapshotDatasetV2Op::kDatasetType;
68 /* static */ constexpr const char* const SnapshotDatasetV2Op::kOutputTypes;
69 /* static */ constexpr const char* const SnapshotDatasetV2Op::kOutputShapes;
70 /* static */ constexpr const char* const SnapshotDatasetV2Op::kCompression;
71 /* static */ constexpr const char* const SnapshotDatasetV2Op::kReaderPrefix;
72 /* static */ constexpr const char* const SnapshotDatasetV2Op::kWriterPrefix;
73 /* static */ constexpr const char* const SnapshotDatasetV2Op::kHashValid;
74 /* static */ constexpr const char* const SnapshotDatasetV2Op::kHash;
75 /* static */ constexpr const char* const SnapshotDatasetV2Op::kCompressionAuto;
76 /* static */ constexpr const char* const SnapshotDatasetV2Op::kReaderFunc;
77 /* static */ constexpr const char* const SnapshotDatasetV2Op::kShardFunc;
78 /* static */ constexpr const char* const
79     SnapshotDatasetV2Op::kReaderFuncOtherArgs;
80 /* static */ constexpr const char* const
81     SnapshotDatasetV2Op::kShardFuncOtherArgs;
82 /* static */ constexpr const char* const
83     SnapshotDatasetV2Op::kReaderFuncTarguments;
84 /* static */ constexpr const char* const
85     SnapshotDatasetV2Op::kShardFuncTarguments;
86 /* static */ constexpr const int SnapshotDatasetV2Op::kFileFormatVersion;
87 
88 // ==== Snapshot Implementation ====
89 
90 /* The current snapshot on-disk layout is as follows:
91  *   /user/specified/path/
92  *     - graphhash1/
93  *       - snapshot.metadata  // metadata file
94  *       - run1/
95  *         - 00000000.shard/  // shard index
96  *           // new checkpoint files are created on all threads at once, either
97  *           // when a file gets too big, or when a TF checkpoint happens.
98  *           - 00000000.snapshot  // checkpoint file 0
99  *           - 00000001.snapshot  // checkpoint file 1
100  *           - ...
101  *         - 00000001.shard/
102  *           - 00000000.snapshot
103  *           - 00000001.snapshot
104  *           - ...
105  *         - 00000002.shard/
106  *           - 00000000.snapshot
107  *           - 00000001.snapshot
108  *           - ...
109  *           ...
110  *       - run2/
111  *           ...
112  *     - graphhash2/
113  *       ...
114  *     - graphhash3/
115  *       ...
116  */
117 
118 class SnapshotDatasetV2Op::Dataset : public DatasetBase {
119  public:
120   Dataset(OpKernelContext* ctx, const DatasetBase* input, uint64 hash,
121           const std::string& path, const std::string& compression,
122           const std::string& reader_prefix, const std::string& writer_prefix,
123           std::unique_ptr<CapturedFunction> reader_func,
124           std::unique_ptr<CapturedFunction> shard_func);
125 
126   ~Dataset() override;
127 
128   std::unique_ptr<IteratorBase> MakeIteratorInternal(
129       const string& prefix) const override;
130 
131   const DataTypeVector& output_dtypes() const override;
132 
133   const std::vector<PartialTensorShape>& output_shapes() const override;
134 
135   string DebugString() const override;
136 
137   int64 Cardinality() const override;
138 
139   Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override;
140 
141   Status CheckExternalState() const override;
142 
143  protected:
144   Status AsGraphDefInternal(SerializationContext* ctx,
145                             DatasetGraphDefBuilder* b,
146                             Node** output) const override;
147 
148  private:
149   const DatasetBase* input_;
150   const uint64 hash_;
151   const tstring path_;
152   const std::string compression_;
153   const std::string reader_prefix_;
154   const std::string writer_prefix_;
155 
156   std::unique_ptr<CapturedFunction> reader_func_;
157   std::unique_ptr<CapturedFunction> shard_func_;
158 
159   class Iterator;
160 };
161 
162 class SnapshotDatasetV2Op::Dataset::Iterator : public DatasetIterator<Dataset> {
163  public:
164   static constexpr const char* const kIteratorMode = "iterator_mode";
165   static constexpr const char* const kIndex = "index";
166   static constexpr const char* const kGraphHashDirectory =
167       "graph_hash_directory";
168 
169   explicit Iterator(const Params& params);
170 
171   Status Initialize(IteratorContext* ctx) override;
172 
173   Status GetNextInternal(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
174                          bool* end_of_sequence) override;
175 
176  protected:
177   Status SaveInternal(SerializationContext* ctx,
178                       IteratorStateWriter* writer) override;
179 
180   Status RestoreInternal(IteratorContext* ctx,
181                          IteratorStateReader* reader) override;
182 
183  private:
184   Status InitializeIterator(IteratorContext* ctx, IteratorStateReader* reader);
185 
186   int64 index_ TF_GUARDED_BY(mu_);
187   std::unique_ptr<IteratorBase> iterator_ TF_GUARDED_BY(mu_);
188   snapshot_util::Mode mode_ TF_GUARDED_BY(mu_);
189   const std::string hash_dir_;
190 
191   mutex mu_;
192 
193   class Reader;
194   class Writer;
195   class Passthrough;
196 };
197 
198 class SnapshotDatasetV2Op::Dataset::Iterator::Reader
199     : public DatasetIterator<Dataset> {
200  public:
201   static constexpr const char* const kIteratorName = "Reader";
202 
203   explicit Reader(const Params& params, int64 start_index);
204 
205   ~Reader() override;
206 
207   Status Initialize(IteratorContext* ctx) override;
208 
209   Status GetNextInternal(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
210                          bool* end_of_sequence) override;
211 
212  protected:
213   Status SaveInternal(SerializationContext* ctx,
214                       IteratorStateWriter* writer) override;
215 
216   Status RestoreInternal(IteratorContext* ctx,
217                          IteratorStateReader* reader) override;
218 
219  private:
220   const int64 start_index_;
221 
222   mutex mu_;
223 
224   std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_);
225 
226   DatasetBase* input_ TF_GUARDED_BY(mu_);
227 
228   std::unique_ptr<InstantiatedCapturedFunction> instantiated_reader_func_
229       TF_GUARDED_BY(mu_);
230 };
231 
232 class SnapshotDatasetV2Op::Dataset::Iterator::Writer
233     : public DatasetIterator<Dataset> {
234  public:
235   static constexpr const char* const kIteratorName = "Writer";
236   static constexpr const char* const kRunId = "run_id";
237   static constexpr const char* const kCurrentCheckpointId =
238       "current_checkpoint_id";
239 
240   explicit Writer(const Params& params);
241 
242   ~Writer() override;
243 
244   Status Initialize(IteratorContext* ctx) override;
245 
246   Status GetNextInternal(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
247                          bool* end_of_sequence) override;
248 
249  protected:
250   Status SaveInternal(SerializationContext* ctx,
251                       IteratorStateWriter* writer) override;
252 
253   Status RestoreInternal(IteratorContext* ctx,
254                          IteratorStateReader* reader) override;
255 
256  private:
257   Status GetShardIndex(IteratorContext* ctx, const std::vector<Tensor>& tensors,
258                        int64* shard_index) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
259 
260   Status WriteMetadataFile(Env* env, bool finalized)
261       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
262 
263   void SignalEOF(bool mark_closed) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
264 
265   mutex mu_;
266   mutex writer_status_mu_;
267   std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_);
268 
269   absl::flat_hash_map<int64, std::unique_ptr<snapshot_util::AsyncWriter>>
270       writers_ TF_GUARDED_BY(mu_);
271   Status writer_status_ TF_GUARDED_BY(writer_status_mu_);
272   bool writers_closed_ TF_GUARDED_BY(mu_);
273 
274   uint64 run_id_ TF_GUARDED_BY(mu_);
275   tstring run_dir_ TF_GUARDED_BY(mu_);
276 
277   // Stores the ID of the current checkpoint .snapshot file being read. See top
278   // of this file for the directory layout.
279   uint64 current_checkpoint_id_ TF_GUARDED_BY(mu_);
280 
281   std::unique_ptr<InstantiatedCapturedFunction> instantiated_shard_func_
282       TF_GUARDED_BY(mu_);
283 };
284 
285 class SnapshotDatasetV2Op::Dataset::Iterator::Passthrough
286     : public DatasetIterator<Dataset> {
287  public:
288   static constexpr const char* const kIteratorName = "Passthrough";
289 
290   explicit Passthrough(const Params& params);
291 
292   Status Initialize(IteratorContext* ctx) override;
293 
294   Status GetNextInternal(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
295                          bool* end_of_sequence) override;
296 
297  protected:
298   Status SaveInternal(SerializationContext* ctx,
299                       IteratorStateWriter* writer) override;
300 
301   Status RestoreInternal(IteratorContext* ctx,
302                          IteratorStateReader* reader) override;
303 
304  private:
305   std::unique_ptr<IteratorBase> input_impl_;
306 };
307 
Dataset(OpKernelContext * ctx,const DatasetBase * input,uint64 hash,const std::string & path,const std::string & compression,const std::string & reader_prefix,const std::string & writer_prefix,std::unique_ptr<CapturedFunction> reader_func,std::unique_ptr<CapturedFunction> shard_func)308 SnapshotDatasetV2Op::Dataset::Dataset(
309     OpKernelContext* ctx, const DatasetBase* input, uint64 hash,
310     const std::string& path, const std::string& compression,
311     const std::string& reader_prefix, const std::string& writer_prefix,
312     std::unique_ptr<CapturedFunction> reader_func,
313     std::unique_ptr<CapturedFunction> shard_func)
314     : DatasetBase(DatasetContext(ctx)),
315       input_(input),
316       hash_(hash),
317       path_(path),
318       compression_(compression),
319       reader_prefix_(reader_prefix),
320       writer_prefix_(writer_prefix),
321       reader_func_(std::move(reader_func)),
322       shard_func_(std::move(shard_func)) {
323   input_->Ref();
324 }
325 
~Dataset()326 SnapshotDatasetV2Op::Dataset::~Dataset() { input_->Unref(); }
327 
328 std::unique_ptr<IteratorBase>
MakeIteratorInternal(const string & prefix) const329 SnapshotDatasetV2Op::Dataset::MakeIteratorInternal(const string& prefix) const {
330   return absl::make_unique<Iterator>(
331       Iterator::Params{this, absl::StrCat(prefix, "::Snapshot")});
332 }
333 
output_dtypes() const334 const DataTypeVector& SnapshotDatasetV2Op::Dataset::output_dtypes() const {
335   return input_->output_dtypes();
336 }
337 
338 const std::vector<PartialTensorShape>&
output_shapes() const339 SnapshotDatasetV2Op::Dataset::output_shapes() const {
340   return input_->output_shapes();
341 }
342 
DebugString() const343 string SnapshotDatasetV2Op::Dataset::DebugString() const {
344   return name_utils::DatasetDebugString(kDatasetType);
345 }
346 
Cardinality() const347 int64 SnapshotDatasetV2Op::Dataset::Cardinality() const {
348   return input_->Cardinality();
349 }
350 
InputDatasets(std::vector<const DatasetBase * > * inputs) const351 Status SnapshotDatasetV2Op::Dataset::InputDatasets(
352     std::vector<const DatasetBase*>* inputs) const {
353   inputs->push_back(input_);
354   return Status::OK();
355 }
356 
CheckExternalState() const357 Status SnapshotDatasetV2Op::Dataset::CheckExternalState() const {
358   return input_->CheckExternalState();
359 }
360 
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const361 Status SnapshotDatasetV2Op::Dataset::AsGraphDefInternal(
362     SerializationContext* ctx, DatasetGraphDefBuilder* b, Node** output) const {
363   Node* input_graph_node = nullptr;
364   TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
365 
366   Node* path = nullptr;
367   TF_RETURN_IF_ERROR(b->AddScalar(path_, &path));
368 
369   std::vector<Node*> reader_func_other_args;
370   DataTypeVector reader_func_other_args_types;
371   TF_RETURN_IF_ERROR(reader_func_->AddToGraph(ctx, b, &reader_func_other_args,
372                                               &reader_func_other_args_types));
373 
374   std::vector<Node*> shard_func_other_args;
375   DataTypeVector shard_func_other_args_types;
376   TF_RETURN_IF_ERROR(shard_func_->AddToGraph(ctx, b, &shard_func_other_args,
377                                              &shard_func_other_args_types));
378 
379   AttrValue compression_attr;
380   b->BuildAttrValue(compression_, &compression_attr);
381 
382   AttrValue reader_prefix_attr;
383   b->BuildAttrValue(reader_prefix_, &reader_prefix_attr);
384 
385   AttrValue writer_prefix_attr;
386   b->BuildAttrValue(writer_prefix_, &writer_prefix_attr);
387 
388   AttrValue hash_valid_attr;
389   b->BuildAttrValue(true, &hash_valid_attr);
390 
391   AttrValue hash_attr;
392   b->BuildAttrValue(static_cast<int64>(hash_), &hash_attr);
393 
394   AttrValue reader_func_attr;
395   b->BuildAttrValue(reader_func_->func(), &reader_func_attr);
396 
397   AttrValue shard_func_attr;
398   b->BuildAttrValue(shard_func_->func(), &shard_func_attr);
399 
400   AttrValue reader_func_arguments_types_attr;
401   b->BuildAttrValue(reader_func_other_args_types,
402                     &reader_func_arguments_types_attr);
403 
404   AttrValue shard_func_arguments_types_attr;
405   b->BuildAttrValue(shard_func_other_args_types,
406                     &shard_func_arguments_types_attr);
407 
408   return b->AddDataset(
409       this,
410       /*inputs=*/
411       {std::make_pair(0, input_graph_node), std::make_pair(1, path)},
412       /*list_inputs=*/
413       {std::make_pair(2, reader_func_other_args),
414        std::make_pair(3, shard_func_other_args)},
415       /*attrs=*/
416       {{kCompression, compression_attr},
417        {kReaderPrefix, reader_prefix_attr},
418        {kWriterPrefix, writer_prefix_attr},
419        {kHashValid, hash_valid_attr},
420        {kHash, hash_attr},
421        {kReaderFunc, reader_func_attr},
422        {kShardFunc, shard_func_attr},
423        {kReaderFuncTarguments, reader_func_arguments_types_attr},
424        {kShardFuncTarguments, shard_func_arguments_types_attr}},
425       output);
426 }
427 
Iterator(const Params & params)428 SnapshotDatasetV2Op::Dataset::Iterator::Iterator(const Params& params)
429     : DatasetIterator<Dataset>(params),
430       index_(0),
431       hash_dir_(
432           snapshot_util::HashDirectory(dataset()->path_, dataset()->hash_)) {}
433 
Initialize(IteratorContext * ctx)434 Status SnapshotDatasetV2Op::Dataset::Iterator::Initialize(
435     IteratorContext* ctx) {
436   return ctx->env()->RecursivelyCreateDir(
437       io::JoinPath(dataset()->writer_prefix_, hash_dir_));
438 }
439 
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)440 Status SnapshotDatasetV2Op::Dataset::Iterator::SaveInternal(
441     SerializationContext* ctx, IteratorStateWriter* writer) {
442   mutex_lock l(mu_);
443   if (iterator_ != nullptr) {
444     TF_RETURN_IF_ERROR(SaveInput(ctx, writer, iterator_));
445     TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kIteratorMode),
446                                            static_cast<int64>(mode_)));
447     TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kIndex), index_));
448     TF_RETURN_IF_ERROR(
449         writer->WriteScalar(full_name(kGraphHashDirectory), hash_dir_));
450   }
451 
452   return Status::OK();
453 }
454 
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)455 Status SnapshotDatasetV2Op::Dataset::Iterator::RestoreInternal(
456     IteratorContext* ctx, IteratorStateReader* reader) {
457   mutex_lock l(mu_);
458 
459   if (reader->Contains(full_name(kIteratorMode))) {
460     TF_RETURN_IF_ERROR(InitializeIterator(ctx, reader));
461     return RestoreInput(ctx, reader, iterator_);
462   }
463 
464   return Status::OK();
465 }
466 
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)467 Status SnapshotDatasetV2Op::Dataset::Iterator::GetNextInternal(
468     IteratorContext* ctx, std::vector<Tensor>* out_tensors,
469     bool* end_of_sequence) {
470   mutex_lock l(mu_);
471   if (iterator_ == nullptr) {
472     TF_RETURN_IF_ERROR(InitializeIterator(ctx, nullptr));
473   }
474   index_++;
475   return iterator_->GetNext(ctx, out_tensors, end_of_sequence);
476 }
477 
InitializeIterator(IteratorContext * ctx,IteratorStateReader * reader)478 Status SnapshotDatasetV2Op::Dataset::Iterator::InitializeIterator(
479     IteratorContext* ctx, IteratorStateReader* reader)
480     TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
481   if (reader != nullptr) {
482     // Check whether the computed hash directory is the same.
483     tstring hash_dir;
484     TF_RETURN_IF_ERROR(
485         reader->ReadScalar(full_name(kGraphHashDirectory), &hash_dir));
486     if (hash_dir != hash_dir_) {
487       return errors::DataLoss(
488           "Dataset has changed while restoring from the checkpoint. Old hash "
489           "directory: ",
490           hash_dir, "; new hash directory: ", hash_dir_);
491     }
492 
493     experimental::SnapshotMetadataRecord metadata;
494     bool file_exists;
495     TF_RETURN_IF_ERROR(snapshot_util::ReadMetadataFile(
496         ctx->env(), io::JoinPath(dataset()->reader_prefix_, hash_dir_),
497         &metadata, &file_exists));
498     if (!file_exists) {
499       return errors::DataLoss("Snapshot metadata file in ", hash_dir_,
500                               " does not exist any more.");
501     }
502 
503     int64 iterator_mode;
504     TF_RETURN_IF_ERROR(
505         reader->ReadScalar(full_name(kIteratorMode), &iterator_mode));
506     mode_ = snapshot_util::Mode(iterator_mode);
507 
508     TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kIndex), &index_));
509   } else {
510     experimental::SnapshotMetadataRecord metadata;
511     bool file_exists;
512     TF_RETURN_IF_ERROR(snapshot_util::ReadMetadataFile(
513         ctx->env(), io::JoinPath(dataset()->reader_prefix_, hash_dir_),
514         &metadata, &file_exists));
515 
516     // `pending_snapshot_expiry_seconds` is a legacy option where we would not
517     // write snapshots that we think were still on-going. We decided that this
518     // would not be necessary as a feature for SnapshotV2, and we would always
519     // write a new snapshot regardless of whether someone else is currently
520     // writing one. Setting this to 0 ensures that all previous snapshots
521     // will be ignored and we will proceed to writing.
522     TF_RETURN_IF_ERROR(snapshot_util::DetermineOpState(
523         /*mode_string=*/"", file_exists, &metadata,
524         /*pending_snapshot_expiry_seconds=*/0, &mode_));
525   }
526 
527   switch (mode_) {
528     case snapshot_util::READER:
529       iterator_ = absl::make_unique<Reader>(
530           Reader::Params{dataset(),
531                          absl::StrCat(prefix(), Reader::kIteratorName)},
532           index_);
533       break;
534     case snapshot_util::WRITER:
535       iterator_ = absl::make_unique<Writer>(Writer::Params{
536           dataset(), absl::StrCat(prefix(), Writer::kIteratorName)});
537       break;
538     case snapshot_util::PASSTHROUGH:
539       iterator_ = absl::make_unique<Passthrough>(Passthrough::Params{
540           dataset(), absl::StrCat(prefix(), Passthrough::kIteratorName)});
541       break;
542   }
543   TF_RETURN_IF_ERROR(iterator_->InitializeBase(ctx, this));
544   return iterator_->Initialize(ctx);
545 }
546 
Reader(const Params & params,int64 start_index)547 SnapshotDatasetV2Op::Dataset::Iterator::Reader::Reader(const Params& params,
548                                                        int64 start_index)
549     : DatasetIterator<Dataset>(params), start_index_(start_index) {}
550 
~Reader()551 SnapshotDatasetV2Op::Dataset::Iterator::Reader::~Reader() { input_->Unref(); }
552 
Initialize(IteratorContext * ctx)553 Status SnapshotDatasetV2Op::Dataset::Iterator::Reader::Initialize(
554     IteratorContext* ctx) {
555   mutex_lock l(mu_);
556 
557   TF_RETURN_IF_ERROR(
558       dataset()->reader_func_->Instantiate(ctx, &instantiated_reader_func_));
559 
560   auto hash_dir = snapshot_util::HashDirectory(
561       io::JoinPath(dataset()->reader_prefix_, dataset()->path_),
562       dataset()->hash_);
563   bool metadata_file_exists;
564   experimental::SnapshotMetadataRecord metadata;
565   TF_RETURN_IF_ERROR(snapshot_util::ReadMetadataFile(
566       ctx->env(), hash_dir, &metadata, &metadata_file_exists));
567 
568   auto run_dir = snapshot_util::RunDirectory(hash_dir, metadata.run_id());
569 
570   std::vector<std::string> snapshot_shard_dirs;
571   TF_RETURN_IF_ERROR(ctx->env()->GetMatchingPaths(
572       io::JoinPath(
573           run_dir,
574           strings::Printf("%s%s", "*", snapshot_util::kShardDirectorySuffix)),
575       &snapshot_shard_dirs));
576   std::sort(snapshot_shard_dirs.begin(), snapshot_shard_dirs.end());
577 
578   DatasetBase* dataset_of_snapshot_files;
579   TF_RETURN_IF_ERROR(snapshot_util::Reader::MakeNestedDataset(
580       ctx->env(), snapshot_shard_dirs, dataset()->compression_,
581       metadata.version(), dataset()->output_dtypes(),
582       dataset()->output_shapes(), start_index_, &dataset_of_snapshot_files));
583 
584   Tensor input_dataset_tensor(DT_VARIANT, TensorShape({}));
585   TF_RETURN_IF_ERROR(StoreDatasetInVariantTensor(dataset_of_snapshot_files,
586                                                  &input_dataset_tensor));
587 
588   std::vector<Tensor> reader_input;
589   std::vector<Tensor> reader_output;
590   reader_input.push_back(std::move(input_dataset_tensor));
591 
592   // NOTE: We intentionally ignore resource modeling outside GetNext().
593   TF_RETURN_IF_ERROR(instantiated_reader_func_->Run(
594       ctx, std::move(reader_input), &reader_output, /*node=*/nullptr));
595   if (reader_output.size() != 1) {
596     return errors::InvalidArgument(
597         "reader_func returns more than one argument.");
598   }
599   TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(reader_output[0], &input_));
600 
601   // We need to take a reference here as we will use the input_ and
602   // its iterator.
603   input_->Ref();
604 
605   return input_->MakeIterator(ctx, this, prefix(), &input_impl_);
606 }
607 
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)608 Status SnapshotDatasetV2Op::Dataset::Iterator::Reader::GetNextInternal(
609     IteratorContext* ctx, std::vector<Tensor>* out_tensors,
610     bool* end_of_sequence) {
611   mutex_lock l(mu_);
612   return input_impl_->GetNext(ctx, out_tensors, end_of_sequence);
613 }
614 
615 // We do not need to checkpoint the reader as we are rebuilding the reader
616 // datasets from information that is already saved by the main iterator.
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)617 Status SnapshotDatasetV2Op::Dataset::Iterator::Reader::SaveInternal(
618     SerializationContext* ctx, IteratorStateWriter* writer) {
619   return Status::OK();
620 }
621 
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)622 Status SnapshotDatasetV2Op::Dataset::Iterator::Reader::RestoreInternal(
623     IteratorContext* ctx, IteratorStateReader* reader) {
624   return Status::OK();
625 }
626 
Writer(const Params & params)627 SnapshotDatasetV2Op::Dataset::Iterator::Writer::Writer(const Params& params)
628     : DatasetIterator<Dataset>(params),
629       writers_closed_(false),
630       run_id_(0),
631       current_checkpoint_id_(0) {}
632 
~Writer()633 SnapshotDatasetV2Op::Dataset::Iterator::Writer::~Writer() {
634   mutex_lock l(mu_);
635   SignalEOF(true);
636 }
637 
SignalEOF(bool mark_closed)638 void SnapshotDatasetV2Op::Dataset::Iterator::Writer::SignalEOF(bool mark_closed)
639     TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
640   if (!writers_closed_) {
641     // Push the end of sequence signal to each of the threads to close files.
642     for (auto& writer : writers_) {
643       writer.second->SignalEOF();
644     }
645 
646     writers_.clear();
647     writers_closed_ = mark_closed;
648   }
649 }
650 
WriteMetadataFile(Env * env,bool finalized)651 Status SnapshotDatasetV2Op::Dataset::Iterator::Writer::WriteMetadataFile(
652     Env* env, bool finalized) {
653   DCHECK(!run_dir_.empty());
654 
655   experimental::SnapshotMetadataRecord metadata;
656   metadata.set_creation_timestamp(EnvTime::NowMicros());
657   metadata.set_graph_hash(strings::StrCat(dataset()->hash_));
658   metadata.set_run_id(strings::StrCat(run_id_));
659   metadata.set_version(kFileFormatVersion);
660   for (const auto& output_dtype : dataset()->output_dtypes()) {
661     metadata.add_dtype(output_dtype);
662   }
663   metadata.set_finalized(finalized);
664   tstring hash_directory = io::JoinPath(
665       dataset()->writer_prefix_,
666       snapshot_util::HashDirectory(dataset()->path_, dataset()->hash_));
667 
668   return snapshot_util::WriteMetadataFile(env, hash_directory, &metadata);
669 }
670 
Initialize(IteratorContext * ctx)671 Status SnapshotDatasetV2Op::Dataset::Iterator::Writer::Initialize(
672     IteratorContext* ctx) {
673   mutex_lock l(mu_);
674   TF_RETURN_IF_ERROR(
675       dataset()->shard_func_->Instantiate(ctx, &instantiated_shard_func_));
676 
677   return dataset()->input_->MakeIterator(
678       ctx, this, strings::StrCat(prefix(), "::WriterIterator"), &input_impl_);
679 }
680 
GetShardIndex(IteratorContext * ctx,const std::vector<Tensor> & tensors,int64 * shard_index)681 Status SnapshotDatasetV2Op::Dataset::Iterator::Writer::GetShardIndex(
682     IteratorContext* ctx, const std::vector<Tensor>& tensors,
683     int64* shard_index) {
684   std::vector<Tensor> output_tensors;
685 
686   // Run the shard function
687   TF_RETURN_IF_ERROR(instantiated_shard_func_->RunWithBorrowedArgs(
688       ctx, tensors, &output_tensors, model_node()));
689 
690   if (output_tensors.size() != 1 || output_tensors[0].dtype() != DT_INT64 ||
691       output_tensors[0].NumElements() != 1) {
692     return errors::InvalidArgument("`shard_func` must return a scalar int64.");
693   }
694 
695   // Create writable files if we see an index bigger than our current files.
696   *shard_index = output_tensors[0].flat<int64>()(0);
697   return Status::OK();
698 }
699 
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)700 Status SnapshotDatasetV2Op::Dataset::Iterator::Writer::GetNextInternal(
701     IteratorContext* ctx, std::vector<Tensor>* out_tensors,
702     bool* end_of_sequence) {
703   *end_of_sequence = false;
704   snapshot_util::AsyncWriter* current_writer;
705 
706   {
707     std::vector<Tensor> output_tensors;
708     mutex_lock l(mu_);
709 
710     // We initialize late here because restoring from checkpoint comes after the
711     // the Initialize call. We cannot initialize within Initialize() because
712     // we cannot determine whether we should overwrite an existing metadata
713     // file or not before `RestoreInternal` is potentially called.
714     if (run_dir_.empty()) {
715       run_id_ = random::New64();
716 
717       // Creates the run directory.
718       run_dir_ = snapshot_util::RunDirectory(
719           snapshot_util::HashDirectory(
720               io::JoinPath(dataset()->writer_prefix_, dataset()->path_),
721               dataset()->hash_),
722           run_id_);
723       TF_RETURN_IF_ERROR(ctx->env()->RecursivelyCreateDir(run_dir_));
724       TF_RETURN_IF_ERROR(WriteMetadataFile(ctx->env(), /*finalized=*/false));
725     }
726 
727     // Writers have either encountered an error or are closed.
728     {
729       mutex_lock wsl(writer_status_mu_);
730       if (!writer_status_.ok() || writers_closed_) {
731         *end_of_sequence = true;
732         return writer_status_;
733       }
734     }
735 
736     TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, out_tensors, end_of_sequence));
737 
738     // Finalize metadata file when we are at the end of the iterator.
739     if (*end_of_sequence) {
740       SignalEOF(/*mark_closed=*/true);
741       {
742         mutex_lock wsl(writer_status_mu_);
743         TF_RETURN_IF_ERROR(writer_status_);
744       }
745       return WriteMetadataFile(ctx->env(), /*finalized=*/true);
746     }
747 
748     int64 shard_index = 0;
749     TF_RETURN_IF_ERROR(GetShardIndex(ctx, *out_tensors, &shard_index));
750 
751     // If the index does not exist, we will start a new thread.
752     if (writers_.count(shard_index) == 0) {
753       auto snapshot_shard_directory =
754           snapshot_util::ShardDirectory(run_dir_, shard_index);
755       auto writer = std::make_unique<snapshot_util::AsyncWriter>(
756           ctx->env(), shard_index, snapshot_shard_directory,
757           current_checkpoint_id_, dataset()->compression_, kFileFormatVersion,
758           dataset()->output_dtypes(), [this](Status s) {
759             if (!s.ok()) {
760               LOG(ERROR) << "AsyncWriter in snapshot writer failed: " << s;
761               mutex_lock l(writer_status_mu_);
762               writer_status_ = s;
763             }
764           });
765       writers_.insert({shard_index, std::move(writer)});
766     }
767     current_writer = writers_[shard_index].get();
768   }
769 
770   current_writer->Write(*out_tensors);
771   return Status::OK();
772 }
773 
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)774 Status SnapshotDatasetV2Op::Dataset::Iterator::Writer::SaveInternal(
775     SerializationContext* ctx, IteratorStateWriter* writer) {
776   mutex_lock l(mu_);
777   TF_RETURN_IF_ERROR(
778       writer->WriteScalar(full_name(kRunId), static_cast<int64>(run_id_)));
779   TF_RETURN_IF_ERROR(
780       writer->WriteScalar(full_name(kCurrentCheckpointId),
781                           static_cast<int64>(current_checkpoint_id_)));
782   SignalEOF(/*mark_closed=*/false);
783   writers_.clear();
784   current_checkpoint_id_++;
785   return SaveInput(ctx, writer, input_impl_);
786 }
787 
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)788 Status SnapshotDatasetV2Op::Dataset::Iterator::Writer::RestoreInternal(
789     IteratorContext* ctx, IteratorStateReader* reader) {
790   mutex_lock l(mu_);
791   int64 run_id_signed;
792   int64 current_checkpoint_id;
793 
794   TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kRunId), &run_id_signed));
795   TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kCurrentCheckpointId),
796                                         &current_checkpoint_id));
797 
798   run_id_ = static_cast<uint64>(run_id_signed);
799   run_dir_ = snapshot_util::RunDirectory(
800       snapshot_util::HashDirectory(
801           io::JoinPath(dataset()->writer_prefix_, dataset()->path_),
802           dataset()->hash_),
803       run_id_);
804   current_checkpoint_id_ = static_cast<uint64>(current_checkpoint_id);
805 
806   return RestoreInput(ctx, reader, input_impl_);
807 }
808 
Passthrough(const Params & params)809 SnapshotDatasetV2Op::Dataset::Iterator::Passthrough::Passthrough(
810     const Params& params)
811     : DatasetIterator<Dataset>(params) {}
812 
Initialize(IteratorContext * ctx)813 Status SnapshotDatasetV2Op::Dataset::Iterator::Passthrough::Initialize(
814     IteratorContext* ctx) {
815   return dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_);
816 }
817 
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)818 Status SnapshotDatasetV2Op::Dataset::Iterator::Passthrough::GetNextInternal(
819     IteratorContext* ctx, std::vector<Tensor>* out_tensors,
820     bool* end_of_sequence) {
821   return input_impl_->GetNext(ctx, out_tensors, end_of_sequence);
822 }
823 
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)824 Status SnapshotDatasetV2Op::Dataset::Iterator::Passthrough::SaveInternal(
825     SerializationContext* ctx, IteratorStateWriter* writer) {
826   return SaveInput(ctx, writer, input_impl_);
827 }
828 
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)829 Status SnapshotDatasetV2Op::Dataset::Iterator::Passthrough::RestoreInternal(
830     IteratorContext* ctx, IteratorStateReader* reader) {
831   return RestoreInput(ctx, reader, input_impl_);
832 }
833 
SnapshotDatasetV2Op(OpKernelConstruction * ctx)834 SnapshotDatasetV2Op::SnapshotDatasetV2Op(OpKernelConstruction* ctx)
835     : UnaryDatasetOpKernel(ctx), graph_def_version_(ctx->graph_def_version()) {
836   FunctionMetadata::Params reader_params;
837   FunctionMetadata::Params shard_params;
838 
839   OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_types_));
840   OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
841   OP_REQUIRES_OK(ctx, ctx->GetAttr(kCompression, &compression_));
842 
843   if (ctx->HasAttr(kReaderPrefix)) {
844     OP_REQUIRES_OK(ctx, ctx->GetAttr(kReaderPrefix, &reader_prefix_));
845   }
846 
847   if (ctx->HasAttr(kWriterPrefix)) {
848     OP_REQUIRES_OK(ctx, ctx->GetAttr(kWriterPrefix, &writer_prefix_));
849   }
850   OP_REQUIRES_OK(ctx, ctx->GetAttr(kHashValid, &hash_valid_));
851   int64 hash;
852   OP_REQUIRES_OK(ctx, ctx->GetAttr(kHash, &hash));
853   hash_ = static_cast<uint64>(hash);
854 
855   OP_REQUIRES_OK(ctx, FunctionMetadata::Create(ctx, kReaderFunc, reader_params,
856                                                &reader_func_metadata_));
857   OP_REQUIRES_OK(ctx, FunctionMetadata::Create(ctx, kShardFunc, shard_params,
858                                                &shard_func_metadata_));
859 }
860 
MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)861 void SnapshotDatasetV2Op::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
862                                       DatasetBase** output) {
863   tstring path;
864   OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "path", &path));
865 
866   std::string compression = compression_ == kCompressionAuto
867                                 ? io::compression::kSnappy
868                                 : compression_;
869   uint64 hash;
870   if (hash_valid_) {
871     hash = hash_;
872   } else {
873     // Computes the hash of the preceding items in the graph.
874     GraphDef graph_def;
875     SerializationContext::Params params;
876     std::vector<std::pair<string, Tensor>> input_list;
877     params.input_list = &input_list;
878     params.external_state_policy =
879         SerializationContext::ExternalStatePolicy::kIgnore;
880     OP_REQUIRES_OK(
881         ctx, AsGraphDef(ctx, input, SerializationContext(params), &graph_def));
882     OP_REQUIRES_OK(ctx, HashGraph(graph_def, &hash));
883     // Different compression modes should result in different graph hashes.
884     hash = Hash64Combine(hash, Hash64(compression));
885   }
886 
887   std::unique_ptr<CapturedFunction> reader_func;
888   OP_REQUIRES_OK(ctx,
889                  CapturedFunction::Create(ctx, reader_func_metadata_,
890                                           kReaderFuncOtherArgs, &reader_func));
891   std::unique_ptr<CapturedFunction> shard_func;
892   OP_REQUIRES_OK(ctx,
893                  CapturedFunction::Create(ctx, shard_func_metadata_,
894                                           kShardFuncOtherArgs, &shard_func));
895 
896   *output = new SnapshotDatasetV2Op::Dataset(
897       ctx, input, hash, path, compression, reader_prefix_, writer_prefix_,
898       std::move(reader_func), std::move(shard_func));
899 }
900 
901 namespace {
902 REGISTER_KERNEL_BUILDER(Name("SnapshotDatasetV2").Device(DEVICE_CPU),
903                         SnapshotDatasetV2Op);
904 }  // namespace
905 
906 // ==== Legacy Snapshot Implementation (Deprecated) ====
907 
908 namespace {
909 
910 // Defaults to 10 GiB per shard.
911 const int64 kDefaultShardSizeBytes = 10LL * 1024 * 1024 * 1024;
912 
913 const int64 kCurrentVersion = 1;
914 
915 constexpr char kSnapshotReaderWorkerPool[] = "snapshot_reader_worker_pool";
916 constexpr char kSnapshotWriterWorkerPool[] = "snapshot_writer_worker_pool";
917 constexpr char kSeparator[] = "::";
918 constexpr char kBookkeeping[] = "Bookkeeping";
919 constexpr char kSnapshotReadElements[] = "snapshot_read_elements";
920 constexpr char kSnapshotReadThroughput[] = "snapshot_read_throughput";
921 constexpr char kSnapshotWrittenElements[] = "snapshot_written_elements";
922 constexpr char kSnapshotWriteThroughput[] = "snapshot_write_throughput";
923 
924 constexpr char kSizeSuffix[] = "_size";
925 constexpr char kState[] = "state";
926 constexpr char kHashDir[] = "hash_dir";
927 constexpr char kRunId[] = "run_id";
928 constexpr char kRunDir[] = "run_dir";
929 constexpr char kVersionStr[] = "version";
930 constexpr char kFilenames[] = "filenames";
931 constexpr char kCurrentFilenames[] = "current_filenames";
932 constexpr char kElementsProduced[] = "elements_produced";
933 constexpr char kNextFileIndex[] = "next_file_index";
934 constexpr char kNumFilesDone[] = "num_files_done";
935 constexpr char kNumElementsRead[] = "num_elements_read";
936 constexpr char kStatus[] = "status";
937 constexpr char kCode[] = ".code";
938 constexpr char kErrorMessage[] = ".error_message";
939 constexpr char kEndOfSequence[] = "end_of_sequence";
940 constexpr char kBuffer[] = "buffer";
941 constexpr char kNumElementsWritten[] = "num_elements_written";
942 constexpr char kNextElem[] = "next_elem";
943 
944 class SnapshotDatasetOp : public UnaryDatasetOpKernel {
945  public:
SnapshotDatasetOp(OpKernelConstruction * ctx)946   explicit SnapshotDatasetOp(OpKernelConstruction* ctx)
947       : UnaryDatasetOpKernel(ctx),
948         graph_def_version_(ctx->graph_def_version()) {
949     OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
950     OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
951 
952     OP_REQUIRES_OK(ctx,
953                    ctx->GetAttr("reader_path_prefix", &reader_path_prefix_));
954     OP_REQUIRES_OK(ctx,
955                    ctx->GetAttr("writer_path_prefix", &writer_path_prefix_));
956     OP_REQUIRES_OK(ctx, ctx->GetAttr("compression", &compression_));
957 
958     OP_REQUIRES_OK(ctx, ctx->GetAttr("shard_size_bytes", &shard_size_bytes_));
959     OP_REQUIRES_OK(ctx, ctx->GetAttr("pending_snapshot_expiry_seconds",
960                                      &pending_snapshot_expiry_seconds_));
961     OP_REQUIRES_OK(ctx,
962                    ctx->GetAttr("num_reader_threads", &num_reader_threads_));
963     OP_REQUIRES_OK(ctx,
964                    ctx->GetAttr("reader_buffer_size", &reader_buffer_size_));
965     OP_REQUIRES_OK(ctx,
966                    ctx->GetAttr("num_writer_threads", &num_writer_threads_));
967     OP_REQUIRES_OK(ctx,
968                    ctx->GetAttr("writer_buffer_size", &writer_buffer_size_));
969     OP_REQUIRES_OK(ctx, ctx->GetAttr("shuffle_on_read", &shuffle_on_read_));
970     OP_REQUIRES_OK(ctx, ctx->GetAttr("seed", &seed_));
971     OP_REQUIRES_OK(ctx, ctx->GetAttr("seed2", &seed2_));
972 
973     mode_ = snapshot_util::kModeAuto;
974     if (ctx->HasAttr("mode")) {
975       OP_REQUIRES_OK(ctx, ctx->GetAttr("mode", &mode_));
976     }
977 
978     snapshot_name_ = "";
979     if (ctx->HasAttr("snapshot_name")) {
980       OP_REQUIRES_OK(ctx, ctx->GetAttr("snapshot_name", &snapshot_name_));
981     }
982 
983     if (shard_size_bytes_ == -1) shard_size_bytes_ = kDefaultShardSizeBytes;
984 
985     // Default to 1 day expiry for snapshots.
986     if (pending_snapshot_expiry_seconds_ == -1) {
987       pending_snapshot_expiry_seconds_ = 86400;
988     }
989 
990     if (num_reader_threads_ == -1) num_reader_threads_ = 1;
991     if (reader_buffer_size_ == -1) reader_buffer_size_ = 1;
992     if (num_writer_threads_ == -1) num_writer_threads_ = 1;
993     if (writer_buffer_size_ == -1) writer_buffer_size_ = 1;
994 
995     OP_REQUIRES(
996         ctx,
997         compression_ == io::compression::kNone ||
998             compression_ == io::compression::kGzip ||
999             compression_ == io::compression::kSnappy,
1000         errors::InvalidArgument("compression must be either '', 'GZIP' or "
1001                                 "'SNAPPY'."));
1002 
1003     OP_REQUIRES(
1004         ctx, pending_snapshot_expiry_seconds_ >= 1,
1005         errors::InvalidArgument(
1006             "pending_snapshot_expiry_seconds must be at least 1 second."));
1007 
1008     OP_REQUIRES(ctx,
1009                 mode_ == snapshot_util::kModeAuto ||
1010                     mode_ == snapshot_util::kModeRead ||
1011                     mode_ == snapshot_util::kModeWrite ||
1012                     mode_ == snapshot_util::kModePassthrough,
1013                 errors::InvalidArgument(
1014                     "mode must be either '", snapshot_util::kModeAuto, "', '",
1015                     snapshot_util::kModeRead, "', '", snapshot_util::kModeWrite,
1016                     "', or '", snapshot_util::kModePassthrough, "'."));
1017   }
1018 
1019  protected:
MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)1020   void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
1021                    DatasetBase** output) override {
1022     tstring path;
1023 
1024     OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "path", &path));
1025 
1026     SerializationContext::Params params;
1027     std::vector<std::pair<string, Tensor>> input_list;
1028     params.input_list = &input_list;
1029     params.external_state_policy =
1030         SerializationContext::ExternalStatePolicy::kIgnore;
1031 
1032     GraphDef graph_def;
1033     OP_REQUIRES_OK(
1034         ctx, AsGraphDef(ctx, input, SerializationContext(params), &graph_def));
1035 
1036     uint64 hash;
1037     OP_REQUIRES_OK(ctx, ComputeDatasetHash(graph_def, path, &hash));
1038 
1039     Status dump_status =
1040         snapshot_util::DumpDatasetGraph(ctx->env(), path, hash, &graph_def);
1041     if (!dump_status.ok()) {
1042       LOG(WARNING) << "Unable to write graphdef to disk, error: "
1043                    << dump_status.ToString();
1044     }
1045 
1046     std::string graph_hash =
1047         strings::StrCat(strings::Hex(hash, strings::kZeroPad16));
1048     LOG(INFO) << "Graph def serialized to hash: " << graph_hash;
1049 
1050     *output = new Dataset(ctx, input, path, graph_hash, reader_path_prefix_,
1051                           writer_path_prefix_, compression_, shard_size_bytes_,
1052                           pending_snapshot_expiry_seconds_, num_reader_threads_,
1053                           reader_buffer_size_, num_writer_threads_,
1054                           writer_buffer_size_, shuffle_on_read_, seed_, seed2_,
1055                           mode_, snapshot_name_);
1056   }
1057 
1058  private:
1059   class Dataset : public DatasetBase {
1060    public:
Dataset(OpKernelContext * ctx,const DatasetBase * input,const string & path,const string & graph_hash,const string & reader_path_prefix,const string & writer_path_prefix,const string & compression,const uint64 shard_size_bytes,const uint64 pending_snapshot_expiry_seconds,const uint64 num_reader_threads,const uint64 reader_buffer_size,const uint64 num_writer_threads,const uint64 writer_buffer_size,const bool shuffle_on_read,const uint64 seed,const uint64 seed2,const std::string & mode,const std::string & snapshot_name)1061     Dataset(OpKernelContext* ctx, const DatasetBase* input, const string& path,
1062             const string& graph_hash, const string& reader_path_prefix,
1063             const string& writer_path_prefix, const string& compression,
1064             const uint64 shard_size_bytes,
1065             const uint64 pending_snapshot_expiry_seconds,
1066             const uint64 num_reader_threads, const uint64 reader_buffer_size,
1067             const uint64 num_writer_threads, const uint64 writer_buffer_size,
1068             const bool shuffle_on_read, const uint64 seed, const uint64 seed2,
1069             const std::string& mode, const std::string& snapshot_name)
1070         : DatasetBase(DatasetContext(ctx)),
1071           input_(input),
1072           dir_(path),
1073           graph_hash_(graph_hash),
1074           reader_path_prefix_(reader_path_prefix),
1075           writer_path_prefix_(writer_path_prefix),
1076           compression_(compression),
1077           shard_size_bytes_(shard_size_bytes),
1078           pending_snapshot_expiry_seconds_(pending_snapshot_expiry_seconds),
1079           num_reader_threads_(num_reader_threads),
1080           reader_buffer_size_(reader_buffer_size),
1081           num_writer_threads_(num_writer_threads),
1082           writer_buffer_size_(writer_buffer_size),
1083           shuffle_on_read_(shuffle_on_read),
1084           seed_(seed),
1085           seed2_(seed2),
1086           mode_(mode),
1087           snapshot_name_(snapshot_name) {
1088       input_->Ref();
1089     }
1090 
~Dataset()1091     ~Dataset() override { input_->Unref(); }
1092 
MakeIteratorInternal(const string & prefix) const1093     std::unique_ptr<IteratorBase> MakeIteratorInternal(
1094         const string& prefix) const override {
1095       return absl::make_unique<Iterator>(
1096           Iterator::Params{this, absl::StrCat(prefix, "::Snapshot")});
1097     }
1098 
output_dtypes() const1099     const DataTypeVector& output_dtypes() const override {
1100       return input_->output_dtypes();
1101     }
1102 
output_shapes() const1103     const std::vector<PartialTensorShape>& output_shapes() const override {
1104       return input_->output_shapes();
1105     }
1106 
DebugString() const1107     string DebugString() const override { return "SnapshotDatasetOp::Dataset"; }
1108 
Cardinality() const1109     int64 Cardinality() const override { return input_->Cardinality(); }
1110 
InputDatasets(std::vector<const DatasetBase * > * inputs) const1111     Status InputDatasets(
1112         std::vector<const DatasetBase*>* inputs) const override {
1113       inputs->push_back(input_);
1114       return Status::OK();
1115     }
1116 
CheckExternalState() const1117     Status CheckExternalState() const override {
1118       return input_->CheckExternalState();
1119     }
1120 
1121    protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const1122     Status AsGraphDefInternal(SerializationContext* ctx,
1123                               DatasetGraphDefBuilder* b,
1124                               Node** output) const override {
1125       Node* input_graph_node = nullptr;
1126       TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
1127 
1128       Node* path = nullptr;
1129       TF_RETURN_IF_ERROR(b->AddScalar(dir_, &path));
1130 
1131       AttrValue compression_attr;
1132       b->BuildAttrValue(compression_, &compression_attr);
1133 
1134       AttrValue reader_path_prefix_attr;
1135       b->BuildAttrValue(reader_path_prefix_, &reader_path_prefix_attr);
1136 
1137       AttrValue writer_path_prefix_attr;
1138       b->BuildAttrValue(writer_path_prefix_, &writer_path_prefix_attr);
1139 
1140       AttrValue shard_size_bytes_attr;
1141       b->BuildAttrValue<int64>(shard_size_bytes_, &shard_size_bytes_attr);
1142 
1143       AttrValue pending_snapshot_expiry_seconds_attr;
1144       b->BuildAttrValue<int64>(pending_snapshot_expiry_seconds_,
1145                                &pending_snapshot_expiry_seconds_attr);
1146 
1147       AttrValue num_reader_threads_attr;
1148       b->BuildAttrValue<int64>(num_reader_threads_, &num_reader_threads_attr);
1149 
1150       AttrValue reader_buffer_size_attr;
1151       b->BuildAttrValue<int64>(reader_buffer_size_, &reader_buffer_size_attr);
1152 
1153       AttrValue num_writer_threads_attr;
1154       b->BuildAttrValue<int64>(num_writer_threads_, &num_writer_threads_attr);
1155 
1156       AttrValue writer_buffer_size_attr;
1157       b->BuildAttrValue<int64>(writer_buffer_size_, &writer_buffer_size_attr);
1158 
1159       AttrValue shuffle_on_read_attr;
1160       b->BuildAttrValue<bool>(shuffle_on_read_, &shuffle_on_read_attr);
1161 
1162       AttrValue seed_attr;
1163       b->BuildAttrValue<int64>(seed_, &seed_attr);
1164 
1165       AttrValue seed2_attr;
1166       b->BuildAttrValue<int64>(seed2_, &seed2_attr);
1167 
1168       AttrValue mode_attr;
1169       b->BuildAttrValue(mode_, &mode_attr);
1170 
1171       AttrValue snapshot_name_attr;
1172       b->BuildAttrValue(snapshot_name_, &snapshot_name_attr);
1173 
1174       TF_RETURN_IF_ERROR(b->AddDataset(
1175           this,
1176           /*inputs=*/
1177           {std::make_pair(0, input_graph_node), std::make_pair(1, path)},
1178           /*list_inputs=*/
1179           {},
1180           /*attrs=*/
1181           {{"compression", compression_attr},
1182            {"reader_path_prefix", reader_path_prefix_attr},
1183            {"writer_path_prefix", writer_path_prefix_attr},
1184            {"shard_size_bytes", shard_size_bytes_attr},
1185            {"pending_snapshot_expiry_seconds",
1186             pending_snapshot_expiry_seconds_attr},
1187            {"num_reader_threads", num_reader_threads_attr},
1188            {"reader_buffer_size", reader_buffer_size_attr},
1189            {"num_writer_threads", num_writer_threads_attr},
1190            {"writer_buffer_size", writer_buffer_size_attr},
1191            {"shuffle_on_read", shuffle_on_read_attr},
1192            {"seed", seed_attr},
1193            {"seed2", seed2_attr},
1194            {"mode", mode_attr},
1195            {"snapshot_name", snapshot_name_attr}},
1196           output));
1197       return Status::OK();
1198     }
1199 
1200    private:
1201     class Iterator : public DatasetIterator<Dataset> {
1202      public:
Iterator(const Params & params)1203       explicit Iterator(const Params& params)
1204           : DatasetIterator<Dataset>(params) {
1205         if (dataset()->snapshot_name_.empty()) {
1206           hash_dir_ = io::JoinPath(dataset()->dir_, dataset()->graph_hash_);
1207         } else {
1208           hash_dir_ = io::JoinPath(
1209               dataset()->dir_,
1210               strings::StrCat("custom-", dataset()->snapshot_name_));
1211         }
1212       }
1213 
1214       // We have a somewhat non traditional pattern for iterator initialization
1215       // for Snapshot. The protocol is that we initialize the Reader / Writer
1216       // iterator on the first GetNext call. We also invoke the same
1217       // initialization code when restoring as well. The reason why we don't do
1218       // this during the Initialize call is because during Restore we call
1219       // Initialize at first and at that point we don't know which iterator
1220       // (Reader / Writer / Passthrough) we need to restore as this info is part
1221       // of the checkpoint.
Initialize(IteratorContext * ctx)1222       Status Initialize(IteratorContext* ctx) override { return Status::OK(); }
1223 
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)1224       Status GetNextInternal(IteratorContext* ctx,
1225                              std::vector<Tensor>* out_tensors,
1226                              bool* end_of_sequence) override {
1227         mutex_lock l(mu_);
1228         if (iterator_ == nullptr) {
1229           experimental::SnapshotMetadataRecord metadata;
1230           bool file_exists;
1231           TF_RETURN_IF_ERROR(snapshot_util::ReadMetadataFile(
1232               ctx->env(), hash_dir_, &metadata, &file_exists));
1233           TF_RETURN_IF_ERROR(snapshot_util::DetermineOpState(
1234               dataset()->mode_, file_exists, &metadata,
1235               dataset()->pending_snapshot_expiry_seconds_, &state_));
1236           VLOG(2) << "Snapshot state: " << state_;
1237           TF_RETURN_IF_ERROR(InitializeIterator(ctx, metadata));
1238         }
1239         return iterator_->GetNext(ctx, out_tensors, end_of_sequence);
1240       }
1241 
1242      protected:
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)1243       Status SaveInternal(SerializationContext* ctx,
1244                           IteratorStateWriter* writer) override {
1245         mutex_lock l(mu_);
1246         TF_RETURN_IF_ERROR(SaveInput(ctx, writer, iterator_));
1247         TF_RETURN_IF_ERROR(
1248             writer->WriteScalar(full_name(kState), static_cast<int64>(state_)));
1249         TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kHashDir), hash_dir_));
1250         VLOG(2) << "Saving Snapshot iterator: " << state_;
1251         return Status::OK();
1252       }
1253 
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)1254       Status RestoreInternal(IteratorContext* ctx,
1255                              IteratorStateReader* reader) override {
1256         mutex_lock l(mu_);
1257         tstring hash_dir;
1258         TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kHashDir), &hash_dir));
1259         if (hash_dir != hash_dir_) {
1260           LOG(ERROR) << "Dataset has changed while restoring from the "
1261                         "checkpoint. Old hash: "
1262                      << hash_dir << "; new hash: " << hash_dir_;
1263           return Status::OK();
1264         }
1265         {
1266           int64 temp;
1267           TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kState), &temp));
1268           state_ = snapshot_util::Mode(temp);
1269         }
1270         experimental::SnapshotMetadataRecord metadata;
1271         bool file_exists;
1272         TF_RETURN_IF_ERROR(snapshot_util::ReadMetadataFile(
1273             ctx->env(), hash_dir_, &metadata, &file_exists));
1274         TF_RETURN_IF_ERROR(InitializeIterator(ctx, metadata));
1275         VLOG(2) << "Restoring Snapshot iterator: " << state_;
1276         return RestoreInput(ctx, reader, iterator_);
1277       }
1278 
1279       // This method expects that state_ is populated and it will create the
1280       // correct Reader / Writer / Passthrough iterator and initialize it.
InitializeIterator(IteratorContext * ctx,const experimental::SnapshotMetadataRecord & metadata)1281       Status InitializeIterator(
1282           IteratorContext* ctx,
1283           const experimental::SnapshotMetadataRecord& metadata)
1284           TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
1285         std::string run_id = "";
1286         if (!dataset()->snapshot_name_.empty()) {
1287           // We have overridden the snapshot with a custom name, so we don't
1288           // generate random run ids, but just use the same one.
1289           run_id = "custom";
1290         }
1291 
1292         switch (state_) {
1293           case snapshot_util::WRITER:
1294             iterator_ = absl::make_unique<SnapshotWriterIterator>(
1295                 SnapshotWriterIterator::Params{
1296                     dataset(), absl::StrCat(prefix(), "WriterImpl")},
1297                 hash_dir_, run_id);
1298             break;
1299           case snapshot_util::READER:
1300             if (run_id.empty() && metadata.run_id().empty()) {
1301               return errors::NotFound(
1302                   "Could not find a valid snapshot to read.");
1303             }
1304             if (run_id.empty()) {
1305               run_id = metadata.run_id();
1306             }
1307             // dtypes in metadata should be the same as dataset()->output_dtypes
1308             if (metadata.dtype_size() != dataset()->output_dtypes().size()) {
1309               return errors::Internal(
1310                   "Expected number of dtypes: ",
1311                   dataset()->output_dtypes().size(),
1312                   " but number in snapshot: ", metadata.dtype_size());
1313             }
1314             for (int i = 0; i < metadata.dtype_size(); ++i) {
1315               if (metadata.dtype(i) != dataset()->output_dtypes()[i]) {
1316                 return errors::Internal(
1317                     "Type: ", i,
1318                     " doesn't match. Snapshot: ", metadata.dtype(i),
1319                     "; dataset: ", dataset()->output_dtypes()[i]);
1320               }
1321             }
1322             iterator_ = absl::make_unique<SnapshotReaderIterator>(
1323                 SnapshotReaderIterator::Params{
1324                     dataset(), absl::StrCat(prefix(), "ReaderImpl")},
1325                 hash_dir_, run_id, metadata.version());
1326             break;
1327           case snapshot_util::PASSTHROUGH:
1328             iterator_ = absl::make_unique<SnapshotPassthroughIterator>(
1329                 SnapshotPassthroughIterator::Params{
1330                     dataset(), absl::StrCat(prefix(), "PassthroughImpl")});
1331             break;
1332         }
1333         TF_RETURN_IF_ERROR(iterator_->InitializeBase(ctx, this));
1334         return iterator_->Initialize(ctx);
1335       }
1336 
1337      protected:
1338       class SnapshotReaderIterator : public DatasetIterator<Dataset> {
1339        public:
1340         static constexpr const char* const kParse = "Parse";
1341 
SnapshotReaderIterator(const Params & params,const string & hash_dir,const string & run_id,int64 version)1342         explicit SnapshotReaderIterator(const Params& params,
1343                                         const string& hash_dir,
1344                                         const string& run_id, int64 version)
1345             : DatasetIterator<Dataset>(params),
1346               hash_dir_(hash_dir),
1347               run_id_(run_id),
1348               version_(version) {}
1349 
~SnapshotReaderIterator()1350         ~SnapshotReaderIterator() override {
1351           mutex_lock l(mu_);
1352           cancelled_ = true;
1353           cond_var_.notify_all();
1354           while (num_active_threads_ > 0) {
1355             cond_var_.wait(l);
1356           }
1357         }
1358 
Initialize(IteratorContext * ctx)1359         Status Initialize(IteratorContext* ctx) override {
1360           mutex_lock l(mu_);
1361           thread_pool_ = ctx->CreateThreadPool(kSnapshotReaderWorkerPool,
1362                                                dataset()->num_reader_threads_);
1363           run_dir_ = io::JoinPath(hash_dir_, run_id_);
1364           // Get all the files in the run_dir.
1365           std::vector<std::string> filenames_str;
1366           TF_RETURN_IF_ERROR(ctx->env()->GetMatchingPaths(
1367               absl::StrCat(absl::string_view(run_dir_), "/*"), &filenames_str));
1368           filenames_.resize(filenames_str.size());
1369           std::copy(filenames_str.begin(), filenames_str.end(),
1370                     filenames_.begin());
1371           if (filenames_.empty()) {
1372             return errors::NotFound("Could not find any files in dir: ",
1373                                     run_dir_);
1374           }
1375 
1376           if (dataset()->shuffle_on_read_) {
1377             uint64 seed = dataset()->seed_ + dataset()->seed2_;
1378             if (dataset()->seed_ == 0 && dataset()->seed2_ == 0) {
1379               seed = random::New64();
1380             }
1381 
1382             std::mt19937 rng(seed);
1383             std::shuffle(filenames_.begin(), filenames_.end(), rng);
1384           } else {
1385             std::sort(filenames_.begin(), filenames_.end());
1386           }
1387 
1388           for (auto i = 0; i < dataset()->num_reader_threads_; ++i) {
1389             curr_filenames_.push_back(GetNextFilename());
1390           }
1391           return Status::OK();
1392         }
1393 
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)1394         Status GetNextInternal(IteratorContext* ctx,
1395                                std::vector<Tensor>* out_tensors,
1396                                bool* end_of_sequence) override {
1397           absl::Time start = absl::Now();
1398           mutex_lock l(mu_);
1399           if (!background_threads_started_) {
1400             for (int i = 0; i < dataset()->num_reader_threads_; ++i) {
1401               ++num_active_threads_;
1402               thread_pool_->Schedule(
1403                   [this, i, env = ctx->env()]() { ReadingFilesLoop(env, i); });
1404             }
1405             background_threads_started_ = true;
1406           }
1407 
1408           // Wait till the buffer has something in it.
1409           while (!cancelled_ && buffer_.empty() &&
1410                  !background_threads_finished_) {
1411             cond_var_.wait(l);
1412           }
1413 
1414           if (cancelled_) {
1415             return errors::Cancelled(
1416                 "SnapshotDatasetOp::Dataset::SnapshotReaderIterator::GetNext");
1417           }
1418 
1419           const auto& stats_aggregator = ctx->stats_aggregator();
1420           if (stats_aggregator) {
1421             stats_aggregator->AddScalar(
1422                 absl::StrCat(dataset()->node_name(), kSeparator,
1423                              kSnapshotReadElements),
1424                 static_cast<float>(num_elements_read_), elements_produced_);
1425             stats_aggregator->AddScalar(
1426                 absl::StrCat(dataset()->node_name(), kSeparator,
1427                              "snapshot_reader_buffer_size"),
1428                 static_cast<float>(buffer_.size()), elements_produced_);
1429           }
1430 
1431           if (!buffer_.empty()) {
1432             Status s = buffer_.front().status;
1433             if (s.ok()) {
1434               *end_of_sequence = false;
1435               *out_tensors = std::move(buffer_.front().value);
1436 
1437               {
1438                 profiler::TraceMe activity(
1439                     [&]() {
1440                       return absl::StrCat(prefix(), kSeparator, kBookkeeping);
1441                     },
1442                     profiler::TraceMeLevel::kInfo);
1443                 // Printing some statistics along the way.
1444                 int64 num_bytes = 0;
1445                 for (int i = 0; i < out_tensors->size(); ++i) {
1446                   num_bytes += (*out_tensors)[i].TotalBytes();
1447                 }
1448                 absl::Time end = absl::Now();
1449                 absl::Duration d = end - start;
1450                 time_spent_micros_ += absl::ToInt64Microseconds(d);
1451                 kbytes_read_ += static_cast<double>(num_bytes) / 1024.0;
1452                 float read_throughput =
1453                     (kbytes_read_ / 1024.0) / (time_spent_micros_ / 1000000.0);
1454                 if (stats_aggregator) {
1455                   stats_aggregator->AddScalar(
1456                       absl::StrCat(dataset()->node_name(), kSeparator,
1457                                    kSnapshotReadThroughput),
1458                       read_throughput, elements_produced_);
1459                 }
1460                 elements_produced_++;
1461                 if (elements_produced_ % 10000 == 0) {
1462                   LOG(INFO)
1463                       << "Current read throughput (MBPS): " << read_throughput;
1464                 }
1465               }
1466             }
1467             buffer_.pop_front();
1468             cond_var_.notify_all();
1469             return s;
1470           }
1471 
1472           if (background_threads_finished_) {
1473             *end_of_sequence = true;
1474             return Status::OK();
1475           }
1476 
1477           return errors::Internal("Unreachable point in SnapshotReader");
1478         }
1479 
1480        protected:
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)1481         Status SaveInternal(SerializationContext* ctx,
1482                             IteratorStateWriter* writer) override {
1483           mutex_lock l(mu_);
1484           TF_RETURN_IF_ERROR(
1485               writer->WriteScalar(full_name(kHashDir), hash_dir_));
1486           TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kRunId), run_id_));
1487           TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kRunDir), run_dir_));
1488           TF_RETURN_IF_ERROR(
1489               writer->WriteScalar(full_name(kVersionStr), version_));
1490           TF_RETURN_IF_ERROR(writer->WriteScalar(
1491               full_name(strings::StrCat(kFilenames, kSizeSuffix)),
1492               filenames_.size()));
1493           for (size_t i = 0; i < filenames_.size(); ++i) {
1494             TF_RETURN_IF_ERROR(writer->WriteScalar(
1495                 full_name(strings::StrCat(kFilenames, "[", i, "]")),
1496                 filenames_[i]));
1497           }
1498           for (auto i = 0; i < dataset()->num_reader_threads_; ++i) {
1499             TF_RETURN_IF_ERROR(writer->WriteScalar(
1500                 full_name(strings::StrCat(kCurrentFilenames, "[", i, "]")),
1501                 curr_filenames_[i]));
1502           }
1503           TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kElementsProduced),
1504                                                  elements_produced_));
1505           TF_RETURN_IF_ERROR(
1506               writer->WriteScalar(full_name(kNextFileIndex), next_file_index_));
1507           TF_RETURN_IF_ERROR(
1508               writer->WriteScalar(full_name(kNumFilesDone), num_files_done_));
1509           TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kNumElementsRead),
1510                                                  num_elements_read_));
1511           VLOG(2) << "Saving SnapshotReaderIterator: " << num_elements_read_
1512                   << "; elements_produced: " << elements_produced_;
1513           return Status::OK();
1514         }
1515 
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)1516         Status RestoreInternal(IteratorContext* ctx,
1517                                IteratorStateReader* reader) override {
1518           mutex_lock l(mu_);
1519           tstring hash_dir, run_id, run_dir;
1520           TF_RETURN_IF_ERROR(
1521               reader->ReadScalar(full_name(kHashDir), &hash_dir));
1522           TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kHashDir), &run_id));
1523           TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kHashDir), &run_dir));
1524           if (run_dir != run_dir_) {
1525             LOG(ERROR) << "Restoring read iterator from ckpt with old "
1526                        << "run_dir: " << run_dir
1527                        << " but new run_dir is: " << run_dir_
1528                        << ". We'll now restart snapshot creation.";
1529             return Status::OK();
1530           }
1531           TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kRunId), &run_id_));
1532           TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kRunDir), &run_dir_));
1533           TF_RETURN_IF_ERROR(
1534               reader->ReadScalar(full_name(kVersionStr), &version_));
1535           curr_filenames_.clear();
1536           curr_filenames_.reserve(dataset()->num_reader_threads_);
1537           for (auto i = 0; i < dataset()->num_reader_threads_; ++i) {
1538             curr_filenames_.emplace_back();
1539             TF_RETURN_IF_ERROR(reader->ReadScalar(
1540                 full_name(strings::StrCat(kCurrentFilenames, "[", i, "]")),
1541                 &curr_filenames_.back()));
1542           }
1543           size_t filenames_size;
1544           {
1545             int64 temp;
1546             TF_RETURN_IF_ERROR(reader->ReadScalar(
1547                 full_name(strings::StrCat(kFilenames, kSizeSuffix)), &temp));
1548             filenames_size = static_cast<size_t>(temp);
1549           }
1550           if (filenames_.size() != filenames_size) {
1551             LOG(ERROR) << "Old filenames size: " << filenames_size
1552                        << "; new filenames size: " << filenames_.size();
1553           }
1554           filenames_.clear();
1555           filenames_.reserve(filenames_size);
1556           for (size_t i = 0; i < filenames_size; ++i) {
1557             filenames_.emplace_back();
1558             TF_RETURN_IF_ERROR(reader->ReadScalar(
1559                 full_name(strings::StrCat(kFilenames, "[", i, "]")),
1560                 &filenames_.back()));
1561           }
1562           {
1563             int64 temp;
1564             TF_RETURN_IF_ERROR(
1565                 reader->ReadScalar(full_name(kElementsProduced), &temp));
1566             elements_produced_ = static_cast<uint64>(temp);
1567           }
1568           {
1569             int64 temp;
1570             TF_RETURN_IF_ERROR(
1571                 reader->ReadScalar(full_name(kNextFileIndex), &temp));
1572             next_file_index_ = static_cast<uint64>(temp);
1573           }
1574           TF_RETURN_IF_ERROR(
1575               reader->ReadScalar(full_name(kNumFilesDone), &num_files_done_));
1576           TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kNumElementsRead),
1577                                                 &num_elements_read_));
1578           VLOG(2) << "Restoring SnapshotReaderIterator: " << num_elements_read_
1579                   << "; elements_produced: " << elements_produced_;
1580           return Status::OK();
1581         }
1582 
1583        private:
1584         // Reads one file end to end.
ReadFile(Env * env,const string & filename)1585         Status ReadFile(Env* env, const string& filename) {
1586           std::unique_ptr<snapshot_util::Reader> reader;
1587           TF_RETURN_IF_ERROR(snapshot_util::Reader::Create(
1588               env, filename, dataset()->compression_, version_,
1589               dataset()->output_dtypes(), &reader));
1590           while (true) {
1591             // Wait for a slot in the buffer.
1592             {
1593               mutex_lock l(mu_);
1594               while (!cancelled_ &&
1595                      buffer_.size() >= dataset()->reader_buffer_size_) {
1596                 cond_var_.wait(l);
1597               }
1598 
1599               if (cancelled_) {
1600                 return errors::Cancelled(
1601                     "SnapshotDatasetOp::Dataset::SnapshotReaderIterator::"
1602                     "ReadFile");
1603               }
1604             }
1605             std::vector<Tensor> read_tensors;
1606             Status s = reader->ReadTensors(&read_tensors);
1607             if (s.ok()) {
1608               profiler::TraceMe activity(
1609                   [&]() { return absl::StrCat(prefix(), kSeparator, kParse); },
1610                   profiler::TraceMeLevel::kInfo);
1611               BufferElement elem;
1612               elem.value = std::move(read_tensors);
1613               elem.status = Status::OK();
1614               mutex_lock l(mu_);
1615               buffer_.push_back(std::move(elem));
1616               num_elements_read_++;
1617               cond_var_.notify_all();
1618             } else if (errors::IsOutOfRange(s)) {
1619               return Status::OK();
1620             } else {
1621               return s;
1622             }
1623           }
1624           return Status::OK();
1625         }
1626 
GetNextFilename()1627         string GetNextFilename() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
1628           if (next_file_index_ >= filenames_.size()) {
1629             return "";
1630           }
1631           string filename = io::JoinPath(dataset()->reader_path_prefix_,
1632                                          filenames_[next_file_index_]);
1633           next_file_index_++;
1634           return filename;
1635         }
1636 
1637         // Pulls one file off the filenames_ list and reads it through. When
1638         // all files are read, terminates.
ReadingFilesLoop(Env * env,int i)1639         void ReadingFilesLoop(Env* env, int i) {
1640           auto cleanup = gtl::MakeCleanup([this]() {
1641             mutex_lock l(mu_);
1642             --num_active_threads_;
1643             cond_var_.notify_all();
1644           });
1645           while (true) {
1646             string filename = "";
1647             {
1648               mutex_lock l(mu_);
1649               filename = curr_filenames_[i];
1650               if (filename.empty()) {
1651                 return;
1652               }
1653               VLOG(2) << "Starting to read: " << filename;
1654             }
1655             Status s = ReadFile(env, filename);
1656             // If we get to the end of the file, it's a clean termination and
1657             // we are at the end of the file. If all files have been processed,
1658             // then we insert an end_of_sequence marker in the buffer and
1659             // terminate the loop.
1660             if (s.ok()) {
1661               VLOG(2) << "Finished reading: " << filename;
1662               mutex_lock l(mu_);
1663               num_files_done_++;
1664               if (num_files_done_ >= filenames_.size()) {
1665                 background_threads_finished_ = true;
1666                 cond_var_.notify_all();
1667                 return;
1668               }
1669               curr_filenames_[i] = GetNextFilename();
1670             } else {
1671               LOG(ERROR) << "Encountered an error: " << s.ToString();
1672               BufferElement elem;
1673               elem.status = s;
1674               mutex_lock l(mu_);
1675               buffer_.push_back(std::move(elem));
1676               cond_var_.notify_all();
1677               return;
1678             }
1679           }
1680         }
1681 
WriteStatus(IteratorStateWriter * writer,size_t index,const Status & status)1682         Status WriteStatus(IteratorStateWriter* writer, size_t index,
1683                            const Status& status)
1684             TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
1685           TF_RETURN_IF_ERROR(writer->WriteScalar(
1686               CodeKey(index), static_cast<int64>(status.code())));
1687           if (!status.ok()) {
1688             TF_RETURN_IF_ERROR(writer->WriteScalar(ErrorMessageKey(index),
1689                                                    status.error_message()));
1690           }
1691           return Status::OK();
1692         }
1693 
ReadStatus(IteratorStateReader * reader,size_t index,Status * status)1694         Status ReadStatus(IteratorStateReader* reader, size_t index,
1695                           Status* status) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
1696           int64 code_int;
1697           TF_RETURN_IF_ERROR(reader->ReadScalar(CodeKey(index), &code_int));
1698           error::Code code = static_cast<error::Code>(code_int);
1699 
1700           if (code != error::Code::OK) {
1701             tstring error_message;
1702             TF_RETURN_IF_ERROR(
1703                 reader->ReadScalar(ErrorMessageKey(index), &error_message));
1704             *status = Status(code, error_message);
1705           } else {
1706             *status = Status::OK();
1707           }
1708           return Status::OK();
1709         }
1710 
CodeKey(size_t index)1711         string CodeKey(size_t index) {
1712           return full_name(strings::StrCat(kStatus, "[", index, "]", kCode));
1713         }
1714 
ErrorMessageKey(size_t index)1715         string ErrorMessageKey(size_t index) {
1716           return full_name(
1717               strings::StrCat(kStatus, "[", index, "]", kErrorMessage));
1718         }
1719 
1720         struct BufferElement {
1721           Status status;
1722           std::vector<Tensor> value;
1723         };
1724 
1725         mutex mu_;
1726         condition_variable cond_var_;
1727 
1728         const string hash_dir_;
1729         tstring run_id_ TF_GUARDED_BY(mu_);
1730         tstring run_dir_ TF_GUARDED_BY(mu_);
1731         int64 version_;
1732         std::vector<tstring> filenames_;
1733 
1734         uint64 elements_produced_ TF_GUARDED_BY(mu_) = 0;
1735         int64 time_spent_micros_ TF_GUARDED_BY(mu_) = 0;
1736         double kbytes_read_ TF_GUARDED_BY(mu_) = 0;
1737         size_t next_file_index_ TF_GUARDED_BY(mu_) = 0;
1738         int64 num_files_done_ TF_GUARDED_BY(mu_) = 0;
1739 
1740         std::unique_ptr<thread::ThreadPool> thread_pool_;
1741         int64 num_active_threads_ TF_GUARDED_BY(mu_) = 0;
1742         std::deque<BufferElement> buffer_ TF_GUARDED_BY(mu_);
1743         bool cancelled_ TF_GUARDED_BY(mu_) = false;
1744         bool background_threads_started_ TF_GUARDED_BY(mu_) = false;
1745         bool background_threads_finished_ TF_GUARDED_BY(mu_) = false;
1746         int64 num_elements_read_ TF_GUARDED_BY(mu_) = 0;
1747         // curr_filenames_ tracks which file is being read by each thread.
1748         std::vector<tstring> curr_filenames_ TF_GUARDED_BY(mu_);
1749       };
1750 
1751       class SnapshotWriterIterator : public DatasetIterator<Dataset> {
1752        public:
1753         static constexpr const char* const kProcessOneElement =
1754             "ProcessOneElement";
1755 
SnapshotWriterIterator(const Params & params,const string & hash_dir,const string & run_id)1756         explicit SnapshotWriterIterator(const Params& params,
1757                                         const string& hash_dir,
1758                                         const string& run_id)
1759             : DatasetIterator<Dataset>(params),
1760               hash_dir_(hash_dir),
1761               run_id_(run_id) {
1762           if (run_id_.empty()) {
1763             run_id_ = strings::StrCat(
1764                 strings::Hex(random::New64(), strings::kZeroPad4));
1765           }
1766           run_dir_ =
1767               io::JoinPath(dataset()->writer_path_prefix_, hash_dir_, run_id_);
1768         }
1769 
~SnapshotWriterIterator()1770         ~SnapshotWriterIterator() override {
1771           mutex_lock l(mu_);
1772           cancelled_ = true;
1773           cond_var_.notify_all();
1774           while (num_active_threads_ > 0) {
1775             cond_var_.wait(l);
1776           }
1777         }
1778 
Initialize(IteratorContext * ctx)1779         Status Initialize(IteratorContext* ctx) override {
1780           thread_pool_ = ctx->CreateThreadPool(kSnapshotWriterWorkerPool,
1781                                                dataset()->num_writer_threads_);
1782           return dataset()->input_->MakeIterator(ctx, this, prefix(),
1783                                                  &input_impl_);
1784         }
1785 
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)1786         Status GetNextInternal(IteratorContext* ctx,
1787                                std::vector<Tensor>* out_tensors,
1788                                bool* end_of_sequence) override {
1789           absl::Time start = absl::Now();
1790 
1791           bool first_call;
1792           bool is_restored;
1793           {
1794             mutex_lock l(mu_);
1795             first_call = first_call_;
1796             is_restored = is_restored_;
1797             if (first_call_) {
1798               // If we're restoring then the directory already exists and we
1799               // don't want to overwrite the snapshot metadata file.
1800               if (!is_restored_) {
1801                 TF_RETURN_IF_ERROR(ctx->env()->RecursivelyCreateDir(run_dir_));
1802                 experimental::SnapshotMetadataRecord metadata;
1803                 metadata.set_creation_timestamp(EnvTime::NowMicros());
1804                 metadata.set_graph_hash(dataset()->graph_hash_);
1805                 metadata.set_run_id(run_id_.data(), run_id_.size());
1806                 metadata.set_version(kCurrentVersion);
1807                 for (const auto& output_dtype : dataset()->output_dtypes()) {
1808                   metadata.add_dtype(output_dtype);
1809                 }
1810                 metadata.set_finalized(false);
1811                 TF_RETURN_IF_ERROR(snapshot_util::WriteMetadataFile(
1812                     ctx->env(), hash_dir_, &metadata));
1813               }
1814               for (int i = 0; i < dataset()->num_writer_threads_; ++i) {
1815                 ++num_active_threads_;
1816                 thread_pool_->Schedule(
1817                     [this, env = ctx->env()]() { WriterThread(env); });
1818               }
1819               first_call_ = false;
1820             }
1821           }
1822 
1823           // When we reach the end of the data, we'd like to finalize the
1824           // snapshot and write the metadata file out. If we just check for
1825           // end_of_sequence on the GetNext call then we will need to make
1826           // N + 1 GetNext calls (if N is the total number of elements in the
1827           // dataset). So right now we solve this issue by prefetching the next
1828           // element in the data stream. Therefore the first call ends up
1829           // pulling two elements.
1830           if (first_call && !is_restored) {
1831             TF_RETURN_IF_ERROR(FillBuffer(ctx));
1832           }
1833 
1834           {
1835             mutex_lock l(mu_);
1836             // Populate out_tensors with the prefetched data.
1837             *out_tensors = std::move(next_elem_.value);
1838             *end_of_sequence = next_elem_.end_of_sequence;
1839           }
1840 
1841           // Update prefetched_elem with the next element.
1842           TF_RETURN_IF_ERROR(FillBuffer(ctx));
1843 
1844           {
1845             profiler::TraceMe activity(
1846                 [&]() {
1847                   return absl::StrCat(prefix(), kSeparator, kBookkeeping);
1848                 },
1849                 profiler::TraceMeLevel::kInfo);
1850 
1851             // Book keeping to report some statistics.
1852             mutex_lock l(mu_);
1853             int64 num_bytes = 0;
1854             for (const auto& out_tensor : *out_tensors) {
1855               num_bytes += out_tensor.TotalBytes();
1856             }
1857 
1858             const auto& stats_aggregator = ctx->stats_aggregator();
1859             if (stats_aggregator) {
1860               stats_aggregator->AddScalar(
1861                   absl::StrCat(dataset()->node_name(), kSeparator,
1862                                kSnapshotWrittenElements),
1863                   static_cast<float>(num_elements_written_),
1864                   elements_produced_);
1865               stats_aggregator->AddScalar(
1866                   absl::StrCat(dataset()->node_name(), kSeparator,
1867                                "snapshot_writer_buffer_size"),
1868                   static_cast<float>(buffer_.size()), elements_produced_);
1869             }
1870 
1871             absl::Time end = absl::Now();
1872             absl::Duration d = end - start;
1873             time_spent_micros_ += absl::ToInt64Microseconds(d);
1874             bytes_produced_ += num_bytes;
1875             float write_throughput = (bytes_produced_ * 1000000.0) /
1876                                      (time_spent_micros_ * 1024.0 * 1024.0);
1877             if (stats_aggregator) {
1878               stats_aggregator->AddScalar(
1879                   absl::StrCat(dataset()->node_name(), kSeparator,
1880                                kSnapshotWriteThroughput),
1881                   write_throughput, elements_produced_);
1882             }
1883 
1884             elements_produced_++;
1885             if (elements_produced_ % 10000 == 0) {
1886               LOG(INFO) << "Current write throughput (MBPS): "
1887                         << write_throughput;
1888             }
1889           }
1890           return Status::OK();
1891         }
1892 
1893        protected:
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)1894         Status SaveInternal(SerializationContext* ctx,
1895                             IteratorStateWriter* writer) override {
1896           mutex_lock l(mu_);
1897           TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
1898           if (end_of_sequence_) {
1899             TF_RETURN_IF_ERROR(
1900                 writer->WriteScalar(full_name(kEndOfSequence), ""));
1901           }
1902           TF_RETURN_IF_ERROR(
1903               writer->WriteScalar(full_name(kHashDir), hash_dir_));
1904           TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kRunId), run_id_));
1905           TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kRunDir), run_dir_));
1906           TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kElementsProduced),
1907                                                  elements_produced_));
1908           TF_RETURN_IF_ERROR(writer->WriteScalar(
1909               full_name(strings::StrCat(kBuffer, kSizeSuffix)),
1910               buffer_.size()));
1911           for (size_t i = 0; i < buffer_.size(); ++i) {
1912             auto& buffer_element = buffer_[i];
1913             if (buffer_element.end_of_sequence) {
1914               TF_RETURN_IF_ERROR(writer->WriteScalar(
1915                   full_name(
1916                       strings::StrCat(kBuffer, "[", i, "].", kEndOfSequence)),
1917                   ""));
1918             }
1919             TF_RETURN_IF_ERROR(writer->WriteScalar(
1920                 full_name(strings::StrCat(kBuffer, "[", i, "]", kSizeSuffix)),
1921                 buffer_element.value.size()));
1922             for (size_t j = 0; j < buffer_element.value.size(); j++) {
1923               TF_RETURN_IF_ERROR(writer->WriteTensor(
1924                   full_name(strings::StrCat(kBuffer, "[", i, "][", j, "]")),
1925                   buffer_element.value[j]));
1926             }
1927           }
1928           TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kNumElementsWritten),
1929                                                  num_elements_written_));
1930           if (next_elem_.end_of_sequence) {
1931             TF_RETURN_IF_ERROR(writer->WriteScalar(
1932                 full_name(strings::StrCat(kNextElem, ".", kEndOfSequence)),
1933                 ""));
1934           }
1935           TF_RETURN_IF_ERROR(writer->WriteScalar(
1936               full_name(strings::StrCat(kNextElem, kSizeSuffix)),
1937               next_elem_.value.size()));
1938           for (size_t i = 0; i < next_elem_.value.size(); i++) {
1939             TF_RETURN_IF_ERROR(writer->WriteTensor(
1940                 full_name(strings::StrCat(kNextElem, "[", i, "]")),
1941                 next_elem_.value[i]));
1942           }
1943           VLOG(2) << "Saving SnapshotWriterIterator: " << num_elements_written_
1944                   << "; elements_produced: " << elements_produced_;
1945           return Status::OK();
1946         }
1947 
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)1948         Status RestoreInternal(IteratorContext* ctx,
1949                                IteratorStateReader* reader) override {
1950           mutex_lock l(mu_);
1951           buffer_.clear();
1952           TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
1953           tstring hash_dir;
1954           TF_RETURN_IF_ERROR(
1955               reader->ReadScalar(full_name(kHashDir), &hash_dir));
1956           // If the hash dir has changed then we restart writing.
1957           if (hash_dir != hash_dir_) {
1958             LOG(INFO) << "Old hash dir from ckpt: " << hash_dir
1959                       << " is not the same as the new one: " << hash_dir_;
1960             return Status::OK();
1961           }
1962           is_restored_ = true;
1963           if (reader->Contains(full_name(kEndOfSequence))) {
1964             end_of_sequence_ = true;
1965           } else {
1966             end_of_sequence_ = false;
1967           }
1968           TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kRunId), &run_id_));
1969           TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kRunDir), &run_dir_));
1970           {
1971             int64 temp;
1972             TF_RETURN_IF_ERROR(
1973                 reader->ReadScalar(full_name(kElementsProduced), &temp));
1974             elements_produced_ = static_cast<uint64>(temp);
1975           }
1976           size_t buffer_size;
1977           {
1978             int64 temp;
1979             TF_RETURN_IF_ERROR(reader->ReadScalar(
1980                 full_name(strings::StrCat(kBuffer, kSizeSuffix)), &temp));
1981             buffer_size = static_cast<size_t>(temp);
1982           }
1983           for (size_t i = 0; i < buffer_size; i++) {
1984             buffer_.emplace_back();
1985             auto& buffer_element = buffer_.back();
1986             size_t value_size;
1987             {
1988               int64 temp;
1989               TF_RETURN_IF_ERROR(reader->ReadScalar(
1990                   full_name(strings::StrCat(kBuffer, "[", i, "]", kSizeSuffix)),
1991                   &temp));
1992               value_size = static_cast<size_t>(temp);
1993             }
1994             if (reader->Contains(full_name(
1995                     strings::StrCat(kBuffer, "[", i, "].", kEndOfSequence)))) {
1996               buffer_element.end_of_sequence = true;
1997             } else {
1998               buffer_element.end_of_sequence = false;
1999             }
2000             buffer_element.value.reserve(value_size);
2001             for (size_t j = 0; j < value_size; j++) {
2002               buffer_element.value.emplace_back();
2003               TF_RETURN_IF_ERROR(reader->ReadTensor(
2004                   full_name(strings::StrCat(kBuffer, "[", i, "][", j, "]")),
2005                   &buffer_element.value.back()));
2006             }
2007           }
2008           // Since the last save we might have written out some files. So we
2009           // get a list of files in the directory and take the final filename
2010           // written. We use the name of the snapshot file to figure out
2011           // next_file_index_;
2012           std::vector<std::string> filenames;
2013           TF_RETURN_IF_ERROR(ctx->env()->GetMatchingPaths(
2014               absl::StrCat(absl::string_view(run_dir_), "/*"), &filenames));
2015           std::sort(filenames.begin(), filenames.end());
2016           std::string final_filename = filenames.back();
2017           std::vector<std::string> split_filename =
2018               absl::StrSplit(final_filename, '/');
2019           std::vector<std::string> split_snapshot_filename =
2020               absl::StrSplit(split_filename.back(), '.');
2021           std::string max_num_str = split_snapshot_filename[0];
2022           uint64 max_num;
2023           if (!strings::safe_strtou64(max_num_str, &max_num)) {
2024             return errors::Internal("Could not parse: ", max_num, " as uint64");
2025           }
2026           next_file_index_ = max_num + 1;
2027           TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kNumElementsWritten),
2028                                                 &num_elements_written_));
2029           size_t next_elem_size;
2030           {
2031             int64 temp;
2032             TF_RETURN_IF_ERROR(reader->ReadScalar(
2033                 full_name(strings::StrCat(kNextElem, kSizeSuffix)), &temp));
2034             next_elem_size = static_cast<size_t>(temp);
2035           }
2036           if (reader->Contains(
2037                   full_name(strings::StrCat(kNextElem, ".", kEndOfSequence)))) {
2038             next_elem_.end_of_sequence = true;
2039           } else {
2040             next_elem_.end_of_sequence = false;
2041           }
2042           next_elem_.value.reserve(next_elem_size);
2043           for (size_t i = 0; i < next_elem_size; i++) {
2044             next_elem_.value.emplace_back();
2045             TF_RETURN_IF_ERROR(reader->ReadTensor(
2046                 full_name(strings::StrCat(kNextElem, "[", i, "]")),
2047                 &next_elem_.value.back()));
2048           }
2049           VLOG(2) << "Restoring SnapshotWriterIterator: "
2050                   << num_elements_written_
2051                   << "; elements_produced: " << elements_produced_;
2052           return Status::OK();
2053         }
2054 
2055        private:
GetSnapshotFilename()2056         string GetSnapshotFilename() {
2057           mutex_lock l(mu_);
2058           string snapshot_data_filename = io::JoinPath(
2059               run_dir_, strings::Printf(
2060                             "%08llu.snapshot",
2061                             static_cast<unsigned long long>(next_file_index_)));
2062           next_file_index_++;
2063           return snapshot_data_filename;
2064         }
2065 
FillBuffer(IteratorContext * ctx)2066         Status FillBuffer(IteratorContext* ctx) TF_LOCKS_EXCLUDED(mu_) {
2067           snapshot_util::ElementOrEOF elem;
2068           TF_RETURN_IF_ERROR(
2069               input_impl_->GetNext(ctx, &elem.value, &elem.end_of_sequence));
2070 
2071           mutex_lock l(mu_);
2072           next_elem_ = std::move(elem);
2073 
2074           if (next_elem_.end_of_sequence) {
2075             end_of_sequence_ = true;
2076             cond_var_.notify_all();
2077             // Now we wait till all background threads finish.
2078             while (num_active_threads_ > 0) {
2079               cond_var_.wait(l);
2080             }
2081             return Status::OK();
2082           }
2083 
2084           // Wait for a space in the buffer_.
2085           while (!cancelled_ &&
2086                  buffer_.size() >= dataset()->writer_buffer_size_) {
2087             cond_var_.wait(l);
2088           }
2089 
2090           if (cancelled_) {
2091             return errors::Cancelled(
2092                 "SnapshotDatasetOp::SnapshotWriterIterator::GetNext");
2093           }
2094 
2095           if (buffer_.size() >= dataset()->writer_buffer_size_) {
2096             return errors::Internal(
2097                 "Buffer size: ", buffer_.size(), " should be smaller than ",
2098                 "maximum size: ", dataset()->writer_buffer_size_);
2099           }
2100 
2101           snapshot_util::ElementOrEOF elem_copy = next_elem_;
2102           buffer_.push_back(elem_copy);
2103           cond_var_.notify_all();
2104           return Status::OK();
2105         }
2106 
ProcessOneElement(Env * env,int64 * bytes_written,string * snapshot_data_filename,std::unique_ptr<snapshot_util::Writer> * writer,bool * end_of_processing)2107         Status ProcessOneElement(Env* env, int64* bytes_written,
2108                                  string* snapshot_data_filename,
2109                                  std::unique_ptr<snapshot_util::Writer>* writer,
2110                                  bool* end_of_processing) {
2111           profiler::TraceMe activity(
2112               [&]() {
2113                 return absl::StrCat(prefix(), kSeparator, kProcessOneElement);
2114               },
2115               profiler::TraceMeLevel::kInfo);
2116           bool cancelled = false;
2117           *end_of_processing = false;
2118           bool produced_elem = false;
2119           bool snapshot_failed = false;
2120           snapshot_util::ElementOrEOF elem;
2121           {
2122             mutex_lock l(mu_);
2123             // Wait for buffer to not be empty.
2124             while (!cancelled_ && buffer_.empty() && !end_of_sequence_ &&
2125                    !snapshot_failed_) {
2126               cond_var_.wait(l);
2127             }
2128             cancelled = cancelled_;
2129             if (!buffer_.empty()) {
2130               produced_elem = true;
2131               std::swap(elem, buffer_.front());
2132               buffer_.pop_front();
2133               cond_var_.notify_all();
2134             } else {
2135               *end_of_processing = end_of_sequence_;
2136             }
2137             snapshot_failed = snapshot_failed_;
2138           }
2139 
2140           if (cancelled || snapshot_failed) {
2141             TF_RETURN_IF_ERROR((*writer)->Close());
2142             if (snapshot_failed) {
2143               return errors::Internal(
2144                   "SnapshotDataset::SnapshotWriterIterator snapshot failed");
2145             }
2146             return errors::Cancelled(
2147                 "SnapshotDataset::SnapshotWriterIterator cancelled");
2148           }
2149 
2150           if (produced_elem) {
2151             for (const auto& out_tensor : elem.value) {
2152               *bytes_written += out_tensor.TotalBytes();
2153             }
2154 
2155             bool should_close;
2156             TF_RETURN_IF_ERROR(
2157                 ShouldCloseWriter(env, *snapshot_data_filename, *bytes_written,
2158                                   (*writer).get(), &should_close));
2159             if (should_close) {
2160               // If we exceed the shard size, we get a new file and reset.
2161               TF_RETURN_IF_ERROR((*writer)->Close());
2162               *snapshot_data_filename = GetSnapshotFilename();
2163 
2164               TF_RETURN_IF_ERROR(snapshot_util::Writer::Create(
2165                   env, *snapshot_data_filename, dataset()->compression_,
2166                   kCurrentVersion, dataset()->output_dtypes(), writer));
2167               *bytes_written = 0;
2168             }
2169             TF_RETURN_IF_ERROR((*writer)->WriteTensors(elem.value));
2170             return Status::OK();
2171           }
2172 
2173           if (*end_of_processing) {
2174             TF_RETURN_IF_ERROR((*writer)->Close());
2175             mutex_lock l(mu_);
2176             if (!written_final_metadata_file_) {
2177               experimental::SnapshotMetadataRecord metadata;
2178               bool file_exists;
2179               TF_RETURN_IF_ERROR(snapshot_util::ReadMetadataFile(
2180                   env, hash_dir_, &metadata, &file_exists));
2181 
2182               if (metadata.run_id() == run_id_) {
2183                 metadata.set_finalized(true);
2184                 TF_RETURN_IF_ERROR(snapshot_util::WriteMetadataFile(
2185                     env, hash_dir_, &metadata));
2186               } else {
2187                 // TODO(frankchn): We lost the race, remove all snapshots.
2188               }
2189               written_final_metadata_file_ = true;
2190               cond_var_.notify_all();
2191             }
2192           }
2193           return Status::OK();
2194         }
2195 
2196         // Just pulls off elements from the buffer and writes them.
WriterThread(Env * env)2197         void WriterThread(Env* env) {
2198           auto cleanup = gtl::MakeCleanup([this]() {
2199             mutex_lock l(mu_);
2200             --num_active_threads_;
2201             cond_var_.notify_all();
2202           });
2203 
2204           int64 bytes_written = 0;
2205           string snapshot_data_filename = GetSnapshotFilename();
2206           std::unique_ptr<snapshot_util::Writer> writer;
2207           Status s = snapshot_util::Writer::Create(
2208               env, snapshot_data_filename, dataset()->compression_,
2209               kCurrentVersion, dataset()->output_dtypes(), &writer);
2210           if (!s.ok()) {
2211             LOG(ERROR) << "Creating " << snapshot_data_filename
2212                        << " failed: " << s.ToString();
2213             mutex_lock l(mu_);
2214             snapshot_failed_ = true;
2215             cond_var_.notify_all();
2216             return;
2217           }
2218 
2219           bool end_of_processing = false;
2220           while (!end_of_processing) {
2221             Status s =
2222                 ProcessOneElement(env, &bytes_written, &snapshot_data_filename,
2223                                   &writer, &end_of_processing);
2224             if (!s.ok()) {
2225               LOG(INFO) << "Error while writing snapshot data to disk: "
2226                         << s.ToString();
2227               mutex_lock l(mu_);
2228               snapshot_failed_ = true;
2229               cond_var_.notify_all();
2230               return;
2231             }
2232             mutex_lock l(mu_);
2233             num_elements_written_++;
2234           }
2235         }
2236 
ShouldCloseWriter(Env * env,const string & filename,uint64 bytes_written,snapshot_util::Writer * writer,bool * should_close)2237         Status ShouldCloseWriter(Env* env, const string& filename,
2238                                  uint64 bytes_written,
2239                                  snapshot_util::Writer* writer,
2240                                  bool* should_close) {
2241           // If the compression ratio has been estimated, use it to decide
2242           // whether the file should be closed. We avoid estimating the
2243           // compression ratio repeatedly because it requires syncing the file,
2244           // which can be expensive.
2245           {
2246             tf_shared_lock l(mu_);
2247             if (compression_ratio_ > 0.0) {
2248               *should_close = bytes_written > (compression_ratio_ *
2249                                                dataset()->shard_size_bytes_);
2250               return Status::OK();
2251             }
2252           }
2253           // If the number of bytes written aren't shard_size_bytes_ yet, we
2254           // keep on going.
2255           if (bytes_written <= dataset()->shard_size_bytes_) {
2256             *should_close = false;
2257             return Status::OK();
2258           }
2259           // Use the actual file size to determine compression ratio.
2260           // Make sure that all bytes are written out.
2261           TF_RETURN_IF_ERROR(writer->Sync());
2262           uint64 file_size;
2263           TF_RETURN_IF_ERROR(env->GetFileSize(filename, &file_size));
2264           mutex_lock l(mu_);
2265           compression_ratio_ = static_cast<double>(bytes_written) /
2266                                static_cast<double>(file_size);
2267           LOG(INFO) << "Writing compression achieved: " << compression_ratio_;
2268           *should_close = file_size >= dataset()->shard_size_bytes_;
2269           return Status::OK();
2270         }
2271 
2272         mutex mu_;
2273         // This condition variable is notified
2274         // 1. By the background writer threads when an element from the buffer
2275         //    is consumed.
2276         // 2. By the main thread when it puts something into the buffer.
2277         // 3. By the main thread when the destructor is called to cancel.
2278         // 4. By the background writer threads when any error is encountered
2279         //    while writing.
2280         // 5. By the background threads when they finish.
2281         condition_variable cond_var_;
2282 
2283         snapshot_util::ElementOrEOF next_elem_ TF_GUARDED_BY(mu_);
2284         std::unique_ptr<IteratorBase> input_impl_;
2285 
2286         const string hash_dir_;
2287         tstring run_id_ TF_GUARDED_BY(mu_);
2288         tstring run_dir_ TF_GUARDED_BY(mu_);
2289         double compression_ratio_ TF_GUARDED_BY(mu_) = 0.0;
2290         bool is_restored_ TF_GUARDED_BY(mu_) = false;
2291 
2292         uint64 elements_produced_ TF_GUARDED_BY(mu_) = 0;
2293         int64 time_spent_micros_ TF_GUARDED_BY(mu_) = 0;
2294         int64 bytes_produced_ TF_GUARDED_BY(mu_) = 0;
2295 
2296         std::deque<snapshot_util::ElementOrEOF> buffer_ TF_GUARDED_BY(mu_);
2297         bool snapshot_failed_ TF_GUARDED_BY(mu_) = false;
2298         bool cancelled_ TF_GUARDED_BY(mu_) = false;
2299         bool first_call_ TF_GUARDED_BY(mu_) = true;
2300         bool end_of_sequence_ TF_GUARDED_BY(mu_) = false;
2301         bool written_final_metadata_file_ TF_GUARDED_BY(mu_) = false;
2302         uint64 next_file_index_ TF_GUARDED_BY(mu_) = 0;
2303         std::unique_ptr<thread::ThreadPool> thread_pool_;
2304         int64 num_active_threads_ TF_GUARDED_BY(mu_) = 0;
2305         int64 num_elements_written_ = 0;
2306       };
2307 
2308       class SnapshotPassthroughIterator : public DatasetIterator<Dataset> {
2309        public:
SnapshotPassthroughIterator(const Params & params)2310         explicit SnapshotPassthroughIterator(const Params& params)
2311             : DatasetIterator<Dataset>(params) {}
2312 
Initialize(IteratorContext * ctx)2313         Status Initialize(IteratorContext* ctx) override {
2314           return dataset()->input_->MakeIterator(ctx, this, prefix(),
2315                                                  &input_impl_);
2316         }
2317 
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)2318         Status GetNextInternal(IteratorContext* ctx,
2319                                std::vector<Tensor>* out_tensors,
2320                                bool* end_of_sequence) override {
2321           return input_impl_->GetNext(ctx, out_tensors, end_of_sequence);
2322         }
2323 
2324        protected:
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)2325         Status SaveInternal(SerializationContext* ctx,
2326                             IteratorStateWriter* writer) override {
2327           return SaveInput(ctx, writer, input_impl_);
2328         }
2329 
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)2330         Status RestoreInternal(IteratorContext* ctx,
2331                                IteratorStateReader* reader) override {
2332           return RestoreInput(ctx, reader, input_impl_);
2333         }
2334 
2335        private:
2336         std::unique_ptr<IteratorBase> input_impl_;
2337       };
2338 
2339       string hash_dir_ TF_GUARDED_BY(mu_);
2340       snapshot_util::Mode state_ TF_GUARDED_BY(mu_);
2341       std::unique_ptr<IteratorBase> iterator_ TF_GUARDED_BY(mu_);
2342 
2343       mutex mu_;
2344     };
2345 
2346     const DatasetBase* const input_;
2347     const tstring dir_;
2348     const string graph_hash_;
2349 
2350     const string reader_path_prefix_;
2351     const string writer_path_prefix_;
2352     const string compression_;
2353 
2354     const uint64 shard_size_bytes_;
2355     const uint64 pending_snapshot_expiry_seconds_;
2356     const uint64 num_reader_threads_;
2357     const uint64 reader_buffer_size_;
2358     const uint64 num_writer_threads_;
2359     const uint64 writer_buffer_size_;
2360     const bool shuffle_on_read_;
2361 
2362     const uint64 seed_;
2363     const uint64 seed2_;
2364 
2365     const std::string mode_;
2366     const std::string snapshot_name_;
2367   };
2368 
ComputeDatasetHash(const GraphDef & graph_def,const std::string & path,uint64 * hash)2369   Status ComputeDatasetHash(const GraphDef& graph_def, const std::string& path,
2370                             uint64* hash) {
2371     TF_RETURN_IF_ERROR(HashGraph(graph_def, hash));
2372     // Adding path, compression, reader / writer path prefix, shard size
2373     // bytes to the fp as they effect the data written on disk.
2374     *hash = Hash64Combine(*hash, Hash64(path));
2375     *hash = Hash64Combine(*hash, Hash64(compression_));
2376     *hash = Hash64Combine(*hash, Hash64(reader_path_prefix_));
2377     *hash = Hash64Combine(*hash, Hash64(writer_path_prefix_));
2378     *hash = Hash64Combine(*hash, shard_size_bytes_);
2379     return Status::OK();
2380   }
2381 
2382   const int graph_def_version_;
2383   DataTypeVector output_types_;
2384   std::vector<PartialTensorShape> output_shapes_;
2385 
2386   string reader_path_prefix_;
2387   string writer_path_prefix_;
2388   string compression_;
2389 
2390   int64 shard_size_bytes_;
2391   int64 pending_snapshot_expiry_seconds_;
2392   int64 num_reader_threads_;
2393   int64 reader_buffer_size_;
2394   int64 num_writer_threads_;
2395   int64 writer_buffer_size_;
2396   bool shuffle_on_read_;
2397 
2398   int64 seed_;
2399   int64 seed2_;
2400 
2401   std::string mode_;
2402   std::string snapshot_name_;
2403 };
2404 
2405 REGISTER_KERNEL_BUILDER(Name("SnapshotDataset").Device(DEVICE_CPU),
2406                         SnapshotDatasetOp);
2407 
2408 }  // namespace
2409 }  // namespace experimental
2410 }  // namespace data
2411 }  // namespace tensorflow
2412