Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mamba2 9 times slower inference time than Mamba1 #355

Closed
realwenlongwang opened this issue Jun 4, 2024 · 22 comments
Closed

Mamba2 9 times slower inference time than Mamba1 #355

realwenlongwang opened this issue Jun 4, 2024 · 22 comments

Comments

@realwenlongwang
Copy link

After change the d_model, mamba2 worked in the simple test environment provided in the README. But I noticed that the mamba2 has a much slower speed than mamba1. I tested it, here is my code

import torch
from mamba_ssm import Mamba2 as Mamba
# from mamba_ssm import Mamba

start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)

start.record()
batch, length, dim = 2, 64, 256
x = torch.randn(batch, length, dim).to("cuda")
model = Mamba(
    # This module uses roughly 3 * expand * d_model^2 parameters
    d_model=dim, # Model dimension d_model
    d_state=64,  # SSM state expansion factor
    d_conv=4,    # Local convolution width
    expand=2,    # Block expansion factor
).to("cuda")
y = model(x)
end.record()
torch.cuda.synchronize()
inference_time = start.elapsed_time(end)
assert y.shape == x.shape
print(f'parameter number: {sum([p.numel() for p in model.parameters()])}')
print(f'inference time: {inference_time}')

The result I got is this

Mamba1 parameter number: 511488
Mamba1 inference time: 539.1769409179688
Mamba2 parameter number: 431768
Mamba2 inference time: 4322.52294921875

I don't know if it is a bug or did I make a mistake. Please feel free to share your thoughts.

@Kiet0712
Copy link

Kiet0712 commented Jun 4, 2024

I think there are some mistake in code because i also found that mamba2 is quite slow :)

@tridao
Copy link
Collaborator

tridao commented Jun 4, 2024

Mamba2 is written mostly in Triton, so there's a lot of CPU overhead if the layer is so small. Two ways to get around that: (1) CUDA graph (or torch compile) (2) use a large model.

@Kiet0712
Copy link

Kiet0712 commented Jun 4, 2024

@tridao Thank you for your help. I add this line "@torch.compile(options={"triton.cudagraphs": True}, fullgraph=True)" before mamba_split_conv1d_scan_combined function in ssd_combined and get the speed competitive with original mamba code.

@realwenlongwang
Copy link
Author

Yes, CUDA grapha works!

@dwgan

This comment was marked as duplicate.

@Kiet0712
Copy link

Kiet0712 commented Jun 6, 2024

What is your main_test.py file and can you give me some detail about your environment ?

@dwgan
Copy link

dwgan commented Jun 6, 2024

@Kiet0712 Thanks for your response. It is the main_test.py

import torch
from mamba_ssm import Mamba, Mamba2

batch, length, dim = 2, 64, 256
x = torch.randn(batch, length, dim).to("cuda")

# Initialize the Mamba model
model = Mamba(
    # This module uses roughly 3 * expand * d_model^2 parameters
    d_model=dim, # Model dimension d_model
    d_state=64,  # SSM state expansion factor
    d_conv=4,    # Local convolution width
    expand=2,    # Block expansion factor
).to("cuda")

# Initialize the Mamba2 model
model2 = Mamba2(
    # This module uses roughly 3 * expand * d_model^2 parameters
    d_model=dim, # Model dimension d_model
    d_state=64,  # SSM state expansion factor, typically 64 or 128
    d_conv=4,    # Local convolution width
    expand=2,    # Block expansion factor
    headdim=32,  # Additional parameter for Mamba2
    ngroups=1,   # Number of groups for group normalization
    sequence_parallel=False, # Whether to use sequence parallelism
).to("cuda")

# Function to count model parameters
def count_parameters(model):
    return sum(p.numel() for p in model.parameters())

# Print the number of parameters for each model
print(f"Mamba model parameters: {count_parameters(model)}")
print(f"Mamba2 model parameters: {count_parameters(model2)}")

# Measure inference time for Mamba model
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)

start_event.record()
y = model(x)
end_event.record()

# Wait for all CUDA operations to finish
torch.cuda.synchronize()

mamba_time = start_event.elapsed_time(end_event) # Time in milliseconds

print(f"\nMamba model time: {mamba_time} ms")
print(y.shape)
assert y.shape == x.shape

# Measure inference time for Mamba2 model
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)

start_event.record()
y = model2(x)
end_event.record()

# Wait for all CUDA operations to finish
torch.cuda.synchronize()

mamba2_time = start_event.elapsed_time(end_event) # Time in milliseconds

print(f"\nMamba2 model time: {mamba2_time} ms")
print(y.shape)
assert y.shape == x.shape

@arelkeselbri
Copy link

arelkeselbri commented Jun 6, 2024

I still find this issue even with cuda graphs compile.

I applied ".contiguous()" patch to fix stride issues. Also used annotation for compile with CUDA graphs.

My test is on a H100 with:

import torch
import timeit
from mamba_ssm import Mamba, Mamba2

batch, length, dim = 2, 64, 64
x = torch.randn(batch, length, dim).to("cuda")

def try_mamba1(batch, length, dim, x):
    model = Mamba(
        # This module uses roughly 3 * expand * d_model^2 parameters
        d_model=dim, # Model dimension d_model
        d_state=16,  # SSM state expansion factor
        d_conv=4,    # Local convolution width
        expand=2,    # Block expansion factor
    ).to("cuda")
    y = model(x)
    assert y.shape == x.shape

def try_mamba2(batch, length, dim, x):
    model = Mamba2(
        # This module uses roughly 3 * expand * d_model^2 parameters
        d_model=dim, # Model dimension d_model
        d_state=64,  # SSM state expansion factor, typically 64 or 128
        d_conv=4,    # Local convolution width
        expand=2,    # Block expansion factor
    ).to("cuda")
    y = model(x)
    assert y.shape == x.shape

mamba1_time = timeit.timeit('try_mamba1(batch, length, dim, x)', number=10, globals=globals())
print(f"Mamba 1 took {mamba1_time} seconds")

mamba2_time = timeit.timeit('try_mamba2(batch, length, dim, x)', number=10, globals=globals())
print(f"Mamba 2 took {mamba2_time} seconds")

Package versions:

torch 2.3.0
causal-conv1d 1.2.2.post1
mamba-ssm 2.0.3
nvidia-cuda-runtime-cu12 12.1.105

LOG:

Traceback (most recent call last):
  File "/home/albertini/mamba/test2.py", line 33, in <module>
    mamba2_time = timeit.timeit('try_mamba2(batch, length, dim, x)', number=10, globals=globals())
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/timeit.py", line 237, in timeit
    return Timer(stmt, setup, timer, globals).timeit(number)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/timeit.py", line 180, in timeit
    timing = self.inner(it, self.timer)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<timeit-src>", line 6, in inner
  File "/home/albertini/mamba/test2.py", line 27, in try_mamba2
    y = model(x)
        ^^^^^^^^
  File "/home/albertini/mamba/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/albertini/mamba/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/albertini/mamba/mamba_ssm/modules/mamba2.py", line 176, in forward
    out = mamba_split_conv1d_scan_combined(
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/albertini/mamba/venv/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 451, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/albertini/mamba/venv/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 921, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state, skip=1)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/albertini/mamba/venv/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 400, in _convert_frame_assert
    return _compile(
           ^^^^^^^^^
  File "/usr/local/lib/python3.11/contextlib.py", line 81, in inner
    return func(*args, **kwds)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/albertini/mamba/venv/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 676, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/albertini/mamba/venv/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 262, in time_wrapper
    r = func(*args, **kwargs)
        ^^^^^^^^^^^^^^^^^^^^^
  File "/home/albertini/mamba/venv/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 535, in compile_inner
    out_code = transform_code_object(code, transform)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/albertini/mamba/venv/lib/python3.11/site-packages/torch/_dynamo/bytecode_transformation.py", line 1036, in transform_code_object
    transformations(instructions, code_options)
  File "/home/albertini/mamba/venv/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 165, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/albertini/mamba/venv/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 500, in transform
    tracer.run()
  File "/home/albertini/mamba/venv/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2149, in run
    super().run()
  File "/home/albertini/mamba/venv/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 810, in run
    and self.step()
        ^^^^^^^^^^^
  File "/home/albertini/mamba/venv/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 773, in step
    getattr(self, inst.opname)(inst)
  File "/home/albertini/mamba/venv/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 489, in wrapper
    return inner_fn(self, inst)
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/albertini/mamba/venv/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 1802, in CALL
    self.call_function(fn, args, kwargs)
  File "/home/albertini/mamba/venv/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 674, in call_function
    self.push(fn.call_function(self, args, kwargs))
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/albertini/mamba/venv/lib/python3.11/site-packages/torch/_dynamo/variables/misc.py", line 562, in call_function
    return self.obj.call_method(tx, self.name, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/albertini/mamba/venv/lib/python3.11/site-packages/torch/_dynamo/variables/misc.py", line 420, in call_method
    return self.call_apply(tx, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/albertini/mamba/venv/lib/python3.11/site-packages/torch/_dynamo/variables/misc.py", line 368, in call_apply
    ).call_function(tx, args, kwargs)
      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/albertini/mamba/venv/lib/python3.11/site-packages/torch/_dynamo/variables/higher_order_ops.py", line 1533, in call_function
    (fwd_out, _), fwd_graph, fwd_freevars = speculate_subgraph(
                                            ^^^^^^^^^^^^^^^^^^^
  File "/home/albertini/mamba/venv/lib/python3.11/site-packages/torch/_dynamo/variables/higher_order_ops.py", line 467, in speculate_subgraph
    raise ex
  File "/home/albertini/mamba/venv/lib/python3.11/site-packages/torch/_dynamo/variables/higher_order_ops.py", line 359, in speculate_subgraph
    args = validate_args_and_maybe_create_graph_inputs(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/albertini/mamba/venv/lib/python3.11/site-packages/torch/_dynamo/variables/higher_order_ops.py", line 200, in validate_args_and_maybe_create_graph_inputs
    raise unimplemented(
          ^^^^^^^^^^^^^^
  File "/home/albertini/mamba/venv/lib/python3.11/site-packages/torch/_dynamo/exc.py", line 190, in unimplemented
    raise Unsupported(msg)
torch._dynamo.exc.Unsupported: autograd.Function with body that accepts non-Tensors as input. Got: <class 'tuple'>

from user code:
   File "/home/albertini/mamba/mamba_ssm/ops/triton/ssd_combined.py", line 909, in mamba_split_conv1d_scan_combined
    return MambaSplitConv1dScanCombinedFn.apply(zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states, seq_idx, dt_limit, return_final_states, activation, rmsnorm_weight, rmsnorm_eps, outproj_weight, outproj_bias, headdim, ngroups, norm_before_gate)

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

@hengck23
Copy link

hengck23 commented Jun 7, 2024

for my case:

  • gpu: .NVIDIA RTX 6000 Ada
  • pytorch: 2.3.0+cu121
[added this code]
ssd_combined.py
@torch.compile(options={"triton.cudagraphs": True}, fullgraph=True)
def mamba_split_conv1d_scan_combined(zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states=None, seq_idx=None, dt_limit=(0.0, float("inf")), return_final_states=False, activation="silu", rmsnorm_weight=None, rmsnorm_eps=1e-6, outproj_weight=None, outproj_bias=None, headdim=None, ngroups=1, norm_before_gate=True):
[error]
LLVM ERROR: pthread_join failed: Invalid argument
LLVM ERROR: pthread_join failed: Invalid argument
...

    raise Unsupported(msg)
torch._dynamo.exc.Unsupported: 'skip function PyCapsule.causal_conv1d_fwd in file Builtin causal_conv1d_fwd'
...

  File "/home/hp/app/anaconda3.10/lib/python3.10/site-packages/mamba_ssm/ops/triton/ssd_combined.py", line 757, in forward
    causal_conv1d_cuda.causal_conv1d_fwd(rearrange(xBC, "b s d -> b d s"),

@Kiet0712
Copy link

Kiet0712 commented Jun 7, 2024

@dwgan I don't actually know what is your problem, in my case, i just simple add "@torch.compile(options={"triton.cudagraphs": True}, fullgraph=True)" before mamba_split_conv1d_scan_combined function and then use Mamba2 in my task and it works.

@DustinEwan
Copy link

  File "/home/hp/app/anaconda3.10/lib/python3.10/site-packages/mamba_ssm/ops/triton/ssd_combined.py", line 757, in forward
    causal_conv1d_cuda.causal_conv1d_fwd(rearrange(xBC, "b s d -> b d s"),

I have the same problem, trying to work through it now. If I find a solution I'll let you know, in the meantime any help is very much appreciated!

@AlwaysFHao
Copy link

@dwgan I don't actually know what is your problem, in my case, i just simple add "@torch.compile(options={"triton.cudagraphs": True}, fullgraph=True)" before mamba_split_conv1d_scan_combined function and then use Mamba2 in my task and it works.

May I ask about your Torch and Triton versions?

@dwgan
Copy link

dwgan commented Jun 14, 2024

@dwgan I don't actually know what is your problem, in my case, i just simple add "@torch.compile(options={"triton.cudagraphs": True}, fullgraph=True)" before mamba_split_conv1d_scan_combined function and then use Mamba2 in my task and it works.

May I ask about your Torch and Triton versions?

Torch==2.1.2
Triton==2.1.0
python==3.10
ubuntu18.04

@Baijiong-Lin
Copy link

I don't actually know what is your problem, in my case, i just simple add "@torch.compile(options={"triton.cudagraphs": True}, fullgraph=True)" before mamba_split_conv1d_scan_combined function and then use Mamba2 in my task and it works.

@Kiet0712 could you tell us your torch and triton versions? thanks.

@Kiet0712
Copy link

@Baijiong-Lin I use triton 2.1.0 and torch 2.1.1

@Baijiong-Lin
Copy link

@Baijiong-Lin I use triton 2.1.0 and torch 2.1.1

@Kiet0712 thanks. but it does not work for me. it still has an error after adding "@torch.compile(options={"triton.cudagraphs": True}, fullgraph=True)" before mamba_split_conv1d_scan_combined function.

@TimothyChen225
Copy link

do anyone solved

@TimothyChen225
Copy link

Yes, CUDA grapha works!

I've tried this but I'm still getting an error, and I'd appreciate it if you could show me the demo code

@dwgan
Copy link

dwgan commented Jun 21, 2024

Mamba2 is written mostly in Triton, so there's a lot of CPU overhead if the layer is so small. Two ways to get around that: (1) CUDA graph (or torch compile) (2) use a large model.

Try warming up by running it once first. The first time will invoke the triton compiler & autotune so it'll be slow.

I think the problem was solved.

See my code here

import time
import torch
from mamba_ssm import Mamba2
from mamba_ssm import Mamba

repeat_num = 1000
batch, length, dim = 2, 256*256*2, 256
x = torch.randn(batch, length, dim).to("cuda")
model = Mamba(
    # This module uses roughly 3 * expand * d_model^2 parameters
    d_model=dim, # Model dimension d_model
    d_state=64,  # SSM state expansion factor
    d_conv=4,    # Local convolution width
    expand=2,    # Block expansion factor
).to("cuda")
y = model(x) # warm up first
t1 = time.time()
for i in range(repeat_num):
    y = model(x)
assert y.shape == x.shape
print(f"Time of mamba taken: {time.time() - t1:.3f} s")

model = Mamba2(
    # This module uses roughly 3 * expand * d_model^2 parameters
    d_model=dim, # Model dimension d_model
    d_state=64,  # SSM state expansion factor, typically 64 or 128
    d_conv=4,    # Local convolution width
    expand=2,    # Block expansion factor
).to("cuda")
y = model(x) # warm up first
t1 = time.time()
for i in range(repeat_num):
    y = model(x)
assert y.shape == x.shape
print(f"Time of mamba2 taken: {time.time() - t1:.3f} s")

The output log

Time of mamba taken: 24.061 s
Time of mamba2 taken: 14.011 s

@AlwaysFHao
Copy link

Mamba2 is written mostly in Triton, so there's a lot of CPU overhead if the layer is so small. Two ways to get around that: (1) CUDA graph (or torch compile) (2) use a large model.

Try warming up by running it once first. The first time will invoke the triton compiler & autotune so it'll be slow.

I think the problem was solved.

See my code here

import time
import torch
from mamba_ssm import Mamba2
from mamba_ssm import Mamba

repeat_num = 1000
batch, length, dim = 2, 256*256*2, 256
x = torch.randn(batch, length, dim).to("cuda")
model = Mamba(
    # This module uses roughly 3 * expand * d_model^2 parameters
    d_model=dim, # Model dimension d_model
    d_state=64,  # SSM state expansion factor
    d_conv=4,    # Local convolution width
    expand=2,    # Block expansion factor
).to("cuda")
y = model(x) # warm up first
t1 = time.time()
for i in range(repeat_num):
    y = model(x)
assert y.shape == x.shape
print(f"Time of mamba taken: {time.time() - t1:.3f} s")

model = Mamba2(
    # This module uses roughly 3 * expand * d_model^2 parameters
    d_model=dim, # Model dimension d_model
    d_state=64,  # SSM state expansion factor, typically 64 or 128
    d_conv=4,    # Local convolution width
    expand=2,    # Block expansion factor
).to("cuda")
y = model(x) # warm up first
t1 = time.time()
for i in range(repeat_num):
    y = model(x)
assert y.shape == x.shape
print(f"Time of mamba2 taken: {time.time() - t1:.3f} s")

The output log

Time of mamba taken: 24.061 s
Time of mamba2 taken: 14.011 s

After adding the pre compiled model of 'torch. compile', I actually need a warm up to achieve good results. But why can you solve it without using compile here?

@dwgan
Copy link

dwgan commented Jun 21, 2024

Mamba2 is written mostly in Triton, so there's a lot of CPU overhead if the layer is so small. Two ways to get around that: (1) CUDA graph (or torch compile) (2) use a large model.

Try warming up by running it once first. The first time will invoke the triton compiler & autotune so it'll be slow.

I think the problem was solved.
See my code here

import time
import torch
from mamba_ssm import Mamba2
from mamba_ssm import Mamba

repeat_num = 1000
batch, length, dim = 2, 256*256*2, 256
x = torch.randn(batch, length, dim).to("cuda")
model = Mamba(
    # This module uses roughly 3 * expand * d_model^2 parameters
    d_model=dim, # Model dimension d_model
    d_state=64,  # SSM state expansion factor
    d_conv=4,    # Local convolution width
    expand=2,    # Block expansion factor
).to("cuda")
y = model(x) # warm up first
t1 = time.time()
for i in range(repeat_num):
    y = model(x)
assert y.shape == x.shape
print(f"Time of mamba taken: {time.time() - t1:.3f} s")

model = Mamba2(
    # This module uses roughly 3 * expand * d_model^2 parameters
    d_model=dim, # Model dimension d_model
    d_state=64,  # SSM state expansion factor, typically 64 or 128
    d_conv=4,    # Local convolution width
    expand=2,    # Block expansion factor
).to("cuda")
y = model(x) # warm up first
t1 = time.time()
for i in range(repeat_num):
    y = model(x)
assert y.shape == x.shape
print(f"Time of mamba2 taken: {time.time() - t1:.3f} s")

The output log

Time of mamba taken: 24.061 s
Time of mamba2 taken: 14.011 s

After adding the pre compiled model of 'torch. compile', I actually need a warm up to achieve good results. But why can you solve it without using compile here?

I use the original version, without adding 'torch. compile'.

@TimothyChen225
Copy link

see #389

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

10 participants