1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/util/tensor_bundle/tensor_bundle.h"
17
18 #include <random>
19 #include <vector>
20
21 #include "tensorflow/core/framework/tensor_testutil.h"
22 #include "tensorflow/core/framework/types.pb.h"
23 #include "tensorflow/core/framework/variant.h"
24 #include "tensorflow/core/framework/variant_op_registry.h"
25 #include "tensorflow/core/framework/versions.pb.h"
26 #include "tensorflow/core/lib/core/status_test_util.h"
27 #include "tensorflow/core/lib/io/path.h"
28 #include "tensorflow/core/lib/io/table_builder.h"
29 #include "tensorflow/core/lib/strings/str_util.h"
30 #include "tensorflow/core/lib/strings/strcat.h"
31 #include "tensorflow/core/platform/test.h"
32 #include "tensorflow/core/platform/test_benchmark.h"
33
34 namespace tensorflow {
35
36 namespace {
37
Prefix(const string & prefix)38 string Prefix(const string& prefix) {
39 return strings::StrCat(testing::TmpDir(), "/", prefix);
40 }
41
TestdataPrefix(const string & prefix)42 string TestdataPrefix(const string& prefix) {
43 return strings::StrCat(testing::TensorFlowSrcRoot(),
44 "/core/util/tensor_bundle/testdata/", prefix);
45 }
46
47 template <typename T>
Constant(T v,TensorShape shape)48 Tensor Constant(T v, TensorShape shape) {
49 Tensor ret(DataTypeToEnum<T>::value, shape);
50 ret.flat<T>().setConstant(v);
51 return ret;
52 }
53
54 template <typename T>
Constant_2x3(T v)55 Tensor Constant_2x3(T v) {
56 return Constant(v, TensorShape({2, 3}));
57 }
58
59 template <typename T>
Expect(BundleReader * reader,const string & key,const Tensor & expected_val)60 void Expect(BundleReader* reader, const string& key,
61 const Tensor& expected_val) {
62 // Tests for Contains().
63 EXPECT_TRUE(reader->Contains(key));
64 // Tests for LookupDtypeAndShape().
65 DataType dtype;
66 TensorShape shape;
67 TF_ASSERT_OK(reader->LookupDtypeAndShape(key, &dtype, &shape));
68 EXPECT_EQ(expected_val.dtype(), dtype);
69 EXPECT_EQ(expected_val.shape(), shape);
70 // Tests for Lookup(), checking tensor contents.
71 Tensor val(expected_val.dtype(), shape);
72 TF_ASSERT_OK(reader->Lookup(key, &val));
73 test::ExpectTensorEqual<T>(val, expected_val);
74 }
75
76 template <class T>
ExpectVariant(BundleReader * reader,const string & key,const Tensor & expected_t)77 void ExpectVariant(BundleReader* reader, const string& key,
78 const Tensor& expected_t) {
79 // Tests for Contains().
80 EXPECT_TRUE(reader->Contains(key));
81 // Tests for LookupDtypeAndShape().
82 DataType dtype;
83 TensorShape shape;
84 TF_ASSERT_OK(reader->LookupDtypeAndShape(key, &dtype, &shape));
85 // Tests for Lookup(), checking tensor contents.
86 EXPECT_EQ(expected_t.dtype(), dtype);
87 EXPECT_EQ(expected_t.shape(), shape);
88 Tensor actual_t(dtype, shape);
89 TF_ASSERT_OK(reader->Lookup(key, &actual_t));
90 for (int i = 0; i < expected_t.NumElements(); i++) {
91 Variant actual_var = actual_t.flat<Variant>()(i);
92 Variant expected_var = expected_t.flat<Variant>()(i);
93 EXPECT_EQ(actual_var.TypeName(), expected_var.TypeName());
94 auto* actual_val = actual_var.get<T>();
95 auto* expected_val = expected_var.get<T>();
96 EXPECT_EQ(*expected_val, *actual_val);
97 }
98 }
99
100 template <typename T>
ExpectNext(BundleReader * reader,const Tensor & expected_val)101 void ExpectNext(BundleReader* reader, const Tensor& expected_val) {
102 EXPECT_TRUE(reader->Valid());
103 reader->Next();
104 TF_ASSERT_OK(reader->status());
105 Tensor val;
106 TF_ASSERT_OK(reader->ReadCurrent(&val));
107 test::ExpectTensorEqual<T>(val, expected_val);
108 }
109
AllTensorKeys(BundleReader * reader)110 std::vector<string> AllTensorKeys(BundleReader* reader) {
111 std::vector<string> ret;
112 reader->Seek(kHeaderEntryKey);
113 reader->Next();
114 for (; reader->Valid(); reader->Next()) {
115 ret.emplace_back(reader->key());
116 }
117 return ret;
118 }
119
120 // Writes out the metadata file of a bundle again, with the endianness marker
121 // bit flipped.
FlipEndiannessBit(const string & prefix)122 Status FlipEndiannessBit(const string& prefix) {
123 Env* env = Env::Default();
124 const string metadata_tmp_path = Prefix("some_tmp_path");
125 std::unique_ptr<WritableFile> file;
126 TF_RETURN_IF_ERROR(env->NewWritableFile(metadata_tmp_path, &file));
127 table::TableBuilder builder(table::Options(), file.get());
128
129 // Reads the existing metadata file, and fills the builder.
130 {
131 const string filename = MetaFilename(prefix);
132 uint64 file_size;
133 TF_RETURN_IF_ERROR(env->GetFileSize(filename, &file_size));
134 std::unique_ptr<RandomAccessFile> file;
135 TF_RETURN_IF_ERROR(env->NewRandomAccessFile(filename, &file));
136
137 table::Table* table = nullptr;
138 TF_RETURN_IF_ERROR(
139 table::Table::Open(table::Options(), file.get(), file_size, &table));
140 std::unique_ptr<table::Table> table_deleter(table);
141 std::unique_ptr<table::Iterator> iter(table->NewIterator());
142
143 // Reads the header entry.
144 iter->Seek(kHeaderEntryKey);
145 CHECK(iter->Valid());
146 BundleHeaderProto header;
147 CHECK(header.ParseFromArray(iter->value().data(), iter->value().size()));
148 // Flips the endianness.
149 if (header.endianness() == BundleHeaderProto::LITTLE) {
150 header.set_endianness(BundleHeaderProto::BIG);
151 } else {
152 header.set_endianness(BundleHeaderProto::LITTLE);
153 }
154 builder.Add(iter->key(), header.SerializeAsString());
155 iter->Next();
156
157 // Adds the non-header entries unmodified.
158 for (; iter->Valid(); iter->Next()) builder.Add(iter->key(), iter->value());
159 }
160 TF_RETURN_IF_ERROR(builder.Finish());
161 TF_RETURN_IF_ERROR(env->RenameFile(metadata_tmp_path, MetaFilename(prefix)));
162 return file->Close();
163 }
164
165 template <typename T>
TestBasic()166 void TestBasic() {
167 {
168 BundleWriter writer(Env::Default(), Prefix("foo"));
169 TF_EXPECT_OK(writer.Add("foo_003", Constant_2x3<T>(3)));
170 TF_EXPECT_OK(writer.Add("foo_000", Constant_2x3<T>(0)));
171 TF_EXPECT_OK(writer.Add("foo_002", Constant_2x3<T>(2)));
172 TF_EXPECT_OK(writer.Add("foo_001", Constant_2x3<T>(1)));
173 TF_ASSERT_OK(writer.Finish());
174 }
175 {
176 BundleReader reader(Env::Default(), Prefix("foo"));
177 TF_ASSERT_OK(reader.status());
178 EXPECT_EQ(
179 AllTensorKeys(&reader),
180 std::vector<string>({"foo_000", "foo_001", "foo_002", "foo_003"}));
181 Expect<T>(&reader, "foo_000", Constant_2x3<T>(0));
182 Expect<T>(&reader, "foo_001", Constant_2x3<T>(1));
183 Expect<T>(&reader, "foo_002", Constant_2x3<T>(2));
184 Expect<T>(&reader, "foo_003", Constant_2x3<T>(3));
185 }
186 {
187 BundleReader reader(Env::Default(), Prefix("foo"));
188 TF_ASSERT_OK(reader.status());
189 ExpectNext<T>(&reader, Constant_2x3<T>(0));
190 ExpectNext<T>(&reader, Constant_2x3<T>(1));
191 ExpectNext<T>(&reader, Constant_2x3<T>(2));
192 ExpectNext<T>(&reader, Constant_2x3<T>(3));
193 EXPECT_TRUE(reader.Valid());
194 reader.Next();
195 EXPECT_FALSE(reader.Valid());
196 }
197 {
198 BundleWriter writer(Env::Default(), Prefix("bar"));
199 TF_EXPECT_OK(writer.Add("bar_003", Constant_2x3<T>(3)));
200 TF_EXPECT_OK(writer.Add("bar_000", Constant_2x3<T>(0)));
201 TF_EXPECT_OK(writer.Add("bar_002", Constant_2x3<T>(2)));
202 TF_EXPECT_OK(writer.Add("bar_001", Constant_2x3<T>(1)));
203 TF_ASSERT_OK(writer.Finish());
204 }
205 {
206 BundleReader reader(Env::Default(), Prefix("bar"));
207 TF_ASSERT_OK(reader.status());
208 EXPECT_EQ(
209 AllTensorKeys(&reader),
210 std::vector<string>({"bar_000", "bar_001", "bar_002", "bar_003"}));
211 Expect<T>(&reader, "bar_003", Constant_2x3<T>(3));
212 Expect<T>(&reader, "bar_002", Constant_2x3<T>(2));
213 Expect<T>(&reader, "bar_001", Constant_2x3<T>(1));
214 Expect<T>(&reader, "bar_000", Constant_2x3<T>(0));
215 }
216 {
217 BundleReader reader(Env::Default(), Prefix("bar"));
218 TF_ASSERT_OK(reader.status());
219 ExpectNext<T>(&reader, Constant_2x3<T>(0));
220 ExpectNext<T>(&reader, Constant_2x3<T>(1));
221 ExpectNext<T>(&reader, Constant_2x3<T>(2));
222 ExpectNext<T>(&reader, Constant_2x3<T>(3));
223 EXPECT_TRUE(reader.Valid());
224 reader.Next();
225 EXPECT_FALSE(reader.Valid());
226 }
227 TF_ASSERT_OK(MergeBundles(Env::Default(), {Prefix("foo"), Prefix("bar")},
228 Prefix("merged")));
229 {
230 BundleReader reader(Env::Default(), Prefix("merged"));
231 TF_ASSERT_OK(reader.status());
232 EXPECT_EQ(
233 AllTensorKeys(&reader),
234 std::vector<string>({"bar_000", "bar_001", "bar_002", "bar_003",
235 "foo_000", "foo_001", "foo_002", "foo_003"}));
236 Expect<T>(&reader, "bar_000", Constant_2x3<T>(0));
237 Expect<T>(&reader, "bar_001", Constant_2x3<T>(1));
238 Expect<T>(&reader, "bar_002", Constant_2x3<T>(2));
239 Expect<T>(&reader, "bar_003", Constant_2x3<T>(3));
240 Expect<T>(&reader, "foo_000", Constant_2x3<T>(0));
241 Expect<T>(&reader, "foo_001", Constant_2x3<T>(1));
242 Expect<T>(&reader, "foo_002", Constant_2x3<T>(2));
243 Expect<T>(&reader, "foo_003", Constant_2x3<T>(3));
244 }
245 {
246 BundleReader reader(Env::Default(), Prefix("merged"));
247 TF_ASSERT_OK(reader.status());
248 ExpectNext<T>(&reader, Constant_2x3<T>(0));
249 ExpectNext<T>(&reader, Constant_2x3<T>(1));
250 ExpectNext<T>(&reader, Constant_2x3<T>(2));
251 ExpectNext<T>(&reader, Constant_2x3<T>(3));
252 ExpectNext<T>(&reader, Constant_2x3<T>(0));
253 ExpectNext<T>(&reader, Constant_2x3<T>(1));
254 ExpectNext<T>(&reader, Constant_2x3<T>(2));
255 ExpectNext<T>(&reader, Constant_2x3<T>(3));
256 EXPECT_TRUE(reader.Valid());
257 reader.Next();
258 EXPECT_FALSE(reader.Valid());
259 }
260 }
261
262 template <typename T>
TestNonStandardShapes()263 void TestNonStandardShapes() {
264 {
265 BundleWriter writer(Env::Default(), Prefix("nonstandard"));
266 TF_EXPECT_OK(writer.Add("scalar", Constant<T>(0, TensorShape())));
267 TF_EXPECT_OK(
268 writer.Add("non_standard0", Constant<T>(0, TensorShape({0, 1618}))));
269 TF_EXPECT_OK(
270 writer.Add("non_standard1", Constant<T>(0, TensorShape({16, 0, 18}))));
271 TF_ASSERT_OK(writer.Finish());
272 }
273 {
274 BundleReader reader(Env::Default(), Prefix("nonstandard"));
275 TF_ASSERT_OK(reader.status());
276 Expect<T>(&reader, "scalar", Constant<T>(0, TensorShape()));
277 Expect<T>(&reader, "non_standard0", Constant<T>(0, TensorShape({0, 1618})));
278 Expect<T>(&reader, "non_standard1",
279 Constant<T>(0, TensorShape({16, 0, 18})));
280 }
281 }
282
283 // Writes a bundle to disk with a bad "version"; checks for "expected_error".
VersionTest(const VersionDef & version,StringPiece expected_error)284 void VersionTest(const VersionDef& version, StringPiece expected_error) {
285 const string path = Prefix("version_test");
286 {
287 // Prepare an empty bundle with the given version information.
288 BundleHeaderProto header;
289 *header.mutable_version() = version;
290
291 // Write the metadata file to disk.
292 std::unique_ptr<WritableFile> file;
293 TF_ASSERT_OK(Env::Default()->NewWritableFile(MetaFilename(path), &file));
294 table::TableBuilder builder(table::Options(), file.get());
295 builder.Add(kHeaderEntryKey, header.SerializeAsString());
296 TF_ASSERT_OK(builder.Finish());
297 }
298 // Read it back in and verify that we get the expected error.
299 BundleReader reader(Env::Default(), path);
300 EXPECT_TRUE(errors::IsInvalidArgument(reader.status()));
301 EXPECT_TRUE(
302 str_util::StartsWith(reader.status().error_message(), expected_error));
303 }
304
305 } // namespace
306
TEST(TensorBundleTest,Basic)307 TEST(TensorBundleTest, Basic) {
308 TestBasic<float>();
309 TestBasic<double>();
310 TestBasic<int32>();
311 TestBasic<uint8>();
312 TestBasic<int16>();
313 TestBasic<int8>();
314 TestBasic<complex64>();
315 TestBasic<complex128>();
316 TestBasic<int64>();
317 TestBasic<bool>();
318 TestBasic<qint32>();
319 TestBasic<quint8>();
320 TestBasic<qint8>();
321 }
322
TEST(TensorBundleTest,PartitionedVariables)323 TEST(TensorBundleTest, PartitionedVariables) {
324 const TensorShape kFullShape({5, 10});
325 // Adds two slices.
326 // First slice: column 0, all zeros.
327 // Second slice: column 1 to rest, all ones.
328 TensorSlice slice1 = TensorSlice::ParseOrDie("-:0,1");
329 TensorSlice slice2 = TensorSlice::ParseOrDie("-:1,9");
330 {
331 BundleWriter writer(Env::Default(), Prefix("foo"));
332
333 TF_ASSERT_OK(writer.AddSlice("foo", kFullShape, slice1,
334 Constant<float>(0., TensorShape({5, 1}))));
335 TF_ASSERT_OK(writer.AddSlice("foo", kFullShape, slice2,
336 Constant<float>(1., TensorShape({5, 9}))));
337 TF_ASSERT_OK(writer.Finish());
338 }
339 // Reads in full.
340 {
341 BundleReader reader(Env::Default(), Prefix("foo"));
342 TF_ASSERT_OK(reader.status());
343
344 Tensor expected_val(DT_FLOAT, kFullShape);
345 test::FillFn<float>(&expected_val, [](int offset) -> float {
346 if (offset % 10 == 0) {
347 return 0; // First column zeros.
348 }
349 return 1; // Other columns ones.
350 });
351
352 Tensor val(DT_FLOAT, kFullShape);
353 TF_ASSERT_OK(reader.Lookup("foo", &val));
354 test::ExpectTensorEqual<float>(val, expected_val);
355 }
356 // Reads all slices.
357 {
358 BundleReader reader(Env::Default(), Prefix("foo"));
359 TF_ASSERT_OK(reader.status());
360
361 std::vector<TensorSlice> slices;
362 TF_ASSERT_OK(reader.LookupTensorSlices("foo", &slices));
363
364 EXPECT_EQ(2, slices.size());
365 EXPECT_EQ(slice1.DebugString(), slices[0].DebugString());
366 EXPECT_EQ(slice2.DebugString(), slices[1].DebugString());
367 }
368 // Reads a slice consisting of first two columns, "cutting" both slices.
369 {
370 BundleReader reader(Env::Default(), Prefix("foo"));
371 TF_ASSERT_OK(reader.status());
372
373 // First two columns, "cutting" both slices.
374 const TensorSlice distinct_slice = TensorSlice::ParseOrDie("-:0,2");
375 Tensor expected_val(DT_FLOAT, TensorShape({5, 2}));
376 test::FillFn<float>(&expected_val, [](int offset) -> float {
377 if (offset % 2 == 0) {
378 return 0; // First column zeros.
379 }
380 return 1; // Other columns ones.
381 });
382
383 Tensor val(DT_FLOAT, TensorShape({5, 2}));
384 TF_ASSERT_OK(reader.LookupSlice("foo", distinct_slice, &val));
385 test::ExpectTensorEqual<float>(val, expected_val);
386 }
387 // Reads a slice consisting of columns 2-4, "cutting" the second slice only.
388 {
389 BundleReader reader(Env::Default(), Prefix("foo"));
390 TF_ASSERT_OK(reader.status());
391
392 const TensorSlice distinct_slice = TensorSlice::ParseOrDie("-:2,2");
393 Tensor val(DT_FLOAT, TensorShape({5, 2}));
394 TF_ASSERT_OK(reader.LookupSlice("foo", distinct_slice, &val));
395 test::ExpectTensorEqual<float>(val,
396 Constant<float>(1., TensorShape({5, 2})));
397 }
398 }
399
TEST(TensorBundleTest,EquivalentSliceTest)400 TEST(TensorBundleTest, EquivalentSliceTest) {
401 const TensorShape kFullShape({5, 10});
402 const Tensor kExpected(Constant<float>(1., kFullShape));
403 {
404 BundleWriter writer(Env::Default(), Prefix("foo"));
405 TF_ASSERT_OK(writer.AddSlice("no_extents", kFullShape,
406 TensorSlice::ParseOrDie("-:-"), kExpected));
407 TF_ASSERT_OK(writer.AddSlice("both_extents", kFullShape,
408 TensorSlice::ParseOrDie("0,5:0,10"),
409 kExpected));
410 TF_ASSERT_OK(writer.Finish());
411 }
412 // Slices match exactly and are fully abbreviated.
413 {
414 BundleReader reader(Env::Default(), Prefix("foo"));
415 TF_ASSERT_OK(reader.status());
416 const TensorSlice slice = TensorSlice::ParseOrDie("-:-");
417 Tensor val(DT_FLOAT, TensorShape(kFullShape));
418 TF_ASSERT_OK(reader.LookupSlice("no_extents", slice, &val));
419 test::ExpectTensorEqual<float>(val, kExpected);
420 }
421 // Slice match exactly and are fully specified.
422 {
423 BundleReader reader(Env::Default(), Prefix("foo"));
424 TF_ASSERT_OK(reader.status());
425 const TensorSlice slice = TensorSlice::ParseOrDie("0,5:0,10");
426 Tensor val(DT_FLOAT, TensorShape(kFullShape));
427 TF_ASSERT_OK(reader.LookupSlice("both_extents", slice, &val));
428 test::ExpectTensorEqual<float>(val, kExpected);
429 }
430 // Stored slice has no extents, spec has extents.
431 {
432 BundleReader reader(Env::Default(), Prefix("foo"));
433 TF_ASSERT_OK(reader.status());
434 const TensorSlice slice = TensorSlice::ParseOrDie("0,5:0,10");
435 Tensor val(DT_FLOAT, TensorShape(kFullShape));
436 TF_ASSERT_OK(reader.LookupSlice("no_extents", slice, &val));
437 test::ExpectTensorEqual<float>(val, kExpected);
438 }
439 // Stored slice has both extents, spec has no extents.
440 {
441 BundleReader reader(Env::Default(), Prefix("foo"));
442 TF_ASSERT_OK(reader.status());
443 const TensorSlice slice = TensorSlice::ParseOrDie("-:-");
444 Tensor val(DT_FLOAT, TensorShape(kFullShape));
445 TF_ASSERT_OK(reader.LookupSlice("both_extents", slice, &val));
446 test::ExpectTensorEqual<float>(val, kExpected);
447 }
448 }
449
TEST(TensorBundleTest,NonStandardShapes)450 TEST(TensorBundleTest, NonStandardShapes) {
451 TestNonStandardShapes<float>();
452 TestNonStandardShapes<double>();
453 TestNonStandardShapes<int32>();
454 TestNonStandardShapes<uint8>();
455 TestNonStandardShapes<int16>();
456 TestNonStandardShapes<int8>();
457 TestNonStandardShapes<complex64>();
458 TestNonStandardShapes<complex128>();
459 TestNonStandardShapes<int64>();
460 TestNonStandardShapes<bool>();
461 TestNonStandardShapes<qint32>();
462 TestNonStandardShapes<quint8>();
463 TestNonStandardShapes<qint8>();
464 }
465
TEST(TensorBundleTest,StringTensorsOldFormat)466 TEST(TensorBundleTest, StringTensorsOldFormat) {
467 // Test string tensor bundle made with previous version of code that use
468 // varint32s to store string lengths (we now use varint64s).
469 BundleReader reader(Env::Default(), TestdataPrefix("old_string_tensors/foo"));
470 TF_ASSERT_OK(reader.status());
471 EXPECT_EQ(AllTensorKeys(&reader),
472 std::vector<string>({"floats", "scalar", "string_tensor", "strs"}));
473
474 Expect<string>(&reader, "string_tensor", Tensor(DT_STRING, TensorShape({1})));
475 Expect<string>(&reader, "scalar", test::AsTensor<string>({"hello"}));
476 Expect<string>(
477 &reader, "strs",
478 test::AsTensor<string>({"hello", "", "x01", string(1 << 10, 'c')}));
479 Expect<float>(&reader, "floats", Constant_2x3<float>(16.18));
480 }
481
TEST(TensorBundleTest,StringTensors)482 TEST(TensorBundleTest, StringTensors) {
483 constexpr size_t kLongLength = static_cast<size_t>(UINT32_MAX) + 1;
484 Tensor long_string_tensor(DT_STRING, TensorShape({1}));
485
486 {
487 BundleWriter writer(Env::Default(), Prefix("foo"));
488 TF_EXPECT_OK(writer.Add("string_tensor",
489 Tensor(DT_STRING, TensorShape({1})))); // Empty.
490 TF_EXPECT_OK(writer.Add("scalar", test::AsTensor<string>({"hello"})));
491 TF_EXPECT_OK(writer.Add(
492 "strs",
493 test::AsTensor<string>({"hello", "", "x01", string(1 << 25, 'c')})));
494
495 // Requires a 64-bit length.
496 string* backing_string = long_string_tensor.flat<string>().data();
497 backing_string->assign(kLongLength, 'd');
498 TF_EXPECT_OK(writer.Add("long_scalar", long_string_tensor));
499
500 // Mixes in some floats.
501 TF_EXPECT_OK(writer.Add("floats", Constant_2x3<float>(16.18)));
502 TF_ASSERT_OK(writer.Finish());
503 }
504 {
505 BundleReader reader(Env::Default(), Prefix("foo"));
506 TF_ASSERT_OK(reader.status());
507 EXPECT_EQ(AllTensorKeys(&reader),
508 std::vector<string>({"floats", "long_scalar", "scalar",
509 "string_tensor", "strs"}));
510
511 Expect<string>(&reader, "string_tensor",
512 Tensor(DT_STRING, TensorShape({1})));
513 Expect<string>(&reader, "scalar", test::AsTensor<string>({"hello"}));
514 Expect<string>(
515 &reader, "strs",
516 test::AsTensor<string>({"hello", "", "x01", string(1 << 25, 'c')}));
517
518 Expect<float>(&reader, "floats", Constant_2x3<float>(16.18));
519
520 // We don't use the Expect function so we can re-use the
521 // `long_string_tensor` buffer for reading out long_scalar to keep memory
522 // usage reasonable.
523 EXPECT_TRUE(reader.Contains("long_scalar"));
524 DataType dtype;
525 TensorShape shape;
526 TF_ASSERT_OK(reader.LookupDtypeAndShape("long_scalar", &dtype, &shape));
527 EXPECT_EQ(DT_STRING, dtype);
528 EXPECT_EQ(TensorShape({1}), shape);
529
530 // Zero-out the string so that we can be sure the new one is read in.
531 string* backing_string = long_string_tensor.flat<string>().data();
532 backing_string->assign("");
533
534 // Read long_scalar and check it contains kLongLength 'd's.
535 TF_ASSERT_OK(reader.Lookup("long_scalar", &long_string_tensor));
536 ASSERT_EQ(backing_string, long_string_tensor.flat<string>().data());
537 EXPECT_EQ(kLongLength, backing_string->length());
538 for (char c : *backing_string) {
539 // Not using ASSERT_EQ('d', c) because this way is twice as fast due to
540 // compiler optimizations.
541 if (c != 'd') {
542 FAIL() << "long_scalar is not full of 'd's as expected.";
543 break;
544 }
545 }
546 }
547 }
548
549 class VariantObject {
550 public:
VariantObject()551 VariantObject() {}
VariantObject(const string & metadata,int64 value)552 VariantObject(const string& metadata, int64 value)
553 : metadata_(metadata), value_(value) {}
554
TypeName() const555 string TypeName() const { return "TEST VariantObject"; }
Encode(VariantTensorData * data) const556 void Encode(VariantTensorData* data) const {
557 data->set_type_name(TypeName());
558 data->set_metadata(metadata_);
559 Tensor val_t = Tensor(DT_INT64, TensorShape({}));
560 val_t.scalar<int64>()() = value_;
561 *(data->add_tensors()) = val_t;
562 }
Decode(const VariantTensorData & data)563 bool Decode(const VariantTensorData& data) {
564 EXPECT_EQ(data.type_name(), TypeName());
565 data.get_metadata(&metadata_);
566 EXPECT_EQ(data.tensors_size(), 1);
567 value_ = data.tensors(0).scalar<int64>()();
568 return true;
569 }
operator ==(const VariantObject other) const570 bool operator==(const VariantObject other) const {
571 return metadata_ == other.metadata_ && value_ == other.value_;
572 }
573 string metadata_;
574 int64 value_;
575 };
576
577 REGISTER_UNARY_VARIANT_DECODE_FUNCTION(VariantObject, "TEST VariantObject");
578
TEST(TensorBundleTest,VariantTensors)579 TEST(TensorBundleTest, VariantTensors) {
580 {
581 BundleWriter writer(Env::Default(), Prefix("foo"));
582 TF_EXPECT_OK(
583 writer.Add("variant_tensor",
584 test::AsTensor<Variant>({VariantObject("test", 10),
585 VariantObject("test1", 20)})));
586 TF_ASSERT_OK(writer.Finish());
587 }
588 {
589 BundleReader reader(Env::Default(), Prefix("foo"));
590 TF_ASSERT_OK(reader.status());
591 ExpectVariant<VariantObject>(
592 &reader, "variant_tensor",
593 test::AsTensor<Variant>(
594 {VariantObject("test", 10), VariantObject("test1", 20)}));
595 }
596 }
597
TEST(TensorBundleTest,DirectoryStructure)598 TEST(TensorBundleTest, DirectoryStructure) {
599 Env* env = Env::Default();
600 // Writes two bundles.
601 const std::vector<string> kBundlePrefixes = {Prefix("worker0"),
602 Prefix("worker1")};
603 for (int i = 0; i < 2; ++i) {
604 BundleWriter writer(env, kBundlePrefixes[i]);
605 TF_EXPECT_OK(
606 writer.Add(strings::StrCat("tensor", i), Constant_2x3<float>(0.)));
607 TF_ASSERT_OK(writer.Finish());
608 }
609
610 // Ensures we have the expected files.
611 auto CheckDirFiles = [env](const string& bundle_prefix,
612 gtl::ArraySlice<string> expected_files) {
613 StringPiece dir = io::Dirname(bundle_prefix);
614 for (const string& expected_file : expected_files) {
615 TF_EXPECT_OK(env->FileExists(io::JoinPath(dir, expected_file)));
616 }
617 };
618
619 // Check we have:
620 // worker<i>.index
621 // worker<i>.data-00000-of-00001
622 CheckDirFiles(kBundlePrefixes[0],
623 {"worker0.index", "worker0.data-00000-of-00001"});
624 CheckDirFiles(kBundlePrefixes[1],
625 {"worker1.index", "worker1.data-00000-of-00001"});
626
627 // Trivially "merge" one bundle to some other location (i.e., a renaming).
628 const string kAnotherPrefix = Prefix("another");
629 TF_ASSERT_OK(MergeBundles(env, {kBundlePrefixes[0]}, kAnotherPrefix));
630 CheckDirFiles(kAnotherPrefix,
631 {"another.index", "another.data-00000-of-00001"});
632
633 // Performs actual merge of the two bundles. Check we have:
634 // merged.index
635 // merged.data-00000-of-00002
636 // merged.data-00001-of-00002
637 const string kMerged = Prefix("merged");
638 TF_ASSERT_OK(
639 MergeBundles(env, {kAnotherPrefix, kBundlePrefixes[1]}, kMerged));
640 CheckDirFiles(kMerged, {"merged.index", "merged.data-00000-of-00002",
641 "merged.data-00001-of-00002"});
642 }
643
TEST(TensorBundleTest,Error)644 TEST(TensorBundleTest, Error) {
645 { // Dup keys.
646 BundleWriter writer(Env::Default(), Prefix("dup"));
647 TF_EXPECT_OK(writer.Add("foo", Constant_2x3(1.f)));
648 EXPECT_FALSE(writer.Add("foo", Constant_2x3(2.f)).ok());
649 EXPECT_TRUE(
650 str_util::StrContains(writer.status().ToString(), "duplicate key"));
651 EXPECT_FALSE(writer.Finish().ok());
652 }
653 { // Double finish
654 BundleWriter writer(Env::Default(), Prefix("bad"));
655 EXPECT_TRUE(writer.Finish().ok());
656 EXPECT_FALSE(writer.Finish().ok());
657 }
658 { // Not found.
659 BundleReader reader(Env::Default(), Prefix("nonexist"));
660 EXPECT_TRUE(str_util::StrContains(reader.status().ToString(), "Not found"));
661 }
662 }
663
TEST(TensorBundleTest,Checksum)664 TEST(TensorBundleTest, Checksum) {
665 // Randomly flips a byte in [pos_lhs, end of data file), or exactly byte
666 // pos_lhs if exact_pos == True.
667 auto FlipByte = [](const string& prefix, int pos_lhs,
668 bool exact_pos = false) {
669 DCHECK_GE(pos_lhs, 0);
670 const string& datafile = DataFilename(Prefix(prefix), 0, 1);
671 string data;
672 TF_ASSERT_OK(ReadFileToString(Env::Default(), datafile, &data));
673
674 int byte_pos = 0;
675 if (!exact_pos) {
676 std::mt19937 rng;
677 std::uniform_int_distribution<int> dist(pos_lhs, data.size() - 1);
678 byte_pos = dist(rng);
679 } else {
680 byte_pos = pos_lhs;
681 }
682 data[byte_pos] = ~data[byte_pos];
683 TF_ASSERT_OK(WriteStringToFile(Env::Default(), datafile, data));
684 };
685 // The lookup should fail with a checksum-related message.
686 auto ExpectLookupFails = [](const string& prefix, const string& key,
687 const string& expected_msg, Tensor& val) {
688 BundleReader reader(Env::Default(), Prefix(prefix));
689 Status status = reader.Lookup(key, &val);
690 EXPECT_TRUE(errors::IsDataLoss(status));
691 EXPECT_TRUE(str_util::StrContains(status.ToString(), expected_msg));
692 };
693
694 // Corrupts a float tensor.
695 {
696 BundleWriter writer(Env::Default(), Prefix("singleton"));
697 TF_EXPECT_OK(writer.Add("foo", Constant_2x3(1.f)));
698 TF_ASSERT_OK(writer.Finish());
699
700 FlipByte("singleton", 0 /* corrupts any byte */);
701 Tensor val(DT_FLOAT, TensorShape({2, 3}));
702 ExpectLookupFails("singleton", "foo",
703 "Checksum does not match" /* expected fail msg */, val);
704 }
705 // Corrupts a string tensor.
706 {
707 auto WriteStrings = []() {
708 BundleWriter writer(Env::Default(), Prefix("strings"));
709 TF_EXPECT_OK(
710 writer.Add("foo", test::AsTensor<string>({"hello", "world"})));
711 TF_ASSERT_OK(writer.Finish());
712 };
713 // Corrupts the first two bytes, which are the varint32-encoded lengths
714 // of the two string elements. Should hit mismatch on length cksum.
715 for (int i = 0; i < 2; ++i) {
716 WriteStrings();
717 FlipByte("strings", i, true /* corrupts exactly byte i */);
718 Tensor val(DT_STRING, TensorShape({2}));
719 ExpectLookupFails(
720 "strings", "foo",
721 "length checksum does not match" /* expected fail msg */, val);
722 }
723 // Corrupts the string bytes, should hit an overall cksum mismatch.
724 WriteStrings();
725 FlipByte("strings", 2 /* corrupts starting from byte 2 */);
726 Tensor val(DT_STRING, TensorShape({2}));
727 ExpectLookupFails("strings", "foo",
728 "Checksum does not match" /* expected fail msg */, val);
729 }
730 }
731
TEST(TensorBundleTest,Endianness)732 TEST(TensorBundleTest, Endianness) {
733 BundleWriter writer(Env::Default(), Prefix("end"));
734 TF_EXPECT_OK(writer.Add("key", Constant_2x3<float>(1.0)));
735 TF_ASSERT_OK(writer.Finish());
736
737 // Flips the endianness bit.
738 TF_ASSERT_OK(FlipEndiannessBit(Prefix("end")));
739
740 BundleReader reader(Env::Default(), Prefix("end"));
741 EXPECT_TRUE(errors::IsUnimplemented(reader.status()));
742 EXPECT_TRUE(str_util::StrContains(reader.status().ToString(),
743 "different endianness from the reader"));
744 }
745
TEST(TensorBundleTest,TruncatedTensorContents)746 TEST(TensorBundleTest, TruncatedTensorContents) {
747 Env* env = Env::Default();
748 BundleWriter writer(env, Prefix("end"));
749 TF_EXPECT_OK(writer.Add("key", Constant_2x3<float>(1.0)));
750 TF_ASSERT_OK(writer.Finish());
751
752 // Truncates the data file by one byte, so that we hit EOF.
753 const string datafile = DataFilename(Prefix("end"), 0, 1);
754 string data;
755 TF_ASSERT_OK(ReadFileToString(env, datafile, &data));
756 ASSERT_TRUE(!data.empty());
757 TF_ASSERT_OK(WriteStringToFile(env, datafile,
758 StringPiece(data.data(), data.size() - 1)));
759
760 BundleReader reader(env, Prefix("end"));
761 TF_ASSERT_OK(reader.status());
762 Tensor val(DT_FLOAT, TensorShape({2, 3}));
763 EXPECT_TRUE(errors::IsOutOfRange(reader.Lookup("key", &val)));
764 }
765
TEST(TensorBundleTest,HeaderEntry)766 TEST(TensorBundleTest, HeaderEntry) {
767 {
768 BundleWriter writer(Env::Default(), Prefix("b"));
769 TF_EXPECT_OK(writer.Add("key", Constant_2x3<float>(1.0)));
770 TF_ASSERT_OK(writer.Finish());
771 }
772
773 // Extracts out the header.
774 BundleHeaderProto header;
775 {
776 BundleReader reader(Env::Default(), Prefix("b"));
777 TF_ASSERT_OK(reader.status());
778 reader.Seek(kHeaderEntryKey);
779 ASSERT_TRUE(reader.Valid());
780 ASSERT_TRUE(ParseProtoUnlimited(&header, reader.value().data(),
781 reader.value().size()));
782 }
783
784 // num_shards
785 EXPECT_EQ(1, header.num_shards());
786 // endianness
787 if (port::kLittleEndian) {
788 EXPECT_EQ(BundleHeaderProto::LITTLE, header.endianness());
789 } else {
790 EXPECT_EQ(BundleHeaderProto::BIG, header.endianness());
791 }
792 // version
793 EXPECT_GT(kTensorBundleVersion, 0);
794 EXPECT_EQ(kTensorBundleVersion, header.version().producer());
795 EXPECT_EQ(kTensorBundleMinConsumer, header.version().min_consumer());
796 }
797
TEST(TensorBundleTest,VersionTest)798 TEST(TensorBundleTest, VersionTest) {
799 // Min consumer.
800 {
801 VersionDef versions;
802 versions.set_producer(kTensorBundleVersion + 1);
803 versions.set_min_consumer(kTensorBundleVersion + 1);
804 VersionTest(
805 versions,
806 strings::StrCat("Checkpoint min consumer version ",
807 kTensorBundleVersion + 1, " above current version ",
808 kTensorBundleVersion, " for TensorFlow"));
809 }
810 // Min producer.
811 {
812 VersionDef versions;
813 versions.set_producer(kTensorBundleMinProducer - 1);
814 VersionTest(
815 versions,
816 strings::StrCat("Checkpoint producer version ",
817 kTensorBundleMinProducer - 1, " below min producer ",
818 kTensorBundleMinProducer, " supported by TensorFlow"));
819 }
820 // Bad consumer.
821 {
822 VersionDef versions;
823 versions.set_producer(kTensorBundleVersion + 1);
824 versions.add_bad_consumers(kTensorBundleVersion);
825 VersionTest(
826 versions,
827 strings::StrCat(
828 "Checkpoint disallows consumer version ", kTensorBundleVersion,
829 ". Please upgrade TensorFlow: this version is likely buggy."));
830 }
831 }
832
833 class TensorBundleAlignmentTest : public ::testing::Test {
834 protected:
835 template <typename T>
ExpectAlignment(BundleReader * reader,const string & key,int alignment)836 void ExpectAlignment(BundleReader* reader, const string& key, int alignment) {
837 BundleEntryProto full_tensor_entry;
838 TF_ASSERT_OK(reader->GetBundleEntryProto(key, &full_tensor_entry));
839 EXPECT_EQ(0, full_tensor_entry.offset() % alignment);
840 }
841 };
842
TEST_F(TensorBundleAlignmentTest,AlignmentTest)843 TEST_F(TensorBundleAlignmentTest, AlignmentTest) {
844 {
845 BundleWriter::Options opts;
846 opts.data_alignment = 42;
847 BundleWriter writer(Env::Default(), Prefix("foo"), opts);
848 TF_EXPECT_OK(writer.Add("foo_003", Constant_2x3<float>(3)));
849 TF_EXPECT_OK(writer.Add("foo_000", Constant_2x3<float>(0)));
850 TF_EXPECT_OK(writer.Add("foo_002", Constant_2x3<float>(2)));
851 TF_EXPECT_OK(writer.Add("foo_001", Constant_2x3<float>(1)));
852 TF_ASSERT_OK(writer.Finish());
853 }
854 {
855 BundleReader reader(Env::Default(), Prefix("foo"));
856 TF_ASSERT_OK(reader.status());
857 EXPECT_EQ(
858 AllTensorKeys(&reader),
859 std::vector<string>({"foo_000", "foo_001", "foo_002", "foo_003"}));
860 Expect<float>(&reader, "foo_000", Constant_2x3<float>(0));
861 Expect<float>(&reader, "foo_001", Constant_2x3<float>(1));
862 Expect<float>(&reader, "foo_002", Constant_2x3<float>(2));
863 Expect<float>(&reader, "foo_003", Constant_2x3<float>(3));
864 }
865 {
866 BundleReader reader(Env::Default(), Prefix("foo"));
867 TF_ASSERT_OK(reader.status());
868 ExpectNext<float>(&reader, Constant_2x3<float>(0));
869 ExpectNext<float>(&reader, Constant_2x3<float>(1));
870 ExpectNext<float>(&reader, Constant_2x3<float>(2));
871 ExpectNext<float>(&reader, Constant_2x3<float>(3));
872 EXPECT_TRUE(reader.Valid());
873 reader.Next();
874 EXPECT_FALSE(reader.Valid());
875 }
876 {
877 BundleReader reader(Env::Default(), Prefix("foo"));
878 TF_ASSERT_OK(reader.status());
879 ExpectAlignment<float>(&reader, "foo_000", 42);
880 ExpectAlignment<float>(&reader, "foo_001", 42);
881 ExpectAlignment<float>(&reader, "foo_002", 42);
882 ExpectAlignment<float>(&reader, "foo_003", 42);
883 }
884 }
885
BM_BundleAlignmentByteOff(int iters,int alignment,int tensor_size)886 static void BM_BundleAlignmentByteOff(int iters, int alignment,
887 int tensor_size) {
888 testing::StopTiming();
889 {
890 BundleWriter::Options opts;
891 opts.data_alignment = alignment;
892 BundleWriter writer(Env::Default(), Prefix("foo"), opts);
893 TF_CHECK_OK(writer.Add("small", Constant(true, TensorShape({1}))));
894 TF_CHECK_OK(writer.Add("big", Constant(32.1, TensorShape({tensor_size}))));
895 TF_CHECK_OK(writer.Finish());
896 }
897 BundleReader reader(Env::Default(), Prefix("foo"));
898 TF_CHECK_OK(reader.status());
899 testing::StartTiming();
900 for (int i = 0; i < iters; ++i) {
901 Tensor t;
902 TF_CHECK_OK(reader.Lookup("big", &t));
903 }
904 testing::StopTiming();
905 }
906
907 #define BM_BundleAlignment(ALIGN, SIZE) \
908 static void BM_BundleAlignment_##ALIGN##_##SIZE(int iters) { \
909 BM_BundleAlignmentByteOff(iters, ALIGN, SIZE); \
910 } \
911 BENCHMARK(BM_BundleAlignment_##ALIGN##_##SIZE)
912
913 BM_BundleAlignment(1, 512);
914 BM_BundleAlignment(1, 4096);
915 BM_BundleAlignment(1, 1048576);
916 BM_BundleAlignment(4096, 512);
917 BM_BundleAlignment(4096, 4096);
918 BM_BundleAlignment(4096, 1048576);
919
920 } // namespace tensorflow
921