From 77869154e4674199bbb176b00c8e7a6bead7b613 Mon Sep 17 00:00:00 2001 From: Casper van der Wel Date: Wed, 20 Mar 2024 09:24:54 +0100 Subject: [PATCH] Add support for access tokens in WSO2 client --- nens_auth_client/tests/conftest.py | 1 + nens_auth_client/tests/test_wso2.py | 35 ++++++++++++++++++---- nens_auth_client/wso2.py | 45 +++++++++++++++++++++++++---- setup.cfg | 2 +- 4 files changed, 71 insertions(+), 12 deletions(-) diff --git a/nens_auth_client/tests/conftest.py b/nens_auth_client/tests/conftest.py index 82455fe..0d4208a 100644 --- a/nens_auth_client/tests/conftest.py +++ b/nens_auth_client/tests/conftest.py @@ -145,6 +145,7 @@ def access_token_generator(token_generator, access_token_template): def func(**extra_claims): claims = {**access_token_template, **extra_claims} + claims = {k: v for (k, v) in claims.items() if v is not None} return token_generator(**claims) return func diff --git a/nens_auth_client/tests/test_wso2.py b/nens_auth_client/tests/test_wso2.py index 9f95333..c91a0c6 100644 --- a/nens_auth_client/tests/test_wso2.py +++ b/nens_auth_client/tests/test_wso2.py @@ -1,3 +1,6 @@ +from authlib.jose.errors import JoseError +from authlib.oidc.discovery import get_well_known_url +from django.conf import settings from nens_auth_client.wso2 import WSO2AuthClient import pytest @@ -19,9 +22,31 @@ def test_extract_username(claims, expected): assert WSO2AuthClient.extract_username(claims) == expected -def test_parse_access_token_includes_claims(access_token_generator): - with pytest.raises(NotImplementedError) as e: - WSO2AuthClient.parse_access_token(None, access_token_generator()) +@pytest.fixture +def wso2_client(): + return WSO2AuthClient( + "foo", + server_metadata_url=get_well_known_url( + settings.NENS_AUTH_ISSUER, external=True + ), + ) - # error is raised with claims as arg - assert e.value.args[0]["client_id"] == "1234" + +def test_parse_access_token_wso2(access_token_generator, jwks_request, wso2_client): + # disable 'token_use' (not included in WSO2 access token) + claims = wso2_client.parse_access_token( + access_token_generator(email="test@wso2", token_use=None) + ) + + assert claims["email"] == "test@wso2" + + +@pytest.mark.parametrize( + "claims_mod", [{"aud": "abc123"}, {"sub": None}, {"iss": "abc123"}, {"exp": 0}] +) +def test_parse_access_token_wso2_invalid_claims( + claims_mod, access_token_generator, jwks_request, wso2_client +): + token = access_token_generator(**claims_mod) + with pytest.raises(JoseError): + wso2_client.parse_access_token(token) diff --git a/nens_auth_client/wso2.py b/nens_auth_client/wso2.py index 48ffe72..597b4b7 100644 --- a/nens_auth_client/wso2.py +++ b/nens_auth_client/wso2.py @@ -1,4 +1,7 @@ from authlib.integrations.django_client import DjangoOAuth2App +from authlib.jose import JsonWebKey +from authlib.jose import JsonWebToken +from django.conf import settings from django.http.response import HttpResponseRedirect from urllib.parse import urlencode from urllib.parse import urlparse @@ -46,17 +49,47 @@ def logout_redirect(self, request, redirect_uri=None, login_after=False): def parse_access_token(self, token, claims_options=None, leeway=120): """Decode and validate a WSO2 access token and return its payload. - Note: this function just errors with the token claims in the error - message (so that we can figure out how we can actually validate the - token) - Args: token (str): access token (base64 encoded JWT) + Returns: + claims (dict): the token payload + Raises: - NotImplementedError: always + authlib.jose.errors.JoseError: if token is invalid + ValueError: if the key id is not present in the jwks.json """ - raise NotImplementedError(decode_jwt(token)) + + # this is a copy from the _parse_id_token equivalent function + def load_key(header, payload): + jwk_set = self.fetch_jwk_set() + kid = header.get("kid") + try: + return JsonWebKey.import_key_set(jwk_set).find_by_kid(kid) + except ValueError: + # re-try with new jwk set + jwk_set = self.fetch_jwk_set(force=True) + return JsonWebKey.import_key_set(jwk_set).find_by_kid(kid) + + metadata = self.load_server_metadata() + claims_options = { + "aud": {"essential": True, "value": settings.NENS_AUTH_RESOURCE_SERVER_ID}, + "iss": {"essential": True, "value": metadata["issuer"]}, + "sub": {"essential": True}, + "scope": {"essential": True}, + **(claims_options or {}), + } + + alg_values = metadata.get("id_token_signing_alg_values_supported") + if not alg_values: + alg_values = ["RS256"] + + claims = JsonWebToken(alg_values).decode( + token, key=load_key, claims_options=claims_options + ) + + claims.validate(leeway=leeway) + return claims @staticmethod def extract_provider_name(claims): diff --git a/setup.cfg b/setup.cfg index 9830613..efa77a3 100644 --- a/setup.cfg +++ b/setup.cfg @@ -13,6 +13,6 @@ force_single_line = true [tool:pytest] DJANGO_SETTINGS_MODULE = nens_auth_client.testsettings -addopts = --cov --cache-clear --cov-report=term-missing nens_auth_client +#addopts = --cov --cache-clear --cov-report=term-missing nens_auth_client python_files = test_*.py junit_family = xunit1