diff --git a/lavis/models/blip_diffusion_models/modeling_ctx_clip.py b/lavis/models/blip_diffusion_models/modeling_ctx_clip.py index 737b77d3..2e8c6efe 100644 --- a/lavis/models/blip_diffusion_models/modeling_ctx_clip.py +++ b/lavis/models/blip_diffusion_models/modeling_ctx_clip.py @@ -13,8 +13,8 @@ from transformers.models.clip.modeling_clip import ( CLIPEncoder, CLIPPreTrainedModel, - _expand_mask, ) +from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask class CtxCLIPTextModel(CLIPPreTrainedModel): @@ -136,7 +136,7 @@ def forward( # expand attention_mask if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _expand_mask(attention_mask, hidden_states.dtype) + attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) encoder_outputs = self.encoder( inputs_embeds=hidden_states,