diff --git a/pom.xml b/pom.xml index df6bb8e..37198a7 100644 --- a/pom.xml +++ b/pom.xml @@ -6,7 +6,7 @@ com.gw2auth oauth2-server - 1.80.0 + 1.81.0 jar diff --git a/src/main/java/com/gw2auth/oauth2/server/configuration/OAuth2ServerConfiguration.java b/src/main/java/com/gw2auth/oauth2/server/configuration/OAuth2ServerConfiguration.java index a0e954a..5ec4e75 100644 --- a/src/main/java/com/gw2auth/oauth2/server/configuration/OAuth2ServerConfiguration.java +++ b/src/main/java/com/gw2auth/oauth2/server/configuration/OAuth2ServerConfiguration.java @@ -2,7 +2,6 @@ import com.gw2auth.oauth2.server.adapt.CustomOAuth2ServerAuthenticationProviders; import com.gw2auth.oauth2.server.service.application.AuthorizationCodeParamAccessor; -import com.gw2auth.oauth2.server.util.ComposedMDCCloseable; import com.gw2auth.oauth2.server.util.JWKHelper; import com.nimbusds.jose.JOSEException; import com.nimbusds.jose.jwk.source.JWKSource; @@ -13,6 +12,7 @@ import jakarta.servlet.http.HttpServletResponse; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.slf4j.MDC; import org.springframework.beans.factory.annotation.Qualifier; import org.springframework.beans.factory.annotation.Value; import org.springframework.context.annotation.Bean; @@ -28,25 +28,33 @@ import org.springframework.security.config.annotation.web.configurers.SecurityContextConfigurer; import org.springframework.security.config.annotation.web.configurers.oauth2.client.OAuth2LoginConfigurer; import org.springframework.security.config.http.SessionCreationPolicy; -import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.AuthenticationException; +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AccessTokenAuthenticationToken; import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationCodeRequestAuthenticationProvider; import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationConsentAuthenticationProvider; +import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientAuthenticationToken; +import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.config.annotation.web.configurers.OAuth2AuthorizationServerConfigurer; import org.springframework.security.oauth2.server.authorization.settings.AuthorizationServerSettings; +import org.springframework.security.oauth2.server.authorization.web.authentication.DelegatingAuthenticationConverter; +import org.springframework.security.oauth2.server.authorization.web.authentication.OAuth2AccessTokenResponseAuthenticationSuccessHandler; +import org.springframework.security.oauth2.server.authorization.web.authentication.OAuth2ErrorAuthenticationFailureHandler; import org.springframework.security.web.SecurityFilterChain; +import org.springframework.security.web.authentication.AuthenticationConverter; +import org.springframework.security.web.authentication.AuthenticationFailureHandler; +import org.springframework.security.web.authentication.AuthenticationSuccessHandler; import org.springframework.security.web.context.SecurityContextHolderFilter; import org.springframework.security.web.util.matcher.AntPathRequestMatcher; import org.springframework.security.web.util.matcher.OrRequestMatcher; import org.springframework.security.web.util.matcher.RequestMatcher; import org.springframework.web.filter.OncePerRequestFilter; -import org.springframework.web.util.UriComponents; -import org.springframework.web.util.UriComponentsBuilder; import java.io.IOException; import java.security.GeneralSecurityException; import java.security.KeyPair; import java.util.*; -import java.util.function.Function; @Configuration public class OAuth2ServerConfiguration { @@ -92,6 +100,19 @@ public OAuth2AuthorizationServerConfigurer oAuth2AuthorizationServerConfigurer(H .consentPage(OAUTH2_CONSENT_PAGE); }); + authorizationServerConfigurer.tokenEndpoint((tokenEndpoint) -> { + final OAuth2TokenResponseHandler handler = new OAuth2TokenResponseHandler(); + + tokenEndpoint.accessTokenRequestConverters((accessTokenRequestConverters) -> { + handler.setAuthenticationConverters(accessTokenRequestConverters); + accessTokenRequestConverters.clear(); + accessTokenRequestConverters.add(handler); + }); + + tokenEndpoint.accessTokenResponseHandler(handler); + tokenEndpoint.errorResponseHandler(handler); + }); + return authorizationServerConfigurer; } @@ -119,8 +140,7 @@ public SecurityFilterChain oauth2ServerHttpSecurityFilterChain(HttpSecurity http .securityContext(securityContextCustomizer) .requestCache(requestCacheCustomizer) .oauth2Login(oauth2LoginCustomizer) - .with(configurer, ignored -> {}) - .addFilterBefore(new OAuth2ServerLoggingFilter(), SecurityContextHolderFilter.class); + .with(configurer, ignored -> {}); return http.build(); } @@ -154,140 +174,69 @@ public AuthorizationCodeParamAccessor authorizationCodeParamAccessor() { return AuthorizationCodeParamAccessor.DEFAULT; } - private static class OAuth2ServerLoggingFilter extends OncePerRequestFilter { + private static class OAuth2TokenResponseHandler implements AuthenticationConverter, AuthenticationSuccessHandler, AuthenticationFailureHandler { - private static final Logger LOG = LoggerFactory.getLogger(OAuth2ServerLoggingFilter.class); + private static final Logger LOG = LoggerFactory.getLogger(OAuth2TokenResponseHandler.class); + private static final String CLIENT_ID_ATTRIBUTE_NAME = OAuth2TokenResponseHandler.class.getName() + "::CLIENT_ID"; + private static final AuthenticationSuccessHandler SUCCESS_DELEGATE = new OAuth2AccessTokenResponseAuthenticationSuccessHandler(); + private static final AuthenticationFailureHandler FAILURE_DELEGATE = new OAuth2ErrorAuthenticationFailureHandler(); - @Override - protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException { - Map requestAttributes; - try { - requestAttributes = buildRequestAttributes(request); - } catch (Exception e) { - // better be safe than sorry - LOG.warn("failed to build request attributes", e); - requestAttributes = Map.of(); - } + private AuthenticationConverter authenticationConverterDelegate; - Exception exc = null; - try { - filterChain.doFilter(request, response); - } catch (Exception e) { - exc = e; - } + private OAuth2TokenResponseHandler() { + this.authenticationConverterDelegate = null; + } - Map responseAttributes; - try { - responseAttributes = buildResponseAttributes(response); - } catch (Exception e) { - LOG.warn("failed to build response attributes", e); - responseAttributes = Map.of(); - } + private void setAuthenticationConverters(List authenticationConverters) { + this.authenticationConverterDelegate = new DelegatingAuthenticationConverter(authenticationConverters); + } - try (ComposedMDCCloseable _unused = ComposedMDCCloseable.create(requestAttributes, Object::toString)) { - try (ComposedMDCCloseable __unused = ComposedMDCCloseable.create(responseAttributes, Object::toString)) { - if (exc == null) { - LOG.info("oauth2 request handled successfully"); - } else { - LOG.info("oauth2 request failed", exc); + @Override + public Authentication convert(HttpServletRequest request) { + final Authentication authentication = this.authenticationConverterDelegate.convert(request); + if (authentication != null) { + final Object principal = authentication.getPrincipal(); + if (principal instanceof OAuth2ClientAuthenticationToken token) { + final RegisteredClient client = token.getRegisteredClient(); + if (client != null) { + request.setAttribute(CLIENT_ID_ATTRIBUTE_NAME, client.getClientId()); } } } - if (exc != null) { - switch (exc) { - case ServletException e: - throw e; - - case IOException e: - throw e; - - case RuntimeException e: - throw e; - - default: - throw new RuntimeException("Unexpected error occurred while logging request", exc); - } - } + return authentication; } - private static Map buildRequestAttributes(HttpServletRequest request) { - final UriComponents uriComponents = UriComponentsBuilder.fromUriString(request.getRequestURI()) - .query(request.getQueryString()) - .build(); - - final Map attributes = new HashMap<>(); - attributes.put("request.method", request.getMethod()); - attributes.put("request.url", uriComponents.getPath()); - uriComponents.getQueryParams().forEach((key, value) -> { - if (!key.equalsIgnoreCase(OAuth2ParameterNames.CLIENT_SECRET) - && !key.equalsIgnoreCase(OAuth2ParameterNames.STATE) - && !key.equalsIgnoreCase("code_challenge") - && !key.equalsIgnoreCase("code_verifier")) { - - addMultiValue(attributes, "request.query." + key, value); + @Override + public void onAuthenticationSuccess(HttpServletRequest request, HttpServletResponse response, Authentication authentication) throws IOException, ServletException { + if (authentication instanceof OAuth2AccessTokenAuthenticationToken token) { + try (MDC.MDCCloseable _unused = MDC.putCloseable("client_id", token.getRegisteredClient().getClientId())) { + LOG.info("oauth2 token request succeeded"); } - }); - - addHeaders( - attributes, - "request", - () -> request.getHeaderNames().asIterator(), - (v) -> () -> request.getHeaders(v).asIterator(), - Set.of( - "cookie", - "authorization" - ) - ); - - return attributes; - } + } - private static Map buildResponseAttributes(HttpServletResponse response) { - final Map attributes = new HashMap<>(); - attributes.put("response.status_code", Integer.toString(response.getStatus())); - addHeaders( - attributes, - "response", - response.getHeaderNames(), - response::getHeaders, - Set.of( - "set-cookie", - "pragma", - "x-xss-protection", - "x-content-type-options", - "expires", - "cache-control", - "x-frame-options" - ) - ); - - return attributes; + SUCCESS_DELEGATE.onAuthenticationSuccess(request, response, authentication); } - private static void addHeaders(Map map, String prefix, Iterable names, Function> getHeaders, Set ignore) { - for (String header : names) { - if (!ignore.contains(header.toLowerCase())) { - final List values = new ArrayList<>(); - for (String value : getHeaders.apply(header)) { - values.add(value); + @Override + public void onAuthenticationFailure(HttpServletRequest request, HttpServletResponse response, AuthenticationException exception) throws IOException, ServletException { + if (exception instanceof OAuth2AuthenticationException oauth2AuthenticationException) { + try (MDC.MDCCloseable _unused = MDC.putCloseable("client_id", getClientId(request))) { + try (MDC.MDCCloseable __unused = MDC.putCloseable("error_code", oauth2AuthenticationException.getError().getErrorCode())) { + try (MDC.MDCCloseable ___unused = MDC.putCloseable("error_description", oauth2AuthenticationException.getError().getDescription())) { + LOG.info("oauth2 token request failed"); + } } - - addMultiValue(map, prefix + ".header." + header, values); } } + + FAILURE_DELEGATE.onAuthenticationFailure(request, response, exception); } - private static void addMultiValue(Map map, String key, List values) { - if (values.isEmpty()) { - map.put(key, ""); - } else if (values.size() == 1) { - map.put(key, values.getFirst()); - } else { - for (int i = 0; i < values.size(); i++) { - map.put(key + "." + i, values.get(i)); - } - } + private static String getClientId(HttpServletRequest request) { + return Optional.ofNullable(request.getAttribute(CLIENT_ID_ATTRIBUTE_NAME)) + .map(Object::toString) + .orElse("UNKNOWN"); } } } diff --git a/src/main/java/com/gw2auth/oauth2/server/util/ComposedMDCCloseable.java b/src/main/java/com/gw2auth/oauth2/server/util/ComposedMDCCloseable.java index 05e28c9..84e3c0c 100644 --- a/src/main/java/com/gw2auth/oauth2/server/util/ComposedMDCCloseable.java +++ b/src/main/java/com/gw2auth/oauth2/server/util/ComposedMDCCloseable.java @@ -5,6 +5,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.function.Function; public final class ComposedMDCCloseable implements AutoCloseable { @@ -32,6 +33,10 @@ public void close() { } } + public static ComposedMDCCloseable create(Map fields) { + return create(fields, Objects::toString); + } + public static ComposedMDCCloseable create(Map fields, Function toStringFunction) { final List mdcCloseables = new ArrayList<>(); for (Map.Entry entry : fields.entrySet()) { diff --git a/src/test/java/com/gw2auth/oauth2/server/oauth2/OAuth2ServerTest.java b/src/test/java/com/gw2auth/oauth2/server/oauth2/OAuth2ServerTest.java index 75edad2..7eb01ca 100644 --- a/src/test/java/com/gw2auth/oauth2/server/oauth2/OAuth2ServerTest.java +++ b/src/test/java/com/gw2auth/oauth2/server/oauth2/OAuth2ServerTest.java @@ -53,6 +53,7 @@ import org.springframework.mock.http.client.MockClientHttpResponse; import org.springframework.mock.web.MockPart; import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.OAuth2ErrorCodes; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.security.oauth2.server.authorization.OAuth2TokenType; import org.springframework.test.web.client.MockRestServiceServer; @@ -1793,6 +1794,58 @@ public void consentSubmitAndSubmitAgainWithLessScopes(SessionHandle sessionHandl ), Set.of(OAuth2Scope.GW2_ACCOUNT, OAuth2Scope.GW2_UNLOCKS)); } + @ParameterizedTest + @WithGw2AuthLogin + @WithOAuth2ClientApiVersion + @WithOAuth2ClientType + public void retrieveAccessTokenWithInvalidClientSecret(SessionHandle sessionHandle, OAuth2ClientApiVersion clientApiVersion, OAuth2ClientType clientType) throws Exception { + final ApplicationClientCreation applicationClientCreation = createApplicationClient(clientApiVersion, clientType); + final ApplicationClient applicationClient = applicationClientCreation.client(); + // perform authorization request (which should redirect to the consent page) + MvcResult result = performAuthorizeWithClient(sessionHandle, applicationClient, Set.of(OAuth2Scope.GW2_ACCOUNT)).andReturn(); + + // submit the consent + final String tokenA = TestHelper.randomRootToken(); + final String tokenB = TestHelper.randomRootToken(); + final String tokenC = TestHelper.randomRootToken(); + result = performSubmitConsent(sessionHandle, applicationClient, URI.create(Objects.requireNonNull(result.getResponse().getRedirectedUrl())), tokenA, tokenB, tokenC).andReturn(); + + // set testing clock to token customizer + final Clock testingClock = Clock.fixed(Instant.now(), ZoneId.systemDefault()); + this.gw2AuthClockedExtension.setClock(testingClock); + + // retrieve the initial access and refresh token + final String dummySubtokenA = TestHelper.createSubtokenJWT(this.gw2AccountId1st, Set.of(Gw2ApiPermission.ACCOUNT), testingClock.instant(), Duration.ofMinutes(30L)); + final String dummySubtokenB = TestHelper.createSubtokenJWT(this.gw2AccountId2nd, Set.of(Gw2ApiPermission.ACCOUNT), testingClock.instant(), Duration.ofMinutes(30L)); + + performRetrieveTokenByCode( + applicationClient, + "invalid_client_secret", + URI.create(Objects.requireNonNull(result.getResponse().getRedirectedUrl())), + Map.of(tokenA, dummySubtokenA, tokenB, dummySubtokenB), + Set.of(Gw2ApiPermission.ACCOUNT) + ) + .andExpect(status().isUnauthorized()) + .andExpect(jsonPath("$.error").value(OAuth2ErrorCodes.INVALID_CLIENT)); + } + + @ParameterizedTest + @WithOAuth2ClientApiVersion + @WithOAuth2ClientType + public void retrieveAccessTokenWithInvalidCode(OAuth2ClientApiVersion clientApiVersion, OAuth2ClientType clientType) throws Exception { + final ApplicationClientCreation applicationClientCreation = createApplicationClient(clientApiVersion, clientType); + final ApplicationClient applicationClient = applicationClientCreation.client(); + + performRetrieveTokenByCode( + applicationClient, + applicationClientCreation.clientSecret(), + TestHelper.first(applicationClient.redirectUris()).orElseThrow(), + "invalid_code" + ) + .andExpect(status().isBadRequest()) + .andExpect(jsonPath("$.error").value(OAuth2ErrorCodes.INVALID_GRANT)); + } + @ParameterizedTest @WithGw2AuthLogin @WithOAuth2ClientApiVersion @@ -2386,6 +2439,24 @@ private ResultActions performRetrieveTokenByCode(ApplicationClient applicationCl return this.mockMvc.perform(builder); } + private ResultActions performRetrieveTokenByCode(ApplicationClient applicationClient, String clientSecret, String redirectUri, String code) throws Exception { + MockMultipartHttpServletRequestBuilder builder = multipart(HttpMethod.POST, "/oauth2/token") + .part(part(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.AUTHORIZATION_CODE.getValue())) + .part(part(OAuth2ParameterNames.CODE, code)) + .part(part(OAuth2ParameterNames.CLIENT_ID, applicationClient.id().toString())) + .part(part(OAuth2ParameterNames.REDIRECT_URI, redirectUri)); + + if (applicationClient.type() == OAuth2ClientType.CONFIDENTIAL) { + builder = builder.part(part(OAuth2ParameterNames.CLIENT_SECRET, clientSecret)); + } else { + builder = builder.part(part("code_verifier", generateCodeChallenge(applicationClient))); + } + + // retrieve an access token + // dont use the user session here! + return this.mockMvc.perform(builder); + } + private void prepareGw2RestServerForCreateSubToken(Map subtokenByGw2ApiToken) { prepareGw2RestServerForCreateSubToken(subtokenByGw2ApiToken, Set.of(Gw2ApiPermission.ACCOUNT)); }