Skip to content

Commit

Permalink
Add support for access tokens in WSO2 client (#85)
Browse files Browse the repository at this point in the history
  • Loading branch information
caspervdw authored Mar 20, 2024
1 parent 1d25ec5 commit 572aa8b
Show file tree
Hide file tree
Showing 11 changed files with 312 additions and 168 deletions.
3 changes: 2 additions & 1 deletion CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@ Changelog of nens-auth-client
1.5.1 (unreleased)
------------------

- Nothing changed yet.
- Added Bearer token parsing to WSO2 client.

- Refactored the Cognito and WSO2 clients so they use the same baseclass.

1.5.0 (2024-02-19)
------------------
Expand Down
8 changes: 6 additions & 2 deletions nens_auth_client/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,15 @@ def check_resource_server_id(app_configs=None, **kwargs):
"AccessTokenMiddleware is used."
)
]
if not url.endswith("/"):
if (
settings.NENS_AUTH_OAUTH_BACKEND
== "nens_auth_client.cognito.CognitoOAuthClient"
and not url.endswith("/")
):
return [
Error(
"The NENS_AUTH_RESOURCE_SERVER_ID setting needs to end with a "
"slash (because AWS Cognito will automatically add one)."
"slash when using the CognitoOAuthClient."
)
]
return []
Expand Down
139 changes: 42 additions & 97 deletions nens_auth_client/cognito.py
Original file line number Diff line number Diff line change
@@ -1,66 +1,12 @@
from authlib.integrations.django_client import DjangoOAuth2App
from authlib.jose import JsonWebKey
from authlib.jose import JsonWebToken
from .oauth_base import BaseOAuthClient
from django.conf import settings
from django.http.response import HttpResponseRedirect
from urllib.parse import urlencode
from urllib.parse import urlparse
from urllib.parse import urlunparse


def preprocess_access_token(claims):
"""Convert AWS Cognito Access token claims to standard form, inplace.
AWS Cognito Access tokens are missing the "aud" (audience) claim and
instead put the audience into each scope.
This function filters the scopes on those that start with the
NENS_AUTH_RESOURCE_SERVER_ID setting. If there is any matching scope, the
"aud" claim will be set.
The resulting "scope" has no audience(s) in it anymore.
Args:
claims (dict): payload of the Access Token
Example:
>>> audience = "https://some/api/"
>>> claims = {
"scope": "https://some/api/users.readwrite https://something/else"
}
>>> preprocess_access_token(claims)
>>> claims
{
"aud": "https://some/api/",
"scopes": "users.readwrite",
...
}
"""
# Do nothing if there is an already an "aud" claim
if "aud" in claims:
return

# Get the expected "aud" claim
audience = settings.NENS_AUTH_RESOURCE_SERVER_ID

# List scopes and chop off the audience from the scope
new_scopes = []
for scope_item in claims.get("scope", "").split(" "):
if scope_item.startswith(audience):
scope_without_audience = scope_item[len(audience) :]
new_scopes.append(scope_without_audience)

# Don't set the audience if there are no scopes as Access Token is
# apparently not meant for this server.
if not new_scopes:
return

# Update the claims inplace
claims["aud"] = audience
claims["scope"] = " ".join(new_scopes)


class CognitoOAuthClient(DjangoOAuth2App):
class CognitoOAuthClient(BaseOAuthClient):
def logout_redirect(self, request, redirect_uri=None, login_after=False):
"""Create a redirect to the remote server's logout endpoint
Expand Down Expand Up @@ -97,57 +43,56 @@ def logout_redirect(self, request, redirect_uri=None, login_after=False):

return HttpResponseRedirect(logout_url)

def parse_access_token(self, token, claims_options=None, leeway=120):
"""Decode and validate a Cognito access token and return its payload.
def preprocess_access_token(self, claims):
"""Convert AWS Cognito Access token claims to standard form, inplace.
Note: this function is based on authlib.DjangoRemoteApp._parse_id_token
to make use of the same server settings and key cache. The token claims
are AWS Cognito specific.
AWS Cognito Access tokens are missing the "aud" (audience) claim and
instead put the audience into each scope.
Args:
token (str): access token (base64 encoded JWT)
This function filters the scopes on those that start with the
NENS_AUTH_RESOURCE_SERVER_ID setting. If there is any matching scope, the
"aud" claim will be set.
Returns:
claims (dict): the token payload
The resulting "scope" has no audience(s) in it anymore.
Raises:
authlib.jose.errors.JoseError: if token is invalid
ValueError: if the key id is not present in the jwks.json
"""
Args:
claims (dict): payload of the Access Token
# this is a copy from the _parse_id_token equivalent function
def load_key(header, payload):
jwk_set = self.fetch_jwk_set()
kid = header.get("kid")
try:
return JsonWebKey.import_key_set(jwk_set).find_by_kid(kid)
except ValueError:
# re-try with new jwk set
jwk_set = self.fetch_jwk_set(force=True)
return JsonWebKey.import_key_set(jwk_set).find_by_kid(kid)

metadata = self.load_server_metadata()
claims_options = {
"aud": {"essential": True, "value": settings.NENS_AUTH_RESOURCE_SERVER_ID},
"iss": {"essential": True, "value": metadata["issuer"]},
"sub": {"essential": True},
"scope": {"essential": True},
**(claims_options or {}),
Example:
>>> audience = "https://some/api/"
>>> claims = {
"scope": "https://some/api/users.readwrite https://something/else"
}
>>> preprocess_access_token(claims)
>>> claims
{
"aud": "https://some/api/",
"scopes": "users.readwrite",
...
}
"""
# Do nothing if there is an already an "aud" claim
if "aud" in claims:
return

alg_values = metadata.get("id_token_signing_alg_values_supported")
if not alg_values:
alg_values = ["RS256"]
# Get the expected "aud" claim
audience = settings.NENS_AUTH_RESOURCE_SERVER_ID

claims = JsonWebToken(alg_values).decode(
token, key=load_key, claims_options=claims_options
)
# List scopes and chop off the audience from the scope
new_scopes = []
for scope_item in claims.get("scope", "").split(" "):
if scope_item.startswith(audience):
scope_without_audience = scope_item[len(audience) :]
new_scopes.append(scope_without_audience)

# Preprocess the token (to add the "aud" claim)
preprocess_access_token(claims)
# Don't set the audience if there are no scopes as Access Token is
# apparently not meant for this server.
if not new_scopes:
return

claims.validate(leeway=leeway)
return claims
# Update the claims inplace
claims["aud"] = audience
claims["scope"] = " ".join(new_scopes)

@staticmethod
def extract_provider_name(claims):
Expand Down
91 changes: 91 additions & 0 deletions nens_auth_client/oauth_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
from authlib.integrations.django_client import DjangoOAuth2App
from authlib.jose import JsonWebKey
from authlib.jose import JsonWebToken
from django.conf import settings


class BaseOAuthClient(DjangoOAuth2App):
def logout_redirect(self, request, redirect_uri=None, login_after=False):
"""Create a redirect to the remote server's logout endpoint
Note that unlike with login, there is no standardization for logout.
This function should be written for a specific authorization server.
Args:
request: The current request
redirect_uri: The absolute url to the logout success view of this app
login_after: whether to show the login screen after logout
Returns:
HttpResponseRedirect authorization server logout endpoint
"""
raise NotImplementedError()

def load_key(self, header, payload):
"""Load a JSONWebKey from the authorization server given JWT header and payload.
Source:
authlib.integrations.base_client.sync_openid.parse_id_token
"""
jwk_set = JsonWebKey.import_key_set(self.fetch_jwk_set())
try:
return jwk_set.find_by_kid(header.get("kid"))
except ValueError:
# re-try with new jwk set
jwk_set = JsonWebKey.import_key_set(self.fetch_jwk_set(force=True))
return jwk_set.find_by_kid(header.get("kid"))

def preprocess_access_token(self, claims):
"""Convert access token claims to standard form, inplace.
Args:
claims (dict): payload of the Access Token
"""

def parse_access_token(self, token, claims_options=None, leeway=120):
"""Decode and validate an access token and return its payload.
Args:
token (str): access token (base64 encoded JWT)
Returns:
claims (dict): the token payload
Raises:
authlib.jose.errors.JoseError: if token is invalid
ValueError: if the key id is not present in the jwks.json
"""
metadata = self.load_server_metadata()
claims_options = {
"aud": {"essential": True, "value": settings.NENS_AUTH_RESOURCE_SERVER_ID},
"iss": {"essential": True, "value": metadata["issuer"]},
"sub": {"essential": True},
"scope": {"essential": True},
**(claims_options or {}),
}

alg_values = metadata.get("id_token_signing_alg_values_supported")
if not alg_values:
alg_values = ["RS256"]

claims = JsonWebToken(alg_values).decode(
token, key=self.load_key, claims_options=claims_options
)

# Preprocess the token (to add the "aud" claim)
self.preprocess_access_token(claims)

claims.validate(leeway=leeway)
return claims

@staticmethod
def extract_provider_name(claims):
"""Return provider name from claim and `None` if not found"""
# Also used by backends.py
raise NotImplementedError()

@staticmethod
def extract_username(claims) -> str:
"""Return username from claims"""
# Also used by backends.py
raise NotImplementedError()
1 change: 1 addition & 0 deletions nens_auth_client/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def access_token_generator(token_generator, access_token_template):

def func(**extra_claims):
claims = {**access_token_template, **extra_claims}
claims = {k: v for (k, v) in claims.items() if v is not None}
return token_generator(**claims)

return func
Expand Down
3 changes: 1 addition & 2 deletions nens_auth_client/tests/test_cognito.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from nens_auth_client.cognito import CognitoOAuthClient
from nens_auth_client.cognito import preprocess_access_token

import pytest

Expand All @@ -19,7 +18,7 @@
)
def test_preprocess_access_token(claims, expected, settings):
settings.NENS_AUTH_RESOURCE_SERVER_ID = "api/"
preprocess_access_token(claims)
CognitoOAuthClient.preprocess_access_token(None, claims)
assert claims == expected


Expand Down
Loading

0 comments on commit 572aa8b

Please sign in to comment.