1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Benchmarks for `tf.data.Dataset.interleave()`.""" 16 17from tensorflow.python.data.benchmarks import benchmark_base 18from tensorflow.python.data.experimental.ops import interleave_ops 19from tensorflow.python.data.experimental.ops import testing 20from tensorflow.python.data.ops import dataset_ops 21 22NON_PARALLEL = "non_parallel" 23EXPERIMENTAL_PARALLEL = "experimental_parallel" 24CORE_PARALLEL = "core_parallel" 25 26 27def _make_fake_dataset_fn(initial_delay_us, remainder_delay_us): 28 """Returns a dataset that emulates a remote storage data source. 29 30 Returns a dataset factory which creates a dataset with 100 elements that 31 emulates the performance characteristic of a file-based dataset stored in a 32 remote storage. In particular, the first element will take an order of 33 magnitude longer to produce than the remaining elements (100ms vs. 1ms). 34 35 Args: 36 initial_delay_us: How long to wait before producing the first element. 37 remainder_delay_us: How long to wait before producing subsequent elements. 38 """ 39 40 def fake_dataset_fn(unused): 41 """Returns a function that creates a dataset with the specified delays.""" 42 del unused 43 44 def make_dataset(time_us, num_elements): 45 dataset = dataset_ops.Dataset.range(num_elements) 46 if time_us > 0: 47 dataset = dataset.apply(testing.sleep(time_us)) 48 return dataset 49 50 if not initial_delay_us: 51 return make_dataset(remainder_delay_us, 100) 52 53 return make_dataset(initial_delay_us, 54 0).concatenate(make_dataset(remainder_delay_us, 100)) 55 56 return fake_dataset_fn 57 58 59class ParallelInterleaveBenchmark(benchmark_base.DatasetBenchmarkBase): 60 """Benchmarks for `tf.data.experimental.parallel_interleave()`.""" 61 62 def apply_interleave(self, interleave_version, dataset, interleave_fn, 63 cycle_length, num_parallel_calls): 64 if interleave_version == NON_PARALLEL: 65 return dataset.interleave(interleave_fn, cycle_length=cycle_length) 66 elif interleave_version == EXPERIMENTAL_PARALLEL: 67 return dataset.apply( 68 interleave_ops.parallel_interleave( 69 interleave_fn, cycle_length=cycle_length)) 70 elif interleave_version == CORE_PARALLEL: 71 if not num_parallel_calls: 72 num_parallel_calls = cycle_length 73 return dataset.interleave( 74 interleave_fn, 75 cycle_length=cycle_length, 76 num_parallel_calls=num_parallel_calls) 77 else: 78 raise ValueError("Unknown version: " + interleave_version) 79 80 def make_dataset(self, 81 interleave_version, 82 initial_delay, 83 remainder_delay, 84 cycle_length, 85 num_parallel_calls=None): 86 dataset = dataset_ops.Dataset.range(1).repeat() 87 interleave_fn = _make_fake_dataset_fn(initial_delay, remainder_delay) 88 return self.apply_interleave( 89 interleave_version=interleave_version, 90 dataset=dataset, 91 interleave_fn=interleave_fn, 92 cycle_length=cycle_length, 93 num_parallel_calls=num_parallel_calls) 94 95 def _benchmark(self, 96 interleave_version, 97 num_elements, 98 benchmark_id, 99 benchmark_label, 100 initial_delay_us=0, 101 remainder_delay_us=0, 102 cycle_length=10, 103 iters=100, 104 num_parallel_calls=None, 105 name=None): 106 dataset = self.make_dataset( 107 interleave_version=interleave_version, 108 initial_delay=initial_delay_us, 109 remainder_delay=remainder_delay_us, 110 cycle_length=cycle_length, 111 num_parallel_calls=num_parallel_calls) 112 113 self.run_and_report_benchmark( 114 dataset=dataset, 115 num_elements=num_elements, 116 iters=iters, 117 warmup=True, 118 extras={ 119 "model_name": 120 "interleave.benchmark.%s.%d" % (benchmark_label, benchmark_id), 121 "parameters": 122 "%d.%d.%d.%s" % 123 (num_elements, cycle_length, iters, str(num_parallel_calls)), 124 }, 125 name=name) 126 127 def benchmark_remote_file_simulation(self): 128 for i, version in enumerate([EXPERIMENTAL_PARALLEL, CORE_PARALLEL]): 129 self._benchmark( 130 interleave_version=version, 131 initial_delay_us=100 * 1000, 132 remainder_delay_us=1000, 133 num_elements=5000, 134 name="remote_file_simulation_" + version, 135 benchmark_id=i, 136 benchmark_label="remote_file") 137 138 def benchmark_fast_input(self): 139 for i, version in enumerate([EXPERIMENTAL_PARALLEL, CORE_PARALLEL]): 140 self._benchmark( 141 interleave_version=version, 142 num_elements=200000, 143 name="fast_input_" + version, 144 benchmark_id=i, 145 benchmark_label="fast_input") 146 147 # Measure the overhead of parallel interleaves compared to non-parallel 148 # interleave. 149 def benchmark_single_cycle(self): 150 for i, version in enumerate( 151 [NON_PARALLEL, EXPERIMENTAL_PARALLEL, CORE_PARALLEL]): 152 self._benchmark( 153 interleave_version=version, 154 cycle_length=1, 155 num_elements=200000, 156 name="single_cycle_" + version, 157 benchmark_id=i, 158 benchmark_label="single_cycle") 159 160 # Compare with a more reasonable cycle length. Experimental interleave 161 # cannot be compared here because it sets num_parallel_calls = cycle_length. 162 def benchmark_single_parallel_call(self): 163 self._benchmark( 164 interleave_version=CORE_PARALLEL, 165 num_elements=200000, 166 num_parallel_calls=1, 167 name="single_parallel_call_" + CORE_PARALLEL, 168 benchmark_id=1, 169 benchmark_label="single_parallel_call") 170 171 def benchmark_long_cycle(self): 172 for i, version in enumerate([EXPERIMENTAL_PARALLEL, CORE_PARALLEL]): 173 self._benchmark( 174 interleave_version=version, 175 cycle_length=1000, 176 num_elements=100000, 177 name="long_cycle_" + version, 178 benchmark_id=i, 179 benchmark_label="long_cycle") 180 181 182if __name__ == "__main__": 183 benchmark_base.test.main() 184