# Copyright (c) Meta Platforms, Inc. and affiliates # Owner(s): ["oncall: distributed"] import itertools import torch import torch.distributed._functional_collectives as funcol import torch.distributed.tensor._random as random from torch.distributed._tensor import DeviceMesh, DTensor from torch.distributed._tensor._utils import compute_local_shape_and_global_offset from torch.distributed._tensor.api import distribute_tensor from torch.distributed._tensor.placement_types import Replicate, Shard from torch.distributed.distributed_c10d import broadcast_object_list from torch.distributed.tensor._random import is_rng_supported_mesh, manual_seed from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import ( DTensorTestBase, skip_if_lt_x_gpu, skip_unless_torch_gpu, with_comms, ) class DistTensorRandomInitTest(DTensorTestBase): def _run_init_op(self, init_op, *args, **kwargs): device_mesh = self.build_device_mesh() shard_spec = [Shard(0)] input_size = (8, 4) # NOTE: currently random initialization on cuda device has different # behavior from other devices. Unify the test once the behavior is unified. if not is_rng_supported_mesh(device_mesh): input_tensor = torch.randn(*input_size, device=self.device_type) dtensor = DTensor.from_local(input_tensor, device_mesh, shard_spec) local_tensor_clone = torch.clone(input_tensor) torch.manual_seed(self.rank) local_tensor_clone = init_op(local_tensor_clone, *args, **kwargs) torch.manual_seed(self.rank) dtensor = init_op(dtensor, *args, **kwargs) self.assertEqual(local_tensor_clone, dtensor.to_local()) else: # create DTensor from Tensor _tensor = torch.empty(*input_size, device="cuda") dtensor = distribute_tensor(_tensor, device_mesh, [Shard(1)]) # DTensor random init dtensor = init_op(dtensor, *args, **kwargs) local_tensor = dtensor.to_local() # compare with local tensors from other ranks for other_rank in range(self.world_size): if self.rank != other_rank: slice_idx = [ slice(input_size[0]), slice( other_rank * input_size[1], (other_rank + 1) * input_size[1] ), ] # other rank should have a different local tensor self.assertNotEqual(dtensor.full_tensor()[slice_idx], local_tensor) @with_comms def test_init_ops(self): self._run_init_op( torch.nn.init.kaiming_uniform_, a=0, mode="fan_in", nonlinearity="leaky_relu", ) self._run_init_op(torch.nn.init.normal_, mean=1.5, std=0.8) self._run_init_op(torch.nn.init.uniform_, a=0, b=1.2) for dtype in (torch.float32, torch.float16): self._run_init_op(torch.rand_like, dtype=dtype) self._run_init_op(torch.randn_like, dtype=dtype) self._run_init_op(torch.randint_like, low=0, high=100, dtype=dtype) class DistTensorRandomOpTest(DTensorTestBase): @with_comms @skip_unless_torch_gpu def test_rng_tracker_init(self): torch.cuda.manual_seed(self.rank) object_list = [torch.cuda.initial_seed()] broadcast_object_list(object_list) seed_from_rank_0 = int(object_list[0]) device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) # seed synchronization happens after the first `distribute_tensor` call dtensor = distribute_tensor( torch.empty([self.world_size], device="cuda"), device_mesh, [Shard(0)] ) self.assertEqual(seed_from_rank_0, random._rng_tracker.get_seed("parallel-rng")) @with_comms @skip_unless_torch_gpu def test_manual_seed(self): device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) manual_seed(1234, device_mesh) self.assertEqual(1234, random._rng_tracker.get_seed("parallel-rng")) with self.assertRaisesRegex(RuntimeError, "different seed values"): manual_seed(self.rank, device_mesh) @with_comms @skip_unless_torch_gpu def test_deterministic_dropout_1d(self): # test suite sets each rank's seed to the same value but in actual # execution the default random seed will be different (a random value). # The DTensor random ops will use the same random seed even though the # torch random generator keeps different seeds on ranks. torch.cuda.manual_seed(self.rank) # TODO: add test before/after enabling distribute region device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) size = [4, 4] dtensor = distribute_tensor( torch.empty(*size, device="cuda"), device_mesh, [Shard(1)] ) # a random op call shifts the offset dtensor.uniform_(0, 1) # the dtensor is now replicate on all ranks dtensor = dtensor.redistribute(device_mesh, [Replicate()]) dropout = torch.nn.Dropout(p=0.2) dtensor = dropout(dtensor) # allgather the local tensors local_tensor = funcol.all_gather_tensor( dtensor.to_local(), gather_dim=0, group=(device_mesh, 0) ) # compare with local tensors from other ranks self_slice = slice(4 * self.rank, 4 * self.rank + 4) for other_rank in range(self.world_size): if self.rank != other_rank: # other rank should have an identical local tensor other_slice = slice(4 * other_rank, 4 * other_rank + 4) self.assertEqual( local_tensor[self_slice, :], local_tensor[other_slice, :], ) @with_comms @skip_unless_torch_gpu def test_deterministic_rand_1d(self): device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) size = [4, 4 * self.world_size] for fn in [ torch.distributed._tensor.rand, torch.distributed._tensor.randn, ]: dtensor = fn(size, device_mesh=device_mesh, placements=[Shard(1)]) local_tensor = funcol.all_gather_tensor( dtensor.to_local(), gather_dim=0, group=(device_mesh, 0) ) # compare with local tensors from other ranks self_slice = slice(4 * self.rank, 4 * self.rank + 4) for other_rank in range(self.world_size): if self.rank != other_rank: # other rank should have an identical local tensor other_slice = slice(4 * other_rank, 4 * other_rank + 4) self.assertNotEqual( local_tensor[self_slice, :], local_tensor[other_slice, :], ) torch.cuda.manual_seed(self.rank) dtensor = fn(size, device_mesh=device_mesh, placements=[Replicate()]) local_tensor = funcol.all_gather_tensor( dtensor.to_local(), gather_dim=0, group=(device_mesh, 0) ) # compare with local tensors from other ranks self_slice = slice(4 * self.rank, 4 * self.rank + 4) for other_rank in range(self.world_size): if self.rank != other_rank: # other rank should have an identical local tensor other_slice = slice(4 * other_rank, 4 * other_rank + 4) self.assertEqual( local_tensor[self_slice, :], local_tensor[other_slice, :], ) @with_comms @skip_if_lt_x_gpu(4) def test_deterministic_uniform_2d(self): mesh = torch.arange(self.world_size).reshape(2, 2) device_mesh = DeviceMesh(self.device_type, mesh) dtensor = distribute_tensor( torch.empty( *[self.world_size for _ in mesh.size()], device=self.device_type ), device_mesh, [Replicate(), Replicate()], ) placements_list = [ # this list of placements should be enough to cover [Shard(0), Shard(1)], [Shard(1), Shard(0)], [Shard(0), Replicate()], [Replicate(), Shard(0)], [Shard(1), Replicate()], [Replicate(), Shard(1)], [Replicate(), Replicate()], ] shard_index_list = [ {0: 0, 1: 1, 2: 2, 3: 3}, {0: 0, 1: 2, 2: 1, 3: 3}, {0: 0, 1: 0, 2: 1, 3: 1}, {0: 0, 1: 1, 2: 0, 3: 1}, {0: 0, 1: 0, 2: 1, 3: 1}, {0: 0, 1: 1, 2: 0, 3: 1}, {0: 0, 1: 0, 2: 0, 3: 0}, ] coordinate = device_mesh.get_coordinate() assert coordinate is not None for placements, shard_index in zip(placements_list, shard_index_list): dtensor = dtensor.redistribute(device_mesh, placements) # check shard information is correct shard_coord = [ coordinate[mesh_dim] if mesh_dim >= 0 else 0 for mesh_dim in dtensor._spec.dim_map ] shard_size = [ device_mesh.size(mesh_dim) if mesh_dim >= 0 else 1 for mesh_dim in dtensor._spec.dim_map ] shard_linear_idx = random._rng_tracker._calc_shard_linear_idx( shard_coord, shard_size ) self.assertEqual(shard_linear_idx, shard_index[self.rank]) # compute local size and offset _, local_shard_offset = compute_local_shape_and_global_offset( dtensor.shape, device_mesh, placements ) # get the local shard size and local shard offset for each shard # local_shard_list_on_dim[i] has the list of all shards on that dim # as a tuple (local_shard_offset, local_shard_size) dtensor_shape = dtensor.shape local_shard_list_on_dim = [[(0, l)] for l in dtensor_shape] for idx, placement in enumerate(placements): if isinstance(placement, Shard): mesh_dim_size = device_mesh.size(idx) shard_dim = placement.dim local_shard_list_on_dim[shard_dim] = [] for shard_idx_on_dim in range(mesh_dim_size): shard_size, shard_offset = placement._local_shard_size_on_dim( dtensor_shape[shard_dim], mesh_dim_size, shard_idx_on_dim, return_offset=True, ) local_shard_list_on_dim[shard_dim].append( (shard_offset, shard_size) ) local_shard_comb = itertools.product(*local_shard_list_on_dim) # random op call dtensor.uniform_(0, 1) # the local shard local_tensor = dtensor.to_local() # allgather the local tensors full_tensor = dtensor.full_tensor() # compare local tensor with each other shard for other_local_shard in local_shard_comb: other_local_shard_offset, _ = zip(*other_local_shard) slice_idx = [ slice(offset, offset + size) for offset, size in other_local_shard ] if local_shard_offset == other_local_shard_offset: self.assertEqual(full_tensor[slice_idx], local_tensor) else: self.assertNotEqual(full_tensor[slice_idx], local_tensor) @with_comms @skip_if_lt_x_gpu(4) def test_meta_tensor_init(self): # test suite sets each rank's seed to the same value but in actual # execution the default random seed will be different (a random value). # The DTensor random ops will use the same random seed even though the # torch random generator keeps different seeds on ranks. This ensures # that Replicate DTensor will have the same initialized results # across ranks. torch.cuda.manual_seed(self.rank) device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) size = [1024, 2048] meta_dtensor = distribute_tensor( torch.empty(*size, device="meta"), device_mesh, [Replicate()] ) self.assertTrue(meta_dtensor.is_meta) dtensor = torch.empty_like(meta_dtensor, device=self.device_type) # disable the distribute region for RNG random._rng_tracker.distribute_region_enabled = False dtensor.uniform_() # allgather the local tensors local_tensor = funcol.all_gather_tensor( dtensor.to_local(), gather_dim=0, group=(device_mesh, 0) ) # compare with local tensors from other ranks self_slice = slice(1024 * self.rank, 1024 * self.rank + 1024) for other_rank in range(self.world_size): # the RNG result on each rank differs even they're supposed # to be replicated if self.rank != other_rank: other_slice = slice(1024 * other_rank, 1024 * other_rank + 1024) self.assertNotEqual( local_tensor[self_slice, :], local_tensor[other_slice, :] ) # enable the distribute region for RNG random._rng_tracker.distribute_region_enabled = True self.assertTrue(meta_dtensor.is_meta) dtensor = torch.empty_like(meta_dtensor, device=self.device_type) dtensor.uniform_() # allgather the local tensors local_tensor = funcol.all_gather_tensor( dtensor.to_local(), gather_dim=0, group=(device_mesh, 0) ) # compare with local tensors from other ranks for other_rank in range(self.world_size): # the RNG result on each rank are the same because they're replicated if self.rank != other_rank: # other rank should have an identical local tensor other_slice = slice(1024 * other_rank, 1024 * other_rank + 1024) self.assertEqual( local_tensor[self_slice, :], local_tensor[other_slice, :] ) if __name__ == "__main__": run_tests()