Skip to content

Commit

Permalink
Improve multiple-choice selection for the OpenAI API
Browse files Browse the repository at this point in the history
The current approach is greedy, in the sense that it generates a single
token at each steps, asking the API to only generate valid next tokens.
This mean having to pay for the prompt tokens for every token generated.

This commit takes a more optimistic approach. It starts with allowing
all tokens present in the sequences, and limiting the length of the
generation to the number of tokens in the longest sequence. If the
completion is not satisfactory it then takes one greedy step before
switching back to the optimistic mode.

On average this new approach consumes less tokens than the current one.
  • Loading branch information
HerrIvan authored and rlouf committed Nov 19, 2023
1 parent f1f5c07 commit 76cfc61
Showing 1 changed file with 97 additions and 19 deletions.
116 changes: 97 additions & 19 deletions outlines/models/openai.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"""Integration with OpenAI's API."""
import functools
import os
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Union
from collections import deque
from itertools import zip_longest
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Set, Tuple, Union

import numpy as np

Expand Down Expand Up @@ -85,6 +87,39 @@ async def generate_base(

return results

def longest_common_prefix(tokens1: List[int], tokens2: List[int]) -> List[int]:
i = 0
while i < len(tokens1) and i < len(tokens2) and tokens1[i] == tokens2[i]:
i += 1
return tokens1[:i]

def get_choices_with_longest_common_prefix(
response: List[int], is_in: List[List[int]]
) -> Tuple[List[int], List[List[int]]]:
max_len_prefix = 0
is_in_left = []
prefix = []
for i in range(len(is_in)):
len_prefix = len(longest_common_prefix(response, is_in[i]))

if len_prefix > max_len_prefix:
max_len_prefix = len_prefix
is_in_left = [is_in[i][len_prefix:]]
prefix = is_in[i][:len_prefix]

elif len_prefix == max_len_prefix:
is_in_left.append(is_in[i][len_prefix:])

return prefix, is_in_left

def build_optimistic_mask(transposed: deque[Set]) -> Dict:
# build the biggest mask possible, adding tokens left to right
to_mask: Set[int] = set()
while len(transposed) > 0 and len(to_mask | transposed[0]) <= 300:
to_mask = to_mask | transposed.popleft()

return {token: 100 for token in to_mask}

@functools.partial(outlines.vectorize, signature="(),(m),()->(s)")
async def generate_choice(
prompt: str,
Expand All @@ -95,12 +130,11 @@ async def generate_choice(
.. warning::
This function will call the API once for every token generated.
Worst case, this function may call the API as many times as tokens are in the response.
We tokenize every choice, iterate over the token lists, create a mask
with the current tokens and generate one token. We progressively
eliminate the choices that don't start with the currently decoded
sequence.
With the optimistic approach, we activate all tokens that could form all answers. If the solution returned
does not match any of the answers, we the call the API again only with the tokens that can be accepted as
next-token. In average, this approach returns a solution consuming less calls to the API.
"""
try:
Expand All @@ -111,36 +145,80 @@ async def generate_choice(
)

tokenizer = tiktoken.encoding_for_model(model_name)
encoded: List[List[int]] = [tokenizer.encode(word) for word in is_in]

decoded_samples = []
for _ in range(samples):
is_in_left = is_in.copy()
decoded: List[str] = []
for i in range(max([len(word) for word in encoded])):
mask = {}
for word, tokenized_word in zip(is_in, encoded):
if not word.startswith("".join(decoded)):
continue
try:
mask[tokenized_word[i]] = 100
except IndexError:
pass

greedy = False # we try to generate the full response at each iteration

while len(is_in_left) > 0:
encoded: List[List[int]] = [
tokenizer.encode(word) for word in is_in_left
]

max_tokens_left = max([len(tokens) for tokens in encoded])
transposed: deque[Set] = deque(
[
{item for item in subset if item is not None}
for subset in zip_longest(*encoded)
]
)

if not greedy:
mask = build_optimistic_mask(transposed)
else:
mask = {}
for token in transposed.popleft(): # build greedy mask
mask[token] = 100

if len(mask) == 0:
break

response = await call_api(
model_name,
format_prompt(prompt),
1,
max_tokens_left if not greedy else 1,
temperature,
[],
mask,
1,
)
decoded.append(extract_choice(response["choices"][0]))

prompt = prompt + "".join(decoded)
current_resp = extract_choice(response["choices"][0])

if current_resp in is_in_left:
decoded.append(current_resp)
break
else:
# map response to tokens
tokenized_resp = tokenizer.encode(current_resp)
(
tokenized_resp,
encoded,
) = get_choices_with_longest_common_prefix(
tokenized_resp, encoded
)

if len(tokenized_resp) == 0:
greedy = True # next iteration will be "greedy"
continue
else:
decoded.append("".join(tokenizer.decode(tokenized_resp)))

# map back to words
is_in_left = [
"".join(tokenizer.decode(tokens)) for tokens in encoded
]

if len(is_in_left) == 1: # only one choice left
decoded.append(is_in_left[0])
break

greedy = False # after each success, stay with (or switch to) "optimistic" approach

prompt = prompt + "".join(decoded)

decoded_samples.append("".join(decoded))

Expand Down

0 comments on commit 76cfc61

Please sign in to comment.