Skip to content

Commit c180663

Browse files
authored
Merge pull request #22608 from dannon/agent-typed-orchestration-state
Pass orchestration state between agents as a typed object
2 parents 1af9307 + 19dd901 commit c180663

4 files changed

Lines changed: 223 additions & 13 deletions

File tree

lib/galaxy/agents/base.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,10 @@
1313
Callable,
1414
Sequence,
1515
)
16-
from dataclasses import dataclass
16+
from dataclasses import (
17+
dataclass,
18+
field,
19+
)
1720
from typing import (
1821
Any,
1922
Literal,
@@ -96,6 +99,7 @@
9699
"ActionSuggestion",
97100
"ActionType",
98101
"AgentResponse",
102+
"AgentRunState",
99103
"AgentType",
100104
"BaseGalaxyAgent",
101105
"ConfidenceLevel",
@@ -258,6 +262,24 @@ def __init__(
258262
self.reasoning = reasoning
259263

260264

265+
@dataclass
266+
class AgentRunState:
267+
"""Per-invocation state shared across sequential multi-agent flows.
268+
269+
The orchestrator creates a fresh instance per user query and attaches it
270+
to each agent's context. Sequential agents read prior agents' responses
271+
from here instead of parsing them out of a text-concatenated prompt.
272+
"""
273+
274+
prior_responses: dict[str, "AgentResponse"] = field(default_factory=dict)
275+
276+
def get_prior(self, agent_type: str) -> Optional["AgentResponse"]:
277+
return self.prior_responses.get(agent_type)
278+
279+
def record(self, agent_type: str, response: "AgentResponse") -> None:
280+
self.prior_responses[agent_type] = response
281+
282+
261283
@dataclass
262284
class GalaxyAgentDependencies:
263285
"""Dependencies passed to Galaxy agents via dependency injection."""
@@ -280,6 +302,7 @@ class BaseGalaxyAgent(ABC):
280302

281303
agent_type: str
282304
agent: Agent[GalaxyAgentDependencies, Any]
305+
_INTERNAL_CONTEXT_KEYS = frozenset({"run_state"})
283306

284307
def __init__(self, deps: GalaxyAgentDependencies):
285308
self.deps = deps
@@ -446,7 +469,9 @@ def _prepare_prompt(self, query: str, context: dict[str, Any]) -> str:
446469
prompt_parts = [query]
447470

448471
if context:
449-
context_str = "\n".join([f"{k}: {v}" for k, v in context.items() if v])
472+
context_str = "\n".join(
473+
[f"{k}: {v}" for k, v in context.items() if v and k not in self._INTERNAL_CONTEXT_KEYS]
474+
)
450475
if context_str:
451476
prompt_parts.insert(0, f"Context:\n{context_str}\n")
452477

lib/galaxy/agents/error_analysis.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
ActionSuggestion,
2121
ActionType,
2222
AgentResponse,
23+
AgentRunState,
2324
AgentType,
2425
BaseGalaxyAgent,
2526
ConfidenceLiteral,
@@ -108,10 +109,15 @@ async def process(self, query: str, context: Optional[dict[str, Any]] = None) ->
108109
try:
109110
log.info(f"ErrorAnalysis: Received query (length={len(query)})")
110111
log.info(f"ErrorAnalysis: Query preview: {query[:800]}...")
111-
if "Previous analysis" in query:
112-
log.info("ErrorAnalysis: Query contains previous analysis context")
113112

114113
enhanced_query = query
114+
run_state = context.get("run_state") if context else None
115+
if isinstance(run_state, AgentRunState):
116+
prior = run_state.get_prior(AgentType.HISTORY)
117+
if prior is not None:
118+
log.info("ErrorAnalysis: Found prior history analysis in run_state")
119+
enhanced_query += f"\n\nContext from history analysis:\n{prior.content}"
120+
115121
if context and context.get("job_id"):
116122
job_details = await self.get_job_details(context["job_id"])
117123
if "error" not in job_details:

lib/galaxy/agents/orchestrator.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from galaxy.schema.agents import ConfidenceLevel
1818
from .base import (
1919
AgentResponse,
20+
AgentRunState,
2021
AgentType,
2122
BaseGalaxyAgent,
2223
extract_result_content,
@@ -190,26 +191,23 @@ async def _execute_sequential(
190191
self, agents: list[str], query: str, context: Optional[dict[str, Any]] = None
191192
) -> dict[str, AgentResponse]:
192193
"""Execute agents sequentially with timeout protection."""
193-
responses = {}
194-
current_query = query
194+
responses: dict[str, AgentResponse] = {}
195195
timeout = self._get_agent_timeout()
196+
run_state = AgentRunState()
197+
ctx: dict[str, Any] = {**(context or {}), "run_state": run_state}
196198

197199
log.info(f"Orchestrator: Running agents in SEQUENTIAL mode: {agents}")
198200
for agent_name in agents:
199201
try:
200-
log.info(f"Orchestrator: Starting agent '{agent_name}' with query length {len(current_query)}")
202+
log.info(f"Orchestrator: Starting agent '{agent_name}' with query length {len(query)}")
201203
agent = self.deps.get_agent(agent_name, self.deps)
202-
response = await asyncio.wait_for(agent.process(current_query, context or {}), timeout=timeout)
204+
response = await asyncio.wait_for(agent.process(query, ctx), timeout=timeout)
203205
responses[agent_name] = response
206+
run_state.record(agent_name, response)
204207

205208
log.debug(f"Orchestrator: Agent '{agent_name}' completed. Response length: {len(response.content)}")
206209
log.debug(f"Orchestrator: Agent '{agent_name}' response preview: {response.content[:500]}...")
207210

208-
# Cap previous response to avoid unbounded query growth
209-
prev_content = response.content[:2000]
210-
current_query = f"{query}\n\nPrevious analysis from {agent_name}: {prev_content}"
211-
log.debug(f"Orchestrator: Updated query for next agent, total length: {len(current_query)}")
212-
213211
except asyncio.TimeoutError:
214212
log.error(f"Agent {agent_name} timed out after {timeout}s")
215213
responses[agent_name] = _create_error_response(

test/unit/app/test_agents.py

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,18 @@
4545
CustomToolAgent,
4646
ErrorAnalysisAgent,
4747
GalaxyAgentDependencies,
48+
HistoryAgent,
4849
QueryRouterAgent,
4950
)
5051
from galaxy.agents.base import truncate_message_history
5152
from galaxy.agents.registry import build_default_registry
5253

5354
agent_registry = build_default_registry()
55+
from galaxy.agents.base import (
56+
AgentResponse,
57+
AgentRunState,
58+
AgentType,
59+
)
5460
from galaxy.agents.error_analysis import ErrorAnalysisResult
5561
from galaxy.agents.orchestrator import (
5662
AgentPlan,
@@ -642,6 +648,181 @@ async def test_workflow_orchestrator_generic_fallback_behavior(self):
642648
assert response.agent_type == "orchestrator"
643649
assert "having trouble" in response.content
644650

651+
def test_agent_run_state_record_and_get_prior(self):
652+
run_state = AgentRunState()
653+
assert run_state.get_prior("history") is None
654+
655+
history_response = AgentResponse(
656+
content="Found a failed job in the BRC history.",
657+
confidence=ConfidenceLevel.HIGH,
658+
agent_type="history",
659+
)
660+
run_state.record("history", history_response)
661+
662+
retrieved = run_state.get_prior("history")
663+
assert retrieved is history_response
664+
assert retrieved.content == "Found a failed job in the BRC history."
665+
assert run_state.get_prior("error_analysis") is None
666+
667+
@pytest.mark.asyncio
668+
async def test_orchestrator_sequential_attaches_run_state_to_context(self):
669+
agent = WorkflowOrchestratorAgent(self.deps)
670+
671+
captured_contexts: list[dict[str, Any]] = []
672+
673+
async def capture_history(query, context):
674+
captured_contexts.append(dict(context))
675+
return MagicMock(
676+
content="History summary content",
677+
agent_type="history",
678+
confidence=ConfidenceLevel.HIGH,
679+
)
680+
681+
async def capture_error(query, context):
682+
captured_contexts.append(dict(context))
683+
return MagicMock(
684+
content="Error analysis content",
685+
agent_type="error_analysis",
686+
confidence=ConfidenceLevel.HIGH,
687+
)
688+
689+
mock_history_agent = MagicMock()
690+
mock_history_agent.process = AsyncMock(side_effect=capture_history)
691+
mock_error_agent = MagicMock()
692+
mock_error_agent.process = AsyncMock(side_effect=capture_error)
693+
694+
def get_agent_side_effect(agent_type, deps):
695+
if agent_type == "history":
696+
return mock_history_agent
697+
if agent_type == "error_analysis":
698+
return mock_error_agent
699+
raise ValueError(f"Unexpected agent type: {agent_type}")
700+
701+
self.deps.get_agent = MagicMock(side_effect=get_agent_side_effect)
702+
703+
with patch.object(agent, "_get_agent_plan") as mock_get_plan:
704+
mock_get_plan.return_value = AgentPlan(
705+
agents=["history", "error_analysis"],
706+
sequential=True,
707+
reasoning="Find failed job, then diagnose it",
708+
)
709+
710+
await agent.process("Why did my job fail?")
711+
712+
assert len(captured_contexts) == 2
713+
714+
first_run_state = captured_contexts[0].get("run_state")
715+
second_run_state = captured_contexts[1].get("run_state")
716+
assert isinstance(first_run_state, AgentRunState)
717+
assert isinstance(second_run_state, AgentRunState)
718+
# Same run_state instance is reused across the sequential flow
719+
assert first_run_state is second_run_state
720+
721+
# First agent saw an empty run_state; second agent saw history recorded
722+
history_prior = second_run_state.get_prior("history")
723+
assert history_prior is not None
724+
assert history_prior.content == "History summary content"
725+
726+
@pytest.mark.asyncio
727+
async def test_orchestrator_sequential_passes_original_query(self):
728+
agent = WorkflowOrchestratorAgent(self.deps)
729+
original_query = "Why did my job fail?"
730+
captured_queries: list[str] = []
731+
732+
async def capture_query(query, context):
733+
captured_queries.append(query)
734+
return MagicMock(
735+
content="some response",
736+
agent_type="history",
737+
confidence=ConfidenceLevel.HIGH,
738+
)
739+
740+
mock_history_agent = MagicMock()
741+
mock_history_agent.process = AsyncMock(side_effect=capture_query)
742+
mock_error_agent = MagicMock()
743+
mock_error_agent.process = AsyncMock(side_effect=capture_query)
744+
745+
def get_agent_side_effect(agent_type, deps):
746+
if agent_type == "history":
747+
return mock_history_agent
748+
if agent_type == "error_analysis":
749+
return mock_error_agent
750+
raise ValueError(f"Unexpected agent type: {agent_type}")
751+
752+
self.deps.get_agent = MagicMock(side_effect=get_agent_side_effect)
753+
754+
with patch.object(agent, "_get_agent_plan") as mock_get_plan:
755+
mock_get_plan.return_value = AgentPlan(
756+
agents=["history", "error_analysis"],
757+
sequential=True,
758+
reasoning="Find failed job, then diagnose it",
759+
)
760+
761+
await agent.process(original_query)
762+
763+
assert len(captured_queries) == 2
764+
for q in captured_queries:
765+
assert q == original_query
766+
assert "Previous analysis from" not in q
767+
768+
@pytest.mark.asyncio
769+
async def test_error_analysis_reads_history_from_run_state(self):
770+
self.mock_config.ai_model = "gpt-4o"
771+
agent = ErrorAnalysisAgent(self.deps)
772+
773+
run_state = AgentRunState()
774+
history_response = AgentResponse(
775+
content="Found failing job 'select_first1' in BRC history; stderr says 'AssertionError'.",
776+
confidence=ConfidenceLevel.HIGH,
777+
agent_type=AgentType.HISTORY,
778+
)
779+
run_state.record(AgentType.HISTORY, history_response)
780+
781+
captured_prompts: list[str] = []
782+
783+
async def fake_run_with_retry(prompt, *args, **kwargs):
784+
captured_prompts.append(prompt)
785+
mock_result = mock.Mock()
786+
mock_result.output = ErrorAnalysisResult(
787+
error_category="tool_failure",
788+
error_severity="medium",
789+
likely_cause="Bad input",
790+
solution_steps=["Re-run"],
791+
confidence="high",
792+
requires_admin=False,
793+
)
794+
return mock_result
795+
796+
with mock.patch.object(agent, "_run_with_retry", side_effect=fake_run_with_retry):
797+
await agent.process("Why did my job fail?", context={"run_state": run_state})
798+
799+
assert len(captured_prompts) == 1
800+
prompt = captured_prompts[0]
801+
assert "Context from history analysis:" in prompt
802+
assert "select_first1" in prompt
803+
assert "AssertionError" in prompt
804+
805+
@pytest.mark.asyncio
806+
async def test_internal_run_state_is_not_rendered_in_default_prompt(self):
807+
agent = HistoryAgent(self.deps)
808+
run_state = AgentRunState()
809+
captured_prompts: list[str] = []
810+
811+
async def fake_run_with_retry(prompt, *args, **kwargs):
812+
captured_prompts.append(prompt)
813+
mock_result = mock.Mock()
814+
mock_result.output = "History summary"
815+
return mock_result
816+
817+
with mock.patch.object(agent, "_run_with_retry", side_effect=fake_run_with_retry):
818+
await agent.process("Summarize my history", context={"run_state": run_state, "history_id": "abc123"})
819+
820+
assert len(captured_prompts) == 1
821+
prompt = captured_prompts[0]
822+
assert "history_id: abc123" in prompt
823+
assert "run_state" not in prompt
824+
assert "AgentRunState" not in prompt
825+
645826
def _orchestrator_agent(self):
646827
agent = WorkflowOrchestratorAgent(self.deps)
647828
return agent

0 commit comments

Comments
 (0)