Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@

package org.opensearch.ml.engine.algorithms.agent;

import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_PARAM_RUN_ID;
import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_PARAM_THREAD_ID;
import static org.opensearch.ml.common.utils.StringUtils.gson;

import java.util.ArrayList;
Expand All @@ -20,7 +18,6 @@
import org.opensearch.ml.common.agui.AGUIInputConverter;
import org.opensearch.ml.common.agui.BaseEvent;
import org.opensearch.ml.common.agui.MessagesSnapshotEvent;
import org.opensearch.ml.common.agui.RunFinishedEvent;
import org.opensearch.ml.common.agui.ToolCallResultEvent;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.execute.agent.Message;
Expand Down Expand Up @@ -164,25 +161,9 @@ public void sendBackendToolResult(String toolCallId, String toolResult, String s
}

public void sendRunFinishedAndCloseStream(String sessionId, String parentInteractionId) {
try {
String threadId = parameters.get(AGUI_PARAM_THREAD_ID);
String runId = parameters.get(AGUI_PARAM_RUN_ID);

// Ensure non-null values to avoid NPE in RunFinishedEvent.writeTo()
if (threadId == null) {
log.warn("AG-UI threadId is null, using generated value. This may cause frontend errors.");
threadId = "thread_" + System.nanoTime();
}
if (runId == null) {
log.warn("AG-UI runId is null, using generated value. This may cause frontend errors.");
runId = "run_" + System.nanoTime();
}

BaseEvent runFinishedEvent = new RunFinishedEvent(threadId, runId, null);
sendAGUIEvent(runFinishedEvent, true);
} catch (Exception e) {
log.error("Failed to send run finished event and close stream", e);
}
// Send an empty completion chunk with is_last=true.
// RestMLExecuteStreamAction will emit RUN_FINISHED when it sees the final chunk.
sendCompletionChunk(sessionId, parentInteractionId);
}

public void sendMessagesSnapshot(List<Message> history, String memoryId, ActionListener<Object> listener) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@
package org.opensearch.ml.engine.algorithms.remote.streaming;

import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_PARAM_MESSAGE_ID;
import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_PARAM_RUN_ID;
import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_PARAM_TEXT_MESSAGE_STARTED;
import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_PARAM_THREAD_ID;
import static org.opensearch.ml.common.utils.StringUtils.gson;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_INTERFACE_OPENAI_V1_CHAT_COMPLETIONS;

Expand All @@ -21,7 +19,6 @@
import java.util.concurrent.atomic.AtomicBoolean;

import org.opensearch.ml.common.agui.BaseEvent;
import org.opensearch.ml.common.agui.RunFinishedEvent;
import org.opensearch.ml.common.agui.TextMessageContentEvent;
import org.opensearch.ml.common.agui.TextMessageEndEvent;
import org.opensearch.ml.common.agui.TextMessageStartEvent;
Expand Down Expand Up @@ -244,21 +241,12 @@ private void handleDoneEvent() {
&& "true".equalsIgnoreCase(parameters.get(AGUI_PARAM_TEXT_MESSAGE_STARTED));

if (textMessageStarted) {
// End any remaining text message
parameters.put(AGUI_PARAM_TEXT_MESSAGE_STARTED, "false");
BaseEvent textMessageEndEvent = new TextMessageEndEvent(messageId);
sendAGUIEvent(textMessageEndEvent, false, streamActionListener);
log.debug("AG-UI: Sent TEXT_MESSAGE_END for messageId: {} at stream end", messageId);
}

// Send RUN_FINISHED event
String threadId = parameters.get(AGUI_PARAM_THREAD_ID);
String runId = parameters.get(AGUI_PARAM_RUN_ID);
BaseEvent runFinishedEvent = new RunFinishedEvent(threadId, runId, null);
sendAGUIEvent(runFinishedEvent, false, streamActionListener);
log.debug("AG-UI: Sent RUN_FINISHED event at [DONE] - threadId={}, runId={}", threadId, runId);

// Trigger agentListener callback to save assistant structured message
streamActionListener.onResponse(createFinalAnswerResponse(accumulatedContent.toString()));
}

Expand All @@ -277,21 +265,12 @@ private void processStreamChunk(Map<String, Object> dataMap) {

if (isAGUIAgent) {
if (textMessageStarted) {
// End the current text message
parameters.put(AGUI_PARAM_TEXT_MESSAGE_STARTED, "false");
BaseEvent textMessageEndEvent = new TextMessageEndEvent(messageId);
sendAGUIEvent(textMessageEndEvent, false, streamActionListener);
log.debug("AG-UI: Sent TEXT_MESSAGE_END for messageId: {}", messageId);
}

// Send RUN_FINISHED event
String threadId = parameters.get(AGUI_PARAM_THREAD_ID);
String runId = parameters.get(AGUI_PARAM_RUN_ID);
BaseEvent runFinishedEvent = new RunFinishedEvent(threadId, runId, null);
sendAGUIEvent(runFinishedEvent, false, streamActionListener);
log.debug("AG-UI: Sent RUN_FINISHED event - threadId={}, runId={}", threadId, runId);

// Trigger agentListener callback to save assistant structured message
streamActionListener.onResponse(createFinalAnswerResponse(accumulatedContent.toString()));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -486,11 +486,12 @@ private HttpChunk convertToHttpChunk(MLTaskResponse response, boolean isAGUIAgen
}

// Forward any content events (pass false — isLast is controlled by the outer chunk)
HttpChunk contentChunk = convertToAGUIEvent(content, false);
combinedSse.append(new String(BytesReference.toBytes(contentChunk.content())));
AGUIEventResult contentResult = convertToAGUIEvent(content, false);
combinedSse.append(new String(BytesReference.toBytes(contentResult.chunk().content())));

// RunFinished is the last AG-UI event, emitted only on the final chunk
if (isLast) {
// RunFinished is the last AG-UI event, emitted only on the final chunk.
// Skip if a RUN_ERROR was already emitted — RUN_ERROR is a terminal event.
if (isLast && !contentResult.hasRunError()) {
BaseEvent runFinishedEvent = new RunFinishedEvent(threadId, runId, null);
combinedSse.append("data: ").append(runFinishedEvent.toJsonString()).append("\n\n");
}
Expand Down Expand Up @@ -571,7 +572,10 @@ private String extractTensorResult(MLTaskResponse response, String tensorName) {
return null;
}

private HttpChunk convertToAGUIEvent(String content, boolean isLast) {
private record AGUIEventResult(HttpChunk chunk, boolean hasRunError) {
}

private AGUIEventResult convertToAGUIEvent(String content, boolean isLast) {
log
.debug(
"RestMLExecuteStreamAction: convertToAGUIEvent() called - contentLength={}, isLast={}",
Expand All @@ -580,6 +584,7 @@ private HttpChunk convertToAGUIEvent(String content, boolean isLast) {
);

StringBuilder sseResponse = new StringBuilder();
boolean hasRunError = false;

if (content != null && !content.isEmpty()) {
log.debug("RestMLExecuteStreamAction: Processing content: '{}'", content);
Expand All @@ -590,25 +595,26 @@ private HttpChunk convertToAGUIEvent(String content, boolean isLast) {
sseResponse.append("data: ").append(element).append("\n\n");
log.debug("RestMLExecuteStreamAction: Processing json element: '{}'", element);
} else {
// catch unexpected content chunks such as Bedrock error
log.warn("Unexpected content received - not valid JSON: {}", content);
BaseEvent runErrorEvent = new RunErrorEvent("Unexpected chunk: " + content, null);
sseResponse.append("data: ").append(runErrorEvent.toJsonString()).append("\n\n");
isLast = true;
hasRunError = true;
}
} catch (Exception e) {
log.error("Failed to process AG-UI events chunk content {}", content, e);
BaseEvent runErrorEvent = new RunErrorEvent("Unexpected error: " + e.getMessage(), null);
sseResponse.append("data: ").append(runErrorEvent.toJsonString()).append("\n\n");
isLast = true;
hasRunError = true;
}
} else {
log.warn("Received null or empty AG-UI content chunk");
}

String finalSse = sseResponse.toString();
log.debug("RestMLExecuteStreamAction: Returning chunk - length={}", finalSse.length());
return createHttpChunk(finalSse, isLast);
return new AGUIEventResult(createHttpChunk(finalSse, isLast), hasRunError);
}

@VisibleForTesting
Expand Down
Loading