Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NT] Implementing Multi-Head Attention with NestedTensors #125214

Open
clessig opened this issue Apr 30, 2024 · 11 comments
Open

[NT] Implementing Multi-Head Attention with NestedTensors #125214

clessig opened this issue Apr 30, 2024 · 11 comments
Labels
module: nestedtensor NestedTensor tag see issue #25032 oncall: transformer/mha Issues related to Transformers and MultiheadAttention triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@clessig
Copy link

clessig commented Apr 30, 2024

馃殌 The feature, motivation and pitch

Nested tensors are supported by PyTorch's flash attention implementation (cf. https://gist.github.com/victoroliv2/3668f07e11a0757febb6e55a8d78592a) and this has a markable (approx 25%) speedup compared to alternative options. But extending this example to a full multi-head attention implementation does not work at the moment since flash attention expects 3D tensors in the nested_tensor while nn.Linear requires 2D tensors.

RuntimeError: Linear requires nested_tensor.dim == 3 and dense_matrix.dim == 2. Nested tensor dim: 4. Dense tensor dim: 2

This restriction on nn.Linear also seems odd to me. One could in principle construct the nested_tensor only after the projection but since this involves a copy operation it is rather inefficient and will likely negate any benefit from the flash attention with nested tensors.

Alternatives

No response

Additional context

Here's a minimal example:

class AttentionHead(torch.nn.Module):

  def __init__(self, proj_dims) :
    '''Attention head'''

    super(AttentionHead, self).__init__()

    # self.proj = torch.nn.Linear( proj_dims[0], 3*proj_dims[1], bias = False)
    self.proj_q = torch.nn.Linear( proj_dims[0], proj_dims[1], bias = False)
    self.proj_k = torch.nn.Linear( proj_dims[0], proj_dims[1], bias = False)
    self.proj_v = torch.nn.Linear( proj_dims[0], proj_dims[1], bias = False)
    
    self.softmax = torch.nn.Softmax(dim=-1)

    self.lnorm_q = torch.nn.LayerNorm( proj_dims[1], elementwise_affine=False)
    self.lnorm_k = torch.nn.LayerNorm( proj_dims[1], elementwise_affine=False)
    
    self.att = torch.nn.functional.scaled_dot_product_attention

  def forward( self, xs) :
    
    # q, k, v = torch.tensor_split( self.proj( xs), 3, dim=-1)
    q, k, v = self.proj_q( xs), self.proj_k( xs), self.proj_v( xs)
    q, k = self.lnorm_q( q), self.lnorm_k( k)

    with torch.backends.cuda.sdp_kernel( enable_flash=True, enable_math=False,
                                         enable_mem_efficient=False):
      q_out = self.att( q, k, v)
      
    return q_out

cc @cpuhrsch @jbschlosser @bhosmer @drisspg @soulitzer @erichan1 @mikaylagawarecki

@cpuhrsch cpuhrsch added the oncall: transformer/mha Issues related to Transformers and MultiheadAttention label Apr 30, 2024
@cpuhrsch
Copy link
Contributor

@clessig Did you try this with the torch.jagged layout? As in, torch.nested.nested_tensor([...], layout=torch.jagged)?

@clessig
Copy link
Author

clessig commented May 3, 2024

@cpuhrsch : thanks, torch.jagged helped.

Next error:

a = torch.nested.as_nested_tensor( tokens_cells[0][:10], layout=torch.jagged)
lin = torch.nn.Linear( 1024, 1024, bias=False).to('cuda')
q = lin(a)
[t.shape for t in q]
[torch.Size([19, 1024]), torch.Size([18, 1024]), torch.Size([22, 1024]), torch.Size([19, 1024]), torch.Size([13, 1024]), torch.Size([18, 1024]), torch.Size([21, 1024]), torch.Size([17, 1024]), torch.Size([22, 1024]), torch.Size([19, 1024])]
p = torch.reshape( q, [-1, 8, 128])

Traceback (most recent call last):
File "", line 1, in
File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/nested/_internal/nested_tensor.py", line 232, in torch_function
return func(*args, **kwargs)
File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/nested/_internal/nested_tensor.py", line 216, in torch_dispatch
return fn(*args, **kwargs)
File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/nested/_internal/ops.py", line 182, in inner
return func(aten_op, *args, **kwargs)
File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/nested/_internal/ops.py", line 868, in view_default
raise RuntimeError(f"view(): cannot view shape {inp._size} as {size}")
RuntimeError: view(): cannot view shape (10, j2, 1024) as [-1, 8, 128]

The documentation for reshape isn't very detailed but this is what I would have expected to work.

@jbschlosser jbschlosser added module: nestedtensor NestedTensor tag see issue #25032 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels May 6, 2024
@jbschlosser
Copy link
Contributor

@clessig it looks like you want to reshape a 3D nested tensor -> a 4D shape:

# 3D -> 4D
(batch_size, ragged_seq_len, dim) -> (batch_size, ragged_seq_len, num_heads, head_dim)

Something like this will work:

import torch

shapes = [(19, 1024), (18, 1024), (22, 1024), (19, 1024), (13, 1024),
          (18, 1024), (21, 1024), (17, 1024), (22, 1024), (19, 1024)]

a = torch.nested.as_nested_tensor(
    [torch.randn(*shape, device="cuda") for shape in shapes],
    layout=torch.jagged
)

print(a.shape, a.dim())  # torch.Size([10, j1, 1024]) 3

# do projection
lin = torch.nn.Linear(1024, 1024, bias=False, device="cuda")
q = lin(a)

print(q.shape, q.dim())  # torch.Size([10, j1, 1024]) 3

# split heads
p = q.unflatten(-1, [8, 128])
# alternative reshape() calls:
# p = q.reshape(-1, -1, 8, 128)
# p = q.reshape(10, -1, 8, 128)

print(p.shape, p.dim())  # torch.Size([10, j1, 8, 128]) 4

I suggest using unflatten() since it conceptually matches what you want, but reshape() works fine as well as long as you indicate that you want a 4D shape output.

@clessig
Copy link
Author

clessig commented May 15, 2024

Hi @jbschlosser,

Many thanks for the example. This works now. I also extended it to use flash attention.

I also found now your example here: https://github.com/pytorch/tutorials/blob/main/prototype_source/nestedtensor.py. Unfortunately, the torch.compile is broken with the inductor backend:

torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised: RuntimeError: aten::sym_size() Expected a value of type 'Tensor' for argument 'self' but instead found type 'Proxy'. Position: 0 Value: Proxy(primals_7) Declaration: aten::sym_size.int(Tensor self, int dim) -> SymInt Cast error details: Unable to cast Proxy(primals_7) to Tensor

When I try to compile my code I get a different error:

File "/etc/ecmwf/nfs/dh2_perm_a/nacl/research/obs/lessig-dev-kas-cell/ai-obs-experimental-transformer/pyenv/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py", line 176, in is_concrete_int if isinstance(a.node.expr, sympy.core.numbers.Integer): torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised: AttributeError: 'torch._C._SymNode' object has no attribute 'expr'

Is there already an issue for the problems with torch.inductor and nested_tensors? Any ETA for a fix / work-around?

Thanks!

@jbschlosser
Copy link
Contributor

Hey @clessig, sorry for the trouble - the first problem you mentioned above stemmed from some bad interaction between PT2 tracing + subclass __torch_function__() impls. It was addressed in #121981. Do you still see this when running on a recent version of PyTorch that includes that PR?

The second problem is likely related to #118446, and we are actively working on addressing this. In the meantime, a workaround is to move nested tensor construction outside of the compiled region and send them as inputs to the compiled region instead. This should work fine. Let me know if you run into other issues - thanks!

@clessig
Copy link
Author

clessig commented May 15, 2024

Hi @jbschlosser , yes the first issue is fixed with the latest nightly (I even got it before posting but then only ran my code and not your example again :(). Performance is, however, not what one would expect:

=== with torch.compile ===
nested and padded calculations differ by 0.0
nested tensor multi-head attention takes 0.0053791539976373315 seconds
padded tensor multi-head attention takes 0.001179479993879795 seconds
Nested speedup: 0.219

That's on a A100 with the latest CUDA. Did you see a speed-up?

Thanks!

@jbschlosser
Copy link
Contributor

That's on a A100 with the latest CUDA. Did you see a speed-up?

Yes, I saw a speedup of somewhere in the 3-5x range locally. I'll investigate this on my A100 machine; there may have been some graph break related regression. Does passing fullgraph=True to torch.compile work or error out on your end?

@clessig
Copy link
Author

clessig commented May 15, 2024

fullgraph=True breaks the compile (as you suspected).

I had modified your example to float16 to try out flash attention. Switching back to float32 I see a speedup of 2X:

nested tensor multi-head attention takes 0.005227703019045293 seconds
padded tensor multi-head attention takes 0.010594977997243404 seconds

For reference float16:
nested tensor multi-head attention takes 0.005226842942647636 seconds
padded tensor multi-head attention takes 0.0011759699555113912 seconds

Should flash attention work or will it fall back to the regular implementation?

@jbschlosser
Copy link
Contributor

Should flash attention work or will it fall back to the regular implementation?

Yes, flash should work here as long as all inputs match what it supports. I'd expect flash to be selected if it was compiled for you and if the inputs / MHA module are all converted to float16; this worked locally for me. You can verify that flash is being selected with:

from torch.nn.attention import sdpa_kernel, SDPBackend

with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
    ...

which will error out if flash isn't selectable as the backend.

Running locally with fullgraph=True, I see a graph break related to recomputing min / max sequence length to pass to SDPA. This problem is being addressed in #122836, which should land soon.

@clessig
Copy link
Author

clessig commented May 15, 2024

Ok, this was/is what I am doing. I was wondering since for me the performance is identical with/without flash attention:

with flash:
nested tensor multi-head attention takes 0.005371901206672192 seconds
padded tensor multi-head attention takes 0.0011916616931557655 seconds

without flash:
nested tensor multi-head attention takes 0.005209321156144142 seconds
padded tensor multi-head attention takes 0.0011920034885406494 seconds

Great to see that there's so much work on this! (very useful for what I am doing)

@jbschlosser
Copy link
Contributor

I was wondering since for me the performance is identical with/without flash attention

If flash is available, it's the first priority to be selected. So I'd expect the same results for the same inputs with or without the use of sdpa_kernel(SDPBackend.FLASH_ATTENTION).

Great to see that there's so much work on this! (very useful for what I am doing)

It's a work in progress but it's coming along! If you can provide any more details on the types of support you'll need (any op coverage gaps you run into, etc.), we can use that to help prioritize our efforts :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: nestedtensor NestedTensor tag see issue #25032 oncall: transformer/mha Issues related to Transformers and MultiheadAttention triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

3 participants