Skip to content

Commit e66cce2

Browse files
authored
fix: task.wait() hangs indefinitely when task enters input_required (#3798)
* fix: resolve OpenAPI 3.x server variables in _create_default_client When an OpenAPI spec defines server variables (e.g. `https://{region}.api.example.com/v1`), the default values are now substituted before constructing the httpx client base URL. Previously, the URL was used as-is, causing all requests to fail for specs that use server variable templating. Fixes #1681 * fix: use str.replace instead of format_map for server variable substitution format_map applies Python string formatting rules, so variable names like {api.version} would be treated as attribute access and raise errors. Literal token replacement handles all valid OpenAPI variable names safely. * fix: task.wait() now returns on input_required instead of hanging Previously, wait() used a terminal-state allowlist (completed, failed, cancelled), so tasks entering input_required would hang until timeout. Replaced with inverse logic: return whenever the task exits the 'working' state. This handles input_required and any future blocking states without needing to update the allowlist. Fixes #3779 * fix: include submitted in in_progress_states to avoid premature return * fix: revert submitted, update state docstring to match MCP spec * fix: add _wait_terminal() so result() waits for completed/failed/cancelled wait() correctly returns on input_required for human-in-the-loop use cases, but result() needs to wait until the task fully resolves. Add a private _wait_terminal() helper that loops through non-terminal states and use it in all result() implementations.
1 parent db6d7a8 commit e66cce2

2 files changed

Lines changed: 48 additions & 7 deletions

File tree

src/fastmcp/client/tasks.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -216,8 +216,8 @@ async def wait(
216216
on status changes when server sends notifications/tasks/status.
217217
218218
Args:
219-
state: Desired state ('submitted', 'working', 'completed', 'failed').
220-
If None, waits for any terminal state (completed/failed)
219+
state: Desired state ('working', 'input_required', 'completed', 'failed', 'cancelled').
220+
If None, waits until the task exits the 'working' state (completed, failed, cancelled, input_required, etc.)
221221
timeout: Maximum time to wait in seconds
222222
223223
Returns:
@@ -237,15 +237,15 @@ async def wait(
237237
self._status_event = asyncio.Event()
238238

239239
start = time.time()
240-
terminal_states = {"completed", "failed", "cancelled"}
240+
in_progress_states = {"working"}
241241
poll_interval = 0.5 # Fallback polling interval (500ms)
242242

243243
while True:
244244
# Check cached status first (updated by notifications)
245245
if self._status_cache:
246246
current = self._status_cache.status
247247
if state is None:
248-
if current in terminal_states:
248+
if current not in in_progress_states:
249249
return self._status_cache
250250
elif current == state:
251251
return self._status_cache
@@ -269,6 +269,21 @@ async def wait(
269269
# Fallback: poll server (notification didn't arrive in time)
270270
self._status_cache = await self._client.get_task_status(self._task_id)
271271

272+
async def _wait_terminal(self, timeout: float = 300.0) -> GetTaskResult:
273+
"""Wait until task reaches a terminal state (completed, failed, cancelled).
274+
275+
Unlike wait(), this will not return on input_required — it continues
276+
waiting until the task fully resolves. Used internally by result().
277+
"""
278+
terminal_states = {"completed", "failed", "cancelled"}
279+
status = await self.wait(timeout=timeout)
280+
while status.status not in terminal_states:
281+
# Task is in a non-terminal state (e.g. input_required) — reset
282+
# cache so the next wait() call blocks instead of returning immediately.
283+
self._status_cache = None
284+
status = await self.wait(timeout=timeout)
285+
return status
286+
272287
async def cancel(self) -> None:
273288
"""Cancel this task, transitioning it to cancelled state.
274289
@@ -354,7 +369,7 @@ async def result(self) -> CallToolResult:
354369
self._check_client_connected()
355370

356371
# Wait for completion using event-based wait (respects notifications)
357-
await self.wait()
372+
await self._wait_terminal()
358373

359374
# Get the raw result (dict or CallToolResult)
360375
raw_result = await self._client.get_task_result(self._task_id)
@@ -445,7 +460,7 @@ async def result(self) -> mcp.types.GetPromptResult:
445460
self._check_client_connected()
446461

447462
# Wait for completion using event-based wait (respects notifications)
448-
await self.wait()
463+
await self._wait_terminal()
449464

450465
# Get the raw MCP result
451466
mcp_result = await self._client.get_task_result(self._task_id)
@@ -517,7 +532,7 @@ async def result(
517532
self._check_client_connected()
518533

519534
# Wait for completion using event-based wait (respects notifications)
520-
await self.wait()
535+
await self._wait_terminal()
521536

522537
# Get the raw MCP result
523538
mcp_result = await self._client.get_task_result(self._task_id)

tests/client/tasks/test_client_task_notifications.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import asyncio
99
import time
10+
from datetime import datetime, timezone
1011

1112
import pytest
1213
from mcp.types import GetTaskResult
@@ -207,3 +208,28 @@ async def test_notification_with_failed_task(task_notification_server):
207208
assert (
208209
status.statusMessage is not None
209210
) # Error details in statusMessage per spec
211+
212+
213+
async def test_wait_returns_on_input_required(task_notification_server):
214+
"""wait() should return immediately when task enters input_required, not hang."""
215+
async with Client(task_notification_server) as client:
216+
task = await client.call_tool("quick_task", {"value": 1}, task=True)
217+
218+
# Directly inject an input_required status into the cache and signal the event
219+
now = datetime.now(timezone.utc)
220+
input_required_status = GetTaskResult(
221+
taskId=task._task_id,
222+
status="input_required",
223+
statusMessage="Waiting for user input",
224+
createdAt=now,
225+
lastUpdatedAt=now,
226+
ttl=None,
227+
)
228+
task._status_cache = input_required_status
229+
if task._status_event is None:
230+
task._status_event = asyncio.Event()
231+
task._status_event.set()
232+
233+
# Should return immediately with input_required, not hang for 300s
234+
status = await task.wait(timeout=2.0)
235+
assert status.status == "input_required"

0 commit comments

Comments
 (0)