1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/util/tensor_bundle/tensor_bundle.h"
17
18 #include <algorithm>
19 #include <cstdlib>
20 #include <cstring>
21 #include <memory>
22 #include <utility>
23
24 #include "tensorflow/core/framework/register_types.h"
25 #include "tensorflow/core/framework/tensor.pb.h"
26 #include "tensorflow/core/framework/tensor_shape.pb_text.h"
27 #include "tensorflow/core/framework/tensor_shape.pb.h"
28 #include "tensorflow/core/framework/types.h"
29 #include "tensorflow/core/framework/types.pb_text.h"
30 #include "tensorflow/core/framework/variant.h"
31 #include "tensorflow/core/framework/variant_op_registry.h"
32 #include "tensorflow/core/framework/variant_tensor_data.h"
33 #include "tensorflow/core/framework/versions.h"
34 #include "tensorflow/core/framework/versions.pb.h"
35 #include "tensorflow/core/lib/core/coding.h"
36 #include "tensorflow/core/lib/core/errors.h"
37 #include "tensorflow/core/lib/gtl/map_util.h"
38 #include "tensorflow/core/lib/gtl/stl_util.h"
39 #include "tensorflow/core/lib/hash/crc32c.h"
40 #include "tensorflow/core/lib/io/path.h"
41 #include "tensorflow/core/lib/io/table_builder.h"
42 #include "tensorflow/core/lib/random/random.h"
43 #include "tensorflow/core/lib/strings/stringprintf.h"
44 #include "tensorflow/core/util/saved_tensor_slice_util.h"
45 #include "tensorflow/core/util/tensor_slice_util.h"
46
47 namespace tensorflow {
48
49 // Versioning of the tensor bundle format.
50 const int kTensorBundleMinProducer = 0;
51 const int kTensorBundleMinConsumer = 0;
52 const int kTensorBundleVersion = 1;
53
54 // Size of our input buffer for streaming reads
55 static const int kBufferSize = 1024 * 1024;
56
57 // Key to the special BundleHeaderProto entry. Do not change this, as clients
58 // can make the assumption that the header is always the first entry in the
59 // bundle.
60 const char* const kHeaderEntryKey = "";
61
62 namespace {
63
64 // Reads "num_elements" string elements from file[offset, offset+size) into the
65 // length-N "destination". Discards the original content of "destination".
66 //
67 // Checksums the string lengths (as restored uint32 or uint64, not varint64
68 // bytes) and string bytes, and stores it into "actual_crc32c".
ReadStringTensor(io::InputBuffer * buffered_file,size_t num_elements,size_t offset,size_t size,string * destination,uint32 * actual_crc32c)69 Status ReadStringTensor(io::InputBuffer* buffered_file, size_t num_elements,
70 size_t offset, size_t size, string* destination,
71 uint32* actual_crc32c) {
72 if (size == 0) return Status::OK();
73 CHECK_GT(size, 0);
74
75 // Reads "num_elements" varint64's from "buffered_file".
76 TF_RETURN_IF_ERROR(buffered_file->Seek(offset));
77 std::vector<uint64> string_lengths(num_elements);
78 for (size_t i = 0; i < num_elements; ++i) {
79 TF_RETURN_IF_ERROR(buffered_file->ReadVarint64(&string_lengths[i]));
80 if (string_lengths[i] <= UINT32_MAX) {
81 // We need to do this because older checkpoints only used uint32s and we
82 // should still support them.
83 const uint32 elem_size_uint32 = static_cast<uint32>(string_lengths[i]);
84 *actual_crc32c = crc32c::Extend(
85 *actual_crc32c, reinterpret_cast<const char*>(&elem_size_uint32),
86 sizeof(uint32));
87 } else {
88 *actual_crc32c = crc32c::Extend(
89 *actual_crc32c, reinterpret_cast<const char*>(&string_lengths[i]),
90 sizeof(uint64));
91 }
92 }
93 if (offset + size < buffered_file->Tell()) {
94 return errors::DataLoss("String lengths longer than expected offset ",
95 offset + size);
96 }
97
98 // Reads the length-checksum.
99 uint32 length_checksum = 0;
100 size_t unused_bytes_read = 0;
101 TF_RETURN_IF_ERROR(buffered_file->ReadNBytes(
102 sizeof(uint32), reinterpret_cast<char*>(&length_checksum),
103 &unused_bytes_read));
104 if (crc32c::Unmask(length_checksum) != *actual_crc32c) {
105 return errors::DataLoss(
106 "The length checksum does not match: expected ",
107 strings::Printf("%08u", crc32c::Unmask(length_checksum)),
108 " but actual is ", strings::Printf("%08u", *actual_crc32c));
109 }
110 *actual_crc32c =
111 crc32c::Extend(*actual_crc32c, reinterpret_cast<char*>(&length_checksum),
112 sizeof(uint32));
113
114 // Reads the actual string bytes.
115 for (size_t i = 0; i < num_elements; ++i) {
116 const uint64 string_length = string_lengths[i];
117 string* buffer = &destination[i];
118
119 buffer->resize(string_length);
120 size_t bytes_read = 0;
121 TF_RETURN_IF_ERROR(
122 buffered_file->ReadNBytes(string_length, &(*buffer)[0], &bytes_read));
123 *actual_crc32c = crc32c::Extend(*actual_crc32c, buffer->data(), bytes_read);
124 }
125 return Status::OK();
126 }
127
ReadVariantTensor(io::InputBuffer * buffered_file,Tensor * ret,size_t offset,size_t size,uint32 * actual_crc32c)128 Status ReadVariantTensor(io::InputBuffer* buffered_file, Tensor* ret,
129 size_t offset, size_t size, uint32* actual_crc32c) {
130 // On-disk format:
131 // [varint64 len1][bytes variant1][4 byte checksum]
132 // ..
133 // [varint64 lenN][bytes variantN][4 byte checksum]
134 // Var "crc32c" checksums all the lens, variant bytes, individual variant
135 // checksums (as uint32, not varint32 bytes).
136 if (size == 0) return Status::OK();
137 size_t num_elements = ret->NumElements();
138
139 // Reads the actual string bytes.
140 TF_RETURN_IF_ERROR(buffered_file->Seek(offset));
141 for (size_t i = 0; i < num_elements; ++i) {
142 // Read the serialized variant length.
143 uint64 string_length = 0;
144 TF_RETURN_IF_ERROR(buffered_file->ReadVarint64(&string_length));
145 *actual_crc32c = crc32c::Extend(
146 *actual_crc32c, reinterpret_cast<const char*>(&string_length),
147 sizeof(uint64));
148 // Read the actual serialized variant.
149 string buffer;
150 buffer.resize(string_length);
151 size_t bytes_read = 0;
152 TF_RETURN_IF_ERROR(
153 buffered_file->ReadNBytes(string_length, &buffer[0], &bytes_read));
154 *actual_crc32c = crc32c::Extend(*actual_crc32c, buffer.data(), bytes_read);
155 VariantTensorDataProto proto;
156 if (!proto.ParseFromString(buffer)) {
157 return errors::DataLoss("Unable to parse VariantTensorDataProto from ",
158 "buffer of size ", string_length, ". ",
159 "Bundle entry offset: ", offset, " size: ", size);
160 }
161 Variant v = proto;
162 if (!DecodeUnaryVariant(&v)) {
163 return errors::Internal("Could not decode variant with type_name: \"",
164 v.TypeName(), "\". Perhaps you forgot to ",
165 "register a decoder via ",
166 "REGISTER_UNARY_VARIANT_DECODE_FUNCTION?");
167 }
168
169 // Read the checksum.
170 uint32 checksum = 0;
171 size_t unused_bytes_read = 0;
172 TF_RETURN_IF_ERROR(buffered_file->ReadNBytes(
173 sizeof(uint32), reinterpret_cast<char*>(&checksum),
174 &unused_bytes_read));
175 if (crc32c::Unmask(checksum) != *actual_crc32c) {
176 return errors::DataLoss(
177 "The checksum after Variant ", i, " does not match.",
178 " Expected: ", strings::Printf("%08u", crc32c::Unmask(checksum)),
179 " Actual: ", strings::Printf("%08u", *actual_crc32c));
180 }
181 *actual_crc32c = crc32c::Extend(
182 *actual_crc32c, reinterpret_cast<char*>(&checksum), sizeof(uint32));
183
184 ret->flat<Variant>()(i) = std::move(v);
185 }
186
187 return Status::OK();
188 }
189
GetBackingBuffer(const Tensor & val)190 char* GetBackingBuffer(const Tensor& val) {
191 CHECK(DataTypeCanUseMemcpy(val.dtype())) << val.dtype();
192 return const_cast<char*>(val.tensor_data().data());
193 }
194
GetStringBackingBuffer(const Tensor & val)195 string* GetStringBackingBuffer(const Tensor& val) {
196 CHECK_EQ(DT_STRING, val.dtype());
197 return const_cast<string*>(val.flat<string>().data());
198 }
199
ParseEntryProto(StringPiece key,StringPiece value,protobuf::MessageLite * out)200 Status ParseEntryProto(StringPiece key, StringPiece value,
201 protobuf::MessageLite* out) {
202 if (!out->ParseFromArray(value.data(), value.size())) {
203 return errors::DataLoss("Entry for key ", key, " not parseable.");
204 }
205 return Status::OK();
206 }
207
208 // Serializes the data bytes of the non-string tensor "val". Discards the
209 // original content of "bytes_written", and on OK updates it with number of
210 // bytes written.
211 // REQUIRES: val.dtype() != DT_STRING
WriteTensor(const Tensor & val,FileOutputBuffer * out,size_t * bytes_written)212 Status WriteTensor(const Tensor& val, FileOutputBuffer* out,
213 size_t* bytes_written) {
214 DCHECK_NE(val.dtype(), DT_STRING);
215 DCHECK_NE(val.dtype(), DT_VARIANT);
216 *bytes_written = val.TotalBytes();
217 char* buf = GetBackingBuffer(val);
218 VLOG(1) << "Appending " << *bytes_written << " bytes to file";
219 return out->Append(StringPiece(buf, *bytes_written));
220 }
221
222 // Serializes string tensor "val". "bytes_written" is treated in the same
223 // fashion as WriteTensor().
224 //
225 // Checksums all bytes written and stores it into "crc32c".
226 // REQUIRES: val.dtype() == DT_STRING
WriteStringTensor(const Tensor & val,FileOutputBuffer * out,size_t * bytes_written,uint32 * crc32c)227 Status WriteStringTensor(const Tensor& val, FileOutputBuffer* out,
228 size_t* bytes_written, uint32* crc32c) {
229 // On-disk format:
230 // [varint64 len0]..[varint64 lenL][4 byte cksum on lengths][string bytes]
231 // Var "crc32c" checksums the string lengths (as uint64, not varint64 bytes),
232 // the length-checksum, and all the string bytes.
233 DCHECK_EQ(val.dtype(), DT_STRING);
234 const string* strings = GetStringBackingBuffer(val);
235
236 // Writes the varint lengths.
237 string lengths;
238 lengths.reserve(val.NumElements()); // At least 1 byte per element.
239 *crc32c = 0;
240 for (int64 i = 0; i < val.NumElements(); ++i) {
241 const string* elem = &strings[i];
242 DCHECK_EQ(elem->size(), static_cast<uint64>(elem->size()));
243 const uint64 elem_size = static_cast<uint64>(elem->size());
244
245 core::PutVarint64(&lengths, elem_size);
246 if (elem_size <= UINT32_MAX) {
247 // We need to do this because older checkpoints only used uint32s and we
248 // should still support them.
249 const uint32 elem_size_uint32 = static_cast<uint32>(elem_size);
250 *crc32c = crc32c::Extend(*crc32c,
251 reinterpret_cast<const char*>(&elem_size_uint32),
252 sizeof(uint32));
253 } else {
254 *crc32c = crc32c::Extend(
255 *crc32c, reinterpret_cast<const char*>(&elem_size), sizeof(uint64));
256 }
257 }
258 TF_RETURN_IF_ERROR(out->Append(lengths));
259 *bytes_written = lengths.size();
260
261 // Writes the length checksum.
262 const uint32 length_checksum = crc32c::Mask(*crc32c);
263 TF_RETURN_IF_ERROR(out->Append(StringPiece(
264 reinterpret_cast<const char*>(&length_checksum), sizeof(uint32))));
265 *crc32c = crc32c::Extend(
266 *crc32c, reinterpret_cast<const char*>(&length_checksum), sizeof(uint32));
267 *bytes_written += sizeof(uint32);
268
269 // Writes all the string bytes out.
270 for (int64 i = 0; i < val.NumElements(); ++i) {
271 const string* string = &strings[i];
272 TF_RETURN_IF_ERROR(out->Append(*string));
273 *bytes_written += string->size();
274 *crc32c = crc32c::Extend(*crc32c, string->data(), string->size());
275 }
276 return Status::OK();
277 }
278
WriteVariantTensor(const Tensor & val,FileOutputBuffer * out,size_t * bytes_written,uint32 * crc32c)279 Status WriteVariantTensor(const Tensor& val, FileOutputBuffer* out,
280 size_t* bytes_written, uint32* crc32c) {
281 // On-disk format:
282 // [varint64 len1][bytes variant1][4 byte checksum]
283 // ..
284 // [varint64 lenN][bytes variantN][4 byte checksum]
285 // Var "crc32c" checksums all the lens, variant bytes, individual variant
286 // checksums (as uint32, not varint32 bytes).
287 DCHECK_EQ(val.dtype(), DT_VARIANT);
288
289 *crc32c = 0;
290 *bytes_written = 0;
291 for (int64 i = 0; i < val.NumElements(); ++i) {
292 VariantTensorData data;
293 val.flat<Variant>()(i).Encode(&data);
294 VariantTensorDataProto proto;
295 data.ToProto(&proto);
296 string elem;
297 proto.SerializeToString(&elem);
298
299 // Write the length of the serialized variant.
300 DCHECK_EQ(elem.size(), static_cast<uint64>(elem.size()));
301 const auto elem_size = static_cast<uint64>(elem.size());
302 string len;
303 core::PutVarint64(&len, elem_size);
304 TF_RETURN_IF_ERROR(out->Append(len));
305 *crc32c = crc32c::Extend(*crc32c, reinterpret_cast<const char*>(&elem_size),
306 sizeof(uint64));
307 *bytes_written += len.size();
308
309 // Write the serialized variant.
310 TF_RETURN_IF_ERROR(out->Append(elem));
311 *crc32c = crc32c::Extend(*crc32c, elem.data(), elem.size());
312 *bytes_written += elem.size();
313
314 // Write the checksum.
315 const uint32 length_checksum = crc32c::Mask(*crc32c);
316 TF_RETURN_IF_ERROR(out->Append(StringPiece(
317 reinterpret_cast<const char*>(&length_checksum), sizeof(uint32))));
318 *crc32c =
319 crc32c::Extend(*crc32c, reinterpret_cast<const char*>(&length_checksum),
320 sizeof(uint32));
321 *bytes_written += sizeof(uint32);
322 }
323
324 return Status::OK();
325 }
326
327 // Returns whether "slice_spec" is a full slice, with respect to the full shape.
328 //
329 // This can happen say, when "slice_spec" is
330 // "TensorSlice(full_tensor_shape.dims())", or when it is "TensorSlice({{0,
331 // dim(0)}, ..., {0, dim(N)}})" -- a degenerate case we need to guard against.
IsFullSlice(const TensorSlice & slice_spec,const TensorShape & full_tensor_shape)332 bool IsFullSlice(const TensorSlice& slice_spec,
333 const TensorShape& full_tensor_shape) {
334 if (slice_spec.IsFull()) {
335 return true;
336 } else {
337 TensorShape sliced_shape;
338 slice_spec.SliceTensorShape(full_tensor_shape, &sliced_shape).IgnoreError();
339 return sliced_shape == full_tensor_shape;
340 }
341 }
342
CorruptFileError(const Status & in_status,const string & filename,const string & detail)343 Status CorruptFileError(const Status& in_status, const string& filename,
344 const string& detail) {
345 if (in_status.ok()) {
346 return errors::Internal("Unable to read file (", filename,
347 "). Perhaps the file is corrupt or was produced by "
348 "a newer version of TensorFlow with format changes "
349 "(",
350 detail, ")");
351 }
352 return Status(
353 in_status.code(),
354 strings::StrCat("Unable to read file (", filename,
355 "). Perhaps the file is corrupt or was produced by a "
356 "newer version of TensorFlow with format changes (",
357 detail, "): ", in_status.error_message()));
358 }
359
TableBuilderOptions()360 table::Options TableBuilderOptions() {
361 table::Options o;
362 // Compressed tables cannot be read by TensorFlow releases prior to 1.1.
363 // To smoothen the transition, compressed writes are disabled for now
364 // (version 1.2) with the intention that they will be enabled again at
365 // some point (perhaps the 1.3 release?).
366 o.compression = table::kNoCompression;
367 return o;
368 }
369
370 // Writes zeros to output buffer to align the next write to the requested
371 // alignment. "size" is the current size of the buffer and is updated to the
372 // new size.
PadAlignment(FileOutputBuffer * out,int alignment,int64 * size)373 Status PadAlignment(FileOutputBuffer* out, int alignment, int64* size) {
374 int bytes_over = *size % alignment;
375 if (bytes_over == 0) {
376 return Status::OK();
377 }
378 int bytes_to_write = alignment - bytes_over;
379 Status status = out->Append(string(bytes_to_write, '\0'));
380 if (status.ok()) {
381 *size += bytes_to_write;
382 }
383 return status;
384 }
385
386 } // namespace
387
BundleWriter(Env * env,StringPiece prefix,const Options & options)388 BundleWriter::BundleWriter(Env* env, StringPiece prefix, const Options& options)
389 : env_(env),
390 options_(options),
391 prefix_(prefix),
392 tmp_metadata_path_(strings::StrCat(MetaFilename(prefix_), ".tempstate",
393 random::New64())),
394 tmp_data_path_(strings::StrCat(DataFilename(prefix_, 0, 1), ".tempstate",
395 random::New64())),
396 out_(nullptr),
397 size_(0) {
398 status_ = env_->CreateDir(string(io::Dirname(prefix_)));
399 if (!status_.ok() && !errors::IsAlreadyExists(status_)) {
400 return;
401 }
402 const string filename = DataFilename(prefix_, 0, 1);
403 std::unique_ptr<WritableFile> wrapper;
404 status_ = env_->NewWritableFile(tmp_data_path_, &wrapper);
405 if (!status_.ok()) return;
406 out_ = std::unique_ptr<FileOutputBuffer>(
407 new FileOutputBuffer(wrapper.release(), 8 << 20 /* 8MB write buffer */));
408
409 VLOG(1) << "Writing to file " << tmp_data_path_;
410 }
411
Add(StringPiece key,const Tensor & val)412 Status BundleWriter::Add(StringPiece key, const Tensor& val) {
413 if (!status_.ok()) return status_;
414 CHECK_NE(key, kHeaderEntryKey);
415 const string key_string(key);
416 if (entries_.find(key_string) != entries_.end()) {
417 status_ = errors::InvalidArgument("Adding duplicate key: ", key);
418 return status_;
419 }
420
421 BundleEntryProto* entry = &entries_[key_string];
422 entry->set_dtype(val.dtype());
423 val.shape().AsProto(entry->mutable_shape());
424 entry->set_shard_id(0);
425 entry->set_offset(size_);
426
427 // Updates the data file.
428 size_t data_bytes_written = 0;
429 uint32 crc32c = 0;
430 out_->clear_crc32c();
431 if (val.dtype() == DT_STRING) {
432 status_ = WriteStringTensor(val, out_.get(), &data_bytes_written, &crc32c);
433 } else if (val.dtype() == DT_VARIANT) {
434 status_ = WriteVariantTensor(val, out_.get(), &data_bytes_written, &crc32c);
435 } else {
436 status_ = WriteTensor(val, out_.get(), &data_bytes_written);
437 crc32c = out_->crc32c();
438 }
439
440 if (status_.ok()) {
441 entry->set_size(data_bytes_written);
442 entry->set_crc32c(crc32c::Mask(crc32c));
443 size_ += data_bytes_written;
444 status_ = PadAlignment(out_.get(), options_.data_alignment, &size_);
445 }
446 return status_;
447 }
448
AddSlice(StringPiece full_tensor_key,const TensorShape & full_tensor_shape,const TensorSlice & slice_spec,const Tensor & slice_tensor)449 Status BundleWriter::AddSlice(StringPiece full_tensor_key,
450 const TensorShape& full_tensor_shape,
451 const TensorSlice& slice_spec,
452 const Tensor& slice_tensor) {
453 if (!status_.ok()) return status_;
454 CHECK_NE(full_tensor_key, kHeaderEntryKey);
455
456 // If just a singleton full slice, use the regular Add() to be more efficient.
457 if (IsFullSlice(slice_spec, full_tensor_shape)) {
458 return Add(full_tensor_key, slice_tensor);
459 }
460
461 // Inserts/updates the full tensor's metadata entry.
462 //
463 // In the case of a sharded save, MergeBundles() is responsible for merging
464 // the "slices" field of multiple metadata entries corresponding to the same
465 // full tensor.
466 const string full_tensor_key_string(full_tensor_key);
467 BundleEntryProto* full_entry = &entries_[full_tensor_key_string];
468 if (full_entry->dtype() != DT_INVALID) {
469 CHECK_EQ(full_entry->dtype(), slice_tensor.dtype());
470 }
471 if (full_entry->has_shape()) {
472 CHECK(TensorShape(full_entry->shape()) == full_tensor_shape);
473 }
474
475 // Populates dtype, shape, and slices. Intentionally leaving out shard_id and
476 // offset, which do not make sense for this full tensor entry.
477 full_entry->set_dtype(slice_tensor.dtype());
478 full_tensor_shape.AsProto(full_entry->mutable_shape());
479 TensorSliceProto* slice_proto = full_entry->add_slices();
480 slice_spec.AsProto(slice_proto);
481
482 // The slice itself is handled by a regular Add(), which includes adding its
483 // own metadata entry, and writing out the slice's values.
484 const string slice_name =
485 checkpoint::EncodeTensorNameSlice(full_tensor_key_string, slice_spec);
486 status_ = Add(slice_name, slice_tensor);
487 return status_;
488 }
489
490 // TODO(zongheng): on metadata write failure or !status_.ok(), consider removing
491 // the orphaned data file.
Finish()492 Status BundleWriter::Finish() {
493 if (out_) {
494 status_.Update(out_->Close());
495 out_ = nullptr;
496 if (status_.ok()) {
497 status_ = Env::Default()->RenameFile(tmp_data_path_,
498 DataFilename(prefix_, 0, 1));
499 } else {
500 Env::Default()->DeleteFile(tmp_data_path_).IgnoreError();
501 }
502 }
503 if (!status_.ok()) return status_;
504 // Build key -> BundleEntryProto table.
505 std::unique_ptr<WritableFile> file;
506 status_ = env_->NewWritableFile(tmp_metadata_path_, &file);
507 if (!status_.ok()) return status_;
508 {
509 // N.B.: the default use of Snappy compression may not be supported on all
510 // platforms (e.g. Android). The metadata file is small, so this is fine.
511 table::Options options;
512 options.compression = table::kNoCompression;
513 table::TableBuilder builder(options, file.get());
514 // Header entry.
515 BundleHeaderProto header;
516 header.set_num_shards(1);
517 header.set_endianness(BundleHeaderProto::LITTLE);
518 if (!port::kLittleEndian) header.set_endianness(BundleHeaderProto::BIG);
519 VersionDef* version = header.mutable_version();
520 version->set_producer(kTensorBundleVersion);
521 version->set_min_consumer(kTensorBundleMinConsumer);
522
523 builder.Add(kHeaderEntryKey, header.SerializeAsString());
524
525 // All others.
526 for (const auto& p : entries_) {
527 builder.Add(p.first, p.second.SerializeAsString());
528 }
529 status_ = builder.Finish();
530 }
531 status_.Update(file->Close());
532 if (!status_.ok()) {
533 Env::Default()->DeleteFile(tmp_metadata_path_).IgnoreError();
534 return status_;
535 } else {
536 status_ =
537 Env::Default()->RenameFile(tmp_metadata_path_, MetaFilename(prefix_));
538 if (!status_.ok()) return status_;
539 }
540 status_ = errors::Internal("BundleWriter is closed");
541 return Status::OK();
542 }
543
544 // Merging tensor bundles.
545
546 // Accumulator of metadata states during a merge.
547 struct MergeState {
548 // Accumulated from the header entries.
549 int num_shards = 0;
550
551 // Derives "endianness" and "version" from the first bundle merged (hence the
552 // "seen_first_bundle" guard). The two fields must be the same for all
553 // bundles in a merge.
554 bool seen_first_bundle = false;
555 BundleHeaderProto_Endianness endianness;
556 VersionDef version;
557
558 // Tensor key -> BundleEntryProto.
559 std::map<string, BundleEntryProto> entries;
560 // Data file path -> new shard id in the final merged bundle.
561 std::unordered_map<string, int32> shard_ids;
562 };
563
564 // Merges entries of "prefix" into the accumulator state "merge".
565 // Returns OK iff the merge succeeds.
MergeOneBundle(Env * env,StringPiece prefix,MergeState * merge_state)566 static Status MergeOneBundle(Env* env, StringPiece prefix,
567 MergeState* merge_state) {
568 VLOG(1) << "Merging bundle:" << prefix;
569 const string filename = MetaFilename(prefix);
570 uint64 file_size;
571 TF_RETURN_IF_ERROR(env->GetFileSize(filename, &file_size));
572 std::unique_ptr<RandomAccessFile> file;
573 TF_RETURN_IF_ERROR(env->NewRandomAccessFile(filename, &file));
574
575 table::Table* table = nullptr;
576 TF_RETURN_IF_ERROR(
577 table::Table::Open(TableBuilderOptions(), file.get(), file_size, &table));
578 std::unique_ptr<table::Table> table_deleter(table);
579 std::unique_ptr<table::Iterator> iter(table->NewIterator());
580
581 int num_shards;
582 // Process header.
583 {
584 iter->Seek(kHeaderEntryKey);
585 if (!iter->Valid()) {
586 return CorruptFileError(iter->status(), filename,
587 "failed to seek to header entry");
588 }
589 BundleHeaderProto header;
590 Status s = ParseEntryProto(iter->key(), iter->value(), &header);
591 if (!s.ok()) return CorruptFileError(s, filename, "unable to parse header");
592
593 merge_state->num_shards += header.num_shards();
594 if (!merge_state->seen_first_bundle) {
595 merge_state->seen_first_bundle = true;
596 merge_state->endianness = header.endianness();
597 merge_state->version = header.version();
598 } else {
599 // Validates "endianness".
600 if (merge_state->endianness != header.endianness()) {
601 return errors::InvalidArgument(
602 "Merging bundles with conflicting endianness; inputs corrupted?");
603 }
604 // Validates "version".
605 string curr_version, merge_version;
606 header.version().SerializeToString(&curr_version);
607 merge_state->version.SerializeToString(&merge_version);
608 if (curr_version != merge_version) {
609 return errors::InvalidArgument(
610 "Merging bundles with different format versions: merged ",
611 merge_version, " vs. curr ", curr_version);
612 }
613 }
614 num_shards = header.num_shards();
615 iter->Next();
616 }
617
618 // Loops through the non-header to-merge entries.
619 BundleEntryProto to_merge_entry;
620 for (; iter->Valid(); iter->Next()) {
621 const string key(iter->key());
622 const auto entry_iter = merge_state->entries.find(key);
623
624 // Illegal: the duplicated entry is a non-slice tensor.
625 if (entry_iter != merge_state->entries.end() &&
626 entry_iter->second.slices().empty()) {
627 return errors::InvalidArgument(
628 "Duplicate tensor keyed by ", key,
629 " encountered, when merging prefix: ", prefix);
630 }
631
632 TF_RETURN_IF_ERROR(
633 ParseEntryProto(iter->key(), iter->value(), &to_merge_entry));
634
635 // The duplicated entry holds metadata for a sliced full tensor.
636 // Allows the duplication and merges "slices".
637 if (entry_iter != merge_state->entries.end()) {
638 BundleEntryProto& existing_entry = entry_iter->second;
639 if (to_merge_entry.slices().empty()) {
640 return errors::Internal(
641 "Duplicate tensor keyed by ", key,
642 "; attempting to merge in a non-slice bundle entry");
643 }
644 // Only needs merge the "slices" field (and validate dtype/shape).
645 for (int i = 0; i < to_merge_entry.slices_size(); ++i) {
646 TensorSliceProto* slot = existing_entry.add_slices();
647 *slot = to_merge_entry.slices(i);
648 }
649 CHECK_EQ(existing_entry.dtype(), to_merge_entry.dtype());
650 CHECK(TensorShape(existing_entry.shape()) ==
651 TensorShape(to_merge_entry.shape()));
652 continue;
653 }
654
655 // Key doesn't duplicate: a fresh tensor/slice entry.
656 auto result = merge_state->shard_ids.insert(
657 {DataFilename(prefix, to_merge_entry.shard_id(), num_shards),
658 merge_state->shard_ids.size()});
659 to_merge_entry.set_shard_id(result.first->second);
660 merge_state->entries[key] = to_merge_entry;
661 }
662 return Status::OK();
663 }
664
MergeBundles(Env * env,gtl::ArraySlice<string> prefixes,StringPiece merged_prefix)665 Status MergeBundles(Env* env, gtl::ArraySlice<string> prefixes,
666 StringPiece merged_prefix) {
667 // Merges all metadata tables.
668 // TODO(zhifengc): KeyValue sorter if it becomes too big.
669 MergeState merge;
670 Status status = env->CreateDir(string(io::Dirname(merged_prefix)));
671 if (!status.ok() && !errors::IsAlreadyExists(status)) return status;
672 for (int i = 0; i < prefixes.size(); ++i) {
673 TF_RETURN_IF_ERROR(MergeOneBundle(env, prefixes[i], &merge));
674 }
675
676 // Renames data files to contain the merged bundle prefix.
677 for (const auto& p : merge.shard_ids) {
678 VLOG(1) << "Renaming " << p.first << " to "
679 << DataFilename(merged_prefix, p.second, merge.shard_ids.size());
680 TF_RETURN_IF_ERROR(env->RenameFile(
681 p.first,
682 DataFilename(merged_prefix, p.second, merge.shard_ids.size())));
683 }
684
685 // Writes the final metadata table under the merged prefix.
686 std::unique_ptr<WritableFile> merged_metadata;
687 TF_RETURN_IF_ERROR(
688 env->NewWritableFile(MetaFilename(merged_prefix), &merged_metadata));
689 {
690 table::TableBuilder builder(TableBuilderOptions(), merged_metadata.get());
691 // Header entry.
692 BundleHeaderProto header;
693 header.set_num_shards(merge.num_shards);
694 header.set_endianness(merge.endianness);
695 *header.mutable_version() = merge.version;
696 builder.Add(kHeaderEntryKey, header.SerializeAsString());
697 // All others.
698 for (const auto& p : merge.entries) {
699 builder.Add(p.first, p.second.SerializeAsString());
700 }
701 status = builder.Finish();
702 }
703 status.Update(merged_metadata->Close());
704 if (!status.ok()) return status;
705 VLOG(1) << "Merged bundles to:" << merged_prefix;
706
707 // Cleanup: best effort based and ignores errors.
708 for (const string& prefix : prefixes) {
709 env->DeleteFile(MetaFilename(prefix)).IgnoreError();
710 }
711 return status;
712 }
713
714 // Interface for reading a tensor bundle.
715
BundleReader(Env * env,StringPiece prefix)716 BundleReader::BundleReader(Env* env, StringPiece prefix)
717 : env_(env),
718 prefix_(prefix),
719 metadata_(nullptr),
720 table_(nullptr),
721 iter_(nullptr) {
722 const string filename = MetaFilename(prefix_);
723 uint64 file_size;
724 status_ = env_->GetFileSize(filename, &file_size);
725 if (!status_.ok()) return;
726
727 // Opens the metadata table.
728 std::unique_ptr<RandomAccessFile> wrapper;
729 status_ = env_->NewRandomAccessFile(filename, &wrapper);
730 if (!status_.ok()) return;
731 metadata_ = wrapper.release();
732 status_ = table::Table::Open(table::Options(), metadata_, file_size, &table_);
733 if (!status_.ok()) return;
734 iter_ = table_->NewIterator();
735
736 // Reads "num_shards_" from the first entry.
737 iter_->Seek(kHeaderEntryKey);
738 if (!iter_->Valid()) {
739 status_ = CorruptFileError(iter_->status(), filename,
740 "failed to seek to header entry");
741 return;
742 }
743 BundleHeaderProto header;
744 status_ = ParseEntryProto(iter_->key(), iter_->value(), &header);
745 if (!status_.ok()) {
746 status_ = CorruptFileError(status_, filename, "unable to parse header");
747 return;
748 }
749 num_shards_ = header.num_shards();
750 if ((header.endianness() == BundleHeaderProto::BIG && port::kLittleEndian) ||
751 (header.endianness() == BundleHeaderProto::LITTLE &&
752 !port::kLittleEndian)) {
753 status_ = errors::Unimplemented(
754 "Reading a bundle with different endianness from the reader");
755 return;
756 }
757 status_ = CheckVersions(header.version(), kTensorBundleVersion,
758 kTensorBundleMinProducer, "Checkpoint", "checkpoint");
759 }
760
~BundleReader()761 BundleReader::~BundleReader() {
762 delete metadata_;
763 delete iter_;
764 delete table_;
765 // InputBuffer does not own the underlying RandomAccessFile.
766 for (auto pair : data_) {
767 if (pair.second != nullptr && pair.second->file() != nullptr) {
768 delete pair.second->file();
769 }
770 }
771 gtl::STLDeleteValues(&data_);
772 gtl::STLDeleteValues(&tensor_slices_);
773 }
774
GetBundleEntryProto(StringPiece key,BundleEntryProto * entry)775 Status BundleReader::GetBundleEntryProto(StringPiece key,
776 BundleEntryProto* entry) {
777 entry->Clear();
778 TF_CHECK_OK(status_);
779 Seek(key);
780 if (!iter_->Valid() || iter_->key() != key) {
781 return errors::NotFound("Key ", key, " not found in checkpoint");
782 }
783
784 BundleEntryProto entry_copy;
785 TF_RETURN_IF_ERROR(
786 ParseEntryProto(iter_->key(), iter_->value(), &entry_copy));
787 if (!TensorShape::IsValid(entry_copy.shape())) {
788 return errors::DataLoss("Invalid tensor shape: ", key, " ",
789 ProtoShortDebugString(entry_copy.shape()));
790 }
791
792 *entry = entry_copy;
793 return Status::OK();
794 }
795
GetValue(const BundleEntryProto & entry,Tensor * val)796 Status BundleReader::GetValue(const BundleEntryProto& entry, Tensor* val) {
797 Tensor* ret = val;
798 const TensorShape stored_shape(TensorShape(entry.shape()));
799 if (val->NumElements() == 0) {
800 ret = new Tensor(entry.dtype(), stored_shape);
801 }
802
803 // Validates the "size" field.
804 if (entry.dtype() != DT_STRING && entry.dtype() != DT_VARIANT) {
805 if (entry.size() != ret->TotalBytes()) {
806 return errors::DataLoss("Invalid size in bundle entry: key ", key(),
807 "; stored size ", entry.size(),
808 "; expected size ", ret->TotalBytes());
809 }
810 } else if (entry.dtype() == DT_STRING) {
811 // Relaxes the check for string tensors as follows:
812 // entry.size() == bytes(varint lengths) + bytes(data)
813 // >= NumElems + bytes(data), since size bytes(varint) >= 1.
814 // TotalBytes() == sizeof(string) * NumElems + bytes(data)
815 // Since we don't know bytes(varint lengths), we just check an inequality.
816 const size_t lower_bound = ret->NumElements() + ret->TotalBytes() -
817 sizeof(string) * ret->NumElements();
818 if (entry.size() < lower_bound) {
819 return errors::DataLoss("Invalid size in bundle entry: key ", key(),
820 "; stored size ", entry.size(),
821 "; expected size is at least ", lower_bound);
822 }
823 }
824
825 // Open the data file if it has not been opened.
826 io::InputBuffer* buffered_file = data_[entry.shard_id()];
827 if (buffered_file == nullptr) {
828 std::unique_ptr<RandomAccessFile> file = nullptr;
829 TF_RETURN_IF_ERROR(env_->NewRandomAccessFile(
830 DataFilename(prefix_, entry.shard_id(), num_shards_), &file));
831 buffered_file = new io::InputBuffer(file.release(), kBufferSize);
832 // The InputBuffer and RandomAccessFile objects are both released in dtor.
833 data_[entry.shard_id()] = buffered_file;
834 }
835 CHECK(buffered_file != nullptr);
836
837 TF_RETURN_IF_ERROR(buffered_file->Seek(entry.offset()));
838 uint32 actual_crc32c = 0;
839
840 if (DataTypeCanUseMemcpy(entry.dtype())) {
841 char* backing_buffer = const_cast<char*>((ret->tensor_data().data()));
842 size_t unused_bytes_read;
843 if (entry.size() > kBufferSize) {
844 StringPiece sp;
845 TF_RETURN_IF_ERROR(buffered_file->file()->Read(
846 entry.offset(), entry.size(), &sp, backing_buffer));
847 if (sp.data() != backing_buffer) {
848 memmove(backing_buffer, sp.data(), entry.size());
849 }
850 } else {
851 TF_RETURN_IF_ERROR(buffered_file->ReadNBytes(entry.size(), backing_buffer,
852 &unused_bytes_read));
853 }
854 actual_crc32c = crc32c::Value(backing_buffer, entry.size());
855 } else if (entry.dtype() == DT_VARIANT) {
856 // Relies on io::InputBuffer's buffering, because we issue many neighboring
857 // reads for a single string tensor.
858 TF_RETURN_IF_ERROR(ReadVariantTensor(buffered_file, ret, entry.offset(),
859 entry.size(), &actual_crc32c));
860 } else {
861 // Relies on io::InputBuffer's buffering, because we issue many neighboring
862 // reads for a single string tensor.
863 TF_RETURN_IF_ERROR(ReadStringTensor(
864 buffered_file, ret->NumElements(), entry.offset(), entry.size(),
865 GetStringBackingBuffer(*ret), &actual_crc32c));
866 }
867 if (crc32c::Unmask(entry.crc32c()) != actual_crc32c) {
868 return errors::DataLoss(
869 "Checksum does not match: stored ",
870 strings::Printf("%08u", crc32c::Unmask(entry.crc32c())),
871 " vs. calculated on the restored bytes ", actual_crc32c);
872 }
873
874 *val = *ret;
875 if (ret != val) delete ret;
876 return Status::OK();
877 }
878
Lookup(StringPiece key,Tensor * val)879 Status BundleReader::Lookup(StringPiece key, Tensor* val) {
880 CHECK(val != nullptr);
881 BundleEntryProto entry;
882 TF_RETURN_IF_ERROR(GetBundleEntryProto(key, &entry));
883
884 if (entry.slices().empty()) {
885 return GetValue(entry, val);
886 } else {
887 return GetSliceValue(
888 key, entry,
889 /* a full slice */ TensorSlice(TensorShape(entry.shape()).dims()), val);
890 }
891 }
892
ReadCurrent(Tensor * val)893 Status BundleReader::ReadCurrent(Tensor* val) {
894 CHECK(val != nullptr);
895 BundleEntryProto entry;
896 TF_RETURN_IF_ERROR(ParseEntryProto(iter_->key(), iter_->value(), &entry));
897 if (!TensorShape::IsValid(entry.shape())) {
898 return errors::DataLoss("Invalid tensor shape: ", iter_->key(), " ",
899 ProtoShortDebugString(entry.shape()));
900 }
901
902 if (entry.slices().empty()) {
903 return GetValue(entry, val);
904 } else {
905 return GetSliceValue(
906 iter_->key(), entry,
907 /* a full slice */ TensorSlice(TensorShape(entry.shape()).dims()), val);
908 }
909 }
910
LookupTensorSlices(StringPiece key,std::vector<TensorSlice> * slices)911 Status BundleReader::LookupTensorSlices(StringPiece key,
912 std::vector<TensorSlice>* slices) {
913 slices->clear();
914 BundleEntryProto entry;
915 TF_RETURN_IF_ERROR(GetBundleEntryProto(key, &entry));
916 slices->reserve(entry.slices_size());
917 for (const auto& slice : entry.slices()) {
918 slices->emplace_back(slice);
919 }
920 return Status::OK();
921 }
922
LookupSlice(StringPiece full_tensor_key,const TensorSlice & slice_spec,Tensor * val)923 Status BundleReader::LookupSlice(StringPiece full_tensor_key,
924 const TensorSlice& slice_spec, Tensor* val) {
925 CHECK(val != nullptr);
926 BundleEntryProto entry;
927 TF_RETURN_IF_ERROR(GetBundleEntryProto(full_tensor_key, &entry));
928 return GetSliceValue(full_tensor_key, entry, slice_spec, val);
929 }
930
GetSliceValue(StringPiece full_tensor_key,const BundleEntryProto & full_tensor_entry,const TensorSlice & slice_spec,Tensor * val)931 Status BundleReader::GetSliceValue(StringPiece full_tensor_key,
932 const BundleEntryProto& full_tensor_entry,
933 const TensorSlice& slice_spec, Tensor* val) {
934 using checkpoint::RegisterTensorSlice;
935 using checkpoint::TensorSliceSet;
936 DCHECK_GE(full_tensor_entry.slices_size(), 0);
937
938 const TensorShape full_shape(TensorShape(full_tensor_entry.shape()));
939 std::vector<std::pair<TensorSlice, string>> details;
940 const string full_tensor_key_string(full_tensor_key);
941 const TensorSliceSet* tss =
942 gtl::FindPtrOrNull(tensor_slices_, full_tensor_key_string);
943
944 // Populates the "full tensor key -> TensorSliceSet" cache.
945 if (tss == nullptr) {
946 if (full_tensor_entry.slices().empty()) {
947 // Special case: a writer has saved a tensor fully, but the reader wants
948 // to read in slices. We therefore register the full slice on-demand here
949 // without further complicating the on-disk bundle format.
950 TF_RETURN_IF_ERROR(RegisterTensorSlice(
951 full_tensor_key_string, full_shape, full_tensor_entry.dtype(),
952 /* tag */ "",
953 /* full slice */ TensorSlice(full_shape.dims()), &tensor_slices_));
954 }
955 for (const TensorSliceProto& slice : full_tensor_entry.slices()) {
956 TF_RETURN_IF_ERROR(RegisterTensorSlice(
957 full_tensor_key_string, full_shape, full_tensor_entry.dtype(),
958 /* tag */ "", TensorSlice(slice), &tensor_slices_));
959 }
960 tss = gtl::FindPtrOrNull(tensor_slices_, full_tensor_key_string);
961 CHECK_NE(tss, nullptr);
962 }
963 if (!tss->QueryMeta(slice_spec, &details)) {
964 return errors::InvalidArgument(
965 "Does not have sufficient slices for partitioned tensor ",
966 full_tensor_key,
967 " to restore in slice_spec: ", slice_spec.DebugString());
968 }
969
970 // The union of the slices in "details" covers "slice_spec". Performs the
971 // copies from each.
972 BundleEntryProto stored_slice_entry = full_tensor_entry;
973 for (const auto& slice_tag_pair : details) {
974 // Seeks for the stored slice.
975 const TensorSlice& stored_slice = slice_tag_pair.first;
976
977 // We already have the entry for the full tensor, so don't query again if
978 // the slice is full.
979 if (!stored_slice.IsFull()) {
980 const string encoded_stored_slice_name =
981 checkpoint::EncodeTensorNameSlice(full_tensor_key_string,
982 stored_slice);
983 status_ =
984 GetBundleEntryProto(encoded_stored_slice_name, &stored_slice_entry);
985 if (!status_.ok()) return status_;
986 }
987
988 // TODO(zongheng): should we take an OpKernelContext, so that we can call
989 // allocate_temp()? Note that without major refactorings to Saver, it's
990 // hard for the caller of the tensor bundle module to allocate these
991 // precisely-shaped scratch storage.
992
993 // Optimization for the common case: the stored slice can be directly
994 // copied to the destination without additional slicing. This is true when
995 // either the slices are equal or when they are both full slices having the
996 // same shape.
997 TensorShape stored_slice_shape(stored_slice_entry.shape());
998 if (stored_slice == slice_spec ||
999 (stored_slice_shape == val->shape() &&
1000 IsFullSlice(stored_slice, stored_slice_shape) &&
1001 IsFullSlice(slice_spec, stored_slice_shape))) {
1002 VLOG(1) << "Optimized for common case: directly copying into "
1003 "pre-allocated buffer; spec: "
1004 << slice_spec.DebugString();
1005 status_ = GetValue(stored_slice_entry, val);
1006 return status_;
1007 }
1008
1009 Tensor stored_slice_tensor(stored_slice_entry.dtype(), stored_slice_shape);
1010 status_ = GetValue(stored_slice_entry, &stored_slice_tensor);
1011 if (!status_.ok()) return status_;
1012
1013 // Copies the intersection over.
1014 const DataType common_dtype = full_tensor_entry.dtype();
1015 switch (common_dtype) {
1016 #define HANDLE_COPY(T) \
1017 case DataTypeToEnum<T>::value: \
1018 CHECK(CopyDataFromTensorSliceToTensorSlice( \
1019 full_shape, stored_slice, slice_spec, \
1020 stored_slice_tensor.flat<T>().data(), val->flat<T>().data())); \
1021 break;
1022
1023 HANDLE_COPY(float)
1024 HANDLE_COPY(double)
1025 HANDLE_COPY(int32)
1026 HANDLE_COPY(uint8)
1027 HANDLE_COPY(int16)
1028 HANDLE_COPY(int8)
1029 HANDLE_COPY(complex64)
1030 HANDLE_COPY(complex128)
1031 HANDLE_COPY(int64)
1032 HANDLE_COPY(bool)
1033 HANDLE_COPY(qint32)
1034 HANDLE_COPY(quint8)
1035 HANDLE_COPY(qint8)
1036 default:
1037 return errors::InvalidArgument("Dtype ", DataTypeString(common_dtype),
1038 " not supported.");
1039 }
1040 #undef HANDLE_COPY
1041 }
1042 return Status::OK();
1043 }
1044
Contains(StringPiece key)1045 bool BundleReader::Contains(StringPiece key) {
1046 Seek(key);
1047 return Valid() && (this->key() == key);
1048 }
1049
LookupDtypeAndShape(StringPiece key,DataType * dtype,TensorShape * shape)1050 Status BundleReader::LookupDtypeAndShape(StringPiece key, DataType* dtype,
1051 TensorShape* shape) {
1052 BundleEntryProto entry;
1053 TF_RETURN_IF_ERROR(GetBundleEntryProto(key, &entry));
1054 *dtype = entry.dtype();
1055 *shape = TensorShape(entry.shape());
1056 return Status::OK();
1057 }
1058
LookupTensorShape(StringPiece key,TensorShape * shape)1059 Status BundleReader::LookupTensorShape(StringPiece key, TensorShape* shape) {
1060 DataType ignored;
1061 return LookupDtypeAndShape(key, &ignored, shape);
1062 }
1063
DebugString()1064 string BundleReader::DebugString() {
1065 // Format used below emulates that of TensorSliceReader::DebugString().
1066 string shape_str;
1067 BundleEntryProto entry;
1068 Seek(kHeaderEntryKey);
1069 for (Next(); Valid(); Next()) {
1070 CHECK(entry.ParseFromArray(value().data(), value().size()));
1071 if (entry.slices_size() > 0) continue; // Slice of some partitioned var.
1072
1073 strings::StrAppend(&shape_str, key(), " (",
1074 EnumName_DataType(entry.dtype()), ") ",
1075 TensorShape(entry.shape()).DebugString());
1076 strings::StrAppend(&shape_str, "\n");
1077 }
1078 return shape_str;
1079 }
1080
~FileOutputBuffer()1081 FileOutputBuffer::~FileOutputBuffer() { delete file_; }
1082
Append(StringPiece data)1083 Status FileOutputBuffer::Append(StringPiece data) {
1084 // In the below, it is critical to calculate the checksum on the actually
1085 // copied bytes, not the source bytes. This is because "data" typically
1086 // points to tensor buffers, which may be concurrently written.
1087 if (data.size() + position_ <= buffer_size_) {
1088 // Can fit into the current buffer.
1089 memcpy(&buffer_[position_], data.data(), data.size());
1090 crc32c_ = crc32c::Extend(crc32c_, &buffer_[position_], data.size());
1091 } else if (data.size() <= buffer_size_) {
1092 // Cannot fit, but can fit after flushing.
1093 TF_RETURN_IF_ERROR(FlushBuffer());
1094 memcpy(&buffer_[0], data.data(), data.size());
1095 crc32c_ = crc32c::Extend(crc32c_, &buffer_[0], data.size());
1096 } else {
1097 // Cannot fit even after flushing. So we break down "data" by chunk, and
1098 // flush/checksum each chunk.
1099 TF_RETURN_IF_ERROR(FlushBuffer());
1100 for (size_t i = 0; i < data.size(); i += buffer_size_) {
1101 const size_t nbytes = std::min(data.size() - i, buffer_size_);
1102 memcpy(&buffer_[0], data.data() + i, nbytes);
1103 crc32c_ = crc32c::Extend(crc32c_, &buffer_[0], nbytes);
1104 position_ = nbytes;
1105 TF_RETURN_IF_ERROR(FlushBuffer());
1106 }
1107 return Status::OK();
1108 }
1109 position_ += data.size();
1110 return Status::OK();
1111 }
1112
Close()1113 Status FileOutputBuffer::Close() {
1114 TF_RETURN_IF_ERROR(FlushBuffer());
1115 return file_->Close();
1116 }
1117
FlushBuffer()1118 Status FileOutputBuffer::FlushBuffer() {
1119 if (position_ > 0) {
1120 TF_RETURN_IF_ERROR(file_->Append(StringPiece(&buffer_[0], position_)));
1121 position_ = 0;
1122 }
1123 return Status::OK();
1124 }
1125
1126 } // namespace tensorflow
1127