Skip to content

Commit 4e49360

Browse files
feat: extend subscription with middleware
1 parent 0f0b033 commit 4e49360

16 files changed

Lines changed: 291 additions & 33 deletions

File tree

ariadne/asgi/handlers/base.py

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from abc import ABC, abstractmethod
22
from inspect import isawaitable
33
from logging import Logger, LoggerAdapter
4-
from typing import Any
4+
from typing import Any, cast
55

66
from graphql import DocumentNode, ExecutionContext, GraphQLSchema, MiddlewareManager
77
from starlette.types import Receive, Scope, Send
@@ -11,7 +11,11 @@
1111
from ...types import (
1212
ContextValue,
1313
ErrorFormatter,
14+
ExtensionList,
15+
Extensions,
1416
GraphQLResult,
17+
MiddlewareList,
18+
Middlewares,
1519
OnComplete,
1620
OnConnect,
1721
OnDisconnect,
@@ -168,6 +172,9 @@ def __init__(
168172
on_disconnect: OnDisconnect | None = None,
169173
on_operation: OnOperation | None = None,
170174
on_complete: OnComplete | None = None,
175+
extensions: Extensions | None = None,
176+
middleware: Middlewares | None = None,
177+
middleware_manager_class: type[MiddlewareManager] | None = None,
171178
) -> None:
172179
"""Initialize websocket handler with optional options specific to it.
173180
@@ -183,6 +190,19 @@ def __init__(
183190
184191
`on_complete`: an `OnComplete` callback, used when GraphQL operation
185192
received over the websocket connection was completed.
193+
194+
`extensions`: an `Extensions` list or callable returning a
195+
list of extensions server should use during subscription execution.
196+
Defaults to no extensions.
197+
198+
`middleware`: a `Middlewares` list or callable returning a list of
199+
middlewares server should use during subscription execution. Defaults
200+
to no middlewares.
201+
202+
`middleware_manager_class`: a `MiddlewareManager` type or subclass to
203+
use for combining provided middlewares into single wrapper for resolvers
204+
by the server. Defaults to `graphql.MiddlewareManager`. Is only used
205+
if `extensions` or `middleware` options are set.
186206
"""
187207
super().__init__()
188208
self.http_handler: GraphQLHttpHandlerBase | None = None
@@ -191,11 +211,57 @@ def __init__(
191211
self.on_disconnect: OnDisconnect | None = on_disconnect
192212
self.on_operation: OnOperation | None = on_operation
193213
self.on_complete: OnComplete | None = on_complete
214+
self.extensions = extensions
215+
self.middleware = middleware
216+
if middleware_manager_class is not None:
217+
self.middleware_manager_class = middleware_manager_class
194218

195219
@abstractmethod
196220
async def handle_websocket(self, websocket: Any):
197221
"""Abstract method for handling the websocket connection."""
198222

223+
async def get_extensions_for_request(
224+
self, request: Any, context: ContextValue | None
225+
) -> ExtensionList:
226+
"""Returns extensions to use when handling the GraphQL request.
227+
228+
Returns `ExtensionList`, a list of extensions to use or `None`.
229+
230+
# Required arguments
231+
232+
`request`: the `WebSocket` instance from Starlette or FastAPI.
233+
234+
`context`: a `ContextValue` for this request.
235+
"""
236+
if callable(self.extensions):
237+
extensions = self.extensions(request, context) # ty: ignore
238+
if isawaitable(extensions):
239+
extensions = await extensions
240+
return cast(ExtensionList, extensions)
241+
return self.extensions
242+
243+
async def get_middleware_for_request(
244+
self, request: Any, context: ContextValue | None
245+
) -> MiddlewareList:
246+
"""Returns GraphQL middlewares to use when handling the GraphQL request.
247+
248+
Returns `MiddlewareList`, a list of middlewares to use or `None`.
249+
250+
# Required arguments
251+
252+
`request`: the `WebSocket` instance from Starlette or FastAPI.
253+
254+
`context`: a `ContextValue` for this request.
255+
"""
256+
middleware = self.middleware
257+
if callable(middleware):
258+
middleware = middleware(request, context) # ty: ignore
259+
if isawaitable(middleware):
260+
middleware = await middleware
261+
if middleware:
262+
return cast(MiddlewareList, middleware)
263+
return None
264+
199265
def configure(
200266
self,
201267
*args,

ariadne/asgi/handlers/graphql_transport_ws.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,9 @@ async def handle_websocket_subscribe(
339339
"schema is not set, call configure method to initialize it"
340340
)
341341

342+
extensions = await self.get_extensions_for_request(websocket, context_value)
343+
middleware = await self.get_middleware_for_request(websocket, context_value)
344+
342345
success, results_producer = await subscribe(
343346
self.schema,
344347
data,
@@ -351,6 +354,9 @@ async def handle_websocket_subscribe(
351354
introspection=self.introspection,
352355
logger=self.logger,
353356
error_formatter=self.error_formatter,
357+
extensions=extensions,
358+
middleware=middleware,
359+
middleware_manager_class=self.middleware_manager_class,
354360
)
355361
else:
356362
if self.http_handler is None:

ariadne/asgi/handlers/graphql_ws.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,9 @@ async def start_websocket_operation(
296296
if self.schema is None:
297297
raise TypeError("schema is not set, call configure method to initialize it")
298298

299+
extensions = await self.get_extensions_for_request(websocket, context_value)
300+
middleware = await self.get_middleware_for_request(websocket, context_value)
301+
299302
success, results = await subscribe(
300303
self.schema,
301304
data,
@@ -308,6 +311,9 @@ async def start_websocket_operation(
308311
introspection=self.introspection,
309312
logger=self.logger,
310313
error_formatter=self.error_formatter,
314+
extensions=extensions,
315+
middleware=middleware,
316+
middleware_manager_class=self.middleware_manager_class,
311317
)
312318

313319
if not success:

ariadne/asgi/handlers/http.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,12 @@ async def graphql_http_server(self, request: Request) -> Response:
205205
for handler in self.subscription_handlers:
206206
if handler.supports(request, data):
207207
context_value = await self.get_context_for_request(request, data)
208+
extensions = await self.get_extensions_for_request(
209+
request, context_value
210+
)
211+
middleware = await self.get_middleware_for_request(
212+
request, context_value
213+
)
208214
return await handler.handle(
209215
request,
210216
data,
@@ -218,6 +224,9 @@ async def graphql_http_server(self, request: Request) -> Response:
218224
introspection=self.introspection,
219225
logger=self.logger,
220226
error_formatter=self.error_formatter,
227+
extensions=extensions,
228+
middleware=middleware,
229+
middleware_manager_class=self.middleware_manager_class,
221230
)
222231

223232
success, result = await self.execute_graphql_query(request, data)

ariadne/contrib/federation/utils.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import copy
12
from inspect import isawaitable
23
from typing import Any, cast
34

@@ -49,29 +50,40 @@ def _purge_directive_nodes(nodes: tuple[Node, ...]) -> tuple[Node, ...]:
4950
)
5051

5152

52-
def _purge_type_directives(definition: Node):
53-
# Recursively check every field defined on the Node definition
54-
# and remove any directives found.
53+
def _purge_type_directives(definition: Node) -> Node:
54+
"""Recursively check every field defined on the Node definition
55+
and remove any directives found. Returns a new node with purged directives."""
56+
changes: dict[str, Any] = {}
5557
for key in definition.keys:
5658
value = getattr(definition, key, None)
5759
if isinstance(value, tuple):
5860
# Remove directive nodes from the tuple
5961
# e.g. doc -> definitions [DirectiveDefinitionNode]
6062
next_value = _purge_directive_nodes(cast(tuple[Node, ...], value))
61-
for item in next_value:
62-
if isinstance(item, Node):
63-
# Look for directive nodes on sub-nodes, e.g.: doc ->
64-
# definitions [ObjectTypeDefinitionNode] -> fields -> directives
65-
_purge_type_directives(item)
66-
setattr(definition, key, next_value)
63+
# Look for directive nodes on sub-nodes, e.g.: doc ->
64+
# definitions [ObjectTypeDefinitionNode] -> fields -> directives
65+
next_value = tuple(
66+
_purge_type_directives(item) if isinstance(item, Node) else item
67+
for item in next_value
68+
)
69+
if next_value != value:
70+
changes[key] = next_value
6771
elif isinstance(value, Node):
68-
_purge_type_directives(value)
72+
new_value = _purge_type_directives(value)
73+
if new_value is not value:
74+
changes[key] = new_value
75+
if changes:
76+
new_node = copy.copy(definition)
77+
for key, value in changes.items():
78+
object.__setattr__(new_node, key, value)
79+
return new_node
80+
return definition
6981

7082

7183
def purge_schema_directives(joined_type_defs: str) -> str:
7284
"""Remove custom schema directives from federation."""
7385
ast_document = parse(joined_type_defs)
74-
_purge_type_directives(ast_document)
86+
ast_document = _purge_type_directives(ast_document)
7587
return print_ast(ast_document)
7688

7789

ariadne/contrib/sse.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,9 @@
4141
from ..subscription_handlers.handlers import SubscriptionHandler
4242
from ..types import (
4343
ErrorFormatter,
44+
ExtensionList,
4445
Extensions,
46+
MiddlewareList,
4547
Middlewares,
4648
QueryParser,
4749
QueryValidator,
@@ -386,6 +388,9 @@ async def handle(
386388
introspection: bool,
387389
logger: None | str | Logger | LoggerAdapter,
388390
error_formatter: ErrorFormatter,
391+
middleware: MiddlewareList = None,
392+
middleware_manager_class: type[MiddlewareManager] | None = None,
393+
extensions: ExtensionList | None = None,
389394
) -> Response:
390395
"""Handle the subscription request via Server-Sent Events.
391396
@@ -405,6 +410,9 @@ async def handle(
405410
introspection=introspection,
406411
logger=logger,
407412
error_formatter=error_formatter,
413+
middleware=middleware,
414+
middleware_manager_class=middleware_manager_class,
415+
extensions=extensions,
408416
),
409417
ping_interval=self.ping_interval,
410418
send_timeout=self.send_timeout,

ariadne/contrib/tracing/utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,10 @@ def repr_upload_file(upload_file: UploadFile | File) -> str:
5959

6060
def format_path(path: ResponsePath):
6161
elements = []
62-
while path:
63-
elements.append(path.key)
64-
path = path.prev
62+
current: ResponsePath | None = path
63+
while current:
64+
elements.append(current.key)
65+
current = current.prev
6566
return elements[::-1]
6667

6768

ariadne/graphql.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -441,6 +441,9 @@ async def subscribe(
441441
logger: None | str | Logger | LoggerAdapter = None,
442442
validation_rules: ValidationRules | None = None,
443443
error_formatter: ErrorFormatter = format_error,
444+
middleware: MiddlewareList = None,
445+
middleware_manager_class: type[MiddlewareManager] | None = None,
446+
extensions: ExtensionList | None = None,
444447
**kwargs,
445448
) -> SubscriptionResult:
446449
"""Subscribe to GraphQL updates.
@@ -491,9 +494,20 @@ async def subscribe(
491494
`error_formatter`: an `ErrorFormatter` callable to use to convert GraphQL
492495
errors encountered during query execution to JSON-serializable format.
493496
497+
`middleware`: a `list` of or callable returning list of GraphQL middleware
498+
to use by query executor.
499+
500+
`middleware_manager_class`: a `MiddlewareManager` class to use by query
501+
executor.
502+
503+
`extensions`: a `list` of or callable returning list of extensions
504+
to use during query execution.
505+
494506
`**kwargs`: any kwargs not used by `subscribe` are passed to
495507
`graphql.subscribe`.
496508
"""
509+
extension_manager = ExtensionManager(extensions, context_value)
510+
497511
try:
498512
validate_data(data)
499513
variables, operation_name = (
@@ -533,15 +547,21 @@ async def subscribe(
533547
if isawaitable(root_value):
534548
root_value = await root_value
535549

536-
result = await _subscribe(
550+
result = _subscribe(
537551
schema,
538552
document,
539553
root_value=root_value,
540554
context_value=context_value,
541555
variable_values=variables,
542556
operation_name=operation_name,
557+
middleware=extension_manager.as_middleware_manager(
558+
middleware, middleware_manager_class
559+
),
543560
**kwargs,
544561
)
562+
563+
if isawaitable(result):
564+
result = await result
545565
except GraphQLError as error:
546566
log_error(error, logger)
547567
return False, [error_formatter(error, debug)]

ariadne/schema_visitor.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -725,9 +725,8 @@ def heal_type(
725725
# any `GraphQLNamedType` with a `name`, then it must end up identical
726726
# to `schema.get_type(name)`, since `schema.type_map` is the source
727727
# of truth for all named schema types.
728-
named_type = cast(GraphQLNamedType, type_)
729-
official_type = schema.get_type(named_type.name)
730-
if official_type and named_type != official_type:
728+
official_type = schema.get_type(type_.name)
729+
if official_type and type_ != official_type:
731730
return official_type
732731

733732
return type_

ariadne/subscription_handlers/handlers.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,22 @@
55
from logging import Logger, LoggerAdapter
66
from typing import Any, cast
77

8-
from graphql import DocumentNode, ExecutionResult, GraphQLError, GraphQLSchema
8+
from graphql import (
9+
DocumentNode,
10+
ExecutionResult,
11+
GraphQLError,
12+
GraphQLSchema,
13+
MiddlewareManager,
14+
)
915
from starlette.requests import Request
1016
from starlette.responses import Response
1117

1218
from ..graphql import subscribe, validate_data
1319
from ..logger import log_error
1420
from ..types import (
1521
ErrorFormatter,
22+
ExtensionList,
23+
MiddlewareList,
1624
QueryParser,
1725
QueryValidator,
1826
RootValue,
@@ -59,6 +67,9 @@ async def handle(
5967
introspection: bool,
6068
logger: None | str | Logger | LoggerAdapter,
6169
error_formatter: ErrorFormatter,
70+
middleware: MiddlewareList = None,
71+
middleware_manager_class: type[MiddlewareManager] | None = None,
72+
extensions: ExtensionList | None = None,
6273
) -> Response:
6374
"""Handle the subscription request."""
6475

@@ -77,6 +88,9 @@ async def generate_events(
7788
introspection: bool,
7889
logger: None | str | Logger | LoggerAdapter,
7990
error_formatter: ErrorFormatter,
91+
middleware: MiddlewareList = None,
92+
middleware_manager_class: type[MiddlewareManager] | None = None,
93+
extensions: ExtensionList | None = None,
8094
) -> AsyncGenerator[SubscriptionEvent, None]:
8195
"""Execute subscription and yield events.
8296
@@ -105,6 +119,9 @@ async def generate_events(
105119
introspection=introspection,
106120
logger=logger,
107121
error_formatter=error_formatter,
122+
middleware=middleware,
123+
middleware_manager_class=middleware_manager_class,
124+
extensions=extensions,
108125
)
109126

110127
if not success:

0 commit comments

Comments
 (0)