Skip to content

Commit aa7b8e4

Browse files
authored
Merge commit from fork
Fix CSRF issue with starlette client
2 parents ef09aeb + 401a770 commit aa7b8e4

4 files changed

Lines changed: 85 additions & 14 deletions

File tree

authlib/integrations/starlette_client/apps.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,7 @@ class StarletteAppMixin:
1414
async def save_authorize_data(self, request, **kwargs):
1515
state = kwargs.pop("state", None)
1616
if state:
17-
if self.framework.cache:
18-
session = None
19-
else:
20-
session = request.session
21-
await self.framework.set_state_data(session, state, kwargs)
17+
await self.framework.set_state_data(request.session, state, kwargs)
2218
else:
2319
raise RuntimeError("Missing state value")
2420

@@ -80,13 +76,10 @@ async def authorize_access_token(self, request, **kwargs):
8076
"state": form.get("state"),
8177
}
8278

83-
if self.framework.cache:
84-
session = None
85-
else:
86-
session = request.session
87-
88-
state_data = await self.framework.get_state_data(session, params.get("state"))
89-
await self.framework.clear_state_data(session, params.get("state"))
79+
state_data = await self.framework.get_state_data(
80+
request.session, params.get("state")
81+
)
82+
await self.framework.clear_state_data(request.session, params.get("state"))
9083
params = self._format_state_params(state_data, params)
9184

9285
claims_options = kwargs.pop("claims_options", None)

authlib/integrations/starlette_client/integration.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@ async def get_state_data(
2222
) -> dict[str, Any]:
2323
key = f"_state_{self.name}_{state}"
2424
if self.cache:
25+
# require a session-bound marker to prove the callback originates
26+
# from the user-agent that started the flow (RFC 6749 §10.12)
27+
if session is None or session.get(key) is None:
28+
return None
2529
value = await self._get_cache_data(key)
2630
elif session is not None:
2731
value = session.get(key)
@@ -37,21 +41,27 @@ async def set_state_data(
3741
):
3842
key_prefix = f"_state_{self.name}_"
3943
key = f"{key_prefix}{state}"
44+
now = time.time()
4045
if self.cache:
4146
await self.cache.set(key, json.dumps({"data": data}), self.expires_in)
47+
if session is not None:
48+
# clear old state data to avoid session size growing
49+
for old_key in list(session.keys()):
50+
if old_key.startswith(key_prefix):
51+
session.pop(old_key)
52+
session[key] = {"exp": now + self.expires_in}
4253
elif session is not None:
4354
# clear old state data to avoid session size growing
4455
for old_key in list(session.keys()):
4556
if old_key.startswith(key_prefix):
4657
session.pop(old_key)
47-
now = time.time()
4858
session[key] = {"data": data, "exp": now + self.expires_in}
4959

5060
async def clear_state_data(self, session: Optional[dict[str, Any]], state: str):
5161
key = f"_state_{self.name}_{state}"
5262
if self.cache:
5363
await self.cache.delete(key)
54-
elif session is not None:
64+
if session is not None:
5565
session.pop(key, None)
5666
self._clear_session_state(session)
5767

docs/changelog.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,14 @@ Changelog
66

77
Here you can see the full list of changes between each Authlib release.
88

9+
Version 1.6.11
10+
--------------
11+
12+
**Released on Apr 15, 2026**
13+
14+
- Fix CSRF vulnerability in the Starlette OAuth client when a ``cache`` is
15+
configured.
16+
917
Version 1.6.10
1018
--------------
1119

tests/clients/test_starlette/test_oauth_client.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,66 @@ async def test_oauth2_authorize():
118118
assert token["access_token"] == "a"
119119

120120

121+
class _FakeAsyncCache:
122+
"""Minimal async cache implementing the authlib framework cache protocol."""
123+
124+
def __init__(self):
125+
self.store = {}
126+
127+
async def get(self, key):
128+
return self.store.get(key)
129+
130+
async def set(self, key, value, expires=None):
131+
self.store[key] = value
132+
133+
async def delete(self, key):
134+
self.store.pop(key, None)
135+
136+
137+
@pytest.mark.asyncio
138+
async def test_oauth2_authorize_csrf_with_cache():
139+
"""When a cache is configured, the state must still be bound to the
140+
session that initiated the flow. Otherwise an attacker can start an
141+
authorization request, stop before the callback, and trick a victim into
142+
completing the flow — logging the victim into the attacker's account
143+
(RFC 6749 §10.12)."""
144+
transport = ASGITransport(
145+
AsyncPathMapDispatch({"/token": {"body": get_bearer_token()}})
146+
)
147+
oauth = OAuth(cache=_FakeAsyncCache())
148+
client = oauth.register(
149+
"dev",
150+
client_id="dev",
151+
client_secret="dev",
152+
api_base_url="https://resource.test/api",
153+
access_token_url="https://provider.test/token",
154+
authorize_url="https://provider.test/authorize",
155+
client_kwargs={
156+
"transport": transport,
157+
},
158+
)
159+
160+
# Attacker initiates an auth flow from their own session.
161+
attacker_req = Request({"type": "http", "session": {}})
162+
resp = await client.authorize_redirect(attacker_req, "https://client.test/callback")
163+
assert resp.status_code == 302
164+
url = resp.headers.get("Location")
165+
state = dict(url_decode(urlparse.urlparse(url).query))["state"]
166+
167+
# Victim is tricked into hitting the callback URL. The victim's browser
168+
# carries a *different* session — they never initiated this flow.
169+
victim_req = Request(
170+
{
171+
"type": "http",
172+
"path": "/",
173+
"query_string": f"code=a&state={state}".encode(),
174+
"session": {},
175+
}
176+
)
177+
with pytest.raises(OAuthError):
178+
await client.authorize_access_token(victim_req)
179+
180+
121181
@pytest.mark.asyncio
122182
async def test_oauth2_authorize_access_denied():
123183
oauth = OAuth()

0 commit comments

Comments
 (0)