1# Copyright (c) Meta Platforms, Inc. and affiliates 2# Owner(s): ["oncall: distributed"] 3 4import itertools 5 6import torch 7from torch.distributed._tensor import DeviceMesh, distribute_tensor, DTensor 8from torch.distributed._tensor.placement_types import Partial, Replicate, Shard 9from torch.distributed.device_mesh import init_device_mesh 10from torch.distributed.tensor._collective_utils import shard_dim_alltoall 11from torch.distributed.tensor.debug import CommDebugMode 12from torch.testing._internal.common_utils import run_tests 13from torch.testing._internal.distributed._tensor.common_dtensor import ( 14 DTensorTestBase, 15 with_comms, 16) 17 18 19funcol = torch.ops.c10d_functional 20 21 22class RedistributeTest(DTensorTestBase): 23 @property 24 def world_size(self): 25 return 4 26 27 @with_comms 28 def test_shard_to_replicate_forward_backward(self): 29 # 1) test shard -> replicate forward 30 device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) 31 replica_spec = [Replicate()] 32 33 input_sizes_and_shard_dim = [ 34 ((self.world_size * 3, 3), 0), 35 ((self.world_size * 3 + 1, 3), 0), 36 ((self.world_size * 3 + 2, 3), 0), 37 ((3, self.world_size * 3), 1), 38 ((3, self.world_size * 3 + 1), 1), 39 ((3, self.world_size * 3 + 2), 1), 40 ] 41 42 comm_mode = CommDebugMode() 43 for input_size, shard_dim in input_sizes_and_shard_dim: 44 shard_spec = [Shard(shard_dim)] 45 expected_tensor = torch.randn( 46 input_size, device=self.device_type, requires_grad=True 47 ) 48 dtensor = distribute_tensor(expected_tensor, device_mesh, shard_spec) 49 with comm_mode: 50 reshard_dtensor = dtensor.redistribute(device_mesh, replica_spec) 51 self.assertEqual(reshard_dtensor.size(), torch.Size(input_size)) 52 self.assertEqual(expected_tensor, reshard_dtensor.to_local()) 53 self.assertEqual( 54 comm_mode.get_comm_counts()[funcol.all_gather_into_tensor], 1 55 ) 56 57 # 2) test shard -> replicate backward: 58 # should give gradient as shard 59 grad_output = torch.ones_like(reshard_dtensor) 60 with comm_mode: 61 reshard_dtensor.backward(grad_output) 62 grad_input = dtensor.grad 63 self.assertEqual(grad_input.placements, shard_spec) 64 self.assertEqual( 65 grad_input.to_local(), torch.ones(dtensor.to_local().size()) 66 ) 67 self.assertEqual(comm_mode.get_total_counts(), 0) 68 69 @with_comms 70 def test_replicate_to_replicate_forward_backward(self): 71 device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) 72 replica_spec = [Replicate()] 73 local_tensor = torch.randn(12, 3, device=self.device_type, requires_grad=True) 74 75 comm_mode = CommDebugMode() 76 77 # 1) test replicate -> replicate forward 78 replica_tensor = distribute_tensor(local_tensor, device_mesh, replica_spec) 79 with comm_mode: 80 reshard_replica_tensor = replica_tensor.redistribute( 81 device_mesh, replica_spec 82 ) 83 self.assertEqual(replica_tensor.size(), local_tensor.size()) 84 self.assertEqual(replica_tensor, reshard_replica_tensor) 85 self.assertEqual(comm_mode.get_total_counts(), 0) 86 87 # 2) test replicate -> replicate backward: 88 # should give gradient as replicate 89 grad_output = torch.ones_like(reshard_replica_tensor) 90 with comm_mode: 91 reshard_replica_tensor.backward(grad_output) 92 grad_input = replica_tensor.grad 93 self.assertEqual(grad_input.placements, replica_spec) 94 self.assertEqual(grad_input.to_local(), torch.ones(12, 3)) 95 self.assertEqual(comm_mode.get_total_counts(), 0) 96 97 @with_comms 98 def test_replicate_to_local_partial_grad(self): 99 device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) 100 replica_spec = [Replicate()] 101 local_tensor = torch.randn(12, 3, device=self.device_type, requires_grad=True) 102 103 replica_tensor = distribute_tensor(local_tensor, device_mesh, replica_spec) 104 105 comm_mode = CommDebugMode() 106 107 with comm_mode: 108 out = replica_tensor.redistribute(placements=[Replicate()]).to_local( 109 grad_placements=[Partial()] 110 ) 111 out.backward(torch.ones_like(out)) 112 113 self.assertEqual(comm_mode.get_total_counts(), 1) 114 self.assertEqual(comm_mode.get_comm_counts()[funcol.all_reduce], 1) 115 116 @with_comms 117 def test_replicate_to_shard_forward_backward(self): 118 device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) 119 replica_spec = [Replicate()] 120 121 input_sizes_and_shard_dim = [ 122 ((self.world_size * 3, 3), 0), 123 ((self.world_size * 3 + 1, 3), 0), 124 ((self.world_size * 3 + 2, 3), 0), 125 ((3, self.world_size * 3), 1), 126 ((3, self.world_size * 3 + 1), 1), 127 ((3, self.world_size * 3 + 2), 1), 128 ] 129 130 comm_mode = CommDebugMode() 131 for input_size, shard_dim in input_sizes_and_shard_dim: 132 shard_spec = [Shard(shard_dim)] 133 # 1) test replicate -> shard forward 134 local_replica = torch.randn( 135 input_size, device=self.device_type, requires_grad=True 136 ) 137 splitted_list = list( 138 torch.chunk(local_replica, self.world_size, dim=shard_dim) 139 ) 140 141 # make local tensor as the element of the corresponding chunked list 142 local_tensor = splitted_list[self.rank] 143 replica_tensor = distribute_tensor(local_replica, device_mesh, replica_spec) 144 with comm_mode: 145 reshard_tensor = replica_tensor.redistribute(device_mesh, shard_spec) 146 self.assertEqual(reshard_tensor.size(), replica_tensor.size()) 147 self.assertEqual(reshard_tensor.placements, shard_spec) 148 self.assertEqual(reshard_tensor.to_local(), local_tensor) 149 self.assertEqual(comm_mode.get_total_counts(), 0) 150 151 # 2) test replicate -> shard backward: 152 # should give gradient as replicate 153 grad_output = torch.ones_like(reshard_tensor) 154 with comm_mode: 155 reshard_tensor.backward(grad_output) 156 grad_input = replica_tensor.grad 157 self.assertEqual(grad_input.placements, replica_spec) 158 self.assertEqual(grad_input.to_local(), torch.ones(input_size)) 159 self.assertEqual(comm_mode.get_total_counts(), 1) 160 self.assertEqual( 161 comm_mode.get_comm_counts()[funcol.all_gather_into_tensor], 1 162 ) 163 164 @with_comms 165 def test_partial_to_replicate_forward_backward(self): 166 # Although we don't allow user to reshard to produce a partial 167 # placement (i.e. user can't reshard to partial), we do allow 168 # replicate to partial internally, and also partial to replicate 169 # backward should work as expected 170 device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) 171 partial_local = torch.ones(12, 3, device=self.device_type, requires_grad=True) 172 partial_spec = [Partial()] 173 replica_spec = [Replicate()] 174 175 comm_mode = CommDebugMode() 176 # test partial -> replicate, which trigger all_reduce 177 partial_tensor = DTensor.from_local(partial_local, device_mesh, partial_spec) 178 with comm_mode: 179 global_partial_tensor = partial_tensor.redistribute( 180 device_mesh, replica_spec 181 ) 182 183 self.assertEqual(partial_tensor.size(), partial_local.size()) 184 self.assertEqual( 185 partial_local * self.world_size, global_partial_tensor.to_local() 186 ) 187 self.assertEqual(comm_mode.get_comm_counts()[funcol.all_reduce], 1) 188 189 # test backward to have replicate grad on partial 190 # for from_local backward, we want the replicate() -> partial() to be 191 # pass through. 192 with comm_mode: 193 global_partial_tensor.backward(torch.ones_like(global_partial_tensor)) 194 self.assertIsNotNone(partial_local.grad) 195 self.assertEqual(partial_local.grad.size(), partial_local.size()) 196 self.assertEqual(partial_local.grad, torch.ones_like(partial_local)) 197 self.assertEqual(comm_mode.get_total_counts(), 0) 198 199 @with_comms 200 def test_replicate_to_partial(self): 201 device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) 202 local_tensor = torch.randn(12, 3, device=self.device_type, requires_grad=True) 203 partial_spec = Partial() 204 replica_spec = Replicate() 205 # 1) test replicate -> partial forward 206 replica_tensor = distribute_tensor(local_tensor, device_mesh, [replica_spec]) 207 with self.assertRaisesRegex(RuntimeError, "Can not redistribute to Partial"): 208 partial_tensor = replica_tensor.redistribute(device_mesh, [partial_spec]) 209 210 from torch.distributed.tensor._redistribute import Redistribute 211 212 comm_mode = CommDebugMode() 213 214 with comm_mode: 215 partial_tensor = Redistribute.apply( 216 replica_tensor, device_mesh, [partial_spec] 217 ) 218 self.assertEqual(partial_tensor.size(), local_tensor.size()) 219 # test it successfully zero out the contents on other ranks 220 self.assertEqual( 221 replica_tensor.to_local() / self.world_size, partial_tensor.to_local() 222 ) 223 self.assertEqual(comm_mode.get_total_counts(), 0) 224 225 # replicate to partial on sub groups 226 local_tensor = torch.randn(12, 3, device=self.device_type) 227 device_mesh = DeviceMesh( 228 self.device_type, 229 torch.arange(self.world_size).reshape(self.world_size // 2, 2), 230 ) 231 # 1) test replicate -> partial on 2d-mesh subgroups 232 replica_tensor = distribute_tensor( 233 local_tensor, device_mesh, [replica_spec, replica_spec] 234 ) 235 with comm_mode: 236 partial_tensor = Redistribute.apply( 237 replica_tensor, device_mesh, [partial_spec, partial_spec] 238 ) 239 self.assertEqual(partial_tensor.size(), local_tensor.size()) 240 241 self.assertEqual( 242 replica_tensor.to_local() / self.world_size, 243 partial_tensor.to_local(), 244 ) 245 self.assertEqual(comm_mode.get_total_counts(), 0) 246 247 @with_comms 248 def test_partial_to_shard(self): 249 device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) 250 partial_spec = [Partial()] 251 my_rank = device_mesh.get_rank() 252 253 input_sizes_and_shard_dim = [ 254 ((self.world_size * 3, 3), 0), 255 ((self.world_size * 3 + 1, 3), 0), 256 ((self.world_size * 3 + 2, 3), 0), 257 ((3, self.world_size * 3), 1), 258 ((3, self.world_size * 3 + 1), 1), 259 ((3, self.world_size * 3 + 2), 1), 260 ] 261 262 comm_mode = CommDebugMode() 263 264 for input_size, shard_dim in input_sizes_and_shard_dim: 265 shard_spec = [Shard(shard_dim)] 266 267 partial_local = torch.ones(input_size, device=self.device_type) 268 partial_tensor = DTensor.from_local( 269 partial_local, device_mesh, partial_spec, run_check=False 270 ) 271 272 full_chunk_size = ( 273 input_size[shard_dim] + self.world_size - 1 274 ) // self.world_size 275 chunk_sizes = [ 276 max( 277 min(input_size[shard_dim], full_chunk_size * (idx + 1)) 278 - full_chunk_size * idx, 279 0, 280 ) 281 for idx in range(self.world_size) 282 ] 283 284 local_shape = list(input_size) 285 local_shape[shard_dim] = chunk_sizes[my_rank] 286 287 # test partial to shard, trigger reduce_scatter 288 with comm_mode: 289 scatter_shard_tensor = partial_tensor.redistribute( 290 device_mesh, shard_spec 291 ) 292 self.assertEqual(scatter_shard_tensor.size(), partial_tensor.size()) 293 self.assertEqual(scatter_shard_tensor.placements, shard_spec) 294 self.assertEqual( 295 scatter_shard_tensor.to_local(), 296 torch.ones(local_shape) * self.world_size, 297 ) 298 self.assertEqual( 299 comm_mode.get_comm_counts()[funcol.reduce_scatter_tensor], 1 300 ) 301 302 @with_comms 303 def test_redistribute_negative_shard_dim(self): 304 device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) 305 local_tensor = torch.randn(12, 3, device=self.device_type, requires_grad=True) 306 shard_spec = [Shard(1)] 307 shard_minus_spec = [Shard(-1)] 308 309 shard_tensor = distribute_tensor(local_tensor, device_mesh, shard_spec) 310 self.assertEqual(shard_tensor.placements[0].dim, 1) 311 reshard_tensor = shard_tensor.redistribute(device_mesh, shard_minus_spec) 312 self.assertEqual(shard_tensor.placements[0].dim, 1) 313 314 @with_comms 315 def test_redistribute_uneven_sharding(self): 316 mesh = DeviceMesh(self.device_type, torch.arange(self.world_size).reshape(2, 2)) 317 data_to_test = [ 318 # uneven on last mesh dim 319 torch.randn((10, 5), device=self.device_type), 320 # uneven on both mesh dims 321 torch.randn((9, 5), device=self.device_type), 322 # smaller than mesh dim shape 323 torch.randn((3, 5), device=self.device_type), 324 torch.randn((1, 3), device=self.device_type), 325 ] 326 327 sharding_to_tests = [ 328 [Shard(0), Shard(0)], 329 [Shard(0), Shard(1)], 330 ] 331 332 for input_tensor in data_to_test: 333 for placements in sharding_to_tests: 334 dt = distribute_tensor(input_tensor, mesh, placements) 335 dt_full_tensor = dt.full_tensor() 336 self.assertEqual(dt_full_tensor, input_tensor) 337 338 @with_comms 339 def test_redistribute_shard_dim_change(self): 340 # test 1d device mesh 341 mesh_1d = DeviceMesh(self.device_type, torch.arange(self.world_size)) 342 data_to_test = [ 343 # evenly sharded case 344 torch.randn((8, 8), device=self.device_type), 345 # 3d or more dims 346 torch.randn((8, 8, 8), device=self.device_type), 347 # uneven case 1 348 torch.randn((8, 5), device=self.device_type), 349 # uneven case 2 350 torch.randn((5, 8), device=self.device_type), 351 # uneven case 3 352 torch.randn((5, 5), device=self.device_type), 353 ] 354 355 sharding_src_dst_pairs = [([Shard(0)], [Shard(1)]), ([Shard(1)], [Shard(0)])] 356 357 comm_mode = CommDebugMode() 358 359 for input_data in data_to_test: 360 for src, dst in sharding_src_dst_pairs: 361 expected_dt = distribute_tensor(input_data.clone(), mesh_1d, dst) 362 sharded_dt = distribute_tensor(input_data, mesh_1d, src) 363 with comm_mode: 364 out_dt = sharded_dt.redistribute(mesh_1d, dst) 365 self.assertEqual(out_dt.placements, expected_dt.placements) 366 local_out_dt = out_dt.to_local() 367 local_expected_dt = expected_dt.to_local() 368 self.assertEqual(out_dt.to_local(), expected_dt.to_local()) 369 if self.device_type == "cuda": 370 self.assertEqual( 371 comm_mode.get_comm_counts()[ 372 torch.ops._dtensor.shard_dim_alltoall 373 ], 374 1, 375 ) 376 else: 377 self.assertEqual( 378 comm_mode.get_comm_counts()[funcol.all_gather_into_tensor], 379 1, 380 ) 381 382 # test 2d device mesh 383 mesh_2d = DeviceMesh( 384 self.device_type, torch.arange(self.world_size).reshape(2, 2) 385 ) 386 data_to_test_2d = [ 387 # evenly sharded case 388 torch.randn((8, 8), device=self.device_type), 389 # 3d or more dims 390 torch.randn((8, 8, 8), device=self.device_type), 391 # uneven case 1 392 torch.randn((8, 5), device=self.device_type), 393 # uneven case 2 394 torch.randn((5, 8), device=self.device_type), 395 # uneven case 3 396 torch.randn((5, 5), device=self.device_type), 397 ] 398 sharding_src_dst_pairs_2d = [ 399 ([Shard(0), Shard(1)], [Shard(0), Shard(0)]), 400 ([Shard(0), Shard(1)], [Shard(1), Shard(0)]), 401 ([Shard(0), Shard(0)], [Shard(1), Shard(1)]), 402 ] 403 comm_counts_2d = [ 404 1, # 1: S1 -> S0 405 2, # 1: S1 -> R, 0: S0 -> S1, 1: R -> S0 406 2, # 1: S0 -> R, 0: S0 -> S1, 1: R -> S1 407 ] 408 409 for input_data in data_to_test_2d: 410 if input_data.ndim > 2: 411 sharding_spec_combs = sharding_src_dst_pairs_2d + [ 412 ([Shard(0), Shard(2)], [Shard(1), Shard(0)]), 413 ([Shard(1), Shard(1)], [Shard(1), Shard(2)]), 414 ] 415 comm_counts_2d = comm_counts_2d + [ 416 2, # 1. S2 -> R, 0: S0 -> S1, 1: R -> S0 417 1, # 1: S1 -> S2 418 ] 419 else: 420 sharding_spec_combs = sharding_src_dst_pairs_2d 421 422 for idx, (src, dst) in enumerate(sharding_spec_combs): 423 expected_dt = distribute_tensor(input_data.clone(), mesh_2d, dst) 424 sharded_dt = distribute_tensor(input_data, mesh_2d, src) 425 with comm_mode: 426 out_dt = sharded_dt.redistribute(mesh_2d, dst) 427 428 self.assertEqual(out_dt.placements, expected_dt.placements) 429 self.assertEqual(comm_mode.get_total_counts(), comm_counts_2d[idx]) 430 431 local_out_dt = out_dt.to_local() 432 local_expected_dt = expected_dt.to_local() 433 self.assertEqual(local_out_dt, local_expected_dt) 434 435 @with_comms 436 def test_shard_dim_alltoall(self): 437 # init 2d mesh here so we can test when group_rank != global_rank 438 mesh = init_device_mesh(self.device_type, (2, 2)) 439 tensor = torch.randn(12, self.world_size, device=self.device_type) 440 new_tensor = shard_dim_alltoall(tensor, 0, 1, mesh, 0) 441 442 meta_tensor = torch.randn(12, self.world_size, device="meta") 443 new_meta_tensor = shard_dim_alltoall(meta_tensor, 0, 1, mesh, 0) 444 445 self.assertEqual(new_tensor.shape, new_meta_tensor.shape) 446 447 448class MultiDimRedistributeTest(DTensorTestBase): 449 @property 450 def world_size(self) -> int: 451 return 8 452 453 @with_comms 454 def test_multi_dim_mesh(self): 455 devices = torch.arange(self.world_size) 456 for mesh_shape in [devices, devices.view(4, 2), devices.view(2, 2, 2)]: 457 mesh_shape = torch.arange(self.world_size).view(-1, 2) 458 device_mesh = DeviceMesh(self.device_type, mesh_shape) 459 tensor_shape = (16, 24) 460 461 if torch.distributed.get_rank() == 0: 462 full_tensor = torch.randn(*tensor_shape) 463 else: 464 # these should be entirely ignored 465 # because distribute_tensor is expected to override shards in ranks != 0 466 full_tensor = torch.ones(*tensor_shape) 467 468 possibilities = [Replicate()] + [Shard(i) for i in range(full_tensor.ndim)] 469 all_outputs = list(itertools.product(*(mesh_shape.ndim * [possibilities]))) 470 all_inputs = list( 471 itertools.product(*(mesh_shape.ndim * [possibilities + [Partial()]])) 472 ) 473 474 for inputs in all_inputs: 475 # if partial, temporarily make it Replicated, then replace replicated with partial afterwards 476 repl_inputs = [Replicate() if s.is_partial() else s for s in inputs] 477 dt = distribute_tensor(full_tensor, device_mesh, repl_inputs) 478 479 if repl_inputs != inputs: 480 # create a new DTensor reinterpreting some of the replicated entires as "Partial" 481 dt = DTensor.from_local( 482 dt.to_local(), device_mesh, inputs, run_check=False 483 ) 484 485 for outputs in all_outputs: 486 # redistribute on target outputs 487 dt2 = dt.redistribute(device_mesh, outputs) 488 489 # replicate and then get first shard 490 local_full = dt2.full_tensor() 491 492 if torch.distributed.get_rank() == 0: 493 self.assertEqual(local_full.shape, full_tensor.shape) 494 495 num_sums = 1 496 for idx, input in enumerate(inputs): 497 if input.is_partial(): 498 num_sums *= mesh_shape.size(idx) 499 expected = num_sums * full_tensor 500 self.assertEqual(local_full, expected) 501 502 @with_comms 503 def test_redistribute_shard_dim_multi_dim_mesh(self): 504 mesh = init_device_mesh(self.device_type, (2, 2, 2)) 505 input_data = torch.randn((8, 8, 8), device=self.device_type) 506 507 sharding_src_dst_pairs_3d = [ 508 ([Shard(0), Shard(0), Shard(0)], [Shard(1), Shard(1), Shard(1)]), 509 ([Shard(0), Shard(1), Shard(0)], [Shard(1), Shard(0), Shard(0)]), 510 ([Shard(0), Shard(1), Shard(2)], [Shard(2), Shard(1), Shard(0)]), 511 ([Shard(1), Shard(0), Shard(0)], [Replicate(), Shard(0), Shard(0)]), 512 ([Shard(1), Replicate(), Shard(0)], [Replicate(), Shard(0), Shard(0)]), 513 ([Shard(0), Shard(0), Shard(1)], [Shard(0), Shard(1), Shard(2)]), 514 ] 515 comm_counts_3d = [ 516 3, # 2: S0 - R, 1: S1 -> R, 0: S0 -> S1 517 3, # 2: S0 -> R, 1: S1 -> R, 0: S0 -> S1, 1: R -> S0, 2: R -> S0 518 2, # 2: S2 -> R, 0: S1 -> S2 519 1, # 0: S1 -> R 520 2, # 2: S0 -> R, 1: R -> S0, 2: R -> S0, 0: S1 -> R 521 2, # 2: S1 -> S2, 1: S0 -> S1 522 ] 523 524 comm_mode = CommDebugMode() 525 for idx, (src_placement, dst_placement) in enumerate(sharding_src_dst_pairs_3d): 526 expected_dt = distribute_tensor(input_data.clone(), mesh, dst_placement) 527 sharded_dt = distribute_tensor(input_data, mesh, src_placement) 528 529 with comm_mode: 530 out_dt = sharded_dt.redistribute(mesh, dst_placement) 531 532 self.assertEqual(out_dt.placements, expected_dt.placements) 533 self.assertEqual(comm_mode.get_total_counts(), comm_counts_3d[idx]) 534 535 local_out_dt = out_dt.to_local() 536 local_expected_dt = expected_dt.to_local() 537 self.assertEqual(local_out_dt, local_expected_dt) 538 539 540if __name__ == "__main__": 541 run_tests() 542