• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7 http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/util/tensor_bundle/tensor_bundle.h"
17 
18 #include <random>
19 #include <string>
20 #include <vector>
21 
22 #if defined(_WIN32)
23 #include <windows.h>
24 #endif  // _WIN32
25 
26 #include "tensorflow/core/framework/tensor_testutil.h"
27 #include "tensorflow/core/framework/tensor_util.h"
28 #include "tensorflow/core/framework/types.pb.h"
29 #include "tensorflow/core/framework/variant.h"
30 #include "tensorflow/core/framework/variant_op_registry.h"
31 #include "tensorflow/core/framework/versions.pb.h"
32 #include "tensorflow/core/lib/core/status_test_util.h"
33 #include "tensorflow/core/lib/io/path.h"
34 #include "tensorflow/core/lib/io/table_builder.h"
35 #include "tensorflow/core/lib/strings/str_util.h"
36 #include "tensorflow/core/lib/strings/strcat.h"
37 #include "tensorflow/core/platform/test.h"
38 #include "tensorflow/core/platform/test_benchmark.h"
39 #include "tensorflow/core/protobuf/error_codes.pb.h"
40 #include "tensorflow/core/protobuf/tensor_bundle.pb.h"
41 #include "tensorflow/core/util/tensor_bundle/byte_swap.h"
42 
43 namespace tensorflow {
44 using ::testing::ElementsAre;
45 
46 namespace {
47 
48 // Prepend the current test case's working temporary directory to <prefix>
Prefix(const string & prefix)49 string Prefix(const string& prefix) {
50   return strings::StrCat(testing::TmpDir(), "/", prefix);
51 }
52 
53 // Construct a data input directory by prepending the test data root
54 // directory to <prefix>
TestdataPrefix(const string & prefix)55 string TestdataPrefix(const string& prefix) {
56   return strings::StrCat(testing::TensorFlowSrcRoot(),
57                          "/core/util/tensor_bundle/testdata/", prefix);
58 }
59 
60 template <typename T>
Constant(T v,TensorShape shape)61 Tensor Constant(T v, TensorShape shape) {
62   Tensor ret(DataTypeToEnum<T>::value, shape);
63   ret.flat<T>().setConstant(v);
64   return ret;
65 }
66 
67 template <typename T>
Constant_2x3(T v)68 Tensor Constant_2x3(T v) {
69   return Constant(v, TensorShape({2, 3}));
70 }
71 
ByteSwap(Tensor t)72 Tensor ByteSwap(Tensor t) {
73   Tensor ret = tensor::DeepCopy(t);
74   TF_EXPECT_OK(ByteSwapTensor(&ret));
75   return ret;
76 }
77 
78 // Assert that <reader> has a tensor under <key> matching <expected_val> in
79 // terms of both shape, dtype, and value
80 template <typename T>
Expect(BundleReader * reader,const string & key,const Tensor & expected_val)81 void Expect(BundleReader* reader, const string& key,
82             const Tensor& expected_val) {
83   // Tests for Contains().
84   EXPECT_TRUE(reader->Contains(key));
85   // Tests for LookupDtypeAndShape().
86   DataType dtype;
87   TensorShape shape;
88   TF_ASSERT_OK(reader->LookupDtypeAndShape(key, &dtype, &shape));
89   EXPECT_EQ(expected_val.dtype(), dtype);
90   EXPECT_EQ(expected_val.shape(), shape);
91   // Tests for Lookup(), checking tensor contents.
92   Tensor val(expected_val.dtype(), shape);
93   TF_ASSERT_OK(reader->Lookup(key, &val));
94   test::ExpectTensorEqual<T>(val, expected_val);
95 }
96 
97 template <class T>
ExpectVariant(BundleReader * reader,const string & key,const Tensor & expected_t)98 void ExpectVariant(BundleReader* reader, const string& key,
99                    const Tensor& expected_t) {
100   // Tests for Contains().
101   EXPECT_TRUE(reader->Contains(key));
102   // Tests for LookupDtypeAndShape().
103   DataType dtype;
104   TensorShape shape;
105   TF_ASSERT_OK(reader->LookupDtypeAndShape(key, &dtype, &shape));
106   // Tests for Lookup(), checking tensor contents.
107   EXPECT_EQ(expected_t.dtype(), dtype);
108   EXPECT_EQ(expected_t.shape(), shape);
109   Tensor actual_t(dtype, shape);
110   TF_ASSERT_OK(reader->Lookup(key, &actual_t));
111   for (int i = 0; i < expected_t.NumElements(); i++) {
112     Variant actual_var = actual_t.flat<Variant>()(i);
113     Variant expected_var = expected_t.flat<Variant>()(i);
114     EXPECT_EQ(actual_var.TypeName(), expected_var.TypeName());
115     auto* actual_val = actual_var.get<T>();
116     auto* expected_val = expected_var.get<T>();
117     EXPECT_EQ(*expected_val, *actual_val);
118   }
119 }
120 
121 template <typename T>
ExpectNext(BundleReader * reader,const Tensor & expected_val)122 void ExpectNext(BundleReader* reader, const Tensor& expected_val) {
123   EXPECT_TRUE(reader->Valid());
124   reader->Next();
125   TF_ASSERT_OK(reader->status());
126   Tensor val;
127   TF_ASSERT_OK(reader->ReadCurrent(&val));
128   test::ExpectTensorEqual<T>(val, expected_val);
129 }
130 
AllTensorKeys(BundleReader * reader)131 std::vector<string> AllTensorKeys(BundleReader* reader) {
132   std::vector<string> ret;
133   reader->Seek(kHeaderEntryKey);
134   reader->Next();
135   for (; reader->Valid(); reader->Next()) {
136     ret.emplace_back(reader->key());
137   }
138   return ret;
139 }
140 
141 // Writes out the metadata file of a bundle again, with the endianness marker
142 // bit flipped.
FlipEndiannessBit(const string & prefix)143 Status FlipEndiannessBit(const string& prefix) {
144   Env* env = Env::Default();
145   const string metadata_tmp_path = Prefix("some_tmp_path");
146   std::unique_ptr<WritableFile> metadata_file;
147   TF_RETURN_IF_ERROR(env->NewWritableFile(metadata_tmp_path, &metadata_file));
148   // We create the builder lazily in case we run into an exception earlier, in
149   // which case we'd forget to call Finish() and TableBuilder's destructor
150   // would complain.
151   std::unique_ptr<table::TableBuilder> builder;
152 
153   // Reads the existing metadata file, and fills the builder.
154   {
155     const string filename = MetaFilename(prefix);
156     uint64 file_size;
157     TF_RETURN_IF_ERROR(env->GetFileSize(filename, &file_size));
158     std::unique_ptr<RandomAccessFile> file;
159     TF_RETURN_IF_ERROR(env->NewRandomAccessFile(filename, &file));
160 
161     table::Table* table = nullptr;
162     TF_RETURN_IF_ERROR(
163         table::Table::Open(table::Options(), file.get(), file_size, &table));
164     std::unique_ptr<table::Table> table_deleter(table);
165     std::unique_ptr<table::Iterator> iter(table->NewIterator());
166 
167     // Reads the header entry.
168     iter->Seek(kHeaderEntryKey);
169     CHECK(iter->Valid());
170     BundleHeaderProto header;
171     CHECK(header.ParseFromArray(iter->value().data(), iter->value().size()));
172     // Flips the endianness.
173     if (header.endianness() == BundleHeaderProto::LITTLE) {
174       header.set_endianness(BundleHeaderProto::BIG);
175     } else {
176       header.set_endianness(BundleHeaderProto::LITTLE);
177     }
178     builder.reset(
179         new table::TableBuilder(table::Options(), metadata_file.get()));
180     builder->Add(iter->key(), header.SerializeAsString());
181     iter->Next();
182 
183     // Adds the non-header entries unmodified.
184     for (; iter->Valid(); iter->Next())
185       builder->Add(iter->key(), iter->value());
186   }
187   TF_RETURN_IF_ERROR(builder->Finish());
188   TF_RETURN_IF_ERROR(env->RenameFile(metadata_tmp_path, MetaFilename(prefix)));
189   return metadata_file->Close();
190 }
191 
192 template <typename T>
TestBasic()193 void TestBasic() {
194   {
195     BundleWriter writer(Env::Default(), Prefix("foo"));
196     TF_EXPECT_OK(writer.Add("foo_003", Constant_2x3(T(3))));
197     TF_EXPECT_OK(writer.Add("foo_000", Constant_2x3(T(0))));
198     TF_EXPECT_OK(writer.Add("foo_002", Constant_2x3(T(2))));
199     TF_EXPECT_OK(writer.Add("foo_001", Constant_2x3(T(1))));
200     TF_ASSERT_OK(writer.Finish());
201   }
202   {
203     BundleReader reader(Env::Default(), Prefix("foo"));
204     TF_ASSERT_OK(reader.status());
205     EXPECT_EQ(
206         AllTensorKeys(&reader),
207         std::vector<string>({"foo_000", "foo_001", "foo_002", "foo_003"}));
208     Expect<T>(&reader, "foo_000", Constant_2x3(T(0)));
209     Expect<T>(&reader, "foo_001", Constant_2x3(T(1)));
210     Expect<T>(&reader, "foo_002", Constant_2x3(T(2)));
211     Expect<T>(&reader, "foo_003", Constant_2x3(T(3)));
212   }
213   {
214     BundleReader reader(Env::Default(), Prefix("foo"));
215     TF_ASSERT_OK(reader.status());
216     ExpectNext<T>(&reader, Constant_2x3(T(0)));
217     ExpectNext<T>(&reader, Constant_2x3(T(1)));
218     ExpectNext<T>(&reader, Constant_2x3(T(2)));
219     ExpectNext<T>(&reader, Constant_2x3(T(3)));
220     EXPECT_TRUE(reader.Valid());
221     reader.Next();
222     EXPECT_FALSE(reader.Valid());
223   }
224   {
225     BundleWriter writer(Env::Default(), Prefix("bar"));
226     TF_EXPECT_OK(writer.Add("bar_003", Constant_2x3(T(3))));
227     TF_EXPECT_OK(writer.Add("bar_000", Constant_2x3(T(0))));
228     TF_EXPECT_OK(writer.Add("bar_002", Constant_2x3(T(2))));
229     TF_EXPECT_OK(writer.Add("bar_001", Constant_2x3(T(1))));
230     TF_ASSERT_OK(writer.Finish());
231   }
232   {
233     BundleReader reader(Env::Default(), Prefix("bar"));
234     TF_ASSERT_OK(reader.status());
235     EXPECT_EQ(
236         AllTensorKeys(&reader),
237         std::vector<string>({"bar_000", "bar_001", "bar_002", "bar_003"}));
238     Expect<T>(&reader, "bar_003", Constant_2x3(T(3)));
239     Expect<T>(&reader, "bar_002", Constant_2x3(T(2)));
240     Expect<T>(&reader, "bar_001", Constant_2x3(T(1)));
241     Expect<T>(&reader, "bar_000", Constant_2x3(T(0)));
242   }
243   {
244     BundleReader reader(Env::Default(), Prefix("bar"));
245     TF_ASSERT_OK(reader.status());
246     ExpectNext<T>(&reader, Constant_2x3(T(0)));
247     ExpectNext<T>(&reader, Constant_2x3(T(1)));
248     ExpectNext<T>(&reader, Constant_2x3(T(2)));
249     ExpectNext<T>(&reader, Constant_2x3(T(3)));
250     EXPECT_TRUE(reader.Valid());
251     reader.Next();
252     EXPECT_FALSE(reader.Valid());
253   }
254   TF_ASSERT_OK(MergeBundles(Env::Default(), {Prefix("foo"), Prefix("bar")},
255                             Prefix("merged")));
256   {
257     BundleReader reader(Env::Default(), Prefix("merged"));
258     TF_ASSERT_OK(reader.status());
259     EXPECT_EQ(
260         AllTensorKeys(&reader),
261         std::vector<string>({"bar_000", "bar_001", "bar_002", "bar_003",
262                              "foo_000", "foo_001", "foo_002", "foo_003"}));
263     Expect<T>(&reader, "bar_000", Constant_2x3(T(0)));
264     Expect<T>(&reader, "bar_001", Constant_2x3(T(1)));
265     Expect<T>(&reader, "bar_002", Constant_2x3(T(2)));
266     Expect<T>(&reader, "bar_003", Constant_2x3(T(3)));
267     Expect<T>(&reader, "foo_000", Constant_2x3(T(0)));
268     Expect<T>(&reader, "foo_001", Constant_2x3(T(1)));
269     Expect<T>(&reader, "foo_002", Constant_2x3(T(2)));
270     Expect<T>(&reader, "foo_003", Constant_2x3(T(3)));
271   }
272   {
273     BundleReader reader(Env::Default(), Prefix("merged"));
274     TF_ASSERT_OK(reader.status());
275     ExpectNext<T>(&reader, Constant_2x3(T(0)));
276     ExpectNext<T>(&reader, Constant_2x3(T(1)));
277     ExpectNext<T>(&reader, Constant_2x3(T(2)));
278     ExpectNext<T>(&reader, Constant_2x3(T(3)));
279     ExpectNext<T>(&reader, Constant_2x3(T(0)));
280     ExpectNext<T>(&reader, Constant_2x3(T(1)));
281     ExpectNext<T>(&reader, Constant_2x3(T(2)));
282     ExpectNext<T>(&reader, Constant_2x3(T(3)));
283     EXPECT_TRUE(reader.Valid());
284     reader.Next();
285     EXPECT_FALSE(reader.Valid());
286   }
287 }
288 
289 // Type-specific subroutine of SwapBytes test below
290 template <typename T>
TestByteSwap(const T * forward,const T * swapped,int array_len)291 void TestByteSwap(const T* forward, const T* swapped, int array_len) {
292   auto bytes_per_elem = sizeof(T);
293 
294   // Convert the entire array at once
295   std::unique_ptr<T[]> forward_copy(new T[array_len]);
296   std::memcpy(forward_copy.get(), forward, array_len * bytes_per_elem);
297   TF_EXPECT_OK(ByteSwapArray(reinterpret_cast<char*>(forward_copy.get()),
298                              bytes_per_elem, array_len));
299   for (int i = 0; i < array_len; i++) {
300     EXPECT_EQ(forward_copy.get()[i], swapped[i]);
301   }
302 
303   // Then the array wrapped in a tensor
304   auto shape = TensorShape({array_len});
305   auto dtype = DataTypeToEnum<T>::value;
306   Tensor forward_tensor(dtype, shape);
307   Tensor swapped_tensor(dtype, shape);
308   std::memcpy(const_cast<char*>(forward_tensor.tensor_data().data()), forward,
309               array_len * bytes_per_elem);
310   std::memcpy(const_cast<char*>(swapped_tensor.tensor_data().data()), swapped,
311               array_len * bytes_per_elem);
312   TF_EXPECT_OK(ByteSwapTensor(&forward_tensor));
313   test::ExpectTensorEqual<T>(forward_tensor, swapped_tensor);
314 }
315 
316 // Unit test of the byte-swapping operations that TensorBundle uses.
TEST(TensorBundleTest,SwapBytes)317 TEST(TensorBundleTest, SwapBytes) {
318   // A bug in the compiler on MacOS causes ByteSwap() and FlipEndiannessBit()
319   // to be removed from the executable if they are only called from templated
320   // functions. As a workaround, we make some dummy calls here.
321   // TODO(frreiss): Remove this workaround when the compiler bug is fixed.
322   ByteSwap(Constant_2x3<int>(42));
323   EXPECT_NE(OkStatus(), FlipEndiannessBit(Prefix("not_a_valid_prefix")));
324 
325   // Test patterns, manually swapped so that we aren't relying on the
326   // correctness of our own byte-swapping macros when testing those macros.
327   // At least one of the entries in each list has the sign bit set when
328   // interpreted as a signed int.
329   const int arr_len_16 = 4;
330   const uint16_t forward_16[] = {0x1de5, 0xd017, 0xf1ea, 0xc0a1};
331   const uint16_t swapped_16[] = {0xe51d, 0x17d0, 0xeaf1, 0xa1c0};
332   const int arr_len_32 = 2;
333   const uint32_t forward_32[] = {0x0ddba115, 0xf01dab1e};
334   const uint32_t swapped_32[] = {0x15a1db0d, 0x1eab1df0};
335   const int arr_len_64 = 2;
336   const uint64_t forward_64[] = {0xf005ba11caba1000, 0x5ca1ab1ecab005e5};
337   const uint64_t swapped_64[] = {0x0010baca11ba05f0, 0xe505b0ca1eaba15c};
338 
339   // 16-bit types
340   TestByteSwap(forward_16, swapped_16, arr_len_16);
341   TestByteSwap(reinterpret_cast<const int16_t*>(forward_16),
342                reinterpret_cast<const int16_t*>(swapped_16), arr_len_16);
343   TestByteSwap(reinterpret_cast<const bfloat16*>(forward_16),
344                reinterpret_cast<const bfloat16*>(swapped_16), arr_len_16);
345 
346   // 32-bit types
347   TestByteSwap(forward_32, swapped_32, arr_len_32);
348   TestByteSwap(reinterpret_cast<const int32_t*>(forward_32),
349                reinterpret_cast<const int32_t*>(swapped_32), arr_len_32);
350   TestByteSwap(reinterpret_cast<const float*>(forward_32),
351                reinterpret_cast<const float*>(swapped_32), arr_len_32);
352 
353   // 64-bit types
354   // Cast to uint64*/int64* to make DataTypeToEnum<T> happy
355   TestByteSwap(reinterpret_cast<const uint64*>(forward_64),
356                reinterpret_cast<const uint64*>(swapped_64), arr_len_64);
357   TestByteSwap(reinterpret_cast<const int64_t*>(forward_64),
358                reinterpret_cast<const int64_t*>(swapped_64), arr_len_64);
359   TestByteSwap(reinterpret_cast<const double*>(forward_64),
360                reinterpret_cast<const double*>(swapped_64), arr_len_64);
361 
362   // Complex types.
363   // Logic for complex number handling is only in ByteSwapTensor, so don't test
364   // ByteSwapArray
365   const float* forward_float = reinterpret_cast<const float*>(forward_32);
366   const float* swapped_float = reinterpret_cast<const float*>(swapped_32);
367   const double* forward_double = reinterpret_cast<const double*>(forward_64);
368   const double* swapped_double = reinterpret_cast<const double*>(swapped_64);
369   Tensor forward_complex64 = Constant_2x3<complex64>(
370       std::complex<float>(forward_float[0], forward_float[1]));
371   Tensor swapped_complex64 = Constant_2x3<complex64>(
372       std::complex<float>(swapped_float[0], swapped_float[1]));
373   Tensor forward_complex128 = Constant_2x3<complex128>(
374       std::complex<double>(forward_double[0], forward_double[1]));
375   Tensor swapped_complex128 = Constant_2x3<complex128>(
376       std::complex<double>(swapped_double[0], swapped_double[1]));
377 
378   TF_EXPECT_OK(ByteSwapTensor(&forward_complex64));
379   test::ExpectTensorEqual<complex64>(forward_complex64, swapped_complex64);
380 
381   TF_EXPECT_OK(ByteSwapTensor(&forward_complex128));
382   test::ExpectTensorEqual<complex128>(forward_complex128, swapped_complex128);
383 }
384 
385 // Basic test of alternate-endianness support. Generates a bundle in
386 // the opposite of the current system's endianness and attempts to
387 // read the bundle back in. Does not exercise sharding or access to
388 // nonaligned tensors. Does cover the major access types exercised
389 // in TestBasic.
390 template <typename T>
TestEndianness()391 void TestEndianness() {
392   {
393     // Write out a TensorBundle in the opposite of this host's endianness.
394     BundleWriter writer(Env::Default(), Prefix("foo"));
395     TF_EXPECT_OK(writer.Add("foo_003", ByteSwap(Constant_2x3<T>(T(3)))));
396     TF_EXPECT_OK(writer.Add("foo_000", ByteSwap(Constant_2x3<T>(T(0)))));
397     TF_EXPECT_OK(writer.Add("foo_002", ByteSwap(Constant_2x3<T>(T(2)))));
398     TF_EXPECT_OK(writer.Add("foo_001", ByteSwap(Constant_2x3<T>(T(1)))));
399     TF_ASSERT_OK(writer.Finish());
400     TF_ASSERT_OK(FlipEndiannessBit(Prefix("foo")));
401   }
402   {
403     BundleReader reader(Env::Default(), Prefix("foo"));
404     TF_ASSERT_OK(reader.status());
405     EXPECT_EQ(
406         AllTensorKeys(&reader),
407         std::vector<string>({"foo_000", "foo_001", "foo_002", "foo_003"}));
408     Expect<T>(&reader, "foo_000", Constant_2x3<T>(T(0)));
409     Expect<T>(&reader, "foo_001", Constant_2x3<T>(T(1)));
410     Expect<T>(&reader, "foo_002", Constant_2x3<T>(T(2)));
411     Expect<T>(&reader, "foo_003", Constant_2x3<T>(T(3)));
412   }
413   {
414     BundleReader reader(Env::Default(), Prefix("foo"));
415     TF_ASSERT_OK(reader.status());
416     ExpectNext<T>(&reader, Constant_2x3<T>(T(0)));
417     ExpectNext<T>(&reader, Constant_2x3<T>(T(1)));
418     ExpectNext<T>(&reader, Constant_2x3<T>(T(2)));
419     ExpectNext<T>(&reader, Constant_2x3<T>(T(3)));
420     EXPECT_TRUE(reader.Valid());
421     reader.Next();
422     EXPECT_FALSE(reader.Valid());
423   }
424   {
425     BundleWriter writer(Env::Default(), Prefix("bar"));
426     TF_EXPECT_OK(writer.Add("bar_003", ByteSwap(Constant_2x3<T>(T(3)))));
427     TF_EXPECT_OK(writer.Add("bar_000", ByteSwap(Constant_2x3<T>(T(0)))));
428     TF_EXPECT_OK(writer.Add("bar_002", ByteSwap(Constant_2x3<T>(T(2)))));
429     TF_EXPECT_OK(writer.Add("bar_001", ByteSwap(Constant_2x3<T>(T(1)))));
430     TF_ASSERT_OK(writer.Finish());
431     TF_ASSERT_OK(FlipEndiannessBit(Prefix("bar")));
432   }
433   {
434     BundleReader reader(Env::Default(), Prefix("bar"));
435     TF_ASSERT_OK(reader.status());
436     EXPECT_EQ(
437         AllTensorKeys(&reader),
438         std::vector<string>({"bar_000", "bar_001", "bar_002", "bar_003"}));
439     Expect<T>(&reader, "bar_003", Constant_2x3<T>(T(3)));
440     Expect<T>(&reader, "bar_002", Constant_2x3<T>(T(2)));
441     Expect<T>(&reader, "bar_001", Constant_2x3<T>(T(1)));
442     Expect<T>(&reader, "bar_000", Constant_2x3<T>(T(0)));
443   }
444   {
445     BundleReader reader(Env::Default(), Prefix("bar"));
446     TF_ASSERT_OK(reader.status());
447     ExpectNext<T>(&reader, Constant_2x3<T>(T(0)));
448     ExpectNext<T>(&reader, Constant_2x3<T>(T(1)));
449     ExpectNext<T>(&reader, Constant_2x3<T>(T(2)));
450     ExpectNext<T>(&reader, Constant_2x3<T>(T(3)));
451     EXPECT_TRUE(reader.Valid());
452     reader.Next();
453     EXPECT_FALSE(reader.Valid());
454   }
455   TF_ASSERT_OK(MergeBundles(Env::Default(), {Prefix("foo"), Prefix("bar")},
456                             Prefix("merged")));
457   {
458     BundleReader reader(Env::Default(), Prefix("merged"));
459     TF_ASSERT_OK(reader.status());
460     EXPECT_EQ(
461         AllTensorKeys(&reader),
462         std::vector<string>({"bar_000", "bar_001", "bar_002", "bar_003",
463                              "foo_000", "foo_001", "foo_002", "foo_003"}));
464     Expect<T>(&reader, "bar_000", Constant_2x3<T>(T(0)));
465     Expect<T>(&reader, "bar_001", Constant_2x3<T>(T(1)));
466     Expect<T>(&reader, "bar_002", Constant_2x3<T>(T(2)));
467     Expect<T>(&reader, "bar_003", Constant_2x3<T>(T(3)));
468     Expect<T>(&reader, "foo_000", Constant_2x3<T>(T(0)));
469     Expect<T>(&reader, "foo_001", Constant_2x3<T>(T(1)));
470     Expect<T>(&reader, "foo_002", Constant_2x3<T>(T(2)));
471     Expect<T>(&reader, "foo_003", Constant_2x3<T>(T(3)));
472   }
473   {
474     BundleReader reader(Env::Default(), Prefix("merged"));
475     TF_ASSERT_OK(reader.status());
476     ExpectNext<T>(&reader, Constant_2x3<T>(T(0)));
477     ExpectNext<T>(&reader, Constant_2x3<T>(T(1)));
478     ExpectNext<T>(&reader, Constant_2x3<T>(T(2)));
479     ExpectNext<T>(&reader, Constant_2x3<T>(T(3)));
480     ExpectNext<T>(&reader, Constant_2x3<T>(T(0)));
481     ExpectNext<T>(&reader, Constant_2x3<T>(T(1)));
482     ExpectNext<T>(&reader, Constant_2x3<T>(T(2)));
483     ExpectNext<T>(&reader, Constant_2x3<T>(T(3)));
484     EXPECT_TRUE(reader.Valid());
485     reader.Next();
486     EXPECT_FALSE(reader.Valid());
487   }
488 }
489 
490 template <typename T>
TestNonStandardShapes()491 void TestNonStandardShapes() {
492   {
493     BundleWriter writer(Env::Default(), Prefix("nonstandard"));
494     TF_EXPECT_OK(writer.Add("scalar", Constant(T(0), TensorShape())));
495     TF_EXPECT_OK(
496         writer.Add("non_standard0", Constant(T(0), TensorShape({0, 1618}))));
497     TF_EXPECT_OK(
498         writer.Add("non_standard1", Constant(T(0), TensorShape({16, 0, 18}))));
499     TF_ASSERT_OK(writer.Finish());
500   }
501   {
502     BundleReader reader(Env::Default(), Prefix("nonstandard"));
503     TF_ASSERT_OK(reader.status());
504     Expect<T>(&reader, "scalar", Constant(T(0), TensorShape()));
505     Expect<T>(&reader, "non_standard0", Constant(T(0), TensorShape({0, 1618})));
506     Expect<T>(&reader, "non_standard1",
507               Constant(T(0), TensorShape({16, 0, 18})));
508   }
509 }
510 
511 // Writes a bundle to disk with a bad "version"; checks for "expected_error".
VersionTest(const VersionDef & version,StringPiece expected_error)512 void VersionTest(const VersionDef& version, StringPiece expected_error) {
513   const string path = Prefix("version_test");
514   {
515     // Prepare an empty bundle with the given version information.
516     BundleHeaderProto header;
517     *header.mutable_version() = version;
518 
519     // Write the metadata file to disk.
520     std::unique_ptr<WritableFile> file;
521     TF_ASSERT_OK(Env::Default()->NewWritableFile(MetaFilename(path), &file));
522     table::TableBuilder builder(table::Options(), file.get());
523     builder.Add(kHeaderEntryKey, header.SerializeAsString());
524     TF_ASSERT_OK(builder.Finish());
525   }
526   // Read it back in and verify that we get the expected error.
527   BundleReader reader(Env::Default(), path);
528   EXPECT_TRUE(errors::IsInvalidArgument(reader.status()));
529   EXPECT_TRUE(
530       absl::StartsWith(reader.status().error_message(), expected_error));
531 }
532 
533 }  // namespace
534 
TEST(TensorBundleTest,Basic)535 TEST(TensorBundleTest, Basic) {
536   TestBasic<float>();
537   TestBasic<double>();
538   TestBasic<int32>();
539   TestBasic<uint8>();
540   TestBasic<int16>();
541   TestBasic<int8>();
542   TestBasic<complex64>();
543   TestBasic<complex128>();
544   TestBasic<int64_t>();
545   TestBasic<bool>();
546   TestBasic<qint32>();
547   TestBasic<quint8>();
548   TestBasic<qint8>();
549   TestBasic<bfloat16>();
550 }
551 
TEST(TensorBundleTest,Endianness)552 TEST(TensorBundleTest, Endianness) {
553   TestEndianness<float>();
554   TestEndianness<double>();
555   TestEndianness<int32>();
556   TestEndianness<uint8>();
557   TestEndianness<int16>();
558   TestEndianness<int8>();
559   TestEndianness<complex64>();
560   TestEndianness<complex128>();
561   TestEndianness<int64_t>();
562   TestEndianness<bool>();
563   TestEndianness<qint32>();
564   TestEndianness<quint8>();
565   TestEndianness<qint8>();
566   TestEndianness<bfloat16>();
567 }
568 
TEST(TensorBundleTest,PartitionedVariables)569 TEST(TensorBundleTest, PartitionedVariables) {
570   const TensorShape kFullShape({5, 10});
571   // Adds two slices.
572   // First slice: column 0, all zeros.
573   // Second slice: column 1 to rest, all ones.
574   TensorSlice slice1 = TensorSlice::ParseOrDie("-:0,1");
575   TensorSlice slice2 = TensorSlice::ParseOrDie("-:1,9");
576   {
577     BundleWriter writer(Env::Default(), Prefix("foo"));
578 
579     TF_ASSERT_OK(writer.AddSlice("foo", kFullShape, slice1,
580                                  Constant<float>(0., TensorShape({5, 1}))));
581     TF_ASSERT_OK(writer.AddSlice("foo", kFullShape, slice2,
582                                  Constant<float>(1., TensorShape({5, 9}))));
583     TF_ASSERT_OK(writer.Finish());
584   }
585   // Reads in full.
586   {
587     BundleReader reader(Env::Default(), Prefix("foo"));
588     TF_ASSERT_OK(reader.status());
589 
590     Tensor expected_val(DT_FLOAT, kFullShape);
591     test::FillFn<float>(&expected_val, [](int offset) -> float {
592       if (offset % 10 == 0) {
593         return 0;  // First column zeros.
594       }
595       return 1;  // Other columns ones.
596     });
597 
598     Tensor val(DT_FLOAT, kFullShape);
599     TF_ASSERT_OK(reader.Lookup("foo", &val));
600     test::ExpectTensorEqual<float>(val, expected_val);
601   }
602   // Reads all slices.
603   {
604     BundleReader reader(Env::Default(), Prefix("foo"));
605     TF_ASSERT_OK(reader.status());
606 
607     std::vector<TensorSlice> slices;
608     TF_ASSERT_OK(reader.LookupTensorSlices("foo", &slices));
609 
610     EXPECT_EQ(2, slices.size());
611     EXPECT_EQ(slice1.DebugString(), slices[0].DebugString());
612     EXPECT_EQ(slice2.DebugString(), slices[1].DebugString());
613   }
614   // Reads a slice consisting of first two columns, "cutting" both slices.
615   {
616     BundleReader reader(Env::Default(), Prefix("foo"));
617     TF_ASSERT_OK(reader.status());
618 
619     // First two columns, "cutting" both slices.
620     const TensorSlice distinct_slice = TensorSlice::ParseOrDie("-:0,2");
621     Tensor expected_val(DT_FLOAT, TensorShape({5, 2}));
622     test::FillFn<float>(&expected_val, [](int offset) -> float {
623       if (offset % 2 == 0) {
624         return 0;  // First column zeros.
625       }
626       return 1;  // Other columns ones.
627     });
628 
629     Tensor val(DT_FLOAT, TensorShape({5, 2}));
630     TF_ASSERT_OK(reader.LookupSlice("foo", distinct_slice, &val));
631     test::ExpectTensorEqual<float>(val, expected_val);
632   }
633   // Reads a slice consisting of columns 2-4, "cutting" the second slice only.
634   {
635     BundleReader reader(Env::Default(), Prefix("foo"));
636     TF_ASSERT_OK(reader.status());
637 
638     const TensorSlice distinct_slice = TensorSlice::ParseOrDie("-:2,2");
639     Tensor val(DT_FLOAT, TensorShape({5, 2}));
640     TF_ASSERT_OK(reader.LookupSlice("foo", distinct_slice, &val));
641     test::ExpectTensorEqual<float>(val,
642                                    Constant<float>(1., TensorShape({5, 2})));
643   }
644 }
645 
TEST(TensorBundleTest,EquivalentSliceTest)646 TEST(TensorBundleTest, EquivalentSliceTest) {
647   const TensorShape kFullShape({5, 10});
648   const Tensor kExpected(Constant<float>(1., kFullShape));
649   {
650     BundleWriter writer(Env::Default(), Prefix("foo"));
651     TF_ASSERT_OK(writer.AddSlice("no_extents", kFullShape,
652                                  TensorSlice::ParseOrDie("-:-"), kExpected));
653     TF_ASSERT_OK(writer.AddSlice("both_extents", kFullShape,
654                                  TensorSlice::ParseOrDie("0,5:0,10"),
655                                  kExpected));
656     TF_ASSERT_OK(writer.Finish());
657   }
658   // Slices match exactly and are fully abbreviated.
659   {
660     BundleReader reader(Env::Default(), Prefix("foo"));
661     TF_ASSERT_OK(reader.status());
662     const TensorSlice slice = TensorSlice::ParseOrDie("-:-");
663     Tensor val(DT_FLOAT, TensorShape(kFullShape));
664     TF_ASSERT_OK(reader.LookupSlice("no_extents", slice, &val));
665     test::ExpectTensorEqual<float>(val, kExpected);
666   }
667   // Slice match exactly and are fully specified.
668   {
669     BundleReader reader(Env::Default(), Prefix("foo"));
670     TF_ASSERT_OK(reader.status());
671     const TensorSlice slice = TensorSlice::ParseOrDie("0,5:0,10");
672     Tensor val(DT_FLOAT, TensorShape(kFullShape));
673     TF_ASSERT_OK(reader.LookupSlice("both_extents", slice, &val));
674     test::ExpectTensorEqual<float>(val, kExpected);
675   }
676   // Stored slice has no extents, spec has extents.
677   {
678     BundleReader reader(Env::Default(), Prefix("foo"));
679     TF_ASSERT_OK(reader.status());
680     const TensorSlice slice = TensorSlice::ParseOrDie("0,5:0,10");
681     Tensor val(DT_FLOAT, TensorShape(kFullShape));
682     TF_ASSERT_OK(reader.LookupSlice("no_extents", slice, &val));
683     test::ExpectTensorEqual<float>(val, kExpected);
684   }
685   // Stored slice has both extents, spec has no extents.
686   {
687     BundleReader reader(Env::Default(), Prefix("foo"));
688     TF_ASSERT_OK(reader.status());
689     const TensorSlice slice = TensorSlice::ParseOrDie("-:-");
690     Tensor val(DT_FLOAT, TensorShape(kFullShape));
691     TF_ASSERT_OK(reader.LookupSlice("both_extents", slice, &val));
692     test::ExpectTensorEqual<float>(val, kExpected);
693   }
694 }
695 
TEST(TensorBundleTest,NonStandardShapes)696 TEST(TensorBundleTest, NonStandardShapes) {
697   TestNonStandardShapes<float>();
698   TestNonStandardShapes<double>();
699   TestNonStandardShapes<int32>();
700   TestNonStandardShapes<uint8>();
701   TestNonStandardShapes<int16>();
702   TestNonStandardShapes<int8>();
703   TestNonStandardShapes<complex64>();
704   TestNonStandardShapes<complex128>();
705   TestNonStandardShapes<int64_t>();
706   TestNonStandardShapes<bool>();
707   TestNonStandardShapes<qint32>();
708   TestNonStandardShapes<quint8>();
709   TestNonStandardShapes<qint8>();
710   TestNonStandardShapes<bfloat16>();
711 }
712 
TEST(TensorBundleTest,StringTensorsOldFormat)713 TEST(TensorBundleTest, StringTensorsOldFormat) {
714   // Test string tensor bundle made with previous version of code that use
715   // varint32s to store string lengths (we now use varint64s).
716   BundleReader reader(Env::Default(), TestdataPrefix("old_string_tensors/foo"));
717   TF_ASSERT_OK(reader.status());
718   EXPECT_EQ(AllTensorKeys(&reader),
719             std::vector<string>({"floats", "scalar", "string_tensor", "strs"}));
720 
721   Expect<tstring>(&reader, "string_tensor",
722                   Tensor(DT_STRING, TensorShape({1})));
723   Expect<tstring>(&reader, "scalar", test::AsTensor<tstring>({"hello"}));
724   Expect<tstring>(
725       &reader, "strs",
726       test::AsTensor<tstring>({"hello", "", "x01", string(1 << 10, 'c')}));
727   Expect<float>(&reader, "floats", Constant_2x3<float>(16.18));
728 }
729 
730 // Copied from absl code.
GetPageSize()731 size_t GetPageSize() {
732 #ifdef _WIN32
733   SYSTEM_INFO system_info;
734   GetSystemInfo(&system_info);
735   return std::max(system_info.dwPageSize, system_info.dwAllocationGranularity);
736 #elif defined(__wasm__) || defined(__asmjs__)
737   return getpagesize();
738 #else
739   return sysconf(_SC_PAGESIZE);
740 #endif
741 }
742 
TEST(TensorBundleTest,StringTensors)743 TEST(TensorBundleTest, StringTensors) {
744   constexpr size_t kLongLength = static_cast<size_t>(UINT32_MAX) + 1;
745   Tensor long_string_tensor(DT_STRING, TensorShape({1}));
746 
747   {
748     BundleWriter writer(Env::Default(), Prefix("foo"));
749     TF_EXPECT_OK(writer.Add("string_tensor",
750                             Tensor(DT_STRING, TensorShape({1}))));  // Empty.
751     TF_EXPECT_OK(writer.Add("scalar", test::AsTensor<tstring>({"hello"})));
752     TF_EXPECT_OK(writer.Add(
753         "strs",
754         test::AsTensor<tstring>({"hello", "", "x01", string(1 << 25, 'c')})));
755 
756     // Requires a 64-bit length.
757     tstring* backing_string = long_string_tensor.flat<tstring>().data();
758     backing_string->resize_uninitialized(kLongLength);
759     std::char_traits<char>::assign(backing_string->data(), kLongLength, 'd');
760     TF_EXPECT_OK(writer.Add("long_scalar", long_string_tensor));
761 
762     // Mixes in some floats.
763     TF_EXPECT_OK(writer.Add("floats", Constant_2x3<float>(16.18)));
764     TF_ASSERT_OK(writer.Finish());
765   }
766   {
767     BundleReader reader(Env::Default(), Prefix("foo"));
768     TF_ASSERT_OK(reader.status());
769     EXPECT_EQ(AllTensorKeys(&reader),
770               std::vector<string>({"floats", "long_scalar", "scalar",
771                                    "string_tensor", "strs"}));
772 
773     Expect<tstring>(&reader, "string_tensor",
774                     Tensor(DT_STRING, TensorShape({1})));
775     Expect<tstring>(&reader, "scalar", test::AsTensor<tstring>({"hello"}));
776     Expect<tstring>(
777         &reader, "strs",
778         test::AsTensor<tstring>({"hello", "", "x01", string(1 << 25, 'c')}));
779 
780     Expect<float>(&reader, "floats", Constant_2x3<float>(16.18));
781 
782     // We don't use the Expect function so we can re-use the
783     // `long_string_tensor` buffer for reading out long_scalar to keep memory
784     // usage reasonable.
785     EXPECT_TRUE(reader.Contains("long_scalar"));
786     DataType dtype;
787     TensorShape shape;
788     TF_ASSERT_OK(reader.LookupDtypeAndShape("long_scalar", &dtype, &shape));
789     EXPECT_EQ(DT_STRING, dtype);
790     EXPECT_EQ(TensorShape({1}), shape);
791 
792     // Fill the string differently so that we can be sure the new one is read
793     // in. Because fragmentation in tc-malloc and we have such a big tensor
794     // of 4GB, therefore it is not ideal to free the buffer right now.
795     // The rationale is to make allocation/free close to each other.
796     tstring* backing_string = long_string_tensor.flat<tstring>().data();
797     std::char_traits<char>::assign(backing_string->data(), kLongLength, 'e');
798 
799     // Read long_scalar and check it contains kLongLength 'd's.
800     TF_ASSERT_OK(reader.Lookup("long_scalar", &long_string_tensor));
801     ASSERT_EQ(backing_string, long_string_tensor.flat<tstring>().data());
802     EXPECT_EQ(kLongLength, backing_string->length());
803 
804     const size_t kPageSize = GetPageSize();
805     char* testblock = new char[kPageSize];
806     memset(testblock, 'd', sizeof(char) * kPageSize);
807     for (size_t i = 0; i < kLongLength; i += kPageSize) {
808       if (memcmp(testblock, backing_string->data() + i, kPageSize) != 0) {
809         FAIL() << "long_scalar is not full of 'd's as expected.";
810         break;
811       }
812     }
813     delete[] testblock;
814   }
815 }
816 
817 class VariantObject {
818  public:
VariantObject()819   VariantObject() {}
VariantObject(const string & metadata,int64_t value)820   VariantObject(const string& metadata, int64_t value)
821       : metadata_(metadata), value_(value) {}
822 
TypeName() const823   string TypeName() const { return "TEST VariantObject"; }
Encode(VariantTensorData * data) const824   void Encode(VariantTensorData* data) const {
825     data->set_type_name(TypeName());
826     data->set_metadata(metadata_);
827     Tensor val_t = Tensor(DT_INT64, TensorShape({}));
828     val_t.scalar<int64_t>()() = value_;
829     *(data->add_tensors()) = val_t;
830   }
Decode(const VariantTensorData & data)831   bool Decode(const VariantTensorData& data) {
832     EXPECT_EQ(data.type_name(), TypeName());
833     data.get_metadata(&metadata_);
834     EXPECT_EQ(data.tensors_size(), 1);
835     value_ = data.tensors(0).scalar<int64_t>()();
836     return true;
837   }
operator ==(const VariantObject other) const838   bool operator==(const VariantObject other) const {
839     return metadata_ == other.metadata_ && value_ == other.value_;
840   }
841   string metadata_;
842   int64_t value_;
843 };
844 
845 REGISTER_UNARY_VARIANT_DECODE_FUNCTION(VariantObject, "TEST VariantObject");
846 
TEST(TensorBundleTest,VariantTensors)847 TEST(TensorBundleTest, VariantTensors) {
848   {
849     BundleWriter writer(Env::Default(), Prefix("foo"));
850     TF_EXPECT_OK(
851         writer.Add("variant_tensor",
852                    test::AsTensor<Variant>({VariantObject("test", 10),
853                                             VariantObject("test1", 20)})));
854     TF_ASSERT_OK(writer.Finish());
855   }
856   {
857     BundleReader reader(Env::Default(), Prefix("foo"));
858     TF_ASSERT_OK(reader.status());
859     ExpectVariant<VariantObject>(
860         &reader, "variant_tensor",
861         test::AsTensor<Variant>(
862             {VariantObject("test", 10), VariantObject("test1", 20)}));
863   }
864 }
865 
TEST(TensorBundleTest,DirectoryStructure)866 TEST(TensorBundleTest, DirectoryStructure) {
867   Env* env = Env::Default();
868   // Writes two bundles.
869   const std::vector<string> kBundlePrefixes = {Prefix("worker0"),
870                                                Prefix("worker1")};
871   for (int i = 0; i < 2; ++i) {
872     BundleWriter writer(env, kBundlePrefixes[i]);
873     TF_EXPECT_OK(
874         writer.Add(strings::StrCat("tensor", i), Constant_2x3<float>(0.)));
875     TF_ASSERT_OK(writer.Finish());
876   }
877 
878   // Ensures we have the expected files.
879   auto CheckDirFiles = [env](const string& bundle_prefix,
880                              gtl::ArraySlice<string> expected_files) {
881     StringPiece dir = io::Dirname(bundle_prefix);
882     for (const string& expected_file : expected_files) {
883       TF_EXPECT_OK(env->FileExists(io::JoinPath(dir, expected_file)));
884     }
885   };
886 
887   // Check we have:
888   //   worker<i>.index
889   //   worker<i>.data-00000-of-00001
890   CheckDirFiles(kBundlePrefixes[0],
891                 {"worker0.index", "worker0.data-00000-of-00001"});
892   CheckDirFiles(kBundlePrefixes[1],
893                 {"worker1.index", "worker1.data-00000-of-00001"});
894 
895   // Trivially "merge" one bundle to some other location (i.e., a renaming).
896   const string kAnotherPrefix = Prefix("another");
897   TF_ASSERT_OK(MergeBundles(env, {kBundlePrefixes[0]}, kAnotherPrefix));
898   CheckDirFiles(kAnotherPrefix,
899                 {"another.index", "another.data-00000-of-00001"});
900 
901   // Performs actual merge of the two bundles.  Check we have:
902   //   merged.index
903   //   merged.data-00000-of-00002
904   //   merged.data-00001-of-00002
905   const string kMerged = Prefix("merged");
906   TF_ASSERT_OK(
907       MergeBundles(env, {kAnotherPrefix, kBundlePrefixes[1]}, kMerged));
908   CheckDirFiles(kMerged, {"merged.index", "merged.data-00000-of-00002",
909                           "merged.data-00001-of-00002"});
910 }
911 
TEST(TensorBundleTest,SortForSequentialAccess)912 TEST(TensorBundleTest, SortForSequentialAccess) {
913   Env* env = Env::Default();
914   const std::vector<string> kBundlePrefixes = {Prefix("worker0"),
915                                                Prefix("worker1")};
916   BundleWriter writer0(env, kBundlePrefixes[0]);
917   for (int i = 0; i < 3; ++i) {
918     TF_EXPECT_OK(
919         writer0.Add(strings::StrCat("tensor-0-", i), Constant_2x3<float>(0.)));
920   }
921   TF_ASSERT_OK(writer0.Finish());
922 
923   BundleWriter writer1(env, kBundlePrefixes[1]);
924   for (int i = 2; i >= 0; --i) {
925     TF_EXPECT_OK(
926         writer1.Add(strings::StrCat("tensor-1-", i), Constant_2x3<float>(0.)));
927   }
928   TF_ASSERT_OK(writer1.Finish());
929 
930   const string kMerged = Prefix("merged");
931   TF_ASSERT_OK(
932       MergeBundles(env, {kBundlePrefixes[0], kBundlePrefixes[1]}, kMerged));
933 
934   // We now have:
935   //   merged.data-00000-of-00002 with tensor-0-0, tensor-0-1, tensor-0-2
936   //   merged.data-00001-of-00002 with tensor-1-2, tensor-1-1, tensor-1-0
937 
938   BundleReader reader(env, kMerged);
939   TF_ASSERT_OK(reader.status());
940   std::vector<string> tensor_names = {"tensor-1-0", "tensor-0-1", "tensor-1-2",
941                                       "tensor-0-0", "tensor-1-1", "tensor-0-2"};
942   TF_ASSERT_OK(reader.SortForSequentialAccess<string>(
943       tensor_names, [](const string& element) { return element; }));
944   EXPECT_THAT(tensor_names,
945               ElementsAre("tensor-0-0", "tensor-0-1", "tensor-0-2",
946                           "tensor-1-2", "tensor-1-1", "tensor-1-0"));
947 }
948 
TEST(TensorBundleTest,Error)949 TEST(TensorBundleTest, Error) {
950   {  // Dup keys.
951     BundleWriter writer(Env::Default(), Prefix("dup"));
952     TF_EXPECT_OK(writer.Add("foo", Constant_2x3(1.f)));
953     EXPECT_FALSE(writer.Add("foo", Constant_2x3(2.f)).ok());
954     EXPECT_TRUE(absl::StrContains(writer.status().ToString(), "duplicate key"));
955     EXPECT_FALSE(writer.Finish().ok());
956   }
957   {  // Double finish
958     BundleWriter writer(Env::Default(), Prefix("bad"));
959     EXPECT_TRUE(writer.Finish().ok());
960     EXPECT_FALSE(writer.Finish().ok());
961   }
962   {  // Not found.
963     BundleReader reader(Env::Default(), Prefix("nonexist"));
964     EXPECT_EQ(reader.status().code(), error::NOT_FOUND);
965   }
966 }
967 
TEST(TensorBundleTest,Checksum)968 TEST(TensorBundleTest, Checksum) {
969   // Randomly flips a byte in [pos_lhs, end of data file), or exactly byte
970   // pos_lhs if exact_pos == True.
971   auto FlipByte = [](const string& prefix, int pos_lhs,
972                      bool exact_pos = false) {
973     DCHECK_GE(pos_lhs, 0);
974     const string& datafile = DataFilename(Prefix(prefix), 0, 1);
975     string data;
976     TF_ASSERT_OK(ReadFileToString(Env::Default(), datafile, &data));
977 
978     int byte_pos = 0;
979     if (!exact_pos) {
980       std::mt19937 rng;
981       std::uniform_int_distribution<int> dist(pos_lhs, data.size() - 1);
982       byte_pos = dist(rng);
983     } else {
984       byte_pos = pos_lhs;
985     }
986     data[byte_pos] = ~data[byte_pos];
987     TF_ASSERT_OK(WriteStringToFile(Env::Default(), datafile, data));
988   };
989   // The lookup should fail with a checksum-related message.
990   auto ExpectLookupFails = [](const string& prefix, const string& key,
991                               const string& expected_msg, Tensor& val) {
992     BundleReader reader(Env::Default(), Prefix(prefix));
993     Status status = reader.Lookup(key, &val);
994     EXPECT_TRUE(errors::IsDataLoss(status));
995     EXPECT_TRUE(absl::StrContains(status.ToString(), expected_msg));
996   };
997 
998   // Corrupts a float tensor.
999   {
1000     BundleWriter writer(Env::Default(), Prefix("singleton"));
1001     TF_EXPECT_OK(writer.Add("foo", Constant_2x3(1.f)));
1002     TF_ASSERT_OK(writer.Finish());
1003 
1004     FlipByte("singleton", 0 /* corrupts any byte */);
1005     Tensor val(DT_FLOAT, TensorShape({2, 3}));
1006     ExpectLookupFails("singleton", "foo",
1007                       "Checksum does not match" /* expected fail msg */, val);
1008   }
1009   // Corrupts a string tensor.
1010   {
1011     auto WriteStrings = []() {
1012       BundleWriter writer(Env::Default(), Prefix("strings"));
1013       TF_EXPECT_OK(
1014           writer.Add("foo", test::AsTensor<tstring>({"hello", "world"})));
1015       TF_ASSERT_OK(writer.Finish());
1016     };
1017     // Corrupts the first two bytes, which are the varint32-encoded lengths
1018     // of the two string elements.  Should hit mismatch on length cksum.
1019     for (int i = 0; i < 2; ++i) {
1020       WriteStrings();
1021       FlipByte("strings", i, true /* corrupts exactly byte i */);
1022       Tensor val(DT_STRING, TensorShape({2}));
1023       ExpectLookupFails(
1024           "strings", "foo",
1025           "length checksum does not match" /* expected fail msg */, val);
1026     }
1027     // Corrupts the string bytes, should hit an overall cksum mismatch.
1028     WriteStrings();
1029     FlipByte("strings", 2 /* corrupts starting from byte 2 */);
1030     Tensor val(DT_STRING, TensorShape({2}));
1031     ExpectLookupFails("strings", "foo",
1032                       "Checksum does not match" /* expected fail msg */, val);
1033   }
1034 }
1035 
TEST(TensorBundleTest,TruncatedTensorContents)1036 TEST(TensorBundleTest, TruncatedTensorContents) {
1037   Env* env = Env::Default();
1038   BundleWriter writer(env, Prefix("end"));
1039   TF_EXPECT_OK(writer.Add("key", Constant_2x3<float>(1.0)));
1040   TF_ASSERT_OK(writer.Finish());
1041 
1042   // Truncates the data file by one byte, so that we hit EOF.
1043   const string datafile = DataFilename(Prefix("end"), 0, 1);
1044   string data;
1045   TF_ASSERT_OK(ReadFileToString(env, datafile, &data));
1046   ASSERT_TRUE(!data.empty());
1047   TF_ASSERT_OK(WriteStringToFile(env, datafile,
1048                                  StringPiece(data.data(), data.size() - 1)));
1049 
1050   BundleReader reader(env, Prefix("end"));
1051   TF_ASSERT_OK(reader.status());
1052   Tensor val(DT_FLOAT, TensorShape({2, 3}));
1053   EXPECT_TRUE(errors::IsOutOfRange(reader.Lookup("key", &val)));
1054 }
1055 
TEST(TensorBundleTest,HeaderEntry)1056 TEST(TensorBundleTest, HeaderEntry) {
1057   {
1058     BundleWriter writer(Env::Default(), Prefix("b"));
1059     TF_EXPECT_OK(writer.Add("key", Constant_2x3<float>(1.0)));
1060     TF_ASSERT_OK(writer.Finish());
1061   }
1062 
1063   // Extracts out the header.
1064   BundleHeaderProto header;
1065   {
1066     BundleReader reader(Env::Default(), Prefix("b"));
1067     TF_ASSERT_OK(reader.status());
1068     reader.Seek(kHeaderEntryKey);
1069     ASSERT_TRUE(reader.Valid());
1070     ASSERT_TRUE(ParseProtoUnlimited(&header, reader.value().data(),
1071                                     reader.value().size()));
1072   }
1073 
1074   // num_shards
1075   EXPECT_EQ(1, header.num_shards());
1076   // endianness
1077   if (port::kLittleEndian) {
1078     EXPECT_EQ(BundleHeaderProto::LITTLE, header.endianness());
1079   } else {
1080     EXPECT_EQ(BundleHeaderProto::BIG, header.endianness());
1081   }
1082   // version
1083   EXPECT_GT(kTensorBundleVersion, 0);
1084   EXPECT_EQ(kTensorBundleVersion, header.version().producer());
1085   EXPECT_EQ(kTensorBundleMinConsumer, header.version().min_consumer());
1086 }
1087 
TEST(TensorBundleTest,VersionTest)1088 TEST(TensorBundleTest, VersionTest) {
1089   // Min consumer.
1090   {
1091     VersionDef versions;
1092     versions.set_producer(kTensorBundleVersion + 1);
1093     versions.set_min_consumer(kTensorBundleVersion + 1);
1094     VersionTest(
1095         versions,
1096         strings::StrCat("Checkpoint min consumer version ",
1097                         kTensorBundleVersion + 1, " above current version ",
1098                         kTensorBundleVersion, " for TensorFlow"));
1099   }
1100   // Min producer.
1101   {
1102     VersionDef versions;
1103     versions.set_producer(kTensorBundleMinProducer - 1);
1104     VersionTest(
1105         versions,
1106         strings::StrCat("Checkpoint producer version ",
1107                         kTensorBundleMinProducer - 1, " below min producer ",
1108                         kTensorBundleMinProducer, " supported by TensorFlow"));
1109   }
1110   // Bad consumer.
1111   {
1112     VersionDef versions;
1113     versions.set_producer(kTensorBundleVersion + 1);
1114     versions.add_bad_consumers(kTensorBundleVersion);
1115     VersionTest(
1116         versions,
1117         strings::StrCat(
1118             "Checkpoint disallows consumer version ", kTensorBundleVersion,
1119             ".  Please upgrade TensorFlow: this version is likely buggy."));
1120   }
1121 }
1122 
1123 class TensorBundleAlignmentTest : public ::testing::Test {
1124  protected:
1125   template <typename T>
ExpectAlignment(BundleReader * reader,const string & key,int alignment)1126   void ExpectAlignment(BundleReader* reader, const string& key, int alignment) {
1127     BundleEntryProto full_tensor_entry;
1128     TF_ASSERT_OK(reader->GetBundleEntryProto(key, &full_tensor_entry));
1129     EXPECT_EQ(0, full_tensor_entry.offset() % alignment);
1130   }
1131 };
1132 
TEST_F(TensorBundleAlignmentTest,AlignmentTest)1133 TEST_F(TensorBundleAlignmentTest, AlignmentTest) {
1134   {
1135     BundleWriter::Options opts;
1136     opts.data_alignment = 42;
1137     BundleWriter writer(Env::Default(), Prefix("foo"), opts);
1138     TF_EXPECT_OK(writer.Add("foo_003", Constant_2x3<float>(3)));
1139     TF_EXPECT_OK(writer.Add("foo_000", Constant_2x3<float>(0)));
1140     TF_EXPECT_OK(writer.Add("foo_002", Constant_2x3<float>(2)));
1141     TF_EXPECT_OK(writer.Add("foo_001", Constant_2x3<float>(1)));
1142     TF_ASSERT_OK(writer.Finish());
1143   }
1144   {
1145     BundleReader reader(Env::Default(), Prefix("foo"));
1146     TF_ASSERT_OK(reader.status());
1147     EXPECT_EQ(
1148         AllTensorKeys(&reader),
1149         std::vector<string>({"foo_000", "foo_001", "foo_002", "foo_003"}));
1150     Expect<float>(&reader, "foo_000", Constant_2x3<float>(0));
1151     Expect<float>(&reader, "foo_001", Constant_2x3<float>(1));
1152     Expect<float>(&reader, "foo_002", Constant_2x3<float>(2));
1153     Expect<float>(&reader, "foo_003", Constant_2x3<float>(3));
1154   }
1155   {
1156     BundleReader reader(Env::Default(), Prefix("foo"));
1157     TF_ASSERT_OK(reader.status());
1158     ExpectNext<float>(&reader, Constant_2x3<float>(0));
1159     ExpectNext<float>(&reader, Constant_2x3<float>(1));
1160     ExpectNext<float>(&reader, Constant_2x3<float>(2));
1161     ExpectNext<float>(&reader, Constant_2x3<float>(3));
1162     EXPECT_TRUE(reader.Valid());
1163     reader.Next();
1164     EXPECT_FALSE(reader.Valid());
1165   }
1166   {
1167     BundleReader reader(Env::Default(), Prefix("foo"));
1168     TF_ASSERT_OK(reader.status());
1169     ExpectAlignment<float>(&reader, "foo_000", 42);
1170     ExpectAlignment<float>(&reader, "foo_001", 42);
1171     ExpectAlignment<float>(&reader, "foo_002", 42);
1172     ExpectAlignment<float>(&reader, "foo_003", 42);
1173   }
1174 }
1175 
BM_BundleAlignment(::testing::benchmark::State & state)1176 static void BM_BundleAlignment(::testing::benchmark::State& state) {
1177   {
1178     const int alignment = state.range(0);
1179     const int tensor_size = state.range(1);
1180     BundleWriter::Options opts;
1181     opts.data_alignment = alignment;
1182     BundleWriter writer(Env::Default(), Prefix("foo"), opts);
1183     TF_CHECK_OK(writer.Add("small", Constant(true, TensorShape({1}))));
1184     TF_CHECK_OK(writer.Add("big", Constant(32.1, TensorShape({tensor_size}))));
1185     TF_CHECK_OK(writer.Finish());
1186   }
1187   BundleReader reader(Env::Default(), Prefix("foo"));
1188   TF_CHECK_OK(reader.status());
1189   for (auto s : state) {
1190     Tensor t;
1191     TF_CHECK_OK(reader.Lookup("big", &t));
1192   }
1193 }
1194 
1195 BENCHMARK(BM_BundleAlignment)->ArgPair(1, 512);
1196 BENCHMARK(BM_BundleAlignment)->ArgPair(1, 4096);
1197 BENCHMARK(BM_BundleAlignment)->ArgPair(1, 1048576);
1198 BENCHMARK(BM_BundleAlignment)->ArgPair(4096, 512);
1199 BENCHMARK(BM_BundleAlignment)->ArgPair(4096, 4096);
1200 BENCHMARK(BM_BundleAlignment)->ArgPair(4096, 1048576);
1201 
BM_BundleWriterSmallTensor(::testing::benchmark::State & state)1202 static void BM_BundleWriterSmallTensor(::testing::benchmark::State& state) {
1203   const int64_t bytes = state.range(0);
1204   Tensor t = Constant(static_cast<int8>('a'), TensorShape{bytes});
1205   BundleWriter writer(Env::Default(), Prefix("foo"));
1206   int suffix = 0;
1207   for (auto s : state) {
1208     TF_CHECK_OK(writer.Add(strings::StrCat("small", suffix++), t));
1209   }
1210 }
1211 
1212 BENCHMARK(BM_BundleWriterSmallTensor)->Range(1, 1 << 20);
1213 
BM_BundleWriterLargeTensor(::testing::benchmark::State & state)1214 static void BM_BundleWriterLargeTensor(::testing::benchmark::State& state) {
1215   const int mb = state.range(0);
1216   const int64_t bytes = static_cast<int64_t>(mb) * (1 << 20);
1217   Tensor t = Constant(static_cast<int8>('a'), TensorShape{bytes});
1218   for (auto s : state) {
1219     BundleWriter writer(Env::Default(), Prefix("foo"));
1220     TF_CHECK_OK(writer.Add("big", t));
1221   }
1222 }
1223 
1224 BENCHMARK(BM_BundleWriterLargeTensor)->Arg(1 << 10);
1225 BENCHMARK(BM_BundleWriterLargeTensor)->Arg(4 << 10);
1226 
1227 }  // namespace tensorflow
1228