• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright (c) Meta Platforms, Inc. and affiliates
2# Owner(s): ["oncall: distributed"]
3import unittest
4
5import torch
6import torch.distributed as dist
7import torch.nn.functional as F
8from torch import nn
9from torch.distributed._tensor import DeviceMesh
10from torch.distributed._tensor.experimental._attention import (
11    _AttentionContextParallel,
12    _CausalBehavior,
13    _context_parallel_buffers,
14    _is_causal_behavior,
15    context_parallel,
16)
17from torch.distributed.tensor.debug import CommDebugMode
18from torch.distributed.tensor.parallel import parallelize_module
19from torch.nn.attention import sdpa_kernel, SDPBackend
20from torch.testing._internal.common_cuda import (
21    PLATFORM_SUPPORTS_FLASH_ATTENTION,
22    PLATFORM_SUPPORTS_FUSED_ATTENTION,
23    PLATFORM_SUPPORTS_MEM_EFF_ATTENTION,
24)
25from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
26from torch.testing._internal.common_utils import (
27    instantiate_parametrized_tests,
28    parametrize,
29    run_tests,
30    skipIfRocm,
31)
32from torch.testing._internal.distributed._tensor.common_dtensor import (
33    DTensorTestBase,
34    ModelArgs,
35    Transformer,
36    with_comms,
37)
38
39
40c10d_functional = torch.ops.c10d_functional
41backends = []
42if PLATFORM_SUPPORTS_FLASH_ATTENTION:
43    backends.append(SDPBackend.FLASH_ATTENTION)
44if PLATFORM_SUPPORTS_MEM_EFF_ATTENTION:
45    backends.append(SDPBackend.EFFICIENT_ATTENTION)
46
47
48class RingAttentionTest(DTensorTestBase):
49    @property
50    def world_size(self) -> int:
51        return torch.cuda.device_count()
52
53    @skip_if_lt_x_gpu(2)
54    @skipIfRocm  # Missing _c10d_functional_autograd::all_to_all_single
55    @unittest.skipIf(
56        not PLATFORM_SUPPORTS_FUSED_ATTENTION,
57        "Does not support flash nor efficient attention",
58    )
59    @with_comms
60    @parametrize("is_causal", [True, False])
61    @parametrize("compiled", [True, False])
62    @parametrize("backend", backends)
63    def test_ring_attention_sdpa(
64        self, is_causal: bool, compiled: bool, backend: SDPBackend
65    ) -> None:
66        device_mesh = DeviceMesh(self.device_type, torch.arange(0, self.world_size))
67        dtype = torch.bfloat16
68        bs = 8
69        query_tokens = 64
70        context_tokens = 64
71        dim = 32
72        nheads = 8
73        torch.manual_seed(10)
74        dtype = (
75            torch.bfloat16 if backend == SDPBackend.FLASH_ATTENTION else torch.float32
76        )
77
78        if is_causal and compiled and self.world_size > 2:
79            # TODO: Fix this after we move `wait_tensor` to use `with_effect`.
80            return
81
82        q = torch.rand(
83            (bs, nheads, self.world_size * query_tokens, dim),
84            device=self.device_type,
85            dtype=dtype,
86            requires_grad=True,
87        )
88        k = torch.rand(
89            (bs, nheads, self.world_size * context_tokens, dim),
90            device=self.device_type,
91            dtype=dtype,
92            requires_grad=True,
93        )
94        v = torch.rand(
95            (bs, nheads, self.world_size * context_tokens, dim),
96            device=self.device_type,
97            dtype=dtype,
98            requires_grad=True,
99        )
100
101        # Ensure all ranks have the same initialization data.
102        with torch.no_grad():
103            dist.broadcast(q, src=0)
104            dist.broadcast(k, src=0)
105            dist.broadcast(v, src=0)
106
107        with sdpa_kernel(backend):
108            out = F.scaled_dot_product_attention(q, k, v, is_causal=is_causal)
109            out.sum().backward()
110
111        local_out, local_dq, local_dk, local_dv = _context_parallel_buffers(
112            device_mesh,
113            buffers=(out, q.grad, k.grad, v.grad),
114            buffer_seq_dims=(2, 2, 2, 2),
115        )
116
117        cp_q = q.clone().detach()
118        cp_k = k.clone().detach()
119        cp_v = v.clone().detach()
120        # Theoretically, context_parallel() should not be used to shard
121        # parameters because when require_grad is True, resize_ is not
122        # allowed. But requires_grad of cp_q, cp_k, and cp_v are False
123        # now. So we can just use context_parallel() to shard q, k, v.
124        # In reality, context_paralle() should be used to shard the input.
125        with context_parallel(
126            device_mesh, buffers=(cp_q, cp_k, cp_v), buffer_seq_dims=(2, 2, 2)
127        ):
128            cp_q.requires_grad = True
129            cp_k.requires_grad = True
130            cp_v.requires_grad = True
131            with CommDebugMode() as comm_mode:
132                with sdpa_kernel(backend):
133                    if compiled:
134                        fn = torch.compile(
135                            F.scaled_dot_product_attention,
136                            fullgraph=True,
137                            backend="aot_eager",
138                        )
139                    else:
140                        fn = F.scaled_dot_product_attention
141
142                    cp_out = fn(cp_q, cp_k, cp_v, is_causal=is_causal)
143                    cp_out.sum().backward()
144
145                    if not compiled:
146                        # Compiler and CommDebugMode do not work well together.
147                        self.assertDictEqual(
148                            comm_mode.get_comm_counts(),
149                            {
150                                c10d_functional.all_to_all_single: self.world_size * 3
151                                - 2
152                            },
153                        )
154
155            # Due to numerical error, we need to choose different atol for different
156            # attention kernels
157            atol = (
158                1e-08
159                if backend == SDPBackend.EFFICIENT_ATTENTION
160                else 1e-3 * self.world_size
161            )
162            self.assertTrue(torch.allclose(local_out, cp_out, atol=atol))
163
164            atol = (
165                2e-06
166                if backend == SDPBackend.EFFICIENT_ATTENTION
167                else 8e-3 * self.world_size
168            )
169            self.assertTrue(torch.allclose(local_dq, cp_q.grad, atol=atol))
170            self.assertTrue(torch.allclose(local_dk, cp_k.grad, atol=atol))
171            self.assertTrue(torch.allclose(local_dv, cp_v.grad, atol=atol))
172
173            cp_q.grad = None
174            cp_k.grad = None
175            cp_v.grad = None
176            cp_q.requires_grad = False
177            cp_k.requires_grad = False
178            cp_v.requires_grad = False
179
180    def test_is_causal_behavior(self) -> None:
181        self.assertEqual(
182            _is_causal_behavior(rank=0, world_size=4, i=0, is_causal=False),
183            _CausalBehavior.NOT_IS_CAUSAL,
184        )
185
186        ranks = [
187            [_CausalBehavior.IS_CAUSAL, _CausalBehavior.SKIP],
188            [_CausalBehavior.IS_CAUSAL, _CausalBehavior.NOT_IS_CAUSAL],
189        ]
190        for rank, iters in enumerate(ranks):
191            for i, behavior in enumerate(iters):
192                self.assertEqual(
193                    _is_causal_behavior(rank=rank, world_size=2, i=i, is_causal=True),
194                    behavior,
195                )
196
197    @skip_if_lt_x_gpu(2)
198    @unittest.skipIf(
199        not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support flash attention"
200    )
201    @with_comms
202    @sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION])
203    @parametrize("is_causal", [True, False])
204    def test_ring_attention_native_transformer(self, is_causal: bool) -> None:
205        device_mesh = DeviceMesh(
206            self.device_type,
207            torch.arange(0, self.world_size),
208        )
209        dtype = torch.bfloat16
210        bs = 8
211        ntokens = 8
212        dim = 32
213        nheads = 8
214        num_layers = 2
215
216        encoder_layer = nn.TransformerEncoderLayer(
217            d_model=dim,
218            nhead=nheads,
219            dim_feedforward=dim,
220            batch_first=True,
221        ).to(dtype)
222        encoder_layer = parallelize_module(
223            module=encoder_layer,
224            device_mesh=device_mesh,
225            parallelize_plan={
226                "self_attn": _AttentionContextParallel(),
227            },
228        )
229        model = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
230        model = model.to(self.device_type).to(dtype)
231
232        mask = (
233            nn.Transformer.generate_square_subsequent_mask(
234                ntokens, device=self.device_type, dtype=dtype
235            )
236            if is_causal
237            else None
238        )
239        seq = torch.rand((bs, ntokens, dim), device=self.device_type, dtype=dtype)
240
241        with CommDebugMode() as comm_mode:
242            out = model(seq, mask=mask, is_causal=is_causal)
243        self.assertDictEqual(
244            comm_mode.get_comm_counts(),
245            {
246                c10d_functional.all_to_all_single: (self.world_size - 1) * num_layers,
247            },
248        )
249
250        with CommDebugMode() as comm_mode:
251            out.sum().backward()
252        self.assertDictEqual(
253            comm_mode.get_comm_counts(),
254            {
255                c10d_functional.all_to_all_single: (self.world_size * 2 - 1)
256                * num_layers,
257            },
258        )
259
260    @skip_if_lt_x_gpu(2)
261    @unittest.skipIf(
262        not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support flash attention"
263    )
264    @with_comms
265    @sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION])
266    def test_ring_attention_custom_transformer(self) -> None:
267        device_mesh = DeviceMesh(
268            self.device_type,
269            torch.arange(0, self.world_size),
270        )
271        dtype = torch.bfloat16
272        bs = 2
273        args = ModelArgs()
274
275        model = Transformer(args).to(dtype).to(self.device_type)
276
277        model = parallelize_module(
278            module=model,
279            device_mesh=device_mesh,
280            parallelize_plan={
281                f"layers.{i}.attention": _AttentionContextParallel()
282                for i in range(args.n_layers)
283            },
284        )
285
286        seq = torch.randint(
287            args.vocab_size, (bs, args.max_seq_len), device=self.device_type
288        )
289
290        with CommDebugMode() as comm_mode:
291            out = model(seq)
292        self.assertDictEqual(
293            comm_mode.get_comm_counts(),
294            {
295                c10d_functional.all_to_all_single: (self.world_size - 1)
296                * args.n_layers,
297            },
298        )
299
300        with CommDebugMode() as comm_mode:
301            out.sum().backward()
302        self.assertDictEqual(
303            comm_mode.get_comm_counts(),
304            {
305                c10d_functional.all_to_all_single: (self.world_size * 2 - 1)
306                * args.n_layers,
307            },
308        )
309
310
311if backends:
312    instantiate_parametrized_tests(RingAttentionTest)
313
314if __name__ == "__main__":
315    run_tests()
316