Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug for distributed wrapper regarding to cross batch memory loss #639

Open
zhaoyuac09 opened this issue Jun 12, 2023 · 5 comments · May be fixed by #642
Open

Bug for distributed wrapper regarding to cross batch memory loss #639

zhaoyuac09 opened this issue Jun 12, 2023 · 5 comments · May be fixed by #642
Labels
bug Something isn't working

Comments

@zhaoyuac09
Copy link

zhaoyuac09 commented Jun 12, 2023

First of all, I really appreciated this repo. Thank you very much for the contribution! However, there are 2 functions will not work logically, in distributed.py for the loss and miner wrappers: gather_emb_and_ref and gather_enqueue_mask.
Let's take gather_enqueue_mask for example:

def gather_enqueue_mask(enqueue_mask, device):
    if enqueue_mask is None:
        return enqueue_mask
    enqueue_mask = c_f.to_device(enqueue_mask, device=device)
    return torch.cat([enqueue_mask, all_gather(enqueue_mask)], dim=0)

def all_gather(x):
    world_size = torch.distributed.get_world_size()
    if world_size > 1:
        rank = torch.distributed.get_rank()
        x_list = [torch.ones_like(x) for _ in range(world_size)]
        torch.distributed.all_gather(x_list, x.contiguous())
        # remove curr rank
        x_list.pop(rank)
        return torch.cat(x_list, dim=0)
    return None

the all_gather function poped the rank, which will be different int on different GPUs, then torch cat the current enqueue_mask. Then the order Of the all gathered masks will not be guaranteed the same. When using cross batch memory losses, the embedding_memory will end up different on different GPUs, which I have already confirmed running some testing function.

Here I propose 2 changes to fix this issue:

def gather(emb, labels):
    device = emb.device
    if labels is not None:
        labels = c_f.to_device(labels, device=device)
    # Gather the embeddings from every replica.
    emb = c_f.to_device(emb, device=device)
    emb_list = [torch.ones_like(emb) for _ in range(torch.distributed.get_world_size())]
    torch.distributed.all_gather(emb_list, emb)
    # Gathered tensors have no gradient, so we overwrite the gathered tensor for the current replica.with the embeddings produced on this replica, which do have gradients.
    emb_list[torch.distributed.get_rank()] = emb
    all_emb = torch.cat(emb_list, dim=0)

    # Gather the labels from every replica.
    if labels is not None:
        labels_list = [torch.ones_like(labels) for _ in range(torch.distributed.get_world_size())]
        torch.distributed.all_gather(labels_list, labels)
        # Gathered tensors have no gradient, so we overwrite the gathered tensor for the current replica.with the embeddings produced on this replica, which do have gradients.
        labels_list[torch.distributed.get_rank()] = labels
        all_labels = torch.cat(labels_list, dim=0)
    else:
        all_labels = None
    return all_emb, all_labels, labels

and

def gather_enqueue_mask(enqueue_mask, device):
    if enqueue_mask is None:
        return enqueue_mask
    enqueue_mask = c_f.to_device(enqueue_mask, device=device)
    # Gather the enqueue_mask from every replica.
    enqueue_mask_list = [torch.ones_like(enqueue_mask) for _ in range(torch.distributed.get_world_size())]
    torch.distributed.all_gather(enqueue_mask_list, enqueue_mask)

    # Gathered tensors have no gradient, so we overwrite the gathered tensor for the current replica.with the embeddings produced on this replica, which do have gradients.
    enqueue_mask_list[torch.distributed.get_rank()] = enqueue_mask

    return torch.cat(enqueue_mask_list, dim=0)
@zhaoyuac09 zhaoyuac09 changed the title Bug for distributed wrapper regarding cross batch memory loss Bug for distributed wrapper regarding to cross batch memory loss Jun 12, 2023
@KevinMusgrave KevinMusgrave added the bug Something isn't working label Jun 13, 2023
@KevinMusgrave
Copy link
Owner

Thanks for the code and explanation @zhaoyuac09!. I've found the distributed stuff to be quite tricky.

I'm really busy for the next few days, so I'll have to look at your code a bit later.

In the meantime, if you'd like, you can open a pull request with your code changes.

@zhaoyuac09
Copy link
Author

zhaoyuac09 commented Jun 13, 2023

Thank you @KevinMusgrave. I would be happy to create a pull request later after I finish more testing cases here. If later I have succeeded all testing cases, I will wrap up all changes and open a pull request.
Another issue is, when cross batch memory loss is wrapped with the distributed wrapper, miner cannot be wrapped again since miner will already have access of all embs, labels, etc after all gathering in the loss wrapper (

loss = self.loss(all_emb, all_labels, indices_tuple, enqueue_mask)
).

I believe your repo is really nice and almost there for distributed training support. Thanks for the nice repo and let's make it even better.

@lolongcovas
Copy link

I am facing the same issue. @KevinMusgrave have you reviewed @zhaoyuac09 PR?

@KevinMusgrave
Copy link
Owner

KevinMusgrave commented Jan 14, 2024

@lolongcovas It's not passing the existing test. See my comment: #642 (comment)

@KevinMusgrave
Copy link
Owner

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
3 participants