From 96921d5bda7ea5d42d4033ff32031bbf018d7eea Mon Sep 17 00:00:00 2001 From: namsaraeva Date: Wed, 29 May 2024 16:26:04 +0200 Subject: [PATCH] add variable loss function --- src/sparcscore/ml/plmodels.py | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/src/sparcscore/ml/plmodels.py b/src/sparcscore/ml/plmodels.py index 0a9300d..6c97da4 100644 --- a/src/sparcscore/ml/plmodels.py +++ b/src/sparcscore/ml/plmodels.py @@ -168,7 +168,7 @@ def test_step(self, batch, batch_idx): class RegressionModel(pl.LightningModule): - def __init__(self, model_type="VGG2_regression", **kwargs): + def __init__(self, model_type="VGG2_regression" **kwargs): super().__init__() self.save_hyperparameters() @@ -198,15 +198,31 @@ def configure_optimizers(self): optimizer = torch.optim.AdamW(self.parameters(), lr=self.hparams["learning_rate"], weight_decay=self.hparams["weight_decay"]) else: - raise ValueError("No optimizer specified in hparams") + raise ValueError("No optimizer specified in hparams.") return optimizer + def configure_loss(self): + if self.hparams["loss"] == "mse": + loss = F.mse_loss + elif self.hparams["loss"] == "huber": + if self.hparams["huber_delta"] is None: + self.hparams["huber_delta"] = 1.0 + loss = F.huber_loss + else: + raise ValueError("No loss function specified in hparams.") + + return loss + def training_step(self, batch): data, target = batch target = target.unsqueeze(1) output = self.network(data) # Forward pass, only one output - loss = F.huber_loss(output, target, delta=1.0, reduction='mean') # consider looking at parameters again + + if self.hparams["loss"] == "huber": # Huber loss + loss = loss(output, target, delta=self.hparams["huber_delta"], reduction='mean') + else: # MSE + loss = loss(output, target) self.log('loss/train', loss, on_step=False, on_epoch=True, prog_bar=True) self.log('mse/train', self.mse(output, target), on_epoch=True, prog_bar=True)