1import operator_benchmark as op_bench 2 3import torch 4 5 6"""Microbenchmarks for add_ operator. Supports both Caffe2/PyTorch.""" 7 8 9class BmmBenchmark(op_bench.TorchBenchmarkBase): 10 def init(self, B, M, N, K, device, op): 11 self.inputs = { 12 "batch1": torch.rand( 13 (B, M, K), device=device, requires_grad=self.auto_set() 14 ), 15 "batch2": torch.rand( 16 ( 17 B, 18 K, 19 N, 20 ), 21 device=device, 22 requires_grad=self.auto_set(), 23 ), 24 } 25 self.set_module_name(f"bmm (actual op={op}") 26 self.op = torch.bmm if op == "bmm" else torch.matmul 27 28 def forward(self, batch1, batch2): 29 return self.op(batch1, batch2) 30 31 32bmm_configs = op_bench.cross_product_configs( 33 B=[2, 100], 34 M=[8, 256], 35 N=[256, 16], 36 K=[16, 32], 37 device=["cpu"], 38 tags=["short"], 39 op=["bmm", "matmul"], 40) 41 42op_bench.generate_pt_test(bmm_configs, BmmBenchmark) 43 44if __name__ == "__main__": 45 op_bench.benchmark_runner.main() 46