-
Notifications
You must be signed in to change notification settings - Fork 8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Dimension issues with predict method #8
Comments
I also encountered this issue. It arises from the rescaling line in the "forward" method of the DecoderBlock class in unet.py. I modified the code to interpolate directly to the correct shape as: This ensures that the concatenation operation will proceed as desired. Not pushing this until I've tested that it doesn't mess anything up. |
One should probably modify how the target_size tuple is defined by indexing from the back of skip.shape. |
Hello,
I am trying to use your framework to perform mask extraction on images I have with the following code :
For the moment I just print masks to check what they look like before saving them. But I get the following error when trying to run the code :
Traceback (most recent call last):
File "hidden/mask.py", line 19, in
mask = model.predict(tensor)
File "env/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "env/lib/python3.10/site-packages/backbones_unet/model/unet.py", line 160, in predict
x = self.forward(x)
File "env/lib/python3.10/site-packages/backbones_unet/model/unet.py", line 142, in forward
x = self.decoder(x)
File "env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "env/lib/python3.10/site-packages/backbones_unet/model/unet.py", line 304, in forward
x = b(x, skip)
File "env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "env/lib/python3.10/site-packages/backbones_unet/model/unet.py", line 246, in forward
x = torch.cat([x, skip], dim=1)
RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 14 but got size 13 for tensor number 1 in the list.
And when printing dimensions of my input tensor as well as x and skip which are concatenated, I get in the same order :
torch.Size([1, 3, 207, 204])
torch.Size([1, 2048, 14, 14])
torch.Size([1, 1024, 13, 13])
Would you know where this come from / how to fix it ?
The text was updated successfully, but these errors were encountered: