• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15r"""Benchmarks for low-level eager execution primitives.
16
17To run CPU benchmarks:
18  bazel run -c opt benchmarks_test -- --benchmarks=.
19
20To run GPU benchmarks:
21  bazel run --config=cuda -c opt --copt="-mavx" benchmarks_test -- \
22    --benchmarks=.
23"""
24from __future__ import absolute_import
25from __future__ import division
26from __future__ import print_function
27
28import time
29
30import numpy as np
31import six
32from six.moves import xrange  # pylint: disable=redefined-builtin
33
34from tensorflow.python import pywrap_tensorflow
35from tensorflow.python.eager import backprop  # pylint: disable=unused-import
36from tensorflow.python.eager import context
37from tensorflow.python.eager import core
38from tensorflow.python.eager import execute
39from tensorflow.python.eager import function
40from tensorflow.python.eager import test
41from tensorflow.python.framework import dtypes
42from tensorflow.python.framework import ops
43from tensorflow.python.ops import gen_array_ops
44from tensorflow.python.ops import gen_math_ops
45from tensorflow.python.ops import math_ops
46from tensorflow.python.ops import random_ops
47from tensorflow.python.ops import resource_variable_ops
48
49CPU = "/device:CPU:0"
50GPU = "/device:GPU:0"
51
52
53def c_tfe_py_fastpath_execute(a,
54                              b,
55                              transpose_a=False,
56                              transpose_b=False,
57                              name=None):
58  ctx = context.context()
59  assert not ctx.in_graph_mode(
60  ), "The prototype doesn't contain C code for graph construction"
61  try:
62    return pywrap_tensorflow.TFE_Py_FastPathExecute(
63        ctx._handle, ctx.device_name, "MatMul", execute.record_gradient, name,
64        ctx._post_execution_callbacks, a, b, "transpose_a", transpose_a,
65        "transpose_b", transpose_b)
66  except core._NotOkStatusException as e:
67    if name is not None:
68      message = e.message + " name: " + name
69    else:
70      message = e.message
71    six.raise_from(core._status_to_exception(e.code, message), None)
72
73
74class MicroBenchmarks(test.Benchmark):
75
76  def __init__(self):
77    # used for multiply benchmarks
78    self._m_2 = random_ops.random_uniform([2])
79
80    # used for matmul benchmarks
81    self._m_2_by_2 = random_ops.random_uniform((2, 2))
82    self._m_100_by_784 = random_ops.random_uniform((100, 784))
83    self._num_iters_2_by_2 = 30000
84    self._num_iters_100_by_784 = 1000
85
86  def _run(self, func, num_iters):
87    # call func to maybe warm up the GPU
88    func()
89    start = time.time()
90    for _ in xrange(num_iters):
91      func()
92    end = time.time()
93    mean_us = (end - start) * 1e6 / num_iters
94    self.report_benchmark(iters=num_iters, wall_time=mean_us,
95                          extras={"examples_per_sec": num_iters/(end-start)})
96
97  def benchmark_create_np_array(self):
98    func = lambda: np.array([3.0])
99    self._run(func, 30000)
100
101  def _benchmark_create_tensor(self, value, dtype, device):
102    """Benchmark overheads of creating a Tensor object."""
103    ctx = context.context()
104    handle = ctx._handle
105    if device == GPU:
106      # Warmup the GPU
107      ops.EagerTensor(value, context=handle, device=device)
108
109    def func():
110      ops.EagerTensor(value, context=handle, device=device, dtype=dtype)
111    self._run(func, 30000)
112
113  def benchmark_create_float_tensor_from_list_CPU(self):
114    self._benchmark_create_tensor([[3.0]], dtypes.float32.as_datatype_enum, CPU)
115
116  def benchmark_create_float_tensor_from_np_array_CPU(self):
117    self._benchmark_create_tensor(
118        np.array([[3.0]], dtype=np.float32), dtypes.float32.as_datatype_enum,
119        CPU)
120
121  def benchmark_create_int32_tensor_from_list_CPU(self):
122    self._benchmark_create_tensor([[3]], dtypes.int32.as_datatype_enum, CPU)
123
124  def benchmark_create_int32_tensor_from_np_array_CPU(self):
125    self._benchmark_create_tensor(
126        np.array([[3]], dtype=np.int32), dtypes.int32.as_datatype_enum, CPU)
127
128  def benchmark_create_float_tensor_from_list_GPU(self):
129    if not context.num_gpus():
130      return
131    self._benchmark_create_tensor([[3.0]], dtypes.float32.as_datatype_enum, GPU)
132
133  def benchmark_create_float_tensor_from_np_array_GPU(self):
134    if not context.num_gpus():
135      return
136    self._benchmark_create_tensor(
137        np.array([[3.0]], dtype=np.float32), dtypes.float32.as_datatype_enum,
138        GPU)
139
140  def benchmark_create_int32_tensor_from_list_GPU(self):
141    # int32's are kept on host memory even when executing on GPU.
142    if not context.num_gpus():
143      return
144    self._benchmark_create_tensor([[3]], dtypes.int32.as_datatype_enum, GPU)
145
146  def benchmark_create_int32_tensor_from_np_array_GPU(self):
147    # int32's are kept on host memory even when executing on GPU.
148    if not context.num_gpus():
149      return
150    self._benchmark_create_tensor(
151        np.array([[3]], dtype=np.int32), dtypes.int32.as_datatype_enum, GPU)
152
153  def _benchmark_np_multiply(self, m, num_iters):
154    a = m.cpu().numpy()
155    func = lambda: a * a
156    self._run(func, num_iters)
157
158  def _benchmark_tf_multiply(self, m, num_iters):
159    func = lambda: m * m
160    self._run(func, num_iters)
161
162  def _benchmark_tf_multiply_op(self, m, num_iters):
163    func = lambda: math_ops.multiply(m, m)
164    self._run(func, num_iters)
165
166  def benchmark_np_multiply(self):
167    self._benchmark_np_multiply(self._m_2, 30000)
168
169  def benchmark_tf_multiply_CPU(self):
170    with context.device(CPU):
171      m = self._m_2.cpu()
172      self._benchmark_tf_multiply(m, 30000)
173
174  def benchmark_tf_multiply_GPU(self):
175    if not context.num_gpus():
176      return
177    with context.device(GPU):
178      m = self._m_2.gpu()
179      self._benchmark_tf_multiply(m, 30000)
180
181  def benchmark_tf_multiply_op_CPU(self):
182    with context.device(CPU):
183      m = self._m_2.cpu()
184      self._benchmark_tf_multiply_op(m, 30000)
185
186  def benchmark_tf_multiply_op_GPU(self):
187    if not context.num_gpus():
188      return
189    with context.device(GPU):
190      m = self._m_2.gpu()
191      self._benchmark_tf_multiply_op(m, 30000)
192
193  def benchmark_tf_identity(self):
194    m = self._m_2
195    self._run(lambda: gen_array_ops.identity(m), 30000)
196
197  def benchmark_tfe_py_execute_identity(self):
198    m = self._m_2
199    ctx_handle = context.context()._handle
200    attrs = ("T", self._m_2.dtype.as_datatype_enum)
201    inputs = [m]
202
203    def f():
204      pywrap_tensorflow.TFE_Py_Execute(
205          ctx_handle, None, "Identity", inputs, attrs, 1)
206
207    self._run(f, 30000)
208
209  def benchmark_tf_gradient_function_identity(self):
210    m = self._m_2
211    self._run(
212        lambda: backprop.gradients_function(gen_array_ops.identity, [0])(m),
213        30000)
214
215  def benchmark_tf_gradient_forward_identity(self):
216    with backprop.GradientTape() as tape:
217      m = self._m_2
218      tape.watch(m)
219      self._run(lambda: gen_array_ops.identity(m), 30000)
220
221  def benchmark_tf_gradient_tape_push_pop(self):
222
223    def f():
224      with backprop.GradientTape():
225        pass
226    self._run(f, 30000)
227
228  def benchmark_tf_gradient_function_no_op(self):
229    m = self._m_2
230    self._run(
231        lambda: backprop.gradients_function(lambda x: x, [0])(m),
232        30000)
233
234  def _benchmark_np_matmul(self, m, transpose_b, num_iters):
235    a = m.cpu().numpy()
236    b = a.T if transpose_b else a
237    func = lambda: np.dot(a, b)
238    self._run(func, num_iters)
239
240  def _benchmark_tf_matmul(self, m, transpose_b, num_iters):
241    func = lambda: math_ops.matmul(m, m, transpose_b=transpose_b)
242    self._run(func, num_iters)
243
244  def _benchmark_gen_math_ops_matmul(self, m, transpose_b, num_iters):
245    def func():
246      gen_math_ops._mat_mul(m, m, transpose_b=transpose_b)
247    self._run(func, num_iters)
248
249  def _benchmark_tfe_py_fastpath_execute_matmul(self, m, transpose_b,
250                                                num_iters):
251
252    def func():
253      c_tfe_py_fastpath_execute(m, m, transpose_b=transpose_b)
254
255    self._run(func, num_iters)
256
257  def _benchmark_tfe_py_execute_matmul(self, m, transpose_b, num_iters):
258    inputs = [m, m]
259    # pylint: disable=protected-access
260    ctx_handle = context.context()._handle
261    # pylint: enable=protected-access
262    attrs = ("transpose_a", False, "transpose_b", transpose_b, "T",
263             m.dtype.as_datatype_enum)
264    def func():
265      pywrap_tensorflow.TFE_Py_Execute(ctx_handle, None, "MatMul", inputs,
266                                       attrs, 1)
267
268    self._run(func, num_iters)
269
270  def _benchmark_defun_matmul(self, m, transpose_b, num_iters):
271    f = function.defun(math_ops.matmul)
272    func = lambda: f(m, m, transpose_b)
273    self._run(func, num_iters)
274
275  def _benchmark_read_variable(self, m, num_iters):
276    self._run(m.value, num_iters)
277
278  def _benchmark_read_variable_with_tape(self, m, num_iters):
279    with backprop.GradientTape() as tape:
280      tape.watch(m)
281      self._run(m.value, num_iters)
282
283  # Benchmarks for A^2, A of dimension 2 by 2.
284  def benchmark_np_matmul_2_by_2(self):
285    self._benchmark_np_matmul(
286        self._m_2_by_2, transpose_b=False, num_iters=self._num_iters_2_by_2)
287
288  def benchmark_tf_matmul_2_by_2_CPU(self):
289    with context.device(CPU):
290      m = self._m_2_by_2.cpu()
291      self._benchmark_tf_matmul(
292          m, transpose_b=False, num_iters=self._num_iters_2_by_2)
293
294  def benchmark_gen_math_ops_matmul_2_by_2_CPU(self):
295    with context.device(CPU):
296      m = self._m_2_by_2.cpu()
297      self._benchmark_gen_math_ops_matmul(
298          m, transpose_b=False, num_iters=self._num_iters_2_by_2)
299
300  def benchmark_tfe_py_fastpath_execute_matmul_2_by_2_CPU(self):
301    with context.device(CPU):
302      m = self._m_2_by_2.cpu()
303      self._benchmark_tfe_py_fastpath_execute_matmul(
304          m, transpose_b=False, num_iters=self._num_iters_2_by_2)
305
306  def benchmark_tfe_py_execute_matmul_2_by_2_CPU(self):
307    with context.device(CPU):
308      m = self._m_2_by_2.cpu()
309      self._benchmark_tfe_py_execute_matmul(
310          m, transpose_b=False, num_iters=self._num_iters_2_by_2)
311
312  def benchmark_defun_matmul_2_by_2_CPU(self):
313    with context.device(CPU):
314      m = self._m_2_by_2.cpu()
315      self._benchmark_defun_matmul(
316          m, transpose_b=False, num_iters=self._num_iters_2_by_2)
317
318  def benchmark_tf_matmul_2_by_2_GPU(self):
319    if not context.num_gpus():
320      return
321    with context.device(GPU):
322      m = self._m_2_by_2.gpu()
323      self._benchmark_tf_matmul(
324          m, transpose_b=False, num_iters=self._num_iters_2_by_2)
325
326  def benchmark_gen_math_ops_matmul_2_by_2_GPU(self):
327    if not context.num_gpus():
328      return
329    with context.device(GPU):
330      m = self._m_2_by_2.gpu()
331      self._benchmark_gen_math_ops_matmul(
332          m, transpose_b=False, num_iters=self._num_iters_2_by_2)
333
334  def benchmark_tfe_py_execute_matmul_2_by_2_GPU(self):
335    if not context.num_gpus():
336      return
337    with context.device(GPU):
338      m = self._m_2_by_2.gpu()
339      self._benchmark_tfe_py_execute_matmul(
340          m, transpose_b=False, num_iters=self._num_iters_2_by_2)
341
342  def benchmark_defun_matmul_2_by_2_GPU(self):
343    if not context.num_gpus():
344      return
345    with context.device(GPU):
346      m = self._m_2_by_2.gpu()
347      self._benchmark_defun_matmul(
348          m, transpose_b=False, num_iters=self._num_iters_2_by_2)
349
350  # Benchmarks for AA.T, A of dimension 100 by 784.
351  def benchmark_np_matmul_100_by_784(self):
352    self._benchmark_np_matmul(
353        self._m_100_by_784,
354        transpose_b=True,
355        num_iters=self._num_iters_100_by_784)
356
357  def benchmark_tf_matmul_100_by_784_CPU(self):
358    with context.device(CPU):
359      m = self._m_100_by_784.cpu()
360      self._benchmark_tf_matmul(
361          m, transpose_b=True, num_iters=self._num_iters_100_by_784)
362
363  def benchmark_gen_math_ops_matmul_100_by_784_CPU(self):
364    with context.device(CPU):
365      m = self._m_100_by_784.cpu()
366      self._benchmark_gen_math_ops_matmul(
367          m, transpose_b=True, num_iters=self._num_iters_100_by_784)
368
369  def benchmark_tfe_py_fastpath_execute_matmul_100_by_784_CPU(self):
370    with context.device(CPU):
371      m = self._m_100_by_784.cpu()
372      self._benchmark_tfe_py_fastpath_execute_matmul(
373          m, transpose_b=True, num_iters=self._num_iters_100_by_784)
374
375  def benchmark_tfe_py_execute_matmul_100_by_784_CPU(self):
376    with context.device(CPU):
377      m = self._m_100_by_784.cpu()
378      self._benchmark_tfe_py_execute_matmul(
379          m, transpose_b=True, num_iters=self._num_iters_100_by_784)
380
381  def benchmark_defun_matmul_100_by_784_CPU(self):
382    with context.device(CPU):
383      m = self._m_100_by_784.cpu()
384      self._benchmark_defun_matmul(
385          m, transpose_b=True, num_iters=self._num_iters_100_by_784)
386
387  def benchmark_tf_matmul_100_by_784_GPU(self):
388    if not context.num_gpus():
389      return
390    with context.device(GPU):
391      m = self._m_100_by_784.gpu()
392      self._benchmark_tf_matmul(
393          m, transpose_b=True, num_iters=self._num_iters_100_by_784)
394
395  def benchmark_gen_math_ops_matmul_100_by_784_GPU(self):
396    if not context.num_gpus():
397      return
398    with context.device(GPU):
399      m = self._m_100_by_784.gpu()
400      self._benchmark_gen_math_ops_matmul(
401          m, transpose_b=True, num_iters=self._num_iters_100_by_784)
402
403  def benchmark_tfe_py_execute_matmul_100_by_784_GPU(self):
404    if not context.num_gpus():
405      return
406    with context.device(GPU):
407      m = self._m_100_by_784.gpu()
408      self._benchmark_tfe_py_execute_matmul(
409          m, transpose_b=True, num_iters=self._num_iters_100_by_784)
410
411  def benchmark_defun_matmul_100_by_784_GPU(self):
412    if not context.num_gpus():
413      return
414    with context.device(GPU):
415      m = self._m_100_by_784.gpu()
416      self._benchmark_defun_matmul(
417          m, transpose_b=True, num_iters=self._num_iters_100_by_784)
418
419  def benchmark_read_variable_op_2_by_2_CPU(self):
420    with context.device(CPU):
421      m = resource_variable_ops.ResourceVariable(self._m_2_by_2)
422      self._benchmark_read_variable(m, num_iters=self._num_iters_2_by_2)
423
424  def benchmark_read_variable_op_2_by_2_GPU(self):
425    if not context.num_gpus():
426      return
427    with context.device(GPU):
428      m = resource_variable_ops.ResourceVariable(self._m_2_by_2.gpu())
429      self._benchmark_read_variable(m, num_iters=self._num_iters_2_by_2)
430
431  def benchmark_read_variable_op_with_tape_2_by_2_CPU(self):
432    with context.device(CPU):
433      m = resource_variable_ops.ResourceVariable(self._m_2_by_2)
434      self._benchmark_read_variable_with_tape(
435          m, num_iters=self._num_iters_2_by_2)
436
437  def benchmark_read_variable_op_with_tape_2_by_2_GPU(self):
438    if not context.num_gpus():
439      return
440    with context.device(GPU):
441      m = resource_variable_ops.ResourceVariable(self._m_2_by_2.gpu())
442      self._benchmark_read_variable_with_tape(
443          m, num_iters=self._num_iters_2_by_2)
444
445
446if __name__ == "__main__":
447  test.main()
448