Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Investigate a new SDPA IR node #2278

Open
Priya2698 opened this issue May 21, 2024 · 5 comments
Open

Investigate a new SDPA IR node #2278

Priya2698 opened this issue May 21, 2024 · 5 comments
Assignees

Comments

@Priya2698
Copy link
Collaborator

Add a new IR node for SDPA that is currently not supported within nvFuser.
CC: @cowanmeg @kevinstephano

@Priya2698
Copy link
Collaborator Author

Priya2698 commented May 28, 2024

Summarizing discussion from today's meeting and an offline discussion with @jjsjann123:

For training, we need two SdpaOpFwd and SdpaOpBwd nodes. PR #2294 currently uses at::scaled_dot_product_attention that does not return any intermediate values to be stored for backward and is an inference-only node. We can merge this with SdpaOpFwd and potentially have a different API if we don't want to return all the outputs.

There are different variants of SDPA in use (flash attention, memory efficient) with slightly different function signatures, we will initially start with one (possibly, flash attention, after verifying that it is indeed being used in models like nanogpt).

CC: @IvanYashchuk

@IvanYashchuk
Copy link
Collaborator

which signature is the right one to target here? PyTorch itself has so many different variants and each has a different signature:
https://github.com/pytorch/pytorch/blob/a6b994ed5467d4df8320cbae51cba6a98ffb139c/aten/src/ATen/native/transformers/attention.cpp#L665-L706
https://github.com/pytorch/pytorch/blob/a6b994ed5467d4df8320cbae51cba6a98ffb139c/tools/autograd/derivatives.yaml#L2806-L2829

There's no right or wrong signature, it depends on how you want to do backward computation and that would dictate the output signature for the forward function. You need to make the decision yourself what needs to be stashed for backward or could be recomputed. If you want to have fallbacks to ATen then there's no other choice than directly mimicking ATen's function signatures. It's important to remember that Flash Attention doesn't work for all input cases, the "memory efficient" one also doesn't work for all input cases.

Is flash attention kernel representable in nvFuser primitives?

@jjsjann123
Copy link
Collaborator

There's no right or wrong signature, it depends on how you want to do backward computation and that would dictate the output signature for the forward function. You need to make the decision yourself what needs to be stashed for backward or could be recomputed.

Yes. The question here is mostly for @cowanmeg, i.e. which implementation we are targeting in codegen would determine what signature we would want to have.

@Priya2698
Copy link
Collaborator Author

[Update from Jun4 meeting]
At the moment we only plan on supporting Flash Attention to support multi-GPU development. Once we support Flash Attention, we can revisit, if we need to add Memory Efficient Attention as well. There could be a few ways:

  1. Plumbing down the backend info from Thunder and using that within our nodes: While the two implementations have different function signatures, there are overlaps and hence, one possibility is to use a superset of the inputs and outputs. The other design here would be distinct nodes for each implementation.
  2. We make the decision about the backend within nvFuser using the same logic as ATen/Thunder. See: https://github.com/Lightning-AI/lightning-thunder/blob/9f0c50cc6df187cf5fd2e31240690fe2b5e9ccc1/thunder/executors/sdpaex.py#L618-L680

Priya2698 added a commit that referenced this issue Jun 10, 2024
Based on the PR discussions, this PR is repurposed to introduce a new IR
node `SdpaFwdOp` for scaled dot product flash attention forward (see
#2278 for details).
This PR does not include changes to the scheduler.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants