-
Notifications
You must be signed in to change notification settings - Fork 8.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix flash attention for ROCm #7011
base: master
Are you sure you want to change the base?
Conversation
I didn't close that other PR on accident. As I said before, I don't think we should be adding a dependency with rocWMMA when the performance is no better than master and we have no dev to test and support it. And I will do an implementation of FlashAttention without any tensor cores at all which may end up being faster anyways. |
I don't know how to get compile on windows :( |
Sorry, I didn't realize it had been closed on purpose. Is the dependency that bad, though? rocwmma is header only, so no link time requirement, and it enables sharing the existing CUDA code. The performance is not better, but the VRAM saving can be very significant, 1 GB in one case. The PR is not ready to merge as is anyway, I need to disable flash-attn in CMake by default for AMD GPUs, or enable it only if rocwmma is detected installed. I might not be a ROCm expert, but I am a C++ dev and I own a 7900xtx, if not merged, I might maintain this fork anyway. Of course, if you already have planned to work on that other implementation soon, all of this comment is irrelevant, but having access to a rocwmma based version as a comparison could be useful, I don't know. Please let me know what you think. |
I wasn't able to test flash attention on Windows with 7900XTX yet. |
So i can say that for CDNA this makes a big difference: This pr:
Lastest Master:
Both of those are still terrible compared to exllama but this pr dose make a big difference in the right direction Note that i had to make some trivial changes to this pr to make it choose the wmma path for gfx908 |
llama-bench
buffer = ROCm0 compute buffer size