1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/lib/core/coding.h"
17 #include "tensorflow/core/lib/core/errors.h"
18 #include "tensorflow/core/lib/core/status_test_util.h"
19 #include "tensorflow/core/lib/hash/crc32c.h"
20 #include "tensorflow/core/lib/io/record_reader.h"
21 #include "tensorflow/core/lib/io/record_writer.h"
22 #include "tensorflow/core/lib/random/simple_philox.h"
23 #include "tensorflow/core/lib/strings/str_util.h"
24 #include "tensorflow/core/platform/env.h"
25 #include "tensorflow/core/platform/test.h"
26
27 namespace tensorflow {
28 namespace io {
29 namespace {
30
31 // Construct a string of the specified length made out of the supplied
32 // partial string.
BigString(const string & partial_string,size_t n)33 string BigString(const string& partial_string, size_t n) {
34 string result;
35 while (result.size() < n) {
36 result.append(partial_string);
37 }
38 result.resize(n);
39 return result;
40 }
41
42 // Construct a string from a number
NumberString(int n)43 string NumberString(int n) {
44 char buf[50];
45 snprintf(buf, sizeof(buf), "%d.", n);
46 return string(buf);
47 }
48
49 // Return a skewed potentially long string
RandomSkewedString(int i,random::SimplePhilox * rnd)50 string RandomSkewedString(int i, random::SimplePhilox* rnd) {
51 return BigString(NumberString(i), rnd->Skewed(17));
52 }
53
54 class StringDest : public WritableFile {
55 public:
StringDest(string * contents)56 explicit StringDest(string* contents) : contents_(contents) {}
57
Close()58 Status Close() override { return Status::OK(); }
Flush()59 Status Flush() override { return Status::OK(); }
Sync()60 Status Sync() override { return Status::OK(); }
Append(StringPiece slice)61 Status Append(StringPiece slice) override {
62 contents_->append(slice.data(), slice.size());
63 return Status::OK();
64 }
Tell(int64 * pos)65 Status Tell(int64* pos) override {
66 *pos = contents_->size();
67 return Status::OK();
68 }
69
70 private:
71 string* contents_;
72 };
73
74 class StringSource : public RandomAccessFile {
75 public:
StringSource(string * contents)76 explicit StringSource(string* contents)
77 : contents_(contents), force_error_(false) {}
78
Read(uint64 offset,size_t n,StringPiece * result,char * scratch) const79 Status Read(uint64 offset, size_t n, StringPiece* result,
80 char* scratch) const override {
81 if (force_error_) {
82 force_error_ = false;
83 return errors::DataLoss("read error");
84 }
85
86 if (offset >= contents_->size()) {
87 return errors::OutOfRange("end of file");
88 }
89
90 if (contents_->size() < offset + n) {
91 n = contents_->size() - offset;
92 }
93 *result = StringPiece(contents_->data() + offset, n);
94 return Status::OK();
95 }
96
force_error()97 void force_error() { force_error_ = true; }
98
99 private:
100 string* contents_;
101 mutable bool force_error_;
102 };
103
104 class RecordioTest : public ::testing::Test {
105 private:
106 string contents_;
107 StringDest dest_;
108 StringSource source_;
109 bool reading_;
110 uint64 readpos_;
111 RecordWriter* writer_;
112 RecordReader* reader_;
113
114 public:
RecordioTest()115 RecordioTest()
116 : dest_(&contents_),
117 source_(&contents_),
118 reading_(false),
119 readpos_(0),
120 writer_(new RecordWriter(&dest_)),
121 reader_(new RecordReader(&source_)) {}
122
~RecordioTest()123 ~RecordioTest() override {
124 delete writer_;
125 delete reader_;
126 }
127
Write(const string & msg)128 void Write(const string& msg) {
129 ASSERT_TRUE(!reading_) << "Write() after starting to read";
130 TF_ASSERT_OK(writer_->WriteRecord(StringPiece(msg)));
131 }
132
WrittenBytes() const133 size_t WrittenBytes() const { return contents_.size(); }
134
Read()135 string Read() {
136 if (!reading_) {
137 reading_ = true;
138 }
139 string record;
140 Status s = reader_->ReadRecord(&readpos_, &record);
141 if (s.ok()) {
142 return record;
143 } else if (errors::IsOutOfRange(s)) {
144 return "EOF";
145 } else {
146 return s.ToString();
147 }
148 }
149
IncrementByte(int offset,int delta)150 void IncrementByte(int offset, int delta) { contents_[offset] += delta; }
151
SetByte(int offset,char new_byte)152 void SetByte(int offset, char new_byte) { contents_[offset] = new_byte; }
153
ShrinkSize(int bytes)154 void ShrinkSize(int bytes) { contents_.resize(contents_.size() - bytes); }
155
FixChecksum(int header_offset,int len)156 void FixChecksum(int header_offset, int len) {
157 // Compute crc of type/len/data
158 uint32_t crc = crc32c::Value(&contents_[header_offset + 6], 1 + len);
159 crc = crc32c::Mask(crc);
160 core::EncodeFixed32(&contents_[header_offset], crc);
161 }
162
ForceError()163 void ForceError() { source_.force_error(); }
164
StartReadingAt(uint64_t initial_offset)165 void StartReadingAt(uint64_t initial_offset) { readpos_ = initial_offset; }
166
CheckOffsetPastEndReturnsNoRecords(uint64_t offset_past_end)167 void CheckOffsetPastEndReturnsNoRecords(uint64_t offset_past_end) {
168 Write("foo");
169 Write("bar");
170 Write(BigString("x", 10000));
171 reading_ = true;
172 uint64 offset = WrittenBytes() + offset_past_end;
173 string record;
174 Status s = reader_->ReadRecord(&offset, &record);
175 ASSERT_TRUE(errors::IsOutOfRange(s)) << s;
176 }
177 };
178
TEST_F(RecordioTest,Empty)179 TEST_F(RecordioTest, Empty) { ASSERT_EQ("EOF", Read()); }
180
TEST_F(RecordioTest,ReadWrite)181 TEST_F(RecordioTest, ReadWrite) {
182 Write("foo");
183 Write("bar");
184 Write("");
185 Write("xxxx");
186 ASSERT_EQ("foo", Read());
187 ASSERT_EQ("bar", Read());
188 ASSERT_EQ("", Read());
189 ASSERT_EQ("xxxx", Read());
190 ASSERT_EQ("EOF", Read());
191 ASSERT_EQ("EOF", Read()); // Make sure reads at eof work
192 }
193
TEST_F(RecordioTest,ManyRecords)194 TEST_F(RecordioTest, ManyRecords) {
195 for (int i = 0; i < 100000; i++) {
196 Write(NumberString(i));
197 }
198 for (int i = 0; i < 100000; i++) {
199 ASSERT_EQ(NumberString(i), Read());
200 }
201 ASSERT_EQ("EOF", Read());
202 }
203
TEST_F(RecordioTest,RandomRead)204 TEST_F(RecordioTest, RandomRead) {
205 const int N = 500;
206 {
207 random::PhiloxRandom philox(301, 17);
208 random::SimplePhilox rnd(&philox);
209 for (int i = 0; i < N; i++) {
210 Write(RandomSkewedString(i, &rnd));
211 }
212 }
213 {
214 random::PhiloxRandom philox(301, 17);
215 random::SimplePhilox rnd(&philox);
216 for (int i = 0; i < N; i++) {
217 ASSERT_EQ(RandomSkewedString(i, &rnd), Read());
218 }
219 }
220 ASSERT_EQ("EOF", Read());
221 }
222
TestNonSequentialReads(const RecordWriterOptions & writer_options,const RecordReaderOptions & reader_options)223 void TestNonSequentialReads(const RecordWriterOptions& writer_options,
224 const RecordReaderOptions& reader_options) {
225 string contents;
226 StringDest dst(&contents);
227 RecordWriter writer(&dst, writer_options);
228 for (int i = 0; i < 10; ++i) {
229 TF_ASSERT_OK(writer.WriteRecord(NumberString(i))) << i;
230 }
231 TF_ASSERT_OK(writer.Close());
232
233 StringSource file(&contents);
234 RecordReader reader(&file, reader_options);
235
236 string record;
237 // First read sequentially to fill in the offsets table.
238 uint64 offsets[10] = {0};
239 uint64 offset = 0;
240 for (int i = 0; i < 10; ++i) {
241 offsets[i] = offset;
242 TF_ASSERT_OK(reader.ReadRecord(&offset, &record)) << i;
243 }
244
245 // Read randomly: First go back to record #3 then forward to #8.
246 offset = offsets[3];
247 TF_ASSERT_OK(reader.ReadRecord(&offset, &record));
248 EXPECT_EQ("3.", record);
249 EXPECT_EQ(offsets[4], offset);
250
251 offset = offsets[8];
252 TF_ASSERT_OK(reader.ReadRecord(&offset, &record));
253 EXPECT_EQ("8.", record);
254 EXPECT_EQ(offsets[9], offset);
255 }
256
TEST_F(RecordioTest,NonSequentialReads)257 TEST_F(RecordioTest, NonSequentialReads) {
258 TestNonSequentialReads(RecordWriterOptions(), RecordReaderOptions());
259 }
260
TEST_F(RecordioTest,NonSequentialReadsWithReadBuffer)261 TEST_F(RecordioTest, NonSequentialReadsWithReadBuffer) {
262 RecordReaderOptions options;
263 options.buffer_size = 1 << 10;
264 TestNonSequentialReads(RecordWriterOptions(), options);
265 }
266
TEST_F(RecordioTest,NonSequentialReadsWithCompression)267 TEST_F(RecordioTest, NonSequentialReadsWithCompression) {
268 TestNonSequentialReads(
269 RecordWriterOptions::CreateRecordWriterOptions("ZLIB"),
270 RecordReaderOptions::CreateRecordReaderOptions("ZLIB"));
271 }
272
273 // Tests of all the error paths in log_reader.cc follow:
AssertHasSubstr(StringPiece s,StringPiece expected)274 void AssertHasSubstr(StringPiece s, StringPiece expected) {
275 EXPECT_TRUE(str_util::StrContains(s, expected))
276 << s << " does not contain " << expected;
277 }
278
TestReadError(const RecordWriterOptions & writer_options,const RecordReaderOptions & reader_options)279 void TestReadError(const RecordWriterOptions& writer_options,
280 const RecordReaderOptions& reader_options) {
281 const string wrote = BigString("well hello there!", 100);
282 string contents;
283 StringDest dst(&contents);
284 TF_ASSERT_OK(RecordWriter(&dst, writer_options).WriteRecord(wrote));
285
286 StringSource file(&contents);
287 RecordReader reader(&file, reader_options);
288
289 uint64 offset = 0;
290 string read;
291 file.force_error();
292 Status status = reader.ReadRecord(&offset, &read);
293 ASSERT_TRUE(errors::IsDataLoss(status));
294 ASSERT_EQ(0, offset);
295
296 // A failed Read() shouldn't update the offset, and thus a retry shouldn't
297 // lose the record.
298 status = reader.ReadRecord(&offset, &read);
299 ASSERT_TRUE(status.ok()) << status;
300 EXPECT_GT(offset, 0);
301 EXPECT_EQ(wrote, read);
302 }
303
TEST_F(RecordioTest,ReadError)304 TEST_F(RecordioTest, ReadError) {
305 TestReadError(RecordWriterOptions(), RecordReaderOptions());
306 }
307
TEST_F(RecordioTest,ReadErrorWithBuffering)308 TEST_F(RecordioTest, ReadErrorWithBuffering) {
309 RecordReaderOptions options;
310 options.buffer_size = 1 << 20;
311 TestReadError(RecordWriterOptions(), options);
312 }
313
TEST_F(RecordioTest,ReadErrorWithCompression)314 TEST_F(RecordioTest, ReadErrorWithCompression) {
315 TestReadError(RecordWriterOptions::CreateRecordWriterOptions("ZLIB"),
316 RecordReaderOptions::CreateRecordReaderOptions("ZLIB"));
317 }
318
TEST_F(RecordioTest,CorruptLength)319 TEST_F(RecordioTest, CorruptLength) {
320 Write("foo");
321 IncrementByte(6, 100);
322 AssertHasSubstr(Read(), "Data loss");
323 }
324
TEST_F(RecordioTest,CorruptLengthCrc)325 TEST_F(RecordioTest, CorruptLengthCrc) {
326 Write("foo");
327 IncrementByte(10, 100);
328 AssertHasSubstr(Read(), "Data loss");
329 }
330
TEST_F(RecordioTest,CorruptData)331 TEST_F(RecordioTest, CorruptData) {
332 Write("foo");
333 IncrementByte(14, 10);
334 AssertHasSubstr(Read(), "Data loss");
335 }
336
TEST_F(RecordioTest,CorruptDataCrc)337 TEST_F(RecordioTest, CorruptDataCrc) {
338 Write("foo");
339 IncrementByte(WrittenBytes() - 1, 10);
340 AssertHasSubstr(Read(), "Data loss");
341 }
342
TEST_F(RecordioTest,ReadEnd)343 TEST_F(RecordioTest, ReadEnd) { CheckOffsetPastEndReturnsNoRecords(0); }
344
TEST_F(RecordioTest,ReadPastEnd)345 TEST_F(RecordioTest, ReadPastEnd) { CheckOffsetPastEndReturnsNoRecords(5); }
346
347 } // namespace
348 } // namespace io
349 } // namespace tensorflow
350