Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validate groups when collect tool packages #3085

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/promptflow-core/promptflow/_core/tools_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
get_prompt_param_name_from_func,
load_function_from_function_path,
validate_tool_func_result,
validate_groups_if_exist_in_tool_spec
)
from promptflow._utils.yaml_utils import load_yaml
from promptflow.contracts.flow import InputAssignment, InputValueType, Node, ToolSourceType
Expand Down Expand Up @@ -97,6 +98,8 @@ def collect_package_tools(keys: Optional[List[str]] = None) -> dict:

m = tool["module"]
importlib.import_module(m) # Import the module to make sure it is valid
if "groups" in tool.keys():
validate_groups_if_exist_in_tool_spec(tool)
tool["package"] = entry_point.dist.metadata["Name"]
tool["package_version"] = entry_point.dist.version
assign_tool_input_index_for_ux_order_if_needed(tool)
Expand Down Expand Up @@ -128,6 +131,8 @@ def collect_package_tools_and_connections(keys: Optional[List[str]] = None) -> d
continue
m = tool["module"]
module = importlib.import_module(m) # Import the module to make sure it is valid
if "groups" in tool.keys():
validate_groups_if_exist_in_tool_spec(tool)
tool["package"] = entry_point.dist.metadata["Name"]
tool["package_version"] = entry_point.dist.version
assign_tool_input_index_for_ux_order_if_needed(tool)
Expand Down
49 changes: 48 additions & 1 deletion src/promptflow-core/promptflow/_utils/tool_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from promptflow._utils.context_utils import _change_working_dir
from promptflow._utils.utils import is_json_serializable
from promptflow.core._model_configuration import MODEL_CONFIG_NAME_2_CLASS
from promptflow.exceptions import ErrorTarget, UserErrorException
from promptflow.exceptions import ErrorTarget, UserErrorException, ValidationException

from ..contracts.tool import (
ConnectionType,
Expand Down Expand Up @@ -470,6 +470,53 @@ def _get_function_path(function):
return func, func_path


def validate_groups_if_exist_in_tool_spec(tool):
"""Validate groups if exist in tool spec."""
tool_name = tool.get("name", "")
groups = tool.get("groups", "")
invalid_group_names = ["advanced"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

invalid -> reserved

used_group_names = set()
used_inputs = set()
for group in groups:
group_name = group.get("name", "")
group_inputs = set(group.get("inputs", []))
# Group must have name and inputs
if not group_name or not group_inputs:
message_format = "Group must have name and inputs, please check the tool '{0}' and rename the group."
raise ValidationException(
message=message_format.format(tool_name),
message_format=message_format,
ErrorTarget=ErrorTarget.TOOL)

# Some group names cannot be used like advanced
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please explain why.

if group_name.lower() in invalid_group_names:
message_format = "The group name '{0}' cannot be used, please check the tool '{1}' and rename the group."
raise ValidationException(
message=message_format.format(group_name, tool_name),
message_format=message_format,
ErrorTarget=ErrorTarget.TOOL)

# Group name should be unique
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think duplicate group name may hit yaml load error? Why do we not need such logic for tool input name duplication?

if group_name in used_group_names:
message_format = "Group name should be unique, please check the tool '{0}' and rename the group."
raise ValidationException(
message=message_format.format(tool_name),
message_format=message_format,
ErrorTarget=ErrorTarget.TOOL)
else:
used_group_names.add(group_name)

# Each input shouldn't appear in multiple groups
if len(group_inputs.intersection(used_inputs)) > 0:
message_format = "Each input shouldn't appear in multiple groups, please check the tool '{0}'."
raise ValidationException(
message=message_format.format(tool_name),
message_format=message_format,
ErrorTarget=ErrorTarget.TOOL)
else:
used_inputs.update(group_inputs)


class RetrieveToolFuncResultError(UserErrorException):
"""Base exception raised for retrieve tool func result errors."""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
param_to_definition,
validate_dynamic_list_func_response_type,
validate_tool_func_result,
validate_groups_if_exist_in_tool_spec
)
from promptflow.connections import AzureOpenAIConnection, CustomConnection
from promptflow.contracts.tool import Tool, ToolFuncCallScenario, ToolType, ValueType
Expand Down Expand Up @@ -407,3 +408,44 @@ def test_find_deprecated_tools(self):
}
with pytest.raises(DuplicateToolMappingError, match="secure operation"):
_find_deprecated_tools(package_tools)

@pytest.mark.parametrize(
"is_valid_spec, tool_spec, expected_error_message",
[
(
False,
{"name": "tool1", "groups": [{"inputs": ["tools", "tool_choice"]}]},
"Group must have name and inputs"
),
(
False,
{"name": "tool1", "groups": [{"name": "Tools", "inputs": ["tools", "tool_choice"]},
{"name": "Tools", "inputs": ["tools1", "tool_choice1"]}]},
"Group name should be unique"
),
(
False,
{"name": "tool1", "groups": [{"name": "Advanced", "inputs": ["tools", "tool_choice"]}]},
"The group name 'Advanced' cannot be used"
),
(
False,
{"name": "tool1", "groups": [{"name": "Tools", "inputs": ["tools", "tool_choice"]},
{"name": "Tools1", "inputs": ["tools", "tool_choice1"]}]},
"Each input shouldn't appear in multiple groups"
),
(
True,
{"name": "tool1", "groups": [{"name": "Tools", "inputs": ["tools", "tool_choice"]}]},
None
)
],
)
def test_validate_groups_if_exist_in_tool_spec(self, is_valid_spec, tool_spec, expected_error_message):
if not is_valid_spec:
from promptflow.exceptions import ValidationException
with pytest.raises(ValidationException) as e:
validate_groups_if_exist_in_tool_spec(tool_spec)
assert expected_error_message in str(e.value)
else:
validate_groups_if_exist_in_tool_spec(tool_spec)