• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7 http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/util/tensor_bundle/tensor_bundle.h"
17 
18 #include <algorithm>
19 #include <cstdlib>
20 #include <cstring>
21 #include <memory>
22 #include <utility>
23 
24 #include "tensorflow/core/framework/register_types.h"
25 #include "tensorflow/core/framework/tensor.pb.h"
26 #include "tensorflow/core/framework/tensor_shape.pb.h"
27 #include "tensorflow/core/framework/types.h"
28 #include "tensorflow/core/framework/types.pb.h"
29 #include "tensorflow/core/framework/variant.h"
30 #include "tensorflow/core/framework/variant_op_registry.h"
31 #include "tensorflow/core/framework/variant_tensor_data.h"
32 #include "tensorflow/core/framework/versions.h"
33 #include "tensorflow/core/framework/versions.pb.h"
34 #include "tensorflow/core/lib/core/coding.h"
35 #include "tensorflow/core/lib/core/errors.h"
36 #include "tensorflow/core/lib/gtl/map_util.h"
37 #include "tensorflow/core/lib/hash/crc32c.h"
38 #include "tensorflow/core/lib/io/path.h"
39 #include "tensorflow/core/lib/io/table_builder.h"
40 #include "tensorflow/core/lib/random/random.h"
41 #include "tensorflow/core/lib/strings/str_util.h"
42 #include "tensorflow/core/lib/strings/stringprintf.h"
43 #include "tensorflow/core/platform/bfloat16.h"
44 #include "tensorflow/core/platform/errors.h"
45 #include "tensorflow/core/util/env_var.h"
46 #include "tensorflow/core/util/saved_tensor_slice_util.h"
47 #include "tensorflow/core/util/tensor_bundle/byte_swap.h"
48 #include "tensorflow/core/util/tensor_slice_util.h"
49 
50 namespace tensorflow {
51 
52 // Versioning of the tensor bundle format.
53 const int kTensorBundleMinProducer = 0;
54 const int kTensorBundleMinConsumer = 0;
55 const int kTensorBundleVersion = 1;
56 
57 // Size of our input buffer for streaming reads
58 static const int kBufferSize = 1024 * 1024;
59 
60 // Key to the special BundleHeaderProto entry.  Do not change this, as clients
61 // can make the assumption that the header is always the first entry in the
62 // bundle.
63 const char* const kHeaderEntryKey = "";
64 
65 namespace {
66 
67 // Reads "num_elements" string elements from file[offset, offset+size) into the
68 // length-N "destination".  Discards the original content of "destination".
69 //
70 // Checksums the string lengths (as restored uint32 or uint64, not varint64
71 // bytes) and string bytes, and stores it into "actual_crc32c".
ReadStringTensor(io::InputBuffer * buffered_file,size_t num_elements,size_t offset,size_t size,tstring * destination,uint32 * actual_crc32c,bool need_to_swap_bytes)72 Status ReadStringTensor(io::InputBuffer* buffered_file, size_t num_elements,
73                         size_t offset, size_t size, tstring* destination,
74                         uint32* actual_crc32c, bool need_to_swap_bytes) {
75   if (size == 0) return Status::OK();
76   CHECK_GT(size, 0);
77 
78   // Reads "num_elements" varint64's from "buffered_file".
79   TF_RETURN_IF_ERROR(buffered_file->Seek(offset));
80   TF_RETURN_IF_ERROR(buffered_file->Hint(size));
81   std::vector<uint64> string_lengths(num_elements);
82   for (size_t i = 0; i < num_elements; ++i) {
83     TF_RETURN_IF_ERROR(buffered_file->ReadVarint64(&string_lengths[i]));
84     if (string_lengths[i] <= UINT32_MAX) {
85       // We need to do this because older checkpoints only used uint32s and we
86       // should still support them.
87       uint32 elem_size_uint32 = static_cast<uint32>(string_lengths[i]);
88       if (need_to_swap_bytes) {
89         // Checksum would have been computed on the source machine's byte order
90         elem_size_uint32 = BYTE_SWAP_32(elem_size_uint32);
91       }
92       *actual_crc32c = crc32c::Extend(
93           *actual_crc32c, reinterpret_cast<const char*>(&elem_size_uint32),
94           sizeof(uint32));
95     } else {
96       uint64 length = string_lengths[i];
97       if (need_to_swap_bytes) {
98         length = BYTE_SWAP_64(length);
99       }
100       *actual_crc32c =
101           crc32c::Extend(*actual_crc32c, reinterpret_cast<const char*>(&length),
102                          sizeof(uint64));
103     }
104   }
105   if (offset + size < buffered_file->Tell()) {
106     return errors::DataLoss("String lengths longer than expected offset ",
107                             offset + size);
108   }
109 
110   // Reads the length-checksum.
111   uint32 raw_length_checksum = 0;  // Bytes in file
112   uint32 length_checksum = 0;      // In-memory representation
113   size_t unused_bytes_read = 0;
114   TF_RETURN_IF_ERROR(buffered_file->ReadNBytes(
115       sizeof(uint32), reinterpret_cast<char*>(&raw_length_checksum),
116       &unused_bytes_read));
117   length_checksum = need_to_swap_bytes ? BYTE_SWAP_32(raw_length_checksum)
118                                        : raw_length_checksum;
119   if (crc32c::Unmask(length_checksum) != *actual_crc32c) {
120     return errors::DataLoss(
121         "The length checksum does not match: expected ",
122         strings::Printf("%08u", crc32c::Unmask(length_checksum)),
123         " but actual is ", strings::Printf("%08u", *actual_crc32c));
124   }
125   *actual_crc32c = crc32c::Extend(*actual_crc32c,
126                                   reinterpret_cast<char*>(&raw_length_checksum),
127                                   sizeof(uint32));
128 
129   // Reads the actual string bytes.
130   for (size_t i = 0; i < num_elements; ++i) {
131     const uint64 string_length = string_lengths[i];
132     tstring* buffer = &destination[i];
133 
134     buffer->resize(string_length);
135     size_t bytes_read = 0;
136     TF_RETURN_IF_ERROR(
137         buffered_file->ReadNBytes(string_length, &(*buffer)[0], &bytes_read));
138     *actual_crc32c = crc32c::Extend(*actual_crc32c, buffer->data(), bytes_read);
139   }
140   return Status::OK();
141 }
142 
ReadVariantTensor(io::InputBuffer * buffered_file,Tensor * ret,size_t offset,size_t size,uint32 * actual_crc32c)143 Status ReadVariantTensor(io::InputBuffer* buffered_file, Tensor* ret,
144                          size_t offset, size_t size, uint32* actual_crc32c) {
145   // On-disk format:
146   //   [varint64 len1][bytes variant1][4 byte checksum]
147   //   ..
148   //   [varint64 lenN][bytes variantN][4 byte checksum]
149   // Var "crc32c" checksums all the lens, variant bytes, individual variant
150   // checksums (as uint32, not varint32 bytes).
151   if (size == 0) return Status::OK();
152   size_t num_elements = ret->NumElements();
153 
154   // Reads the actual string bytes.
155   TF_RETURN_IF_ERROR(buffered_file->Seek(offset));
156   TF_RETURN_IF_ERROR(buffered_file->Hint(size));
157   for (size_t i = 0; i < num_elements; ++i) {
158     // Read the serialized variant length.
159     uint64 string_length = 0;
160     TF_RETURN_IF_ERROR(buffered_file->ReadVarint64(&string_length));
161     *actual_crc32c = crc32c::Extend(
162         *actual_crc32c, reinterpret_cast<const char*>(&string_length),
163         sizeof(uint64));
164     // Read the actual serialized variant.
165     string buffer;
166     buffer.resize(string_length);
167     size_t bytes_read = 0;
168     TF_RETURN_IF_ERROR(
169         buffered_file->ReadNBytes(string_length, &buffer[0], &bytes_read));
170     *actual_crc32c = crc32c::Extend(*actual_crc32c, buffer.data(), bytes_read);
171     VariantTensorDataProto proto;
172     if (!proto.ParseFromString(buffer)) {
173       return errors::DataLoss("Unable to parse VariantTensorDataProto from ",
174                               "buffer of size ", string_length, ". ",
175                               "Bundle entry offset: ", offset, " size: ", size);
176     }
177     Variant v = proto;
178     if (!DecodeUnaryVariant(&v)) {
179       return errors::Internal("Could not decode variant with type_name: \"",
180                               v.TypeName(), "\".  Perhaps you forgot to ",
181                               "register a decoder via ",
182                               "REGISTER_UNARY_VARIANT_DECODE_FUNCTION?");
183     }
184 
185     // Read the checksum.
186     uint32 checksum = 0;
187     size_t unused_bytes_read = 0;
188     TF_RETURN_IF_ERROR(buffered_file->ReadNBytes(
189         sizeof(uint32), reinterpret_cast<char*>(&checksum),
190         &unused_bytes_read));
191     if (crc32c::Unmask(checksum) != *actual_crc32c) {
192       return errors::DataLoss(
193           "The checksum after Variant ", i, " does not match.",
194           " Expected: ", strings::Printf("%08u", crc32c::Unmask(checksum)),
195           " Actual: ", strings::Printf("%08u", *actual_crc32c));
196     }
197     *actual_crc32c = crc32c::Extend(
198         *actual_crc32c, reinterpret_cast<char*>(&checksum), sizeof(uint32));
199 
200     ret->flat<Variant>()(i) = std::move(v);
201   }
202 
203   return Status::OK();
204 }
205 
GetBackingBuffer(const Tensor & val)206 char* GetBackingBuffer(const Tensor& val) {
207   CHECK(DataTypeCanUseMemcpy(val.dtype())) << val.dtype();
208   return const_cast<char*>(val.tensor_data().data());
209 }
210 
GetStringBackingBuffer(const Tensor & val)211 tstring* GetStringBackingBuffer(const Tensor& val) {
212   CHECK_EQ(DT_STRING, val.dtype());
213   return const_cast<tstring*>(val.flat<tstring>().data());
214 }
215 
ParseEntryProto(StringPiece key,StringPiece value,protobuf::MessageLite * out)216 Status ParseEntryProto(StringPiece key, StringPiece value,
217                        protobuf::MessageLite* out) {
218   if (!out->ParseFromArray(value.data(), value.size())) {
219     return errors::DataLoss("Entry for key ", key, " not parseable.");
220   }
221   return Status::OK();
222 }
223 
224 // Serializes the data bytes of the non-string tensor "val".  Discards the
225 // original content of "bytes_written", and on OK updates it with number of
226 // bytes written.
227 // REQUIRES: val.dtype() != DT_STRING
WriteTensor(const Tensor & val,FileOutputBuffer * out,size_t * bytes_written)228 Status WriteTensor(const Tensor& val, FileOutputBuffer* out,
229                    size_t* bytes_written) {
230   DCHECK_NE(val.dtype(), DT_STRING);
231   DCHECK_NE(val.dtype(), DT_VARIANT);
232   *bytes_written = val.TotalBytes();
233   char* buf = GetBackingBuffer(val);
234   VLOG(1) << "Appending " << *bytes_written << " bytes to file";
235   return out->Append(StringPiece(buf, *bytes_written));
236 }
237 
238 // Serializes string tensor "val".  "bytes_written" is treated in the same
239 // fashion as WriteTensor().
240 //
241 // Checksums all bytes written and stores it into "crc32c".
242 // REQUIRES: val.dtype() == DT_STRING
WriteStringTensor(const Tensor & val,FileOutputBuffer * out,size_t * bytes_written,uint32 * crc32c)243 Status WriteStringTensor(const Tensor& val, FileOutputBuffer* out,
244                          size_t* bytes_written, uint32* crc32c) {
245   // On-disk format:
246   //   [varint64 len0]..[varint64 lenL][4 byte cksum on lengths][string bytes]
247   // Var "crc32c" checksums the string lengths (as uint64, not varint64 bytes),
248   // the length-checksum, and all the string bytes.
249   DCHECK_EQ(val.dtype(), DT_STRING);
250   const tstring* strings = GetStringBackingBuffer(val);
251 
252   // Writes the varint lengths.
253   string lengths;
254   lengths.reserve(val.NumElements());  // At least 1 byte per element.
255   *crc32c = 0;
256   for (int64 i = 0; i < val.NumElements(); ++i) {
257     const tstring* elem = &strings[i];
258     DCHECK_EQ(elem->size(), static_cast<uint64>(elem->size()));
259     const uint64 elem_size = static_cast<uint64>(elem->size());
260 
261     core::PutVarint64(&lengths, elem_size);
262     if (elem_size <= UINT32_MAX) {
263       // We need to do this because older checkpoints only used uint32s and we
264       // should still support them.
265       const uint32 elem_size_uint32 = static_cast<uint32>(elem_size);
266       *crc32c = crc32c::Extend(*crc32c,
267                                reinterpret_cast<const char*>(&elem_size_uint32),
268                                sizeof(uint32));
269     } else {
270       *crc32c = crc32c::Extend(
271           *crc32c, reinterpret_cast<const char*>(&elem_size), sizeof(uint64));
272     }
273   }
274   TF_RETURN_IF_ERROR(out->Append(lengths));
275   *bytes_written = lengths.size();
276 
277   // Writes the length checksum.
278   const uint32 length_checksum = crc32c::Mask(*crc32c);
279   TF_RETURN_IF_ERROR(out->Append(StringPiece(
280       reinterpret_cast<const char*>(&length_checksum), sizeof(uint32))));
281   *crc32c = crc32c::Extend(
282       *crc32c, reinterpret_cast<const char*>(&length_checksum), sizeof(uint32));
283   *bytes_written += sizeof(uint32);
284 
285   // Writes all the string bytes out.
286   for (int64 i = 0; i < val.NumElements(); ++i) {
287     const tstring* string = &strings[i];
288     TF_RETURN_IF_ERROR(out->Append(*string));
289     *bytes_written += string->size();
290     *crc32c = crc32c::Extend(*crc32c, string->data(), string->size());
291   }
292   return Status::OK();
293 }
294 
WriteVariantTensor(const Tensor & val,FileOutputBuffer * out,size_t * bytes_written,uint32 * crc32c)295 Status WriteVariantTensor(const Tensor& val, FileOutputBuffer* out,
296                           size_t* bytes_written, uint32* crc32c) {
297   // On-disk format:
298   //   [varint64 len1][bytes variant1][4 byte checksum]
299   //   ..
300   //   [varint64 lenN][bytes variantN][4 byte checksum]
301   // Var "crc32c" checksums all the lens, variant bytes, individual variant
302   // checksums (as uint32, not varint32 bytes).
303   DCHECK_EQ(val.dtype(), DT_VARIANT);
304 
305   *crc32c = 0;
306   *bytes_written = 0;
307   for (int64 i = 0; i < val.NumElements(); ++i) {
308     VariantTensorData data;
309     val.flat<Variant>()(i).Encode(&data);
310     VariantTensorDataProto proto;
311     data.ToProto(&proto);
312     string elem;
313     if (!proto.SerializeToString(&elem)) {
314       return errors::Unknown(
315           "Failed to serialize tensor data of size ", proto.ByteSizeLong(),
316           ". Tensor: ", val.flat<Variant>()(i).DebugString());
317     }
318 
319     // Write the length of the serialized variant.
320     DCHECK_EQ(elem.size(), static_cast<uint64>(elem.size()));
321     const auto elem_size = static_cast<uint64>(elem.size());
322     string len;
323     core::PutVarint64(&len, elem_size);
324     TF_RETURN_IF_ERROR(out->Append(len));
325     *crc32c = crc32c::Extend(*crc32c, reinterpret_cast<const char*>(&elem_size),
326                              sizeof(uint64));
327     *bytes_written += len.size();
328 
329     // Write the serialized variant.
330     TF_RETURN_IF_ERROR(out->Append(elem));
331     *crc32c = crc32c::Extend(*crc32c, elem.data(), elem.size());
332     *bytes_written += elem.size();
333 
334     // Write the checksum.
335     const uint32 length_checksum = crc32c::Mask(*crc32c);
336     TF_RETURN_IF_ERROR(out->Append(StringPiece(
337         reinterpret_cast<const char*>(&length_checksum), sizeof(uint32))));
338     *crc32c =
339         crc32c::Extend(*crc32c, reinterpret_cast<const char*>(&length_checksum),
340                        sizeof(uint32));
341     *bytes_written += sizeof(uint32);
342   }
343 
344   return Status::OK();
345 }
346 
347 // Returns whether "slice_spec" is a full slice, with respect to the full shape.
348 //
349 // This can happen say, when "slice_spec" is
350 // "TensorSlice(full_tensor_shape.dims())", or when it is "TensorSlice({{0,
351 // dim(0)}, ..., {0, dim(N)}})" -- a degenerate case we need to guard against.
IsFullSlice(const TensorSlice & slice_spec,const TensorShape & full_tensor_shape)352 bool IsFullSlice(const TensorSlice& slice_spec,
353                  const TensorShape& full_tensor_shape) {
354   if (slice_spec.IsFull()) {
355     return true;
356   } else {
357     TensorShape sliced_shape;
358     slice_spec.SliceTensorShape(full_tensor_shape, &sliced_shape).IgnoreError();
359     return sliced_shape == full_tensor_shape;
360   }
361 }
362 
CorruptFileError(const Status & in_status,const string & filename,const string & detail)363 Status CorruptFileError(const Status& in_status, const string& filename,
364                         const string& detail) {
365   if (in_status.ok()) {
366     return errors::Internal("Unable to read file (", filename,
367                             "). Perhaps the file is corrupt or was produced by "
368                             "a newer version of TensorFlow with format changes "
369                             "(",
370                             detail, ")");
371   }
372   return Status(
373       in_status.code(),
374       strings::StrCat("Unable to read file (", filename,
375                       "). Perhaps the file is corrupt or was produced by a "
376                       "newer version of TensorFlow with format changes (",
377                       detail, "): ", in_status.error_message()));
378 }
379 
TableBuilderOptions()380 table::Options TableBuilderOptions() {
381   table::Options o;
382   // Compressed tables cannot be read by TensorFlow releases prior to 1.1.
383   // To smoothen the transition, compressed writes are disabled for now
384   // (version 1.2) with the intention that they will be enabled again at
385   // some point (perhaps the 1.3 release?).
386   o.compression = table::kNoCompression;
387   return o;
388 }
389 
390 // Writes zeros to output buffer to align the next write to the requested
391 // alignment. "size" is the current size of the buffer and is updated to the
392 // new size.
PadAlignment(FileOutputBuffer * out,int alignment,int64 * size)393 Status PadAlignment(FileOutputBuffer* out, int alignment, int64* size) {
394   int bytes_over = *size % alignment;
395   if (bytes_over == 0) {
396     return Status::OK();
397   }
398   int bytes_to_write = alignment - bytes_over;
399   Status status = out->Append(string(bytes_to_write, '\0'));
400   if (status.ok()) {
401     *size += bytes_to_write;
402   }
403   return status;
404 }
405 
406 }  // namespace
407 
BundleWriter(Env * env,StringPiece prefix,const Options & options)408 BundleWriter::BundleWriter(Env* env, StringPiece prefix, const Options& options)
409     : env_(env),
410       options_(options),
411       prefix_(prefix),
412       out_(nullptr),
413       size_(0) {
414   status_ = env_->HasAtomicMove(prefix_, &use_temp_file_);
415   if (!status_.ok()) return;
416 
417   data_path_ = DataFilename(prefix_, 0, 1);
418   metadata_path_ = MetaFilename(prefix_);
419   if (use_temp_file_) {
420     data_path_ = strings::StrCat(data_path_, ".tempstate", random::New64());
421     metadata_path_ =
422         strings::StrCat(metadata_path_, ".tempstate", random::New64());
423   }
424 
425   status_ = env_->CreateDir(string(io::Dirname(prefix_)));
426   if (!status_.ok() && !errors::IsAlreadyExists(status_)) {
427     return;
428   }
429 
430   std::unique_ptr<WritableFile> wrapper;
431   status_ = env_->NewWritableFile(data_path_, &wrapper);
432   if (!status_.ok()) return;
433   out_ = std::unique_ptr<FileOutputBuffer>(
434       new FileOutputBuffer(wrapper.release(), 8 << 20 /* 8MB write buffer */));
435 
436   VLOG(1) << "Writing to file " << data_path_;
437 }
438 
Add(StringPiece key,const Tensor & val)439 Status BundleWriter::Add(StringPiece key, const Tensor& val) {
440   if (!status_.ok()) return status_;
441   CHECK_NE(key, kHeaderEntryKey);
442   const string key_string(key);
443   if (entries_.find(key_string) != entries_.end()) {
444     status_ = errors::InvalidArgument("Adding duplicate key: ", key);
445     return status_;
446   }
447 
448   BundleEntryProto* entry = &entries_[key_string];
449   entry->set_dtype(val.dtype());
450   val.shape().AsProto(entry->mutable_shape());
451   entry->set_shard_id(0);
452   entry->set_offset(size_);
453 
454   // Updates the data file.
455   size_t data_bytes_written = 0;
456   uint32 crc32c = 0;
457   out_->clear_crc32c();
458   if (val.dtype() == DT_STRING) {
459     status_ = WriteStringTensor(val, out_.get(), &data_bytes_written, &crc32c);
460   } else if (val.dtype() == DT_VARIANT) {
461     status_ = WriteVariantTensor(val, out_.get(), &data_bytes_written, &crc32c);
462   } else {
463     status_ = WriteTensor(val, out_.get(), &data_bytes_written);
464     crc32c = out_->crc32c();
465   }
466 
467   if (status_.ok()) {
468     entry->set_size(data_bytes_written);
469     entry->set_crc32c(crc32c::Mask(crc32c));
470     size_ += data_bytes_written;
471     status_ = PadAlignment(out_.get(), options_.data_alignment, &size_);
472   }
473   return status_;
474 }
475 
AddSlice(StringPiece full_tensor_key,const TensorShape & full_tensor_shape,const TensorSlice & slice_spec,const Tensor & slice_tensor)476 Status BundleWriter::AddSlice(StringPiece full_tensor_key,
477                               const TensorShape& full_tensor_shape,
478                               const TensorSlice& slice_spec,
479                               const Tensor& slice_tensor) {
480   if (!status_.ok()) return status_;
481   CHECK_NE(full_tensor_key, kHeaderEntryKey);
482 
483   // If just a singleton full slice, use the regular Add() to be more efficient.
484   if (IsFullSlice(slice_spec, full_tensor_shape)) {
485     return Add(full_tensor_key, slice_tensor);
486   }
487 
488   // Inserts/updates the full tensor's metadata entry.
489   //
490   // In the case of a sharded save, MergeBundles() is responsible for merging
491   // the "slices" field of multiple metadata entries corresponding to the same
492   // full tensor.
493   const string full_tensor_key_string(full_tensor_key);
494   BundleEntryProto* full_entry = &entries_[full_tensor_key_string];
495   if (full_entry->dtype() != DT_INVALID) {
496     CHECK_EQ(full_entry->dtype(), slice_tensor.dtype());
497   }
498   if (full_entry->has_shape()) {
499     CHECK(TensorShape(full_entry->shape()) == full_tensor_shape);
500   }
501 
502   // Populates dtype, shape, and slices.  Intentionally leaving out shard_id and
503   // offset, which do not make sense for this full tensor entry.
504   full_entry->set_dtype(slice_tensor.dtype());
505   full_tensor_shape.AsProto(full_entry->mutable_shape());
506   TensorSliceProto* slice_proto = full_entry->add_slices();
507   slice_spec.AsProto(slice_proto);
508 
509   // The slice itself is handled by a regular Add(), which includes adding its
510   // own metadata entry, and writing out the slice's values.
511   const string slice_name =
512       checkpoint::EncodeTensorNameSlice(full_tensor_key_string, slice_spec);
513   status_ = Add(slice_name, slice_tensor);
514   return status_;
515 }
516 
517 // TODO(zongheng): on metadata write failure or !status_.ok(), consider removing
518 // the orphaned data file.
Finish()519 Status BundleWriter::Finish() {
520   if (out_) {
521     status_.Update(out_->Close());
522     out_ = nullptr;
523     if (status_.ok()) {
524       if (use_temp_file_) {
525         status_ =
526             Env::Default()->RenameFile(data_path_, DataFilename(prefix_, 0, 1));
527       }
528     } else {
529       Env::Default()->DeleteFile(data_path_).IgnoreError();
530     }
531   }
532   if (!status_.ok()) return status_;
533   // Build key -> BundleEntryProto table.
534   std::unique_ptr<WritableFile> file;
535   status_ = env_->NewWritableFile(metadata_path_, &file);
536   if (!status_.ok()) return status_;
537   {
538     // N.B.: the default use of Snappy compression may not be supported on all
539     // platforms (e.g. Android).  The metadata file is small, so this is fine.
540     table::Options options;
541     options.compression = table::kNoCompression;
542     table::TableBuilder builder(options, file.get());
543     // Header entry.
544     BundleHeaderProto header;
545     header.set_num_shards(1);
546     header.set_endianness(BundleHeaderProto::LITTLE);
547     if (!port::kLittleEndian) header.set_endianness(BundleHeaderProto::BIG);
548     VersionDef* version = header.mutable_version();
549     version->set_producer(kTensorBundleVersion);
550     version->set_min_consumer(kTensorBundleMinConsumer);
551 
552     builder.Add(kHeaderEntryKey, header.SerializeAsString());
553 
554     // All others.
555     for (const auto& p : entries_) {
556       builder.Add(p.first, p.second.SerializeAsString());
557     }
558     status_ = builder.Finish();
559   }
560   status_.Update(file->Close());
561   if (!status_.ok()) {
562     Env::Default()->DeleteFile(metadata_path_).IgnoreError();
563     return status_;
564   } else if (use_temp_file_) {
565     status_ = Env::Default()->RenameFile(metadata_path_, MetaFilename(prefix_));
566     if (!status_.ok()) return status_;
567   }
568   status_ = errors::Internal("BundleWriter is closed");
569   return Status::OK();
570 }
571 
572 // Merging tensor bundles.
573 
574 // Accumulator of metadata states during a merge.
575 struct MergeState {
576   // Accumulated from the header entries.
577   int num_shards = 0;
578 
579   // Derives "endianness" and "version" from the first bundle merged (hence the
580   // "seen_first_bundle" guard).  The two fields must be the same for all
581   // bundles in a merge.
582   bool seen_first_bundle = false;
583   BundleHeaderProto_Endianness endianness;
584   VersionDef version;
585 
586   // Tensor key -> BundleEntryProto.
587   std::map<string, BundleEntryProto> entries;
588   // Data file path -> new shard id in the final merged bundle.
589   std::unordered_map<string, int32> shard_ids;
590 };
591 
592 // Merges entries of "prefix" into the accumulator state "merge".
593 // Returns OK iff the merge succeeds.
MergeOneBundle(Env * env,StringPiece prefix,MergeState * merge_state)594 static Status MergeOneBundle(Env* env, StringPiece prefix,
595                              MergeState* merge_state) {
596   VLOG(1) << "Merging bundle:" << prefix;
597   const string filename = MetaFilename(prefix);
598   uint64 file_size;
599   TF_RETURN_IF_ERROR(env->GetFileSize(filename, &file_size));
600   std::unique_ptr<RandomAccessFile> file;
601   TF_RETURN_IF_ERROR(env->NewRandomAccessFile(filename, &file));
602 
603   table::Table* table = nullptr;
604   TF_RETURN_IF_ERROR(
605       table::Table::Open(TableBuilderOptions(), file.get(), file_size, &table));
606   std::unique_ptr<table::Table> table_deleter(table);
607   std::unique_ptr<table::Iterator> iter(table->NewIterator());
608 
609   int num_shards;
610   // Process header.
611   {
612     iter->Seek(kHeaderEntryKey);
613     if (!iter->Valid()) {
614       return CorruptFileError(iter->status(), filename,
615                               "failed to seek to header entry");
616     }
617     BundleHeaderProto header;
618     Status s = ParseEntryProto(iter->key(), iter->value(), &header);
619     if (!s.ok()) return CorruptFileError(s, filename, "unable to parse header");
620 
621     merge_state->num_shards += header.num_shards();
622     if (!merge_state->seen_first_bundle) {
623       merge_state->seen_first_bundle = true;
624       merge_state->endianness = header.endianness();
625       merge_state->version = header.version();
626     } else {
627       // Validates "endianness".
628       if (merge_state->endianness != header.endianness()) {
629         return errors::InvalidArgument(
630             "Merging bundles with conflicting endianness; inputs corrupted?");
631       }
632       // Validates "version".
633       string curr_version, merge_version;
634       header.version().SerializeToString(&curr_version);
635       merge_state->version.SerializeToString(&merge_version);
636       if (curr_version != merge_version) {
637         return errors::InvalidArgument(
638             "Merging bundles with different format versions: merged ",
639             merge_version, " vs. curr ", curr_version);
640       }
641     }
642     num_shards = header.num_shards();
643     iter->Next();
644   }
645 
646   // Loops through the non-header to-merge entries.
647   BundleEntryProto to_merge_entry;
648   for (; iter->Valid(); iter->Next()) {
649     const string key(iter->key());
650     const auto entry_iter = merge_state->entries.find(key);
651 
652     // Illegal: the duplicated entry is a non-slice tensor.
653     if (entry_iter != merge_state->entries.end() &&
654         entry_iter->second.slices().empty()) {
655       return errors::InvalidArgument(
656           "Duplicate tensor keyed by ", key,
657           " encountered, when merging prefix: ", prefix);
658     }
659 
660     TF_RETURN_IF_ERROR(
661         ParseEntryProto(iter->key(), iter->value(), &to_merge_entry));
662 
663     // The duplicated entry holds metadata for a sliced full tensor.
664     // Allows the duplication and merges "slices".
665     if (entry_iter != merge_state->entries.end()) {
666       BundleEntryProto& existing_entry = entry_iter->second;
667       if (to_merge_entry.slices().empty()) {
668         return errors::Internal(
669             "Duplicate tensor keyed by ", key,
670             "; attempting to merge in a non-slice bundle entry");
671       }
672       // Only needs merge the "slices" field (and validate dtype/shape).
673       for (int i = 0; i < to_merge_entry.slices_size(); ++i) {
674         TensorSliceProto* slot = existing_entry.add_slices();
675         *slot = to_merge_entry.slices(i);
676       }
677       CHECK_EQ(existing_entry.dtype(), to_merge_entry.dtype());
678       CHECK(TensorShape(existing_entry.shape()) ==
679             TensorShape(to_merge_entry.shape()));
680       continue;
681     }
682 
683     // Key doesn't duplicate: a fresh tensor/slice entry.
684     auto result = merge_state->shard_ids.insert(
685         {DataFilename(prefix, to_merge_entry.shard_id(), num_shards),
686          merge_state->shard_ids.size()});
687     to_merge_entry.set_shard_id(result.first->second);
688     merge_state->entries[key] = to_merge_entry;
689   }
690   return Status::OK();
691 }
692 
MergeBundles(Env * env,gtl::ArraySlice<tstring> prefixes,StringPiece merged_prefix)693 Status MergeBundles(Env* env, gtl::ArraySlice<tstring> prefixes,
694                     StringPiece merged_prefix) {
695   // Merges all metadata tables.
696   // TODO(zhifengc): KeyValue sorter if it becomes too big.
697   MergeState merge;
698   Status status = env->CreateDir(string(io::Dirname(merged_prefix)));
699   if (!status.ok() && !errors::IsAlreadyExists(status)) return status;
700   for (int i = 0; i < prefixes.size(); ++i) {
701     TF_RETURN_IF_ERROR(MergeOneBundle(env, prefixes[i], &merge));
702   }
703 
704   // Renames data files to contain the merged bundle prefix.
705   for (const auto& p : merge.shard_ids) {
706     VLOG(1) << "Renaming " << p.first << " to "
707             << DataFilename(merged_prefix, p.second, merge.shard_ids.size());
708     TF_RETURN_IF_ERROR(env->RenameFile(
709         p.first,
710         DataFilename(merged_prefix, p.second, merge.shard_ids.size())));
711   }
712 
713   // Writes the final metadata table under the merged prefix.
714   std::unique_ptr<WritableFile> merged_metadata;
715   TF_RETURN_IF_ERROR(
716       env->NewWritableFile(MetaFilename(merged_prefix), &merged_metadata));
717   {
718     table::TableBuilder builder(TableBuilderOptions(), merged_metadata.get());
719     // Header entry.
720     BundleHeaderProto header;
721     header.set_num_shards(merge.num_shards);
722     header.set_endianness(merge.endianness);
723     *header.mutable_version() = merge.version;
724     builder.Add(kHeaderEntryKey, header.SerializeAsString());
725     // All others.
726     for (const auto& p : merge.entries) {
727       builder.Add(p.first, p.second.SerializeAsString());
728     }
729     status = builder.Finish();
730   }
731   status.Update(merged_metadata->Close());
732   if (!status.ok()) return status;
733   VLOG(1) << "Merged bundles to:" << merged_prefix;
734 
735   // Cleanup: best effort based and ignores errors.
736   for (const tstring& prefix : prefixes) {
737     env->DeleteFile(MetaFilename(prefix)).IgnoreError();
738   }
739   return status;
740 }
741 
742 // Interface for reading a tensor bundle.
743 
BundleReader(Env * env,StringPiece prefix)744 BundleReader::BundleReader(Env* env, StringPiece prefix)
745     : env_(env),
746       prefix_(prefix),
747       metadata_(nullptr),
748       table_(nullptr),
749       index_cache_(nullptr),
750       iter_(nullptr),
751       need_to_swap_bytes_(false) {
752   const string filename = MetaFilename(prefix_);
753   uint64 file_size;
754   status_ = env_->GetFileSize(filename, &file_size);
755   if (!status_.ok()) return;
756 
757   // Opens the metadata table.
758   std::unique_ptr<RandomAccessFile> wrapper;
759   status_ = env_->NewRandomAccessFile(filename, &wrapper);
760   if (!status_.ok()) return;
761   metadata_ = wrapper.release();
762 
763   table::Options o;
764   int64 cache_size;
765   Status s =
766       ReadInt64FromEnvVar("TF_TABLE_INDEX_CACHE_SIZE_IN_MB", 0, &cache_size);
767   if (s.ok() && cache_size > 0) {
768     index_cache_ = table::NewLRUCache(cache_size << 20);
769     o.block_cache = index_cache_;
770   }
771 
772   status_ = table::Table::Open(o, metadata_, file_size, &table_);
773   if (!status_.ok()) return;
774   iter_ = table_->NewIterator();
775 
776   // Reads "num_shards_" from the first entry.
777   iter_->Seek(kHeaderEntryKey);
778   if (!iter_->Valid()) {
779     status_ = CorruptFileError(iter_->status(), filename,
780                                "failed to seek to header entry");
781     return;
782   }
783   BundleHeaderProto header;
784   status_ = ParseEntryProto(iter_->key(), iter_->value(), &header);
785   if (!status_.ok()) {
786     status_ = CorruptFileError(status_, filename, "unable to parse header");
787     return;
788   }
789   num_shards_ = header.num_shards();
790   if ((header.endianness() == BundleHeaderProto::BIG && port::kLittleEndian) ||
791       (header.endianness() == BundleHeaderProto::LITTLE &&
792        !port::kLittleEndian)) {
793     need_to_swap_bytes_ = true;
794   }
795   status_ = CheckVersions(header.version(), kTensorBundleVersion,
796                           kTensorBundleMinProducer, "Checkpoint", "checkpoint");
797 }
798 
~BundleReader()799 BundleReader::~BundleReader() {
800   delete metadata_;
801   delete iter_;
802   delete table_;
803   if (index_cache_) {
804     delete index_cache_;
805   }
806   // InputBuffer does not own the underlying RandomAccessFile.
807   for (auto pair : data_) {
808     if (pair.second != nullptr && pair.second->file() != nullptr) {
809       delete pair.second->file();
810     }
811   }
812   for (auto& temp : data_) {
813     delete temp.second;
814   }
815   for (auto& temp : tensor_slices_) {
816     delete temp.second;
817   }
818   data_.clear();
819   tensor_slices_.clear();
820 }
821 
GetBundleEntryProto(StringPiece key,BundleEntryProto * entry)822 Status BundleReader::GetBundleEntryProto(StringPiece key,
823                                          BundleEntryProto* entry) {
824   entry->Clear();
825   TF_CHECK_OK(status_);
826   Seek(key);
827   if (!iter_->Valid() || iter_->key() != key) {
828     return errors::NotFound("Key ", key, " not found in checkpoint");
829   }
830 
831   BundleEntryProto entry_copy;
832   TF_RETURN_IF_ERROR(
833       ParseEntryProto(iter_->key(), iter_->value(), &entry_copy));
834   if (!TensorShape::IsValid(entry_copy.shape())) {
835     return errors::DataLoss("Invalid tensor shape: ", key, " ",
836                             entry_copy.shape().ShortDebugString());
837   }
838 
839   entry->Swap(&entry_copy);
840   return Status::OK();
841 }
842 
GetValue(const BundleEntryProto & entry,Tensor * val)843 Status BundleReader::GetValue(const BundleEntryProto& entry, Tensor* val) {
844   Tensor* ret = val;
845   const TensorShape stored_shape(TensorShape(entry.shape()));
846   if (val->NumElements() == 0) {
847     ret = new Tensor(entry.dtype(), stored_shape);
848   }
849 
850   // Validates the "size" field.
851   if (entry.dtype() != DT_STRING && entry.dtype() != DT_VARIANT) {
852     if (entry.size() != ret->TotalBytes()) {
853       return errors::DataLoss("Invalid size in bundle entry: key ", key(),
854                               "; stored size ", entry.size(),
855                               "; expected size ", ret->TotalBytes());
856     }
857   } else if (entry.dtype() == DT_STRING) {
858     // Relaxes the check for string tensors as follows:
859     //   entry.size() == bytes(varint lengths) + bytes(data)
860     //                >= NumElems + bytes(data), since size bytes(varint) >= 1.
861     //   TotalBytes() == sizeof(tstring) * NumElems + bytes(data)
862     // Since we don't know bytes(varint lengths), we just check an inequality.
863     const size_t lower_bound = ret->NumElements() + ret->TotalBytes() -
864                                sizeof(tstring) * ret->NumElements();
865     if (entry.size() < lower_bound) {
866       return errors::DataLoss("Invalid size in bundle entry: key ", key(),
867                               "; stored size ", entry.size(),
868                               "; expected size is at least ", lower_bound);
869     }
870   }
871 
872   // Open the data file if it has not been opened.
873   io::InputBuffer* buffered_file = data_[entry.shard_id()];
874   if (buffered_file == nullptr) {
875     std::unique_ptr<RandomAccessFile> file = nullptr;
876     TF_RETURN_IF_ERROR(env_->NewRandomAccessFile(
877         DataFilename(prefix_, entry.shard_id(), num_shards_), &file));
878     buffered_file = new io::InputBuffer(file.release(), kBufferSize);
879     // The InputBuffer and RandomAccessFile objects are both released in dtor.
880     data_[entry.shard_id()] = buffered_file;
881   }
882   CHECK(buffered_file != nullptr);
883 
884   TF_RETURN_IF_ERROR(buffered_file->Seek(entry.offset()));
885   uint32 actual_crc32c = 0;
886 
887   if (DataTypeCanUseMemcpy(entry.dtype())) {
888     char* backing_buffer = const_cast<char*>((ret->tensor_data().data()));
889     size_t unused_bytes_read;
890     if (entry.size() > kBufferSize) {
891       StringPiece sp;
892       TF_RETURN_IF_ERROR(buffered_file->file()->Read(
893           entry.offset(), entry.size(), &sp, backing_buffer));
894       if (sp.data() != backing_buffer) {
895         memmove(backing_buffer, sp.data(), entry.size());
896       }
897     } else {
898       TF_RETURN_IF_ERROR(buffered_file->ReadNBytes(entry.size(), backing_buffer,
899                                                    &unused_bytes_read));
900     }
901     // Note that we compute the checksum *before* byte-swapping. The checksum
902     // should be on the bytes in the order they appear in the file.
903     actual_crc32c = crc32c::Value(backing_buffer, entry.size());
904     if (need_to_swap_bytes_) {
905       TF_RETURN_IF_ERROR(ByteSwapTensor(ret));
906     }
907   } else if (entry.dtype() == DT_VARIANT) {
908     if (need_to_swap_bytes_) {
909       return errors::Unimplemented(
910           "TensorBundle at ", prefix_,
911           "is of a different endianness than this machine's hardware, and "
912           "the bundle contains a variant (arbitrary C++ type) tensor. "
913           "Byte-swapping of variant tensors is not currently implemented.");
914     }
915     // Relies on io::InputBuffer's buffering, because we issue many neighboring
916     // reads for a single string tensor.
917     TF_RETURN_IF_ERROR(ReadVariantTensor(buffered_file, ret, entry.offset(),
918                                          entry.size(), &actual_crc32c));
919   } else {
920     // Relies on io::InputBuffer's buffering, because we issue many neighboring
921     // reads for a single string tensor.
922     TF_RETURN_IF_ERROR(ReadStringTensor(
923         buffered_file, ret->NumElements(), entry.offset(), entry.size(),
924         GetStringBackingBuffer(*ret), &actual_crc32c, need_to_swap_bytes_));
925   }
926   if (crc32c::Unmask(entry.crc32c()) != actual_crc32c) {
927     return errors::DataLoss(
928         "TensorBundle at ", prefix_, " shard ", entry.shard_id(), " (",
929         entry.size(), " bytes): Checksum does not match: stored ",
930         strings::Printf("%08u", crc32c::Unmask(entry.crc32c())),
931         " vs. calculated on the restored bytes ", actual_crc32c);
932   }
933 
934   *val = *ret;
935   if (ret != val) delete ret;
936   return Status::OK();
937 }
938 
Lookup(StringPiece key,Tensor * val)939 Status BundleReader::Lookup(StringPiece key, Tensor* val) {
940   CHECK(val != nullptr);
941   BundleEntryProto entry;
942   TF_RETURN_IF_ERROR(GetBundleEntryProto(key, &entry));
943 
944   if (entry.slices().empty()) {
945     return GetValue(entry, val);
946   } else {
947     return GetSliceValue(
948         key, entry,
949         /* a full slice */ TensorSlice(TensorShape(entry.shape()).dims()), val);
950   }
951 }
952 
ReadCurrent(Tensor * val)953 Status BundleReader::ReadCurrent(Tensor* val) {
954   CHECK(val != nullptr);
955   BundleEntryProto entry;
956   TF_RETURN_IF_ERROR(ParseEntryProto(iter_->key(), iter_->value(), &entry));
957   if (!TensorShape::IsValid(entry.shape())) {
958     return errors::DataLoss("Invalid tensor shape: ", iter_->key(), " ",
959                             entry.shape().ShortDebugString());
960   }
961 
962   if (entry.slices().empty()) {
963     return GetValue(entry, val);
964   } else {
965     return GetSliceValue(
966         iter_->key(), entry,
967         /* a full slice */ TensorSlice(TensorShape(entry.shape()).dims()), val);
968   }
969 }
970 
LookupTensorSlices(StringPiece key,std::vector<TensorSlice> * slices)971 Status BundleReader::LookupTensorSlices(StringPiece key,
972                                         std::vector<TensorSlice>* slices) {
973   slices->clear();
974   BundleEntryProto entry;
975   TF_RETURN_IF_ERROR(GetBundleEntryProto(key, &entry));
976   slices->reserve(entry.slices_size());
977   for (const auto& slice : entry.slices()) {
978     slices->emplace_back(slice);
979   }
980   return Status::OK();
981 }
982 
LookupSlice(StringPiece full_tensor_key,const TensorSlice & slice_spec,Tensor * val)983 Status BundleReader::LookupSlice(StringPiece full_tensor_key,
984                                  const TensorSlice& slice_spec, Tensor* val) {
985   CHECK(val != nullptr);
986   BundleEntryProto entry;
987   TF_RETURN_IF_ERROR(GetBundleEntryProto(full_tensor_key, &entry));
988   return GetSliceValue(full_tensor_key, entry, slice_spec, val);
989 }
990 
GetSliceValue(StringPiece full_tensor_key,const BundleEntryProto & full_tensor_entry,const TensorSlice & slice_spec,Tensor * val)991 Status BundleReader::GetSliceValue(StringPiece full_tensor_key,
992                                    const BundleEntryProto& full_tensor_entry,
993                                    const TensorSlice& slice_spec, Tensor* val) {
994   using checkpoint::RegisterTensorSlice;
995   using checkpoint::TensorSliceSet;
996   DCHECK_GE(full_tensor_entry.slices_size(), 0);
997 
998   const TensorShape full_shape(TensorShape(full_tensor_entry.shape()));
999   std::vector<std::pair<TensorSlice, string>> details;
1000   const string full_tensor_key_string(full_tensor_key);
1001   const TensorSliceSet* tss =
1002       gtl::FindPtrOrNull(tensor_slices_, full_tensor_key_string);
1003 
1004   // Populates the "full tensor key -> TensorSliceSet" cache.
1005   if (tss == nullptr) {
1006     if (full_tensor_entry.slices().empty()) {
1007       // Special case: a writer has saved a tensor fully, but the reader wants
1008       // to read in slices.  We therefore register the full slice on-demand here
1009       // without further complicating the on-disk bundle format.
1010       TF_RETURN_IF_ERROR(RegisterTensorSlice(
1011           full_tensor_key_string, full_shape, full_tensor_entry.dtype(),
1012           /* tag */ "",
1013           /* full slice */ TensorSlice(full_shape.dims()), &tensor_slices_));
1014     }
1015     for (const TensorSliceProto& slice : full_tensor_entry.slices()) {
1016       TF_RETURN_IF_ERROR(RegisterTensorSlice(
1017           full_tensor_key_string, full_shape, full_tensor_entry.dtype(),
1018           /* tag */ "", TensorSlice(slice), &tensor_slices_));
1019     }
1020     tss = gtl::FindPtrOrNull(tensor_slices_, full_tensor_key_string);
1021     CHECK_NE(tss, nullptr);
1022   }
1023   if (!tss->QueryMeta(slice_spec, &details)) {
1024     return errors::InvalidArgument(
1025         "Does not have sufficient slices for partitioned tensor ",
1026         full_tensor_key,
1027         " to restore in slice_spec: ", slice_spec.DebugString());
1028   }
1029 
1030   // The union of the slices in "details" covers "slice_spec".  Performs the
1031   // copies from each.
1032   BundleEntryProto stored_slice_entry = full_tensor_entry;
1033   for (const auto& slice_tag_pair : details) {
1034     // Seeks for the stored slice.
1035     const TensorSlice& stored_slice = slice_tag_pair.first;
1036 
1037     // We already have the entry for the full tensor, so don't query again if
1038     // the slice is full.
1039     if (!stored_slice.IsFull()) {
1040       const string encoded_stored_slice_name =
1041           checkpoint::EncodeTensorNameSlice(full_tensor_key_string,
1042                                             stored_slice);
1043       status_ =
1044           GetBundleEntryProto(encoded_stored_slice_name, &stored_slice_entry);
1045       if (!status_.ok()) return status_;
1046     }
1047 
1048     // TODO(zongheng): should we take an OpKernelContext, so that we can call
1049     // allocate_temp()?  Note that without major refactorings to Saver, it's
1050     // hard for the caller of the tensor bundle module to allocate these
1051     // precisely-shaped scratch storage.
1052 
1053     // Optimization for the common case: the stored slice can be directly
1054     // copied to the destination without additional slicing. This is true when
1055     // either the slices are equal or when they are both full slices having the
1056     // same shape.
1057     TensorShape stored_slice_shape(stored_slice_entry.shape());
1058     if (stored_slice == slice_spec ||
1059         (stored_slice_shape == val->shape() &&
1060          IsFullSlice(stored_slice, stored_slice_shape) &&
1061          IsFullSlice(slice_spec, stored_slice_shape))) {
1062       VLOG(1) << "Optimized for common case: directly copying into "
1063                  "pre-allocated buffer; spec: "
1064               << slice_spec.DebugString();
1065       status_ = GetValue(stored_slice_entry, val);
1066       return status_;
1067     }
1068 
1069     Tensor stored_slice_tensor(stored_slice_entry.dtype(), stored_slice_shape);
1070     status_ = GetValue(stored_slice_entry, &stored_slice_tensor);
1071     if (!status_.ok()) return status_;
1072 
1073     // Copies the intersection over.
1074     const DataType common_dtype = full_tensor_entry.dtype();
1075     switch (common_dtype) {
1076 #define HANDLE_COPY(T)                                                 \
1077   case DataTypeToEnum<T>::value:                                       \
1078     CHECK(CopyDataFromTensorSliceToTensorSlice(                        \
1079         full_shape, stored_slice, slice_spec,                          \
1080         stored_slice_tensor.flat<T>().data(), val->flat<T>().data())); \
1081     break;
1082 
1083       HANDLE_COPY(float)
1084       HANDLE_COPY(double)
1085       HANDLE_COPY(int32)
1086       HANDLE_COPY(uint8)
1087       HANDLE_COPY(int16)
1088       HANDLE_COPY(int8)
1089       HANDLE_COPY(complex64)
1090       HANDLE_COPY(complex128)
1091       HANDLE_COPY(int64)
1092       HANDLE_COPY(bool)
1093       HANDLE_COPY(qint32)
1094       HANDLE_COPY(quint8)
1095       HANDLE_COPY(qint8)
1096       HANDLE_COPY(bfloat16)
1097       default:
1098         return errors::InvalidArgument("Dtype ", DataTypeString(common_dtype),
1099                                        " not supported.");
1100     }
1101 #undef HANDLE_COPY
1102   }
1103   return Status::OK();
1104 }
1105 
Contains(StringPiece key)1106 bool BundleReader::Contains(StringPiece key) {
1107   Seek(key);
1108   return Valid() && (this->key() == key);
1109 }
1110 
LookupDtypeAndShape(StringPiece key,DataType * dtype,TensorShape * shape)1111 Status BundleReader::LookupDtypeAndShape(StringPiece key, DataType* dtype,
1112                                          TensorShape* shape) {
1113   BundleEntryProto entry;
1114   TF_RETURN_IF_ERROR(GetBundleEntryProto(key, &entry));
1115   *dtype = entry.dtype();
1116   *shape = TensorShape(entry.shape());
1117   return Status::OK();
1118 }
1119 
LookupTensorShape(StringPiece key,TensorShape * shape)1120 Status BundleReader::LookupTensorShape(StringPiece key, TensorShape* shape) {
1121   DataType ignored;
1122   return LookupDtypeAndShape(key, &ignored, shape);
1123 }
1124 
DebugString()1125 string BundleReader::DebugString() {
1126   // Format used below emulates that of TensorSliceReader::DebugString().
1127   string shape_str;
1128   BundleEntryProto entry;
1129   Seek(kHeaderEntryKey);
1130   for (Next(); Valid(); Next()) {
1131     CHECK(entry.ParseFromArray(value().data(), value().size()));
1132     if (entry.slices_size() > 0) continue;  // Slice of some partitioned var.
1133 
1134     strings::StrAppend(&shape_str, key(), " (", DataType_Name(entry.dtype()),
1135                        ") ", TensorShape(entry.shape()).DebugString());
1136     strings::StrAppend(&shape_str, "\n");
1137   }
1138   return shape_str;
1139 }
1140 
~FileOutputBuffer()1141 FileOutputBuffer::~FileOutputBuffer() { delete file_; }
1142 
Append(StringPiece data)1143 Status FileOutputBuffer::Append(StringPiece data) {
1144   // In the below, it is critical to calculate the checksum on the actually
1145   // copied bytes, not the source bytes.  This is because "data" typically
1146   // points to tensor buffers, which may be concurrently written.
1147   if (data.size() + position_ <= buffer_size_) {
1148     // Can fit into the current buffer.
1149     memcpy(&buffer_[position_], data.data(), data.size());
1150     crc32c_ = crc32c::Extend(crc32c_, &buffer_[position_], data.size());
1151   } else if (data.size() <= buffer_size_) {
1152     // Cannot fit, but can fit after flushing.
1153     TF_RETURN_IF_ERROR(FlushBuffer());
1154     memcpy(&buffer_[0], data.data(), data.size());
1155     crc32c_ = crc32c::Extend(crc32c_, &buffer_[0], data.size());
1156   } else {
1157     // Cannot fit even after flushing.  So we break down "data" by chunk, and
1158     // flush/checksum each chunk.
1159     TF_RETURN_IF_ERROR(FlushBuffer());
1160     for (size_t i = 0; i < data.size(); i += buffer_size_) {
1161       const size_t nbytes = std::min(data.size() - i, buffer_size_);
1162       memcpy(&buffer_[0], data.data() + i, nbytes);
1163       crc32c_ = crc32c::Extend(crc32c_, &buffer_[0], nbytes);
1164       position_ = nbytes;
1165       TF_RETURN_IF_ERROR(FlushBuffer());
1166     }
1167     return Status::OK();
1168   }
1169   position_ += data.size();
1170   return Status::OK();
1171 }
1172 
Close()1173 Status FileOutputBuffer::Close() {
1174   TF_RETURN_IF_ERROR(FlushBuffer());
1175   return file_->Close();
1176 }
1177 
FlushBuffer()1178 Status FileOutputBuffer::FlushBuffer() {
1179   if (position_ > 0) {
1180     TF_RETURN_IF_ERROR(file_->Append(StringPiece(&buffer_[0], position_)));
1181     position_ = 0;
1182   }
1183   return Status::OK();
1184 }
1185 
1186 }  // namespace tensorflow
1187