Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Reward Backpropogation Support #1585

Open
wants to merge 13 commits into
base: main
Choose a base branch
from

Conversation

mihirp1998
Copy link

@mihirp1998 mihirp1998 commented Apr 25, 2024

Hi,

I have added support for AlignProp (https://align-prop.github.io/) for finetuning Stable Diffusion model using reward gradients.

AlignProp directly backpropagate gradients from the reward model to the diffusion weights. Thus is about 25x more sample and compute efficient than policy gradient based methods like DDPO.

The current implementation seems to train effectively, almost within an hour on a single A100 while using Aesthetic reward model. Please find the attached loss and reward curves + some qualitative results after training.

huggingface/diffusers#7312

Difference between DDPO and AlignProp:

  • DDPO uses PPO, which is a policy gradient method for aligning diffusion models. AlignProp doesn't use policy gradients instead it directly backpropagates gradients from the reward function to diffusion denoising process, to maximize reward.

  • AlignProp can only work when the reward function is differentiable, DDPO on other hand can handle non-differentiable reward functions, as it never backpropagates gradients from the reward function weights.

  • As AlignProp takes benefit of the differentiability of the reward function as it backpropagates gradient. It is significantly more sample efficient than DDPO.

  • The loss function in AlignProp is simply the negative of the reward value outputed by the reward function, while in DDPO it's the PPO loss function.

  • As the reward function is sitting on the RGB images. AlignProp requires to do the full denoising chain from Noise to RGB during training, while DDPO can instead sample random denoising timesteps, similar to diffusion training.

  • DDPO and AlignProp both use LoRA and gradient checkpointing.

CC: @parthos86 @sayakpaul @lvwerra @younesbelkada

W B Chart 4_25_2024, 3_02_39 AM
W B Chart 4_25_2024, 3_02_45 AM

Image Generations post training:
Screenshot 2024-04-25 at 3 02 30 AM

@mihirp1998 mihirp1998 changed the title Added AlignProp Support Added Reward Backpropogation Support May 1, 2024
@huggingface huggingface deleted a comment from github-actions bot May 28, 2024
Copy link
Collaborator

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot @mihirp1998 for your hardwork ! In principle this looks good !
I just have few questions with respect to the differences between this method and DDPO, could you clearly highlight either on the documentation or in this PR what are the major differences between DDPO and this algorithm ? 🙏
I would also like to have a review from @sayakpaul if possible, what do you think of Reward Backpropagation ?
Thanks !

Comment on lines 5 to 9
| Before | After finetuning |
| --- | --- |
| <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/pre_squirrel.png"/></div> | <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/post_squirrel.png"/></div> |
| <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/pre_crab.png"/></div> | <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/post_crab.png"/></div> |
| <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/pre_starfish.png"/></div> | <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/post_starfish.png"/></div> |
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These images are the ones generated from DDPO no?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, i wanted to update them, although I wasn't sure how to do it, as they linked to a huggingface internal webpage https://huggingface.co/datasets/trl-internal-testing/

If you can guide me on how to do it, i can update them.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can open a PR to https://huggingface.co/datasets/trl-internal-testing/ repository adding the resultant images you want.

library. A reason for stating this is that getting started requires a bit of familiarity with the `diffusers` library concepts, mainly two of them - pipelines and schedulers.
Right out of the box (`diffusers` library), there isn't a `Pipeline` nor a `Scheduler` instance that is suitable for finetuning with reinforcement learning. Some adjustments need to made.

There is a pipeline interface that is provided by this library that is required to be implemented to be used with the `DDPOTrainer`, which is the main machinery for fine-tuning Stable Diffusion with reinforcement learning. **Note: Only the StableDiffusion architecture is supported at this point.**
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here it references DDPO trainer

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing this. I have fixed this.

@mihirp1998
Copy link
Author

Thanks a lot @mihirp1998 for your hardwork ! In principle this looks good ! I just have few questions with respect to the differences between this method and DDPO, could you clearly highlight either on the documentation or in this PR what are the major differences between DDPO and this algorithm ? 🙏 I would also like to have a review from @sayakpaul if possible, what do you think of Reward Backpropagation ? Thanks !

I have added the differences in the pull request, let me know if u have some doubts or think something is missing.

@@ -0,0 +1,117 @@
# Aligning Text-to-Image Diffusion Models with Reward Backpropagation

## The why
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think as a reader I understand if the following table justifies the name of this section. Would you mind elaborating?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I added a better why statement.

| <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/pre_starfish.png"/></div> | <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/post_starfish.png"/></div> |


## Getting started with Stable Diffusion finetuning with reinforcement learning
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is needed. We should strive to keep the API documentation lean and precise.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes i removed it.

```python

import torch
from trl import DefaultDDPOStableDiffusionPipeline
Copy link
Member

@sayakpaul sayakpaul May 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we have to use a non-diffusers pipeline here? Does DiffusionPipeline from diffusers not work here?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes indeed, i changed it to StableDiffusionPipeline from diffusers

Comment on lines 97 to 104
pipeline = DefaultDDPOStableDiffusionPipeline("metric-space/alignprop-finetuned-sd-model")

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

# memory optimization
pipeline.vae.to(device, torch.float16)
pipeline.text_encoder.to(device, torch.float16)
pipeline.unet.to(device, torch.float16)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These LoCs could be reduce if we do:

pipeline = DefaultDDPOStableDiffusionPipeline("metric-space/alignprop-finetuned-sd-model", torch_dtype=torch.float16)
pipeline = pipeline.to("cuda")

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additionally, https://huggingface.co/metric-space/alignprop-finetuned-sd-model is not available. Let's make sure we're using the right checkpoint ids here.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes i reduced it and fixed the checkpoint ids.

Comment on lines 62 to 84
class MLP(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.Sequential(
nn.Linear(768, 1024),
nn.Dropout(0.2),
nn.Linear(1024, 128),
nn.Dropout(0.2),
nn.Linear(128, 64),
nn.Dropout(0.1),
nn.Linear(64, 16),
nn.Linear(16, 1),
)

def forward(self, embed):
return self.layers(embed)


class AestheticScorer(torch.nn.Module):
"""
This model attempts to predict the aesthetic score of an image. The aesthetic score
is a numerical approximation of how much a specific image is liked by humans on average.
This is from https://github.com/christophschuhmann/improved-aesthetic-predictor
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we copy-pasting these modules from the DDPO script?

@younesbelkada would it make sense to have a separate module for these (auxiliary_modules, perhaps)?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They are not exactly copy pasted, as DDPO had clamp and no_grad operations within them, which were preventing gradients from backpropagating.

Anyhow I still transfered the above reward function code from alignprop.py to trl/models/auxiliary_modules.py, as u suggested.

Comment on lines 715 to 722
if truncated_backprop:
if truncated_backprop_rand:
rand_timestep = random.randint(truncated_rand_backprop_minmax[0],truncated_rand_backprop_minmax[1])
if i < rand_timestep:
noise_pred = noise_pred.detach()
else:
if i < truncated_backprop_timestep:
noise_pred = noise_pred.detach()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We would want to supplement this code block with comments.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes added comments.

@@ -527,6 +527,243 @@ def pipeline_step(

return DDPOPipelineOutput(image, all_latents, all_log_probs)

def pipeline_step_with_grad(
self,
Copy link
Member

@sayakpaul sayakpaul May 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self, could be replaced with pipeline as that is what we're passing down the line, IIUC?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes i changed it to pipeline.


# {model_name}

This is a pipeline that finetunes a diffusion model with reward gradients. The model can be used for image generation conditioned with text.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure what is the norm is within the library but I think it could be nice to also include a link to the AlignProp paper here.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes i added.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your contributions. I left a couple of comments.

I would love to see some concrete comparisons to DDPO (training time, reward dynamics, convergence of the validation samples, etc.).

@mihirp1998
Copy link
Author

mihirp1998 commented Jun 2, 2024

Thanks for your contributions. I left a couple of comments.

I would love to see some concrete comparisons to DDPO (training time, reward dynamics, convergence of the validation samples, etc.).

I have made concrete comparisions with DDPO here. I ran the DDPO default code in TRL with batch size 128, while AlignProp also uses the same batch size. As can be seen AlignProp is significantly more sample efficient, here i train both the models for a few hours. Here x-axis is the epochs and y-axis is the reward achieved.

W B Chart 6_1_2024, 10_35_51 PM (1)

Although i ran the above experiments for a few hours, AlignProp only takes about 30 minutes to converge to a good solution. So i early stopped at the 8th epoch in training. Below is the comparision with DDPO after training both models for 30 minutes. Here x-axis is the training time and y-axis is the reward achieved.

W B Chart 6_1_2024, 10_35_51 PM

Both the curves are similar to the curves in the AlignProp paper.

The above curves were with the same set of prompts during training/testing. In the curve below i show AlignProp results on unseen prompts. As can be seen there is not much gap in results between seen/unseen prompts. Here dotted lines is the unseen prompts while solid line is the seen prompts.

W B Chart 6_1_2024, 10_34_58 PM

Finally here are some generated images from AlignProp, for seen/unseen animals after training.

media_images_llama_8_d86f465673c821a39d26 (1)
media_images_lion_8_9e9a947367e031acdceb (1)
squirrel

CC: @sayakpaul @younesbelkada

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants