from typing import List import operator_benchmark as op_bench import torch """Microbenchmarks for as_strided operator""" # Configs for PT as_strided operator as_strided_configs_short = op_bench.config_list( attr_names=["M", "N", "size", "stride", "storage_offset"], attrs=[ [8, 8, (2, 2), (1, 1), 0], [256, 256, (32, 32), (1, 1), 0], [512, 512, (64, 64), (2, 2), 1], ], cross_product_configs={ "device": ["cpu", "cuda"], }, tags=["short"], ) as_strided_configs_long = op_bench.cross_product_configs( M=[512], N=[1024], size=[(16, 16), (128, 128)], stride=[(1, 1)], storage_offset=[0, 1], device=["cpu", "cuda"], tags=["long"], ) class As_stridedBenchmark(op_bench.TorchBenchmarkBase): def init(self, M, N, size, stride, storage_offset, device): self.inputs = { "input_one": torch.rand(M, N, device=device), "size": size, "stride": stride, "storage_offset": storage_offset, } self.set_module_name("as_strided") def forward( self, input_one, size: List[int], stride: List[int], storage_offset: int ): return torch.as_strided(input_one, size, stride, storage_offset) op_bench.generate_pt_test( as_strided_configs_short + as_strided_configs_long, As_stridedBenchmark ) if __name__ == "__main__": op_bench.benchmark_runner.main()