• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# mypy: allow-untyped-defs
2"""Example of the Timer and Fuzzer APIs:
3
4$ python -m examples.fuzzer
5"""
6
7import sys
8
9import torch.utils.benchmark as benchmark_utils
10
11
12def main():
13    add_fuzzer = benchmark_utils.Fuzzer(
14        parameters=[
15            [
16                benchmark_utils.FuzzedParameter(
17                    name=f"k{i}",
18                    minval=16,
19                    maxval=16 * 1024,
20                    distribution="loguniform",
21                ) for i in range(3)
22            ],
23            benchmark_utils.FuzzedParameter(
24                name="d",
25                distribution={2: 0.6, 3: 0.4},
26            ),
27        ],
28        tensors=[
29            [
30                benchmark_utils.FuzzedTensor(
31                    name=name,
32                    size=("k0", "k1", "k2"),
33                    dim_parameter="d",
34                    probability_contiguous=0.75,
35                    min_elements=64 * 1024,
36                    max_elements=128 * 1024,
37                ) for name in ("x", "y")
38            ],
39        ],
40        seed=0,
41    )
42
43    n = 250
44    measurements = []
45    for i, (tensors, tensor_properties, _) in enumerate(add_fuzzer.take(n=n)):
46        x, x_order = tensors["x"], str(tensor_properties["x"]["order"])
47        y, y_order = tensors["y"], str(tensor_properties["y"]["order"])
48        shape = ", ".join(tuple(f'{i:>4}' for i in x.shape))
49
50        description = "".join([
51            f"{x.numel():>7} | {shape:<16} | ",
52            f"{'contiguous' if x.is_contiguous() else x_order:<12} | ",
53            f"{'contiguous' if y.is_contiguous() else y_order:<12} | ",
54        ])
55
56        timer = benchmark_utils.Timer(
57            stmt="x + y",
58            globals=tensors,
59            description=description,
60        )
61
62        measurements.append(timer.blocked_autorange(min_run_time=0.1))
63        measurements[-1].metadata = {"numel": x.numel()}
64        print(f"\r{i + 1} / {n}", end="")
65        sys.stdout.flush()
66    print()
67
68    # More string munging to make pretty output.
69    print(f"Average attempts per valid config: {1. / (1. - add_fuzzer.rejection_rate):.1f}")
70
71    def time_fn(m):
72        return m.median / m.metadata["numel"]
73    measurements.sort(key=time_fn)
74
75    template = f"{{:>6}}{' ' * 19}Size    Shape{' ' * 13}X order        Y order\n{'-' * 80}"
76    print(template.format("Best:"))
77    for m in measurements[:15]:
78        print(f"{time_fn(m) * 1e9:>4.1f} ns / element     {m.description}")
79
80    print("\n" + template.format("Worst:"))
81    for m in measurements[-15:]:
82        print(f"{time_fn(m) * 1e9:>4.1f} ns / element     {m.description}")
83
84
85if __name__ == "__main__":
86    main()
87