• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1import operator_benchmark as op_bench
2
3import torch
4
5
6"""Microbenchmarks for add_ operator. Supports both Caffe2/PyTorch."""
7
8
9class BmmBenchmark(op_bench.TorchBenchmarkBase):
10    def init(self, B, M, N, K, device, op):
11        self.inputs = {
12            "batch1": torch.rand(
13                (B, M, K), device=device, requires_grad=self.auto_set()
14            ),
15            "batch2": torch.rand(
16                (
17                    B,
18                    K,
19                    N,
20                ),
21                device=device,
22                requires_grad=self.auto_set(),
23            ),
24        }
25        self.set_module_name(f"bmm (actual op={op}")
26        self.op = torch.bmm if op == "bmm" else torch.matmul
27
28    def forward(self, batch1, batch2):
29        return self.op(batch1, batch2)
30
31
32bmm_configs = op_bench.cross_product_configs(
33    B=[2, 100],
34    M=[8, 256],
35    N=[256, 16],
36    K=[16, 32],
37    device=["cpu"],
38    tags=["short"],
39    op=["bmm", "matmul"],
40)
41
42op_bench.generate_pt_test(bmm_configs, BmmBenchmark)
43
44if __name__ == "__main__":
45    op_bench.benchmark_runner.main()
46