Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ability to store generation logits and vals for training #1535

Closed
wants to merge 1 commit into from

Conversation

ejmejm
Copy link
Contributor

@ejmejm ejmejm commented Apr 13, 2024

When using the PPO trainer, a user will generally generate responses with PPOTrainer.generate(). The input queries and resulting responses are then passed to PPOTrainer.step() to train. At the start of training, the initial logits and values of the these input sequences are again calculated in PPOTrainer.step() (line 1129 in this branch). The majority of this computation is a repeat of what was computed during the call to PPOTrainer.generate(), which leaves an opportunity to save the computation cost of one forward pass per training batch. This optimization was previous requested in ticket #848.

This change adds the ability to return the values and logits during generation, so that they can be fed back to PPOTrainer.step(), and reused to save on computation. Example usage:

response_tensors, values_and_logits = ppo_trainer.generate(
    query_tensors,
    return_prompt=False,
    return_values_and_logits=True,
    **generation_kwargs,
)

ppo_trainer.step(query_tensors, response_tensors, reward, values_and_logits=values_and_logits)

This is an easy addition to anyone who wants to save compute. The return_values_and_logits and values_and_logits arguments of these functions are optional, so using these functions without change is also not an issue.

@vwxyzjn
Copy link
Collaborator

vwxyzjn commented Apr 17, 2024

@ejmejm I believe https://github.com/vwxyzjn/trl/blob/61e39010bd660fbddb9cff6a1f50d347ae375f9e/trl/trainer/ppov2_bandit_rloo_trainer.py#L461-L481 does what you are thinking? Ah I also realize the latest transformer just supported the output_logits which previously I needed to do output_score.

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

@github-actions github-actions bot closed this May 22, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants