Skip to content

Commit 3fa81f5

Browse files
committed
Replace default input_schema fallback with clear error
Instead of silently injecting an empty default schema when a tool is missing input_schema, throw IllegalArgumentException with a message naming the tool and the missing attribute. A default empty schema misleads the LLM into calling the tool with no arguments, causing silent failures downstream. Signed-off-by: Tyler Ohlsen <ohltyler@amazon.com>
1 parent 7b8e6c5 commit 3fa81f5

File tree

2 files changed

+32
-24
lines changed

2 files changed

+32
-24
lines changed

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,6 @@ public class AgentUtils {
154154
private static final String NAME = "name";
155155
private static final String DESCRIPTION = "description";
156156
private static final String INPUT_SCHEMA = "input_schema";
157-
private static final String DEFAULT_INPUT_SCHEMA = "{\"type\":\"object\",\"properties\":{}}";
158157
private static final Pattern ADDITIONAL_PROPERTIES_PATTERN = Pattern
159158
.compile(",\\s*\"additionalProperties\"\\s*:\\s*(?:false|true)", Pattern.CASE_INSENSITIVE);
160159
public static final String AGENT_LLM_MODEL_ID = "agent_llm_model_id";
@@ -243,18 +242,23 @@ public static String addToolsToFunctionCalling(
243242
toolParams.put(DESCRIPTION, StringEscapeUtils.escapeJson(tool.getDescription()));
244243
Map<String, ?> attributes = tool.getAttributes();
245244
if (attributes == null || !attributes.containsKey(INPUT_SCHEMA)) {
246-
toolParams.put("attributes." + INPUT_SCHEMA, DEFAULT_INPUT_SCHEMA);
245+
throw new IllegalArgumentException(
246+
"Tool ["
247+
+ toolName
248+
+ "] is missing ["
249+
+ INPUT_SCHEMA
250+
+ "] in its attributes. "
251+
+ "All tools used with function calling must define an input_schema."
252+
);
247253
}
248-
if (attributes != null) {
249-
for (String key : attributes.keySet()) {
250-
toolParams.put("attributes." + key, attributes.get(key));
251-
}
252-
// For Gemini, clean input_schema to remove additionalProperties
253-
if (parameters.containsKey("gemini.schema.cleaner") && attributes.containsKey(INPUT_SCHEMA)) {
254-
String schema = String.valueOf(attributes.get(INPUT_SCHEMA));
255-
String cleanedSchema = removeAdditionalPropertiesFromSchema(schema);
256-
toolParams.put("attributes.input_schema_cleaned", cleanedSchema);
257-
}
254+
for (String key : attributes.keySet()) {
255+
toolParams.put("attributes." + key, attributes.get(key));
256+
}
257+
// For Gemini, clean input_schema to remove additionalProperties
258+
if (parameters.containsKey("gemini.schema.cleaner") && attributes.containsKey(INPUT_SCHEMA)) {
259+
String schema = String.valueOf(attributes.get(INPUT_SCHEMA));
260+
String cleanedSchema = removeAdditionalPropertiesFromSchema(schema);
261+
toolParams.put("attributes.input_schema_cleaned", cleanedSchema);
258262
}
259263
StringSubstitutor substitutor = new StringSubstitutor(toolParams, "${tool.", "}");
260264
String chatQuestionMessage = substitutor.replace(toolTemplate);

ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/AgentUtilsTest.java

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1230,11 +1230,11 @@ public void testAddToolsToFunctionCalling() {
12301230

12311231
when(tool1.getName()).thenReturn("Tool1");
12321232
when(tool1.getDescription()).thenReturn("Description of Tool1");
1233-
when(tool1.getAttributes()).thenReturn(Map.of("param1", "value1"));
1233+
when(tool1.getAttributes()).thenReturn(Map.of("param1", "value1", "input_schema", "{\"type\":\"object\",\"properties\":{}}"));
12341234

12351235
when(tool2.getName()).thenReturn("Tool2");
12361236
when(tool2.getDescription()).thenReturn("Description of Tool2");
1237-
when(tool2.getAttributes()).thenReturn(Map.of("param2", "value2"));
1237+
when(tool2.getAttributes()).thenReturn(Map.of("param2", "value2", "input_schema", "{\"type\":\"object\",\"properties\":{}}"));
12381238

12391239
Map<String, String> parameters = new HashMap<>();
12401240
String toolTemplate = "{\"name\": \"${tool.name}\", \"description\": \"${tool.description}\"}";
@@ -1267,11 +1267,13 @@ public void testAddToolsToFunctionCalling_ToolWithNoInputSchema() {
12671267
"{\"toolSpec\":{\"name\":\"${tool.name}\",\"description\":\"${tool.description}\",\"inputSchema\":{\"json\":${tool.attributes.input_schema}}}}";
12681268
parameters.put(TOOL_TEMPLATE, toolTemplate);
12691269

1270-
AgentUtils.addToolsToFunctionCalling(tools, parameters, List.of("VectorDBTool"), "prompt");
1271-
1272-
String result = parameters.get(TOOLS);
1273-
Assert.assertFalse(result.contains("${tool.attributes.input_schema}"));
1274-
Assert.assertTrue(result.contains("\"inputSchema\":{\"json\":{\"type\":\"object\",\"properties\":{}}}"));
1270+
Exception ex = Assert
1271+
.assertThrows(
1272+
IllegalArgumentException.class,
1273+
() -> AgentUtils.addToolsToFunctionCalling(tools, parameters, List.of("VectorDBTool"), "prompt")
1274+
);
1275+
Assert.assertTrue(ex.getMessage().contains("VectorDBTool"));
1276+
Assert.assertTrue(ex.getMessage().contains("input_schema"));
12751277
}
12761278

12771279
@Test
@@ -1289,11 +1291,13 @@ public void testAddToolsToFunctionCalling_ToolWithAttributesButNoInputSchema() {
12891291
"{\"toolSpec\":{\"name\":\"${tool.name}\",\"description\":\"${tool.description}\",\"inputSchema\":{\"json\":${tool.attributes.input_schema}}}}";
12901292
parameters.put(TOOL_TEMPLATE, toolTemplate);
12911293

1292-
AgentUtils.addToolsToFunctionCalling(tools, parameters, List.of("VectorDBTool"), "prompt");
1293-
1294-
String result = parameters.get(TOOLS);
1295-
Assert.assertFalse(result.contains("${tool.attributes.input_schema}"));
1296-
Assert.assertTrue(result.contains("\"inputSchema\":{\"json\":{\"type\":\"object\",\"properties\":{}}}"));
1294+
Exception ex = Assert
1295+
.assertThrows(
1296+
IllegalArgumentException.class,
1297+
() -> AgentUtils.addToolsToFunctionCalling(tools, parameters, List.of("VectorDBTool"), "prompt")
1298+
);
1299+
Assert.assertTrue(ex.getMessage().contains("VectorDBTool"));
1300+
Assert.assertTrue(ex.getMessage().contains("input_schema"));
12971301
}
12981302

12991303
@Test

0 commit comments

Comments
 (0)