Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add role to reflection with llm #2527

Merged
merged 16 commits into from
May 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
21 changes: 18 additions & 3 deletions autogen/agentchat/conversable_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -937,6 +937,7 @@ def my_summary_method(
One example key is "summary_prompt", and value is a string of text used to prompt a LLM-based agent (the sender or receiver agent) to reflect
on the conversation and extract a summary when summary_method is "reflection_with_llm".
The default summary_prompt is DEFAULT_SUMMARY_PROMPT, i.e., "Summarize takeaway from the conversation. Do not add any introductory phrases. If the intended request is NOT properly addressed, please point it out."
Another available key is "summary_role", which is the role of the message sent to the agent in charge of summarizing. Default is "system".
message (str, dict or Callable): the initial message to be sent to the recipient. Needs to be provided. Otherwise, input() will be called to get the initial message.
- If a string or a dict is provided, it will be used as the initial message. `generate_init_message` is called to generate the initial message for the agent based on this string and the context.
If dict, it may contain the following reserved fields (either content or tool_calls need to be provided).
Expand Down Expand Up @@ -1168,8 +1169,13 @@ def _reflection_with_llm_as_summary(sender, recipient, summary_args):
raise ValueError("The summary_prompt must be a string.")
msg_list = recipient.chat_messages_for_summary(sender)
agent = sender if recipient is None else recipient
role = summary_args.get("summary_role", None)
if role and not isinstance(role, str):
raise ValueError("The summary_role in summary_arg must be a string.")
try:
summary = sender._reflection_with_llm(prompt, msg_list, llm_agent=agent, cache=summary_args.get("cache"))
summary = sender._reflection_with_llm(
prompt, msg_list, llm_agent=agent, cache=summary_args.get("cache"), role=role
)
except BadRequestError as e:
warnings.warn(
f"Cannot extract summary using reflection_with_llm: {e}. Using an empty str as summary.", UserWarning
Expand All @@ -1178,7 +1184,12 @@ def _reflection_with_llm_as_summary(sender, recipient, summary_args):
return summary

def _reflection_with_llm(
self, prompt, messages, llm_agent: Optional[Agent] = None, cache: Optional[AbstractCache] = None
self,
prompt,
messages,
llm_agent: Optional[Agent] = None,
cache: Optional[AbstractCache] = None,
role: Union[str, None] = None,
) -> str:
"""Get a chat summary using reflection with an llm client based on the conversation history.

Expand All @@ -1187,10 +1198,14 @@ def _reflection_with_llm(
messages (list): The messages generated as part of a chat conversation.
llm_agent: the agent with an llm client.
cache (AbstractCache or None): the cache client to be used for this conversation.
role (str): the role of the message, usually "system" or "user". Default is "system".
"""
if not role:
role = "system"

system_msg = [
{
"role": "system",
"role": role,
"content": prompt,
}
]
Expand Down
59 changes: 52 additions & 7 deletions test/agentchat/test_groupchat.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
import json
import logging
from typing import Any, Dict, List, Optional
from unittest import mock
from unittest import TestCase, mock

import pytest
from test_assistant_agent import KEY_LOC, OAI_CONFIG_LIST

import autogen
from autogen import Agent, AssistantAgent, GroupChat, GroupChatManager
Expand Down Expand Up @@ -1446,6 +1447,46 @@ def test_speaker_selection_agent_name_match():
assert result == {}


def test_role_for_reflection_summary():
llm_config = {"config_list": [{"model": "mock", "api_key": "mock"}]}
agent1 = autogen.ConversableAgent(
"alice",
max_consecutive_auto_reply=10,
human_input_mode="NEVER",
llm_config=False,
default_auto_reply="This is alice speaking.",
)
agent2 = autogen.ConversableAgent(
"bob",
max_consecutive_auto_reply=10,
human_input_mode="NEVER",
llm_config=False,
default_auto_reply="This is bob speaking.",
)
groupchat = autogen.GroupChat(
agents=[agent1, agent2], messages=[], max_round=3, speaker_selection_method="round_robin"
)
group_chat_manager = autogen.GroupChatManager(groupchat=groupchat, llm_config=llm_config)

role_name = "user"
with mock.patch.object(
autogen.ConversableAgent, "_generate_oai_reply_from_client"
) as mock_generate_oai_reply_from_client:
mock_generate_oai_reply_from_client.return_value = "Mocked summary"

agent1.initiate_chat(
group_chat_manager,
max_turns=2,
message="hello",
summary_method="reflection_with_llm",
summary_args={"summary_role": role_name},
)

mock_generate_oai_reply_from_client.assert_called_once()
args, kwargs = mock_generate_oai_reply_from_client.call_args
assert kwargs["messages"][-1]["role"] == role_name


def test_speaker_selection_auto_process_result():
"""
Tests the return result of the 2-agent chat used for speaker selection for the auto method.
Expand Down Expand Up @@ -1984,12 +2025,16 @@ def test_manager_resume_messages():
# test_role_for_select_speaker_messages()
# test_select_speaker_message_and_prompt_templates()
# test_speaker_selection_agent_name_match()
# test_role_for_reflection_summary()
# test_speaker_selection_auto_process_result()
# test_speaker_selection_validate_speaker_name()
# test_select_speaker_auto_messages()
# test_speaker_selection_auto_process_result()
# test_speaker_selection_validate_speaker_name()
# test_select_speaker_auto_messages()
test_manager_messages_to_string()
test_manager_messages_from_string()
test_manager_resume_functions()
test_manager_resume_returns()
test_manager_resume_messages()
# pass
# test_manager_messages_to_string()
# test_manager_messages_from_string()
# test_manager_resume_functions()
# test_manager_resume_returns()
# test_manager_resume_messages()
pass