Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added DataCollatorForMultiTurnCompletions #1592

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

AswanthManoj
Copy link

This PR adds a new data collator DataCollatorForMultiTurnCompletions to the trl library. This data collator is designed for multi-turn completion tasks, where the prompts are ignored while only considering assistant completions in multi-turn conversation datasets.

The existing DataCollatorForCompletionOnlyLM is designed for single-turn completion tasks, where the input contains a single prompt or instruction, and the model is expected to generate a single response. However, it does not handle multi-turn conversational scenarios effectively, where the input sequence might contain multiple user prompts and assistant responses interspersed.

The DataCollatorForMultiTurnCompletions class inherits from DataCollatorForLanguageModeling and extends its functionality to handle multi-turn conversations. It takes two additional arguments:

  1. user_template: A string or list of token IDs that indicate the start of a user prompt or input.
  2. assistant_template: A string or list of token IDs that indicate the start of an assistant response.

During the forward pass, the collator identifies the assistant responses based on the provided templates (or token IDs), and masks out the user prompts during the loss calculation. This ensures that the model is trained to generate appropriate responses based on the user input, while not letting the model learn on the entire text.

The new data collator can be used as follows:

    from datasets import load_dataset
    from transformers import AutoModelForCausalLM, AutoTokenizer
    from trl import SFTTrainer, DataCollatorForMultiTurnCompletions
    
    # Load and prepare your dataset
    dataset = ...
    
    model_id = "<your model id>"

    tokenizer = AutoTokenizer.from_pretrained(model_id)
    model = AutoModelForCausalLM.from_pretrained(model_id)
    
    # If the dataset text field contains multi-turn conversation in `chatml` format.
    
    user_template = "<|im_start|>user\n"
    assistant_template="<|im_start|>assistant\n"
    
    collator = DataCollatorForMultiTurnCompletions(user_template=user_template, assistant_template=assistant_template, tokenizer=tokenizer, mlm=False)

    trainer = SFTTrainer(
        model,
        train_dataset=dataset,
        dataset_text_field="text",
        data_collator=collator,
    )

    trainer.train()

@edbeeching
Copy link
Collaborator

Hi. Thanks for this implementation, we have also been looking at this internally. Here I describe some of the problems with the current implementation.

One challenge with this templating approach are the tokenizer's "edge effects" when mapping strings to ID's. For example "<|im_start|>user\n" may not be mapped to the same sequence of IDs when a whitespace follows it. Some examples from the issue, with a different tempalte:

chat_template = f"### Question: {question}\n ### Answer: {answer}"
response_template = " ### Answer:"
tokenizer.encode(" ### Answer:", add_special_tokens=False)    # [28705,     774, 26307, 28747] 
tokenizer.encode("\n ### Answer:", add_special_tokens=False)  # [28705, 13, 774, 26307, 28747]
tokenizer.encode(".\n ### Answer:", add_special_tokens=False) # [842,   13, 774, 26307, 28747]

I think the simplest approach is to use only special tokens such as <|im_start|> or create special <mask_start> and <mask_end> tokens that indicate where the masking should be done.

Let me know what you think and we can iterate on a solution.

@AswanthManoj
Copy link
Author

AswanthManoj commented Apr 29, 2024

Hi there @edbeeching , Thank you for checking out the implementation. I have made a simple solution considering the tokenizer's "edge effects" by splitting the prompt based on the eos_token or any special token that occurs in every turn. This allows us to identify the assistant's turns and others, enabling us to apply the mask appropriately.

class DataCollatorForMultiTurnCompletions(DataCollatorForLanguageModeling):
    """
    Data collator for multi-turn conversational data. It masks the responses in the input sequences with the
    specified `ignore_index` token, except for the instances where the `response_template` string is present.

    Args:
        separator_token (Union[str, int]): The token used to separate different turns in the conversation.
        response_template (Union[str, List[int]]): The string or token IDs representing the response template.
            Responses matching this template will not be masked.
        mlm (bool, optional): Whether to use Masked Language Modeling or not. Defaults to False.
        ignore_index (int, optional): The index to use for masking responses. Defaults to -100.

    Example usage:
        >>> from transformers import AutoTokenizer
        >>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
        >>> collator = DataCollatorForMultiTurnCompletions(
        ...     separator_token="<|eot_id|>",
        ...     response_template="<|start_header_id|>assistant<|end_header_id|>",
        ...     tokenizer=tokenizer,
        ... )
        ... message = [
        ...     dict(role='user', content='hello there'),
        ...     dict(role='assistant', content="hi how are you"),
        ...     dict(role='user', content='i am great what about you?'),
        ...     dict(role='assistant', content='am also fine thank you'),
        ... ]
        >>> prompt = tokenizer.apply_chat_template(message, tokenize=True, add_special_tokens=False)
        >>> tokenizer.decode(prompt)
            <|begin_of_text|><|start_header_id|>user<|end_header_id|>

            hello there<|eot_id|><|start_header_id|>assistant<|end_header_id|>

            hi how are you<|eot_id|><|start_header_id|>user<|end_header_id|>

            i am great what about you?<|eot_id|><|start_header_id|>assistant<|end_header_id|>

            am also fine thank you<|eot_id|>
            
        >>> batch = collator([inputs])
        >>> batch["input_ids"]
        tensor([[128000, 128006, 882, 128007, 271, 15339, 1070, 128009, 
        128006, 78191, 128007, 271, 6151, 1268, 527, 499, 128009, 
        128006, 882, 128007, 271, 72, 1097, 2294, 1148, 922, 499, 30, 128009, 
        128006, 78191, 128007, 271, 309, 1101, 7060, 9901, 499, 128009]]), 
        
        >>> batch["labels"]
        tensor([[-100, -100, -100, -100, -100, -100, -100, -100, 
                128006, 78191, 128007, 271, 6151, 1268, 527, 499, 128009, 
                -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 
                128006, 78191, 128007, 271, 309, 1101, 7060, 9901, 499, 128009]])
    """
    def __init__(
        self,
        separator_token: Union[str, int],
        response_template: Union[str, List[int]],
        *args,
        mlm: bool = False,
        ignore_index: int = -100,
        **kwargs,
    ):
        super().__init__(*args, mlm=mlm, **kwargs)
        
        if isinstance(separator_token, str):
            self.separator_token = separator_token
            separator_token_ids = self.tokenizer.encode(separator_token, add_special_tokens=False)
            if len(separator_token_ids) != 1:
                raise ValueError("The separator token should be a single token.")
        else:
            self.separator_token = self.tokenizer.decode(separator_token)
            
        if isinstance(response_template, str):
            self.response_template = response_template
            self.response_template_ids = self.tokenizer.encode(response_template, add_special_tokens=False)
        else:
            self.response_template = self.tokenizer.decode(response_template)
            self.response_template_ids = response_template
        self.ignore_index = ignore_index

    def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
        batch = super().torch_call(examples)
        input_ids = batch["input_ids"]
        labels = []

        for i in range(input_ids.size(0)):
            decoded_input = self.tokenizer.decode(input_ids[i])
            splits = decoded_input.split(self.separator_token)
            instance_labels = []
            separator_token_is_eos_token = self.separator_token == self.tokenizer.eos_token
            for j, split in enumerate(splits):
                if split:
                    if separator_token_is_eos_token:
                        split = split + self.separator_token
                    else:
                        split = self.separator_token + split
                    if self.response_template in split:
                        instance_labels.extend(self.tokenizer.encode(split, add_special_tokens=False))
                    else:
                        instance_labels.extend(len(self.tokenizer.encode(split, add_special_tokens=False)) * [self.ignore_index])

            labels.append(instance_labels)

        if not any(self.response_template in self.tokenizer.decode(input_ids[i]) for i in range(input_ids.size(0))):
            warnings.warn(
                f"Could not find response template `{self.response_template}` in any instance. "
                f"These instances will be ignored in loss calculation. "
                f"Note, if this happens often, consider increasing the `max_seq_length`."
            )

        if len(input_ids) != len(labels):
            raise ValueError(
                "The lengths of input_ids and labels do not match after processing. "
                "This should not happen and may indicate a bug in the DataCollator."
            )

        return {"input_ids": input_ids, "labels": labels}

This modified implementation works by splitting the input sequences based on the provided separator_token, which separates different turns in the conversation. For each split:

  • If it contains the response_template, the entire split is included in the labels without any masking.
  • If it does not contain the response_template, the split is masked by replacing all tokens with the specified ignore_index.

Please check this out and do let me know.

@hahuyhoang411
Copy link

This is cool. Thanks @AswanthManoj

For the example, I think you have a wrong variable
batch = collator([inputs]) -> batch = collator([prompt])

Copy link

github-actions bot commented Jun 6, 2024

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants