Skip to content

Commit ca1711f

Browse files
authored
Support Debug extension (#1991)
1 parent 3697c8d commit ca1711f

4 files changed

Lines changed: 64 additions & 11 deletions

File tree

starlette/middleware/base.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22

33
import anyio
44

5+
from starlette.background import BackgroundTask
56
from starlette.requests import Request
6-
from starlette.responses import Response, StreamingResponse
7+
from starlette.responses import ContentStream, Response, StreamingResponse
78
from starlette.types import ASGIApp, Message, Receive, Scope, Send
89

910
RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]]
@@ -75,6 +76,9 @@ async def coro() -> None:
7576

7677
try:
7778
message = await recv_stream.receive()
79+
info = message.get("info", None)
80+
if message["type"] == "http.response.debug" and info is not None:
81+
message = await recv_stream.receive()
7882
except anyio.EndOfStream:
7983
if app_exc is not None:
8084
raise app_exc
@@ -93,8 +97,8 @@ async def body_stream() -> typing.AsyncGenerator[bytes, None]:
9397
if app_exc is not None:
9498
raise app_exc
9599

96-
response = StreamingResponse(
97-
status_code=message["status"], content=body_stream()
100+
response = _StreamingResponse(
101+
status_code=message["status"], content=body_stream(), info=info
98102
)
99103
response.raw_headers = message["headers"]
100104
return response
@@ -109,3 +113,22 @@ async def dispatch(
109113
self, request: Request, call_next: RequestResponseEndpoint
110114
) -> Response:
111115
raise NotImplementedError() # pragma: no cover
116+
117+
118+
class _StreamingResponse(StreamingResponse):
119+
def __init__(
120+
self,
121+
content: ContentStream,
122+
status_code: int = 200,
123+
headers: typing.Optional[typing.Mapping[str, str]] = None,
124+
media_type: typing.Optional[str] = None,
125+
background: typing.Optional[BackgroundTask] = None,
126+
info: typing.Optional[typing.Mapping[str, typing.Any]] = None,
127+
) -> None:
128+
self._info = info
129+
super().__init__(content, status_code, headers, media_type, background)
130+
131+
async def stream_response(self, send: Send) -> None:
132+
if self._info:
133+
await send({"type": "http.response.debug", "info": self._info})
134+
return await super().stream_response(send)

starlette/templating.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,14 @@ def __init__(
4141
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
4242
request = self.context.get("request", {})
4343
extensions = request.get("extensions", {})
44-
if "http.response.template" in extensions:
44+
if "http.response.debug" in extensions:
4545
await send(
4646
{
47-
"type": "http.response.template",
48-
"template": self.template,
49-
"context": self.context,
47+
"type": "http.response.debug",
48+
"info": {
49+
"template": self.template,
50+
"context": self.context,
51+
},
5052
}
5153
)
5254
await super().__call__(scope, receive, send)

starlette/testclient.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ def handle_request(self, request: httpx.Request) -> httpx.Response:
259259
"headers": headers,
260260
"client": ["testclient", 50000],
261261
"server": [host, port],
262-
"extensions": {"http.response.template": {}},
262+
"extensions": {"http.response.debug": {}},
263263
}
264264

265265
request_complete = False
@@ -324,9 +324,9 @@ async def send(message: Message) -> None:
324324
if not more_body:
325325
raw_kwargs["stream"].seek(0)
326326
response_complete.set()
327-
elif message["type"] == "http.response.template":
328-
template = message["template"]
329-
context = message["context"]
327+
elif message["type"] == "http.response.debug":
328+
template = message["info"]["template"]
329+
context = message["info"]["context"]
330330

331331
try:
332332
with self.portal_factory() as portal:

tests/test_templates.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import pytest
44

55
from starlette.applications import Starlette
6+
from starlette.middleware import Middleware
7+
from starlette.middleware.base import BaseHTTPMiddleware
68
from starlette.routing import Route
79
from starlette.templating import Jinja2Templates
810

@@ -60,3 +62,29 @@ def hello_world_processor(request):
6062
assert response.text == "<html>Hello World</html>"
6163
assert response.template.name == "index.html"
6264
assert set(response.context.keys()) == {"request", "username"}
65+
66+
67+
def test_template_with_middleware(tmpdir, test_client_factory):
68+
path = os.path.join(tmpdir, "index.html")
69+
with open(path, "w") as file:
70+
file.write("<html>Hello, <a href='{{ url_for('homepage') }}'>world</a></html>")
71+
72+
async def homepage(request):
73+
return templates.TemplateResponse("index.html", {"request": request})
74+
75+
class CustomMiddleware(BaseHTTPMiddleware):
76+
async def dispatch(self, request, call_next):
77+
return await call_next(request)
78+
79+
app = Starlette(
80+
debug=True,
81+
routes=[Route("/", endpoint=homepage)],
82+
middleware=[Middleware(CustomMiddleware)],
83+
)
84+
templates = Jinja2Templates(directory=str(tmpdir))
85+
86+
client = test_client_factory(app)
87+
response = client.get("/")
88+
assert response.text == "<html>Hello, <a href='http://testserver/'>world</a></html>"
89+
assert response.template.name == "index.html"
90+
assert set(response.context.keys()) == {"request"}

0 commit comments

Comments
 (0)