1import operator_benchmark as op_bench 2 3import torch 4import torch.nn.functional as F 5 6 7"""Microbenchmarks for instancenorm operator.""" 8 9instancenorm_configs_short = op_bench.cross_product_configs( 10 dims=( 11 (32, 8, 16), 12 (32, 8, 56, 56), 13 ), 14 tags=["short"], 15) 16 17 18class InstanceNormBenchmark(op_bench.TorchBenchmarkBase): 19 def init(self, dims): 20 num_channels = dims[1] 21 self.inputs = { 22 "input": (torch.rand(*dims) - 0.5) * 256, 23 "weight": torch.rand(num_channels, dtype=torch.float), 24 "bias": torch.rand(num_channels, dtype=torch.float), 25 "eps": 1e-5, 26 } 27 28 def forward(self, input, weight, bias, eps: float): 29 return F.instance_norm(input, weight=weight, bias=bias, eps=eps) 30 31 32op_bench.generate_pt_test(instancenorm_configs_short, InstanceNormBenchmark) 33 34 35if __name__ == "__main__": 36 op_bench.benchmark_runner.main() 37