1import operator_benchmark as op_bench 2 3import torch 4import torch.ao.quantization.observer as obs 5 6 7qobserver_short_configs_dict = { 8 "attr_names": ("C", "M", "N", "dtype", "device"), 9 "attrs": ( 10 (3, 512, 512, torch.quint8, "cpu"), 11 (3, 512, 512, torch.quint8, "cuda"), 12 ), 13 "tags": ("short",), 14} 15 16q_hist_observer_short_configs_dict = { 17 "attr_names": ("C", "M", "N", "dtype", "device"), 18 "attrs": ((3, 512, 512, torch.quint8, "cpu"),), 19 "tags": ("short",), 20} 21 22qobserver_long_configs_dict = { 23 "C": (32, 64), 24 "M": (256, 1024), 25 "N": (256, 1024), 26 "device": ("cpu", "cuda"), 27 "dtype": (torch.quint8,), # dtype doesn't change the timing, keep the same 28 "tags": ("long",), 29} 30 31q_hist_observer_long_configs_dict = { 32 "C": (1, 3, 8), 33 "M": (256, 1024), 34 "N": (256, 1024), 35 "device": ("cpu",), 36 "dtype": (torch.quint8,), # dtype doesn't change the timing, keep the same 37 "tags": ("long",), 38} 39 40 41qobserver_per_tensor_configs_short = op_bench.config_list( 42 cross_product_configs={ 43 "qscheme": (torch.per_tensor_affine, torch.per_tensor_symmetric) 44 }, 45 **qobserver_short_configs_dict, 46) 47 48qobserver_per_tensor_configs_long = op_bench.cross_product_configs( 49 qscheme=(torch.per_tensor_affine, torch.per_tensor_symmetric), 50 **qobserver_long_configs_dict, 51) 52 53qobserver_per_channel_configs_short = op_bench.config_list( 54 cross_product_configs={ 55 "qscheme": (torch.per_channel_affine, torch.per_channel_symmetric) 56 }, 57 **qobserver_short_configs_dict, 58) 59 60qobserver_per_channel_configs_long = op_bench.cross_product_configs( 61 qscheme=(torch.per_channel_affine, torch.per_channel_symmetric), 62 **qobserver_long_configs_dict, 63) 64 65q_hist_observer_per_tensor_configs_short = op_bench.config_list( 66 cross_product_configs={ 67 "qscheme": (torch.per_tensor_affine, torch.per_tensor_symmetric) 68 }, 69 **q_hist_observer_short_configs_dict, 70) 71 72q_hist_observer_per_tensor_configs_long = op_bench.cross_product_configs( 73 qscheme=(torch.per_tensor_affine, torch.per_tensor_symmetric), 74 **q_hist_observer_long_configs_dict, 75) 76 77 78qobserver_per_tensor_list = op_bench.op_list( 79 attr_names=["op_name", "op_func"], 80 attrs=[ 81 ["MinMaxObserver", obs.MinMaxObserver], 82 ["MovingAverageMinMaxObserver", obs.MovingAverageMinMaxObserver], 83 ], 84) 85 86qobserver_per_channel_list = op_bench.op_list( 87 attr_names=["op_name", "op_func"], 88 attrs=[ 89 ["PerChannelMinMaxObserver", obs.PerChannelMinMaxObserver], 90 [ 91 "MovingAveragePerChannelMinMaxObserver", 92 obs.MovingAveragePerChannelMinMaxObserver, 93 ], 94 ], 95) 96 97q_hist_observer_list = op_bench.op_list( 98 attr_names=["op_name", "op_func"], 99 attrs=[ 100 ["HistogramObserver", obs.HistogramObserver], 101 ["HistogramObserverCalculateQparams", obs.HistogramObserver], 102 ], 103) 104 105 106class QObserverBenchmark(op_bench.TorchBenchmarkBase): 107 def init(self, C, M, N, dtype, qscheme, op_func, device): 108 self.inputs = {"f_input": torch.rand(C, M, N, device=device)} 109 self.op_func = op_func(dtype=dtype, qscheme=qscheme).to(device) 110 111 def forward(self, f_input): 112 self.op_func(f_input) 113 return self.op_func.calculate_qparams() 114 115 116class QObserverBenchmarkCalculateQparams(op_bench.TorchBenchmarkBase): 117 def init(self, C, M, N, dtype, qscheme, op_func, device): 118 self.f_input = torch.rand(C, M, N, device=device) 119 self.q_observer = op_func(dtype=dtype, qscheme=qscheme).to(device) 120 self.q_observer(self.f_input) 121 self.inputs = {} 122 123 def forward(self): 124 return self.q_observer.calculate_qparams() 125 126 127op_bench.generate_pt_tests_from_op_list( 128 qobserver_per_tensor_list, 129 qobserver_per_tensor_configs_short + qobserver_per_tensor_configs_long, 130 QObserverBenchmark, 131) 132 133op_bench.generate_pt_tests_from_op_list( 134 qobserver_per_channel_list, 135 qobserver_per_channel_configs_short + qobserver_per_channel_configs_long, 136 QObserverBenchmark, 137) 138 139op_bench.generate_pt_tests_from_op_list( 140 q_hist_observer_list, 141 q_hist_observer_per_tensor_configs_short + q_hist_observer_per_tensor_configs_long, 142 QObserverBenchmarkCalculateQparams, 143) 144 145 146if __name__ == "__main__": 147 op_bench.benchmark_runner.main() 148