1# Owner(s): ["module: fx"] 2 3import os 4import tempfile 5 6import torch 7from torch.fx import subgraph_rewriter, symbolic_trace 8from torch.fx.passes.graph_transform_observer import GraphTransformObserver 9from torch.testing._internal.common_utils import TestCase 10 11 12if __name__ == "__main__": 13 raise RuntimeError( 14 "This test file is not meant to be run directly, use:\n\n" 15 "\tpython test/test_fx.py TESTNAME\n\n" 16 "instead." 17 ) 18 19 20class TestGraphTransformObserver(TestCase): 21 def test_graph_transform_observer(self): 22 class M(torch.nn.Module): 23 def forward(self, x): 24 val = torch.neg(x) 25 return torch.add(val, val) 26 27 def pattern(x): 28 return torch.neg(x) 29 30 def replacement(x): 31 return torch.relu(x) 32 33 traced = symbolic_trace(M()) 34 35 log_url = tempfile.mkdtemp() 36 37 with GraphTransformObserver(traced, "replace_neg_with_relu", log_url) as ob: 38 subgraph_rewriter.replace_pattern(traced, pattern, replacement) 39 40 self.assertTrue("relu" in ob.created_nodes) 41 self.assertTrue("neg" in ob.erased_nodes) 42 43 current_pass_count = GraphTransformObserver.get_current_pass_count() 44 45 self.assertTrue( 46 os.path.isfile( 47 os.path.join( 48 log_url, 49 f"pass_{current_pass_count}_replace_neg_with_relu_input_graph.dot", 50 ) 51 ) 52 ) 53 self.assertTrue( 54 os.path.isfile( 55 os.path.join( 56 log_url, 57 f"pass_{current_pass_count}_replace_neg_with_relu_output_graph.dot", 58 ) 59 ) 60 ) 61