1import operator_benchmark as op_bench 2 3import torch 4 5 6"""Microbenchmarks for quantized batchnorm operator.""" 7 8batchnorm_configs_short = op_bench.config_list( 9 attr_names=["M", "N", "K"], 10 attrs=[ 11 [1, 256, 3136], 12 ], 13 cross_product_configs={ 14 "device": ["cpu"], 15 "dtype": (torch.qint8,), 16 }, 17 tags=["short"], 18) 19 20 21class QBatchNormBenchmark(op_bench.TorchBenchmarkBase): 22 def init(self, M, N, K, device, dtype): 23 self._init(M, N, K, device) 24 x_scale = 0.1 25 x_zero_point = 0 26 self.inputs = { 27 "q_input_one": torch.quantize_per_tensor( 28 self.input_one, scale=x_scale, zero_point=x_zero_point, dtype=dtype 29 ), 30 "mean": torch.rand(N), 31 "var": torch.rand(N), 32 "weight": torch.rand(N), 33 "bias": torch.rand(N), 34 "eps": 1e-5, 35 "Y_scale": 0.1, 36 "Y_zero_point": 0, 37 } 38 39 def _init(self, M, N, K, device): 40 pass 41 42 def forward(self): 43 pass 44 45 46class QBatchNorm1dBenchmark(QBatchNormBenchmark): 47 def _init(self, M, N, K, device): 48 self.set_module_name("QBatchNorm1d") 49 self.input_one = torch.rand( 50 M, N, K, device=device, requires_grad=self.auto_set() 51 ) 52 53 def forward( 54 self, 55 q_input_one, 56 weight, 57 bias, 58 mean, 59 var, 60 eps: float, 61 Y_scale: float, 62 Y_zero_point: int, 63 ): 64 return torch.ops.quantized.batch_norm1d( 65 q_input_one, weight, bias, mean, var, eps, Y_scale, Y_zero_point 66 ) 67 68 69class QBatchNorm2dBenchmark(QBatchNormBenchmark): 70 def _init(self, M, N, K, device): 71 self.set_module_name("QBatchNorm2d") 72 # Note: quantized implementation requires rank 4, which is why we 73 # add a 1 as the last dimension 74 self.input_one = torch.rand( 75 M, N, K, 1, device=device, requires_grad=self.auto_set() 76 ) 77 78 def forward( 79 self, 80 q_input_one, 81 weight, 82 bias, 83 mean, 84 var, 85 eps: float, 86 Y_scale: float, 87 Y_zero_point: int, 88 ): 89 return torch.ops.quantized.batch_norm2d( 90 q_input_one, weight, bias, mean, var, eps, Y_scale, Y_zero_point 91 ) 92 93 94op_bench.generate_pt_test(batchnorm_configs_short, QBatchNorm1dBenchmark) 95op_bench.generate_pt_test(batchnorm_configs_short, QBatchNorm2dBenchmark) 96 97if __name__ == "__main__": 98 op_bench.benchmark_runner.main() 99