# Owner(s): ["module: unknown"] from copy import copy import torch from torch.distributed._tools.mod_tracker import ModTracker from torch.testing._internal.common_utils import run_tests, TestCase, xfailIfTorchDynamo from torch.utils.checkpoint import checkpoint class TestModTracker(TestCase): # "https://github.com/pytorch/pytorch/issues/127112 @xfailIfTorchDynamo def test_module_hierarchy(self): seen_fw = [] seen_bw = [] class Foo(torch.nn.Module): def forward(self, x): x = x["a"].relu_() seen_fw.append((copy(tracker.parents), tracker.is_bw)) x.register_hook( lambda grad: seen_bw.append((copy(tracker.parents), tracker.is_bw)) ) return {"a": torch.mm(x, x)} class Mod(torch.nn.Module): def __init__(self) -> None: super().__init__() self.a = Foo() self.b = torch.nn.ModuleDict({"nest": Foo()}) self.c = torch.nn.ModuleList([Foo()]) def forward(self, x): x = self.c[0](x) return self.b["nest"](self.a(x)) mod = Mod() with ModTracker() as tracker: mod({"a": torch.randn(10, 10, requires_grad=True).clone()})[ "a" ].sum().backward() mod({"a": torch.randn(10, 10, requires_grad=True).clone()})[ "a" ].sum().backward() self.assertEqual( seen_fw, [ ({"Global", "Mod", "Mod.c.0"}, False), ({"Global", "Mod", "Mod.a"}, False), ({"Global", "Mod", "Mod.b.nest"}, False), ({"Global", "Mod", "Mod.c.0"}, False), ({"Global", "Mod", "Mod.a"}, False), ({"Global", "Mod", "Mod.b.nest"}, False), ], ) self.assertEqual( seen_bw, [ ({"Global", "Mod", "Mod.b.nest"}, True), ({"Global", "Mod", "Mod.a"}, True), ({"Global", "Mod", "Mod.c.0"}, True), ({"Global", "Mod", "Mod.b.nest"}, True), ({"Global", "Mod", "Mod.a"}, True), ({"Global", "Mod", "Mod.c.0"}, True), ], ) def test_bw_detection(self): mod = torch.nn.Linear(2, 2) with ModTracker() as tracker: mod(torch.rand(2, requires_grad=True)).sum().backward() self.assertFalse(tracker.is_bw) self.assertEqual(tracker.parents, {"Global"}) @xfailIfTorchDynamo def test_user_hooks(self): class Bar(torch.nn.Module): def __init__(self) -> None: super().__init__() self.foo = torch.nn.Linear(10, 10) def forward(self, x): return self.foo(x).relu_() mt = ModTracker() test_op = [] def hook(mod, hook_name): mfqn = mt.get_known_fqn(mod) if mod is not None else None test_op.append((hook_name, mfqn, mfqn in mt.parents, mt.is_bw)) mod = Bar() mt.register_user_hooks( lambda m, inp: hook(m, "pre_fw"), lambda m, inp, op: hook(m, "post_fw"), lambda m, gop: hook(m, "pre_bw"), lambda m, ginp: hook(m, "post_bw"), ) with mt: mod(torch.rand(10, 10, requires_grad=True)).sum().backward() expected_op = [ ("pre_fw", "Bar", True, False), ("pre_fw", "Bar.foo", True, False), ("post_fw", "Bar.foo", True, False), ("post_fw", "Bar", True, False), ("pre_bw", "Bar", True, True), ("pre_bw", "Bar.foo", True, True), ("post_bw", "Bar", True, True), ("post_bw", "Bar.foo", True, True), ] self.assertEqual(test_op, expected_op) with self.assertRaises(AssertionError): mt.register_user_hooks(lambda x, y: x, None, None, None) test_op.clear() with mt: loss = mod(torch.rand(10, 10, requires_grad=True)).sum() del mod loss.backward() expected_op = [ ("pre_fw", "Bar", True, False), ("pre_fw", "Bar.foo", True, False), ("post_fw", "Bar.foo", True, False), ("post_fw", "Bar", True, False), ("pre_bw", None, False, True), ("pre_bw", None, False, True), ("post_bw", None, False, True), ("post_bw", None, False, True), ] self.assertEqual(test_op, expected_op) @xfailIfTorchDynamo def test_ac(self): class Foo(torch.nn.Module): def __init__(self, n_layers: int, dim: int, use_ac: bool = False): super().__init__() self.linears = torch.nn.ModuleList() self.use_ac = use_ac for _ in range(n_layers): self.linears.append(torch.nn.Linear(dim, dim)) def forward(self, x): for i, block in enumerate(self.linears): if i >= 1 and self.use_ac: x = checkpoint( block, x, preserve_rng_state=True, use_reentrant=False ) else: x = block(x) assert x is not None x = torch.nn.functional.relu(x) return x bsz = 2 dim = 8 n_layers = 2 test_op = [] def hook(mod, mt, hook_name): mfqn = mt.get_known_fqn(mod) if mod is not None else None test_op.append((hook_name, mfqn, mfqn in mt.parents, mt.is_bw)) mt = ModTracker() mt.register_user_hooks( lambda m, i: hook(m, mt, "pre_fw"), lambda m, i, o: hook(m, mt, "post_fw"), lambda m, go: hook(m, mt, "pre_bw"), lambda m, gi: hook(m, mt, "post_bw"), ) model = Foo(n_layers, dim, True) x = torch.randn(bsz, dim) with mt: model(x).sum().backward() expected_op = [ ("pre_fw", "Foo", True, False), ("pre_fw", "Foo.linears.0", True, False), ("post_fw", "Foo.linears.0", True, False), ("pre_fw", "Foo.linears.1", True, False), ("post_fw", "Foo.linears.1", True, False), ("post_fw", "Foo", True, False), ("pre_bw", "Foo", True, True), ("pre_bw", "Foo.linears.1", True, True), ("pre_fw", "Foo.linears.1", True, True), ("post_fw", "Foo.linears.1", True, True), ("post_bw", "Foo.linears.1", True, True), ("pre_bw", "Foo.linears.0", True, True), ] self.assertEqual(test_op, expected_op) if __name__ == "__main__": run_tests()