Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Timmvit and laion data #966

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions llava/model/llava_arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def initialize_vision_modules(self, model_args, fsdp=None):
mm_vision_select_layer = model_args.mm_vision_select_layer
mm_vision_select_feature = model_args.mm_vision_select_feature
pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
use_timm_vision_tower = model_args.use_timm_vision_tower

self.config.mm_vision_tower = vision_tower

Expand All @@ -66,6 +67,7 @@ def initialize_vision_modules(self, model_args, fsdp=None):
self.config.mm_hidden_size = vision_tower.hidden_size
self.config.mm_vision_select_layer = mm_vision_select_layer
self.config.mm_vision_select_feature = mm_vision_select_feature
self.config.use_timm_vision_tower = use_timm_vision_tower

if getattr(self, 'mm_projector', None) is None:
self.mm_projector = build_vision_projector(self.config)
Expand Down
4 changes: 3 additions & 1 deletion llava/model/multimodal_encoder/builder.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import os
from .clip_encoder import CLIPVisionTower

from .timm_clip_encoder import TIMMVisionTower

def build_vision_tower(vision_tower_cfg, **kwargs):
vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None))
is_absolute_path_exists = os.path.exists(vision_tower)
if vision_tower_cfg.use_timm_vision_tower:
return TIMMVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion"):
return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)

Expand Down
88 changes: 88 additions & 0 deletions llava/model/multimodal_encoder/timm_clip_encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import torch
import torch.nn as nn
import timm

class TIMMImageProcessor:
def __init__(self, vision_tower):
self.image_mean = list(vision_tower.default_cfg["mean"]) # [0.48145466, 0.4578275, 0.40821073]
crop_size = vision_tower.default_cfg["input_size"]
self.crop_size = {
'height': crop_size[1],
'width': crop_size[2]
}
# get model specific transforms (normalization, resize)
data_config = timm.data.resolve_model_data_config(vision_tower)
self.transforms = timm.data.create_transform(**data_config, is_training=False)

def preprocess(self, img, return_tensors='pt'):
if img.mode != 'RGB':
img = img.convert('RGB')
transformed_img = self.transforms(img)
return {'pixel_values': [transformed_img]}


class TIMMVisionTower(nn.Module):
def __init__(self, vision_tower, args, delay_load=False):
super().__init__()

self.is_loaded = False
self.vision_tower_name = vision_tower
self.load_model()

def load_model(self):
import os
files = os.listdir(self.vision_tower_name)
for file_name in files:
if file_name.endswith('.bin'):
bin_file = os.path.join(self.vision_tower_name, file_name)
assert os.path.exists(bin_file)
self.vision_tower = timm.create_model(
self.vision_tower_name,
pretrained=True,
# features_only=True,
pretrained_cfg_overlay=dict(file=bin_file),
num_classes=0, # remove classifier nn.Linear
global_pool=''
)
self.vision_tower = self.vision_tower.eval()
print("loaded!")

self.image_processor = TIMMImageProcessor(self.vision_tower)
self.is_loaded = True


# @torch.no_grad()
def forward(self, images):
if type(images) is list:
image_features = []
for image in images:
image_feature = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0)).to(image.dtype)[:, 1:, :] # remove CLS token
image_features.append(image_feature)
else:
image_features = self.vision_tower(images.to(device=self.device, dtype=self.dtype)).to(images[0].dtype)[:, 1:, :] # remove CLS token

return image_features

@property
def dummy_feature(self):
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)

@property
def dtype(self):
return self.vision_tower._parameters[list(self.vision_tower._parameters.keys())[0]].dtype

@property
def device(self):
return self.vision_tower._parameters[list(self.vision_tower._parameters.keys())[0]].device

# @property
# def config(self):
# if self.is_loaded:
# return self.vision_tower.config
# else:
# return self.cfg_only

@property
def hidden_size(self):
return self.vision_tower.embed_dim

111 changes: 108 additions & 3 deletions llava/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from llava.mm_utils import tokenizer_image_token

from PIL import Image

import random

local_rank = None

Expand All @@ -58,6 +58,7 @@ class ModelArguments:
mm_use_im_start_end: bool = field(default=False)
mm_use_im_patch_token: bool = field(default=True)
mm_vision_select_feature: Optional[str] = field(default="patch")
use_timm_vision_tower: bool = field(default=False)


@dataclass
Expand All @@ -68,6 +69,10 @@ class DataArguments:
is_multimodal: bool = False
image_folder: Optional[str] = field(default=None)
image_aspect_ratio: str = 'square'
##### use laion_data
laion_path: str = field(default=None,
metadata={"help": "path for laion"})
laion_amount: int = 0


@dataclass
Expand Down Expand Up @@ -623,6 +628,30 @@ def get_tokenize_len(prompts):
return dict(input_ids=input_ids, labels=targets)


import webdataset as wds
import glob
import io
def create_laion_dataset(
data_dir,
cache_path=None,
):
"""Create a WebDataset reader, it can read a webdataset of image, text and json"""
meta_file_list = glob.glob(os.path.join(data_dir, "split*", "*.tar"))
print("Find", len(meta_file_list), "tar files")
image_dataset = wds.WebDataset(meta_file_list, cache_dir=cache_path,
cache_size=20,
handler=wds.handlers.warn_and_continue,
resampled=True)
image_dataloader = wds.WebLoader(
dataset = image_dataset,
batch_size = 1,
num_workers = 0,
pin_memory = True,
prefetch_factor = None,
)
data_iterator = image_dataloader.iterator()
return data_iterator

class LazySupervisedDataset(Dataset):
"""Dataset for supervised fine-tuning."""

Expand All @@ -635,10 +664,30 @@ def __init__(self, data_path: str,
rank0_print("Formatting inputs...Skip in lazy mode")
self.tokenizer = tokenizer
self.list_data_dict = list_data_dict

self.caption_prompt = [
'Describe the image concisely.',
'Provide a brief description of the given image.',
'Offer a succinct explanation of the picture presented.',
'Summarize the visual content of the image.',
'Give a short and clear explanation of the subsequent image.',
'Share a concise interpretation of the image provided.',
"""Present a compact description of the photo's key features.""",
'Relay a brief, clear account of the picture shown.',
'Render a clear and concise summary of the photo.',
'Write a terse but informative summary of the picture.',
'Create a compact narrative representing the image presented.'
]
self.laion_iter = None
###### add laion data
if data_args.laion_path is not None:
self.laion_iter = create_laion_dataset(data_args.laion_path)

self.data_args = data_args

def __len__(self):
return len(self.list_data_dict)
# return len(self.list_data_dict)
return len(self.list_data_dict) + self.data_args.laion_amount

@property
def lengths(self):
Expand All @@ -657,7 +706,63 @@ def modality_lengths(self):
length_list.append(cur_len)
return length_list

def get_laion_item(self,):
valid_flag = False
while not valid_flag:
cur_data = next(self.laion_iter)
main_data = json.loads(cur_data['json'][0])
if main_data['width'] >= 50 and main_data['height'] >= 50 and main_data['caption'] is not None and len(main_data['caption']) < 200:
valid_flag = True
break
image_file = io.BytesIO(cur_data['jpg'][0])
image = Image.open(image_file).convert('RGB')
caption = json.loads(cur_data['json'][0])['caption']

processor = self.data_args.image_processor
if self.data_args.image_aspect_ratio == 'pad':
def expand2square(pil_img, background_color):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
image = expand2square(image, tuple(int(x*255) for x in processor.image_mean))
image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
else:
image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]

cur_prompt = random.choice(self.caption_prompt)
format_item = {}
format_item['conversations'] = []

format_item['conversations'].append({'from': 'human', 'value': f'{cur_prompt}\n{DEFAULT_IMAGE_TOKEN}'})
format_item['conversations'].append({'from': 'gpt', 'value': f'{caption}'})

sources = preprocess_multimodal(
copy.deepcopy([format_item['conversations']]),
self.data_args)

data_dict = preprocess(
sources,
self.tokenizer,
has_image=True)

data_dict = dict(input_ids=data_dict["input_ids"][0],
labels=data_dict["labels"][0])
data_dict['image'] = image

return data_dict

def __getitem__(self, i) -> Dict[str, torch.Tensor]:
if i>= len(self.list_data_dict):
return self.get_laion_item()

sources = self.list_data_dict[i]
if isinstance(i, int):
sources = [sources]
Expand Down Expand Up @@ -876,7 +981,7 @@ def make_inputs_require_grad(module, input, output):
model_args=model_args,
fsdp=training_args.fsdp
)

vision_tower = model.get_vision_tower()
vision_tower.to(dtype=torch.bfloat16 if training_args.bf16 else torch.float16, device=training_args.device)

Expand Down