• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7 http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/util/tensor_bundle/tensor_bundle.h"
17 
18 #include <algorithm>
19 #include <cstdlib>
20 #include <cstring>
21 #include <memory>
22 #include <utility>
23 
24 #include "tensorflow/core/framework/register_types.h"
25 #include "tensorflow/core/framework/tensor.pb.h"
26 #include "tensorflow/core/framework/tensor_shape.pb.h"
27 #include "tensorflow/core/framework/types.h"
28 #include "tensorflow/core/framework/types.pb.h"
29 #include "tensorflow/core/framework/variant.h"
30 #include "tensorflow/core/framework/variant_op_registry.h"
31 #include "tensorflow/core/framework/variant_tensor_data.h"
32 #include "tensorflow/core/framework/versions.h"
33 #include "tensorflow/core/framework/versions.pb.h"
34 #include "tensorflow/core/lib/core/coding.h"
35 #include "tensorflow/core/lib/core/errors.h"
36 #include "tensorflow/core/lib/gtl/map_util.h"
37 #include "tensorflow/core/lib/hash/crc32c.h"
38 #include "tensorflow/core/lib/io/path.h"
39 #include "tensorflow/core/lib/io/table_builder.h"
40 #include "tensorflow/core/lib/random/random.h"
41 #include "tensorflow/core/lib/strings/str_util.h"
42 #include "tensorflow/core/lib/strings/stringprintf.h"
43 #include "tensorflow/core/platform/bfloat16.h"
44 #include "tensorflow/core/platform/cord.h"
45 #include "tensorflow/core/platform/errors.h"
46 #include "tensorflow/core/platform/mem.h"
47 #include "tensorflow/core/util/env_var.h"
48 #include "tensorflow/core/util/saved_tensor_slice_util.h"
49 #include "tensorflow/core/util/tensor_bundle/byte_swap.h"
50 #include "tensorflow/core/util/tensor_slice_util.h"
51 
52 #ifdef PLATFORM_WINDOWS
53 #undef DeleteFile
54 #endif
55 
56 namespace tensorflow {
57 
58 // Versioning of the tensor bundle format.
59 const int kTensorBundleMinProducer = 0;
60 const int kTensorBundleMinConsumer = 0;
61 const int kTensorBundleVersion = 1;
62 
63 // Size of our input buffer for streaming reads
64 static const int kBufferSize = 1024 * 1024;
65 
66 // Key to the special BundleHeaderProto entry.  Do not change this, as clients
67 // can make the assumption that the header is always the first entry in the
68 // bundle.
69 const char* const kHeaderEntryKey = "";
70 
71 namespace {
72 
73 // Reads "num_elements" string elements from file[offset, offset+size) into the
74 // length-N "destination".  Discards the original content of "destination".
75 //
76 // Checksums the string lengths (as restored uint32 or uint64, not varint64
77 // bytes) and string bytes, and stores it into "actual_crc32c".
ReadStringTensor(io::InputBuffer * buffered_file,size_t num_elements,size_t offset,size_t size,tstring * destination,uint32 * actual_crc32c,bool need_to_swap_bytes)78 Status ReadStringTensor(io::InputBuffer* buffered_file, size_t num_elements,
79                         size_t offset, size_t size, tstring* destination,
80                         uint32* actual_crc32c, bool need_to_swap_bytes) {
81   if (size == 0) return Status::OK();
82   CHECK_GT(size, 0);
83 
84   // Reads "num_elements" varint64's from "buffered_file".
85   TF_RETURN_IF_ERROR(buffered_file->Seek(offset));
86   TF_RETURN_IF_ERROR(buffered_file->Hint(size));
87   std::vector<uint64> string_lengths(num_elements);
88   for (size_t i = 0; i < num_elements; ++i) {
89     TF_RETURN_IF_ERROR(buffered_file->ReadVarint64(&string_lengths[i]));
90     if (string_lengths[i] <= UINT32_MAX) {
91       // We need to do this because older checkpoints only used uint32s and we
92       // should still support them.
93       uint32 elem_size_uint32 = static_cast<uint32>(string_lengths[i]);
94       if (need_to_swap_bytes) {
95         // Checksum would have been computed on the source machine's byte order
96         elem_size_uint32 = BYTE_SWAP_32(elem_size_uint32);
97       }
98       *actual_crc32c = crc32c::Extend(
99           *actual_crc32c, reinterpret_cast<const char*>(&elem_size_uint32),
100           sizeof(uint32));
101     } else {
102       uint64 length = string_lengths[i];
103       if (need_to_swap_bytes) {
104         length = BYTE_SWAP_64(length);
105       }
106       *actual_crc32c =
107           crc32c::Extend(*actual_crc32c, reinterpret_cast<const char*>(&length),
108                          sizeof(uint64));
109     }
110   }
111   if (offset + size < buffered_file->Tell()) {
112     return errors::DataLoss("String lengths longer than expected offset ",
113                             offset + size);
114   }
115 
116   // Reads the length-checksum.
117   uint32 raw_length_checksum = 0;  // Bytes in file
118   uint32 length_checksum = 0;      // In-memory representation
119   size_t unused_bytes_read = 0;
120   TF_RETURN_IF_ERROR(buffered_file->ReadNBytes(
121       sizeof(uint32), reinterpret_cast<char*>(&raw_length_checksum),
122       &unused_bytes_read));
123   length_checksum = need_to_swap_bytes ? BYTE_SWAP_32(raw_length_checksum)
124                                        : raw_length_checksum;
125   if (crc32c::Unmask(length_checksum) != *actual_crc32c) {
126     return errors::DataLoss(
127         "The length checksum does not match: expected ",
128         strings::Printf("%08u", crc32c::Unmask(length_checksum)),
129         " but actual is ", strings::Printf("%08u", *actual_crc32c));
130   }
131   *actual_crc32c = crc32c::Extend(*actual_crc32c,
132                                   reinterpret_cast<char*>(&raw_length_checksum),
133                                   sizeof(uint32));
134 
135   // Reads the actual string bytes.
136   for (size_t i = 0; i < num_elements; ++i) {
137     const uint64 string_length = string_lengths[i];
138     tstring* buffer = &destination[i];
139 
140     buffer->resize(string_length);
141     size_t bytes_read = 0;
142     TF_RETURN_IF_ERROR(
143         buffered_file->ReadNBytes(string_length, &(*buffer)[0], &bytes_read));
144     *actual_crc32c = crc32c::Extend(*actual_crc32c, buffer->data(), bytes_read);
145   }
146   return Status::OK();
147 }
148 
ReadVariantTensor(io::InputBuffer * buffered_file,Tensor * ret,size_t offset,size_t size,uint32 * actual_crc32c)149 Status ReadVariantTensor(io::InputBuffer* buffered_file, Tensor* ret,
150                          size_t offset, size_t size, uint32* actual_crc32c) {
151   // On-disk format:
152   //   [varint64 len1][bytes variant1][4 byte checksum]
153   //   ..
154   //   [varint64 lenN][bytes variantN][4 byte checksum]
155   // Var "crc32c" checksums all the lens, variant bytes, individual variant
156   // checksums (as uint32, not varint32 bytes).
157   if (size == 0) return Status::OK();
158   size_t num_elements = ret->NumElements();
159 
160   // Reads the actual string bytes.
161   TF_RETURN_IF_ERROR(buffered_file->Seek(offset));
162   TF_RETURN_IF_ERROR(buffered_file->Hint(size));
163   for (size_t i = 0; i < num_elements; ++i) {
164     // Read the serialized variant length.
165     uint64 string_length = 0;
166     TF_RETURN_IF_ERROR(buffered_file->ReadVarint64(&string_length));
167     *actual_crc32c = crc32c::Extend(
168         *actual_crc32c, reinterpret_cast<const char*>(&string_length),
169         sizeof(uint64));
170     // Read the actual serialized variant.
171     string buffer;
172     buffer.resize(string_length);
173     size_t bytes_read = 0;
174     TF_RETURN_IF_ERROR(
175         buffered_file->ReadNBytes(string_length, &buffer[0], &bytes_read));
176     *actual_crc32c = crc32c::Extend(*actual_crc32c, buffer.data(), bytes_read);
177     VariantTensorDataProto proto;
178     if (!proto.ParseFromString(buffer)) {
179       return errors::DataLoss("Unable to parse VariantTensorDataProto from ",
180                               "buffer of size ", string_length, ". ",
181                               "Bundle entry offset: ", offset, " size: ", size);
182     }
183     Variant v = proto;
184     if (!DecodeUnaryVariant(&v)) {
185       return errors::Internal("Could not decode variant with type_name: \"",
186                               v.TypeName(), "\".  Perhaps you forgot to ",
187                               "register a decoder via ",
188                               "REGISTER_UNARY_VARIANT_DECODE_FUNCTION?");
189     }
190 
191     // Read the checksum.
192     uint32 checksum = 0;
193     size_t unused_bytes_read = 0;
194     TF_RETURN_IF_ERROR(buffered_file->ReadNBytes(
195         sizeof(uint32), reinterpret_cast<char*>(&checksum),
196         &unused_bytes_read));
197     if (crc32c::Unmask(checksum) != *actual_crc32c) {
198       return errors::DataLoss(
199           "The checksum after Variant ", i, " does not match.",
200           " Expected: ", strings::Printf("%08u", crc32c::Unmask(checksum)),
201           " Actual: ", strings::Printf("%08u", *actual_crc32c));
202     }
203     *actual_crc32c = crc32c::Extend(
204         *actual_crc32c, reinterpret_cast<char*>(&checksum), sizeof(uint32));
205 
206     ret->flat<Variant>()(i) = std::move(v);
207   }
208 
209   return Status::OK();
210 }
211 
GetBackingBuffer(const Tensor & val)212 char* GetBackingBuffer(const Tensor& val) {
213   CHECK(DataTypeCanUseMemcpy(val.dtype())) << val.dtype();
214   return const_cast<char*>(val.tensor_data().data());
215 }
216 
GetStringBackingBuffer(const Tensor & val)217 tstring* GetStringBackingBuffer(const Tensor& val) {
218   CHECK_EQ(DT_STRING, val.dtype());
219   return const_cast<tstring*>(val.flat<tstring>().data());
220 }
221 
ParseEntryProto(StringPiece key,StringPiece value,protobuf::MessageLite * out)222 Status ParseEntryProto(StringPiece key, StringPiece value,
223                        protobuf::MessageLite* out) {
224   if (!out->ParseFromArray(value.data(), value.size())) {
225     return errors::DataLoss("Entry for key ", key, " not parseable.");
226   }
227   return Status::OK();
228 }
229 
230 // Serializes the data bytes of the non-string tensor "val".  Discards the
231 // original content of "bytes_written", and on OK updates it with number of
232 // bytes written.
233 // REQUIRES: val.dtype() != DT_STRING
WriteTensor(const Tensor & val,FileOutputBuffer * out,size_t * bytes_written)234 Status WriteTensor(const Tensor& val, FileOutputBuffer* out,
235                    size_t* bytes_written) {
236   DCHECK_NE(val.dtype(), DT_STRING);
237   DCHECK_NE(val.dtype(), DT_VARIANT);
238   *bytes_written = val.TotalBytes();
239   char* buf = GetBackingBuffer(val);
240   VLOG(1) << "Appending " << *bytes_written << " bytes to file";
241   return out->Append(StringPiece(buf, *bytes_written));
242 }
243 
244 // Serializes string tensor "val".  "bytes_written" is treated in the same
245 // fashion as WriteTensor().
246 //
247 // Checksums all bytes written and stores it into "crc32c".
248 // REQUIRES: val.dtype() == DT_STRING
WriteStringTensor(const Tensor & val,FileOutputBuffer * out,size_t * bytes_written,uint32 * crc32c)249 Status WriteStringTensor(const Tensor& val, FileOutputBuffer* out,
250                          size_t* bytes_written, uint32* crc32c) {
251   // On-disk format:
252   //   [varint64 len0]..[varint64 lenL][4 byte cksum on lengths][string bytes]
253   // Var "crc32c" checksums the string lengths (as uint64, not varint64 bytes),
254   // the length-checksum, and all the string bytes.
255   DCHECK_EQ(val.dtype(), DT_STRING);
256   const tstring* strings = GetStringBackingBuffer(val);
257 
258   // Writes the varint lengths.
259   string lengths;
260   lengths.reserve(val.NumElements());  // At least 1 byte per element.
261   *crc32c = 0;
262   for (int64_t i = 0; i < val.NumElements(); ++i) {
263     const tstring* elem = &strings[i];
264     DCHECK_EQ(elem->size(), static_cast<uint64>(elem->size()));
265     const uint64 elem_size = static_cast<uint64>(elem->size());
266 
267     core::PutVarint64(&lengths, elem_size);
268     if (elem_size <= UINT32_MAX) {
269       // We need to do this because older checkpoints only used uint32s and we
270       // should still support them.
271       const uint32 elem_size_uint32 = static_cast<uint32>(elem_size);
272       *crc32c = crc32c::Extend(*crc32c,
273                                reinterpret_cast<const char*>(&elem_size_uint32),
274                                sizeof(uint32));
275     } else {
276       *crc32c = crc32c::Extend(
277           *crc32c, reinterpret_cast<const char*>(&elem_size), sizeof(uint64));
278     }
279   }
280   TF_RETURN_IF_ERROR(out->Append(lengths));
281   *bytes_written = lengths.size();
282 
283   // Writes the length checksum.
284   const uint32 length_checksum = crc32c::Mask(*crc32c);
285   TF_RETURN_IF_ERROR(out->Append(StringPiece(
286       reinterpret_cast<const char*>(&length_checksum), sizeof(uint32))));
287   *crc32c = crc32c::Extend(
288       *crc32c, reinterpret_cast<const char*>(&length_checksum), sizeof(uint32));
289   *bytes_written += sizeof(uint32);
290 
291   // Writes all the string bytes out.
292   for (int64_t i = 0; i < val.NumElements(); ++i) {
293     const tstring* string = &strings[i];
294     TF_RETURN_IF_ERROR(out->Append(*string));
295     *bytes_written += string->size();
296     *crc32c = crc32c::Extend(*crc32c, string->data(), string->size());
297   }
298   return Status::OK();
299 }
300 
WriteVariantTensor(const Tensor & val,FileOutputBuffer * out,size_t * bytes_written,uint32 * crc32c)301 Status WriteVariantTensor(const Tensor& val, FileOutputBuffer* out,
302                           size_t* bytes_written, uint32* crc32c) {
303   // On-disk format:
304   //   [varint64 len1][bytes variant1][4 byte checksum]
305   //   ..
306   //   [varint64 lenN][bytes variantN][4 byte checksum]
307   // Var "crc32c" checksums all the lens, variant bytes, individual variant
308   // checksums (as uint32, not varint32 bytes).
309   DCHECK_EQ(val.dtype(), DT_VARIANT);
310 
311   *crc32c = 0;
312   *bytes_written = 0;
313   for (int64_t i = 0; i < val.NumElements(); ++i) {
314     VariantTensorData data;
315     val.flat<Variant>()(i).Encode(&data);
316     VariantTensorDataProto proto;
317     data.ToProto(&proto);
318     string elem;
319     if (!proto.SerializeToString(&elem)) {
320       return errors::Unknown(
321           "Failed to serialize tensor data of size ", proto.ByteSizeLong(),
322           ". Tensor: ", val.flat<Variant>()(i).DebugString());
323     }
324 
325     // Write the length of the serialized variant.
326     DCHECK_EQ(elem.size(), static_cast<uint64>(elem.size()));
327     const auto elem_size = static_cast<uint64>(elem.size());
328     string len;
329     core::PutVarint64(&len, elem_size);
330     TF_RETURN_IF_ERROR(out->Append(len));
331     *crc32c = crc32c::Extend(*crc32c, reinterpret_cast<const char*>(&elem_size),
332                              sizeof(uint64));
333     *bytes_written += len.size();
334 
335     // Write the serialized variant.
336     TF_RETURN_IF_ERROR(out->Append(elem));
337     *crc32c = crc32c::Extend(*crc32c, elem.data(), elem.size());
338     *bytes_written += elem.size();
339 
340     // Write the checksum.
341     const uint32 length_checksum = crc32c::Mask(*crc32c);
342     TF_RETURN_IF_ERROR(out->Append(StringPiece(
343         reinterpret_cast<const char*>(&length_checksum), sizeof(uint32))));
344     *crc32c =
345         crc32c::Extend(*crc32c, reinterpret_cast<const char*>(&length_checksum),
346                        sizeof(uint32));
347     *bytes_written += sizeof(uint32);
348   }
349 
350   return Status::OK();
351 }
352 
353 // Returns whether "slice_spec" is a full slice, with respect to the full shape.
354 //
355 // This can happen say, when "slice_spec" is
356 // "TensorSlice(full_tensor_shape.dims())", or when it is "TensorSlice({{0,
357 // dim(0)}, ..., {0, dim(N)}})" -- a degenerate case we need to guard against.
IsFullSlice(const TensorSlice & slice_spec,const TensorShape & full_tensor_shape)358 bool IsFullSlice(const TensorSlice& slice_spec,
359                  const TensorShape& full_tensor_shape) {
360   if (slice_spec.IsFull()) {
361     return true;
362   } else {
363     TensorShape sliced_shape;
364     slice_spec.SliceTensorShape(full_tensor_shape, &sliced_shape).IgnoreError();
365     return sliced_shape == full_tensor_shape;
366   }
367 }
368 
CorruptFileError(const Status & in_status,const string & filename,const string & detail)369 Status CorruptFileError(const Status& in_status, const string& filename,
370                         const string& detail) {
371   if (in_status.ok()) {
372     return errors::Internal("Unable to read file (", filename,
373                             "). Perhaps the file is corrupt or was produced by "
374                             "a newer version of TensorFlow with format changes "
375                             "(",
376                             detail, ")");
377   }
378   return Status(
379       in_status.code(),
380       strings::StrCat("Unable to read file (", filename,
381                       "). Perhaps the file is corrupt or was produced by a "
382                       "newer version of TensorFlow with format changes (",
383                       detail, "): ", in_status.error_message()));
384 }
385 
TableBuilderOptions()386 table::Options TableBuilderOptions() {
387   table::Options o;
388   // Compressed tables cannot be read by TensorFlow releases prior to 1.1.
389   // To smoothen the transition, compressed writes are disabled for now
390   // (version 1.2) with the intention that they will be enabled again at
391   // some point (perhaps the 1.3 release?).
392   o.compression = table::kNoCompression;
393   return o;
394 }
395 
396 // Writes zeros to output buffer to align the next write to the requested
397 // alignment. "size" is the current size of the buffer and is updated to the
398 // new size.
PadAlignment(FileOutputBuffer * out,int alignment,int64 * size)399 Status PadAlignment(FileOutputBuffer* out, int alignment, int64* size) {
400   int bytes_over = *size % alignment;
401   if (bytes_over == 0) {
402     return Status::OK();
403   }
404   int bytes_to_write = alignment - bytes_over;
405   Status status = out->Append(string(bytes_to_write, '\0'));
406   if (status.ok()) {
407     *size += bytes_to_write;
408   }
409   return status;
410 }
411 
412 }  // namespace
413 
BundleWriter(Env * env,StringPiece prefix,const Options & options)414 BundleWriter::BundleWriter(Env* env, StringPiece prefix, const Options& options)
415     : env_(env), options_(options), prefix_(prefix), out_(nullptr), size_(0) {
416   status_ = env_->HasAtomicMove(prefix_, &use_temp_file_);
417   if (!status_.ok()) return;
418 
419   data_path_ = DataFilename(prefix_, 0, 1);
420   metadata_path_ = MetaFilename(prefix_);
421   if (use_temp_file_) {
422     data_path_ = strings::StrCat(data_path_, ".tempstate", random::New64());
423     metadata_path_ =
424         strings::StrCat(metadata_path_, ".tempstate", random::New64());
425   }
426 
427   status_ = env_->CreateDir(string(io::Dirname(prefix_)));
428   if (!status_.ok() && !errors::IsAlreadyExists(status_)) {
429     return;
430   }
431 
432   std::unique_ptr<WritableFile> wrapper;
433   status_ = env_->NewWritableFile(data_path_, &wrapper);
434   if (!status_.ok()) return;
435   out_ = std::unique_ptr<FileOutputBuffer>(
436       new FileOutputBuffer(wrapper.release(), 8 << 20 /* 8MB write buffer */));
437 
438   VLOG(1) << "Writing to file " << data_path_;
439 }
440 
Add(StringPiece key,const Tensor & val)441 Status BundleWriter::Add(StringPiece key, const Tensor& val) {
442   if (!status_.ok()) return status_;
443   CHECK_NE(key, kHeaderEntryKey);
444   const string key_string(key);
445   if (entries_.find(key_string) != entries_.end()) {
446     status_ = errors::InvalidArgument("Adding duplicate key: ", key);
447     return status_;
448   }
449 
450   BundleEntryProto* entry = &entries_[key_string];
451   entry->set_dtype(val.dtype());
452   val.shape().AsProto(entry->mutable_shape());
453   entry->set_shard_id(0);
454   entry->set_offset(size_);
455 
456   // Updates the data file.
457   size_t data_bytes_written = 0;
458   uint32 crc32c = 0;
459   out_->clear_crc32c();
460   if (val.dtype() == DT_STRING) {
461     status_ = WriteStringTensor(val, out_.get(), &data_bytes_written, &crc32c);
462   } else if (val.dtype() == DT_VARIANT) {
463     status_ = WriteVariantTensor(val, out_.get(), &data_bytes_written, &crc32c);
464   } else {
465     status_ = WriteTensor(val, out_.get(), &data_bytes_written);
466     crc32c = out_->crc32c();
467   }
468 
469   if (status_.ok()) {
470     entry->set_size(data_bytes_written);
471     entry->set_crc32c(crc32c::Mask(crc32c));
472     size_ += data_bytes_written;
473     status_ = PadAlignment(out_.get(), options_.data_alignment, &size_);
474   }
475   return status_;
476 }
477 
AddSlice(StringPiece full_tensor_key,const TensorShape & full_tensor_shape,const TensorSlice & slice_spec,const Tensor & slice_tensor)478 Status BundleWriter::AddSlice(StringPiece full_tensor_key,
479                               const TensorShape& full_tensor_shape,
480                               const TensorSlice& slice_spec,
481                               const Tensor& slice_tensor) {
482   if (!status_.ok()) return status_;
483   CHECK_NE(full_tensor_key, kHeaderEntryKey);
484 
485   // If just a singleton full slice, use the regular Add() to be more efficient.
486   if (IsFullSlice(slice_spec, full_tensor_shape)) {
487     return Add(full_tensor_key, slice_tensor);
488   }
489 
490   // Inserts/updates the full tensor's metadata entry.
491   //
492   // In the case of a sharded save, MergeBundles() is responsible for merging
493   // the "slices" field of multiple metadata entries corresponding to the same
494   // full tensor.
495   const string full_tensor_key_string(full_tensor_key);
496   BundleEntryProto* full_entry = &entries_[full_tensor_key_string];
497   if (full_entry->dtype() != DT_INVALID) {
498     CHECK_EQ(full_entry->dtype(), slice_tensor.dtype());
499   }
500   if (full_entry->has_shape()) {
501     CHECK(TensorShape(full_entry->shape()) == full_tensor_shape);
502   }
503 
504   // Populates dtype, shape, and slices.  Intentionally leaving out shard_id and
505   // offset, which do not make sense for this full tensor entry.
506   full_entry->set_dtype(slice_tensor.dtype());
507   full_tensor_shape.AsProto(full_entry->mutable_shape());
508   TensorSliceProto* slice_proto = full_entry->add_slices();
509   slice_spec.AsProto(slice_proto);
510 
511   // The slice itself is handled by a regular Add(), which includes adding its
512   // own metadata entry, and writing out the slice's values.
513   const string slice_name =
514       checkpoint::EncodeTensorNameSlice(full_tensor_key_string, slice_spec);
515   status_ = Add(slice_name, slice_tensor);
516   return status_;
517 }
518 
519 // TODO(zongheng): on metadata write failure or !status_.ok(), consider removing
520 // the orphaned data file.
Finish()521 Status BundleWriter::Finish() {
522   if (out_) {
523     status_.Update(out_->Close());
524     out_ = nullptr;
525     if (status_.ok()) {
526       if (use_temp_file_) {
527         status_ =
528             Env::Default()->RenameFile(data_path_, DataFilename(prefix_, 0, 1));
529       }
530     } else {
531       Env::Default()->DeleteFile(data_path_).IgnoreError();
532     }
533   }
534   if (!status_.ok()) return status_;
535   // Build key -> BundleEntryProto table.
536   std::unique_ptr<WritableFile> file;
537   status_ = env_->NewWritableFile(metadata_path_, &file);
538   if (!status_.ok()) return status_;
539   {
540     // N.B.: the default use of Snappy compression may not be supported on all
541     // platforms (e.g. Android).  The metadata file is small, so this is fine.
542     table::Options options;
543     options.compression = table::kNoCompression;
544     table::TableBuilder builder(options, file.get());
545     // Header entry.
546     BundleHeaderProto header;
547     header.set_num_shards(1);
548     header.set_endianness(BundleHeaderProto::LITTLE);
549     if (!port::kLittleEndian) header.set_endianness(BundleHeaderProto::BIG);
550     VersionDef* version = header.mutable_version();
551     version->set_producer(kTensorBundleVersion);
552     version->set_min_consumer(kTensorBundleMinConsumer);
553 
554     builder.Add(kHeaderEntryKey, header.SerializeAsString());
555 
556     // All others.
557     for (const auto& p : entries_) {
558       builder.Add(p.first, p.second.SerializeAsString());
559     }
560     status_ = builder.Finish();
561   }
562   status_.Update(file->Close());
563   if (!status_.ok()) {
564     Env::Default()->DeleteFile(metadata_path_).IgnoreError();
565     return status_;
566   } else if (use_temp_file_) {
567     status_ = Env::Default()->RenameFile(metadata_path_, MetaFilename(prefix_));
568     if (!status_.ok()) return status_;
569   }
570   status_ = errors::Internal("BundleWriter is closed");
571   return Status::OK();
572 }
573 
574 // Merging tensor bundles.
575 
576 // Accumulator of metadata states during a merge.
577 struct MergeState {
578   // Accumulated from the header entries.
579   int num_shards = 0;
580 
581   // Derives "endianness" and "version" from the first bundle merged (hence the
582   // "seen_first_bundle" guard).  The two fields must be the same for all
583   // bundles in a merge.
584   bool seen_first_bundle = false;
585   BundleHeaderProto_Endianness endianness;
586   VersionDef version;
587 
588   // Tensor key -> BundleEntryProto.
589   std::map<string, BundleEntryProto> entries;
590   // Data file path -> new shard id in the final merged bundle.
591   std::unordered_map<string, int32> shard_ids;
592 };
593 
594 // Merges entries of "prefix" into the accumulator state "merge".
595 // Returns OK iff the merge succeeds.
MergeOneBundle(Env * env,StringPiece prefix,MergeState * merge_state)596 static Status MergeOneBundle(Env* env, StringPiece prefix,
597                              MergeState* merge_state) {
598   VLOG(1) << "Merging bundle:" << prefix;
599   const string filename = MetaFilename(prefix);
600   uint64 file_size;
601   TF_RETURN_IF_ERROR(env->GetFileSize(filename, &file_size));
602   std::unique_ptr<RandomAccessFile> file;
603   TF_RETURN_IF_ERROR(env->NewRandomAccessFile(filename, &file));
604 
605   table::Table* table = nullptr;
606   TF_RETURN_IF_ERROR(
607       table::Table::Open(TableBuilderOptions(), file.get(), file_size, &table));
608   std::unique_ptr<table::Table> table_deleter(table);
609   std::unique_ptr<table::Iterator> iter(table->NewIterator());
610 
611   int num_shards;
612   // Process header.
613   {
614     iter->Seek(kHeaderEntryKey);
615     if (!iter->Valid()) {
616       return CorruptFileError(iter->status(), filename,
617                               "failed to seek to header entry");
618     }
619     BundleHeaderProto header;
620     Status s = ParseEntryProto(iter->key(), iter->value(), &header);
621     if (!s.ok()) return CorruptFileError(s, filename, "unable to parse header");
622 
623     merge_state->num_shards += header.num_shards();
624     if (!merge_state->seen_first_bundle) {
625       merge_state->seen_first_bundle = true;
626       merge_state->endianness = header.endianness();
627       merge_state->version = header.version();
628     } else {
629       // Validates "endianness".
630       if (merge_state->endianness != header.endianness()) {
631         return errors::InvalidArgument(
632             "Merging bundles with conflicting endianness; inputs corrupted?");
633       }
634       // Validates "version".
635       string curr_version, merge_version;
636       header.version().SerializeToString(&curr_version);
637       merge_state->version.SerializeToString(&merge_version);
638       if (curr_version != merge_version) {
639         return errors::InvalidArgument(
640             "Merging bundles with different format versions: merged ",
641             merge_version, " vs. curr ", curr_version);
642       }
643     }
644     num_shards = header.num_shards();
645     iter->Next();
646   }
647 
648   // Loops through the non-header to-merge entries.
649   BundleEntryProto to_merge_entry;
650   for (; iter->Valid(); iter->Next()) {
651     const string key(iter->key());
652     const auto entry_iter = merge_state->entries.find(key);
653 
654     // Illegal: the duplicated entry is a non-slice tensor.
655     if (entry_iter != merge_state->entries.end() &&
656         entry_iter->second.slices().empty()) {
657       return errors::InvalidArgument(
658           "Duplicate tensor keyed by ", key,
659           " encountered, when merging prefix: ", prefix);
660     }
661 
662     TF_RETURN_IF_ERROR(
663         ParseEntryProto(iter->key(), iter->value(), &to_merge_entry));
664 
665     // The duplicated entry holds metadata for a sliced full tensor.
666     // Allows the duplication and merges "slices".
667     if (entry_iter != merge_state->entries.end()) {
668       BundleEntryProto& existing_entry = entry_iter->second;
669       if (to_merge_entry.slices().empty()) {
670         return errors::Internal(
671             "Duplicate tensor keyed by ", key,
672             "; attempting to merge in a non-slice bundle entry");
673       }
674       // Only needs merge the "slices" field (and validate dtype/shape).
675       for (int i = 0; i < to_merge_entry.slices_size(); ++i) {
676         TensorSliceProto* slot = existing_entry.add_slices();
677         *slot = to_merge_entry.slices(i);
678       }
679       CHECK_EQ(existing_entry.dtype(), to_merge_entry.dtype());
680       CHECK(TensorShape(existing_entry.shape()) ==
681             TensorShape(to_merge_entry.shape()));
682       continue;
683     }
684 
685     // Key doesn't duplicate: a fresh tensor/slice entry.
686     auto result = merge_state->shard_ids.insert(
687         {DataFilename(prefix, to_merge_entry.shard_id(), num_shards),
688          merge_state->shard_ids.size()});
689     to_merge_entry.set_shard_id(result.first->second);
690     merge_state->entries[key] = to_merge_entry;
691   }
692   return Status::OK();
693 }
694 
MergeBundles(Env * env,gtl::ArraySlice<tstring> prefixes,StringPiece merged_prefix)695 Status MergeBundles(Env* env, gtl::ArraySlice<tstring> prefixes,
696                     StringPiece merged_prefix) {
697   // Merges all metadata tables.
698   // TODO(zhifengc): KeyValue sorter if it becomes too big.
699   MergeState merge;
700   Status status = env->CreateDir(string(io::Dirname(merged_prefix)));
701   if (!status.ok() && !errors::IsAlreadyExists(status)) return status;
702   for (int i = 0; i < prefixes.size(); ++i) {
703     TF_RETURN_IF_ERROR(MergeOneBundle(env, prefixes[i], &merge));
704   }
705 
706   // Renames data files to contain the merged bundle prefix.
707   for (const auto& p : merge.shard_ids) {
708     VLOG(1) << "Renaming " << p.first << " to "
709             << DataFilename(merged_prefix, p.second, merge.shard_ids.size());
710     TF_RETURN_IF_ERROR(env->RenameFile(
711         p.first,
712         DataFilename(merged_prefix, p.second, merge.shard_ids.size())));
713   }
714 
715   // Writes the final metadata table under the merged prefix.
716   std::unique_ptr<WritableFile> merged_metadata;
717   TF_RETURN_IF_ERROR(
718       env->NewWritableFile(MetaFilename(merged_prefix), &merged_metadata));
719   {
720     table::TableBuilder builder(TableBuilderOptions(), merged_metadata.get());
721     // Header entry.
722     BundleHeaderProto header;
723     header.set_num_shards(merge.num_shards);
724     header.set_endianness(merge.endianness);
725     *header.mutable_version() = merge.version;
726     builder.Add(kHeaderEntryKey, header.SerializeAsString());
727     // All others.
728     for (const auto& p : merge.entries) {
729       builder.Add(p.first, p.second.SerializeAsString());
730     }
731     status = builder.Finish();
732   }
733   status.Update(merged_metadata->Close());
734   if (!status.ok()) return status;
735   VLOG(1) << "Merged bundles to:" << merged_prefix;
736 
737   // Cleanup: best effort based and ignores errors.
738   for (const tstring& prefix : prefixes) {
739     env->DeleteFile(MetaFilename(prefix)).IgnoreError();
740   }
741   return status;
742 }
743 
744 // Interface for reading a tensor bundle.
745 
BundleReader(Env * env,StringPiece prefix)746 BundleReader::BundleReader(Env* env, StringPiece prefix)
747     : env_(env),
748       prefix_(prefix),
749       metadata_(nullptr),
750       table_(nullptr),
751       index_cache_(nullptr),
752       iter_(nullptr),
753       need_to_swap_bytes_(false) {
754   const string filename = MetaFilename(prefix_);
755   uint64 file_size;
756   status_ = env_->GetFileSize(filename, &file_size);
757   if (!status_.ok()) return;
758 
759   // Opens the metadata table.
760   std::unique_ptr<RandomAccessFile> wrapper;
761   status_ = env_->NewRandomAccessFile(filename, &wrapper);
762   if (!status_.ok()) return;
763   metadata_ = wrapper.release();
764 
765   table::Options o;
766   int64_t cache_size;
767   Status s =
768       ReadInt64FromEnvVar("TF_TABLE_INDEX_CACHE_SIZE_IN_MB", 0, &cache_size);
769   if (s.ok() && cache_size > 0) {
770     index_cache_ = table::NewLRUCache(cache_size << 20);
771     o.block_cache = index_cache_;
772   }
773 
774   status_ = table::Table::Open(o, metadata_, file_size, &table_);
775   if (!status_.ok()) return;
776   iter_ = table_->NewIterator();
777 
778   // Reads "num_shards_" from the first entry.
779   iter_->Seek(kHeaderEntryKey);
780   if (!iter_->Valid()) {
781     status_ = CorruptFileError(iter_->status(), filename,
782                                "failed to seek to header entry");
783     return;
784   }
785   BundleHeaderProto header;
786   status_ = ParseEntryProto(iter_->key(), iter_->value(), &header);
787   if (!status_.ok()) {
788     status_ = CorruptFileError(status_, filename, "unable to parse header");
789     return;
790   }
791   num_shards_ = header.num_shards();
792   if ((header.endianness() == BundleHeaderProto::BIG && port::kLittleEndian) ||
793       (header.endianness() == BundleHeaderProto::LITTLE &&
794        !port::kLittleEndian)) {
795     need_to_swap_bytes_ = true;
796   }
797   status_ = CheckVersions(header.version(), kTensorBundleVersion,
798                           kTensorBundleMinProducer, "Checkpoint", "checkpoint");
799 }
800 
~BundleReader()801 BundleReader::~BundleReader() {
802   delete metadata_;
803   delete iter_;
804   delete table_;
805   if (index_cache_) {
806     delete index_cache_;
807   }
808   // InputBuffer does not own the underlying RandomAccessFile.
809   for (auto pair : data_) {
810     if (pair.second != nullptr && pair.second->file() != nullptr) {
811       delete pair.second->file();
812     }
813   }
814   for (auto& temp : data_) {
815     delete temp.second;
816   }
817   for (auto& temp : tensor_slices_) {
818     delete temp.second;
819   }
820   data_.clear();
821   tensor_slices_.clear();
822 }
823 
GetBundleEntryProto(StringPiece key,BundleEntryProto * entry)824 Status BundleReader::GetBundleEntryProto(StringPiece key,
825                                          BundleEntryProto* entry) {
826   entry->Clear();
827   TF_CHECK_OK(status_);
828   Seek(key);
829   if (!iter_->Valid() || iter_->key() != key) {
830     return errors::NotFound("Key ", key, " not found in checkpoint");
831   }
832 
833   BundleEntryProto entry_copy;
834   TF_RETURN_IF_ERROR(
835       ParseEntryProto(iter_->key(), iter_->value(), &entry_copy));
836   if (!TensorShape::IsValid(entry_copy.shape())) {
837     return errors::DataLoss("Invalid tensor shape: ", key, " ",
838                             entry_copy.shape().ShortDebugString());
839   }
840 
841   entry->Swap(&entry_copy);
842   return Status::OK();
843 }
844 
GetValue(const BundleEntryProto & entry,Tensor * val)845 Status BundleReader::GetValue(const BundleEntryProto& entry, Tensor* val) {
846   Tensor* ret = val;
847   const TensorShape stored_shape(TensorShape(entry.shape()));
848   if (val->NumElements() == 0) {
849     ret = new Tensor(entry.dtype(), stored_shape);
850   }
851 
852   // Validates the "size" field.
853   if (entry.dtype() != DT_STRING && entry.dtype() != DT_VARIANT) {
854     if (entry.size() != ret->TotalBytes()) {
855       return errors::DataLoss("Invalid size in bundle entry: key ", key(),
856                               "; stored size ", entry.size(),
857                               "; expected size ", ret->TotalBytes());
858     }
859   } else if (entry.dtype() == DT_STRING) {
860     // Relaxes the check for string tensors as follows:
861     //   entry.size() == bytes(varint lengths) + bytes(data)
862     //                >= NumElems + bytes(data), since size bytes(varint) >= 1.
863     //   TotalBytes() == sizeof(tstring) * NumElems + bytes(data)
864     // Since we don't know bytes(varint lengths), we just check an inequality.
865     const size_t lower_bound = ret->NumElements() + ret->TotalBytes() -
866                                sizeof(tstring) * ret->NumElements();
867     if (entry.size() < lower_bound) {
868       return errors::DataLoss("Invalid size in bundle entry: key ", key(),
869                               "; stored size ", entry.size(),
870                               "; expected size is at least ", lower_bound);
871     }
872   }
873 
874   // Open the data file if it has not been opened.
875   io::InputBuffer* buffered_file = data_[entry.shard_id()];
876   if (buffered_file == nullptr) {
877     std::unique_ptr<RandomAccessFile> file = nullptr;
878     TF_RETURN_IF_ERROR(env_->NewRandomAccessFile(
879         DataFilename(prefix_, entry.shard_id(), num_shards_), &file));
880     buffered_file = new io::InputBuffer(file.release(), kBufferSize);
881     // The InputBuffer and RandomAccessFile objects are both released in dtor.
882     data_[entry.shard_id()] = buffered_file;
883   }
884   CHECK(buffered_file != nullptr);
885 
886   TF_RETURN_IF_ERROR(buffered_file->Seek(entry.offset()));
887   uint32 actual_crc32c = 0;
888 
889   if (DataTypeCanUseMemcpy(entry.dtype())) {
890     char* backing_buffer = const_cast<char*>((ret->tensor_data().data()));
891     size_t unused_bytes_read;
892     if (entry.size() > kBufferSize) {
893       StringPiece sp;
894       TF_RETURN_IF_ERROR(buffered_file->file()->Read(
895           entry.offset(), entry.size(), &sp, backing_buffer));
896       if (sp.data() != backing_buffer) {
897         memmove(backing_buffer, sp.data(), entry.size());
898       }
899     } else {
900       TF_RETURN_IF_ERROR(buffered_file->ReadNBytes(entry.size(), backing_buffer,
901                                                    &unused_bytes_read));
902     }
903     // Note that we compute the checksum *before* byte-swapping. The checksum
904     // should be on the bytes in the order they appear in the file.
905     actual_crc32c = crc32c::Value(backing_buffer, entry.size());
906     if (need_to_swap_bytes_) {
907       TF_RETURN_IF_ERROR(ByteSwapTensor(ret));
908     }
909   } else if (entry.dtype() == DT_VARIANT) {
910     if (need_to_swap_bytes_) {
911       return errors::Unimplemented(
912           "TensorBundle at ", prefix_,
913           "is of a different endianness than this machine's hardware, and "
914           "the bundle contains a variant (arbitrary C++ type) tensor. "
915           "Byte-swapping of variant tensors is not currently implemented.");
916     }
917     // Relies on io::InputBuffer's buffering, because we issue many neighboring
918     // reads for a single string tensor.
919     TF_RETURN_IF_ERROR(ReadVariantTensor(buffered_file, ret, entry.offset(),
920                                          entry.size(), &actual_crc32c));
921   } else {
922     // Relies on io::InputBuffer's buffering, because we issue many neighboring
923     // reads for a single string tensor.
924     TF_RETURN_IF_ERROR(ReadStringTensor(
925         buffered_file, ret->NumElements(), entry.offset(), entry.size(),
926         GetStringBackingBuffer(*ret), &actual_crc32c, need_to_swap_bytes_));
927   }
928   if (crc32c::Unmask(entry.crc32c()) != actual_crc32c) {
929     return errors::DataLoss(
930         "TensorBundle at ", prefix_, " shard ", entry.shard_id(), " (",
931         entry.size(), " bytes): Checksum does not match: stored ",
932         strings::Printf("%08u", crc32c::Unmask(entry.crc32c())),
933         " vs. calculated on the restored bytes ", actual_crc32c);
934   }
935 
936   *val = *ret;
937   if (ret != val) delete ret;
938   return Status::OK();
939 }
940 
Lookup(StringPiece key,Tensor * val)941 Status BundleReader::Lookup(StringPiece key, Tensor* val) {
942   CHECK(val != nullptr);
943   BundleEntryProto entry;
944   TF_RETURN_IF_ERROR(GetBundleEntryProto(key, &entry));
945 
946   if (entry.slices().empty()) {
947     return GetValue(entry, val);
948   } else {
949     return GetSliceValue(
950         key, entry,
951         /* a full slice */ TensorSlice(TensorShape(entry.shape()).dims()), val);
952   }
953 }
954 
ReadCurrent(Tensor * val)955 Status BundleReader::ReadCurrent(Tensor* val) {
956   CHECK(val != nullptr);
957   BundleEntryProto entry;
958   TF_RETURN_IF_ERROR(ParseEntryProto(iter_->key(), iter_->value(), &entry));
959   if (!TensorShape::IsValid(entry.shape())) {
960     return errors::DataLoss("Invalid tensor shape: ", iter_->key(), " ",
961                             entry.shape().ShortDebugString());
962   }
963 
964   if (entry.slices().empty()) {
965     return GetValue(entry, val);
966   } else {
967     return GetSliceValue(
968         iter_->key(), entry,
969         /* a full slice */ TensorSlice(TensorShape(entry.shape()).dims()), val);
970   }
971 }
972 
LookupTensorSlices(StringPiece key,std::vector<TensorSlice> * slices)973 Status BundleReader::LookupTensorSlices(StringPiece key,
974                                         std::vector<TensorSlice>* slices) {
975   slices->clear();
976   BundleEntryProto entry;
977   TF_RETURN_IF_ERROR(GetBundleEntryProto(key, &entry));
978   slices->reserve(entry.slices_size());
979   for (const auto& slice : entry.slices()) {
980     slices->emplace_back(slice);
981   }
982   return Status::OK();
983 }
984 
LookupSlice(StringPiece full_tensor_key,const TensorSlice & slice_spec,Tensor * val)985 Status BundleReader::LookupSlice(StringPiece full_tensor_key,
986                                  const TensorSlice& slice_spec, Tensor* val) {
987   CHECK(val != nullptr);
988   BundleEntryProto entry;
989   TF_RETURN_IF_ERROR(GetBundleEntryProto(full_tensor_key, &entry));
990   return GetSliceValue(full_tensor_key, entry, slice_spec, val);
991 }
992 
GetSliceValue(StringPiece full_tensor_key,const BundleEntryProto & full_tensor_entry,const TensorSlice & slice_spec,Tensor * val)993 Status BundleReader::GetSliceValue(StringPiece full_tensor_key,
994                                    const BundleEntryProto& full_tensor_entry,
995                                    const TensorSlice& slice_spec, Tensor* val) {
996   using checkpoint::RegisterTensorSlice;
997   using checkpoint::TensorSliceSet;
998   DCHECK_GE(full_tensor_entry.slices_size(), 0);
999 
1000   const TensorShape full_shape(TensorShape(full_tensor_entry.shape()));
1001   std::vector<std::pair<TensorSlice, string>> details;
1002   const string full_tensor_key_string(full_tensor_key);
1003   const TensorSliceSet* tss =
1004       gtl::FindPtrOrNull(tensor_slices_, full_tensor_key_string);
1005 
1006   // Populates the "full tensor key -> TensorSliceSet" cache.
1007   if (tss == nullptr) {
1008     if (full_tensor_entry.slices().empty()) {
1009       // Special case: a writer has saved a tensor fully, but the reader wants
1010       // to read in slices.  We therefore register the full slice on-demand here
1011       // without further complicating the on-disk bundle format.
1012       TF_RETURN_IF_ERROR(RegisterTensorSlice(
1013           full_tensor_key_string, full_shape, full_tensor_entry.dtype(),
1014           /* tag */ "",
1015           /* full slice */ TensorSlice(full_shape.dims()), &tensor_slices_));
1016     }
1017     for (const TensorSliceProto& slice : full_tensor_entry.slices()) {
1018       TF_RETURN_IF_ERROR(RegisterTensorSlice(
1019           full_tensor_key_string, full_shape, full_tensor_entry.dtype(),
1020           /* tag */ "", TensorSlice(slice), &tensor_slices_));
1021     }
1022     tss = gtl::FindPtrOrNull(tensor_slices_, full_tensor_key_string);
1023     CHECK_NE(tss, nullptr);
1024   }
1025   if (!tss->QueryMeta(slice_spec, &details)) {
1026     return errors::InvalidArgument(
1027         "Does not have sufficient slices for partitioned tensor ",
1028         full_tensor_key,
1029         " to restore in slice_spec: ", slice_spec.DebugString());
1030   }
1031 
1032   // The union of the slices in "details" covers "slice_spec".  Performs the
1033   // copies from each.
1034   BundleEntryProto stored_slice_entry = full_tensor_entry;
1035   for (const auto& slice_tag_pair : details) {
1036     // Seeks for the stored slice.
1037     const TensorSlice& stored_slice = slice_tag_pair.first;
1038 
1039     // We already have the entry for the full tensor, so don't query again if
1040     // the slice is full.
1041     if (!stored_slice.IsFull()) {
1042       const string encoded_stored_slice_name =
1043           checkpoint::EncodeTensorNameSlice(full_tensor_key_string,
1044                                             stored_slice);
1045       status_ =
1046           GetBundleEntryProto(encoded_stored_slice_name, &stored_slice_entry);
1047       if (!status_.ok()) return status_;
1048     }
1049 
1050     // TODO(zongheng): should we take an OpKernelContext, so that we can call
1051     // allocate_temp()?  Note that without major refactorings to Saver, it's
1052     // hard for the caller of the tensor bundle module to allocate these
1053     // precisely-shaped scratch storage.
1054 
1055     // Optimization for the common case: the stored slice can be directly
1056     // copied to the destination without additional slicing. This is true when
1057     // either the slices are equal or when they are both full slices having the
1058     // same shape.
1059     TensorShape stored_slice_shape(stored_slice_entry.shape());
1060     if (stored_slice == slice_spec ||
1061         (stored_slice_shape == val->shape() &&
1062          IsFullSlice(stored_slice, stored_slice_shape) &&
1063          IsFullSlice(slice_spec, stored_slice_shape))) {
1064       VLOG(1) << "Optimized for common case: directly copying into "
1065                  "pre-allocated buffer; spec: "
1066               << slice_spec.DebugString();
1067       status_ = GetValue(stored_slice_entry, val);
1068       return status_;
1069     }
1070 
1071     Tensor stored_slice_tensor(stored_slice_entry.dtype(), stored_slice_shape);
1072     status_ = GetValue(stored_slice_entry, &stored_slice_tensor);
1073     if (!status_.ok()) return status_;
1074 
1075     // Copies the intersection over.
1076     const DataType common_dtype = full_tensor_entry.dtype();
1077     switch (common_dtype) {
1078 #define HANDLE_COPY(T)                                                 \
1079   case DataTypeToEnum<T>::value:                                       \
1080     CHECK(CopyDataFromTensorSliceToTensorSlice(                        \
1081         full_shape, stored_slice, slice_spec,                          \
1082         stored_slice_tensor.flat<T>().data(), val->flat<T>().data())); \
1083     break;
1084 
1085       HANDLE_COPY(float)
1086       HANDLE_COPY(double)
1087       HANDLE_COPY(int32)
1088       HANDLE_COPY(uint8)
1089       HANDLE_COPY(int16)
1090       HANDLE_COPY(int8)
1091       HANDLE_COPY(complex64)
1092       HANDLE_COPY(complex128)
1093       HANDLE_COPY(int64)
1094       HANDLE_COPY(bool)
1095       HANDLE_COPY(qint32)
1096       HANDLE_COPY(quint8)
1097       HANDLE_COPY(qint8)
1098       HANDLE_COPY(bfloat16)
1099       default:
1100         return errors::InvalidArgument("Dtype ", DataTypeString(common_dtype),
1101                                        " not supported.");
1102     }
1103 #undef HANDLE_COPY
1104   }
1105   return Status::OK();
1106 }
1107 
Contains(StringPiece key)1108 bool BundleReader::Contains(StringPiece key) {
1109   Seek(key);
1110   return Valid() && (this->key() == key);
1111 }
1112 
LookupDtypeAndShape(StringPiece key,DataType * dtype,TensorShape * shape)1113 Status BundleReader::LookupDtypeAndShape(StringPiece key, DataType* dtype,
1114                                          TensorShape* shape) {
1115   BundleEntryProto entry;
1116   TF_RETURN_IF_ERROR(GetBundleEntryProto(key, &entry));
1117   *dtype = entry.dtype();
1118   *shape = TensorShape(entry.shape());
1119   return Status::OK();
1120 }
1121 
LookupTensorShape(StringPiece key,TensorShape * shape)1122 Status BundleReader::LookupTensorShape(StringPiece key, TensorShape* shape) {
1123   DataType ignored;
1124   return LookupDtypeAndShape(key, &ignored, shape);
1125 }
1126 
DebugString()1127 string BundleReader::DebugString() {
1128   // Format used below emulates that of TensorSliceReader::DebugString().
1129   string shape_str;
1130   BundleEntryProto entry;
1131   Seek(kHeaderEntryKey);
1132   for (Next(); Valid(); Next()) {
1133     CHECK(entry.ParseFromArray(value().data(), value().size()));
1134     if (entry.slices_size() > 0) continue;  // Slice of some partitioned var.
1135 
1136     strings::StrAppend(&shape_str, key(), " (", DataType_Name(entry.dtype()),
1137                        ") ", TensorShape(entry.shape()).DebugString());
1138     strings::StrAppend(&shape_str, "\n");
1139   }
1140   return shape_str;
1141 }
1142 
1143 namespace {
AlignedMalloc(size_t size)1144 inline char* AlignedMalloc(size_t size) {
1145   char* buffer = static_cast<char*>(port::AlignedMalloc(size, 64));
1146   DCHECK(buffer);
1147   return buffer;
1148 }
1149 }  // namespace
1150 
FileOutputBuffer(WritableFile * file,size_t buffer_size)1151 FileOutputBuffer::FileOutputBuffer(WritableFile* file, size_t buffer_size)
1152     : file_(file), position_(0), buffer_size_(buffer_size) {
1153   DCHECK_GT(buffer_size, 0);
1154   buffer_ptr_ = AlignedMalloc(buffer_size);
1155 }
1156 
~FileOutputBuffer()1157 FileOutputBuffer::~FileOutputBuffer() {
1158   if (buffer_ptr_) port::AlignedFree(buffer_ptr_);
1159   delete file_;
1160 }
1161 
Append(StringPiece data)1162 Status FileOutputBuffer::Append(StringPiece data) {
1163   // In the below, it is critical to calculate the checksum on the actually
1164   // copied bytes, not the source bytes.  This is because "data" typically
1165   // points to tensor buffers, which may be concurrently written.
1166   if (data.size() + position_ <= buffer_size_) {
1167     // Can fit into the current buffer.
1168     memcpy(buffer_ptr_ + position_, data.data(), data.size());
1169     crc32c_ = crc32c::Extend(crc32c_, buffer_ptr_ + position_, data.size());
1170   } else if (data.size() <= buffer_size_) {
1171     // Cannot fit, but can fit after flushing.
1172     TF_RETURN_IF_ERROR(FlushBuffer(false));
1173     memcpy(buffer_ptr_, data.data(), data.size());
1174     crc32c_ = crc32c::Extend(crc32c_, buffer_ptr_, data.size());
1175   } else {
1176     // Cannot fit even after flushing.  So we break down "data" by chunk, and
1177     // flush/checksum each chunk.
1178     TF_RETURN_IF_ERROR(FlushBuffer(false));
1179     for (size_t i = 0; i < data.size(); i += buffer_size_) {
1180       const size_t nbytes = std::min(data.size() - i, buffer_size_);
1181       memcpy(buffer_ptr_, data.data() + i, nbytes);
1182       crc32c_ = crc32c::Extend(crc32c_, buffer_ptr_, nbytes);
1183       position_ = nbytes;
1184       TF_RETURN_IF_ERROR(FlushBuffer(false));
1185     }
1186     return Status::OK();
1187   }
1188   position_ += data.size();
1189   return Status::OK();
1190 }
1191 
Close()1192 Status FileOutputBuffer::Close() {
1193   TF_RETURN_IF_ERROR(FlushBuffer(true));
1194   return file_->Close();
1195 }
1196 
FlushBuffer(bool closing)1197 Status FileOutputBuffer::FlushBuffer(bool closing) {
1198   if (position_ > 0) {
1199     // Use Cord to avoid extra data copy for some WritableFile implementations.
1200     absl::Cord buffer = absl::MakeCordFromExternal(
1201         StringPiece(buffer_ptr_, position_),
1202         [ptr = buffer_ptr_](StringPiece) { port::AlignedFree(ptr); });
1203     buffer_ptr_ = closing ? nullptr : AlignedMalloc(buffer_size_);
1204     TF_RETURN_IF_ERROR(file_->Append(buffer));
1205     position_ = 0;
1206   }
1207   return Status::OK();
1208 }
1209 
1210 }  // namespace tensorflow
1211