Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding re-raise flag at OpenAI Agent level #12674

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -54,6 +54,7 @@ def __init__(
verbose: bool = False,
max_function_calls: int = DEFAULT_MAX_FUNCTION_CALLS,
default_tool_choice: str = "auto",
raise_error: bool = False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My general feedback is this arg is pretty generic, but its actual use only covers one very specific use-case, which is kind of misleading

Is there any way we can think of to improve this? A better name? Handle more scenarios where errors are caught?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, please double check the unit tests

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi Logan! Thank you for the comments! :)

So regarding this feature, absolutely! I understand there are multiple places where exceptions are raised which are captured by llama-index code.
This was motivated by Azure filtering error exception while streaming, I needed to capture in my side the exception.

I noticed the functionality (i.e. the flag to re-reaise the exception) was there, although not end-to-end implemented.
My main objective was to introduce the parameter at an Agent level, as this was recommended in a thread in Discord.

I'm no expert of the whole llama-index codebase, but I am sure that with this small step, any other in need of this kind of feature will be able to adjust, in future PRs, the code to get exception in particular scenarios if they need!

Regarding the naming of the variable at an Agent level, wdyt of capture_exception or rethrow_exception, to be more explicit of not an error, but an exception.

And to finish, I checked the unit tests, they neither pass in the main branch (at least on my setup). I would take a look.

Please let me know about your thought on this! ✨

callback_manager: Optional[CallbackManager] = None,
tool_retriever: Optional[ObjectRetriever[BaseTool]] = None,
tool_call_parser: Optional[Callable[[OpenAIToolCall], Dict]] = None,
Expand All @@ -66,6 +67,7 @@ def __init__(
llm=llm,
verbose=verbose,
max_function_calls=max_function_calls,
raise_error=raise_error,
callback_manager=callback_manager,
prefix_messages=prefix_messages,
tool_call_parser=tool_call_parser,
Expand All @@ -90,6 +92,7 @@ def from_tools(
verbose: bool = False,
max_function_calls: int = DEFAULT_MAX_FUNCTION_CALLS,
default_tool_choice: str = "auto",
raise_error: bool = False,
callback_manager: Optional[CallbackManager] = None,
system_prompt: Optional[str] = None,
prefix_messages: Optional[List[ChatMessage]] = None,
Expand Down Expand Up @@ -137,6 +140,7 @@ def from_tools(
prefix_messages=prefix_messages,
verbose=verbose,
max_function_calls=max_function_calls,
raise_error=raise_error,
callback_manager=callback_manager,
default_tool_choice=default_tool_choice,
tool_call_parser=tool_call_parser,
Expand Down
Expand Up @@ -260,13 +260,15 @@ def __init__(
prefix_messages: List[ChatMessage],
verbose: bool = False,
max_function_calls: int = DEFAULT_MAX_FUNCTION_CALLS,
raise_error: bool = False,
callback_manager: Optional[CallbackManager] = None,
tool_retriever: Optional[ObjectRetriever[BaseTool]] = None,
tool_call_parser: Optional[Callable[[OpenAIToolCall], Dict]] = None,
):
self._llm = llm
self._verbose = verbose
self._max_function_calls = max_function_calls
self._raise_error = raise_error
self.prefix_messages = prefix_messages
self.callback_manager = callback_manager or self._llm.callback_manager
self.tool_call_parser = tool_call_parser or default_tool_call_parser
Expand All @@ -290,6 +292,7 @@ def from_tools(
llm: Optional[LLM] = None,
verbose: bool = False,
max_function_calls: int = DEFAULT_MAX_FUNCTION_CALLS,
raise_error: bool = False,
callback_manager: Optional[CallbackManager] = None,
system_prompt: Optional[str] = None,
prefix_messages: Optional[List[ChatMessage]] = None,
Expand Down Expand Up @@ -333,6 +336,7 @@ def from_tools(
prefix_messages=prefix_messages,
verbose=verbose,
max_function_calls=max_function_calls,
raise_error=raise_error,
callback_manager=callback_manager,
tool_call_parser=tool_call_parser,
)
Expand Down Expand Up @@ -384,7 +388,10 @@ def _get_stream_ai_response(
# Get the response in a separate thread so we can yield the response
thread = Thread(
target=chat_stream_response.write_response_to_history,
args=(task.extra_state["new_memory"],),
args=(
task.extra_state["new_memory"],
self._raise_error,
),
kwargs={"on_stream_end_fn": partial(self.finalize_task, task)},
)
thread.start()
Expand Down