• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1import operator_benchmark as op_bench
2
3import torch
4import torch.ao.quantization.observer as obs
5
6
7qobserver_short_configs_dict = {
8    "attr_names": ("C", "M", "N", "dtype", "device"),
9    "attrs": (
10        (3, 512, 512, torch.quint8, "cpu"),
11        (3, 512, 512, torch.quint8, "cuda"),
12    ),
13    "tags": ("short",),
14}
15
16q_hist_observer_short_configs_dict = {
17    "attr_names": ("C", "M", "N", "dtype", "device"),
18    "attrs": ((3, 512, 512, torch.quint8, "cpu"),),
19    "tags": ("short",),
20}
21
22qobserver_long_configs_dict = {
23    "C": (32, 64),
24    "M": (256, 1024),
25    "N": (256, 1024),
26    "device": ("cpu", "cuda"),
27    "dtype": (torch.quint8,),  # dtype doesn't change the timing, keep the same
28    "tags": ("long",),
29}
30
31q_hist_observer_long_configs_dict = {
32    "C": (1, 3, 8),
33    "M": (256, 1024),
34    "N": (256, 1024),
35    "device": ("cpu",),
36    "dtype": (torch.quint8,),  # dtype doesn't change the timing, keep the same
37    "tags": ("long",),
38}
39
40
41qobserver_per_tensor_configs_short = op_bench.config_list(
42    cross_product_configs={
43        "qscheme": (torch.per_tensor_affine, torch.per_tensor_symmetric)
44    },
45    **qobserver_short_configs_dict,
46)
47
48qobserver_per_tensor_configs_long = op_bench.cross_product_configs(
49    qscheme=(torch.per_tensor_affine, torch.per_tensor_symmetric),
50    **qobserver_long_configs_dict,
51)
52
53qobserver_per_channel_configs_short = op_bench.config_list(
54    cross_product_configs={
55        "qscheme": (torch.per_channel_affine, torch.per_channel_symmetric)
56    },
57    **qobserver_short_configs_dict,
58)
59
60qobserver_per_channel_configs_long = op_bench.cross_product_configs(
61    qscheme=(torch.per_channel_affine, torch.per_channel_symmetric),
62    **qobserver_long_configs_dict,
63)
64
65q_hist_observer_per_tensor_configs_short = op_bench.config_list(
66    cross_product_configs={
67        "qscheme": (torch.per_tensor_affine, torch.per_tensor_symmetric)
68    },
69    **q_hist_observer_short_configs_dict,
70)
71
72q_hist_observer_per_tensor_configs_long = op_bench.cross_product_configs(
73    qscheme=(torch.per_tensor_affine, torch.per_tensor_symmetric),
74    **q_hist_observer_long_configs_dict,
75)
76
77
78qobserver_per_tensor_list = op_bench.op_list(
79    attr_names=["op_name", "op_func"],
80    attrs=[
81        ["MinMaxObserver", obs.MinMaxObserver],
82        ["MovingAverageMinMaxObserver", obs.MovingAverageMinMaxObserver],
83    ],
84)
85
86qobserver_per_channel_list = op_bench.op_list(
87    attr_names=["op_name", "op_func"],
88    attrs=[
89        ["PerChannelMinMaxObserver", obs.PerChannelMinMaxObserver],
90        [
91            "MovingAveragePerChannelMinMaxObserver",
92            obs.MovingAveragePerChannelMinMaxObserver,
93        ],
94    ],
95)
96
97q_hist_observer_list = op_bench.op_list(
98    attr_names=["op_name", "op_func"],
99    attrs=[
100        ["HistogramObserver", obs.HistogramObserver],
101        ["HistogramObserverCalculateQparams", obs.HistogramObserver],
102    ],
103)
104
105
106class QObserverBenchmark(op_bench.TorchBenchmarkBase):
107    def init(self, C, M, N, dtype, qscheme, op_func, device):
108        self.inputs = {"f_input": torch.rand(C, M, N, device=device)}
109        self.op_func = op_func(dtype=dtype, qscheme=qscheme).to(device)
110
111    def forward(self, f_input):
112        self.op_func(f_input)
113        return self.op_func.calculate_qparams()
114
115
116class QObserverBenchmarkCalculateQparams(op_bench.TorchBenchmarkBase):
117    def init(self, C, M, N, dtype, qscheme, op_func, device):
118        self.f_input = torch.rand(C, M, N, device=device)
119        self.q_observer = op_func(dtype=dtype, qscheme=qscheme).to(device)
120        self.q_observer(self.f_input)
121        self.inputs = {}
122
123    def forward(self):
124        return self.q_observer.calculate_qparams()
125
126
127op_bench.generate_pt_tests_from_op_list(
128    qobserver_per_tensor_list,
129    qobserver_per_tensor_configs_short + qobserver_per_tensor_configs_long,
130    QObserverBenchmark,
131)
132
133op_bench.generate_pt_tests_from_op_list(
134    qobserver_per_channel_list,
135    qobserver_per_channel_configs_short + qobserver_per_channel_configs_long,
136    QObserverBenchmark,
137)
138
139op_bench.generate_pt_tests_from_op_list(
140    q_hist_observer_list,
141    q_hist_observer_per_tensor_configs_short + q_hist_observer_per_tensor_configs_long,
142    QObserverBenchmarkCalculateQparams,
143)
144
145
146if __name__ == "__main__":
147    op_bench.benchmark_runner.main()
148