• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Owner(s): ["module: unknown"]
2
3from copy import copy
4
5import torch
6from torch.distributed._tools.mod_tracker import ModTracker
7from torch.testing._internal.common_utils import run_tests, TestCase, xfailIfTorchDynamo
8from torch.utils.checkpoint import checkpoint
9
10
11class TestModTracker(TestCase):
12    # "https://github.com/pytorch/pytorch/issues/127112
13    @xfailIfTorchDynamo
14    def test_module_hierarchy(self):
15        seen_fw = []
16        seen_bw = []
17
18        class Foo(torch.nn.Module):
19            def forward(self, x):
20                x = x["a"].relu_()
21                seen_fw.append((copy(tracker.parents), tracker.is_bw))
22                x.register_hook(
23                    lambda grad: seen_bw.append((copy(tracker.parents), tracker.is_bw))
24                )
25                return {"a": torch.mm(x, x)}
26
27        class Mod(torch.nn.Module):
28            def __init__(self) -> None:
29                super().__init__()
30                self.a = Foo()
31                self.b = torch.nn.ModuleDict({"nest": Foo()})
32                self.c = torch.nn.ModuleList([Foo()])
33
34            def forward(self, x):
35                x = self.c[0](x)
36                return self.b["nest"](self.a(x))
37
38        mod = Mod()
39
40        with ModTracker() as tracker:
41            mod({"a": torch.randn(10, 10, requires_grad=True).clone()})[
42                "a"
43            ].sum().backward()
44            mod({"a": torch.randn(10, 10, requires_grad=True).clone()})[
45                "a"
46            ].sum().backward()
47
48        self.assertEqual(
49            seen_fw,
50            [
51                ({"Global", "Mod", "Mod.c.0"}, False),
52                ({"Global", "Mod", "Mod.a"}, False),
53                ({"Global", "Mod", "Mod.b.nest"}, False),
54                ({"Global", "Mod", "Mod.c.0"}, False),
55                ({"Global", "Mod", "Mod.a"}, False),
56                ({"Global", "Mod", "Mod.b.nest"}, False),
57            ],
58        )
59
60        self.assertEqual(
61            seen_bw,
62            [
63                ({"Global", "Mod", "Mod.b.nest"}, True),
64                ({"Global", "Mod", "Mod.a"}, True),
65                ({"Global", "Mod", "Mod.c.0"}, True),
66                ({"Global", "Mod", "Mod.b.nest"}, True),
67                ({"Global", "Mod", "Mod.a"}, True),
68                ({"Global", "Mod", "Mod.c.0"}, True),
69            ],
70        )
71
72    def test_bw_detection(self):
73        mod = torch.nn.Linear(2, 2)
74
75        with ModTracker() as tracker:
76            mod(torch.rand(2, requires_grad=True)).sum().backward()
77            self.assertFalse(tracker.is_bw)
78            self.assertEqual(tracker.parents, {"Global"})
79
80    @xfailIfTorchDynamo
81    def test_user_hooks(self):
82        class Bar(torch.nn.Module):
83            def __init__(self) -> None:
84                super().__init__()
85                self.foo = torch.nn.Linear(10, 10)
86
87            def forward(self, x):
88                return self.foo(x).relu_()
89
90        mt = ModTracker()
91        test_op = []
92
93        def hook(mod, hook_name):
94            mfqn = mt.get_known_fqn(mod) if mod is not None else None
95            test_op.append((hook_name, mfqn, mfqn in mt.parents, mt.is_bw))
96
97        mod = Bar()
98
99        mt.register_user_hooks(
100            lambda m, inp: hook(m, "pre_fw"),
101            lambda m, inp, op: hook(m, "post_fw"),
102            lambda m, gop: hook(m, "pre_bw"),
103            lambda m, ginp: hook(m, "post_bw"),
104        )
105        with mt:
106            mod(torch.rand(10, 10, requires_grad=True)).sum().backward()
107        expected_op = [
108            ("pre_fw", "Bar", True, False),
109            ("pre_fw", "Bar.foo", True, False),
110            ("post_fw", "Bar.foo", True, False),
111            ("post_fw", "Bar", True, False),
112            ("pre_bw", "Bar", True, True),
113            ("pre_bw", "Bar.foo", True, True),
114            ("post_bw", "Bar", True, True),
115            ("post_bw", "Bar.foo", True, True),
116        ]
117        self.assertEqual(test_op, expected_op)
118
119        with self.assertRaises(AssertionError):
120            mt.register_user_hooks(lambda x, y: x, None, None, None)
121
122        test_op.clear()
123        with mt:
124            loss = mod(torch.rand(10, 10, requires_grad=True)).sum()
125            del mod
126            loss.backward()
127        expected_op = [
128            ("pre_fw", "Bar", True, False),
129            ("pre_fw", "Bar.foo", True, False),
130            ("post_fw", "Bar.foo", True, False),
131            ("post_fw", "Bar", True, False),
132            ("pre_bw", None, False, True),
133            ("pre_bw", None, False, True),
134            ("post_bw", None, False, True),
135            ("post_bw", None, False, True),
136        ]
137        self.assertEqual(test_op, expected_op)
138
139    @xfailIfTorchDynamo
140    def test_ac(self):
141        class Foo(torch.nn.Module):
142            def __init__(self, n_layers: int, dim: int, use_ac: bool = False):
143                super().__init__()
144                self.linears = torch.nn.ModuleList()
145                self.use_ac = use_ac
146                for _ in range(n_layers):
147                    self.linears.append(torch.nn.Linear(dim, dim))
148
149            def forward(self, x):
150                for i, block in enumerate(self.linears):
151                    if i >= 1 and self.use_ac:
152                        x = checkpoint(
153                            block, x, preserve_rng_state=True, use_reentrant=False
154                        )
155                    else:
156                        x = block(x)
157                    assert x is not None
158                    x = torch.nn.functional.relu(x)
159                return x
160
161        bsz = 2
162        dim = 8
163        n_layers = 2
164        test_op = []
165
166        def hook(mod, mt, hook_name):
167            mfqn = mt.get_known_fqn(mod) if mod is not None else None
168            test_op.append((hook_name, mfqn, mfqn in mt.parents, mt.is_bw))
169
170        mt = ModTracker()
171        mt.register_user_hooks(
172            lambda m, i: hook(m, mt, "pre_fw"),
173            lambda m, i, o: hook(m, mt, "post_fw"),
174            lambda m, go: hook(m, mt, "pre_bw"),
175            lambda m, gi: hook(m, mt, "post_bw"),
176        )
177        model = Foo(n_layers, dim, True)
178        x = torch.randn(bsz, dim)
179        with mt:
180            model(x).sum().backward()
181
182        expected_op = [
183            ("pre_fw", "Foo", True, False),
184            ("pre_fw", "Foo.linears.0", True, False),
185            ("post_fw", "Foo.linears.0", True, False),
186            ("pre_fw", "Foo.linears.1", True, False),
187            ("post_fw", "Foo.linears.1", True, False),
188            ("post_fw", "Foo", True, False),
189            ("pre_bw", "Foo", True, True),
190            ("pre_bw", "Foo.linears.1", True, True),
191            ("pre_fw", "Foo.linears.1", True, True),
192            ("post_fw", "Foo.linears.1", True, True),
193            ("post_bw", "Foo.linears.1", True, True),
194            ("pre_bw", "Foo.linears.0", True, True),
195        ]
196        self.assertEqual(test_op, expected_op)
197
198
199if __name__ == "__main__":
200    run_tests()
201