Skip to content

Commit

Permalink
feat(jwt): Added revoked token handler support (#3960)
Browse files Browse the repository at this point in the history
* feat(jwt): add revoked token handler to authentication middleware

* feat(jwt): implement logout functionality with revoked token handling

* feat(jwt): add example implementation for JWT token revocation and user authentication

* fix(jwt): set default value for revoked_token_handler in JWT authentication classes

* docs(jwt): add documentation for implementing token revocation

* fix(jwt): fixed type error on test_auth

* Update docs/examples/security/jwt/using_token_revocation.py

---------

Co-authored-by: Alexandr Panteleev <[email protected]>
Co-authored-by: Cody Fincher <[email protected]>
  • Loading branch information
3 people authored Jan 21, 2025
1 parent 46d69f4 commit 844c90a
Show file tree
Hide file tree
Showing 5 changed files with 195 additions and 3 deletions.
102 changes: 102 additions & 0 deletions docs/examples/security/jwt/using_token_revocation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
from os import environ
from typing import Any, Dict, Optional
from uuid import UUID

from pydantic import BaseModel, EmailStr

from litestar import Litestar, Request, Response, get, post
from litestar.connection import ASGIConnection
from litestar.openapi.config import OpenAPIConfig
from litestar.security.jwt import JWTAuth, Token


# Let's assume we have a User model that is a pydantic model.
# This though is not required - we need some sort of user class -
# but it can be any arbitrary value, e.g. an SQLAlchemy model, a representation of a MongoDB etc.
class User(BaseModel):
id: UUID
name: str
email: EmailStr


MOCK_DB: Dict[str, User] = {}
BLOCKLIST: Dict[str, str] = {}


# JWTAuth requires a retrieve handler callable that receives the JWT token model and the ASGI connection
# and returns the 'User' instance correlating to it.
#
# Notes:
# - 'User' can be any arbitrary value you decide upon.
# - The callable can be either sync or async - both will work.
async def retrieve_user_handler(token: Token, connection: "ASGIConnection[Any, Any, Any, Any]") -> Optional[User]:
# logic here to retrieve the user instance
return MOCK_DB.get(token.sub)


# If you want to use JWTAuth with revoking tokens, you have to define a handler of revoked tokens
# with your custom logic.
async def revoked_token_handler(token: Token, connection: "ASGIConnection[Any, Any, Any, Any]") -> bool:
jti = token.jti # Unique token identifier (JWT ID)
if jti:
# Check if the token is already revoked in the BLOCKLIST
revoked = BLOCKLIST.get(jti)
if revoked:
return True
return False


jwt_auth = JWTAuth[User](
retrieve_user_handler=retrieve_user_handler,
revoked_token_handler=revoked_token_handler,
token_secret=environ.get("JWT_SECRET", "abcd123"),
# we are specifying which endpoints should be excluded from authentication. In this case the login endpoint
# and our openAPI docs.
exclude=["/login", "/schema"],
)


# Given an instance of 'JWTAuth' we can create a login handler function:
@post("/login")
async def login_handler(data: User) -> Response[User]:
MOCK_DB[str(data.id)] = data
# you can do whatever you want to update the response instance here
# e.g. response.set_cookie(...)
return jwt_auth.login(identifier=str(data.id), token_extras={"email": data.email}, response_body=data)


# Also we can create a logout
@post("/logout")
async def logout_handler(request: Request["User", Token, Any]) -> Dict[str, str]:
# Your custom logic here
# For example
jti = request.auth.jti
if jti:
BLOCKLIST[jti] = "revoked"
return {"message": "Token has been revoked."}
return {"message": "No valid token found."}


# We also have some other routes, for example:
@get("/some-path", sync_to_thread=False, middleware=[jwt_auth.middleware])
def some_route_handler(request: "Request[User, Token, Any]") -> Any:
# request.user is set to the instance of user returned by the middleware
assert isinstance(request.user, User)
# request.auth is the instance of 'litestar.security.jwt.Token' created from the data encoded in the auth header
assert isinstance(request.auth, Token)
# do stuff ...


# We create our OpenAPIConfig as usual - the JWT security scheme will be injected into it.
openapi_config = OpenAPIConfig(
title="My API",
version="1.0.0",
)

# We initialize the app instance and pass the jwt_auth 'on_app_init' handler to the constructor.
# The hook handler will inject the JWT middleware and openapi configuration into the app.
app = Litestar(
route_handlers=[login_handler, logout_handler, some_route_handler],
on_app_init=[jwt_auth.on_app_init],
openapi_config=openapi_config,
)
12 changes: 12 additions & 0 deletions docs/usage/security/jwt.rst
Original file line number Diff line number Diff line change
Expand Up @@ -100,3 +100,15 @@ dictionary representing the decoded payload, which will then used by
.. literalinclude:: /examples/security/jwt/custom_decode_payload.py
:language: python
:caption: Customizing payload decoding


Using token revocation
----------------------
Token revocation can be implemented by maintaining a list of revoked tokens and checking against this list during authentication.
When a token is revoked, it should be added to the list, and any subsequent requests with that token should be denied.

.. dropdown:: Click to see the code

.. literalinclude:: /examples/security/jwt/using_token_revocation.py
:language: python
:caption: Implementing token revocation
15 changes: 15 additions & 0 deletions litestar/security/jwt/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ class BaseJWTAuth(Generic[UserType, TokenT], AbstractSecurityConfig[UserType, To
- The callable can be sync or async. If it is sync, it will be wrapped to support async.
"""
revoked_token_handler: Callable[[Any, ASGIConnection], SyncOrAsyncUnion[bool]] | None = field(default=None)
"""Callable that receives the auth value from the authentication middleware and checks whether the token has been revoked,
returning True if revoked, False otherwise."""
algorithm: str
"""Algorithm to use for JWT hashing."""
auth_header: str
Expand Down Expand Up @@ -138,6 +141,7 @@ def middleware(self) -> DefineMiddleware:
exclude_opt_key=self.exclude_opt_key,
exclude_http_methods=self.exclude_http_methods,
retrieve_user_handler=self.retrieve_user_handler,
revoked_token_handler=self.revoked_token_handler,
scopes=self.scopes,
token_secret=self.token_secret,
token_cls=self.token_cls,
Expand Down Expand Up @@ -276,6 +280,9 @@ class JWTAuth(Generic[UserType, TokenT], BaseJWTAuth[UserType, TokenT]):
- The callable can be sync or async. If it is sync, it will be wrapped to support async.
"""
revoked_token_handler: Callable[[Any, ASGIConnection], SyncOrAsyncUnion[bool]] | None = field(default=None)
"""Callable that receives the auth value from the authentication middleware and checks whether the token has been revoked,
returning True if revoked, False otherwise."""
guards: Iterable[Guard] | None = field(default=None)
"""An iterable of guards to call for requests, providing authorization functionalities."""
exclude: str | list[str] | None = field(default=None)
Expand Down Expand Up @@ -364,6 +371,9 @@ class and adds support for passing JWT tokens ``HttpOnly`` cookies.
- The callable can be sync or async. If it is sync, it will be wrapped to support async.
"""
revoked_token_handler: Callable[[Any, ASGIConnection], SyncOrAsyncUnion[bool]] | None = field(default=None)
"""Callable that receives the auth value from the authentication middleware and checks whether the token has been revoked,
returning True if revoked, False otherwise."""
guards: Iterable[Guard] | None = field(default=None)
"""An iterable of guards to call for requests, providing authorization functionalities."""
exclude: str | list[str] | None = field(default=None)
Expand Down Expand Up @@ -477,6 +487,7 @@ def middleware(self) -> DefineMiddleware:
exclude_opt_key=self.exclude_opt_key,
exclude_http_methods=self.exclude_http_methods,
retrieve_user_handler=self.retrieve_user_handler,
revoked_token_handler=self.revoked_token_handler,
scopes=self.scopes,
token_secret=self.token_secret,
token_cls=self.token_cls,
Expand Down Expand Up @@ -598,6 +609,9 @@ class OAuth2PasswordBearerAuth(Generic[UserType, TokenT], BaseJWTAuth[UserType,
- The callable can be sync or async. If it is sync, it will be wrapped to support async.
"""
revoked_token_handler: Callable[[Any, ASGIConnection], SyncOrAsyncUnion[bool]] | None = field(default=None)
"""Callable that receives the auth value from the authentication middleware and checks whether the token has been revoked,
returning True if revoked, False otherwise."""
guards: Iterable[Guard] | None = field(default=None)
"""An iterable of guards to call for requests, providing authorization functionalities."""
exclude: str | list[str] | None = field(default=None)
Expand Down Expand Up @@ -693,6 +707,7 @@ def middleware(self) -> DefineMiddleware:
exclude_opt_key=self.exclude_opt_key,
exclude_http_methods=self.exclude_http_methods,
retrieve_user_handler=self.retrieve_user_handler,
revoked_token_handler=self.revoked_token_handler,
scopes=self.scopes,
token_secret=self.token_secret,
token_cls=self.token_cls,
Expand Down
15 changes: 14 additions & 1 deletion litestar/security/jwt/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class JWTAuthenticationMiddleware(AbstractAuthenticationMiddleware):
"auth_header",
"require_claims",
"retrieve_user_handler",
"revoked_token_handler",
"strict_audience",
"token_audience",
"token_cls",
Expand Down Expand Up @@ -57,6 +58,7 @@ def __init__(
verify_expiry: bool = True,
verify_not_before: bool = True,
strict_audience: bool = False,
revoked_token_handler: Callable[[Token, ASGIConnection[Any, Any, Any, Any]], Awaitable[Any]] | None = None,
) -> None:
"""Check incoming requests for an encoded token in the auth header specified, and if present retrieve the user
from persistence using the provided function.
Expand Down Expand Up @@ -86,6 +88,8 @@ def __init__(
strict_audience: Verify that the value of the ``aud`` (*audience*) claim is a single value, and
not a list of values, and matches ``audience`` exactly. Requires that
``accepted_audiences`` is a sequence of length 1
revoked_token_handler: A function that receives a :class:`Token <.security.jwt.Token>` and returns a boolean
indicating whether the token has been revoked.
"""
super().__init__(
app=app,
Expand All @@ -97,6 +101,7 @@ def __init__(
self.algorithm = algorithm
self.auth_header = auth_header
self.retrieve_user_handler = retrieve_user_handler
self.revoked_token_handler = revoked_token_handler
self.token_secret = token_secret
self.token_cls = token_cls
self.token_audience = token_audience
Expand Down Expand Up @@ -153,8 +158,12 @@ async def authenticate_token(
)

user = await self.retrieve_user_handler(token, connection)
token_revoked = False

if not user:
if self.revoked_token_handler:
token_revoked = await self.revoked_token_handler(token, connection)

if not user or token_revoked:
raise NotAuthorizedException()

return AuthenticationResult(user=user, auth=token)
Expand Down Expand Up @@ -184,6 +193,7 @@ def __init__(
verify_expiry: bool = True,
verify_not_before: bool = True,
strict_audience: bool = False,
revoked_token_handler: Callable[[Token, ASGIConnection[Any, Any, Any, Any]], Awaitable[Any]] | None = None,
) -> None:
"""Check incoming requests for an encoded token in the auth header or cookie name specified, and if present
retrieves the user from persistence using the provided function.
Expand Down Expand Up @@ -214,6 +224,8 @@ def __init__(
strict_audience: Verify that the value of the ``aud`` (*audience*) claim is a single value, and
not a list of values, and matches ``audience`` exactly. Requires that
``accepted_audiences`` is a sequence of length 1
revoked_token_handler: A function that receives a :class:`Token <.security.jwt.Token>` and returns a boolean
indicating whether the token has been revoked.
"""
super().__init__(
algorithm=algorithm,
Expand All @@ -223,6 +235,7 @@ def __init__(
exclude_http_methods=exclude_http_methods,
exclude_opt_key=exclude_opt_key,
retrieve_user_handler=retrieve_user_handler,
revoked_token_handler=revoked_token_handler,
scopes=scopes,
token_secret=token_secret,
token_cls=token_cls,
Expand Down
54 changes: 52 additions & 2 deletions tests/unit/test_security/test_jwt/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,19 +60,26 @@ async def test_jwt_auth(
token_unique_jwt_id: Optional[str],
token_extras: Optional[Dict[str, Any]],
) -> None:
mock_block_list: Dict[str, str] = {}
user = UserFactory.build()

await mock_db.set(str(user.id), user, 120) # type: ignore[arg-type]

async def retrieve_user_handler(token: Token, _: "ASGIConnection") -> Any:
return await mock_db.get(token.sub)

async def revoked_token_handler(token: Token, _: "ASGIConnection") -> bool:
if token.jti:
return mock_block_list.get(token.jti) == "revoked"
return False

jwt_auth = JWTAuth[Any](
algorithm=algorithm,
auth_header=auth_header,
default_token_expiration=default_token_expiration,
token_secret=token_secret,
retrieve_user_handler=retrieve_user_handler,
revoked_token_handler=revoked_token_handler,
)

@get("/my-endpoint", middleware=[jwt_auth.middleware])
Expand All @@ -94,7 +101,15 @@ def login_handler() -> Response["User"]:
token_extras=token_extras,
)

with create_test_client(route_handlers=[my_handler, login_handler]) as client:
@get("/logout", middleware=[jwt_auth.middleware])
def logout_handler(request: Request["User", Token, Any]) -> Dict[str, str]:
jti = request.auth.jti
if jti:
mock_block_list[jti] = "revoked"
return {"message": "logged out successfully"}
return {"message": f"can't logout, jti is {jti}"}

with create_test_client(route_handlers=[my_handler, login_handler, logout_handler]) as client:
response = client.get("/login")
assert response.status_code == response_status_code
_, _, encoded_token = response.headers.get(auth_header).partition(" ")
Expand All @@ -114,6 +129,14 @@ def login_handler() -> Response["User"]:
response = client.get("/my-endpoint", headers={auth_header: jwt_auth.format_auth_header(encoded_token)})
assert response.status_code == HTTP_200_OK

response = client.get("/logout", headers={auth_header: jwt_auth.format_auth_header(encoded_token)})
if decoded_token.jti:
assert response.json()["message"] == "logged out successfully"
response = client.get("/my-endpoint", headers={auth_header: jwt_auth.format_auth_header(encoded_token)})
assert response.status_code == HTTP_401_UNAUTHORIZED
else:
assert response.json()["message"] == f"can't logout, jti is {decoded_token.jti}"

response = client.get("/my-endpoint", headers={auth_header: encoded_token})
assert response.status_code == HTTP_401_UNAUTHORIZED

Expand Down Expand Up @@ -216,6 +239,7 @@ async def test_jwt_cookie_auth(
token_unique_jwt_id: Optional[str],
token_extras: Optional[Dict[str, Any]],
) -> None:
mock_block_list: Dict[str, str] = {}
user = UserFactory.build()

await mock_db.set(str(user.id), user, 120) # type: ignore[arg-type]
Expand All @@ -224,12 +248,18 @@ async def retrieve_user_handler(token: Token, connection: Any) -> Any:
assert connection
return await mock_db.get(token.sub)

async def revoked_token_handler(token: Token, _: Any) -> bool:
if token.jti:
return mock_block_list.get(token.jti) == "revoked"
return False

jwt_auth = JWTCookieAuth(
algorithm=algorithm,
key=auth_cookie,
auth_header=auth_header,
default_token_expiration=default_token_expiration,
retrieve_user_handler=retrieve_user_handler, # type: ignore[var-annotated]
revoked_token_handler=revoked_token_handler,
token_secret=token_secret,
)

Expand All @@ -252,7 +282,15 @@ def login_handler() -> Response["User"]:
token_extras=token_extras,
)

with create_test_client(route_handlers=[my_handler, login_handler]) as client:
@get("/logout", middleware=[jwt_auth.middleware])
def logout_handler(request: Request["User", Token, Any]) -> Dict[str, str]:
jti = request.auth.jti
if jti:
mock_block_list[jti] = "revoked"
return {"message": "logged out successfully"}
return {"message": f"can't logout, jti is {jti}"}

with create_test_client(route_handlers=[my_handler, login_handler, logout_handler]) as client:
response = client.get("/login")
assert response.status_code == response_status_code
_, _, encoded_token = response.headers.get(auth_header).partition(" ")
Expand Down Expand Up @@ -315,6 +353,18 @@ def login_handler() -> Response["User"]:
response = client.get("/my-endpoint")
assert response.status_code == HTTP_401_UNAUTHORIZED

client.cookies.clear()
client.cookies = {auth_cookie: jwt_auth.format_auth_header(encoded_token)} # type: ignore[assignment]
response = client.get("/my-endpoint")
assert response.status_code == HTTP_200_OK
response = client.get("/logout")
if decoded_token.jti:
assert response.json()["message"] == "logged out successfully"
response = client.get("/my-endpoint")
assert response.status_code == HTTP_401_UNAUTHORIZED
else:
assert response.json()["message"] == f"can't logout, jti is {decoded_token.jti}"


async def test_path_exclusion() -> None:
async def retrieve_user_handler(_: Token, __: "ASGIConnection") -> None:
Expand Down

0 comments on commit 844c90a

Please sign in to comment.