import operator_benchmark as op_bench import torch """Microbenchmarks for binary operators.""" # Benchmark ops performance with broadcast binary_ops_bcast_list = op_bench.op_list( attr_names=["op_name", "op_func"], attrs=[ ["add", torch.add], ], ) # Configs with broadcast binary_configs_broadcast = op_bench.config_list( attr_names=["in_one", "in_two"], attrs=[ [[64, 1, 64], [1, 64, 1]], ], cross_product_configs={ "device": ["cpu"], "dtype": [torch.float], }, tags=["short"], ) class BinaryOpBcastBenchmark(op_bench.TorchBenchmarkBase): def init(self, in_one, in_two, dtype, device, op_func): self.inputs = { "in_one": torch.randn(in_one, device=device).to(dtype=dtype), "in_two": torch.randn(in_two, device=device).to(dtype=dtype), } self.op_func = op_func def forward(self, in_one, in_two): return self.op_func(in_one, in_two) op_bench.generate_pt_tests_from_op_list( binary_ops_bcast_list, binary_configs_broadcast, BinaryOpBcastBenchmark ) def copy(in1, in2): return in1.copy_(in2) # Benchmark ops performance without broadcast binary_ops_list = op_bench.op_list( attr_names=["op_name", "op_func"], attrs=[ ["add", torch.add], ["copy_", copy], ], ) binary_short_configs = op_bench.config_list( attr_names=["M", "N", "K"], attrs=[ [1, 1, 1], [64, 64, 64], [64, 64, 128], ], cross_product_configs={ "device": ["cpu", "cuda"], "dtype_one": [torch.int32], "dtype_two": [torch.int32], }, tags=["short"], ) binary_long_configs = op_bench.cross_product_configs( M=[8, 128], N=[32, 64], K=[256, 512], device=["cpu", "cuda"], dtype_one=[torch.int8, torch.int32], dtype_two=[torch.int8, torch.int32], tags=["long"], ) class BinaryOpBenchmark(op_bench.TorchBenchmarkBase): def init(self, M, N, K, device, dtype_one, dtype_two, op_func): self.inputs = { "input_one": torch.randn(M, N, K, device=device).to(dtype=dtype_one), "input_two": torch.randn(M, N, K, device=device).to(dtype=dtype_two), } self.op_func = op_func def forward(self, input_one, input_two): return self.op_func(input_one, input_two) op_bench.generate_pt_tests_from_op_list( binary_ops_list, binary_short_configs + binary_long_configs, BinaryOpBenchmark ) if __name__ == "__main__": op_bench.benchmark_runner.main()