Skip to content

Commit

Permalink
Move OidcSessionRegistry Login Configuration to oauth2Login
Browse files Browse the repository at this point in the history
  • Loading branch information
jzheaux committed Jul 25, 2023
1 parent 556b242 commit 3e14022
Show file tree
Hide file tree
Showing 4 changed files with 150 additions and 163 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
import org.springframework.security.config.annotation.web.configurers.AbstractHttpConfigurer;
import org.springframework.security.oauth2.client.InMemoryOAuth2AuthorizedClientService;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClientService;
import org.springframework.security.oauth2.client.oidc.session.InMemoryOidcSessionRegistry;
import org.springframework.security.oauth2.client.oidc.session.OidcSessionRegistry;
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
import org.springframework.security.oauth2.client.web.AuthenticatedPrincipalOAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
Expand Down Expand Up @@ -112,4 +114,13 @@ private static <B extends HttpSecurityBuilder<B>> OAuth2AuthorizedClientService
return (!authorizedClientServiceMap.isEmpty() ? authorizedClientServiceMap.values().iterator().next() : null);
}

static <B extends HttpSecurityBuilder<B>> OidcSessionRegistry getOidcSessionRegistry(B builder) {
OidcSessionRegistry sessionRegistry = builder.getSharedObject(OidcSessionRegistry.class);
if (sessionRegistry == null) {
sessionRegistry = new InMemoryOidcSessionRegistry();
builder.setSharedObject(OidcSessionRegistry.class, sessionRegistry);
}
return sessionRegistry;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -22,26 +22,43 @@
import java.util.LinkedHashMap;
import java.util.Map;

import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import jakarta.servlet.http.HttpSession;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

import org.springframework.beans.factory.BeanFactoryUtils;
import org.springframework.beans.factory.NoUniqueBeanDefinitionException;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationListener;
import org.springframework.context.event.GenericApplicationListenerAdapter;
import org.springframework.context.event.SmartApplicationListener;
import org.springframework.core.ResolvableType;
import org.springframework.security.authentication.AuthenticationProvider;
import org.springframework.security.config.Customizer;
import org.springframework.security.config.annotation.web.HttpSecurityBuilder;
import org.springframework.security.config.annotation.web.builders.HttpSecurity;
import org.springframework.security.config.annotation.web.configurers.AbstractAuthenticationFilterConfigurer;
import org.springframework.security.config.annotation.web.configurers.AbstractHttpConfigurer;
import org.springframework.security.config.annotation.web.configurers.SessionManagementConfigurer;
import org.springframework.security.context.DelegatingApplicationListener;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMapper;
import org.springframework.security.core.session.AbstractSessionEvent;
import org.springframework.security.core.session.SessionDestroyedEvent;
import org.springframework.security.core.session.SessionIdChangedEvent;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClientService;
import org.springframework.security.oauth2.client.authentication.OAuth2LoginAuthenticationProvider;
import org.springframework.security.oauth2.client.authentication.OAuth2LoginAuthenticationToken;
import org.springframework.security.oauth2.client.endpoint.DefaultAuthorizationCodeTokenResponseClient;
import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient;
import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCodeGrantRequest;
import org.springframework.security.oauth2.client.oidc.authentication.OidcAuthorizationCodeAuthenticationProvider;
import org.springframework.security.oauth2.client.oidc.session.InMemoryOidcSessionRegistry;
import org.springframework.security.oauth2.client.oidc.session.OidcSessionInformation;
import org.springframework.security.oauth2.client.oidc.session.OidcSessionRegistry;
import org.springframework.security.oauth2.client.oidc.userinfo.OidcUserRequest;
import org.springframework.security.oauth2.client.oidc.userinfo.OidcUserService;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
Expand All @@ -67,7 +84,10 @@
import org.springframework.security.web.RedirectStrategy;
import org.springframework.security.web.authentication.DelegatingAuthenticationEntryPoint;
import org.springframework.security.web.authentication.LoginUrlAuthenticationEntryPoint;
import org.springframework.security.web.authentication.session.SessionAuthenticationException;
import org.springframework.security.web.authentication.session.SessionAuthenticationStrategy;
import org.springframework.security.web.authentication.ui.DefaultLoginPageGeneratingFilter;
import org.springframework.security.web.csrf.CsrfToken;
import org.springframework.security.web.savedrequest.RequestCache;
import org.springframework.security.web.util.matcher.AndRequestMatcher;
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
Expand Down Expand Up @@ -124,6 +144,7 @@
* <li>{@link DefaultLoginPageGeneratingFilter} - if {@link #loginPage(String)} is not
* configured and {@code DefaultLoginPageGeneratingFilter} is available, then a default
* login page will be made available</li>
* <li>{@link OidcSessionRegistry}</li>
* </ul>
*
* @author Joe Grandja
Expand Down Expand Up @@ -202,6 +223,17 @@ public OAuth2LoginConfigurer<B> loginProcessingUrl(String loginProcessingUrl) {
return this;
}

/**
* Sets the registry for managing the OIDC client-provider session link
* @param sessionRegistry the {@link OidcSessionRegistry} to use
* @return the {@link OAuth2LoginConfigurer} for further configuration
*/
public OAuth2LoginConfigurer<B> oidcSessionRegistry(OidcSessionRegistry sessionRegistry) {
Assert.notNull(sessionRegistry, "sessionRegistry cannot be null");
this.getBuilder().setSharedObject(OidcSessionRegistry.class, sessionRegistry);
return this;
}

/**
* Returns the {@link AuthorizationEndpointConfig} for configuring the Authorization
* Server's Authorization Endpoint.
Expand Down Expand Up @@ -400,6 +432,7 @@ public void configure(B http) throws Exception {
authenticationFilter
.setAuthorizationRequestRepository(this.authorizationEndpointConfig.authorizationRequestRepository);
}
configureOidcSessionRegistry(http);
super.configure(http);
}

Expand Down Expand Up @@ -539,6 +572,29 @@ private RequestMatcher getFormLoginNotEnabledRequestMatcher(B http) {
return AnyRequestMatcher.INSTANCE;
}

private void configureOidcSessionRegistry(B http) {
OidcSessionRegistry sessionRegistry = OAuth2ClientConfigurerUtils.getOidcSessionRegistry(http);
SessionManagementConfigurer<B> sessionConfigurer = http.getConfigurer(SessionManagementConfigurer.class);
if (sessionConfigurer != null) {
OidcSessionRegistryAuthenticationStrategy sessionAuthenticationStrategy = new OidcSessionRegistryAuthenticationStrategy();
sessionAuthenticationStrategy.setSessionRegistry(sessionRegistry);
sessionConfigurer.addSessionAuthenticationStrategy(sessionAuthenticationStrategy);
}
OidcClientSessionEventListener listener = new OidcClientSessionEventListener();
listener.setSessionRegistry(sessionRegistry);
registerDelegateApplicationListener(listener);
}

private void registerDelegateApplicationListener(ApplicationListener<?> delegate) {
DelegatingApplicationListener delegating = getBeanOrNull(
ResolvableType.forType(DelegatingApplicationListener.class));
if (delegating == null) {
return;
}
SmartApplicationListener smartListener = new GenericApplicationListenerAdapter(delegate);
delegating.addListener(smartListener);
}

/**
* Configuration options for the Authorization Server's Authorization Endpoint.
*/
Expand Down Expand Up @@ -786,4 +842,83 @@ public boolean supports(Class<?> authentication) {

}

private static final class OidcClientSessionEventListener implements ApplicationListener<AbstractSessionEvent> {

private final Log logger = LogFactory.getLog(OidcClientSessionEventListener.class);

private OidcSessionRegistry sessionRegistry = new InMemoryOidcSessionRegistry();

/**
* {@inheritDoc}
*/
@Override
public void onApplicationEvent(AbstractSessionEvent event) {
if (event instanceof SessionDestroyedEvent destroyed) {
this.logger.debug("Received SessionDestroyedEvent");
this.sessionRegistry.removeSessionInformation(destroyed.getId());
return;
}
if (event instanceof SessionIdChangedEvent changed) {
this.logger.debug("Received SessionIdChangedEvent");
OidcSessionInformation information = this.sessionRegistry.removeSessionInformation(changed.getOldSessionId());
if (information == null) {
this.logger.debug("Failed to register new session id since old session id was not found in registry");
return;
}
this.sessionRegistry.saveSessionInformation(information.withSessionId(changed.getNewSessionId()));
}
}

/**
* The registry where OIDC Provider sessions are linked to the Client session.
* Defaults to in-memory storage.
* @param sessionRegistry the {@link OidcSessionRegistry} to use
*/
void setSessionRegistry(OidcSessionRegistry sessionRegistry) {
Assert.notNull(sessionRegistry, "sessionRegistry cannot be null");
this.sessionRegistry = sessionRegistry;
}

}

private static final class OidcSessionRegistryAuthenticationStrategy implements SessionAuthenticationStrategy {

private final Log logger = LogFactory.getLog(getClass());

private OidcSessionRegistry sessionRegistry = new InMemoryOidcSessionRegistry();

/**
* {@inheritDoc}
*/
@Override
public void onAuthentication(Authentication authentication, HttpServletRequest request, HttpServletResponse response) throws SessionAuthenticationException {
HttpSession session = request.getSession(false);
if (session == null) {
return;
}
if (!(authentication.getPrincipal() instanceof OidcUser user)) {
return;
}
String sessionId = session.getId();
CsrfToken csrfToken = (CsrfToken) request.getAttribute(CsrfToken.class.getName());
Map<String, String> headers = (csrfToken != null) ? Map.of(csrfToken.getHeaderName(), csrfToken.getToken()) : Collections.emptyMap();
OidcSessionInformation registration = new OidcSessionInformation(sessionId, headers, user);
if (this.logger.isTraceEnabled()) {
this.logger.trace(String.format("Linking a provider [%s] session to this client's session", user.getIssuer()));
}
this.sessionRegistry.saveSessionInformation(registration);
}

/**
* The registration for linking OIDC Provider Session information to the Client's
* session. Defaults to in-memory storage.
* @param sessionRegistry the {@link OidcSessionRegistry} to use
*/
void setSessionRegistry(OidcSessionRegistry sessionRegistry) {
Assert.notNull(sessionRegistry, "sessionRegistry cannot be null");
this.sessionRegistry = sessionRegistry;
}

}

}
Loading

0 comments on commit 3e14022

Please sign in to comment.