Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prototype Dataset Processor #1646

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open

Prototype Dataset Processor #1646

wants to merge 11 commits into from

Conversation

vwxyzjn
Copy link
Collaborator

@vwxyzjn vwxyzjn commented May 16, 2024

This PR attempts to refactor and pull all tokenization logic out of the Trainer class. Having a separate tokenization process gives us higher visibility into what's being used in training, providing more clarified logic and reducing bugs. It attempts to do the following things.

# 1. PPO (prompt)
# 2. SFT (prompt + demonstration), there is also packing.
# 3. ✅ RM / DPO (chosen and rejected)
# 4. ✅ Visualization of length distributions?
# 5. ✅ Filter?
#   * Smart truncation?
# 6. ✅ dataset_num_proc
# 7. check EOS token
# 8. dataset mixer?
# 9. ✅ pretty print that show tokenization?
# 10. hashable tokneization?
# 11. inputs / labels / attention_mask
# 12. always set a `tokenizer.pad_token_id`?

why?

Currently, the Trainer is also responsible for tokenization. It causes several issues:

  1. duplicate tokenization steps. For example, alignment-handbook calls apply_chat_template(tokenize=False) for the dataset, followed by SFT/DPO trainer calling tokenized again. To remove duplication, we only needed to go through the dataset once by calling apply_chat_template(tokenize=True)

  2. truncation logic happens in various places and is hard to predict. SFTTrainer calls it the max_seq_length, RewardModeling calls it max_length, DPO/KTOTrainers call it max_length, max_prompt_length, max_target_length. There are also different truncation logics. E.g., [(truncate the prompt if prompt + chosen is too long)]
    (

    # if combined sequence is too long, truncate the prompt
    for answer_tokens in [chosen_tokens, rejected_tokens, prompt_tokens]:
    if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
    ). This causes issue like https://huggingface.slack.com/archives/C04EX6W3QSY/p1715255460198239 as raised by @abhishekkrthakur.

    • the hard truncation logic seems debatable: if the sequence length is too long, shouldn't we filter them out instead of giving a truncated response? The truncated response could be an incomplete code snippet / summaries (basically bad data). If truncation is really desired, we should do some kind of smart truncation like truncate at the last paragraph, so the sentences are still complete.
  3. learning to generate EOS tokens. Learning to generate EOS tokens  #1623 (comment) suggested that EOS tokens always 1) correspond to -100 in the labels and 2) if the dataset contains the EOS token before collating, then the attention mask of EOS token is also 1. It's possible that the model may never learn to generate EOS tokens.

    • what's a bit unclear to me is how zephyr learns to output EOS tokens, despite all the labels of EOS token are marked with -100 and are being masked out. My suspicion is that the attention_mask=1 plays some roles in it.
  4. dataset_num_proc is not uniformly applied, as a result [ORPO] Enable batched tokenization & multiprocessing to process large datasets #1624 is needed. There is also the question of hashable tokenization

  5. Dataset mixer (e.g., in our h4 codebase), that should be more widely available to use in TRL and can be combined with this class.

The current design

The current design roughly looks like this. Note that we can still put it in Trainer.__init__ so users don't have to configure it directly.

dataset_config = DatasetConfig(max_token_length=1024, max_prompt_token_lenth=128)
dataset_processor = PreferenceDatasetProcessor(tokenizer=tok, config=dataset_config)
train_dataset = dataset_processor.tokenize(preference_datasets["train"])
stats = dataset_processor.get_token_length_stats(train_dataset)
pprint.pp(stats)
train_dataset = dataset_processor.filter(train_dataset)
stats = dataset_processor.get_token_length_stats(train_dataset)
pprint.pp(stats)
dataset_processor.get_token_length_visualization(train_dataset)
print(tok.decode(train_dataset[0]["chosen"]))
visualize_token(train_dataset[0]["chosen"], tok)
image image

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@kashif
Copy link
Collaborator

kashif commented May 16, 2024

very cool! thanks! checking

@edbeeching
Copy link
Collaborator

what's a bit unclear to me is how zephyr learns to output EOS tokens, despite all the labels of EOS token are marked with -100 and are being masked out. My suspicion is that the attention_mask=1 plays some roles in it.

I think for zephyr we used packing and there is a concat token=eos that is not masked / ignored.

@vwxyzjn
Copy link
Collaborator Author

vwxyzjn commented May 18, 2024

@edbeeching you are right, because the datacollator is not called when using the packed dataset! The output below confirms it.

image

@vwxyzjn vwxyzjn marked this pull request as ready for review June 21, 2024 15:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants