|
7 | 7 |
|
8 | 8 | import static org.junit.Assert.*; |
9 | 9 | import static org.mockito.Mockito.*; |
| 10 | +import static org.opensearch.ml.common.CommonValue.MCP_CONNECTORS_FIELD; |
10 | 11 | import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_PARAM_CONTEXT; |
| 12 | +import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_MCP_CONNECTOR_DISABLED_MESSAGE; |
11 | 13 | import static org.opensearch.ml.engine.algorithms.agent.MLAgentExecutor.QUESTION; |
12 | 14 |
|
13 | 15 | import java.io.IOException; |
|
17 | 19 | import java.util.HashMap; |
18 | 20 | import java.util.List; |
19 | 21 | import java.util.Map; |
| 22 | +import java.util.concurrent.CompletionStage; |
20 | 23 | import java.util.concurrent.atomic.AtomicBoolean; |
| 24 | +import java.util.function.BiConsumer; |
21 | 25 |
|
22 | 26 | import org.junit.Assert; |
23 | 27 | import org.junit.Before; |
|
26 | 30 | import org.mockito.Captor; |
27 | 31 | import org.mockito.Mock; |
28 | 32 | import org.mockito.MockitoAnnotations; |
| 33 | +import org.opensearch.OpenSearchException; |
29 | 34 | import org.opensearch.OpenSearchStatusException; |
30 | 35 | import org.opensearch.action.get.GetResponse; |
31 | 36 | import org.opensearch.cluster.ClusterState; |
32 | 37 | import org.opensearch.cluster.metadata.Metadata; |
| 38 | +import org.opensearch.cluster.node.DiscoveryNode; |
33 | 39 | import org.opensearch.cluster.service.ClusterService; |
34 | 40 | import org.opensearch.common.settings.Settings; |
35 | 41 | import org.opensearch.common.util.concurrent.ThreadContext; |
36 | 42 | import org.opensearch.common.xcontent.XContentFactory; |
| 43 | +import org.opensearch.common.xcontent.XContentType; |
37 | 44 | import org.opensearch.core.action.ActionListener; |
38 | 45 | import org.opensearch.core.common.bytes.BytesReference; |
39 | 46 | import org.opensearch.core.rest.RestStatus; |
|
65 | 72 | import org.opensearch.ml.common.spi.tools.Tool; |
66 | 73 | import org.opensearch.ml.engine.encryptor.Encryptor; |
67 | 74 | import org.opensearch.ml.engine.memory.ConversationIndexMemory; |
| 75 | +import org.opensearch.remote.metadata.client.GetDataObjectResponse; |
68 | 76 | import org.opensearch.remote.metadata.client.SdkClient; |
69 | 77 | import org.opensearch.threadpool.ThreadPool; |
70 | 78 | import org.opensearch.transport.TransportChannel; |
@@ -131,6 +139,9 @@ public class MLAgentExecutorTest { |
131 | 139 | @Mock |
132 | 140 | private ActionListener<Output> agentActionListener; |
133 | 141 |
|
| 142 | + @Mock |
| 143 | + private DiscoveryNode localNode; |
| 144 | + |
134 | 145 | MLAgent mlAgent; |
135 | 146 |
|
136 | 147 | @Before |
@@ -1681,4 +1692,289 @@ public void test_SupportsStructuredMessages_PlanExecuteAndReflect() { |
1681 | 1692 | // PER agents don't support structured messages |
1682 | 1693 | assertFalse(mlAgentExecutor.supportsStructuredMessages(agent)); |
1683 | 1694 | } |
| 1695 | + |
| 1696 | + @Test |
| 1697 | + public void test_ExecuteAgent_V2Agent_WithMemoryAndMessages_RoutesToExecuteV2Agent() throws IOException { |
| 1698 | + // Setup: agent index exists (needed for MLIndicesHandler.doesMultiTenantIndexExist to return true) |
| 1699 | + when(clusterService.state().metadata().hasIndex(anyString())).thenReturn(true); |
| 1700 | + when(clusterService.localNode()).thenReturn(localNode); |
| 1701 | + when(localNode.getId()).thenReturn("test-node-id"); |
| 1702 | + when(mlFeatureEnabledSetting.isRemoteAgenticMemoryEnabled()).thenReturn(false); |
| 1703 | + |
| 1704 | + // Build a CONVERSATIONAL_V2 agent with a conversation_index memory spec |
| 1705 | + MLAgent v2Agent = MLAgent |
| 1706 | + .builder() |
| 1707 | + .name("test_v2_agent") |
| 1708 | + .type(MLAgentType.CONVERSATIONAL_V2.name()) |
| 1709 | + .description("Test V2 agent") |
| 1710 | + .memory(new MLMemorySpec("conversation_index", null, 0, null)) |
| 1711 | + .createdTime(Instant.now()) |
| 1712 | + .lastUpdateTime(Instant.now()) |
| 1713 | + .build(); |
| 1714 | + |
| 1715 | + // Serialize agent → GetResponse JSON so the sdkClient mock can return it |
| 1716 | + XContentBuilder agentContent = v2Agent.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS); |
| 1717 | + GetResult getResult = new GetResult( |
| 1718 | + ".plugins-ml-agent", |
| 1719 | + "test-agent-id", |
| 1720 | + 1L, |
| 1721 | + 1L, |
| 1722 | + 1L, |
| 1723 | + true, |
| 1724 | + BytesReference.bytes(agentContent), |
| 1725 | + null, |
| 1726 | + null |
| 1727 | + ); |
| 1728 | + GetResponse getAgentResponseObj = new GetResponse(getResult); |
| 1729 | + XContentBuilder responseBuilder = XContentFactory.jsonBuilder(); |
| 1730 | + getAgentResponseObj.toXContent(responseBuilder, ToXContent.EMPTY_PARAMS); |
| 1731 | + String getResponseJson = BytesReference.bytes(responseBuilder).utf8ToString(); |
| 1732 | + |
| 1733 | + // Mock sdkClient.getDataObjectAsync to return the V2 agent (two-arg overload used in execute) |
| 1734 | + when(sdkClient.getDataObjectAsync(any(), any())).thenAnswer(inv -> { |
| 1735 | + GetDataObjectResponse resp = mock(GetDataObjectResponse.class); |
| 1736 | + when(resp.parser()) |
| 1737 | + .thenReturn(XContentType.JSON.xContent().createParser(NamedXContentRegistry.EMPTY, null, getResponseJson)); |
| 1738 | + CompletionStage<GetDataObjectResponse> stage = mock(CompletionStage.class); |
| 1739 | + when(stage.whenComplete(any())).thenAnswer(cbInv -> { |
| 1740 | + BiConsumer<GetDataObjectResponse, Throwable> cb = cbInv.getArgument(0); |
| 1741 | + cb.accept(resp, null); |
| 1742 | + return stage; |
| 1743 | + }); |
| 1744 | + return stage; |
| 1745 | + }); |
| 1746 | + |
| 1747 | + // Register the memory factory so the execute flow creates a memory instance |
| 1748 | + memoryFactoryMap.put("CONVERSATION_INDEX", mockMemoryFactory); |
| 1749 | + when(memory.getId()).thenReturn("test-memory-id"); |
| 1750 | + doAnswer(invocation -> { |
| 1751 | + ActionListener<ConversationIndexMemory> memListener = invocation.getArgument(1); |
| 1752 | + memListener.onResponse(memory); |
| 1753 | + return null; |
| 1754 | + }).when(mockMemoryFactory).create(any(), any()); |
| 1755 | + |
| 1756 | + // Mock getStructuredMessages (first call inside executeV2Agent) to verify the path was reached |
| 1757 | + doAnswer(invocation -> { |
| 1758 | + ActionListener<List<Message>> msgListener = invocation.getArgument(0); |
| 1759 | + msgListener.onFailure(new RuntimeException("reached_executeV2Agent")); |
| 1760 | + return null; |
| 1761 | + }).when(memory).getStructuredMessages(any()); |
| 1762 | + |
| 1763 | + // Create input with MESSAGES type so inputMessages is non-null in saveRootInteractionAndExecute |
| 1764 | + ContentBlock textBlock = new ContentBlock(); |
| 1765 | + textBlock.setType(ContentType.TEXT); |
| 1766 | + textBlock.setText("What is machine learning?"); |
| 1767 | + Message message = new Message("user", Collections.singletonList(textBlock)); |
| 1768 | + AgentInput agentInput = new AgentInput(); |
| 1769 | + agentInput.setInput(Collections.singletonList(message)); |
| 1770 | + |
| 1771 | + Map<String, String> params = new HashMap<>(); |
| 1772 | + params.put(QUESTION, "What is machine learning?"); |
| 1773 | + RemoteInferenceInputDataSet dataset = RemoteInferenceInputDataSet.builder().parameters(params).build(); |
| 1774 | + AgentMLInput agentMLInput = new AgentMLInput("test-agent-id", null, FunctionName.AGENT, agentInput, dataset, false); |
| 1775 | + |
| 1776 | + mlAgentExecutor.execute(agentMLInput, listener, channel); |
| 1777 | + |
| 1778 | + // executeV2Agent delegates immediately to memory.getStructuredMessages — verify it was called |
| 1779 | + verify(memory, timeout(5000).atLeastOnce()).getStructuredMessages(any()); |
| 1780 | + } |
| 1781 | + |
| 1782 | + private String serializeAgentToGetResponseJson(MLAgent agent) throws IOException { |
| 1783 | + XContentBuilder agentContent = agent.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS); |
| 1784 | + GetResult getResult = new GetResult( |
| 1785 | + ".plugins-ml-agent", |
| 1786 | + "test-agent-id", |
| 1787 | + 1L, |
| 1788 | + 1L, |
| 1789 | + 1L, |
| 1790 | + true, |
| 1791 | + BytesReference.bytes(agentContent), |
| 1792 | + null, |
| 1793 | + null |
| 1794 | + ); |
| 1795 | + GetResponse getAgentResponseObj = new GetResponse(getResult); |
| 1796 | + XContentBuilder responseBuilder = XContentFactory.jsonBuilder(); |
| 1797 | + getAgentResponseObj.toXContent(responseBuilder, ToXContent.EMPTY_PARAMS); |
| 1798 | + return BytesReference.bytes(responseBuilder).utf8ToString(); |
| 1799 | + } |
| 1800 | + |
| 1801 | + private void mockSdkClientWithAgent(String getResponseJson) { |
| 1802 | + when(clusterService.localNode()).thenReturn(localNode); |
| 1803 | + when(localNode.getId()).thenReturn("test-node-id"); |
| 1804 | + when(sdkClient.getDataObjectAsync(any(), any())).thenAnswer(inv -> { |
| 1805 | + GetDataObjectResponse resp = mock(GetDataObjectResponse.class); |
| 1806 | + when(resp.parser()) |
| 1807 | + .thenReturn(XContentType.JSON.xContent().createParser(NamedXContentRegistry.EMPTY, null, getResponseJson)); |
| 1808 | + CompletionStage<GetDataObjectResponse> stage = mock(CompletionStage.class); |
| 1809 | + when(stage.whenComplete(any())).thenAnswer(cbInv -> { |
| 1810 | + BiConsumer<GetDataObjectResponse, Throwable> cb = cbInv.getArgument(0); |
| 1811 | + cb.accept(resp, null); |
| 1812 | + return stage; |
| 1813 | + }); |
| 1814 | + return stage; |
| 1815 | + }); |
| 1816 | + } |
| 1817 | + |
| 1818 | + private AgentMLInput buildV2AgentMLInput() { |
| 1819 | + ContentBlock textBlock = new ContentBlock(); |
| 1820 | + textBlock.setType(ContentType.TEXT); |
| 1821 | + textBlock.setText("What is machine learning?"); |
| 1822 | + Message message = new Message("user", Collections.singletonList(textBlock)); |
| 1823 | + AgentInput agentInput = new AgentInput(); |
| 1824 | + agentInput.setInput(Collections.singletonList(message)); |
| 1825 | + |
| 1826 | + Map<String, String> params = new HashMap<>(); |
| 1827 | + params.put(QUESTION, "What is machine learning?"); |
| 1828 | + RemoteInferenceInputDataSet dataset = RemoteInferenceInputDataSet.builder().parameters(params).build(); |
| 1829 | + return new AgentMLInput("test-agent-id", null, FunctionName.AGENT, agentInput, dataset, false); |
| 1830 | + } |
| 1831 | + |
| 1832 | + /** |
| 1833 | + * Covers the false branch of condition 1: agentType.isV2() == false. |
| 1834 | + * A non-V2 agent (CONVERSATIONAL) with no memory spec reaches executeAgent via the |
| 1835 | + * line-484 path (null memory, null inputMessages). The V2 condition fails immediately |
| 1836 | + * and execution falls through to the normal runner path. |
| 1837 | + */ |
| 1838 | + @Test |
| 1839 | + public void test_ExecuteAgent_NonV2Agent_DoesNotRouteToExecuteV2Agent() throws IOException { |
| 1840 | + when(clusterService.state().metadata().hasIndex(anyString())).thenReturn(true); |
| 1841 | + when(mlFeatureEnabledSetting.isRemoteAgenticMemoryEnabled()).thenReturn(false); |
| 1842 | + |
| 1843 | + MLAgent convAgent = MLAgent |
| 1844 | + .builder() |
| 1845 | + .name("test_conv_agent") |
| 1846 | + .type(MLAgentType.CONVERSATIONAL.name()) |
| 1847 | + .llm(LLMSpec.builder().modelId("gpt-4").build()) |
| 1848 | + // no memory spec → executeAgent called with null memory and null inputMessages |
| 1849 | + .createdTime(Instant.now()) |
| 1850 | + .lastUpdateTime(Instant.now()) |
| 1851 | + .build(); |
| 1852 | + |
| 1853 | + mockSdkClientWithAgent(serializeAgentToGetResponseJson(convAgent)); |
| 1854 | + |
| 1855 | + Map<String, String> params = new HashMap<>(); |
| 1856 | + params.put(QUESTION, "What is ML?"); |
| 1857 | + RemoteInferenceInputDataSet dataset = RemoteInferenceInputDataSet.builder().parameters(params).build(); |
| 1858 | + AgentMLInput agentMLInput = new AgentMLInput("test-agent-id", null, FunctionName.AGENT, dataset); |
| 1859 | + |
| 1860 | + mlAgentExecutor.execute(agentMLInput, listener, channel); |
| 1861 | + |
| 1862 | + // agentType.isV2() = false → V2 execution path must not be entered |
| 1863 | + verify(memory, never()).getStructuredMessages(any()); |
| 1864 | + verify(listener, timeout(5000)).onFailure(any()); |
| 1865 | + } |
| 1866 | + |
| 1867 | + /** |
| 1868 | + * Covers the false branch of condition 2: inputMessages == null (with agentType.isV2() true). |
| 1869 | + * A V2 agent with no memory spec reaches executeAgent via the line-484 path where |
| 1870 | + * inputMessages is hardcoded null. The condition fails at the second && operand. |
| 1871 | + */ |
| 1872 | + @Test |
| 1873 | + public void test_ExecuteAgent_V2Agent_WithNullInputMessages_DoesNotRouteToExecuteV2Agent() throws IOException { |
| 1874 | + when(clusterService.state().metadata().hasIndex(anyString())).thenReturn(true); |
| 1875 | + when(mlFeatureEnabledSetting.isRemoteAgenticMemoryEnabled()).thenReturn(false); |
| 1876 | + |
| 1877 | + MLAgent v2Agent = MLAgent |
| 1878 | + .builder() |
| 1879 | + .name("test_v2_no_memory") |
| 1880 | + .type(MLAgentType.CONVERSATIONAL_V2.name()) |
| 1881 | + // no memory spec → executeAgent called with null memory and null inputMessages |
| 1882 | + .createdTime(Instant.now()) |
| 1883 | + .lastUpdateTime(Instant.now()) |
| 1884 | + .build(); |
| 1885 | + |
| 1886 | + mockSdkClientWithAgent(serializeAgentToGetResponseJson(v2Agent)); |
| 1887 | + |
| 1888 | + // Old-style dataset input (no AgentInput) → inputMessages stays null inside executeAgent |
| 1889 | + Map<String, String> params = new HashMap<>(); |
| 1890 | + params.put(QUESTION, "What is ML?"); |
| 1891 | + RemoteInferenceInputDataSet dataset = RemoteInferenceInputDataSet.builder().parameters(params).build(); |
| 1892 | + AgentMLInput agentMLInput = new AgentMLInput("test-agent-id", null, FunctionName.AGENT, dataset); |
| 1893 | + |
| 1894 | + mlAgentExecutor.execute(agentMLInput, listener, channel); |
| 1895 | + |
| 1896 | + // agentType.isV2() = true, but inputMessages = null → V2 execution path must not be entered |
| 1897 | + verify(memory, never()).getStructuredMessages(any()); |
| 1898 | + verify(listener, timeout(5000)).onFailure(any()); |
| 1899 | + } |
| 1900 | + |
| 1901 | + @Test |
| 1902 | + public void test_ExecuteAgent_V2Agent_WithMcpConnector_McpDisabled_FailsWithMcpError() throws IOException { |
| 1903 | + when(clusterService.state().metadata().hasIndex(anyString())).thenReturn(true); |
| 1904 | + when(mlFeatureEnabledSetting.isRemoteAgenticMemoryEnabled()).thenReturn(false); |
| 1905 | + when(mlFeatureEnabledSetting.isMcpConnectorEnabled()).thenReturn(false); |
| 1906 | + |
| 1907 | + Map<String, String> agentParams = new HashMap<>(); |
| 1908 | + agentParams.put(MCP_CONNECTORS_FIELD, "[{\"id\":\"mcp-conn-1\"}]"); |
| 1909 | + |
| 1910 | + MLAgent v2Agent = MLAgent |
| 1911 | + .builder() |
| 1912 | + .name("test_v2_mcp_agent") |
| 1913 | + .type(MLAgentType.CONVERSATIONAL_V2.name()) |
| 1914 | + .description("V2 agent with MCP connector") |
| 1915 | + .parameters(agentParams) |
| 1916 | + .memory(new MLMemorySpec("conversation_index", null, 0, null)) |
| 1917 | + .createdTime(Instant.now()) |
| 1918 | + .lastUpdateTime(Instant.now()) |
| 1919 | + .build(); |
| 1920 | + |
| 1921 | + mockSdkClientWithAgent(serializeAgentToGetResponseJson(v2Agent)); |
| 1922 | + memoryFactoryMap.put("CONVERSATION_INDEX", mockMemoryFactory); |
| 1923 | + when(memory.getId()).thenReturn("test-memory-id"); |
| 1924 | + doAnswer(invocation -> { |
| 1925 | + ActionListener<ConversationIndexMemory> memListener = invocation.getArgument(1); |
| 1926 | + memListener.onResponse(memory); |
| 1927 | + return null; |
| 1928 | + }).when(mockMemoryFactory).create(any(), any()); |
| 1929 | + |
| 1930 | + mlAgentExecutor.execute(buildV2AgentMLInput(), listener, channel); |
| 1931 | + |
| 1932 | + // MCP check must fire before V2 routing — listener.onFailure called with MCP error |
| 1933 | + verify(listener, timeout(5000)).onFailure(argThat(e -> e instanceof OpenSearchException |
| 1934 | + && e.getMessage().contains(ML_COMMONS_MCP_CONNECTOR_DISABLED_MESSAGE))); |
| 1935 | + // executeV2Agent must NOT have been reached |
| 1936 | + verify(memory, never()).getStructuredMessages(any()); |
| 1937 | + } |
| 1938 | + |
| 1939 | + @Test |
| 1940 | + public void test_ExecuteAgent_V2Agent_WithMcpConnector_McpEnabled_RoutesToExecuteV2Agent() throws IOException { |
| 1941 | + when(clusterService.state().metadata().hasIndex(anyString())).thenReturn(true); |
| 1942 | + when(mlFeatureEnabledSetting.isRemoteAgenticMemoryEnabled()).thenReturn(false); |
| 1943 | + when(mlFeatureEnabledSetting.isMcpConnectorEnabled()).thenReturn(true); |
| 1944 | + |
| 1945 | + Map<String, String> agentParams = new HashMap<>(); |
| 1946 | + agentParams.put(MCP_CONNECTORS_FIELD, "[{\"id\":\"mcp-conn-1\"}]"); |
| 1947 | + |
| 1948 | + MLAgent v2Agent = MLAgent |
| 1949 | + .builder() |
| 1950 | + .name("test_v2_mcp_agent") |
| 1951 | + .type(MLAgentType.CONVERSATIONAL_V2.name()) |
| 1952 | + .description("V2 agent with MCP connector, MCP enabled") |
| 1953 | + .parameters(agentParams) |
| 1954 | + .memory(new MLMemorySpec("conversation_index", null, 0, null)) |
| 1955 | + .createdTime(Instant.now()) |
| 1956 | + .lastUpdateTime(Instant.now()) |
| 1957 | + .build(); |
| 1958 | + |
| 1959 | + mockSdkClientWithAgent(serializeAgentToGetResponseJson(v2Agent)); |
| 1960 | + memoryFactoryMap.put("CONVERSATION_INDEX", mockMemoryFactory); |
| 1961 | + when(memory.getId()).thenReturn("test-memory-id"); |
| 1962 | + doAnswer(invocation -> { |
| 1963 | + ActionListener<ConversationIndexMemory> memListener = invocation.getArgument(1); |
| 1964 | + memListener.onResponse(memory); |
| 1965 | + return null; |
| 1966 | + }).when(mockMemoryFactory).create(any(), any()); |
| 1967 | + |
| 1968 | + // Fail fast inside executeV2Agent to verify the path was reached |
| 1969 | + doAnswer(invocation -> { |
| 1970 | + ActionListener<List<Message>> msgListener = invocation.getArgument(0); |
| 1971 | + msgListener.onFailure(new RuntimeException("reached_executeV2Agent")); |
| 1972 | + return null; |
| 1973 | + }).when(memory).getStructuredMessages(any()); |
| 1974 | + |
| 1975 | + mlAgentExecutor.execute(buildV2AgentMLInput(), listener, channel); |
| 1976 | + |
| 1977 | + // MCP check passes (enabled), so execution must reach executeV2Agent |
| 1978 | + verify(memory, timeout(5000).atLeastOnce()).getStructuredMessages(any()); |
| 1979 | + } |
1684 | 1980 | } |
0 commit comments