1import operator_benchmark as op_bench 2 3import torch 4 5 6"""Microbenchmarks for Chunk operator""" 7 8 9# Configs for PT Chunk operator 10chunk_short_configs = op_bench.config_list( 11 attr_names=["M", "N", "chunks"], 12 attrs=[ 13 [8, 8, 2], 14 [256, 512, 2], 15 [512, 512, 2], 16 ], 17 cross_product_configs={ 18 "device": ["cpu", "cuda"], 19 }, 20 tags=["short"], 21) 22 23chunks_long_configs = op_bench.cross_product_configs( 24 M=[128, 1024], N=[128, 1024], chunks=[2, 4], device=["cpu", "cuda"], tags=["long"] 25) 26 27 28class ChunkBenchmark(op_bench.TorchBenchmarkBase): 29 def init(self, M, N, chunks, device): 30 self.inputs = {"input_one": torch.rand(M, N, device=device), "chunks": chunks} 31 self.set_module_name("chunk") 32 33 def forward(self, input_one, chunks: int): 34 return torch.chunk(input_one, chunks) 35 36 37op_bench.generate_pt_test(chunk_short_configs + chunks_long_configs, ChunkBenchmark) 38 39 40if __name__ == "__main__": 41 op_bench.benchmark_runner.main() 42