1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import os 8import tempfile 9import unittest 10 11import torch 12from executorch.exir import EdgeCompileConfig, to_edge 13from executorch.extension.llm.modules import ( 14 replace_tile_positional_embedding, 15 replace_tiled_token_positional_embedding, 16 TiledTokenPositionalEmbedding, 17 TilePositionalEmbedding, 18) 19from executorch.runtime import Runtime 20from torch._inductor.package import load_package, package_aoti 21from torch.testing import assert_close 22from torchtune.models.clip import ( 23 TiledTokenPositionalEmbedding as TuneTiledTokenPositionalEmbedding, 24 TilePositionalEmbedding as TuneTilePositionalEmbedding, 25) 26 27 28class TilePositionalEmbeddingTest(unittest.TestCase): 29 def setUp(self): 30 super().setUp() 31 self.tpe = TilePositionalEmbedding(4, 1280) 32 self.ref_tpe = TuneTilePositionalEmbedding(4, 1280) 33 self.x = torch.randn(1, 4, 1600, 1280) 34 self.aspect_ratio = torch.tensor([[1, 1]]) 35 num_tiles_dim = torch.export.Dim("num_tiles", min=1, max=4) 36 num_tokens = torch.export.Dim("num_tokens", min=1, max=1600) 37 38 self.dynamic_shape = { 39 0: 1, # batch 40 1: num_tiles_dim, # num tiles 41 2: num_tokens, # num tokens 42 3: 1280, # embedding dim 43 } 44 45 def test_tile_positional_embedding_smoke(self): 46 y = self.tpe(self.x, self.aspect_ratio) 47 ref_y = self.ref_tpe(self.x, self.aspect_ratio) 48 49 self.assertTrue(torch.allclose(y, ref_y)) 50 51 def test_tile_positional_embedding_export(self): 52 53 tpe_ep = torch.export.export( 54 self.tpe, 55 (self.x, self.aspect_ratio), 56 dynamic_shapes=( 57 self.dynamic_shape, 58 None, 59 ), # assuming aspect ratio is static 60 ) 61 62 y = tpe_ep.module()(self.x, self.aspect_ratio) 63 ref_y = self.ref_tpe(self.x, self.aspect_ratio) 64 65 self.assertTrue(torch.allclose(y, ref_y)) 66 67 def test_tile_positional_embedding_aoti(self): 68 so = torch._export.aot_compile( 69 self.tpe, 70 args=(self.x, self.aspect_ratio), 71 options={"aot_inductor.package": True}, 72 dynamic_shapes=( 73 self.dynamic_shape, 74 None, 75 ), # assuming aspect ratio is static 76 ) 77 with tempfile.TemporaryDirectory() as tmpdir: 78 path = package_aoti(os.path.join(tmpdir, "tpe.pt2"), so) 79 tpe_aoti = load_package(path) 80 81 y = tpe_aoti(self.x, self.aspect_ratio) 82 ref_y = self.ref_tpe(self.x, self.aspect_ratio) 83 84 self.assertTrue(torch.allclose(y, ref_y)) 85 86 def test_tile_positional_embedding_et(self): 87 tpe_ep = torch.export.export( 88 self.tpe, 89 (self.x, self.aspect_ratio), 90 dynamic_shapes=( 91 self.dynamic_shape, 92 None, 93 ), # assuming aspect ratio is static 94 ) 95 et_program = to_edge( 96 tpe_ep, 97 compile_config=EdgeCompileConfig( 98 _core_aten_ops_exception_list=[ 99 torch.ops.aten.sym_constrain_range_for_size.default, 100 torch.ops.aten._assert_scalar.default, 101 torch.ops.aten._local_scalar_dense.default, 102 ] 103 ), 104 ).to_executorch() 105 runtime = Runtime.get() 106 program = runtime.load_program(et_program.buffer) 107 method = program.load_method("forward") 108 y = method.execute((self.x, self.aspect_ratio)) 109 ref_y = self.ref_tpe(self.x, self.aspect_ratio) 110 111 self.assertTrue(torch.allclose(y[0], ref_y)) 112 113 def test_replace_tile_positional_embedding(self): 114 class Module(torch.nn.Module): 115 def __init__(self): 116 super().__init__() 117 self.tpe = TuneTilePositionalEmbedding(4, 1280) 118 119 def forward(self, x, aspect_ratio): 120 return self.tpe(x, aspect_ratio) 121 122 m = Module() 123 m = replace_tile_positional_embedding(m) 124 self.assertTrue(isinstance(m.tpe, TilePositionalEmbedding)) 125 126 127class TiledTokenPositionalEmbeddingTest(unittest.TestCase): 128 def setUp(self): 129 super().setUp() 130 self.tpe = TiledTokenPositionalEmbedding(4, 1280, 40, 1) 131 self.ref_tpe = TuneTiledTokenPositionalEmbedding(4, 1280, 40, 1) 132 self.tpe.load_state_dict(self.ref_tpe.state_dict()) 133 self.x = torch.randn(1, 4, 1601, 1280) 134 self.aspect_ratio = torch.tensor([[1, 2]]) 135 num_tiles_dim = torch.export.Dim("num_tiles", min=1, max=4) 136 137 self.dynamic_shape = { 138 0: 1, # batch 139 1: num_tiles_dim, # num tiles 140 2: 1601, # num tokens 141 3: 1280, # embedding dim 142 } 143 144 def test_tiled_token_positional_embedding_smoke(self): 145 y = self.tpe(self.x, self.aspect_ratio) 146 ref_y = self.ref_tpe(self.x, self.aspect_ratio) 147 148 assert_close(y, ref_y) 149 150 def test_tiled_token_positional_embedding_export(self): 151 152 tpe_ep = torch.export.export( 153 self.tpe, 154 (self.x, self.aspect_ratio), 155 dynamic_shapes=( 156 self.dynamic_shape, 157 None, 158 ), # assuming aspect ratio is static 159 ) 160 161 y = tpe_ep.module()(self.x, self.aspect_ratio) 162 ref_y = self.ref_tpe(self.x, self.aspect_ratio) 163 164 assert_close(y, ref_y) 165 166 @unittest.skip(reason="TODO(T207740932): test is flaky") 167 def test_tiled_token_positional_embedding_aoti(self): 168 tpe_ep = torch.export.export( 169 self.tpe, 170 (self.x, self.aspect_ratio), 171 dynamic_shapes=( 172 self.dynamic_shape, 173 None, 174 ), # assuming aspect ratio is static 175 ) 176 177 with tempfile.TemporaryDirectory() as tmpdir: 178 path = torch._inductor.aoti_compile_and_package( 179 tpe_ep, 180 (self.x, self.aspect_ratio), 181 package_path=os.path.join(tmpdir, "tpe.pt2"), 182 ) 183 tpe_aoti = load_package(path) 184 185 y = tpe_aoti(self.x, self.aspect_ratio) 186 ref_y = self.ref_tpe(self.x, self.aspect_ratio) 187 188 assert_close(y, ref_y) 189 190 def test_tiled_token_positional_embedding_et(self): 191 tpe_ep = torch.export.export( 192 self.tpe, 193 (self.x, self.aspect_ratio), 194 dynamic_shapes=( 195 self.dynamic_shape, 196 None, 197 ), # assuming aspect ratio is static 198 ) 199 et_program = to_edge( 200 tpe_ep, 201 compile_config=EdgeCompileConfig( 202 _core_aten_ops_exception_list=[ 203 torch.ops.aten.sym_constrain_range_for_size.default, 204 torch.ops.aten._assert_scalar.default, 205 torch.ops.aten._local_scalar_dense.default, 206 ] 207 ), 208 ).to_executorch() 209 runtime = Runtime.get() 210 program = runtime.load_program(et_program.buffer) 211 method = program.load_method("forward") 212 y = method.execute((self.x, self.aspect_ratio)) 213 ref_y = self.ref_tpe(self.x, self.aspect_ratio) 214 215 assert_close(y[0], ref_y) 216 217 def test_replace_tiled_token_positional_embedding(self): 218 class Module(torch.nn.Module): 219 def __init__(self): 220 super().__init__() 221 self.tpe = TuneTiledTokenPositionalEmbedding(4, 1280, 40, 1) 222 223 def forward(self, x, aspect_ratio): 224 return self.tpe(x, aspect_ratio) 225 226 m = Module() 227 m = replace_tiled_token_positional_embedding(m) 228 self.assertTrue(isinstance(m.tpe, TiledTokenPositionalEmbedding)) 229