Skip to content

Commit 15e95df

Browse files
committed
add graceful timeout option in WebSocket.close()
1 parent 0519f84 commit 15e95df

3 files changed

Lines changed: 39 additions & 28 deletions

File tree

tremolo/lib/http_request.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,8 @@ async def recv(self, size=-1, *, timeout=None, raw=True):
153153

154154
if self._stream is None:
155155
self._stream = self.stream(timeout, raw)
156+
elif timeout is not None: # update
157+
self.timeout = timeout
156158

157159
if size == -1:
158160
return await self._stream.__anext__()

tremolo/lib/request.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,13 @@
77

88

99
class Request:
10-
__slots__ = ('protocol', 'context', 'body_size')
10+
__slots__ = ('protocol', 'context', 'body_size', 'timeout')
1111

1212
def __init__(self, protocol):
1313
self.protocol = protocol
1414
self.context = RequestContext()
1515
self.body_size = 0
16+
self.timeout = protocol.options['keepalive_timeout']
1617

1718
@property
1819
def server(self):
@@ -48,12 +49,12 @@ def clear(self):
4849
self.protocol = None # cut off access to the protocol object
4950

5051
async def recv(self, timeout=None):
51-
if timeout is None:
52-
timeout = self.server.options['keepalive_timeout']
52+
if timeout is not None:
53+
self.timeout = timeout
5354

5455
while self.server.queue:
5556
try:
56-
data = await self.server.queue[0].get(timeout)
57+
data = await self.server.queue[0].get(self.timeout)
5758
except asyncio.CancelledError as exc:
5859
raise TimeoutError('recv timeout') from exc
5960

tremolo/lib/websocket.py

Lines changed: 32 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,12 @@ class WebSocket:
1414
def __init__(self, request, response):
1515
self.request = request
1616
self.response = response
17+
self.max_payload_size = request.server.options['ws_max_payload_size']
18+
self.ping_interval = request.server.options['keepalive_timeout'] / 2
19+
1720
self.fin = 1
1821
self.opcode = 0
1922

20-
self._max_payload_size = request.server.options['ws_max_payload_size']
21-
self._receive_timeout = request.server.options['keepalive_timeout'] / 2
22-
2323
def __aiter__(self):
2424
return self
2525

@@ -44,18 +44,18 @@ async def accept(self):
4444
self.response.set_header(b'Sec-WebSocket-Accept', accept_key)
4545
await self.response.write()
4646

47-
async def recv(self):
47+
async def recv(self, timeout=None):
4848
try:
49-
first_byte, second_byte = await self.request.recv(2)
49+
b1, b2 = await self.request.recv(2, timeout=timeout)
5050
except ValueError as exc:
5151
raise WebSocketClientClosed(
5252
'connection closed: recv failed'
5353
) from exc
5454

55-
fin = (first_byte & 0x80) >> 7
56-
opcode = first_byte & 0x0f
57-
is_masked = (second_byte & 0x80) >> 7
58-
payload_length = second_byte & 0x7f
55+
fin = (b1 & 0x80) >> 7
56+
opcode = b1 & 0x0f
57+
is_masked = (b2 & 0x80) >> 7
58+
payload_length = b2 & 0x7f
5959

6060
if opcode != 0:
6161
if self.fin != 1:
@@ -80,10 +80,10 @@ async def recv(self):
8080
payload_length = int.from_bytes(await self.request.recv(8),
8181
byteorder='big')
8282

83-
if payload_length > self._max_payload_size:
83+
if payload_length > self.max_payload_size:
8484
raise WebSocketServerClosed(
8585
'%d exceeds maximum payload size (%d)' %
86-
(payload_length, self._max_payload_size),
86+
(payload_length, self.max_payload_size),
8787
code=1009
8888
)
8989

@@ -126,18 +126,18 @@ async def recv(self):
126126
code=1008
127127
)
128128

129-
async def receive(self):
129+
async def receive(self, timeout=None):
130130
payload = bytearray()
131131

132-
while True:
132+
while self.opcode != 8:
133133
coro = self.ping()
134134
timer = self.request.server.loop.call_at(
135-
self.request.server.loop.time() + self._receive_timeout,
135+
self.request.server.loop.time() + self.ping_interval,
136136
self.request.server.create_task, coro
137137
)
138138

139139
try:
140-
frame = await self.recv()
140+
frame = await self.recv(timeout)
141141
except TimeoutError as exc:
142142
raise WebSocketServerClosed('receive timeout',
143143
code=1000) from exc
@@ -151,12 +151,14 @@ async def receive(self):
151151

152152
payload.extend(frame)
153153

154-
if len(payload) > self._max_payload_size:
154+
if len(payload) > self.max_payload_size:
155155
raise WebSocketServerClosed('maximum payload size exceeded',
156156
code=1009)
157157

158158
if self.fin == 1:
159159
break
160+
else:
161+
raise WebSocketClientClosed('connection already closed')
160162

161163
if self.opcode == 1:
162164
return payload.decode('utf-8')
@@ -174,24 +176,24 @@ def create_frame(payload_data, fin=1, opcode=None, mask=False):
174176
if opcode == 1:
175177
payload_data = payload_data.encode('utf-8')
176178

177-
first_byte = (fin << 7) | opcode
179+
b1 = (fin << 7) | opcode
178180
payload_length = len(payload_data)
179181

180182
if payload_length < 126:
181-
second_byte = payload_length
183+
b2 = payload_length
182184
payload_length_data = b''
183185
elif payload_length < 65536:
184-
second_byte = 126
186+
b2 = 126
185187
payload_length_data = payload_length.to_bytes(2, byteorder='big')
186188
else:
187-
second_byte = 127
189+
b2 = 127
188190
payload_length_data = payload_length.to_bytes(8, byteorder='big')
189191

190192
if mask:
191-
second_byte |= (1 << 7)
193+
b2 |= (1 << 7)
192194
masking_key = os.urandom(4)
193195

194-
frame_header = bytes([first_byte, second_byte]) + payload_length_data
196+
frame_header = bytes([b1, b2]) + payload_length_data
195197

196198
if mask:
197199
masked_payload_data = bytes(
@@ -216,9 +218,15 @@ async def ping(self, data=b''):
216218
async def pong(self, data=b''):
217219
await self.send(data, opcode=10)
218220

219-
async def close(self, code=1000):
221+
async def close(self, code=1000, *, timeout=None):
220222
try:
221223
await self.send(code.to_bytes(2, byteorder='big'), opcode=8)
222-
self.response.close(keepalive=True)
223224
except RuntimeError:
225+
return
226+
227+
try:
228+
await self.receive(timeout)
229+
except (WebSocketClientClosed, WebSocketServerClosed):
224230
pass
231+
finally:
232+
self.response.close()

0 commit comments

Comments
 (0)