Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable Rope for Mpt #1097

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 99 additions & 0 deletions optimum/habana/transformers/models/mpt/configuration_mpt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from typing import TYPE_CHECKING, Optional, Union

if TYPE_CHECKING:
pass

import copy
from transformers.models.mpt.configuration_mpt import PretrainedConfig
from transformers.models.mpt.configuration_mpt import MptAttentionConfig


class MptConfig(PretrainedConfig):
"""
Copied from: https://github.com/huggingface/transformers/blob/v4.32.0/src/transformers/models/mpt/configuration_mpt.py
Changes:
- add `rope_scaling` `rope_theta` and `_rope_scaling_validation` (inspired from Llama)
"""

model_type = "mpt"
attribute_map = {
"num_attention_heads": "n_heads",
"hidden_size": "d_model",
"num_hidden_layers": "n_layers",
}

def __init__(
self,
d_model: int = 2048,
n_heads: int = 16,
n_layers: int = 24,
expansion_ratio: int = 4,
max_seq_len: int = 2048,
vocab_size: int = 50368,
resid_pdrop: float = 0.0,
layer_norm_epsilon: float = 1e-5,
emb_pdrop: float = 0.0,
learned_pos_emb: bool = True,
attn_config: MptAttentionConfig = None,
init_device: str = "cpu",
logit_scale: Optional[Union[float, str]] = None,
no_bias: bool = True,
verbose: int = 0,
embedding_fraction: float = 1.0,
norm_type: str = "low_precision_layernorm",
use_cache: bool = False,
initializer_range=0.02,
rope_scaling=None,
rope_theta=10000,
**kwargs,
):
if attn_config is None:
self.attn_config = MptAttentionConfig()
elif isinstance(attn_config, dict):
self.attn_config = MptAttentionConfig(**attn_config)
else:
self.attn_config = attn_config
self.d_model = d_model
self.n_heads = n_heads
self.n_layers = n_layers
self.expansion_ratio = expansion_ratio
self.max_seq_len = max_seq_len
self.vocab_size = vocab_size
self.resid_pdrop = resid_pdrop
self.emb_pdrop = emb_pdrop
self.learned_pos_emb = learned_pos_emb
self.init_device = init_device
self.logit_scale = logit_scale
self.no_bias = no_bias
self.verbose = verbose
self.embedding_fraction = embedding_fraction
self.norm_type = norm_type
self.layer_norm_epsilon = layer_norm_epsilon
self.use_cache = use_cache
self.initializer_range = initializer_range
super().__init__(**kwargs)

self.rope_scaling = rope_scaling
self.rope_theta = rope_theta
self._rope_scaling_validation()

def _rope_scaling_validation(self):
"""
Taken from: https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/configuration_llama.py#L172
"""
if self.rope_scaling is None:
return

if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
raise ValueError(
"`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
f"got {self.rope_scaling}"
)
rope_scaling_type = self.rope_scaling.get("type", None)
rope_scaling_factor = self.rope_scaling.get("factor", None)
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
raise ValueError(
f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
)
if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")
73 changes: 71 additions & 2 deletions optimum/habana/transformers/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,24 @@
from torch.nn import CrossEntropyLoss
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
from transformers.models.mpt.modeling_mpt import MptForCausalLM, MptModel
from transformers.models.mistral.modeling_mistral import apply_rotary_pos_emb
from transformers.utils import logging

from .configuration_mpt import MptConfig
from ...modeling_attn_mask_utils import _gaudi_prepare_4d_causal_attention_mask

from ..llama.modeling_llama import (
GaudiLlamaDynamicNTKScalingRotaryEmbedding,
GaudiLlamaLinearScalingRotaryEmbedding,
GaudiLlamaRotaryEmbedding,
)

logger = logging.get_logger(__name__)

try:
from habana_frameworks.torch.hpex.kernels import RotaryPosEmbeddingHelperV2 as FusedRoPE
has_fused_rope = True
except ImportError:
has_fused_rope = False
print("Not using HPU fused kernel for apply_rotary_pos_emb")

def gaudi_mpt_attention_forward(
self,
Expand All @@ -45,6 +56,37 @@ def gaudi_mpt_attention_forward(
- optimize KV cache
"""

def init_rope():
"""
Copied from: https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L294
"""
config = MptConfig()
if config.rope_scaling is None:
self.rotary_emb = GaudiLlamaRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_seq_length,
base=config.rope_theta,
)
else:
scaling_type = config.rope_scaling["type"]
scaling_factor = config.rope_scaling["factor"]
if scaling_type == "linear":
self.rotary_emb = GaudiLlamaLinearScalingRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_seq_length,
scaling_factor=scaling_factor,
base=config.rope_theta,
)
elif scaling_type == "dynamic":
self.rotary_emb = GaudiLlamaDynamicNTKScalingRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_seq_length,
scaling_factor=scaling_factor,
base=config.rope_theta,
)
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")

batch_size, seq_length = hidden_states.shape[:2]

mixed_qkv = self.Wqkv(hidden_states)
Expand All @@ -57,23 +99,30 @@ def gaudi_mpt_attention_forward(
mixed_qkv[:, 2 * self.n_heads :, ...],
)

kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
if len(past_key_value) != 0:
if token_idx is not None:
past_key_value[0].index_copy_(2, token_idx - 1, key_states)
past_key_value[1].index_copy_(2, token_idx - 1, value_states)
key_states = past_key_value[0]
value_states = past_key_value[1]
kv_seq_len = past_key_value[0].shape[-2]
else:
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
kv_seq_len += past_key_value[0].shape[-2]
past_key_value = (key_states, value_states)
else:
past_key_value = (key_states, value_states)

attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2)) * self.softmax_scale

query_length = seq_length if past_key_value is None else seq_length + past_key_value[0].shape[2]

init_rope()
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_customized_rope(query_states, key_states, cos, sin, None)

if position_bias is not None:
if len(position_bias.shape) != 3:
Expand Down Expand Up @@ -382,3 +431,23 @@ def forward(
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)

def apply_customized_rope(q, k, cos, sin, position_ids):
if q.device.type == "hpu" and has_fused_rope:
# TODO: remove `.clone()` when SynapseAI v1.15 is released
if k.dtype == torch.bfloat16:
return FusedRoPE.apply(
q, cos.unsqueeze(0).unsqueeze(0).clone(), sin.unsqueeze(0).unsqueeze(0).clone(), position_ids
), FusedRoPE.apply(
k,
cos.unsqueeze(0).unsqueeze(0).clone().to(torch.bfloat16),
sin.unsqueeze(0).unsqueeze(0).clone().to(torch.bfloat16),
position_ids,
)
return FusedRoPE.apply(
q, cos.unsqueeze(0).unsqueeze(0).clone(), sin.unsqueeze(0).unsqueeze(0).clone(), position_ids
), FusedRoPE.apply(
k, cos.unsqueeze(0).unsqueeze(0).clone(), sin.unsqueeze(0).unsqueeze(0).clone(), position_ids
)
else:
return apply_rotary_pos_emb(q, k, cos, sin, position_ids)