Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

L2I Enhancement (temperature scaling on classifier's logits) #2311

Draft
wants to merge 4 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions recipes/ESC50/interpret/extra_requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
matplotlib
pandas
scikit-learn

3 changes: 2 additions & 1 deletion recipes/ESC50/interpret/hparams/l2i_conv2dclassifier.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,9 @@ nmf_decoder: !new:speechbrain.lobes.models.L2I.NMFDecoderAudio
n_freq: !ref <n_freq>

alpha: 10 # applied to NMF loss
beta: 0.8 # L1 regularization to time activations
beta: 0.3 # L1 regularization to time activations
k_fidelity: 3 # top-k fidelity
gamma: 1 # soft cross-entropy temperature

modules:
compute_stft: !ref <compute_stft>
Expand Down
29 changes: 26 additions & 3 deletions recipes/ESC50/interpret/train_l2i.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,7 @@ def compute_objectives(self, pred, batch, stage):
self.top_3_fidelity.append(batch.id, theta_out, classification_out)
self.input_fidelity.append(batch.id, wavs, classification_out)
self.faithfulness.append(batch.id, wavs, classification_out)

self.acc_metric.append(
uttid, predict=classification_out, target=classid, length=lens
)
Expand All @@ -393,7 +394,11 @@ def compute_objectives(self, pred, batch, stage):

loss_nmf = ((reconstructions - X_stft_logpower) ** 2).mean()
loss_nmf = self.hparams.alpha * loss_nmf
self.rec_loss.append(uttid, loss_nmf)

prev = loss_nmf.clone().detach()
loss_nmf += self.hparams.beta * (time_activations).abs().mean()
self.reg_loss.append(uttid, loss_nmf - prev)

if stage != sb.Stage.TEST:
if hasattr(self.hparams.lr_annealing, "on_batch_end"):
Expand All @@ -402,10 +407,15 @@ def compute_objectives(self, pred, batch, stage):
self.last_batch = batch
self.batch_to_plot = (reconstructions.clone(), X_stft_logpower.clone())

theta_out = -torch.log(theta_out)
loss_fdi = (F.softmax(classification_out, dim=1) * theta_out).mean()
c_soft = F.softmax(
classification_out / self.hparams.gamma, dim=1
).detach()
theta_out = torch.log(theta_out)
loss_fid = -(theta_out * c_soft).sum(1).mean()

self.fid_loss.append(uttid, loss_fid)

return loss_nmf + loss_fdi
return loss_nmf + loss_fid

def on_stage_start(self, stage, epoch=None):
def accuracy_value(predict, target, length):
Expand Down Expand Up @@ -511,12 +521,19 @@ def compute_faithfulness(wavs, predictions):

return faithfulness

def save(x):
return x[None]

self.top_3_fidelity = MetricStats(metric=compute_fidelity)
self.input_fidelity = MetricStats(metric=compute_inp_fidelity)
self.faithfulness = MetricStats(metric=compute_faithfulness)
self.acc_metric = sb.utils.metric_stats.MetricStats(
metric=accuracy_value, n_jobs=1
)
self.rec_loss = MetricStats(metric=save)
self.reg_loss = MetricStats(metric=save)
self.fid_loss = MetricStats(metric=save)

return super().on_stage_start(stage, epoch)

def on_stage_end(self, stage, stage_loss, epoch=None):
Expand All @@ -528,6 +545,9 @@ def on_stage_end(self, stage, stage_loss, epoch=None):
self.train_loss = stage_loss
self.train_stats = {
"loss": self.train_loss,
"rec_loss": self.rec_loss.summarize("average"),
"reg_loss": self.reg_loss.summarize("average"),
"fid_loss": self.fid_loss.summarize("average"),
"acc": self.acc_metric.summarize("average"),
}

Expand All @@ -540,6 +560,9 @@ def on_stage_end(self, stage, stage_loss, epoch=None):
sb.nnet.schedulers.update_learning_rate(self.optimizer, new_lr)
valid_stats = {
"loss": stage_loss,
"rec_loss": self.rec_loss.summarize("average"),
"reg_loss": self.reg_loss.summarize("average"),
"fid_loss": self.fid_loss.summarize("average"),
"acc": self.acc_metric.summarize("average"),
"top-3_fid": current_fid,
"input-fidelity": current_inpfid,
Expand Down