-
Notifications
You must be signed in to change notification settings - Fork 42
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Investigate a new SDPA IR node #2278
Comments
Summarizing discussion from today's meeting and an offline discussion with @jjsjann123: For training, we need two There are different variants of SDPA in use (flash attention, memory efficient) with slightly different function signatures, we will initially start with one (possibly, flash attention, after verifying that it is indeed being used in models like nanogpt). CC: @IvanYashchuk |
There's no right or wrong signature, it depends on how you want to do backward computation and that would dictate the output signature for the forward function. You need to make the decision yourself what needs to be stashed for backward or could be recomputed. If you want to have fallbacks to ATen then there's no other choice than directly mimicking ATen's function signatures. It's important to remember that Flash Attention doesn't work for all input cases, the "memory efficient" one also doesn't work for all input cases. Is flash attention kernel representable in nvFuser primitives? |
Yes. The question here is mostly for @cowanmeg, i.e. which implementation we are targeting in codegen would determine what signature we would want to have. |
[Update from Jun4 meeting]
|
Add a new IR node for SDPA that is currently not supported within nvFuser.
CC: @cowanmeg @kevinstephano
The text was updated successfully, but these errors were encountered: