1# Owner(s): ["module: fx"] 2 3import unittest 4from typing import Mapping 5 6import torch 7from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner 8from torch.fx.passes.operator_support import OperatorSupport 9from torch.testing._internal.common_utils import TestCase 10 11 12class DummyDevOperatorSupport(OperatorSupport): 13 def is_node_supported( 14 self, submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node 15 ) -> bool: 16 return True 17 18 19class DummyPartitioner(CapabilityBasedPartitioner): 20 def __init__(self, graph_module: torch.fx.GraphModule): 21 super().__init__( 22 graph_module, 23 DummyDevOperatorSupport(), 24 allows_single_node_partition=True, 25 ) 26 27 28class AddModule(torch.nn.Module): 29 def forward(self, x): 30 y = torch.add(x, x) 31 z = torch.add(y, x) 32 return z 33 34 35class TestPartitionerOrder(TestCase): 36 # partitoner test to check graph node order 37 def test_partitioner_order(self): 38 m = AddModule() 39 traced_m = torch.fx.symbolic_trace(m) 40 partions = DummyPartitioner(traced_m).propose_partitions() 41 partion_nodes = [list(partition.nodes) for partition in partions] 42 node_order = [n.name for n in partion_nodes[0]] 43 for _ in range(10): 44 traced_m = torch.fx.symbolic_trace(m) 45 new_partion = DummyPartitioner(traced_m).propose_partitions() 46 new_partion_nodes = [list(partition.nodes) for partition in new_partion] 47 new_node_order = [n.name for n in new_partion_nodes[0]] 48 self.assertTrue(node_order == new_node_order) 49 50 51if __name__ == "__main__": 52 unittest.main() 53