Skip to content

Commit f8e84e4

Browse files
feat: Automated SQLAlchemy integration
1 parent f9a95a3 commit f8e84e4

15 files changed

Lines changed: 917 additions & 12 deletions

File tree

AGENTS.md

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,19 @@ hatch test benchmark --benchmark-storage=file://benchmark/results
5353
just coverage
5454
```
5555

56+
### Running Examples
57+
58+
Use `uv` to run examples from the `examples/` directory. Refer to the docstrings within each example file for specific commands.
59+
Every example should include an example command of running a particular example with uvicorn.
60+
```bash
61+
# Run a basic query example
62+
uv run examples/basic_query_example.py
63+
64+
# Run an ASGI example with uvicorn
65+
uv run --with "uvicorn[standard]" --with ariadne \
66+
uvicorn examples.basic_query_example:app --reload
67+
```
68+
5669
## Code Style Requirements
5770

5871
- **Python 3.10+** with type hints throughout
@@ -96,4 +109,6 @@ Follow [Conventional Commits](https://www.conventionalcommits.org/):
96109
2. Ensure test coverage meets the 90% minimum requirement
97110
3. Format code with `just fmt`
98111
4. Verify type hints with `just types`
99-
5. Write a clear commit message following the conventional commits format
112+
5. Ensure the documentation is up-to-date
113+
6. Write a clear commit message following the conventional commits format
114+

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ Documentation is available [here](https://ariadnegraphql.org).
4242
- Loading schema from `.graphql`, `.gql`, and `.graphqls` files.
4343
- ASGI and WSGI support, with integrations for Django, FastAPI, Flask, and Starlette.
4444
- Opt-in automatic resolvers mapping between `camelCase123` and `snake_case_123`.
45+
- Automated integration with **SQLAlchemy 2.0** for zero-boilerplate resolvers and N+1 prevention.
4546
- [OpenTelemetry](https://opentelemetry.io/) extension for API monitoring.
4647
- Built-in [GraphiQL](https://github.com/graphql/graphiql) explorer for development and testing.
4748
- GraphQL syntax validation via `gql()` helper function.
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
try:
2+
from .dataloaders import LoaderRegistry, SQLAlchemyRelationLoader
3+
from .objects import SQLAlchemyObjectType
4+
from .query import SQLAlchemyQueryType
5+
from .utils import auto_eager_load
6+
except ImportError as ex:
7+
raise ImportError(
8+
"SQLAlchemy integration requires the 'sqlalchemy' and 'aiodataloader' "
9+
"packages. Install them using 'pip install \"ariadne[sqlalchemy]\"'."
10+
) from ex
11+
12+
__all__ = [
13+
"LoaderRegistry",
14+
"SQLAlchemyObjectType",
15+
"SQLAlchemyQueryType",
16+
"SQLAlchemyRelationLoader",
17+
"auto_eager_load",
18+
]
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
import logging
2+
from collections import defaultdict
3+
from typing import Any
4+
5+
from aiodataloader import DataLoader
6+
from sqlalchemy import select, tuple_
7+
from sqlalchemy.ext.asyncio import AsyncSession
8+
from sqlalchemy.orm import RelationshipProperty
9+
10+
logger = logging.getLogger(__name__)
11+
12+
13+
class SQLAlchemyRelationLoader(DataLoader):
14+
"""
15+
DataLoader for SQLAlchemy relationships supporting:
16+
- Composite Keys
17+
- Many-to-Many (secondary tables)
18+
- Result grouping via SQL columns (optimized)
19+
"""
20+
21+
def __init__(
22+
self,
23+
session: AsyncSession,
24+
relation_prop: RelationshipProperty,
25+
cache: bool = True,
26+
):
27+
super().__init__(cache=cache)
28+
self.session = session
29+
self.relation_prop = relation_prop
30+
self.target_model = relation_prop.mapper.class_
31+
self.is_list = relation_prop.uselist
32+
33+
# Identify local and remote columns (handles composite keys)
34+
if relation_prop.secondary is not None:
35+
self.local_cols = [lp.key for lp, rp in relation_prop.synchronize_pairs]
36+
self.remote_cols = [rp.key for lp, rp in relation_prop.synchronize_pairs]
37+
else:
38+
self.local_cols = [c.key for c in relation_prop.local_columns]
39+
self.remote_cols = [c.key for c in relation_prop.remote_side]
40+
41+
self.secondary = relation_prop.secondary
42+
43+
def get_query(self, keys: list[Any]):
44+
"""Builds query. Handles composite IN clause and M2M joins."""
45+
target_model = self.target_model
46+
stmt = select(target_model)
47+
48+
if self.secondary is not None:
49+
stmt = stmt.join(self.secondary)
50+
filter_cols = [self.secondary.c[k] for k in self.remote_cols] # ty: ignore[invalid-argument-type]
51+
else:
52+
filter_cols = [getattr(target_model, k) for k in self.remote_cols] # ty: ignore[invalid-argument-type]
53+
54+
# Add the filtering columns to the result to allow grouping
55+
stmt = stmt.add_columns(*filter_cols)
56+
57+
if len(filter_cols) > 1:
58+
stmt = stmt.where(tuple_(*filter_cols).in_(keys))
59+
else:
60+
# Flatten keys if they are single-element tuples
61+
flat_keys = [k[0] if isinstance(k, (list, tuple)) else k for k in keys]
62+
stmt = stmt.where(filter_cols[0].in_(flat_keys))
63+
64+
return stmt
65+
66+
async def batch_load_fn(self, keys: list[Any]) -> list[Any]:
67+
logger.debug(
68+
"SQLAlchemyRelationLoader: Fetching %s for %d parents",
69+
self.target_model.__name__,
70+
len(keys),
71+
)
72+
stmt = self.get_query(keys)
73+
result = await self.session.execute(stmt)
74+
rows = result.all()
75+
76+
num_filter_cols = len(self.remote_cols)
77+
grouped = defaultdict(list)
78+
79+
for row in rows:
80+
item = row[0]
81+
# The filter columns are appended after the model instance
82+
key_parts = row[1 : 1 + num_filter_cols]
83+
key = tuple(key_parts) if num_filter_cols > 1 else key_parts[0]
84+
grouped[key].append(item)
85+
86+
return [
87+
grouped[k] if self.is_list else (grouped[k][0] if grouped[k] else None)
88+
for k in keys
89+
]
90+
91+
92+
class LoaderRegistry:
93+
"""
94+
Keeps one DataLoader instance per relationship per request.
95+
"""
96+
97+
def __init__(self, session: AsyncSession):
98+
self.session = session
99+
self._loaders: dict[
100+
tuple[RelationshipProperty, type[DataLoader]], DataLoader
101+
] = {}
102+
103+
def get_loader(
104+
self,
105+
relation_prop: RelationshipProperty,
106+
loader_class: type[SQLAlchemyRelationLoader] = SQLAlchemyRelationLoader,
107+
) -> DataLoader:
108+
key = (relation_prop, loader_class)
109+
if key not in self._loaders:
110+
self._loaders[key] = loader_class(self.session, relation_prop)
111+
return self._loaders[key]
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
from collections.abc import Callable
2+
from typing import Any
3+
4+
from graphql import GraphQLSchema
5+
from sqlalchemy import select
6+
from sqlalchemy.orm import DeclarativeBase, RelationshipProperty, class_mapper
7+
8+
from ...objects import ObjectType
9+
from .dataloaders import LoaderRegistry, SQLAlchemyRelationLoader
10+
11+
12+
class SQLAlchemyObjectType(ObjectType):
13+
"""
14+
ObjectType specialized for SQLAlchemy models.
15+
Automatically binds resolvers for relationships using DataLoaders.
16+
"""
17+
18+
model: type[DeclarativeBase]
19+
aliases: Any
20+
strategies: dict[str, Any]
21+
max_depth: int
22+
loader_registry_key: str
23+
24+
def __init__(
25+
self,
26+
name: str,
27+
model: type[DeclarativeBase],
28+
*,
29+
aliases: dict[str, str] | Callable[[], dict[str, str]] | None = None,
30+
strategies: dict[str, Any] | None = None,
31+
max_depth: int = 3,
32+
loader_registry_key: str = "loader_registry",
33+
):
34+
super().__init__(name)
35+
self.model = model
36+
self.aliases = aliases or {}
37+
self.strategies = strategies or {}
38+
self.max_depth = max_depth
39+
self.loader_registry_key = loader_registry_key
40+
41+
def bind_to_schema(self, schema: GraphQLSchema) -> None:
42+
self._bind_auto_resolvers()
43+
super().bind_to_schema(schema)
44+
45+
def get_base_query(self, info: Any, **kwargs: Any):
46+
"""
47+
Returns the base SQLAlchemy select statement for root queries.
48+
Can be overridden to apply default filters.
49+
"""
50+
return select(self.model)
51+
52+
def _bind_auto_resolvers(self):
53+
mapper = class_mapper(self.model)
54+
55+
if callable(self.aliases):
56+
self.aliases = self.aliases()
57+
58+
# Bind field aliases
59+
for gql_field, db_attr in self.aliases.items():
60+
if callable(db_attr):
61+
self.set_field(gql_field, db_attr)
62+
else:
63+
self.set_field(
64+
gql_field, lambda obj, *_, _attr=db_attr: getattr(obj, _attr)
65+
)
66+
67+
# Bind relationships
68+
for relation in mapper.relationships:
69+
if relation.key not in self._resolvers:
70+
self.set_field(relation.key, self._create_relation_resolver(relation))
71+
72+
def _create_relation_resolver(self, relation: RelationshipProperty):
73+
async def resolve(obj: Any, info: Any, **kwargs: Any):
74+
# If the attribute is already loaded (e.g. via joinedload/selectinload),
75+
# return it
76+
if relation.key in obj.__dict__:
77+
return getattr(obj, relation.key)
78+
79+
registry: LoaderRegistry = info.context.get(self.loader_registry_key)
80+
if registry is None:
81+
raise RuntimeError(
82+
"LoaderRegistry not found in context under key "
83+
f"'{self.loader_registry_key}'"
84+
)
85+
86+
# Build the key for the loader
87+
local_cols = [c.key for c in relation.local_columns]
88+
key = tuple(getattr(obj, k) for k in local_cols) # ty: ignore[invalid-argument-type]
89+
if len(key) == 1:
90+
key = key[0]
91+
92+
loader = registry.get_loader(relation, SQLAlchemyRelationLoader)
93+
return await loader.load(key)
94+
95+
return resolve
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
from collections.abc import Sequence
2+
from typing import Any
3+
4+
from graphql import GraphQLList, GraphQLNonNull, GraphQLObjectType, GraphQLSchema
5+
6+
from ...objects import ObjectType
7+
from .objects import SQLAlchemyObjectType
8+
from .utils import auto_eager_load
9+
10+
11+
class SQLAlchemyQueryType(ObjectType):
12+
"""
13+
A custom Query type that automatically binds SQLAlchemy resolvers
14+
by inspecting the GraphQLSchema during the make_executable_schema build phase.
15+
"""
16+
17+
def __init__(
18+
self,
19+
name: str,
20+
object_types: Sequence[SQLAlchemyObjectType],
21+
*,
22+
session_key: str = "session",
23+
):
24+
super().__init__(name)
25+
self.object_types = {ot.name: ot for ot in object_types}
26+
self.session_key = session_key
27+
28+
def bind_to_schema(self, schema: GraphQLSchema) -> None:
29+
graphql_type = schema.type_map.get(self.name)
30+
if not isinstance(graphql_type, GraphQLObjectType):
31+
super().bind_to_schema(schema)
32+
return
33+
34+
for field_name, field_def in graphql_type.fields.items():
35+
is_list = False
36+
unwrapped_type = field_def.type
37+
38+
while isinstance(unwrapped_type, (GraphQLList, GraphQLNonNull)):
39+
if isinstance(unwrapped_type, GraphQLList):
40+
is_list = True
41+
unwrapped_type = unwrapped_type.of_type
42+
43+
type_name = getattr(unwrapped_type, "name", None)
44+
45+
if type_name in self.object_types and field_name not in self._resolvers:
46+
obj_type = self.object_types[type_name]
47+
self.set_field(
48+
field_name, self._create_auto_resolver(obj_type, is_list)
49+
)
50+
51+
super().bind_to_schema(schema)
52+
53+
def _create_auto_resolver(self, obj_type: SQLAlchemyObjectType, return_list: bool):
54+
async def auto_resolve(obj: Any, info: Any, **kwargs: Any):
55+
session = info.context.get(self.session_key)
56+
if session is None:
57+
raise RuntimeError(
58+
f"Session not found in context under key '{self.session_key}'"
59+
)
60+
61+
model = obj_type.model
62+
stmt = obj_type.get_base_query(info, **kwargs)
63+
64+
stmt = auto_eager_load(
65+
stmt,
66+
info,
67+
model,
68+
strategies=obj_type.strategies,
69+
aliases=obj_type.aliases,
70+
max_depth=obj_type.max_depth,
71+
)
72+
73+
for key, value in kwargs.items():
74+
db_col_name = obj_type.aliases.get(key, key)
75+
if hasattr(model, db_col_name):
76+
stmt = stmt.where(getattr(model, db_col_name) == value)
77+
78+
result = await session.execute(stmt)
79+
80+
if return_list:
81+
return result.scalars().unique().all()
82+
return result.scalars().first()
83+
84+
return auto_resolve

0 commit comments

Comments
 (0)