1import operator_benchmark as op_bench 2 3import torch 4 5 6"""Microbenchmarks for diag operator""" 7 8 9# Configs for PT diag operator 10diag_configs_short = op_bench.config_list( 11 attr_names=["dim", "M", "N", "diagonal", "out"], 12 attrs=[ 13 [1, 64, 64, 0, True], 14 [2, 128, 128, -10, False], 15 [1, 256, 256, 20, True], 16 ], 17 cross_product_configs={ 18 "device": ["cpu", "cuda"], 19 }, 20 tags=["short"], 21) 22 23 24class DiagBenchmark(op_bench.TorchBenchmarkBase): 25 def init(self, dim, M, N, diagonal, out, device): 26 self.inputs = { 27 "input": torch.rand(M, N, device=device) 28 if dim == 2 29 else torch.rand(M, device=device), 30 "diagonal": diagonal, 31 "out": out, 32 "out_tensor": torch.tensor( 33 (), 34 device=device, 35 ), 36 } 37 self.set_module_name("diag") 38 39 def forward(self, input, diagonal: int, out: bool, out_tensor): 40 if out: 41 return torch.diag(input, diagonal=diagonal, out=out_tensor) 42 else: 43 return torch.diag(input, diagonal=diagonal) 44 45 46op_bench.generate_pt_test(diag_configs_short, DiagBenchmark) 47 48 49if __name__ == "__main__": 50 op_bench.benchmark_runner.main() 51