Skip to content

Commit 4c46425

Browse files
authored
Merge pull request #22607 from dannon/agent-router-message-history
Pass chat history to agents as structured messages
2 parents 21408ef + 2216df4 commit 4c46425

4 files changed

Lines changed: 282 additions & 47 deletions

File tree

lib/galaxy/agents/base.py

Lines changed: 136 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,10 @@
99
ABC,
1010
abstractmethod,
1111
)
12-
from collections.abc import Callable
12+
from collections.abc import (
13+
Callable,
14+
Sequence,
15+
)
1316
from dataclasses import dataclass
1417
from typing import (
1518
Any,
@@ -37,6 +40,14 @@
3740

3841
from pydantic_ai import Agent
3942
from pydantic_ai.exceptions import UnexpectedModelBehavior
43+
from pydantic_ai.messages import (
44+
ModelMessage,
45+
ModelRequest,
46+
ModelResponse,
47+
SystemPromptPart,
48+
TextPart,
49+
UserPromptPart,
50+
)
4051
from pydantic_ai.models.openai import OpenAIChatModel
4152
from pydantic_ai.providers.openai import OpenAIProvider
4253
from pydantic_ai.settings import ModelSettings
@@ -66,6 +77,21 @@
6677
# Literal inlines enum values in JSON schema, avoiding $defs that vLLM can't handle
6778
ConfidenceLiteral = Literal["low", "medium", "high"]
6879

80+
MAX_HISTORY_MESSAGES = 40
81+
"""Cap on prior messages passed as pydantic-ai ``message_history``.
82+
83+
~20 turn-pairs; tool-heavy turns produce 3-5 messages each. Bounds total token
84+
load while preserving enough recent context to keep multi-turn conversations
85+
coherent.
86+
"""
87+
88+
TOOL_HELPER_HISTORY_MESSAGES = 8
89+
"""Tighter cap when a sub-agent is invoked from inside a `@agent.tool` call.
90+
91+
Tool-context turns burn token budget faster (tool call + tool return +
92+
follow-up), so we hand the sub-agent a smaller window.
93+
"""
94+
6995
__all__ = [
7096
"ActionSuggestion",
7197
"ActionType",
@@ -78,11 +104,66 @@
78104
"extract_structured_output",
79105
"extract_usage_info",
80106
"GalaxyAgentDependencies",
107+
"MAX_HISTORY_MESSAGES",
81108
"normalize_llm_text",
82109
"SimpleGalaxyAgent",
110+
"TOOL_HELPER_HISTORY_MESSAGES",
111+
"truncate_message_history",
83112
]
84113

85114

115+
def truncate_message_history(history: list[ModelMessage], limit: int = MAX_HISTORY_MESSAGES) -> list[ModelMessage]:
116+
"""Cap conversation history at ``limit`` recent messages, preserving the first one.
117+
118+
Keeps ``history[0]`` -- typically the user's original request, which anchors
119+
intent across long conversations -- and the most recent ``limit`` messages.
120+
"""
121+
if len(history) <= limit:
122+
return history
123+
log.info(
124+
"Truncating conversation history from %d to %d messages (first + last %d)",
125+
len(history),
126+
limit + 1,
127+
limit,
128+
)
129+
return [history[0]] + history[-limit:]
130+
131+
132+
def _coerce_message_history(history: Sequence[Any]) -> list[ModelMessage]:
133+
"""Normalize API-formatted and legacy role/content chat history."""
134+
messages: list[ModelMessage] = []
135+
skipped = 0
136+
137+
for item in history:
138+
if isinstance(item, (ModelRequest, ModelResponse)):
139+
messages.append(item)
140+
continue
141+
142+
if not isinstance(item, dict):
143+
skipped += 1
144+
continue
145+
146+
role = str(item.get("role", "")).lower()
147+
content = item.get("content")
148+
if content is None:
149+
skipped += 1
150+
continue
151+
152+
if role == "assistant":
153+
messages.append(ModelResponse(parts=[TextPart(content=str(content))]))
154+
elif role == "user":
155+
messages.append(ModelRequest(parts=[UserPromptPart(content=str(content))]))
156+
elif role == "system":
157+
messages.append(ModelRequest(parts=[SystemPromptPart(content=str(content))]))
158+
else:
159+
skipped += 1
160+
161+
if skipped:
162+
log.warning("Ignored %d unsupported conversation_history message(s)", skipped)
163+
164+
return messages
165+
166+
86167
def extract_result_content(result: Any) -> str:
87168
"""Extract text content from a pydantic-ai result (.output or .data)."""
88169
if hasattr(result, "output"):
@@ -259,15 +340,56 @@ async def process(self, query: str, context: Optional[dict[str, Any]] = None) ->
259340
return self._validation_error_response(validation_error)
260341

261342
try:
262-
full_prompt = self._prepare_prompt(query, context or {})
263-
result = await self._run_with_retry(full_prompt)
264-
return self._format_response(result, query, context or {})
343+
ctx = context or {}
344+
message_history = self._extract_message_history(ctx)
345+
full_prompt = self._prepare_prompt(query, self._strip_history_from_context(ctx))
346+
result = await self._run_with_retry(full_prompt, message_history=message_history)
347+
return self._format_response(result, query, ctx)
265348

266349
except (UnexpectedModelBehavior, OSError, ValueError) as e:
267350
log.warning(f"Error in {self.agent_type} agent: {e}")
268351
return self._get_fallback_response(query, str(e))
269352

270-
async def _run_with_retry(self, prompt: str, max_retries: int = 3, base_delay: float = 1.0):
353+
@staticmethod
354+
def _extract_message_history(
355+
context: Optional[dict[str, Any]],
356+
limit: int = MAX_HISTORY_MESSAGES,
357+
) -> Optional[list[ModelMessage]]:
358+
"""Pull ``conversation_history`` out of context, normalize it, and truncate it.
359+
360+
Returns None when history is missing/empty so callers can pass it
361+
straight to ``agent.run(..., message_history=...)`` without branching.
362+
"""
363+
if not context:
364+
return None
365+
history = context.get("conversation_history")
366+
if not history:
367+
return None
368+
if isinstance(history, (str, bytes)) or not isinstance(history, Sequence):
369+
log.warning("Ignoring unsupported conversation_history value of type %s", type(history).__name__)
370+
return None
371+
messages = _coerce_message_history(history)
372+
if not messages:
373+
return None
374+
return truncate_message_history(messages, limit=limit)
375+
376+
@staticmethod
377+
def _strip_history_from_context(context: dict[str, Any]) -> dict[str, Any]:
378+
"""Drop ``conversation_history`` before rendering context as text.
379+
380+
``_prepare_prompt`` stringifies whatever's in the context dict; the raw
381+
``ModelMessage`` repr is noise once we're passing the history through
382+
the structured ``message_history`` channel.
383+
"""
384+
return {k: v for k, v in context.items() if k != "conversation_history"}
385+
386+
async def _run_with_retry(
387+
self,
388+
prompt: str,
389+
max_retries: int = 3,
390+
base_delay: float = 1.0,
391+
message_history: Optional[list[ModelMessage]] = None,
392+
):
271393
"""Run the agent with exponential backoff for retryable errors."""
272394
last_exception = None
273395

@@ -278,7 +400,12 @@ async def _run_with_retry(self, prompt: str, max_retries: int = 3, base_delay: f
278400

279401
for attempt in range(max_retries + 1):
280402
try:
281-
return await self.agent.run(prompt, deps=self.deps, model_settings=model_settings)
403+
return await self.agent.run(
404+
prompt,
405+
deps=self.deps,
406+
model_settings=model_settings,
407+
message_history=message_history,
408+
)
282409

283410
except Exception as e:
284411
last_exception = e
@@ -551,27 +678,19 @@ async def _call_agent_from_tool(
551678

552679
target_agent = ctx.deps.get_agent(agent_type, ctx.deps)
553680

554-
full_query = query
555-
if context and "conversation_history" in context:
556-
history = context["conversation_history"]
557-
if history and len(history) > 0:
558-
history_text = "Previous conversation:\n"
559-
for msg in history[-4:]:
560-
role = msg.get("role", "unknown")
561-
content = msg.get("content", "")[:200]
562-
history_text += f"{role}: {content}\n"
563-
full_query = f"{history_text}\nCurrent request: {query}"
681+
message_history = self._extract_message_history(context, limit=TOOL_HELPER_HISTORY_MESSAGES)
564682

565683
target_model_settings = {
566684
"temperature": target_agent._get_temperature(),
567685
"max_tokens": target_agent._get_max_tokens(),
568686
}
569687

570688
result = await target_agent.agent.run(
571-
full_query,
689+
query,
572690
deps=ctx.deps,
573691
usage=usage or ctx.usage,
574692
model_settings=target_model_settings,
693+
message_history=message_history,
575694
)
576695

577696
response_data = extract_result_content(result)

lib/galaxy/agents/router.py

Lines changed: 5 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -272,15 +272,13 @@ async def process(self, query: str, context: Optional[dict[str, Any]] = None) ->
272272
return self._validation_error_response(validation_error)
273273

274274
try:
275-
if context and context.get("conversation_history"):
276-
log.info(f"Router: Conversation has {len(context['conversation_history'])} messages")
275+
message_history = self._extract_message_history(context)
276+
if message_history:
277+
log.info(f"Router: passing {len(message_history)} prior messages as message_history")
277278
else:
278-
log.info("Router: Processing query with no conversation history")
279+
log.info("Router: processing query with no conversation history")
279280

280-
full_query = self._build_query_with_context(query, context)
281-
log.info(f"Router: Full query length={len(full_query)} (original={len(query)})")
282-
283-
result = await self._run_with_retry(full_query)
281+
result = await self._run_with_retry(query, message_history=message_history)
284282
content = extract_result_content(result)
285283

286284
try:
@@ -311,27 +309,6 @@ async def process(self, query: str, context: Optional[dict[str, Any]] = None) ->
311309
log.warning(f"Router agent error, using fallback: {e}")
312310
return self._handle_fallback(query, context, str(e))
313311

314-
def _build_query_with_context(self, query: str, context: Optional[dict[str, Any]]) -> str:
315-
if not context or "conversation_history" not in context:
316-
return query
317-
318-
history = context["conversation_history"]
319-
if not history:
320-
return query
321-
322-
max_history = 6
323-
if len(history) > max_history:
324-
log.debug(f"Router: Truncating conversation history from {len(history)} to {max_history} messages")
325-
326-
history_text = "Previous conversation:\n"
327-
for msg in history[-max_history:]:
328-
role = msg.get("role", "unknown")
329-
content = msg.get("content", "")
330-
history_text += f"{role}: {content}\n"
331-
history_text += f"\nCurrent query: {query}"
332-
333-
return history_text
334-
335312
def _handle_fallback(self, query: str, context: Optional[dict[str, Any]], error_msg: str) -> AgentResponse:
336313
query_lower = query.lower()
337314

lib/galaxy/webapps/galaxy/api/chat.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,10 +178,12 @@ async def query(
178178
# Build context with conversation history
179179
full_context: dict[str, Any] = query_context.copy() if query_context else {}
180180

181-
# If we have an exchange_id, ALWAYS load conversation history from database (source of truth)
181+
# If we have an exchange_id, ALWAYS load conversation history from database (source of truth).
182+
# Use structured pydantic-ai message format so the router can pass it through as
183+
# ``message_history`` rather than flattening it into a text blob.
182184
if exchange_id:
183185
db_history = await anyio.to_thread.run_sync(
184-
partial(self.chat_manager.get_chat_history, trans, exchange_id, format_for_pydantic_ai=False)
186+
partial(self.chat_manager.get_chat_history, trans, exchange_id, format_for_pydantic_ai=True)
185187
)
186188
if db_history:
187189
full_context["conversation_history"] = db_history

0 commit comments

Comments
 (0)