1# mypy: allow-untyped-defs 2"""Example of the Timer and Fuzzer APIs: 3 4$ python -m examples.fuzzer 5""" 6 7import sys 8 9import torch.utils.benchmark as benchmark_utils 10 11 12def main(): 13 add_fuzzer = benchmark_utils.Fuzzer( 14 parameters=[ 15 [ 16 benchmark_utils.FuzzedParameter( 17 name=f"k{i}", 18 minval=16, 19 maxval=16 * 1024, 20 distribution="loguniform", 21 ) for i in range(3) 22 ], 23 benchmark_utils.FuzzedParameter( 24 name="d", 25 distribution={2: 0.6, 3: 0.4}, 26 ), 27 ], 28 tensors=[ 29 [ 30 benchmark_utils.FuzzedTensor( 31 name=name, 32 size=("k0", "k1", "k2"), 33 dim_parameter="d", 34 probability_contiguous=0.75, 35 min_elements=64 * 1024, 36 max_elements=128 * 1024, 37 ) for name in ("x", "y") 38 ], 39 ], 40 seed=0, 41 ) 42 43 n = 250 44 measurements = [] 45 for i, (tensors, tensor_properties, _) in enumerate(add_fuzzer.take(n=n)): 46 x, x_order = tensors["x"], str(tensor_properties["x"]["order"]) 47 y, y_order = tensors["y"], str(tensor_properties["y"]["order"]) 48 shape = ", ".join(tuple(f'{i:>4}' for i in x.shape)) 49 50 description = "".join([ 51 f"{x.numel():>7} | {shape:<16} | ", 52 f"{'contiguous' if x.is_contiguous() else x_order:<12} | ", 53 f"{'contiguous' if y.is_contiguous() else y_order:<12} | ", 54 ]) 55 56 timer = benchmark_utils.Timer( 57 stmt="x + y", 58 globals=tensors, 59 description=description, 60 ) 61 62 measurements.append(timer.blocked_autorange(min_run_time=0.1)) 63 measurements[-1].metadata = {"numel": x.numel()} 64 print(f"\r{i + 1} / {n}", end="") 65 sys.stdout.flush() 66 print() 67 68 # More string munging to make pretty output. 69 print(f"Average attempts per valid config: {1. / (1. - add_fuzzer.rejection_rate):.1f}") 70 71 def time_fn(m): 72 return m.median / m.metadata["numel"] 73 measurements.sort(key=time_fn) 74 75 template = f"{{:>6}}{' ' * 19}Size Shape{' ' * 13}X order Y order\n{'-' * 80}" 76 print(template.format("Best:")) 77 for m in measurements[:15]: 78 print(f"{time_fn(m) * 1e9:>4.1f} ns / element {m.description}") 79 80 print("\n" + template.format("Worst:")) 81 for m in measurements[-15:]: 82 print(f"{time_fn(m) * 1e9:>4.1f} ns / element {m.description}") 83 84 85if __name__ == "__main__": 86 main() 87