1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""Utilities to run benchmarks.""" 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import numbers 22import os 23import re 24import sys 25import time 26 27import six 28 29from tensorflow.core.protobuf import config_pb2 30from tensorflow.core.protobuf import rewriter_config_pb2 31from tensorflow.core.util import test_log_pb2 32from tensorflow.python.client import timeline 33from tensorflow.python.framework import ops 34from tensorflow.python.platform import app 35from tensorflow.python.platform import gfile 36from tensorflow.python.platform import tf_logging as logging 37from tensorflow.python.util import tf_inspect 38from tensorflow.python.util.tf_export import tf_export 39 40 41# When a subclass of the Benchmark class is created, it is added to 42# the registry automatically 43GLOBAL_BENCHMARK_REGISTRY = set() 44 45# Environment variable that determines whether benchmarks are written. 46# See also tensorflow/core/util/reporter.h TestReporter::kTestReporterEnv. 47TEST_REPORTER_TEST_ENV = "TEST_REPORT_FILE_PREFIX" 48 49# Environment variable that lets the TensorFlow runtime allocate a new 50# threadpool for each benchmark. 51OVERRIDE_GLOBAL_THREADPOOL = "TF_OVERRIDE_GLOBAL_THREADPOOL" 52 53 54def _global_report_benchmark( 55 name, iters=None, cpu_time=None, wall_time=None, 56 throughput=None, extras=None): 57 """Method for recording a benchmark directly. 58 59 Args: 60 name: The BenchmarkEntry name. 61 iters: (optional) How many iterations were run 62 cpu_time: (optional) Total cpu time in seconds 63 wall_time: (optional) Total wall time in seconds 64 throughput: (optional) Throughput (in MB/s) 65 extras: (optional) Dict mapping string keys to additional benchmark info. 66 67 Raises: 68 TypeError: if extras is not a dict. 69 IOError: if the benchmark output file already exists. 70 """ 71 if extras is not None: 72 if not isinstance(extras, dict): 73 raise TypeError("extras must be a dict") 74 75 logging.info("Benchmark [%s] iters: %d, wall_time: %g, cpu_time: %g," 76 "throughput: %g %s", name, iters if iters is not None else -1, 77 wall_time if wall_time is not None else -1, cpu_time if 78 cpu_time is not None else -1, throughput if 79 throughput is not None else -1, str(extras) if extras else "") 80 81 entries = test_log_pb2.BenchmarkEntries() 82 entry = entries.entry.add() 83 entry.name = name 84 if iters is not None: 85 entry.iters = iters 86 if cpu_time is not None: 87 entry.cpu_time = cpu_time 88 if wall_time is not None: 89 entry.wall_time = wall_time 90 if throughput is not None: 91 entry.throughput = throughput 92 if extras is not None: 93 for (k, v) in extras.items(): 94 if isinstance(v, numbers.Number): 95 entry.extras[k].double_value = v 96 else: 97 entry.extras[k].string_value = str(v) 98 99 test_env = os.environ.get(TEST_REPORTER_TEST_ENV, None) 100 if test_env is None: 101 # Reporting was not requested, just print the proto 102 print(str(entries)) 103 return 104 105 serialized_entry = entries.SerializeToString() 106 107 mangled_name = name.replace("/", "__") 108 output_path = "%s%s" % (test_env, mangled_name) 109 if gfile.Exists(output_path): 110 raise IOError("File already exists: %s" % output_path) 111 with gfile.GFile(output_path, "wb") as out: 112 out.write(serialized_entry) 113 114 115class _BenchmarkRegistrar(type): 116 """The Benchmark class registrar. Used by abstract Benchmark class.""" 117 118 def __new__(mcs, clsname, base, attrs): 119 newclass = super(mcs, _BenchmarkRegistrar).__new__( 120 mcs, clsname, base, attrs) 121 if not newclass.is_abstract(): 122 GLOBAL_BENCHMARK_REGISTRY.add(newclass) 123 return newclass 124 125 126class Benchmark(six.with_metaclass(_BenchmarkRegistrar, object)): 127 """Abstract class that provides helper functions for running benchmarks. 128 129 Any class subclassing this one is immediately registered in the global 130 benchmark registry. 131 132 Only methods whose names start with the word "benchmark" will be run during 133 benchmarking. 134 """ 135 136 @classmethod 137 def is_abstract(cls): 138 # mro: (_BenchmarkRegistrar, Benchmark) means this is Benchmark 139 return len(cls.mro()) <= 2 140 141 def _get_name(self, overwrite_name=None): 142 """Returns full name of class and method calling report_benchmark.""" 143 144 # Find the caller method (outermost Benchmark class) 145 stack = tf_inspect.stack() 146 calling_class = None 147 name = None 148 for frame in stack[::-1]: 149 f_locals = frame[0].f_locals 150 f_self = f_locals.get("self", None) 151 if isinstance(f_self, Benchmark): 152 calling_class = f_self # Get the outermost stack Benchmark call 153 name = frame[3] # Get the method name 154 break 155 if calling_class is None: 156 raise ValueError("Unable to determine calling Benchmark class.") 157 158 # Use the method name, or overwrite_name is provided. 159 name = overwrite_name or name 160 # Prefix the name with the class name. 161 class_name = type(calling_class).__name__ 162 name = "%s.%s" % (class_name, name) 163 return name 164 165 def report_benchmark( 166 self, 167 iters=None, 168 cpu_time=None, 169 wall_time=None, 170 throughput=None, 171 extras=None, 172 name=None): 173 """Report a benchmark. 174 175 Args: 176 iters: (optional) How many iterations were run 177 cpu_time: (optional) median or mean cpu time in seconds. 178 wall_time: (optional) median or mean wall time in seconds. 179 throughput: (optional) Throughput (in MB/s) 180 extras: (optional) Dict mapping string keys to additional benchmark info. 181 Values may be either floats or values that are convertible to strings. 182 name: (optional) Override the BenchmarkEntry name with `name`. 183 Otherwise it is inferred from the top-level method name. 184 """ 185 name = self._get_name(overwrite_name=name) 186 _global_report_benchmark( 187 name=name, iters=iters, cpu_time=cpu_time, wall_time=wall_time, 188 throughput=throughput, extras=extras) 189 190 191@tf_export("test.benchmark_config") 192def benchmark_config(): 193 """Returns a tf.ConfigProto for disabling the dependency optimizer. 194 195 Returns: 196 A TensorFlow ConfigProto object. 197 """ 198 config = config_pb2.ConfigProto() 199 config.graph_options.rewrite_options.dependency_optimization = ( 200 rewriter_config_pb2.RewriterConfig.OFF) 201 return config 202 203 204@tf_export("test.Benchmark") 205class TensorFlowBenchmark(Benchmark): 206 """Abstract class that provides helpers for TensorFlow benchmarks.""" 207 208 def __init__(self): 209 # Allow TensorFlow runtime to allocate a new threadpool with different 210 # number of threads for each new benchmark. 211 os.environ[OVERRIDE_GLOBAL_THREADPOOL] = "1" 212 super(TensorFlowBenchmark, self).__init__() 213 214 @classmethod 215 def is_abstract(cls): 216 # mro: (_BenchmarkRegistrar, Benchmark, TensorFlowBenchmark) means 217 # this is TensorFlowBenchmark. 218 return len(cls.mro()) <= 3 219 220 def run_op_benchmark(self, 221 sess, 222 op_or_tensor, 223 feed_dict=None, 224 burn_iters=2, 225 min_iters=10, 226 store_trace=False, 227 store_memory_usage=True, 228 name=None, 229 extras=None, 230 mbs=0): 231 """Run an op or tensor in the given session. Report the results. 232 233 Args: 234 sess: `Session` object to use for timing. 235 op_or_tensor: `Operation` or `Tensor` to benchmark. 236 feed_dict: A `dict` of values to feed for each op iteration (see the 237 `feed_dict` parameter of `Session.run`). 238 burn_iters: Number of burn-in iterations to run. 239 min_iters: Minimum number of iterations to use for timing. 240 store_trace: Boolean, whether to run an extra untimed iteration and 241 store the trace of iteration in returned extras. 242 The trace will be stored as a string in Google Chrome trace format 243 in the extras field "full_trace_chrome_format". Note that trace 244 will not be stored in test_log_pb2.TestResults proto. 245 store_memory_usage: Boolean, whether to run an extra untimed iteration, 246 calculate memory usage, and store that in extras fields. 247 name: (optional) Override the BenchmarkEntry name with `name`. 248 Otherwise it is inferred from the top-level method name. 249 extras: (optional) Dict mapping string keys to additional benchmark info. 250 Values may be either floats or values that are convertible to strings. 251 mbs: (optional) The number of megabytes moved by this op, used to 252 calculate the ops throughput. 253 254 Returns: 255 A `dict` containing the key-value pairs that were passed to 256 `report_benchmark`. If `store_trace` option is used, then 257 `full_chrome_trace_format` will be included in return dictionary even 258 though it is not passed to `report_benchmark` with `extras`. 259 """ 260 for _ in range(burn_iters): 261 sess.run(op_or_tensor, feed_dict=feed_dict) 262 263 deltas = [None] * min_iters 264 265 for i in range(min_iters): 266 start_time = time.time() 267 sess.run(op_or_tensor, feed_dict=feed_dict) 268 end_time = time.time() 269 delta = end_time - start_time 270 deltas[i] = delta 271 272 extras = extras if extras is not None else {} 273 unreported_extras = {} 274 if store_trace or store_memory_usage: 275 run_options = config_pb2.RunOptions( 276 trace_level=config_pb2.RunOptions.FULL_TRACE) 277 run_metadata = config_pb2.RunMetadata() 278 sess.run(op_or_tensor, feed_dict=feed_dict, 279 options=run_options, run_metadata=run_metadata) 280 tl = timeline.Timeline(run_metadata.step_stats) 281 282 if store_trace: 283 unreported_extras["full_trace_chrome_format"] = ( 284 tl.generate_chrome_trace_format()) 285 286 if store_memory_usage: 287 step_stats_analysis = tl.analyze_step_stats(show_memory=True) 288 allocator_maximums = step_stats_analysis.allocator_maximums 289 for k, v in allocator_maximums.items(): 290 extras["allocator_maximum_num_bytes_%s" % k] = v.num_bytes 291 292 def _median(x): 293 if not x: 294 return -1 295 s = sorted(x) 296 l = len(x) 297 lm1 = l - 1 298 return (s[l//2] + s[lm1//2]) / 2.0 299 300 median_delta = _median(deltas) 301 302 benchmark_values = { 303 "iters": min_iters, 304 "wall_time": median_delta, 305 "extras": extras, 306 "name": name, 307 "throughput": mbs / median_delta 308 } 309 self.report_benchmark(**benchmark_values) 310 benchmark_values["extras"].update(unreported_extras) 311 return benchmark_values 312 313 def evaluate(self, tensors): 314 """Evaluates tensors and returns numpy values. 315 316 Args: 317 tensors: A Tensor or a nested list/tuple of Tensors. 318 319 Returns: 320 tensors numpy values. 321 """ 322 sess = ops.get_default_session() or self.cached_session() 323 return sess.run(tensors) 324 325 326def _run_benchmarks(regex): 327 """Run benchmarks that match regex `regex`. 328 329 This function goes through the global benchmark registry, and matches 330 benchmark class and method names of the form 331 `module.name.BenchmarkClass.benchmarkMethod` to the given regex. 332 If a method matches, it is run. 333 334 Args: 335 regex: The string regular expression to match Benchmark classes against. 336 """ 337 registry = list(GLOBAL_BENCHMARK_REGISTRY) 338 339 # Match benchmarks in registry against regex 340 for benchmark in registry: 341 benchmark_name = "%s.%s" % (benchmark.__module__, benchmark.__name__) 342 attrs = dir(benchmark) 343 # Don't instantiate the benchmark class unless necessary 344 benchmark_instance = None 345 346 for attr in attrs: 347 if not attr.startswith("benchmark"): 348 continue 349 candidate_benchmark_fn = getattr(benchmark, attr) 350 if not callable(candidate_benchmark_fn): 351 continue 352 full_benchmark_name = "%s.%s" % (benchmark_name, attr) 353 if regex == "all" or re.search(regex, full_benchmark_name): 354 # Instantiate the class if it hasn't been instantiated 355 benchmark_instance = benchmark_instance or benchmark() 356 # Get the method tied to the class 357 instance_benchmark_fn = getattr(benchmark_instance, attr) 358 # Call the instance method 359 instance_benchmark_fn() 360 361 362def benchmarks_main(true_main, argv=None): 363 """Run benchmarks as declared in argv. 364 365 Args: 366 true_main: True main function to run if benchmarks are not requested. 367 argv: the command line arguments (if None, uses sys.argv). 368 """ 369 if argv is None: 370 argv = sys.argv 371 found_arg = [arg for arg in argv 372 if arg.startswith("--benchmarks=") 373 or arg.startswith("-benchmarks=")] 374 if found_arg: 375 # Remove --benchmarks arg from sys.argv 376 argv.remove(found_arg[0]) 377 378 regex = found_arg[0].split("=")[1] 379 app.run(lambda _: _run_benchmarks(regex), argv=argv) 380 else: 381 true_main() 382