1import operator_benchmark as op_bench 2 3import torch 4from torch.testing._internal.common_device_type import get_all_device_types 5 6 7"""Microbenchmark for Fill_ operator.""" 8 9fill_short_configs = op_bench.config_list( 10 attr_names=["N"], 11 attrs=[ 12 [1], 13 [1024], 14 [2048], 15 ], 16 cross_product_configs={ 17 "device": ["cpu", "cuda"], 18 "dtype": [torch.int32], 19 }, 20 tags=["short"], 21) 22 23fill_long_configs = op_bench.cross_product_configs( 24 N=[10, 1000], 25 device=get_all_device_types(), 26 dtype=[ 27 torch.bool, 28 torch.int8, 29 torch.uint8, 30 torch.int16, 31 torch.int32, 32 torch.int64, 33 torch.half, 34 torch.float, 35 torch.double, 36 ], 37 tags=["long"], 38) 39 40 41class Fill_Benchmark(op_bench.TorchBenchmarkBase): 42 def init(self, N, device, dtype): 43 self.inputs = {"input_one": torch.zeros(N, device=device).type(dtype)} 44 self.set_module_name("fill_") 45 46 def forward(self, input_one): 47 return input_one.fill_(10) 48 49 50op_bench.generate_pt_test(fill_short_configs + fill_long_configs, Fill_Benchmark) 51 52 53if __name__ == "__main__": 54 op_bench.benchmark_runner.main() 55