• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Owner(s): ["oncall: distributed"]
2
3import sys
4
5import torch
6from torch import distributed as dist
7from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
8from torch.nn import Linear, Module
9from torch.nn.parallel import DistributedDataParallel
10from torch.optim import SGD
11from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
12from torch.testing._internal.common_fsdp import FSDPTest, get_full_params
13from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
14
15
16if not dist.is_available():
17    print("Distributed not available, skipping tests", file=sys.stderr)
18    sys.exit(0)
19
20if TEST_WITH_DEV_DBG_ASAN:
21    print(
22        "Skip dev-asan as torch + multiprocessing spawn have known issues",
23        file=sys.stderr,
24    )
25    sys.exit(0)
26
27
28class Model(Module):
29    def __init__(self, wrap_fsdp):
30        super().__init__()
31        # keep everything deterministic for model initialization
32        torch.manual_seed(0)
33        self.inner = Linear(4, 4)
34        if wrap_fsdp:
35            self.inner = FSDP(self.inner)
36        self.outer = Linear(4, 5)
37
38    def forward(self, x):
39        # Forward twice.
40        i = self.inner(x)
41        j = self.inner(x)
42        return self.outer(i + j)
43
44
45class TestMultiForward(FSDPTest):
46    def _dist_train(self, wrap_fsdp):
47        # keep everything deterministic for input data
48        torch.manual_seed(0)
49
50        model = Model(wrap_fsdp).cuda()
51        if wrap_fsdp:
52            model = FSDP(model)
53        else:
54            model = DistributedDataParallel(model, device_ids=[self.rank])
55        optim = SGD(model.parameters(), lr=0.1)
56
57        in_data = torch.rand(64, 4).cuda()
58        in_data.requires_grad = True
59        for _ in range(3):
60            out = model(in_data)
61            out.sum().backward()
62            optim.step()
63            optim.zero_grad()
64
65        if wrap_fsdp:
66            return get_full_params(model)
67
68        return list(model.parameters())
69
70    @skip_if_lt_x_gpu(2)
71    def test_multi_forward(self):
72        # DDP
73        ddp_state = self._dist_train(wrap_fsdp=False)
74
75        # FSDP
76        fsdp_state = self._dist_train(wrap_fsdp=True)
77
78        self.assertEqual(ddp_state, fsdp_state)
79
80
81if __name__ == "__main__":
82    run_tests()
83