1from pt import configs 2 3import operator_benchmark as op_bench 4 5import torch 6import torch.nn as nn 7 8 9""" 10Microbenchmarks for Conv1d and ConvTranspose1d operators. 11""" 12 13 14class Conv1dBenchmark(op_bench.TorchBenchmarkBase): 15 def init(self, IC, OC, kernel, stride, N, L, device): 16 self.inputs = { 17 "input": torch.rand(N, IC, L, device=device, requires_grad=self.auto_set()) 18 } 19 self.conv1d = nn.Conv1d(IC, OC, kernel, stride=stride).to(device=device) 20 self.set_module_name("Conv1d") 21 22 def forward(self, input): 23 return self.conv1d(input) 24 25 26class ConvTranspose1dBenchmark(op_bench.TorchBenchmarkBase): 27 def init(self, IC, OC, kernel, stride, N, L, device): 28 self.inputs = {"input": torch.rand(N, IC, L, device=device)} 29 self.convtranspose1d = nn.ConvTranspose1d(IC, OC, kernel, stride=stride).to( 30 device=device 31 ) 32 self.set_module_name("ConvTranspose1d") 33 34 def forward(self, input): 35 return self.convtranspose1d(input) 36 37 38op_bench.generate_pt_test( 39 configs.conv_1d_configs_short + configs.conv_1d_configs_long, Conv1dBenchmark 40) 41op_bench.generate_pt_test( 42 configs.convtranspose_1d_configs_short 43 + configs.conv_1d_configs_short 44 + configs.conv_1d_configs_long, 45 ConvTranspose1dBenchmark, 46) 47 48 49""" 50Microbenchmarks for Conv2d, ConvTranspose2d, and Conv2dPointwise operators. 51""" 52 53 54class Conv2dBenchmark(op_bench.TorchBenchmarkBase): 55 def init(self, IC, OC, kernel, stride, N, H, W, G, pad, device): 56 self.inputs = {"input": torch.rand(N, IC, H, W, device=device)} 57 self.conv2d = nn.Conv2d( 58 IC, OC, kernel, stride=stride, groups=G, padding=pad 59 ).to(device=device) 60 self.set_module_name("Conv2d") 61 62 def forward(self, input): 63 return self.conv2d(input) 64 65 66class ConvTranspose2dBenchmark(op_bench.TorchBenchmarkBase): 67 def init(self, IC, OC, kernel, stride, N, H, W, G, pad, device): 68 self.inputs = {"input": torch.rand(N, IC, H, W, device=device)} 69 self.convtranspose2d = nn.ConvTranspose2d( 70 IC, OC, kernel, stride=stride, groups=G, padding=pad 71 ).to(device=device) 72 self.set_module_name("ConvTranspose2d") 73 74 def forward(self, input): 75 return self.convtranspose2d(input) 76 77 78class Conv2dPointwiseBenchmark(op_bench.TorchBenchmarkBase): 79 def init(self, IC, OC, stride, N, H, W, G, pad, device): 80 self.inputs = {"input": torch.rand(N, IC, H, W, device=device)} 81 # Use 1 as kernel for pointwise convolution 82 self.conv2d = nn.Conv2d(IC, OC, 1, stride=stride, groups=G, padding=pad).to( 83 device=device 84 ) 85 self.set_module_name("Conv2dPointwise") 86 87 def forward(self, input): 88 return self.conv2d(input) 89 90 91op_bench.generate_pt_test( 92 configs.conv_2d_configs_short + configs.conv_2d_configs_long, Conv2dBenchmark 93) 94op_bench.generate_pt_test( 95 configs.conv_2d_configs_short + configs.conv_2d_configs_long, 96 ConvTranspose2dBenchmark, 97) 98op_bench.generate_pt_test( 99 configs.conv_2d_pw_configs_short + configs.conv_2d_pw_configs_long, 100 Conv2dPointwiseBenchmark, 101) 102 103 104""" 105Microbenchmarks for Conv3d and ConvTranspose3d operators. 106""" 107 108 109class Conv3dBenchmark(op_bench.TorchBenchmarkBase): 110 def init(self, IC, OC, kernel, stride, N, D, H, W, device): 111 self.inputs = {"input": torch.rand(N, IC, D, H, W, device=device)} 112 self.conv3d = nn.Conv3d(IC, OC, kernel, stride=stride).to(device=device) 113 self.set_module_name("Conv3d") 114 115 def forward(self, input): 116 return self.conv3d(input) 117 118 119class ConvTranspose3dBenchmark(op_bench.TorchBenchmarkBase): 120 def init(self, IC, OC, kernel, stride, N, D, H, W, device): 121 self.inputs = {"input": torch.rand(N, IC, D, H, W, device=device)} 122 self.convtranspose3d = nn.ConvTranspose3d(IC, OC, kernel, stride=stride).to( 123 device=device 124 ) 125 self.set_module_name("ConvTranspose3d") 126 127 def forward(self, input): 128 return self.convtranspose3d(input) 129 130 131op_bench.generate_pt_test(configs.conv_3d_configs_short, Conv3dBenchmark) 132op_bench.generate_pt_test(configs.conv_3d_configs_short, ConvTranspose3dBenchmark) 133 134 135if __name__ == "__main__": 136 op_bench.benchmark_runner.main() 137