Skip to content

Commit

Permalink
upd Bloom _prepare_attn_mask()
Browse files Browse the repository at this point in the history
  • Loading branch information
younesbelkada committed Nov 21, 2023
1 parent 4295ee8 commit 4cdeed7
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion src/petals/models/bloom/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Optional, Tuple

import torch
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from transformers.models.bloom.modeling_bloom import BloomBlock, BloomModel, build_alibi_tensor


Expand All @@ -26,7 +27,13 @@ def forward(
attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
if alibi is None:
alibi = build_alibi_tensor(attention_mask, num_heads=self.num_heads, dtype=hidden_states.dtype)
attention_mask = BloomModel._prepare_attn_mask(None, attention_mask, (batch_size, seq_length), past_length)
fake_inputs_embeds = torch.tensor([42], dtype=torch.float32)
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask=attention_mask,
input_shape=(batch_size, seq_length),
inputs_embeds=fake_inputs_embeds,
past_key_values_length=past_length,
)
return super().forward(
hidden_states, *args, attention_mask=attention_mask, alibi=alibi, layer_past=layer_past, **kwargs
)

0 comments on commit 4cdeed7

Please sign in to comment.