• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for the private `MatchingFilesDataset`."""
16import os
17import shutil
18import tempfile
19
20from absl.testing import parameterized
21
22from tensorflow.python.data.experimental.ops import matching_files
23from tensorflow.python.data.kernel_tests import checkpoint_test_base
24from tensorflow.python.data.kernel_tests import test_base
25from tensorflow.python.framework import combinations
26from tensorflow.python.framework import errors
27from tensorflow.python.platform import test
28from tensorflow.python.util import compat
29
30
31class MatchingFilesDatasetTest(test_base.DatasetTestBase,
32                               parameterized.TestCase):
33
34  def setUp(self):
35    super(MatchingFilesDatasetTest, self).setUp()
36    self.tmp_dir = tempfile.mkdtemp()
37
38  def tearDown(self):
39    shutil.rmtree(self.tmp_dir, ignore_errors=True)
40    super(MatchingFilesDatasetTest, self).tearDown()
41
42  def _touchTempFiles(self, filenames):
43    for filename in filenames:
44      open(os.path.join(self.tmp_dir, filename), 'a').close()
45
46  @combinations.generate(test_base.default_test_combinations())
47  def testNonExistingDirectory(self):
48    """Test the MatchingFiles dataset with a non-existing directory."""
49
50    self.tmp_dir = os.path.join(self.tmp_dir, 'nonexistingdir')
51    dataset = matching_files.MatchingFilesDataset(
52        os.path.join(self.tmp_dir, '*'))
53    self.assertDatasetProduces(
54        dataset, expected_error=(errors.NotFoundError, ''))
55
56  @combinations.generate(test_base.default_test_combinations())
57  def testEmptyDirectory(self):
58    """Test the MatchingFiles dataset with an empty directory."""
59
60    dataset = matching_files.MatchingFilesDataset(
61        os.path.join(self.tmp_dir, '*'))
62    self.assertDatasetProduces(
63        dataset, expected_error=(errors.NotFoundError, ''))
64
65  @combinations.generate(test_base.default_test_combinations())
66  def testSimpleDirectory(self):
67    """Test the MatchingFiles dataset with a simple directory."""
68
69    filenames = ['a', 'b', 'c']
70    self._touchTempFiles(filenames)
71
72    dataset = matching_files.MatchingFilesDataset(
73        os.path.join(self.tmp_dir, '*'))
74    self.assertDatasetProduces(
75        dataset,
76        expected_output=[
77            compat.as_bytes(os.path.join(self.tmp_dir, filename))
78            for filename in filenames
79        ],
80        assert_items_equal=True)
81
82  @combinations.generate(test_base.default_test_combinations())
83  def testFileSuffixes(self):
84    """Test the MatchingFiles dataset using the suffixes of filename."""
85
86    filenames = ['a.txt', 'b.py', 'c.py', 'd.pyc']
87    self._touchTempFiles(filenames)
88
89    dataset = matching_files.MatchingFilesDataset(
90        os.path.join(self.tmp_dir, '*.py'))
91    self.assertDatasetProduces(
92        dataset,
93        expected_output=[
94            compat.as_bytes(os.path.join(self.tmp_dir, filename))
95            for filename in filenames[1:-1]
96        ],
97        assert_items_equal=True)
98
99  @combinations.generate(test_base.default_test_combinations())
100  def testFileMiddles(self):
101    """Test the MatchingFiles dataset using the middles of filename."""
102
103    filenames = ['aa.txt', 'bb.py', 'bbc.pyc', 'cc.pyc']
104    self._touchTempFiles(filenames)
105
106    dataset = matching_files.MatchingFilesDataset(
107        os.path.join(self.tmp_dir, 'b*.py*'))
108    self.assertDatasetProduces(
109        dataset,
110        expected_output=[
111            compat.as_bytes(os.path.join(self.tmp_dir, filename))
112            for filename in filenames[1:3]
113        ],
114        assert_items_equal=True)
115
116  @combinations.generate(test_base.default_test_combinations())
117  def testNestedDirectories(self):
118    """Test the MatchingFiles dataset with nested directories."""
119
120    filenames = []
121    width = 8
122    depth = 4
123    for i in range(width):
124      for j in range(depth):
125        new_base = os.path.join(self.tmp_dir, str(i),
126                                *[str(dir_name) for dir_name in range(j)])
127        os.makedirs(new_base)
128        child_files = ['a.py', 'b.pyc'] if j < depth - 1 else ['c.txt', 'd.log']
129        for f in child_files:
130          filename = os.path.join(new_base, f)
131          filenames.append(filename)
132          open(filename, 'w').close()
133
134    patterns = [
135        os.path.join(self.tmp_dir, os.path.join(*['**' for _ in range(depth)]),
136                     suffix) for suffix in ['*.txt', '*.log']
137    ]
138
139    dataset = matching_files.MatchingFilesDataset(patterns)
140    next_element = self.getNext(dataset)
141    expected_filenames = [
142        compat.as_bytes(filename)
143        for filename in filenames
144        if filename.endswith('.txt') or filename.endswith('.log')
145    ]
146    actual_filenames = []
147    while True:
148      try:
149        actual_filenames.append(compat.as_bytes(self.evaluate(next_element())))
150      except errors.OutOfRangeError:
151        break
152
153    self.assertCountEqual(expected_filenames, actual_filenames)
154
155
156class MatchingFilesDatasetCheckpointTest(
157    checkpoint_test_base.CheckpointTestBase, parameterized.TestCase):
158
159  def _build_iterator_graph(self, test_patterns):
160    return matching_files.MatchingFilesDataset(test_patterns)
161
162  @combinations.generate(
163      combinations.times(test_base.default_test_combinations(),
164                         checkpoint_test_base.default_test_combinations()))
165  def test(self, verify_fn):
166    tmp_dir = tempfile.mkdtemp()
167    width = 16
168    depth = 8
169    for i in range(width):
170      for j in range(depth):
171        new_base = os.path.join(tmp_dir, str(i),
172                                *[str(dir_name) for dir_name in range(j)])
173        if not os.path.exists(new_base):
174          os.makedirs(new_base)
175        child_files = ['a.py', 'b.pyc'] if j < depth - 1 else ['c.txt', 'd.log']
176        for f in child_files:
177          filename = os.path.join(new_base, f)
178          open(filename, 'w').close()
179
180    patterns = [
181        os.path.join(tmp_dir, os.path.join(*['**'
182                                             for _ in range(depth)]), suffix)
183        for suffix in ['*.txt', '*.log']
184    ]
185
186    num_outputs = width * len(patterns)
187    verify_fn(self, lambda: self._build_iterator_graph(patterns), num_outputs)
188
189    shutil.rmtree(tmp_dir, ignore_errors=True)
190
191
192if __name__ == '__main__':
193  test.main()
194