• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1from typing import List
2
3import operator_benchmark as op_bench
4
5import torch
6import torch.ao.nn.quantized as nnq
7
8
9"""Microbenchmarks for quantized Cat operator"""
10
11# Configs for PT Cat operator
12qcat_configs_short = op_bench.config_list(
13    attr_names=["M", "N", "K", "L", "dim"],
14    attrs=[
15        [256, 512, 1, 2, 0],
16        [512, 512, 2, 1, 1],
17    ],
18    cross_product_configs={
19        "contig": ("all", "one", "none"),
20        "dtype": (torch.quint8, torch.qint8, torch.qint32),
21    },
22    tags=["short"],
23)
24
25qcat_configs_long = op_bench.cross_product_configs(
26    M=[128, 1024],
27    N=[128, 1024],
28    K=[1, 2],
29    L=[5, 7],
30    dim=[0, 1, 2],
31    contig=["all", "one", "none"],
32    dtype=[torch.quint8],
33    tags=["long"],
34)
35
36
37class QCatBenchmark(op_bench.TorchBenchmarkBase):
38    def init(self, M, N, K, L, dim, contig, dtype):
39        f_input = (torch.rand(M, N, K) - 0.5) * 256
40        self.qf = nnq.QFunctional()
41        scale = 1.0
42        zero_point = 0
43        self.qf.scale = scale
44        self.qf.zero_point = zero_point
45
46        assert contig in ("none", "one", "all")
47        q_input = torch.quantize_per_tensor(f_input, scale, zero_point, dtype)
48        permute_dims = tuple(range(q_input.ndim - 1, -1, -1))
49        q_input_non_contig = q_input.permute(permute_dims).contiguous()
50        q_input_non_contig = q_input_non_contig.permute(permute_dims)
51        if contig == "all":
52            self.input = (q_input, q_input)
53        elif contig == "one":
54            self.input = (q_input, q_input_non_contig)
55        elif contig == "none":
56            self.input = (q_input_non_contig, q_input_non_contig)
57
58        self.inputs = {"input": self.input, "dim": dim}
59        self.set_module_name("qcat")
60
61    def forward(self, input: List[torch.Tensor], dim: int):
62        return self.qf.cat(input, dim=dim)
63
64
65op_bench.generate_pt_test(qcat_configs_short + qcat_configs_long, QCatBenchmark)
66
67
68if __name__ == "__main__":
69    op_bench.benchmark_runner.main()
70