1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/lib/io/record_reader.h"
17
18 #include <limits.h>
19
20 #include "tensorflow/core/lib/core/coding.h"
21 #include "tensorflow/core/lib/core/errors.h"
22 #include "tensorflow/core/lib/hash/crc32c.h"
23 #include "tensorflow/core/lib/io/buffered_inputstream.h"
24 #include "tensorflow/core/lib/io/compression.h"
25 #include "tensorflow/core/lib/io/random_inputstream.h"
26 #include "tensorflow/core/platform/env.h"
27
28 namespace tensorflow {
29 namespace io {
30
CreateRecordReaderOptions(const string & compression_type)31 RecordReaderOptions RecordReaderOptions::CreateRecordReaderOptions(
32 const string& compression_type) {
33 RecordReaderOptions options;
34 if (compression_type == "ZLIB") {
35 options.compression_type = io::RecordReaderOptions::ZLIB_COMPRESSION;
36 #if defined(IS_SLIM_BUILD)
37 LOG(ERROR) << "Compression is not supported but compression_type is set."
38 << " No compression will be used.";
39 #else
40 options.zlib_options = io::ZlibCompressionOptions::DEFAULT();
41 #endif // IS_SLIM_BUILD
42 } else if (compression_type == compression::kGzip) {
43 options.compression_type = io::RecordReaderOptions::ZLIB_COMPRESSION;
44 #if defined(IS_SLIM_BUILD)
45 LOG(ERROR) << "Compression is not supported but compression_type is set."
46 << " No compression will be used.";
47 #else
48 options.zlib_options = io::ZlibCompressionOptions::GZIP();
49 #endif // IS_SLIM_BUILD
50 } else if (compression_type != compression::kNone) {
51 LOG(ERROR) << "Unsupported compression_type:" << compression_type
52 << ". No compression will be used.";
53 }
54 return options;
55 }
56
RecordReader(RandomAccessFile * file,const RecordReaderOptions & options)57 RecordReader::RecordReader(RandomAccessFile* file,
58 const RecordReaderOptions& options)
59 : options_(options),
60 input_stream_(new RandomAccessInputStream(file)),
61 last_read_failed_(false) {
62 if (options.buffer_size > 0) {
63 input_stream_.reset(new BufferedInputStream(input_stream_.release(),
64 options.buffer_size, true));
65 }
66 if (options.compression_type == RecordReaderOptions::ZLIB_COMPRESSION) {
67 // We don't have zlib available on all embedded platforms, so fail.
68 #if defined(IS_SLIM_BUILD)
69 LOG(FATAL) << "Zlib compression is unsupported on mobile platforms.";
70 #else // IS_SLIM_BUILD
71 input_stream_.reset(new ZlibInputStream(
72 input_stream_.release(), options.zlib_options.input_buffer_size,
73 options.zlib_options.output_buffer_size, options.zlib_options, true));
74 #endif // IS_SLIM_BUILD
75 } else if (options.compression_type == RecordReaderOptions::NONE) {
76 // Nothing to do.
77 } else {
78 LOG(FATAL) << "Unrecognized compression type :" << options.compression_type;
79 }
80 }
81
82 // Read n+4 bytes from file, verify that checksum of first n bytes is
83 // stored in the last 4 bytes and store the first n bytes in *result.
84 //
85 // offset corresponds to the user-provided value to ReadRecord()
86 // and is used only in error messages.
ReadChecksummed(uint64 offset,size_t n,string * result)87 Status RecordReader::ReadChecksummed(uint64 offset, size_t n, string* result) {
88 if (n >= SIZE_MAX - sizeof(uint32)) {
89 return errors::DataLoss("record size too large");
90 }
91
92 const size_t expected = n + sizeof(uint32);
93 TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(expected, result));
94
95 if (result->size() != expected) {
96 if (result->empty()) {
97 return errors::OutOfRange("eof");
98 } else {
99 return errors::DataLoss("truncated record at ", offset);
100 }
101 }
102
103 const uint32 masked_crc = core::DecodeFixed32(result->data() + n);
104 if (crc32c::Unmask(masked_crc) != crc32c::Value(result->data(), n)) {
105 return errors::DataLoss("corrupted record at ", offset);
106 }
107 result->resize(n);
108 return Status::OK();
109 }
110
GetMetadata(Metadata * md)111 Status RecordReader::GetMetadata(Metadata* md) {
112 if (!md) {
113 return errors::InvalidArgument(
114 "Metadata object call to GetMetadata() was null");
115 }
116
117 // Compute the metadata of the TFRecord file if not cached.
118 if (!cached_metadata_) {
119 TF_RETURN_IF_ERROR(input_stream_->Reset());
120
121 int64 data_size = 0;
122 int64 entries = 0;
123
124 // Within the loop, we always increment offset positively, so this
125 // loop should be guaranteed to either return after reaching EOF
126 // or encountering an error.
127 uint64 offset = 0;
128 string record;
129 while (true) {
130 // Read header, containing size of data.
131 Status s = ReadChecksummed(offset, sizeof(uint64), &record);
132 if (!s.ok()) {
133 if (errors::IsOutOfRange(s)) {
134 // We should reach out of range when the record file is complete.
135 break;
136 }
137 return s;
138 }
139
140 // Read the length of the data.
141 const uint64 length = core::DecodeFixed64(record.data());
142
143 // Skip reading the actual data since we just want the number
144 // of records and the size of the data.
145 TF_RETURN_IF_ERROR(input_stream_->SkipNBytes(length + kFooterSize));
146 offset += kHeaderSize + length + kFooterSize;
147
148 // Increment running stats.
149 data_size += length;
150 ++entries;
151 }
152
153 cached_metadata_.reset(new Metadata());
154 cached_metadata_->stats.entries = entries;
155 cached_metadata_->stats.data_size = data_size;
156 cached_metadata_->stats.file_size =
157 data_size + (kHeaderSize + kFooterSize) * entries;
158 }
159
160 md->stats = cached_metadata_->stats;
161 return Status::OK();
162 }
163
ReadRecord(uint64 * offset,string * record)164 Status RecordReader::ReadRecord(uint64* offset, string* record) {
165 // Position the input stream.
166 int64 curr_pos = input_stream_->Tell();
167 int64 desired_pos = static_cast<int64>(*offset);
168 if (curr_pos > desired_pos || curr_pos < 0 /* EOF */ ||
169 (curr_pos == desired_pos && last_read_failed_)) {
170 last_read_failed_ = false;
171 TF_RETURN_IF_ERROR(input_stream_->Reset());
172 TF_RETURN_IF_ERROR(input_stream_->SkipNBytes(desired_pos));
173 } else if (curr_pos < desired_pos) {
174 TF_RETURN_IF_ERROR(input_stream_->SkipNBytes(desired_pos - curr_pos));
175 }
176 DCHECK_EQ(desired_pos, input_stream_->Tell());
177
178 // Read header data.
179 Status s = ReadChecksummed(*offset, sizeof(uint64), record);
180 if (!s.ok()) {
181 last_read_failed_ = true;
182 return s;
183 }
184 const uint64 length = core::DecodeFixed64(record->data());
185
186 // Read data
187 s = ReadChecksummed(*offset + kHeaderSize, length, record);
188 if (!s.ok()) {
189 last_read_failed_ = true;
190 if (errors::IsOutOfRange(s)) {
191 s = errors::DataLoss("truncated record at ", *offset);
192 }
193 return s;
194 }
195
196 *offset += kHeaderSize + length + kFooterSize;
197 DCHECK_EQ(*offset, input_stream_->Tell());
198 return Status::OK();
199 }
200
SequentialRecordReader(RandomAccessFile * file,const RecordReaderOptions & options)201 SequentialRecordReader::SequentialRecordReader(
202 RandomAccessFile* file, const RecordReaderOptions& options)
203 : underlying_(file, options), offset_(0) {}
204
205 } // namespace io
206 } // namespace tensorflow
207