1import operator_benchmark as op_bench 2 3import torch 4import torch.nn.functional as F 5 6 7"""Microbenchmarks for groupnorm operator.""" 8 9groupnorm_configs_short = op_bench.cross_product_configs( 10 dims=( 11 (32, 8, 16), 12 (32, 8, 56, 56), 13 ), 14 num_groups=(2, 4), 15 tags=["short"], 16) 17 18 19class GroupNormBenchmark(op_bench.TorchBenchmarkBase): 20 def init(self, dims, num_groups): 21 num_channels = dims[1] 22 self.inputs = { 23 "input": (torch.rand(*dims) - 0.5) * 256, 24 "num_groups": num_groups, 25 "weight": torch.rand(num_channels, dtype=torch.float), 26 "bias": torch.rand(num_channels, dtype=torch.float), 27 "eps": 1e-5, 28 } 29 30 def forward(self, input, num_groups: int, weight, bias, eps: float): 31 return F.group_norm(input, num_groups, weight=weight, bias=bias, eps=eps) 32 33 34op_bench.generate_pt_test(groupnorm_configs_short, GroupNormBenchmark) 35 36 37if __name__ == "__main__": 38 op_bench.benchmark_runner.main() 39