• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1import operator_benchmark as op_bench
2
3import torch
4import torch.ao.nn.quantized.functional as qF
5
6
7r"""Microbenchmarks for the quantized activations."""
8
9qactivation_long_configs = op_bench.cross_product_configs(
10    dims=(
11        # VGG-16 relu's with original shape: (-1, 3, 224, 224)
12        (64, 224, 224),  # ReLU-1   # noqa: E201
13        (128, 112, 112),  # ReLU-6
14        (256, 56, 56),  # ReLU-11  # noqa: E241
15        (512, 28, 28),  # ReLU-18  # noqa: E241
16        (512, 14, 14),  # ReLU-25  # noqa: E241
17        # Batch = 16
18        (16, 64, 224, 224),  # ReLU-1   # noqa: E241
19        (16, 128, 112, 112),  # ReLU-6
20        (16, 256, 56, 56),  # ReLU-11  # noqa: E241
21        (16, 512, 28, 28),  # ReLU-18  # noqa: E241
22        (16, 512, 14, 14),  # ReLU-25  # noqa: E241
23    ),
24    contig=(False, True),
25    inplace=(False, True),
26    dtype=(torch.quint8,),
27    tags=("long",),
28)
29
30qactivation_short_configs = op_bench.cross_product_configs(
31    dims=(
32        (3, 4, 5),  # Rank=3
33        (2, 3, 4, 5),  # Rank=4,
34        # Dimensions from the floating point benchmarks
35        (512, 512),
36        (256, 1024),
37    ),
38    contig=(False,),
39    inplace=(False,),
40    dtype=(torch.quint8, torch.qint8, torch.qint32),
41    tags=("short",),
42)
43
44qactivation_ops = op_bench.op_list(
45    attrs=(
46        ("relu", torch.nn.ReLU()),
47        ("relu6", torch.ops.quantized.relu6),
48        ("functional.hardtanh", qF.hardtanh),
49        ("functional.hardsigmoid", qF.hardsigmoid),
50        ("functional.leaky_relu", qF.leaky_relu),
51        ("functional.sigmoid", torch.nn.functional.sigmoid),
52        ("functional.tanh", torch.nn.functional.tanh),
53    ),
54    attr_names=("op_name", "op_func"),
55)
56
57
58class QActivationBenchmarkBase(op_bench.TorchBenchmarkBase):
59    r"""Base class for all the activations."""
60
61    def _setup(self, dims, contig, dtype):
62        # Input
63        f_input = (torch.rand(*dims) - 0.5) * 256
64        self.scale = 1.0
65        self.zero_point = 0
66
67        # Quantize the tensor
68        q_input = torch.quantize_per_tensor(
69            f_input, scale=self.scale, zero_point=self.zero_point, dtype=dtype
70        )
71        if not contig:
72            # Make non-contiguous
73            new_shape = list(range(q_input.ndim))[::-1]
74            q_input = q_input.permute(new_shape)
75
76        self.inputs = {"q_input": q_input}
77
78    def init(self, dims, contig, inplace, dtype, op_func):
79        self._setup(dims, contig, dtype)
80        self.qop = op_func
81
82
83class QActivationBenchmark(QActivationBenchmarkBase):
84    def forward(self, q_input):
85        return self.qop(q_input)
86
87
88op_bench.generate_pt_tests_from_op_list(
89    qactivation_ops,
90    qactivation_short_configs + qactivation_long_configs,
91    QActivationBenchmark,
92)
93
94
95qactivation_scale_zero_point_ops = op_bench.op_list(
96    attrs=(
97        ("functional.hardswish", qF.hardswish),
98        ("functional.elu", qF.elu),
99        ("functional.celu", qF.celu),
100    ),
101    attr_names=("op_name", "op_func"),
102)
103
104
105class QActivationScaleZeroPointBenchmark(QActivationBenchmarkBase):
106    def forward(self, q_input):
107        return self.qop(q_input, scale=self.scale, zero_point=self.zero_point)
108
109
110op_bench.generate_pt_tests_from_op_list(
111    qactivation_scale_zero_point_ops,
112    qactivation_short_configs + qactivation_long_configs,
113    QActivationScaleZeroPointBenchmark,
114)
115
116if __name__ == "__main__":
117    op_bench.benchmark_runner.main()
118