• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_SNAPSHOT_UTIL_H_
17 #define TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_SNAPSHOT_UTIL_H_
18 
19 #include "tensorflow/core/framework/dataset.h"
20 #include "tensorflow/core/framework/tensor.h"
21 #include "tensorflow/core/framework/types.h"
22 #include "tensorflow/core/lib/io/compression.h"
23 #include "tensorflow/core/lib/io/inputstream_interface.h"
24 #include "tensorflow/core/lib/io/record_reader.h"
25 #include "tensorflow/core/lib/io/record_writer.h"
26 #include "tensorflow/core/platform/env.h"
27 #include "tensorflow/core/platform/file_system.h"
28 #include "tensorflow/core/platform/path.h"
29 #include "tensorflow/core/platform/status.h"
30 
31 namespace tensorflow {
32 
33 class GraphDef;
34 
35 namespace data {
36 
37 namespace experimental {
38 
39 class SnapshotMetadataRecord;
40 class SnapshotTensorMetadata;
41 
42 }  // namespace experimental
43 
44 namespace snapshot_util {
45 
46 constexpr char kMetadataFilename[] = "snapshot.metadata";
47 
48 constexpr char kModeAuto[] = "auto";
49 constexpr char kModeWrite[] = "write";
50 constexpr char kModeRead[] = "read";
51 constexpr char kModePassthrough[] = "passthrough";
52 constexpr char kShardDirectorySuffix[] = ".shard";
53 
54 enum Mode { READER = 0, WRITER = 1, PASSTHROUGH = 2 };
55 
56 // Returns the name of the "hash" directory for the given base path and hash ID.
57 std::string HashDirectory(const std::string& path, uint64 hash);
58 
59 // Returns the name of the "run" directory for the given base path and run ID.
60 std::string RunDirectory(const std::string& hash_directory, uint64 run_id);
61 std::string RunDirectory(const std::string& hash_directory,
62                          const std::string& run_id);
63 
64 // Returns the name of the "shard" directory for the given base path and shard
65 // ID.
66 std::string ShardDirectory(const std::string& run_directory, int64 shard_id);
67 
68 // Returns the checkpoint file name for the given directory and checkpoint ID.
69 std::string GetCheckpointFileName(const std::string& shard_directory,
70                                   const uint64 checkpoint_id);
71 
72 // This is a interface class that exposes snapshot writing functionality.
73 class Writer {
74  public:
75   // Creates a new writer object.
76   static Status Create(Env* env, const std::string& filename,
77                        const std::string& compression_type, int version,
78                        const DataTypeVector& dtypes,
79                        std::unique_ptr<Writer>* out_writer);
80 
81   // Writes a vector of tensors to the snapshot writer file.
82   virtual Status WriteTensors(const std::vector<Tensor>& tensors) = 0;
83 
84   // Flushes any in-memory buffers to disk.
85   virtual Status Sync() = 0;
86 
87   // Closes and finalizes the snapshot file. All calls to any other method will
88   // be invalid after this call.
89   virtual Status Close() = 0;
90 
~Writer()91   virtual ~Writer() {}
92 
93  protected:
94   virtual Status Initialize(tensorflow::Env* env) = 0;
95 };
96 
97 // Writes snapshots with the standard TFRecord file format.
98 class TFRecordWriter : public Writer {
99  public:
100   TFRecordWriter(const std::string& filename,
101                  const std::string& compression_type);
102 
103   Status WriteTensors(const std::vector<Tensor>& tensors) override;
104 
105   Status Sync() override;
106 
107   Status Close() override;
108 
109   ~TFRecordWriter() override;
110 
111  protected:
112   Status Initialize(tensorflow::Env* env) override;
113 
114  private:
115   const std::string filename_;
116   const std::string compression_type_;
117 
118   std::unique_ptr<WritableFile> dest_;
119   std::unique_ptr<io::RecordWriter> record_writer_;
120 };
121 
122 // Writes snapshot with a custom (legacy) file format.
123 class CustomWriter : public Writer {
124  public:
125   static constexpr const size_t kHeaderSize = sizeof(uint64);
126 
127   static constexpr const char* const kClassName = "SnapshotWriter";
128   static constexpr const char* const kWriteStringPiece = "WriteStringPiece";
129   static constexpr const char* const kWriteCord = "WriteCord";
130   static constexpr const char* const kSeparator = "::";
131 
132   CustomWriter(const std::string& filename, const std::string& compression_type,
133                const DataTypeVector& dtypes);
134 
135   Status WriteTensors(const std::vector<Tensor>& tensors) override;
136 
137   Status Sync() override;
138 
139   Status Close() override;
140 
141   ~CustomWriter() override;
142 
143  protected:
144   Status Initialize(tensorflow::Env* env) override;
145 
146  private:
147   Status WriteRecord(const StringPiece& data);
148 
149 #if defined(TF_CORD_SUPPORT)
150   Status WriteRecord(const absl::Cord& data);
151 #endif  // TF_CORD_SUPPORT
152 
153   std::unique_ptr<WritableFile> dest_;
154   const std::string filename_;
155   const std::string compression_type_;
156   const DataTypeVector dtypes_;
157   // We hold zlib_dest_ because we may create a ZlibOutputBuffer and put that
158   // in dest_ if we want compression. ZlibOutputBuffer doesn't own the original
159   // dest_ and so we need somewhere to store the original one.
160   std::unique_ptr<WritableFile> zlib_underlying_dest_;
161   std::vector<bool> simple_tensor_mask_;  // true for simple, false for complex.
162   int num_simple_ = 0;
163   int num_complex_ = 0;
164 };
165 
166 // Interface class for reading snapshot files previous written with Writer.
167 class Reader {
168  public:
169   // Creates a new Reader object that reads data from `filename`. Note that
170   // the `version`, `compression_type`, and `dtypes` arguments passed into
171   // `Writer` and `Reader` must be the same for the reading to succeed.
172   static Status Create(Env* env, const std::string& filename,
173                        const string& compression_type, int version,
174                        const DataTypeVector& dtypes,
175                        std::unique_ptr<Reader>* out_reader);
176 
177   // Returns a nested dataset for a set of given snapshot file names.
178   //
179   // This function takes a vector of snapshot files, and returns a nested
180   // dataset. Each element within the nested dataset is itself a dataset, and
181   // contains all the elements written out to each individual snapshot file.
182   static Status MakeNestedDataset(Env* env,
183                                   const std::vector<std::string>& shard_dirs,
184                                   const string& compression_type, int version,
185                                   const DataTypeVector& dtypes,
186                                   const std::vector<PartialTensorShape>& shapes,
187                                   const int64 start_index,
188                                   DatasetBase** output);
189 
190   // Reads a vector of Tensors from the snapshot file.
191   virtual Status ReadTensors(std::vector<Tensor>* read_tensors) = 0;
192 
193   // Skips `num_records`. Equivalent to calling `ReadTensors` `num_records`
194   // times then discarding the results.
195   virtual Status SkipRecords(int64 num_records);
196 
~Reader()197   virtual ~Reader() {}
198 
199  protected:
200   virtual Status Initialize(Env* env) = 0;
201 
202   class Dataset;
203   class NestedDataset;
204 };
205 
206 // Reads snapshots previously written with `TFRecordWriter`.
207 class TFRecordReader : public Reader {
208  public:
209   TFRecordReader(const std::string& filename, const string& compression_type,
210                  const DataTypeVector& dtypes);
211 
212   Status ReadTensors(std::vector<Tensor>* read_tensors) override;
213 
~TFRecordReader()214   ~TFRecordReader() override {}
215 
216  protected:
217   Status Initialize(Env* env) override;
218 
219  private:
220   std::string filename_;
221   std::unique_ptr<RandomAccessFile> file_;
222   std::unique_ptr<io::RecordReader> record_reader_;
223   uint64 offset_;
224 
225   const string compression_type_;
226   const DataTypeVector dtypes_;
227 };
228 
229 // Reads snapshots previously written with `CustomWriter`.
230 class CustomReader : public Reader {
231  public:
232   // The reader input buffer size is deliberately large because the input reader
233   // will throw an error if the compressed block length cannot fit in the input
234   // buffer.
235   static constexpr const int64 kSnappyReaderInputBufferSizeBytes =
236       1 << 30;  // 1 GiB
237   // TODO(b/148804377): Set this in a smarter fashion.
238   static constexpr const int64 kSnappyReaderOutputBufferSizeBytes =
239       32 << 20;  // 32 MiB
240   static constexpr const size_t kHeaderSize = sizeof(uint64);
241 
242   static constexpr const char* const kClassName = "SnapshotReader";
243   static constexpr const char* const kReadString = "ReadString";
244   static constexpr const char* const kReadCord = "ReadCord";
245   static constexpr const char* const kSeparator = "::";
246 
247   CustomReader(const std::string& filename, const string& compression_type,
248                const int version, const DataTypeVector& dtypes);
249 
250   Status ReadTensors(std::vector<Tensor>* read_tensors) override;
251 
~CustomReader()252   ~CustomReader() override {}
253 
254  protected:
255   Status Initialize(Env* env) override;
256 
257  private:
258   Status ReadTensorsV0(std::vector<Tensor>* read_tensors);
259 
260   Status SnappyUncompress(
261       const experimental::SnapshotTensorMetadata* metadata,
262       std::vector<Tensor>* simple_tensors,
263       std::vector<std::pair<std::unique_ptr<char[]>, size_t>>*
264           tensor_proto_strs);
265 
266   Status ReadRecord(tstring* record);
267 
268 #if defined(TF_CORD_SUPPORT)
269   Status ReadRecord(absl::Cord* record);
270 #endif
271 
272   std::string filename_;
273   std::unique_ptr<RandomAccessFile> file_;
274   std::unique_ptr<io::InputStreamInterface> input_stream_;
275   const string compression_type_;
276   const int version_;
277   const DataTypeVector dtypes_;
278   int num_simple_ = 0;
279   int num_complex_ = 0;
280   std::vector<bool> simple_tensor_mask_;  // true for simple, false for complex.
281 };
282 
283 // Writes snapshot metadata to the given directory.
284 Status WriteMetadataFile(Env* env, const string& dir,
285                          const experimental::SnapshotMetadataRecord* metadata);
286 
287 // Reads snapshot metadata from the given directory.
288 Status ReadMetadataFile(Env* env, const string& dir,
289                         experimental::SnapshotMetadataRecord* metadata,
290                         bool* file_exists);
291 
292 // Writes a dataset graph to the given directory.
293 Status DumpDatasetGraph(Env* env, const std::string& path, uint64 hash,
294                         const GraphDef* graph);
295 
296 Status DetermineOpState(const std::string& mode_string, bool file_exists,
297                         const experimental::SnapshotMetadataRecord* metadata,
298                         const uint64 pending_snapshot_expiry_seconds,
299                         Mode* mode);
300 
301 // Represents a dataset element or EOF.
302 struct ElementOrEOF {
303   std::vector<Tensor> value;
304   bool end_of_sequence = false;
305 };
306 
307 // AsyncWriter provides API for asynchronously writing dataset elements
308 // (each represented as a vector of tensors) to a file.
309 //
310 // The expected use of this API is:
311 //
312 // std::unique_ptr<AsyncWriter> writer = absl_make_unique<AsyncWriter>(...);
313 //
314 // while (data_available()) {
315 //   std::vector<Tensor> data = read_data()
316 //   writer->Write(data);
317 // }
318 // writer->SignalEOF();
319 // writer = nullptr;  // This will block until writes are flushed.
320 class AsyncWriter {
321  public:
322   explicit AsyncWriter(Env* env, int64 file_index,
323                        const std::string& shard_directory, uint64 checkpoint_id,
324                        const std::string& compression, int64 version,
325                        const DataTypeVector& output_types,
326                        std::function<void(Status)> done);
327 
328   // Writes the given tensors. The method is non-blocking and returns without
329   // waiting for the element to be written.
330   void Write(const std::vector<Tensor>& tensors) TF_LOCKS_EXCLUDED(mu_);
331 
332   // Signals the end of input. The method is non-blocking and returns without
333   // waiting for the writer to be closed.
334   void SignalEOF() TF_LOCKS_EXCLUDED(mu_);
335 
336  private:
337   void Consume(ElementOrEOF* be) TF_LOCKS_EXCLUDED(mu_);
338   bool ElementAvailable() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
339   Status WriterThread(Env* env, const std::string& shard_directory,
340                       uint64 checkpoint_id, const std::string& compression,
341                       int64 version, DataTypeVector output_types);
342 
343   mutex mu_;
344   std::deque<ElementOrEOF> deque_ TF_GUARDED_BY(mu_);
345 
346   // This has to be last. During destruction, we need to make sure that the
347   // Thread object is destroyed first as its destructor blocks on thread
348   // completion. If there are other member variables after this, they may get
349   // destroyed first before the thread finishes, potentially causing the
350   // thread to access invalid memory.
351   std::unique_ptr<Thread> thread_;
352 };
353 
354 }  // namespace snapshot_util
355 }  // namespace data
356 }  // namespace tensorflow
357 
358 #endif  // TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_SNAPSHOT_UTIL_H_
359