1import numpy 2 3import operator_benchmark as op_bench 4 5import torch 6 7 8"""Microbenchmarks for index_select operator.""" 9 10# An example input from this configuration is M=4, N=4, dim=0. 11index_select_configs_short = op_bench.config_list( 12 attr_names=["M", "N", "K", "dim"], 13 attrs=[ 14 [8, 8, 1, 1], 15 [256, 512, 1, 1], 16 [512, 512, 1, 1], 17 [8, 8, 2, 1], 18 [256, 512, 2, 1], 19 [512, 512, 2, 1], 20 ], 21 cross_product_configs={ 22 "device": ["cpu", "cuda"], 23 }, 24 tags=["short"], 25) 26 27 28index_select_configs_long = op_bench.cross_product_configs( 29 M=[128, 1024], 30 N=[128, 1024], 31 K=[1, 2], 32 dim=[1], 33 device=["cpu", "cuda"], 34 tags=["long"], 35) 36 37 38class IndexSelectBenchmark(op_bench.TorchBenchmarkBase): 39 def init(self, M, N, K, dim, device): 40 max_val = N 41 numpy.random.seed((1 << 32) - 1) 42 index_dim = numpy.random.randint(0, N) 43 self.inputs = { 44 "input_one": torch.rand(M, N, K, device=device), 45 "dim": dim, 46 "index": torch.tensor( 47 numpy.random.randint(0, max_val, index_dim), device=device 48 ), 49 } 50 self.set_module_name("index_select") 51 52 def forward(self, input_one, dim, index): 53 return torch.index_select(input_one, dim, index) 54 55 56op_bench.generate_pt_test( 57 index_select_configs_short + index_select_configs_long, IndexSelectBenchmark 58) 59 60 61if __name__ == "__main__": 62 op_bench.benchmark_runner.main() 63