1import operator_benchmark as op_bench 2 3import torch 4import torch.nn as nn 5 6 7""" 8Microbenchmarks for the softmax operators. 9""" 10 11 12# Configs for softmax ops 13softmax_configs_short = op_bench.config_list( 14 attr_names=["N", "C", "H", "W"], 15 attrs=[ 16 [1, 3, 256, 256], 17 [4, 3, 256, 256], 18 ], 19 cross_product_configs={ 20 "device": ["cpu", "cuda"], 21 }, 22 tags=["short"], 23) 24 25 26softmax_configs_long = op_bench.cross_product_configs( 27 N=[8, 16], 28 C=[3], 29 H=[256, 512], 30 W=[256, 512], 31 device=["cpu", "cuda"], 32 tags=["long"], 33) 34 35 36softmax_ops_list = op_bench.op_list( 37 attr_names=["op_name", "op_func"], 38 attrs=[ 39 ["Softmax", nn.Softmax], 40 ["Softmax2d", nn.Softmax2d], 41 ["LogSoftmax", nn.LogSoftmax], 42 ], 43) 44 45softmax_two_dims_ops_list = op_bench.op_list( 46 attr_names=["op_name", "op_func"], 47 attrs=[ 48 ["Softmax", nn.Softmax], 49 ["LogSoftmax", nn.LogSoftmax], 50 ], 51) 52 53 54softmax_two_dims_configs = op_bench.config_list( 55 attr_names=["M", "N", "dim"], 56 attrs=[ 57 [700, 23258, 0], 58 [700, 23258, 1], 59 [1024, 23258, 1], 60 [128, 128, 1], 61 [48, 128, 1], 62 [16, 1024, 1], 63 [32, 1024, 1], 64 [48, 1024, 1], 65 [16, 512, 1], 66 [32, 512, 1], 67 [48, 512, 1], 68 [16, 256, 1], 69 [32, 256, 1], 70 [48, 256, 1], 71 ], 72 cross_product_configs={ 73 "device": ["cpu", "cuda"], 74 }, 75 tags=["long"], 76) 77 78 79class SoftmaxBenchmark(op_bench.TorchBenchmarkBase): 80 def init(self, N, C, H, W, device, op_func): 81 self.inputs = {"input": torch.rand(N, C, H, W, device=device)} 82 self.op_func = op_func() 83 84 def forward(self, input): 85 return self.op_func(input) 86 87 88class Softmax2DimsBenchmark(op_bench.TorchBenchmarkBase): 89 def init(self, M, N, dim, device, op_func): 90 self.inputs = {"input": torch.rand(M, N, device=device)} 91 self.op_func = op_func(dim=dim) 92 93 def forward(self, input): 94 return self.op_func(input) 95 96 97op_bench.generate_pt_tests_from_op_list( 98 softmax_ops_list, softmax_configs_short + softmax_configs_long, SoftmaxBenchmark 99) 100 101 102op_bench.generate_pt_tests_from_op_list( 103 softmax_two_dims_ops_list, softmax_two_dims_configs, Softmax2DimsBenchmark 104) 105 106 107if __name__ == "__main__": 108 op_bench.benchmark_runner.main() 109