diff --git a/litestar/security/jwt/token.py b/litestar/security/jwt/token.py index db7fbc9021..344853d39b 100644 --- a/litestar/security/jwt/token.py +++ b/litestar/security/jwt/token.py @@ -90,7 +90,7 @@ def __post_init__(self) -> None: def decode_payload( cls, encoded_token: str, - secret: str, + secret: str | bytes, algorithms: list[str], issuer: list[str] | None = None, audience: str | Sequence[str] | None = None, @@ -110,7 +110,7 @@ def decode_payload( def decode( cls, encoded_token: str, - secret: str, + secret: str | bytes, algorithm: str, audience: str | Sequence[str] | None = None, issuer: str | Sequence[str] | None = None, @@ -194,12 +194,18 @@ def decode( ) as e: raise NotAuthorizedException("Invalid token") from e - def encode(self, secret: str, algorithm: str) -> str: + def encode( + self, + secret: str | bytes, + algorithm: str, + headers: dict[str, Any] | None = None, + ) -> str: """Encode the token instance into a string. Args: secret: The secret with which the JWT is encoded. algorithm: The algorithm used to encode the JWT. + headers: Optional headers to include in the JWT (e.g., {"kid": "..."}). Returns: An encoded token string. @@ -212,6 +218,7 @@ def encode(self, secret: str, algorithm: str) -> str: payload={k: v for k, v in asdict(self).items() if v is not None}, key=secret, algorithm=algorithm, + headers=headers, ) except (jwt.DecodeError, NotImplementedError) as e: raise ImproperlyConfiguredException("Failed to encode token") from e diff --git a/tests/unit/test_security/test_jwt/test_token.py b/tests/unit/test_security/test_jwt/test_token.py index f93d719a6d..34328646d5 100644 --- a/tests/unit/test_security/test_jwt/test_token.py +++ b/tests/unit/test_security/test_jwt/test_token.py @@ -206,7 +206,7 @@ class CustomToken(Token): def decode_payload( cls, encoded_token: str, - secret: str, + secret: str | bytes, algorithms: list[str], issuer: list[str] | None = None, audience: str | Sequence[str] | None = None, @@ -223,3 +223,14 @@ def decode_payload( _secret = secrets.token_hex() encoded = CustomToken(exp=datetime.now() + timedelta(days=1), sub="foo").encode(_secret, "HS256") assert CustomToken.decode(encoded, secret=_secret, algorithm="HS256").sub == "some-random-value" + + +def test_token_encode_includes_custom_headers() -> None: + token = Token(exp=datetime.now() + timedelta(days=1), sub="some-random-value") + custom_headers = {"kid": "key-id"} + encoded = token.encode(secret=secrets.token_hex(), algorithm="HS256", headers=custom_headers) + header = jwt.get_unverified_header(encoded) + + assert header["alg"] == "HS256" + assert "kid" in header + assert header["kid"] == custom_headers["kid"]