-
Notifications
You must be signed in to change notification settings - Fork 935
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Why mamba2 is much slower than mamba? #367
Comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi, I have compared the inference time of mamba and mamba2 using the same input, but I found the mamba2 is much slower than mamba.
My environment version:
pytorch==2.1.2
pytorch-cuda=12.1
here is the code
`
import torch
from mamba_ssm import Mamba, Mamba2
batch, length, dim = 8, 1024, 128
x = torch.randn(batch, length, dim).to("cuda")
model = Mamba(
d_model=dim, # Model dimension d_model
d_state=64, # SSM state expansion factor
d_conv=4, # Local convolution width
expand=2, # Block expansion factor
).to("cuda")
model2 = Mamba2(
d_model=dim, # Model dimension d_model
d_state=64, # SSM state expansion factor, typically 64 or 128
d_conv=4, # Local convolution width
expand=2, # Block expansion factor
headdim=32, # Additional parameter for Mamba2
ngroups=1, # Number of groups for group normalization
sequence_parallel=False, # Whether to use sequence parallelism
).to("cuda")
def count_parameters(model):
return sum(p.numel() for p in model.parameters())
print(f"Mamba model parameters: {count_parameters(model)}")
print(f"Mamba2 model parameters: {count_parameters(model2)}")
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
y = model(x)
end_event.record()
torch.cuda.synchronize()
mamba_time = start_event.elapsed_time(end_event) # Time in milliseconds
print(f"\nMamba model time: {mamba_time} ms")
print(y.shape)
assert y.shape == x.shape
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
y = model2(x)
end_event.record()
torch.cuda.synchronize()
mamba2_time = start_event.elapsed_time(end_event) # Time in milliseconds
print(f"\nMamba2 model time: {mamba2_time} ms")
print(y.shape)
assert y.shape == x.shape
`
and the output
`
/home/anaconda3/envs/mamba/bin/python /home/mamba/main_test.py
Mamba model parameters: 153344
Mamba2 model parameters: 117912
Mamba model time: 103.55609893798828 ms
torch.Size([8, 1024, 128])
Mamba2 model time: 8146.93310546875 ms
torch.Size([8, 1024, 128])
Process finished with exit code 0
`
Thanks for your kindly help.
The text was updated successfully, but these errors were encountered: