1from typing import List 2 3import operator_benchmark as op_bench 4 5import torch 6import torch.ao.nn.quantized as nnq 7 8 9"""Microbenchmarks for quantized Cat operator""" 10 11# Configs for PT Cat operator 12qcat_configs_short = op_bench.config_list( 13 attr_names=["M", "N", "K", "L", "dim"], 14 attrs=[ 15 [256, 512, 1, 2, 0], 16 [512, 512, 2, 1, 1], 17 ], 18 cross_product_configs={ 19 "contig": ("all", "one", "none"), 20 "dtype": (torch.quint8, torch.qint8, torch.qint32), 21 }, 22 tags=["short"], 23) 24 25qcat_configs_long = op_bench.cross_product_configs( 26 M=[128, 1024], 27 N=[128, 1024], 28 K=[1, 2], 29 L=[5, 7], 30 dim=[0, 1, 2], 31 contig=["all", "one", "none"], 32 dtype=[torch.quint8], 33 tags=["long"], 34) 35 36 37class QCatBenchmark(op_bench.TorchBenchmarkBase): 38 def init(self, M, N, K, L, dim, contig, dtype): 39 f_input = (torch.rand(M, N, K) - 0.5) * 256 40 self.qf = nnq.QFunctional() 41 scale = 1.0 42 zero_point = 0 43 self.qf.scale = scale 44 self.qf.zero_point = zero_point 45 46 assert contig in ("none", "one", "all") 47 q_input = torch.quantize_per_tensor(f_input, scale, zero_point, dtype) 48 permute_dims = tuple(range(q_input.ndim - 1, -1, -1)) 49 q_input_non_contig = q_input.permute(permute_dims).contiguous() 50 q_input_non_contig = q_input_non_contig.permute(permute_dims) 51 if contig == "all": 52 self.input = (q_input, q_input) 53 elif contig == "one": 54 self.input = (q_input, q_input_non_contig) 55 elif contig == "none": 56 self.input = (q_input_non_contig, q_input_non_contig) 57 58 self.inputs = {"input": self.input, "dim": dim} 59 self.set_module_name("qcat") 60 61 def forward(self, input: List[torch.Tensor], dim: int): 62 return self.qf.cat(input, dim=dim) 63 64 65op_bench.generate_pt_test(qcat_configs_short + qcat_configs_long, QCatBenchmark) 66 67 68if __name__ == "__main__": 69 op_bench.benchmark_runner.main() 70