# Owner(s): ["oncall: distributed"] import torch import torch.distributed.checkpoint as dist_cp from torch.distributed._shard.sharded_tensor import ShardedTensor from torch.distributed._state_dict_utils import _all_gather_sharded_tensor from torch.distributed._tensor import DTensor, init_device_mesh, Replicate from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType from torch.distributed.tensor.parallel import ( ColwiseParallel, parallelize_module, RowwiseParallel, ) from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import ( DTensorTestBase, MLPModule, skip_if_lt_x_gpu, with_comms, ) from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir # TODO: modularize this test and add test for checkpoint conversion in both direction. class TestFsdpTpCheckpointConversion(DTensorTestBase): @with_comms @skip_if_lt_x_gpu(2) @with_temp_dir def test_fsdp_to_tp(self): CHECKPOINT_DIR = self.temp_dir model = MLPModule(self.device_type).cuda(self.rank) # create a FSDP wrapped model fsdp_model = FSDP(model, use_orig_params=True) FSDP.set_state_dict_type( fsdp_model, StateDictType.SHARDED_STATE_DICT, ) fsdp_state_dict = fsdp_model.state_dict() # save fsdp_state_dict to storage dist_cp.save( state_dict=fsdp_state_dict, storage_writer=dist_cp.FileSystemWriter(CHECKPOINT_DIR), ) # create a TP wrapped model mesh_shape = (self.world_size,) device_mesh = init_device_mesh(self.device_type, mesh_shape) model = MLPModule(self.device_type).cuda(self.rank) # Parallelize the module based on the given Parallel Style. parallelize_plan = { "net1": ColwiseParallel(), "net2": RowwiseParallel(), } tp_model = parallelize_module(model, device_mesh, parallelize_plan) optimizer = torch.optim.SGD(tp_model.parameters(), lr=0.25) # Update the parameters so tp_model.state_dict() will be different from fsdp_model.state_dict(). torch.manual_seed(0) inp = torch.rand(20, 10).cuda(self.rank) output = tp_model(inp) output.sum().backward() optimizer.step() tp_state_dict = tp_model.state_dict() # Check parameters are indeed different prior to loading. for fsdp_item, tp_item in zip(fsdp_state_dict.items(), tp_state_dict.items()): fsdp_k, fsdp_v = fsdp_item tp_k, tp_v = tp_item self.assertEqual(fsdp_k, tp_k) if isinstance(fsdp_v, ShardedTensor) and isinstance(tp_v, DTensor): fsdp_redistributed = _all_gather_sharded_tensor(fsdp_v) tp_redistributed = tp_v.redistribute( device_mesh, placements=[Replicate()] ).to_local() self.assertNotEqual(fsdp_redistributed, tp_redistributed) dist_cp.load_state_dict( state_dict=tp_state_dict, storage_reader=dist_cp.FileSystemReader(CHECKPOINT_DIR), ) tp_model.load_state_dict(tp_state_dict) # Check parameters are equal after loading. tp_state_dict_after_load = tp_model.state_dict() for fsdp_item, tp_item in zip(fsdp_state_dict.items(), tp_state_dict.items()): fsdp_k, fsdp_v = fsdp_item tp_k, tp_v = tp_item self.assertEqual(fsdp_k, tp_k) if isinstance(fsdp_v, ShardedTensor) and isinstance(tp_v, DTensor): fsdp_redistributed = _all_gather_sharded_tensor(fsdp_v) tp_redistributed = tp_v.redistribute( device_mesh, placements=[Replicate()] ).to_local() self.assertEqual(fsdp_redistributed, tp_redistributed) if __name__ == "__main__": run_tests()