1# Copyright (c) Meta Platforms, Inc. and affiliates 2# Owner(s): ["oncall: distributed"] 3 4import torch 5from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard, zeros 6from torch.testing._internal.common_utils import run_tests 7from torch.testing._internal.distributed._tensor.common_dtensor import ( 8 DTensorTestBase, 9 with_comms, 10) 11 12 13class DTensorInitOpsTest(DTensorTestBase): 14 def _run_init_op(self, init_op, *args, **kwargs): 15 device_mesh = self.build_device_mesh() 16 shard_spec = [Shard(0)] 17 input_size = (8, 4) 18 input_tensor = torch.randn(*input_size, device=self.device_type) 19 dtensor = DTensor.from_local(input_tensor, device_mesh, shard_spec) 20 local_tensor_clone = torch.clone(input_tensor) 21 torch.manual_seed(self.rank) 22 local_tensor_clone = init_op(local_tensor_clone, *args, **kwargs) 23 torch.manual_seed(self.rank) 24 dtensor = init_op(dtensor, *args, **kwargs) 25 self.assertEqual(local_tensor_clone, dtensor.to_local()) 26 27 @with_comms 28 def test_init_ops(self): 29 # NOTE: random init tests are moved to test_random_ops.py 30 self._run_init_op(torch.nn.init.constant_, 2.4) 31 32 33class DTensorConstructorTest(DTensorTestBase): 34 @property 35 def world_size(self): 36 return 4 37 38 def _run_init_op(self, init_op, dist_init_op, eq_op, *args, **kwargs): 39 # 1d mesh test 40 device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) 41 placements_list = [[Shard(0)], [Shard(1)], [Shard(2)], [Replicate()]] 42 43 # even sharding 44 tensor_size = [4, 8, 12] 45 for placements in placements_list: 46 local_tensor_size = tensor_size.copy() 47 if isinstance(placements[0], Shard): 48 shard_dim = placements[0].dim 49 local_tensor_size[shard_dim] //= self.world_size 50 51 dist_tensor = dist_init_op( 52 tensor_size, 53 *args, 54 **kwargs, 55 device_mesh=device_mesh, 56 placements=placements, 57 ) 58 ones_expected = init_op(local_tensor_size, *args, **kwargs) 59 eq_op(ones_expected, dist_tensor.to_local()) 60 61 # uneven sharding 62 tensor_size = [5, 10, 15] 63 for placements in placements_list: 64 dist_tensor = dist_init_op( 65 tensor_size, 66 *args, 67 **kwargs, 68 device_mesh=device_mesh, 69 placements=placements, 70 ) 71 if isinstance(placements[0], Shard): 72 shard_dim = placements[0].dim 73 exp_tensor_list = list( 74 torch.chunk( 75 init_op(tensor_size, *args, **kwargs), 76 self.world_size, 77 dim=shard_dim, 78 ) 79 ) 80 if self.rank < len(exp_tensor_list): 81 eq_op(exp_tensor_list[self.rank], dist_tensor.to_local()) 82 else: 83 exp_tensor = init_op(tensor_size, *args, **kwargs) 84 eq_op(exp_tensor, dist_tensor.to_local()) 85 86 # empty shape 87 local_tensor = dist_init_op( 88 [], *args, **kwargs, device_mesh=device_mesh, placements=[Replicate()] 89 ).to_local() 90 expected_tensor = init_op([], *args, **kwargs) 91 eq_op(expected_tensor, local_tensor) 92 93 @with_comms 94 def test_ones(self): 95 self._run_init_op( 96 torch.ones, 97 torch.distributed._tensor.ones, 98 self.assertEqual, 99 requires_grad=True, 100 ) 101 102 @with_comms 103 def test_empty(self): 104 self._run_init_op( 105 torch.empty, 106 torch.distributed._tensor.empty, 107 lambda x, y: (x.shape == y.shape) 108 and (x.dtype == y.dtype) 109 and (x.layout == y.layout), 110 requires_grad=True, 111 ) 112 113 @with_comms 114 def test_full(self): 115 self._run_init_op( 116 torch.full, 117 torch.distributed._tensor.full, 118 self.assertEqual, 119 123.4, 120 requires_grad=True, 121 ) 122 123 @with_comms 124 def test_zeros(self): 125 self._run_init_op( 126 torch.zeros, 127 torch.distributed._tensor.zeros, 128 self.assertEqual, 129 requires_grad=True, 130 ) 131 132 @with_comms 133 def test_zeros_full_mesh(self): 134 # construct a cuda device 1d mesh 135 mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) 136 placements = [Shard(0)] 137 size = [32, 3] 138 dist_tensor = zeros(size, device_mesh=mesh, placements=placements) 139 self.assertEqual(dist_tensor.size(), torch.Size(size)) 140 local_tensor = dist_tensor.to_local() 141 self.assertEqual(local_tensor.size(), torch.Size([8, 3])) 142 143 local_tensor = torch.zeros(8, 3) 144 self.assertEqual(dist_tensor.to_local(), local_tensor) 145 146 self.assertEqual(dist_tensor.device.type, self.device_type) 147 148 # 1d sharded unevenly 149 size = [31, 3] 150 dist_tensor = zeros(size, device_mesh=mesh, placements=placements) 151 self.assertEqual(dist_tensor.size(), torch.Size(size)) 152 local_tensor = dist_tensor.to_local() 153 if self.rank <= 2: 154 self.assertEqual(local_tensor.size(), torch.Size([8, 3])) 155 self.assertEqual(torch.zeros(8, 3), local_tensor) 156 else: 157 self.assertEqual(local_tensor.size(), torch.Size([7, 3])) 158 self.assertEqual(torch.zeros(7, 3), local_tensor) 159 160 # construct a cuda device mesh with 2d: shard, replicate 161 mesh = DeviceMesh(self.device_type, torch.arange(self.world_size).reshape(2, 2)) 162 placements = [Shard(0), Replicate()] 163 size = [32, 4] 164 dist_tensor = zeros(size, device_mesh=mesh, placements=placements) 165 166 self.assertEqual(dist_tensor.size(), torch.Size(size)) 167 local_tensor = dist_tensor.to_local() 168 self.assertEqual(local_tensor.size(), torch.Size([16, 4])) 169 self.assertEqual(local_tensor, torch.zeros([16, 4])) 170 171 # construct a cuda device mesh with 2d: shard, shard 172 placements = [Shard(0), Shard(1)] 173 size = [32, 4] 174 dist_tensor = zeros(size, device_mesh=mesh, placements=placements) 175 176 self.assertEqual(dist_tensor.size(), torch.Size(size)) 177 local_tensor = dist_tensor.to_local() 178 self.assertEqual(local_tensor.size(), torch.Size([16, 2])) 179 self.assertEqual(local_tensor, torch.zeros([16, 2])) 180 181 # 2d sharded unevenly 182 placements = [Shard(0), Shard(1)] 183 size = [31, 3] 184 dist_tensor = zeros(size, device_mesh=mesh, placements=placements) 185 186 self.assertEqual(dist_tensor.size(), torch.Size(size)) 187 local_tensor = dist_tensor.to_local() 188 if self.rank == 0: 189 self.assertEqual(local_tensor, torch.zeros([16, 2])) 190 elif self.rank == 1: 191 self.assertEqual(local_tensor, torch.zeros([16, 1])) 192 elif self.rank == 2: 193 self.assertEqual(local_tensor, torch.zeros([15, 2])) 194 elif self.rank == 3: 195 self.assertEqual(local_tensor, torch.zeros([15, 1])) 196 197 @with_comms 198 def test_zeros_submesh(self): 199 # default world_size is 4 200 # construct a cuda device 1d mesh, with no sub pg initialized 201 sub_mesh_list = [0, 3] 202 mesh = DeviceMesh(self.device_type, sub_mesh_list) 203 placements = [Shard(0)] 204 size = [32, 3] 205 dist_tensor = zeros(size, device_mesh=mesh, placements=placements) 206 self.assertEqual(dist_tensor.size(), torch.Size(size)) 207 local_tensor = dist_tensor.to_local() 208 209 if self.rank in sub_mesh_list: 210 self.assertEqual(local_tensor.size(), torch.Size([16, 3])) 211 self.assertEqual(local_tensor, torch.zeros([16, 3])) 212 else: 213 self.assertEqual(local_tensor.size(), torch.Size([0])) 214 self.assertEqual(local_tensor, torch.zeros(0)) 215 216 # construct a cuda device 1d mesh: unevenly, with subpg initialized 217 sub_mesh_list = [0, 1, 3] 218 mesh = DeviceMesh(self.device_type, sub_mesh_list) 219 placements = [Shard(0)] 220 size = [32, 3] 221 dist_tensor = zeros(size, device_mesh=mesh, placements=placements) 222 self.assertEqual(dist_tensor.size(), torch.Size(size)) 223 local_tensor = dist_tensor.to_local() 224 225 if self.rank in sub_mesh_list: 226 if self.rank != 3: 227 self.assertEqual(local_tensor.size(), torch.Size([11, 3])) 228 self.assertEqual(local_tensor, torch.zeros([11, 3])) 229 else: 230 self.assertEqual(local_tensor.size(), torch.Size([10, 3])) 231 self.assertEqual(local_tensor, torch.zeros([10, 3])) 232 else: 233 self.assertEqual(local_tensor.size(), torch.Size([0])) 234 self.assertEqual(local_tensor, torch.tensor([])) 235 236 # construct a cuda device 2d mesh, with no subpg initialized 237 sub_mesh_list = [[0], [3]] 238 mesh = DeviceMesh(self.device_type, sub_mesh_list) 239 placements = [Shard(0), Shard(1)] 240 size = [32, 3] 241 dist_tensor = zeros(size, device_mesh=mesh, placements=placements) 242 self.assertEqual(dist_tensor.size(), torch.Size(size)) 243 local_tensor = dist_tensor.to_local() 244 245 if self.rank in [0, 3]: 246 self.assertEqual(local_tensor.size(), torch.Size([16, 3])) 247 self.assertEqual(local_tensor, torch.zeros([16, 3])) 248 else: 249 self.assertEqual(local_tensor.size(), torch.Size([0])) 250 self.assertEqual(local_tensor, torch.tensor([])) 251 252 253if __name__ == "__main__": 254 run_tests() 255