• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Benchmarks for `tf.data.Dataset.from_tensor_slices()`."""
16import numpy as np
17
18from tensorflow.python.data.benchmarks import benchmark_base
19from tensorflow.python.data.experimental.ops import get_single_element
20from tensorflow.python.data.ops import dataset_ops
21from tensorflow.python.data.ops import structured_function
22from tensorflow.python.eager import def_function
23from tensorflow.python.framework import sparse_tensor
24from tensorflow.python.ops import gen_dataset_ops
25
26
27class SingleThreadedFlatMapDataset(dataset_ops.UnaryDataset):
28  """A `Dataset` that maps a function over its input and flattens the result."""
29
30  def __init__(self, input_dataset, map_func):
31    """See `Dataset.flat_map()` for details."""
32    self._input_dataset = input_dataset
33    self._map_func = structured_function.StructuredFunctionWrapper(
34        map_func,
35        self._transformation_name(),
36        dataset=input_dataset,
37        defun_kwargs={"_executor": "SINGLE_THREADED_EXECUTOR"})
38    self._structure = self._map_func.output_structure._element_spec  # pylint: disable=protected-access
39    variant_tensor = gen_dataset_ops.flat_map_dataset(
40        input_dataset._variant_tensor,  # pylint: disable=protected-access
41        self._map_func.function.captured_inputs,
42        f=self._map_func.function,
43        **self._flat_structure)
44    super(SingleThreadedFlatMapDataset, self).__init__(input_dataset,
45                                                       variant_tensor)
46
47  def _functions(self):
48    return [self._map_func]
49
50  @property
51  def element_spec(self):
52    return self._structure
53
54  def _transformation_name(self):
55    return "SingleThreadedFlatMapDataset"
56
57
58class FromTensorSlicesBenchmark(benchmark_base.DatasetBenchmarkBase):
59  """Benchmarks for `tf.data.Dataset.from_tensor_slices()`."""
60
61  def benchmark_slice_repeat_batch(self):
62    input_size = 10000
63    batch_size = 100
64    num_epochs = 100
65    num_elements = input_size * num_epochs // batch_size
66
67    input_data = np.random.randn(input_size)
68
69    dataset = dataset_ops.Dataset.from_tensor_slices(input_data)
70    dataset = dataset.repeat(num_epochs).batch(batch_size)
71
72    self.run_and_report_benchmark(
73        dataset,
74        num_elements=num_elements,
75        extras={
76            "model_name": "from_tensor_slices.benchmark.1",
77            "parameters": "%d.%d" % (input_size, batch_size),
78        },
79        name="slice_repeat_batch_input_%d_batch_%d" % (input_size, batch_size))
80
81  def benchmark_reshape_slice_repeat(self):
82    input_size = 10000
83    reshape_dim = [100, 100]
84    num_epochs = 100
85
86    num_elements = num_epochs * reshape_dim[0]
87
88    data = np.random.randn(input_size).reshape(*reshape_dim)
89    dataset = dataset_ops.Dataset.from_tensor_slices(data).repeat(num_epochs)
90
91    self.run_and_report_benchmark(
92        dataset,
93        num_elements=num_elements,
94        extras={
95            "model_name": "from_tensor_slices.benchmark.2",
96            "parameters": "%d" % input_size,
97        },
98        name="reshape_slice_repeat_input_%d" % input_size,
99    )
100
101  def benchmark_slice_repeat_sparse(self):
102    non_zeros_per_row_values = [0, 1, 5, 10, 100]
103    num_rows_values = [32, 64, 128, 1024]
104
105    for non_zeros_per_row in non_zeros_per_row_values:
106      tensor = sparse_tensor.SparseTensor(
107          indices=np.arange(non_zeros_per_row, dtype=np.int64)[:, np.newaxis],
108          values=np.arange(non_zeros_per_row, dtype=np.int64),
109          dense_shape=[1000])
110
111      for num_rows in num_rows_values:
112
113        # TODO(b/147153744): Function-valued attributes with their own
114        # attributes are currently only supported in graph mode.
115        @def_function.function
116        def make_dataset():
117          # pylint: disable=cell-var-from-loop
118          dataset = dataset_ops.Dataset.from_tensors(tensor)
119          dataset = dataset.repeat(num_rows).batch(num_rows)
120          batched_tensor = get_single_element.get_single_element(dataset)
121
122          dataset = dataset_ops.Dataset.from_tensors(batched_tensor).repeat()
123          return SingleThreadedFlatMapDataset(
124              dataset, dataset_ops.Dataset.from_tensor_slices)
125
126        self.run_and_report_benchmark(
127            make_dataset(),
128            num_elements=100000,
129            iters=5,
130            extras={
131                "model_name": "from_tensor_slices.benchmark.3",
132                "parameters": "%d.%d" % (non_zeros_per_row, num_rows),
133            },
134            name="slice_repeat_sparse_elements_per_row_%d_num_rows_%d" %
135            (non_zeros_per_row, num_rows))
136
137  def benchmark_slice_batch_cache_repeat(self):
138    input_size = 10000
139    batch_size = 100
140    num_epochs = 100
141    num_elements = input_size * num_epochs // batch_size
142
143    input_data = np.random.randn(input_size)
144
145    dataset = (
146        dataset_ops.Dataset.from_tensor_slices(input_data).batch(
147            batch_size).cache().repeat(num_epochs))
148
149    self.run_and_report_benchmark(
150        dataset,
151        num_elements=num_elements,
152        extras={
153            "model_name": "from_tensor_slices.benchmark.4",
154            "parameters": "%d.%d" % (input_size, batch_size),
155        },
156        name="slice_batch_cache_repeat_input_%d_batch_%d" %
157        (input_size, batch_size))
158
159
160if __name__ == "__main__":
161  benchmark_base.test.main()
162