Skip to content

Commit 9b73513

Browse files
authored
Introduce V2 Chat Agent (#4732)
* feat: introduce v2 agents Signed-off-by: Pavan Yekbote <pybot@amazon.com> * fix: spotless Signed-off-by: Pavan Yekbote <pybot@amazon.com> * fix: executor post rebase Signed-off-by: Pavan Yekbote <pybot@amazon.com> * fix: order of setting params Signed-off-by: Pavan Yekbote <pybot@amazon.com> * fix: support tool use and store in memory, add validation for streaming & context management, remove dead code Signed-off-by: Pavan Yekbote <pybot@amazon.com> * fix: test cases Signed-off-by: Pavan Yekbote <pybot@amazon.com> * fix: add test coverage and fix issues raised by review Signed-off-by: Pavan Yekbote <pybot@amazon.com> * fix: apply spotless Signed-off-by: Pavan Yekbote <pybot@amazon.com> * fix: java doc Signed-off-by: Pavan Yekbote <pybot@amazon.com> * tests: add test coverage Signed-off-by: Pavan Yekbote <pybot@amazon.com> * tests: add test coverage for agenttype Signed-off-by: Pavan Yekbote <pybot@amazon.com> * fix: v2 runner test Signed-off-by: Pavan Yekbote <pybot@amazon.com> --------- Signed-off-by: Pavan Yekbote <pybot@amazon.com>
1 parent aa04a63 commit 9b73513

26 files changed

+4970
-114
lines changed

common/src/main/java/org/opensearch/ml/common/MLAgentType.java

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,16 @@
88
import java.util.Locale;
99

1010
public enum MLAgentType {
11+
// V1 Agent Types (legacy)
1112
FLOW,
1213
CONVERSATIONAL,
1314
CONVERSATIONAL_FLOW,
1415
PLAN_EXECUTE_AND_REFLECT,
15-
AG_UI;
16+
AG_UI,
17+
18+
// V2 Agent Types (simplified, message-centric)
19+
// Add to isV2() method when a new V2 agent is added
20+
CONVERSATIONAL_V2;
1621

1722
public static MLAgentType from(String value) {
1823
if (value == null) {
@@ -24,4 +29,15 @@ public static MLAgentType from(String value) {
2429
throw new IllegalArgumentException("Wrong Agent type");
2530
}
2631
}
32+
33+
/**
34+
* Check if this is a V2 agent type.
35+
* V2 agents use message-centric architecture, require agentic memory,
36+
* and support standardized input/output formats.
37+
*
38+
* @return true if this is a V2 agent type, false otherwise
39+
*/
40+
public boolean isV2() {
41+
return this == CONVERSATIONAL_V2;
42+
}
2743
}

common/src/main/java/org/opensearch/ml/common/agent/BedrockConverseModelProvider.java

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -410,27 +410,34 @@ private String mapSourceTypeToBedrock(SourceType sourceType, String dataUrl) {
410410
}
411411

412412
/**
413-
* Parses a Bedrock Converse response message into a unified Message object.
414-
* Handles three content item types within the "content" array:
413+
* Extracts the message JSON from a Bedrock Converse API response.
414+
* Navigates the response structure to retrieve the message object at output.message.
415415
*
416-
* 1. Text response:
417-
* {"role": "assistant", "content": [{"text": "Here is the result..."}]}
416+
* Expected response structure:
417+
* {"output": {"message": {"role": "assistant", "content": [...]}}}
418418
*
419-
* 2. Tool call request:
420-
* {"role": "assistant", "content": [
421-
* {"toolUse": {"toolUseId": "tool_abc123", "name": "get_weather",
422-
* "input": {"location": "Seattle"}}}
423-
* ]}
424-
*
425-
* 3. Tool result (stored as role=user by Bedrock, mapped to role=tool for unified format):
426-
* {"role": "user", "content": [
427-
* {"toolResult": {"toolUseId": "tool_abc123",
428-
* "content": [{"text": "72°F, sunny"}]}}
429-
* ]}
430-
*
431-
* @param json JSON string containing the Bedrock Converse response message
432-
* @return a unified Message object, or null if the input cannot be parsed
419+
* @param responseData Map containing the Bedrock Converse API response data
420+
* @return JSON string representation of the message object, or null if not found
433421
*/
422+
@Override
423+
public String extractMessageFromResponse(Map<String, ?> responseData) {
424+
if (responseData == null) {
425+
return null;
426+
}
427+
428+
Object outputObj = responseData.get("output");
429+
if (outputObj instanceof Map) {
430+
@SuppressWarnings("unchecked")
431+
Map<String, ?> outputMap = (Map<String, ?>) outputObj;
432+
Object messageObj = outputMap.get("message");
433+
if (messageObj != null) {
434+
return StringUtils.toJson(messageObj);
435+
}
436+
}
437+
438+
return null;
439+
}
440+
434441
@SuppressWarnings("unchecked")
435442
@Override
436443
public Message parseToUnifiedMessage(String json) {

common/src/main/java/org/opensearch/ml/common/agent/GeminiV1BetaGenerateContentModelProvider.java

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,27 @@ private String mapImageSourceTypeToGemini(SourceType sourceType) {
297297
};
298298
}
299299

300+
@Override
301+
public String extractMessageFromResponse(Map<String, ?> responseData) {
302+
if (responseData == null) {
303+
return null;
304+
}
305+
306+
Object candidatesObj = responseData.get("candidates");
307+
if (candidatesObj instanceof List) {
308+
List<?> candidatesList = (List<?>) candidatesObj;
309+
if (!candidatesList.isEmpty() && candidatesList.get(0) instanceof Map) {
310+
Map<String, ?> candidateMap = (Map<String, ?>) candidatesList.get(0);
311+
Object contentObj = candidateMap.get("content");
312+
if (contentObj != null) {
313+
return org.opensearch.ml.common.utils.StringUtils.toJson(contentObj);
314+
}
315+
}
316+
}
317+
318+
return null;
319+
}
320+
300321
@Override
301322
public Message parseToUnifiedMessage(String json) {
302323
throw new UnsupportedOperationException("parseToUnifiedMessage is not yet supported for Gemini model provider");

common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -505,6 +505,16 @@ public Tags getTags() {
505505
return tags;
506506
}
507507

508+
/**
509+
* Check if this agent uses the unified agent interface.
510+
* Unified interface agents have the 'model' field configured and use simplified registration.
511+
*
512+
* @return true if agent was registered with unified interface (has model field)
513+
*/
514+
public boolean usesUnifiedInterface() {
515+
return model != null;
516+
}
517+
508518
/**
509519
* Check if this agent has context management configuration
510520
* @return true if agent has either context management name or inline configuration

common/src/main/java/org/opensearch/ml/common/agent/OpenaiV1ChatCompletionsModelProvider.java

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -354,24 +354,38 @@ private String mapImageSourceTypeToOpenAI(SourceType sourceType) {
354354
}
355355

356356
/**
357-
* Parses an OpenAI chat completions response message into a unified Message object.
358-
* Handles three message types:
357+
* Extracts the message JSON from an OpenAI Chat Completions API response.
358+
* Navigates the response structure to retrieve the message object at choices[0].message.
359359
*
360-
* 1. Assistant text response:
361-
* {"role": "assistant", "content": "Here is the result..."}
360+
* Expected response structure:
361+
* {"choices": [{"message": {"role": "assistant", "content": "..."}}]}
362362
*
363-
* 2. Assistant tool call request:
364-
* {"role": "assistant", "content": null, "tool_calls": [
365-
* {"id": "call_abc123", "type": "function",
366-
* "function": {"name": "get_weather", "arguments": "{\"location\":\"Seattle\"}"}}
367-
* ]}
368-
*
369-
* 3. Tool result message:
370-
* {"role": "tool", "tool_call_id": "call_abc123", "content": "72°F, sunny"}
371-
*
372-
* @param json JSON string containing the OpenAI response message
373-
* @return a unified Message object, or null if the input cannot be parsed
363+
* @param responseData Map containing the OpenAI Chat Completions API response data
364+
* @return JSON string representation of the message object, or null if not found
374365
*/
366+
@Override
367+
public String extractMessageFromResponse(Map<String, ?> responseData) {
368+
if (responseData == null) {
369+
return null;
370+
}
371+
372+
Object choicesObj = responseData.get("choices");
373+
if (choicesObj instanceof List) {
374+
@SuppressWarnings("unchecked")
375+
List<?> choicesList = (List<?>) choicesObj;
376+
if (!choicesList.isEmpty() && choicesList.get(0) instanceof Map) {
377+
@SuppressWarnings("unchecked")
378+
Map<String, ?> choiceMap = (Map<String, ?>) choicesList.get(0);
379+
Object messageObj = choiceMap.get("message");
380+
if (messageObj != null) {
381+
return StringUtils.toJson(messageObj);
382+
}
383+
}
384+
}
385+
386+
return null;
387+
}
388+
375389
@SuppressWarnings("unchecked")
376390
@Override
377391
public Message parseToUnifiedMessage(String json) {

common/src/main/java/org/opensearch/ml/common/model/ModelProvider.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,15 @@ public abstract class ModelProvider {
7474
*/
7575
public abstract Map<String, String> mapMessages(List<Message> messages, MLAgentType type);
7676

77+
/**
78+
* Extracts the message portion from a full LLM response.
79+
* Each provider knows its own response format (e.g., Bedrock: output.message, OpenAI: choices[0].message).
80+
*
81+
* @param responseData the full LLM response data map
82+
* @return JSON string of just the message portion, or null if extraction fails
83+
*/
84+
public abstract String extractMessageFromResponse(Map<String, ?> responseData);
85+
7786
/**
7887
* Parses an provider-specific format response message JSON string into a unified Message object.
7988
*

common/src/main/java/org/opensearch/ml/common/output/MLOutputType.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,6 @@ public enum MLOutputType {
1111
SAMPLE_ALGO,
1212
MODEL_TENSOR,
1313
MCORR_TENSOR,
14-
ML_TASK_OUTPUT
14+
ML_TASK_OUTPUT,
15+
AGENT_V2
1516
}
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.common.output.execute.agent;
7+
8+
import static org.opensearch.ml.common.utils.StringUtils.gson;
9+
10+
import java.io.IOException;
11+
import java.util.HashMap;
12+
import java.util.Map;
13+
14+
import org.opensearch.core.common.io.stream.StreamInput;
15+
import org.opensearch.core.common.io.stream.StreamOutput;
16+
import org.opensearch.core.xcontent.XContentBuilder;
17+
import org.opensearch.ml.common.annotation.MLAlgoOutput;
18+
import org.opensearch.ml.common.input.execute.agent.Message;
19+
import org.opensearch.ml.common.output.MLOutput;
20+
import org.opensearch.ml.common.output.MLOutputType;
21+
22+
import lombok.Builder;
23+
import lombok.Data;
24+
import lombok.EqualsAndHashCode;
25+
26+
/**
27+
* Standardized output format for V2 agents (CONVERSATIONAL_V2).
28+
* Follows Strands-style response format with structured fields.
29+
*/
30+
@Data
31+
@EqualsAndHashCode(callSuper = false)
32+
@MLAlgoOutput(MLOutputType.AGENT_V2)
33+
public class AgentV2Output extends MLOutput {
34+
35+
private static final MLOutputType OUTPUT_TYPE = MLOutputType.AGENT_V2;
36+
37+
public static final String STOP_REASON_FIELD = "stop_reason";
38+
public static final String MESSAGE_FIELD = "message";
39+
public static final String MEMORY_ID_FIELD = "memory_id";
40+
public static final String METRICS_FIELD = "metrics";
41+
42+
/**
43+
* The reason the agent stopped execution.
44+
* Values: "end_turn", "max_iterations", "tool_use"
45+
*/
46+
private String stopReason;
47+
48+
/**
49+
* The assistant's response message with content blocks.
50+
*/
51+
private Message message;
52+
53+
/**
54+
* The memory/conversation ID for session tracking.
55+
*/
56+
private String memoryId;
57+
58+
/**
59+
* Execution metrics (token usage, latency, etc.).
60+
*/
61+
@Builder.Default
62+
private Map<String, Object> metrics = new HashMap<>();
63+
64+
@Builder
65+
public AgentV2Output(String stopReason, Message message, String memoryId, Map<String, Object> metrics) {
66+
super(OUTPUT_TYPE);
67+
this.stopReason = stopReason;
68+
this.message = message;
69+
this.memoryId = memoryId;
70+
this.metrics = metrics != null ? metrics : new HashMap<>();
71+
}
72+
73+
public AgentV2Output(StreamInput in) throws IOException {
74+
super(OUTPUT_TYPE);
75+
this.stopReason = in.readOptionalString();
76+
// Deserialize Message from JSON string
77+
String messageJson = in.readOptionalString();
78+
this.message = messageJson != null ? gson.fromJson(messageJson, Message.class) : null;
79+
this.memoryId = in.readOptionalString();
80+
this.metrics = in.readBoolean() ? in.readMap() : new HashMap<>();
81+
}
82+
83+
@Override
84+
public void writeTo(StreamOutput out) throws IOException {
85+
super.writeTo(out);
86+
out.writeOptionalString(stopReason);
87+
// Serialize Message as JSON string
88+
String messageJson = message != null ? gson.toJson(message) : null;
89+
out.writeOptionalString(messageJson);
90+
out.writeOptionalString(memoryId);
91+
if (metrics != null && !metrics.isEmpty()) {
92+
out.writeBoolean(true);
93+
out.writeMap(metrics);
94+
} else {
95+
out.writeBoolean(false);
96+
}
97+
}
98+
99+
@Override
100+
protected MLOutputType getType() {
101+
return OUTPUT_TYPE;
102+
}
103+
104+
@Override
105+
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
106+
builder.startObject();
107+
if (stopReason != null) {
108+
builder.field(STOP_REASON_FIELD, stopReason);
109+
}
110+
if (message != null) {
111+
builder.startObject(MESSAGE_FIELD);
112+
113+
// Write content blocks (simplified format without "type" field)
114+
if (message.getContent() != null && !message.getContent().isEmpty()) {
115+
builder.startArray("content");
116+
for (var contentBlock : message.getContent()) {
117+
builder.startObject();
118+
if (contentBlock.getText() != null) {
119+
builder.field("text", contentBlock.getText());
120+
}
121+
builder.endObject();
122+
}
123+
builder.endArray();
124+
}
125+
126+
// Write role
127+
if (message.getRole() != null) {
128+
builder.field("role", message.getRole());
129+
}
130+
131+
builder.endObject();
132+
}
133+
if (memoryId != null) {
134+
builder.field(MEMORY_ID_FIELD, memoryId);
135+
}
136+
if (metrics != null && !metrics.isEmpty()) {
137+
builder.field(METRICS_FIELD, metrics);
138+
}
139+
builder.endObject();
140+
return builder;
141+
}
142+
}

0 commit comments

Comments
 (0)