Skip to content

Commit 51c1de1

Browse files
adriangbKludex
andauthored
Lazily build middleware stack (#2017)
Co-authored-by: Marcelo Trylesinski <marcelotryle@gmail.com>
1 parent ca1711f commit 51c1de1

2 files changed

Lines changed: 52 additions & 16 deletions

File tree

starlette/applications.py

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def __init__(
6565
on_startup is None and on_shutdown is None
6666
), "Use either 'lifespan' or 'on_startup'/'on_shutdown', not both."
6767

68-
self._debug = debug
68+
self.debug = debug
6969
self.state = State()
7070
self.router = Router(
7171
routes, on_startup=on_startup, on_shutdown=on_shutdown, lifespan=lifespan
@@ -74,7 +74,7 @@ def __init__(
7474
{} if exception_handlers is None else dict(exception_handlers)
7575
)
7676
self.user_middleware = [] if middleware is None else list(middleware)
77-
self.middleware_stack = self.build_middleware_stack()
77+
self.middleware_stack: typing.Optional[ASGIApp] = None
7878

7979
def build_middleware_stack(self) -> ASGIApp:
8080
debug = self.debug
@@ -108,20 +108,13 @@ def build_middleware_stack(self) -> ASGIApp:
108108
def routes(self) -> typing.List[BaseRoute]:
109109
return self.router.routes
110110

111-
@property
112-
def debug(self) -> bool:
113-
return self._debug
114-
115-
@debug.setter
116-
def debug(self, value: bool) -> None:
117-
self._debug = value
118-
self.middleware_stack = self.build_middleware_stack()
119-
120111
def url_path_for(self, name: str, **path_params: typing.Any) -> URLPath:
121112
return self.router.url_path_for(name, **path_params)
122113

123114
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
124115
scope["app"] = self
116+
if self.middleware_stack is None:
117+
self.middleware_stack = self.build_middleware_stack()
125118
await self.middleware_stack(scope, receive, send)
126119

127120
def on_event(self, event_type: str) -> typing.Callable: # pragma: nocover
@@ -137,19 +130,17 @@ def host(
137130
) -> None: # pragma: no cover
138131
self.router.host(host, app=app, name=name)
139132

140-
def add_middleware(
141-
self, middleware_class: type, **options: typing.Any
142-
) -> None: # pragma: no cover
133+
def add_middleware(self, middleware_class: type, **options: typing.Any) -> None:
134+
if self.middleware_stack is not None: # pragma: no cover
135+
raise RuntimeError("Cannot add middleware after an application has started")
143136
self.user_middleware.insert(0, Middleware(middleware_class, **options))
144-
self.middleware_stack = self.build_middleware_stack()
145137

146138
def add_exception_handler(
147139
self,
148140
exc_class_or_status_code: typing.Union[int, typing.Type[Exception]],
149141
handler: typing.Callable,
150142
) -> None: # pragma: no cover
151143
self.exception_handlers[exc_class_or_status_code] = handler
152-
self.middleware_stack = self.build_middleware_stack()
153144

154145
def add_event_handler(
155146
self, event_type: str, func: typing.Callable

tests/test_applications.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import os
22
from contextlib import asynccontextmanager
3+
from typing import Any, Callable
34

45
import anyio
6+
import httpx
57
import pytest
68

79
from starlette import status
@@ -13,6 +15,7 @@
1315
from starlette.responses import JSONResponse, PlainTextResponse
1416
from starlette.routing import Host, Mount, Route, Router, WebSocketRoute
1517
from starlette.staticfiles import StaticFiles
18+
from starlette.types import ASGIApp
1619
from starlette.websockets import WebSocket
1720

1821

@@ -486,3 +489,45 @@ async def startup():
486489

487490
app.on_event("startup")(startup)
488491
assert len(record) == 1
492+
493+
494+
def test_middleware_stack_init(test_client_factory: Callable[[ASGIApp], httpx.Client]):
495+
class NoOpMiddleware:
496+
def __init__(self, app: ASGIApp):
497+
self.app = app
498+
499+
async def __call__(self, *args: Any):
500+
await self.app(*args)
501+
502+
class SimpleInitializableMiddleware:
503+
counter = 0
504+
505+
def __init__(self, app: ASGIApp):
506+
self.app = app
507+
SimpleInitializableMiddleware.counter += 1
508+
509+
async def __call__(self, *args: Any):
510+
await self.app(*args)
511+
512+
def get_app() -> ASGIApp:
513+
app = Starlette()
514+
app.add_middleware(SimpleInitializableMiddleware)
515+
app.add_middleware(NoOpMiddleware)
516+
return app
517+
518+
app = get_app()
519+
520+
with test_client_factory(app):
521+
pass
522+
523+
assert SimpleInitializableMiddleware.counter == 1
524+
525+
test_client_factory(app).get("/foo")
526+
527+
assert SimpleInitializableMiddleware.counter == 1
528+
529+
app = get_app()
530+
531+
test_client_factory(app).get("/foo")
532+
533+
assert SimpleInitializableMiddleware.counter == 2

0 commit comments

Comments
 (0)