Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

how to convert custom ckpt to bin? #35

Open
SaltedSlark opened this issue Nov 3, 2023 · 3 comments
Open

how to convert custom ckpt to bin? #35

SaltedSlark opened this issue Nov 3, 2023 · 3 comments

Comments

@SaltedSlark
Copy link

Hi @hubertsiuzdak, I am trying to figure out how to convert my ckpt to pytorch_model.bin so that I can load model by vocos.pretrained or any idea to load ckpt for inferencing directly?

@taalua
Copy link

taalua commented Nov 14, 2023

Hey,

You can do this way:
import torch
ckpt = torch.load('last.ckpt')
torch.save(ckpt['state_dict'], 'pytorch.bin')

then when loading the model:

    def from_model(cls, model_path: str, config_path: str) -> Vocos:
        """
        Class method to create a new Vocos saved model
        """
        model = cls.from_hparams(config_path)
        state_dict = torch.load(model_path, map_location="cpu")
        if isinstance(model.feature_extractor, EncodecFeatures):
            encodec_parameters = {
                "feature_extractor.encodec." + key: value
                for key, value in model.feature_extractor.encodec.state_dict().items()
            }
            state_dict.update(encodec_parameters)
        model.load_state_dict(state_dict, strict=False)
        model.eval()
        return model

@SaltedSlark
Copy link
Author

Hey,

You can do this way: import torch ckpt = torch.load('last.ckpt') torch.save(ckpt['state_dict'], 'pytorch.bin')

then when loading the model:

    def from_model(cls, model_path: str, config_path: str) -> Vocos:
        """
        Class method to create a new Vocos saved model
        """
        model = cls.from_hparams(config_path)
        state_dict = torch.load(model_path, map_location="cpu")
        if isinstance(model.feature_extractor, EncodecFeatures):
            encodec_parameters = {
                "feature_extractor.encodec." + key: value
                for key, value in model.feature_extractor.encodec.state_dict().items()
            }
            state_dict.update(encodec_parameters)
        model.load_state_dict(state_dict, strict=False)
        model.eval()
        return model

Thanks for ur reply! I'll have a try 👍

@dgo2dance
Copy link

+1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants