1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Benchmarks for `tf.data.Dataset.map()`.""" 16import numpy as np 17 18from tensorflow.python.data.benchmarks import benchmark_base 19from tensorflow.python.data.ops import dataset_ops 20from tensorflow.python.framework import constant_op 21from tensorflow.python.ops import array_ops 22from tensorflow.python.ops import control_flow_ops 23from tensorflow.python.ops import map_fn 24from tensorflow.python.ops import math_ops 25from tensorflow.python.ops import random_ops 26 27 28class MapBenchmark(benchmark_base.DatasetBenchmarkBase): 29 """Benchmarks for `tf.data.Dataset.map()`.""" 30 31 def benchmark_chain_of_maps(self): 32 33 def benchmark_helper(chain_length, fn, use_inter_op_parallelism, label, 34 benchmark_id): 35 dataset = dataset_ops.Dataset.range(10000) 36 for _ in range(chain_length): 37 dataset = dataset_ops.MapDataset( 38 dataset, fn, use_inter_op_parallelism=use_inter_op_parallelism) 39 self.run_and_report_benchmark( 40 dataset, 41 num_elements=10000, 42 extras={ 43 "model_name": "map.benchmark.%d" % benchmark_id, 44 "parameters": "%d" % chain_length, 45 }, 46 name="chain_length_%d%s" % (chain_length, label)) 47 48 chain_lengths = [0, 1, 2, 5, 10, 20, 50] 49 for chain_length in chain_lengths: 50 benchmark_helper( 51 chain_length=chain_length, 52 fn=lambda x: x + 1, 53 use_inter_op_parallelism=True, 54 label="", 55 benchmark_id=1) 56 benchmark_helper( 57 chain_length=chain_length, 58 fn=lambda x: x + 1, 59 use_inter_op_parallelism=False, 60 label="_single_threaded", 61 benchmark_id=2) 62 benchmark_helper( 63 chain_length=chain_length, 64 fn=lambda x: x, 65 use_inter_op_parallelism=True, 66 label="_short_circuit", 67 benchmark_id=3) 68 69 def benchmark_map_fan_out(self): 70 fan_outs = [1, 2, 5, 10, 20, 50, 100] 71 72 def benchmark_helper(fan_out, fn, use_inter_op_parallelism, label, 73 benchmark_id): 74 dataset = dataset_ops.Dataset.from_tensors( 75 tuple(0 for _ in range(fan_out))).repeat(None) 76 dataset = dataset_ops.MapDataset( 77 dataset, fn, use_inter_op_parallelism=use_inter_op_parallelism) 78 self.run_and_report_benchmark( 79 dataset, 80 num_elements=10000, 81 extras={ 82 "model_name": "map.benchmark.%d" % benchmark_id, 83 "parameters": "%d" % fan_out, 84 }, 85 name="fan_out_%d%s" % (fan_out, label)) 86 87 for fan_out in fan_outs: 88 benchmark_helper( 89 fan_out=fan_out, 90 fn=lambda *xs: [x + 1 for x in xs], 91 use_inter_op_parallelism=True, 92 label="", 93 benchmark_id=4) 94 benchmark_helper( 95 fan_out=fan_out, 96 fn=lambda *xs: [x + 1 for x in xs], 97 use_inter_op_parallelism=False, 98 label="_single_threaded", 99 benchmark_id=5) 100 benchmark_helper( 101 fan_out=fan_out, 102 fn=lambda *xs: xs, 103 use_inter_op_parallelism=True, 104 label="_short_circuit", 105 benchmark_id=6) 106 107 def benchmark_sequential_control_flow(self): 108 dataset = dataset_ops.Dataset.from_tensors(100000) 109 110 def fn(x): 111 i = constant_op.constant(0) 112 113 def body(i, x): 114 return math_ops.add(i, 1), x 115 116 return control_flow_ops.while_loop(math_ops.less, body, [i, x]) 117 118 num_elements = 1 119 dataset = dataset.map(fn) 120 self.run_and_report_benchmark( 121 dataset, 122 num_elements=num_elements, 123 extras={ 124 "model_name": "map.benchmark.8", 125 "parameters": "%d" % num_elements, 126 }, 127 name="sequential_control_flow", 128 apply_default_optimizations=True) 129 130 def benchmark_parallel_control_flow(self): 131 dataset = dataset_ops.Dataset.from_tensors( 132 random_ops.random_uniform([100, 10000000])) 133 134 def fn(x): 135 return map_fn.map_fn( 136 lambda y: y * array_ops.transpose(y), x, parallel_iterations=10) 137 138 num_elements = 1 139 dataset = dataset.map(fn) 140 self.run_and_report_benchmark( 141 dataset, 142 num_elements=1, 143 extras={ 144 "model_name": "map.benchmark.9", 145 "parameters": "%d" % num_elements, 146 }, 147 name="parallel_control_flow", 148 apply_default_optimizations=True) 149 150 def _benchmark_nested_parallel_map(self, cycle_length, num_parallel_calls): 151 k = 1024 * 1024 152 num_map_elements = 10 153 num_range_elements = 2000 154 155 def g(_): 156 return np.random.rand(50 * k).sum() 157 158 def f(_): 159 return dataset_ops.Dataset.range(num_map_elements).map( 160 g, num_parallel_calls=num_parallel_calls) 161 162 dataset = dataset_ops.Dataset.range(num_range_elements) 163 dataset = dataset.interleave( 164 f, cycle_length=cycle_length, num_parallel_calls=dataset_ops.AUTOTUNE) 165 166 cycle_length_str = ("default" 167 if cycle_length is None else str(cycle_length)) 168 num_parallel_calls_str = ("autotune" 169 if num_parallel_calls == dataset_ops.AUTOTUNE else 170 str(num_parallel_calls)) 171 map_dataset_str = ("map" if num_parallel_calls is None else 172 "parallel_map_num_parallel_calls_%s" % 173 num_parallel_calls_str) 174 175 self.run_and_report_benchmark( 176 dataset, 177 num_elements=num_map_elements * num_range_elements, 178 extras={ 179 "model_name": "map.benchmark.10", 180 "parameters": "%s_%s" % (cycle_length_str, num_parallel_calls_str), 181 }, 182 name=("%s_cycle_length_%s" % (map_dataset_str, cycle_length_str))) 183 184 def benchmark_nested_parallel_map(self): 185 cycle_lengths = [None, 100] 186 nums_parallel_calls = [None, 1, 10, 100, dataset_ops.AUTOTUNE] 187 for cycle_length in cycle_lengths: 188 for num_parallel_calls in nums_parallel_calls: 189 self._benchmark_nested_parallel_map(cycle_length, num_parallel_calls) 190 191 192if __name__ == "__main__": 193 benchmark_base.test.main() 194