diff --git a/src/asgi/io.rs b/src/asgi/io.rs index a7068f41..775c22fd 100644 --- a/src/asgi/io.rs +++ b/src/asgi/io.rs @@ -218,31 +218,35 @@ impl ASGIHTTPProtocol { more, self.response_chunked.load(atomic::Ordering::Relaxed), ) { - (true, false, false) => { - let (status, headers) = self.response_intent.lock().unwrap().take().unwrap(); - self.send_response( - status, - headers, - http_body_util::Full::new(body::Bytes::from(body)) - .map_err(std::convert::Into::into) - .boxed(), - ); - self.flow_tx_waiter.notify_one(); - empty_future_into_py(py) - } - (true, true, false) => { - self.response_chunked.store(true, atomic::Ordering::Relaxed); - let (status, headers) = self.response_intent.lock().unwrap().take().unwrap(); - let (body_tx, body_rx) = mpsc::unbounded_channel::(); - let body_stream = http_body_util::StreamBody::new( - tokio_stream::wrappers::UnboundedReceiverStream::new(body_rx) - .map(body::Frame::data) - .map(Result::Ok), - ); - *self.body_tx.lock().unwrap() = Some(body_tx.clone()); - self.send_response(status, headers, BodyExt::boxed(body_stream)); - self.send_body(py, &body_tx, body, false) - } + (true, false, false) => match self.response_intent.lock().unwrap().take() { + Some((status, headers)) => { + self.send_response( + status, + headers, + http_body_util::Full::new(body::Bytes::from(body)) + .map_err(std::convert::Into::into) + .boxed(), + ); + self.flow_tx_waiter.notify_one(); + empty_future_into_py(py) + } + _ => error_flow!("Response already finished"), + }, + (true, true, false) => match self.response_intent.lock().unwrap().take() { + Some((status, headers)) => { + self.response_chunked.store(true, atomic::Ordering::Relaxed); + let (body_tx, body_rx) = mpsc::unbounded_channel::(); + let body_stream = http_body_util::StreamBody::new( + tokio_stream::wrappers::UnboundedReceiverStream::new(body_rx) + .map(body::Frame::data) + .map(Result::Ok), + ); + *self.body_tx.lock().unwrap() = Some(body_tx.clone()); + self.send_response(status, headers, BodyExt::boxed(body_stream)); + self.send_body(py, &body_tx, body, false) + } + _ => error_flow!("Response already finished"), + }, (true, true, true) => match &*self.body_tx.lock().unwrap() { Some(tx) => self.send_body(py, tx, body, false), _ => error_flow!("Transport not initialized or closed"), diff --git a/tests/apps/asgi.py b/tests/apps/asgi.py index a62ab93c..82068bfb 100644 --- a/tests/apps/asgi.py +++ b/tests/apps/asgi.py @@ -123,8 +123,26 @@ async def err_app(scope, receive, send): 1 / 0 -async def err_proto(scope, receive, send): - await send({'type': 'wrong.msg'}) +async def err_proto_msg(scope, receive, send): + await send(PLAINTEXT_RESPONSE) + try: + await send({'type': 'wrong.msg'}) + except Exception as e: + msg = e.args[0] + await send({'type': 'http.response.body', 'body': msg.encode('utf8'), 'more_body': False}) + + +async def err_proto_flow(scope, receive, send): + await send(PLAINTEXT_RESPONSE) + await send({'type': 'http.response.body', 'body': b'msg1', 'more_body': False}) + try: + await send({'type': 'http.response.body', 'body': b'msg2', 'more_body': True}) + except Exception: + pass + try: + await send({'type': 'http.response.body', 'body': b'msg3', 'more_body': False}) + except Exception: + pass async def timeout_n(scope, receive, send): @@ -172,7 +190,8 @@ def app(scope, receive, send): '/ws_echo': ws_echo, '/ws_push': ws_push, '/err_app': err_app, - '/err_proto': err_proto, + '/err_proto/type': err_proto_msg, + '/err_proto/flow': err_proto_flow, '/timeout_n': timeout_n, '/timeout_w': timeout_w, }.get(scope['path'], info)(scope, receive, send) diff --git a/tests/test_asgi.py b/tests/test_asgi.py index f8a76e0d..2489d1d7 100644 --- a/tests/test_asgi.py +++ b/tests/test_asgi.py @@ -75,9 +75,13 @@ async def test_app_error(asgi_server, runtime_mode): @pytest.mark.parametrize('runtime_mode', ['mt', 'st']) async def test_protocol_error(asgi_server, runtime_mode): async with asgi_server(runtime_mode, ws=False) as port: - res = httpx.get(f'http://localhost:{port}/err_proto') + res = httpx.get(f'http://localhost:{port}/err_proto/type') + assert res.status_code == 200 + assert res.text == 'Unsupported ASGI message' - assert res.status_code == 500 + res = httpx.get(f'http://localhost:{port}/err_proto/flow') + assert res.status_code == 200 + assert res.text == 'msg1' @pytest.mark.asyncio