• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1import operator_benchmark as op_bench
2
3import torch
4
5
6"""Microbenchmarks for point-wise unary operator."""
7
8
9# Configs for pointwise unary ops
10unary_ops_configs_short = op_bench.config_list(
11    attr_names=["M", "N"],
12    attrs=[
13        [512, 512],
14    ],
15    cross_product_configs={
16        "device": ["cpu", "cuda"],
17    },
18    tags=["short"],
19)
20
21unary_ops_configs_long = op_bench.cross_product_configs(
22    M=[256, 1024], N=[256, 1024], device=["cpu", "cuda"], tags=["long"]
23)
24
25
26class UnaryOpBenchmark(op_bench.TorchBenchmarkBase):
27    def init(self, M, N, device, op_func):
28        self.inputs = {"input": torch.rand(M, N, device=device)}
29        self.op_func = op_func
30
31    def forward(self, input):
32        return self.op_func(input)
33
34
35def bernoulli_(input):
36    return input.bernoulli_()
37
38
39def cauchy_(input):
40    return input.cauchy_()
41
42
43def digamma_(input):
44    return input.digamma_()
45
46
47def exponential_(input):
48    return input.exponential_()
49
50
51def normal_(input):
52    return input.normal_()
53
54
55def random_(input):
56    return input.random_()
57
58
59def sign_(input):
60    return input.sign_()
61
62
63def uniform_(input):
64    return input.uniform_()
65
66
67def half_(input):
68    return input.half()
69
70
71def long_(input):
72    return input.long()
73
74
75unary_ops_list = op_bench.op_list(
76    attr_names=["op_name", "op_func"],
77    attrs=[
78        ["abs", torch.abs],
79        ["abs_", torch.abs_],
80        ["acos", torch.acos],
81        ["acos_", torch.acos_],
82        ["argsort", torch.argsort],
83        ["asin", torch.asin],
84        ["asin_", torch.asin_],
85        ["atan", torch.atan],
86        ["atan_", torch.atan_],
87        ["ceil", torch.ceil],
88        ["ceil_", torch.ceil_],
89        ["clone", torch.clone],
90        ["cos", torch.cos],
91        ["cos_", torch.cos_],
92        ["cosh", torch.cosh],
93        ["digamma", torch.digamma],
94        ["erf", torch.erf],
95        ["erf_", torch.erf_],
96        ["erfc", torch.erfc],
97        ["erfc_", torch.erfc_],
98        ["erfinv", torch.erfinv],
99        ["exp", torch.exp],
100        ["exp_", torch.exp_],
101        ["expm1", torch.expm1],
102        ["expm1_", torch.expm1_],
103        ["floor", torch.floor],
104        ["floor_", torch.floor_],
105        ["frac", torch.frac],
106        ["frac_", torch.frac_],
107        ["hardshrink", torch.hardshrink],
108        ["lgamma", torch.lgamma],
109        ["log", torch.log],
110        ["log10", torch.log10],
111        ["log10_", torch.log10_],
112        ["log1p", torch.log1p],
113        ["log1p_", torch.log1p_],
114        ["log2", torch.log2],
115        ["log2_", torch.log2_],
116        ["log_", torch.log_],
117        ["logit", torch.logit],
118        ["logit_", torch.logit_],
119        ["neg", torch.neg],
120        ["neg_", torch.neg_],
121        ["reciprocal", torch.reciprocal],
122        ["reciprocal_", torch.reciprocal_],
123        ["relu", torch.relu],
124        ["relu_", torch.relu_],
125        ["round", torch.round],
126        ["round_", torch.round_],
127        ["rsqrt", torch.rsqrt],
128        ["rsqrt_", torch.rsqrt_],
129        ["sigmoid", torch.sigmoid],
130        ["sigmoid_", torch.sigmoid_],
131        ["sign", torch.sign],
132        ["sgn", torch.sgn],
133        ["sin", torch.sin],
134        ["sin_", torch.sin_],
135        ["sinh", torch.sinh],
136        ["sqrt", torch.sqrt],
137        ["sqrt_", torch.sqrt_],
138        ["square", torch.square],
139        ["square_", torch.square_],
140        ["tan", torch.tan],
141        ["tan_", torch.tan_],
142        ["tanh", torch.tanh],
143        ["tanh_", torch.tanh_],
144        ["trunc", torch.trunc],
145        ["trunc_", torch.trunc_],
146        ["unique", torch.functional._return_output],
147        ["zero_", torch.zero_],
148        ["bernoulli_", bernoulli_],
149        ["cauchy_", cauchy_],
150        ["digamma_", digamma_],
151        ["exponential_", exponential_],
152        ["normal_", normal_],
153        ["random_", random_],
154        ["sign_", sign_],
155        ["uniform_", uniform_],
156        ["half", half_],
157        ["long", long_],
158    ],
159)
160
161
162op_bench.generate_pt_tests_from_op_list(
163    unary_ops_list, unary_ops_configs_short + unary_ops_configs_long, UnaryOpBenchmark
164)
165
166
167if __name__ == "__main__":
168    op_bench.benchmark_runner.main()
169