• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Owner(s): ["oncall: distributed"]
2
3import torch
4import torch.nn as nn
5from torch.distributed._tensor import DTensor
6from torch.distributed.checkpoint.state_dict import get_state_dict
7from torch.distributed.device_mesh import _mesh_resources, init_device_mesh
8from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
9from torch.testing._internal.common_utils import run_tests
10from torch.testing._internal.distributed._tensor.common_dtensor import (
11    DTensorTestBase,
12    skip_if_lt_x_gpu,
13    with_comms,
14)
15from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
16from torch.testing._internal.distributed.common_state_dict import VerifyStateDictMixin
17
18
19class Dummymodel(nn.Module):
20    def __init__(self) -> None:
21        super().__init__()
22
23    def forward(self, x):
24        raise NotImplementedError
25
26
27class EPModel(nn.Module):
28    def __init__(self, rank):
29        super().__init__()
30        self.net1 = nn.Sequential(nn.Linear(16, 16), nn.ReLU())
31        self.net2 = nn.Sequential(nn.Linear(16, 16), nn.ReLU())
32
33    def forward(self, x):
34        raise NotImplementedError
35
36
37class SecondTier(nn.Module):
38    def __init__(self, rank):
39        super().__init__()
40        self.ep_layers = nn.ModuleList(
41            [EPModel(rank) if rank % 4 == i else Dummymodel() for i in range(4)]
42        )
43        self.net = nn.Sequential(nn.Linear(16, 16), nn.ReLU())
44
45    def forward(self, x):
46        raise NotImplementedError
47
48
49class TopModel(nn.Module):
50    def __init__(self, rank):
51        super().__init__()
52        torch.manual_seed(0)
53
54        self.second = SecondTier(rank)
55        self.net = nn.Sequential(nn.Linear(16, 16), nn.ReLU())
56
57    def forward(self, x):
58        raise NotImplementedError
59
60
61class TestFSDPWithEP(DTensorTestBase, VerifyStateDictMixin):
62    @property
63    def world_size(self) -> int:
64        return min(8, torch.cuda.device_count())
65
66    @with_comms
67    @skip_if_lt_x_gpu(8)
68    @with_temp_dir
69    def test_e2e(self):
70        model = TopModel(self.rank).cuda()
71
72        mesh_fsdp_tp = init_device_mesh(
73            self.device_type, (2, 4), mesh_dim_names=("dp", "tp")
74        )
75        # TODO: we are using an internal API atm. Change to a publich API once it is ready.
76        mesh_fsdp_ep = _mesh_resources.create_child_mesh(mesh_fsdp_tp, ("dp",))
77        del _mesh_resources.child_to_parent_mapping[mesh_fsdp_ep]
78
79        mesh_fsdp = init_device_mesh(self.device_type, (8,))
80        for i, l in enumerate(model.second.ep_layers):
81            model.second.ep_layers[i] = FSDP(
82                l, use_orig_params=True, device_mesh=mesh_fsdp_ep
83            )
84        model.second = FSDP(model.second, use_orig_params=True, device_mesh=mesh_fsdp)
85        model = FSDP(model, use_orig_params=True, device_mesh=mesh_fsdp)
86        optim = torch.optim.Adam(model.parameters(), lr=0.1)
87        msd, osd = get_state_dict(model, optim)
88
89        # FSDP only params
90        for key in (
91            "net.0.weight",
92            "net.0.bias",
93            "second.net.0.weight",
94            "second.net.0.bias",
95        ):
96            msd_v = msd[key]
97            osd_v = osd["state"][key]["exp_avg"]
98            for v in (msd_v, osd_v):
99                self.assertTrue(isinstance(v, DTensor))
100                self.assertEqual(tuple(v.device_mesh.mesh), tuple(range(8)))
101
102        # FSDP/EP params
103        layer = self.rank % 4
104        ranks = (layer, layer + 4)
105        for i in range(4):
106            for key in (
107                f"second.ep_layers.{i}.net1.0.weight",
108                f"second.ep_layers.{i}.net1.0.bias",
109                f"second.ep_layers.{i}.net2.0.weight",
110                f"second.ep_layers.{i}.net2.0.bias",
111            ):
112                if layer != i:
113                    self.assertTrue(key not in msd)
114                else:
115                    msd_v = msd[key]
116                    osd_v = osd["state"][key]["exp_avg"]
117                    for v in (msd_v, osd_v):
118                        self.assertTrue(isinstance(v, DTensor))
119                        self.assertEqual(tuple(v.device_mesh.mesh), ranks)
120
121        self.assertEqual(set(osd["state"].keys()), set(msd.keys()))
122
123
124if __name__ == "__main__":
125    run_tests()
126