1import operator_benchmark as op_bench 2 3import torch 4 5 6"""Microbenchmarks for quantized groupnorm operator.""" 7 8groupnorm_configs_short = op_bench.cross_product_configs( 9 dims=( 10 (32, 8, 16), 11 (32, 8, 56, 56), 12 ), 13 num_groups=(2, 4), 14 dtype=(torch.qint8,), 15 tags=["short"], 16) 17 18 19class QGroupNormBenchmark(op_bench.TorchBenchmarkBase): 20 def init(self, dims, num_groups, dtype): 21 X = (torch.rand(*dims) - 0.5) * 256 22 num_channels = dims[1] 23 scale = 1.0 24 zero_point = 0 25 26 self.inputs = { 27 "qX": torch.quantize_per_tensor( 28 X, scale=scale, zero_point=zero_point, dtype=dtype 29 ), 30 "num_groups": num_groups, 31 "weight": torch.rand(num_channels, dtype=torch.float), 32 "bias": torch.rand(num_channels, dtype=torch.float), 33 "eps": 1e-5, 34 "Y_scale": 0.1, 35 "Y_zero_point": 0, 36 } 37 38 def forward( 39 self, 40 qX, 41 num_groups: int, 42 weight, 43 bias, 44 eps: float, 45 Y_scale: float, 46 Y_zero_point: int, 47 ): 48 return torch.ops.quantized.group_norm( 49 qX, 50 num_groups, 51 weight=weight, 52 bias=bias, 53 eps=eps, 54 output_scale=Y_scale, 55 output_zero_point=Y_zero_point, 56 ) 57 58 59op_bench.generate_pt_test(groupnorm_configs_short, QGroupNormBenchmark) 60 61 62if __name__ == "__main__": 63 op_bench.benchmark_runner.main() 64