1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/util/tensor_bundle/tensor_bundle.h"
17
18 #include <random>
19 #include <string>
20 #include <vector>
21
22 #if defined(_WIN32)
23 #include <windows.h>
24 #endif // _WIN32
25
26 #include "tensorflow/core/framework/tensor_testutil.h"
27 #include "tensorflow/core/framework/tensor_util.h"
28 #include "tensorflow/core/framework/types.pb.h"
29 #include "tensorflow/core/framework/variant.h"
30 #include "tensorflow/core/framework/variant_op_registry.h"
31 #include "tensorflow/core/framework/versions.pb.h"
32 #include "tensorflow/core/lib/core/status_test_util.h"
33 #include "tensorflow/core/lib/io/path.h"
34 #include "tensorflow/core/lib/io/table_builder.h"
35 #include "tensorflow/core/lib/strings/str_util.h"
36 #include "tensorflow/core/lib/strings/strcat.h"
37 #include "tensorflow/core/platform/test.h"
38 #include "tensorflow/core/platform/test_benchmark.h"
39 #include "tensorflow/core/protobuf/error_codes.pb.h"
40 #include "tensorflow/core/protobuf/tensor_bundle.pb.h"
41 #include "tensorflow/core/util/tensor_bundle/byte_swap.h"
42
43 namespace tensorflow {
44 using ::testing::ElementsAre;
45
46 namespace {
47
48 // Prepend the current test case's working temporary directory to <prefix>
Prefix(const string & prefix)49 string Prefix(const string& prefix) {
50 return strings::StrCat(testing::TmpDir(), "/", prefix);
51 }
52
53 // Construct a data input directory by prepending the test data root
54 // directory to <prefix>
TestdataPrefix(const string & prefix)55 string TestdataPrefix(const string& prefix) {
56 return strings::StrCat(testing::TensorFlowSrcRoot(),
57 "/core/util/tensor_bundle/testdata/", prefix);
58 }
59
60 template <typename T>
Constant(T v,TensorShape shape)61 Tensor Constant(T v, TensorShape shape) {
62 Tensor ret(DataTypeToEnum<T>::value, shape);
63 ret.flat<T>().setConstant(v);
64 return ret;
65 }
66
67 template <typename T>
Constant_2x3(T v)68 Tensor Constant_2x3(T v) {
69 return Constant(v, TensorShape({2, 3}));
70 }
71
ByteSwap(Tensor t)72 Tensor ByteSwap(Tensor t) {
73 Tensor ret = tensor::DeepCopy(t);
74 TF_EXPECT_OK(ByteSwapTensor(&ret));
75 return ret;
76 }
77
78 // Assert that <reader> has a tensor under <key> matching <expected_val> in
79 // terms of both shape, dtype, and value
80 template <typename T>
Expect(BundleReader * reader,const string & key,const Tensor & expected_val)81 void Expect(BundleReader* reader, const string& key,
82 const Tensor& expected_val) {
83 // Tests for Contains().
84 EXPECT_TRUE(reader->Contains(key));
85 // Tests for LookupDtypeAndShape().
86 DataType dtype;
87 TensorShape shape;
88 TF_ASSERT_OK(reader->LookupDtypeAndShape(key, &dtype, &shape));
89 EXPECT_EQ(expected_val.dtype(), dtype);
90 EXPECT_EQ(expected_val.shape(), shape);
91 // Tests for Lookup(), checking tensor contents.
92 Tensor val(expected_val.dtype(), shape);
93 TF_ASSERT_OK(reader->Lookup(key, &val));
94 test::ExpectTensorEqual<T>(val, expected_val);
95 }
96
97 template <class T>
ExpectVariant(BundleReader * reader,const string & key,const Tensor & expected_t)98 void ExpectVariant(BundleReader* reader, const string& key,
99 const Tensor& expected_t) {
100 // Tests for Contains().
101 EXPECT_TRUE(reader->Contains(key));
102 // Tests for LookupDtypeAndShape().
103 DataType dtype;
104 TensorShape shape;
105 TF_ASSERT_OK(reader->LookupDtypeAndShape(key, &dtype, &shape));
106 // Tests for Lookup(), checking tensor contents.
107 EXPECT_EQ(expected_t.dtype(), dtype);
108 EXPECT_EQ(expected_t.shape(), shape);
109 Tensor actual_t(dtype, shape);
110 TF_ASSERT_OK(reader->Lookup(key, &actual_t));
111 for (int i = 0; i < expected_t.NumElements(); i++) {
112 Variant actual_var = actual_t.flat<Variant>()(i);
113 Variant expected_var = expected_t.flat<Variant>()(i);
114 EXPECT_EQ(actual_var.TypeName(), expected_var.TypeName());
115 auto* actual_val = actual_var.get<T>();
116 auto* expected_val = expected_var.get<T>();
117 EXPECT_EQ(*expected_val, *actual_val);
118 }
119 }
120
121 template <typename T>
ExpectNext(BundleReader * reader,const Tensor & expected_val)122 void ExpectNext(BundleReader* reader, const Tensor& expected_val) {
123 EXPECT_TRUE(reader->Valid());
124 reader->Next();
125 TF_ASSERT_OK(reader->status());
126 Tensor val;
127 TF_ASSERT_OK(reader->ReadCurrent(&val));
128 test::ExpectTensorEqual<T>(val, expected_val);
129 }
130
AllTensorKeys(BundleReader * reader)131 std::vector<string> AllTensorKeys(BundleReader* reader) {
132 std::vector<string> ret;
133 reader->Seek(kHeaderEntryKey);
134 reader->Next();
135 for (; reader->Valid(); reader->Next()) {
136 ret.emplace_back(reader->key());
137 }
138 return ret;
139 }
140
141 // Writes out the metadata file of a bundle again, with the endianness marker
142 // bit flipped.
FlipEndiannessBit(const string & prefix)143 Status FlipEndiannessBit(const string& prefix) {
144 Env* env = Env::Default();
145 const string metadata_tmp_path = Prefix("some_tmp_path");
146 std::unique_ptr<WritableFile> metadata_file;
147 TF_RETURN_IF_ERROR(env->NewWritableFile(metadata_tmp_path, &metadata_file));
148 // We create the builder lazily in case we run into an exception earlier, in
149 // which case we'd forget to call Finish() and TableBuilder's destructor
150 // would complain.
151 std::unique_ptr<table::TableBuilder> builder;
152
153 // Reads the existing metadata file, and fills the builder.
154 {
155 const string filename = MetaFilename(prefix);
156 uint64 file_size;
157 TF_RETURN_IF_ERROR(env->GetFileSize(filename, &file_size));
158 std::unique_ptr<RandomAccessFile> file;
159 TF_RETURN_IF_ERROR(env->NewRandomAccessFile(filename, &file));
160
161 table::Table* table = nullptr;
162 TF_RETURN_IF_ERROR(
163 table::Table::Open(table::Options(), file.get(), file_size, &table));
164 std::unique_ptr<table::Table> table_deleter(table);
165 std::unique_ptr<table::Iterator> iter(table->NewIterator());
166
167 // Reads the header entry.
168 iter->Seek(kHeaderEntryKey);
169 CHECK(iter->Valid());
170 BundleHeaderProto header;
171 CHECK(header.ParseFromArray(iter->value().data(), iter->value().size()));
172 // Flips the endianness.
173 if (header.endianness() == BundleHeaderProto::LITTLE) {
174 header.set_endianness(BundleHeaderProto::BIG);
175 } else {
176 header.set_endianness(BundleHeaderProto::LITTLE);
177 }
178 builder.reset(
179 new table::TableBuilder(table::Options(), metadata_file.get()));
180 builder->Add(iter->key(), header.SerializeAsString());
181 iter->Next();
182
183 // Adds the non-header entries unmodified.
184 for (; iter->Valid(); iter->Next())
185 builder->Add(iter->key(), iter->value());
186 }
187 TF_RETURN_IF_ERROR(builder->Finish());
188 TF_RETURN_IF_ERROR(env->RenameFile(metadata_tmp_path, MetaFilename(prefix)));
189 return metadata_file->Close();
190 }
191
192 template <typename T>
TestBasic()193 void TestBasic() {
194 {
195 BundleWriter writer(Env::Default(), Prefix("foo"));
196 TF_EXPECT_OK(writer.Add("foo_003", Constant_2x3(T(3))));
197 TF_EXPECT_OK(writer.Add("foo_000", Constant_2x3(T(0))));
198 TF_EXPECT_OK(writer.Add("foo_002", Constant_2x3(T(2))));
199 TF_EXPECT_OK(writer.Add("foo_001", Constant_2x3(T(1))));
200 TF_ASSERT_OK(writer.Finish());
201 }
202 {
203 BundleReader reader(Env::Default(), Prefix("foo"));
204 TF_ASSERT_OK(reader.status());
205 EXPECT_EQ(
206 AllTensorKeys(&reader),
207 std::vector<string>({"foo_000", "foo_001", "foo_002", "foo_003"}));
208 Expect<T>(&reader, "foo_000", Constant_2x3(T(0)));
209 Expect<T>(&reader, "foo_001", Constant_2x3(T(1)));
210 Expect<T>(&reader, "foo_002", Constant_2x3(T(2)));
211 Expect<T>(&reader, "foo_003", Constant_2x3(T(3)));
212 }
213 {
214 BundleReader reader(Env::Default(), Prefix("foo"));
215 TF_ASSERT_OK(reader.status());
216 ExpectNext<T>(&reader, Constant_2x3(T(0)));
217 ExpectNext<T>(&reader, Constant_2x3(T(1)));
218 ExpectNext<T>(&reader, Constant_2x3(T(2)));
219 ExpectNext<T>(&reader, Constant_2x3(T(3)));
220 EXPECT_TRUE(reader.Valid());
221 reader.Next();
222 EXPECT_FALSE(reader.Valid());
223 }
224 {
225 BundleWriter writer(Env::Default(), Prefix("bar"));
226 TF_EXPECT_OK(writer.Add("bar_003", Constant_2x3(T(3))));
227 TF_EXPECT_OK(writer.Add("bar_000", Constant_2x3(T(0))));
228 TF_EXPECT_OK(writer.Add("bar_002", Constant_2x3(T(2))));
229 TF_EXPECT_OK(writer.Add("bar_001", Constant_2x3(T(1))));
230 TF_ASSERT_OK(writer.Finish());
231 }
232 {
233 BundleReader reader(Env::Default(), Prefix("bar"));
234 TF_ASSERT_OK(reader.status());
235 EXPECT_EQ(
236 AllTensorKeys(&reader),
237 std::vector<string>({"bar_000", "bar_001", "bar_002", "bar_003"}));
238 Expect<T>(&reader, "bar_003", Constant_2x3(T(3)));
239 Expect<T>(&reader, "bar_002", Constant_2x3(T(2)));
240 Expect<T>(&reader, "bar_001", Constant_2x3(T(1)));
241 Expect<T>(&reader, "bar_000", Constant_2x3(T(0)));
242 }
243 {
244 BundleReader reader(Env::Default(), Prefix("bar"));
245 TF_ASSERT_OK(reader.status());
246 ExpectNext<T>(&reader, Constant_2x3(T(0)));
247 ExpectNext<T>(&reader, Constant_2x3(T(1)));
248 ExpectNext<T>(&reader, Constant_2x3(T(2)));
249 ExpectNext<T>(&reader, Constant_2x3(T(3)));
250 EXPECT_TRUE(reader.Valid());
251 reader.Next();
252 EXPECT_FALSE(reader.Valid());
253 }
254 TF_ASSERT_OK(MergeBundles(Env::Default(), {Prefix("foo"), Prefix("bar")},
255 Prefix("merged")));
256 {
257 BundleReader reader(Env::Default(), Prefix("merged"));
258 TF_ASSERT_OK(reader.status());
259 EXPECT_EQ(
260 AllTensorKeys(&reader),
261 std::vector<string>({"bar_000", "bar_001", "bar_002", "bar_003",
262 "foo_000", "foo_001", "foo_002", "foo_003"}));
263 Expect<T>(&reader, "bar_000", Constant_2x3(T(0)));
264 Expect<T>(&reader, "bar_001", Constant_2x3(T(1)));
265 Expect<T>(&reader, "bar_002", Constant_2x3(T(2)));
266 Expect<T>(&reader, "bar_003", Constant_2x3(T(3)));
267 Expect<T>(&reader, "foo_000", Constant_2x3(T(0)));
268 Expect<T>(&reader, "foo_001", Constant_2x3(T(1)));
269 Expect<T>(&reader, "foo_002", Constant_2x3(T(2)));
270 Expect<T>(&reader, "foo_003", Constant_2x3(T(3)));
271 }
272 {
273 BundleReader reader(Env::Default(), Prefix("merged"));
274 TF_ASSERT_OK(reader.status());
275 ExpectNext<T>(&reader, Constant_2x3(T(0)));
276 ExpectNext<T>(&reader, Constant_2x3(T(1)));
277 ExpectNext<T>(&reader, Constant_2x3(T(2)));
278 ExpectNext<T>(&reader, Constant_2x3(T(3)));
279 ExpectNext<T>(&reader, Constant_2x3(T(0)));
280 ExpectNext<T>(&reader, Constant_2x3(T(1)));
281 ExpectNext<T>(&reader, Constant_2x3(T(2)));
282 ExpectNext<T>(&reader, Constant_2x3(T(3)));
283 EXPECT_TRUE(reader.Valid());
284 reader.Next();
285 EXPECT_FALSE(reader.Valid());
286 }
287 }
288
289 // Type-specific subroutine of SwapBytes test below
290 template <typename T>
TestByteSwap(const T * forward,const T * swapped,int array_len)291 void TestByteSwap(const T* forward, const T* swapped, int array_len) {
292 auto bytes_per_elem = sizeof(T);
293
294 // Convert the entire array at once
295 std::unique_ptr<T[]> forward_copy(new T[array_len]);
296 std::memcpy(forward_copy.get(), forward, array_len * bytes_per_elem);
297 TF_EXPECT_OK(ByteSwapArray(reinterpret_cast<char*>(forward_copy.get()),
298 bytes_per_elem, array_len));
299 for (int i = 0; i < array_len; i++) {
300 EXPECT_EQ(forward_copy.get()[i], swapped[i]);
301 }
302
303 // Then the array wrapped in a tensor
304 auto shape = TensorShape({array_len});
305 auto dtype = DataTypeToEnum<T>::value;
306 Tensor forward_tensor(dtype, shape);
307 Tensor swapped_tensor(dtype, shape);
308 std::memcpy(const_cast<char*>(forward_tensor.tensor_data().data()), forward,
309 array_len * bytes_per_elem);
310 std::memcpy(const_cast<char*>(swapped_tensor.tensor_data().data()), swapped,
311 array_len * bytes_per_elem);
312 TF_EXPECT_OK(ByteSwapTensor(&forward_tensor));
313 test::ExpectTensorEqual<T>(forward_tensor, swapped_tensor);
314 }
315
316 // Unit test of the byte-swapping operations that TensorBundle uses.
TEST(TensorBundleTest,SwapBytes)317 TEST(TensorBundleTest, SwapBytes) {
318 // A bug in the compiler on MacOS causes ByteSwap() and FlipEndiannessBit()
319 // to be removed from the executable if they are only called from templated
320 // functions. As a workaround, we make some dummy calls here.
321 // TODO(frreiss): Remove this workaround when the compiler bug is fixed.
322 ByteSwap(Constant_2x3<int>(42));
323 EXPECT_NE(OkStatus(), FlipEndiannessBit(Prefix("not_a_valid_prefix")));
324
325 // Test patterns, manually swapped so that we aren't relying on the
326 // correctness of our own byte-swapping macros when testing those macros.
327 // At least one of the entries in each list has the sign bit set when
328 // interpreted as a signed int.
329 const int arr_len_16 = 4;
330 const uint16_t forward_16[] = {0x1de5, 0xd017, 0xf1ea, 0xc0a1};
331 const uint16_t swapped_16[] = {0xe51d, 0x17d0, 0xeaf1, 0xa1c0};
332 const int arr_len_32 = 2;
333 const uint32_t forward_32[] = {0x0ddba115, 0xf01dab1e};
334 const uint32_t swapped_32[] = {0x15a1db0d, 0x1eab1df0};
335 const int arr_len_64 = 2;
336 const uint64_t forward_64[] = {0xf005ba11caba1000, 0x5ca1ab1ecab005e5};
337 const uint64_t swapped_64[] = {0x0010baca11ba05f0, 0xe505b0ca1eaba15c};
338
339 // 16-bit types
340 TestByteSwap(forward_16, swapped_16, arr_len_16);
341 TestByteSwap(reinterpret_cast<const int16_t*>(forward_16),
342 reinterpret_cast<const int16_t*>(swapped_16), arr_len_16);
343 TestByteSwap(reinterpret_cast<const bfloat16*>(forward_16),
344 reinterpret_cast<const bfloat16*>(swapped_16), arr_len_16);
345
346 // 32-bit types
347 TestByteSwap(forward_32, swapped_32, arr_len_32);
348 TestByteSwap(reinterpret_cast<const int32_t*>(forward_32),
349 reinterpret_cast<const int32_t*>(swapped_32), arr_len_32);
350 TestByteSwap(reinterpret_cast<const float*>(forward_32),
351 reinterpret_cast<const float*>(swapped_32), arr_len_32);
352
353 // 64-bit types
354 // Cast to uint64*/int64* to make DataTypeToEnum<T> happy
355 TestByteSwap(reinterpret_cast<const uint64*>(forward_64),
356 reinterpret_cast<const uint64*>(swapped_64), arr_len_64);
357 TestByteSwap(reinterpret_cast<const int64_t*>(forward_64),
358 reinterpret_cast<const int64_t*>(swapped_64), arr_len_64);
359 TestByteSwap(reinterpret_cast<const double*>(forward_64),
360 reinterpret_cast<const double*>(swapped_64), arr_len_64);
361
362 // Complex types.
363 // Logic for complex number handling is only in ByteSwapTensor, so don't test
364 // ByteSwapArray
365 const float* forward_float = reinterpret_cast<const float*>(forward_32);
366 const float* swapped_float = reinterpret_cast<const float*>(swapped_32);
367 const double* forward_double = reinterpret_cast<const double*>(forward_64);
368 const double* swapped_double = reinterpret_cast<const double*>(swapped_64);
369 Tensor forward_complex64 = Constant_2x3<complex64>(
370 std::complex<float>(forward_float[0], forward_float[1]));
371 Tensor swapped_complex64 = Constant_2x3<complex64>(
372 std::complex<float>(swapped_float[0], swapped_float[1]));
373 Tensor forward_complex128 = Constant_2x3<complex128>(
374 std::complex<double>(forward_double[0], forward_double[1]));
375 Tensor swapped_complex128 = Constant_2x3<complex128>(
376 std::complex<double>(swapped_double[0], swapped_double[1]));
377
378 TF_EXPECT_OK(ByteSwapTensor(&forward_complex64));
379 test::ExpectTensorEqual<complex64>(forward_complex64, swapped_complex64);
380
381 TF_EXPECT_OK(ByteSwapTensor(&forward_complex128));
382 test::ExpectTensorEqual<complex128>(forward_complex128, swapped_complex128);
383 }
384
385 // Basic test of alternate-endianness support. Generates a bundle in
386 // the opposite of the current system's endianness and attempts to
387 // read the bundle back in. Does not exercise sharding or access to
388 // nonaligned tensors. Does cover the major access types exercised
389 // in TestBasic.
390 template <typename T>
TestEndianness()391 void TestEndianness() {
392 {
393 // Write out a TensorBundle in the opposite of this host's endianness.
394 BundleWriter writer(Env::Default(), Prefix("foo"));
395 TF_EXPECT_OK(writer.Add("foo_003", ByteSwap(Constant_2x3<T>(T(3)))));
396 TF_EXPECT_OK(writer.Add("foo_000", ByteSwap(Constant_2x3<T>(T(0)))));
397 TF_EXPECT_OK(writer.Add("foo_002", ByteSwap(Constant_2x3<T>(T(2)))));
398 TF_EXPECT_OK(writer.Add("foo_001", ByteSwap(Constant_2x3<T>(T(1)))));
399 TF_ASSERT_OK(writer.Finish());
400 TF_ASSERT_OK(FlipEndiannessBit(Prefix("foo")));
401 }
402 {
403 BundleReader reader(Env::Default(), Prefix("foo"));
404 TF_ASSERT_OK(reader.status());
405 EXPECT_EQ(
406 AllTensorKeys(&reader),
407 std::vector<string>({"foo_000", "foo_001", "foo_002", "foo_003"}));
408 Expect<T>(&reader, "foo_000", Constant_2x3<T>(T(0)));
409 Expect<T>(&reader, "foo_001", Constant_2x3<T>(T(1)));
410 Expect<T>(&reader, "foo_002", Constant_2x3<T>(T(2)));
411 Expect<T>(&reader, "foo_003", Constant_2x3<T>(T(3)));
412 }
413 {
414 BundleReader reader(Env::Default(), Prefix("foo"));
415 TF_ASSERT_OK(reader.status());
416 ExpectNext<T>(&reader, Constant_2x3<T>(T(0)));
417 ExpectNext<T>(&reader, Constant_2x3<T>(T(1)));
418 ExpectNext<T>(&reader, Constant_2x3<T>(T(2)));
419 ExpectNext<T>(&reader, Constant_2x3<T>(T(3)));
420 EXPECT_TRUE(reader.Valid());
421 reader.Next();
422 EXPECT_FALSE(reader.Valid());
423 }
424 {
425 BundleWriter writer(Env::Default(), Prefix("bar"));
426 TF_EXPECT_OK(writer.Add("bar_003", ByteSwap(Constant_2x3<T>(T(3)))));
427 TF_EXPECT_OK(writer.Add("bar_000", ByteSwap(Constant_2x3<T>(T(0)))));
428 TF_EXPECT_OK(writer.Add("bar_002", ByteSwap(Constant_2x3<T>(T(2)))));
429 TF_EXPECT_OK(writer.Add("bar_001", ByteSwap(Constant_2x3<T>(T(1)))));
430 TF_ASSERT_OK(writer.Finish());
431 TF_ASSERT_OK(FlipEndiannessBit(Prefix("bar")));
432 }
433 {
434 BundleReader reader(Env::Default(), Prefix("bar"));
435 TF_ASSERT_OK(reader.status());
436 EXPECT_EQ(
437 AllTensorKeys(&reader),
438 std::vector<string>({"bar_000", "bar_001", "bar_002", "bar_003"}));
439 Expect<T>(&reader, "bar_003", Constant_2x3<T>(T(3)));
440 Expect<T>(&reader, "bar_002", Constant_2x3<T>(T(2)));
441 Expect<T>(&reader, "bar_001", Constant_2x3<T>(T(1)));
442 Expect<T>(&reader, "bar_000", Constant_2x3<T>(T(0)));
443 }
444 {
445 BundleReader reader(Env::Default(), Prefix("bar"));
446 TF_ASSERT_OK(reader.status());
447 ExpectNext<T>(&reader, Constant_2x3<T>(T(0)));
448 ExpectNext<T>(&reader, Constant_2x3<T>(T(1)));
449 ExpectNext<T>(&reader, Constant_2x3<T>(T(2)));
450 ExpectNext<T>(&reader, Constant_2x3<T>(T(3)));
451 EXPECT_TRUE(reader.Valid());
452 reader.Next();
453 EXPECT_FALSE(reader.Valid());
454 }
455 TF_ASSERT_OK(MergeBundles(Env::Default(), {Prefix("foo"), Prefix("bar")},
456 Prefix("merged")));
457 {
458 BundleReader reader(Env::Default(), Prefix("merged"));
459 TF_ASSERT_OK(reader.status());
460 EXPECT_EQ(
461 AllTensorKeys(&reader),
462 std::vector<string>({"bar_000", "bar_001", "bar_002", "bar_003",
463 "foo_000", "foo_001", "foo_002", "foo_003"}));
464 Expect<T>(&reader, "bar_000", Constant_2x3<T>(T(0)));
465 Expect<T>(&reader, "bar_001", Constant_2x3<T>(T(1)));
466 Expect<T>(&reader, "bar_002", Constant_2x3<T>(T(2)));
467 Expect<T>(&reader, "bar_003", Constant_2x3<T>(T(3)));
468 Expect<T>(&reader, "foo_000", Constant_2x3<T>(T(0)));
469 Expect<T>(&reader, "foo_001", Constant_2x3<T>(T(1)));
470 Expect<T>(&reader, "foo_002", Constant_2x3<T>(T(2)));
471 Expect<T>(&reader, "foo_003", Constant_2x3<T>(T(3)));
472 }
473 {
474 BundleReader reader(Env::Default(), Prefix("merged"));
475 TF_ASSERT_OK(reader.status());
476 ExpectNext<T>(&reader, Constant_2x3<T>(T(0)));
477 ExpectNext<T>(&reader, Constant_2x3<T>(T(1)));
478 ExpectNext<T>(&reader, Constant_2x3<T>(T(2)));
479 ExpectNext<T>(&reader, Constant_2x3<T>(T(3)));
480 ExpectNext<T>(&reader, Constant_2x3<T>(T(0)));
481 ExpectNext<T>(&reader, Constant_2x3<T>(T(1)));
482 ExpectNext<T>(&reader, Constant_2x3<T>(T(2)));
483 ExpectNext<T>(&reader, Constant_2x3<T>(T(3)));
484 EXPECT_TRUE(reader.Valid());
485 reader.Next();
486 EXPECT_FALSE(reader.Valid());
487 }
488 }
489
490 template <typename T>
TestNonStandardShapes()491 void TestNonStandardShapes() {
492 {
493 BundleWriter writer(Env::Default(), Prefix("nonstandard"));
494 TF_EXPECT_OK(writer.Add("scalar", Constant(T(0), TensorShape())));
495 TF_EXPECT_OK(
496 writer.Add("non_standard0", Constant(T(0), TensorShape({0, 1618}))));
497 TF_EXPECT_OK(
498 writer.Add("non_standard1", Constant(T(0), TensorShape({16, 0, 18}))));
499 TF_ASSERT_OK(writer.Finish());
500 }
501 {
502 BundleReader reader(Env::Default(), Prefix("nonstandard"));
503 TF_ASSERT_OK(reader.status());
504 Expect<T>(&reader, "scalar", Constant(T(0), TensorShape()));
505 Expect<T>(&reader, "non_standard0", Constant(T(0), TensorShape({0, 1618})));
506 Expect<T>(&reader, "non_standard1",
507 Constant(T(0), TensorShape({16, 0, 18})));
508 }
509 }
510
511 // Writes a bundle to disk with a bad "version"; checks for "expected_error".
VersionTest(const VersionDef & version,StringPiece expected_error)512 void VersionTest(const VersionDef& version, StringPiece expected_error) {
513 const string path = Prefix("version_test");
514 {
515 // Prepare an empty bundle with the given version information.
516 BundleHeaderProto header;
517 *header.mutable_version() = version;
518
519 // Write the metadata file to disk.
520 std::unique_ptr<WritableFile> file;
521 TF_ASSERT_OK(Env::Default()->NewWritableFile(MetaFilename(path), &file));
522 table::TableBuilder builder(table::Options(), file.get());
523 builder.Add(kHeaderEntryKey, header.SerializeAsString());
524 TF_ASSERT_OK(builder.Finish());
525 }
526 // Read it back in and verify that we get the expected error.
527 BundleReader reader(Env::Default(), path);
528 EXPECT_TRUE(errors::IsInvalidArgument(reader.status()));
529 EXPECT_TRUE(
530 absl::StartsWith(reader.status().error_message(), expected_error));
531 }
532
533 } // namespace
534
TEST(TensorBundleTest,Basic)535 TEST(TensorBundleTest, Basic) {
536 TestBasic<float>();
537 TestBasic<double>();
538 TestBasic<int32>();
539 TestBasic<uint8>();
540 TestBasic<int16>();
541 TestBasic<int8>();
542 TestBasic<complex64>();
543 TestBasic<complex128>();
544 TestBasic<int64_t>();
545 TestBasic<bool>();
546 TestBasic<qint32>();
547 TestBasic<quint8>();
548 TestBasic<qint8>();
549 TestBasic<bfloat16>();
550 }
551
TEST(TensorBundleTest,Endianness)552 TEST(TensorBundleTest, Endianness) {
553 TestEndianness<float>();
554 TestEndianness<double>();
555 TestEndianness<int32>();
556 TestEndianness<uint8>();
557 TestEndianness<int16>();
558 TestEndianness<int8>();
559 TestEndianness<complex64>();
560 TestEndianness<complex128>();
561 TestEndianness<int64_t>();
562 TestEndianness<bool>();
563 TestEndianness<qint32>();
564 TestEndianness<quint8>();
565 TestEndianness<qint8>();
566 TestEndianness<bfloat16>();
567 }
568
TEST(TensorBundleTest,PartitionedVariables)569 TEST(TensorBundleTest, PartitionedVariables) {
570 const TensorShape kFullShape({5, 10});
571 // Adds two slices.
572 // First slice: column 0, all zeros.
573 // Second slice: column 1 to rest, all ones.
574 TensorSlice slice1 = TensorSlice::ParseOrDie("-:0,1");
575 TensorSlice slice2 = TensorSlice::ParseOrDie("-:1,9");
576 {
577 BundleWriter writer(Env::Default(), Prefix("foo"));
578
579 TF_ASSERT_OK(writer.AddSlice("foo", kFullShape, slice1,
580 Constant<float>(0., TensorShape({5, 1}))));
581 TF_ASSERT_OK(writer.AddSlice("foo", kFullShape, slice2,
582 Constant<float>(1., TensorShape({5, 9}))));
583 TF_ASSERT_OK(writer.Finish());
584 }
585 // Reads in full.
586 {
587 BundleReader reader(Env::Default(), Prefix("foo"));
588 TF_ASSERT_OK(reader.status());
589
590 Tensor expected_val(DT_FLOAT, kFullShape);
591 test::FillFn<float>(&expected_val, [](int offset) -> float {
592 if (offset % 10 == 0) {
593 return 0; // First column zeros.
594 }
595 return 1; // Other columns ones.
596 });
597
598 Tensor val(DT_FLOAT, kFullShape);
599 TF_ASSERT_OK(reader.Lookup("foo", &val));
600 test::ExpectTensorEqual<float>(val, expected_val);
601 }
602 // Reads all slices.
603 {
604 BundleReader reader(Env::Default(), Prefix("foo"));
605 TF_ASSERT_OK(reader.status());
606
607 std::vector<TensorSlice> slices;
608 TF_ASSERT_OK(reader.LookupTensorSlices("foo", &slices));
609
610 EXPECT_EQ(2, slices.size());
611 EXPECT_EQ(slice1.DebugString(), slices[0].DebugString());
612 EXPECT_EQ(slice2.DebugString(), slices[1].DebugString());
613 }
614 // Reads a slice consisting of first two columns, "cutting" both slices.
615 {
616 BundleReader reader(Env::Default(), Prefix("foo"));
617 TF_ASSERT_OK(reader.status());
618
619 // First two columns, "cutting" both slices.
620 const TensorSlice distinct_slice = TensorSlice::ParseOrDie("-:0,2");
621 Tensor expected_val(DT_FLOAT, TensorShape({5, 2}));
622 test::FillFn<float>(&expected_val, [](int offset) -> float {
623 if (offset % 2 == 0) {
624 return 0; // First column zeros.
625 }
626 return 1; // Other columns ones.
627 });
628
629 Tensor val(DT_FLOAT, TensorShape({5, 2}));
630 TF_ASSERT_OK(reader.LookupSlice("foo", distinct_slice, &val));
631 test::ExpectTensorEqual<float>(val, expected_val);
632 }
633 // Reads a slice consisting of columns 2-4, "cutting" the second slice only.
634 {
635 BundleReader reader(Env::Default(), Prefix("foo"));
636 TF_ASSERT_OK(reader.status());
637
638 const TensorSlice distinct_slice = TensorSlice::ParseOrDie("-:2,2");
639 Tensor val(DT_FLOAT, TensorShape({5, 2}));
640 TF_ASSERT_OK(reader.LookupSlice("foo", distinct_slice, &val));
641 test::ExpectTensorEqual<float>(val,
642 Constant<float>(1., TensorShape({5, 2})));
643 }
644 }
645
TEST(TensorBundleTest,EquivalentSliceTest)646 TEST(TensorBundleTest, EquivalentSliceTest) {
647 const TensorShape kFullShape({5, 10});
648 const Tensor kExpected(Constant<float>(1., kFullShape));
649 {
650 BundleWriter writer(Env::Default(), Prefix("foo"));
651 TF_ASSERT_OK(writer.AddSlice("no_extents", kFullShape,
652 TensorSlice::ParseOrDie("-:-"), kExpected));
653 TF_ASSERT_OK(writer.AddSlice("both_extents", kFullShape,
654 TensorSlice::ParseOrDie("0,5:0,10"),
655 kExpected));
656 TF_ASSERT_OK(writer.Finish());
657 }
658 // Slices match exactly and are fully abbreviated.
659 {
660 BundleReader reader(Env::Default(), Prefix("foo"));
661 TF_ASSERT_OK(reader.status());
662 const TensorSlice slice = TensorSlice::ParseOrDie("-:-");
663 Tensor val(DT_FLOAT, TensorShape(kFullShape));
664 TF_ASSERT_OK(reader.LookupSlice("no_extents", slice, &val));
665 test::ExpectTensorEqual<float>(val, kExpected);
666 }
667 // Slice match exactly and are fully specified.
668 {
669 BundleReader reader(Env::Default(), Prefix("foo"));
670 TF_ASSERT_OK(reader.status());
671 const TensorSlice slice = TensorSlice::ParseOrDie("0,5:0,10");
672 Tensor val(DT_FLOAT, TensorShape(kFullShape));
673 TF_ASSERT_OK(reader.LookupSlice("both_extents", slice, &val));
674 test::ExpectTensorEqual<float>(val, kExpected);
675 }
676 // Stored slice has no extents, spec has extents.
677 {
678 BundleReader reader(Env::Default(), Prefix("foo"));
679 TF_ASSERT_OK(reader.status());
680 const TensorSlice slice = TensorSlice::ParseOrDie("0,5:0,10");
681 Tensor val(DT_FLOAT, TensorShape(kFullShape));
682 TF_ASSERT_OK(reader.LookupSlice("no_extents", slice, &val));
683 test::ExpectTensorEqual<float>(val, kExpected);
684 }
685 // Stored slice has both extents, spec has no extents.
686 {
687 BundleReader reader(Env::Default(), Prefix("foo"));
688 TF_ASSERT_OK(reader.status());
689 const TensorSlice slice = TensorSlice::ParseOrDie("-:-");
690 Tensor val(DT_FLOAT, TensorShape(kFullShape));
691 TF_ASSERT_OK(reader.LookupSlice("both_extents", slice, &val));
692 test::ExpectTensorEqual<float>(val, kExpected);
693 }
694 }
695
TEST(TensorBundleTest,NonStandardShapes)696 TEST(TensorBundleTest, NonStandardShapes) {
697 TestNonStandardShapes<float>();
698 TestNonStandardShapes<double>();
699 TestNonStandardShapes<int32>();
700 TestNonStandardShapes<uint8>();
701 TestNonStandardShapes<int16>();
702 TestNonStandardShapes<int8>();
703 TestNonStandardShapes<complex64>();
704 TestNonStandardShapes<complex128>();
705 TestNonStandardShapes<int64_t>();
706 TestNonStandardShapes<bool>();
707 TestNonStandardShapes<qint32>();
708 TestNonStandardShapes<quint8>();
709 TestNonStandardShapes<qint8>();
710 TestNonStandardShapes<bfloat16>();
711 }
712
TEST(TensorBundleTest,StringTensorsOldFormat)713 TEST(TensorBundleTest, StringTensorsOldFormat) {
714 // Test string tensor bundle made with previous version of code that use
715 // varint32s to store string lengths (we now use varint64s).
716 BundleReader reader(Env::Default(), TestdataPrefix("old_string_tensors/foo"));
717 TF_ASSERT_OK(reader.status());
718 EXPECT_EQ(AllTensorKeys(&reader),
719 std::vector<string>({"floats", "scalar", "string_tensor", "strs"}));
720
721 Expect<tstring>(&reader, "string_tensor",
722 Tensor(DT_STRING, TensorShape({1})));
723 Expect<tstring>(&reader, "scalar", test::AsTensor<tstring>({"hello"}));
724 Expect<tstring>(
725 &reader, "strs",
726 test::AsTensor<tstring>({"hello", "", "x01", string(1 << 10, 'c')}));
727 Expect<float>(&reader, "floats", Constant_2x3<float>(16.18));
728 }
729
730 // Copied from absl code.
GetPageSize()731 size_t GetPageSize() {
732 #ifdef _WIN32
733 SYSTEM_INFO system_info;
734 GetSystemInfo(&system_info);
735 return std::max(system_info.dwPageSize, system_info.dwAllocationGranularity);
736 #elif defined(__wasm__) || defined(__asmjs__)
737 return getpagesize();
738 #else
739 return sysconf(_SC_PAGESIZE);
740 #endif
741 }
742
TEST(TensorBundleTest,StringTensors)743 TEST(TensorBundleTest, StringTensors) {
744 constexpr size_t kLongLength = static_cast<size_t>(UINT32_MAX) + 1;
745 Tensor long_string_tensor(DT_STRING, TensorShape({1}));
746
747 {
748 BundleWriter writer(Env::Default(), Prefix("foo"));
749 TF_EXPECT_OK(writer.Add("string_tensor",
750 Tensor(DT_STRING, TensorShape({1})))); // Empty.
751 TF_EXPECT_OK(writer.Add("scalar", test::AsTensor<tstring>({"hello"})));
752 TF_EXPECT_OK(writer.Add(
753 "strs",
754 test::AsTensor<tstring>({"hello", "", "x01", string(1 << 25, 'c')})));
755
756 // Requires a 64-bit length.
757 tstring* backing_string = long_string_tensor.flat<tstring>().data();
758 backing_string->resize_uninitialized(kLongLength);
759 std::char_traits<char>::assign(backing_string->data(), kLongLength, 'd');
760 TF_EXPECT_OK(writer.Add("long_scalar", long_string_tensor));
761
762 // Mixes in some floats.
763 TF_EXPECT_OK(writer.Add("floats", Constant_2x3<float>(16.18)));
764 TF_ASSERT_OK(writer.Finish());
765 }
766 {
767 BundleReader reader(Env::Default(), Prefix("foo"));
768 TF_ASSERT_OK(reader.status());
769 EXPECT_EQ(AllTensorKeys(&reader),
770 std::vector<string>({"floats", "long_scalar", "scalar",
771 "string_tensor", "strs"}));
772
773 Expect<tstring>(&reader, "string_tensor",
774 Tensor(DT_STRING, TensorShape({1})));
775 Expect<tstring>(&reader, "scalar", test::AsTensor<tstring>({"hello"}));
776 Expect<tstring>(
777 &reader, "strs",
778 test::AsTensor<tstring>({"hello", "", "x01", string(1 << 25, 'c')}));
779
780 Expect<float>(&reader, "floats", Constant_2x3<float>(16.18));
781
782 // We don't use the Expect function so we can re-use the
783 // `long_string_tensor` buffer for reading out long_scalar to keep memory
784 // usage reasonable.
785 EXPECT_TRUE(reader.Contains("long_scalar"));
786 DataType dtype;
787 TensorShape shape;
788 TF_ASSERT_OK(reader.LookupDtypeAndShape("long_scalar", &dtype, &shape));
789 EXPECT_EQ(DT_STRING, dtype);
790 EXPECT_EQ(TensorShape({1}), shape);
791
792 // Fill the string differently so that we can be sure the new one is read
793 // in. Because fragmentation in tc-malloc and we have such a big tensor
794 // of 4GB, therefore it is not ideal to free the buffer right now.
795 // The rationale is to make allocation/free close to each other.
796 tstring* backing_string = long_string_tensor.flat<tstring>().data();
797 std::char_traits<char>::assign(backing_string->data(), kLongLength, 'e');
798
799 // Read long_scalar and check it contains kLongLength 'd's.
800 TF_ASSERT_OK(reader.Lookup("long_scalar", &long_string_tensor));
801 ASSERT_EQ(backing_string, long_string_tensor.flat<tstring>().data());
802 EXPECT_EQ(kLongLength, backing_string->length());
803
804 const size_t kPageSize = GetPageSize();
805 char* testblock = new char[kPageSize];
806 memset(testblock, 'd', sizeof(char) * kPageSize);
807 for (size_t i = 0; i < kLongLength; i += kPageSize) {
808 if (memcmp(testblock, backing_string->data() + i, kPageSize) != 0) {
809 FAIL() << "long_scalar is not full of 'd's as expected.";
810 break;
811 }
812 }
813 delete[] testblock;
814 }
815 }
816
817 class VariantObject {
818 public:
VariantObject()819 VariantObject() {}
VariantObject(const string & metadata,int64_t value)820 VariantObject(const string& metadata, int64_t value)
821 : metadata_(metadata), value_(value) {}
822
TypeName() const823 string TypeName() const { return "TEST VariantObject"; }
Encode(VariantTensorData * data) const824 void Encode(VariantTensorData* data) const {
825 data->set_type_name(TypeName());
826 data->set_metadata(metadata_);
827 Tensor val_t = Tensor(DT_INT64, TensorShape({}));
828 val_t.scalar<int64_t>()() = value_;
829 *(data->add_tensors()) = val_t;
830 }
Decode(const VariantTensorData & data)831 bool Decode(const VariantTensorData& data) {
832 EXPECT_EQ(data.type_name(), TypeName());
833 data.get_metadata(&metadata_);
834 EXPECT_EQ(data.tensors_size(), 1);
835 value_ = data.tensors(0).scalar<int64_t>()();
836 return true;
837 }
operator ==(const VariantObject other) const838 bool operator==(const VariantObject other) const {
839 return metadata_ == other.metadata_ && value_ == other.value_;
840 }
841 string metadata_;
842 int64_t value_;
843 };
844
845 REGISTER_UNARY_VARIANT_DECODE_FUNCTION(VariantObject, "TEST VariantObject");
846
TEST(TensorBundleTest,VariantTensors)847 TEST(TensorBundleTest, VariantTensors) {
848 {
849 BundleWriter writer(Env::Default(), Prefix("foo"));
850 TF_EXPECT_OK(
851 writer.Add("variant_tensor",
852 test::AsTensor<Variant>({VariantObject("test", 10),
853 VariantObject("test1", 20)})));
854 TF_ASSERT_OK(writer.Finish());
855 }
856 {
857 BundleReader reader(Env::Default(), Prefix("foo"));
858 TF_ASSERT_OK(reader.status());
859 ExpectVariant<VariantObject>(
860 &reader, "variant_tensor",
861 test::AsTensor<Variant>(
862 {VariantObject("test", 10), VariantObject("test1", 20)}));
863 }
864 }
865
TEST(TensorBundleTest,DirectoryStructure)866 TEST(TensorBundleTest, DirectoryStructure) {
867 Env* env = Env::Default();
868 // Writes two bundles.
869 const std::vector<string> kBundlePrefixes = {Prefix("worker0"),
870 Prefix("worker1")};
871 for (int i = 0; i < 2; ++i) {
872 BundleWriter writer(env, kBundlePrefixes[i]);
873 TF_EXPECT_OK(
874 writer.Add(strings::StrCat("tensor", i), Constant_2x3<float>(0.)));
875 TF_ASSERT_OK(writer.Finish());
876 }
877
878 // Ensures we have the expected files.
879 auto CheckDirFiles = [env](const string& bundle_prefix,
880 gtl::ArraySlice<string> expected_files) {
881 StringPiece dir = io::Dirname(bundle_prefix);
882 for (const string& expected_file : expected_files) {
883 TF_EXPECT_OK(env->FileExists(io::JoinPath(dir, expected_file)));
884 }
885 };
886
887 // Check we have:
888 // worker<i>.index
889 // worker<i>.data-00000-of-00001
890 CheckDirFiles(kBundlePrefixes[0],
891 {"worker0.index", "worker0.data-00000-of-00001"});
892 CheckDirFiles(kBundlePrefixes[1],
893 {"worker1.index", "worker1.data-00000-of-00001"});
894
895 // Trivially "merge" one bundle to some other location (i.e., a renaming).
896 const string kAnotherPrefix = Prefix("another");
897 TF_ASSERT_OK(MergeBundles(env, {kBundlePrefixes[0]}, kAnotherPrefix));
898 CheckDirFiles(kAnotherPrefix,
899 {"another.index", "another.data-00000-of-00001"});
900
901 // Performs actual merge of the two bundles. Check we have:
902 // merged.index
903 // merged.data-00000-of-00002
904 // merged.data-00001-of-00002
905 const string kMerged = Prefix("merged");
906 TF_ASSERT_OK(
907 MergeBundles(env, {kAnotherPrefix, kBundlePrefixes[1]}, kMerged));
908 CheckDirFiles(kMerged, {"merged.index", "merged.data-00000-of-00002",
909 "merged.data-00001-of-00002"});
910 }
911
TEST(TensorBundleTest,SortForSequentialAccess)912 TEST(TensorBundleTest, SortForSequentialAccess) {
913 Env* env = Env::Default();
914 const std::vector<string> kBundlePrefixes = {Prefix("worker0"),
915 Prefix("worker1")};
916 BundleWriter writer0(env, kBundlePrefixes[0]);
917 for (int i = 0; i < 3; ++i) {
918 TF_EXPECT_OK(
919 writer0.Add(strings::StrCat("tensor-0-", i), Constant_2x3<float>(0.)));
920 }
921 TF_ASSERT_OK(writer0.Finish());
922
923 BundleWriter writer1(env, kBundlePrefixes[1]);
924 for (int i = 2; i >= 0; --i) {
925 TF_EXPECT_OK(
926 writer1.Add(strings::StrCat("tensor-1-", i), Constant_2x3<float>(0.)));
927 }
928 TF_ASSERT_OK(writer1.Finish());
929
930 const string kMerged = Prefix("merged");
931 TF_ASSERT_OK(
932 MergeBundles(env, {kBundlePrefixes[0], kBundlePrefixes[1]}, kMerged));
933
934 // We now have:
935 // merged.data-00000-of-00002 with tensor-0-0, tensor-0-1, tensor-0-2
936 // merged.data-00001-of-00002 with tensor-1-2, tensor-1-1, tensor-1-0
937
938 BundleReader reader(env, kMerged);
939 TF_ASSERT_OK(reader.status());
940 std::vector<string> tensor_names = {"tensor-1-0", "tensor-0-1", "tensor-1-2",
941 "tensor-0-0", "tensor-1-1", "tensor-0-2"};
942 TF_ASSERT_OK(reader.SortForSequentialAccess<string>(
943 tensor_names, [](const string& element) { return element; }));
944 EXPECT_THAT(tensor_names,
945 ElementsAre("tensor-0-0", "tensor-0-1", "tensor-0-2",
946 "tensor-1-2", "tensor-1-1", "tensor-1-0"));
947 }
948
TEST(TensorBundleTest,Error)949 TEST(TensorBundleTest, Error) {
950 { // Dup keys.
951 BundleWriter writer(Env::Default(), Prefix("dup"));
952 TF_EXPECT_OK(writer.Add("foo", Constant_2x3(1.f)));
953 EXPECT_FALSE(writer.Add("foo", Constant_2x3(2.f)).ok());
954 EXPECT_TRUE(absl::StrContains(writer.status().ToString(), "duplicate key"));
955 EXPECT_FALSE(writer.Finish().ok());
956 }
957 { // Double finish
958 BundleWriter writer(Env::Default(), Prefix("bad"));
959 EXPECT_TRUE(writer.Finish().ok());
960 EXPECT_FALSE(writer.Finish().ok());
961 }
962 { // Not found.
963 BundleReader reader(Env::Default(), Prefix("nonexist"));
964 EXPECT_EQ(reader.status().code(), error::NOT_FOUND);
965 }
966 }
967
TEST(TensorBundleTest,Checksum)968 TEST(TensorBundleTest, Checksum) {
969 // Randomly flips a byte in [pos_lhs, end of data file), or exactly byte
970 // pos_lhs if exact_pos == True.
971 auto FlipByte = [](const string& prefix, int pos_lhs,
972 bool exact_pos = false) {
973 DCHECK_GE(pos_lhs, 0);
974 const string& datafile = DataFilename(Prefix(prefix), 0, 1);
975 string data;
976 TF_ASSERT_OK(ReadFileToString(Env::Default(), datafile, &data));
977
978 int byte_pos = 0;
979 if (!exact_pos) {
980 std::mt19937 rng;
981 std::uniform_int_distribution<int> dist(pos_lhs, data.size() - 1);
982 byte_pos = dist(rng);
983 } else {
984 byte_pos = pos_lhs;
985 }
986 data[byte_pos] = ~data[byte_pos];
987 TF_ASSERT_OK(WriteStringToFile(Env::Default(), datafile, data));
988 };
989 // The lookup should fail with a checksum-related message.
990 auto ExpectLookupFails = [](const string& prefix, const string& key,
991 const string& expected_msg, Tensor& val) {
992 BundleReader reader(Env::Default(), Prefix(prefix));
993 Status status = reader.Lookup(key, &val);
994 EXPECT_TRUE(errors::IsDataLoss(status));
995 EXPECT_TRUE(absl::StrContains(status.ToString(), expected_msg));
996 };
997
998 // Corrupts a float tensor.
999 {
1000 BundleWriter writer(Env::Default(), Prefix("singleton"));
1001 TF_EXPECT_OK(writer.Add("foo", Constant_2x3(1.f)));
1002 TF_ASSERT_OK(writer.Finish());
1003
1004 FlipByte("singleton", 0 /* corrupts any byte */);
1005 Tensor val(DT_FLOAT, TensorShape({2, 3}));
1006 ExpectLookupFails("singleton", "foo",
1007 "Checksum does not match" /* expected fail msg */, val);
1008 }
1009 // Corrupts a string tensor.
1010 {
1011 auto WriteStrings = []() {
1012 BundleWriter writer(Env::Default(), Prefix("strings"));
1013 TF_EXPECT_OK(
1014 writer.Add("foo", test::AsTensor<tstring>({"hello", "world"})));
1015 TF_ASSERT_OK(writer.Finish());
1016 };
1017 // Corrupts the first two bytes, which are the varint32-encoded lengths
1018 // of the two string elements. Should hit mismatch on length cksum.
1019 for (int i = 0; i < 2; ++i) {
1020 WriteStrings();
1021 FlipByte("strings", i, true /* corrupts exactly byte i */);
1022 Tensor val(DT_STRING, TensorShape({2}));
1023 ExpectLookupFails(
1024 "strings", "foo",
1025 "length checksum does not match" /* expected fail msg */, val);
1026 }
1027 // Corrupts the string bytes, should hit an overall cksum mismatch.
1028 WriteStrings();
1029 FlipByte("strings", 2 /* corrupts starting from byte 2 */);
1030 Tensor val(DT_STRING, TensorShape({2}));
1031 ExpectLookupFails("strings", "foo",
1032 "Checksum does not match" /* expected fail msg */, val);
1033 }
1034 }
1035
TEST(TensorBundleTest,TruncatedTensorContents)1036 TEST(TensorBundleTest, TruncatedTensorContents) {
1037 Env* env = Env::Default();
1038 BundleWriter writer(env, Prefix("end"));
1039 TF_EXPECT_OK(writer.Add("key", Constant_2x3<float>(1.0)));
1040 TF_ASSERT_OK(writer.Finish());
1041
1042 // Truncates the data file by one byte, so that we hit EOF.
1043 const string datafile = DataFilename(Prefix("end"), 0, 1);
1044 string data;
1045 TF_ASSERT_OK(ReadFileToString(env, datafile, &data));
1046 ASSERT_TRUE(!data.empty());
1047 TF_ASSERT_OK(WriteStringToFile(env, datafile,
1048 StringPiece(data.data(), data.size() - 1)));
1049
1050 BundleReader reader(env, Prefix("end"));
1051 TF_ASSERT_OK(reader.status());
1052 Tensor val(DT_FLOAT, TensorShape({2, 3}));
1053 EXPECT_TRUE(errors::IsOutOfRange(reader.Lookup("key", &val)));
1054 }
1055
TEST(TensorBundleTest,HeaderEntry)1056 TEST(TensorBundleTest, HeaderEntry) {
1057 {
1058 BundleWriter writer(Env::Default(), Prefix("b"));
1059 TF_EXPECT_OK(writer.Add("key", Constant_2x3<float>(1.0)));
1060 TF_ASSERT_OK(writer.Finish());
1061 }
1062
1063 // Extracts out the header.
1064 BundleHeaderProto header;
1065 {
1066 BundleReader reader(Env::Default(), Prefix("b"));
1067 TF_ASSERT_OK(reader.status());
1068 reader.Seek(kHeaderEntryKey);
1069 ASSERT_TRUE(reader.Valid());
1070 ASSERT_TRUE(ParseProtoUnlimited(&header, reader.value().data(),
1071 reader.value().size()));
1072 }
1073
1074 // num_shards
1075 EXPECT_EQ(1, header.num_shards());
1076 // endianness
1077 if (port::kLittleEndian) {
1078 EXPECT_EQ(BundleHeaderProto::LITTLE, header.endianness());
1079 } else {
1080 EXPECT_EQ(BundleHeaderProto::BIG, header.endianness());
1081 }
1082 // version
1083 EXPECT_GT(kTensorBundleVersion, 0);
1084 EXPECT_EQ(kTensorBundleVersion, header.version().producer());
1085 EXPECT_EQ(kTensorBundleMinConsumer, header.version().min_consumer());
1086 }
1087
TEST(TensorBundleTest,VersionTest)1088 TEST(TensorBundleTest, VersionTest) {
1089 // Min consumer.
1090 {
1091 VersionDef versions;
1092 versions.set_producer(kTensorBundleVersion + 1);
1093 versions.set_min_consumer(kTensorBundleVersion + 1);
1094 VersionTest(
1095 versions,
1096 strings::StrCat("Checkpoint min consumer version ",
1097 kTensorBundleVersion + 1, " above current version ",
1098 kTensorBundleVersion, " for TensorFlow"));
1099 }
1100 // Min producer.
1101 {
1102 VersionDef versions;
1103 versions.set_producer(kTensorBundleMinProducer - 1);
1104 VersionTest(
1105 versions,
1106 strings::StrCat("Checkpoint producer version ",
1107 kTensorBundleMinProducer - 1, " below min producer ",
1108 kTensorBundleMinProducer, " supported by TensorFlow"));
1109 }
1110 // Bad consumer.
1111 {
1112 VersionDef versions;
1113 versions.set_producer(kTensorBundleVersion + 1);
1114 versions.add_bad_consumers(kTensorBundleVersion);
1115 VersionTest(
1116 versions,
1117 strings::StrCat(
1118 "Checkpoint disallows consumer version ", kTensorBundleVersion,
1119 ". Please upgrade TensorFlow: this version is likely buggy."));
1120 }
1121 }
1122
1123 class TensorBundleAlignmentTest : public ::testing::Test {
1124 protected:
1125 template <typename T>
ExpectAlignment(BundleReader * reader,const string & key,int alignment)1126 void ExpectAlignment(BundleReader* reader, const string& key, int alignment) {
1127 BundleEntryProto full_tensor_entry;
1128 TF_ASSERT_OK(reader->GetBundleEntryProto(key, &full_tensor_entry));
1129 EXPECT_EQ(0, full_tensor_entry.offset() % alignment);
1130 }
1131 };
1132
TEST_F(TensorBundleAlignmentTest,AlignmentTest)1133 TEST_F(TensorBundleAlignmentTest, AlignmentTest) {
1134 {
1135 BundleWriter::Options opts;
1136 opts.data_alignment = 42;
1137 BundleWriter writer(Env::Default(), Prefix("foo"), opts);
1138 TF_EXPECT_OK(writer.Add("foo_003", Constant_2x3<float>(3)));
1139 TF_EXPECT_OK(writer.Add("foo_000", Constant_2x3<float>(0)));
1140 TF_EXPECT_OK(writer.Add("foo_002", Constant_2x3<float>(2)));
1141 TF_EXPECT_OK(writer.Add("foo_001", Constant_2x3<float>(1)));
1142 TF_ASSERT_OK(writer.Finish());
1143 }
1144 {
1145 BundleReader reader(Env::Default(), Prefix("foo"));
1146 TF_ASSERT_OK(reader.status());
1147 EXPECT_EQ(
1148 AllTensorKeys(&reader),
1149 std::vector<string>({"foo_000", "foo_001", "foo_002", "foo_003"}));
1150 Expect<float>(&reader, "foo_000", Constant_2x3<float>(0));
1151 Expect<float>(&reader, "foo_001", Constant_2x3<float>(1));
1152 Expect<float>(&reader, "foo_002", Constant_2x3<float>(2));
1153 Expect<float>(&reader, "foo_003", Constant_2x3<float>(3));
1154 }
1155 {
1156 BundleReader reader(Env::Default(), Prefix("foo"));
1157 TF_ASSERT_OK(reader.status());
1158 ExpectNext<float>(&reader, Constant_2x3<float>(0));
1159 ExpectNext<float>(&reader, Constant_2x3<float>(1));
1160 ExpectNext<float>(&reader, Constant_2x3<float>(2));
1161 ExpectNext<float>(&reader, Constant_2x3<float>(3));
1162 EXPECT_TRUE(reader.Valid());
1163 reader.Next();
1164 EXPECT_FALSE(reader.Valid());
1165 }
1166 {
1167 BundleReader reader(Env::Default(), Prefix("foo"));
1168 TF_ASSERT_OK(reader.status());
1169 ExpectAlignment<float>(&reader, "foo_000", 42);
1170 ExpectAlignment<float>(&reader, "foo_001", 42);
1171 ExpectAlignment<float>(&reader, "foo_002", 42);
1172 ExpectAlignment<float>(&reader, "foo_003", 42);
1173 }
1174 }
1175
BM_BundleAlignment(::testing::benchmark::State & state)1176 static void BM_BundleAlignment(::testing::benchmark::State& state) {
1177 {
1178 const int alignment = state.range(0);
1179 const int tensor_size = state.range(1);
1180 BundleWriter::Options opts;
1181 opts.data_alignment = alignment;
1182 BundleWriter writer(Env::Default(), Prefix("foo"), opts);
1183 TF_CHECK_OK(writer.Add("small", Constant(true, TensorShape({1}))));
1184 TF_CHECK_OK(writer.Add("big", Constant(32.1, TensorShape({tensor_size}))));
1185 TF_CHECK_OK(writer.Finish());
1186 }
1187 BundleReader reader(Env::Default(), Prefix("foo"));
1188 TF_CHECK_OK(reader.status());
1189 for (auto s : state) {
1190 Tensor t;
1191 TF_CHECK_OK(reader.Lookup("big", &t));
1192 }
1193 }
1194
1195 BENCHMARK(BM_BundleAlignment)->ArgPair(1, 512);
1196 BENCHMARK(BM_BundleAlignment)->ArgPair(1, 4096);
1197 BENCHMARK(BM_BundleAlignment)->ArgPair(1, 1048576);
1198 BENCHMARK(BM_BundleAlignment)->ArgPair(4096, 512);
1199 BENCHMARK(BM_BundleAlignment)->ArgPair(4096, 4096);
1200 BENCHMARK(BM_BundleAlignment)->ArgPair(4096, 1048576);
1201
BM_BundleWriterSmallTensor(::testing::benchmark::State & state)1202 static void BM_BundleWriterSmallTensor(::testing::benchmark::State& state) {
1203 const int64_t bytes = state.range(0);
1204 Tensor t = Constant(static_cast<int8>('a'), TensorShape{bytes});
1205 BundleWriter writer(Env::Default(), Prefix("foo"));
1206 int suffix = 0;
1207 for (auto s : state) {
1208 TF_CHECK_OK(writer.Add(strings::StrCat("small", suffix++), t));
1209 }
1210 }
1211
1212 BENCHMARK(BM_BundleWriterSmallTensor)->Range(1, 1 << 20);
1213
BM_BundleWriterLargeTensor(::testing::benchmark::State & state)1214 static void BM_BundleWriterLargeTensor(::testing::benchmark::State& state) {
1215 const int mb = state.range(0);
1216 const int64_t bytes = static_cast<int64_t>(mb) * (1 << 20);
1217 Tensor t = Constant(static_cast<int8>('a'), TensorShape{bytes});
1218 for (auto s : state) {
1219 BundleWriter writer(Env::Default(), Prefix("foo"));
1220 TF_CHECK_OK(writer.Add("big", t));
1221 }
1222 }
1223
1224 BENCHMARK(BM_BundleWriterLargeTensor)->Arg(1 << 10);
1225 BENCHMARK(BM_BundleWriterLargeTensor)->Arg(4 << 10);
1226
1227 } // namespace tensorflow
1228