Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CASIA Face Recognition #687

Open
gauravkuppa opened this issue Mar 6, 2024 · 4 comments
Open

CASIA Face Recognition #687

gauravkuppa opened this issue Mar 6, 2024 · 4 comments

Comments

@gauravkuppa
Copy link

gauravkuppa commented Mar 6, 2024

I am trying to extend the Metric Learning notebook to do class-disjoint metric learning for faces on the CASIA dataset. For reference, there are a bunch of different faces organized by folders, where each folder has faces of the same people. Ideally, I would like to get a high similarity score between faces of the same people, and low similarity score of people of different similarity score.

import logging
import os
from glob import glob
import datetime
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torchvision
import umap
from cycler import cycler
from PIL import Image
from torchvision import datasets, transforms

import pytorch_metric_learning
import pytorch_metric_learning.utils.logging_presets as logging_presets
from pytorch_metric_learning import losses, miners, samplers, testers, trainers
from pytorch_metric_learning.utils.accuracy_calculator import AccuracyCalculator

logging.getLogger().setLevel(logging.INFO)
logging.info("VERSION %s" % pytorch_metric_learning.__version__)

class MLP(nn.Module):
    def __init__(self, layer_sizes, final_relu=False):
        super().__init__()
        layer_list = []
        layer_sizes = [int(x) for x in layer_sizes]
        num_layers = len(layer_sizes) - 1
        final_relu_layer = num_layers if final_relu else num_layers - 1
        for i in range(len(layer_sizes) - 1):
            input_size = layer_sizes[i]
            curr_size = layer_sizes[i + 1]
            if i < final_relu_layer:
                layer_list.append(nn.ReLU(inplace=False))
            layer_list.append(nn.Linear(input_size, curr_size))
        self.net = nn.Sequential(*layer_list)
        self.last_linear = self.net[-1]

    def forward(self, x):
        return self.net(x)
    
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Set trunk model and replace the softmax layer with an identity function
trunk = torchvision.models.resnet18(pretrained=True)
trunk_output_size = trunk.fc.in_features
trunk.fc = nn.Identity()
trunk = torch.nn.DataParallel(trunk.to(device))

# Set embedder model. This takes in the output of the trunk and outputs 64 dimensional embeddings
embedder = torch.nn.DataParallel(MLP([trunk_output_size, 64]).to(device))

# Set optimizers
trunk_optimizer = torch.optim.Adam(trunk.parameters(), lr=0.00001, weight_decay=0.0001) # lr=0.00001
embedder_optimizer = torch.optim.Adam(
    embedder.parameters(), lr=0.0001, weight_decay=0.0001
) # lr=0.0001

# Set the image transforms
train_transform = transforms.Compose(
    [
        transforms.Resize(64),
        # transforms.RandomResizedCrop(scale=(0.16, 1), ratio=(0.75, 1.33), size=64),
        transforms.RandomHorizontalFlip(0.5),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)

val_transform = transforms.Compose(
    [
        transforms.Resize(64),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)

# This will be used to create train and val sets that are class-disjoint
class ClassDisjointCASIA(torch.utils.data.Dataset):
    def __init__(self, path, train, transform):
        self.images = glob(os.path.join(path, "**", "*jpg"))
        labels = glob(os.path.join(path, "**"))

        self.folder_to_label_dictionary = {}
        for i, label in enumerate(labels):
            l = label.split(os.sep)[-1]
            self.folder_to_label_dictionary[l] = i
        half_way_val =  len(self.folder_to_label_dictionary)//2 # len(self.l)//2
        rule = (lambda x: x < half_way_val) if train else (lambda x: x >= half_way_val)
        self.filtered_idx = [
            i for i, x in enumerate(self.images) if rule(self.get_label(x, False))
        ]
        self.num_classes = len(self.filtered_idx)
        self.transform = transform
    
    def get_label(self, filename, is_folder):
        if is_folder:
            idx = -1
        else:
            idx = -2
        l = filename.split(os.sep)[idx]
        return self.folder_to_label_dictionary[l]
        

    def __len__(self):
        return len(self.filtered_idx)

    def __getitem__(self, index):
        img = self.images[self.filtered_idx[index]]
        label = self.get_label(img, False)
        img = Image.open(img).convert('RGB')
        if self.transform is not None:
            img = self.transform(img)
        assert img.shape[0] == 3, print(img.shape)
        return img, label

# Class disjoint training and validation set
path = "/home/<>/Downloads/CASIA-maxpy-clean/"
train_dataset = ClassDisjointCASIA(
    path, True, train_transform
)
val_dataset = ClassDisjointCASIA(path, False, val_transform)
# train_labels, val_labels = [], []

# for _, label in train_dataset:
#     train_labels.append(label)

# for _, label in val_dataset:
#     val_labels.append(label)
# assert set(train_labels).isdisjoint(set(val_labels)), set(train_labels).intersection(set(val_labels))

# Set the loss function
loss_func = losses.ArcFaceLoss(num_classes=train_dataset.num_classes, embedding_size=64, margin=28.6, scale=30) # .to(device)
loss_optimizer = torch.optim.Adam(loss_func.parameters(), lr=0.00001) 


# Set other training parameters
batch_size = 512
num_epochs = 10

# Package the above stuff into dictionaries.
models = {"trunk": trunk, "embedder": embedder}
optimizers = {
    "trunk_optimizer": trunk_optimizer,
    "embedder_optimizer": embedder_optimizer,
    "metric_loss_optimizer": loss_optimizer
}
loss_funcs = {"metric_loss": loss_func}

model_records_dir = f'runs/MetricLossLogs-{datetime.datetime.now().strftime("%Y%m%d-%H%M%S")}'
model_folder = os.path.join(model_records_dir, "MetricLossLogsModel")
record_keeper, _, _ = logging_presets.get_record_keeper(
    os.path.join(model_records_dir, "MetricLossLogs"), os.path.join(model_records_dir, "MetricLossLogsTensorboard")
)
hooks = logging_presets.get_hook_container(record_keeper)
dataset_dict = {"val": train_dataset}

def visualizer_hook(umapper, umap_embeddings, labels, split_name, keyname, *args):
    # assert False
    logging.info(
        "UMAP plot for the {} split and label set {}".format(split_name, keyname)
    )
    label_set = np.unique(labels)
    num_classes = len(label_set)
    plt.figure(figsize=(20, 15))
    plt.gca().set_prop_cycle(
        cycler(
            "color", [plt.cm.nipy_spectral(i) for i in np.linspace(0, 0.9, num_classes)]
        )
    )
    for i in range(num_classes):
        idx = labels == label_set[i]
        plt.plot(umap_embeddings[idx, 0], umap_embeddings[idx, 1], ".", markersize=1)
    legend_labels = []
    for i in range(num_classes):
        idx = labels == label_set[i]
        # Plot dummy points for legend
        plt.plot([], [], "s", markersize=8, label=f"Class {label_set[i]}")
        # Append labels for legend
        legend_labels.append(f"Class {label_set[i]}")
    plt.legend(legend_labels, loc='upper right', fontsize='large')
    unique_identifier = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    plt.savefig(f"{model_records_dir}/UMAP_fig-{unique_identifier}.jpg", dpi=300) 
    plt.clf()

# Create the tester
tester = testers.GlobalEmbeddingSpaceTester(
    end_of_testing_hook=hooks.end_of_testing_hook,
    visualizer=umap.UMAP(),
    visualizer_hook=visualizer_hook,
    dataloader_num_workers=32,
    accuracy_calculator=AccuracyCalculator(k="max_bin_count"),
)

end_of_epoch_hook = hooks.end_of_epoch_hook(
    tester, dataset_dict, model_folder, test_interval=1, patience=1
)

trainer = trainers.MetricLossOnly(
    models,
    optimizers,
    batch_size,
    loss_funcs,
    train_dataset,
    # mining_funcs=mining_funcs,
    # sampler=sampler,
    dataloader_num_workers=32,
    end_of_iteration_hook=hooks.end_of_iteration_hook,
    end_of_epoch_hook=end_of_epoch_hook,
)

trainer.train(num_epochs=num_epochs)

I have tried to extend this using this new dataset, but I am struggling to train a model that is of any good performance. I know this is the exact problem that is intended to be solved, but I am not sure where I am making a mistake.

After training for about 10 epochs, I get up to 1% accuracy, before val accuracy plateaus. I will attach my results and UMap of train dataset's embeddings from a recent run of up to 10 epochs. These values do not improve considerably with more training.

image
UMAP_fig-20240306-154013

I want to know why this dataset is very difficult to learn. How can I gain more separation of embeddings between classes on train dataset, so I have a better chance of generalization on the test dataset? What are my options here that can enable better learning?

@KevinMusgrave
Copy link
Owner

How big are the images? You could try using the original image sizes, or resizing to something larger than 64.

Does training accuracy keep going up, or does it also plateau?

@gauravkuppa
Copy link
Author

I was able to make some improvements. Tuning the learning rates helped me reach ~90% accuracy on the training set. However, this led to a <1% accuracy for validation set. I attempted your suggestion of increasing image size to 256, and this helped minimally. The validation with larger images plateaued around 6%. I tried to tune s and m, but did not get a meaningful validation accuracy boost.

@gauravkuppa
Copy link
Author

@KevinMusgrave hoping to bump this message to get a response please

@KevinMusgrave
Copy link
Owner

KevinMusgrave commented Mar 20, 2024

The huge gap between training and validation accuracies is odd.

I don't have any experience using the CASIA dataset, but it looks some other people have had issues with very low validation accuracies. Here are some discussions on that:

And here's another thread on CASIA training:

Maybe you can try out some of the hyperparameters mentioned in those threads.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants