Skip to content

Commit 3f89c46

Browse files
committed
Avoid panics on exceeding ASGI messages (#707)
1 parent 8239129 commit 3f89c46

3 files changed

Lines changed: 57 additions & 30 deletions

File tree

src/asgi/io.rs

Lines changed: 29 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -218,31 +218,35 @@ impl ASGIHTTPProtocol {
218218
more,
219219
self.response_chunked.load(atomic::Ordering::Relaxed),
220220
) {
221-
(true, false, false) => {
222-
let (status, headers) = self.response_intent.lock().unwrap().take().unwrap();
223-
self.send_response(
224-
status,
225-
headers,
226-
http_body_util::Full::new(body::Bytes::from(body))
227-
.map_err(std::convert::Into::into)
228-
.boxed(),
229-
);
230-
self.flow_tx_waiter.notify_one();
231-
empty_future_into_py(py)
232-
}
233-
(true, true, false) => {
234-
self.response_chunked.store(true, atomic::Ordering::Relaxed);
235-
let (status, headers) = self.response_intent.lock().unwrap().take().unwrap();
236-
let (body_tx, body_rx) = mpsc::unbounded_channel::<body::Bytes>();
237-
let body_stream = http_body_util::StreamBody::new(
238-
tokio_stream::wrappers::UnboundedReceiverStream::new(body_rx)
239-
.map(body::Frame::data)
240-
.map(Result::Ok),
241-
);
242-
*self.body_tx.lock().unwrap() = Some(body_tx.clone());
243-
self.send_response(status, headers, BodyExt::boxed(body_stream));
244-
self.send_body(py, &body_tx, body, false)
245-
}
221+
(true, false, false) => match self.response_intent.lock().unwrap().take() {
222+
Some((status, headers)) => {
223+
self.send_response(
224+
status,
225+
headers,
226+
http_body_util::Full::new(body::Bytes::from(body))
227+
.map_err(std::convert::Into::into)
228+
.boxed(),
229+
);
230+
self.flow_tx_waiter.notify_one();
231+
empty_future_into_py(py)
232+
}
233+
_ => error_flow!("Response already finished"),
234+
},
235+
(true, true, false) => match self.response_intent.lock().unwrap().take() {
236+
Some((status, headers)) => {
237+
self.response_chunked.store(true, atomic::Ordering::Relaxed);
238+
let (body_tx, body_rx) = mpsc::unbounded_channel::<body::Bytes>();
239+
let body_stream = http_body_util::StreamBody::new(
240+
tokio_stream::wrappers::UnboundedReceiverStream::new(body_rx)
241+
.map(body::Frame::data)
242+
.map(Result::Ok),
243+
);
244+
*self.body_tx.lock().unwrap() = Some(body_tx.clone());
245+
self.send_response(status, headers, BodyExt::boxed(body_stream));
246+
self.send_body(py, &body_tx, body, false)
247+
}
248+
_ => error_flow!("Response already finished"),
249+
},
246250
(true, true, true) => match &*self.body_tx.lock().unwrap() {
247251
Some(tx) => self.send_body(py, tx, body, false),
248252
_ => error_flow!("Transport not initialized or closed"),

tests/apps/asgi.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -123,8 +123,26 @@ async def err_app(scope, receive, send):
123123
1 / 0
124124

125125

126-
async def err_proto(scope, receive, send):
127-
await send({'type': 'wrong.msg'})
126+
async def err_proto_msg(scope, receive, send):
127+
await send(PLAINTEXT_RESPONSE)
128+
try:
129+
await send({'type': 'wrong.msg'})
130+
except Exception as e:
131+
msg = e.args[0]
132+
await send({'type': 'http.response.body', 'body': msg.encode('utf8'), 'more_body': False})
133+
134+
135+
async def err_proto_flow(scope, receive, send):
136+
await send(PLAINTEXT_RESPONSE)
137+
await send({'type': 'http.response.body', 'body': b'msg1', 'more_body': False})
138+
try:
139+
await send({'type': 'http.response.body', 'body': b'msg2', 'more_body': True})
140+
except Exception:
141+
pass
142+
try:
143+
await send({'type': 'http.response.body', 'body': b'msg3', 'more_body': False})
144+
except Exception:
145+
pass
128146

129147

130148
async def timeout_n(scope, receive, send):
@@ -172,7 +190,8 @@ def app(scope, receive, send):
172190
'/ws_echo': ws_echo,
173191
'/ws_push': ws_push,
174192
'/err_app': err_app,
175-
'/err_proto': err_proto,
193+
'/err_proto/type': err_proto_msg,
194+
'/err_proto/flow': err_proto_flow,
176195
'/timeout_n': timeout_n,
177196
'/timeout_w': timeout_w,
178197
}.get(scope['path'], info)(scope, receive, send)

tests/test_asgi.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,13 @@ async def test_app_error(asgi_server, runtime_mode):
7575
@pytest.mark.parametrize('runtime_mode', ['mt', 'st'])
7676
async def test_protocol_error(asgi_server, runtime_mode):
7777
async with asgi_server(runtime_mode, ws=False) as port:
78-
res = httpx.get(f'http://localhost:{port}/err_proto')
78+
res = httpx.get(f'http://localhost:{port}/err_proto/type')
79+
assert res.status_code == 200
80+
assert res.text == 'Unsupported ASGI message'
7981

80-
assert res.status_code == 500
82+
res = httpx.get(f'http://localhost:{port}/err_proto/flow')
83+
assert res.status_code == 200
84+
assert res.text == 'msg1'
8185

8286

8387
@pytest.mark.asyncio

0 commit comments

Comments
 (0)