1import operator_benchmark as op_bench 2 3import torch 4import torch.ao.nn.quantized as nnq 5import torch.ao.quantization as tq 6import torch.nn as nn 7 8 9"""Microbenchmarks for general quantization operations.""" 10 11# mode is used to show the direction of the benchmark: 12# if 'Q', benchmark quantization, else dequantization 13 14quantize_configs_short_dict = { 15 "attr_names": ["C", "M", "N", "dtype", "mode"], 16 "attrs": [ 17 [3, 512, 512, torch.quint8, "Q"], 18 [3, 512, 512, torch.quint8, "D"], 19 ], 20 "tags": ["short"], 21} 22 23quantize_configs_long_dict = { 24 "C": [3, 5, 8], # this is reused for per-channel: avoid single channel test 25 "M": [256, 1024], 26 "N": [256, 1024], 27 "dtype": [torch.quint8, torch.qint8, torch.qint32], 28 "mode": ["D", "Q"], 29 "tags": ["long"], 30} 31 32 33quantize_per_tensor_configs_short = op_bench.config_list(**quantize_configs_short_dict) 34 35quantize_per_tensor_configs_long = op_bench.cross_product_configs( 36 **quantize_configs_long_dict 37) 38 39 40class QuantizePerTensorBenchmark(op_bench.TorchBenchmarkBase): 41 r"""Benchmarks both quantization and dequantization.""" 42 43 def init(self, C, M, N, dtype, mode): 44 assert mode in ("Q", "D") 45 self.input = torch.rand(C, M, N) 46 self.dtype = dtype 47 self.op = nnq.Quantize(scale=1.0, zero_point=0, dtype=dtype) 48 self.set_module_name("QuantizePerTensor") 49 50 if mode == "D": 51 self.input = self.op(self.input) 52 self.op = nnq.DeQuantize() 53 self.set_module_name("DequantizePerTensor") 54 55 self.inputs = {"input": self.input} 56 57 def forward(self, input): 58 return self.op(input) 59 60 61op_bench.generate_pt_test( 62 quantize_per_tensor_configs_short + quantize_per_tensor_configs_long, 63 QuantizePerTensorBenchmark, 64) 65 66# === Per Channel quantization === 67 68quantize_per_channel_configs_short = op_bench.config_list( 69 cross_product_configs={"axis": (0,)}, **quantize_configs_short_dict 70) 71 72quantize_per_channel_configs_long = op_bench.cross_product_configs( 73 axis=(0, 1, 2), **quantize_configs_long_dict 74) 75 76 77class QuantizePerChannelBenchmark(op_bench.TorchBenchmarkBase): 78 r"""Benchmarks both quantization and dequantization.""" 79 80 def init(self, C, M, N, dtype, axis, mode): 81 assert mode in ("Q", "D") 82 self.input = torch.rand(C, M, N) 83 self.op = torch.quantize_per_channel 84 85 channel_len = (C, M, N)[axis] 86 87 self.kwargs = { 88 "scales": torch.tensor([1.0] * channel_len), 89 "zero_points": torch.tensor([0] * channel_len), 90 "dtype": dtype, 91 "axis": axis, 92 } 93 94 self.set_module_name("QuantizePerChannel") 95 96 if mode == "D": 97 self.input = self.op(self.input, **self.kwargs) 98 99 def dequant(input, scales, zero_points, axis: int, dtype: int): 100 return input.dequantize() 101 102 self.op = dequant 103 self.set_module_name("DequantizePerChannel") 104 105 self.inputs = { 106 "input": self.input, 107 "scales": torch.tensor([1.0] * channel_len), 108 "zero_points": torch.tensor([0] * channel_len), 109 "axis": axis, 110 "dtype": dtype, 111 } 112 113 def forward(self, input, scales, zero_points, axis: int, dtype: int): 114 return self.op( 115 input, scales=scales, zero_points=zero_points, axis=axis, dtype=dtype 116 ) 117 118 119op_bench.generate_pt_test( 120 quantize_per_channel_configs_short + quantize_per_channel_configs_long, 121 QuantizePerChannelBenchmark, 122) 123 124# === Fake Quantization === 125# Generated benchmarks names start with 'learnable_kernel' or 'original_kernel', 126# for ex. 'original_kernel_nbits8_cpu_N1_C1_H256_W256_zero_point_dtypetorch.int32_bwdall' 127 128fake_quantize_configs_short_dict = { 129 "attr_names": ["N", "C", "H", "W", "zero_point_dtype"], 130 "attrs": [ 131 [1, 3, 512, 512, torch.int32], 132 ], 133 "tags": ["short"], 134} 135 136fake_quantize_configs_long_dict = { 137 "N": [1], 138 "C": [1, 3, 8, 32], 139 "H": [256, 1024], 140 "W": [256, 1024], 141 "zero_point_dtype": [torch.int32], 142 "tags": ["long"], 143} 144 145fake_quantize_configs_short = op_bench.config_list( 146 cross_product_configs={ 147 "device": ("cpu", "cuda"), 148 }, 149 **fake_quantize_configs_short_dict, 150) 151 152fake_quantize_configs_long = op_bench.cross_product_configs( 153 device=("cpu", "cuda"), **fake_quantize_configs_long_dict 154) 155 156 157class FakeQuantizeBenchmark(op_bench.TorchBenchmarkBase): 158 r"""Benchmarks fake quantization with default parameters.""" 159 160 def init(self, N, C, H, W, zero_point_dtype, device): 161 self.inputs = {"input": torch.rand(N, C, H, W).to(device)} 162 self.op = tq.FakeQuantize().to(device) 163 self.set_module_name("FakeQuantize") 164 165 def forward(self, input): 166 return self.op(input) 167 168 169op_bench.generate_pt_test( 170 fake_quantize_configs_short + fake_quantize_configs_long, FakeQuantizeBenchmark 171) 172 173 174# op_type is used to describe the type of operator used in benchmarking: 175# learnable_kernel represents the c++ kernel that can backpropagate on 176# scale and zero point. 177# original_kernel represents the original fake quantize c++ kernel. 178 179 180def fakeQuantizePerTensorLearnableKernel( 181 input, scale, zero_point, quant_min: int, quant_max: int 182): 183 return torch._fake_quantize_learnable_per_tensor_affine( 184 input, scale, zero_point, quant_min, quant_max 185 ) 186 187 188def fakeQuantizePerTensorOriginalKernel( 189 input, scale, zero_point, quant_min: int, quant_max: int 190): 191 return torch.fake_quantize_per_tensor_affine(input, 1.0, 0, quant_min, quant_max) 192 193 194fake_quantize_per_tensor_ops = op_bench.op_list( 195 attrs=( 196 ("learnable_kernel", fakeQuantizePerTensorLearnableKernel), 197 ("original_kernel", fakeQuantizePerTensorOriginalKernel), 198 ), 199 attr_names=("op_name", "op_func"), 200) 201 202fake_quantize_operator_configs_short = op_bench.config_list( 203 cross_product_configs={ 204 "nbits": (4, 8), 205 "device": ("cpu", "cuda"), 206 }, 207 **fake_quantize_configs_short_dict, 208) 209 210fake_quantize_operator_configs_long = op_bench.cross_product_configs( 211 nbits=(4, 8), device=("cpu", "cuda"), **fake_quantize_configs_long_dict 212) 213 214# TODO(future PR) Combine config for floating point zero_point with other configs, once it is 215# fully supported in all fakeQuant operators and devices for 216# https://github.com/pytorch/pytorch/issues/61866. 217fake_quantize_configs_long_dict_float_zero_point = ( 218 fake_quantize_configs_long_dict.copy() 219) 220fake_quantize_configs_long_dict_float_zero_point["zero_point_dtype"] = [ 221 torch.float32, 222 torch.half, 223] 224 225fake_quantize_operator_configs_long_float_zero_point = op_bench.cross_product_configs( 226 nbits=(8,), 227 device=("cpu", "cuda"), 228 **fake_quantize_configs_long_dict_float_zero_point, 229) 230 231 232class FakeQuantizePerTensorBaseOpBenchmark(op_bench.TorchBenchmarkBase): 233 r"""Benchmarks 3 different fake quantize per tensor operators.""" 234 235 def init(self, N, C, H, W, zero_point_dtype, nbits, device, op_func): 236 self.quant_min = 0 237 self.quant_max = 2**nbits - 1 238 self.quant_range = 2**nbits 239 self.input = nn.Parameter( 240 torch.rand(N, C, H, W, dtype=torch.float, device=device), 241 requires_grad=self.auto_set(), 242 ) 243 self.scale = nn.Parameter( 244 torch.tensor([1.0]).to(device), requires_grad=self.auto_set() 245 ) 246 if op_func.__name__ == "fakeQuantizePerChannelOriginalKernel": 247 self.zero_point = nn.Parameter( 248 torch.tensor([0.0]).to(device).to(zero_point_dtype), 249 requires_grad=self.auto_set(), 250 ) 251 else: 252 self.zero_point = nn.Parameter( 253 torch.tensor([0.0]).to(device), requires_grad=self.auto_set() 254 ) 255 256 self.inputs = { 257 "input": self.input, 258 "scale": self.scale, 259 "zero_point": self.zero_point, 260 "quant_min": self.quant_min, 261 "quant_max": self.quant_max, 262 } 263 self.op_func = op_func 264 265 def forward(self, input, scale, zero_point, quant_min: int, quant_max: int): 266 return self.op_func(input, scale, zero_point, quant_min, quant_max) 267 268 269op_bench.generate_pt_tests_from_op_list( 270 fake_quantize_per_tensor_ops, 271 fake_quantize_operator_configs_short + fake_quantize_operator_configs_long, 272 FakeQuantizePerTensorBaseOpBenchmark, 273) 274 275op_bench.generate_pt_gradient_tests_from_op_list( 276 fake_quantize_per_tensor_ops, 277 fake_quantize_operator_configs_short + fake_quantize_operator_configs_long, 278 FakeQuantizePerTensorBaseOpBenchmark, 279) 280 281 282def fakeQuantizePerChannelLearnableKernel( 283 input, scale, zero_point, axis: int, quant_min: int, quant_max: int 284): 285 return torch._fake_quantize_learnable_per_channel_affine( 286 input, scale, zero_point, axis, quant_min, quant_max 287 ) 288 289 290def fakeQuantizePerChannelOriginalKernel( 291 input, scale, zero_point, axis: int, quant_min: int, quant_max: int 292): 293 return torch.fake_quantize_per_channel_affine( 294 input, scale, zero_point, axis, quant_min, quant_max 295 ) 296 297 298fake_quantize_per_channel_ops = op_bench.op_list( 299 attrs=( 300 ("learnable_kernel", fakeQuantizePerChannelLearnableKernel), 301 ("original_kernel", fakeQuantizePerChannelOriginalKernel), 302 ), 303 attr_names=("op_name", "op_func"), 304) 305 306fake_quantize_per_channel_float_zero_point_ops = op_bench.op_list( 307 attrs=(("original_kernel", fakeQuantizePerChannelOriginalKernel),), 308 attr_names=("op_name", "op_func"), 309) 310 311 312class FakeQuantizePerChannelOpBenchmark(op_bench.TorchBenchmarkBase): 313 r"""Benchmarks 3 different fake quantize per channel operators.""" 314 315 def init(self, N, C, H, W, zero_point_dtype, nbits, device, op_func): 316 self.quant_min = 0 317 self.quant_max = 2**nbits - 1 318 self.quant_range = 2**nbits 319 # Axis is chosen with respect to the number of channels: C. 320 self.axis = 1 321 self.input = nn.Parameter( 322 torch.rand( 323 N, 324 C, 325 H, 326 W, 327 dtype=torch.float, 328 device=device, 329 requires_grad=self.auto_set(), 330 ) 331 ) 332 333 if op_func.__name__ == "fakeQuantizePerChannelOriginalKernel": 334 self.scale = torch.ones( 335 C, device=device, dtype=torch.float32, requires_grad=False 336 ) 337 self.zero_point = torch.zeros( 338 C, device=device, dtype=zero_point_dtype, requires_grad=False 339 ) 340 else: 341 self.scale = nn.Parameter( 342 torch.ones(C, device=device, dtype=torch.float32), 343 requires_grad=self.auto_set(), 344 ) 345 self.zero_point = nn.Parameter( 346 torch.zeros(C, device=device, dtype=torch.float32), 347 requires_grad=self.auto_set(), 348 ) 349 350 self.inputs = { 351 "input": self.input, 352 "scale": self.scale, 353 "zero_point": self.zero_point, 354 "axis": self.axis, 355 "quant_min": self.quant_min, 356 "quant_max": self.quant_max, 357 } 358 359 self.op_func = op_func 360 361 def forward( 362 self, input, scale, zero_point, axis: int, quant_min: int, quant_max: int 363 ): 364 return self.op_func(input, scale, zero_point, axis, quant_min, quant_max) 365 366 367op_bench.generate_pt_tests_from_op_list( 368 fake_quantize_per_channel_ops, 369 fake_quantize_operator_configs_short + fake_quantize_operator_configs_long, 370 FakeQuantizePerChannelOpBenchmark, 371) 372 373op_bench.generate_pt_tests_from_op_list( 374 fake_quantize_per_channel_float_zero_point_ops, 375 fake_quantize_operator_configs_long_float_zero_point, 376 FakeQuantizePerChannelOpBenchmark, 377) 378 379op_bench.generate_pt_gradient_tests_from_op_list( 380 fake_quantize_per_channel_ops, 381 fake_quantize_operator_configs_short + fake_quantize_operator_configs_long, 382 FakeQuantizePerChannelOpBenchmark, 383) 384 385if __name__ == "__main__": 386 op_bench.benchmark_runner.main() 387