1# Owner(s): ["module: unknown"] 2 3from copy import copy 4 5import torch 6from torch.distributed._tools.mod_tracker import ModTracker 7from torch.testing._internal.common_utils import run_tests, TestCase, xfailIfTorchDynamo 8from torch.utils.checkpoint import checkpoint 9 10 11class TestModTracker(TestCase): 12 # "https://github.com/pytorch/pytorch/issues/127112 13 @xfailIfTorchDynamo 14 def test_module_hierarchy(self): 15 seen_fw = [] 16 seen_bw = [] 17 18 class Foo(torch.nn.Module): 19 def forward(self, x): 20 x = x["a"].relu_() 21 seen_fw.append((copy(tracker.parents), tracker.is_bw)) 22 x.register_hook( 23 lambda grad: seen_bw.append((copy(tracker.parents), tracker.is_bw)) 24 ) 25 return {"a": torch.mm(x, x)} 26 27 class Mod(torch.nn.Module): 28 def __init__(self) -> None: 29 super().__init__() 30 self.a = Foo() 31 self.b = torch.nn.ModuleDict({"nest": Foo()}) 32 self.c = torch.nn.ModuleList([Foo()]) 33 34 def forward(self, x): 35 x = self.c[0](x) 36 return self.b["nest"](self.a(x)) 37 38 mod = Mod() 39 40 with ModTracker() as tracker: 41 mod({"a": torch.randn(10, 10, requires_grad=True).clone()})[ 42 "a" 43 ].sum().backward() 44 mod({"a": torch.randn(10, 10, requires_grad=True).clone()})[ 45 "a" 46 ].sum().backward() 47 48 self.assertEqual( 49 seen_fw, 50 [ 51 ({"Global", "Mod", "Mod.c.0"}, False), 52 ({"Global", "Mod", "Mod.a"}, False), 53 ({"Global", "Mod", "Mod.b.nest"}, False), 54 ({"Global", "Mod", "Mod.c.0"}, False), 55 ({"Global", "Mod", "Mod.a"}, False), 56 ({"Global", "Mod", "Mod.b.nest"}, False), 57 ], 58 ) 59 60 self.assertEqual( 61 seen_bw, 62 [ 63 ({"Global", "Mod", "Mod.b.nest"}, True), 64 ({"Global", "Mod", "Mod.a"}, True), 65 ({"Global", "Mod", "Mod.c.0"}, True), 66 ({"Global", "Mod", "Mod.b.nest"}, True), 67 ({"Global", "Mod", "Mod.a"}, True), 68 ({"Global", "Mod", "Mod.c.0"}, True), 69 ], 70 ) 71 72 def test_bw_detection(self): 73 mod = torch.nn.Linear(2, 2) 74 75 with ModTracker() as tracker: 76 mod(torch.rand(2, requires_grad=True)).sum().backward() 77 self.assertFalse(tracker.is_bw) 78 self.assertEqual(tracker.parents, {"Global"}) 79 80 @xfailIfTorchDynamo 81 def test_user_hooks(self): 82 class Bar(torch.nn.Module): 83 def __init__(self) -> None: 84 super().__init__() 85 self.foo = torch.nn.Linear(10, 10) 86 87 def forward(self, x): 88 return self.foo(x).relu_() 89 90 mt = ModTracker() 91 test_op = [] 92 93 def hook(mod, hook_name): 94 mfqn = mt.get_known_fqn(mod) if mod is not None else None 95 test_op.append((hook_name, mfqn, mfqn in mt.parents, mt.is_bw)) 96 97 mod = Bar() 98 99 mt.register_user_hooks( 100 lambda m, inp: hook(m, "pre_fw"), 101 lambda m, inp, op: hook(m, "post_fw"), 102 lambda m, gop: hook(m, "pre_bw"), 103 lambda m, ginp: hook(m, "post_bw"), 104 ) 105 with mt: 106 mod(torch.rand(10, 10, requires_grad=True)).sum().backward() 107 expected_op = [ 108 ("pre_fw", "Bar", True, False), 109 ("pre_fw", "Bar.foo", True, False), 110 ("post_fw", "Bar.foo", True, False), 111 ("post_fw", "Bar", True, False), 112 ("pre_bw", "Bar", True, True), 113 ("pre_bw", "Bar.foo", True, True), 114 ("post_bw", "Bar", True, True), 115 ("post_bw", "Bar.foo", True, True), 116 ] 117 self.assertEqual(test_op, expected_op) 118 119 with self.assertRaises(AssertionError): 120 mt.register_user_hooks(lambda x, y: x, None, None, None) 121 122 test_op.clear() 123 with mt: 124 loss = mod(torch.rand(10, 10, requires_grad=True)).sum() 125 del mod 126 loss.backward() 127 expected_op = [ 128 ("pre_fw", "Bar", True, False), 129 ("pre_fw", "Bar.foo", True, False), 130 ("post_fw", "Bar.foo", True, False), 131 ("post_fw", "Bar", True, False), 132 ("pre_bw", None, False, True), 133 ("pre_bw", None, False, True), 134 ("post_bw", None, False, True), 135 ("post_bw", None, False, True), 136 ] 137 self.assertEqual(test_op, expected_op) 138 139 @xfailIfTorchDynamo 140 def test_ac(self): 141 class Foo(torch.nn.Module): 142 def __init__(self, n_layers: int, dim: int, use_ac: bool = False): 143 super().__init__() 144 self.linears = torch.nn.ModuleList() 145 self.use_ac = use_ac 146 for _ in range(n_layers): 147 self.linears.append(torch.nn.Linear(dim, dim)) 148 149 def forward(self, x): 150 for i, block in enumerate(self.linears): 151 if i >= 1 and self.use_ac: 152 x = checkpoint( 153 block, x, preserve_rng_state=True, use_reentrant=False 154 ) 155 else: 156 x = block(x) 157 assert x is not None 158 x = torch.nn.functional.relu(x) 159 return x 160 161 bsz = 2 162 dim = 8 163 n_layers = 2 164 test_op = [] 165 166 def hook(mod, mt, hook_name): 167 mfqn = mt.get_known_fqn(mod) if mod is not None else None 168 test_op.append((hook_name, mfqn, mfqn in mt.parents, mt.is_bw)) 169 170 mt = ModTracker() 171 mt.register_user_hooks( 172 lambda m, i: hook(m, mt, "pre_fw"), 173 lambda m, i, o: hook(m, mt, "post_fw"), 174 lambda m, go: hook(m, mt, "pre_bw"), 175 lambda m, gi: hook(m, mt, "post_bw"), 176 ) 177 model = Foo(n_layers, dim, True) 178 x = torch.randn(bsz, dim) 179 with mt: 180 model(x).sum().backward() 181 182 expected_op = [ 183 ("pre_fw", "Foo", True, False), 184 ("pre_fw", "Foo.linears.0", True, False), 185 ("post_fw", "Foo.linears.0", True, False), 186 ("pre_fw", "Foo.linears.1", True, False), 187 ("post_fw", "Foo.linears.1", True, False), 188 ("post_fw", "Foo", True, False), 189 ("pre_bw", "Foo", True, True), 190 ("pre_bw", "Foo.linears.1", True, True), 191 ("pre_fw", "Foo.linears.1", True, True), 192 ("post_fw", "Foo.linears.1", True, True), 193 ("post_bw", "Foo.linears.1", True, True), 194 ("pre_bw", "Foo.linears.0", True, True), 195 ] 196 self.assertEqual(test_op, expected_op) 197 198 199if __name__ == "__main__": 200 run_tests() 201