Skip to content

Commit

Permalink
Add chat templating support for KeyDataset in text-generation pipeline (
Browse files Browse the repository at this point in the history
#30558)

* added chat templating support for keydataset in generation pipeline

* fixed and improved test

* fix formatting test failures

* Fix tests

* Fix tests
  • Loading branch information
DarshanDeshpande committed Apr 30, 2024
1 parent 0cdb6b3 commit 2ecefc3
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 2 deletions.
8 changes: 6 additions & 2 deletions src/transformers/pipelines/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

if is_torch_available():
from ..models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
from .pt_utils import KeyDataset

if is_tf_available():
import tensorflow as tf
Expand Down Expand Up @@ -243,7 +244,9 @@ def __call__(self, text_inputs, **kwargs):
- **generated_token_ids** (`torch.Tensor` or `tf.Tensor`, present when `return_tensors=True`) -- The token
ids of the generated text.
"""
if isinstance(text_inputs, (list, tuple)) and isinstance(text_inputs[0], (list, tuple, dict)):
if isinstance(
text_inputs, (list, tuple, KeyDataset) if is_torch_available() else (list, tuple)
) and isinstance(text_inputs[0], (list, tuple, dict)):
# We have one or more prompts in list-of-dicts format, so this is chat mode
if isinstance(text_inputs[0], dict):
return super().__call__(Chat(text_inputs), **kwargs)
Expand Down Expand Up @@ -380,7 +383,8 @@ def postprocess(self, model_outputs, return_type=ReturnType.FULL_TEXT, clean_up_
if isinstance(prompt_text, str):
all_text = prompt_text + all_text
elif isinstance(prompt_text, Chat):
all_text = prompt_text.messages + [{"role": "assistant", "content": all_text}]
# Explicit list parsing is necessary for parsing chat datasets
all_text = list(prompt_text.messages) + [{"role": "assistant", "content": all_text}]

record = {"generated_text": all_text}
records.append(record)
Expand Down
42 changes: 42 additions & 0 deletions tests/pipelines/test_pipelines_text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,48 @@ def test_small_chat_model_pt(self):
],
)

@require_torch
def test_small_chat_model_with_dataset_pt(self):
from torch.utils.data import Dataset

from transformers.pipelines.pt_utils import KeyDataset

class MyDataset(Dataset):
data = [
[
{"role": "system", "content": "This is a system message."},
{"role": "user", "content": "This is a test"},
{"role": "assistant", "content": "This is a reply"},
],
]

def __len__(self):
return 1

def __getitem__(self, i):
return {"text": self.data[i]}

text_generator = pipeline(
task="text-generation", model="rocketknight1/tiny-gpt2-with-chatml-template", framework="pt"
)

dataset = MyDataset()
key_dataset = KeyDataset(dataset, "text")

for outputs in text_generator(key_dataset, do_sample=False, max_new_tokens=10):
expected_chat = dataset.data[0] + [
{
"role": "assistant",
"content": " factors factors factors factors factors factors factors factors factors factors",
}
]
self.assertEqual(
outputs,
[
{"generated_text": expected_chat},
],
)

@require_tf
def test_small_model_tf(self):
text_generator = pipeline(task="text-generation", model="sshleifer/tiny-ctrl", framework="tf")
Expand Down

0 comments on commit 2ecefc3

Please sign in to comment.