1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/util/tensor_bundle/tensor_bundle.h"
17
18 #include <algorithm>
19 #include <cstdlib>
20 #include <cstring>
21 #include <memory>
22 #include <utility>
23
24 #include "tensorflow/core/framework/register_types.h"
25 #include "tensorflow/core/framework/tensor.pb.h"
26 #include "tensorflow/core/framework/tensor_shape.pb.h"
27 #include "tensorflow/core/framework/types.h"
28 #include "tensorflow/core/framework/types.pb.h"
29 #include "tensorflow/core/framework/variant.h"
30 #include "tensorflow/core/framework/variant_op_registry.h"
31 #include "tensorflow/core/framework/variant_tensor_data.h"
32 #include "tensorflow/core/framework/versions.h"
33 #include "tensorflow/core/framework/versions.pb.h"
34 #include "tensorflow/core/lib/core/coding.h"
35 #include "tensorflow/core/lib/core/errors.h"
36 #include "tensorflow/core/lib/gtl/map_util.h"
37 #include "tensorflow/core/lib/hash/crc32c.h"
38 #include "tensorflow/core/lib/io/path.h"
39 #include "tensorflow/core/lib/io/table_builder.h"
40 #include "tensorflow/core/lib/random/random.h"
41 #include "tensorflow/core/lib/strings/str_util.h"
42 #include "tensorflow/core/lib/strings/stringprintf.h"
43 #include "tensorflow/core/platform/bfloat16.h"
44 #include "tensorflow/core/platform/errors.h"
45 #include "tensorflow/core/util/env_var.h"
46 #include "tensorflow/core/util/saved_tensor_slice_util.h"
47 #include "tensorflow/core/util/tensor_bundle/byte_swap.h"
48 #include "tensorflow/core/util/tensor_slice_util.h"
49
50 namespace tensorflow {
51
52 // Versioning of the tensor bundle format.
53 const int kTensorBundleMinProducer = 0;
54 const int kTensorBundleMinConsumer = 0;
55 const int kTensorBundleVersion = 1;
56
57 // Size of our input buffer for streaming reads
58 static const int kBufferSize = 1024 * 1024;
59
60 // Key to the special BundleHeaderProto entry. Do not change this, as clients
61 // can make the assumption that the header is always the first entry in the
62 // bundle.
63 const char* const kHeaderEntryKey = "";
64
65 namespace {
66
67 // Reads "num_elements" string elements from file[offset, offset+size) into the
68 // length-N "destination". Discards the original content of "destination".
69 //
70 // Checksums the string lengths (as restored uint32 or uint64, not varint64
71 // bytes) and string bytes, and stores it into "actual_crc32c".
ReadStringTensor(io::InputBuffer * buffered_file,size_t num_elements,size_t offset,size_t size,tstring * destination,uint32 * actual_crc32c,bool need_to_swap_bytes)72 Status ReadStringTensor(io::InputBuffer* buffered_file, size_t num_elements,
73 size_t offset, size_t size, tstring* destination,
74 uint32* actual_crc32c, bool need_to_swap_bytes) {
75 if (size == 0) return Status::OK();
76 CHECK_GT(size, 0);
77
78 // Reads "num_elements" varint64's from "buffered_file".
79 TF_RETURN_IF_ERROR(buffered_file->Seek(offset));
80 TF_RETURN_IF_ERROR(buffered_file->Hint(size));
81 std::vector<uint64> string_lengths(num_elements);
82 for (size_t i = 0; i < num_elements; ++i) {
83 TF_RETURN_IF_ERROR(buffered_file->ReadVarint64(&string_lengths[i]));
84 if (string_lengths[i] <= UINT32_MAX) {
85 // We need to do this because older checkpoints only used uint32s and we
86 // should still support them.
87 uint32 elem_size_uint32 = static_cast<uint32>(string_lengths[i]);
88 if (need_to_swap_bytes) {
89 // Checksum would have been computed on the source machine's byte order
90 elem_size_uint32 = BYTE_SWAP_32(elem_size_uint32);
91 }
92 *actual_crc32c = crc32c::Extend(
93 *actual_crc32c, reinterpret_cast<const char*>(&elem_size_uint32),
94 sizeof(uint32));
95 } else {
96 uint64 length = string_lengths[i];
97 if (need_to_swap_bytes) {
98 length = BYTE_SWAP_64(length);
99 }
100 *actual_crc32c =
101 crc32c::Extend(*actual_crc32c, reinterpret_cast<const char*>(&length),
102 sizeof(uint64));
103 }
104 }
105 if (offset + size < buffered_file->Tell()) {
106 return errors::DataLoss("String lengths longer than expected offset ",
107 offset + size);
108 }
109
110 // Reads the length-checksum.
111 uint32 raw_length_checksum = 0; // Bytes in file
112 uint32 length_checksum = 0; // In-memory representation
113 size_t unused_bytes_read = 0;
114 TF_RETURN_IF_ERROR(buffered_file->ReadNBytes(
115 sizeof(uint32), reinterpret_cast<char*>(&raw_length_checksum),
116 &unused_bytes_read));
117 length_checksum = need_to_swap_bytes ? BYTE_SWAP_32(raw_length_checksum)
118 : raw_length_checksum;
119 if (crc32c::Unmask(length_checksum) != *actual_crc32c) {
120 return errors::DataLoss(
121 "The length checksum does not match: expected ",
122 strings::Printf("%08u", crc32c::Unmask(length_checksum)),
123 " but actual is ", strings::Printf("%08u", *actual_crc32c));
124 }
125 *actual_crc32c = crc32c::Extend(*actual_crc32c,
126 reinterpret_cast<char*>(&raw_length_checksum),
127 sizeof(uint32));
128
129 // Reads the actual string bytes.
130 for (size_t i = 0; i < num_elements; ++i) {
131 const uint64 string_length = string_lengths[i];
132 tstring* buffer = &destination[i];
133
134 buffer->resize(string_length);
135 size_t bytes_read = 0;
136 TF_RETURN_IF_ERROR(
137 buffered_file->ReadNBytes(string_length, &(*buffer)[0], &bytes_read));
138 *actual_crc32c = crc32c::Extend(*actual_crc32c, buffer->data(), bytes_read);
139 }
140 return Status::OK();
141 }
142
ReadVariantTensor(io::InputBuffer * buffered_file,Tensor * ret,size_t offset,size_t size,uint32 * actual_crc32c)143 Status ReadVariantTensor(io::InputBuffer* buffered_file, Tensor* ret,
144 size_t offset, size_t size, uint32* actual_crc32c) {
145 // On-disk format:
146 // [varint64 len1][bytes variant1][4 byte checksum]
147 // ..
148 // [varint64 lenN][bytes variantN][4 byte checksum]
149 // Var "crc32c" checksums all the lens, variant bytes, individual variant
150 // checksums (as uint32, not varint32 bytes).
151 if (size == 0) return Status::OK();
152 size_t num_elements = ret->NumElements();
153
154 // Reads the actual string bytes.
155 TF_RETURN_IF_ERROR(buffered_file->Seek(offset));
156 TF_RETURN_IF_ERROR(buffered_file->Hint(size));
157 for (size_t i = 0; i < num_elements; ++i) {
158 // Read the serialized variant length.
159 uint64 string_length = 0;
160 TF_RETURN_IF_ERROR(buffered_file->ReadVarint64(&string_length));
161 *actual_crc32c = crc32c::Extend(
162 *actual_crc32c, reinterpret_cast<const char*>(&string_length),
163 sizeof(uint64));
164 // Read the actual serialized variant.
165 string buffer;
166 buffer.resize(string_length);
167 size_t bytes_read = 0;
168 TF_RETURN_IF_ERROR(
169 buffered_file->ReadNBytes(string_length, &buffer[0], &bytes_read));
170 *actual_crc32c = crc32c::Extend(*actual_crc32c, buffer.data(), bytes_read);
171 VariantTensorDataProto proto;
172 if (!proto.ParseFromString(buffer)) {
173 return errors::DataLoss("Unable to parse VariantTensorDataProto from ",
174 "buffer of size ", string_length, ". ",
175 "Bundle entry offset: ", offset, " size: ", size);
176 }
177 Variant v = proto;
178 if (!DecodeUnaryVariant(&v)) {
179 return errors::Internal("Could not decode variant with type_name: \"",
180 v.TypeName(), "\". Perhaps you forgot to ",
181 "register a decoder via ",
182 "REGISTER_UNARY_VARIANT_DECODE_FUNCTION?");
183 }
184
185 // Read the checksum.
186 uint32 checksum = 0;
187 size_t unused_bytes_read = 0;
188 TF_RETURN_IF_ERROR(buffered_file->ReadNBytes(
189 sizeof(uint32), reinterpret_cast<char*>(&checksum),
190 &unused_bytes_read));
191 if (crc32c::Unmask(checksum) != *actual_crc32c) {
192 return errors::DataLoss(
193 "The checksum after Variant ", i, " does not match.",
194 " Expected: ", strings::Printf("%08u", crc32c::Unmask(checksum)),
195 " Actual: ", strings::Printf("%08u", *actual_crc32c));
196 }
197 *actual_crc32c = crc32c::Extend(
198 *actual_crc32c, reinterpret_cast<char*>(&checksum), sizeof(uint32));
199
200 ret->flat<Variant>()(i) = std::move(v);
201 }
202
203 return Status::OK();
204 }
205
GetBackingBuffer(const Tensor & val)206 char* GetBackingBuffer(const Tensor& val) {
207 CHECK(DataTypeCanUseMemcpy(val.dtype())) << val.dtype();
208 return const_cast<char*>(val.tensor_data().data());
209 }
210
GetStringBackingBuffer(const Tensor & val)211 tstring* GetStringBackingBuffer(const Tensor& val) {
212 CHECK_EQ(DT_STRING, val.dtype());
213 return const_cast<tstring*>(val.flat<tstring>().data());
214 }
215
ParseEntryProto(StringPiece key,StringPiece value,protobuf::MessageLite * out)216 Status ParseEntryProto(StringPiece key, StringPiece value,
217 protobuf::MessageLite* out) {
218 if (!out->ParseFromArray(value.data(), value.size())) {
219 return errors::DataLoss("Entry for key ", key, " not parseable.");
220 }
221 return Status::OK();
222 }
223
224 // Serializes the data bytes of the non-string tensor "val". Discards the
225 // original content of "bytes_written", and on OK updates it with number of
226 // bytes written.
227 // REQUIRES: val.dtype() != DT_STRING
WriteTensor(const Tensor & val,FileOutputBuffer * out,size_t * bytes_written)228 Status WriteTensor(const Tensor& val, FileOutputBuffer* out,
229 size_t* bytes_written) {
230 DCHECK_NE(val.dtype(), DT_STRING);
231 DCHECK_NE(val.dtype(), DT_VARIANT);
232 *bytes_written = val.TotalBytes();
233 char* buf = GetBackingBuffer(val);
234 VLOG(1) << "Appending " << *bytes_written << " bytes to file";
235 return out->Append(StringPiece(buf, *bytes_written));
236 }
237
238 // Serializes string tensor "val". "bytes_written" is treated in the same
239 // fashion as WriteTensor().
240 //
241 // Checksums all bytes written and stores it into "crc32c".
242 // REQUIRES: val.dtype() == DT_STRING
WriteStringTensor(const Tensor & val,FileOutputBuffer * out,size_t * bytes_written,uint32 * crc32c)243 Status WriteStringTensor(const Tensor& val, FileOutputBuffer* out,
244 size_t* bytes_written, uint32* crc32c) {
245 // On-disk format:
246 // [varint64 len0]..[varint64 lenL][4 byte cksum on lengths][string bytes]
247 // Var "crc32c" checksums the string lengths (as uint64, not varint64 bytes),
248 // the length-checksum, and all the string bytes.
249 DCHECK_EQ(val.dtype(), DT_STRING);
250 const tstring* strings = GetStringBackingBuffer(val);
251
252 // Writes the varint lengths.
253 string lengths;
254 lengths.reserve(val.NumElements()); // At least 1 byte per element.
255 *crc32c = 0;
256 for (int64 i = 0; i < val.NumElements(); ++i) {
257 const tstring* elem = &strings[i];
258 DCHECK_EQ(elem->size(), static_cast<uint64>(elem->size()));
259 const uint64 elem_size = static_cast<uint64>(elem->size());
260
261 core::PutVarint64(&lengths, elem_size);
262 if (elem_size <= UINT32_MAX) {
263 // We need to do this because older checkpoints only used uint32s and we
264 // should still support them.
265 const uint32 elem_size_uint32 = static_cast<uint32>(elem_size);
266 *crc32c = crc32c::Extend(*crc32c,
267 reinterpret_cast<const char*>(&elem_size_uint32),
268 sizeof(uint32));
269 } else {
270 *crc32c = crc32c::Extend(
271 *crc32c, reinterpret_cast<const char*>(&elem_size), sizeof(uint64));
272 }
273 }
274 TF_RETURN_IF_ERROR(out->Append(lengths));
275 *bytes_written = lengths.size();
276
277 // Writes the length checksum.
278 const uint32 length_checksum = crc32c::Mask(*crc32c);
279 TF_RETURN_IF_ERROR(out->Append(StringPiece(
280 reinterpret_cast<const char*>(&length_checksum), sizeof(uint32))));
281 *crc32c = crc32c::Extend(
282 *crc32c, reinterpret_cast<const char*>(&length_checksum), sizeof(uint32));
283 *bytes_written += sizeof(uint32);
284
285 // Writes all the string bytes out.
286 for (int64 i = 0; i < val.NumElements(); ++i) {
287 const tstring* string = &strings[i];
288 TF_RETURN_IF_ERROR(out->Append(*string));
289 *bytes_written += string->size();
290 *crc32c = crc32c::Extend(*crc32c, string->data(), string->size());
291 }
292 return Status::OK();
293 }
294
WriteVariantTensor(const Tensor & val,FileOutputBuffer * out,size_t * bytes_written,uint32 * crc32c)295 Status WriteVariantTensor(const Tensor& val, FileOutputBuffer* out,
296 size_t* bytes_written, uint32* crc32c) {
297 // On-disk format:
298 // [varint64 len1][bytes variant1][4 byte checksum]
299 // ..
300 // [varint64 lenN][bytes variantN][4 byte checksum]
301 // Var "crc32c" checksums all the lens, variant bytes, individual variant
302 // checksums (as uint32, not varint32 bytes).
303 DCHECK_EQ(val.dtype(), DT_VARIANT);
304
305 *crc32c = 0;
306 *bytes_written = 0;
307 for (int64 i = 0; i < val.NumElements(); ++i) {
308 VariantTensorData data;
309 val.flat<Variant>()(i).Encode(&data);
310 VariantTensorDataProto proto;
311 data.ToProto(&proto);
312 string elem;
313 if (!proto.SerializeToString(&elem)) {
314 return errors::Unknown(
315 "Failed to serialize tensor data of size ", proto.ByteSizeLong(),
316 ". Tensor: ", val.flat<Variant>()(i).DebugString());
317 }
318
319 // Write the length of the serialized variant.
320 DCHECK_EQ(elem.size(), static_cast<uint64>(elem.size()));
321 const auto elem_size = static_cast<uint64>(elem.size());
322 string len;
323 core::PutVarint64(&len, elem_size);
324 TF_RETURN_IF_ERROR(out->Append(len));
325 *crc32c = crc32c::Extend(*crc32c, reinterpret_cast<const char*>(&elem_size),
326 sizeof(uint64));
327 *bytes_written += len.size();
328
329 // Write the serialized variant.
330 TF_RETURN_IF_ERROR(out->Append(elem));
331 *crc32c = crc32c::Extend(*crc32c, elem.data(), elem.size());
332 *bytes_written += elem.size();
333
334 // Write the checksum.
335 const uint32 length_checksum = crc32c::Mask(*crc32c);
336 TF_RETURN_IF_ERROR(out->Append(StringPiece(
337 reinterpret_cast<const char*>(&length_checksum), sizeof(uint32))));
338 *crc32c =
339 crc32c::Extend(*crc32c, reinterpret_cast<const char*>(&length_checksum),
340 sizeof(uint32));
341 *bytes_written += sizeof(uint32);
342 }
343
344 return Status::OK();
345 }
346
347 // Returns whether "slice_spec" is a full slice, with respect to the full shape.
348 //
349 // This can happen say, when "slice_spec" is
350 // "TensorSlice(full_tensor_shape.dims())", or when it is "TensorSlice({{0,
351 // dim(0)}, ..., {0, dim(N)}})" -- a degenerate case we need to guard against.
IsFullSlice(const TensorSlice & slice_spec,const TensorShape & full_tensor_shape)352 bool IsFullSlice(const TensorSlice& slice_spec,
353 const TensorShape& full_tensor_shape) {
354 if (slice_spec.IsFull()) {
355 return true;
356 } else {
357 TensorShape sliced_shape;
358 slice_spec.SliceTensorShape(full_tensor_shape, &sliced_shape).IgnoreError();
359 return sliced_shape == full_tensor_shape;
360 }
361 }
362
CorruptFileError(const Status & in_status,const string & filename,const string & detail)363 Status CorruptFileError(const Status& in_status, const string& filename,
364 const string& detail) {
365 if (in_status.ok()) {
366 return errors::Internal("Unable to read file (", filename,
367 "). Perhaps the file is corrupt or was produced by "
368 "a newer version of TensorFlow with format changes "
369 "(",
370 detail, ")");
371 }
372 return Status(
373 in_status.code(),
374 strings::StrCat("Unable to read file (", filename,
375 "). Perhaps the file is corrupt or was produced by a "
376 "newer version of TensorFlow with format changes (",
377 detail, "): ", in_status.error_message()));
378 }
379
TableBuilderOptions()380 table::Options TableBuilderOptions() {
381 table::Options o;
382 // Compressed tables cannot be read by TensorFlow releases prior to 1.1.
383 // To smoothen the transition, compressed writes are disabled for now
384 // (version 1.2) with the intention that they will be enabled again at
385 // some point (perhaps the 1.3 release?).
386 o.compression = table::kNoCompression;
387 return o;
388 }
389
390 // Writes zeros to output buffer to align the next write to the requested
391 // alignment. "size" is the current size of the buffer and is updated to the
392 // new size.
PadAlignment(FileOutputBuffer * out,int alignment,int64 * size)393 Status PadAlignment(FileOutputBuffer* out, int alignment, int64* size) {
394 int bytes_over = *size % alignment;
395 if (bytes_over == 0) {
396 return Status::OK();
397 }
398 int bytes_to_write = alignment - bytes_over;
399 Status status = out->Append(string(bytes_to_write, '\0'));
400 if (status.ok()) {
401 *size += bytes_to_write;
402 }
403 return status;
404 }
405
406 } // namespace
407
BundleWriter(Env * env,StringPiece prefix,const Options & options)408 BundleWriter::BundleWriter(Env* env, StringPiece prefix, const Options& options)
409 : env_(env),
410 options_(options),
411 prefix_(prefix),
412 out_(nullptr),
413 size_(0) {
414 status_ = env_->HasAtomicMove(prefix_, &use_temp_file_);
415 if (!status_.ok()) return;
416
417 data_path_ = DataFilename(prefix_, 0, 1);
418 metadata_path_ = MetaFilename(prefix_);
419 if (use_temp_file_) {
420 data_path_ = strings::StrCat(data_path_, ".tempstate", random::New64());
421 metadata_path_ =
422 strings::StrCat(metadata_path_, ".tempstate", random::New64());
423 }
424
425 status_ = env_->CreateDir(string(io::Dirname(prefix_)));
426 if (!status_.ok() && !errors::IsAlreadyExists(status_)) {
427 return;
428 }
429
430 std::unique_ptr<WritableFile> wrapper;
431 status_ = env_->NewWritableFile(data_path_, &wrapper);
432 if (!status_.ok()) return;
433 out_ = std::unique_ptr<FileOutputBuffer>(
434 new FileOutputBuffer(wrapper.release(), 8 << 20 /* 8MB write buffer */));
435
436 VLOG(1) << "Writing to file " << data_path_;
437 }
438
Add(StringPiece key,const Tensor & val)439 Status BundleWriter::Add(StringPiece key, const Tensor& val) {
440 if (!status_.ok()) return status_;
441 CHECK_NE(key, kHeaderEntryKey);
442 const string key_string(key);
443 if (entries_.find(key_string) != entries_.end()) {
444 status_ = errors::InvalidArgument("Adding duplicate key: ", key);
445 return status_;
446 }
447
448 BundleEntryProto* entry = &entries_[key_string];
449 entry->set_dtype(val.dtype());
450 val.shape().AsProto(entry->mutable_shape());
451 entry->set_shard_id(0);
452 entry->set_offset(size_);
453
454 // Updates the data file.
455 size_t data_bytes_written = 0;
456 uint32 crc32c = 0;
457 out_->clear_crc32c();
458 if (val.dtype() == DT_STRING) {
459 status_ = WriteStringTensor(val, out_.get(), &data_bytes_written, &crc32c);
460 } else if (val.dtype() == DT_VARIANT) {
461 status_ = WriteVariantTensor(val, out_.get(), &data_bytes_written, &crc32c);
462 } else {
463 status_ = WriteTensor(val, out_.get(), &data_bytes_written);
464 crc32c = out_->crc32c();
465 }
466
467 if (status_.ok()) {
468 entry->set_size(data_bytes_written);
469 entry->set_crc32c(crc32c::Mask(crc32c));
470 size_ += data_bytes_written;
471 status_ = PadAlignment(out_.get(), options_.data_alignment, &size_);
472 }
473 return status_;
474 }
475
AddSlice(StringPiece full_tensor_key,const TensorShape & full_tensor_shape,const TensorSlice & slice_spec,const Tensor & slice_tensor)476 Status BundleWriter::AddSlice(StringPiece full_tensor_key,
477 const TensorShape& full_tensor_shape,
478 const TensorSlice& slice_spec,
479 const Tensor& slice_tensor) {
480 if (!status_.ok()) return status_;
481 CHECK_NE(full_tensor_key, kHeaderEntryKey);
482
483 // If just a singleton full slice, use the regular Add() to be more efficient.
484 if (IsFullSlice(slice_spec, full_tensor_shape)) {
485 return Add(full_tensor_key, slice_tensor);
486 }
487
488 // Inserts/updates the full tensor's metadata entry.
489 //
490 // In the case of a sharded save, MergeBundles() is responsible for merging
491 // the "slices" field of multiple metadata entries corresponding to the same
492 // full tensor.
493 const string full_tensor_key_string(full_tensor_key);
494 BundleEntryProto* full_entry = &entries_[full_tensor_key_string];
495 if (full_entry->dtype() != DT_INVALID) {
496 CHECK_EQ(full_entry->dtype(), slice_tensor.dtype());
497 }
498 if (full_entry->has_shape()) {
499 CHECK(TensorShape(full_entry->shape()) == full_tensor_shape);
500 }
501
502 // Populates dtype, shape, and slices. Intentionally leaving out shard_id and
503 // offset, which do not make sense for this full tensor entry.
504 full_entry->set_dtype(slice_tensor.dtype());
505 full_tensor_shape.AsProto(full_entry->mutable_shape());
506 TensorSliceProto* slice_proto = full_entry->add_slices();
507 slice_spec.AsProto(slice_proto);
508
509 // The slice itself is handled by a regular Add(), which includes adding its
510 // own metadata entry, and writing out the slice's values.
511 const string slice_name =
512 checkpoint::EncodeTensorNameSlice(full_tensor_key_string, slice_spec);
513 status_ = Add(slice_name, slice_tensor);
514 return status_;
515 }
516
517 // TODO(zongheng): on metadata write failure or !status_.ok(), consider removing
518 // the orphaned data file.
Finish()519 Status BundleWriter::Finish() {
520 if (out_) {
521 status_.Update(out_->Close());
522 out_ = nullptr;
523 if (status_.ok()) {
524 if (use_temp_file_) {
525 status_ =
526 Env::Default()->RenameFile(data_path_, DataFilename(prefix_, 0, 1));
527 }
528 } else {
529 Env::Default()->DeleteFile(data_path_).IgnoreError();
530 }
531 }
532 if (!status_.ok()) return status_;
533 // Build key -> BundleEntryProto table.
534 std::unique_ptr<WritableFile> file;
535 status_ = env_->NewWritableFile(metadata_path_, &file);
536 if (!status_.ok()) return status_;
537 {
538 // N.B.: the default use of Snappy compression may not be supported on all
539 // platforms (e.g. Android). The metadata file is small, so this is fine.
540 table::Options options;
541 options.compression = table::kNoCompression;
542 table::TableBuilder builder(options, file.get());
543 // Header entry.
544 BundleHeaderProto header;
545 header.set_num_shards(1);
546 header.set_endianness(BundleHeaderProto::LITTLE);
547 if (!port::kLittleEndian) header.set_endianness(BundleHeaderProto::BIG);
548 VersionDef* version = header.mutable_version();
549 version->set_producer(kTensorBundleVersion);
550 version->set_min_consumer(kTensorBundleMinConsumer);
551
552 builder.Add(kHeaderEntryKey, header.SerializeAsString());
553
554 // All others.
555 for (const auto& p : entries_) {
556 builder.Add(p.first, p.second.SerializeAsString());
557 }
558 status_ = builder.Finish();
559 }
560 status_.Update(file->Close());
561 if (!status_.ok()) {
562 Env::Default()->DeleteFile(metadata_path_).IgnoreError();
563 return status_;
564 } else if (use_temp_file_) {
565 status_ = Env::Default()->RenameFile(metadata_path_, MetaFilename(prefix_));
566 if (!status_.ok()) return status_;
567 }
568 status_ = errors::Internal("BundleWriter is closed");
569 return Status::OK();
570 }
571
572 // Merging tensor bundles.
573
574 // Accumulator of metadata states during a merge.
575 struct MergeState {
576 // Accumulated from the header entries.
577 int num_shards = 0;
578
579 // Derives "endianness" and "version" from the first bundle merged (hence the
580 // "seen_first_bundle" guard). The two fields must be the same for all
581 // bundles in a merge.
582 bool seen_first_bundle = false;
583 BundleHeaderProto_Endianness endianness;
584 VersionDef version;
585
586 // Tensor key -> BundleEntryProto.
587 std::map<string, BundleEntryProto> entries;
588 // Data file path -> new shard id in the final merged bundle.
589 std::unordered_map<string, int32> shard_ids;
590 };
591
592 // Merges entries of "prefix" into the accumulator state "merge".
593 // Returns OK iff the merge succeeds.
MergeOneBundle(Env * env,StringPiece prefix,MergeState * merge_state)594 static Status MergeOneBundle(Env* env, StringPiece prefix,
595 MergeState* merge_state) {
596 VLOG(1) << "Merging bundle:" << prefix;
597 const string filename = MetaFilename(prefix);
598 uint64 file_size;
599 TF_RETURN_IF_ERROR(env->GetFileSize(filename, &file_size));
600 std::unique_ptr<RandomAccessFile> file;
601 TF_RETURN_IF_ERROR(env->NewRandomAccessFile(filename, &file));
602
603 table::Table* table = nullptr;
604 TF_RETURN_IF_ERROR(
605 table::Table::Open(TableBuilderOptions(), file.get(), file_size, &table));
606 std::unique_ptr<table::Table> table_deleter(table);
607 std::unique_ptr<table::Iterator> iter(table->NewIterator());
608
609 int num_shards;
610 // Process header.
611 {
612 iter->Seek(kHeaderEntryKey);
613 if (!iter->Valid()) {
614 return CorruptFileError(iter->status(), filename,
615 "failed to seek to header entry");
616 }
617 BundleHeaderProto header;
618 Status s = ParseEntryProto(iter->key(), iter->value(), &header);
619 if (!s.ok()) return CorruptFileError(s, filename, "unable to parse header");
620
621 merge_state->num_shards += header.num_shards();
622 if (!merge_state->seen_first_bundle) {
623 merge_state->seen_first_bundle = true;
624 merge_state->endianness = header.endianness();
625 merge_state->version = header.version();
626 } else {
627 // Validates "endianness".
628 if (merge_state->endianness != header.endianness()) {
629 return errors::InvalidArgument(
630 "Merging bundles with conflicting endianness; inputs corrupted?");
631 }
632 // Validates "version".
633 string curr_version, merge_version;
634 header.version().SerializeToString(&curr_version);
635 merge_state->version.SerializeToString(&merge_version);
636 if (curr_version != merge_version) {
637 return errors::InvalidArgument(
638 "Merging bundles with different format versions: merged ",
639 merge_version, " vs. curr ", curr_version);
640 }
641 }
642 num_shards = header.num_shards();
643 iter->Next();
644 }
645
646 // Loops through the non-header to-merge entries.
647 BundleEntryProto to_merge_entry;
648 for (; iter->Valid(); iter->Next()) {
649 const string key(iter->key());
650 const auto entry_iter = merge_state->entries.find(key);
651
652 // Illegal: the duplicated entry is a non-slice tensor.
653 if (entry_iter != merge_state->entries.end() &&
654 entry_iter->second.slices().empty()) {
655 return errors::InvalidArgument(
656 "Duplicate tensor keyed by ", key,
657 " encountered, when merging prefix: ", prefix);
658 }
659
660 TF_RETURN_IF_ERROR(
661 ParseEntryProto(iter->key(), iter->value(), &to_merge_entry));
662
663 // The duplicated entry holds metadata for a sliced full tensor.
664 // Allows the duplication and merges "slices".
665 if (entry_iter != merge_state->entries.end()) {
666 BundleEntryProto& existing_entry = entry_iter->second;
667 if (to_merge_entry.slices().empty()) {
668 return errors::Internal(
669 "Duplicate tensor keyed by ", key,
670 "; attempting to merge in a non-slice bundle entry");
671 }
672 // Only needs merge the "slices" field (and validate dtype/shape).
673 for (int i = 0; i < to_merge_entry.slices_size(); ++i) {
674 TensorSliceProto* slot = existing_entry.add_slices();
675 *slot = to_merge_entry.slices(i);
676 }
677 CHECK_EQ(existing_entry.dtype(), to_merge_entry.dtype());
678 CHECK(TensorShape(existing_entry.shape()) ==
679 TensorShape(to_merge_entry.shape()));
680 continue;
681 }
682
683 // Key doesn't duplicate: a fresh tensor/slice entry.
684 auto result = merge_state->shard_ids.insert(
685 {DataFilename(prefix, to_merge_entry.shard_id(), num_shards),
686 merge_state->shard_ids.size()});
687 to_merge_entry.set_shard_id(result.first->second);
688 merge_state->entries[key] = to_merge_entry;
689 }
690 return Status::OK();
691 }
692
MergeBundles(Env * env,gtl::ArraySlice<tstring> prefixes,StringPiece merged_prefix)693 Status MergeBundles(Env* env, gtl::ArraySlice<tstring> prefixes,
694 StringPiece merged_prefix) {
695 // Merges all metadata tables.
696 // TODO(zhifengc): KeyValue sorter if it becomes too big.
697 MergeState merge;
698 Status status = env->CreateDir(string(io::Dirname(merged_prefix)));
699 if (!status.ok() && !errors::IsAlreadyExists(status)) return status;
700 for (int i = 0; i < prefixes.size(); ++i) {
701 TF_RETURN_IF_ERROR(MergeOneBundle(env, prefixes[i], &merge));
702 }
703
704 // Renames data files to contain the merged bundle prefix.
705 for (const auto& p : merge.shard_ids) {
706 VLOG(1) << "Renaming " << p.first << " to "
707 << DataFilename(merged_prefix, p.second, merge.shard_ids.size());
708 TF_RETURN_IF_ERROR(env->RenameFile(
709 p.first,
710 DataFilename(merged_prefix, p.second, merge.shard_ids.size())));
711 }
712
713 // Writes the final metadata table under the merged prefix.
714 std::unique_ptr<WritableFile> merged_metadata;
715 TF_RETURN_IF_ERROR(
716 env->NewWritableFile(MetaFilename(merged_prefix), &merged_metadata));
717 {
718 table::TableBuilder builder(TableBuilderOptions(), merged_metadata.get());
719 // Header entry.
720 BundleHeaderProto header;
721 header.set_num_shards(merge.num_shards);
722 header.set_endianness(merge.endianness);
723 *header.mutable_version() = merge.version;
724 builder.Add(kHeaderEntryKey, header.SerializeAsString());
725 // All others.
726 for (const auto& p : merge.entries) {
727 builder.Add(p.first, p.second.SerializeAsString());
728 }
729 status = builder.Finish();
730 }
731 status.Update(merged_metadata->Close());
732 if (!status.ok()) return status;
733 VLOG(1) << "Merged bundles to:" << merged_prefix;
734
735 // Cleanup: best effort based and ignores errors.
736 for (const tstring& prefix : prefixes) {
737 env->DeleteFile(MetaFilename(prefix)).IgnoreError();
738 }
739 return status;
740 }
741
742 // Interface for reading a tensor bundle.
743
BundleReader(Env * env,StringPiece prefix)744 BundleReader::BundleReader(Env* env, StringPiece prefix)
745 : env_(env),
746 prefix_(prefix),
747 metadata_(nullptr),
748 table_(nullptr),
749 index_cache_(nullptr),
750 iter_(nullptr),
751 need_to_swap_bytes_(false) {
752 const string filename = MetaFilename(prefix_);
753 uint64 file_size;
754 status_ = env_->GetFileSize(filename, &file_size);
755 if (!status_.ok()) return;
756
757 // Opens the metadata table.
758 std::unique_ptr<RandomAccessFile> wrapper;
759 status_ = env_->NewRandomAccessFile(filename, &wrapper);
760 if (!status_.ok()) return;
761 metadata_ = wrapper.release();
762
763 table::Options o;
764 int64 cache_size;
765 Status s =
766 ReadInt64FromEnvVar("TF_TABLE_INDEX_CACHE_SIZE_IN_MB", 0, &cache_size);
767 if (s.ok() && cache_size > 0) {
768 index_cache_ = table::NewLRUCache(cache_size << 20);
769 o.block_cache = index_cache_;
770 }
771
772 status_ = table::Table::Open(o, metadata_, file_size, &table_);
773 if (!status_.ok()) return;
774 iter_ = table_->NewIterator();
775
776 // Reads "num_shards_" from the first entry.
777 iter_->Seek(kHeaderEntryKey);
778 if (!iter_->Valid()) {
779 status_ = CorruptFileError(iter_->status(), filename,
780 "failed to seek to header entry");
781 return;
782 }
783 BundleHeaderProto header;
784 status_ = ParseEntryProto(iter_->key(), iter_->value(), &header);
785 if (!status_.ok()) {
786 status_ = CorruptFileError(status_, filename, "unable to parse header");
787 return;
788 }
789 num_shards_ = header.num_shards();
790 if ((header.endianness() == BundleHeaderProto::BIG && port::kLittleEndian) ||
791 (header.endianness() == BundleHeaderProto::LITTLE &&
792 !port::kLittleEndian)) {
793 need_to_swap_bytes_ = true;
794 }
795 status_ = CheckVersions(header.version(), kTensorBundleVersion,
796 kTensorBundleMinProducer, "Checkpoint", "checkpoint");
797 }
798
~BundleReader()799 BundleReader::~BundleReader() {
800 delete metadata_;
801 delete iter_;
802 delete table_;
803 if (index_cache_) {
804 delete index_cache_;
805 }
806 // InputBuffer does not own the underlying RandomAccessFile.
807 for (auto pair : data_) {
808 if (pair.second != nullptr && pair.second->file() != nullptr) {
809 delete pair.second->file();
810 }
811 }
812 for (auto& temp : data_) {
813 delete temp.second;
814 }
815 for (auto& temp : tensor_slices_) {
816 delete temp.second;
817 }
818 data_.clear();
819 tensor_slices_.clear();
820 }
821
GetBundleEntryProto(StringPiece key,BundleEntryProto * entry)822 Status BundleReader::GetBundleEntryProto(StringPiece key,
823 BundleEntryProto* entry) {
824 entry->Clear();
825 TF_CHECK_OK(status_);
826 Seek(key);
827 if (!iter_->Valid() || iter_->key() != key) {
828 return errors::NotFound("Key ", key, " not found in checkpoint");
829 }
830
831 BundleEntryProto entry_copy;
832 TF_RETURN_IF_ERROR(
833 ParseEntryProto(iter_->key(), iter_->value(), &entry_copy));
834 if (!TensorShape::IsValid(entry_copy.shape())) {
835 return errors::DataLoss("Invalid tensor shape: ", key, " ",
836 entry_copy.shape().ShortDebugString());
837 }
838
839 entry->Swap(&entry_copy);
840 return Status::OK();
841 }
842
GetValue(const BundleEntryProto & entry,Tensor * val)843 Status BundleReader::GetValue(const BundleEntryProto& entry, Tensor* val) {
844 Tensor* ret = val;
845 const TensorShape stored_shape(TensorShape(entry.shape()));
846 if (val->NumElements() == 0) {
847 ret = new Tensor(entry.dtype(), stored_shape);
848 }
849
850 // Validates the "size" field.
851 if (entry.dtype() != DT_STRING && entry.dtype() != DT_VARIANT) {
852 if (entry.size() != ret->TotalBytes()) {
853 return errors::DataLoss("Invalid size in bundle entry: key ", key(),
854 "; stored size ", entry.size(),
855 "; expected size ", ret->TotalBytes());
856 }
857 } else if (entry.dtype() == DT_STRING) {
858 // Relaxes the check for string tensors as follows:
859 // entry.size() == bytes(varint lengths) + bytes(data)
860 // >= NumElems + bytes(data), since size bytes(varint) >= 1.
861 // TotalBytes() == sizeof(tstring) * NumElems + bytes(data)
862 // Since we don't know bytes(varint lengths), we just check an inequality.
863 const size_t lower_bound = ret->NumElements() + ret->TotalBytes() -
864 sizeof(tstring) * ret->NumElements();
865 if (entry.size() < lower_bound) {
866 return errors::DataLoss("Invalid size in bundle entry: key ", key(),
867 "; stored size ", entry.size(),
868 "; expected size is at least ", lower_bound);
869 }
870 }
871
872 // Open the data file if it has not been opened.
873 io::InputBuffer* buffered_file = data_[entry.shard_id()];
874 if (buffered_file == nullptr) {
875 std::unique_ptr<RandomAccessFile> file = nullptr;
876 TF_RETURN_IF_ERROR(env_->NewRandomAccessFile(
877 DataFilename(prefix_, entry.shard_id(), num_shards_), &file));
878 buffered_file = new io::InputBuffer(file.release(), kBufferSize);
879 // The InputBuffer and RandomAccessFile objects are both released in dtor.
880 data_[entry.shard_id()] = buffered_file;
881 }
882 CHECK(buffered_file != nullptr);
883
884 TF_RETURN_IF_ERROR(buffered_file->Seek(entry.offset()));
885 uint32 actual_crc32c = 0;
886
887 if (DataTypeCanUseMemcpy(entry.dtype())) {
888 char* backing_buffer = const_cast<char*>((ret->tensor_data().data()));
889 size_t unused_bytes_read;
890 if (entry.size() > kBufferSize) {
891 StringPiece sp;
892 TF_RETURN_IF_ERROR(buffered_file->file()->Read(
893 entry.offset(), entry.size(), &sp, backing_buffer));
894 if (sp.data() != backing_buffer) {
895 memmove(backing_buffer, sp.data(), entry.size());
896 }
897 } else {
898 TF_RETURN_IF_ERROR(buffered_file->ReadNBytes(entry.size(), backing_buffer,
899 &unused_bytes_read));
900 }
901 // Note that we compute the checksum *before* byte-swapping. The checksum
902 // should be on the bytes in the order they appear in the file.
903 actual_crc32c = crc32c::Value(backing_buffer, entry.size());
904 if (need_to_swap_bytes_) {
905 TF_RETURN_IF_ERROR(ByteSwapTensor(ret));
906 }
907 } else if (entry.dtype() == DT_VARIANT) {
908 if (need_to_swap_bytes_) {
909 return errors::Unimplemented(
910 "TensorBundle at ", prefix_,
911 "is of a different endianness than this machine's hardware, and "
912 "the bundle contains a variant (arbitrary C++ type) tensor. "
913 "Byte-swapping of variant tensors is not currently implemented.");
914 }
915 // Relies on io::InputBuffer's buffering, because we issue many neighboring
916 // reads for a single string tensor.
917 TF_RETURN_IF_ERROR(ReadVariantTensor(buffered_file, ret, entry.offset(),
918 entry.size(), &actual_crc32c));
919 } else {
920 // Relies on io::InputBuffer's buffering, because we issue many neighboring
921 // reads for a single string tensor.
922 TF_RETURN_IF_ERROR(ReadStringTensor(
923 buffered_file, ret->NumElements(), entry.offset(), entry.size(),
924 GetStringBackingBuffer(*ret), &actual_crc32c, need_to_swap_bytes_));
925 }
926 if (crc32c::Unmask(entry.crc32c()) != actual_crc32c) {
927 return errors::DataLoss(
928 "TensorBundle at ", prefix_, " shard ", entry.shard_id(), " (",
929 entry.size(), " bytes): Checksum does not match: stored ",
930 strings::Printf("%08u", crc32c::Unmask(entry.crc32c())),
931 " vs. calculated on the restored bytes ", actual_crc32c);
932 }
933
934 *val = *ret;
935 if (ret != val) delete ret;
936 return Status::OK();
937 }
938
Lookup(StringPiece key,Tensor * val)939 Status BundleReader::Lookup(StringPiece key, Tensor* val) {
940 CHECK(val != nullptr);
941 BundleEntryProto entry;
942 TF_RETURN_IF_ERROR(GetBundleEntryProto(key, &entry));
943
944 if (entry.slices().empty()) {
945 return GetValue(entry, val);
946 } else {
947 return GetSliceValue(
948 key, entry,
949 /* a full slice */ TensorSlice(TensorShape(entry.shape()).dims()), val);
950 }
951 }
952
ReadCurrent(Tensor * val)953 Status BundleReader::ReadCurrent(Tensor* val) {
954 CHECK(val != nullptr);
955 BundleEntryProto entry;
956 TF_RETURN_IF_ERROR(ParseEntryProto(iter_->key(), iter_->value(), &entry));
957 if (!TensorShape::IsValid(entry.shape())) {
958 return errors::DataLoss("Invalid tensor shape: ", iter_->key(), " ",
959 entry.shape().ShortDebugString());
960 }
961
962 if (entry.slices().empty()) {
963 return GetValue(entry, val);
964 } else {
965 return GetSliceValue(
966 iter_->key(), entry,
967 /* a full slice */ TensorSlice(TensorShape(entry.shape()).dims()), val);
968 }
969 }
970
LookupTensorSlices(StringPiece key,std::vector<TensorSlice> * slices)971 Status BundleReader::LookupTensorSlices(StringPiece key,
972 std::vector<TensorSlice>* slices) {
973 slices->clear();
974 BundleEntryProto entry;
975 TF_RETURN_IF_ERROR(GetBundleEntryProto(key, &entry));
976 slices->reserve(entry.slices_size());
977 for (const auto& slice : entry.slices()) {
978 slices->emplace_back(slice);
979 }
980 return Status::OK();
981 }
982
LookupSlice(StringPiece full_tensor_key,const TensorSlice & slice_spec,Tensor * val)983 Status BundleReader::LookupSlice(StringPiece full_tensor_key,
984 const TensorSlice& slice_spec, Tensor* val) {
985 CHECK(val != nullptr);
986 BundleEntryProto entry;
987 TF_RETURN_IF_ERROR(GetBundleEntryProto(full_tensor_key, &entry));
988 return GetSliceValue(full_tensor_key, entry, slice_spec, val);
989 }
990
GetSliceValue(StringPiece full_tensor_key,const BundleEntryProto & full_tensor_entry,const TensorSlice & slice_spec,Tensor * val)991 Status BundleReader::GetSliceValue(StringPiece full_tensor_key,
992 const BundleEntryProto& full_tensor_entry,
993 const TensorSlice& slice_spec, Tensor* val) {
994 using checkpoint::RegisterTensorSlice;
995 using checkpoint::TensorSliceSet;
996 DCHECK_GE(full_tensor_entry.slices_size(), 0);
997
998 const TensorShape full_shape(TensorShape(full_tensor_entry.shape()));
999 std::vector<std::pair<TensorSlice, string>> details;
1000 const string full_tensor_key_string(full_tensor_key);
1001 const TensorSliceSet* tss =
1002 gtl::FindPtrOrNull(tensor_slices_, full_tensor_key_string);
1003
1004 // Populates the "full tensor key -> TensorSliceSet" cache.
1005 if (tss == nullptr) {
1006 if (full_tensor_entry.slices().empty()) {
1007 // Special case: a writer has saved a tensor fully, but the reader wants
1008 // to read in slices. We therefore register the full slice on-demand here
1009 // without further complicating the on-disk bundle format.
1010 TF_RETURN_IF_ERROR(RegisterTensorSlice(
1011 full_tensor_key_string, full_shape, full_tensor_entry.dtype(),
1012 /* tag */ "",
1013 /* full slice */ TensorSlice(full_shape.dims()), &tensor_slices_));
1014 }
1015 for (const TensorSliceProto& slice : full_tensor_entry.slices()) {
1016 TF_RETURN_IF_ERROR(RegisterTensorSlice(
1017 full_tensor_key_string, full_shape, full_tensor_entry.dtype(),
1018 /* tag */ "", TensorSlice(slice), &tensor_slices_));
1019 }
1020 tss = gtl::FindPtrOrNull(tensor_slices_, full_tensor_key_string);
1021 CHECK_NE(tss, nullptr);
1022 }
1023 if (!tss->QueryMeta(slice_spec, &details)) {
1024 return errors::InvalidArgument(
1025 "Does not have sufficient slices for partitioned tensor ",
1026 full_tensor_key,
1027 " to restore in slice_spec: ", slice_spec.DebugString());
1028 }
1029
1030 // The union of the slices in "details" covers "slice_spec". Performs the
1031 // copies from each.
1032 BundleEntryProto stored_slice_entry = full_tensor_entry;
1033 for (const auto& slice_tag_pair : details) {
1034 // Seeks for the stored slice.
1035 const TensorSlice& stored_slice = slice_tag_pair.first;
1036
1037 // We already have the entry for the full tensor, so don't query again if
1038 // the slice is full.
1039 if (!stored_slice.IsFull()) {
1040 const string encoded_stored_slice_name =
1041 checkpoint::EncodeTensorNameSlice(full_tensor_key_string,
1042 stored_slice);
1043 status_ =
1044 GetBundleEntryProto(encoded_stored_slice_name, &stored_slice_entry);
1045 if (!status_.ok()) return status_;
1046 }
1047
1048 // TODO(zongheng): should we take an OpKernelContext, so that we can call
1049 // allocate_temp()? Note that without major refactorings to Saver, it's
1050 // hard for the caller of the tensor bundle module to allocate these
1051 // precisely-shaped scratch storage.
1052
1053 // Optimization for the common case: the stored slice can be directly
1054 // copied to the destination without additional slicing. This is true when
1055 // either the slices are equal or when they are both full slices having the
1056 // same shape.
1057 TensorShape stored_slice_shape(stored_slice_entry.shape());
1058 if (stored_slice == slice_spec ||
1059 (stored_slice_shape == val->shape() &&
1060 IsFullSlice(stored_slice, stored_slice_shape) &&
1061 IsFullSlice(slice_spec, stored_slice_shape))) {
1062 VLOG(1) << "Optimized for common case: directly copying into "
1063 "pre-allocated buffer; spec: "
1064 << slice_spec.DebugString();
1065 status_ = GetValue(stored_slice_entry, val);
1066 return status_;
1067 }
1068
1069 Tensor stored_slice_tensor(stored_slice_entry.dtype(), stored_slice_shape);
1070 status_ = GetValue(stored_slice_entry, &stored_slice_tensor);
1071 if (!status_.ok()) return status_;
1072
1073 // Copies the intersection over.
1074 const DataType common_dtype = full_tensor_entry.dtype();
1075 switch (common_dtype) {
1076 #define HANDLE_COPY(T) \
1077 case DataTypeToEnum<T>::value: \
1078 CHECK(CopyDataFromTensorSliceToTensorSlice( \
1079 full_shape, stored_slice, slice_spec, \
1080 stored_slice_tensor.flat<T>().data(), val->flat<T>().data())); \
1081 break;
1082
1083 HANDLE_COPY(float)
1084 HANDLE_COPY(double)
1085 HANDLE_COPY(int32)
1086 HANDLE_COPY(uint8)
1087 HANDLE_COPY(int16)
1088 HANDLE_COPY(int8)
1089 HANDLE_COPY(complex64)
1090 HANDLE_COPY(complex128)
1091 HANDLE_COPY(int64)
1092 HANDLE_COPY(bool)
1093 HANDLE_COPY(qint32)
1094 HANDLE_COPY(quint8)
1095 HANDLE_COPY(qint8)
1096 HANDLE_COPY(bfloat16)
1097 default:
1098 return errors::InvalidArgument("Dtype ", DataTypeString(common_dtype),
1099 " not supported.");
1100 }
1101 #undef HANDLE_COPY
1102 }
1103 return Status::OK();
1104 }
1105
Contains(StringPiece key)1106 bool BundleReader::Contains(StringPiece key) {
1107 Seek(key);
1108 return Valid() && (this->key() == key);
1109 }
1110
LookupDtypeAndShape(StringPiece key,DataType * dtype,TensorShape * shape)1111 Status BundleReader::LookupDtypeAndShape(StringPiece key, DataType* dtype,
1112 TensorShape* shape) {
1113 BundleEntryProto entry;
1114 TF_RETURN_IF_ERROR(GetBundleEntryProto(key, &entry));
1115 *dtype = entry.dtype();
1116 *shape = TensorShape(entry.shape());
1117 return Status::OK();
1118 }
1119
LookupTensorShape(StringPiece key,TensorShape * shape)1120 Status BundleReader::LookupTensorShape(StringPiece key, TensorShape* shape) {
1121 DataType ignored;
1122 return LookupDtypeAndShape(key, &ignored, shape);
1123 }
1124
DebugString()1125 string BundleReader::DebugString() {
1126 // Format used below emulates that of TensorSliceReader::DebugString().
1127 string shape_str;
1128 BundleEntryProto entry;
1129 Seek(kHeaderEntryKey);
1130 for (Next(); Valid(); Next()) {
1131 CHECK(entry.ParseFromArray(value().data(), value().size()));
1132 if (entry.slices_size() > 0) continue; // Slice of some partitioned var.
1133
1134 strings::StrAppend(&shape_str, key(), " (", DataType_Name(entry.dtype()),
1135 ") ", TensorShape(entry.shape()).DebugString());
1136 strings::StrAppend(&shape_str, "\n");
1137 }
1138 return shape_str;
1139 }
1140
~FileOutputBuffer()1141 FileOutputBuffer::~FileOutputBuffer() { delete file_; }
1142
Append(StringPiece data)1143 Status FileOutputBuffer::Append(StringPiece data) {
1144 // In the below, it is critical to calculate the checksum on the actually
1145 // copied bytes, not the source bytes. This is because "data" typically
1146 // points to tensor buffers, which may be concurrently written.
1147 if (data.size() + position_ <= buffer_size_) {
1148 // Can fit into the current buffer.
1149 memcpy(&buffer_[position_], data.data(), data.size());
1150 crc32c_ = crc32c::Extend(crc32c_, &buffer_[position_], data.size());
1151 } else if (data.size() <= buffer_size_) {
1152 // Cannot fit, but can fit after flushing.
1153 TF_RETURN_IF_ERROR(FlushBuffer());
1154 memcpy(&buffer_[0], data.data(), data.size());
1155 crc32c_ = crc32c::Extend(crc32c_, &buffer_[0], data.size());
1156 } else {
1157 // Cannot fit even after flushing. So we break down "data" by chunk, and
1158 // flush/checksum each chunk.
1159 TF_RETURN_IF_ERROR(FlushBuffer());
1160 for (size_t i = 0; i < data.size(); i += buffer_size_) {
1161 const size_t nbytes = std::min(data.size() - i, buffer_size_);
1162 memcpy(&buffer_[0], data.data() + i, nbytes);
1163 crc32c_ = crc32c::Extend(crc32c_, &buffer_[0], nbytes);
1164 position_ = nbytes;
1165 TF_RETURN_IF_ERROR(FlushBuffer());
1166 }
1167 return Status::OK();
1168 }
1169 position_ += data.size();
1170 return Status::OK();
1171 }
1172
Close()1173 Status FileOutputBuffer::Close() {
1174 TF_RETURN_IF_ERROR(FlushBuffer());
1175 return file_->Close();
1176 }
1177
FlushBuffer()1178 Status FileOutputBuffer::FlushBuffer() {
1179 if (position_ > 0) {
1180 TF_RETURN_IF_ERROR(file_->Append(StringPiece(&buffer_[0], position_)));
1181 position_ = 0;
1182 }
1183 return Status::OK();
1184 }
1185
1186 } // namespace tensorflow
1187