Skip to content

Commit 49093b0

Browse files
committed
refactor(ws): Clean up websocket listener DI handling
Clean up some rough edges in the websocket listener DI handling
1 parent f45503f commit 49093b0

2 files changed

Lines changed: 12 additions & 7 deletions

File tree

litestar/handlers/websocket_handlers/_utils.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,10 +141,17 @@ def create_handler_signature(callback_signature: Signature) -> Signature:
141141
return callback_signature.replace(parameters=new_params)
142142

143143

144-
def create_stub_dependency(src: AnyCallable) -> Provide:
144+
def create_stub_dependency(src: AnyCallable | None) -> Provide:
145145
"""Create a stub dependency, accepting any kwargs defined in ``src``, and
146146
wrap it in ``Provide``
147147
"""
148+
if src is None:
149+
150+
async def empty_stub() -> None:
151+
return None
152+
153+
return Provide(empty_stub)
154+
148155
src = unwrap_partial(src)
149156

150157
@wraps(src)

litestar/handlers/websocket_handlers/listener.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -203,11 +203,9 @@ def __init__(
203203
connection_lifespan or self.default_connection_lifespan
204204
)
205205

206-
if self.on_accept:
207-
listener_dependencies["on_accept_dependencies"] = create_stub_dependency(self.on_accept)
206+
listener_dependencies["on_accept_dependencies"] = create_stub_dependency(self.on_accept)
208207

209-
if self.on_disconnect:
210-
listener_dependencies["on_disconnect_dependencies"] = create_stub_dependency(self.on_disconnect)
208+
listener_dependencies["on_disconnect_dependencies"] = create_stub_dependency(self.on_disconnect)
211209

212210
super().__init__(
213211
path=path,
@@ -376,8 +374,8 @@ def __init__(self, owner: Router) -> None:
376374
self._owner = owner
377375

378376
def to_handler(self) -> WebsocketListenerRouteHandler:
379-
on_accept = self.on_accept if self.on_accept != WebsocketListener.on_accept else None
380-
on_disconnect = self.on_disconnect if self.on_disconnect != WebsocketListener.on_disconnect else None
377+
on_accept = self.on_accept if type(self).on_accept != WebsocketListener.on_accept else None
378+
on_disconnect = self.on_disconnect if type(self).on_disconnect != WebsocketListener.on_disconnect else None
381379

382380
handler = WebsocketListenerRouteHandler(
383381
dependencies=self.dependencies,

0 commit comments

Comments
 (0)