• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1import operator_benchmark as op_bench
2
3import torch
4import torch.nn.functional as F
5
6
7"""Microbenchmarks for batchnorm operator."""
8
9# Benchmark cudnn if available
10if torch.backends.cudnn.is_available:
11
12    def cudnn_benchmark_configs(configs):
13        result = []
14        for config in configs:
15            is_cuda = any("cuda" in attr.values() for attr in config)
16            if is_cuda:
17                result.append((*config, dict(cudnn=True)))
18            result.append((*config, dict(cudnn=False)))
19        return result
20
21else:
22
23    def cudnn_benchmark_configs(configs):
24        return [(*config, dict(cudnn=False)) for config in configs]
25
26
27batchnorm_configs_short = cudnn_benchmark_configs(
28    op_bench.config_list(
29        attr_names=["M", "N", "K"],
30        attrs=[
31            [1, 256, 3136],
32        ],
33        cross_product_configs={
34            "device": ["cpu", "cuda"],
35            "training": [True, False],
36        },
37        tags=["short"],
38    )
39)
40
41batchnorm_configs_long = cudnn_benchmark_configs(
42    op_bench.cross_product_configs(
43        M=[2, 128],
44        N=[8192, 2048],
45        K=[1],
46        device=["cpu", "cuda"],
47        training=[True, False],
48        tags=["long"],
49    )
50)
51
52
53class BatchNormBenchmark(op_bench.TorchBenchmarkBase):
54    def init(self, M, N, K, device, training, cudnn):
55        self.inputs = {
56            "input_one": torch.rand(
57                M, N, K, device=device, requires_grad=self.auto_set()
58            ),
59            "mean": torch.rand(N, device=device),
60            "var": torch.rand(N, device=device),
61            "weight": torch.rand(N, device=device),
62            "bias": torch.rand(N, device=device),
63            "training": training,
64            "cudnn": cudnn,
65        }
66        self.set_module_name("batchnorm")
67
68    def forward(self, input_one, mean, var, weight, bias, training, cudnn):
69        with torch.backends.cudnn.flags(enabled=cudnn):
70            return F.batch_norm(input_one, mean, var, weight, bias, training)
71
72
73op_bench.generate_pt_test(
74    batchnorm_configs_short + batchnorm_configs_long, BatchNormBenchmark
75)
76op_bench.generate_pt_gradient_test(
77    batchnorm_configs_short + batchnorm_configs_long, BatchNormBenchmark
78)
79
80
81batchnorm1d_configs_short = cudnn_benchmark_configs(
82    op_bench.config_list(
83        attr_names=["N", "C"],
84        attrs=[
85            [3136, 256],
86        ],
87        cross_product_configs={
88            "device": ["cpu", "cuda"],
89            "training": [True, False],
90        },
91        tags=["short"],
92    )
93)
94
95batchnorm1d_configs_long = cudnn_benchmark_configs(
96    op_bench.cross_product_configs(
97        N=[2, 128],
98        C=[8192, 2048],
99        device=["cpu", "cuda"],
100        training=[True, False],
101        tags=["long"],
102    )
103)
104
105
106class BatchNorm1dBenchmark(op_bench.TorchBenchmarkBase):
107    def init(self, N, C, device, training, cudnn):
108        self.inputs = {
109            "input_one": torch.rand(N, C, device=device, requires_grad=self.auto_set()),
110            "mean": torch.rand(C, device=device),
111            "var": torch.rand(C, device=device),
112            "weight": torch.rand(C, device=device),
113            "bias": torch.rand(C, device=device),
114            "training": training,
115            "cudnn": cudnn,
116        }
117        self.set_module_name("batchnorm")
118
119    def forward(self, input_one, mean, var, weight, bias, training, cudnn):
120        with torch.backends.cudnn.flags(enabled=cudnn):
121            return F.batch_norm(input_one, mean, var, weight, bias, training)
122
123
124op_bench.generate_pt_test(
125    batchnorm1d_configs_short + batchnorm1d_configs_long, BatchNorm1dBenchmark
126)
127op_bench.generate_pt_gradient_test(
128    batchnorm1d_configs_short + batchnorm1d_configs_long, BatchNorm1dBenchmark
129)
130
131
132if __name__ == "__main__":
133    op_bench.benchmark_runner.main()
134