• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""For reading and writing TFRecords files."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22from tensorflow.python import pywrap_tensorflow
23from tensorflow.python.framework import errors
24from tensorflow.python.util import compat
25from tensorflow.python.util import deprecation
26from tensorflow.python.util.tf_export import tf_export
27
28
29@tf_export(
30    "io.TFRecordCompressionType",
31    v1=["io.TFRecordCompressionType", "python_io.TFRecordCompressionType"])
32@deprecation.deprecated_endpoints("python_io.TFRecordCompressionType")
33class TFRecordCompressionType(object):
34  """The type of compression for the record."""
35  NONE = 0
36  ZLIB = 1
37  GZIP = 2
38
39
40@tf_export(
41    "io.TFRecordOptions",
42    v1=["io.TFRecordOptions", "python_io.TFRecordOptions"])
43@deprecation.deprecated_endpoints("python_io.TFRecordOptions")
44class TFRecordOptions(object):
45  """Options used for manipulating TFRecord files."""
46  compression_type_map = {
47      TFRecordCompressionType.ZLIB: "ZLIB",
48      TFRecordCompressionType.GZIP: "GZIP",
49      TFRecordCompressionType.NONE: ""
50  }
51
52  def __init__(self,
53               compression_type=None,
54               flush_mode=None,
55               input_buffer_size=None,
56               output_buffer_size=None,
57               window_bits=None,
58               compression_level=None,
59               compression_method=None,
60               mem_level=None,
61               compression_strategy=None):
62    # pylint: disable=line-too-long
63    """Creates a `TFRecordOptions` instance.
64
65    Options only effect TFRecordWriter when compression_type is not `None`.
66    Documentation, details, and defaults can be found in
67    [`zlib_compression_options.h`](https://www.tensorflow.org/code/tensorflow/core/lib/io/zlib_compression_options.h)
68    and in the [zlib manual](http://www.zlib.net/manual.html).
69    Leaving an option as `None` allows C++ to set a reasonable default.
70
71    Args:
72      compression_type: `TFRecordCompressionType` or `None`.
73      flush_mode: flush mode or `None`, Default: Z_NO_FLUSH.
74      input_buffer_size: int or `None`.
75      output_buffer_size: int or `None`.
76      window_bits: int or `None`.
77      compression_level: 0 to 9, or `None`.
78      compression_method: compression method or `None`.
79      mem_level: 1 to 9, or `None`.
80      compression_strategy: strategy or `None`. Default: Z_DEFAULT_STRATEGY.
81
82    Returns:
83      A `TFRecordOptions` object.
84
85    Raises:
86      ValueError: If compression_type is invalid.
87    """
88    # pylint: enable=line-too-long
89    # Check compression_type is valid, but for backwards compatibility don't
90    # immediately convert to a string.
91    self.get_compression_type_string(compression_type)
92    self.compression_type = compression_type
93    self.flush_mode = flush_mode
94    self.input_buffer_size = input_buffer_size
95    self.output_buffer_size = output_buffer_size
96    self.window_bits = window_bits
97    self.compression_level = compression_level
98    self.compression_method = compression_method
99    self.mem_level = mem_level
100    self.compression_strategy = compression_strategy
101
102  @classmethod
103  def get_compression_type_string(cls, options):
104    """Convert various option types to a unified string.
105
106    Args:
107      options: `TFRecordOption`, `TFRecordCompressionType`, or string.
108
109    Returns:
110      Compression type as string (e.g. `'ZLIB'`, `'GZIP'`, or `''`).
111
112    Raises:
113      ValueError: If compression_type is invalid.
114    """
115    if not options:
116      return ""
117    elif isinstance(options, TFRecordOptions):
118      return cls.get_compression_type_string(options.compression_type)
119    elif isinstance(options, TFRecordCompressionType):
120      return cls.compression_type_map[options]
121    elif options in TFRecordOptions.compression_type_map:
122      return cls.compression_type_map[options]
123    elif options in TFRecordOptions.compression_type_map.values():
124      return options
125    else:
126      raise ValueError('Not a valid compression_type: "{}"'.format(options))
127
128  def _as_record_writer_options(self):
129    """Convert to RecordWriterOptions for use with PyRecordWriter."""
130    options = pywrap_tensorflow.RecordWriterOptions_CreateRecordWriterOptions(
131        compat.as_bytes(
132            self.get_compression_type_string(self.compression_type)))
133
134    if self.flush_mode is not None:
135      options.zlib_options.flush_mode = self.flush_mode
136    if self.input_buffer_size is not None:
137      options.zlib_options.input_buffer_size = self.input_buffer_size
138    if self.output_buffer_size is not None:
139      options.zlib_options.output_buffer_size = self.output_buffer_size
140    if self.window_bits is not None:
141      options.zlib_options.window_bits = self.window_bits
142    if self.compression_level is not None:
143      options.zlib_options.compression_level = self.compression_level
144    if self.compression_method is not None:
145      options.zlib_options.compression_method = self.compression_method
146    if self.mem_level is not None:
147      options.zlib_options.mem_level = self.mem_level
148    if self.compression_strategy is not None:
149      options.zlib_options.compression_strategy = self.compression_strategy
150    return options
151
152
153@tf_export(v1=["io.tf_record_iterator", "python_io.tf_record_iterator"])
154@deprecation.deprecated(
155    date=None,
156    instructions=("Use eager execution and: \n"
157                  "`tf.data.TFRecordDataset(path)`"))
158def tf_record_iterator(path, options=None):
159  """An iterator that read the records from a TFRecords file.
160
161  Args:
162    path: The path to the TFRecords file.
163    options: (optional) A TFRecordOptions object.
164
165  Yields:
166    Strings.
167
168  Raises:
169    IOError: If `path` cannot be opened for reading.
170  """
171  compression_type = TFRecordOptions.get_compression_type_string(options)
172  with errors.raise_exception_on_not_ok_status() as status:
173    reader = pywrap_tensorflow.PyRecordReader_New(
174        compat.as_bytes(path), 0, compat.as_bytes(compression_type), status)
175
176  if reader is None:
177    raise IOError("Could not open %s." % path)
178  try:
179    while True:
180      try:
181        reader.GetNext()
182      except errors.OutOfRangeError:
183        break
184      yield reader.record()
185  finally:
186    reader.Close()
187
188
189@tf_export(
190    "io.TFRecordWriter", v1=["io.TFRecordWriter", "python_io.TFRecordWriter"])
191@deprecation.deprecated_endpoints("python_io.TFRecordWriter")
192class TFRecordWriter(object):
193  """A class to write records to a TFRecords file.
194
195  This class implements `__enter__` and `__exit__`, and can be used
196  in `with` blocks like a normal file.
197  """
198
199  # TODO(josh11b): Support appending?
200  def __init__(self, path, options=None):
201    """Opens file `path` and creates a `TFRecordWriter` writing to it.
202
203    Args:
204      path: The path to the TFRecords file.
205      options: (optional) String specifying compression type,
206          `TFRecordCompressionType`, or `TFRecordOptions` object.
207
208    Raises:
209      IOError: If `path` cannot be opened for writing.
210      ValueError: If valid compression_type can't be determined from `options`.
211    """
212    if not isinstance(options, TFRecordOptions):
213      options = TFRecordOptions(compression_type=options)
214
215    with errors.raise_exception_on_not_ok_status() as status:
216      # pylint: disable=protected-access
217      self._writer = pywrap_tensorflow.PyRecordWriter_New(
218          compat.as_bytes(path), options._as_record_writer_options(), status)
219      # pylint: enable=protected-access
220
221  def __enter__(self):
222    """Enter a `with` block."""
223    return self
224
225  def __exit__(self, unused_type, unused_value, unused_traceback):
226    """Exit a `with` block, closing the file."""
227    self.close()
228
229  def write(self, record):
230    """Write a string record to the file.
231
232    Args:
233      record: str
234    """
235    with errors.raise_exception_on_not_ok_status() as status:
236      self._writer.WriteRecord(record, status)
237
238  def flush(self):
239    """Flush the file."""
240    with errors.raise_exception_on_not_ok_status() as status:
241      self._writer.Flush(status)
242
243  def close(self):
244    """Close the file."""
245    with errors.raise_exception_on_not_ok_status() as status:
246      self._writer.Close(status)
247