• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1import operator_benchmark as op_bench
2
3import torch
4
5
6"""Microbenchmarks for quantized instancenorm operator."""
7
8instancenorm_configs_short = op_bench.cross_product_configs(
9    dims=(
10        (32, 8, 16),
11        (32, 8, 56, 56),
12    ),
13    dtype=(torch.qint8,),
14    tags=["short"],
15)
16
17
18class QInstanceNormBenchmark(op_bench.TorchBenchmarkBase):
19    def init(self, dims, dtype):
20        X = (torch.rand(*dims) - 0.5) * 256
21        num_channels = dims[1]
22        scale = 1.0
23        zero_point = 0
24
25        self.inputs = {
26            "qX": torch.quantize_per_tensor(
27                X, scale=scale, zero_point=zero_point, dtype=dtype
28            ),
29            "weight": torch.rand(num_channels, dtype=torch.float),
30            "bias": torch.rand(num_channels, dtype=torch.float),
31            "eps": 1e-5,
32            "Y_scale": 0.1,
33            "Y_zero_point": 0,
34        }
35
36    def forward(self, qX, weight, bias, eps: float, Y_scale: float, Y_zero_point: int):
37        return torch.ops.quantized.instance_norm(
38            qX,
39            weight=weight,
40            bias=bias,
41            eps=eps,
42            output_scale=Y_scale,
43            output_zero_point=Y_zero_point,
44        )
45
46
47op_bench.generate_pt_test(instancenorm_configs_short, QInstanceNormBenchmark)
48
49
50if __name__ == "__main__":
51    op_bench.benchmark_runner.main()
52