Skip to content

Birds 400-Species Image Classification using Pytorch Metric Learning (Triplet Margin Loss)

License

Notifications You must be signed in to change notification settings

imneonizer/pytorch-triplet-loss

Repository files navigation

Pytorch Triplet Loss (Metric Learning)

This repository is provides a hands on approach to train a Pytorch model using Metric Learning method. In short the model learns to differentiate between images of different classes, it returns a N dimensional vector which can be used to calculate distance between different embedding to find most similar image or label.

  • Metric learning are specially useful when number of classes are very large and you may have less number of images per class.
  • Please make sure that the number of images per class or the class distribution is balanced before training the model, You can apply augmentation method to over sample some classes which are less in number.

Companion YouTube Video

Youtube Video

You can start by installing modules from requirements.txt

Some module like Pytorch or torchvision is not included as you can use latest version

Prediction from trained model

Prediction