1# Owner(s): ["oncall: distributed"] 2 3import os 4import sys 5 6import torch 7import torch.distributed as dist 8import torch.distributed.checkpoint as dist_cp 9import torch.nn as nn 10from torch.distributed._tensor import init_device_mesh 11from torch.distributed.checkpoint.state_dict import ( 12 get_model_state_dict, 13 get_state_dict, 14 set_model_state_dict, 15 set_state_dict, 16 StateDictOptions, 17) 18from torch.distributed.fsdp import FullyShardedDataParallel as FSDP 19from torch.testing._internal.common_distributed import skip_if_lt_x_gpu 20from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN 21from torch.testing._internal.distributed._tensor.common_dtensor import ( 22 DTensorTestBase, 23 with_comms, 24) 25from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir 26 27 28if not dist.is_available(): 29 print("Distributed not available, skipping tests", file=sys.stderr) 30 sys.exit(0) 31 32if TEST_WITH_DEV_DBG_ASAN: 33 print( 34 "Skip dev-asan as torch + multiprocessing spawn have known issues", 35 file=sys.stderr, 36 ) 37 sys.exit(0) 38 39 40DIM = 500 41 42 43class PreTrainedModel(nn.Module): 44 def __init__(self) -> None: 45 super().__init__() 46 self.layer1 = nn.Linear(DIM, DIM) 47 self.layer2 = nn.Linear(DIM, DIM) 48 self.layer3 = nn.Linear(DIM, DIM) 49 self.sequential = nn.Sequential(nn.Linear(DIM, DIM), nn.ReLU()) 50 self.module_list = nn.ModuleList([nn.Linear(DIM, DIM), nn.ReLU()]) 51 self.relu = nn.ReLU() 52 53 def forward(self, batch): 54 x = self.relu(self.layer1(batch)) 55 x = self.relu(self.layer2(x)) 56 x = self.relu(self.layer3(x)) 57 x = self.sequential(x) 58 x = self.module_list[1](self.module_list[0](x)) 59 return x 60 61 62class FineTuningModel(nn.Module): 63 def __init__(self) -> None: 64 super().__init__() 65 self.pretrain = PreTrainedModel() 66 for p in self.pretrain.parameters(): 67 p.requires_grad = False 68 69 self.layer1 = nn.Linear(DIM, DIM) 70 self.layer2 = nn.Linear(DIM, DIM) 71 self.layer3 = nn.Linear(DIM, DIM) 72 self.relu = nn.ReLU() 73 74 def forward(self, batch): 75 x = self.relu(self.pretrain(batch)) 76 x = self.relu(self.layer1(x)) 77 x = self.relu(self.layer2(x)) 78 x = self.relu(self.layer3(x)) 79 return x 80 81 82class TestFineTuning(DTensorTestBase): 83 @property 84 def world_size(self) -> int: 85 return min(4, torch.cuda.device_count()) 86 87 @property 88 def backend(self): 89 return "cpu:gloo,cuda:nccl" 90 91 def pretrain(self, pretrain_dir: str) -> None: 92 device_mesh = init_device_mesh(self.device_type, (self.world_size,)) 93 94 model = PreTrainedModel().cuda() 95 model = FSDP(model, device_mesh=device_mesh) 96 optim = torch.optim.Adam(model.parameters(), lr=1e-3) 97 98 # Training 99 for i in range(3): 100 batch = torch.rand(32, DIM, device="cuda") 101 loss = model(batch).sum() 102 loss.backward() 103 optim.step() 104 optim.zero_grad() 105 106 # Save state_dict 107 model_state_dict, optim_state_dict = get_state_dict(model, optimizers=optim) 108 saved_state_dict = {"model": model_state_dict, "optim": optim_state_dict} 109 dist_cp.save_state_dict( 110 state_dict=saved_state_dict, 111 storage_writer=dist_cp.FileSystemWriter(pretrain_dir), 112 ) 113 114 def finetune(self, pretrain_dir: str, finetune_dir: str) -> None: 115 device_mesh = init_device_mesh(self.device_type, (self.world_size,)) 116 117 model = FineTuningModel().cuda() 118 # TODO: make the parallelism more complicated, e.g., using 2D + DDP. 119 model = FSDP(model, use_orig_params=True, device_mesh=device_mesh) 120 optim = torch.optim.Adam(model.parameters(), lr=1e-3) 121 122 # Simulate that the fine tuning restart after 3 iterations 123 for i in range(2): 124 # Load pretrain submodules checkpoint 125 pretrain_state_dict = get_model_state_dict( 126 model, 127 submodules={model.pretrain}, 128 options=StateDictOptions(keep_submodule_prefixes=False), 129 ) 130 dist_cp.load_state_dict( 131 {"model": pretrain_state_dict}, 132 storage_reader=dist_cp.FileSystemReader(pretrain_dir), 133 ) 134 set_model_state_dict( 135 model, 136 model_state_dict={model.pretrain: pretrain_state_dict}, 137 options=StateDictOptions(strict=False), 138 ) 139 140 try: 141 # Load training submodules checkpoint 142 model_state_dict, optim_state_dict = get_state_dict( 143 model, 144 optimizers=optim, 145 options=StateDictOptions(ignore_frozen_params=True), 146 ) 147 dist_cp.load_state_dict( 148 {"model": model_state_dict, "optim": optim_state_dict}, 149 storage_reader=dist_cp.FileSystemReader(pretrain_dir), 150 ) 151 set_state_dict( 152 model, 153 optimizers=optim, 154 model_state_dict=model_state_dict, 155 optim_state_dict=optim_state_dict, 156 options=StateDictOptions(strict=False), 157 ) 158 except KeyError: 159 # If this is the first round of the fine tuning, then nothing is saved. 160 # If this is the restart of the fine tuning, then checkpoint should exit. 161 self.assertEqual(i, 0) 162 163 # Training 164 for j in range(3): 165 batch = torch.rand(32, DIM, device="cuda") 166 loss = model(batch).sum() 167 loss.backward() 168 optim.step() 169 optim.zero_grad() 170 171 # Save state_dict 172 model_state_dict, optim_state_dict = get_state_dict( 173 model, 174 optimizers=optim, 175 options=StateDictOptions(ignore_frozen_params=True), 176 ) 177 saved_state_dict = {"model": model_state_dict, "optim": optim_state_dict} 178 dist_cp.save_state_dict( 179 state_dict=saved_state_dict, 180 storage_writer=dist_cp.FileSystemWriter(finetune_dir), 181 ) 182 183 @skip_if_lt_x_gpu(4) 184 @with_comms 185 @with_temp_dir 186 def test_fine_tuning(self) -> None: 187 self.assertTrue(os.path.exists(self.temp_dir)) 188 pretrain_dir = os.path.join(self.temp_dir, "pretrain") 189 finetune_dir = os.path.join(self.temp_dir, "finetune") 190 print(pretrain_dir, finetune_dir) 191 if self.rank == 0: 192 os.mkdir(pretrain_dir) 193 os.mkdir(finetune_dir) 194 dist.barrier() 195 os.sync() 196 self.assertTrue(os.path.exists(pretrain_dir)) 197 self.assertTrue(os.path.exists(finetune_dir)) 198 199 self.pretrain(pretrain_dir) 200 self.finetune(pretrain_dir, finetune_dir) 201 202 203if __name__ == "__main__": 204 run_tests() 205