• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1import operator_benchmark as op_bench
2
3import torch
4
5
6"""Microbenchmarks for quantized layernorm operator."""
7
8layernorm_configs_short = op_bench.cross_product_configs(
9    dims=(
10        (1, 8, 16),
11        (8, 8, 16),
12        (32, 8, 16),
13        (64, 128, 56, 56),
14    ),
15    dtype=(torch.qint8,),
16    tags=["short"],
17)
18
19
20class QLayerNormBenchmark(op_bench.TorchBenchmarkBase):
21    def init(self, dims, dtype):
22        X = (torch.rand(*dims) - 0.5) * 256
23        scale = 1.0
24        zero_point = 0
25        self.qX = torch.quantize_per_tensor(
26            X, scale=scale, zero_point=zero_point, dtype=dtype
27        )
28
29        self.inputs = {
30            "qX": self.qX,
31            "weight": torch.rand(*self.qX.size()[1:], dtype=torch.float),
32            "bias": torch.rand(*self.qX.size()[1:], dtype=torch.float),
33            "eps": 1e-5,
34            "Y_scale": 0.1,
35            "Y_zero_point": 0,
36        }
37
38    def forward(self, qX, weight, bias, eps: float, Y_scale: float, Y_zero_point: int):
39        return torch.ops.quantized.layer_norm(
40            qX,
41            qX.size()[1:],
42            weight=weight,
43            bias=bias,
44            eps=eps,
45            output_scale=Y_scale,
46            output_zero_point=Y_zero_point,
47        )
48
49
50op_bench.generate_pt_test(layernorm_configs_short, QLayerNormBenchmark)
51
52
53if __name__ == "__main__":
54    op_bench.benchmark_runner.main()
55