Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix a bug in function "prepare_inputs_labels_for_multimodal" of "LlavaMetaForCausalLM" when there are more than one image in each conversation of a batch. #967

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

zsxm1998
Copy link

@zsxm1998 zsxm1998 commented Jan 3, 2024

Dear author:

Thanks for your excellent work LLaVA.

When reading your code, I found that in this line there is a conditional statements used for the case that there is more than one image for each conversation in a batch, I suppose. So the "images" in this case is either a tensor of shape [N, M, 3, H, W], where N is batch_size and M is the image number of each sentence, or a list that contain N tensors of shape [m_n, 3, H, W], where m_n is the image number and differs in different sentences in a batch.

But in this if block, the resulting "image_features" may has the wrong shape in the case that at least one sentence contains more than one image, and raise an "IndexError" exception in this line of the file.
I commit the shape of the result of each line in this if block for better understanding:

if type(images) is list or images.ndim == 5: #[N, M, 3, H, W], where N is batch_size and M is the image number of each sentence
    concat_images = torch.cat([image for image in images], dim=0) #[N*M, 3, H, W]
    image_features = self.encode_images(concat_images) #[N*M, L, D], where L is sequence length and D is embedding dim
    split_sizes = [image.shape[0] for image in images] #[M]*N, a list containing N elements of M
    image_features = torch.split(image_features, split_sizes, dim=0) #[M, L, D]*N, a tuple that contain N tensor of shape [M, L, D]
    image_features = [x.flatten(0, 1).to(self.device) for x in image_features] #[M*L, D] * N, a list that contain N tensor of shape [M*L, D]

I don't understand why using flatten(0, 1) to modify the shape of x, it concatenates features of multiple images into only ONE feature. So image_features will only contain N image features, but there should have been N*M image features.

I wrote a simple script to reproduce the bug as follows:

import os
import torch
from llava.model.builder import load_pretrained_model
from llava.mm_utils import tokenizer_image_token
from llava.constants import IMAGE_TOKEN_INDEX
from PIL import Image

os.environ['CUDA_VISIBLE_DEVICES'] = '1'

tokenizer, model, image_processor, context_len = load_pretrained_model('checkpoints/llava-v1.5-7b', None, 'llava-v1.5-7b')

input_text = ['<image>\nThis is the first sentence<image>.', '<image>\nThis is the second sentence.\n<image>']
image_files = [['your_image_path',
                'your_image_path'],
               [''your_image_path',
                ''your_image_path']]

input_ids = [tokenizer_image_token(s, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt') for s in input_text]
input_ids = torch.nn.utils.rnn.pad_sequence(
            input_ids,
            batch_first=True,
            padding_value=tokenizer.pad_token_id).cuda()
print('input_ids.shape:', input_ids.shape)


image_lists = [torch.cat([image_processor.preprocess(Image.open(f), return_tensors='pt')['pixel_values'] for f in files]).half().cuda() for files in image_files]
image_lists = torch.stack(image_lists)
print('image_lists.shape:', image_lists.shape)
print('image_lists[1].shape:', image_lists[1].shape)

model(input_ids=input_ids, images = image_lists)

this will raise the "IndexError" exception in this line of the file. You can modify the "input_text" and "image_files" to make the "image_files" a list that contains N tensors of shape [m_n, 3, H, W] for the case that the image numbers m_n differ in each input conversation (or sentence), and the exception still raises.

So I modify the code in this if block. If I make any mistake or have a misunderstanding of this code, please don't hesitate to correct me.

Thank you.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

1 participant