Skip to content

Commit

Permalink
minor fix
Browse files Browse the repository at this point in the history
  • Loading branch information
arielnlee committed Mar 27, 2024
1 parent 1bf4898 commit 40949ba
Showing 1 changed file with 8 additions and 0 deletions.
8 changes: 8 additions & 0 deletions llava/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,14 @@ def preprocess_mpt(

if has_image:
input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt') for prompt in conversations], dim=0)
else:
input_ids = tokenizer(
conversations,
return_tensors="pt",
padding="longest",
max_length=tokenizer.model_max_length,
truncation=True,
).input_ids

targets = input_ids.clone()
assert conv.sep_style == conversation_lib.SeparatorStyle.MPT
Expand Down

0 comments on commit 40949ba

Please sign in to comment.