1# Copyright (c) Meta Platforms, Inc. and affiliates 2# Owner(s): ["oncall: distributed"] 3 4from copy import deepcopy 5 6import torch 7import torch.nn as nn 8from torch.distributed._tensor import ( 9 distribute_tensor, 10 DTensor, 11 init_device_mesh, 12 Replicate, 13 Shard, 14) 15from torch.distributed.tensor.debug import CommDebugMode 16from torch.distributed.tensor.parallel import parallelize_module 17from torch.distributed.tensor.parallel.style import ( 18 ColwiseParallel, 19 PrepareModuleInput, 20 PrepareModuleOutput, 21 RowwiseParallel, 22 SequenceParallel, 23) 24from torch.distributed.tensor.placement_types import _Partial 25from torch.testing._internal.common_utils import run_tests 26from torch.testing._internal.distributed._tensor.common_dtensor import ( 27 DTensorTestBase, 28 NUM_DEVICES, 29 RMSNormPython, 30 with_comms, 31) 32 33 34c10d_functional = torch.ops.c10d_functional 35 36 37class TensorParallelStyleTest(DTensorTestBase): 38 @property 39 def world_size(self): 40 return NUM_DEVICES 41 42 @with_comms 43 def test_colwise_parallel_style(self): 44 mesh = init_device_mesh(self.device_type, (self.world_size,)) 45 46 comm_mode = CommDebugMode() 47 tensor = torch.rand(8, 16, device=self.device_type, requires_grad=True) 48 model = nn.Linear(16, 16, device=self.device_type) 49 50 default_col_parallel = ColwiseParallel() 51 colwise_mod = parallelize_module(deepcopy(model), mesh, default_col_parallel) 52 with comm_mode: 53 out = colwise_mod(tensor) 54 # ensure output shard on the last dim 55 self.assertEqual(out.shape, (8, 16 // self.world_size)) 56 # ensure no communication happened in fwd 57 self.assertEqual(comm_mode.get_total_counts(), 0) 58 59 out.sum().backward() 60 # allreduce in bwd 61 self.assertEqual(comm_mode.get_comm_counts()[c10d_functional.all_reduce], 1) 62 self.assertEqual(comm_mode.get_total_counts(), 1) 63 64 sharded_col_parallel = ColwiseParallel(input_layouts=Shard(0)) 65 colwise_mod = parallelize_module(deepcopy(model), mesh, sharded_col_parallel) 66 with comm_mode: 67 out = colwise_mod(tensor) 68 # ensure output shard on the last dim 69 self.assertEqual(out.shape, (8 * self.world_size, 16 // self.world_size)) 70 # allgather in fwd 71 self.assertEqual( 72 comm_mode.get_comm_counts()[c10d_functional.all_gather_into_tensor], 1 73 ) 74 self.assertEqual(comm_mode.get_total_counts(), 1) 75 76 out.sum().backward() 77 # reduce_scatter in bwd 78 self.assertEqual( 79 comm_mode.get_comm_counts()[c10d_functional.reduce_scatter_tensor], 1 80 ) 81 self.assertEqual(comm_mode.get_total_counts(), 2) 82 83 @with_comms 84 def test_colwise_parallel_embedding(self): 85 mesh = init_device_mesh(self.device_type, (self.world_size,)) 86 87 comm_mode = CommDebugMode() 88 tensor = torch.arange(8, device=self.device_type).reshape(4, 2) 89 model = nn.Embedding(16, 16, device=self.device_type) 90 91 default_col_parallel = ColwiseParallel() 92 colwise_mod = parallelize_module(deepcopy(model), mesh, default_col_parallel) 93 with comm_mode: 94 out = colwise_mod(tensor) 95 # ensure output shard on the last dim 96 self.assertEqual(out.shape, (4, 2, 16 // self.world_size)) 97 # ensure no communication happened in fwd 98 self.assertEqual(comm_mode.get_total_counts(), 0) 99 100 out.sum().backward() 101 # no comm in bwd 102 self.assertEqual(comm_mode.get_total_counts(), 0) 103 104 @with_comms 105 def test_rowwise_parallel_style(self): 106 mesh = init_device_mesh(self.device_type, (self.world_size,)) 107 108 comm_mode = CommDebugMode() 109 tensor = torch.rand( 110 8, 16 // self.world_size, device=self.device_type, requires_grad=True 111 ) 112 model = nn.Linear(16, 16, device=self.device_type) 113 114 default_row_parallel = RowwiseParallel() 115 rowwise_mod = parallelize_module(deepcopy(model), mesh, default_row_parallel) 116 with comm_mode: 117 out = rowwise_mod(tensor) 118 # ensure output replicated 119 self.assertEqual(out.shape, (8, 16)) 120 # allreduce in fwd 121 self.assertEqual(comm_mode.get_comm_counts()[c10d_functional.all_reduce], 1) 122 self.assertEqual(comm_mode.get_total_counts(), 1) 123 124 out.sum().backward() 125 # no op in bwd 126 self.assertEqual(comm_mode.get_total_counts(), 1) 127 128 sharded_row_parallel = RowwiseParallel(output_layouts=Shard(0)) 129 rowwise_mod = parallelize_module(deepcopy(model), mesh, sharded_row_parallel) 130 with comm_mode: 131 out = rowwise_mod(tensor) 132 # ensure output replicated 133 self.assertEqual(out.shape, (8 // self.world_size, 16)) 134 # reduce_scatter in fwd 135 self.assertEqual( 136 comm_mode.get_comm_counts()[c10d_functional.reduce_scatter_tensor], 1 137 ) 138 self.assertEqual(comm_mode.get_total_counts(), 1) 139 140 out.sum().backward() 141 # allgather in bwd 142 self.assertEqual( 143 comm_mode.get_comm_counts()[c10d_functional.all_gather_into_tensor], 1 144 ) 145 self.assertEqual(comm_mode.get_total_counts(), 2) 146 147 @with_comms 148 def test_rowwise_parallel_embedding(self): 149 mesh = init_device_mesh(self.device_type, (self.world_size,)) 150 151 comm_mode = CommDebugMode() 152 tensor = torch.arange(8, device=self.device_type).reshape(4, 2) 153 model = nn.Embedding(16, 16, device=self.device_type) 154 155 rowwise_mod = parallelize_module( 156 deepcopy(model), mesh, RowwiseParallel(input_layouts=Replicate()) 157 ) 158 with comm_mode: 159 out = rowwise_mod(tensor) 160 # ensure output shard on the last dim 161 self.assertEqual(out.shape, (4, 2, 16)) 162 # ensure allreduce communication happened in fwd 163 self.assertEqual(comm_mode.get_total_counts(), 1) 164 self.assertEqual(comm_mode.get_comm_counts()[c10d_functional.all_reduce], 1) 165 166 out.sum().backward() 167 # no comm in bwd 168 self.assertEqual(comm_mode.get_total_counts(), 1) 169 170 sharded_row_parallel = RowwiseParallel( 171 input_layouts=Replicate(), output_layouts=Shard(1) 172 ) 173 174 rowwise_mod = parallelize_module(deepcopy(model), mesh, sharded_row_parallel) 175 176 inp_indices = torch.arange(8, device=self.device_type) 177 with comm_mode: 178 out = rowwise_mod(inp_indices) 179 # ensure output shard on the last dim 180 self.assertEqual(out.shape, (8, 16 // self.world_size)) 181 # reduce scatter in fwd 182 self.assertEqual(comm_mode.get_total_counts(), 1) 183 self.assertEqual( 184 comm_mode.get_comm_counts()[c10d_functional.reduce_scatter_tensor], 1 185 ) 186 out.sum().backward() 187 # allgather comm in bwd 188 self.assertEqual(comm_mode.get_total_counts(), 2) 189 self.assertEqual( 190 comm_mode.get_comm_counts()[c10d_functional.all_gather_into_tensor], 1 191 ) 192 193 @with_comms 194 def test_prepare_module_input(self): 195 mesh = init_device_mesh(self.device_type, (self.world_size,)) 196 197 tensor = torch.ones(2, 16, device=self.device_type) 198 expected_tensor = torch.ones(2 * self.world_size, 16, device=self.device_type) 199 prepare_inp_style = PrepareModuleInput( 200 input_layouts=Shard(0), desired_input_layouts=Replicate() 201 ) 202 203 model = nn.Identity() 204 allgather_mod = parallelize_module(model, mesh, prepare_inp_style) 205 output = allgather_mod(tensor).full_tensor() 206 self.assertEqual(output, expected_tensor) 207 208 @with_comms 209 def test_prepare_module_input_multiple_inputs(self): 210 mesh = init_device_mesh(self.device_type, (self.world_size,)) 211 212 class TestModule(torch.nn.Module): 213 def __init__(self) -> None: 214 super().__init__() 215 self.linear = torch.nn.Linear(8, 8) 216 217 def forward(self, x, y): 218 return self.linear(x) + y 219 220 # Raise assertion error if input_layouts and desired_input_layouts do not have same length. 221 test_mod = TestModule().to(self.device_type) 222 with self.assertRaisesRegex( 223 AssertionError, 224 "input_layouts and desired_input_layouts should have same length!", 225 ): 226 prepare_inps_dimension_mismatch = PrepareModuleInput( 227 input_layouts=Shard(0), desired_input_layouts=(Replicate(), None) 228 ) 229 # Raise assertion error if module inputs and input_layouts do not have same length. 230 prepare_inps_short_dimension = PrepareModuleInput( 231 input_layouts=Shard(0), desired_input_layouts=Replicate() 232 ) 233 parallelize_module(test_mod.linear, mesh, ColwiseParallel()) 234 parallelize_module(test_mod, mesh, prepare_inps_short_dimension) 235 with self.assertRaisesRegex( 236 ValueError, "module inputs and input_layouts should have same length!" 237 ): 238 output = test_mod( 239 torch.randn(2, 8, device=self.device_type), 240 torch.ones( 241 self.world_size * 2, 8 // self.world_size, device=self.device_type 242 ), 243 ) 244 245 test_mod = TestModule().to(self.device_type) 246 prepare_inps = PrepareModuleInput( 247 input_layouts=(Shard(0), None), desired_input_layouts=(Replicate(), None) 248 ) 249 250 parallelize_module(test_mod.linear, mesh, ColwiseParallel()) 251 parallelize_module(test_mod, mesh, prepare_inps) 252 output = test_mod( 253 torch.randn(2, 8, device=self.device_type), 254 torch.ones( 255 self.world_size * 2, 8 // self.world_size, device=self.device_type 256 ), 257 ) 258 self.assertEqual(output.shape, (self.world_size * 2, 8 // self.world_size)) 259 260 @with_comms 261 def test_prepare_module_kwargs_input(self): 262 mesh = init_device_mesh(self.device_type, (self.world_size,)) 263 264 class TestKwargModule(torch.nn.Module): 265 def __init__(self) -> None: 266 super().__init__() 267 self.linear = torch.nn.Linear(8, 8) 268 269 def forward(self, x, *, y, z=2): 270 return self.linear(x) + y + z 271 272 test_mod = TestKwargModule().to(self.device_type) 273 prepare_inps_simple = PrepareModuleInput( 274 input_kwarg_layouts={"y": Shard(0)}, 275 desired_input_kwarg_layouts={"y": Replicate()}, 276 ) 277 parallelize_module( 278 test_mod.linear, mesh, ColwiseParallel(use_local_output=False) 279 ) 280 parallelize_module(test_mod, mesh, prepare_inps_simple) 281 282 comm_mode = CommDebugMode() 283 with comm_mode: 284 output = test_mod( 285 torch.randn(1 * self.world_size, 8, device=self.device_type), 286 y=torch.ones(1, 8, device=self.device_type), 287 ) 288 289 self.assertEqual(comm_mode.get_total_counts(), 1) 290 self.assertEqual(output.shape, (1 * self.world_size, 8)) 291 292 class TestKwargOnlyModule(torch.nn.Module): 293 def __init__(self) -> None: 294 super().__init__() 295 self.linear = torch.nn.Linear(8, 8) 296 297 def forward(self, *, x, y=2, z=None): 298 return self.linear(x) + y + z 299 300 test_kwonly_mod = TestKwargOnlyModule().to(self.device_type) 301 prepare_inps_simple = PrepareModuleInput( 302 input_kwarg_layouts={"x": Shard(0), "z": Shard(0)}, 303 desired_input_kwarg_layouts={"x": Replicate(), "z": Replicate()}, 304 ) 305 parallelize_module( 306 test_kwonly_mod.linear, mesh, ColwiseParallel(use_local_output=False) 307 ) 308 parallelize_module(test_kwonly_mod, mesh, prepare_inps_simple) 309 310 with comm_mode: 311 output = test_kwonly_mod( 312 x=torch.randn(1, 8, device=self.device_type), 313 z=torch.ones(1, 8, device=self.device_type), 314 ) 315 316 self.assertEqual(comm_mode.get_total_counts(), 2) 317 self.assertEqual(output.shape, (1 * self.world_size, 8)) 318 319 # test the case where x is a DTensor 320 x_dt = DTensor.from_local( 321 torch.randn(1, 8, device=self.device_type), mesh, [Shard(0)] 322 ) 323 with comm_mode: 324 output = test_kwonly_mod( 325 x=x_dt, z=torch.ones(1, 8, device=self.device_type) 326 ) 327 328 self.assertEqual(comm_mode.get_total_counts(), 2) 329 self.assertEqual(output.shape, (1 * self.world_size, 8)) 330 331 @with_comms 332 def test_prepare_module_output(self): 333 mesh = init_device_mesh(self.device_type, (self.world_size,)) 334 335 tensor = torch.ones(8, 16, device=self.device_type) 336 expected_tensor = torch.ones(8 // self.world_size, 16, device=self.device_type) 337 prepare_out_style = PrepareModuleOutput( 338 output_layouts=Replicate(), desired_output_layouts=Shard(0) 339 ) 340 341 model = nn.Identity() 342 chunk_mod = parallelize_module(model, mesh, prepare_out_style) 343 output = chunk_mod(tensor) 344 self.assertEqual(output, expected_tensor) 345 346 @with_comms 347 def test_sequence_parallel_style(self): 348 mesh = init_device_mesh(self.device_type, (self.world_size,)) 349 350 comm_mode = CommDebugMode() 351 batch, N, embedding_dim = 20, 8, 12 352 353 global_input = torch.rand( 354 batch, 355 N * self.world_size, 356 embedding_dim, 357 device=self.device_type, 358 requires_grad=True, 359 ) 360 sharded_input = distribute_tensor(global_input, mesh, [Shard(1)]) 361 362 # test LayerNorm 363 for elementwise_affine in [True, False]: 364 norm = nn.LayerNorm( 365 embedding_dim, 366 elementwise_affine=elementwise_affine, 367 device=self.device_type, 368 ) 369 sp_norm = parallelize_module(deepcopy(norm), mesh, SequenceParallel()) 370 371 output = norm(global_input) 372 output.sum().backward() 373 374 with comm_mode: 375 sharded_out = sp_norm(sharded_input) 376 grad_out = torch.ones_like(sharded_out) 377 sharded_out.backward(grad_out) 378 self.assertIsInstance(sharded_out, DTensor) 379 self.assertEqual(sharded_out.placements, (Shard(1),)) 380 self.assertEqual(comm_mode.get_total_counts(), 0) 381 self.assertEqual( 382 comm_mode.get_comm_counts()[c10d_functional.all_reduce], 0 383 ) 384 if elementwise_affine: 385 self.assertEqual(sp_norm.weight.grad.placements, (_Partial(),)) 386 self.assertEqual(sp_norm.bias.grad.placements, (_Partial(),)) 387 388 self.assertEqual(sharded_out.full_tensor(), output) 389 390 # test RMSNorm 391 rmsnorm = RMSNormPython(embedding_dim).to(self.device_type) 392 sp_rmsnorm = parallelize_module(deepcopy(rmsnorm), mesh, SequenceParallel()) 393 394 output = rmsnorm(global_input) 395 output.sum().backward() 396 397 with comm_mode: 398 sharded_out = sp_rmsnorm(sharded_input) 399 grad_out = torch.ones_like(sharded_out) 400 sharded_out.backward(grad_out) 401 self.assertIsInstance(sharded_out, DTensor) 402 self.assertEqual(sharded_out.placements, (Shard(1),)) 403 self.assertEqual(sp_rmsnorm.weight.grad.placements, (_Partial(),)) 404 self.assertEqual(comm_mode.get_total_counts(), 0) 405 self.assertEqual(comm_mode.get_comm_counts()[c10d_functional.all_reduce], 0) 406 407 self.assertEqual(sharded_out.full_tensor(), output) 408 409 # test dropout 410 dropout = nn.Dropout(0.5).to(self.device_type) 411 sp_dropout = parallelize_module(deepcopy(dropout), mesh, SequenceParallel()) 412 413 output = dropout(global_input) 414 output.sum().backward() 415 with comm_mode: 416 sharded_out = sp_dropout(sharded_input) 417 grad_out = torch.ones_like(sharded_out) 418 sharded_out.backward(grad_out) 419 self.assertIsInstance(sharded_out, DTensor) 420 self.assertEqual(sharded_out.placements, (Shard(1),)) 421 self.assertEqual(comm_mode.get_total_counts(), 0) 422 423 # test sharded on non-sequence dim input 424 sharded_batch_input = distribute_tensor(global_input, mesh, [Shard(0)]) 425 rmsnorm = RMSNormPython(embedding_dim).to(self.device_type) 426 sp_rmsnorm = parallelize_module(deepcopy(rmsnorm), mesh, SequenceParallel()) 427 428 with comm_mode: 429 sharded_out = sp_rmsnorm(sharded_batch_input) 430 grad_out = torch.ones_like(sharded_out) 431 sharded_out.backward(grad_out) 432 self.assertIsInstance(sharded_out, DTensor) 433 # output still sharded on sequence dimension 434 self.assertEqual(sharded_out.placements, (Shard(1),)) 435 self.assertEqual(sp_rmsnorm.weight.grad.placements, (_Partial(),)) 436 # communication happens in both fwd/bwd to redistribute input 437 self.assertEqual(comm_mode.get_total_counts(), 2) 438 439 440if __name__ == "__main__": 441 run_tests() 442