Skip to content

Commit

Permalink
Ignore Some Messages When Transforming (#2661)
Browse files Browse the repository at this point in the history
* works

* spelling

* returned old docstring

* add cache fix

* spelling?

---------

Co-authored-by: Eric Zhu <[email protected]>
  • Loading branch information
WaelKarkoub and ekzhu committed May 22, 2024
1 parent 3e11b07 commit 4ebfb82
Show file tree
Hide file tree
Showing 4 changed files with 252 additions and 92 deletions.
42 changes: 39 additions & 3 deletions autogen/agentchat/contrib/capabilities/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from autogen import token_count_utils
from autogen.cache import AbstractCache, Cache
from autogen.oai.openai_utils import filter_config

from .text_compressors import LLMLingua, TextCompressor

Expand Down Expand Up @@ -130,6 +131,8 @@ def __init__(
max_tokens: Optional[int] = None,
min_tokens: Optional[int] = None,
model: str = "gpt-3.5-turbo-0613",
filter_dict: Optional[Dict] = None,
exclude_filter: bool = True,
):
"""
Args:
Expand All @@ -140,11 +143,17 @@ def __init__(
min_tokens (Optional[int]): Minimum number of tokens in messages to apply the transformation.
Must be greater than or equal to 0 if not None.
model (str): The target OpenAI model for tokenization alignment.
filter_dict (None or dict): A dictionary to filter out messages that you want/don't want to compress.
If None, no filters will be applied.
exclude_filter (bool): If exclude filter is True (the default value), messages that match the filter will be
excluded from token truncation. If False, messages that match the filter will be truncated.
"""
self._model = model
self._max_tokens_per_message = self._validate_max_tokens(max_tokens_per_message)
self._max_tokens = self._validate_max_tokens(max_tokens)
self._min_tokens = self._validate_min_tokens(min_tokens, max_tokens)
self._filter_dict = filter_dict
self._exclude_filter = exclude_filter

def apply_transform(self, messages: List[Dict]) -> List[Dict]:
"""Applies token truncation to the conversation history.
Expand All @@ -169,10 +178,15 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]:

for msg in reversed(temp_messages):
# Some messages may not have content.
if not isinstance(msg.get("content"), (str, list)):
if not _is_content_right_type(msg.get("content")):
processed_messages.insert(0, msg)
continue

if not _should_transform_message(msg, self._filter_dict, self._exclude_filter):
processed_messages.insert(0, msg)
processed_messages_tokens += _count_tokens(msg["content"])
continue

expected_tokens_remained = self._max_tokens - processed_messages_tokens - self._max_tokens_per_message

# If adding this message would exceed the token limit, truncate the last message to meet the total token
Expand Down Expand Up @@ -282,6 +296,8 @@ def __init__(
min_tokens: Optional[int] = None,
compression_params: Dict = dict(),
cache: Optional[AbstractCache] = Cache.disk(),
filter_dict: Optional[Dict] = None,
exclude_filter: bool = True,
):
"""
Args:
Expand All @@ -293,6 +309,10 @@ def __init__(
dictionary.
cache (None or AbstractCache): The cache client to use to store and retrieve previously compressed messages.
If None, no caching will be used.
filter_dict (None or dict): A dictionary to filter out messages that you want/don't want to compress.
If None, no filters will be applied.
exclude_filter (bool): If exclude filter is True (the default value), messages that match the filter will be
excluded from compression. If False, messages that match the filter will be compressed.
"""

if text_compressor is None:
Expand All @@ -303,6 +323,8 @@ def __init__(
self._text_compressor = text_compressor
self._min_tokens = min_tokens
self._compression_args = compression_params
self._filter_dict = filter_dict
self._exclude_filter = exclude_filter
self._cache = cache

# Optimizing savings calculations to optimize log generation
Expand Down Expand Up @@ -334,7 +356,10 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]:
processed_messages = messages.copy()
for message in processed_messages:
# Some messages may not have content.
if not isinstance(message.get("content"), (str, list)):
if not _is_content_right_type(message.get("content")):
continue

if not _should_transform_message(message, self._filter_dict, self._exclude_filter):
continue

if _is_content_text_empty(message["content"]):
Expand Down Expand Up @@ -397,7 +422,7 @@ def _cache_set(
self, content: Union[str, List[Dict]], compressed_content: Union[str, List[Dict]], tokens_saved: int
):
if self._cache:
value = (tokens_saved, json.dumps(compressed_content))
value = (tokens_saved, compressed_content)
self._cache.set(self._cache_key(content), value)

def _cache_key(self, content: Union[str, List[Dict]]) -> str:
Expand Down Expand Up @@ -427,10 +452,21 @@ def _count_tokens(content: Union[str, List[Dict[str, Any]]]) -> int:
return token_count


def _is_content_right_type(content: Any) -> bool:
return isinstance(content, (str, list))


def _is_content_text_empty(content: Union[str, List[Dict[str, Any]]]) -> bool:
if isinstance(content, str):
return content == ""
elif isinstance(content, list):
return all(_is_content_text_empty(item.get("text", "")) for item in content)
else:
return False


def _should_transform_message(message: Dict[str, Any], filter_dict: Optional[Dict[str, Any]], exclude: bool) -> bool:
if not filter_dict:
return True

return len(filter_config([message], filter_dict, exclude)) > 0
105 changes: 54 additions & 51 deletions autogen/oai/openai_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,11 +379,10 @@ def config_list_gpt4_gpt35(
def filter_config(
config_list: List[Dict[str, Any]],
filter_dict: Optional[Dict[str, Union[List[Union[str, None]], Set[Union[str, None]]]]],
exclude: bool = False,
) -> List[Dict[str, Any]]:
"""
This function filters `config_list` by checking each configuration dictionary against the
criteria specified in `filter_dict`. A configuration dictionary is retained if for every
key in `filter_dict`, see example below.
"""This function filters `config_list` by checking each configuration dictionary against the criteria specified in
`filter_dict`. A configuration dictionary is retained if for every key in `filter_dict`, see example below.
Args:
config_list (list of dict): A list of configuration dictionaries to be filtered.
Expand All @@ -394,71 +393,68 @@ def filter_config(
when it is found in the list of acceptable values. If the configuration's
field's value is a list, then a match occurs if there is a non-empty
intersection with the acceptable values.
exclude (bool): If False (the default value), configs that match the filter will be included in the returned
list. If True, configs that match the filter will be excluded in the returned list.
Returns:
list of dict: A list of configuration dictionaries that meet all the criteria specified
in `filter_dict`.
Example:
```python
# Example configuration list with various models and API types
configs = [
{'model': 'gpt-3.5-turbo'},
{'model': 'gpt-4'},
{'model': 'gpt-3.5-turbo', 'api_type': 'azure'},
{'model': 'gpt-3.5-turbo', 'tags': ['gpt35_turbo', 'gpt-35-turbo']},
]
# Define filter criteria to select configurations for the 'gpt-3.5-turbo' model
# that are also using the 'azure' API type
filter_criteria = {
'model': ['gpt-3.5-turbo'], # Only accept configurations for 'gpt-3.5-turbo'
'api_type': ['azure'] # Only accept configurations for 'azure' API type
}
# Apply the filter to the configuration list
filtered_configs = filter_config(configs, filter_criteria)
# The resulting `filtered_configs` will be:
# [{'model': 'gpt-3.5-turbo', 'api_type': 'azure', ...}]
# Define a filter to select a given tag
filter_criteria = {
'tags': ['gpt35_turbo'],
}
# Apply the filter to the configuration list
filtered_configs = filter_config(configs, filter_criteria)
# The resulting `filtered_configs` will be:
# [{'model': 'gpt-3.5-turbo', 'tags': ['gpt35_turbo', 'gpt-35-turbo']}]
```
```python
# Example configuration list with various models and API types
configs = [
{'model': 'gpt-3.5-turbo'},
{'model': 'gpt-4'},
{'model': 'gpt-3.5-turbo', 'api_type': 'azure'},
{'model': 'gpt-3.5-turbo', 'tags': ['gpt35_turbo', 'gpt-35-turbo']},
]
# Define filter criteria to select configurations for the 'gpt-3.5-turbo' model
# that are also using the 'azure' API type
filter_criteria = {
'model': ['gpt-3.5-turbo'], # Only accept configurations for 'gpt-3.5-turbo'
'api_type': ['azure'] # Only accept configurations for 'azure' API type
}
# Apply the filter to the configuration list
filtered_configs = filter_config(configs, filter_criteria)
# The resulting `filtered_configs` will be:
# [{'model': 'gpt-3.5-turbo', 'api_type': 'azure', ...}]
# Define a filter to select a given tag
filter_criteria = {
'tags': ['gpt35_turbo'],
}
# Apply the filter to the configuration list
filtered_configs = filter_config(configs, filter_criteria)
# The resulting `filtered_configs` will be:
# [{'model': 'gpt-3.5-turbo', 'tags': ['gpt35_turbo', 'gpt-35-turbo']}]
```
Note:
- If `filter_dict` is empty or None, no filtering is applied and `config_list` is returned as is.
- If a configuration dictionary in `config_list` does not contain a key specified in `filter_dict`,
it is considered a non-match and is excluded from the result.
- If the list of acceptable values for a key in `filter_dict` includes None, then configuration
dictionaries that do not have that key will also be considered a match.
"""
def _satisfies(config_value: Any, acceptable_values: Any) -> bool:
if isinstance(config_value, list):
return bool(set(config_value) & set(acceptable_values)) # Non-empty intersection
else:
return config_value in acceptable_values
"""

if filter_dict:
config_list = [
config
for config in config_list
if all(_satisfies(config.get(key), value) for key, value in filter_dict.items())
return [
item
for item in config_list
if all(_satisfies_criteria(item.get(key), values) != exclude for key, values in filter_dict.items())
]
return config_list


def _satisfies_criteria(value: Any, criteria_values: Any) -> bool:
if value is None:
return False

if isinstance(value, list):
return bool(set(value) & set(criteria_values)) # Non-empty intersection
else:
return value in criteria_values


def config_list_from_json(
env_or_file: str,
file_location: Optional[str] = "",
Expand Down Expand Up @@ -785,3 +781,10 @@ def update_gpt_assistant(client: OpenAI, assistant_id: str, assistant_config: Di
assistant_update_kwargs["file_ids"] = assistant_config["file_ids"]

return client.beta.assistants.update(assistant_id=assistant_id, **assistant_update_kwargs)


def _satisfies(config_value: Any, acceptable_values: Any) -> bool:
if isinstance(config_value, list):
return bool(set(config_value) & set(acceptable_values)) # Non-empty intersection
else:
return config_value in acceptable_values

0 comments on commit 4ebfb82

Please sign in to comment.