• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# mypy: allow-untyped-defs
2"""Example of Timer and Compare APIs:
3
4$ python -m examples.compare
5"""
6
7import pickle
8import sys
9import time
10
11import torch
12
13import torch.utils.benchmark as benchmark_utils
14
15
16class FauxTorch:
17    """Emulate different versions of pytorch.
18
19    In normal circumstances this would be done with multiple processes
20    writing serialized measurements, but this simplifies that model to
21    make the example clearer.
22    """
23    def __init__(self, real_torch, extra_ns_per_element):
24        self._real_torch = real_torch
25        self._extra_ns_per_element = extra_ns_per_element
26
27    def extra_overhead(self, result):
28        # time.sleep has a ~65 us overhead, so only fake a
29        # per-element overhead if numel is large enough.
30        numel = int(result.numel())
31        if numel > 5000:
32            time.sleep(numel * self._extra_ns_per_element * 1e-9)
33        return result
34
35    def add(self, *args, **kwargs):
36        return self.extra_overhead(self._real_torch.add(*args, **kwargs))
37
38    def mul(self, *args, **kwargs):
39        return self.extra_overhead(self._real_torch.mul(*args, **kwargs))
40
41    def cat(self, *args, **kwargs):
42        return self.extra_overhead(self._real_torch.cat(*args, **kwargs))
43
44    def matmul(self, *args, **kwargs):
45        return self.extra_overhead(self._real_torch.matmul(*args, **kwargs))
46
47
48def main():
49    tasks = [
50        ("add", "add", "torch.add(x, y)"),
51        ("add", "add (extra +0)", "torch.add(x, y + zero)"),
52    ]
53
54    serialized_results = []
55    repeats = 2
56    timers = [
57        benchmark_utils.Timer(
58            stmt=stmt,
59            globals={
60                "torch": torch if branch == "master" else FauxTorch(torch, overhead_ns),
61                "x": torch.ones((size, 4)),
62                "y": torch.ones((1, 4)),
63                "zero": torch.zeros(()),
64            },
65            label=label,
66            sub_label=sub_label,
67            description=f"size: {size}",
68            env=branch,
69            num_threads=num_threads,
70        )
71        for branch, overhead_ns in [("master", None), ("my_branch", 1), ("severe_regression", 5)]
72        for label, sub_label, stmt in tasks
73        for size in [1, 10, 100, 1000, 10000, 50000]
74        for num_threads in [1, 4]
75    ]
76
77    for i, timer in enumerate(timers * repeats):
78        serialized_results.append(pickle.dumps(
79            timer.blocked_autorange(min_run_time=0.05)
80        ))
81        print(f"\r{i + 1} / {len(timers) * repeats}", end="")
82        sys.stdout.flush()
83    print()
84
85    comparison = benchmark_utils.Compare([
86        pickle.loads(i) for i in serialized_results
87    ])
88
89    print("== Unformatted " + "=" * 80 + "\n" + "/" * 95 + "\n")
90    comparison.print()
91
92    print("== Formatted " + "=" * 80 + "\n" + "/" * 93 + "\n")
93    comparison.trim_significant_figures()
94    comparison.colorize()
95    comparison.print()
96
97
98if __name__ == "__main__":
99    main()
100