• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1import itertools
2
3from benchmark_helper import time_with_torch_timer
4
5import torch
6import torch._dynamo
7
8
9@torch._dynamo.optimize("inductor", nopython=True)
10def inductor_scatter_add(dst, src, index):
11    return torch.scatter_add(dst, 1, index, src)
12
13
14def torch_scatter_add(dst, src, index):
15    return torch.scatter_add(dst, 1, index, src)
16
17
18def test_total_time(shapes, types):
19    print(
20        "shape; type; torch scatter_add; inductor scatter_add; torch scatter_add (worst case); inductor scatter_add (worst case)"
21    )
22    for shape, dtype in itertools.product(shapes, types):
23        print(shape, dtype, sep="; ", end="; ")
24
25        torch.manual_seed(1)
26        if dtype.is_floating_point:
27            src = torch.randn(shape, device="cpu", dtype=dtype)
28            dst = torch.randn(shape, device="cpu", dtype=dtype)
29        else:
30            src = torch.randint(0, shape[1], shape, device="cpu", dtype=dtype)
31            dst = torch.randint(0, shape[1], shape, device="cpu", dtype=dtype)
32        index = torch.randint(0, shape[1], shape, device="cpu", dtype=torch.int64)
33        worst_index = torch.tensor([[0] * shape[1]], device="cpu", dtype=torch.int64)
34
35        torch_result = torch_scatter_add(dst, src, index)
36        inductor_result = inductor_scatter_add(dst, src, index)
37        torch.testing.assert_close(torch_result, inductor_result)
38
39        torch_ms = (
40            time_with_torch_timer(torch_scatter_add, (dst, src, index)).mean * 1000
41        )
42        inductor_ms = (
43            time_with_torch_timer(inductor_scatter_add, (dst, src, index)).mean * 1000
44        )
45        torch_worst_ms = (
46            time_with_torch_timer(torch_scatter_add, (dst, src, worst_index)).mean
47            * 1000
48        )
49        inductor_worst_ms = (
50            time_with_torch_timer(inductor_scatter_add, (dst, src, worst_index)).mean
51            * 1000
52        )
53
54        print(torch_ms, inductor_ms, torch_worst_ms, inductor_worst_ms, sep="; ")
55
56        torch._dynamo.reset()
57
58
59if __name__ == "__main__":
60    shapes = [
61        ([1, 4096]),
62        ([1, 65536]),
63    ]
64    types = [
65        torch.float32,
66        torch.int32,
67    ]
68    print("test total time")
69    test_total_time(shapes, types)
70
71# Results preview on 5800H
72"""
73test total time
74shape; type; torch scatter_add; inductor scatter_add; torch scatter_add (worst case); inductor scatter_add (worst case)
75[1, 4096]; torch.float32; 0.14733232000025964; 0.05388864999986254; 0.1451428800010035; 0.06496850000075938
76[1, 4096]; torch.int32; 0.1440268700002889; 0.05882900999949925; 0.1429359899998417; 0.07036211000013282
77[1, 65536]; torch.float32; 1.3435545300012564; 0.15207924000151252; 1.2523296799986383; 3.1408327299982375
78[1, 65536]; torch.int32; 1.3407247500003905; 0.12999147000073208; 1.2956029100018895; 0.853825209999286
79"""
80