• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7 http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/util/tensor_bundle/tensor_bundle.h"
17 
18 #include <random>
19 #include <vector>
20 
21 #include "tensorflow/core/framework/tensor_testutil.h"
22 #include "tensorflow/core/framework/types.pb.h"
23 #include "tensorflow/core/framework/variant.h"
24 #include "tensorflow/core/framework/variant_op_registry.h"
25 #include "tensorflow/core/framework/versions.pb.h"
26 #include "tensorflow/core/lib/core/status_test_util.h"
27 #include "tensorflow/core/lib/io/path.h"
28 #include "tensorflow/core/lib/io/table_builder.h"
29 #include "tensorflow/core/lib/strings/str_util.h"
30 #include "tensorflow/core/lib/strings/strcat.h"
31 #include "tensorflow/core/platform/test.h"
32 #include "tensorflow/core/platform/test_benchmark.h"
33 
34 namespace tensorflow {
35 
36 namespace {
37 
Prefix(const string & prefix)38 string Prefix(const string& prefix) {
39   return strings::StrCat(testing::TmpDir(), "/", prefix);
40 }
41 
TestdataPrefix(const string & prefix)42 string TestdataPrefix(const string& prefix) {
43   return strings::StrCat(testing::TensorFlowSrcRoot(),
44                          "/core/util/tensor_bundle/testdata/", prefix);
45 }
46 
47 template <typename T>
Constant(T v,TensorShape shape)48 Tensor Constant(T v, TensorShape shape) {
49   Tensor ret(DataTypeToEnum<T>::value, shape);
50   ret.flat<T>().setConstant(v);
51   return ret;
52 }
53 
54 template <typename T>
Constant_2x3(T v)55 Tensor Constant_2x3(T v) {
56   return Constant(v, TensorShape({2, 3}));
57 }
58 
59 template <typename T>
Expect(BundleReader * reader,const string & key,const Tensor & expected_val)60 void Expect(BundleReader* reader, const string& key,
61             const Tensor& expected_val) {
62   // Tests for Contains().
63   EXPECT_TRUE(reader->Contains(key));
64   // Tests for LookupDtypeAndShape().
65   DataType dtype;
66   TensorShape shape;
67   TF_ASSERT_OK(reader->LookupDtypeAndShape(key, &dtype, &shape));
68   EXPECT_EQ(expected_val.dtype(), dtype);
69   EXPECT_EQ(expected_val.shape(), shape);
70   // Tests for Lookup(), checking tensor contents.
71   Tensor val(expected_val.dtype(), shape);
72   TF_ASSERT_OK(reader->Lookup(key, &val));
73   test::ExpectTensorEqual<T>(val, expected_val);
74 }
75 
76 template <class T>
ExpectVariant(BundleReader * reader,const string & key,const Tensor & expected_t)77 void ExpectVariant(BundleReader* reader, const string& key,
78                    const Tensor& expected_t) {
79   // Tests for Contains().
80   EXPECT_TRUE(reader->Contains(key));
81   // Tests for LookupDtypeAndShape().
82   DataType dtype;
83   TensorShape shape;
84   TF_ASSERT_OK(reader->LookupDtypeAndShape(key, &dtype, &shape));
85   // Tests for Lookup(), checking tensor contents.
86   EXPECT_EQ(expected_t.dtype(), dtype);
87   EXPECT_EQ(expected_t.shape(), shape);
88   Tensor actual_t(dtype, shape);
89   TF_ASSERT_OK(reader->Lookup(key, &actual_t));
90   for (int i = 0; i < expected_t.NumElements(); i++) {
91     Variant actual_var = actual_t.flat<Variant>()(i);
92     Variant expected_var = expected_t.flat<Variant>()(i);
93     EXPECT_EQ(actual_var.TypeName(), expected_var.TypeName());
94     auto* actual_val = actual_var.get<T>();
95     auto* expected_val = expected_var.get<T>();
96     EXPECT_EQ(*expected_val, *actual_val);
97   }
98 }
99 
100 template <typename T>
ExpectNext(BundleReader * reader,const Tensor & expected_val)101 void ExpectNext(BundleReader* reader, const Tensor& expected_val) {
102   EXPECT_TRUE(reader->Valid());
103   reader->Next();
104   TF_ASSERT_OK(reader->status());
105   Tensor val;
106   TF_ASSERT_OK(reader->ReadCurrent(&val));
107   test::ExpectTensorEqual<T>(val, expected_val);
108 }
109 
AllTensorKeys(BundleReader * reader)110 std::vector<string> AllTensorKeys(BundleReader* reader) {
111   std::vector<string> ret;
112   reader->Seek(kHeaderEntryKey);
113   reader->Next();
114   for (; reader->Valid(); reader->Next()) {
115     ret.emplace_back(reader->key());
116   }
117   return ret;
118 }
119 
120 // Writes out the metadata file of a bundle again, with the endianness marker
121 // bit flipped.
FlipEndiannessBit(const string & prefix)122 Status FlipEndiannessBit(const string& prefix) {
123   Env* env = Env::Default();
124   const string metadata_tmp_path = Prefix("some_tmp_path");
125   std::unique_ptr<WritableFile> file;
126   TF_RETURN_IF_ERROR(env->NewWritableFile(metadata_tmp_path, &file));
127   table::TableBuilder builder(table::Options(), file.get());
128 
129   // Reads the existing metadata file, and fills the builder.
130   {
131     const string filename = MetaFilename(prefix);
132     uint64 file_size;
133     TF_RETURN_IF_ERROR(env->GetFileSize(filename, &file_size));
134     std::unique_ptr<RandomAccessFile> file;
135     TF_RETURN_IF_ERROR(env->NewRandomAccessFile(filename, &file));
136 
137     table::Table* table = nullptr;
138     TF_RETURN_IF_ERROR(
139         table::Table::Open(table::Options(), file.get(), file_size, &table));
140     std::unique_ptr<table::Table> table_deleter(table);
141     std::unique_ptr<table::Iterator> iter(table->NewIterator());
142 
143     // Reads the header entry.
144     iter->Seek(kHeaderEntryKey);
145     CHECK(iter->Valid());
146     BundleHeaderProto header;
147     CHECK(header.ParseFromArray(iter->value().data(), iter->value().size()));
148     // Flips the endianness.
149     if (header.endianness() == BundleHeaderProto::LITTLE) {
150       header.set_endianness(BundleHeaderProto::BIG);
151     } else {
152       header.set_endianness(BundleHeaderProto::LITTLE);
153     }
154     builder.Add(iter->key(), header.SerializeAsString());
155     iter->Next();
156 
157     // Adds the non-header entries unmodified.
158     for (; iter->Valid(); iter->Next()) builder.Add(iter->key(), iter->value());
159   }
160   TF_RETURN_IF_ERROR(builder.Finish());
161   TF_RETURN_IF_ERROR(env->RenameFile(metadata_tmp_path, MetaFilename(prefix)));
162   return file->Close();
163 }
164 
165 template <typename T>
TestBasic()166 void TestBasic() {
167   {
168     BundleWriter writer(Env::Default(), Prefix("foo"));
169     TF_EXPECT_OK(writer.Add("foo_003", Constant_2x3<T>(3)));
170     TF_EXPECT_OK(writer.Add("foo_000", Constant_2x3<T>(0)));
171     TF_EXPECT_OK(writer.Add("foo_002", Constant_2x3<T>(2)));
172     TF_EXPECT_OK(writer.Add("foo_001", Constant_2x3<T>(1)));
173     TF_ASSERT_OK(writer.Finish());
174   }
175   {
176     BundleReader reader(Env::Default(), Prefix("foo"));
177     TF_ASSERT_OK(reader.status());
178     EXPECT_EQ(
179         AllTensorKeys(&reader),
180         std::vector<string>({"foo_000", "foo_001", "foo_002", "foo_003"}));
181     Expect<T>(&reader, "foo_000", Constant_2x3<T>(0));
182     Expect<T>(&reader, "foo_001", Constant_2x3<T>(1));
183     Expect<T>(&reader, "foo_002", Constant_2x3<T>(2));
184     Expect<T>(&reader, "foo_003", Constant_2x3<T>(3));
185   }
186   {
187     BundleReader reader(Env::Default(), Prefix("foo"));
188     TF_ASSERT_OK(reader.status());
189     ExpectNext<T>(&reader, Constant_2x3<T>(0));
190     ExpectNext<T>(&reader, Constant_2x3<T>(1));
191     ExpectNext<T>(&reader, Constant_2x3<T>(2));
192     ExpectNext<T>(&reader, Constant_2x3<T>(3));
193     EXPECT_TRUE(reader.Valid());
194     reader.Next();
195     EXPECT_FALSE(reader.Valid());
196   }
197   {
198     BundleWriter writer(Env::Default(), Prefix("bar"));
199     TF_EXPECT_OK(writer.Add("bar_003", Constant_2x3<T>(3)));
200     TF_EXPECT_OK(writer.Add("bar_000", Constant_2x3<T>(0)));
201     TF_EXPECT_OK(writer.Add("bar_002", Constant_2x3<T>(2)));
202     TF_EXPECT_OK(writer.Add("bar_001", Constant_2x3<T>(1)));
203     TF_ASSERT_OK(writer.Finish());
204   }
205   {
206     BundleReader reader(Env::Default(), Prefix("bar"));
207     TF_ASSERT_OK(reader.status());
208     EXPECT_EQ(
209         AllTensorKeys(&reader),
210         std::vector<string>({"bar_000", "bar_001", "bar_002", "bar_003"}));
211     Expect<T>(&reader, "bar_003", Constant_2x3<T>(3));
212     Expect<T>(&reader, "bar_002", Constant_2x3<T>(2));
213     Expect<T>(&reader, "bar_001", Constant_2x3<T>(1));
214     Expect<T>(&reader, "bar_000", Constant_2x3<T>(0));
215   }
216   {
217     BundleReader reader(Env::Default(), Prefix("bar"));
218     TF_ASSERT_OK(reader.status());
219     ExpectNext<T>(&reader, Constant_2x3<T>(0));
220     ExpectNext<T>(&reader, Constant_2x3<T>(1));
221     ExpectNext<T>(&reader, Constant_2x3<T>(2));
222     ExpectNext<T>(&reader, Constant_2x3<T>(3));
223     EXPECT_TRUE(reader.Valid());
224     reader.Next();
225     EXPECT_FALSE(reader.Valid());
226   }
227   TF_ASSERT_OK(MergeBundles(Env::Default(), {Prefix("foo"), Prefix("bar")},
228                             Prefix("merged")));
229   {
230     BundleReader reader(Env::Default(), Prefix("merged"));
231     TF_ASSERT_OK(reader.status());
232     EXPECT_EQ(
233         AllTensorKeys(&reader),
234         std::vector<string>({"bar_000", "bar_001", "bar_002", "bar_003",
235                              "foo_000", "foo_001", "foo_002", "foo_003"}));
236     Expect<T>(&reader, "bar_000", Constant_2x3<T>(0));
237     Expect<T>(&reader, "bar_001", Constant_2x3<T>(1));
238     Expect<T>(&reader, "bar_002", Constant_2x3<T>(2));
239     Expect<T>(&reader, "bar_003", Constant_2x3<T>(3));
240     Expect<T>(&reader, "foo_000", Constant_2x3<T>(0));
241     Expect<T>(&reader, "foo_001", Constant_2x3<T>(1));
242     Expect<T>(&reader, "foo_002", Constant_2x3<T>(2));
243     Expect<T>(&reader, "foo_003", Constant_2x3<T>(3));
244   }
245   {
246     BundleReader reader(Env::Default(), Prefix("merged"));
247     TF_ASSERT_OK(reader.status());
248     ExpectNext<T>(&reader, Constant_2x3<T>(0));
249     ExpectNext<T>(&reader, Constant_2x3<T>(1));
250     ExpectNext<T>(&reader, Constant_2x3<T>(2));
251     ExpectNext<T>(&reader, Constant_2x3<T>(3));
252     ExpectNext<T>(&reader, Constant_2x3<T>(0));
253     ExpectNext<T>(&reader, Constant_2x3<T>(1));
254     ExpectNext<T>(&reader, Constant_2x3<T>(2));
255     ExpectNext<T>(&reader, Constant_2x3<T>(3));
256     EXPECT_TRUE(reader.Valid());
257     reader.Next();
258     EXPECT_FALSE(reader.Valid());
259   }
260 }
261 
262 template <typename T>
TestNonStandardShapes()263 void TestNonStandardShapes() {
264   {
265     BundleWriter writer(Env::Default(), Prefix("nonstandard"));
266     TF_EXPECT_OK(writer.Add("scalar", Constant<T>(0, TensorShape())));
267     TF_EXPECT_OK(
268         writer.Add("non_standard0", Constant<T>(0, TensorShape({0, 1618}))));
269     TF_EXPECT_OK(
270         writer.Add("non_standard1", Constant<T>(0, TensorShape({16, 0, 18}))));
271     TF_ASSERT_OK(writer.Finish());
272   }
273   {
274     BundleReader reader(Env::Default(), Prefix("nonstandard"));
275     TF_ASSERT_OK(reader.status());
276     Expect<T>(&reader, "scalar", Constant<T>(0, TensorShape()));
277     Expect<T>(&reader, "non_standard0", Constant<T>(0, TensorShape({0, 1618})));
278     Expect<T>(&reader, "non_standard1",
279               Constant<T>(0, TensorShape({16, 0, 18})));
280   }
281 }
282 
283 // Writes a bundle to disk with a bad "version"; checks for "expected_error".
VersionTest(const VersionDef & version,StringPiece expected_error)284 void VersionTest(const VersionDef& version, StringPiece expected_error) {
285   const string path = Prefix("version_test");
286   {
287     // Prepare an empty bundle with the given version information.
288     BundleHeaderProto header;
289     *header.mutable_version() = version;
290 
291     // Write the metadata file to disk.
292     std::unique_ptr<WritableFile> file;
293     TF_ASSERT_OK(Env::Default()->NewWritableFile(MetaFilename(path), &file));
294     table::TableBuilder builder(table::Options(), file.get());
295     builder.Add(kHeaderEntryKey, header.SerializeAsString());
296     TF_ASSERT_OK(builder.Finish());
297   }
298   // Read it back in and verify that we get the expected error.
299   BundleReader reader(Env::Default(), path);
300   EXPECT_TRUE(errors::IsInvalidArgument(reader.status()));
301   EXPECT_TRUE(
302       str_util::StartsWith(reader.status().error_message(), expected_error));
303 }
304 
305 }  // namespace
306 
TEST(TensorBundleTest,Basic)307 TEST(TensorBundleTest, Basic) {
308   TestBasic<float>();
309   TestBasic<double>();
310   TestBasic<int32>();
311   TestBasic<uint8>();
312   TestBasic<int16>();
313   TestBasic<int8>();
314   TestBasic<complex64>();
315   TestBasic<complex128>();
316   TestBasic<int64>();
317   TestBasic<bool>();
318   TestBasic<qint32>();
319   TestBasic<quint8>();
320   TestBasic<qint8>();
321 }
322 
TEST(TensorBundleTest,PartitionedVariables)323 TEST(TensorBundleTest, PartitionedVariables) {
324   const TensorShape kFullShape({5, 10});
325   // Adds two slices.
326   // First slice: column 0, all zeros.
327   // Second slice: column 1 to rest, all ones.
328   TensorSlice slice1 = TensorSlice::ParseOrDie("-:0,1");
329   TensorSlice slice2 = TensorSlice::ParseOrDie("-:1,9");
330   {
331     BundleWriter writer(Env::Default(), Prefix("foo"));
332 
333     TF_ASSERT_OK(writer.AddSlice("foo", kFullShape, slice1,
334                                  Constant<float>(0., TensorShape({5, 1}))));
335     TF_ASSERT_OK(writer.AddSlice("foo", kFullShape, slice2,
336                                  Constant<float>(1., TensorShape({5, 9}))));
337     TF_ASSERT_OK(writer.Finish());
338   }
339   // Reads in full.
340   {
341     BundleReader reader(Env::Default(), Prefix("foo"));
342     TF_ASSERT_OK(reader.status());
343 
344     Tensor expected_val(DT_FLOAT, kFullShape);
345     test::FillFn<float>(&expected_val, [](int offset) -> float {
346       if (offset % 10 == 0) {
347         return 0;  // First column zeros.
348       }
349       return 1;  // Other columns ones.
350     });
351 
352     Tensor val(DT_FLOAT, kFullShape);
353     TF_ASSERT_OK(reader.Lookup("foo", &val));
354     test::ExpectTensorEqual<float>(val, expected_val);
355   }
356   // Reads all slices.
357   {
358     BundleReader reader(Env::Default(), Prefix("foo"));
359     TF_ASSERT_OK(reader.status());
360 
361     std::vector<TensorSlice> slices;
362     TF_ASSERT_OK(reader.LookupTensorSlices("foo", &slices));
363 
364     EXPECT_EQ(2, slices.size());
365     EXPECT_EQ(slice1.DebugString(), slices[0].DebugString());
366     EXPECT_EQ(slice2.DebugString(), slices[1].DebugString());
367   }
368   // Reads a slice consisting of first two columns, "cutting" both slices.
369   {
370     BundleReader reader(Env::Default(), Prefix("foo"));
371     TF_ASSERT_OK(reader.status());
372 
373     // First two columns, "cutting" both slices.
374     const TensorSlice distinct_slice = TensorSlice::ParseOrDie("-:0,2");
375     Tensor expected_val(DT_FLOAT, TensorShape({5, 2}));
376     test::FillFn<float>(&expected_val, [](int offset) -> float {
377       if (offset % 2 == 0) {
378         return 0;  // First column zeros.
379       }
380       return 1;  // Other columns ones.
381     });
382 
383     Tensor val(DT_FLOAT, TensorShape({5, 2}));
384     TF_ASSERT_OK(reader.LookupSlice("foo", distinct_slice, &val));
385     test::ExpectTensorEqual<float>(val, expected_val);
386   }
387   // Reads a slice consisting of columns 2-4, "cutting" the second slice only.
388   {
389     BundleReader reader(Env::Default(), Prefix("foo"));
390     TF_ASSERT_OK(reader.status());
391 
392     const TensorSlice distinct_slice = TensorSlice::ParseOrDie("-:2,2");
393     Tensor val(DT_FLOAT, TensorShape({5, 2}));
394     TF_ASSERT_OK(reader.LookupSlice("foo", distinct_slice, &val));
395     test::ExpectTensorEqual<float>(val,
396                                    Constant<float>(1., TensorShape({5, 2})));
397   }
398 }
399 
TEST(TensorBundleTest,EquivalentSliceTest)400 TEST(TensorBundleTest, EquivalentSliceTest) {
401   const TensorShape kFullShape({5, 10});
402   const Tensor kExpected(Constant<float>(1., kFullShape));
403   {
404     BundleWriter writer(Env::Default(), Prefix("foo"));
405     TF_ASSERT_OK(writer.AddSlice("no_extents", kFullShape,
406                                  TensorSlice::ParseOrDie("-:-"), kExpected));
407     TF_ASSERT_OK(writer.AddSlice("both_extents", kFullShape,
408                                  TensorSlice::ParseOrDie("0,5:0,10"),
409                                  kExpected));
410     TF_ASSERT_OK(writer.Finish());
411   }
412   // Slices match exactly and are fully abbreviated.
413   {
414     BundleReader reader(Env::Default(), Prefix("foo"));
415     TF_ASSERT_OK(reader.status());
416     const TensorSlice slice = TensorSlice::ParseOrDie("-:-");
417     Tensor val(DT_FLOAT, TensorShape(kFullShape));
418     TF_ASSERT_OK(reader.LookupSlice("no_extents", slice, &val));
419     test::ExpectTensorEqual<float>(val, kExpected);
420   }
421   // Slice match exactly and are fully specified.
422   {
423     BundleReader reader(Env::Default(), Prefix("foo"));
424     TF_ASSERT_OK(reader.status());
425     const TensorSlice slice = TensorSlice::ParseOrDie("0,5:0,10");
426     Tensor val(DT_FLOAT, TensorShape(kFullShape));
427     TF_ASSERT_OK(reader.LookupSlice("both_extents", slice, &val));
428     test::ExpectTensorEqual<float>(val, kExpected);
429   }
430   // Stored slice has no extents, spec has extents.
431   {
432     BundleReader reader(Env::Default(), Prefix("foo"));
433     TF_ASSERT_OK(reader.status());
434     const TensorSlice slice = TensorSlice::ParseOrDie("0,5:0,10");
435     Tensor val(DT_FLOAT, TensorShape(kFullShape));
436     TF_ASSERT_OK(reader.LookupSlice("no_extents", slice, &val));
437     test::ExpectTensorEqual<float>(val, kExpected);
438   }
439   // Stored slice has both extents, spec has no extents.
440   {
441     BundleReader reader(Env::Default(), Prefix("foo"));
442     TF_ASSERT_OK(reader.status());
443     const TensorSlice slice = TensorSlice::ParseOrDie("-:-");
444     Tensor val(DT_FLOAT, TensorShape(kFullShape));
445     TF_ASSERT_OK(reader.LookupSlice("both_extents", slice, &val));
446     test::ExpectTensorEqual<float>(val, kExpected);
447   }
448 }
449 
TEST(TensorBundleTest,NonStandardShapes)450 TEST(TensorBundleTest, NonStandardShapes) {
451   TestNonStandardShapes<float>();
452   TestNonStandardShapes<double>();
453   TestNonStandardShapes<int32>();
454   TestNonStandardShapes<uint8>();
455   TestNonStandardShapes<int16>();
456   TestNonStandardShapes<int8>();
457   TestNonStandardShapes<complex64>();
458   TestNonStandardShapes<complex128>();
459   TestNonStandardShapes<int64>();
460   TestNonStandardShapes<bool>();
461   TestNonStandardShapes<qint32>();
462   TestNonStandardShapes<quint8>();
463   TestNonStandardShapes<qint8>();
464 }
465 
TEST(TensorBundleTest,StringTensorsOldFormat)466 TEST(TensorBundleTest, StringTensorsOldFormat) {
467   // Test string tensor bundle made with previous version of code that use
468   // varint32s to store string lengths (we now use varint64s).
469   BundleReader reader(Env::Default(), TestdataPrefix("old_string_tensors/foo"));
470   TF_ASSERT_OK(reader.status());
471   EXPECT_EQ(AllTensorKeys(&reader),
472             std::vector<string>({"floats", "scalar", "string_tensor", "strs"}));
473 
474   Expect<string>(&reader, "string_tensor", Tensor(DT_STRING, TensorShape({1})));
475   Expect<string>(&reader, "scalar", test::AsTensor<string>({"hello"}));
476   Expect<string>(
477       &reader, "strs",
478       test::AsTensor<string>({"hello", "", "x01", string(1 << 10, 'c')}));
479   Expect<float>(&reader, "floats", Constant_2x3<float>(16.18));
480 }
481 
TEST(TensorBundleTest,StringTensors)482 TEST(TensorBundleTest, StringTensors) {
483   constexpr size_t kLongLength = static_cast<size_t>(UINT32_MAX) + 1;
484   Tensor long_string_tensor(DT_STRING, TensorShape({1}));
485 
486   {
487     BundleWriter writer(Env::Default(), Prefix("foo"));
488     TF_EXPECT_OK(writer.Add("string_tensor",
489                             Tensor(DT_STRING, TensorShape({1}))));  // Empty.
490     TF_EXPECT_OK(writer.Add("scalar", test::AsTensor<string>({"hello"})));
491     TF_EXPECT_OK(writer.Add(
492         "strs",
493         test::AsTensor<string>({"hello", "", "x01", string(1 << 25, 'c')})));
494 
495     // Requires a 64-bit length.
496     string* backing_string = long_string_tensor.flat<string>().data();
497     backing_string->assign(kLongLength, 'd');
498     TF_EXPECT_OK(writer.Add("long_scalar", long_string_tensor));
499 
500     // Mixes in some floats.
501     TF_EXPECT_OK(writer.Add("floats", Constant_2x3<float>(16.18)));
502     TF_ASSERT_OK(writer.Finish());
503   }
504   {
505     BundleReader reader(Env::Default(), Prefix("foo"));
506     TF_ASSERT_OK(reader.status());
507     EXPECT_EQ(AllTensorKeys(&reader),
508               std::vector<string>({"floats", "long_scalar", "scalar",
509                                    "string_tensor", "strs"}));
510 
511     Expect<string>(&reader, "string_tensor",
512                    Tensor(DT_STRING, TensorShape({1})));
513     Expect<string>(&reader, "scalar", test::AsTensor<string>({"hello"}));
514     Expect<string>(
515         &reader, "strs",
516         test::AsTensor<string>({"hello", "", "x01", string(1 << 25, 'c')}));
517 
518     Expect<float>(&reader, "floats", Constant_2x3<float>(16.18));
519 
520     // We don't use the Expect function so we can re-use the
521     // `long_string_tensor` buffer for reading out long_scalar to keep memory
522     // usage reasonable.
523     EXPECT_TRUE(reader.Contains("long_scalar"));
524     DataType dtype;
525     TensorShape shape;
526     TF_ASSERT_OK(reader.LookupDtypeAndShape("long_scalar", &dtype, &shape));
527     EXPECT_EQ(DT_STRING, dtype);
528     EXPECT_EQ(TensorShape({1}), shape);
529 
530     // Zero-out the string so that we can be sure the new one is read in.
531     string* backing_string = long_string_tensor.flat<string>().data();
532     backing_string->assign("");
533 
534     // Read long_scalar and check it contains kLongLength 'd's.
535     TF_ASSERT_OK(reader.Lookup("long_scalar", &long_string_tensor));
536     ASSERT_EQ(backing_string, long_string_tensor.flat<string>().data());
537     EXPECT_EQ(kLongLength, backing_string->length());
538     for (char c : *backing_string) {
539       // Not using ASSERT_EQ('d', c) because this way is twice as fast due to
540       // compiler optimizations.
541       if (c != 'd') {
542         FAIL() << "long_scalar is not full of 'd's as expected.";
543         break;
544       }
545     }
546   }
547 }
548 
549 class VariantObject {
550  public:
VariantObject()551   VariantObject() {}
VariantObject(const string & metadata,int64 value)552   VariantObject(const string& metadata, int64 value)
553       : metadata_(metadata), value_(value) {}
554 
TypeName() const555   string TypeName() const { return "TEST VariantObject"; }
Encode(VariantTensorData * data) const556   void Encode(VariantTensorData* data) const {
557     data->set_type_name(TypeName());
558     data->set_metadata(metadata_);
559     Tensor val_t = Tensor(DT_INT64, TensorShape({}));
560     val_t.scalar<int64>()() = value_;
561     *(data->add_tensors()) = val_t;
562   }
Decode(const VariantTensorData & data)563   bool Decode(const VariantTensorData& data) {
564     EXPECT_EQ(data.type_name(), TypeName());
565     data.get_metadata(&metadata_);
566     EXPECT_EQ(data.tensors_size(), 1);
567     value_ = data.tensors(0).scalar<int64>()();
568     return true;
569   }
operator ==(const VariantObject other) const570   bool operator==(const VariantObject other) const {
571     return metadata_ == other.metadata_ && value_ == other.value_;
572   }
573   string metadata_;
574   int64 value_;
575 };
576 
577 REGISTER_UNARY_VARIANT_DECODE_FUNCTION(VariantObject, "TEST VariantObject");
578 
TEST(TensorBundleTest,VariantTensors)579 TEST(TensorBundleTest, VariantTensors) {
580   {
581     BundleWriter writer(Env::Default(), Prefix("foo"));
582     TF_EXPECT_OK(
583         writer.Add("variant_tensor",
584                    test::AsTensor<Variant>({VariantObject("test", 10),
585                                             VariantObject("test1", 20)})));
586     TF_ASSERT_OK(writer.Finish());
587   }
588   {
589     BundleReader reader(Env::Default(), Prefix("foo"));
590     TF_ASSERT_OK(reader.status());
591     ExpectVariant<VariantObject>(
592         &reader, "variant_tensor",
593         test::AsTensor<Variant>(
594             {VariantObject("test", 10), VariantObject("test1", 20)}));
595   }
596 }
597 
TEST(TensorBundleTest,DirectoryStructure)598 TEST(TensorBundleTest, DirectoryStructure) {
599   Env* env = Env::Default();
600   // Writes two bundles.
601   const std::vector<string> kBundlePrefixes = {Prefix("worker0"),
602                                                Prefix("worker1")};
603   for (int i = 0; i < 2; ++i) {
604     BundleWriter writer(env, kBundlePrefixes[i]);
605     TF_EXPECT_OK(
606         writer.Add(strings::StrCat("tensor", i), Constant_2x3<float>(0.)));
607     TF_ASSERT_OK(writer.Finish());
608   }
609 
610   // Ensures we have the expected files.
611   auto CheckDirFiles = [env](const string& bundle_prefix,
612                              gtl::ArraySlice<string> expected_files) {
613     StringPiece dir = io::Dirname(bundle_prefix);
614     for (const string& expected_file : expected_files) {
615       TF_EXPECT_OK(env->FileExists(io::JoinPath(dir, expected_file)));
616     }
617   };
618 
619   // Check we have:
620   //   worker<i>.index
621   //   worker<i>.data-00000-of-00001
622   CheckDirFiles(kBundlePrefixes[0],
623                 {"worker0.index", "worker0.data-00000-of-00001"});
624   CheckDirFiles(kBundlePrefixes[1],
625                 {"worker1.index", "worker1.data-00000-of-00001"});
626 
627   // Trivially "merge" one bundle to some other location (i.e., a renaming).
628   const string kAnotherPrefix = Prefix("another");
629   TF_ASSERT_OK(MergeBundles(env, {kBundlePrefixes[0]}, kAnotherPrefix));
630   CheckDirFiles(kAnotherPrefix,
631                 {"another.index", "another.data-00000-of-00001"});
632 
633   // Performs actual merge of the two bundles.  Check we have:
634   //   merged.index
635   //   merged.data-00000-of-00002
636   //   merged.data-00001-of-00002
637   const string kMerged = Prefix("merged");
638   TF_ASSERT_OK(
639       MergeBundles(env, {kAnotherPrefix, kBundlePrefixes[1]}, kMerged));
640   CheckDirFiles(kMerged, {"merged.index", "merged.data-00000-of-00002",
641                           "merged.data-00001-of-00002"});
642 }
643 
TEST(TensorBundleTest,Error)644 TEST(TensorBundleTest, Error) {
645   {  // Dup keys.
646     BundleWriter writer(Env::Default(), Prefix("dup"));
647     TF_EXPECT_OK(writer.Add("foo", Constant_2x3(1.f)));
648     EXPECT_FALSE(writer.Add("foo", Constant_2x3(2.f)).ok());
649     EXPECT_TRUE(
650         str_util::StrContains(writer.status().ToString(), "duplicate key"));
651     EXPECT_FALSE(writer.Finish().ok());
652   }
653   {  // Double finish
654     BundleWriter writer(Env::Default(), Prefix("bad"));
655     EXPECT_TRUE(writer.Finish().ok());
656     EXPECT_FALSE(writer.Finish().ok());
657   }
658   {  // Not found.
659     BundleReader reader(Env::Default(), Prefix("nonexist"));
660     EXPECT_TRUE(str_util::StrContains(reader.status().ToString(), "Not found"));
661   }
662 }
663 
TEST(TensorBundleTest,Checksum)664 TEST(TensorBundleTest, Checksum) {
665   // Randomly flips a byte in [pos_lhs, end of data file), or exactly byte
666   // pos_lhs if exact_pos == True.
667   auto FlipByte = [](const string& prefix, int pos_lhs,
668                      bool exact_pos = false) {
669     DCHECK_GE(pos_lhs, 0);
670     const string& datafile = DataFilename(Prefix(prefix), 0, 1);
671     string data;
672     TF_ASSERT_OK(ReadFileToString(Env::Default(), datafile, &data));
673 
674     int byte_pos = 0;
675     if (!exact_pos) {
676       std::mt19937 rng;
677       std::uniform_int_distribution<int> dist(pos_lhs, data.size() - 1);
678       byte_pos = dist(rng);
679     } else {
680       byte_pos = pos_lhs;
681     }
682     data[byte_pos] = ~data[byte_pos];
683     TF_ASSERT_OK(WriteStringToFile(Env::Default(), datafile, data));
684   };
685   // The lookup should fail with a checksum-related message.
686   auto ExpectLookupFails = [](const string& prefix, const string& key,
687                               const string& expected_msg, Tensor& val) {
688     BundleReader reader(Env::Default(), Prefix(prefix));
689     Status status = reader.Lookup(key, &val);
690     EXPECT_TRUE(errors::IsDataLoss(status));
691     EXPECT_TRUE(str_util::StrContains(status.ToString(), expected_msg));
692   };
693 
694   // Corrupts a float tensor.
695   {
696     BundleWriter writer(Env::Default(), Prefix("singleton"));
697     TF_EXPECT_OK(writer.Add("foo", Constant_2x3(1.f)));
698     TF_ASSERT_OK(writer.Finish());
699 
700     FlipByte("singleton", 0 /* corrupts any byte */);
701     Tensor val(DT_FLOAT, TensorShape({2, 3}));
702     ExpectLookupFails("singleton", "foo",
703                       "Checksum does not match" /* expected fail msg */, val);
704   }
705   // Corrupts a string tensor.
706   {
707     auto WriteStrings = []() {
708       BundleWriter writer(Env::Default(), Prefix("strings"));
709       TF_EXPECT_OK(
710           writer.Add("foo", test::AsTensor<string>({"hello", "world"})));
711       TF_ASSERT_OK(writer.Finish());
712     };
713     // Corrupts the first two bytes, which are the varint32-encoded lengths
714     // of the two string elements.  Should hit mismatch on length cksum.
715     for (int i = 0; i < 2; ++i) {
716       WriteStrings();
717       FlipByte("strings", i, true /* corrupts exactly byte i */);
718       Tensor val(DT_STRING, TensorShape({2}));
719       ExpectLookupFails(
720           "strings", "foo",
721           "length checksum does not match" /* expected fail msg */, val);
722     }
723     // Corrupts the string bytes, should hit an overall cksum mismatch.
724     WriteStrings();
725     FlipByte("strings", 2 /* corrupts starting from byte 2 */);
726     Tensor val(DT_STRING, TensorShape({2}));
727     ExpectLookupFails("strings", "foo",
728                       "Checksum does not match" /* expected fail msg */, val);
729   }
730 }
731 
TEST(TensorBundleTest,Endianness)732 TEST(TensorBundleTest, Endianness) {
733   BundleWriter writer(Env::Default(), Prefix("end"));
734   TF_EXPECT_OK(writer.Add("key", Constant_2x3<float>(1.0)));
735   TF_ASSERT_OK(writer.Finish());
736 
737   // Flips the endianness bit.
738   TF_ASSERT_OK(FlipEndiannessBit(Prefix("end")));
739 
740   BundleReader reader(Env::Default(), Prefix("end"));
741   EXPECT_TRUE(errors::IsUnimplemented(reader.status()));
742   EXPECT_TRUE(str_util::StrContains(reader.status().ToString(),
743                                     "different endianness from the reader"));
744 }
745 
TEST(TensorBundleTest,TruncatedTensorContents)746 TEST(TensorBundleTest, TruncatedTensorContents) {
747   Env* env = Env::Default();
748   BundleWriter writer(env, Prefix("end"));
749   TF_EXPECT_OK(writer.Add("key", Constant_2x3<float>(1.0)));
750   TF_ASSERT_OK(writer.Finish());
751 
752   // Truncates the data file by one byte, so that we hit EOF.
753   const string datafile = DataFilename(Prefix("end"), 0, 1);
754   string data;
755   TF_ASSERT_OK(ReadFileToString(env, datafile, &data));
756   ASSERT_TRUE(!data.empty());
757   TF_ASSERT_OK(WriteStringToFile(env, datafile,
758                                  StringPiece(data.data(), data.size() - 1)));
759 
760   BundleReader reader(env, Prefix("end"));
761   TF_ASSERT_OK(reader.status());
762   Tensor val(DT_FLOAT, TensorShape({2, 3}));
763   EXPECT_TRUE(errors::IsOutOfRange(reader.Lookup("key", &val)));
764 }
765 
TEST(TensorBundleTest,HeaderEntry)766 TEST(TensorBundleTest, HeaderEntry) {
767   {
768     BundleWriter writer(Env::Default(), Prefix("b"));
769     TF_EXPECT_OK(writer.Add("key", Constant_2x3<float>(1.0)));
770     TF_ASSERT_OK(writer.Finish());
771   }
772 
773   // Extracts out the header.
774   BundleHeaderProto header;
775   {
776     BundleReader reader(Env::Default(), Prefix("b"));
777     TF_ASSERT_OK(reader.status());
778     reader.Seek(kHeaderEntryKey);
779     ASSERT_TRUE(reader.Valid());
780     ASSERT_TRUE(ParseProtoUnlimited(&header, reader.value().data(),
781                                     reader.value().size()));
782   }
783 
784   // num_shards
785   EXPECT_EQ(1, header.num_shards());
786   // endianness
787   if (port::kLittleEndian) {
788     EXPECT_EQ(BundleHeaderProto::LITTLE, header.endianness());
789   } else {
790     EXPECT_EQ(BundleHeaderProto::BIG, header.endianness());
791   }
792   // version
793   EXPECT_GT(kTensorBundleVersion, 0);
794   EXPECT_EQ(kTensorBundleVersion, header.version().producer());
795   EXPECT_EQ(kTensorBundleMinConsumer, header.version().min_consumer());
796 }
797 
TEST(TensorBundleTest,VersionTest)798 TEST(TensorBundleTest, VersionTest) {
799   // Min consumer.
800   {
801     VersionDef versions;
802     versions.set_producer(kTensorBundleVersion + 1);
803     versions.set_min_consumer(kTensorBundleVersion + 1);
804     VersionTest(
805         versions,
806         strings::StrCat("Checkpoint min consumer version ",
807                         kTensorBundleVersion + 1, " above current version ",
808                         kTensorBundleVersion, " for TensorFlow"));
809   }
810   // Min producer.
811   {
812     VersionDef versions;
813     versions.set_producer(kTensorBundleMinProducer - 1);
814     VersionTest(
815         versions,
816         strings::StrCat("Checkpoint producer version ",
817                         kTensorBundleMinProducer - 1, " below min producer ",
818                         kTensorBundleMinProducer, " supported by TensorFlow"));
819   }
820   // Bad consumer.
821   {
822     VersionDef versions;
823     versions.set_producer(kTensorBundleVersion + 1);
824     versions.add_bad_consumers(kTensorBundleVersion);
825     VersionTest(
826         versions,
827         strings::StrCat(
828             "Checkpoint disallows consumer version ", kTensorBundleVersion,
829             ".  Please upgrade TensorFlow: this version is likely buggy."));
830   }
831 }
832 
833 class TensorBundleAlignmentTest : public ::testing::Test {
834  protected:
835   template <typename T>
ExpectAlignment(BundleReader * reader,const string & key,int alignment)836   void ExpectAlignment(BundleReader* reader, const string& key, int alignment) {
837     BundleEntryProto full_tensor_entry;
838     TF_ASSERT_OK(reader->GetBundleEntryProto(key, &full_tensor_entry));
839     EXPECT_EQ(0, full_tensor_entry.offset() % alignment);
840   }
841 };
842 
TEST_F(TensorBundleAlignmentTest,AlignmentTest)843 TEST_F(TensorBundleAlignmentTest, AlignmentTest) {
844   {
845     BundleWriter::Options opts;
846     opts.data_alignment = 42;
847     BundleWriter writer(Env::Default(), Prefix("foo"), opts);
848     TF_EXPECT_OK(writer.Add("foo_003", Constant_2x3<float>(3)));
849     TF_EXPECT_OK(writer.Add("foo_000", Constant_2x3<float>(0)));
850     TF_EXPECT_OK(writer.Add("foo_002", Constant_2x3<float>(2)));
851     TF_EXPECT_OK(writer.Add("foo_001", Constant_2x3<float>(1)));
852     TF_ASSERT_OK(writer.Finish());
853   }
854   {
855     BundleReader reader(Env::Default(), Prefix("foo"));
856     TF_ASSERT_OK(reader.status());
857     EXPECT_EQ(
858         AllTensorKeys(&reader),
859         std::vector<string>({"foo_000", "foo_001", "foo_002", "foo_003"}));
860     Expect<float>(&reader, "foo_000", Constant_2x3<float>(0));
861     Expect<float>(&reader, "foo_001", Constant_2x3<float>(1));
862     Expect<float>(&reader, "foo_002", Constant_2x3<float>(2));
863     Expect<float>(&reader, "foo_003", Constant_2x3<float>(3));
864   }
865   {
866     BundleReader reader(Env::Default(), Prefix("foo"));
867     TF_ASSERT_OK(reader.status());
868     ExpectNext<float>(&reader, Constant_2x3<float>(0));
869     ExpectNext<float>(&reader, Constant_2x3<float>(1));
870     ExpectNext<float>(&reader, Constant_2x3<float>(2));
871     ExpectNext<float>(&reader, Constant_2x3<float>(3));
872     EXPECT_TRUE(reader.Valid());
873     reader.Next();
874     EXPECT_FALSE(reader.Valid());
875   }
876   {
877     BundleReader reader(Env::Default(), Prefix("foo"));
878     TF_ASSERT_OK(reader.status());
879     ExpectAlignment<float>(&reader, "foo_000", 42);
880     ExpectAlignment<float>(&reader, "foo_001", 42);
881     ExpectAlignment<float>(&reader, "foo_002", 42);
882     ExpectAlignment<float>(&reader, "foo_003", 42);
883   }
884 }
885 
BM_BundleAlignmentByteOff(int iters,int alignment,int tensor_size)886 static void BM_BundleAlignmentByteOff(int iters, int alignment,
887                                       int tensor_size) {
888   testing::StopTiming();
889   {
890     BundleWriter::Options opts;
891     opts.data_alignment = alignment;
892     BundleWriter writer(Env::Default(), Prefix("foo"), opts);
893     TF_CHECK_OK(writer.Add("small", Constant(true, TensorShape({1}))));
894     TF_CHECK_OK(writer.Add("big", Constant(32.1, TensorShape({tensor_size}))));
895     TF_CHECK_OK(writer.Finish());
896   }
897   BundleReader reader(Env::Default(), Prefix("foo"));
898   TF_CHECK_OK(reader.status());
899   testing::StartTiming();
900   for (int i = 0; i < iters; ++i) {
901     Tensor t;
902     TF_CHECK_OK(reader.Lookup("big", &t));
903   }
904   testing::StopTiming();
905 }
906 
907 #define BM_BundleAlignment(ALIGN, SIZE)                        \
908   static void BM_BundleAlignment_##ALIGN##_##SIZE(int iters) { \
909     BM_BundleAlignmentByteOff(iters, ALIGN, SIZE);             \
910   }                                                            \
911   BENCHMARK(BM_BundleAlignment_##ALIGN##_##SIZE)
912 
913 BM_BundleAlignment(1, 512);
914 BM_BundleAlignment(1, 4096);
915 BM_BundleAlignment(1, 1048576);
916 BM_BundleAlignment(4096, 512);
917 BM_BundleAlignment(4096, 4096);
918 BM_BundleAlignment(4096, 1048576);
919 
920 }  // namespace tensorflow
921