• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Common benchmarking code.
16
17See https://www.tensorflow.org/community/benchmarks for usage.
18"""
19
20from __future__ import absolute_import
21from __future__ import division
22from __future__ import print_function
23
24import time
25
26import numpy as np
27
28import tensorflow as tf
29
30
31class ReportingBenchmark(tf.test.Benchmark):
32  """Base class for a benchmark that reports general performance metrics.
33
34  Subclasses only need to call one of the _profile methods, and optionally
35  report_results.
36  """
37
38  def time_execution(self, name, target, iters, warm_up_iters=5):
39    for _ in range(warm_up_iters):
40      target()
41
42    all_times = []
43    for _ in range(iters):
44      iter_time = time.time()
45      target()
46      all_times.append(time.time() - iter_time)
47
48    avg_time = np.average(all_times)
49
50    extras = dict()
51    extras['all_times'] = all_times
52
53    if isinstance(name, tuple):
54      extras['name'] = name
55      name = '_'.join(str(piece) for piece in name)
56
57    self.report_benchmark(
58        iters=iters, wall_time=avg_time, name=name, extras=extras)
59
60
61if __name__ == '__main__':
62  tf.test.main()
63