Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Union typehint in conjunction with coders causes error for PyTorchModelHubMixin #2283

Open
gorold opened this issue May 20, 2024 · 4 comments
Labels
bug Something isn't working

Comments

@gorold
Copy link

gorold commented May 20, 2024

Describe the bug

Using a union typehint in conjunction with custom encoder/decoder raises the below errors.

Reproduction

from torch import nn
from huggingface_hub import PyTorchModelHubMixin


class CustomArg:
    @classmethod
    def encode(cls, arg): return "custom"

    @classmethod
    def decode(cls, arg): return CustomArg()


class OKModel(
    nn.Module, 
    PyTorchModelHubMixin, 
    coders={CustomArg: (CustomArg.encode, CustomArg.decode)}
):
    def __init__(self, a: int):
        super().__init__()
        self.a = a


class NotOKModel(
    nn.Module, 
    PyTorchModelHubMixin, 
    coders={CustomArg: (CustomArg.encode, CustomArg.decode)}
):
    def __init__(self, a: int | float):
        super().__init__()
        self.a = a

ok_model = OKModel(1)
ok_model.save_pretrained("model")
ok_model = OKModel.from_pretrained("model")

not_ok_model = NotOKModel(1)
not_ok_model.save_pretrained("model")
not_ok_model = NotOKModel.from_pretrained("model")

Logs

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[8], line 26
     24 not_ok_model = NotOKModel(1)
     25 not_ok_model.save_pretrained("model")
---> 26 not_ok_model = NotOKModel.from_pretrained("model")

File .../lib/python3.10/site-packages/huggingface_hub/utils/_validators.py:114, in validate_hf_hub_args.<locals>._inner_fn(*args, **kwargs)
    111 if check_use_auth_token:
    112     kwargs = smoothly_deprecate_use_auth_token(fn_name=fn.__name__, has_token=has_token, kwargs=kwargs)
--> 114 return fn(*args, **kwargs)

File .../lib/python3.10/site-packages/huggingface_hub/hub_mixin.py:472, in ModelHubMixin.from_pretrained(cls, pretrained_model_name_or_path, force_download, resume_download, proxies, token, cache_dir, local_files_only, revision, **model_kwargs)
    470         expected_type = cls._hub_mixin_init_parameters[key].annotation
    471         if expected_type is not inspect.Parameter.empty:
--> 472             config[key] = cls._decode_arg(expected_type, value)
    474 # Populate model_kwargs from config
    475 for param in cls._hub_mixin_init_parameters.values():

File .../lib/python3.10/site-packages/huggingface_hub/hub_mixin.py:317, in ModelHubMixin._decode_arg(cls, expected_type, value)
    315 """Decode a JSON serializable value into an argument."""
    316 for type_, (_, decoder) in cls._hub_mixin_coders.items():
--> 317     if issubclass(expected_type, type_):
    318         return decoder(value)
    319 return value

TypeError: issubclass() arg 1 must be a class


### System info

```shell
Copy-and-paste the text below in your GitHub issue.

- huggingface_hub version: 0.23.0
- Platform: Linux-6.1.58+-x86_64-with-glibc2.31
- Python version: 3.10.11
- Running in iPython ?: No
- Running in notebook ?: No
- Running in Google Colab ?: No
- Token path ?: /root/.cache/huggingface/token
- Has saved token ?: True
- Who am I ?: gorold
- Configured git credential helpers: 
- FastAI: N/A
- Tensorflow: N/A
- Torch: 2.2.0
- Jinja2: 3.1.3
- Graphviz: N/A
- keras: N/A
- Pydot: N/A
- Pillow: 10.2.0
- hf_transfer: N/A
- gradio: N/A
- tensorboard: N/A
- numpy: 1.26.4
- pydantic: 2.6.1
- aiohttp: 3.9.3
- ENDPOINT: https://huggingface.co
- HF_HUB_CACHE: /root/.cache/huggingface/hub
- HF_ASSETS_CACHE: /root/.cache/huggingface/assets
- HF_TOKEN_PATH: /root/.cache/huggingface/token
- HF_HUB_OFFLINE: False
- HF_HUB_DISABLE_TELEMETRY: False
- HF_HUB_DISABLE_PROGRESS_BARS: None
- HF_HUB_DISABLE_SYMLINKS_WARNING: False
- HF_HUB_DISABLE_EXPERIMENTAL_WARNING: False
- HF_HUB_DISABLE_IMPLICIT_TOKEN: False
- HF_HUB_ENABLE_HF_TRANSFER: False
- HF_HUB_ETAG_TIMEOUT: 10
- HF_HUB_DOWNLOAD_TIMEOUT: 10
@gorold gorold added the bug Something isn't working label May 20, 2024
@Wauplin
Copy link
Contributor

Wauplin commented May 21, 2024

Good catch @gorold, thanks for reporting! Would you like to open a PR to fix this? The line if issubclass(expected_type, type_): should be fixed to work with unions as well. You will probably have to use typing.get_args for this.

@gorold
Copy link
Author

gorold commented May 21, 2024

Sure, I could take a stab at it, but I think there's some undefined behaviours that I'd like to clarify first.

  1. If coders contains both parent and child class, and an argument has the typehinted to be the child class, it would always use the parent decoder?
class ParentArg: ...

class ChildArg(ParentArg): ...

class Model(
    nn.Module, 
    PyTorchModelHubMixin, 
    coders={ParentArg: ..., ChildArg: ...}
):
    def __init__(self, a: ChildArg):
        self. a = a

This could be straightforward to solve, just check if expected_type in cls._hub_mixin_coders before looping through the dict? But this may not be correct for complex inheritance structures. I guess the ideal behaviour is to use the nearest ancestor...

  1. How should we actually handle unions which have multiple possible candidate coders?
class Arg1: ...

class Arg2: ...

class Model(
    nn.Module, 
    PyTorchModelHubMixin, 
    coders={Arg1: ..., Arg2: ...}
):
    def __init__(self, a: Arg1 | Arg2 | int):
        self. a = a

This case seems to be more challenging, since we can't tell for sure which decoder we should actually use.
My proposal would be to use the first one that works

for etype in typing.get_args(expected_type):
    try:
        if etype in cls._hub_mixin_coders:
            _, decoder = cls._hub_mixin_coders[etype]
            out = decoder(value)
    except:
        # do smth
    return out
return value

This doesn't check for subclassing though.. need to think about it a little more

@Wauplin
Copy link
Contributor

Wauplin commented May 24, 2024

Thanks for asking the right questions @gorold! I've digged a bit more into it and I think we should:

  1. Fix the NotOKModel class in your example. Annotation int | float is not even related to the encoder/decoder so it's definitely a bug to raise an error here. I think this can be fixed by checking first if the annotation is a class: if inspect.isclass(expected_type) and issubclass(expected_type, type_):
  2. Handle optional type annotation, e.g. CustomArg, Optional[CustomArg] and CustomArg | None (last 2 being the same). This is straightforward in a deterministic way for both encoding and decoding. It also covers most of the use cases.
  3. Not handle more complex types. For union/list/tuple/... annotations we should just ignore them, no matter if the encoder/decoder could have been used. If in the future we get feedback that this use case would make sense in a practical use case, then we would reassess. But let's avoid adding complex logic if it's not necessary.

What do you think?

@gorold
Copy link
Author

gorold commented May 27, 2024

That makes sense! I'll make a PR based on this

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants