Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FunctionTool] Use docstring_parser to infer description for FunctionTool #12864

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
21 changes: 17 additions & 4 deletions llama-index-core/llama_index/core/tools/function_tool.py
@@ -1,12 +1,12 @@
import asyncio
from inspect import signature
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, Type

if TYPE_CHECKING:
from llama_index.core.bridge.langchain import StructuredTool, Tool
from llama_index.core.bridge.pydantic import BaseModel
from llama_index.core.tools.types import AsyncBaseTool, ToolMetadata, ToolOutput
from llama_index.core.tools.utils import create_schema_from_function
import docstring_parser

AsyncCallable = Callable[..., Awaitable[Any]]

Expand All @@ -22,7 +22,8 @@ async def _async_wrapped_fn(*args: Any, **kwargs: Any) -> Any:


class FunctionTool(AsyncBaseTool):
"""Function Tool.
"""
Function Tool.

A tool that takes in a function.

Expand Down Expand Up @@ -54,8 +55,20 @@ def from_defaults(
) -> "FunctionTool":
if tool_metadata is None:
name = name or fn.__name__
docstring = fn.__doc__
description = description or f"{name}{signature(fn)}\n{docstring}"
if not description:
doc = docstring_parser.parse(fn.__doc__)
doc.meta = [
i
for i in doc.meta
if not isinstance(i, docstring_parser.common.DocstringParam)
]
description = docstring_parser.compose(
doc,
docstring_parser.Style.GOOGLE,
rendering_style=docstring_parser.RenderingStyle.CLEAN,
indent=" ",
)

if fn_schema is None:
fn_schema = create_schema_from_function(
f"{name}", fn, additional_fields=None
Expand Down
3 changes: 2 additions & 1 deletion llama-index-core/llama_index/core/tools/types.py
Expand Up @@ -58,7 +58,8 @@ def get_name(self) -> str:
"Deprecated in favor of `to_openai_tool`, which should be used instead."
)
def to_openai_function(self) -> Dict[str, Any]:
"""Deprecated and replaced by `to_openai_tool`.
"""
Deprecated and replaced by `to_openai_tool`.
The name and arguments of a function that should be called, as generated by the
model.
"""
Expand Down
28 changes: 21 additions & 7 deletions llama-index-core/llama_index/core/tools/utils.py
@@ -1,7 +1,9 @@
from inspect import signature
from typing import Any, Callable, List, Optional, Tuple, Type, Union, cast
from typing import Any, Callable, List, Optional, Tuple, Type, Union, cast, Dict

from llama_index.core.bridge.pydantic import BaseModel, FieldInfo, create_model
from llama_index.core.bridge.pydantic import BaseModel, FieldInfo, create_model, Field

import docstring_parser


def create_schema_from_function(
Expand All @@ -12,23 +14,35 @@ def create_schema_from_function(
] = None,
) -> Type[BaseModel]:
"""Create schema from function."""
fields = {}
fields: Dict[str, FieldInfo] = {}
params = signature(func).parameters
doc = docstring_parser.parse(func.__doc__)

params_doc = {param.arg_name: param for param in doc.params}
for param_name in params:
param_type = params[param_name].annotation
param_default = params[param_name].default
param_desc = (
params_doc[param_name].description if param_name in params_doc else None
)

if param_type is params[param_name].empty:
param_type = Any

if param_default is params[param_name].empty:
# Required field
fields[param_name] = (param_type, FieldInfo())
field_info = Field()
elif isinstance(param_default, FieldInfo):
# Field with pydantic.Field as default value
fields[param_name] = (param_type, param_default)
# Field with pydantic.FieldInfo as default value
field_info = param_default
else:
fields[param_name] = (param_type, FieldInfo(default=param_default))
field_info = Field(default=param_default)
field_info = cast(FieldInfo, field_info)

if param_desc:
field_info.description = param_desc

fields[param_name] = (param_type, field_info)

additional_fields = additional_fields or []
for field_info in additional_fields:
Expand Down
19 changes: 17 additions & 2 deletions llama-index-core/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions llama-index-core/pyproject.toml
Expand Up @@ -85,6 +85,7 @@ pillow = ">=9.0.0"
PyYAML = ">=6.0.1"
llamaindex-py-client = "^0.1.18"
wrapt = "*"
docstring-parser = "^0.16"

[tool.poetry.extras]
gradientai = [
Expand Down