Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to load a trained model from a checkpoin #43

Open
Violonur-PavelBI opened this issue Aug 3, 2022 · 0 comments
Open

How to load a trained model from a checkpoin #43

Violonur-PavelBI opened this issue Aug 3, 2022 · 0 comments

Comments

@Violonur-PavelBI
Copy link

Violonur-PavelBI commented Aug 3, 2022

How to load a trained model from a checkpoint.
if the model is set via settings
classification_model:
target: autoalbument.faster_autoaugment.models.ClassificationModel
num_classes: 10
architecture: resnet50 #mobilenetv2_100
pretrained: False

from typing import Tuple
import segmentation_models_pytorch as smp
import timm
from torch import Tensor, nn
from torch.nn import Flatten
import pytorch_lightning as pl

I pulled the model class from autoalbument/autoalbument/faster_autoaugment/models/main_model.py

class BaseDiscriminator(pl.LightningModule):
    def __init__(self, *args, **kwargs):
        super().__init__()

    def forward(self, input: Tensor) -> Tuple[Tensor, Tensor]:
        raise NotImplementedError


class ClassificationModel(BaseDiscriminator):
    def __init__(self):
        super().__init__()
        self.base_model = timm.create_model("resnet50", pretrained=True)
        self.base_model.reset_classifier(10)
        self.classifier = self.base_model.get_classifier()
        num_features = self.classifier.in_features
        self.discriminator = nn.Sequential(
            nn.Linear(num_features, num_features), nn.ReLU(), nn.Linear(num_features, 1)
        )

    def forward(self, input: Tensor) -> Tuple[Tensor, Tensor]:
        x = self.base_model.forward_features(input)
        x = self.base_model.global_pool(x).flatten(1)
        return self.classifier(x), self.discriminator(x).view(-1)

created model

model_ft=ClassificationModel()

when loading weights into the model, an error appears

model_ft_afteFsAA=model_ft.load_from_checkpoint("/workspace/proj/autoalbument/examples/imageWoof/resnet50/outputs/2022-07-18/10-12-02/checkpoints/last.ckpt",batch_size=32,
                num_workers=10)

RuntimeError: Error(s) in loading state_dict for ClassificationModel:
Missing key(s) in state_dict: "base_model.conv1.weight" ... "policy_model.sub_policies.99.stages.3.operations.13.saved_image_shape".

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant