• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1import operator_benchmark as op_bench
2
3import torch
4import torch.nn as nn
5
6
7"""
8Microbenchmarks for MaxPool1d and AvgPool1d operators.
9"""
10
11# Configs for pool-1d ops
12pool_1d_configs_short = op_bench.config_list(
13    attr_names=["kernel", "stride", "N", "C", "L"],
14    attrs=[
15        [3, 1, 8, 256, 256],
16    ],
17    cross_product_configs={
18        "device": ["cpu", "cuda"],
19    },
20    tags=["short"],
21)
22
23pool_1d_configs_long = op_bench.cross_product_configs(
24    kernel=[3],
25    stride=[1, 2],
26    N=[8, 16],
27    C=[3],
28    L=[128, 256],
29    device=["cpu", "cuda"],
30    tags=["long"],
31)
32
33pool_1d_ops_list = op_bench.op_list(
34    attr_names=["op_name", "op_func"],
35    attrs=[
36        ["MaxPool1d", nn.MaxPool1d],
37        ["AvgPool1d", nn.AvgPool1d],
38    ],
39)
40
41
42class Pool1dBenchmark(op_bench.TorchBenchmarkBase):
43    def init(self, kernel, stride, N, C, L, device, op_func):
44        self.inputs = {"input": torch.rand(N, C, L, device=device)}
45        self.op_func = op_func(kernel, stride=stride)
46
47    def forward(self, input):
48        return self.op_func(input)
49
50
51op_bench.generate_pt_tests_from_op_list(
52    pool_1d_ops_list, pool_1d_configs_short + pool_1d_configs_long, Pool1dBenchmark
53)
54
55
56"""
57Microbenchmarks for MaxPool2d and AvgPool2d operators.
58"""
59
60
61# Configs for pool-2d ops
62pool_2d_configs_short = op_bench.config_list(
63    attr_names=["kernel", "stride", "N", "C", "H", "W"],
64    attrs=[
65        [[3, 1], [2, 1], 1, 16, 32, 32],
66    ],
67    cross_product_configs={
68        "device": ["cpu", "cuda"],
69    },
70    tags=["short"],
71)
72
73pool_2d_configs_long = op_bench.cross_product_configs(
74    kernel=[[3, 2], [3, 3]],
75    stride=[[2, 2]],
76    N=[8, 16],
77    C=[32],
78    H=[32, 64],
79    W=[32, 64],
80    device=["cpu", "cuda"],
81    tags=["long"],
82)
83
84pool_2d_ops_list = op_bench.op_list(
85    attr_names=["op_name", "op_func"],
86    attrs=[
87        ["MaxPool2d", nn.MaxPool2d],
88        ["AvgPool2d", nn.AvgPool2d],
89        ["AdaptiveMaxPool2d", lambda kernel, stride: nn.AdaptiveMaxPool2d(kernel)],
90        [
91            "FractionalMaxPool2d",
92            lambda kernel, stride: nn.FractionalMaxPool2d(kernel, output_size=2),
93        ],
94    ],
95)
96
97
98class Pool2dBenchmark(op_bench.TorchBenchmarkBase):
99    def init(self, kernel, stride, N, C, H, W, device, op_func):
100        self.inputs = {"input": torch.rand(N, C, H, W, device=device)}
101        self.op_func = op_func(kernel, stride=stride)
102
103    def forward(self, input):
104        return self.op_func(input)
105
106
107op_bench.generate_pt_tests_from_op_list(
108    pool_2d_ops_list, pool_2d_configs_short + pool_2d_configs_long, Pool2dBenchmark
109)
110
111
112"""
113Microbenchmarks for MaxPool3d and AvgPool3d operators.
114"""
115
116
117# Configs for pool-3d ops
118pool_3d_configs_short = op_bench.config_list(
119    attr_names=["kernel", "stride", "N", "C", "D", "H", "W"],
120    attrs=[
121        [[3, 1, 3], [2, 1, 2], 1, 16, 16, 32, 32],
122    ],
123    cross_product_configs={
124        "device": ["cpu", "cuda"],
125    },
126    tags=["short"],
127)
128
129pool_3d_configs_long = op_bench.cross_product_configs(
130    kernel=[[3, 2, 3], [3, 3, 3]],
131    stride=[[2, 2, 2]],
132    N=[8, 16],
133    C=[32],
134    D=[32],
135    H=[32, 64],
136    W=[32, 64],
137    device=["cpu", "cuda"],
138    tags=["long"],
139)
140
141
142pool_3d_ops_list = op_bench.op_list(
143    attr_names=["op_name", "op_func"],
144    attrs=[
145        ["MaxPool3d", nn.MaxPool3d],
146        ["AvgPool3d", nn.AvgPool3d],
147        ["AdaptiveMaxPool3d", lambda kernel, stride: nn.AdaptiveMaxPool3d(kernel)],
148        [
149            "FractionalMaxPool3d",
150            lambda kernel, stride: nn.FractionalMaxPool3d(kernel, output_size=2),
151        ],
152    ],
153)
154
155
156class Pool3dBenchmark(op_bench.TorchBenchmarkBase):
157    def init(self, kernel, stride, N, C, D, H, W, device, op_func):
158        self.inputs = {"input": torch.rand(N, C, D, H, W, device=device)}
159        self.op_func = op_func(kernel, stride=stride)
160
161    def forward(self, input):
162        return self.op_func(input)
163
164
165op_bench.generate_pt_tests_from_op_list(
166    pool_3d_ops_list, pool_3d_configs_short + pool_3d_configs_long, Pool3dBenchmark
167)
168
169
170if __name__ == "__main__":
171    op_bench.benchmark_runner.main()
172