Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Why mamba2 is much slower than mamba? #367

Open
dwgan opened this issue Jun 5, 2024 · 2 comments
Open

Why mamba2 is much slower than mamba? #367

dwgan opened this issue Jun 5, 2024 · 2 comments

Comments

@dwgan
Copy link

dwgan commented Jun 5, 2024

Hi, I have compared the inference time of mamba and mamba2 using the same input, but I found the mamba2 is much slower than mamba.

My environment version:
pytorch==2.1.2
pytorch-cuda=12.1

here is the code

`
import torch
from mamba_ssm import Mamba, Mamba2

batch, length, dim = 8, 1024, 128
x = torch.randn(batch, length, dim).to("cuda")

model = Mamba(
d_model=dim, # Model dimension d_model
d_state=64, # SSM state expansion factor
d_conv=4, # Local convolution width
expand=2, # Block expansion factor
).to("cuda")

model2 = Mamba2(
d_model=dim, # Model dimension d_model
d_state=64, # SSM state expansion factor, typically 64 or 128
d_conv=4, # Local convolution width
expand=2, # Block expansion factor
headdim=32, # Additional parameter for Mamba2
ngroups=1, # Number of groups for group normalization
sequence_parallel=False, # Whether to use sequence parallelism
).to("cuda")

def count_parameters(model):
return sum(p.numel() for p in model.parameters())

print(f"Mamba model parameters: {count_parameters(model)}")
print(f"Mamba2 model parameters: {count_parameters(model2)}")

start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)

start_event.record()
y = model(x)
end_event.record()

torch.cuda.synchronize()

mamba_time = start_event.elapsed_time(end_event) # Time in milliseconds

print(f"\nMamba model time: {mamba_time} ms")
print(y.shape)
assert y.shape == x.shape

start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)

start_event.record()
y = model2(x)
end_event.record()

torch.cuda.synchronize()

mamba2_time = start_event.elapsed_time(end_event) # Time in milliseconds

print(f"\nMamba2 model time: {mamba2_time} ms")
print(y.shape)
assert y.shape == x.shape
`

and the output
`
/home/anaconda3/envs/mamba/bin/python /home/mamba/main_test.py
Mamba model parameters: 153344
Mamba2 model parameters: 117912

Mamba model time: 103.55609893798828 ms
torch.Size([8, 1024, 128])

Mamba2 model time: 8146.93310546875 ms
torch.Size([8, 1024, 128])

Process finished with exit code 0
`

Thanks for your kindly help.

@Hprairie
Copy link

Hprairie commented Jun 5, 2024

#355

@dwgan
Copy link
Author

dwgan commented Jun 5, 2024

@Hprairie Thanks a lot.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants