• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1import operator_benchmark as op_bench
2
3import torch
4
5
6"""Microbenchmarks for diag operator"""
7
8
9# Configs for PT diag operator
10diag_configs_short = op_bench.config_list(
11    attr_names=["dim", "M", "N", "diagonal", "out"],
12    attrs=[
13        [1, 64, 64, 0, True],
14        [2, 128, 128, -10, False],
15        [1, 256, 256, 20, True],
16    ],
17    cross_product_configs={
18        "device": ["cpu", "cuda"],
19    },
20    tags=["short"],
21)
22
23
24class DiagBenchmark(op_bench.TorchBenchmarkBase):
25    def init(self, dim, M, N, diagonal, out, device):
26        self.inputs = {
27            "input": torch.rand(M, N, device=device)
28            if dim == 2
29            else torch.rand(M, device=device),
30            "diagonal": diagonal,
31            "out": out,
32            "out_tensor": torch.tensor(
33                (),
34                device=device,
35            ),
36        }
37        self.set_module_name("diag")
38
39    def forward(self, input, diagonal: int, out: bool, out_tensor):
40        if out:
41            return torch.diag(input, diagonal=diagonal, out=out_tensor)
42        else:
43            return torch.diag(input, diagonal=diagonal)
44
45
46op_bench.generate_pt_test(diag_configs_short, DiagBenchmark)
47
48
49if __name__ == "__main__":
50    op_bench.benchmark_runner.main()
51