1# Owner(s): ["oncall: distributed"] 2import torch 3import torch.distributed.checkpoint as dist_cp 4from torch.distributed._tensor import ( 5 distribute_tensor, 6 init_device_mesh, 7 Replicate, 8 Shard, 9 zeros, 10) 11from torch.testing._internal.common_utils import run_tests 12from torch.testing._internal.distributed._tensor.common_dtensor import ( 13 DTensorTestBase, 14 skip_if_lt_x_gpu, 15 with_comms, 16) 17from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir 18 19 20CHECKPOINT_DIR = "checkpoint" 21 22ONE_D_PLACEMENTS = [ 23 [Shard(0)], 24 [Replicate()], 25] 26ONE_D_TO_ONE_D_PLACEMENTS = [ 27 ([Replicate()], [Shard(0)]), 28 ([Shard(0)], [Replicate()]), 29] 30 31TWO_D_PLACEMENTS = [ 32 [Replicate(), Replicate()], 33 [Replicate(), Shard(0)], 34 [Shard(0), Replicate()], 35 [Shard(0), Shard(0)], 36] 37TWO_D_TO_TWO_D_PLACEMENTS = [] 38for p1 in TWO_D_PLACEMENTS: 39 for p2 in TWO_D_PLACEMENTS: 40 if p1 != p2: 41 TWO_D_TO_TWO_D_PLACEMENTS.append((p1, p2)) 42 43 44class TestDTensorReshardPlacementChange(DTensorTestBase): 45 """ 46 Test DCP reshard for DTensor with placements changes and without world_size change and mesh_tensor change. 47 """ 48 49 @with_comms 50 @skip_if_lt_x_gpu(2) 51 @with_temp_dir 52 def test_1d_to_1d_reshard_placement_change(self) -> None: 53 CHECKPOINT_DIR = self.temp_dir 54 55 for one_d_to_one_d_placements in ONE_D_TO_ONE_D_PLACEMENTS: 56 original_placement, new_placement = one_d_to_one_d_placements 57 58 global_tensor = torch.arange(16, dtype=torch.float).view(4, 4) 59 mesh_shape = (self.world_size,) 60 device_mesh = init_device_mesh(self.device_type, mesh_shape) 61 dtensor = distribute_tensor( 62 global_tensor, device_mesh, placements=original_placement 63 ) 64 state_dict_to_save = {"dtensor": dtensor} 65 66 dist_cp.save_state_dict( 67 state_dict=state_dict_to_save, 68 storage_writer=dist_cp.FileSystemWriter(path=CHECKPOINT_DIR), 69 planner=dist_cp.DefaultSavePlanner(), 70 ) 71 72 zero_dtensor = zeros( 73 [4, 4], device_mesh=device_mesh, placements=new_placement 74 ) 75 state_dict_to_load = {"dtensor": zero_dtensor} 76 77 dist_cp.load_state_dict( 78 state_dict=state_dict_to_load, 79 storage_reader=dist_cp.FileSystemReader(CHECKPOINT_DIR), 80 planner=dist_cp.DefaultLoadPlanner(), 81 ) 82 83 # materialzie the whole tensor to compare with the original global_tensor 84 state_dict_to_load["dtensor"] = state_dict_to_load["dtensor"].redistribute( 85 device_mesh, 86 placements=[Replicate()], 87 ) 88 self.assertEqual(global_tensor, state_dict_to_load["dtensor"].to_local()) 89 90 # redistribute the tensor back to its original placement for comparison. 91 state_dict_to_load["dtensor"] = state_dict_to_load["dtensor"].redistribute( 92 device_mesh, 93 placements=original_placement, 94 ) 95 self.assertEqual( 96 state_dict_to_save["dtensor"].to_local(), 97 state_dict_to_load["dtensor"].to_local(), 98 ) 99 100 @with_comms 101 @skip_if_lt_x_gpu(4) 102 @with_temp_dir 103 def test_2d_to_2d_reshard_placement_change(self) -> None: 104 CHECKPOINT_DIR = self.temp_dir 105 for two_d_to_two_d_placements in TWO_D_TO_TWO_D_PLACEMENTS: 106 original_placement, new_placement = two_d_to_two_d_placements 107 108 global_tensor = torch.arange(16, dtype=torch.float).view(4, 4) 109 mesh_shape = (2, self.world_size // 2) 110 mesh_2d = init_device_mesh(self.device_type, mesh_shape) 111 dtensor = distribute_tensor( 112 global_tensor, 113 mesh_2d, 114 placements=original_placement, 115 ) 116 state_dict_to_save = {"dtensor": dtensor} 117 118 dist_cp.save_state_dict( 119 state_dict=state_dict_to_save, 120 storage_writer=dist_cp.FileSystemWriter(path=CHECKPOINT_DIR), 121 planner=dist_cp.DefaultSavePlanner(), 122 ) 123 124 zero_dtensor = zeros([4, 4], device_mesh=mesh_2d, placements=new_placement) 125 state_dict_to_load = {"dtensor": zero_dtensor} 126 127 dist_cp.load_state_dict( 128 state_dict=state_dict_to_load, 129 storage_reader=dist_cp.FileSystemReader(CHECKPOINT_DIR), 130 planner=dist_cp.DefaultLoadPlanner(), 131 ) 132 133 state_dict_to_load["dtensor"] = state_dict_to_load["dtensor"].redistribute( 134 mesh_2d, 135 placements=[Replicate(), Replicate()], 136 ) 137 self.assertEqual(global_tensor, state_dict_to_load["dtensor"].to_local()) 138 139 state_dict_to_load["dtensor"] = state_dict_to_load["dtensor"].redistribute( 140 mesh_2d, 141 placements=original_placement, 142 ) 143 self.assertEqual( 144 state_dict_to_save["dtensor"].to_local(), 145 state_dict_to_load["dtensor"].to_local(), 146 ) 147 148 149class TestDTensorReshardMeshChange(DTensorTestBase): 150 """ 151 Test DCP reshard for DTensor with placements changes and mesh_tensor change. 152 """ 153 154 @with_comms 155 @with_temp_dir 156 @skip_if_lt_x_gpu(2) 157 def test_1d_to_2d_reshard_mesh_change(self) -> None: 158 CHECKPOINT_DIR = self.temp_dir 159 for placements_1d in ONE_D_PLACEMENTS: 160 global_tensor = torch.arange(16, dtype=torch.float).view(4, 4) 161 mesh_shape = (self.world_size,) 162 mesh_1d = init_device_mesh(self.device_type, mesh_shape) 163 dtensor = distribute_tensor( 164 global_tensor, mesh_1d, placements=placements_1d 165 ) 166 state_dict_to_save = {"dtensor": dtensor} 167 168 dist_cp.save_state_dict( 169 state_dict=state_dict_to_save, 170 storage_writer=dist_cp.FileSystemWriter(path=CHECKPOINT_DIR), 171 planner=dist_cp.DefaultSavePlanner(), 172 ) 173 174 for placements_2d in TWO_D_PLACEMENTS: 175 mesh_shape = (2, self.world_size // 2) 176 mesh_2d = init_device_mesh(self.device_type, mesh_shape) 177 178 zero_dtensor = zeros( 179 [4, 4], device_mesh=mesh_2d, placements=placements_2d 180 ) 181 state_dict_to_load = {"dtensor": zero_dtensor} 182 183 dist_cp.load_state_dict( 184 state_dict=state_dict_to_load, 185 storage_reader=dist_cp.FileSystemReader(CHECKPOINT_DIR), 186 planner=dist_cp.DefaultLoadPlanner(), 187 ) 188 189 # materialzie the whole tensor to compare with the original global_tensor 190 state_dict_to_load["dtensor"] = state_dict_to_load[ 191 "dtensor" 192 ].redistribute( 193 mesh_2d, 194 placements=[Replicate(), Replicate()], 195 ) 196 self.assertEqual( 197 global_tensor, state_dict_to_load["dtensor"].to_local() 198 ) 199 200 @with_comms 201 @with_temp_dir 202 @skip_if_lt_x_gpu(4) 203 def test_2d_to_1d_reshard_mesh_change(self) -> None: 204 CHECKPOINT_DIR = self.temp_dir 205 for placements_2d in TWO_D_PLACEMENTS: 206 global_tensor = torch.arange(16, dtype=torch.float).view(4, 4) 207 mesh_shape = (2, self.world_size // 2) 208 mesh_2d = init_device_mesh(self.device_type, mesh_shape) 209 dtensor = distribute_tensor( 210 global_tensor, mesh_2d, placements=placements_2d 211 ) 212 state_dict_to_save = {"dtensor": dtensor} 213 214 dist_cp.save_state_dict( 215 state_dict=state_dict_to_save, 216 storage_writer=dist_cp.FileSystemWriter(path=CHECKPOINT_DIR), 217 planner=dist_cp.DefaultSavePlanner(), 218 ) 219 220 for placements_1d in ONE_D_PLACEMENTS: 221 mesh_shape = (self.world_size,) 222 mesh_1d = init_device_mesh(self.device_type, mesh_shape) 223 224 zero_dtensor = zeros( 225 [4, 4], device_mesh=mesh_1d, placements=placements_1d 226 ) 227 state_dict_to_load = {"dtensor": zero_dtensor} 228 229 dist_cp.load_state_dict( 230 state_dict=state_dict_to_load, 231 storage_reader=dist_cp.FileSystemReader(CHECKPOINT_DIR), 232 planner=dist_cp.DefaultLoadPlanner(), 233 ) 234 235 # materialzie the whole tensor to compare with the original global_tensor 236 state_dict_to_load["dtensor"] = state_dict_to_load[ 237 "dtensor" 238 ].redistribute( 239 mesh_1d, 240 placements=[Replicate()], 241 ) 242 self.assertEqual( 243 global_tensor, state_dict_to_load["dtensor"].to_local() 244 ) 245 246 @with_comms 247 @with_temp_dir 248 @skip_if_lt_x_gpu(2) 249 def test_dtensor_checkpoint_resharding_with_empty_shard(self): 250 """ 251 Test dtensor checkpoint resharding with dtensor containing empty shards. 252 """ 253 tensor = torch.rand(1).cuda() 254 mesh = init_device_mesh(self.device_type, (self.world_size,)) 255 dtensor = distribute_tensor(tensor, mesh, [Shard(0)]) 256 ref_state_dict = {"dtensor": dtensor} 257 258 dist_cp.save_state_dict( 259 state_dict=ref_state_dict, 260 storage_writer=dist_cp.FileSystemWriter(path=self.temp_dir), 261 ) 262 263 tensor = torch.rand(1).cuda() 264 mesh_2 = init_device_mesh(self.device_type, (2, self.world_size // 2)) 265 dtensor = distribute_tensor(tensor, mesh_2, [Shard(0), Shard(0)]) 266 state_dict = {"dtensor": dtensor} 267 dist_cp.load_state_dict( 268 state_dict=state_dict, 269 storage_reader=dist_cp.FileSystemReader(self.temp_dir), 270 ) 271 272 # TODO: Add a assertEqual for ref_state_dict["dtensor"].full_tensor() 273 # and state_dict["dtensor"].full_tensor() after we fix the size mismatch 274 # issue for un-even sharding dtensor. 275 276 277# TODO: Add dtensor resharding test when world size changes. 278if __name__ == "__main__": 279 run_tests() 280