• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1import operator_benchmark as op_bench
2
3import torch
4
5
6"""Microbenchmarks for binary operators."""
7
8
9# Benchmark ops performance with broadcast
10binary_ops_bcast_list = op_bench.op_list(
11    attr_names=["op_name", "op_func"],
12    attrs=[
13        ["add", torch.add],
14    ],
15)
16
17# Configs with broadcast
18binary_configs_broadcast = op_bench.config_list(
19    attr_names=["in_one", "in_two"],
20    attrs=[
21        [[64, 1, 64], [1, 64, 1]],
22    ],
23    cross_product_configs={
24        "device": ["cpu"],
25        "dtype": [torch.float],
26    },
27    tags=["short"],
28)
29
30
31class BinaryOpBcastBenchmark(op_bench.TorchBenchmarkBase):
32    def init(self, in_one, in_two, dtype, device, op_func):
33        self.inputs = {
34            "in_one": torch.randn(in_one, device=device).to(dtype=dtype),
35            "in_two": torch.randn(in_two, device=device).to(dtype=dtype),
36        }
37        self.op_func = op_func
38
39    def forward(self, in_one, in_two):
40        return self.op_func(in_one, in_two)
41
42
43op_bench.generate_pt_tests_from_op_list(
44    binary_ops_bcast_list, binary_configs_broadcast, BinaryOpBcastBenchmark
45)
46
47
48def copy(in1, in2):
49    return in1.copy_(in2)
50
51
52# Benchmark ops performance without broadcast
53binary_ops_list = op_bench.op_list(
54    attr_names=["op_name", "op_func"],
55    attrs=[
56        ["add", torch.add],
57        ["copy_", copy],
58    ],
59)
60
61binary_short_configs = op_bench.config_list(
62    attr_names=["M", "N", "K"],
63    attrs=[
64        [1, 1, 1],
65        [64, 64, 64],
66        [64, 64, 128],
67    ],
68    cross_product_configs={
69        "device": ["cpu", "cuda"],
70        "dtype_one": [torch.int32],
71        "dtype_two": [torch.int32],
72    },
73    tags=["short"],
74)
75
76binary_long_configs = op_bench.cross_product_configs(
77    M=[8, 128],
78    N=[32, 64],
79    K=[256, 512],
80    device=["cpu", "cuda"],
81    dtype_one=[torch.int8, torch.int32],
82    dtype_two=[torch.int8, torch.int32],
83    tags=["long"],
84)
85
86
87class BinaryOpBenchmark(op_bench.TorchBenchmarkBase):
88    def init(self, M, N, K, device, dtype_one, dtype_two, op_func):
89        self.inputs = {
90            "input_one": torch.randn(M, N, K, device=device).to(dtype=dtype_one),
91            "input_two": torch.randn(M, N, K, device=device).to(dtype=dtype_two),
92        }
93        self.op_func = op_func
94
95    def forward(self, input_one, input_two):
96        return self.op_func(input_one, input_two)
97
98
99op_bench.generate_pt_tests_from_op_list(
100    binary_ops_list, binary_short_configs + binary_long_configs, BinaryOpBenchmark
101)
102
103
104if __name__ == "__main__":
105    op_bench.benchmark_runner.main()
106