• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1import random
2from typing import List
3
4import operator_benchmark as op_bench
5
6import torch
7
8
9"""Microbenchmarks for Stack operator"""
10
11# Configs for PT stack operator
12stack_configs_static_runtime = op_bench.config_list(
13    attr_names=["sizes", "N"],
14    attrs=[
15        [(20, 40), 5],
16        [(1, 40), 5],
17    ],
18    cross_product_configs={"device": ["cpu", "cuda"], "dim": list(range(3))},
19    tags=["static_runtime"],
20)
21
22stack_configs_short = op_bench.config_list(
23    attr_names=["sizes", "N"],
24    attrs=[
25        [(1, 1, 1), 2],  # noqa: E241
26        [(512, 512, 2), 2],  # noqa: E241
27        [(128, 1024, 2), 2],  # noqa: E241
28    ],
29    cross_product_configs={"device": ["cpu", "cuda"], "dim": list(range(4))},
30    tags=["short"],
31)
32
33stack_configs_long = op_bench.config_list(
34    attr_names=["sizes", "N"],
35    attrs=[
36        [(2**10, 2**10, 2), 2],  # noqa: E241
37        [(2**10 + 1, 2**10 - 1, 2), 2],  # noqa: E226,E241
38        [(2**10, 2**10, 2), 2],  # noqa: E241
39    ],
40    cross_product_configs={"device": ["cpu", "cuda"], "dim": list(range(4))},
41    tags=["long"],
42)
43
44# There is a different codepath on CUDA for >4 dimensions
45stack_configs_multidim = op_bench.config_list(
46    attr_names=["sizes", "N"],
47    attrs=[
48        [(2**6, 2**5, 2**2, 2**4, 2**5), 2],  # noqa: E241
49        [(2**4, 2**5, 2**2, 2**4, 2**5), 8],  # noqa: E241
50        [
51            (2**3 + 1, 2**5 - 1, 2**2 + 1, 2**4 - 1, 2**5 + 1),
52            17,
53        ],  # noqa: E226,E241
54    ],
55    cross_product_configs={"device": ["cpu", "cuda"], "dim": list(range(6))},
56    tags=["multidim"],
57)
58
59
60class StackBenchmark(op_bench.TorchBenchmarkBase):
61    def init(self, sizes, N, dim, device):
62        random.seed(42)
63        inputs = []
64        gen_sizes = []
65        if type(sizes) == list and N == -1:
66            gen_sizes = sizes
67        else:
68            for i in range(N):
69                gen_sizes.append(
70                    [
71                        old_size() if callable(old_size) else old_size
72                        for old_size in sizes
73                    ]
74                )
75
76        for s in gen_sizes:
77            inputs.append(torch.rand(s, device=device))
78        result = torch.rand(gen_sizes[0], device=device)
79        self.inputs = {"result": result, "inputs": inputs, "dim": dim}
80        self.set_module_name("stack")
81
82    def forward(self, result: torch.Tensor, inputs: List[torch.Tensor], dim: int):
83        return torch.stack(inputs, dim=dim, out=result)
84
85
86op_bench.generate_pt_test(
87    stack_configs_static_runtime
88    + stack_configs_short
89    + stack_configs_long
90    + stack_configs_multidim,
91    StackBenchmark,
92)
93
94if __name__ == "__main__":
95    op_bench.benchmark_runner.main()
96