1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/lib/io/record_reader.h"
17
18 #include <limits.h>
19
20 #include "tensorflow/core/lib/core/coding.h"
21 #include "tensorflow/core/lib/core/errors.h"
22 #include "tensorflow/core/lib/hash/crc32c.h"
23 #include "tensorflow/core/lib/io/buffered_inputstream.h"
24 #include "tensorflow/core/lib/io/compression.h"
25 #include "tensorflow/core/lib/io/random_inputstream.h"
26 #include "tensorflow/core/platform/env.h"
27
28 namespace tensorflow {
29 namespace io {
30
CreateRecordReaderOptions(const string & compression_type)31 RecordReaderOptions RecordReaderOptions::CreateRecordReaderOptions(
32 const string& compression_type) {
33 RecordReaderOptions options;
34
35 #if defined(IS_SLIM_BUILD)
36 if (compression_type != compression::kNone) {
37 LOG(ERROR) << "Compression is not supported but compression_type is set."
38 << " No compression will be used.";
39 }
40 #else
41 if (compression_type == compression::kZlib) {
42 options.compression_type = io::RecordReaderOptions::ZLIB_COMPRESSION;
43 options.zlib_options = io::ZlibCompressionOptions::DEFAULT();
44 } else if (compression_type == compression::kGzip) {
45 options.compression_type = io::RecordReaderOptions::ZLIB_COMPRESSION;
46 options.zlib_options = io::ZlibCompressionOptions::GZIP();
47 } else if (compression_type == compression::kSnappy) {
48 options.compression_type = io::RecordReaderOptions::SNAPPY_COMPRESSION;
49 } else if (compression_type != compression::kNone) {
50 LOG(ERROR) << "Unsupported compression_type:" << compression_type
51 << ". No compression will be used.";
52 }
53 #endif
54 return options;
55 }
56
RecordReader(RandomAccessFile * file,const RecordReaderOptions & options)57 RecordReader::RecordReader(RandomAccessFile* file,
58 const RecordReaderOptions& options)
59 : options_(options),
60 input_stream_(new RandomAccessInputStream(file)),
61 last_read_failed_(false) {
62 if (options.buffer_size > 0) {
63 input_stream_.reset(new BufferedInputStream(input_stream_.release(),
64 options.buffer_size, true));
65 }
66 #if defined(IS_SLIM_BUILD)
67 if (options.compression_type != RecordReaderOptions::NONE) {
68 LOG(FATAL) << "Compression is unsupported on mobile platforms.";
69 }
70 #else
71 if (options.compression_type == RecordReaderOptions::ZLIB_COMPRESSION) {
72 input_stream_.reset(new ZlibInputStream(
73 input_stream_.release(), options.zlib_options.input_buffer_size,
74 options.zlib_options.output_buffer_size, options.zlib_options, true));
75 } else if (options.compression_type ==
76 RecordReaderOptions::SNAPPY_COMPRESSION) {
77 input_stream_.reset(
78 new SnappyInputStream(input_stream_.release(),
79 options.snappy_options.output_buffer_size, true));
80 } else if (options.compression_type == RecordReaderOptions::NONE) {
81 // Nothing to do.
82 } else {
83 LOG(FATAL) << "Unrecognized compression type :" << options.compression_type;
84 }
85 #endif
86 }
87
88 // Read n+4 bytes from file, verify that checksum of first n bytes is
89 // stored in the last 4 bytes and store the first n bytes in *result.
90 //
91 // offset corresponds to the user-provided value to ReadRecord()
92 // and is used only in error messages.
ReadChecksummed(uint64 offset,size_t n,tstring * result)93 Status RecordReader::ReadChecksummed(uint64 offset, size_t n, tstring* result) {
94 if (n >= SIZE_MAX - sizeof(uint32)) {
95 return errors::DataLoss("record size too large");
96 }
97
98 const size_t expected = n + sizeof(uint32);
99 TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(expected, result));
100
101 if (result->size() != expected) {
102 if (result->empty()) {
103 return errors::OutOfRange("eof");
104 } else {
105 return errors::DataLoss("truncated record at ", offset);
106 }
107 }
108
109 const uint32 masked_crc = core::DecodeFixed32(result->data() + n);
110 if (crc32c::Unmask(masked_crc) != crc32c::Value(result->data(), n)) {
111 return errors::DataLoss("corrupted record at ", offset);
112 }
113 result->resize(n);
114 return Status::OK();
115 }
116
GetMetadata(Metadata * md)117 Status RecordReader::GetMetadata(Metadata* md) {
118 if (!md) {
119 return errors::InvalidArgument(
120 "Metadata object call to GetMetadata() was null");
121 }
122
123 // Compute the metadata of the TFRecord file if not cached.
124 if (!cached_metadata_) {
125 TF_RETURN_IF_ERROR(input_stream_->Reset());
126
127 int64 data_size = 0;
128 int64 entries = 0;
129
130 // Within the loop, we always increment offset positively, so this
131 // loop should be guaranteed to either return after reaching EOF
132 // or encountering an error.
133 uint64 offset = 0;
134 tstring record;
135 while (true) {
136 // Read header, containing size of data.
137 Status s = ReadChecksummed(offset, sizeof(uint64), &record);
138 if (!s.ok()) {
139 if (errors::IsOutOfRange(s)) {
140 // We should reach out of range when the record file is complete.
141 break;
142 }
143 return s;
144 }
145
146 // Read the length of the data.
147 const uint64 length = core::DecodeFixed64(record.data());
148
149 // Skip reading the actual data since we just want the number
150 // of records and the size of the data.
151 TF_RETURN_IF_ERROR(input_stream_->SkipNBytes(length + kFooterSize));
152 offset += kHeaderSize + length + kFooterSize;
153
154 // Increment running stats.
155 data_size += length;
156 ++entries;
157 }
158
159 cached_metadata_.reset(new Metadata());
160 cached_metadata_->stats.entries = entries;
161 cached_metadata_->stats.data_size = data_size;
162 cached_metadata_->stats.file_size =
163 data_size + (kHeaderSize + kFooterSize) * entries;
164 }
165
166 md->stats = cached_metadata_->stats;
167 return Status::OK();
168 }
169
PositionInputStream(uint64 offset)170 Status RecordReader::PositionInputStream(uint64 offset) {
171 int64 curr_pos = input_stream_->Tell();
172 int64 desired_pos = static_cast<int64>(offset);
173 if (curr_pos > desired_pos || curr_pos < 0 /* EOF */ ||
174 (curr_pos == desired_pos && last_read_failed_)) {
175 last_read_failed_ = false;
176 TF_RETURN_IF_ERROR(input_stream_->Reset());
177 TF_RETURN_IF_ERROR(input_stream_->SkipNBytes(desired_pos));
178 } else if (curr_pos < desired_pos) {
179 TF_RETURN_IF_ERROR(input_stream_->SkipNBytes(desired_pos - curr_pos));
180 }
181 DCHECK_EQ(desired_pos, input_stream_->Tell());
182 return Status::OK();
183 }
184
ReadRecord(uint64 * offset,tstring * record)185 Status RecordReader::ReadRecord(uint64* offset, tstring* record) {
186 TF_RETURN_IF_ERROR(PositionInputStream(*offset));
187
188 // Read header data.
189 Status s = ReadChecksummed(*offset, sizeof(uint64), record);
190 if (!s.ok()) {
191 last_read_failed_ = true;
192 return s;
193 }
194 const uint64 length = core::DecodeFixed64(record->data());
195
196 // Read data
197 s = ReadChecksummed(*offset + kHeaderSize, length, record);
198 if (!s.ok()) {
199 last_read_failed_ = true;
200 if (errors::IsOutOfRange(s)) {
201 s = errors::DataLoss("truncated record at ", *offset, "' failed with ",
202 s.error_message());
203 }
204 return s;
205 }
206
207 *offset += kHeaderSize + length + kFooterSize;
208 DCHECK_EQ(*offset, input_stream_->Tell());
209 return Status::OK();
210 }
211
SkipRecords(uint64 * offset,int num_to_skip,int * num_skipped)212 Status RecordReader::SkipRecords(uint64* offset, int num_to_skip,
213 int* num_skipped) {
214 TF_RETURN_IF_ERROR(PositionInputStream(*offset));
215
216 Status s;
217 tstring record;
218 *num_skipped = 0;
219 for (int i = 0; i < num_to_skip; ++i) {
220 s = ReadChecksummed(*offset, sizeof(uint64), &record);
221 if (!s.ok()) {
222 last_read_failed_ = true;
223 return s;
224 }
225 const uint64 length = core::DecodeFixed64(record.data());
226
227 // Skip data
228 s = input_stream_->SkipNBytes(length + kFooterSize);
229 if (!s.ok()) {
230 last_read_failed_ = true;
231 if (errors::IsOutOfRange(s)) {
232 s = errors::DataLoss("truncated record at ", *offset, "' failed with ",
233 s.error_message());
234 }
235 return s;
236 }
237 *offset += kHeaderSize + length + kFooterSize;
238 DCHECK_EQ(*offset, input_stream_->Tell());
239 (*num_skipped)++;
240 }
241 return Status::OK();
242 }
243
SequentialRecordReader(RandomAccessFile * file,const RecordReaderOptions & options)244 SequentialRecordReader::SequentialRecordReader(
245 RandomAccessFile* file, const RecordReaderOptions& options)
246 : underlying_(file, options), offset_(0) {}
247
248 } // namespace io
249 } // namespace tensorflow
250