|
4 | 4 | import uuid |
5 | 5 | import asyncio |
6 | 6 |
|
| 7 | +import httpx |
| 8 | + |
7 | 9 | from asyncio import Queue |
8 | 10 | from typing import ( |
9 | 11 | Dict, |
|
26 | 28 | OutputType, |
27 | 29 | UnexpectedEndOfExecution, |
28 | 30 | ) |
| 31 | +from consts import JUPYTER_BASE_URL |
29 | 32 | from errors import ExecutionError |
30 | 33 | from envs import get_envs |
31 | 34 |
|
32 | 35 | logger = logging.getLogger(__name__) |
33 | 36 |
|
34 | 37 | MAX_RECONNECT_RETRIES = 3 |
35 | 38 | PING_TIMEOUT = 30 |
| 39 | +KEEPALIVE_INTERVAL = 5 # seconds between keepalive pings during streaming |
36 | 40 |
|
37 | 41 |
|
38 | 42 | class Execution: |
@@ -97,6 +101,22 @@ async def connect(self): |
97 | 101 | name="receive_message", |
98 | 102 | ) |
99 | 103 |
|
| 104 | + async def interrupt(self): |
| 105 | + """Interrupt the current kernel execution via the Jupyter REST API.""" |
| 106 | + try: |
| 107 | + async with httpx.AsyncClient() as client: |
| 108 | + response = await client.post( |
| 109 | + f"{JUPYTER_BASE_URL}/api/kernels/{self.context_id}/interrupt" |
| 110 | + ) |
| 111 | + if response.is_success: |
| 112 | + logger.info(f"Kernel {self.context_id} interrupted successfully") |
| 113 | + else: |
| 114 | + logger.error( |
| 115 | + f"Failed to interrupt kernel {self.context_id}: {response.status_code}" |
| 116 | + ) |
| 117 | + except Exception as e: |
| 118 | + logger.error(f"Error interrupting kernel {self.context_id}: {e}") |
| 119 | + |
100 | 120 | def _get_execute_request( |
101 | 121 | self, msg_id: str, code: Union[str, StrictStr], background: bool |
102 | 122 | ) -> str: |
@@ -238,8 +258,24 @@ async def _cleanup_env_vars(self, env_vars: Dict[StrictStr, str]): |
238 | 258 | async def _wait_for_result(self, message_id: str): |
239 | 259 | queue = self._executions[message_id].queue |
240 | 260 |
|
| 261 | + # Use a timeout on queue.get() to periodically send keepalives. |
| 262 | + # Without keepalives, the generator blocks indefinitely waiting for |
| 263 | + # kernel output. If the client silently disappears (e.g. network |
| 264 | + # failure), uvicorn can only detect the broken connection when it |
| 265 | + # tries to write — so we force a write every KEEPALIVE_INTERVAL |
| 266 | + # seconds. This ensures timely disconnect detection and kernel |
| 267 | + # interrupt for abandoned executions (see #213). |
241 | 268 | while True: |
242 | | - output = await queue.get() |
| 269 | + try: |
| 270 | + output = await asyncio.wait_for(queue.get(), timeout=KEEPALIVE_INTERVAL) |
| 271 | + except asyncio.TimeoutError: |
| 272 | + # Yield a keepalive so Starlette writes to the socket. |
| 273 | + # If the client has disconnected, the write fails and |
| 274 | + # uvicorn delivers http.disconnect, which cancels this |
| 275 | + # generator via CancelledError. |
| 276 | + yield {"type": "keepalive"} |
| 277 | + continue |
| 278 | + |
243 | 279 | if output.type == OutputType.END_OF_EXECUTION: |
244 | 280 | break |
245 | 281 |
|
@@ -362,11 +398,26 @@ async def execute( |
362 | 398 | ) |
363 | 399 | await execution.queue.put(UnexpectedEndOfExecution()) |
364 | 400 |
|
365 | | - # Stream the results |
366 | | - async for item in self._wait_for_result(message_id): |
367 | | - yield item |
368 | | - |
369 | | - del self._executions[message_id] |
| 401 | + # Stream the results. |
| 402 | + # If the client disconnects (Starlette cancels the task), we |
| 403 | + # interrupt the kernel so the next execution isn't blocked (#213). |
| 404 | + try: |
| 405 | + async for item in self._wait_for_result(message_id): |
| 406 | + yield item |
| 407 | + except (asyncio.CancelledError, GeneratorExit): |
| 408 | + logger.warning( |
| 409 | + f"Client disconnected during execution ({message_id}), interrupting kernel" |
| 410 | + ) |
| 411 | + # Shield the interrupt from the ongoing cancellation so |
| 412 | + # the HTTP request to the kernel actually completes. |
| 413 | + try: |
| 414 | + await asyncio.shield(self.interrupt()) |
| 415 | + except asyncio.CancelledError: |
| 416 | + pass |
| 417 | + raise |
| 418 | + finally: |
| 419 | + if message_id in self._executions: |
| 420 | + del self._executions[message_id] |
370 | 421 |
|
371 | 422 | # Clean up env vars in a separate request after the main code has run |
372 | 423 | if env_vars: |
|
0 commit comments