Skip to content

Commit

Permalink
re-enable llama-adapter
Browse files Browse the repository at this point in the history
Signed-off-by: Wang, Yi A <[email protected]>
  • Loading branch information
sywangyi committed Jun 25, 2024
1 parent 3ca8f94 commit da125f0
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 32 deletions.
4 changes: 2 additions & 2 deletions examples/language-modeling/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ python run_clm.py \

## PEFT

### LORA/ADALORA/IA3
### LORA/ADALORA/IA3/LLAMA_ADAPTER

To run LoRA finetuning, you can use `run_lora_clm.py`.
Here are single-/multi-device command examples for Llama1-7B, Falcon-40B, Llama2-70B, Llama3-8B and Llama3-70B.
Expand Down Expand Up @@ -653,7 +653,7 @@ DEEPSPEED_HPU_ZERO3_SYNC_MARK_STEP_REQUIRED=1 LOWER_LIST=ops_bf16.txt python3 ..
--validation_split_percentage 5 \
--deepspeed ds_falcon_180b_z3.json
```
Default `peft_type` is `lora`, you could enable adalora or ia3 using `--peft_type adalora` or `--peft_type ia3`.
Default `peft_type` is `lora`, you could enable adalora or ia3 using `--peft_type adalora` or `--peft_type ia3`, or enable llama-adapter for llama model using `--peft_type llama-adapter`.

### Prompt/Prefix/P-tuning

Expand Down
8 changes: 2 additions & 6 deletions examples/language-modeling/run_lora_clm.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,16 +794,12 @@ def compute_metrics(eval_preds):
task_type=TaskType.CAUSAL_LM,
)
from optimum.habana.peft.layer import (
GaudiAdaptedAttentionAttentionAllReduce,
GaudiAdaptedAttentionPostAttnForward,
GaudiAdaptedAttention_getattr,
GaudiAdaptedAttentionPreAttnForward,
)

tuners.adaption_prompt.layer.AdaptedAttention.pre_attn_forward = GaudiAdaptedAttentionPreAttnForward
tuners.adaption_prompt.layer.AdaptedAttention.post_attn_forward = GaudiAdaptedAttentionPostAttnForward
tuners.adaption_prompt.layer.AdaptedAttention.attention_all_reduce = (
GaudiAdaptedAttentionAttentionAllReduce
)
tuners.adaption_prompt.layer.AdaptedAttention.__getattr__ = GaudiAdaptedAttention_getattr
if training_args.gradient_checkpointing:
model.enable_input_require_grads()
lora_model = get_peft_model(model, peft_config)
Expand Down
8 changes: 2 additions & 6 deletions examples/text-generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,16 +378,12 @@ def peft_model(args, model_dtype, logger, **model_kwargs):
from peft import tuners

from optimum.habana.peft.layer import (
GaudiAdaptedAttentionAttentionAllReduce,
GaudiAdaptedAttentionPostAttnForward,
GaudiAdaptedAttention_getattr,
GaudiAdaptedAttentionPreAttnForward,
)

tuners.adaption_prompt.layer.AdaptedAttention.pre_attn_forward = GaudiAdaptedAttentionPreAttnForward
tuners.adaption_prompt.layer.AdaptedAttention.post_attn_forward = GaudiAdaptedAttentionPostAttnForward
tuners.adaption_prompt.layer.AdaptedAttention.attention_all_reduce = (
GaudiAdaptedAttentionAttentionAllReduce
)
tuners.adaption_prompt.layer.AdaptedAttention.__getattr__ = GaudiAdaptedAttention_getattr

return model

Expand Down
3 changes: 1 addition & 2 deletions optimum/habana/peft/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from .layer import (
GaudiAdaloraLayerSVDLinearForward,
GaudiAdaptedAttentionAttentionAllReduce,
GaudiAdaptedAttentionPostAttnForward,
GaudiAdaptedAttention_getattr,
GaudiAdaptedAttentionPreAttnForward,
)
from .peft_model import gaudi_generate, gaudi_prepare_inputs_for_generation
82 changes: 72 additions & 10 deletions optimum/habana/peft/layer.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import inspect
import math
from typing import Any

import torch
import torch.nn.functional as F
from peft.tuners.adaption_prompt.config import TRANSFORMERS_MODEL_CONFIG
from peft.tuners.adaption_prompt.utils import llama_apply_rotary_pos_emb, llama_rotate_half
from peft.utils.other import transpose


Expand Down Expand Up @@ -37,6 +39,68 @@ def GaudiAdaloraLayerSVDLinearForward(self, x: torch.Tensor, *args: Any, **kwarg
return result


def compute_query_states(model: torch.nn.Module, **kwargs) -> torch.Tensor:
"""
Copied from https://github.com/huggingface/peft/blob/v0.10.0/src/peft/tuners/adaption_prompt/utils.py#L60
The only differences are:
-add reuse cache support.
-add past key value list support
"""
hidden_states = kwargs.get("hidden_states")
position_ids = kwargs.get("position_ids")
past_key_value = kwargs.get("past_key_value")
bsz, q_len, _ = hidden_states.size()
query_states = model.q_proj(hidden_states).view(bsz, q_len, model.num_heads, model.head_dim).transpose(1, 2)

factor = model.k_proj.in_features // model.k_proj.out_features
value_states = (
model.v_proj(hidden_states).view(bsz, q_len, (model.num_heads // factor), model.head_dim).transpose(1, 2)
)

seq_len = q_len

if past_key_value is not None:
if kwargs.get("reuse_cache", False):
seq_len += past_key_value[0][-2]
elif isinstance(past_key_value, tuple) or isinstance(past_key_value, list):
# for transformers <= 4.35
seq_len += past_key_value[0].shape[-2]
else:
# since transformers 4.36, this is a DynamicCache instance
seq_len += past_key_value.get_seq_length(model.layer_idx)

# For transformers > 4.37.2 `position_ids` became a required arguments in the rotary embedding's forward pass.
if "position_ids" not in inspect.signature(model.rotary_emb.forward).parameters:
# TODO we assume that position_ids is not None here, not sure if that is safe but the old code also did that
cos, sin = model.rotary_emb(value_states, seq_len=seq_len)
return llama_apply_rotary_pos_emb(query_states, cos, sin, position_ids)

past_seen_tokens = 0
if position_ids is None:
# Compute position_ids, since they are required for transformers > 4.37.2
if past_key_value is None:
new_cache_positions = torch.arange(q_len, q_len + q_len, device=value_states.device)
else:
past_seen_tokens = past_key_value.get_usable_length(q_len, model.layer_idx)
new_cache_positions = torch.arange(past_seen_tokens, past_seen_tokens + q_len, device=value_states.device)
position_ids = new_cache_positions.unsqueeze(0)

rotary_emb_kwargs = {"position_ids": position_ids}
# The `seq_len` argument has been officially removed in transformers >= 4.39.0
if "seq_len" in inspect.signature(model.rotary_emb.forward).parameters:
rotary_emb_kwargs["seq_len"] = q_len + past_seen_tokens

cos, sin = model.rotary_emb(value_states, **rotary_emb_kwargs)

# For batched inference unsqueeze it on the correct dim
# since: https://github.com/huggingface/transformers/pull/29109
if len(cos.shape) == 3:
cos = cos.unsqueeze(1)
sin = sin.unsqueeze(1)

return (query_states * cos) + (llama_rotate_half(query_states) * sin)


def GaudiAdaptedAttentionPreAttnForward(self, *args, **kwargs):
"""
Copied from AdaptedAttention.forward: https://github.com/huggingface/peft/blob/v0.10.0/src/peft/tuners/adaption_prompt/layer.py#L57
Expand Down Expand Up @@ -79,7 +143,6 @@ def GaudiAdaptedAttentionPreAttnForward(self, *args, **kwargs):
adapter_k = torch.repeat_interleave(adapter_k, repeats=factor, dim=1)
adapter_v = torch.repeat_interleave(adapter_v, repeats=factor, dim=1)
# Recompute query states.
compute_query_states = TRANSFORMERS_MODEL_CONFIG[self.model_type].compute_query_states
# (bsz, num_heads, q_len, head_dim)
query_states = compute_query_states(model=self.model, **kwargs)

Expand All @@ -105,15 +168,14 @@ def GaudiAdaptedAttentionPreAttnForward(self, *args, **kwargs):
return output, None, past_key_value


def GaudiAdaptedAttentionAttentionAllReduce(self, attn_output):
if hasattr(self.model.o_proj, "all_reduce"):
self.model.o_proj.all_reduce(attn_output)


def GaudiAdaptedAttentionPostAttnForward(self, attn_output):
if hasattr(self.model.o_proj, "post_all_reduce"):
self.model.o_proj.post_all_reduce(attn_output)
return attn_output
def GaudiAdaptedAttention_getattr(self, name: str):
"""Forward missing attributes to the wrapped module."""
try:
return super(self.__class__, self).__getattr__(name)
except AttributeError:
# This is necessary as e.g. causal models have various methods that we
# don't want to re-implement here.
return getattr(self.model, name)


class LoRALinear:
Expand Down
10 changes: 4 additions & 6 deletions tests/test_peft_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ def _text_generation(self, model, tokenizer, extra_kwargs=None):
"max_new_tokens": 128,
"ignore_eos": True,
}
if extra_kwargs:
generate_kwargs.update(extra_kwargs)
generator = pipeline(
"text-generation",
model=model,
Expand Down Expand Up @@ -76,16 +78,12 @@ def _test_text_generation(self, model_name_or_path, peft_method):
)
elif peft_method == "llama-adapter":
from optimum.habana.peft.layer import (
GaudiAdaptedAttentionAttentionAllReduce,
GaudiAdaptedAttentionPostAttnForward,
GaudiAdaptedAttention_getattr,
GaudiAdaptedAttentionPreAttnForward,
)

tuners.adaption_prompt.layer.AdaptedAttention.pre_attn_forward = GaudiAdaptedAttentionPreAttnForward
tuners.adaption_prompt.layer.AdaptedAttention.post_attn_forward = GaudiAdaptedAttentionPostAttnForward
tuners.adaption_prompt.layer.AdaptedAttention.attention_all_reduce = (
GaudiAdaptedAttentionAttentionAllReduce
)
tuners.adaption_prompt.layer.AdaptedAttention.__getattr__ = GaudiAdaptedAttention_getattr
config = AdaptionPromptConfig(
adapter_layers=2,
adapter_len=4,
Expand Down

0 comments on commit da125f0

Please sign in to comment.