• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1import operator_benchmark as op_bench
2
3import torch
4
5
6"""Microbenchmarks for channel_shuffle operator."""
7
8
9# Configs for PT channel_shuffle operator
10channel_shuffle_long_configs = op_bench.cross_product_configs(
11    batch_size=[4, 8],
12    channels_per_group=[32, 64],
13    height=[32, 64],
14    width=[32, 64],
15    groups=[4, 8],
16    channel_last=[True, False],
17    tags=["long"],
18)
19
20
21channel_shuffle_short_configs = op_bench.config_list(
22    attr_names=["batch_size", "channels_per_group", "height", "width", "groups"],
23    attrs=[
24        [2, 16, 16, 16, 2],
25        [2, 32, 32, 32, 2],
26        [4, 32, 32, 32, 4],
27        [4, 64, 64, 64, 4],
28        [8, 64, 64, 64, 8],
29        [16, 64, 64, 64, 16],
30    ],
31    cross_product_configs={
32        "channel_last": [True, False],
33    },
34    tags=["short"],
35)
36
37
38class ChannelSHuffleBenchmark(op_bench.TorchBenchmarkBase):
39    def init(self, batch_size, channels_per_group, height, width, groups, channel_last):
40        channels = channels_per_group * groups
41        data_shape = (batch_size, channels, height, width)
42        input_data = torch.rand(data_shape)
43        if channel_last:
44            input_data = input_data.contiguous(memory_format=torch.channels_last)
45        self.inputs = {"input_data": input_data, "groups": groups}
46        self.set_module_name("channel_shuffle")
47
48    def forward(self, input_data, groups: int):
49        return torch.channel_shuffle(input_data, groups)
50
51
52op_bench.generate_pt_test(
53    channel_shuffle_short_configs + channel_shuffle_long_configs,
54    ChannelSHuffleBenchmark,
55)
56
57
58if __name__ == "__main__":
59    op_bench.benchmark_runner.main()
60