Skip to content

Commit

Permalink
Use less problematic whitespace token (#916)
Browse files Browse the repository at this point in the history
Fixes #839 #908 #690 #450

## Problem

A major problem, especially with smaller language models, is the
repetition problem.

For example, let's say a model is generating json and must provide 12
space tokens for indentation in json output. Often a language model will
assign a high probability to a 13th space token, and do the same for a
14th space, and then enter an infinite space generation loop.

This is a problem with NLG that has been known for half a decade, but
only has mitigations (mirostat, repetition penalty, using hundreds of
billions of weights, etc), no absolute solutions (except for
**structured generation**)

## Solution

For structured json generation, we set a sane default whitespace pattern
of `r"[ ]?"`. This removes all newlines and indentation. It disallows
any syntactic whitespace beyond a single space separator.

Users can still set the argument `whitespace_pattern=` if they want
different behavior
  • Loading branch information
lapp0 committed May 24, 2024
1 parent ba7affd commit 411eaaf
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 33 deletions.
4 changes: 2 additions & 2 deletions docs/reference/json.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@ print(result)

!!! Note "JSON and whitespaces"

By default Outlines lets model choose the number of linebreaks and white spaces used to structure the JSON. Small models tend to struggle with this, in which case we recommend to set the value of the parameter `whitespace_pattern` to the empty string:
By default Outlines prevents the model from generating json with syntactic newlines, tabs, or multiple spaces. The default `whitespace_pattern` is `r"[ ]?"`. Small models tend to enter an infinite repetition loop if the `whitespace_pattern` allows infinite spacing. If you would like to allow the model to generate multiple tabs, newlines, and spaces, you can set the whitespace pattern as follows:

```python
generator = generate.json(model, User, whitespace_pattern="")
generator = generate.json(model, User, whitespace_pattern=r"[\n\t ]*")
```

!!! Note "Performance"
Expand Down
2 changes: 1 addition & 1 deletion outlines/fsm/json_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
NUMBER = rf"({INTEGER})(\.[0-9]+)?([eE][+-][0-9]+)?"
BOOLEAN = r"(true|false)"
NULL = r"null"
WHITESPACE = r"[\n ]*"
WHITESPACE = r"[ ]?"

type_to_regex = {
"string": STRING,
Expand Down
45 changes: 15 additions & 30 deletions tests/fsm/test_json_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,8 +215,8 @@ def test_match_number(pattern, does_match):
"properties": {"count": {"title": "Count", "type": "integer"}},
"required": ["count"],
},
'\\{[\\n ]*"count"[\\n ]*:[\\n ]*(-)?(0|[1-9][0-9]*)[\\n ]*\\}',
[('{\n "count": 100\n}', True)],
'\\{[ ]?"count"[ ]?:[ ]?(-)?(0|[1-9][0-9]*)[ ]?\\}',
[('{ "count": 100 }', True)],
),
# array
(
Expand Down Expand Up @@ -277,7 +277,7 @@ def test_match_number(pattern, does_match):
rf"""\{{{WHITESPACE}"test_dict"{WHITESPACE}:{WHITESPACE}\{{{WHITESPACE}({STRING}{WHITESPACE}:{WHITESPACE}{STRING}({WHITESPACE},{WHITESPACE}{STRING}{WHITESPACE}:{WHITESPACE}{STRING}){{0,}})?{WHITESPACE}\}}{WHITESPACE}\}}""",
[
("""{ "test_dict":{"foo":"bar","baz": "bif"}}""", True),
("""{ "test_dict":{"foo":"bar"\n}}""", True),
("""{ "test_dict":{"foo":"bar" }}""", True),
("""{ "test_dict":{}}""", True),
("""{ "WRONG_KEY":{}}""", False),
("""{ "test_dict":{"wrong_type" 1}}""", False),
Expand Down Expand Up @@ -369,8 +369,8 @@ def test_match_number(pattern, does_match):
},
"required": ["fuzz"],
},
f'\\{{[\\n ]*"fuzz"[\\n ]*:[\\n ]*\\{{[\\n ]*"spam"[\\n ]*:[\\n ]*{INTEGER}[\\n ]*\\}}[\\n ]*\\}}',
[('{\n "fuzz": {\n "spam": 100\n }\n}', True)],
f'\\{{[ ]?"fuzz"[ ]?:[ ]?\\{{[ ]?"spam"[ ]?:[ ]?{INTEGER}[ ]?\\}}[ ]?\\}}',
[('{ "fuzz": { "spam": 100 }}', True)],
),
# Schema with a reference
(
Expand All @@ -384,7 +384,7 @@ def test_match_number(pattern, does_match):
},
"required": ["user_id", "name", "a"],
},
f'\\{{[\\n ]*"user_id"[\\n ]*:[\\n ]*{INTEGER}[\\n ]*,[\\n ]*"name"[\\n ]*:[\\n ]*{STRING}[\\n ]*,[\\n ]*"a"[\\n ]*:[\\n ]*{STRING}[\\n ]*\\}}',
f'\\{{[ ]?"user_id"[ ]?:[ ]?{INTEGER}[ ]?,[ ]?"name"[ ]?:[ ]?{STRING}[ ]?,[ ]?"a"[ ]?:[ ]?{STRING}[ ]?\\}}',
[('{"user_id": 100, "name": "John", "a": "Marc"}', True)],
),
(
Expand All @@ -399,7 +399,7 @@ def test_match_number(pattern, does_match):
},
"required": ["user_id", "name", "name2"],
},
f'\\{{[\\n ]*"user_id"[\\n ]*:[\\n ]*{INTEGER}[\\n ]*,[\\n ]*"name"[\\n ]*:[\\n ]*{STRING}[\\n ]*,[\\n ]*"name2"[\\n ]*:[\\n ]*{STRING}[\\n ]*\\}}',
f'\\{{[ ]?"user_id"[ ]?:[ ]?{INTEGER}[ ]?,[ ]?"name"[ ]?:[ ]?{STRING}[ ]?,[ ]?"name2"[ ]?:[ ]?{STRING}[ ]?\\}}',
[('{"user_id": 100, "name": "John", "name2": "Marc"}', True)],
),
(
Expand Down Expand Up @@ -441,7 +441,7 @@ def test_match_number(pattern, does_match):
}
},
},
f'\\{{[\\n ]*"name"[\\n ]*:[\\n ]*{STRING}[\\n ]*,[\\n ]*"last_name"[\\n ]*:[\\n ]*{STRING}[\\n ]*,[\\n ]*"address"[\\n ]*:[\\n ]*\\{{[\\n ]*"city"[\\n ]*:[\\n ]*{STRING}[\\n ]*\\}}[\\n ]*\\}}',
f'\\{{[ ]?"name"[ ]?:[ ]?{STRING}[ ]?,[ ]?"last_name"[ ]?:[ ]?{STRING}[ ]?,[ ]?"address"[ ]?:[ ]?\\{{[ ]?"city"[ ]?:[ ]?{STRING}[ ]?\\}}[ ]?\\}}',
[
(
'{"name": "John", "last_name": "Doe", "address": {"city": "Paris"}}',
Expand All @@ -462,7 +462,7 @@ def test_match_number(pattern, does_match):
"title": "Character",
"type": "object",
},
f'\\{{[\\n ]*"name"[\\n ]*:[\\n ]*{STRING}([\\n ]*,[\\n ]*"age"[\\n ]*:[\\n ]*({INTEGER}|null))?([\\n ]*,[\\n ]*"weapon"[\\n ]*:[\\n ]*({STRING}|null))?[\\n ]*\\}}',
f'\\{{[ ]?"name"[ ]?:[ ]?{STRING}([ ]?,[ ]?"age"[ ]?:[ ]?({INTEGER}|null))?([ ]?,[ ]?"weapon"[ ]?:[ ]?({STRING}|null))?[ ]?\\}}',
[
('{ "name" : "Player" }', True),
('{ "name" : "Player", "weapon" : "sword" }', True),
Expand All @@ -482,7 +482,7 @@ def test_match_number(pattern, does_match):
"title": "Character",
"type": "object",
},
f'\\{{[\\n ]*"name"[\\n ]*:[\\n ]*{STRING}[\\n ]*,([\\n ]*"age"[\\n ]*:[\\n ]*({INTEGER}|null)[\\n ]*,)?[\\n ]*"weapon"[\\n ]*:[\\n ]*{STRING}([\\n ]*,[\\n ]*"strength"[\\n ]*:[\\n ]*({INTEGER}|null))?[\\n ]*\\}}',
f'\\{{[ ]?"name"[ ]?:[ ]?{STRING}[ ]?,([ ]?"age"[ ]?:[ ]?({INTEGER}|null)[ ]?,)?[ ]?"weapon"[ ]?:[ ]?{STRING}([ ]?,[ ]?"strength"[ ]?:[ ]?({INTEGER}|null))?[ ]?\\}}',
[
('{ "name" : "Player" , "weapon" : "sword" }', True),
(
Expand All @@ -506,7 +506,7 @@ def test_match_number(pattern, does_match):
"title": "Character",
"type": "object",
},
f'\\{{([\\n ]*"name"[\\n ]*:[\\n ]*({STRING}|null)[\\n ]*,)?[\\n ]*"age"[\\n ]*:[\\n ]*{INTEGER}[\\n ]*,[\\n ]*"armor"[\\n ]*:[\\n ]*{STRING}[\\n ]*,([\\n ]*"strength"[\\n ]*:[\\n ]*({INTEGER}|null)[\\n ]*,)?[\\n ]*"weapon"[\\n ]*:[\\n ]*{STRING}[\\n ]*\\}}',
f'\\{{([ ]?"name"[ ]?:[ ]?({STRING}|null)[ ]?,)?[ ]?"age"[ ]?:[ ]?{INTEGER}[ ]?,[ ]?"armor"[ ]?:[ ]?{STRING}[ ]?,([ ]?"strength"[ ]?:[ ]?({INTEGER}|null)[ ]?,)?[ ]?"weapon"[ ]?:[ ]?{STRING}[ ]?\\}}',
[
(
'{ "name" : "Player", "age" : 10, "armor" : "plate", "strength" : 11, "weapon" : "sword" }',
Expand All @@ -530,7 +530,7 @@ def test_match_number(pattern, does_match):
"title": "Character",
"type": "object",
},
f'\\{{([\\n ]*"name"[\\n ]*:[\\n ]*({STRING}|null)([\\n ]*,[\\n ]*"age"[\\n ]*:[\\n ]*({INTEGER}|null))?([\\n ]*,[\\n ]*"strength"[\\n ]*:[\\n ]*({INTEGER}|null))?|([\\n ]*"name"[\\n ]*:[\\n ]*({STRING}|null)[\\n ]*,)?[\\n ]*"age"[\\n ]*:[\\n ]*({INTEGER}|null)([\\n ]*,[\\n ]*"strength"[\\n ]*:[\\n ]*({INTEGER}|null))?|([\\n ]*"name"[\\n ]*:[\\n ]*({STRING}|null)[\\n ]*,)?([\\n ]*"age"[\\n ]*:[\\n ]*({INTEGER}|null)[\\n ]*,)?[\\n ]*"strength"[\\n ]*:[\\n ]*({INTEGER}|null))?[\\n ]*\\}}',
f'\\{{([ ]?"name"[ ]?:[ ]?({STRING}|null)([ ]?,[ ]?"age"[ ]?:[ ]?({INTEGER}|null))?([ ]?,[ ]?"strength"[ ]?:[ ]?({INTEGER}|null))?|([ ]?"name"[ ]?:[ ]?({STRING}|null)[ ]?,)?[ ]?"age"[ ]?:[ ]?({INTEGER}|null)([ ]?,[ ]?"strength"[ ]?:[ ]?({INTEGER}|null))?|([ ]?"name"[ ]?:[ ]?({STRING}|null)[ ]?,)?([ ]?"age"[ ]?:[ ]?({INTEGER}|null)[ ]?,)?[ ]?"strength"[ ]?:[ ]?({INTEGER}|null))?[ ]?\\}}',
[
('{ "name" : "Player" }', True),
('{ "name" : "Player", "age" : 10, "strength" : 10 }', True),
Expand Down Expand Up @@ -710,19 +710,6 @@ def test_format(schema, regex, examples):
('{"time":20:20:39Z}', False), # missing quotes for value
],
),
# Unconstrained Object
(
{
"title": "Foo",
"type": "object",
},
[
("{}", True),
('{"a": 1, "b": null}', True),
('{"a": {"z": {"g": 4}}, "b": null}', True),
("1234", False), # not an object
],
),
],
)
def test_format_without_regex(schema, examples):
Expand All @@ -737,7 +724,7 @@ def test_format_without_regex(schema, examples):
assert match is None
@pytest.mark.parametrize("whitespace_pattern", [None, r"[\n ]?", "abc"])
@pytest.mark.parametrize("whitespace_pattern", [None, r"[\n ]*", "abc"])
def test_json_schema_custom_whitespace_pattern(whitespace_pattern):
"""assert whitespace_pattern setting respected"""
Expand All @@ -759,13 +746,11 @@ class MockModel(BaseModel):
)
mock_result_maybe_ws = """{"foo" : 4 ,"bar":"baz baz baz bar"}"""
match_default_ws = re.fullmatch(pattern, mock_result_mult_ws)
match_default_ws = re.fullmatch(pattern, mock_result_maybe_ws)
if whitespace_pattern is None:
assert match_default_ws
else:
assert match_default_ws is None
assert re.fullmatch(pattern, mock_result_maybe_ws)
assert re.fullmatch(pattern, mock_result_mult_ws)
def test_one_of_doesnt_produce_illegal_lookaround():
Expand Down

0 comments on commit 411eaaf

Please sign in to comment.