Skip to content

Commit

Permalink
for long input, enabled fp8 attn for prefill
Browse files Browse the repository at this point in the history
  • Loading branch information
schoi-habana committed Jun 20, 2024
1 parent f428dd5 commit 1d35568
Showing 1 changed file with 54 additions and 27 deletions.
81 changes: 54 additions & 27 deletions optimum/habana/transformers/models/mistral/modeling_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,12 @@ def forward(self, x, y):


# Copy from GaudiMixtralAttentionLongSequence
class GaudiMistralAttentionLongSequence:
@staticmethod
def forward(q, k, v, mask, causal, q_block_size):
class GaudiMistralAttentionLongSequence(torch.nn.Module):
def __init__(self):
super().__init__()
self.fsdpa_module = ModuleFusedSDPA(FusedSDPA)

def forward(self, q, k, v, mask, causal, q_block_size):
"""
Support long sequence at prompt phase
"""
Expand All @@ -153,8 +156,10 @@ def forward(q, k, v, mask, causal, q_block_size):
s, e = i * q_block_size, (i + 1) * q_block_size
row_q = q[:, :, s:e, :]
row_mask = mask[:, :, s:e, :]
row_o = attn_output[:, :, s:e, :]
row_o.fill_(FusedSDPA.apply(row_q, k, v, row_mask, 0.0, causal, None))
#row_o = attn_output[:, :, s:e, :]
#row_o.fill_(FusedSDPA.apply(row_q, k, v, row_mask, 0.0, causal, None))
#row_o.fill_(self.fsdpa_module(row_q, k, v, row_mask, 0.0, causal, None))
attn_output[:, :, s:e, :] = self.fsdpa_module(row_q, k, v, row_mask, 0.0, causal, None) #saves memory

if q_padding != 0:
attn_output = attn_output[:, :, :-q_padding, :]
Expand Down Expand Up @@ -229,10 +234,11 @@ def __init__(self, config: MistralConfig, layer_idx: Optional[int] = None):
self.matmul_qk = Matmul()
self.matmul_av = Matmul()
self.fused_scaled_dot_product_attention = ModuleFusedSDPA(FusedSDPA) if FusedSDPA else None
self.fused_sdpa_long = GaudiMistralAttentionLongSequence() if FusedSDPA else None
self.inp_seq_len = -1
self._init_rope()
self.norm_factor = 1.0 / math.sqrt(self.head_dim)
self.block_size = 1024
self.block_size = 8192#1024

def _init_rope(self):
"""
Expand Down Expand Up @@ -378,6 +384,7 @@ def forward(

if FusedSDPA and use_flash_attention:
if q_len == 1:
#print("**********388")
# next token
use_recompute = True if os.getenv("QUANT_CONFIG", "") else False
with ht.sdp_kernel(enable_recompute=use_recompute):
Expand All @@ -386,29 +393,49 @@ def forward(
)
else:
# first token
if flash_attention_causal_mask:
# causal masking on first token requires inputs to be of the same length
with ht.sdp_kernel(enable_recompute=flash_attention_recompute):
attn_output = self.fused_scaled_dot_product_attention(
query_states, key_states, value_states, None, 0.0, True, None
)
if not self.training and q_len == key_states.size(-2) and q_len > 8192:
htcore.mark_step()
#print("*************long q attn")
attn_output = self.fused_sdpa_long(
#attn_output = GaudiMistralAttentionLongSequence.forward(
query_states,
key_states,
value_states,
attention_mask,
False,
self.block_size,
)
htcore.mark_step()

#print("***********q_len", q_len)
else:
with ht.sdp_kernel(enable_recompute=flash_attention_recompute):
attn_output = self.fused_scaled_dot_product_attention(
query_states, key_states, value_states, attention_mask, 0.0, False, None
)
elif FusedSDPA and not self.training and q_len == key_states.size(-2) and q_len > 8192:
htcore.mark_step()
attn_output = GaudiMistralAttentionLongSequence.forward(
query_states,
key_states,
value_states,
attention_mask,
False,
self.block_size,
)
htcore.mark_step()
if flash_attention_causal_mask:
# causal masking on first token requires inputs to be of the same length
with ht.sdp_kernel(enable_recompute=flash_attention_recompute):
attn_output = self.fused_scaled_dot_product_attention(
query_states, key_states, value_states, None, 0.0, True, None
)
else:
with ht.sdp_kernel(enable_recompute=flash_attention_recompute):
attn_output = self.fused_scaled_dot_product_attention(
query_states, key_states, value_states, attention_mask, 0.0, False, None
)
'''
elif FusedSDPA and not self.training and q_len == key_states.size(-2) and q_len > 8192:
htcore.mark_step()
print("*************long q attn")
attn_output = GaudiMistralAttentionLongSequence.forward(
query_states,
key_states,
value_states,
attention_mask,
False,
self.block_size,
)
htcore.mark_step()
'''
else:
print("***********not fused")
# repeat k/v heads if n_kv_heads < n_heads
query_states, key_states, value_states, attention_mask = gaudi_mistral_repeat_kv(
query_states, key_states, value_states, attention_mask, self.num_key_value_groups
Expand Down

0 comments on commit 1d35568

Please sign in to comment.