1from typing import List 2 3import operator_benchmark as op_bench 4 5import torch 6 7 8"""Microbenchmarks for as_strided operator""" 9 10 11# Configs for PT as_strided operator 12as_strided_configs_short = op_bench.config_list( 13 attr_names=["M", "N", "size", "stride", "storage_offset"], 14 attrs=[ 15 [8, 8, (2, 2), (1, 1), 0], 16 [256, 256, (32, 32), (1, 1), 0], 17 [512, 512, (64, 64), (2, 2), 1], 18 ], 19 cross_product_configs={ 20 "device": ["cpu", "cuda"], 21 }, 22 tags=["short"], 23) 24 25as_strided_configs_long = op_bench.cross_product_configs( 26 M=[512], 27 N=[1024], 28 size=[(16, 16), (128, 128)], 29 stride=[(1, 1)], 30 storage_offset=[0, 1], 31 device=["cpu", "cuda"], 32 tags=["long"], 33) 34 35 36class As_stridedBenchmark(op_bench.TorchBenchmarkBase): 37 def init(self, M, N, size, stride, storage_offset, device): 38 self.inputs = { 39 "input_one": torch.rand(M, N, device=device), 40 "size": size, 41 "stride": stride, 42 "storage_offset": storage_offset, 43 } 44 self.set_module_name("as_strided") 45 46 def forward( 47 self, input_one, size: List[int], stride: List[int], storage_offset: int 48 ): 49 return torch.as_strided(input_one, size, stride, storage_offset) 50 51 52op_bench.generate_pt_test( 53 as_strided_configs_short + as_strided_configs_long, As_stridedBenchmark 54) 55 56 57if __name__ == "__main__": 58 op_bench.benchmark_runner.main() 59