from typing import List import operator_benchmark as op_bench import torch import torch.ao.nn.quantized as nnq """Microbenchmarks for quantized Cat operator""" # Configs for PT Cat operator qcat_configs_short = op_bench.config_list( attr_names=["M", "N", "K", "L", "dim"], attrs=[ [256, 512, 1, 2, 0], [512, 512, 2, 1, 1], ], cross_product_configs={ "contig": ("all", "one", "none"), "dtype": (torch.quint8, torch.qint8, torch.qint32), }, tags=["short"], ) qcat_configs_long = op_bench.cross_product_configs( M=[128, 1024], N=[128, 1024], K=[1, 2], L=[5, 7], dim=[0, 1, 2], contig=["all", "one", "none"], dtype=[torch.quint8], tags=["long"], ) class QCatBenchmark(op_bench.TorchBenchmarkBase): def init(self, M, N, K, L, dim, contig, dtype): f_input = (torch.rand(M, N, K) - 0.5) * 256 self.qf = nnq.QFunctional() scale = 1.0 zero_point = 0 self.qf.scale = scale self.qf.zero_point = zero_point assert contig in ("none", "one", "all") q_input = torch.quantize_per_tensor(f_input, scale, zero_point, dtype) permute_dims = tuple(range(q_input.ndim - 1, -1, -1)) q_input_non_contig = q_input.permute(permute_dims).contiguous() q_input_non_contig = q_input_non_contig.permute(permute_dims) if contig == "all": self.input = (q_input, q_input) elif contig == "one": self.input = (q_input, q_input_non_contig) elif contig == "none": self.input = (q_input_non_contig, q_input_non_contig) self.inputs = {"input": self.input, "dim": dim} self.set_module_name("qcat") def forward(self, input: List[torch.Tensor], dim: int): return self.qf.cat(input, dim=dim) op_bench.generate_pt_test(qcat_configs_short + qcat_configs_long, QCatBenchmark) if __name__ == "__main__": op_bench.benchmark_runner.main()