Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mamab2 has the error #384

Open
liangaomng opened this issue Jun 11, 2024 · 1 comment
Open

mamab2 has the error #384

liangaomng opened this issue Jun 11, 2024 · 1 comment

Comments

@liangaomng
Copy link

thanks for your wanderful work!
When I. run the mamba1, it is ok.
but when I run the mamba2 in your .readme ,it shows that:
File "", line 21, in _chunk_cumsum_fwd_kernel
KeyError: ('2-.-0-.-0-d82511111ad128294e9d31a6ac684238-d6252949da17ceb5f3a278a70250af13-1af5134066c618146d2cd009138944a0-bf290056f5caf73914ee917acc2e7230-3498c340fd4b6ee7805fd54b882a04f5-e1f133f98d04093da2078dfc51c36b72-b26258bf01f839199e39d64851821f26-d7c06e3b46e708006c15224aac7a1378-f585402118c8a136948ce0a49cfe122c', (torch.float32, torch.float32, torch.float32, torch.float32, torch.float32, 'i32', 'i32', 'i32', 'i32', 'fp32', 'fp32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32'), (True, True, 1, 256), (True, True, True, True, True, (False, False), (True, False), (False, False), (True, False), (False,), (False,), (True, False), (False, False), (False, True), (False, True), (False, True), (True, False), (True, False), (True, False), (False, True), (True, False), (True, False), (True, False), (False, True)))

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File "/liujinxin/lam/le_pde/pytorch_net/mamba.py", line 15, in
y = model(x)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/mamba_ssm/modules/mamba2.py", line 176, in forward
out = mamba_split_conv1d_scan_combined(
File "/opt/conda/lib/python3.10/site-packages/mamba_ssm/ops/triton/ssd_combined.py", line 908, in mamba_split_conv1d_scan_combined
return MambaSplitConv1dScanCombinedFn.apply(zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states, seq_idx, dt_limit, return_final_states, activation, rmsnorm_weight, rmsnorm_eps, outproj_weight, outproj_bias, headdim, ngroups, norm_before_gate)
File "/opt/conda/lib/python3.10/site-packages/torch/autograd/function.py", line 506, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
File "/opt/conda/lib/python3.10/site-packages/torch/cuda/amp/autocast_mode.py", line 98, in decorate_fwd
return fwd(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/mamba_ssm/ops/triton/ssd_combined.py", line 773, in forward
out_x, _, dt_out, dA_cumsum, states, final_states = _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size=chunk_size, D=D, z=None, dt_bias=dt_bias, initial_states=initial_states, seq_idx=seq_idx, dt_softplus=True, dt_limit=dt_limit)
File "/opt/conda/lib/python3.10/site-packages/mamba_ssm/ops/triton/ssd_combined.py", line 307, in _mamba_chunk_scan_combined_fwd
dA_cumsum, dt = _chunk_cumsum_fwd(dt, A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus, dt_limit=dt_limit)
File "/opt/conda/lib/python3.10/site-packages/mamba_ssm/ops/triton/ssd_chunk_state.py", line 582, in _chunk_cumsum_fwd
_chunk_cumsum_fwd_kernel[grid_chunk_cs](
File "/opt/conda/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 77, in run
timings = {config: self._bench(*args, config=config, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 77, in
timings = {config: self._bench(*args, config=config, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 65, in _bench
return do_bench(kernel_call)
File "/opt/conda/lib/python3.10/site-packages/triton/testing.py", line 143, in do_bench
fn()
File "/opt/conda/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 63, in kernel_call
self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current)
File "", line 41, in _chunk_cumsum_fwd_kernel
File "/opt/conda/lib/python3.10/site-packages/triton/compiler.py", line 1589, in compile
fn_cache_manager = CacheManager(make_hash(fn, **kwargs))
File "/opt/conda/lib/python3.10/site-packages/triton/compiler.py", line 1499, in make_hash
key = f"{fn.cache_key}-{''.join(signature.values())}-{configs_key}-{constants}-{num_warps}-{num_stages}"
File "/opt/conda/lib/python3.10/site-packages/triton/runtime/jit.py", line 333, in cache_key
dependencies_finder.visit(self.parse())
File "/opt/conda/lib/python3.10/ast.py", line 418, in visit
return visitor(node)
File "/opt/conda/lib/python3.10/ast.py", line 426, in generic_visit
self.visit(item)
File "/opt/conda/lib/python3.10/ast.py", line 418, in visit
return visitor(node)
File "/opt/conda/lib/python3.10/ast.py", line 426, in generic_visit
self.visit(item)
File "/opt/conda/lib/python3.10/ast.py", line 418, in visit
return visitor(node)
File "/opt/conda/lib/python3.10/ast.py", line 428, in generic_visit
self.visit(value)
File "/opt/conda/lib/python3.10/ast.py", line 418, in visit
return visitor(node)
File "/opt/conda/lib/python3.10/site-packages/triton/runtime/jit.py", line 55, in visit_Call
func = self.visit(node.func)
File "/opt/conda/lib/python3.10/ast.py", line 418, in visit
return visitor(node)
File "/opt/conda/lib/python3.10/site-packages/triton/runtime/jit.py", line 52, in visit_Attribute
return getattr(lhs, node.attr)
AttributeError: module 'triton.language' has no attribute 'cumsum'

@tridao
Copy link
Collaborator

tridao commented Jun 11, 2024

Please use triton >= 2.1.0

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants