• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# mypy: allow-untyped-defs
2"""Example use of Timer and sparse op fuzzers to measure kernel performance.
3
4$ python -m examples.sparse.op_benchmark
5"""
6
7import numpy as np
8import torch
9
10from torch.utils.benchmark import Timer
11from torch.utils.benchmark.op_fuzzers.sparse_unary import UnaryOpSparseFuzzer
12from torch.utils.benchmark.op_fuzzers.sparse_binary import BinaryOpSparseFuzzer
13import operator
14
15_MEASURE_TIME = 1.0
16
17def assert_dicts_equal(dict_0, dict_1):
18    """Builtin dict comparison will not compare numpy arrays.
19    e.g.
20        x = {"a": np.ones((2, 1))}
21        x == x  # Raises ValueError
22    """
23    assert set(dict_0.keys()) == set(dict_0.keys())
24    assert all(np.all(v == dict_1[k]) for k, v in dict_0.items() if k != "dtype")
25
26def run(n, stmt, fuzzer_cls):
27    float_iter = fuzzer_cls(seed=0, dtype=torch.float32).take(n)
28    double_iter = fuzzer_cls(seed=0, dtype=torch.float64).take(n)
29    raw_results = []
30    for i, (float_values, int_values) in enumerate(zip(float_iter, double_iter)):
31        float_tensors, float_tensor_params, float_params = float_values
32        int_tensors, int_tensor_params, int_params = int_values
33
34        assert_dicts_equal(float_params, int_params)
35        assert_dicts_equal(float_tensor_params["x"], int_tensor_params["x"])
36
37        float_measurement, int_measurement = (
38            Timer(
39                stmt,
40                globals=tensors,
41            ).blocked_autorange(min_run_time=_MEASURE_TIME)
42            for tensors in (float_tensors, int_tensors)
43        )
44
45        descriptions = []
46        for name in float_tensors:
47            shape_str = "(" + ", ".join([
48                f"2 ** {int(np.log2(i))}"
49                if 2 ** int(np.log2(i)) == i and i > 1
50                else str(i)
51                for i in float_tensors[name].shape
52            ]) + ")"
53            sparse_dim = float_tensor_params[name]["sparse_dim"]
54            sparse_dim_str = str(sparse_dim)
55            is_coalesced = float_tensor_params[name]["is_coalesced"]
56            is_coalesced_str = "True" if is_coalesced else "False"
57            descriptions.append((name, shape_str, sparse_dim_str, is_coalesced_str))
58        raw_results.append((float_measurement, int_measurement, descriptions))
59
60        print(f"\r{i + 1} / {n}", end="")
61    print()
62
63    parsed_results, name_len, shape_len, sparse_dim_len, is_coalesced_len = [], 0, 0, 0, 0
64    for float_measurement, int_measurement, descriptions in raw_results:
65        t_float = float_measurement.median * 1e6
66        t_int = int_measurement.median * 1e6
67        rel_diff = abs(t_float - t_int) / (t_float + t_int) * 2
68        parsed_results.append((t_float, t_int, rel_diff, descriptions))
69        for name, shape, sparse_dim, is_coalesced in descriptions:
70            name_len = max(name_len, len(name))
71            shape_len = max(shape_len, len(shape))
72            sparse_dim_len = max(sparse_dim_len, len(sparse_dim))
73            is_coalesced_len = max(is_coalesced_len, len(is_coalesced))
74
75    parsed_results.sort(key=operator.itemgetter(2))
76
77    print(f"stmt: {stmt}")
78    print(f" diff    faster{'':>17}{' ' * name_len} ", end="")
79    print(f"{'shape'.ljust(shape_len)}{'':>12}{'sparse_dim'.ljust(sparse_dim_len)}", end="")
80    print(f"          is_coalesced\n{'-' * 100}")
81    for results, spacer in [(parsed_results[:10], "..."), (parsed_results[-10:], "")]:
82        for t_float, t_int, rel_diff, descriptions in results:
83            time_str = [f"{rel_diff * 100:>4.1f}%    {'int' if t_int < t_float else 'float':<20}"]
84            time_str.extend(["".ljust(len(time_str[0])) for _ in descriptions[:-1]])
85            for t_str, (name, shape, sparse_dim, is_coalesced) in zip(time_str, descriptions):
86                name = f"{name}:".ljust(name_len + 1)
87                shape = shape.ljust(shape_len + 10)
88                sparse_dim = sparse_dim.ljust(sparse_dim_len)
89                print(f"{t_str} {name}  {shape}|     {sparse_dim}      |   {is_coalesced}")
90        print(spacer)
91
92
93def main():
94    run(n=100, stmt="torch.sparse.sum(x, dim=0)", fuzzer_cls=UnaryOpSparseFuzzer)
95    run(n=100, stmt="torch.sparse.softmax(x, dim=0)", fuzzer_cls=UnaryOpSparseFuzzer)
96    run(n=100, stmt="x + y", fuzzer_cls=BinaryOpSparseFuzzer)
97
98
99if __name__ == "__main__":
100    main()
101