1# Copyright (c) Meta Platforms, Inc. and affiliates 2# Owner(s): ["oncall: distributed"] 3import unittest 4 5import torch 6import torch.distributed as dist 7import torch.nn.functional as F 8from torch import nn 9from torch.distributed._tensor import DeviceMesh 10from torch.distributed._tensor.experimental._attention import ( 11 _AttentionContextParallel, 12 _CausalBehavior, 13 _context_parallel_buffers, 14 _is_causal_behavior, 15 context_parallel, 16) 17from torch.distributed.tensor.debug import CommDebugMode 18from torch.distributed.tensor.parallel import parallelize_module 19from torch.nn.attention import sdpa_kernel, SDPBackend 20from torch.testing._internal.common_cuda import ( 21 PLATFORM_SUPPORTS_FLASH_ATTENTION, 22 PLATFORM_SUPPORTS_FUSED_ATTENTION, 23 PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, 24) 25from torch.testing._internal.common_distributed import skip_if_lt_x_gpu 26from torch.testing._internal.common_utils import ( 27 instantiate_parametrized_tests, 28 parametrize, 29 run_tests, 30 skipIfRocm, 31) 32from torch.testing._internal.distributed._tensor.common_dtensor import ( 33 DTensorTestBase, 34 ModelArgs, 35 Transformer, 36 with_comms, 37) 38 39 40c10d_functional = torch.ops.c10d_functional 41backends = [] 42if PLATFORM_SUPPORTS_FLASH_ATTENTION: 43 backends.append(SDPBackend.FLASH_ATTENTION) 44if PLATFORM_SUPPORTS_MEM_EFF_ATTENTION: 45 backends.append(SDPBackend.EFFICIENT_ATTENTION) 46 47 48class RingAttentionTest(DTensorTestBase): 49 @property 50 def world_size(self) -> int: 51 return torch.cuda.device_count() 52 53 @skip_if_lt_x_gpu(2) 54 @skipIfRocm # Missing _c10d_functional_autograd::all_to_all_single 55 @unittest.skipIf( 56 not PLATFORM_SUPPORTS_FUSED_ATTENTION, 57 "Does not support flash nor efficient attention", 58 ) 59 @with_comms 60 @parametrize("is_causal", [True, False]) 61 @parametrize("compiled", [True, False]) 62 @parametrize("backend", backends) 63 def test_ring_attention_sdpa( 64 self, is_causal: bool, compiled: bool, backend: SDPBackend 65 ) -> None: 66 device_mesh = DeviceMesh(self.device_type, torch.arange(0, self.world_size)) 67 dtype = torch.bfloat16 68 bs = 8 69 query_tokens = 64 70 context_tokens = 64 71 dim = 32 72 nheads = 8 73 torch.manual_seed(10) 74 dtype = ( 75 torch.bfloat16 if backend == SDPBackend.FLASH_ATTENTION else torch.float32 76 ) 77 78 if is_causal and compiled and self.world_size > 2: 79 # TODO: Fix this after we move `wait_tensor` to use `with_effect`. 80 return 81 82 q = torch.rand( 83 (bs, nheads, self.world_size * query_tokens, dim), 84 device=self.device_type, 85 dtype=dtype, 86 requires_grad=True, 87 ) 88 k = torch.rand( 89 (bs, nheads, self.world_size * context_tokens, dim), 90 device=self.device_type, 91 dtype=dtype, 92 requires_grad=True, 93 ) 94 v = torch.rand( 95 (bs, nheads, self.world_size * context_tokens, dim), 96 device=self.device_type, 97 dtype=dtype, 98 requires_grad=True, 99 ) 100 101 # Ensure all ranks have the same initialization data. 102 with torch.no_grad(): 103 dist.broadcast(q, src=0) 104 dist.broadcast(k, src=0) 105 dist.broadcast(v, src=0) 106 107 with sdpa_kernel(backend): 108 out = F.scaled_dot_product_attention(q, k, v, is_causal=is_causal) 109 out.sum().backward() 110 111 local_out, local_dq, local_dk, local_dv = _context_parallel_buffers( 112 device_mesh, 113 buffers=(out, q.grad, k.grad, v.grad), 114 buffer_seq_dims=(2, 2, 2, 2), 115 ) 116 117 cp_q = q.clone().detach() 118 cp_k = k.clone().detach() 119 cp_v = v.clone().detach() 120 # Theoretically, context_parallel() should not be used to shard 121 # parameters because when require_grad is True, resize_ is not 122 # allowed. But requires_grad of cp_q, cp_k, and cp_v are False 123 # now. So we can just use context_parallel() to shard q, k, v. 124 # In reality, context_paralle() should be used to shard the input. 125 with context_parallel( 126 device_mesh, buffers=(cp_q, cp_k, cp_v), buffer_seq_dims=(2, 2, 2) 127 ): 128 cp_q.requires_grad = True 129 cp_k.requires_grad = True 130 cp_v.requires_grad = True 131 with CommDebugMode() as comm_mode: 132 with sdpa_kernel(backend): 133 if compiled: 134 fn = torch.compile( 135 F.scaled_dot_product_attention, 136 fullgraph=True, 137 backend="aot_eager", 138 ) 139 else: 140 fn = F.scaled_dot_product_attention 141 142 cp_out = fn(cp_q, cp_k, cp_v, is_causal=is_causal) 143 cp_out.sum().backward() 144 145 if not compiled: 146 # Compiler and CommDebugMode do not work well together. 147 self.assertDictEqual( 148 comm_mode.get_comm_counts(), 149 { 150 c10d_functional.all_to_all_single: self.world_size * 3 151 - 2 152 }, 153 ) 154 155 # Due to numerical error, we need to choose different atol for different 156 # attention kernels 157 atol = ( 158 1e-08 159 if backend == SDPBackend.EFFICIENT_ATTENTION 160 else 1e-3 * self.world_size 161 ) 162 self.assertTrue(torch.allclose(local_out, cp_out, atol=atol)) 163 164 atol = ( 165 2e-06 166 if backend == SDPBackend.EFFICIENT_ATTENTION 167 else 8e-3 * self.world_size 168 ) 169 self.assertTrue(torch.allclose(local_dq, cp_q.grad, atol=atol)) 170 self.assertTrue(torch.allclose(local_dk, cp_k.grad, atol=atol)) 171 self.assertTrue(torch.allclose(local_dv, cp_v.grad, atol=atol)) 172 173 cp_q.grad = None 174 cp_k.grad = None 175 cp_v.grad = None 176 cp_q.requires_grad = False 177 cp_k.requires_grad = False 178 cp_v.requires_grad = False 179 180 def test_is_causal_behavior(self) -> None: 181 self.assertEqual( 182 _is_causal_behavior(rank=0, world_size=4, i=0, is_causal=False), 183 _CausalBehavior.NOT_IS_CAUSAL, 184 ) 185 186 ranks = [ 187 [_CausalBehavior.IS_CAUSAL, _CausalBehavior.SKIP], 188 [_CausalBehavior.IS_CAUSAL, _CausalBehavior.NOT_IS_CAUSAL], 189 ] 190 for rank, iters in enumerate(ranks): 191 for i, behavior in enumerate(iters): 192 self.assertEqual( 193 _is_causal_behavior(rank=rank, world_size=2, i=i, is_causal=True), 194 behavior, 195 ) 196 197 @skip_if_lt_x_gpu(2) 198 @unittest.skipIf( 199 not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support flash attention" 200 ) 201 @with_comms 202 @sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]) 203 @parametrize("is_causal", [True, False]) 204 def test_ring_attention_native_transformer(self, is_causal: bool) -> None: 205 device_mesh = DeviceMesh( 206 self.device_type, 207 torch.arange(0, self.world_size), 208 ) 209 dtype = torch.bfloat16 210 bs = 8 211 ntokens = 8 212 dim = 32 213 nheads = 8 214 num_layers = 2 215 216 encoder_layer = nn.TransformerEncoderLayer( 217 d_model=dim, 218 nhead=nheads, 219 dim_feedforward=dim, 220 batch_first=True, 221 ).to(dtype) 222 encoder_layer = parallelize_module( 223 module=encoder_layer, 224 device_mesh=device_mesh, 225 parallelize_plan={ 226 "self_attn": _AttentionContextParallel(), 227 }, 228 ) 229 model = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) 230 model = model.to(self.device_type).to(dtype) 231 232 mask = ( 233 nn.Transformer.generate_square_subsequent_mask( 234 ntokens, device=self.device_type, dtype=dtype 235 ) 236 if is_causal 237 else None 238 ) 239 seq = torch.rand((bs, ntokens, dim), device=self.device_type, dtype=dtype) 240 241 with CommDebugMode() as comm_mode: 242 out = model(seq, mask=mask, is_causal=is_causal) 243 self.assertDictEqual( 244 comm_mode.get_comm_counts(), 245 { 246 c10d_functional.all_to_all_single: (self.world_size - 1) * num_layers, 247 }, 248 ) 249 250 with CommDebugMode() as comm_mode: 251 out.sum().backward() 252 self.assertDictEqual( 253 comm_mode.get_comm_counts(), 254 { 255 c10d_functional.all_to_all_single: (self.world_size * 2 - 1) 256 * num_layers, 257 }, 258 ) 259 260 @skip_if_lt_x_gpu(2) 261 @unittest.skipIf( 262 not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support flash attention" 263 ) 264 @with_comms 265 @sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]) 266 def test_ring_attention_custom_transformer(self) -> None: 267 device_mesh = DeviceMesh( 268 self.device_type, 269 torch.arange(0, self.world_size), 270 ) 271 dtype = torch.bfloat16 272 bs = 2 273 args = ModelArgs() 274 275 model = Transformer(args).to(dtype).to(self.device_type) 276 277 model = parallelize_module( 278 module=model, 279 device_mesh=device_mesh, 280 parallelize_plan={ 281 f"layers.{i}.attention": _AttentionContextParallel() 282 for i in range(args.n_layers) 283 }, 284 ) 285 286 seq = torch.randint( 287 args.vocab_size, (bs, args.max_seq_len), device=self.device_type 288 ) 289 290 with CommDebugMode() as comm_mode: 291 out = model(seq) 292 self.assertDictEqual( 293 comm_mode.get_comm_counts(), 294 { 295 c10d_functional.all_to_all_single: (self.world_size - 1) 296 * args.n_layers, 297 }, 298 ) 299 300 with CommDebugMode() as comm_mode: 301 out.sum().backward() 302 self.assertDictEqual( 303 comm_mode.get_comm_counts(), 304 { 305 c10d_functional.all_to_all_single: (self.world_size * 2 - 1) 306 * args.n_layers, 307 }, 308 ) 309 310 311if backends: 312 instantiate_parametrized_tests(RingAttentionTest) 313 314if __name__ == "__main__": 315 run_tests() 316