• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1from typing import List
2
3import operator_benchmark as op_bench
4
5import torch
6
7
8"""Microbenchmarks for as_strided operator"""
9
10
11# Configs for PT as_strided operator
12as_strided_configs_short = op_bench.config_list(
13    attr_names=["M", "N", "size", "stride", "storage_offset"],
14    attrs=[
15        [8, 8, (2, 2), (1, 1), 0],
16        [256, 256, (32, 32), (1, 1), 0],
17        [512, 512, (64, 64), (2, 2), 1],
18    ],
19    cross_product_configs={
20        "device": ["cpu", "cuda"],
21    },
22    tags=["short"],
23)
24
25as_strided_configs_long = op_bench.cross_product_configs(
26    M=[512],
27    N=[1024],
28    size=[(16, 16), (128, 128)],
29    stride=[(1, 1)],
30    storage_offset=[0, 1],
31    device=["cpu", "cuda"],
32    tags=["long"],
33)
34
35
36class As_stridedBenchmark(op_bench.TorchBenchmarkBase):
37    def init(self, M, N, size, stride, storage_offset, device):
38        self.inputs = {
39            "input_one": torch.rand(M, N, device=device),
40            "size": size,
41            "stride": stride,
42            "storage_offset": storage_offset,
43        }
44        self.set_module_name("as_strided")
45
46    def forward(
47        self, input_one, size: List[int], stride: List[int], storage_offset: int
48    ):
49        return torch.as_strided(input_one, size, stride, storage_offset)
50
51
52op_bench.generate_pt_test(
53    as_strided_configs_short + as_strided_configs_long, As_stridedBenchmark
54)
55
56
57if __name__ == "__main__":
58    op_bench.benchmark_runner.main()
59