• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1import numpy
2
3import operator_benchmark as op_bench
4
5import torch
6
7
8"""Microbenchmarks for gather operator."""
9
10# An example input from this configuration is M=4, N=4, dim=0.
11gather_configs_short = op_bench.config_list(
12    attr_names=["M", "N", "dim"],
13    attrs=[
14        [256, 512, 0],
15        [512, 512, 1],
16    ],
17    cross_product_configs={
18        "device": ["cpu", "cuda"],
19    },
20    tags=["short"],
21)
22
23
24gather_configs_long = op_bench.cross_product_configs(
25    M=[128, 1024], N=[128, 1024], dim=[0, 1], device=["cpu", "cuda"], tags=["long"]
26)
27
28
29class GatherBenchmark(op_bench.TorchBenchmarkBase):
30    def init(self, M, N, dim, device):
31        min_val = M if dim == 0 else N
32        numpy.random.seed((1 << 32) - 1)
33        self.inputs = {
34            "input_one": torch.rand(M, N, device=device),
35            "dim": dim,
36            "index": torch.tensor(
37                numpy.random.randint(0, min_val, (M, N)), device=device
38            ),
39        }
40        self.set_module_name("gather")
41
42    def forward(self, input_one, dim: int, index):
43        return torch.gather(input_one, dim, index)
44
45
46op_bench.generate_pt_test(gather_configs_short + gather_configs_long, GatherBenchmark)
47
48
49if __name__ == "__main__":
50    op_bench.benchmark_runner.main()
51