Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 69 additions & 62 deletions aisuite/utils/tools.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, Dict, Any, Type, Optional, get_origin, get_args, Union
from typing import Callable, Dict, Any, Type, Optional
from pydantic import BaseModel, create_model, Field, ValidationError
import inspect
import json
Expand Down Expand Up @@ -114,78 +114,85 @@ def tools(self, format="openai") -> list:
return self.__convert_to_openai_format()
return [tool["spec"] for tool in self._tools.values()]

def _unwrap_optional(self, field_type: Type) -> tuple[Type, bool]:
"""
Unwrap Optional[T] to get the base type T.
@classmethod
def _normalize_json_schema(cls, schema: Dict[str, Any]) -> Dict[str, Any]:
"""Normalize Pydantic JSON Schema for provider tool definitions."""
definitions = schema.get("$defs", {})

Returns:
tuple: (base_type, is_optional)
"""
# Check if it's Optional (Union with None)
origin = get_origin(field_type)
if origin is Union:
args = get_args(field_type)
# Optional[T] is Union[T, None]
if type(None) in args:
# Get the non-None type
non_none_types = [arg for arg in args if arg is not type(None)]
if len(non_none_types) == 1:
return non_none_types[0], True
return field_type, False
def normalize(value):
if isinstance(value, list):
return [normalize(item) for item in value]

if not isinstance(value, dict):
return value

if "$ref" in value:
resolved = cls._resolve_local_json_schema_ref(
value["$ref"], definitions
)
if resolved is not None:
siblings = {
key: item for key, item in value.items() if key != "$ref"
}
return normalize({**resolved, **siblings})

normalized = {
key: normalize(item)
for key, item in value.items()
if key not in {"$defs", "title"}
}
return cls._flatten_nullable_json_schema(normalized)

return normalize(schema)

@staticmethod
def _resolve_local_json_schema_ref(ref: str, definitions: Dict[str, Any]):
prefix = "#/$defs/"
if not ref.startswith(prefix):
return None
return definitions.get(ref[len(prefix) :])

@staticmethod
def _flatten_nullable_json_schema(schema: Dict[str, Any]) -> Dict[str, Any]:
any_of = schema.get("anyOf")
if not isinstance(any_of, list) or len(any_of) != 2:
return schema

non_null_schemas = [
item
for item in any_of
if not (isinstance(item, dict) and item.get("type") == "null")
]
has_null = len(non_null_schemas) != len(any_of)
if not has_null or len(non_null_schemas) != 1:
return schema

base_schema = non_null_schemas[0]
if not isinstance(base_schema, dict):
return schema

nullable_metadata = {
key: value for key, value in schema.items() if key != "anyOf"
}
return {**base_schema, **nullable_metadata}

# Convert the function and its Pydantic model to a unified tool specification.
def _convert_to_tool_spec(
self, func: Callable, param_model: Type[BaseModel]
) -> Dict[str, Any]:
"""Convert the function and its Pydantic model to a unified tool specification."""
type_mapping = {str: "string", int: "integer", float: "number", bool: "boolean"}

properties = {}
for field_name, field in param_model.model_fields.items():
field_type = field.annotation

# Unwrap Optional[T] to get base type T
field_type, is_optional = self._unwrap_optional(field_type)

# Handle enum types
if hasattr(field_type, "__members__"): # Check if it's an enum
enum_values = [
member.value if hasattr(member, "value") else member.name
for member in field_type
]
properties[field_name] = {
"type": "string",
"enum": enum_values,
"description": field.description or "",
}
# Convert enum default value to string if it exists
if str(field.default) != "PydanticUndefined":
properties[field_name]["default"] = (
field.default.value
if hasattr(field.default, "value")
else field.default
)
else:
properties[field_name] = {
"type": type_mapping.get(field_type, str(field_type)),
"description": field.description or "",
}
# Add default if it exists and isn't PydanticUndefined
if str(field.default) != "PydanticUndefined":
properties[field_name]["default"] = field.default
parameters = self._normalize_json_schema(param_model.model_json_schema())
parameters.setdefault("type", "object")
properties = parameters.setdefault("properties", {})

for field_name in param_model.model_fields:
if field_name in properties:
properties[field_name].setdefault("description", "")

return {
"name": func.__name__,
"description": func.__doc__ or "",
"parameters": {
"type": "object",
"properties": properties,
"required": [
name
for name, field in param_model.model_fields.items()
if field.is_required and str(field.default) == "PydanticUndefined"
],
},
"parameters": parameters,
}

def __extract_param_descriptions(self, func: Callable) -> dict[str, str]:
Expand Down
75 changes: 74 additions & 1 deletion tests/utils/test_tool_manager.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import unittest
from pydantic import BaseModel
from typing import Dict
from typing import Dict, Literal, Optional
from aisuite.utils.tools import Tools # Import your ToolManager class
from enum import Enum

Expand Down Expand Up @@ -38,6 +38,23 @@ def get_current_temperature_v2(
return {"location": location, "unit": unit, "temperature": "72"}


def analyze_items(
tags: list[str],
scores: dict[str, int],
mode: Literal["fast", "thorough"] = "fast",
note: Optional[str] = None,
limit: int | None = None,
) -> Dict[str, object]:
"""Analyze items using complex argument types."""
return {
"tags": tags,
"scores": scores,
"mode": mode,
"note": note,
"limit": limit,
}


class TestToolManager(unittest.TestCase):
def setUp(self):
self.tool_manager = Tools()
Expand Down Expand Up @@ -195,6 +212,62 @@ def test_add_tool_with_enum(self):
tools == expected_tool_spec
), f"Expected {expected_tool_spec}, but got {tools}"

def test_add_tool_with_complex_annotations_uses_json_schema_types(self):
"""Test complex annotations generate valid JSON Schema types."""
self.tool_manager._add_tool(analyze_items)

parameters = self.tool_manager.tools()[0]["function"]["parameters"]
properties = parameters["properties"]

self.assertEqual(properties["tags"]["type"], "array")
self.assertEqual(properties["tags"]["items"], {"type": "string"})
self.assertEqual(properties["scores"]["type"], "object")
self.assertEqual(
properties["scores"]["additionalProperties"], {"type": "integer"}
)
self.assertEqual(properties["mode"]["type"], "string")
self.assertEqual(properties["mode"]["enum"], ["fast", "thorough"])
self.assertEqual(properties["mode"]["default"], "fast")
self.assertEqual(properties["note"]["type"], "string")
self.assertIsNone(properties["note"]["default"])
self.assertEqual(properties["limit"]["type"], "integer")
self.assertIsNone(properties["limit"]["default"])
self.assertEqual(parameters["required"], ["tags", "scores"])

for schema in properties.values():
self.assertNotIn("list[", str(schema))
self.assertNotIn("dict[", str(schema))
self.assertNotIn("| None", str(schema))

def test_execute_tool_with_complex_annotations(self):
"""Test Pydantic validation still executes complex-typed tools."""
self.tool_manager._add_tool(analyze_items)
tool_call = {
"id": "call_1",
"function": {
"name": "analyze_items",
"arguments": {
"tags": ["urgent", "finance"],
"scores": {"urgent": 10, "finance": 7},
"mode": "thorough",
"limit": 3,
},
},
}

result, _ = self.tool_manager.execute_tool(tool_call)

self.assertEqual(
result[0],
{
"tags": ["urgent", "finance"],
"scores": {"urgent": 10, "finance": 7},
"mode": "thorough",
"note": None,
"limit": 3,
},
)


if __name__ == "__main__":
unittest.main()