# Owner(s): ["oncall: distributed"] from copy import deepcopy import torch import torch.distributed.checkpoint as dist_cp import torch.nn as nn import torch.nn.functional as F from torch.distributed._tensor import init_device_mesh, Replicate from torch.distributed.checkpoint.default_planner import ( DefaultLoadPlanner, DefaultSavePlanner, ) from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp.fully_sharded_data_parallel import ( ShardingStrategy, StateDictType, ) from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, run_tests, ) from torch.testing._internal.distributed._tensor.common_dtensor import ( DTensorTestBase, with_comms, ) from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir class SimpleModel(torch.nn.Module): def __init__(self) -> None: super().__init__() self.net1 = nn.Linear(5, 8) self.relu = nn.ReLU() self.net2 = nn.Linear(8, 4) self.net3 = nn.Linear(4, 12) def forward(self, x): x = F.relu(self.net1(x)) x = F.relu(self.net2(x)) x = F.relu(self.net3(x)) return x def get_input(self): return torch.rand(4, 5, device="cuda") class SimpleModelUneven(torch.nn.Module): def __init__(self) -> None: super().__init__() self.net1 = nn.Linear(5, 10) self.relu = nn.ReLU() self.net2 = nn.Linear(10, 15) self.net3 = nn.Linear(15, 30) self.net4 = nn.Linear(30, 5) def forward(self, x): x = F.relu(self.net1(x)) x = F.relu(self.net2(x)) x = F.relu(self.net3(x)) x = F.relu(self.net4(x)) return x def get_input(self): return torch.rand(4, 5, device="cuda") class TestHSDPCheckpoint(DTensorTestBase): @property def backend(self): return "cpu:gloo,cuda:nccl" @with_comms @skip_if_lt_x_gpu(4) @with_temp_dir @parametrize("is_even_sharded_model", [True, False]) def test_hsdp_checkpoint(self, is_even_sharded_model) -> None: CHECKPOINT_DIR = self.temp_dir simple_model = SimpleModel if is_even_sharded_model else SimpleModelUneven mesh_2d = init_device_mesh(self.device_type, (2, self.world_size // 2)) model = FSDP( simple_model().cuda(), sharding_strategy=ShardingStrategy.HYBRID_SHARD, device_mesh=mesh_2d, ) optim = torch.optim.Adam(model.parameters(), lr=0.1) FSDP.set_state_dict_type( model, StateDictType.SHARDED_STATE_DICT, ) state_dict = {"model": model.state_dict()} state_dict_to_save = deepcopy(state_dict) dist_cp.save( state_dict=state_dict_to_save, storage_writer=dist_cp.FileSystemWriter(CHECKPOINT_DIR), planner=DefaultSavePlanner(), ) # Update the parameters so current model state_dict now be different from state_dict_to_save. model(model.get_input()).sum().backward() optim.step() # At this point, the current state dict is different from state_dict_to_save. for (k1, v1), (k2, v2) in zip( state_dict_to_save["model"].items(), model.state_dict().items() ): self.assertEqual(k1, k2) self.assertEqual(v1.device_mesh, v2.device_mesh) self.assertEqual(v1.placements, v2.placements) self.assertNotEqual(v1.to_local(), v2.to_local()) dist_cp.load( state_dict=state_dict_to_save, storage_reader=dist_cp.FileSystemReader(CHECKPOINT_DIR), planner=DefaultLoadPlanner(), ) model.load_state_dict(state_dict_to_save["model"]) state_dict_after_load = model.state_dict() # After loading, the current model state dict should be the same as state_dict_to_save. for (k1, v1), (k2, v2) in zip( state_dict_to_save["model"].items(), model.state_dict().items() ): self.assertEqual(k1, k2) self.assertEqual(v1.device_mesh, v2.device_mesh) self.assertEqual(v1.placements, v2.placements) self.assertEqual(v1.to_local(), v2.to_local()) @with_comms @skip_if_lt_x_gpu(4) @with_temp_dir @parametrize("is_even_sharded_model", [True, False]) def test_hsdp_fsdp_checkpoint_conversion(self, is_even_sharded_model) -> None: CHECKPOINT_DIR = self.temp_dir simple_model = SimpleModel if is_even_sharded_model else SimpleModelUneven # save the hsdp model state_dict mesh_2d = init_device_mesh(self.device_type, (2, self.world_size // 2)) hsdp_model = FSDP( simple_model().cuda(), sharding_strategy=ShardingStrategy.HYBRID_SHARD, device_mesh=mesh_2d, ) FSDP.set_state_dict_type( hsdp_model, StateDictType.SHARDED_STATE_DICT, ) hsdp_state_dict = {"model": hsdp_model.state_dict()} dist_cp.save_state_dict( state_dict=hsdp_state_dict, storage_writer=dist_cp.FileSystemWriter(CHECKPOINT_DIR), planner=DefaultSavePlanner(), ) # initialize a fsdp model to load checkpoint into mesh_1d = init_device_mesh(self.device_type, (self.world_size,)) fsdp_model = FSDP( simple_model().cuda(), device_mesh=mesh_1d, ) FSDP.set_state_dict_type( fsdp_model, StateDictType.SHARDED_STATE_DICT, ) fsdp_state_dict = {"model": fsdp_model.state_dict()} # at this point, the hsdp model parameters are different from fsdp model parameters. for (k1, v1), (k2, v2) in zip( hsdp_state_dict["model"].items(), fsdp_state_dict["model"].items() ): self.assertEqual(k1, k2) self.assertNotEqual(v1.device_mesh, v2.device_mesh) self.assertNotEqual(v1.placements, v2.placements) v1_all_gather = v1.redistribute( mesh_2d, placements=(Replicate(), Replicate()) ) v2_all_gather = v2.redistribute(mesh_1d, placements=(Replicate(),)) self.assertNotEqual(v1_all_gather.to_local(), v2_all_gather.to_local()) # load the fsdp state_dict from storage dist_cp.load_state_dict( state_dict=fsdp_state_dict, storage_reader=dist_cp.FileSystemReader(CHECKPOINT_DIR), planner=DefaultLoadPlanner(), ) fsdp_model.load_state_dict(fsdp_state_dict["model"]) state_dict_after_load = fsdp_model.state_dict() # After loading, the current model state dict should be the same as hsdp_state_dict. for (k1, v1), (k2, v2) in zip( hsdp_state_dict["model"].items(), state_dict_after_load.items() ): self.assertEqual(k1, k2) self.assertNotEqual(v1.device_mesh, v2.device_mesh) self.assertNotEqual(v1.placements, v2.placements) v1_all_gather = v1.redistribute( mesh_2d, placements=(Replicate(), Replicate()) ) v2_all_gather = v2.redistribute(mesh_1d, placements=(Replicate(),)) self.assertEqual(v1_all_gather.to_local(), v2_all_gather.to_local()) instantiate_parametrized_tests(TestHSDPCheckpoint) if __name__ == "__main__": run_tests()