diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OidcLogoutConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OidcLogoutConfigurer.java index ed6d9dbfa83..39c50265f08 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OidcLogoutConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OidcLogoutConfigurer.java @@ -32,6 +32,8 @@ import org.springframework.context.event.GenericApplicationListenerAdapter; import org.springframework.context.event.SmartApplicationListener; import org.springframework.security.authentication.AuthenticationManager; +import org.springframework.security.authentication.ProviderManager; +import org.springframework.security.config.Customizer; import org.springframework.security.config.annotation.web.HttpSecurityBuilder; import org.springframework.security.config.annotation.web.builders.HttpSecurity; import org.springframework.security.config.annotation.web.configurers.AbstractHttpConfigurer; @@ -41,16 +43,14 @@ import org.springframework.security.core.session.AbstractSessionEvent; import org.springframework.security.core.session.SessionDestroyedEvent; import org.springframework.security.core.session.SessionIdChangedEvent; -import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken; -import org.springframework.security.oauth2.client.oidc.authentication.logout.OidcBackChannelLogoutAuthenticationManager; -import org.springframework.security.oauth2.client.oidc.authentication.session.InMemoryOidcSessionRegistry; -import org.springframework.security.oauth2.client.oidc.authentication.session.OidcSessionRegistration; -import org.springframework.security.oauth2.client.oidc.authentication.session.OidcSessionRegistry; -import org.springframework.security.oauth2.client.oidc.web.authentication.logout.OidcBackChannelLogoutFilter; -import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.oidc.authentication.logout.OidcBackChannelLogoutAuthenticationProvider; +import org.springframework.security.oauth2.client.oidc.session.InMemoryOidcSessionRegistry; +import org.springframework.security.oauth2.client.oidc.session.OidcSessionInformation; +import org.springframework.security.oauth2.client.oidc.session.OidcSessionRegistry; +import org.springframework.security.oauth2.client.oidc.web.OidcBackChannelLogoutFilter; +import org.springframework.security.oauth2.client.oidc.web.logout.OidcBackChannelLogoutHandler; import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; import org.springframework.security.oauth2.core.oidc.user.OidcUser; -import org.springframework.security.web.authentication.logout.BackchannelLogoutHandler; import org.springframework.security.web.authentication.logout.LogoutHandler; import org.springframework.security.web.authentication.session.SessionAuthenticationException; import org.springframework.security.web.authentication.session.SessionAuthenticationStrategy; @@ -59,11 +59,11 @@ import org.springframework.util.Assert; /** - * An {@link AbstractHttpConfigurer} for OAuth 2.0 Logout flows + * An {@link AbstractHttpConfigurer} for OIDC Logout flows * *

- * OAuth 2.0 Logout provides an application with the capability to have users log out by - * using their existing account at an OAuth 2.0 or OpenID Connect 1.0 Provider. + * OIDC Logout provides an application with the capability to have users log out by using + * their existing account at an OAuth 2.0 or OpenID Connect 1.0 Provider. * * *

Security Filters

@@ -83,7 +83,7 @@ * * * @author Josh Cummings - * @since 6.1 + * @since 6.2 * @see HttpSecurity#oidcLogout() * @see OidcBackChannelLogoutFilter * @see ClientRegistrationRepository @@ -97,11 +97,11 @@ public final class OidcLogoutConfigurer> * Configure OIDC Back-Channel Logout using the provided {@link Consumer} * @return the {@link OidcLogoutConfigurer} for further configuration */ - public OidcLogoutConfigurer backChannel(Consumer backChannelLogoutConfigurer) { + public OidcLogoutConfigurer backChannel(Customizer backChannelLogoutConfigurer) { if (this.backChannel == null) { this.backChannel = new BackChannelLogoutConfigurer(); } - backChannelLogoutConfigurer.accept(this.backChannel); + backChannelLogoutConfigurer.customize(this.backChannel); return this; } @@ -139,26 +139,46 @@ private T getBeanOrNull(Class type) { } } + /** + * A configurer for configuring OIDC Back-Channel Logout + */ public final class BackChannelLogoutConfigurer { - private LogoutHandler logoutHandler = new BackchannelLogoutHandler(); - - private AuthenticationManager authenticationManager; + private AuthenticationManager authenticationManager = new ProviderManager( + new OidcBackChannelLogoutAuthenticationProvider()); private OidcSessionRegistry sessionRegistry = new InMemoryOidcSessionRegistry(); + private LogoutHandler logoutHandler; + + /** + * Use this {@link AuthenticationManager} to authenticate the OIDC Logout Token + * @param authenticationManager the {@link AuthenticationManager} to use + * @return the {@link BackChannelLogoutConfigurer} for further configuration + */ public BackChannelLogoutConfigurer authenticationManager(AuthenticationManager authenticationManager) { Assert.notNull(authenticationManager, "authenticationManager cannot be null"); this.authenticationManager = authenticationManager; return this; } - public BackChannelLogoutConfigurer sessionRegistry(OidcSessionRegistry sessionRegistry) { + /** + * Use this {@link OidcSessionRegistry} for managing the client-provider session + * link + * @param sessionRegistry the {@link OidcSessionRegistry} to use + * @return the {@link BackChannelLogoutConfigurer} for further configuration + */ + public BackChannelLogoutConfigurer oidcSessionRegistry(OidcSessionRegistry sessionRegistry) { Assert.notNull(sessionRegistry, "sessionRegistry cannot be null"); this.sessionRegistry = sessionRegistry; return this; } + /** + * Use this {@link LogoutHandler} for invalidating each session identified by the + * OIDC Back-Channel Logout Token + * @return the {@link BackChannelLogoutConfigurer} for further configuration + */ public BackChannelLogoutConfigurer logoutHandler(LogoutHandler logoutHandler) { Assert.notNull(logoutHandler, "logoutHandler cannot be null"); this.logoutHandler = logoutHandler; @@ -166,27 +186,25 @@ public BackChannelLogoutConfigurer logoutHandler(LogoutHandler logoutHandler) { } private AuthenticationManager authenticationManager() { - if (this.authenticationManager == null) { - OidcBackChannelLogoutAuthenticationManager authenticationManager = new OidcBackChannelLogoutAuthenticationManager(); - authenticationManager.setSessionRegistry(sessionRegistry()); - this.authenticationManager = authenticationManager; - } return this.authenticationManager; } - private OidcSessionRegistry sessionRegistry() { + private OidcSessionRegistry oidcSessionRegistry() { return this.sessionRegistry; } private LogoutHandler logoutHandler() { + if (this.logoutHandler == null) { + OidcBackChannelLogoutHandler logoutHandler = new OidcBackChannelLogoutHandler(); + logoutHandler.setSessionRegistry(this.sessionRegistry); + this.logoutHandler = logoutHandler; + } return this.logoutHandler; } - private SessionAuthenticationStrategy sessionAuthenticationStrategy( - ClientRegistrationRepository clientRegistrationRepository) { - OidcSessionRegistryAuthenticationStrategy strategy = new OidcSessionRegistryAuthenticationStrategy( - clientRegistrationRepository); - strategy.setSessionRegistry(sessionRegistry()); + private SessionAuthenticationStrategy sessionAuthenticationStrategy() { + OidcSessionRegistryAuthenticationStrategy strategy = new OidcSessionRegistryAuthenticationStrategy(); + strategy.setSessionRegistry(oidcSessionRegistry()); return strategy; } @@ -195,13 +213,11 @@ void configure(B http) { .getClientRegistrationRepository(http); OidcBackChannelLogoutFilter filter = new OidcBackChannelLogoutFilter(clientRegistrationRepository, authenticationManager()); - LogoutHandler expiredStrategy = logoutHandler(); - filter.setLogoutHandler(expiredStrategy); + filter.setLogoutHandler(logoutHandler()); http.addFilterBefore(filter, CsrfFilter.class); SessionManagementConfigurer sessionConfigurer = http.getConfigurer(SessionManagementConfigurer.class); if (sessionConfigurer != null) { - sessionConfigurer - .addSessionAuthenticationStrategy(sessionAuthenticationStrategy(clientRegistrationRepository)); + sessionConfigurer.addSessionAuthenticationStrategy(sessionAuthenticationStrategy()); } OidcClientSessionEventListener listener = new OidcClientSessionEventListener(); listener.setSessionRegistry(this.sessionRegistry); @@ -221,12 +237,17 @@ static final class OidcClientSessionEventListener implements ApplicationListener public void onApplicationEvent(AbstractSessionEvent event) { if (event instanceof SessionDestroyedEvent destroyed) { this.logger.debug("Received SessionDestroyedEvent"); - this.sessionRegistry.deregister(destroyed.getId()); + this.sessionRegistry.removeSessionInformation(destroyed.getId()); return; } if (event instanceof SessionIdChangedEvent changed) { this.logger.debug("Received SessionIdChangedEvent"); - this.sessionRegistry.register(changed.getOldSessionId(), changed.getNewSessionId()); + OidcSessionInformation information = this.sessionRegistry.removeSessionInformation(changed.getOldSessionId()); + if (information == null) { + this.logger.debug("Failed to register new session id since old session id was not found in registry"); + return; + } + this.sessionRegistry.saveSessionInformation(information.withSessionId(changed.getNewSessionId())); } } @@ -246,14 +267,8 @@ static final class OidcSessionRegistryAuthenticationStrategy implements SessionA private final Log logger = LogFactory.getLog(getClass()); - private final ClientRegistrationRepository clientRegistrationRepository; - private OidcSessionRegistry sessionRegistry = new InMemoryOidcSessionRegistry(); - OidcSessionRegistryAuthenticationStrategy(ClientRegistrationRepository clientRegistrationRepository) { - this.clientRegistrationRepository = clientRegistrationRepository; - } - /** * {@inheritDoc} */ @@ -263,28 +278,22 @@ public void onAuthentication(Authentication authentication, HttpServletRequest r if (session == null) { return; } - if (!(authentication instanceof OAuth2AuthenticationToken token)) { - return; - } if (!(authentication.getPrincipal() instanceof OidcUser user)) { return; } - String registrationId = token.getAuthorizedClientRegistrationId(); - ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId(registrationId); - String clientId = clientRegistration.getClientId(); String sessionId = session.getId(); CsrfToken csrfToken = (CsrfToken) request.getAttribute(CsrfToken.class.getName()); Map headers = (csrfToken != null) ? Map.of(csrfToken.getHeaderName(), csrfToken.getToken()) : Collections.emptyMap(); - OidcSessionRegistration registration = new OidcSessionRegistration(clientId, sessionId, headers, user); + OidcSessionInformation registration = new OidcSessionInformation(sessionId, headers, user); if (this.logger.isTraceEnabled()) { this.logger.trace(String.format("Linking a provider [%s] session to this client's session", user.getIssuer())); } - this.sessionRegistry.register(registration); + this.sessionRegistry.saveSessionInformation(registration); } /** * The registration for linking OIDC Provider Session information to the - * Client's session. Defaults to in-memory. + * Client's session. Defaults to in-memory storage. * @param sessionRegistry the {@link OidcSessionRegistry} to use */ void setSessionRegistry(OidcSessionRegistry sessionRegistry) { diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OidcLogoutConfigurerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OidcLogoutConfigurerTests.java index 1282bcb9323..4b0bcfc08ab 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OidcLogoutConfigurerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OidcLogoutConfigurerTests.java @@ -62,12 +62,13 @@ import org.springframework.security.core.userdetails.User; import org.springframework.security.core.userdetails.UserDetailsService; import org.springframework.security.oauth2.client.oidc.authentication.logout.LogoutTokenClaimNames; -import org.springframework.security.oauth2.client.oidc.authentication.logout.OidcBackChannelLogoutAuthenticationManager; +import org.springframework.security.oauth2.client.oidc.authentication.logout.OidcBackChannelLogoutAuthentication; import org.springframework.security.oauth2.client.oidc.authentication.logout.OidcLogoutToken; import org.springframework.security.oauth2.client.oidc.authentication.logout.TestOidcLogoutTokens; -import org.springframework.security.oauth2.client.oidc.authentication.session.OidcSessionRegistration; -import org.springframework.security.oauth2.client.oidc.authentication.session.OidcSessionRegistry; -import org.springframework.security.oauth2.client.oidc.authentication.session.TestOidcSessionRegistrations; +import org.springframework.security.oauth2.client.oidc.session.OidcSessionInformation; +import org.springframework.security.oauth2.client.oidc.session.OidcSessionRegistry; +import org.springframework.security.oauth2.client.oidc.session.TestOidcSessionInformations; +import org.springframework.security.oauth2.client.oidc.web.logout.OidcBackChannelLogoutHandler; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; import org.springframework.security.oauth2.client.registration.InMemoryClientRegistrationRepository; @@ -81,7 +82,6 @@ import org.springframework.security.oauth2.jwt.NimbusJwtEncoder; import org.springframework.security.provisioning.InMemoryUserDetailsManager; import org.springframework.security.web.SecurityFilterChain; -import org.springframework.security.web.authentication.logout.BackchannelLogoutAuthentication; import org.springframework.security.web.authentication.logout.LogoutHandler; import org.springframework.test.web.servlet.MockMvc; import org.springframework.test.web.servlet.MvcResult; @@ -96,6 +96,7 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.csrf; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.httpBasic; @@ -103,6 +104,9 @@ import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; +/** + * Tests for {@link OidcLogoutConfigurer} + */ @ExtendWith(SpringTestContextExtension.class) public class OidcLogoutConfigurerTests { @@ -113,7 +117,7 @@ public class OidcLogoutConfigurerTests { private MockWebServer web; @Autowired - private ClientRegistration registration; + private ClientRegistration clientRegistration; public final SpringTestContext spring = new SpringTestContext(this); @@ -122,14 +126,14 @@ void logoutWhenDefaultsThenRemotelyInvalidatesSessions() throws Exception { this.spring.register(WebServerConfig.class, OidcProviderConfig.class, DefaultConfig.class).autowire(); MockMvcDispatcher dispatcher = (MockMvcDispatcher) this.web.getDispatcher(); this.mvc.perform(get("/token/logout")).andExpect(status().isUnauthorized()); - String registrationId = this.registration.getRegistrationId(); + String registrationId = this.clientRegistration.getRegistrationId(); MvcResult result = this.mvc.perform(get("/oauth2/authorization/" + registrationId)) .andExpect(status().isFound()).andReturn(); MockHttpSession session = (MockHttpSession) result.getRequest().getSession(); String redirectUrl = UrlUtils.decode(result.getResponse().getRedirectedUrl()); String state = this.mvc - .perform(get(redirectUrl) - .with(httpBasic(this.registration.getClientId(), this.registration.getClientSecret()))) + .perform(get(redirectUrl).with( + httpBasic(this.clientRegistration.getClientId(), this.clientRegistration.getClientSecret()))) .andReturn().getResponse().getContentAsString(); result = this.mvc.perform(get("/login/oauth2/code/" + registrationId).param("code", "code") .param("state", state).session(session)).andExpect(status().isFound()).andReturn(); @@ -146,14 +150,14 @@ void logoutWhenDefaultsThenRemotelyInvalidatesSessions() throws Exception { void logoutWhenInvalidLogoutTokenThenBadRequest() throws Exception { this.spring.register(WebServerConfig.class, OidcProviderConfig.class, DefaultConfig.class).autowire(); this.mvc.perform(get("/token/logout")).andExpect(status().isUnauthorized()); - String registrationId = this.registration.getRegistrationId(); + String registrationId = this.clientRegistration.getRegistrationId(); MvcResult result = this.mvc.perform(get("/oauth2/authorization/" + registrationId)) .andExpect(status().isFound()).andReturn(); MockHttpSession session = (MockHttpSession) result.getRequest().getSession(); String redirectUrl = UrlUtils.decode(result.getResponse().getRedirectedUrl()); String state = this.mvc - .perform(get(redirectUrl) - .with(httpBasic(this.registration.getClientId(), this.registration.getClientSecret()))) + .perform(get(redirectUrl).with( + httpBasic(this.clientRegistration.getClientId(), this.clientRegistration.getClientSecret()))) .andReturn().getResponse().getContentAsString(); result = this.mvc.perform(get("/login/oauth2/code/" + registrationId).param("code", "code") .param("state", state).session(session)).andExpect(status().isFound()).andReturn(); @@ -166,18 +170,19 @@ void logoutWhenInvalidLogoutTokenThenBadRequest() throws Exception { @Test void logoutWhenCustomComponentsThenUses() throws Exception { this.spring.register(WithCustomComponentsConfig.class).autowire(); - String registrationId = this.registration.getRegistrationId(); + String registrationId = this.clientRegistration.getRegistrationId(); AuthenticationManager authenticationManager = this.spring.getContext().getBean(AuthenticationManager.class); OidcLogoutToken logoutToken = TestOidcLogoutTokens.withSessionId("issuer", "provider").build(); - Set details = Set.of(TestOidcSessionRegistrations.create()); given(authenticationManager.authenticate(any())) - .willReturn(new BackchannelLogoutAuthentication(logoutToken, logoutToken, details)); - LogoutHandler logoutHandler = this.spring.getContext().getBean(LogoutHandler.class); + .willReturn(new OidcBackChannelLogoutAuthentication(logoutToken)); + OidcSessionRegistry sessionRegistry = this.spring.getContext().getBean(OidcSessionRegistry.class); + Set details = Set.of(TestOidcSessionInformations.create()); + given(sessionRegistry.removeSessionInformation(logoutToken)).willReturn(details); this.mvc.perform(post("/logout/connect/back-channel/" + registrationId).param("logout_token", "token")) .andExpect(status().isOk()); - // verify(registry).deregister(any(OidcLogoutToken.class)); verify(authenticationManager).authenticate(any()); - verify(logoutHandler).logout(any(), any(), any()); + verify(this.spring.getContext().getBean(LogoutHandler.class)).logout(any(), any(), any()); + verify(sessionRegistry).removeSessionInformation(logoutToken); } @Configuration @@ -187,7 +192,7 @@ static class RegistrationConfig { MockWebServer web; @Bean - ClientRegistration registration() { + ClientRegistration clientRegistration() { if (this.web == null) { return TestClientRegistrations.clientRegistration().build(); } @@ -197,8 +202,8 @@ ClientRegistration registration() { } @Bean - ClientRegistrationRepository registrations(ClientRegistration registration) { - return new InMemoryClientRegistrationRepository(registration); + ClientRegistrationRepository clientRegistrationRepository(ClientRegistration clientRegistration) { + return new InMemoryClientRegistrationRepository(clientRegistration); } } @@ -215,9 +220,7 @@ SecurityFilterChain filters(HttpSecurity http) throws Exception { http .authorizeHttpRequests((authorize) -> authorize.anyRequest().authenticated()) .oauth2Login(Customizer.withDefaults()) - .oidcLogout((oauth2) -> oauth2. - backChannel((backchannel) -> { }) - ); + .oidcLogout((oidc) -> oidc.backChannel(Customizer.withDefaults())); // @formatter:on return http.build(); @@ -232,25 +235,23 @@ static class WithCustomComponentsConfig { AuthenticationManager authenticationManager = mock(AuthenticationManager.class); - LogoutHandler logoutHandler = mock(LogoutHandler.class); + OidcSessionRegistry sessionRegistry = mock(OidcSessionRegistry.class); - OidcSessionRegistry registry = mock(OidcSessionRegistry.class); + OidcBackChannelLogoutHandler logoutHandler = spy(new OidcBackChannelLogoutHandler()); @Bean @Order(1) SecurityFilterChain filters(HttpSecurity http) throws Exception { - OidcBackChannelLogoutAuthenticationManager authenticationManager = new OidcBackChannelLogoutAuthenticationManager(); - authenticationManager.setSessionRegistry(this.registry); + this.logoutHandler.setSessionRegistry(this.sessionRegistry); // @formatter:off http .authorizeHttpRequests((authorize) -> authorize.anyRequest().authenticated()) .oauth2Login(Customizer.withDefaults()) - .oidcLogout((oauth2) -> oauth2. - backChannel((backchannel) -> backchannel - .logoutHandler(this.logoutHandler) - .authenticationManager(this.authenticationManager) - ) - ); + .oidcLogout((oidc) -> oidc.backChannel((logout) -> logout + .authenticationManager(this.authenticationManager) + .oidcSessionRegistry(this.sessionRegistry) + .logoutHandler(this.logoutHandler) + )); // @formatter:on return http.build(); @@ -262,13 +263,13 @@ AuthenticationManager authenticationManager() { } @Bean - LogoutHandler logoutHandler() { - return this.logoutHandler; + OidcSessionRegistry sessionRegistry() { + return this.sessionRegistry; } @Bean - OidcSessionRegistry providerSessionRegistry() { - return this.registry; + LogoutHandler logoutHandler() { + return this.logoutHandler; } } @@ -325,7 +326,7 @@ SecurityFilterChain authorizationServer(HttpSecurity http, ClientRegistration re ) .httpBasic(Customizer.withDefaults()) .oauth2ResourceServer((oauth2) -> oauth2 - .jwt().jwkSetUri(registration.getProviderDetails().getJwkSetUri()) + .jwt((jwt) -> jwt.jwkSetUri(registration.getProviderDetails().getJwkSetUri())) ); // @formatter:off diff --git a/core/src/main/java/org/springframework/security/core/session/SessionInformation.java b/core/src/main/java/org/springframework/security/core/session/SessionInformation.java index db53d4bfe1b..54b05bbbb08 100644 --- a/core/src/main/java/org/springframework/security/core/session/SessionInformation.java +++ b/core/src/main/java/org/springframework/security/core/session/SessionInformation.java @@ -18,8 +18,6 @@ import java.io.Serializable; import java.util.Date; -import java.util.LinkedHashMap; -import java.util.Map; import org.springframework.security.core.SpringSecurityCoreVersion; import org.springframework.util.Assert; @@ -51,8 +49,6 @@ public class SessionInformation implements Serializable { private boolean expired = false; - private Map headers = new LinkedHashMap<>(); - public SessionInformation(Object principal, String sessionId, Date lastRequest) { Assert.notNull(principal, "Principal required"); Assert.hasText(sessionId, "SessionId required"); @@ -62,15 +58,6 @@ public SessionInformation(Object principal, String sessionId, Date lastRequest) this.lastRequest = lastRequest; } - public SessionInformation(Object principal, String sessionId, Map headers) { - Assert.notNull(principal, "Principal required"); - Assert.hasText(sessionId, "SessionId required"); - this.principal = principal; - this.sessionId = sessionId; - this.lastRequest = new Date(); - this.headers = headers; - } - public void expireNow() { this.expired = true; } @@ -87,10 +74,6 @@ public String getSessionId() { return this.sessionId; } - public Map getHeaders() { - return this.headers; - } - public boolean isExpired() { return this.expired; } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/logout/DefaultOidcLogoutTokenValidatorFactory.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/logout/DefaultOidcLogoutTokenValidatorFactory.java index 847b67f827a..d9bc0c944e6 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/logout/DefaultOidcLogoutTokenValidatorFactory.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/logout/DefaultOidcLogoutTokenValidatorFactory.java @@ -24,12 +24,12 @@ import org.springframework.security.oauth2.jwt.Jwt; import org.springframework.security.oauth2.jwt.JwtTimestampValidator; -class DefaultOidcLogoutTokenValidatorFactory implements Function> { +final class DefaultOidcLogoutTokenValidatorFactory implements Function> { @Override public OAuth2TokenValidator apply(ClientRegistration clientRegistration) { return new DelegatingOAuth2TokenValidator<>(new JwtTimestampValidator(), - new OidcLogoutTokenValidator(clientRegistration)); + new OidcBackChannelLogoutTokenValidator(clientRegistration)); } } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/logout/LogoutTokenClaimAccessor.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/logout/LogoutTokenClaimAccessor.java index 1fe9c78c466..49aeff4c3cc 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/logout/LogoutTokenClaimAccessor.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/logout/LogoutTokenClaimAccessor.java @@ -24,15 +24,15 @@ import org.springframework.security.oauth2.core.ClaimAccessor; /** - * A {@link ClaimAccessor} for the "claims" that can be returned in OIDC - * Backchannel Logout Tokens + * A {@link ClaimAccessor} for the "claims" that can be returned in OIDC Logout + * Tokens * * @author Josh Cummings - * @since 6.1 + * @since 6.2 * @see OidcLogoutToken * @see Logout - * Token + * "https://openid.net/specs/openid-connect-backchannel-1_0.html#LogoutToken">OIDC + * Back-Channel Logout Token */ public interface LogoutTokenClaimAccessor extends ClaimAccessor { diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/logout/LogoutTokenClaimNames.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/logout/LogoutTokenClaimNames.java index 5f00470ba37..9893aa350ab 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/logout/LogoutTokenClaimNames.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/logout/LogoutTokenClaimNames.java @@ -17,17 +17,16 @@ package org.springframework.security.oauth2.client.oidc.authentication.logout; /** - * The names of the "claims" defined by the OpenID Backchannel Logout 1.0 + * The names of the "claims" defined by the OpenID Back-Channel Logout 1.0 * specification that can be returned in a Logout Token. * * @author Josh Cummings - * @since 6.1 + * @since 6.2 * @see OidcLogoutToken * @see Logout - * Token + * "https://openid.net/specs/openid-connect-backchannel-1_0.html#LogoutToken">OIDC + * Back-Channel Logout Token */ - public final class LogoutTokenClaimNames { /** diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/logout/OidcBackChannelLogoutAuthentication.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/logout/OidcBackChannelLogoutAuthentication.java new file mode 100644 index 00000000000..0f12ad3f06e --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/logout/OidcBackChannelLogoutAuthentication.java @@ -0,0 +1,65 @@ +/* + * Copyright 2002-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client.oidc.authentication.logout; + +import java.util.Collections; + +import org.springframework.security.authentication.AbstractAuthenticationToken; + +/** + * An {@link org.springframework.security.core.Authentication} implementation that + * represents the result of authenticating an OIDC Logout token for the purposes of + * performing Back-Channel Logout. + * + * @author Josh Cummings + * @since 6.2 + * @see OidcLogoutAuthenticationToken + * @see OIDC Back-Channel + * Logout + */ +public class OidcBackChannelLogoutAuthentication extends AbstractAuthenticationToken { + + private final OidcLogoutToken logoutToken; + + /** + * Construct an {@link OidcBackChannelLogoutAuthentication} + * @param logoutToken a deserialized, verified OIDC Logout Token + */ + public OidcBackChannelLogoutAuthentication(OidcLogoutToken logoutToken) { + super(Collections.emptyList()); + this.logoutToken = logoutToken; + setAuthenticated(true); + } + + /** + * {@inheritDoc} + */ + @Override + public OidcLogoutToken getPrincipal() { + return this.logoutToken; + } + + /** + * {@inheritDoc} + */ + @Override + public OidcLogoutToken getCredentials() { + return this.logoutToken; + } + +} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/logout/OidcBackChannelLogoutAuthenticationManager.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/logout/OidcBackChannelLogoutAuthenticationProvider.java similarity index 60% rename from oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/logout/OidcBackChannelLogoutAuthenticationManager.java rename to oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/logout/OidcBackChannelLogoutAuthenticationProvider.java index dd05f34ef91..8c417cc1143 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/logout/OidcBackChannelLogoutAuthenticationManager.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/logout/OidcBackChannelLogoutAuthenticationProvider.java @@ -16,39 +16,56 @@ package org.springframework.security.oauth2.client.oidc.authentication.logout; -import org.springframework.security.authentication.AuthenticationManager; +import org.springframework.security.authentication.AuthenticationProvider; import org.springframework.security.authentication.AuthenticationServiceException; -import org.springframework.security.authentication.BadCredentialsException; import org.springframework.security.core.Authentication; import org.springframework.security.core.AuthenticationException; -import org.springframework.security.core.session.SessionInformation; import org.springframework.security.oauth2.client.oidc.authentication.OidcIdTokenDecoderFactory; -import org.springframework.security.oauth2.client.oidc.authentication.session.InMemoryOidcSessionRegistry; -import org.springframework.security.oauth2.client.oidc.authentication.session.OidcSessionRegistry; import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.OAuth2Error; +import org.springframework.security.oauth2.core.OAuth2ErrorCodes; import org.springframework.security.oauth2.jwt.BadJwtException; import org.springframework.security.oauth2.jwt.Jwt; import org.springframework.security.oauth2.jwt.JwtDecoder; import org.springframework.security.oauth2.jwt.JwtDecoderFactory; -import org.springframework.security.oauth2.jwt.JwtException; -import org.springframework.security.web.authentication.logout.BackchannelLogoutAuthentication; import org.springframework.util.Assert; -public final class OidcBackChannelLogoutAuthenticationManager implements AuthenticationManager { +/** + * An {@link AuthenticationProvider} that authenticates an OIDC Logout Token; namely + * deserializing it, verifying its signature, and validating its claims. + * + *

+ * Intended to be included in a + * {@link org.springframework.security.authentication.ProviderManager} + * + * @author Josh Cummings + * @since 6.2 + * @see OidcLogoutAuthenticationToken + * @see org.springframework.security.authentication.ProviderManager + * @see OIDC Back-Channel + * Logout + */ +public final class OidcBackChannelLogoutAuthenticationProvider implements AuthenticationProvider { private JwtDecoderFactory logoutTokenDecoderFactory; - private OidcSessionRegistry sessionRegistry = new InMemoryOidcSessionRegistry(); - - public OidcBackChannelLogoutAuthenticationManager() { + /** + * Construct an {@link OidcBackChannelLogoutAuthenticationProvider} + */ + public OidcBackChannelLogoutAuthenticationProvider() { OidcIdTokenDecoderFactory logoutTokenDecoderFactory = new OidcIdTokenDecoderFactory(); logoutTokenDecoderFactory.setJwtValidatorFactory(new DefaultOidcLogoutTokenValidatorFactory()); this.logoutTokenDecoderFactory = logoutTokenDecoderFactory; } + /** + * {@inheritDoc} + */ @Override public Authentication authenticate(Authentication authentication) throws AuthenticationException { - if (!(authentication instanceof LogoutTokenAuthenticationToken token)) { + if (!(authentication instanceof OidcLogoutAuthenticationToken token)) { return null; } String logoutToken = token.getLogoutToken(); @@ -56,8 +73,15 @@ public Authentication authenticate(Authentication authentication) throws Authent Jwt jwt = decode(registration, logoutToken); OidcLogoutToken oidcLogoutToken = OidcLogoutToken.withTokenValue(logoutToken) .claims((claims) -> claims.putAll(jwt.getClaims())).build(); - Iterable sessions = this.sessionRegistry.deregister(oidcLogoutToken); - return new BackchannelLogoutAuthentication(oidcLogoutToken, oidcLogoutToken, sessions); + return new OidcBackChannelLogoutAuthentication(oidcLogoutToken); + } + + /** + * {@inheritDoc} + */ + @Override + public boolean supports(Class authentication) { + return OidcLogoutAuthenticationToken.class.isAssignableFrom(authentication); } private Jwt decode(ClientRegistration registration, String token) { @@ -66,21 +90,23 @@ private Jwt decode(ClientRegistration registration, String token) { return logoutTokenDecoder.decode(token); } catch (BadJwtException failed) { - throw new BadCredentialsException(failed.getMessage(), failed); + OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST, failed.getMessage(), + "https://openid.net/specs/openid-connect-backchannel-1_0.html#Validation"); + throw new OAuth2AuthenticationException(error, failed); } - catch (JwtException failed) { + catch (Exception failed) { throw new AuthenticationServiceException(failed.getMessage(), failed); } } + /** + * Use this {@link JwtDecoderFactory} to generate {@link JwtDecoder}s that correspond + * to the {@link ClientRegistration} associated with the OIDC logout token. + * @param logoutTokenDecoderFactory the {@link JwtDecoderFactory} to use + */ public void setLogoutTokenDecoderFactory(JwtDecoderFactory logoutTokenDecoderFactory) { Assert.notNull(logoutTokenDecoderFactory, "logoutTokenDecoderFactory cannot be null"); this.logoutTokenDecoderFactory = logoutTokenDecoderFactory; } - public void setSessionRegistry(OidcSessionRegistry sessionRegistry) { - Assert.notNull(sessionRegistry, "sessionRegistry cannot be null"); - this.sessionRegistry = sessionRegistry; - } - } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/logout/OidcLogoutTokenValidator.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/logout/OidcBackChannelLogoutTokenValidator.java similarity index 89% rename from oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/logout/OidcLogoutTokenValidator.java rename to oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/logout/OidcBackChannelLogoutTokenValidator.java index 680820c134a..d0bd9408560 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/logout/OidcLogoutTokenValidator.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/logout/OidcBackChannelLogoutTokenValidator.java @@ -28,10 +28,10 @@ import org.springframework.security.oauth2.core.OAuth2TokenValidator; import org.springframework.security.oauth2.core.OAuth2TokenValidatorResult; import org.springframework.security.oauth2.jwt.Jwt; -import org.springframework.security.oauth2.jwt.JwtDecoderFactory; /** - * A {@link JwtDecoderFactory} that decodes and verifies OIDC Logout Tokens. + * A {@link OAuth2TokenValidator} that validates OIDC Logout Token claims in conformance + * with the OIDC Back-Channel Logout Spec. * * @author Josh Cummings * @since 6.2 @@ -39,8 +39,11 @@ * @see Logout * Token + * @see the OIDC + * Back-Channel Logout spec */ -public final class OidcLogoutTokenValidator implements OAuth2TokenValidator { +public final class OidcBackChannelLogoutTokenValidator implements OAuth2TokenValidator { private static final String LOGOUT_VALIDATION_URL = "https://openid.net/specs/openid-connect-backchannel-1_0.html#Validation"; @@ -50,7 +53,7 @@ public final class OidcLogoutTokenValidator implements OAuth2TokenValidator private final String issuer; - OidcLogoutTokenValidator(ClientRegistration clientRegistration) { + public OidcBackChannelLogoutTokenValidator(ClientRegistration clientRegistration) { this.audience = clientRegistration.getClientId(); this.issuer = clientRegistration.getProviderDetails().getIssuerUri(); } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/logout/LogoutTokenAuthenticationToken.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/logout/OidcLogoutAuthenticationToken.java similarity index 61% rename from oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/logout/LogoutTokenAuthenticationToken.java rename to oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/logout/OidcLogoutAuthenticationToken.java index c0e2690d79a..8912dc52df1 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/logout/LogoutTokenAuthenticationToken.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/logout/OidcLogoutAuthenticationToken.java @@ -20,32 +20,59 @@ import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.oauth2.client.registration.ClientRegistration; -public class LogoutTokenAuthenticationToken extends AbstractAuthenticationToken { +/** + * An {@link org.springframework.security.core.Authentication} instance that represents a + * request to authenticate an OIDC Logout Token. + * + * @author Josh Cummings + * @since 6.2 + */ +public class OidcLogoutAuthenticationToken extends AbstractAuthenticationToken { private final String logoutToken; private final ClientRegistration clientRegistration; - public LogoutTokenAuthenticationToken(String logoutToken, ClientRegistration clientRegistration) { + /** + * Construct an {@link OidcLogoutAuthenticationToken} + * @param logoutToken a signed, serialized OIDC Logout token + * @param clientRegistration the {@link ClientRegistration client} associated with + * this token; this is usually derived from material in the logout HTTP request + */ + public OidcLogoutAuthenticationToken(String logoutToken, ClientRegistration clientRegistration) { super(AuthorityUtils.NO_AUTHORITIES); this.logoutToken = logoutToken; this.clientRegistration = clientRegistration; } + /** + * {@inheritDoc} + */ @Override public String getCredentials() { return this.logoutToken; } + /** + * {@inheritDoc} + */ @Override public String getPrincipal() { return this.logoutToken; } + /** + * Get the signed, serialized OIDC Logout token + * @return the logout token + */ public String getLogoutToken() { return this.logoutToken; } + /** + * Get the {@link ClientRegistration} associated with this logout token + * @return the {@link ClientRegistration} + */ public ClientRegistration getClientRegistration() { return this.clientRegistration; } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/logout/OidcLogoutToken.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/logout/OidcLogoutToken.java index 3407e0da945..41b425bf408 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/logout/OidcLogoutToken.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/logout/OidcLogoutToken.java @@ -25,7 +25,6 @@ import org.springframework.security.oauth2.core.AbstractOAuth2Token; import org.springframework.security.oauth2.core.oidc.IdTokenClaimNames; -import org.springframework.security.oauth2.core.oidc.OidcIdToken; import org.springframework.util.Assert; /** @@ -37,7 +36,7 @@ * terminating sessions for a given OIDC Provider session id or End User. * * @author Josh Cummings - * @since 6.1 + * @since 6.2 * @see AbstractOAuth2Token * @see LogoutTokenClaimAccessor * @see claims; /** - * Constructs a {@code OidcLogoutToken} using the provided parameters. + * Constructs a {@link OidcLogoutToken} using the provided parameters. * @param tokenValue the Logout Token value * @param issuedAt the time at which the Logout Token was issued {@code (iat)} * @param claims the claims about the logout statement @@ -90,11 +89,11 @@ public static final class Builder { private Builder(String tokenValue) { this.tokenValue = tokenValue; this.claims.put(LogoutTokenClaimNames.EVENTS, - Collections.singletonMap(LOGOUT_TOKEN_EVENT_NAME, Collections.emptyMap())); + Collections.singletonMap(BACKCHANNEL_LOGOUT_TOKEN_EVENT_NAME, Collections.emptyMap())); } /** - * Use this token value in the resulting {@link OidcIdToken} + * Use this token value in the resulting {@link OidcLogoutToken} * @param tokenValue The token value to use * @return the {@link Builder} for further configurations */ @@ -104,7 +103,7 @@ public Builder tokenValue(String tokenValue) { } /** - * Use this claim in the resulting {@link OidcIdToken} + * Use this claim in the resulting {@link OidcLogoutToken} * @param name The claim name * @param value The claim value * @return the {@link Builder} for further configurations @@ -126,7 +125,7 @@ public Builder claims(Consumer> claimsConsumer) { } /** - * Use this audience in the resulting {@link OidcIdToken} + * Use this audience in the resulting {@link OidcLogoutToken} * @param audience The audience(s) to use * @return the {@link Builder} for further configurations */ @@ -135,7 +134,7 @@ public Builder audience(Collection audience) { } /** - * Use this issued-at timestamp in the resulting {@link OidcIdToken} + * Use this issued-at timestamp in the resulting {@link OidcLogoutToken} * @param issuedAt The issued-at timestamp to use * @return the {@link Builder} for further configurations */ @@ -144,7 +143,7 @@ public Builder issuedAt(Instant issuedAt) { } /** - * Use this issuer in the resulting {@link OidcIdToken} + * Use this issuer in the resulting {@link OidcLogoutToken} * @param issuer The issuer to use * @return the {@link Builder} for further configurations */ @@ -152,12 +151,17 @@ public Builder issuer(String issuer) { return claim(LogoutTokenClaimNames.ISS, issuer); } - public Builder jti(String id) { - return claim(LogoutTokenClaimNames.JTI, id); + /** + * Use this id to identify the resulting {@link OidcLogoutToken} + * @param jti The unique identifier to use + * @return the {@link Builder} for further configurations + */ + public Builder jti(String jti) { + return claim(LogoutTokenClaimNames.JTI, jti); } /** - * Use this subject in the resulting {@link OidcIdToken} + * Use this subject in the resulting {@link OidcLogoutToken} * @param subject The subject to use * @return the {@link Builder} for further configurations */ @@ -190,8 +194,8 @@ public OidcLogoutToken build() { Assert.notEmpty((Collection) this.claims.get(LogoutTokenClaimNames.AUD), "audience must not be empty"); Assert.notNull(this.claims.get(LogoutTokenClaimNames.JTI), "jti must not be null"); Assert.isTrue(hasLogoutTokenIdentifyingMember(), - "logout token must contain an events claim that contains a member called " - + "'http://schemas.openid.net/event/backchannel-logout' whose value is an empty Map"); + "logout token must contain an events claim that contains a member called " + "'" + + BACKCHANNEL_LOGOUT_TOKEN_EVENT_NAME + "' whose value is an empty Map"); Assert.isNull(this.claims.get("nonce"), "logout token must not contain a nonce claim"); Instant iat = toInstant(this.claims.get(IdTokenClaimNames.IAT)); return new OidcLogoutToken(this.tokenValue, iat, this.claims); @@ -201,7 +205,7 @@ private boolean hasLogoutTokenIdentifyingMember() { if (!(this.claims.get(LogoutTokenClaimNames.EVENTS) instanceof Map events)) { return false; } - if (!(events.get("http://schemas.openid.net/event/backchannel-logout") instanceof Map object)) { + if (!(events.get(BACKCHANNEL_LOGOUT_TOKEN_EVENT_NAME) instanceof Map object)) { return false; } return object.isEmpty(); diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/session/OidcSessionRegistration.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/session/OidcSessionRegistration.java deleted file mode 100644 index 8113fe1935c..00000000000 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/session/OidcSessionRegistration.java +++ /dev/null @@ -1,57 +0,0 @@ -/* - * Copyright 2002-2023 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.security.oauth2.client.oidc.authentication.session; - -import java.util.Map; - -import org.springframework.security.core.session.SessionInformation; -import org.springframework.security.oauth2.core.oidc.user.OidcUser; - -/** - * The default implementation for {@link OidcSessionRegistration}. Handy for in-memory - * registries. - * - * @author Josh Cummings - * @since 6.2 - */ -public class OidcSessionRegistration extends SessionInformation { - - private String clientRegistrationId; - - /** - * Construct an {@link OidcSessionRegistration} - * @param sessionId the Client's session id - * @param additionalHeaders any additional headers needed to authenticate session - * ownership - * @param user the OIDC Provider's session and end user - */ - public OidcSessionRegistration(String clientId, String sessionId, Map additionalHeaders, - OidcUser user) { - super(user, sessionId, additionalHeaders); - this.clientRegistrationId = clientId; - } - - public String getClientId() { - return this.clientRegistrationId; - } - - @Override - public OidcUser getPrincipal() { - return (OidcUser) super.getPrincipal(); - } - -} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/session/InMemoryOidcSessionRegistry.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/session/InMemoryOidcSessionRegistry.java similarity index 65% rename from oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/session/InMemoryOidcSessionRegistry.java rename to oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/session/InMemoryOidcSessionRegistry.java index 1215bc5b709..f5bb6235df3 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/session/InMemoryOidcSessionRegistry.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/session/InMemoryOidcSessionRegistry.java @@ -14,8 +14,9 @@ * limitations under the License. */ -package org.springframework.security.oauth2.client.oidc.authentication.session; +package org.springframework.security.oauth2.client.oidc.session; +import java.util.Collections; import java.util.HashSet; import java.util.List; import java.util.Map; @@ -39,40 +40,29 @@ public final class InMemoryOidcSessionRegistry implements OidcSessionRegistry { private final Log logger = LogFactory.getLog(InMemoryOidcSessionRegistry.class); - private final Map sessions = new ConcurrentHashMap<>(); + private final Map sessions = new ConcurrentHashMap<>(); @Override - public void register(OidcSessionRegistration registration) { - this.sessions.put(registration.getSessionId(), registration); + public void saveSessionInformation(OidcSessionInformation info) { + this.sessions.put(info.getSessionId(), info); } @Override - public void register(String oldClientSessionId, String newClientSessionId) { - OidcSessionRegistration old = this.sessions.remove(oldClientSessionId); - if (old == null) { - this.logger.debug("Failed to register new session id since old session id was not found in registry"); - return; - } - register(new OidcSessionRegistration(old.getClientId(), newClientSessionId, old.getHeaders(), - old.getPrincipal())); - } - - @Override - public OidcSessionRegistration deregister(String clientSessionId) { - OidcSessionRegistration details = this.sessions.remove(clientSessionId); - if (details != null) { + public OidcSessionInformation removeSessionInformation(String clientSessionId) { + OidcSessionInformation information = this.sessions.remove(clientSessionId); + if (information != null) { this.logger.trace("Removed client session"); } - return details; + return information; } @Override - public Iterable deregister(OidcLogoutToken token) { + public Iterable removeSessionInformation(OidcLogoutToken token) { List audience = token.getAudience(); String issuer = token.getIssuer().toString(); String subject = token.getSubject(); String providerSessionId = token.getSessionId(); - Predicate matcher = (providerSessionId != null) + Predicate matcher = (providerSessionId != null) ? sessionIdMatcher(audience, issuer, providerSessionId) : subjectMatcher(audience, issuer, subject); if (this.logger.isTraceEnabled()) { String message = "Looking up sessions by issuer [%s] and %s [%s]"; @@ -84,7 +74,7 @@ public Iterable deregister(OidcLogoutToken token) { } } int size = this.sessions.size(); - Set infos = new HashSet<>(); + Set infos = new HashSet<>(); this.sessions.values().removeIf((info) -> { boolean result = matcher.test(info); if (result) { @@ -102,24 +92,31 @@ else if (this.logger.isTraceEnabled()) { return infos; } - private static Predicate sessionIdMatcher(List audience, String issuer, + private static Predicate sessionIdMatcher(List audience, String issuer, String sessionId) { return (session) -> { - String thatRegistrationId = session.getClientId(); + List thatAudience = session.getPrincipal().getAudience(); String thatIssuer = session.getPrincipal().getIssuer().toString(); String thatSessionId = session.getPrincipal().getClaimAsString(LogoutTokenClaimNames.SID); - return audience.contains(thatRegistrationId) && issuer.equals(thatIssuer) + if (thatAudience == null) { + return false; + } + return !Collections.disjoint(audience, thatAudience) && issuer.equals(thatIssuer) && sessionId.equals(thatSessionId); }; } - private static Predicate subjectMatcher(List audience, String issuer, + private static Predicate subjectMatcher(List audience, String issuer, String subject) { return (session) -> { - String thatRegistrationId = session.getClientId(); + List thatAudience = session.getPrincipal().getAudience(); String thatIssuer = session.getPrincipal().getIssuer().toString(); String thatSubject = session.getPrincipal().getSubject(); - return audience.contains(thatRegistrationId) && issuer.equals(thatIssuer) && subject.equals(thatSubject); + if (thatAudience == null) { + return false; + } + return !Collections.disjoint(audience, thatAudience) && issuer.equals(thatIssuer) + && subject.equals(thatSubject); }; } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/session/OidcSessionInformation.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/session/OidcSessionInformation.java new file mode 100644 index 00000000000..e51ae90b2dc --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/session/OidcSessionInformation.java @@ -0,0 +1,73 @@ +/* + * Copyright 2002-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client.oidc.session; + +import java.util.Collections; +import java.util.Date; +import java.util.LinkedHashMap; +import java.util.Map; + +import org.springframework.security.core.session.SessionInformation; +import org.springframework.security.oauth2.core.oidc.user.OidcUser; + +/** + * A {@link SessionInformation} extension that enforces the principal be of type + * {@link OidcUser}. + * + * @author Josh Cummings + * @since 6.2 + */ +public class OidcSessionInformation extends SessionInformation { + + private final Map authorities; + + /** + * Construct an {@link OidcSessionInformation} + * @param sessionId the Client's session id + * @param user the OIDC Provider's session and end user + */ + public OidcSessionInformation(String sessionId, Map authorities, OidcUser user) { + super(user, sessionId, new Date()); + this.authorities = (authorities != null) ? new LinkedHashMap<>(authorities) : Collections.emptyMap(); + } + + /** + * Any headers needed to authorize operations on this session + * @return the {@link Map} of headers + */ + public Map getAuthorities() { + return this.authorities; + } + + /** + * {@inheritDoc} + */ + @Override + public OidcUser getPrincipal() { + return (OidcUser) super.getPrincipal(); + } + + /** + * Copy this {@link OidcSessionInformation}, using a new session identifier + * @param sessionId the new session identifier to use + * @return a new {@link OidcSessionInformation} instance + */ + public OidcSessionInformation withSessionId(String sessionId) { + return new OidcSessionInformation(sessionId, this.authorities, getPrincipal()); + } + +} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/session/OidcSessionRegistry.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/session/OidcSessionRegistry.java similarity index 65% rename from oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/session/OidcSessionRegistry.java rename to oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/session/OidcSessionRegistry.java index a8c0a793204..26bae499db3 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/session/OidcSessionRegistry.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/session/OidcSessionRegistry.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.springframework.security.oauth2.client.oidc.authentication.session; +package org.springframework.security.oauth2.client.oidc.session; import org.springframework.security.oauth2.client.oidc.authentication.logout.OidcLogoutToken; @@ -24,7 +24,7 @@ * session or the End User. * * @author Josh Cummings - * @since 6.1 + * @since 6.2 * @see Logout * Token @@ -34,32 +34,26 @@ public interface OidcSessionRegistry { /** * Register a OIDC Provider session with the provided client session. Generally * speaking, the client session should be the session tied to the current login. - * @param details the {@link OidcSessionRegistration} to use + * @param info the {@link OidcSessionInformation} to use */ - void register(OidcSessionRegistration details); - - /** - * Update the entry for a Client when their session id changes. This is handy, for - * example, when the id changes for session fixation protection. - * @param oldClientSessionId the Client's old session id - * @param newClientSessionId the Client's new session id - */ - void register(String oldClientSessionId, String newClientSessionId); + void saveSessionInformation(OidcSessionInformation info); /** * Deregister the OIDC Provider session tied to the provided client session. Generally * speaking, the client session should be the session tied to the current logout. * @param clientSessionId the client session - * @return any found {@link OidcSessionRegistration}, could be {@code null} + * @return any found {@link OidcSessionInformation}, could be {@code null} */ - OidcSessionRegistration deregister(String clientSessionId); + OidcSessionInformation removeSessionInformation(String clientSessionId); /** * Deregister the OIDC Provider sessions referenced by the provided OIDC Logout Token - * by its session id or its subject. + * by its session id or its subject. Note that the issuer and audience should also + * match the corresponding values found in each {@link OidcSessionInformation} + * returned. * @param logoutToken the {@link OidcLogoutToken} - * @return any found {@link OidcSessionRegistration}s, could be empty + * @return any found {@link OidcSessionInformation}s, could be empty */ - Iterable deregister(OidcLogoutToken logoutToken); + Iterable removeSessionInformation(OidcLogoutToken logoutToken); } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/web/authentication/logout/OidcBackChannelLogoutFilter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/web/OidcBackChannelLogoutFilter.java similarity index 86% rename from oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/web/authentication/logout/OidcBackChannelLogoutFilter.java rename to oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/web/OidcBackChannelLogoutFilter.java index 324e0229eff..3c49535bf14 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/web/authentication/logout/OidcBackChannelLogoutFilter.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/web/OidcBackChannelLogoutFilter.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.springframework.security.oauth2.client.oidc.web.authentication.logout; +package org.springframework.security.oauth2.client.oidc.web; import java.io.IOException; @@ -30,13 +30,14 @@ import org.springframework.security.authentication.AuthenticationServiceException; import org.springframework.security.core.Authentication; import org.springframework.security.core.AuthenticationException; -import org.springframework.security.oauth2.client.oidc.authentication.logout.LogoutTokenAuthenticationToken; +import org.springframework.security.oauth2.client.oidc.authentication.logout.OidcLogoutAuthenticationToken; +import org.springframework.security.oauth2.client.oidc.web.logout.OidcBackChannelLogoutHandler; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.OAuth2ErrorCodes; import org.springframework.security.oauth2.core.http.converter.OAuth2ErrorHttpMessageConverter; -import org.springframework.security.web.authentication.logout.BackchannelLogoutHandler; import org.springframework.security.web.authentication.logout.LogoutHandler; import org.springframework.security.web.util.matcher.AntPathRequestMatcher; import org.springframework.security.web.util.matcher.RequestMatcher; @@ -66,7 +67,7 @@ public class OidcBackChannelLogoutFilter extends OncePerRequestFilter { private RequestMatcher requestMatcher = new AntPathRequestMatcher(DEFAULT_LOGOUT_URI, "POST"); - private LogoutHandler logoutHandler = new BackchannelLogoutHandler(); + private LogoutHandler logoutHandler = new OidcBackChannelLogoutHandler(); /** * Construct an {@link OidcBackChannelLogoutFilter} @@ -95,8 +96,8 @@ protected void doFilterInternal(HttpServletRequest request, HttpServletResponse return; } String registrationId = result.getVariables().get("registrationId"); - ClientRegistration registration = this.clientRegistrationRepository.findByRegistrationId(registrationId); - if (registration == null) { + ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId(registrationId); + if (clientRegistration == null) { this.logger.debug("Did not process OIDC Back-Channel Logout since no ClientRegistration was found"); response.sendError(HttpServletResponse.SC_BAD_REQUEST); return; @@ -107,11 +108,16 @@ protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response.sendError(HttpServletResponse.SC_BAD_REQUEST); return; } - LogoutTokenAuthenticationToken token = new LogoutTokenAuthenticationToken(logoutToken, registration); + OidcLogoutAuthenticationToken token = new OidcLogoutAuthenticationToken(logoutToken, clientRegistration); try { Authentication authentication = this.authenticationManager.authenticate(token); this.logoutHandler.logout(request, response, authentication); } + catch (OAuth2AuthenticationException ex) { + this.logger.debug("Failed to process OIDC Back-Channel Logout", ex); + response.setStatus(HttpServletResponse.SC_BAD_REQUEST); + this.errorHttpMessageConverter.write(ex.getError(), null, new ServletServerHttpResponse(response)); + } catch (AuthenticationServiceException ex) { this.logger.debug("Failed to process OIDC Back-Channel Logout", ex); response.sendError(HttpServletResponse.SC_INTERNAL_SERVER_ERROR, ex.getMessage()); @@ -137,7 +143,7 @@ public void setRequestMatcher(RequestMatcher requestMatcher) { /** * The strategy for expiring all Client sessions indicated by the logout request. - * Defaults to {@link BackchannelLogoutHandler}. + * Defaults to {@link OidcBackChannelLogoutHandler}. * @param logoutHandler the {@link LogoutHandler} to use */ public void setLogoutHandler(LogoutHandler logoutHandler) { diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/web/logout/OidcBackChannelLogoutHandler.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/web/logout/OidcBackChannelLogoutHandler.java new file mode 100644 index 00000000000..0263ae649f9 --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/web/logout/OidcBackChannelLogoutHandler.java @@ -0,0 +1,147 @@ +/* + * Copyright 2002-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client.oidc.web.logout; + +import java.util.Map; + +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import org.springframework.http.HttpEntity; +import org.springframework.http.HttpHeaders; +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.client.oidc.authentication.logout.OidcBackChannelLogoutAuthentication; +import org.springframework.security.oauth2.client.oidc.authentication.logout.OidcLogoutToken; +import org.springframework.security.oauth2.client.oidc.session.InMemoryOidcSessionRegistry; +import org.springframework.security.oauth2.client.oidc.session.OidcSessionInformation; +import org.springframework.security.oauth2.client.oidc.session.OidcSessionRegistry; +import org.springframework.security.web.authentication.logout.LogoutHandler; +import org.springframework.util.Assert; +import org.springframework.web.client.RestClientException; +import org.springframework.web.client.RestOperations; +import org.springframework.web.client.RestTemplate; +import org.springframework.web.util.UriComponentsBuilder; + +/** + * A {@link LogoutHandler} that locates the sessions associated with a given OIDC + * Back-Channel Logout Token and invalidates each one. + * + * @author Josh Cummings + * @since 6.2 + * @see OIDC Back-Channel Logout + * Spec + */ +public final class OidcBackChannelLogoutHandler implements LogoutHandler { + + private final Log logger = LogFactory.getLog(getClass()); + + private OidcSessionRegistry sessionRegistry = new InMemoryOidcSessionRegistry(); + + private RestOperations restOperations = new RestTemplate(); + + private String logoutEndpointName = "/logout"; + + private String sessionCookieName = "JSESSIONID"; + + @Override + public void logout(HttpServletRequest request, HttpServletResponse response, Authentication authentication) { + if (!(authentication instanceof OidcBackChannelLogoutAuthentication token)) { + if (this.logger.isDebugEnabled()) { + String message = "Did not perform OIDC Back-Channel Logout since authentication [%s] was of the wrong type"; + this.logger.debug(String.format(message, authentication.getClass().getSimpleName())); + } + return; + } + Iterable sessions = this.sessionRegistry.removeSessionInformation(token.getPrincipal()); + int totalCount = 0; + int invalidatedCount = 0; + for (OidcSessionInformation session : sessions) { + totalCount++; + try { + eachLogout(request, session); + invalidatedCount++; + } + catch (RestClientException ex) { + this.logger.debug("Failed to invalidate session", ex); + } + } + if (this.logger.isTraceEnabled()) { + this.logger.trace(String.format("Invalidated %d out of %d sessions", invalidatedCount, totalCount)); + } + } + + private void eachLogout(HttpServletRequest request, OidcSessionInformation session) { + HttpHeaders headers = new HttpHeaders(); + headers.add(HttpHeaders.COOKIE, this.sessionCookieName + "=" + session.getSessionId()); + for (Map.Entry credential : session.getAuthorities().entrySet()) { + headers.add(credential.getKey(), credential.getValue()); + } + String url = request.getRequestURL().toString(); + String logout = UriComponentsBuilder.fromHttpUrl(url).replacePath(this.logoutEndpointName).build() + .toUriString(); + HttpEntity entity = new HttpEntity<>(null, headers); + this.restOperations.postForEntity(logout, entity, Object.class); + } + + /** + * Use this {@link OidcSessionRegistry} to identify sessions to invalidate. Note that + * this class uses + * {@link OidcSessionRegistry#removeSessionInformation(OidcLogoutToken)} to identify + * sessions. + * @param sessionRegistry the {@link OidcSessionRegistry} to use + */ + public void setSessionRegistry(OidcSessionRegistry sessionRegistry) { + Assert.notNull(sessionRegistry, "sessionRegistry cannot be null"); + this.sessionRegistry = sessionRegistry; + } + + /** + * Use this {@link RestOperations} to perform the per-session back-channel logout + * @param restOperations the {@link RestOperations} to use + */ + public void setRestOperations(RestOperations restOperations) { + Assert.notNull(restOperations, "restOperations cannot be null"); + this.restOperations = restOperations; + } + + /** + * Use this logout URI for performing per-session logout. Defaults to {@code /logout} + * since that is the default URI for + * {@link org.springframework.security.web.authentication.logout.LogoutFilter}. + * @param logoutUri the URI to use + */ + public void setLogoutUri(String logoutUri) { + Assert.hasText(logoutUri, "logoutUri cannot be empty"); + this.logoutEndpointName = logoutUri; + } + + /** + * Use this cookie name for the session identifier. Defaults to {@code JSESSIONID}. + * + *

+ * Note that if you are using Spring Session, this likely needs to change to SESSION. + * @param sessionCookieName the cookie name to use + */ + public void setSessionCookieName(String sessionCookieName) { + Assert.hasText(sessionCookieName, "clientSessionCookieName cannot be empty"); + this.sessionCookieName = sessionCookieName; + } + +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/logout/OidcBackChannelLogoutTokenValidatorTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/logout/OidcBackChannelLogoutTokenValidatorTests.java new file mode 100644 index 00000000000..c4747837ad4 --- /dev/null +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/logout/OidcBackChannelLogoutTokenValidatorTests.java @@ -0,0 +1,88 @@ +/* + * Copyright 2002-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client.oidc.authentication.logout; + +import java.util.List; + +import org.junit.jupiter.api.Test; + +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.TestClientRegistrations; +import org.springframework.security.oauth2.jwt.Jwt; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Tests for {@link OidcBackChannelLogoutTokenValidator} + */ +public class OidcBackChannelLogoutTokenValidatorTests { + + // @formatter:off + private final ClientRegistration clientRegistration = TestClientRegistrations + .clientRegistration() + .issuerUri("https://issuer") + .scope("openid").build(); + // @formatter:on + + private final OidcBackChannelLogoutTokenValidator logoutTokenValidator = new OidcBackChannelLogoutTokenValidator( + this.clientRegistration); + + @Test + public void createDecoderWhenTokenValidThenNoErrors() { + Jwt valid = valid(this.clientRegistration).build(); + assertThat(this.logoutTokenValidator.validate(valid).hasErrors()).isFalse(); + } + + @Test + public void createDecoderWhenInvalidAudienceThenErrors() { + Jwt valid = valid(this.clientRegistration).audience(List.of("wrong")).build(); + assertThat(this.logoutTokenValidator.validate(valid).hasErrors()).isTrue(); + } + + @Test + public void createDecoderWhenMissingEventsThenErrors() { + Jwt valid = valid(this.clientRegistration).claims((claims) -> claims.remove(LogoutTokenClaimNames.EVENTS)) + .build(); + assertThat(this.logoutTokenValidator.validate(valid).hasErrors()).isTrue(); + } + + @Test + public void createDecoderWhenInvalidIssuerThenErrors() { + Jwt valid = valid(this.clientRegistration).issuer("https://wrong").build(); + assertThat(this.logoutTokenValidator.validate(valid).hasErrors()).isTrue(); + } + + @Test + public void createDecoderWhenMissingSubjectThenErrors() { + Jwt valid = valid(this.clientRegistration).claims((claims) -> claims.remove(LogoutTokenClaimNames.SUB)).build(); + assertThat(this.logoutTokenValidator.validate(valid).hasErrors()).isTrue(); + } + + @Test + public void createDecoderWhenMissingAudienceThenErrors() { + Jwt valid = valid(this.clientRegistration).claims((claims) -> claims.remove(LogoutTokenClaimNames.AUD)).build(); + assertThat(this.logoutTokenValidator.validate(valid).hasErrors()).isTrue(); + } + + private Jwt.Builder valid(ClientRegistration clientRegistration) { + String issuerUri = clientRegistration.getProviderDetails().getIssuerUri(); + OidcLogoutToken logoutToken = TestOidcLogoutTokens.withSubject(issuerUri, "subject").build(); + return Jwt.withTokenValue(logoutToken.getTokenValue()).header("header", "value") + .claims((claims) -> claims.putAll(logoutToken.getClaims())); + } + +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/logout/OidcLogoutTokenValidatorTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/logout/OidcLogoutTokenValidatorTests.java deleted file mode 100644 index dcff8a917f2..00000000000 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/logout/OidcLogoutTokenValidatorTests.java +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Copyright 2002-2023 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.security.oauth2.client.oidc.authentication.logout; - -import org.junit.jupiter.api.Test; - -import org.springframework.security.oauth2.client.registration.ClientRegistration; -import org.springframework.security.oauth2.client.registration.TestClientRegistrations; - -import static org.assertj.core.api.Assertions.assertThat; - -public class OidcLogoutTokenValidatorTests { - - // @formatter:off - private ClientRegistration.Builder registration = TestClientRegistrations - .clientRegistration() - .scope("openid"); - // @formatter:on - - @Test - public void createDecoderWhenClientRegistrationValidThenReturnDecoder() { - OidcLogoutTokenValidator validator = new OidcLogoutTokenValidator(this.registration.build()); - assertThat(validator).isNotNull(); - } - -} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/session/InMemoryOidcSessionRegistryTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/session/InMemoryOidcSessionRegistryTests.java deleted file mode 100644 index ca495d4181d..00000000000 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/session/InMemoryOidcSessionRegistryTests.java +++ /dev/null @@ -1,98 +0,0 @@ -/* - * Copyright 2002-2023 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.security.oauth2.client.oidc.authentication.session; - -import org.junit.jupiter.api.Test; - -import org.springframework.security.core.authority.AuthorityUtils; -import org.springframework.security.oauth2.client.oidc.authentication.logout.OidcLogoutToken; -import org.springframework.security.oauth2.client.oidc.authentication.logout.TestOidcLogoutTokens; -import org.springframework.security.oauth2.core.oidc.OidcIdToken; -import org.springframework.security.oauth2.core.oidc.TestOidcIdTokens; -import org.springframework.security.oauth2.core.oidc.user.DefaultOidcUser; -import org.springframework.security.oauth2.core.oidc.user.OidcUser; - -import static org.assertj.core.api.Assertions.assertThat; - -public class InMemoryOidcSessionRegistryTests { - - @Test - public void registerWhenDefaultsThenStoresSessionInformation() { - InMemoryOidcSessionRegistry registry = new InMemoryOidcSessionRegistry(); - String sessionId = "client"; - OidcSessionRegistration info = TestOidcSessionRegistrations.create(sessionId); - registry.register(info); - OidcLogoutToken token = TestOidcLogoutTokens.withUser(info.getPrincipal()).build(); - Iterable infos = registry.deregister(token); - assertThat(infos).containsExactly(info); - } - - @Test - public void registerWhenIdTokenHasSessionIdThenStoresSessionInformation() { - InMemoryOidcSessionRegistry registry = new InMemoryOidcSessionRegistry(); - OidcIdToken token = TestOidcIdTokens.idToken().claim("sid", "provider").build(); - OidcUser user = new DefaultOidcUser(AuthorityUtils.NO_AUTHORITIES, token); - OidcSessionRegistration info = TestOidcSessionRegistrations.create("client", user); - registry.register(info); - OidcLogoutToken logoutToken = TestOidcLogoutTokens.withSessionId(token.getIssuer().toString(), "provider") - .build(); - Iterable infos = registry.deregister(logoutToken); - assertThat(infos).containsExactly(info); - } - - @Test - public void unregisterWhenMultipleSessionsThenRemovesAllMatching() { - InMemoryOidcSessionRegistry registry = new InMemoryOidcSessionRegistry(); - OidcIdToken token = TestOidcIdTokens.idToken().claim("sid", "providerOne").subject("otheruser").build(); - OidcUser user = new DefaultOidcUser(AuthorityUtils.NO_AUTHORITIES, token); - OidcSessionRegistration one = TestOidcSessionRegistrations.create("clientOne", user); - registry.register(one); - token = TestOidcIdTokens.idToken().claim("sid", "providerTwo").build(); - user = new DefaultOidcUser(AuthorityUtils.NO_AUTHORITIES, token); - OidcSessionRegistration two = TestOidcSessionRegistrations.create("clientTwo", user); - registry.register(two); - token = TestOidcIdTokens.idToken().claim("sid", "providerThree").build(); - user = new DefaultOidcUser(AuthorityUtils.NO_AUTHORITIES, token); - OidcSessionRegistration three = TestOidcSessionRegistrations.create("clientThree", user); - registry.register(three); - OidcLogoutToken logoutToken = TestOidcLogoutTokens.withSubject(token.getIssuer().toString(), token.getSubject()) - .build(); - Iterable infos = registry.deregister(logoutToken); - assertThat(infos).containsExactlyInAnyOrder(two, three); - logoutToken = TestOidcLogoutTokens.withSubject(token.getIssuer().toString(), "otheruser").build(); - infos = registry.deregister(logoutToken); - assertThat(infos).containsExactly(one); - } - - @Test - public void unregisterWhenNoSessionsThenEmptyList() { - InMemoryOidcSessionRegistry registry = new InMemoryOidcSessionRegistry(); - OidcIdToken token = TestOidcIdTokens.idToken().claim("sid", "provider").build(); - OidcUser user = new DefaultOidcUser(AuthorityUtils.NO_AUTHORITIES, token); - OidcSessionRegistration registration = TestOidcSessionRegistrations.create("client", user); - registry.register(registration); - OidcLogoutToken logoutToken = TestOidcLogoutTokens.withSessionId(token.getIssuer().toString(), "wrong").build(); - Iterable infos = registry.deregister(logoutToken); - assertThat(infos).isNotNull(); - assertThat(infos).isEmpty(); - logoutToken = TestOidcLogoutTokens.withSessionId("https://wrong", "provider").build(); - infos = registry.deregister(logoutToken); - assertThat(infos).isNotNull(); - assertThat(infos).isEmpty(); - } - -} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/session/InMemoryOidcSessionRegistryTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/session/InMemoryOidcSessionRegistryTests.java new file mode 100644 index 00000000000..861eccce7ea --- /dev/null +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/session/InMemoryOidcSessionRegistryTests.java @@ -0,0 +1,102 @@ +/* + * Copyright 2002-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client.oidc.session; + +import org.junit.jupiter.api.Test; + +import org.springframework.security.core.authority.AuthorityUtils; +import org.springframework.security.oauth2.client.oidc.authentication.logout.OidcLogoutToken; +import org.springframework.security.oauth2.client.oidc.authentication.logout.TestOidcLogoutTokens; +import org.springframework.security.oauth2.core.oidc.OidcIdToken; +import org.springframework.security.oauth2.core.oidc.TestOidcIdTokens; +import org.springframework.security.oauth2.core.oidc.user.DefaultOidcUser; +import org.springframework.security.oauth2.core.oidc.user.OidcUser; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Tests for {@link InMemoryOidcSessionRegistry} + */ +public class InMemoryOidcSessionRegistryTests { + + @Test + public void registerWhenDefaultsThenStoresSessionInformation() { + InMemoryOidcSessionRegistry sessionRegistry = new InMemoryOidcSessionRegistry(); + String sessionId = "client"; + OidcSessionInformation info = TestOidcSessionInformations.create(sessionId); + sessionRegistry.saveSessionInformation(info); + OidcLogoutToken logoutToken = TestOidcLogoutTokens.withUser(info.getPrincipal()).build(); + Iterable infos = sessionRegistry.removeSessionInformation(logoutToken); + assertThat(infos).containsExactly(info); + } + + @Test + public void registerWhenIdTokenHasSessionIdThenStoresSessionInformation() { + InMemoryOidcSessionRegistry sessionRegistry = new InMemoryOidcSessionRegistry(); + OidcIdToken idToken = TestOidcIdTokens.idToken().claim("sid", "provider").build(); + OidcUser user = new DefaultOidcUser(AuthorityUtils.NO_AUTHORITIES, idToken); + OidcSessionInformation info = TestOidcSessionInformations.create("client", user); + sessionRegistry.saveSessionInformation(info); + OidcLogoutToken logoutToken = TestOidcLogoutTokens.withSessionId(idToken.getIssuer().toString(), "provider") + .build(); + Iterable infos = sessionRegistry.removeSessionInformation(logoutToken); + assertThat(infos).containsExactly(info); + } + + @Test + public void unregisterWhenMultipleSessionsThenRemovesAllMatching() { + InMemoryOidcSessionRegistry sessionRegistry = new InMemoryOidcSessionRegistry(); + OidcIdToken idToken = TestOidcIdTokens.idToken().claim("sid", "providerOne").subject("otheruser").build(); + OidcUser user = new DefaultOidcUser(AuthorityUtils.NO_AUTHORITIES, idToken); + OidcSessionInformation oneSession = TestOidcSessionInformations.create("clientOne", user); + sessionRegistry.saveSessionInformation(oneSession); + idToken = TestOidcIdTokens.idToken().claim("sid", "providerTwo").build(); + user = new DefaultOidcUser(AuthorityUtils.NO_AUTHORITIES, idToken); + OidcSessionInformation twoSession = TestOidcSessionInformations.create("clientTwo", user); + sessionRegistry.saveSessionInformation(twoSession); + idToken = TestOidcIdTokens.idToken().claim("sid", "providerThree").build(); + user = new DefaultOidcUser(AuthorityUtils.NO_AUTHORITIES, idToken); + OidcSessionInformation threeSession = TestOidcSessionInformations.create("clientThree", user); + sessionRegistry.saveSessionInformation(threeSession); + OidcLogoutToken logoutToken = TestOidcLogoutTokens + .withSubject(idToken.getIssuer().toString(), idToken.getSubject()).build(); + Iterable infos = sessionRegistry.removeSessionInformation(logoutToken); + assertThat(infos).containsExactlyInAnyOrder(twoSession, threeSession); + logoutToken = TestOidcLogoutTokens.withSubject(idToken.getIssuer().toString(), "otheruser").build(); + infos = sessionRegistry.removeSessionInformation(logoutToken); + assertThat(infos).containsExactly(oneSession); + } + + @Test + public void unregisterWhenNoSessionsThenEmptyList() { + InMemoryOidcSessionRegistry sessionRegistry = new InMemoryOidcSessionRegistry(); + OidcIdToken idToken = TestOidcIdTokens.idToken().claim("sid", "provider").build(); + OidcUser user = new DefaultOidcUser(AuthorityUtils.NO_AUTHORITIES, idToken); + OidcSessionInformation info = TestOidcSessionInformations.create("client", user); + sessionRegistry.saveSessionInformation(info); + OidcLogoutToken logoutToken = TestOidcLogoutTokens.withSessionId(idToken.getIssuer().toString(), "wrong") + .build(); + Iterable infos = sessionRegistry.removeSessionInformation(logoutToken); + assertThat(infos).isNotNull(); + assertThat(infos).isEmpty(); + logoutToken = TestOidcLogoutTokens.withSessionId("https://wrong", "provider").build(); + infos = sessionRegistry.removeSessionInformation(logoutToken); + assertThat(infos).isNotNull(); + assertThat(infos).isEmpty(); + } + +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/session/TestOidcSessionRegistrations.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/session/TestOidcSessionInformations.java similarity index 64% rename from oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/session/TestOidcSessionRegistrations.java rename to oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/session/TestOidcSessionInformations.java index 3beb9b9bc5d..47f64868de1 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/session/TestOidcSessionRegistrations.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/session/TestOidcSessionInformations.java @@ -14,28 +14,31 @@ * limitations under the License. */ -package org.springframework.security.oauth2.client.oidc.authentication.session; +package org.springframework.security.oauth2.client.oidc.session; import java.util.Map; import org.springframework.security.oauth2.core.oidc.user.OidcUser; import org.springframework.security.oauth2.core.oidc.user.TestOidcUsers; -public final class TestOidcSessionRegistrations { +/** + * Sample {@link OidcSessionInformation} instances + */ +public final class TestOidcSessionInformations { - public static OidcSessionRegistration create() { + public static OidcSessionInformation create() { return create("sessionId"); } - public static OidcSessionRegistration create(String sessionId) { + public static OidcSessionInformation create(String sessionId) { return create(sessionId, TestOidcUsers.create()); } - public static OidcSessionRegistration create(String sessionId, OidcUser user) { - return new OidcSessionRegistration("client-id", sessionId, Map.of("_csrf", "token"), user); + public static OidcSessionInformation create(String sessionId, OidcUser user) { + return new OidcSessionInformation(sessionId, Map.of("_csrf", "token"), user); } - private TestOidcSessionRegistrations() { + private TestOidcSessionInformations() { } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/web/authentication/logout/OidcBackChannelLogoutFilterTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/web/OidcBackChannelLogoutFilterTests.java similarity index 53% rename from oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/web/authentication/logout/OidcBackChannelLogoutFilterTests.java rename to oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/web/OidcBackChannelLogoutFilterTests.java index 0da06be9f2f..5d87f8baeb4 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/web/authentication/logout/OidcBackChannelLogoutFilterTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/web/OidcBackChannelLogoutFilterTests.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.springframework.security.oauth2.client.oidc.web.authentication.logout; +package org.springframework.security.oauth2.client.oidc.web; import java.util.Set; @@ -25,14 +25,16 @@ import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.security.authentication.AuthenticationManager; import org.springframework.security.authentication.BadCredentialsException; +import org.springframework.security.oauth2.client.oidc.authentication.logout.OidcBackChannelLogoutAuthentication; import org.springframework.security.oauth2.client.oidc.authentication.logout.OidcLogoutToken; import org.springframework.security.oauth2.client.oidc.authentication.logout.TestOidcLogoutTokens; -import org.springframework.security.oauth2.client.oidc.authentication.session.OidcSessionRegistration; -import org.springframework.security.oauth2.client.oidc.authentication.session.TestOidcSessionRegistrations; +import org.springframework.security.oauth2.client.oidc.session.OidcSessionInformation; +import org.springframework.security.oauth2.client.oidc.session.OidcSessionRegistry; +import org.springframework.security.oauth2.client.oidc.session.TestOidcSessionInformations; +import org.springframework.security.oauth2.client.oidc.web.logout.OidcBackChannelLogoutHandler; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; import org.springframework.security.oauth2.client.registration.TestClientRegistrations; -import org.springframework.security.web.authentication.logout.BackchannelLogoutAuthentication; import org.springframework.security.web.authentication.logout.LogoutHandler; import static org.assertj.core.api.Assertions.assertThat; @@ -46,91 +48,99 @@ public class OidcBackChannelLogoutFilterTests { @Test public void doFilterRequestDoesNotMatchThenDoesNotRun() throws Exception { - ClientRegistrationRepository clients = mock(ClientRegistrationRepository.class); - AuthenticationManager factory = mock(AuthenticationManager.class); - OidcBackChannelLogoutFilter filter = new OidcBackChannelLogoutFilter(clients, factory); + ClientRegistrationRepository clientRegistrationRepository = mock(ClientRegistrationRepository.class); + AuthenticationManager authenticationManager = mock(AuthenticationManager.class); + OidcBackChannelLogoutFilter backChannelLogoutFilter = new OidcBackChannelLogoutFilter( + clientRegistrationRepository, authenticationManager); MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain chain = mock(FilterChain.class); - filter.doFilter(request, response, chain); - verifyNoInteractions(clients, factory); + backChannelLogoutFilter.doFilter(request, response, chain); + verifyNoInteractions(clientRegistrationRepository, authenticationManager); verify(chain).doFilter(request, response); } @Test public void doFilterRequestDoesNotMatchContainLogoutTokenThenBadRequest() throws Exception { - ClientRegistration registration = TestClientRegistrations.clientRegistration().build(); - ClientRegistrationRepository clients = mock(ClientRegistrationRepository.class); - given(clients.findByRegistrationId(any())).willReturn(registration); - AuthenticationManager factory = mock(AuthenticationManager.class); - OidcBackChannelLogoutFilter filter = new OidcBackChannelLogoutFilter(clients, factory); + ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build(); + ClientRegistrationRepository clientRegistrationRepository = mock(ClientRegistrationRepository.class); + given(clientRegistrationRepository.findByRegistrationId(any())).willReturn(clientRegistration); + AuthenticationManager authenticationManager = mock(AuthenticationManager.class); + OidcBackChannelLogoutFilter filter = new OidcBackChannelLogoutFilter(clientRegistrationRepository, + authenticationManager); MockHttpServletRequest request = new MockHttpServletRequest("POST", "/logout/connect/back-channel/id"); request.setServletPath("/logout/connect/back-channel/id"); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain chain = mock(FilterChain.class); filter.doFilter(request, response, chain); - verifyNoInteractions(factory, chain); + verifyNoInteractions(authenticationManager, chain); assertThat(response.getStatus()).isEqualTo(400); } @Test public void doFilterWithNoMatchingClientThenBadRequest() throws Exception { - ClientRegistrationRepository clients = mock(ClientRegistrationRepository.class); - AuthenticationManager factory = mock(AuthenticationManager.class); - OidcBackChannelLogoutFilter filter = new OidcBackChannelLogoutFilter(clients, factory); + ClientRegistrationRepository clientRegistrationRepository = mock(ClientRegistrationRepository.class); + AuthenticationManager authenticationManager = mock(AuthenticationManager.class); + OidcBackChannelLogoutFilter backChannelLogoutFilter = new OidcBackChannelLogoutFilter( + clientRegistrationRepository, authenticationManager); MockHttpServletRequest request = new MockHttpServletRequest("POST", "/logout/connect/back-channel/id"); request.setServletPath("/logout/connect/back-channel/id"); request.setParameter("logout_token", "logout_token"); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain chain = mock(FilterChain.class); - filter.doFilter(request, response, chain); - verify(clients).findByRegistrationId("id"); - verifyNoInteractions(factory, chain); + backChannelLogoutFilter.doFilter(request, response, chain); + verify(clientRegistrationRepository).findByRegistrationId("id"); + verifyNoInteractions(authenticationManager, chain); assertThat(response.getStatus()).isEqualTo(400); } @Test public void doFilterWithSessionMatchingLogoutTokenThenInvalidates() throws Exception { - ClientRegistration registration = TestClientRegistrations.clientRegistration().build(); - ClientRegistrationRepository clients = mock(ClientRegistrationRepository.class); - given(clients.findByRegistrationId(any())).willReturn(registration); - AuthenticationManager factory = mock(AuthenticationManager.class); + ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build(); + ClientRegistrationRepository clientRegistrationRepository = mock(ClientRegistrationRepository.class); + given(clientRegistrationRepository.findByRegistrationId(any())).willReturn(clientRegistration); + AuthenticationManager authenticationManager = mock(AuthenticationManager.class); OidcLogoutToken token = TestOidcLogoutTokens.withSessionId("issuer", "provider").build(); - Iterable infos = Set.of(TestOidcSessionRegistrations.create("clientOne"), - TestOidcSessionRegistrations.create("clientTwo")); - given(factory.authenticate(any())).willReturn(new BackchannelLogoutAuthentication(token, token, infos)); - LogoutHandler logoutHandler = mock(LogoutHandler.class); - OidcBackChannelLogoutFilter filter = new OidcBackChannelLogoutFilter(clients, factory); - filter.setLogoutHandler(logoutHandler); + Iterable infos = Set.of(TestOidcSessionInformations.create("clientOne"), + TestOidcSessionInformations.create("clientTwo")); + given(authenticationManager.authenticate(any())).willReturn(new OidcBackChannelLogoutAuthentication(token)); + OidcBackChannelLogoutHandler backChannelLogoutHandler = new OidcBackChannelLogoutHandler(); + OidcSessionRegistry sessionRegistry = mock(OidcSessionRegistry.class); + given(sessionRegistry.removeSessionInformation(any(OidcLogoutToken.class))).willReturn(infos); + backChannelLogoutHandler.setSessionRegistry(sessionRegistry); + OidcBackChannelLogoutFilter filter = new OidcBackChannelLogoutFilter(clientRegistrationRepository, + authenticationManager); + filter.setLogoutHandler(backChannelLogoutHandler); MockHttpServletRequest request = new MockHttpServletRequest("POST", - "/oauth2/" + registration.getRegistrationId() + "/logout"); + "/oauth2/" + clientRegistration.getRegistrationId() + "/logout"); request.setServletPath("/logout/connect/back-channel/id"); request.setParameter("logout_token", "logout_token"); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain chain = mock(FilterChain.class); filter.doFilter(request, response, chain); - verify(logoutHandler).logout(any(), any(), any()); + verify(sessionRegistry).removeSessionInformation(token); verifyNoInteractions(chain); assertThat(response.getStatus()).isEqualTo(200); } @Test public void doFilterWhenInvalidJwtThenBadRequest() throws Exception { - ClientRegistration registration = TestClientRegistrations.clientRegistration().build(); - ClientRegistrationRepository clients = mock(ClientRegistrationRepository.class); - given(clients.findByRegistrationId(any())).willReturn(registration); - AuthenticationManager factory = mock(AuthenticationManager.class); - given(factory.authenticate(any())).willThrow(new BadCredentialsException("bad")); + ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build(); + ClientRegistrationRepository clientRegistrationRepository = mock(ClientRegistrationRepository.class); + given(clientRegistrationRepository.findByRegistrationId(any())).willReturn(clientRegistration); + AuthenticationManager authenticationManager = mock(AuthenticationManager.class); + given(authenticationManager.authenticate(any())).willThrow(new BadCredentialsException("bad")); LogoutHandler logoutHandler = mock(LogoutHandler.class); - OidcBackChannelLogoutFilter filter = new OidcBackChannelLogoutFilter(clients, factory); - filter.setLogoutHandler(logoutHandler); + OidcBackChannelLogoutFilter backChannelLogoutFilter = new OidcBackChannelLogoutFilter( + clientRegistrationRepository, authenticationManager); + backChannelLogoutFilter.setLogoutHandler(logoutHandler); MockHttpServletRequest request = new MockHttpServletRequest("POST", - "/oauth2/" + registration.getRegistrationId() + "/logout"); + "/oauth2/" + clientRegistration.getRegistrationId() + "/logout"); request.setServletPath("/logout/connect/back-channel/id"); request.setParameter("logout_token", "logout_token"); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain chain = mock(FilterChain.class); - filter.doFilter(request, response, chain); + backChannelLogoutFilter.doFilter(request, response, chain); verifyNoInteractions(logoutHandler, chain); assertThat(response.getStatus()).isEqualTo(400); assertThat(response.getContentAsString()).contains("bad"); diff --git a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/oidc/TestOidcIdTokens.java b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/oidc/TestOidcIdTokens.java index ca859473d1e..2271a52e00f 100644 --- a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/oidc/TestOidcIdTokens.java +++ b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/oidc/TestOidcIdTokens.java @@ -17,6 +17,7 @@ package org.springframework.security.oauth2.core.oidc; import java.time.Instant; +import java.util.List; /** * Test {@link OidcIdToken}s @@ -32,6 +33,7 @@ public static OidcIdToken.Builder idToken() { // @formatter:off return OidcIdToken.withTokenValue("id-token") .issuer("https://example.com") + .audience(List.of("client-id")) .subject("subject") .issuedAt(Instant.now()) .expiresAt(Instant.now() diff --git a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/oidc/user/TestOidcUsers.java b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/oidc/user/TestOidcUsers.java index 3bda7ec32d7..ca2c37abf78 100644 --- a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/oidc/user/TestOidcUsers.java +++ b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/oidc/user/TestOidcUsers.java @@ -50,7 +50,7 @@ private static OidcIdToken idToken() { .expiresAt(expiresAt) .subject("subject") .issuer("http://localhost/issuer") - .audience(Collections.unmodifiableSet(new LinkedHashSet<>(Collections.singletonList("client")))) + .audience(Collections.unmodifiableSet(new LinkedHashSet<>(Collections.singletonList("client-id")))) .authorizedParty("client") .build(); // @formatter:on diff --git a/web/src/main/java/org/springframework/security/web/authentication/logout/BackchannelLogoutAuthentication.java b/web/src/main/java/org/springframework/security/web/authentication/logout/BackchannelLogoutAuthentication.java deleted file mode 100644 index c61488bddb7..00000000000 --- a/web/src/main/java/org/springframework/security/web/authentication/logout/BackchannelLogoutAuthentication.java +++ /dev/null @@ -1,57 +0,0 @@ -/* - * Copyright 2002-2023 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.security.web.authentication.logout; - -import java.util.Collections; - -import org.springframework.security.authentication.AbstractAuthenticationToken; -import org.springframework.security.core.session.SessionInformation; -import org.springframework.util.Assert; - -public class BackchannelLogoutAuthentication extends AbstractAuthenticationToken { - - private final Object principal; - - private final Object credentials; - - private final Iterable sessions; - - public BackchannelLogoutAuthentication(Object principal, Object credentials, - Iterable sessions) { - super(Collections.emptyList()); - Assert.notNull(sessions, "sessions cannot be null"); - this.sessions = sessions; - this.principal = principal; - this.credentials = credentials; - setAuthenticated(true); - } - - @Override - public Object getPrincipal() { - return this.principal; - } - - @Override - public Object getCredentials() { - return this.credentials; - } - - public Iterable getSessions() { - return this.sessions; - } - -} diff --git a/web/src/main/java/org/springframework/security/web/authentication/logout/BackchannelLogoutHandler.java b/web/src/main/java/org/springframework/security/web/authentication/logout/BackchannelLogoutHandler.java deleted file mode 100644 index 2e16a339e67..00000000000 --- a/web/src/main/java/org/springframework/security/web/authentication/logout/BackchannelLogoutHandler.java +++ /dev/null @@ -1,97 +0,0 @@ -/* - * Copyright 2002-2023 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.security.web.authentication.logout; - -import java.util.Map; - -import jakarta.servlet.http.HttpServletRequest; -import jakarta.servlet.http.HttpServletResponse; -import org.apache.commons.logging.Log; -import org.apache.commons.logging.LogFactory; - -import org.springframework.http.HttpEntity; -import org.springframework.http.HttpHeaders; -import org.springframework.security.core.Authentication; -import org.springframework.security.core.session.SessionInformation; -import org.springframework.util.Assert; -import org.springframework.web.client.RestClientException; -import org.springframework.web.client.RestOperations; -import org.springframework.web.client.RestTemplate; -import org.springframework.web.util.UriComponentsBuilder; - -public final class BackchannelLogoutHandler implements LogoutHandler { - - private final Log logger = LogFactory.getLog(getClass()); - - private RestOperations rest = new RestTemplate(); - - private String logoutEndpointName = "/logout"; - - private String clientSessionCookieName = "JSESSIONID"; - - @Override - public void logout(HttpServletRequest request, HttpServletResponse response, Authentication authentication) { - if (!(authentication instanceof BackchannelLogoutAuthentication token)) { - if (this.logger.isDebugEnabled()) { - String message = "Did not perform Backchannel Logout since authentication [%s] was of the wrong type"; - this.logger.debug(String.format(message, authentication.getClass().getSimpleName())); - } - return; - } - Iterable sessions = token.getSessions(); - for (SessionInformation session : sessions) { - eachLogout(request, session); - } - } - - private void eachLogout(HttpServletRequest request, SessionInformation session) { - HttpHeaders headers = new HttpHeaders(); - headers.add(HttpHeaders.COOKIE, this.clientSessionCookieName + "=" + session.getSessionId()); - for (Map.Entry credential : session.getHeaders().entrySet()) { - headers.add(credential.getKey(), credential.getValue()); - } - String url = request.getRequestURL().toString(); - String logout = UriComponentsBuilder.fromHttpUrl(url).replacePath(this.logoutEndpointName).build() - .toUriString(); - HttpEntity entity = new HttpEntity<>(null, headers); - try { - this.rest.postForEntity(logout, entity, Object.class); - if (this.logger.isTraceEnabled()) { - this.logger.trace("Invalidated session"); - } - } - catch (RestClientException ex) { - this.logger.debug("Failed to invalidate session", ex); - } - } - - public void setRestOperations(RestOperations rest) { - Assert.notNull(rest, "rest cannot be null"); - this.rest = rest; - } - - public void setLogoutEndpointName(String logoutEndpointName) { - Assert.hasText(logoutEndpointName, "logoutEndpointName cannot be empty"); - this.logoutEndpointName = logoutEndpointName; - } - - public void setClientSessionCookieName(String clientSessionCookieName) { - Assert.hasText(clientSessionCookieName, "clientSessionCookieName cannot be empty"); - this.clientSessionCookieName = clientSessionCookieName; - } - -}