Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FlashAttention actually does not support attention mask #116

Open
HJoonKwon opened this issue Feb 17, 2024 · 3 comments
Open

FlashAttention actually does not support attention mask #116

HJoonKwon opened this issue Feb 17, 2024 · 3 comments

Comments

@HJoonKwon
Copy link

HJoonKwon commented Feb 17, 2024

Thanks for your great work!

I'm just curious whether your code here is using flash or not when mask is not None. My guess is it's using memory efficient attention instead since PyTorch flash attention kernel does not support attention mask. In addition, if memory efficient was used, half() would not have been needed when mask is not None.
Thank you!

++ I did some experiments. Even if sdp_flash is enabled, it is not executed when mask is not None. If we force PyTorch to use flash, it spits out an error like below.

class Attention(nn.Module):
    def __init__(self, attn_dropout=0.0):
        super().__init__()
        self.attn_dropout = attn_dropout

    def forward(self, q, k, v, q_mask=None, kv_mask=None):
        if kv_mask is not None:
            attn_mask = q_mask[:, None, :, None] * kv_mask[:, None, None, :]
        else:
            attn_mask = None
        with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
            y = torch.nn.functional.scaled_dot_product_attention(
                q, k, v, attn_mask=attn_mask, dropout_p=self.attn_dropout, is_causal=False
            )
            
        return y if attn_mask is None else y.nan_to_num()

device = 'cuda'
attn = Attention().to(device)
B = 4
L = 32 * 32
S = 24 * 24
n_embd = 32
n_heads = 4
q = torch.randn(B, n_heads, L, n_embd // n_heads).to(device)
k = torch.randn(B, n_heads, S, n_embd // n_heads).to(device)
v = torch.randn(B, n_heads, S, n_embd // n_heads).to(device)
q_mask = (torch.rand(B, L) > 0.1).to(device)
kv_mask = (torch.rand(B, S) > 0.1).to(device)
x = [x.half() for x in [q, k, v]]
y = attn(*x, q_mask, kv_mask)
/tmp/ipykernel_467687/3943656874.py:12: UserWarning: Memory efficient kernel not used because: (Triggered internally at /opt/conda/conda-bld/pytorch_1702400410390/work/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp:367.)
  y = torch.nn.functional.scaled_dot_product_attention(
/tmp/ipykernel_467687/3943656874.py:12: UserWarning: Memory Efficient attention has been runtime disabled. (Triggered internally at /opt/conda/conda-bld/pytorch_1702400410390/work/aten/src/ATen/native/transformers/sdp_utils_cpp.h:437.)
  y = torch.nn.functional.scaled_dot_product_attention(
/tmp/ipykernel_467687/3943656874.py:12: UserWarning: Flash attention kernel not used because: (Triggered internally at /opt/conda/conda-bld/pytorch_1702400410390/work/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp:369.)
  y = torch.nn.functional.scaled_dot_product_attention(
/tmp/ipykernel_467687/3943656874.py:12: UserWarning: Both fused kernels do not support non-null attn_mask. (Triggered internally at /opt/conda/conda-bld/pytorch_1702400410390/work/aten/src/ATen/native/transformers/sdp_utils_cpp.h:261.)
  y = torch.nn.functional.scaled_dot_product_attention(
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[34], line 12
     10 kv_mask = (torch.rand(B, S) > 0.1).to(device)
     11 x = [x.half() for x in [q, k, v]]
---> 12 y = attn(*x, q_mask, kv_mask)

File ~/miniconda3/envs/torch212/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File ~/miniconda3/envs/torch212/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

Cell In[32], line 12, in TorchNativeAttention.forward(self, q, k, v, q_mask, kv_mask)
     10     attn_mask = None
     11 with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
---> 12     y = torch.nn.functional.scaled_dot_product_attention(
     13         q, k, v, attn_mask=attn_mask, dropout_p=self.attn_dropout, is_causal=False
     14     )
     16 return y if attn_mask is None else y.nan_to_num()

RuntimeError: No available kernel.  Aborting execution.

while memory efficient kernel does not

with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=False, enable_mem_efficient=True):
            y = torch.nn.functional.scaled_dot_product_attention(
                q, k, v, attn_mask=attn_mask, dropout_p=self.attn_dropout, is_causal=False
            )
@Phil26AT
Copy link
Collaborator

Hey @HJoonKwon! Damn, very good find, thank you! I guess this does matter in compiled forward, where we are padding inputs to static dimensions. We'd need to run the benchmarks, but maybe avoiding the call to half() could improve throughput then.

@HJoonKwon
Copy link
Author

HJoonKwon commented Feb 24, 2024

@Phil26AT Great! Thank you again for your great work. I got inspired a lot.

@LudvigDillen
Copy link

On the topic of FlashAttention, you link to FlashAttention and not FlashAttention2 here
image
Isn't the second version used? If not, why? Seems quite much faster
image

FlashAttention: https://arxiv.org/abs/2205.14135
FlashAttention2: https://arxiv.org/pdf/2307.08691.pdf?trk=public_post_comment-text

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants