Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Setup Steps #279

Open
bahmanian opened this issue Mar 5, 2023 · 2 comments
Open

Setup Steps #279

bahmanian opened this issue Mar 5, 2023 · 2 comments

Comments

@bahmanian
Copy link

Please someone explain how to use it exactly?

@u1ug
Copy link

u1ug commented Nov 5, 2023

Here is a sample code for training

import torch
from dalle2_pytorch.tokenizer import SimpleTokenizer
from torch.utils.data import DataLoader
from dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, OpenAIClipAdapter, Unet, Decoder, \
    DecoderTrainer
from torchvision.utils import make_grid
import torchvision.transforms as T
from torchvision.utils import save_image
from PIL import Image
from datetime import datetime
import os
import torch.utils.data as data
import json
from torchvision.utils import save_image


def read_metadata(text: str) -> list[dict]:
    data = []
    for line in text.split('\n'):
        if not line:
            continue
        line_json = json.loads(line)
        data.append(line_json)
    return data


# ImgTextDataset returns image tensor and text caption for it. Works as huggingface text-to-image dataset implementation.
# __init__ reads info about dataset from `metadata.jsonl` file where image paths and captions are specified.
# {"file_name": "/path/1.png", "text": "sample text 1"}
# {"file_name": "/path/2.png", "text": "sample text 2"}
# ...
# To make custom dataset inherit data.Dataset and implement __len__ and __getitem__ methods.
class ImgTextDataset(data.Dataset):
    def __init__(self, fp: str):
        self.fp = fp
        with open(os.path.join(fp, 'metadata.jsonl'), 'r') as file:
            metadata = read_metadata(file.read())

        self.img_paths = []
        self.captions = []

        for line in metadata:
            self.img_paths.append(line['file_name'])
            self.captions.append(line['text'])

        # Make sure that each image is captioned
        assert len(self.img_paths) == len(self.captions)
        # Apply required image transforms. For my model I need RGB images with 256 x 256 dimensions.
        self.image_tranform = T.Compose([
            T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
            T.Resize((256, 256)),
            T.ToTensor()
        ])

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):
        image_path = os.path.join(self.fp, self.img_paths[idx])
        caption = self.captions[idx]

        image = Image.open(image_path)

        image_pt = self.image_tranform(image).cuda()
        return image_pt, caption


# Parameters
image_size = 256  # Image dimension
batch_size = 1  # Batch size for training, adjust based on GPU memory
learning_rate = 1e-4  # Learning rate for the optimizer
num_epochs = 50  # Number of epochs for training
log_image_interval = 1000  # Interval for logging images
save_dir = "./log_images"  # Directory to save log images
os.makedirs(save_dir, exist_ok=True)  # Create save directory if it doesn't exist

# Setup device
device = torch.device("cuda")  # Not recommended to train on cpu

# Define your image-text dataset
dataset = ImgTextDataset('path to folder with metadata.jsonl file')
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Initialize OpenAI CLIP model adapter
clip = OpenAIClipAdapter()
# Create models for training
unet1 = Unet(
    dim=128,
    image_embed_dim=512,
    text_embed_dim=512,
    cond_dim=128,
    channels=3,
    dim_mults=(1, 2, 4, 8),
    cond_on_text_encodings=True,
).cuda()

decoder = Decoder(
    unet=unet1,
    image_size=image_size,
    clip=clip,
    timesteps=1000
).cuda()

decoder_trainer = DecoderTrainer(
    decoder,
    lr=3e-4,
    wd=1e-2,
    ema_beta=0.99,
    ema_update_after_step=1000,
    ema_update_every=10,
).cuda()

# Use built-in tokenizer. You can use others like GPT2, YTTM etc.
t = SimpleTokenizer()

# Training loop.
# Iterate over the dataloader and pass image tensors and tokenized text to the training wrapper.
# Repeat process N times.

for epoch in range(num_epochs):
    for batch_idx, (images, texts) in enumerate(dataloader):
        loss = decoder_trainer(
            images.cuda(),
            text=t.tokenize(texts).cuda(),
            unet_number=1,
            max_batch_size=4
        )
        decoder_trainer.update(1)
        if batch_idx % 100 == 0:
            print(f"epoch {epoch}, step {batch_idx}, loss {loss}")
        if batch_idx % 5000 == 0 and batch_idx != 0:
            image_embed = clip.embed_image(images.cuda())
            sample = decoder_trainer.sample(image_embed=image_embed[0], text=t.tokenize(texts).cuda())
            save_image(sample, f'./log_images/{epoch}_{batch_idx}.png')
    # Periodically save the model.
    torch.save(decoder_trainer.state_dict(), f'model_{epoch}.pt')

@Felix-FN
Copy link

Felix-FN commented Apr 12, 2024

Hi @u1ug , do you also have a similar example to train a prior?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants