1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""Utilities to run benchmarks.""" 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import math 22import numbers 23import os 24import re 25import sys 26import time 27import types 28 29import six 30 31from tensorflow.core.protobuf import config_pb2 32from tensorflow.core.protobuf import rewriter_config_pb2 33from tensorflow.core.util import test_log_pb2 34from tensorflow.python.client import timeline 35from tensorflow.python.framework import ops 36from tensorflow.python.platform import app 37from tensorflow.python.platform import gfile 38from tensorflow.python.platform import tf_logging as logging 39from tensorflow.python.util import tf_inspect 40from tensorflow.python.util.tf_export import tf_export 41 42 43# When a subclass of the Benchmark class is created, it is added to 44# the registry automatically 45GLOBAL_BENCHMARK_REGISTRY = set() 46 47# Environment variable that determines whether benchmarks are written. 48# See also tensorflow/core/util/reporter.h TestReporter::kTestReporterEnv. 49TEST_REPORTER_TEST_ENV = "TEST_REPORT_FILE_PREFIX" 50 51# Environment variable that lets the TensorFlow runtime allocate a new 52# threadpool for each benchmark. 53OVERRIDE_GLOBAL_THREADPOOL = "TF_OVERRIDE_GLOBAL_THREADPOOL" 54 55 56def _rename_function(f, arg_num, name): 57 """Rename the given function's name appears in the stack trace.""" 58 func_code = six.get_function_code(f) 59 if six.PY2: 60 new_code = types.CodeType(arg_num, func_code.co_nlocals, 61 func_code.co_stacksize, func_code.co_flags, 62 func_code.co_code, func_code.co_consts, 63 func_code.co_names, func_code.co_varnames, 64 func_code.co_filename, name, 65 func_code.co_firstlineno, func_code.co_lnotab, 66 func_code.co_freevars, func_code.co_cellvars) 67 else: 68 if sys.version_info > (3, 8, 0, "alpha", 3): 69 # Python3.8 / PEP570 added co_posonlyargcount argument to CodeType. 70 new_code = types.CodeType(arg_num, func_code.co_posonlyargcount, 71 0, func_code.co_nlocals, 72 func_code.co_stacksize, func_code.co_flags, 73 func_code.co_code, func_code.co_consts, 74 func_code.co_names, func_code.co_varnames, 75 func_code.co_filename, name, 76 func_code.co_firstlineno, func_code.co_lnotab, 77 func_code.co_freevars, func_code.co_cellvars) 78 else: 79 new_code = types.CodeType(arg_num, 0, func_code.co_nlocals, 80 func_code.co_stacksize, func_code.co_flags, 81 func_code.co_code, func_code.co_consts, 82 func_code.co_names, func_code.co_varnames, 83 func_code.co_filename, name, 84 func_code.co_firstlineno, func_code.co_lnotab, 85 func_code.co_freevars, func_code.co_cellvars) 86 87 return types.FunctionType(new_code, f.__globals__, name, f.__defaults__, 88 f.__closure__) 89 90 91def _global_report_benchmark( 92 name, iters=None, cpu_time=None, wall_time=None, 93 throughput=None, extras=None, metrics=None): 94 """Method for recording a benchmark directly. 95 96 Args: 97 name: The BenchmarkEntry name. 98 iters: (optional) How many iterations were run 99 cpu_time: (optional) Total cpu time in seconds 100 wall_time: (optional) Total wall time in seconds 101 throughput: (optional) Throughput (in MB/s) 102 extras: (optional) Dict mapping string keys to additional benchmark info. 103 metrics: (optional) A list of dict representing metrics generated by the 104 benchmark. Each dict should contain keys 'name' and'value'. A dict 105 can optionally contain keys 'min_value' and 'max_value'. 106 107 Raises: 108 TypeError: if extras is not a dict. 109 IOError: if the benchmark output file already exists. 110 """ 111 logging.info("Benchmark [%s] iters: %d, wall_time: %g, cpu_time: %g," 112 "throughput: %g, extras: %s, metrics: %s", name, 113 iters if iters is not None else -1, 114 wall_time if wall_time is not None else -1, 115 cpu_time if cpu_time is not None else -1, 116 throughput if throughput is not None else -1, 117 str(extras) if extras else "None", 118 str(metrics) if metrics else "None") 119 120 entries = test_log_pb2.BenchmarkEntries() 121 entry = entries.entry.add() 122 entry.name = name 123 if iters is not None: 124 entry.iters = iters 125 if cpu_time is not None: 126 entry.cpu_time = cpu_time 127 if wall_time is not None: 128 entry.wall_time = wall_time 129 if throughput is not None: 130 entry.throughput = throughput 131 if extras is not None: 132 if not isinstance(extras, dict): 133 raise TypeError("extras must be a dict") 134 for (k, v) in extras.items(): 135 if isinstance(v, numbers.Number): 136 entry.extras[k].double_value = v 137 else: 138 entry.extras[k].string_value = str(v) 139 if metrics is not None: 140 if not isinstance(metrics, list): 141 raise TypeError("metrics must be a list") 142 for metric in metrics: 143 if "name" not in metric: 144 raise TypeError("metric must has a 'name' field") 145 if "value" not in metric: 146 raise TypeError("metric must has a 'value' field") 147 148 metric_entry = entry.metrics.add() 149 metric_entry.name = metric["name"] 150 metric_entry.value = metric["value"] 151 if "min_value" in metric: 152 metric_entry.min_value.value = metric["min_value"] 153 if "max_value" in metric: 154 metric_entry.max_value.value = metric["max_value"] 155 156 test_env = os.environ.get(TEST_REPORTER_TEST_ENV, None) 157 if test_env is None: 158 # Reporting was not requested, just print the proto 159 print(str(entries)) 160 return 161 162 serialized_entry = entries.SerializeToString() 163 164 mangled_name = name.replace("/", "__") 165 output_path = "%s%s" % (test_env, mangled_name) 166 if gfile.Exists(output_path): 167 raise IOError("File already exists: %s" % output_path) 168 with gfile.GFile(output_path, "wb") as out: 169 out.write(serialized_entry) 170 171 172class _BenchmarkRegistrar(type): 173 """The Benchmark class registrar. Used by abstract Benchmark class.""" 174 175 def __new__(mcs, clsname, base, attrs): 176 newclass = type.__new__(mcs, clsname, base, attrs) 177 if not newclass.is_abstract(): 178 GLOBAL_BENCHMARK_REGISTRY.add(newclass) 179 return newclass 180 181 182@tf_export("__internal__.test.ParameterizedBenchmark", v1=[]) 183class ParameterizedBenchmark(_BenchmarkRegistrar): 184 """Metaclass to generate parameterized benchmarks. 185 186 Use this class as a metaclass and override the `_benchmark_parameters` to 187 generate multiple benchmark test cases. For example: 188 189 class FooBenchmark(metaclass=tf.test.ParameterizedBenchmark, 190 tf.test.Benchmark): 191 # The `_benchmark_parameters` is expected to be a list with test cases. 192 # Each of the test case is a tuple, with the first time to be test case 193 # name, followed by any number of the parameters needed for the test case. 194 _benchmark_parameters = [ 195 ('case_1', Foo, 1, 'one'), 196 ('case_2', Bar, 2, 'two'), 197 ] 198 199 def benchmark_test(self, target_class, int_param, string_param): 200 # benchmark test body 201 202 The example above will generate two benchmark test cases: 203 "benchmark_test__case_1" and "benchmark_test__case_2". 204 """ 205 206 def __new__(mcs, clsname, base, attrs): 207 param_config_list = attrs["_benchmark_parameters"] 208 209 def create_benchmark_function(original_benchmark, params): 210 return lambda self: original_benchmark(self, *params) 211 212 for name in attrs.copy().keys(): 213 if not name.startswith("benchmark"): 214 continue 215 216 original_benchmark = attrs[name] 217 del attrs[name] 218 219 for param_config in param_config_list: 220 test_name_suffix = param_config[0] 221 params = param_config[1:] 222 benchmark_name = name + "__" + test_name_suffix 223 if benchmark_name in attrs: 224 raise Exception( 225 "Benchmark named {} already defined.".format(benchmark_name)) 226 227 benchmark = create_benchmark_function(original_benchmark, params) 228 # Renaming is important because `report_benchmark` function looks up the 229 # function name in the stack trace. 230 attrs[benchmark_name] = _rename_function(benchmark, 1, benchmark_name) 231 232 return super(mcs, ParameterizedBenchmark).__new__(mcs, clsname, base, attrs) 233 234 235class Benchmark(six.with_metaclass(_BenchmarkRegistrar, object)): 236 """Abstract class that provides helper functions for running benchmarks. 237 238 Any class subclassing this one is immediately registered in the global 239 benchmark registry. 240 241 Only methods whose names start with the word "benchmark" will be run during 242 benchmarking. 243 """ 244 245 @classmethod 246 def is_abstract(cls): 247 # mro: (_BenchmarkRegistrar, Benchmark) means this is Benchmark 248 return len(cls.mro()) <= 2 249 250 def _get_name(self, overwrite_name=None): 251 """Returns full name of class and method calling report_benchmark.""" 252 253 # Find the caller method (outermost Benchmark class) 254 stack = tf_inspect.stack() 255 calling_class = None 256 name = None 257 for frame in stack[::-1]: 258 f_locals = frame[0].f_locals 259 f_self = f_locals.get("self", None) 260 if isinstance(f_self, Benchmark): 261 calling_class = f_self # Get the outermost stack Benchmark call 262 name = frame[3] # Get the method name 263 break 264 if calling_class is None: 265 raise ValueError("Unable to determine calling Benchmark class.") 266 267 # Use the method name, or overwrite_name is provided. 268 name = overwrite_name or name 269 # Prefix the name with the class name. 270 class_name = type(calling_class).__name__ 271 name = "%s.%s" % (class_name, name) 272 return name 273 274 def report_benchmark( 275 self, 276 iters=None, 277 cpu_time=None, 278 wall_time=None, 279 throughput=None, 280 extras=None, 281 name=None, 282 metrics=None): 283 """Report a benchmark. 284 285 Args: 286 iters: (optional) How many iterations were run 287 cpu_time: (optional) Median or mean cpu time in seconds. 288 wall_time: (optional) Median or mean wall time in seconds. 289 throughput: (optional) Throughput (in MB/s) 290 extras: (optional) Dict mapping string keys to additional benchmark info. 291 Values may be either floats or values that are convertible to strings. 292 name: (optional) Override the BenchmarkEntry name with `name`. 293 Otherwise it is inferred from the top-level method name. 294 metrics: (optional) A list of dict, where each dict has the keys below 295 name (required), string, metric name 296 value (required), double, metric value 297 min_value (optional), double, minimum acceptable metric value 298 max_value (optional), double, maximum acceptable metric value 299 """ 300 name = self._get_name(overwrite_name=name) 301 _global_report_benchmark( 302 name=name, iters=iters, cpu_time=cpu_time, wall_time=wall_time, 303 throughput=throughput, extras=extras, metrics=metrics) 304 305 306@tf_export("test.benchmark_config") 307def benchmark_config(): 308 """Returns a tf.compat.v1.ConfigProto for disabling the dependency optimizer. 309 310 Returns: 311 A TensorFlow ConfigProto object. 312 """ 313 config = config_pb2.ConfigProto() 314 config.graph_options.rewrite_options.dependency_optimization = ( 315 rewriter_config_pb2.RewriterConfig.OFF) 316 return config 317 318 319@tf_export("test.Benchmark") 320class TensorFlowBenchmark(Benchmark): 321 """Abstract class that provides helpers for TensorFlow benchmarks.""" 322 323 def __init__(self): 324 # Allow TensorFlow runtime to allocate a new threadpool with different 325 # number of threads for each new benchmark. 326 os.environ[OVERRIDE_GLOBAL_THREADPOOL] = "1" 327 super(TensorFlowBenchmark, self).__init__() 328 329 @classmethod 330 def is_abstract(cls): 331 # mro: (_BenchmarkRegistrar, Benchmark, TensorFlowBenchmark) means 332 # this is TensorFlowBenchmark. 333 return len(cls.mro()) <= 3 334 335 def run_op_benchmark(self, 336 sess, 337 op_or_tensor, 338 feed_dict=None, 339 burn_iters=2, 340 min_iters=10, 341 store_trace=False, 342 store_memory_usage=True, 343 name=None, 344 extras=None, 345 mbs=0): 346 """Run an op or tensor in the given session. Report the results. 347 348 Args: 349 sess: `Session` object to use for timing. 350 op_or_tensor: `Operation` or `Tensor` to benchmark. 351 feed_dict: A `dict` of values to feed for each op iteration (see the 352 `feed_dict` parameter of `Session.run`). 353 burn_iters: Number of burn-in iterations to run. 354 min_iters: Minimum number of iterations to use for timing. 355 store_trace: Boolean, whether to run an extra untimed iteration and 356 store the trace of iteration in returned extras. 357 The trace will be stored as a string in Google Chrome trace format 358 in the extras field "full_trace_chrome_format". Note that trace 359 will not be stored in test_log_pb2.TestResults proto. 360 store_memory_usage: Boolean, whether to run an extra untimed iteration, 361 calculate memory usage, and store that in extras fields. 362 name: (optional) Override the BenchmarkEntry name with `name`. 363 Otherwise it is inferred from the top-level method name. 364 extras: (optional) Dict mapping string keys to additional benchmark info. 365 Values may be either floats or values that are convertible to strings. 366 mbs: (optional) The number of megabytes moved by this op, used to 367 calculate the ops throughput. 368 369 Returns: 370 A `dict` containing the key-value pairs that were passed to 371 `report_benchmark`. If `store_trace` option is used, then 372 `full_chrome_trace_format` will be included in return dictionary even 373 though it is not passed to `report_benchmark` with `extras`. 374 """ 375 for _ in range(burn_iters): 376 sess.run(op_or_tensor, feed_dict=feed_dict) 377 378 deltas = [None] * min_iters 379 380 for i in range(min_iters): 381 start_time = time.time() 382 sess.run(op_or_tensor, feed_dict=feed_dict) 383 end_time = time.time() 384 delta = end_time - start_time 385 deltas[i] = delta 386 387 extras = extras if extras is not None else {} 388 unreported_extras = {} 389 if store_trace or store_memory_usage: 390 run_options = config_pb2.RunOptions( 391 trace_level=config_pb2.RunOptions.FULL_TRACE) 392 run_metadata = config_pb2.RunMetadata() 393 sess.run(op_or_tensor, feed_dict=feed_dict, 394 options=run_options, run_metadata=run_metadata) 395 tl = timeline.Timeline(run_metadata.step_stats) 396 397 if store_trace: 398 unreported_extras["full_trace_chrome_format"] = ( 399 tl.generate_chrome_trace_format()) 400 401 if store_memory_usage: 402 step_stats_analysis = tl.analyze_step_stats(show_memory=True) 403 allocator_maximums = step_stats_analysis.allocator_maximums 404 for k, v in allocator_maximums.items(): 405 extras["allocator_maximum_num_bytes_%s" % k] = v.num_bytes 406 407 def _median(x): 408 if not x: 409 return -1 410 s = sorted(x) 411 l = len(x) 412 lm1 = l - 1 413 return (s[l//2] + s[lm1//2]) / 2.0 414 415 def _mean_and_stdev(x): 416 if not x: 417 return -1, -1 418 l = len(x) 419 mean = sum(x) / l 420 if l == 1: 421 return mean, -1 422 variance = sum([(e - mean) * (e - mean) for e in x]) / (l - 1) 423 return mean, math.sqrt(variance) 424 425 median_delta = _median(deltas) 426 427 benchmark_values = { 428 "iters": min_iters, 429 "wall_time": median_delta, 430 "extras": extras, 431 "name": name, 432 "throughput": mbs / median_delta 433 } 434 self.report_benchmark(**benchmark_values) 435 436 mean_delta, stdev_delta = _mean_and_stdev(deltas) 437 unreported_extras["wall_time_mean"] = mean_delta 438 unreported_extras["wall_time_stdev"] = stdev_delta 439 benchmark_values["extras"].update(unreported_extras) 440 return benchmark_values 441 442 def evaluate(self, tensors): 443 """Evaluates tensors and returns numpy values. 444 445 Args: 446 tensors: A Tensor or a nested list/tuple of Tensors. 447 448 Returns: 449 tensors numpy values. 450 """ 451 sess = ops.get_default_session() or self.cached_session() 452 return sess.run(tensors) 453 454 455def _run_benchmarks(regex): 456 """Run benchmarks that match regex `regex`. 457 458 This function goes through the global benchmark registry, and matches 459 benchmark class and method names of the form 460 `module.name.BenchmarkClass.benchmarkMethod` to the given regex. 461 If a method matches, it is run. 462 463 Args: 464 regex: The string regular expression to match Benchmark classes against. 465 466 Raises: 467 ValueError: If no benchmarks were selected by the input regex. 468 """ 469 registry = list(GLOBAL_BENCHMARK_REGISTRY) 470 471 selected_benchmarks = [] 472 # Match benchmarks in registry against regex 473 for benchmark in registry: 474 benchmark_name = "%s.%s" % (benchmark.__module__, benchmark.__name__) 475 attrs = dir(benchmark) 476 # Don't instantiate the benchmark class unless necessary 477 benchmark_instance = None 478 479 for attr in attrs: 480 if not attr.startswith("benchmark"): 481 continue 482 candidate_benchmark_fn = getattr(benchmark, attr) 483 if not callable(candidate_benchmark_fn): 484 continue 485 full_benchmark_name = "%s.%s" % (benchmark_name, attr) 486 if regex == "all" or re.search(regex, full_benchmark_name): 487 selected_benchmarks.append(full_benchmark_name) 488 # Instantiate the class if it hasn't been instantiated 489 benchmark_instance = benchmark_instance or benchmark() 490 # Get the method tied to the class 491 instance_benchmark_fn = getattr(benchmark_instance, attr) 492 # Call the instance method 493 instance_benchmark_fn() 494 495 if not selected_benchmarks: 496 raise ValueError("No benchmarks matched the pattern: '{}'".format(regex)) 497 498 499def benchmarks_main(true_main, argv=None): 500 """Run benchmarks as declared in argv. 501 502 Args: 503 true_main: True main function to run if benchmarks are not requested. 504 argv: the command line arguments (if None, uses sys.argv). 505 """ 506 if argv is None: 507 argv = sys.argv 508 found_arg = [arg for arg in argv 509 if arg.startswith("--benchmarks=") 510 or arg.startswith("-benchmarks=")] 511 if found_arg: 512 # Remove --benchmarks arg from sys.argv 513 argv.remove(found_arg[0]) 514 515 regex = found_arg[0].split("=")[1] 516 app.run(lambda _: _run_benchmarks(regex), argv=argv) 517 else: 518 true_main() 519