1import operator_benchmark as op_bench 2 3import torch 4import torch.nn as nn 5 6 7""" 8Microbenchmarks for the hardswish operators. 9""" 10 11 12# Configs for hardswish ops 13hardswish_configs_short = op_bench.config_list( 14 attr_names=["N", "C", "H", "W"], 15 attrs=[ 16 [1, 3, 256, 256], 17 [4, 3, 256, 256], 18 ], 19 cross_product_configs={ 20 "device": ["cpu"], 21 }, 22 tags=["short"], 23) 24 25 26hardswish_configs_long = op_bench.cross_product_configs( 27 N=[8, 16], C=[3], H=[256, 512], W=[256, 512], device=["cpu"], tags=["long"] 28) 29 30 31hardswish_ops_list = op_bench.op_list( 32 attr_names=["op_name", "op_func"], 33 attrs=[ 34 ["Hardswish", nn.Hardswish], 35 ], 36) 37 38 39class HardswishBenchmark(op_bench.TorchBenchmarkBase): 40 def init(self, N, C, H, W, device, op_func): 41 self.inputs = {"input_one": torch.rand(N, C, H, W, device=device)} 42 self.op_func = op_func() 43 44 def forward(self, input_one): 45 return self.op_func(input_one) 46 47 48op_bench.generate_pt_tests_from_op_list( 49 hardswish_ops_list, 50 hardswish_configs_short + hardswish_configs_long, 51 HardswishBenchmark, 52) 53 54 55if __name__ == "__main__": 56 op_bench.benchmark_runner.main() 57