# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import os import tempfile import unittest import torch from executorch.exir import EdgeCompileConfig, to_edge from executorch.extension.llm.modules import ( replace_tile_positional_embedding, replace_tiled_token_positional_embedding, TiledTokenPositionalEmbedding, TilePositionalEmbedding, ) from executorch.runtime import Runtime from torch._inductor.package import load_package, package_aoti from torch.testing import assert_close from torchtune.models.clip import ( TiledTokenPositionalEmbedding as TuneTiledTokenPositionalEmbedding, TilePositionalEmbedding as TuneTilePositionalEmbedding, ) class TilePositionalEmbeddingTest(unittest.TestCase): def setUp(self): super().setUp() self.tpe = TilePositionalEmbedding(4, 1280) self.ref_tpe = TuneTilePositionalEmbedding(4, 1280) self.x = torch.randn(1, 4, 1600, 1280) self.aspect_ratio = torch.tensor([[1, 1]]) num_tiles_dim = torch.export.Dim("num_tiles", min=1, max=4) num_tokens = torch.export.Dim("num_tokens", min=1, max=1600) self.dynamic_shape = { 0: 1, # batch 1: num_tiles_dim, # num tiles 2: num_tokens, # num tokens 3: 1280, # embedding dim } def test_tile_positional_embedding_smoke(self): y = self.tpe(self.x, self.aspect_ratio) ref_y = self.ref_tpe(self.x, self.aspect_ratio) self.assertTrue(torch.allclose(y, ref_y)) def test_tile_positional_embedding_export(self): tpe_ep = torch.export.export( self.tpe, (self.x, self.aspect_ratio), dynamic_shapes=( self.dynamic_shape, None, ), # assuming aspect ratio is static ) y = tpe_ep.module()(self.x, self.aspect_ratio) ref_y = self.ref_tpe(self.x, self.aspect_ratio) self.assertTrue(torch.allclose(y, ref_y)) def test_tile_positional_embedding_aoti(self): so = torch._export.aot_compile( self.tpe, args=(self.x, self.aspect_ratio), options={"aot_inductor.package": True}, dynamic_shapes=( self.dynamic_shape, None, ), # assuming aspect ratio is static ) with tempfile.TemporaryDirectory() as tmpdir: path = package_aoti(os.path.join(tmpdir, "tpe.pt2"), so) tpe_aoti = load_package(path) y = tpe_aoti(self.x, self.aspect_ratio) ref_y = self.ref_tpe(self.x, self.aspect_ratio) self.assertTrue(torch.allclose(y, ref_y)) def test_tile_positional_embedding_et(self): tpe_ep = torch.export.export( self.tpe, (self.x, self.aspect_ratio), dynamic_shapes=( self.dynamic_shape, None, ), # assuming aspect ratio is static ) et_program = to_edge( tpe_ep, compile_config=EdgeCompileConfig( _core_aten_ops_exception_list=[ torch.ops.aten.sym_constrain_range_for_size.default, torch.ops.aten._assert_scalar.default, torch.ops.aten._local_scalar_dense.default, ] ), ).to_executorch() runtime = Runtime.get() program = runtime.load_program(et_program.buffer) method = program.load_method("forward") y = method.execute((self.x, self.aspect_ratio)) ref_y = self.ref_tpe(self.x, self.aspect_ratio) self.assertTrue(torch.allclose(y[0], ref_y)) def test_replace_tile_positional_embedding(self): class Module(torch.nn.Module): def __init__(self): super().__init__() self.tpe = TuneTilePositionalEmbedding(4, 1280) def forward(self, x, aspect_ratio): return self.tpe(x, aspect_ratio) m = Module() m = replace_tile_positional_embedding(m) self.assertTrue(isinstance(m.tpe, TilePositionalEmbedding)) class TiledTokenPositionalEmbeddingTest(unittest.TestCase): def setUp(self): super().setUp() self.tpe = TiledTokenPositionalEmbedding(4, 1280, 40, 1) self.ref_tpe = TuneTiledTokenPositionalEmbedding(4, 1280, 40, 1) self.tpe.load_state_dict(self.ref_tpe.state_dict()) self.x = torch.randn(1, 4, 1601, 1280) self.aspect_ratio = torch.tensor([[1, 2]]) num_tiles_dim = torch.export.Dim("num_tiles", min=1, max=4) self.dynamic_shape = { 0: 1, # batch 1: num_tiles_dim, # num tiles 2: 1601, # num tokens 3: 1280, # embedding dim } def test_tiled_token_positional_embedding_smoke(self): y = self.tpe(self.x, self.aspect_ratio) ref_y = self.ref_tpe(self.x, self.aspect_ratio) assert_close(y, ref_y) def test_tiled_token_positional_embedding_export(self): tpe_ep = torch.export.export( self.tpe, (self.x, self.aspect_ratio), dynamic_shapes=( self.dynamic_shape, None, ), # assuming aspect ratio is static ) y = tpe_ep.module()(self.x, self.aspect_ratio) ref_y = self.ref_tpe(self.x, self.aspect_ratio) assert_close(y, ref_y) @unittest.skip(reason="TODO(T207740932): test is flaky") def test_tiled_token_positional_embedding_aoti(self): tpe_ep = torch.export.export( self.tpe, (self.x, self.aspect_ratio), dynamic_shapes=( self.dynamic_shape, None, ), # assuming aspect ratio is static ) with tempfile.TemporaryDirectory() as tmpdir: path = torch._inductor.aoti_compile_and_package( tpe_ep, (self.x, self.aspect_ratio), package_path=os.path.join(tmpdir, "tpe.pt2"), ) tpe_aoti = load_package(path) y = tpe_aoti(self.x, self.aspect_ratio) ref_y = self.ref_tpe(self.x, self.aspect_ratio) assert_close(y, ref_y) def test_tiled_token_positional_embedding_et(self): tpe_ep = torch.export.export( self.tpe, (self.x, self.aspect_ratio), dynamic_shapes=( self.dynamic_shape, None, ), # assuming aspect ratio is static ) et_program = to_edge( tpe_ep, compile_config=EdgeCompileConfig( _core_aten_ops_exception_list=[ torch.ops.aten.sym_constrain_range_for_size.default, torch.ops.aten._assert_scalar.default, torch.ops.aten._local_scalar_dense.default, ] ), ).to_executorch() runtime = Runtime.get() program = runtime.load_program(et_program.buffer) method = program.load_method("forward") y = method.execute((self.x, self.aspect_ratio)) ref_y = self.ref_tpe(self.x, self.aspect_ratio) assert_close(y[0], ref_y) def test_replace_tiled_token_positional_embedding(self): class Module(torch.nn.Module): def __init__(self): super().__init__() self.tpe = TuneTiledTokenPositionalEmbedding(4, 1280, 40, 1) def forward(self, x, aspect_ratio): return self.tpe(x, aspect_ratio) m = Module() m = replace_tiled_token_positional_embedding(m) self.assertTrue(isinstance(m.tpe, TiledTokenPositionalEmbedding))