Skip to content

Commit ab98383

Browse files
committed
Ensure only a single QUIC timer task per connection
This prevents Hypercorn calling aioquic's interface too many times, due to many tasks running concurrently, triggering exponential backoff and errors. Many thanks to @rthalley from whom's work this is based.
1 parent 81bbb32 commit ab98383

6 files changed

Lines changed: 144 additions & 85 deletions

File tree

src/hypercorn/asyncio/tcp_server.py

Lines changed: 9 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22

33
import asyncio
44
from ssl import SSLError
5-
from typing import Any, Generator, Optional
5+
from typing import Any, Generator
66

77
from .task_group import TaskGroup
8-
from .worker_context import WorkerContext
8+
from .worker_context import AsyncioSingleTask, WorkerContext
99
from ..config import Config
1010
from ..events import Closed, Event, RawData, Updated
1111
from ..protocol import ProtocolWrapper
@@ -33,9 +33,7 @@ def __init__(
3333
self.reader = reader
3434
self.writer = writer
3535
self.send_lock = asyncio.Lock()
36-
self.idle_lock = asyncio.Lock()
37-
38-
self._idle_handle: Optional[asyncio.Task] = None
36+
self.idle_task = AsyncioSingleTask()
3937

4038
def __await__(self) -> Generator[Any, None, None]:
4139
return self.run().__await__()
@@ -54,6 +52,7 @@ async def run(self) -> None:
5452
alpn_protocol = "http/1.1"
5553

5654
async with TaskGroup(self.loop) as task_group:
55+
self._task_group = task_group
5756
self.protocol = ProtocolWrapper(
5857
self.app,
5958
self.config,
@@ -66,7 +65,7 @@ async def run(self) -> None:
6665
alpn_protocol,
6766
)
6867
await self.protocol.initiate()
69-
await self._start_idle()
68+
await self.idle_task.restart(task_group, self._idle_timeout)
7069
await self._read_data()
7170
except OSError:
7271
pass
@@ -85,9 +84,9 @@ async def protocol_send(self, event: Event) -> None:
8584
await self._close()
8685
elif isinstance(event, Updated):
8786
if event.idle:
88-
await self._start_idle()
87+
await self.idle_task.restart(self._task_group, self._idle_timeout)
8988
else:
90-
await self._stop_idle()
89+
await self.idle_task.stop()
9190

9291
async def _read_data(self) -> None:
9392
while not self.reader.at_eof():
@@ -124,28 +123,13 @@ async def _close(self) -> None:
124123
):
125124
pass # Already closed
126125
finally:
127-
await self._stop_idle()
126+
await self.idle_task.stop()
128127

129128
async def _initiate_server_close(self) -> None:
130129
await self.protocol.handle(Closed())
131130
self.writer.close()
132131

133-
async def _start_idle(self) -> None:
134-
async with self.idle_lock:
135-
if self._idle_handle is None:
136-
self._idle_handle = self.loop.create_task(self._run_idle())
137-
138-
async def _stop_idle(self) -> None:
139-
async with self.idle_lock:
140-
if self._idle_handle is not None:
141-
self._idle_handle.cancel()
142-
try:
143-
await self._idle_handle
144-
except asyncio.CancelledError:
145-
pass
146-
self._idle_handle = None
147-
148-
async def _run_idle(self) -> None:
132+
async def _idle_timeout(self) -> None:
149133
try:
150134
await asyncio.wait_for(self.context.terminated.wait(), self.config.keep_alive_timeout)
151135
except asyncio.TimeoutError:

src/hypercorn/asyncio/worker_context.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,37 @@
11
from __future__ import annotations
22

33
import asyncio
4-
from typing import Optional, Type, Union
4+
from typing import Callable, Optional, Type, Union
55

6-
from ..typing import Event
6+
from ..typing import Event, SingleTask, TaskGroup
7+
8+
9+
class AsyncioSingleTask:
10+
def __init__(self) -> None:
11+
self._handle: Optional[asyncio.Task] = None
12+
self._lock = asyncio.Lock()
13+
14+
async def restart(self, task_group: TaskGroup, action: Callable) -> None:
15+
async with self._lock:
16+
if self._handle is not None:
17+
self._handle.cancel()
18+
try:
19+
await self._handle
20+
except asyncio.CancelledError:
21+
pass
22+
23+
self._handle = task_group._task_group.create_task(action()) # type: ignore
24+
25+
async def stop(self) -> None:
26+
async with self._lock:
27+
if self._handle is not None:
28+
self._handle.cancel()
29+
try:
30+
await self._handle
31+
except asyncio.CancelledError:
32+
pass
33+
34+
self._handle = None
735

836

937
class EventWrapper:
@@ -25,6 +53,7 @@ def is_set(self) -> bool:
2553

2654
class WorkerContext:
2755
event_class: Type[Event] = EventWrapper
56+
single_task_class: Type[SingleTask] = AsyncioSingleTask
2857

2958
def __init__(self, max_requests: Optional[int]) -> None:
3059
self.max_requests = max_requests

src/hypercorn/protocol/quic.py

Lines changed: 44 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from __future__ import annotations
22

3+
from dataclasses import dataclass
34
from functools import partial
4-
from typing import Awaitable, Callable, Dict, Optional, Tuple
5+
from typing import Awaitable, Callable, Dict, Optional, Set, Tuple
56

67
from aioquic.buffer import Buffer
78
from aioquic.h3.connection import H3_ALPN
@@ -22,7 +23,15 @@
2223
from .h3 import H3Protocol
2324
from ..config import Config
2425
from ..events import Closed, Event, RawData
25-
from ..typing import AppWrapper, TaskGroup, WorkerContext
26+
from ..typing import AppWrapper, SingleTask, TaskGroup, WorkerContext
27+
28+
29+
@dataclass
30+
class _Connection:
31+
cids: Set[bytes]
32+
quic: QuicConnection
33+
task: SingleTask
34+
h3: Optional[H3Protocol] = None
2635

2736

2837
class QuicProtocol:
@@ -38,8 +47,7 @@ def __init__(
3847
self.app = app
3948
self.config = config
4049
self.context = context
41-
self.connections: Dict[bytes, QuicConnection] = {}
42-
self.http_connections: Dict[QuicConnection, H3Protocol] = {}
50+
self.connections: Dict[bytes, _Connection] = {}
4351
self.send = send
4452
self.server = server
4553
self.task_group = task_group
@@ -49,7 +57,7 @@ def __init__(
4957

5058
@property
5159
def idle(self) -> bool:
52-
return len(self.connections) == 0 and len(self.http_connections) == 0
60+
return len(self.connections) == 0
5361

5462
async def handle(self, event: Event) -> None:
5563
if isinstance(event, RawData):
@@ -76,32 +84,46 @@ async def handle(self, event: Event) -> None:
7684
and header.packet_type == PACKET_TYPE_INITIAL
7785
and not self.context.terminated.is_set()
7886
):
79-
connection = QuicConnection(
87+
quic_connection = QuicConnection(
8088
configuration=self.quic_config,
8189
original_destination_connection_id=header.destination_cid,
8290
)
91+
connection = _Connection(
92+
cids={header.destination_cid, quic_connection.host_cid},
93+
quic=quic_connection,
94+
task=self.context.single_task_class(),
95+
)
8396
self.connections[header.destination_cid] = connection
84-
self.connections[connection.host_cid] = connection
97+
self.connections[quic_connection.host_cid] = connection
8598

8699
if connection is not None:
87-
connection.receive_datagram(event.data, event.address, now=self.context.time())
100+
connection.quic.receive_datagram(event.data, event.address, now=self.context.time())
88101
await self._handle_events(connection, event.address)
89102
elif isinstance(event, Closed):
90103
pass
91104

92-
async def send_all(self, connection: QuicConnection) -> None:
93-
for data, address in connection.datagrams_to_send(now=self.context.time()):
105+
async def send_all(self, connection: _Connection) -> None:
106+
for data, address in connection.quic.datagrams_to_send(now=self.context.time()):
94107
await self.send(RawData(data=data, address=address))
95108

109+
timer = connection.quic.get_timer()
110+
if timer is not None:
111+
await connection.task.restart(
112+
self.task_group, partial(self._handle_timer, timer, connection)
113+
)
114+
96115
async def _handle_events(
97-
self, connection: QuicConnection, client: Optional[Tuple[str, int]] = None
116+
self, connection: _Connection, client: Optional[Tuple[str, int]] = None
98117
) -> None:
99-
event = connection.next_event()
118+
event = connection.quic.next_event()
100119
while event is not None:
101120
if isinstance(event, ConnectionTerminated):
102-
pass
121+
await connection.task.stop()
122+
for cid in connection.cids:
123+
del self.connections[cid]
124+
connection.cids = set()
103125
elif isinstance(event, ProtocolNegotiated):
104-
self.http_connections[connection] = H3Protocol(
126+
connection.h3 = H3Protocol(
105127
self.app,
106128
self.config,
107129
self.context,
@@ -112,24 +134,22 @@ async def _handle_events(
112134
partial(self.send_all, connection),
113135
)
114136
elif isinstance(event, ConnectionIdIssued):
137+
connection.cids.add(event.connection_id)
115138
self.connections[event.connection_id] = connection
116139
elif isinstance(event, ConnectionIdRetired):
140+
connection.cids.remove(event.connection_id)
117141
del self.connections[event.connection_id]
118142

119-
if connection in self.http_connections:
120-
await self.http_connections[connection].handle(event)
143+
if connection.h3 is not None:
144+
await connection.h3.handle(event)
121145

122-
event = connection.next_event()
146+
event = connection.quic.next_event()
123147

124148
await self.send_all(connection)
125149

126-
timer = connection.get_timer()
127-
if timer is not None:
128-
self.task_group.spawn(self._handle_timer, timer, connection)
129-
130-
async def _handle_timer(self, timer: float, connection: QuicConnection) -> None:
150+
async def _handle_timer(self, timer: float, connection: _Connection) -> None:
131151
wait = max(0, timer - self.context.time())
132152
await self.context.sleep(wait)
133-
if connection._close_at is not None:
134-
connection.handle_timer(now=self.context.time())
153+
if connection.quic._close_at is not None:
154+
connection.quic.handle_timer(now=self.context.time())
135155
await self._handle_events(connection, None)

src/hypercorn/trio/tcp_server.py

Lines changed: 13 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
from __future__ import annotations
22

33
from math import inf
4-
from typing import Any, Generator, Optional
4+
from typing import Any, Generator
55

66
import trio
77

88
from .task_group import TaskGroup
9-
from .worker_context import WorkerContext
9+
from .worker_context import TrioSingleTask, WorkerContext
1010
from ..config import Config
1111
from ..events import Closed, Event, RawData, Updated
1212
from ..protocol import ProtocolWrapper
@@ -25,11 +25,9 @@ def __init__(
2525
self.context = context
2626
self.protocol: ProtocolWrapper
2727
self.send_lock = trio.Lock()
28-
self.idle_lock = trio.Lock()
28+
self.idle_task = TrioSingleTask()
2929
self.stream = stream
3030

31-
self._idle_handle: Optional[trio.CancelScope] = None
32-
3331
def __await__(self) -> Generator[Any, None, None]:
3432
return self.run().__await__()
3533

@@ -66,7 +64,7 @@ async def run(self) -> None:
6664
alpn_protocol,
6765
)
6866
await self.protocol.initiate()
69-
await self._start_idle()
67+
await self.idle_task.restart(self._task_group, self._idle_timeout)
7068
await self._read_data()
7169
except OSError:
7270
pass
@@ -87,9 +85,9 @@ async def protocol_send(self, event: Event) -> None:
8785
await self.protocol.handle(Closed())
8886
elif isinstance(event, Updated):
8987
if event.idle:
90-
await self._start_idle()
88+
await self.idle_task.restart(self._task_group, self._idle_timeout)
9189
else:
92-
await self._stop_idle()
90+
await self.idle_task.stop()
9391

9492
async def _read_data(self) -> None:
9593
while True:
@@ -122,30 +120,13 @@ async def _close(self) -> None:
122120
pass
123121
await self.stream.aclose()
124122

123+
async def _idle_timeout(self) -> None:
124+
with trio.move_on_after(self.config.keep_alive_timeout):
125+
await self.context.terminated.wait()
126+
127+
with trio.CancelScope(shield=True):
128+
await self._initiate_server_close()
129+
125130
async def _initiate_server_close(self) -> None:
126131
await self.protocol.handle(Closed())
127132
await self.stream.aclose()
128-
129-
async def _start_idle(self) -> None:
130-
async with self.idle_lock:
131-
if self._idle_handle is None:
132-
self._idle_handle = await self._task_group._nursery.start(self._run_idle)
133-
134-
async def _stop_idle(self) -> None:
135-
async with self.idle_lock:
136-
if self._idle_handle is not None:
137-
self._idle_handle.cancel()
138-
self._idle_handle = None
139-
140-
async def _run_idle(
141-
self,
142-
task_status: trio._core._run._TaskStatus = trio.TASK_STATUS_IGNORED,
143-
) -> None:
144-
cancel_scope = trio.CancelScope()
145-
task_status.started(cancel_scope)
146-
with cancel_scope:
147-
with trio.move_on_after(self.config.keep_alive_timeout):
148-
await self.context.terminated.wait()
149-
150-
cancel_scope.shield = True
151-
await self._initiate_server_close()

0 commit comments

Comments
 (0)