1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_SNAPSHOT_UTIL_H_ 17 #define TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_SNAPSHOT_UTIL_H_ 18 19 #include "tensorflow/core/framework/dataset.h" 20 #include "tensorflow/core/framework/tensor.h" 21 #include "tensorflow/core/framework/types.h" 22 #include "tensorflow/core/lib/io/compression.h" 23 #include "tensorflow/core/lib/io/inputstream_interface.h" 24 #include "tensorflow/core/lib/io/record_reader.h" 25 #include "tensorflow/core/lib/io/record_writer.h" 26 #include "tensorflow/core/platform/env.h" 27 #include "tensorflow/core/platform/file_system.h" 28 #include "tensorflow/core/platform/path.h" 29 #include "tensorflow/core/platform/status.h" 30 31 namespace tensorflow { 32 33 class GraphDef; 34 35 namespace data { 36 37 namespace experimental { 38 39 class SnapshotMetadataRecord; 40 class SnapshotTensorMetadata; 41 42 } // namespace experimental 43 44 namespace snapshot_util { 45 46 constexpr char kMetadataFilename[] = "snapshot.metadata"; 47 48 constexpr char kModeAuto[] = "auto"; 49 constexpr char kModeWrite[] = "write"; 50 constexpr char kModeRead[] = "read"; 51 constexpr char kModePassthrough[] = "passthrough"; 52 constexpr char kShardDirectorySuffix[] = ".shard"; 53 54 enum Mode { READER = 0, WRITER = 1, PASSTHROUGH = 2 }; 55 56 // Returns the name of the "hash" directory for the given base path and hash ID. 57 std::string HashDirectory(const std::string& path, uint64 hash); 58 59 // Returns the name of the "run" directory for the given base path and run ID. 60 std::string RunDirectory(const std::string& hash_directory, uint64 run_id); 61 std::string RunDirectory(const std::string& hash_directory, 62 const std::string& run_id); 63 64 // Returns the name of the "shard" directory for the given base path and shard 65 // ID. 66 std::string ShardDirectory(const std::string& run_directory, int64 shard_id); 67 68 // Returns the checkpoint file name for the given directory and checkpoint ID. 69 std::string GetCheckpointFileName(const std::string& shard_directory, 70 const uint64 checkpoint_id); 71 72 // This is a interface class that exposes snapshot writing functionality. 73 class Writer { 74 public: 75 // Creates a new writer object. 76 static Status Create(Env* env, const std::string& filename, 77 const std::string& compression_type, int version, 78 const DataTypeVector& dtypes, 79 std::unique_ptr<Writer>* out_writer); 80 81 // Writes a vector of tensors to the snapshot writer file. 82 virtual Status WriteTensors(const std::vector<Tensor>& tensors) = 0; 83 84 // Flushes any in-memory buffers to disk. 85 virtual Status Sync() = 0; 86 87 // Closes and finalizes the snapshot file. All calls to any other method will 88 // be invalid after this call. 89 virtual Status Close() = 0; 90 ~Writer()91 virtual ~Writer() {} 92 93 protected: 94 virtual Status Initialize(tensorflow::Env* env) = 0; 95 }; 96 97 // Writes snapshots with the standard TFRecord file format. 98 class TFRecordWriter : public Writer { 99 public: 100 TFRecordWriter(const std::string& filename, 101 const std::string& compression_type); 102 103 Status WriteTensors(const std::vector<Tensor>& tensors) override; 104 105 Status Sync() override; 106 107 Status Close() override; 108 109 ~TFRecordWriter() override; 110 111 protected: 112 Status Initialize(tensorflow::Env* env) override; 113 114 private: 115 const std::string filename_; 116 const std::string compression_type_; 117 118 std::unique_ptr<WritableFile> dest_; 119 std::unique_ptr<io::RecordWriter> record_writer_; 120 }; 121 122 // Writes snapshot with a custom (legacy) file format. 123 class CustomWriter : public Writer { 124 public: 125 static constexpr const size_t kHeaderSize = sizeof(uint64); 126 127 static constexpr const char* const kClassName = "SnapshotWriter"; 128 static constexpr const char* const kWriteStringPiece = "WriteStringPiece"; 129 static constexpr const char* const kWriteCord = "WriteCord"; 130 static constexpr const char* const kSeparator = "::"; 131 132 CustomWriter(const std::string& filename, const std::string& compression_type, 133 const DataTypeVector& dtypes); 134 135 Status WriteTensors(const std::vector<Tensor>& tensors) override; 136 137 Status Sync() override; 138 139 Status Close() override; 140 141 ~CustomWriter() override; 142 143 protected: 144 Status Initialize(tensorflow::Env* env) override; 145 146 private: 147 Status WriteRecord(const StringPiece& data); 148 149 #if defined(TF_CORD_SUPPORT) 150 Status WriteRecord(const absl::Cord& data); 151 #endif // TF_CORD_SUPPORT 152 153 std::unique_ptr<WritableFile> dest_; 154 const std::string filename_; 155 const std::string compression_type_; 156 const DataTypeVector dtypes_; 157 // We hold zlib_dest_ because we may create a ZlibOutputBuffer and put that 158 // in dest_ if we want compression. ZlibOutputBuffer doesn't own the original 159 // dest_ and so we need somewhere to store the original one. 160 std::unique_ptr<WritableFile> zlib_underlying_dest_; 161 std::vector<bool> simple_tensor_mask_; // true for simple, false for complex. 162 int num_simple_ = 0; 163 int num_complex_ = 0; 164 }; 165 166 // Interface class for reading snapshot files previous written with Writer. 167 class Reader { 168 public: 169 // Creates a new Reader object that reads data from `filename`. Note that 170 // the `version`, `compression_type`, and `dtypes` arguments passed into 171 // `Writer` and `Reader` must be the same for the reading to succeed. 172 static Status Create(Env* env, const std::string& filename, 173 const string& compression_type, int version, 174 const DataTypeVector& dtypes, 175 std::unique_ptr<Reader>* out_reader); 176 177 // Returns a nested dataset for a set of given snapshot file names. 178 // 179 // This function takes a vector of snapshot files, and returns a nested 180 // dataset. Each element within the nested dataset is itself a dataset, and 181 // contains all the elements written out to each individual snapshot file. 182 static Status MakeNestedDataset(Env* env, 183 const std::vector<std::string>& shard_dirs, 184 const string& compression_type, int version, 185 const DataTypeVector& dtypes, 186 const std::vector<PartialTensorShape>& shapes, 187 const int64 start_index, 188 DatasetBase** output); 189 190 // Reads a vector of Tensors from the snapshot file. 191 virtual Status ReadTensors(std::vector<Tensor>* read_tensors) = 0; 192 193 // Skips `num_records`. Equivalent to calling `ReadTensors` `num_records` 194 // times then discarding the results. 195 virtual Status SkipRecords(int64 num_records); 196 ~Reader()197 virtual ~Reader() {} 198 199 protected: 200 virtual Status Initialize(Env* env) = 0; 201 202 class Dataset; 203 class NestedDataset; 204 }; 205 206 // Reads snapshots previously written with `TFRecordWriter`. 207 class TFRecordReader : public Reader { 208 public: 209 TFRecordReader(const std::string& filename, const string& compression_type, 210 const DataTypeVector& dtypes); 211 212 Status ReadTensors(std::vector<Tensor>* read_tensors) override; 213 ~TFRecordReader()214 ~TFRecordReader() override {} 215 216 protected: 217 Status Initialize(Env* env) override; 218 219 private: 220 std::string filename_; 221 std::unique_ptr<RandomAccessFile> file_; 222 std::unique_ptr<io::RecordReader> record_reader_; 223 uint64 offset_; 224 225 const string compression_type_; 226 const DataTypeVector dtypes_; 227 }; 228 229 // Reads snapshots previously written with `CustomWriter`. 230 class CustomReader : public Reader { 231 public: 232 // The reader input buffer size is deliberately large because the input reader 233 // will throw an error if the compressed block length cannot fit in the input 234 // buffer. 235 static constexpr const int64 kSnappyReaderInputBufferSizeBytes = 236 1 << 30; // 1 GiB 237 // TODO(b/148804377): Set this in a smarter fashion. 238 static constexpr const int64 kSnappyReaderOutputBufferSizeBytes = 239 32 << 20; // 32 MiB 240 static constexpr const size_t kHeaderSize = sizeof(uint64); 241 242 static constexpr const char* const kClassName = "SnapshotReader"; 243 static constexpr const char* const kReadString = "ReadString"; 244 static constexpr const char* const kReadCord = "ReadCord"; 245 static constexpr const char* const kSeparator = "::"; 246 247 CustomReader(const std::string& filename, const string& compression_type, 248 const int version, const DataTypeVector& dtypes); 249 250 Status ReadTensors(std::vector<Tensor>* read_tensors) override; 251 ~CustomReader()252 ~CustomReader() override {} 253 254 protected: 255 Status Initialize(Env* env) override; 256 257 private: 258 Status ReadTensorsV0(std::vector<Tensor>* read_tensors); 259 260 Status SnappyUncompress( 261 const experimental::SnapshotTensorMetadata* metadata, 262 std::vector<Tensor>* simple_tensors, 263 std::vector<std::pair<std::unique_ptr<char[]>, size_t>>* 264 tensor_proto_strs); 265 266 Status ReadRecord(tstring* record); 267 268 #if defined(TF_CORD_SUPPORT) 269 Status ReadRecord(absl::Cord* record); 270 #endif 271 272 std::string filename_; 273 std::unique_ptr<RandomAccessFile> file_; 274 std::unique_ptr<io::InputStreamInterface> input_stream_; 275 const string compression_type_; 276 const int version_; 277 const DataTypeVector dtypes_; 278 int num_simple_ = 0; 279 int num_complex_ = 0; 280 std::vector<bool> simple_tensor_mask_; // true for simple, false for complex. 281 }; 282 283 // Writes snapshot metadata to the given directory. 284 Status WriteMetadataFile(Env* env, const string& dir, 285 const experimental::SnapshotMetadataRecord* metadata); 286 287 // Reads snapshot metadata from the given directory. 288 Status ReadMetadataFile(Env* env, const string& dir, 289 experimental::SnapshotMetadataRecord* metadata, 290 bool* file_exists); 291 292 // Writes a dataset graph to the given directory. 293 Status DumpDatasetGraph(Env* env, const std::string& path, uint64 hash, 294 const GraphDef* graph); 295 296 Status DetermineOpState(const std::string& mode_string, bool file_exists, 297 const experimental::SnapshotMetadataRecord* metadata, 298 const uint64 pending_snapshot_expiry_seconds, 299 Mode* mode); 300 301 // Represents a dataset element or EOF. 302 struct ElementOrEOF { 303 std::vector<Tensor> value; 304 bool end_of_sequence = false; 305 }; 306 307 // AsyncWriter provides API for asynchronously writing dataset elements 308 // (each represented as a vector of tensors) to a file. 309 // 310 // The expected use of this API is: 311 // 312 // std::unique_ptr<AsyncWriter> writer = absl_make_unique<AsyncWriter>(...); 313 // 314 // while (data_available()) { 315 // std::vector<Tensor> data = read_data() 316 // writer->Write(data); 317 // } 318 // writer->SignalEOF(); 319 // writer = nullptr; // This will block until writes are flushed. 320 class AsyncWriter { 321 public: 322 explicit AsyncWriter(Env* env, int64 file_index, 323 const std::string& shard_directory, uint64 checkpoint_id, 324 const std::string& compression, int64 version, 325 const DataTypeVector& output_types, 326 std::function<void(Status)> done); 327 328 // Writes the given tensors. The method is non-blocking and returns without 329 // waiting for the element to be written. 330 void Write(const std::vector<Tensor>& tensors) TF_LOCKS_EXCLUDED(mu_); 331 332 // Signals the end of input. The method is non-blocking and returns without 333 // waiting for the writer to be closed. 334 void SignalEOF() TF_LOCKS_EXCLUDED(mu_); 335 336 private: 337 void Consume(ElementOrEOF* be) TF_LOCKS_EXCLUDED(mu_); 338 bool ElementAvailable() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); 339 Status WriterThread(Env* env, const std::string& shard_directory, 340 uint64 checkpoint_id, const std::string& compression, 341 int64 version, DataTypeVector output_types); 342 343 mutex mu_; 344 std::deque<ElementOrEOF> deque_ TF_GUARDED_BY(mu_); 345 346 // This has to be last. During destruction, we need to make sure that the 347 // Thread object is destroyed first as its destructor blocks on thread 348 // completion. If there are other member variables after this, they may get 349 // destroyed first before the thread finishes, potentially causing the 350 // thread to access invalid memory. 351 std::unique_ptr<Thread> thread_; 352 }; 353 354 } // namespace snapshot_util 355 } // namespace data 356 } // namespace tensorflow 357 358 #endif // TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_SNAPSHOT_UTIL_H_ 359