• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1import operator_benchmark as op_bench
2
3import torch
4
5
6"""Microbenchmarks for interpolate operator."""
7
8
9class InterpolateBenchmark(op_bench.TorchBenchmarkBase):
10    def init(
11        self,
12        input_size,
13        output_size,
14        channels_last=False,
15        mode="linear",
16        dtype=torch.float,
17    ):
18        input_image = torch.randint(
19            0,
20            256,
21            size=input_size,
22            dtype=dtype,
23            device="cpu",
24            requires_grad=self.auto_set(),
25        )
26        if channels_last:
27            if input_image.ndim == 4:
28                input_image = input_image.contiguous(memory_format=torch.channels_last)
29            elif input_image.ndim == 5:
30                input_image = input_image.contiguous(
31                    memory_format=torch.channels_last_3d
32                )
33            else:
34                raise ValueError(
35                    f"Can not set channels_last to the input of {input_image.ndim} dims"
36                )
37
38        align_corners = None if mode == "nearest" else False
39
40        if mode == "linear":
41            mode = {
42                3: "linear",
43                4: "bilinear",
44                5: "trilinear",
45            }[input_image.ndim]
46
47        self.inputs = {
48            "input_image": input_image,
49            "output_size": output_size,
50            "mode": mode,
51            "align_corners": align_corners,
52        }
53
54        self.set_module_name("interpolate")
55
56    def forward(self, input_image, output_size, mode, align_corners):
57        return torch.nn.functional.interpolate(
58            input_image, size=output_size, mode=mode, align_corners=align_corners
59        )
60
61
62config_short = op_bench.config_list(
63    attr_names=["input_size", "output_size"],
64    attrs=[
65        [(1, 3, 60, 40), (24, 24)],
66        [(1, 3, 600, 400), (240, 240)],
67        [(1, 3, 320, 320), (256, 256)],
68        [(1, 1, 60, 40), (24, 24)],
69        [(1, 1, 600, 400), (240, 240)],
70        [(1, 1, 320, 320), (256, 256)],
71    ],
72    cross_product_configs={
73        "channels_last": [True, False],
74        "mode": ["nearest", "linear", "bicubic"],
75    },
76    tags=["short"],
77)
78
79config_short += op_bench.config_list(
80    attr_names=["input_size", "output_size"],
81    attrs=[
82        [(1, 3, 60, 40), (24, 24)],
83        [(1, 3, 600, 400), (240, 240)],
84        [(1, 3, 320, 320), (256, 256)],
85        [(1, 1, 60, 40), (24, 24)],
86        [(1, 1, 600, 400), (240, 240)],
87        [(1, 1, 320, 320), (256, 256)],
88    ],
89    cross_product_configs={
90        "channels_last": [True, False],
91        "mode": [
92            "nearest",
93        ],
94        "dtype": [
95            torch.uint8,
96        ],
97    },
98    tags=["short"],
99)
100
101
102config_long = op_bench.config_list(
103    attr_names=["input_size", "output_size"],
104    attrs=[
105        [(1, 3, 320, 320), (512, 512)],
106        [(1, 3, 500, 500), (256, 256)],
107        [(1, 3, 500, 500), (800, 800)],
108        [(1, 1, 320, 320), (512, 512)],
109        [(1, 1, 500, 500), (256, 256)],
110        [(1, 1, 500, 500), (800, 800)],
111        # vectorization test-case
112        [(2, 128, 64, 46), (128, 128)],
113        [(2, 128, 64, 46), (32, 24)],
114    ],
115    cross_product_configs={
116        "channels_last": [True, False],
117        "mode": ["nearest", "linear", "bicubic"],
118    },
119    tags=["long"],
120)
121
122
123config_3d = op_bench.config_list(
124    # no channels_last for 3D tensors
125    attr_names=["input_size", "output_size"],
126    attrs=[
127        [(4, 512, 320), (256,)],
128        [(4, 512, 320), (512,)],
129    ],
130    cross_product_configs={
131        "mode": ["nearest", "linear"],
132    },
133    tags=["long"],
134)
135
136
137config_5d = op_bench.config_list(
138    attr_names=["input_size", "output_size"],
139    attrs=[
140        [(1, 3, 16, 320, 320), (8, 256, 256)],
141        [(1, 3, 16, 320, 320), (32, 512, 512)],
142        # vectorization test-case
143        [(1, 16, 32, 64, 64), (16, 32, 32)],
144        [(1, 16, 32, 64, 64), (64, 128, 128)],
145    ],
146    cross_product_configs={
147        "channels_last": [True, False],
148        "mode": ["nearest", "linear"],
149    },
150    tags=["long"],
151)
152
153
154for config in (config_short, config_long, config_3d, config_5d):
155    op_bench.generate_pt_test(config, InterpolateBenchmark)
156
157
158if __name__ == "__main__":
159    op_bench.benchmark_runner.main()
160