Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

class_weights cannot be passed via config file as a tensor is expected #2060

Open
robmarkcole opened this issue May 15, 2024 · 5 comments
Open
Labels
scripts Training and evaluation scripts trainers PyTorch Lightning trainers

Comments

@robmarkcole
Copy link
Contributor

Description

Using the Lightning CLI we can train the SemanticSegmentationTask, but cannot use class_weights without an error. Solution is to accept a list if int in addition to tensor

Steps to reproduce

In Lightning CLI Yaml:

model:
  class_path: SemanticSegmentationTask
  init_args:
    model: unet
    backbone: resnet50
    weights: null
    lr: 0.001
    in_channels: 6
    num_classes: 2
    class_weights:
      - 1
      - 50

Will result in

      Does not validate against any of the Union subtypes
      Subtypes: (<class 'torch.Tensor'>, <class 'NoneType'>)
      Errors:
        - Not a valid subclass of Tensor
          Subclass types expect one of:
          - a class path (str)
          - a dict with class_path entry
          - a dict without class_path but with init_args entry (class path given previously)
        - Expected a <class 'NoneType'>
      Given value type: <class 'list'>
      Given value: [1, 50]

Version

main

@isaaccorley
Copy link
Collaborator

Care to make a PR to accept a list and convert to a tensor? If not then I can take it on this weekend.

@robmarkcole
Copy link
Contributor Author

You will get to it way before me!

@adamjstewart
Copy link
Collaborator

For a bit of history, I added this in #1221 and it initially only supported lists. In #1413, @ntw-au modified this to support lists, numpy arrays, and torch tensors. Then in #1541, I modified it to only accept torch tensors. I agree we need a way to support class_weights in a YAML file (and preferably also on the command line). If omegaconf supports this, we could also easily enable omegaconf as a parser: https://lightning.ai/docs/pytorch/stable/cli/lightning_cli_advanced_2.html#enable-variable-interpolation.

@adamjstewart adamjstewart added trainers PyTorch Lightning trainers scripts Training and evaluation scripts labels May 15, 2024
@isaaccorley
Copy link
Collaborator

isaaccorley commented May 15, 2024

If you want to use it with hydra.utils.instantiate and omegaconf you would only need to do the following:

class_weights:
   _target_: torch.tensor
   data: [0.5, 0.5]

@isaaccorley
Copy link
Collaborator

isaaccorley commented May 15, 2024

I haven't looked at the Lightning CLI in awhile but I wonder if it supports recursive instantiation like

class_weights:
   class_path: torch.tensor
   init_args:
      data: [0.5, 0.5]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
scripts Training and evaluation scripts trainers PyTorch Lightning trainers
Projects
None yet
Development

No branches or pull requests

3 participants