• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1from pt import configs
2
3import operator_benchmark as op_bench
4
5import torch
6import torch.nn as nn
7
8
9"""Microbenchmarks for Linear operator."""
10
11
12class LinearBenchmark(op_bench.TorchBenchmarkBase):
13    def init(self, N, IN, OUT, device):
14        self.inputs = {"input_one": torch.rand(N, IN, device=device)}
15        self.linear = nn.Linear(IN, OUT).to(device=device)
16        self.set_module_name("linear")
17
18    def forward(self, input_one):
19        return self.linear(input_one)
20
21
22op_bench.generate_pt_test(
23    configs.linear_configs_short + configs.linear_configs_long, LinearBenchmark
24)
25
26
27if __name__ == "__main__":
28    op_bench.benchmark_runner.main()
29