1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/kernels/data/experimental/snapshot_util.h"
17
18 #include <queue>
19
20 #include "absl/memory/memory.h"
21 #include "tensorflow/core/common_runtime/dma_helper.h"
22 #include "tensorflow/core/framework/dataset.h"
23 #include "tensorflow/core/framework/graph.pb.h"
24 #include "tensorflow/core/framework/tensor.pb.h"
25 #include "tensorflow/core/kernels/data/name_utils.h"
26 #include "tensorflow/core/lib/io/buffered_inputstream.h"
27 #include "tensorflow/core/lib/io/random_inputstream.h"
28 #include "tensorflow/core/lib/io/record_writer.h"
29 #include "tensorflow/core/lib/io/snappy/snappy_inputbuffer.h"
30 #include "tensorflow/core/lib/io/snappy/snappy_outputbuffer.h"
31 #include "tensorflow/core/lib/io/zlib_compression_options.h"
32 #include "tensorflow/core/lib/io/zlib_inputstream.h"
33 #include "tensorflow/core/lib/io/zlib_outputbuffer.h"
34 #include "tensorflow/core/platform/coding.h"
35 #include "tensorflow/core/platform/errors.h"
36 #include "tensorflow/core/platform/file_system.h"
37 #include "tensorflow/core/platform/path.h"
38 #include "tensorflow/core/platform/random.h"
39 #include "tensorflow/core/platform/stringprintf.h"
40 #include "tensorflow/core/profiler/lib/traceme.h"
41 #include "tensorflow/core/protobuf/snapshot.pb.h"
42
43 namespace tensorflow {
44 namespace data {
45 namespace snapshot_util {
46
47 /* static */ constexpr const int64
48 CustomReader::kSnappyReaderInputBufferSizeBytes;
49 /* static */ constexpr const int64
50 CustomReader::kSnappyReaderOutputBufferSizeBytes;
51
HashDirectory(const std::string & path,uint64 hash)52 std::string HashDirectory(const std::string& path, uint64 hash) {
53 return io::JoinPath(
54 path, strings::Printf("%llu", static_cast<unsigned long long>(hash)));
55 }
56
RunDirectory(const std::string & hash_directory,uint64 run_id)57 std::string RunDirectory(const std::string& hash_directory, uint64 run_id) {
58 return RunDirectory(
59 hash_directory,
60 strings::Printf("%llu", static_cast<unsigned long long>(run_id)));
61 }
62
RunDirectory(const std::string & hash_directory,const std::string & run_id)63 std::string RunDirectory(const std::string& hash_directory,
64 const std::string& run_id) {
65 return io::JoinPath(hash_directory, run_id);
66 }
67
ShardDirectory(const std::string & run_directory,int64 shard_id)68 std::string ShardDirectory(const std::string& run_directory, int64 shard_id) {
69 return io::JoinPath(
70 run_directory,
71 strings::Printf("%08llu%s", static_cast<unsigned long long>(shard_id),
72 kShardDirectorySuffix));
73 }
GetCheckpointFileName(const std::string & shard_directory,uint64 checkpoint_id)74 std::string GetCheckpointFileName(const std::string& shard_directory,
75 uint64 checkpoint_id) {
76 return io::JoinPath(
77 shard_directory,
78 strings::Printf("%08llu.snapshot",
79 static_cast<unsigned long long>(checkpoint_id)));
80 }
81
Create(Env * env,const std::string & filename,const std::string & compression_type,int version,const DataTypeVector & dtypes,std::unique_ptr<Writer> * out_writer)82 Status Writer::Create(Env* env, const std::string& filename,
83 const std::string& compression_type, int version,
84 const DataTypeVector& dtypes,
85 std::unique_ptr<Writer>* out_writer) {
86 switch (version) {
87 case 1:
88 *out_writer =
89 absl::make_unique<CustomWriter>(filename, compression_type, dtypes);
90 break;
91 case 2:
92 *out_writer =
93 absl::make_unique<TFRecordWriter>(filename, compression_type);
94 break;
95 default:
96 return errors::InvalidArgument("Snapshot writer version: ", version,
97 " is not supported.");
98 }
99
100 return (*out_writer)->Initialize(env);
101 }
102
TFRecordWriter(const std::string & filename,const std::string & compression_type)103 TFRecordWriter::TFRecordWriter(const std::string& filename,
104 const std::string& compression_type)
105 : filename_(filename), compression_type_(compression_type) {}
106
Initialize(tensorflow::Env * env)107 Status TFRecordWriter::Initialize(tensorflow::Env* env) {
108 TF_RETURN_IF_ERROR(env->NewAppendableFile(filename_, &dest_));
109
110 record_writer_ = absl::make_unique<io::RecordWriter>(
111 dest_.get(), io::RecordWriterOptions::CreateRecordWriterOptions(
112 /*compression_type=*/compression_type_));
113 return Status::OK();
114 }
115
WriteTensors(const std::vector<Tensor> & tensors)116 Status TFRecordWriter::WriteTensors(const std::vector<Tensor>& tensors) {
117 for (const auto& tensor : tensors) {
118 TensorProto proto;
119 tensor.AsProtoTensorContent(&proto);
120 #if defined(TF_CORD_SUPPORT)
121 // Creating raw pointer here because std::move() in a releases in OSS TF
122 // will result in a smart pointer being moved upon function creation, which
123 // will result in proto_buffer == nullptr when WriteRecord happens.
124 auto proto_buffer = new std::string();
125 proto.SerializeToString(proto_buffer);
126 absl::Cord proto_serialized = absl::MakeCordFromExternal(
127 *proto_buffer,
128 [proto_buffer](absl::string_view) { delete proto_buffer; });
129 TF_RETURN_IF_ERROR(record_writer_->WriteRecord(proto_serialized));
130 #else // TF_CORD_SUPPORT
131 TF_RETURN_IF_ERROR(record_writer_->WriteRecord(proto.SerializeAsString()));
132 #endif // TF_CORD_SUPPORT
133 }
134 return Status::OK();
135 }
136
Sync()137 Status TFRecordWriter::Sync() {
138 TF_RETURN_IF_ERROR(record_writer_->Flush());
139 return dest_->Flush();
140 }
141
Close()142 Status TFRecordWriter::Close() {
143 if (record_writer_ != nullptr) {
144 TF_RETURN_IF_ERROR(Sync());
145 TF_RETURN_IF_ERROR(record_writer_->Close());
146 TF_RETURN_IF_ERROR(dest_->Close());
147 record_writer_ = nullptr;
148 dest_ = nullptr;
149 }
150 return Status::OK();
151 }
152
~TFRecordWriter()153 TFRecordWriter::~TFRecordWriter() {
154 Status s = Close();
155 if (!s.ok()) {
156 LOG(ERROR) << "Failed to close snapshot file " << filename_ << ": " << s;
157 }
158 }
159
CustomWriter(const std::string & filename,const std::string & compression_type,const DataTypeVector & dtypes)160 CustomWriter::CustomWriter(const std::string& filename,
161 const std::string& compression_type,
162 const DataTypeVector& dtypes)
163 : filename_(filename),
164 compression_type_(compression_type),
165 dtypes_(dtypes) {}
166
Initialize(tensorflow::Env * env)167 Status CustomWriter::Initialize(tensorflow::Env* env) {
168 TF_RETURN_IF_ERROR(env->NewAppendableFile(filename_, &dest_));
169 #if defined(IS_SLIM_BUILD)
170 if (compression_type_ != io::compression::kNone) {
171 LOG(ERROR) << "Compression is unsupported on mobile platforms. Turning "
172 << "off compression.";
173 }
174 #else // IS_SLIM_BUILD
175 if (compression_type_ == io::compression::kGzip) {
176 zlib_underlying_dest_.swap(dest_);
177 io::ZlibCompressionOptions zlib_options;
178 zlib_options = io::ZlibCompressionOptions::GZIP();
179
180 io::ZlibOutputBuffer* zlib_output_buffer = new io::ZlibOutputBuffer(
181 zlib_underlying_dest_.get(), zlib_options.input_buffer_size,
182 zlib_options.output_buffer_size, zlib_options);
183 TF_CHECK_OK(zlib_output_buffer->Init());
184 dest_.reset(zlib_output_buffer);
185 }
186 #endif // IS_SLIM_BUILD
187 simple_tensor_mask_.reserve(dtypes_.size());
188 for (const auto& dtype : dtypes_) {
189 if (DataTypeCanUseMemcpy(dtype)) {
190 simple_tensor_mask_.push_back(true);
191 num_simple_++;
192 } else {
193 simple_tensor_mask_.push_back(false);
194 num_complex_++;
195 }
196 }
197
198 return Status::OK();
199 }
200
WriteTensors(const std::vector<Tensor> & tensors)201 Status CustomWriter::WriteTensors(const std::vector<Tensor>& tensors) {
202 if (compression_type_ != io::compression::kSnappy) {
203 experimental::SnapshotRecord record;
204 for (const auto& tensor : tensors) {
205 TensorProto* t = record.add_tensor();
206 tensor.AsProtoTensorContent(t);
207 }
208 #if defined(TF_CORD_SUPPORT)
209 auto record_buffer = new std::string();
210 record.SerializeToString(record_buffer);
211 absl::Cord record_serialized = absl::MakeCordFromExternal(
212 *record_buffer,
213 [record_buffer](absl::string_view) { delete record_buffer; });
214 return WriteRecord(record_serialized);
215 #else // TF_CORD_SUPPORT
216 return WriteRecord(record.SerializeAsString());
217 #endif // TF_CORD_SUPPORT
218 }
219
220 std::vector<const TensorBuffer*> tensor_buffers;
221 tensor_buffers.reserve(num_simple_);
222 std::vector<TensorProto> tensor_protos;
223 tensor_protos.reserve(num_complex_);
224 experimental::SnapshotTensorMetadata metadata;
225 int64 total_size = 0;
226 for (int i = 0, end = tensors.size(); i < end; ++i) {
227 const Tensor& tensor = tensors[i];
228 experimental::TensorMetadata* tensor_metadata =
229 metadata.add_tensor_metadata();
230 tensor.shape().AsProto(tensor_metadata->mutable_tensor_shape());
231 int64 size = 0;
232 if (simple_tensor_mask_[i]) {
233 auto tensor_buffer = DMAHelper::buffer(&tensor);
234 tensor_buffers.push_back(tensor_buffer);
235 size = tensor_buffer->size();
236 } else {
237 TensorProto proto;
238 tensor.AsProtoTensorContent(&proto);
239 size = proto.ByteSizeLong();
240 tensor_protos.push_back(std::move(proto));
241 }
242 tensor_metadata->set_tensor_size_bytes(size);
243 total_size += size;
244 }
245
246 std::vector<char> uncompressed(total_size);
247 char* position = uncompressed.data();
248 int buffer_index = 0;
249 int proto_index = 0;
250 for (int i = 0, end = tensors.size(); i < end; ++i) {
251 const auto& tensor_metadata = metadata.tensor_metadata(i);
252 if (simple_tensor_mask_[i]) {
253 memcpy(position, tensor_buffers[buffer_index]->data(),
254 tensor_metadata.tensor_size_bytes());
255 buffer_index++;
256 } else {
257 tensor_protos[proto_index].SerializeToArray(
258 position, tensor_metadata.tensor_size_bytes());
259 proto_index++;
260 }
261 position += tensor_metadata.tensor_size_bytes();
262 }
263 DCHECK_EQ(position, uncompressed.data() + total_size);
264
265 string output;
266 if (!port::Snappy_Compress(uncompressed.data(), total_size, &output)) {
267 return errors::Internal("Failed to compress using snappy.");
268 }
269
270 #if defined(TF_CORD_SUPPORT)
271 auto metadata_buffer = new std::string();
272 metadata.SerializeToString(metadata_buffer);
273 absl::Cord metadata_serialized = absl::MakeCordFromExternal(
274 *metadata_buffer,
275 [metadata_buffer](absl::string_view) { delete metadata_buffer; });
276 #else
277 std::string metadata_serialized = metadata.SerializeAsString();
278 #endif // TF_CORD_SUPPORT
279 TF_RETURN_IF_ERROR(WriteRecord(metadata_serialized));
280 TF_RETURN_IF_ERROR(WriteRecord(output));
281 return Status::OK();
282 }
283
Sync()284 Status CustomWriter::Sync() { return dest_->Sync(); }
285
Close()286 Status CustomWriter::Close() {
287 if (dest_ != nullptr) {
288 TF_RETURN_IF_ERROR(dest_->Close());
289 dest_ = nullptr;
290 }
291 if (zlib_underlying_dest_ != nullptr) {
292 TF_RETURN_IF_ERROR(zlib_underlying_dest_->Close());
293 zlib_underlying_dest_ = nullptr;
294 }
295 return Status::OK();
296 }
297
~CustomWriter()298 CustomWriter::~CustomWriter() {
299 Status s = Close();
300 if (!s.ok()) {
301 LOG(ERROR) << "Could not finish writing file: " << s;
302 }
303 }
304
WriteRecord(const StringPiece & data)305 Status CustomWriter::WriteRecord(const StringPiece& data) {
306 char header[kHeaderSize];
307 core::EncodeFixed64(header, data.size());
308 TF_RETURN_IF_ERROR(dest_->Append(StringPiece(header, sizeof(header))));
309 return dest_->Append(data);
310 }
311
312 #if defined(TF_CORD_SUPPORT)
WriteRecord(const absl::Cord & data)313 Status CustomWriter::WriteRecord(const absl::Cord& data) {
314 char header[kHeaderSize];
315 core::EncodeFixed64(header, data.size());
316 TF_RETURN_IF_ERROR(dest_->Append(StringPiece(header, sizeof(header))));
317 return dest_->Append(data);
318 }
319 #endif // TF_CORD_SUPPORT
320
Create(Env * env,const std::string & filename,const string & compression_type,int version,const DataTypeVector & dtypes,std::unique_ptr<Reader> * out_reader)321 Status Reader::Create(Env* env, const std::string& filename,
322 const string& compression_type, int version,
323 const DataTypeVector& dtypes,
324 std::unique_ptr<Reader>* out_reader) {
325 switch (version) {
326 // CustomReader is able to read a legacy snapshot file format (v0) though
327 // custom writer doesn't have the ability to write it any more since it is
328 // strictly worse than V1.
329 case 0:
330 case 1:
331 *out_reader = absl::make_unique<CustomReader>(filename, compression_type,
332 version, dtypes);
333 break;
334 case 2:
335 *out_reader =
336 absl::make_unique<TFRecordReader>(filename, compression_type, dtypes);
337 break;
338 default:
339 return errors::InvalidArgument("Snapshot reader version: ", version,
340 " is not supported.");
341 }
342
343 return (*out_reader)->Initialize(env);
344 }
345
SkipRecords(int64 num_records)346 Status Reader::SkipRecords(int64 num_records) {
347 // TODO(frankchn): Optimize to not parse the entire Tensor and actually skip.
348 for (int i = 0; i < num_records; ++i) {
349 std::vector<Tensor> unused_tensors;
350 TF_RETURN_IF_ERROR(ReadTensors(&unused_tensors));
351 }
352 return Status::OK();
353 }
354
355 class Reader::Dataset : public DatasetBase {
356 public:
Dataset(const std::string & shard_dir,const std::string & compression,const int64 version,const DataTypeVector & dtypes,const std::vector<PartialTensorShape> & shapes,const int64 start_index,DatasetContext::Params params)357 explicit Dataset(const std::string& shard_dir, const std::string& compression,
358 const int64 version, const DataTypeVector& dtypes,
359 const std::vector<PartialTensorShape>& shapes,
360 const int64 start_index, DatasetContext::Params params)
361 : DatasetBase(DatasetContext(std::move(params))),
362 shard_dir_(shard_dir),
363 compression_(compression),
364 version_(version),
365 dtypes_(dtypes),
366 shapes_(shapes),
367 start_index_(start_index) {}
368
output_dtypes() const369 const DataTypeVector& output_dtypes() const override { return dtypes_; }
370
output_shapes() const371 const std::vector<PartialTensorShape>& output_shapes() const override {
372 return shapes_;
373 }
374
DebugString() const375 std::string DebugString() const override {
376 return "snapshot_util::Reader::Dataset";
377 }
378
InputDatasets(std::vector<const DatasetBase * > * inputs) const379 Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
380 return Status::OK();
381 }
382
CheckExternalState() const383 Status CheckExternalState() const override { return Status::OK(); }
384
385 protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** node) const386 Status AsGraphDefInternal(SerializationContext* ctx,
387 DatasetGraphDefBuilder* b,
388 Node** node) const override {
389 // Not necessary perform any serialization as this dataset is only
390 // constructed at runtime in C++ and will be reconstructed every time.
391 return Status::OK();
392 }
393
MakeIteratorInternal(const string & prefix) const394 std::unique_ptr<IteratorBase> MakeIteratorInternal(
395 const string& prefix) const override {
396 return absl::make_unique<Iterator>(Iterator::Params{
397 this, name_utils::IteratorPrefix(node_name(), prefix)});
398 }
399
400 private:
401 class Iterator : public DatasetIterator<Dataset> {
402 public:
Iterator(const Params & params)403 explicit Iterator(const Params& params)
404 : DatasetIterator<Dataset>(params), current_checkpoint_id_(0) {}
405
Initialize(IteratorContext * ctx)406 Status Initialize(IteratorContext* ctx) override {
407 TF_RETURN_IF_ERROR(Reader::Create(
408 ctx->env(), GetCurrentFilename(), dataset()->compression_,
409 dataset()->version_, dataset()->dtypes_, &reader_));
410 bool end_of_sequence;
411 for (int64 i = 0; i < dataset()->start_index_; ++i) {
412 // TODO(frankchn): Optimize this to not parse every single element.
413 std::vector<Tensor> unused;
414 TF_RETURN_IF_ERROR(GetNextInternal(ctx, &unused, &end_of_sequence));
415 }
416 return Status::OK();
417 }
418
419 protected:
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)420 Status GetNextInternal(IteratorContext* ctx,
421 std::vector<Tensor>* out_tensors,
422 bool* end_of_sequence) override {
423 *end_of_sequence = false;
424 Status s = reader_->ReadTensors(out_tensors);
425 if (!errors::IsOutOfRange(s)) {
426 return s;
427 }
428 Status status = AdvanceToNextFile(ctx->env());
429 if (errors::IsNotFound(status)) {
430 *end_of_sequence = true;
431 return Status::OK();
432 } else {
433 return status;
434 }
435 }
436
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)437 Status SaveInternal(SerializationContext* ctx,
438 IteratorStateWriter* writer) override {
439 // Not necessary to save any state as this iterator will be reconstructed
440 // from scratch when the parent snapshot dataset is restored from
441 // checkpoint.
442 return Status::OK();
443 }
444
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)445 Status RestoreInternal(IteratorContext* ctx,
446 IteratorStateReader* reader) override {
447 // Not necessary to restore any state as this iterator will be
448 // reconstructed from scratch when the parent snapshot dataset is restored
449 // from checkpoint.
450 return Status::OK();
451 }
452
453 private:
GetCurrentFilename()454 std::string GetCurrentFilename() {
455 return GetCheckpointFileName(dataset()->shard_dir_,
456 current_checkpoint_id_);
457 }
458
AdvanceToNextFile(Env * env)459 Status AdvanceToNextFile(Env* env) {
460 current_checkpoint_id_++;
461 TF_RETURN_IF_ERROR(env->FileExists(GetCurrentFilename()));
462 return Reader::Create(env, GetCurrentFilename(), dataset()->compression_,
463 dataset()->version_, dataset()->dtypes_, &reader_);
464 }
465
466 std::unique_ptr<Reader> reader_;
467
468 // Stores the id current checkpoint file that we are in the process of
469 // reading (e.g. if the file is currently 00000001.snapshot, then this will
470 // be 1).
471 uint64 current_checkpoint_id_;
472 };
473
474 const std::string shard_dir_;
475 const std::string compression_;
476 const int64 version_;
477 const DataTypeVector dtypes_;
478 const std::vector<PartialTensorShape> shapes_;
479 const int64 start_index_;
480 };
481
482 class Reader::NestedDataset : public DatasetBase {
483 public:
NestedDataset(std::vector<DatasetBase * > datasets,DatasetContext::Params params)484 explicit NestedDataset(std::vector<DatasetBase*> datasets,
485 DatasetContext::Params params)
486 : DatasetBase(DatasetContext(std::move(params))), datasets_(datasets) {
487 dtypes_.push_back(DT_VARIANT);
488 gtl::InlinedVector<int64, 1> element_dim_sizes;
489 element_dim_sizes.push_back(1);
490 partial_shapes_.emplace_back(element_dim_sizes);
491 }
492
output_dtypes() const493 const DataTypeVector& output_dtypes() const override { return dtypes_; }
494
output_shapes() const495 const std::vector<PartialTensorShape>& output_shapes() const override {
496 return partial_shapes_;
497 }
498
DebugString() const499 std::string DebugString() const override {
500 return "snapshot_util::Reader::NestedDataset";
501 }
502
InputDatasets(std::vector<const DatasetBase * > * inputs) const503 Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
504 inputs->clear();
505 return Status::OK();
506 }
507
CheckExternalState() const508 Status CheckExternalState() const override { return Status::OK(); }
509
510 protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** node) const511 Status AsGraphDefInternal(SerializationContext* ctx,
512 DatasetGraphDefBuilder* b,
513 Node** node) const override {
514 // Not necessary perform any serialization as this dataset is only
515 // constructed at runtime in C++ and will be reconstructed every time.
516 return Status::OK();
517 }
518
MakeIteratorInternal(const string & prefix) const519 std::unique_ptr<IteratorBase> MakeIteratorInternal(
520 const string& prefix) const override {
521 return absl::make_unique<Iterator>(Iterator::Params{
522 this, name_utils::IteratorPrefix(node_name(), prefix)});
523 }
524
525 private:
526 std::vector<DatasetBase*> datasets_;
527 DataTypeVector dtypes_;
528 std::vector<PartialTensorShape> partial_shapes_;
529
530 class Iterator : public DatasetIterator<NestedDataset> {
531 public:
Iterator(const Params & params)532 explicit Iterator(const Params& params)
533 : DatasetIterator<NestedDataset>(params), index_(0) {}
534
535 protected:
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)536 Status GetNextInternal(IteratorContext* ctx,
537 std::vector<Tensor>* out_tensors,
538 bool* end_of_sequence) override {
539 const int64 num_datasets = dataset()->datasets_.size();
540 *end_of_sequence = num_datasets == index_;
541 if (!*end_of_sequence) {
542 Tensor tensor(DT_VARIANT, TensorShape({}));
543
544 TF_RETURN_IF_ERROR(
545 StoreDatasetInVariantTensor(dataset()->datasets_[index_], &tensor));
546 out_tensors->clear();
547 out_tensors->push_back(std::move(tensor));
548
549 index_++;
550 }
551 return Status::OK();
552 }
553
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)554 Status SaveInternal(SerializationContext* ctx,
555 IteratorStateWriter* writer) override {
556 // Not necessary to save any state as this iterator will be reconstructed
557 // from scratch when the parent snapshot dataset is restored from
558 // checkpoint.
559 return Status::OK();
560 }
561
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)562 Status RestoreInternal(IteratorContext* ctx,
563 IteratorStateReader* reader) override {
564 // Not necessary to restore any state as this iterator will be
565 // reconstructed from scratch when the parent snapshot dataset is restored
566 // from checkpoint.
567 return Status::OK();
568 }
569
570 private:
571 int64 index_;
572 };
573 };
574
MakeNestedDataset(Env * env,const std::vector<std::string> & shard_dirs,const string & compression_type,int version,const DataTypeVector & dtypes,const std::vector<PartialTensorShape> & shapes,const int64 start_index,DatasetBase ** output)575 Status Reader::MakeNestedDataset(Env* env,
576 const std::vector<std::string>& shard_dirs,
577 const string& compression_type, int version,
578 const DataTypeVector& dtypes,
579 const std::vector<PartialTensorShape>& shapes,
580 const int64 start_index,
581 DatasetBase** output) {
582 std::vector<DatasetBase*> datasets;
583
584 datasets.reserve(shard_dirs.size());
585 for (const auto& shard_dir : shard_dirs) {
586 // TODO(frankchn): The reading pattern could be controlled in a non-round
587 // robin fashion, so we cannot assume a round-robin manner when restoring.
588 int64 dataset_start_index = start_index / shard_dirs.size();
589 if (start_index % shard_dirs.size() > datasets.size()) {
590 dataset_start_index++;
591 }
592
593 datasets.push_back(
594 new Dataset(shard_dir, compression_type, version, dtypes, shapes,
595 dataset_start_index,
596 DatasetContext::Params({"snapshot_util::Reader::Dataset",
597 "snapshot_util_reader_Dataset"})));
598 }
599
600 // Rotate the vector such that the first dataset contains the next element
601 // to be produced, but not if there are no shards at all (then we just
602 // construct an empty dataset).
603 if (!shard_dirs.empty()) {
604 std::rotate(datasets.begin(),
605 datasets.begin() + (start_index % shard_dirs.size()),
606 datasets.end());
607 }
608
609 *output = new NestedDataset(
610 datasets, DatasetContext::Params({"snapshot_util::Reader::NestedDataset",
611 "snapshot_util_reader_NestedDataset"}));
612 return Status::OK();
613 }
614
TFRecordReader(const std::string & filename,const string & compression_type,const DataTypeVector & dtypes)615 TFRecordReader::TFRecordReader(const std::string& filename,
616 const string& compression_type,
617 const DataTypeVector& dtypes)
618 : filename_(filename),
619 offset_(0),
620 compression_type_(compression_type),
621 dtypes_(dtypes) {}
622
Initialize(Env * env)623 Status TFRecordReader::Initialize(Env* env) {
624 TF_RETURN_IF_ERROR(env->NewRandomAccessFile(filename_, &file_));
625
626 record_reader_ = absl::make_unique<io::RecordReader>(
627 file_.get(), io::RecordReaderOptions::CreateRecordReaderOptions(
628 /*compression_type=*/compression_type_));
629 return Status::OK();
630 }
631
ReadTensors(std::vector<Tensor> * read_tensors)632 Status TFRecordReader::ReadTensors(std::vector<Tensor>* read_tensors) {
633 read_tensors->reserve(dtypes_.size());
634 for (int i = 0; i < dtypes_.size(); ++i) {
635 tstring record;
636 TF_RETURN_IF_ERROR(record_reader_->ReadRecord(&offset_, &record));
637
638 TensorProto proto;
639 proto.ParseFromArray(record.data(), record.size());
640
641 Tensor tensor;
642 if (!tensor.FromProto(proto)) {
643 return errors::DataLoss("Unable to parse tensor from stored proto.");
644 }
645
646 read_tensors->push_back(std::move(tensor));
647 }
648 return Status::OK();
649 }
650
CustomReader(const std::string & filename,const string & compression_type,const int version,const DataTypeVector & dtypes)651 CustomReader::CustomReader(const std::string& filename,
652 const string& compression_type, const int version,
653 const DataTypeVector& dtypes)
654 : filename_(filename),
655 compression_type_(compression_type),
656 version_(version),
657 dtypes_(dtypes) {}
658
Initialize(Env * env)659 Status CustomReader::Initialize(Env* env) {
660 TF_RETURN_IF_ERROR(env->NewRandomAccessFile(filename_, &file_));
661 input_stream_ = std::make_unique<io::RandomAccessInputStream>(file_.get());
662
663 #if defined(IS_SLIM_BUILD)
664 if (compression_type_ != io::compression::kNone) {
665 LOG(ERROR) << "Compression is unsupported on mobile platforms. Turning "
666 << "off compression.";
667 }
668 #else // IS_SLIM_BUILD
669 if (compression_type_ == io::compression::kGzip) {
670 io::ZlibCompressionOptions zlib_options;
671 zlib_options = io::ZlibCompressionOptions::GZIP();
672
673 input_stream_ = absl::make_unique<io::ZlibInputStream>(
674 input_stream_.release(), zlib_options.input_buffer_size,
675 zlib_options.output_buffer_size, zlib_options, true);
676 } else if (compression_type_ == io::compression::kSnappy) {
677 if (version_ == 0) {
678 input_stream_ = absl::make_unique<io::SnappyInputBuffer>(
679 file_.get(), /*input_buffer_bytes=*/kSnappyReaderInputBufferSizeBytes,
680 /*output_buffer_bytes=*/kSnappyReaderOutputBufferSizeBytes);
681 } else {
682 input_stream_ =
683 absl::make_unique<io::BufferedInputStream>(file_.get(), 64 << 20);
684 }
685 }
686 #endif // IS_SLIM_BUILD
687 simple_tensor_mask_.reserve(dtypes_.size());
688 for (const auto& dtype : dtypes_) {
689 if (DataTypeCanUseMemcpy(dtype)) {
690 simple_tensor_mask_.push_back(true);
691 num_simple_++;
692 } else {
693 simple_tensor_mask_.push_back(false);
694 num_complex_++;
695 }
696 }
697
698 return Status::OK();
699 }
700
ReadTensors(std::vector<Tensor> * read_tensors)701 Status CustomReader::ReadTensors(std::vector<Tensor>* read_tensors) {
702 profiler::TraceMe activity(
703 [&]() { return absl::StrCat(kClassName, kSeparator, "ReadTensors"); },
704 profiler::TraceMeLevel::kInfo);
705 if (version_ == 0 || compression_type_ != io::compression::kSnappy) {
706 return ReadTensorsV0(read_tensors);
707 }
708 if (version_ != 1) {
709 return errors::InvalidArgument("Version: ", version_, " is not supported.");
710 }
711 if (compression_type_ != io::compression::kSnappy) {
712 return errors::InvalidArgument("Compression ", compression_type_,
713 " is not supported.");
714 }
715
716 experimental::SnapshotTensorMetadata metadata;
717 tstring metadata_str;
718 TF_RETURN_IF_ERROR(ReadRecord(&metadata_str));
719 if (!metadata.ParseFromArray(metadata_str.data(), metadata_str.size())) {
720 return errors::DataLoss("Could not parse SnapshotTensorMetadata");
721 }
722 read_tensors->reserve(metadata.tensor_metadata_size());
723
724 std::vector<Tensor> simple_tensors;
725 simple_tensors.reserve(num_simple_);
726 std::vector<std::pair<std::unique_ptr<char[]>, size_t>> tensor_proto_strs;
727 tensor_proto_strs.reserve(num_complex_);
728 TF_RETURN_IF_ERROR(
729 SnappyUncompress(&metadata, &simple_tensors, &tensor_proto_strs));
730
731 int simple_index = 0;
732 int complex_index = 0;
733 for (int i = 0, end = simple_tensor_mask_.size(); i < end; ++i) {
734 if (simple_tensor_mask_[i]) {
735 read_tensors->push_back(std::move(simple_tensors[simple_index]));
736 simple_index++;
737 } else {
738 auto tensor_proto_str = std::move(tensor_proto_strs[complex_index].first);
739 size_t tensor_proto_size = tensor_proto_strs[complex_index].second;
740 TensorProto tp;
741 if (!tp.ParseFromArray(tensor_proto_str.get(), tensor_proto_size)) {
742 return errors::Internal("Could not parse TensorProto");
743 }
744 Tensor t;
745 if (!t.FromProto(tp)) {
746 return errors::Internal("Could not parse Tensor");
747 }
748 read_tensors->push_back(std::move(t));
749 complex_index++;
750 }
751 }
752 return Status::OK();
753 }
754
ReadTensorsV0(std::vector<Tensor> * read_tensors)755 Status CustomReader::ReadTensorsV0(std::vector<Tensor>* read_tensors) {
756 experimental::SnapshotRecord record;
757 #if defined(PLATFORM_GOOGLE)
758 absl::Cord c;
759 TF_RETURN_IF_ERROR(ReadRecord(&c));
760 record.ParseFromCord(c);
761 #else // PLATFORM_GOOGLE
762 tstring record_bytes;
763 TF_RETURN_IF_ERROR(ReadRecord(&record_bytes));
764 record.ParseFromArray(record_bytes.data(), record_bytes.size());
765 #endif // PLATFORM_GOOGLE
766 read_tensors->reserve(record.tensor_size());
767 for (int i = 0; i < record.tensor_size(); ++i) {
768 read_tensors->emplace_back();
769 if (!read_tensors->back().FromProto(record.tensor(i))) {
770 return errors::DataLoss("Unable to parse tensor from proto.");
771 }
772 }
773 return Status::OK();
774 }
775
SnappyUncompress(const experimental::SnapshotTensorMetadata * metadata,std::vector<Tensor> * simple_tensors,std::vector<std::pair<std::unique_ptr<char[]>,size_t>> * tensor_proto_strs)776 Status CustomReader::SnappyUncompress(
777 const experimental::SnapshotTensorMetadata* metadata,
778 std::vector<Tensor>* simple_tensors,
779 std::vector<std::pair<std::unique_ptr<char[]>, size_t>>*
780 tensor_proto_strs) {
781 tstring compressed;
782 TF_RETURN_IF_ERROR(ReadRecord(&compressed));
783 size_t size;
784 if (!port::Snappy_GetUncompressedLength(compressed.data(), compressed.size(),
785 &size)) {
786 return errors::Internal("Could not get snappy uncompressed length");
787 }
788
789 int num_tensors = metadata->tensor_metadata_size();
790 std::vector<struct iovec> iov(num_tensors);
791 int index = 0;
792 int64 total_size = 0;
793 for (int i = 0, end = simple_tensor_mask_.size(); i < end; ++i) {
794 const auto& tensor_metadata = metadata->tensor_metadata(i);
795 if (simple_tensor_mask_[i]) {
796 TensorShape shape(tensor_metadata.tensor_shape());
797 Tensor simple_tensor(dtypes_[i], shape);
798 TensorBuffer* buffer = DMAHelper::buffer(&simple_tensor);
799 iov[index].iov_base = buffer->data();
800 iov[index].iov_len = buffer->size();
801 simple_tensors->push_back(std::move(simple_tensor));
802 } else {
803 auto tensor_proto_str =
804 absl::make_unique<char[]>(tensor_metadata.tensor_size_bytes());
805 iov[index].iov_base = tensor_proto_str.get();
806 iov[index].iov_len = tensor_metadata.tensor_size_bytes();
807 tensor_proto_strs->push_back(std::make_pair(
808 std::move(tensor_proto_str), tensor_metadata.tensor_size_bytes()));
809 }
810 total_size += iov[index].iov_len;
811 index++;
812 }
813 const int64 size_int = size;
814 if (size_int != total_size) {
815 return errors::Internal("Uncompressed size mismatch. Snappy expects ", size,
816 " whereas the tensor metadata suggests ",
817 total_size);
818 }
819 if (!port::Snappy_UncompressToIOVec(compressed.data(), compressed.size(),
820 iov.data(), num_tensors)) {
821 return errors::Internal("Failed to perform snappy decompression.");
822 }
823 return Status::OK();
824 }
825
ReadRecord(tstring * record)826 Status CustomReader::ReadRecord(tstring* record) {
827 tstring header;
828 TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(kHeaderSize, &header));
829 uint64 length = core::DecodeFixed64(header.data());
830 return input_stream_->ReadNBytes(length, record);
831 }
832
833 #if defined(TF_CORD_SUPPORT)
ReadRecord(absl::Cord * record)834 Status CustomReader::ReadRecord(absl::Cord* record) {
835 tstring header;
836 TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(kHeaderSize, &header));
837 uint64 length = core::DecodeFixed64(header.data());
838 if (compression_type_ == io::compression::kNone) {
839 return input_stream_->ReadNBytes(length, record);
840 } else {
841 auto tmp_str = new tstring();
842 TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(length, tmp_str));
843 absl::string_view tmp_str_view(*tmp_str);
844 record->Append(absl::MakeCordFromExternal(
845 tmp_str_view, [tmp_str](absl::string_view) { delete tmp_str; }));
846 return Status::OK();
847 }
848 }
849 #endif // TF_CORD_SUPPORT
850
WriteMetadataFile(Env * env,const string & dir,const experimental::SnapshotMetadataRecord * metadata)851 Status WriteMetadataFile(Env* env, const string& dir,
852 const experimental::SnapshotMetadataRecord* metadata) {
853 string metadata_filename = io::JoinPath(dir, kMetadataFilename);
854 TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(dir));
855 std::string tmp_filename =
856 absl::StrCat(metadata_filename, "-tmp-", random::New64());
857 TF_RETURN_IF_ERROR(WriteBinaryProto(env, tmp_filename, *metadata));
858 return env->RenameFile(tmp_filename, metadata_filename);
859 }
860
ReadMetadataFile(Env * env,const string & dir,experimental::SnapshotMetadataRecord * metadata,bool * file_exists)861 Status ReadMetadataFile(Env* env, const string& dir,
862 experimental::SnapshotMetadataRecord* metadata,
863 bool* file_exists) {
864 string metadata_filename = io::JoinPath(dir, kMetadataFilename);
865 Status s = env->FileExists(metadata_filename);
866 *file_exists = s.ok();
867
868 if (*file_exists) {
869 return ReadBinaryProto(env, metadata_filename, metadata);
870 } else {
871 return Status::OK();
872 }
873 }
874
DumpDatasetGraph(Env * env,const std::string & path,uint64 hash,const GraphDef * graph)875 Status DumpDatasetGraph(Env* env, const std::string& path, uint64 hash,
876 const GraphDef* graph) {
877 std::string hash_hex =
878 strings::StrCat(strings::Hex(hash, strings::kZeroPad16));
879 std::string graph_file =
880 io::JoinPath(path, absl::StrCat(hash_hex, "-graph.pbtxt"));
881
882 LOG(INFO) << "Graph hash is " << hash_hex << ", writing to " << graph_file;
883 TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(path));
884 return WriteTextProto(env, graph_file, *graph);
885 }
886
DetermineOpState(const std::string & mode_string,bool file_exists,const experimental::SnapshotMetadataRecord * metadata,const uint64 pending_snapshot_expiry_seconds,Mode * mode)887 Status DetermineOpState(const std::string& mode_string, bool file_exists,
888 const experimental::SnapshotMetadataRecord* metadata,
889 const uint64 pending_snapshot_expiry_seconds,
890 Mode* mode) {
891 if (mode_string == kModeRead) {
892 // In read mode, we should expect a metadata file is written.
893 if (!file_exists) {
894 return errors::NotFound("Metadata file does not exist.");
895 }
896 LOG(INFO) << "Overriding mode to reader.";
897 *mode = READER;
898 return Status::OK();
899 }
900
901 if (mode_string == kModeWrite) {
902 LOG(INFO) << "Overriding mode to writer.";
903 *mode = WRITER;
904 return Status::OK();
905 }
906
907 if (mode_string == kModePassthrough) {
908 LOG(INFO) << "Overriding mode to passthrough.";
909 *mode = PASSTHROUGH;
910 return Status::OK();
911 }
912
913 if (!file_exists) {
914 *mode = WRITER;
915 return Status::OK();
916 }
917
918 if (metadata->finalized()) {
919 // File found, snapshot has been finalized.
920 *mode = READER;
921 return Status::OK();
922 }
923
924 int64 expiration_timer = static_cast<int64>(EnvTime::NowMicros()) -
925 pending_snapshot_expiry_seconds * 1000000;
926
927 if (metadata->creation_timestamp() >= expiration_timer) {
928 // Someone else is already writing and time has not expired.
929 *mode = PASSTHROUGH;
930 return Status::OK();
931 } else {
932 // Time has expired, we write regardless.
933 *mode = WRITER;
934 return Status::OK();
935 }
936 }
937
AsyncWriter(Env * env,int64 file_index,const std::string & shard_directory,uint64 checkpoint_id,const std::string & compression,int64 version,const DataTypeVector & output_types,std::function<void (Status)> done)938 AsyncWriter::AsyncWriter(Env* env, int64 file_index,
939 const std::string& shard_directory,
940 uint64 checkpoint_id, const std::string& compression,
941 int64 version, const DataTypeVector& output_types,
942 std::function<void(Status)> done) {
943 thread_ = absl::WrapUnique(env->StartThread(
944 ThreadOptions(), absl::StrCat("writer_thread_", file_index),
945 [this, env, shard_directory, checkpoint_id, compression, version,
946 &output_types, done = std::move(done)] {
947 done(WriterThread(env, shard_directory, checkpoint_id, compression,
948 version, output_types));
949 }));
950 }
951
Write(const std::vector<Tensor> & tensors)952 void AsyncWriter::Write(const std::vector<Tensor>& tensors) {
953 mutex_lock l(mu_);
954 ElementOrEOF element;
955 element.value = tensors;
956 deque_.push_back(std::move(element));
957 }
958
SignalEOF()959 void AsyncWriter::SignalEOF() {
960 mutex_lock l(mu_);
961 ElementOrEOF be;
962 be.end_of_sequence = true;
963 deque_.push_back(std::move(be));
964 }
965
Consume(ElementOrEOF * be)966 void AsyncWriter::Consume(ElementOrEOF* be) {
967 mutex_lock l(mu_);
968 mu_.Await(tensorflow::Condition(this, &AsyncWriter::ElementAvailable));
969 *be = deque_.front();
970 deque_.pop_front();
971 }
972
ElementAvailable()973 bool AsyncWriter::ElementAvailable() { return !deque_.empty(); }
974
WriterThread(Env * env,const std::string & shard_directory,uint64 checkpoint_id,const std::string & compression,int64 version,DataTypeVector output_types)975 Status AsyncWriter::WriterThread(Env* env, const std::string& shard_directory,
976 uint64 checkpoint_id,
977 const std::string& compression, int64 version,
978 DataTypeVector output_types) {
979 std::unique_ptr<snapshot_util::Writer> writer;
980 TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(shard_directory));
981
982 TF_RETURN_IF_ERROR(snapshot_util::Writer::Create(
983 env, GetCheckpointFileName(shard_directory, checkpoint_id), compression,
984 version, std::move(output_types), &writer));
985
986 while (true) {
987 ElementOrEOF be;
988 Consume(&be);
989
990 if (be.end_of_sequence) {
991 TF_RETURN_IF_ERROR(writer->Close());
992 break;
993 }
994
995 TF_RETURN_IF_ERROR(writer->WriteTensors(be.value));
996 }
997 return Status::OK();
998 }
999
1000 } // namespace snapshot_util
1001 } // namespace data
1002 } // namespace tensorflow
1003