• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1import torch
2from torch._inductor import ir
3from torch._inductor.runtime.benchmarking import benchmarker
4
5
6def to_channels_last(x):
7    assert x.dim() == 4
8
9    # NCHW -> NHWC
10    stride_order = [3, 0, 2, 1]
11    y = x.clone().as_strided(
12        x.shape,
13        ir.FlexibleLayout.stride_ordered(x.shape, stride_order),
14    )
15    y.copy_(x)
16    assert torch.allclose(x, y)
17    return y
18
19
20def bench_conv(with_stack=True):
21    x = torch.rand(256, 3, 224, 224).cuda()
22    weight = torch.rand(64, 3, 7, 7).cuda()
23
24    x_chan = to_channels_last(x)
25    weight_chan = to_channels_last(weight)
26    kwargs = {
27        "stride": [2, 2],
28        "padding": [3, 3],
29        "dilation": [1, 1],
30        "transposed": False,
31        "output_padding": [0, 0],
32        "groups": 1,
33    }
34
35    def baseline_fn():
36        return torch.convolution(x, weight, bias=None, **kwargs)
37
38    def test_fn():
39        return torch.convolution(x_chan, weight_chan, bias=None, **kwargs)
40
41    # warmup
42    baseline_fn()
43    test_fn()
44
45    torch.cuda.synchronize()
46    with torch.profiler.profile(with_stack=with_stack) as p:
47        baseline_out = baseline_fn()
48        test_out = test_fn()
49        torch.cuda.synchronize()
50
51    p.export_chrome_trace("/tmp/chrome.json")
52    assert torch.allclose(baseline_out, test_out, atol=1e-3, rtol=1e-3), (
53        baseline_out[0][0][0][:32],
54        test_out[0][0][0][:32],
55    )
56
57    baseline_ms = benchmarker.benchmark_gpu(baseline_fn, rep=40)
58    test_ms = benchmarker.benchmark_gpu(test_fn, rep=40)
59    print(f"baseline {baseline_ms} test {test_ms} speedup {baseline_ms / test_ms:.3f}x")
60
61
62def main():
63    bench_conv()
64
65
66if __name__ == "__main__":
67    main()
68