• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7 http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/util/tensor_bundle/tensor_bundle.h"
17 
18 #include <algorithm>
19 #include <cstdlib>
20 #include <cstring>
21 #include <memory>
22 #include <utility>
23 
24 #include "tensorflow/core/framework/register_types.h"
25 #include "tensorflow/core/framework/tensor.pb.h"
26 #include "tensorflow/core/framework/tensor_shape.pb.h"
27 #include "tensorflow/core/framework/types.h"
28 #include "tensorflow/core/framework/types.pb.h"
29 #include "tensorflow/core/framework/variant.h"
30 #include "tensorflow/core/framework/variant_op_registry.h"
31 #include "tensorflow/core/framework/variant_tensor_data.h"
32 #include "tensorflow/core/framework/versions.h"
33 #include "tensorflow/core/framework/versions.pb.h"
34 #include "tensorflow/core/lib/bfloat16/bfloat16.h"
35 #include "tensorflow/core/lib/core/coding.h"
36 #include "tensorflow/core/lib/core/errors.h"
37 #include "tensorflow/core/lib/gtl/map_util.h"
38 #include "tensorflow/core/lib/hash/crc32c.h"
39 #include "tensorflow/core/lib/io/path.h"
40 #include "tensorflow/core/lib/io/table_builder.h"
41 #include "tensorflow/core/lib/random/random.h"
42 #include "tensorflow/core/lib/strings/stringprintf.h"
43 #include "tensorflow/core/util/saved_tensor_slice_util.h"
44 #include "tensorflow/core/util/tensor_bundle/byte_swap.h"
45 #include "tensorflow/core/util/tensor_slice_util.h"
46 
47 namespace tensorflow {
48 
49 // Versioning of the tensor bundle format.
50 const int kTensorBundleMinProducer = 0;
51 const int kTensorBundleMinConsumer = 0;
52 const int kTensorBundleVersion = 1;
53 
54 // Size of our input buffer for streaming reads
55 static const int kBufferSize = 1024 * 1024;
56 
57 // Key to the special BundleHeaderProto entry.  Do not change this, as clients
58 // can make the assumption that the header is always the first entry in the
59 // bundle.
60 const char* const kHeaderEntryKey = "";
61 
62 namespace {
63 
64 // Reads "num_elements" string elements from file[offset, offset+size) into the
65 // length-N "destination".  Discards the original content of "destination".
66 //
67 // Checksums the string lengths (as restored uint32 or uint64, not varint64
68 // bytes) and string bytes, and stores it into "actual_crc32c".
ReadStringTensor(io::InputBuffer * buffered_file,size_t num_elements,size_t offset,size_t size,tstring * destination,uint32 * actual_crc32c,bool need_to_swap_bytes)69 Status ReadStringTensor(io::InputBuffer* buffered_file, size_t num_elements,
70                         size_t offset, size_t size, tstring* destination,
71                         uint32* actual_crc32c, bool need_to_swap_bytes) {
72   if (size == 0) return Status::OK();
73   CHECK_GT(size, 0);
74 
75   // Reads "num_elements" varint64's from "buffered_file".
76   TF_RETURN_IF_ERROR(buffered_file->Seek(offset));
77   std::vector<uint64> string_lengths(num_elements);
78   for (size_t i = 0; i < num_elements; ++i) {
79     TF_RETURN_IF_ERROR(buffered_file->ReadVarint64(&string_lengths[i]));
80     if (string_lengths[i] <= UINT32_MAX) {
81       // We need to do this because older checkpoints only used uint32s and we
82       // should still support them.
83       uint32 elem_size_uint32 = static_cast<uint32>(string_lengths[i]);
84       if (need_to_swap_bytes) {
85         // Checksum would have been computed on the source machine's byte order
86         elem_size_uint32 = BYTE_SWAP_32(elem_size_uint32);
87       }
88       *actual_crc32c = crc32c::Extend(
89           *actual_crc32c, reinterpret_cast<const char*>(&elem_size_uint32),
90           sizeof(uint32));
91     } else {
92       uint64 length = string_lengths[i];
93       if (need_to_swap_bytes) {
94         length = BYTE_SWAP_64(length);
95       }
96       *actual_crc32c =
97           crc32c::Extend(*actual_crc32c, reinterpret_cast<const char*>(&length),
98                          sizeof(uint64));
99     }
100   }
101   if (offset + size < buffered_file->Tell()) {
102     return errors::DataLoss("String lengths longer than expected offset ",
103                             offset + size);
104   }
105 
106   // Reads the length-checksum.
107   uint32 raw_length_checksum = 0;  // Bytes in file
108   uint32 length_checksum = 0;      // In-memory representation
109   size_t unused_bytes_read = 0;
110   TF_RETURN_IF_ERROR(buffered_file->ReadNBytes(
111       sizeof(uint32), reinterpret_cast<char*>(&raw_length_checksum),
112       &unused_bytes_read));
113   length_checksum = need_to_swap_bytes ? BYTE_SWAP_32(raw_length_checksum)
114                                        : raw_length_checksum;
115   if (crc32c::Unmask(length_checksum) != *actual_crc32c) {
116     return errors::DataLoss(
117         "The length checksum does not match: expected ",
118         strings::Printf("%08u", crc32c::Unmask(length_checksum)),
119         " but actual is ", strings::Printf("%08u", *actual_crc32c));
120   }
121   *actual_crc32c = crc32c::Extend(*actual_crc32c,
122                                   reinterpret_cast<char*>(&raw_length_checksum),
123                                   sizeof(uint32));
124 
125   // Reads the actual string bytes.
126   for (size_t i = 0; i < num_elements; ++i) {
127     const uint64 string_length = string_lengths[i];
128     tstring* buffer = &destination[i];
129 
130     buffer->resize(string_length);
131     size_t bytes_read = 0;
132     TF_RETURN_IF_ERROR(
133         buffered_file->ReadNBytes(string_length, &(*buffer)[0], &bytes_read));
134     *actual_crc32c = crc32c::Extend(*actual_crc32c, buffer->data(), bytes_read);
135   }
136   return Status::OK();
137 }
138 
ReadVariantTensor(io::InputBuffer * buffered_file,Tensor * ret,size_t offset,size_t size,uint32 * actual_crc32c)139 Status ReadVariantTensor(io::InputBuffer* buffered_file, Tensor* ret,
140                          size_t offset, size_t size, uint32* actual_crc32c) {
141   // On-disk format:
142   //   [varint64 len1][bytes variant1][4 byte checksum]
143   //   ..
144   //   [varint64 lenN][bytes variantN][4 byte checksum]
145   // Var "crc32c" checksums all the lens, variant bytes, individual variant
146   // checksums (as uint32, not varint32 bytes).
147   if (size == 0) return Status::OK();
148   size_t num_elements = ret->NumElements();
149 
150   // Reads the actual string bytes.
151   TF_RETURN_IF_ERROR(buffered_file->Seek(offset));
152   for (size_t i = 0; i < num_elements; ++i) {
153     // Read the serialized variant length.
154     uint64 string_length = 0;
155     TF_RETURN_IF_ERROR(buffered_file->ReadVarint64(&string_length));
156     *actual_crc32c = crc32c::Extend(
157         *actual_crc32c, reinterpret_cast<const char*>(&string_length),
158         sizeof(uint64));
159     // Read the actual serialized variant.
160     string buffer;
161     buffer.resize(string_length);
162     size_t bytes_read = 0;
163     TF_RETURN_IF_ERROR(
164         buffered_file->ReadNBytes(string_length, &buffer[0], &bytes_read));
165     *actual_crc32c = crc32c::Extend(*actual_crc32c, buffer.data(), bytes_read);
166     VariantTensorDataProto proto;
167     if (!proto.ParseFromString(buffer)) {
168       return errors::DataLoss("Unable to parse VariantTensorDataProto from ",
169                               "buffer of size ", string_length, ". ",
170                               "Bundle entry offset: ", offset, " size: ", size);
171     }
172     Variant v = proto;
173     if (!DecodeUnaryVariant(&v)) {
174       return errors::Internal("Could not decode variant with type_name: \"",
175                               v.TypeName(), "\".  Perhaps you forgot to ",
176                               "register a decoder via ",
177                               "REGISTER_UNARY_VARIANT_DECODE_FUNCTION?");
178     }
179 
180     // Read the checksum.
181     uint32 checksum = 0;
182     size_t unused_bytes_read = 0;
183     TF_RETURN_IF_ERROR(buffered_file->ReadNBytes(
184         sizeof(uint32), reinterpret_cast<char*>(&checksum),
185         &unused_bytes_read));
186     if (crc32c::Unmask(checksum) != *actual_crc32c) {
187       return errors::DataLoss(
188           "The checksum after Variant ", i, " does not match.",
189           " Expected: ", strings::Printf("%08u", crc32c::Unmask(checksum)),
190           " Actual: ", strings::Printf("%08u", *actual_crc32c));
191     }
192     *actual_crc32c = crc32c::Extend(
193         *actual_crc32c, reinterpret_cast<char*>(&checksum), sizeof(uint32));
194 
195     ret->flat<Variant>()(i) = std::move(v);
196   }
197 
198   return Status::OK();
199 }
200 
GetBackingBuffer(const Tensor & val)201 char* GetBackingBuffer(const Tensor& val) {
202   CHECK(DataTypeCanUseMemcpy(val.dtype())) << val.dtype();
203   return const_cast<char*>(val.tensor_data().data());
204 }
205 
GetStringBackingBuffer(const Tensor & val)206 tstring* GetStringBackingBuffer(const Tensor& val) {
207   CHECK_EQ(DT_STRING, val.dtype());
208   return const_cast<tstring*>(val.flat<tstring>().data());
209 }
210 
ParseEntryProto(StringPiece key,StringPiece value,protobuf::MessageLite * out)211 Status ParseEntryProto(StringPiece key, StringPiece value,
212                        protobuf::MessageLite* out) {
213   if (!out->ParseFromArray(value.data(), value.size())) {
214     return errors::DataLoss("Entry for key ", key, " not parseable.");
215   }
216   return Status::OK();
217 }
218 
219 // Serializes the data bytes of the non-string tensor "val".  Discards the
220 // original content of "bytes_written", and on OK updates it with number of
221 // bytes written.
222 // REQUIRES: val.dtype() != DT_STRING
WriteTensor(const Tensor & val,FileOutputBuffer * out,size_t * bytes_written)223 Status WriteTensor(const Tensor& val, FileOutputBuffer* out,
224                    size_t* bytes_written) {
225   DCHECK_NE(val.dtype(), DT_STRING);
226   DCHECK_NE(val.dtype(), DT_VARIANT);
227   *bytes_written = val.TotalBytes();
228   char* buf = GetBackingBuffer(val);
229   VLOG(1) << "Appending " << *bytes_written << " bytes to file";
230   return out->Append(StringPiece(buf, *bytes_written));
231 }
232 
233 // Serializes string tensor "val".  "bytes_written" is treated in the same
234 // fashion as WriteTensor().
235 //
236 // Checksums all bytes written and stores it into "crc32c".
237 // REQUIRES: val.dtype() == DT_STRING
WriteStringTensor(const Tensor & val,FileOutputBuffer * out,size_t * bytes_written,uint32 * crc32c)238 Status WriteStringTensor(const Tensor& val, FileOutputBuffer* out,
239                          size_t* bytes_written, uint32* crc32c) {
240   // On-disk format:
241   //   [varint64 len0]..[varint64 lenL][4 byte cksum on lengths][string bytes]
242   // Var "crc32c" checksums the string lengths (as uint64, not varint64 bytes),
243   // the length-checksum, and all the string bytes.
244   DCHECK_EQ(val.dtype(), DT_STRING);
245   const tstring* strings = GetStringBackingBuffer(val);
246 
247   // Writes the varint lengths.
248   string lengths;
249   lengths.reserve(val.NumElements());  // At least 1 byte per element.
250   *crc32c = 0;
251   for (int64 i = 0; i < val.NumElements(); ++i) {
252     const tstring* elem = &strings[i];
253     DCHECK_EQ(elem->size(), static_cast<uint64>(elem->size()));
254     const uint64 elem_size = static_cast<uint64>(elem->size());
255 
256     core::PutVarint64(&lengths, elem_size);
257     if (elem_size <= UINT32_MAX) {
258       // We need to do this because older checkpoints only used uint32s and we
259       // should still support them.
260       const uint32 elem_size_uint32 = static_cast<uint32>(elem_size);
261       *crc32c = crc32c::Extend(*crc32c,
262                                reinterpret_cast<const char*>(&elem_size_uint32),
263                                sizeof(uint32));
264     } else {
265       *crc32c = crc32c::Extend(
266           *crc32c, reinterpret_cast<const char*>(&elem_size), sizeof(uint64));
267     }
268   }
269   TF_RETURN_IF_ERROR(out->Append(lengths));
270   *bytes_written = lengths.size();
271 
272   // Writes the length checksum.
273   const uint32 length_checksum = crc32c::Mask(*crc32c);
274   TF_RETURN_IF_ERROR(out->Append(StringPiece(
275       reinterpret_cast<const char*>(&length_checksum), sizeof(uint32))));
276   *crc32c = crc32c::Extend(
277       *crc32c, reinterpret_cast<const char*>(&length_checksum), sizeof(uint32));
278   *bytes_written += sizeof(uint32);
279 
280   // Writes all the string bytes out.
281   for (int64 i = 0; i < val.NumElements(); ++i) {
282     const tstring* string = &strings[i];
283     TF_RETURN_IF_ERROR(out->Append(*string));
284     *bytes_written += string->size();
285     *crc32c = crc32c::Extend(*crc32c, string->data(), string->size());
286   }
287   return Status::OK();
288 }
289 
WriteVariantTensor(const Tensor & val,FileOutputBuffer * out,size_t * bytes_written,uint32 * crc32c)290 Status WriteVariantTensor(const Tensor& val, FileOutputBuffer* out,
291                           size_t* bytes_written, uint32* crc32c) {
292   // On-disk format:
293   //   [varint64 len1][bytes variant1][4 byte checksum]
294   //   ..
295   //   [varint64 lenN][bytes variantN][4 byte checksum]
296   // Var "crc32c" checksums all the lens, variant bytes, individual variant
297   // checksums (as uint32, not varint32 bytes).
298   DCHECK_EQ(val.dtype(), DT_VARIANT);
299 
300   *crc32c = 0;
301   *bytes_written = 0;
302   for (int64 i = 0; i < val.NumElements(); ++i) {
303     VariantTensorData data;
304     val.flat<Variant>()(i).Encode(&data);
305     VariantTensorDataProto proto;
306     data.ToProto(&proto);
307     string elem;
308     proto.SerializeToString(&elem);
309 
310     // Write the length of the serialized variant.
311     DCHECK_EQ(elem.size(), static_cast<uint64>(elem.size()));
312     const auto elem_size = static_cast<uint64>(elem.size());
313     string len;
314     core::PutVarint64(&len, elem_size);
315     TF_RETURN_IF_ERROR(out->Append(len));
316     *crc32c = crc32c::Extend(*crc32c, reinterpret_cast<const char*>(&elem_size),
317                              sizeof(uint64));
318     *bytes_written += len.size();
319 
320     // Write the serialized variant.
321     TF_RETURN_IF_ERROR(out->Append(elem));
322     *crc32c = crc32c::Extend(*crc32c, elem.data(), elem.size());
323     *bytes_written += elem.size();
324 
325     // Write the checksum.
326     const uint32 length_checksum = crc32c::Mask(*crc32c);
327     TF_RETURN_IF_ERROR(out->Append(StringPiece(
328         reinterpret_cast<const char*>(&length_checksum), sizeof(uint32))));
329     *crc32c =
330         crc32c::Extend(*crc32c, reinterpret_cast<const char*>(&length_checksum),
331                        sizeof(uint32));
332     *bytes_written += sizeof(uint32);
333   }
334 
335   return Status::OK();
336 }
337 
338 // Returns whether "slice_spec" is a full slice, with respect to the full shape.
339 //
340 // This can happen say, when "slice_spec" is
341 // "TensorSlice(full_tensor_shape.dims())", or when it is "TensorSlice({{0,
342 // dim(0)}, ..., {0, dim(N)}})" -- a degenerate case we need to guard against.
IsFullSlice(const TensorSlice & slice_spec,const TensorShape & full_tensor_shape)343 bool IsFullSlice(const TensorSlice& slice_spec,
344                  const TensorShape& full_tensor_shape) {
345   if (slice_spec.IsFull()) {
346     return true;
347   } else {
348     TensorShape sliced_shape;
349     slice_spec.SliceTensorShape(full_tensor_shape, &sliced_shape).IgnoreError();
350     return sliced_shape == full_tensor_shape;
351   }
352 }
353 
CorruptFileError(const Status & in_status,const string & filename,const string & detail)354 Status CorruptFileError(const Status& in_status, const string& filename,
355                         const string& detail) {
356   if (in_status.ok()) {
357     return errors::Internal("Unable to read file (", filename,
358                             "). Perhaps the file is corrupt or was produced by "
359                             "a newer version of TensorFlow with format changes "
360                             "(",
361                             detail, ")");
362   }
363   return Status(
364       in_status.code(),
365       strings::StrCat("Unable to read file (", filename,
366                       "). Perhaps the file is corrupt or was produced by a "
367                       "newer version of TensorFlow with format changes (",
368                       detail, "): ", in_status.error_message()));
369 }
370 
TableBuilderOptions()371 table::Options TableBuilderOptions() {
372   table::Options o;
373   // Compressed tables cannot be read by TensorFlow releases prior to 1.1.
374   // To smoothen the transition, compressed writes are disabled for now
375   // (version 1.2) with the intention that they will be enabled again at
376   // some point (perhaps the 1.3 release?).
377   o.compression = table::kNoCompression;
378   return o;
379 }
380 
381 // Writes zeros to output buffer to align the next write to the requested
382 // alignment. "size" is the current size of the buffer and is updated to the
383 // new size.
PadAlignment(FileOutputBuffer * out,int alignment,int64 * size)384 Status PadAlignment(FileOutputBuffer* out, int alignment, int64* size) {
385   int bytes_over = *size % alignment;
386   if (bytes_over == 0) {
387     return Status::OK();
388   }
389   int bytes_to_write = alignment - bytes_over;
390   Status status = out->Append(string(bytes_to_write, '\0'));
391   if (status.ok()) {
392     *size += bytes_to_write;
393   }
394   return status;
395 }
396 
397 }  // namespace
398 
BundleWriter(Env * env,StringPiece prefix,const Options & options)399 BundleWriter::BundleWriter(Env* env, StringPiece prefix, const Options& options)
400     : env_(env),
401       options_(options),
402       prefix_(prefix),
403       tmp_metadata_path_(strings::StrCat(MetaFilename(prefix_), ".tempstate",
404                                          random::New64())),
405       tmp_data_path_(strings::StrCat(DataFilename(prefix_, 0, 1), ".tempstate",
406                                      random::New64())),
407       out_(nullptr),
408       size_(0) {
409   status_ = env_->CreateDir(string(io::Dirname(prefix_)));
410   if (!status_.ok() && !errors::IsAlreadyExists(status_)) {
411     return;
412   }
413   const string filename = DataFilename(prefix_, 0, 1);
414   std::unique_ptr<WritableFile> wrapper;
415   status_ = env_->NewWritableFile(tmp_data_path_, &wrapper);
416   if (!status_.ok()) return;
417   out_ = std::unique_ptr<FileOutputBuffer>(
418       new FileOutputBuffer(wrapper.release(), 8 << 20 /* 8MB write buffer */));
419 
420   VLOG(1) << "Writing to file " << tmp_data_path_;
421 }
422 
Add(StringPiece key,const Tensor & val)423 Status BundleWriter::Add(StringPiece key, const Tensor& val) {
424   if (!status_.ok()) return status_;
425   CHECK_NE(key, kHeaderEntryKey);
426   const string key_string(key);
427   if (entries_.find(key_string) != entries_.end()) {
428     status_ = errors::InvalidArgument("Adding duplicate key: ", key);
429     return status_;
430   }
431 
432   BundleEntryProto* entry = &entries_[key_string];
433   entry->set_dtype(val.dtype());
434   val.shape().AsProto(entry->mutable_shape());
435   entry->set_shard_id(0);
436   entry->set_offset(size_);
437 
438   // Updates the data file.
439   size_t data_bytes_written = 0;
440   uint32 crc32c = 0;
441   out_->clear_crc32c();
442   if (val.dtype() == DT_STRING) {
443     status_ = WriteStringTensor(val, out_.get(), &data_bytes_written, &crc32c);
444   } else if (val.dtype() == DT_VARIANT) {
445     status_ = WriteVariantTensor(val, out_.get(), &data_bytes_written, &crc32c);
446   } else {
447     status_ = WriteTensor(val, out_.get(), &data_bytes_written);
448     crc32c = out_->crc32c();
449   }
450 
451   if (status_.ok()) {
452     entry->set_size(data_bytes_written);
453     entry->set_crc32c(crc32c::Mask(crc32c));
454     size_ += data_bytes_written;
455     status_ = PadAlignment(out_.get(), options_.data_alignment, &size_);
456   }
457   return status_;
458 }
459 
AddSlice(StringPiece full_tensor_key,const TensorShape & full_tensor_shape,const TensorSlice & slice_spec,const Tensor & slice_tensor)460 Status BundleWriter::AddSlice(StringPiece full_tensor_key,
461                               const TensorShape& full_tensor_shape,
462                               const TensorSlice& slice_spec,
463                               const Tensor& slice_tensor) {
464   if (!status_.ok()) return status_;
465   CHECK_NE(full_tensor_key, kHeaderEntryKey);
466 
467   // If just a singleton full slice, use the regular Add() to be more efficient.
468   if (IsFullSlice(slice_spec, full_tensor_shape)) {
469     return Add(full_tensor_key, slice_tensor);
470   }
471 
472   // Inserts/updates the full tensor's metadata entry.
473   //
474   // In the case of a sharded save, MergeBundles() is responsible for merging
475   // the "slices" field of multiple metadata entries corresponding to the same
476   // full tensor.
477   const string full_tensor_key_string(full_tensor_key);
478   BundleEntryProto* full_entry = &entries_[full_tensor_key_string];
479   if (full_entry->dtype() != DT_INVALID) {
480     CHECK_EQ(full_entry->dtype(), slice_tensor.dtype());
481   }
482   if (full_entry->has_shape()) {
483     CHECK(TensorShape(full_entry->shape()) == full_tensor_shape);
484   }
485 
486   // Populates dtype, shape, and slices.  Intentionally leaving out shard_id and
487   // offset, which do not make sense for this full tensor entry.
488   full_entry->set_dtype(slice_tensor.dtype());
489   full_tensor_shape.AsProto(full_entry->mutable_shape());
490   TensorSliceProto* slice_proto = full_entry->add_slices();
491   slice_spec.AsProto(slice_proto);
492 
493   // The slice itself is handled by a regular Add(), which includes adding its
494   // own metadata entry, and writing out the slice's values.
495   const string slice_name =
496       checkpoint::EncodeTensorNameSlice(full_tensor_key_string, slice_spec);
497   status_ = Add(slice_name, slice_tensor);
498   return status_;
499 }
500 
501 // TODO(zongheng): on metadata write failure or !status_.ok(), consider removing
502 // the orphaned data file.
Finish()503 Status BundleWriter::Finish() {
504   if (out_) {
505     status_.Update(out_->Close());
506     out_ = nullptr;
507     if (status_.ok()) {
508       status_ = Env::Default()->RenameFile(tmp_data_path_,
509                                            DataFilename(prefix_, 0, 1));
510     } else {
511       Env::Default()->DeleteFile(tmp_data_path_).IgnoreError();
512     }
513   }
514   if (!status_.ok()) return status_;
515   // Build key -> BundleEntryProto table.
516   std::unique_ptr<WritableFile> file;
517   status_ = env_->NewWritableFile(tmp_metadata_path_, &file);
518   if (!status_.ok()) return status_;
519   {
520     // N.B.: the default use of Snappy compression may not be supported on all
521     // platforms (e.g. Android).  The metadata file is small, so this is fine.
522     table::Options options;
523     options.compression = table::kNoCompression;
524     table::TableBuilder builder(options, file.get());
525     // Header entry.
526     BundleHeaderProto header;
527     header.set_num_shards(1);
528     header.set_endianness(BundleHeaderProto::LITTLE);
529     if (!port::kLittleEndian) header.set_endianness(BundleHeaderProto::BIG);
530     VersionDef* version = header.mutable_version();
531     version->set_producer(kTensorBundleVersion);
532     version->set_min_consumer(kTensorBundleMinConsumer);
533 
534     builder.Add(kHeaderEntryKey, header.SerializeAsString());
535 
536     // All others.
537     for (const auto& p : entries_) {
538       builder.Add(p.first, p.second.SerializeAsString());
539     }
540     status_ = builder.Finish();
541   }
542   status_.Update(file->Close());
543   if (!status_.ok()) {
544     Env::Default()->DeleteFile(tmp_metadata_path_).IgnoreError();
545     return status_;
546   } else {
547     status_ =
548         Env::Default()->RenameFile(tmp_metadata_path_, MetaFilename(prefix_));
549     if (!status_.ok()) return status_;
550   }
551   status_ = errors::Internal("BundleWriter is closed");
552   return Status::OK();
553 }
554 
555 // Merging tensor bundles.
556 
557 // Accumulator of metadata states during a merge.
558 struct MergeState {
559   // Accumulated from the header entries.
560   int num_shards = 0;
561 
562   // Derives "endianness" and "version" from the first bundle merged (hence the
563   // "seen_first_bundle" guard).  The two fields must be the same for all
564   // bundles in a merge.
565   bool seen_first_bundle = false;
566   BundleHeaderProto_Endianness endianness;
567   VersionDef version;
568 
569   // Tensor key -> BundleEntryProto.
570   std::map<string, BundleEntryProto> entries;
571   // Data file path -> new shard id in the final merged bundle.
572   std::unordered_map<string, int32> shard_ids;
573 };
574 
575 // Merges entries of "prefix" into the accumulator state "merge".
576 // Returns OK iff the merge succeeds.
MergeOneBundle(Env * env,StringPiece prefix,MergeState * merge_state)577 static Status MergeOneBundle(Env* env, StringPiece prefix,
578                              MergeState* merge_state) {
579   VLOG(1) << "Merging bundle:" << prefix;
580   const string filename = MetaFilename(prefix);
581   uint64 file_size;
582   TF_RETURN_IF_ERROR(env->GetFileSize(filename, &file_size));
583   std::unique_ptr<RandomAccessFile> file;
584   TF_RETURN_IF_ERROR(env->NewRandomAccessFile(filename, &file));
585 
586   table::Table* table = nullptr;
587   TF_RETURN_IF_ERROR(
588       table::Table::Open(TableBuilderOptions(), file.get(), file_size, &table));
589   std::unique_ptr<table::Table> table_deleter(table);
590   std::unique_ptr<table::Iterator> iter(table->NewIterator());
591 
592   int num_shards;
593   // Process header.
594   {
595     iter->Seek(kHeaderEntryKey);
596     if (!iter->Valid()) {
597       return CorruptFileError(iter->status(), filename,
598                               "failed to seek to header entry");
599     }
600     BundleHeaderProto header;
601     Status s = ParseEntryProto(iter->key(), iter->value(), &header);
602     if (!s.ok()) return CorruptFileError(s, filename, "unable to parse header");
603 
604     merge_state->num_shards += header.num_shards();
605     if (!merge_state->seen_first_bundle) {
606       merge_state->seen_first_bundle = true;
607       merge_state->endianness = header.endianness();
608       merge_state->version = header.version();
609     } else {
610       // Validates "endianness".
611       if (merge_state->endianness != header.endianness()) {
612         return errors::InvalidArgument(
613             "Merging bundles with conflicting endianness; inputs corrupted?");
614       }
615       // Validates "version".
616       string curr_version, merge_version;
617       header.version().SerializeToString(&curr_version);
618       merge_state->version.SerializeToString(&merge_version);
619       if (curr_version != merge_version) {
620         return errors::InvalidArgument(
621             "Merging bundles with different format versions: merged ",
622             merge_version, " vs. curr ", curr_version);
623       }
624     }
625     num_shards = header.num_shards();
626     iter->Next();
627   }
628 
629   // Loops through the non-header to-merge entries.
630   BundleEntryProto to_merge_entry;
631   for (; iter->Valid(); iter->Next()) {
632     const string key(iter->key());
633     const auto entry_iter = merge_state->entries.find(key);
634 
635     // Illegal: the duplicated entry is a non-slice tensor.
636     if (entry_iter != merge_state->entries.end() &&
637         entry_iter->second.slices().empty()) {
638       return errors::InvalidArgument(
639           "Duplicate tensor keyed by ", key,
640           " encountered, when merging prefix: ", prefix);
641     }
642 
643     TF_RETURN_IF_ERROR(
644         ParseEntryProto(iter->key(), iter->value(), &to_merge_entry));
645 
646     // The duplicated entry holds metadata for a sliced full tensor.
647     // Allows the duplication and merges "slices".
648     if (entry_iter != merge_state->entries.end()) {
649       BundleEntryProto& existing_entry = entry_iter->second;
650       if (to_merge_entry.slices().empty()) {
651         return errors::Internal(
652             "Duplicate tensor keyed by ", key,
653             "; attempting to merge in a non-slice bundle entry");
654       }
655       // Only needs merge the "slices" field (and validate dtype/shape).
656       for (int i = 0; i < to_merge_entry.slices_size(); ++i) {
657         TensorSliceProto* slot = existing_entry.add_slices();
658         *slot = to_merge_entry.slices(i);
659       }
660       CHECK_EQ(existing_entry.dtype(), to_merge_entry.dtype());
661       CHECK(TensorShape(existing_entry.shape()) ==
662             TensorShape(to_merge_entry.shape()));
663       continue;
664     }
665 
666     // Key doesn't duplicate: a fresh tensor/slice entry.
667     auto result = merge_state->shard_ids.insert(
668         {DataFilename(prefix, to_merge_entry.shard_id(), num_shards),
669          merge_state->shard_ids.size()});
670     to_merge_entry.set_shard_id(result.first->second);
671     merge_state->entries[key] = to_merge_entry;
672   }
673   return Status::OK();
674 }
675 
MergeBundles(Env * env,gtl::ArraySlice<tstring> prefixes,StringPiece merged_prefix)676 Status MergeBundles(Env* env, gtl::ArraySlice<tstring> prefixes,
677                     StringPiece merged_prefix) {
678   // Merges all metadata tables.
679   // TODO(zhifengc): KeyValue sorter if it becomes too big.
680   MergeState merge;
681   Status status = env->CreateDir(string(io::Dirname(merged_prefix)));
682   if (!status.ok() && !errors::IsAlreadyExists(status)) return status;
683   for (int i = 0; i < prefixes.size(); ++i) {
684     TF_RETURN_IF_ERROR(MergeOneBundle(env, prefixes[i], &merge));
685   }
686 
687   // Renames data files to contain the merged bundle prefix.
688   for (const auto& p : merge.shard_ids) {
689     VLOG(1) << "Renaming " << p.first << " to "
690             << DataFilename(merged_prefix, p.second, merge.shard_ids.size());
691     TF_RETURN_IF_ERROR(env->RenameFile(
692         p.first,
693         DataFilename(merged_prefix, p.second, merge.shard_ids.size())));
694   }
695 
696   // Writes the final metadata table under the merged prefix.
697   std::unique_ptr<WritableFile> merged_metadata;
698   TF_RETURN_IF_ERROR(
699       env->NewWritableFile(MetaFilename(merged_prefix), &merged_metadata));
700   {
701     table::TableBuilder builder(TableBuilderOptions(), merged_metadata.get());
702     // Header entry.
703     BundleHeaderProto header;
704     header.set_num_shards(merge.num_shards);
705     header.set_endianness(merge.endianness);
706     *header.mutable_version() = merge.version;
707     builder.Add(kHeaderEntryKey, header.SerializeAsString());
708     // All others.
709     for (const auto& p : merge.entries) {
710       builder.Add(p.first, p.second.SerializeAsString());
711     }
712     status = builder.Finish();
713   }
714   status.Update(merged_metadata->Close());
715   if (!status.ok()) return status;
716   VLOG(1) << "Merged bundles to:" << merged_prefix;
717 
718   // Cleanup: best effort based and ignores errors.
719   for (const tstring& prefix : prefixes) {
720     env->DeleteFile(MetaFilename(prefix)).IgnoreError();
721   }
722   return status;
723 }
724 
725 // Interface for reading a tensor bundle.
726 
BundleReader(Env * env,StringPiece prefix)727 BundleReader::BundleReader(Env* env, StringPiece prefix)
728     : env_(env),
729       prefix_(prefix),
730       metadata_(nullptr),
731       table_(nullptr),
732       iter_(nullptr),
733       need_to_swap_bytes_(false) {
734   const string filename = MetaFilename(prefix_);
735   uint64 file_size;
736   status_ = env_->GetFileSize(filename, &file_size);
737   if (!status_.ok()) return;
738 
739   // Opens the metadata table.
740   std::unique_ptr<RandomAccessFile> wrapper;
741   status_ = env_->NewRandomAccessFile(filename, &wrapper);
742   if (!status_.ok()) return;
743   metadata_ = wrapper.release();
744   status_ = table::Table::Open(table::Options(), metadata_, file_size, &table_);
745   if (!status_.ok()) return;
746   iter_ = table_->NewIterator();
747 
748   // Reads "num_shards_" from the first entry.
749   iter_->Seek(kHeaderEntryKey);
750   if (!iter_->Valid()) {
751     status_ = CorruptFileError(iter_->status(), filename,
752                                "failed to seek to header entry");
753     return;
754   }
755   BundleHeaderProto header;
756   status_ = ParseEntryProto(iter_->key(), iter_->value(), &header);
757   if (!status_.ok()) {
758     status_ = CorruptFileError(status_, filename, "unable to parse header");
759     return;
760   }
761   num_shards_ = header.num_shards();
762   if ((header.endianness() == BundleHeaderProto::BIG && port::kLittleEndian) ||
763       (header.endianness() == BundleHeaderProto::LITTLE &&
764        !port::kLittleEndian)) {
765     need_to_swap_bytes_ = true;
766   }
767   status_ = CheckVersions(header.version(), kTensorBundleVersion,
768                           kTensorBundleMinProducer, "Checkpoint", "checkpoint");
769 }
770 
~BundleReader()771 BundleReader::~BundleReader() {
772   delete metadata_;
773   delete iter_;
774   delete table_;
775   // InputBuffer does not own the underlying RandomAccessFile.
776   for (auto pair : data_) {
777     if (pair.second != nullptr && pair.second->file() != nullptr) {
778       delete pair.second->file();
779     }
780   }
781   for (auto& temp : data_) {
782     delete temp.second;
783   }
784   for (auto& temp : tensor_slices_) {
785     delete temp.second;
786   }
787   data_.clear();
788   tensor_slices_.clear();
789 }
790 
GetBundleEntryProto(StringPiece key,BundleEntryProto * entry)791 Status BundleReader::GetBundleEntryProto(StringPiece key,
792                                          BundleEntryProto* entry) {
793   entry->Clear();
794   TF_CHECK_OK(status_);
795   Seek(key);
796   if (!iter_->Valid() || iter_->key() != key) {
797     return errors::NotFound("Key ", key, " not found in checkpoint");
798   }
799 
800   BundleEntryProto entry_copy;
801   TF_RETURN_IF_ERROR(
802       ParseEntryProto(iter_->key(), iter_->value(), &entry_copy));
803   if (!TensorShape::IsValid(entry_copy.shape())) {
804     return errors::DataLoss("Invalid tensor shape: ", key, " ",
805                             entry_copy.shape().ShortDebugString());
806   }
807 
808   *entry = entry_copy;
809   return Status::OK();
810 }
811 
GetValue(const BundleEntryProto & entry,Tensor * val)812 Status BundleReader::GetValue(const BundleEntryProto& entry, Tensor* val) {
813   Tensor* ret = val;
814   const TensorShape stored_shape(TensorShape(entry.shape()));
815   if (val->NumElements() == 0) {
816     ret = new Tensor(entry.dtype(), stored_shape);
817   }
818 
819   // Validates the "size" field.
820   if (entry.dtype() != DT_STRING && entry.dtype() != DT_VARIANT) {
821     if (entry.size() != ret->TotalBytes()) {
822       return errors::DataLoss("Invalid size in bundle entry: key ", key(),
823                               "; stored size ", entry.size(),
824                               "; expected size ", ret->TotalBytes());
825     }
826   } else if (entry.dtype() == DT_STRING) {
827     // Relaxes the check for string tensors as follows:
828     //   entry.size() == bytes(varint lengths) + bytes(data)
829     //                >= NumElems + bytes(data), since size bytes(varint) >= 1.
830     //   TotalBytes() == sizeof(tstring) * NumElems + bytes(data)
831     // Since we don't know bytes(varint lengths), we just check an inequality.
832     const size_t lower_bound = ret->NumElements() + ret->TotalBytes() -
833                                sizeof(tstring) * ret->NumElements();
834     if (entry.size() < lower_bound) {
835       return errors::DataLoss("Invalid size in bundle entry: key ", key(),
836                               "; stored size ", entry.size(),
837                               "; expected size is at least ", lower_bound);
838     }
839   }
840 
841   // Open the data file if it has not been opened.
842   io::InputBuffer* buffered_file = data_[entry.shard_id()];
843   if (buffered_file == nullptr) {
844     std::unique_ptr<RandomAccessFile> file = nullptr;
845     TF_RETURN_IF_ERROR(env_->NewRandomAccessFile(
846         DataFilename(prefix_, entry.shard_id(), num_shards_), &file));
847     buffered_file = new io::InputBuffer(file.release(), kBufferSize);
848     // The InputBuffer and RandomAccessFile objects are both released in dtor.
849     data_[entry.shard_id()] = buffered_file;
850   }
851   CHECK(buffered_file != nullptr);
852 
853   TF_RETURN_IF_ERROR(buffered_file->Seek(entry.offset()));
854   uint32 actual_crc32c = 0;
855 
856   if (DataTypeCanUseMemcpy(entry.dtype())) {
857     char* backing_buffer = const_cast<char*>((ret->tensor_data().data()));
858     size_t unused_bytes_read;
859     if (entry.size() > kBufferSize) {
860       StringPiece sp;
861       TF_RETURN_IF_ERROR(buffered_file->file()->Read(
862           entry.offset(), entry.size(), &sp, backing_buffer));
863       if (sp.data() != backing_buffer) {
864         memmove(backing_buffer, sp.data(), entry.size());
865       }
866     } else {
867       TF_RETURN_IF_ERROR(buffered_file->ReadNBytes(entry.size(), backing_buffer,
868                                                    &unused_bytes_read));
869     }
870     // Note that we compute the checksum *before* byte-swapping. The checksum
871     // should be on the bytes in the order they appear in the file.
872     actual_crc32c = crc32c::Value(backing_buffer, entry.size());
873     if (need_to_swap_bytes_) {
874       TF_RETURN_IF_ERROR(ByteSwapTensor(ret));
875     }
876   } else if (entry.dtype() == DT_VARIANT) {
877     if (need_to_swap_bytes_) {
878       return errors::Unimplemented(
879           "TensorBundle at ", prefix_,
880           "is of a different endianness than this machine's hardware, and "
881           "the bundle contains a variant (arbitrary C++ type) tensor. "
882           "Byte-swapping of variant tensors is not currently implemented.");
883     }
884     // Relies on io::InputBuffer's buffering, because we issue many neighboring
885     // reads for a single string tensor.
886     TF_RETURN_IF_ERROR(ReadVariantTensor(buffered_file, ret, entry.offset(),
887                                          entry.size(), &actual_crc32c));
888   } else {
889     // Relies on io::InputBuffer's buffering, because we issue many neighboring
890     // reads for a single string tensor.
891     TF_RETURN_IF_ERROR(ReadStringTensor(
892         buffered_file, ret->NumElements(), entry.offset(), entry.size(),
893         GetStringBackingBuffer(*ret), &actual_crc32c, need_to_swap_bytes_));
894   }
895   if (crc32c::Unmask(entry.crc32c()) != actual_crc32c) {
896     return errors::DataLoss(
897         "Checksum does not match: stored ",
898         strings::Printf("%08u", crc32c::Unmask(entry.crc32c())),
899         " vs. calculated on the restored bytes ", actual_crc32c);
900   }
901 
902   *val = *ret;
903   if (ret != val) delete ret;
904   return Status::OK();
905 }
906 
Lookup(StringPiece key,Tensor * val)907 Status BundleReader::Lookup(StringPiece key, Tensor* val) {
908   CHECK(val != nullptr);
909   BundleEntryProto entry;
910   TF_RETURN_IF_ERROR(GetBundleEntryProto(key, &entry));
911 
912   if (entry.slices().empty()) {
913     return GetValue(entry, val);
914   } else {
915     return GetSliceValue(
916         key, entry,
917         /* a full slice */ TensorSlice(TensorShape(entry.shape()).dims()), val);
918   }
919 }
920 
ReadCurrent(Tensor * val)921 Status BundleReader::ReadCurrent(Tensor* val) {
922   CHECK(val != nullptr);
923   BundleEntryProto entry;
924   TF_RETURN_IF_ERROR(ParseEntryProto(iter_->key(), iter_->value(), &entry));
925   if (!TensorShape::IsValid(entry.shape())) {
926     return errors::DataLoss("Invalid tensor shape: ", iter_->key(), " ",
927                             entry.shape().ShortDebugString());
928   }
929 
930   if (entry.slices().empty()) {
931     return GetValue(entry, val);
932   } else {
933     return GetSliceValue(
934         iter_->key(), entry,
935         /* a full slice */ TensorSlice(TensorShape(entry.shape()).dims()), val);
936   }
937 }
938 
LookupTensorSlices(StringPiece key,std::vector<TensorSlice> * slices)939 Status BundleReader::LookupTensorSlices(StringPiece key,
940                                         std::vector<TensorSlice>* slices) {
941   slices->clear();
942   BundleEntryProto entry;
943   TF_RETURN_IF_ERROR(GetBundleEntryProto(key, &entry));
944   slices->reserve(entry.slices_size());
945   for (const auto& slice : entry.slices()) {
946     slices->emplace_back(slice);
947   }
948   return Status::OK();
949 }
950 
LookupSlice(StringPiece full_tensor_key,const TensorSlice & slice_spec,Tensor * val)951 Status BundleReader::LookupSlice(StringPiece full_tensor_key,
952                                  const TensorSlice& slice_spec, Tensor* val) {
953   CHECK(val != nullptr);
954   BundleEntryProto entry;
955   TF_RETURN_IF_ERROR(GetBundleEntryProto(full_tensor_key, &entry));
956   return GetSliceValue(full_tensor_key, entry, slice_spec, val);
957 }
958 
GetSliceValue(StringPiece full_tensor_key,const BundleEntryProto & full_tensor_entry,const TensorSlice & slice_spec,Tensor * val)959 Status BundleReader::GetSliceValue(StringPiece full_tensor_key,
960                                    const BundleEntryProto& full_tensor_entry,
961                                    const TensorSlice& slice_spec, Tensor* val) {
962   using checkpoint::RegisterTensorSlice;
963   using checkpoint::TensorSliceSet;
964   DCHECK_GE(full_tensor_entry.slices_size(), 0);
965 
966   const TensorShape full_shape(TensorShape(full_tensor_entry.shape()));
967   std::vector<std::pair<TensorSlice, string>> details;
968   const string full_tensor_key_string(full_tensor_key);
969   const TensorSliceSet* tss =
970       gtl::FindPtrOrNull(tensor_slices_, full_tensor_key_string);
971 
972   // Populates the "full tensor key -> TensorSliceSet" cache.
973   if (tss == nullptr) {
974     if (full_tensor_entry.slices().empty()) {
975       // Special case: a writer has saved a tensor fully, but the reader wants
976       // to read in slices.  We therefore register the full slice on-demand here
977       // without further complicating the on-disk bundle format.
978       TF_RETURN_IF_ERROR(RegisterTensorSlice(
979           full_tensor_key_string, full_shape, full_tensor_entry.dtype(),
980           /* tag */ "",
981           /* full slice */ TensorSlice(full_shape.dims()), &tensor_slices_));
982     }
983     for (const TensorSliceProto& slice : full_tensor_entry.slices()) {
984       TF_RETURN_IF_ERROR(RegisterTensorSlice(
985           full_tensor_key_string, full_shape, full_tensor_entry.dtype(),
986           /* tag */ "", TensorSlice(slice), &tensor_slices_));
987     }
988     tss = gtl::FindPtrOrNull(tensor_slices_, full_tensor_key_string);
989     CHECK_NE(tss, nullptr);
990   }
991   if (!tss->QueryMeta(slice_spec, &details)) {
992     return errors::InvalidArgument(
993         "Does not have sufficient slices for partitioned tensor ",
994         full_tensor_key,
995         " to restore in slice_spec: ", slice_spec.DebugString());
996   }
997 
998   // The union of the slices in "details" covers "slice_spec".  Performs the
999   // copies from each.
1000   BundleEntryProto stored_slice_entry = full_tensor_entry;
1001   for (const auto& slice_tag_pair : details) {
1002     // Seeks for the stored slice.
1003     const TensorSlice& stored_slice = slice_tag_pair.first;
1004 
1005     // We already have the entry for the full tensor, so don't query again if
1006     // the slice is full.
1007     if (!stored_slice.IsFull()) {
1008       const string encoded_stored_slice_name =
1009           checkpoint::EncodeTensorNameSlice(full_tensor_key_string,
1010                                             stored_slice);
1011       status_ =
1012           GetBundleEntryProto(encoded_stored_slice_name, &stored_slice_entry);
1013       if (!status_.ok()) return status_;
1014     }
1015 
1016     // TODO(zongheng): should we take an OpKernelContext, so that we can call
1017     // allocate_temp()?  Note that without major refactorings to Saver, it's
1018     // hard for the caller of the tensor bundle module to allocate these
1019     // precisely-shaped scratch storage.
1020 
1021     // Optimization for the common case: the stored slice can be directly
1022     // copied to the destination without additional slicing. This is true when
1023     // either the slices are equal or when they are both full slices having the
1024     // same shape.
1025     TensorShape stored_slice_shape(stored_slice_entry.shape());
1026     if (stored_slice == slice_spec ||
1027         (stored_slice_shape == val->shape() &&
1028          IsFullSlice(stored_slice, stored_slice_shape) &&
1029          IsFullSlice(slice_spec, stored_slice_shape))) {
1030       VLOG(1) << "Optimized for common case: directly copying into "
1031                  "pre-allocated buffer; spec: "
1032               << slice_spec.DebugString();
1033       status_ = GetValue(stored_slice_entry, val);
1034       return status_;
1035     }
1036 
1037     Tensor stored_slice_tensor(stored_slice_entry.dtype(), stored_slice_shape);
1038     status_ = GetValue(stored_slice_entry, &stored_slice_tensor);
1039     if (!status_.ok()) return status_;
1040 
1041     // Copies the intersection over.
1042     const DataType common_dtype = full_tensor_entry.dtype();
1043     switch (common_dtype) {
1044 #define HANDLE_COPY(T)                                                 \
1045   case DataTypeToEnum<T>::value:                                       \
1046     CHECK(CopyDataFromTensorSliceToTensorSlice(                        \
1047         full_shape, stored_slice, slice_spec,                          \
1048         stored_slice_tensor.flat<T>().data(), val->flat<T>().data())); \
1049     break;
1050 
1051       HANDLE_COPY(float)
1052       HANDLE_COPY(double)
1053       HANDLE_COPY(int32)
1054       HANDLE_COPY(uint8)
1055       HANDLE_COPY(int16)
1056       HANDLE_COPY(int8)
1057       HANDLE_COPY(complex64)
1058       HANDLE_COPY(complex128)
1059       HANDLE_COPY(int64)
1060       HANDLE_COPY(bool)
1061       HANDLE_COPY(qint32)
1062       HANDLE_COPY(quint8)
1063       HANDLE_COPY(qint8)
1064       HANDLE_COPY(bfloat16)
1065       default:
1066         return errors::InvalidArgument("Dtype ", DataTypeString(common_dtype),
1067                                        " not supported.");
1068     }
1069 #undef HANDLE_COPY
1070   }
1071   return Status::OK();
1072 }
1073 
Contains(StringPiece key)1074 bool BundleReader::Contains(StringPiece key) {
1075   Seek(key);
1076   return Valid() && (this->key() == key);
1077 }
1078 
LookupDtypeAndShape(StringPiece key,DataType * dtype,TensorShape * shape)1079 Status BundleReader::LookupDtypeAndShape(StringPiece key, DataType* dtype,
1080                                          TensorShape* shape) {
1081   BundleEntryProto entry;
1082   TF_RETURN_IF_ERROR(GetBundleEntryProto(key, &entry));
1083   *dtype = entry.dtype();
1084   *shape = TensorShape(entry.shape());
1085   return Status::OK();
1086 }
1087 
LookupTensorShape(StringPiece key,TensorShape * shape)1088 Status BundleReader::LookupTensorShape(StringPiece key, TensorShape* shape) {
1089   DataType ignored;
1090   return LookupDtypeAndShape(key, &ignored, shape);
1091 }
1092 
DebugString()1093 string BundleReader::DebugString() {
1094   // Format used below emulates that of TensorSliceReader::DebugString().
1095   string shape_str;
1096   BundleEntryProto entry;
1097   Seek(kHeaderEntryKey);
1098   for (Next(); Valid(); Next()) {
1099     CHECK(entry.ParseFromArray(value().data(), value().size()));
1100     if (entry.slices_size() > 0) continue;  // Slice of some partitioned var.
1101 
1102     strings::StrAppend(&shape_str, key(), " (", DataType_Name(entry.dtype()),
1103                        ") ", TensorShape(entry.shape()).DebugString());
1104     strings::StrAppend(&shape_str, "\n");
1105   }
1106   return shape_str;
1107 }
1108 
~FileOutputBuffer()1109 FileOutputBuffer::~FileOutputBuffer() { delete file_; }
1110 
Append(StringPiece data)1111 Status FileOutputBuffer::Append(StringPiece data) {
1112   // In the below, it is critical to calculate the checksum on the actually
1113   // copied bytes, not the source bytes.  This is because "data" typically
1114   // points to tensor buffers, which may be concurrently written.
1115   if (data.size() + position_ <= buffer_size_) {
1116     // Can fit into the current buffer.
1117     memcpy(&buffer_[position_], data.data(), data.size());
1118     crc32c_ = crc32c::Extend(crc32c_, &buffer_[position_], data.size());
1119   } else if (data.size() <= buffer_size_) {
1120     // Cannot fit, but can fit after flushing.
1121     TF_RETURN_IF_ERROR(FlushBuffer());
1122     memcpy(&buffer_[0], data.data(), data.size());
1123     crc32c_ = crc32c::Extend(crc32c_, &buffer_[0], data.size());
1124   } else {
1125     // Cannot fit even after flushing.  So we break down "data" by chunk, and
1126     // flush/checksum each chunk.
1127     TF_RETURN_IF_ERROR(FlushBuffer());
1128     for (size_t i = 0; i < data.size(); i += buffer_size_) {
1129       const size_t nbytes = std::min(data.size() - i, buffer_size_);
1130       memcpy(&buffer_[0], data.data() + i, nbytes);
1131       crc32c_ = crc32c::Extend(crc32c_, &buffer_[0], nbytes);
1132       position_ = nbytes;
1133       TF_RETURN_IF_ERROR(FlushBuffer());
1134     }
1135     return Status::OK();
1136   }
1137   position_ += data.size();
1138   return Status::OK();
1139 }
1140 
Close()1141 Status FileOutputBuffer::Close() {
1142   TF_RETURN_IF_ERROR(FlushBuffer());
1143   return file_->Close();
1144 }
1145 
FlushBuffer()1146 Status FileOutputBuffer::FlushBuffer() {
1147   if (position_ > 0) {
1148     TF_RETURN_IF_ERROR(file_->Append(StringPiece(&buffer_[0], position_)));
1149     position_ = 0;
1150   }
1151   return Status::OK();
1152 }
1153 
1154 }  // namespace tensorflow
1155