1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/util/tensor_bundle/tensor_bundle.h"
17
18 #include <algorithm>
19 #include <cstdlib>
20 #include <cstring>
21 #include <memory>
22 #include <utility>
23
24 #include "tensorflow/core/framework/register_types.h"
25 #include "tensorflow/core/framework/tensor.pb.h"
26 #include "tensorflow/core/framework/tensor_shape.pb.h"
27 #include "tensorflow/core/framework/types.h"
28 #include "tensorflow/core/framework/types.pb.h"
29 #include "tensorflow/core/framework/variant.h"
30 #include "tensorflow/core/framework/variant_op_registry.h"
31 #include "tensorflow/core/framework/variant_tensor_data.h"
32 #include "tensorflow/core/framework/versions.h"
33 #include "tensorflow/core/framework/versions.pb.h"
34 #include "tensorflow/core/lib/bfloat16/bfloat16.h"
35 #include "tensorflow/core/lib/core/coding.h"
36 #include "tensorflow/core/lib/core/errors.h"
37 #include "tensorflow/core/lib/gtl/map_util.h"
38 #include "tensorflow/core/lib/hash/crc32c.h"
39 #include "tensorflow/core/lib/io/path.h"
40 #include "tensorflow/core/lib/io/table_builder.h"
41 #include "tensorflow/core/lib/random/random.h"
42 #include "tensorflow/core/lib/strings/stringprintf.h"
43 #include "tensorflow/core/util/saved_tensor_slice_util.h"
44 #include "tensorflow/core/util/tensor_bundle/byte_swap.h"
45 #include "tensorflow/core/util/tensor_slice_util.h"
46
47 namespace tensorflow {
48
49 // Versioning of the tensor bundle format.
50 const int kTensorBundleMinProducer = 0;
51 const int kTensorBundleMinConsumer = 0;
52 const int kTensorBundleVersion = 1;
53
54 // Size of our input buffer for streaming reads
55 static const int kBufferSize = 1024 * 1024;
56
57 // Key to the special BundleHeaderProto entry. Do not change this, as clients
58 // can make the assumption that the header is always the first entry in the
59 // bundle.
60 const char* const kHeaderEntryKey = "";
61
62 namespace {
63
64 // Reads "num_elements" string elements from file[offset, offset+size) into the
65 // length-N "destination". Discards the original content of "destination".
66 //
67 // Checksums the string lengths (as restored uint32 or uint64, not varint64
68 // bytes) and string bytes, and stores it into "actual_crc32c".
ReadStringTensor(io::InputBuffer * buffered_file,size_t num_elements,size_t offset,size_t size,tstring * destination,uint32 * actual_crc32c,bool need_to_swap_bytes)69 Status ReadStringTensor(io::InputBuffer* buffered_file, size_t num_elements,
70 size_t offset, size_t size, tstring* destination,
71 uint32* actual_crc32c, bool need_to_swap_bytes) {
72 if (size == 0) return Status::OK();
73 CHECK_GT(size, 0);
74
75 // Reads "num_elements" varint64's from "buffered_file".
76 TF_RETURN_IF_ERROR(buffered_file->Seek(offset));
77 std::vector<uint64> string_lengths(num_elements);
78 for (size_t i = 0; i < num_elements; ++i) {
79 TF_RETURN_IF_ERROR(buffered_file->ReadVarint64(&string_lengths[i]));
80 if (string_lengths[i] <= UINT32_MAX) {
81 // We need to do this because older checkpoints only used uint32s and we
82 // should still support them.
83 uint32 elem_size_uint32 = static_cast<uint32>(string_lengths[i]);
84 if (need_to_swap_bytes) {
85 // Checksum would have been computed on the source machine's byte order
86 elem_size_uint32 = BYTE_SWAP_32(elem_size_uint32);
87 }
88 *actual_crc32c = crc32c::Extend(
89 *actual_crc32c, reinterpret_cast<const char*>(&elem_size_uint32),
90 sizeof(uint32));
91 } else {
92 uint64 length = string_lengths[i];
93 if (need_to_swap_bytes) {
94 length = BYTE_SWAP_64(length);
95 }
96 *actual_crc32c =
97 crc32c::Extend(*actual_crc32c, reinterpret_cast<const char*>(&length),
98 sizeof(uint64));
99 }
100 }
101 if (offset + size < buffered_file->Tell()) {
102 return errors::DataLoss("String lengths longer than expected offset ",
103 offset + size);
104 }
105
106 // Reads the length-checksum.
107 uint32 raw_length_checksum = 0; // Bytes in file
108 uint32 length_checksum = 0; // In-memory representation
109 size_t unused_bytes_read = 0;
110 TF_RETURN_IF_ERROR(buffered_file->ReadNBytes(
111 sizeof(uint32), reinterpret_cast<char*>(&raw_length_checksum),
112 &unused_bytes_read));
113 length_checksum = need_to_swap_bytes ? BYTE_SWAP_32(raw_length_checksum)
114 : raw_length_checksum;
115 if (crc32c::Unmask(length_checksum) != *actual_crc32c) {
116 return errors::DataLoss(
117 "The length checksum does not match: expected ",
118 strings::Printf("%08u", crc32c::Unmask(length_checksum)),
119 " but actual is ", strings::Printf("%08u", *actual_crc32c));
120 }
121 *actual_crc32c = crc32c::Extend(*actual_crc32c,
122 reinterpret_cast<char*>(&raw_length_checksum),
123 sizeof(uint32));
124
125 // Reads the actual string bytes.
126 for (size_t i = 0; i < num_elements; ++i) {
127 const uint64 string_length = string_lengths[i];
128 tstring* buffer = &destination[i];
129
130 buffer->resize(string_length);
131 size_t bytes_read = 0;
132 TF_RETURN_IF_ERROR(
133 buffered_file->ReadNBytes(string_length, &(*buffer)[0], &bytes_read));
134 *actual_crc32c = crc32c::Extend(*actual_crc32c, buffer->data(), bytes_read);
135 }
136 return Status::OK();
137 }
138
ReadVariantTensor(io::InputBuffer * buffered_file,Tensor * ret,size_t offset,size_t size,uint32 * actual_crc32c)139 Status ReadVariantTensor(io::InputBuffer* buffered_file, Tensor* ret,
140 size_t offset, size_t size, uint32* actual_crc32c) {
141 // On-disk format:
142 // [varint64 len1][bytes variant1][4 byte checksum]
143 // ..
144 // [varint64 lenN][bytes variantN][4 byte checksum]
145 // Var "crc32c" checksums all the lens, variant bytes, individual variant
146 // checksums (as uint32, not varint32 bytes).
147 if (size == 0) return Status::OK();
148 size_t num_elements = ret->NumElements();
149
150 // Reads the actual string bytes.
151 TF_RETURN_IF_ERROR(buffered_file->Seek(offset));
152 for (size_t i = 0; i < num_elements; ++i) {
153 // Read the serialized variant length.
154 uint64 string_length = 0;
155 TF_RETURN_IF_ERROR(buffered_file->ReadVarint64(&string_length));
156 *actual_crc32c = crc32c::Extend(
157 *actual_crc32c, reinterpret_cast<const char*>(&string_length),
158 sizeof(uint64));
159 // Read the actual serialized variant.
160 string buffer;
161 buffer.resize(string_length);
162 size_t bytes_read = 0;
163 TF_RETURN_IF_ERROR(
164 buffered_file->ReadNBytes(string_length, &buffer[0], &bytes_read));
165 *actual_crc32c = crc32c::Extend(*actual_crc32c, buffer.data(), bytes_read);
166 VariantTensorDataProto proto;
167 if (!proto.ParseFromString(buffer)) {
168 return errors::DataLoss("Unable to parse VariantTensorDataProto from ",
169 "buffer of size ", string_length, ". ",
170 "Bundle entry offset: ", offset, " size: ", size);
171 }
172 Variant v = proto;
173 if (!DecodeUnaryVariant(&v)) {
174 return errors::Internal("Could not decode variant with type_name: \"",
175 v.TypeName(), "\". Perhaps you forgot to ",
176 "register a decoder via ",
177 "REGISTER_UNARY_VARIANT_DECODE_FUNCTION?");
178 }
179
180 // Read the checksum.
181 uint32 checksum = 0;
182 size_t unused_bytes_read = 0;
183 TF_RETURN_IF_ERROR(buffered_file->ReadNBytes(
184 sizeof(uint32), reinterpret_cast<char*>(&checksum),
185 &unused_bytes_read));
186 if (crc32c::Unmask(checksum) != *actual_crc32c) {
187 return errors::DataLoss(
188 "The checksum after Variant ", i, " does not match.",
189 " Expected: ", strings::Printf("%08u", crc32c::Unmask(checksum)),
190 " Actual: ", strings::Printf("%08u", *actual_crc32c));
191 }
192 *actual_crc32c = crc32c::Extend(
193 *actual_crc32c, reinterpret_cast<char*>(&checksum), sizeof(uint32));
194
195 ret->flat<Variant>()(i) = std::move(v);
196 }
197
198 return Status::OK();
199 }
200
GetBackingBuffer(const Tensor & val)201 char* GetBackingBuffer(const Tensor& val) {
202 CHECK(DataTypeCanUseMemcpy(val.dtype())) << val.dtype();
203 return const_cast<char*>(val.tensor_data().data());
204 }
205
GetStringBackingBuffer(const Tensor & val)206 tstring* GetStringBackingBuffer(const Tensor& val) {
207 CHECK_EQ(DT_STRING, val.dtype());
208 return const_cast<tstring*>(val.flat<tstring>().data());
209 }
210
ParseEntryProto(StringPiece key,StringPiece value,protobuf::MessageLite * out)211 Status ParseEntryProto(StringPiece key, StringPiece value,
212 protobuf::MessageLite* out) {
213 if (!out->ParseFromArray(value.data(), value.size())) {
214 return errors::DataLoss("Entry for key ", key, " not parseable.");
215 }
216 return Status::OK();
217 }
218
219 // Serializes the data bytes of the non-string tensor "val". Discards the
220 // original content of "bytes_written", and on OK updates it with number of
221 // bytes written.
222 // REQUIRES: val.dtype() != DT_STRING
WriteTensor(const Tensor & val,FileOutputBuffer * out,size_t * bytes_written)223 Status WriteTensor(const Tensor& val, FileOutputBuffer* out,
224 size_t* bytes_written) {
225 DCHECK_NE(val.dtype(), DT_STRING);
226 DCHECK_NE(val.dtype(), DT_VARIANT);
227 *bytes_written = val.TotalBytes();
228 char* buf = GetBackingBuffer(val);
229 VLOG(1) << "Appending " << *bytes_written << " bytes to file";
230 return out->Append(StringPiece(buf, *bytes_written));
231 }
232
233 // Serializes string tensor "val". "bytes_written" is treated in the same
234 // fashion as WriteTensor().
235 //
236 // Checksums all bytes written and stores it into "crc32c".
237 // REQUIRES: val.dtype() == DT_STRING
WriteStringTensor(const Tensor & val,FileOutputBuffer * out,size_t * bytes_written,uint32 * crc32c)238 Status WriteStringTensor(const Tensor& val, FileOutputBuffer* out,
239 size_t* bytes_written, uint32* crc32c) {
240 // On-disk format:
241 // [varint64 len0]..[varint64 lenL][4 byte cksum on lengths][string bytes]
242 // Var "crc32c" checksums the string lengths (as uint64, not varint64 bytes),
243 // the length-checksum, and all the string bytes.
244 DCHECK_EQ(val.dtype(), DT_STRING);
245 const tstring* strings = GetStringBackingBuffer(val);
246
247 // Writes the varint lengths.
248 string lengths;
249 lengths.reserve(val.NumElements()); // At least 1 byte per element.
250 *crc32c = 0;
251 for (int64 i = 0; i < val.NumElements(); ++i) {
252 const tstring* elem = &strings[i];
253 DCHECK_EQ(elem->size(), static_cast<uint64>(elem->size()));
254 const uint64 elem_size = static_cast<uint64>(elem->size());
255
256 core::PutVarint64(&lengths, elem_size);
257 if (elem_size <= UINT32_MAX) {
258 // We need to do this because older checkpoints only used uint32s and we
259 // should still support them.
260 const uint32 elem_size_uint32 = static_cast<uint32>(elem_size);
261 *crc32c = crc32c::Extend(*crc32c,
262 reinterpret_cast<const char*>(&elem_size_uint32),
263 sizeof(uint32));
264 } else {
265 *crc32c = crc32c::Extend(
266 *crc32c, reinterpret_cast<const char*>(&elem_size), sizeof(uint64));
267 }
268 }
269 TF_RETURN_IF_ERROR(out->Append(lengths));
270 *bytes_written = lengths.size();
271
272 // Writes the length checksum.
273 const uint32 length_checksum = crc32c::Mask(*crc32c);
274 TF_RETURN_IF_ERROR(out->Append(StringPiece(
275 reinterpret_cast<const char*>(&length_checksum), sizeof(uint32))));
276 *crc32c = crc32c::Extend(
277 *crc32c, reinterpret_cast<const char*>(&length_checksum), sizeof(uint32));
278 *bytes_written += sizeof(uint32);
279
280 // Writes all the string bytes out.
281 for (int64 i = 0; i < val.NumElements(); ++i) {
282 const tstring* string = &strings[i];
283 TF_RETURN_IF_ERROR(out->Append(*string));
284 *bytes_written += string->size();
285 *crc32c = crc32c::Extend(*crc32c, string->data(), string->size());
286 }
287 return Status::OK();
288 }
289
WriteVariantTensor(const Tensor & val,FileOutputBuffer * out,size_t * bytes_written,uint32 * crc32c)290 Status WriteVariantTensor(const Tensor& val, FileOutputBuffer* out,
291 size_t* bytes_written, uint32* crc32c) {
292 // On-disk format:
293 // [varint64 len1][bytes variant1][4 byte checksum]
294 // ..
295 // [varint64 lenN][bytes variantN][4 byte checksum]
296 // Var "crc32c" checksums all the lens, variant bytes, individual variant
297 // checksums (as uint32, not varint32 bytes).
298 DCHECK_EQ(val.dtype(), DT_VARIANT);
299
300 *crc32c = 0;
301 *bytes_written = 0;
302 for (int64 i = 0; i < val.NumElements(); ++i) {
303 VariantTensorData data;
304 val.flat<Variant>()(i).Encode(&data);
305 VariantTensorDataProto proto;
306 data.ToProto(&proto);
307 string elem;
308 proto.SerializeToString(&elem);
309
310 // Write the length of the serialized variant.
311 DCHECK_EQ(elem.size(), static_cast<uint64>(elem.size()));
312 const auto elem_size = static_cast<uint64>(elem.size());
313 string len;
314 core::PutVarint64(&len, elem_size);
315 TF_RETURN_IF_ERROR(out->Append(len));
316 *crc32c = crc32c::Extend(*crc32c, reinterpret_cast<const char*>(&elem_size),
317 sizeof(uint64));
318 *bytes_written += len.size();
319
320 // Write the serialized variant.
321 TF_RETURN_IF_ERROR(out->Append(elem));
322 *crc32c = crc32c::Extend(*crc32c, elem.data(), elem.size());
323 *bytes_written += elem.size();
324
325 // Write the checksum.
326 const uint32 length_checksum = crc32c::Mask(*crc32c);
327 TF_RETURN_IF_ERROR(out->Append(StringPiece(
328 reinterpret_cast<const char*>(&length_checksum), sizeof(uint32))));
329 *crc32c =
330 crc32c::Extend(*crc32c, reinterpret_cast<const char*>(&length_checksum),
331 sizeof(uint32));
332 *bytes_written += sizeof(uint32);
333 }
334
335 return Status::OK();
336 }
337
338 // Returns whether "slice_spec" is a full slice, with respect to the full shape.
339 //
340 // This can happen say, when "slice_spec" is
341 // "TensorSlice(full_tensor_shape.dims())", or when it is "TensorSlice({{0,
342 // dim(0)}, ..., {0, dim(N)}})" -- a degenerate case we need to guard against.
IsFullSlice(const TensorSlice & slice_spec,const TensorShape & full_tensor_shape)343 bool IsFullSlice(const TensorSlice& slice_spec,
344 const TensorShape& full_tensor_shape) {
345 if (slice_spec.IsFull()) {
346 return true;
347 } else {
348 TensorShape sliced_shape;
349 slice_spec.SliceTensorShape(full_tensor_shape, &sliced_shape).IgnoreError();
350 return sliced_shape == full_tensor_shape;
351 }
352 }
353
CorruptFileError(const Status & in_status,const string & filename,const string & detail)354 Status CorruptFileError(const Status& in_status, const string& filename,
355 const string& detail) {
356 if (in_status.ok()) {
357 return errors::Internal("Unable to read file (", filename,
358 "). Perhaps the file is corrupt or was produced by "
359 "a newer version of TensorFlow with format changes "
360 "(",
361 detail, ")");
362 }
363 return Status(
364 in_status.code(),
365 strings::StrCat("Unable to read file (", filename,
366 "). Perhaps the file is corrupt or was produced by a "
367 "newer version of TensorFlow with format changes (",
368 detail, "): ", in_status.error_message()));
369 }
370
TableBuilderOptions()371 table::Options TableBuilderOptions() {
372 table::Options o;
373 // Compressed tables cannot be read by TensorFlow releases prior to 1.1.
374 // To smoothen the transition, compressed writes are disabled for now
375 // (version 1.2) with the intention that they will be enabled again at
376 // some point (perhaps the 1.3 release?).
377 o.compression = table::kNoCompression;
378 return o;
379 }
380
381 // Writes zeros to output buffer to align the next write to the requested
382 // alignment. "size" is the current size of the buffer and is updated to the
383 // new size.
PadAlignment(FileOutputBuffer * out,int alignment,int64 * size)384 Status PadAlignment(FileOutputBuffer* out, int alignment, int64* size) {
385 int bytes_over = *size % alignment;
386 if (bytes_over == 0) {
387 return Status::OK();
388 }
389 int bytes_to_write = alignment - bytes_over;
390 Status status = out->Append(string(bytes_to_write, '\0'));
391 if (status.ok()) {
392 *size += bytes_to_write;
393 }
394 return status;
395 }
396
397 } // namespace
398
BundleWriter(Env * env,StringPiece prefix,const Options & options)399 BundleWriter::BundleWriter(Env* env, StringPiece prefix, const Options& options)
400 : env_(env),
401 options_(options),
402 prefix_(prefix),
403 tmp_metadata_path_(strings::StrCat(MetaFilename(prefix_), ".tempstate",
404 random::New64())),
405 tmp_data_path_(strings::StrCat(DataFilename(prefix_, 0, 1), ".tempstate",
406 random::New64())),
407 out_(nullptr),
408 size_(0) {
409 status_ = env_->CreateDir(string(io::Dirname(prefix_)));
410 if (!status_.ok() && !errors::IsAlreadyExists(status_)) {
411 return;
412 }
413 const string filename = DataFilename(prefix_, 0, 1);
414 std::unique_ptr<WritableFile> wrapper;
415 status_ = env_->NewWritableFile(tmp_data_path_, &wrapper);
416 if (!status_.ok()) return;
417 out_ = std::unique_ptr<FileOutputBuffer>(
418 new FileOutputBuffer(wrapper.release(), 8 << 20 /* 8MB write buffer */));
419
420 VLOG(1) << "Writing to file " << tmp_data_path_;
421 }
422
Add(StringPiece key,const Tensor & val)423 Status BundleWriter::Add(StringPiece key, const Tensor& val) {
424 if (!status_.ok()) return status_;
425 CHECK_NE(key, kHeaderEntryKey);
426 const string key_string(key);
427 if (entries_.find(key_string) != entries_.end()) {
428 status_ = errors::InvalidArgument("Adding duplicate key: ", key);
429 return status_;
430 }
431
432 BundleEntryProto* entry = &entries_[key_string];
433 entry->set_dtype(val.dtype());
434 val.shape().AsProto(entry->mutable_shape());
435 entry->set_shard_id(0);
436 entry->set_offset(size_);
437
438 // Updates the data file.
439 size_t data_bytes_written = 0;
440 uint32 crc32c = 0;
441 out_->clear_crc32c();
442 if (val.dtype() == DT_STRING) {
443 status_ = WriteStringTensor(val, out_.get(), &data_bytes_written, &crc32c);
444 } else if (val.dtype() == DT_VARIANT) {
445 status_ = WriteVariantTensor(val, out_.get(), &data_bytes_written, &crc32c);
446 } else {
447 status_ = WriteTensor(val, out_.get(), &data_bytes_written);
448 crc32c = out_->crc32c();
449 }
450
451 if (status_.ok()) {
452 entry->set_size(data_bytes_written);
453 entry->set_crc32c(crc32c::Mask(crc32c));
454 size_ += data_bytes_written;
455 status_ = PadAlignment(out_.get(), options_.data_alignment, &size_);
456 }
457 return status_;
458 }
459
AddSlice(StringPiece full_tensor_key,const TensorShape & full_tensor_shape,const TensorSlice & slice_spec,const Tensor & slice_tensor)460 Status BundleWriter::AddSlice(StringPiece full_tensor_key,
461 const TensorShape& full_tensor_shape,
462 const TensorSlice& slice_spec,
463 const Tensor& slice_tensor) {
464 if (!status_.ok()) return status_;
465 CHECK_NE(full_tensor_key, kHeaderEntryKey);
466
467 // If just a singleton full slice, use the regular Add() to be more efficient.
468 if (IsFullSlice(slice_spec, full_tensor_shape)) {
469 return Add(full_tensor_key, slice_tensor);
470 }
471
472 // Inserts/updates the full tensor's metadata entry.
473 //
474 // In the case of a sharded save, MergeBundles() is responsible for merging
475 // the "slices" field of multiple metadata entries corresponding to the same
476 // full tensor.
477 const string full_tensor_key_string(full_tensor_key);
478 BundleEntryProto* full_entry = &entries_[full_tensor_key_string];
479 if (full_entry->dtype() != DT_INVALID) {
480 CHECK_EQ(full_entry->dtype(), slice_tensor.dtype());
481 }
482 if (full_entry->has_shape()) {
483 CHECK(TensorShape(full_entry->shape()) == full_tensor_shape);
484 }
485
486 // Populates dtype, shape, and slices. Intentionally leaving out shard_id and
487 // offset, which do not make sense for this full tensor entry.
488 full_entry->set_dtype(slice_tensor.dtype());
489 full_tensor_shape.AsProto(full_entry->mutable_shape());
490 TensorSliceProto* slice_proto = full_entry->add_slices();
491 slice_spec.AsProto(slice_proto);
492
493 // The slice itself is handled by a regular Add(), which includes adding its
494 // own metadata entry, and writing out the slice's values.
495 const string slice_name =
496 checkpoint::EncodeTensorNameSlice(full_tensor_key_string, slice_spec);
497 status_ = Add(slice_name, slice_tensor);
498 return status_;
499 }
500
501 // TODO(zongheng): on metadata write failure or !status_.ok(), consider removing
502 // the orphaned data file.
Finish()503 Status BundleWriter::Finish() {
504 if (out_) {
505 status_.Update(out_->Close());
506 out_ = nullptr;
507 if (status_.ok()) {
508 status_ = Env::Default()->RenameFile(tmp_data_path_,
509 DataFilename(prefix_, 0, 1));
510 } else {
511 Env::Default()->DeleteFile(tmp_data_path_).IgnoreError();
512 }
513 }
514 if (!status_.ok()) return status_;
515 // Build key -> BundleEntryProto table.
516 std::unique_ptr<WritableFile> file;
517 status_ = env_->NewWritableFile(tmp_metadata_path_, &file);
518 if (!status_.ok()) return status_;
519 {
520 // N.B.: the default use of Snappy compression may not be supported on all
521 // platforms (e.g. Android). The metadata file is small, so this is fine.
522 table::Options options;
523 options.compression = table::kNoCompression;
524 table::TableBuilder builder(options, file.get());
525 // Header entry.
526 BundleHeaderProto header;
527 header.set_num_shards(1);
528 header.set_endianness(BundleHeaderProto::LITTLE);
529 if (!port::kLittleEndian) header.set_endianness(BundleHeaderProto::BIG);
530 VersionDef* version = header.mutable_version();
531 version->set_producer(kTensorBundleVersion);
532 version->set_min_consumer(kTensorBundleMinConsumer);
533
534 builder.Add(kHeaderEntryKey, header.SerializeAsString());
535
536 // All others.
537 for (const auto& p : entries_) {
538 builder.Add(p.first, p.second.SerializeAsString());
539 }
540 status_ = builder.Finish();
541 }
542 status_.Update(file->Close());
543 if (!status_.ok()) {
544 Env::Default()->DeleteFile(tmp_metadata_path_).IgnoreError();
545 return status_;
546 } else {
547 status_ =
548 Env::Default()->RenameFile(tmp_metadata_path_, MetaFilename(prefix_));
549 if (!status_.ok()) return status_;
550 }
551 status_ = errors::Internal("BundleWriter is closed");
552 return Status::OK();
553 }
554
555 // Merging tensor bundles.
556
557 // Accumulator of metadata states during a merge.
558 struct MergeState {
559 // Accumulated from the header entries.
560 int num_shards = 0;
561
562 // Derives "endianness" and "version" from the first bundle merged (hence the
563 // "seen_first_bundle" guard). The two fields must be the same for all
564 // bundles in a merge.
565 bool seen_first_bundle = false;
566 BundleHeaderProto_Endianness endianness;
567 VersionDef version;
568
569 // Tensor key -> BundleEntryProto.
570 std::map<string, BundleEntryProto> entries;
571 // Data file path -> new shard id in the final merged bundle.
572 std::unordered_map<string, int32> shard_ids;
573 };
574
575 // Merges entries of "prefix" into the accumulator state "merge".
576 // Returns OK iff the merge succeeds.
MergeOneBundle(Env * env,StringPiece prefix,MergeState * merge_state)577 static Status MergeOneBundle(Env* env, StringPiece prefix,
578 MergeState* merge_state) {
579 VLOG(1) << "Merging bundle:" << prefix;
580 const string filename = MetaFilename(prefix);
581 uint64 file_size;
582 TF_RETURN_IF_ERROR(env->GetFileSize(filename, &file_size));
583 std::unique_ptr<RandomAccessFile> file;
584 TF_RETURN_IF_ERROR(env->NewRandomAccessFile(filename, &file));
585
586 table::Table* table = nullptr;
587 TF_RETURN_IF_ERROR(
588 table::Table::Open(TableBuilderOptions(), file.get(), file_size, &table));
589 std::unique_ptr<table::Table> table_deleter(table);
590 std::unique_ptr<table::Iterator> iter(table->NewIterator());
591
592 int num_shards;
593 // Process header.
594 {
595 iter->Seek(kHeaderEntryKey);
596 if (!iter->Valid()) {
597 return CorruptFileError(iter->status(), filename,
598 "failed to seek to header entry");
599 }
600 BundleHeaderProto header;
601 Status s = ParseEntryProto(iter->key(), iter->value(), &header);
602 if (!s.ok()) return CorruptFileError(s, filename, "unable to parse header");
603
604 merge_state->num_shards += header.num_shards();
605 if (!merge_state->seen_first_bundle) {
606 merge_state->seen_first_bundle = true;
607 merge_state->endianness = header.endianness();
608 merge_state->version = header.version();
609 } else {
610 // Validates "endianness".
611 if (merge_state->endianness != header.endianness()) {
612 return errors::InvalidArgument(
613 "Merging bundles with conflicting endianness; inputs corrupted?");
614 }
615 // Validates "version".
616 string curr_version, merge_version;
617 header.version().SerializeToString(&curr_version);
618 merge_state->version.SerializeToString(&merge_version);
619 if (curr_version != merge_version) {
620 return errors::InvalidArgument(
621 "Merging bundles with different format versions: merged ",
622 merge_version, " vs. curr ", curr_version);
623 }
624 }
625 num_shards = header.num_shards();
626 iter->Next();
627 }
628
629 // Loops through the non-header to-merge entries.
630 BundleEntryProto to_merge_entry;
631 for (; iter->Valid(); iter->Next()) {
632 const string key(iter->key());
633 const auto entry_iter = merge_state->entries.find(key);
634
635 // Illegal: the duplicated entry is a non-slice tensor.
636 if (entry_iter != merge_state->entries.end() &&
637 entry_iter->second.slices().empty()) {
638 return errors::InvalidArgument(
639 "Duplicate tensor keyed by ", key,
640 " encountered, when merging prefix: ", prefix);
641 }
642
643 TF_RETURN_IF_ERROR(
644 ParseEntryProto(iter->key(), iter->value(), &to_merge_entry));
645
646 // The duplicated entry holds metadata for a sliced full tensor.
647 // Allows the duplication and merges "slices".
648 if (entry_iter != merge_state->entries.end()) {
649 BundleEntryProto& existing_entry = entry_iter->second;
650 if (to_merge_entry.slices().empty()) {
651 return errors::Internal(
652 "Duplicate tensor keyed by ", key,
653 "; attempting to merge in a non-slice bundle entry");
654 }
655 // Only needs merge the "slices" field (and validate dtype/shape).
656 for (int i = 0; i < to_merge_entry.slices_size(); ++i) {
657 TensorSliceProto* slot = existing_entry.add_slices();
658 *slot = to_merge_entry.slices(i);
659 }
660 CHECK_EQ(existing_entry.dtype(), to_merge_entry.dtype());
661 CHECK(TensorShape(existing_entry.shape()) ==
662 TensorShape(to_merge_entry.shape()));
663 continue;
664 }
665
666 // Key doesn't duplicate: a fresh tensor/slice entry.
667 auto result = merge_state->shard_ids.insert(
668 {DataFilename(prefix, to_merge_entry.shard_id(), num_shards),
669 merge_state->shard_ids.size()});
670 to_merge_entry.set_shard_id(result.first->second);
671 merge_state->entries[key] = to_merge_entry;
672 }
673 return Status::OK();
674 }
675
MergeBundles(Env * env,gtl::ArraySlice<tstring> prefixes,StringPiece merged_prefix)676 Status MergeBundles(Env* env, gtl::ArraySlice<tstring> prefixes,
677 StringPiece merged_prefix) {
678 // Merges all metadata tables.
679 // TODO(zhifengc): KeyValue sorter if it becomes too big.
680 MergeState merge;
681 Status status = env->CreateDir(string(io::Dirname(merged_prefix)));
682 if (!status.ok() && !errors::IsAlreadyExists(status)) return status;
683 for (int i = 0; i < prefixes.size(); ++i) {
684 TF_RETURN_IF_ERROR(MergeOneBundle(env, prefixes[i], &merge));
685 }
686
687 // Renames data files to contain the merged bundle prefix.
688 for (const auto& p : merge.shard_ids) {
689 VLOG(1) << "Renaming " << p.first << " to "
690 << DataFilename(merged_prefix, p.second, merge.shard_ids.size());
691 TF_RETURN_IF_ERROR(env->RenameFile(
692 p.first,
693 DataFilename(merged_prefix, p.second, merge.shard_ids.size())));
694 }
695
696 // Writes the final metadata table under the merged prefix.
697 std::unique_ptr<WritableFile> merged_metadata;
698 TF_RETURN_IF_ERROR(
699 env->NewWritableFile(MetaFilename(merged_prefix), &merged_metadata));
700 {
701 table::TableBuilder builder(TableBuilderOptions(), merged_metadata.get());
702 // Header entry.
703 BundleHeaderProto header;
704 header.set_num_shards(merge.num_shards);
705 header.set_endianness(merge.endianness);
706 *header.mutable_version() = merge.version;
707 builder.Add(kHeaderEntryKey, header.SerializeAsString());
708 // All others.
709 for (const auto& p : merge.entries) {
710 builder.Add(p.first, p.second.SerializeAsString());
711 }
712 status = builder.Finish();
713 }
714 status.Update(merged_metadata->Close());
715 if (!status.ok()) return status;
716 VLOG(1) << "Merged bundles to:" << merged_prefix;
717
718 // Cleanup: best effort based and ignores errors.
719 for (const tstring& prefix : prefixes) {
720 env->DeleteFile(MetaFilename(prefix)).IgnoreError();
721 }
722 return status;
723 }
724
725 // Interface for reading a tensor bundle.
726
BundleReader(Env * env,StringPiece prefix)727 BundleReader::BundleReader(Env* env, StringPiece prefix)
728 : env_(env),
729 prefix_(prefix),
730 metadata_(nullptr),
731 table_(nullptr),
732 iter_(nullptr),
733 need_to_swap_bytes_(false) {
734 const string filename = MetaFilename(prefix_);
735 uint64 file_size;
736 status_ = env_->GetFileSize(filename, &file_size);
737 if (!status_.ok()) return;
738
739 // Opens the metadata table.
740 std::unique_ptr<RandomAccessFile> wrapper;
741 status_ = env_->NewRandomAccessFile(filename, &wrapper);
742 if (!status_.ok()) return;
743 metadata_ = wrapper.release();
744 status_ = table::Table::Open(table::Options(), metadata_, file_size, &table_);
745 if (!status_.ok()) return;
746 iter_ = table_->NewIterator();
747
748 // Reads "num_shards_" from the first entry.
749 iter_->Seek(kHeaderEntryKey);
750 if (!iter_->Valid()) {
751 status_ = CorruptFileError(iter_->status(), filename,
752 "failed to seek to header entry");
753 return;
754 }
755 BundleHeaderProto header;
756 status_ = ParseEntryProto(iter_->key(), iter_->value(), &header);
757 if (!status_.ok()) {
758 status_ = CorruptFileError(status_, filename, "unable to parse header");
759 return;
760 }
761 num_shards_ = header.num_shards();
762 if ((header.endianness() == BundleHeaderProto::BIG && port::kLittleEndian) ||
763 (header.endianness() == BundleHeaderProto::LITTLE &&
764 !port::kLittleEndian)) {
765 need_to_swap_bytes_ = true;
766 }
767 status_ = CheckVersions(header.version(), kTensorBundleVersion,
768 kTensorBundleMinProducer, "Checkpoint", "checkpoint");
769 }
770
~BundleReader()771 BundleReader::~BundleReader() {
772 delete metadata_;
773 delete iter_;
774 delete table_;
775 // InputBuffer does not own the underlying RandomAccessFile.
776 for (auto pair : data_) {
777 if (pair.second != nullptr && pair.second->file() != nullptr) {
778 delete pair.second->file();
779 }
780 }
781 for (auto& temp : data_) {
782 delete temp.second;
783 }
784 for (auto& temp : tensor_slices_) {
785 delete temp.second;
786 }
787 data_.clear();
788 tensor_slices_.clear();
789 }
790
GetBundleEntryProto(StringPiece key,BundleEntryProto * entry)791 Status BundleReader::GetBundleEntryProto(StringPiece key,
792 BundleEntryProto* entry) {
793 entry->Clear();
794 TF_CHECK_OK(status_);
795 Seek(key);
796 if (!iter_->Valid() || iter_->key() != key) {
797 return errors::NotFound("Key ", key, " not found in checkpoint");
798 }
799
800 BundleEntryProto entry_copy;
801 TF_RETURN_IF_ERROR(
802 ParseEntryProto(iter_->key(), iter_->value(), &entry_copy));
803 if (!TensorShape::IsValid(entry_copy.shape())) {
804 return errors::DataLoss("Invalid tensor shape: ", key, " ",
805 entry_copy.shape().ShortDebugString());
806 }
807
808 *entry = entry_copy;
809 return Status::OK();
810 }
811
GetValue(const BundleEntryProto & entry,Tensor * val)812 Status BundleReader::GetValue(const BundleEntryProto& entry, Tensor* val) {
813 Tensor* ret = val;
814 const TensorShape stored_shape(TensorShape(entry.shape()));
815 if (val->NumElements() == 0) {
816 ret = new Tensor(entry.dtype(), stored_shape);
817 }
818
819 // Validates the "size" field.
820 if (entry.dtype() != DT_STRING && entry.dtype() != DT_VARIANT) {
821 if (entry.size() != ret->TotalBytes()) {
822 return errors::DataLoss("Invalid size in bundle entry: key ", key(),
823 "; stored size ", entry.size(),
824 "; expected size ", ret->TotalBytes());
825 }
826 } else if (entry.dtype() == DT_STRING) {
827 // Relaxes the check for string tensors as follows:
828 // entry.size() == bytes(varint lengths) + bytes(data)
829 // >= NumElems + bytes(data), since size bytes(varint) >= 1.
830 // TotalBytes() == sizeof(tstring) * NumElems + bytes(data)
831 // Since we don't know bytes(varint lengths), we just check an inequality.
832 const size_t lower_bound = ret->NumElements() + ret->TotalBytes() -
833 sizeof(tstring) * ret->NumElements();
834 if (entry.size() < lower_bound) {
835 return errors::DataLoss("Invalid size in bundle entry: key ", key(),
836 "; stored size ", entry.size(),
837 "; expected size is at least ", lower_bound);
838 }
839 }
840
841 // Open the data file if it has not been opened.
842 io::InputBuffer* buffered_file = data_[entry.shard_id()];
843 if (buffered_file == nullptr) {
844 std::unique_ptr<RandomAccessFile> file = nullptr;
845 TF_RETURN_IF_ERROR(env_->NewRandomAccessFile(
846 DataFilename(prefix_, entry.shard_id(), num_shards_), &file));
847 buffered_file = new io::InputBuffer(file.release(), kBufferSize);
848 // The InputBuffer and RandomAccessFile objects are both released in dtor.
849 data_[entry.shard_id()] = buffered_file;
850 }
851 CHECK(buffered_file != nullptr);
852
853 TF_RETURN_IF_ERROR(buffered_file->Seek(entry.offset()));
854 uint32 actual_crc32c = 0;
855
856 if (DataTypeCanUseMemcpy(entry.dtype())) {
857 char* backing_buffer = const_cast<char*>((ret->tensor_data().data()));
858 size_t unused_bytes_read;
859 if (entry.size() > kBufferSize) {
860 StringPiece sp;
861 TF_RETURN_IF_ERROR(buffered_file->file()->Read(
862 entry.offset(), entry.size(), &sp, backing_buffer));
863 if (sp.data() != backing_buffer) {
864 memmove(backing_buffer, sp.data(), entry.size());
865 }
866 } else {
867 TF_RETURN_IF_ERROR(buffered_file->ReadNBytes(entry.size(), backing_buffer,
868 &unused_bytes_read));
869 }
870 // Note that we compute the checksum *before* byte-swapping. The checksum
871 // should be on the bytes in the order they appear in the file.
872 actual_crc32c = crc32c::Value(backing_buffer, entry.size());
873 if (need_to_swap_bytes_) {
874 TF_RETURN_IF_ERROR(ByteSwapTensor(ret));
875 }
876 } else if (entry.dtype() == DT_VARIANT) {
877 if (need_to_swap_bytes_) {
878 return errors::Unimplemented(
879 "TensorBundle at ", prefix_,
880 "is of a different endianness than this machine's hardware, and "
881 "the bundle contains a variant (arbitrary C++ type) tensor. "
882 "Byte-swapping of variant tensors is not currently implemented.");
883 }
884 // Relies on io::InputBuffer's buffering, because we issue many neighboring
885 // reads for a single string tensor.
886 TF_RETURN_IF_ERROR(ReadVariantTensor(buffered_file, ret, entry.offset(),
887 entry.size(), &actual_crc32c));
888 } else {
889 // Relies on io::InputBuffer's buffering, because we issue many neighboring
890 // reads for a single string tensor.
891 TF_RETURN_IF_ERROR(ReadStringTensor(
892 buffered_file, ret->NumElements(), entry.offset(), entry.size(),
893 GetStringBackingBuffer(*ret), &actual_crc32c, need_to_swap_bytes_));
894 }
895 if (crc32c::Unmask(entry.crc32c()) != actual_crc32c) {
896 return errors::DataLoss(
897 "Checksum does not match: stored ",
898 strings::Printf("%08u", crc32c::Unmask(entry.crc32c())),
899 " vs. calculated on the restored bytes ", actual_crc32c);
900 }
901
902 *val = *ret;
903 if (ret != val) delete ret;
904 return Status::OK();
905 }
906
Lookup(StringPiece key,Tensor * val)907 Status BundleReader::Lookup(StringPiece key, Tensor* val) {
908 CHECK(val != nullptr);
909 BundleEntryProto entry;
910 TF_RETURN_IF_ERROR(GetBundleEntryProto(key, &entry));
911
912 if (entry.slices().empty()) {
913 return GetValue(entry, val);
914 } else {
915 return GetSliceValue(
916 key, entry,
917 /* a full slice */ TensorSlice(TensorShape(entry.shape()).dims()), val);
918 }
919 }
920
ReadCurrent(Tensor * val)921 Status BundleReader::ReadCurrent(Tensor* val) {
922 CHECK(val != nullptr);
923 BundleEntryProto entry;
924 TF_RETURN_IF_ERROR(ParseEntryProto(iter_->key(), iter_->value(), &entry));
925 if (!TensorShape::IsValid(entry.shape())) {
926 return errors::DataLoss("Invalid tensor shape: ", iter_->key(), " ",
927 entry.shape().ShortDebugString());
928 }
929
930 if (entry.slices().empty()) {
931 return GetValue(entry, val);
932 } else {
933 return GetSliceValue(
934 iter_->key(), entry,
935 /* a full slice */ TensorSlice(TensorShape(entry.shape()).dims()), val);
936 }
937 }
938
LookupTensorSlices(StringPiece key,std::vector<TensorSlice> * slices)939 Status BundleReader::LookupTensorSlices(StringPiece key,
940 std::vector<TensorSlice>* slices) {
941 slices->clear();
942 BundleEntryProto entry;
943 TF_RETURN_IF_ERROR(GetBundleEntryProto(key, &entry));
944 slices->reserve(entry.slices_size());
945 for (const auto& slice : entry.slices()) {
946 slices->emplace_back(slice);
947 }
948 return Status::OK();
949 }
950
LookupSlice(StringPiece full_tensor_key,const TensorSlice & slice_spec,Tensor * val)951 Status BundleReader::LookupSlice(StringPiece full_tensor_key,
952 const TensorSlice& slice_spec, Tensor* val) {
953 CHECK(val != nullptr);
954 BundleEntryProto entry;
955 TF_RETURN_IF_ERROR(GetBundleEntryProto(full_tensor_key, &entry));
956 return GetSliceValue(full_tensor_key, entry, slice_spec, val);
957 }
958
GetSliceValue(StringPiece full_tensor_key,const BundleEntryProto & full_tensor_entry,const TensorSlice & slice_spec,Tensor * val)959 Status BundleReader::GetSliceValue(StringPiece full_tensor_key,
960 const BundleEntryProto& full_tensor_entry,
961 const TensorSlice& slice_spec, Tensor* val) {
962 using checkpoint::RegisterTensorSlice;
963 using checkpoint::TensorSliceSet;
964 DCHECK_GE(full_tensor_entry.slices_size(), 0);
965
966 const TensorShape full_shape(TensorShape(full_tensor_entry.shape()));
967 std::vector<std::pair<TensorSlice, string>> details;
968 const string full_tensor_key_string(full_tensor_key);
969 const TensorSliceSet* tss =
970 gtl::FindPtrOrNull(tensor_slices_, full_tensor_key_string);
971
972 // Populates the "full tensor key -> TensorSliceSet" cache.
973 if (tss == nullptr) {
974 if (full_tensor_entry.slices().empty()) {
975 // Special case: a writer has saved a tensor fully, but the reader wants
976 // to read in slices. We therefore register the full slice on-demand here
977 // without further complicating the on-disk bundle format.
978 TF_RETURN_IF_ERROR(RegisterTensorSlice(
979 full_tensor_key_string, full_shape, full_tensor_entry.dtype(),
980 /* tag */ "",
981 /* full slice */ TensorSlice(full_shape.dims()), &tensor_slices_));
982 }
983 for (const TensorSliceProto& slice : full_tensor_entry.slices()) {
984 TF_RETURN_IF_ERROR(RegisterTensorSlice(
985 full_tensor_key_string, full_shape, full_tensor_entry.dtype(),
986 /* tag */ "", TensorSlice(slice), &tensor_slices_));
987 }
988 tss = gtl::FindPtrOrNull(tensor_slices_, full_tensor_key_string);
989 CHECK_NE(tss, nullptr);
990 }
991 if (!tss->QueryMeta(slice_spec, &details)) {
992 return errors::InvalidArgument(
993 "Does not have sufficient slices for partitioned tensor ",
994 full_tensor_key,
995 " to restore in slice_spec: ", slice_spec.DebugString());
996 }
997
998 // The union of the slices in "details" covers "slice_spec". Performs the
999 // copies from each.
1000 BundleEntryProto stored_slice_entry = full_tensor_entry;
1001 for (const auto& slice_tag_pair : details) {
1002 // Seeks for the stored slice.
1003 const TensorSlice& stored_slice = slice_tag_pair.first;
1004
1005 // We already have the entry for the full tensor, so don't query again if
1006 // the slice is full.
1007 if (!stored_slice.IsFull()) {
1008 const string encoded_stored_slice_name =
1009 checkpoint::EncodeTensorNameSlice(full_tensor_key_string,
1010 stored_slice);
1011 status_ =
1012 GetBundleEntryProto(encoded_stored_slice_name, &stored_slice_entry);
1013 if (!status_.ok()) return status_;
1014 }
1015
1016 // TODO(zongheng): should we take an OpKernelContext, so that we can call
1017 // allocate_temp()? Note that without major refactorings to Saver, it's
1018 // hard for the caller of the tensor bundle module to allocate these
1019 // precisely-shaped scratch storage.
1020
1021 // Optimization for the common case: the stored slice can be directly
1022 // copied to the destination without additional slicing. This is true when
1023 // either the slices are equal or when they are both full slices having the
1024 // same shape.
1025 TensorShape stored_slice_shape(stored_slice_entry.shape());
1026 if (stored_slice == slice_spec ||
1027 (stored_slice_shape == val->shape() &&
1028 IsFullSlice(stored_slice, stored_slice_shape) &&
1029 IsFullSlice(slice_spec, stored_slice_shape))) {
1030 VLOG(1) << "Optimized for common case: directly copying into "
1031 "pre-allocated buffer; spec: "
1032 << slice_spec.DebugString();
1033 status_ = GetValue(stored_slice_entry, val);
1034 return status_;
1035 }
1036
1037 Tensor stored_slice_tensor(stored_slice_entry.dtype(), stored_slice_shape);
1038 status_ = GetValue(stored_slice_entry, &stored_slice_tensor);
1039 if (!status_.ok()) return status_;
1040
1041 // Copies the intersection over.
1042 const DataType common_dtype = full_tensor_entry.dtype();
1043 switch (common_dtype) {
1044 #define HANDLE_COPY(T) \
1045 case DataTypeToEnum<T>::value: \
1046 CHECK(CopyDataFromTensorSliceToTensorSlice( \
1047 full_shape, stored_slice, slice_spec, \
1048 stored_slice_tensor.flat<T>().data(), val->flat<T>().data())); \
1049 break;
1050
1051 HANDLE_COPY(float)
1052 HANDLE_COPY(double)
1053 HANDLE_COPY(int32)
1054 HANDLE_COPY(uint8)
1055 HANDLE_COPY(int16)
1056 HANDLE_COPY(int8)
1057 HANDLE_COPY(complex64)
1058 HANDLE_COPY(complex128)
1059 HANDLE_COPY(int64)
1060 HANDLE_COPY(bool)
1061 HANDLE_COPY(qint32)
1062 HANDLE_COPY(quint8)
1063 HANDLE_COPY(qint8)
1064 HANDLE_COPY(bfloat16)
1065 default:
1066 return errors::InvalidArgument("Dtype ", DataTypeString(common_dtype),
1067 " not supported.");
1068 }
1069 #undef HANDLE_COPY
1070 }
1071 return Status::OK();
1072 }
1073
Contains(StringPiece key)1074 bool BundleReader::Contains(StringPiece key) {
1075 Seek(key);
1076 return Valid() && (this->key() == key);
1077 }
1078
LookupDtypeAndShape(StringPiece key,DataType * dtype,TensorShape * shape)1079 Status BundleReader::LookupDtypeAndShape(StringPiece key, DataType* dtype,
1080 TensorShape* shape) {
1081 BundleEntryProto entry;
1082 TF_RETURN_IF_ERROR(GetBundleEntryProto(key, &entry));
1083 *dtype = entry.dtype();
1084 *shape = TensorShape(entry.shape());
1085 return Status::OK();
1086 }
1087
LookupTensorShape(StringPiece key,TensorShape * shape)1088 Status BundleReader::LookupTensorShape(StringPiece key, TensorShape* shape) {
1089 DataType ignored;
1090 return LookupDtypeAndShape(key, &ignored, shape);
1091 }
1092
DebugString()1093 string BundleReader::DebugString() {
1094 // Format used below emulates that of TensorSliceReader::DebugString().
1095 string shape_str;
1096 BundleEntryProto entry;
1097 Seek(kHeaderEntryKey);
1098 for (Next(); Valid(); Next()) {
1099 CHECK(entry.ParseFromArray(value().data(), value().size()));
1100 if (entry.slices_size() > 0) continue; // Slice of some partitioned var.
1101
1102 strings::StrAppend(&shape_str, key(), " (", DataType_Name(entry.dtype()),
1103 ") ", TensorShape(entry.shape()).DebugString());
1104 strings::StrAppend(&shape_str, "\n");
1105 }
1106 return shape_str;
1107 }
1108
~FileOutputBuffer()1109 FileOutputBuffer::~FileOutputBuffer() { delete file_; }
1110
Append(StringPiece data)1111 Status FileOutputBuffer::Append(StringPiece data) {
1112 // In the below, it is critical to calculate the checksum on the actually
1113 // copied bytes, not the source bytes. This is because "data" typically
1114 // points to tensor buffers, which may be concurrently written.
1115 if (data.size() + position_ <= buffer_size_) {
1116 // Can fit into the current buffer.
1117 memcpy(&buffer_[position_], data.data(), data.size());
1118 crc32c_ = crc32c::Extend(crc32c_, &buffer_[position_], data.size());
1119 } else if (data.size() <= buffer_size_) {
1120 // Cannot fit, but can fit after flushing.
1121 TF_RETURN_IF_ERROR(FlushBuffer());
1122 memcpy(&buffer_[0], data.data(), data.size());
1123 crc32c_ = crc32c::Extend(crc32c_, &buffer_[0], data.size());
1124 } else {
1125 // Cannot fit even after flushing. So we break down "data" by chunk, and
1126 // flush/checksum each chunk.
1127 TF_RETURN_IF_ERROR(FlushBuffer());
1128 for (size_t i = 0; i < data.size(); i += buffer_size_) {
1129 const size_t nbytes = std::min(data.size() - i, buffer_size_);
1130 memcpy(&buffer_[0], data.data() + i, nbytes);
1131 crc32c_ = crc32c::Extend(crc32c_, &buffer_[0], nbytes);
1132 position_ = nbytes;
1133 TF_RETURN_IF_ERROR(FlushBuffer());
1134 }
1135 return Status::OK();
1136 }
1137 position_ += data.size();
1138 return Status::OK();
1139 }
1140
Close()1141 Status FileOutputBuffer::Close() {
1142 TF_RETURN_IF_ERROR(FlushBuffer());
1143 return file_->Close();
1144 }
1145
FlushBuffer()1146 Status FileOutputBuffer::FlushBuffer() {
1147 if (position_ > 0) {
1148 TF_RETURN_IF_ERROR(file_->Append(StringPiece(&buffer_[0], position_)));
1149 position_ = 0;
1150 }
1151 return Status::OK();
1152 }
1153
1154 } // namespace tensorflow
1155