Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DOC TST Document and test reproducibility with models using batch norm #1734

Conversation

BenjaminBossan
Copy link
Member

Fixes #1732

After loading a model that was trained with PEFT on a base model with some kind of batch norm layer, the loaded model should produce the same output. Right now, this does not happen.

The reason is that during training, buffers for running mean etc. are updated, but they are not saved when calling save_pretrained on the PeftModel instance. Normally in PEFT, we assume that during training, the base model parameters are kept constant, which is not the case with batch norm. We only save the PEFT parameters and assume that when the user loads the base model, all parameters are restored exactly. That way, the information in the buffers is lost completely.

This PR fixes this issue by saving the buffers of the batch norm layers. They are identified by checking for the presence of the track_running_stats attribute.

Note: One test for BOFT is currently failing, see the comment in the test file.

Fixes huggingface#1732

After loading a model that was trained with PEFT on a base model with
some kind of batch norm layer, the loaded model should produce the same
output. Right now, this does not happen.

The reason is that during training, buffers for running mean etc. are
updated, but they are not saved when calling save_pretrained on the
PeftModel instance. Normally in PEFT, we assume that during training,
the base model parameters are kept constant, which is not the case with
batch norm. We only save the PEFT parameters and assume that when the
user loads the base model, all parameters are restored exactly. That
way, the information in the buffers is lost completely.

This PR fixes this issue by saving the buffers of the batch norm layers.
They are identified by checking for the presence of the
track_running_stats attribute.

Note: One test for BOFT is currently failing, see the comment in the
test file.
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

# currently, we only deal with running stats from BatchNorm* modules
if not hasattr(module, "track_running_stats"):
continue
for buffer_name, buffer in module.named_buffers():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we wanna check if track_running_stats is set to True?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. I considered this, but was hesitant. Say someone trains a model with tracking enabled, and then turns it off for some reason before saving. Then we would not save these buffers even though they're needed, do I see that right? Ideally, we would have a check if they changed vis-à-vis the base model, but I don't see a way to monitor this except for storing a copy of all these buffers, requiring extra memory.

I guess we could decide only to save them if getattr(module, "track_running_stats", None) is True, and issue a warning that they're not saved if getattr(module, "track_running_stats", None) is False (and for None, just ignore).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes... true.. ok its safer to save them you are right yeah someone can then turn it off and it starts to use the batch's summary statistics in inference mode...

@kashif
Copy link
Contributor

kashif commented May 15, 2024

@BenjaminBossan when merging the weights... I suppose currently it will work for nn.Params but now it also needs to do the same for the buffers... how is that handled?

@BenjaminBossan
Copy link
Member Author

when merging the weights... I suppose currently it will work for nn.Params but now it also needs to do the same for the buffers... how is that handled?

Hmm, not sure if I understand. If I train a LoRA adapter and then merge it, its weights will be fused with the base model weights. When we load the LoRA adapter, the base model's running stats are replaced by whatever is saved in the LoRA checkpoint. As the running stats buffers are part of the base model and no LoRA is applied to them, they are not further affected by merging.

Based on your comment, I could think of another problematic case though: A user adds 2 adapters, first adapter A, then B. Let's call the running stats buffers rs. They train A first, updating the running stats to rs_A. Then they switch to B and train, which updates the running stats further to rs_A_B. When they now safe the adapter, we will store rs_A_B for both adapters, when in reality we want rs_A and rs_B.

There are probably more edge cases that result from the fact that we kinda assume that only the PEFT parameters ever change, whereas the base model parameters are fixed. I think for this particular one, we can accept this failure case for now, as the scenario I describe should be very rare (users would typically train A and B separately, not in the same process).

@BenjaminBossan
Copy link
Member Author

Oh, now I wonder if there isn't a different solution: Ask users to add the batch norm layers to modules_to_save. I haven't thought this through completely, but this may solve all the problems we discussed with batch norm. The disadvantage is that users need to explicitly set modules_to_save.

@kashif
Copy link
Contributor

kashif commented May 15, 2024

yes that was my initial solution i believe... sorry i was thinking of something else

@BenjaminBossan
Copy link
Member Author

Just tested the modules_to_save solution: When using this config: LoraConfig(target_modules=["convolution"], modules_to_save=["classifier", "normalization"]), the test passes even without changing the code to explicitly add buffers to the state_dict.

Copy link
Contributor

@pacman100 pacman100 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @BenjaminBossan for the nice minimal reproduction of the issue with batch norm and this PR to fix it. After going through the thread I see that passing batch norm layers to modules_to_save fixes this issue, right?

No need to add extra code to save buffers in checkpoint.
@BenjaminBossan
Copy link
Member Author

@pacman100 I removed the new functionality and instead changed the test to add batch norm layers to modules_to_save and the tests pass. I also added a troubleshooting section about this.

Copy link
Contributor

@pacman100 pacman100 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, @BenjaminBossan, for adding detailed docs on handling BatchNorm, adding comments on different sections of save checkpoint utility and extensive tests! ✨

@BenjaminBossan BenjaminBossan changed the title FIX Store batch norm buffers in PEFT checkpoint DOC TST Document and test reproducibility with models using batch norm May 22, 2024
@BenjaminBossan BenjaminBossan merged commit 1fec231 into huggingface:main May 22, 2024
14 checks passed
@BenjaminBossan BenjaminBossan deleted the fix-checkpoint-for-batchnorm-models branch May 22, 2024 08:43
# TODO: cannot use BOFT because some convolutional kernel dimensions are even (64) and others odd (147). There is no
# common denominator for the boft_block_size except 1, but using 1 results in an error in the fbd_cuda kernel:
# > Error in forward_fast_block_diag_cuda_kernel: an illegal memory access was encountered
# "boft": BOFTConfig(target_modules=["convolution"], modules_to_save=["classifier", "normalization"], boft_block_size=2),
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yfeng95 @Zeju1997 @YuliangXiu Any idea how I could fix the described issue?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Reproducibility when using a model with batch norm
4 participants