diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java index 4738d310f6..7fa4b874ad 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java @@ -153,6 +153,7 @@ public class AgentUtils { public static final String TOKEN_USAGE_PATH = "token_usage_path"; private static final String NAME = "name"; private static final String DESCRIPTION = "description"; + private static final String INPUT_SCHEMA = "input_schema"; private static final Pattern ADDITIONAL_PROPERTIES_PATTERN = Pattern .compile(",\\s*\"additionalProperties\"\\s*:\\s*(?:false|true)", Pattern.CASE_INSENSITIVE); public static final String AGENT_LLM_MODEL_ID = "agent_llm_model_id"; @@ -240,16 +241,24 @@ public static String addToolsToFunctionCalling( toolParams.put(NAME, StringEscapeUtils.escapeJson(tool.getName())); toolParams.put(DESCRIPTION, StringEscapeUtils.escapeJson(tool.getDescription())); Map attributes = tool.getAttributes(); - if (attributes != null) { - for (String key : attributes.keySet()) { - toolParams.put("attributes." + key, attributes.get(key)); - } - // For Gemini, clean input_schema to remove additionalProperties - if (parameters.containsKey("gemini.schema.cleaner") && attributes.containsKey("input_schema")) { - String schema = String.valueOf(attributes.get("input_schema")); - String cleanedSchema = removeAdditionalPropertiesFromSchema(schema); - toolParams.put("attributes.input_schema_cleaned", cleanedSchema); - } + if (attributes == null || !attributes.containsKey(INPUT_SCHEMA)) { + throw new IllegalArgumentException( + "Tool [" + + toolName + + "] is missing [" + + INPUT_SCHEMA + + "] in its attributes. " + + "All tools used with function calling must define an input_schema." + ); + } + for (String key : attributes.keySet()) { + toolParams.put("attributes." + key, attributes.get(key)); + } + // For Gemini, clean input_schema to remove additionalProperties + if (parameters.containsKey("gemini.schema.cleaner") && attributes.containsKey(INPUT_SCHEMA)) { + String schema = String.valueOf(attributes.get(INPUT_SCHEMA)); + String cleanedSchema = removeAdditionalPropertiesFromSchema(schema); + toolParams.put("attributes.input_schema_cleaned", cleanedSchema); } StringSubstitutor substitutor = new StringSubstitutor(toolParams, "${tool.", "}"); String chatQuestionMessage = substitutor.replace(toolTemplate); @@ -1262,10 +1271,10 @@ public static Map wrapFrontendToolsAsToolObjects(List emptySchema = Map.of("type", "object", "properties", Map.of()); - toolAttributes.put("input_schema", gson.toJson(emptySchema)); + toolAttributes.put(INPUT_SCHEMA, gson.toJson(emptySchema)); } Tool frontendToolObj = new AGUIFrontendTool(toolName, toolDescription, toolAttributes); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/AgentUtilsTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/AgentUtilsTest.java index 0809ec40d7..d2781d70ad 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/AgentUtilsTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/AgentUtilsTest.java @@ -1230,11 +1230,11 @@ public void testAddToolsToFunctionCalling() { when(tool1.getName()).thenReturn("Tool1"); when(tool1.getDescription()).thenReturn("Description of Tool1"); - when(tool1.getAttributes()).thenReturn(Map.of("param1", "value1")); + when(tool1.getAttributes()).thenReturn(Map.of("param1", "value1", "input_schema", "{\"type\":\"object\",\"properties\":{}}")); when(tool2.getName()).thenReturn("Tool2"); when(tool2.getDescription()).thenReturn("Description of Tool2"); - when(tool2.getAttributes()).thenReturn(Map.of("param2", "value2")); + when(tool2.getAttributes()).thenReturn(Map.of("param2", "value2", "input_schema", "{\"type\":\"object\",\"properties\":{}}")); Map parameters = new HashMap<>(); String toolTemplate = "{\"name\": \"${tool.name}\", \"description\": \"${tool.description}\"}"; @@ -1252,6 +1252,54 @@ public void testAddToolsToFunctionCalling() { assertEquals(expectedTools, parameters.get(TOOLS)); } + @Test + public void testAddToolsToFunctionCalling_ToolWithNoInputSchema() { + Tool toolNoSchema = mock(Tool.class); + when(toolNoSchema.getName()).thenReturn("VectorDBTool"); + when(toolNoSchema.getDescription()).thenReturn("knn dense retrieval tool"); + when(toolNoSchema.getAttributes()).thenReturn(null); + + Map tools = new HashMap<>(); + tools.put("VectorDBTool", toolNoSchema); + + Map parameters = new HashMap<>(); + String toolTemplate = + "{\"toolSpec\":{\"name\":\"${tool.name}\",\"description\":\"${tool.description}\",\"inputSchema\":{\"json\":${tool.attributes.input_schema}}}}"; + parameters.put(TOOL_TEMPLATE, toolTemplate); + + Exception ex = Assert + .assertThrows( + IllegalArgumentException.class, + () -> AgentUtils.addToolsToFunctionCalling(tools, parameters, List.of("VectorDBTool"), "prompt") + ); + Assert.assertTrue(ex.getMessage().contains("VectorDBTool")); + Assert.assertTrue(ex.getMessage().contains("input_schema")); + } + + @Test + public void testAddToolsToFunctionCalling_ToolWithAttributesButNoInputSchema() { + Tool toolNoSchema = mock(Tool.class); + when(toolNoSchema.getName()).thenReturn("VectorDBTool"); + when(toolNoSchema.getDescription()).thenReturn("knn dense retrieval tool"); + when(toolNoSchema.getAttributes()).thenReturn(Map.of("some_other_key", "value")); + + Map tools = new HashMap<>(); + tools.put("VectorDBTool", toolNoSchema); + + Map parameters = new HashMap<>(); + String toolTemplate = + "{\"toolSpec\":{\"name\":\"${tool.name}\",\"description\":\"${tool.description}\",\"inputSchema\":{\"json\":${tool.attributes.input_schema}}}}"; + parameters.put(TOOL_TEMPLATE, toolTemplate); + + Exception ex = Assert + .assertThrows( + IllegalArgumentException.class, + () -> AgentUtils.addToolsToFunctionCalling(tools, parameters, List.of("VectorDBTool"), "prompt") + ); + Assert.assertTrue(ex.getMessage().contains("VectorDBTool")); + Assert.assertTrue(ex.getMessage().contains("input_schema")); + } + @Test public void testAddToolsToFunctionCalling_ToolNotRegistered() { Map tools = new HashMap<>();