• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1import operator_benchmark as op_bench
2
3import torch
4
5
6"""
7Microbenchmarks for the gelu operators.
8"""
9
10gelu_configs_long = op_bench.cross_product_configs(
11    N=[1, 4], C=[3], H=[16, 256], W=[16, 256], device=["cpu"], tags=["long"]
12)
13
14
15class GeluBenchmark(op_bench.TorchBenchmarkBase):
16    def init(self, N, C, H, W, device):
17        self.inputs = {"input": torch.rand(N, C, H, W, device=device)}
18
19    def forward(self, input):
20        return torch.nn.functional.gelu(input)
21
22
23op_bench.generate_pt_test(gelu_configs_long, GeluBenchmark)
24
25
26if __name__ == "__main__":
27    op_bench.benchmark_runner.main()
28