• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1import random
2from typing import List
3
4import operator_benchmark as op_bench
5
6import torch
7
8
9"""Microbenchmarks for Cat operator"""
10
11cross_product_configs = {
12    "device": ["cpu", "cuda"],
13}
14
15# Configs for PT Cat operator
16cat_configs_short = op_bench.config_list(
17    attr_names=["sizes", "N", "dim"],
18    attrs=[
19        [(1, 1, 1), 2, 0],  # noqa: E241
20        [(512, 512, 2), 2, 1],  # noqa: E241
21        [(128, 1024, 2), 2, 1],  # noqa: E241
22    ],
23    cross_product_configs=cross_product_configs,
24    tags=["short"],
25)
26
27# Configs specific to static runtime feature - a fast path runtime for pared down models
28cat_configs_static_runtime = op_bench.config_list(
29    attr_names=["sizes", "N", "dim"],
30    attrs=[
31        [[(1, 160), (1, 14)], -1, 1],
32        [[(1, 20, 40), (1, 4, 40), (1, 5, 40)], -1, 1],
33        [[(1, 580), (1, 174)], -1, 1],
34        [[(20, 160), (20, 14)], -1, 1],
35        [[(20, 20, 40), (20, 4, 40), (20, 5, 40)], -1, 1],
36        [[(20, 580), (20, 174)], -1, 1],
37    ],
38    cross_product_configs=cross_product_configs,
39    tags=["static_runtime"],
40)
41
42cat_configs_long = op_bench.config_list(
43    attr_names=["sizes", "N", "dim"],
44    attrs=[
45        [(2**10, 2**10, 2), 2, 0],  # noqa: E241
46        [(2**10 + 1, 2**10 - 1, 2), 2, 1],  # noqa: E226,E241
47        [(2**10, 2**10, 2), 2, 2],  # noqa: E241
48        [
49            [
50                lambda: random.randint(2**6, 2**7),
51                2**7 - 17,
52                2**6 + 1,
53            ],  # noqa: E201,E226,E241
54            5,
55            0,
56        ],
57        [
58            [
59                2**6 + 2**5,
60                lambda: random.randint(2**6, 2**7),
61                2**6,
62            ],  # noqa: E201,E226,E241,E272
63            5,
64            1,
65        ],
66        [
67            [
68                2**7,
69                2**6,
70                lambda: random.randint(2**6, 2**7),
71            ],  # noqa: E201,E241,E272
72            5,
73            2,
74        ],
75        [[lambda: random.randint(2**5, 2**6), 2**5, 2**6], 50, 0],  # noqa: E241
76        [
77            [2**5, lambda: random.randint(2**5, 2**6), 2**6],  # noqa: E241,E272
78            50,
79            1,
80        ],
81        [
82            [
83                2**5 + 1,
84                2**6 + 1,
85                lambda: random.randint(2**5, 2**6),
86            ],  # noqa: E226,E241,E272
87            50,
88            2,
89        ],
90    ],
91    cross_product_configs=cross_product_configs,
92    tags=["long"],
93)
94
95# There is a different codepath on CUDA for >4 dimensions
96cat_configs_multidim = op_bench.config_list(
97    attr_names=["sizes", "N", "dim"],
98    attrs=[
99        [(2**6, 2**5, 2**2, 2**4, 2**5), 2, 2],  # noqa: E241
100        [(2**4, 2**5, 2**2, 2**4, 2**5), 8, 2],  # noqa: E241
101        [
102            (2**3 + 1, 2**5 - 1, 2**2 + 1, 2**4 - 1, 2**5 + 1),
103            17,
104            4,
105        ],  # noqa: E226,E241
106    ],
107    cross_product_configs=cross_product_configs,
108    tags=["multidim"],
109)
110
111cat_configs_manyinputs = op_bench.config_list(
112    attr_names=["sizes", "N", "dim"],
113    attrs=[
114        [[lambda: random.randint(1, 10000)], 100, 0],
115        [[lambda: random.randint(1, 1000)], 1000, 0],
116        [[lambda: random.randint(1, 500)], 2000, 0],
117        [[lambda: random.randint(1, 300)], 3000, 0],
118    ],
119    cross_product_configs=cross_product_configs,
120    tags=["manyinputs"],
121)
122
123
124class CatBenchmark(op_bench.TorchBenchmarkBase):
125    def init(self, sizes, N, dim, device):
126        random.seed(42)
127        inputs = []
128        gen_sizes = []
129        if type(sizes) == list and N == -1:
130            gen_sizes = sizes
131        else:
132            for i in range(N):
133                gen_sizes.append(
134                    [
135                        old_size() if callable(old_size) else old_size
136                        for old_size in sizes
137                    ]
138                )
139
140        for s in gen_sizes:
141            inputs.append(torch.rand(s, device=device))
142        result = torch.empty(0, device=device)
143        self.inputs = {"result": result, "inputs": inputs, "dim": dim}
144        self.set_module_name("cat")
145
146    def forward(self, result: torch.Tensor, inputs: List[torch.Tensor], dim: int):
147        return torch.cat(inputs, dim=dim, out=result)
148
149
150op_bench.generate_pt_test(
151    cat_configs_short
152    + cat_configs_long
153    + cat_configs_multidim
154    + cat_configs_manyinputs
155    + cat_configs_static_runtime,
156    CatBenchmark,
157)
158
159if __name__ == "__main__":
160    op_bench.benchmark_runner.main()
161