1# Owner(s): ["oncall: distributed"] 2import torch 3import torch.distributed.checkpoint as dist_cp 4from torch.distributed._shard.sharded_tensor import ShardedTensor 5from torch.distributed._state_dict_utils import _all_gather_sharded_tensor 6from torch.distributed._tensor import DTensor, init_device_mesh, Replicate 7from torch.distributed.fsdp import FullyShardedDataParallel as FSDP 8from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType 9from torch.distributed.tensor.parallel import ( 10 ColwiseParallel, 11 parallelize_module, 12 RowwiseParallel, 13) 14from torch.testing._internal.common_utils import run_tests 15from torch.testing._internal.distributed._tensor.common_dtensor import ( 16 DTensorTestBase, 17 MLPModule, 18 skip_if_lt_x_gpu, 19 with_comms, 20) 21from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir 22 23 24# TODO: modularize this test and add test for checkpoint conversion in both direction. 25class TestFsdpTpCheckpointConversion(DTensorTestBase): 26 @with_comms 27 @skip_if_lt_x_gpu(2) 28 @with_temp_dir 29 def test_fsdp_to_tp(self): 30 CHECKPOINT_DIR = self.temp_dir 31 32 model = MLPModule(self.device_type).cuda(self.rank) 33 # create a FSDP wrapped model 34 fsdp_model = FSDP(model, use_orig_params=True) 35 36 FSDP.set_state_dict_type( 37 fsdp_model, 38 StateDictType.SHARDED_STATE_DICT, 39 ) 40 fsdp_state_dict = fsdp_model.state_dict() 41 42 # save fsdp_state_dict to storage 43 dist_cp.save( 44 state_dict=fsdp_state_dict, 45 storage_writer=dist_cp.FileSystemWriter(CHECKPOINT_DIR), 46 ) 47 48 # create a TP wrapped model 49 mesh_shape = (self.world_size,) 50 device_mesh = init_device_mesh(self.device_type, mesh_shape) 51 model = MLPModule(self.device_type).cuda(self.rank) 52 # Parallelize the module based on the given Parallel Style. 53 parallelize_plan = { 54 "net1": ColwiseParallel(), 55 "net2": RowwiseParallel(), 56 } 57 tp_model = parallelize_module(model, device_mesh, parallelize_plan) 58 optimizer = torch.optim.SGD(tp_model.parameters(), lr=0.25) 59 60 # Update the parameters so tp_model.state_dict() will be different from fsdp_model.state_dict(). 61 torch.manual_seed(0) 62 inp = torch.rand(20, 10).cuda(self.rank) 63 output = tp_model(inp) 64 output.sum().backward() 65 optimizer.step() 66 tp_state_dict = tp_model.state_dict() 67 68 # Check parameters are indeed different prior to loading. 69 for fsdp_item, tp_item in zip(fsdp_state_dict.items(), tp_state_dict.items()): 70 fsdp_k, fsdp_v = fsdp_item 71 tp_k, tp_v = tp_item 72 73 self.assertEqual(fsdp_k, tp_k) 74 75 if isinstance(fsdp_v, ShardedTensor) and isinstance(tp_v, DTensor): 76 fsdp_redistributed = _all_gather_sharded_tensor(fsdp_v) 77 tp_redistributed = tp_v.redistribute( 78 device_mesh, placements=[Replicate()] 79 ).to_local() 80 self.assertNotEqual(fsdp_redistributed, tp_redistributed) 81 82 dist_cp.load_state_dict( 83 state_dict=tp_state_dict, 84 storage_reader=dist_cp.FileSystemReader(CHECKPOINT_DIR), 85 ) 86 tp_model.load_state_dict(tp_state_dict) 87 88 # Check parameters are equal after loading. 89 tp_state_dict_after_load = tp_model.state_dict() 90 for fsdp_item, tp_item in zip(fsdp_state_dict.items(), tp_state_dict.items()): 91 fsdp_k, fsdp_v = fsdp_item 92 tp_k, tp_v = tp_item 93 94 self.assertEqual(fsdp_k, tp_k) 95 96 if isinstance(fsdp_v, ShardedTensor) and isinstance(tp_v, DTensor): 97 fsdp_redistributed = _all_gather_sharded_tensor(fsdp_v) 98 tp_redistributed = tp_v.redistribute( 99 device_mesh, placements=[Replicate()] 100 ).to_local() 101 self.assertEqual(fsdp_redistributed, tp_redistributed) 102 103 104if __name__ == "__main__": 105 run_tests() 106