• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7 http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // A tensor bundle is a set of immutable persistent files storing a set of named
17 // tensors.  It is designed for checkpointing TensorFlow tensors.
18 //
19 // The paths of the managed files share a common prefix; e.g., with the prefix:
20 //   /fs/model/train/ckpt-step/ckpt
21 //
22 // the bundle may contain a metadata file, and sharded data files:
23 //   /fs/model/train/ckpt-step/
24 //       ckpt.index
25 //       ckpt.data-00000-of-00020
26 //       ckpt.data-00001-of-00020
27 //       ...
28 //       ckpt.data-00019-of-00020
29 //
30 // The ".index" file is a string-string immutable table
31 // (tensorflow::table::Table).  Each key is a name of a tensor and its value is
32 // a serialized BundleEntryProto.  Each BundleEntryProto describes the metadata
33 // of a tensor: which of the "data" files contains the content of a tensor, the
34 // offset into that file, checksum, some auxiliary data, etc.
35 //
36 // A tensor bundle can be accessed randomly using a BundleReader.  Usage:
37 //
38 //   BundleReader reader(env, "/fs/model/train/ckpt-step/ckpt");
39 //   reader.Lookup("name", &tensor);
40 //
41 // A tensor bundle can be built using BundleWriter.  Each BundleWriter builds a
42 // single data file bundle.  Multiple bundles can then be merged by
43 // MergeBundles() without reading and writing large chunk of data: it reads the
44 // metadata files and outputs a single merged metadata.  Typical usage:
45 //
46 //   worker 0:
47 //     BundleWriter writer(env, "/fs/model/train/ckpt-step/tmp/worker0-step");
48 //     writer.Add(...);  // Adds the tensors on this worker.
49 //     writer.Finish();  // Flushes.
50 //   worker 1:
51 //     BundleWriter writer(env, "/fs/model/train/ckpt-step/tmp/worker1-step");
52 //     writer.Add(...);
53 //     writer.Finish();
54 //   worker 2:
55 //     MergeBundles(env,
56 //       {"/fs/model/train/ckpt-step/tmp/worker0-step",
57 //        "/fs/model/train/ckpt-step/tmp/worker1-step"},
58 //       "/fs/model/train/ckpt-step/ckpt" /* merged prefix */);
59 //
60 
61 #ifndef TENSORFLOW_CORE_UTIL_TENSOR_BUNDLE_TENSOR_BUNDLE_H_
62 #define TENSORFLOW_CORE_UTIL_TENSOR_BUNDLE_TENSOR_BUNDLE_H_
63 
64 #include <map>
65 #include <string>
66 #include <unordered_map>
67 
68 #include "tensorflow/core/framework/tensor.h"
69 #include "tensorflow/core/framework/tensor_shape.h"
70 #include "tensorflow/core/framework/tensor_slice.h"
71 #include "tensorflow/core/lib/core/status.h"
72 #include "tensorflow/core/lib/gtl/array_slice.h"
73 #include "tensorflow/core/lib/io/cache.h"
74 #include "tensorflow/core/lib/io/inputbuffer.h"
75 #include "tensorflow/core/lib/io/table.h"
76 #include "tensorflow/core/platform/cord.h"
77 #include "tensorflow/core/platform/env.h"
78 #include "tensorflow/core/platform/file_system.h"
79 #include "tensorflow/core/platform/macros.h"
80 #include "tensorflow/core/platform/types.h"
81 #include "tensorflow/core/protobuf/tensor_bundle.pb.h"
82 #include "tensorflow/core/util/tensor_bundle/naming.h"
83 #include "tensorflow/core/util/tensor_slice_set.h"
84 
85 namespace tensorflow {
86 
87 class FileOutputBuffer;
88 
89 // Versioning of the tensor bundle format.
90 // Follows the same rules as 3p/tf/core/public/version.h.
91 //
92 // History:
93 // 0. Any tensor bundles produced before this field was added.
94 // 1. Added this field (2016-09-14).
95 extern const int kTensorBundleMinProducer;
96 extern const int kTensorBundleMinConsumer;
97 extern const int kTensorBundleVersion;
98 
99 // The empty string, hence always the first key in the metadata table.  Its
100 // corresponding value is a BundleHeaderProto.
101 extern const char* const kHeaderEntryKey;
102 
103 // Builds a string-string table of tensor names to BundleEntryProto (metadata).
104 //
105 // On construction, attempts to create a directory given by the dirname of
106 // "prefix", so "status()" must be checked before calling any member functions.
107 //
108 // All threads accessing the same BundleWriter must synchronize.
109 class BundleWriter {
110  public:
111   struct Options {
OptionsOptions112     Options() {}
113     // Alignment, in bytes, for tensor data.
114     // Must be >= 1. The default size of 1 densely packs tensors.
115     int data_alignment{1};
116   };
117   BundleWriter(Env* env, StringPiece prefix,
118                const Options& options = Options());
119 
120   // Adds the tensor "val" under key "key".
121   // Across calls "key" must be unique but can be added in any order.
122   Status Add(StringPiece key, const Tensor& val);
123 
124   // Partitioned variables support.
125   // A slice of a full tensor is stored in two entries in the metadata table:
126   //
127   //   full_tensor_key   -> BundleEntryProto, describing all stored slices
128   //                        of this full tensor.  Does not append to the data
129   //                        file.
130   //   encoded slice key -> BundleEntryProto, describing one particular slice.
131   //                        Appends values of this slice to the data file.
132   //
133   // Slices of a full tensor can be added in any order.
134   //
135   // If a full tensor has slices placed on N devices and N BundleWriter's are
136   // concurrently used, the caller must use MergeBundles() to ensure that a
137   // consistent entry for "full_tensor_key" is produced.
138   //
139   // Returns an error if the same slice is added the second time.
140   Status AddSlice(StringPiece full_tensor_key,
141                   const TensorShape& full_tensor_shape,
142                   const TensorSlice& slice_spec, const Tensor& slice_tensor);
143 
144   // Finishes the writer and flushes.
145   Status Finish() TF_MUST_USE_RESULT;
146 
status()147   Status status() const { return status_; }
148 
149  private:
150   Env* const env_;  // Not owned.
151   const Options options_;
152   const string prefix_;
153   string metadata_path_;
154   string data_path_;
155   bool use_temp_file_;
156   std::unique_ptr<FileOutputBuffer> out_;
157   int64 size_;  // Number of bytes written into out_.
158   std::map<string, BundleEntryProto> entries_;
159   Status status_;
160 
161   TF_DISALLOW_COPY_AND_ASSIGN(BundleWriter);
162 };
163 
164 // Merges a set of bundles (given their prefixes) into a single bundle with the
165 // given "merged_prefix".  The merged metadata is guaranteed to be consistent.
166 //
167 // If there are N bundles in "prefixes", during the merge the data files will be
168 // renamed to contain a proper sharded file spec, with num_shards set to the sum
169 // of num_shards across the N input bundles.
170 //
171 // The caller should only rely on the metadata file of the merged bundle to
172 // query information about a tensor.  In particular, this function does not
173 // guarantee not to re-order the input data files.
174 //
175 // Once merged, makes a best effort to delete the old metadata files.
176 // Returns OK iff all bundles are successfully merged.
177 Status MergeBundles(Env* env, gtl::ArraySlice<tstring> prefixes,
178                     StringPiece merged_prefix);
179 
180 // On construction, silently attempts to read the metadata associated with
181 // "prefix".  If caller intends to call any function afterwards, "status()"
182 // must be checked.
183 // All threads accessing the same BundleReader must synchronize.
184 class BundleReader {
185  public:
186   BundleReader(Env* const env, StringPiece prefix);
187   ~BundleReader();
188 
189   // Is ok() iff the reader construction is successful (completed the read of
190   // the metadata).
status()191   Status status() const { return status_; }
192 
193   // Queries whether the bundle contains an entry keyed by "key".  Calls Seek()
194   // internally, so this call invalidates the reader's current position.
195   // REQUIRES: status().ok()
196   bool Contains(StringPiece key);
197 
198   // Looks up the dtype and the shape of the tensor keyed by "key".
199   // REQUIRES: status().ok()
200   Status LookupDtypeAndShape(StringPiece key, DataType* dtype,
201                              TensorShape* shape) TF_MUST_USE_RESULT;
202 
203   // Looks up the shape of the tensor keyed by "key".
204   // Clears "shape" if not found.
205   // REQUIRES: status().ok()
206   Status LookupTensorShape(StringPiece key,
207                            TensorShape* shape) TF_MUST_USE_RESULT;
208 
209   // Looks up the tensor keyed by "key".  If "key" refers to a partitioned
210   // tensor, attempts to look up the full contents using all stored slices.
211   //
212   // Caller must make sure "val" has the same shape and dtype as the
213   // corresponding contents, so that its buffer can be filled without needing
214   // extra allocation.  These can be queried via "LookupDtypeAndShape()".
215   //
216   // On error, "val" may contain nonsense data.  Returns a NotFound error if
217   // tensor keyed by "key" does not exist in this bundle.
218   //
219   // Validates the stored crc32c checksum against the restored bytes.
220   // REQUIRES: status().ok()
221   Status Lookup(StringPiece key, Tensor* val) TF_MUST_USE_RESULT;
222 
223   // Looks up the tensor pointed to by the internal iterator.
224   //
225   // On error, "val" may contain nonsense data.
226   //
227   // Validates the stored crc32c checksum against the restored bytes.
228   // REQUIRES: status().ok() && Valid()
229   Status ReadCurrent(Tensor* val) TF_MUST_USE_RESULT;
230 
231   // Looks up the slices of the tensor keyed by "key".  On OK, "slices"
232   // is non-empty if and only if the tensor is a partitioned tensor.
233   //
234   // Warning - there is no guaranteed ordering for the returned slices, so
235   // a slice with a larger start index in some dimension could come before
236   // another slice with a smaller start index in the same dimension.
237   // REQUIRES: status().ok()
238   Status LookupTensorSlices(StringPiece key, std::vector<TensorSlice>* slices)
239       TF_MUST_USE_RESULT;
240 
241   // Looks up a specific slice of a partitioned tensor.
242   // It is only required that the stored slices cover the requested slice,
243   // namely "slice_spec" is a subset of the union of the stored slices.
244   // REQUIRES: status().ok()
245   Status LookupSlice(StringPiece full_tensor_key, const TensorSlice& slice_spec,
246                      Tensor* val) TF_MUST_USE_RESULT;
247 
248   // Seeks to the first position in the bundle whose key is no less than "key".
249   // REQUIRES: status().ok()
Seek(StringPiece key)250   void Seek(StringPiece key) { return iter_->Seek(key); }
251   // Moves to the next position in the bundle.
252   // REQUIRES: status().ok()
Next()253   void Next() const { iter_->Next(); }
254   // Returns true iff the reader is positioned to a key/val pair.
255   // REQUIRES: status().ok()
Valid()256   bool Valid() const { return iter_->Valid(); }
257 
258   // Returns the key at the current position.
259   // REQUIRES: status().ok() && Valid()
key()260   StringPiece key() const { return iter_->key(); }
261   // Returns the raw value at the current position.
262   // REQUIRES: status().ok() && Valid()
value()263   StringPiece value() const { return iter_->value(); }
264 
265   string DebugString();
266 
267  private:
268   // Seeks for "key" and reads the metadata proto.
269   // On non-OK return, clears "entry" for the caller.
270   // REQUIRES: status().ok()
271   Status GetBundleEntryProto(StringPiece key,
272                              BundleEntryProto* entry) TF_MUST_USE_RESULT;
273 
274   // Reads the tensor value described by the metadata proto "entry".
275   // Usage for "val" follows the comment of "Lookup()".
276   Status GetValue(const BundleEntryProto& entry,
277                   Tensor* val) TF_MUST_USE_RESULT;
278 
279   // Reads the slice described by "slice_spec".  The corresponding full tensor
280   // has key "ful_tensor_key" and metadata proto "full_tensor_entry".
281   // REQUIRES: full_tensor_entry.slices_size() > 0
282   Status GetSliceValue(StringPiece full_tensor_key,
283                        const BundleEntryProto& full_tensor_entry,
284                        const TensorSlice& slice_spec,
285                        Tensor* val) TF_MUST_USE_RESULT;
286 
287   Env* env_;  // Not owned.
288   const string prefix_;
289 
290   Status status_;
291   RandomAccessFile* metadata_;  // Owned.
292   table::Table* table_;
293   table::Cache* index_cache_;
294   table::Iterator* iter_;
295   // Owned the InputBuffer objects and their underlying RandomAccessFile's.
296   std::unordered_map<int32, io::InputBuffer*> data_;
297 
298   // Maps each partitioned tensor's key to its stored slices (represented in a
299   // TensorSliceSet).  Populated on-demand.
300   std::unordered_map<string, checkpoint::TensorSliceSet*> tensor_slices_;
301 
302   // Expected number of data file shards in the bundle.  Extracted by reading
303   // the header entry in the metadata table.
304   int num_shards_;
305 
306   // Flag that this class sets to true when the endianness of the target bundle
307   // differs from that of the current system's processor architecture.
308   bool need_to_swap_bytes_;
309 
310   friend class TensorBundleAlignmentTest;  // For testing data alignment.
311 
312   TF_DISALLOW_COPY_AND_ASSIGN(BundleReader);
313 };
314 
315 // A buffering wrapper for a WritableFile.  Useful if the caller wishes to issue
316 // small writes to a file (e.g. writing out a list of small varints).
317 // External synchronization must be used in the presence of concurrent callers.
318 class FileOutputBuffer {
319  public:
320   FileOutputBuffer(WritableFile* file, size_t buffer_size);
321   ~FileOutputBuffer();
322 
323   // Buffered append.
324   Status Append(StringPiece data);
325 
326   // Returns the running crc32c checksum of all currently appended bytes.
crc32c()327   uint32 crc32c() { return crc32c_; }
328   // Clears the running crc32c checksum.
clear_crc32c()329   void clear_crc32c() { crc32c_ = 0; }
330 
331   // Appends the buffered data, then closes the underlying file.
332   Status Close();
333 
334  private:
335   // Appends the buffered data to the underlying file. Does NOT flush the file.
336   Status FlushBuffer(bool closing);
337 
338   WritableFile* file_;  // Owned.
339 
340   // buffer_ptr_[0, position_) holds the buffered data not yet appended to the
341   // underlying file.
342   size_t position_;
343   const size_t buffer_size_;
344   char* buffer_ptr_;
345 
346   // Checksum of all appended bytes since construction or last clear_crc32c().
347   uint32 crc32c_ = 0;
348 };
349 
350 }  // namespace tensorflow
351 
352 #endif  // TENSORFLOW_CORE_UTIL_TENSOR_BUNDLE_TENSOR_BUNDLE_H_
353