1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/lib/core/coding.h"
17 #include "tensorflow/core/lib/core/errors.h"
18 #include "tensorflow/core/lib/core/status_test_util.h"
19 #include "tensorflow/core/lib/hash/crc32c.h"
20 #include "tensorflow/core/lib/io/record_reader.h"
21 #include "tensorflow/core/lib/io/record_writer.h"
22 #include "tensorflow/core/lib/random/simple_philox.h"
23 #include "tensorflow/core/lib/strings/str_util.h"
24 #include "tensorflow/core/platform/env.h"
25 #include "tensorflow/core/platform/test.h"
26
27 namespace tensorflow {
28 namespace io {
29 namespace {
30
31 // Construct a string of the specified length made out of the supplied
32 // partial string.
BigString(const string & partial_string,size_t n)33 string BigString(const string& partial_string, size_t n) {
34 string result;
35 while (result.size() < n) {
36 result.append(partial_string);
37 }
38 result.resize(n);
39 return result;
40 }
41
42 // Construct a string from a number
NumberString(int n)43 string NumberString(int n) {
44 char buf[50];
45 snprintf(buf, sizeof(buf), "%d.", n);
46 return string(buf);
47 }
48
49 // Return a skewed potentially long string
RandomSkewedString(int i,random::SimplePhilox * rnd)50 string RandomSkewedString(int i, random::SimplePhilox* rnd) {
51 return BigString(NumberString(i), rnd->Skewed(17));
52 }
53
54 class StringDest : public WritableFile {
55 public:
StringDest(string * contents)56 explicit StringDest(string* contents) : contents_(contents) {}
57
Close()58 Status Close() override { return Status::OK(); }
Flush()59 Status Flush() override { return Status::OK(); }
Sync()60 Status Sync() override { return Status::OK(); }
Append(StringPiece slice)61 Status Append(StringPiece slice) override {
62 contents_->append(slice.data(), slice.size());
63 return Status::OK();
64 }
65 #if defined(PLATFORM_GOOGLE)
Append(const absl::Cord & data)66 Status Append(const absl::Cord& data) override {
67 contents_->append(data.ToString());
68 return Status::OK();
69 }
70 #endif
Tell(int64 * pos)71 Status Tell(int64* pos) override {
72 *pos = contents_->size();
73 return Status::OK();
74 }
75
76 private:
77 string* contents_;
78 };
79
80 class StringSource : public RandomAccessFile {
81 public:
StringSource(string * contents)82 explicit StringSource(string* contents)
83 : contents_(contents), force_error_(false) {}
84
Read(uint64 offset,size_t n,StringPiece * result,char * scratch) const85 Status Read(uint64 offset, size_t n, StringPiece* result,
86 char* scratch) const override {
87 if (force_error_) {
88 force_error_ = false;
89 return errors::DataLoss("read error");
90 }
91
92 if (offset >= contents_->size()) {
93 return errors::OutOfRange("end of file");
94 }
95
96 if (contents_->size() < offset + n) {
97 n = contents_->size() - offset;
98 }
99 *result = StringPiece(contents_->data() + offset, n);
100 return Status::OK();
101 }
102
force_error()103 void force_error() { force_error_ = true; }
104
105 private:
106 string* contents_;
107 mutable bool force_error_;
108 };
109
110 class RecordioTest : public ::testing::Test {
111 private:
112 string contents_;
113 StringDest dest_;
114 StringSource source_;
115 bool reading_;
116 uint64 readpos_;
117 RecordWriter* writer_;
118 RecordReader* reader_;
119
120 public:
RecordioTest()121 RecordioTest()
122 : dest_(&contents_),
123 source_(&contents_),
124 reading_(false),
125 readpos_(0),
126 writer_(new RecordWriter(&dest_)),
127 reader_(new RecordReader(&source_)) {}
128
~RecordioTest()129 ~RecordioTest() override {
130 delete writer_;
131 delete reader_;
132 }
133
Write(const string & msg)134 void Write(const string& msg) {
135 ASSERT_TRUE(!reading_) << "Write() after starting to read";
136 TF_ASSERT_OK(writer_->WriteRecord(StringPiece(msg)));
137 }
138
139 #if defined(PLATFORM_GOOGLE)
Write(const absl::Cord & msg)140 void Write(const absl::Cord& msg) {
141 ASSERT_TRUE(!reading_) << "Write() after starting to read";
142 TF_ASSERT_OK(writer_->WriteRecord(msg));
143 }
144 #endif
145
WrittenBytes() const146 size_t WrittenBytes() const { return contents_.size(); }
147
Read()148 string Read() {
149 if (!reading_) {
150 reading_ = true;
151 }
152 tstring record;
153 Status s = reader_->ReadRecord(&readpos_, &record);
154 if (s.ok()) {
155 return record;
156 } else if (errors::IsOutOfRange(s)) {
157 return "EOF";
158 } else {
159 return s.ToString();
160 }
161 }
162
IncrementByte(int offset,int delta)163 void IncrementByte(int offset, int delta) { contents_[offset] += delta; }
164
SetByte(int offset,char new_byte)165 void SetByte(int offset, char new_byte) { contents_[offset] = new_byte; }
166
ShrinkSize(int bytes)167 void ShrinkSize(int bytes) { contents_.resize(contents_.size() - bytes); }
168
FixChecksum(int header_offset,int len)169 void FixChecksum(int header_offset, int len) {
170 // Compute crc of type/len/data
171 uint32_t crc = crc32c::Value(&contents_[header_offset + 6], 1 + len);
172 crc = crc32c::Mask(crc);
173 core::EncodeFixed32(&contents_[header_offset], crc);
174 }
175
ForceError()176 void ForceError() { source_.force_error(); }
177
StartReadingAt(uint64_t initial_offset)178 void StartReadingAt(uint64_t initial_offset) { readpos_ = initial_offset; }
179
CheckOffsetPastEndReturnsNoRecords(uint64_t offset_past_end)180 void CheckOffsetPastEndReturnsNoRecords(uint64_t offset_past_end) {
181 Write("foo");
182 Write("bar");
183 Write(BigString("x", 10000));
184 reading_ = true;
185 uint64 offset = WrittenBytes() + offset_past_end;
186 tstring record;
187 Status s = reader_->ReadRecord(&offset, &record);
188 ASSERT_TRUE(errors::IsOutOfRange(s)) << s;
189 }
190 };
191
TEST_F(RecordioTest,Empty)192 TEST_F(RecordioTest, Empty) { ASSERT_EQ("EOF", Read()); }
193
TEST_F(RecordioTest,ReadWrite)194 TEST_F(RecordioTest, ReadWrite) {
195 Write("foo");
196 Write("bar");
197 Write("");
198 Write("xxxx");
199 ASSERT_EQ("foo", Read());
200 ASSERT_EQ("bar", Read());
201 ASSERT_EQ("", Read());
202 ASSERT_EQ("xxxx", Read());
203 ASSERT_EQ("EOF", Read());
204 ASSERT_EQ("EOF", Read()); // Make sure reads at eof work
205 }
206
207 #if defined(PLATFORM_GOOGLE)
TEST_F(RecordioTest,ReadWriteCords)208 TEST_F(RecordioTest, ReadWriteCords) {
209 Write(absl::Cord("foo"));
210 Write(absl::Cord("bar"));
211 Write(absl::Cord(""));
212 Write(absl::Cord("xxxx"));
213 ASSERT_EQ("foo", Read());
214 ASSERT_EQ("bar", Read());
215 ASSERT_EQ("", Read());
216 ASSERT_EQ("xxxx", Read());
217 ASSERT_EQ("EOF", Read());
218 ASSERT_EQ("EOF", Read()); // Make sure reads at eof work
219 }
220 #endif
221
TEST_F(RecordioTest,ManyRecords)222 TEST_F(RecordioTest, ManyRecords) {
223 for (int i = 0; i < 100000; i++) {
224 Write(NumberString(i));
225 }
226 for (int i = 0; i < 100000; i++) {
227 ASSERT_EQ(NumberString(i), Read());
228 }
229 ASSERT_EQ("EOF", Read());
230 }
231
TEST_F(RecordioTest,RandomRead)232 TEST_F(RecordioTest, RandomRead) {
233 const int N = 500;
234 {
235 random::PhiloxRandom philox(301, 17);
236 random::SimplePhilox rnd(&philox);
237 for (int i = 0; i < N; i++) {
238 Write(RandomSkewedString(i, &rnd));
239 }
240 }
241 {
242 random::PhiloxRandom philox(301, 17);
243 random::SimplePhilox rnd(&philox);
244 for (int i = 0; i < N; i++) {
245 ASSERT_EQ(RandomSkewedString(i, &rnd), Read());
246 }
247 }
248 ASSERT_EQ("EOF", Read());
249 }
250
TestNonSequentialReads(const RecordWriterOptions & writer_options,const RecordReaderOptions & reader_options)251 void TestNonSequentialReads(const RecordWriterOptions& writer_options,
252 const RecordReaderOptions& reader_options) {
253 string contents;
254 StringDest dst(&contents);
255 RecordWriter writer(&dst, writer_options);
256 for (int i = 0; i < 10; ++i) {
257 TF_ASSERT_OK(writer.WriteRecord(NumberString(i))) << i;
258 }
259 TF_ASSERT_OK(writer.Close());
260
261 StringSource file(&contents);
262 RecordReader reader(&file, reader_options);
263
264 tstring record;
265 // First read sequentially to fill in the offsets table.
266 uint64 offsets[10] = {0};
267 uint64 offset = 0;
268 for (int i = 0; i < 10; ++i) {
269 offsets[i] = offset;
270 TF_ASSERT_OK(reader.ReadRecord(&offset, &record)) << i;
271 }
272
273 // Read randomly: First go back to record #3 then forward to #8.
274 offset = offsets[3];
275 TF_ASSERT_OK(reader.ReadRecord(&offset, &record));
276 EXPECT_EQ("3.", record);
277 EXPECT_EQ(offsets[4], offset);
278
279 offset = offsets[8];
280 TF_ASSERT_OK(reader.ReadRecord(&offset, &record));
281 EXPECT_EQ("8.", record);
282 EXPECT_EQ(offsets[9], offset);
283 }
284
TEST_F(RecordioTest,NonSequentialReads)285 TEST_F(RecordioTest, NonSequentialReads) {
286 TestNonSequentialReads(RecordWriterOptions(), RecordReaderOptions());
287 }
288
TEST_F(RecordioTest,NonSequentialReadsWithReadBuffer)289 TEST_F(RecordioTest, NonSequentialReadsWithReadBuffer) {
290 RecordReaderOptions options;
291 options.buffer_size = 1 << 10;
292 TestNonSequentialReads(RecordWriterOptions(), options);
293 }
294
TEST_F(RecordioTest,NonSequentialReadsWithCompression)295 TEST_F(RecordioTest, NonSequentialReadsWithCompression) {
296 TestNonSequentialReads(
297 RecordWriterOptions::CreateRecordWriterOptions("ZLIB"),
298 RecordReaderOptions::CreateRecordReaderOptions("ZLIB"));
299 }
300
301 // Tests of all the error paths in log_reader.cc follow:
AssertHasSubstr(StringPiece s,StringPiece expected)302 void AssertHasSubstr(StringPiece s, StringPiece expected) {
303 EXPECT_TRUE(absl::StrContains(s, expected))
304 << s << " does not contain " << expected;
305 }
306
TestReadError(const RecordWriterOptions & writer_options,const RecordReaderOptions & reader_options)307 void TestReadError(const RecordWriterOptions& writer_options,
308 const RecordReaderOptions& reader_options) {
309 const string wrote = BigString("well hello there!", 100);
310 string contents;
311 StringDest dst(&contents);
312 TF_ASSERT_OK(RecordWriter(&dst, writer_options).WriteRecord(wrote));
313
314 StringSource file(&contents);
315 RecordReader reader(&file, reader_options);
316
317 uint64 offset = 0;
318 tstring read;
319 file.force_error();
320 Status status = reader.ReadRecord(&offset, &read);
321 ASSERT_TRUE(errors::IsDataLoss(status));
322 ASSERT_EQ(0, offset);
323
324 // A failed Read() shouldn't update the offset, and thus a retry shouldn't
325 // lose the record.
326 status = reader.ReadRecord(&offset, &read);
327 ASSERT_TRUE(status.ok()) << status;
328 EXPECT_GT(offset, 0);
329 EXPECT_EQ(wrote, read);
330 }
331
TEST_F(RecordioTest,ReadError)332 TEST_F(RecordioTest, ReadError) {
333 TestReadError(RecordWriterOptions(), RecordReaderOptions());
334 }
335
TEST_F(RecordioTest,ReadErrorWithBuffering)336 TEST_F(RecordioTest, ReadErrorWithBuffering) {
337 RecordReaderOptions options;
338 options.buffer_size = 1 << 20;
339 TestReadError(RecordWriterOptions(), options);
340 }
341
TEST_F(RecordioTest,ReadErrorWithCompression)342 TEST_F(RecordioTest, ReadErrorWithCompression) {
343 TestReadError(RecordWriterOptions::CreateRecordWriterOptions("ZLIB"),
344 RecordReaderOptions::CreateRecordReaderOptions("ZLIB"));
345 }
346
TEST_F(RecordioTest,CorruptLength)347 TEST_F(RecordioTest, CorruptLength) {
348 Write("foo");
349 IncrementByte(6, 100);
350 AssertHasSubstr(Read(), "Data loss");
351 }
352
TEST_F(RecordioTest,CorruptLengthCrc)353 TEST_F(RecordioTest, CorruptLengthCrc) {
354 Write("foo");
355 IncrementByte(10, 100);
356 AssertHasSubstr(Read(), "Data loss");
357 }
358
TEST_F(RecordioTest,CorruptData)359 TEST_F(RecordioTest, CorruptData) {
360 Write("foo");
361 IncrementByte(14, 10);
362 AssertHasSubstr(Read(), "Data loss");
363 }
364
TEST_F(RecordioTest,CorruptDataCrc)365 TEST_F(RecordioTest, CorruptDataCrc) {
366 Write("foo");
367 IncrementByte(WrittenBytes() - 1, 10);
368 AssertHasSubstr(Read(), "Data loss");
369 }
370
TEST_F(RecordioTest,ReadEnd)371 TEST_F(RecordioTest, ReadEnd) { CheckOffsetPastEndReturnsNoRecords(0); }
372
TEST_F(RecordioTest,ReadPastEnd)373 TEST_F(RecordioTest, ReadPastEnd) { CheckOffsetPastEndReturnsNoRecords(5); }
374
375 } // namespace
376 } // namespace io
377 } // namespace tensorflow
378