1# Copyright (c) Meta Platforms, Inc. and affiliates 2# Owner(s): ["oncall: distributed"] 3 4import itertools 5from typing import cast, List, Optional 6 7import torch 8import torch.nn.functional as F 9from torch.distributed._tensor import DeviceMesh, distribute_tensor 10from torch.distributed._tensor.api import DTensor 11from torch.distributed._tensor.placement_types import ( 12 Partial, 13 Placement, 14 Replicate, 15 Shard, 16) 17from torch.distributed.tensor.debug import CommDebugMode 18from torch.testing._internal.common_utils import run_tests 19from torch.testing._internal.distributed._tensor.common_dtensor import ( 20 DTensorTestBase, 21 skip_unless_torch_gpu, 22 with_comms, 23) 24 25 26class DistMatrixOpsTest(DTensorTestBase): 27 @with_comms 28 def test_addmm(self): 29 device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) 30 shard_spec = [Shard(0)] 31 replica_spec = [Replicate()] 32 33 tensor_to_shard = torch.randn(12, 8) 34 mat1 = distribute_tensor(tensor_to_shard, device_mesh, shard_spec) 35 tensor_to_replicate = torch.randn(8, 4) 36 mat2 = distribute_tensor(tensor_to_replicate, device_mesh, replica_spec) 37 input_tensor = torch.randn(4) 38 input = distribute_tensor(input_tensor, device_mesh, replica_spec) 39 40 dist_res = torch.addmm(input, mat1, mat2) 41 local_res = torch.addmm(input_tensor, tensor_to_shard, tensor_to_replicate) 42 self.assertEqual(dist_res.full_tensor(), local_res) 43 44 @with_comms 45 def test_addmm_empty_operand(self): 46 device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) 47 shard_spec = [Shard(0)] 48 replica_spec = [Replicate()] 49 50 tensor_to_shard = torch.randn(12, 0) 51 mat1 = distribute_tensor(tensor_to_shard, device_mesh, shard_spec) 52 tensor_to_replicate = torch.randn(0, 4) 53 mat2 = distribute_tensor(tensor_to_replicate, device_mesh, replica_spec) 54 input_tensor = torch.randn(4) 55 inp = distribute_tensor(input_tensor, device_mesh, replica_spec) 56 57 dist_res = torch.addmm(inp, mat1, mat2) 58 local_res = torch.addmm(input_tensor, tensor_to_shard, tensor_to_replicate) 59 self.assertEqual(dist_res.full_tensor(), local_res) 60 61 @with_comms 62 def test_addmm_auto_redistribute(self): 63 device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) 64 shard0_spec = [Shard(0)] 65 shard1_spec = [Shard(1)] 66 replica_spec = [Replicate()] 67 68 tensor_to_shard1 = torch.randn(12, 8, requires_grad=True) 69 mat1 = distribute_tensor(tensor_to_shard1, device_mesh, shard1_spec) 70 tensor_to_shard0 = torch.randn(8, 4, requires_grad=True) 71 mat2 = distribute_tensor(tensor_to_shard0, device_mesh, shard0_spec) 72 input_tensor = torch.randn(4, requires_grad=True) 73 input = distribute_tensor(input_tensor, device_mesh, replica_spec) 74 75 local_res = torch.addmm(input_tensor, tensor_to_shard1, tensor_to_shard0) 76 dist_res = torch.addmm(input, mat1, mat2) 77 78 # test if addmm output is a partial 79 self.assertIsInstance(dist_res, DTensor) 80 self.assertIsInstance(dist_res.placements[0], Partial) 81 82 # test if result is the same as tensor 83 dist_local_res = dist_res.full_tensor() 84 self.assertEqual(local_res, dist_local_res) 85 86 # backward checks 87 dist_local_res.sum().backward() 88 local_res.sum().backward() 89 self.assertIsNotNone(mat2.grad) 90 self.assertEqual(mat2.grad.full_tensor(), tensor_to_shard0.grad) 91 92 @with_comms 93 def test_mm(self): 94 device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) 95 shard0_spec = Shard(0) 96 shard1_spec = Shard(1) 97 replica_spec = Replicate() 98 99 t1 = torch.randn(12, 8, requires_grad=True) 100 t2 = torch.randn(8, 16, requires_grad=True) 101 local_res = torch.mm(t1, t2) 102 103 def test_placement_comb( 104 placements1: List[Placement], placements2: List[Placement] 105 ) -> None: 106 dt1 = distribute_tensor(t1, device_mesh, placements1) 107 dt2 = distribute_tensor(t2, device_mesh, placements2) 108 dist_res: DTensor = cast(DTensor, torch.mm(dt1, dt2)).redistribute( 109 device_mesh, [replica_spec] 110 ) 111 self.assertEqual(dist_res.to_local(), local_res) 112 # backward 113 grad_dist_res = torch.ones_like(dist_res) 114 dist_res.backward(grad_dist_res) 115 self.assertIsNotNone(dt1.grad) 116 117 placement_specs = [shard0_spec, shard1_spec, replica_spec] 118 shard_specs_comb = list(itertools.product(placement_specs, placement_specs)) 119 for spec in shard_specs_comb: 120 test_placement_comb([spec[0]], [spec[1]]) 121 122 @with_comms 123 def test_t(self): 124 device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) 125 shard_spec = [Shard(0)] 126 127 tensor_to_transpose = torch.randn(12, 8, requires_grad=True) 128 mat = distribute_tensor(tensor_to_transpose, device_mesh, shard_spec) 129 tranposed_mat = mat.t() 130 self.assertEqual(tranposed_mat.size(), torch.Size([8, 12])) 131 self.assertEqual(tranposed_mat.placements, [Shard(1)]) 132 tranposed_mat2 = tranposed_mat.t() 133 self.assertEqual(tranposed_mat2.size(), torch.Size([12, 8])) 134 self.assertEqual(tranposed_mat2.placements, shard_spec) 135 136 @with_comms 137 def test_t_partial(self): 138 device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) 139 140 a = torch.randn(12, 8) 141 b = torch.randn(8, 4) 142 c = torch.mm(a, b).t() 143 144 da = distribute_tensor(a, device_mesh, [Shard(1)]) 145 db = distribute_tensor(b, device_mesh, [Shard(0)]) 146 147 # mm(da, db) should return a Partial tensor. 148 # transposing it should keep it Partial 149 dc = torch.mm(da, db).t() 150 151 self.assertTrue(isinstance(dc.placements[0], Partial)) 152 153 # check that the local and distributed op results match 154 self.assertEqual( 155 c, 156 dc.redistribute(device_mesh, [Replicate()]).to_local(), 157 ) 158 159 # baddbmm introduces nan occasionally on CPU: https://github.com/pytorch/pytorch/issues/80588 160 @with_comms 161 @skip_unless_torch_gpu 162 def test_baddbmm(self): 163 device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) 164 tensor = torch.rand(4, 4, 8, device=self.device_type, requires_grad=True) 165 batch_1 = torch.rand(4, 4, 8, device=self.device_type, requires_grad=True) 166 batch_2 = torch.rand(4, 8, 8, device=self.device_type, requires_grad=True) 167 168 def test_placement_comb( 169 tensor_placements: List[Placement], 170 batch_1_placements: List[Placement], 171 batch_2_placements: List[Placement], 172 beta: int, 173 alpha: int, 174 batch_1_grad: Optional[torch.Tensor], 175 ) -> None: 176 tensor_dt = distribute_tensor(tensor, device_mesh, tensor_placements) 177 batch_1_dt = distribute_tensor(batch_1, device_mesh, batch_1_placements) 178 batch_2_dt = distribute_tensor(batch_2, device_mesh, batch_2_placements) 179 dist_res = cast( 180 DTensor, 181 torch.baddbmm( 182 tensor_dt, batch_1_dt, batch_2_dt, beta=beta, alpha=alpha 183 ), 184 ).redistribute(device_mesh, [Replicate()]) 185 dist_local_res = dist_res.to_local() 186 assert not torch.isnan(local_result).any() 187 assert not torch.isnan(dist_local_res).any() 188 self.assertEqual(dist_local_res.detach(), local_result.detach()) 189 190 # TODO: add test backward 191 # grad_dist_res = torch.ones_like(dist_res) 192 # dist_res.backward(grad_dist_res) 193 # self.assertIsNotNone(batch_1_dt.grad) 194 # batch_1_grad_local = batch_1_dt.grad.redistribute( 195 # device_mesh, [Replicate()] 196 # ).to_local() 197 # self.assertEqual(batch_1_grad_local, batch_1_grad) 198 199 shard0_spec = Shard(0) 200 shard1_spec = Shard(1) 201 shard2_spec = Shard(2) 202 replica_spec = Replicate() 203 shard_specs = [shard0_spec, shard1_spec, shard2_spec, replica_spec] 204 shard_specs_comb = list( 205 itertools.product(shard_specs, shard_specs, shard_specs) 206 ) 207 # If beta is 0, input tensor will be ignored 208 numeric_params_comb = [ 209 (0.0, 0.5), # zero-beta 210 (0.8, 0.5), # non-zero-beta 211 ] 212 213 for beta, alpha in numeric_params_comb: 214 local_result = torch.baddbmm( 215 tensor, batch_1, batch_2, beta=beta, alpha=alpha 216 ) 217 grad_local_res = torch.ones_like(local_result) 218 local_result.backward(grad_local_res) 219 # test all combos 220 for spec in shard_specs_comb: 221 test_placement_comb( 222 [spec[0]], [spec[1]], [spec[2]], beta, alpha, batch_1.grad 223 ) 224 225 @with_comms 226 def test_bmm(self): 227 device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) 228 mat1 = torch.rand(4, 8, 4, device=self.device_type, requires_grad=True) 229 mat2 = torch.rand(4, 4, 8, device=self.device_type, requires_grad=True) 230 local_result = torch.bmm(mat1, mat2) 231 grad_local_res = torch.ones_like(local_result) 232 local_result.backward(grad_local_res) 233 234 def test_placement_comb( 235 placements1: List[Placement], 236 placements2: List[Placement], 237 ) -> None: 238 mat1_dt = distribute_tensor(mat1, device_mesh, placements1) 239 mat2_dt = distribute_tensor(mat2, device_mesh, placements2) 240 dist_res = cast(DTensor, torch.bmm(mat1_dt, mat2_dt)).redistribute( 241 device_mesh, [Replicate()] 242 ) 243 dist_local_res = dist_res.to_local() 244 self.assertEqual(dist_local_res, local_result) 245 246 # test backward 247 # TODO: figure out (replicate, shard1) fail on backward 248 # it generates a different grad shape 249 grad_dist_res = torch.ones_like(dist_res) 250 dist_res.backward(grad_dist_res) 251 self.assertIsNotNone(mat1_dt.grad) 252 mat1_dt_grad = cast(DTensor, mat1_dt.grad) 253 mat1_grad_local = mat1_dt_grad.redistribute( 254 device_mesh, [Replicate()] 255 ).to_local() 256 self.assertEqual(mat1_grad_local, mat1.grad) 257 258 shard0_spec = Shard(0) 259 shard1_spec = Shard(1) 260 shard2_spec = Shard(2) 261 replica_spec = Replicate() 262 placement_specs = [shard0_spec, shard1_spec, shard2_spec, replica_spec] 263 shard_specs_comb = list(itertools.product(placement_specs, placement_specs)) 264 265 # tests that currently pass 266 for spec in shard_specs_comb: 267 test_placement_comb([spec[0]], [spec[1]]) 268 269 @with_comms 270 @skip_unless_torch_gpu 271 def test_scaled_dot_product_attention(self): 272 device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) 273 comm_mode = CommDebugMode() 274 # bsz, n_heads, slen, head_dim 275 query = torch.rand( 276 (4, 8, 8, 8), 277 device=self.device_type, 278 dtype=torch.bfloat16, 279 requires_grad=True, 280 ) 281 key = torch.rand( 282 (4, 8, 8, 8), 283 device=self.device_type, 284 dtype=torch.bfloat16, 285 requires_grad=True, 286 ) 287 value = torch.rand( 288 (4, 8, 8, 8), 289 device=self.device_type, 290 dtype=torch.bfloat16, 291 requires_grad=True, 292 ) 293 294 dist_query = distribute_tensor(query, device_mesh, [Shard(1)]) 295 dist_key = distribute_tensor(key, device_mesh, [Shard(1)]) 296 dist_value = distribute_tensor(value, device_mesh, [Shard(1)]) 297 298 from torch.nn.attention import sdpa_kernel, SDPBackend 299 300 available_backends = [] 301 dropout_p = 0.0 302 # TODO: Add test cases where is_causal=False and an attention mask is provided. 303 # Gaps include missing op support for aten.masked_fill_.Scalar. 304 is_causal = True 305 enable_gqa = False 306 params = torch.backends.cuda.SDPAParams( 307 query, key, value, None, dropout_p, is_causal, enable_gqa 308 ) 309 if torch.backends.cuda.can_use_flash_attention(params, debug=False): 310 available_backends.append(SDPBackend.FLASH_ATTENTION) 311 if torch.backends.cuda.can_use_efficient_attention(params, debug=False): 312 available_backends.append(SDPBackend.EFFICIENT_ATTENTION) 313 314 for backend in available_backends: 315 with sdpa_kernel(backends=[backend]): 316 out = F.scaled_dot_product_attention( 317 query, key, value, dropout_p=dropout_p, is_causal=is_causal 318 ) 319 with comm_mode: 320 dist_out = F.scaled_dot_product_attention( 321 dist_query, 322 dist_key, 323 dist_value, 324 dropout_p=dropout_p, 325 is_causal=is_causal, 326 ) 327 self.assertEqual(comm_mode.get_total_counts(), 0) 328 self.assertTrue(dist_out.placements[0].is_shard(dim=1)) 329 self.assertEqual(dist_out.full_tensor(), out) 330 331 out.sum().backward() 332 with comm_mode: 333 dist_out.sum().backward() 334 self.assertEqual(comm_mode.get_total_counts(), 0) 335 self.assertTrue(dist_query.grad.placements[0].is_shard(dim=1)) 336 self.assertEqual(dist_query.grad.full_tensor(), query.grad) 337 self.assertTrue(dist_key.grad.placements[0].is_shard(dim=1)) 338 self.assertEqual(dist_key.grad.full_tensor(), key.grad) 339 self.assertTrue(dist_value.grad.placements[0].is_shard(dim=1)) 340 self.assertEqual(dist_value.grad.full_tensor(), value.grad) 341 342 343if __name__ == "__main__": 344 run_tests() 345