-
Notifications
You must be signed in to change notification settings - Fork 1.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Reproducibility when using a model with batch norm #1732
Labels
help wanted
Extra attention is needed
Comments
BenjaminBossan
added a commit
to BenjaminBossan/peft
that referenced
this issue
May 15, 2024
Fixes huggingface#1732 After loading a model that was trained with PEFT on a base model with some kind of batch norm layer, the loaded model should produce the same output. Right now, this does not happen. The reason is that during training, buffers for running mean etc. are updated, but they are not saved when calling save_pretrained on the PeftModel instance. Normally in PEFT, we assume that during training, the base model parameters are kept constant, which is not the case with batch norm. We only save the PEFT parameters and assume that when the user loads the base model, all parameters are restored exactly. That way, the information in the buffers is lost completely. This PR fixes this issue by saving the buffers of the batch norm layers. They are identified by checking for the presence of the track_running_stats attribute. Note: One test for BOFT is currently failing, see the comment in the test file.
BenjaminBossan
added a commit
that referenced
this issue
May 22, 2024
Fixes #1732 After loading a model that was trained with PEFT on a base model with some kind of batch norm layer, the loaded model should produce the same output. Right now, this does not happen. The reason is that during training, buffers for running mean etc. are updated, but they are not saved when calling save_pretrained on the PeftModel instance. Normally in PEFT, we assume that during training, the base model parameters are kept constant, which is not the case with batch norm. We only save the PEFT parameters and assume that when the user loads the base model, all parameters are restored exactly. That way, the information in the buffers is lost completely. The fix is to add the batch norm layers to modules_to_save. This fix is now documented and tested.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
System Info
Latest version of PEFT
Who can help?
No response
Information
Tasks
examples
folderReproduction
Expected behavior
After loading a model that was trained with PEFT on a base model with some kind of batch norm layer, the loaded model should produce the same output. Right now, this does not happen.
The reason is that during training, buffers for running mean etc. are updated, but they are not saved when calling
save_pretrained
on thePeftModel
instance. Normally in PEFT, we assume that during training, the base model parameters are kept constant, which is not the case with batch norm. We only save the PEFT parameters and assume that when the user loads the base model, all parameters are restored exactly. That way, the information in the buffers is lost completely.One possible solution would be to try to include the buffers in the PEFT adapter, which is not very pretty. For this to work, we would need to have a way to identify buffers that were updated vs those that are static. If someone knows a way to achieve this, or has a better idea how to fix this, please let us know.
Edit: Best suggestion so far by @kashif: Check for the
track_running_stats
and if it'sTrue
, save the module's buffer. This will not cover all possible corner cases, but hopefully most.The text was updated successfully, but these errors were encountered: