• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/kernels/data/experimental/snapshot_util.h"
17 
18 #include <queue>
19 
20 #include "absl/memory/memory.h"
21 #include "tensorflow/core/common_runtime/dma_helper.h"
22 #include "tensorflow/core/framework/dataset.h"
23 #include "tensorflow/core/framework/graph.pb.h"
24 #include "tensorflow/core/framework/tensor.pb.h"
25 #include "tensorflow/core/kernels/data/name_utils.h"
26 #include "tensorflow/core/lib/io/buffered_inputstream.h"
27 #include "tensorflow/core/lib/io/random_inputstream.h"
28 #include "tensorflow/core/lib/io/record_writer.h"
29 #include "tensorflow/core/lib/io/snappy/snappy_inputbuffer.h"
30 #include "tensorflow/core/lib/io/snappy/snappy_outputbuffer.h"
31 #include "tensorflow/core/lib/io/zlib_compression_options.h"
32 #include "tensorflow/core/lib/io/zlib_inputstream.h"
33 #include "tensorflow/core/lib/io/zlib_outputbuffer.h"
34 #include "tensorflow/core/platform/coding.h"
35 #include "tensorflow/core/platform/errors.h"
36 #include "tensorflow/core/platform/file_system.h"
37 #include "tensorflow/core/platform/path.h"
38 #include "tensorflow/core/platform/random.h"
39 #include "tensorflow/core/platform/stringprintf.h"
40 #include "tensorflow/core/profiler/lib/traceme.h"
41 #include "tensorflow/core/protobuf/snapshot.pb.h"
42 
43 namespace tensorflow {
44 namespace data {
45 namespace snapshot_util {
46 
47 /* static */ constexpr const int64
48     CustomReader::kSnappyReaderInputBufferSizeBytes;
49 /* static */ constexpr const int64
50     CustomReader::kSnappyReaderOutputBufferSizeBytes;
51 
HashDirectory(const std::string & path,uint64 hash)52 std::string HashDirectory(const std::string& path, uint64 hash) {
53   return io::JoinPath(
54       path, strings::Printf("%llu", static_cast<unsigned long long>(hash)));
55 }
56 
RunDirectory(const std::string & hash_directory,uint64 run_id)57 std::string RunDirectory(const std::string& hash_directory, uint64 run_id) {
58   return RunDirectory(
59       hash_directory,
60       strings::Printf("%llu", static_cast<unsigned long long>(run_id)));
61 }
62 
RunDirectory(const std::string & hash_directory,const std::string & run_id)63 std::string RunDirectory(const std::string& hash_directory,
64                          const std::string& run_id) {
65   return io::JoinPath(hash_directory, run_id);
66 }
67 
ShardDirectory(const std::string & run_directory,int64 shard_id)68 std::string ShardDirectory(const std::string& run_directory, int64 shard_id) {
69   return io::JoinPath(
70       run_directory,
71       strings::Printf("%08llu%s", static_cast<unsigned long long>(shard_id),
72                       kShardDirectorySuffix));
73 }
GetCheckpointFileName(const std::string & shard_directory,uint64 checkpoint_id)74 std::string GetCheckpointFileName(const std::string& shard_directory,
75                                   uint64 checkpoint_id) {
76   return io::JoinPath(
77       shard_directory,
78       strings::Printf("%08llu.snapshot",
79                       static_cast<unsigned long long>(checkpoint_id)));
80 }
81 
Create(Env * env,const std::string & filename,const std::string & compression_type,int version,const DataTypeVector & dtypes,std::unique_ptr<Writer> * out_writer)82 Status Writer::Create(Env* env, const std::string& filename,
83                       const std::string& compression_type, int version,
84                       const DataTypeVector& dtypes,
85                       std::unique_ptr<Writer>* out_writer) {
86   switch (version) {
87     case 1:
88       *out_writer =
89           absl::make_unique<CustomWriter>(filename, compression_type, dtypes);
90       break;
91     case 2:
92       *out_writer =
93           absl::make_unique<TFRecordWriter>(filename, compression_type);
94       break;
95     default:
96       return errors::InvalidArgument("Snapshot writer version: ", version,
97                                      " is not supported.");
98   }
99 
100   return (*out_writer)->Initialize(env);
101 }
102 
TFRecordWriter(const std::string & filename,const std::string & compression_type)103 TFRecordWriter::TFRecordWriter(const std::string& filename,
104                                const std::string& compression_type)
105     : filename_(filename), compression_type_(compression_type) {}
106 
Initialize(tensorflow::Env * env)107 Status TFRecordWriter::Initialize(tensorflow::Env* env) {
108   TF_RETURN_IF_ERROR(env->NewAppendableFile(filename_, &dest_));
109 
110   record_writer_ = absl::make_unique<io::RecordWriter>(
111       dest_.get(), io::RecordWriterOptions::CreateRecordWriterOptions(
112                        /*compression_type=*/compression_type_));
113   return Status::OK();
114 }
115 
WriteTensors(const std::vector<Tensor> & tensors)116 Status TFRecordWriter::WriteTensors(const std::vector<Tensor>& tensors) {
117   for (const auto& tensor : tensors) {
118     TensorProto proto;
119     tensor.AsProtoTensorContent(&proto);
120 #if defined(TF_CORD_SUPPORT)
121     // Creating raw pointer here because std::move() in a releases in OSS TF
122     // will result in a smart pointer being moved upon function creation, which
123     // will result in proto_buffer == nullptr when WriteRecord happens.
124     auto proto_buffer = new std::string();
125     proto.SerializeToString(proto_buffer);
126     absl::Cord proto_serialized = absl::MakeCordFromExternal(
127         *proto_buffer,
128         [proto_buffer](absl::string_view) { delete proto_buffer; });
129     TF_RETURN_IF_ERROR(record_writer_->WriteRecord(proto_serialized));
130 #else   // TF_CORD_SUPPORT
131     TF_RETURN_IF_ERROR(record_writer_->WriteRecord(proto.SerializeAsString()));
132 #endif  // TF_CORD_SUPPORT
133   }
134   return Status::OK();
135 }
136 
Sync()137 Status TFRecordWriter::Sync() {
138   TF_RETURN_IF_ERROR(record_writer_->Flush());
139   return dest_->Flush();
140 }
141 
Close()142 Status TFRecordWriter::Close() {
143   if (record_writer_ != nullptr) {
144     TF_RETURN_IF_ERROR(Sync());
145     TF_RETURN_IF_ERROR(record_writer_->Close());
146     TF_RETURN_IF_ERROR(dest_->Close());
147     record_writer_ = nullptr;
148     dest_ = nullptr;
149   }
150   return Status::OK();
151 }
152 
~TFRecordWriter()153 TFRecordWriter::~TFRecordWriter() {
154   Status s = Close();
155   if (!s.ok()) {
156     LOG(ERROR) << "Failed to close snapshot file " << filename_ << ": " << s;
157   }
158 }
159 
CustomWriter(const std::string & filename,const std::string & compression_type,const DataTypeVector & dtypes)160 CustomWriter::CustomWriter(const std::string& filename,
161                            const std::string& compression_type,
162                            const DataTypeVector& dtypes)
163     : filename_(filename),
164       compression_type_(compression_type),
165       dtypes_(dtypes) {}
166 
Initialize(tensorflow::Env * env)167 Status CustomWriter::Initialize(tensorflow::Env* env) {
168   TF_RETURN_IF_ERROR(env->NewAppendableFile(filename_, &dest_));
169 #if defined(IS_SLIM_BUILD)
170   if (compression_type_ != io::compression::kNone) {
171     LOG(ERROR) << "Compression is unsupported on mobile platforms. Turning "
172                << "off compression.";
173   }
174 #else   // IS_SLIM_BUILD
175   if (compression_type_ == io::compression::kGzip) {
176     zlib_underlying_dest_.swap(dest_);
177     io::ZlibCompressionOptions zlib_options;
178     zlib_options = io::ZlibCompressionOptions::GZIP();
179 
180     io::ZlibOutputBuffer* zlib_output_buffer = new io::ZlibOutputBuffer(
181         zlib_underlying_dest_.get(), zlib_options.input_buffer_size,
182         zlib_options.output_buffer_size, zlib_options);
183     TF_CHECK_OK(zlib_output_buffer->Init());
184     dest_.reset(zlib_output_buffer);
185   }
186 #endif  // IS_SLIM_BUILD
187   simple_tensor_mask_.reserve(dtypes_.size());
188   for (const auto& dtype : dtypes_) {
189     if (DataTypeCanUseMemcpy(dtype)) {
190       simple_tensor_mask_.push_back(true);
191       num_simple_++;
192     } else {
193       simple_tensor_mask_.push_back(false);
194       num_complex_++;
195     }
196   }
197 
198   return Status::OK();
199 }
200 
WriteTensors(const std::vector<Tensor> & tensors)201 Status CustomWriter::WriteTensors(const std::vector<Tensor>& tensors) {
202   if (compression_type_ != io::compression::kSnappy) {
203     experimental::SnapshotRecord record;
204     for (const auto& tensor : tensors) {
205       TensorProto* t = record.add_tensor();
206       tensor.AsProtoTensorContent(t);
207     }
208 #if defined(TF_CORD_SUPPORT)
209     auto record_buffer = new std::string();
210     record.SerializeToString(record_buffer);
211     absl::Cord record_serialized = absl::MakeCordFromExternal(
212         *record_buffer,
213         [record_buffer](absl::string_view) { delete record_buffer; });
214     return WriteRecord(record_serialized);
215 #else   // TF_CORD_SUPPORT
216     return WriteRecord(record.SerializeAsString());
217 #endif  // TF_CORD_SUPPORT
218   }
219 
220   std::vector<const TensorBuffer*> tensor_buffers;
221   tensor_buffers.reserve(num_simple_);
222   std::vector<TensorProto> tensor_protos;
223   tensor_protos.reserve(num_complex_);
224   experimental::SnapshotTensorMetadata metadata;
225   int64 total_size = 0;
226   for (int i = 0, end = tensors.size(); i < end; ++i) {
227     const Tensor& tensor = tensors[i];
228     experimental::TensorMetadata* tensor_metadata =
229         metadata.add_tensor_metadata();
230     tensor.shape().AsProto(tensor_metadata->mutable_tensor_shape());
231     int64 size = 0;
232     if (simple_tensor_mask_[i]) {
233       auto tensor_buffer = DMAHelper::buffer(&tensor);
234       tensor_buffers.push_back(tensor_buffer);
235       size = tensor_buffer->size();
236     } else {
237       TensorProto proto;
238       tensor.AsProtoTensorContent(&proto);
239       size = proto.ByteSizeLong();
240       tensor_protos.push_back(std::move(proto));
241     }
242     tensor_metadata->set_tensor_size_bytes(size);
243     total_size += size;
244   }
245 
246   std::vector<char> uncompressed(total_size);
247   char* position = uncompressed.data();
248   int buffer_index = 0;
249   int proto_index = 0;
250   for (int i = 0, end = tensors.size(); i < end; ++i) {
251     const auto& tensor_metadata = metadata.tensor_metadata(i);
252     if (simple_tensor_mask_[i]) {
253       memcpy(position, tensor_buffers[buffer_index]->data(),
254              tensor_metadata.tensor_size_bytes());
255       buffer_index++;
256     } else {
257       tensor_protos[proto_index].SerializeToArray(
258           position, tensor_metadata.tensor_size_bytes());
259       proto_index++;
260     }
261     position += tensor_metadata.tensor_size_bytes();
262   }
263   DCHECK_EQ(position, uncompressed.data() + total_size);
264 
265   string output;
266   if (!port::Snappy_Compress(uncompressed.data(), total_size, &output)) {
267     return errors::Internal("Failed to compress using snappy.");
268   }
269 
270 #if defined(TF_CORD_SUPPORT)
271   auto metadata_buffer = new std::string();
272   metadata.SerializeToString(metadata_buffer);
273   absl::Cord metadata_serialized = absl::MakeCordFromExternal(
274       *metadata_buffer,
275       [metadata_buffer](absl::string_view) { delete metadata_buffer; });
276 #else
277   std::string metadata_serialized = metadata.SerializeAsString();
278 #endif  // TF_CORD_SUPPORT
279   TF_RETURN_IF_ERROR(WriteRecord(metadata_serialized));
280   TF_RETURN_IF_ERROR(WriteRecord(output));
281   return Status::OK();
282 }
283 
Sync()284 Status CustomWriter::Sync() { return dest_->Sync(); }
285 
Close()286 Status CustomWriter::Close() {
287   if (dest_ != nullptr) {
288     TF_RETURN_IF_ERROR(dest_->Close());
289     dest_ = nullptr;
290   }
291   if (zlib_underlying_dest_ != nullptr) {
292     TF_RETURN_IF_ERROR(zlib_underlying_dest_->Close());
293     zlib_underlying_dest_ = nullptr;
294   }
295   return Status::OK();
296 }
297 
~CustomWriter()298 CustomWriter::~CustomWriter() {
299   Status s = Close();
300   if (!s.ok()) {
301     LOG(ERROR) << "Could not finish writing file: " << s;
302   }
303 }
304 
WriteRecord(const StringPiece & data)305 Status CustomWriter::WriteRecord(const StringPiece& data) {
306   char header[kHeaderSize];
307   core::EncodeFixed64(header, data.size());
308   TF_RETURN_IF_ERROR(dest_->Append(StringPiece(header, sizeof(header))));
309   return dest_->Append(data);
310 }
311 
312 #if defined(TF_CORD_SUPPORT)
WriteRecord(const absl::Cord & data)313 Status CustomWriter::WriteRecord(const absl::Cord& data) {
314   char header[kHeaderSize];
315   core::EncodeFixed64(header, data.size());
316   TF_RETURN_IF_ERROR(dest_->Append(StringPiece(header, sizeof(header))));
317   return dest_->Append(data);
318 }
319 #endif  // TF_CORD_SUPPORT
320 
Create(Env * env,const std::string & filename,const string & compression_type,int version,const DataTypeVector & dtypes,std::unique_ptr<Reader> * out_reader)321 Status Reader::Create(Env* env, const std::string& filename,
322                       const string& compression_type, int version,
323                       const DataTypeVector& dtypes,
324                       std::unique_ptr<Reader>* out_reader) {
325   switch (version) {
326     // CustomReader is able to read a legacy snapshot file format (v0) though
327     // custom writer doesn't have the ability to write it any more since it is
328     // strictly worse than V1.
329     case 0:
330     case 1:
331       *out_reader = absl::make_unique<CustomReader>(filename, compression_type,
332                                                     version, dtypes);
333       break;
334     case 2:
335       *out_reader =
336           absl::make_unique<TFRecordReader>(filename, compression_type, dtypes);
337       break;
338     default:
339       return errors::InvalidArgument("Snapshot reader version: ", version,
340                                      " is not supported.");
341   }
342 
343   return (*out_reader)->Initialize(env);
344 }
345 
SkipRecords(int64 num_records)346 Status Reader::SkipRecords(int64 num_records) {
347   // TODO(frankchn): Optimize to not parse the entire Tensor and actually skip.
348   for (int i = 0; i < num_records; ++i) {
349     std::vector<Tensor> unused_tensors;
350     TF_RETURN_IF_ERROR(ReadTensors(&unused_tensors));
351   }
352   return Status::OK();
353 }
354 
355 class Reader::Dataset : public DatasetBase {
356  public:
Dataset(const std::string & shard_dir,const std::string & compression,const int64 version,const DataTypeVector & dtypes,const std::vector<PartialTensorShape> & shapes,const int64 start_index,DatasetContext::Params params)357   explicit Dataset(const std::string& shard_dir, const std::string& compression,
358                    const int64 version, const DataTypeVector& dtypes,
359                    const std::vector<PartialTensorShape>& shapes,
360                    const int64 start_index, DatasetContext::Params params)
361       : DatasetBase(DatasetContext(std::move(params))),
362         shard_dir_(shard_dir),
363         compression_(compression),
364         version_(version),
365         dtypes_(dtypes),
366         shapes_(shapes),
367         start_index_(start_index) {}
368 
output_dtypes() const369   const DataTypeVector& output_dtypes() const override { return dtypes_; }
370 
output_shapes() const371   const std::vector<PartialTensorShape>& output_shapes() const override {
372     return shapes_;
373   }
374 
DebugString() const375   std::string DebugString() const override {
376     return "snapshot_util::Reader::Dataset";
377   }
378 
InputDatasets(std::vector<const DatasetBase * > * inputs) const379   Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
380     return Status::OK();
381   }
382 
CheckExternalState() const383   Status CheckExternalState() const override { return Status::OK(); }
384 
385  protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** node) const386   Status AsGraphDefInternal(SerializationContext* ctx,
387                             DatasetGraphDefBuilder* b,
388                             Node** node) const override {
389     // Not necessary perform any serialization as this dataset is only
390     // constructed at runtime in C++ and will be reconstructed every time.
391     return Status::OK();
392   }
393 
MakeIteratorInternal(const string & prefix) const394   std::unique_ptr<IteratorBase> MakeIteratorInternal(
395       const string& prefix) const override {
396     return absl::make_unique<Iterator>(Iterator::Params{
397         this, name_utils::IteratorPrefix(node_name(), prefix)});
398   }
399 
400  private:
401   class Iterator : public DatasetIterator<Dataset> {
402    public:
Iterator(const Params & params)403     explicit Iterator(const Params& params)
404         : DatasetIterator<Dataset>(params), current_checkpoint_id_(0) {}
405 
Initialize(IteratorContext * ctx)406     Status Initialize(IteratorContext* ctx) override {
407       TF_RETURN_IF_ERROR(Reader::Create(
408           ctx->env(), GetCurrentFilename(), dataset()->compression_,
409           dataset()->version_, dataset()->dtypes_, &reader_));
410       bool end_of_sequence;
411       for (int64 i = 0; i < dataset()->start_index_; ++i) {
412         // TODO(frankchn): Optimize this to not parse every single element.
413         std::vector<Tensor> unused;
414         TF_RETURN_IF_ERROR(GetNextInternal(ctx, &unused, &end_of_sequence));
415       }
416       return Status::OK();
417     }
418 
419    protected:
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)420     Status GetNextInternal(IteratorContext* ctx,
421                            std::vector<Tensor>* out_tensors,
422                            bool* end_of_sequence) override {
423       *end_of_sequence = false;
424       Status s = reader_->ReadTensors(out_tensors);
425       if (!errors::IsOutOfRange(s)) {
426         return s;
427       }
428       Status status = AdvanceToNextFile(ctx->env());
429       if (errors::IsNotFound(status)) {
430         *end_of_sequence = true;
431         return Status::OK();
432       } else {
433         return status;
434       }
435     }
436 
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)437     Status SaveInternal(SerializationContext* ctx,
438                         IteratorStateWriter* writer) override {
439       // Not necessary to save any state as this iterator will be reconstructed
440       // from scratch when the parent snapshot dataset is restored from
441       // checkpoint.
442       return Status::OK();
443     }
444 
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)445     Status RestoreInternal(IteratorContext* ctx,
446                            IteratorStateReader* reader) override {
447       // Not necessary to restore any state as this iterator will be
448       // reconstructed from scratch when the parent snapshot dataset is restored
449       // from checkpoint.
450       return Status::OK();
451     }
452 
453    private:
GetCurrentFilename()454     std::string GetCurrentFilename() {
455       return GetCheckpointFileName(dataset()->shard_dir_,
456                                    current_checkpoint_id_);
457     }
458 
AdvanceToNextFile(Env * env)459     Status AdvanceToNextFile(Env* env) {
460       current_checkpoint_id_++;
461       TF_RETURN_IF_ERROR(env->FileExists(GetCurrentFilename()));
462       return Reader::Create(env, GetCurrentFilename(), dataset()->compression_,
463                             dataset()->version_, dataset()->dtypes_, &reader_);
464     }
465 
466     std::unique_ptr<Reader> reader_;
467 
468     // Stores the id current checkpoint file that we are in the process of
469     // reading (e.g. if the file is currently 00000001.snapshot, then this will
470     // be 1).
471     uint64 current_checkpoint_id_;
472   };
473 
474   const std::string shard_dir_;
475   const std::string compression_;
476   const int64 version_;
477   const DataTypeVector dtypes_;
478   const std::vector<PartialTensorShape> shapes_;
479   const int64 start_index_;
480 };
481 
482 class Reader::NestedDataset : public DatasetBase {
483  public:
NestedDataset(std::vector<DatasetBase * > datasets,DatasetContext::Params params)484   explicit NestedDataset(std::vector<DatasetBase*> datasets,
485                          DatasetContext::Params params)
486       : DatasetBase(DatasetContext(std::move(params))), datasets_(datasets) {
487     dtypes_.push_back(DT_VARIANT);
488     gtl::InlinedVector<int64, 1> element_dim_sizes;
489     element_dim_sizes.push_back(1);
490     partial_shapes_.emplace_back(element_dim_sizes);
491   }
492 
output_dtypes() const493   const DataTypeVector& output_dtypes() const override { return dtypes_; }
494 
output_shapes() const495   const std::vector<PartialTensorShape>& output_shapes() const override {
496     return partial_shapes_;
497   }
498 
DebugString() const499   std::string DebugString() const override {
500     return "snapshot_util::Reader::NestedDataset";
501   }
502 
InputDatasets(std::vector<const DatasetBase * > * inputs) const503   Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
504     inputs->clear();
505     return Status::OK();
506   }
507 
CheckExternalState() const508   Status CheckExternalState() const override { return Status::OK(); }
509 
510  protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** node) const511   Status AsGraphDefInternal(SerializationContext* ctx,
512                             DatasetGraphDefBuilder* b,
513                             Node** node) const override {
514     // Not necessary perform any serialization as this dataset is only
515     // constructed at runtime in C++ and will be reconstructed every time.
516     return Status::OK();
517   }
518 
MakeIteratorInternal(const string & prefix) const519   std::unique_ptr<IteratorBase> MakeIteratorInternal(
520       const string& prefix) const override {
521     return absl::make_unique<Iterator>(Iterator::Params{
522         this, name_utils::IteratorPrefix(node_name(), prefix)});
523   }
524 
525  private:
526   std::vector<DatasetBase*> datasets_;
527   DataTypeVector dtypes_;
528   std::vector<PartialTensorShape> partial_shapes_;
529 
530   class Iterator : public DatasetIterator<NestedDataset> {
531    public:
Iterator(const Params & params)532     explicit Iterator(const Params& params)
533         : DatasetIterator<NestedDataset>(params), index_(0) {}
534 
535    protected:
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)536     Status GetNextInternal(IteratorContext* ctx,
537                            std::vector<Tensor>* out_tensors,
538                            bool* end_of_sequence) override {
539       const int64 num_datasets = dataset()->datasets_.size();
540       *end_of_sequence = num_datasets == index_;
541       if (!*end_of_sequence) {
542         Tensor tensor(DT_VARIANT, TensorShape({}));
543 
544         TF_RETURN_IF_ERROR(
545             StoreDatasetInVariantTensor(dataset()->datasets_[index_], &tensor));
546         out_tensors->clear();
547         out_tensors->push_back(std::move(tensor));
548 
549         index_++;
550       }
551       return Status::OK();
552     }
553 
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)554     Status SaveInternal(SerializationContext* ctx,
555                         IteratorStateWriter* writer) override {
556       // Not necessary to save any state as this iterator will be reconstructed
557       // from scratch when the parent snapshot dataset is restored from
558       // checkpoint.
559       return Status::OK();
560     }
561 
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)562     Status RestoreInternal(IteratorContext* ctx,
563                            IteratorStateReader* reader) override {
564       // Not necessary to restore any state as this iterator will be
565       // reconstructed from scratch when the parent snapshot dataset is restored
566       // from checkpoint.
567       return Status::OK();
568     }
569 
570    private:
571     int64 index_;
572   };
573 };
574 
MakeNestedDataset(Env * env,const std::vector<std::string> & shard_dirs,const string & compression_type,int version,const DataTypeVector & dtypes,const std::vector<PartialTensorShape> & shapes,const int64 start_index,DatasetBase ** output)575 Status Reader::MakeNestedDataset(Env* env,
576                                  const std::vector<std::string>& shard_dirs,
577                                  const string& compression_type, int version,
578                                  const DataTypeVector& dtypes,
579                                  const std::vector<PartialTensorShape>& shapes,
580                                  const int64 start_index,
581                                  DatasetBase** output) {
582   std::vector<DatasetBase*> datasets;
583 
584   datasets.reserve(shard_dirs.size());
585   for (const auto& shard_dir : shard_dirs) {
586     // TODO(frankchn): The reading pattern could be controlled in a non-round
587     // robin fashion, so we cannot assume a round-robin manner when restoring.
588     int64 dataset_start_index = start_index / shard_dirs.size();
589     if (start_index % shard_dirs.size() > datasets.size()) {
590       dataset_start_index++;
591     }
592 
593     datasets.push_back(
594         new Dataset(shard_dir, compression_type, version, dtypes, shapes,
595                     dataset_start_index,
596                     DatasetContext::Params({"snapshot_util::Reader::Dataset",
597                                             "snapshot_util_reader_Dataset"})));
598   }
599 
600   // Rotate the vector such that the first dataset contains the next element
601   // to be produced, but not if there are no shards at all (then we just
602   // construct an empty dataset).
603   if (!shard_dirs.empty()) {
604     std::rotate(datasets.begin(),
605                 datasets.begin() + (start_index % shard_dirs.size()),
606                 datasets.end());
607   }
608 
609   *output = new NestedDataset(
610       datasets, DatasetContext::Params({"snapshot_util::Reader::NestedDataset",
611                                         "snapshot_util_reader_NestedDataset"}));
612   return Status::OK();
613 }
614 
TFRecordReader(const std::string & filename,const string & compression_type,const DataTypeVector & dtypes)615 TFRecordReader::TFRecordReader(const std::string& filename,
616                                const string& compression_type,
617                                const DataTypeVector& dtypes)
618     : filename_(filename),
619       offset_(0),
620       compression_type_(compression_type),
621       dtypes_(dtypes) {}
622 
Initialize(Env * env)623 Status TFRecordReader::Initialize(Env* env) {
624   TF_RETURN_IF_ERROR(env->NewRandomAccessFile(filename_, &file_));
625 
626   record_reader_ = absl::make_unique<io::RecordReader>(
627       file_.get(), io::RecordReaderOptions::CreateRecordReaderOptions(
628                        /*compression_type=*/compression_type_));
629   return Status::OK();
630 }
631 
ReadTensors(std::vector<Tensor> * read_tensors)632 Status TFRecordReader::ReadTensors(std::vector<Tensor>* read_tensors) {
633   read_tensors->reserve(dtypes_.size());
634   for (int i = 0; i < dtypes_.size(); ++i) {
635     tstring record;
636     TF_RETURN_IF_ERROR(record_reader_->ReadRecord(&offset_, &record));
637 
638     TensorProto proto;
639     proto.ParseFromArray(record.data(), record.size());
640 
641     Tensor tensor;
642     if (!tensor.FromProto(proto)) {
643       return errors::DataLoss("Unable to parse tensor from stored proto.");
644     }
645 
646     read_tensors->push_back(std::move(tensor));
647   }
648   return Status::OK();
649 }
650 
CustomReader(const std::string & filename,const string & compression_type,const int version,const DataTypeVector & dtypes)651 CustomReader::CustomReader(const std::string& filename,
652                            const string& compression_type, const int version,
653                            const DataTypeVector& dtypes)
654     : filename_(filename),
655       compression_type_(compression_type),
656       version_(version),
657       dtypes_(dtypes) {}
658 
Initialize(Env * env)659 Status CustomReader::Initialize(Env* env) {
660   TF_RETURN_IF_ERROR(env->NewRandomAccessFile(filename_, &file_));
661   input_stream_ = std::make_unique<io::RandomAccessInputStream>(file_.get());
662 
663 #if defined(IS_SLIM_BUILD)
664   if (compression_type_ != io::compression::kNone) {
665     LOG(ERROR) << "Compression is unsupported on mobile platforms. Turning "
666                << "off compression.";
667   }
668 #else   // IS_SLIM_BUILD
669   if (compression_type_ == io::compression::kGzip) {
670     io::ZlibCompressionOptions zlib_options;
671     zlib_options = io::ZlibCompressionOptions::GZIP();
672 
673     input_stream_ = absl::make_unique<io::ZlibInputStream>(
674         input_stream_.release(), zlib_options.input_buffer_size,
675         zlib_options.output_buffer_size, zlib_options, true);
676   } else if (compression_type_ == io::compression::kSnappy) {
677     if (version_ == 0) {
678       input_stream_ = absl::make_unique<io::SnappyInputBuffer>(
679           file_.get(), /*input_buffer_bytes=*/kSnappyReaderInputBufferSizeBytes,
680           /*output_buffer_bytes=*/kSnappyReaderOutputBufferSizeBytes);
681     } else {
682       input_stream_ =
683           absl::make_unique<io::BufferedInputStream>(file_.get(), 64 << 20);
684     }
685   }
686 #endif  // IS_SLIM_BUILD
687   simple_tensor_mask_.reserve(dtypes_.size());
688   for (const auto& dtype : dtypes_) {
689     if (DataTypeCanUseMemcpy(dtype)) {
690       simple_tensor_mask_.push_back(true);
691       num_simple_++;
692     } else {
693       simple_tensor_mask_.push_back(false);
694       num_complex_++;
695     }
696   }
697 
698   return Status::OK();
699 }
700 
ReadTensors(std::vector<Tensor> * read_tensors)701 Status CustomReader::ReadTensors(std::vector<Tensor>* read_tensors) {
702   profiler::TraceMe activity(
703       [&]() { return absl::StrCat(kClassName, kSeparator, "ReadTensors"); },
704       profiler::TraceMeLevel::kInfo);
705   if (version_ == 0 || compression_type_ != io::compression::kSnappy) {
706     return ReadTensorsV0(read_tensors);
707   }
708   if (version_ != 1) {
709     return errors::InvalidArgument("Version: ", version_, " is not supported.");
710   }
711   if (compression_type_ != io::compression::kSnappy) {
712     return errors::InvalidArgument("Compression ", compression_type_,
713                                    " is not supported.");
714   }
715 
716   experimental::SnapshotTensorMetadata metadata;
717   tstring metadata_str;
718   TF_RETURN_IF_ERROR(ReadRecord(&metadata_str));
719   if (!metadata.ParseFromArray(metadata_str.data(), metadata_str.size())) {
720     return errors::DataLoss("Could not parse SnapshotTensorMetadata");
721   }
722   read_tensors->reserve(metadata.tensor_metadata_size());
723 
724   std::vector<Tensor> simple_tensors;
725   simple_tensors.reserve(num_simple_);
726   std::vector<std::pair<std::unique_ptr<char[]>, size_t>> tensor_proto_strs;
727   tensor_proto_strs.reserve(num_complex_);
728   TF_RETURN_IF_ERROR(
729       SnappyUncompress(&metadata, &simple_tensors, &tensor_proto_strs));
730 
731   int simple_index = 0;
732   int complex_index = 0;
733   for (int i = 0, end = simple_tensor_mask_.size(); i < end; ++i) {
734     if (simple_tensor_mask_[i]) {
735       read_tensors->push_back(std::move(simple_tensors[simple_index]));
736       simple_index++;
737     } else {
738       auto tensor_proto_str = std::move(tensor_proto_strs[complex_index].first);
739       size_t tensor_proto_size = tensor_proto_strs[complex_index].second;
740       TensorProto tp;
741       if (!tp.ParseFromArray(tensor_proto_str.get(), tensor_proto_size)) {
742         return errors::Internal("Could not parse TensorProto");
743       }
744       Tensor t;
745       if (!t.FromProto(tp)) {
746         return errors::Internal("Could not parse Tensor");
747       }
748       read_tensors->push_back(std::move(t));
749       complex_index++;
750     }
751   }
752   return Status::OK();
753 }
754 
ReadTensorsV0(std::vector<Tensor> * read_tensors)755 Status CustomReader::ReadTensorsV0(std::vector<Tensor>* read_tensors) {
756   experimental::SnapshotRecord record;
757 #if defined(PLATFORM_GOOGLE)
758   absl::Cord c;
759   TF_RETURN_IF_ERROR(ReadRecord(&c));
760   record.ParseFromCord(c);
761 #else   // PLATFORM_GOOGLE
762   tstring record_bytes;
763   TF_RETURN_IF_ERROR(ReadRecord(&record_bytes));
764   record.ParseFromArray(record_bytes.data(), record_bytes.size());
765 #endif  // PLATFORM_GOOGLE
766   read_tensors->reserve(record.tensor_size());
767   for (int i = 0; i < record.tensor_size(); ++i) {
768     read_tensors->emplace_back();
769     if (!read_tensors->back().FromProto(record.tensor(i))) {
770       return errors::DataLoss("Unable to parse tensor from proto.");
771     }
772   }
773   return Status::OK();
774 }
775 
SnappyUncompress(const experimental::SnapshotTensorMetadata * metadata,std::vector<Tensor> * simple_tensors,std::vector<std::pair<std::unique_ptr<char[]>,size_t>> * tensor_proto_strs)776 Status CustomReader::SnappyUncompress(
777     const experimental::SnapshotTensorMetadata* metadata,
778     std::vector<Tensor>* simple_tensors,
779     std::vector<std::pair<std::unique_ptr<char[]>, size_t>>*
780         tensor_proto_strs) {
781   tstring compressed;
782   TF_RETURN_IF_ERROR(ReadRecord(&compressed));
783   size_t size;
784   if (!port::Snappy_GetUncompressedLength(compressed.data(), compressed.size(),
785                                           &size)) {
786     return errors::Internal("Could not get snappy uncompressed length");
787   }
788 
789   int num_tensors = metadata->tensor_metadata_size();
790   std::vector<struct iovec> iov(num_tensors);
791   int index = 0;
792   int64 total_size = 0;
793   for (int i = 0, end = simple_tensor_mask_.size(); i < end; ++i) {
794     const auto& tensor_metadata = metadata->tensor_metadata(i);
795     if (simple_tensor_mask_[i]) {
796       TensorShape shape(tensor_metadata.tensor_shape());
797       Tensor simple_tensor(dtypes_[i], shape);
798       TensorBuffer* buffer = DMAHelper::buffer(&simple_tensor);
799       iov[index].iov_base = buffer->data();
800       iov[index].iov_len = buffer->size();
801       simple_tensors->push_back(std::move(simple_tensor));
802     } else {
803       auto tensor_proto_str =
804           absl::make_unique<char[]>(tensor_metadata.tensor_size_bytes());
805       iov[index].iov_base = tensor_proto_str.get();
806       iov[index].iov_len = tensor_metadata.tensor_size_bytes();
807       tensor_proto_strs->push_back(std::make_pair(
808           std::move(tensor_proto_str), tensor_metadata.tensor_size_bytes()));
809     }
810     total_size += iov[index].iov_len;
811     index++;
812   }
813   const int64 size_int = size;
814   if (size_int != total_size) {
815     return errors::Internal("Uncompressed size mismatch. Snappy expects ", size,
816                             " whereas the tensor metadata suggests ",
817                             total_size);
818   }
819   if (!port::Snappy_UncompressToIOVec(compressed.data(), compressed.size(),
820                                       iov.data(), num_tensors)) {
821     return errors::Internal("Failed to perform snappy decompression.");
822   }
823   return Status::OK();
824 }
825 
ReadRecord(tstring * record)826 Status CustomReader::ReadRecord(tstring* record) {
827   tstring header;
828   TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(kHeaderSize, &header));
829   uint64 length = core::DecodeFixed64(header.data());
830   return input_stream_->ReadNBytes(length, record);
831 }
832 
833 #if defined(TF_CORD_SUPPORT)
ReadRecord(absl::Cord * record)834 Status CustomReader::ReadRecord(absl::Cord* record) {
835   tstring header;
836   TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(kHeaderSize, &header));
837   uint64 length = core::DecodeFixed64(header.data());
838   if (compression_type_ == io::compression::kNone) {
839     return input_stream_->ReadNBytes(length, record);
840   } else {
841     auto tmp_str = new tstring();
842     TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(length, tmp_str));
843     absl::string_view tmp_str_view(*tmp_str);
844     record->Append(absl::MakeCordFromExternal(
845         tmp_str_view, [tmp_str](absl::string_view) { delete tmp_str; }));
846     return Status::OK();
847   }
848 }
849 #endif  // TF_CORD_SUPPORT
850 
WriteMetadataFile(Env * env,const string & dir,const experimental::SnapshotMetadataRecord * metadata)851 Status WriteMetadataFile(Env* env, const string& dir,
852                          const experimental::SnapshotMetadataRecord* metadata) {
853   string metadata_filename = io::JoinPath(dir, kMetadataFilename);
854   TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(dir));
855   std::string tmp_filename =
856       absl::StrCat(metadata_filename, "-tmp-", random::New64());
857   TF_RETURN_IF_ERROR(WriteBinaryProto(env, tmp_filename, *metadata));
858   return env->RenameFile(tmp_filename, metadata_filename);
859 }
860 
ReadMetadataFile(Env * env,const string & dir,experimental::SnapshotMetadataRecord * metadata,bool * file_exists)861 Status ReadMetadataFile(Env* env, const string& dir,
862                         experimental::SnapshotMetadataRecord* metadata,
863                         bool* file_exists) {
864   string metadata_filename = io::JoinPath(dir, kMetadataFilename);
865   Status s = env->FileExists(metadata_filename);
866   *file_exists = s.ok();
867 
868   if (*file_exists) {
869     return ReadBinaryProto(env, metadata_filename, metadata);
870   } else {
871     return Status::OK();
872   }
873 }
874 
DumpDatasetGraph(Env * env,const std::string & path,uint64 hash,const GraphDef * graph)875 Status DumpDatasetGraph(Env* env, const std::string& path, uint64 hash,
876                         const GraphDef* graph) {
877   std::string hash_hex =
878       strings::StrCat(strings::Hex(hash, strings::kZeroPad16));
879   std::string graph_file =
880       io::JoinPath(path, absl::StrCat(hash_hex, "-graph.pbtxt"));
881 
882   LOG(INFO) << "Graph hash is " << hash_hex << ", writing to " << graph_file;
883   TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(path));
884   return WriteTextProto(env, graph_file, *graph);
885 }
886 
DetermineOpState(const std::string & mode_string,bool file_exists,const experimental::SnapshotMetadataRecord * metadata,const uint64 pending_snapshot_expiry_seconds,Mode * mode)887 Status DetermineOpState(const std::string& mode_string, bool file_exists,
888                         const experimental::SnapshotMetadataRecord* metadata,
889                         const uint64 pending_snapshot_expiry_seconds,
890                         Mode* mode) {
891   if (mode_string == kModeRead) {
892     // In read mode, we should expect a metadata file is written.
893     if (!file_exists) {
894       return errors::NotFound("Metadata file does not exist.");
895     }
896     LOG(INFO) << "Overriding mode to reader.";
897     *mode = READER;
898     return Status::OK();
899   }
900 
901   if (mode_string == kModeWrite) {
902     LOG(INFO) << "Overriding mode to writer.";
903     *mode = WRITER;
904     return Status::OK();
905   }
906 
907   if (mode_string == kModePassthrough) {
908     LOG(INFO) << "Overriding mode to passthrough.";
909     *mode = PASSTHROUGH;
910     return Status::OK();
911   }
912 
913   if (!file_exists) {
914     *mode = WRITER;
915     return Status::OK();
916   }
917 
918   if (metadata->finalized()) {
919     // File found, snapshot has been finalized.
920     *mode = READER;
921     return Status::OK();
922   }
923 
924   int64 expiration_timer = static_cast<int64>(EnvTime::NowMicros()) -
925                            pending_snapshot_expiry_seconds * 1000000;
926 
927   if (metadata->creation_timestamp() >= expiration_timer) {
928     // Someone else is already writing and time has not expired.
929     *mode = PASSTHROUGH;
930     return Status::OK();
931   } else {
932     // Time has expired, we write regardless.
933     *mode = WRITER;
934     return Status::OK();
935   }
936 }
937 
AsyncWriter(Env * env,int64 file_index,const std::string & shard_directory,uint64 checkpoint_id,const std::string & compression,int64 version,const DataTypeVector & output_types,std::function<void (Status)> done)938 AsyncWriter::AsyncWriter(Env* env, int64 file_index,
939                          const std::string& shard_directory,
940                          uint64 checkpoint_id, const std::string& compression,
941                          int64 version, const DataTypeVector& output_types,
942                          std::function<void(Status)> done) {
943   thread_ = absl::WrapUnique(env->StartThread(
944       ThreadOptions(), absl::StrCat("writer_thread_", file_index),
945       [this, env, shard_directory, checkpoint_id, compression, version,
946        &output_types, done = std::move(done)] {
947         done(WriterThread(env, shard_directory, checkpoint_id, compression,
948                           version, output_types));
949       }));
950 }
951 
Write(const std::vector<Tensor> & tensors)952 void AsyncWriter::Write(const std::vector<Tensor>& tensors) {
953   mutex_lock l(mu_);
954   ElementOrEOF element;
955   element.value = tensors;
956   deque_.push_back(std::move(element));
957 }
958 
SignalEOF()959 void AsyncWriter::SignalEOF() {
960   mutex_lock l(mu_);
961   ElementOrEOF be;
962   be.end_of_sequence = true;
963   deque_.push_back(std::move(be));
964 }
965 
Consume(ElementOrEOF * be)966 void AsyncWriter::Consume(ElementOrEOF* be) {
967   mutex_lock l(mu_);
968   mu_.Await(tensorflow::Condition(this, &AsyncWriter::ElementAvailable));
969   *be = deque_.front();
970   deque_.pop_front();
971 }
972 
ElementAvailable()973 bool AsyncWriter::ElementAvailable() { return !deque_.empty(); }
974 
WriterThread(Env * env,const std::string & shard_directory,uint64 checkpoint_id,const std::string & compression,int64 version,DataTypeVector output_types)975 Status AsyncWriter::WriterThread(Env* env, const std::string& shard_directory,
976                                  uint64 checkpoint_id,
977                                  const std::string& compression, int64 version,
978                                  DataTypeVector output_types) {
979   std::unique_ptr<snapshot_util::Writer> writer;
980   TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(shard_directory));
981 
982   TF_RETURN_IF_ERROR(snapshot_util::Writer::Create(
983       env, GetCheckpointFileName(shard_directory, checkpoint_id), compression,
984       version, std::move(output_types), &writer));
985 
986   while (true) {
987     ElementOrEOF be;
988     Consume(&be);
989 
990     if (be.end_of_sequence) {
991       TF_RETURN_IF_ERROR(writer->Close());
992       break;
993     }
994 
995     TF_RETURN_IF_ERROR(writer->WriteTensors(be.value));
996   }
997   return Status::OK();
998 }
999 
1000 }  // namespace snapshot_util
1001 }  // namespace data
1002 }  // namespace tensorflow
1003