1import numpy 2 3import operator_benchmark as op_bench 4 5import torch 6 7 8"""Microbenchmarks for gather operator.""" 9 10# An example input from this configuration is M=4, N=4, dim=0. 11gather_configs_short = op_bench.config_list( 12 attr_names=["M", "N", "dim"], 13 attrs=[ 14 [256, 512, 0], 15 [512, 512, 1], 16 ], 17 cross_product_configs={ 18 "device": ["cpu", "cuda"], 19 }, 20 tags=["short"], 21) 22 23 24gather_configs_long = op_bench.cross_product_configs( 25 M=[128, 1024], N=[128, 1024], dim=[0, 1], device=["cpu", "cuda"], tags=["long"] 26) 27 28 29class GatherBenchmark(op_bench.TorchBenchmarkBase): 30 def init(self, M, N, dim, device): 31 min_val = M if dim == 0 else N 32 numpy.random.seed((1 << 32) - 1) 33 self.inputs = { 34 "input_one": torch.rand(M, N, device=device), 35 "dim": dim, 36 "index": torch.tensor( 37 numpy.random.randint(0, min_val, (M, N)), device=device 38 ), 39 } 40 self.set_module_name("gather") 41 42 def forward(self, input_one, dim: int, index): 43 return torch.gather(input_one, dim, index) 44 45 46op_bench.generate_pt_test(gather_configs_short + gather_configs_long, GatherBenchmark) 47 48 49if __name__ == "__main__": 50 op_bench.benchmark_runner.main() 51