Skip to content

Commit db29be6

Browse files
test: cover stale session middleware edges
1 parent f1c963a commit db29be6

2 files changed

Lines changed: 113 additions & 5 deletions

File tree

src/godot_ai/asgi.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,6 @@ async def _send_response(
8585
body: bytes,
8686
send: Send,
8787
) -> None:
88-
if start_message is None:
89-
return
90-
9188
rewritten = self._rewrite_stale_session_body(start_message, body)
9289
response_body = rewritten if rewritten is not None else body
9390
headers = start_message.get("headers", [])
@@ -99,8 +96,6 @@ async def _send_response(
9996
await send({"type": "http.response.body", "body": response_body, "more_body": False})
10097

10198
def _rewrite_stale_session_body(self, start_message: Message, body: bytes) -> bytes | None:
102-
if start_message.get("status") != HTTPStatus.NOT_FOUND:
103-
return None
10499
try:
105100
payload = json.loads(body.decode("utf-8"))
106101
except (UnicodeDecodeError, json.JSONDecodeError):

tests/unit/test_asgi_session_diagnostics.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,19 @@ async def send(message):
2828
return sent
2929

3030

31+
async def _single_asgi_request(app, *, scope):
32+
sent = []
33+
34+
async def receive():
35+
return {"type": "websocket.disconnect"}
36+
37+
async def send(message):
38+
sent.append(message)
39+
40+
await app(scope, receive, send)
41+
return sent
42+
43+
3144
@pytest.mark.anyio
3245
async def test_stale_mcp_session_diagnostic_rewrites_sdk_session_not_found():
3346
async def sdk_stale_session_response(scope, receive, send):
@@ -76,6 +89,36 @@ async def sdk_stale_session_response(scope, receive, send):
7689
}
7790

7891

92+
@pytest.mark.anyio
93+
async def test_stale_mcp_session_diagnostic_handles_chunked_stale_session_body():
94+
async def sdk_stale_session_response(scope, receive, send):
95+
await send({"type": "http.response.start", "status": 404, "headers": []})
96+
await send(
97+
{
98+
"type": "http.response.body",
99+
"body": b'{"jsonrpc":"2.0","id":"server-error",',
100+
"more_body": True,
101+
}
102+
)
103+
await send(
104+
{
105+
"type": "http.response.body",
106+
"body": b'"error":{"code":-32600,"message":"Session not found"}}',
107+
"more_body": False,
108+
}
109+
)
110+
111+
app = StaleMcpSessionDiagnosticMiddleware(sdk_stale_session_response)
112+
113+
sent = await _single_http_request(app)
114+
115+
assert sent[0]["status"] == 404
116+
assert (b"content-type", b"application/json") in sent[0]["headers"]
117+
body = json.loads(sent[1]["body"])
118+
assert body["error"]["message"] == STALE_MCP_SESSION_MESSAGE
119+
assert body["error"]["data"]["action"] == "reinitialize_mcp_session"
120+
121+
79122
@pytest.mark.anyio
80123
async def test_stale_mcp_session_diagnostic_leaves_other_responses_unchanged():
81124
async def ok_response(scope, receive, send):
@@ -130,6 +173,40 @@ async def streaming_response(scope, receive, send):
130173
]
131174

132175

176+
@pytest.mark.anyio
177+
async def test_stale_mcp_session_diagnostic_passes_non_http_scopes_through():
178+
async def websocket_response(scope, receive, send):
179+
await send({"type": "websocket.close", "code": 1000})
180+
181+
app = StaleMcpSessionDiagnosticMiddleware(websocket_response)
182+
183+
sent = await _single_asgi_request(app, scope={"type": "websocket", "path": "/ws"})
184+
185+
assert sent == [{"type": "websocket.close", "code": 1000}]
186+
187+
188+
@pytest.mark.anyio
189+
async def test_stale_mcp_session_diagnostic_passes_unhandled_asgi_messages_through():
190+
async def extension_message_response(scope, receive, send):
191+
await send(
192+
{
193+
"type": "http.response.debug",
194+
"info": {"note": "kept for downstream middleware"},
195+
}
196+
)
197+
198+
app = StaleMcpSessionDiagnosticMiddleware(extension_message_response)
199+
200+
sent = await _single_http_request(app)
201+
202+
assert sent == [
203+
{
204+
"type": "http.response.debug",
205+
"info": {"note": "kept for downstream middleware"},
206+
}
207+
]
208+
209+
133210
@pytest.mark.anyio
134211
async def test_stale_mcp_session_diagnostic_leaves_other_404_responses_unchanged():
135212
async def not_found_response(scope, receive, send):
@@ -156,6 +233,34 @@ async def not_found_response(scope, receive, send):
156233
]
157234

158235

236+
@pytest.mark.anyio
237+
async def test_stale_mcp_session_diagnostic_leaves_other_json_rpc_404_errors_unchanged():
238+
async def json_rpc_not_found_response(scope, receive, send):
239+
body = json.dumps(
240+
{
241+
"jsonrpc": "2.0",
242+
"id": "server-error",
243+
"error": {"code": -32000, "message": "Tool not found"},
244+
}
245+
).encode()
246+
await send(
247+
{
248+
"type": "http.response.start",
249+
"status": 404,
250+
"headers": [(b"content-type", b"application/json")],
251+
}
252+
)
253+
await send({"type": "http.response.body", "body": body, "more_body": False})
254+
255+
app = StaleMcpSessionDiagnosticMiddleware(json_rpc_not_found_response)
256+
257+
sent = await _single_http_request(app)
258+
259+
body = json.loads(sent[1]["body"])
260+
assert body["error"] == {"code": -32000, "message": "Tool not found"}
261+
assert "data" not in body["error"]
262+
263+
159264
def test_create_server_wraps_streamable_http_app_with_stale_session_diagnostic():
160265
server = create_server()
161266

@@ -164,6 +269,14 @@ def test_create_server_wraps_streamable_http_app_with_stale_session_diagnostic()
164269
assert isinstance(app, StaleMcpSessionDiagnosticMiddleware)
165270

166271

272+
def test_create_server_does_not_wrap_sse_app_with_stale_session_diagnostic():
273+
server = create_server()
274+
275+
app = server.http_app(transport="sse")
276+
277+
assert not isinstance(app, StaleMcpSessionDiagnosticMiddleware)
278+
279+
167280
def test_stale_mcp_session_diagnostic_preserves_fastmcp_app_state():
168281
server = create_server()
169282

0 commit comments

Comments
 (0)