From ae9ae50759b471b07dae858a537ac8bb1c7afcaa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Wed, 10 Apr 2024 11:32:00 +0200 Subject: [PATCH] Add integrations tests for the vLLM integration --- outlines/models/vllm.py | 2 + tests/generate/test_integration_llamacpp.py | 2 +- tests/generate/test_integration_vllm.py | 127 +++++++++++++++++++- 3 files changed, 126 insertions(+), 5 deletions(-) diff --git a/outlines/models/vllm.py b/outlines/models/vllm.py index 0e186fc06..378a35d91 100644 --- a/outlines/models/vllm.py +++ b/outlines/models/vllm.py @@ -74,6 +74,8 @@ def generate( if max_tokens is not None: sampling_params.max_tokens = max_tokens if stop_at is not None: + if isinstance(stop_at, str): + stop_at = [stop_at] sampling_params.stop = stop_at if seed is not None: sampling_params.seed = seed diff --git a/tests/generate/test_integration_llamacpp.py b/tests/generate/test_integration_llamacpp.py index 0007a7362..d036b560f 100644 --- a/tests/generate/test_integration_llamacpp.py +++ b/tests/generate/test_integration_llamacpp.py @@ -181,7 +181,7 @@ def test_llamacpp_date(model): prompt = ( "<|im_start|>user\nWhat day is it today?<|im_end|>\n<|im_start|>assistant\n" ) - sequence = generate.format(model, datetime.date)(prompt, max_tokens=10) + sequence = generate.format(model, datetime.date)(prompt, max_tokens=20, seed=10) assert isinstance(sequence, datetime.date) diff --git a/tests/generate/test_integration_vllm.py b/tests/generate/test_integration_vllm.py index b097cb0b4..4634bc839 100644 --- a/tests/generate/test_integration_vllm.py +++ b/tests/generate/test_integration_vllm.py @@ -1,8 +1,13 @@ +import datetime +import re + import pytest import torch +from pydantic import BaseModel, constr from vllm.sampling_params import SamplingParams import outlines.generate as generate +import outlines.grammars as grammars import outlines.models as models import outlines.samplers as samplers @@ -114,7 +119,121 @@ def test_vllm_beam_search(model): assert res[0] != res[1] -@pytest.mark.xfail(reason="CFG logits processor not available for vLLM") -def test_cfg_simple(model): - generator = generate.cfg(model) - _ = generator("test") +def test_vllm_text_stop(model): + prompt = "Write a short sentence containing 'You': " + sequence = generate.text(model)(prompt, max_tokens=100, seed=10) + assert sequence.find("news") != -1 + + sequence = generate.text(model)(prompt, stop_at="news", max_tokens=100, seed=10) + assert isinstance(sequence, str) + assert sequence.find("news") == -1 + + +def test_vllm_regex(model): + prompt = "Write an email address: " + regex_str = r"([a-z]{10})@([a-z]{5})\.([a-z]{3})" + generator = generate.regex(model, regex_str) + + # One prompt + sequence = generator(prompts=prompt) + assert isinstance(sequence, str) + assert re.fullmatch(pattern=regex_str, string=sequence) is not None + + +def test_vllm_integer(model): + prompt = "Give me an integer: " + sequence = generate.format(model, int)(prompt, max_tokens=10) + assert isinstance(sequence, int) + assert sequence != "" + int(sequence) + + +def test_vllm_float(model): + prompt = "Give me a floating-point number: " + sequence = generate.format(model, float)(prompt, max_tokens=10) + assert isinstance(sequence, float) + + assert sequence != "" + float(sequence) + + +def test_vllm_bool(model): + prompt = "Is this True or False? " + sequence = generate.format(model, bool)(prompt, max_tokens=10) + assert isinstance(sequence, bool) + + assert sequence != "" + bool(sequence) + + +def test_vllm_date(model): + prompt = "What day is it today? " + sequence = generate.format(model, datetime.date)(prompt, max_tokens=10) + assert isinstance(sequence, datetime.date) + + +def test_vllm_time(model): + prompt = "What time is it? " + sequence = generate.format(model, datetime.time)(prompt, max_tokens=10) + assert isinstance(sequence, datetime.time) + + +def test_vllm_datetime(model): + prompt = "What time is it? " + sequence = generate.format(model, datetime.datetime)(prompt, max_tokens=20) + assert isinstance(sequence, datetime.datetime) + + +def test_vllm_choice(model): + prompt = "Which one between 'test' and 'choice'? " + sequence = generate.choice(model, ["test", "choice"])(prompt) + assert sequence == "test" or sequence == "choice" + + +def test_vllm_json_basic(model): + prompt = "Output some JSON. " + + class Spam(BaseModel): + spam: constr(max_length=10) + fuzz: bool + + sampling_params = SamplingParams(temperature=0) + result = generate.json(model, Spam, whitespace_pattern="")( + prompt, max_tokens=100, seed=1, sampling_params=sampling_params + ) + assert isinstance(result, BaseModel) + assert isinstance(result.spam, str) + assert isinstance(result.fuzz, bool) + assert len(result.spam) <= 10 + + +def test_vllm_json_schema(model): + prompt = "Output some JSON. " + + schema = """{ + "title": "spam", + "type": "object", + "properties": { + "foo" : {"type": "boolean"}, + "bar": {"type": "string", "maxLength": 4} + }, + "required": ["foo", "bar"] + } + """ + + sampling_params = SamplingParams(temperature=0) + result = generate.json(model, schema, whitespace_pattern="")( + prompt, max_tokens=100, seed=10, sampling_params=sampling_params + ) + assert isinstance(result, dict) + assert isinstance(result["foo"], bool) + assert isinstance(result["bar"], str) + + +@pytest.mark.xfail( + reason="The CFG logits processor for vLLM has not been implemented yet." +) +def test_vllm_cfg(model): + prompt = "<|im_start|>user\nOutput a short and valid JSON object with two keys.<|im_end|>\n><|im_start|>assistant\n" + result = generate.cfg(model, grammars.arithmetic)(prompt, seed=11) + assert isinstance(result, str)