• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Benchmarks for `tf.data.Dataset.interleave()`."""
16
17from tensorflow.python.data.benchmarks import benchmark_base
18from tensorflow.python.data.experimental.ops import interleave_ops
19from tensorflow.python.data.experimental.ops import testing
20from tensorflow.python.data.ops import dataset_ops
21
22NON_PARALLEL = "non_parallel"
23EXPERIMENTAL_PARALLEL = "experimental_parallel"
24CORE_PARALLEL = "core_parallel"
25
26
27def _make_fake_dataset_fn(initial_delay_us, remainder_delay_us):
28  """Returns a dataset that emulates a remote storage data source.
29
30  Returns a dataset factory which creates a dataset with 100 elements that
31  emulates the performance characteristic of a file-based dataset stored in a
32  remote storage. In particular, the first element will take an order of
33  magnitude longer to produce than the remaining elements (100ms vs. 1ms).
34
35  Args:
36    initial_delay_us: How long to wait before producing the first element.
37    remainder_delay_us: How long to wait before producing subsequent elements.
38  """
39
40  def fake_dataset_fn(unused):
41    """Returns a function that creates a dataset with the specified delays."""
42    del unused
43
44    def make_dataset(time_us, num_elements):
45      dataset = dataset_ops.Dataset.range(num_elements)
46      if time_us > 0:
47        dataset = dataset.apply(testing.sleep(time_us))
48      return dataset
49
50    if not initial_delay_us:
51      return make_dataset(remainder_delay_us, 100)
52
53    return make_dataset(initial_delay_us,
54                        0).concatenate(make_dataset(remainder_delay_us, 100))
55
56  return fake_dataset_fn
57
58
59class ParallelInterleaveBenchmark(benchmark_base.DatasetBenchmarkBase):
60  """Benchmarks for `tf.data.experimental.parallel_interleave()`."""
61
62  def apply_interleave(self, interleave_version, dataset, interleave_fn,
63                       cycle_length, num_parallel_calls):
64    if interleave_version == NON_PARALLEL:
65      return dataset.interleave(interleave_fn, cycle_length=cycle_length)
66    elif interleave_version == EXPERIMENTAL_PARALLEL:
67      return dataset.apply(
68          interleave_ops.parallel_interleave(
69              interleave_fn, cycle_length=cycle_length))
70    elif interleave_version == CORE_PARALLEL:
71      if not num_parallel_calls:
72        num_parallel_calls = cycle_length
73      return dataset.interleave(
74          interleave_fn,
75          cycle_length=cycle_length,
76          num_parallel_calls=num_parallel_calls)
77    else:
78      raise ValueError("Unknown version: " + interleave_version)
79
80  def make_dataset(self,
81                   interleave_version,
82                   initial_delay,
83                   remainder_delay,
84                   cycle_length,
85                   num_parallel_calls=None):
86    dataset = dataset_ops.Dataset.range(1).repeat()
87    interleave_fn = _make_fake_dataset_fn(initial_delay, remainder_delay)
88    return self.apply_interleave(
89        interleave_version=interleave_version,
90        dataset=dataset,
91        interleave_fn=interleave_fn,
92        cycle_length=cycle_length,
93        num_parallel_calls=num_parallel_calls)
94
95  def _benchmark(self,
96                 interleave_version,
97                 num_elements,
98                 benchmark_id,
99                 benchmark_label,
100                 initial_delay_us=0,
101                 remainder_delay_us=0,
102                 cycle_length=10,
103                 iters=100,
104                 num_parallel_calls=None,
105                 name=None):
106    dataset = self.make_dataset(
107        interleave_version=interleave_version,
108        initial_delay=initial_delay_us,
109        remainder_delay=remainder_delay_us,
110        cycle_length=cycle_length,
111        num_parallel_calls=num_parallel_calls)
112
113    self.run_and_report_benchmark(
114        dataset=dataset,
115        num_elements=num_elements,
116        iters=iters,
117        warmup=True,
118        extras={
119            "model_name":
120                "interleave.benchmark.%s.%d" % (benchmark_label, benchmark_id),
121            "parameters":
122                "%d.%d.%d.%s" %
123                (num_elements, cycle_length, iters, str(num_parallel_calls)),
124        },
125        name=name)
126
127  def benchmark_remote_file_simulation(self):
128    for i, version in enumerate([EXPERIMENTAL_PARALLEL, CORE_PARALLEL]):
129      self._benchmark(
130          interleave_version=version,
131          initial_delay_us=100 * 1000,
132          remainder_delay_us=1000,
133          num_elements=5000,
134          name="remote_file_simulation_" + version,
135          benchmark_id=i,
136          benchmark_label="remote_file")
137
138  def benchmark_fast_input(self):
139    for i, version in enumerate([EXPERIMENTAL_PARALLEL, CORE_PARALLEL]):
140      self._benchmark(
141          interleave_version=version,
142          num_elements=200000,
143          name="fast_input_" + version,
144          benchmark_id=i,
145          benchmark_label="fast_input")
146
147  # Measure the overhead of parallel interleaves compared to non-parallel
148  # interleave.
149  def benchmark_single_cycle(self):
150    for i, version in enumerate(
151        [NON_PARALLEL, EXPERIMENTAL_PARALLEL, CORE_PARALLEL]):
152      self._benchmark(
153          interleave_version=version,
154          cycle_length=1,
155          num_elements=200000,
156          name="single_cycle_" + version,
157          benchmark_id=i,
158          benchmark_label="single_cycle")
159
160  # Compare with a more reasonable cycle length. Experimental interleave
161  # cannot be compared here because it sets num_parallel_calls = cycle_length.
162  def benchmark_single_parallel_call(self):
163    self._benchmark(
164        interleave_version=CORE_PARALLEL,
165        num_elements=200000,
166        num_parallel_calls=1,
167        name="single_parallel_call_" + CORE_PARALLEL,
168        benchmark_id=1,
169        benchmark_label="single_parallel_call")
170
171  def benchmark_long_cycle(self):
172    for i, version in enumerate([EXPERIMENTAL_PARALLEL, CORE_PARALLEL]):
173      self._benchmark(
174          interleave_version=version,
175          cycle_length=1000,
176          num_elements=100000,
177          name="long_cycle_" + version,
178          benchmark_id=i,
179          benchmark_label="long_cycle")
180
181
182if __name__ == "__main__":
183  benchmark_base.test.main()
184