Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add add_weighted_adapter to IA3 adapters #1701

Merged
merged 7 commits into from
May 17, 2024

Conversation

alexrs
Copy link
Contributor

@alexrs alexrs commented May 1, 2024

(Partially) Resolves: #1688
See #980 for context.

What

$(IA)^3$ adapters can't be combined. This option, however, is available for other PEFT adapters such as LoRA.

Solution

Implement add_weighted_adapter method for $(IA)^3$ models supporting only weighted average of adapters for now.

@BenjaminBossan
Copy link
Member

@alexrs Thanks a lot for reviving the feature and updating it so quickly. LMK when you think it's ready for review.

@Abdullah-kwl If you could give this a try, that would be great.

if peft_config.target_modules is None:
if model_config["model_type"] not in TRANSFORMERS_MODELS_TO_IA3_TARGET_MODULES_MAPPING:
raise ValueError("Please specify `target_modules` in `peft_config`")
peft_config.target_modules = TRANSFORMERS_MODELS_TO_IA3_TARGET_MODULES_MAPPING[model_config["model_type"]]
peft_config.target_modules = set(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed in #980 (comment), this seemed to be a bug.

Comment on lines -578 to -580
for adapter in adapters:
if adapter not in list(self.peft_config.keys()):
raise ValueError(f"Adapter {adapter} does not exist")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is done in _check_add_weighted_adapter

for adapter in adapters:
if adapter not in list(self.peft_config.keys()):
raise ValueError(f"Adapter {adapter} does not exist")
, which is called right after

else:
continue

target_ia3_l.data = target_ia3_l.data * 0.0
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess here we could use:

Suggested change
target_ia3_l.data = target_ia3_l.data * 0.0
target_ia3_l.data = target_ia3_l.data.zero_()

but I tried to follow the code style used in LoRA:

target_lora_A.data = target_lora_A.data * 0.0

https://pytorch.org/docs/stable/generated/torch.Tensor.zero_.html#torch-tensor-zero

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably .zero_() is more efficient (or has the potential to be so), so I'd be fine with this change.

@alexrs alexrs marked this pull request as ready for review May 2, 2024 13:33
@alexrs
Copy link
Contributor Author

alexrs commented May 2, 2024

@BenjaminBossan Ready for review! 👀

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for your work. This already looks quite good. I have a few smaller comments, and then 3 more general points:

  1. I haven't reviewed the tests yet. I see that you refactored _test_weighted_combination_of_adapters to work with both LoRA and IA³. However, all the changes make the review quite hard and there is also the risk to accidentally break the LoRA tests (e.g. skipping to test something that is currently tested). Instead, it would be better IMO if _test_weighted_combination_of_adapters would call a submethod _test_weighted_combination_of_adapters_lora, _test_weighted_combination_of_adapters_ia3, etc., with _test_weighted_combination_of_adapters being the current function body. Then you can copy and modify it for IA³. Yes, this will result in more code duplication, but for tests that doesn't matter as much and it will be more maintainable, especially since the different methods will always be quite different and thus tests will differ.
  2. IIUC, right now, if I merge, say, 10 adapters, all with weight=1, the resulting merged adapter will possibly have very large values, right? So I should probably assign weight=1/10 so that the norm does not shift. Maybe this should be documented.
  3. How about adding a section to https://github.com/huggingface/peft/blob/main/docs/source/developer_guides/model_merging.md?

src/peft/tuners/ia3/model.py Show resolved Hide resolved
if str in target_module_types:
new_target_modules = "|".join(f"({self.peft_config[adapter].target_modules})" for adapter in adapters)
else:
new_target_modules = set.union(*(self.peft_config[adapter].target_modules for adapter in adapters))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yeah, simpler than the reduce op we have in LoRA 👍

else:
continue

target_ia3_l.data = target_ia3_l.data * 0.0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably .zero_() is more efficient (or has the potential to be so), so I'd be fine with this change.

@alexrs
Copy link
Contributor Author

alexrs commented May 7, 2024

Hey @BenjaminBossan. I made some changes:

  • Re-factored the tests to include _test_weighted_combination_of_adapters_lora and _test_weighted_combination_of_adapters_ia3. These methods are called from _test_weighted_combination_of_adapters.
  • Added a short section to the docs. Let me know if I should expand it.

With respect to whether we should normalize the adapter weights or not, I'm not sure what is the best approach. On the one hand, I agree that having a combination of adapters with the sum of weights > 1 might not return the best results. On the other hand, I also think it is interesting for users to specify how they want their adapter weights to be normalized. In your example, weight=1/10 is a totally valid approach, but also is taking a softmax of the weights or passing some weights as 0 if we want to do some sort of Top K combination. I'm not sure this method should be in charge of handling all of that.

Maybe we should just specify in the docs that we recommend sum(weights) == 1? We can also print a warning if that condition is not fulfilled.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the updates, I think we're almost good, I only had a few minor points.

Maybe we should just specify in the docs that we recommend sum(weights) == 1?

Yes, my suggestion is to document that users should take care that the weights should sum up to 1 or at least be close to 1.

docs/source/developer_guides/model_merging.md Outdated Show resolved Hide resolved
docs/source/developer_guides/model_merging.md Outdated Show resolved Hide resolved
**config_kwargs,
)

# Define a dictionary to map config types to their respective test functions
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's more straightforward to read (though less elegant) if we do:

if isinstance(config, LoraConfig):
    self._test_weighted_combination_of_adapters_lora(...)
elif ...

WDYT?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I initially did that, but if we want to call model = get_peft_model(model, config, adapter_list[0]) only for LoraConfig and IA3Config, we'd need to:

# check if config is LoraConfig or IA3Config. Skip if not.
if not isinstance(config, (LoraConfig, IA3Config)):
    pytest.skip(f"Test not applicable for {config}")

# Initialize model ...

# Call test method according to config type
if isinstance(config, LoraConfig):
    self._test_weighted_combination_of_adapters_lora(...)
elif  isinstance(config, IA3Config):
   self._test_weighted_combination_of_adapters_ia3(...)
else:
    pytest.skip(f"Test not applicable for {config}")

and it seems a bit redundant.

All in all, I don't have a strong opinion here. Let me know if you think this is more readable and I'll change it!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see what you mean but honestly, for a test I don't care so much about such issues, it's more important to me that I can quickly read from top to bottom and understand what's going on.

tests/testing_common.py Outdated Show resolved Hide resolved
Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot Alejandro, LGTM, good to have this feature finally.

@pacman100 do you also want to review?

Copy link
Contributor

@pacman100 pacman100 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, @alexrs, for adding support for combining adapters when usiung IA3! ✨

@pacman100 pacman100 merged commit fb7f279 into huggingface:main May 17, 2024
14 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
4 participants