• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import os
8import tempfile
9import unittest
10
11import torch
12from executorch.exir import EdgeCompileConfig, to_edge
13from executorch.extension.llm.modules import (
14    replace_tile_positional_embedding,
15    replace_tiled_token_positional_embedding,
16    TiledTokenPositionalEmbedding,
17    TilePositionalEmbedding,
18)
19from executorch.runtime import Runtime
20from torch._inductor.package import load_package, package_aoti
21from torch.testing import assert_close
22from torchtune.models.clip import (
23    TiledTokenPositionalEmbedding as TuneTiledTokenPositionalEmbedding,
24    TilePositionalEmbedding as TuneTilePositionalEmbedding,
25)
26
27
28class TilePositionalEmbeddingTest(unittest.TestCase):
29    def setUp(self):
30        super().setUp()
31        self.tpe = TilePositionalEmbedding(4, 1280)
32        self.ref_tpe = TuneTilePositionalEmbedding(4, 1280)
33        self.x = torch.randn(1, 4, 1600, 1280)
34        self.aspect_ratio = torch.tensor([[1, 1]])
35        num_tiles_dim = torch.export.Dim("num_tiles", min=1, max=4)
36        num_tokens = torch.export.Dim("num_tokens", min=1, max=1600)
37
38        self.dynamic_shape = {
39            0: 1,  # batch
40            1: num_tiles_dim,  # num tiles
41            2: num_tokens,  # num tokens
42            3: 1280,  # embedding dim
43        }
44
45    def test_tile_positional_embedding_smoke(self):
46        y = self.tpe(self.x, self.aspect_ratio)
47        ref_y = self.ref_tpe(self.x, self.aspect_ratio)
48
49        self.assertTrue(torch.allclose(y, ref_y))
50
51    def test_tile_positional_embedding_export(self):
52
53        tpe_ep = torch.export.export(
54            self.tpe,
55            (self.x, self.aspect_ratio),
56            dynamic_shapes=(
57                self.dynamic_shape,
58                None,
59            ),  # assuming aspect ratio is static
60        )
61
62        y = tpe_ep.module()(self.x, self.aspect_ratio)
63        ref_y = self.ref_tpe(self.x, self.aspect_ratio)
64
65        self.assertTrue(torch.allclose(y, ref_y))
66
67    def test_tile_positional_embedding_aoti(self):
68        so = torch._export.aot_compile(
69            self.tpe,
70            args=(self.x, self.aspect_ratio),
71            options={"aot_inductor.package": True},
72            dynamic_shapes=(
73                self.dynamic_shape,
74                None,
75            ),  # assuming aspect ratio is static
76        )
77        with tempfile.TemporaryDirectory() as tmpdir:
78            path = package_aoti(os.path.join(tmpdir, "tpe.pt2"), so)
79            tpe_aoti = load_package(path)
80
81            y = tpe_aoti(self.x, self.aspect_ratio)
82            ref_y = self.ref_tpe(self.x, self.aspect_ratio)
83
84            self.assertTrue(torch.allclose(y, ref_y))
85
86    def test_tile_positional_embedding_et(self):
87        tpe_ep = torch.export.export(
88            self.tpe,
89            (self.x, self.aspect_ratio),
90            dynamic_shapes=(
91                self.dynamic_shape,
92                None,
93            ),  # assuming aspect ratio is static
94        )
95        et_program = to_edge(
96            tpe_ep,
97            compile_config=EdgeCompileConfig(
98                _core_aten_ops_exception_list=[
99                    torch.ops.aten.sym_constrain_range_for_size.default,
100                    torch.ops.aten._assert_scalar.default,
101                    torch.ops.aten._local_scalar_dense.default,
102                ]
103            ),
104        ).to_executorch()
105        runtime = Runtime.get()
106        program = runtime.load_program(et_program.buffer)
107        method = program.load_method("forward")
108        y = method.execute((self.x, self.aspect_ratio))
109        ref_y = self.ref_tpe(self.x, self.aspect_ratio)
110
111        self.assertTrue(torch.allclose(y[0], ref_y))
112
113    def test_replace_tile_positional_embedding(self):
114        class Module(torch.nn.Module):
115            def __init__(self):
116                super().__init__()
117                self.tpe = TuneTilePositionalEmbedding(4, 1280)
118
119            def forward(self, x, aspect_ratio):
120                return self.tpe(x, aspect_ratio)
121
122        m = Module()
123        m = replace_tile_positional_embedding(m)
124        self.assertTrue(isinstance(m.tpe, TilePositionalEmbedding))
125
126
127class TiledTokenPositionalEmbeddingTest(unittest.TestCase):
128    def setUp(self):
129        super().setUp()
130        self.tpe = TiledTokenPositionalEmbedding(4, 1280, 40, 1)
131        self.ref_tpe = TuneTiledTokenPositionalEmbedding(4, 1280, 40, 1)
132        self.tpe.load_state_dict(self.ref_tpe.state_dict())
133        self.x = torch.randn(1, 4, 1601, 1280)
134        self.aspect_ratio = torch.tensor([[1, 2]])
135        num_tiles_dim = torch.export.Dim("num_tiles", min=1, max=4)
136
137        self.dynamic_shape = {
138            0: 1,  # batch
139            1: num_tiles_dim,  # num tiles
140            2: 1601,  # num tokens
141            3: 1280,  # embedding dim
142        }
143
144    def test_tiled_token_positional_embedding_smoke(self):
145        y = self.tpe(self.x, self.aspect_ratio)
146        ref_y = self.ref_tpe(self.x, self.aspect_ratio)
147
148        assert_close(y, ref_y)
149
150    def test_tiled_token_positional_embedding_export(self):
151
152        tpe_ep = torch.export.export(
153            self.tpe,
154            (self.x, self.aspect_ratio),
155            dynamic_shapes=(
156                self.dynamic_shape,
157                None,
158            ),  # assuming aspect ratio is static
159        )
160
161        y = tpe_ep.module()(self.x, self.aspect_ratio)
162        ref_y = self.ref_tpe(self.x, self.aspect_ratio)
163
164        assert_close(y, ref_y)
165
166    @unittest.skip(reason="TODO(T207740932): test is flaky")
167    def test_tiled_token_positional_embedding_aoti(self):
168        tpe_ep = torch.export.export(
169            self.tpe,
170            (self.x, self.aspect_ratio),
171            dynamic_shapes=(
172                self.dynamic_shape,
173                None,
174            ),  # assuming aspect ratio is static
175        )
176
177        with tempfile.TemporaryDirectory() as tmpdir:
178            path = torch._inductor.aoti_compile_and_package(
179                tpe_ep,
180                (self.x, self.aspect_ratio),
181                package_path=os.path.join(tmpdir, "tpe.pt2"),
182            )
183            tpe_aoti = load_package(path)
184
185            y = tpe_aoti(self.x, self.aspect_ratio)
186            ref_y = self.ref_tpe(self.x, self.aspect_ratio)
187
188            assert_close(y, ref_y)
189
190    def test_tiled_token_positional_embedding_et(self):
191        tpe_ep = torch.export.export(
192            self.tpe,
193            (self.x, self.aspect_ratio),
194            dynamic_shapes=(
195                self.dynamic_shape,
196                None,
197            ),  # assuming aspect ratio is static
198        )
199        et_program = to_edge(
200            tpe_ep,
201            compile_config=EdgeCompileConfig(
202                _core_aten_ops_exception_list=[
203                    torch.ops.aten.sym_constrain_range_for_size.default,
204                    torch.ops.aten._assert_scalar.default,
205                    torch.ops.aten._local_scalar_dense.default,
206                ]
207            ),
208        ).to_executorch()
209        runtime = Runtime.get()
210        program = runtime.load_program(et_program.buffer)
211        method = program.load_method("forward")
212        y = method.execute((self.x, self.aspect_ratio))
213        ref_y = self.ref_tpe(self.x, self.aspect_ratio)
214
215        assert_close(y[0], ref_y)
216
217    def test_replace_tiled_token_positional_embedding(self):
218        class Module(torch.nn.Module):
219            def __init__(self):
220                super().__init__()
221                self.tpe = TuneTiledTokenPositionalEmbedding(4, 1280, 40, 1)
222
223            def forward(self, x, aspect_ratio):
224                return self.tpe(x, aspect_ratio)
225
226        m = Module()
227        m = replace_tiled_token_positional_embedding(m)
228        self.assertTrue(isinstance(m.tpe, TiledTokenPositionalEmbedding))
229