1import operator_benchmark as op_bench 2 3import torch 4 5 6"""Microbenchmarks for quantized unary operators (point-wise and reduction).""" 7 8 9# Configs for pointwise and reduction unary ops 10qunary_ops_configs_short = op_bench.config_list( 11 attr_names=["M", "N"], 12 attrs=[ 13 [512, 512], 14 ], 15 cross_product_configs={ 16 "dtype": [torch.quint8], 17 }, 18 tags=["short"], 19) 20 21qunary_ops_configs_long = op_bench.cross_product_configs( 22 M=[256, 1024], 23 N=[256, 1024], 24 dtype=[torch.quint8, torch.qint8, torch.qint32], 25 tags=["long"], 26) 27 28 29class QUnaryOpBenchmark(op_bench.TorchBenchmarkBase): 30 def init(self, M, N, dtype, op_func): 31 f_input = torch.rand(M, N) 32 scale = 1.0 33 zero_point = 0 34 self.inputs = { 35 "q_input": torch.quantize_per_tensor( 36 f_input, scale=scale, zero_point=zero_point, dtype=dtype 37 ) 38 } 39 self.op_func = op_func 40 41 def forward(self, q_input): 42 return self.op_func(q_input) 43 44 45# TODO: Uncomment the ops whenever they are implemented for quantized tensor. 46qunary_ops_list = op_bench.op_list( 47 attr_names=["op_name", "op_func"], 48 attrs=[ 49 # ['q_abs', torch.abs], 50 # ['q_abs_', torch.abs_], 51 # ['q_acos', torch.acos], 52 # ['q_acos_', torch.acos_], 53 ["q_argsort", torch.argsort], 54 # ['q_asin', torch.asin], 55 # ['q_asin_', torch.asin_], 56 # ['q_atan', torch.atan], 57 # ['q_atan_', torch.atan_], 58 # ['q_ceil', torch.ceil], 59 # ['q_ceil_', torch.ceil_], 60 ["q_clone", torch.clone], 61 # ['q_cos', torch.cos], 62 # ['q_cos_', torch.cos_], 63 # ['q_cosh', torch.cosh], 64 # ['q_digamma', torch.digamma], 65 # ['q_erf', torch.erf], 66 # ['q_erf_', torch.erf_], 67 # ['q_erfc', torch.erfc], 68 # ['q_erfc_', torch.erfc_], 69 # ['q_erfinv', torch.erfinv], 70 # ['q_exp', torch.exp], 71 # ['q_exp_', torch.exp_], 72 # ['q_expm1', torch.expm1], 73 # ['q_expm1_', torch.expm1_], 74 # ['q_floor', torch.floor], 75 # ['q_floor_', torch.floor_], 76 # ['q_frac', torch.frac], 77 # ['q_frac_', torch.frac_], 78 # ['q_hardshrink', torch.hardshrink], 79 # ['q_lgamma', torch.lgamma], 80 # ['q_log', torch.log], 81 # ['q_log10', torch.log10], 82 # ['q_log10_', torch.log10_], 83 # ['q_log1p', torch.log1p], 84 # ['q_log1p_', torch.log1p_], 85 # ['q_log2', torch.log2], 86 # ['q_log2_', torch.log2_], 87 # ['q_log_', torch.log_], 88 ["q_mean", torch.mean], 89 # ['q_neg', torch.neg], 90 # ['q_neg_', torch.neg_], 91 # ['q_reciprocal', torch.reciprocal], 92 # ['q_reciprocal_', torch.reciprocal_], 93 ["q_relu", torch.relu], 94 ["q_relu_", torch.relu_], 95 # ['q_round', torch.round], 96 # ['q_round_', torch.round_], 97 # ['q_rsqrt', torch.rsqrt], 98 # ['q_rsqrt_', torch.rsqrt_], 99 # ['q_sigmoid', torch.sigmoid], 100 # ['q_sigmoid_', torch.sigmoid_], 101 # ['q_sign', torch.sign], 102 # ['q_sin', torch.sin], 103 # ['q_sin_', torch.sin_], 104 # ['q_sinh', torch.sinh], 105 ["q_sort", torch.sort], 106 # ['q_sqrt', torch.sqrt], 107 # ['q_sqrt_', torch.sqrt_], 108 # ['q_tan', torch.tan], 109 # ['q_tan_', torch.tan_], 110 # ['q_tanh', torch.tanh], 111 # ['q_tanh_', torch.tanh_], 112 # ['q_trunc', torch.trunc], 113 # ['q_trunc_', torch.trunc_], 114 # ['q_unique', torch.unique], 115 # ['q_zero_', torch.zero_], 116 # ['q_bernoulli_', lambda t: t.bernoulli_()], 117 # ['q_cauchy_', lambda t: t.cauchy_()], 118 # ['q_digamma_', lambda t: t.digamma_()], 119 # ['q_exponential_', lambda t: t.exponential_()], 120 # ['q_normal_', lambda t: t.normal_()], 121 # ['q_random_', lambda t: t.random_()], 122 # ['q_sign_', lambda t: t.sign_()], 123 # ['q_uniform_', lambda t: t.uniform_()], 124 # ['q_half', lambda t: t.half()], 125 # ['q_long', lambda t: t.long()], 126 ], 127) 128 129 130op_bench.generate_pt_tests_from_op_list( 131 qunary_ops_list, 132 qunary_ops_configs_short + qunary_ops_configs_long, 133 QUnaryOpBenchmark, 134) 135 136 137# === Other unary ops (i.e. the ones that need parameters as args) === 138 139# Configs for pointwise and reduction unary ops 140qunary_ops_topk_configs_short = op_bench.config_list( 141 attr_names=["M", "N", "k"], 142 attrs=[ 143 [512, 512, 5], 144 ], 145 cross_product_configs={ 146 "dtype": [torch.quint8], 147 }, 148 tags=["short"], 149) 150 151qunary_ops_topk_configs_long = op_bench.cross_product_configs( 152 M=[256, 1024], 153 N=[256, 1024], 154 k=[1, 3, 5], 155 dtype=[torch.quint8, torch.qint8, torch.qint32], 156 tags=["long"], 157) 158 159 160class QTopkOpBenchmark(op_bench.TorchBenchmarkBase): 161 def init(self, M, N, dtype, k): 162 f_input = torch.rand(M, N) 163 scale = 1.0 164 zero_point = 0 165 self.inputs = { 166 "q_input": torch.quantize_per_tensor( 167 f_input, scale=scale, zero_point=zero_point, dtype=dtype 168 ), 169 "k": k, 170 } 171 self.set_module_name("qtopk") 172 173 def forward(self, q_input, k: int): 174 return torch.topk(q_input, k) 175 176 177op_bench.generate_pt_test( 178 qunary_ops_topk_configs_short + qunary_ops_topk_configs_long, QTopkOpBenchmark 179) 180 181 182if __name__ == "__main__": 183 op_bench.benchmark_runner.main() 184