Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add PKCE to SessionRefresh middleware #515

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 30 additions & 1 deletion mozilla_django_oidc/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from mozilla_django_oidc.utils import (
absolutify,
add_state_and_verifier_and_nonce_to_session,
generate_code_challenge,
import_from_settings,
)

Expand Down Expand Up @@ -152,7 +153,35 @@ def process_request(self, request):
nonce = get_random_string(self.OIDC_NONCE_SIZE)
params.update({"nonce": nonce})

add_state_and_verifier_and_nonce_to_session(request, state, params)
if self.get_settings("OIDC_USE_PKCE", False):
code_verifier_length = self.get_settings("OIDC_PKCE_CODE_VERIFIER_SIZE", 64)
# Check that code_verifier_length is between the min and max length
# defined in https://datatracker.ietf.org/doc/html/rfc7636#section-4.1
if not (43 <= code_verifier_length <= 128):
raise ValueError("code_verifier_length must be between 43 and 128")

# Generate code_verifier and code_challenge pair
code_verifier = get_random_string(code_verifier_length)
code_challenge_method = self.get_settings(
"OIDC_PKCE_CODE_CHALLENGE_METHOD", "S256"
)
code_challenge = generate_code_challenge(
code_verifier, code_challenge_method
)

# Append code_challenge to authentication request parameters
params.update(
{
"code_challenge": code_challenge,
"code_challenge_method": code_challenge_method,
}
)
else:
code_verifier = None

add_state_and_verifier_and_nonce_to_session(
request, state, params, code_verifier
)

request.session["oidc_login_next"] = request.get_full_path()

Expand Down
121 changes: 121 additions & 0 deletions tests/test_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,36 @@ def test_is_ajax(self, mock_middleware_random):
json_payload = json.loads(response.content.decode("utf-8"))
self.assertEqual(json_payload["refresh_url"], response["refresh_url"])

@override_settings(OIDC_USE_PKCE=True)
def test_is_ajax_with_pkce(self, mock_middleware_random):
mock_middleware_random.return_value = "examplestring"

request = self.factory.get("/foo", HTTP_X_REQUESTED_WITH="XMLHttpRequest")
request.session = {}
request.user = self.user

response = self.middleware.process_request(request)
self.assertEqual(response.status_code, 403)
# The URL to go to is available both as a header and as a key
# in the JSON response.
self.assertTrue(response["refresh_url"])
url, qs = response["refresh_url"].split("?")
self.assertEqual(url, "http://example.com/authorize")
expected_query = {
"response_type": ["code"],
"redirect_uri": ["http://testserver/callback/"],
"client_id": ["foo"],
"nonce": ["examplestring"],
"prompt": ["none"],
"scope": ["openid email"],
"state": ["examplestring"],
"code_challenge_method": ["S256"],
"code_challenge": ["m8yog7rVNdOd7hYIoUg6yl5mk_IYauWdSIBUjoPJHB0"],
}
self.assertEqual(expected_query, parse_qs(qs))
json_payload = json.loads(response.content.decode("utf-8"))
self.assertEqual(json_payload["refresh_url"], response["refresh_url"])

def test_no_oidc_token_expiration_forces_renewal(self, mock_middleware_random):
mock_middleware_random.return_value = "examplestring"

Expand All @@ -101,6 +131,34 @@ def test_no_oidc_token_expiration_forces_renewal(self, mock_middleware_random):
}
self.assertEqual(expected_query, parse_qs(qs))

@override_settings(OIDC_USE_PKCE=True)
def test_no_oidc_token_expiration_forces_renewal_with_pkce(
self, mock_middleware_random
):
mock_middleware_random.return_value = "examplestring"

request = self.factory.get("/foo")
request.user = self.user
request.session = {}

response = self.middleware.process_request(request)

self.assertEqual(response.status_code, 302)
url, qs = response.url.split("?")
self.assertEqual(url, "http://example.com/authorize")
expected_query = {
"response_type": ["code"],
"redirect_uri": ["http://testserver/callback/"],
"client_id": ["foo"],
"nonce": ["examplestring"],
"prompt": ["none"],
"scope": ["openid email"],
"state": ["examplestring"],
"code_challenge_method": ["S256"],
"code_challenge": ["m8yog7rVNdOd7hYIoUg6yl5mk_IYauWdSIBUjoPJHB0"],
}
self.assertEqual(expected_query, parse_qs(qs))

def test_expired_token_forces_renewal(self, mock_middleware_random):
mock_middleware_random.return_value = "examplestring"

Expand All @@ -124,6 +182,32 @@ def test_expired_token_forces_renewal(self, mock_middleware_random):
}
self.assertEqual(expected_query, parse_qs(qs))

@override_settings(OIDC_USE_PKCE=True)
def test_expired_token_forces_renewal_with_pkce(self, mock_middleware_random):
mock_middleware_random.return_value = "examplestring"

request = self.factory.get("/foo")
request.user = self.user
request.session = {"oidc_id_token_expiration": time.time() - 10}

response = self.middleware.process_request(request)

self.assertEqual(response.status_code, 302)
url, qs = response.url.split("?")
self.assertEqual(url, "http://example.com/authorize")
expected_query = {
"response_type": ["code"],
"redirect_uri": ["http://testserver/callback/"],
"client_id": ["foo"],
"nonce": ["examplestring"],
"prompt": ["none"],
"scope": ["openid email"],
"state": ["examplestring"],
"code_challenge_method": ["S256"],
"code_challenge": ["m8yog7rVNdOd7hYIoUg6yl5mk_IYauWdSIBUjoPJHB0"],
}
self.assertEqual(expected_query, parse_qs(qs))


# This adds a "home page" we can test against.
def fakeview(req):
Expand Down Expand Up @@ -306,6 +390,43 @@ def test_expired_token_redirects_to_sso(self, mock_middleware_random):
}
self.assertEqual(expected_query, parse_qs(qs))

@override_settings(OIDC_OP_AUTHORIZATION_ENDPOINT="http://example.com/authorize")
@override_settings(OIDC_RP_CLIENT_ID="foo")
@override_settings(OIDC_RENEW_ID_TOKEN_EXPIRY_SECONDS=120)
@override_settings(OIDC_USE_PKCE=True)
@patch("mozilla_django_oidc.middleware.get_random_string")
def test_expired_token_redirects_to_sso_with_pkce(self, mock_middleware_random):
mock_middleware_random.return_value = "examplestring"

client = ClientWithUser()
client.login(username=self.user.username, password="password")

# Set expiration to some time in the past
session = client.session
session["oidc_id_token_expiration"] = time.time() - 100
session[
"_auth_user_backend"
] = "mozilla_django_oidc.auth.OIDCAuthenticationBackend"
session.save()

resp = client.get("/mdo_fake_view/")
self.assertEqual(resp.status_code, 302)

url, qs = resp.url.split("?")
self.assertEqual(url, "http://example.com/authorize")
expected_query = {
"response_type": ["code"],
"redirect_uri": ["http://testserver/callback/"],
"client_id": ["foo"],
"nonce": ["examplestring"],
"prompt": ["none"],
"scope": ["openid email"],
"state": ["examplestring"],
"code_challenge_method": ["S256"],
"code_challenge": ["m8yog7rVNdOd7hYIoUg6yl5mk_IYauWdSIBUjoPJHB0"],
}
self.assertEqual(expected_query, parse_qs(qs))

@override_settings(OIDC_OP_AUTHORIZATION_ENDPOINT="http://example.com/authorize")
@override_settings(OIDC_RP_CLIENT_ID="foo")
@override_settings(OIDC_RENEW_ID_TOKEN_EXPIRY_SECONDS=120)
Expand Down
Loading