1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for the private `MatchingFilesDataset`.""" 16import os 17import shutil 18import tempfile 19 20from absl.testing import parameterized 21 22from tensorflow.python.data.experimental.ops import matching_files 23from tensorflow.python.data.kernel_tests import checkpoint_test_base 24from tensorflow.python.data.kernel_tests import test_base 25from tensorflow.python.framework import combinations 26from tensorflow.python.framework import errors 27from tensorflow.python.platform import test 28from tensorflow.python.util import compat 29 30 31class MatchingFilesDatasetTest(test_base.DatasetTestBase, 32 parameterized.TestCase): 33 34 def setUp(self): 35 super(MatchingFilesDatasetTest, self).setUp() 36 self.tmp_dir = tempfile.mkdtemp() 37 38 def tearDown(self): 39 shutil.rmtree(self.tmp_dir, ignore_errors=True) 40 super(MatchingFilesDatasetTest, self).tearDown() 41 42 def _touchTempFiles(self, filenames): 43 for filename in filenames: 44 open(os.path.join(self.tmp_dir, filename), 'a').close() 45 46 @combinations.generate(test_base.default_test_combinations()) 47 def testNonExistingDirectory(self): 48 """Test the MatchingFiles dataset with a non-existing directory.""" 49 50 self.tmp_dir = os.path.join(self.tmp_dir, 'nonexistingdir') 51 dataset = matching_files.MatchingFilesDataset( 52 os.path.join(self.tmp_dir, '*')) 53 self.assertDatasetProduces( 54 dataset, expected_error=(errors.NotFoundError, '')) 55 56 @combinations.generate(test_base.default_test_combinations()) 57 def testEmptyDirectory(self): 58 """Test the MatchingFiles dataset with an empty directory.""" 59 60 dataset = matching_files.MatchingFilesDataset( 61 os.path.join(self.tmp_dir, '*')) 62 self.assertDatasetProduces( 63 dataset, expected_error=(errors.NotFoundError, '')) 64 65 @combinations.generate(test_base.default_test_combinations()) 66 def testSimpleDirectory(self): 67 """Test the MatchingFiles dataset with a simple directory.""" 68 69 filenames = ['a', 'b', 'c'] 70 self._touchTempFiles(filenames) 71 72 dataset = matching_files.MatchingFilesDataset( 73 os.path.join(self.tmp_dir, '*')) 74 self.assertDatasetProduces( 75 dataset, 76 expected_output=[ 77 compat.as_bytes(os.path.join(self.tmp_dir, filename)) 78 for filename in filenames 79 ], 80 assert_items_equal=True) 81 82 @combinations.generate(test_base.default_test_combinations()) 83 def testFileSuffixes(self): 84 """Test the MatchingFiles dataset using the suffixes of filename.""" 85 86 filenames = ['a.txt', 'b.py', 'c.py', 'd.pyc'] 87 self._touchTempFiles(filenames) 88 89 dataset = matching_files.MatchingFilesDataset( 90 os.path.join(self.tmp_dir, '*.py')) 91 self.assertDatasetProduces( 92 dataset, 93 expected_output=[ 94 compat.as_bytes(os.path.join(self.tmp_dir, filename)) 95 for filename in filenames[1:-1] 96 ], 97 assert_items_equal=True) 98 99 @combinations.generate(test_base.default_test_combinations()) 100 def testFileMiddles(self): 101 """Test the MatchingFiles dataset using the middles of filename.""" 102 103 filenames = ['aa.txt', 'bb.py', 'bbc.pyc', 'cc.pyc'] 104 self._touchTempFiles(filenames) 105 106 dataset = matching_files.MatchingFilesDataset( 107 os.path.join(self.tmp_dir, 'b*.py*')) 108 self.assertDatasetProduces( 109 dataset, 110 expected_output=[ 111 compat.as_bytes(os.path.join(self.tmp_dir, filename)) 112 for filename in filenames[1:3] 113 ], 114 assert_items_equal=True) 115 116 @combinations.generate(test_base.default_test_combinations()) 117 def testNestedDirectories(self): 118 """Test the MatchingFiles dataset with nested directories.""" 119 120 filenames = [] 121 width = 8 122 depth = 4 123 for i in range(width): 124 for j in range(depth): 125 new_base = os.path.join(self.tmp_dir, str(i), 126 *[str(dir_name) for dir_name in range(j)]) 127 os.makedirs(new_base) 128 child_files = ['a.py', 'b.pyc'] if j < depth - 1 else ['c.txt', 'd.log'] 129 for f in child_files: 130 filename = os.path.join(new_base, f) 131 filenames.append(filename) 132 open(filename, 'w').close() 133 134 patterns = [ 135 os.path.join(self.tmp_dir, os.path.join(*['**' for _ in range(depth)]), 136 suffix) for suffix in ['*.txt', '*.log'] 137 ] 138 139 dataset = matching_files.MatchingFilesDataset(patterns) 140 next_element = self.getNext(dataset) 141 expected_filenames = [ 142 compat.as_bytes(filename) 143 for filename in filenames 144 if filename.endswith('.txt') or filename.endswith('.log') 145 ] 146 actual_filenames = [] 147 while True: 148 try: 149 actual_filenames.append(compat.as_bytes(self.evaluate(next_element()))) 150 except errors.OutOfRangeError: 151 break 152 153 self.assertCountEqual(expected_filenames, actual_filenames) 154 155 156class MatchingFilesDatasetCheckpointTest( 157 checkpoint_test_base.CheckpointTestBase, parameterized.TestCase): 158 159 def _build_iterator_graph(self, test_patterns): 160 return matching_files.MatchingFilesDataset(test_patterns) 161 162 @combinations.generate( 163 combinations.times(test_base.default_test_combinations(), 164 checkpoint_test_base.default_test_combinations())) 165 def test(self, verify_fn): 166 tmp_dir = tempfile.mkdtemp() 167 width = 16 168 depth = 8 169 for i in range(width): 170 for j in range(depth): 171 new_base = os.path.join(tmp_dir, str(i), 172 *[str(dir_name) for dir_name in range(j)]) 173 if not os.path.exists(new_base): 174 os.makedirs(new_base) 175 child_files = ['a.py', 'b.pyc'] if j < depth - 1 else ['c.txt', 'd.log'] 176 for f in child_files: 177 filename = os.path.join(new_base, f) 178 open(filename, 'w').close() 179 180 patterns = [ 181 os.path.join(tmp_dir, os.path.join(*['**' 182 for _ in range(depth)]), suffix) 183 for suffix in ['*.txt', '*.log'] 184 ] 185 186 num_outputs = width * len(patterns) 187 verify_fn(self, lambda: self._build_iterator_graph(patterns), num_outputs) 188 189 shutil.rmtree(tmp_dir, ignore_errors=True) 190 191 192if __name__ == '__main__': 193 test.main() 194