1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/util/tensor_bundle/tensor_bundle.h"
17
18 #include <algorithm>
19 #include <cstdlib>
20 #include <cstring>
21 #include <memory>
22 #include <utility>
23
24 #include "tensorflow/core/framework/register_types.h"
25 #include "tensorflow/core/framework/tensor.pb.h"
26 #include "tensorflow/core/framework/tensor_shape.pb.h"
27 #include "tensorflow/core/framework/types.h"
28 #include "tensorflow/core/framework/types.pb.h"
29 #include "tensorflow/core/framework/variant.h"
30 #include "tensorflow/core/framework/variant_op_registry.h"
31 #include "tensorflow/core/framework/variant_tensor_data.h"
32 #include "tensorflow/core/framework/versions.h"
33 #include "tensorflow/core/framework/versions.pb.h"
34 #include "tensorflow/core/lib/core/coding.h"
35 #include "tensorflow/core/lib/core/errors.h"
36 #include "tensorflow/core/lib/gtl/map_util.h"
37 #include "tensorflow/core/lib/hash/crc32c.h"
38 #include "tensorflow/core/lib/io/path.h"
39 #include "tensorflow/core/lib/io/table_builder.h"
40 #include "tensorflow/core/lib/random/random.h"
41 #include "tensorflow/core/lib/strings/str_util.h"
42 #include "tensorflow/core/lib/strings/stringprintf.h"
43 #include "tensorflow/core/platform/bfloat16.h"
44 #include "tensorflow/core/platform/cord.h"
45 #include "tensorflow/core/platform/errors.h"
46 #include "tensorflow/core/platform/mem.h"
47 #include "tensorflow/core/util/env_var.h"
48 #include "tensorflow/core/util/saved_tensor_slice_util.h"
49 #include "tensorflow/core/util/tensor_bundle/byte_swap.h"
50 #include "tensorflow/core/util/tensor_slice_util.h"
51
52 #ifdef PLATFORM_WINDOWS
53 #undef DeleteFile
54 #endif
55
56 namespace tensorflow {
57
58 // Versioning of the tensor bundle format.
59 const int kTensorBundleMinProducer = 0;
60 const int kTensorBundleMinConsumer = 0;
61 const int kTensorBundleVersion = 1;
62
63 // Size of our input buffer for streaming reads
64 static const int kBufferSize = 1024 * 1024;
65
66 // Key to the special BundleHeaderProto entry. Do not change this, as clients
67 // can make the assumption that the header is always the first entry in the
68 // bundle.
69 const char* const kHeaderEntryKey = "";
70
71 namespace {
72
73 // Reads "num_elements" string elements from file[offset, offset+size) into the
74 // length-N "destination". Discards the original content of "destination".
75 //
76 // Checksums the string lengths (as restored uint32 or uint64, not varint64
77 // bytes) and string bytes, and stores it into "actual_crc32c".
ReadStringTensor(io::InputBuffer * buffered_file,size_t num_elements,size_t offset,size_t size,tstring * destination,uint32 * actual_crc32c,bool need_to_swap_bytes)78 Status ReadStringTensor(io::InputBuffer* buffered_file, size_t num_elements,
79 size_t offset, size_t size, tstring* destination,
80 uint32* actual_crc32c, bool need_to_swap_bytes) {
81 if (size == 0) return Status::OK();
82 CHECK_GT(size, 0);
83
84 // Reads "num_elements" varint64's from "buffered_file".
85 TF_RETURN_IF_ERROR(buffered_file->Seek(offset));
86 TF_RETURN_IF_ERROR(buffered_file->Hint(size));
87 std::vector<uint64> string_lengths(num_elements);
88 for (size_t i = 0; i < num_elements; ++i) {
89 TF_RETURN_IF_ERROR(buffered_file->ReadVarint64(&string_lengths[i]));
90 if (string_lengths[i] <= UINT32_MAX) {
91 // We need to do this because older checkpoints only used uint32s and we
92 // should still support them.
93 uint32 elem_size_uint32 = static_cast<uint32>(string_lengths[i]);
94 if (need_to_swap_bytes) {
95 // Checksum would have been computed on the source machine's byte order
96 elem_size_uint32 = BYTE_SWAP_32(elem_size_uint32);
97 }
98 *actual_crc32c = crc32c::Extend(
99 *actual_crc32c, reinterpret_cast<const char*>(&elem_size_uint32),
100 sizeof(uint32));
101 } else {
102 uint64 length = string_lengths[i];
103 if (need_to_swap_bytes) {
104 length = BYTE_SWAP_64(length);
105 }
106 *actual_crc32c =
107 crc32c::Extend(*actual_crc32c, reinterpret_cast<const char*>(&length),
108 sizeof(uint64));
109 }
110 }
111 if (offset + size < buffered_file->Tell()) {
112 return errors::DataLoss("String lengths longer than expected offset ",
113 offset + size);
114 }
115
116 // Reads the length-checksum.
117 uint32 raw_length_checksum = 0; // Bytes in file
118 uint32 length_checksum = 0; // In-memory representation
119 size_t unused_bytes_read = 0;
120 TF_RETURN_IF_ERROR(buffered_file->ReadNBytes(
121 sizeof(uint32), reinterpret_cast<char*>(&raw_length_checksum),
122 &unused_bytes_read));
123 length_checksum = need_to_swap_bytes ? BYTE_SWAP_32(raw_length_checksum)
124 : raw_length_checksum;
125 if (crc32c::Unmask(length_checksum) != *actual_crc32c) {
126 return errors::DataLoss(
127 "The length checksum does not match: expected ",
128 strings::Printf("%08u", crc32c::Unmask(length_checksum)),
129 " but actual is ", strings::Printf("%08u", *actual_crc32c));
130 }
131 *actual_crc32c = crc32c::Extend(*actual_crc32c,
132 reinterpret_cast<char*>(&raw_length_checksum),
133 sizeof(uint32));
134
135 // Reads the actual string bytes.
136 for (size_t i = 0; i < num_elements; ++i) {
137 const uint64 string_length = string_lengths[i];
138 tstring* buffer = &destination[i];
139
140 buffer->resize(string_length);
141 size_t bytes_read = 0;
142 TF_RETURN_IF_ERROR(
143 buffered_file->ReadNBytes(string_length, &(*buffer)[0], &bytes_read));
144 *actual_crc32c = crc32c::Extend(*actual_crc32c, buffer->data(), bytes_read);
145 }
146 return Status::OK();
147 }
148
ReadVariantTensor(io::InputBuffer * buffered_file,Tensor * ret,size_t offset,size_t size,uint32 * actual_crc32c)149 Status ReadVariantTensor(io::InputBuffer* buffered_file, Tensor* ret,
150 size_t offset, size_t size, uint32* actual_crc32c) {
151 // On-disk format:
152 // [varint64 len1][bytes variant1][4 byte checksum]
153 // ..
154 // [varint64 lenN][bytes variantN][4 byte checksum]
155 // Var "crc32c" checksums all the lens, variant bytes, individual variant
156 // checksums (as uint32, not varint32 bytes).
157 if (size == 0) return Status::OK();
158 size_t num_elements = ret->NumElements();
159
160 // Reads the actual string bytes.
161 TF_RETURN_IF_ERROR(buffered_file->Seek(offset));
162 TF_RETURN_IF_ERROR(buffered_file->Hint(size));
163 for (size_t i = 0; i < num_elements; ++i) {
164 // Read the serialized variant length.
165 uint64 string_length = 0;
166 TF_RETURN_IF_ERROR(buffered_file->ReadVarint64(&string_length));
167 *actual_crc32c = crc32c::Extend(
168 *actual_crc32c, reinterpret_cast<const char*>(&string_length),
169 sizeof(uint64));
170 // Read the actual serialized variant.
171 string buffer;
172 buffer.resize(string_length);
173 size_t bytes_read = 0;
174 TF_RETURN_IF_ERROR(
175 buffered_file->ReadNBytes(string_length, &buffer[0], &bytes_read));
176 *actual_crc32c = crc32c::Extend(*actual_crc32c, buffer.data(), bytes_read);
177 VariantTensorDataProto proto;
178 if (!proto.ParseFromString(buffer)) {
179 return errors::DataLoss("Unable to parse VariantTensorDataProto from ",
180 "buffer of size ", string_length, ". ",
181 "Bundle entry offset: ", offset, " size: ", size);
182 }
183 Variant v = proto;
184 if (!DecodeUnaryVariant(&v)) {
185 return errors::Internal("Could not decode variant with type_name: \"",
186 v.TypeName(), "\". Perhaps you forgot to ",
187 "register a decoder via ",
188 "REGISTER_UNARY_VARIANT_DECODE_FUNCTION?");
189 }
190
191 // Read the checksum.
192 uint32 checksum = 0;
193 size_t unused_bytes_read = 0;
194 TF_RETURN_IF_ERROR(buffered_file->ReadNBytes(
195 sizeof(uint32), reinterpret_cast<char*>(&checksum),
196 &unused_bytes_read));
197 if (crc32c::Unmask(checksum) != *actual_crc32c) {
198 return errors::DataLoss(
199 "The checksum after Variant ", i, " does not match.",
200 " Expected: ", strings::Printf("%08u", crc32c::Unmask(checksum)),
201 " Actual: ", strings::Printf("%08u", *actual_crc32c));
202 }
203 *actual_crc32c = crc32c::Extend(
204 *actual_crc32c, reinterpret_cast<char*>(&checksum), sizeof(uint32));
205
206 ret->flat<Variant>()(i) = std::move(v);
207 }
208
209 return Status::OK();
210 }
211
GetBackingBuffer(const Tensor & val)212 char* GetBackingBuffer(const Tensor& val) {
213 CHECK(DataTypeCanUseMemcpy(val.dtype())) << val.dtype();
214 return const_cast<char*>(val.tensor_data().data());
215 }
216
GetStringBackingBuffer(const Tensor & val)217 tstring* GetStringBackingBuffer(const Tensor& val) {
218 CHECK_EQ(DT_STRING, val.dtype());
219 return const_cast<tstring*>(val.flat<tstring>().data());
220 }
221
ParseEntryProto(StringPiece key,StringPiece value,protobuf::MessageLite * out)222 Status ParseEntryProto(StringPiece key, StringPiece value,
223 protobuf::MessageLite* out) {
224 if (!out->ParseFromArray(value.data(), value.size())) {
225 return errors::DataLoss("Entry for key ", key, " not parseable.");
226 }
227 return Status::OK();
228 }
229
230 // Serializes the data bytes of the non-string tensor "val". Discards the
231 // original content of "bytes_written", and on OK updates it with number of
232 // bytes written.
233 // REQUIRES: val.dtype() != DT_STRING
WriteTensor(const Tensor & val,FileOutputBuffer * out,size_t * bytes_written)234 Status WriteTensor(const Tensor& val, FileOutputBuffer* out,
235 size_t* bytes_written) {
236 DCHECK_NE(val.dtype(), DT_STRING);
237 DCHECK_NE(val.dtype(), DT_VARIANT);
238 *bytes_written = val.TotalBytes();
239 char* buf = GetBackingBuffer(val);
240 VLOG(1) << "Appending " << *bytes_written << " bytes to file";
241 return out->Append(StringPiece(buf, *bytes_written));
242 }
243
244 // Serializes string tensor "val". "bytes_written" is treated in the same
245 // fashion as WriteTensor().
246 //
247 // Checksums all bytes written and stores it into "crc32c".
248 // REQUIRES: val.dtype() == DT_STRING
WriteStringTensor(const Tensor & val,FileOutputBuffer * out,size_t * bytes_written,uint32 * crc32c)249 Status WriteStringTensor(const Tensor& val, FileOutputBuffer* out,
250 size_t* bytes_written, uint32* crc32c) {
251 // On-disk format:
252 // [varint64 len0]..[varint64 lenL][4 byte cksum on lengths][string bytes]
253 // Var "crc32c" checksums the string lengths (as uint64, not varint64 bytes),
254 // the length-checksum, and all the string bytes.
255 DCHECK_EQ(val.dtype(), DT_STRING);
256 const tstring* strings = GetStringBackingBuffer(val);
257
258 // Writes the varint lengths.
259 string lengths;
260 lengths.reserve(val.NumElements()); // At least 1 byte per element.
261 *crc32c = 0;
262 for (int64_t i = 0; i < val.NumElements(); ++i) {
263 const tstring* elem = &strings[i];
264 DCHECK_EQ(elem->size(), static_cast<uint64>(elem->size()));
265 const uint64 elem_size = static_cast<uint64>(elem->size());
266
267 core::PutVarint64(&lengths, elem_size);
268 if (elem_size <= UINT32_MAX) {
269 // We need to do this because older checkpoints only used uint32s and we
270 // should still support them.
271 const uint32 elem_size_uint32 = static_cast<uint32>(elem_size);
272 *crc32c = crc32c::Extend(*crc32c,
273 reinterpret_cast<const char*>(&elem_size_uint32),
274 sizeof(uint32));
275 } else {
276 *crc32c = crc32c::Extend(
277 *crc32c, reinterpret_cast<const char*>(&elem_size), sizeof(uint64));
278 }
279 }
280 TF_RETURN_IF_ERROR(out->Append(lengths));
281 *bytes_written = lengths.size();
282
283 // Writes the length checksum.
284 const uint32 length_checksum = crc32c::Mask(*crc32c);
285 TF_RETURN_IF_ERROR(out->Append(StringPiece(
286 reinterpret_cast<const char*>(&length_checksum), sizeof(uint32))));
287 *crc32c = crc32c::Extend(
288 *crc32c, reinterpret_cast<const char*>(&length_checksum), sizeof(uint32));
289 *bytes_written += sizeof(uint32);
290
291 // Writes all the string bytes out.
292 for (int64_t i = 0; i < val.NumElements(); ++i) {
293 const tstring* string = &strings[i];
294 TF_RETURN_IF_ERROR(out->Append(*string));
295 *bytes_written += string->size();
296 *crc32c = crc32c::Extend(*crc32c, string->data(), string->size());
297 }
298 return Status::OK();
299 }
300
WriteVariantTensor(const Tensor & val,FileOutputBuffer * out,size_t * bytes_written,uint32 * crc32c)301 Status WriteVariantTensor(const Tensor& val, FileOutputBuffer* out,
302 size_t* bytes_written, uint32* crc32c) {
303 // On-disk format:
304 // [varint64 len1][bytes variant1][4 byte checksum]
305 // ..
306 // [varint64 lenN][bytes variantN][4 byte checksum]
307 // Var "crc32c" checksums all the lens, variant bytes, individual variant
308 // checksums (as uint32, not varint32 bytes).
309 DCHECK_EQ(val.dtype(), DT_VARIANT);
310
311 *crc32c = 0;
312 *bytes_written = 0;
313 for (int64_t i = 0; i < val.NumElements(); ++i) {
314 VariantTensorData data;
315 val.flat<Variant>()(i).Encode(&data);
316 VariantTensorDataProto proto;
317 data.ToProto(&proto);
318 string elem;
319 if (!proto.SerializeToString(&elem)) {
320 return errors::Unknown(
321 "Failed to serialize tensor data of size ", proto.ByteSizeLong(),
322 ". Tensor: ", val.flat<Variant>()(i).DebugString());
323 }
324
325 // Write the length of the serialized variant.
326 DCHECK_EQ(elem.size(), static_cast<uint64>(elem.size()));
327 const auto elem_size = static_cast<uint64>(elem.size());
328 string len;
329 core::PutVarint64(&len, elem_size);
330 TF_RETURN_IF_ERROR(out->Append(len));
331 *crc32c = crc32c::Extend(*crc32c, reinterpret_cast<const char*>(&elem_size),
332 sizeof(uint64));
333 *bytes_written += len.size();
334
335 // Write the serialized variant.
336 TF_RETURN_IF_ERROR(out->Append(elem));
337 *crc32c = crc32c::Extend(*crc32c, elem.data(), elem.size());
338 *bytes_written += elem.size();
339
340 // Write the checksum.
341 const uint32 length_checksum = crc32c::Mask(*crc32c);
342 TF_RETURN_IF_ERROR(out->Append(StringPiece(
343 reinterpret_cast<const char*>(&length_checksum), sizeof(uint32))));
344 *crc32c =
345 crc32c::Extend(*crc32c, reinterpret_cast<const char*>(&length_checksum),
346 sizeof(uint32));
347 *bytes_written += sizeof(uint32);
348 }
349
350 return Status::OK();
351 }
352
353 // Returns whether "slice_spec" is a full slice, with respect to the full shape.
354 //
355 // This can happen say, when "slice_spec" is
356 // "TensorSlice(full_tensor_shape.dims())", or when it is "TensorSlice({{0,
357 // dim(0)}, ..., {0, dim(N)}})" -- a degenerate case we need to guard against.
IsFullSlice(const TensorSlice & slice_spec,const TensorShape & full_tensor_shape)358 bool IsFullSlice(const TensorSlice& slice_spec,
359 const TensorShape& full_tensor_shape) {
360 if (slice_spec.IsFull()) {
361 return true;
362 } else {
363 TensorShape sliced_shape;
364 slice_spec.SliceTensorShape(full_tensor_shape, &sliced_shape).IgnoreError();
365 return sliced_shape == full_tensor_shape;
366 }
367 }
368
CorruptFileError(const Status & in_status,const string & filename,const string & detail)369 Status CorruptFileError(const Status& in_status, const string& filename,
370 const string& detail) {
371 if (in_status.ok()) {
372 return errors::Internal("Unable to read file (", filename,
373 "). Perhaps the file is corrupt or was produced by "
374 "a newer version of TensorFlow with format changes "
375 "(",
376 detail, ")");
377 }
378 return Status(
379 in_status.code(),
380 strings::StrCat("Unable to read file (", filename,
381 "). Perhaps the file is corrupt or was produced by a "
382 "newer version of TensorFlow with format changes (",
383 detail, "): ", in_status.error_message()));
384 }
385
TableBuilderOptions()386 table::Options TableBuilderOptions() {
387 table::Options o;
388 // Compressed tables cannot be read by TensorFlow releases prior to 1.1.
389 // To smoothen the transition, compressed writes are disabled for now
390 // (version 1.2) with the intention that they will be enabled again at
391 // some point (perhaps the 1.3 release?).
392 o.compression = table::kNoCompression;
393 return o;
394 }
395
396 // Writes zeros to output buffer to align the next write to the requested
397 // alignment. "size" is the current size of the buffer and is updated to the
398 // new size.
PadAlignment(FileOutputBuffer * out,int alignment,int64 * size)399 Status PadAlignment(FileOutputBuffer* out, int alignment, int64* size) {
400 int bytes_over = *size % alignment;
401 if (bytes_over == 0) {
402 return Status::OK();
403 }
404 int bytes_to_write = alignment - bytes_over;
405 Status status = out->Append(string(bytes_to_write, '\0'));
406 if (status.ok()) {
407 *size += bytes_to_write;
408 }
409 return status;
410 }
411
412 } // namespace
413
BundleWriter(Env * env,StringPiece prefix,const Options & options)414 BundleWriter::BundleWriter(Env* env, StringPiece prefix, const Options& options)
415 : env_(env), options_(options), prefix_(prefix), out_(nullptr), size_(0) {
416 status_ = env_->HasAtomicMove(prefix_, &use_temp_file_);
417 if (!status_.ok()) return;
418
419 data_path_ = DataFilename(prefix_, 0, 1);
420 metadata_path_ = MetaFilename(prefix_);
421 if (use_temp_file_) {
422 data_path_ = strings::StrCat(data_path_, ".tempstate", random::New64());
423 metadata_path_ =
424 strings::StrCat(metadata_path_, ".tempstate", random::New64());
425 }
426
427 status_ = env_->CreateDir(string(io::Dirname(prefix_)));
428 if (!status_.ok() && !errors::IsAlreadyExists(status_)) {
429 return;
430 }
431
432 std::unique_ptr<WritableFile> wrapper;
433 status_ = env_->NewWritableFile(data_path_, &wrapper);
434 if (!status_.ok()) return;
435 out_ = std::unique_ptr<FileOutputBuffer>(
436 new FileOutputBuffer(wrapper.release(), 8 << 20 /* 8MB write buffer */));
437
438 VLOG(1) << "Writing to file " << data_path_;
439 }
440
Add(StringPiece key,const Tensor & val)441 Status BundleWriter::Add(StringPiece key, const Tensor& val) {
442 if (!status_.ok()) return status_;
443 CHECK_NE(key, kHeaderEntryKey);
444 const string key_string(key);
445 if (entries_.find(key_string) != entries_.end()) {
446 status_ = errors::InvalidArgument("Adding duplicate key: ", key);
447 return status_;
448 }
449
450 BundleEntryProto* entry = &entries_[key_string];
451 entry->set_dtype(val.dtype());
452 val.shape().AsProto(entry->mutable_shape());
453 entry->set_shard_id(0);
454 entry->set_offset(size_);
455
456 // Updates the data file.
457 size_t data_bytes_written = 0;
458 uint32 crc32c = 0;
459 out_->clear_crc32c();
460 if (val.dtype() == DT_STRING) {
461 status_ = WriteStringTensor(val, out_.get(), &data_bytes_written, &crc32c);
462 } else if (val.dtype() == DT_VARIANT) {
463 status_ = WriteVariantTensor(val, out_.get(), &data_bytes_written, &crc32c);
464 } else {
465 status_ = WriteTensor(val, out_.get(), &data_bytes_written);
466 crc32c = out_->crc32c();
467 }
468
469 if (status_.ok()) {
470 entry->set_size(data_bytes_written);
471 entry->set_crc32c(crc32c::Mask(crc32c));
472 size_ += data_bytes_written;
473 status_ = PadAlignment(out_.get(), options_.data_alignment, &size_);
474 }
475 return status_;
476 }
477
AddSlice(StringPiece full_tensor_key,const TensorShape & full_tensor_shape,const TensorSlice & slice_spec,const Tensor & slice_tensor)478 Status BundleWriter::AddSlice(StringPiece full_tensor_key,
479 const TensorShape& full_tensor_shape,
480 const TensorSlice& slice_spec,
481 const Tensor& slice_tensor) {
482 if (!status_.ok()) return status_;
483 CHECK_NE(full_tensor_key, kHeaderEntryKey);
484
485 // If just a singleton full slice, use the regular Add() to be more efficient.
486 if (IsFullSlice(slice_spec, full_tensor_shape)) {
487 return Add(full_tensor_key, slice_tensor);
488 }
489
490 // Inserts/updates the full tensor's metadata entry.
491 //
492 // In the case of a sharded save, MergeBundles() is responsible for merging
493 // the "slices" field of multiple metadata entries corresponding to the same
494 // full tensor.
495 const string full_tensor_key_string(full_tensor_key);
496 BundleEntryProto* full_entry = &entries_[full_tensor_key_string];
497 if (full_entry->dtype() != DT_INVALID) {
498 CHECK_EQ(full_entry->dtype(), slice_tensor.dtype());
499 }
500 if (full_entry->has_shape()) {
501 CHECK(TensorShape(full_entry->shape()) == full_tensor_shape);
502 }
503
504 // Populates dtype, shape, and slices. Intentionally leaving out shard_id and
505 // offset, which do not make sense for this full tensor entry.
506 full_entry->set_dtype(slice_tensor.dtype());
507 full_tensor_shape.AsProto(full_entry->mutable_shape());
508 TensorSliceProto* slice_proto = full_entry->add_slices();
509 slice_spec.AsProto(slice_proto);
510
511 // The slice itself is handled by a regular Add(), which includes adding its
512 // own metadata entry, and writing out the slice's values.
513 const string slice_name =
514 checkpoint::EncodeTensorNameSlice(full_tensor_key_string, slice_spec);
515 status_ = Add(slice_name, slice_tensor);
516 return status_;
517 }
518
519 // TODO(zongheng): on metadata write failure or !status_.ok(), consider removing
520 // the orphaned data file.
Finish()521 Status BundleWriter::Finish() {
522 if (out_) {
523 status_.Update(out_->Close());
524 out_ = nullptr;
525 if (status_.ok()) {
526 if (use_temp_file_) {
527 status_ =
528 Env::Default()->RenameFile(data_path_, DataFilename(prefix_, 0, 1));
529 }
530 } else {
531 Env::Default()->DeleteFile(data_path_).IgnoreError();
532 }
533 }
534 if (!status_.ok()) return status_;
535 // Build key -> BundleEntryProto table.
536 std::unique_ptr<WritableFile> file;
537 status_ = env_->NewWritableFile(metadata_path_, &file);
538 if (!status_.ok()) return status_;
539 {
540 // N.B.: the default use of Snappy compression may not be supported on all
541 // platforms (e.g. Android). The metadata file is small, so this is fine.
542 table::Options options;
543 options.compression = table::kNoCompression;
544 table::TableBuilder builder(options, file.get());
545 // Header entry.
546 BundleHeaderProto header;
547 header.set_num_shards(1);
548 header.set_endianness(BundleHeaderProto::LITTLE);
549 if (!port::kLittleEndian) header.set_endianness(BundleHeaderProto::BIG);
550 VersionDef* version = header.mutable_version();
551 version->set_producer(kTensorBundleVersion);
552 version->set_min_consumer(kTensorBundleMinConsumer);
553
554 builder.Add(kHeaderEntryKey, header.SerializeAsString());
555
556 // All others.
557 for (const auto& p : entries_) {
558 builder.Add(p.first, p.second.SerializeAsString());
559 }
560 status_ = builder.Finish();
561 }
562 status_.Update(file->Close());
563 if (!status_.ok()) {
564 Env::Default()->DeleteFile(metadata_path_).IgnoreError();
565 return status_;
566 } else if (use_temp_file_) {
567 status_ = Env::Default()->RenameFile(metadata_path_, MetaFilename(prefix_));
568 if (!status_.ok()) return status_;
569 }
570 status_ = errors::Internal("BundleWriter is closed");
571 return Status::OK();
572 }
573
574 // Merging tensor bundles.
575
576 // Accumulator of metadata states during a merge.
577 struct MergeState {
578 // Accumulated from the header entries.
579 int num_shards = 0;
580
581 // Derives "endianness" and "version" from the first bundle merged (hence the
582 // "seen_first_bundle" guard). The two fields must be the same for all
583 // bundles in a merge.
584 bool seen_first_bundle = false;
585 BundleHeaderProto_Endianness endianness;
586 VersionDef version;
587
588 // Tensor key -> BundleEntryProto.
589 std::map<string, BundleEntryProto> entries;
590 // Data file path -> new shard id in the final merged bundle.
591 std::unordered_map<string, int32> shard_ids;
592 };
593
594 // Merges entries of "prefix" into the accumulator state "merge".
595 // Returns OK iff the merge succeeds.
MergeOneBundle(Env * env,StringPiece prefix,MergeState * merge_state)596 static Status MergeOneBundle(Env* env, StringPiece prefix,
597 MergeState* merge_state) {
598 VLOG(1) << "Merging bundle:" << prefix;
599 const string filename = MetaFilename(prefix);
600 uint64 file_size;
601 TF_RETURN_IF_ERROR(env->GetFileSize(filename, &file_size));
602 std::unique_ptr<RandomAccessFile> file;
603 TF_RETURN_IF_ERROR(env->NewRandomAccessFile(filename, &file));
604
605 table::Table* table = nullptr;
606 TF_RETURN_IF_ERROR(
607 table::Table::Open(TableBuilderOptions(), file.get(), file_size, &table));
608 std::unique_ptr<table::Table> table_deleter(table);
609 std::unique_ptr<table::Iterator> iter(table->NewIterator());
610
611 int num_shards;
612 // Process header.
613 {
614 iter->Seek(kHeaderEntryKey);
615 if (!iter->Valid()) {
616 return CorruptFileError(iter->status(), filename,
617 "failed to seek to header entry");
618 }
619 BundleHeaderProto header;
620 Status s = ParseEntryProto(iter->key(), iter->value(), &header);
621 if (!s.ok()) return CorruptFileError(s, filename, "unable to parse header");
622
623 merge_state->num_shards += header.num_shards();
624 if (!merge_state->seen_first_bundle) {
625 merge_state->seen_first_bundle = true;
626 merge_state->endianness = header.endianness();
627 merge_state->version = header.version();
628 } else {
629 // Validates "endianness".
630 if (merge_state->endianness != header.endianness()) {
631 return errors::InvalidArgument(
632 "Merging bundles with conflicting endianness; inputs corrupted?");
633 }
634 // Validates "version".
635 string curr_version, merge_version;
636 header.version().SerializeToString(&curr_version);
637 merge_state->version.SerializeToString(&merge_version);
638 if (curr_version != merge_version) {
639 return errors::InvalidArgument(
640 "Merging bundles with different format versions: merged ",
641 merge_version, " vs. curr ", curr_version);
642 }
643 }
644 num_shards = header.num_shards();
645 iter->Next();
646 }
647
648 // Loops through the non-header to-merge entries.
649 BundleEntryProto to_merge_entry;
650 for (; iter->Valid(); iter->Next()) {
651 const string key(iter->key());
652 const auto entry_iter = merge_state->entries.find(key);
653
654 // Illegal: the duplicated entry is a non-slice tensor.
655 if (entry_iter != merge_state->entries.end() &&
656 entry_iter->second.slices().empty()) {
657 return errors::InvalidArgument(
658 "Duplicate tensor keyed by ", key,
659 " encountered, when merging prefix: ", prefix);
660 }
661
662 TF_RETURN_IF_ERROR(
663 ParseEntryProto(iter->key(), iter->value(), &to_merge_entry));
664
665 // The duplicated entry holds metadata for a sliced full tensor.
666 // Allows the duplication and merges "slices".
667 if (entry_iter != merge_state->entries.end()) {
668 BundleEntryProto& existing_entry = entry_iter->second;
669 if (to_merge_entry.slices().empty()) {
670 return errors::Internal(
671 "Duplicate tensor keyed by ", key,
672 "; attempting to merge in a non-slice bundle entry");
673 }
674 // Only needs merge the "slices" field (and validate dtype/shape).
675 for (int i = 0; i < to_merge_entry.slices_size(); ++i) {
676 TensorSliceProto* slot = existing_entry.add_slices();
677 *slot = to_merge_entry.slices(i);
678 }
679 CHECK_EQ(existing_entry.dtype(), to_merge_entry.dtype());
680 CHECK(TensorShape(existing_entry.shape()) ==
681 TensorShape(to_merge_entry.shape()));
682 continue;
683 }
684
685 // Key doesn't duplicate: a fresh tensor/slice entry.
686 auto result = merge_state->shard_ids.insert(
687 {DataFilename(prefix, to_merge_entry.shard_id(), num_shards),
688 merge_state->shard_ids.size()});
689 to_merge_entry.set_shard_id(result.first->second);
690 merge_state->entries[key] = to_merge_entry;
691 }
692 return Status::OK();
693 }
694
MergeBundles(Env * env,gtl::ArraySlice<tstring> prefixes,StringPiece merged_prefix)695 Status MergeBundles(Env* env, gtl::ArraySlice<tstring> prefixes,
696 StringPiece merged_prefix) {
697 // Merges all metadata tables.
698 // TODO(zhifengc): KeyValue sorter if it becomes too big.
699 MergeState merge;
700 Status status = env->CreateDir(string(io::Dirname(merged_prefix)));
701 if (!status.ok() && !errors::IsAlreadyExists(status)) return status;
702 for (int i = 0; i < prefixes.size(); ++i) {
703 TF_RETURN_IF_ERROR(MergeOneBundle(env, prefixes[i], &merge));
704 }
705
706 // Renames data files to contain the merged bundle prefix.
707 for (const auto& p : merge.shard_ids) {
708 VLOG(1) << "Renaming " << p.first << " to "
709 << DataFilename(merged_prefix, p.second, merge.shard_ids.size());
710 TF_RETURN_IF_ERROR(env->RenameFile(
711 p.first,
712 DataFilename(merged_prefix, p.second, merge.shard_ids.size())));
713 }
714
715 // Writes the final metadata table under the merged prefix.
716 std::unique_ptr<WritableFile> merged_metadata;
717 TF_RETURN_IF_ERROR(
718 env->NewWritableFile(MetaFilename(merged_prefix), &merged_metadata));
719 {
720 table::TableBuilder builder(TableBuilderOptions(), merged_metadata.get());
721 // Header entry.
722 BundleHeaderProto header;
723 header.set_num_shards(merge.num_shards);
724 header.set_endianness(merge.endianness);
725 *header.mutable_version() = merge.version;
726 builder.Add(kHeaderEntryKey, header.SerializeAsString());
727 // All others.
728 for (const auto& p : merge.entries) {
729 builder.Add(p.first, p.second.SerializeAsString());
730 }
731 status = builder.Finish();
732 }
733 status.Update(merged_metadata->Close());
734 if (!status.ok()) return status;
735 VLOG(1) << "Merged bundles to:" << merged_prefix;
736
737 // Cleanup: best effort based and ignores errors.
738 for (const tstring& prefix : prefixes) {
739 env->DeleteFile(MetaFilename(prefix)).IgnoreError();
740 }
741 return status;
742 }
743
744 // Interface for reading a tensor bundle.
745
BundleReader(Env * env,StringPiece prefix)746 BundleReader::BundleReader(Env* env, StringPiece prefix)
747 : env_(env),
748 prefix_(prefix),
749 metadata_(nullptr),
750 table_(nullptr),
751 index_cache_(nullptr),
752 iter_(nullptr),
753 need_to_swap_bytes_(false) {
754 const string filename = MetaFilename(prefix_);
755 uint64 file_size;
756 status_ = env_->GetFileSize(filename, &file_size);
757 if (!status_.ok()) return;
758
759 // Opens the metadata table.
760 std::unique_ptr<RandomAccessFile> wrapper;
761 status_ = env_->NewRandomAccessFile(filename, &wrapper);
762 if (!status_.ok()) return;
763 metadata_ = wrapper.release();
764
765 table::Options o;
766 int64_t cache_size;
767 Status s =
768 ReadInt64FromEnvVar("TF_TABLE_INDEX_CACHE_SIZE_IN_MB", 0, &cache_size);
769 if (s.ok() && cache_size > 0) {
770 index_cache_ = table::NewLRUCache(cache_size << 20);
771 o.block_cache = index_cache_;
772 }
773
774 status_ = table::Table::Open(o, metadata_, file_size, &table_);
775 if (!status_.ok()) return;
776 iter_ = table_->NewIterator();
777
778 // Reads "num_shards_" from the first entry.
779 iter_->Seek(kHeaderEntryKey);
780 if (!iter_->Valid()) {
781 status_ = CorruptFileError(iter_->status(), filename,
782 "failed to seek to header entry");
783 return;
784 }
785 BundleHeaderProto header;
786 status_ = ParseEntryProto(iter_->key(), iter_->value(), &header);
787 if (!status_.ok()) {
788 status_ = CorruptFileError(status_, filename, "unable to parse header");
789 return;
790 }
791 num_shards_ = header.num_shards();
792 if ((header.endianness() == BundleHeaderProto::BIG && port::kLittleEndian) ||
793 (header.endianness() == BundleHeaderProto::LITTLE &&
794 !port::kLittleEndian)) {
795 need_to_swap_bytes_ = true;
796 }
797 status_ = CheckVersions(header.version(), kTensorBundleVersion,
798 kTensorBundleMinProducer, "Checkpoint", "checkpoint");
799 }
800
~BundleReader()801 BundleReader::~BundleReader() {
802 delete metadata_;
803 delete iter_;
804 delete table_;
805 if (index_cache_) {
806 delete index_cache_;
807 }
808 // InputBuffer does not own the underlying RandomAccessFile.
809 for (auto pair : data_) {
810 if (pair.second != nullptr && pair.second->file() != nullptr) {
811 delete pair.second->file();
812 }
813 }
814 for (auto& temp : data_) {
815 delete temp.second;
816 }
817 for (auto& temp : tensor_slices_) {
818 delete temp.second;
819 }
820 data_.clear();
821 tensor_slices_.clear();
822 }
823
GetBundleEntryProto(StringPiece key,BundleEntryProto * entry)824 Status BundleReader::GetBundleEntryProto(StringPiece key,
825 BundleEntryProto* entry) {
826 entry->Clear();
827 TF_CHECK_OK(status_);
828 Seek(key);
829 if (!iter_->Valid() || iter_->key() != key) {
830 return errors::NotFound("Key ", key, " not found in checkpoint");
831 }
832
833 BundleEntryProto entry_copy;
834 TF_RETURN_IF_ERROR(
835 ParseEntryProto(iter_->key(), iter_->value(), &entry_copy));
836 if (!TensorShape::IsValid(entry_copy.shape())) {
837 return errors::DataLoss("Invalid tensor shape: ", key, " ",
838 entry_copy.shape().ShortDebugString());
839 }
840
841 entry->Swap(&entry_copy);
842 return Status::OK();
843 }
844
GetValue(const BundleEntryProto & entry,Tensor * val)845 Status BundleReader::GetValue(const BundleEntryProto& entry, Tensor* val) {
846 Tensor* ret = val;
847 const TensorShape stored_shape(TensorShape(entry.shape()));
848 if (val->NumElements() == 0) {
849 ret = new Tensor(entry.dtype(), stored_shape);
850 }
851
852 // Validates the "size" field.
853 if (entry.dtype() != DT_STRING && entry.dtype() != DT_VARIANT) {
854 if (entry.size() != ret->TotalBytes()) {
855 return errors::DataLoss("Invalid size in bundle entry: key ", key(),
856 "; stored size ", entry.size(),
857 "; expected size ", ret->TotalBytes());
858 }
859 } else if (entry.dtype() == DT_STRING) {
860 // Relaxes the check for string tensors as follows:
861 // entry.size() == bytes(varint lengths) + bytes(data)
862 // >= NumElems + bytes(data), since size bytes(varint) >= 1.
863 // TotalBytes() == sizeof(tstring) * NumElems + bytes(data)
864 // Since we don't know bytes(varint lengths), we just check an inequality.
865 const size_t lower_bound = ret->NumElements() + ret->TotalBytes() -
866 sizeof(tstring) * ret->NumElements();
867 if (entry.size() < lower_bound) {
868 return errors::DataLoss("Invalid size in bundle entry: key ", key(),
869 "; stored size ", entry.size(),
870 "; expected size is at least ", lower_bound);
871 }
872 }
873
874 // Open the data file if it has not been opened.
875 io::InputBuffer* buffered_file = data_[entry.shard_id()];
876 if (buffered_file == nullptr) {
877 std::unique_ptr<RandomAccessFile> file = nullptr;
878 TF_RETURN_IF_ERROR(env_->NewRandomAccessFile(
879 DataFilename(prefix_, entry.shard_id(), num_shards_), &file));
880 buffered_file = new io::InputBuffer(file.release(), kBufferSize);
881 // The InputBuffer and RandomAccessFile objects are both released in dtor.
882 data_[entry.shard_id()] = buffered_file;
883 }
884 CHECK(buffered_file != nullptr);
885
886 TF_RETURN_IF_ERROR(buffered_file->Seek(entry.offset()));
887 uint32 actual_crc32c = 0;
888
889 if (DataTypeCanUseMemcpy(entry.dtype())) {
890 char* backing_buffer = const_cast<char*>((ret->tensor_data().data()));
891 size_t unused_bytes_read;
892 if (entry.size() > kBufferSize) {
893 StringPiece sp;
894 TF_RETURN_IF_ERROR(buffered_file->file()->Read(
895 entry.offset(), entry.size(), &sp, backing_buffer));
896 if (sp.data() != backing_buffer) {
897 memmove(backing_buffer, sp.data(), entry.size());
898 }
899 } else {
900 TF_RETURN_IF_ERROR(buffered_file->ReadNBytes(entry.size(), backing_buffer,
901 &unused_bytes_read));
902 }
903 // Note that we compute the checksum *before* byte-swapping. The checksum
904 // should be on the bytes in the order they appear in the file.
905 actual_crc32c = crc32c::Value(backing_buffer, entry.size());
906 if (need_to_swap_bytes_) {
907 TF_RETURN_IF_ERROR(ByteSwapTensor(ret));
908 }
909 } else if (entry.dtype() == DT_VARIANT) {
910 if (need_to_swap_bytes_) {
911 return errors::Unimplemented(
912 "TensorBundle at ", prefix_,
913 "is of a different endianness than this machine's hardware, and "
914 "the bundle contains a variant (arbitrary C++ type) tensor. "
915 "Byte-swapping of variant tensors is not currently implemented.");
916 }
917 // Relies on io::InputBuffer's buffering, because we issue many neighboring
918 // reads for a single string tensor.
919 TF_RETURN_IF_ERROR(ReadVariantTensor(buffered_file, ret, entry.offset(),
920 entry.size(), &actual_crc32c));
921 } else {
922 // Relies on io::InputBuffer's buffering, because we issue many neighboring
923 // reads for a single string tensor.
924 TF_RETURN_IF_ERROR(ReadStringTensor(
925 buffered_file, ret->NumElements(), entry.offset(), entry.size(),
926 GetStringBackingBuffer(*ret), &actual_crc32c, need_to_swap_bytes_));
927 }
928 if (crc32c::Unmask(entry.crc32c()) != actual_crc32c) {
929 return errors::DataLoss(
930 "TensorBundle at ", prefix_, " shard ", entry.shard_id(), " (",
931 entry.size(), " bytes): Checksum does not match: stored ",
932 strings::Printf("%08u", crc32c::Unmask(entry.crc32c())),
933 " vs. calculated on the restored bytes ", actual_crc32c);
934 }
935
936 *val = *ret;
937 if (ret != val) delete ret;
938 return Status::OK();
939 }
940
Lookup(StringPiece key,Tensor * val)941 Status BundleReader::Lookup(StringPiece key, Tensor* val) {
942 CHECK(val != nullptr);
943 BundleEntryProto entry;
944 TF_RETURN_IF_ERROR(GetBundleEntryProto(key, &entry));
945
946 if (entry.slices().empty()) {
947 return GetValue(entry, val);
948 } else {
949 return GetSliceValue(
950 key, entry,
951 /* a full slice */ TensorSlice(TensorShape(entry.shape()).dims()), val);
952 }
953 }
954
ReadCurrent(Tensor * val)955 Status BundleReader::ReadCurrent(Tensor* val) {
956 CHECK(val != nullptr);
957 BundleEntryProto entry;
958 TF_RETURN_IF_ERROR(ParseEntryProto(iter_->key(), iter_->value(), &entry));
959 if (!TensorShape::IsValid(entry.shape())) {
960 return errors::DataLoss("Invalid tensor shape: ", iter_->key(), " ",
961 entry.shape().ShortDebugString());
962 }
963
964 if (entry.slices().empty()) {
965 return GetValue(entry, val);
966 } else {
967 return GetSliceValue(
968 iter_->key(), entry,
969 /* a full slice */ TensorSlice(TensorShape(entry.shape()).dims()), val);
970 }
971 }
972
LookupTensorSlices(StringPiece key,std::vector<TensorSlice> * slices)973 Status BundleReader::LookupTensorSlices(StringPiece key,
974 std::vector<TensorSlice>* slices) {
975 slices->clear();
976 BundleEntryProto entry;
977 TF_RETURN_IF_ERROR(GetBundleEntryProto(key, &entry));
978 slices->reserve(entry.slices_size());
979 for (const auto& slice : entry.slices()) {
980 slices->emplace_back(slice);
981 }
982 return Status::OK();
983 }
984
LookupSlice(StringPiece full_tensor_key,const TensorSlice & slice_spec,Tensor * val)985 Status BundleReader::LookupSlice(StringPiece full_tensor_key,
986 const TensorSlice& slice_spec, Tensor* val) {
987 CHECK(val != nullptr);
988 BundleEntryProto entry;
989 TF_RETURN_IF_ERROR(GetBundleEntryProto(full_tensor_key, &entry));
990 return GetSliceValue(full_tensor_key, entry, slice_spec, val);
991 }
992
GetSliceValue(StringPiece full_tensor_key,const BundleEntryProto & full_tensor_entry,const TensorSlice & slice_spec,Tensor * val)993 Status BundleReader::GetSliceValue(StringPiece full_tensor_key,
994 const BundleEntryProto& full_tensor_entry,
995 const TensorSlice& slice_spec, Tensor* val) {
996 using checkpoint::RegisterTensorSlice;
997 using checkpoint::TensorSliceSet;
998 DCHECK_GE(full_tensor_entry.slices_size(), 0);
999
1000 const TensorShape full_shape(TensorShape(full_tensor_entry.shape()));
1001 std::vector<std::pair<TensorSlice, string>> details;
1002 const string full_tensor_key_string(full_tensor_key);
1003 const TensorSliceSet* tss =
1004 gtl::FindPtrOrNull(tensor_slices_, full_tensor_key_string);
1005
1006 // Populates the "full tensor key -> TensorSliceSet" cache.
1007 if (tss == nullptr) {
1008 if (full_tensor_entry.slices().empty()) {
1009 // Special case: a writer has saved a tensor fully, but the reader wants
1010 // to read in slices. We therefore register the full slice on-demand here
1011 // without further complicating the on-disk bundle format.
1012 TF_RETURN_IF_ERROR(RegisterTensorSlice(
1013 full_tensor_key_string, full_shape, full_tensor_entry.dtype(),
1014 /* tag */ "",
1015 /* full slice */ TensorSlice(full_shape.dims()), &tensor_slices_));
1016 }
1017 for (const TensorSliceProto& slice : full_tensor_entry.slices()) {
1018 TF_RETURN_IF_ERROR(RegisterTensorSlice(
1019 full_tensor_key_string, full_shape, full_tensor_entry.dtype(),
1020 /* tag */ "", TensorSlice(slice), &tensor_slices_));
1021 }
1022 tss = gtl::FindPtrOrNull(tensor_slices_, full_tensor_key_string);
1023 CHECK_NE(tss, nullptr);
1024 }
1025 if (!tss->QueryMeta(slice_spec, &details)) {
1026 return errors::InvalidArgument(
1027 "Does not have sufficient slices for partitioned tensor ",
1028 full_tensor_key,
1029 " to restore in slice_spec: ", slice_spec.DebugString());
1030 }
1031
1032 // The union of the slices in "details" covers "slice_spec". Performs the
1033 // copies from each.
1034 BundleEntryProto stored_slice_entry = full_tensor_entry;
1035 for (const auto& slice_tag_pair : details) {
1036 // Seeks for the stored slice.
1037 const TensorSlice& stored_slice = slice_tag_pair.first;
1038
1039 // We already have the entry for the full tensor, so don't query again if
1040 // the slice is full.
1041 if (!stored_slice.IsFull()) {
1042 const string encoded_stored_slice_name =
1043 checkpoint::EncodeTensorNameSlice(full_tensor_key_string,
1044 stored_slice);
1045 status_ =
1046 GetBundleEntryProto(encoded_stored_slice_name, &stored_slice_entry);
1047 if (!status_.ok()) return status_;
1048 }
1049
1050 // TODO(zongheng): should we take an OpKernelContext, so that we can call
1051 // allocate_temp()? Note that without major refactorings to Saver, it's
1052 // hard for the caller of the tensor bundle module to allocate these
1053 // precisely-shaped scratch storage.
1054
1055 // Optimization for the common case: the stored slice can be directly
1056 // copied to the destination without additional slicing. This is true when
1057 // either the slices are equal or when they are both full slices having the
1058 // same shape.
1059 TensorShape stored_slice_shape(stored_slice_entry.shape());
1060 if (stored_slice == slice_spec ||
1061 (stored_slice_shape == val->shape() &&
1062 IsFullSlice(stored_slice, stored_slice_shape) &&
1063 IsFullSlice(slice_spec, stored_slice_shape))) {
1064 VLOG(1) << "Optimized for common case: directly copying into "
1065 "pre-allocated buffer; spec: "
1066 << slice_spec.DebugString();
1067 status_ = GetValue(stored_slice_entry, val);
1068 return status_;
1069 }
1070
1071 Tensor stored_slice_tensor(stored_slice_entry.dtype(), stored_slice_shape);
1072 status_ = GetValue(stored_slice_entry, &stored_slice_tensor);
1073 if (!status_.ok()) return status_;
1074
1075 // Copies the intersection over.
1076 const DataType common_dtype = full_tensor_entry.dtype();
1077 switch (common_dtype) {
1078 #define HANDLE_COPY(T) \
1079 case DataTypeToEnum<T>::value: \
1080 CHECK(CopyDataFromTensorSliceToTensorSlice( \
1081 full_shape, stored_slice, slice_spec, \
1082 stored_slice_tensor.flat<T>().data(), val->flat<T>().data())); \
1083 break;
1084
1085 HANDLE_COPY(float)
1086 HANDLE_COPY(double)
1087 HANDLE_COPY(int32)
1088 HANDLE_COPY(uint8)
1089 HANDLE_COPY(int16)
1090 HANDLE_COPY(int8)
1091 HANDLE_COPY(complex64)
1092 HANDLE_COPY(complex128)
1093 HANDLE_COPY(int64)
1094 HANDLE_COPY(bool)
1095 HANDLE_COPY(qint32)
1096 HANDLE_COPY(quint8)
1097 HANDLE_COPY(qint8)
1098 HANDLE_COPY(bfloat16)
1099 default:
1100 return errors::InvalidArgument("Dtype ", DataTypeString(common_dtype),
1101 " not supported.");
1102 }
1103 #undef HANDLE_COPY
1104 }
1105 return Status::OK();
1106 }
1107
Contains(StringPiece key)1108 bool BundleReader::Contains(StringPiece key) {
1109 Seek(key);
1110 return Valid() && (this->key() == key);
1111 }
1112
LookupDtypeAndShape(StringPiece key,DataType * dtype,TensorShape * shape)1113 Status BundleReader::LookupDtypeAndShape(StringPiece key, DataType* dtype,
1114 TensorShape* shape) {
1115 BundleEntryProto entry;
1116 TF_RETURN_IF_ERROR(GetBundleEntryProto(key, &entry));
1117 *dtype = entry.dtype();
1118 *shape = TensorShape(entry.shape());
1119 return Status::OK();
1120 }
1121
LookupTensorShape(StringPiece key,TensorShape * shape)1122 Status BundleReader::LookupTensorShape(StringPiece key, TensorShape* shape) {
1123 DataType ignored;
1124 return LookupDtypeAndShape(key, &ignored, shape);
1125 }
1126
DebugString()1127 string BundleReader::DebugString() {
1128 // Format used below emulates that of TensorSliceReader::DebugString().
1129 string shape_str;
1130 BundleEntryProto entry;
1131 Seek(kHeaderEntryKey);
1132 for (Next(); Valid(); Next()) {
1133 CHECK(entry.ParseFromArray(value().data(), value().size()));
1134 if (entry.slices_size() > 0) continue; // Slice of some partitioned var.
1135
1136 strings::StrAppend(&shape_str, key(), " (", DataType_Name(entry.dtype()),
1137 ") ", TensorShape(entry.shape()).DebugString());
1138 strings::StrAppend(&shape_str, "\n");
1139 }
1140 return shape_str;
1141 }
1142
1143 namespace {
AlignedMalloc(size_t size)1144 inline char* AlignedMalloc(size_t size) {
1145 char* buffer = static_cast<char*>(port::AlignedMalloc(size, 64));
1146 DCHECK(buffer);
1147 return buffer;
1148 }
1149 } // namespace
1150
FileOutputBuffer(WritableFile * file,size_t buffer_size)1151 FileOutputBuffer::FileOutputBuffer(WritableFile* file, size_t buffer_size)
1152 : file_(file), position_(0), buffer_size_(buffer_size) {
1153 DCHECK_GT(buffer_size, 0);
1154 buffer_ptr_ = AlignedMalloc(buffer_size);
1155 }
1156
~FileOutputBuffer()1157 FileOutputBuffer::~FileOutputBuffer() {
1158 if (buffer_ptr_) port::AlignedFree(buffer_ptr_);
1159 delete file_;
1160 }
1161
Append(StringPiece data)1162 Status FileOutputBuffer::Append(StringPiece data) {
1163 // In the below, it is critical to calculate the checksum on the actually
1164 // copied bytes, not the source bytes. This is because "data" typically
1165 // points to tensor buffers, which may be concurrently written.
1166 if (data.size() + position_ <= buffer_size_) {
1167 // Can fit into the current buffer.
1168 memcpy(buffer_ptr_ + position_, data.data(), data.size());
1169 crc32c_ = crc32c::Extend(crc32c_, buffer_ptr_ + position_, data.size());
1170 } else if (data.size() <= buffer_size_) {
1171 // Cannot fit, but can fit after flushing.
1172 TF_RETURN_IF_ERROR(FlushBuffer(false));
1173 memcpy(buffer_ptr_, data.data(), data.size());
1174 crc32c_ = crc32c::Extend(crc32c_, buffer_ptr_, data.size());
1175 } else {
1176 // Cannot fit even after flushing. So we break down "data" by chunk, and
1177 // flush/checksum each chunk.
1178 TF_RETURN_IF_ERROR(FlushBuffer(false));
1179 for (size_t i = 0; i < data.size(); i += buffer_size_) {
1180 const size_t nbytes = std::min(data.size() - i, buffer_size_);
1181 memcpy(buffer_ptr_, data.data() + i, nbytes);
1182 crc32c_ = crc32c::Extend(crc32c_, buffer_ptr_, nbytes);
1183 position_ = nbytes;
1184 TF_RETURN_IF_ERROR(FlushBuffer(false));
1185 }
1186 return Status::OK();
1187 }
1188 position_ += data.size();
1189 return Status::OK();
1190 }
1191
Close()1192 Status FileOutputBuffer::Close() {
1193 TF_RETURN_IF_ERROR(FlushBuffer(true));
1194 return file_->Close();
1195 }
1196
FlushBuffer(bool closing)1197 Status FileOutputBuffer::FlushBuffer(bool closing) {
1198 if (position_ > 0) {
1199 // Use Cord to avoid extra data copy for some WritableFile implementations.
1200 absl::Cord buffer = absl::MakeCordFromExternal(
1201 StringPiece(buffer_ptr_, position_),
1202 [ptr = buffer_ptr_](StringPiece) { port::AlignedFree(ptr); });
1203 buffer_ptr_ = closing ? nullptr : AlignedMalloc(buffer_size_);
1204 TF_RETURN_IF_ERROR(file_->Append(buffer));
1205 position_ = 0;
1206 }
1207 return Status::OK();
1208 }
1209
1210 } // namespace tensorflow
1211