11from abc import ABC , abstractmethod
22from inspect import isawaitable
33from logging import Logger , LoggerAdapter
4- from typing import Any
4+ from typing import Any , cast
55
66from graphql import DocumentNode , ExecutionContext , GraphQLSchema , MiddlewareManager
77from starlette .types import Receive , Scope , Send
1111from ...types import (
1212 ContextValue ,
1313 ErrorFormatter ,
14+ ExtensionList ,
15+ Extensions ,
1416 GraphQLResult ,
17+ MiddlewareList ,
18+ Middlewares ,
1519 OnComplete ,
1620 OnConnect ,
1721 OnDisconnect ,
@@ -168,6 +172,9 @@ def __init__(
168172 on_disconnect : OnDisconnect | None = None ,
169173 on_operation : OnOperation | None = None ,
170174 on_complete : OnComplete | None = None ,
175+ extensions : Extensions | None = None ,
176+ middleware : Middlewares | None = None ,
177+ middleware_manager_class : type [MiddlewareManager ] | None = None ,
171178 ) -> None :
172179 """Initialize websocket handler with optional options specific to it.
173180
@@ -183,6 +190,19 @@ def __init__(
183190
184191 `on_complete`: an `OnComplete` callback, used when GraphQL operation
185192 received over the websocket connection was completed.
193+
194+ `extensions`: an `Extensions` list or callable returning a
195+ list of extensions server should use during subscription execution.
196+ Defaults to no extensions.
197+
198+ `middleware`: a `Middlewares` list or callable returning a list of
199+ middlewares server should use during subscription execution. Defaults
200+ to no middlewares.
201+
202+ `middleware_manager_class`: a `MiddlewareManager` type or subclass to
203+ use for combining provided middlewares into single wrapper for resolvers
204+ by the server. Defaults to `graphql.MiddlewareManager`. Is only used
205+ if `extensions` or `middleware` options are set.
186206 """
187207 super ().__init__ ()
188208 self .http_handler : GraphQLHttpHandlerBase | None = None
@@ -191,11 +211,57 @@ def __init__(
191211 self .on_disconnect : OnDisconnect | None = on_disconnect
192212 self .on_operation : OnOperation | None = on_operation
193213 self .on_complete : OnComplete | None = on_complete
214+ self .extensions = extensions
215+ self .middleware = middleware
216+ if middleware_manager_class is not None :
217+ self .middleware_manager_class = middleware_manager_class
194218
195219 @abstractmethod
196220 async def handle_websocket (self , websocket : Any ):
197221 """Abstract method for handling the websocket connection."""
198222
223+ async def get_extensions_for_request (
224+ self , request : Any , context : ContextValue | None
225+ ) -> ExtensionList :
226+ """Returns extensions to use when handling the GraphQL request.
227+
228+ Returns `ExtensionList`, a list of extensions to use or `None`.
229+
230+ # Required arguments
231+
232+ `request`: the `WebSocket` instance from Starlette or FastAPI.
233+
234+ `context`: a `ContextValue` for this request.
235+ """
236+ if callable (self .extensions ):
237+ extensions = self .extensions (request , context ) # ty: ignore
238+ if isawaitable (extensions ):
239+ extensions = await extensions
240+ return cast (ExtensionList , extensions )
241+ return self .extensions
242+
243+ async def get_middleware_for_request (
244+ self , request : Any , context : ContextValue | None
245+ ) -> MiddlewareList :
246+ """Returns GraphQL middlewares to use when handling the GraphQL request.
247+
248+ Returns `MiddlewareList`, a list of middlewares to use or `None`.
249+
250+ # Required arguments
251+
252+ `request`: the `WebSocket` instance from Starlette or FastAPI.
253+
254+ `context`: a `ContextValue` for this request.
255+ """
256+ middleware = self .middleware
257+ if callable (middleware ):
258+ middleware = middleware (request , context ) # ty: ignore
259+ if isawaitable (middleware ):
260+ middleware = await middleware
261+ if middleware :
262+ return cast (MiddlewareList , middleware )
263+ return None
264+
199265 def configure (
200266 self ,
201267 * args ,
0 commit comments