• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1import operator_benchmark as op_bench
2
3import torch
4import torch.ao.nn.quantized as nnq
5import torch.ao.quantization as tq
6import torch.nn as nn
7
8
9"""Microbenchmarks for general quantization operations."""
10
11# mode is used to show the direction of the benchmark:
12# if 'Q', benchmark quantization, else dequantization
13
14quantize_configs_short_dict = {
15    "attr_names": ["C", "M", "N", "dtype", "mode"],
16    "attrs": [
17        [3, 512, 512, torch.quint8, "Q"],
18        [3, 512, 512, torch.quint8, "D"],
19    ],
20    "tags": ["short"],
21}
22
23quantize_configs_long_dict = {
24    "C": [3, 5, 8],  # this is reused for per-channel: avoid single channel test
25    "M": [256, 1024],
26    "N": [256, 1024],
27    "dtype": [torch.quint8, torch.qint8, torch.qint32],
28    "mode": ["D", "Q"],
29    "tags": ["long"],
30}
31
32
33quantize_per_tensor_configs_short = op_bench.config_list(**quantize_configs_short_dict)
34
35quantize_per_tensor_configs_long = op_bench.cross_product_configs(
36    **quantize_configs_long_dict
37)
38
39
40class QuantizePerTensorBenchmark(op_bench.TorchBenchmarkBase):
41    r"""Benchmarks both quantization and dequantization."""
42
43    def init(self, C, M, N, dtype, mode):
44        assert mode in ("Q", "D")
45        self.input = torch.rand(C, M, N)
46        self.dtype = dtype
47        self.op = nnq.Quantize(scale=1.0, zero_point=0, dtype=dtype)
48        self.set_module_name("QuantizePerTensor")
49
50        if mode == "D":
51            self.input = self.op(self.input)
52            self.op = nnq.DeQuantize()
53            self.set_module_name("DequantizePerTensor")
54
55        self.inputs = {"input": self.input}
56
57    def forward(self, input):
58        return self.op(input)
59
60
61op_bench.generate_pt_test(
62    quantize_per_tensor_configs_short + quantize_per_tensor_configs_long,
63    QuantizePerTensorBenchmark,
64)
65
66# === Per Channel quantization ===
67
68quantize_per_channel_configs_short = op_bench.config_list(
69    cross_product_configs={"axis": (0,)}, **quantize_configs_short_dict
70)
71
72quantize_per_channel_configs_long = op_bench.cross_product_configs(
73    axis=(0, 1, 2), **quantize_configs_long_dict
74)
75
76
77class QuantizePerChannelBenchmark(op_bench.TorchBenchmarkBase):
78    r"""Benchmarks both quantization and dequantization."""
79
80    def init(self, C, M, N, dtype, axis, mode):
81        assert mode in ("Q", "D")
82        self.input = torch.rand(C, M, N)
83        self.op = torch.quantize_per_channel
84
85        channel_len = (C, M, N)[axis]
86
87        self.kwargs = {
88            "scales": torch.tensor([1.0] * channel_len),
89            "zero_points": torch.tensor([0] * channel_len),
90            "dtype": dtype,
91            "axis": axis,
92        }
93
94        self.set_module_name("QuantizePerChannel")
95
96        if mode == "D":
97            self.input = self.op(self.input, **self.kwargs)
98
99            def dequant(input, scales, zero_points, axis: int, dtype: int):
100                return input.dequantize()
101
102            self.op = dequant
103            self.set_module_name("DequantizePerChannel")
104
105        self.inputs = {
106            "input": self.input,
107            "scales": torch.tensor([1.0] * channel_len),
108            "zero_points": torch.tensor([0] * channel_len),
109            "axis": axis,
110            "dtype": dtype,
111        }
112
113    def forward(self, input, scales, zero_points, axis: int, dtype: int):
114        return self.op(
115            input, scales=scales, zero_points=zero_points, axis=axis, dtype=dtype
116        )
117
118
119op_bench.generate_pt_test(
120    quantize_per_channel_configs_short + quantize_per_channel_configs_long,
121    QuantizePerChannelBenchmark,
122)
123
124# === Fake Quantization ===
125# Generated benchmarks names start with 'learnable_kernel' or 'original_kernel',
126#    for ex. 'original_kernel_nbits8_cpu_N1_C1_H256_W256_zero_point_dtypetorch.int32_bwdall'
127
128fake_quantize_configs_short_dict = {
129    "attr_names": ["N", "C", "H", "W", "zero_point_dtype"],
130    "attrs": [
131        [1, 3, 512, 512, torch.int32],
132    ],
133    "tags": ["short"],
134}
135
136fake_quantize_configs_long_dict = {
137    "N": [1],
138    "C": [1, 3, 8, 32],
139    "H": [256, 1024],
140    "W": [256, 1024],
141    "zero_point_dtype": [torch.int32],
142    "tags": ["long"],
143}
144
145fake_quantize_configs_short = op_bench.config_list(
146    cross_product_configs={
147        "device": ("cpu", "cuda"),
148    },
149    **fake_quantize_configs_short_dict,
150)
151
152fake_quantize_configs_long = op_bench.cross_product_configs(
153    device=("cpu", "cuda"), **fake_quantize_configs_long_dict
154)
155
156
157class FakeQuantizeBenchmark(op_bench.TorchBenchmarkBase):
158    r"""Benchmarks fake quantization with default parameters."""
159
160    def init(self, N, C, H, W, zero_point_dtype, device):
161        self.inputs = {"input": torch.rand(N, C, H, W).to(device)}
162        self.op = tq.FakeQuantize().to(device)
163        self.set_module_name("FakeQuantize")
164
165    def forward(self, input):
166        return self.op(input)
167
168
169op_bench.generate_pt_test(
170    fake_quantize_configs_short + fake_quantize_configs_long, FakeQuantizeBenchmark
171)
172
173
174# op_type is used to describe the type of operator used in benchmarking:
175# learnable_kernel represents the c++ kernel that can backpropagate on
176# scale and zero point.
177# original_kernel represents the original fake quantize c++ kernel.
178
179
180def fakeQuantizePerTensorLearnableKernel(
181    input, scale, zero_point, quant_min: int, quant_max: int
182):
183    return torch._fake_quantize_learnable_per_tensor_affine(
184        input, scale, zero_point, quant_min, quant_max
185    )
186
187
188def fakeQuantizePerTensorOriginalKernel(
189    input, scale, zero_point, quant_min: int, quant_max: int
190):
191    return torch.fake_quantize_per_tensor_affine(input, 1.0, 0, quant_min, quant_max)
192
193
194fake_quantize_per_tensor_ops = op_bench.op_list(
195    attrs=(
196        ("learnable_kernel", fakeQuantizePerTensorLearnableKernel),
197        ("original_kernel", fakeQuantizePerTensorOriginalKernel),
198    ),
199    attr_names=("op_name", "op_func"),
200)
201
202fake_quantize_operator_configs_short = op_bench.config_list(
203    cross_product_configs={
204        "nbits": (4, 8),
205        "device": ("cpu", "cuda"),
206    },
207    **fake_quantize_configs_short_dict,
208)
209
210fake_quantize_operator_configs_long = op_bench.cross_product_configs(
211    nbits=(4, 8), device=("cpu", "cuda"), **fake_quantize_configs_long_dict
212)
213
214# TODO(future PR) Combine config for floating point zero_point with other configs, once it is
215# fully supported in all fakeQuant operators and devices for
216# https://github.com/pytorch/pytorch/issues/61866.
217fake_quantize_configs_long_dict_float_zero_point = (
218    fake_quantize_configs_long_dict.copy()
219)
220fake_quantize_configs_long_dict_float_zero_point["zero_point_dtype"] = [
221    torch.float32,
222    torch.half,
223]
224
225fake_quantize_operator_configs_long_float_zero_point = op_bench.cross_product_configs(
226    nbits=(8,),
227    device=("cpu", "cuda"),
228    **fake_quantize_configs_long_dict_float_zero_point,
229)
230
231
232class FakeQuantizePerTensorBaseOpBenchmark(op_bench.TorchBenchmarkBase):
233    r"""Benchmarks 3 different fake quantize per tensor operators."""
234
235    def init(self, N, C, H, W, zero_point_dtype, nbits, device, op_func):
236        self.quant_min = 0
237        self.quant_max = 2**nbits - 1
238        self.quant_range = 2**nbits
239        self.input = nn.Parameter(
240            torch.rand(N, C, H, W, dtype=torch.float, device=device),
241            requires_grad=self.auto_set(),
242        )
243        self.scale = nn.Parameter(
244            torch.tensor([1.0]).to(device), requires_grad=self.auto_set()
245        )
246        if op_func.__name__ == "fakeQuantizePerChannelOriginalKernel":
247            self.zero_point = nn.Parameter(
248                torch.tensor([0.0]).to(device).to(zero_point_dtype),
249                requires_grad=self.auto_set(),
250            )
251        else:
252            self.zero_point = nn.Parameter(
253                torch.tensor([0.0]).to(device), requires_grad=self.auto_set()
254            )
255
256        self.inputs = {
257            "input": self.input,
258            "scale": self.scale,
259            "zero_point": self.zero_point,
260            "quant_min": self.quant_min,
261            "quant_max": self.quant_max,
262        }
263        self.op_func = op_func
264
265    def forward(self, input, scale, zero_point, quant_min: int, quant_max: int):
266        return self.op_func(input, scale, zero_point, quant_min, quant_max)
267
268
269op_bench.generate_pt_tests_from_op_list(
270    fake_quantize_per_tensor_ops,
271    fake_quantize_operator_configs_short + fake_quantize_operator_configs_long,
272    FakeQuantizePerTensorBaseOpBenchmark,
273)
274
275op_bench.generate_pt_gradient_tests_from_op_list(
276    fake_quantize_per_tensor_ops,
277    fake_quantize_operator_configs_short + fake_quantize_operator_configs_long,
278    FakeQuantizePerTensorBaseOpBenchmark,
279)
280
281
282def fakeQuantizePerChannelLearnableKernel(
283    input, scale, zero_point, axis: int, quant_min: int, quant_max: int
284):
285    return torch._fake_quantize_learnable_per_channel_affine(
286        input, scale, zero_point, axis, quant_min, quant_max
287    )
288
289
290def fakeQuantizePerChannelOriginalKernel(
291    input, scale, zero_point, axis: int, quant_min: int, quant_max: int
292):
293    return torch.fake_quantize_per_channel_affine(
294        input, scale, zero_point, axis, quant_min, quant_max
295    )
296
297
298fake_quantize_per_channel_ops = op_bench.op_list(
299    attrs=(
300        ("learnable_kernel", fakeQuantizePerChannelLearnableKernel),
301        ("original_kernel", fakeQuantizePerChannelOriginalKernel),
302    ),
303    attr_names=("op_name", "op_func"),
304)
305
306fake_quantize_per_channel_float_zero_point_ops = op_bench.op_list(
307    attrs=(("original_kernel", fakeQuantizePerChannelOriginalKernel),),
308    attr_names=("op_name", "op_func"),
309)
310
311
312class FakeQuantizePerChannelOpBenchmark(op_bench.TorchBenchmarkBase):
313    r"""Benchmarks 3 different fake quantize per channel operators."""
314
315    def init(self, N, C, H, W, zero_point_dtype, nbits, device, op_func):
316        self.quant_min = 0
317        self.quant_max = 2**nbits - 1
318        self.quant_range = 2**nbits
319        # Axis is chosen with respect to the number of channels: C.
320        self.axis = 1
321        self.input = nn.Parameter(
322            torch.rand(
323                N,
324                C,
325                H,
326                W,
327                dtype=torch.float,
328                device=device,
329                requires_grad=self.auto_set(),
330            )
331        )
332
333        if op_func.__name__ == "fakeQuantizePerChannelOriginalKernel":
334            self.scale = torch.ones(
335                C, device=device, dtype=torch.float32, requires_grad=False
336            )
337            self.zero_point = torch.zeros(
338                C, device=device, dtype=zero_point_dtype, requires_grad=False
339            )
340        else:
341            self.scale = nn.Parameter(
342                torch.ones(C, device=device, dtype=torch.float32),
343                requires_grad=self.auto_set(),
344            )
345            self.zero_point = nn.Parameter(
346                torch.zeros(C, device=device, dtype=torch.float32),
347                requires_grad=self.auto_set(),
348            )
349
350        self.inputs = {
351            "input": self.input,
352            "scale": self.scale,
353            "zero_point": self.zero_point,
354            "axis": self.axis,
355            "quant_min": self.quant_min,
356            "quant_max": self.quant_max,
357        }
358
359        self.op_func = op_func
360
361    def forward(
362        self, input, scale, zero_point, axis: int, quant_min: int, quant_max: int
363    ):
364        return self.op_func(input, scale, zero_point, axis, quant_min, quant_max)
365
366
367op_bench.generate_pt_tests_from_op_list(
368    fake_quantize_per_channel_ops,
369    fake_quantize_operator_configs_short + fake_quantize_operator_configs_long,
370    FakeQuantizePerChannelOpBenchmark,
371)
372
373op_bench.generate_pt_tests_from_op_list(
374    fake_quantize_per_channel_float_zero_point_ops,
375    fake_quantize_operator_configs_long_float_zero_point,
376    FakeQuantizePerChannelOpBenchmark,
377)
378
379op_bench.generate_pt_gradient_tests_from_op_list(
380    fake_quantize_per_channel_ops,
381    fake_quantize_operator_configs_short + fake_quantize_operator_configs_long,
382    FakeQuantizePerChannelOpBenchmark,
383)
384
385if __name__ == "__main__":
386    op_bench.benchmark_runner.main()
387