import operator_benchmark as op_bench import torch import torch.ao.quantization.observer as obs qobserver_short_configs_dict = { "attr_names": ("C", "M", "N", "dtype", "device"), "attrs": ( (3, 512, 512, torch.quint8, "cpu"), (3, 512, 512, torch.quint8, "cuda"), ), "tags": ("short",), } q_hist_observer_short_configs_dict = { "attr_names": ("C", "M", "N", "dtype", "device"), "attrs": ((3, 512, 512, torch.quint8, "cpu"),), "tags": ("short",), } qobserver_long_configs_dict = { "C": (32, 64), "M": (256, 1024), "N": (256, 1024), "device": ("cpu", "cuda"), "dtype": (torch.quint8,), # dtype doesn't change the timing, keep the same "tags": ("long",), } q_hist_observer_long_configs_dict = { "C": (1, 3, 8), "M": (256, 1024), "N": (256, 1024), "device": ("cpu",), "dtype": (torch.quint8,), # dtype doesn't change the timing, keep the same "tags": ("long",), } qobserver_per_tensor_configs_short = op_bench.config_list( cross_product_configs={ "qscheme": (torch.per_tensor_affine, torch.per_tensor_symmetric) }, **qobserver_short_configs_dict, ) qobserver_per_tensor_configs_long = op_bench.cross_product_configs( qscheme=(torch.per_tensor_affine, torch.per_tensor_symmetric), **qobserver_long_configs_dict, ) qobserver_per_channel_configs_short = op_bench.config_list( cross_product_configs={ "qscheme": (torch.per_channel_affine, torch.per_channel_symmetric) }, **qobserver_short_configs_dict, ) qobserver_per_channel_configs_long = op_bench.cross_product_configs( qscheme=(torch.per_channel_affine, torch.per_channel_symmetric), **qobserver_long_configs_dict, ) q_hist_observer_per_tensor_configs_short = op_bench.config_list( cross_product_configs={ "qscheme": (torch.per_tensor_affine, torch.per_tensor_symmetric) }, **q_hist_observer_short_configs_dict, ) q_hist_observer_per_tensor_configs_long = op_bench.cross_product_configs( qscheme=(torch.per_tensor_affine, torch.per_tensor_symmetric), **q_hist_observer_long_configs_dict, ) qobserver_per_tensor_list = op_bench.op_list( attr_names=["op_name", "op_func"], attrs=[ ["MinMaxObserver", obs.MinMaxObserver], ["MovingAverageMinMaxObserver", obs.MovingAverageMinMaxObserver], ], ) qobserver_per_channel_list = op_bench.op_list( attr_names=["op_name", "op_func"], attrs=[ ["PerChannelMinMaxObserver", obs.PerChannelMinMaxObserver], [ "MovingAveragePerChannelMinMaxObserver", obs.MovingAveragePerChannelMinMaxObserver, ], ], ) q_hist_observer_list = op_bench.op_list( attr_names=["op_name", "op_func"], attrs=[ ["HistogramObserver", obs.HistogramObserver], ["HistogramObserverCalculateQparams", obs.HistogramObserver], ], ) class QObserverBenchmark(op_bench.TorchBenchmarkBase): def init(self, C, M, N, dtype, qscheme, op_func, device): self.inputs = {"f_input": torch.rand(C, M, N, device=device)} self.op_func = op_func(dtype=dtype, qscheme=qscheme).to(device) def forward(self, f_input): self.op_func(f_input) return self.op_func.calculate_qparams() class QObserverBenchmarkCalculateQparams(op_bench.TorchBenchmarkBase): def init(self, C, M, N, dtype, qscheme, op_func, device): self.f_input = torch.rand(C, M, N, device=device) self.q_observer = op_func(dtype=dtype, qscheme=qscheme).to(device) self.q_observer(self.f_input) self.inputs = {} def forward(self): return self.q_observer.calculate_qparams() op_bench.generate_pt_tests_from_op_list( qobserver_per_tensor_list, qobserver_per_tensor_configs_short + qobserver_per_tensor_configs_long, QObserverBenchmark, ) op_bench.generate_pt_tests_from_op_list( qobserver_per_channel_list, qobserver_per_channel_configs_short + qobserver_per_channel_configs_long, QObserverBenchmark, ) op_bench.generate_pt_tests_from_op_list( q_hist_observer_list, q_hist_observer_per_tensor_configs_short + q_hist_observer_per_tensor_configs_long, QObserverBenchmarkCalculateQparams, ) if __name__ == "__main__": op_bench.benchmark_runner.main()