1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-unsafe 8 9from unittest import TestCase 10 11from executorch import exir 12from executorch.exir import to_edge 13from executorch.exir.passes import DebugPass, HintBasedSymShapeEvalPass, SpecPropPass 14from executorch.exir.tests.models import Repeat, TensorItem 15from torch.export import export 16 17 18class TestDynamicShapeProp(TestCase): 19 def test_repeat(self): 20 eager_model = Repeat() 21 inputs = eager_model.get_random_inputs() 22 inputs = inputs[0], inputs[1] 23 24 prog = to_edge( 25 export(eager_model, inputs, dynamic_shapes=eager_model.get_dynamic_shape()), 26 compile_config=exir.EdgeCompileConfig(_check_ir_validity=False), 27 ) 28 29 new_prog = prog.transform([SpecPropPass(), HintBasedSymShapeEvalPass()]) 30 31 gm = new_prog.exported_program().graph_module 32 33 DebugPass(show_spec=True)(gm) 34 *_, return_node = gm.graph.nodes 35 speclist = return_node.meta["spec"] 36 self.assertEqual(len(speclist), 2) 37 first_spec, second_spec = speclist 38 39 self.assertTrue(first_spec.is_upper_bound_tensor) 40 self.assertTrue(second_spec.is_upper_bound_tensor) 41 self.assertEqual(first_spec.shape, [4, 5]) 42 43 44class TestUnbackedSymInt(TestCase): 45 def test_unbacked_symint(self): 46 eager_model = TensorItem() 47 inputs = eager_model.get_random_inputs() 48 inputs = inputs[0], inputs[1] 49 50 prog = to_edge( 51 export(eager_model, inputs, dynamic_shapes=None), 52 compile_config=exir.EdgeCompileConfig(_check_ir_validity=False), 53 ) 54 new_prog = prog.transform([SpecPropPass(), HintBasedSymShapeEvalPass()]) 55 gm = new_prog.exported_program().graph_module 56 57 DebugPass(show_spec=True)(gm) 58 *_, return_node = gm.graph.nodes 59 speclist = return_node.meta["spec"] 60 self.assertEqual(len(speclist), 1) 61 self.assertTrue(speclist[0].is_upper_bound_tensor) 62 self.assertEqual( 63 speclist[0].shape, [100, 100] 64 ) # upper bound of TensorItem model 65