Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update caching and add tokenizer to create_states_mapping #911

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 84 additions & 45 deletions outlines/caching.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,40 @@
import asyncio
import functools
import hashlib
import os
from typing import Callable, Optional

import cloudpickle
from diskcache import Cache
from diskcache import Cache, Disk
from diskcache.core import ENOVAL, UNKNOWN, args_to_key, full_name

_caching_enabled = True


class CloudpickleDisk(Disk):
def __init__(self, directory, compress_level=1, **kwargs):
self.compress_level = compress_level
super().__init__(directory, **kwargs)

def put(self, key):
data = cloudpickle.dumps(key)
return super().put(data)

def get(self, key, raw):
data = super().get(key, raw)
return cloudpickle.loads(data)

def store(self, value, read, key=UNKNOWN):
if not read:
value = cloudpickle.dumps(value)
return super().store(value, read, key=key)

def fetch(self, mode, filename, value, read):
data = super().fetch(mode, filename, value, read)
if not read:
data = cloudpickle.loads(data)
return data


@functools.lru_cache(1)
def get_cache():
"""Get the context object that contains previously-computed return values.
Expand All @@ -26,7 +51,12 @@ def get_cache():

home_dir = os.path.expanduser("~")
cache_dir = os.environ.get("OUTLINES_CACHE_DIR", f"{home_dir}/.cache/outlines")
memory = Cache(cache_dir, eviction_policy="none", cull_limit=0)
memory = Cache(
cache_dir,
eviction_policy="none",
cull_limit=0,
disk=CloudpickleDisk,
)

# ensure if version upgrade occurs, old cache is pruned
if outlines_version != memory.get("__version__"):
Expand All @@ -36,63 +66,72 @@ def get_cache():
return memory


def hash_arguments(*args, **kwargs) -> str:
"""Create a hash out of the args and kwargs provided"""
result = hashlib.md5()
for item in list(args) + sorted(kwargs.items()):
result.update(cloudpickle.dumps(item))
return result.hexdigest()


def cache(key_function: Optional[Callable] = None):
def cache(expire: Optional[float] = None, typed=False, ignore=()):
"""Caching decorator for memoizing function calls.

The cache key is created based on the values returned by the key_function callable
if provided or based on the arguments of the decorated function directly otherwise

This is based on `diskcache`'s `memoize`.

Parameters
----------
key_function
A callable function used to generate a unique key for each function call. It's
called with the arguments of the decorated function as arguments
expire
Seconds until arguments expire.
typed
Cache different types separately.
ignore
Positional or keyword arguments to ignore.

Returns
-------
A decorator function that can be applied to other functions.
A decorator function that can be applied to other functions.
"""

def decorator(cached_function: Callable):
memory = get_cache()

def wrapper(*args, **kwargs):
if not _caching_enabled:
return cached_function(*args, **kwargs)
if key_function:
key_args = key_function(*args, **kwargs)
cache_key = hash_arguments(*key_args)
else:
cache_key = hash_arguments(*args, **kwargs)
if cache_key in memory:
return memory[cache_key]
result = cached_function(*args, **kwargs)
memory[cache_key] = result
return result

async def async_wrapper(*args, **kwargs):
if not _caching_enabled:
return await cached_function(*args, **kwargs)
if key_function:
key_args = key_function(*args, **kwargs)
cache_key = hash_arguments(*key_args)
else:
cache_key = hash_arguments(*args, **kwargs)
if cache_key in memory:
return memory[cache_key]
result = await cached_function(*args, **kwargs)
memory[cache_key] = result
return result
base = (full_name(cached_function),)

if asyncio.iscoroutinefunction(cached_function):
return async_wrapper

async def wrapper(*args, **kwargs):
if not _caching_enabled:
return await cached_function(*args, **kwargs)

cache_key = wrapper.__cache_key__(*args, **kwargs)
result = wrapper.__memory__.get(cache_key, default=ENOVAL, retry=True)

if result is ENOVAL:
result = await cached_function(*args, **kwargs)
wrapper.__memory__.set(cache_key, result, expire, retry=True)

return result

else:
return wrapper

def wrapper(*args, **kwargs):
if not _caching_enabled:
return cached_function(*args, **kwargs)

cache_key = wrapper.__cache_key__(*args, **kwargs)
result = wrapper.__memory__.get(cache_key, default=ENOVAL, retry=True)

if result is ENOVAL:
result = cached_function(*args, **kwargs)
wrapper.__memory__.set(cache_key, result, expire, retry=True)

return result

def __cache_key__(*args, **kwargs):
"""Make key for cache given function arguments."""
return args_to_key(base, args, kwargs, typed, ignore)

wrapper.__cache_key__ = __cache_key__ # type: ignore
wrapper.__memory__ = memory # type: ignore
wrapper.__wrapped__ = cached_function # type: ignore

return wrapper

return decorator

Expand Down
56 changes: 28 additions & 28 deletions outlines/fsm/guide.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,44 +105,44 @@ def copy(self):
return self


@cache()
def create_states_mapping(
regex_string: str, tokenizer: "Tokenizer"
) -> Tuple[dict, set, set]:
"""Create the variables related to the mapping between states and tokens
The parameters of the function are used for caching purpose
"""
regex_pattern = interegular.parse_pattern(regex_string)
byte_fsm = make_byte_level_fsm(regex_pattern.to_fsm().reduce(), keep_utf8=True)
regex_fsm, _ = make_deterministic_fsm(byte_fsm)
states_to_token_maps, empty_token_ids = create_fsm_index_tokenizer(
regex_fsm, tokenizer
)

# We make sure that it is possible to generate strings in the language
# of the regular expression with the tokens present in the model's
# vocabulary.
if not any(
regex_fsm.finals.intersection(v.values()) for v in states_to_token_maps.values()
):
raise ValueError(
"The vocabulary does not allow us to build a sequence that matches the input regex"
)

return states_to_token_maps, empty_token_ids, regex_fsm.finals


class RegexGuide(Guide):
"""Guide to generate text in the language of a regular expression."""

initial_state = 0

def __init__(self, regex_string: str, tokenizer):
@cache()
def create_states_mapping(regex_string: str) -> Tuple[dict, set, set]:
"""Create the variables related to the mapping between states and tokens
The parameters of the function are used for caching purpose
"""
regex_pattern = interegular.parse_pattern(regex_string)
byte_fsm = make_byte_level_fsm(
regex_pattern.to_fsm().reduce(), keep_utf8=True
)
regex_fsm, _ = make_deterministic_fsm(byte_fsm)
states_to_token_maps, empty_token_ids = create_fsm_index_tokenizer(
regex_fsm, tokenizer
)

# We make sure that it is possible to generate strings in the language
# of the regular expression with the tokens present in the model's
# vocabulary.
if not any(
regex_fsm.finals.intersection(v.values())
for v in states_to_token_maps.values()
):
raise ValueError(
"The vocabulary does not allow us to build a sequence that matches the input regex"
)

return states_to_token_maps, empty_token_ids, regex_fsm.finals

(
self.states_to_token_maps,
self.empty_token_ids,
fsm_finals,
) = create_states_mapping(regex_string)
) = create_states_mapping(regex_string, tokenizer)
self.eos_token_id = tokenizer.eos_token_id
self.final_states = fsm_finals | {-1}

Expand Down
11 changes: 9 additions & 2 deletions outlines/models/transformers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import TYPE_CHECKING, List, Optional, Tuple, Union

from datasets.fingerprint import Hasher

from outlines.models.tokenizer import Tokenizer

if TYPE_CHECKING:
Expand Down Expand Up @@ -109,10 +111,15 @@ def __eq__(self, other):
return NotImplemented

def __hash__(self):
from datasets.fingerprint import Hasher

return hash(Hasher.hash(self.tokenizer))

def __getstate__(self):
state = {"tokenizer": self.tokenizer}
return state

def __setstate__(self, state):
self.__init__(state["tokenizer"])


class Transformers:
"""Represents a `transformers` model."""
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ dependencies = [
"referencing",
"jsonschema",
"requests",
"tqdm"
"tqdm",
"datasets",
]
dynamic = ["version"]

Expand All @@ -50,7 +51,6 @@ test = [
"diff-cover",
"accelerate",
"beartype<0.16.0",
"datasets",
"responses",
"llama-cpp-python",
"huggingface_hub",
Expand Down
68 changes: 67 additions & 1 deletion tests/generate/test_integration_transformers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import datetime
import re
from enum import Enum
from importlib import reload
from typing import List, Union

import pytest
Expand All @@ -11,7 +12,28 @@
import outlines.models as models
from outlines.fsm.regex import reduced_vocabulary
from outlines.models.transformers import Transformers, TransformerTokenizer
from outlines.samplers import beam_search, multinomial
from outlines.samplers import beam_search, greedy, multinomial


@pytest.fixture
def temp_cache_dir():
import os
import tempfile

import outlines.caching
import outlines.fsm.guide

with tempfile.TemporaryDirectory() as tempdir:
os.environ["OUTLINES_CACHE_DIR"] = tempdir
outlines.caching.get_cache.cache_clear()
reload(outlines)
reload(outlines.fsm.guide)
cache_status = outlines.caching._caching_enabled
try:
outlines.caching._caching_enabled = True
yield
finally:
outlines.caching._caching_enabled = cache_status


def test_transformers_integration_text():
Expand Down Expand Up @@ -632,3 +654,47 @@ def test_transformers_use_existing_model_and_tokenizer():
model = Transformers(hf_model, hf_tokenizer)
sequence = generate.text(model)("Write a short sentence ", rng=rng)
assert isinstance(sequence, str)


def test_RegexGuide_caching(temp_cache_dir):
import outlines.caching
from outlines.fsm.guide import create_states_mapping

assert outlines.caching._caching_enabled

regex = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
prompt = "What is the IP address of the Google DNS servers? "

cache = outlines.caching.get_cache()

# Returns (hits, misses)
_ = cache.stats(enable=True)
assert cache.statistics

assert create_states_mapping.__memory__ is cache

model = models.transformers(
"hf-internal-testing/tiny-random-XLMRobertaXLForCausalLM"
)
generator = generate.regex(model, regex, sampler=greedy())
assert cache.stats() == (0, 1)

model_2 = models.transformers("hf-internal-testing/tiny-random-GPTJForCausalLM")
generator_2 = generate.regex(model_2, regex, sampler=greedy())
assert cache.stats() == (0, 2)

# These two different models and tokenizers should not have the same state
# mapping results
assert generator.fsm.states_to_token_maps != generator_2.fsm.states_to_token_maps

generator_3 = generate.regex(model_2, regex, sampler=greedy())
assert cache.stats() == (1, 2)
assert generator_2.fsm.states_to_token_maps == generator_3.fsm.states_to_token_maps

# Just for fun...
structured = generator(prompt, max_tokens=30)
structured_2 = generator_2(prompt, max_tokens=30)

assert re.fullmatch(regex, structured)
assert re.fullmatch(regex, structured_2)
assert structured != structured_2
14 changes: 11 additions & 3 deletions tests/models/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,14 @@ def test_tokenizer_eq_hash():
tokenizer_hf = AutoTokenizer.from_pretrained("gpt2")

tokenizer = TransformerTokenizer(tokenizer_hf)
tokenizer2 = TransformerTokenizer(tokenizer_hf)
assert tokenizer == tokenizer2
assert hash(tokenizer) == hash(tokenizer2)
tokenizer_2 = TransformerTokenizer(tokenizer_hf)

assert tokenizer == tokenizer_2
assert hash(tokenizer) == hash(tokenizer_2)

tokenizer_hf_2 = AutoTokenizer.from_pretrained("gpt2")
tokenizer_hf_2.add_tokens(["test_token"])

tokenizer_3 = TransformerTokenizer(tokenizer_hf_2)
assert tokenizer != tokenizer_3
assert hash(tokenizer) != hash(tokenizer_3)
Loading