From da125f0ed5b63283c980411b6d252cc09746e122 Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Tue, 25 Jun 2024 00:44:30 -0700 Subject: [PATCH] re-enable llama-adapter Signed-off-by: Wang, Yi A --- examples/language-modeling/README.md | 4 +- examples/language-modeling/run_lora_clm.py | 8 +-- examples/text-generation/utils.py | 8 +-- optimum/habana/peft/__init__.py | 3 +- optimum/habana/peft/layer.py | 82 +++++++++++++++++++--- tests/test_peft_inference.py | 10 ++- 6 files changed, 83 insertions(+), 32 deletions(-) diff --git a/examples/language-modeling/README.md b/examples/language-modeling/README.md index 8c77f0e81..a164cdf28 100644 --- a/examples/language-modeling/README.md +++ b/examples/language-modeling/README.md @@ -362,7 +362,7 @@ python run_clm.py \ ## PEFT -### LORA/ADALORA/IA3 +### LORA/ADALORA/IA3/LLAMA_ADAPTER To run LoRA finetuning, you can use `run_lora_clm.py`. Here are single-/multi-device command examples for Llama1-7B, Falcon-40B, Llama2-70B, Llama3-8B and Llama3-70B. @@ -653,7 +653,7 @@ DEEPSPEED_HPU_ZERO3_SYNC_MARK_STEP_REQUIRED=1 LOWER_LIST=ops_bf16.txt python3 .. --validation_split_percentage 5 \ --deepspeed ds_falcon_180b_z3.json ``` -Default `peft_type` is `lora`, you could enable adalora or ia3 using `--peft_type adalora` or `--peft_type ia3`. +Default `peft_type` is `lora`, you could enable adalora or ia3 using `--peft_type adalora` or `--peft_type ia3`, or enable llama-adapter for llama model using `--peft_type llama-adapter`. ### Prompt/Prefix/P-tuning diff --git a/examples/language-modeling/run_lora_clm.py b/examples/language-modeling/run_lora_clm.py index 96f1df011..d7619a2a7 100644 --- a/examples/language-modeling/run_lora_clm.py +++ b/examples/language-modeling/run_lora_clm.py @@ -794,16 +794,12 @@ def compute_metrics(eval_preds): task_type=TaskType.CAUSAL_LM, ) from optimum.habana.peft.layer import ( - GaudiAdaptedAttentionAttentionAllReduce, - GaudiAdaptedAttentionPostAttnForward, + GaudiAdaptedAttention_getattr, GaudiAdaptedAttentionPreAttnForward, ) tuners.adaption_prompt.layer.AdaptedAttention.pre_attn_forward = GaudiAdaptedAttentionPreAttnForward - tuners.adaption_prompt.layer.AdaptedAttention.post_attn_forward = GaudiAdaptedAttentionPostAttnForward - tuners.adaption_prompt.layer.AdaptedAttention.attention_all_reduce = ( - GaudiAdaptedAttentionAttentionAllReduce - ) + tuners.adaption_prompt.layer.AdaptedAttention.__getattr__ = GaudiAdaptedAttention_getattr if training_args.gradient_checkpointing: model.enable_input_require_grads() lora_model = get_peft_model(model, peft_config) diff --git a/examples/text-generation/utils.py b/examples/text-generation/utils.py index 4b46ba175..76708b88e 100644 --- a/examples/text-generation/utils.py +++ b/examples/text-generation/utils.py @@ -378,16 +378,12 @@ def peft_model(args, model_dtype, logger, **model_kwargs): from peft import tuners from optimum.habana.peft.layer import ( - GaudiAdaptedAttentionAttentionAllReduce, - GaudiAdaptedAttentionPostAttnForward, + GaudiAdaptedAttention_getattr, GaudiAdaptedAttentionPreAttnForward, ) tuners.adaption_prompt.layer.AdaptedAttention.pre_attn_forward = GaudiAdaptedAttentionPreAttnForward - tuners.adaption_prompt.layer.AdaptedAttention.post_attn_forward = GaudiAdaptedAttentionPostAttnForward - tuners.adaption_prompt.layer.AdaptedAttention.attention_all_reduce = ( - GaudiAdaptedAttentionAttentionAllReduce - ) + tuners.adaption_prompt.layer.AdaptedAttention.__getattr__ = GaudiAdaptedAttention_getattr return model diff --git a/optimum/habana/peft/__init__.py b/optimum/habana/peft/__init__.py index 24b73d097..912681ac9 100644 --- a/optimum/habana/peft/__init__.py +++ b/optimum/habana/peft/__init__.py @@ -1,7 +1,6 @@ from .layer import ( GaudiAdaloraLayerSVDLinearForward, - GaudiAdaptedAttentionAttentionAllReduce, - GaudiAdaptedAttentionPostAttnForward, + GaudiAdaptedAttention_getattr, GaudiAdaptedAttentionPreAttnForward, ) from .peft_model import gaudi_generate, gaudi_prepare_inputs_for_generation diff --git a/optimum/habana/peft/layer.py b/optimum/habana/peft/layer.py index 95af9a8d0..d9731a269 100755 --- a/optimum/habana/peft/layer.py +++ b/optimum/habana/peft/layer.py @@ -1,9 +1,11 @@ +import inspect import math from typing import Any import torch import torch.nn.functional as F from peft.tuners.adaption_prompt.config import TRANSFORMERS_MODEL_CONFIG +from peft.tuners.adaption_prompt.utils import llama_apply_rotary_pos_emb, llama_rotate_half from peft.utils.other import transpose @@ -37,6 +39,68 @@ def GaudiAdaloraLayerSVDLinearForward(self, x: torch.Tensor, *args: Any, **kwarg return result +def compute_query_states(model: torch.nn.Module, **kwargs) -> torch.Tensor: + """ + Copied from https://github.com/huggingface/peft/blob/v0.10.0/src/peft/tuners/adaption_prompt/utils.py#L60 + The only differences are: + -add reuse cache support. + -add past key value list support + """ + hidden_states = kwargs.get("hidden_states") + position_ids = kwargs.get("position_ids") + past_key_value = kwargs.get("past_key_value") + bsz, q_len, _ = hidden_states.size() + query_states = model.q_proj(hidden_states).view(bsz, q_len, model.num_heads, model.head_dim).transpose(1, 2) + + factor = model.k_proj.in_features // model.k_proj.out_features + value_states = ( + model.v_proj(hidden_states).view(bsz, q_len, (model.num_heads // factor), model.head_dim).transpose(1, 2) + ) + + seq_len = q_len + + if past_key_value is not None: + if kwargs.get("reuse_cache", False): + seq_len += past_key_value[0][-2] + elif isinstance(past_key_value, tuple) or isinstance(past_key_value, list): + # for transformers <= 4.35 + seq_len += past_key_value[0].shape[-2] + else: + # since transformers 4.36, this is a DynamicCache instance + seq_len += past_key_value.get_seq_length(model.layer_idx) + + # For transformers > 4.37.2 `position_ids` became a required arguments in the rotary embedding's forward pass. + if "position_ids" not in inspect.signature(model.rotary_emb.forward).parameters: + # TODO we assume that position_ids is not None here, not sure if that is safe but the old code also did that + cos, sin = model.rotary_emb(value_states, seq_len=seq_len) + return llama_apply_rotary_pos_emb(query_states, cos, sin, position_ids) + + past_seen_tokens = 0 + if position_ids is None: + # Compute position_ids, since they are required for transformers > 4.37.2 + if past_key_value is None: + new_cache_positions = torch.arange(q_len, q_len + q_len, device=value_states.device) + else: + past_seen_tokens = past_key_value.get_usable_length(q_len, model.layer_idx) + new_cache_positions = torch.arange(past_seen_tokens, past_seen_tokens + q_len, device=value_states.device) + position_ids = new_cache_positions.unsqueeze(0) + + rotary_emb_kwargs = {"position_ids": position_ids} + # The `seq_len` argument has been officially removed in transformers >= 4.39.0 + if "seq_len" in inspect.signature(model.rotary_emb.forward).parameters: + rotary_emb_kwargs["seq_len"] = q_len + past_seen_tokens + + cos, sin = model.rotary_emb(value_states, **rotary_emb_kwargs) + + # For batched inference unsqueeze it on the correct dim + # since: https://github.com/huggingface/transformers/pull/29109 + if len(cos.shape) == 3: + cos = cos.unsqueeze(1) + sin = sin.unsqueeze(1) + + return (query_states * cos) + (llama_rotate_half(query_states) * sin) + + def GaudiAdaptedAttentionPreAttnForward(self, *args, **kwargs): """ Copied from AdaptedAttention.forward: https://github.com/huggingface/peft/blob/v0.10.0/src/peft/tuners/adaption_prompt/layer.py#L57 @@ -79,7 +143,6 @@ def GaudiAdaptedAttentionPreAttnForward(self, *args, **kwargs): adapter_k = torch.repeat_interleave(adapter_k, repeats=factor, dim=1) adapter_v = torch.repeat_interleave(adapter_v, repeats=factor, dim=1) # Recompute query states. - compute_query_states = TRANSFORMERS_MODEL_CONFIG[self.model_type].compute_query_states # (bsz, num_heads, q_len, head_dim) query_states = compute_query_states(model=self.model, **kwargs) @@ -105,15 +168,14 @@ def GaudiAdaptedAttentionPreAttnForward(self, *args, **kwargs): return output, None, past_key_value -def GaudiAdaptedAttentionAttentionAllReduce(self, attn_output): - if hasattr(self.model.o_proj, "all_reduce"): - self.model.o_proj.all_reduce(attn_output) - - -def GaudiAdaptedAttentionPostAttnForward(self, attn_output): - if hasattr(self.model.o_proj, "post_all_reduce"): - self.model.o_proj.post_all_reduce(attn_output) - return attn_output +def GaudiAdaptedAttention_getattr(self, name: str): + """Forward missing attributes to the wrapped module.""" + try: + return super(self.__class__, self).__getattr__(name) + except AttributeError: + # This is necessary as e.g. causal models have various methods that we + # don't want to re-implement here. + return getattr(self.model, name) class LoRALinear: diff --git a/tests/test_peft_inference.py b/tests/test_peft_inference.py index bf014ff86..05e005851 100644 --- a/tests/test_peft_inference.py +++ b/tests/test_peft_inference.py @@ -46,6 +46,8 @@ def _text_generation(self, model, tokenizer, extra_kwargs=None): "max_new_tokens": 128, "ignore_eos": True, } + if extra_kwargs: + generate_kwargs.update(extra_kwargs) generator = pipeline( "text-generation", model=model, @@ -76,16 +78,12 @@ def _test_text_generation(self, model_name_or_path, peft_method): ) elif peft_method == "llama-adapter": from optimum.habana.peft.layer import ( - GaudiAdaptedAttentionAttentionAllReduce, - GaudiAdaptedAttentionPostAttnForward, + GaudiAdaptedAttention_getattr, GaudiAdaptedAttentionPreAttnForward, ) tuners.adaption_prompt.layer.AdaptedAttention.pre_attn_forward = GaudiAdaptedAttentionPreAttnForward - tuners.adaption_prompt.layer.AdaptedAttention.post_attn_forward = GaudiAdaptedAttentionPostAttnForward - tuners.adaption_prompt.layer.AdaptedAttention.attention_all_reduce = ( - GaudiAdaptedAttentionAttentionAllReduce - ) + tuners.adaption_prompt.layer.AdaptedAttention.__getattr__ = GaudiAdaptedAttention_getattr config = AdaptionPromptConfig( adapter_layers=2, adapter_len=4,