• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1import operator_benchmark as op_bench
2
3import torch
4
5
6"""Microbenchmarks for quantized unary operators (point-wise and reduction)."""
7
8
9# Configs for pointwise and reduction unary ops
10qunary_ops_configs_short = op_bench.config_list(
11    attr_names=["M", "N"],
12    attrs=[
13        [512, 512],
14    ],
15    cross_product_configs={
16        "dtype": [torch.quint8],
17    },
18    tags=["short"],
19)
20
21qunary_ops_configs_long = op_bench.cross_product_configs(
22    M=[256, 1024],
23    N=[256, 1024],
24    dtype=[torch.quint8, torch.qint8, torch.qint32],
25    tags=["long"],
26)
27
28
29class QUnaryOpBenchmark(op_bench.TorchBenchmarkBase):
30    def init(self, M, N, dtype, op_func):
31        f_input = torch.rand(M, N)
32        scale = 1.0
33        zero_point = 0
34        self.inputs = {
35            "q_input": torch.quantize_per_tensor(
36                f_input, scale=scale, zero_point=zero_point, dtype=dtype
37            )
38        }
39        self.op_func = op_func
40
41    def forward(self, q_input):
42        return self.op_func(q_input)
43
44
45# TODO: Uncomment the ops whenever they are implemented for quantized tensor.
46qunary_ops_list = op_bench.op_list(
47    attr_names=["op_name", "op_func"],
48    attrs=[
49        # ['q_abs', torch.abs],
50        # ['q_abs_', torch.abs_],
51        # ['q_acos', torch.acos],
52        # ['q_acos_', torch.acos_],
53        ["q_argsort", torch.argsort],
54        # ['q_asin', torch.asin],
55        # ['q_asin_', torch.asin_],
56        # ['q_atan', torch.atan],
57        # ['q_atan_', torch.atan_],
58        # ['q_ceil', torch.ceil],
59        # ['q_ceil_', torch.ceil_],
60        ["q_clone", torch.clone],
61        # ['q_cos', torch.cos],
62        # ['q_cos_', torch.cos_],
63        # ['q_cosh', torch.cosh],
64        # ['q_digamma', torch.digamma],
65        # ['q_erf', torch.erf],
66        # ['q_erf_', torch.erf_],
67        # ['q_erfc', torch.erfc],
68        # ['q_erfc_', torch.erfc_],
69        # ['q_erfinv', torch.erfinv],
70        # ['q_exp', torch.exp],
71        # ['q_exp_', torch.exp_],
72        # ['q_expm1', torch.expm1],
73        # ['q_expm1_', torch.expm1_],
74        # ['q_floor', torch.floor],
75        # ['q_floor_', torch.floor_],
76        # ['q_frac', torch.frac],
77        # ['q_frac_', torch.frac_],
78        # ['q_hardshrink', torch.hardshrink],
79        # ['q_lgamma', torch.lgamma],
80        # ['q_log', torch.log],
81        # ['q_log10', torch.log10],
82        # ['q_log10_', torch.log10_],
83        # ['q_log1p', torch.log1p],
84        # ['q_log1p_', torch.log1p_],
85        # ['q_log2', torch.log2],
86        # ['q_log2_', torch.log2_],
87        # ['q_log_', torch.log_],
88        ["q_mean", torch.mean],
89        # ['q_neg', torch.neg],
90        # ['q_neg_', torch.neg_],
91        # ['q_reciprocal', torch.reciprocal],
92        # ['q_reciprocal_', torch.reciprocal_],
93        ["q_relu", torch.relu],
94        ["q_relu_", torch.relu_],
95        # ['q_round', torch.round],
96        # ['q_round_', torch.round_],
97        # ['q_rsqrt', torch.rsqrt],
98        # ['q_rsqrt_', torch.rsqrt_],
99        # ['q_sigmoid', torch.sigmoid],
100        # ['q_sigmoid_', torch.sigmoid_],
101        # ['q_sign', torch.sign],
102        # ['q_sin', torch.sin],
103        # ['q_sin_', torch.sin_],
104        # ['q_sinh', torch.sinh],
105        ["q_sort", torch.sort],
106        # ['q_sqrt', torch.sqrt],
107        # ['q_sqrt_', torch.sqrt_],
108        # ['q_tan', torch.tan],
109        # ['q_tan_', torch.tan_],
110        # ['q_tanh', torch.tanh],
111        # ['q_tanh_', torch.tanh_],
112        # ['q_trunc', torch.trunc],
113        # ['q_trunc_', torch.trunc_],
114        # ['q_unique', torch.unique],
115        # ['q_zero_', torch.zero_],
116        # ['q_bernoulli_', lambda t: t.bernoulli_()],
117        # ['q_cauchy_', lambda t: t.cauchy_()],
118        # ['q_digamma_', lambda t: t.digamma_()],
119        # ['q_exponential_', lambda t: t.exponential_()],
120        # ['q_normal_', lambda t: t.normal_()],
121        # ['q_random_', lambda t: t.random_()],
122        # ['q_sign_', lambda t: t.sign_()],
123        # ['q_uniform_', lambda t: t.uniform_()],
124        # ['q_half', lambda t: t.half()],
125        # ['q_long', lambda t: t.long()],
126    ],
127)
128
129
130op_bench.generate_pt_tests_from_op_list(
131    qunary_ops_list,
132    qunary_ops_configs_short + qunary_ops_configs_long,
133    QUnaryOpBenchmark,
134)
135
136
137# === Other unary ops (i.e. the ones that need parameters as args) ===
138
139# Configs for pointwise and reduction unary ops
140qunary_ops_topk_configs_short = op_bench.config_list(
141    attr_names=["M", "N", "k"],
142    attrs=[
143        [512, 512, 5],
144    ],
145    cross_product_configs={
146        "dtype": [torch.quint8],
147    },
148    tags=["short"],
149)
150
151qunary_ops_topk_configs_long = op_bench.cross_product_configs(
152    M=[256, 1024],
153    N=[256, 1024],
154    k=[1, 3, 5],
155    dtype=[torch.quint8, torch.qint8, torch.qint32],
156    tags=["long"],
157)
158
159
160class QTopkOpBenchmark(op_bench.TorchBenchmarkBase):
161    def init(self, M, N, dtype, k):
162        f_input = torch.rand(M, N)
163        scale = 1.0
164        zero_point = 0
165        self.inputs = {
166            "q_input": torch.quantize_per_tensor(
167                f_input, scale=scale, zero_point=zero_point, dtype=dtype
168            ),
169            "k": k,
170        }
171        self.set_module_name("qtopk")
172
173    def forward(self, q_input, k: int):
174        return torch.topk(q_input, k)
175
176
177op_bench.generate_pt_test(
178    qunary_ops_topk_configs_short + qunary_ops_topk_configs_long, QTopkOpBenchmark
179)
180
181
182if __name__ == "__main__":
183    op_bench.benchmark_runner.main()
184