Skip to content

Commit

Permalink
allow lmql.F with positionals
Browse files Browse the repository at this point in the history
  • Loading branch information
lbeurerkellner committed Jul 25, 2023
1 parent 5d09885 commit 16c8a1a
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 9 deletions.
10 changes: 7 additions & 3 deletions src/lmql/algorithms/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import inspect

from lmql.runtime.lmql_runtime import LMQLQueryFunction
from lmql import LMQLResult
from lmql import LMQLResult, F

# cache query results by query code and arguments
global cache_file
Expand Down Expand Up @@ -46,9 +46,13 @@ def persist_cache():
with open(cache_file, "wb") as f:
pickle.dump(cache, f)

async def apply(q, *args):
async def apply(q, *args, **kwargs):
global cache

if type(q) is str:
where = kwargs.pop("where", None)
q = F(q, constraints=where, is_async=True)

# handle non-LMQL queries
if type(q) is not LMQLQueryFunction:
if inspect.iscoroutinefunction(q):
Expand All @@ -74,7 +78,7 @@ async def apply(q, *args):
return cache[key]
else:
try:
result = await q(*args)
result = await q(*args, **kwargs)
if len(result) == 1:
result = result[0]
if type(result) is LMQLResult:
Expand Down
2 changes: 1 addition & 1 deletion src/lmql/algorithms/functools.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ async def map(q, items, chunksize=None, progress=False, **kwargs):
chunks = tqdm.tqdm(chunks, file=sys.stdout)

for chunk in chunks:
results = await asyncio.gather(*[apply(q, x) for x in chunk])
results = await asyncio.gather(*[apply(q, x, **kwargs) for x in chunk])
total_results += results

return total_results
Expand Down
35 changes: 30 additions & 5 deletions src/lmql/runtime/lmql_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,23 @@ def output_keys(self) -> List[str]:
def force_model(self, model):
self.model = model

def try_bind_positional_to_kwargs(self, signature, *args, **query_kwargs):
"""
Best-effort attempt to bind positional arguments to keyword arguments in order of self.args.
Only enabled for lmql.F for now, may have unexpected effects depending on the order query
arguments as determined by the compiler.
"""
# only bind if kwargs are empty and no signature is provided (lmql.F or lmql.run)
if len(signature.parameters) != 0 or len(self.args) != len(args):
return
kwargs = {**{k:v for k,v in zip(self.args, args)}, **query_kwargs}

return inspect.BoundArguments(
signature=inspect.Signature(parameters=[inspect.Parameter(name=k, kind=inspect.Parameter.POSITIONAL_OR_KEYWORD) for k in self.args]),
arguments=kwargs
)

def make_kwargs(self, *args, **kwargs):
"""
Binds args and kwargs to the function signature and returns a dict of all user-defined kwargs.
Expand All @@ -122,11 +139,19 @@ def make_kwargs(self, *args, **kwargs):
try:
signature = signature.bind(*args, **query_kwargs)
except TypeError as e:
if len(e.args) == 1 and e.args[0].startswith("missing "):
e.args = (f"Call to @lmql.query function is " + e.args[0] + "." + f" Expecting {signature}, but got positional args {args} and {kwargs}.",)
elif len(e.args) == 1:
e.args = (e.args[0] + "." + f" Expecting {signature}, but got positional args {args} and {kwargs}.",)
raise e
if "too many positional arguments" in str(e):
# this is different from Python behavior, but we allow it for lmql.F and lmql.run
pos_as_kw = self.try_bind_positional_to_kwargs(signature, *args, **kwargs)
if pos_as_kw is not None:
signature = pos_as_kw
else:
raise e
else:
if len(e.args) == 1 and e.args[0].startswith("missing "):
e.args = (f"Call to @lmql.query function is " + e.args[0] + "." + f" Expecting {signature}, but got positional args {args} and {kwargs}.",)
elif len(e.args) == 1:
e.args = (e.args[0] + "." + f" Expecting {signature}, but got positional args {args} and {kwargs}.",)
raise e

# special case, if signature is empty (no input variables provided)
if len(signature.arguments) == 0:
Expand Down
53 changes: 53 additions & 0 deletions src/lmql/tests/test_f.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import lmql
import lmql.algorithms as la

from lmql.tests.expr_test_utils import run_all_tests

def test_f_functions():
a = lmql.F("Summarize this text: {text}: [SUMMARY]", "len(TOKENS(SUMMARY)) < 10", model=lmql.model("random", seed=123))
assert type(a(text="This is a test.")) is str

def test_f_positional_args():
a = lmql.F("Summarize this text: {text}: [SUMMARY]", "len(TOKENS(SUMMARY)) < 10", model=lmql.model("random", seed=123))
assert type(a("This is a test.")) is str

def test_f_multiple_positional_args():
a = lmql.F("Summarize this text: {text1}{text2}: [SUMMARY]", "len(TOKENS(SUMMARY)) < 10", model=lmql.model("random", seed=123))
b = lmql.F("Summarize this text: {text}: [SUMMARY]", "len(TOKENS(SUMMARY)) < 10", model=lmql.model("random", seed=123))
r1 = type(a("This is ", "a test."))
r2 = type(b("This is a test."))

assert r1 is str, "output should be a string"
assert r1 == r2, "using positional and keyword args should result in the same output"


async def test_async_f_functions():
a = lmql.F("Summarize this text: {text}: [SUMMARY]", "len(TOKENS(SUMMARY)) < 10", model=lmql.model("random", seed=123), is_async=True)
assert type(await a(text="This is a test.")) is str

async def test_map_async_f_functions():
# a = lmql.F("Summarize this text: {text}: [SUMMARY]", "len(TOKENS(SUMMARY)) < 10", model=lmql.model("random", seed=123), is_async=True)
data = [
"A dog walks into a bar.",
"A cat goes for a walk.",
]
q = lmql.F("Summarize this text: {text}: [SUMMARY]", "len(TOKENS(SUMMARY)) < 10", model=lmql.model("random", seed=123), is_async=True)
r = await la.map(q, data)
assert len(r) == 2

r = await la.map("Summarize this text: {text}: [SUMMARY]", data, where="len(TOKENS(SUMMARY)) < 10", model=lmql.model("random", seed=123))
assert len(r) == 2

def test_run_pos_args():
source = '''lmql
"Summarize this text: {t1}{t2}: [SUMMARY]" where len(TOKENS(SUMMARY)) < 4
return SUMMARY
'''

r1 = lmql.run_sync(source, "This is ", "a test.", model=lmql.model("random", seed=123))
r2 = lmql.run_sync(source, t1="This is ", t2="a test.", model=lmql.model("random", seed=123))

assert r1 == r2, "using positional and keyword args should result in the same output"

if __name__ == "__main__":
run_all_tests(globals())

0 comments on commit 16c8a1a

Please sign in to comment.