1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for `tf.data.TextLineDataset`.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import gzip 21import os 22import pathlib 23import zlib 24 25from absl.testing import parameterized 26 27from tensorflow.python.data.kernel_tests import checkpoint_test_base 28from tensorflow.python.data.kernel_tests import test_base 29from tensorflow.python.data.ops import dataset_ops 30from tensorflow.python.data.ops import readers 31from tensorflow.python.framework import combinations 32from tensorflow.python.platform import test 33from tensorflow.python.util import compat 34 35 36try: 37 import psutil # pylint: disable=g-import-not-at-top 38 psutil_import_succeeded = True 39except ImportError: 40 psutil_import_succeeded = False 41 42 43class TextLineDatasetTestBase(test_base.DatasetTestBase): 44 """Base class for setting up and testing TextLineDataset.""" 45 46 def _lineText(self, f, l): 47 return compat.as_bytes("%d: %d" % (f, l)) 48 49 def _createFiles(self, 50 num_files, 51 num_lines, 52 crlf=False, 53 compression_type=None): 54 filenames = [] 55 for i in range(num_files): 56 fn = os.path.join(self.get_temp_dir(), "text_line.%d.txt" % i) 57 filenames.append(fn) 58 contents = [] 59 for j in range(num_lines): 60 contents.append(self._lineText(i, j)) 61 # Always include a newline after the record unless it is 62 # at the end of the file, in which case we include it 63 if j + 1 != num_lines or i == 0: 64 contents.append(b"\r\n" if crlf else b"\n") 65 contents = b"".join(contents) 66 67 if not compression_type: 68 with open(fn, "wb") as f: 69 f.write(contents) 70 elif compression_type == "GZIP": 71 with gzip.GzipFile(fn, "wb") as f: 72 f.write(contents) 73 elif compression_type == "ZLIB": 74 contents = zlib.compress(contents) 75 with open(fn, "wb") as f: 76 f.write(contents) 77 else: 78 raise ValueError("Unsupported compression_type", compression_type) 79 80 return filenames 81 82 83class TextLineDatasetTest(TextLineDatasetTestBase, parameterized.TestCase): 84 85 @combinations.generate( 86 combinations.times( 87 test_base.default_test_combinations(), 88 combinations.combine(compression_type=[None, "GZIP", "ZLIB"]))) 89 def testTextLineDataset(self, compression_type): 90 test_filenames = self._createFiles( 91 2, 5, crlf=True, compression_type=compression_type) 92 93 def dataset_fn(filenames, num_epochs, batch_size=None): 94 repeat_dataset = readers.TextLineDataset( 95 filenames, compression_type=compression_type).repeat(num_epochs) 96 if batch_size: 97 return repeat_dataset.batch(batch_size) 98 return repeat_dataset 99 100 # Basic test: read from file 0. 101 expected_output = [self._lineText(0, i) for i in range(5)] 102 self.assertDatasetProduces( 103 dataset_fn([test_filenames[0]], 1), expected_output=expected_output) 104 105 # Basic test: read from file 1. 106 self.assertDatasetProduces( 107 dataset_fn([test_filenames[1]], 1), 108 expected_output=[self._lineText(1, i) for i in range(5)]) 109 110 # Basic test: read from both files. 111 expected_output = [self._lineText(0, i) for i in range(5)] 112 expected_output.extend(self._lineText(1, i) for i in range(5)) 113 self.assertDatasetProduces( 114 dataset_fn(test_filenames, 1), expected_output=expected_output) 115 116 # Test repeated iteration through both files. 117 expected_output = [self._lineText(0, i) for i in range(5)] 118 expected_output.extend(self._lineText(1, i) for i in range(5)) 119 self.assertDatasetProduces( 120 dataset_fn(test_filenames, 10), expected_output=expected_output * 10) 121 122 # Test batched and repeated iteration through both files. 123 self.assertDatasetProduces( 124 dataset_fn(test_filenames, 10, 5), 125 expected_output=[[self._lineText(0, i) for i in range(5)], 126 [self._lineText(1, i) for i in range(5)]] * 10) 127 128 @combinations.generate(test_base.default_test_combinations()) 129 def testTextLineDatasetParallelRead(self): 130 test_filenames = self._createFiles(10, 10) 131 files = dataset_ops.Dataset.from_tensor_slices(test_filenames).repeat(10) 132 expected_output = [] 133 for j in range(10): 134 expected_output.extend(self._lineText(j, i) for i in range(10)) 135 dataset = readers.TextLineDataset(files, num_parallel_reads=4) 136 self.assertDatasetProduces( 137 dataset, expected_output=expected_output * 10, assert_items_equal=True) 138 139 @combinations.generate(test_base.default_test_combinations()) 140 def testTextLineDatasetBuffering(self): 141 test_filenames = self._createFiles(2, 5, crlf=True) 142 143 repeat_dataset = readers.TextLineDataset(test_filenames, buffer_size=10) 144 expected_output = [] 145 for j in range(2): 146 expected_output.extend([self._lineText(j, i) for i in range(5)]) 147 self.assertDatasetProduces(repeat_dataset, expected_output=expected_output) 148 149 @combinations.generate(test_base.eager_only_combinations()) 150 def testIteratorResourceCleanup(self): 151 filename = os.path.join(self.get_temp_dir(), "text.txt") 152 with open(filename, "wt") as f: 153 for i in range(3): 154 f.write("%d\n" % (i,)) 155 first_iterator = iter(readers.TextLineDataset(filename)) 156 self.assertEqual(b"0", next(first_iterator).numpy()) 157 second_iterator = iter(readers.TextLineDataset(filename)) 158 self.assertEqual(b"0", next(second_iterator).numpy()) 159 # Eager kernel caching is based on op attributes, which includes the 160 # Dataset's output shape. Create a different kernel to test that they 161 # don't create resources with the same names. 162 different_kernel_iterator = iter( 163 readers.TextLineDataset(filename).repeat().batch(16)) 164 self.assertEqual([16], next(different_kernel_iterator).shape) 165 # Remove our references to the Python Iterator objects, which (assuming no 166 # reference cycles) is enough to trigger DestroyResourceOp and close the 167 # partially-read files. 168 del first_iterator 169 del second_iterator 170 del different_kernel_iterator 171 if not psutil_import_succeeded: 172 self.skipTest( 173 "psutil is required to check that we've closed our files.") 174 open_files = psutil.Process().open_files() 175 self.assertNotIn(filename, [open_file.path for open_file in open_files]) 176 177 @combinations.generate(test_base.default_test_combinations()) 178 def testTextLineDatasetPathlib(self): 179 files = self._createFiles(1, 5) 180 files = [pathlib.Path(f) for f in files] 181 182 expected_output = [self._lineText(0, i) for i in range(5)] 183 ds = readers.TextLineDataset(files) 184 self.assertDatasetProduces( 185 ds, expected_output=expected_output, assert_items_equal=True) 186 187 188class TextLineDatasetCheckpointTest(TextLineDatasetTestBase, 189 checkpoint_test_base.CheckpointTestBase, 190 parameterized.TestCase): 191 192 def _build_iterator_graph(self, test_filenames, compression_type=None): 193 return readers.TextLineDataset( 194 test_filenames, compression_type=compression_type, buffer_size=10) 195 196 @combinations.generate( 197 combinations.times( 198 test_base.default_test_combinations(), 199 checkpoint_test_base.default_test_combinations(), 200 combinations.combine(compression_type=[None, "GZIP", "ZLIB"]))) 201 def test(self, verify_fn, compression_type): 202 num_files = 5 203 lines_per_file = 5 204 num_outputs = num_files * lines_per_file 205 test_filenames = self._createFiles( 206 num_files, lines_per_file, crlf=True, compression_type=compression_type) 207 verify_fn( 208 self, 209 lambda: self._build_iterator_graph(test_filenames, compression_type), 210 num_outputs) 211 212 213if __name__ == "__main__": 214 test.main() 215