Skip to content

Commit

Permalink
Use a persistent Tokenizer hash for create_states_mapping cache
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed May 23, 2024
1 parent d7c9707 commit ba7affd
Show file tree
Hide file tree
Showing 5 changed files with 117 additions and 36 deletions.
56 changes: 28 additions & 28 deletions outlines/fsm/guide.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,44 +105,44 @@ def copy(self):
return self


@cache()
def create_states_mapping(
regex_string: str, tokenizer: "Tokenizer"
) -> Tuple[dict, set, set]:
"""Create the variables related to the mapping between states and tokens
The parameters of the function are used for caching purpose
"""
regex_pattern = interegular.parse_pattern(regex_string)
byte_fsm = make_byte_level_fsm(regex_pattern.to_fsm().reduce(), keep_utf8=True)
regex_fsm, _ = make_deterministic_fsm(byte_fsm)
states_to_token_maps, empty_token_ids = create_fsm_index_tokenizer(
regex_fsm, tokenizer
)

# We make sure that it is possible to generate strings in the language
# of the regular expression with the tokens present in the model's
# vocabulary.
if not any(
regex_fsm.finals.intersection(v.values()) for v in states_to_token_maps.values()
):
raise ValueError(
"The vocabulary does not allow us to build a sequence that matches the input regex"
)

return states_to_token_maps, empty_token_ids, regex_fsm.finals


class RegexGuide(Guide):
"""Guide to generate text in the language of a regular expression."""

initial_state = 0

def __init__(self, regex_string: str, tokenizer):
@cache()
def create_states_mapping(regex_string: str) -> Tuple[dict, set, set]:
"""Create the variables related to the mapping between states and tokens
The parameters of the function are used for caching purpose
"""
regex_pattern = interegular.parse_pattern(regex_string)
byte_fsm = make_byte_level_fsm(
regex_pattern.to_fsm().reduce(), keep_utf8=True
)
regex_fsm, _ = make_deterministic_fsm(byte_fsm)
states_to_token_maps, empty_token_ids = create_fsm_index_tokenizer(
regex_fsm, tokenizer
)

# We make sure that it is possible to generate strings in the language
# of the regular expression with the tokens present in the model's
# vocabulary.
if not any(
regex_fsm.finals.intersection(v.values())
for v in states_to_token_maps.values()
):
raise ValueError(
"The vocabulary does not allow us to build a sequence that matches the input regex"
)

return states_to_token_maps, empty_token_ids, regex_fsm.finals

(
self.states_to_token_maps,
self.empty_token_ids,
fsm_finals,
) = create_states_mapping(regex_string)
) = create_states_mapping(regex_string, tokenizer)
self.eos_token_id = tokenizer.eos_token_id
self.final_states = fsm_finals | {-1}

Expand Down
11 changes: 9 additions & 2 deletions outlines/models/transformers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import TYPE_CHECKING, List, Optional, Tuple, Union

from datasets.fingerprint import Hasher

from outlines.models.tokenizer import Tokenizer

if TYPE_CHECKING:
Expand Down Expand Up @@ -109,10 +111,15 @@ def __eq__(self, other):
return NotImplemented

def __hash__(self):
from datasets.fingerprint import Hasher

return hash(Hasher.hash(self.tokenizer))

def __getstate__(self):
state = {"tokenizer": self.tokenizer}
return state

def __setstate__(self, state):
self.__init__(state["tokenizer"])


class Transformers:
"""Represents a `transformers` model."""
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ dependencies = [
"referencing",
"jsonschema",
"requests",
"tqdm"
"tqdm",
"datasets",
]
dynamic = ["version"]

Expand All @@ -50,7 +51,6 @@ test = [
"diff-cover",
"accelerate",
"beartype<0.16.0",
"datasets",
"responses",
"llama-cpp-python",
"huggingface_hub",
Expand Down
68 changes: 67 additions & 1 deletion tests/generate/test_integration_transformers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import datetime
import re
from enum import Enum
from importlib import reload
from typing import List, Union

import pytest
Expand All @@ -11,7 +12,28 @@
import outlines.models as models
from outlines.fsm.regex import reduced_vocabulary
from outlines.models.transformers import Transformers, TransformerTokenizer
from outlines.samplers import beam_search, multinomial
from outlines.samplers import beam_search, greedy, multinomial


@pytest.fixture
def temp_cache_dir():
import os
import tempfile

import outlines.caching
import outlines.fsm.guide

with tempfile.TemporaryDirectory() as tempdir:
os.environ["OUTLINES_CACHE_DIR"] = tempdir
outlines.caching.get_cache.cache_clear()
reload(outlines)
reload(outlines.fsm.guide)
cache_status = outlines.caching._caching_enabled
try:
outlines.caching._caching_enabled = True
yield
finally:
outlines.caching._caching_enabled = cache_status


def test_transformers_integration_text():
Expand Down Expand Up @@ -632,3 +654,47 @@ def test_transformers_use_existing_model_and_tokenizer():
model = Transformers(hf_model, hf_tokenizer)
sequence = generate.text(model)("Write a short sentence ", rng=rng)
assert isinstance(sequence, str)


def test_RegexGuide_caching(temp_cache_dir):
import outlines.caching
from outlines.fsm.guide import create_states_mapping

assert outlines.caching._caching_enabled

regex = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
prompt = "What is the IP address of the Google DNS servers? "

cache = outlines.caching.get_cache()

# Returns (hits, misses)
_ = cache.stats(enable=True)
assert cache.statistics

assert create_states_mapping.__memory__ is cache

model = models.transformers(
"hf-internal-testing/tiny-random-XLMRobertaXLForCausalLM"
)
generator = generate.regex(model, regex, sampler=greedy())
assert cache.stats() == (0, 1)

model_2 = models.transformers("hf-internal-testing/tiny-random-GPTJForCausalLM")
generator_2 = generate.regex(model_2, regex, sampler=greedy())
assert cache.stats() == (0, 2)

# These two different models and tokenizers should not have the same state
# mapping results
assert generator.fsm.states_to_token_maps != generator_2.fsm.states_to_token_maps

generator_3 = generate.regex(model_2, regex, sampler=greedy())
assert cache.stats() == (1, 2)
assert generator_2.fsm.states_to_token_maps == generator_3.fsm.states_to_token_maps

# Just for fun...
structured = generator(prompt, max_tokens=30)
structured_2 = generator_2(prompt, max_tokens=30)

assert re.fullmatch(regex, structured)
assert re.fullmatch(regex, structured_2)
assert structured != structured_2
14 changes: 11 additions & 3 deletions tests/models/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,14 @@ def test_tokenizer_eq_hash():
tokenizer_hf = AutoTokenizer.from_pretrained("gpt2")

tokenizer = TransformerTokenizer(tokenizer_hf)
tokenizer2 = TransformerTokenizer(tokenizer_hf)
assert tokenizer == tokenizer2
assert hash(tokenizer) == hash(tokenizer2)
tokenizer_2 = TransformerTokenizer(tokenizer_hf)

assert tokenizer == tokenizer_2
assert hash(tokenizer) == hash(tokenizer_2)

tokenizer_hf_2 = AutoTokenizer.from_pretrained("gpt2")
tokenizer_hf_2.add_tokens(["test_token"])

tokenizer_3 = TransformerTokenizer(tokenizer_hf_2)
assert tokenizer != tokenizer_3
assert hash(tokenizer) != hash(tokenizer_3)

0 comments on commit ba7affd

Please sign in to comment.