• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1from pt import configs
2
3import operator_benchmark as op_bench
4
5import torch
6import torch.nn as nn
7
8
9"""
10Microbenchmarks for Conv1d and ConvTranspose1d operators.
11"""
12
13
14class Conv1dBenchmark(op_bench.TorchBenchmarkBase):
15    def init(self, IC, OC, kernel, stride, N, L, device):
16        self.inputs = {
17            "input": torch.rand(N, IC, L, device=device, requires_grad=self.auto_set())
18        }
19        self.conv1d = nn.Conv1d(IC, OC, kernel, stride=stride).to(device=device)
20        self.set_module_name("Conv1d")
21
22    def forward(self, input):
23        return self.conv1d(input)
24
25
26class ConvTranspose1dBenchmark(op_bench.TorchBenchmarkBase):
27    def init(self, IC, OC, kernel, stride, N, L, device):
28        self.inputs = {"input": torch.rand(N, IC, L, device=device)}
29        self.convtranspose1d = nn.ConvTranspose1d(IC, OC, kernel, stride=stride).to(
30            device=device
31        )
32        self.set_module_name("ConvTranspose1d")
33
34    def forward(self, input):
35        return self.convtranspose1d(input)
36
37
38op_bench.generate_pt_test(
39    configs.conv_1d_configs_short + configs.conv_1d_configs_long, Conv1dBenchmark
40)
41op_bench.generate_pt_test(
42    configs.convtranspose_1d_configs_short
43    + configs.conv_1d_configs_short
44    + configs.conv_1d_configs_long,
45    ConvTranspose1dBenchmark,
46)
47
48
49"""
50Microbenchmarks for Conv2d, ConvTranspose2d, and Conv2dPointwise operators.
51"""
52
53
54class Conv2dBenchmark(op_bench.TorchBenchmarkBase):
55    def init(self, IC, OC, kernel, stride, N, H, W, G, pad, device):
56        self.inputs = {"input": torch.rand(N, IC, H, W, device=device)}
57        self.conv2d = nn.Conv2d(
58            IC, OC, kernel, stride=stride, groups=G, padding=pad
59        ).to(device=device)
60        self.set_module_name("Conv2d")
61
62    def forward(self, input):
63        return self.conv2d(input)
64
65
66class ConvTranspose2dBenchmark(op_bench.TorchBenchmarkBase):
67    def init(self, IC, OC, kernel, stride, N, H, W, G, pad, device):
68        self.inputs = {"input": torch.rand(N, IC, H, W, device=device)}
69        self.convtranspose2d = nn.ConvTranspose2d(
70            IC, OC, kernel, stride=stride, groups=G, padding=pad
71        ).to(device=device)
72        self.set_module_name("ConvTranspose2d")
73
74    def forward(self, input):
75        return self.convtranspose2d(input)
76
77
78class Conv2dPointwiseBenchmark(op_bench.TorchBenchmarkBase):
79    def init(self, IC, OC, stride, N, H, W, G, pad, device):
80        self.inputs = {"input": torch.rand(N, IC, H, W, device=device)}
81        # Use 1 as kernel for pointwise convolution
82        self.conv2d = nn.Conv2d(IC, OC, 1, stride=stride, groups=G, padding=pad).to(
83            device=device
84        )
85        self.set_module_name("Conv2dPointwise")
86
87    def forward(self, input):
88        return self.conv2d(input)
89
90
91op_bench.generate_pt_test(
92    configs.conv_2d_configs_short + configs.conv_2d_configs_long, Conv2dBenchmark
93)
94op_bench.generate_pt_test(
95    configs.conv_2d_configs_short + configs.conv_2d_configs_long,
96    ConvTranspose2dBenchmark,
97)
98op_bench.generate_pt_test(
99    configs.conv_2d_pw_configs_short + configs.conv_2d_pw_configs_long,
100    Conv2dPointwiseBenchmark,
101)
102
103
104"""
105Microbenchmarks for Conv3d and ConvTranspose3d operators.
106"""
107
108
109class Conv3dBenchmark(op_bench.TorchBenchmarkBase):
110    def init(self, IC, OC, kernel, stride, N, D, H, W, device):
111        self.inputs = {"input": torch.rand(N, IC, D, H, W, device=device)}
112        self.conv3d = nn.Conv3d(IC, OC, kernel, stride=stride).to(device=device)
113        self.set_module_name("Conv3d")
114
115    def forward(self, input):
116        return self.conv3d(input)
117
118
119class ConvTranspose3dBenchmark(op_bench.TorchBenchmarkBase):
120    def init(self, IC, OC, kernel, stride, N, D, H, W, device):
121        self.inputs = {"input": torch.rand(N, IC, D, H, W, device=device)}
122        self.convtranspose3d = nn.ConvTranspose3d(IC, OC, kernel, stride=stride).to(
123            device=device
124        )
125        self.set_module_name("ConvTranspose3d")
126
127    def forward(self, input):
128        return self.convtranspose3d(input)
129
130
131op_bench.generate_pt_test(configs.conv_3d_configs_short, Conv3dBenchmark)
132op_bench.generate_pt_test(configs.conv_3d_configs_short, ConvTranspose3dBenchmark)
133
134
135if __name__ == "__main__":
136    op_bench.benchmark_runner.main()
137