# Owner(s): ["module: inductor"] import functools import itertools import math import torch import torch._inductor.config import torch.utils.checkpoint from torch._dynamo.debug_utils import aot_graph_input_parser from torch._dynamo.utils import counters from torch._inductor.test_case import run_tests, TestCase from torch._inductor.utils import run_and_get_code from torch.testing._internal.common_cuda import ( PLATFORM_SUPPORTS_FUSED_ATTENTION, SM80OrLater, ) from torch.testing._internal.common_utils import IS_LINUX, skipIfRocm from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA def checkpoint_wrapper(fn): def inner(*args): return torch.utils.checkpoint.checkpoint(fn, *args, use_reentrant=True) return inner class TestSDPAPatternRewriterTemplate(TestCase): use_static_shapes = True def _clone_inputs(self, inputs): def clone(x): if not isinstance(x, torch.Tensor): return x return x.clone() return [clone(x) for x in inputs] def _check_common( self, dot_prod_attention, args1=None, contains=True, atol=1e-5, has_fuse_pattern=True, has_dropout=False, check_train=True, override_check_equal=False, dtype=torch.float, rtol=1.3e-6, ): if args1 is None: tensor_shape = (4, 2, 16, 32) args1 = [ torch.randn(tensor_shape, device=self.device, dtype=dtype), torch.randn(tensor_shape, device=self.device, dtype=dtype), torch.randn(tensor_shape, device=self.device, dtype=dtype), ] else: args1 = list(args1) args2 = self._clone_inputs(args1) for training in [False, True] if check_train else [False]: for x in itertools.chain(args1[:], args2[:]): if isinstance(x, torch.Tensor) and x.is_floating_point(): x.requires_grad = training if not self.use_static_shapes: torch._dynamo.mark_dynamic(args2[0], 0) torch._dynamo.mark_dynamic(args2[1], 0) torch._dynamo.mark_dynamic(args2[2], 0) dropout_arg = [training] if has_dropout else [] torch.manual_seed(1234) result1 = dot_prod_attention(*(args1 + dropout_arg)) counters.clear() torch.manual_seed(1234) result2, source_code = run_and_get_code( torch.compile(dot_prod_attention, fullgraph=True), *(args2 + dropout_arg), ) source_code = "\n".join(source_code) if has_fuse_pattern: self.assertGreaterEqual(counters["inductor"]["fuse_attention"], 1) if contains: # many of the patterns get re-expanded in dispatcher self.assertIn( "aten._scaled_dot_product", source_code, ) # some tests configured with very low dropout where we still want to check equality if not has_dropout or override_check_equal: self.assertEqual(result1, result2, atol=atol, rtol=1.3e-6) if training: result1.sum().backward() result2.sum().backward() for arg1, arg2 in zip(args1, args2): if ( isinstance(arg1, torch.Tensor) and arg1.is_floating_point() and (not has_dropout or override_check_equal) ): self.assertEqual(arg1.grad, arg2.grad, atol=atol, rtol=rtol) @skipIfRocm def _test_sdpa_rewriter_1(self): def dot_prod_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor ) -> torch.Tensor: """Input tensors assumed to have shape (batch_size, n_head, seq_len, embed_dim)""" return ( torch.matmul(query, key.transpose(-2, -1)) .div(math.sqrt(key.shape[-1])) .softmax(dim=-1) .matmul(value) ) for dtype in [torch.float, torch.half]: atol = 0.001 rtol = 1.3e-6 if dtype == torch.float else 0.7 if self.device == "cpu" and dtype == torch.half: atol = 2e-3 rtol = 1e-2 self._check_common(dot_prod_attention, dtype=dtype, atol=atol, rtol=rtol) self._check_common( checkpoint_wrapper(dot_prod_attention), dtype=dtype, atol=atol, rtol=rtol, ) @skipIfRocm @torch._inductor.config.patch("freezing", True) def _test_sdpa_rewriter_1_freezing(self): def dot_prod_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor ) -> torch.Tensor: """Input tensors assumed to have shape (batch_size, n_head, seq_len, embed_dim)""" return ( torch.matmul(query, key.transpose(-2, -1)) .div(math.sqrt(key.shape[-1])) .softmax(dim=-1) .matmul(value) ) for dtype in [torch.float, torch.half]: atol = 0.001 rtol = 1.3e-6 if dtype == torch.float else 0.7 if self.device == "cpu" and dtype == torch.half: atol = 2e-3 rtol = 1e-2 with torch.no_grad(): self._check_common( dot_prod_attention, dtype=dtype, atol=atol, rtol=rtol, check_train=False, ) @skipIfRocm def _test_insignificant_strides(self): f32 = torch.float32 # repro taken from https://github.com/pytorch/pytorch/issues/124289 # constant_pad_nd is a single element tensor that gets expanded def forward( permute_3: "f32[1, 32, 1, 128]", permute_4: "f32[1, 32, 1, 128]", permute_5: "f32[1, 32, 1, 128]", permute_6: "f32[1, 1, 64]", mul_2: "f32[1, 1, 1, 1]", ): cat = torch.ops.aten.cat.default([permute_6, permute_6], 2) permute_6 = None cos = torch.ops.aten.cos.default(cat) sin = torch.ops.aten.sin.default(cat) unsqueeze_10 = torch.ops.aten.unsqueeze.default(cos, 1) cos = None unsqueeze_11 = torch.ops.aten.unsqueeze.default(sin, 1) sin = None mul_5 = torch.ops.aten.mul.Tensor(permute_3, unsqueeze_10) slice_10 = torch.ops.aten.slice.Tensor(permute_3, 3, 0, 64) slice_11 = torch.ops.aten.slice.Tensor( permute_3, 3, 64, 9223372036854775807 ) permute_3 = None neg = torch.ops.aten.neg.default(slice_11) slice_11 = None cat_1 = torch.ops.aten.cat.default([neg, slice_10], 3) neg = slice_10 = None mul_6 = torch.ops.aten.mul.Tensor(cat_1, unsqueeze_11) cat_1 = None add_1 = torch.ops.aten.add.Tensor(mul_5, mul_6) mul_5 = mul_6 = None mul_7 = torch.ops.aten.mul.Tensor(permute_4, unsqueeze_10) unsqueeze_10 = None slice_12 = torch.ops.aten.slice.Tensor(permute_4, 3, 0, 64) slice_13 = torch.ops.aten.slice.Tensor( permute_4, 3, 64, 9223372036854775807 ) permute_4 = None neg_1 = torch.ops.aten.neg.default(slice_13) slice_13 = None cat_2 = torch.ops.aten.cat.default([neg_1, slice_12], 3) neg_1 = slice_12 = None mul_8 = torch.ops.aten.mul.Tensor(cat_2, unsqueeze_11) cat_2 = unsqueeze_11 = None add_2 = torch.ops.aten.add.Tensor(mul_7, mul_8) mul_7 = mul_8 = None slice_14 = torch.ops.aten.slice.Tensor(mul_2, 0, 0, 9223372036854775807) mul_2 = None slice_15 = torch.ops.aten.slice.Tensor(slice_14, 1, 0, 9223372036854775807) slice_14 = None slice_16 = torch.ops.aten.slice.Tensor(slice_15, 2, 0, 9223372036854775807) slice_15 = None constant_pad_nd = torch.ops.aten.constant_pad_nd.default( slice_16, [0, 7], 0.0 ) slice_16 = None slice_17 = torch.ops.aten.slice.Tensor(constant_pad_nd, -1, 0, 1) constant_pad_nd = None expand_5 = torch.ops.aten.expand.default(slice_17, [1, 32, 1, 1]) _scaled_dot_product_efficient_attention = ( torch.ops.aten._scaled_dot_product_efficient_attention.default( add_1, add_2, permute_5, expand_5, True ) ) return _scaled_dot_product_efficient_attention kwargs = aot_graph_input_parser(forward, device="cuda") # runs successfully out_eager = forward(**kwargs) out_c = torch.compile(forward)(**kwargs) # dont compare philox_seed/offset torch.testing.assert_close(out_eager[0:2], out_c[0:2]) def _test_pattern_fails_with_reuse(self): """ This test checks that the replacement is not done when an intermediate result is being used / returned downstream """ @torch.compile(fullgraph=True) def dot_prod_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor ) -> torch.Tensor: attn_weights = ( torch.matmul(query, key.transpose(-2, -1)) .div(math.sqrt(key.shape[-1])) .softmax(dim=-1) ) return attn_weights.matmul(value), attn_weights tensor_shape = (2, 4, 8, 16) args = [ torch.randn(tensor_shape, device=self.device), torch.randn(tensor_shape, device=self.device), torch.randn(tensor_shape, device=self.device), ] _, (source_code,) = run_and_get_code(dot_prod_attention, *args) self.assertNotIn("aten._scaled_dot_product_efficient_attention", source_code) @skipIfRocm def _test_sdpa_rewriter_2(self): def dot_prod_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor ) -> torch.Tensor: return ( torch.matmul(query, key.transpose(-2, -1)) .mul(1.0 / math.sqrt(key.shape[-1])) .softmax(dim=-1) .matmul(value) ) self._check_common(dot_prod_attention) self._check_common(checkpoint_wrapper(dot_prod_attention)) @skipIfRocm # AssertionError: expected size 4==4, stride 32==64 at dim=0 def _test_sdpa_rewriter_3(self): def dot_prod_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, training: bool ) -> torch.Tensor: return torch.nn.functional.dropout( torch.matmul(query, key.transpose(-2, -1)).div(3.0).softmax(dim=-1), p=0.4, training=training, inplace=False, ).matmul(value) self._check_common(dot_prod_attention, contains=False, has_dropout=True) self._check_common( checkpoint_wrapper(dot_prod_attention), contains=False, has_dropout=True ) @skipIfRocm # AssertionError: expected size 4==4, stride 32==64 at dim=0 def _test_sdpa_rewriter_4(self): def dot_prod_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, training: bool, ) -> torch.Tensor: return torch.nn.functional.dropout( torch.matmul(query, key.transpose(-2, -1)).mul(0.4).softmax(dim=-1), p=0.2, inplace=False, training=training, ).matmul(value) self._check_common(dot_prod_attention, contains=False, has_dropout=True) self._check_common( checkpoint_wrapper(dot_prod_attention), contains=False, has_dropout=True ) def _test_sdpa_rewriter_5(self): def sfdp_pattern_5_v1(query, key, value): attn_mask = torch.ones( query.size(-2), key.size(-2), dtype=torch.bool, device=query.device ).tril(diagonal=0) attn_mask = attn_mask.masked_fill( torch.logical_not(attn_mask), -float("inf") ) attn_weight = torch.softmax( (query @ key.transpose(-2, -1) / math.sqrt(query.size(-1))) + attn_mask, dim=-1, ) return attn_weight @ value def sfdp_pattern_5_v2(query, key, value): # https://github.com/pytorch/pytorch/issues/100318. attn_mask = torch.zeros( query.size(-2), key.size(-2), dtype=torch.bool, device=query.device ).bool() attn_weight = torch.softmax( (query @ key.transpose(-2, -1) / math.sqrt(query.size(-1))) + attn_mask, dim=-1, ) return attn_weight @ value self._check_common(sfdp_pattern_5_v1, contains=False) self._check_common(checkpoint_wrapper(sfdp_pattern_5_v1), contains=False) self._check_common(sfdp_pattern_5_v2, contains=False) self._check_common(checkpoint_wrapper(sfdp_pattern_5_v2), contains=False) @skipIfRocm def _test_sdpa_rewriter_6(self): def sfdp_pattern_6(query, key, value, training): attn_mask = torch.ones( query.size(-2), key.size(-2), dtype=torch.bool, device=query.device ).tril(diagonal=0) attn_mask = attn_mask.masked_fill( torch.logical_not(attn_mask), -float("inf") ) attn_weight = torch.softmax( (query @ key.transpose(-2, -1) / math.sqrt(query.size(-1))) + attn_mask, dim=-1, ) attn_weight = torch.nn.functional.dropout(attn_weight, 0.5, training) return attn_weight @ value self._check_common(sfdp_pattern_6, contains=False, has_dropout=True) self._check_common( checkpoint_wrapper(sfdp_pattern_6), contains=False, has_dropout=True ) @skipIfRocm def _test_sdpa_rewriter_7(self): def sfdp_pattern_7(query, key, value, training): q = query.permute(0, 2, 1, 3) k = key.permute(0, 2, 1, 3) v = value.permute(0, 2, 1, 3) div = q @ k.transpose(-2, -1) / math.sqrt(q.size(-1)) div = div.to(torch.float32) attn_weight = torch.softmax(div, dim=-1) # Set to False attn_weight = torch.dropout(attn_weight, 0.00000000001, training) attn_weight = attn_weight.to(torch.float16) return attn_weight @ v args = ( torch.randn((2, 8, 4, 16), device=self.device, dtype=torch.half), torch.randn((2, 8, 4, 16), device=self.device, dtype=torch.half), torch.randn((2, 8, 4, 16), device=self.device, dtype=torch.half), ) self._check_common( sfdp_pattern_7, args, contains=SM80OrLater, has_dropout=True, override_check_equal=True, atol=2e-3, ) args = ( torch.randn((2, 8, 4, 16), device="cuda", dtype=torch.half), torch.randn((2, 8, 4, 16), device="cuda", dtype=torch.half), torch.randn((2, 8, 4, 16), device="cuda", dtype=torch.half), ) self._check_common( checkpoint_wrapper(sfdp_pattern_7), args, contains=SM80OrLater, has_dropout=True, override_check_equal=True, atol=2e-3, ) @skipIfRocm def _test_sdpa_rewriter_8(self): def sfdp_pattern_8(query, key, value): q = query.permute(0, 2, 1, 3) k = key.permute(0, 2, 1, 3) v = value.permute(0, 2, 1, 3) div = q @ k.transpose(-2, -1) / math.sqrt(q.size(-1)) div = div.to(torch.float32) attn_weight = torch.softmax(div, dim=-1) attn_weight = attn_weight.to(torch.float16) return attn_weight @ v args = ( torch.randn((2, 8, 4, 16), device=self.device, dtype=torch.half), torch.randn((2, 8, 4, 16), device=self.device, dtype=torch.half), torch.randn((2, 8, 4, 16), device=self.device, dtype=torch.half), ) self._check_common(sfdp_pattern_8, args, atol=2e-3) args = ( torch.randn((2, 8, 4, 16), device="cuda", dtype=torch.half), torch.randn((2, 8, 4, 16), device="cuda", dtype=torch.half), torch.randn((2, 8, 4, 16), device="cuda", dtype=torch.half), ) self._check_common(checkpoint_wrapper(sfdp_pattern_8), args, atol=2e-3) @skipIfRocm def _test_sdpa_rewriter_9(self): def sfdp_pattern_9(query, key, value, training): q = query.permute(0, 2, 1, 3) k = key.permute(0, 2, 1, 3) v = value.permute(0, 2, 1, 3) q = q / math.sqrt(q.size(-1)) div = q @ k.transpose(-2, -1) div = div.to(torch.float32) attn_weight = torch.softmax(div, dim=-1) # very low dropout to make test pass attn_weight = torch.dropout(attn_weight, 0.00000000001, training) attn_weight = attn_weight.to(torch.float16) return attn_weight @ v args = ( torch.randn((2, 8, 4, 16), device=self.device, dtype=torch.half), torch.randn((2, 8, 4, 16), device=self.device, dtype=torch.half), torch.randn((2, 8, 4, 16), device=self.device, dtype=torch.half), ) self._check_common( sfdp_pattern_9, args, contains=SM80OrLater, has_dropout=True, override_check_equal=True, atol=2e-3, ) args = ( torch.randn((2, 8, 4, 16), device="cuda", dtype=torch.half), torch.randn((2, 8, 4, 16), device="cuda", dtype=torch.half), torch.randn((2, 8, 4, 16), device="cuda", dtype=torch.half), ) self._check_common( checkpoint_wrapper(sfdp_pattern_9), args, contains=SM80OrLater, has_dropout=True, override_check_equal=True, atol=2e-3, ) @skipIfRocm def _test_sdpa_rewriter_10(self): def sfdp_pattern_10(query, key, value): q = query.permute(0, 2, 1, 3) k = key.permute(0, 2, 1, 3) v = value.permute(0, 2, 1, 3) q = q / math.sqrt(q.size(-1)) div = q @ k.transpose(-2, -1) div = div.to(torch.float32) attn_weight = torch.softmax(div, dim=-1) attn_weight = attn_weight.to(torch.float16) return attn_weight @ v args = ( torch.randn((2, 8, 4, 16), device=self.device, dtype=torch.half), torch.randn((2, 8, 4, 16), device=self.device, dtype=torch.half), torch.randn((2, 8, 4, 16), device=self.device, dtype=torch.half), ) self._check_common(sfdp_pattern_10, args, atol=2e-3) args = ( torch.randn((2, 8, 4, 16), device="cuda", dtype=torch.half), torch.randn((2, 8, 4, 16), device="cuda", dtype=torch.half), torch.randn((2, 8, 4, 16), device="cuda", dtype=torch.half), ) self._check_common(checkpoint_wrapper(sfdp_pattern_10), args, atol=2e-3) def _test_pattern_fails_with_tensor_factor(self): # https://github.com/pytorch/pytorch/issues/99124 class Model(torch.nn.Module): def __init__(self, is_inv_factor): super().__init__() self.is_inv_factor = is_inv_factor def forward(self, query, key, value, scale_factor) -> torch.Tensor: # Dividing by scale_factor makes scale_factor gradients very # unstable scale_factor = scale_factor.detach() y = torch.matmul(query, key.transpose(-2, -1)) if self.is_inv_factor: y = y.div(scale_factor) else: y = y.mul(scale_factor) return y.softmax(dim=-1).matmul(value) tensor_shape = (2, 4, 4, 4) for is_inv_factor in [True, False]: args = [ torch.randn(tensor_shape, device=self.device), torch.randn(tensor_shape, device=self.device), torch.randn(tensor_shape, device=self.device), torch.randn((4, 1, 1), device=self.device), ] model = Model(is_inv_factor).eval() # The training path has an accuracy gap compared with eager mode. self._check_common( model, args1=args, contains=False, atol=1e-3, has_fuse_pattern=False ) def _test_pattern_fails_with_unsupported_mask(self): if not self.use_static_shapes: self.skipTest("Causes shape specialization. TODO: investigate") # https://github.com/pytorch/pytorch/issues/100315 class Model(torch.nn.Module): def __init__( self, ): super().__init__() def forward(self, query, key, value, attn_mask) -> torch.Tensor: attn_weight = torch.softmax( query @ key.transpose(-2, -1) / math.sqrt(query.size(-1)) + attn_mask, dim=-1, ) return attn_weight @ value tensor_shape = (2, 4, 4, 4) upsupported_masks = [ torch.randn((2, 4, 4, 4), device=self.device).to(dtype=torch.int), 2.0, ] for atte_mask in upsupported_masks: args = [ torch.randn(tensor_shape, device=self.device), torch.randn(tensor_shape, device=self.device), torch.randn(tensor_shape, device=self.device), atte_mask, ] model = Model().eval() # The training path has an accuracy gap compared with eager mode. self._check_common( model, args1=args, contains=False, atol=1e-4, has_fuse_pattern=False ) @skipIfRocm def _test_sdpa_rewriter_11(self): def dot_prod_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor ) -> torch.Tensor: """Input tensors assumed to have shape (batch_size, seq_len, n_head, embed_dim)""" q = query.transpose(1, 2) k = key.transpose(1, 2) v = value.transpose(1, 2) return ( torch.matmul(q, k.transpose(-2, -1)) .div(math.sqrt(key.shape[-1])) .softmax(dim=-1) .matmul(v) ) self._check_common(dot_prod_attention) def _test_sdpa_rewriter_12(self): def dot_prod_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, training: bool, ) -> torch.Tensor: """Input tensors assumed to have shape (batch_size, seq_len, n_head, embed_dim)""" q = query.transpose(1, 2) k = key.transpose(1, 2) v = value.transpose(1, 2) return torch.nn.functional.dropout( torch.matmul(q, k.transpose(-2, -1)) .div(math.sqrt(key.shape[-1])) .softmax(dim=-1) .matmul(v), p=0.4, training=training, inplace=False, ) self._check_common(dot_prod_attention, contains=False, has_dropout=True) @skipIfRocm def _test_sdpa_prev_13(self): def dot_prod_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor ) -> torch.Tensor: """Input tensors assumed to have shape (batch_size, n_head, seq_len, embed_dim)""" return ( torch.matmul(query, key.transpose(-2, -1)) .div(math.sqrt(key.shape[-1])) .softmax(dim=-1) .clone() .matmul(value) ) self._check_common(dot_prod_attention, check_train=False) self._check_common(checkpoint_wrapper(dot_prod_attention), check_train=False) @skipIfRocm def _test_sdpa_prev_14(self): def dot_prod_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor ) -> torch.Tensor: return ( torch.matmul(query, key.transpose(-2, -1)) .mul(1.0 / math.sqrt(key.shape[-1])) .softmax(dim=-1) .clone() .matmul(value) ) self._check_common(dot_prod_attention, check_train=False) self._check_common(checkpoint_wrapper(dot_prod_attention), check_train=False) @skipIfRocm def _test_sdpa_prev_15(self): def dot_prod_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor ) -> torch.Tensor: """Input tensors assumed to have shape (batch_size, seq_len, n_head, embed_dim)""" q = query.transpose(1, 2) k = key.transpose(1, 2) v = value.transpose(1, 2) return ( torch.matmul(q, k.transpose(-2, -1)) .div(math.sqrt(key.shape[-1])) .softmax(dim=-1) .clone() .matmul(v) ) self._check_common(dot_prod_attention, check_train=False) @skipIfRocm def _test_sdpa_rewriter_13(self, dtype): def dot_prod_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, training: bool, ) -> torch.Tensor: """Input tensors assumed to have shape (batch_size, seq_len, n_head, embed_dim)""" attn_weight = torch.bmm(query, key.transpose(1, 2)).softmax(dim=-1) attn_weight = torch.nn.functional.dropout( attn_weight, p=0.5, training=training ) return torch.bmm(attn_weight, value) tensor_shape = (4, 8, 16) args = [ torch.randn(tensor_shape, device=self.device, dtype=dtype), torch.randn(tensor_shape, device=self.device, dtype=dtype), torch.randn(tensor_shape, device=self.device, dtype=dtype), ] self._check_common( dot_prod_attention, check_train=False, args1=args, has_dropout=True, override_check_equal=True, atol=1e-2, rtol=1e-2, ) @skipIfRocm def _test_sdpa_rewriter_14(self): def dot_prod_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor ) -> torch.Tensor: """Input tensors assumed to have shape (batch_size, seq_len, n_head, embed_dim)""" attn_mask = torch.ones( query.size(1), key.size(1), dtype=torch.bool, device=query.device ).tril(diagonal=0) attn_mask = attn_mask.masked_fill( torch.logical_not(attn_mask), -float("inf") ) q = query.permute(0, 2, 1, 3) k = key.permute(0, 2, 1, 3) v = value.permute(0, 2, 1, 3) return ( (torch.matmul(q, k.transpose(-2, -1)).div(3.0) + attn_mask) .softmax(dim=-1) .matmul(v) ) self._check_common(dot_prod_attention) @skipIfRocm def _test_sdpa_rewriter_15(self): def dot_prod_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor ) -> torch.Tensor: """Input tensors assumed to have shape (batch_size, seq_len, n_head, embed_dim)""" q = query.transpose(1, 2) k = key.transpose(1, 2) v = value.transpose(1, 2) bs = q.size(0) k_len = k.size(-2) attn_mask = torch.ones( bs, k_len, dtype=torch.bool, device=query.device ).tril(diagonal=0) scores = torch.matmul(q, k.transpose(-2, -1)) / 3.0 attn_mask = (attn_mask == 0).view((bs, 1, 1, k_len)).expand_as(scores) scores = scores.masked_fill(attn_mask, -float("inf")) weights = torch.nn.functional.softmax(scores, dim=-1) return torch.matmul(weights, v) self._check_common(dot_prod_attention, check_train=False) @skipIfRocm def _test_sdpa_rewriter_16(self): def dot_prod_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, training ) -> torch.Tensor: """Input tensors assumed to have shape (batch_size, seq_len, n_head, embed_dim)""" attn_mask = torch.ones( query.size(1), key.size(1), dtype=torch.bool, device=query.device ).tril(diagonal=0) attn_mask = attn_mask.masked_fill( torch.logical_not(attn_mask), -float("inf") ) q = query.permute(0, 2, 1, 3) k = key.permute(0, 2, 1, 3) v = value.permute(0, 2, 1, 3) return torch.nn.functional.dropout( (torch.matmul(q, k.transpose(-2, -1)).div(3.0) + attn_mask).softmax( dim=-1 ), p=0.4, training=training, inplace=False, ).matmul(v) self._check_common(dot_prod_attention, contains=False, has_dropout=True) # also check batch_size=1 because the graph is slightly different tensor_shape = (1, 2, 16, 32) args = [ torch.randn(tensor_shape, device=self.device), torch.randn(tensor_shape, device=self.device), torch.randn(tensor_shape, device=self.device), ] self._check_common( dot_prod_attention, args1=args, contains=False, has_dropout=True ) @skipIfRocm def _test_sdpa_rewriter_16_fp32_mask(self): def dot_prod_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, training ) -> torch.Tensor: """Input tensors assumed to have shape (batch_size, seq_len, n_head, embed_dim)""" attn_mask = torch.randn( query.size(1), key.size(1), dtype=torch.float, device=query.device ).tril(diagonal=0) q = query.permute(0, 2, 1, 3) k = key.permute(0, 2, 1, 3) v = value.permute(0, 2, 1, 3) return torch.nn.functional.dropout( (torch.matmul(q, k.transpose(-2, -1)).div(3.0) + attn_mask).softmax( dim=-1 ), p=0.4, training=training, inplace=False, ).matmul(v) self._check_common(dot_prod_attention, contains=False, has_dropout=True) # also check batch_size=1 because the graph is slightly different tensor_shape = (1, 2, 16, 32) args = [ torch.randn(tensor_shape, device=self.device), torch.randn(tensor_shape, device=self.device), torch.randn(tensor_shape, device=self.device), ] self._check_common( dot_prod_attention, args1=args, contains=False, has_dropout=True ) @skipIfRocm def _test_sdpa_rewriter_17(self): def dot_prod_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, training ) -> torch.Tensor: """Input tensors assumed to have shape (batch_size, seq_len, n_head, embed_dim)""" q = query.transpose(1, 2) k = key.transpose(1, 2) v = value.transpose(1, 2) bs = q.size(0) k_len = k.size(-2) attn_mask = torch.ones( bs, k_len, dtype=torch.bool, device=query.device ).tril(diagonal=0) scores = torch.matmul(q, k.transpose(-2, -1)) / 3.0 attn_mask = (attn_mask == 0).view((bs, 1, 1, k_len)).expand_as(scores) scores = scores.masked_fill(attn_mask, -float("inf")) weights = torch.nn.functional.softmax(scores, dim=-1) weights = torch.nn.functional.dropout( weights, p=0.4, training=training, inplace=False, ) return torch.matmul(weights, v) self._check_common(dot_prod_attention, check_train=False, has_dropout=True) @skipIfRocm def _test_sdpa_rewriter_18(self): def dot_prod_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, causal_mask: torch.Tensor, ) -> torch.Tensor: # for hf_GPT2 with dropout query = query.permute([0, 2, 1, 3]) key = key.permute([0, 2, 1, 3]) value = value.permute([0, 2, 1, 3]) attn_weights = torch.matmul(query, key.permute(0, 1, 3, 2)) inv_scale = torch.full( (), math.sqrt(value.size(-1)), dtype=query.dtype, device=query.device ) attn_weights = attn_weights.div(inv_scale) causal_mask_value = torch.full( (), torch.finfo(query.dtype).min, dtype=query.dtype, device=query.device ) attn_weights = torch.where(causal_mask, attn_weights, causal_mask_value) return ( ( torch.nn.functional.dropout( attn_weights.softmax(dim=-1), 0.0 ).matmul(value) ), key.permute([0, 2, 1, 3]), value.permute([0, 2, 1, 3]), ) tensor_shape = (4, 2, 16, 32) causal_mask = torch.ones(2, 2, dtype=torch.bool, device=self.device).tril( diagonal=0 ) args = [ torch.randn(tensor_shape, device=self.device), torch.randn(tensor_shape, device=self.device), torch.randn(tensor_shape, device=self.device), causal_mask, ] self._check_common( dot_prod_attention, args1=args, contains=False, has_dropout=False, check_train=False, ) # also check batch_size=1 because the graph is slightly different tensor_shape = (1, 2, 16, 32) args = [ torch.randn(tensor_shape, device=self.device), torch.randn(tensor_shape, device=self.device), torch.randn(tensor_shape, device=self.device), causal_mask, ] self._check_common( dot_prod_attention, args1=args, contains=False, has_dropout=False, check_train=False, ) @skipIfRocm def _test_sdpa_rewriter_19(self): def dot_prod_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, causal_mask: torch.Tensor, attn_mask: torch.Tensor, training, ) -> torch.Tensor: attn_weights = torch.matmul(query, key.permute(0, 1, 3, 2)) inv_scale = torch.full( (), math.sqrt(value.size(-1)), dtype=attn_weights.dtype, device=attn_weights.device, ) attn_weights = attn_weights.div(inv_scale) causal_mask_value = torch.full( (), torch.finfo(query.dtype).min, dtype=query.dtype, device=query.device ) attn_weights = torch.where(causal_mask, attn_weights, causal_mask_value) attn_weights = attn_weights + attn_mask attn_weights = attn_weights.softmax(dim=-1).type(value.dtype) return torch.nn.functional.dropout( attn_weights, p=0.4, training=training, inplace=False, ).matmul(value) tensor_shape = (4, 2, 16, 32) causal_mask = torch.ones(16, 16, dtype=torch.bool, device=self.device).tril( diagonal=0 ) attn_mask = torch.randn((16, 16), dtype=torch.float, device=self.device) args = [ torch.randn(tensor_shape, device=self.device), torch.randn(tensor_shape, device=self.device), torch.randn(tensor_shape, device=self.device), causal_mask, attn_mask, ] self._check_common( dot_prod_attention, args1=args, contains=False, has_dropout=True, check_train=False, ) if HAS_CUDA and PLATFORM_SUPPORTS_FUSED_ATTENTION: class SDPAPatternRewriterCudaTests(TestSDPAPatternRewriterTemplate): device = "cuda" test_sdpa_rewriter_1_cuda = ( TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_1 ) test_sdpa_rewriter_1_freezing = ( TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_1_freezing ) test_insignificant_strides = ( TestSDPAPatternRewriterTemplate._test_insignificant_strides ) test_pattern_fails_with_reuse_cuda = ( TestSDPAPatternRewriterTemplate._test_pattern_fails_with_reuse ) test_sdpa_rewriter_2_cuda = ( TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_2 ) test_sdpa_rewriter_3_cuda = ( TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_3 ) test_sdpa_rewriter_4_cuda = ( TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_4 ) test_sdpa_rewriter_5_cuda = ( TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_5 ) test_sdpa_rewriter_6_cuda = ( TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_6 ) test_sdpa_rewriter_7_cuda = ( TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_7 ) test_sdpa_rewriter_8_cuda = ( TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_8 ) test_sdpa_rewriter_9_cuda = ( TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_9 ) test_sdpa_rewriter_10_cuda = ( TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_10 ) test_pattern_fails_with_tensor_factor_cuda = ( TestSDPAPatternRewriterTemplate._test_pattern_fails_with_tensor_factor ) test_pattern_fails_with_unsupported_mask_cuda = ( TestSDPAPatternRewriterTemplate._test_pattern_fails_with_unsupported_mask ) test_sdpa_rewriter_11_cuda = ( TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_11 ) test_sdpa_rewriter_12_cuda = ( TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_12 ) test_sdpa_prev_13_cuda = TestSDPAPatternRewriterTemplate._test_sdpa_prev_13 test_sdpa_prev_14_cuda = TestSDPAPatternRewriterTemplate._test_sdpa_prev_14 test_sdpa_prev_15_cuda = TestSDPAPatternRewriterTemplate._test_sdpa_prev_15 test_sdpa_rewriter_13_cuda = functools.partialmethod( TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_13, dtype=torch.half ) test_sdpa_rewriter_14_cuda = functools.partialmethod( TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_14 ) test_sdpa_rewriter_15_cuda = functools.partialmethod( TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_15 ) test_sdpa_rewriter_17_cuda = functools.partialmethod( TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_17 ) test_sdpa_rewriter_19_cuda = functools.partialmethod( TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_19 ) class SDPAPatternRewriterCudaDynamicTests(SDPAPatternRewriterCudaTests): use_static_shapes = False if HAS_CPU: class SDPAPatternRewriterCpuTests(TestSDPAPatternRewriterTemplate): device = "cpu" test_sdpa_rewriter_1_cpu = TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_1 test_pattern_fails_with_reuse_cpu = ( TestSDPAPatternRewriterTemplate._test_pattern_fails_with_reuse ) test_sdpa_rewriter_2_cpu = TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_2 test_sdpa_rewriter_5_cpu = TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_5 test_pattern_fails_with_tensor_factor_cpu = ( TestSDPAPatternRewriterTemplate._test_pattern_fails_with_tensor_factor ) test_pattern_fails_with_unsupported_mask_cpu = ( TestSDPAPatternRewriterTemplate._test_pattern_fails_with_unsupported_mask ) test_sdpa_rewriter_11_cpu = ( TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_11 ) test_sdpa_rewriter_12_cpu = ( TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_12 ) test_sdpa_prev_13_cpu = TestSDPAPatternRewriterTemplate._test_sdpa_prev_13 test_sdpa_prev_14_cpu = TestSDPAPatternRewriterTemplate._test_sdpa_prev_14 test_sdpa_prev_15_cpu = TestSDPAPatternRewriterTemplate._test_sdpa_prev_15 test_sdpa_rewriter_13_cpu = functools.partialmethod( TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_13, dtype=torch.float32 ) test_sdpa_rewriter_14_cpu = functools.partialmethod( TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_14 ) test_sdpa_rewriter_15_cpu = functools.partialmethod( TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_15 ) test_sdpa_rewriter_16_cpu = functools.partialmethod( TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_16 ) test_sdpa_rewriter_16_fp32_mask_cpu = functools.partialmethod( TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_16_fp32_mask ) test_sdpa_rewriter_17_cpu = functools.partialmethod( TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_17 ) test_sdpa_rewriter_18_cpu = functools.partialmethod( TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_18 ) test_sdpa_rewriter_19_cpu = functools.partialmethod( TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_19 ) class SDPAPatternRewriterCpuDynamicTests(SDPAPatternRewriterCpuTests): use_static_shapes = False if __name__ == "__main__": if IS_LINUX: run_tests()