-
Notifications
You must be signed in to change notification settings - Fork 1.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Add LoRA multihead attention module #1324
[WIP] Add LoRA multihead attention module #1324
Conversation
For now, only works with _qkv_same_embed_dim=True.
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
This is no longer necessary when unloading the model because the base_layer is already the original layer. This is just a leftover from before we adopted the base_layer pattern.
There was a bug because the removal of the parameter resulted in it no longer appearing in the state_dict and named_parameters. This commit fixes this bug. The bug also exists in the referenced lora-torch library.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice work ! I left few preliminary comments, I think we can go for the _restore_weights
approach for now as I don't see any other alternative
src/peft/tuners/lora/layer.py
Outdated
lora_alpha: int = 1, | ||
lora_dropout: float = 0.0, | ||
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out) | ||
is_target_conv_1d_layer: bool = False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is_target_conv_1d_layer: bool = False, |
I don't think this is used?
src/peft/tuners/lora/layer.py
Outdated
|
||
self._active_adapter = adapter_name | ||
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora) | ||
self.is_target_conv_1d_layer = is_target_conv_1d_layer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.is_target_conv_1d_layer = is_target_conv_1d_layer |
We can also just hard-code it to False
self._restore_weights() | ||
return super().state_dict(*args, **kwargs) | ||
|
||
def named_modules(self, *args, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need also to over-write the modules()
method?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not needed, as modules
calls named_modules
under the hood. I added a comment to that effect.
@@ -193,11 +193,6 @@ def _replace_module(self, parent, child_name, new_module, child): | |||
if hasattr(child, "base_layer"): | |||
child = child.base_layer | |||
|
|||
if not hasattr(new_module, "base_layer"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why this has been removed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, forgot to put this into the description of the PR.
These lines are obsolete for some time now. They only apply when we unload the model (otherwise, the if
does not match). Remember when we made the base_layer
switch, we ensured that when unloading, we simply return the base_layer
, no more need to create a new layer (say, a new nn.Linear
when using lora.Linear
) and replace the new layer's weight
by the parent layer's weight
. The base_layer
already has the original weight
. Therefore, these lines are unnecessary.
I removed them now because they were annoying with MultiheadAttention
, because that layer has no weight
attribute, so this line would fail.
- Some clarifying comments - Remove fan_in_fan_out Also: - Raise proper error instead of assert
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you Benjamin for adding support for torch MHA layer in LoRA, interesting way to use merge, forward and unmerge logic!
@younesbelkada Could I address all your concerns? I pinged the user who wanted to test it on their case. When it comes to docs, I didn't really find a place where we list all supported layers, so no update needed really. |
Before, LoRA was applied only to the in_proj. Now it is also applied to the out_proj. Unfortunately, there is no easy way to just apply a normal lora.Linear to the out_proj by targeting it with target_modules. If that worked, it would be much nicer to do that, so that users can decide for themselves if they want to apply LoRA to the out_proj or not. The reason why it doesn't work is twofold: 1. We cannot really control the order in which LoRA is applied, so when the LoRA adapter is injected to out_proj, the whole MHA layer may already be wrapped by lora.MultiheadAttention. 2. Even if we successfully applied a normal lora.Linear to the out_proj, it would not work correctly. This is because the forward method of out_proj is not used at all by nn.MultiheadAttention. Instead, it just passes the weight and bias to F.multi_head_attention_forward. Therefore, we must ensure that the weights are merged and unmerged correctly, same as for in_proj, and we cannot do that if we use a normal lora.Linear. Note that the test test_merge_layers for MHA fails. This is most likely because of an existing bug in now merging is implemented, see PR huggingface#1355. Once that is merged, the test should pass.
Note: The test |
just wanted to bump this one because it's really the only way for tuning CLIP models after they are released. |
@bghira Do you happen to have a use case where you could test if this PR works and is working well enough speed-wise? I think the implementation could be ready to be merged but ideally we'd have someone with a real use case give it a try. |
i do and i may be able to test it. stupid question but is the code example above complete? i dont see the hinge loss function |
You mean the code right at the top? No, it's not complete at all, just a quick test to show that MHA is applied and the backward pass does not fail. This is not proper nor complete training code. |
@BenjaminBossan Hi Ben, Thank you for directing me here; it seems like the exact issue I am looking for. Since this function has not been officially merged into the main branch yet, could you kindly let me know what the config will look like for the multihead LoRA? (peft_model = get_peft_model(model, config)). I hope to receive some instructions and test this function soon! I'm very much looking forward to it! Here is current issues I met.I run my code with
I found a few warnings, and the performance degradation was extremely dramatic. I will dive into this issue. Loading evaluator: Classification |
Not sure if it's a problem with the lora_alpha parameter, since it works fine when lora_alpha=1. However, choosing 2*rank seems to destroy the model's performance. Perhaps bigger alpha is not fit for standard CLIP model. |
Honestly, I don't know what About these warnings:
Those are definitely strange, as it means the script tried to merge and unmerge some layers, which normally shouldn't happen. You should check your training code for suspicious lines related to merging. |
Hi Ben, Thank you again for your prompt reply! For lora_alpha=1, it indeed increases model performance in classification, so I think maybe MHA only fits with small scaling. And I did additional research here; if you've heard about MultiLoRA (paper link), the scaling factor, even trained from zero, might cause problems. I guess MHA is extremely sensitive to scaling factors. As for the warnings, I'm unsure if they are related to my training code, since I use the Dassl framework for the training process. But it seems it does not affect how LoRA improves performance. Really appreciate this work! |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. |
not stale |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. |
First stab at adding LoRA support for
nn.MultiheadAttention
. See #761.Todos:
For now, only works with_qkv_same_embed_dim=True
-- make it work withFalse
too._qkv_same_embed_dim=False
is out of scope for this PR and can be added in a later PR if needed.DocsApart from docstrings, I don't think anything else needs to be addedUpdate: I now also included the
out_proj
to apply LoRA to.This is a simple test that I ran successfully with the PR in its current state: