diff --git a/optimum/habana/transformers/models/mpt/configuration_mpt.py b/optimum/habana/transformers/models/mpt/configuration_mpt.py new file mode 100644 index 000000000..8bf21a6db --- /dev/null +++ b/optimum/habana/transformers/models/mpt/configuration_mpt.py @@ -0,0 +1,99 @@ +from typing import TYPE_CHECKING, Optional, Union + +if TYPE_CHECKING: + pass + +import copy +from transformers.models.mpt.configuration_mpt import PretrainedConfig +from transformers.models.mpt.configuration_mpt import MptAttentionConfig + + +class MptConfig(PretrainedConfig): + """ + Copied from: https://github.com/huggingface/transformers/blob/v4.32.0/src/transformers/models/mpt/configuration_mpt.py + Changes: + - add `rope_scaling` `rope_theta` and `_rope_scaling_validation` (inspired from Llama) + """ + + model_type = "mpt" + attribute_map = { + "num_attention_heads": "n_heads", + "hidden_size": "d_model", + "num_hidden_layers": "n_layers", + } + + def __init__( + self, + d_model: int = 2048, + n_heads: int = 16, + n_layers: int = 24, + expansion_ratio: int = 4, + max_seq_len: int = 2048, + vocab_size: int = 50368, + resid_pdrop: float = 0.0, + layer_norm_epsilon: float = 1e-5, + emb_pdrop: float = 0.0, + learned_pos_emb: bool = True, + attn_config: MptAttentionConfig = None, + init_device: str = "cpu", + logit_scale: Optional[Union[float, str]] = None, + no_bias: bool = True, + verbose: int = 0, + embedding_fraction: float = 1.0, + norm_type: str = "low_precision_layernorm", + use_cache: bool = False, + initializer_range=0.02, + rope_scaling=None, + rope_theta=10000, + **kwargs, + ): + if attn_config is None: + self.attn_config = MptAttentionConfig() + elif isinstance(attn_config, dict): + self.attn_config = MptAttentionConfig(**attn_config) + else: + self.attn_config = attn_config + self.d_model = d_model + self.n_heads = n_heads + self.n_layers = n_layers + self.expansion_ratio = expansion_ratio + self.max_seq_len = max_seq_len + self.vocab_size = vocab_size + self.resid_pdrop = resid_pdrop + self.emb_pdrop = emb_pdrop + self.learned_pos_emb = learned_pos_emb + self.init_device = init_device + self.logit_scale = logit_scale + self.no_bias = no_bias + self.verbose = verbose + self.embedding_fraction = embedding_fraction + self.norm_type = norm_type + self.layer_norm_epsilon = layer_norm_epsilon + self.use_cache = use_cache + self.initializer_range = initializer_range + super().__init__(**kwargs) + + self.rope_scaling = rope_scaling + self.rope_theta = rope_theta + self._rope_scaling_validation() + + def _rope_scaling_validation(self): + """ + Taken from: https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/configuration_llama.py#L172 + """ + if self.rope_scaling is None: + return + + if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: + raise ValueError( + "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, " + f"got {self.rope_scaling}" + ) + rope_scaling_type = self.rope_scaling.get("type", None) + rope_scaling_factor = self.rope_scaling.get("factor", None) + if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: + raise ValueError( + f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" + ) + if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: + raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}") diff --git a/optimum/habana/transformers/models/mpt/modeling_mpt.py b/optimum/habana/transformers/models/mpt/modeling_mpt.py index ed470f165..b3e069714 100644 --- a/optimum/habana/transformers/models/mpt/modeling_mpt.py +++ b/optimum/habana/transformers/models/mpt/modeling_mpt.py @@ -22,13 +22,24 @@ from torch.nn import CrossEntropyLoss from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions from transformers.models.mpt.modeling_mpt import MptForCausalLM, MptModel +from transformers.models.mistral.modeling_mistral import apply_rotary_pos_emb from transformers.utils import logging - +from .configuration_mpt import MptConfig from ...modeling_attn_mask_utils import _gaudi_prepare_4d_causal_attention_mask - +from ..llama.modeling_llama import ( + GaudiLlamaDynamicNTKScalingRotaryEmbedding, + GaudiLlamaLinearScalingRotaryEmbedding, + GaudiLlamaRotaryEmbedding, +) logger = logging.get_logger(__name__) +try: + from habana_frameworks.torch.hpex.kernels import RotaryPosEmbeddingHelperV2 as FusedRoPE + has_fused_rope = True +except ImportError: + has_fused_rope = False + print("Not using HPU fused kernel for apply_rotary_pos_emb") def gaudi_mpt_attention_forward( self, @@ -45,6 +56,37 @@ def gaudi_mpt_attention_forward( - optimize KV cache """ + def init_rope(): + """ + Copied from: https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L294 + """ + config = MptConfig() + if config.rope_scaling is None: + self.rotary_emb = GaudiLlamaRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_seq_length, + base=config.rope_theta, + ) + else: + scaling_type = config.rope_scaling["type"] + scaling_factor = config.rope_scaling["factor"] + if scaling_type == "linear": + self.rotary_emb = GaudiLlamaLinearScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_seq_length, + scaling_factor=scaling_factor, + base=config.rope_theta, + ) + elif scaling_type == "dynamic": + self.rotary_emb = GaudiLlamaDynamicNTKScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_seq_length, + scaling_factor=scaling_factor, + base=config.rope_theta, + ) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + batch_size, seq_length = hidden_states.shape[:2] mixed_qkv = self.Wqkv(hidden_states) @@ -64,9 +106,11 @@ def gaudi_mpt_attention_forward( past_key_value[1].index_copy_(2, token_idx - 1, value_states) key_states = past_key_value[0] value_states = past_key_value[1] + kv_seq_len = past_key_value[0].shape[-2] else: key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) + kv_seq_len += past_key_value[0].shape[-2] past_key_value = (key_states, value_states) else: past_key_value = (key_states, value_states) @@ -74,6 +118,10 @@ def gaudi_mpt_attention_forward( attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2)) * self.softmax_scale query_length = seq_length if past_key_value is None else seq_length + past_key_value[0].shape[2] + + init_rope() + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_customized_rope(query_states, key_states, cos, sin, None) if position_bias is not None: if len(position_bias.shape) != 3: @@ -382,3 +430,23 @@ def forward( hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions, ) + +def apply_customized_rope(q, k, cos, sin, position_ids): + if q.device.type == "hpu" and has_fused_rope: + # TODO: remove `.clone()` when SynapseAI v1.15 is released + if k.dtype == torch.bfloat16: + return FusedRoPE.apply( + q, cos.unsqueeze(0).unsqueeze(0).clone(), sin.unsqueeze(0).unsqueeze(0).clone(), position_ids + ), FusedRoPE.apply( + k, + cos.unsqueeze(0).unsqueeze(0).clone().to(torch.bfloat16), + sin.unsqueeze(0).unsqueeze(0).clone().to(torch.bfloat16), + position_ids, + ) + return FusedRoPE.apply( + q, cos.unsqueeze(0).unsqueeze(0).clone(), sin.unsqueeze(0).unsqueeze(0).clone(), position_ids + ), FusedRoPE.apply( + k, cos.unsqueeze(0).unsqueeze(0).clone(), sin.unsqueeze(0).unsqueeze(0).clone(), position_ids + ) + else: + return apply_rotary_pos_emb(q, k, cos, sin, position_ids)