From 1d3556872855b3c4592dfbe9852b450da0914b77 Mon Sep 17 00:00:00 2001 From: Sun Choi Date: Tue, 18 Jun 2024 20:19:52 +0000 Subject: [PATCH] for long input, enabled fp8 attn for prefill --- .../models/mistral/modeling_mistral.py | 81 ++++++++++++------- 1 file changed, 54 insertions(+), 27 deletions(-) diff --git a/optimum/habana/transformers/models/mistral/modeling_mistral.py b/optimum/habana/transformers/models/mistral/modeling_mistral.py index 42b573515..67231331d 100644 --- a/optimum/habana/transformers/models/mistral/modeling_mistral.py +++ b/optimum/habana/transformers/models/mistral/modeling_mistral.py @@ -135,9 +135,12 @@ def forward(self, x, y): # Copy from GaudiMixtralAttentionLongSequence -class GaudiMistralAttentionLongSequence: - @staticmethod - def forward(q, k, v, mask, causal, q_block_size): +class GaudiMistralAttentionLongSequence(torch.nn.Module): + def __init__(self): + super().__init__() + self.fsdpa_module = ModuleFusedSDPA(FusedSDPA) + + def forward(self, q, k, v, mask, causal, q_block_size): """ Support long sequence at prompt phase """ @@ -153,8 +156,10 @@ def forward(q, k, v, mask, causal, q_block_size): s, e = i * q_block_size, (i + 1) * q_block_size row_q = q[:, :, s:e, :] row_mask = mask[:, :, s:e, :] - row_o = attn_output[:, :, s:e, :] - row_o.fill_(FusedSDPA.apply(row_q, k, v, row_mask, 0.0, causal, None)) + #row_o = attn_output[:, :, s:e, :] + #row_o.fill_(FusedSDPA.apply(row_q, k, v, row_mask, 0.0, causal, None)) + #row_o.fill_(self.fsdpa_module(row_q, k, v, row_mask, 0.0, causal, None)) + attn_output[:, :, s:e, :] = self.fsdpa_module(row_q, k, v, row_mask, 0.0, causal, None) #saves memory if q_padding != 0: attn_output = attn_output[:, :, :-q_padding, :] @@ -229,10 +234,11 @@ def __init__(self, config: MistralConfig, layer_idx: Optional[int] = None): self.matmul_qk = Matmul() self.matmul_av = Matmul() self.fused_scaled_dot_product_attention = ModuleFusedSDPA(FusedSDPA) if FusedSDPA else None + self.fused_sdpa_long = GaudiMistralAttentionLongSequence() if FusedSDPA else None self.inp_seq_len = -1 self._init_rope() self.norm_factor = 1.0 / math.sqrt(self.head_dim) - self.block_size = 1024 + self.block_size = 8192#1024 def _init_rope(self): """ @@ -378,6 +384,7 @@ def forward( if FusedSDPA and use_flash_attention: if q_len == 1: + #print("**********388") # next token use_recompute = True if os.getenv("QUANT_CONFIG", "") else False with ht.sdp_kernel(enable_recompute=use_recompute): @@ -386,29 +393,49 @@ def forward( ) else: # first token - if flash_attention_causal_mask: - # causal masking on first token requires inputs to be of the same length - with ht.sdp_kernel(enable_recompute=flash_attention_recompute): - attn_output = self.fused_scaled_dot_product_attention( - query_states, key_states, value_states, None, 0.0, True, None - ) + if not self.training and q_len == key_states.size(-2) and q_len > 8192: + htcore.mark_step() + #print("*************long q attn") + attn_output = self.fused_sdpa_long( + #attn_output = GaudiMistralAttentionLongSequence.forward( + query_states, + key_states, + value_states, + attention_mask, + False, + self.block_size, + ) + htcore.mark_step() + + #print("***********q_len", q_len) else: - with ht.sdp_kernel(enable_recompute=flash_attention_recompute): - attn_output = self.fused_scaled_dot_product_attention( - query_states, key_states, value_states, attention_mask, 0.0, False, None - ) - elif FusedSDPA and not self.training and q_len == key_states.size(-2) and q_len > 8192: - htcore.mark_step() - attn_output = GaudiMistralAttentionLongSequence.forward( - query_states, - key_states, - value_states, - attention_mask, - False, - self.block_size, - ) - htcore.mark_step() + if flash_attention_causal_mask: + # causal masking on first token requires inputs to be of the same length + with ht.sdp_kernel(enable_recompute=flash_attention_recompute): + attn_output = self.fused_scaled_dot_product_attention( + query_states, key_states, value_states, None, 0.0, True, None + ) + else: + with ht.sdp_kernel(enable_recompute=flash_attention_recompute): + attn_output = self.fused_scaled_dot_product_attention( + query_states, key_states, value_states, attention_mask, 0.0, False, None + ) + ''' + elif FusedSDPA and not self.training and q_len == key_states.size(-2) and q_len > 8192: + htcore.mark_step() + print("*************long q attn") + attn_output = GaudiMistralAttentionLongSequence.forward( + query_states, + key_states, + value_states, + attention_mask, + False, + self.block_size, + ) + htcore.mark_step() + ''' else: + print("***********not fused") # repeat k/v heads if n_kv_heads < n_heads query_states, key_states, value_states, attention_mask = gaudi_mistral_repeat_kv( query_states, key_states, value_states, attention_mask, self.num_key_value_groups