• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7 http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/util/tensor_bundle/tensor_bundle.h"
17 
18 #include <algorithm>
19 #include <cstdlib>
20 #include <cstring>
21 #include <memory>
22 #include <utility>
23 
24 #include "tensorflow/core/framework/register_types.h"
25 #include "tensorflow/core/framework/tensor.pb.h"
26 #include "tensorflow/core/framework/tensor_shape.pb.h"
27 #include "tensorflow/core/framework/types.h"
28 #include "tensorflow/core/framework/types.pb.h"
29 #include "tensorflow/core/framework/variant.h"
30 #include "tensorflow/core/framework/variant_op_registry.h"
31 #include "tensorflow/core/framework/variant_tensor_data.h"
32 #include "tensorflow/core/framework/versions.h"
33 #include "tensorflow/core/framework/versions.pb.h"
34 #include "tensorflow/core/lib/core/coding.h"
35 #include "tensorflow/core/lib/core/errors.h"
36 #include "tensorflow/core/lib/gtl/map_util.h"
37 #include "tensorflow/core/lib/hash/crc32c.h"
38 #include "tensorflow/core/lib/io/path.h"
39 #include "tensorflow/core/lib/io/table_builder.h"
40 #include "tensorflow/core/lib/random/random.h"
41 #include "tensorflow/core/lib/strings/str_util.h"
42 #include "tensorflow/core/lib/strings/stringprintf.h"
43 #include "tensorflow/core/platform/bfloat16.h"
44 #include "tensorflow/core/platform/cord.h"
45 #include "tensorflow/core/platform/errors.h"
46 #include "tensorflow/core/platform/mem.h"
47 #include "tensorflow/core/util/env_var.h"
48 #include "tensorflow/core/util/saved_tensor_slice_util.h"
49 #include "tensorflow/core/util/tensor_bundle/byte_swap.h"
50 #include "tensorflow/core/util/tensor_slice_util.h"
51 
52 #ifdef PLATFORM_WINDOWS
53 #undef DeleteFile
54 #endif
55 
56 namespace tensorflow {
57 
58 // Versioning of the tensor bundle format.
59 const int kTensorBundleMinProducer = 0;
60 const int kTensorBundleMinConsumer = 0;
61 const int kTensorBundleVersion = 1;
62 
63 // Size of our input buffer for streaming reads
64 static const int kBufferSize = 1024 * 1024;
65 
66 // Key to the special BundleHeaderProto entry.  Do not change this, as clients
67 // can make the assumption that the header is always the first entry in the
68 // bundle.
69 const char* const kHeaderEntryKey = "";
70 
71 namespace {
72 
73 // Reads "num_elements" string elements from file[offset, offset+size) into the
74 // length-N "destination".  Discards the original content of "destination".
75 //
76 // Checksums the string lengths (as restored uint32 or uint64, not varint64
77 // bytes) and string bytes, and stores it into "actual_crc32c".
ReadStringTensor(io::InputBuffer * buffered_file,size_t num_elements,size_t offset,size_t size,tstring * destination,uint32 * actual_crc32c,bool need_to_swap_bytes)78 Status ReadStringTensor(io::InputBuffer* buffered_file, size_t num_elements,
79                         size_t offset, size_t size, tstring* destination,
80                         uint32* actual_crc32c, bool need_to_swap_bytes) {
81   if (size == 0) return OkStatus();
82   CHECK_GT(size, 0);
83 
84   // Reads "num_elements" varint64's from "buffered_file".
85   TF_RETURN_IF_ERROR(buffered_file->Seek(offset));
86   TF_RETURN_IF_ERROR(buffered_file->Hint(size));
87   std::vector<uint64> string_lengths(num_elements);
88   for (size_t i = 0; i < num_elements; ++i) {
89     TF_RETURN_IF_ERROR(buffered_file->ReadVarint64(&string_lengths[i]));
90     if (string_lengths[i] <= UINT32_MAX) {
91       // We need to do this because older checkpoints only used uint32s and we
92       // should still support them.
93       uint32 elem_size_uint32 = static_cast<uint32>(string_lengths[i]);
94       if (need_to_swap_bytes) {
95         // Checksum would have been computed on the source machine's byte order
96         elem_size_uint32 = BYTE_SWAP_32(elem_size_uint32);
97       }
98       *actual_crc32c = crc32c::Extend(
99           *actual_crc32c, reinterpret_cast<const char*>(&elem_size_uint32),
100           sizeof(uint32));
101     } else {
102       uint64 length = string_lengths[i];
103       if (need_to_swap_bytes) {
104         length = BYTE_SWAP_64(length);
105       }
106       *actual_crc32c =
107           crc32c::Extend(*actual_crc32c, reinterpret_cast<const char*>(&length),
108                          sizeof(uint64));
109     }
110   }
111   if (offset + size < buffered_file->Tell()) {
112     return errors::DataLoss("String lengths longer than expected offset ",
113                             offset + size);
114   }
115 
116   // Reads the length-checksum.
117   uint32 raw_length_checksum = 0;  // Bytes in file
118   uint32 length_checksum = 0;      // In-memory representation
119   size_t unused_bytes_read = 0;
120   TF_RETURN_IF_ERROR(buffered_file->ReadNBytes(
121       sizeof(uint32), reinterpret_cast<char*>(&raw_length_checksum),
122       &unused_bytes_read));
123   length_checksum = need_to_swap_bytes ? BYTE_SWAP_32(raw_length_checksum)
124                                        : raw_length_checksum;
125   if (crc32c::Unmask(length_checksum) != *actual_crc32c) {
126     return errors::DataLoss(
127         "The length checksum does not match: expected ",
128         strings::Printf("%08u", crc32c::Unmask(length_checksum)),
129         " but actual is ", strings::Printf("%08u", *actual_crc32c));
130   }
131   *actual_crc32c = crc32c::Extend(*actual_crc32c,
132                                   reinterpret_cast<char*>(&raw_length_checksum),
133                                   sizeof(uint32));
134 
135   // Reads the actual string bytes.
136   for (size_t i = 0; i < num_elements; ++i) {
137     const uint64 string_length = string_lengths[i];
138     tstring* buffer = &destination[i];
139 
140     buffer->resize(string_length);
141     size_t bytes_read = 0;
142     TF_RETURN_IF_ERROR(
143         buffered_file->ReadNBytes(string_length, &(*buffer)[0], &bytes_read));
144     *actual_crc32c = crc32c::Extend(*actual_crc32c, buffer->data(), bytes_read);
145   }
146   return OkStatus();
147 }
148 
ReadVariantTensor(io::InputBuffer * buffered_file,Tensor * ret,size_t offset,size_t size,uint32 * actual_crc32c)149 Status ReadVariantTensor(io::InputBuffer* buffered_file, Tensor* ret,
150                          size_t offset, size_t size, uint32* actual_crc32c) {
151   // On-disk format:
152   //   [varint64 len1][bytes variant1][4 byte checksum]
153   //   ..
154   //   [varint64 lenN][bytes variantN][4 byte checksum]
155   // Var "crc32c" checksums all the lens, variant bytes, individual variant
156   // checksums (as uint32, not varint32 bytes).
157   if (size == 0) return OkStatus();
158   size_t num_elements = ret->NumElements();
159 
160   // Reads the actual string bytes.
161   TF_RETURN_IF_ERROR(buffered_file->Seek(offset));
162   TF_RETURN_IF_ERROR(buffered_file->Hint(size));
163   for (size_t i = 0; i < num_elements; ++i) {
164     // Read the serialized variant length.
165     uint64 string_length = 0;
166     TF_RETURN_IF_ERROR(buffered_file->ReadVarint64(&string_length));
167     *actual_crc32c = crc32c::Extend(
168         *actual_crc32c, reinterpret_cast<const char*>(&string_length),
169         sizeof(uint64));
170     // Read the actual serialized variant.
171     string buffer;
172     buffer.resize(string_length);
173     size_t bytes_read = 0;
174     TF_RETURN_IF_ERROR(
175         buffered_file->ReadNBytes(string_length, &buffer[0], &bytes_read));
176     *actual_crc32c = crc32c::Extend(*actual_crc32c, buffer.data(), bytes_read);
177     VariantTensorDataProto proto;
178     if (!proto.ParseFromString(buffer)) {
179       return errors::DataLoss("Unable to parse VariantTensorDataProto from ",
180                               "buffer of size ", string_length, ". ",
181                               "Bundle entry offset: ", offset, " size: ", size);
182     }
183     Variant v = proto;
184     if (!DecodeUnaryVariant(&v)) {
185       return errors::Internal("Could not decode variant with type_name: \"",
186                               v.TypeName(), "\".  Perhaps you forgot to ",
187                               "register a decoder via ",
188                               "REGISTER_UNARY_VARIANT_DECODE_FUNCTION?");
189     }
190 
191     // Read the checksum.
192     uint32 checksum = 0;
193     size_t unused_bytes_read = 0;
194     TF_RETURN_IF_ERROR(buffered_file->ReadNBytes(
195         sizeof(uint32), reinterpret_cast<char*>(&checksum),
196         &unused_bytes_read));
197     if (crc32c::Unmask(checksum) != *actual_crc32c) {
198       return errors::DataLoss(
199           "The checksum after Variant ", i, " does not match.",
200           " Expected: ", strings::Printf("%08u", crc32c::Unmask(checksum)),
201           " Actual: ", strings::Printf("%08u", *actual_crc32c));
202     }
203     *actual_crc32c = crc32c::Extend(
204         *actual_crc32c, reinterpret_cast<char*>(&checksum), sizeof(uint32));
205 
206     ret->flat<Variant>()(i) = std::move(v);
207   }
208 
209   return OkStatus();
210 }
211 
GetBackingBuffer(const Tensor & val)212 char* GetBackingBuffer(const Tensor& val) {
213   CHECK(DataTypeCanUseMemcpy(val.dtype())) << val.dtype();
214   return const_cast<char*>(val.tensor_data().data());
215 }
216 
GetStringBackingBuffer(const Tensor & val)217 tstring* GetStringBackingBuffer(const Tensor& val) {
218   CHECK_EQ(DT_STRING, val.dtype());
219   return const_cast<tstring*>(val.flat<tstring>().data());
220 }
221 
ParseEntryProto(StringPiece key,StringPiece value,protobuf::MessageLite * out)222 Status ParseEntryProto(StringPiece key, StringPiece value,
223                        protobuf::MessageLite* out) {
224   if (!out->ParseFromArray(value.data(), value.size())) {
225     return errors::DataLoss("Entry for key ", key, " not parseable.");
226   }
227   return OkStatus();
228 }
229 
230 // Serializes the data bytes of the non-string tensor "val".  Discards the
231 // original content of "bytes_written", and on OK updates it with number of
232 // bytes written.
233 // REQUIRES: val.dtype() != DT_STRING
WriteTensor(const Tensor & val,FileOutputBuffer * out,size_t * bytes_written)234 Status WriteTensor(const Tensor& val, FileOutputBuffer* out,
235                    size_t* bytes_written) {
236   DCHECK_NE(val.dtype(), DT_STRING);
237   DCHECK_NE(val.dtype(), DT_VARIANT);
238   *bytes_written = val.TotalBytes();
239   char* buf = GetBackingBuffer(val);
240   VLOG(1) << "Appending " << *bytes_written << " bytes to file";
241   return out->Append(StringPiece(buf, *bytes_written));
242 }
243 
244 // Serializes string tensor "val".  "bytes_written" is treated in the same
245 // fashion as WriteTensor().
246 //
247 // Checksums all bytes written and stores it into "crc32c".
248 // REQUIRES: val.dtype() == DT_STRING
WriteStringTensor(const Tensor & val,FileOutputBuffer * out,size_t * bytes_written,uint32 * crc32c)249 Status WriteStringTensor(const Tensor& val, FileOutputBuffer* out,
250                          size_t* bytes_written, uint32* crc32c) {
251   // On-disk format:
252   //   [varint64 len0]..[varint64 lenL][4 byte cksum on lengths][string bytes]
253   // Var "crc32c" checksums the string lengths (as uint64, not varint64 bytes),
254   // the length-checksum, and all the string bytes.
255   DCHECK_EQ(val.dtype(), DT_STRING);
256   const tstring* strings = GetStringBackingBuffer(val);
257 
258   // Writes the varint lengths.
259   string lengths;
260   lengths.reserve(val.NumElements());  // At least 1 byte per element.
261   *crc32c = 0;
262   for (int64_t i = 0; i < val.NumElements(); ++i) {
263     const tstring* elem = &strings[i];
264     DCHECK_EQ(elem->size(), static_cast<uint64>(elem->size()));
265     const uint64 elem_size = static_cast<uint64>(elem->size());
266 
267     core::PutVarint64(&lengths, elem_size);
268     if (elem_size <= UINT32_MAX) {
269       // We need to do this because older checkpoints only used uint32s and we
270       // should still support them.
271       const uint32 elem_size_uint32 = static_cast<uint32>(elem_size);
272       *crc32c = crc32c::Extend(*crc32c,
273                                reinterpret_cast<const char*>(&elem_size_uint32),
274                                sizeof(uint32));
275     } else {
276       *crc32c = crc32c::Extend(
277           *crc32c, reinterpret_cast<const char*>(&elem_size), sizeof(uint64));
278     }
279   }
280   TF_RETURN_IF_ERROR(out->Append(lengths));
281   *bytes_written = lengths.size();
282 
283   // Writes the length checksum.
284   const uint32 length_checksum = crc32c::Mask(*crc32c);
285   TF_RETURN_IF_ERROR(out->Append(StringPiece(
286       reinterpret_cast<const char*>(&length_checksum), sizeof(uint32))));
287   *crc32c = crc32c::Extend(
288       *crc32c, reinterpret_cast<const char*>(&length_checksum), sizeof(uint32));
289   *bytes_written += sizeof(uint32);
290 
291   // Writes all the string bytes out.
292   for (int64_t i = 0; i < val.NumElements(); ++i) {
293     const tstring* string = &strings[i];
294     TF_RETURN_IF_ERROR(out->Append(*string));
295     *bytes_written += string->size();
296     *crc32c = crc32c::Extend(*crc32c, string->data(), string->size());
297   }
298   return OkStatus();
299 }
300 
WriteVariantTensor(const Tensor & val,FileOutputBuffer * out,size_t * bytes_written,uint32 * crc32c)301 Status WriteVariantTensor(const Tensor& val, FileOutputBuffer* out,
302                           size_t* bytes_written, uint32* crc32c) {
303   // On-disk format:
304   //   [varint64 len1][bytes variant1][4 byte checksum]
305   //   ..
306   //   [varint64 lenN][bytes variantN][4 byte checksum]
307   // Var "crc32c" checksums all the lens, variant bytes, individual variant
308   // checksums (as uint32, not varint32 bytes).
309   DCHECK_EQ(val.dtype(), DT_VARIANT);
310 
311   *crc32c = 0;
312   *bytes_written = 0;
313   for (int64_t i = 0; i < val.NumElements(); ++i) {
314     VariantTensorData data;
315     val.flat<Variant>()(i).Encode(&data);
316     VariantTensorDataProto proto;
317     data.ToProto(&proto);
318     string elem;
319     if (!proto.SerializeToString(&elem)) {
320       return errors::Unknown(
321           "Failed to serialize tensor data of size ", proto.ByteSizeLong(),
322           ". Tensor: ", val.flat<Variant>()(i).DebugString());
323     }
324 
325     // Write the length of the serialized variant.
326     DCHECK_EQ(elem.size(), static_cast<uint64>(elem.size()));
327     const auto elem_size = static_cast<uint64>(elem.size());
328     string len;
329     core::PutVarint64(&len, elem_size);
330     TF_RETURN_IF_ERROR(out->Append(len));
331     *crc32c = crc32c::Extend(*crc32c, reinterpret_cast<const char*>(&elem_size),
332                              sizeof(uint64));
333     *bytes_written += len.size();
334 
335     // Write the serialized variant.
336     TF_RETURN_IF_ERROR(out->Append(elem));
337     *crc32c = crc32c::Extend(*crc32c, elem.data(), elem.size());
338     *bytes_written += elem.size();
339 
340     // Write the checksum.
341     const uint32 length_checksum = crc32c::Mask(*crc32c);
342     TF_RETURN_IF_ERROR(out->Append(StringPiece(
343         reinterpret_cast<const char*>(&length_checksum), sizeof(uint32))));
344     *crc32c =
345         crc32c::Extend(*crc32c, reinterpret_cast<const char*>(&length_checksum),
346                        sizeof(uint32));
347     *bytes_written += sizeof(uint32);
348   }
349 
350   return OkStatus();
351 }
352 
353 // Returns whether "slice_spec" is a full slice, with respect to the full shape.
354 //
355 // This can happen say, when "slice_spec" is
356 // "TensorSlice(full_tensor_shape.dims())", or when it is "TensorSlice({{0,
357 // dim(0)}, ..., {0, dim(N)}})" -- a degenerate case we need to guard against.
IsFullSlice(const TensorSlice & slice_spec,const TensorShape & full_tensor_shape)358 bool IsFullSlice(const TensorSlice& slice_spec,
359                  const TensorShape& full_tensor_shape) {
360   if (slice_spec.IsFull()) {
361     return true;
362   } else {
363     TensorShape sliced_shape;
364     slice_spec.SliceTensorShape(full_tensor_shape, &sliced_shape).IgnoreError();
365     return sliced_shape == full_tensor_shape;
366   }
367 }
368 
CorruptFileError(const Status & in_status,const string & filename,const string & detail)369 Status CorruptFileError(const Status& in_status, const string& filename,
370                         const string& detail) {
371   if (in_status.ok()) {
372     return errors::Internal("Unable to read file (", filename,
373                             "). Perhaps the file is corrupt or was produced by "
374                             "a newer version of TensorFlow with format changes "
375                             "(",
376                             detail, ")");
377   }
378   return Status(
379       in_status.code(),
380       strings::StrCat("Unable to read file (", filename,
381                       "). Perhaps the file is corrupt or was produced by a "
382                       "newer version of TensorFlow with format changes (",
383                       detail, "): ", in_status.error_message()));
384 }
385 
TableBuilderOptions()386 table::Options TableBuilderOptions() {
387   table::Options o;
388   // Compressed tables cannot be read by TensorFlow releases prior to 1.1.
389   // To smoothen the transition, compressed writes are disabled for now
390   // (version 1.2) with the intention that they will be enabled again at
391   // some point (perhaps the 1.3 release?).
392   o.compression = table::kNoCompression;
393   return o;
394 }
395 
396 // Writes zeros to output buffer to align the next write to the requested
397 // alignment. "size" is the current size of the buffer and is updated to the
398 // new size.
PadAlignment(FileOutputBuffer * out,int alignment,int64_t * size)399 Status PadAlignment(FileOutputBuffer* out, int alignment, int64_t* size) {
400   int bytes_over = *size % alignment;
401   if (bytes_over == 0) {
402     return OkStatus();
403   }
404   int bytes_to_write = alignment - bytes_over;
405   Status status = out->Append(string(bytes_to_write, '\0'));
406   if (status.ok()) {
407     *size += bytes_to_write;
408   }
409   return status;
410 }
411 
412 }  // namespace
413 
BundleWriter(Env * env,StringPiece prefix,const Options & options)414 BundleWriter::BundleWriter(Env* env, StringPiece prefix, const Options& options)
415     : env_(env), options_(options), prefix_(prefix), out_(nullptr), size_(0) {
416   status_ = env_->HasAtomicMove(prefix_, &use_temp_file_);
417   if (!status_.ok()) return;
418 
419   data_path_ = DataFilename(prefix_, 0, 1);
420   metadata_path_ = MetaFilename(prefix_);
421   if (use_temp_file_) {
422     data_path_ = strings::StrCat(data_path_, ".tempstate", random::New64());
423     metadata_path_ =
424         strings::StrCat(metadata_path_, ".tempstate", random::New64());
425   }
426 
427   status_ = env_->CreateDir(string(io::Dirname(prefix_)));
428   if (!status_.ok() && !errors::IsAlreadyExists(status_)) {
429     return;
430   }
431 
432   std::unique_ptr<WritableFile> wrapper;
433   status_ = env_->NewWritableFile(data_path_, &wrapper);
434   if (!status_.ok()) return;
435   out_ = std::unique_ptr<FileOutputBuffer>(
436       new FileOutputBuffer(wrapper.release(), 8 << 20 /* 8MB write buffer */));
437 
438   VLOG(1) << "Writing to file " << data_path_;
439 }
440 
Add(StringPiece key,const Tensor & val)441 Status BundleWriter::Add(StringPiece key, const Tensor& val) {
442   if (!status_.ok()) return status_;
443   CHECK_NE(key, kHeaderEntryKey);
444   const string key_string(key);
445   if (entries_.find(key_string) != entries_.end()) {
446     status_ = errors::InvalidArgument("Adding duplicate key: ", key);
447     return status_;
448   }
449 
450   BundleEntryProto* entry = &entries_[key_string];
451   entry->set_dtype(val.dtype());
452   val.shape().AsProto(entry->mutable_shape());
453   entry->set_shard_id(0);
454   entry->set_offset(size_);
455 
456   // Updates the data file.
457   size_t data_bytes_written = 0;
458   uint32 crc32c = 0;
459   out_->clear_crc32c();
460   if (val.dtype() == DT_STRING) {
461     status_ = WriteStringTensor(val, out_.get(), &data_bytes_written, &crc32c);
462   } else if (val.dtype() == DT_VARIANT) {
463     status_ = WriteVariantTensor(val, out_.get(), &data_bytes_written, &crc32c);
464   } else {
465     status_ = WriteTensor(val, out_.get(), &data_bytes_written);
466     crc32c = out_->crc32c();
467   }
468 
469   if (status_.ok()) {
470     entry->set_size(data_bytes_written);
471     entry->set_crc32c(crc32c::Mask(crc32c));
472     size_ += data_bytes_written;
473     status_ = PadAlignment(out_.get(), options_.data_alignment, &size_);
474   }
475   return status_;
476 }
477 
AddSlice(StringPiece full_tensor_key,const TensorShape & full_tensor_shape,const TensorSlice & slice_spec,const Tensor & slice_tensor)478 Status BundleWriter::AddSlice(StringPiece full_tensor_key,
479                               const TensorShape& full_tensor_shape,
480                               const TensorSlice& slice_spec,
481                               const Tensor& slice_tensor) {
482   if (!status_.ok()) return status_;
483   CHECK_NE(full_tensor_key, kHeaderEntryKey);
484 
485   // If just a singleton full slice, use the regular Add() to be more efficient.
486   if (IsFullSlice(slice_spec, full_tensor_shape)) {
487     return Add(full_tensor_key, slice_tensor);
488   }
489 
490   // Inserts/updates the full tensor's metadata entry.
491   //
492   // In the case of a sharded save, MergeBundles() is responsible for merging
493   // the "slices" field of multiple metadata entries corresponding to the same
494   // full tensor.
495   const string full_tensor_key_string(full_tensor_key);
496   BundleEntryProto* full_entry = &entries_[full_tensor_key_string];
497   if (full_entry->dtype() != DT_INVALID) {
498     CHECK_EQ(full_entry->dtype(), slice_tensor.dtype());
499   }
500   if (full_entry->has_shape()) {
501     CHECK(TensorShape(full_entry->shape()) == full_tensor_shape);
502   }
503 
504   // Populates dtype, shape, and slices.  Intentionally leaving out shard_id and
505   // offset, which do not make sense for this full tensor entry.
506   full_entry->set_dtype(slice_tensor.dtype());
507   full_tensor_shape.AsProto(full_entry->mutable_shape());
508   TensorSliceProto* slice_proto = full_entry->add_slices();
509   slice_spec.AsProto(slice_proto);
510 
511   // The slice itself is handled by a regular Add(), which includes adding its
512   // own metadata entry, and writing out the slice's values.
513   const string slice_name =
514       checkpoint::EncodeTensorNameSlice(full_tensor_key_string, slice_spec);
515   status_ = Add(slice_name, slice_tensor);
516   return status_;
517 }
518 
519 // TODO(zongheng): on metadata write failure or !status_.ok(), consider removing
520 // the orphaned data file.
Finish()521 Status BundleWriter::Finish() {
522   if (out_) {
523     status_.Update(out_->Close());
524     out_ = nullptr;
525     if (status_.ok()) {
526       if (use_temp_file_) {
527         status_ =
528             Env::Default()->RenameFile(data_path_, DataFilename(prefix_, 0, 1));
529       }
530     } else {
531       Env::Default()->DeleteFile(data_path_).IgnoreError();
532     }
533   }
534   if (!status_.ok()) return status_;
535   // Build key -> BundleEntryProto table.
536   std::unique_ptr<WritableFile> file;
537   status_ = env_->NewWritableFile(metadata_path_, &file);
538   if (!status_.ok()) return status_;
539   {
540     // N.B.: the default use of Snappy compression may not be supported on all
541     // platforms (e.g. Android).  The metadata file is small, so this is fine.
542     table::Options options;
543     options.compression = table::kNoCompression;
544     table::TableBuilder builder(options, file.get());
545     // Header entry.
546     BundleHeaderProto header;
547     header.set_num_shards(1);
548     header.set_endianness(BundleHeaderProto::LITTLE);
549     if (!port::kLittleEndian) header.set_endianness(BundleHeaderProto::BIG);
550     VersionDef* version = header.mutable_version();
551     version->set_producer(kTensorBundleVersion);
552     version->set_min_consumer(kTensorBundleMinConsumer);
553 
554     builder.Add(kHeaderEntryKey, header.SerializeAsString());
555 
556     // All others.
557     for (const auto& p : entries_) {
558       builder.Add(p.first, p.second.SerializeAsString());
559     }
560     status_ = builder.Finish();
561   }
562   status_.Update(file->Close());
563   if (!status_.ok()) {
564     Env::Default()->DeleteFile(metadata_path_).IgnoreError();
565     return status_;
566   } else if (use_temp_file_) {
567     status_ = Env::Default()->RenameFile(metadata_path_, MetaFilename(prefix_));
568     if (!status_.ok()) return status_;
569   }
570   status_ = errors::Internal("BundleWriter is closed");
571   return OkStatus();
572 }
573 
574 // Merging tensor bundles.
575 
576 // Accumulator of metadata states during a merge.
577 struct MergeState {
578   // Accumulated from the header entries.
579   int num_shards = 0;
580 
581   // Derives "endianness" and "version" from the first bundle merged (hence the
582   // "seen_first_bundle" guard).  The two fields must be the same for all
583   // bundles in a merge.
584   bool seen_first_bundle = false;
585   BundleHeaderProto_Endianness endianness;
586   VersionDef version;
587 
588   // Tensor key -> BundleEntryProto.
589   std::map<string, BundleEntryProto> entries;
590   // Data file path -> new shard id in the final merged bundle.
591   std::unordered_map<string, int32> shard_ids;
592 };
593 
594 // Merges entries of "prefix" into the accumulator state "merge".
595 // Returns OK iff the merge succeeds.
MergeOneBundle(Env * env,StringPiece prefix,MergeState * merge_state)596 static Status MergeOneBundle(Env* env, StringPiece prefix,
597                              MergeState* merge_state) {
598   VLOG(1) << "Merging bundle:" << prefix;
599   const string filename = MetaFilename(prefix);
600   uint64 file_size;
601   TF_RETURN_IF_ERROR(env->GetFileSize(filename, &file_size));
602   std::unique_ptr<RandomAccessFile> file;
603   TF_RETURN_IF_ERROR(env->NewRandomAccessFile(filename, &file));
604 
605   table::Table* table = nullptr;
606   TF_RETURN_IF_ERROR(
607       table::Table::Open(TableBuilderOptions(), file.get(), file_size, &table));
608   std::unique_ptr<table::Table> table_deleter(table);
609   std::unique_ptr<table::Iterator> iter(table->NewIterator());
610 
611   int num_shards;
612   // Process header.
613   {
614     iter->Seek(kHeaderEntryKey);
615     if (!iter->Valid()) {
616       return CorruptFileError(iter->status(), filename,
617                               "failed to seek to header entry");
618     }
619     BundleHeaderProto header;
620     Status s = ParseEntryProto(iter->key(), iter->value(), &header);
621     if (!s.ok()) return CorruptFileError(s, filename, "unable to parse header");
622 
623     merge_state->num_shards += header.num_shards();
624     if (!merge_state->seen_first_bundle) {
625       merge_state->seen_first_bundle = true;
626       merge_state->endianness = header.endianness();
627       merge_state->version = header.version();
628     } else {
629       // Validates "endianness".
630       if (merge_state->endianness != header.endianness()) {
631         return errors::InvalidArgument(
632             "Merging bundles with conflicting endianness; inputs corrupted?");
633       }
634       // Validates "version".
635       string curr_version, merge_version;
636       header.version().SerializeToString(&curr_version);
637       merge_state->version.SerializeToString(&merge_version);
638       if (curr_version != merge_version) {
639         return errors::InvalidArgument(
640             "Merging bundles with different format versions: merged ",
641             merge_version, " vs. curr ", curr_version);
642       }
643     }
644     num_shards = header.num_shards();
645     iter->Next();
646   }
647 
648   // Loops through the non-header to-merge entries.
649   BundleEntryProto to_merge_entry;
650   for (; iter->Valid(); iter->Next()) {
651     const string key(iter->key());
652     const auto entry_iter = merge_state->entries.find(key);
653 
654     // Illegal: the duplicated entry is a non-slice tensor.
655     if (entry_iter != merge_state->entries.end() &&
656         entry_iter->second.slices().empty()) {
657       return errors::InvalidArgument(
658           "Duplicate tensor keyed by ", key,
659           " encountered, when merging prefix: ", prefix);
660     }
661 
662     TF_RETURN_IF_ERROR(
663         ParseEntryProto(iter->key(), iter->value(), &to_merge_entry));
664 
665     // The duplicated entry holds metadata for a sliced full tensor.
666     // Allows the duplication and merges "slices".
667     if (entry_iter != merge_state->entries.end()) {
668       BundleEntryProto& existing_entry = entry_iter->second;
669       if (to_merge_entry.slices().empty()) {
670         return errors::Internal(
671             "Duplicate tensor keyed by ", key,
672             "; attempting to merge in a non-slice bundle entry");
673       }
674       // Only needs merge the "slices" field (and validate dtype/shape).
675       for (int i = 0; i < to_merge_entry.slices_size(); ++i) {
676         TensorSliceProto* slot = existing_entry.add_slices();
677         *slot = to_merge_entry.slices(i);
678       }
679       CHECK_EQ(existing_entry.dtype(), to_merge_entry.dtype());
680       CHECK(TensorShape(existing_entry.shape()) ==
681             TensorShape(to_merge_entry.shape()));
682       continue;
683     }
684 
685     // Key doesn't duplicate: a fresh tensor/slice entry.
686     auto result = merge_state->shard_ids.insert(
687         {DataFilename(prefix, to_merge_entry.shard_id(), num_shards),
688          merge_state->shard_ids.size()});
689     to_merge_entry.set_shard_id(result.first->second);
690     merge_state->entries[key] = to_merge_entry;
691   }
692   return OkStatus();
693 }
694 
MergeBundles(Env * env,gtl::ArraySlice<tstring> prefixes,StringPiece merged_prefix,bool allow_missing_files)695 Status MergeBundles(Env* env, gtl::ArraySlice<tstring> prefixes,
696                     StringPiece merged_prefix, bool allow_missing_files) {
697   // Merges all metadata tables.
698   // TODO(zhifengc): KeyValue sorter if it becomes too big.
699   MergeState merge;
700   Status status = env->CreateDir(string(io::Dirname(merged_prefix)));
701   if (!status.ok() && !errors::IsAlreadyExists(status)) return status;
702   bool atleast_one_file_exists = false;
703   for (auto& prefix : prefixes) {
704     if (!env->FileExists(MetaFilename(prefix)).ok()) {
705       if (allow_missing_files) continue;
706       return errors::InvalidArgument(
707           "allow_missing_files was set to false and ", prefix,
708           " did not exist.", env->FileExists(prefix).ToString());
709     }
710     atleast_one_file_exists = true;
711     TF_RETURN_IF_ERROR(MergeOneBundle(env, prefix, &merge));
712   }
713   if (!atleast_one_file_exists) {
714     return errors::InvalidArgument(
715         "At least one prefix checkpoint file must exist, but none existed.");
716   }
717   // Renames data files to contain the merged bundle prefix.
718   for (const auto& p : merge.shard_ids) {
719     VLOG(1) << "Renaming " << p.first << " to "
720             << DataFilename(merged_prefix, p.second, merge.shard_ids.size());
721     TF_RETURN_IF_ERROR(env->RenameFile(
722         p.first,
723         DataFilename(merged_prefix, p.second, merge.shard_ids.size())));
724   }
725 
726   // Writes the final metadata table under the merged prefix.
727   std::unique_ptr<WritableFile> merged_metadata;
728   TF_RETURN_IF_ERROR(
729       env->NewWritableFile(MetaFilename(merged_prefix), &merged_metadata));
730   {
731     table::TableBuilder builder(TableBuilderOptions(), merged_metadata.get());
732     // Header entry.
733     BundleHeaderProto header;
734     header.set_num_shards(merge.num_shards);
735     header.set_endianness(merge.endianness);
736     *header.mutable_version() = merge.version;
737     builder.Add(kHeaderEntryKey, header.SerializeAsString());
738     // All others.
739     for (const auto& p : merge.entries) {
740       builder.Add(p.first, p.second.SerializeAsString());
741     }
742     status = builder.Finish();
743   }
744   status.Update(merged_metadata->Close());
745   if (!status.ok()) return status;
746   VLOG(1) << "Merged bundles to:" << merged_prefix;
747 
748   // Cleanup: best effort based and ignores errors.
749   for (const tstring& prefix : prefixes) {
750     env->DeleteFile(MetaFilename(prefix)).IgnoreError();
751   }
752   return status;
753 }
754 
755 // Interface for reading a tensor bundle.
756 
BundleReader(Env * env,StringPiece prefix)757 BundleReader::BundleReader(Env* env, StringPiece prefix)
758     : env_(env),
759       prefix_(prefix),
760       metadata_(nullptr),
761       table_(nullptr),
762       index_cache_(nullptr),
763       iter_(nullptr),
764       need_to_swap_bytes_(false) {
765   const string filename = MetaFilename(prefix_);
766   uint64 file_size;
767   status_ = env_->GetFileSize(filename, &file_size);
768   if (!status_.ok()) return;
769 
770   // Opens the metadata table.
771   std::unique_ptr<RandomAccessFile> wrapper;
772   status_ = env_->NewRandomAccessFile(filename, &wrapper);
773   if (!status_.ok()) return;
774   metadata_ = wrapper.release();
775 
776   table::Options o;
777   int64_t cache_size;
778   Status s =
779       ReadInt64FromEnvVar("TF_TABLE_INDEX_CACHE_SIZE_IN_MB", 0, &cache_size);
780   if (s.ok() && cache_size > 0) {
781     index_cache_ = table::NewLRUCache(cache_size << 20);
782     o.block_cache = index_cache_;
783   }
784 
785   status_ = table::Table::Open(o, metadata_, file_size, &table_);
786   if (!status_.ok()) return;
787   iter_ = table_->NewIterator();
788 
789   // Reads "num_shards_" from the first entry.
790   iter_->Seek(kHeaderEntryKey);
791   if (!iter_->Valid()) {
792     status_ = CorruptFileError(iter_->status(), filename,
793                                "failed to seek to header entry");
794     return;
795   }
796   BundleHeaderProto header;
797   status_ = ParseEntryProto(iter_->key(), iter_->value(), &header);
798   if (!status_.ok()) {
799     status_ = CorruptFileError(status_, filename, "unable to parse header");
800     return;
801   }
802   num_shards_ = header.num_shards();
803   if ((header.endianness() == BundleHeaderProto::BIG && port::kLittleEndian) ||
804       (header.endianness() == BundleHeaderProto::LITTLE &&
805        !port::kLittleEndian)) {
806     need_to_swap_bytes_ = true;
807   }
808   status_ = CheckVersions(header.version(), kTensorBundleVersion,
809                           kTensorBundleMinProducer, "Checkpoint", "checkpoint");
810 }
811 
~BundleReader()812 BundleReader::~BundleReader() {
813   delete metadata_;
814   delete iter_;
815   delete table_;
816   if (index_cache_) {
817     delete index_cache_;
818   }
819   // InputBuffer does not own the underlying RandomAccessFile.
820   for (auto pair : data_) {
821     if (pair.second != nullptr && pair.second->file() != nullptr) {
822       delete pair.second->file();
823     }
824   }
825   for (auto& temp : data_) {
826     delete temp.second;
827   }
828   for (auto& temp : tensor_slices_) {
829     delete temp.second;
830   }
831   data_.clear();
832   tensor_slices_.clear();
833 }
834 
GetBundleEntryProto(StringPiece key,BundleEntryProto * entry)835 Status BundleReader::GetBundleEntryProto(StringPiece key,
836                                          BundleEntryProto* entry) {
837   entry->Clear();
838   TF_CHECK_OK(status_);
839   Seek(key);
840   if (!iter_->Valid() || iter_->key() != key) {
841     return errors::NotFound("Key ", key, " not found in checkpoint");
842   }
843 
844   BundleEntryProto entry_copy;
845   TF_RETURN_IF_ERROR(
846       ParseEntryProto(iter_->key(), iter_->value(), &entry_copy));
847   if (!TensorShape::IsValid(entry_copy.shape())) {
848     return errors::DataLoss("Invalid tensor shape: ", key, " ",
849                             entry_copy.shape().ShortDebugString());
850   }
851 
852   entry->Swap(&entry_copy);
853   return OkStatus();
854 }
855 
GetValue(const BundleEntryProto & entry,Tensor * val)856 Status BundleReader::GetValue(const BundleEntryProto& entry, Tensor* val) {
857   Tensor* ret = val;
858   const TensorShape stored_shape(TensorShape(entry.shape()));
859   if (val->NumElements() == 0) {
860     ret = new Tensor(entry.dtype(), stored_shape);
861   }
862 
863   // Validates the "size" field.
864   if (entry.dtype() != DT_STRING && entry.dtype() != DT_VARIANT) {
865     if (entry.size() != ret->TotalBytes()) {
866       return errors::DataLoss("Invalid size in bundle entry: key ", key(),
867                               "; stored size ", entry.size(),
868                               "; expected size ", ret->TotalBytes());
869     }
870   } else if (entry.dtype() == DT_STRING) {
871     // Relaxes the check for string tensors as follows:
872     //   entry.size() == bytes(varint lengths) + bytes(data)
873     //                >= NumElems + bytes(data), since size bytes(varint) >= 1.
874     //   TotalBytes() == sizeof(tstring) * NumElems + bytes(data)
875     // Since we don't know bytes(varint lengths), we just check an inequality.
876     const size_t lower_bound = ret->NumElements() + ret->TotalBytes() -
877                                sizeof(tstring) * ret->NumElements();
878     if (entry.size() < lower_bound) {
879       return errors::DataLoss("Invalid size in bundle entry: key ", key(),
880                               "; stored size ", entry.size(),
881                               "; expected size is at least ", lower_bound);
882     }
883   }
884 
885   // Open the data file if it has not been opened.
886   io::InputBuffer* buffered_file = data_[entry.shard_id()];
887   if (buffered_file == nullptr) {
888     std::unique_ptr<RandomAccessFile> file = nullptr;
889     TF_RETURN_IF_ERROR(env_->NewRandomAccessFile(
890         DataFilename(prefix_, entry.shard_id(), num_shards_), &file));
891     buffered_file = new io::InputBuffer(file.release(), kBufferSize);
892     // The InputBuffer and RandomAccessFile objects are both released in dtor.
893     data_[entry.shard_id()] = buffered_file;
894   }
895   CHECK(buffered_file != nullptr);
896 
897   TF_RETURN_IF_ERROR(buffered_file->Seek(entry.offset()));
898   uint32 actual_crc32c = 0;
899 
900   if (DataTypeCanUseMemcpy(entry.dtype())) {
901     char* backing_buffer = const_cast<char*>((ret->tensor_data().data()));
902     size_t unused_bytes_read;
903     if (entry.size() > kBufferSize) {
904       StringPiece sp;
905       TF_RETURN_IF_ERROR(buffered_file->file()->Read(
906           entry.offset(), entry.size(), &sp, backing_buffer));
907       if (sp.data() != backing_buffer) {
908         memmove(backing_buffer, sp.data(), entry.size());
909       }
910     } else {
911       TF_RETURN_IF_ERROR(buffered_file->ReadNBytes(entry.size(), backing_buffer,
912                                                    &unused_bytes_read));
913     }
914     // Note that we compute the checksum *before* byte-swapping. The checksum
915     // should be on the bytes in the order they appear in the file.
916     actual_crc32c = crc32c::Value(backing_buffer, entry.size());
917     if (need_to_swap_bytes_) {
918       TF_RETURN_IF_ERROR(ByteSwapTensor(ret));
919     }
920   } else if (entry.dtype() == DT_VARIANT) {
921     if (need_to_swap_bytes_) {
922       return errors::Unimplemented(
923           "TensorBundle at ", prefix_,
924           "is of a different endianness than this machine's hardware, and "
925           "the bundle contains a variant (arbitrary C++ type) tensor. "
926           "Byte-swapping of variant tensors is not currently implemented.");
927     }
928     // Relies on io::InputBuffer's buffering, because we issue many neighboring
929     // reads for a single string tensor.
930     TF_RETURN_IF_ERROR(ReadVariantTensor(buffered_file, ret, entry.offset(),
931                                          entry.size(), &actual_crc32c));
932   } else {
933     // Relies on io::InputBuffer's buffering, because we issue many neighboring
934     // reads for a single string tensor.
935     TF_RETURN_IF_ERROR(ReadStringTensor(
936         buffered_file, ret->NumElements(), entry.offset(), entry.size(),
937         GetStringBackingBuffer(*ret), &actual_crc32c, need_to_swap_bytes_));
938   }
939   if (crc32c::Unmask(entry.crc32c()) != actual_crc32c) {
940     return errors::DataLoss(
941         "TensorBundle at ", prefix_, " shard ", entry.shard_id(), " (",
942         entry.size(), " bytes): Checksum does not match: stored ",
943         strings::Printf("%08u", crc32c::Unmask(entry.crc32c())),
944         " vs. calculated on the restored bytes ", actual_crc32c);
945   }
946 
947   *val = *ret;
948   if (ret != val) delete ret;
949   return OkStatus();
950 }
951 
Lookup(StringPiece key,Tensor * val)952 Status BundleReader::Lookup(StringPiece key, Tensor* val) {
953   CHECK(val != nullptr);
954   BundleEntryProto entry;
955   TF_RETURN_IF_ERROR(GetBundleEntryProto(key, &entry));
956 
957   if (entry.slices().empty()) {
958     return GetValue(entry, val);
959   } else {
960     return GetSliceValue(
961         key, entry,
962         /* a full slice */ TensorSlice(TensorShape(entry.shape()).dims()), val);
963   }
964 }
965 
ReadCurrent(Tensor * val)966 Status BundleReader::ReadCurrent(Tensor* val) {
967   CHECK(val != nullptr);
968   BundleEntryProto entry;
969   TF_RETURN_IF_ERROR(ParseEntryProto(iter_->key(), iter_->value(), &entry));
970   if (!TensorShape::IsValid(entry.shape())) {
971     return errors::DataLoss("Invalid tensor shape: ", iter_->key(), " ",
972                             entry.shape().ShortDebugString());
973   }
974 
975   if (entry.slices().empty()) {
976     return GetValue(entry, val);
977   } else {
978     return GetSliceValue(
979         iter_->key(), entry,
980         /* a full slice */ TensorSlice(TensorShape(entry.shape()).dims()), val);
981   }
982 }
983 
LookupTensorSlices(StringPiece key,std::vector<TensorSlice> * slices)984 Status BundleReader::LookupTensorSlices(StringPiece key,
985                                         std::vector<TensorSlice>* slices) {
986   slices->clear();
987   BundleEntryProto entry;
988   TF_RETURN_IF_ERROR(GetBundleEntryProto(key, &entry));
989   slices->reserve(entry.slices_size());
990   for (const auto& slice : entry.slices()) {
991     slices->emplace_back(slice);
992   }
993   return OkStatus();
994 }
995 
LookupSlice(StringPiece full_tensor_key,const TensorSlice & slice_spec,Tensor * val)996 Status BundleReader::LookupSlice(StringPiece full_tensor_key,
997                                  const TensorSlice& slice_spec, Tensor* val) {
998   CHECK(val != nullptr);
999   BundleEntryProto entry;
1000   TF_RETURN_IF_ERROR(GetBundleEntryProto(full_tensor_key, &entry));
1001   return GetSliceValue(full_tensor_key, entry, slice_spec, val);
1002 }
1003 
GetSliceValue(StringPiece full_tensor_key,const BundleEntryProto & full_tensor_entry,const TensorSlice & slice_spec,Tensor * val)1004 Status BundleReader::GetSliceValue(StringPiece full_tensor_key,
1005                                    const BundleEntryProto& full_tensor_entry,
1006                                    const TensorSlice& slice_spec, Tensor* val) {
1007   using checkpoint::RegisterTensorSlice;
1008   using checkpoint::TensorSliceSet;
1009   DCHECK_GE(full_tensor_entry.slices_size(), 0);
1010 
1011   const TensorShape full_shape(TensorShape(full_tensor_entry.shape()));
1012   std::vector<std::pair<TensorSlice, string>> details;
1013   const string full_tensor_key_string(full_tensor_key);
1014   const TensorSliceSet* tss =
1015       gtl::FindPtrOrNull(tensor_slices_, full_tensor_key_string);
1016 
1017   // Populates the "full tensor key -> TensorSliceSet" cache.
1018   if (tss == nullptr) {
1019     if (full_tensor_entry.slices().empty()) {
1020       // Special case: a writer has saved a tensor fully, but the reader wants
1021       // to read in slices.  We therefore register the full slice on-demand here
1022       // without further complicating the on-disk bundle format.
1023       TF_RETURN_IF_ERROR(RegisterTensorSlice(
1024           full_tensor_key_string, full_shape, full_tensor_entry.dtype(),
1025           /* tag */ "",
1026           /* full slice */ TensorSlice(full_shape.dims()), &tensor_slices_));
1027     }
1028     for (const TensorSliceProto& slice : full_tensor_entry.slices()) {
1029       TF_RETURN_IF_ERROR(RegisterTensorSlice(
1030           full_tensor_key_string, full_shape, full_tensor_entry.dtype(),
1031           /* tag */ "", TensorSlice(slice), &tensor_slices_));
1032     }
1033     tss = gtl::FindPtrOrNull(tensor_slices_, full_tensor_key_string);
1034     CHECK_NE(tss, nullptr);
1035   }
1036   if (!tss->QueryMeta(slice_spec, &details)) {
1037     return errors::InvalidArgument(
1038         "Does not have sufficient slices for partitioned tensor ",
1039         full_tensor_key,
1040         " to restore in slice_spec: ", slice_spec.DebugString());
1041   }
1042 
1043   // The union of the slices in "details" covers "slice_spec".  Performs the
1044   // copies from each.
1045   BundleEntryProto stored_slice_entry = full_tensor_entry;
1046   for (const auto& slice_tag_pair : details) {
1047     // Seeks for the stored slice.
1048     const TensorSlice& stored_slice = slice_tag_pair.first;
1049 
1050     // We already have the entry for the full tensor, so don't query again if
1051     // the slice is full.
1052     if (!stored_slice.IsFull()) {
1053       const string encoded_stored_slice_name =
1054           checkpoint::EncodeTensorNameSlice(full_tensor_key_string,
1055                                             stored_slice);
1056       status_ =
1057           GetBundleEntryProto(encoded_stored_slice_name, &stored_slice_entry);
1058       if (!status_.ok()) return status_;
1059     }
1060 
1061     // TODO(zongheng): should we take an OpKernelContext, so that we can call
1062     // allocate_temp()?  Note that without major refactorings to Saver, it's
1063     // hard for the caller of the tensor bundle module to allocate these
1064     // precisely-shaped scratch storage.
1065 
1066     // Optimization for the common case: the stored slice can be directly
1067     // copied to the destination without additional slicing. This is true when
1068     // either the slices are equal or when they are both full slices having the
1069     // same shape.
1070     TensorShape stored_slice_shape(stored_slice_entry.shape());
1071     if (stored_slice == slice_spec ||
1072         (stored_slice_shape == val->shape() &&
1073          IsFullSlice(stored_slice, stored_slice_shape) &&
1074          IsFullSlice(slice_spec, stored_slice_shape))) {
1075       VLOG(1) << "Optimized for common case: directly copying into "
1076                  "pre-allocated buffer; spec: "
1077               << slice_spec.DebugString();
1078       status_ = GetValue(stored_slice_entry, val);
1079       return status_;
1080     }
1081 
1082     Tensor stored_slice_tensor(stored_slice_entry.dtype(), stored_slice_shape);
1083     status_ = GetValue(stored_slice_entry, &stored_slice_tensor);
1084     if (!status_.ok()) return status_;
1085 
1086     // Copies the intersection over.
1087     const DataType common_dtype = full_tensor_entry.dtype();
1088     switch (common_dtype) {
1089 #define HANDLE_COPY(T)                                                 \
1090   case DataTypeToEnum<T>::value:                                       \
1091     CHECK(CopyDataFromTensorSliceToTensorSlice(                        \
1092         full_shape, stored_slice, slice_spec,                          \
1093         stored_slice_tensor.flat<T>().data(), val->flat<T>().data())); \
1094     break;
1095 
1096       HANDLE_COPY(float)
1097       HANDLE_COPY(double)
1098       HANDLE_COPY(int32)
1099       HANDLE_COPY(uint8)
1100       HANDLE_COPY(int16)
1101       HANDLE_COPY(int8)
1102       HANDLE_COPY(complex64)
1103       HANDLE_COPY(complex128)
1104       HANDLE_COPY(int64_t)
1105       HANDLE_COPY(bool)
1106       HANDLE_COPY(qint32)
1107       HANDLE_COPY(quint8)
1108       HANDLE_COPY(qint8)
1109       HANDLE_COPY(bfloat16)
1110       default:
1111         return errors::InvalidArgument("Dtype ", DataTypeString(common_dtype),
1112                                        " not supported.");
1113     }
1114 #undef HANDLE_COPY
1115   }
1116   return OkStatus();
1117 }
1118 
Contains(StringPiece key)1119 bool BundleReader::Contains(StringPiece key) {
1120   Seek(key);
1121   return Valid() && (this->key() == key);
1122 }
1123 
LookupDtypeAndShape(StringPiece key,DataType * dtype,TensorShape * shape)1124 Status BundleReader::LookupDtypeAndShape(StringPiece key, DataType* dtype,
1125                                          TensorShape* shape) {
1126   BundleEntryProto entry;
1127   TF_RETURN_IF_ERROR(GetBundleEntryProto(key, &entry));
1128   *dtype = entry.dtype();
1129   *shape = TensorShape(entry.shape());
1130   return OkStatus();
1131 }
1132 
LookupTensorShape(StringPiece key,TensorShape * shape)1133 Status BundleReader::LookupTensorShape(StringPiece key, TensorShape* shape) {
1134   DataType ignored;
1135   return LookupDtypeAndShape(key, &ignored, shape);
1136 }
1137 
DebugString()1138 string BundleReader::DebugString() {
1139   // Format used below emulates that of TensorSliceReader::DebugString().
1140   string shape_str;
1141   BundleEntryProto entry;
1142   Seek(kHeaderEntryKey);
1143   for (Next(); Valid(); Next()) {
1144     CHECK(entry.ParseFromArray(value().data(), value().size()));
1145     if (entry.slices_size() > 0) continue;  // Slice of some partitioned var.
1146 
1147     strings::StrAppend(&shape_str, key(), " (", DataType_Name(entry.dtype()),
1148                        ") ", TensorShape(entry.shape()).DebugString());
1149     strings::StrAppend(&shape_str, "\n");
1150   }
1151   return shape_str;
1152 }
1153 
1154 namespace {
AlignedMalloc(size_t size)1155 inline char* AlignedMalloc(size_t size) {
1156   char* buffer = static_cast<char*>(port::AlignedMalloc(size, 64));
1157   DCHECK(buffer);
1158   return buffer;
1159 }
1160 }  // namespace
1161 
FileOutputBuffer(WritableFile * file,size_t buffer_size)1162 FileOutputBuffer::FileOutputBuffer(WritableFile* file, size_t buffer_size)
1163     : file_(file), position_(0), buffer_size_(buffer_size) {
1164   DCHECK_GT(buffer_size, 0);
1165   buffer_ptr_ = AlignedMalloc(buffer_size);
1166 }
1167 
~FileOutputBuffer()1168 FileOutputBuffer::~FileOutputBuffer() {
1169   if (buffer_ptr_) port::AlignedFree(buffer_ptr_);
1170   delete file_;
1171 }
1172 
Append(StringPiece data)1173 Status FileOutputBuffer::Append(StringPiece data) {
1174   // In the below, it is critical to calculate the checksum on the actually
1175   // copied bytes, not the source bytes.  This is because "data" typically
1176   // points to tensor buffers, which may be concurrently written.
1177   if (data.size() + position_ <= buffer_size_) {
1178     // Can fit into the current buffer.
1179     memcpy(buffer_ptr_ + position_, data.data(), data.size());
1180     crc32c_ = crc32c::Extend(crc32c_, buffer_ptr_ + position_, data.size());
1181   } else if (data.size() <= buffer_size_) {
1182     // Cannot fit, but can fit after flushing.
1183     TF_RETURN_IF_ERROR(FlushBuffer(false));
1184     memcpy(buffer_ptr_, data.data(), data.size());
1185     crc32c_ = crc32c::Extend(crc32c_, buffer_ptr_, data.size());
1186   } else {
1187     // Cannot fit even after flushing.  So we break down "data" by chunk, and
1188     // flush/checksum each chunk.
1189     TF_RETURN_IF_ERROR(FlushBuffer(false));
1190     for (size_t i = 0; i < data.size(); i += buffer_size_) {
1191       const size_t nbytes = std::min(data.size() - i, buffer_size_);
1192       memcpy(buffer_ptr_, data.data() + i, nbytes);
1193       crc32c_ = crc32c::Extend(crc32c_, buffer_ptr_, nbytes);
1194       position_ = nbytes;
1195       TF_RETURN_IF_ERROR(FlushBuffer(false));
1196     }
1197     return OkStatus();
1198   }
1199   position_ += data.size();
1200   return OkStatus();
1201 }
1202 
Close()1203 Status FileOutputBuffer::Close() {
1204   TF_RETURN_IF_ERROR(FlushBuffer(true));
1205   return file_->Close();
1206 }
1207 
FlushBuffer(bool closing)1208 Status FileOutputBuffer::FlushBuffer(bool closing) {
1209   if (position_ > 0) {
1210     // Use Cord to avoid extra data copy for some WritableFile implementations.
1211     absl::Cord buffer = absl::MakeCordFromExternal(
1212         StringPiece(buffer_ptr_, position_),
1213         [ptr = buffer_ptr_](StringPiece) { port::AlignedFree(ptr); });
1214     buffer_ptr_ = closing ? nullptr : AlignedMalloc(buffer_size_);
1215     TF_RETURN_IF_ERROR(file_->Append(buffer));
1216     position_ = 0;
1217   }
1218   return OkStatus();
1219 }
1220 
1221 }  // namespace tensorflow
1222