1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // A tensor bundle is a set of immutable persistent files storing a set of named 17 // tensors. It is designed for checkpointing TensorFlow tensors. 18 // 19 // The paths of the managed files share a common prefix; e.g., with the prefix: 20 // /fs/model/train/ckpt-step/ckpt 21 // 22 // the bundle may contain a metadata file, and sharded data files: 23 // /fs/model/train/ckpt-step/ 24 // ckpt.index 25 // ckpt.data-00000-of-00020 26 // ckpt.data-00001-of-00020 27 // ... 28 // ckpt.data-00019-of-00020 29 // 30 // The ".index" file is a string-string immutable table 31 // (tensorflow::table::Table). Each key is a name of a tensor and its value is 32 // a serialized BundleEntryProto. Each BundleEntryProto describes the metadata 33 // of a tensor: which of the "data" files contains the content of a tensor, the 34 // offset into that file, checksum, some auxiliary data, etc. 35 // 36 // A tensor bundle can be accessed randomly using a BundleReader. Usage: 37 // 38 // BundleReader reader(env, "/fs/model/train/ckpt-step/ckpt"); 39 // reader.Lookup("name", &tensor); 40 // 41 // A tensor bundle can be built using BundleWriter. Each BundleWriter builds a 42 // single data file bundle. Multiple bundles can then be merged by 43 // MergeBundles() without reading and writing large chunk of data: it reads the 44 // metadata files and outputs a single merged metadata. Typical usage: 45 // 46 // worker 0: 47 // BundleWriter writer(env, "/fs/model/train/ckpt-step/tmp/worker0-step"); 48 // writer.Add(...); // Adds the tensors on this worker. 49 // writer.Finish(); // Flushes. 50 // worker 1: 51 // BundleWriter writer(env, "/fs/model/train/ckpt-step/tmp/worker1-step"); 52 // writer.Add(...); 53 // writer.Finish(); 54 // worker 2: 55 // MergeBundles(env, 56 // {"/fs/model/train/ckpt-step/tmp/worker0-step", 57 // "/fs/model/train/ckpt-step/tmp/worker1-step"}, 58 // "/fs/model/train/ckpt-step/ckpt" /* merged prefix */); 59 // 60 61 #ifndef TENSORFLOW_CORE_UTIL_TENSOR_BUNDLE_TENSOR_BUNDLE_H_ 62 #define TENSORFLOW_CORE_UTIL_TENSOR_BUNDLE_TENSOR_BUNDLE_H_ 63 64 #include <map> 65 #include <string> 66 #include <unordered_map> 67 68 #include "tensorflow/core/framework/tensor.h" 69 #include "tensorflow/core/framework/tensor_shape.h" 70 #include "tensorflow/core/framework/tensor_slice.h" 71 #include "tensorflow/core/lib/core/status.h" 72 #include "tensorflow/core/lib/gtl/array_slice.h" 73 #include "tensorflow/core/lib/io/cache.h" 74 #include "tensorflow/core/lib/io/inputbuffer.h" 75 #include "tensorflow/core/lib/io/table.h" 76 #include "tensorflow/core/platform/cord.h" 77 #include "tensorflow/core/platform/env.h" 78 #include "tensorflow/core/platform/file_system.h" 79 #include "tensorflow/core/platform/macros.h" 80 #include "tensorflow/core/platform/types.h" 81 #include "tensorflow/core/protobuf/tensor_bundle.pb.h" 82 #include "tensorflow/core/util/tensor_bundle/naming.h" 83 #include "tensorflow/core/util/tensor_slice_set.h" 84 85 namespace tensorflow { 86 87 class FileOutputBuffer; 88 89 // Versioning of the tensor bundle format. 90 // Follows the same rules as 3p/tf/core/public/version.h. 91 // 92 // History: 93 // 0. Any tensor bundles produced before this field was added. 94 // 1. Added this field (2016-09-14). 95 extern const int kTensorBundleMinProducer; 96 extern const int kTensorBundleMinConsumer; 97 extern const int kTensorBundleVersion; 98 99 // The empty string, hence always the first key in the metadata table. Its 100 // corresponding value is a BundleHeaderProto. 101 extern const char* const kHeaderEntryKey; 102 103 // Builds a string-string table of tensor names to BundleEntryProto (metadata). 104 // 105 // On construction, attempts to create a directory given by the dirname of 106 // "prefix", so "status()" must be checked before calling any member functions. 107 // 108 // All threads accessing the same BundleWriter must synchronize. 109 class BundleWriter { 110 public: 111 struct Options { OptionsOptions112 Options() {} 113 // Alignment, in bytes, for tensor data. 114 // Must be >= 1. The default size of 1 densely packs tensors. 115 int data_alignment{1}; 116 }; 117 BundleWriter(Env* env, StringPiece prefix, 118 const Options& options = Options()); 119 120 // Adds the tensor "val" under key "key". 121 // Across calls "key" must be unique but can be added in any order. 122 Status Add(StringPiece key, const Tensor& val); 123 124 // Partitioned variables support. 125 // A slice of a full tensor is stored in two entries in the metadata table: 126 // 127 // full_tensor_key -> BundleEntryProto, describing all stored slices 128 // of this full tensor. Does not append to the data 129 // file. 130 // encoded slice key -> BundleEntryProto, describing one particular slice. 131 // Appends values of this slice to the data file. 132 // 133 // Slices of a full tensor can be added in any order. 134 // 135 // If a full tensor has slices placed on N devices and N BundleWriter's are 136 // concurrently used, the caller must use MergeBundles() to ensure that a 137 // consistent entry for "full_tensor_key" is produced. 138 // 139 // Returns an error if the same slice is added the second time. 140 Status AddSlice(StringPiece full_tensor_key, 141 const TensorShape& full_tensor_shape, 142 const TensorSlice& slice_spec, const Tensor& slice_tensor); 143 144 // Finishes the writer and flushes. 145 Status Finish() TF_MUST_USE_RESULT; 146 status()147 Status status() const { return status_; } 148 149 private: 150 Env* const env_; // Not owned. 151 const Options options_; 152 const string prefix_; 153 string metadata_path_; 154 string data_path_; 155 bool use_temp_file_; 156 std::unique_ptr<FileOutputBuffer> out_; 157 int64 size_; // Number of bytes written into out_. 158 std::map<string, BundleEntryProto> entries_; 159 Status status_; 160 161 TF_DISALLOW_COPY_AND_ASSIGN(BundleWriter); 162 }; 163 164 // Merges a set of bundles (given their prefixes) into a single bundle with the 165 // given "merged_prefix". The merged metadata is guaranteed to be consistent. 166 // 167 // If there are N bundles in "prefixes", during the merge the data files will be 168 // renamed to contain a proper sharded file spec, with num_shards set to the sum 169 // of num_shards across the N input bundles. 170 // 171 // The caller should only rely on the metadata file of the merged bundle to 172 // query information about a tensor. In particular, this function does not 173 // guarantee not to re-order the input data files. 174 // 175 // Once merged, makes a best effort to delete the old metadata files. 176 // Returns OK iff all bundles are successfully merged. 177 Status MergeBundles(Env* env, gtl::ArraySlice<tstring> prefixes, 178 StringPiece merged_prefix); 179 180 // On construction, silently attempts to read the metadata associated with 181 // "prefix". If caller intends to call any function afterwards, "status()" 182 // must be checked. 183 // All threads accessing the same BundleReader must synchronize. 184 class BundleReader { 185 public: 186 BundleReader(Env* const env, StringPiece prefix); 187 ~BundleReader(); 188 189 // Is ok() iff the reader construction is successful (completed the read of 190 // the metadata). status()191 Status status() const { return status_; } 192 193 // Queries whether the bundle contains an entry keyed by "key". Calls Seek() 194 // internally, so this call invalidates the reader's current position. 195 // REQUIRES: status().ok() 196 bool Contains(StringPiece key); 197 198 // Looks up the dtype and the shape of the tensor keyed by "key". 199 // REQUIRES: status().ok() 200 Status LookupDtypeAndShape(StringPiece key, DataType* dtype, 201 TensorShape* shape) TF_MUST_USE_RESULT; 202 203 // Looks up the shape of the tensor keyed by "key". 204 // Clears "shape" if not found. 205 // REQUIRES: status().ok() 206 Status LookupTensorShape(StringPiece key, 207 TensorShape* shape) TF_MUST_USE_RESULT; 208 209 // Looks up the tensor keyed by "key". If "key" refers to a partitioned 210 // tensor, attempts to look up the full contents using all stored slices. 211 // 212 // Caller must make sure "val" has the same shape and dtype as the 213 // corresponding contents, so that its buffer can be filled without needing 214 // extra allocation. These can be queried via "LookupDtypeAndShape()". 215 // 216 // On error, "val" may contain nonsense data. Returns a NotFound error if 217 // tensor keyed by "key" does not exist in this bundle. 218 // 219 // Validates the stored crc32c checksum against the restored bytes. 220 // REQUIRES: status().ok() 221 Status Lookup(StringPiece key, Tensor* val) TF_MUST_USE_RESULT; 222 223 // Looks up the tensor pointed to by the internal iterator. 224 // 225 // On error, "val" may contain nonsense data. 226 // 227 // Validates the stored crc32c checksum against the restored bytes. 228 // REQUIRES: status().ok() && Valid() 229 Status ReadCurrent(Tensor* val) TF_MUST_USE_RESULT; 230 231 // Looks up the slices of the tensor keyed by "key". On OK, "slices" 232 // is non-empty if and only if the tensor is a partitioned tensor. 233 // 234 // Warning - there is no guaranteed ordering for the returned slices, so 235 // a slice with a larger start index in some dimension could come before 236 // another slice with a smaller start index in the same dimension. 237 // REQUIRES: status().ok() 238 Status LookupTensorSlices(StringPiece key, std::vector<TensorSlice>* slices) 239 TF_MUST_USE_RESULT; 240 241 // Looks up a specific slice of a partitioned tensor. 242 // It is only required that the stored slices cover the requested slice, 243 // namely "slice_spec" is a subset of the union of the stored slices. 244 // REQUIRES: status().ok() 245 Status LookupSlice(StringPiece full_tensor_key, const TensorSlice& slice_spec, 246 Tensor* val) TF_MUST_USE_RESULT; 247 248 // Seeks to the first position in the bundle whose key is no less than "key". 249 // REQUIRES: status().ok() Seek(StringPiece key)250 void Seek(StringPiece key) { return iter_->Seek(key); } 251 // Moves to the next position in the bundle. 252 // REQUIRES: status().ok() Next()253 void Next() const { iter_->Next(); } 254 // Returns true iff the reader is positioned to a key/val pair. 255 // REQUIRES: status().ok() Valid()256 bool Valid() const { return iter_->Valid(); } 257 258 // Returns the key at the current position. 259 // REQUIRES: status().ok() && Valid() key()260 StringPiece key() const { return iter_->key(); } 261 // Returns the raw value at the current position. 262 // REQUIRES: status().ok() && Valid() value()263 StringPiece value() const { return iter_->value(); } 264 265 string DebugString(); 266 267 private: 268 // Seeks for "key" and reads the metadata proto. 269 // On non-OK return, clears "entry" for the caller. 270 // REQUIRES: status().ok() 271 Status GetBundleEntryProto(StringPiece key, 272 BundleEntryProto* entry) TF_MUST_USE_RESULT; 273 274 // Reads the tensor value described by the metadata proto "entry". 275 // Usage for "val" follows the comment of "Lookup()". 276 Status GetValue(const BundleEntryProto& entry, 277 Tensor* val) TF_MUST_USE_RESULT; 278 279 // Reads the slice described by "slice_spec". The corresponding full tensor 280 // has key "ful_tensor_key" and metadata proto "full_tensor_entry". 281 // REQUIRES: full_tensor_entry.slices_size() > 0 282 Status GetSliceValue(StringPiece full_tensor_key, 283 const BundleEntryProto& full_tensor_entry, 284 const TensorSlice& slice_spec, 285 Tensor* val) TF_MUST_USE_RESULT; 286 287 Env* env_; // Not owned. 288 const string prefix_; 289 290 Status status_; 291 RandomAccessFile* metadata_; // Owned. 292 table::Table* table_; 293 table::Cache* index_cache_; 294 table::Iterator* iter_; 295 // Owned the InputBuffer objects and their underlying RandomAccessFile's. 296 std::unordered_map<int32, io::InputBuffer*> data_; 297 298 // Maps each partitioned tensor's key to its stored slices (represented in a 299 // TensorSliceSet). Populated on-demand. 300 std::unordered_map<string, checkpoint::TensorSliceSet*> tensor_slices_; 301 302 // Expected number of data file shards in the bundle. Extracted by reading 303 // the header entry in the metadata table. 304 int num_shards_; 305 306 // Flag that this class sets to true when the endianness of the target bundle 307 // differs from that of the current system's processor architecture. 308 bool need_to_swap_bytes_; 309 310 friend class TensorBundleAlignmentTest; // For testing data alignment. 311 312 TF_DISALLOW_COPY_AND_ASSIGN(BundleReader); 313 }; 314 315 // A buffering wrapper for a WritableFile. Useful if the caller wishes to issue 316 // small writes to a file (e.g. writing out a list of small varints). 317 // External synchronization must be used in the presence of concurrent callers. 318 class FileOutputBuffer { 319 public: 320 FileOutputBuffer(WritableFile* file, size_t buffer_size); 321 ~FileOutputBuffer(); 322 323 // Buffered append. 324 Status Append(StringPiece data); 325 326 // Returns the running crc32c checksum of all currently appended bytes. crc32c()327 uint32 crc32c() { return crc32c_; } 328 // Clears the running crc32c checksum. clear_crc32c()329 void clear_crc32c() { crc32c_ = 0; } 330 331 // Appends the buffered data, then closes the underlying file. 332 Status Close(); 333 334 private: 335 // Appends the buffered data to the underlying file. Does NOT flush the file. 336 Status FlushBuffer(bool closing); 337 338 WritableFile* file_; // Owned. 339 340 // buffer_ptr_[0, position_) holds the buffered data not yet appended to the 341 // underlying file. 342 size_t position_; 343 const size_t buffer_size_; 344 char* buffer_ptr_; 345 346 // Checksum of all appended bytes since construction or last clear_crc32c(). 347 uint32 crc32c_ = 0; 348 }; 349 350 } // namespace tensorflow 351 352 #endif // TENSORFLOW_CORE_UTIL_TENSOR_BUNDLE_TENSOR_BUNDLE_H_ 353