1import operator_benchmark as op_bench 2 3import torch 4 5 6"""Microbenchmarks for channel_shuffle operator.""" 7 8 9# Configs for PT channel_shuffle operator 10channel_shuffle_long_configs = op_bench.cross_product_configs( 11 batch_size=[4, 8], 12 channels_per_group=[32, 64], 13 height=[32, 64], 14 width=[32, 64], 15 groups=[4, 8], 16 channel_last=[True, False], 17 tags=["long"], 18) 19 20 21channel_shuffle_short_configs = op_bench.config_list( 22 attr_names=["batch_size", "channels_per_group", "height", "width", "groups"], 23 attrs=[ 24 [2, 16, 16, 16, 2], 25 [2, 32, 32, 32, 2], 26 [4, 32, 32, 32, 4], 27 [4, 64, 64, 64, 4], 28 [8, 64, 64, 64, 8], 29 [16, 64, 64, 64, 16], 30 ], 31 cross_product_configs={ 32 "channel_last": [True, False], 33 }, 34 tags=["short"], 35) 36 37 38class ChannelSHuffleBenchmark(op_bench.TorchBenchmarkBase): 39 def init(self, batch_size, channels_per_group, height, width, groups, channel_last): 40 channels = channels_per_group * groups 41 data_shape = (batch_size, channels, height, width) 42 input_data = torch.rand(data_shape) 43 if channel_last: 44 input_data = input_data.contiguous(memory_format=torch.channels_last) 45 self.inputs = {"input_data": input_data, "groups": groups} 46 self.set_module_name("channel_shuffle") 47 48 def forward(self, input_data, groups: int): 49 return torch.channel_shuffle(input_data, groups) 50 51 52op_bench.generate_pt_test( 53 channel_shuffle_short_configs + channel_shuffle_long_configs, 54 ChannelSHuffleBenchmark, 55) 56 57 58if __name__ == "__main__": 59 op_bench.benchmark_runner.main() 60