1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/experimental/snapshot_dataset_op.h"
16
17 #include <random>
18
19 #include "absl/time/clock.h"
20 #include "tensorflow/core/common_runtime/dma_helper.h"
21 #include "tensorflow/core/framework/dataset.h"
22 #include "tensorflow/core/framework/op_kernel.h"
23 #include "tensorflow/core/framework/partial_tensor_shape.h"
24 #include "tensorflow/core/framework/stats_aggregator.h"
25 #include "tensorflow/core/framework/tensor.h"
26 #include "tensorflow/core/framework/tensor.pb.h" // NOLINT
27 #include "tensorflow/core/grappler/graph_view.h"
28 #include "tensorflow/core/kernels/data/dataset_utils.h"
29 #include "tensorflow/core/kernels/data/experimental/snapshot_util.h"
30 #include "tensorflow/core/kernels/data/hash_utils.h"
31 #include "tensorflow/core/lib/core/coding.h"
32 #include "tensorflow/core/lib/core/errors.h"
33 #include "tensorflow/core/lib/core/raw_coding.h"
34 #include "tensorflow/core/lib/gtl/cleanup.h"
35 #include "tensorflow/core/lib/hash/hash.h"
36 #include "tensorflow/core/lib/io/buffered_inputstream.h"
37 #include "tensorflow/core/lib/io/compression.h"
38 #include "tensorflow/core/lib/io/path.h"
39 #include "tensorflow/core/lib/io/random_inputstream.h"
40 #include "tensorflow/core/platform/errors.h"
41 #include "tensorflow/core/platform/file_system.h"
42 #include "tensorflow/core/platform/snappy.h"
43 #if !defined(IS_SLIM_BUILD)
44 #include "tensorflow/core/lib/io/snappy/snappy_inputbuffer.h"
45 #include "tensorflow/core/lib/io/snappy/snappy_outputbuffer.h"
46 #include "tensorflow/core/lib/io/zlib_compression_options.h"
47 #include "tensorflow/core/lib/io/zlib_inputstream.h"
48 #include "tensorflow/core/lib/io/zlib_outputbuffer.h"
49 #endif // IS_SLIM_BUILD
50 #include "tensorflow/core/lib/random/random.h"
51 #include "tensorflow/core/lib/strings/base64.h"
52 #include "tensorflow/core/lib/strings/proto_serialization.h"
53 #include "tensorflow/core/lib/strings/strcat.h"
54 #include "tensorflow/core/lib/strings/stringprintf.h"
55 #include "tensorflow/core/platform/cord.h"
56 #include "tensorflow/core/platform/macros.h"
57 #include "tensorflow/core/platform/stringprintf.h"
58 #include "tensorflow/core/profiler/lib/traceme.h"
59 #include "tensorflow/core/protobuf/snapshot.pb.h"
60 #include "tensorflow/core/util/batch_util.h"
61 #include "tensorflow/core/util/ptr_util.h"
62
63 namespace tensorflow {
64 namespace data {
65 namespace experimental {
66
67 /* static */ constexpr const char* const SnapshotDatasetV2Op::kDatasetType;
68 /* static */ constexpr const char* const SnapshotDatasetV2Op::kOutputTypes;
69 /* static */ constexpr const char* const SnapshotDatasetV2Op::kOutputShapes;
70 /* static */ constexpr const char* const SnapshotDatasetV2Op::kCompression;
71 /* static */ constexpr const char* const SnapshotDatasetV2Op::kReaderPrefix;
72 /* static */ constexpr const char* const SnapshotDatasetV2Op::kWriterPrefix;
73 /* static */ constexpr const char* const SnapshotDatasetV2Op::kHashValid;
74 /* static */ constexpr const char* const SnapshotDatasetV2Op::kHash;
75 /* static */ constexpr const char* const SnapshotDatasetV2Op::kCompressionAuto;
76 /* static */ constexpr const char* const SnapshotDatasetV2Op::kReaderFunc;
77 /* static */ constexpr const char* const SnapshotDatasetV2Op::kShardFunc;
78 /* static */ constexpr const char* const
79 SnapshotDatasetV2Op::kReaderFuncOtherArgs;
80 /* static */ constexpr const char* const
81 SnapshotDatasetV2Op::kShardFuncOtherArgs;
82 /* static */ constexpr const char* const
83 SnapshotDatasetV2Op::kReaderFuncTarguments;
84 /* static */ constexpr const char* const
85 SnapshotDatasetV2Op::kShardFuncTarguments;
86 /* static */ constexpr const int SnapshotDatasetV2Op::kFileFormatVersion;
87
88 // ==== Snapshot Implementation ====
89
90 /* The current snapshot on-disk layout is as follows:
91 * /user/specified/path/
92 * - graphhash1/
93 * - snapshot.metadata // metadata file
94 * - run1/
95 * - 00000000.shard/ // shard index
96 * // new checkpoint files are created on all threads at once, either
97 * // when a file gets too big, or when a TF checkpoint happens.
98 * - 00000000.snapshot // checkpoint file 0
99 * - 00000001.snapshot // checkpoint file 1
100 * - ...
101 * - 00000001.shard/
102 * - 00000000.snapshot
103 * - 00000001.snapshot
104 * - ...
105 * - 00000002.shard/
106 * - 00000000.snapshot
107 * - 00000001.snapshot
108 * - ...
109 * ...
110 * - run2/
111 * ...
112 * - graphhash2/
113 * ...
114 * - graphhash3/
115 * ...
116 */
117
118 class SnapshotDatasetV2Op::Dataset : public DatasetBase {
119 public:
120 Dataset(OpKernelContext* ctx, const DatasetBase* input, uint64 hash,
121 const std::string& path, const std::string& compression,
122 const std::string& reader_prefix, const std::string& writer_prefix,
123 std::unique_ptr<CapturedFunction> reader_func,
124 std::unique_ptr<CapturedFunction> shard_func);
125
126 ~Dataset() override;
127
128 std::unique_ptr<IteratorBase> MakeIteratorInternal(
129 const string& prefix) const override;
130
131 const DataTypeVector& output_dtypes() const override;
132
133 const std::vector<PartialTensorShape>& output_shapes() const override;
134
135 string DebugString() const override;
136
137 int64 Cardinality() const override;
138
139 Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override;
140
141 Status CheckExternalState() const override;
142
143 protected:
144 Status AsGraphDefInternal(SerializationContext* ctx,
145 DatasetGraphDefBuilder* b,
146 Node** output) const override;
147
148 private:
149 const DatasetBase* input_;
150 const uint64 hash_;
151 const tstring path_;
152 const std::string compression_;
153 const std::string reader_prefix_;
154 const std::string writer_prefix_;
155
156 std::unique_ptr<CapturedFunction> reader_func_;
157 std::unique_ptr<CapturedFunction> shard_func_;
158
159 class Iterator;
160 };
161
162 class SnapshotDatasetV2Op::Dataset::Iterator : public DatasetIterator<Dataset> {
163 public:
164 static constexpr const char* const kIteratorMode = "iterator_mode";
165 static constexpr const char* const kIndex = "index";
166 static constexpr const char* const kGraphHashDirectory =
167 "graph_hash_directory";
168
169 explicit Iterator(const Params& params);
170
171 Status Initialize(IteratorContext* ctx) override;
172
173 Status GetNextInternal(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
174 bool* end_of_sequence) override;
175
176 protected:
177 Status SaveInternal(SerializationContext* ctx,
178 IteratorStateWriter* writer) override;
179
180 Status RestoreInternal(IteratorContext* ctx,
181 IteratorStateReader* reader) override;
182
183 private:
184 Status InitializeIterator(IteratorContext* ctx, IteratorStateReader* reader);
185
186 int64 index_ TF_GUARDED_BY(mu_);
187 std::unique_ptr<IteratorBase> iterator_ TF_GUARDED_BY(mu_);
188 snapshot_util::Mode mode_ TF_GUARDED_BY(mu_);
189 const std::string hash_dir_;
190
191 mutex mu_;
192
193 class Reader;
194 class Writer;
195 class Passthrough;
196 };
197
198 class SnapshotDatasetV2Op::Dataset::Iterator::Reader
199 : public DatasetIterator<Dataset> {
200 public:
201 static constexpr const char* const kIteratorName = "Reader";
202
203 explicit Reader(const Params& params, int64 start_index);
204
205 ~Reader() override;
206
207 Status Initialize(IteratorContext* ctx) override;
208
209 Status GetNextInternal(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
210 bool* end_of_sequence) override;
211
212 protected:
213 Status SaveInternal(SerializationContext* ctx,
214 IteratorStateWriter* writer) override;
215
216 Status RestoreInternal(IteratorContext* ctx,
217 IteratorStateReader* reader) override;
218
219 private:
220 const int64 start_index_;
221
222 mutex mu_;
223
224 std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_);
225
226 DatasetBase* input_ TF_GUARDED_BY(mu_);
227
228 std::unique_ptr<InstantiatedCapturedFunction> instantiated_reader_func_
229 TF_GUARDED_BY(mu_);
230 };
231
232 class SnapshotDatasetV2Op::Dataset::Iterator::Writer
233 : public DatasetIterator<Dataset> {
234 public:
235 static constexpr const char* const kIteratorName = "Writer";
236 static constexpr const char* const kRunId = "run_id";
237 static constexpr const char* const kCurrentCheckpointId =
238 "current_checkpoint_id";
239
240 explicit Writer(const Params& params);
241
242 ~Writer() override;
243
244 Status Initialize(IteratorContext* ctx) override;
245
246 Status GetNextInternal(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
247 bool* end_of_sequence) override;
248
249 protected:
250 Status SaveInternal(SerializationContext* ctx,
251 IteratorStateWriter* writer) override;
252
253 Status RestoreInternal(IteratorContext* ctx,
254 IteratorStateReader* reader) override;
255
256 private:
257 Status GetShardIndex(IteratorContext* ctx, const std::vector<Tensor>& tensors,
258 int64* shard_index) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
259
260 Status WriteMetadataFile(Env* env, bool finalized)
261 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
262
263 void SignalEOF(bool mark_closed) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
264
265 mutex mu_;
266 mutex writer_status_mu_;
267 std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_);
268
269 absl::flat_hash_map<int64, std::unique_ptr<snapshot_util::AsyncWriter>>
270 writers_ TF_GUARDED_BY(mu_);
271 Status writer_status_ TF_GUARDED_BY(writer_status_mu_);
272 bool writers_closed_ TF_GUARDED_BY(mu_);
273
274 uint64 run_id_ TF_GUARDED_BY(mu_);
275 tstring run_dir_ TF_GUARDED_BY(mu_);
276
277 // Stores the ID of the current checkpoint .snapshot file being read. See top
278 // of this file for the directory layout.
279 uint64 current_checkpoint_id_ TF_GUARDED_BY(mu_);
280
281 std::unique_ptr<InstantiatedCapturedFunction> instantiated_shard_func_
282 TF_GUARDED_BY(mu_);
283 };
284
285 class SnapshotDatasetV2Op::Dataset::Iterator::Passthrough
286 : public DatasetIterator<Dataset> {
287 public:
288 static constexpr const char* const kIteratorName = "Passthrough";
289
290 explicit Passthrough(const Params& params);
291
292 Status Initialize(IteratorContext* ctx) override;
293
294 Status GetNextInternal(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
295 bool* end_of_sequence) override;
296
297 protected:
298 Status SaveInternal(SerializationContext* ctx,
299 IteratorStateWriter* writer) override;
300
301 Status RestoreInternal(IteratorContext* ctx,
302 IteratorStateReader* reader) override;
303
304 private:
305 std::unique_ptr<IteratorBase> input_impl_;
306 };
307
Dataset(OpKernelContext * ctx,const DatasetBase * input,uint64 hash,const std::string & path,const std::string & compression,const std::string & reader_prefix,const std::string & writer_prefix,std::unique_ptr<CapturedFunction> reader_func,std::unique_ptr<CapturedFunction> shard_func)308 SnapshotDatasetV2Op::Dataset::Dataset(
309 OpKernelContext* ctx, const DatasetBase* input, uint64 hash,
310 const std::string& path, const std::string& compression,
311 const std::string& reader_prefix, const std::string& writer_prefix,
312 std::unique_ptr<CapturedFunction> reader_func,
313 std::unique_ptr<CapturedFunction> shard_func)
314 : DatasetBase(DatasetContext(ctx)),
315 input_(input),
316 hash_(hash),
317 path_(path),
318 compression_(compression),
319 reader_prefix_(reader_prefix),
320 writer_prefix_(writer_prefix),
321 reader_func_(std::move(reader_func)),
322 shard_func_(std::move(shard_func)) {
323 input_->Ref();
324 }
325
~Dataset()326 SnapshotDatasetV2Op::Dataset::~Dataset() { input_->Unref(); }
327
328 std::unique_ptr<IteratorBase>
MakeIteratorInternal(const string & prefix) const329 SnapshotDatasetV2Op::Dataset::MakeIteratorInternal(const string& prefix) const {
330 return absl::make_unique<Iterator>(
331 Iterator::Params{this, absl::StrCat(prefix, "::Snapshot")});
332 }
333
output_dtypes() const334 const DataTypeVector& SnapshotDatasetV2Op::Dataset::output_dtypes() const {
335 return input_->output_dtypes();
336 }
337
338 const std::vector<PartialTensorShape>&
output_shapes() const339 SnapshotDatasetV2Op::Dataset::output_shapes() const {
340 return input_->output_shapes();
341 }
342
DebugString() const343 string SnapshotDatasetV2Op::Dataset::DebugString() const {
344 return name_utils::DatasetDebugString(kDatasetType);
345 }
346
Cardinality() const347 int64 SnapshotDatasetV2Op::Dataset::Cardinality() const {
348 return input_->Cardinality();
349 }
350
InputDatasets(std::vector<const DatasetBase * > * inputs) const351 Status SnapshotDatasetV2Op::Dataset::InputDatasets(
352 std::vector<const DatasetBase*>* inputs) const {
353 inputs->push_back(input_);
354 return Status::OK();
355 }
356
CheckExternalState() const357 Status SnapshotDatasetV2Op::Dataset::CheckExternalState() const {
358 return input_->CheckExternalState();
359 }
360
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const361 Status SnapshotDatasetV2Op::Dataset::AsGraphDefInternal(
362 SerializationContext* ctx, DatasetGraphDefBuilder* b, Node** output) const {
363 Node* input_graph_node = nullptr;
364 TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
365
366 Node* path = nullptr;
367 TF_RETURN_IF_ERROR(b->AddScalar(path_, &path));
368
369 std::vector<Node*> reader_func_other_args;
370 DataTypeVector reader_func_other_args_types;
371 TF_RETURN_IF_ERROR(reader_func_->AddToGraph(ctx, b, &reader_func_other_args,
372 &reader_func_other_args_types));
373
374 std::vector<Node*> shard_func_other_args;
375 DataTypeVector shard_func_other_args_types;
376 TF_RETURN_IF_ERROR(shard_func_->AddToGraph(ctx, b, &shard_func_other_args,
377 &shard_func_other_args_types));
378
379 AttrValue compression_attr;
380 b->BuildAttrValue(compression_, &compression_attr);
381
382 AttrValue reader_prefix_attr;
383 b->BuildAttrValue(reader_prefix_, &reader_prefix_attr);
384
385 AttrValue writer_prefix_attr;
386 b->BuildAttrValue(writer_prefix_, &writer_prefix_attr);
387
388 AttrValue hash_valid_attr;
389 b->BuildAttrValue(true, &hash_valid_attr);
390
391 AttrValue hash_attr;
392 b->BuildAttrValue(static_cast<int64>(hash_), &hash_attr);
393
394 AttrValue reader_func_attr;
395 b->BuildAttrValue(reader_func_->func(), &reader_func_attr);
396
397 AttrValue shard_func_attr;
398 b->BuildAttrValue(shard_func_->func(), &shard_func_attr);
399
400 AttrValue reader_func_arguments_types_attr;
401 b->BuildAttrValue(reader_func_other_args_types,
402 &reader_func_arguments_types_attr);
403
404 AttrValue shard_func_arguments_types_attr;
405 b->BuildAttrValue(shard_func_other_args_types,
406 &shard_func_arguments_types_attr);
407
408 return b->AddDataset(
409 this,
410 /*inputs=*/
411 {std::make_pair(0, input_graph_node), std::make_pair(1, path)},
412 /*list_inputs=*/
413 {std::make_pair(2, reader_func_other_args),
414 std::make_pair(3, shard_func_other_args)},
415 /*attrs=*/
416 {{kCompression, compression_attr},
417 {kReaderPrefix, reader_prefix_attr},
418 {kWriterPrefix, writer_prefix_attr},
419 {kHashValid, hash_valid_attr},
420 {kHash, hash_attr},
421 {kReaderFunc, reader_func_attr},
422 {kShardFunc, shard_func_attr},
423 {kReaderFuncTarguments, reader_func_arguments_types_attr},
424 {kShardFuncTarguments, shard_func_arguments_types_attr}},
425 output);
426 }
427
Iterator(const Params & params)428 SnapshotDatasetV2Op::Dataset::Iterator::Iterator(const Params& params)
429 : DatasetIterator<Dataset>(params),
430 index_(0),
431 hash_dir_(
432 snapshot_util::HashDirectory(dataset()->path_, dataset()->hash_)) {}
433
Initialize(IteratorContext * ctx)434 Status SnapshotDatasetV2Op::Dataset::Iterator::Initialize(
435 IteratorContext* ctx) {
436 return ctx->env()->RecursivelyCreateDir(
437 io::JoinPath(dataset()->writer_prefix_, hash_dir_));
438 }
439
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)440 Status SnapshotDatasetV2Op::Dataset::Iterator::SaveInternal(
441 SerializationContext* ctx, IteratorStateWriter* writer) {
442 mutex_lock l(mu_);
443 if (iterator_ != nullptr) {
444 TF_RETURN_IF_ERROR(SaveInput(ctx, writer, iterator_));
445 TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kIteratorMode),
446 static_cast<int64>(mode_)));
447 TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kIndex), index_));
448 TF_RETURN_IF_ERROR(
449 writer->WriteScalar(full_name(kGraphHashDirectory), hash_dir_));
450 }
451
452 return Status::OK();
453 }
454
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)455 Status SnapshotDatasetV2Op::Dataset::Iterator::RestoreInternal(
456 IteratorContext* ctx, IteratorStateReader* reader) {
457 mutex_lock l(mu_);
458
459 if (reader->Contains(full_name(kIteratorMode))) {
460 TF_RETURN_IF_ERROR(InitializeIterator(ctx, reader));
461 return RestoreInput(ctx, reader, iterator_);
462 }
463
464 return Status::OK();
465 }
466
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)467 Status SnapshotDatasetV2Op::Dataset::Iterator::GetNextInternal(
468 IteratorContext* ctx, std::vector<Tensor>* out_tensors,
469 bool* end_of_sequence) {
470 mutex_lock l(mu_);
471 if (iterator_ == nullptr) {
472 TF_RETURN_IF_ERROR(InitializeIterator(ctx, nullptr));
473 }
474 index_++;
475 return iterator_->GetNext(ctx, out_tensors, end_of_sequence);
476 }
477
InitializeIterator(IteratorContext * ctx,IteratorStateReader * reader)478 Status SnapshotDatasetV2Op::Dataset::Iterator::InitializeIterator(
479 IteratorContext* ctx, IteratorStateReader* reader)
480 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
481 if (reader != nullptr) {
482 // Check whether the computed hash directory is the same.
483 tstring hash_dir;
484 TF_RETURN_IF_ERROR(
485 reader->ReadScalar(full_name(kGraphHashDirectory), &hash_dir));
486 if (hash_dir != hash_dir_) {
487 return errors::DataLoss(
488 "Dataset has changed while restoring from the checkpoint. Old hash "
489 "directory: ",
490 hash_dir, "; new hash directory: ", hash_dir_);
491 }
492
493 experimental::SnapshotMetadataRecord metadata;
494 bool file_exists;
495 TF_RETURN_IF_ERROR(snapshot_util::ReadMetadataFile(
496 ctx->env(), io::JoinPath(dataset()->reader_prefix_, hash_dir_),
497 &metadata, &file_exists));
498 if (!file_exists) {
499 return errors::DataLoss("Snapshot metadata file in ", hash_dir_,
500 " does not exist any more.");
501 }
502
503 int64 iterator_mode;
504 TF_RETURN_IF_ERROR(
505 reader->ReadScalar(full_name(kIteratorMode), &iterator_mode));
506 mode_ = snapshot_util::Mode(iterator_mode);
507
508 TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kIndex), &index_));
509 } else {
510 experimental::SnapshotMetadataRecord metadata;
511 bool file_exists;
512 TF_RETURN_IF_ERROR(snapshot_util::ReadMetadataFile(
513 ctx->env(), io::JoinPath(dataset()->reader_prefix_, hash_dir_),
514 &metadata, &file_exists));
515
516 // `pending_snapshot_expiry_seconds` is a legacy option where we would not
517 // write snapshots that we think were still on-going. We decided that this
518 // would not be necessary as a feature for SnapshotV2, and we would always
519 // write a new snapshot regardless of whether someone else is currently
520 // writing one. Setting this to 0 ensures that all previous snapshots
521 // will be ignored and we will proceed to writing.
522 TF_RETURN_IF_ERROR(snapshot_util::DetermineOpState(
523 /*mode_string=*/"", file_exists, &metadata,
524 /*pending_snapshot_expiry_seconds=*/0, &mode_));
525 }
526
527 switch (mode_) {
528 case snapshot_util::READER:
529 iterator_ = absl::make_unique<Reader>(
530 Reader::Params{dataset(),
531 absl::StrCat(prefix(), Reader::kIteratorName)},
532 index_);
533 break;
534 case snapshot_util::WRITER:
535 iterator_ = absl::make_unique<Writer>(Writer::Params{
536 dataset(), absl::StrCat(prefix(), Writer::kIteratorName)});
537 break;
538 case snapshot_util::PASSTHROUGH:
539 iterator_ = absl::make_unique<Passthrough>(Passthrough::Params{
540 dataset(), absl::StrCat(prefix(), Passthrough::kIteratorName)});
541 break;
542 }
543 TF_RETURN_IF_ERROR(iterator_->InitializeBase(ctx, this));
544 return iterator_->Initialize(ctx);
545 }
546
Reader(const Params & params,int64 start_index)547 SnapshotDatasetV2Op::Dataset::Iterator::Reader::Reader(const Params& params,
548 int64 start_index)
549 : DatasetIterator<Dataset>(params), start_index_(start_index) {}
550
~Reader()551 SnapshotDatasetV2Op::Dataset::Iterator::Reader::~Reader() { input_->Unref(); }
552
Initialize(IteratorContext * ctx)553 Status SnapshotDatasetV2Op::Dataset::Iterator::Reader::Initialize(
554 IteratorContext* ctx) {
555 mutex_lock l(mu_);
556
557 TF_RETURN_IF_ERROR(
558 dataset()->reader_func_->Instantiate(ctx, &instantiated_reader_func_));
559
560 auto hash_dir = snapshot_util::HashDirectory(
561 io::JoinPath(dataset()->reader_prefix_, dataset()->path_),
562 dataset()->hash_);
563 bool metadata_file_exists;
564 experimental::SnapshotMetadataRecord metadata;
565 TF_RETURN_IF_ERROR(snapshot_util::ReadMetadataFile(
566 ctx->env(), hash_dir, &metadata, &metadata_file_exists));
567
568 auto run_dir = snapshot_util::RunDirectory(hash_dir, metadata.run_id());
569
570 std::vector<std::string> snapshot_shard_dirs;
571 TF_RETURN_IF_ERROR(ctx->env()->GetMatchingPaths(
572 io::JoinPath(
573 run_dir,
574 strings::Printf("%s%s", "*", snapshot_util::kShardDirectorySuffix)),
575 &snapshot_shard_dirs));
576 std::sort(snapshot_shard_dirs.begin(), snapshot_shard_dirs.end());
577
578 DatasetBase* dataset_of_snapshot_files;
579 TF_RETURN_IF_ERROR(snapshot_util::Reader::MakeNestedDataset(
580 ctx->env(), snapshot_shard_dirs, dataset()->compression_,
581 metadata.version(), dataset()->output_dtypes(),
582 dataset()->output_shapes(), start_index_, &dataset_of_snapshot_files));
583
584 Tensor input_dataset_tensor(DT_VARIANT, TensorShape({}));
585 TF_RETURN_IF_ERROR(StoreDatasetInVariantTensor(dataset_of_snapshot_files,
586 &input_dataset_tensor));
587
588 std::vector<Tensor> reader_input;
589 std::vector<Tensor> reader_output;
590 reader_input.push_back(std::move(input_dataset_tensor));
591
592 // NOTE: We intentionally ignore resource modeling outside GetNext().
593 TF_RETURN_IF_ERROR(instantiated_reader_func_->Run(
594 ctx, std::move(reader_input), &reader_output, /*node=*/nullptr));
595 if (reader_output.size() != 1) {
596 return errors::InvalidArgument(
597 "reader_func returns more than one argument.");
598 }
599 TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(reader_output[0], &input_));
600
601 // We need to take a reference here as we will use the input_ and
602 // its iterator.
603 input_->Ref();
604
605 return input_->MakeIterator(ctx, this, prefix(), &input_impl_);
606 }
607
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)608 Status SnapshotDatasetV2Op::Dataset::Iterator::Reader::GetNextInternal(
609 IteratorContext* ctx, std::vector<Tensor>* out_tensors,
610 bool* end_of_sequence) {
611 mutex_lock l(mu_);
612 return input_impl_->GetNext(ctx, out_tensors, end_of_sequence);
613 }
614
615 // We do not need to checkpoint the reader as we are rebuilding the reader
616 // datasets from information that is already saved by the main iterator.
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)617 Status SnapshotDatasetV2Op::Dataset::Iterator::Reader::SaveInternal(
618 SerializationContext* ctx, IteratorStateWriter* writer) {
619 return Status::OK();
620 }
621
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)622 Status SnapshotDatasetV2Op::Dataset::Iterator::Reader::RestoreInternal(
623 IteratorContext* ctx, IteratorStateReader* reader) {
624 return Status::OK();
625 }
626
Writer(const Params & params)627 SnapshotDatasetV2Op::Dataset::Iterator::Writer::Writer(const Params& params)
628 : DatasetIterator<Dataset>(params),
629 writers_closed_(false),
630 run_id_(0),
631 current_checkpoint_id_(0) {}
632
~Writer()633 SnapshotDatasetV2Op::Dataset::Iterator::Writer::~Writer() {
634 mutex_lock l(mu_);
635 SignalEOF(true);
636 }
637
SignalEOF(bool mark_closed)638 void SnapshotDatasetV2Op::Dataset::Iterator::Writer::SignalEOF(bool mark_closed)
639 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
640 if (!writers_closed_) {
641 // Push the end of sequence signal to each of the threads to close files.
642 for (auto& writer : writers_) {
643 writer.second->SignalEOF();
644 }
645
646 writers_.clear();
647 writers_closed_ = mark_closed;
648 }
649 }
650
WriteMetadataFile(Env * env,bool finalized)651 Status SnapshotDatasetV2Op::Dataset::Iterator::Writer::WriteMetadataFile(
652 Env* env, bool finalized) {
653 DCHECK(!run_dir_.empty());
654
655 experimental::SnapshotMetadataRecord metadata;
656 metadata.set_creation_timestamp(EnvTime::NowMicros());
657 metadata.set_graph_hash(strings::StrCat(dataset()->hash_));
658 metadata.set_run_id(strings::StrCat(run_id_));
659 metadata.set_version(kFileFormatVersion);
660 for (const auto& output_dtype : dataset()->output_dtypes()) {
661 metadata.add_dtype(output_dtype);
662 }
663 metadata.set_finalized(finalized);
664 tstring hash_directory = io::JoinPath(
665 dataset()->writer_prefix_,
666 snapshot_util::HashDirectory(dataset()->path_, dataset()->hash_));
667
668 return snapshot_util::WriteMetadataFile(env, hash_directory, &metadata);
669 }
670
Initialize(IteratorContext * ctx)671 Status SnapshotDatasetV2Op::Dataset::Iterator::Writer::Initialize(
672 IteratorContext* ctx) {
673 mutex_lock l(mu_);
674 TF_RETURN_IF_ERROR(
675 dataset()->shard_func_->Instantiate(ctx, &instantiated_shard_func_));
676
677 return dataset()->input_->MakeIterator(
678 ctx, this, strings::StrCat(prefix(), "::WriterIterator"), &input_impl_);
679 }
680
GetShardIndex(IteratorContext * ctx,const std::vector<Tensor> & tensors,int64 * shard_index)681 Status SnapshotDatasetV2Op::Dataset::Iterator::Writer::GetShardIndex(
682 IteratorContext* ctx, const std::vector<Tensor>& tensors,
683 int64* shard_index) {
684 std::vector<Tensor> output_tensors;
685
686 // Run the shard function
687 TF_RETURN_IF_ERROR(instantiated_shard_func_->RunWithBorrowedArgs(
688 ctx, tensors, &output_tensors, model_node()));
689
690 if (output_tensors.size() != 1 || output_tensors[0].dtype() != DT_INT64 ||
691 output_tensors[0].NumElements() != 1) {
692 return errors::InvalidArgument("`shard_func` must return a scalar int64.");
693 }
694
695 // Create writable files if we see an index bigger than our current files.
696 *shard_index = output_tensors[0].flat<int64>()(0);
697 return Status::OK();
698 }
699
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)700 Status SnapshotDatasetV2Op::Dataset::Iterator::Writer::GetNextInternal(
701 IteratorContext* ctx, std::vector<Tensor>* out_tensors,
702 bool* end_of_sequence) {
703 *end_of_sequence = false;
704 snapshot_util::AsyncWriter* current_writer;
705
706 {
707 std::vector<Tensor> output_tensors;
708 mutex_lock l(mu_);
709
710 // We initialize late here because restoring from checkpoint comes after the
711 // the Initialize call. We cannot initialize within Initialize() because
712 // we cannot determine whether we should overwrite an existing metadata
713 // file or not before `RestoreInternal` is potentially called.
714 if (run_dir_.empty()) {
715 run_id_ = random::New64();
716
717 // Creates the run directory.
718 run_dir_ = snapshot_util::RunDirectory(
719 snapshot_util::HashDirectory(
720 io::JoinPath(dataset()->writer_prefix_, dataset()->path_),
721 dataset()->hash_),
722 run_id_);
723 TF_RETURN_IF_ERROR(ctx->env()->RecursivelyCreateDir(run_dir_));
724 TF_RETURN_IF_ERROR(WriteMetadataFile(ctx->env(), /*finalized=*/false));
725 }
726
727 // Writers have either encountered an error or are closed.
728 {
729 mutex_lock wsl(writer_status_mu_);
730 if (!writer_status_.ok() || writers_closed_) {
731 *end_of_sequence = true;
732 return writer_status_;
733 }
734 }
735
736 TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, out_tensors, end_of_sequence));
737
738 // Finalize metadata file when we are at the end of the iterator.
739 if (*end_of_sequence) {
740 SignalEOF(/*mark_closed=*/true);
741 {
742 mutex_lock wsl(writer_status_mu_);
743 TF_RETURN_IF_ERROR(writer_status_);
744 }
745 return WriteMetadataFile(ctx->env(), /*finalized=*/true);
746 }
747
748 int64 shard_index = 0;
749 TF_RETURN_IF_ERROR(GetShardIndex(ctx, *out_tensors, &shard_index));
750
751 // If the index does not exist, we will start a new thread.
752 if (writers_.count(shard_index) == 0) {
753 auto snapshot_shard_directory =
754 snapshot_util::ShardDirectory(run_dir_, shard_index);
755 auto writer = std::make_unique<snapshot_util::AsyncWriter>(
756 ctx->env(), shard_index, snapshot_shard_directory,
757 current_checkpoint_id_, dataset()->compression_, kFileFormatVersion,
758 dataset()->output_dtypes(), [this](Status s) {
759 if (!s.ok()) {
760 LOG(ERROR) << "AsyncWriter in snapshot writer failed: " << s;
761 mutex_lock l(writer_status_mu_);
762 writer_status_ = s;
763 }
764 });
765 writers_.insert({shard_index, std::move(writer)});
766 }
767 current_writer = writers_[shard_index].get();
768 }
769
770 current_writer->Write(*out_tensors);
771 return Status::OK();
772 }
773
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)774 Status SnapshotDatasetV2Op::Dataset::Iterator::Writer::SaveInternal(
775 SerializationContext* ctx, IteratorStateWriter* writer) {
776 mutex_lock l(mu_);
777 TF_RETURN_IF_ERROR(
778 writer->WriteScalar(full_name(kRunId), static_cast<int64>(run_id_)));
779 TF_RETURN_IF_ERROR(
780 writer->WriteScalar(full_name(kCurrentCheckpointId),
781 static_cast<int64>(current_checkpoint_id_)));
782 SignalEOF(/*mark_closed=*/false);
783 writers_.clear();
784 current_checkpoint_id_++;
785 return SaveInput(ctx, writer, input_impl_);
786 }
787
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)788 Status SnapshotDatasetV2Op::Dataset::Iterator::Writer::RestoreInternal(
789 IteratorContext* ctx, IteratorStateReader* reader) {
790 mutex_lock l(mu_);
791 int64 run_id_signed;
792 int64 current_checkpoint_id;
793
794 TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kRunId), &run_id_signed));
795 TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kCurrentCheckpointId),
796 ¤t_checkpoint_id));
797
798 run_id_ = static_cast<uint64>(run_id_signed);
799 run_dir_ = snapshot_util::RunDirectory(
800 snapshot_util::HashDirectory(
801 io::JoinPath(dataset()->writer_prefix_, dataset()->path_),
802 dataset()->hash_),
803 run_id_);
804 current_checkpoint_id_ = static_cast<uint64>(current_checkpoint_id);
805
806 return RestoreInput(ctx, reader, input_impl_);
807 }
808
Passthrough(const Params & params)809 SnapshotDatasetV2Op::Dataset::Iterator::Passthrough::Passthrough(
810 const Params& params)
811 : DatasetIterator<Dataset>(params) {}
812
Initialize(IteratorContext * ctx)813 Status SnapshotDatasetV2Op::Dataset::Iterator::Passthrough::Initialize(
814 IteratorContext* ctx) {
815 return dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_);
816 }
817
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)818 Status SnapshotDatasetV2Op::Dataset::Iterator::Passthrough::GetNextInternal(
819 IteratorContext* ctx, std::vector<Tensor>* out_tensors,
820 bool* end_of_sequence) {
821 return input_impl_->GetNext(ctx, out_tensors, end_of_sequence);
822 }
823
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)824 Status SnapshotDatasetV2Op::Dataset::Iterator::Passthrough::SaveInternal(
825 SerializationContext* ctx, IteratorStateWriter* writer) {
826 return SaveInput(ctx, writer, input_impl_);
827 }
828
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)829 Status SnapshotDatasetV2Op::Dataset::Iterator::Passthrough::RestoreInternal(
830 IteratorContext* ctx, IteratorStateReader* reader) {
831 return RestoreInput(ctx, reader, input_impl_);
832 }
833
SnapshotDatasetV2Op(OpKernelConstruction * ctx)834 SnapshotDatasetV2Op::SnapshotDatasetV2Op(OpKernelConstruction* ctx)
835 : UnaryDatasetOpKernel(ctx), graph_def_version_(ctx->graph_def_version()) {
836 FunctionMetadata::Params reader_params;
837 FunctionMetadata::Params shard_params;
838
839 OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_types_));
840 OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
841 OP_REQUIRES_OK(ctx, ctx->GetAttr(kCompression, &compression_));
842
843 if (ctx->HasAttr(kReaderPrefix)) {
844 OP_REQUIRES_OK(ctx, ctx->GetAttr(kReaderPrefix, &reader_prefix_));
845 }
846
847 if (ctx->HasAttr(kWriterPrefix)) {
848 OP_REQUIRES_OK(ctx, ctx->GetAttr(kWriterPrefix, &writer_prefix_));
849 }
850 OP_REQUIRES_OK(ctx, ctx->GetAttr(kHashValid, &hash_valid_));
851 int64 hash;
852 OP_REQUIRES_OK(ctx, ctx->GetAttr(kHash, &hash));
853 hash_ = static_cast<uint64>(hash);
854
855 OP_REQUIRES_OK(ctx, FunctionMetadata::Create(ctx, kReaderFunc, reader_params,
856 &reader_func_metadata_));
857 OP_REQUIRES_OK(ctx, FunctionMetadata::Create(ctx, kShardFunc, shard_params,
858 &shard_func_metadata_));
859 }
860
MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)861 void SnapshotDatasetV2Op::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
862 DatasetBase** output) {
863 tstring path;
864 OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "path", &path));
865
866 std::string compression = compression_ == kCompressionAuto
867 ? io::compression::kSnappy
868 : compression_;
869 uint64 hash;
870 if (hash_valid_) {
871 hash = hash_;
872 } else {
873 // Computes the hash of the preceding items in the graph.
874 GraphDef graph_def;
875 SerializationContext::Params params;
876 std::vector<std::pair<string, Tensor>> input_list;
877 params.input_list = &input_list;
878 params.external_state_policy =
879 SerializationContext::ExternalStatePolicy::kIgnore;
880 OP_REQUIRES_OK(
881 ctx, AsGraphDef(ctx, input, SerializationContext(params), &graph_def));
882 OP_REQUIRES_OK(ctx, HashGraph(graph_def, &hash));
883 // Different compression modes should result in different graph hashes.
884 hash = Hash64Combine(hash, Hash64(compression));
885 }
886
887 std::unique_ptr<CapturedFunction> reader_func;
888 OP_REQUIRES_OK(ctx,
889 CapturedFunction::Create(ctx, reader_func_metadata_,
890 kReaderFuncOtherArgs, &reader_func));
891 std::unique_ptr<CapturedFunction> shard_func;
892 OP_REQUIRES_OK(ctx,
893 CapturedFunction::Create(ctx, shard_func_metadata_,
894 kShardFuncOtherArgs, &shard_func));
895
896 *output = new SnapshotDatasetV2Op::Dataset(
897 ctx, input, hash, path, compression, reader_prefix_, writer_prefix_,
898 std::move(reader_func), std::move(shard_func));
899 }
900
901 namespace {
902 REGISTER_KERNEL_BUILDER(Name("SnapshotDatasetV2").Device(DEVICE_CPU),
903 SnapshotDatasetV2Op);
904 } // namespace
905
906 // ==== Legacy Snapshot Implementation (Deprecated) ====
907
908 namespace {
909
910 // Defaults to 10 GiB per shard.
911 const int64 kDefaultShardSizeBytes = 10LL * 1024 * 1024 * 1024;
912
913 const int64 kCurrentVersion = 1;
914
915 constexpr char kSnapshotReaderWorkerPool[] = "snapshot_reader_worker_pool";
916 constexpr char kSnapshotWriterWorkerPool[] = "snapshot_writer_worker_pool";
917 constexpr char kSeparator[] = "::";
918 constexpr char kBookkeeping[] = "Bookkeeping";
919 constexpr char kSnapshotReadElements[] = "snapshot_read_elements";
920 constexpr char kSnapshotReadThroughput[] = "snapshot_read_throughput";
921 constexpr char kSnapshotWrittenElements[] = "snapshot_written_elements";
922 constexpr char kSnapshotWriteThroughput[] = "snapshot_write_throughput";
923
924 constexpr char kSizeSuffix[] = "_size";
925 constexpr char kState[] = "state";
926 constexpr char kHashDir[] = "hash_dir";
927 constexpr char kRunId[] = "run_id";
928 constexpr char kRunDir[] = "run_dir";
929 constexpr char kVersionStr[] = "version";
930 constexpr char kFilenames[] = "filenames";
931 constexpr char kCurrentFilenames[] = "current_filenames";
932 constexpr char kElementsProduced[] = "elements_produced";
933 constexpr char kNextFileIndex[] = "next_file_index";
934 constexpr char kNumFilesDone[] = "num_files_done";
935 constexpr char kNumElementsRead[] = "num_elements_read";
936 constexpr char kStatus[] = "status";
937 constexpr char kCode[] = ".code";
938 constexpr char kErrorMessage[] = ".error_message";
939 constexpr char kEndOfSequence[] = "end_of_sequence";
940 constexpr char kBuffer[] = "buffer";
941 constexpr char kNumElementsWritten[] = "num_elements_written";
942 constexpr char kNextElem[] = "next_elem";
943
944 class SnapshotDatasetOp : public UnaryDatasetOpKernel {
945 public:
SnapshotDatasetOp(OpKernelConstruction * ctx)946 explicit SnapshotDatasetOp(OpKernelConstruction* ctx)
947 : UnaryDatasetOpKernel(ctx),
948 graph_def_version_(ctx->graph_def_version()) {
949 OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
950 OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
951
952 OP_REQUIRES_OK(ctx,
953 ctx->GetAttr("reader_path_prefix", &reader_path_prefix_));
954 OP_REQUIRES_OK(ctx,
955 ctx->GetAttr("writer_path_prefix", &writer_path_prefix_));
956 OP_REQUIRES_OK(ctx, ctx->GetAttr("compression", &compression_));
957
958 OP_REQUIRES_OK(ctx, ctx->GetAttr("shard_size_bytes", &shard_size_bytes_));
959 OP_REQUIRES_OK(ctx, ctx->GetAttr("pending_snapshot_expiry_seconds",
960 &pending_snapshot_expiry_seconds_));
961 OP_REQUIRES_OK(ctx,
962 ctx->GetAttr("num_reader_threads", &num_reader_threads_));
963 OP_REQUIRES_OK(ctx,
964 ctx->GetAttr("reader_buffer_size", &reader_buffer_size_));
965 OP_REQUIRES_OK(ctx,
966 ctx->GetAttr("num_writer_threads", &num_writer_threads_));
967 OP_REQUIRES_OK(ctx,
968 ctx->GetAttr("writer_buffer_size", &writer_buffer_size_));
969 OP_REQUIRES_OK(ctx, ctx->GetAttr("shuffle_on_read", &shuffle_on_read_));
970 OP_REQUIRES_OK(ctx, ctx->GetAttr("seed", &seed_));
971 OP_REQUIRES_OK(ctx, ctx->GetAttr("seed2", &seed2_));
972
973 mode_ = snapshot_util::kModeAuto;
974 if (ctx->HasAttr("mode")) {
975 OP_REQUIRES_OK(ctx, ctx->GetAttr("mode", &mode_));
976 }
977
978 snapshot_name_ = "";
979 if (ctx->HasAttr("snapshot_name")) {
980 OP_REQUIRES_OK(ctx, ctx->GetAttr("snapshot_name", &snapshot_name_));
981 }
982
983 if (shard_size_bytes_ == -1) shard_size_bytes_ = kDefaultShardSizeBytes;
984
985 // Default to 1 day expiry for snapshots.
986 if (pending_snapshot_expiry_seconds_ == -1) {
987 pending_snapshot_expiry_seconds_ = 86400;
988 }
989
990 if (num_reader_threads_ == -1) num_reader_threads_ = 1;
991 if (reader_buffer_size_ == -1) reader_buffer_size_ = 1;
992 if (num_writer_threads_ == -1) num_writer_threads_ = 1;
993 if (writer_buffer_size_ == -1) writer_buffer_size_ = 1;
994
995 OP_REQUIRES(
996 ctx,
997 compression_ == io::compression::kNone ||
998 compression_ == io::compression::kGzip ||
999 compression_ == io::compression::kSnappy,
1000 errors::InvalidArgument("compression must be either '', 'GZIP' or "
1001 "'SNAPPY'."));
1002
1003 OP_REQUIRES(
1004 ctx, pending_snapshot_expiry_seconds_ >= 1,
1005 errors::InvalidArgument(
1006 "pending_snapshot_expiry_seconds must be at least 1 second."));
1007
1008 OP_REQUIRES(ctx,
1009 mode_ == snapshot_util::kModeAuto ||
1010 mode_ == snapshot_util::kModeRead ||
1011 mode_ == snapshot_util::kModeWrite ||
1012 mode_ == snapshot_util::kModePassthrough,
1013 errors::InvalidArgument(
1014 "mode must be either '", snapshot_util::kModeAuto, "', '",
1015 snapshot_util::kModeRead, "', '", snapshot_util::kModeWrite,
1016 "', or '", snapshot_util::kModePassthrough, "'."));
1017 }
1018
1019 protected:
MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)1020 void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
1021 DatasetBase** output) override {
1022 tstring path;
1023
1024 OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "path", &path));
1025
1026 SerializationContext::Params params;
1027 std::vector<std::pair<string, Tensor>> input_list;
1028 params.input_list = &input_list;
1029 params.external_state_policy =
1030 SerializationContext::ExternalStatePolicy::kIgnore;
1031
1032 GraphDef graph_def;
1033 OP_REQUIRES_OK(
1034 ctx, AsGraphDef(ctx, input, SerializationContext(params), &graph_def));
1035
1036 uint64 hash;
1037 OP_REQUIRES_OK(ctx, ComputeDatasetHash(graph_def, path, &hash));
1038
1039 Status dump_status =
1040 snapshot_util::DumpDatasetGraph(ctx->env(), path, hash, &graph_def);
1041 if (!dump_status.ok()) {
1042 LOG(WARNING) << "Unable to write graphdef to disk, error: "
1043 << dump_status.ToString();
1044 }
1045
1046 std::string graph_hash =
1047 strings::StrCat(strings::Hex(hash, strings::kZeroPad16));
1048 LOG(INFO) << "Graph def serialized to hash: " << graph_hash;
1049
1050 *output = new Dataset(ctx, input, path, graph_hash, reader_path_prefix_,
1051 writer_path_prefix_, compression_, shard_size_bytes_,
1052 pending_snapshot_expiry_seconds_, num_reader_threads_,
1053 reader_buffer_size_, num_writer_threads_,
1054 writer_buffer_size_, shuffle_on_read_, seed_, seed2_,
1055 mode_, snapshot_name_);
1056 }
1057
1058 private:
1059 class Dataset : public DatasetBase {
1060 public:
Dataset(OpKernelContext * ctx,const DatasetBase * input,const string & path,const string & graph_hash,const string & reader_path_prefix,const string & writer_path_prefix,const string & compression,const uint64 shard_size_bytes,const uint64 pending_snapshot_expiry_seconds,const uint64 num_reader_threads,const uint64 reader_buffer_size,const uint64 num_writer_threads,const uint64 writer_buffer_size,const bool shuffle_on_read,const uint64 seed,const uint64 seed2,const std::string & mode,const std::string & snapshot_name)1061 Dataset(OpKernelContext* ctx, const DatasetBase* input, const string& path,
1062 const string& graph_hash, const string& reader_path_prefix,
1063 const string& writer_path_prefix, const string& compression,
1064 const uint64 shard_size_bytes,
1065 const uint64 pending_snapshot_expiry_seconds,
1066 const uint64 num_reader_threads, const uint64 reader_buffer_size,
1067 const uint64 num_writer_threads, const uint64 writer_buffer_size,
1068 const bool shuffle_on_read, const uint64 seed, const uint64 seed2,
1069 const std::string& mode, const std::string& snapshot_name)
1070 : DatasetBase(DatasetContext(ctx)),
1071 input_(input),
1072 dir_(path),
1073 graph_hash_(graph_hash),
1074 reader_path_prefix_(reader_path_prefix),
1075 writer_path_prefix_(writer_path_prefix),
1076 compression_(compression),
1077 shard_size_bytes_(shard_size_bytes),
1078 pending_snapshot_expiry_seconds_(pending_snapshot_expiry_seconds),
1079 num_reader_threads_(num_reader_threads),
1080 reader_buffer_size_(reader_buffer_size),
1081 num_writer_threads_(num_writer_threads),
1082 writer_buffer_size_(writer_buffer_size),
1083 shuffle_on_read_(shuffle_on_read),
1084 seed_(seed),
1085 seed2_(seed2),
1086 mode_(mode),
1087 snapshot_name_(snapshot_name) {
1088 input_->Ref();
1089 }
1090
~Dataset()1091 ~Dataset() override { input_->Unref(); }
1092
MakeIteratorInternal(const string & prefix) const1093 std::unique_ptr<IteratorBase> MakeIteratorInternal(
1094 const string& prefix) const override {
1095 return absl::make_unique<Iterator>(
1096 Iterator::Params{this, absl::StrCat(prefix, "::Snapshot")});
1097 }
1098
output_dtypes() const1099 const DataTypeVector& output_dtypes() const override {
1100 return input_->output_dtypes();
1101 }
1102
output_shapes() const1103 const std::vector<PartialTensorShape>& output_shapes() const override {
1104 return input_->output_shapes();
1105 }
1106
DebugString() const1107 string DebugString() const override { return "SnapshotDatasetOp::Dataset"; }
1108
Cardinality() const1109 int64 Cardinality() const override { return input_->Cardinality(); }
1110
InputDatasets(std::vector<const DatasetBase * > * inputs) const1111 Status InputDatasets(
1112 std::vector<const DatasetBase*>* inputs) const override {
1113 inputs->push_back(input_);
1114 return Status::OK();
1115 }
1116
CheckExternalState() const1117 Status CheckExternalState() const override {
1118 return input_->CheckExternalState();
1119 }
1120
1121 protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const1122 Status AsGraphDefInternal(SerializationContext* ctx,
1123 DatasetGraphDefBuilder* b,
1124 Node** output) const override {
1125 Node* input_graph_node = nullptr;
1126 TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
1127
1128 Node* path = nullptr;
1129 TF_RETURN_IF_ERROR(b->AddScalar(dir_, &path));
1130
1131 AttrValue compression_attr;
1132 b->BuildAttrValue(compression_, &compression_attr);
1133
1134 AttrValue reader_path_prefix_attr;
1135 b->BuildAttrValue(reader_path_prefix_, &reader_path_prefix_attr);
1136
1137 AttrValue writer_path_prefix_attr;
1138 b->BuildAttrValue(writer_path_prefix_, &writer_path_prefix_attr);
1139
1140 AttrValue shard_size_bytes_attr;
1141 b->BuildAttrValue<int64>(shard_size_bytes_, &shard_size_bytes_attr);
1142
1143 AttrValue pending_snapshot_expiry_seconds_attr;
1144 b->BuildAttrValue<int64>(pending_snapshot_expiry_seconds_,
1145 &pending_snapshot_expiry_seconds_attr);
1146
1147 AttrValue num_reader_threads_attr;
1148 b->BuildAttrValue<int64>(num_reader_threads_, &num_reader_threads_attr);
1149
1150 AttrValue reader_buffer_size_attr;
1151 b->BuildAttrValue<int64>(reader_buffer_size_, &reader_buffer_size_attr);
1152
1153 AttrValue num_writer_threads_attr;
1154 b->BuildAttrValue<int64>(num_writer_threads_, &num_writer_threads_attr);
1155
1156 AttrValue writer_buffer_size_attr;
1157 b->BuildAttrValue<int64>(writer_buffer_size_, &writer_buffer_size_attr);
1158
1159 AttrValue shuffle_on_read_attr;
1160 b->BuildAttrValue<bool>(shuffle_on_read_, &shuffle_on_read_attr);
1161
1162 AttrValue seed_attr;
1163 b->BuildAttrValue<int64>(seed_, &seed_attr);
1164
1165 AttrValue seed2_attr;
1166 b->BuildAttrValue<int64>(seed2_, &seed2_attr);
1167
1168 AttrValue mode_attr;
1169 b->BuildAttrValue(mode_, &mode_attr);
1170
1171 AttrValue snapshot_name_attr;
1172 b->BuildAttrValue(snapshot_name_, &snapshot_name_attr);
1173
1174 TF_RETURN_IF_ERROR(b->AddDataset(
1175 this,
1176 /*inputs=*/
1177 {std::make_pair(0, input_graph_node), std::make_pair(1, path)},
1178 /*list_inputs=*/
1179 {},
1180 /*attrs=*/
1181 {{"compression", compression_attr},
1182 {"reader_path_prefix", reader_path_prefix_attr},
1183 {"writer_path_prefix", writer_path_prefix_attr},
1184 {"shard_size_bytes", shard_size_bytes_attr},
1185 {"pending_snapshot_expiry_seconds",
1186 pending_snapshot_expiry_seconds_attr},
1187 {"num_reader_threads", num_reader_threads_attr},
1188 {"reader_buffer_size", reader_buffer_size_attr},
1189 {"num_writer_threads", num_writer_threads_attr},
1190 {"writer_buffer_size", writer_buffer_size_attr},
1191 {"shuffle_on_read", shuffle_on_read_attr},
1192 {"seed", seed_attr},
1193 {"seed2", seed2_attr},
1194 {"mode", mode_attr},
1195 {"snapshot_name", snapshot_name_attr}},
1196 output));
1197 return Status::OK();
1198 }
1199
1200 private:
1201 class Iterator : public DatasetIterator<Dataset> {
1202 public:
Iterator(const Params & params)1203 explicit Iterator(const Params& params)
1204 : DatasetIterator<Dataset>(params) {
1205 if (dataset()->snapshot_name_.empty()) {
1206 hash_dir_ = io::JoinPath(dataset()->dir_, dataset()->graph_hash_);
1207 } else {
1208 hash_dir_ = io::JoinPath(
1209 dataset()->dir_,
1210 strings::StrCat("custom-", dataset()->snapshot_name_));
1211 }
1212 }
1213
1214 // We have a somewhat non traditional pattern for iterator initialization
1215 // for Snapshot. The protocol is that we initialize the Reader / Writer
1216 // iterator on the first GetNext call. We also invoke the same
1217 // initialization code when restoring as well. The reason why we don't do
1218 // this during the Initialize call is because during Restore we call
1219 // Initialize at first and at that point we don't know which iterator
1220 // (Reader / Writer / Passthrough) we need to restore as this info is part
1221 // of the checkpoint.
Initialize(IteratorContext * ctx)1222 Status Initialize(IteratorContext* ctx) override { return Status::OK(); }
1223
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)1224 Status GetNextInternal(IteratorContext* ctx,
1225 std::vector<Tensor>* out_tensors,
1226 bool* end_of_sequence) override {
1227 mutex_lock l(mu_);
1228 if (iterator_ == nullptr) {
1229 experimental::SnapshotMetadataRecord metadata;
1230 bool file_exists;
1231 TF_RETURN_IF_ERROR(snapshot_util::ReadMetadataFile(
1232 ctx->env(), hash_dir_, &metadata, &file_exists));
1233 TF_RETURN_IF_ERROR(snapshot_util::DetermineOpState(
1234 dataset()->mode_, file_exists, &metadata,
1235 dataset()->pending_snapshot_expiry_seconds_, &state_));
1236 VLOG(2) << "Snapshot state: " << state_;
1237 TF_RETURN_IF_ERROR(InitializeIterator(ctx, metadata));
1238 }
1239 return iterator_->GetNext(ctx, out_tensors, end_of_sequence);
1240 }
1241
1242 protected:
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)1243 Status SaveInternal(SerializationContext* ctx,
1244 IteratorStateWriter* writer) override {
1245 mutex_lock l(mu_);
1246 TF_RETURN_IF_ERROR(SaveInput(ctx, writer, iterator_));
1247 TF_RETURN_IF_ERROR(
1248 writer->WriteScalar(full_name(kState), static_cast<int64>(state_)));
1249 TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kHashDir), hash_dir_));
1250 VLOG(2) << "Saving Snapshot iterator: " << state_;
1251 return Status::OK();
1252 }
1253
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)1254 Status RestoreInternal(IteratorContext* ctx,
1255 IteratorStateReader* reader) override {
1256 mutex_lock l(mu_);
1257 tstring hash_dir;
1258 TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kHashDir), &hash_dir));
1259 if (hash_dir != hash_dir_) {
1260 LOG(ERROR) << "Dataset has changed while restoring from the "
1261 "checkpoint. Old hash: "
1262 << hash_dir << "; new hash: " << hash_dir_;
1263 return Status::OK();
1264 }
1265 {
1266 int64 temp;
1267 TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kState), &temp));
1268 state_ = snapshot_util::Mode(temp);
1269 }
1270 experimental::SnapshotMetadataRecord metadata;
1271 bool file_exists;
1272 TF_RETURN_IF_ERROR(snapshot_util::ReadMetadataFile(
1273 ctx->env(), hash_dir_, &metadata, &file_exists));
1274 TF_RETURN_IF_ERROR(InitializeIterator(ctx, metadata));
1275 VLOG(2) << "Restoring Snapshot iterator: " << state_;
1276 return RestoreInput(ctx, reader, iterator_);
1277 }
1278
1279 // This method expects that state_ is populated and it will create the
1280 // correct Reader / Writer / Passthrough iterator and initialize it.
InitializeIterator(IteratorContext * ctx,const experimental::SnapshotMetadataRecord & metadata)1281 Status InitializeIterator(
1282 IteratorContext* ctx,
1283 const experimental::SnapshotMetadataRecord& metadata)
1284 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
1285 std::string run_id = "";
1286 if (!dataset()->snapshot_name_.empty()) {
1287 // We have overridden the snapshot with a custom name, so we don't
1288 // generate random run ids, but just use the same one.
1289 run_id = "custom";
1290 }
1291
1292 switch (state_) {
1293 case snapshot_util::WRITER:
1294 iterator_ = absl::make_unique<SnapshotWriterIterator>(
1295 SnapshotWriterIterator::Params{
1296 dataset(), absl::StrCat(prefix(), "WriterImpl")},
1297 hash_dir_, run_id);
1298 break;
1299 case snapshot_util::READER:
1300 if (run_id.empty() && metadata.run_id().empty()) {
1301 return errors::NotFound(
1302 "Could not find a valid snapshot to read.");
1303 }
1304 if (run_id.empty()) {
1305 run_id = metadata.run_id();
1306 }
1307 // dtypes in metadata should be the same as dataset()->output_dtypes
1308 if (metadata.dtype_size() != dataset()->output_dtypes().size()) {
1309 return errors::Internal(
1310 "Expected number of dtypes: ",
1311 dataset()->output_dtypes().size(),
1312 " but number in snapshot: ", metadata.dtype_size());
1313 }
1314 for (int i = 0; i < metadata.dtype_size(); ++i) {
1315 if (metadata.dtype(i) != dataset()->output_dtypes()[i]) {
1316 return errors::Internal(
1317 "Type: ", i,
1318 " doesn't match. Snapshot: ", metadata.dtype(i),
1319 "; dataset: ", dataset()->output_dtypes()[i]);
1320 }
1321 }
1322 iterator_ = absl::make_unique<SnapshotReaderIterator>(
1323 SnapshotReaderIterator::Params{
1324 dataset(), absl::StrCat(prefix(), "ReaderImpl")},
1325 hash_dir_, run_id, metadata.version());
1326 break;
1327 case snapshot_util::PASSTHROUGH:
1328 iterator_ = absl::make_unique<SnapshotPassthroughIterator>(
1329 SnapshotPassthroughIterator::Params{
1330 dataset(), absl::StrCat(prefix(), "PassthroughImpl")});
1331 break;
1332 }
1333 TF_RETURN_IF_ERROR(iterator_->InitializeBase(ctx, this));
1334 return iterator_->Initialize(ctx);
1335 }
1336
1337 protected:
1338 class SnapshotReaderIterator : public DatasetIterator<Dataset> {
1339 public:
1340 static constexpr const char* const kParse = "Parse";
1341
SnapshotReaderIterator(const Params & params,const string & hash_dir,const string & run_id,int64 version)1342 explicit SnapshotReaderIterator(const Params& params,
1343 const string& hash_dir,
1344 const string& run_id, int64 version)
1345 : DatasetIterator<Dataset>(params),
1346 hash_dir_(hash_dir),
1347 run_id_(run_id),
1348 version_(version) {}
1349
~SnapshotReaderIterator()1350 ~SnapshotReaderIterator() override {
1351 mutex_lock l(mu_);
1352 cancelled_ = true;
1353 cond_var_.notify_all();
1354 while (num_active_threads_ > 0) {
1355 cond_var_.wait(l);
1356 }
1357 }
1358
Initialize(IteratorContext * ctx)1359 Status Initialize(IteratorContext* ctx) override {
1360 mutex_lock l(mu_);
1361 thread_pool_ = ctx->CreateThreadPool(kSnapshotReaderWorkerPool,
1362 dataset()->num_reader_threads_);
1363 run_dir_ = io::JoinPath(hash_dir_, run_id_);
1364 // Get all the files in the run_dir.
1365 std::vector<std::string> filenames_str;
1366 TF_RETURN_IF_ERROR(ctx->env()->GetMatchingPaths(
1367 absl::StrCat(absl::string_view(run_dir_), "/*"), &filenames_str));
1368 filenames_.resize(filenames_str.size());
1369 std::copy(filenames_str.begin(), filenames_str.end(),
1370 filenames_.begin());
1371 if (filenames_.empty()) {
1372 return errors::NotFound("Could not find any files in dir: ",
1373 run_dir_);
1374 }
1375
1376 if (dataset()->shuffle_on_read_) {
1377 uint64 seed = dataset()->seed_ + dataset()->seed2_;
1378 if (dataset()->seed_ == 0 && dataset()->seed2_ == 0) {
1379 seed = random::New64();
1380 }
1381
1382 std::mt19937 rng(seed);
1383 std::shuffle(filenames_.begin(), filenames_.end(), rng);
1384 } else {
1385 std::sort(filenames_.begin(), filenames_.end());
1386 }
1387
1388 for (auto i = 0; i < dataset()->num_reader_threads_; ++i) {
1389 curr_filenames_.push_back(GetNextFilename());
1390 }
1391 return Status::OK();
1392 }
1393
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)1394 Status GetNextInternal(IteratorContext* ctx,
1395 std::vector<Tensor>* out_tensors,
1396 bool* end_of_sequence) override {
1397 absl::Time start = absl::Now();
1398 mutex_lock l(mu_);
1399 if (!background_threads_started_) {
1400 for (int i = 0; i < dataset()->num_reader_threads_; ++i) {
1401 ++num_active_threads_;
1402 thread_pool_->Schedule(
1403 [this, i, env = ctx->env()]() { ReadingFilesLoop(env, i); });
1404 }
1405 background_threads_started_ = true;
1406 }
1407
1408 // Wait till the buffer has something in it.
1409 while (!cancelled_ && buffer_.empty() &&
1410 !background_threads_finished_) {
1411 cond_var_.wait(l);
1412 }
1413
1414 if (cancelled_) {
1415 return errors::Cancelled(
1416 "SnapshotDatasetOp::Dataset::SnapshotReaderIterator::GetNext");
1417 }
1418
1419 const auto& stats_aggregator = ctx->stats_aggregator();
1420 if (stats_aggregator) {
1421 stats_aggregator->AddScalar(
1422 absl::StrCat(dataset()->node_name(), kSeparator,
1423 kSnapshotReadElements),
1424 static_cast<float>(num_elements_read_), elements_produced_);
1425 stats_aggregator->AddScalar(
1426 absl::StrCat(dataset()->node_name(), kSeparator,
1427 "snapshot_reader_buffer_size"),
1428 static_cast<float>(buffer_.size()), elements_produced_);
1429 }
1430
1431 if (!buffer_.empty()) {
1432 Status s = buffer_.front().status;
1433 if (s.ok()) {
1434 *end_of_sequence = false;
1435 *out_tensors = std::move(buffer_.front().value);
1436
1437 {
1438 profiler::TraceMe activity(
1439 [&]() {
1440 return absl::StrCat(prefix(), kSeparator, kBookkeeping);
1441 },
1442 profiler::TraceMeLevel::kInfo);
1443 // Printing some statistics along the way.
1444 int64 num_bytes = 0;
1445 for (int i = 0; i < out_tensors->size(); ++i) {
1446 num_bytes += (*out_tensors)[i].TotalBytes();
1447 }
1448 absl::Time end = absl::Now();
1449 absl::Duration d = end - start;
1450 time_spent_micros_ += absl::ToInt64Microseconds(d);
1451 kbytes_read_ += static_cast<double>(num_bytes) / 1024.0;
1452 float read_throughput =
1453 (kbytes_read_ / 1024.0) / (time_spent_micros_ / 1000000.0);
1454 if (stats_aggregator) {
1455 stats_aggregator->AddScalar(
1456 absl::StrCat(dataset()->node_name(), kSeparator,
1457 kSnapshotReadThroughput),
1458 read_throughput, elements_produced_);
1459 }
1460 elements_produced_++;
1461 if (elements_produced_ % 10000 == 0) {
1462 LOG(INFO)
1463 << "Current read throughput (MBPS): " << read_throughput;
1464 }
1465 }
1466 }
1467 buffer_.pop_front();
1468 cond_var_.notify_all();
1469 return s;
1470 }
1471
1472 if (background_threads_finished_) {
1473 *end_of_sequence = true;
1474 return Status::OK();
1475 }
1476
1477 return errors::Internal("Unreachable point in SnapshotReader");
1478 }
1479
1480 protected:
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)1481 Status SaveInternal(SerializationContext* ctx,
1482 IteratorStateWriter* writer) override {
1483 mutex_lock l(mu_);
1484 TF_RETURN_IF_ERROR(
1485 writer->WriteScalar(full_name(kHashDir), hash_dir_));
1486 TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kRunId), run_id_));
1487 TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kRunDir), run_dir_));
1488 TF_RETURN_IF_ERROR(
1489 writer->WriteScalar(full_name(kVersionStr), version_));
1490 TF_RETURN_IF_ERROR(writer->WriteScalar(
1491 full_name(strings::StrCat(kFilenames, kSizeSuffix)),
1492 filenames_.size()));
1493 for (size_t i = 0; i < filenames_.size(); ++i) {
1494 TF_RETURN_IF_ERROR(writer->WriteScalar(
1495 full_name(strings::StrCat(kFilenames, "[", i, "]")),
1496 filenames_[i]));
1497 }
1498 for (auto i = 0; i < dataset()->num_reader_threads_; ++i) {
1499 TF_RETURN_IF_ERROR(writer->WriteScalar(
1500 full_name(strings::StrCat(kCurrentFilenames, "[", i, "]")),
1501 curr_filenames_[i]));
1502 }
1503 TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kElementsProduced),
1504 elements_produced_));
1505 TF_RETURN_IF_ERROR(
1506 writer->WriteScalar(full_name(kNextFileIndex), next_file_index_));
1507 TF_RETURN_IF_ERROR(
1508 writer->WriteScalar(full_name(kNumFilesDone), num_files_done_));
1509 TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kNumElementsRead),
1510 num_elements_read_));
1511 VLOG(2) << "Saving SnapshotReaderIterator: " << num_elements_read_
1512 << "; elements_produced: " << elements_produced_;
1513 return Status::OK();
1514 }
1515
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)1516 Status RestoreInternal(IteratorContext* ctx,
1517 IteratorStateReader* reader) override {
1518 mutex_lock l(mu_);
1519 tstring hash_dir, run_id, run_dir;
1520 TF_RETURN_IF_ERROR(
1521 reader->ReadScalar(full_name(kHashDir), &hash_dir));
1522 TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kHashDir), &run_id));
1523 TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kHashDir), &run_dir));
1524 if (run_dir != run_dir_) {
1525 LOG(ERROR) << "Restoring read iterator from ckpt with old "
1526 << "run_dir: " << run_dir
1527 << " but new run_dir is: " << run_dir_
1528 << ". We'll now restart snapshot creation.";
1529 return Status::OK();
1530 }
1531 TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kRunId), &run_id_));
1532 TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kRunDir), &run_dir_));
1533 TF_RETURN_IF_ERROR(
1534 reader->ReadScalar(full_name(kVersionStr), &version_));
1535 curr_filenames_.clear();
1536 curr_filenames_.reserve(dataset()->num_reader_threads_);
1537 for (auto i = 0; i < dataset()->num_reader_threads_; ++i) {
1538 curr_filenames_.emplace_back();
1539 TF_RETURN_IF_ERROR(reader->ReadScalar(
1540 full_name(strings::StrCat(kCurrentFilenames, "[", i, "]")),
1541 &curr_filenames_.back()));
1542 }
1543 size_t filenames_size;
1544 {
1545 int64 temp;
1546 TF_RETURN_IF_ERROR(reader->ReadScalar(
1547 full_name(strings::StrCat(kFilenames, kSizeSuffix)), &temp));
1548 filenames_size = static_cast<size_t>(temp);
1549 }
1550 if (filenames_.size() != filenames_size) {
1551 LOG(ERROR) << "Old filenames size: " << filenames_size
1552 << "; new filenames size: " << filenames_.size();
1553 }
1554 filenames_.clear();
1555 filenames_.reserve(filenames_size);
1556 for (size_t i = 0; i < filenames_size; ++i) {
1557 filenames_.emplace_back();
1558 TF_RETURN_IF_ERROR(reader->ReadScalar(
1559 full_name(strings::StrCat(kFilenames, "[", i, "]")),
1560 &filenames_.back()));
1561 }
1562 {
1563 int64 temp;
1564 TF_RETURN_IF_ERROR(
1565 reader->ReadScalar(full_name(kElementsProduced), &temp));
1566 elements_produced_ = static_cast<uint64>(temp);
1567 }
1568 {
1569 int64 temp;
1570 TF_RETURN_IF_ERROR(
1571 reader->ReadScalar(full_name(kNextFileIndex), &temp));
1572 next_file_index_ = static_cast<uint64>(temp);
1573 }
1574 TF_RETURN_IF_ERROR(
1575 reader->ReadScalar(full_name(kNumFilesDone), &num_files_done_));
1576 TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kNumElementsRead),
1577 &num_elements_read_));
1578 VLOG(2) << "Restoring SnapshotReaderIterator: " << num_elements_read_
1579 << "; elements_produced: " << elements_produced_;
1580 return Status::OK();
1581 }
1582
1583 private:
1584 // Reads one file end to end.
ReadFile(Env * env,const string & filename)1585 Status ReadFile(Env* env, const string& filename) {
1586 std::unique_ptr<snapshot_util::Reader> reader;
1587 TF_RETURN_IF_ERROR(snapshot_util::Reader::Create(
1588 env, filename, dataset()->compression_, version_,
1589 dataset()->output_dtypes(), &reader));
1590 while (true) {
1591 // Wait for a slot in the buffer.
1592 {
1593 mutex_lock l(mu_);
1594 while (!cancelled_ &&
1595 buffer_.size() >= dataset()->reader_buffer_size_) {
1596 cond_var_.wait(l);
1597 }
1598
1599 if (cancelled_) {
1600 return errors::Cancelled(
1601 "SnapshotDatasetOp::Dataset::SnapshotReaderIterator::"
1602 "ReadFile");
1603 }
1604 }
1605 std::vector<Tensor> read_tensors;
1606 Status s = reader->ReadTensors(&read_tensors);
1607 if (s.ok()) {
1608 profiler::TraceMe activity(
1609 [&]() { return absl::StrCat(prefix(), kSeparator, kParse); },
1610 profiler::TraceMeLevel::kInfo);
1611 BufferElement elem;
1612 elem.value = std::move(read_tensors);
1613 elem.status = Status::OK();
1614 mutex_lock l(mu_);
1615 buffer_.push_back(std::move(elem));
1616 num_elements_read_++;
1617 cond_var_.notify_all();
1618 } else if (errors::IsOutOfRange(s)) {
1619 return Status::OK();
1620 } else {
1621 return s;
1622 }
1623 }
1624 return Status::OK();
1625 }
1626
GetNextFilename()1627 string GetNextFilename() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
1628 if (next_file_index_ >= filenames_.size()) {
1629 return "";
1630 }
1631 string filename = io::JoinPath(dataset()->reader_path_prefix_,
1632 filenames_[next_file_index_]);
1633 next_file_index_++;
1634 return filename;
1635 }
1636
1637 // Pulls one file off the filenames_ list and reads it through. When
1638 // all files are read, terminates.
ReadingFilesLoop(Env * env,int i)1639 void ReadingFilesLoop(Env* env, int i) {
1640 auto cleanup = gtl::MakeCleanup([this]() {
1641 mutex_lock l(mu_);
1642 --num_active_threads_;
1643 cond_var_.notify_all();
1644 });
1645 while (true) {
1646 string filename = "";
1647 {
1648 mutex_lock l(mu_);
1649 filename = curr_filenames_[i];
1650 if (filename.empty()) {
1651 return;
1652 }
1653 VLOG(2) << "Starting to read: " << filename;
1654 }
1655 Status s = ReadFile(env, filename);
1656 // If we get to the end of the file, it's a clean termination and
1657 // we are at the end of the file. If all files have been processed,
1658 // then we insert an end_of_sequence marker in the buffer and
1659 // terminate the loop.
1660 if (s.ok()) {
1661 VLOG(2) << "Finished reading: " << filename;
1662 mutex_lock l(mu_);
1663 num_files_done_++;
1664 if (num_files_done_ >= filenames_.size()) {
1665 background_threads_finished_ = true;
1666 cond_var_.notify_all();
1667 return;
1668 }
1669 curr_filenames_[i] = GetNextFilename();
1670 } else {
1671 LOG(ERROR) << "Encountered an error: " << s.ToString();
1672 BufferElement elem;
1673 elem.status = s;
1674 mutex_lock l(mu_);
1675 buffer_.push_back(std::move(elem));
1676 cond_var_.notify_all();
1677 return;
1678 }
1679 }
1680 }
1681
WriteStatus(IteratorStateWriter * writer,size_t index,const Status & status)1682 Status WriteStatus(IteratorStateWriter* writer, size_t index,
1683 const Status& status)
1684 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
1685 TF_RETURN_IF_ERROR(writer->WriteScalar(
1686 CodeKey(index), static_cast<int64>(status.code())));
1687 if (!status.ok()) {
1688 TF_RETURN_IF_ERROR(writer->WriteScalar(ErrorMessageKey(index),
1689 status.error_message()));
1690 }
1691 return Status::OK();
1692 }
1693
ReadStatus(IteratorStateReader * reader,size_t index,Status * status)1694 Status ReadStatus(IteratorStateReader* reader, size_t index,
1695 Status* status) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
1696 int64 code_int;
1697 TF_RETURN_IF_ERROR(reader->ReadScalar(CodeKey(index), &code_int));
1698 error::Code code = static_cast<error::Code>(code_int);
1699
1700 if (code != error::Code::OK) {
1701 tstring error_message;
1702 TF_RETURN_IF_ERROR(
1703 reader->ReadScalar(ErrorMessageKey(index), &error_message));
1704 *status = Status(code, error_message);
1705 } else {
1706 *status = Status::OK();
1707 }
1708 return Status::OK();
1709 }
1710
CodeKey(size_t index)1711 string CodeKey(size_t index) {
1712 return full_name(strings::StrCat(kStatus, "[", index, "]", kCode));
1713 }
1714
ErrorMessageKey(size_t index)1715 string ErrorMessageKey(size_t index) {
1716 return full_name(
1717 strings::StrCat(kStatus, "[", index, "]", kErrorMessage));
1718 }
1719
1720 struct BufferElement {
1721 Status status;
1722 std::vector<Tensor> value;
1723 };
1724
1725 mutex mu_;
1726 condition_variable cond_var_;
1727
1728 const string hash_dir_;
1729 tstring run_id_ TF_GUARDED_BY(mu_);
1730 tstring run_dir_ TF_GUARDED_BY(mu_);
1731 int64 version_;
1732 std::vector<tstring> filenames_;
1733
1734 uint64 elements_produced_ TF_GUARDED_BY(mu_) = 0;
1735 int64 time_spent_micros_ TF_GUARDED_BY(mu_) = 0;
1736 double kbytes_read_ TF_GUARDED_BY(mu_) = 0;
1737 size_t next_file_index_ TF_GUARDED_BY(mu_) = 0;
1738 int64 num_files_done_ TF_GUARDED_BY(mu_) = 0;
1739
1740 std::unique_ptr<thread::ThreadPool> thread_pool_;
1741 int64 num_active_threads_ TF_GUARDED_BY(mu_) = 0;
1742 std::deque<BufferElement> buffer_ TF_GUARDED_BY(mu_);
1743 bool cancelled_ TF_GUARDED_BY(mu_) = false;
1744 bool background_threads_started_ TF_GUARDED_BY(mu_) = false;
1745 bool background_threads_finished_ TF_GUARDED_BY(mu_) = false;
1746 int64 num_elements_read_ TF_GUARDED_BY(mu_) = 0;
1747 // curr_filenames_ tracks which file is being read by each thread.
1748 std::vector<tstring> curr_filenames_ TF_GUARDED_BY(mu_);
1749 };
1750
1751 class SnapshotWriterIterator : public DatasetIterator<Dataset> {
1752 public:
1753 static constexpr const char* const kProcessOneElement =
1754 "ProcessOneElement";
1755
SnapshotWriterIterator(const Params & params,const string & hash_dir,const string & run_id)1756 explicit SnapshotWriterIterator(const Params& params,
1757 const string& hash_dir,
1758 const string& run_id)
1759 : DatasetIterator<Dataset>(params),
1760 hash_dir_(hash_dir),
1761 run_id_(run_id) {
1762 if (run_id_.empty()) {
1763 run_id_ = strings::StrCat(
1764 strings::Hex(random::New64(), strings::kZeroPad4));
1765 }
1766 run_dir_ =
1767 io::JoinPath(dataset()->writer_path_prefix_, hash_dir_, run_id_);
1768 }
1769
~SnapshotWriterIterator()1770 ~SnapshotWriterIterator() override {
1771 mutex_lock l(mu_);
1772 cancelled_ = true;
1773 cond_var_.notify_all();
1774 while (num_active_threads_ > 0) {
1775 cond_var_.wait(l);
1776 }
1777 }
1778
Initialize(IteratorContext * ctx)1779 Status Initialize(IteratorContext* ctx) override {
1780 thread_pool_ = ctx->CreateThreadPool(kSnapshotWriterWorkerPool,
1781 dataset()->num_writer_threads_);
1782 return dataset()->input_->MakeIterator(ctx, this, prefix(),
1783 &input_impl_);
1784 }
1785
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)1786 Status GetNextInternal(IteratorContext* ctx,
1787 std::vector<Tensor>* out_tensors,
1788 bool* end_of_sequence) override {
1789 absl::Time start = absl::Now();
1790
1791 bool first_call;
1792 bool is_restored;
1793 {
1794 mutex_lock l(mu_);
1795 first_call = first_call_;
1796 is_restored = is_restored_;
1797 if (first_call_) {
1798 // If we're restoring then the directory already exists and we
1799 // don't want to overwrite the snapshot metadata file.
1800 if (!is_restored_) {
1801 TF_RETURN_IF_ERROR(ctx->env()->RecursivelyCreateDir(run_dir_));
1802 experimental::SnapshotMetadataRecord metadata;
1803 metadata.set_creation_timestamp(EnvTime::NowMicros());
1804 metadata.set_graph_hash(dataset()->graph_hash_);
1805 metadata.set_run_id(run_id_.data(), run_id_.size());
1806 metadata.set_version(kCurrentVersion);
1807 for (const auto& output_dtype : dataset()->output_dtypes()) {
1808 metadata.add_dtype(output_dtype);
1809 }
1810 metadata.set_finalized(false);
1811 TF_RETURN_IF_ERROR(snapshot_util::WriteMetadataFile(
1812 ctx->env(), hash_dir_, &metadata));
1813 }
1814 for (int i = 0; i < dataset()->num_writer_threads_; ++i) {
1815 ++num_active_threads_;
1816 thread_pool_->Schedule(
1817 [this, env = ctx->env()]() { WriterThread(env); });
1818 }
1819 first_call_ = false;
1820 }
1821 }
1822
1823 // When we reach the end of the data, we'd like to finalize the
1824 // snapshot and write the metadata file out. If we just check for
1825 // end_of_sequence on the GetNext call then we will need to make
1826 // N + 1 GetNext calls (if N is the total number of elements in the
1827 // dataset). So right now we solve this issue by prefetching the next
1828 // element in the data stream. Therefore the first call ends up
1829 // pulling two elements.
1830 if (first_call && !is_restored) {
1831 TF_RETURN_IF_ERROR(FillBuffer(ctx));
1832 }
1833
1834 {
1835 mutex_lock l(mu_);
1836 // Populate out_tensors with the prefetched data.
1837 *out_tensors = std::move(next_elem_.value);
1838 *end_of_sequence = next_elem_.end_of_sequence;
1839 }
1840
1841 // Update prefetched_elem with the next element.
1842 TF_RETURN_IF_ERROR(FillBuffer(ctx));
1843
1844 {
1845 profiler::TraceMe activity(
1846 [&]() {
1847 return absl::StrCat(prefix(), kSeparator, kBookkeeping);
1848 },
1849 profiler::TraceMeLevel::kInfo);
1850
1851 // Book keeping to report some statistics.
1852 mutex_lock l(mu_);
1853 int64 num_bytes = 0;
1854 for (const auto& out_tensor : *out_tensors) {
1855 num_bytes += out_tensor.TotalBytes();
1856 }
1857
1858 const auto& stats_aggregator = ctx->stats_aggregator();
1859 if (stats_aggregator) {
1860 stats_aggregator->AddScalar(
1861 absl::StrCat(dataset()->node_name(), kSeparator,
1862 kSnapshotWrittenElements),
1863 static_cast<float>(num_elements_written_),
1864 elements_produced_);
1865 stats_aggregator->AddScalar(
1866 absl::StrCat(dataset()->node_name(), kSeparator,
1867 "snapshot_writer_buffer_size"),
1868 static_cast<float>(buffer_.size()), elements_produced_);
1869 }
1870
1871 absl::Time end = absl::Now();
1872 absl::Duration d = end - start;
1873 time_spent_micros_ += absl::ToInt64Microseconds(d);
1874 bytes_produced_ += num_bytes;
1875 float write_throughput = (bytes_produced_ * 1000000.0) /
1876 (time_spent_micros_ * 1024.0 * 1024.0);
1877 if (stats_aggregator) {
1878 stats_aggregator->AddScalar(
1879 absl::StrCat(dataset()->node_name(), kSeparator,
1880 kSnapshotWriteThroughput),
1881 write_throughput, elements_produced_);
1882 }
1883
1884 elements_produced_++;
1885 if (elements_produced_ % 10000 == 0) {
1886 LOG(INFO) << "Current write throughput (MBPS): "
1887 << write_throughput;
1888 }
1889 }
1890 return Status::OK();
1891 }
1892
1893 protected:
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)1894 Status SaveInternal(SerializationContext* ctx,
1895 IteratorStateWriter* writer) override {
1896 mutex_lock l(mu_);
1897 TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
1898 if (end_of_sequence_) {
1899 TF_RETURN_IF_ERROR(
1900 writer->WriteScalar(full_name(kEndOfSequence), ""));
1901 }
1902 TF_RETURN_IF_ERROR(
1903 writer->WriteScalar(full_name(kHashDir), hash_dir_));
1904 TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kRunId), run_id_));
1905 TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kRunDir), run_dir_));
1906 TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kElementsProduced),
1907 elements_produced_));
1908 TF_RETURN_IF_ERROR(writer->WriteScalar(
1909 full_name(strings::StrCat(kBuffer, kSizeSuffix)),
1910 buffer_.size()));
1911 for (size_t i = 0; i < buffer_.size(); ++i) {
1912 auto& buffer_element = buffer_[i];
1913 if (buffer_element.end_of_sequence) {
1914 TF_RETURN_IF_ERROR(writer->WriteScalar(
1915 full_name(
1916 strings::StrCat(kBuffer, "[", i, "].", kEndOfSequence)),
1917 ""));
1918 }
1919 TF_RETURN_IF_ERROR(writer->WriteScalar(
1920 full_name(strings::StrCat(kBuffer, "[", i, "]", kSizeSuffix)),
1921 buffer_element.value.size()));
1922 for (size_t j = 0; j < buffer_element.value.size(); j++) {
1923 TF_RETURN_IF_ERROR(writer->WriteTensor(
1924 full_name(strings::StrCat(kBuffer, "[", i, "][", j, "]")),
1925 buffer_element.value[j]));
1926 }
1927 }
1928 TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kNumElementsWritten),
1929 num_elements_written_));
1930 if (next_elem_.end_of_sequence) {
1931 TF_RETURN_IF_ERROR(writer->WriteScalar(
1932 full_name(strings::StrCat(kNextElem, ".", kEndOfSequence)),
1933 ""));
1934 }
1935 TF_RETURN_IF_ERROR(writer->WriteScalar(
1936 full_name(strings::StrCat(kNextElem, kSizeSuffix)),
1937 next_elem_.value.size()));
1938 for (size_t i = 0; i < next_elem_.value.size(); i++) {
1939 TF_RETURN_IF_ERROR(writer->WriteTensor(
1940 full_name(strings::StrCat(kNextElem, "[", i, "]")),
1941 next_elem_.value[i]));
1942 }
1943 VLOG(2) << "Saving SnapshotWriterIterator: " << num_elements_written_
1944 << "; elements_produced: " << elements_produced_;
1945 return Status::OK();
1946 }
1947
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)1948 Status RestoreInternal(IteratorContext* ctx,
1949 IteratorStateReader* reader) override {
1950 mutex_lock l(mu_);
1951 buffer_.clear();
1952 TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
1953 tstring hash_dir;
1954 TF_RETURN_IF_ERROR(
1955 reader->ReadScalar(full_name(kHashDir), &hash_dir));
1956 // If the hash dir has changed then we restart writing.
1957 if (hash_dir != hash_dir_) {
1958 LOG(INFO) << "Old hash dir from ckpt: " << hash_dir
1959 << " is not the same as the new one: " << hash_dir_;
1960 return Status::OK();
1961 }
1962 is_restored_ = true;
1963 if (reader->Contains(full_name(kEndOfSequence))) {
1964 end_of_sequence_ = true;
1965 } else {
1966 end_of_sequence_ = false;
1967 }
1968 TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kRunId), &run_id_));
1969 TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kRunDir), &run_dir_));
1970 {
1971 int64 temp;
1972 TF_RETURN_IF_ERROR(
1973 reader->ReadScalar(full_name(kElementsProduced), &temp));
1974 elements_produced_ = static_cast<uint64>(temp);
1975 }
1976 size_t buffer_size;
1977 {
1978 int64 temp;
1979 TF_RETURN_IF_ERROR(reader->ReadScalar(
1980 full_name(strings::StrCat(kBuffer, kSizeSuffix)), &temp));
1981 buffer_size = static_cast<size_t>(temp);
1982 }
1983 for (size_t i = 0; i < buffer_size; i++) {
1984 buffer_.emplace_back();
1985 auto& buffer_element = buffer_.back();
1986 size_t value_size;
1987 {
1988 int64 temp;
1989 TF_RETURN_IF_ERROR(reader->ReadScalar(
1990 full_name(strings::StrCat(kBuffer, "[", i, "]", kSizeSuffix)),
1991 &temp));
1992 value_size = static_cast<size_t>(temp);
1993 }
1994 if (reader->Contains(full_name(
1995 strings::StrCat(kBuffer, "[", i, "].", kEndOfSequence)))) {
1996 buffer_element.end_of_sequence = true;
1997 } else {
1998 buffer_element.end_of_sequence = false;
1999 }
2000 buffer_element.value.reserve(value_size);
2001 for (size_t j = 0; j < value_size; j++) {
2002 buffer_element.value.emplace_back();
2003 TF_RETURN_IF_ERROR(reader->ReadTensor(
2004 full_name(strings::StrCat(kBuffer, "[", i, "][", j, "]")),
2005 &buffer_element.value.back()));
2006 }
2007 }
2008 // Since the last save we might have written out some files. So we
2009 // get a list of files in the directory and take the final filename
2010 // written. We use the name of the snapshot file to figure out
2011 // next_file_index_;
2012 std::vector<std::string> filenames;
2013 TF_RETURN_IF_ERROR(ctx->env()->GetMatchingPaths(
2014 absl::StrCat(absl::string_view(run_dir_), "/*"), &filenames));
2015 std::sort(filenames.begin(), filenames.end());
2016 std::string final_filename = filenames.back();
2017 std::vector<std::string> split_filename =
2018 absl::StrSplit(final_filename, '/');
2019 std::vector<std::string> split_snapshot_filename =
2020 absl::StrSplit(split_filename.back(), '.');
2021 std::string max_num_str = split_snapshot_filename[0];
2022 uint64 max_num;
2023 if (!strings::safe_strtou64(max_num_str, &max_num)) {
2024 return errors::Internal("Could not parse: ", max_num, " as uint64");
2025 }
2026 next_file_index_ = max_num + 1;
2027 TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kNumElementsWritten),
2028 &num_elements_written_));
2029 size_t next_elem_size;
2030 {
2031 int64 temp;
2032 TF_RETURN_IF_ERROR(reader->ReadScalar(
2033 full_name(strings::StrCat(kNextElem, kSizeSuffix)), &temp));
2034 next_elem_size = static_cast<size_t>(temp);
2035 }
2036 if (reader->Contains(
2037 full_name(strings::StrCat(kNextElem, ".", kEndOfSequence)))) {
2038 next_elem_.end_of_sequence = true;
2039 } else {
2040 next_elem_.end_of_sequence = false;
2041 }
2042 next_elem_.value.reserve(next_elem_size);
2043 for (size_t i = 0; i < next_elem_size; i++) {
2044 next_elem_.value.emplace_back();
2045 TF_RETURN_IF_ERROR(reader->ReadTensor(
2046 full_name(strings::StrCat(kNextElem, "[", i, "]")),
2047 &next_elem_.value.back()));
2048 }
2049 VLOG(2) << "Restoring SnapshotWriterIterator: "
2050 << num_elements_written_
2051 << "; elements_produced: " << elements_produced_;
2052 return Status::OK();
2053 }
2054
2055 private:
GetSnapshotFilename()2056 string GetSnapshotFilename() {
2057 mutex_lock l(mu_);
2058 string snapshot_data_filename = io::JoinPath(
2059 run_dir_, strings::Printf(
2060 "%08llu.snapshot",
2061 static_cast<unsigned long long>(next_file_index_)));
2062 next_file_index_++;
2063 return snapshot_data_filename;
2064 }
2065
FillBuffer(IteratorContext * ctx)2066 Status FillBuffer(IteratorContext* ctx) TF_LOCKS_EXCLUDED(mu_) {
2067 snapshot_util::ElementOrEOF elem;
2068 TF_RETURN_IF_ERROR(
2069 input_impl_->GetNext(ctx, &elem.value, &elem.end_of_sequence));
2070
2071 mutex_lock l(mu_);
2072 next_elem_ = std::move(elem);
2073
2074 if (next_elem_.end_of_sequence) {
2075 end_of_sequence_ = true;
2076 cond_var_.notify_all();
2077 // Now we wait till all background threads finish.
2078 while (num_active_threads_ > 0) {
2079 cond_var_.wait(l);
2080 }
2081 return Status::OK();
2082 }
2083
2084 // Wait for a space in the buffer_.
2085 while (!cancelled_ &&
2086 buffer_.size() >= dataset()->writer_buffer_size_) {
2087 cond_var_.wait(l);
2088 }
2089
2090 if (cancelled_) {
2091 return errors::Cancelled(
2092 "SnapshotDatasetOp::SnapshotWriterIterator::GetNext");
2093 }
2094
2095 if (buffer_.size() >= dataset()->writer_buffer_size_) {
2096 return errors::Internal(
2097 "Buffer size: ", buffer_.size(), " should be smaller than ",
2098 "maximum size: ", dataset()->writer_buffer_size_);
2099 }
2100
2101 snapshot_util::ElementOrEOF elem_copy = next_elem_;
2102 buffer_.push_back(elem_copy);
2103 cond_var_.notify_all();
2104 return Status::OK();
2105 }
2106
ProcessOneElement(Env * env,int64 * bytes_written,string * snapshot_data_filename,std::unique_ptr<snapshot_util::Writer> * writer,bool * end_of_processing)2107 Status ProcessOneElement(Env* env, int64* bytes_written,
2108 string* snapshot_data_filename,
2109 std::unique_ptr<snapshot_util::Writer>* writer,
2110 bool* end_of_processing) {
2111 profiler::TraceMe activity(
2112 [&]() {
2113 return absl::StrCat(prefix(), kSeparator, kProcessOneElement);
2114 },
2115 profiler::TraceMeLevel::kInfo);
2116 bool cancelled = false;
2117 *end_of_processing = false;
2118 bool produced_elem = false;
2119 bool snapshot_failed = false;
2120 snapshot_util::ElementOrEOF elem;
2121 {
2122 mutex_lock l(mu_);
2123 // Wait for buffer to not be empty.
2124 while (!cancelled_ && buffer_.empty() && !end_of_sequence_ &&
2125 !snapshot_failed_) {
2126 cond_var_.wait(l);
2127 }
2128 cancelled = cancelled_;
2129 if (!buffer_.empty()) {
2130 produced_elem = true;
2131 std::swap(elem, buffer_.front());
2132 buffer_.pop_front();
2133 cond_var_.notify_all();
2134 } else {
2135 *end_of_processing = end_of_sequence_;
2136 }
2137 snapshot_failed = snapshot_failed_;
2138 }
2139
2140 if (cancelled || snapshot_failed) {
2141 TF_RETURN_IF_ERROR((*writer)->Close());
2142 if (snapshot_failed) {
2143 return errors::Internal(
2144 "SnapshotDataset::SnapshotWriterIterator snapshot failed");
2145 }
2146 return errors::Cancelled(
2147 "SnapshotDataset::SnapshotWriterIterator cancelled");
2148 }
2149
2150 if (produced_elem) {
2151 for (const auto& out_tensor : elem.value) {
2152 *bytes_written += out_tensor.TotalBytes();
2153 }
2154
2155 bool should_close;
2156 TF_RETURN_IF_ERROR(
2157 ShouldCloseWriter(env, *snapshot_data_filename, *bytes_written,
2158 (*writer).get(), &should_close));
2159 if (should_close) {
2160 // If we exceed the shard size, we get a new file and reset.
2161 TF_RETURN_IF_ERROR((*writer)->Close());
2162 *snapshot_data_filename = GetSnapshotFilename();
2163
2164 TF_RETURN_IF_ERROR(snapshot_util::Writer::Create(
2165 env, *snapshot_data_filename, dataset()->compression_,
2166 kCurrentVersion, dataset()->output_dtypes(), writer));
2167 *bytes_written = 0;
2168 }
2169 TF_RETURN_IF_ERROR((*writer)->WriteTensors(elem.value));
2170 return Status::OK();
2171 }
2172
2173 if (*end_of_processing) {
2174 TF_RETURN_IF_ERROR((*writer)->Close());
2175 mutex_lock l(mu_);
2176 if (!written_final_metadata_file_) {
2177 experimental::SnapshotMetadataRecord metadata;
2178 bool file_exists;
2179 TF_RETURN_IF_ERROR(snapshot_util::ReadMetadataFile(
2180 env, hash_dir_, &metadata, &file_exists));
2181
2182 if (metadata.run_id() == run_id_) {
2183 metadata.set_finalized(true);
2184 TF_RETURN_IF_ERROR(snapshot_util::WriteMetadataFile(
2185 env, hash_dir_, &metadata));
2186 } else {
2187 // TODO(frankchn): We lost the race, remove all snapshots.
2188 }
2189 written_final_metadata_file_ = true;
2190 cond_var_.notify_all();
2191 }
2192 }
2193 return Status::OK();
2194 }
2195
2196 // Just pulls off elements from the buffer and writes them.
WriterThread(Env * env)2197 void WriterThread(Env* env) {
2198 auto cleanup = gtl::MakeCleanup([this]() {
2199 mutex_lock l(mu_);
2200 --num_active_threads_;
2201 cond_var_.notify_all();
2202 });
2203
2204 int64 bytes_written = 0;
2205 string snapshot_data_filename = GetSnapshotFilename();
2206 std::unique_ptr<snapshot_util::Writer> writer;
2207 Status s = snapshot_util::Writer::Create(
2208 env, snapshot_data_filename, dataset()->compression_,
2209 kCurrentVersion, dataset()->output_dtypes(), &writer);
2210 if (!s.ok()) {
2211 LOG(ERROR) << "Creating " << snapshot_data_filename
2212 << " failed: " << s.ToString();
2213 mutex_lock l(mu_);
2214 snapshot_failed_ = true;
2215 cond_var_.notify_all();
2216 return;
2217 }
2218
2219 bool end_of_processing = false;
2220 while (!end_of_processing) {
2221 Status s =
2222 ProcessOneElement(env, &bytes_written, &snapshot_data_filename,
2223 &writer, &end_of_processing);
2224 if (!s.ok()) {
2225 LOG(INFO) << "Error while writing snapshot data to disk: "
2226 << s.ToString();
2227 mutex_lock l(mu_);
2228 snapshot_failed_ = true;
2229 cond_var_.notify_all();
2230 return;
2231 }
2232 mutex_lock l(mu_);
2233 num_elements_written_++;
2234 }
2235 }
2236
ShouldCloseWriter(Env * env,const string & filename,uint64 bytes_written,snapshot_util::Writer * writer,bool * should_close)2237 Status ShouldCloseWriter(Env* env, const string& filename,
2238 uint64 bytes_written,
2239 snapshot_util::Writer* writer,
2240 bool* should_close) {
2241 // If the compression ratio has been estimated, use it to decide
2242 // whether the file should be closed. We avoid estimating the
2243 // compression ratio repeatedly because it requires syncing the file,
2244 // which can be expensive.
2245 {
2246 tf_shared_lock l(mu_);
2247 if (compression_ratio_ > 0.0) {
2248 *should_close = bytes_written > (compression_ratio_ *
2249 dataset()->shard_size_bytes_);
2250 return Status::OK();
2251 }
2252 }
2253 // If the number of bytes written aren't shard_size_bytes_ yet, we
2254 // keep on going.
2255 if (bytes_written <= dataset()->shard_size_bytes_) {
2256 *should_close = false;
2257 return Status::OK();
2258 }
2259 // Use the actual file size to determine compression ratio.
2260 // Make sure that all bytes are written out.
2261 TF_RETURN_IF_ERROR(writer->Sync());
2262 uint64 file_size;
2263 TF_RETURN_IF_ERROR(env->GetFileSize(filename, &file_size));
2264 mutex_lock l(mu_);
2265 compression_ratio_ = static_cast<double>(bytes_written) /
2266 static_cast<double>(file_size);
2267 LOG(INFO) << "Writing compression achieved: " << compression_ratio_;
2268 *should_close = file_size >= dataset()->shard_size_bytes_;
2269 return Status::OK();
2270 }
2271
2272 mutex mu_;
2273 // This condition variable is notified
2274 // 1. By the background writer threads when an element from the buffer
2275 // is consumed.
2276 // 2. By the main thread when it puts something into the buffer.
2277 // 3. By the main thread when the destructor is called to cancel.
2278 // 4. By the background writer threads when any error is encountered
2279 // while writing.
2280 // 5. By the background threads when they finish.
2281 condition_variable cond_var_;
2282
2283 snapshot_util::ElementOrEOF next_elem_ TF_GUARDED_BY(mu_);
2284 std::unique_ptr<IteratorBase> input_impl_;
2285
2286 const string hash_dir_;
2287 tstring run_id_ TF_GUARDED_BY(mu_);
2288 tstring run_dir_ TF_GUARDED_BY(mu_);
2289 double compression_ratio_ TF_GUARDED_BY(mu_) = 0.0;
2290 bool is_restored_ TF_GUARDED_BY(mu_) = false;
2291
2292 uint64 elements_produced_ TF_GUARDED_BY(mu_) = 0;
2293 int64 time_spent_micros_ TF_GUARDED_BY(mu_) = 0;
2294 int64 bytes_produced_ TF_GUARDED_BY(mu_) = 0;
2295
2296 std::deque<snapshot_util::ElementOrEOF> buffer_ TF_GUARDED_BY(mu_);
2297 bool snapshot_failed_ TF_GUARDED_BY(mu_) = false;
2298 bool cancelled_ TF_GUARDED_BY(mu_) = false;
2299 bool first_call_ TF_GUARDED_BY(mu_) = true;
2300 bool end_of_sequence_ TF_GUARDED_BY(mu_) = false;
2301 bool written_final_metadata_file_ TF_GUARDED_BY(mu_) = false;
2302 uint64 next_file_index_ TF_GUARDED_BY(mu_) = 0;
2303 std::unique_ptr<thread::ThreadPool> thread_pool_;
2304 int64 num_active_threads_ TF_GUARDED_BY(mu_) = 0;
2305 int64 num_elements_written_ = 0;
2306 };
2307
2308 class SnapshotPassthroughIterator : public DatasetIterator<Dataset> {
2309 public:
SnapshotPassthroughIterator(const Params & params)2310 explicit SnapshotPassthroughIterator(const Params& params)
2311 : DatasetIterator<Dataset>(params) {}
2312
Initialize(IteratorContext * ctx)2313 Status Initialize(IteratorContext* ctx) override {
2314 return dataset()->input_->MakeIterator(ctx, this, prefix(),
2315 &input_impl_);
2316 }
2317
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)2318 Status GetNextInternal(IteratorContext* ctx,
2319 std::vector<Tensor>* out_tensors,
2320 bool* end_of_sequence) override {
2321 return input_impl_->GetNext(ctx, out_tensors, end_of_sequence);
2322 }
2323
2324 protected:
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)2325 Status SaveInternal(SerializationContext* ctx,
2326 IteratorStateWriter* writer) override {
2327 return SaveInput(ctx, writer, input_impl_);
2328 }
2329
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)2330 Status RestoreInternal(IteratorContext* ctx,
2331 IteratorStateReader* reader) override {
2332 return RestoreInput(ctx, reader, input_impl_);
2333 }
2334
2335 private:
2336 std::unique_ptr<IteratorBase> input_impl_;
2337 };
2338
2339 string hash_dir_ TF_GUARDED_BY(mu_);
2340 snapshot_util::Mode state_ TF_GUARDED_BY(mu_);
2341 std::unique_ptr<IteratorBase> iterator_ TF_GUARDED_BY(mu_);
2342
2343 mutex mu_;
2344 };
2345
2346 const DatasetBase* const input_;
2347 const tstring dir_;
2348 const string graph_hash_;
2349
2350 const string reader_path_prefix_;
2351 const string writer_path_prefix_;
2352 const string compression_;
2353
2354 const uint64 shard_size_bytes_;
2355 const uint64 pending_snapshot_expiry_seconds_;
2356 const uint64 num_reader_threads_;
2357 const uint64 reader_buffer_size_;
2358 const uint64 num_writer_threads_;
2359 const uint64 writer_buffer_size_;
2360 const bool shuffle_on_read_;
2361
2362 const uint64 seed_;
2363 const uint64 seed2_;
2364
2365 const std::string mode_;
2366 const std::string snapshot_name_;
2367 };
2368
ComputeDatasetHash(const GraphDef & graph_def,const std::string & path,uint64 * hash)2369 Status ComputeDatasetHash(const GraphDef& graph_def, const std::string& path,
2370 uint64* hash) {
2371 TF_RETURN_IF_ERROR(HashGraph(graph_def, hash));
2372 // Adding path, compression, reader / writer path prefix, shard size
2373 // bytes to the fp as they effect the data written on disk.
2374 *hash = Hash64Combine(*hash, Hash64(path));
2375 *hash = Hash64Combine(*hash, Hash64(compression_));
2376 *hash = Hash64Combine(*hash, Hash64(reader_path_prefix_));
2377 *hash = Hash64Combine(*hash, Hash64(writer_path_prefix_));
2378 *hash = Hash64Combine(*hash, shard_size_bytes_);
2379 return Status::OK();
2380 }
2381
2382 const int graph_def_version_;
2383 DataTypeVector output_types_;
2384 std::vector<PartialTensorShape> output_shapes_;
2385
2386 string reader_path_prefix_;
2387 string writer_path_prefix_;
2388 string compression_;
2389
2390 int64 shard_size_bytes_;
2391 int64 pending_snapshot_expiry_seconds_;
2392 int64 num_reader_threads_;
2393 int64 reader_buffer_size_;
2394 int64 num_writer_threads_;
2395 int64 writer_buffer_size_;
2396 bool shuffle_on_read_;
2397
2398 int64 seed_;
2399 int64 seed2_;
2400
2401 std::string mode_;
2402 std::string snapshot_name_;
2403 };
2404
2405 REGISTER_KERNEL_BUILDER(Name("SnapshotDataset").Device(DEVICE_CPU),
2406 SnapshotDatasetOp);
2407
2408 } // namespace
2409 } // namespace experimental
2410 } // namespace data
2411 } // namespace tensorflow
2412