1import operator_benchmark as op_bench 2 3import torch 4 5 6"""Microbenchmarks for Split operator""" 7 8 9# Configs for PT Split operator 10split_configs_short = op_bench.config_list( 11 attr_names=["M", "N", "parts"], 12 attrs=[ 13 [8, 8, 2], 14 [256, 512, 2], 15 [512, 512, 2], 16 ], 17 cross_product_configs={ 18 "device": ["cpu", "cuda"], 19 }, 20 tags=["short"], 21) 22 23split_configs_long = op_bench.cross_product_configs( 24 M=[128, 1024], N=[128, 1024], parts=[2, 4], device=["cpu", "cuda"], tags=["long"] 25) 26 27 28class SplitBenchmark(op_bench.TorchBenchmarkBase): 29 def init(self, M, N, parts, device): 30 self.inputs = { 31 "input": torch.rand(M, N, device=device), 32 "split_size": int(M * N / parts), 33 } 34 self.set_module_name("split") 35 36 def forward(self, input, split_size: int): 37 return torch.split(input, split_size) 38 39 40op_bench.generate_pt_test(split_configs_short + split_configs_long, SplitBenchmark) 41 42 43if __name__ == "__main__": 44 op_bench.benchmark_runner.main() 45