|
45 | 45 | CustomToolAgent, |
46 | 46 | ErrorAnalysisAgent, |
47 | 47 | GalaxyAgentDependencies, |
| 48 | + HistoryAgent, |
48 | 49 | QueryRouterAgent, |
49 | 50 | ) |
50 | 51 | from galaxy.agents.base import truncate_message_history |
51 | 52 | from galaxy.agents.registry import build_default_registry |
52 | 53 |
|
53 | 54 | agent_registry = build_default_registry() |
| 55 | +from galaxy.agents.base import ( |
| 56 | + AgentResponse, |
| 57 | + AgentRunState, |
| 58 | + AgentType, |
| 59 | +) |
54 | 60 | from galaxy.agents.error_analysis import ErrorAnalysisResult |
55 | 61 | from galaxy.agents.orchestrator import ( |
56 | 62 | AgentPlan, |
@@ -642,6 +648,181 @@ async def test_workflow_orchestrator_generic_fallback_behavior(self): |
642 | 648 | assert response.agent_type == "orchestrator" |
643 | 649 | assert "having trouble" in response.content |
644 | 650 |
|
| 651 | + def test_agent_run_state_record_and_get_prior(self): |
| 652 | + run_state = AgentRunState() |
| 653 | + assert run_state.get_prior("history") is None |
| 654 | + |
| 655 | + history_response = AgentResponse( |
| 656 | + content="Found a failed job in the BRC history.", |
| 657 | + confidence=ConfidenceLevel.HIGH, |
| 658 | + agent_type="history", |
| 659 | + ) |
| 660 | + run_state.record("history", history_response) |
| 661 | + |
| 662 | + retrieved = run_state.get_prior("history") |
| 663 | + assert retrieved is history_response |
| 664 | + assert retrieved.content == "Found a failed job in the BRC history." |
| 665 | + assert run_state.get_prior("error_analysis") is None |
| 666 | + |
| 667 | + @pytest.mark.asyncio |
| 668 | + async def test_orchestrator_sequential_attaches_run_state_to_context(self): |
| 669 | + agent = WorkflowOrchestratorAgent(self.deps) |
| 670 | + |
| 671 | + captured_contexts: list[dict[str, Any]] = [] |
| 672 | + |
| 673 | + async def capture_history(query, context): |
| 674 | + captured_contexts.append(dict(context)) |
| 675 | + return MagicMock( |
| 676 | + content="History summary content", |
| 677 | + agent_type="history", |
| 678 | + confidence=ConfidenceLevel.HIGH, |
| 679 | + ) |
| 680 | + |
| 681 | + async def capture_error(query, context): |
| 682 | + captured_contexts.append(dict(context)) |
| 683 | + return MagicMock( |
| 684 | + content="Error analysis content", |
| 685 | + agent_type="error_analysis", |
| 686 | + confidence=ConfidenceLevel.HIGH, |
| 687 | + ) |
| 688 | + |
| 689 | + mock_history_agent = MagicMock() |
| 690 | + mock_history_agent.process = AsyncMock(side_effect=capture_history) |
| 691 | + mock_error_agent = MagicMock() |
| 692 | + mock_error_agent.process = AsyncMock(side_effect=capture_error) |
| 693 | + |
| 694 | + def get_agent_side_effect(agent_type, deps): |
| 695 | + if agent_type == "history": |
| 696 | + return mock_history_agent |
| 697 | + if agent_type == "error_analysis": |
| 698 | + return mock_error_agent |
| 699 | + raise ValueError(f"Unexpected agent type: {agent_type}") |
| 700 | + |
| 701 | + self.deps.get_agent = MagicMock(side_effect=get_agent_side_effect) |
| 702 | + |
| 703 | + with patch.object(agent, "_get_agent_plan") as mock_get_plan: |
| 704 | + mock_get_plan.return_value = AgentPlan( |
| 705 | + agents=["history", "error_analysis"], |
| 706 | + sequential=True, |
| 707 | + reasoning="Find failed job, then diagnose it", |
| 708 | + ) |
| 709 | + |
| 710 | + await agent.process("Why did my job fail?") |
| 711 | + |
| 712 | + assert len(captured_contexts) == 2 |
| 713 | + |
| 714 | + first_run_state = captured_contexts[0].get("run_state") |
| 715 | + second_run_state = captured_contexts[1].get("run_state") |
| 716 | + assert isinstance(first_run_state, AgentRunState) |
| 717 | + assert isinstance(second_run_state, AgentRunState) |
| 718 | + # Same run_state instance is reused across the sequential flow |
| 719 | + assert first_run_state is second_run_state |
| 720 | + |
| 721 | + # First agent saw an empty run_state; second agent saw history recorded |
| 722 | + history_prior = second_run_state.get_prior("history") |
| 723 | + assert history_prior is not None |
| 724 | + assert history_prior.content == "History summary content" |
| 725 | + |
| 726 | + @pytest.mark.asyncio |
| 727 | + async def test_orchestrator_sequential_passes_original_query(self): |
| 728 | + agent = WorkflowOrchestratorAgent(self.deps) |
| 729 | + original_query = "Why did my job fail?" |
| 730 | + captured_queries: list[str] = [] |
| 731 | + |
| 732 | + async def capture_query(query, context): |
| 733 | + captured_queries.append(query) |
| 734 | + return MagicMock( |
| 735 | + content="some response", |
| 736 | + agent_type="history", |
| 737 | + confidence=ConfidenceLevel.HIGH, |
| 738 | + ) |
| 739 | + |
| 740 | + mock_history_agent = MagicMock() |
| 741 | + mock_history_agent.process = AsyncMock(side_effect=capture_query) |
| 742 | + mock_error_agent = MagicMock() |
| 743 | + mock_error_agent.process = AsyncMock(side_effect=capture_query) |
| 744 | + |
| 745 | + def get_agent_side_effect(agent_type, deps): |
| 746 | + if agent_type == "history": |
| 747 | + return mock_history_agent |
| 748 | + if agent_type == "error_analysis": |
| 749 | + return mock_error_agent |
| 750 | + raise ValueError(f"Unexpected agent type: {agent_type}") |
| 751 | + |
| 752 | + self.deps.get_agent = MagicMock(side_effect=get_agent_side_effect) |
| 753 | + |
| 754 | + with patch.object(agent, "_get_agent_plan") as mock_get_plan: |
| 755 | + mock_get_plan.return_value = AgentPlan( |
| 756 | + agents=["history", "error_analysis"], |
| 757 | + sequential=True, |
| 758 | + reasoning="Find failed job, then diagnose it", |
| 759 | + ) |
| 760 | + |
| 761 | + await agent.process(original_query) |
| 762 | + |
| 763 | + assert len(captured_queries) == 2 |
| 764 | + for q in captured_queries: |
| 765 | + assert q == original_query |
| 766 | + assert "Previous analysis from" not in q |
| 767 | + |
| 768 | + @pytest.mark.asyncio |
| 769 | + async def test_error_analysis_reads_history_from_run_state(self): |
| 770 | + self.mock_config.ai_model = "gpt-4o" |
| 771 | + agent = ErrorAnalysisAgent(self.deps) |
| 772 | + |
| 773 | + run_state = AgentRunState() |
| 774 | + history_response = AgentResponse( |
| 775 | + content="Found failing job 'select_first1' in BRC history; stderr says 'AssertionError'.", |
| 776 | + confidence=ConfidenceLevel.HIGH, |
| 777 | + agent_type=AgentType.HISTORY, |
| 778 | + ) |
| 779 | + run_state.record(AgentType.HISTORY, history_response) |
| 780 | + |
| 781 | + captured_prompts: list[str] = [] |
| 782 | + |
| 783 | + async def fake_run_with_retry(prompt, *args, **kwargs): |
| 784 | + captured_prompts.append(prompt) |
| 785 | + mock_result = mock.Mock() |
| 786 | + mock_result.output = ErrorAnalysisResult( |
| 787 | + error_category="tool_failure", |
| 788 | + error_severity="medium", |
| 789 | + likely_cause="Bad input", |
| 790 | + solution_steps=["Re-run"], |
| 791 | + confidence="high", |
| 792 | + requires_admin=False, |
| 793 | + ) |
| 794 | + return mock_result |
| 795 | + |
| 796 | + with mock.patch.object(agent, "_run_with_retry", side_effect=fake_run_with_retry): |
| 797 | + await agent.process("Why did my job fail?", context={"run_state": run_state}) |
| 798 | + |
| 799 | + assert len(captured_prompts) == 1 |
| 800 | + prompt = captured_prompts[0] |
| 801 | + assert "Context from history analysis:" in prompt |
| 802 | + assert "select_first1" in prompt |
| 803 | + assert "AssertionError" in prompt |
| 804 | + |
| 805 | + @pytest.mark.asyncio |
| 806 | + async def test_internal_run_state_is_not_rendered_in_default_prompt(self): |
| 807 | + agent = HistoryAgent(self.deps) |
| 808 | + run_state = AgentRunState() |
| 809 | + captured_prompts: list[str] = [] |
| 810 | + |
| 811 | + async def fake_run_with_retry(prompt, *args, **kwargs): |
| 812 | + captured_prompts.append(prompt) |
| 813 | + mock_result = mock.Mock() |
| 814 | + mock_result.output = "History summary" |
| 815 | + return mock_result |
| 816 | + |
| 817 | + with mock.patch.object(agent, "_run_with_retry", side_effect=fake_run_with_retry): |
| 818 | + await agent.process("Summarize my history", context={"run_state": run_state, "history_id": "abc123"}) |
| 819 | + |
| 820 | + assert len(captured_prompts) == 1 |
| 821 | + prompt = captured_prompts[0] |
| 822 | + assert "history_id: abc123" in prompt |
| 823 | + assert "run_state" not in prompt |
| 824 | + assert "AgentRunState" not in prompt |
| 825 | + |
645 | 826 | def _orchestrator_agent(self): |
646 | 827 | agent = WorkflowOrchestratorAgent(self.deps) |
647 | 828 | return agent |
|
0 commit comments