1from pt import configs 2 3import operator_benchmark as op_bench 4 5import torch 6import torch.nn as nn 7 8 9"""Microbenchmarks for Linear operator.""" 10 11 12class LinearBenchmark(op_bench.TorchBenchmarkBase): 13 def init(self, N, IN, OUT, device): 14 self.inputs = {"input_one": torch.rand(N, IN, device=device)} 15 self.linear = nn.Linear(IN, OUT).to(device=device) 16 self.set_module_name("linear") 17 18 def forward(self, input_one): 19 return self.linear(input_one) 20 21 22op_bench.generate_pt_test( 23 configs.linear_configs_short + configs.linear_configs_long, LinearBenchmark 24) 25 26 27if __name__ == "__main__": 28 op_bench.benchmark_runner.main() 29