Skip to content

Commit e5690af

Browse files
authored
restore AGUI context for legacy interface agent (#4720)
Signed-off-by: Jiaping Zeng <jpz@amazon.com>
1 parent c1edcd8 commit e5690af

File tree

2 files changed

+78
-6
lines changed

2 files changed

+78
-6
lines changed

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1347,6 +1347,13 @@ void processAgentInput(AgentMLInput agentMLInput, MLAgent mlAgent) {
13471347
remoteDataSet.getParameters().putAll(parameters);
13481348
} else {
13491349
// For old-style AG_UI agents without model field
1350+
// Prepend context to question if available
1351+
if (agentType == MLAgentType.AG_UI) {
1352+
String context = remoteDataSet.getParameters().get(AGUI_PARAM_CONTEXT);
1353+
if (context != null && !context.isEmpty()) {
1354+
question = "Context: " + context + "\nQuestion: " + question;
1355+
}
1356+
}
13501357
remoteDataSet.getParameters().putIfAbsent(QUESTION, question);
13511358
}
13521359
} catch (Exception e) {

ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutorTest.java

Lines changed: 71 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import static org.junit.Assert.*;
99
import static org.mockito.Mockito.*;
10+
import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_PARAM_CONTEXT;
1011
import static org.opensearch.ml.engine.algorithms.agent.MLAgentExecutor.QUESTION;
1112

1213
import java.io.IOException;
@@ -595,33 +596,97 @@ public void test_ProcessAgentInput_AGUIAgent_WithoutContext() {
595596

596597
@Test
597598
public void test_ProcessAgentInput_AGUIAgent_WithContext_LegacyInterface() {
598-
// AGUI agent with legacy LLM interface
599-
// Context has already been appended by AGUIInputConverter before reaching MLAgentExecutor
599+
// AGUI agent with legacy LLM interface (no model field)
600+
// Context is passed via AGUI_PARAM_CONTEXT and should be prepended to question
600601
MLAgent agent = MLAgent
601602
.builder()
602603
.name("agui_agent_legacy_context")
603604
.type(MLAgentType.AG_UI.name())
604605
.llm(LLMSpec.builder().modelId("gpt-4").build())
605606
.build();
606607

607-
// Simulate message with context already appended (as done by AGUIInputConverter)
608608
ContentBlock textBlock = new ContentBlock();
609609
textBlock.setType(ContentType.TEXT);
610-
textBlock.setText("Context:\n- Location: San Francisco\n\nWhat is the weather?");
610+
textBlock.setText("What is the weather?");
611611
Message message = new Message("user", Collections.singletonList(textBlock));
612612
AgentInput agentInput = new AgentInput();
613613
agentInput.setInput(Collections.singletonList(message));
614614
AgentMLInput agentMLInput = new AgentMLInput("test", null, FunctionName.AGENT, agentInput, null, false);
615615

616+
// Set context via params (as AGUIInputConverter stores it)
617+
Map<String, String> params = new HashMap<>();
618+
params.put(AGUI_PARAM_CONTEXT, "[{\"description\":\"Location\",\"value\":\"San Francisco\"}]");
619+
agentMLInput.setInputDataset(new RemoteInferenceInputDataSet(params));
620+
616621
mlAgentExecutor.processAgentInput(agentMLInput, agent);
617622

618-
// Verify question contains context that was already appended
623+
// Verify question contains context prepended by the new code path
619624
RemoteInferenceInputDataSet dataset = (RemoteInferenceInputDataSet) agentMLInput.getInputDataset();
620625
String question = dataset.getParameters().get(QUESTION);
621626
Assert.assertNotNull(question);
622-
Assert.assertTrue(question.contains("Context:"));
627+
Assert.assertTrue(question.startsWith("Context: "));
623628
Assert.assertTrue(question.contains("San Francisco"));
629+
Assert.assertTrue(question.contains("Question: "));
630+
Assert.assertTrue(question.contains("What is the weather?"));
631+
}
632+
633+
@Test
634+
public void test_ProcessAgentInput_AGUIAgent_NoContext_LegacyInterface() {
635+
// AGUI agent with legacy LLM interface, no context param
636+
MLAgent agent = MLAgent
637+
.builder()
638+
.name("agui_agent_legacy_no_context")
639+
.type(MLAgentType.AG_UI.name())
640+
.llm(LLMSpec.builder().modelId("gpt-4").build())
641+
.build();
642+
643+
ContentBlock textBlock = new ContentBlock();
644+
textBlock.setType(ContentType.TEXT);
645+
textBlock.setText("What is the weather?");
646+
Message message = new Message("user", Collections.singletonList(textBlock));
647+
AgentInput agentInput = new AgentInput();
648+
agentInput.setInput(Collections.singletonList(message));
649+
AgentMLInput agentMLInput = new AgentMLInput("test", null, FunctionName.AGENT, agentInput, null, false);
650+
651+
mlAgentExecutor.processAgentInput(agentMLInput, agent);
652+
653+
RemoteInferenceInputDataSet dataset = (RemoteInferenceInputDataSet) agentMLInput.getInputDataset();
654+
String question = dataset.getParameters().get(QUESTION);
655+
Assert.assertNotNull(question);
624656
Assert.assertTrue(question.contains("What is the weather?"));
657+
Assert.assertFalse(question.contains("Context:"));
658+
}
659+
660+
@Test
661+
public void test_ProcessAgentInput_AGUIAgent_EmptyContext_LegacyInterface() {
662+
// AGUI agent with legacy LLM interface, empty context string
663+
MLAgent agent = MLAgent
664+
.builder()
665+
.name("agui_agent_legacy_bad_context")
666+
.type(MLAgentType.AG_UI.name())
667+
.llm(LLMSpec.builder().modelId("gpt-4").build())
668+
.build();
669+
670+
ContentBlock textBlock = new ContentBlock();
671+
textBlock.setType(ContentType.TEXT);
672+
textBlock.setText("What is the weather?");
673+
Message message = new Message("user", Collections.singletonList(textBlock));
674+
AgentInput agentInput = new AgentInput();
675+
agentInput.setInput(Collections.singletonList(message));
676+
AgentMLInput agentMLInput = new AgentMLInput("test", null, FunctionName.AGENT, agentInput, null, false);
677+
678+
// Empty context should not be prepended
679+
Map<String, String> params = new HashMap<>();
680+
params.put(AGUI_PARAM_CONTEXT, "");
681+
agentMLInput.setInputDataset(new RemoteInferenceInputDataSet(params));
682+
683+
mlAgentExecutor.processAgentInput(agentMLInput, agent);
684+
685+
RemoteInferenceInputDataSet dataset = (RemoteInferenceInputDataSet) agentMLInput.getInputDataset();
686+
String question = dataset.getParameters().get(QUESTION);
687+
Assert.assertNotNull(question);
688+
Assert.assertTrue(question.contains("What is the weather?"));
689+
Assert.assertFalse(question.contains("Context:"));
625690
}
626691

627692
@Test

0 commit comments

Comments
 (0)