Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Seq2SeqTrainer with DataCollatorForCompletionOnlyLM: incorrect masking for evaluation #1634

Open
adamamer20 opened this issue May 8, 2024 · 6 comments

Comments

@adamamer20
Copy link

adamamer20 commented May 8, 2024

In #632, there was a discussion on labels being included in the input_ids when using DataCollatorForCompletionOnlyLM.

The provided reason was accurately computing the loss when training the model, which makes sense.

However, I was trying to use trainer.evaluate with predict_with_generate=True to use a custom evaluation metric.

def evaluate_predictions(eval_preds: EvalPrediction):
    #custom function

collator = DataCollatorForCompletionOnlyLM(
    instruction_template="<|user|>",
    response_template="<|assistant|>",
    tokenizer=tokenizer,


training_args = Seq2SeqTrainingArguments(
    output_dir=training_output_path,
    evaluation_strategy="epoch",
    predict_with_generate=True,
    load_best_model_at_end=True,

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_ds["train"],
    eval_dataset=train_ds["test"],
    tokenizer=tokenizer,
    data_collator=collator,
    compute_metrics=evaluate_predictions,

When generating the text to evaluate, the Trainer does not take into account that the labels are present in the text beforehand, and thus
the true labels are considered when generating the predictions.

A simple fix would be to change the attention_mask during generation so that the true labels are ignored in the forward pass.
This is how I solved it:

  collator = DataCollatorForCompletionOnlyLM(
        instruction_template="<|user|>",
        response_template="<|assistant|>",
        tokenizer=tokenizer,
    )

 def apply_chat_template(
        example,
    ):
        example["input_ids"] = tokenizer.apply_chat_template(
            example["messages"], tokenize=True, add_generation_prompt=False
        )
        
        collated_data = collator([example["input_ids"]])
        
        example["input_ids"] = collated_data["input_ids"][0]
        example["labels"] = collated_data["labels"][0]

        return example


 def mask_attention_response(
        example,
    ):
        response_len = len(tokenizer(example["messages"][-1]["content"])["input_ids"])

        example["attention_mask"] = torch.tensor(
            [1] * (len(example["input_ids"]) - response_len) + [0] * response_len
        )


        return example


    train_ds = raw_ds.map(apply_chat_template).train_test_split(split, seed=SEED)

    train_ds["test"] = (
        train_ds["test"]
        .map(mask_attention_response, desc="Masking assistant response")
    )
@younesbelkada
Copy link
Collaborator

Hi @adamamer20
Thanks for the detailed issue and the proposed fix, just for me to understand, where does the fix needs to be applied exactly?

@adamamer20
Copy link
Author

adamamer20 commented May 24, 2024

Hi @younesbelkada,
I think the most elegant solution, also to address the issues in #862 would be to modify the class of the SFTTrainer to subclass from the Seq2SeqTrainer instead of the general Trainer, adding the possibility for a custom data collator during evaluation with respect to training. Look at the class i have created here:

class FineTuningTrainer(Seq2SeqTrainer):

    def __init__(self, *args, eval_data_collator=None, **kwargs):
        super().__init__(*args, **kwargs)
        if not eval_data_collator:
             eval_data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer)
        self.eval_data_collator = eval_data_collator

    def get_eval_dataloader(self, eval_dataset: Dataset | None = None) -> DataLoader:
        if eval_dataset is None and self.eval_dataset is None:
            raise ValueError("Trainer: evaluation requires an eval_dataset.")

        # If we have persistent workers, don't do a fork bomb especially as eval datasets
        # don't change during training
        if (
            hasattr(self, "_eval_dataloader")
            and self.args.dataloader_persistent_workers
        ):
            return self.accelerator.prepare(self._eval_dataloader)
        eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
        data_collator = self.eval_data_collator #Here eval_data_collator is called instead of standard data_collator

        if isinstance(eval_dataset, Dataset):
            eval_dataset = self._remove_unused_columns(
                eval_dataset, description="evaluation"
            )
        else:
            data_collator = self._get_collator_with_removed_columns(
                data_collator, description="evaluation"
            )

        dataloader_params = {
            "batch_size": self.args.eval_batch_size,
            "collate_fn": data_collator,
            "num_workers": self.args.dataloader_num_workers,
            "pin_memory": self.args.dataloader_pin_memory,
            "persistent_workers": self.args.dataloader_persistent_workers,
        }

        if not isinstance(sample_dataset, torch.utils.data.IterableDataset):
            dataloader_params["sampler"] = self._get_eval_sampler(eval_dataset)
            dataloader_params["drop_last"] = self.args.dataloader_drop_last
            dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor

        # accelerator.free_memory() will destroy the references, so
        # we need to store the non-prepared version

        eval_dataloader = DataLoader(eval_dataset, **dataloader_params)
        if self.args.dataloader_persistent_workers:
            self._eval_dataloader = eval_dataloader

        return self.accelerator.prepare(eval_dataloader)


    trainer = FineTuningTrainer(
        model=model,
        args=training_args,
        train_dataset=train_ds["train"],
        eval_dataset=train_ds["validation"],
        tokenizer=tokenizer,
        data_collator=DataCollatorForCompletionOnlyLM(
            instruction_template="<|start_header_id|>user<|end_header_id|>",
            response_template="<|start_header_id|>assistant<|end_header_id|>",
            tokenizer=tokenizer,
        ),
        eval_data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer),
        callbacks=[EarlyStoppingCallback(early_stopping_patience=5)],
        compute_metrics=evaluate_predictions,
    )

In this way, both the generation problem in #862 and the attention masking problem here in #1634 would be solved. It is not clear to me however what exactly the current implementation of SFTTrainer does differently from Seq2SeqTrainer. Here they would be essentially the same, apart from the data_collator.

Note that I had to preprocess differently the train and eval datasets here to make this work, substituting the true response with padding in the eval_dataset. This means that the computed loss for evaluation is not correct:

    def process_messages(
        example: dict,
        system_prompt: str,
        split: str,
    ):
        from json import dumps
        from re import sub

        example["messages"] = [
            {
                "role": "system",
                "content": f"{system_prompt)}",
            },
            {
                "role": "user",
                "content": f"{example['text'])}",
            },
        ]

        if split == "train":
            example["messages"].append(
                {"role": "assistant", "content": example["labels"]}
            )
        return example

    def apply_chat_template(example: dict, tokenizer: AutoTokenizer, split: str):

        example["input_ids"] = tokenizer.apply_chat_template(
            example["messages"],
            tokenize=True,
            add_generation_prompt=False if split == "train" else True,
        )

        if split != "train":
            tokenizer.padding_side = "left"

            example["labels"] = tokenizer(
                example["labels"],
                padding="max_length",
                max_length=len(example["input_ids"]),
            )["input_ids"]

            tokenizer.padding_side = "right"

        return example

One could also do without this preprocessing step by creating a new custom datacollator for evaluation, but i haven't wrote code for that one.

@jacklanda
Copy link

jacklanda commented May 26, 2024

In #632, there was a discussion on labels being included in the input_ids when using DataCollatorForCompletionOnlyLM.

The provided reason was accurately computing the loss when training the model, which makes sense.

However, I was trying to use trainer.evaluate with predict_with_generate=True to use a custom evaluation metric.

def evaluate_predictions(eval_preds: EvalPrediction):
    #custom function

collator = DataCollatorForCompletionOnlyLM(
    instruction_template="<|user|>",
    response_template="<|assistant|>",
    tokenizer=tokenizer,


training_args = Seq2SeqTrainingArguments(
    output_dir=training_output_path,
    evaluation_strategy="epoch",
    predict_with_generate=True,
    load_best_model_at_end=True,

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_ds["train"],
    eval_dataset=train_ds["test"],
    tokenizer=tokenizer,
    data_collator=collator,
    compute_metrics=evaluate_predictions,

When generating the text to evaluate, the Trainer does not take into account that the labels are present in the text beforehand, and thus the true labels are considered when generating the predictions.

A simple fix would be to change the attention_mask during generation so that the true labels are ignored in the forward pass. This is how I solved it:

  collator = DataCollatorForCompletionOnlyLM(
        instruction_template="<|user|>",
        response_template="<|assistant|>",
        tokenizer=tokenizer,
    )

 def apply_chat_template(
        example,
    ):
        example["input_ids"] = tokenizer.apply_chat_template(
            example["messages"], tokenize=True, add_generation_prompt=False
        )
        
        collated_data = collator([example["input_ids"]])
        
        example["input_ids"] = collated_data["input_ids"][0]
        example["labels"] = collated_data["labels"][0]

        return example


 def mask_attention_response(
        example,
    ):
        response_len = len(tokenizer(example["messages"][-1]["content"])["input_ids"])

        example["attention_mask"] = torch.tensor(
            [1] * (len(example["input_ids"]) - response_len) + [0] * response_len
        )


        return example


    train_ds = raw_ds.map(apply_chat_template).train_test_split(split, seed=SEED)

    train_ds["test"] = (
        train_ds["test"]
        .map(mask_attention_response, desc="Masking assistant response")
    )

I've tried to set attention_mask before calling generate provided by transformers in the Seq2SeqTrainer instance, but how can I make sure there is not any label leakage when using the argument predict_with_generate?

It is not working for me even though I set attention_mask manually for masking the part of the labels of each sample.

@adamamer20
Copy link
Author

adamamer20 commented May 28, 2024

I've tried to set attention_mask before calling generate provided by transformers in the Seq2SeqTrainer instance, but how can I make sure there is not any label leakage when using the argument predict_with_generate?

It is not working for me even though I set attention_mask manually for masking the part of the labels of each sample.

Are you 100% sure there is leakage? Theoretically, setting the attention_mask to 0 should indeed ensure that those tokens are ignored by the attention mechanism, similar to how padding tokens are treated. This should prevent them from affecting the generated output during multiple forward passes. Note that the generated response will still contain the true labels.

@jacklanda
Copy link

jacklanda commented May 28, 2024

I've tried to set attention_mask before calling generate provided by transformers in the Seq2SeqTrainer instance, but how can I make sure there is not any label leakage when using the argument predict_with_generate?
It is not working for me even though I set attention_mask manually for masking the part of the labels of each sample.

Are you 100% sure there is leakage? Theoretically, setting the attention_mask to 0 should indeed ensure that those tokens are ignored by the attention mechanism, similar to how padding tokens are treated. This should prevent them from affecting the generated output during multiple forward passes. Note that the generated response will still contain the true labels.

I got the same predicted results after setting up the correct attention_mask.
The prediction generated from Llama-3 which I am training always follows the labels that have been masked before.
Here is the way I prepared the datapoint:

tokenized_input_sequence = tokenize(input_text)
tokenized_target_sequence = tokenize(target_text)
tokenized_full_sequence = tokenize(input_text + target_text)
if not train_on_input:
    # Do not compute loss on the input sequence
    tokenized_full_sequence["labels"] = [-100] * len(
        tokenized_input_sequence["input_ids"]
    ) + tokenized_target_sequence["input_ids"]

    # Do not leak labels to model while predicting the response
    tokenized_full_sequence["attention_mask"] = [1] * len(
        tokenized_input_sequence["input_ids"]
     ) + [0] * len(tokenized_target_sequence["input_ids"])

@adamamer20
Copy link
Author

I've tried to set attention_mask before calling generate provided by transformers in the Seq2SeqTrainer instance, but how can I make sure there is not any label leakage when using the argument predict_with_generate?
It is not working for me even though I set attention_mask manually for masking the part of the labels of each sample.

Are you 100% sure there is leakage? Theoretically, setting the attention_mask to 0 should indeed ensure that those tokens are ignored by the attention mechanism, similar to how padding tokens are treated. This should prevent them from affecting the generated output during multiple forward passes. Note that the generated response will still contain the true labels.

I got the same predicted results after setting up the correct attention_mask. The prediction generated from Llama-3 which I am training always follows the labels that have been masked before. Here is the way I prepared the datapoint:

tokenized_input_sequence = tokenize(input_text)
tokenized_target_sequence = tokenize(target_text)
tokenized_full_sequence = tokenize(input_text + target_text)
if not train_on_input:
    # Do not compute loss on the input sequence
    tokenized_full_sequence["labels"] = [-100] * len(
        tokenized_input_sequence["input_ids"]
    ) + tokenized_target_sequence["input_ids"]

    # Do not leak labels to model while predicting the response
    tokenized_full_sequence["attention_mask"] = [1] * len(
        tokenized_input_sequence["input_ids"]
     ) + [0] * len(tokenized_target_sequence["input_ids"])

Seems correct to me. If you can, post a minimal reproducible example. I'll take a look when I have time

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants