New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Token counting and litellm provider customization #1421
Changes from 4 commits
8196742
c825a55
a521c0a
294768d
bd12f56
7424522
f3b503a
b2ea746
ed958cc
4c36787
037504d
b0c5f29
6a853c2
fb605d1
30ab87a
7ad63d4
4269405
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
from litellm import completion as litellm_completion | ||
import litellm | ||
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_random_exponential | ||
from litellm.exceptions import APIConnectionError, RateLimitError, ServiceUnavailableError | ||
from functools import partial | ||
|
@@ -16,6 +17,9 @@ | |
LLM_NUM_RETRIES = config.get(ConfigType.LLM_NUM_RETRIES) | ||
LLM_RETRY_MIN_WAIT = config.get(ConfigType.LLM_RETRY_MIN_WAIT) | ||
LLM_RETRY_MAX_WAIT = config.get(ConfigType.LLM_RETRY_MAX_WAIT) | ||
LLM_MAX_INPUT_TOKENS = config.get(ConfigType.LLM_MAX_INPUT_TOKENS) | ||
LLM_MAX_OUTPUT_TOKENS = config.get(ConfigType.LLM_MAX_OUTPUT_TOKENS) | ||
LLM_CUSTOM_LLM_PROVIDER = config.get(ConfigType.LLM_CUSTOM_LLM_PROVIDER) | ||
|
||
|
||
class LLM: | ||
|
@@ -31,6 +35,9 @@ def __init__(self, | |
num_retries=LLM_NUM_RETRIES, | ||
retry_min_wait=LLM_RETRY_MIN_WAIT, | ||
retry_max_wait=LLM_RETRY_MAX_WAIT, | ||
max_input_tokens=LLM_MAX_INPUT_TOKENS, | ||
max_output_tokens=LLM_MAX_OUTPUT_TOKENS, | ||
custom_llm_provider=LLM_CUSTOM_LLM_PROVIDER | ||
): | ||
""" | ||
Args: | ||
|
@@ -41,6 +48,9 @@ def __init__(self, | |
num_retries (int, optional): The number of retries for API calls. Defaults to LLM_NUM_RETRIES. | ||
retry_min_wait (int, optional): The minimum time to wait between retries in seconds. Defaults to LLM_RETRY_MIN_TIME. | ||
retry_max_wait (int, optional): The maximum time to wait between retries in seconds. Defaults to LLM_RETRY_MAX_TIME. | ||
max_input_tokens (int, optional): The maximum number of tokens to send to and receive from LLM per task. Defaults to LLM_MAX_INPUT_TOKENS. | ||
computer-whisperer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
max_output_tokens (int, optional): The maximum number of tokens to send to and receive from LLM per task. Defaults to LLM_MAX_OUTPUT_TOKENS. | ||
computer-whisperer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
custom_llm_provider (function, optional): A custom LLM provider. Defaults to LLM_CUSTOM_LLM_PROVIDER. | ||
|
||
Attributes: | ||
model_name (str): The name of the language model. | ||
|
@@ -54,9 +64,32 @@ def __init__(self, | |
self.api_key = api_key | ||
self.base_url = base_url | ||
self.api_version = api_version | ||
self.max_input_tokens = max_input_tokens | ||
self.max_output_tokens = max_output_tokens | ||
self.custom_llm_provider = custom_llm_provider | ||
|
||
# litellm actually uses base Exception here for unknown model | ||
self.model_info = None | ||
try: | ||
self.model_info = litellm.get_model_info(self.model_name) | ||
# noinspection PyBroadException | ||
except Exception: | ||
logger.warning(f'Could not get model info for {self.model_name}') | ||
|
||
if self.max_input_tokens is None: | ||
if self.model_info is not None and 'max_input_tokens' in self.model_info: | ||
self.max_input_tokens = self.model_info['max_input_tokens'] | ||
else: | ||
self.max_input_tokens = 4096 | ||
|
||
if self.max_output_tokens is None: | ||
if self.model_info is not None and 'max_output_tokens' in self.model_info: | ||
self.max_output_tokens = self.model_info['max_output_tokens'] | ||
else: | ||
self.max_output_tokens = 1024 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just curious: where does this number come from? I guess 4096 is because it's the limit of GPT 3.5, but how about this one? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't have a significant justification for either of these defaults, and I am interested to hear opinions on them. I regularly experienced overruns with a 512 output token limit, and therefore I usually use 1024 or higher locally. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't have a strong opinion either. I just feel like it would be better to have some comments explaining where these numbers are from. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have added comments documenting this: # Max input tokens for gpt3.5, so this is a safe fallback for any potentially viable model
self.max_input_tokens = 4096
# Enough tokens for most output actions, and not too many for a bad llm to get carried away responding
# with thousands of unwanted tokens
self.max_output_tokens = 1024 |
||
|
||
self._completion = partial( | ||
litellm_completion, model=self.model_name, api_key=self.api_key, base_url=self.base_url, api_version=self.api_version) | ||
litellm_completion, model=self.model_name, api_key=self.api_key, base_url=self.base_url, api_version=self.api_version, max_tokens=max_output_tokens, custom_llm_provider=custom_llm_provider) | ||
|
||
completion_unwrapped = self._completion | ||
|
||
|
@@ -89,6 +122,18 @@ def completion(self): | |
""" | ||
return self._completion | ||
|
||
def get_token_count(self, messages): | ||
""" | ||
Get the number of tokens in a list of messages. | ||
|
||
Args: | ||
messages (list): A list of messages. | ||
|
||
Returns: | ||
int: The number of tokens. | ||
""" | ||
return litellm.token_counter(model=self.model_name, messages=messages) | ||
|
||
def __str__(self): | ||
if self.api_version: | ||
return f'LLM(model={self.model_name}, api_version={self.api_version}, base_url={self.base_url})' | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this roughly a similar number? 40 chars per token? That seems like a lot to me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ahh nvm--I see how it's being used differently