Skip to content

Commit 9c3dd69

Browse files
authored
Merge pull request #45 from rongxinzy/techdebt/general-agent-mainline
Refactor general agent event streaming
2 parents a93c6af + b68be11 commit 9c3dd69

File tree

2 files changed

+337
-128
lines changed

2 files changed

+337
-128
lines changed

swarmmind/agents/general_agent.py

Lines changed: 177 additions & 126 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import logging
1010
import uuid
1111
from collections.abc import AsyncGenerator, Generator
12+
from dataclasses import dataclass, field
1213
from typing import Any
1314

1415
import deerflow.client as deerflow_client_module
@@ -26,6 +27,18 @@
2627
logger = logging.getLogger(__name__)
2728

2829

30+
@dataclass
31+
class _StreamCaptureState:
32+
"""Mutable capture state for a single async DeerFlow turn."""
33+
34+
current_chunk_msg_id: str | None = None
35+
accumulated_reasoning: str = ""
36+
accumulated_content: str = ""
37+
final_text: str = ""
38+
tool_results: list[str] = field(default_factory=list)
39+
seen_ids: set[str] = field(default_factory=set)
40+
41+
2942
class SwarmMindDeerFlowClient(DeerFlowClient):
3043
"""DeerFlow client wrapper that injects SwarmMind product identity."""
3144

@@ -364,14 +377,7 @@ async def _astream_events(
364377
state: dict[str, Any] = {"messages": [HumanMessage(content=goal, id=current_user_message_id)]}
365378
runtime_context = {"thread_id": thread_id}
366379

367-
seen_ids: set[str] = set()
368-
final_text = ""
369-
tool_results: list[str] = []
370-
371-
# Token-level streaming accumulators (reset per LLM invocation)
372-
current_chunk_msg_id: str | None = None
373-
accumulated_reasoning = ""
374-
accumulated_content = ""
380+
capture_state = _StreamCaptureState()
375381

376382
async for mode_tag, chunk in self._client._agent.astream(
377383
state,
@@ -381,133 +387,27 @@ async def _astream_events(
381387
):
382388
if mode_tag == "messages":
383389
msg_chunk, _metadata = chunk
384-
if not isinstance(msg_chunk, AIMessageChunk):
385-
continue
386-
387-
chunk_id = getattr(msg_chunk, "id", None)
388-
if chunk_id and chunk_id != current_chunk_msg_id:
389-
# New LLM invocation started; reset accumulators
390-
current_chunk_msg_id = chunk_id
391-
accumulated_reasoning = ""
392-
accumulated_content = ""
393-
394-
if not current_chunk_msg_id:
395-
current_chunk_msg_id = str(uuid.uuid4())
396-
397-
# Stream reasoning tokens
398-
reasoning_delta = self._extract_reasoning_delta(msg_chunk)
399-
if reasoning_delta:
400-
accumulated_reasoning += reasoning_delta
401-
yield {
402-
"type": "assistant_reasoning",
403-
"message_id": current_chunk_msg_id,
404-
"content": accumulated_reasoning,
405-
}
406-
407-
# Stream content tokens
408-
content_delta = self._extract_content_delta(msg_chunk)
409-
if content_delta:
410-
accumulated_content += content_delta
411-
yield {
412-
"type": "assistant_message",
413-
"message_id": current_chunk_msg_id,
414-
"content": accumulated_content,
415-
}
390+
for event in self._process_messages_mode_chunk(msg_chunk, capture_state):
391+
yield event
416392

417393
elif mode_tag == "custom":
418-
# Handle custom events from task_tool (task_started, task_running, task_completed, task_failed)
419-
event = chunk
420-
logger.debug("Custom event received: %s", event)
421-
if isinstance(event, dict) and event.get("type") in (
422-
"task_started",
423-
"task_running",
424-
"task_completed",
425-
"task_failed",
426-
):
427-
logger.info("Task event: type=%s, task_id=%s", event.get("type"), event.get("task_id"))
428-
yield {
429-
"type": "custom_event",
430-
"event_type": event["type"],
431-
"task_id": event.get("task_id"),
432-
"description": event.get("description"),
433-
"message": event.get("message"),
434-
"result": event.get("result"),
435-
"error": event.get("error"),
436-
}
394+
event = self._process_custom_mode_chunk(chunk)
395+
if event is not None:
396+
yield event
437397

438398
elif mode_tag == "values":
439399
messages = chunk.get("messages", [])
440-
turn_anchor_index = next(
441-
(
442-
index
443-
for index, message in enumerate(messages)
444-
if isinstance(message, HumanMessage) and getattr(message, "id", None) == current_user_message_id
445-
),
446-
-1,
447-
)
448-
449-
if turn_anchor_index == -1:
450-
continue
451-
452-
for msg in messages[turn_anchor_index + 1 :]:
453-
if isinstance(msg, HumanMessage):
454-
continue
455-
456-
msg_id = getattr(msg, "id", None)
457-
if msg_id and msg_id in seen_ids:
458-
continue
459-
if msg_id:
460-
seen_ids.add(msg_id)
461-
462-
if isinstance(msg, AIMessage):
463-
# Tool calls (only from values mode for completeness)
464-
if msg.tool_calls:
465-
tool_names = [tc.get("name") for tc in msg.tool_calls]
466-
logger.info("AI tool calls: %s", tool_names)
467-
yield {
468-
"type": "assistant_tool_calls",
469-
"message_id": msg_id,
470-
"tool_calls": [
471-
{
472-
"name": tool_call.get("name"),
473-
"args": tool_call.get("args", {}),
474-
"id": tool_call.get("id"),
475-
}
476-
for tool_call in msg.tool_calls
477-
],
478-
}
479-
480-
# Track final text from complete messages
481-
content = self._client._extract_text(msg.content)
482-
if content:
483-
final_text = content
484-
485-
elif isinstance(msg, ToolMessage):
486-
tool_name = getattr(msg, "name", None) or "unknown"
487-
tool_content = self._client._extract_text(msg.content)
488-
logger.info(
489-
"Tool result: name=%s, content_preview=%s",
490-
tool_name,
491-
tool_content[:100] if tool_content else "(empty)",
492-
)
493-
if tool_content:
494-
tool_results.append(f"[{tool_name}]: {tool_content[:200]}")
495-
496-
yield {
497-
"type": "tool_result",
498-
"message_id": msg_id,
499-
"tool_name": tool_name,
500-
"tool_call_id": getattr(msg, "tool_call_id", None),
501-
"content": tool_content,
502-
}
400+
for msg in self._iter_new_turn_messages(messages, current_user_message_id, capture_state.seen_ids):
401+
for event in self._process_values_mode_message(msg, capture_state):
402+
yield event
503403

504404
# Fallback: if messages mode captured content but values mode didn't
505-
if not final_text and accumulated_content:
506-
final_text = accumulated_content
405+
if not capture_state.final_text and capture_state.accumulated_content:
406+
capture_state.final_text = capture_state.accumulated_content
507407

508408
# Store results for the caller to retrieve
509-
self._last_final_text = final_text
510-
self._last_tool_results = tool_results
409+
self._last_final_text = capture_state.final_text
410+
self._last_tool_results = capture_state.tool_results
511411

512412
def stream_events(
513413
self,
@@ -552,6 +452,157 @@ def _run_deerflow_turn(
552452

553453
return final_text, tool_results
554454

455+
def _process_messages_mode_chunk(
456+
self,
457+
msg_chunk: object,
458+
capture_state: _StreamCaptureState,
459+
) -> list[dict[str, Any]]:
460+
"""Convert a streaming AI chunk into accumulated reasoning/content events."""
461+
if not isinstance(msg_chunk, AIMessageChunk):
462+
return []
463+
464+
chunk_id = getattr(msg_chunk, "id", None)
465+
if chunk_id and chunk_id != capture_state.current_chunk_msg_id:
466+
capture_state.current_chunk_msg_id = chunk_id
467+
capture_state.accumulated_reasoning = ""
468+
capture_state.accumulated_content = ""
469+
470+
if not capture_state.current_chunk_msg_id:
471+
capture_state.current_chunk_msg_id = str(uuid.uuid4())
472+
473+
events: list[dict[str, Any]] = []
474+
reasoning_delta = self._extract_reasoning_delta(msg_chunk)
475+
if reasoning_delta:
476+
capture_state.accumulated_reasoning += reasoning_delta
477+
events.append(
478+
{
479+
"type": "assistant_reasoning",
480+
"message_id": capture_state.current_chunk_msg_id,
481+
"content": capture_state.accumulated_reasoning,
482+
}
483+
)
484+
485+
content_delta = self._extract_content_delta(msg_chunk)
486+
if content_delta:
487+
capture_state.accumulated_content += content_delta
488+
events.append(
489+
{
490+
"type": "assistant_message",
491+
"message_id": capture_state.current_chunk_msg_id,
492+
"content": capture_state.accumulated_content,
493+
}
494+
)
495+
496+
return events
497+
498+
@staticmethod
499+
def _process_custom_mode_chunk(event: object) -> dict[str, Any] | None:
500+
"""Normalize supported custom task events from DeerFlow."""
501+
logger.debug("Custom event received: %s", event)
502+
if not isinstance(event, dict) or event.get("type") not in {
503+
"task_started",
504+
"task_running",
505+
"task_completed",
506+
"task_failed",
507+
}:
508+
return None
509+
510+
logger.info("Task event: type=%s, task_id=%s", event.get("type"), event.get("task_id"))
511+
return {
512+
"type": "custom_event",
513+
"event_type": event["type"],
514+
"task_id": event.get("task_id"),
515+
"description": event.get("description"),
516+
"message": event.get("message"),
517+
"result": event.get("result"),
518+
"error": event.get("error"),
519+
}
520+
521+
@staticmethod
522+
def _iter_new_turn_messages(
523+
messages: list[object],
524+
current_user_message_id: str,
525+
seen_ids: set[str],
526+
) -> Generator[object, None, None]:
527+
"""Yield unseen non-user messages after the current turn anchor."""
528+
turn_anchor_index = next(
529+
(
530+
index
531+
for index, message in enumerate(messages)
532+
if isinstance(message, HumanMessage) and getattr(message, "id", None) == current_user_message_id
533+
),
534+
-1,
535+
)
536+
if turn_anchor_index == -1:
537+
return
538+
539+
for msg in messages[turn_anchor_index + 1 :]:
540+
if isinstance(msg, HumanMessage):
541+
continue
542+
543+
msg_id = getattr(msg, "id", None)
544+
if msg_id and msg_id in seen_ids:
545+
continue
546+
if msg_id:
547+
seen_ids.add(msg_id)
548+
yield msg
549+
550+
def _process_values_mode_message(
551+
self,
552+
msg: object,
553+
capture_state: _StreamCaptureState,
554+
) -> list[dict[str, Any]]:
555+
"""Convert full values-mode messages into runtime events and summaries."""
556+
msg_id = getattr(msg, "id", None)
557+
558+
if isinstance(msg, AIMessage):
559+
events: list[dict[str, Any]] = []
560+
if msg.tool_calls:
561+
tool_names = [tc.get("name") for tc in msg.tool_calls]
562+
logger.info("AI tool calls: %s", tool_names)
563+
events.append(
564+
{
565+
"type": "assistant_tool_calls",
566+
"message_id": msg_id,
567+
"tool_calls": [
568+
{
569+
"name": tool_call.get("name"),
570+
"args": tool_call.get("args", {}),
571+
"id": tool_call.get("id"),
572+
}
573+
for tool_call in msg.tool_calls
574+
],
575+
}
576+
)
577+
578+
content = self._client._extract_text(msg.content)
579+
if content:
580+
capture_state.final_text = content
581+
return events
582+
583+
if isinstance(msg, ToolMessage):
584+
tool_name = getattr(msg, "name", None) or "unknown"
585+
tool_content = self._client._extract_text(msg.content)
586+
logger.info(
587+
"Tool result: name=%s, content_preview=%s",
588+
tool_name,
589+
tool_content[:100] if tool_content else "(empty)",
590+
)
591+
if tool_content:
592+
capture_state.tool_results.append(f"[{tool_name}]: {tool_content[:200]}")
593+
594+
return [
595+
{
596+
"type": "tool_result",
597+
"message_id": msg_id,
598+
"tool_name": tool_name,
599+
"tool_call_id": getattr(msg, "tool_call_id", None),
600+
"content": tool_content,
601+
}
602+
]
603+
604+
return []
605+
555606
def _resolve_runtime_options(
556607
self,
557608
runtime_options: ConversationRuntimeOptions | None = None,

0 commit comments

Comments
 (0)