• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/util/tensor_slice_writer.h"
17 
18 #include <array>
19 
20 #include "tensorflow/core/framework/tensor_shape.pb.h"
21 #include "tensorflow/core/framework/versions.pb.h"
22 #include "tensorflow/core/lib/core/status_test_util.h"
23 #include "tensorflow/core/lib/core/stringpiece.h"
24 #include "tensorflow/core/lib/io/path.h"
25 #include "tensorflow/core/lib/strings/str_util.h"
26 #include "tensorflow/core/platform/logging.h"
27 #include "tensorflow/core/platform/protobuf.h"
28 #include "tensorflow/core/platform/test.h"
29 #include "tensorflow/core/public/version.h"
30 #include "tensorflow/core/util/saved_tensor_slice_util.h"
31 #include "tensorflow/core/util/tensor_slice_reader.h"
32 
33 namespace tensorflow {
34 
35 namespace checkpoint {
36 
37 class TensorSliceWriteTestHelper {
38  public:
39   static void CheckEntries(const string& fname);
40   static void GetData(TensorSliceReader::Table* table, const string& name,
41                       const TensorSlice& slice, SavedSlice* ss);
42 };
43 
44 namespace {
45 
46 // Testing that an array is what is expected
ExpectIdenticalFloatArrays(const float * expected,int size,const float * actual)47 void ExpectIdenticalFloatArrays(const float* expected, int size,
48                                 const float* actual) {
49   // TODO(yangke): copy some of the Dump* functions over
50   //  LOG(INFO) << "Expected = " << DumpFloatArray(expected, size);
51   //  LOG(INFO) << "Actual   = " << DumpFloatArray(actual, size);
52   for (int i = 0; i < size; ++i) {
53     EXPECT_NEAR(expected[i], actual[i], 1e-6);
54   }
55 }
56 
57 template <typename T, typename U>
ExpectIdenticalIntArrays(const T * expected,int size,const U * actual)58 void ExpectIdenticalIntArrays(const T* expected, int size, const U* actual) {
59   for (int i = 0; i < size; ++i) {
60     EXPECT_EQ(expected[i], static_cast<T>(actual[i]));
61   }
62 }
63 
64 // Nifty routine to get the size of an array
65 template <typename T, unsigned SIZE>
ArraySize(const T (& v)[SIZE])66 inline size_t ArraySize(const T (&v)[SIZE]) {
67   return SIZE;
68 }
69 
70 // A simple test on writing a few tensor slices
71 // TODO(yangke): refactor into smaller tests: will do as we add more stuff to
72 // the writer.
TEST(TensorSliceWriteTest,SimpleWrite)73 TEST(TensorSliceWriteTest, SimpleWrite) {
74   const string filename = io::JoinPath(testing::TmpDir(), "checkpoint");
75 
76   TensorSliceWriter writer(filename, CreateTableTensorSliceBuilder);
77 
78   // Add some int32 tensor slices
79   {
80     TensorShape shape({5, 10});
81     TensorSlice slice = TensorSlice::ParseOrDie("-:0,1");
82     const int32 data[] = {0, 1, 2, 3, 4};
83     TF_CHECK_OK(writer.Add("test", shape, slice, data));
84   }
85 
86   // Two slices share the same tensor name
87   {
88     TensorShape shape({5, 10});
89     TensorSlice slice = TensorSlice::ParseOrDie("-:3,1");
90     const int32 data[] = {10, 11, 12, 13, 14};
91     TF_CHECK_OK(writer.Add("test", shape, slice, data));
92   }
93 
94   // Another slice from a different float tensor -- it has a different name and
95   // should be inserted in front of the previous tensor
96   {
97     TensorShape shape({3, 2});
98     TensorSlice slice = TensorSlice::ParseOrDie("-:-");
99     const float data[] = {1.2, 1.3, 1.4, 2.1, 2.2, 2.3};
100     TF_CHECK_OK(writer.Add("AA", shape, slice, data));
101   }
102 
103   // A slice with int64 data
104   {
105     TensorShape shape({5, 10});
106     TensorSlice slice = TensorSlice::ParseOrDie("-:3,1");
107     const int64 data[] = {10, 11, 12, 13, 14};
108     TF_CHECK_OK(writer.Add("int64", shape, slice, data));
109   }
110 
111   // A slice with int16 data
112   {
113     TensorShape shape({5, 10});
114     TensorSlice slice = TensorSlice::ParseOrDie("-:3,1");
115     const int16 data[] = {10, 11, 12, 13, 14};
116     TF_CHECK_OK(writer.Add("int16", shape, slice, data));
117   }
118 
119   TF_CHECK_OK(writer.Finish());
120 
121   // Now we examine the checkpoint file manually.
122   TensorSliceWriteTestHelper::CheckEntries(filename);
123 }
124 
125 }  // namespace
126 
GetData(TensorSliceReader::Table * table,const string & name,const TensorSlice & slice,SavedSlice * ss)127 void TensorSliceWriteTestHelper::GetData(TensorSliceReader::Table* table,
128                                          const string& name,
129                                          const TensorSlice& slice,
130                                          SavedSlice* ss) {
131   string key = EncodeTensorNameSlice(name, slice);
132   string value;
133   EXPECT_TRUE(table->Get(key, &value));
134   SavedTensorSlices sts;
135   EXPECT_TRUE(ParseProtoUnlimited(&sts, value));
136   EXPECT_FALSE(sts.has_meta());
137   *ss = sts.data();
138   EXPECT_EQ(name, ss->name());
139   TensorSlice slice2(ss->slice());
140   EXPECT_EQ(slice.DebugString(), slice2.DebugString());
141 }
142 
CheckEntries(const string & fname)143 void TensorSliceWriteTestHelper::CheckEntries(const string& fname) {
144   TensorSliceReader::Table* tptr;
145   TF_CHECK_OK(OpenTableTensorSliceReader(fname, &tptr));
146   std::unique_ptr<TensorSliceReader::Table> table(tptr);
147   CHECK_NOTNULL(table.get());
148 
149   // We expect a block of SavedTensorSlices
150   string value;
151   ASSERT_TRUE(table->Get(kSavedTensorSlicesKey, &value));
152   {
153     SavedTensorSlices sts;
154     EXPECT_TRUE(ParseProtoUnlimited(&sts, value));
155     // We also expect two entries for the tensors
156     EXPECT_TRUE(sts.has_meta());
157     EXPECT_EQ(4, sts.meta().tensor_size());
158     // We should have written nontrivial version information
159     EXPECT_LT(0, TF_CHECKPOINT_VERSION);
160     EXPECT_EQ(TF_CHECKPOINT_VERSION, sts.meta().versions().producer());
161     EXPECT_EQ(TF_CHECKPOINT_VERSION_MIN_CONSUMER,
162               sts.meta().versions().min_consumer());
163     // We don't expect any data in the first block.
164     EXPECT_FALSE(sts.has_data());
165     // The two tensors should be stored in the same order as they are first
166     // created.
167     {
168       // The two slices of the "test" tensor
169       const SavedSliceMeta& ssm = sts.meta().tensor(0);
170       EXPECT_EQ("test", ssm.name());
171       EXPECT_EQ(
172           "dim { size: 5 } "
173           "dim { size: 10 }",
174           ssm.shape().ShortDebugString());
175       EXPECT_EQ(DT_INT32, ssm.type());
176       EXPECT_EQ(2, ssm.slice_size());
177       TensorSlice s0(ssm.slice(0));
178       TensorSlice s1(ssm.slice(1));
179       EXPECT_EQ("-:0,1", s0.DebugString());
180       EXPECT_EQ("-:3,1", s1.DebugString());
181     }
182     {
183       // The "AA" tensor
184       const SavedSliceMeta& ssm = sts.meta().tensor(1);
185       EXPECT_EQ("AA", ssm.name());
186       EXPECT_EQ(
187           "dim { size: 3 } "
188           "dim { size: 2 }",
189           ssm.shape().ShortDebugString());
190       EXPECT_EQ(DT_FLOAT, ssm.type());
191       EXPECT_EQ(1, ssm.slice_size());
192       TensorSlice s0(ssm.slice(0));
193       EXPECT_EQ("-:-", s0.DebugString());
194     }
195     {
196       // The "int64" tensor
197       const SavedSliceMeta& ssm = sts.meta().tensor(2);
198       EXPECT_EQ("int64", ssm.name());
199       EXPECT_EQ(
200           "dim { size: 5 } "
201           "dim { size: 10 }",
202           ssm.shape().ShortDebugString());
203       EXPECT_EQ(DT_INT64, ssm.type());
204       EXPECT_EQ(1, ssm.slice_size());
205       TensorSlice s0(ssm.slice(0));
206       EXPECT_EQ("-:3,1", s0.DebugString());
207     }
208     {
209       // The "int16" tensor
210       const SavedSliceMeta& ssm = sts.meta().tensor(3);
211       EXPECT_EQ("int16", ssm.name());
212       EXPECT_EQ(
213           "dim { size: 5 } "
214           "dim { size: 10 }",
215           ssm.shape().ShortDebugString());
216       EXPECT_EQ(DT_INT16, ssm.type());
217       EXPECT_EQ(1, ssm.slice_size());
218       TensorSlice s0(ssm.slice(0));
219       EXPECT_EQ("-:3,1", s0.DebugString());
220     }
221   }
222 
223   // We expect 5 blocks of tensor data
224   {
225     // Block 1: we expect it to be the full slice of the "AA" tensor
226     SavedSlice ss;
227     GetData(table.get(), "AA", TensorSlice(2), &ss);
228     const float data[] = {1.2, 1.3, 1.4, 2.1, 2.2, 2.3};
229     EXPECT_EQ(ArraySize(data), ss.data().float_val_size());
230     ExpectIdenticalFloatArrays(data, ArraySize(data),
231                                ss.data().float_val().data());
232   }
233 
234   {
235     // Block 2: we expect it to be the first slice of the "test" tensor
236     SavedSlice ss;
237     GetData(table.get(), "test", TensorSlice({{0, -1}, {0, 1}}), &ss);
238     const int32 data[] = {0, 1, 2, 3, 4};
239     EXPECT_EQ(ArraySize(data), ss.data().int_val_size());
240     ExpectIdenticalIntArrays(data, ArraySize(data), ss.data().int_val().data());
241   }
242 
243   {
244     // Block 3: we expect it to be the second slice of the "test" tensor
245     SavedSlice ss;
246     GetData(table.get(), "test", TensorSlice({{0, -1}, {3, 1}}), &ss);
247     const int32 data[] = {10, 11, 12, 13, 14};
248     EXPECT_EQ(ArraySize(data), ss.data().int_val_size());
249     ExpectIdenticalIntArrays(data, ArraySize(data), ss.data().int_val().data());
250   }
251 
252   {
253     // Block 4: we expect it to be the slice of the "int64" tensor
254     SavedSlice ss;
255     GetData(table.get(), "int64", TensorSlice({{0, -1}, {3, 1}}), &ss);
256     const int64 data[] = {10, 11, 12, 13, 14};
257     EXPECT_EQ(ArraySize(data), ss.data().int64_val_size());
258     ExpectIdenticalIntArrays(data, ArraySize(data),
259                              ss.data().int64_val().data());
260   }
261 
262   {
263     // Block 5: we expect it to be the slice of the "int16" tensor
264     SavedSlice ss;
265     GetData(table.get(), "int16", TensorSlice({{0, -1}, {3, 1}}), &ss);
266     const int16 data[] = {10, 11, 12, 13, 14};
267     EXPECT_EQ(ArraySize(data), ss.data().int_val_size());
268     ExpectIdenticalIntArrays(data, ArraySize(data), ss.data().int_val().data());
269   }
270 }
271 
272 template <typename DT>
BytesPerElementHelper(DT value)273 size_t BytesPerElementHelper(DT value) {
274   SavedSlice ss;
275   std::array<DT, 1> lo_data;
276   std::fill(lo_data.begin(), lo_data.end(), value);
277   TF_EXPECT_OK(
278       TensorSliceWriter::SaveData(lo_data.data(), lo_data.size(), &ss));
279   size_t lo_byte_size = ss.ByteSizeLong();
280 
281   std::array<DT, 1001> hi_data;
282   std::fill(hi_data.begin(), hi_data.end(), value);
283   TF_EXPECT_OK(
284       TensorSliceWriter::SaveData(hi_data.data(), hi_data.size(), &ss));
285   size_t hi_byte_size = ss.ByteSizeLong();
286 
287   return (hi_byte_size - lo_byte_size) / (hi_data.size() - lo_data.size());
288 }
289 
TEST(TensorSliceWriteTest,CheckpointSize)290 TEST(TensorSliceWriteTest, CheckpointSize) {
291   EXPECT_EQ(TensorSliceWriter::MaxBytesPerElement(DT_BOOL),
292             BytesPerElementHelper<bool>(false));
293   EXPECT_EQ(TensorSliceWriter::MaxBytesPerElement(DT_BOOL),
294             BytesPerElementHelper<bool>(true));
295   EXPECT_EQ(TensorSliceWriter::MaxBytesPerElement(DT_FLOAT),
296             BytesPerElementHelper<float>(-1.0));
297   EXPECT_EQ(TensorSliceWriter::MaxBytesPerElement(DT_DOUBLE),
298             BytesPerElementHelper<double>(-1.0));
299   EXPECT_EQ(TensorSliceWriter::MaxBytesPerElement(DT_COMPLEX64),
300             BytesPerElementHelper<complex64>(-1.0));
301   EXPECT_EQ(TensorSliceWriter::MaxBytesPerElement(DT_COMPLEX128),
302             BytesPerElementHelper<complex128>(-1.0));
303   EXPECT_EQ(TensorSliceWriter::MaxBytesPerElement(DT_INT32),
304             BytesPerElementHelper<int32>(-1));
305   EXPECT_EQ(TensorSliceWriter::MaxBytesPerElement(DT_INT64),
306             BytesPerElementHelper<int64>(-1));
307   EXPECT_EQ(TensorSliceWriter::MaxBytesPerElement(DT_UINT16),
308             BytesPerElementHelper<uint16>(std::numeric_limits<uint16>::max()));
309   EXPECT_EQ(TensorSliceWriter::MaxBytesPerElement(DT_UINT8),
310             BytesPerElementHelper<uint8>(std::numeric_limits<uint8>::max()));
311   EXPECT_EQ(TensorSliceWriter::MaxBytesPerElement(DT_INT8),
312             BytesPerElementHelper<int8>(-1));
313   EXPECT_EQ(TensorSliceWriter::MaxBytesPerElement(DT_INT16),
314             BytesPerElementHelper<int16>(-1));
315   EXPECT_EQ(TensorSliceWriter::MaxBytesPerElement(DT_QINT8),
316             BytesPerElementHelper<qint8>(-1));
317   EXPECT_EQ(TensorSliceWriter::MaxBytesPerElement(DT_QUINT8),
318             BytesPerElementHelper<quint8>(std::numeric_limits<uint8>::max()));
319   EXPECT_EQ(TensorSliceWriter::MaxBytesPerElement(DT_QINT32),
320             BytesPerElementHelper<qint32>(-1));
321   EXPECT_EQ(TensorSliceWriter::MaxBytesPerElement(DT_HALF),
322             BytesPerElementHelper<Eigen::half>(Eigen::half(-1.0)));
323 }
324 
TEST(TensorSliceWriteTest,SizeErrors)325 TEST(TensorSliceWriteTest, SizeErrors) {
326   const string filename = io::JoinPath(testing::TmpDir(), "checkpoint");
327 
328   TensorSliceWriter writer(filename, CreateTableTensorSliceBuilder);
329 
330   // Add a 300MB int8 tensor slice, which will fail because it expands to 3GB.
331   {
332     TensorShape shape({300, 1000000});
333     TensorSlice slice = TensorSlice::ParseOrDie("-:-");
334     const std::vector<int8> data(300000000, -1);
335     Status s = writer.Add("test1", shape, slice, data.data());
336     EXPECT_EQ(s.code(), error::INVALID_ARGUMENT);
337     EXPECT_TRUE(str_util::StrContains(
338         s.error_message(), "Tensor slice is too large to serialize"));
339   }
340 
341   // Add a large string tensor slice, which will fail.
342   {
343     TensorShape shape({256, 1024});
344     TensorSlice slice = TensorSlice::ParseOrDie("-:-");
345     const std::vector<string> data(256 * 1024, std::string(8192, 'f'));
346     Status s = writer.Add("test2", shape, slice, data.data());
347     EXPECT_EQ(s.code(), error::INVALID_ARGUMENT);
348     EXPECT_TRUE(str_util::StrContains(
349         s.error_message(), "Tensor slice is too large to serialize"));
350   }
351 }
352 
353 }  // namespace checkpoint
354 
355 }  // namespace tensorflow
356