• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""For reading and writing TFRecords files."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22from tensorflow.python import pywrap_tensorflow
23from tensorflow.python.framework import errors
24from tensorflow.python.util import compat
25from tensorflow.python.util.tf_export import tf_export
26
27
28@tf_export("python_io.TFRecordCompressionType")
29class TFRecordCompressionType(object):
30  """The type of compression for the record."""
31  NONE = 0
32  ZLIB = 1
33  GZIP = 2
34
35
36# NOTE(vrv): This will eventually be converted into a proto.  to match
37# the interface used by the C++ RecordWriter.
38@tf_export("python_io.TFRecordOptions")
39class TFRecordOptions(object):
40  """Options used for manipulating TFRecord files."""
41  compression_type_map = {
42      TFRecordCompressionType.ZLIB: "ZLIB",
43      TFRecordCompressionType.GZIP: "GZIP",
44      TFRecordCompressionType.NONE: ""
45  }
46
47  def __init__(self, compression_type):
48    self.compression_type = compression_type
49
50  @classmethod
51  def get_compression_type_string(cls, options):
52    if not options:
53      return ""
54    return cls.compression_type_map[options.compression_type]
55
56
57@tf_export("python_io.tf_record_iterator")
58def tf_record_iterator(path, options=None):
59  """An iterator that read the records from a TFRecords file.
60
61  Args:
62    path: The path to the TFRecords file.
63    options: (optional) A TFRecordOptions object.
64
65  Yields:
66    Strings.
67
68  Raises:
69    IOError: If `path` cannot be opened for reading.
70  """
71  compression_type = TFRecordOptions.get_compression_type_string(options)
72  with errors.raise_exception_on_not_ok_status() as status:
73    reader = pywrap_tensorflow.PyRecordReader_New(
74        compat.as_bytes(path), 0, compat.as_bytes(compression_type), status)
75
76  if reader is None:
77    raise IOError("Could not open %s." % path)
78  while True:
79    try:
80      with errors.raise_exception_on_not_ok_status() as status:
81        reader.GetNext(status)
82    except errors.OutOfRangeError:
83      break
84    yield reader.record()
85  reader.Close()
86
87
88@tf_export("python_io.TFRecordWriter")
89class TFRecordWriter(object):
90  """A class to write records to a TFRecords file.
91
92  This class implements `__enter__` and `__exit__`, and can be used
93  in `with` blocks like a normal file.
94  """
95
96  # TODO(josh11b): Support appending?
97  def __init__(self, path, options=None):
98    """Opens file `path` and creates a `TFRecordWriter` writing to it.
99
100    Args:
101      path: The path to the TFRecords file.
102      options: (optional) A TFRecordOptions object.
103
104    Raises:
105      IOError: If `path` cannot be opened for writing.
106    """
107    compression_type = TFRecordOptions.get_compression_type_string(options)
108
109    with errors.raise_exception_on_not_ok_status() as status:
110      self._writer = pywrap_tensorflow.PyRecordWriter_New(
111          compat.as_bytes(path), compat.as_bytes(compression_type), status)
112
113  def __enter__(self):
114    """Enter a `with` block."""
115    return self
116
117  def __exit__(self, unused_type, unused_value, unused_traceback):
118    """Exit a `with` block, closing the file."""
119    self.close()
120
121  def write(self, record):
122    """Write a string record to the file.
123
124    Args:
125      record: str
126    """
127    self._writer.WriteRecord(record)
128
129  def flush(self):
130    """Flush the file."""
131    with errors.raise_exception_on_not_ok_status() as status:
132      self._writer.Flush(status)
133
134  def close(self):
135    """Close the file."""
136    with errors.raise_exception_on_not_ok_status() as status:
137      self._writer.Close(status)
138