• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/lib/io/record_reader.h"
17 
18 #include <limits.h>
19 
20 #include "tensorflow/core/lib/core/coding.h"
21 #include "tensorflow/core/lib/core/errors.h"
22 #include "tensorflow/core/lib/hash/crc32c.h"
23 #include "tensorflow/core/lib/io/buffered_inputstream.h"
24 #include "tensorflow/core/lib/io/compression.h"
25 #include "tensorflow/core/lib/io/random_inputstream.h"
26 #include "tensorflow/core/platform/env.h"
27 
28 namespace tensorflow {
29 namespace io {
30 
CreateRecordReaderOptions(const string & compression_type)31 RecordReaderOptions RecordReaderOptions::CreateRecordReaderOptions(
32     const string& compression_type) {
33   RecordReaderOptions options;
34 
35 #if defined(IS_SLIM_BUILD)
36   if (compression_type != compression::kNone) {
37     LOG(ERROR) << "Compression is not supported but compression_type is set."
38                << " No compression will be used.";
39   }
40 #else
41   if (compression_type == compression::kZlib) {
42     options.compression_type = io::RecordReaderOptions::ZLIB_COMPRESSION;
43     options.zlib_options = io::ZlibCompressionOptions::DEFAULT();
44   } else if (compression_type == compression::kGzip) {
45     options.compression_type = io::RecordReaderOptions::ZLIB_COMPRESSION;
46     options.zlib_options = io::ZlibCompressionOptions::GZIP();
47   } else if (compression_type == compression::kSnappy) {
48     options.compression_type = io::RecordReaderOptions::SNAPPY_COMPRESSION;
49   } else if (compression_type != compression::kNone) {
50     LOG(ERROR) << "Unsupported compression_type:" << compression_type
51                << ". No compression will be used.";
52   }
53 #endif
54   return options;
55 }
56 
RecordReader(RandomAccessFile * file,const RecordReaderOptions & options)57 RecordReader::RecordReader(RandomAccessFile* file,
58                            const RecordReaderOptions& options)
59     : options_(options),
60       input_stream_(new RandomAccessInputStream(file)),
61       last_read_failed_(false) {
62   if (options.buffer_size > 0) {
63     input_stream_.reset(new BufferedInputStream(input_stream_.release(),
64                                                 options.buffer_size, true));
65   }
66 #if defined(IS_SLIM_BUILD)
67   if (options.compression_type != RecordReaderOptions::NONE) {
68     LOG(FATAL) << "Compression is unsupported on mobile platforms.";
69   }
70 #else
71   if (options.compression_type == RecordReaderOptions::ZLIB_COMPRESSION) {
72     input_stream_.reset(new ZlibInputStream(
73         input_stream_.release(), options.zlib_options.input_buffer_size,
74         options.zlib_options.output_buffer_size, options.zlib_options, true));
75   } else if (options.compression_type ==
76              RecordReaderOptions::SNAPPY_COMPRESSION) {
77     input_stream_.reset(
78         new SnappyInputStream(input_stream_.release(),
79                               options.snappy_options.output_buffer_size, true));
80   } else if (options.compression_type == RecordReaderOptions::NONE) {
81     // Nothing to do.
82   } else {
83     LOG(FATAL) << "Unrecognized compression type :" << options.compression_type;
84   }
85 #endif
86 }
87 
88 // Read n+4 bytes from file, verify that checksum of first n bytes is
89 // stored in the last 4 bytes and store the first n bytes in *result.
90 //
91 // offset corresponds to the user-provided value to ReadRecord()
92 // and is used only in error messages.
ReadChecksummed(uint64 offset,size_t n,tstring * result)93 Status RecordReader::ReadChecksummed(uint64 offset, size_t n, tstring* result) {
94   if (n >= SIZE_MAX - sizeof(uint32)) {
95     return errors::DataLoss("record size too large");
96   }
97 
98   const size_t expected = n + sizeof(uint32);
99   TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(expected, result));
100 
101   if (result->size() != expected) {
102     if (result->empty()) {
103       return errors::OutOfRange("eof");
104     } else {
105       return errors::DataLoss("truncated record at ", offset);
106     }
107   }
108 
109   const uint32 masked_crc = core::DecodeFixed32(result->data() + n);
110   if (crc32c::Unmask(masked_crc) != crc32c::Value(result->data(), n)) {
111     return errors::DataLoss("corrupted record at ", offset);
112   }
113   result->resize(n);
114   return OkStatus();
115 }
116 
GetMetadata(Metadata * md)117 Status RecordReader::GetMetadata(Metadata* md) {
118   if (!md) {
119     return errors::InvalidArgument(
120         "Metadata object call to GetMetadata() was null");
121   }
122 
123   // Compute the metadata of the TFRecord file if not cached.
124   if (!cached_metadata_) {
125     TF_RETURN_IF_ERROR(input_stream_->Reset());
126 
127     int64_t data_size = 0;
128     int64_t entries = 0;
129 
130     // Within the loop, we always increment offset positively, so this
131     // loop should be guaranteed to either return after reaching EOF
132     // or encountering an error.
133     uint64 offset = 0;
134     tstring record;
135     while (true) {
136       // Read header, containing size of data.
137       Status s = ReadChecksummed(offset, sizeof(uint64), &record);
138       if (!s.ok()) {
139         if (errors::IsOutOfRange(s)) {
140           // We should reach out of range when the record file is complete.
141           break;
142         }
143         return s;
144       }
145 
146       // Read the length of the data.
147       const uint64 length = core::DecodeFixed64(record.data());
148 
149       // Skip reading the actual data since we just want the number
150       // of records and the size of the data.
151       TF_RETURN_IF_ERROR(input_stream_->SkipNBytes(length + kFooterSize));
152       offset += kHeaderSize + length + kFooterSize;
153 
154       // Increment running stats.
155       data_size += length;
156       ++entries;
157     }
158 
159     cached_metadata_.reset(new Metadata());
160     cached_metadata_->stats.entries = entries;
161     cached_metadata_->stats.data_size = data_size;
162     cached_metadata_->stats.file_size =
163         data_size + (kHeaderSize + kFooterSize) * entries;
164   }
165 
166   md->stats = cached_metadata_->stats;
167   return OkStatus();
168 }
169 
PositionInputStream(uint64 offset)170 Status RecordReader::PositionInputStream(uint64 offset) {
171   int64_t curr_pos = input_stream_->Tell();
172   int64_t desired_pos = static_cast<int64_t>(offset);
173   if (curr_pos > desired_pos || curr_pos < 0 /* EOF */ ||
174       (curr_pos == desired_pos && last_read_failed_)) {
175     last_read_failed_ = false;
176     TF_RETURN_IF_ERROR(input_stream_->Reset());
177     TF_RETURN_IF_ERROR(input_stream_->SkipNBytes(desired_pos));
178   } else if (curr_pos < desired_pos) {
179     TF_RETURN_IF_ERROR(input_stream_->SkipNBytes(desired_pos - curr_pos));
180   }
181   DCHECK_EQ(desired_pos, input_stream_->Tell());
182   return OkStatus();
183 }
184 
ReadRecord(uint64 * offset,tstring * record)185 Status RecordReader::ReadRecord(uint64* offset, tstring* record) {
186   TF_RETURN_IF_ERROR(PositionInputStream(*offset));
187 
188   // Read header data.
189   Status s = ReadChecksummed(*offset, sizeof(uint64), record);
190   if (!s.ok()) {
191     last_read_failed_ = true;
192     return s;
193   }
194   const uint64 length = core::DecodeFixed64(record->data());
195 
196   // Read data
197   s = ReadChecksummed(*offset + kHeaderSize, length, record);
198   if (!s.ok()) {
199     last_read_failed_ = true;
200     if (errors::IsOutOfRange(s)) {
201       s = errors::DataLoss("truncated record at ", *offset, "' failed with ",
202                            s.error_message());
203     }
204     return s;
205   }
206 
207   *offset += kHeaderSize + length + kFooterSize;
208   DCHECK_EQ(*offset, input_stream_->Tell());
209   return OkStatus();
210 }
211 
SkipRecords(uint64 * offset,int num_to_skip,int * num_skipped)212 Status RecordReader::SkipRecords(uint64* offset, int num_to_skip,
213                                  int* num_skipped) {
214   TF_RETURN_IF_ERROR(PositionInputStream(*offset));
215 
216   Status s;
217   tstring record;
218   *num_skipped = 0;
219   for (int i = 0; i < num_to_skip; ++i) {
220     s = ReadChecksummed(*offset, sizeof(uint64), &record);
221     if (!s.ok()) {
222       last_read_failed_ = true;
223       return s;
224     }
225     const uint64 length = core::DecodeFixed64(record.data());
226 
227     // Skip data
228     s = input_stream_->SkipNBytes(length + kFooterSize);
229     if (!s.ok()) {
230       last_read_failed_ = true;
231       if (errors::IsOutOfRange(s)) {
232         s = errors::DataLoss("truncated record at ", *offset, "' failed with ",
233                              s.error_message());
234       }
235       return s;
236     }
237     *offset += kHeaderSize + length + kFooterSize;
238     DCHECK_EQ(*offset, input_stream_->Tell());
239     (*num_skipped)++;
240   }
241   return OkStatus();
242 }
243 
SequentialRecordReader(RandomAccessFile * file,const RecordReaderOptions & options)244 SequentialRecordReader::SequentialRecordReader(
245     RandomAccessFile* file, const RecordReaderOptions& options)
246     : underlying_(file, options), offset_(0) {}
247 
248 }  // namespace io
249 }  // namespace tensorflow
250