1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""For reading and writing TFRecords files.""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22from tensorflow.python import pywrap_tensorflow 23from tensorflow.python.framework import errors 24from tensorflow.python.util import compat 25from tensorflow.python.util.tf_export import tf_export 26 27 28@tf_export("python_io.TFRecordCompressionType") 29class TFRecordCompressionType(object): 30 """The type of compression for the record.""" 31 NONE = 0 32 ZLIB = 1 33 GZIP = 2 34 35 36# NOTE(vrv): This will eventually be converted into a proto. to match 37# the interface used by the C++ RecordWriter. 38@tf_export("python_io.TFRecordOptions") 39class TFRecordOptions(object): 40 """Options used for manipulating TFRecord files.""" 41 compression_type_map = { 42 TFRecordCompressionType.ZLIB: "ZLIB", 43 TFRecordCompressionType.GZIP: "GZIP", 44 TFRecordCompressionType.NONE: "" 45 } 46 47 def __init__(self, compression_type): 48 self.compression_type = compression_type 49 50 @classmethod 51 def get_compression_type_string(cls, options): 52 if not options: 53 return "" 54 return cls.compression_type_map[options.compression_type] 55 56 57@tf_export("python_io.tf_record_iterator") 58def tf_record_iterator(path, options=None): 59 """An iterator that read the records from a TFRecords file. 60 61 Args: 62 path: The path to the TFRecords file. 63 options: (optional) A TFRecordOptions object. 64 65 Yields: 66 Strings. 67 68 Raises: 69 IOError: If `path` cannot be opened for reading. 70 """ 71 compression_type = TFRecordOptions.get_compression_type_string(options) 72 with errors.raise_exception_on_not_ok_status() as status: 73 reader = pywrap_tensorflow.PyRecordReader_New( 74 compat.as_bytes(path), 0, compat.as_bytes(compression_type), status) 75 76 if reader is None: 77 raise IOError("Could not open %s." % path) 78 while True: 79 try: 80 with errors.raise_exception_on_not_ok_status() as status: 81 reader.GetNext(status) 82 except errors.OutOfRangeError: 83 break 84 yield reader.record() 85 reader.Close() 86 87 88@tf_export("python_io.TFRecordWriter") 89class TFRecordWriter(object): 90 """A class to write records to a TFRecords file. 91 92 This class implements `__enter__` and `__exit__`, and can be used 93 in `with` blocks like a normal file. 94 """ 95 96 # TODO(josh11b): Support appending? 97 def __init__(self, path, options=None): 98 """Opens file `path` and creates a `TFRecordWriter` writing to it. 99 100 Args: 101 path: The path to the TFRecords file. 102 options: (optional) A TFRecordOptions object. 103 104 Raises: 105 IOError: If `path` cannot be opened for writing. 106 """ 107 compression_type = TFRecordOptions.get_compression_type_string(options) 108 109 with errors.raise_exception_on_not_ok_status() as status: 110 self._writer = pywrap_tensorflow.PyRecordWriter_New( 111 compat.as_bytes(path), compat.as_bytes(compression_type), status) 112 113 def __enter__(self): 114 """Enter a `with` block.""" 115 return self 116 117 def __exit__(self, unused_type, unused_value, unused_traceback): 118 """Exit a `with` block, closing the file.""" 119 self.close() 120 121 def write(self, record): 122 """Write a string record to the file. 123 124 Args: 125 record: str 126 """ 127 self._writer.WriteRecord(record) 128 129 def flush(self): 130 """Flush the file.""" 131 with errors.raise_exception_on_not_ok_status() as status: 132 self._writer.Flush(status) 133 134 def close(self): 135 """Close the file.""" 136 with errors.raise_exception_on_not_ok_status() as status: 137 self._writer.Close(status) 138