1import operator_benchmark as op_bench 2 3import torch 4 5 6"""Microbenchmarks for interpolate operator.""" 7 8 9class InterpolateBenchmark(op_bench.TorchBenchmarkBase): 10 def init( 11 self, 12 input_size, 13 output_size, 14 channels_last=False, 15 mode="linear", 16 dtype=torch.float, 17 ): 18 input_image = torch.randint( 19 0, 20 256, 21 size=input_size, 22 dtype=dtype, 23 device="cpu", 24 requires_grad=self.auto_set(), 25 ) 26 if channels_last: 27 if input_image.ndim == 4: 28 input_image = input_image.contiguous(memory_format=torch.channels_last) 29 elif input_image.ndim == 5: 30 input_image = input_image.contiguous( 31 memory_format=torch.channels_last_3d 32 ) 33 else: 34 raise ValueError( 35 f"Can not set channels_last to the input of {input_image.ndim} dims" 36 ) 37 38 align_corners = None if mode == "nearest" else False 39 40 if mode == "linear": 41 mode = { 42 3: "linear", 43 4: "bilinear", 44 5: "trilinear", 45 }[input_image.ndim] 46 47 self.inputs = { 48 "input_image": input_image, 49 "output_size": output_size, 50 "mode": mode, 51 "align_corners": align_corners, 52 } 53 54 self.set_module_name("interpolate") 55 56 def forward(self, input_image, output_size, mode, align_corners): 57 return torch.nn.functional.interpolate( 58 input_image, size=output_size, mode=mode, align_corners=align_corners 59 ) 60 61 62config_short = op_bench.config_list( 63 attr_names=["input_size", "output_size"], 64 attrs=[ 65 [(1, 3, 60, 40), (24, 24)], 66 [(1, 3, 600, 400), (240, 240)], 67 [(1, 3, 320, 320), (256, 256)], 68 [(1, 1, 60, 40), (24, 24)], 69 [(1, 1, 600, 400), (240, 240)], 70 [(1, 1, 320, 320), (256, 256)], 71 ], 72 cross_product_configs={ 73 "channels_last": [True, False], 74 "mode": ["nearest", "linear", "bicubic"], 75 }, 76 tags=["short"], 77) 78 79config_short += op_bench.config_list( 80 attr_names=["input_size", "output_size"], 81 attrs=[ 82 [(1, 3, 60, 40), (24, 24)], 83 [(1, 3, 600, 400), (240, 240)], 84 [(1, 3, 320, 320), (256, 256)], 85 [(1, 1, 60, 40), (24, 24)], 86 [(1, 1, 600, 400), (240, 240)], 87 [(1, 1, 320, 320), (256, 256)], 88 ], 89 cross_product_configs={ 90 "channels_last": [True, False], 91 "mode": [ 92 "nearest", 93 ], 94 "dtype": [ 95 torch.uint8, 96 ], 97 }, 98 tags=["short"], 99) 100 101 102config_long = op_bench.config_list( 103 attr_names=["input_size", "output_size"], 104 attrs=[ 105 [(1, 3, 320, 320), (512, 512)], 106 [(1, 3, 500, 500), (256, 256)], 107 [(1, 3, 500, 500), (800, 800)], 108 [(1, 1, 320, 320), (512, 512)], 109 [(1, 1, 500, 500), (256, 256)], 110 [(1, 1, 500, 500), (800, 800)], 111 # vectorization test-case 112 [(2, 128, 64, 46), (128, 128)], 113 [(2, 128, 64, 46), (32, 24)], 114 ], 115 cross_product_configs={ 116 "channels_last": [True, False], 117 "mode": ["nearest", "linear", "bicubic"], 118 }, 119 tags=["long"], 120) 121 122 123config_3d = op_bench.config_list( 124 # no channels_last for 3D tensors 125 attr_names=["input_size", "output_size"], 126 attrs=[ 127 [(4, 512, 320), (256,)], 128 [(4, 512, 320), (512,)], 129 ], 130 cross_product_configs={ 131 "mode": ["nearest", "linear"], 132 }, 133 tags=["long"], 134) 135 136 137config_5d = op_bench.config_list( 138 attr_names=["input_size", "output_size"], 139 attrs=[ 140 [(1, 3, 16, 320, 320), (8, 256, 256)], 141 [(1, 3, 16, 320, 320), (32, 512, 512)], 142 # vectorization test-case 143 [(1, 16, 32, 64, 64), (16, 32, 32)], 144 [(1, 16, 32, 64, 64), (64, 128, 128)], 145 ], 146 cross_product_configs={ 147 "channels_last": [True, False], 148 "mode": ["nearest", "linear"], 149 }, 150 tags=["long"], 151) 152 153 154for config in (config_short, config_long, config_3d, config_5d): 155 op_bench.generate_pt_test(config, InterpolateBenchmark) 156 157 158if __name__ == "__main__": 159 op_bench.benchmark_runner.main() 160