Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

supports learning the combination weights of pre-trained LoRA modules #1666

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

mahdibeit
Copy link

Based on #1655

Adds a use_wlora config to LoraLayer that allows learning the combination weights (i.e. `wlora_weights) of pre-trained LoRAs.

self.use_wlora[adapter_name] = True
# remove `lora_A` and `lora_B` from the list of trainable parameters
self.adapter_layer_names = tuple(
layer for layer in self.adapter_layer_names if layer != "lora_A" and layer != "lora_B"
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure if this is the best way to freeze lora_A and loraB

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So you want to make sure that requires_grad is False for those layers? This not the way. The easiest way may be to call self.lora_A.requires_grad_(False) etc.

@BenjaminBossan
Copy link
Member

Thanks for the PR. For me to be able to review it, could you provide an example of how it should be used?

@mahdibeit
Copy link
Author

@BenjaminBossan Yes, for sure.

The code implements the learned composition in Does Combining Parameter-efficient Modules Improve Few-shot Transfer Accuracy? (Asadi et al., 2024). More specifically, it learns the $v$ for the weighted sum of LoRA modules as follows.

$$\hat{\mathbf{W}} = \mathbf{W}_{base} + \sum_{n=1}^{N} \hat{v}_n \left( \frac{\alpha_n}{r_n} \mathbf{A}_n \mathbf{B}_n\right),$$ $$\sum_{n=1}^{N}\hat{v}_n=1,$$

where $\hat{v}$ is the softmax operation applied on the weighting vector $v$, i.e.,

$$\hat{v}_n=e^{v_n}/\left(\sum_{j=1}^{N}e^{v_j}\right)$$

We named the parameter $v$ as wlora_weights in the model parameters.

Usage example

The following script is an example of how to load two pre-trained LoRA modules and learn the combination weights for LLMs.

First, we load the base model

from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftConfig

# Load base model
base_model = "facebook/opt-350m"
tokenizer = AutoTokenizer.from_pretrained(base_model)
model = AutoModelForCausalLM.from_pretrained(base_model)

Then, we add the two LoRAs and make their weight trainable.

# Add the first LoRA with learnable weight to the base model
lora_1 = "varun-v-rao/opt-350m-lora-1.57M-squad-model3"
lora_1_config = PeftConfig.from_pretrained(lora_1)
lora_1_config.use_wlora =True
model.add_adapter(adapter_config=lora_1_config, adapter_name='lora_1')

# Add the second LoRA
lora_2 = "varun-v-rao/opt-350m-lora-1.57M-squad-model3"
lora_2_config = PeftConfig.from_pretrained(lora_2)
lora_2_config.use_wlora =True
model.add_adapter(adapter_config=lora_2_config, adapter_name='lora_2')

# Activate LoRA modules as trainable
model.set_adapter(['lora_1', 'lora_2'])

Modules are successfully loaded and you can treat the model as any HuggingFace or torch.nn.Module and use any training method. Following is an example of using the HuggingFace Trainer.

# Train the wights of LoRA modules
from transformers import TrainingArguments, Trainer

training_args = TrainingArguments(
    output_dir="wlora-model",
    evaluation_strategy="epoch",
    learning_rate=1e-4,
    weight_decay=0.01,
    push_to_hub=False,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=lm_dataset["train"],
    eval_dataset=lm_dataset["test"],
    data_collator=data_collator,
)

trainer.train()

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this PR and for providing an example. I also see now that you're one of the paper authors :)

I left a couple of comments on this PR. On top of that, we should probably also add a section to the docs (here) because it is not quite trivial to figure out for a user how to use this.

Moreover, I tried to come up with a test for this method. When I tried something based on the example you provided, I ran into an error though. Could you please check?

import torch

from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftConfig, LoraConfig, get_peft_model, PeftModel

torch.manual_seed(0)
base_model_id = "facebook/opt-125m"
tokenizer = AutoTokenizer.from_pretrained(base_model_id)
model = AutoModelForCausalLM.from_pretrained(base_model_id)

config = LoraConfig(init_lora_weights=False, use_wlora=True)
model = get_peft_model(model, config)
model.add_adapter("other", config)
model.base_model.set_adapter(['lora_1', 'lora_2'])

optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
optimizer.zero_grad()
output = model(torch.arange(10).reshape(-1, 1))
loss = output.logits.sum()
loss.backward()
# this causes: RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

Maybe this is related to the comment about how to set requires_grad, not sure.

self.use_wlora[adapter_name] = True
# remove `lora_A` and `lora_B` from the list of trainable parameters
self.adapter_layer_names = tuple(
layer for layer in self.adapter_layer_names if layer != "lora_A" and layer != "lora_B"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So you want to make sure that requires_grad is False for those layers? This not the way. The easiest way may be to call self.lora_A.requires_grad_(False) etc.

x = dropout(x)
result = result + self._apply_dora(x, lora_A, lora_B, scaling, active_adapter)
elif self.use_wlora[active_adapter]:
wlora_scale = self._cal_wlora_scale(active_adapter)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, right now, wlora and dora would be incompatible (which we should directly check in the config's __post_init__ and raise an error). However, I wonder if we can make them work together. In a sense, wlora_scale is just a scaling factor like scaling. Would it work if we remove this elif block and instead, further above in line 560, we do:

scaling = self.scaling[active_adapter]
if self.use_wlora["adapter"]:
    scaling = scaling * self._cal_wlora_scale(active_adapter)

I haven't thought this through completely, e.g. if this also works for merging, but WDYT?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I understand your suggestion. So there are two ways we can combine dlora with wlora:

wlora and dlora working separately

So in this case (this is the case that we already support in the PR), we can have let's say three Lora modules, e.g., lora_1, lora_2, and lora_3. We can set lora_1 as dlora and lora_2 and lora_3 as wlora. So the forward formulation looks like this:

result = result
        + self._apply_dora(x, lora_1_A, lora_1_B, scaling_1, lora_1)
        + lora_2_B(lora_2_A(dropout_2(x))) * scaling_2 * wlora_scale_2
        + lora_3_B(lora_3_A(dropout_3(x))) * scaling_3 * wlora_scale_3

where wlora_scale_2 + wlora_scale_3 = 1

wlora and dlora working togther

In this case (that we do not currently support and I it helps dlora at all) we use two LoRAs (e.g. lora_1 and lora_2) inside dlora and we need to change _apply_dora function to accept wlora parameters like this:

result = result
        + self._apply_dora_and_wlora(x, lora_1_A, lora_1_B, scaling_1, wlora_scale_1, lora_2_A, lora_2_B, scaling_2, wlora_scale_2,)

Overall, I think with the current implementation, wlora works with other modules and we don't have to worry about conflicts since it can be trained and combined with other PEFT methods.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not quite sure if I understand your suggestion completely. Probably you mean if we "unroll" the for loop for the different adapters inside of forward, this is what it would look like?

Would it not work if we just apply the wlora scale to the existing scaling argument and the rest takes care of itself?

src/peft/tuners/lora/config.py Outdated Show resolved Hide resolved
src/peft/tuners/lora/config.py Show resolved Hide resolved
)
},
)
use_wlora: bool = field(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

About the name: Was wlora used in the paper?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, we did not use any names. I am not sure what is the best way to describe this method in the config. We can use WeightedLoRATrainer, AdaptiveScaling, or LearnableScaling. Do you have any suggestions?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have any specific suggestion other than probably not wlora if it's not used in the paper. E.g. imagine that a new method comes out with the same name (which, given the dozens of existing LoRA variants, is not unlikely), this could cause a lot of confusion. Your latter 2 suggestions sound good though, choose whatever fits the paper better.

src/peft/tuners/lora/config.py Outdated Show resolved Hide resolved
@mahdibeit
Copy link
Author

mahdibeit commented May 1, 2024

Thanks for working on this PR and for providing an example. I also see now that you're one of the paper authors :)

I left a couple of comments on this PR. On top of that, we should probably also add a section to the docs (here) because it is not quite trivial to figure out for a user how to use this.

Moreover, I tried to come up with a test for this method. When I tried something based on the example you provided, I ran into an error though. Could you please check?

import torch

from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftConfig, LoraConfig, get_peft_model, PeftModel

torch.manual_seed(0)
base_model_id = "facebook/opt-125m"
tokenizer = AutoTokenizer.from_pretrained(base_model_id)
model = AutoModelForCausalLM.from_pretrained(base_model_id)

config = LoraConfig(init_lora_weights=False, use_wlora=True)
model = get_peft_model(model, config)
model.add_adapter("other", config)
model.base_model.set_adapter(['lora_1', 'lora_2'])

optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
optimizer.zero_grad()
output = model(torch.arange(10).reshape(-1, 1))
loss = output.logits.sum()
loss.backward()
# this causes: RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

Maybe this is related to the comment about how to set requires_grad, not sure.

Thanks for taking the time to read the PR. Yes, I am one of the authors. I was hoping to create an easy method for the community to combine pre-trained LoRAs :)

This is a great test script. Yes, there was an issue regarding the required grad and I changed the method as you mentioned. Also, model.base_model.set_adapter(['lora_1', 'lora_2'])

should contain the names of the layers so it should be :
model.base_model.set_adapter(['default', 'other'])

@mahdibeit
Copy link
Author

mahdibeit commented May 1, 2024

So right now using this update we can write the following scripts as test and example usage. I added the example usage to the docs.

Test

import torch

from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftConfig, LoraConfig, get_peft_model, PeftModel

torch.manual_seed(0)
base_model_id = "facebook/opt-125m"
tokenizer = AutoTokenizer.from_pretrained(base_model_id)
model = AutoModelForCausalLM.from_pretrained(base_model_id)

config = LoraConfig(init_lora_weights=False, use_wlora=True)
model = get_peft_model(model, config)
model.add_adapter("other", config)
model.base_model.set_adapter(['default', 'other'])

# Freeze lora_A and lora_B
for name, param in model.named_parameters():
    if 'lora_A' in name or 'lora_B' in name:
        param.requires_grad = False
        
# Print number of trainable parameters 
print('n_trainable_parameters', model.get_nb_trainable_parameters()) # 12 (layers) * 2 (lora) * 2 (q,v)

optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
optimizer.zero_grad()
output = model(torch.arange(10).reshape(-1, 1))
loss = output.logits.sum()
loss.backward()

Example usage

from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftConfig

# Load base model
base_model = "facebook/opt-350m"
tokenizer = AutoTokenizer.from_pretrained(base_model)
model = AutoModelForCausalLM.from_pretrained(base_model)

# Add the first LoRA with learnable weight to the base model
lora_1 = "varun-v-rao/opt-350m-lora-1.57M-squad-model3"
lora_1_config = PeftConfig.from_pretrained(lora_1)
lora_1_config.use_wlora =True
model.add_adapter(adapter_config=lora_1_config, adapter_name='lora_1')

# Add the second LoRA
lora_2 = "varun-v-rao/opt-350m-lora-1.57M-squad-model3"
lora_2_config = PeftConfig.from_pretrained(lora_2)
lora_2_config.use_wlora =True
model.add_adapter(adapter_config=lora_2_config, adapter_name='lora_2')

# Activate LoRA modules as trainable
model.set_adapter(['lora_1', 'lora_2'])

# Freeze lora_A and lora_B
for name, param in model.named_parameters():
    if 'lora_A' in name or 'lora_B' in name:
        param.requires_grad = False
        

Here, I am using
model.base_model.set_adapter(['default', 'other']) to activate the two modules and I am using

for name, param in model.named_parameters():
    if 'lora_A' in name or 'lora_B' in name:
        param.requires_grad = False

to freeze lora_B and lora_A layers to just keep the wlroa_weights as trainable.

@mahdibeit mahdibeit marked this pull request as ready for review May 1, 2024 02:36
@mahdibeit mahdibeit marked this pull request as draft May 1, 2024 02:36
Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, thanks a lot for the updates.

So right now using this update we can write the following scripts as test and example usage. I added the example usage to the docs.

Next step is now to add a unit tests (or a few) based on the examples. We should ensure in these tests that training works (can be a very simple training loop). At the end, we should ideally see that only the wlora weights were updated and the LoRA weights stay the same. For this test, we should also generate LoRA adapters with init_lora_weights=False.

LMK if you need help with adding these unit tests.

)
},
)
use_wlora: bool = field(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have any specific suggestion other than probably not wlora if it's not used in the paper. E.g. imagine that a new method comes out with the same name (which, given the dozens of existing LoRA variants, is not unlikely), this could cause a lot of confusion. Your latter 2 suggestions sound good though, choose whatever fits the paper better.

@@ -239,6 +261,22 @@ def _apply_dora(self, x, lora_A, lora_B, scaling, active_adapter):

return result_dora

def _cal_wlora_scale(self, active_adapter: str) -> torch.Tensor:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's rename this to _calculate_wlora_scale or _get_wlora_scale (and replace "wlora" with whatever other name you choose).

x = dropout(x)
result = result + self._apply_dora(x, lora_A, lora_B, scaling, active_adapter)
elif self.use_wlora[active_adapter]:
wlora_scale = self._cal_wlora_scale(active_adapter)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not quite sure if I understand your suggestion completely. Probably you mean if we "unroll" the for loop for the different adapters inside of forward, this is what it would look like?

Would it not work if we just apply the wlora scale to the existing scaling argument and the rest takes care of itself?

model = AutoModelForCausalLM.from_pretrained(base_model)

# Add the first LoRA with learnable weight to the base model
lora_1 = "varun-v-rao/opt-350m-lora-1.57M-squad-model3"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally, we can move these examples over to our testing repo https://huggingface.co/peft-internal-testing. I can do that. Are these weights actually trained on something and create interesting outputs, or are those just dummy weights?

# Freeze lora_A and lora_B
for name, param in model.named_parameters():
if "lora_A" in name or "lora_B" in name:
param.requires_grad = False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be really great if we could automate this step. Check my other comment below.

self.wlora_weights[adapter_name] = nn.Parameter(torch.tensor([1.0]))
# add `wlora_weights`` to the list of learnable parameters
self.adapter_layer_names = self.adapter_layer_names[:] + ("wlora_weights",)
self.use_wlora[adapter_name] = True
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we could not automatically deactivate the gradients for the normal LoRA weights if we use wlora. So basically: self.lora_A[adapter_name].requires_grad_(False) etc. WDYT?

# initialize wlora_weights
if not self.wlora_weights:
self.wlora_weights = nn.ParameterDict()
self.wlora_weights[adapter_name] = nn.Parameter(torch.tensor([1.0]))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When this is renamed, please ensure that the parameter name still contains the substring (lora_), since this makes it much easier for users to quickly identify all LoRA-related parameters.

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

@mahdibeit
Copy link
Author

Hi @BenjaminBossan, sorry I was busy in the last three weeks. I will apply your comments and push it this week.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants