Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Unify Policy Trainers #1586

Draft
wants to merge 320 commits into
base: main
Choose a base branch
from
Draft

Conversation

lapp0
Copy link

@lapp0 lapp0 commented Apr 25, 2024

WIP: Unify Policy Trainers

Overview / Problem

Many trainers within trl follow the same paradigm:

    1. Create a base model and a frozen reference model
    1. Given a batch,
    • 2a) Generate using one or both models
    • 2b) Get logits
    • 2c) Use trainer-specific method to calculate loss
    • 2d) backpropagate, and log metrics

Trainers following this workflow include PPOTrainer, DPOTrainer, KTOTrainer, and the new RLOOTrainer (in PR).

Despite sharing these features, each trainer has repetitive and sometimes inconsistent implementations of core components including reference model management, generation of policy output, and even model saving.

This has resulted in a number of bugs, confusion, and unnecessary redundant work when implementing new policy trainers.

PolicyTrainerBase

The goal for this PR is to introduce an abstract PolicyTrainerBase with the RLOOTrainer adapted from #1540

The adapted RLOOTrainer only implements training_step() which is provided a batch of inputs, calculates loss, applies backprop, and logs metrics.

PolicyTrainerBase takes care of everything else, primarily preparation and management of the reference model, along with preparation of the generation config, and a utility function for generation of output sequences and logits.

I'll have to consider the generation function carefully, as that is one of the most complex components of the different policy trainers (see PPOTrainers implementation

def generate(
self,
query_tensor: Union[torch.Tensor, List[torch.Tensor]],
length_sampler: Optional[Callable] = None,
batch_size: int = 4,
return_prompt: bool = True,
generate_ref_response: bool = False,
**generation_kwargs,
):
"""
Generate response with the model given the query tensor.
call the `generate` method of the model.
Args:
query_tensor (`torch.LongTensor`):
A tensor of shape (`seq_len`) containing query tokens or a list of tensors of shape (`seq_len`).
length_sampler (`Callable`, *optional*):
Callable that returns the number of newly generated tokens.
batch_size (`int`, *optional):
Batch size used for generation, defaults to `4`.
return_prompt (`bool`, *optional*):
If set to `False` the prompt is not returned but only the newly generated tokens, defaults to `True`.
generate_ref_response (`bool`, *optional*):
If set to `True` the reference response is also generated, defaults to `False`.
generation_kwargs (dict[str, Any]):
Keyword arguments for generation.
Returns:
`torch.LongTensor`: A tensor of shape (`batch_size`, `gen_len`) containing response tokens.
"""
if generate_ref_response:
ref_model = self.model if self.is_peft_model else self.ref_model
if isinstance(query_tensor, List):
response = self._generate_batched(
self.model,
query_tensor,
length_sampler=length_sampler,
batch_size=batch_size,
return_prompt=return_prompt,
**generation_kwargs,
)
if generate_ref_response:
with self.optional_peft_ctx():
ref_response = self._generate_batched(
ref_model,
query_tensor,
length_sampler=length_sampler,
batch_size=batch_size,
return_prompt=return_prompt,
**generation_kwargs,
)
else:
if len(query_tensor.shape) == 2:
raise ValueError(
"query_tensor must be a tensor of shape (`seq_len`) or a list of tensors of shape (`seq_len`)"
)
if length_sampler is not None:
generation_kwargs["max_new_tokens"] = length_sampler()
response = self.accelerator.unwrap_model(self.model).generate(
input_ids=query_tensor.unsqueeze(dim=0), **generation_kwargs
)
if generate_ref_response:
with self.optional_peft_ctx():
ref_response = ref_model.generate(input_ids=query_tensor.unsqueeze(dim=0), **generation_kwargs)
if not return_prompt and not self.is_encoder_decoder:
response = response[:, query_tensor.shape[0] :]
if generate_ref_response:
ref_response = ref_response[:, query_tensor.shape[0] :]
if generate_ref_response:
return response, ref_response
return response
def _generate_batched(
self,
model: PreTrainedModelWrapper,
query_tensors: List[torch.Tensor],
length_sampler: Optional[Callable] = None,
batch_size: int = 4,
return_prompt: bool = True,
pad_to_multiple_of: Optional[int] = None,
remove_padding: bool = True,
**generation_kwargs,
):
outputs = []
padding_side_default = self.tokenizer.padding_side
if not self.is_encoder_decoder:
self.tokenizer.padding_side = "left"
# in case we have fewer examples than bs
batch_size = min(len(query_tensors), batch_size)
for i in range(0, len(query_tensors), batch_size):
if length_sampler is not None:
generation_kwargs["max_new_tokens"] = length_sampler()
# prevent overflow if query tensors are not even multiple of bs
end_index = min(len(query_tensors), i + batch_size)
batch = query_tensors[i:end_index]
batch_mask = [torch.ones_like(element) for element in batch]
inputs = {"input_ids": batch, "attention_mask": batch_mask}
padded_inputs = self.tokenizer.pad(
inputs,
padding=True,
max_length=None,
pad_to_multiple_of=pad_to_multiple_of,
return_tensors="pt",
).to(self.current_device)
generations = self.accelerator.unwrap_model(model).generate(**padded_inputs, **generation_kwargs)
for generation, mask in zip(generations, padded_inputs["attention_mask"]):
if not self.is_encoder_decoder:
output = generation[(1 - mask).sum() :] # remove padding
else:
output = generation
if not return_prompt and not self.is_encoder_decoder:
output = output[(mask).sum() :] # remove prompt
if remove_padding and self.tokenizer.eos_token_id in output:
pad_mask = output == self.tokenizer.eos_token_id
pad_start = torch.nonzero(pad_mask, as_tuple=False)[0, 0].item()
output = output[: pad_start + 1] # keep the eos token at the end
outputs.append(output)
self.tokenizer.padding_side = padding_side_default
return outputs
)

Remaining Work:

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

@lapp0
Copy link
Author

lapp0 commented Jun 1, 2024

Awaiting unslothai/unsloth#533

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants