1import operator_benchmark as op_bench 2 3import torch 4 5 6""" 7Microbenchmarks for batch matrix mult with einsum and torch.bmm. 8""" 9 10batch_mm_configs_short = op_bench.config_list( 11 attr_names=["B", "M", "N", "K"], 12 attrs=[ 13 [4, 5, 3, 2], 14 [32, 25, 20, 30], 15 [128, 100, 120, 110], 16 ], 17 cross_product_configs={ 18 "device": ["cpu", "cuda"], 19 }, 20 tags=["short"], 21) 22 23batch_mm_configs_long = op_bench.config_list( 24 attr_names=["B", "M", "N", "K"], 25 attrs=[ 26 [128, 256, 128, 256], 27 [512, 1024, 1024, 512], 28 ], 29 cross_product_configs={ 30 "device": ["cpu", "cuda"], 31 }, 32 tags=["long"], 33) 34 35batch_mm_op_list = op_bench.op_list( 36 attr_names=["op_name", "op_func"], 37 attrs=[ 38 ["einsum_bmm", torch.einsum], 39 ["bmm", torch.bmm], 40 ], 41) 42 43 44class BatchMatrixMultBenchmark(op_bench.TorchBenchmarkBase): 45 def init(self, B, M, N, K, device, op_func): 46 self.inputs = { 47 "input_one": torch.rand(B, M, N, device=device), 48 "input_two": torch.rand(B, N, K, device=device), 49 } 50 self.op_func = op_func 51 52 def forward(self, input_one, input_two): 53 if self.op_func.__name__ == "einsum": 54 return torch.einsum("bij,bjk->bik", input_one, input_two) 55 else: 56 return torch.bmm(input_one, input_two) 57 58 59""" 60Microbenchmarks for element-wise matrix mult with einsum and torch.mul. 61""" 62 63batch_elementwise_configs_short = op_bench.config_list( 64 attr_names=["B", "M", "N"], 65 attrs=[ 66 [4, 5, 3], 67 [32, 25, 20], 68 [100, 90, 110], 69 ], 70 cross_product_configs={ 71 "device": ["cpu", "cuda"], 72 }, 73 tags=["short"], 74) 75 76 77batch_elementwise_configs_long = op_bench.cross_product_configs( 78 B=[128, 512, 1024], 79 M=[128, 512, 1024], 80 N=[128, 512, 1024], 81 device=["cpu", "cuda"], 82 tags=["long"], 83) 84 85batch_elementwise_op_list = op_bench.op_list( 86 attr_names=["op_name", "op_func"], 87 attrs=[ 88 ["einsum_elementwise", torch.einsum], 89 ["mul", torch.mul], 90 ], 91) 92 93 94class BatchElementWiseBenchmark(op_bench.TorchBenchmarkBase): 95 def init(self, B, M, N, device, op_func): 96 self.inputs = { 97 "input_one": torch.rand(B, M, N, device=device), 98 "input_two": torch.rand(B, M, N, device=device), 99 } 100 self.op_func = op_func 101 102 def forward(self, input_one, input_two): 103 if self.op_func.__name__ == "einsum": 104 return torch.einsum("bij,bij->bij", input_one, input_two) 105 else: 106 return torch.mul(input_one, input_two) 107 108 109op_bench.generate_pt_tests_from_op_list( 110 batch_mm_op_list, 111 batch_mm_configs_short + batch_mm_configs_long, 112 BatchMatrixMultBenchmark, 113) 114 115op_bench.generate_pt_tests_from_op_list( 116 batch_elementwise_op_list, 117 batch_elementwise_configs_short + batch_elementwise_configs_long, 118 BatchElementWiseBenchmark, 119) 120 121 122if __name__ == "__main__": 123 op_bench.benchmark_runner.main() 124