Skip to content

Commit 9c5b544

Browse files
committed
refactor asgi websocket closure
1 parent 40a7f33 commit 9c5b544

3 files changed

Lines changed: 31 additions & 18 deletions

File tree

tremolo/asgi_server.py

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
from .exceptions import (
99
InternalServerError,
10-
WebSocketException,
1110
WebSocketClientClosed,
1211
WebSocketServerClosed
1312
)
@@ -68,16 +67,20 @@ async def request_received(self, request, response):
6867
scope['scheme'] = request.scheme.decode('latin-1')
6968

7069
try:
71-
await ASGIAppWrapper(self, self.options['app'], scope, response)
70+
app = ASGIAppWrapper(self, self.options['app'], scope, response)
71+
await app
7272

73-
if not self.events['close'].done():
73+
if 'close' in self.events and not self.events['close'].done():
7474
self.logger.info('handler exited early (no close?)')
7575
await self.events['close']
7676
except (asyncio.CancelledError, Exception) as exc:
77-
if scope['type'] == 'websocket' and request.upgraded:
78-
exc = WebSocketServerClosed(cause=exc)
77+
if app.response is None: # sent
78+
self.print_exception(exc, 'app')
79+
else:
80+
if scope['type'] == 'websocket' and request.upgraded:
81+
exc = WebSocketServerClosed(cause=exc)
7982

80-
await response.handle_exception(exc)
83+
await response.handle_exception(exc, 'app')
8184
finally:
8285
scope.clear()
8386

@@ -122,26 +125,32 @@ async def receive(self):
122125
'bytes': payload
123126
}
124127
except Exception as exc:
125-
code = 1011
128+
code = 1005
126129

127130
if self._websocket is None or self.protocol.is_closing():
128-
code = 1005
129131
self.logger.info(
130132
'calling receive() after the connection is closed'
131133
)
132134
else:
133-
if isinstance(exc, WebSocketException):
135+
if isinstance(exc, WebSocketClientClosed):
134136
code = exc.code
137+
else:
138+
self.protocol.print_exception(exc, 'receive')
135139

136-
if not isinstance(exc, WebSocketClientClosed):
137-
self.protocol.print_exception(exc)
138-
await self._websocket.close(code)
139-
self._websocket = None
140+
if isinstance(exc, WebSocketServerClosed):
141+
await self._websocket.close(exc.code)
142+
elif code == 1005:
143+
await self._websocket.close(1011)
144+
else:
145+
await self._websocket.close()
140146

141147
self.protocol.request = None # force handler timeout
142148
self.protocol.set_handler_timeout(
143149
self.protocol.options['app_close_timeout']
144150
)
151+
self._websocket = None
152+
self.response = None
153+
145154
return {
146155
'type': 'websocket.disconnect',
147156
'code': code
@@ -181,11 +190,13 @@ async def receive(self):
181190
)
182191
await self.protocol.events['close']
183192
else:
184-
await self.response.handle_exception(exc)
193+
await self.response.handle_exception(exc, 'receive')
185194

186195
self.protocol.set_handler_timeout(
187196
self.protocol.options['app_close_timeout']
188197
)
198+
self.response = None
199+
189200
return {'type': 'http.disconnect'}
190201

191202
async def send(self, data):
@@ -245,6 +256,7 @@ async def send(self, data):
245256
await self.response.write(b'')
246257
self.response.close(keepalive=True)
247258
self.request = None # disallows further receive()
259+
self.response = None
248260
self.protocol.events['close'].cancel()
249261
elif data['type'] == 'websocket.send':
250262
if 'bytes' in data and data['bytes']:
@@ -256,6 +268,7 @@ async def send(self, data):
256268
elif data['type'] == 'websocket.close':
257269
await self._websocket.close(data.get('code', 1000))
258270
self._websocket = None
271+
self.response = None
259272
self.protocol.events['close'].cancel()
260273
elif data['type'] != 'http.response.start':
261274
raise InternalServerError('invalid ASGI message type')
@@ -269,5 +282,5 @@ async def send(self, data):
269282
self.response.request.upgraded):
270283
exc = WebSocketServerClosed(cause=exc)
271284

272-
await self.response.handle_exception(exc)
285+
await self.response.handle_exception(exc, 'send')
273286
self.response = None # disallows further send()

tremolo/lib/http_protocol.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ async def _handle_request(self):
257257
try:
258258
data = await self.error_received(exc, response)
259259
finally:
260-
await response.handle_exception(exc, data)
260+
await response.handle_exception(exc, data=data)
261261

262262
async def _receive_data(self):
263263
if 'request' in self.events:

tremolo/lib/http_response.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -438,7 +438,7 @@ def run_sync(func, *args):
438438

439439
self.close(keepalive=True)
440440

441-
async def handle_exception(self, exc, data=None):
441+
async def handle_exception(self, exc, *args, data=None):
442442
if self.request.protocol is None or self.request.transport is None:
443443
return
444444

@@ -448,7 +448,7 @@ async def handle_exception(self, exc, data=None):
448448

449449
if not isinstance(exc, asyncio.CancelledError):
450450
self.request.protocol.print_exception(
451-
exc, quote(unquote_to_bytes(self.request.path))
451+
exc, *args, quote(unquote_to_bytes(self.request.path))
452452
)
453453

454454
# WebSocket

0 commit comments

Comments
 (0)