diff --git a/llama-index-core/llama_index/core/agent/function_calling/step.py b/llama-index-core/llama_index/core/agent/function_calling/step.py index 5e7c16805a7fd0..f2d4ca16e906aa 100644 --- a/llama-index-core/llama_index/core/agent/function_calling/step.py +++ b/llama-index-core/llama_index/core/agent/function_calling/step.py @@ -305,7 +305,7 @@ def run_step(self, step: TaskStep, task: Task, **kwargs: Any) -> TaskStepOutput: ) agent_response = AgentChatResponse( - response=str(response), sources=task.extra_state["sources"] + response=response.message.content, sources=task.extra_state["sources"] ) return TaskStepOutput( diff --git a/llama-index-core/llama_index/core/agent/introspective/step.py b/llama-index-core/llama_index/core/agent/introspective/step.py index d0bd5d1511df89..a4885a39535a57 100644 --- a/llama-index-core/llama_index/core/agent/introspective/step.py +++ b/llama-index-core/llama_index/core/agent/introspective/step.py @@ -3,7 +3,6 @@ import logging import uuid from typing import Any, List, Optional, cast -import asyncio from llama_index.core.agent.types import ( BaseAgentWorker, @@ -52,9 +51,8 @@ def __init__( self, tools: List[BaseTool], llm: FunctionCallingLLM, - main_agent_worker: BaseAgentWorker, reflective_agent_worker: BaseAgentWorker, - prefix_messages: List[ChatMessage], + main_agent_worker: Optional[BaseAgentWorker] = None, verbose: bool = False, max_function_calls: int = 5, callback_manager: Optional[CallbackManager] = None, @@ -71,7 +69,6 @@ def __init__( self._max_function_calls = max_function_calls self._main_agent_worker = main_agent_worker self._reflective_agent_worker = reflective_agent_worker - self.prefix_messages = prefix_messages self.callback_manager = callback_manager or self._llm.callback_manager self.allow_parallel_tool_calls = allow_parallel_tool_calls @@ -89,8 +86,8 @@ def __init__( @classmethod def from_args( cls, - main_agent_worker: BaseAgentWorker, reflective_agent_worker: BaseAgentWorker, + main_agent_worker: Optional[BaseAgentWorker] = None, tools: Optional[List[BaseTool]] = None, tool_retriever: Optional[ObjectRetriever[BaseTool]] = None, llm: Optional[FunctionCallingLLM] = None, @@ -98,7 +95,6 @@ def from_args( max_function_calls: int = DEFAULT_MAX_FUNCTION_CALLS, callback_manager: Optional[CallbackManager] = None, system_prompt: Optional[str] = None, - prefix_messages: Optional[List[ChatMessage]] = None, **kwargs: Any, ) -> "IntrospectiveAgentWorker": """Create an IntrospectiveAgentWorker from a list of tools. @@ -121,15 +117,12 @@ def from_args( ) prefix_messages = [ChatMessage(content=system_prompt, role="system")] - prefix_messages = prefix_messages or [] - return cls( tools=tools, tool_retriever=tool_retriever, main_agent_worker=main_agent_worker, reflective_agent_worker=reflective_agent_worker, llm=llm, - prefix_messages=prefix_messages, verbose=verbose, max_function_calls=max_function_calls, callback_manager=callback_manager, @@ -142,6 +135,11 @@ def initialize_step(self, task: Task, **kwargs: Any) -> TaskStep: main_memory = ChatMemoryBuffer.from_defaults() reflective_memory = ChatMemoryBuffer.from_defaults() + # put current history in new memory + messages = task.memory.get() + for message in messages: + main_memory.put(message) + # initialize task state task_state = { "main": { @@ -174,20 +172,36 @@ def get_all_messages(self, task: Task) -> List[ChatMessage]: def run_step(self, step: TaskStep, task: Task, **kwargs: Any) -> TaskStepOutput: """Run step.""" # run main agent - main_agent = self._main_agent_worker.as_agent() - main_agent_response = main_agent.chat( - task.input - ) # or should i use step.input here? - task.extra_state["main"]["sources"] = main_agent_response.sources - task.extra_state["main"]["memory"] = main_agent.memory - print(f"MAIN AGENT MEMORY: {main_agent.memory}", flush=True) + if self._main_agent_worker is not None: + main_agent_messages = task.extra_state["main"]["memory"].get() + main_agent = self._main_agent_worker.as_agent( + chat_history=main_agent_messages + ) + main_agent_response = main_agent.chat( + task.input, parent_task_id=task.task_id + ) + original_response = main_agent_response.response + task.extra_state["main"]["sources"] = main_agent_response.sources + task.extra_state["main"]["memory"] = main_agent.memory + else: + add_user_step_to_memory( + step, task.extra_state["main"]["memory"], verbose=self._verbose + ) + original_response = step.input + task.extra_state["main"]["memory"].put( + ChatMessage(content=original_response, role="assistant") + ) # run reflective agent - reflective_agent = self._reflective_agent_worker.as_agent() - reflective_agent_response = reflective_agent.chat(main_agent_response.response) + reflective_agent_messages = task.extra_state["main"]["memory"].get() + reflective_agent = self._reflective_agent_worker.as_agent( + chat_history=reflective_agent_messages + ) + reflective_agent_response = reflective_agent.chat( + original_response, parent_task_id=task.task_id + ) task.extra_state["reflection"]["sources"] = reflective_agent_response.sources task.extra_state["reflection"]["memory"] = reflective_agent.memory - print(f"REFLECTIVE AGENT MEMORY: {reflective_agent.memory}", flush=True) agent_response = AgentChatResponse( response=str(reflective_agent_response.response), @@ -207,86 +221,7 @@ async def arun_step( self, step: TaskStep, task: Task, **kwargs: Any ) -> TaskStepOutput: """Run step (async).""" - if step.input is not None: - add_user_step_to_memory( - step, task.extra_state["new_memory"], verbose=self._verbose - ) - # TODO: see if we want to do step-based inputs - tools = self.get_tools(task.input) - - # get response and tool call (if exists) - response = await self._llm.achat_with_tools( - tools=tools, - user_msg=None, - chat_history=self.get_all_messages(task), - verbose=self._verbose, - allow_parallel_tool_calls=self.allow_parallel_tool_calls, - ) - tool_calls = self._llm.get_tool_calls_from_response( - response, error_on_no_tool_call=False - ) - - if self._verbose and response.message.content: - print("=== LLM Response ===") - print(str(response.message.content)) - - if not self.allow_parallel_tool_calls and len(tool_calls) > 1: - raise ValueError( - "Parallel tool calls not supported for synchronous function calling agent" - ) - - # call all tools, gather responses - task.extra_state["new_memory"].put(response.message) - if ( - len(tool_calls) == 0 - or task.extra_state["n_function_calls"] >= self._max_function_calls - ): - # we are done - is_done = True - new_steps = [] - else: - is_done = False - tasks = [ - self._acall_function( - tools, - tool_call, - task.extra_state["new_memory"], - task.extra_state["sources"], - verbose=self._verbose, - ) - for tool_call in tool_calls - ] - return_directs = await asyncio.gather(*tasks) - - # check if any of the tools return directly -- only works if there is one tool call - if len(return_directs) == 1 and return_directs[0]: - is_done = True - response = task.extra_state["sources"][-1].content - - task.extra_state["n_function_calls"] += len(tool_calls) - # put tool output in sources and memory - new_steps = ( - [ - step.get_next_step( - step_id=str(uuid.uuid4()), - # NOTE: input is unused - input=None, - ) - ] - if not is_done - else [] - ) - - agent_response = AgentChatResponse( - response=str(response), sources=task.extra_state["sources"] - ) - - return TaskStepOutput( - output=agent_response, - task_step=step, - is_last=is_done, - next_steps=new_steps, - ) + ... @trace_method("run_step") def stream_step(self, step: TaskStep, task: Task, **kwargs: Any) -> TaskStepOutput: @@ -303,10 +238,6 @@ async def astream_step( def finalize_task(self, task: Task, **kwargs: Any) -> None: """Finalize task, after all the steps are completed.""" # add new messages to memory - task.memory.set( - task.memory.get_all() - + task.extra_state["main"]["memory"].get_all() - + task.extra_state["reflection"]["memory"].get_all() - ) + task.memory.set(task.extra_state["reflection"]["memory"].get_all()) # reset new memory task.extra_state["main"]["memory"].reset() diff --git a/llama-index-core/llama_index/core/agent/runner/base.py b/llama-index-core/llama_index/core/agent/runner/base.py index 81a21b144ea474..1bdb2c3c638c5c 100644 --- a/llama-index-core/llama_index/core/agent/runner/base.py +++ b/llama-index-core/llama_index/core/agent/runner/base.py @@ -171,6 +171,9 @@ class AgentState(BaseModel): default_factory=dict, description="Task dictionary." ) + class Config: + arbitrary_types_allowed = True + def get_task(self, task_id: str) -> Task: """Get task state.""" return self.task_dict[task_id].task @@ -294,7 +297,9 @@ def reset(self) -> None: self.memory.reset() self.state.reset() - def create_task(self, input: str, **kwargs: Any) -> Task: + def create_task( + self, input: str, parent_task_id: Optional[str] = None, **kwargs: Any + ) -> Task: """Create task.""" if not self.init_task_state_kwargs: extra_state = kwargs.pop("extra_state", {}) @@ -309,6 +314,7 @@ def create_task(self, input: str, **kwargs: Any) -> Task: callback_manager = kwargs.pop("callback_manager", self.callback_manager) task = Task( input=input, + parent_task_id=parent_task_id, memory=self.memory, extra_state=extra_state, callback_manager=callback_manager, @@ -532,13 +538,14 @@ def _chat( chat_history: Optional[List[ChatMessage]] = None, tool_choice: Union[str, dict] = "auto", mode: ChatResponseMode = ChatResponseMode.WAIT, + parent_task_id: Optional[str] = None, ) -> AGENT_CHAT_RESPONSE_TYPE: """Chat with step executor.""" dispatch_event = dispatcher.get_dispatch_event() if chat_history is not None: self.memory.set(chat_history) - task = self.create_task(message) + task = self.create_task(message, parent_task_id) result_output = None dispatch_event(AgentChatWithStepStartEvent()) @@ -569,20 +576,24 @@ async def _achat( chat_history: Optional[List[ChatMessage]] = None, tool_choice: Union[str, dict] = "auto", mode: ChatResponseMode = ChatResponseMode.WAIT, + parent_task_id: Optional[str] = None, ) -> AGENT_CHAT_RESPONSE_TYPE: """Chat with step executor.""" dispatch_event = dispatcher.get_dispatch_event() if chat_history is not None: self.memory.set(chat_history) - task = self.create_task(message) + task = self.create_task(message, parent_task_id) result_output = None dispatch_event(AgentChatWithStepStartEvent()) while True: # pass step queue in as argument, assume step executor is stateless cur_step_output = await self._arun_step( - task.task_id, mode=mode, tool_choice=tool_choice + task.task_id, + mode=mode, + tool_choice=tool_choice, + parent_task_id=parent_task_id, ) if cur_step_output.is_last: @@ -606,6 +617,7 @@ def chat( message: str, chat_history: Optional[List[ChatMessage]] = None, tool_choice: Optional[Union[str, dict]] = None, + parent_task_id: Optional[str] = None, ) -> AgentChatResponse: # override tool choice is provided as input. if tool_choice is None: @@ -619,6 +631,7 @@ def chat( chat_history=chat_history, tool_choice=tool_choice, mode=ChatResponseMode.WAIT, + parent_task_id=parent_task_id, ) assert isinstance(chat_response, AgentChatResponse) e.on_end(payload={EventPayload.RESPONSE: chat_response}) @@ -631,6 +644,7 @@ async def achat( message: str, chat_history: Optional[List[ChatMessage]] = None, tool_choice: Optional[Union[str, dict]] = None, + parent_task_id: Optional[str] = None, ) -> AgentChatResponse: # override tool choice is provided as input. if tool_choice is None: @@ -644,6 +658,7 @@ async def achat( chat_history=chat_history, tool_choice=tool_choice, mode=ChatResponseMode.WAIT, + parent_task_id=parent_task_id, ) assert isinstance(chat_response, AgentChatResponse) e.on_end(payload={EventPayload.RESPONSE: chat_response}) diff --git a/llama-index-core/llama_index/core/agent/types.py b/llama-index-core/llama_index/core/agent/types.py index 4b271629021457..4d68c46528c603 100644 --- a/llama-index-core/llama_index/core/agent/types.py +++ b/llama-index-core/llama_index/core/agent/types.py @@ -161,6 +161,7 @@ class Config: default_factory=lambda: str(uuid.uuid4()), type=str, description="Task ID" ) input: str = Field(..., type=str, description="User input") + parent_task_id: Optional[str] = Field(default=None, description="Parent Task ID") # NOTE: this is state that may be modified throughout the course of execution of the task memory: BaseMemory = Field( diff --git a/llama-index-core/llama_index/core/agent/utils.py b/llama-index-core/llama_index/core/agent/utils.py index 0fcb2d03963a07..cd82631626c43f 100644 --- a/llama-index-core/llama_index/core/agent/utils.py +++ b/llama-index-core/llama_index/core/agent/utils.py @@ -13,3 +13,13 @@ def add_user_step_to_memory( memory.put(user_message) if verbose: print(f"Added user message to memory: {step.input}") + + +def add_assistant_step_to_memory( + step: TaskStep, memory: BaseMemory, verbose: bool = False +) -> None: + """Add user step to memory.""" + asst_message = ChatMessage(content=step.input, role=MessageRole.ASSISTANT) + memory.put(asst_message) + if verbose: + print(f"Added assistant message to memory: {step.input}")