Skip to content

Commit 2216df4

Browse files
committed
Annotate test history lists as list[ModelMessage]
Mypy on CI runs against test files too (not just lib/), and list invariance tripped it up: ``history = [ModelRequest(...), ModelResponse(...)]`` gets inferred as ``list[ModelRequest]``, which can't be passed to a function expecting ``list[ModelMessage]``. Annotate the four test fixtures explicitly.
1 parent 6c79ba8 commit 2216df4

1 file changed

Lines changed: 6 additions & 5 deletions

File tree

test/unit/app/test_agents.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
pydantic_ai = pytest.importorskip("pydantic_ai")
3333
from pydantic_ai import Agent
3434
from pydantic_ai.messages import (
35+
ModelMessage,
3536
ModelRequest,
3637
ModelResponse,
3738
SystemPromptPart,
@@ -344,15 +345,15 @@ async def test_router_rejects_prompt_injection_query(self):
344345
assert "rephrase your question" in response.content.lower()
345346

346347
def test_truncate_message_history_under_limit_returns_unchanged(self):
347-
history = [
348+
history: list[ModelMessage] = [
348349
ModelRequest(parts=[UserPromptPart(content="hello")]),
349350
ModelResponse(parts=[TextPart(content="hi")]),
350351
]
351352

352353
assert truncate_message_history(history, limit=40) is history
353354

354355
def test_truncate_message_history_keeps_first_plus_last_n(self):
355-
history = []
356+
history: list[ModelMessage] = []
356357
for i in range(50):
357358
history.append(ModelRequest(parts=[UserPromptPart(content=f"q{i}")]))
358359
history.append(ModelResponse(parts=[TextPart(content=f"r{i}")]))
@@ -364,7 +365,7 @@ def test_truncate_message_history_keeps_first_plus_last_n(self):
364365
assert truncated[-10:] == history[-10:] # most recent preserved
365366

366367
def test_truncate_message_history_at_exact_boundary(self):
367-
history = [ModelRequest(parts=[UserPromptPart(content=f"m{i}")]) for i in range(10)]
368+
history: list[ModelMessage] = [ModelRequest(parts=[UserPromptPart(content=f"m{i}")]) for i in range(10)]
368369

369370
# At-boundary: returned as-is, not truncated to first+last-10 (which would lose nothing here)
370371
assert truncate_message_history(history, limit=10) is history
@@ -375,7 +376,7 @@ def test_extract_message_history_returns_none_for_empty_context(self):
375376
assert QueryRouterAgent._extract_message_history({"conversation_history": []}) is None
376377

377378
def test_extract_message_history_truncates(self):
378-
history = [ModelRequest(parts=[UserPromptPart(content=f"m{i}")]) for i in range(50)]
379+
history: list[ModelMessage] = [ModelRequest(parts=[UserPromptPart(content=f"m{i}")]) for i in range(50)]
379380

380381
# Default limit is MAX_HISTORY_MESSAGES (40), so 50 -> 41 (first + last 40)
381382
result = QueryRouterAgent._extract_message_history({"conversation_history": history})
@@ -435,7 +436,7 @@ def test_extract_message_history_returns_none_for_unsupported_history_shape(self
435436
async def test_router_passes_message_history_to_run(self):
436437
"""Router should hand the structured history to ``agent.run`` via ``message_history``."""
437438
router = QueryRouterAgent(self.deps)
438-
history = [
439+
history: list[ModelMessage] = [
439440
ModelRequest(parts=[UserPromptPart(content="What histories do I have?")]),
440441
ModelResponse(parts=[TextPart(content="You have 3.")]),
441442
]

0 commit comments

Comments
 (0)