1import operator_benchmark as op_bench 2 3import torch 4 5 6"""Microbenchmarks for MatMul operator""" 7 8# Configs for PT Matmul operator 9mm_short_configs = op_bench.config_list( 10 attr_names=["M", "N", "K", "trans_a", "trans_b"], 11 attrs=[ 12 [1, 1, 1, True, False], 13 [128, 128, 128, True, False], 14 [256, 256, 256, False, True], 15 ], 16 cross_product_configs={ 17 "device": ["cpu", "cuda"], 18 }, 19 tags=["short"], 20) 21 22 23mm_long_configs = op_bench.cross_product_configs( 24 M=[32], 25 N=[512, 128], 26 K=[64], 27 trans_a=[False, True], 28 trans_b=[True, False], 29 device=["cpu", "cuda"], 30 tags=["long"], 31) 32 33 34class MatMulBenchmark(op_bench.TorchBenchmarkBase): 35 def init(self, M, N, K, trans_a, trans_b, device): 36 self.inputs = { 37 "input_one": torch.rand(M, N, device=device) 38 if trans_a 39 else torch.rand(N, M, device=device).t(), 40 "input_two": torch.rand(N, K, device=device) 41 if trans_b 42 else torch.rand(K, N, device=device).t(), 43 } 44 self.set_module_name("matmul") 45 46 def forward(self, input_one, input_two): 47 return torch.matmul(input_one, input_two) 48 49 50op_bench.generate_pt_test(mm_long_configs + mm_short_configs, MatMulBenchmark) 51 52 53if __name__ == "__main__": 54 op_bench.benchmark_runner.main() 55