Skip to content

Commit

Permalink
Load model stored in GitHub
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Dec 13, 2023
1 parent 0c5503a commit 4f34714
Show file tree
Hide file tree
Showing 3 changed files with 189 additions and 5 deletions.
73 changes: 70 additions & 3 deletions outlines/function.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import importlib.util
from dataclasses import dataclass
from typing import TYPE_CHECKING, Callable, Optional, Union
from typing import TYPE_CHECKING, Callable, Optional, Tuple, Union

import requests

from outlines import generate, models

if TYPE_CHECKING:
from outlines.generate.api import SequenceGenerator
from outlines.prompts import Prompt


@dataclass
Expand All @@ -18,11 +22,19 @@ class Function:
"""

prompt_template: Callable
model_name: str
prompt_template: "Prompt"
schema: Union[str, Callable, object]
model_name: str
generator: Optional["SequenceGenerator"] = None

@classmethod
def from_github(cls, program_path: str, function_name: str = "fn"):
"""Load a function stored on GitHub"""
program_content = download_from_github(program_path)
function = extract_function_from_file(program_content, function_name)

return function

def init_generator(self):
"""Load the model and initialize the generator."""
model = models.transformers(self.model_name)
Expand All @@ -48,3 +60,58 @@ def __call__(self, *args, **kwargs):

prompt = self.prompt_template(*args, **kwargs)
return self.generator(prompt)


def download_from_github(short_path: str):
"""Download the file in which the function is stored on GitHub."""
GITHUB_BASE_URL = "https://raw.githubusercontent.com"
BRANCH = "main"

path = short_path.split("/")
if len(path) < 3:
raise ValueError(
"Please provide a valid path in the form {USERNAME}/{REPO_NAME}/{PATH_TO_FILE}."
)
elif short_path[-3:] == ".py":
raise ValueError("Do not append the `.py` extension to the program name.")

username = path[0]
repo = path[1]
path_to_file = path[2:]

url = "/".join([GITHUB_BASE_URL, username, repo, BRANCH] + path_to_file) + ".py"
result = requests.get(url)

if result.status_code == 200:
return result.text
elif result.status_code == 404:
raise ValueError(
f"Program could not be found at {url}. Please make sure you entered the GitHub username, repository name and path to the program correctly."
)
else:
result.raise_for_status()


def extract_function_from_file(content: str, function_name: str) -> Tuple[Callable]:
"""Extract a function object from a downloaded file."""

spec = importlib.util.spec_from_loader(
"outlines_function", loader=None, origin="github"
)
if spec is not None:
module = importlib.util.module_from_spec(spec)
exec(content, module.__dict__)

try:
fn = getattr(module, function_name)
except AttributeError:
raise AttributeError(
"Could not find an `outlines.Function` instance in the remote file. Make sure that the path you specified is correct."
)

if not isinstance(fn, module.outlines.Function):
raise TypeError(
f"The `{function_name}` variable in the program must be an instance of `outlines.Function`"
)

return fn
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ dependencies = [
"joblib",
"referencing",
"jsonschema",
"requests",
]
dynamic = ["version"]

Expand All @@ -52,6 +53,7 @@ test = [
"accelerate",
"beartype<0.16.0",
"datasets",
"responses",
]

[project.urls]
Expand Down Expand Up @@ -111,6 +113,8 @@ module = [
"interegular.*",
"datasets.*",
"numba.*",
"requests.*",
"responses.*",
]
ignore_missing_imports = true

Expand Down
117 changes: 115 additions & 2 deletions tests/test_function.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import pytest
import responses
from pydantic import BaseModel
from requests.exceptions import HTTPError

import outlines
from outlines.function import Function
from outlines.function import Function, download_from_github, extract_function_from_file


def test_function_basic():
Expand All @@ -12,9 +15,119 @@ def test_template(text: str):
class Foo(BaseModel):
id: int

fn = Function(test_template, "hf-internal-testing/tiny-random-GPTJForCausalLM", Foo)
fn = Function(test_template, Foo, "hf-internal-testing/tiny-random-GPTJForCausalLM")

assert fn.generator is None

result = fn("test")
assert isinstance(result, BaseModel)


def test_download_from_github_invalid():
with pytest.raises(ValueError, match="Please provide"):
download_from_github("outlines/program")

with pytest.raises(ValueError, match="Do not append"):
download_from_github("outlines-dev/outlines/program.py")


@responses.activate
def test_download_from_github_success():
responses.add(
responses.GET,
"https://raw.githubusercontent.com/outlines-dev/outlines/main/program.py",
body="import outlines\n",
status=200,
)

file = download_from_github("outlines-dev/outlines/program")
assert file == "import outlines\n"

responses.add(
responses.GET,
"https://raw.githubusercontent.com/outlines-dev/outlines/main/foo/bar/program.py",
body="import outlines\n",
status=200,
)

file = download_from_github("outlines-dev/outlines/foo/bar/program")
assert file == "import outlines\n"


@responses.activate
def test_download_from_github_error():
responses.add(
responses.GET,
"https://raw.githubusercontent.com/foo/bar/main/program.py",
json={"error": "not found"},
status=404,
)

with pytest.raises(ValueError, match="Program could not be found at"):
download_from_github("foo/bar/program")

responses.add(
responses.GET,
"https://raw.githubusercontent.com/foo/bar/main/program.py",
json={"error": "Internal Server Error"},
status=500,
)

with pytest.raises(HTTPError, match="500 Server Error"):
download_from_github("foo/bar/program")


def test_extract_function_from_file():
content = """
import outlines
from pydantic import BaseModel
model = "gpt2"
@outlines.prompt
def prompt():
'''Hello'''
class User(BaseModel):
id: int
name: str
function = outlines.Function(
prompt,
User,
"gpt2",
)
"""

fn = extract_function_from_file(content, "function")
assert (
str(type(fn)) == "<class 'outlines.function.Function'>"
) # because imported via `exec`


def test_extract_function_from_file_no_function():
content = """
import outlines
from pydantic import BaseModel
@outlines.prompt
def prompt():
'''Hello'''
class User(BaseModel):
id: int
name: str
program = outlines.Function(
prompt,
User,
"gpt2",
)
"""

with pytest.raises(AttributeError, match="Could not find"):
extract_function_from_file(content, "function")

0 comments on commit 4f34714

Please sign in to comment.