1# Copyright (c) Meta Platforms, Inc. and affiliates 2# Owner(s): ["oncall: distributed"] 3 4import itertools 5 6import torch 7import torch.distributed._functional_collectives as funcol 8import torch.distributed.tensor._random as random 9from torch.distributed._tensor import DeviceMesh, DTensor 10from torch.distributed._tensor._utils import compute_local_shape_and_global_offset 11from torch.distributed._tensor.api import distribute_tensor 12from torch.distributed._tensor.placement_types import Replicate, Shard 13from torch.distributed.distributed_c10d import broadcast_object_list 14from torch.distributed.tensor._random import is_rng_supported_mesh, manual_seed 15from torch.testing._internal.common_utils import run_tests 16from torch.testing._internal.distributed._tensor.common_dtensor import ( 17 DTensorTestBase, 18 skip_if_lt_x_gpu, 19 skip_unless_torch_gpu, 20 with_comms, 21) 22 23 24class DistTensorRandomInitTest(DTensorTestBase): 25 def _run_init_op(self, init_op, *args, **kwargs): 26 device_mesh = self.build_device_mesh() 27 shard_spec = [Shard(0)] 28 input_size = (8, 4) 29 30 # NOTE: currently random initialization on cuda device has different 31 # behavior from other devices. Unify the test once the behavior is unified. 32 if not is_rng_supported_mesh(device_mesh): 33 input_tensor = torch.randn(*input_size, device=self.device_type) 34 dtensor = DTensor.from_local(input_tensor, device_mesh, shard_spec) 35 local_tensor_clone = torch.clone(input_tensor) 36 torch.manual_seed(self.rank) 37 local_tensor_clone = init_op(local_tensor_clone, *args, **kwargs) 38 torch.manual_seed(self.rank) 39 dtensor = init_op(dtensor, *args, **kwargs) 40 self.assertEqual(local_tensor_clone, dtensor.to_local()) 41 else: 42 # create DTensor from Tensor 43 _tensor = torch.empty(*input_size, device="cuda") 44 dtensor = distribute_tensor(_tensor, device_mesh, [Shard(1)]) 45 46 # DTensor random init 47 dtensor = init_op(dtensor, *args, **kwargs) 48 local_tensor = dtensor.to_local() 49 50 # compare with local tensors from other ranks 51 for other_rank in range(self.world_size): 52 if self.rank != other_rank: 53 slice_idx = [ 54 slice(input_size[0]), 55 slice( 56 other_rank * input_size[1], (other_rank + 1) * input_size[1] 57 ), 58 ] 59 # other rank should have a different local tensor 60 self.assertNotEqual(dtensor.full_tensor()[slice_idx], local_tensor) 61 62 @with_comms 63 def test_init_ops(self): 64 self._run_init_op( 65 torch.nn.init.kaiming_uniform_, 66 a=0, 67 mode="fan_in", 68 nonlinearity="leaky_relu", 69 ) 70 self._run_init_op(torch.nn.init.normal_, mean=1.5, std=0.8) 71 self._run_init_op(torch.nn.init.uniform_, a=0, b=1.2) 72 73 for dtype in (torch.float32, torch.float16): 74 self._run_init_op(torch.rand_like, dtype=dtype) 75 self._run_init_op(torch.randn_like, dtype=dtype) 76 self._run_init_op(torch.randint_like, low=0, high=100, dtype=dtype) 77 78 79class DistTensorRandomOpTest(DTensorTestBase): 80 @with_comms 81 @skip_unless_torch_gpu 82 def test_rng_tracker_init(self): 83 torch.cuda.manual_seed(self.rank) 84 object_list = [torch.cuda.initial_seed()] 85 broadcast_object_list(object_list) 86 seed_from_rank_0 = int(object_list[0]) 87 88 device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) 89 # seed synchronization happens after the first `distribute_tensor` call 90 dtensor = distribute_tensor( 91 torch.empty([self.world_size], device="cuda"), device_mesh, [Shard(0)] 92 ) 93 self.assertEqual(seed_from_rank_0, random._rng_tracker.get_seed("parallel-rng")) 94 95 @with_comms 96 @skip_unless_torch_gpu 97 def test_manual_seed(self): 98 device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) 99 manual_seed(1234, device_mesh) 100 self.assertEqual(1234, random._rng_tracker.get_seed("parallel-rng")) 101 with self.assertRaisesRegex(RuntimeError, "different seed values"): 102 manual_seed(self.rank, device_mesh) 103 104 @with_comms 105 @skip_unless_torch_gpu 106 def test_deterministic_dropout_1d(self): 107 # test suite sets each rank's seed to the same value but in actual 108 # execution the default random seed will be different (a random value). 109 # The DTensor random ops will use the same random seed even though the 110 # torch random generator keeps different seeds on ranks. 111 torch.cuda.manual_seed(self.rank) 112 # TODO: add test before/after enabling distribute region 113 device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) 114 size = [4, 4] 115 116 dtensor = distribute_tensor( 117 torch.empty(*size, device="cuda"), device_mesh, [Shard(1)] 118 ) 119 120 # a random op call shifts the offset 121 dtensor.uniform_(0, 1) 122 123 # the dtensor is now replicate on all ranks 124 dtensor = dtensor.redistribute(device_mesh, [Replicate()]) 125 126 dropout = torch.nn.Dropout(p=0.2) 127 dtensor = dropout(dtensor) 128 129 # allgather the local tensors 130 local_tensor = funcol.all_gather_tensor( 131 dtensor.to_local(), gather_dim=0, group=(device_mesh, 0) 132 ) 133 134 # compare with local tensors from other ranks 135 self_slice = slice(4 * self.rank, 4 * self.rank + 4) 136 for other_rank in range(self.world_size): 137 if self.rank != other_rank: 138 # other rank should have an identical local tensor 139 other_slice = slice(4 * other_rank, 4 * other_rank + 4) 140 self.assertEqual( 141 local_tensor[self_slice, :], 142 local_tensor[other_slice, :], 143 ) 144 145 @with_comms 146 @skip_unless_torch_gpu 147 def test_deterministic_rand_1d(self): 148 device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) 149 size = [4, 4 * self.world_size] 150 151 for fn in [ 152 torch.distributed._tensor.rand, 153 torch.distributed._tensor.randn, 154 ]: 155 dtensor = fn(size, device_mesh=device_mesh, placements=[Shard(1)]) 156 local_tensor = funcol.all_gather_tensor( 157 dtensor.to_local(), gather_dim=0, group=(device_mesh, 0) 158 ) 159 160 # compare with local tensors from other ranks 161 self_slice = slice(4 * self.rank, 4 * self.rank + 4) 162 for other_rank in range(self.world_size): 163 if self.rank != other_rank: 164 # other rank should have an identical local tensor 165 other_slice = slice(4 * other_rank, 4 * other_rank + 4) 166 self.assertNotEqual( 167 local_tensor[self_slice, :], 168 local_tensor[other_slice, :], 169 ) 170 171 torch.cuda.manual_seed(self.rank) 172 dtensor = fn(size, device_mesh=device_mesh, placements=[Replicate()]) 173 local_tensor = funcol.all_gather_tensor( 174 dtensor.to_local(), gather_dim=0, group=(device_mesh, 0) 175 ) 176 177 # compare with local tensors from other ranks 178 self_slice = slice(4 * self.rank, 4 * self.rank + 4) 179 for other_rank in range(self.world_size): 180 if self.rank != other_rank: 181 # other rank should have an identical local tensor 182 other_slice = slice(4 * other_rank, 4 * other_rank + 4) 183 self.assertEqual( 184 local_tensor[self_slice, :], 185 local_tensor[other_slice, :], 186 ) 187 188 @with_comms 189 @skip_if_lt_x_gpu(4) 190 def test_deterministic_uniform_2d(self): 191 mesh = torch.arange(self.world_size).reshape(2, 2) 192 device_mesh = DeviceMesh(self.device_type, mesh) 193 dtensor = distribute_tensor( 194 torch.empty( 195 *[self.world_size for _ in mesh.size()], device=self.device_type 196 ), 197 device_mesh, 198 [Replicate(), Replicate()], 199 ) 200 201 placements_list = [ # this list of placements should be enough to cover 202 [Shard(0), Shard(1)], 203 [Shard(1), Shard(0)], 204 [Shard(0), Replicate()], 205 [Replicate(), Shard(0)], 206 [Shard(1), Replicate()], 207 [Replicate(), Shard(1)], 208 [Replicate(), Replicate()], 209 ] 210 211 shard_index_list = [ 212 {0: 0, 1: 1, 2: 2, 3: 3}, 213 {0: 0, 1: 2, 2: 1, 3: 3}, 214 {0: 0, 1: 0, 2: 1, 3: 1}, 215 {0: 0, 1: 1, 2: 0, 3: 1}, 216 {0: 0, 1: 0, 2: 1, 3: 1}, 217 {0: 0, 1: 1, 2: 0, 3: 1}, 218 {0: 0, 1: 0, 2: 0, 3: 0}, 219 ] 220 221 coordinate = device_mesh.get_coordinate() 222 assert coordinate is not None 223 224 for placements, shard_index in zip(placements_list, shard_index_list): 225 dtensor = dtensor.redistribute(device_mesh, placements) 226 227 # check shard information is correct 228 shard_coord = [ 229 coordinate[mesh_dim] if mesh_dim >= 0 else 0 230 for mesh_dim in dtensor._spec.dim_map 231 ] 232 233 shard_size = [ 234 device_mesh.size(mesh_dim) if mesh_dim >= 0 else 1 235 for mesh_dim in dtensor._spec.dim_map 236 ] 237 238 shard_linear_idx = random._rng_tracker._calc_shard_linear_idx( 239 shard_coord, shard_size 240 ) 241 self.assertEqual(shard_linear_idx, shard_index[self.rank]) 242 243 # compute local size and offset 244 _, local_shard_offset = compute_local_shape_and_global_offset( 245 dtensor.shape, device_mesh, placements 246 ) 247 248 # get the local shard size and local shard offset for each shard 249 # local_shard_list_on_dim[i] has the list of all shards on that dim 250 # as a tuple (local_shard_offset, local_shard_size) 251 dtensor_shape = dtensor.shape 252 local_shard_list_on_dim = [[(0, l)] for l in dtensor_shape] 253 for idx, placement in enumerate(placements): 254 if isinstance(placement, Shard): 255 mesh_dim_size = device_mesh.size(idx) 256 shard_dim = placement.dim 257 local_shard_list_on_dim[shard_dim] = [] 258 for shard_idx_on_dim in range(mesh_dim_size): 259 shard_size, shard_offset = placement._local_shard_size_on_dim( 260 dtensor_shape[shard_dim], 261 mesh_dim_size, 262 shard_idx_on_dim, 263 return_offset=True, 264 ) 265 local_shard_list_on_dim[shard_dim].append( 266 (shard_offset, shard_size) 267 ) 268 269 local_shard_comb = itertools.product(*local_shard_list_on_dim) 270 271 # random op call 272 dtensor.uniform_(0, 1) 273 274 # the local shard 275 local_tensor = dtensor.to_local() 276 # allgather the local tensors 277 full_tensor = dtensor.full_tensor() 278 279 # compare local tensor with each other shard 280 for other_local_shard in local_shard_comb: 281 other_local_shard_offset, _ = zip(*other_local_shard) 282 slice_idx = [ 283 slice(offset, offset + size) for offset, size in other_local_shard 284 ] 285 if local_shard_offset == other_local_shard_offset: 286 self.assertEqual(full_tensor[slice_idx], local_tensor) 287 else: 288 self.assertNotEqual(full_tensor[slice_idx], local_tensor) 289 290 @with_comms 291 @skip_if_lt_x_gpu(4) 292 def test_meta_tensor_init(self): 293 # test suite sets each rank's seed to the same value but in actual 294 # execution the default random seed will be different (a random value). 295 # The DTensor random ops will use the same random seed even though the 296 # torch random generator keeps different seeds on ranks. This ensures 297 # that Replicate DTensor will have the same initialized results 298 # across ranks. 299 torch.cuda.manual_seed(self.rank) 300 device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) 301 size = [1024, 2048] 302 meta_dtensor = distribute_tensor( 303 torch.empty(*size, device="meta"), device_mesh, [Replicate()] 304 ) 305 self.assertTrue(meta_dtensor.is_meta) 306 dtensor = torch.empty_like(meta_dtensor, device=self.device_type) 307 308 # disable the distribute region for RNG 309 random._rng_tracker.distribute_region_enabled = False 310 dtensor.uniform_() 311 312 # allgather the local tensors 313 local_tensor = funcol.all_gather_tensor( 314 dtensor.to_local(), gather_dim=0, group=(device_mesh, 0) 315 ) 316 317 # compare with local tensors from other ranks 318 self_slice = slice(1024 * self.rank, 1024 * self.rank + 1024) 319 for other_rank in range(self.world_size): 320 # the RNG result on each rank differs even they're supposed 321 # to be replicated 322 if self.rank != other_rank: 323 other_slice = slice(1024 * other_rank, 1024 * other_rank + 1024) 324 self.assertNotEqual( 325 local_tensor[self_slice, :], local_tensor[other_slice, :] 326 ) 327 328 # enable the distribute region for RNG 329 random._rng_tracker.distribute_region_enabled = True 330 self.assertTrue(meta_dtensor.is_meta) 331 dtensor = torch.empty_like(meta_dtensor, device=self.device_type) 332 dtensor.uniform_() 333 334 # allgather the local tensors 335 local_tensor = funcol.all_gather_tensor( 336 dtensor.to_local(), gather_dim=0, group=(device_mesh, 0) 337 ) 338 339 # compare with local tensors from other ranks 340 for other_rank in range(self.world_size): 341 # the RNG result on each rank are the same because they're replicated 342 if self.rank != other_rank: 343 # other rank should have an identical local tensor 344 other_slice = slice(1024 * other_rank, 1024 * other_rank + 1024) 345 self.assertEqual( 346 local_tensor[self_slice, :], local_tensor[other_slice, :] 347 ) 348 349 350if __name__ == "__main__": 351 run_tests() 352