1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""For reading and writing TFRecords files.""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22from tensorflow.python.lib.io import _pywrap_record_io 23from tensorflow.python.util import compat 24from tensorflow.python.util import deprecation 25from tensorflow.python.util.tf_export import tf_export 26 27 28@tf_export( 29 v1=["io.TFRecordCompressionType", "python_io.TFRecordCompressionType"]) 30@deprecation.deprecated_endpoints("io.TFRecordCompressionType", 31 "python_io.TFRecordCompressionType") 32class TFRecordCompressionType(object): 33 """The type of compression for the record.""" 34 NONE = 0 35 ZLIB = 1 36 GZIP = 2 37 38 39@tf_export( 40 "io.TFRecordOptions", 41 v1=["io.TFRecordOptions", "python_io.TFRecordOptions"]) 42@deprecation.deprecated_endpoints("python_io.TFRecordOptions") 43class TFRecordOptions(object): 44 """Options used for manipulating TFRecord files.""" 45 compression_type_map = { 46 TFRecordCompressionType.ZLIB: "ZLIB", 47 TFRecordCompressionType.GZIP: "GZIP", 48 TFRecordCompressionType.NONE: "" 49 } 50 51 def __init__(self, 52 compression_type=None, 53 flush_mode=None, 54 input_buffer_size=None, 55 output_buffer_size=None, 56 window_bits=None, 57 compression_level=None, 58 compression_method=None, 59 mem_level=None, 60 compression_strategy=None): 61 # pylint: disable=line-too-long 62 """Creates a `TFRecordOptions` instance. 63 64 Options only effect TFRecordWriter when compression_type is not `None`. 65 Documentation, details, and defaults can be found in 66 [`zlib_compression_options.h`](https://www.tensorflow.org/code/tensorflow/core/lib/io/zlib_compression_options.h) 67 and in the [zlib manual](http://www.zlib.net/manual.html). 68 Leaving an option as `None` allows C++ to set a reasonable default. 69 70 Args: 71 compression_type: `"GZIP"`, `"ZLIB"`, or `""` (no compression). 72 flush_mode: flush mode or `None`, Default: Z_NO_FLUSH. 73 input_buffer_size: int or `None`. 74 output_buffer_size: int or `None`. 75 window_bits: int or `None`. 76 compression_level: 0 to 9, or `None`. 77 compression_method: compression method or `None`. 78 mem_level: 1 to 9, or `None`. 79 compression_strategy: strategy or `None`. Default: Z_DEFAULT_STRATEGY. 80 81 Returns: 82 A `TFRecordOptions` object. 83 84 Raises: 85 ValueError: If compression_type is invalid. 86 """ 87 # pylint: enable=line-too-long 88 # Check compression_type is valid, but for backwards compatibility don't 89 # immediately convert to a string. 90 self.get_compression_type_string(compression_type) 91 self.compression_type = compression_type 92 self.flush_mode = flush_mode 93 self.input_buffer_size = input_buffer_size 94 self.output_buffer_size = output_buffer_size 95 self.window_bits = window_bits 96 self.compression_level = compression_level 97 self.compression_method = compression_method 98 self.mem_level = mem_level 99 self.compression_strategy = compression_strategy 100 101 @classmethod 102 def get_compression_type_string(cls, options): 103 """Convert various option types to a unified string. 104 105 Args: 106 options: `TFRecordOption`, `TFRecordCompressionType`, or string. 107 108 Returns: 109 Compression type as string (e.g. `'ZLIB'`, `'GZIP'`, or `''`). 110 111 Raises: 112 ValueError: If compression_type is invalid. 113 """ 114 if not options: 115 return "" 116 elif isinstance(options, TFRecordOptions): 117 return cls.get_compression_type_string(options.compression_type) 118 elif isinstance(options, TFRecordCompressionType): 119 return cls.compression_type_map[options] 120 elif options in TFRecordOptions.compression_type_map: 121 return cls.compression_type_map[options] 122 elif options in TFRecordOptions.compression_type_map.values(): 123 return options 124 else: 125 raise ValueError('Not a valid compression_type: "{}"'.format(options)) 126 127 def _as_record_writer_options(self): 128 """Convert to RecordWriterOptions for use with PyRecordWriter.""" 129 options = _pywrap_record_io.RecordWriterOptions( 130 compat.as_bytes( 131 self.get_compression_type_string(self.compression_type))) 132 133 if self.flush_mode is not None: 134 options.zlib_options.flush_mode = self.flush_mode 135 if self.input_buffer_size is not None: 136 options.zlib_options.input_buffer_size = self.input_buffer_size 137 if self.output_buffer_size is not None: 138 options.zlib_options.output_buffer_size = self.output_buffer_size 139 if self.window_bits is not None: 140 options.zlib_options.window_bits = self.window_bits 141 if self.compression_level is not None: 142 options.zlib_options.compression_level = self.compression_level 143 if self.compression_method is not None: 144 options.zlib_options.compression_method = self.compression_method 145 if self.mem_level is not None: 146 options.zlib_options.mem_level = self.mem_level 147 if self.compression_strategy is not None: 148 options.zlib_options.compression_strategy = self.compression_strategy 149 return options 150 151 152@tf_export(v1=["io.tf_record_iterator", "python_io.tf_record_iterator"]) 153@deprecation.deprecated( 154 date=None, 155 instructions=("Use eager execution and: \n" 156 "`tf.data.TFRecordDataset(path)`")) 157def tf_record_iterator(path, options=None): 158 """An iterator that read the records from a TFRecords file. 159 160 Args: 161 path: The path to the TFRecords file. 162 options: (optional) A TFRecordOptions object. 163 164 Returns: 165 An iterator of serialized TFRecords. 166 167 Raises: 168 IOError: If `path` cannot be opened for reading. 169 """ 170 compression_type = TFRecordOptions.get_compression_type_string(options) 171 return _pywrap_record_io.RecordIterator(path, compression_type) 172 173 174def tf_record_random_reader(path): 175 """Creates a reader that allows random-access reads from a TFRecords file. 176 177 The created reader object has the following method: 178 179 - `read(offset)`, which returns a tuple of `(record, ending_offset)`, where 180 `record` is the TFRecord read at the offset, and 181 `ending_offset` is the ending offset of the read record. 182 183 The method throws a `tf.errors.DataLossError` if data is corrupted at 184 the given offset. The method throws `IndexError` if the offset is out of 185 range for the TFRecords file. 186 187 188 Usage example: 189 ```py 190 reader = tf_record_random_reader(file_path) 191 192 record_1, offset_1 = reader.read(0) # 0 is the initial offset. 193 # offset_1 is the ending offset of the 1st record and the starting offset of 194 # the next. 195 196 record_2, offset_2 = reader.read(offset_1) 197 # offset_2 is the ending offset of the 2nd record and the starting offset of 198 # the next. 199 # We can jump back and read the first record again if so desired. 200 reader.read(0) 201 ``` 202 203 Args: 204 path: The path to the TFRecords file. 205 206 Returns: 207 An object that supports random-access reading of the serialized TFRecords. 208 209 Raises: 210 IOError: If `path` cannot be opened for reading. 211 """ 212 return _pywrap_record_io.RandomRecordReader(path) 213 214 215@tf_export( 216 "io.TFRecordWriter", v1=["io.TFRecordWriter", "python_io.TFRecordWriter"]) 217@deprecation.deprecated_endpoints("python_io.TFRecordWriter") 218class TFRecordWriter(_pywrap_record_io.RecordWriter): 219 """A class to write records to a TFRecords file. 220 221 [TFRecords tutorial](https://www.tensorflow.org/tutorials/load_data/tfrecord) 222 223 TFRecords is a binary format which is optimized for high throughput data 224 retrieval, generally in conjunction with `tf.data`. `TFRecordWriter` is used 225 to write serialized examples to a file for later consumption. The key steps 226 are: 227 228 Ahead of time: 229 230 - [Convert data into a serialized format]( 231 https://www.tensorflow.org/tutorials/load_data/tfrecord#tfexample) 232 - [Write the serialized data to one or more files]( 233 https://www.tensorflow.org/tutorials/load_data/tfrecord#tfrecord_files_in_python) 234 235 During training or evaluation: 236 237 - [Read serialized examples into memory]( 238 https://www.tensorflow.org/tutorials/load_data/tfrecord#reading_a_tfrecord_file) 239 - [Parse (deserialize) examples]( 240 https://www.tensorflow.org/tutorials/load_data/tfrecord#reading_a_tfrecord_file) 241 242 A minimal example is given below: 243 244 >>> import tempfile 245 >>> example_path = os.path.join(tempfile.gettempdir(), "example.tfrecords") 246 >>> np.random.seed(0) 247 248 >>> # Write the records to a file. 249 ... with tf.io.TFRecordWriter(example_path) as file_writer: 250 ... for _ in range(4): 251 ... x, y = np.random.random(), np.random.random() 252 ... 253 ... record_bytes = tf.train.Example(features=tf.train.Features(feature={ 254 ... "x": tf.train.Feature(float_list=tf.train.FloatList(value=[x])), 255 ... "y": tf.train.Feature(float_list=tf.train.FloatList(value=[y])), 256 ... })).SerializeToString() 257 ... file_writer.write(record_bytes) 258 259 >>> # Read the data back out. 260 >>> def decode_fn(record_bytes): 261 ... return tf.io.parse_single_example( 262 ... # Data 263 ... record_bytes, 264 ... 265 ... # Schema 266 ... {"x": tf.io.FixedLenFeature([], dtype=tf.float32), 267 ... "y": tf.io.FixedLenFeature([], dtype=tf.float32)} 268 ... ) 269 270 >>> for batch in tf.data.TFRecordDataset([example_path]).map(decode_fn): 271 ... print("x = {x:.4f}, y = {y:.4f}".format(**batch)) 272 x = 0.5488, y = 0.7152 273 x = 0.6028, y = 0.5449 274 x = 0.4237, y = 0.6459 275 x = 0.4376, y = 0.8918 276 277 This class implements `__enter__` and `__exit__`, and can be used 278 in `with` blocks like a normal file. (See the usage example above.) 279 """ 280 281 # TODO(josh11b): Support appending? 282 def __init__(self, path, options=None): 283 """Opens file `path` and creates a `TFRecordWriter` writing to it. 284 285 Args: 286 path: The path to the TFRecords file. 287 options: (optional) String specifying compression type, 288 `TFRecordCompressionType`, or `TFRecordOptions` object. 289 290 Raises: 291 IOError: If `path` cannot be opened for writing. 292 ValueError: If valid compression_type can't be determined from `options`. 293 """ 294 if not isinstance(options, TFRecordOptions): 295 options = TFRecordOptions(compression_type=options) 296 297 # pylint: disable=protected-access 298 super(TFRecordWriter, self).__init__( 299 compat.as_bytes(path), options._as_record_writer_options()) 300 # pylint: enable=protected-access 301 302 # TODO(slebedev): The following wrapper methods are there to compensate 303 # for lack of signatures in pybind11-generated classes. Switch to 304 # __text_signature__ when TensorFlow drops Python 2.X support. 305 # See https://github.com/pybind/pybind11/issues/945 306 # pylint: disable=useless-super-delegation 307 def write(self, record): 308 """Write a string record to the file. 309 310 Args: 311 record: str 312 """ 313 super(TFRecordWriter, self).write(record) 314 315 def flush(self): 316 """Flush the file.""" 317 super(TFRecordWriter, self).flush() 318 319 def close(self): 320 """Close the file.""" 321 super(TFRecordWriter, self).close() 322 # pylint: enable=useless-super-delegation 323