Skip to content

Commit

Permalink
Fix starcoder2 accuracy issue and optimize performance with fushed rope
Browse files Browse the repository at this point in the history
  • Loading branch information
mandy-li committed Jun 25, 2024
1 parent 9e1319f commit 00a3546
Showing 1 changed file with 32 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,13 @@
)


try:
from habana_frameworks.torch.hpex.kernels import RotaryPosEmbeddingHelperV2 as FusedRoPE
except ImportError:
print("Not using HPU fused kernel for apply_rotary_pos_emb")
FusedRoPE = None


logger = logging.get_logger(__name__)


Expand All @@ -47,6 +54,7 @@ def gaudi_starcoder2_attention_forward(
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
)
bsz, q_len, _ = hidden_states.size()
self.norm_factor = 1.0 / math.sqrt(self.head_dim)

query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
Expand All @@ -70,7 +78,7 @@ def gaudi_starcoder2_attention_forward(
else:
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
query_states, key_states = apply_customized_rope(query_states, key_states, cos, sin, position_ids, self.training)

if past_key_value is not None:
if token_idx is not None:
Expand All @@ -90,7 +98,7 @@ def gaudi_starcoder2_attention_forward(
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)

attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.norm_factor

if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
Expand Down Expand Up @@ -470,3 +478,25 @@ def prepare_inputs_for_generation(
}
)
return model_inputs


def apply_customized_rope(q, k, cos, sin, position_ids, is_training):
if q.device.type == "hpu" and FusedRoPE:
if not is_training and (q.dtype == torch.bfloat16 or k.dtype == torch.bfloat16):
return FusedRoPE.apply(
q,
cos.unsqueeze(0).unsqueeze(0).to(torch.bfloat16),
sin.unsqueeze(0).unsqueeze(0).to(torch.bfloat16),
position_ids,
), FusedRoPE.apply(
k,
cos.unsqueeze(0).unsqueeze(0).to(torch.bfloat16),
sin.unsqueeze(0).unsqueeze(0).to(torch.bfloat16),
position_ids,
)
else:
return FusedRoPE.apply(
q, cos.unsqueeze(0).unsqueeze(0), sin.unsqueeze(0).unsqueeze(0), position_ids
), FusedRoPE.apply(k, cos.unsqueeze(0).unsqueeze(0), sin.unsqueeze(0).unsqueeze(0), position_ids)
else:
return apply_rotary_pos_emb(q, k, cos, sin, position_ids)

0 comments on commit 00a3546

Please sign in to comment.