Skip to content

Commit

Permalink
bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
domenicoMuscill0 committed Oct 18, 2023
1 parent 663ba7a commit 6df3168
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,11 @@ def compute_loss_without_labels(
):
mat = self.distance(embeddings, ref_emb)
r, c = mat.size()

d_pos = torch.zeros(max(r, c))
d_pos = c_f.to_device(d_pos, tensor=embeddings, dtype=embeddings.dtype)
d_pos[: min(r, c)] = mat.diag()
mat.fill_diagonal_(np.inf)

min_a, min_p = torch.zeros(max(r, c)), torch.zeros(
max(r, c)
Expand All @@ -90,9 +95,7 @@ def compute_loss_without_labels(
min_a[:c], _ = torch.min(mat, dim=0)
min_p[:r], _ = torch.min(mat, dim=1)

d_pos = torch.zeros(max(r, c))
d_pos = c_f.to_device(d_pos, tensor=embeddings, dtype=embeddings.dtype)
d_pos[: min(r, c)], d_neg = mat.diag(), torch.min(min_a, min_p)
d_neg = torch.min(min_a, min_p)
return d_pos - d_neg

def compute_loss_with_labels(
Expand Down
1 change: 0 additions & 1 deletion src/pytorch_metric_learning/losses/ranked_list_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ def compute_loss(self, embeddings, labels, indices_tuple, ref_emb, ref_labels):
)
w_n[N_star > 0] = torch.exp(self.Tn * (self.alpha - N_star[N_star > 0]))

print("w_P: ", w_p)
loss_P = torch.sum(
w_p * (P_star - (self.alpha - self.margin)), dim=1
) / torch.sum(w_p + 1e-5, dim=1)
Expand Down
7 changes: 5 additions & 2 deletions tests/losses/test_dynamic_soft_margin_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def _compute_l2_distances(self, x, labels=None):
return find_hard_negatives(dmat, output_index=False, empirical_thresh=0.008)
else:
dmat = compute_distance_matrix_unit_l2(x, x)
dmat.fill_diagonal_(0) # Put distance to itself to 0
anchor_idx, positive_idx, negative_idx = lmu.convert_to_triplets(
None, labels, labels, t_per_anchor="all"
)
Expand Down Expand Up @@ -200,8 +201,10 @@ def forward(self, x, labels=None):
bin_idx = torch.floor((hist_var - self._min_val) / self.bin_width).long()
weight = CDF[bin_idx]

loss = -(neg_dist * weight).mean() + (pos_dist * weight).mean()
return loss.to(dtype=x.dtype)
# Changed to an equivalent version for making same computation as in dynamic_soft_margin_loss.py
# loss = -(neg_dist * weight).mean() + (pos_dist * weight).mean()
loss = (hist_var*weight).mean()
return loss.to(device=x.device, dtype=x.dtype) # Added cast to avoid errors


import unittest
Expand Down

0 comments on commit 6df3168

Please sign in to comment.