1# Owner(s): ["oncall: distributed"] 2 3import torch 4import torch.nn as nn 5from torch.distributed._tensor import DTensor 6from torch.distributed.checkpoint.state_dict import get_state_dict 7from torch.distributed.device_mesh import _mesh_resources, init_device_mesh 8from torch.distributed.fsdp import FullyShardedDataParallel as FSDP 9from torch.testing._internal.common_utils import run_tests 10from torch.testing._internal.distributed._tensor.common_dtensor import ( 11 DTensorTestBase, 12 skip_if_lt_x_gpu, 13 with_comms, 14) 15from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir 16from torch.testing._internal.distributed.common_state_dict import VerifyStateDictMixin 17 18 19class Dummymodel(nn.Module): 20 def __init__(self) -> None: 21 super().__init__() 22 23 def forward(self, x): 24 raise NotImplementedError 25 26 27class EPModel(nn.Module): 28 def __init__(self, rank): 29 super().__init__() 30 self.net1 = nn.Sequential(nn.Linear(16, 16), nn.ReLU()) 31 self.net2 = nn.Sequential(nn.Linear(16, 16), nn.ReLU()) 32 33 def forward(self, x): 34 raise NotImplementedError 35 36 37class SecondTier(nn.Module): 38 def __init__(self, rank): 39 super().__init__() 40 self.ep_layers = nn.ModuleList( 41 [EPModel(rank) if rank % 4 == i else Dummymodel() for i in range(4)] 42 ) 43 self.net = nn.Sequential(nn.Linear(16, 16), nn.ReLU()) 44 45 def forward(self, x): 46 raise NotImplementedError 47 48 49class TopModel(nn.Module): 50 def __init__(self, rank): 51 super().__init__() 52 torch.manual_seed(0) 53 54 self.second = SecondTier(rank) 55 self.net = nn.Sequential(nn.Linear(16, 16), nn.ReLU()) 56 57 def forward(self, x): 58 raise NotImplementedError 59 60 61class TestFSDPWithEP(DTensorTestBase, VerifyStateDictMixin): 62 @property 63 def world_size(self) -> int: 64 return min(8, torch.cuda.device_count()) 65 66 @with_comms 67 @skip_if_lt_x_gpu(8) 68 @with_temp_dir 69 def test_e2e(self): 70 model = TopModel(self.rank).cuda() 71 72 mesh_fsdp_tp = init_device_mesh( 73 self.device_type, (2, 4), mesh_dim_names=("dp", "tp") 74 ) 75 # TODO: we are using an internal API atm. Change to a publich API once it is ready. 76 mesh_fsdp_ep = _mesh_resources.create_child_mesh(mesh_fsdp_tp, ("dp",)) 77 del _mesh_resources.child_to_parent_mapping[mesh_fsdp_ep] 78 79 mesh_fsdp = init_device_mesh(self.device_type, (8,)) 80 for i, l in enumerate(model.second.ep_layers): 81 model.second.ep_layers[i] = FSDP( 82 l, use_orig_params=True, device_mesh=mesh_fsdp_ep 83 ) 84 model.second = FSDP(model.second, use_orig_params=True, device_mesh=mesh_fsdp) 85 model = FSDP(model, use_orig_params=True, device_mesh=mesh_fsdp) 86 optim = torch.optim.Adam(model.parameters(), lr=0.1) 87 msd, osd = get_state_dict(model, optim) 88 89 # FSDP only params 90 for key in ( 91 "net.0.weight", 92 "net.0.bias", 93 "second.net.0.weight", 94 "second.net.0.bias", 95 ): 96 msd_v = msd[key] 97 osd_v = osd["state"][key]["exp_avg"] 98 for v in (msd_v, osd_v): 99 self.assertTrue(isinstance(v, DTensor)) 100 self.assertEqual(tuple(v.device_mesh.mesh), tuple(range(8))) 101 102 # FSDP/EP params 103 layer = self.rank % 4 104 ranks = (layer, layer + 4) 105 for i in range(4): 106 for key in ( 107 f"second.ep_layers.{i}.net1.0.weight", 108 f"second.ep_layers.{i}.net1.0.bias", 109 f"second.ep_layers.{i}.net2.0.weight", 110 f"second.ep_layers.{i}.net2.0.bias", 111 ): 112 if layer != i: 113 self.assertTrue(key not in msd) 114 else: 115 msd_v = msd[key] 116 osd_v = osd["state"][key]["exp_avg"] 117 for v in (msd_v, osd_v): 118 self.assertTrue(isinstance(v, DTensor)) 119 self.assertEqual(tuple(v.device_mesh.mesh), ranks) 120 121 self.assertEqual(set(osd["state"].keys()), set(msd.keys())) 122 123 124if __name__ == "__main__": 125 run_tests() 126