Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

distributed training bug fix, to maintain same order for mem queue fo… #642

Open
wants to merge 8 commits into
base: dev
Choose a base branch
from

Conversation

zhaoyuac09
Copy link

…r all workers

@KevinMusgrave KevinMusgrave changed the base branch from master to dev June 21, 2023 01:58
@KevinMusgrave
Copy link
Owner

Thanks @zhaoyuac09!

You mentioned you were testing this functionality. Could you add a file tests/utils/test_distributed_xbm_queue.py and paste in your testing code?

@zhaoyuac09
Copy link
Author

Thanks @zhaoyuac09!

You mentioned you were testing this functionality. Could you add a file tests/utils/test_distributed_xbm_queue.py and paste in your testing code?

Thank for checking on the issue! I can add the testing code, but, it has to be tested on a distributed environment to see the memory queue for the correct behavior.
Also, I made another change to the code inside the distributed.py. I found that the current implementation can only handle a small memory queue size compared to MOCO official code (65536). I figured out that the reason is that in distributed setting, the current code will let each GPU to calculate the global batch loss, which won't be necessary. So I have done another modification, just within distributed.py which can significantly increase the memory queue size. I have also done some testing that the modification can reproduce the exact same results as the current version. Do you think you'd like to include this change as well? If so, is another PR a good idea?

@KevinMusgrave
Copy link
Owner

KevinMusgrave commented Jun 26, 2023

Thank for checking on the issue! I can add the testing code, but, it has to be tested on a distributed environment to see the memory queue for the correct behavior.

Yes please add it.

Also, I made another change to the code inside the distributed.py. I found that the current implementation can only handle a small memory queue size compared to MOCO official code (65536). I figured out that the reason is that in distributed setting, the current code will let each GPU to calculate the global batch loss, which won't be necessary. So I have done another modification, just within distributed.py which can significantly increase the memory queue size. I have also done some testing that the modification can reproduce the exact same results as the current version. Do you think you'd like to include this change as well? If so, is another PR a good idea?

I think it's ok to add it to this PR. I guess what you're talking about is allowing CrossBatchMemory to work with efficient=True. Here's the issue for that (you can ignore the suggested solution): #448. Here's an explanation of the efficient flag: https://kevinmusgrave.github.io/pytorch-metric-learning/distributed/#distributedlosswrapper

@zhaoyuac09
Copy link
Author

Thank for checking on the issue! I can add the testing code, but, it has to be tested on a distributed environment to see the memory queue for the correct behavior.

Yes please add it.

Also, I made another change to the code inside the distributed.py. I found that the current implementation can only handle a small memory queue size compared to MOCO official code (65536). I figured out that the reason is that in distributed setting, the current code will let each GPU to calculate the global batch loss, which won't be necessary. So I have done another modification, just within distributed.py which can significantly increase the memory queue size. I have also done some testing that the modification can reproduce the exact same results as the current version. Do you think you'd like to include this change as well? If so, is another PR a good idea?

I think it's ok to add it to this PR. I guess what you're talking about is allowing CrossBatchMemory to work with efficient=True. Here's the issue for that (you can ignore the suggested solution): #448. Here's an explanation of the efficient flag: https://kevinmusgrave.github.io/pytorch-metric-learning/distributed/#distributedlosswrapper

Sure I will later on add the testing code and also the "efficient" xbm loss wrapper for DDP. However, there is one thing I wanted to point out that, the "efficient" xbm loss wrapper for DDP WILL maintain the same gradient as non-dist code (similar to official MOCO implementation), as you may see later after I add the code, as I will override the loss forward inside the wrapper, without scaling the loss with world_size. I'm not sure if you agree on that since the efficient flag on regular loss will not maintain same gradient. I did fixed seed test between efficient xbm loss wrapper and non-efficient version as in current version, I can have the same weight on the trained model. Since the non-efficient current version will maintain the same gradient as non-dist code, then I assume the efficient version xbm loss wrapper will also have the same gradient.

@KevinMusgrave
Copy link
Owner

If you can make efficient=True have gradients equivalent to the non-distributed version, that is even better! 👍

… efficient xbm loss distributed wrapper will maintain same grad as non-efficient version.
@zhaoyuac09
Copy link
Author

If you can make efficient=True have gradients equivalent to the non-distributed version, that is even better! 👍

Just added the changes for both efficient xbm loss dist wrapper and the tester to my forked master branch.

In my opinion, it is not possible to maintain same grad for regular loss in dist setting because gathered embs will not have grad anyway, no matter efficient or non-efficient. But you mentioned here (https://kevinmusgrave.github.io/pytorch-metric-learning/distributed/#:~:text=False%3A%20each%20process%20uses%20gathered%20embeddings%20for%20both%20anchors%20and%20positives/negatives.%20Gradients%20will%20be%20equal%20to%20those%20in%20non%2Ddistributed%20code%2C%20but%20at%20the%20cost%20of%20doing%20unnecessary%20operations%20(i.e.%20doing%20computations%20where%20both%20anchors%20and%20positives/negatives%20have%20no%20gradient).) that non-efficient version will maintain the same grad as non-dist code. I think mathematically, it is not possible, due to the fact gathered embs will have no grad and will be treated as a constant during backward, which has a different math equation of the actual grad. However, the xbm loss is different, due to the fact that embs in queue does not need grad in definition, which can be fully realized even in dist mode. Any thoughts on that? Please correct me if I was wrong here.

@KevinMusgrave
Copy link
Owner

I think you're right 😆

My assumption must have been based on my existing test where the distributed and non-distributed model parameters are nearly the same (though I guess there are some parameters with a larger discrepancy, which is why I had to set rtol=1e-2 on line 22.)

Would it be fair to say the gradients obtained with efficient=False are closer to the real gradients than the gradients obtained with efficient=True? In other words, do you think there is any advantage in efficient=False? If not, we could just get rid of the flag and always use efficient=True behavior.

@zhaoyuac09
Copy link
Author

I think you're right 😆

My assumption must have been based on my existing test where the distributed and non-distributed model parameters are nearly the same (though I guess there are some parameters with a larger discrepancy, which is why I had to set rtol=1e-2 on line 22.)

Would it be fair to say the gradients obtained with efficient=False are closer to the real gradients than the gradients obtained with efficient=True? In other words, do you think there is any advantage in efficient=False? If not, we could just get rid of the flag and always use efficient=True behavior.

For regular loss, I think maybe it is fair to say with efficient=False have closer gradients, due to the fact the the loss with efficient=False may maintain the same, but mathematically, I don't know if it holds true. However, for xbm loss, we shall always have the efficient=True since the grad maintains the same. So as you can see, I don't have the flag in the modified version of distributed loss wrapper.

@KevinMusgrave
Copy link
Owner

KevinMusgrave commented Jun 30, 2023

Thanks for adding the test_distributed_xbm_queue test. Can you format it so that it's a unittest class? Here's a simple test file you can refer to: https://github.com/KevinMusgrave/pytorch-metric-learning/blob/master/tests/distances/test_collected_stats.py

After you've converted it to use unittest, you can run this command:

python -m unittest tests/utils/test_distributed_xbm_queue.py

and it should print something like "Ran X tests" where X is the number of tests in the unittest class. The name of each test function in the class has to start with test_

@zhaoyuac09
Copy link
Author

zhaoyuac09 commented Jun 30, 2023

Thanks for adding the test_distributed_xbm_queue test. Can you format it so that it's a unittest class? Here's a simple test file you can refer to: https://github.com/KevinMusgrave/pytorch-metric-learning/blob/master/tests/distances/test_collected_stats.py

After you've converted it to use unittest, you can run this command:

python -m unittest tests/utils/test_distributed_xbm_queue.py

and it should print something like "Ran X tests" where X is the number of tests in the unittest class. The name of each test function in the class has to start with test_

Sure, I will change that file.

Earlier, I mentioned that distributed cross batch regular loss cannot maintain the same grad as in single GPU. Actually I wrote down the math format and realized I was wrong. I did a test and confirmed that same grad can be maintained in dist mode. I apologize for the confusion. However, the issue is with the efficient=True, mathematically, there should be a way to make the have the same grad as well. I will have to take more time to investigate your dist regular loss wrapper code to see why with efficient=True, the grad cannot be maintained. Please allow me more time and I will see if I can provide fixed code and the testers together. Happy July 4th!

@KevinMusgrave
Copy link
Owner

Thank you for your effort!

@zhaoyuac09
Copy link
Author

Thank you for your effort!

Hello @KevinMusgrave , thank you for your patience. I have finished fixing the efficient=True for distributed regular loss. Even though I cannot reproduce the exact same gradient as in single GPU on distributed setting, I highly doubt that the difference is coming from the DDP internal precision issue. However, I cannot prove it. I can only reproduce that with efficient=False, which can prove that cross batch loss can be distributed without any compromise.

However, I believe my fix is still valid: 1). for removing self comparison, the curr_batch_idx should be the idx of the current mini batch inside the global batch; 2). similar to efficient cross memory wrapper, the loss does not need to be scaled by world_size.

I would really appreciate if you can also test the fix on your side for the efficient=True. And if possible, please also let me know if you can explain why with efficient=True, the gradient is not the same. I thought mathematically it should be the same.

@KevinMusgrave
Copy link
Owner

KevinMusgrave commented Jul 21, 2023

Do you know if it's possible to test distributed training on CPU-only machines? It'd be nice to have the distributed tests run as part of the github workflow. Right now I skip it:

if TEST_DEVICE == torch.device("cpu"):
return

But I do set the distributed backend to gloo if it's CPU:

dist_type = "gloo" if TEST_DEVICE == torch.device("cpu") else "nccl"

@zhaoyuac09
Copy link
Author

distributed

Do you know if it's possible to test distributed training on CPU-only machines? It'd be nice to have the distributed tests run as part of the github workflow. Right now I skip it:

if TEST_DEVICE == torch.device("cpu"):
return

But I do set the distributed backend to gloo if it's CPU:

dist_type = "gloo" if TEST_DEVICE == torch.device("cpu") else "nccl"

I think testing on CPU is possbile, I have tested CPU using "gloo", but only used 1 CPU, where the results are reproduced exactly. I'm not sure how to run CPU test using world_size > 1 though.

For my previous concerns regarding "efficient" flag, I can conclude that for cross memory loss, we shall always use efficient, which is reflected in my latest commit, since efficient and non-efficient version can reproduce exact same results. From grad point of view, they are equivalent due to the fact that memory bank does not require grad. For regular cross batch loss, we can still keep the "efficient" flag since the grad will never be the same in distributed settings, due to the fact that both items in each pair requires grad, but the gathered item lost their grad.

@KevinMusgrave
Copy link
Owner

KevinMusgrave commented Dec 11, 2023

When I run python -m unittest tests.utils.test_distributed.TestDistributedLossWrapper.test_distributed_tuple_loss

I'm getting large differences in the distributed vs non-distributed model parameters, specifically in the XBM case:

tensor(0.0016, device='cuda:1')
tensor(0.0016, device='cuda:0')
tensor(0.0009, device='cuda:1')
tensor(0.0009, device='cuda:0')
tensor(0.0009, device='cuda:1')
tensor(0.0022, device='cuda:1')
tensor(0.0009, device='cuda:0')
tensor(0.0022, device='cuda:0')

Whereas in the old distributed code I get:

tensor(2.9802e-08, device='cuda:1')
tensor(2.9802e-08, device='cuda:0')
tensor(2.9802e-08, device='cuda:1')
tensor(2.9802e-08, device='cuda:0')
tensor(2.9802e-08, device='cuda:0')
tensor(2.9802e-08, device='cuda:1')
tensor(5.9605e-08, device='cuda:1')
tensor(5.9605e-08, device='cuda:0')

@KevinMusgrave
Copy link
Owner

KevinMusgrave commented Dec 11, 2023

I pushed a change with some print statements in test_distributed for debugging, in case you want to test it out.

Basically what test_distributed checks is that after computing a loss and updating parameters for multiple iterations, the distributed model (ddp_mp_model) parameters are equal to the non-distributed (original_model) parameters. Here's line 99-128 of test_distributed:

    for i in range(iterations):
        optimizer.zero_grad()
        outputs = ddp_mp_model(inputs[i][rank].to(device))
        curr_labels = labels[i][rank]
        ref_outputs, curr_ref_labels = None, None
        if ref_inputs:
            ref_outputs = ddp_mp_model(ref_inputs[i][rank].to(device))
            curr_ref_labels = ref_labels[i][rank]
        indices_tuple = None
        if miner_fn:
            indices_tuple = miner_fn(outputs, curr_labels, ref_outputs, curr_ref_labels)
        if miner_fn and not pass_labels_to_loss_fn:
            loss = loss_fn(outputs, indices_tuple=indices_tuple, ref_emb=ref_outputs)
        elif use_xbm_enqueue_mask and isinstance(loss_fn.loss, CrossBatchMemory):
            loss = loss_fn(
                outputs, curr_labels, indices_tuple, enqueue_mask=enqueue_mask[rank]
            )
        else:
            loss = loss_fn(
                outputs, curr_labels, indices_tuple, ref_outputs, curr_ref_labels
            )

        dist.barrier()
        loss.backward()
        dist.barrier()
        optimizer.step()

    dist.barrier()
    print("assert equal")
    assert parameters_are_equal(original_model, ddp_mp_model.module)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Bug for distributed wrapper regarding to cross batch memory loss
2 participants