Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PPO / Reinforce Trainers #1540

Merged
merged 59 commits into from
May 22, 2024
Merged

PPO / Reinforce Trainers #1540

merged 59 commits into from
May 22, 2024

Conversation

vwxyzjn
Copy link
Collaborator

@vwxyzjn vwxyzjn commented Apr 15, 2024

This RP supports the REINFORCE RLOO trainers in https://arxiv.org/pdf/2402.14740.pdf.

Note that REINFORCE's loss is a special case of PPO, as shown below

image

it matches the REINFORCE loss presented in the Cohere paper (where PPO uses advantages A hat, but REINFORCE uses the RLHF reward R(y, x))
image

We add the following files

  • trl/trainer/ppov2_trainer.py
  • trl/trainer/ppov2_bandit_rloo_trainer.py
    • a PPO variant which implements 1) modeling completion as a joint action and 2) RLOO loss, which does not use a value network
    • I copied this file directly from ppov2_trainer.py, so feel free to do a file diff to see the changes(e.g., the following diff shows how the RLOO loss is implemented)
image

Two more examples showing how they work with dummy reward models

  • examples/scripts/minimal/ppo.py: preliminary experiment shows RLHF reward goes up, so from an optimization standpoint it works as intended
image
  • examples/scripts/minimal/ppo_bandit_rloo.py: preliminary experiment shows RLHF reward goes up, so from an optimization standpoint it works as intended; though the KL kind of exploded, so we may need to use a larger beta for stronger regularization.
image

@vwxyzjn vwxyzjn requested a review from lewtun April 15, 2024 16:06
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link

@lapp0 lapp0 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks so much for this contribution. I've been hoping to experiment with REINFORCE on transformers for a while now, but didn't have the time to roll my own implementation.

This is a great foundation in terms of functionality. I'll be playing around with it soon.

I think we should reduce the repetition, and use inheritance of existing classes so we can take advantage of the great infrastructure built out by huggingface/transformers and huggingface/trl.

Happy to help if you're interested in collaborating, let me know.

trl/trainer/ppov2_bandit_rloo_trainer.py Outdated Show resolved Hide resolved
trl/trainer/ppov2_bandit_rloo_trainer.py Outdated Show resolved Hide resolved
trl/trainer/ppov2_bandit_rloo_trainer.py Outdated Show resolved Hide resolved
trl/trainer/ppov2_bandit_rloo_trainer.py Outdated Show resolved Hide resolved
trl/trainer/ppov2_bandit_rloo_trainer.py Outdated Show resolved Hide resolved
examples/scripts/minimal/ppo_bandit_rloo_large.py Outdated Show resolved Hide resolved
Copy link
Member

@lewtun lewtun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Epic work on adding these new RL trainers @vwxyzjn ! I've left some high level feedback on the RLOO trainer for now and will do a more fine-grained review when we've iterated a bit on the design.

Overall looks super clean

examples/scripts/minimal/ppo.py Outdated Show resolved Hide resolved
examples/scripts/minimal/ppo.py Outdated Show resolved Hide resolved
examples/scripts/minimal/ppo.py Outdated Show resolved Hide resolved
examples/scripts/minimal/ppo_bandit_rloo.py Outdated Show resolved Hide resolved
examples/scripts/minimal/ppo_bandit_rloo.py Outdated Show resolved Hide resolved
trl/trainer/ppov2_bandit_rloo_trainer.py Outdated Show resolved Hide resolved
trl/trainer/ppov2_bandit_rloo_trainer.py Outdated Show resolved Hide resolved
trl/trainer/ppov2_bandit_rloo_trainer.py Outdated Show resolved Hide resolved
trl/trainer/ppov2_bandit_rloo_trainer.py Outdated Show resolved Hide resolved
trl/trainer/ppov2_bandit_rloo_trainer.py Outdated Show resolved Hide resolved
@vwxyzjn
Copy link
Collaborator Author

vwxyzjn commented Apr 23, 2024

Hi @lapp0 thanks for the review! Will look into these comments more closely. I started running some experiments and noticed the KL of RLOO was orders of magnitude higher than that of the new PPO trainer. Not exactly sure the reason but will further investigate.

image

The PPO / Vanilla PG actually seems quite stable now with RLHF reward going up and model gets good scores and reasonable completions. There were some implementation details I found particularly helpful such as truncate at EOS token (i.e., --truncate_token eos), and I suspect the same technique could work nicely with RLOO.

Right now I am a bit focused on a zephyr PPO /Vanilla PG recipe for these couple of days, and will look into RLOO right after.

image image image

Copy link

@lapp0 lapp0 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a WIP fork. The main difference is that it lets transformers.Trainer set everything up: batch creation, accelerate / deepspeed, etc. Instead it overrides training_step.

https://github.com/lapp0/trl/blob/onpolicy/trl/trainer/rloo_trainer.py

The main behavior difference is that it generates once per batch and runs for num_train_epochs rather than generating once per update and running for num_train_epochs * num_updates. Have you experimented with updating once per batch, and if so, does this harm stability? Is it important that I retain the ability to update once and run multiple epochs based on the model outputs from the start of the generation?

It's possible that generating once per batch instead of per update would improve KL now that you mention it.

trl/trainer/ppov2_bandit_rloo_trainer.py Outdated Show resolved Hide resolved
trl/trainer/ppov2_trainer.py Show resolved Hide resolved
@lapp0 lapp0 mentioned this pull request Apr 25, 2024
4 tasks
Copy link
Collaborator Author

@vwxyzjn vwxyzjn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lapp0 @lewtun thanks so much for the review! I put some comments down and TODO items.

examples/scripts/minimal/ppo.py Outdated Show resolved Hide resolved
examples/scripts/minimal/ppo.py Outdated Show resolved Hide resolved
examples/scripts/minimal/ppo.py Outdated Show resolved Hide resolved
examples/scripts/minimal/ppo.py Outdated Show resolved Hide resolved
examples/scripts/minimal/ppo.py Outdated Show resolved Hide resolved
trl/trainer/ppov2_bandit_rloo_trainer.py Outdated Show resolved Hide resolved
trl/trainer/ppov2_bandit_rloo_trainer.py Outdated Show resolved Hide resolved
trl/trainer/ppov2_bandit_rloo_trainer.py Outdated Show resolved Hide resolved
trl/trainer/ppov2_bandit_rloo_trainer.py Outdated Show resolved Hide resolved
trl/trainer/ppov2_trainer.py Show resolved Hide resolved
@vwxyzjn
Copy link
Collaborator Author

vwxyzjn commented Apr 25, 2024

After some refactoring / bug fixes, the new RLOO also seems much more stable. Will report when having newer results.
image

trl/trainer/ppov2_trainer.py Outdated Show resolved Hide resolved
@vwxyzjn
Copy link
Collaborator Author

vwxyzjn commented May 8, 2024

To show it works, per Arash's suggestion, I also ran experiments on TL;DR to see if it works. Should have more results in https://wandb.ai/costa-huang/huggingface/reports/ppo-rloo-tldr--Vmlldzo3ODUzNDEx tomorrow.

Copy link
Collaborator

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for this huge work ! I left some comments ! I think the new classes should be also exposed in TRL's main init - LMK wdyt about my suggestions below 🙏

Comment on lines 123 to 156
def masked_mean(values, mask, axis=None):
"""Compute mean of tensor with a masked values."""
if axis is not None:
return (values * mask).sum(axis=axis) / mask.sum(axis=axis)
else:
return (values * mask).sum() / mask.sum()


def masked_var(values, mask, unbiased=True):
"""Compute variance of tensor with masked values."""
mean = masked_mean(values, mask)
centered_values = values - mean
variance = masked_mean(centered_values**2, mask)
if unbiased:
mask_sum = mask.sum()
if mask_sum == 0:
raise ValueError(
"The sum of the mask is zero, which can happen when `mini_batch_size=1`;"
"try increase the `mini_batch_size` or `gradient_accumulation_steps`"
)
# note that if mask_sum == 1, then there is a division by zero issue
# to avoid it you just need to use a larger minibatch_size
bessel_correction = mask_sum / (mask_sum - 1)
variance = variance * bessel_correction
return variance


def masked_whiten(values, mask, shift_mean=True):
"""Whiten values with masked values."""
mean, var = masked_mean(values, mask), masked_var(values, mask, False)
whitened = (values - mean) * torch.rsqrt(var + 1e-8)
if not shift_mean:
whitened += mean
return whitened
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

those look the same as in:

trl/trl/core.py

Line 152 in 3b4c249

def masked_mean(values: torch.Tensor, mask: torch.Tensor, axis: Optional[bool] = None) -> torch.Tensor:
- can't you re-use them from trl.core?

return whitened


def get_reward(model, query_responses, tokenizer, context_length):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you move this method to trl.core?


def get_reward(model, query_responses, tokenizer, context_length):
attention_mask = query_responses != tokenizer.pad_token_id
# position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum

trl/trainer/rloo_trainer.py Outdated Show resolved Hide resolved
trl/trainer/rloo_trainer.py Outdated Show resolved Hide resolved
trl/trainer/rloo_trainer.py Show resolved Hide resolved
trl/trainer/rloo_trainer.py Outdated Show resolved Hide resolved
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we move all ppo-related minimal scripts under a new ppo/ dir and rloo under rloo/ dir ? What do you think?

Copy link
Member

@lewtun lewtun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for iterating on this epic PR @vwxyzjn ! Overall it's looking quite close to being finished and I think the main remaining points to address are splitting off the configs into their own modules and seeing if we can hide config variables like world_size from the end user

parser = HfArgumentParser((PPOConfig, ModelConfig))
config, model_config = parser.parse_args_into_dataclasses()
# remove output_dir if exists
shutil.rmtree(config.output_dir, ignore_errors=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI you can set overwrite_output_dir in PPOConfig (via TrainingArguments)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I gave it a quick test but it does not seem to remove the output_dir.

image

A quick search shows that the removing logic seems no longer there.

image

@@ -150,7 +150,7 @@ def unwrap_model_for_generation(
if accelerator.state.deepspeed_plugin is not None and accelerator.state.deepspeed_plugin.zero_stage == 3:
with deepspeed.zero.GatheredParameters(model.parameters()):
remove_hooks(model)
yield model
yield accelerator.unwrap_model(model)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe models wrapped with the DeepSpeedEngine can still generate, so I'm curious why this is needed

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah it causes issues for RLOO, having some errors like "DeepSpeedEngine has not attribute generate`, so we still need to unwrap it.

"""Whether to use deepspeed to train the model"""

# various batch sizes
world_size: Optional[int] = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, but why do we only seem to need this for the RL trainers and not the other ones like SFTTrainer? In general, I'd like to avoid exposing this distributed stuff to the user if we can because it might not be clear if they should set the value manually or let accelerate handle it for them

trl/trainer/ppov2_trainer.py Outdated Show resolved Hide resolved
trl/trainer/ppov2_trainer.py Outdated Show resolved Hide resolved
trl/trainer/ppov2_trainer.py Outdated Show resolved Hide resolved
trl/trainer/ppov2_trainer.py Show resolved Hide resolved
trl/trainer/ppov2_trainer.py Outdated Show resolved Hide resolved
trl/trainer/rloo_trainer.py Outdated Show resolved Hide resolved
trl/trainer/rloo_trainer.py Outdated Show resolved Hide resolved
@vwxyzjn
Copy link
Collaborator Author

vwxyzjn commented May 15, 2024

Thank you @lewtun @younesbelkada @lapp0 for the review. I have addressed most of the concerns and also added some docs and benchmarks. Let me know if there is anything else needed :D

Copy link
Collaborator

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Huge work ! Thanks @vwxyzjn ! Good for me to merge once @lewtun is happy about the latest changes + CI is green ! 🚀

@lapp0
Copy link

lapp0 commented May 18, 2024

Great work @vwxyzjn! Really impressed with the vLLM integration along with the other components you've introduced here.

I'll be working on a follow-up PR for quantized training using ppo_v2 once Unsloth's numerical stability issue is resolved, and hopefully incorporate a few structural changes as well, so I don't have any further comments on structure right now.

Did any of your RLOO runs result in improved benchmarks or at least improved score metrics? I was able to reproduce improving scores with ppov2 my refactor of your branch with BnB / peft support, but I never managed to do the same with RLOO.

PPOV2 metrics:

image

@vwxyzjn
Copy link
Collaborator Author

vwxyzjn commented May 21, 2024

@lapp0 Very nice to hear your great results with PPOv2 and peft! I was able to get 1B RLOO good results on tl;dr summarization. See https://moon-ci-docs.huggingface.co/docs/trl/pr_1540/en/rloo_trainer#benchmark-experiments.

@vwxyzjn vwxyzjn merged commit 13454d2 into huggingface:main May 22, 2024
9 checks passed
@lapp0
Copy link

lapp0 commented May 22, 2024

Great work @vwxyzjn really exciting research and implementation you have put together. Feel free to ping me on any other PRs.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants