• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1from pt import configs
2
3import operator_benchmark as op_bench
4
5import torch
6import torch.ao.nn.quantized as nnq
7import torch.ao.nn.quantized.dynamic as nnqd
8
9
10"""
11Microbenchmarks for Quantized Linear operators.
12"""
13
14
15class _QLinearBenchmarkBase(op_bench.TorchBenchmarkBase):
16    def init(self, N, IN, OUT, linear_under_test):
17        scale = torch.tensor(1.0 / 255)
18        zero_point = torch.tensor(0)
19        self.X = torch.randn(N, IN, dtype=torch.float32)
20        self.qX = torch.quantize_per_tensor(
21            self.X, scale=scale, zero_point=zero_point, dtype=torch.quint8
22        )
23        W = torch.randn(OUT, IN, dtype=torch.float32)
24        qW = torch.quantize_per_tensor(W, scale=scale, zero_point=0, dtype=torch.qint8)
25
26        # Assume that the `self.qlinear` is set in the child
27        self.qlinear = linear_under_test
28        self.qlinear.weight = qW
29        self.qlinear.scale = scale
30        self.qlinear.zero_point = zero_point
31
32    def forward(self, input):
33        # Assume that the `self.input` is set in the child
34        return self.qlinear(input)
35
36
37class QLinearBenchmark(_QLinearBenchmarkBase):
38    def init(self, N, IN, OUT, device):
39        super().init(N, IN, OUT, nnq.Linear(IN, OUT))
40        self.inputs = {"input": self.qX}
41        self.set_module_name("QLinear")
42
43
44class QDynamicLinearBenchmark(_QLinearBenchmarkBase):
45    def init(self, N, IN, OUT, device):
46        super().init(N, IN, OUT, nnqd.Linear(IN, OUT))
47        self.inputs = {"input": self.X}
48        self.set_module_name("QDynamicLinear")
49
50
51op_bench.generate_pt_test(
52    configs.remove_cuda(configs.linear_configs_short + configs.linear_configs_long),
53    QLinearBenchmark,
54)
55op_bench.generate_pt_test(
56    configs.remove_cuda(configs.linear_configs_short + configs.linear_configs_long),
57    QDynamicLinearBenchmark,
58)
59
60
61if __name__ == "__main__":
62    op_bench.benchmark_runner.main()
63