# Copyright (c) Meta Platforms, Inc. and affiliates # Owner(s): ["oncall: distributed"] import copy import itertools from pprint import pformat from typing import NamedTuple import torch from torch.distributed._tensor import ( DeviceMesh, distribute_module, distribute_tensor, DTensor, ) from torch.distributed._tensor.placement_types import Replicate, Shard from torch.distributed.tensor._ops.utils import is_tensor_partial, normalize_dim from torch.distributed.tensor.debug import CommDebugMode from torch.distributed.tensor.parallel import ( ColwiseParallel, parallelize_module, RowwiseParallel, SequenceParallel, ) from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import ( DTensorTestBase, skip_unless_torch_gpu, with_comms, ) funcol = torch.ops.c10d_functional class DistMathOpsTest(DTensorTestBase): def _check_module(self, m1, m2, check_grad=False): named_parameters = dict(m1.named_parameters()) for name, param_m2 in m2.named_parameters(): self.assertTrue(name in named_parameters) param_m1 = named_parameters[name] if check_grad: param_m2 = param_m2.grad param_m1 = param_m1.grad if isinstance(param_m2, DTensor): replicate = [Replicate()] param_m2 = param_m2.redistribute( device_mesh=param_m2.device_mesh, placements=replicate ).to_local() self.assertEqual(param_m2, param_m1) def linear_op_reductions(self, op_str): device_mesh = self.build_device_mesh() shard_spec = [Shard(0)] tensor = torch.randn(12, 8, 8) # TODO: check `all` correctness and test `all` on a bool tensor if op_str in ("any"): # test out a bool tensor for any tensor = tensor < 0 dtensor = distribute_tensor(tensor, device_mesh, shard_spec) op = getattr(tensor, op_str) op_dt = getattr(dtensor, op_str) keep_dim_or_not = [True, False, None] for dim in range(tensor.ndim): for keep_dim in keep_dim_or_not: args = (dim, keep_dim) if keep_dim is not None else (dim,) if op_str in ("max", "min"): # min and max return a tuple when dim specified dim_reduced_tensor, _ = op(*args) dt_reduced, _ = op_dt(*args) else: dim_reduced_tensor = op(*args) dt_reduced = op_dt(*args) dt_dim_reduced_tensor = dt_reduced.full_tensor() self.assertEqual(dt_dim_reduced_tensor, dim_reduced_tensor) full_reduced_tensor = op() dt_full_reduced = op_dt().full_tensor() self.assertEqual(dt_full_reduced, full_reduced_tensor) @with_comms def test_linear_op_reductions(self): for op_str in ("all", "sum", "prod", "max", "min", "any"): self.linear_op_reductions(op_str) @with_comms @skip_unless_torch_gpu def test_mean(self): self.linear_op_reductions("mean") # TODO: forward test can be removed once test_softmax_with_bwd passes on CPU @with_comms def test_softmax_fwd(self): device_mesh = self.build_device_mesh() x = torch.rand(8, 12, 16, device=self.device_type) dims = range(3) # used to convert -1 to the actual dim softmax_dims = [-1, 0, 1, 2] shard_dims = [-1, 0, 1, 2] test_list = list(itertools.product(softmax_dims, shard_dims)) for softmax_dim, shard_dim in test_list: local_y = torch.nn.functional.softmax( x, dim=softmax_dim, dtype=torch.float32 ) dist_x = distribute_tensor(x, device_mesh, [Shard(shard_dim)]) dist_y = torch.nn.functional.softmax( dist_x, dim=softmax_dim, dtype=torch.float32 ) shard_dim = normalize_dim(shard_dim, dist_x.ndim) if dims[shard_dim] == dims[softmax_dim]: self.assertTrue(dist_y.placements[0].is_replicate()) self.assertEqual(dist_y.to_local(), local_y) else: self.assertTrue(dist_y.placements[0].is_shard(dim=shard_dim)) self.assertEqual(dist_y.full_tensor(), local_y) # TODO: get test_softmax_with_bwd pass on CPU # DTensor's _softmax_backward_data produces wrong result on CPU on certain dimension. # fail_on_cpu_list = [(0, -1), (1, -1)] @with_comms @skip_unless_torch_gpu def test_softmax_with_bwd(self): device_mesh = self.build_device_mesh() dims = range(3) # used to convert -1 to the actual dim softmax_dims = [-1, 0, 1, 2] shard_dims = [-1, 0, 1, 2] test_list = list(itertools.product(softmax_dims, shard_dims)) for params in test_list: softmax_dim, shard_dim = params x = torch.rand(8, 12, 16, device=self.device_type, requires_grad=True) self.assertTrue(x.requires_grad) local_y = torch.nn.functional.softmax( x, dim=softmax_dim, dtype=torch.float32 ).sum() local_y.backward() dist_x = distribute_tensor(x, device_mesh, [Shard(shard_dim)]) self.assertTrue(dist_x.requires_grad) dist_softmax = dist_x.softmax(dim=softmax_dim) shard_dim = normalize_dim(shard_dim, dist_x.ndim) if dims[softmax_dim] == dims[shard_dim]: self.assertTrue(dist_softmax.placements[0].is_replicate()) else: self.assertTrue(dist_softmax.placements[0].is_shard(dim=shard_dim)) dist_y = dist_softmax.sum() if dims[softmax_dim] == dims[shard_dim]: self.assertTrue(dist_y.placements[0].is_replicate()) else: self.assertTrue(dist_y.placements[0].is_partial()) dist_y = dist_y.redistribute(device_mesh, [Replicate()]) self.assertEqual(dist_y.to_local(), local_y) self.assertIsNone(dist_x.grad) dist_y.backward() self.assertIsNotNone(dist_x.grad) if dims[softmax_dim] == dims[shard_dim]: self.assertTrue(dist_x.grad.placements[0].is_replicate()) else: self.assertTrue(dist_x.grad.placements[0].is_shard(dim=shard_dim)) self.assertEqual(dist_x.grad.full_tensor(), x.grad) @with_comms @skip_unless_torch_gpu def test_nll_loss_and_cross_entropy(self): device_mesh = self.build_device_mesh() comm_mode = CommDebugMode() channel_size, channel_dim = 16, 1 test_setup = [ (2, (8, channel_size), (8,)), # calling aten.nll_loss_forward (3, (8, channel_size, 12), (8, 12)), # calling aten.nll_loss2d_forward ] for input_ndim, input_size, target_size in test_setup: x = torch.rand(*input_size, device=self.device_type, requires_grad=True) target = torch.randint(channel_size, target_size, device=self.device_type) dist_target = distribute_tensor(target, device_mesh, [Replicate()]) shard_dims = list(range(input_ndim)) reductions = ["none", "mean", "sum"] # Compared with nll_loss, cross_entropy additionally calls log_softmax first. # Testing them together as code can be reused. loss_functions = [ torch.nn.functional.nll_loss, torch.nn.functional.cross_entropy, ] for shard_dim, reduction, loss_fn in itertools.product( shard_dims, reductions, loss_functions ): dist_x = distribute_tensor(x, device_mesh, [Shard(shard_dim)]) y = loss_fn(x, target, reduction=reduction) if reduction == "none": y.sum().backward() else: y.backward() with comm_mode: dist_y = loss_fn(dist_x, dist_target, reduction=reduction) if shard_dim == channel_dim: self.assertEqual(comm_mode.get_total_counts(), 1) self.assertEqual( comm_mode.get_comm_counts()[funcol.all_gather_into_tensor], 1, ) self.assertTrue(dist_y.placements[0].is_replicate()) self.assertEqual(dist_y.to_local(), y) else: self.assertEqual(comm_mode.get_total_counts(), 0) if reduction == "none": output_shard_dim = ( shard_dim if shard_dim < channel_dim else shard_dim - 1 ) self.assertTrue( dist_y.placements[0].is_shard(dim=output_shard_dim) ) else: self.assertTrue(dist_y.placements[0].is_partial()) self.assertEqual(dist_y.full_tensor(), y) if reduction == "none": dist_y.sum().backward() else: dist_y.backward() if shard_dim == channel_dim: self.assertTrue(dist_x.grad.placements[0].is_replicate()) self.assertEqual(dist_x.grad.to_local(), x.grad) else: self.assertTrue( dist_x.grad.placements[0].is_shard(dim=shard_dim) ) self.assertEqual(dist_x.grad.full_tensor(), x.grad) x.grad.zero_() @with_comms def test_shard_math_ops(self): mesh_shape = (2, self.world_size // 2) mesh = DeviceMesh( self.device_type, torch.arange(self.world_size).reshape(*mesh_shape), ) global_tensor = torch.ones(4, 4) double_shard_tensor = distribute_tensor( global_tensor, mesh, [Shard(0), Shard(0)] ) fully_shard_tensor = distribute_tensor( global_tensor, mesh, [Shard(0), Shard(1)] ) # for op in [torch.add, torch.sub, torch.mul, torch.div]: for op in [torch.add, torch.sub, torch.mul, torch.div]: expect_rs = op(global_tensor, 2) double_shard_full_tensor = op(double_shard_tensor, 2).full_tensor() self.assertEqual(double_shard_full_tensor, expect_rs) fully_shard_full_tensor = op(fully_shard_tensor, 2).full_tensor() self.assertEqual(fully_shard_full_tensor, expect_rs) @with_comms def test_layer_norm_fwd(self): device_mesh = self.build_device_mesh() # NLP example from pytorch docs # https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html batch, sentence_length, embedding_dim = 20, 5, 10 x = torch.rand(batch, sentence_length, embedding_dim, device=self.device_type) norm_shape_idx_list = list(range(x.ndim)) shard_dims = [-1, 0, 1, 2] elementwise_affine_list = [False, True] test_config_list = list( itertools.product(shard_dims, norm_shape_idx_list, elementwise_affine_list) ) # normalized shape is a torch.Size object for shard_dim, norm_idx, elementwise_affine in test_config_list: normalized_shape = x.shape[norm_idx:] layer_norm = torch.nn.LayerNorm( normalized_shape, elementwise_affine=elementwise_affine, device=self.device_type, ) layer_norm_local = copy.deepcopy(layer_norm).to(self.device_type) def _replicate_fn(name, module, device_mesh): for name, param in module.named_parameters(): if name in ["weight", "bias"]: param_dist = torch.nn.Parameter( distribute_tensor(param, device_mesh, [Replicate()]) ) module.register_parameter(name, param_dist) layer_norm_dist = distribute_module(layer_norm, device_mesh, _replicate_fn) x_local = x x_dist = distribute_tensor(x, device_mesh, [Shard(shard_dim)]) y_local = layer_norm_local(x_local) # make sure that forward layer norm does not introduce extra collectives comm_mode = CommDebugMode() with comm_mode: y_dist = layer_norm_dist(x_dist) self.assertLessEqual( comm_mode.get_total_counts(), 1, # TODO: This should be 0! f"comm count={comm_mode.get_total_counts()}, " f"shard_dim={shard_dim}, norm_shape={normalized_shape}, elem_affine={elementwise_affine}", ) from torch.distributed._tensor.placement_types import TensorMeta dtensor_meta = y_dist._spec.tensor_meta assert isinstance(dtensor_meta, TensorMeta) # make sure the right shape in sharding prop self.assertEqual(y_local.shape, dtensor_meta.shape) self.assertEqual(y_local, y_dist.full_tensor()) @with_comms def test_layer_norm_bwd(self): device_mesh = self.build_device_mesh() # NLP example from pytorch docs # https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html batch, sentence_length, embedding_dim = 20, 5, 10 norm_shape_idx_list = list(range(3)) shard_dims = [0, 1, 2] elementwise_affine_list = [False, True] test_config_list = list( itertools.product(shard_dims, norm_shape_idx_list, elementwise_affine_list) ) # normalized shape is a torch.Size object for shard_dim, norm_idx, elementwise_affine in test_config_list: x = torch.rand( batch, sentence_length, embedding_dim, device=self.device_type, requires_grad=True, ) normalized_shape = x.shape[norm_idx:] layer_norm = torch.nn.LayerNorm( normalized_shape, elementwise_affine=elementwise_affine, device=self.device_type, ) layer_norm_local = copy.deepcopy(layer_norm).to(self.device_type) def _replicate_fn(name, module, device_mesh): for name, param in module.named_parameters(): if name in ["weight", "bias"]: param_dist = torch.nn.Parameter( distribute_tensor(param, device_mesh, [Replicate()]) ) module.register_parameter(name, param_dist) layer_norm_dist = distribute_module(layer_norm, device_mesh, _replicate_fn) if elementwise_affine: self.assertEqual( layer_norm_local.weight, layer_norm_dist.weight.full_tensor() ) self.assertEqual( layer_norm_local.bias, layer_norm_dist.bias.full_tensor() ) x_local = x.detach().clone().requires_grad_(True) x_dist = distribute_tensor(x, device_mesh, [Shard(shard_dim)]) self.assertEqual(x_local, x_dist.full_tensor()) y_local = layer_norm_local(x_local) # make sure that backward layer norm does not introduce extra collectives comm_mode = CommDebugMode() with comm_mode: y_dist = layer_norm_dist(x_dist) y_dist.sum().backward() expected_fwd_comm = 0 if shard_dim < norm_idx else 1 self.assertEqual( sum(comm_mode.comm_module_counts["Global"]["forward"].values()), expected_fwd_comm, f"comm count={comm_mode.get_total_counts()}, " f"shard_dim={shard_dim}, norm_shape={normalized_shape}, elem_affine={elementwise_affine}", ) self.assertEqual(y_local, y_dist.full_tensor()) # backward step y_local.sum().backward() expected_bwd_comm = 0 if shard_dim < norm_idx else 1 self.assertEqual( sum(comm_mode.comm_module_counts["Global"]["backward"].values()), expected_bwd_comm, f"comm count={comm_mode.get_total_counts()}, " f"shard_dim={shard_dim}, norm_shape={normalized_shape}, elem_affine={elementwise_affine}", ) if elementwise_affine: # if input is sharded on any outer dimension, the gradient of weight # and bias should be Partial dim_map = x_dist._spec.dim_map outer_dims = range(norm_idx) needs_reduction = any(dim_map[d] >= 0 for d in outer_dims) self.assertEqual( is_tensor_partial(layer_norm_dist.weight.grad._spec), needs_reduction, ) self.assertEqual( is_tensor_partial(layer_norm_dist.bias.grad._spec), needs_reduction, ) self.assertEqual( layer_norm_local.weight.grad, layer_norm_dist.weight.grad.full_tensor(), ) self.assertEqual( layer_norm_local.bias.grad, layer_norm_dist.bias.grad.full_tensor(), ) self.assertEqual(x_local.grad, x_dist.grad.full_tensor()) @with_comms def test_layer_norm_bwd_req_grad(self): device_mesh = self.build_device_mesh() batch, seq_len, embedding_dim, vocab_size = 8, 8, 10, 32 # build our subtest configurations and filter out invalid ones class SubTest(NamedTuple): multidim_norm: bool elementwise_affine: bool emb_req_grad: bool ln_req_grad: bool out_req_grad: bool subtest_fails = {} valid_filter = lambda cfg: not ( # noqa: E731 cfg.ln_req_grad and not cfg.elementwise_affine ) and any(cfg[2:]) subtest_cfgs = list( filter( valid_filter, [SubTest(*cfg) for cfg in itertools.product(*(((False, True),) * 5))], ) ) for subtest_cfg in subtest_cfgs: try: ( multidim_norm, elementwise_affine, emb_req_grad, ln_req_grad, out_req_grad, ) = subtest_cfg normalized_shape = ( (seq_len, embedding_dim) if multidim_norm else (embedding_dim,) ) # configure our local and parallelized models for this subtest class LnTpBlock(torch.nn.Module): def __init__(self): super().__init__() self.preln_embeddings = torch.nn.Embedding( vocab_size, embedding_dim ) self.layer_norm = torch.nn.LayerNorm( normalized_shape, elementwise_affine=elementwise_affine ) self.postln_linear = torch.nn.Linear( embedding_dim, embedding_dim ) def forward(self, tokens): h = self.preln_embeddings(tokens) h = self.layer_norm(h) output = self.postln_linear(h) return output parallel_plan = { "preln_embeddings": RowwiseParallel( input_layouts=Replicate(), output_layouts=Shard(1) ), "layer_norm": SequenceParallel(), "postln_linear": ColwiseParallel( input_layouts=Shard(1), output_layouts=Replicate(), ), } model = LnTpBlock() model_local = copy.deepcopy(model).to(device=self.device_type) model_dist = parallelize_module(model, device_mesh, parallel_plan) req_grad_map = { "preln_embeddings": emb_req_grad, "postln_linear": out_req_grad, "layer_norm": ln_req_grad, } # apply the relevant `requires_grad` mask for this subtest to both models for target_model in [model_local, model_dist]: for n, p in target_model.named_parameters(): if not req_grad_map.get(n.rpartition(".")[0], False): p.requires_grad_(False) assert not p.requires_grad else: assert p.requires_grad # forward step for both local and distributed models x = torch.randint(vocab_size, (batch, seq_len), device=self.device_type) x_local = x.detach().clone() output_local = model_local(x_local) with CommDebugMode() as comm_mode: output_dist = model_dist(x) self.assertEqual(output_local, output_dist) # all requires_grad patterns should have the same forward comm counts expected_fwd_comm = { funcol.reduce_scatter_tensor: 1, funcol.all_gather_into_tensor: 2, } self.assertDictEqual( comm_mode.comm_module_counts["Global"]["forward"], expected_fwd_comm ) # backward step output_local.sum().backward() with CommDebugMode() as comm_mode: output_dist.sum().backward() # ensure gradients (and parameters) remain equal between local and distributed models self._check_module(model_local, model_dist, check_grad=True) # different requires_grad patterns will have different bwd comm counts if out_req_grad and not any((emb_req_grad, ln_req_grad)): expected_bwd_comm = {} elif ln_req_grad and not any((emb_req_grad, multidim_norm)): expected_bwd_comm = {funcol.reduce_scatter_tensor: 1} elif multidim_norm: expected_bwd_comm = {funcol.all_reduce: 1} expected_bwd_comm[funcol.all_gather_into_tensor] = ( 2 if emb_req_grad else 1 ) else: expected_bwd_comm = { funcol.reduce_scatter_tensor: 1, funcol.all_gather_into_tensor: 1, } self.assertDictEqual( comm_mode.comm_module_counts["Global"]["backward"], expected_bwd_comm, ) self.assertEqual(output_local, output_dist) except Exception as e: subtest_fails[subtest_cfg] = e # if any subtest fails, provide the failed subtests and report the overall failure assert ( not subtest_fails ), f"{len(subtest_fails)}/{len(subtest_cfgs)} subtests failed: {pformat(subtest_fails)}" @with_comms def test_topk(self): device_mesh = self.build_device_mesh() placement_combs = [Shard(0), Shard(1), Shard(2), Replicate()] comm_mode = CommDebugMode() tensor = torch.randn(12, 8, 8, requires_grad=True) global_topk = tensor.topk(3, dim=0) for placement in placement_combs: dtensor = distribute_tensor(tensor, device_mesh, (placement,)) with comm_mode: out_dt = dtensor.topk(3, dim=0) if placement.is_shard(0): self.assertEqual(comm_mode.get_total_counts(), 1) self.assertEqual( comm_mode.get_comm_counts()[funcol.all_gather_into_tensor], 1, ) out_full_values = out_dt.values.full_tensor() self.assertEqual(global_topk.values, out_full_values) # TODO: support backward scatter # global_topk.values.sum().backward() # out_full_values.sum().backward() @with_comms def test_shard0_svd(self): device_mesh = self.build_device_mesh() torch.manual_seed(42) replicated_x = torch.randn((8, 8), device=self.device_type) sharded_x = distribute_tensor(replicated_x, device_mesh, (Shard(0),)) with CommDebugMode() as comm_mode: U, S, V = torch.linalg.svd(sharded_x, full_matrices=False) ref_U, ref_S, ref_V = torch.linalg.svd(replicated_x, full_matrices=False) self.assertEqual(U.to_local(), ref_U) self.assertEqual(S.to_local(), ref_S) self.assertEqual(V.to_local(), ref_V) comm_counts = comm_mode.get_comm_counts() self.assertEqual(len(comm_counts), 1) self.assertEqual(comm_counts[funcol.all_gather_into_tensor], 1) @with_comms def test_foreach_norm(self): device_mesh = self.build_device_mesh() grad0 = torch.randn(12, 8) grad1 = torch.randn(8, 8) sharded_grad0 = distribute_tensor(grad0, device_mesh, [Shard(0)]) sharded_grad1 = distribute_tensor(grad1, device_mesh, [Shard(0)]) # non-sharded op out = torch.ops.aten._foreach_norm([grad0, grad1], 2) # sharded op sharded_out = torch.ops.aten._foreach_norm([sharded_grad0, sharded_grad1], 2) for o, so in zip(out, sharded_out): self.assertEqual(so.full_tensor(), o) @with_comms def test_linalg_eigh(self): A = torch.randn(2, 2, dtype=torch.float64) mesh = self.build_device_mesh() dtensor_A = distribute_tensor(A, device_mesh=mesh, placements=[Replicate()]) dtensor_A = dtensor_A + dtensor_A.mT dtensor_L, dtensor_Q = torch.linalg.eigh(dtensor_A) # TODO: we need to convert A, L, Q to local because we don't have a # sharding strategy registered for aten.dist.default yet. local_A, local_L, local_Q = ( dtensor_A.to_local(), dtensor_L.to_local(), dtensor_Q.to_local(), ) distance = torch.dist(local_Q @ torch.diag(local_L) @ local_Q.mT, local_A) self.assertEqual(distance.item(), 0.0) if __name__ == "__main__": run_tests()