-
Notifications
You must be signed in to change notification settings - Fork 923
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
mamab2 has the error #384
Comments
Please use triton >= 2.1.0 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
thanks for your wanderful work!
When I. run the mamba1, it is ok.
but when I run the mamba2 in your .readme ,it shows that:
File "", line 21, in _chunk_cumsum_fwd_kernel
KeyError: ('2-.-0-.-0-d82511111ad128294e9d31a6ac684238-d6252949da17ceb5f3a278a70250af13-1af5134066c618146d2cd009138944a0-bf290056f5caf73914ee917acc2e7230-3498c340fd4b6ee7805fd54b882a04f5-e1f133f98d04093da2078dfc51c36b72-b26258bf01f839199e39d64851821f26-d7c06e3b46e708006c15224aac7a1378-f585402118c8a136948ce0a49cfe122c', (torch.float32, torch.float32, torch.float32, torch.float32, torch.float32, 'i32', 'i32', 'i32', 'i32', 'fp32', 'fp32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32'), (True, True, 1, 256), (True, True, True, True, True, (False, False), (True, False), (False, False), (True, False), (False,), (False,), (True, False), (False, False), (False, True), (False, True), (False, True), (True, False), (True, False), (True, False), (False, True), (True, False), (True, False), (True, False), (False, True)))
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/liujinxin/lam/le_pde/pytorch_net/mamba.py", line 15, in
y = model(x)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/mamba_ssm/modules/mamba2.py", line 176, in forward
out = mamba_split_conv1d_scan_combined(
File "/opt/conda/lib/python3.10/site-packages/mamba_ssm/ops/triton/ssd_combined.py", line 908, in mamba_split_conv1d_scan_combined
return MambaSplitConv1dScanCombinedFn.apply(zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states, seq_idx, dt_limit, return_final_states, activation, rmsnorm_weight, rmsnorm_eps, outproj_weight, outproj_bias, headdim, ngroups, norm_before_gate)
File "/opt/conda/lib/python3.10/site-packages/torch/autograd/function.py", line 506, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
File "/opt/conda/lib/python3.10/site-packages/torch/cuda/amp/autocast_mode.py", line 98, in decorate_fwd
return fwd(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/mamba_ssm/ops/triton/ssd_combined.py", line 773, in forward
out_x, _, dt_out, dA_cumsum, states, final_states = _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size=chunk_size, D=D, z=None, dt_bias=dt_bias, initial_states=initial_states, seq_idx=seq_idx, dt_softplus=True, dt_limit=dt_limit)
File "/opt/conda/lib/python3.10/site-packages/mamba_ssm/ops/triton/ssd_combined.py", line 307, in _mamba_chunk_scan_combined_fwd
dA_cumsum, dt = _chunk_cumsum_fwd(dt, A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus, dt_limit=dt_limit)
File "/opt/conda/lib/python3.10/site-packages/mamba_ssm/ops/triton/ssd_chunk_state.py", line 582, in _chunk_cumsum_fwd
_chunk_cumsum_fwd_kernel[grid_chunk_cs](
File "/opt/conda/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 77, in run
timings = {config: self._bench(*args, config=config, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 77, in
timings = {config: self._bench(*args, config=config, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 65, in _bench
return do_bench(kernel_call)
File "/opt/conda/lib/python3.10/site-packages/triton/testing.py", line 143, in do_bench
fn()
File "/opt/conda/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 63, in kernel_call
self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current)
File "", line 41, in _chunk_cumsum_fwd_kernel
File "/opt/conda/lib/python3.10/site-packages/triton/compiler.py", line 1589, in compile
fn_cache_manager = CacheManager(make_hash(fn, **kwargs))
File "/opt/conda/lib/python3.10/site-packages/triton/compiler.py", line 1499, in make_hash
key = f"{fn.cache_key}-{''.join(signature.values())}-{configs_key}-{constants}-{num_warps}-{num_stages}"
File "/opt/conda/lib/python3.10/site-packages/triton/runtime/jit.py", line 333, in cache_key
dependencies_finder.visit(self.parse())
File "/opt/conda/lib/python3.10/ast.py", line 418, in visit
return visitor(node)
File "/opt/conda/lib/python3.10/ast.py", line 426, in generic_visit
self.visit(item)
File "/opt/conda/lib/python3.10/ast.py", line 418, in visit
return visitor(node)
File "/opt/conda/lib/python3.10/ast.py", line 426, in generic_visit
self.visit(item)
File "/opt/conda/lib/python3.10/ast.py", line 418, in visit
return visitor(node)
File "/opt/conda/lib/python3.10/ast.py", line 428, in generic_visit
self.visit(value)
File "/opt/conda/lib/python3.10/ast.py", line 418, in visit
return visitor(node)
File "/opt/conda/lib/python3.10/site-packages/triton/runtime/jit.py", line 55, in visit_Call
func = self.visit(node.func)
File "/opt/conda/lib/python3.10/ast.py", line 418, in visit
return visitor(node)
File "/opt/conda/lib/python3.10/site-packages/triton/runtime/jit.py", line 52, in visit_Attribute
return getattr(lhs, node.attr)
AttributeError: module 'triton.language' has no attribute 'cumsum'
The text was updated successfully, but these errors were encountered: