1# Owner(s): ["oncall: distributed"] 2 3import torch 4import torch.distributed as dist 5import torch.distributed.checkpoint as dist_cp 6from torch.distributed.checkpoint.default_planner import ( 7 DefaultLoadPlanner, 8 DefaultSavePlanner, 9) 10from torch.distributed.fsdp import FullyShardedDataParallel as FSDP 11from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType 12from torch.testing._internal.common_distributed import skip_if_lt_x_gpu 13from torch.testing._internal.common_utils import run_tests 14from torch.testing._internal.distributed._tensor.common_dtensor import ( 15 DTensorTestBase, 16 with_comms, 17) 18from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir 19 20 21class FsdpModelStateCheckpoint(DTensorTestBase): 22 @property 23 def backend(self): 24 return "cpu:gloo,cuda:nccl" 25 26 def _test_fsdp_model_state(self, process_group) -> None: 27 CHECKPOINT_DIR = self.temp_dir 28 29 model = FSDP(torch.nn.Linear(8, 8, device="meta")) 30 model(torch.rand(8, 8, device=dist.get_rank())).sum().backward() 31 32 with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): 33 state_dict = { 34 "model": model.state_dict(), 35 } 36 37 dist_cp.save( 38 state_dict=state_dict, 39 storage_writer=dist_cp.FileSystemWriter(CHECKPOINT_DIR), 40 planner=DefaultSavePlanner(), 41 ) 42 43 model_2 = FSDP( 44 torch.nn.Linear(8, 8, device="meta"), process_group=process_group 45 ) 46 47 with FSDP.summon_full_params(model): 48 with FSDP.summon_full_params(model_2): 49 self.assertNotEqual(model.weight, model_2.weight) 50 self.assertNotEqual(model.bias, model_2.bias) 51 52 # now load the model and ensure the values are the same 53 with FSDP.state_dict_type(model_2, StateDictType.SHARDED_STATE_DICT): 54 state_dict = { 55 "model": model_2.state_dict(), 56 } 57 58 dist_cp.load( 59 state_dict=state_dict, 60 storage_reader=dist_cp.FileSystemReader(CHECKPOINT_DIR), 61 planner=DefaultLoadPlanner(), 62 ) 63 model_2.load_state_dict(state_dict["model"]) 64 65 with FSDP.summon_full_params(model): 66 with FSDP.summon_full_params(model_2): 67 self.assertEqual(model.weight, model_2.weight) 68 self.assertEqual(model.bias, model_2.bias) 69 70 @with_comms 71 @skip_if_lt_x_gpu(2) 72 @with_temp_dir 73 def test_fsdp_model_state_no_resharding(self): 74 self._test_fsdp_model_state(process_group=None) 75 76 def _create_new_dist_group(self): 77 world_size = dist.get_world_size() 78 group1 = [i for i in range(world_size) if i % 2 == 0] 79 group2 = [i for i in range(world_size) if i % 2 != 0] 80 81 # create new fsdp group for resharding 82 fsdp_0 = dist.new_group(ranks=group1) 83 fsdp_1 = dist.new_group(ranks=group2) 84 if dist.get_rank() % 2 == 0: 85 my_fsdp = fsdp_0 86 else: 87 my_fsdp = fsdp_1 88 89 return my_fsdp 90 91 @with_comms 92 @skip_if_lt_x_gpu(4) 93 @with_temp_dir 94 def test_fsdp_model_state_with_resharding(self): 95 self._test_fsdp_model_state(process_group=self._create_new_dist_group()) 96 97 98if __name__ == "__main__": 99 run_tests() 100