• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1import operator_benchmark as op_bench
2
3import torch
4import torch.nn as nn
5
6
7"""
8Microbenchmarks for the hardswish operators.
9"""
10
11
12# Configs for hardswish ops
13hardswish_configs_short = op_bench.config_list(
14    attr_names=["N", "C", "H", "W"],
15    attrs=[
16        [1, 3, 256, 256],
17        [4, 3, 256, 256],
18    ],
19    cross_product_configs={
20        "device": ["cpu"],
21    },
22    tags=["short"],
23)
24
25
26hardswish_configs_long = op_bench.cross_product_configs(
27    N=[8, 16], C=[3], H=[256, 512], W=[256, 512], device=["cpu"], tags=["long"]
28)
29
30
31hardswish_ops_list = op_bench.op_list(
32    attr_names=["op_name", "op_func"],
33    attrs=[
34        ["Hardswish", nn.Hardswish],
35    ],
36)
37
38
39class HardswishBenchmark(op_bench.TorchBenchmarkBase):
40    def init(self, N, C, H, W, device, op_func):
41        self.inputs = {"input_one": torch.rand(N, C, H, W, device=device)}
42        self.op_func = op_func()
43
44    def forward(self, input_one):
45        return self.op_func(input_one)
46
47
48op_bench.generate_pt_tests_from_op_list(
49    hardswish_ops_list,
50    hardswish_configs_short + hardswish_configs_long,
51    HardswishBenchmark,
52)
53
54
55if __name__ == "__main__":
56    op_bench.benchmark_runner.main()
57