• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for `tf.data.TextLineDataset`."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import gzip
21import os
22import pathlib
23import zlib
24
25from absl.testing import parameterized
26
27from tensorflow.python.data.kernel_tests import checkpoint_test_base
28from tensorflow.python.data.kernel_tests import test_base
29from tensorflow.python.data.ops import dataset_ops
30from tensorflow.python.data.ops import readers
31from tensorflow.python.framework import combinations
32from tensorflow.python.platform import test
33from tensorflow.python.util import compat
34
35
36try:
37  import psutil  # pylint: disable=g-import-not-at-top
38  psutil_import_succeeded = True
39except ImportError:
40  psutil_import_succeeded = False
41
42
43class TextLineDatasetTestBase(test_base.DatasetTestBase):
44  """Base class for setting up and testing TextLineDataset."""
45
46  def _lineText(self, f, l):
47    return compat.as_bytes("%d: %d" % (f, l))
48
49  def _createFiles(self,
50                   num_files,
51                   num_lines,
52                   crlf=False,
53                   compression_type=None):
54    filenames = []
55    for i in range(num_files):
56      fn = os.path.join(self.get_temp_dir(), "text_line.%d.txt" % i)
57      filenames.append(fn)
58      contents = []
59      for j in range(num_lines):
60        contents.append(self._lineText(i, j))
61        # Always include a newline after the record unless it is
62        # at the end of the file, in which case we include it
63        if j + 1 != num_lines or i == 0:
64          contents.append(b"\r\n" if crlf else b"\n")
65      contents = b"".join(contents)
66
67      if not compression_type:
68        with open(fn, "wb") as f:
69          f.write(contents)
70      elif compression_type == "GZIP":
71        with gzip.GzipFile(fn, "wb") as f:
72          f.write(contents)
73      elif compression_type == "ZLIB":
74        contents = zlib.compress(contents)
75        with open(fn, "wb") as f:
76          f.write(contents)
77      else:
78        raise ValueError("Unsupported compression_type", compression_type)
79
80    return filenames
81
82
83class TextLineDatasetTest(TextLineDatasetTestBase, parameterized.TestCase):
84
85  @combinations.generate(
86      combinations.times(
87          test_base.default_test_combinations(),
88          combinations.combine(compression_type=[None, "GZIP", "ZLIB"])))
89  def testTextLineDataset(self, compression_type):
90    test_filenames = self._createFiles(
91        2, 5, crlf=True, compression_type=compression_type)
92
93    def dataset_fn(filenames, num_epochs, batch_size=None):
94      repeat_dataset = readers.TextLineDataset(
95          filenames, compression_type=compression_type).repeat(num_epochs)
96      if batch_size:
97        return repeat_dataset.batch(batch_size)
98      return repeat_dataset
99
100    # Basic test: read from file 0.
101    expected_output = [self._lineText(0, i) for i in range(5)]
102    self.assertDatasetProduces(
103        dataset_fn([test_filenames[0]], 1), expected_output=expected_output)
104
105    # Basic test: read from file 1.
106    self.assertDatasetProduces(
107        dataset_fn([test_filenames[1]], 1),
108        expected_output=[self._lineText(1, i) for i in range(5)])
109
110    # Basic test: read from both files.
111    expected_output = [self._lineText(0, i) for i in range(5)]
112    expected_output.extend(self._lineText(1, i) for i in range(5))
113    self.assertDatasetProduces(
114        dataset_fn(test_filenames, 1), expected_output=expected_output)
115
116    # Test repeated iteration through both files.
117    expected_output = [self._lineText(0, i) for i in range(5)]
118    expected_output.extend(self._lineText(1, i) for i in range(5))
119    self.assertDatasetProduces(
120        dataset_fn(test_filenames, 10), expected_output=expected_output * 10)
121
122    # Test batched and repeated iteration through both files.
123    self.assertDatasetProduces(
124        dataset_fn(test_filenames, 10, 5),
125        expected_output=[[self._lineText(0, i) for i in range(5)],
126                         [self._lineText(1, i) for i in range(5)]] * 10)
127
128  @combinations.generate(test_base.default_test_combinations())
129  def testTextLineDatasetParallelRead(self):
130    test_filenames = self._createFiles(10, 10)
131    files = dataset_ops.Dataset.from_tensor_slices(test_filenames).repeat(10)
132    expected_output = []
133    for j in range(10):
134      expected_output.extend(self._lineText(j, i) for i in range(10))
135    dataset = readers.TextLineDataset(files, num_parallel_reads=4)
136    self.assertDatasetProduces(
137        dataset, expected_output=expected_output * 10, assert_items_equal=True)
138
139  @combinations.generate(test_base.default_test_combinations())
140  def testTextLineDatasetBuffering(self):
141    test_filenames = self._createFiles(2, 5, crlf=True)
142
143    repeat_dataset = readers.TextLineDataset(test_filenames, buffer_size=10)
144    expected_output = []
145    for j in range(2):
146      expected_output.extend([self._lineText(j, i) for i in range(5)])
147    self.assertDatasetProduces(repeat_dataset, expected_output=expected_output)
148
149  @combinations.generate(test_base.eager_only_combinations())
150  def testIteratorResourceCleanup(self):
151    filename = os.path.join(self.get_temp_dir(), "text.txt")
152    with open(filename, "wt") as f:
153      for i in range(3):
154        f.write("%d\n" % (i,))
155    first_iterator = iter(readers.TextLineDataset(filename))
156    self.assertEqual(b"0", next(first_iterator).numpy())
157    second_iterator = iter(readers.TextLineDataset(filename))
158    self.assertEqual(b"0", next(second_iterator).numpy())
159    # Eager kernel caching is based on op attributes, which includes the
160    # Dataset's output shape. Create a different kernel to test that they
161    # don't create resources with the same names.
162    different_kernel_iterator = iter(
163        readers.TextLineDataset(filename).repeat().batch(16))
164    self.assertEqual([16], next(different_kernel_iterator).shape)
165    # Remove our references to the Python Iterator objects, which (assuming no
166    # reference cycles) is enough to trigger DestroyResourceOp and close the
167    # partially-read files.
168    del first_iterator
169    del second_iterator
170    del different_kernel_iterator
171    if not psutil_import_succeeded:
172      self.skipTest(
173          "psutil is required to check that we've closed our files.")
174    open_files = psutil.Process().open_files()
175    self.assertNotIn(filename, [open_file.path for open_file in open_files])
176
177  @combinations.generate(test_base.default_test_combinations())
178  def testTextLineDatasetPathlib(self):
179    files = self._createFiles(1, 5)
180    files = [pathlib.Path(f) for f in files]
181
182    expected_output = [self._lineText(0, i) for i in range(5)]
183    ds = readers.TextLineDataset(files)
184    self.assertDatasetProduces(
185        ds, expected_output=expected_output, assert_items_equal=True)
186
187
188class TextLineDatasetCheckpointTest(TextLineDatasetTestBase,
189                                    checkpoint_test_base.CheckpointTestBase,
190                                    parameterized.TestCase):
191
192  def _build_iterator_graph(self, test_filenames, compression_type=None):
193    return readers.TextLineDataset(
194        test_filenames, compression_type=compression_type, buffer_size=10)
195
196  @combinations.generate(
197      combinations.times(
198          test_base.default_test_combinations(),
199          checkpoint_test_base.default_test_combinations(),
200          combinations.combine(compression_type=[None, "GZIP", "ZLIB"])))
201  def test(self, verify_fn, compression_type):
202    num_files = 5
203    lines_per_file = 5
204    num_outputs = num_files * lines_per_file
205    test_filenames = self._createFiles(
206        num_files, lines_per_file, crlf=True, compression_type=compression_type)
207    verify_fn(
208        self,
209        lambda: self._build_iterator_graph(test_filenames, compression_type),
210        num_outputs)
211
212
213if __name__ == "__main__":
214  test.main()
215