• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Owner(s): ["module: fx"]
2
3import os
4import tempfile
5
6import torch
7from torch.fx import subgraph_rewriter, symbolic_trace
8from torch.fx.passes.graph_transform_observer import GraphTransformObserver
9from torch.testing._internal.common_utils import TestCase
10
11
12if __name__ == "__main__":
13    raise RuntimeError(
14        "This test file is not meant to be run directly, use:\n\n"
15        "\tpython test/test_fx.py TESTNAME\n\n"
16        "instead."
17    )
18
19
20class TestGraphTransformObserver(TestCase):
21    def test_graph_transform_observer(self):
22        class M(torch.nn.Module):
23            def forward(self, x):
24                val = torch.neg(x)
25                return torch.add(val, val)
26
27        def pattern(x):
28            return torch.neg(x)
29
30        def replacement(x):
31            return torch.relu(x)
32
33        traced = symbolic_trace(M())
34
35        log_url = tempfile.mkdtemp()
36
37        with GraphTransformObserver(traced, "replace_neg_with_relu", log_url) as ob:
38            subgraph_rewriter.replace_pattern(traced, pattern, replacement)
39
40            self.assertTrue("relu" in ob.created_nodes)
41            self.assertTrue("neg" in ob.erased_nodes)
42
43        current_pass_count = GraphTransformObserver.get_current_pass_count()
44
45        self.assertTrue(
46            os.path.isfile(
47                os.path.join(
48                    log_url,
49                    f"pass_{current_pass_count}_replace_neg_with_relu_input_graph.dot",
50                )
51            )
52        )
53        self.assertTrue(
54            os.path.isfile(
55                os.path.join(
56                    log_url,
57                    f"pass_{current_pass_count}_replace_neg_with_relu_output_graph.dot",
58                )
59            )
60        )
61