1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""For reading and writing TFRecords files.""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22from tensorflow.python import pywrap_tensorflow 23from tensorflow.python.framework import errors 24from tensorflow.python.util import compat 25from tensorflow.python.util import deprecation 26from tensorflow.python.util.tf_export import tf_export 27 28 29@tf_export( 30 "io.TFRecordCompressionType", 31 v1=["io.TFRecordCompressionType", "python_io.TFRecordCompressionType"]) 32@deprecation.deprecated_endpoints("python_io.TFRecordCompressionType") 33class TFRecordCompressionType(object): 34 """The type of compression for the record.""" 35 NONE = 0 36 ZLIB = 1 37 GZIP = 2 38 39 40@tf_export( 41 "io.TFRecordOptions", 42 v1=["io.TFRecordOptions", "python_io.TFRecordOptions"]) 43@deprecation.deprecated_endpoints("python_io.TFRecordOptions") 44class TFRecordOptions(object): 45 """Options used for manipulating TFRecord files.""" 46 compression_type_map = { 47 TFRecordCompressionType.ZLIB: "ZLIB", 48 TFRecordCompressionType.GZIP: "GZIP", 49 TFRecordCompressionType.NONE: "" 50 } 51 52 def __init__(self, 53 compression_type=None, 54 flush_mode=None, 55 input_buffer_size=None, 56 output_buffer_size=None, 57 window_bits=None, 58 compression_level=None, 59 compression_method=None, 60 mem_level=None, 61 compression_strategy=None): 62 # pylint: disable=line-too-long 63 """Creates a `TFRecordOptions` instance. 64 65 Options only effect TFRecordWriter when compression_type is not `None`. 66 Documentation, details, and defaults can be found in 67 [`zlib_compression_options.h`](https://www.tensorflow.org/code/tensorflow/core/lib/io/zlib_compression_options.h) 68 and in the [zlib manual](http://www.zlib.net/manual.html). 69 Leaving an option as `None` allows C++ to set a reasonable default. 70 71 Args: 72 compression_type: `TFRecordCompressionType` or `None`. 73 flush_mode: flush mode or `None`, Default: Z_NO_FLUSH. 74 input_buffer_size: int or `None`. 75 output_buffer_size: int or `None`. 76 window_bits: int or `None`. 77 compression_level: 0 to 9, or `None`. 78 compression_method: compression method or `None`. 79 mem_level: 1 to 9, or `None`. 80 compression_strategy: strategy or `None`. Default: Z_DEFAULT_STRATEGY. 81 82 Returns: 83 A `TFRecordOptions` object. 84 85 Raises: 86 ValueError: If compression_type is invalid. 87 """ 88 # pylint: enable=line-too-long 89 # Check compression_type is valid, but for backwards compatibility don't 90 # immediately convert to a string. 91 self.get_compression_type_string(compression_type) 92 self.compression_type = compression_type 93 self.flush_mode = flush_mode 94 self.input_buffer_size = input_buffer_size 95 self.output_buffer_size = output_buffer_size 96 self.window_bits = window_bits 97 self.compression_level = compression_level 98 self.compression_method = compression_method 99 self.mem_level = mem_level 100 self.compression_strategy = compression_strategy 101 102 @classmethod 103 def get_compression_type_string(cls, options): 104 """Convert various option types to a unified string. 105 106 Args: 107 options: `TFRecordOption`, `TFRecordCompressionType`, or string. 108 109 Returns: 110 Compression type as string (e.g. `'ZLIB'`, `'GZIP'`, or `''`). 111 112 Raises: 113 ValueError: If compression_type is invalid. 114 """ 115 if not options: 116 return "" 117 elif isinstance(options, TFRecordOptions): 118 return cls.get_compression_type_string(options.compression_type) 119 elif isinstance(options, TFRecordCompressionType): 120 return cls.compression_type_map[options] 121 elif options in TFRecordOptions.compression_type_map: 122 return cls.compression_type_map[options] 123 elif options in TFRecordOptions.compression_type_map.values(): 124 return options 125 else: 126 raise ValueError('Not a valid compression_type: "{}"'.format(options)) 127 128 def _as_record_writer_options(self): 129 """Convert to RecordWriterOptions for use with PyRecordWriter.""" 130 options = pywrap_tensorflow.RecordWriterOptions_CreateRecordWriterOptions( 131 compat.as_bytes( 132 self.get_compression_type_string(self.compression_type))) 133 134 if self.flush_mode is not None: 135 options.zlib_options.flush_mode = self.flush_mode 136 if self.input_buffer_size is not None: 137 options.zlib_options.input_buffer_size = self.input_buffer_size 138 if self.output_buffer_size is not None: 139 options.zlib_options.output_buffer_size = self.output_buffer_size 140 if self.window_bits is not None: 141 options.zlib_options.window_bits = self.window_bits 142 if self.compression_level is not None: 143 options.zlib_options.compression_level = self.compression_level 144 if self.compression_method is not None: 145 options.zlib_options.compression_method = self.compression_method 146 if self.mem_level is not None: 147 options.zlib_options.mem_level = self.mem_level 148 if self.compression_strategy is not None: 149 options.zlib_options.compression_strategy = self.compression_strategy 150 return options 151 152 153@tf_export(v1=["io.tf_record_iterator", "python_io.tf_record_iterator"]) 154@deprecation.deprecated( 155 date=None, 156 instructions=("Use eager execution and: \n" 157 "`tf.data.TFRecordDataset(path)`")) 158def tf_record_iterator(path, options=None): 159 """An iterator that read the records from a TFRecords file. 160 161 Args: 162 path: The path to the TFRecords file. 163 options: (optional) A TFRecordOptions object. 164 165 Yields: 166 Strings. 167 168 Raises: 169 IOError: If `path` cannot be opened for reading. 170 """ 171 compression_type = TFRecordOptions.get_compression_type_string(options) 172 with errors.raise_exception_on_not_ok_status() as status: 173 reader = pywrap_tensorflow.PyRecordReader_New( 174 compat.as_bytes(path), 0, compat.as_bytes(compression_type), status) 175 176 if reader is None: 177 raise IOError("Could not open %s." % path) 178 try: 179 while True: 180 try: 181 reader.GetNext() 182 except errors.OutOfRangeError: 183 break 184 yield reader.record() 185 finally: 186 reader.Close() 187 188 189@tf_export( 190 "io.TFRecordWriter", v1=["io.TFRecordWriter", "python_io.TFRecordWriter"]) 191@deprecation.deprecated_endpoints("python_io.TFRecordWriter") 192class TFRecordWriter(object): 193 """A class to write records to a TFRecords file. 194 195 This class implements `__enter__` and `__exit__`, and can be used 196 in `with` blocks like a normal file. 197 """ 198 199 # TODO(josh11b): Support appending? 200 def __init__(self, path, options=None): 201 """Opens file `path` and creates a `TFRecordWriter` writing to it. 202 203 Args: 204 path: The path to the TFRecords file. 205 options: (optional) String specifying compression type, 206 `TFRecordCompressionType`, or `TFRecordOptions` object. 207 208 Raises: 209 IOError: If `path` cannot be opened for writing. 210 ValueError: If valid compression_type can't be determined from `options`. 211 """ 212 if not isinstance(options, TFRecordOptions): 213 options = TFRecordOptions(compression_type=options) 214 215 with errors.raise_exception_on_not_ok_status() as status: 216 # pylint: disable=protected-access 217 self._writer = pywrap_tensorflow.PyRecordWriter_New( 218 compat.as_bytes(path), options._as_record_writer_options(), status) 219 # pylint: enable=protected-access 220 221 def __enter__(self): 222 """Enter a `with` block.""" 223 return self 224 225 def __exit__(self, unused_type, unused_value, unused_traceback): 226 """Exit a `with` block, closing the file.""" 227 self.close() 228 229 def write(self, record): 230 """Write a string record to the file. 231 232 Args: 233 record: str 234 """ 235 with errors.raise_exception_on_not_ok_status() as status: 236 self._writer.WriteRecord(record, status) 237 238 def flush(self): 239 """Flush the file.""" 240 with errors.raise_exception_on_not_ok_status() as status: 241 self._writer.Flush(status) 242 243 def close(self): 244 """Close the file.""" 245 with errors.raise_exception_on_not_ok_status() as status: 246 self._writer.Close(status) 247