• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-unsafe
8
9from unittest import TestCase
10
11from executorch import exir
12from executorch.exir import to_edge
13from executorch.exir.passes import DebugPass, HintBasedSymShapeEvalPass, SpecPropPass
14from executorch.exir.tests.models import Repeat, TensorItem
15from torch.export import export
16
17
18class TestDynamicShapeProp(TestCase):
19    def test_repeat(self):
20        eager_model = Repeat()
21        inputs = eager_model.get_random_inputs()
22        inputs = inputs[0], inputs[1]
23
24        prog = to_edge(
25            export(eager_model, inputs, dynamic_shapes=eager_model.get_dynamic_shape()),
26            compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
27        )
28
29        new_prog = prog.transform([SpecPropPass(), HintBasedSymShapeEvalPass()])
30
31        gm = new_prog.exported_program().graph_module
32
33        DebugPass(show_spec=True)(gm)
34        *_, return_node = gm.graph.nodes
35        speclist = return_node.meta["spec"]
36        self.assertEqual(len(speclist), 2)
37        first_spec, second_spec = speclist
38
39        self.assertTrue(first_spec.is_upper_bound_tensor)
40        self.assertTrue(second_spec.is_upper_bound_tensor)
41        self.assertEqual(first_spec.shape, [4, 5])
42
43
44class TestUnbackedSymInt(TestCase):
45    def test_unbacked_symint(self):
46        eager_model = TensorItem()
47        inputs = eager_model.get_random_inputs()
48        inputs = inputs[0], inputs[1]
49
50        prog = to_edge(
51            export(eager_model, inputs, dynamic_shapes=None),
52            compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
53        )
54        new_prog = prog.transform([SpecPropPass(), HintBasedSymShapeEvalPass()])
55        gm = new_prog.exported_program().graph_module
56
57        DebugPass(show_spec=True)(gm)
58        *_, return_node = gm.graph.nodes
59        speclist = return_node.meta["spec"]
60        self.assertEqual(len(speclist), 1)
61        self.assertTrue(speclist[0].is_upper_bound_tensor)
62        self.assertEqual(
63            speclist[0].shape, [100, 100]
64        )  # upper bound of TensorItem model
65