1import itertools 2 3from benchmark_helper import time_with_torch_timer 4 5import torch 6import torch._dynamo 7 8 9@torch._dynamo.optimize("inductor", nopython=True) 10def inductor_scatter_add(dst, src, index): 11 return torch.scatter_add(dst, 1, index, src) 12 13 14def torch_scatter_add(dst, src, index): 15 return torch.scatter_add(dst, 1, index, src) 16 17 18def test_total_time(shapes, types): 19 print( 20 "shape; type; torch scatter_add; inductor scatter_add; torch scatter_add (worst case); inductor scatter_add (worst case)" 21 ) 22 for shape, dtype in itertools.product(shapes, types): 23 print(shape, dtype, sep="; ", end="; ") 24 25 torch.manual_seed(1) 26 if dtype.is_floating_point: 27 src = torch.randn(shape, device="cpu", dtype=dtype) 28 dst = torch.randn(shape, device="cpu", dtype=dtype) 29 else: 30 src = torch.randint(0, shape[1], shape, device="cpu", dtype=dtype) 31 dst = torch.randint(0, shape[1], shape, device="cpu", dtype=dtype) 32 index = torch.randint(0, shape[1], shape, device="cpu", dtype=torch.int64) 33 worst_index = torch.tensor([[0] * shape[1]], device="cpu", dtype=torch.int64) 34 35 torch_result = torch_scatter_add(dst, src, index) 36 inductor_result = inductor_scatter_add(dst, src, index) 37 torch.testing.assert_close(torch_result, inductor_result) 38 39 torch_ms = ( 40 time_with_torch_timer(torch_scatter_add, (dst, src, index)).mean * 1000 41 ) 42 inductor_ms = ( 43 time_with_torch_timer(inductor_scatter_add, (dst, src, index)).mean * 1000 44 ) 45 torch_worst_ms = ( 46 time_with_torch_timer(torch_scatter_add, (dst, src, worst_index)).mean 47 * 1000 48 ) 49 inductor_worst_ms = ( 50 time_with_torch_timer(inductor_scatter_add, (dst, src, worst_index)).mean 51 * 1000 52 ) 53 54 print(torch_ms, inductor_ms, torch_worst_ms, inductor_worst_ms, sep="; ") 55 56 torch._dynamo.reset() 57 58 59if __name__ == "__main__": 60 shapes = [ 61 ([1, 4096]), 62 ([1, 65536]), 63 ] 64 types = [ 65 torch.float32, 66 torch.int32, 67 ] 68 print("test total time") 69 test_total_time(shapes, types) 70 71# Results preview on 5800H 72""" 73test total time 74shape; type; torch scatter_add; inductor scatter_add; torch scatter_add (worst case); inductor scatter_add (worst case) 75[1, 4096]; torch.float32; 0.14733232000025964; 0.05388864999986254; 0.1451428800010035; 0.06496850000075938 76[1, 4096]; torch.int32; 0.1440268700002889; 0.05882900999949925; 0.1429359899998417; 0.07036211000013282 77[1, 65536]; torch.float32; 1.3435545300012564; 0.15207924000151252; 1.2523296799986383; 3.1408327299982375 78[1, 65536]; torch.int32; 1.3407247500003905; 0.12999147000073208; 1.2956029100018895; 0.853825209999286 79""" 80