-
Notifications
You must be signed in to change notification settings - Fork 928
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Mamba2 9 times slower inference time than Mamba1 #355
Comments
I think there are some mistake in code because i also found that mamba2 is quite slow :) |
Mamba2 is written mostly in Triton, so there's a lot of CPU overhead if the layer is so small. Two ways to get around that: (1) CUDA graph (or torch compile) (2) use a large model. |
@tridao Thank you for your help. I add this line "@torch.compile(options={"triton.cudagraphs": True}, fullgraph=True)" before mamba_split_conv1d_scan_combined function in ssd_combined and get the speed competitive with original mamba code. |
Yes, CUDA grapha works! |
This comment was marked as duplicate.
This comment was marked as duplicate.
What is your main_test.py file and can you give me some detail about your environment ? |
@Kiet0712 Thanks for your response. It is the main_test.py
|
I still find this issue even with cuda graphs compile. I applied ".contiguous()" patch to fix stride issues. Also used annotation for compile with CUDA graphs. My test is on a H100 with:
Package versions: torch 2.3.0 LOG:
|
for my case:
|
@dwgan I don't actually know what is your problem, in my case, i just simple add "@torch.compile(options={"triton.cudagraphs": True}, fullgraph=True)" before mamba_split_conv1d_scan_combined function and then use Mamba2 in my task and it works. |
I have the same problem, trying to work through it now. If I find a solution I'll let you know, in the meantime any help is very much appreciated! |
May I ask about your Torch and Triton versions? |
Torch==2.1.2 |
@Kiet0712 could you tell us your torch and triton versions? thanks. |
@Baijiong-Lin I use triton 2.1.0 and torch 2.1.1 |
@Kiet0712 thanks. but it does not work for me. it still has an error after adding "@torch.compile(options={"triton.cudagraphs": True}, fullgraph=True)" before mamba_split_conv1d_scan_combined function. |
do anyone solved |
I've tried this but I'm still getting an error, and I'd appreciate it if you could show me the demo code |
I think the problem was solved. See my code here
The output log
|
After adding the pre compiled model of 'torch. compile', I actually need a warm up to achieve good results. But why can you solve it without using compile here? |
I use the original version, without adding 'torch. compile'. |
see #389 |
After change the
d_model
, mamba2 worked in the simple test environment provided in the README. But I noticed that the mamba2 has a much slower speed than mamba1. I tested it, here is my codeThe result I got is this
I don't know if it is a bug or did I make a mistake. Please feel free to share your thoughts.
The text was updated successfully, but these errors were encountered: