diff --git a/optimum/habana/transformers/modeling_utils.py b/optimum/habana/transformers/modeling_utils.py index 6dc40a73b..c3370d8cb 100644 --- a/optimum/habana/transformers/modeling_utils.py +++ b/optimum/habana/transformers/modeling_utils.py @@ -88,6 +88,7 @@ gaudi_esm_for_protein_folding_forward, gaudi_esmfolding_trunk_forward, gaudi_falcon_attention_split_heads, + gaudi_falcon_linear_forward, gaudi_generate_speech, gaudi_get_extended_attention_mask, gaudi_gpt2_block_forward, @@ -307,6 +308,7 @@ def adapt_transformers_to_gaudi(): transformers.models.falcon.modeling_falcon.FalconModel = GaudiFalconModel transformers.models.falcon.modeling_falcon.FalconDecoderLayer = GaudiFalconDecoderLayer transformers.models.falcon.modeling_falcon.FalconAttention._split_heads = gaudi_falcon_attention_split_heads + transformers.models.falcon.modeling_falcon.FalconLinear.forward = gaudi_falcon_linear_forward # Optimization for t5 on Gaudi transformers.models.t5.modeling_t5.T5LayerNorm.forward = gaudi_t5_layernorm_forward diff --git a/optimum/habana/transformers/models/__init__.py b/optimum/habana/transformers/models/__init__.py index 1582d3f09..a075cd7ae 100644 --- a/optimum/habana/transformers/models/__init__.py +++ b/optimum/habana/transformers/models/__init__.py @@ -49,6 +49,7 @@ GaudiFalconMLP, GaudiFalconModel, gaudi_falcon_attention_split_heads, + gaudi_falcon_linear_forward, ) from .gpt2 import GaudiGPT2Attention, GaudiGPT2LMHeadModel, gaudi_gpt2_block_forward, gaudi_gpt2_forward from .gpt_bigcode import ( diff --git a/optimum/habana/transformers/models/falcon/__init__.py b/optimum/habana/transformers/models/falcon/__init__.py index 00c73ad11..a42b846c6 100644 --- a/optimum/habana/transformers/models/falcon/__init__.py +++ b/optimum/habana/transformers/models/falcon/__init__.py @@ -5,4 +5,5 @@ GaudiFalconMLP, GaudiFalconModel, gaudi_falcon_attention_split_heads, + gaudi_falcon_linear_forward, ) diff --git a/optimum/habana/transformers/models/falcon/modeling_falcon.py b/optimum/habana/transformers/models/falcon/modeling_falcon.py index 9b9a74c12..c33a8bb78 100644 --- a/optimum/habana/transformers/models/falcon/modeling_falcon.py +++ b/optimum/habana/transformers/models/falcon/modeling_falcon.py @@ -83,6 +83,11 @@ def apply_customized_rope(q, k, cos, sin, position_ids): return apply_rotary_pos_emb(q, k, cos, sin, position_ids) +def gaudi_falcon_linear_forward(self, input: torch.Tensor) -> torch.Tensor: + hidden_states = F.linear(input, self.weight, bias=self.bias) + return hidden_states + + def gaudi_falcon_attention_split_heads( self, fused_qkv: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: