• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1import operator_benchmark as op_bench
2
3import torch
4from torch._ops import ops
5
6
7qarithmetic_binary_configs = op_bench.cross_product_configs(
8    N=(2, 8, 64, 512),
9    dtype=(torch.quint8, torch.qint8, torch.qint32),
10    contig=(False, True),
11    tags=("short",),
12)
13
14
15qarithmetic_binary_ops = op_bench.op_list(
16    attrs=(
17        ("add", ops.quantized.add),
18        ("add_relu", ops.quantized.add_relu),
19        ("mul", ops.quantized.mul),
20    ),
21    attr_names=("op_name", "op_func"),
22)
23
24qarithmetic_binary_scalar_ops = op_bench.op_list(
25    attrs=(
26        ("add_scalar", ops.quantized.add_scalar),
27        ("mul_scalar", ops.quantized.mul_scalar),
28    ),
29    attr_names=("op_name", "op_func"),
30)
31
32
33class _QFunctionalBinaryArithmeticBenchmarkBase(op_bench.TorchBenchmarkBase):
34    def setup(self, N, dtype, contig):
35        self.qfunctional = torch.ao.nn.quantized.QFunctional()
36
37        # TODO: Consider more diverse shapes
38        f_input = (torch.rand(N, N) - 0.5) * 256
39        self.scale = 1.0
40        self.zero_point = 0
41        self.q_input_a = torch.quantize_per_tensor(
42            f_input, scale=self.scale, zero_point=self.zero_point, dtype=dtype
43        )
44
45        if not contig:
46            permute_dims = list(range(f_input.ndim))[::-1]
47            self.q_input_a = self.q_input_a.permute(permute_dims)
48
49
50class QFunctionalBenchmark(_QFunctionalBinaryArithmeticBenchmarkBase):
51    def init(self, N, dtype, contig, op_func):
52        super().setup(N, dtype, contig)
53        self.inputs = {
54            "q_input_a": self.q_input_a,
55            "q_input_b": self.q_input_a,
56            "scale": self.scale,
57            "zero_point": self.zero_point,
58        }
59        self.op_func = op_func
60
61    def forward(self, q_input_a, q_input_b, scale: float, zero_point: int):
62        return self.op_func(q_input_a, q_input_b, scale=scale, zero_point=zero_point)
63
64
65op_bench.generate_pt_tests_from_op_list(
66    qarithmetic_binary_ops, qarithmetic_binary_configs, QFunctionalBenchmark
67)
68
69
70class QFunctionalScalarBenchmark(_QFunctionalBinaryArithmeticBenchmarkBase):
71    def init(self, N, dtype, contig, op_func):
72        super().setup(N, dtype, contig)
73        self.inputs = {"q_input": self.q_input_a, "scalar_input": 42}
74        self.op_func = op_func
75
76    def forward(self, q_input, scalar_input: int):
77        return self.op_func(q_input, scalar_input)
78
79
80op_bench.generate_pt_tests_from_op_list(
81    qarithmetic_binary_scalar_ops,
82    qarithmetic_binary_configs,
83    QFunctionalScalarBenchmark,
84)
85
86
87if __name__ == "__main__":
88    op_bench.benchmark_runner.main()
89