• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/lib/core/coding.h"
17 #include "tensorflow/core/lib/core/errors.h"
18 #include "tensorflow/core/lib/core/status_test_util.h"
19 #include "tensorflow/core/lib/hash/crc32c.h"
20 #include "tensorflow/core/lib/io/record_reader.h"
21 #include "tensorflow/core/lib/io/record_writer.h"
22 #include "tensorflow/core/lib/random/simple_philox.h"
23 #include "tensorflow/core/lib/strings/str_util.h"
24 #include "tensorflow/core/platform/env.h"
25 #include "tensorflow/core/platform/test.h"
26 
27 namespace tensorflow {
28 namespace io {
29 namespace {
30 
31 // Construct a string of the specified length made out of the supplied
32 // partial string.
BigString(const string & partial_string,size_t n)33 string BigString(const string& partial_string, size_t n) {
34   string result;
35   while (result.size() < n) {
36     result.append(partial_string);
37   }
38   result.resize(n);
39   return result;
40 }
41 
42 // Construct a string from a number
NumberString(int n)43 string NumberString(int n) {
44   char buf[50];
45   snprintf(buf, sizeof(buf), "%d.", n);
46   return string(buf);
47 }
48 
49 // Return a skewed potentially long string
RandomSkewedString(int i,random::SimplePhilox * rnd)50 string RandomSkewedString(int i, random::SimplePhilox* rnd) {
51   return BigString(NumberString(i), rnd->Skewed(17));
52 }
53 
54 class StringDest : public WritableFile {
55  public:
StringDest(string * contents)56   explicit StringDest(string* contents) : contents_(contents) {}
57 
Close()58   Status Close() override { return Status::OK(); }
Flush()59   Status Flush() override { return Status::OK(); }
Sync()60   Status Sync() override { return Status::OK(); }
Append(StringPiece slice)61   Status Append(StringPiece slice) override {
62     contents_->append(slice.data(), slice.size());
63     return Status::OK();
64   }
Tell(int64 * pos)65   Status Tell(int64* pos) override {
66     *pos = contents_->size();
67     return Status::OK();
68   }
69 
70  private:
71   string* contents_;
72 };
73 
74 class StringSource : public RandomAccessFile {
75  public:
StringSource(string * contents)76   explicit StringSource(string* contents)
77       : contents_(contents), force_error_(false) {}
78 
Read(uint64 offset,size_t n,StringPiece * result,char * scratch) const79   Status Read(uint64 offset, size_t n, StringPiece* result,
80               char* scratch) const override {
81     if (force_error_) {
82       force_error_ = false;
83       return errors::DataLoss("read error");
84     }
85 
86     if (offset >= contents_->size()) {
87       return errors::OutOfRange("end of file");
88     }
89 
90     if (contents_->size() < offset + n) {
91       n = contents_->size() - offset;
92     }
93     *result = StringPiece(contents_->data() + offset, n);
94     return Status::OK();
95   }
96 
force_error()97   void force_error() { force_error_ = true; }
98 
99  private:
100   string* contents_;
101   mutable bool force_error_;
102 };
103 
104 class RecordioTest : public ::testing::Test {
105  private:
106   string contents_;
107   StringDest dest_;
108   StringSource source_;
109   bool reading_;
110   uint64 readpos_;
111   RecordWriter* writer_;
112   RecordReader* reader_;
113 
114  public:
RecordioTest()115   RecordioTest()
116       : dest_(&contents_),
117         source_(&contents_),
118         reading_(false),
119         readpos_(0),
120         writer_(new RecordWriter(&dest_)),
121         reader_(new RecordReader(&source_)) {}
122 
~RecordioTest()123   ~RecordioTest() override {
124     delete writer_;
125     delete reader_;
126   }
127 
Write(const string & msg)128   void Write(const string& msg) {
129     ASSERT_TRUE(!reading_) << "Write() after starting to read";
130     TF_ASSERT_OK(writer_->WriteRecord(StringPiece(msg)));
131   }
132 
WrittenBytes() const133   size_t WrittenBytes() const { return contents_.size(); }
134 
Read()135   string Read() {
136     if (!reading_) {
137       reading_ = true;
138     }
139     string record;
140     Status s = reader_->ReadRecord(&readpos_, &record);
141     if (s.ok()) {
142       return record;
143     } else if (errors::IsOutOfRange(s)) {
144       return "EOF";
145     } else {
146       return s.ToString();
147     }
148   }
149 
IncrementByte(int offset,int delta)150   void IncrementByte(int offset, int delta) { contents_[offset] += delta; }
151 
SetByte(int offset,char new_byte)152   void SetByte(int offset, char new_byte) { contents_[offset] = new_byte; }
153 
ShrinkSize(int bytes)154   void ShrinkSize(int bytes) { contents_.resize(contents_.size() - bytes); }
155 
FixChecksum(int header_offset,int len)156   void FixChecksum(int header_offset, int len) {
157     // Compute crc of type/len/data
158     uint32_t crc = crc32c::Value(&contents_[header_offset + 6], 1 + len);
159     crc = crc32c::Mask(crc);
160     core::EncodeFixed32(&contents_[header_offset], crc);
161   }
162 
ForceError()163   void ForceError() { source_.force_error(); }
164 
StartReadingAt(uint64_t initial_offset)165   void StartReadingAt(uint64_t initial_offset) { readpos_ = initial_offset; }
166 
CheckOffsetPastEndReturnsNoRecords(uint64_t offset_past_end)167   void CheckOffsetPastEndReturnsNoRecords(uint64_t offset_past_end) {
168     Write("foo");
169     Write("bar");
170     Write(BigString("x", 10000));
171     reading_ = true;
172     uint64 offset = WrittenBytes() + offset_past_end;
173     string record;
174     Status s = reader_->ReadRecord(&offset, &record);
175     ASSERT_TRUE(errors::IsOutOfRange(s)) << s;
176   }
177 };
178 
TEST_F(RecordioTest,Empty)179 TEST_F(RecordioTest, Empty) { ASSERT_EQ("EOF", Read()); }
180 
TEST_F(RecordioTest,ReadWrite)181 TEST_F(RecordioTest, ReadWrite) {
182   Write("foo");
183   Write("bar");
184   Write("");
185   Write("xxxx");
186   ASSERT_EQ("foo", Read());
187   ASSERT_EQ("bar", Read());
188   ASSERT_EQ("", Read());
189   ASSERT_EQ("xxxx", Read());
190   ASSERT_EQ("EOF", Read());
191   ASSERT_EQ("EOF", Read());  // Make sure reads at eof work
192 }
193 
TEST_F(RecordioTest,ManyRecords)194 TEST_F(RecordioTest, ManyRecords) {
195   for (int i = 0; i < 100000; i++) {
196     Write(NumberString(i));
197   }
198   for (int i = 0; i < 100000; i++) {
199     ASSERT_EQ(NumberString(i), Read());
200   }
201   ASSERT_EQ("EOF", Read());
202 }
203 
TEST_F(RecordioTest,RandomRead)204 TEST_F(RecordioTest, RandomRead) {
205   const int N = 500;
206   {
207     random::PhiloxRandom philox(301, 17);
208     random::SimplePhilox rnd(&philox);
209     for (int i = 0; i < N; i++) {
210       Write(RandomSkewedString(i, &rnd));
211     }
212   }
213   {
214     random::PhiloxRandom philox(301, 17);
215     random::SimplePhilox rnd(&philox);
216     for (int i = 0; i < N; i++) {
217       ASSERT_EQ(RandomSkewedString(i, &rnd), Read());
218     }
219   }
220   ASSERT_EQ("EOF", Read());
221 }
222 
TestNonSequentialReads(const RecordWriterOptions & writer_options,const RecordReaderOptions & reader_options)223 void TestNonSequentialReads(const RecordWriterOptions& writer_options,
224                             const RecordReaderOptions& reader_options) {
225   string contents;
226   StringDest dst(&contents);
227   RecordWriter writer(&dst, writer_options);
228   for (int i = 0; i < 10; ++i) {
229     TF_ASSERT_OK(writer.WriteRecord(NumberString(i))) << i;
230   }
231   TF_ASSERT_OK(writer.Close());
232 
233   StringSource file(&contents);
234   RecordReader reader(&file, reader_options);
235 
236   string record;
237   // First read sequentially to fill in the offsets table.
238   uint64 offsets[10] = {0};
239   uint64 offset = 0;
240   for (int i = 0; i < 10; ++i) {
241     offsets[i] = offset;
242     TF_ASSERT_OK(reader.ReadRecord(&offset, &record)) << i;
243   }
244 
245   // Read randomly: First go back to record #3 then forward to #8.
246   offset = offsets[3];
247   TF_ASSERT_OK(reader.ReadRecord(&offset, &record));
248   EXPECT_EQ("3.", record);
249   EXPECT_EQ(offsets[4], offset);
250 
251   offset = offsets[8];
252   TF_ASSERT_OK(reader.ReadRecord(&offset, &record));
253   EXPECT_EQ("8.", record);
254   EXPECT_EQ(offsets[9], offset);
255 }
256 
TEST_F(RecordioTest,NonSequentialReads)257 TEST_F(RecordioTest, NonSequentialReads) {
258   TestNonSequentialReads(RecordWriterOptions(), RecordReaderOptions());
259 }
260 
TEST_F(RecordioTest,NonSequentialReadsWithReadBuffer)261 TEST_F(RecordioTest, NonSequentialReadsWithReadBuffer) {
262   RecordReaderOptions options;
263   options.buffer_size = 1 << 10;
264   TestNonSequentialReads(RecordWriterOptions(), options);
265 }
266 
TEST_F(RecordioTest,NonSequentialReadsWithCompression)267 TEST_F(RecordioTest, NonSequentialReadsWithCompression) {
268   TestNonSequentialReads(
269       RecordWriterOptions::CreateRecordWriterOptions("ZLIB"),
270       RecordReaderOptions::CreateRecordReaderOptions("ZLIB"));
271 }
272 
273 // Tests of all the error paths in log_reader.cc follow:
AssertHasSubstr(StringPiece s,StringPiece expected)274 void AssertHasSubstr(StringPiece s, StringPiece expected) {
275   EXPECT_TRUE(str_util::StrContains(s, expected))
276       << s << " does not contain " << expected;
277 }
278 
TestReadError(const RecordWriterOptions & writer_options,const RecordReaderOptions & reader_options)279 void TestReadError(const RecordWriterOptions& writer_options,
280                    const RecordReaderOptions& reader_options) {
281   const string wrote = BigString("well hello there!", 100);
282   string contents;
283   StringDest dst(&contents);
284   TF_ASSERT_OK(RecordWriter(&dst, writer_options).WriteRecord(wrote));
285 
286   StringSource file(&contents);
287   RecordReader reader(&file, reader_options);
288 
289   uint64 offset = 0;
290   string read;
291   file.force_error();
292   Status status = reader.ReadRecord(&offset, &read);
293   ASSERT_TRUE(errors::IsDataLoss(status));
294   ASSERT_EQ(0, offset);
295 
296   // A failed Read() shouldn't update the offset, and thus a retry shouldn't
297   // lose the record.
298   status = reader.ReadRecord(&offset, &read);
299   ASSERT_TRUE(status.ok()) << status;
300   EXPECT_GT(offset, 0);
301   EXPECT_EQ(wrote, read);
302 }
303 
TEST_F(RecordioTest,ReadError)304 TEST_F(RecordioTest, ReadError) {
305   TestReadError(RecordWriterOptions(), RecordReaderOptions());
306 }
307 
TEST_F(RecordioTest,ReadErrorWithBuffering)308 TEST_F(RecordioTest, ReadErrorWithBuffering) {
309   RecordReaderOptions options;
310   options.buffer_size = 1 << 20;
311   TestReadError(RecordWriterOptions(), options);
312 }
313 
TEST_F(RecordioTest,ReadErrorWithCompression)314 TEST_F(RecordioTest, ReadErrorWithCompression) {
315   TestReadError(RecordWriterOptions::CreateRecordWriterOptions("ZLIB"),
316                 RecordReaderOptions::CreateRecordReaderOptions("ZLIB"));
317 }
318 
TEST_F(RecordioTest,CorruptLength)319 TEST_F(RecordioTest, CorruptLength) {
320   Write("foo");
321   IncrementByte(6, 100);
322   AssertHasSubstr(Read(), "Data loss");
323 }
324 
TEST_F(RecordioTest,CorruptLengthCrc)325 TEST_F(RecordioTest, CorruptLengthCrc) {
326   Write("foo");
327   IncrementByte(10, 100);
328   AssertHasSubstr(Read(), "Data loss");
329 }
330 
TEST_F(RecordioTest,CorruptData)331 TEST_F(RecordioTest, CorruptData) {
332   Write("foo");
333   IncrementByte(14, 10);
334   AssertHasSubstr(Read(), "Data loss");
335 }
336 
TEST_F(RecordioTest,CorruptDataCrc)337 TEST_F(RecordioTest, CorruptDataCrc) {
338   Write("foo");
339   IncrementByte(WrittenBytes() - 1, 10);
340   AssertHasSubstr(Read(), "Data loss");
341 }
342 
TEST_F(RecordioTest,ReadEnd)343 TEST_F(RecordioTest, ReadEnd) { CheckOffsetPastEndReturnsNoRecords(0); }
344 
TEST_F(RecordioTest,ReadPastEnd)345 TEST_F(RecordioTest, ReadPastEnd) { CheckOffsetPastEndReturnsNoRecords(5); }
346 
347 }  // namespace
348 }  // namespace io
349 }  // namespace tensorflow
350