• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Owner(s): ["module: fx"]
2
3import unittest
4from typing import Mapping
5
6import torch
7from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
8from torch.fx.passes.operator_support import OperatorSupport
9from torch.testing._internal.common_utils import TestCase
10
11
12class DummyDevOperatorSupport(OperatorSupport):
13    def is_node_supported(
14        self, submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node
15    ) -> bool:
16        return True
17
18
19class DummyPartitioner(CapabilityBasedPartitioner):
20    def __init__(self, graph_module: torch.fx.GraphModule):
21        super().__init__(
22            graph_module,
23            DummyDevOperatorSupport(),
24            allows_single_node_partition=True,
25        )
26
27
28class AddModule(torch.nn.Module):
29    def forward(self, x):
30        y = torch.add(x, x)
31        z = torch.add(y, x)
32        return z
33
34
35class TestPartitionerOrder(TestCase):
36    # partitoner test to check graph node order
37    def test_partitioner_order(self):
38        m = AddModule()
39        traced_m = torch.fx.symbolic_trace(m)
40        partions = DummyPartitioner(traced_m).propose_partitions()
41        partion_nodes = [list(partition.nodes) for partition in partions]
42        node_order = [n.name for n in partion_nodes[0]]
43        for _ in range(10):
44            traced_m = torch.fx.symbolic_trace(m)
45            new_partion = DummyPartitioner(traced_m).propose_partitions()
46            new_partion_nodes = [list(partition.nodes) for partition in new_partion]
47            new_node_order = [n.name for n in new_partion_nodes[0]]
48            self.assertTrue(node_order == new_node_order)
49
50
51if __name__ == "__main__":
52    unittest.main()
53