From eedfc5ab85431724b407ad87595b9d7a9f174d3e Mon Sep 17 00:00:00 2001 From: namsaraeva Date: Thu, 16 May 2024 11:57:22 +0200 Subject: [PATCH] troubleshooting 3 --- src/sparcscore/ml/models.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/sparcscore/ml/models.py b/src/sparcscore/ml/models.py index b03533c..b7b64f5 100644 --- a/src/sparcscore/ml/models.py +++ b/src/sparcscore/ml/models.py @@ -173,7 +173,7 @@ def __init__(self, super(VGG2_regression, self).__init__() - self.norm = nn.BatchNorm2d(in_channels) + self.norm = nn.BatchNorm2d(512) self.features = self.make_layers(self.cfgs[cfg], in_channels) self.classifier = self.make_layers_MLP(self.cfgs_MLP[cfg_MLP], self.cfgs[cfg], regression=True) # regression is set to True to make the final layer a single output @@ -183,10 +183,7 @@ def vgg(cfg, in_channels, **kwargs): return model def forward(self, x): - - num_of_channels = x.shape[1] - - x = nn.BatchNorm2d(num_of_channels)(x) + x = self.norm(x) x = self.features(x) x = torch.flatten(x, 1)