Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature Request: Return all losses with multi-loss methods e.g., VICReg #1422

Open
RylanSchaeffer opened this issue Nov 7, 2023 · 4 comments

Comments

@RylanSchaeffer
Copy link

Some SSL methods have a total loss that is a weighted combination of several component losses. For instance, VICReg is a sum of three loss terms: invariance, variance, covariance.

It would be really nice if all the component losses were returned by the VICRegLoss function.

The minimal use case is logging/monitoring. For instance, if the VICReg loss becomes unstable, the user can't necessarily tell which of the three component losses is misbehaving.

@RylanSchaeffer
Copy link
Author

Another example with multiple component losses is TiCo https://arxiv.org/abs/2206.10698

@RylanSchaeffer
Copy link
Author

A third example with multiple component losses is HypersphereLoss https://docs.lightly.ai/self-supervised-learning/lightly.loss.html#lightly.loss.hypersphere_loss.HypersphereLoss

@guarin
Copy link
Contributor

guarin commented Nov 9, 2023

Yes, this is a common issue. We currently only return a single loss because it slightly simplifies the code and it allows you to exchange loss functions without having to make other code changes (you don't have to handle aggregation of the different parts).

For VICReg you can compute the components individually:

from lightly.loss.vicreg_loss import invariance_loss, variance_loss, covariance_loss

inv_loss = invariance_loss(x=z_a, y=z_b)
var_loss = 0.5 * (variance_loss(x=z_a) + variance_loss(x=z_b))
cov_loss = covariance_loss(x=z_a) + covariance_loss(x=z_b)
print(inv_loss, var_loss, cov_loss)
total_loss = 25.0 * inv_loss + 25.0 * var_loss + 1.0 * cov_loss

See also #1161

For TiCo and HypersphereLoss we should definitely add the possibility to calculate the components individually.

Out of curiosity, when logging the individual parts would you log them with or without the loss weights? For example in VICReg each loss part has a weight (lambda_param, mu_param, nu_param) that is then multiplied with the loss.

@RylanSchaeffer
Copy link
Author

RylanSchaeffer commented Nov 9, 2023

For VICReg you can compute the components individually:

I was previously implementing all this SSL stuff myself for a research project, then discovered Lightly and thought "Awesome! Why reimplement the wheel? I'll just switch to Lightly." I can of course compute the components individually, but I would prefer to either use Lightly or not, rather than mix and match.

The pattern I was using that works well is that each loss function should return a dictionary by default, and then one can access the combined loss e.g.:


    def forward(self, z_a: torch.Tensor, z_b: torch.Tensor) -> Dict[str, torch.Tensor]:
        """Returns VICReg loss.

        Args:
            z_a:
                Tensor with shape (batch_size, ..., dim).
            z_b:
                Tensor with shape (batch_size, ..., dim).
        """
        assert (
            z_a.shape[0] > 1 and z_b.shape[0] > 1
        ), f"z_a and z_b must have batch size > 1 but found {z_a.shape[0]} and {z_b.shape[0]}"
        assert (
            z_a.shape == z_b.shape
        ), f"z_a and z_b must have same shape but found {z_a.shape} and {z_b.shape}."

        # invariance term of the loss
        inv_loss = invariance_loss(x=z_a, y=z_b)

        # gather all batches
        if self.gather_distributed and dist.is_initialized():
            world_size = dist.get_world_size()
            if world_size > 1:
                z_a = torch.cat(gather(z_a), dim=0)
                z_b = torch.cat(gather(z_b), dim=0)

        var_loss = 0.5 * (
            variance_loss(x=z_a, eps=self.eps) + variance_loss(x=z_b, eps=self.eps)
        )
        cov_loss = covariance_loss(x=z_a) + covariance_loss(x=z_b)

        loss = (
            self.lambda_param * inv_loss
            + self.mu_param * var_loss
            + self.nu_param * cov_loss
        )

        loss_results_dict = {
            "inv_loss": inv_loss,
            "var_loss": var_loss,
            "cov_loss": cov_loss,
            "total_loss": self.lambda_param * inv_loss + self.mu_param * var_loss + nu_param * cov_loss
        }

        return loss_results_dict

Out of curiosity, when logging the individual parts would you log them with or without the loss weights?

Without the weights, definitely.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
Status: No status
Development

No branches or pull requests

2 participants