99import logging
1010import uuid
1111from collections .abc import AsyncGenerator , Generator
12+ from dataclasses import dataclass , field
1213from typing import Any
1314
1415import deerflow .client as deerflow_client_module
2627logger = logging .getLogger (__name__ )
2728
2829
30+ @dataclass
31+ class _StreamCaptureState :
32+ """Mutable capture state for a single async DeerFlow turn."""
33+
34+ current_chunk_msg_id : str | None = None
35+ accumulated_reasoning : str = ""
36+ accumulated_content : str = ""
37+ final_text : str = ""
38+ tool_results : list [str ] = field (default_factory = list )
39+ seen_ids : set [str ] = field (default_factory = set )
40+
41+
2942class SwarmMindDeerFlowClient (DeerFlowClient ):
3043 """DeerFlow client wrapper that injects SwarmMind product identity."""
3144
@@ -364,14 +377,7 @@ async def _astream_events(
364377 state : dict [str , Any ] = {"messages" : [HumanMessage (content = goal , id = current_user_message_id )]}
365378 runtime_context = {"thread_id" : thread_id }
366379
367- seen_ids : set [str ] = set ()
368- final_text = ""
369- tool_results : list [str ] = []
370-
371- # Token-level streaming accumulators (reset per LLM invocation)
372- current_chunk_msg_id : str | None = None
373- accumulated_reasoning = ""
374- accumulated_content = ""
380+ capture_state = _StreamCaptureState ()
375381
376382 async for mode_tag , chunk in self ._client ._agent .astream (
377383 state ,
@@ -381,133 +387,27 @@ async def _astream_events(
381387 ):
382388 if mode_tag == "messages" :
383389 msg_chunk , _metadata = chunk
384- if not isinstance (msg_chunk , AIMessageChunk ):
385- continue
386-
387- chunk_id = getattr (msg_chunk , "id" , None )
388- if chunk_id and chunk_id != current_chunk_msg_id :
389- # New LLM invocation started; reset accumulators
390- current_chunk_msg_id = chunk_id
391- accumulated_reasoning = ""
392- accumulated_content = ""
393-
394- if not current_chunk_msg_id :
395- current_chunk_msg_id = str (uuid .uuid4 ())
396-
397- # Stream reasoning tokens
398- reasoning_delta = self ._extract_reasoning_delta (msg_chunk )
399- if reasoning_delta :
400- accumulated_reasoning += reasoning_delta
401- yield {
402- "type" : "assistant_reasoning" ,
403- "message_id" : current_chunk_msg_id ,
404- "content" : accumulated_reasoning ,
405- }
406-
407- # Stream content tokens
408- content_delta = self ._extract_content_delta (msg_chunk )
409- if content_delta :
410- accumulated_content += content_delta
411- yield {
412- "type" : "assistant_message" ,
413- "message_id" : current_chunk_msg_id ,
414- "content" : accumulated_content ,
415- }
390+ for event in self ._process_messages_mode_chunk (msg_chunk , capture_state ):
391+ yield event
416392
417393 elif mode_tag == "custom" :
418- # Handle custom events from task_tool (task_started, task_running, task_completed, task_failed)
419- event = chunk
420- logger .debug ("Custom event received: %s" , event )
421- if isinstance (event , dict ) and event .get ("type" ) in (
422- "task_started" ,
423- "task_running" ,
424- "task_completed" ,
425- "task_failed" ,
426- ):
427- logger .info ("Task event: type=%s, task_id=%s" , event .get ("type" ), event .get ("task_id" ))
428- yield {
429- "type" : "custom_event" ,
430- "event_type" : event ["type" ],
431- "task_id" : event .get ("task_id" ),
432- "description" : event .get ("description" ),
433- "message" : event .get ("message" ),
434- "result" : event .get ("result" ),
435- "error" : event .get ("error" ),
436- }
394+ event = self ._process_custom_mode_chunk (chunk )
395+ if event is not None :
396+ yield event
437397
438398 elif mode_tag == "values" :
439399 messages = chunk .get ("messages" , [])
440- turn_anchor_index = next (
441- (
442- index
443- for index , message in enumerate (messages )
444- if isinstance (message , HumanMessage ) and getattr (message , "id" , None ) == current_user_message_id
445- ),
446- - 1 ,
447- )
448-
449- if turn_anchor_index == - 1 :
450- continue
451-
452- for msg in messages [turn_anchor_index + 1 :]:
453- if isinstance (msg , HumanMessage ):
454- continue
455-
456- msg_id = getattr (msg , "id" , None )
457- if msg_id and msg_id in seen_ids :
458- continue
459- if msg_id :
460- seen_ids .add (msg_id )
461-
462- if isinstance (msg , AIMessage ):
463- # Tool calls (only from values mode for completeness)
464- if msg .tool_calls :
465- tool_names = [tc .get ("name" ) for tc in msg .tool_calls ]
466- logger .info ("AI tool calls: %s" , tool_names )
467- yield {
468- "type" : "assistant_tool_calls" ,
469- "message_id" : msg_id ,
470- "tool_calls" : [
471- {
472- "name" : tool_call .get ("name" ),
473- "args" : tool_call .get ("args" , {}),
474- "id" : tool_call .get ("id" ),
475- }
476- for tool_call in msg .tool_calls
477- ],
478- }
479-
480- # Track final text from complete messages
481- content = self ._client ._extract_text (msg .content )
482- if content :
483- final_text = content
484-
485- elif isinstance (msg , ToolMessage ):
486- tool_name = getattr (msg , "name" , None ) or "unknown"
487- tool_content = self ._client ._extract_text (msg .content )
488- logger .info (
489- "Tool result: name=%s, content_preview=%s" ,
490- tool_name ,
491- tool_content [:100 ] if tool_content else "(empty)" ,
492- )
493- if tool_content :
494- tool_results .append (f"[{ tool_name } ]: { tool_content [:200 ]} " )
495-
496- yield {
497- "type" : "tool_result" ,
498- "message_id" : msg_id ,
499- "tool_name" : tool_name ,
500- "tool_call_id" : getattr (msg , "tool_call_id" , None ),
501- "content" : tool_content ,
502- }
400+ for msg in self ._iter_new_turn_messages (messages , current_user_message_id , capture_state .seen_ids ):
401+ for event in self ._process_values_mode_message (msg , capture_state ):
402+ yield event
503403
504404 # Fallback: if messages mode captured content but values mode didn't
505- if not final_text and accumulated_content :
506- final_text = accumulated_content
405+ if not capture_state . final_text and capture_state . accumulated_content :
406+ capture_state . final_text = capture_state . accumulated_content
507407
508408 # Store results for the caller to retrieve
509- self ._last_final_text = final_text
510- self ._last_tool_results = tool_results
409+ self ._last_final_text = capture_state . final_text
410+ self ._last_tool_results = capture_state . tool_results
511411
512412 def stream_events (
513413 self ,
@@ -552,6 +452,157 @@ def _run_deerflow_turn(
552452
553453 return final_text , tool_results
554454
455+ def _process_messages_mode_chunk (
456+ self ,
457+ msg_chunk : object ,
458+ capture_state : _StreamCaptureState ,
459+ ) -> list [dict [str , Any ]]:
460+ """Convert a streaming AI chunk into accumulated reasoning/content events."""
461+ if not isinstance (msg_chunk , AIMessageChunk ):
462+ return []
463+
464+ chunk_id = getattr (msg_chunk , "id" , None )
465+ if chunk_id and chunk_id != capture_state .current_chunk_msg_id :
466+ capture_state .current_chunk_msg_id = chunk_id
467+ capture_state .accumulated_reasoning = ""
468+ capture_state .accumulated_content = ""
469+
470+ if not capture_state .current_chunk_msg_id :
471+ capture_state .current_chunk_msg_id = str (uuid .uuid4 ())
472+
473+ events : list [dict [str , Any ]] = []
474+ reasoning_delta = self ._extract_reasoning_delta (msg_chunk )
475+ if reasoning_delta :
476+ capture_state .accumulated_reasoning += reasoning_delta
477+ events .append (
478+ {
479+ "type" : "assistant_reasoning" ,
480+ "message_id" : capture_state .current_chunk_msg_id ,
481+ "content" : capture_state .accumulated_reasoning ,
482+ }
483+ )
484+
485+ content_delta = self ._extract_content_delta (msg_chunk )
486+ if content_delta :
487+ capture_state .accumulated_content += content_delta
488+ events .append (
489+ {
490+ "type" : "assistant_message" ,
491+ "message_id" : capture_state .current_chunk_msg_id ,
492+ "content" : capture_state .accumulated_content ,
493+ }
494+ )
495+
496+ return events
497+
498+ @staticmethod
499+ def _process_custom_mode_chunk (event : object ) -> dict [str , Any ] | None :
500+ """Normalize supported custom task events from DeerFlow."""
501+ logger .debug ("Custom event received: %s" , event )
502+ if not isinstance (event , dict ) or event .get ("type" ) not in {
503+ "task_started" ,
504+ "task_running" ,
505+ "task_completed" ,
506+ "task_failed" ,
507+ }:
508+ return None
509+
510+ logger .info ("Task event: type=%s, task_id=%s" , event .get ("type" ), event .get ("task_id" ))
511+ return {
512+ "type" : "custom_event" ,
513+ "event_type" : event ["type" ],
514+ "task_id" : event .get ("task_id" ),
515+ "description" : event .get ("description" ),
516+ "message" : event .get ("message" ),
517+ "result" : event .get ("result" ),
518+ "error" : event .get ("error" ),
519+ }
520+
521+ @staticmethod
522+ def _iter_new_turn_messages (
523+ messages : list [object ],
524+ current_user_message_id : str ,
525+ seen_ids : set [str ],
526+ ) -> Generator [object , None , None ]:
527+ """Yield unseen non-user messages after the current turn anchor."""
528+ turn_anchor_index = next (
529+ (
530+ index
531+ for index , message in enumerate (messages )
532+ if isinstance (message , HumanMessage ) and getattr (message , "id" , None ) == current_user_message_id
533+ ),
534+ - 1 ,
535+ )
536+ if turn_anchor_index == - 1 :
537+ return
538+
539+ for msg in messages [turn_anchor_index + 1 :]:
540+ if isinstance (msg , HumanMessage ):
541+ continue
542+
543+ msg_id = getattr (msg , "id" , None )
544+ if msg_id and msg_id in seen_ids :
545+ continue
546+ if msg_id :
547+ seen_ids .add (msg_id )
548+ yield msg
549+
550+ def _process_values_mode_message (
551+ self ,
552+ msg : object ,
553+ capture_state : _StreamCaptureState ,
554+ ) -> list [dict [str , Any ]]:
555+ """Convert full values-mode messages into runtime events and summaries."""
556+ msg_id = getattr (msg , "id" , None )
557+
558+ if isinstance (msg , AIMessage ):
559+ events : list [dict [str , Any ]] = []
560+ if msg .tool_calls :
561+ tool_names = [tc .get ("name" ) for tc in msg .tool_calls ]
562+ logger .info ("AI tool calls: %s" , tool_names )
563+ events .append (
564+ {
565+ "type" : "assistant_tool_calls" ,
566+ "message_id" : msg_id ,
567+ "tool_calls" : [
568+ {
569+ "name" : tool_call .get ("name" ),
570+ "args" : tool_call .get ("args" , {}),
571+ "id" : tool_call .get ("id" ),
572+ }
573+ for tool_call in msg .tool_calls
574+ ],
575+ }
576+ )
577+
578+ content = self ._client ._extract_text (msg .content )
579+ if content :
580+ capture_state .final_text = content
581+ return events
582+
583+ if isinstance (msg , ToolMessage ):
584+ tool_name = getattr (msg , "name" , None ) or "unknown"
585+ tool_content = self ._client ._extract_text (msg .content )
586+ logger .info (
587+ "Tool result: name=%s, content_preview=%s" ,
588+ tool_name ,
589+ tool_content [:100 ] if tool_content else "(empty)" ,
590+ )
591+ if tool_content :
592+ capture_state .tool_results .append (f"[{ tool_name } ]: { tool_content [:200 ]} " )
593+
594+ return [
595+ {
596+ "type" : "tool_result" ,
597+ "message_id" : msg_id ,
598+ "tool_name" : tool_name ,
599+ "tool_call_id" : getattr (msg , "tool_call_id" , None ),
600+ "content" : tool_content ,
601+ }
602+ ]
603+
604+ return []
605+
555606 def _resolve_runtime_options (
556607 self ,
557608 runtime_options : ConversationRuntimeOptions | None = None ,
0 commit comments