Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hello , Is this package able to train a multi-label dataset with one-hot encoding? #669

Open
Mahiro2211 opened this issue Oct 14, 2023 · 2 comments
Labels
enhancement New feature or request

Comments

@Mahiro2211
Copy link

It is a great package that improves my efficiency , When I test cifar10, for the one-hot label , I can use

label = torch.argmax(label,dim=1)

to transform one-hot label but When I test it on some one-hot label I can't find a nice method to deal with a multi-label dataset.

at first, I saw this issue it tells me a way to put in multi-label, but I want to further custom it because I need to construct a similarity matrix

label =  torch.matmul(label,label.t())
# For multi-label dataset , if there is one label shared by two samples I mark it as the same

I hope to receive a response from you soon. Thank you.

@KevinMusgrave KevinMusgrave added the enhancement New feature or request label Oct 15, 2023
@KevinMusgrave
Copy link
Owner

KevinMusgrave commented Oct 15, 2023

Unfortunately there isn't a way to pass in a custom label comparison function into miners or loss functions. It would be a good idea to add this feature though, so I will keep this issue open.

Edit:

Actually I think you can write a miner to accomplish what you're talking about:

from pytorch_metric_learning.miners import BaseMiner

class CustomMiner(BaseMiner):
    def mine(self, embeddings, labels, ref_emb, ref_labels):
        # compare labels and ref_labels however you want
        # return a tuple (a1, p, a2, n)
        # where (a1, p) are the positive pair indices
        # and (a2, n) are the negative pair indices


miner = CustomMiner()
pairs = miner(embeddings, labels)
loss = loss_fn(embeddings, indices_tuple=pairs)

It's not ideal but it's the only workaround I can think of.

@Mahiro2211
Copy link
Author

Thank you for your response.I will try it

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants