diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OidcLogoutConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OidcLogoutConfigurer.java index 570821b82ba..15512cbbb3e 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OidcLogoutConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OidcLogoutConfigurer.java @@ -26,11 +26,9 @@ import org.springframework.security.config.annotation.web.configurers.AbstractHttpConfigurer; import org.springframework.security.oauth2.client.oidc.authentication.logout.OidcBackChannelLogoutAuthenticationProvider; import org.springframework.security.oauth2.client.oidc.web.OidcBackChannelLogoutFilter; -import org.springframework.security.oauth2.client.oidc.web.logout.OidcBackChannelLogoutHandler; import org.springframework.security.oauth2.client.oidc.web.logout.OidcLogoutAuthenticationConverter; import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; import org.springframework.security.web.authentication.AuthenticationConverter; -import org.springframework.security.web.authentication.logout.LogoutHandler; import org.springframework.security.web.csrf.CsrfFilter; import org.springframework.util.Assert; @@ -100,10 +98,7 @@ public final class BackChannelLogoutConfigurer { private AuthenticationConverter authenticationConverter; - private AuthenticationManager authenticationManager = new ProviderManager( - new OidcBackChannelLogoutAuthenticationProvider()); - - private LogoutHandler logoutHandler; + private AuthenticationManager authenticationManager; /** * Use this {@link AuthenticationConverter} to extract the Logout Token from the @@ -128,17 +123,6 @@ public BackChannelLogoutConfigurer authenticationManager(AuthenticationManager a return this; } - /** - * Use this {@link LogoutHandler} for invalidating each session identified by the - * OIDC Back-Channel Logout Token - * @return the {@link BackChannelLogoutConfigurer} for further configuration - */ - public BackChannelLogoutConfigurer logoutHandler(LogoutHandler logoutHandler) { - Assert.notNull(logoutHandler, "logoutHandler cannot be null"); - this.logoutHandler = logoutHandler; - return this; - } - private AuthenticationConverter authenticationConverter(B http) { if (this.authenticationConverter == null) { ClientRegistrationRepository clientRegistrationRepository = OAuth2ClientConfigurerUtils @@ -148,23 +132,18 @@ private AuthenticationConverter authenticationConverter(B http) { return this.authenticationConverter; } - private AuthenticationManager authenticationManager() { - return this.authenticationManager; - } - - private LogoutHandler logoutHandler(B http) { - if (this.logoutHandler == null) { - OidcBackChannelLogoutHandler logoutHandler = new OidcBackChannelLogoutHandler(); - logoutHandler.setSessionRegistry(OAuth2ClientConfigurerUtils.getOidcSessionRegistry(http)); - this.logoutHandler = logoutHandler; + private AuthenticationManager authenticationManager(B http) { + if (this.authenticationManager == null) { + OidcBackChannelLogoutAuthenticationProvider authenticationProvider = new OidcBackChannelLogoutAuthenticationProvider(); + authenticationProvider.setSessionRegistry(OAuth2ClientConfigurerUtils.getOidcSessionRegistry(http)); + this.authenticationManager = new ProviderManager(authenticationProvider); } - return this.logoutHandler; + return this.authenticationManager; } void configure(B http) { OidcBackChannelLogoutFilter filter = new OidcBackChannelLogoutFilter(authenticationConverter(http), - authenticationManager()); - filter.setLogoutHandler(logoutHandler(http)); + authenticationManager(http)); http.addFilterBefore(filter, CsrfFilter.class); } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OidcLogoutConfigurerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OidcLogoutConfigurerTests.java index 3fd3aed8ad6..8993fe1b529 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OidcLogoutConfigurerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OidcLogoutConfigurerTests.java @@ -52,7 +52,8 @@ import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.mock.web.MockHttpSession; import org.springframework.mock.web.MockServletContext; -import org.springframework.security.authentication.AuthenticationManager; +import org.springframework.security.authentication.AuthenticationProvider; +import org.springframework.security.authentication.ProviderManager; import org.springframework.security.config.Customizer; import org.springframework.security.config.annotation.web.builders.HttpSecurity; import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity; @@ -62,14 +63,13 @@ import org.springframework.security.core.userdetails.User; import org.springframework.security.core.userdetails.UserDetailsService; import org.springframework.security.oauth2.client.oidc.authentication.logout.LogoutTokenClaimNames; -import org.springframework.security.oauth2.client.oidc.authentication.logout.OidcBackChannelLogoutAuthentication; +import org.springframework.security.oauth2.client.oidc.authentication.logout.OidcBackChannelLogoutAuthenticationProvider; import org.springframework.security.oauth2.client.oidc.authentication.logout.OidcLogoutAuthenticationToken; import org.springframework.security.oauth2.client.oidc.authentication.logout.OidcLogoutToken; import org.springframework.security.oauth2.client.oidc.authentication.logout.TestOidcLogoutTokens; import org.springframework.security.oauth2.client.oidc.session.OidcSessionInformation; import org.springframework.security.oauth2.client.oidc.session.OidcSessionRegistry; import org.springframework.security.oauth2.client.oidc.session.TestOidcSessionInformations; -import org.springframework.security.oauth2.client.oidc.web.logout.OidcBackChannelLogoutHandler; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; import org.springframework.security.oauth2.client.registration.InMemoryClientRegistrationRepository; @@ -78,13 +78,15 @@ import org.springframework.security.oauth2.core.oidc.TestOidcIdTokens; import org.springframework.security.oauth2.core.oidc.user.OidcUser; import org.springframework.security.oauth2.jwt.JwtClaimsSet; +import org.springframework.security.oauth2.jwt.JwtDecoder; +import org.springframework.security.oauth2.jwt.JwtDecoderFactory; import org.springframework.security.oauth2.jwt.JwtEncoder; import org.springframework.security.oauth2.jwt.JwtEncoderParameters; import org.springframework.security.oauth2.jwt.NimbusJwtEncoder; +import org.springframework.security.oauth2.jwt.TestJwts; import org.springframework.security.provisioning.InMemoryUserDetailsManager; import org.springframework.security.web.SecurityFilterChain; import org.springframework.security.web.authentication.AuthenticationConverter; -import org.springframework.security.web.authentication.logout.LogoutHandler; import org.springframework.test.web.servlet.MockMvc; import org.springframework.test.web.servlet.MvcResult; import org.springframework.test.web.servlet.ResultActions; @@ -175,19 +177,21 @@ void logoutWhenCustomComponentsThenUses() throws Exception { String registrationId = this.clientRegistration.getRegistrationId(); AuthenticationConverter authenticationConverter = this.spring.getContext() .getBean(AuthenticationConverter.class); - given(authenticationConverter.convert(any())) - .willReturn(new OidcLogoutAuthenticationToken("token", this.clientRegistration)); - AuthenticationManager authenticationManager = this.spring.getContext().getBean(AuthenticationManager.class); + given(authenticationConverter.convert(any())).willReturn(new OidcLogoutAuthenticationToken("token", + this.clientRegistration, "http://localhost/logout/connect/back-channel/" + registrationId)); OidcLogoutToken logoutToken = TestOidcLogoutTokens.withSessionId("issuer", "provider").build(); - given(authenticationManager.authenticate(any())) - .willReturn(new OidcBackChannelLogoutAuthentication(logoutToken)); - OidcSessionRegistry sessionRegistry = this.spring.getContext().getBean(OidcSessionRegistry.class); + JwtDecoderFactory decoderFactory = this.spring.getContext() + .getBean(JwtDecoderFactory.class); + JwtDecoder decoder = mock(JwtDecoder.class); + given(decoder.decode(any())) + .willReturn(TestJwts.jwt().claims((claims) -> claims.putAll(logoutToken.getClaims())).build()); + given(decoderFactory.createDecoder(any())).willReturn(decoder); Set details = Set.of(TestOidcSessionInformations.create()); - given(sessionRegistry.removeSessionInformation(logoutToken)).willReturn(details); + OidcSessionRegistry sessionRegistry = this.spring.getContext().getBean(OidcSessionRegistry.class); + given(sessionRegistry.removeSessionInformation(any(OidcLogoutToken.class))).willReturn(details); this.mvc.perform(post("/logout/connect/back-channel/" + registrationId).param("logout_token", "token")) .andExpect(status().isOk()); - verify(authenticationManager).authenticate(any()); - verify(this.spring.getContext().getBean(LogoutHandler.class)).logout(any(), any(), any()); + verify(decoder).decode(any()); verify(sessionRegistry).removeSessionInformation(logoutToken); } @@ -241,24 +245,25 @@ static class WithCustomComponentsConfig { AuthenticationConverter authenticationConverter = mock(AuthenticationConverter.class); - AuthenticationManager authenticationManager = mock(AuthenticationManager.class); + OidcBackChannelLogoutAuthenticationProvider authenticationProvider = spy( + new OidcBackChannelLogoutAuthenticationProvider()); - OidcSessionRegistry sessionRegistry = mock(OidcSessionRegistry.class); + JwtDecoderFactory decoderFactory = mock(JwtDecoderFactory.class); - OidcBackChannelLogoutHandler logoutHandler = spy(new OidcBackChannelLogoutHandler()); + OidcSessionRegistry sessionRegistry = mock(OidcSessionRegistry.class); @Bean @Order(1) SecurityFilterChain filters(HttpSecurity http) throws Exception { - this.logoutHandler.setSessionRegistry(this.sessionRegistry); + this.authenticationProvider.setSessionRegistry(this.sessionRegistry); + this.authenticationProvider.setLogoutTokenDecoderFactory(this.decoderFactory); // @formatter:off http .authorizeHttpRequests((authorize) -> authorize.anyRequest().authenticated()) .oauth2Login((oauth2) -> oauth2.oidcSessionRegistry(this.sessionRegistry)) .oidcLogout((oidc) -> oidc.backChannel((logout) -> logout .authenticationConverter(this.authenticationConverter) - .authenticationManager(this.authenticationManager) - .logoutHandler(this.logoutHandler) + .authenticationManager(new ProviderManager(this.authenticationProvider)) )); // @formatter:on @@ -271,8 +276,8 @@ AuthenticationConverter authenticationConverter() { } @Bean - AuthenticationManager authenticationManager() { - return this.authenticationManager; + AuthenticationProvider authenticationProvider() { + return this.authenticationProvider; } @Bean @@ -281,8 +286,8 @@ OidcSessionRegistry sessionRegistry() { } @Bean - LogoutHandler logoutHandler() { - return this.logoutHandler; + JwtDecoderFactory jwtDecoderFactory() { + return this.decoderFactory; } } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/logout/OidcBackChannelLogoutAuthentication.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/logout/OidcBackChannelLogoutAuthentication.java index 0f12ad3f06e..a475897a069 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/logout/OidcBackChannelLogoutAuthentication.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/logout/OidcBackChannelLogoutAuthentication.java @@ -19,6 +19,7 @@ import java.util.Collections; import org.springframework.security.authentication.AbstractAuthenticationToken; +import org.springframework.security.oauth2.client.oidc.session.OidcSessionInformation; /** * An {@link org.springframework.security.core.Authentication} implementation that @@ -36,13 +37,17 @@ public class OidcBackChannelLogoutAuthentication extends AbstractAuthenticationT private final OidcLogoutToken logoutToken; + private final Iterable invalidated; + /** * Construct an {@link OidcBackChannelLogoutAuthentication} * @param logoutToken a deserialized, verified OIDC Logout Token */ - public OidcBackChannelLogoutAuthentication(OidcLogoutToken logoutToken) { + public OidcBackChannelLogoutAuthentication(OidcLogoutToken logoutToken, + Iterable invalidated) { super(Collections.emptyList()); this.logoutToken = logoutToken; + this.invalidated = invalidated; setAuthenticated(true); } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/logout/OidcBackChannelLogoutAuthenticationProvider.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/logout/OidcBackChannelLogoutAuthenticationProvider.java index 8c417cc1143..089722bd44f 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/logout/OidcBackChannelLogoutAuthenticationProvider.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/logout/OidcBackChannelLogoutAuthenticationProvider.java @@ -16,11 +16,23 @@ package org.springframework.security.oauth2.client.oidc.authentication.logout; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Map; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import org.springframework.http.HttpEntity; +import org.springframework.http.HttpHeaders; import org.springframework.security.authentication.AuthenticationProvider; import org.springframework.security.authentication.AuthenticationServiceException; import org.springframework.security.core.Authentication; import org.springframework.security.core.AuthenticationException; import org.springframework.security.oauth2.client.oidc.authentication.OidcIdTokenDecoderFactory; +import org.springframework.security.oauth2.client.oidc.session.InMemoryOidcSessionRegistry; +import org.springframework.security.oauth2.client.oidc.session.OidcSessionInformation; +import org.springframework.security.oauth2.client.oidc.session.OidcSessionRegistry; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2Error; @@ -30,6 +42,10 @@ import org.springframework.security.oauth2.jwt.JwtDecoder; import org.springframework.security.oauth2.jwt.JwtDecoderFactory; import org.springframework.util.Assert; +import org.springframework.web.client.RestClientException; +import org.springframework.web.client.RestOperations; +import org.springframework.web.client.RestTemplate; +import org.springframework.web.util.UriComponentsBuilder; /** * An {@link AuthenticationProvider} that authenticates an OIDC Logout Token; namely @@ -49,8 +65,18 @@ */ public final class OidcBackChannelLogoutAuthenticationProvider implements AuthenticationProvider { + private final Log logger = LogFactory.getLog(getClass()); + private JwtDecoderFactory logoutTokenDecoderFactory; + private OidcSessionRegistry sessionRegistry = new InMemoryOidcSessionRegistry(); + + private RestOperations restOperations = new RestTemplate(); + + private String logoutEndpointName = "/logout"; + + private String sessionCookieName = "JSESSIONID"; + /** * Construct an {@link OidcBackChannelLogoutAuthenticationProvider} */ @@ -73,7 +99,8 @@ public Authentication authenticate(Authentication authentication) throws Authent Jwt jwt = decode(registration, logoutToken); OidcLogoutToken oidcLogoutToken = OidcLogoutToken.withTokenValue(logoutToken) .claims((claims) -> claims.putAll(jwt.getClaims())).build(); - return new OidcBackChannelLogoutAuthentication(oidcLogoutToken); + Collection loggedOut = logout(token.getBaseUrl(), oidcLogoutToken); + return new OidcBackChannelLogoutAuthentication(oidcLogoutToken, loggedOut); } /** @@ -99,6 +126,40 @@ private Jwt decode(ClientRegistration registration, String token) { } } + private Collection logout(String baseUrl, OidcLogoutToken token) { + Iterable sessions = this.sessionRegistry.removeSessionInformation(token); + Collection invalidated = new ArrayList<>(); + int totalCount = 0; + int invalidatedCount = 0; + for (OidcSessionInformation session : sessions) { + totalCount++; + try { + eachLogout(baseUrl, session); + invalidated.add(session); + invalidatedCount++; + } + catch (RestClientException ex) { + this.logger.debug("Failed to invalidate session", ex); + } + } + if (this.logger.isTraceEnabled()) { + this.logger.trace(String.format("Invalidated %d out of %d sessions", invalidatedCount, totalCount)); + } + return invalidated; + } + + private void eachLogout(String baseUrl, OidcSessionInformation session) { + HttpHeaders headers = new HttpHeaders(); + headers.add(HttpHeaders.COOKIE, this.sessionCookieName + "=" + session.getSessionId()); + for (Map.Entry credential : session.getAuthorities().entrySet()) { + headers.add(credential.getKey(), credential.getValue()); + } + String logout = UriComponentsBuilder.fromHttpUrl(baseUrl).replacePath(this.logoutEndpointName).build() + .toUriString(); + HttpEntity entity = new HttpEntity<>(null, headers); + this.restOperations.postForEntity(logout, entity, Object.class); + } + /** * Use this {@link JwtDecoderFactory} to generate {@link JwtDecoder}s that correspond * to the {@link ClientRegistration} associated with the OIDC logout token. @@ -109,4 +170,8 @@ public void setLogoutTokenDecoderFactory(JwtDecoderFactory l this.logoutTokenDecoderFactory = logoutTokenDecoderFactory; } + public void setSessionRegistry(OidcSessionRegistry sessionRegistry) { + this.sessionRegistry = sessionRegistry; + } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/logout/OidcLogoutAuthenticationToken.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/logout/OidcLogoutAuthenticationToken.java index 8912dc52df1..ef8d9ed287c 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/logout/OidcLogoutAuthenticationToken.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/logout/OidcLogoutAuthenticationToken.java @@ -33,16 +33,19 @@ public class OidcLogoutAuthenticationToken extends AbstractAuthenticationToken { private final ClientRegistration clientRegistration; + private final String baseUrl; + /** * Construct an {@link OidcLogoutAuthenticationToken} * @param logoutToken a signed, serialized OIDC Logout token * @param clientRegistration the {@link ClientRegistration client} associated with * this token; this is usually derived from material in the logout HTTP request */ - public OidcLogoutAuthenticationToken(String logoutToken, ClientRegistration clientRegistration) { + public OidcLogoutAuthenticationToken(String logoutToken, ClientRegistration clientRegistration, String baseUrl) { super(AuthorityUtils.NO_AUTHORITIES); this.logoutToken = logoutToken; this.clientRegistration = clientRegistration; + this.baseUrl = baseUrl; } /** @@ -77,4 +80,8 @@ public ClientRegistration getClientRegistration() { return this.clientRegistration; } + public String getBaseUrl() { + return this.baseUrl; + } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/web/OidcBackChannelLogoutFilter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/web/OidcBackChannelLogoutFilter.java index 9179329f79e..ab356f9f3b5 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/web/OidcBackChannelLogoutFilter.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/web/OidcBackChannelLogoutFilter.java @@ -30,13 +30,11 @@ import org.springframework.security.authentication.AuthenticationServiceException; import org.springframework.security.core.Authentication; import org.springframework.security.core.AuthenticationException; -import org.springframework.security.oauth2.client.oidc.web.logout.OidcBackChannelLogoutHandler; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.OAuth2ErrorCodes; import org.springframework.security.oauth2.core.http.converter.OAuth2ErrorHttpMessageConverter; import org.springframework.security.web.authentication.AuthenticationConverter; -import org.springframework.security.web.authentication.logout.LogoutHandler; import org.springframework.util.Assert; import org.springframework.web.filter.OncePerRequestFilter; @@ -59,8 +57,6 @@ public class OidcBackChannelLogoutFilter extends OncePerRequestFilter { private final OAuth2ErrorHttpMessageConverter errorHttpMessageConverter = new OAuth2ErrorHttpMessageConverter(); - private LogoutHandler logoutHandler = new OidcBackChannelLogoutHandler(); - /** * Construct an {@link OidcBackChannelLogoutFilter} * @param authenticationConverter the {@link AuthenticationConverter} for deriving @@ -98,9 +94,8 @@ protected void doFilterInternal(HttpServletRequest request, HttpServletResponse chain.doFilter(request, response); return; } - Authentication authentication; try { - authentication = this.authenticationManager.authenticate(token); + this.authenticationManager.authenticate(token); } catch (AuthenticationServiceException ex) { this.logger.debug("Failed to process OIDC Back-Channel Logout", ex); @@ -108,9 +103,7 @@ protected void doFilterInternal(HttpServletRequest request, HttpServletResponse } catch (AuthenticationException ex) { handleAuthenticationFailure(response, ex); - return; } - this.logoutHandler.logout(request, response, authentication); } private void handleAuthenticationFailure(HttpServletResponse response, AuthenticationException ex) @@ -128,14 +121,4 @@ private OAuth2Error oauth2Error(AuthenticationException ex) { "https://openid.net/specs/openid-connect-backchannel-1_0.html#Validation"); } - /** - * The strategy for expiring all Client sessions indicated by the logout request. - * Defaults to {@link OidcBackChannelLogoutHandler}. - * @param logoutHandler the {@link LogoutHandler} to use - */ - public void setLogoutHandler(LogoutHandler logoutHandler) { - Assert.notNull(logoutHandler, "logoutHandler cannot be null"); - this.logoutHandler = logoutHandler; - } - } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/web/logout/OidcBackChannelLogoutHandler.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/web/logout/OidcBackChannelLogoutHandler.java deleted file mode 100644 index 0263ae649f9..00000000000 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/web/logout/OidcBackChannelLogoutHandler.java +++ /dev/null @@ -1,147 +0,0 @@ -/* - * Copyright 2002-2023 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.security.oauth2.client.oidc.web.logout; - -import java.util.Map; - -import jakarta.servlet.http.HttpServletRequest; -import jakarta.servlet.http.HttpServletResponse; -import org.apache.commons.logging.Log; -import org.apache.commons.logging.LogFactory; - -import org.springframework.http.HttpEntity; -import org.springframework.http.HttpHeaders; -import org.springframework.security.core.Authentication; -import org.springframework.security.oauth2.client.oidc.authentication.logout.OidcBackChannelLogoutAuthentication; -import org.springframework.security.oauth2.client.oidc.authentication.logout.OidcLogoutToken; -import org.springframework.security.oauth2.client.oidc.session.InMemoryOidcSessionRegistry; -import org.springframework.security.oauth2.client.oidc.session.OidcSessionInformation; -import org.springframework.security.oauth2.client.oidc.session.OidcSessionRegistry; -import org.springframework.security.web.authentication.logout.LogoutHandler; -import org.springframework.util.Assert; -import org.springframework.web.client.RestClientException; -import org.springframework.web.client.RestOperations; -import org.springframework.web.client.RestTemplate; -import org.springframework.web.util.UriComponentsBuilder; - -/** - * A {@link LogoutHandler} that locates the sessions associated with a given OIDC - * Back-Channel Logout Token and invalidates each one. - * - * @author Josh Cummings - * @since 6.2 - * @see OIDC Back-Channel Logout - * Spec - */ -public final class OidcBackChannelLogoutHandler implements LogoutHandler { - - private final Log logger = LogFactory.getLog(getClass()); - - private OidcSessionRegistry sessionRegistry = new InMemoryOidcSessionRegistry(); - - private RestOperations restOperations = new RestTemplate(); - - private String logoutEndpointName = "/logout"; - - private String sessionCookieName = "JSESSIONID"; - - @Override - public void logout(HttpServletRequest request, HttpServletResponse response, Authentication authentication) { - if (!(authentication instanceof OidcBackChannelLogoutAuthentication token)) { - if (this.logger.isDebugEnabled()) { - String message = "Did not perform OIDC Back-Channel Logout since authentication [%s] was of the wrong type"; - this.logger.debug(String.format(message, authentication.getClass().getSimpleName())); - } - return; - } - Iterable sessions = this.sessionRegistry.removeSessionInformation(token.getPrincipal()); - int totalCount = 0; - int invalidatedCount = 0; - for (OidcSessionInformation session : sessions) { - totalCount++; - try { - eachLogout(request, session); - invalidatedCount++; - } - catch (RestClientException ex) { - this.logger.debug("Failed to invalidate session", ex); - } - } - if (this.logger.isTraceEnabled()) { - this.logger.trace(String.format("Invalidated %d out of %d sessions", invalidatedCount, totalCount)); - } - } - - private void eachLogout(HttpServletRequest request, OidcSessionInformation session) { - HttpHeaders headers = new HttpHeaders(); - headers.add(HttpHeaders.COOKIE, this.sessionCookieName + "=" + session.getSessionId()); - for (Map.Entry credential : session.getAuthorities().entrySet()) { - headers.add(credential.getKey(), credential.getValue()); - } - String url = request.getRequestURL().toString(); - String logout = UriComponentsBuilder.fromHttpUrl(url).replacePath(this.logoutEndpointName).build() - .toUriString(); - HttpEntity entity = new HttpEntity<>(null, headers); - this.restOperations.postForEntity(logout, entity, Object.class); - } - - /** - * Use this {@link OidcSessionRegistry} to identify sessions to invalidate. Note that - * this class uses - * {@link OidcSessionRegistry#removeSessionInformation(OidcLogoutToken)} to identify - * sessions. - * @param sessionRegistry the {@link OidcSessionRegistry} to use - */ - public void setSessionRegistry(OidcSessionRegistry sessionRegistry) { - Assert.notNull(sessionRegistry, "sessionRegistry cannot be null"); - this.sessionRegistry = sessionRegistry; - } - - /** - * Use this {@link RestOperations} to perform the per-session back-channel logout - * @param restOperations the {@link RestOperations} to use - */ - public void setRestOperations(RestOperations restOperations) { - Assert.notNull(restOperations, "restOperations cannot be null"); - this.restOperations = restOperations; - } - - /** - * Use this logout URI for performing per-session logout. Defaults to {@code /logout} - * since that is the default URI for - * {@link org.springframework.security.web.authentication.logout.LogoutFilter}. - * @param logoutUri the URI to use - */ - public void setLogoutUri(String logoutUri) { - Assert.hasText(logoutUri, "logoutUri cannot be empty"); - this.logoutEndpointName = logoutUri; - } - - /** - * Use this cookie name for the session identifier. Defaults to {@code JSESSIONID}. - * - *

- * Note that if you are using Spring Session, this likely needs to change to SESSION. - * @param sessionCookieName the cookie name to use - */ - public void setSessionCookieName(String sessionCookieName) { - Assert.hasText(sessionCookieName, "clientSessionCookieName cannot be empty"); - this.sessionCookieName = sessionCookieName; - } - -} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/web/logout/OidcLogoutAuthenticationConverter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/web/logout/OidcLogoutAuthenticationConverter.java index bf21f075e47..47727ffa2b2 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/web/logout/OidcLogoutAuthenticationConverter.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/web/logout/OidcLogoutAuthenticationConverter.java @@ -70,7 +70,7 @@ public Authentication convert(HttpServletRequest request) { this.logger.debug("Failed to process OIDC Back-Channel Logout since no logout token was found"); throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_REQUEST); } - return new OidcLogoutAuthenticationToken(logoutToken, clientRegistration); + return new OidcLogoutAuthenticationToken(logoutToken, clientRegistration, request.getRequestURL().toString()); } /** diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/web/OidcBackChannelLogoutFilterTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/web/OidcBackChannelLogoutFilterTests.java index 698d92c7480..be9726caf97 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/web/OidcBackChannelLogoutFilterTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/web/OidcBackChannelLogoutFilterTests.java @@ -16,27 +16,19 @@ package org.springframework.security.oauth2.client.oidc.web; -import java.util.Set; - import jakarta.servlet.FilterChain; import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.security.authentication.AuthenticationManager; import org.springframework.security.authentication.BadCredentialsException; -import org.springframework.security.oauth2.client.oidc.authentication.logout.OidcBackChannelLogoutAuthentication; -import org.springframework.security.oauth2.client.oidc.authentication.logout.OidcLogoutToken; -import org.springframework.security.oauth2.client.oidc.authentication.logout.TestOidcLogoutTokens; -import org.springframework.security.oauth2.client.oidc.session.OidcSessionInformation; -import org.springframework.security.oauth2.client.oidc.session.OidcSessionRegistry; -import org.springframework.security.oauth2.client.oidc.session.TestOidcSessionInformations; -import org.springframework.security.oauth2.client.oidc.web.logout.OidcBackChannelLogoutHandler; +import org.springframework.security.oauth2.client.oidc.authentication.logout.OidcLogoutAuthenticationToken; import org.springframework.security.oauth2.client.oidc.web.logout.OidcLogoutAuthenticationConverter; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; import org.springframework.security.oauth2.client.registration.TestClientRegistrations; -import org.springframework.security.web.authentication.logout.LogoutHandler; import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; @@ -101,17 +93,8 @@ public void doFilterWithSessionMatchingLogoutTokenThenInvalidates() throws Excep ClientRegistrationRepository clientRegistrationRepository = mock(ClientRegistrationRepository.class); given(clientRegistrationRepository.findByRegistrationId(any())).willReturn(clientRegistration); AuthenticationManager authenticationManager = mock(AuthenticationManager.class); - OidcLogoutToken token = TestOidcLogoutTokens.withSessionId("issuer", "provider").build(); - Iterable infos = Set.of(TestOidcSessionInformations.create("clientOne"), - TestOidcSessionInformations.create("clientTwo")); - given(authenticationManager.authenticate(any())).willReturn(new OidcBackChannelLogoutAuthentication(token)); - OidcBackChannelLogoutHandler backChannelLogoutHandler = new OidcBackChannelLogoutHandler(); - OidcSessionRegistry sessionRegistry = mock(OidcSessionRegistry.class); - given(sessionRegistry.removeSessionInformation(any(OidcLogoutToken.class))).willReturn(infos); - backChannelLogoutHandler.setSessionRegistry(sessionRegistry); OidcBackChannelLogoutFilter filter = new OidcBackChannelLogoutFilter( new OidcLogoutAuthenticationConverter(clientRegistrationRepository), authenticationManager); - filter.setLogoutHandler(backChannelLogoutHandler); MockHttpServletRequest request = new MockHttpServletRequest("POST", "/oauth2/" + clientRegistration.getRegistrationId() + "/logout"); request.setServletPath("/logout/connect/back-channel/id"); @@ -119,7 +102,10 @@ public void doFilterWithSessionMatchingLogoutTokenThenInvalidates() throws Excep MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain chain = mock(FilterChain.class); filter.doFilter(request, response, chain); - verify(sessionRegistry).removeSessionInformation(token); + ArgumentCaptor authentication = ArgumentCaptor + .forClass(OidcLogoutAuthenticationToken.class); + verify(authenticationManager).authenticate(authentication.capture()); + assertThat(authentication.getValue().getLogoutToken()).isEqualTo("logout_token"); verifyNoInteractions(chain); assertThat(response.getStatus()).isEqualTo(200); } @@ -131,10 +117,8 @@ public void doFilterWhenInvalidJwtThenBadRequest() throws Exception { given(clientRegistrationRepository.findByRegistrationId(any())).willReturn(clientRegistration); AuthenticationManager authenticationManager = mock(AuthenticationManager.class); given(authenticationManager.authenticate(any())).willThrow(new BadCredentialsException("bad")); - LogoutHandler logoutHandler = mock(LogoutHandler.class); OidcBackChannelLogoutFilter backChannelLogoutFilter = new OidcBackChannelLogoutFilter( new OidcLogoutAuthenticationConverter(clientRegistrationRepository), authenticationManager); - backChannelLogoutFilter.setLogoutHandler(logoutHandler); MockHttpServletRequest request = new MockHttpServletRequest("POST", "/oauth2/" + clientRegistration.getRegistrationId() + "/logout"); request.setServletPath("/logout/connect/back-channel/id"); @@ -142,7 +126,7 @@ public void doFilterWhenInvalidJwtThenBadRequest() throws Exception { MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain chain = mock(FilterChain.class); backChannelLogoutFilter.doFilter(request, response, chain); - verifyNoInteractions(logoutHandler, chain); + verifyNoInteractions(chain); assertThat(response.getStatus()).isEqualTo(400); assertThat(response.getContentAsString()).contains("bad"); }