Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Triton Error [CUDA]: device kernel image is invalid #386

Open
rationalspark opened this issue Jun 12, 2024 · 7 comments
Open

Triton Error [CUDA]: device kernel image is invalid #386

rationalspark opened this issue Jun 12, 2024 · 7 comments

Comments

@rationalspark
Copy link

Thanks for the wonderful work.

When running Mamba2, I encountered the error "Triton Error [CUDA]: device kernel image is invalid".

Should you be so kind as to provide some advice?

My environment is
torch 2.3.0+cu118
triton 2.3.0
The GPU is RTX3090.

The code is

'''
import torch
from mamba_ssm import Mamba2
batch, length, dim = 2, 512, 256
x = torch.randn(batch, length, dim).to("cuda:2")
model = Mamba2(
# This module uses roughly 3 * expand * d_model^2 parameters
d_model=dim, # Model dimension d_model
d_state=64, # SSM state expansion factor, typically 64 or 128
d_conv=4, # Local convolution width
expand=2, # Block expansion factor
headdim=128
).to("cuda:2")
y = model(x)
assert y.shape == x.shape
'''

The error messages are


RuntimeError Traceback (most recent call last)
Cell In[3], line 13
4 x = torch.randn(batch, length, dim).to("cuda:2")
5 model = Mamba2(
6 # This module uses roughly 3 * expand * d_model^2 parameters
7 d_model=dim, # Model dimension d_model
(...)
11 headdim=128
12 ).to("cuda:2")
---> 13 y = model(x)
14 assert y.shape == x.shape

File ~/anaconda3/lib/python3.9/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
1530 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1531 else:
-> 1532 return self._call_impl(*args, **kwargs)

File ~/anaconda3/lib/python3.9/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
1536 # If we don't have any hooks, we want to skip the rest of the logic in
1537 # this function, and just call forward.
1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1539 or _global_backward_pre_hooks or _global_backward_hooks
1540 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541 return forward_call(*args, **kwargs)
1543 try:
1544 result = None

File ~/work/ts/mamba/mamba_ssm/modules/mamba2.py:176, in Mamba2.forward(self, u, seqlen, seq_idx, inference_params)
174 dt_limit_kwargs = {} if self.dt_limit == (0.0, float("inf")) else dict(dt_limit=self.dt_limit)
175 if self.use_mem_eff_path and inference_params is None:
--> 176 out = mamba_split_conv1d_scan_combined(
177 zxbcdt,
178 rearrange(self.conv1d.weight, "d 1 w -> d w"),
179 self.conv1d.bias,
180 self.dt_bias,
181 A,
182 D=rearrange(self.D, "(h p) -> h p", p=self.headdim) if self.D_has_hdim else self.D,
183 chunk_size=self.chunk_size,
184 seq_idx=seq_idx,
185 activation=self.activation,
186 rmsnorm_weight=self.norm.weight if self.rmsnorm else None,
187 rmsnorm_eps=self.norm.eps if self.rmsnorm else 1e-6,
188 outproj_weight=self.out_proj.weight,
189 outproj_bias=self.out_proj.bias,
190 headdim=None if self.D_has_hdim else self.headdim,
191 ngroups=self.ngroups,
192 norm_before_gate=self.norm_before_gate,
193 **dt_limit_kwargs,
194 )
195 if seqlen_og is not None:
196 out = rearrange(out, "b l d -> (b l) d")

File ~/work/ts/mamba/mamba_ssm/ops/triton/ssd_combined.py:908, in mamba_split_conv1d_scan_combined(zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states, seq_idx, dt_limit, return_final_states, activation, rmsnorm_weight, rmsnorm_eps, outproj_weight, outproj_bias, headdim, ngroups, norm_before_gate)
889 def mamba_split_conv1d_scan_combined(zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states=None, seq_idx=None, dt_limit=(0.0, float("inf")), return_final_states=False, activation="silu", rmsnorm_weight=None, rmsnorm_eps=1e-6, outproj_weight=None, outproj_bias=None, headdim=None, ngroups=1, norm_before_gate=True):
890 """
891 Argument:
892 zxbcdt: (batch, seqlen, 2 * dim + 2 * ngroups * dstate + nheads) where dim == nheads * headdim
(...)
906 out: (batch, seqlen, dim)
907 """
--> 908 return MambaSplitConv1dScanCombinedFn.apply(zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states, seq_idx, dt_limit, return_final_states, activation, rmsnorm_weight, rmsnorm_eps, outproj_weight, outproj_bias, headdim, ngroups, norm_before_gate)

File ~/anaconda3/lib/python3.9/site-packages/torch/autograd/function.py:598, in Function.apply(cls, *args, **kwargs)
595 if not torch._C._are_functorch_transforms_active():
596 # See NOTE: [functorch vjp and autograd interaction]
597 args = _functorch.utils.unwrap_dead_wrappers(args)
--> 598 return super().apply(*args, **kwargs) # type: ignore[misc]
600 if not is_setup_ctx_defined:
601 raise RuntimeError(
602 "In order to use an autograd.Function with functorch transforms "
603 "(vmap, grad, jvp, jacrev, ...), it must override the setup_context "
604 "staticmethod. For more details, please see "
605 "https://pytorch.org/docs/master/notes/extending.func.html"
606 )

File ~/anaconda3/lib/python3.9/site-packages/torch/cuda/amp/autocast_mode.py:115, in custom_fwd..decorate_fwd(*args, **kwargs)
113 if cast_inputs is None:
114 args[0]._fwd_used_autocast = torch.is_autocast_enabled()
--> 115 return fwd(*args, **kwargs)
116 else:
117 autocast_context = torch.is_autocast_enabled()

File ~/work/ts/mamba/mamba_ssm/ops/triton/ssd_combined.py:773, in MambaSplitConv1dScanCombinedFn.forward(ctx, zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states, seq_idx, dt_limit, return_final_states, activation, rmsnorm_weight, rmsnorm_eps, outproj_weight, outproj_bias, headdim, ngroups, norm_before_gate)
771 out = torch.cat([_swiglu_fwd(zx0), out], dim=-1)
772 else:
--> 773 out_x, _, dt_out, dA_cumsum, states, final_states = _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size=chunk_size, D=D, z=None, dt_bias=dt_bias, initial_states=initial_states, seq_idx=seq_idx, dt_softplus=True, dt_limit=dt_limit)
774 # reshape input data into 2D tensor
775 x_rms = rearrange(out_x, "b s h p -> (b s) (h p)")

File ~/work/ts/mamba/mamba_ssm/ops/triton/ssd_combined.py:307, in _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size, D, z, dt_bias, initial_states, seq_idx, dt_softplus, dt_limit)
302 assert initial_states.shape == (batch, nheads, headdim, dstate)
303 # # (batch, nchunks, chunk_size, chunk_size) or (batch, nchunks, nheads, chunk_size, chunk_size)
304 # dA_cumsum_tmp0, dt_tmp0 = _chunk_cumsum_fwd(dt[:, :147], A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus)
305 # dA_cumsum_tmp1, dt_tmp1 = _chunk_cumsum_fwd(dt[:, 147:], A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus)
306 # dA_cumsum_tmp2, dt_tmp2 = _chunk_cumsum_fwd(dt[:, 147:256], A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus)
--> 307 dA_cumsum, dt = _chunk_cumsum_fwd(dt, A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus, dt_limit=dt_limit)
308 states = _chunk_state_fwd(B, x, dt, dA_cumsum, seq_idx=seq_idx, states_in_fp32=True)
309 # states_tmp0 = _chunk_state_fwd(B[:, :147], x[:, :147], dt_tmp0, dA_cumsum_tmp0, states_in_fp32=True)
310 # states_tmp1 = _chunk_state_fwd(B[:, 147:], x[:, 147:], dt_tmp1, dA_cumsum_tmp1, states_in_fp32=True)
311 # states_tmp2 = _chunk_state_fwd(B[:, 147:256], x[:, 147:256], dt_tmp2, dA_cumsum_tmp2, states_in_fp32=True)

File ~/work/ts/mamba/mamba_ssm/ops/triton/ssd_chunk_state.py:582, in _chunk_cumsum_fwd(dt, A, chunk_size, dt_bias, dt_softplus, dt_limit)
580 grid_chunk_cs = lambda META: (batch, nchunks, triton.cdiv(nheads, META['BLOCK_SIZE_H']))
581 with torch.cuda.device(dt.device.index):
--> 582 _chunk_cumsum_fwd_kernel[grid_chunk_cs](
583 dt, A, dt_bias, dt_out, dA_cumsum,
584 batch, seqlen, nheads, chunk_size,
585 dt_limit[0], dt_limit[1],
586 dt.stride(0), dt.stride(1), dt.stride(2),
587 A.stride(0),
588 dt_bias.stride(0) if dt_bias is not None else 0,
589 dt_out.stride(0), dt_out.stride(2), dt_out.stride(1), dt_out.stride(3),
590 dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),
591 dt_softplus,
592 HAS_DT_BIAS=dt_bias is not None,
593 BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size),
594 )
595 return dA_cumsum, dt_out

File ~/anaconda3/lib/python3.9/site-packages/triton/runtime/jit.py:167, in KernelInterface.getitem..(*args, **kwargs)
161 def getitem(self, grid) -> T:
162 """
163 A JIT function is launched with: fn[grid](*args, **kwargs).
164 Hence JITFunction.getitem returns a callable proxy that
165 memorizes the grid.
166 """
--> 167 return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)

File ~/anaconda3/lib/python3.9/site-packages/triton/runtime/autotuner.py:143, in Autotuner.run(self, *args, **kwargs)
141 pruned_configs = self.prune_configs(kwargs)
142 bench_start = time.time()
--> 143 timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
144 bench_end = time.time()
145 self.bench_time = bench_end - bench_start

File ~/anaconda3/lib/python3.9/site-packages/triton/runtime/autotuner.py:143, in (.0)
141 pruned_configs = self.prune_configs(kwargs)
142 bench_start = time.time()
--> 143 timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
144 bench_end = time.time()
145 self.bench_time = bench_end - bench_start

File ~/anaconda3/lib/python3.9/site-packages/triton/runtime/autotuner.py:122, in Autotuner._bench(self, config, *args, **meta)
119 self.post_hook(args)
121 try:
--> 122 return do_bench(kernel_call, warmup=self.warmup, rep=self.rep, quantiles=(0.5, 0.2, 0.8))
123 except OutOfResources:
124 return [float("inf"), float("inf"), float("inf")]

File ~/anaconda3/lib/python3.9/site-packages/triton/testing.py:102, in do_bench(fn, warmup, rep, grad_to_none, quantiles, fast_flush, return_mode)
83 import torch
84 """
85 Benchmark the runtime of the provided function. By default, return the median runtime of :code:fn along with
86 the 20-th and 80-th performance percentile.
(...)
99 :type fast_flush: bool
100 """
--> 102 fn()
103 torch.cuda.synchronize()
105 # We maintain a buffer of 256 MB that we clear
106 # before each kernel call to make sure that the L2
107 # doesn't contain any input data before the run

File ~/anaconda3/lib/python3.9/site-packages/triton/runtime/autotuner.py:110, in Autotuner._bench..kernel_call()
108 config.pre_hook(full_nargs)
109 self.pre_hook(args)
--> 110 self.fn.run(
111 *args,
112 num_warps=config.num_warps,
113 num_stages=config.num_stages,
114 num_ctas=config.num_ctas,
115 enable_warp_specialization=config.enable_warp_specialization,
116 # enable_persistent=False,
117 **current,
118 )
119 self.post_hook(args)

File ~/anaconda3/lib/python3.9/site-packages/triton/runtime/jit.py:425, in JITFunction.run(self, grid, warmup, *args, **kwargs)
423 if not warmup:
424 args = [arg.value for arg in args if not arg.param.is_constexpr]
--> 425 kernel.run(grid_0, grid_1, grid_2, kernel.num_warps, kernel.num_ctas, # number of warps/ctas per instance
426 kernel.cluster_dims[0], kernel.cluster_dims[1], kernel.cluster_dims[2], # cluster
427 kernel.shared, stream, kernel.function, CompiledKernel.launch_enter_hook,
428 CompiledKernel.launch_exit_hook, kernel,
429 *driver.assemble_tensormap_to_arg(kernel.metadata["tensormaps_info"], args))
430 return kernel

File ~/anaconda3/lib/python3.9/site-packages/triton/compiler/compiler.py:255, in CompiledKernel.getattribute(self, name)
253 def getattribute(self, name):
254 if name == 'run':
--> 255 self._init_handles()
256 return super().getattribute(name)

File ~/anaconda3/lib/python3.9/site-packages/triton/compiler/compiler.py:250, in CompiledKernel._init_handles(self)
248 raise OutOfResources(self.shared, max_shared, "shared memory")
249 # TODO: n_regs, n_spills should be metadata generated when calling ptxas
--> 250 self.module, self.function, self.n_regs, self.n_spills = driver.utils.load_binary(
251 self.name, self.kernel, self.shared, device)

RuntimeError: Triton Error [CUDA]: device kernel image is invalid

Thank you very much for all your assistance.

@tridao
Copy link
Collaborator

tridao commented Jun 12, 2024

It's a triton error, idk how to fix it but you can search triton repo issues

@catalpaaa
Copy link

catalpaaa commented Jun 12, 2024

try to build mamba and causal conv1d yourself with pip install -e, iirc

or maybe build triton from source

@jsie7
Copy link

jsie7 commented Jun 12, 2024

I encountered this before as well. It seems as if mamba/triton is using binaries for a different cuda version, hence the invalidity error. Either build from source to ensure the correct version or manually select the correct binaries.

@catalpaaa
Copy link

I encountered this before as well. It seems as if mamba/triton is using binaries for a different cuda version, hence the invalidity error. Either build from source to ensure the correct version or manually select the correct binaries.

Yes, check if your CUDA_HOME is pointing to other cuda installations

@JHChen1
Copy link

JHChen1 commented Jun 13, 2024

我以前也遇到过这种情况。似乎 mamba/triton 使用的是不同 cuda 版本的二进制文件,因此出现无效错误。要么从源代码构建以确保正确的版本,要么手动选择正确的二进制文件。

Hello, I am currently using torch==2.0.1; triton==2.3.0; mamba_ssm==2.0.3. I am using cuda v11.8 V100 and have this problem: it can run correctly on cuda:0, but reports an error on cuda:1: "untimeError: Triton Error [CUDA]: context is destroyed". Can you give me some advice?

@tridao
Copy link
Collaborator

tridao commented Jun 13, 2024

You can try upgrading pytorch, though I don't think Triton support V100 very well in general

@rationalspark
Copy link
Author

Thanks for all the replies. I tried to build Mamba from source, or install the latest triton 2.3.0, but the "Triton Error [CUDA]: device kernel image is invalid" still exists. I also tried to build Trition from souce, but the compilation fails.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants