1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // A tensor bundle is a set of immutable persistent files storing a set of named 17 // tensors. It is designed for checkpointing TensorFlow tensors. 18 // 19 // The paths of the managed files share a common prefix; e.g., with the prefix: 20 // /fs/model/train/ckpt-step/ckpt 21 // 22 // the bundle may contain a metadata file, and sharded data files: 23 // /fs/model/train/ckpt-step/ 24 // ckpt.index 25 // ckpt.data-00000-of-00020 26 // ckpt.data-00001-of-00020 27 // ... 28 // ckpt.data-00019-of-00020 29 // 30 // The ".index" file is a string-string immutable table 31 // (tensorflow::table::Table). Each key is a name of a tensor and its value is 32 // a serialized BundleEntryProto. Each BundleEntryProto describes the metadata 33 // of a tensor: which of the "data" files contains the content of a tensor, the 34 // offset into that file, checksum, some auxiliary data, etc. 35 // 36 // A tensor bundle can be accessed randomly using a BundleReader. Usage: 37 // 38 // BundleReader reader(env, "/fs/model/train/ckpt-step/ckpt"); 39 // reader.Lookup("name", &tensor); 40 // 41 // A tensor bundle can be built using BundleWriter. Each BundleWriter builds a 42 // single data file bundle. Multiple bundles can then be merged by 43 // MergeBundles() without reading and writing large chunk of data: it reads the 44 // metadata files and outputs a single merged metadata. Typical usage: 45 // 46 // worker 0: 47 // BundleWriter writer(env, "/fs/model/train/ckpt-step/tmp/worker0-step"); 48 // writer.Add(...); // Adds the tensors on this worker. 49 // writer.Finish(); // Flushes. 50 // worker 1: 51 // BundleWriter writer(env, "/fs/model/train/ckpt-step/tmp/worker1-step"); 52 // writer.Add(...); 53 // writer.Finish(); 54 // worker 2: 55 // MergeBundles(env, 56 // {"/fs/model/train/ckpt-step/tmp/worker0-step", 57 // "/fs/model/train/ckpt-step/tmp/worker1-step"}, 58 // "/fs/model/train/ckpt-step/ckpt" /* merged prefix */); 59 // 60 61 #ifndef TENSORFLOW_CORE_UTIL_TENSOR_BUNDLE_TENSOR_BUNDLE_H_ 62 #define TENSORFLOW_CORE_UTIL_TENSOR_BUNDLE_TENSOR_BUNDLE_H_ 63 64 #include <map> 65 #include <string> 66 #include <unordered_map> 67 68 #include "tensorflow/core/framework/tensor.h" 69 #include "tensorflow/core/framework/tensor_shape.h" 70 #include "tensorflow/core/framework/tensor_slice.h" 71 #include "tensorflow/core/lib/core/status.h" 72 #include "tensorflow/core/lib/gtl/array_slice.h" 73 #include "tensorflow/core/lib/io/cache.h" 74 #include "tensorflow/core/lib/io/inputbuffer.h" 75 #include "tensorflow/core/lib/io/table.h" 76 #include "tensorflow/core/platform/env.h" 77 #include "tensorflow/core/platform/file_system.h" 78 #include "tensorflow/core/platform/macros.h" 79 #include "tensorflow/core/platform/types.h" 80 #include "tensorflow/core/protobuf/tensor_bundle.pb.h" 81 #include "tensorflow/core/util/tensor_bundle/naming.h" 82 #include "tensorflow/core/util/tensor_slice_set.h" 83 84 namespace tensorflow { 85 86 class FileOutputBuffer; 87 88 // Versioning of the tensor bundle format. 89 // Follows the same rules as 3p/tf/core/public/version.h. 90 // 91 // History: 92 // 0. Any tensor bundles produced before this field was added. 93 // 1. Added this field (2016-09-14). 94 extern const int kTensorBundleMinProducer; 95 extern const int kTensorBundleMinConsumer; 96 extern const int kTensorBundleVersion; 97 98 // The empty string, hence always the first key in the metadata table. Its 99 // corresponding value is a BundleHeaderProto. 100 extern const char* const kHeaderEntryKey; 101 102 // Builds a string-string table of tensor names to BundleEntryProto (metadata). 103 // 104 // On construction, attempts to create a directory given by the dirname of 105 // "prefix", so "status()" must be checked before calling any member functions. 106 // 107 // All threads accessing the same BundleWriter must synchronize. 108 class BundleWriter { 109 public: 110 struct Options { OptionsOptions111 Options() {} 112 // Alignment, in bytes, for tensor data. 113 // Must be >= 1. The default size of 1 densely packs tensors. 114 int data_alignment{1}; 115 }; 116 BundleWriter(Env* env, StringPiece prefix, 117 const Options& options = Options()); 118 119 // Adds the tensor "val" under key "key". 120 // Across calls "key" must be unique but can be added in any order. 121 Status Add(StringPiece key, const Tensor& val); 122 123 // Partitioned variables support. 124 // A slice of a full tensor is stored in two entries in the metadata table: 125 // 126 // full_tensor_key -> BundleEntryProto, describing all stored slices 127 // of this full tensor. Does not append to the data 128 // file. 129 // encoded slice key -> BundleEntryProto, describing one particular slice. 130 // Appends values of this slice to the data file. 131 // 132 // Slices of a full tensor can be added in any order. 133 // 134 // If a full tensor has slices placed on N devices and N BundleWriter's are 135 // concurrently used, the caller must use MergeBundles() to ensure that a 136 // consistent entry for "full_tensor_key" is produced. 137 // 138 // Returns an error if the same slice is added the second time. 139 Status AddSlice(StringPiece full_tensor_key, 140 const TensorShape& full_tensor_shape, 141 const TensorSlice& slice_spec, const Tensor& slice_tensor); 142 143 // Finishes the writer and flushes. 144 Status Finish() TF_MUST_USE_RESULT; 145 status()146 Status status() const { return status_; } 147 148 private: 149 Env* const env_; // Not owned. 150 const Options options_; 151 const string prefix_; 152 string metadata_path_; 153 string data_path_; 154 bool use_temp_file_; 155 std::unique_ptr<FileOutputBuffer> out_; 156 int64 size_; // Number of bytes written into out_. 157 std::map<string, BundleEntryProto> entries_; 158 Status status_; 159 160 TF_DISALLOW_COPY_AND_ASSIGN(BundleWriter); 161 }; 162 163 // Merges a set of bundles (given their prefixes) into a single bundle with the 164 // given "merged_prefix". The merged metadata is guaranteed to be consistent. 165 // 166 // If there are N bundles in "prefixes", during the merge the data files will be 167 // renamed to contain a proper sharded file spec, with num_shards set to the sum 168 // of num_shards across the N input bundles. 169 // 170 // The caller should only rely on the metadata file of the merged bundle to 171 // query information about a tensor. In particular, this function does not 172 // guarantee not to re-order the input data files. 173 // 174 // Once merged, makes a best effort to delete the old metadata files. 175 // Returns OK iff all bundles are successfully merged. 176 Status MergeBundles(Env* env, gtl::ArraySlice<tstring> prefixes, 177 StringPiece merged_prefix); 178 179 // On construction, silently attempts to read the metadata associated with 180 // "prefix". If caller intends to call any function afterwards, "status()" 181 // must be checked. 182 // All threads accessing the same BundleReader must synchronize. 183 class BundleReader { 184 public: 185 BundleReader(Env* const env, StringPiece prefix); 186 ~BundleReader(); 187 188 // Is ok() iff the reader construction is successful (completed the read of 189 // the metadata). status()190 Status status() const { return status_; } 191 192 // Queries whether the bundle contains an entry keyed by "key". Calls Seek() 193 // internally, so this call invalidates the reader's current position. 194 // REQUIRES: status().ok() 195 bool Contains(StringPiece key); 196 197 // Looks up the dtype and the shape of the tensor keyed by "key". 198 // REQUIRES: status().ok() 199 Status LookupDtypeAndShape(StringPiece key, DataType* dtype, 200 TensorShape* shape) TF_MUST_USE_RESULT; 201 202 // Looks up the shape of the tensor keyed by "key". 203 // Clears "shape" if not found. 204 // REQUIRES: status().ok() 205 Status LookupTensorShape(StringPiece key, 206 TensorShape* shape) TF_MUST_USE_RESULT; 207 208 // Looks up the tensor keyed by "key". If "key" refers to a partitioned 209 // tensor, attempts to look up the full contents using all stored slices. 210 // 211 // Caller must make sure "val" has the same shape and dtype as the 212 // corresponding contents, so that its buffer can be filled without needing 213 // extra allocation. These can be queried via "LookupDtypeAndShape()". 214 // 215 // On error, "val" may contain nonsense data. Returns a NotFound error if 216 // tensor keyed by "key" does not exist in this bundle. 217 // 218 // Validates the stored crc32c checksum against the restored bytes. 219 // REQUIRES: status().ok() 220 Status Lookup(StringPiece key, Tensor* val) TF_MUST_USE_RESULT; 221 222 // Looks up the tensor pointed to by the internal iterator. 223 // 224 // On error, "val" may contain nonsense data. 225 // 226 // Validates the stored crc32c checksum against the restored bytes. 227 // REQUIRES: status().ok() && Valid() 228 Status ReadCurrent(Tensor* val) TF_MUST_USE_RESULT; 229 230 // Looks up the slices of the tensor keyed by "key". On OK, "slices" 231 // is non-empty if and only if the tensor is a partitioned tensor. 232 // 233 // Warning - there is no guaranteed ordering for the returned slices, so 234 // a slice with a larger start index in some dimension could come before 235 // another slice with a smaller start index in the same dimension. 236 // REQUIRES: status().ok() 237 Status LookupTensorSlices(StringPiece key, std::vector<TensorSlice>* slices) 238 TF_MUST_USE_RESULT; 239 240 // Looks up a specific slice of a partitioned tensor. 241 // It is only required that the stored slices cover the requested slice, 242 // namely "slice_spec" is a subset of the union of the stored slices. 243 // REQUIRES: status().ok() 244 Status LookupSlice(StringPiece full_tensor_key, const TensorSlice& slice_spec, 245 Tensor* val) TF_MUST_USE_RESULT; 246 247 // Seeks to the first position in the bundle whose key is no less than "key". 248 // REQUIRES: status().ok() Seek(StringPiece key)249 void Seek(StringPiece key) { return iter_->Seek(key); } 250 // Moves to the next position in the bundle. 251 // REQUIRES: status().ok() Next()252 void Next() const { iter_->Next(); } 253 // Returns true iff the reader is positioned to a key/val pair. 254 // REQUIRES: status().ok() Valid()255 bool Valid() const { return iter_->Valid(); } 256 257 // Returns the key at the current position. 258 // REQUIRES: status().ok() && Valid() key()259 StringPiece key() const { return iter_->key(); } 260 // Returns the raw value at the current position. 261 // REQUIRES: status().ok() && Valid() value()262 StringPiece value() const { return iter_->value(); } 263 264 string DebugString(); 265 266 private: 267 // Seeks for "key" and reads the metadata proto. 268 // On non-OK return, clears "entry" for the caller. 269 // REQUIRES: status().ok() 270 Status GetBundleEntryProto(StringPiece key, 271 BundleEntryProto* entry) TF_MUST_USE_RESULT; 272 273 // Reads the tensor value described by the metadata proto "entry". 274 // Usage for "val" follows the comment of "Lookup()". 275 Status GetValue(const BundleEntryProto& entry, 276 Tensor* val) TF_MUST_USE_RESULT; 277 278 // Reads the slice described by "slice_spec". The corresponding full tensor 279 // has key "ful_tensor_key" and metadata proto "full_tensor_entry". 280 // REQUIRES: full_tensor_entry.slices_size() > 0 281 Status GetSliceValue(StringPiece full_tensor_key, 282 const BundleEntryProto& full_tensor_entry, 283 const TensorSlice& slice_spec, 284 Tensor* val) TF_MUST_USE_RESULT; 285 286 Env* env_; // Not owned. 287 const string prefix_; 288 289 Status status_; 290 RandomAccessFile* metadata_; // Owned. 291 table::Table* table_; 292 table::Cache* index_cache_; 293 table::Iterator* iter_; 294 // Owned the InputBuffer objects and their underlying RandomAccessFile's. 295 std::unordered_map<int32, io::InputBuffer*> data_; 296 297 // Maps each partitioned tensor's key to its stored slices (represented in a 298 // TensorSliceSet). Populated on-demand. 299 std::unordered_map<string, checkpoint::TensorSliceSet*> tensor_slices_; 300 301 // Expected number of data file shards in the bundle. Extracted by reading 302 // the header entry in the metadata table. 303 int num_shards_; 304 305 // Flag that this class sets to true when the endianness of the target bundle 306 // differs from that of the current system's processor architecture. 307 bool need_to_swap_bytes_; 308 309 friend class TensorBundleAlignmentTest; // For testing data alignment. 310 311 TF_DISALLOW_COPY_AND_ASSIGN(BundleReader); 312 }; 313 314 // A buffering wrapper for a WritableFile. Useful if the caller wishes to issue 315 // small writes to a file (e.g. writing out a list of small varints). 316 // External synchronization must be used in the presence of concurrent callers. 317 class FileOutputBuffer { 318 public: FileOutputBuffer(WritableFile * file,size_t buffer_size)319 FileOutputBuffer(WritableFile* file, size_t buffer_size) 320 : file_(file), position_(0), buffer_size_(buffer_size) { 321 DCHECK_GT(buffer_size, 0); 322 buffer_.resize(buffer_size); 323 } 324 ~FileOutputBuffer(); 325 326 // Buffered append. 327 Status Append(StringPiece data); 328 329 // Returns the running crc32c checksum of all currently appended bytes. crc32c()330 uint32 crc32c() { return crc32c_; } 331 // Clears the running crc32c checksum. clear_crc32c()332 void clear_crc32c() { crc32c_ = 0; } 333 334 // Appends the buffered data, then closes the underlying file. 335 Status Close(); 336 337 private: 338 // Appends the buffered data to the underlying file. Does NOT flush the file. 339 Status FlushBuffer(); 340 341 WritableFile* file_; // Owned. 342 343 // buffer_[0, position_) holds the buffered data not yet appended to the 344 // underlying file. 345 size_t position_; 346 const size_t buffer_size_; 347 std::vector<char> buffer_; 348 349 // Checksum of all appended bytes since construction or last clear_crc32c(). 350 uint32 crc32c_ = 0; 351 }; 352 353 } // namespace tensorflow 354 355 #endif // TENSORFLOW_CORE_UTIL_TENSOR_BUNDLE_TENSOR_BUNDLE_H_ 356