Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 29 additions & 25 deletions src/asgi/io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -218,31 +218,35 @@ impl ASGIHTTPProtocol {
more,
self.response_chunked.load(atomic::Ordering::Relaxed),
) {
(true, false, false) => {
let (status, headers) = self.response_intent.lock().unwrap().take().unwrap();
self.send_response(
status,
headers,
http_body_util::Full::new(body::Bytes::from(body))
.map_err(std::convert::Into::into)
.boxed(),
);
self.flow_tx_waiter.notify_one();
empty_future_into_py(py)
}
(true, true, false) => {
self.response_chunked.store(true, atomic::Ordering::Relaxed);
let (status, headers) = self.response_intent.lock().unwrap().take().unwrap();
let (body_tx, body_rx) = mpsc::unbounded_channel::<body::Bytes>();
let body_stream = http_body_util::StreamBody::new(
tokio_stream::wrappers::UnboundedReceiverStream::new(body_rx)
.map(body::Frame::data)
.map(Result::Ok),
);
*self.body_tx.lock().unwrap() = Some(body_tx.clone());
self.send_response(status, headers, BodyExt::boxed(body_stream));
self.send_body(py, &body_tx, body, false)
}
(true, false, false) => match self.response_intent.lock().unwrap().take() {
Some((status, headers)) => {
self.send_response(
status,
headers,
http_body_util::Full::new(body::Bytes::from(body))
.map_err(std::convert::Into::into)
.boxed(),
);
self.flow_tx_waiter.notify_one();
empty_future_into_py(py)
}
_ => error_flow!("Response already finished"),
},
(true, true, false) => match self.response_intent.lock().unwrap().take() {
Some((status, headers)) => {
self.response_chunked.store(true, atomic::Ordering::Relaxed);
let (body_tx, body_rx) = mpsc::unbounded_channel::<body::Bytes>();
let body_stream = http_body_util::StreamBody::new(
tokio_stream::wrappers::UnboundedReceiverStream::new(body_rx)
.map(body::Frame::data)
.map(Result::Ok),
);
*self.body_tx.lock().unwrap() = Some(body_tx.clone());
self.send_response(status, headers, BodyExt::boxed(body_stream));
self.send_body(py, &body_tx, body, false)
}
_ => error_flow!("Response already finished"),
},
(true, true, true) => match &*self.body_tx.lock().unwrap() {
Some(tx) => self.send_body(py, tx, body, false),
_ => error_flow!("Transport not initialized or closed"),
Expand Down
25 changes: 22 additions & 3 deletions tests/apps/asgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,26 @@ async def err_app(scope, receive, send):
1 / 0


async def err_proto(scope, receive, send):
await send({'type': 'wrong.msg'})
async def err_proto_msg(scope, receive, send):
await send(PLAINTEXT_RESPONSE)
try:
await send({'type': 'wrong.msg'})
except Exception as e:
msg = e.args[0]
await send({'type': 'http.response.body', 'body': msg.encode('utf8'), 'more_body': False})


async def err_proto_flow(scope, receive, send):
await send(PLAINTEXT_RESPONSE)
await send({'type': 'http.response.body', 'body': b'msg1', 'more_body': False})
try:
await send({'type': 'http.response.body', 'body': b'msg2', 'more_body': True})
except Exception:
pass
try:
await send({'type': 'http.response.body', 'body': b'msg3', 'more_body': False})
except Exception:
pass


async def timeout_n(scope, receive, send):
Expand Down Expand Up @@ -172,7 +190,8 @@ def app(scope, receive, send):
'/ws_echo': ws_echo,
'/ws_push': ws_push,
'/err_app': err_app,
'/err_proto': err_proto,
'/err_proto/type': err_proto_msg,
'/err_proto/flow': err_proto_flow,
'/timeout_n': timeout_n,
'/timeout_w': timeout_w,
}.get(scope['path'], info)(scope, receive, send)
8 changes: 6 additions & 2 deletions tests/test_asgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,13 @@ async def test_app_error(asgi_server, runtime_mode):
@pytest.mark.parametrize('runtime_mode', ['mt', 'st'])
async def test_protocol_error(asgi_server, runtime_mode):
async with asgi_server(runtime_mode, ws=False) as port:
res = httpx.get(f'http://localhost:{port}/err_proto')
res = httpx.get(f'http://localhost:{port}/err_proto/type')
assert res.status_code == 200
assert res.text == 'Unsupported ASGI message'

assert res.status_code == 500
res = httpx.get(f'http://localhost:{port}/err_proto/flow')
assert res.status_code == 200
assert res.text == 'msg1'


@pytest.mark.asyncio
Expand Down