1# Owner(s): ["oncall: distributed"] 2 3import shutil 4import tempfile 5from functools import wraps 6from typing import Any, Callable, Dict, Optional, Tuple 7 8import torch 9import torch.distributed as dist 10import torch.distributed.checkpoint as dcp 11import torch.nn as nn 12from torch.distributed.checkpoint._fsspec_filesystem import FsspecReader, FsspecWriter 13from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict 14from torch.distributed.checkpoint.utils import CheckpointException 15from torch.distributed.fsdp import FullyShardedDataParallel as FSDP 16from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType 17from torch.testing._internal.common_distributed import requires_nccl, skip_if_lt_x_gpu 18from torch.testing._internal.common_utils import run_tests 19from torch.testing._internal.distributed._shard.sharded_tensor import ( 20 ShardedTensorTestBase, 21 with_comms, 22) 23 24 25def with_temp_dir( 26 func: Optional[Callable] = None, 27) -> Optional[Callable]: 28 """ 29 Wrapper to initialize temp directory for distributed checkpoint. 30 """ 31 assert func is not None 32 33 @wraps(func) 34 def wrapper(self, *args: Tuple[object], **kwargs: Dict[str, Any]) -> None: 35 # Only create temp_dir when rank is 0 36 if dist.get_rank() == 0: 37 temp_dir = tempfile.mkdtemp() 38 print(f"Using temp directory: {temp_dir}") 39 else: 40 temp_dir = "" 41 object_list = [temp_dir] 42 43 # Broadcast temp_dir to all the other ranks 44 dist.broadcast_object_list(object_list) 45 self.temp_dir = object_list[0] 46 47 try: 48 func(self, *args, **kwargs) 49 finally: 50 if dist.get_rank() == 0: 51 shutil.rmtree(self.temp_dir, ignore_errors=True) 52 53 return wrapper 54 55 56class MyTestModule(torch.nn.Module): 57 def __init__(self) -> None: 58 super().__init__() 59 self.net1 = nn.Sequential(nn.Linear(8, 16), nn.ReLU()) 60 self.net2 = nn.Sequential(nn.Linear(16, 32), nn.ReLU()) 61 self.net3 = nn.Linear(32, 64) 62 self.net4 = nn.Sequential(nn.ReLU(), nn.Linear(64, 8)) 63 64 def forward(self, x): 65 return self.net4(self.net3(self.net2(self.net1(x)))) 66 67 68class TestFSSpec(ShardedTensorTestBase): 69 @property 70 def world_size(self) -> int: 71 return 2 72 73 @with_comms(init_rpc=False) 74 @skip_if_lt_x_gpu(2) 75 @requires_nccl() 76 @with_temp_dir 77 def test_fsspec(self): 78 CHECKPOINT_DIR = self.temp_dir 79 80 model = FSDP(MyTestModule().cuda()) 81 optim = torch.optim.Adam(model.parameters(), lr=0.1) 82 model(torch.rand(8, 8, device=dist.get_rank())).sum().backward() 83 optim.step() 84 85 with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): 86 state_dict = { 87 "model": model.state_dict(), 88 "optim": FSDP.optim_state_dict(model, optim), 89 } 90 91 dcp.save( 92 state_dict=state_dict, 93 storage_writer=FsspecWriter(CHECKPOINT_DIR), 94 planner=dcp.DefaultSavePlanner(), 95 ) 96 97 model_2 = FSDP(MyTestModule().cuda()) 98 optim_2 = torch.optim.Adam(model_2.parameters(), lr=0.1) 99 100 with FSDP.summon_full_params(model): 101 with FSDP.summon_full_params(model_2): 102 for n_p1, n_p2 in zip( 103 model.named_parameters(), model_2.named_parameters() 104 ): 105 self.assertNotEqual(n_p1[1], n_p2[1]) 106 107 # now load the model and ensure the values are the same 108 with FSDP.state_dict_type(model_2, StateDictType.SHARDED_STATE_DICT): 109 state_dict = { 110 "model": model_2.state_dict(), 111 } 112 113 dcp.load( 114 state_dict=state_dict, 115 storage_reader=FsspecReader(CHECKPOINT_DIR), 116 planner=dcp.DefaultLoadPlanner(), 117 ) 118 model_2.load_state_dict(state_dict["model"]) 119 120 optim_state = load_sharded_optimizer_state_dict( 121 model_state_dict=state_dict["model"], 122 optimizer_key="optim", 123 storage_reader=FsspecReader(CHECKPOINT_DIR), 124 ) 125 126 flattened_osd = FSDP.optim_state_dict_to_load( 127 model_2, optim_2, optim_state["optim"] 128 ) 129 optim_2.load_state_dict(flattened_osd) 130 131 with FSDP.summon_full_params(model): 132 with FSDP.summon_full_params(model_2): 133 for n_p1, n_p2 in zip( 134 model.named_parameters(), model_2.named_parameters() 135 ): 136 self.assertEqual(n_p1[1], n_p2[1]) 137 138 def opt_at(opt, idx): 139 return list(iter(opt.state.values()))[idx] 140 141 # Adam lazily creates its state 142 self.assertEqual(opt_at(optim, 0)["exp_avg"], opt_at(optim_2, 0)["exp_avg"]) 143 self.assertEqual( 144 opt_at(optim, 0)["exp_avg_sq"], opt_at(optim_2, 0)["exp_avg_sq"] 145 ) 146 147 @with_comms(init_rpc=False) 148 @skip_if_lt_x_gpu(2) 149 @requires_nccl() 150 @with_temp_dir 151 def test_overwrite(self): 152 t1, t2 = torch.randn(10), torch.randn(10) 153 154 dcp.save( 155 {"random": t1}, storage_writer=FsspecWriter(self.temp_dir, overwrite=False) 156 ) 157 dcp.save( 158 {"random": t2}, storage_writer=FsspecWriter(self.temp_dir, overwrite=True) 159 ) 160 161 sd = {"random": torch.zeros(10)} 162 dcp.load(sd, checkpoint_id=self.temp_dir) 163 self.assertTrue(torch.allclose(sd["random"], t2)) 164 165 with self.assertRaisesRegex( 166 CheckpointException, ".*Checkpoint already exists.*" 167 ): 168 dcp.save( 169 {"random": t2}, 170 storage_writer=FsspecWriter(self.temp_dir, overwrite=False), 171 ) 172 173 174if __name__ == "__main__": 175 run_tests() 176