1import operator_benchmark as op_bench 2 3import torch 4 5 6"""Microbenchmarks for the quantized interpolate op. 7 8Note: We are not benchmarking `upsample` as it is being deprecated, and calls 9the `interpolate` anyway. 10""" 11 12qinterpolate_long_configs = op_bench.config_list( 13 attr_names=["M", "N", "K"], 14 attrs=[ 15 [512, 512, 512], 16 ], 17 cross_product_configs={ 18 "dtype": [torch.quint8, torch.qint8, torch.qint32], 19 "mode": ["nearest", "bilinear"], 20 "scale": [0.5, 1.0, 2.0], 21 "contig": [True], # TODO: Add `False` after #29435 22 }, 23 tags=["long"], 24) 25 26 27qinterpolate_short_configs = op_bench.config_list( 28 attr_names=["M", "N", "K", "dtype", "mode", "scale", "contig"], 29 attrs=[ 30 [32, 32, 32, torch.quint8, "nearest", 0.5, True], # Downsample 31 [32, 32, 32, torch.quint8, "bilinear", 0.5, True], # Downsample 32 [32, 32, 32, torch.quint8, "nearest", 2.0, True], # Upsample 33 [32, 32, 32, torch.quint8, "bilinear", 2.0, True], # Upsample 34 [3, 720, 1280, torch.quint8, "bilinear", 0.83333, True], # Downsample 35 ], 36 tags=["short"], 37) 38 39 40class QInterpolateBenchmark(op_bench.TorchBenchmarkBase): 41 def init(self, M, N, K, dtype, mode, scale, contig): 42 f_input = (torch.rand(1, M, N, K) - 0.5) * 256 43 scale = 0.1 44 zero_point = 42 45 self.q_input = torch.quantize_per_tensor( 46 f_input, scale=scale, zero_point=zero_point, dtype=dtype 47 ) 48 if not contig: 49 permute_dims = list(range(self.q_input.ndim))[::-1] 50 self.q_input = self.q_input.permute(permute_dims) 51 52 self.inputs = {"q_input": self.q_input, "scale_factor": scale, "mode": mode} 53 self.set_module_name("q_interpolate") 54 55 def forward(self, q_input, scale_factor: float, mode: str): 56 return torch.nn.functional.interpolate( 57 q_input, scale_factor=scale_factor, mode=mode 58 ) 59 60 61op_bench.generate_pt_test( 62 qinterpolate_short_configs + qinterpolate_long_configs, QInterpolateBenchmark 63) 64 65 66if __name__ == "__main__": 67 op_bench.benchmark_runner.main() 68