Skip to content

Commit

Permalink
Enable Rope for Mpt
Browse files Browse the repository at this point in the history
  • Loading branch information
Jianhong-Zhang committed Jun 25, 2024
1 parent 9aa739b commit fdfe366
Show file tree
Hide file tree
Showing 2 changed files with 169 additions and 2 deletions.
99 changes: 99 additions & 0 deletions optimum/habana/transformers/models/mpt/configuration_mpt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from typing import TYPE_CHECKING, Optional, Union

if TYPE_CHECKING:
pass

import copy
from transformers.models.mpt.configuration_mpt import PretrainedConfig
from transformers.models.mpt.configuration_mpt import MptAttentionConfig


class MptConfig(PretrainedConfig):
"""
Copied from: https://github.com/huggingface/transformers/blob/v4.32.0/src/transformers/models/mpt/configuration_mpt.py
Changes:
- add `rope_scaling` `rope_theta` and `_rope_scaling_validation` (inspired from Llama)
"""

model_type = "mpt"
attribute_map = {
"num_attention_heads": "n_heads",
"hidden_size": "d_model",
"num_hidden_layers": "n_layers",
}

def __init__(
self,
d_model: int = 2048,
n_heads: int = 16,
n_layers: int = 24,
expansion_ratio: int = 4,
max_seq_len: int = 2048,
vocab_size: int = 50368,
resid_pdrop: float = 0.0,
layer_norm_epsilon: float = 1e-5,
emb_pdrop: float = 0.0,
learned_pos_emb: bool = True,
attn_config: MptAttentionConfig = None,
init_device: str = "cpu",
logit_scale: Optional[Union[float, str]] = None,
no_bias: bool = True,
verbose: int = 0,
embedding_fraction: float = 1.0,
norm_type: str = "low_precision_layernorm",
use_cache: bool = False,
initializer_range=0.02,
rope_scaling=None,
rope_theta=10000,
**kwargs,
):
if attn_config is None:
self.attn_config = MptAttentionConfig()
elif isinstance(attn_config, dict):
self.attn_config = MptAttentionConfig(**attn_config)
else:
self.attn_config = attn_config
self.d_model = d_model
self.n_heads = n_heads
self.n_layers = n_layers
self.expansion_ratio = expansion_ratio
self.max_seq_len = max_seq_len
self.vocab_size = vocab_size
self.resid_pdrop = resid_pdrop
self.emb_pdrop = emb_pdrop
self.learned_pos_emb = learned_pos_emb
self.init_device = init_device
self.logit_scale = logit_scale
self.no_bias = no_bias
self.verbose = verbose
self.embedding_fraction = embedding_fraction
self.norm_type = norm_type
self.layer_norm_epsilon = layer_norm_epsilon
self.use_cache = use_cache
self.initializer_range = initializer_range
super().__init__(**kwargs)

self.rope_scaling = rope_scaling
self.rope_theta = rope_theta
self._rope_scaling_validation()

def _rope_scaling_validation(self):
"""
Taken from: https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/configuration_llama.py#L172
"""
if self.rope_scaling is None:
return

if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
raise ValueError(
"`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
f"got {self.rope_scaling}"
)
rope_scaling_type = self.rope_scaling.get("type", None)
rope_scaling_factor = self.rope_scaling.get("factor", None)
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
raise ValueError(
f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
)
if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")
72 changes: 70 additions & 2 deletions optimum/habana/transformers/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,24 @@
from torch.nn import CrossEntropyLoss
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
from transformers.models.mpt.modeling_mpt import MptForCausalLM, MptModel
from transformers.models.mistral.modeling_mistral import apply_rotary_pos_emb
from transformers.utils import logging

from .configuration_mpt import MptConfig
from ...modeling_attn_mask_utils import _gaudi_prepare_4d_causal_attention_mask

from ..llama.modeling_llama import (
GaudiLlamaDynamicNTKScalingRotaryEmbedding,
GaudiLlamaLinearScalingRotaryEmbedding,
GaudiLlamaRotaryEmbedding,
)

logger = logging.get_logger(__name__)

try:
from habana_frameworks.torch.hpex.kernels import RotaryPosEmbeddingHelperV2 as FusedRoPE
has_fused_rope = True
except ImportError:
has_fused_rope = False
print("Not using HPU fused kernel for apply_rotary_pos_emb")

def gaudi_mpt_attention_forward(
self,
Expand All @@ -45,6 +56,37 @@ def gaudi_mpt_attention_forward(
- optimize KV cache
"""

def init_rope():
"""
Copied from: https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L294
"""
config = MptConfig()
if config.rope_scaling is None:
self.rotary_emb = GaudiLlamaRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_seq_length,
base=config.rope_theta,
)
else:
scaling_type = config.rope_scaling["type"]
scaling_factor = config.rope_scaling["factor"]
if scaling_type == "linear":
self.rotary_emb = GaudiLlamaLinearScalingRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_seq_length,
scaling_factor=scaling_factor,
base=config.rope_theta,
)
elif scaling_type == "dynamic":
self.rotary_emb = GaudiLlamaDynamicNTKScalingRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_seq_length,
scaling_factor=scaling_factor,
base=config.rope_theta,
)
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")

batch_size, seq_length = hidden_states.shape[:2]

mixed_qkv = self.Wqkv(hidden_states)
Expand All @@ -64,16 +106,22 @@ def gaudi_mpt_attention_forward(
past_key_value[1].index_copy_(2, token_idx - 1, value_states)
key_states = past_key_value[0]
value_states = past_key_value[1]
kv_seq_len = past_key_value[0].shape[-2]
else:
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
kv_seq_len += past_key_value[0].shape[-2]
past_key_value = (key_states, value_states)
else:
past_key_value = (key_states, value_states)

attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2)) * self.softmax_scale

query_length = seq_length if past_key_value is None else seq_length + past_key_value[0].shape[2]

init_rope()
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_customized_rope(query_states, key_states, cos, sin, None)

if position_bias is not None:
if len(position_bias.shape) != 3:
Expand Down Expand Up @@ -382,3 +430,23 @@ def forward(
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)

def apply_customized_rope(q, k, cos, sin, position_ids):
if q.device.type == "hpu" and has_fused_rope:
# TODO: remove `.clone()` when SynapseAI v1.15 is released
if k.dtype == torch.bfloat16:
return FusedRoPE.apply(
q, cos.unsqueeze(0).unsqueeze(0).clone(), sin.unsqueeze(0).unsqueeze(0).clone(), position_ids
), FusedRoPE.apply(
k,
cos.unsqueeze(0).unsqueeze(0).clone().to(torch.bfloat16),
sin.unsqueeze(0).unsqueeze(0).clone().to(torch.bfloat16),
position_ids,
)
return FusedRoPE.apply(
q, cos.unsqueeze(0).unsqueeze(0).clone(), sin.unsqueeze(0).unsqueeze(0).clone(), position_ids
), FusedRoPE.apply(
k, cos.unsqueeze(0).unsqueeze(0).clone(), sin.unsqueeze(0).unsqueeze(0).clone(), position_ids
)
else:
return apply_rotary_pos_emb(q, k, cos, sin, position_ids)

0 comments on commit fdfe366

Please sign in to comment.