Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! #181

Open
dddlli opened this issue Mar 31, 2024 · 4 comments
Assignees
Labels
bug Something isn't working

Comments

@dddlli
Copy link

dddlli commented Mar 31, 2024

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument mat2 in method wrapper_CUDA_mm

import torch

from zeta.nn import MambaBlock

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

block = MambaBlock(dim=64, depth=1)

x = torch.randn(1, 10, 64).to(device)

y = block(x).to(device)

print(y.shape)

@dddlli dddlli added the bug Something isn't working label Mar 31, 2024
Copy link

Hello there, thank you for opening an Issue ! 🙏🏻 The team was notified and they will get back to you asap.

@trefoil0219
Copy link

trefoil0219 commented Apr 9, 2024

The same bug:
File "/home/miniconda3/lib/python3.11/site-packages/zeta/nn/modules/simple_mamba.py", line 205, in selective_scan
x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
~~~~~~~~~~~~~~~~^~~
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

@xzytza
Copy link

xzytza commented Apr 10, 2024

add .to(x.device); but I think it is inefficient

@Bool1020
Copy link
Contributor

Replace the following code at line 202 in zeta/nn/modules/simple_mamba.py with:

x = torch.zeros((b, d_in, n)).to(next(self.parameters()).device)

Bool1020 added a commit to Bool1020/zeta that referenced this issue Apr 11, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

5 participants