-
Notifications
You must be signed in to change notification settings - Fork 298
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
class_weights cannot be passed via config file as a tensor is expected #2060
Comments
Care to make a PR to accept a list and convert to a tensor? If not then I can take it on this weekend. |
You will get to it way before me! |
For a bit of history, I added this in #1221 and it initially only supported lists. In #1413, @ntw-au modified this to support lists, numpy arrays, and torch tensors. Then in #1541, I modified it to only accept torch tensors. I agree we need a way to support |
If you want to use it with hydra.utils.instantiate and omegaconf you would only need to do the following: class_weights:
_target_: torch.tensor
data: [0.5, 0.5]
|
I haven't looked at the Lightning CLI in awhile but I wonder if it supports recursive instantiation like class_weights:
class_path: torch.tensor
init_args:
data: [0.5, 0.5] |
Description
Using the Lightning CLI we can train the SemanticSegmentationTask, but cannot use class_weights without an error. Solution is to accept a list if int in addition to tensor
Steps to reproduce
In Lightning CLI Yaml:
Will result in
Version
main
The text was updated successfully, but these errors were encountered: