Skip to content

Commit

Permalink
Add support for access tokens in WSO2 client
Browse files Browse the repository at this point in the history
  • Loading branch information
caspervdw committed Mar 20, 2024
1 parent 1d25ec5 commit 7786915
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 12 deletions.
1 change: 1 addition & 0 deletions nens_auth_client/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def access_token_generator(token_generator, access_token_template):

def func(**extra_claims):
claims = {**access_token_template, **extra_claims}
claims = {k: v for (k, v) in claims.items() if v is not None}
return token_generator(**claims)

return func
Expand Down
35 changes: 30 additions & 5 deletions nens_auth_client/tests/test_wso2.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from authlib.jose.errors import JoseError
from authlib.oidc.discovery import get_well_known_url
from django.conf import settings
from nens_auth_client.wso2 import WSO2AuthClient

import pytest
Expand All @@ -19,9 +22,31 @@ def test_extract_username(claims, expected):
assert WSO2AuthClient.extract_username(claims) == expected


def test_parse_access_token_includes_claims(access_token_generator):
with pytest.raises(NotImplementedError) as e:
WSO2AuthClient.parse_access_token(None, access_token_generator())
@pytest.fixture
def wso2_client():
return WSO2AuthClient(
"foo",
server_metadata_url=get_well_known_url(
settings.NENS_AUTH_ISSUER, external=True
),
)

# error is raised with claims as arg
assert e.value.args[0]["client_id"] == "1234"

def test_parse_access_token_wso2(access_token_generator, jwks_request, wso2_client):
# disable 'token_use' (not included in WSO2 access token)
claims = wso2_client.parse_access_token(
access_token_generator(email="test@wso2", token_use=None)
)

assert claims["email"] == "test@wso2"


@pytest.mark.parametrize(
"claims_mod", [{"aud": "abc123"}, {"sub": None}, {"iss": "abc123"}, {"exp": 0}]
)
def test_parse_access_token_wso2_invalid_claims(
claims_mod, access_token_generator, jwks_request, wso2_client
):
token = access_token_generator(**claims_mod)
with pytest.raises(JoseError):
wso2_client.parse_access_token(token)
45 changes: 39 additions & 6 deletions nens_auth_client/wso2.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from authlib.integrations.django_client import DjangoOAuth2App
from authlib.jose import JsonWebKey
from authlib.jose import JsonWebToken
from django.conf import settings
from django.http.response import HttpResponseRedirect
from urllib.parse import urlencode
from urllib.parse import urlparse
Expand Down Expand Up @@ -46,17 +49,47 @@ def logout_redirect(self, request, redirect_uri=None, login_after=False):
def parse_access_token(self, token, claims_options=None, leeway=120):
"""Decode and validate a WSO2 access token and return its payload.
Note: this function just errors with the token claims in the error
message (so that we can figure out how we can actually validate the
token)
Args:
token (str): access token (base64 encoded JWT)
Returns:
claims (dict): the token payload
Raises:
NotImplementedError: always
authlib.jose.errors.JoseError: if token is invalid
ValueError: if the key id is not present in the jwks.json
"""
raise NotImplementedError(decode_jwt(token))

# this is a copy from the _parse_id_token equivalent function
def load_key(header, payload):
jwk_set = self.fetch_jwk_set()
kid = header.get("kid")
try:
return JsonWebKey.import_key_set(jwk_set).find_by_kid(kid)
except ValueError:
# re-try with new jwk set
jwk_set = self.fetch_jwk_set(force=True)
return JsonWebKey.import_key_set(jwk_set).find_by_kid(kid)

metadata = self.load_server_metadata()
claims_options = {
"aud": {"essential": True, "value": settings.NENS_AUTH_RESOURCE_SERVER_ID},
"iss": {"essential": True, "value": metadata["issuer"]},
"sub": {"essential": True},
"scope": {"essential": True},
**(claims_options or {}),
}

alg_values = metadata.get("id_token_signing_alg_values_supported")
if not alg_values:
alg_values = ["RS256"]

claims = JsonWebToken(alg_values).decode(
token, key=load_key, claims_options=claims_options
)

claims.validate(leeway=leeway)
return claims

@staticmethod
def extract_provider_name(claims):
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,6 @@ force_single_line = true

[tool:pytest]
DJANGO_SETTINGS_MODULE = nens_auth_client.testsettings
addopts = --cov --cache-clear --cov-report=term-missing nens_auth_client
#addopts = --cov --cache-clear --cov-report=term-missing nens_auth_client
python_files = test_*.py
junit_family = xunit1

0 comments on commit 7786915

Please sign in to comment.