@@ -118,6 +118,66 @@ async def test_oauth2_authorize():
118118 assert token ["access_token" ] == "a"
119119
120120
121+ class _FakeAsyncCache :
122+ """Minimal async cache implementing the authlib framework cache protocol."""
123+
124+ def __init__ (self ):
125+ self .store = {}
126+
127+ async def get (self , key ):
128+ return self .store .get (key )
129+
130+ async def set (self , key , value , expires = None ):
131+ self .store [key ] = value
132+
133+ async def delete (self , key ):
134+ self .store .pop (key , None )
135+
136+
137+ @pytest .mark .asyncio
138+ async def test_oauth2_authorize_csrf_with_cache ():
139+ """When a cache is configured, the state must still be bound to the
140+ session that initiated the flow. Otherwise an attacker can start an
141+ authorization request, stop before the callback, and trick a victim into
142+ completing the flow — logging the victim into the attacker's account
143+ (RFC 6749 §10.12)."""
144+ transport = ASGITransport (
145+ AsyncPathMapDispatch ({"/token" : {"body" : get_bearer_token ()}})
146+ )
147+ oauth = OAuth (cache = _FakeAsyncCache ())
148+ client = oauth .register (
149+ "dev" ,
150+ client_id = "dev" ,
151+ client_secret = "dev" ,
152+ api_base_url = "https://resource.test/api" ,
153+ access_token_url = "https://provider.test/token" ,
154+ authorize_url = "https://provider.test/authorize" ,
155+ client_kwargs = {
156+ "transport" : transport ,
157+ },
158+ )
159+
160+ # Attacker initiates an auth flow from their own session.
161+ attacker_req = Request ({"type" : "http" , "session" : {}})
162+ resp = await client .authorize_redirect (attacker_req , "https://client.test/callback" )
163+ assert resp .status_code == 302
164+ url = resp .headers .get ("Location" )
165+ state = dict (url_decode (urlparse .urlparse (url ).query ))["state" ]
166+
167+ # Victim is tricked into hitting the callback URL. The victim's browser
168+ # carries a *different* session — they never initiated this flow.
169+ victim_req = Request (
170+ {
171+ "type" : "http" ,
172+ "path" : "/" ,
173+ "query_string" : f"code=a&state={ state } " .encode (),
174+ "session" : {},
175+ }
176+ )
177+ with pytest .raises (OAuthError ):
178+ await client .authorize_access_token (victim_req )
179+
180+
121181@pytest .mark .asyncio
122182async def test_oauth2_authorize_access_denied ():
123183 oauth = OAuth ()
0 commit comments