diff --git a/optimum/habana/transformers/models/starcoder2/modeling_starcoder2.py b/optimum/habana/transformers/models/starcoder2/modeling_starcoder2.py index c40283331..2d66e11fa 100644 --- a/optimum/habana/transformers/models/starcoder2/modeling_starcoder2.py +++ b/optimum/habana/transformers/models/starcoder2/modeling_starcoder2.py @@ -22,6 +22,13 @@ ) +try: + from habana_frameworks.torch.hpex.kernels import RotaryPosEmbeddingHelperV2 as FusedRoPE +except ImportError: + print("Not using HPU fused kernel for apply_rotary_pos_emb") + FusedRoPE = None + + logger = logging.get_logger(__name__) @@ -47,6 +54,7 @@ def gaudi_starcoder2_attention_forward( "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" ) bsz, q_len, _ = hidden_states.size() + self.norm_factor = 1.0 / math.sqrt(self.head_dim) query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) @@ -70,7 +78,7 @@ def gaudi_starcoder2_attention_forward( else: kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + query_states, key_states = apply_customized_rope(query_states, key_states, cos, sin, position_ids, self.training) if past_key_value is not None: if token_idx is not None: @@ -90,7 +98,7 @@ def gaudi_starcoder2_attention_forward( key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.norm_factor if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): raise ValueError( @@ -470,3 +478,25 @@ def prepare_inputs_for_generation( } ) return model_inputs + + +def apply_customized_rope(q, k, cos, sin, position_ids, is_training): + if q.device.type == "hpu" and FusedRoPE: + if not is_training and (q.dtype == torch.bfloat16 or k.dtype == torch.bfloat16): + return FusedRoPE.apply( + q, + cos.unsqueeze(0).unsqueeze(0).to(torch.bfloat16), + sin.unsqueeze(0).unsqueeze(0).to(torch.bfloat16), + position_ids, + ), FusedRoPE.apply( + k, + cos.unsqueeze(0).unsqueeze(0).to(torch.bfloat16), + sin.unsqueeze(0).unsqueeze(0).to(torch.bfloat16), + position_ids, + ) + else: + return FusedRoPE.apply( + q, cos.unsqueeze(0).unsqueeze(0), sin.unsqueeze(0).unsqueeze(0), position_ids + ), FusedRoPE.apply(k, cos.unsqueeze(0).unsqueeze(0), sin.unsqueeze(0).unsqueeze(0), position_ids) + else: + return apply_rotary_pos_emb(q, k, cos, sin, position_ids)