• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright (c) Meta Platforms, Inc. and affiliates
2# Owner(s): ["oncall: distributed"]
3
4import itertools
5from typing import cast, List, Optional
6
7import torch
8import torch.nn.functional as F
9from torch.distributed._tensor import DeviceMesh, distribute_tensor
10from torch.distributed._tensor.api import DTensor
11from torch.distributed._tensor.placement_types import (
12    Partial,
13    Placement,
14    Replicate,
15    Shard,
16)
17from torch.distributed.tensor.debug import CommDebugMode
18from torch.testing._internal.common_utils import run_tests
19from torch.testing._internal.distributed._tensor.common_dtensor import (
20    DTensorTestBase,
21    skip_unless_torch_gpu,
22    with_comms,
23)
24
25
26class DistMatrixOpsTest(DTensorTestBase):
27    @with_comms
28    def test_addmm(self):
29        device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
30        shard_spec = [Shard(0)]
31        replica_spec = [Replicate()]
32
33        tensor_to_shard = torch.randn(12, 8)
34        mat1 = distribute_tensor(tensor_to_shard, device_mesh, shard_spec)
35        tensor_to_replicate = torch.randn(8, 4)
36        mat2 = distribute_tensor(tensor_to_replicate, device_mesh, replica_spec)
37        input_tensor = torch.randn(4)
38        input = distribute_tensor(input_tensor, device_mesh, replica_spec)
39
40        dist_res = torch.addmm(input, mat1, mat2)
41        local_res = torch.addmm(input_tensor, tensor_to_shard, tensor_to_replicate)
42        self.assertEqual(dist_res.full_tensor(), local_res)
43
44    @with_comms
45    def test_addmm_empty_operand(self):
46        device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
47        shard_spec = [Shard(0)]
48        replica_spec = [Replicate()]
49
50        tensor_to_shard = torch.randn(12, 0)
51        mat1 = distribute_tensor(tensor_to_shard, device_mesh, shard_spec)
52        tensor_to_replicate = torch.randn(0, 4)
53        mat2 = distribute_tensor(tensor_to_replicate, device_mesh, replica_spec)
54        input_tensor = torch.randn(4)
55        inp = distribute_tensor(input_tensor, device_mesh, replica_spec)
56
57        dist_res = torch.addmm(inp, mat1, mat2)
58        local_res = torch.addmm(input_tensor, tensor_to_shard, tensor_to_replicate)
59        self.assertEqual(dist_res.full_tensor(), local_res)
60
61    @with_comms
62    def test_addmm_auto_redistribute(self):
63        device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
64        shard0_spec = [Shard(0)]
65        shard1_spec = [Shard(1)]
66        replica_spec = [Replicate()]
67
68        tensor_to_shard1 = torch.randn(12, 8, requires_grad=True)
69        mat1 = distribute_tensor(tensor_to_shard1, device_mesh, shard1_spec)
70        tensor_to_shard0 = torch.randn(8, 4, requires_grad=True)
71        mat2 = distribute_tensor(tensor_to_shard0, device_mesh, shard0_spec)
72        input_tensor = torch.randn(4, requires_grad=True)
73        input = distribute_tensor(input_tensor, device_mesh, replica_spec)
74
75        local_res = torch.addmm(input_tensor, tensor_to_shard1, tensor_to_shard0)
76        dist_res = torch.addmm(input, mat1, mat2)
77
78        # test if addmm output is a partial
79        self.assertIsInstance(dist_res, DTensor)
80        self.assertIsInstance(dist_res.placements[0], Partial)
81
82        # test if result is the same as tensor
83        dist_local_res = dist_res.full_tensor()
84        self.assertEqual(local_res, dist_local_res)
85
86        # backward checks
87        dist_local_res.sum().backward()
88        local_res.sum().backward()
89        self.assertIsNotNone(mat2.grad)
90        self.assertEqual(mat2.grad.full_tensor(), tensor_to_shard0.grad)
91
92    @with_comms
93    def test_mm(self):
94        device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
95        shard0_spec = Shard(0)
96        shard1_spec = Shard(1)
97        replica_spec = Replicate()
98
99        t1 = torch.randn(12, 8, requires_grad=True)
100        t2 = torch.randn(8, 16, requires_grad=True)
101        local_res = torch.mm(t1, t2)
102
103        def test_placement_comb(
104            placements1: List[Placement], placements2: List[Placement]
105        ) -> None:
106            dt1 = distribute_tensor(t1, device_mesh, placements1)
107            dt2 = distribute_tensor(t2, device_mesh, placements2)
108            dist_res: DTensor = cast(DTensor, torch.mm(dt1, dt2)).redistribute(
109                device_mesh, [replica_spec]
110            )
111            self.assertEqual(dist_res.to_local(), local_res)
112            # backward
113            grad_dist_res = torch.ones_like(dist_res)
114            dist_res.backward(grad_dist_res)
115            self.assertIsNotNone(dt1.grad)
116
117        placement_specs = [shard0_spec, shard1_spec, replica_spec]
118        shard_specs_comb = list(itertools.product(placement_specs, placement_specs))
119        for spec in shard_specs_comb:
120            test_placement_comb([spec[0]], [spec[1]])
121
122    @with_comms
123    def test_t(self):
124        device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
125        shard_spec = [Shard(0)]
126
127        tensor_to_transpose = torch.randn(12, 8, requires_grad=True)
128        mat = distribute_tensor(tensor_to_transpose, device_mesh, shard_spec)
129        tranposed_mat = mat.t()
130        self.assertEqual(tranposed_mat.size(), torch.Size([8, 12]))
131        self.assertEqual(tranposed_mat.placements, [Shard(1)])
132        tranposed_mat2 = tranposed_mat.t()
133        self.assertEqual(tranposed_mat2.size(), torch.Size([12, 8]))
134        self.assertEqual(tranposed_mat2.placements, shard_spec)
135
136    @with_comms
137    def test_t_partial(self):
138        device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
139
140        a = torch.randn(12, 8)
141        b = torch.randn(8, 4)
142        c = torch.mm(a, b).t()
143
144        da = distribute_tensor(a, device_mesh, [Shard(1)])
145        db = distribute_tensor(b, device_mesh, [Shard(0)])
146
147        # mm(da, db) should return a Partial tensor.
148        # transposing it should keep it Partial
149        dc = torch.mm(da, db).t()
150
151        self.assertTrue(isinstance(dc.placements[0], Partial))
152
153        # check that the local and distributed op results match
154        self.assertEqual(
155            c,
156            dc.redistribute(device_mesh, [Replicate()]).to_local(),
157        )
158
159    # baddbmm introduces nan occasionally on CPU: https://github.com/pytorch/pytorch/issues/80588
160    @with_comms
161    @skip_unless_torch_gpu
162    def test_baddbmm(self):
163        device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
164        tensor = torch.rand(4, 4, 8, device=self.device_type, requires_grad=True)
165        batch_1 = torch.rand(4, 4, 8, device=self.device_type, requires_grad=True)
166        batch_2 = torch.rand(4, 8, 8, device=self.device_type, requires_grad=True)
167
168        def test_placement_comb(
169            tensor_placements: List[Placement],
170            batch_1_placements: List[Placement],
171            batch_2_placements: List[Placement],
172            beta: int,
173            alpha: int,
174            batch_1_grad: Optional[torch.Tensor],
175        ) -> None:
176            tensor_dt = distribute_tensor(tensor, device_mesh, tensor_placements)
177            batch_1_dt = distribute_tensor(batch_1, device_mesh, batch_1_placements)
178            batch_2_dt = distribute_tensor(batch_2, device_mesh, batch_2_placements)
179            dist_res = cast(
180                DTensor,
181                torch.baddbmm(
182                    tensor_dt, batch_1_dt, batch_2_dt, beta=beta, alpha=alpha
183                ),
184            ).redistribute(device_mesh, [Replicate()])
185            dist_local_res = dist_res.to_local()
186            assert not torch.isnan(local_result).any()
187            assert not torch.isnan(dist_local_res).any()
188            self.assertEqual(dist_local_res.detach(), local_result.detach())
189
190            # TODO: add test backward
191            # grad_dist_res = torch.ones_like(dist_res)
192            # dist_res.backward(grad_dist_res)
193            # self.assertIsNotNone(batch_1_dt.grad)
194            # batch_1_grad_local = batch_1_dt.grad.redistribute(
195            #     device_mesh, [Replicate()]
196            # ).to_local()
197            # self.assertEqual(batch_1_grad_local, batch_1_grad)
198
199        shard0_spec = Shard(0)
200        shard1_spec = Shard(1)
201        shard2_spec = Shard(2)
202        replica_spec = Replicate()
203        shard_specs = [shard0_spec, shard1_spec, shard2_spec, replica_spec]
204        shard_specs_comb = list(
205            itertools.product(shard_specs, shard_specs, shard_specs)
206        )
207        # If beta is 0, input tensor will be ignored
208        numeric_params_comb = [
209            (0.0, 0.5),  # zero-beta
210            (0.8, 0.5),  # non-zero-beta
211        ]
212
213        for beta, alpha in numeric_params_comb:
214            local_result = torch.baddbmm(
215                tensor, batch_1, batch_2, beta=beta, alpha=alpha
216            )
217            grad_local_res = torch.ones_like(local_result)
218            local_result.backward(grad_local_res)
219            # test all combos
220            for spec in shard_specs_comb:
221                test_placement_comb(
222                    [spec[0]], [spec[1]], [spec[2]], beta, alpha, batch_1.grad
223                )
224
225    @with_comms
226    def test_bmm(self):
227        device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
228        mat1 = torch.rand(4, 8, 4, device=self.device_type, requires_grad=True)
229        mat2 = torch.rand(4, 4, 8, device=self.device_type, requires_grad=True)
230        local_result = torch.bmm(mat1, mat2)
231        grad_local_res = torch.ones_like(local_result)
232        local_result.backward(grad_local_res)
233
234        def test_placement_comb(
235            placements1: List[Placement],
236            placements2: List[Placement],
237        ) -> None:
238            mat1_dt = distribute_tensor(mat1, device_mesh, placements1)
239            mat2_dt = distribute_tensor(mat2, device_mesh, placements2)
240            dist_res = cast(DTensor, torch.bmm(mat1_dt, mat2_dt)).redistribute(
241                device_mesh, [Replicate()]
242            )
243            dist_local_res = dist_res.to_local()
244            self.assertEqual(dist_local_res, local_result)
245
246            # test backward
247            # TODO: figure out (replicate, shard1) fail on backward
248            # it generates a different grad shape
249            grad_dist_res = torch.ones_like(dist_res)
250            dist_res.backward(grad_dist_res)
251            self.assertIsNotNone(mat1_dt.grad)
252            mat1_dt_grad = cast(DTensor, mat1_dt.grad)
253            mat1_grad_local = mat1_dt_grad.redistribute(
254                device_mesh, [Replicate()]
255            ).to_local()
256            self.assertEqual(mat1_grad_local, mat1.grad)
257
258        shard0_spec = Shard(0)
259        shard1_spec = Shard(1)
260        shard2_spec = Shard(2)
261        replica_spec = Replicate()
262        placement_specs = [shard0_spec, shard1_spec, shard2_spec, replica_spec]
263        shard_specs_comb = list(itertools.product(placement_specs, placement_specs))
264
265        # tests that currently pass
266        for spec in shard_specs_comb:
267            test_placement_comb([spec[0]], [spec[1]])
268
269    @with_comms
270    @skip_unless_torch_gpu
271    def test_scaled_dot_product_attention(self):
272        device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
273        comm_mode = CommDebugMode()
274        # bsz, n_heads, slen, head_dim
275        query = torch.rand(
276            (4, 8, 8, 8),
277            device=self.device_type,
278            dtype=torch.bfloat16,
279            requires_grad=True,
280        )
281        key = torch.rand(
282            (4, 8, 8, 8),
283            device=self.device_type,
284            dtype=torch.bfloat16,
285            requires_grad=True,
286        )
287        value = torch.rand(
288            (4, 8, 8, 8),
289            device=self.device_type,
290            dtype=torch.bfloat16,
291            requires_grad=True,
292        )
293
294        dist_query = distribute_tensor(query, device_mesh, [Shard(1)])
295        dist_key = distribute_tensor(key, device_mesh, [Shard(1)])
296        dist_value = distribute_tensor(value, device_mesh, [Shard(1)])
297
298        from torch.nn.attention import sdpa_kernel, SDPBackend
299
300        available_backends = []
301        dropout_p = 0.0
302        # TODO: Add test cases where is_causal=False and an attention mask is provided.
303        #       Gaps include missing op support for aten.masked_fill_.Scalar.
304        is_causal = True
305        enable_gqa = False
306        params = torch.backends.cuda.SDPAParams(
307            query, key, value, None, dropout_p, is_causal, enable_gqa
308        )
309        if torch.backends.cuda.can_use_flash_attention(params, debug=False):
310            available_backends.append(SDPBackend.FLASH_ATTENTION)
311        if torch.backends.cuda.can_use_efficient_attention(params, debug=False):
312            available_backends.append(SDPBackend.EFFICIENT_ATTENTION)
313
314        for backend in available_backends:
315            with sdpa_kernel(backends=[backend]):
316                out = F.scaled_dot_product_attention(
317                    query, key, value, dropout_p=dropout_p, is_causal=is_causal
318                )
319                with comm_mode:
320                    dist_out = F.scaled_dot_product_attention(
321                        dist_query,
322                        dist_key,
323                        dist_value,
324                        dropout_p=dropout_p,
325                        is_causal=is_causal,
326                    )
327                    self.assertEqual(comm_mode.get_total_counts(), 0)
328                    self.assertTrue(dist_out.placements[0].is_shard(dim=1))
329                    self.assertEqual(dist_out.full_tensor(), out)
330
331                out.sum().backward()
332                with comm_mode:
333                    dist_out.sum().backward()
334                    self.assertEqual(comm_mode.get_total_counts(), 0)
335                    self.assertTrue(dist_query.grad.placements[0].is_shard(dim=1))
336                    self.assertEqual(dist_query.grad.full_tensor(), query.grad)
337                    self.assertTrue(dist_key.grad.placements[0].is_shard(dim=1))
338                    self.assertEqual(dist_key.grad.full_tensor(), key.grad)
339                    self.assertTrue(dist_value.grad.placements[0].is_shard(dim=1))
340                    self.assertEqual(dist_value.grad.full_tensor(), value.grad)
341
342
343if __name__ == "__main__":
344    run_tests()
345