Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add LoRA multihead attention module #1324

Conversation

BenjaminBossan
Copy link
Member

@BenjaminBossan BenjaminBossan commented Jan 5, 2024

First stab at adding LoRA support for nn.MultiheadAttention. See #761.

Todos:

  • For now, only works with _qkv_same_embed_dim=True -- make it work with False too. _qkv_same_embed_dim=False is out of scope for this PR and can be added in a later PR if needed.
  • Show that it works in a real world test.
  • Unit tests
  • Docs Apart from docstrings, I don't think anything else needs to be added

Update: I now also included the out_proj to apply LoRA to.

This is a simple test that I ran successfully with the PR in its current state:

import open_clip
import requests
import torch
from torch import nn
from peft import LoraConfig, get_peft_model
from PIL import Image
from peft.tuners.lora.layer import MultiheadAttention as PeftMha

model, preprocess = open_clip.create_model_from_pretrained('hf-hub:laion/CLIP-ViT-g-14-laion2B-s12B-b42K')
tokenizer = open_clip.get_tokenizer('hf-hub:laion/CLIP-ViT-g-14-laion2B-s12B-b42K')
peft_model = get_peft_model(model, config)
opt = torch.optim.SGD(peft_model.parameters(), 0.1)
print(len([m for m in peft_model.modules() if isinstance(m, PeftMha)]))  # 64 PEFT MHA layers
peft_model.print_trainable_parameters()  # trainable params: 2,588,672 || all params: 1,055,873,793 || trainable%: 0.24516869508096598

# text encoder
text = tokenizer(["a diagram", "a dog", "a cat"])
text_features = peft_model.encode_text(text)
loss = text_features.sum()
loss.backward()
opt.step()

# image encoder
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)
image = preprocess(image).unsqueeze(0)
image_features = model.encode_image(image)
image_features.sum().backward()
opt.step()

For now, only works with _qkv_same_embed_dim=True.
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

This is no longer necessary when unloading the model because the
base_layer is already the original layer. This is just a leftover
from before we adopted the base_layer pattern.
There was a bug because the removal of the parameter resulted in it no
longer appearing in the state_dict and named_parameters. This commit
fixes this bug.

The bug also exists in the referenced lora-torch library.
Copy link
Collaborator

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work ! I left few preliminary comments, I think we can go for the _restore_weights approach for now as I don't see any other alternative

lora_alpha: int = 1,
lora_dropout: float = 0.0,
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
is_target_conv_1d_layer: bool = False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
is_target_conv_1d_layer: bool = False,

I don't think this is used?


self._active_adapter = adapter_name
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora)
self.is_target_conv_1d_layer = is_target_conv_1d_layer
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.is_target_conv_1d_layer = is_target_conv_1d_layer

We can also just hard-code it to False

self._restore_weights()
return super().state_dict(*args, **kwargs)

def named_modules(self, *args, **kwargs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need also to over-write the modules() method?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not needed, as modules calls named_modules under the hood. I added a comment to that effect.

@@ -193,11 +193,6 @@ def _replace_module(self, parent, child_name, new_module, child):
if hasattr(child, "base_layer"):
child = child.base_layer

if not hasattr(new_module, "base_layer"):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this has been removed?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, forgot to put this into the description of the PR.

These lines are obsolete for some time now. They only apply when we unload the model (otherwise, the if does not match). Remember when we made the base_layer switch, we ensured that when unloading, we simply return the base_layer, no more need to create a new layer (say, a new nn.Linear when using lora.Linear) and replace the new layer's weight by the parent layer's weight. The base_layer already has the original weight. Therefore, these lines are unnecessary.

I removed them now because they were annoying with MultiheadAttention, because that layer has no weight attribute, so this line would fail.

Copy link
Contributor

@pacman100 pacman100 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you Benjamin for adding support for torch MHA layer in LoRA, interesting way to use merge, forward and unmerge logic!

@BenjaminBossan
Copy link
Member Author

@younesbelkada Could I address all your concerns?

I pinged the user who wanted to test it on their case. When it comes to docs, I didn't really find a place where we list all supported layers, so no update needed really.

Before, LoRA was applied only to the in_proj. Now it is also applied to
the out_proj.

Unfortunately, there is no easy way to just apply a normal lora.Linear
to the out_proj by targeting it with target_modules. If that worked, it
would be much nicer to do that, so that users can decide for themselves
if they want to apply LoRA to the out_proj or not.

The reason why it doesn't work is twofold:

1. We cannot really control the order in which LoRA is applied, so when
   the LoRA adapter is injected to out_proj, the whole MHA layer may
   already be wrapped by lora.MultiheadAttention.
2. Even if we successfully applied a normal lora.Linear to the out_proj,
   it would not work correctly. This is because the forward method of
   out_proj is not used at all by nn.MultiheadAttention. Instead, it
   just passes the weight and bias to F.multi_head_attention_forward.
   Therefore, we must ensure that the weights are merged and unmerged
   correctly, same as for in_proj, and we cannot do that if we use a
   normal lora.Linear.

Note that the test test_merge_layers for MHA fails. This is most likely
because of an existing bug in now merging is implemented, see PR huggingface#1355.
Once that is merged, the test should pass.
@BenjaminBossan
Copy link
Member Author

Note: The test test_merge_layers for MHA fails. This is most likely because of an existing bug in how merging is implemented, see PR #1355. Once that is merged, the test should pass.

@ambroser53
Copy link

Just want to bump a bunch of the issues I've mentioned in #761 but specifically the problem with requires_grad reproducable in this repo

@bghira
Copy link

bghira commented Feb 26, 2024

just wanted to bump this one because it's really the only way for tuning CLIP models after they are released.

@BenjaminBossan
Copy link
Member Author

@bghira Do you happen to have a use case where you could test if this PR works and is working well enough speed-wise? I think the implementation could be ready to be merged but ideally we'd have someone with a real use case give it a try.

@bghira
Copy link

bghira commented Feb 26, 2024

i do and i may be able to test it. stupid question but is the code example above complete? i dont see the hinge loss function

@BenjaminBossan
Copy link
Member Author

stupid question but is the code example above complete? i dont see the hinge loss function

You mean the code right at the top? No, it's not complete at all, just a quick test to show that MHA is applied and the backward pass does not fail. This is not proper nor complete training code.

@PoopBear1
Copy link

PoopBear1 commented Mar 27, 2024

@BenjaminBossan Hi Ben,

Thank you for directing me here; it seems like the exact issue I am looking for. Since this function has not been officially merged into the main branch yet, could you kindly let me know what the config will look like for the multihead LoRA? (peft_model = get_peft_model(model, config)).

I hope to receive some instructions and test this function soon! I'm very much looking forward to it!

Here is current issues I met.

I run my code with

lora_config = LoraConfig(
        r=12,
        lora_alpha=24,
        target_modules=["attn"],
        lora_dropout=0.05,
        bias="none"
    )

I found a few warnings, and the performance degradation was extremely dramatic. I will dive into this issue.

Loading evaluator: Classification
No checkpoint found, train from scratch
Initialize tensorboard (log_dir=......./tensorboard)
/home/isaac/anaconda3/envs/lib/python3.8/site-packages/peft/tuners/tuners_utils.py:711: UserWarning: All adapters are already merged, nothing to do.
warnings.warn("All adapters are already merged, nothing to do.")
/home/isaac/anaconda3/envs/lib/python3.8/site-packages/peft/tuners/lora/layer.py:439: UserWarning: Already unmerged. Nothing to do.
warnings.warn("Already unmerged. Nothing to do.")
epoch [1/5] batch [20/204] time 0.177 (0.218) data 0.000 (0.011) loss 0.6667 (0.8238) lr 1.0000e-05 eta 0:03:37
epoch [1/5] batch [40/204] time 0.178 (0.198) data 0.000 (0.006) loss 1.2822 (0.8632) lr 1.0000e-05 eta 0:03:14
epoch [1/5] batch [60/204] time 0.178 (0.192) data 0.000 (0.004) loss 1.2055 (0.7797) lr 1.0000e-05 eta 0:03:03
epoch [1/5] batch [80/204] time 0.178 (0.188) data 0.000 (0.003) loss 0.1426 (0.8225) lr 1.0000e-05 eta 0:02:57
epoch [1/5] batch [100/204] time 0.178 (0.186) data 0.000 (0.002) loss 0.1367 (0.7533) lr 1.0000e-05 eta 0:02:51
epoch [1/5] batch [120/204] time 0.179 (0.185) data 0.000 (0.002) loss 0.1386 (0.7612) lr 1.0000e-05 eta 0:02:46
epoch [1/5] batch [140/204] time 0.179 (0.184) data 0.000 (0.002) loss 0.1560 (0.7837) lr 1.0000e-05 eta 0:02:42
epoch [1/5] batch [160/204] time 0.179 (0.184) data 0.000 (0.002) loss 0.1206 (0.7639) lr 1.0000e-05 eta 0:02:37
epoch [1/5] batch [180/204] time 0.179 (0.183) data 0.000 (0.001) loss 2.1940 (0.7720) lr 1.0000e-05 eta 0:02:33
epoch [1/5] batch [200/204] time 0.178 (0.183) data 0.000 (0.001) loss 0.4894 (0.7809) lr 1.0000e-05 eta 0:02:29
epoch [2/5] batch [20/204] time 0.178 (0.188) data 0.000 (0.009) loss 3.9319 (1.8420) lr 3.5000e-03 eta 0:02:29
epoch [2/5] batch [40/204] time 0.178 (0.183) data 0.000 (0.005) loss 4.0332 (3.0835) lr 3.5000e-03 eta 0:02:22
epoch [2/5] batch [60/204] time 0.178 (0.182) data 0.000 (0.003) loss 4.0524 (3.3984) lr 3.5000e-03 eta 0:02:17
epoch [2/5] batch [80/204] time 0.178 (0.181) data 0.000 (0.002) loss 4.0278 (3.5583) lr 3.5000e-03 eta 0:02:13
epoch [2/5] batch [100/204] time 0.178 (0.181) data 0.000 (0.002) loss 4.0273 (3.6542) lr 3.5000e-03 eta 0:02:09
epoch [2/5] batch [120/204] time 0.179 (0.180) data 0.000 (0.002) loss 4.0250 (3.7172) lr 3.5000e-03 eta 0:02:05
epoch [2/5] batch [140/204] time 0.179 (0.180) data 0.000 (0.001) loss 4.0519 (3.7622) lr 3.5000e-03 eta 0:02:01
epoch [2/5] batch [160/204] time 0.179 (0.180) data 0.000 (0.001) loss 4.0429 (3.7968) lr 3.5000e-03 eta 0:01:58
epoch [2/5] batch [180/204] time 0.179 (0.180) data 0.000 (0.001) loss 4.0290 (3.8228) lr 3.5000e-03 eta 0:01:54
epoch [2/5] batch [200/204] time 0.178 (0.180) data 0.000 (0.001) loss 4.0289 (3.8440) lr 3.5000e-03 eta 0:01:50
epoch [3/5] batch [20/204] time 0.182 (0.189) data 0.000 (0.010) loss 4.0316 (4.0368) lr 3.1658e-03 eta 0:01:51
epoch [3/5] batch [40/204] time 0.179 (0.184) data 0.000 (0.005) loss 4.0338 (4.0353) lr 3.1658e-03 eta 0:01:45
epoch [3/5] batch [60/204] time 0.179 (0.182) data 0.000 (0.003) loss 4.0431 (4.0348) lr 3.1658e-03 eta 0:01:40
epoch [3/5] batch [80/204] time 0.179 (0.181) data 0.000 (0.003) loss 4.0460 (4.0352) lr 3.1658e-03 eta 0:01:36
epoch [3/5] batch [100/204] time 0.179 (0.181) data 0.000 (0.002) loss 4.0375 (4.0352) lr 3.1658e-03 eta 0:01:32
epoch [3/5] batch [120/204] time 0.179 (0.181) data 0.000 (0.002) loss 4.0183 (4.0345) lr 3.1658e-03 eta 0:01:28
epoch [3/5] batch [140/204] time 0.179 (0.181) data 0.000 (0.002) loss 4.0367 (4.0340) lr 3.1658e-03 eta 0:01:25
epoch [3/5] batch [160/204] time 0.187 (0.181) data 0.000 (0.001) loss 4.0395 (4.0336) lr 3.1658e-03 eta 0:01:21
epoch [3/5] batch [180/204] time 0.179 (0.180) data 0.000 (0.001) loss 4.0319 (4.0340) lr 3.1658e-03 eta 0:01:17
epoch [3/5] batch [200/204] time 0.183 (0.181) data 0.000 (0.001) loss 4.0406 (4.0340) lr 3.1658e-03 eta 0:01:14
epoch [4/5] batch [20/204] time 0.184 (0.189) data 0.000 (0.010) loss 4.0245 (4.0325) lr 2.2908e-03 eta 0:01:13
epoch [4/5] batch [40/204] time 0.179 (0.184) data 0.000 (0.005) loss 4.0110 (4.0332) lr 2.2908e-03 eta 0:01:07
epoch [4/5] batch [60/204] time 0.179 (0.183) data 0.000 (0.003) loss 4.0404 (4.0320) lr 2.2908e-03 eta 0:01:03
epoch [4/5] batch [80/204] time 0.179 (0.182) data 0.000 (0.002) loss 4.0287 (4.0318) lr 2.2908e-03 eta 0:00:59
epoch [4/5] batch [100/204] time 0.179 (0.181) data 0.000 (0.002) loss 4.0235 (4.0310) lr 2.2908e-03 eta 0:00:55
epoch [4/5] batch [120/204] time 0.179 (0.181) data 0.000 (0.002) loss 4.0425 (4.0309) lr 2.2908e-03 eta 0:00:52
epoch [4/5] batch [140/204] time 0.179 (0.181) data 0.000 (0.001) loss 4.0313 (4.0315) lr 2.2908e-03 eta 0:00:48
epoch [4/5] batch [160/204] time 0.284 (0.183) data 0.000 (0.001) loss 3.9778 (4.0311) lr 2.2908e-03 eta 0:00:45
epoch [4/5] batch [180/204] time 0.188 (0.184) data 0.000 (0.001) loss 4.0434 (4.0309) lr 2.2908e-03 eta 0:00:42
epoch [4/5] batch [200/204] time 0.179 (0.186) data 0.000 (0.001) loss 4.0240 (4.0313) lr 2.2908e-03 eta 0:00:38
epoch [5/5] batch [20/204] time 0.182 (0.201) data 0.000 (0.010) loss 4.0220 (4.0251) lr 1.2092e-03 eta 0:00:36
epoch [5/5] batch [40/204] time 0.204 (0.197) data 0.000 (0.005) loss 4.1016 (4.0277) lr 1.2092e-03 eta 0:00:32
epoch [5/5] batch [60/204] time 0.180 (0.192) data 0.000 (0.003) loss 4.0113 (4.0270) lr 1.2092e-03 eta 0:00:27
epoch [5/5] batch [80/204] time 0.179 (0.189) data 0.000 (0.003) loss 4.0538 (4.0266) lr 1.2092e-03 eta 0:00:23
epoch [5/5] batch [100/204] time 0.180 (0.187) data 0.000 (0.002) loss 4.0295 (4.0261) lr 1.2092e-03 eta 0:00:19
epoch [5/5] batch [120/204] time 0.179 (0.186) data 0.000 (0.002) loss 3.9695 (4.0243) lr 1.2092e-03 eta 0:00:15
epoch [5/5] batch [140/204] time 0.179 (0.185) data 0.000 (0.002) loss 4.0654 (4.0252) lr 1.2092e-03 eta 0:00:11
epoch [5/5] batch [160/204] time 0.226 (0.186) data 0.000 (0.001) loss 4.0207 (4.0259) lr 1.2092e-03 eta 0:00:08
epoch [5/5] batch [180/204] time 0.179 (0.186) data 0.000 (0.001) loss 4.0290 (4.0258) lr 1.2092e-03 eta 0:00:04
epoch [5/5] batch [200/204] time 0.179 (0.186) data 0.000 (0.001) loss 4.0228 (4.0256) lr 1.2092e-03 eta 0:00:00

@PoopBear1
Copy link

PoopBear1 commented Mar 27, 2024

@BenjaminBossan Hi Ben,

Thank you for directing me here; it seems like the exact issue I am looking for. Since this function has not been officially merged into the main branch yet, could you kindly let me know what the config will look like for the multihead LoRA? (peft_model = get_peft_model(model, config)).

I hope to receive some instructions and test this function soon! I'm very much looking forward to it!

Here is current issues I met.

I run my code with

lora_config = LoraConfig(
        r=12,
        lora_alpha=24,
        target_modules=["attn"],
        lora_dropout=0.05,
        bias="none"
    )

I found a few warnings, and the performance degradation was extremely dramatic. I will dive into this issue.

Loading evaluator: Classification No checkpoint found, train from scratch Initialize tensorboard (log_dir=......./tensorboard) /home/isaac/anaconda3/envs/lib/python3.8/site-packages/peft/tuners/tuners_utils.py:711: UserWarning: All adapters are already merged, nothing to do. warnings.warn("All adapters are already merged, nothing to do.") /home/isaac/anaconda3/envs/lib/python3.8/site-packages/peft/tuners/lora/layer.py:439: UserWarning: Already unmerged. Nothing to do. warnings.warn("Already unmerged. Nothing to do.") epoch [1/5] batch [20/204] time 0.177 (0.218) data 0.000 (0.011) loss 0.6667 (0.8238) lr 1.0000e-05 eta 0:03:37 epoch [1/5] batch [40/204] time 0.178 (0.198) data 0.000 (0.006) loss 1.2822 (0.8632) lr 1.0000e-05 eta 0:03:14 epoch [1/5] batch [60/204] time 0.178 (0.192) data 0.000 (0.004) loss 1.2055 (0.7797) lr 1.0000e-05 eta 0:03:03 epoch [1/5] batch [80/204] time 0.178 (0.188) data 0.000 (0.003) loss 0.1426 (0.8225) lr 1.0000e-05 eta 0:02:57 epoch [1/5] batch [100/204] time 0.178 (0.186) data 0.000 (0.002) loss 0.1367 (0.7533) lr 1.0000e-05 eta 0:02:51 epoch [1/5] batch [120/204] time 0.179 (0.185) data 0.000 (0.002) loss 0.1386 (0.7612) lr 1.0000e-05 eta 0:02:46 epoch [1/5] batch [140/204] time 0.179 (0.184) data 0.000 (0.002) loss 0.1560 (0.7837) lr 1.0000e-05 eta 0:02:42 epoch [1/5] batch [160/204] time 0.179 (0.184) data 0.000 (0.002) loss 0.1206 (0.7639) lr 1.0000e-05 eta 0:02:37 epoch [1/5] batch [180/204] time 0.179 (0.183) data 0.000 (0.001) loss 2.1940 (0.7720) lr 1.0000e-05 eta 0:02:33 epoch [1/5] batch [200/204] time 0.178 (0.183) data 0.000 (0.001) loss 0.4894 (0.7809) lr 1.0000e-05 eta 0:02:29 epoch [2/5] batch [20/204] time 0.178 (0.188) data 0.000 (0.009) loss 3.9319 (1.8420) lr 3.5000e-03 eta 0:02:29 epoch [2/5] batch [40/204] time 0.178 (0.183) data 0.000 (0.005) loss 4.0332 (3.0835) lr 3.5000e-03 eta 0:02:22 epoch [2/5] batch [60/204] time 0.178 (0.182) data 0.000 (0.003) loss 4.0524 (3.3984) lr 3.5000e-03 eta 0:02:17 epoch [2/5] batch [80/204] time 0.178 (0.181) data 0.000 (0.002) loss 4.0278 (3.5583) lr 3.5000e-03 eta 0:02:13 epoch [2/5] batch [100/204] time 0.178 (0.181) data 0.000 (0.002) loss 4.0273 (3.6542) lr 3.5000e-03 eta 0:02:09 epoch [2/5] batch [120/204] time 0.179 (0.180) data 0.000 (0.002) loss 4.0250 (3.7172) lr 3.5000e-03 eta 0:02:05 epoch [2/5] batch [140/204] time 0.179 (0.180) data 0.000 (0.001) loss 4.0519 (3.7622) lr 3.5000e-03 eta 0:02:01 epoch [2/5] batch [160/204] time 0.179 (0.180) data 0.000 (0.001) loss 4.0429 (3.7968) lr 3.5000e-03 eta 0:01:58 epoch [2/5] batch [180/204] time 0.179 (0.180) data 0.000 (0.001) loss 4.0290 (3.8228) lr 3.5000e-03 eta 0:01:54 epoch [2/5] batch [200/204] time 0.178 (0.180) data 0.000 (0.001) loss 4.0289 (3.8440) lr 3.5000e-03 eta 0:01:50 epoch [3/5] batch [20/204] time 0.182 (0.189) data 0.000 (0.010) loss 4.0316 (4.0368) lr 3.1658e-03 eta 0:01:51 epoch [3/5] batch [40/204] time 0.179 (0.184) data 0.000 (0.005) loss 4.0338 (4.0353) lr 3.1658e-03 eta 0:01:45 epoch [3/5] batch [60/204] time 0.179 (0.182) data 0.000 (0.003) loss 4.0431 (4.0348) lr 3.1658e-03 eta 0:01:40 epoch [3/5] batch [80/204] time 0.179 (0.181) data 0.000 (0.003) loss 4.0460 (4.0352) lr 3.1658e-03 eta 0:01:36 epoch [3/5] batch [100/204] time 0.179 (0.181) data 0.000 (0.002) loss 4.0375 (4.0352) lr 3.1658e-03 eta 0:01:32 epoch [3/5] batch [120/204] time 0.179 (0.181) data 0.000 (0.002) loss 4.0183 (4.0345) lr 3.1658e-03 eta 0:01:28 epoch [3/5] batch [140/204] time 0.179 (0.181) data 0.000 (0.002) loss 4.0367 (4.0340) lr 3.1658e-03 eta 0:01:25 epoch [3/5] batch [160/204] time 0.187 (0.181) data 0.000 (0.001) loss 4.0395 (4.0336) lr 3.1658e-03 eta 0:01:21 epoch [3/5] batch [180/204] time 0.179 (0.180) data 0.000 (0.001) loss 4.0319 (4.0340) lr 3.1658e-03 eta 0:01:17 epoch [3/5] batch [200/204] time 0.183 (0.181) data 0.000 (0.001) loss 4.0406 (4.0340) lr 3.1658e-03 eta 0:01:14 epoch [4/5] batch [20/204] time 0.184 (0.189) data 0.000 (0.010) loss 4.0245 (4.0325) lr 2.2908e-03 eta 0:01:13 epoch [4/5] batch [40/204] time 0.179 (0.184) data 0.000 (0.005) loss 4.0110 (4.0332) lr 2.2908e-03 eta 0:01:07 epoch [4/5] batch [60/204] time 0.179 (0.183) data 0.000 (0.003) loss 4.0404 (4.0320) lr 2.2908e-03 eta 0:01:03 epoch [4/5] batch [80/204] time 0.179 (0.182) data 0.000 (0.002) loss 4.0287 (4.0318) lr 2.2908e-03 eta 0:00:59 epoch [4/5] batch [100/204] time 0.179 (0.181) data 0.000 (0.002) loss 4.0235 (4.0310) lr 2.2908e-03 eta 0:00:55 epoch [4/5] batch [120/204] time 0.179 (0.181) data 0.000 (0.002) loss 4.0425 (4.0309) lr 2.2908e-03 eta 0:00:52 epoch [4/5] batch [140/204] time 0.179 (0.181) data 0.000 (0.001) loss 4.0313 (4.0315) lr 2.2908e-03 eta 0:00:48 epoch [4/5] batch [160/204] time 0.284 (0.183) data 0.000 (0.001) loss 3.9778 (4.0311) lr 2.2908e-03 eta 0:00:45 epoch [4/5] batch [180/204] time 0.188 (0.184) data 0.000 (0.001) loss 4.0434 (4.0309) lr 2.2908e-03 eta 0:00:42 epoch [4/5] batch [200/204] time 0.179 (0.186) data 0.000 (0.001) loss 4.0240 (4.0313) lr 2.2908e-03 eta 0:00:38 epoch [5/5] batch [20/204] time 0.182 (0.201) data 0.000 (0.010) loss 4.0220 (4.0251) lr 1.2092e-03 eta 0:00:36 epoch [5/5] batch [40/204] time 0.204 (0.197) data 0.000 (0.005) loss 4.1016 (4.0277) lr 1.2092e-03 eta 0:00:32 epoch [5/5] batch [60/204] time 0.180 (0.192) data 0.000 (0.003) loss 4.0113 (4.0270) lr 1.2092e-03 eta 0:00:27 epoch [5/5] batch [80/204] time 0.179 (0.189) data 0.000 (0.003) loss 4.0538 (4.0266) lr 1.2092e-03 eta 0:00:23 epoch [5/5] batch [100/204] time 0.180 (0.187) data 0.000 (0.002) loss 4.0295 (4.0261) lr 1.2092e-03 eta 0:00:19 epoch [5/5] batch [120/204] time 0.179 (0.186) data 0.000 (0.002) loss 3.9695 (4.0243) lr 1.2092e-03 eta 0:00:15 epoch [5/5] batch [140/204] time 0.179 (0.185) data 0.000 (0.002) loss 4.0654 (4.0252) lr 1.2092e-03 eta 0:00:11 epoch [5/5] batch [160/204] time 0.226 (0.186) data 0.000 (0.001) loss 4.0207 (4.0259) lr 1.2092e-03 eta 0:00:08 epoch [5/5] batch [180/204] time 0.179 (0.186) data 0.000 (0.001) loss 4.0290 (4.0258) lr 1.2092e-03 eta 0:00:04 epoch [5/5] batch [200/204] time 0.179 (0.186) data 0.000 (0.001) loss 4.0228 (4.0256) lr 1.2092e-03 eta 0:00:00

Not sure if it's a problem with the lora_alpha parameter, since it works fine when lora_alpha=1. However, choosing 2*rank seems to destroy the model's performance. Perhaps bigger alpha is not fit for standard CLIP model.

@BenjaminBossan
Copy link
Member Author

Not sure if it's a problem with the lora_alpha parameter, since it works fine when lora_alpha=1. However, choosing 2*rank seems to destroy the model's performance. Perhaps bigger alpha is not fit for standard CLIP model.

Honestly, I don't know what lora_alpha value works well with MHA. Thanks for testing it out. Is the performance on par with your expectation for lora_alpha=1?

About these warnings:

/home/isaac/anaconda3/envs/lib/python3.8/site-packages/peft/tuners/tuners_utils.py:711: UserWarning: All adapters are already merged, nothing to do.
warnings.warn("All adapters are already merged, nothing to do.")
/home/isaac/anaconda3/envs/lib/python3.8/site-packages/peft/tuners/lora/layer.py:439: UserWarning: Already unmerged. Nothing to do.

Those are definitely strange, as it means the script tried to merge and unmerge some layers, which normally shouldn't happen. You should check your training code for suspicious lines related to merging.

@PoopBear1
Copy link

PoopBear1 commented Mar 27, 2024

Not sure if it's a problem with the lora_alpha parameter, since it works fine when lora_alpha=1. However, choosing 2*rank seems to destroy the model's performance. Perhaps bigger alpha is not fit for standard CLIP model.

Honestly, I don't know what lora_alpha value works well with MHA. Thanks for testing it out. Is the performance on par with your expectation for lora_alpha=1?

About these warnings:

/home/isaac/anaconda3/envs/lib/python3.8/site-packages/peft/tuners/tuners_utils.py:711: UserWarning: All adapters are already merged, nothing to do.
warnings.warn("All adapters are already merged, nothing to do.")
/home/isaac/anaconda3/envs/lib/python3.8/site-packages/peft/tuners/lora/layer.py:439: UserWarning: Already unmerged. Nothing to do.

Those are definitely strange, as it means the script tried to merge and unmerge some layers, which normally shouldn't happen. You should check your training code for suspicious lines related to merging.

Hi Ben,

Thank you again for your prompt reply!

For lora_alpha=1, it indeed increases model performance in classification, so I think maybe MHA only fits with small scaling. And I did additional research here; if you've heard about MultiLoRA (paper link), the scaling factor, even trained from zero, might cause problems. I guess MHA is extremely sensitive to scaling factors.

As for the warnings, I'm unsure if they are related to my training code, since I use the Dassl framework for the training process. But it seems it does not affect how LoRA improves performance.

Really appreciate this work!

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

@BenjaminBossan
Copy link
Member Author

not stale

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

@github-actions github-actions bot closed this May 29, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

7 participants