Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
nerdai committed May 1, 2024
1 parent 23cb6ec commit 542f9d1
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 109 deletions.
Expand Up @@ -305,7 +305,7 @@ def run_step(self, step: TaskStep, task: Task, **kwargs: Any) -> TaskStepOutput:
)

agent_response = AgentChatResponse(
response=str(response), sources=task.extra_state["sources"]
response=response.message.content, sources=task.extra_state["sources"]
)

return TaskStepOutput(
Expand Down
139 changes: 35 additions & 104 deletions llama-index-core/llama_index/core/agent/introspective/step.py
Expand Up @@ -3,7 +3,6 @@
import logging
import uuid
from typing import Any, List, Optional, cast
import asyncio

from llama_index.core.agent.types import (
BaseAgentWorker,
Expand Down Expand Up @@ -52,9 +51,8 @@ def __init__(
self,
tools: List[BaseTool],
llm: FunctionCallingLLM,
main_agent_worker: BaseAgentWorker,
reflective_agent_worker: BaseAgentWorker,
prefix_messages: List[ChatMessage],
main_agent_worker: Optional[BaseAgentWorker] = None,
verbose: bool = False,
max_function_calls: int = 5,
callback_manager: Optional[CallbackManager] = None,
Expand All @@ -71,7 +69,6 @@ def __init__(
self._max_function_calls = max_function_calls
self._main_agent_worker = main_agent_worker
self._reflective_agent_worker = reflective_agent_worker
self.prefix_messages = prefix_messages
self.callback_manager = callback_manager or self._llm.callback_manager
self.allow_parallel_tool_calls = allow_parallel_tool_calls

Expand All @@ -89,16 +86,15 @@ def __init__(
@classmethod
def from_args(
cls,
main_agent_worker: BaseAgentWorker,
reflective_agent_worker: BaseAgentWorker,
main_agent_worker: Optional[BaseAgentWorker] = None,
tools: Optional[List[BaseTool]] = None,
tool_retriever: Optional[ObjectRetriever[BaseTool]] = None,
llm: Optional[FunctionCallingLLM] = None,
verbose: bool = False,
max_function_calls: int = DEFAULT_MAX_FUNCTION_CALLS,
callback_manager: Optional[CallbackManager] = None,
system_prompt: Optional[str] = None,
prefix_messages: Optional[List[ChatMessage]] = None,
**kwargs: Any,
) -> "IntrospectiveAgentWorker":
"""Create an IntrospectiveAgentWorker from a list of tools.
Expand All @@ -121,15 +117,12 @@ def from_args(
)
prefix_messages = [ChatMessage(content=system_prompt, role="system")]

prefix_messages = prefix_messages or []

return cls(
tools=tools,
tool_retriever=tool_retriever,
main_agent_worker=main_agent_worker,
reflective_agent_worker=reflective_agent_worker,
llm=llm,
prefix_messages=prefix_messages,
verbose=verbose,
max_function_calls=max_function_calls,
callback_manager=callback_manager,
Expand All @@ -142,6 +135,11 @@ def initialize_step(self, task: Task, **kwargs: Any) -> TaskStep:
main_memory = ChatMemoryBuffer.from_defaults()
reflective_memory = ChatMemoryBuffer.from_defaults()

# put current history in new memory
messages = task.memory.get()
for message in messages:
main_memory.put(message)

# initialize task state
task_state = {
"main": {
Expand Down Expand Up @@ -174,20 +172,36 @@ def get_all_messages(self, task: Task) -> List[ChatMessage]:
def run_step(self, step: TaskStep, task: Task, **kwargs: Any) -> TaskStepOutput:
"""Run step."""
# run main agent
main_agent = self._main_agent_worker.as_agent()
main_agent_response = main_agent.chat(
task.input
) # or should i use step.input here?
task.extra_state["main"]["sources"] = main_agent_response.sources
task.extra_state["main"]["memory"] = main_agent.memory
print(f"MAIN AGENT MEMORY: {main_agent.memory}", flush=True)
if self._main_agent_worker is not None:
main_agent_messages = task.extra_state["main"]["memory"].get()
main_agent = self._main_agent_worker.as_agent(
chat_history=main_agent_messages
)
main_agent_response = main_agent.chat(
task.input, parent_task_id=task.task_id
)
original_response = main_agent_response.response
task.extra_state["main"]["sources"] = main_agent_response.sources
task.extra_state["main"]["memory"] = main_agent.memory
else:
add_user_step_to_memory(
step, task.extra_state["main"]["memory"], verbose=self._verbose
)
original_response = step.input
task.extra_state["main"]["memory"].put(
ChatMessage(content=original_response, role="assistant")
)

# run reflective agent
reflective_agent = self._reflective_agent_worker.as_agent()
reflective_agent_response = reflective_agent.chat(main_agent_response.response)
reflective_agent_messages = task.extra_state["main"]["memory"].get()
reflective_agent = self._reflective_agent_worker.as_agent(
chat_history=reflective_agent_messages
)
reflective_agent_response = reflective_agent.chat(
original_response, parent_task_id=task.task_id
)
task.extra_state["reflection"]["sources"] = reflective_agent_response.sources
task.extra_state["reflection"]["memory"] = reflective_agent.memory
print(f"REFLECTIVE AGENT MEMORY: {reflective_agent.memory}", flush=True)

agent_response = AgentChatResponse(
response=str(reflective_agent_response.response),
Expand All @@ -207,86 +221,7 @@ async def arun_step(
self, step: TaskStep, task: Task, **kwargs: Any
) -> TaskStepOutput:
"""Run step (async)."""
if step.input is not None:
add_user_step_to_memory(
step, task.extra_state["new_memory"], verbose=self._verbose
)
# TODO: see if we want to do step-based inputs
tools = self.get_tools(task.input)

# get response and tool call (if exists)
response = await self._llm.achat_with_tools(
tools=tools,
user_msg=None,
chat_history=self.get_all_messages(task),
verbose=self._verbose,
allow_parallel_tool_calls=self.allow_parallel_tool_calls,
)
tool_calls = self._llm.get_tool_calls_from_response(
response, error_on_no_tool_call=False
)

if self._verbose and response.message.content:
print("=== LLM Response ===")
print(str(response.message.content))

if not self.allow_parallel_tool_calls and len(tool_calls) > 1:
raise ValueError(
"Parallel tool calls not supported for synchronous function calling agent"
)

# call all tools, gather responses
task.extra_state["new_memory"].put(response.message)
if (
len(tool_calls) == 0
or task.extra_state["n_function_calls"] >= self._max_function_calls
):
# we are done
is_done = True
new_steps = []
else:
is_done = False
tasks = [
self._acall_function(
tools,
tool_call,
task.extra_state["new_memory"],
task.extra_state["sources"],
verbose=self._verbose,
)
for tool_call in tool_calls
]
return_directs = await asyncio.gather(*tasks)

# check if any of the tools return directly -- only works if there is one tool call
if len(return_directs) == 1 and return_directs[0]:
is_done = True
response = task.extra_state["sources"][-1].content

task.extra_state["n_function_calls"] += len(tool_calls)
# put tool output in sources and memory
new_steps = (
[
step.get_next_step(
step_id=str(uuid.uuid4()),
# NOTE: input is unused
input=None,
)
]
if not is_done
else []
)

agent_response = AgentChatResponse(
response=str(response), sources=task.extra_state["sources"]
)

return TaskStepOutput(
output=agent_response,
task_step=step,
is_last=is_done,
next_steps=new_steps,
)
...

@trace_method("run_step")
def stream_step(self, step: TaskStep, task: Task, **kwargs: Any) -> TaskStepOutput:
Expand All @@ -303,10 +238,6 @@ async def astream_step(
def finalize_task(self, task: Task, **kwargs: Any) -> None:
"""Finalize task, after all the steps are completed."""
# add new messages to memory
task.memory.set(
task.memory.get_all()
+ task.extra_state["main"]["memory"].get_all()
+ task.extra_state["reflection"]["memory"].get_all()
)
task.memory.set(task.extra_state["reflection"]["memory"].get_all())
# reset new memory
task.extra_state["main"]["memory"].reset()
23 changes: 19 additions & 4 deletions llama-index-core/llama_index/core/agent/runner/base.py
Expand Up @@ -171,6 +171,9 @@ class AgentState(BaseModel):
default_factory=dict, description="Task dictionary."
)

class Config:
arbitrary_types_allowed = True

def get_task(self, task_id: str) -> Task:
"""Get task state."""
return self.task_dict[task_id].task
Expand Down Expand Up @@ -294,7 +297,9 @@ def reset(self) -> None:
self.memory.reset()
self.state.reset()

def create_task(self, input: str, **kwargs: Any) -> Task:
def create_task(
self, input: str, parent_task_id: Optional[str] = None, **kwargs: Any
) -> Task:
"""Create task."""
if not self.init_task_state_kwargs:
extra_state = kwargs.pop("extra_state", {})
Expand All @@ -309,6 +314,7 @@ def create_task(self, input: str, **kwargs: Any) -> Task:
callback_manager = kwargs.pop("callback_manager", self.callback_manager)
task = Task(
input=input,
parent_task_id=parent_task_id,
memory=self.memory,
extra_state=extra_state,
callback_manager=callback_manager,
Expand Down Expand Up @@ -532,13 +538,14 @@ def _chat(
chat_history: Optional[List[ChatMessage]] = None,
tool_choice: Union[str, dict] = "auto",
mode: ChatResponseMode = ChatResponseMode.WAIT,
parent_task_id: Optional[str] = None,
) -> AGENT_CHAT_RESPONSE_TYPE:
"""Chat with step executor."""
dispatch_event = dispatcher.get_dispatch_event()

if chat_history is not None:
self.memory.set(chat_history)
task = self.create_task(message)
task = self.create_task(message, parent_task_id)

result_output = None
dispatch_event(AgentChatWithStepStartEvent())
Expand Down Expand Up @@ -569,20 +576,24 @@ async def _achat(
chat_history: Optional[List[ChatMessage]] = None,
tool_choice: Union[str, dict] = "auto",
mode: ChatResponseMode = ChatResponseMode.WAIT,
parent_task_id: Optional[str] = None,
) -> AGENT_CHAT_RESPONSE_TYPE:
"""Chat with step executor."""
dispatch_event = dispatcher.get_dispatch_event()

if chat_history is not None:
self.memory.set(chat_history)
task = self.create_task(message)
task = self.create_task(message, parent_task_id)

result_output = None
dispatch_event(AgentChatWithStepStartEvent())
while True:
# pass step queue in as argument, assume step executor is stateless
cur_step_output = await self._arun_step(
task.task_id, mode=mode, tool_choice=tool_choice
task.task_id,
mode=mode,
tool_choice=tool_choice,
parent_task_id=parent_task_id,
)

if cur_step_output.is_last:
Expand All @@ -606,6 +617,7 @@ def chat(
message: str,
chat_history: Optional[List[ChatMessage]] = None,
tool_choice: Optional[Union[str, dict]] = None,
parent_task_id: Optional[str] = None,
) -> AgentChatResponse:
# override tool choice is provided as input.
if tool_choice is None:
Expand All @@ -619,6 +631,7 @@ def chat(
chat_history=chat_history,
tool_choice=tool_choice,
mode=ChatResponseMode.WAIT,
parent_task_id=parent_task_id,
)
assert isinstance(chat_response, AgentChatResponse)
e.on_end(payload={EventPayload.RESPONSE: chat_response})
Expand All @@ -631,6 +644,7 @@ async def achat(
message: str,
chat_history: Optional[List[ChatMessage]] = None,
tool_choice: Optional[Union[str, dict]] = None,
parent_task_id: Optional[str] = None,
) -> AgentChatResponse:
# override tool choice is provided as input.
if tool_choice is None:
Expand All @@ -644,6 +658,7 @@ async def achat(
chat_history=chat_history,
tool_choice=tool_choice,
mode=ChatResponseMode.WAIT,
parent_task_id=parent_task_id,
)
assert isinstance(chat_response, AgentChatResponse)
e.on_end(payload={EventPayload.RESPONSE: chat_response})
Expand Down
1 change: 1 addition & 0 deletions llama-index-core/llama_index/core/agent/types.py
Expand Up @@ -161,6 +161,7 @@ class Config:
default_factory=lambda: str(uuid.uuid4()), type=str, description="Task ID"
)
input: str = Field(..., type=str, description="User input")
parent_task_id: Optional[str] = Field(default=None, description="Parent Task ID")

# NOTE: this is state that may be modified throughout the course of execution of the task
memory: BaseMemory = Field(
Expand Down
10 changes: 10 additions & 0 deletions llama-index-core/llama_index/core/agent/utils.py
Expand Up @@ -13,3 +13,13 @@ def add_user_step_to_memory(
memory.put(user_message)
if verbose:
print(f"Added user message to memory: {step.input}")


def add_assistant_step_to_memory(
step: TaskStep, memory: BaseMemory, verbose: bool = False
) -> None:
"""Add user step to memory."""
asst_message = ChatMessage(content=step.input, role=MessageRole.ASSISTANT)
memory.put(asst_message)
if verbose:
print(f"Added assistant message to memory: {step.input}")

0 comments on commit 542f9d1

Please sign in to comment.