Skip to content

Commit a93c6af

Browse files
authored
Merge pull request #44 from rongxinzy/techdebt/runtime-bridge-extraction
refactor: extract runtime bridge and harden base agent
2 parents d378ce6 + 63cb97d commit a93c6af

File tree

5 files changed

+224
-71
lines changed

5 files changed

+224
-71
lines changed

swarmmind/agents/base.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from abc import ABC, abstractmethod
44

55
from swarmmind.layered_memory import LayeredMemory
6-
from swarmmind.models import MemoryContext, MemoryScope
6+
from swarmmind.models import MemoryContext, MemoryLayer, MemoryScope
77
from swarmmind.repositories.action_proposal import ActionProposalRepository
88
from swarmmind.repositories.agent import AgentRepository
99

@@ -12,7 +12,6 @@ class AgentError(Exception):
1212
"""Base exception for SwarmMind agent errors."""
1313

1414
pass
15-
pass
1615

1716

1817
class BaseAgent(ABC):
@@ -45,8 +44,6 @@ def _resolve_write_scope(self, ctx: MemoryContext | None) -> MemoryScope:
4544
but since ctx always provides user_id as a fallback, we fall through
4645
to the most specific available scope (never L4 for regular agents).
4746
"""
48-
from swarmmind.models import MemoryLayer
49-
5047
if ctx is None:
5148
# No context — use a default user scope (L4, but agent will be denied
5249
# unless they are soul_writer; this is intentional guardrail)

swarmmind/agents/general_agent.py

Lines changed: 7 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
from __future__ import annotations
77

8-
import asyncio
98
import json
109
import logging
1110
import uuid
@@ -22,6 +21,7 @@
2221
from swarmmind.prompting import rewrite_swarmmind_identity_prompt
2322
from swarmmind.runtime import RuntimeExecutionError, ensure_default_runtime_instance
2423
from swarmmind.runtime.models import RuntimeInstance
24+
from swarmmind.services.runtime_bridge import iter_async_generator_in_thread
2525

2626
logger = logging.getLogger(__name__)
2727

@@ -524,72 +524,12 @@ def stream_events(
524524
with its own event loop. This isolates the runtime from any existing loop
525525
in the caller while preserving the synchronous generator API.
526526
"""
527-
import queue
528-
import threading
529-
530-
result_queue: queue.Queue = queue.Queue()
531-
exception_container = []
532-
stop_event = threading.Event()
533-
534-
async def _run_async_stream():
535-
"""Run the async stream and put events into the queue."""
536-
try:
537-
async for event in self._astream_events(goal, ctx=ctx, runtime_options=runtime_options):
538-
result_queue.put(("event", event))
539-
except Exception as e:
540-
exception_container.append(e)
541-
finally:
542-
result_queue.put(("done", None))
543-
stop_event.set()
544-
545-
def _run_in_thread():
546-
"""Run the async code in a new event loop in a separate thread.
547-
548-
NOTE: We create a new event loop in this thread to isolate the
549-
DeerFlow agent execution from any existing event loop. This is
550-
necessary because subagents also create their own event loops,
551-
and we need to avoid httpx client binding conflicts.
552-
"""
553-
# Set the event loop policy to create new loops per-thread
554-
loop = asyncio.new_event_loop()
555-
asyncio.set_event_loop(loop)
556-
try:
557-
loop.run_until_complete(_run_async_stream())
558-
finally:
559-
# Clean up any remaining tasks
560-
try:
561-
pending = asyncio.all_tasks(loop)
562-
if pending:
563-
for task in pending:
564-
task.cancel()
565-
loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))
566-
except Exception: # nosec: B110 - cleanup code, safe to ignore
567-
pass
568-
loop.close()
569-
570-
# Start the async execution in a separate thread
571-
thread = threading.Thread(target=_run_in_thread, name="deerflow-stream")
572-
thread.start()
573-
574-
# Yield events as they become available
575-
try:
576-
while not stop_event.is_set() or not result_queue.empty():
577-
try:
578-
item_type, event = result_queue.get(timeout=0.1)
579-
if item_type == "done":
580-
break
581-
if item_type == "event":
582-
yield event
583-
except queue.Empty:
584-
continue
585-
finally:
586-
thread.join(timeout=5.0)
587-
if thread.is_alive():
588-
logger.warning("Stream thread did not terminate within timeout")
589-
590-
# Re-raise any exception from the async execution
591-
if exception_container:
592-
raise exception_container[0]
527+
yield from iter_async_generator_in_thread(
528+
lambda: self._astream_events(goal, ctx=ctx, runtime_options=runtime_options),
529+
thread_name="deerflow-stream",
530+
join_timeout=5.0,
531+
bridge_logger=logger,
532+
)
593533

594534
return self._last_final_text, self._last_tool_results
595535

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
"""Helpers for bridging async runtime streams into sync iterators."""
2+
3+
from __future__ import annotations
4+
5+
import asyncio
6+
import logging
7+
import queue
8+
import threading
9+
from collections.abc import AsyncGenerator, Callable, Generator
10+
from typing import TypeVar
11+
12+
logger = logging.getLogger(__name__)
13+
14+
T = TypeVar("T")
15+
16+
17+
def iter_async_generator_in_thread[T](
18+
async_factory: Callable[[], AsyncGenerator[T, None]],
19+
*,
20+
thread_name: str = "runtime-stream",
21+
join_timeout: float = 5.0,
22+
bridge_logger: logging.Logger = logger,
23+
) -> Generator[T, None, None]:
24+
"""Run an async generator inside a worker thread and yield its items synchronously."""
25+
result_queue: queue.Queue[tuple[str, T | None]] = queue.Queue()
26+
exception_container: list[BaseException] = []
27+
stop_event = threading.Event()
28+
29+
async def _drain_async_generator() -> None:
30+
try:
31+
async for item in async_factory():
32+
result_queue.put(("event", item))
33+
except BaseException as exc:
34+
exception_container.append(exc)
35+
finally:
36+
result_queue.put(("done", None))
37+
stop_event.set()
38+
39+
def _run_in_thread() -> None:
40+
loop = asyncio.new_event_loop()
41+
asyncio.set_event_loop(loop)
42+
try:
43+
loop.run_until_complete(_drain_async_generator())
44+
finally:
45+
try:
46+
pending = asyncio.all_tasks(loop)
47+
if pending:
48+
for task in pending:
49+
task.cancel()
50+
loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))
51+
except Exception: # nosec: B110 - cleanup code, safe to ignore
52+
pass
53+
loop.close()
54+
55+
thread = threading.Thread(target=_run_in_thread, name=thread_name, daemon=True)
56+
thread.start()
57+
58+
try:
59+
while not stop_event.is_set() or not result_queue.empty():
60+
try:
61+
item_type, item = result_queue.get(timeout=0.1)
62+
except queue.Empty:
63+
continue
64+
65+
if item_type == "done":
66+
break
67+
if item_type == "event" and item is not None:
68+
yield item
69+
finally:
70+
thread.join(timeout=join_timeout)
71+
if thread.is_alive():
72+
bridge_logger.warning("Async bridge thread %s did not terminate within timeout", thread_name)
73+
74+
if exception_container:
75+
raise exception_container[0]

tests/test_base_agent.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
from __future__ import annotations
2+
3+
import pytest
4+
5+
from swarmmind.agents.base import AgentError, BaseAgent
6+
from swarmmind.models import MemoryContext, MemoryLayer
7+
8+
9+
class _FakeLayeredMemory:
10+
def __init__(self, agent_id: str) -> None:
11+
self.agent_id = agent_id
12+
13+
14+
class _DummyAgent(BaseAgent):
15+
@property
16+
def domain_tags(self) -> list[str]:
17+
return [self.domain]
18+
19+
20+
def test_base_agent_loads_system_prompt_from_repository(monkeypatch: pytest.MonkeyPatch) -> None:
21+
calls: list[str] = []
22+
23+
class _FakeAgentRepository:
24+
def get_system_prompt(self, agent_id: str) -> str | None:
25+
calls.append(agent_id)
26+
return "system prompt from repository"
27+
28+
monkeypatch.setattr("swarmmind.agents.base.LayeredMemory", _FakeLayeredMemory)
29+
monkeypatch.setattr("swarmmind.agents.base.AgentRepository", _FakeAgentRepository)
30+
31+
agent = _DummyAgent(agent_id="general", domain="general")
32+
33+
assert calls == ["general"]
34+
assert agent._system_prompt == "system prompt from repository"
35+
assert agent.memory.agent_id == "general"
36+
37+
38+
def test_base_agent_raises_for_missing_agent(monkeypatch: pytest.MonkeyPatch) -> None:
39+
class _FakeAgentRepository:
40+
def get_system_prompt(self, agent_id: str) -> str | None:
41+
return None
42+
43+
monkeypatch.setattr("swarmmind.agents.base.LayeredMemory", _FakeLayeredMemory)
44+
monkeypatch.setattr("swarmmind.agents.base.AgentRepository", _FakeAgentRepository)
45+
46+
with pytest.raises(AgentError, match="Agent missing not found in database\\."):
47+
_DummyAgent(agent_id="missing", domain="general")
48+
49+
50+
def test_create_rejected_proposal_delegates_to_repository(monkeypatch: pytest.MonkeyPatch) -> None:
51+
rejected_calls: list[tuple[str, str]] = []
52+
53+
class _FakeAgentRepository:
54+
def get_system_prompt(self, agent_id: str) -> str | None:
55+
return "prompt"
56+
57+
class _FakeActionProposalRepository:
58+
def reject(self, proposal_id: str, description: str) -> None:
59+
rejected_calls.append((proposal_id, description))
60+
61+
monkeypatch.setattr("swarmmind.agents.base.LayeredMemory", _FakeLayeredMemory)
62+
monkeypatch.setattr("swarmmind.agents.base.AgentRepository", _FakeAgentRepository)
63+
monkeypatch.setattr("swarmmind.agents.base.ActionProposalRepository", _FakeActionProposalRepository)
64+
65+
agent = _DummyAgent(agent_id="general", domain="general")
66+
agent._create_rejected_proposal("proposal-123", "runtime failed")
67+
68+
assert rejected_calls == [("proposal-123", "runtime failed")]
69+
70+
71+
@pytest.mark.parametrize(
72+
("ctx", "expected_layer", "expected_scope_id"),
73+
[
74+
(
75+
MemoryContext(user_id="user-1", project_id="project-1", team_id="team-1", session_id="session-1"),
76+
MemoryLayer.TMP,
77+
"session-1",
78+
),
79+
(
80+
MemoryContext(user_id="user-1", project_id="project-1", team_id="team-1"),
81+
MemoryLayer.TEAM,
82+
"team-1",
83+
),
84+
(
85+
MemoryContext(user_id="user-1", project_id="project-1"),
86+
MemoryLayer.PROJECT,
87+
"project-1",
88+
),
89+
(
90+
MemoryContext(user_id="user-1"),
91+
MemoryLayer.USER_SOUL,
92+
"user-1",
93+
),
94+
],
95+
)
96+
def test_resolve_write_scope_uses_most_specific_context(
97+
monkeypatch: pytest.MonkeyPatch,
98+
ctx: MemoryContext,
99+
expected_layer: MemoryLayer,
100+
expected_scope_id: str,
101+
) -> None:
102+
class _FakeAgentRepository:
103+
def get_system_prompt(self, agent_id: str) -> str | None:
104+
return "prompt"
105+
106+
monkeypatch.setattr("swarmmind.agents.base.LayeredMemory", _FakeLayeredMemory)
107+
monkeypatch.setattr("swarmmind.agents.base.AgentRepository", _FakeAgentRepository)
108+
109+
agent = _DummyAgent(agent_id="general", domain="general")
110+
scope = agent._resolve_write_scope(ctx)
111+
112+
assert scope.layer == expected_layer
113+
assert scope.scope_id == expected_scope_id

tests/test_runtime_bridge.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
"""Tests for async-to-sync runtime bridge helpers."""
2+
3+
from __future__ import annotations
4+
5+
from swarmmind.services.runtime_bridge import iter_async_generator_in_thread
6+
7+
8+
def test_iter_async_generator_in_thread_yields_items_in_order() -> None:
9+
async def factory():
10+
yield "first"
11+
yield "second"
12+
13+
assert list(iter_async_generator_in_thread(factory, thread_name="bridge-test")) == ["first", "second"]
14+
15+
16+
def test_iter_async_generator_in_thread_reraises_async_errors() -> None:
17+
async def factory():
18+
raise RuntimeError("bridge failed")
19+
yield # pragma: no cover
20+
21+
iterator = iter_async_generator_in_thread(factory, thread_name="bridge-test")
22+
23+
try:
24+
next(iterator)
25+
except RuntimeError as exc:
26+
assert str(exc) == "bridge failed"
27+
else: # pragma: no cover
28+
raise AssertionError("RuntimeError was not re-raised")

0 commit comments

Comments
 (0)