Skip to content

Commit 25bb4d9

Browse files
committed
Fix websocket.disconnect not delivered after server-initiated close
Signed-off-by: JaeHyuck Sa <jaehyuck.sa.dev@gmail.com>
1 parent 189f1be commit 25bb4d9

3 files changed

Lines changed: 37 additions & 2 deletions

File tree

src/asgi/io.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -434,13 +434,12 @@ impl ASGIWebsocketProtocol {
434434
#[inline(always)]
435435
fn close<'p>(&self, py: Python<'p>, frame: Option<wsframe::CloseFrame>) -> PyResult<Bound<'p, PyAny>> {
436436
let closed = self.closed.clone();
437-
let ws_rx = self.ws_rx.clone();
438437
let ws_tx = self.ws_tx.clone();
439438

440439
future_into_py_futlike(self.rt.clone(), py, async move {
441440
if let Some(tx) = ws_tx.lock().await.take() {
442441
closed.store(true, atomic::Ordering::Release);
443-
WebsocketDetachedTransport::new(true, ws_rx.lock().await.take(), Some(tx), frame)
442+
WebsocketDetachedTransport::new(true, None, Some(tx), frame)
444443
.close()
445444
.await;
446445
}

tests/apps/asgi.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import asyncio
22
import json
33
import pathlib
4+
import tempfile
45

56
import sniffio
67

@@ -109,6 +110,15 @@ async def ws_echo(scope, receive, send):
109110
await send({'type': 'websocket.close'})
110111

111112

113+
async def ws_server_close(scope, receive, send):
114+
await receive()
115+
await send({'type': 'websocket.accept'})
116+
await receive()
117+
await send({'type': 'websocket.close', 'code': 1000})
118+
await asyncio.wait_for(receive(), timeout=5)
119+
pathlib.Path(tempfile.gettempdir(), 'granian_ws_test_result').touch()
120+
121+
112122
async def ws_push(scope, receive, send):
113123
await send({'type': 'websocket.accept'})
114124

@@ -188,6 +198,7 @@ def app(scope, receive, send):
188198
'/ws_reject': ws_reject,
189199
'/ws_info': ws_info,
190200
'/ws_echo': ws_echo,
201+
'/ws_server_close': ws_server_close,
191202
'/ws_push': ws_push,
192203
'/err_app': err_app,
193204
'/err_proto/type': err_proto_msg,

tests/test_ws.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1+
import asyncio
12
import json
23
import os
4+
import pathlib
5+
import tempfile
36

47
import pytest
58
import websockets
@@ -32,6 +35,28 @@ async def test_reject(server, runtime_mode):
3235
assert exc.value.response.status_code == 403
3336

3437

38+
@pytest.mark.asyncio
39+
@pytest.mark.parametrize('runtime_mode', ['mt', 'st'])
40+
async def test_server_initiated_close(asgi_server, runtime_mode):
41+
result_path = pathlib.Path(tempfile.gettempdir(), 'granian_ws_test_result')
42+
result_path.unlink(missing_ok=True)
43+
44+
async with asgi_server(runtime_mode) as port:
45+
ws = await websockets.connect(f'ws://localhost:{port}/ws_server_close')
46+
await ws.send('hello')
47+
for _ in range(50):
48+
if result_path.exists():
49+
break
50+
await asyncio.sleep(0.1)
51+
try:
52+
await ws.close()
53+
except Exception:
54+
pass
55+
56+
assert result_path.exists(), 'Server did not write result file'
57+
result_path.unlink(missing_ok=True)
58+
59+
3560
@pytest.mark.asyncio
3661
@pytest.mark.skipif(bool(os.getenv('PGO_RUN')), reason='PGO build')
3762
@pytest.mark.parametrize('runtime_mode', ['mt', 'st'])

0 commit comments

Comments
 (0)