Skip to content

Commit 3b8538e

Browse files
authored
Allow private JWT headers (#4290)
1 parent 0445c31 commit 3b8538e

2 files changed

Lines changed: 103 additions & 1 deletion

File tree

fastmcp_slim/fastmcp/server/auth/providers/jwt.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
from cryptography.hazmat.primitives.asymmetric import rsa
1414
from joserfc import jwk, jwt
1515
from joserfc.errors import JoseError
16+
from joserfc.jws import JWSRegistry
17+
from joserfc.registry import JWS_HEADER_REGISTRY
1618
from pydantic import AnyHttpUrl, SecretStr
1719
from typing_extensions import TypedDict
1820

@@ -24,6 +26,7 @@
2426
logger = get_logger(__name__)
2527

2628
JWKKeyData: TypeAlias = dict[str, str | list[str]]
29+
SUPPORTED_JWS_HEADER_FIELDS = frozenset(JWS_HEADER_REGISTRY)
2730

2831

2932
def _import_key_for_algorithm(key: str | bytes | JWKKeyData, algorithm: str):
@@ -45,6 +48,21 @@ def _jwk_to_pem(key_data: JWKKeyData) -> str:
4548
raise ValueError(f"Unsupported JWK key type: {key_type!r}")
4649

4750

51+
def _has_unsupported_critical_headers(header: dict[str, Any]) -> bool:
52+
crit = header.get("crit")
53+
if crit is None:
54+
return False
55+
if not isinstance(crit, list):
56+
return True
57+
58+
return any(
59+
not isinstance(header_name, str)
60+
or header_name not in header
61+
or header_name not in SUPPORTED_JWS_HEADER_FIELDS
62+
for header_name in crit
63+
)
64+
65+
4866
class JWKData(TypedDict, total=False):
4967
"""JSON Web Key data structure."""
5068

@@ -429,7 +447,22 @@ async def load_access_token(self, token: str) -> AccessToken | None:
429447

430448
# Decode and verify the JWT token
431449
key = _import_key_for_algorithm(verification_key, self.algorithm)
432-
claims = jwt.decode(token, key, algorithms=[self.algorithm]).claims
450+
header = decode_jwt_header(token)
451+
if _has_unsupported_critical_headers(header):
452+
self.logger.debug(
453+
"Token validation failed: unsupported critical JWT header"
454+
)
455+
return None
456+
457+
claims = jwt.decode(
458+
token,
459+
key,
460+
algorithms=[self.algorithm],
461+
registry=JWSRegistry(
462+
algorithms=[self.algorithm],
463+
strict_check_header=False,
464+
),
465+
).claims
433466

434467
# Extract client ID early for logging
435468
client_id = (

tests/server/auth/test_jwt_provider.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
import pytest
77
from joserfc import jwk as jose_jwk
88
from joserfc import jwt
9+
from joserfc.jws import JWSRegistry
10+
from joserfc.registry import HeaderParameter
911
from pytest_httpx import HTTPXMock
1012

1113
from fastmcp import FastMCP
@@ -182,6 +184,73 @@ def test_create_token_with_scopes(self, rsa_key_pair: RSAKeyPair):
182184
# We'll validate the scopes in the BearerToken tests
183185

184186

187+
class TestJWTVerifierHeaders:
188+
async def test_non_critical_private_header_is_allowed(
189+
self, rsa_key_pair: RSAKeyPair
190+
):
191+
signing_key = jose_jwk.import_key(
192+
rsa_key_pair.private_key.get_secret_value(),
193+
"RSA",
194+
)
195+
token = jwt.encode(
196+
{
197+
"alg": "RS256",
198+
"cat": "cl_example",
199+
},
200+
{
201+
"sub": "test-user",
202+
"iss": "https://test.example.com",
203+
"exp": int(time.time()) + 3600,
204+
},
205+
signing_key,
206+
algorithms=["RS256"],
207+
registry=JWSRegistry(strict_check_header=False),
208+
)
209+
verifier = JWTVerifier(
210+
public_key=rsa_key_pair.public_key,
211+
issuer="https://test.example.com",
212+
)
213+
214+
access_token = await verifier.verify_token(token)
215+
216+
assert access_token is not None
217+
assert access_token.client_id == "test-user"
218+
219+
async def test_critical_private_header_is_rejected(self, rsa_key_pair: RSAKeyPair):
220+
signing_key = jose_jwk.import_key(
221+
rsa_key_pair.private_key.get_secret_value(),
222+
"RSA",
223+
)
224+
token = jwt.encode(
225+
{
226+
"alg": "RS256",
227+
"crit": ["cat"],
228+
"cat": "cl_example",
229+
},
230+
{
231+
"sub": "test-user",
232+
"iss": "https://test.example.com",
233+
"exp": int(time.time()) + 3600,
234+
},
235+
signing_key,
236+
algorithms=["RS256"],
237+
registry=JWSRegistry(
238+
header_registry={
239+
"cat": HeaderParameter("Custom private header", "str")
240+
},
241+
strict_check_header=False,
242+
),
243+
)
244+
verifier = JWTVerifier(
245+
public_key=rsa_key_pair.public_key,
246+
issuer="https://test.example.com",
247+
)
248+
249+
access_token = await verifier.verify_token(token)
250+
251+
assert access_token is None
252+
253+
185254
class TestSymmetricKeyJWT:
186255
"""Tests for JWT verification using symmetric keys (HMAC algorithms)."""
187256

0 commit comments

Comments
 (0)