# Copyright (c) Meta Platforms, Inc. and affiliates # Owner(s): ["oncall: distributed"] import itertools from copy import deepcopy from typing import Dict, NamedTuple, Optional import torch import torch.distributed as dist import torch.nn.functional as F from torch.distributed._tensor import ( DeviceMesh, distribute_tensor, DTensor, Replicate, Shard, ) from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( checkpoint_wrapper, CheckpointImpl, ) from torch.distributed.tensor.debug import CommDebugMode from torch.distributed.tensor.parallel import ( ColwiseParallel, loss_parallel, parallelize_module, RowwiseParallel, ) from torch.distributed.tensor.parallel.input_reshard import input_reshard from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, run_tests, ) from torch.testing._internal.distributed._tensor.common_dtensor import ( DTensorTestBase, MLPModule, ModelArgs, NUM_DEVICES, skip_unless_torch_gpu, Transformer, with_comms, ) c10d_functional = torch.ops.c10d_functional reduce_scatter, all_gather, all_reduce = ( c10d_functional.reduce_scatter_tensor, c10d_functional.all_gather_into_tensor, c10d_functional.all_reduce, ) class ExpCommCounts(NamedTuple): fwd: Optional[Dict] = None bwd: Optional[Dict] = None optim: Optional[Dict] = None class DistTensorParallelExampleTest(DTensorTestBase): def _check_module(self, m1, m2, check_grad=False): named_parameters = dict(m1.named_parameters()) for name, param_m2 in m2.named_parameters(): self.assertTrue(name in named_parameters) param_m1 = named_parameters[name] if check_grad: param_m2 = param_m2.grad param_m1 = param_m1.grad if isinstance(param_m2, DTensor): replicate = [Replicate()] param_m2 = param_m2.redistribute( device_mesh=param_m2.device_mesh, placements=replicate ).to_local() self.assertEqual(param_m2, param_m1) def _test_mlp_training_e2e(self, is_seq_parallel=False, recompute_activation=False): inp_size = [8, 10] # Ensure all tp ranks have same input. rng_seed = self.rank if is_seq_parallel else 0 torch.manual_seed(rng_seed) inp = torch.rand(*inp_size, device=self.device_type) model = MLPModule(self.device_type) model_tp = deepcopy(model) # Ensure model are initialized the same way. self._check_module(model, model_tp) # Shard module and initialize optimizer. LR = 0.25 device_mesh = DeviceMesh( self.device_type, torch.arange(0, NUM_DEVICES), ) parallelize_plan = { "net1": ( ColwiseParallel(input_layouts=Shard(0)) if is_seq_parallel else ColwiseParallel() ), "net2": ( RowwiseParallel(output_layouts=Shard(0)) if is_seq_parallel else RowwiseParallel() ), } model_tp = parallelize_module(model_tp, device_mesh, parallelize_plan) if recompute_activation: model_tp = input_reshard( checkpoint_wrapper( model_tp, checkpoint_impl=CheckpointImpl.NO_REENTRANT ), device_mesh, None if is_seq_parallel else 0, ) optim = torch.optim.SGD(model.parameters(), lr=LR) optim_tp = torch.optim.SGD(model_tp.parameters(), lr=LR) output = model(inp) output.sum().backward() comm_mode = CommDebugMode() with comm_mode: output_tp = model_tp(inp) output_tp.sum().backward() self.assertEqual(output, output_tp) if is_seq_parallel: self.assertEqual( comm_mode.get_comm_counts()[c10d_functional.all_gather_into_tensor], 2 ) self.assertEqual( comm_mode.get_comm_counts()[c10d_functional.reduce_scatter_tensor], 1 ) else: self.assertEqual(comm_mode.get_comm_counts()[c10d_functional.all_reduce], 1) if is_seq_parallel: # Sum gradients from different ranks, since input # are different across ranks for sequence parallel. dist.all_reduce(model.net1.weight.grad) dist.all_reduce(model.net1.bias.grad) dist.all_reduce(model.net2.weight.grad) dist.all_reduce(model.net2.bias.grad) # Ensure gradients are same. self._check_module(model, model_tp, check_grad=True) optim.step() optim_tp.step() # Ensure model weights are still same after update. # Due to the trick we use for Partial aggregation, we only check the weight when local_rank = 0. self._check_module(model, model_tp) inp = torch.rand(*inp_size, device=self.device_type) output = model(inp) output_tp = model_tp(inp) self.assertEqual(output, output_tp) def _test_mlp_inference(self, device_mesh): inp_size = [8, 10] # Ensure all tp ranks have same input. torch.manual_seed(0) inp = torch.rand(*inp_size, device=self.device_type) model = MLPModule(self.device_type) model_tp = deepcopy(model) # Ensure model are initialized the same way. self._check_module(model, model_tp) # Shard module and initialize optimizer. parallelize_plan = { "net1": ColwiseParallel(), "net2": RowwiseParallel(), } model_tp = parallelize_module(model_tp, device_mesh, parallelize_plan) output = model(inp) output_tp = model_tp(inp) self.assertEqual(output, output_tp) @with_comms @parametrize("is_seq_parallel", [True, False]) # TODO: need to revisit input_reshard API about why it failed multi-gpu tests. # @parametrize("recompute_activation", [True, False]) @parametrize("recompute_activation", [False]) def test_mlp_training(self, is_seq_parallel, recompute_activation): self._test_mlp_training_e2e( is_seq_parallel=is_seq_parallel, recompute_activation=recompute_activation ) @with_comms def test_mlp_inference(self): device_mesh = DeviceMesh( self.device_type, torch.arange(0, NUM_DEVICES), ) with torch.inference_mode(): self._test_mlp_inference(device_mesh) def _setup_single_gpu_model(self, model_args, dtype): return Transformer(model_args).to(device=self.device_type, dtype=dtype) def _setup_tp_model(self, model, is_seq_parallel, dtype): model_tp = deepcopy(model) self._check_module(model, model_tp) device_mesh = DeviceMesh(self.device_type, torch.arange(0, NUM_DEVICES)) local_output_for_attn = dtype is torch.float64 return Transformer.parallelize( model_tp, device_mesh, is_seq_parallel, local_output_for_attn=local_output_for_attn, ) def _setup_optimizer(self, model, model_tp): # Step 3: Run test by comparing outputs from single-gpu and multi-gpu models. LR = 0.25 optim = torch.optim.Adam(model.parameters(), lr=LR) optim_tp = torch.optim.Adam(model_tp.parameters(), lr=LR) return optim, optim_tp def _validate_fwd( self, model, model_tp, inp, expected_comms_dict=None, check_comms=True ): # Compare outputs on the same input. output = model(inp) with CommDebugMode() as comm_mode: output_tp = model_tp(inp) self.assertEqual(output, output_tp) if check_comms: self.assertDictEqual(comm_mode.get_comm_counts(), expected_comms_dict or {}) return output, output_tp def _validate_bwd( self, model, model_tp, output, output_tp, expected_comms_dict=None, check_comms=True, ): # Ensure gradients are equal. output.sum().backward() with CommDebugMode() as comm_mode: output_tp.sum().backward() self._check_module(model, model_tp, check_grad=True) if check_comms: self.assertDictEqual(comm_mode.get_comm_counts(), expected_comms_dict or {}) def _validate_optim_step( self, model, model_tp, optim, optim_tp, expected_comms_dict=None, check_comms=True, ): optim.step() # Ensure model weights are still the same after update. from torch.distributed._tensor.experimental import implicit_replication with implicit_replication(): with CommDebugMode() as comm_mode: optim_tp.step() self._check_module(model, model_tp) if check_comms: self.assertDictEqual(comm_mode.get_comm_counts(), expected_comms_dict or {}) @staticmethod def _thaw_params(thaw_params, model, model_tp): if not thaw_params: return for target_model in [model, model_tp]: for n, p in target_model.named_parameters(): if n not in thaw_params: p.requires_grad_(False) @with_comms @skip_unless_torch_gpu @parametrize("is_seq_parallel", [True, False]) @parametrize("dtype", [torch.float64, torch.float32]) def test_transformer_training(self, is_seq_parallel, dtype: torch.dtype): EXP_BASE_CC = ExpCommCounts( fwd={all_reduce: 6, all_gather: 1}, bwd={all_reduce: 9} ) EXP_SEQ_PARALLEL_CC = ExpCommCounts( fwd={reduce_scatter: 6, all_gather: 6}, bwd={reduce_scatter: 5, all_gather: 6}, optim={all_reduce: 30}, ) # Disable dropout in the test since we cannot reproduce the same random # behaviors when comparing single-gpu models with multi-gpu models. model_args = ModelArgs(dropout_p=0.0) model = self._setup_single_gpu_model( model_args, dtype ) # Step 1: Initialize single-gpu models. model_tp = self._setup_tp_model( model, is_seq_parallel, dtype ) # Step 2: Setup tp model, place onto device mesh. optim, optim_tp = self._setup_optimizer( model, model_tp ) # Step 3: Setup optimizers for both models # Initialize input and make sure all ranks have the same input. inp_size = [8, 8] # [batch_size, seq_len] if is_seq_parallel: assert inp_size[1] % self.world_size == 0 torch.manual_seed(0) steps = 10 if type(model) is torch.float64 else 1 for iter in range(steps): inp = torch.randint( model_args.vocab_size, inp_size, device=self.device_type ) expected_fwd_comms = ( EXP_SEQ_PARALLEL_CC.fwd if is_seq_parallel else EXP_BASE_CC.fwd ) output, output_tp = self._validate_fwd( model, model_tp, inp, expected_fwd_comms ) expected_bwd_comms = ( EXP_SEQ_PARALLEL_CC.bwd if is_seq_parallel else EXP_BASE_CC.bwd ) self._validate_bwd(model, model_tp, output, output_tp, expected_bwd_comms) expected_optim_comms = ( EXP_SEQ_PARALLEL_CC.optim if is_seq_parallel else EXP_BASE_CC.optim ) self._validate_optim_step( model, model_tp, optim, optim_tp, expected_optim_comms ) @with_comms @skip_unless_torch_gpu @parametrize( "thaw_params, is_seq_parallel, dtype, exp_cnts", [ ( None, # all require grad seq_parallel float32 baseline True, torch.float32, ExpCommCounts( bwd={reduce_scatter: 5, all_gather: 6}, optim={all_reduce: 30} ), ), ( None, # all require grad no seq_parallel float64 baseline False, torch.float64, ExpCommCounts(bwd={all_reduce: 9}), ), # test a subset of LayerNorm bwd output_masks ( ("output.weight", "norm.weight", "norm.bias"), # [False, True, True] True, torch.float32, ExpCommCounts(bwd={reduce_scatter: 1}, optim={all_reduce: 6}), ), ( ("tok_embeddings.weight", "output.weight"), # [True, False, False] True, torch.float32, ExpCommCounts(bwd={reduce_scatter: 5, all_gather: 5}), ), ( ( "tok_embeddings.weight", "output.weight", "norm.weight", "norm.bias", ), # [True, True, True] True, torch.float32, ExpCommCounts( bwd={reduce_scatter: 5, all_gather: 5}, optim={all_reduce: 6} ), ), ( ( "tok_embeddings.weight", "output.weight", "norm.weight", "norm.bias", "layers.1.ffn_norm.weight", "layers.1.ffn_norm.bias", ), # a single transformerblock layernorm True, torch.float32, ExpCommCounts( bwd={reduce_scatter: 5, all_gather: 5}, optim={all_reduce: 12} ), ), ( ( "tok_embeddings.weight", "layers.0.attention.wv.weight", "layers.0.feed_forward.w1.bias", "layers.1.ffn_norm.bias", "layers.1.feed_forward.w2.weight", "output.weight", ), # varied layer/param types True, torch.float32, ExpCommCounts( bwd={reduce_scatter: 5, all_gather: 5}, optim={all_reduce: 3} ), ), ], name_fn=lambda thaw, seq, dtype, *_: f"{'seq_parallel_' if seq else ''}" + f"{str(dtype).split('.')[-1]}_" + f"thaw_{'__'.join(sorted({n.rpartition('.')[0].replace('.', '_') for n in thaw})) if thaw else 'all'}", ) def test_transformer_req_grad(self, thaw_params, is_seq_parallel, dtype, exp_cnts): # Sample a subset of `requires_grad` patterns # disabling dropout to facilitate single gpu to multi-device comparison # disable weight-tying to enable more fine-tuning configurations model_args = ModelArgs(dropout_p=0.0, weight_tying=False) model = self._setup_single_gpu_model( model_args, dtype ) # Step 1: Initialize single-gpu models. model_tp = self._setup_tp_model( model, is_seq_parallel, dtype ) # Step 2: Setup tp model, place onto device mesh. optim, optim_tp = self._setup_optimizer( model, model_tp ) # Step 3: Setup optimizers for both models DistTensorParallelExampleTest._thaw_params( thaw_params, model, model_tp ) # Step 4: set `requires_grad` patterns # Initialize input and make sure all ranks have the same input. inp_size = [8, 8] # [batch_size, seq_len] if is_seq_parallel: assert inp_size[1] % self.world_size == 0 torch.manual_seed(0) inp = torch.randint(model_args.vocab_size, inp_size, device=self.device_type) output, output_tp = self._validate_fwd(model, model_tp, inp, check_comms=False) self._validate_bwd( model, model_tp, output, output_tp, exp_cnts.bwd, check_comms=True ) self._validate_optim_step( model, model_tp, optim, optim_tp, exp_cnts.optim, check_comms=True ) @with_comms def test_weight_tying(self): class TestModule(torch.nn.Module): def __init__(self) -> None: super().__init__() # Initialize different weights for embedding and fc. torch.manual_seed(1) self.embedding = torch.nn.Embedding(16, 8) torch.manual_seed(2) self.fc = torch.nn.Linear(8, 16) def forward(self, x): return self.fc(self.embedding(x)) model = TestModule().to(self.device_type) parallelize_plan = { "embedding": ColwiseParallel(), "fc": RowwiseParallel(), } device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) parallelize_module(model, device_mesh, parallelize_plan) input_size = [5] torch.manual_seed(0) inp = torch.randint(16, input_size, device=self.device_type) # Without weight tying. self.assertNotEqual( model.embedding.weight.to_local(), model.fc.weight.to_local() ) output = model(inp) output.sum().backward() self.assertNotEqual( model.embedding.weight.grad.to_local(), model.fc.weight.grad.to_local() ) model.zero_grad() # With weight tying. model.fc.weight = model.embedding.weight self.assertEqual(model.embedding.weight, model.fc.weight) self.assertEqual(id(model.embedding.weight), id(model.fc.weight)) output = model(inp) output.sum().backward() self.assertEqual(model.embedding.weight.grad, model.fc.weight.grad) self.assertEqual(id(model.embedding.weight.grad), id(model.fc.weight.grad)) @with_comms def test_loss_parallel(self): device_mesh = self.build_device_mesh() comm_mode = CommDebugMode() channel_size, channel_dim = 16, 1 test_setup = [ (2, (8, channel_size), (8,)), # calling aten.nll_loss_forward (3, (8, channel_size, 12), (8, 12)), # calling aten.nll_loss2d_forward ] weight = torch.rand(channel_size, device=self.device_type) for input_ndim, input_size, target_size in test_setup: x = torch.rand(*input_size, device=self.device_type, requires_grad=True) target = torch.randint(channel_size, target_size, device=self.device_type) shard_dims = list(range(input_ndim)) reductions = ["none", "mean", "sum"] for shard_dim, reduction in itertools.product(shard_dims, reductions): dist_x = distribute_tensor(x, device_mesh, [Shard(shard_dim)]) y = F.cross_entropy(x, target, weight, reduction=reduction) with loss_parallel(): if shard_dim == channel_dim: with comm_mode: dist_y = F.cross_entropy( dist_x, target, weight, reduction=reduction ) self.assertEqual(comm_mode.get_total_counts(), 3) self.assertEqual( comm_mode.get_comm_counts()[c10d_functional.all_reduce], 3, ) self.assertTrue(dist_y.placements[0].is_replicate()) self.assertEqual(dist_y.to_local(), y) with comm_mode: if reduction == "none": y.sum().backward() dist_y.sum().backward() else: y.backward() dist_y.backward() self.assertEqual(comm_mode.get_total_counts(), 0) self.assertTrue( dist_x.grad.placements[0].is_shard(shard_dim) ) self.assertEqual(dist_x.grad.full_tensor(), x.grad) x.grad.zero_() else: with self.assertRaisesRegex( ValueError, "loss_parallel", ): dist_y = F.cross_entropy( dist_x, target, reduction=reduction ) instantiate_parametrized_tests(DistTensorParallelExampleTest) if __name__ == "__main__": run_tests()