Skip to content

Official implementation of the Generalized Wasserstein Dice Loss in PyTorch

License

Notifications You must be signed in to change notification settings

LucasFidon/GeneralizedWassersteinDiceLoss

Repository files navigation

Generalized Wasserstein Dice Loss

The Generalized Wasserstein Dice Loss (GWDL) is a loss function to train deep neural networks for applications in medical image multi-class segmentation.

The GWDL is a generalization of the Dice loss and the Generalized Dice loss that can tackle hierarchical classes and can take advantage of known relationships between classes.

Installation

pip install git+https://github.com/LucasFidon/GeneralizedWassersteinDiceLoss.git

Example

import torch
import numpy as np
from generalized_wasserstein_dice_loss.loss import GeneralizedWassersteinDiceLoss

# Example with 3 classes (including the background: label 0).
# The distance between the background (class 0) and the other classes is the maximum, equal to 1.
# The distance between class 1 and class 2 is 0.5.
dist_mat = np.array([
    [0., 1., 1.],
    [1., 0., 0.5],
    [1., 0.5, 0.]
])
wass_loss = GeneralizedWassersteinDiceLoss(dist_matrix=dist_mat)
# 1D prediction; shape: batch size, n class, n elements
pred = torch.tensor([[[1, 0], [0, 1], [0, 0]]], dtype=torch.float32).cuda()
# !D ground truth; shape: batch size, n elements 
grnd = torch.tensor([[0, 2]], dtype=torch.int64).cuda()
wass_loss(pred, grnd)

How to cite

If you use the Generalized Wasserstein Dice Loss in your work, please cite

BibTeX:

@inproceedings{fidon2017generalised,
  title={Generalised {W}asserstein dice score for imbalanced multi-class segmentation using holistic convolutional networks},
  author={Fidon, Lucas and Li, Wenqi and Garcia-Peraza-Herrera, Luis C and Ekanayake, Jinendra and Kitchen, Neil and Ourselin, S{\'e}bastien and Vercauteren, Tom},
  booktitle={International MICCAI Brainlesion Workshop},
  pages={64--76},
  year={2017},
  organization={Springer}
}

Applications of the Generalized Wasserstein Dice loss

For more examples of applications of the generalized Wasserstein Dice loss and how to define the distance matrix, you can look at:

If you find more papers using the generalized Wasserstein Dice loss please let me know :)

About

Official implementation of the Generalized Wasserstein Dice Loss in PyTorch

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published