• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import unittest
8
9import torch
10from executorch.backends.xnnpack.test.tester import Tester
11from torchvision import models
12
13
14class TestViT(unittest.TestCase):
15    vit = models.vision_transformer.vit_b_16(weights="IMAGENET1K_V1")
16    vit = vit.eval()
17    model_inputs = (torch.randn(1, 3, 224, 224),)
18    dynamic_shapes = (
19        {
20            2: torch.export.Dim("height", min=224, max=455),
21            3: torch.export.Dim("width", min=224, max=455),
22        },
23    )
24
25    class DynamicViT(torch.nn.Module):
26        def __init__(self):
27            super().__init__()
28            self.vit = models.vision_transformer.vit_b_16(weights="IMAGENET1K_V1")
29            self.vit = self.vit.eval()
30
31        def forward(self, x):
32            x = torch.nn.functional.interpolate(
33                x,
34                size=(224, 224),
35                mode="bilinear",
36                align_corners=True,
37                antialias=False,
38            )
39            return self.vit(x)
40
41    all_operators = {
42        "executorch_exir_dialects_edge__ops_aten_expand_copy_default",
43        "executorch_exir_dialects_edge__ops_aten_cat_default",
44        "executorch_exir_dialects_edge__ops_aten_permute_copy_default",
45        "executorch_exir_dialects_edge__ops_aten_addmm_default",
46        "executorch_exir_dialects_edge__ops_aten_add_Tensor",
47        "executorch_exir_dialects_edge__ops_aten_mul_Scalar",
48        "executorch_exir_dialects_edge__ops_aten_gelu_default",
49        "executorch_exir_dialects_edge__ops_aten_unsqueeze_copy_default",
50        "executorch_exir_dialects_edge__ops_aten_clone_default",
51        "executorch_exir_dialects_edge__ops_aten__softmax_default",
52        "executorch_exir_dialects_edge__ops_aten_convolution_default",
53        "executorch_exir_dialects_edge__ops_aten_view_copy_default",
54        "executorch_exir_dialects_edge__ops_aten_squeeze_copy_dim",
55        "executorch_exir_dialects_edge__ops_aten_select_copy_int",
56        "executorch_exir_dialects_edge__ops_aten_native_layer_norm_default",
57        "executorch_exir_dialects_edge__ops_aten_bmm_default",
58    }
59
60    def _test_exported_vit(self, tester, check_nots=None):
61        check_nots = check_nots or []
62        lowerable_xnn_operators = self.all_operators - {
63            "executorch_exir_dialects_edge__ops_aten_expand_copy_default",
64            "executorch_exir_dialects_edge__ops_aten_gelu_default",
65            "executorch_exir_dialects_edge__ops_aten_unsqueeze_copy_default",
66            "executorch_exir_dialects_edge__ops_aten_clone_default",
67            "executorch_exir_dialects_edge__ops_aten_view_copy_default",
68            "executorch_exir_dialects_edge__ops_aten_squeeze_copy_dim",
69            "executorch_exir_dialects_edge__ops_aten_select_copy_int",
70            "executorch_exir_dialects_edge__ops_aten_native_layer_norm_default",
71            "executorch_exir_dialects_edge__ops_aten_mul_Scalar",
72            "executorch_exir_dialects_edge__ops_aten_bmm_default",
73        }
74        (
75            tester.export()
76            .to_edge_transform_and_lower()
77            .check(["torch.ops.higher_order.executorch_call_delegate"])
78            .check_not(list(lowerable_xnn_operators))
79            .check_not(check_nots)
80            .to_executorch()
81            .serialize()
82            .run_method_and_compare_outputs()
83        )
84
85    def test_fp32_vit(self):
86        self._test_exported_vit(Tester(self.vit, self.model_inputs))
87
88    def test_dynamic_vit(self):
89        bilinear_ops = {
90            "executorch_exir_dialects_edge__ops_aten_sub_Tensor",
91            "executorch_exir_dialects_edge__ops_aten_mul_Tensor",
92            "executorch_exir_dialects_edge__ops_aten_index_Tensor",
93            "executorch_exir_dialects_edge__ops_aten_arange_start_step",
94            "executorch_exir_dialects_edge__ops_aten__to_copy_default",
95            "executorch_exir_dialects_edge__ops_aten_add_Tensor",
96            "executorch_exir_dialects_edge__ops_aten_clamp_default",
97        }
98
99        self._test_exported_vit(
100            Tester(self.DynamicViT(), self.model_inputs, self.dynamic_shapes),
101            bilinear_ops,
102        )
103