Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions litestar/security/jwt/token.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def __post_init__(self) -> None:
def decode_payload(
cls,
encoded_token: str,
secret: str,
secret: str | bytes,
algorithms: list[str],
issuer: list[str] | None = None,
audience: str | Sequence[str] | None = None,
Expand All @@ -110,7 +110,7 @@ def decode_payload(
def decode(
cls,
encoded_token: str,
secret: str,
secret: str | bytes,
algorithm: str,
audience: str | Sequence[str] | None = None,
issuer: str | Sequence[str] | None = None,
Expand Down Expand Up @@ -194,12 +194,18 @@ def decode(
) as e:
raise NotAuthorizedException("Invalid token") from e

def encode(self, secret: str, algorithm: str) -> str:
def encode(
self,
secret: str | bytes,
algorithm: str,
headers: dict[str, Any] | None = None,
) -> str:
"""Encode the token instance into a string.

Args:
secret: The secret with which the JWT is encoded.
algorithm: The algorithm used to encode the JWT.
headers: Optional headers to include in the JWT (e.g., {"kid": "..."}).

Returns:
An encoded token string.
Expand All @@ -212,6 +218,7 @@ def encode(self, secret: str, algorithm: str) -> str:
payload={k: v for k, v in asdict(self).items() if v is not None},
key=secret,
algorithm=algorithm,
headers=headers,
)
except (jwt.DecodeError, NotImplementedError) as e:
raise ImproperlyConfiguredException("Failed to encode token") from e
13 changes: 12 additions & 1 deletion tests/unit/test_security/test_jwt/test_token.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ class CustomToken(Token):
def decode_payload(
cls,
encoded_token: str,
secret: str,
secret: str | bytes,
algorithms: list[str],
issuer: list[str] | None = None,
audience: str | Sequence[str] | None = None,
Expand All @@ -223,3 +223,14 @@ def decode_payload(
_secret = secrets.token_hex()
encoded = CustomToken(exp=datetime.now() + timedelta(days=1), sub="foo").encode(_secret, "HS256")
assert CustomToken.decode(encoded, secret=_secret, algorithm="HS256").sub == "some-random-value"


def test_token_encode_includes_custom_headers() -> None:
token = Token(exp=datetime.now() + timedelta(days=1), sub="some-random-value")
custom_headers = {"kid": "key-id"}
encoded = token.encode(secret=secrets.token_hex(), algorithm="HS256", headers=custom_headers)
header = jwt.get_unverified_header(encoded)

assert header["alg"] == "HS256"
assert "kid" in header
assert header["kid"] == custom_headers["kid"]
Loading