• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""For reading and writing TFRecords files."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22from tensorflow.python.lib.io import _pywrap_record_io
23from tensorflow.python.util import compat
24from tensorflow.python.util import deprecation
25from tensorflow.python.util.tf_export import tf_export
26
27
28@tf_export(
29    v1=["io.TFRecordCompressionType", "python_io.TFRecordCompressionType"])
30@deprecation.deprecated_endpoints("io.TFRecordCompressionType",
31                                  "python_io.TFRecordCompressionType")
32class TFRecordCompressionType(object):
33  """The type of compression for the record."""
34  NONE = 0
35  ZLIB = 1
36  GZIP = 2
37
38
39@tf_export(
40    "io.TFRecordOptions",
41    v1=["io.TFRecordOptions", "python_io.TFRecordOptions"])
42@deprecation.deprecated_endpoints("python_io.TFRecordOptions")
43class TFRecordOptions(object):
44  """Options used for manipulating TFRecord files."""
45  compression_type_map = {
46      TFRecordCompressionType.ZLIB: "ZLIB",
47      TFRecordCompressionType.GZIP: "GZIP",
48      TFRecordCompressionType.NONE: ""
49  }
50
51  def __init__(self,
52               compression_type=None,
53               flush_mode=None,
54               input_buffer_size=None,
55               output_buffer_size=None,
56               window_bits=None,
57               compression_level=None,
58               compression_method=None,
59               mem_level=None,
60               compression_strategy=None):
61    # pylint: disable=line-too-long
62    """Creates a `TFRecordOptions` instance.
63
64    Options only effect TFRecordWriter when compression_type is not `None`.
65    Documentation, details, and defaults can be found in
66    [`zlib_compression_options.h`](https://www.tensorflow.org/code/tensorflow/core/lib/io/zlib_compression_options.h)
67    and in the [zlib manual](http://www.zlib.net/manual.html).
68    Leaving an option as `None` allows C++ to set a reasonable default.
69
70    Args:
71      compression_type: `"GZIP"`, `"ZLIB"`, or `""` (no compression).
72      flush_mode: flush mode or `None`, Default: Z_NO_FLUSH.
73      input_buffer_size: int or `None`.
74      output_buffer_size: int or `None`.
75      window_bits: int or `None`.
76      compression_level: 0 to 9, or `None`.
77      compression_method: compression method or `None`.
78      mem_level: 1 to 9, or `None`.
79      compression_strategy: strategy or `None`. Default: Z_DEFAULT_STRATEGY.
80
81    Returns:
82      A `TFRecordOptions` object.
83
84    Raises:
85      ValueError: If compression_type is invalid.
86    """
87    # pylint: enable=line-too-long
88    # Check compression_type is valid, but for backwards compatibility don't
89    # immediately convert to a string.
90    self.get_compression_type_string(compression_type)
91    self.compression_type = compression_type
92    self.flush_mode = flush_mode
93    self.input_buffer_size = input_buffer_size
94    self.output_buffer_size = output_buffer_size
95    self.window_bits = window_bits
96    self.compression_level = compression_level
97    self.compression_method = compression_method
98    self.mem_level = mem_level
99    self.compression_strategy = compression_strategy
100
101  @classmethod
102  def get_compression_type_string(cls, options):
103    """Convert various option types to a unified string.
104
105    Args:
106      options: `TFRecordOption`, `TFRecordCompressionType`, or string.
107
108    Returns:
109      Compression type as string (e.g. `'ZLIB'`, `'GZIP'`, or `''`).
110
111    Raises:
112      ValueError: If compression_type is invalid.
113    """
114    if not options:
115      return ""
116    elif isinstance(options, TFRecordOptions):
117      return cls.get_compression_type_string(options.compression_type)
118    elif isinstance(options, TFRecordCompressionType):
119      return cls.compression_type_map[options]
120    elif options in TFRecordOptions.compression_type_map:
121      return cls.compression_type_map[options]
122    elif options in TFRecordOptions.compression_type_map.values():
123      return options
124    else:
125      raise ValueError('Not a valid compression_type: "{}"'.format(options))
126
127  def _as_record_writer_options(self):
128    """Convert to RecordWriterOptions for use with PyRecordWriter."""
129    options = _pywrap_record_io.RecordWriterOptions(
130        compat.as_bytes(
131            self.get_compression_type_string(self.compression_type)))
132
133    if self.flush_mode is not None:
134      options.zlib_options.flush_mode = self.flush_mode
135    if self.input_buffer_size is not None:
136      options.zlib_options.input_buffer_size = self.input_buffer_size
137    if self.output_buffer_size is not None:
138      options.zlib_options.output_buffer_size = self.output_buffer_size
139    if self.window_bits is not None:
140      options.zlib_options.window_bits = self.window_bits
141    if self.compression_level is not None:
142      options.zlib_options.compression_level = self.compression_level
143    if self.compression_method is not None:
144      options.zlib_options.compression_method = self.compression_method
145    if self.mem_level is not None:
146      options.zlib_options.mem_level = self.mem_level
147    if self.compression_strategy is not None:
148      options.zlib_options.compression_strategy = self.compression_strategy
149    return options
150
151
152@tf_export(v1=["io.tf_record_iterator", "python_io.tf_record_iterator"])
153@deprecation.deprecated(
154    date=None,
155    instructions=("Use eager execution and: \n"
156                  "`tf.data.TFRecordDataset(path)`"))
157def tf_record_iterator(path, options=None):
158  """An iterator that read the records from a TFRecords file.
159
160  Args:
161    path: The path to the TFRecords file.
162    options: (optional) A TFRecordOptions object.
163
164  Returns:
165    An iterator of serialized TFRecords.
166
167  Raises:
168    IOError: If `path` cannot be opened for reading.
169  """
170  compression_type = TFRecordOptions.get_compression_type_string(options)
171  return _pywrap_record_io.RecordIterator(path, compression_type)
172
173
174def tf_record_random_reader(path):
175  """Creates a reader that allows random-access reads from a TFRecords file.
176
177  The created reader object has the following method:
178
179    - `read(offset)`, which returns a tuple of `(record, ending_offset)`, where
180      `record` is the TFRecord read at the offset, and
181      `ending_offset` is the ending offset of the read record.
182
183      The method throws a `tf.errors.DataLossError` if data is corrupted at
184      the given offset. The method throws `IndexError` if the offset is out of
185      range for the TFRecords file.
186
187
188  Usage example:
189  ```py
190  reader = tf_record_random_reader(file_path)
191
192  record_1, offset_1 = reader.read(0)  # 0 is the initial offset.
193  # offset_1 is the ending offset of the 1st record and the starting offset of
194  # the next.
195
196  record_2, offset_2 = reader.read(offset_1)
197  # offset_2 is the ending offset of the 2nd record and the starting offset of
198  # the next.
199  # We can jump back and read the first record again if so desired.
200  reader.read(0)
201  ```
202
203  Args:
204    path: The path to the TFRecords file.
205
206  Returns:
207    An object that supports random-access reading of the serialized TFRecords.
208
209  Raises:
210    IOError: If `path` cannot be opened for reading.
211  """
212  return _pywrap_record_io.RandomRecordReader(path)
213
214
215@tf_export(
216    "io.TFRecordWriter", v1=["io.TFRecordWriter", "python_io.TFRecordWriter"])
217@deprecation.deprecated_endpoints("python_io.TFRecordWriter")
218class TFRecordWriter(_pywrap_record_io.RecordWriter):
219  """A class to write records to a TFRecords file.
220
221  [TFRecords tutorial](https://www.tensorflow.org/tutorials/load_data/tfrecord)
222
223  TFRecords is a binary format which is optimized for high throughput data
224  retrieval, generally in conjunction with `tf.data`. `TFRecordWriter` is used
225  to write serialized examples to a file for later consumption. The key steps
226  are:
227
228   Ahead of time:
229
230   - [Convert data into a serialized format](
231   https://www.tensorflow.org/tutorials/load_data/tfrecord#tfexample)
232   - [Write the serialized data to one or more files](
233   https://www.tensorflow.org/tutorials/load_data/tfrecord#tfrecord_files_in_python)
234
235   During training or evaluation:
236
237   - [Read serialized examples into memory](
238   https://www.tensorflow.org/tutorials/load_data/tfrecord#reading_a_tfrecord_file)
239   - [Parse (deserialize) examples](
240   https://www.tensorflow.org/tutorials/load_data/tfrecord#reading_a_tfrecord_file)
241
242  A minimal example is given below:
243
244  >>> import tempfile
245  >>> example_path = os.path.join(tempfile.gettempdir(), "example.tfrecords")
246  >>> np.random.seed(0)
247
248  >>> # Write the records to a file.
249  ... with tf.io.TFRecordWriter(example_path) as file_writer:
250  ...   for _ in range(4):
251  ...     x, y = np.random.random(), np.random.random()
252  ...
253  ...     record_bytes = tf.train.Example(features=tf.train.Features(feature={
254  ...         "x": tf.train.Feature(float_list=tf.train.FloatList(value=[x])),
255  ...         "y": tf.train.Feature(float_list=tf.train.FloatList(value=[y])),
256  ...     })).SerializeToString()
257  ...     file_writer.write(record_bytes)
258
259  >>> # Read the data back out.
260  >>> def decode_fn(record_bytes):
261  ...   return tf.io.parse_single_example(
262  ...       # Data
263  ...       record_bytes,
264  ...
265  ...       # Schema
266  ...       {"x": tf.io.FixedLenFeature([], dtype=tf.float32),
267  ...        "y": tf.io.FixedLenFeature([], dtype=tf.float32)}
268  ...   )
269
270  >>> for batch in tf.data.TFRecordDataset([example_path]).map(decode_fn):
271  ...   print("x = {x:.4f},  y = {y:.4f}".format(**batch))
272  x = 0.5488,  y = 0.7152
273  x = 0.6028,  y = 0.5449
274  x = 0.4237,  y = 0.6459
275  x = 0.4376,  y = 0.8918
276
277  This class implements `__enter__` and `__exit__`, and can be used
278  in `with` blocks like a normal file. (See the usage example above.)
279  """
280
281  # TODO(josh11b): Support appending?
282  def __init__(self, path, options=None):
283    """Opens file `path` and creates a `TFRecordWriter` writing to it.
284
285    Args:
286      path: The path to the TFRecords file.
287      options: (optional) String specifying compression type,
288          `TFRecordCompressionType`, or `TFRecordOptions` object.
289
290    Raises:
291      IOError: If `path` cannot be opened for writing.
292      ValueError: If valid compression_type can't be determined from `options`.
293    """
294    if not isinstance(options, TFRecordOptions):
295      options = TFRecordOptions(compression_type=options)
296
297    # pylint: disable=protected-access
298    super(TFRecordWriter, self).__init__(
299        compat.as_bytes(path), options._as_record_writer_options())
300    # pylint: enable=protected-access
301
302  # TODO(slebedev): The following wrapper methods are there to compensate
303  # for lack of signatures in pybind11-generated classes. Switch to
304  # __text_signature__ when TensorFlow drops Python 2.X support.
305  # See https://github.com/pybind/pybind11/issues/945
306  # pylint: disable=useless-super-delegation
307  def write(self, record):
308    """Write a string record to the file.
309
310    Args:
311      record: str
312    """
313    super(TFRecordWriter, self).write(record)
314
315  def flush(self):
316    """Flush the file."""
317    super(TFRecordWriter, self).flush()
318
319  def close(self):
320    """Close the file."""
321    super(TFRecordWriter, self).close()
322  # pylint: enable=useless-super-delegation
323