Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added The BInary Expected_Calibration_Error (ECE) Metric #3132

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
57 changes: 57 additions & 0 deletions ignite/contrib/metrics/ExpectedCalibrationError.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import torch
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First, let's put it into ignite/metrics/ExpectedCalibrationError.py instead of ignite/contrib/metrics/ExpectedCalibrationError.py


from ignite.exceptions import NotComputableError
from ignite.metrics import Metric


class ExpectedCalibrationError(Metric):
def __init__(self, num_bins=10, device=None):
super(ExpectedCalibrationError, self).__init__()
self.num_bins = num_bins
self.device = device
self.reset()

def reset(self):
self.confidences = torch.tensor([], device=self.device)
self.corrects = torch.tensor([], device=self.device)

def update(self, output):
y_pred, y = output
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We usually call .detach on both to stop grad computation like here:

y_pred, y = output[0].detach(), output[1].detach()


assert y_pred.dim() == 2 and y_pred.shape[1] == 2, "This metric is for binary classification."
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's use the following way to raise errors instead of assert:

if not (y_pred.dim() == 2 and y_pred.shape[1] == 2):
    raise ValueError("This metric is for binary classification")

To assert if the input is binary we were doing previously something like here:

def _check_binary_multilabel_cases(self, output: Sequence[torch.Tensor]) -> None:
y_pred, y = output
if not torch.equal(y, y**2):
raise ValueError("For binary cases, y must be comprised of 0's and 1's.")
if not torch.equal(y_pred, y_pred**2):
raise ValueError("For binary cases, y_pred must be comprised of 0's and 1's.")
def _check_type(self, output: Sequence[torch.Tensor]) -> None:
y_pred, y = output
if y.ndimension() + 1 == y_pred.ndimension():
num_classes = y_pred.shape[1]
if num_classes == 1:
update_type = "binary"
self._check_binary_multilabel_cases((y_pred, y))


softmax_probs = torch.softmax(y_pred, dim=1)
max_probs, predicted_class = torch.max(softmax_probs, dim=1)

self.confidences = torch.cat((self.confidences, max_probs))
self.corrects = torch.cat((self.corrects, predicted_class == y))

def compute(self):
if self.confidences.numel() == 0:
raise NotComputableError(
"ExpectedCalibrationError must have at least one example before it can be computed."
)

bin_edges = torch.linspace(0, 1, self.num_bins + 1, device=self.device)

bin_indices = torch.searchsorted(bin_edges, self.confidences)

ece = 0.0
bin_sizes = torch.zeros(self.num_bins, device=self.device)
bin_accuracies = torch.zeros(self.num_bins, device=self.device)

for i in range(self.num_bins):
mask = bin_indices == i
bin_confidences = self.confidences[mask]
bin_corrects = self.corrects[mask]

accuracy = torch.mean(bin_corrects)

avg_confidence = torch.mean(bin_confidences)

bin_size = bin_confidences.numel()
ece += (bin_size / len(self.confidences)) * abs(accuracy - avg_confidence)
bin_sizes[i] = bin_size
bin_accuracies[i] = accuracy

return ece