• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/lib/core/coding.h"
17 #include "tensorflow/core/lib/core/errors.h"
18 #include "tensorflow/core/lib/core/status_test_util.h"
19 #include "tensorflow/core/lib/hash/crc32c.h"
20 #include "tensorflow/core/lib/io/record_reader.h"
21 #include "tensorflow/core/lib/io/record_writer.h"
22 #include "tensorflow/core/lib/random/simple_philox.h"
23 #include "tensorflow/core/lib/strings/str_util.h"
24 #include "tensorflow/core/platform/env.h"
25 #include "tensorflow/core/platform/test.h"
26 
27 namespace tensorflow {
28 namespace io {
29 namespace {
30 
31 // Construct a string of the specified length made out of the supplied
32 // partial string.
BigString(const string & partial_string,size_t n)33 string BigString(const string& partial_string, size_t n) {
34   string result;
35   while (result.size() < n) {
36     result.append(partial_string);
37   }
38   result.resize(n);
39   return result;
40 }
41 
42 // Construct a string from a number
NumberString(int n)43 string NumberString(int n) {
44   char buf[50];
45   snprintf(buf, sizeof(buf), "%d.", n);
46   return string(buf);
47 }
48 
49 // Return a skewed potentially long string
RandomSkewedString(int i,random::SimplePhilox * rnd)50 string RandomSkewedString(int i, random::SimplePhilox* rnd) {
51   return BigString(NumberString(i), rnd->Skewed(17));
52 }
53 
54 class StringDest : public WritableFile {
55  public:
StringDest(string * contents)56   explicit StringDest(string* contents) : contents_(contents) {}
57 
Close()58   Status Close() override { return Status::OK(); }
Flush()59   Status Flush() override { return Status::OK(); }
Sync()60   Status Sync() override { return Status::OK(); }
Append(StringPiece slice)61   Status Append(StringPiece slice) override {
62     contents_->append(slice.data(), slice.size());
63     return Status::OK();
64   }
65 #if defined(PLATFORM_GOOGLE)
Append(const absl::Cord & data)66   Status Append(const absl::Cord& data) override {
67     contents_->append(data.ToString());
68     return Status::OK();
69   }
70 #endif
Tell(int64 * pos)71   Status Tell(int64* pos) override {
72     *pos = contents_->size();
73     return Status::OK();
74   }
75 
76  private:
77   string* contents_;
78 };
79 
80 class StringSource : public RandomAccessFile {
81  public:
StringSource(string * contents)82   explicit StringSource(string* contents)
83       : contents_(contents), force_error_(false) {}
84 
Read(uint64 offset,size_t n,StringPiece * result,char * scratch) const85   Status Read(uint64 offset, size_t n, StringPiece* result,
86               char* scratch) const override {
87     if (force_error_) {
88       force_error_ = false;
89       return errors::DataLoss("read error");
90     }
91 
92     if (offset >= contents_->size()) {
93       return errors::OutOfRange("end of file");
94     }
95 
96     if (contents_->size() < offset + n) {
97       n = contents_->size() - offset;
98     }
99     *result = StringPiece(contents_->data() + offset, n);
100     return Status::OK();
101   }
102 
force_error()103   void force_error() { force_error_ = true; }
104 
105  private:
106   string* contents_;
107   mutable bool force_error_;
108 };
109 
110 class RecordioTest : public ::testing::Test {
111  private:
112   string contents_;
113   StringDest dest_;
114   StringSource source_;
115   bool reading_;
116   uint64 readpos_;
117   RecordWriter* writer_;
118   RecordReader* reader_;
119 
120  public:
RecordioTest()121   RecordioTest()
122       : dest_(&contents_),
123         source_(&contents_),
124         reading_(false),
125         readpos_(0),
126         writer_(new RecordWriter(&dest_)),
127         reader_(new RecordReader(&source_)) {}
128 
~RecordioTest()129   ~RecordioTest() override {
130     delete writer_;
131     delete reader_;
132   }
133 
Write(const string & msg)134   void Write(const string& msg) {
135     ASSERT_TRUE(!reading_) << "Write() after starting to read";
136     TF_ASSERT_OK(writer_->WriteRecord(StringPiece(msg)));
137   }
138 
139 #if defined(PLATFORM_GOOGLE)
Write(const absl::Cord & msg)140   void Write(const absl::Cord& msg) {
141     ASSERT_TRUE(!reading_) << "Write() after starting to read";
142     TF_ASSERT_OK(writer_->WriteRecord(msg));
143   }
144 #endif
145 
WrittenBytes() const146   size_t WrittenBytes() const { return contents_.size(); }
147 
Read()148   string Read() {
149     if (!reading_) {
150       reading_ = true;
151     }
152     tstring record;
153     Status s = reader_->ReadRecord(&readpos_, &record);
154     if (s.ok()) {
155       return record;
156     } else if (errors::IsOutOfRange(s)) {
157       return "EOF";
158     } else {
159       return s.ToString();
160     }
161   }
162 
IncrementByte(int offset,int delta)163   void IncrementByte(int offset, int delta) { contents_[offset] += delta; }
164 
SetByte(int offset,char new_byte)165   void SetByte(int offset, char new_byte) { contents_[offset] = new_byte; }
166 
ShrinkSize(int bytes)167   void ShrinkSize(int bytes) { contents_.resize(contents_.size() - bytes); }
168 
FixChecksum(int header_offset,int len)169   void FixChecksum(int header_offset, int len) {
170     // Compute crc of type/len/data
171     uint32_t crc = crc32c::Value(&contents_[header_offset + 6], 1 + len);
172     crc = crc32c::Mask(crc);
173     core::EncodeFixed32(&contents_[header_offset], crc);
174   }
175 
ForceError()176   void ForceError() { source_.force_error(); }
177 
StartReadingAt(uint64_t initial_offset)178   void StartReadingAt(uint64_t initial_offset) { readpos_ = initial_offset; }
179 
CheckOffsetPastEndReturnsNoRecords(uint64_t offset_past_end)180   void CheckOffsetPastEndReturnsNoRecords(uint64_t offset_past_end) {
181     Write("foo");
182     Write("bar");
183     Write(BigString("x", 10000));
184     reading_ = true;
185     uint64 offset = WrittenBytes() + offset_past_end;
186     tstring record;
187     Status s = reader_->ReadRecord(&offset, &record);
188     ASSERT_TRUE(errors::IsOutOfRange(s)) << s;
189   }
190 };
191 
TEST_F(RecordioTest,Empty)192 TEST_F(RecordioTest, Empty) { ASSERT_EQ("EOF", Read()); }
193 
TEST_F(RecordioTest,ReadWrite)194 TEST_F(RecordioTest, ReadWrite) {
195   Write("foo");
196   Write("bar");
197   Write("");
198   Write("xxxx");
199   ASSERT_EQ("foo", Read());
200   ASSERT_EQ("bar", Read());
201   ASSERT_EQ("", Read());
202   ASSERT_EQ("xxxx", Read());
203   ASSERT_EQ("EOF", Read());
204   ASSERT_EQ("EOF", Read());  // Make sure reads at eof work
205 }
206 
207 #if defined(PLATFORM_GOOGLE)
TEST_F(RecordioTest,ReadWriteCords)208 TEST_F(RecordioTest, ReadWriteCords) {
209   Write(absl::Cord("foo"));
210   Write(absl::Cord("bar"));
211   Write(absl::Cord(""));
212   Write(absl::Cord("xxxx"));
213   ASSERT_EQ("foo", Read());
214   ASSERT_EQ("bar", Read());
215   ASSERT_EQ("", Read());
216   ASSERT_EQ("xxxx", Read());
217   ASSERT_EQ("EOF", Read());
218   ASSERT_EQ("EOF", Read());  // Make sure reads at eof work
219 }
220 #endif
221 
TEST_F(RecordioTest,ManyRecords)222 TEST_F(RecordioTest, ManyRecords) {
223   for (int i = 0; i < 100000; i++) {
224     Write(NumberString(i));
225   }
226   for (int i = 0; i < 100000; i++) {
227     ASSERT_EQ(NumberString(i), Read());
228   }
229   ASSERT_EQ("EOF", Read());
230 }
231 
TEST_F(RecordioTest,RandomRead)232 TEST_F(RecordioTest, RandomRead) {
233   const int N = 500;
234   {
235     random::PhiloxRandom philox(301, 17);
236     random::SimplePhilox rnd(&philox);
237     for (int i = 0; i < N; i++) {
238       Write(RandomSkewedString(i, &rnd));
239     }
240   }
241   {
242     random::PhiloxRandom philox(301, 17);
243     random::SimplePhilox rnd(&philox);
244     for (int i = 0; i < N; i++) {
245       ASSERT_EQ(RandomSkewedString(i, &rnd), Read());
246     }
247   }
248   ASSERT_EQ("EOF", Read());
249 }
250 
TestNonSequentialReads(const RecordWriterOptions & writer_options,const RecordReaderOptions & reader_options)251 void TestNonSequentialReads(const RecordWriterOptions& writer_options,
252                             const RecordReaderOptions& reader_options) {
253   string contents;
254   StringDest dst(&contents);
255   RecordWriter writer(&dst, writer_options);
256   for (int i = 0; i < 10; ++i) {
257     TF_ASSERT_OK(writer.WriteRecord(NumberString(i))) << i;
258   }
259   TF_ASSERT_OK(writer.Close());
260 
261   StringSource file(&contents);
262   RecordReader reader(&file, reader_options);
263 
264   tstring record;
265   // First read sequentially to fill in the offsets table.
266   uint64 offsets[10] = {0};
267   uint64 offset = 0;
268   for (int i = 0; i < 10; ++i) {
269     offsets[i] = offset;
270     TF_ASSERT_OK(reader.ReadRecord(&offset, &record)) << i;
271   }
272 
273   // Read randomly: First go back to record #3 then forward to #8.
274   offset = offsets[3];
275   TF_ASSERT_OK(reader.ReadRecord(&offset, &record));
276   EXPECT_EQ("3.", record);
277   EXPECT_EQ(offsets[4], offset);
278 
279   offset = offsets[8];
280   TF_ASSERT_OK(reader.ReadRecord(&offset, &record));
281   EXPECT_EQ("8.", record);
282   EXPECT_EQ(offsets[9], offset);
283 }
284 
TEST_F(RecordioTest,NonSequentialReads)285 TEST_F(RecordioTest, NonSequentialReads) {
286   TestNonSequentialReads(RecordWriterOptions(), RecordReaderOptions());
287 }
288 
TEST_F(RecordioTest,NonSequentialReadsWithReadBuffer)289 TEST_F(RecordioTest, NonSequentialReadsWithReadBuffer) {
290   RecordReaderOptions options;
291   options.buffer_size = 1 << 10;
292   TestNonSequentialReads(RecordWriterOptions(), options);
293 }
294 
TEST_F(RecordioTest,NonSequentialReadsWithCompression)295 TEST_F(RecordioTest, NonSequentialReadsWithCompression) {
296   TestNonSequentialReads(
297       RecordWriterOptions::CreateRecordWriterOptions("ZLIB"),
298       RecordReaderOptions::CreateRecordReaderOptions("ZLIB"));
299 }
300 
301 // Tests of all the error paths in log_reader.cc follow:
AssertHasSubstr(StringPiece s,StringPiece expected)302 void AssertHasSubstr(StringPiece s, StringPiece expected) {
303   EXPECT_TRUE(absl::StrContains(s, expected))
304       << s << " does not contain " << expected;
305 }
306 
TestReadError(const RecordWriterOptions & writer_options,const RecordReaderOptions & reader_options)307 void TestReadError(const RecordWriterOptions& writer_options,
308                    const RecordReaderOptions& reader_options) {
309   const string wrote = BigString("well hello there!", 100);
310   string contents;
311   StringDest dst(&contents);
312   TF_ASSERT_OK(RecordWriter(&dst, writer_options).WriteRecord(wrote));
313 
314   StringSource file(&contents);
315   RecordReader reader(&file, reader_options);
316 
317   uint64 offset = 0;
318   tstring read;
319   file.force_error();
320   Status status = reader.ReadRecord(&offset, &read);
321   ASSERT_TRUE(errors::IsDataLoss(status));
322   ASSERT_EQ(0, offset);
323 
324   // A failed Read() shouldn't update the offset, and thus a retry shouldn't
325   // lose the record.
326   status = reader.ReadRecord(&offset, &read);
327   ASSERT_TRUE(status.ok()) << status;
328   EXPECT_GT(offset, 0);
329   EXPECT_EQ(wrote, read);
330 }
331 
TEST_F(RecordioTest,ReadError)332 TEST_F(RecordioTest, ReadError) {
333   TestReadError(RecordWriterOptions(), RecordReaderOptions());
334 }
335 
TEST_F(RecordioTest,ReadErrorWithBuffering)336 TEST_F(RecordioTest, ReadErrorWithBuffering) {
337   RecordReaderOptions options;
338   options.buffer_size = 1 << 20;
339   TestReadError(RecordWriterOptions(), options);
340 }
341 
TEST_F(RecordioTest,ReadErrorWithCompression)342 TEST_F(RecordioTest, ReadErrorWithCompression) {
343   TestReadError(RecordWriterOptions::CreateRecordWriterOptions("ZLIB"),
344                 RecordReaderOptions::CreateRecordReaderOptions("ZLIB"));
345 }
346 
TEST_F(RecordioTest,CorruptLength)347 TEST_F(RecordioTest, CorruptLength) {
348   Write("foo");
349   IncrementByte(6, 100);
350   AssertHasSubstr(Read(), "Data loss");
351 }
352 
TEST_F(RecordioTest,CorruptLengthCrc)353 TEST_F(RecordioTest, CorruptLengthCrc) {
354   Write("foo");
355   IncrementByte(10, 100);
356   AssertHasSubstr(Read(), "Data loss");
357 }
358 
TEST_F(RecordioTest,CorruptData)359 TEST_F(RecordioTest, CorruptData) {
360   Write("foo");
361   IncrementByte(14, 10);
362   AssertHasSubstr(Read(), "Data loss");
363 }
364 
TEST_F(RecordioTest,CorruptDataCrc)365 TEST_F(RecordioTest, CorruptDataCrc) {
366   Write("foo");
367   IncrementByte(WrittenBytes() - 1, 10);
368   AssertHasSubstr(Read(), "Data loss");
369 }
370 
TEST_F(RecordioTest,ReadEnd)371 TEST_F(RecordioTest, ReadEnd) { CheckOffsetPastEndReturnsNoRecords(0); }
372 
TEST_F(RecordioTest,ReadPastEnd)373 TEST_F(RecordioTest, ReadPastEnd) { CheckOffsetPastEndReturnsNoRecords(5); }
374 
375 }  // namespace
376 }  // namespace io
377 }  // namespace tensorflow
378