Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

triton error while running Mamba2 with slow path #369

Open
Seeker98 opened this issue Jun 6, 2024 · 7 comments
Open

triton error while running Mamba2 with slow path #369

Seeker98 opened this issue Jun 6, 2024 · 7 comments

Comments

@Seeker98
Copy link

Seeker98 commented Jun 6, 2024

as #355 , I added "@torch.compile(options={"triton.cudagraphs": True}, fullgraph=True)" to "mamba_chunk_scan_combined" function in file "ssd_combined.py", and running failed with error:

Unsupported: autograd.Function with body that accepts non-Tensors as input. Got: <class 'tuple'>

from user code:
   File "/home/hit/.conda/envs/torch2/lib/python3.9/site-packages/mamba_ssm/ops/triton/ssd_combined.py", line 560, in mamba_chunk_scan_combined
    return MambaChunkScanCombinedFn.apply(x, dt, A, B, C, chunk_size, D, z, dt_bias, initial_states, seq_idx, dt_softplus, dt_limit, return_final_states)

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

reproduce code:

import torch
from mamba_ssm import Mamba2
batch, length, dim = 8,1024,128
x = torch.randn(batch, length, dim).to("cuda")
model = Mamba2(
    # This module uses roughly 3 * expand * d_model^2 parameters
    d_model=dim, # Model dimension d_model
    d_state=64,  # SSM state expansion factor, typically 64 or 128
    d_conv=4,    # Local convolution width
    expand=2,    # Block expansion factor
    headdim=32,
    use_mem_eff_path=False
).to("cuda")
y = model(x)
assert y.shape == x.shape

I'm not sure what to provide, but my packages are:
mamba-ssm 2.0.3
causal-conv1d 1.2.2.post1
pytorch 2.3.1 with py39_cu121_cudnn8.9.2_0

@arelkeselbri
Copy link

Tried the following and time seems not to change. Maybe this is just an initial delay:


for i in range(10):
    x = torch.randn(batch, length, dim).to("cuda")
    y = model2(x)

@Seeker98
Copy link
Author

Seeker98 commented Jun 6, 2024

Well I’m wondering about why adding compile as #355 discussion makes the code failed, as the author mentioned this could accelerate a lot

@Baijiong-Lin
Copy link

the same issue

@yaosi-ym
Copy link

同样的问题

@zizheng-guo
Copy link

the same issue

@JHChen1
Copy link

JHChen1 commented Jun 16, 2024

#355,我在文件“ssd_combined.py”中的“mamba_chunk_scan_combined”函数中添加了“@torch.compile(options={"triton.cudagraphs": True}, fullgraph=True)”,运行失败,错误如下:

Unsupported: autograd.Function with body that accepts non-Tensors as input. Got: <class 'tuple'>

from user code:
   File "/home/hit/.conda/envs/torch2/lib/python3.9/site-packages/mamba_ssm/ops/triton/ssd_combined.py", line 560, in mamba_chunk_scan_combined
    return MambaChunkScanCombinedFn.apply(x, dt, A, B, C, chunk_size, D, z, dt_bias, initial_states, seq_idx, dt_softplus, dt_limit, return_final_states)

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

重现代码:

import torch
from mamba_ssm import Mamba2
batch, length, dim = 8,1024,128
x = torch.randn(batch, length, dim).to("cuda")
model = Mamba2(
    # This module uses roughly 3 * expand * d_model^2 parameters
    d_model=dim, # Model dimension d_model
    d_state=64,  # SSM state expansion factor, typically 64 or 128
    d_conv=4,    # Local convolution width
    expand=2,    # Block expansion factor
    headdim=32,
    use_mem_eff_path=False
).to("cuda")
y = model(x)
assert y.shape == x.shape

我不确定要提供什么,但我的包是: mamba-ssm 2.0.3 causal-conv1d 1.2.2.post1 pytorch 2.3.1 和 py39_cu121_cudnn8.9.2_0

Hi, I have the same problem, have you solved it?

@TimothyChen225
Copy link

the same issue

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

7 participants