• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Benchmarks for `tf.data.Dataset.batch()`."""
16import numpy as np
17
18from tensorflow.python.data.benchmarks import benchmark_base
19from tensorflow.python.data.ops import dataset_ops
20from tensorflow.python.data.ops import options as options_lib
21from tensorflow.python.framework import sparse_tensor
22from tensorflow.python.ops import random_ops
23
24
25class BatchBenchmark(benchmark_base.DatasetBenchmarkBase):
26  """Benchmarks for `tf.data.Dataset.batch()`."""
27
28  def benchmark_batch_sparse(self):
29    non_zeros_per_row_values = [0, 1, 5, 10, 100]
30    batch_size_values = [1, 32, 64, 128, 1024]
31
32    for non_zeros_per_row in non_zeros_per_row_values:
33
34      tensor = sparse_tensor.SparseTensor(
35          indices=np.arange(non_zeros_per_row, dtype=np.int64)[:, np.newaxis],
36          values=np.arange(non_zeros_per_row, dtype=np.int64),
37          dense_shape=[1000])
38
39      for batch_size in batch_size_values:
40        dataset = dataset_ops.Dataset.from_tensors(tensor).repeat().batch(
41            batch_size)
42        self.run_and_report_benchmark(
43            dataset,
44            num_elements=100000 // batch_size,
45            iters=1,
46            extras={
47                "model_name": "batch.benchmark.1",
48                "parameters": "%d.%d" % (batch_size, non_zeros_per_row),
49            },
50            name="sparse_num_elements_%d_batch_size_%d" %
51            (non_zeros_per_row, batch_size))
52
53  def _benchmark_batch_dense(self, parallel_copy, benchmark_id):
54    for element_exp in [10, 12, 14, 16, 18, 20, 22]:
55      for batch_exp in [3, 6, 9]:
56        element_size = 1 << element_exp
57        batch_size = 1 << batch_exp
58        dataset = dataset_ops.Dataset.from_tensors(
59            np.random.rand(element_size)).repeat().batch(batch_size)
60        options = options_lib.Options()
61        options.experimental_optimization.parallel_batch = parallel_copy
62        dataset = dataset.with_options(options)
63        tag = "_parallel_copy" if parallel_copy else ""
64        self.run_and_report_benchmark(
65            dataset,
66            num_elements=(1 << (22 - batch_exp - element_exp // 2)),
67            iters=1,
68            extras={
69                "model_name": "batch.benchmark.%d" % benchmark_id,
70                "parameters": "%d.%d" % (batch_size, element_size),
71            },
72            name="batch_element_size_%d_batch_size_%d%s" %
73            (element_size, batch_size, tag))
74
75  def benchmark_batch_dense(self):
76    self._benchmark_batch_dense(parallel_copy=False, benchmark_id=2)
77    self._benchmark_batch_dense(parallel_copy=True, benchmark_id=3)
78
79  def benchmark_parallel_batch(self):
80    batch_size = 128
81    nums_parallel_calls = [None, 1, 4, 16, dataset_ops.AUTOTUNE]
82    num_range = 100000
83
84    def f(_):
85      return random_ops.random_uniform([224, 224, 3])
86
87    for num_parallel_calls in nums_parallel_calls:
88      num_parallel_calls_str = ("autotune"
89                                if num_parallel_calls == dataset_ops.AUTOTUNE
90                                else str(num_parallel_calls))
91      op_str = ("batch" if num_parallel_calls is None else
92                ("parallel_batch_num_parallel_calls_%s" %
93                 num_parallel_calls_str))
94
95      dataset = dataset_ops.Dataset.range(num_range).map(f).batch(
96          batch_size, num_parallel_calls=num_parallel_calls)
97      self.run_and_report_benchmark(
98          dataset,
99          num_elements=num_range // batch_size,
100          iters=1,
101          extras={
102              "model_name": "batch.benchmark.4",
103              "parameters": "%d.%s" % (batch_size, num_parallel_calls_str),
104          },
105          name="batch_size_%d_%s" % (batch_size, op_str))
106
107
108if __name__ == "__main__":
109  benchmark_base.test.main()
110