1# Owner(s): ["oncall: distributed"] 2 3from copy import deepcopy 4 5import torch 6import torch.distributed.checkpoint as dcp 7from torch.distributed._tensor import init_device_mesh 8from torch.distributed.checkpoint.default_planner import ( 9 DefaultLoadPlanner, 10 DefaultSavePlanner, 11) 12from torch.distributed.tensor.parallel import ( 13 ColwiseParallel, 14 parallelize_module, 15 RowwiseParallel, 16) 17from torch.testing._internal.common_utils import run_tests 18from torch.testing._internal.distributed._tensor.common_dtensor import ( 19 DTensorTestBase, 20 MLPModule, 21 skip_if_lt_x_gpu, 22 with_comms, 23) 24from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir 25 26 27class UnevenShardedModel(torch.nn.Module): 28 def __init__(self, device): 29 super().__init__() 30 torch.manual_seed(5) 31 self.net1 = torch.nn.Linear(5, 10, device=device) 32 self.relu = torch.nn.ReLU() 33 self.net2 = torch.nn.Linear(10, 15, device=device) 34 self.net3 = torch.nn.Linear(15, 1, device=device) 35 36 def forward(self, x): 37 return self.net3(self.net2(self.relu(self.net1(x)))) 38 39 40class TestTpCheckpoint(DTensorTestBase): 41 @with_comms 42 @skip_if_lt_x_gpu(2) 43 @with_temp_dir 44 def test_tp_checkpoint(self): 45 CHECKPOINT_DIR = self.temp_dir 46 mesh_shpe = (self.world_size,) 47 tp_mesh = init_device_mesh(self.device_type, mesh_shpe) 48 49 # create model and move it to GPU with id rank 50 model = MLPModule(self.device_type).cuda(self.rank) 51 # Parallelize the module based on the given Parallel Style. 52 parallelize_plan = { 53 "net1": ColwiseParallel(), 54 "net2": RowwiseParallel(), 55 } 56 model = parallelize_module(model, tp_mesh, parallelize_plan) 57 optimizer = torch.optim.SGD(model.parameters(), lr=0.25) 58 original_state_dict = deepcopy(model.state_dict()) 59 60 dcp.save( 61 state_dict=original_state_dict, 62 storage_writer=dcp.FileSystemWriter(CHECKPOINT_DIR), 63 planner=DefaultSavePlanner(), 64 ) 65 66 # Update the parameters so model.state_dict() will be different from original_state_dict. 67 torch.manual_seed(0) 68 inp = torch.rand(20, 10).cuda(self.rank) 69 output = model(inp) 70 output.sum().backward() 71 optimizer.step() 72 state_dict = model.state_dict() 73 74 # ensure the current model parameters are different from original_state_dict before loading from checkpoint 75 for param1, param2 in zip(original_state_dict.values(), state_dict.values()): 76 self.assertNotEqual(param1.to_local(), param2.to_local()) 77 78 dcp.load( 79 state_dict=state_dict, 80 storage_reader=dcp.FileSystemReader(CHECKPOINT_DIR), 81 planner=DefaultLoadPlanner(), 82 ) 83 84 # now load from checkpoint to check current model parameters are the same as original_state_dict 85 for param1, param2 in zip(original_state_dict.values(), state_dict.values()): 86 self.assertEqual(param1.to_local(), param2.to_local()) 87 88 @with_comms 89 @skip_if_lt_x_gpu(2) 90 @with_temp_dir 91 def test_tp_checkpoint_load_on_meta_device(self): 92 CHECKPOINT_DIR = self.temp_dir 93 mesh_shpe = (self.world_size,) 94 tp_mesh = init_device_mesh(self.device_type, mesh_shpe) 95 96 # create model and move it to GPU with id rank 97 model = UnevenShardedModel(self.device_type).cuda(self.rank) 98 # Parallelize the module based on the given Parallel Style. 99 parallelize_plan = { 100 "net1": ColwiseParallel(), 101 "net2": RowwiseParallel(), 102 "net3": ColwiseParallel(), 103 } 104 model = parallelize_module(model, tp_mesh, parallelize_plan=parallelize_plan) 105 original_state_dict = { 106 "model": model.state_dict(), 107 } 108 109 dcp.save( 110 state_dict=original_state_dict, 111 storage_writer=dcp.FileSystemWriter(CHECKPOINT_DIR), 112 ) 113 114 model2 = parallelize_module( 115 UnevenShardedModel("meta"), tp_mesh, parallelize_plan=parallelize_plan 116 ) 117 model2_sd_before_load = model2.state_dict() 118 state_dict_to_load = {"model": model2_sd_before_load} 119 120 dcp.load( 121 state_dict=state_dict_to_load, 122 storage_reader=dcp.FileSystemReader(CHECKPOINT_DIR), 123 ) 124 # We need to make sure state_dict_to_load["model"] is the same as state_dict_after_load["model"], 125 # since we are doing in-place loading. 126 self.assertTrue(state_dict_to_load["model"] is model2_sd_before_load) 127 128 model2.load_state_dict(state_dict_to_load["model"], assign=True) 129 state_dict_after_load = {"model": model2.state_dict()} 130 131 self.assertEqual( 132 len(original_state_dict["model"]), len(state_dict_to_load["model"]) 133 ) 134 self.assertEqual( 135 len(original_state_dict["model"]), len(state_dict_after_load["model"]) 136 ) 137 138 for name, param in original_state_dict["model"].items(): 139 param_to_load = state_dict_to_load["model"][name] 140 param_after_load = state_dict_after_load["model"][name] 141 142 # we need to explicitly check the device is not meta as the assertEqual check 143 # currently doesn't handle DTensor with meta device. 144 self.assertTrue(not param_to_load.is_meta) 145 self.assertTrue(not param_after_load.is_meta) 146 self.assertEqual(param.to_local(), param_to_load.to_local()) 147 self.assertEqual(param.to_local(), param_after_load.to_local()) 148 149 150if __name__ == "__main__": 151 run_tests() 152