-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
PPO / Reinforce Trainers #1540
PPO / Reinforce Trainers #1540
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks so much for this contribution. I've been hoping to experiment with REINFORCE on transformers for a while now, but didn't have the time to roll my own implementation.
This is a great foundation in terms of functionality. I'll be playing around with it soon.
I think we should reduce the repetition, and use inheritance of existing classes so we can take advantage of the great infrastructure built out by huggingface/transformers and huggingface/trl.
Happy to help if you're interested in collaborating, let me know.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Epic work on adding these new RL trainers @vwxyzjn ! I've left some high level feedback on the RLOO trainer for now and will do a more fine-grained review when we've iterated a bit on the design.
Overall looks super clean
Hi @lapp0 thanks for the review! Will look into these comments more closely. I started running some experiments and noticed the KL of RLOO was orders of magnitude higher than that of the new PPO trainer. Not exactly sure the reason but will further investigate. The PPO / Vanilla PG actually seems quite stable now with RLHF reward going up and model gets good scores and reasonable completions. There were some implementation details I found particularly helpful such as truncate at EOS token (i.e., Right now I am a bit focused on a zephyr PPO /Vanilla PG recipe for these couple of days, and will look into RLOO right after. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have a WIP fork. The main difference is that it lets transformers.Trainer
set everything up: batch creation, accelerate / deepspeed, etc. Instead it overrides training_step
.
https://github.com/lapp0/trl/blob/onpolicy/trl/trainer/rloo_trainer.py
The main behavior difference is that it generates once per batch and runs for num_train_epochs
rather than generating once per update and running for num_train_epochs * num_updates
. Have you experimented with updating once per batch, and if so, does this harm stability? Is it important that I retain the ability to update once and run multiple epochs based on the model outputs from the start of the generation?
It's possible that generating once per batch instead of per update would improve KL now that you mention it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To show it works, per Arash's suggestion, I also ran experiments on TL;DR to see if it works. Should have more results in https://wandb.ai/costa-huang/huggingface/reports/ppo-rloo-tldr--Vmlldzo3ODUzNDEx tomorrow. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for this huge work ! I left some comments ! I think the new classes should be also exposed in TRL's main init - LMK wdyt about my suggestions below 🙏
trl/trainer/ppov2_trainer.py
Outdated
def masked_mean(values, mask, axis=None): | ||
"""Compute mean of tensor with a masked values.""" | ||
if axis is not None: | ||
return (values * mask).sum(axis=axis) / mask.sum(axis=axis) | ||
else: | ||
return (values * mask).sum() / mask.sum() | ||
|
||
|
||
def masked_var(values, mask, unbiased=True): | ||
"""Compute variance of tensor with masked values.""" | ||
mean = masked_mean(values, mask) | ||
centered_values = values - mean | ||
variance = masked_mean(centered_values**2, mask) | ||
if unbiased: | ||
mask_sum = mask.sum() | ||
if mask_sum == 0: | ||
raise ValueError( | ||
"The sum of the mask is zero, which can happen when `mini_batch_size=1`;" | ||
"try increase the `mini_batch_size` or `gradient_accumulation_steps`" | ||
) | ||
# note that if mask_sum == 1, then there is a division by zero issue | ||
# to avoid it you just need to use a larger minibatch_size | ||
bessel_correction = mask_sum / (mask_sum - 1) | ||
variance = variance * bessel_correction | ||
return variance | ||
|
||
|
||
def masked_whiten(values, mask, shift_mean=True): | ||
"""Whiten values with masked values.""" | ||
mean, var = masked_mean(values, mask), masked_var(values, mask, False) | ||
whitened = (values - mean) * torch.rsqrt(var + 1e-8) | ||
if not shift_mean: | ||
whitened += mean | ||
return whitened |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
those look the same as in:
Line 152 in 3b4c249
def masked_mean(values: torch.Tensor, mask: torch.Tensor, axis: Optional[bool] = None) -> torch.Tensor: |
trl/trainer/ppov2_trainer.py
Outdated
return whitened | ||
|
||
|
||
def get_reward(model, query_responses, tokenizer, context_length): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you move this method to trl.core
?
trl/trainer/ppov2_trainer.py
Outdated
|
||
def get_reward(model, query_responses, tokenizer, context_length): | ||
attention_mask = query_responses != tokenizer.pad_token_id | ||
# position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum |
examples/scripts/minimal/ppo.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we move all ppo-related minimal scripts under a new ppo/
dir and rloo under rloo/
dir ? What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for iterating on this epic PR @vwxyzjn ! Overall it's looking quite close to being finished and I think the main remaining points to address are splitting off the configs into their own modules and seeing if we can hide config variables like world_size
from the end user
examples/scripts/minimal/ppo.py
Outdated
parser = HfArgumentParser((PPOConfig, ModelConfig)) | ||
config, model_config = parser.parse_args_into_dataclasses() | ||
# remove output_dir if exists | ||
shutil.rmtree(config.output_dir, ignore_errors=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI you can set overwrite_output_dir
in PPOConfig
(via TrainingArguments
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@@ -150,7 +150,7 @@ def unwrap_model_for_generation( | |||
if accelerator.state.deepspeed_plugin is not None and accelerator.state.deepspeed_plugin.zero_stage == 3: | |||
with deepspeed.zero.GatheredParameters(model.parameters()): | |||
remove_hooks(model) | |||
yield model | |||
yield accelerator.unwrap_model(model) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe models wrapped with the DeepSpeedEngine
can still generate, so I'm curious why this is needed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah it causes issues for RLOO, having some errors like "DeepSpeedEngine has not attribute generate`, so we still need to unwrap it.
trl/trainer/ppov2_trainer.py
Outdated
"""Whether to use deepspeed to train the model""" | ||
|
||
# various batch sizes | ||
world_size: Optional[int] = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, but why do we only seem to need this for the RL trainers and not the other ones like SFTTrainer
? In general, I'd like to avoid exposing this distributed stuff to the user if we can because it might not be clear if they should set the value manually or let accelerate
handle it for them
Thank you @lewtun @younesbelkada @lapp0 for the review. I have addressed most of the concerns and also added some docs and benchmarks. Let me know if there is anything else needed :D |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work @vwxyzjn! Really impressed with the vLLM integration along with the other components you've introduced here. I'll be working on a follow-up PR for quantized training using ppo_v2 once Unsloth's numerical stability issue is resolved, and hopefully incorporate a few structural changes as well, so I don't have any further comments on structure right now. Did any of your RLOO runs result in improved benchmarks or at least improved score metrics? I was able to reproduce improving scores with ppov2 my refactor of your branch with BnB / peft support, but I never managed to do the same with RLOO. PPOV2 metrics: |
@lapp0 Very nice to hear your great results with PPOv2 and peft! I was able to get 1B RLOO good results on tl;dr summarization. See https://moon-ci-docs.huggingface.co/docs/trl/pr_1540/en/rloo_trainer#benchmark-experiments. |
Great work @vwxyzjn really exciting research and implementation you have put together. Feel free to ping me on any other PRs. |
This RP supports the REINFORCE RLOO trainers in https://arxiv.org/pdf/2402.14740.pdf.
Note that REINFORCE's loss is a special case of PPO, as shown below
it matches the REINFORCE loss presented in the Cohere paper (where PPO uses advantages A hat, but REINFORCE uses the RLHF reward R(y, x))
We add the following files
ppov2_trainer.py
, so feel free to do a file diff to see the changes(e.g., the following diff shows how the RLOO loss is implemented)Two more examples showing how they work with dummy reward models