Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mamba2 assertion error #414

Open
wyc1997 opened this issue Jun 21, 2024 · 0 comments
Open

Mamba2 assertion error #414

wyc1997 opened this issue Jun 21, 2024 · 0 comments

Comments

@wyc1997
Copy link

wyc1997 commented Jun 21, 2024

Hi, when running example inference on Mamba2:

python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba2-2.7b" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.7 --repetition-penalty 1.2

An assertion error on the shape of dt is raised:

$ python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba2-2.7b" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.7 --repetition-penalty 1.2

Loading model state-spaces/mamba2-2.7b
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Number of parameters: 2702599680
Traceback (most recent call last):
  File "/home/ec2-user/workspace/mamba/benchmarks/benchmark_generation_mamba_simple.py", line 82, in <module>
    out = fn()
  File "/home/ec2-user/workspace/mamba/benchmarks/benchmark_generation_mamba_simple.py", line 56, in <lambda>
    fn = lambda: model.generate(
  File "/home/ec2-user/workspace/mamba/mamba_ssm/utils/generation.py", line 260, in generate
    output = decode(
  File "/opt/conda/envs/mamba/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/ec2-user/workspace/mamba/mamba_ssm/utils/generation.py", line 221, in decode
    scores.append(get_logits(sequences[-1], inference_params))
  File "/home/ec2-user/workspace/mamba/mamba_ssm/utils/generation.py", line 184, in get_logits
    logits = model(
  File "/opt/conda/envs/mamba/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/envs/mamba/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ec2-user/workspace/mamba/mamba_ssm/models/mixer_seq_simple.py", line 281, in forward
    hidden_states = self.backbone(input_ids, inference_params=inference_params, **mixer_kwargs)
  File "/opt/conda/envs/mamba/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/envs/mamba/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ec2-user/workspace/mamba/mamba_ssm/models/mixer_seq_simple.py", line 195, in forward
    hidden_states, residual = layer(
  File "/opt/conda/envs/mamba/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/envs/mamba/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ec2-user/workspace/mamba/mamba_ssm/modules/block.py", line 67, in forward
    hidden_states = self.mixer(hidden_states, inference_params=inference_params, **mixer_kwargs)
  File "/opt/conda/envs/mamba/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/envs/mamba/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ec2-user/workspace/mamba/mamba_ssm/modules/mamba2.py", line 226, in forward
    y = mamba_chunk_scan_combined(
  File "/home/ec2-user/workspace/mamba/mamba_ssm/ops/triton/ssd_combined.py", line 563, in mamba_chunk_scan_combined
    return MambaChunkScanCombinedFn.apply(x, dt, A, B, C, chunk_size, D, z, dt_bias, initial_states, seq_idx, dt_softplus, dt_limit, return_final_states)
  File "/opt/conda/envs/mamba/lib/python3.10/site-packages/torch/autograd/function.py", line 598, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/home/ec2-user/workspace/mamba/mamba_ssm/ops/triton/ssd_combined.py", line 529, in forward
    out, out_x, dt_out, dA_cumsum, states, final_states = _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size, D=D, z=z, dt_bias=dt_bias, initial_states=initial_states, seq_idx=seq_idx, dt_softplus=dt_softplus, dt_limit=dt_limit)
  File "/home/ec2-user/workspace/mamba/mamba_ssm/ops/triton/ssd_combined.py", line 286, in _mamba_chunk_scan_combined_fwd
    assert dt.shape == (batch, seqlen, nheads)
AssertionError

Further inspection shows that at the point of the AssertionError:
dt.shape == torch.Size([1, 14, 80])
x.shape == torch.Size([1, 17, 80, 64])
B.shape == torch.Size([1, 17, 1, 128])
Seems like the seqlen of x has 3 more than dt, which caused the assertion error. I wonder if anyone else is also getting this error and what could be potentially causing trouble here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant