From 8bf691f7b1ea6a81f5150535020c073acbbd74a4 Mon Sep 17 00:00:00 2001 From: Ryan Johnson Date: Tue, 9 Jan 2024 10:29:45 -0800 Subject: [PATCH] add PKCE to SessionRefresh middleware --- mozilla_django_oidc/middleware.py | 31 +++++++- tests/test_middleware.py | 121 ++++++++++++++++++++++++++++++ 2 files changed, 151 insertions(+), 1 deletion(-) diff --git a/mozilla_django_oidc/middleware.py b/mozilla_django_oidc/middleware.py index 1b050325..370c6785 100644 --- a/mozilla_django_oidc/middleware.py +++ b/mozilla_django_oidc/middleware.py @@ -15,6 +15,7 @@ from mozilla_django_oidc.utils import ( absolutify, add_state_and_verifier_and_nonce_to_session, + generate_code_challenge, import_from_settings, ) @@ -152,7 +153,35 @@ def process_request(self, request): nonce = get_random_string(self.OIDC_NONCE_SIZE) params.update({"nonce": nonce}) - add_state_and_verifier_and_nonce_to_session(request, state, params) + if self.get_settings("OIDC_USE_PKCE", False): + code_verifier_length = self.get_settings("OIDC_PKCE_CODE_VERIFIER_SIZE", 64) + # Check that code_verifier_length is between the min and max length + # defined in https://datatracker.ietf.org/doc/html/rfc7636#section-4.1 + if not (43 <= code_verifier_length <= 128): + raise ValueError("code_verifier_length must be between 43 and 128") + + # Generate code_verifier and code_challenge pair + code_verifier = get_random_string(code_verifier_length) + code_challenge_method = self.get_settings( + "OIDC_PKCE_CODE_CHALLENGE_METHOD", "S256" + ) + code_challenge = generate_code_challenge( + code_verifier, code_challenge_method + ) + + # Append code_challenge to authentication request parameters + params.update( + { + "code_challenge": code_challenge, + "code_challenge_method": code_challenge_method, + } + ) + else: + code_verifier = None + + add_state_and_verifier_and_nonce_to_session( + request, state, params, code_verifier + ) request.session["oidc_login_next"] = request.get_full_path() diff --git a/tests/test_middleware.py b/tests/test_middleware.py index ac1b708f..442dfa87 100644 --- a/tests/test_middleware.py +++ b/tests/test_middleware.py @@ -78,6 +78,36 @@ def test_is_ajax(self, mock_middleware_random): json_payload = json.loads(response.content.decode("utf-8")) self.assertEqual(json_payload["refresh_url"], response["refresh_url"]) + @override_settings(OIDC_USE_PKCE=True) + def test_is_ajax_with_pkce(self, mock_middleware_random): + mock_middleware_random.return_value = "examplestring" + + request = self.factory.get("/foo", HTTP_X_REQUESTED_WITH="XMLHttpRequest") + request.session = {} + request.user = self.user + + response = self.middleware.process_request(request) + self.assertEqual(response.status_code, 403) + # The URL to go to is available both as a header and as a key + # in the JSON response. + self.assertTrue(response["refresh_url"]) + url, qs = response["refresh_url"].split("?") + self.assertEqual(url, "http://example.com/authorize") + expected_query = { + "response_type": ["code"], + "redirect_uri": ["http://testserver/callback/"], + "client_id": ["foo"], + "nonce": ["examplestring"], + "prompt": ["none"], + "scope": ["openid email"], + "state": ["examplestring"], + "code_challenge_method": ["S256"], + "code_challenge": ["m8yog7rVNdOd7hYIoUg6yl5mk_IYauWdSIBUjoPJHB0"], + } + self.assertEqual(expected_query, parse_qs(qs)) + json_payload = json.loads(response.content.decode("utf-8")) + self.assertEqual(json_payload["refresh_url"], response["refresh_url"]) + def test_no_oidc_token_expiration_forces_renewal(self, mock_middleware_random): mock_middleware_random.return_value = "examplestring" @@ -101,6 +131,34 @@ def test_no_oidc_token_expiration_forces_renewal(self, mock_middleware_random): } self.assertEqual(expected_query, parse_qs(qs)) + @override_settings(OIDC_USE_PKCE=True) + def test_no_oidc_token_expiration_forces_renewal_with_pkce( + self, mock_middleware_random + ): + mock_middleware_random.return_value = "examplestring" + + request = self.factory.get("/foo") + request.user = self.user + request.session = {} + + response = self.middleware.process_request(request) + + self.assertEqual(response.status_code, 302) + url, qs = response.url.split("?") + self.assertEqual(url, "http://example.com/authorize") + expected_query = { + "response_type": ["code"], + "redirect_uri": ["http://testserver/callback/"], + "client_id": ["foo"], + "nonce": ["examplestring"], + "prompt": ["none"], + "scope": ["openid email"], + "state": ["examplestring"], + "code_challenge_method": ["S256"], + "code_challenge": ["m8yog7rVNdOd7hYIoUg6yl5mk_IYauWdSIBUjoPJHB0"], + } + self.assertEqual(expected_query, parse_qs(qs)) + def test_expired_token_forces_renewal(self, mock_middleware_random): mock_middleware_random.return_value = "examplestring" @@ -124,6 +182,32 @@ def test_expired_token_forces_renewal(self, mock_middleware_random): } self.assertEqual(expected_query, parse_qs(qs)) + @override_settings(OIDC_USE_PKCE=True) + def test_expired_token_forces_renewal_with_pkce(self, mock_middleware_random): + mock_middleware_random.return_value = "examplestring" + + request = self.factory.get("/foo") + request.user = self.user + request.session = {"oidc_id_token_expiration": time.time() - 10} + + response = self.middleware.process_request(request) + + self.assertEqual(response.status_code, 302) + url, qs = response.url.split("?") + self.assertEqual(url, "http://example.com/authorize") + expected_query = { + "response_type": ["code"], + "redirect_uri": ["http://testserver/callback/"], + "client_id": ["foo"], + "nonce": ["examplestring"], + "prompt": ["none"], + "scope": ["openid email"], + "state": ["examplestring"], + "code_challenge_method": ["S256"], + "code_challenge": ["m8yog7rVNdOd7hYIoUg6yl5mk_IYauWdSIBUjoPJHB0"], + } + self.assertEqual(expected_query, parse_qs(qs)) + # This adds a "home page" we can test against. def fakeview(req): @@ -306,6 +390,43 @@ def test_expired_token_redirects_to_sso(self, mock_middleware_random): } self.assertEqual(expected_query, parse_qs(qs)) + @override_settings(OIDC_OP_AUTHORIZATION_ENDPOINT="http://example.com/authorize") + @override_settings(OIDC_RP_CLIENT_ID="foo") + @override_settings(OIDC_RENEW_ID_TOKEN_EXPIRY_SECONDS=120) + @override_settings(OIDC_USE_PKCE=True) + @patch("mozilla_django_oidc.middleware.get_random_string") + def test_expired_token_redirects_to_sso_with_pkce(self, mock_middleware_random): + mock_middleware_random.return_value = "examplestring" + + client = ClientWithUser() + client.login(username=self.user.username, password="password") + + # Set expiration to some time in the past + session = client.session + session["oidc_id_token_expiration"] = time.time() - 100 + session[ + "_auth_user_backend" + ] = "mozilla_django_oidc.auth.OIDCAuthenticationBackend" + session.save() + + resp = client.get("/mdo_fake_view/") + self.assertEqual(resp.status_code, 302) + + url, qs = resp.url.split("?") + self.assertEqual(url, "http://example.com/authorize") + expected_query = { + "response_type": ["code"], + "redirect_uri": ["http://testserver/callback/"], + "client_id": ["foo"], + "nonce": ["examplestring"], + "prompt": ["none"], + "scope": ["openid email"], + "state": ["examplestring"], + "code_challenge_method": ["S256"], + "code_challenge": ["m8yog7rVNdOd7hYIoUg6yl5mk_IYauWdSIBUjoPJHB0"], + } + self.assertEqual(expected_query, parse_qs(qs)) + @override_settings(OIDC_OP_AUTHORIZATION_ENDPOINT="http://example.com/authorize") @override_settings(OIDC_RP_CLIENT_ID="foo") @override_settings(OIDC_RENEW_ID_TOKEN_EXPIRY_SECONDS=120)