• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6from typing import Any, Dict, List, Tuple
7
8import numpy as np
9import PIL
10import pytest
11import torch
12
13# Import these first. Otherwise, the custom ops are not registered.
14from executorch.extension.pybindings import portable_lib  # noqa # usort: skip
15from executorch.extension.llm.custom_ops import op_tile_crop_aot  # noqa # usort: skip
16
17from executorch.examples.models.llama3_2_vision.preprocess.model import (
18    CLIPImageTransformModel,
19    PreprocessConfig,
20)
21
22from executorch.exir import EdgeCompileConfig, to_edge
23
24from executorch.extension.pybindings.portable_lib import (
25    _load_for_executorch_from_buffer,
26)
27
28from PIL import Image
29from torch._inductor.package import package_aoti
30
31from torchtune.models.clip.inference._transform import CLIPImageTransform
32
33from torchtune.modules.transforms.vision_utils.get_canvas_best_fit import (
34    find_supported_resolutions,
35    get_canvas_best_fit,
36)
37
38from torchtune.modules.transforms.vision_utils.get_inscribed_size import (
39    get_inscribed_size,
40)
41
42from torchvision.transforms.v2 import functional as F
43
44
45def initialize_models(resize_to_max_canvas: bool) -> Dict[str, Any]:
46    config = PreprocessConfig(resize_to_max_canvas=resize_to_max_canvas)
47
48    reference_model = CLIPImageTransform(
49        image_mean=config.image_mean,
50        image_std=config.image_std,
51        resample=config.resample,
52        antialias=config.antialias,
53        tile_size=config.tile_size,
54        max_num_tiles=config.max_num_tiles,
55        resize_to_max_canvas=config.resize_to_max_canvas,
56        possible_resolutions=None,
57    )
58
59    # Eager model.
60    model = CLIPImageTransformModel(config)
61
62    # Exported model.
63    exported_model = torch.export.export(
64        model.get_eager_model(),
65        model.get_example_inputs(),
66        dynamic_shapes=model.get_dynamic_shapes(),
67        strict=False,
68    )
69
70    # AOTInductor model.
71    so = torch._export.aot_compile(
72        exported_model.module(),
73        args=model.get_example_inputs(),
74        options={"aot_inductor.package": True},
75        dynamic_shapes=model.get_dynamic_shapes(),
76    )
77    aoti_path = "preprocess.pt2"
78    package_aoti(aoti_path, so)
79
80    edge_program = to_edge(
81        exported_model, compile_config=EdgeCompileConfig(_check_ir_validity=False)
82    )
83    executorch_model = edge_program.to_executorch()
84
85    # Re-export as ExecuTorch edits the ExportedProgram.
86    exported_model = torch.export.export(
87        model.get_eager_model(),
88        model.get_example_inputs(),
89        dynamic_shapes=model.get_dynamic_shapes(),
90        strict=False,
91    )
92
93    return {
94        "config": config,
95        "reference_model": reference_model,
96        "model": model,
97        "exported_model": exported_model,
98        "aoti_path": aoti_path,
99        "executorch_model": executorch_model,
100    }
101
102
103# From https://github.com/pytorch/torchtune/blob/main/tests/test_utils.py#L231
104def assert_expected(
105    actual: Any,
106    expected: Any,
107    rtol: float = 1e-5,
108    atol: float = 1e-8,
109    check_device: bool = True,
110):
111    torch.testing.assert_close(
112        actual,
113        expected,
114        rtol=rtol,
115        atol=atol,
116        check_device=check_device,
117        msg=f"actual: {actual}, expected: {expected}",
118    )
119
120
121class TestImageTransform:
122    """
123    This test checks that the exported image transform model produces the
124    same output as the reference model.
125
126    Reference model: CLIPImageTransform
127        https://github.com/pytorch/torchtune/blob/main/torchtune/models/clip/inference/_transforms.py#L115
128    Eager and exported models: _CLIPImageTransform
129        https://github.com/pytorch/torchtune/blob/main/torchtune/models/clip/inference/_transforms.py#L26
130    """
131
132    models_no_resize = initialize_models(resize_to_max_canvas=False)
133    models_resize = initialize_models(resize_to_max_canvas=True)
134
135    @pytest.fixture(autouse=True)
136    def setup_function(self):
137        np.random.seed(0)
138
139    def prepare_inputs(
140        self, image: Image.Image, config: PreprocessConfig
141    ) -> Tuple[torch.Tensor]:
142        """
143        Prepare inputs for eager and exported models:
144        - Convert PIL image to tensor.
145        - Calculate the best resolution; a canvas with height and width divisible by tile_size.
146        - Calculate the inscribed size; the size of the image inscribed within best_resolution,
147            without distortion.
148
149        These calculations are done by the reference model inside __init__ and __call__
150        https://github.com/pytorch/torchtune/blob/main/torchtune/models/clip/inference/_transforms.py#L115
151        """
152        image_tensor = F.to_dtype(
153            F.grayscale_to_rgb_image(F.to_image(image)), scale=True
154        )
155
156        # The above converts the PIL image into a torchvision tv_tensor.
157        # Convert the tv_tensor into a torch.Tensor.
158        image_tensor = image_tensor + 0
159
160        # Ensure tensor is contiguous for executorch.
161        image_tensor = image_tensor.contiguous()
162
163        # Calculate possible resolutions.
164        possible_resolutions = config.possible_resolutions
165        if possible_resolutions is None:
166            possible_resolutions = find_supported_resolutions(
167                max_num_tiles=config.max_num_tiles, tile_size=config.tile_size
168            )
169        possible_resolutions = torch.tensor(possible_resolutions).reshape(-1, 2)
170
171        # Limit resizing.
172        max_size = None if config.resize_to_max_canvas else config.tile_size
173
174        # Find the best canvas to fit the image without distortion.
175        best_resolution = get_canvas_best_fit(
176            image=image_tensor,
177            possible_resolutions=possible_resolutions,
178            resize_to_max_canvas=config.resize_to_max_canvas,
179        )
180        best_resolution = torch.tensor(best_resolution)
181
182        # Find the dimensions of the image, such that it is inscribed within best_resolution
183        # without distortion.
184        inscribed_size = get_inscribed_size(
185            image_tensor.shape[-2:], best_resolution, max_size
186        )
187        inscribed_size = torch.tensor(inscribed_size)
188
189        return image_tensor, inscribed_size, best_resolution
190
191    def run_preprocess(
192        self,
193        image_size: Tuple[int],
194        expected_shape: torch.Size,
195        resize_to_max_canvas: bool,
196        expected_tile_means: List[float],
197        expected_tile_max: List[float],
198        expected_tile_min: List[float],
199        expected_ar: List[int],
200    ) -> None:
201        models = self.models_resize if resize_to_max_canvas else self.models_no_resize
202        # Prepare image input.
203        image = (
204            np.random.randint(0, 256, np.prod(image_size))
205            .reshape(image_size)
206            .astype(np.uint8)
207        )
208        image = PIL.Image.fromarray(image)
209
210        # Run reference model.
211        reference_model = models["reference_model"]
212        reference_output = reference_model(image=image)
213        reference_image = reference_output["image"]
214        reference_ar = reference_output["aspect_ratio"].tolist()
215
216        # Check output shape and aspect ratio matches expected values.
217        assert (
218            reference_image.shape == expected_shape
219        ), f"Expected shape {expected_shape} but got {reference_image.shape}"
220
221        assert (
222            reference_ar == expected_ar
223        ), f"Expected ar {reference_ar} but got {expected_ar}"
224
225        # Check pixel values within expected range [0, 1]
226        assert (
227            0 <= reference_image.min() <= reference_image.max() <= 1
228        ), f"Expected pixel values in range [0, 1] but got {reference_image.min()} to {reference_image.max()}"
229
230        # Check mean, max, and min values of the tiles match expected values.
231        for i, tile in enumerate(reference_image):
232            assert_expected(
233                tile.mean().item(), expected_tile_means[i], rtol=0, atol=1e-4
234            )
235            assert_expected(tile.max().item(), expected_tile_max[i], rtol=0, atol=1e-4)
236            assert_expected(tile.min().item(), expected_tile_min[i], rtol=0, atol=1e-4)
237
238        # Check num tiles matches the product of the aspect ratio.
239        expected_num_tiles = reference_ar[0] * reference_ar[1]
240        assert (
241            expected_num_tiles == reference_image.shape[0]
242        ), f"Expected {expected_num_tiles} tiles but got {reference_image.shape[0]}"
243
244        # Pre-work for eager and exported models. The reference model performs these
245        # calculations and passes the result to _CLIPImageTransform, the exportable model.
246        image_tensor, inscribed_size, best_resolution = self.prepare_inputs(
247            image=image, config=models["config"]
248        )
249
250        # Run eager model and check it matches reference model.
251        eager_model = models["model"].get_eager_model()
252        eager_image, eager_ar = eager_model(
253            image_tensor, inscribed_size, best_resolution
254        )
255        eager_ar = eager_ar.tolist()
256        assert_expected(eager_image, reference_image, rtol=0, atol=1e-4)
257        assert (
258            reference_ar == eager_ar
259        ), f"Eager model: expected {reference_ar} but got {eager_ar}"
260
261        # Run exported model and check it matches reference model.
262        exported_model = models["exported_model"]
263        exported_image, exported_ar = exported_model.module()(
264            image_tensor, inscribed_size, best_resolution
265        )
266        exported_ar = exported_ar.tolist()
267        assert_expected(exported_image, reference_image, rtol=0, atol=1e-4)
268        assert (
269            reference_ar == exported_ar
270        ), f"Exported model: expected {reference_ar} but got {exported_ar}"
271
272        # Run executorch model and check it matches reference model.
273        executorch_model = models["executorch_model"]
274        executorch_module = _load_for_executorch_from_buffer(executorch_model.buffer)
275        et_image, et_ar = executorch_module.forward(
276            (image_tensor, inscribed_size, best_resolution)
277        )
278        assert_expected(et_image, reference_image, rtol=0, atol=1e-4)
279        assert (
280            reference_ar == et_ar.tolist()
281        ), f"Executorch model: expected {reference_ar} but got {et_ar.tolist()}"
282
283        # Run aoti model and check it matches reference model.
284        aoti_path = models["aoti_path"]
285        aoti_model = torch._inductor.aoti_load_package(aoti_path)
286        aoti_image, aoti_ar = aoti_model(image_tensor, inscribed_size, best_resolution)
287        assert_expected(aoti_image, reference_image, rtol=0, atol=1e-4)
288        assert (
289            reference_ar == aoti_ar.tolist()
290        ), f"AOTI model: expected {reference_ar} but got {aoti_ar.tolist()}"
291
292    # This test setup mirrors the one in torchtune:
293    # https://github.com/pytorch/torchtune/blob/main/tests/torchtune/models/clip/test_clip_image_transform.py
294    # The values are slightly different, as torchtune uses antialias=True,
295    # and this test uses antialias=False, which is exportable (has a portable kernel).
296    def test_preprocess1(self):
297        self.run_preprocess(
298            (100, 400, 3),  # image_size
299            torch.Size([2, 3, 224, 224]),  # expected shape
300            False,  # resize_to_max_canvas
301            [0.2230, 0.1763],  # expected_tile_means
302            [1.0, 1.0],  # expected_tile_max
303            [0.0, 0.0],  # expected_tile_min
304            [1, 2],  # expected_aspect_ratio
305        )
306
307    def test_preprocess2(self):
308        self.run_preprocess(
309            (1000, 300, 3),  # image_size
310            torch.Size([4, 3, 224, 224]),  # expected shape
311            True,  # resize_to_max_canvas
312            [0.5005, 0.4992, 0.5004, 0.1651],  # expected_tile_means
313            [0.9976, 0.9940, 0.9936, 0.9906],  # expected_tile_max
314            [0.0037, 0.0047, 0.0039, 0.0],  # expected_tile_min
315            [4, 1],  # expected_aspect_ratio
316        )
317
318    def test_preprocess3(self):
319        self.run_preprocess(
320            (200, 200, 3),  # image_size
321            torch.Size([4, 3, 224, 224]),  # expected shape
322            True,  # resize_to_max_canvas
323            [0.5012, 0.5020, 0.5010, 0.4991],  # expected_tile_means
324            [0.9921, 0.9925, 0.9969, 0.9908],  # expected_tile_max
325            [0.0056, 0.0069, 0.0059, 0.0032],  # expected_tile_min
326            [2, 2],  # expected_aspect_ratio
327        )
328
329    def test_preprocess4(self):
330        self.run_preprocess(
331            (600, 200, 3),  # image_size
332            torch.Size([3, 3, 224, 224]),  # expected shape
333            False,  # resize_to_max_canvas
334            [0.4472, 0.4468, 0.3031],  # expected_tile_means
335            [1.0, 1.0, 1.0],  # expected_tile_max
336            [0.0, 0.0, 0.0],  # expected_tile_min
337            [3, 1],  # expected_aspect_ratio
338        )
339