Skip to content

Commit

Permalink
Use Only AuthenticationManager
Browse files Browse the repository at this point in the history
I've put together this PR to allow folks to play with the
idea of using only AuthenticationManager in the logout filter.

I've articulated what I see as the abstraction limitations of
this approach in spring-projects#13767
  • Loading branch information
jzheaux committed Sep 5, 2023
1 parent 95960e8 commit 03646fe
Show file tree
Hide file tree
Showing 9 changed files with 125 additions and 244 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,9 @@
import org.springframework.security.config.annotation.web.configurers.AbstractHttpConfigurer;
import org.springframework.security.oauth2.client.oidc.authentication.logout.OidcBackChannelLogoutAuthenticationProvider;
import org.springframework.security.oauth2.client.oidc.web.OidcBackChannelLogoutFilter;
import org.springframework.security.oauth2.client.oidc.web.logout.OidcBackChannelLogoutHandler;
import org.springframework.security.oauth2.client.oidc.web.logout.OidcLogoutAuthenticationConverter;
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
import org.springframework.security.web.authentication.AuthenticationConverter;
import org.springframework.security.web.authentication.logout.LogoutHandler;
import org.springframework.security.web.csrf.CsrfFilter;
import org.springframework.util.Assert;

Expand Down Expand Up @@ -100,10 +98,7 @@ public final class BackChannelLogoutConfigurer {

private AuthenticationConverter authenticationConverter;

private AuthenticationManager authenticationManager = new ProviderManager(
new OidcBackChannelLogoutAuthenticationProvider());

private LogoutHandler logoutHandler;
private AuthenticationManager authenticationManager;

/**
* Use this {@link AuthenticationConverter} to extract the Logout Token from the
Expand All @@ -128,17 +123,6 @@ public BackChannelLogoutConfigurer authenticationManager(AuthenticationManager a
return this;
}

/**
* Use this {@link LogoutHandler} for invalidating each session identified by the
* OIDC Back-Channel Logout Token
* @return the {@link BackChannelLogoutConfigurer} for further configuration
*/
public BackChannelLogoutConfigurer logoutHandler(LogoutHandler logoutHandler) {
Assert.notNull(logoutHandler, "logoutHandler cannot be null");
this.logoutHandler = logoutHandler;
return this;
}

private AuthenticationConverter authenticationConverter(B http) {
if (this.authenticationConverter == null) {
ClientRegistrationRepository clientRegistrationRepository = OAuth2ClientConfigurerUtils
Expand All @@ -148,23 +132,18 @@ private AuthenticationConverter authenticationConverter(B http) {
return this.authenticationConverter;
}

private AuthenticationManager authenticationManager() {
return this.authenticationManager;
}

private LogoutHandler logoutHandler(B http) {
if (this.logoutHandler == null) {
OidcBackChannelLogoutHandler logoutHandler = new OidcBackChannelLogoutHandler();
logoutHandler.setSessionRegistry(OAuth2ClientConfigurerUtils.getOidcSessionRegistry(http));
this.logoutHandler = logoutHandler;
private AuthenticationManager authenticationManager(B http) {
if (this.authenticationManager == null) {
OidcBackChannelLogoutAuthenticationProvider authenticationProvider = new OidcBackChannelLogoutAuthenticationProvider();
authenticationProvider.setSessionRegistry(OAuth2ClientConfigurerUtils.getOidcSessionRegistry(http));
this.authenticationManager = new ProviderManager(authenticationProvider);
}
return this.logoutHandler;
return this.authenticationManager;
}

void configure(B http) {
OidcBackChannelLogoutFilter filter = new OidcBackChannelLogoutFilter(authenticationConverter(http),
authenticationManager());
filter.setLogoutHandler(logoutHandler(http));
authenticationManager(http));
http.addFilterBefore(filter, CsrfFilter.class);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@
import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.mock.web.MockHttpSession;
import org.springframework.mock.web.MockServletContext;
import org.springframework.security.authentication.AuthenticationManager;
import org.springframework.security.authentication.AuthenticationProvider;
import org.springframework.security.authentication.ProviderManager;
import org.springframework.security.config.Customizer;
import org.springframework.security.config.annotation.web.builders.HttpSecurity;
import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
Expand All @@ -62,14 +63,13 @@
import org.springframework.security.core.userdetails.User;
import org.springframework.security.core.userdetails.UserDetailsService;
import org.springframework.security.oauth2.client.oidc.authentication.logout.LogoutTokenClaimNames;
import org.springframework.security.oauth2.client.oidc.authentication.logout.OidcBackChannelLogoutAuthentication;
import org.springframework.security.oauth2.client.oidc.authentication.logout.OidcBackChannelLogoutAuthenticationProvider;
import org.springframework.security.oauth2.client.oidc.authentication.logout.OidcLogoutAuthenticationToken;
import org.springframework.security.oauth2.client.oidc.authentication.logout.OidcLogoutToken;
import org.springframework.security.oauth2.client.oidc.authentication.logout.TestOidcLogoutTokens;
import org.springframework.security.oauth2.client.oidc.session.OidcSessionInformation;
import org.springframework.security.oauth2.client.oidc.session.OidcSessionRegistry;
import org.springframework.security.oauth2.client.oidc.session.TestOidcSessionInformations;
import org.springframework.security.oauth2.client.oidc.web.logout.OidcBackChannelLogoutHandler;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
import org.springframework.security.oauth2.client.registration.InMemoryClientRegistrationRepository;
Expand All @@ -78,13 +78,15 @@
import org.springframework.security.oauth2.core.oidc.TestOidcIdTokens;
import org.springframework.security.oauth2.core.oidc.user.OidcUser;
import org.springframework.security.oauth2.jwt.JwtClaimsSet;
import org.springframework.security.oauth2.jwt.JwtDecoder;
import org.springframework.security.oauth2.jwt.JwtDecoderFactory;
import org.springframework.security.oauth2.jwt.JwtEncoder;
import org.springframework.security.oauth2.jwt.JwtEncoderParameters;
import org.springframework.security.oauth2.jwt.NimbusJwtEncoder;
import org.springframework.security.oauth2.jwt.TestJwts;
import org.springframework.security.provisioning.InMemoryUserDetailsManager;
import org.springframework.security.web.SecurityFilterChain;
import org.springframework.security.web.authentication.AuthenticationConverter;
import org.springframework.security.web.authentication.logout.LogoutHandler;
import org.springframework.test.web.servlet.MockMvc;
import org.springframework.test.web.servlet.MvcResult;
import org.springframework.test.web.servlet.ResultActions;
Expand Down Expand Up @@ -175,19 +177,21 @@ void logoutWhenCustomComponentsThenUses() throws Exception {
String registrationId = this.clientRegistration.getRegistrationId();
AuthenticationConverter authenticationConverter = this.spring.getContext()
.getBean(AuthenticationConverter.class);
given(authenticationConverter.convert(any()))
.willReturn(new OidcLogoutAuthenticationToken("token", this.clientRegistration));
AuthenticationManager authenticationManager = this.spring.getContext().getBean(AuthenticationManager.class);
given(authenticationConverter.convert(any())).willReturn(new OidcLogoutAuthenticationToken("token",
this.clientRegistration, "http://localhost/logout/connect/back-channel/" + registrationId));
OidcLogoutToken logoutToken = TestOidcLogoutTokens.withSessionId("issuer", "provider").build();
given(authenticationManager.authenticate(any()))
.willReturn(new OidcBackChannelLogoutAuthentication(logoutToken));
OidcSessionRegistry sessionRegistry = this.spring.getContext().getBean(OidcSessionRegistry.class);
JwtDecoderFactory<ClientRegistration> decoderFactory = this.spring.getContext()
.getBean(JwtDecoderFactory.class);
JwtDecoder decoder = mock(JwtDecoder.class);
given(decoder.decode(any()))
.willReturn(TestJwts.jwt().claims((claims) -> claims.putAll(logoutToken.getClaims())).build());
given(decoderFactory.createDecoder(any())).willReturn(decoder);
Set<OidcSessionInformation> details = Set.of(TestOidcSessionInformations.create());
given(sessionRegistry.removeSessionInformation(logoutToken)).willReturn(details);
OidcSessionRegistry sessionRegistry = this.spring.getContext().getBean(OidcSessionRegistry.class);
given(sessionRegistry.removeSessionInformation(any(OidcLogoutToken.class))).willReturn(details);
this.mvc.perform(post("/logout/connect/back-channel/" + registrationId).param("logout_token", "token"))
.andExpect(status().isOk());
verify(authenticationManager).authenticate(any());
verify(this.spring.getContext().getBean(LogoutHandler.class)).logout(any(), any(), any());
verify(decoder).decode(any());
verify(sessionRegistry).removeSessionInformation(logoutToken);
}

Expand Down Expand Up @@ -241,24 +245,25 @@ static class WithCustomComponentsConfig {

AuthenticationConverter authenticationConverter = mock(AuthenticationConverter.class);

AuthenticationManager authenticationManager = mock(AuthenticationManager.class);
OidcBackChannelLogoutAuthenticationProvider authenticationProvider = spy(
new OidcBackChannelLogoutAuthenticationProvider());

OidcSessionRegistry sessionRegistry = mock(OidcSessionRegistry.class);
JwtDecoderFactory<ClientRegistration> decoderFactory = mock(JwtDecoderFactory.class);

OidcBackChannelLogoutHandler logoutHandler = spy(new OidcBackChannelLogoutHandler());
OidcSessionRegistry sessionRegistry = mock(OidcSessionRegistry.class);

@Bean
@Order(1)
SecurityFilterChain filters(HttpSecurity http) throws Exception {
this.logoutHandler.setSessionRegistry(this.sessionRegistry);
this.authenticationProvider.setSessionRegistry(this.sessionRegistry);
this.authenticationProvider.setLogoutTokenDecoderFactory(this.decoderFactory);
// @formatter:off
http
.authorizeHttpRequests((authorize) -> authorize.anyRequest().authenticated())
.oauth2Login((oauth2) -> oauth2.oidcSessionRegistry(this.sessionRegistry))
.oidcLogout((oidc) -> oidc.backChannel((logout) -> logout
.authenticationConverter(this.authenticationConverter)
.authenticationManager(this.authenticationManager)
.logoutHandler(this.logoutHandler)
.authenticationManager(new ProviderManager(this.authenticationProvider))
));
// @formatter:on

Expand All @@ -271,8 +276,8 @@ AuthenticationConverter authenticationConverter() {
}

@Bean
AuthenticationManager authenticationManager() {
return this.authenticationManager;
AuthenticationProvider authenticationProvider() {
return this.authenticationProvider;
}

@Bean
Expand All @@ -281,8 +286,8 @@ OidcSessionRegistry sessionRegistry() {
}

@Bean
LogoutHandler logoutHandler() {
return this.logoutHandler;
JwtDecoderFactory<ClientRegistration> jwtDecoderFactory() {
return this.decoderFactory;
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import java.util.Collections;

import org.springframework.security.authentication.AbstractAuthenticationToken;
import org.springframework.security.oauth2.client.oidc.session.OidcSessionInformation;

/**
* An {@link org.springframework.security.core.Authentication} implementation that
Expand All @@ -36,13 +37,17 @@ public class OidcBackChannelLogoutAuthentication extends AbstractAuthenticationT

private final OidcLogoutToken logoutToken;

private final Iterable<OidcSessionInformation> invalidated;

/**
* Construct an {@link OidcBackChannelLogoutAuthentication}
* @param logoutToken a deserialized, verified OIDC Logout Token
*/
public OidcBackChannelLogoutAuthentication(OidcLogoutToken logoutToken) {
public OidcBackChannelLogoutAuthentication(OidcLogoutToken logoutToken,
Iterable<OidcSessionInformation> invalidated) {
super(Collections.emptyList());
this.logoutToken = logoutToken;
this.invalidated = invalidated;
setAuthenticated(true);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,23 @@

package org.springframework.security.oauth2.client.oidc.authentication.logout;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Map;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

import org.springframework.http.HttpEntity;
import org.springframework.http.HttpHeaders;
import org.springframework.security.authentication.AuthenticationProvider;
import org.springframework.security.authentication.AuthenticationServiceException;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.oauth2.client.oidc.authentication.OidcIdTokenDecoderFactory;
import org.springframework.security.oauth2.client.oidc.session.InMemoryOidcSessionRegistry;
import org.springframework.security.oauth2.client.oidc.session.OidcSessionInformation;
import org.springframework.security.oauth2.client.oidc.session.OidcSessionRegistry;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2Error;
Expand All @@ -30,6 +42,10 @@
import org.springframework.security.oauth2.jwt.JwtDecoder;
import org.springframework.security.oauth2.jwt.JwtDecoderFactory;
import org.springframework.util.Assert;
import org.springframework.web.client.RestClientException;
import org.springframework.web.client.RestOperations;
import org.springframework.web.client.RestTemplate;
import org.springframework.web.util.UriComponentsBuilder;

/**
* An {@link AuthenticationProvider} that authenticates an OIDC Logout Token; namely
Expand All @@ -49,8 +65,18 @@
*/
public final class OidcBackChannelLogoutAuthenticationProvider implements AuthenticationProvider {

private final Log logger = LogFactory.getLog(getClass());

private JwtDecoderFactory<ClientRegistration> logoutTokenDecoderFactory;

private OidcSessionRegistry sessionRegistry = new InMemoryOidcSessionRegistry();

private RestOperations restOperations = new RestTemplate();

private String logoutEndpointName = "/logout";

private String sessionCookieName = "JSESSIONID";

/**
* Construct an {@link OidcBackChannelLogoutAuthenticationProvider}
*/
Expand All @@ -73,7 +99,8 @@ public Authentication authenticate(Authentication authentication) throws Authent
Jwt jwt = decode(registration, logoutToken);
OidcLogoutToken oidcLogoutToken = OidcLogoutToken.withTokenValue(logoutToken)
.claims((claims) -> claims.putAll(jwt.getClaims())).build();
return new OidcBackChannelLogoutAuthentication(oidcLogoutToken);
Collection<OidcSessionInformation> loggedOut = logout(token.getBaseUrl(), oidcLogoutToken);
return new OidcBackChannelLogoutAuthentication(oidcLogoutToken, loggedOut);
}

/**
Expand All @@ -99,6 +126,40 @@ private Jwt decode(ClientRegistration registration, String token) {
}
}

private Collection<OidcSessionInformation> logout(String baseUrl, OidcLogoutToken token) {
Iterable<OidcSessionInformation> sessions = this.sessionRegistry.removeSessionInformation(token);
Collection<OidcSessionInformation> invalidated = new ArrayList<>();
int totalCount = 0;
int invalidatedCount = 0;
for (OidcSessionInformation session : sessions) {
totalCount++;
try {
eachLogout(baseUrl, session);
invalidated.add(session);
invalidatedCount++;
}
catch (RestClientException ex) {
this.logger.debug("Failed to invalidate session", ex);
}
}
if (this.logger.isTraceEnabled()) {
this.logger.trace(String.format("Invalidated %d out of %d sessions", invalidatedCount, totalCount));
}
return invalidated;
}

private void eachLogout(String baseUrl, OidcSessionInformation session) {
HttpHeaders headers = new HttpHeaders();
headers.add(HttpHeaders.COOKIE, this.sessionCookieName + "=" + session.getSessionId());
for (Map.Entry<String, String> credential : session.getAuthorities().entrySet()) {
headers.add(credential.getKey(), credential.getValue());
}
String logout = UriComponentsBuilder.fromHttpUrl(baseUrl).replacePath(this.logoutEndpointName).build()
.toUriString();
HttpEntity<?> entity = new HttpEntity<>(null, headers);
this.restOperations.postForEntity(logout, entity, Object.class);
}

/**
* Use this {@link JwtDecoderFactory} to generate {@link JwtDecoder}s that correspond
* to the {@link ClientRegistration} associated with the OIDC logout token.
Expand All @@ -109,4 +170,8 @@ public void setLogoutTokenDecoderFactory(JwtDecoderFactory<ClientRegistration> l
this.logoutTokenDecoderFactory = logoutTokenDecoderFactory;
}

public void setSessionRegistry(OidcSessionRegistry sessionRegistry) {
this.sessionRegistry = sessionRegistry;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,19 @@ public class OidcLogoutAuthenticationToken extends AbstractAuthenticationToken {

private final ClientRegistration clientRegistration;

private final String baseUrl;

/**
* Construct an {@link OidcLogoutAuthenticationToken}
* @param logoutToken a signed, serialized OIDC Logout token
* @param clientRegistration the {@link ClientRegistration client} associated with
* this token; this is usually derived from material in the logout HTTP request
*/
public OidcLogoutAuthenticationToken(String logoutToken, ClientRegistration clientRegistration) {
public OidcLogoutAuthenticationToken(String logoutToken, ClientRegistration clientRegistration, String baseUrl) {
super(AuthorityUtils.NO_AUTHORITIES);
this.logoutToken = logoutToken;
this.clientRegistration = clientRegistration;
this.baseUrl = baseUrl;
}

/**
Expand Down Expand Up @@ -77,4 +80,8 @@ public ClientRegistration getClientRegistration() {
return this.clientRegistration;
}

public String getBaseUrl() {
return this.baseUrl;
}

}
Loading

0 comments on commit 03646fe

Please sign in to comment.