# Owner(s): ["oncall: pt2"] import dataclasses import functools import torch from torch import nn from torch._dynamo import compiled_autograd from torch._dynamo.test_case import run_tests, TestCase from torch._dynamo.testing import CompileCounter from torch.testing._internal.common_utils import IS_MACOS, skipIfRocm, skipIfXpu from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, requires_gpu # Fake distributed WORLD_SIZE = 2 def init_fake_distributed(device="cpu"): @torch.no_grad def all_gather(t): return torch.cat([t] * WORLD_SIZE, 0) @torch.no_grad def reduce_scatter(t): # clone since reduce_scatter input and output should not be aliases. return t.narrow(0, 0, t.size(0) // WORLD_SIZE).clone() def fw_pre_hook(mod, inp): if not compiled_autograd.compiled_autograd_enabled: # torch.ops.fsdp.set_ doesn't work well in eager mode, so use the slow copy_ path instead. mod.unsharded_weight.untyped_storage().resize_( mod.unsharded_weight.nelement() * mod.unsharded_weight.element_size() ) with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter( mod.unsharded_weight ): mod.unsharded_weight.copy_(all_gather(mod.sharded_weight)) else: with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter( mod.unsharded_weight ): torch.ops.fsdp.set_( mod.unsharded_weight, all_gather(mod.sharded_weight) ) mod._parameters["weight"] = mod.unsharded_weight # Forward: # mod.sharded_weight = local_shard (always) # Before: # mod.weight = local_shard # mod.unsharded_weight = zero-sized allgather # After: # mod.weight = local_shard # mod.unsharded_weight = zero-sized allgather def fw_post_hook(mod, inp, out): mod._parameters["weight"] = mod.sharded_weight mod.unsharded_weight.untyped_storage().resize_(0) def bw_pre_hook(mod, gO): if not compiled_autograd.compiled_autograd_enabled: # torch.ops.fsdp.set_ doesn't work well in eager mode, so use the slow copy_ path instead. mod.unsharded_weight.untyped_storage().resize_( mod.unsharded_weight.nelement() * mod.unsharded_weight.element_size() ) with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter( mod.unsharded_weight ): mod.unsharded_weight.copy_(all_gather(mod.sharded_weight)) else: with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter( mod.unsharded_weight ): torch.ops.fsdp.set_( mod.unsharded_weight, all_gather(mod.sharded_weight) ) mod._parameters["weight"] = mod.unsharded_weight # Backward: # mod.sharded_weight = local_shard (always) # Before: # mod.weight = local_shard # mod.unsharded_weight = zero-sized allgather # After: # mod.weight = local_shard # mod.unsharded_weight = zero-sized allgather def bw_post_hook(mod, gI, gO): grad = mod.weight.grad new_grad = reduce_scatter(grad) mod._parameters["weight"] = mod.sharded_weight mod.weight.grad = new_grad mod.unsharded_weight.untyped_storage().resize_(0) torch.manual_seed(1234) m = nn.Linear(20, 10, bias=False, device=device) # Mimics eager 1st iteration m.sharded_weight = nn.Parameter(reduce_scatter(m.weight)) m.unsharded_weight = nn.Parameter(all_gather(m.sharded_weight)) m.unsharded_weight.untyped_storage().resize_(0) m.register_full_backward_pre_hook(bw_pre_hook) m.register_full_backward_hook(bw_post_hook) m.register_forward_pre_hook(fw_pre_hook) m.register_forward_hook(fw_post_hook) return m, torch.rand(2, 20, requires_grad=True, device=device) def init_module_bw_hooks(allow_eager): def bw_pre_hook(mod, gO): assert allow_eager or torch._dynamo.is_compiling() assert mod.weight.size() == (10, 10) mod.hook_count_pre.add_(1) return (torch.sin(gO[0] + 1.2),) def bw_post_hook(mod, gI, gO): assert allow_eager or torch._dynamo.is_compiling() assert mod.weight.size() == (10, 10) mod.hook_count_post.add_(1) return (torch.sin(gI[0] + 3.4),) torch.manual_seed(1234) m = nn.Linear(10, 10) m.hook_count_pre = torch.tensor(0) m.hook_count_post = torch.tensor(0) m.register_full_backward_pre_hook(bw_pre_hook) m.register_full_backward_hook(bw_post_hook) return m, torch.rand(2, 10, requires_grad=True) def steps(m, inp): for _ in range(4): out = m(inp) out.sum().backward() return out class DistributedPatternTests(TestCase): def test_intermediate_hook_with_closure(self): @dataclasses.dataclass class CustomObj: val: torch.Tensor def fn(x, obj): y = x.sin() closure_var = y + 1 y.register_hook(lambda grad: grad + obj.val + closure_var) z = y.sin() return z opt = torch.compile(fn, fullgraph=True) obj1 = CustomObj(torch.tensor(88)) obj2 = CustomObj(torch.tensor(99)) x0 = torch.ones(4, requires_grad=True) x1 = torch.ones(4, requires_grad=True) x2 = torch.ones(4, requires_grad=True) x3 = torch.ones(4, requires_grad=True) fn(x0, obj1).sum().backward() fn(x1, obj2).sum().backward() with compiled_autograd.enable(functools.partial(torch.compile, fullgraph=True)): opt(x2, obj1).sum().backward() opt(x3, obj2).sum().backward() self.assertEqual(x0.grad, x2.grad) self.assertEqual(x1.grad, x3.grad) @torch.no_grad() def _test_storage_resize_zero(self, device): @torch.compile(fullgraph=True) def fn(x): y = torch.sin(x) x.untyped_storage().resize_(0) return torch.cos(y) x = torch.randn(10, device=device) expected = torch.cos(torch.sin(x)) y = fn(x) self.assertEqual(y, expected) self.assertEqual(x.untyped_storage().size(), 0) def test_storage_resize_zero_cpu(self): self._test_storage_resize_zero("cpu") @skipIfRocm @requires_gpu() def test_storage_resize_zero_gpu(self): self._test_storage_resize_zero(GPU_TYPE) @torch.no_grad() def _test_storage_resize_nonzero(self, device): @torch.compile(fullgraph=True) def fn(x, out): y = torch.sin(x) assert out.untyped_storage().size() == 0 out.untyped_storage().resize_(x.untyped_storage().size()) out.copy_(y.cos()) x = torch.randn(10, device=device) out = torch.randn(10, device=device) expected = torch.cos(torch.sin(x)) out.untyped_storage().resize_(0) fn(x, out) self.assertEqual(out.untyped_storage().size(), x.untyped_storage().size()) self.assertEqual(out, expected) def test_storage_resize_nonzero_cpu(self): self._test_storage_resize_nonzero("cpu") @skipIfRocm @requires_gpu() def test_storage_resize_nonzero_gpu(self): self._test_storage_resize_nonzero(GPU_TYPE) @torch.no_grad() def test_unsafe_set_version_counter1(self): cnt = CompileCounter() @torch.compile(backend=cnt, fullgraph=True) def fn(w, x): x = x.sin() v = w._version w.copy_(x + 1) torch._C._autograd._unsafe_set_version_counter(w, v) return w, v for v in (3, 0, 1): w1 = torch.randn(16) for i in range(v): w1.fill_(i) # bump w1._version self.assertEqual(w1._version, v) x1 = torch.randn(16) w2, v2 = fn(w1, x1) self.assertIs(w1, w2) self.assertEqual(w1, x1.sin() + 1) self.assertEqual(v2, v) self.assertEqual(w1._version, v) self.assertEqual(cnt.frame_count, 1) def test_unsafe_set_version_counter2(self): @torch.compile(backend="inductor", fullgraph=True) def fn(w, x): r = w.sin() with torch.no_grad(): v = w._version w.copy_(x) torch._C._autograd._unsafe_set_version_counter(w, v) return r w1 = torch.randn(1, requires_grad=True) x1 = torch.randn(1) expected_r1 = w1.detach().sin() r1 = fn(w1, x1) r1.backward() self.assertEqual(r1, expected_r1) self.assertEqual(w1, x1) self.assertEqual(w1.grad, x1.cos()) @torch.no_grad() def test_unsafe_preserve_version_counter1(self): @torch.compile(backend="eager", fullgraph=True) def fn(w, x): x = x.sin() with torch.autograd._unsafe_preserve_version_counter(w): w.copy_(x + 1) return w w1 = torch.randn(16).fill_(0).fill_(1) x1 = torch.randn(16) v1 = w1._version w2 = fn(w1, x1) v2 = w1._version self.assertIs(w1, w2) self.assertEqual(w1, x1.sin() + 1) self.assertEqual(v1, v2) def test_unsafe_preserve_version_counter2(self): @torch.compile(backend="inductor", fullgraph=True) def fn(w, x): r = w.sin() with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter(w): w.copy_(x) return r w1 = torch.randn(1, requires_grad=True) x1 = torch.randn(1) expected_r1 = w1.detach().sin() r1 = fn(w1, x1) r1.backward() self.assertEqual(r1, expected_r1) self.assertEqual(w1, x1) self.assertEqual(w1.grad, x1.cos()) def test_module_backward_hooks_eager(self): m1, inp1 = init_module_bw_hooks(True) out1 = steps(m1, inp1) m2, inp2 = init_module_bw_hooks(False) fw_cnt = CompileCounter() bw_cnt = CompileCounter() with compiled_autograd.enable(torch.compile(backend=bw_cnt, fullgraph=True)): m2 = torch.compile(m2, backend=fw_cnt, fullgraph=True) out2 = steps(m2, inp2) self.assertEqual(m1.hook_count_pre, m2.hook_count_pre) self.assertEqual(m1.hook_count_post, m2.hook_count_post) self.assertEqual(out1, out2) self.assertEqual(inp1.grad, inp2.grad) self.assertEqual(m1.weight.grad, m2.weight.grad) self.assertEqual(m1.bias.grad, m2.bias.grad) self.assertEqual(fw_cnt.frame_count, 1) self.assertEqual(fw_cnt.op_count, 5) self.assertEqual(bw_cnt.frame_count, 2) # grad=None and grad!=None self.assertEqual(bw_cnt.op_count, 48) def test_module_backward_hooks_aot(self): m1, inp1 = init_module_bw_hooks(True) out1 = steps(m1, inp1) m2, inp2 = init_module_bw_hooks(True) m2 = torch.compile(m2, backend="aot_eager", fullgraph=True) with compiled_autograd.enable(lambda gm: gm): out2 = steps(m2, inp2) self.assertEqual(m1.hook_count_pre, m2.hook_count_pre) self.assertEqual(m1.hook_count_post, m2.hook_count_post) self.assertEqual(out1, out2) self.assertEqual(inp1.grad, inp2.grad) self.assertEqual(m1.weight.grad, m2.weight.grad) self.assertEqual(m1.bias.grad, m2.bias.grad) def test_module_backward_hooks_inductor(self): m1, inp1 = init_module_bw_hooks(True) out1 = steps(m1, inp1) m2, inp2 = init_module_bw_hooks(False) m2 = torch.compile(m2, fullgraph=True) with compiled_autograd.enable(torch.compile(fullgraph=True)): out2 = steps(m2, inp2) self.assertEqual(m1.hook_count_pre, m2.hook_count_pre) self.assertEqual(m1.hook_count_post, m2.hook_count_post) self.assertEqual(out1, out2) self.assertEqual(inp1.grad, inp2.grad) self.assertEqual(m1.weight.grad, m2.weight.grad) self.assertEqual(m1.bias.grad, m2.bias.grad) def test_module_backward_hooks_multi_layers(self): a1, inp1 = init_module_bw_hooks(True) b1, _ = init_module_bw_hooks(True) out1 = steps(torch.nn.Sequential(a1, b1), inp1) a2, inp2 = init_module_bw_hooks(False) b2, _ = init_module_bw_hooks(False) with compiled_autograd.enable(torch.compile(fullgraph=True)): out2 = steps( torch.compile(torch.nn.Sequential(a2, b2), fullgraph=True), inp2 ) self.assertEqual(a1.hook_count_pre, a2.hook_count_pre) self.assertEqual(a1.hook_count_post, a2.hook_count_post) self.assertEqual(b1.hook_count_pre, b2.hook_count_pre) self.assertEqual(b1.hook_count_post, b2.hook_count_post) self.assertEqual(out1, out2) self.assertEqual(inp1.grad, inp2.grad) self.assertEqual(a1.weight.grad, a2.weight.grad) self.assertEqual(a1.bias.grad, a2.bias.grad) self.assertEqual(b1.weight.grad, b2.weight.grad) self.assertEqual(b1.bias.grad, b2.bias.grad) # TODO(jansel): support bw hooks with graph break def _assert_same_grad(self, a, b): self.assertEqual(type(a), type(b)) self.assertEqual(a, b) self.assertEqual(a.grad, b.grad) self.assertEqual(a.requires_grad, b.requires_grad) def test_nn_param_return1(self): def fn(x): p = torch.nn.Parameter(x) return p, p.sin() opt = torch.compile(fn, fullgraph=True) x1 = torch.randn(16) x2 = x1.clone() p1, r1 = fn(x1) r1.sum().backward() p2, r2 = opt(x2) r2.sum().backward() self._assert_same_grad(r1, r2) self._assert_same_grad(p1, p2) def test_nn_param_return2(self): def fn(x): p = torch.nn.Parameter(x, requires_grad=False) return p, x + 1 opt = torch.compile(fn, fullgraph=True) x1 = torch.randn(16) x2 = x1.clone() p1, r1 = fn(x1) p2, r2 = opt(x2) self._assert_same_grad(r1, r2) self._assert_same_grad(p1, p2) def test_nn_param_return3(self): def fn(x): p = torch.nn.Parameter(x + 123) return p, p.sin() opt = torch.compile(fn, fullgraph=True) x1 = torch.randn(16) x2 = x1.clone() p1, r1 = fn(x1) r1.sum().backward() p2, r2 = opt(x2) r2.sum().backward() self._assert_same_grad(r1, r2) self._assert_same_grad(p1, p2) def test_nn_param_return4(self): def fn(x): p = torch.nn.Parameter(x + 123, requires_grad=False) return p, x + 1 opt = torch.compile(fn, fullgraph=True) x1 = torch.randn(16) x2 = x1.clone() p1, r1 = fn(x1) p2, r2 = opt(x2) self._assert_same_grad(r1, r2) self._assert_same_grad(p1, p2) @torch._functorch.config.patch(recompute_views=True) def test_fake_distributed_aot_eager(self): m1, inp1 = init_fake_distributed() out1 = steps(m1, inp1) m2, inp2 = init_fake_distributed() m2 = torch.compile(m2, backend="aot_eager", fullgraph=True) bw_cnt = CompileCounter() with compiled_autograd.enable(torch.compile(backend=bw_cnt, fullgraph=True)): out2 = steps(m2, inp2) self._assert_same_grad(m1.weight, m2.weight) self._assert_same_grad(inp1, inp2) self._assert_same_grad(out1, out2) # Recompile on grad==None/grad!=None self.assertEqual(bw_cnt.frame_count, 2) @skipIfRocm @skipIfXpu @requires_gpu() @torch._functorch.config.patch(recompute_views=True) def test_fake_distributed_inductor(self): # TODO: fix .set_ lowering in CPU inductor, and enable the CPU test. m1, inp1 = init_fake_distributed(GPU_TYPE) out1 = steps(m1, inp1) m2, inp2 = init_fake_distributed(GPU_TYPE) m2 = torch.compile(m2, fullgraph=True) with compiled_autograd.enable(torch.compile(fullgraph=True)): out2 = steps(m2, inp2) self._assert_same_grad(m1.weight, m2.weight) self._assert_same_grad(inp1, inp2) self._assert_same_grad(out1, out2) if __name__ == "__main__": if HAS_CPU and not IS_MACOS: run_tests(needs="filelock")