Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

how to train your inpainting model using my own dataset?? #352

Open
dreamlychina opened this issue Jun 13, 2023 · 1 comment
Open

how to train your inpainting model using my own dataset?? #352

dreamlychina opened this issue Jun 13, 2023 · 1 comment

Comments

@dreamlychina
Copy link

Thanks for sharing this amazing work,I want to train your inpainting model using my own dataset, could you show me any training script and how to prepare the data at your convenience?

@swayampragnya-malla
Copy link

from imagen_pytorch import Unet, Imagen, ImagenTrainer
from imagen_pytorch.data import Dataset

output_path="/content/drive/MyDrive/imgen_pytorch/output"

unets for unconditional imagen

unet = Unet(
dim = 32,
dim_mults = (1, 2, 4, 8),
num_resnet_blocks = 1,
layer_attns = (False, False, False, True),
layer_cross_attns = False
)

imagen, which contains the unet above

imagen = Imagen(
condition_on_text = False, # this must be set to False for unconditional Imagen
unets = unet,
image_sizes = 256,
timesteps = 1000
)

trainer = ImagenTrainer(
imagen = imagen,
split_valid_from_train = True # whether to split the validation dataset from the training
).cuda()

instantiate your dataloader, which returns the necessary inputs to the DDPM as tuple in the order of images, text embeddings, then text masks. in this case, only images is returned as it is unconditional training

dataset = Dataset('/content/drive/MyDrive/unconditional_generation/dataset_256', image_size = 256)

trainer.add_train_dataset(dataset, batch_size = 16)

working training loop

for i in range(20000):
loss = trainer.train_step(unet_number = 1, max_batch_size = 4)
print(f'loss: {loss}')

if not (i % 50):
    valid_loss = trainer.valid_step(unet_number = 1, max_batch_size = 4)
    print(f'valid loss: {valid_loss}')

if not (i % 100) and trainer.is_main: # is_main makes sure this can run in distributed
    images = trainer.sample(batch_size = 1, return_pil_images = True) # returns List[Image]
    images[0].save(f'{output_path}/{i // 100}.png')

This is the training code for your custom dataset .

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants