1import operator_benchmark as op_bench 2 3import torch 4 5 6"""Microbenchmarks for binary operators.""" 7 8 9# Benchmark ops performance with broadcast 10binary_ops_bcast_list = op_bench.op_list( 11 attr_names=["op_name", "op_func"], 12 attrs=[ 13 ["add", torch.add], 14 ], 15) 16 17# Configs with broadcast 18binary_configs_broadcast = op_bench.config_list( 19 attr_names=["in_one", "in_two"], 20 attrs=[ 21 [[64, 1, 64], [1, 64, 1]], 22 ], 23 cross_product_configs={ 24 "device": ["cpu"], 25 "dtype": [torch.float], 26 }, 27 tags=["short"], 28) 29 30 31class BinaryOpBcastBenchmark(op_bench.TorchBenchmarkBase): 32 def init(self, in_one, in_two, dtype, device, op_func): 33 self.inputs = { 34 "in_one": torch.randn(in_one, device=device).to(dtype=dtype), 35 "in_two": torch.randn(in_two, device=device).to(dtype=dtype), 36 } 37 self.op_func = op_func 38 39 def forward(self, in_one, in_two): 40 return self.op_func(in_one, in_two) 41 42 43op_bench.generate_pt_tests_from_op_list( 44 binary_ops_bcast_list, binary_configs_broadcast, BinaryOpBcastBenchmark 45) 46 47 48def copy(in1, in2): 49 return in1.copy_(in2) 50 51 52# Benchmark ops performance without broadcast 53binary_ops_list = op_bench.op_list( 54 attr_names=["op_name", "op_func"], 55 attrs=[ 56 ["add", torch.add], 57 ["copy_", copy], 58 ], 59) 60 61binary_short_configs = op_bench.config_list( 62 attr_names=["M", "N", "K"], 63 attrs=[ 64 [1, 1, 1], 65 [64, 64, 64], 66 [64, 64, 128], 67 ], 68 cross_product_configs={ 69 "device": ["cpu", "cuda"], 70 "dtype_one": [torch.int32], 71 "dtype_two": [torch.int32], 72 }, 73 tags=["short"], 74) 75 76binary_long_configs = op_bench.cross_product_configs( 77 M=[8, 128], 78 N=[32, 64], 79 K=[256, 512], 80 device=["cpu", "cuda"], 81 dtype_one=[torch.int8, torch.int32], 82 dtype_two=[torch.int8, torch.int32], 83 tags=["long"], 84) 85 86 87class BinaryOpBenchmark(op_bench.TorchBenchmarkBase): 88 def init(self, M, N, K, device, dtype_one, dtype_two, op_func): 89 self.inputs = { 90 "input_one": torch.randn(M, N, K, device=device).to(dtype=dtype_one), 91 "input_two": torch.randn(M, N, K, device=device).to(dtype=dtype_two), 92 } 93 self.op_func = op_func 94 95 def forward(self, input_one, input_two): 96 return self.op_func(input_one, input_two) 97 98 99op_bench.generate_pt_tests_from_op_list( 100 binary_ops_list, binary_short_configs + binary_long_configs, BinaryOpBenchmark 101) 102 103 104if __name__ == "__main__": 105 op_bench.benchmark_runner.main() 106