Skip to content

Commit 6c235e7

Browse files
authored
Fix: Respect MCP connector setting for Agent V2 (#4739)
* fix: move agentv2 execution post mcp flag check Signed-off-by: Pavan Yekbote <pybot@amazon.com> * tests: add test case for mcp connector check Signed-off-by: Pavan Yekbote <pybot@amazon.com> * test: add another case Signed-off-by: Pavan Yekbote <pybot@amazon.com> --------- Signed-off-by: Pavan Yekbote <pybot@amazon.com>
1 parent cc33c8d commit 6c235e7

File tree

2 files changed

+303
-7
lines changed

2 files changed

+303
-7
lines changed

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -915,20 +915,20 @@ private void executeAgent(
915915
) {
916916
MLAgentType agentType = MLAgentType.from(mlAgent.getType());
917917

918-
// V2 agents follow pure message-centric execution path
919-
// TODO: Refactor to separate MLAgentExecutorV2 class for cleaner separation
920-
if (agentType.isV2() && inputMessages != null && memory != null) {
921-
executeV2Agent(inputDataSet, tenantId, mlTask, isAsync, mlAgent, listener, memory, channel, hookRegistry, inputMessages);
922-
return;
923-
}
924-
925918
String mcpConnectorConfigJSON = (mlAgent.getParameters() != null) ? mlAgent.getParameters().get(MCP_CONNECTORS_FIELD) : null;
926919
if (mcpConnectorConfigJSON != null && !mlFeatureEnabledSetting.isMcpConnectorEnabled()) {
927920
// MCP connector provided as tools but MCP feature is disabled, so abort.
928921
listener.onFailure(new OpenSearchException(ML_COMMONS_MCP_CONNECTOR_DISABLED_MESSAGE));
929922
return;
930923
}
931924

925+
// V2 agents follow pure message-centric execution path
926+
// TODO: Refactor to separate MLAgentExecutorV2 class for cleaner separation
927+
if (agentType.isV2() && inputMessages != null && memory != null) {
928+
executeV2Agent(inputDataSet, tenantId, mlTask, isAsync, mlAgent, listener, memory, channel, hookRegistry, inputMessages);
929+
return;
930+
}
931+
932932
// Check for agent-level context management configuration (following connector
933933
// pattern)
934934
if (mlAgent.hasContextManagement()) {

ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutorTest.java

Lines changed: 296 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77

88
import static org.junit.Assert.*;
99
import static org.mockito.Mockito.*;
10+
import static org.opensearch.ml.common.CommonValue.MCP_CONNECTORS_FIELD;
1011
import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_PARAM_CONTEXT;
12+
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_MCP_CONNECTOR_DISABLED_MESSAGE;
1113
import static org.opensearch.ml.engine.algorithms.agent.MLAgentExecutor.QUESTION;
1214

1315
import java.io.IOException;
@@ -17,7 +19,9 @@
1719
import java.util.HashMap;
1820
import java.util.List;
1921
import java.util.Map;
22+
import java.util.concurrent.CompletionStage;
2023
import java.util.concurrent.atomic.AtomicBoolean;
24+
import java.util.function.BiConsumer;
2125

2226
import org.junit.Assert;
2327
import org.junit.Before;
@@ -26,14 +30,17 @@
2630
import org.mockito.Captor;
2731
import org.mockito.Mock;
2832
import org.mockito.MockitoAnnotations;
33+
import org.opensearch.OpenSearchException;
2934
import org.opensearch.OpenSearchStatusException;
3035
import org.opensearch.action.get.GetResponse;
3136
import org.opensearch.cluster.ClusterState;
3237
import org.opensearch.cluster.metadata.Metadata;
38+
import org.opensearch.cluster.node.DiscoveryNode;
3339
import org.opensearch.cluster.service.ClusterService;
3440
import org.opensearch.common.settings.Settings;
3541
import org.opensearch.common.util.concurrent.ThreadContext;
3642
import org.opensearch.common.xcontent.XContentFactory;
43+
import org.opensearch.common.xcontent.XContentType;
3744
import org.opensearch.core.action.ActionListener;
3845
import org.opensearch.core.common.bytes.BytesReference;
3946
import org.opensearch.core.rest.RestStatus;
@@ -65,6 +72,7 @@
6572
import org.opensearch.ml.common.spi.tools.Tool;
6673
import org.opensearch.ml.engine.encryptor.Encryptor;
6774
import org.opensearch.ml.engine.memory.ConversationIndexMemory;
75+
import org.opensearch.remote.metadata.client.GetDataObjectResponse;
6876
import org.opensearch.remote.metadata.client.SdkClient;
6977
import org.opensearch.threadpool.ThreadPool;
7078
import org.opensearch.transport.TransportChannel;
@@ -131,6 +139,9 @@ public class MLAgentExecutorTest {
131139
@Mock
132140
private ActionListener<Output> agentActionListener;
133141

142+
@Mock
143+
private DiscoveryNode localNode;
144+
134145
MLAgent mlAgent;
135146

136147
@Before
@@ -1681,4 +1692,289 @@ public void test_SupportsStructuredMessages_PlanExecuteAndReflect() {
16811692
// PER agents don't support structured messages
16821693
assertFalse(mlAgentExecutor.supportsStructuredMessages(agent));
16831694
}
1695+
1696+
@Test
1697+
public void test_ExecuteAgent_V2Agent_WithMemoryAndMessages_RoutesToExecuteV2Agent() throws IOException {
1698+
// Setup: agent index exists (needed for MLIndicesHandler.doesMultiTenantIndexExist to return true)
1699+
when(clusterService.state().metadata().hasIndex(anyString())).thenReturn(true);
1700+
when(clusterService.localNode()).thenReturn(localNode);
1701+
when(localNode.getId()).thenReturn("test-node-id");
1702+
when(mlFeatureEnabledSetting.isRemoteAgenticMemoryEnabled()).thenReturn(false);
1703+
1704+
// Build a CONVERSATIONAL_V2 agent with a conversation_index memory spec
1705+
MLAgent v2Agent = MLAgent
1706+
.builder()
1707+
.name("test_v2_agent")
1708+
.type(MLAgentType.CONVERSATIONAL_V2.name())
1709+
.description("Test V2 agent")
1710+
.memory(new MLMemorySpec("conversation_index", null, 0, null))
1711+
.createdTime(Instant.now())
1712+
.lastUpdateTime(Instant.now())
1713+
.build();
1714+
1715+
// Serialize agent → GetResponse JSON so the sdkClient mock can return it
1716+
XContentBuilder agentContent = v2Agent.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS);
1717+
GetResult getResult = new GetResult(
1718+
".plugins-ml-agent",
1719+
"test-agent-id",
1720+
1L,
1721+
1L,
1722+
1L,
1723+
true,
1724+
BytesReference.bytes(agentContent),
1725+
null,
1726+
null
1727+
);
1728+
GetResponse getAgentResponseObj = new GetResponse(getResult);
1729+
XContentBuilder responseBuilder = XContentFactory.jsonBuilder();
1730+
getAgentResponseObj.toXContent(responseBuilder, ToXContent.EMPTY_PARAMS);
1731+
String getResponseJson = BytesReference.bytes(responseBuilder).utf8ToString();
1732+
1733+
// Mock sdkClient.getDataObjectAsync to return the V2 agent (two-arg overload used in execute)
1734+
when(sdkClient.getDataObjectAsync(any(), any())).thenAnswer(inv -> {
1735+
GetDataObjectResponse resp = mock(GetDataObjectResponse.class);
1736+
when(resp.parser())
1737+
.thenReturn(XContentType.JSON.xContent().createParser(NamedXContentRegistry.EMPTY, null, getResponseJson));
1738+
CompletionStage<GetDataObjectResponse> stage = mock(CompletionStage.class);
1739+
when(stage.whenComplete(any())).thenAnswer(cbInv -> {
1740+
BiConsumer<GetDataObjectResponse, Throwable> cb = cbInv.getArgument(0);
1741+
cb.accept(resp, null);
1742+
return stage;
1743+
});
1744+
return stage;
1745+
});
1746+
1747+
// Register the memory factory so the execute flow creates a memory instance
1748+
memoryFactoryMap.put("CONVERSATION_INDEX", mockMemoryFactory);
1749+
when(memory.getId()).thenReturn("test-memory-id");
1750+
doAnswer(invocation -> {
1751+
ActionListener<ConversationIndexMemory> memListener = invocation.getArgument(1);
1752+
memListener.onResponse(memory);
1753+
return null;
1754+
}).when(mockMemoryFactory).create(any(), any());
1755+
1756+
// Mock getStructuredMessages (first call inside executeV2Agent) to verify the path was reached
1757+
doAnswer(invocation -> {
1758+
ActionListener<List<Message>> msgListener = invocation.getArgument(0);
1759+
msgListener.onFailure(new RuntimeException("reached_executeV2Agent"));
1760+
return null;
1761+
}).when(memory).getStructuredMessages(any());
1762+
1763+
// Create input with MESSAGES type so inputMessages is non-null in saveRootInteractionAndExecute
1764+
ContentBlock textBlock = new ContentBlock();
1765+
textBlock.setType(ContentType.TEXT);
1766+
textBlock.setText("What is machine learning?");
1767+
Message message = new Message("user", Collections.singletonList(textBlock));
1768+
AgentInput agentInput = new AgentInput();
1769+
agentInput.setInput(Collections.singletonList(message));
1770+
1771+
Map<String, String> params = new HashMap<>();
1772+
params.put(QUESTION, "What is machine learning?");
1773+
RemoteInferenceInputDataSet dataset = RemoteInferenceInputDataSet.builder().parameters(params).build();
1774+
AgentMLInput agentMLInput = new AgentMLInput("test-agent-id", null, FunctionName.AGENT, agentInput, dataset, false);
1775+
1776+
mlAgentExecutor.execute(agentMLInput, listener, channel);
1777+
1778+
// executeV2Agent delegates immediately to memory.getStructuredMessages — verify it was called
1779+
verify(memory, timeout(5000).atLeastOnce()).getStructuredMessages(any());
1780+
}
1781+
1782+
private String serializeAgentToGetResponseJson(MLAgent agent) throws IOException {
1783+
XContentBuilder agentContent = agent.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS);
1784+
GetResult getResult = new GetResult(
1785+
".plugins-ml-agent",
1786+
"test-agent-id",
1787+
1L,
1788+
1L,
1789+
1L,
1790+
true,
1791+
BytesReference.bytes(agentContent),
1792+
null,
1793+
null
1794+
);
1795+
GetResponse getAgentResponseObj = new GetResponse(getResult);
1796+
XContentBuilder responseBuilder = XContentFactory.jsonBuilder();
1797+
getAgentResponseObj.toXContent(responseBuilder, ToXContent.EMPTY_PARAMS);
1798+
return BytesReference.bytes(responseBuilder).utf8ToString();
1799+
}
1800+
1801+
private void mockSdkClientWithAgent(String getResponseJson) {
1802+
when(clusterService.localNode()).thenReturn(localNode);
1803+
when(localNode.getId()).thenReturn("test-node-id");
1804+
when(sdkClient.getDataObjectAsync(any(), any())).thenAnswer(inv -> {
1805+
GetDataObjectResponse resp = mock(GetDataObjectResponse.class);
1806+
when(resp.parser())
1807+
.thenReturn(XContentType.JSON.xContent().createParser(NamedXContentRegistry.EMPTY, null, getResponseJson));
1808+
CompletionStage<GetDataObjectResponse> stage = mock(CompletionStage.class);
1809+
when(stage.whenComplete(any())).thenAnswer(cbInv -> {
1810+
BiConsumer<GetDataObjectResponse, Throwable> cb = cbInv.getArgument(0);
1811+
cb.accept(resp, null);
1812+
return stage;
1813+
});
1814+
return stage;
1815+
});
1816+
}
1817+
1818+
private AgentMLInput buildV2AgentMLInput() {
1819+
ContentBlock textBlock = new ContentBlock();
1820+
textBlock.setType(ContentType.TEXT);
1821+
textBlock.setText("What is machine learning?");
1822+
Message message = new Message("user", Collections.singletonList(textBlock));
1823+
AgentInput agentInput = new AgentInput();
1824+
agentInput.setInput(Collections.singletonList(message));
1825+
1826+
Map<String, String> params = new HashMap<>();
1827+
params.put(QUESTION, "What is machine learning?");
1828+
RemoteInferenceInputDataSet dataset = RemoteInferenceInputDataSet.builder().parameters(params).build();
1829+
return new AgentMLInput("test-agent-id", null, FunctionName.AGENT, agentInput, dataset, false);
1830+
}
1831+
1832+
/**
1833+
* Covers the false branch of condition 1: agentType.isV2() == false.
1834+
* A non-V2 agent (CONVERSATIONAL) with no memory spec reaches executeAgent via the
1835+
* line-484 path (null memory, null inputMessages). The V2 condition fails immediately
1836+
* and execution falls through to the normal runner path.
1837+
*/
1838+
@Test
1839+
public void test_ExecuteAgent_NonV2Agent_DoesNotRouteToExecuteV2Agent() throws IOException {
1840+
when(clusterService.state().metadata().hasIndex(anyString())).thenReturn(true);
1841+
when(mlFeatureEnabledSetting.isRemoteAgenticMemoryEnabled()).thenReturn(false);
1842+
1843+
MLAgent convAgent = MLAgent
1844+
.builder()
1845+
.name("test_conv_agent")
1846+
.type(MLAgentType.CONVERSATIONAL.name())
1847+
.llm(LLMSpec.builder().modelId("gpt-4").build())
1848+
// no memory spec → executeAgent called with null memory and null inputMessages
1849+
.createdTime(Instant.now())
1850+
.lastUpdateTime(Instant.now())
1851+
.build();
1852+
1853+
mockSdkClientWithAgent(serializeAgentToGetResponseJson(convAgent));
1854+
1855+
Map<String, String> params = new HashMap<>();
1856+
params.put(QUESTION, "What is ML?");
1857+
RemoteInferenceInputDataSet dataset = RemoteInferenceInputDataSet.builder().parameters(params).build();
1858+
AgentMLInput agentMLInput = new AgentMLInput("test-agent-id", null, FunctionName.AGENT, dataset);
1859+
1860+
mlAgentExecutor.execute(agentMLInput, listener, channel);
1861+
1862+
// agentType.isV2() = false → V2 execution path must not be entered
1863+
verify(memory, never()).getStructuredMessages(any());
1864+
verify(listener, timeout(5000)).onFailure(any());
1865+
}
1866+
1867+
/**
1868+
* Covers the false branch of condition 2: inputMessages == null (with agentType.isV2() true).
1869+
* A V2 agent with no memory spec reaches executeAgent via the line-484 path where
1870+
* inputMessages is hardcoded null. The condition fails at the second && operand.
1871+
*/
1872+
@Test
1873+
public void test_ExecuteAgent_V2Agent_WithNullInputMessages_DoesNotRouteToExecuteV2Agent() throws IOException {
1874+
when(clusterService.state().metadata().hasIndex(anyString())).thenReturn(true);
1875+
when(mlFeatureEnabledSetting.isRemoteAgenticMemoryEnabled()).thenReturn(false);
1876+
1877+
MLAgent v2Agent = MLAgent
1878+
.builder()
1879+
.name("test_v2_no_memory")
1880+
.type(MLAgentType.CONVERSATIONAL_V2.name())
1881+
// no memory spec → executeAgent called with null memory and null inputMessages
1882+
.createdTime(Instant.now())
1883+
.lastUpdateTime(Instant.now())
1884+
.build();
1885+
1886+
mockSdkClientWithAgent(serializeAgentToGetResponseJson(v2Agent));
1887+
1888+
// Old-style dataset input (no AgentInput) → inputMessages stays null inside executeAgent
1889+
Map<String, String> params = new HashMap<>();
1890+
params.put(QUESTION, "What is ML?");
1891+
RemoteInferenceInputDataSet dataset = RemoteInferenceInputDataSet.builder().parameters(params).build();
1892+
AgentMLInput agentMLInput = new AgentMLInput("test-agent-id", null, FunctionName.AGENT, dataset);
1893+
1894+
mlAgentExecutor.execute(agentMLInput, listener, channel);
1895+
1896+
// agentType.isV2() = true, but inputMessages = null → V2 execution path must not be entered
1897+
verify(memory, never()).getStructuredMessages(any());
1898+
verify(listener, timeout(5000)).onFailure(any());
1899+
}
1900+
1901+
@Test
1902+
public void test_ExecuteAgent_V2Agent_WithMcpConnector_McpDisabled_FailsWithMcpError() throws IOException {
1903+
when(clusterService.state().metadata().hasIndex(anyString())).thenReturn(true);
1904+
when(mlFeatureEnabledSetting.isRemoteAgenticMemoryEnabled()).thenReturn(false);
1905+
when(mlFeatureEnabledSetting.isMcpConnectorEnabled()).thenReturn(false);
1906+
1907+
Map<String, String> agentParams = new HashMap<>();
1908+
agentParams.put(MCP_CONNECTORS_FIELD, "[{\"id\":\"mcp-conn-1\"}]");
1909+
1910+
MLAgent v2Agent = MLAgent
1911+
.builder()
1912+
.name("test_v2_mcp_agent")
1913+
.type(MLAgentType.CONVERSATIONAL_V2.name())
1914+
.description("V2 agent with MCP connector")
1915+
.parameters(agentParams)
1916+
.memory(new MLMemorySpec("conversation_index", null, 0, null))
1917+
.createdTime(Instant.now())
1918+
.lastUpdateTime(Instant.now())
1919+
.build();
1920+
1921+
mockSdkClientWithAgent(serializeAgentToGetResponseJson(v2Agent));
1922+
memoryFactoryMap.put("CONVERSATION_INDEX", mockMemoryFactory);
1923+
when(memory.getId()).thenReturn("test-memory-id");
1924+
doAnswer(invocation -> {
1925+
ActionListener<ConversationIndexMemory> memListener = invocation.getArgument(1);
1926+
memListener.onResponse(memory);
1927+
return null;
1928+
}).when(mockMemoryFactory).create(any(), any());
1929+
1930+
mlAgentExecutor.execute(buildV2AgentMLInput(), listener, channel);
1931+
1932+
// MCP check must fire before V2 routing — listener.onFailure called with MCP error
1933+
verify(listener, timeout(5000)).onFailure(argThat(e -> e instanceof OpenSearchException
1934+
&& e.getMessage().contains(ML_COMMONS_MCP_CONNECTOR_DISABLED_MESSAGE)));
1935+
// executeV2Agent must NOT have been reached
1936+
verify(memory, never()).getStructuredMessages(any());
1937+
}
1938+
1939+
@Test
1940+
public void test_ExecuteAgent_V2Agent_WithMcpConnector_McpEnabled_RoutesToExecuteV2Agent() throws IOException {
1941+
when(clusterService.state().metadata().hasIndex(anyString())).thenReturn(true);
1942+
when(mlFeatureEnabledSetting.isRemoteAgenticMemoryEnabled()).thenReturn(false);
1943+
when(mlFeatureEnabledSetting.isMcpConnectorEnabled()).thenReturn(true);
1944+
1945+
Map<String, String> agentParams = new HashMap<>();
1946+
agentParams.put(MCP_CONNECTORS_FIELD, "[{\"id\":\"mcp-conn-1\"}]");
1947+
1948+
MLAgent v2Agent = MLAgent
1949+
.builder()
1950+
.name("test_v2_mcp_agent")
1951+
.type(MLAgentType.CONVERSATIONAL_V2.name())
1952+
.description("V2 agent with MCP connector, MCP enabled")
1953+
.parameters(agentParams)
1954+
.memory(new MLMemorySpec("conversation_index", null, 0, null))
1955+
.createdTime(Instant.now())
1956+
.lastUpdateTime(Instant.now())
1957+
.build();
1958+
1959+
mockSdkClientWithAgent(serializeAgentToGetResponseJson(v2Agent));
1960+
memoryFactoryMap.put("CONVERSATION_INDEX", mockMemoryFactory);
1961+
when(memory.getId()).thenReturn("test-memory-id");
1962+
doAnswer(invocation -> {
1963+
ActionListener<ConversationIndexMemory> memListener = invocation.getArgument(1);
1964+
memListener.onResponse(memory);
1965+
return null;
1966+
}).when(mockMemoryFactory).create(any(), any());
1967+
1968+
// Fail fast inside executeV2Agent to verify the path was reached
1969+
doAnswer(invocation -> {
1970+
ActionListener<List<Message>> msgListener = invocation.getArgument(0);
1971+
msgListener.onFailure(new RuntimeException("reached_executeV2Agent"));
1972+
return null;
1973+
}).when(memory).getStructuredMessages(any());
1974+
1975+
mlAgentExecutor.execute(buildV2AgentMLInput(), listener, channel);
1976+
1977+
// MCP check passes (enabled), so execution must reach executeV2Agent
1978+
verify(memory, timeout(5000).atLeastOnce()).getStructuredMessages(any());
1979+
}
16841980
}

0 commit comments

Comments
 (0)