1import operator_benchmark as op_bench 2 3import torch 4 5 6"""Microbenchmarks for point-wise unary operator.""" 7 8 9# Configs for pointwise unary ops 10unary_ops_configs_short = op_bench.config_list( 11 attr_names=["M", "N"], 12 attrs=[ 13 [512, 512], 14 ], 15 cross_product_configs={ 16 "device": ["cpu", "cuda"], 17 }, 18 tags=["short"], 19) 20 21unary_ops_configs_long = op_bench.cross_product_configs( 22 M=[256, 1024], N=[256, 1024], device=["cpu", "cuda"], tags=["long"] 23) 24 25 26class UnaryOpBenchmark(op_bench.TorchBenchmarkBase): 27 def init(self, M, N, device, op_func): 28 self.inputs = {"input": torch.rand(M, N, device=device)} 29 self.op_func = op_func 30 31 def forward(self, input): 32 return self.op_func(input) 33 34 35def bernoulli_(input): 36 return input.bernoulli_() 37 38 39def cauchy_(input): 40 return input.cauchy_() 41 42 43def digamma_(input): 44 return input.digamma_() 45 46 47def exponential_(input): 48 return input.exponential_() 49 50 51def normal_(input): 52 return input.normal_() 53 54 55def random_(input): 56 return input.random_() 57 58 59def sign_(input): 60 return input.sign_() 61 62 63def uniform_(input): 64 return input.uniform_() 65 66 67def half_(input): 68 return input.half() 69 70 71def long_(input): 72 return input.long() 73 74 75unary_ops_list = op_bench.op_list( 76 attr_names=["op_name", "op_func"], 77 attrs=[ 78 ["abs", torch.abs], 79 ["abs_", torch.abs_], 80 ["acos", torch.acos], 81 ["acos_", torch.acos_], 82 ["argsort", torch.argsort], 83 ["asin", torch.asin], 84 ["asin_", torch.asin_], 85 ["atan", torch.atan], 86 ["atan_", torch.atan_], 87 ["ceil", torch.ceil], 88 ["ceil_", torch.ceil_], 89 ["clone", torch.clone], 90 ["cos", torch.cos], 91 ["cos_", torch.cos_], 92 ["cosh", torch.cosh], 93 ["digamma", torch.digamma], 94 ["erf", torch.erf], 95 ["erf_", torch.erf_], 96 ["erfc", torch.erfc], 97 ["erfc_", torch.erfc_], 98 ["erfinv", torch.erfinv], 99 ["exp", torch.exp], 100 ["exp_", torch.exp_], 101 ["expm1", torch.expm1], 102 ["expm1_", torch.expm1_], 103 ["floor", torch.floor], 104 ["floor_", torch.floor_], 105 ["frac", torch.frac], 106 ["frac_", torch.frac_], 107 ["hardshrink", torch.hardshrink], 108 ["lgamma", torch.lgamma], 109 ["log", torch.log], 110 ["log10", torch.log10], 111 ["log10_", torch.log10_], 112 ["log1p", torch.log1p], 113 ["log1p_", torch.log1p_], 114 ["log2", torch.log2], 115 ["log2_", torch.log2_], 116 ["log_", torch.log_], 117 ["logit", torch.logit], 118 ["logit_", torch.logit_], 119 ["neg", torch.neg], 120 ["neg_", torch.neg_], 121 ["reciprocal", torch.reciprocal], 122 ["reciprocal_", torch.reciprocal_], 123 ["relu", torch.relu], 124 ["relu_", torch.relu_], 125 ["round", torch.round], 126 ["round_", torch.round_], 127 ["rsqrt", torch.rsqrt], 128 ["rsqrt_", torch.rsqrt_], 129 ["sigmoid", torch.sigmoid], 130 ["sigmoid_", torch.sigmoid_], 131 ["sign", torch.sign], 132 ["sgn", torch.sgn], 133 ["sin", torch.sin], 134 ["sin_", torch.sin_], 135 ["sinh", torch.sinh], 136 ["sqrt", torch.sqrt], 137 ["sqrt_", torch.sqrt_], 138 ["square", torch.square], 139 ["square_", torch.square_], 140 ["tan", torch.tan], 141 ["tan_", torch.tan_], 142 ["tanh", torch.tanh], 143 ["tanh_", torch.tanh_], 144 ["trunc", torch.trunc], 145 ["trunc_", torch.trunc_], 146 ["unique", torch.functional._return_output], 147 ["zero_", torch.zero_], 148 ["bernoulli_", bernoulli_], 149 ["cauchy_", cauchy_], 150 ["digamma_", digamma_], 151 ["exponential_", exponential_], 152 ["normal_", normal_], 153 ["random_", random_], 154 ["sign_", sign_], 155 ["uniform_", uniform_], 156 ["half", half_], 157 ["long", long_], 158 ], 159) 160 161 162op_bench.generate_pt_tests_from_op_list( 163 unary_ops_list, unary_ops_configs_short + unary_ops_configs_long, UnaryOpBenchmark 164) 165 166 167if __name__ == "__main__": 168 op_bench.benchmark_runner.main() 169