1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/util/tensor_bundle/tensor_bundle.h"
17
18 #include <algorithm>
19 #include <cstdlib>
20 #include <cstring>
21 #include <memory>
22 #include <utility>
23
24 #include "tensorflow/core/framework/register_types.h"
25 #include "tensorflow/core/framework/tensor.pb.h"
26 #include "tensorflow/core/framework/tensor_shape.pb.h"
27 #include "tensorflow/core/framework/types.h"
28 #include "tensorflow/core/framework/types.pb.h"
29 #include "tensorflow/core/framework/variant.h"
30 #include "tensorflow/core/framework/variant_op_registry.h"
31 #include "tensorflow/core/framework/variant_tensor_data.h"
32 #include "tensorflow/core/framework/versions.h"
33 #include "tensorflow/core/framework/versions.pb.h"
34 #include "tensorflow/core/lib/core/coding.h"
35 #include "tensorflow/core/lib/core/errors.h"
36 #include "tensorflow/core/lib/gtl/map_util.h"
37 #include "tensorflow/core/lib/hash/crc32c.h"
38 #include "tensorflow/core/lib/io/path.h"
39 #include "tensorflow/core/lib/io/table_builder.h"
40 #include "tensorflow/core/lib/random/random.h"
41 #include "tensorflow/core/lib/strings/str_util.h"
42 #include "tensorflow/core/lib/strings/stringprintf.h"
43 #include "tensorflow/core/platform/bfloat16.h"
44 #include "tensorflow/core/platform/cord.h"
45 #include "tensorflow/core/platform/errors.h"
46 #include "tensorflow/core/platform/mem.h"
47 #include "tensorflow/core/util/env_var.h"
48 #include "tensorflow/core/util/saved_tensor_slice_util.h"
49 #include "tensorflow/core/util/tensor_bundle/byte_swap.h"
50 #include "tensorflow/core/util/tensor_slice_util.h"
51
52 #ifdef PLATFORM_WINDOWS
53 #undef DeleteFile
54 #endif
55
56 namespace tensorflow {
57
58 // Versioning of the tensor bundle format.
59 const int kTensorBundleMinProducer = 0;
60 const int kTensorBundleMinConsumer = 0;
61 const int kTensorBundleVersion = 1;
62
63 // Size of our input buffer for streaming reads
64 static const int kBufferSize = 1024 * 1024;
65
66 // Key to the special BundleHeaderProto entry. Do not change this, as clients
67 // can make the assumption that the header is always the first entry in the
68 // bundle.
69 const char* const kHeaderEntryKey = "";
70
71 namespace {
72
73 // Reads "num_elements" string elements from file[offset, offset+size) into the
74 // length-N "destination". Discards the original content of "destination".
75 //
76 // Checksums the string lengths (as restored uint32 or uint64, not varint64
77 // bytes) and string bytes, and stores it into "actual_crc32c".
ReadStringTensor(io::InputBuffer * buffered_file,size_t num_elements,size_t offset,size_t size,tstring * destination,uint32 * actual_crc32c,bool need_to_swap_bytes)78 Status ReadStringTensor(io::InputBuffer* buffered_file, size_t num_elements,
79 size_t offset, size_t size, tstring* destination,
80 uint32* actual_crc32c, bool need_to_swap_bytes) {
81 if (size == 0) return OkStatus();
82 CHECK_GT(size, 0);
83
84 // Reads "num_elements" varint64's from "buffered_file".
85 TF_RETURN_IF_ERROR(buffered_file->Seek(offset));
86 TF_RETURN_IF_ERROR(buffered_file->Hint(size));
87 std::vector<uint64> string_lengths(num_elements);
88 for (size_t i = 0; i < num_elements; ++i) {
89 TF_RETURN_IF_ERROR(buffered_file->ReadVarint64(&string_lengths[i]));
90 if (string_lengths[i] <= UINT32_MAX) {
91 // We need to do this because older checkpoints only used uint32s and we
92 // should still support them.
93 uint32 elem_size_uint32 = static_cast<uint32>(string_lengths[i]);
94 if (need_to_swap_bytes) {
95 // Checksum would have been computed on the source machine's byte order
96 elem_size_uint32 = BYTE_SWAP_32(elem_size_uint32);
97 }
98 *actual_crc32c = crc32c::Extend(
99 *actual_crc32c, reinterpret_cast<const char*>(&elem_size_uint32),
100 sizeof(uint32));
101 } else {
102 uint64 length = string_lengths[i];
103 if (need_to_swap_bytes) {
104 length = BYTE_SWAP_64(length);
105 }
106 *actual_crc32c =
107 crc32c::Extend(*actual_crc32c, reinterpret_cast<const char*>(&length),
108 sizeof(uint64));
109 }
110 }
111 if (offset + size < buffered_file->Tell()) {
112 return errors::DataLoss("String lengths longer than expected offset ",
113 offset + size);
114 }
115
116 // Reads the length-checksum.
117 uint32 raw_length_checksum = 0; // Bytes in file
118 uint32 length_checksum = 0; // In-memory representation
119 size_t unused_bytes_read = 0;
120 TF_RETURN_IF_ERROR(buffered_file->ReadNBytes(
121 sizeof(uint32), reinterpret_cast<char*>(&raw_length_checksum),
122 &unused_bytes_read));
123 length_checksum = need_to_swap_bytes ? BYTE_SWAP_32(raw_length_checksum)
124 : raw_length_checksum;
125 if (crc32c::Unmask(length_checksum) != *actual_crc32c) {
126 return errors::DataLoss(
127 "The length checksum does not match: expected ",
128 strings::Printf("%08u", crc32c::Unmask(length_checksum)),
129 " but actual is ", strings::Printf("%08u", *actual_crc32c));
130 }
131 *actual_crc32c = crc32c::Extend(*actual_crc32c,
132 reinterpret_cast<char*>(&raw_length_checksum),
133 sizeof(uint32));
134
135 // Reads the actual string bytes.
136 for (size_t i = 0; i < num_elements; ++i) {
137 const uint64 string_length = string_lengths[i];
138 tstring* buffer = &destination[i];
139
140 buffer->resize(string_length);
141 size_t bytes_read = 0;
142 TF_RETURN_IF_ERROR(
143 buffered_file->ReadNBytes(string_length, &(*buffer)[0], &bytes_read));
144 *actual_crc32c = crc32c::Extend(*actual_crc32c, buffer->data(), bytes_read);
145 }
146 return OkStatus();
147 }
148
ReadVariantTensor(io::InputBuffer * buffered_file,Tensor * ret,size_t offset,size_t size,uint32 * actual_crc32c)149 Status ReadVariantTensor(io::InputBuffer* buffered_file, Tensor* ret,
150 size_t offset, size_t size, uint32* actual_crc32c) {
151 // On-disk format:
152 // [varint64 len1][bytes variant1][4 byte checksum]
153 // ..
154 // [varint64 lenN][bytes variantN][4 byte checksum]
155 // Var "crc32c" checksums all the lens, variant bytes, individual variant
156 // checksums (as uint32, not varint32 bytes).
157 if (size == 0) return OkStatus();
158 size_t num_elements = ret->NumElements();
159
160 // Reads the actual string bytes.
161 TF_RETURN_IF_ERROR(buffered_file->Seek(offset));
162 TF_RETURN_IF_ERROR(buffered_file->Hint(size));
163 for (size_t i = 0; i < num_elements; ++i) {
164 // Read the serialized variant length.
165 uint64 string_length = 0;
166 TF_RETURN_IF_ERROR(buffered_file->ReadVarint64(&string_length));
167 *actual_crc32c = crc32c::Extend(
168 *actual_crc32c, reinterpret_cast<const char*>(&string_length),
169 sizeof(uint64));
170 // Read the actual serialized variant.
171 string buffer;
172 buffer.resize(string_length);
173 size_t bytes_read = 0;
174 TF_RETURN_IF_ERROR(
175 buffered_file->ReadNBytes(string_length, &buffer[0], &bytes_read));
176 *actual_crc32c = crc32c::Extend(*actual_crc32c, buffer.data(), bytes_read);
177 VariantTensorDataProto proto;
178 if (!proto.ParseFromString(buffer)) {
179 return errors::DataLoss("Unable to parse VariantTensorDataProto from ",
180 "buffer of size ", string_length, ". ",
181 "Bundle entry offset: ", offset, " size: ", size);
182 }
183 Variant v = proto;
184 if (!DecodeUnaryVariant(&v)) {
185 return errors::Internal("Could not decode variant with type_name: \"",
186 v.TypeName(), "\". Perhaps you forgot to ",
187 "register a decoder via ",
188 "REGISTER_UNARY_VARIANT_DECODE_FUNCTION?");
189 }
190
191 // Read the checksum.
192 uint32 checksum = 0;
193 size_t unused_bytes_read = 0;
194 TF_RETURN_IF_ERROR(buffered_file->ReadNBytes(
195 sizeof(uint32), reinterpret_cast<char*>(&checksum),
196 &unused_bytes_read));
197 if (crc32c::Unmask(checksum) != *actual_crc32c) {
198 return errors::DataLoss(
199 "The checksum after Variant ", i, " does not match.",
200 " Expected: ", strings::Printf("%08u", crc32c::Unmask(checksum)),
201 " Actual: ", strings::Printf("%08u", *actual_crc32c));
202 }
203 *actual_crc32c = crc32c::Extend(
204 *actual_crc32c, reinterpret_cast<char*>(&checksum), sizeof(uint32));
205
206 ret->flat<Variant>()(i) = std::move(v);
207 }
208
209 return OkStatus();
210 }
211
GetBackingBuffer(const Tensor & val)212 char* GetBackingBuffer(const Tensor& val) {
213 CHECK(DataTypeCanUseMemcpy(val.dtype())) << val.dtype();
214 return const_cast<char*>(val.tensor_data().data());
215 }
216
GetStringBackingBuffer(const Tensor & val)217 tstring* GetStringBackingBuffer(const Tensor& val) {
218 CHECK_EQ(DT_STRING, val.dtype());
219 return const_cast<tstring*>(val.flat<tstring>().data());
220 }
221
ParseEntryProto(StringPiece key,StringPiece value,protobuf::MessageLite * out)222 Status ParseEntryProto(StringPiece key, StringPiece value,
223 protobuf::MessageLite* out) {
224 if (!out->ParseFromArray(value.data(), value.size())) {
225 return errors::DataLoss("Entry for key ", key, " not parseable.");
226 }
227 return OkStatus();
228 }
229
230 // Serializes the data bytes of the non-string tensor "val". Discards the
231 // original content of "bytes_written", and on OK updates it with number of
232 // bytes written.
233 // REQUIRES: val.dtype() != DT_STRING
WriteTensor(const Tensor & val,FileOutputBuffer * out,size_t * bytes_written)234 Status WriteTensor(const Tensor& val, FileOutputBuffer* out,
235 size_t* bytes_written) {
236 DCHECK_NE(val.dtype(), DT_STRING);
237 DCHECK_NE(val.dtype(), DT_VARIANT);
238 *bytes_written = val.TotalBytes();
239 char* buf = GetBackingBuffer(val);
240 VLOG(1) << "Appending " << *bytes_written << " bytes to file";
241 return out->Append(StringPiece(buf, *bytes_written));
242 }
243
244 // Serializes string tensor "val". "bytes_written" is treated in the same
245 // fashion as WriteTensor().
246 //
247 // Checksums all bytes written and stores it into "crc32c".
248 // REQUIRES: val.dtype() == DT_STRING
WriteStringTensor(const Tensor & val,FileOutputBuffer * out,size_t * bytes_written,uint32 * crc32c)249 Status WriteStringTensor(const Tensor& val, FileOutputBuffer* out,
250 size_t* bytes_written, uint32* crc32c) {
251 // On-disk format:
252 // [varint64 len0]..[varint64 lenL][4 byte cksum on lengths][string bytes]
253 // Var "crc32c" checksums the string lengths (as uint64, not varint64 bytes),
254 // the length-checksum, and all the string bytes.
255 DCHECK_EQ(val.dtype(), DT_STRING);
256 const tstring* strings = GetStringBackingBuffer(val);
257
258 // Writes the varint lengths.
259 string lengths;
260 lengths.reserve(val.NumElements()); // At least 1 byte per element.
261 *crc32c = 0;
262 for (int64_t i = 0; i < val.NumElements(); ++i) {
263 const tstring* elem = &strings[i];
264 DCHECK_EQ(elem->size(), static_cast<uint64>(elem->size()));
265 const uint64 elem_size = static_cast<uint64>(elem->size());
266
267 core::PutVarint64(&lengths, elem_size);
268 if (elem_size <= UINT32_MAX) {
269 // We need to do this because older checkpoints only used uint32s and we
270 // should still support them.
271 const uint32 elem_size_uint32 = static_cast<uint32>(elem_size);
272 *crc32c = crc32c::Extend(*crc32c,
273 reinterpret_cast<const char*>(&elem_size_uint32),
274 sizeof(uint32));
275 } else {
276 *crc32c = crc32c::Extend(
277 *crc32c, reinterpret_cast<const char*>(&elem_size), sizeof(uint64));
278 }
279 }
280 TF_RETURN_IF_ERROR(out->Append(lengths));
281 *bytes_written = lengths.size();
282
283 // Writes the length checksum.
284 const uint32 length_checksum = crc32c::Mask(*crc32c);
285 TF_RETURN_IF_ERROR(out->Append(StringPiece(
286 reinterpret_cast<const char*>(&length_checksum), sizeof(uint32))));
287 *crc32c = crc32c::Extend(
288 *crc32c, reinterpret_cast<const char*>(&length_checksum), sizeof(uint32));
289 *bytes_written += sizeof(uint32);
290
291 // Writes all the string bytes out.
292 for (int64_t i = 0; i < val.NumElements(); ++i) {
293 const tstring* string = &strings[i];
294 TF_RETURN_IF_ERROR(out->Append(*string));
295 *bytes_written += string->size();
296 *crc32c = crc32c::Extend(*crc32c, string->data(), string->size());
297 }
298 return OkStatus();
299 }
300
WriteVariantTensor(const Tensor & val,FileOutputBuffer * out,size_t * bytes_written,uint32 * crc32c)301 Status WriteVariantTensor(const Tensor& val, FileOutputBuffer* out,
302 size_t* bytes_written, uint32* crc32c) {
303 // On-disk format:
304 // [varint64 len1][bytes variant1][4 byte checksum]
305 // ..
306 // [varint64 lenN][bytes variantN][4 byte checksum]
307 // Var "crc32c" checksums all the lens, variant bytes, individual variant
308 // checksums (as uint32, not varint32 bytes).
309 DCHECK_EQ(val.dtype(), DT_VARIANT);
310
311 *crc32c = 0;
312 *bytes_written = 0;
313 for (int64_t i = 0; i < val.NumElements(); ++i) {
314 VariantTensorData data;
315 val.flat<Variant>()(i).Encode(&data);
316 VariantTensorDataProto proto;
317 data.ToProto(&proto);
318 string elem;
319 if (!proto.SerializeToString(&elem)) {
320 return errors::Unknown(
321 "Failed to serialize tensor data of size ", proto.ByteSizeLong(),
322 ". Tensor: ", val.flat<Variant>()(i).DebugString());
323 }
324
325 // Write the length of the serialized variant.
326 DCHECK_EQ(elem.size(), static_cast<uint64>(elem.size()));
327 const auto elem_size = static_cast<uint64>(elem.size());
328 string len;
329 core::PutVarint64(&len, elem_size);
330 TF_RETURN_IF_ERROR(out->Append(len));
331 *crc32c = crc32c::Extend(*crc32c, reinterpret_cast<const char*>(&elem_size),
332 sizeof(uint64));
333 *bytes_written += len.size();
334
335 // Write the serialized variant.
336 TF_RETURN_IF_ERROR(out->Append(elem));
337 *crc32c = crc32c::Extend(*crc32c, elem.data(), elem.size());
338 *bytes_written += elem.size();
339
340 // Write the checksum.
341 const uint32 length_checksum = crc32c::Mask(*crc32c);
342 TF_RETURN_IF_ERROR(out->Append(StringPiece(
343 reinterpret_cast<const char*>(&length_checksum), sizeof(uint32))));
344 *crc32c =
345 crc32c::Extend(*crc32c, reinterpret_cast<const char*>(&length_checksum),
346 sizeof(uint32));
347 *bytes_written += sizeof(uint32);
348 }
349
350 return OkStatus();
351 }
352
353 // Returns whether "slice_spec" is a full slice, with respect to the full shape.
354 //
355 // This can happen say, when "slice_spec" is
356 // "TensorSlice(full_tensor_shape.dims())", or when it is "TensorSlice({{0,
357 // dim(0)}, ..., {0, dim(N)}})" -- a degenerate case we need to guard against.
IsFullSlice(const TensorSlice & slice_spec,const TensorShape & full_tensor_shape)358 bool IsFullSlice(const TensorSlice& slice_spec,
359 const TensorShape& full_tensor_shape) {
360 if (slice_spec.IsFull()) {
361 return true;
362 } else {
363 TensorShape sliced_shape;
364 slice_spec.SliceTensorShape(full_tensor_shape, &sliced_shape).IgnoreError();
365 return sliced_shape == full_tensor_shape;
366 }
367 }
368
CorruptFileError(const Status & in_status,const string & filename,const string & detail)369 Status CorruptFileError(const Status& in_status, const string& filename,
370 const string& detail) {
371 if (in_status.ok()) {
372 return errors::Internal("Unable to read file (", filename,
373 "). Perhaps the file is corrupt or was produced by "
374 "a newer version of TensorFlow with format changes "
375 "(",
376 detail, ")");
377 }
378 return Status(
379 in_status.code(),
380 strings::StrCat("Unable to read file (", filename,
381 "). Perhaps the file is corrupt or was produced by a "
382 "newer version of TensorFlow with format changes (",
383 detail, "): ", in_status.error_message()));
384 }
385
TableBuilderOptions()386 table::Options TableBuilderOptions() {
387 table::Options o;
388 // Compressed tables cannot be read by TensorFlow releases prior to 1.1.
389 // To smoothen the transition, compressed writes are disabled for now
390 // (version 1.2) with the intention that they will be enabled again at
391 // some point (perhaps the 1.3 release?).
392 o.compression = table::kNoCompression;
393 return o;
394 }
395
396 // Writes zeros to output buffer to align the next write to the requested
397 // alignment. "size" is the current size of the buffer and is updated to the
398 // new size.
PadAlignment(FileOutputBuffer * out,int alignment,int64_t * size)399 Status PadAlignment(FileOutputBuffer* out, int alignment, int64_t* size) {
400 int bytes_over = *size % alignment;
401 if (bytes_over == 0) {
402 return OkStatus();
403 }
404 int bytes_to_write = alignment - bytes_over;
405 Status status = out->Append(string(bytes_to_write, '\0'));
406 if (status.ok()) {
407 *size += bytes_to_write;
408 }
409 return status;
410 }
411
412 } // namespace
413
BundleWriter(Env * env,StringPiece prefix,const Options & options)414 BundleWriter::BundleWriter(Env* env, StringPiece prefix, const Options& options)
415 : env_(env), options_(options), prefix_(prefix), out_(nullptr), size_(0) {
416 status_ = env_->HasAtomicMove(prefix_, &use_temp_file_);
417 if (!status_.ok()) return;
418
419 data_path_ = DataFilename(prefix_, 0, 1);
420 metadata_path_ = MetaFilename(prefix_);
421 if (use_temp_file_) {
422 data_path_ = strings::StrCat(data_path_, ".tempstate", random::New64());
423 metadata_path_ =
424 strings::StrCat(metadata_path_, ".tempstate", random::New64());
425 }
426
427 status_ = env_->CreateDir(string(io::Dirname(prefix_)));
428 if (!status_.ok() && !errors::IsAlreadyExists(status_)) {
429 return;
430 }
431
432 std::unique_ptr<WritableFile> wrapper;
433 status_ = env_->NewWritableFile(data_path_, &wrapper);
434 if (!status_.ok()) return;
435 out_ = std::unique_ptr<FileOutputBuffer>(
436 new FileOutputBuffer(wrapper.release(), 8 << 20 /* 8MB write buffer */));
437
438 VLOG(1) << "Writing to file " << data_path_;
439 }
440
Add(StringPiece key,const Tensor & val)441 Status BundleWriter::Add(StringPiece key, const Tensor& val) {
442 if (!status_.ok()) return status_;
443 CHECK_NE(key, kHeaderEntryKey);
444 const string key_string(key);
445 if (entries_.find(key_string) != entries_.end()) {
446 status_ = errors::InvalidArgument("Adding duplicate key: ", key);
447 return status_;
448 }
449
450 BundleEntryProto* entry = &entries_[key_string];
451 entry->set_dtype(val.dtype());
452 val.shape().AsProto(entry->mutable_shape());
453 entry->set_shard_id(0);
454 entry->set_offset(size_);
455
456 // Updates the data file.
457 size_t data_bytes_written = 0;
458 uint32 crc32c = 0;
459 out_->clear_crc32c();
460 if (val.dtype() == DT_STRING) {
461 status_ = WriteStringTensor(val, out_.get(), &data_bytes_written, &crc32c);
462 } else if (val.dtype() == DT_VARIANT) {
463 status_ = WriteVariantTensor(val, out_.get(), &data_bytes_written, &crc32c);
464 } else {
465 status_ = WriteTensor(val, out_.get(), &data_bytes_written);
466 crc32c = out_->crc32c();
467 }
468
469 if (status_.ok()) {
470 entry->set_size(data_bytes_written);
471 entry->set_crc32c(crc32c::Mask(crc32c));
472 size_ += data_bytes_written;
473 status_ = PadAlignment(out_.get(), options_.data_alignment, &size_);
474 }
475 return status_;
476 }
477
AddSlice(StringPiece full_tensor_key,const TensorShape & full_tensor_shape,const TensorSlice & slice_spec,const Tensor & slice_tensor)478 Status BundleWriter::AddSlice(StringPiece full_tensor_key,
479 const TensorShape& full_tensor_shape,
480 const TensorSlice& slice_spec,
481 const Tensor& slice_tensor) {
482 if (!status_.ok()) return status_;
483 CHECK_NE(full_tensor_key, kHeaderEntryKey);
484
485 // If just a singleton full slice, use the regular Add() to be more efficient.
486 if (IsFullSlice(slice_spec, full_tensor_shape)) {
487 return Add(full_tensor_key, slice_tensor);
488 }
489
490 // Inserts/updates the full tensor's metadata entry.
491 //
492 // In the case of a sharded save, MergeBundles() is responsible for merging
493 // the "slices" field of multiple metadata entries corresponding to the same
494 // full tensor.
495 const string full_tensor_key_string(full_tensor_key);
496 BundleEntryProto* full_entry = &entries_[full_tensor_key_string];
497 if (full_entry->dtype() != DT_INVALID) {
498 CHECK_EQ(full_entry->dtype(), slice_tensor.dtype());
499 }
500 if (full_entry->has_shape()) {
501 CHECK(TensorShape(full_entry->shape()) == full_tensor_shape);
502 }
503
504 // Populates dtype, shape, and slices. Intentionally leaving out shard_id and
505 // offset, which do not make sense for this full tensor entry.
506 full_entry->set_dtype(slice_tensor.dtype());
507 full_tensor_shape.AsProto(full_entry->mutable_shape());
508 TensorSliceProto* slice_proto = full_entry->add_slices();
509 slice_spec.AsProto(slice_proto);
510
511 // The slice itself is handled by a regular Add(), which includes adding its
512 // own metadata entry, and writing out the slice's values.
513 const string slice_name =
514 checkpoint::EncodeTensorNameSlice(full_tensor_key_string, slice_spec);
515 status_ = Add(slice_name, slice_tensor);
516 return status_;
517 }
518
519 // TODO(zongheng): on metadata write failure or !status_.ok(), consider removing
520 // the orphaned data file.
Finish()521 Status BundleWriter::Finish() {
522 if (out_) {
523 status_.Update(out_->Close());
524 out_ = nullptr;
525 if (status_.ok()) {
526 if (use_temp_file_) {
527 status_ =
528 Env::Default()->RenameFile(data_path_, DataFilename(prefix_, 0, 1));
529 }
530 } else {
531 Env::Default()->DeleteFile(data_path_).IgnoreError();
532 }
533 }
534 if (!status_.ok()) return status_;
535 // Build key -> BundleEntryProto table.
536 std::unique_ptr<WritableFile> file;
537 status_ = env_->NewWritableFile(metadata_path_, &file);
538 if (!status_.ok()) return status_;
539 {
540 // N.B.: the default use of Snappy compression may not be supported on all
541 // platforms (e.g. Android). The metadata file is small, so this is fine.
542 table::Options options;
543 options.compression = table::kNoCompression;
544 table::TableBuilder builder(options, file.get());
545 // Header entry.
546 BundleHeaderProto header;
547 header.set_num_shards(1);
548 header.set_endianness(BundleHeaderProto::LITTLE);
549 if (!port::kLittleEndian) header.set_endianness(BundleHeaderProto::BIG);
550 VersionDef* version = header.mutable_version();
551 version->set_producer(kTensorBundleVersion);
552 version->set_min_consumer(kTensorBundleMinConsumer);
553
554 builder.Add(kHeaderEntryKey, header.SerializeAsString());
555
556 // All others.
557 for (const auto& p : entries_) {
558 builder.Add(p.first, p.second.SerializeAsString());
559 }
560 status_ = builder.Finish();
561 }
562 status_.Update(file->Close());
563 if (!status_.ok()) {
564 Env::Default()->DeleteFile(metadata_path_).IgnoreError();
565 return status_;
566 } else if (use_temp_file_) {
567 status_ = Env::Default()->RenameFile(metadata_path_, MetaFilename(prefix_));
568 if (!status_.ok()) return status_;
569 }
570 status_ = errors::Internal("BundleWriter is closed");
571 return OkStatus();
572 }
573
574 // Merging tensor bundles.
575
576 // Accumulator of metadata states during a merge.
577 struct MergeState {
578 // Accumulated from the header entries.
579 int num_shards = 0;
580
581 // Derives "endianness" and "version" from the first bundle merged (hence the
582 // "seen_first_bundle" guard). The two fields must be the same for all
583 // bundles in a merge.
584 bool seen_first_bundle = false;
585 BundleHeaderProto_Endianness endianness;
586 VersionDef version;
587
588 // Tensor key -> BundleEntryProto.
589 std::map<string, BundleEntryProto> entries;
590 // Data file path -> new shard id in the final merged bundle.
591 std::unordered_map<string, int32> shard_ids;
592 };
593
594 // Merges entries of "prefix" into the accumulator state "merge".
595 // Returns OK iff the merge succeeds.
MergeOneBundle(Env * env,StringPiece prefix,MergeState * merge_state)596 static Status MergeOneBundle(Env* env, StringPiece prefix,
597 MergeState* merge_state) {
598 VLOG(1) << "Merging bundle:" << prefix;
599 const string filename = MetaFilename(prefix);
600 uint64 file_size;
601 TF_RETURN_IF_ERROR(env->GetFileSize(filename, &file_size));
602 std::unique_ptr<RandomAccessFile> file;
603 TF_RETURN_IF_ERROR(env->NewRandomAccessFile(filename, &file));
604
605 table::Table* table = nullptr;
606 TF_RETURN_IF_ERROR(
607 table::Table::Open(TableBuilderOptions(), file.get(), file_size, &table));
608 std::unique_ptr<table::Table> table_deleter(table);
609 std::unique_ptr<table::Iterator> iter(table->NewIterator());
610
611 int num_shards;
612 // Process header.
613 {
614 iter->Seek(kHeaderEntryKey);
615 if (!iter->Valid()) {
616 return CorruptFileError(iter->status(), filename,
617 "failed to seek to header entry");
618 }
619 BundleHeaderProto header;
620 Status s = ParseEntryProto(iter->key(), iter->value(), &header);
621 if (!s.ok()) return CorruptFileError(s, filename, "unable to parse header");
622
623 merge_state->num_shards += header.num_shards();
624 if (!merge_state->seen_first_bundle) {
625 merge_state->seen_first_bundle = true;
626 merge_state->endianness = header.endianness();
627 merge_state->version = header.version();
628 } else {
629 // Validates "endianness".
630 if (merge_state->endianness != header.endianness()) {
631 return errors::InvalidArgument(
632 "Merging bundles with conflicting endianness; inputs corrupted?");
633 }
634 // Validates "version".
635 string curr_version, merge_version;
636 header.version().SerializeToString(&curr_version);
637 merge_state->version.SerializeToString(&merge_version);
638 if (curr_version != merge_version) {
639 return errors::InvalidArgument(
640 "Merging bundles with different format versions: merged ",
641 merge_version, " vs. curr ", curr_version);
642 }
643 }
644 num_shards = header.num_shards();
645 iter->Next();
646 }
647
648 // Loops through the non-header to-merge entries.
649 BundleEntryProto to_merge_entry;
650 for (; iter->Valid(); iter->Next()) {
651 const string key(iter->key());
652 const auto entry_iter = merge_state->entries.find(key);
653
654 // Illegal: the duplicated entry is a non-slice tensor.
655 if (entry_iter != merge_state->entries.end() &&
656 entry_iter->second.slices().empty()) {
657 return errors::InvalidArgument(
658 "Duplicate tensor keyed by ", key,
659 " encountered, when merging prefix: ", prefix);
660 }
661
662 TF_RETURN_IF_ERROR(
663 ParseEntryProto(iter->key(), iter->value(), &to_merge_entry));
664
665 // The duplicated entry holds metadata for a sliced full tensor.
666 // Allows the duplication and merges "slices".
667 if (entry_iter != merge_state->entries.end()) {
668 BundleEntryProto& existing_entry = entry_iter->second;
669 if (to_merge_entry.slices().empty()) {
670 return errors::Internal(
671 "Duplicate tensor keyed by ", key,
672 "; attempting to merge in a non-slice bundle entry");
673 }
674 // Only needs merge the "slices" field (and validate dtype/shape).
675 for (int i = 0; i < to_merge_entry.slices_size(); ++i) {
676 TensorSliceProto* slot = existing_entry.add_slices();
677 *slot = to_merge_entry.slices(i);
678 }
679 CHECK_EQ(existing_entry.dtype(), to_merge_entry.dtype());
680 CHECK(TensorShape(existing_entry.shape()) ==
681 TensorShape(to_merge_entry.shape()));
682 continue;
683 }
684
685 // Key doesn't duplicate: a fresh tensor/slice entry.
686 auto result = merge_state->shard_ids.insert(
687 {DataFilename(prefix, to_merge_entry.shard_id(), num_shards),
688 merge_state->shard_ids.size()});
689 to_merge_entry.set_shard_id(result.first->second);
690 merge_state->entries[key] = to_merge_entry;
691 }
692 return OkStatus();
693 }
694
MergeBundles(Env * env,gtl::ArraySlice<tstring> prefixes,StringPiece merged_prefix,bool allow_missing_files)695 Status MergeBundles(Env* env, gtl::ArraySlice<tstring> prefixes,
696 StringPiece merged_prefix, bool allow_missing_files) {
697 // Merges all metadata tables.
698 // TODO(zhifengc): KeyValue sorter if it becomes too big.
699 MergeState merge;
700 Status status = env->CreateDir(string(io::Dirname(merged_prefix)));
701 if (!status.ok() && !errors::IsAlreadyExists(status)) return status;
702 bool atleast_one_file_exists = false;
703 for (auto& prefix : prefixes) {
704 if (!env->FileExists(MetaFilename(prefix)).ok()) {
705 if (allow_missing_files) continue;
706 return errors::InvalidArgument(
707 "allow_missing_files was set to false and ", prefix,
708 " did not exist.", env->FileExists(prefix).ToString());
709 }
710 atleast_one_file_exists = true;
711 TF_RETURN_IF_ERROR(MergeOneBundle(env, prefix, &merge));
712 }
713 if (!atleast_one_file_exists) {
714 return errors::InvalidArgument(
715 "At least one prefix checkpoint file must exist, but none existed.");
716 }
717 // Renames data files to contain the merged bundle prefix.
718 for (const auto& p : merge.shard_ids) {
719 VLOG(1) << "Renaming " << p.first << " to "
720 << DataFilename(merged_prefix, p.second, merge.shard_ids.size());
721 TF_RETURN_IF_ERROR(env->RenameFile(
722 p.first,
723 DataFilename(merged_prefix, p.second, merge.shard_ids.size())));
724 }
725
726 // Writes the final metadata table under the merged prefix.
727 std::unique_ptr<WritableFile> merged_metadata;
728 TF_RETURN_IF_ERROR(
729 env->NewWritableFile(MetaFilename(merged_prefix), &merged_metadata));
730 {
731 table::TableBuilder builder(TableBuilderOptions(), merged_metadata.get());
732 // Header entry.
733 BundleHeaderProto header;
734 header.set_num_shards(merge.num_shards);
735 header.set_endianness(merge.endianness);
736 *header.mutable_version() = merge.version;
737 builder.Add(kHeaderEntryKey, header.SerializeAsString());
738 // All others.
739 for (const auto& p : merge.entries) {
740 builder.Add(p.first, p.second.SerializeAsString());
741 }
742 status = builder.Finish();
743 }
744 status.Update(merged_metadata->Close());
745 if (!status.ok()) return status;
746 VLOG(1) << "Merged bundles to:" << merged_prefix;
747
748 // Cleanup: best effort based and ignores errors.
749 for (const tstring& prefix : prefixes) {
750 env->DeleteFile(MetaFilename(prefix)).IgnoreError();
751 }
752 return status;
753 }
754
755 // Interface for reading a tensor bundle.
756
BundleReader(Env * env,StringPiece prefix)757 BundleReader::BundleReader(Env* env, StringPiece prefix)
758 : env_(env),
759 prefix_(prefix),
760 metadata_(nullptr),
761 table_(nullptr),
762 index_cache_(nullptr),
763 iter_(nullptr),
764 need_to_swap_bytes_(false) {
765 const string filename = MetaFilename(prefix_);
766 uint64 file_size;
767 status_ = env_->GetFileSize(filename, &file_size);
768 if (!status_.ok()) return;
769
770 // Opens the metadata table.
771 std::unique_ptr<RandomAccessFile> wrapper;
772 status_ = env_->NewRandomAccessFile(filename, &wrapper);
773 if (!status_.ok()) return;
774 metadata_ = wrapper.release();
775
776 table::Options o;
777 int64_t cache_size;
778 Status s =
779 ReadInt64FromEnvVar("TF_TABLE_INDEX_CACHE_SIZE_IN_MB", 0, &cache_size);
780 if (s.ok() && cache_size > 0) {
781 index_cache_ = table::NewLRUCache(cache_size << 20);
782 o.block_cache = index_cache_;
783 }
784
785 status_ = table::Table::Open(o, metadata_, file_size, &table_);
786 if (!status_.ok()) return;
787 iter_ = table_->NewIterator();
788
789 // Reads "num_shards_" from the first entry.
790 iter_->Seek(kHeaderEntryKey);
791 if (!iter_->Valid()) {
792 status_ = CorruptFileError(iter_->status(), filename,
793 "failed to seek to header entry");
794 return;
795 }
796 BundleHeaderProto header;
797 status_ = ParseEntryProto(iter_->key(), iter_->value(), &header);
798 if (!status_.ok()) {
799 status_ = CorruptFileError(status_, filename, "unable to parse header");
800 return;
801 }
802 num_shards_ = header.num_shards();
803 if ((header.endianness() == BundleHeaderProto::BIG && port::kLittleEndian) ||
804 (header.endianness() == BundleHeaderProto::LITTLE &&
805 !port::kLittleEndian)) {
806 need_to_swap_bytes_ = true;
807 }
808 status_ = CheckVersions(header.version(), kTensorBundleVersion,
809 kTensorBundleMinProducer, "Checkpoint", "checkpoint");
810 }
811
~BundleReader()812 BundleReader::~BundleReader() {
813 delete metadata_;
814 delete iter_;
815 delete table_;
816 if (index_cache_) {
817 delete index_cache_;
818 }
819 // InputBuffer does not own the underlying RandomAccessFile.
820 for (auto pair : data_) {
821 if (pair.second != nullptr && pair.second->file() != nullptr) {
822 delete pair.second->file();
823 }
824 }
825 for (auto& temp : data_) {
826 delete temp.second;
827 }
828 for (auto& temp : tensor_slices_) {
829 delete temp.second;
830 }
831 data_.clear();
832 tensor_slices_.clear();
833 }
834
GetBundleEntryProto(StringPiece key,BundleEntryProto * entry)835 Status BundleReader::GetBundleEntryProto(StringPiece key,
836 BundleEntryProto* entry) {
837 entry->Clear();
838 TF_CHECK_OK(status_);
839 Seek(key);
840 if (!iter_->Valid() || iter_->key() != key) {
841 return errors::NotFound("Key ", key, " not found in checkpoint");
842 }
843
844 BundleEntryProto entry_copy;
845 TF_RETURN_IF_ERROR(
846 ParseEntryProto(iter_->key(), iter_->value(), &entry_copy));
847 if (!TensorShape::IsValid(entry_copy.shape())) {
848 return errors::DataLoss("Invalid tensor shape: ", key, " ",
849 entry_copy.shape().ShortDebugString());
850 }
851
852 entry->Swap(&entry_copy);
853 return OkStatus();
854 }
855
GetValue(const BundleEntryProto & entry,Tensor * val)856 Status BundleReader::GetValue(const BundleEntryProto& entry, Tensor* val) {
857 Tensor* ret = val;
858 const TensorShape stored_shape(TensorShape(entry.shape()));
859 if (val->NumElements() == 0) {
860 ret = new Tensor(entry.dtype(), stored_shape);
861 }
862
863 // Validates the "size" field.
864 if (entry.dtype() != DT_STRING && entry.dtype() != DT_VARIANT) {
865 if (entry.size() != ret->TotalBytes()) {
866 return errors::DataLoss("Invalid size in bundle entry: key ", key(),
867 "; stored size ", entry.size(),
868 "; expected size ", ret->TotalBytes());
869 }
870 } else if (entry.dtype() == DT_STRING) {
871 // Relaxes the check for string tensors as follows:
872 // entry.size() == bytes(varint lengths) + bytes(data)
873 // >= NumElems + bytes(data), since size bytes(varint) >= 1.
874 // TotalBytes() == sizeof(tstring) * NumElems + bytes(data)
875 // Since we don't know bytes(varint lengths), we just check an inequality.
876 const size_t lower_bound = ret->NumElements() + ret->TotalBytes() -
877 sizeof(tstring) * ret->NumElements();
878 if (entry.size() < lower_bound) {
879 return errors::DataLoss("Invalid size in bundle entry: key ", key(),
880 "; stored size ", entry.size(),
881 "; expected size is at least ", lower_bound);
882 }
883 }
884
885 // Open the data file if it has not been opened.
886 io::InputBuffer* buffered_file = data_[entry.shard_id()];
887 if (buffered_file == nullptr) {
888 std::unique_ptr<RandomAccessFile> file = nullptr;
889 TF_RETURN_IF_ERROR(env_->NewRandomAccessFile(
890 DataFilename(prefix_, entry.shard_id(), num_shards_), &file));
891 buffered_file = new io::InputBuffer(file.release(), kBufferSize);
892 // The InputBuffer and RandomAccessFile objects are both released in dtor.
893 data_[entry.shard_id()] = buffered_file;
894 }
895 CHECK(buffered_file != nullptr);
896
897 TF_RETURN_IF_ERROR(buffered_file->Seek(entry.offset()));
898 uint32 actual_crc32c = 0;
899
900 if (DataTypeCanUseMemcpy(entry.dtype())) {
901 char* backing_buffer = const_cast<char*>((ret->tensor_data().data()));
902 size_t unused_bytes_read;
903 if (entry.size() > kBufferSize) {
904 StringPiece sp;
905 TF_RETURN_IF_ERROR(buffered_file->file()->Read(
906 entry.offset(), entry.size(), &sp, backing_buffer));
907 if (sp.data() != backing_buffer) {
908 memmove(backing_buffer, sp.data(), entry.size());
909 }
910 } else {
911 TF_RETURN_IF_ERROR(buffered_file->ReadNBytes(entry.size(), backing_buffer,
912 &unused_bytes_read));
913 }
914 // Note that we compute the checksum *before* byte-swapping. The checksum
915 // should be on the bytes in the order they appear in the file.
916 actual_crc32c = crc32c::Value(backing_buffer, entry.size());
917 if (need_to_swap_bytes_) {
918 TF_RETURN_IF_ERROR(ByteSwapTensor(ret));
919 }
920 } else if (entry.dtype() == DT_VARIANT) {
921 if (need_to_swap_bytes_) {
922 return errors::Unimplemented(
923 "TensorBundle at ", prefix_,
924 "is of a different endianness than this machine's hardware, and "
925 "the bundle contains a variant (arbitrary C++ type) tensor. "
926 "Byte-swapping of variant tensors is not currently implemented.");
927 }
928 // Relies on io::InputBuffer's buffering, because we issue many neighboring
929 // reads for a single string tensor.
930 TF_RETURN_IF_ERROR(ReadVariantTensor(buffered_file, ret, entry.offset(),
931 entry.size(), &actual_crc32c));
932 } else {
933 // Relies on io::InputBuffer's buffering, because we issue many neighboring
934 // reads for a single string tensor.
935 TF_RETURN_IF_ERROR(ReadStringTensor(
936 buffered_file, ret->NumElements(), entry.offset(), entry.size(),
937 GetStringBackingBuffer(*ret), &actual_crc32c, need_to_swap_bytes_));
938 }
939 if (crc32c::Unmask(entry.crc32c()) != actual_crc32c) {
940 return errors::DataLoss(
941 "TensorBundle at ", prefix_, " shard ", entry.shard_id(), " (",
942 entry.size(), " bytes): Checksum does not match: stored ",
943 strings::Printf("%08u", crc32c::Unmask(entry.crc32c())),
944 " vs. calculated on the restored bytes ", actual_crc32c);
945 }
946
947 *val = *ret;
948 if (ret != val) delete ret;
949 return OkStatus();
950 }
951
Lookup(StringPiece key,Tensor * val)952 Status BundleReader::Lookup(StringPiece key, Tensor* val) {
953 CHECK(val != nullptr);
954 BundleEntryProto entry;
955 TF_RETURN_IF_ERROR(GetBundleEntryProto(key, &entry));
956
957 if (entry.slices().empty()) {
958 return GetValue(entry, val);
959 } else {
960 return GetSliceValue(
961 key, entry,
962 /* a full slice */ TensorSlice(TensorShape(entry.shape()).dims()), val);
963 }
964 }
965
ReadCurrent(Tensor * val)966 Status BundleReader::ReadCurrent(Tensor* val) {
967 CHECK(val != nullptr);
968 BundleEntryProto entry;
969 TF_RETURN_IF_ERROR(ParseEntryProto(iter_->key(), iter_->value(), &entry));
970 if (!TensorShape::IsValid(entry.shape())) {
971 return errors::DataLoss("Invalid tensor shape: ", iter_->key(), " ",
972 entry.shape().ShortDebugString());
973 }
974
975 if (entry.slices().empty()) {
976 return GetValue(entry, val);
977 } else {
978 return GetSliceValue(
979 iter_->key(), entry,
980 /* a full slice */ TensorSlice(TensorShape(entry.shape()).dims()), val);
981 }
982 }
983
LookupTensorSlices(StringPiece key,std::vector<TensorSlice> * slices)984 Status BundleReader::LookupTensorSlices(StringPiece key,
985 std::vector<TensorSlice>* slices) {
986 slices->clear();
987 BundleEntryProto entry;
988 TF_RETURN_IF_ERROR(GetBundleEntryProto(key, &entry));
989 slices->reserve(entry.slices_size());
990 for (const auto& slice : entry.slices()) {
991 slices->emplace_back(slice);
992 }
993 return OkStatus();
994 }
995
LookupSlice(StringPiece full_tensor_key,const TensorSlice & slice_spec,Tensor * val)996 Status BundleReader::LookupSlice(StringPiece full_tensor_key,
997 const TensorSlice& slice_spec, Tensor* val) {
998 CHECK(val != nullptr);
999 BundleEntryProto entry;
1000 TF_RETURN_IF_ERROR(GetBundleEntryProto(full_tensor_key, &entry));
1001 return GetSliceValue(full_tensor_key, entry, slice_spec, val);
1002 }
1003
GetSliceValue(StringPiece full_tensor_key,const BundleEntryProto & full_tensor_entry,const TensorSlice & slice_spec,Tensor * val)1004 Status BundleReader::GetSliceValue(StringPiece full_tensor_key,
1005 const BundleEntryProto& full_tensor_entry,
1006 const TensorSlice& slice_spec, Tensor* val) {
1007 using checkpoint::RegisterTensorSlice;
1008 using checkpoint::TensorSliceSet;
1009 DCHECK_GE(full_tensor_entry.slices_size(), 0);
1010
1011 const TensorShape full_shape(TensorShape(full_tensor_entry.shape()));
1012 std::vector<std::pair<TensorSlice, string>> details;
1013 const string full_tensor_key_string(full_tensor_key);
1014 const TensorSliceSet* tss =
1015 gtl::FindPtrOrNull(tensor_slices_, full_tensor_key_string);
1016
1017 // Populates the "full tensor key -> TensorSliceSet" cache.
1018 if (tss == nullptr) {
1019 if (full_tensor_entry.slices().empty()) {
1020 // Special case: a writer has saved a tensor fully, but the reader wants
1021 // to read in slices. We therefore register the full slice on-demand here
1022 // without further complicating the on-disk bundle format.
1023 TF_RETURN_IF_ERROR(RegisterTensorSlice(
1024 full_tensor_key_string, full_shape, full_tensor_entry.dtype(),
1025 /* tag */ "",
1026 /* full slice */ TensorSlice(full_shape.dims()), &tensor_slices_));
1027 }
1028 for (const TensorSliceProto& slice : full_tensor_entry.slices()) {
1029 TF_RETURN_IF_ERROR(RegisterTensorSlice(
1030 full_tensor_key_string, full_shape, full_tensor_entry.dtype(),
1031 /* tag */ "", TensorSlice(slice), &tensor_slices_));
1032 }
1033 tss = gtl::FindPtrOrNull(tensor_slices_, full_tensor_key_string);
1034 CHECK_NE(tss, nullptr);
1035 }
1036 if (!tss->QueryMeta(slice_spec, &details)) {
1037 return errors::InvalidArgument(
1038 "Does not have sufficient slices for partitioned tensor ",
1039 full_tensor_key,
1040 " to restore in slice_spec: ", slice_spec.DebugString());
1041 }
1042
1043 // The union of the slices in "details" covers "slice_spec". Performs the
1044 // copies from each.
1045 BundleEntryProto stored_slice_entry = full_tensor_entry;
1046 for (const auto& slice_tag_pair : details) {
1047 // Seeks for the stored slice.
1048 const TensorSlice& stored_slice = slice_tag_pair.first;
1049
1050 // We already have the entry for the full tensor, so don't query again if
1051 // the slice is full.
1052 if (!stored_slice.IsFull()) {
1053 const string encoded_stored_slice_name =
1054 checkpoint::EncodeTensorNameSlice(full_tensor_key_string,
1055 stored_slice);
1056 status_ =
1057 GetBundleEntryProto(encoded_stored_slice_name, &stored_slice_entry);
1058 if (!status_.ok()) return status_;
1059 }
1060
1061 // TODO(zongheng): should we take an OpKernelContext, so that we can call
1062 // allocate_temp()? Note that without major refactorings to Saver, it's
1063 // hard for the caller of the tensor bundle module to allocate these
1064 // precisely-shaped scratch storage.
1065
1066 // Optimization for the common case: the stored slice can be directly
1067 // copied to the destination without additional slicing. This is true when
1068 // either the slices are equal or when they are both full slices having the
1069 // same shape.
1070 TensorShape stored_slice_shape(stored_slice_entry.shape());
1071 if (stored_slice == slice_spec ||
1072 (stored_slice_shape == val->shape() &&
1073 IsFullSlice(stored_slice, stored_slice_shape) &&
1074 IsFullSlice(slice_spec, stored_slice_shape))) {
1075 VLOG(1) << "Optimized for common case: directly copying into "
1076 "pre-allocated buffer; spec: "
1077 << slice_spec.DebugString();
1078 status_ = GetValue(stored_slice_entry, val);
1079 return status_;
1080 }
1081
1082 Tensor stored_slice_tensor(stored_slice_entry.dtype(), stored_slice_shape);
1083 status_ = GetValue(stored_slice_entry, &stored_slice_tensor);
1084 if (!status_.ok()) return status_;
1085
1086 // Copies the intersection over.
1087 const DataType common_dtype = full_tensor_entry.dtype();
1088 switch (common_dtype) {
1089 #define HANDLE_COPY(T) \
1090 case DataTypeToEnum<T>::value: \
1091 CHECK(CopyDataFromTensorSliceToTensorSlice( \
1092 full_shape, stored_slice, slice_spec, \
1093 stored_slice_tensor.flat<T>().data(), val->flat<T>().data())); \
1094 break;
1095
1096 HANDLE_COPY(float)
1097 HANDLE_COPY(double)
1098 HANDLE_COPY(int32)
1099 HANDLE_COPY(uint8)
1100 HANDLE_COPY(int16)
1101 HANDLE_COPY(int8)
1102 HANDLE_COPY(complex64)
1103 HANDLE_COPY(complex128)
1104 HANDLE_COPY(int64_t)
1105 HANDLE_COPY(bool)
1106 HANDLE_COPY(qint32)
1107 HANDLE_COPY(quint8)
1108 HANDLE_COPY(qint8)
1109 HANDLE_COPY(bfloat16)
1110 default:
1111 return errors::InvalidArgument("Dtype ", DataTypeString(common_dtype),
1112 " not supported.");
1113 }
1114 #undef HANDLE_COPY
1115 }
1116 return OkStatus();
1117 }
1118
Contains(StringPiece key)1119 bool BundleReader::Contains(StringPiece key) {
1120 Seek(key);
1121 return Valid() && (this->key() == key);
1122 }
1123
LookupDtypeAndShape(StringPiece key,DataType * dtype,TensorShape * shape)1124 Status BundleReader::LookupDtypeAndShape(StringPiece key, DataType* dtype,
1125 TensorShape* shape) {
1126 BundleEntryProto entry;
1127 TF_RETURN_IF_ERROR(GetBundleEntryProto(key, &entry));
1128 *dtype = entry.dtype();
1129 *shape = TensorShape(entry.shape());
1130 return OkStatus();
1131 }
1132
LookupTensorShape(StringPiece key,TensorShape * shape)1133 Status BundleReader::LookupTensorShape(StringPiece key, TensorShape* shape) {
1134 DataType ignored;
1135 return LookupDtypeAndShape(key, &ignored, shape);
1136 }
1137
DebugString()1138 string BundleReader::DebugString() {
1139 // Format used below emulates that of TensorSliceReader::DebugString().
1140 string shape_str;
1141 BundleEntryProto entry;
1142 Seek(kHeaderEntryKey);
1143 for (Next(); Valid(); Next()) {
1144 CHECK(entry.ParseFromArray(value().data(), value().size()));
1145 if (entry.slices_size() > 0) continue; // Slice of some partitioned var.
1146
1147 strings::StrAppend(&shape_str, key(), " (", DataType_Name(entry.dtype()),
1148 ") ", TensorShape(entry.shape()).DebugString());
1149 strings::StrAppend(&shape_str, "\n");
1150 }
1151 return shape_str;
1152 }
1153
1154 namespace {
AlignedMalloc(size_t size)1155 inline char* AlignedMalloc(size_t size) {
1156 char* buffer = static_cast<char*>(port::AlignedMalloc(size, 64));
1157 DCHECK(buffer);
1158 return buffer;
1159 }
1160 } // namespace
1161
FileOutputBuffer(WritableFile * file,size_t buffer_size)1162 FileOutputBuffer::FileOutputBuffer(WritableFile* file, size_t buffer_size)
1163 : file_(file), position_(0), buffer_size_(buffer_size) {
1164 DCHECK_GT(buffer_size, 0);
1165 buffer_ptr_ = AlignedMalloc(buffer_size);
1166 }
1167
~FileOutputBuffer()1168 FileOutputBuffer::~FileOutputBuffer() {
1169 if (buffer_ptr_) port::AlignedFree(buffer_ptr_);
1170 delete file_;
1171 }
1172
Append(StringPiece data)1173 Status FileOutputBuffer::Append(StringPiece data) {
1174 // In the below, it is critical to calculate the checksum on the actually
1175 // copied bytes, not the source bytes. This is because "data" typically
1176 // points to tensor buffers, which may be concurrently written.
1177 if (data.size() + position_ <= buffer_size_) {
1178 // Can fit into the current buffer.
1179 memcpy(buffer_ptr_ + position_, data.data(), data.size());
1180 crc32c_ = crc32c::Extend(crc32c_, buffer_ptr_ + position_, data.size());
1181 } else if (data.size() <= buffer_size_) {
1182 // Cannot fit, but can fit after flushing.
1183 TF_RETURN_IF_ERROR(FlushBuffer(false));
1184 memcpy(buffer_ptr_, data.data(), data.size());
1185 crc32c_ = crc32c::Extend(crc32c_, buffer_ptr_, data.size());
1186 } else {
1187 // Cannot fit even after flushing. So we break down "data" by chunk, and
1188 // flush/checksum each chunk.
1189 TF_RETURN_IF_ERROR(FlushBuffer(false));
1190 for (size_t i = 0; i < data.size(); i += buffer_size_) {
1191 const size_t nbytes = std::min(data.size() - i, buffer_size_);
1192 memcpy(buffer_ptr_, data.data() + i, nbytes);
1193 crc32c_ = crc32c::Extend(crc32c_, buffer_ptr_, nbytes);
1194 position_ = nbytes;
1195 TF_RETURN_IF_ERROR(FlushBuffer(false));
1196 }
1197 return OkStatus();
1198 }
1199 position_ += data.size();
1200 return OkStatus();
1201 }
1202
Close()1203 Status FileOutputBuffer::Close() {
1204 TF_RETURN_IF_ERROR(FlushBuffer(true));
1205 return file_->Close();
1206 }
1207
FlushBuffer(bool closing)1208 Status FileOutputBuffer::FlushBuffer(bool closing) {
1209 if (position_ > 0) {
1210 // Use Cord to avoid extra data copy for some WritableFile implementations.
1211 absl::Cord buffer = absl::MakeCordFromExternal(
1212 StringPiece(buffer_ptr_, position_),
1213 [ptr = buffer_ptr_](StringPiece) { port::AlignedFree(ptr); });
1214 buffer_ptr_ = closing ? nullptr : AlignedMalloc(buffer_size_);
1215 TF_RETURN_IF_ERROR(file_->Append(buffer));
1216 position_ = 0;
1217 }
1218 return OkStatus();
1219 }
1220
1221 } // namespace tensorflow
1222