• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/lib/io/record_reader.h"
17 
18 #include <limits.h>
19 
20 #include "tensorflow/core/lib/core/coding.h"
21 #include "tensorflow/core/lib/core/errors.h"
22 #include "tensorflow/core/lib/hash/crc32c.h"
23 #include "tensorflow/core/lib/io/buffered_inputstream.h"
24 #include "tensorflow/core/lib/io/compression.h"
25 #include "tensorflow/core/lib/io/random_inputstream.h"
26 #include "tensorflow/core/platform/env.h"
27 
28 namespace tensorflow {
29 namespace io {
30 
CreateRecordReaderOptions(const string & compression_type)31 RecordReaderOptions RecordReaderOptions::CreateRecordReaderOptions(
32     const string& compression_type) {
33   RecordReaderOptions options;
34   if (compression_type == "ZLIB") {
35     options.compression_type = io::RecordReaderOptions::ZLIB_COMPRESSION;
36 #if defined(IS_SLIM_BUILD)
37     LOG(ERROR) << "Compression is not supported but compression_type is set."
38                << " No compression will be used.";
39 #else
40     options.zlib_options = io::ZlibCompressionOptions::DEFAULT();
41 #endif  // IS_SLIM_BUILD
42   } else if (compression_type == compression::kGzip) {
43     options.compression_type = io::RecordReaderOptions::ZLIB_COMPRESSION;
44 #if defined(IS_SLIM_BUILD)
45     LOG(ERROR) << "Compression is not supported but compression_type is set."
46                << " No compression will be used.";
47 #else
48     options.zlib_options = io::ZlibCompressionOptions::GZIP();
49 #endif  // IS_SLIM_BUILD
50   } else if (compression_type != compression::kNone) {
51     LOG(ERROR) << "Unsupported compression_type:" << compression_type
52                << ". No compression will be used.";
53   }
54   return options;
55 }
56 
RecordReader(RandomAccessFile * file,const RecordReaderOptions & options)57 RecordReader::RecordReader(RandomAccessFile* file,
58                            const RecordReaderOptions& options)
59     : options_(options),
60       input_stream_(new RandomAccessInputStream(file)),
61       last_read_failed_(false) {
62   if (options.buffer_size > 0) {
63     input_stream_.reset(new BufferedInputStream(input_stream_.release(),
64                                                 options.buffer_size, true));
65   }
66   if (options.compression_type == RecordReaderOptions::ZLIB_COMPRESSION) {
67 // We don't have zlib available on all embedded platforms, so fail.
68 #if defined(IS_SLIM_BUILD)
69     LOG(FATAL) << "Zlib compression is unsupported on mobile platforms.";
70 #else   // IS_SLIM_BUILD
71     input_stream_.reset(new ZlibInputStream(
72         input_stream_.release(), options.zlib_options.input_buffer_size,
73         options.zlib_options.output_buffer_size, options.zlib_options, true));
74 #endif  // IS_SLIM_BUILD
75   } else if (options.compression_type == RecordReaderOptions::NONE) {
76     // Nothing to do.
77   } else {
78     LOG(FATAL) << "Unrecognized compression type :" << options.compression_type;
79   }
80 }
81 
82 // Read n+4 bytes from file, verify that checksum of first n bytes is
83 // stored in the last 4 bytes and store the first n bytes in *result.
84 //
85 // offset corresponds to the user-provided value to ReadRecord()
86 // and is used only in error messages.
ReadChecksummed(uint64 offset,size_t n,string * result)87 Status RecordReader::ReadChecksummed(uint64 offset, size_t n, string* result) {
88   if (n >= SIZE_MAX - sizeof(uint32)) {
89     return errors::DataLoss("record size too large");
90   }
91 
92   const size_t expected = n + sizeof(uint32);
93   TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(expected, result));
94 
95   if (result->size() != expected) {
96     if (result->empty()) {
97       return errors::OutOfRange("eof");
98     } else {
99       return errors::DataLoss("truncated record at ", offset);
100     }
101   }
102 
103   const uint32 masked_crc = core::DecodeFixed32(result->data() + n);
104   if (crc32c::Unmask(masked_crc) != crc32c::Value(result->data(), n)) {
105     return errors::DataLoss("corrupted record at ", offset);
106   }
107   result->resize(n);
108   return Status::OK();
109 }
110 
GetMetadata(Metadata * md)111 Status RecordReader::GetMetadata(Metadata* md) {
112   if (!md) {
113     return errors::InvalidArgument(
114         "Metadata object call to GetMetadata() was null");
115   }
116 
117   // Compute the metadata of the TFRecord file if not cached.
118   if (!cached_metadata_) {
119     TF_RETURN_IF_ERROR(input_stream_->Reset());
120 
121     int64 data_size = 0;
122     int64 entries = 0;
123 
124     // Within the loop, we always increment offset positively, so this
125     // loop should be guaranteed to either return after reaching EOF
126     // or encountering an error.
127     uint64 offset = 0;
128     string record;
129     while (true) {
130       // Read header, containing size of data.
131       Status s = ReadChecksummed(offset, sizeof(uint64), &record);
132       if (!s.ok()) {
133         if (errors::IsOutOfRange(s)) {
134           // We should reach out of range when the record file is complete.
135           break;
136         }
137         return s;
138       }
139 
140       // Read the length of the data.
141       const uint64 length = core::DecodeFixed64(record.data());
142 
143       // Skip reading the actual data since we just want the number
144       // of records and the size of the data.
145       TF_RETURN_IF_ERROR(input_stream_->SkipNBytes(length + kFooterSize));
146       offset += kHeaderSize + length + kFooterSize;
147 
148       // Increment running stats.
149       data_size += length;
150       ++entries;
151     }
152 
153     cached_metadata_.reset(new Metadata());
154     cached_metadata_->stats.entries = entries;
155     cached_metadata_->stats.data_size = data_size;
156     cached_metadata_->stats.file_size =
157         data_size + (kHeaderSize + kFooterSize) * entries;
158   }
159 
160   md->stats = cached_metadata_->stats;
161   return Status::OK();
162 }
163 
ReadRecord(uint64 * offset,string * record)164 Status RecordReader::ReadRecord(uint64* offset, string* record) {
165   // Position the input stream.
166   int64 curr_pos = input_stream_->Tell();
167   int64 desired_pos = static_cast<int64>(*offset);
168   if (curr_pos > desired_pos || curr_pos < 0 /* EOF */ ||
169       (curr_pos == desired_pos && last_read_failed_)) {
170     last_read_failed_ = false;
171     TF_RETURN_IF_ERROR(input_stream_->Reset());
172     TF_RETURN_IF_ERROR(input_stream_->SkipNBytes(desired_pos));
173   } else if (curr_pos < desired_pos) {
174     TF_RETURN_IF_ERROR(input_stream_->SkipNBytes(desired_pos - curr_pos));
175   }
176   DCHECK_EQ(desired_pos, input_stream_->Tell());
177 
178   // Read header data.
179   Status s = ReadChecksummed(*offset, sizeof(uint64), record);
180   if (!s.ok()) {
181     last_read_failed_ = true;
182     return s;
183   }
184   const uint64 length = core::DecodeFixed64(record->data());
185 
186   // Read data
187   s = ReadChecksummed(*offset + kHeaderSize, length, record);
188   if (!s.ok()) {
189     last_read_failed_ = true;
190     if (errors::IsOutOfRange(s)) {
191       s = errors::DataLoss("truncated record at ", *offset);
192     }
193     return s;
194   }
195 
196   *offset += kHeaderSize + length + kFooterSize;
197   DCHECK_EQ(*offset, input_stream_->Tell());
198   return Status::OK();
199 }
200 
SequentialRecordReader(RandomAccessFile * file,const RecordReaderOptions & options)201 SequentialRecordReader::SequentialRecordReader(
202     RandomAccessFile* file, const RecordReaderOptions& options)
203     : underlying_(file, options), offset_(0) {}
204 
205 }  // namespace io
206 }  // namespace tensorflow
207