-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added DataCollatorForMultiTurnCompletions #1592
base: main
Are you sure you want to change the base?
Conversation
Hi. Thanks for this implementation, we have also been looking at this internally. Here I describe some of the problems with the current implementation. One challenge with this templating approach are the tokenizer's "edge effects" when mapping strings to ID's. For example
I think the simplest approach is to use only special tokens such as Let me know what you think and we can iterate on a solution. |
Hi there @edbeeching , Thank you for checking out the implementation. I have made a simple solution considering the tokenizer's "edge effects" by splitting the prompt based on the eos_token or any special token that occurs in every turn. This allows us to identify the assistant's turns and others, enabling us to apply the mask appropriately. class DataCollatorForMultiTurnCompletions(DataCollatorForLanguageModeling):
"""
Data collator for multi-turn conversational data. It masks the responses in the input sequences with the
specified `ignore_index` token, except for the instances where the `response_template` string is present.
Args:
separator_token (Union[str, int]): The token used to separate different turns in the conversation.
response_template (Union[str, List[int]]): The string or token IDs representing the response template.
Responses matching this template will not be masked.
mlm (bool, optional): Whether to use Masked Language Modeling or not. Defaults to False.
ignore_index (int, optional): The index to use for masking responses. Defaults to -100.
Example usage:
>>> from transformers import AutoTokenizer
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
>>> collator = DataCollatorForMultiTurnCompletions(
... separator_token="<|eot_id|>",
... response_template="<|start_header_id|>assistant<|end_header_id|>",
... tokenizer=tokenizer,
... )
... message = [
... dict(role='user', content='hello there'),
... dict(role='assistant', content="hi how are you"),
... dict(role='user', content='i am great what about you?'),
... dict(role='assistant', content='am also fine thank you'),
... ]
>>> prompt = tokenizer.apply_chat_template(message, tokenize=True, add_special_tokens=False)
>>> tokenizer.decode(prompt)
<|begin_of_text|><|start_header_id|>user<|end_header_id|>
hello there<|eot_id|><|start_header_id|>assistant<|end_header_id|>
hi how are you<|eot_id|><|start_header_id|>user<|end_header_id|>
i am great what about you?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
am also fine thank you<|eot_id|>
>>> batch = collator([inputs])
>>> batch["input_ids"]
tensor([[128000, 128006, 882, 128007, 271, 15339, 1070, 128009,
128006, 78191, 128007, 271, 6151, 1268, 527, 499, 128009,
128006, 882, 128007, 271, 72, 1097, 2294, 1148, 922, 499, 30, 128009,
128006, 78191, 128007, 271, 309, 1101, 7060, 9901, 499, 128009]]),
>>> batch["labels"]
tensor([[-100, -100, -100, -100, -100, -100, -100, -100,
128006, 78191, 128007, 271, 6151, 1268, 527, 499, 128009,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
128006, 78191, 128007, 271, 309, 1101, 7060, 9901, 499, 128009]])
"""
def __init__(
self,
separator_token: Union[str, int],
response_template: Union[str, List[int]],
*args,
mlm: bool = False,
ignore_index: int = -100,
**kwargs,
):
super().__init__(*args, mlm=mlm, **kwargs)
if isinstance(separator_token, str):
self.separator_token = separator_token
separator_token_ids = self.tokenizer.encode(separator_token, add_special_tokens=False)
if len(separator_token_ids) != 1:
raise ValueError("The separator token should be a single token.")
else:
self.separator_token = self.tokenizer.decode(separator_token)
if isinstance(response_template, str):
self.response_template = response_template
self.response_template_ids = self.tokenizer.encode(response_template, add_special_tokens=False)
else:
self.response_template = self.tokenizer.decode(response_template)
self.response_template_ids = response_template
self.ignore_index = ignore_index
def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
batch = super().torch_call(examples)
input_ids = batch["input_ids"]
labels = []
for i in range(input_ids.size(0)):
decoded_input = self.tokenizer.decode(input_ids[i])
splits = decoded_input.split(self.separator_token)
instance_labels = []
separator_token_is_eos_token = self.separator_token == self.tokenizer.eos_token
for j, split in enumerate(splits):
if split:
if separator_token_is_eos_token:
split = split + self.separator_token
else:
split = self.separator_token + split
if self.response_template in split:
instance_labels.extend(self.tokenizer.encode(split, add_special_tokens=False))
else:
instance_labels.extend(len(self.tokenizer.encode(split, add_special_tokens=False)) * [self.ignore_index])
labels.append(instance_labels)
if not any(self.response_template in self.tokenizer.decode(input_ids[i]) for i in range(input_ids.size(0))):
warnings.warn(
f"Could not find response template `{self.response_template}` in any instance. "
f"These instances will be ignored in loss calculation. "
f"Note, if this happens often, consider increasing the `max_seq_length`."
)
if len(input_ids) != len(labels):
raise ValueError(
"The lengths of input_ids and labels do not match after processing. "
"This should not happen and may indicate a bug in the DataCollator."
)
return {"input_ids": input_ids, "labels": labels} This modified implementation works by splitting the input sequences based on the provided
Please check this out and do let me know. |
This is cool. Thanks @AswanthManoj For the example, I think you have a wrong variable |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. |
This PR adds a new data collator
DataCollatorForMultiTurnCompletions
to thetrl
library. This data collator is designed for multi-turn completion tasks, where the prompts are ignored while only considering assistant completions in multi-turn conversation datasets.The existing
DataCollatorForCompletionOnlyLM
is designed for single-turn completion tasks, where the input contains a single prompt or instruction, and the model is expected to generate a single response. However, it does not handle multi-turn conversational scenarios effectively, where the input sequence might contain multiple user prompts and assistant responses interspersed.The
DataCollatorForMultiTurnCompletions
class inherits fromDataCollatorForLanguageModeling
and extends its functionality to handle multi-turn conversations. It takes two additional arguments:user_template
: A string or list of token IDs that indicate the start of a user prompt or input.assistant_template
: A string or list of token IDs that indicate the start of an assistant response.During the forward pass, the collator identifies the assistant responses based on the provided templates (or token IDs), and masks out the user prompts during the loss calculation. This ensures that the model is trained to generate appropriate responses based on the user input, while not letting the model learn on the entire text.
The new data collator can be used as follows: