Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This PR attempts to refactor and pull all tokenization logic out of the Trainer class. Having a separate tokenization process gives us higher visibility into what's being used in training, providing more clarified logic and reducing bugs. It attempts to do the following things.
why?
Currently, the Trainer is also responsible for tokenization. It causes several issues:
duplicate tokenization steps. For example, alignment-handbook calls apply_chat_template(tokenize=False) for the dataset, followed by SFT/DPO trainer calling tokenized again. To remove duplication, we only needed to go through the dataset once by calling
apply_chat_template(tokenize=True)
truncation logic happens in various places and is hard to predict. SFTTrainer calls it the
max_seq_length
, RewardModeling calls itmax_length
, DPO/KTOTrainers call itmax_length
,max_prompt_length
,max_target_length
. There are also different truncation logics. E.g., [(truncate the prompt if prompt + chosen is too long)](
trl/trl/trainer/dpo_trainer.py
Lines 797 to 799 in 99f2c94
learning to generate EOS tokens. Learning to generate EOS tokens #1623 (comment) suggested that EOS tokens always 1) correspond to -100 in the labels and 2) if the dataset contains the EOS token before collating, then the attention mask of EOS token is also 1. It's possible that the model may never learn to generate EOS tokens.
dataset_num_proc
is not uniformly applied, as a result [ORPO] Enable batched tokenization & multiprocessing to process large datasets #1624 is needed. There is also the question of hashable tokenizationDataset mixer (e.g., in our h4 codebase), that should be more widely available to use in TRL and can be combined with this class.
The current design
The current design roughly looks like this. Note that we can still put it in
Trainer.__init__
so users don't have to configure it directly.