• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Benchmarks for `tf.data.Dataset.list_files()`."""
16import os
17import shutil
18import tempfile
19
20from tensorflow.python.data.benchmarks import benchmark_base
21from tensorflow.python.data.ops import dataset_ops
22
23
24class ListFilesBenchmark(benchmark_base.DatasetBenchmarkBase):
25  """Benchmarks for `tf.data.Dataset.list_files()`."""
26
27  def benchmark_nested_directories(self):
28    tmp_dir = tempfile.mkdtemp()
29    width = 1024
30    depth = 16
31    for i in range(width):
32      for j in range(depth):
33        new_base = os.path.join(tmp_dir, str(i),
34                                *[str(dir_name) for dir_name in range(j)])
35        os.makedirs(new_base)
36        child_files = ['a.py', 'b.pyc'] if j < depth - 1 else ['c.txt', 'd.log']
37        for f in child_files:
38          filename = os.path.join(new_base, f)
39          open(filename, 'w').close()
40    patterns = [
41        os.path.join(tmp_dir, os.path.join(*['**'
42                                             for _ in range(depth)]), suffix)
43        for suffix in ['*.txt', '*.log']
44    ]
45    # the num_elements depends on the pattern that has been defined above.
46    # In the current scenario, the num of files are selected based on the
47    # ['*.txt', '*.log'] patterns. Since the files which match either of these
48    # patterns are created once per `width`. The num_elements would be:
49    num_elements = width * 2
50
51    dataset = dataset_ops.Dataset.list_files(patterns)
52    self.run_and_report_benchmark(
53        dataset=dataset,
54        iters=3,
55        num_elements=num_elements,
56        extras={
57            'model_name': 'list_files.benchmark.1',
58            'parameters': '%d.%d' % (width, depth),
59        },
60        name='nested_directory(%d*%d)' % (width, depth))
61    shutil.rmtree(tmp_dir, ignore_errors=True)
62
63
64if __name__ == '__main__':
65  benchmark_base.test.main()
66