• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1import operator_benchmark as op_bench
2
3import torch
4
5
6"""Microbenchmarks for quantized batchnorm operator."""
7
8batchnorm_configs_short = op_bench.config_list(
9    attr_names=["M", "N", "K"],
10    attrs=[
11        [1, 256, 3136],
12    ],
13    cross_product_configs={
14        "device": ["cpu"],
15        "dtype": (torch.qint8,),
16    },
17    tags=["short"],
18)
19
20
21class QBatchNormBenchmark(op_bench.TorchBenchmarkBase):
22    def init(self, M, N, K, device, dtype):
23        self._init(M, N, K, device)
24        x_scale = 0.1
25        x_zero_point = 0
26        self.inputs = {
27            "q_input_one": torch.quantize_per_tensor(
28                self.input_one, scale=x_scale, zero_point=x_zero_point, dtype=dtype
29            ),
30            "mean": torch.rand(N),
31            "var": torch.rand(N),
32            "weight": torch.rand(N),
33            "bias": torch.rand(N),
34            "eps": 1e-5,
35            "Y_scale": 0.1,
36            "Y_zero_point": 0,
37        }
38
39    def _init(self, M, N, K, device):
40        pass
41
42    def forward(self):
43        pass
44
45
46class QBatchNorm1dBenchmark(QBatchNormBenchmark):
47    def _init(self, M, N, K, device):
48        self.set_module_name("QBatchNorm1d")
49        self.input_one = torch.rand(
50            M, N, K, device=device, requires_grad=self.auto_set()
51        )
52
53    def forward(
54        self,
55        q_input_one,
56        weight,
57        bias,
58        mean,
59        var,
60        eps: float,
61        Y_scale: float,
62        Y_zero_point: int,
63    ):
64        return torch.ops.quantized.batch_norm1d(
65            q_input_one, weight, bias, mean, var, eps, Y_scale, Y_zero_point
66        )
67
68
69class QBatchNorm2dBenchmark(QBatchNormBenchmark):
70    def _init(self, M, N, K, device):
71        self.set_module_name("QBatchNorm2d")
72        # Note: quantized implementation requires rank 4, which is why we
73        # add a 1 as the last dimension
74        self.input_one = torch.rand(
75            M, N, K, 1, device=device, requires_grad=self.auto_set()
76        )
77
78    def forward(
79        self,
80        q_input_one,
81        weight,
82        bias,
83        mean,
84        var,
85        eps: float,
86        Y_scale: float,
87        Y_zero_point: int,
88    ):
89        return torch.ops.quantized.batch_norm2d(
90            q_input_one, weight, bias, mean, var, eps, Y_scale, Y_zero_point
91        )
92
93
94op_bench.generate_pt_test(batchnorm_configs_short, QBatchNorm1dBenchmark)
95op_bench.generate_pt_test(batchnorm_configs_short, QBatchNorm2dBenchmark)
96
97if __name__ == "__main__":
98    op_bench.benchmark_runner.main()
99