diff --git a/modules/core/src/main/java/org/apache/synapse/endpoints/auth/oauth/AuthorizationCodeHandler.java b/modules/core/src/main/java/org/apache/synapse/endpoints/auth/oauth/AuthorizationCodeHandler.java index dc3432e41a..28576ba590 100644 --- a/modules/core/src/main/java/org/apache/synapse/endpoints/auth/oauth/AuthorizationCodeHandler.java +++ b/modules/core/src/main/java/org/apache/synapse/endpoints/auth/oauth/AuthorizationCodeHandler.java @@ -37,10 +37,11 @@ public class AuthorizationCodeHandler extends OAuthHandler { public AuthorizationCodeHandler(String tokenApiUrl, String clientId, String clientSecret, String refreshToken, String authMode, int connectionTimeout, - int connectionRequestTimeout, int socketTimeout) { + int connectionRequestTimeout, int socketTimeout, + TokenCacheProvider tokenCacheProvider) { super(tokenApiUrl, clientId, clientSecret, authMode, connectionTimeout, connectionRequestTimeout, - socketTimeout); + socketTimeout, tokenCacheProvider); this.refreshToken = refreshToken; } diff --git a/modules/core/src/main/java/org/apache/synapse/endpoints/auth/oauth/ClientCredentialsHandler.java b/modules/core/src/main/java/org/apache/synapse/endpoints/auth/oauth/ClientCredentialsHandler.java index df95cf1e81..37dfef724e 100644 --- a/modules/core/src/main/java/org/apache/synapse/endpoints/auth/oauth/ClientCredentialsHandler.java +++ b/modules/core/src/main/java/org/apache/synapse/endpoints/auth/oauth/ClientCredentialsHandler.java @@ -34,9 +34,11 @@ public class ClientCredentialsHandler extends OAuthHandler { public ClientCredentialsHandler(String tokenApiUrl, String clientId, String clientSecret, String authMode, - int connectionTimeout, int connectionRequestTimeout, int socketTimeout) { + int connectionTimeout, int connectionRequestTimeout, int socketTimeout, + TokenCacheProvider tokenCacheProvider) { - super(tokenApiUrl, clientId, clientSecret, authMode, connectionTimeout, connectionRequestTimeout, socketTimeout); + super(tokenApiUrl, clientId, clientSecret, authMode, connectionTimeout, connectionRequestTimeout, socketTimeout, + tokenCacheProvider); } @Override diff --git a/modules/core/src/main/java/org/apache/synapse/endpoints/auth/oauth/OAuthClient.java b/modules/core/src/main/java/org/apache/synapse/endpoints/auth/oauth/OAuthClient.java index 014c411688..e3c42ac9a8 100644 --- a/modules/core/src/main/java/org/apache/synapse/endpoints/auth/oauth/OAuthClient.java +++ b/modules/core/src/main/java/org/apache/synapse/endpoints/auth/oauth/OAuthClient.java @@ -44,7 +44,6 @@ import org.apache.http.impl.NoConnectionReuseStrategy; import org.apache.http.impl.client.CloseableHttpClient; import org.apache.http.impl.client.HttpClientBuilder; -import org.apache.http.impl.client.HttpClients; import org.apache.http.impl.conn.BasicHttpClientConnectionManager; import org.apache.synapse.MessageContext; import org.apache.synapse.core.axis2.Axis2MessageContext; @@ -100,28 +99,30 @@ public class OAuthClient { public static String generateToken(String tokenApiUrl, String payload, String credentials, MessageContext messageContext, Map customHeaders, int connectionTimeout, int connectionRequestTimeout, int socketTimeout) throws AuthException, IOException { - CloseableHttpClient httpClient = getSecureClient(tokenApiUrl, messageContext, connectionTimeout, - connectionRequestTimeout, socketTimeout); + if (log.isDebugEnabled()) { log.debug("Initializing token generation request: [token-endpoint] " + tokenApiUrl); } - HttpPost httpPost = new HttpPost(tokenApiUrl); - httpPost.setHeader(AuthConstants.CONTENT_TYPE_HEADER, AuthConstants.APPLICATION_X_WWW_FORM_URLENCODED); - if (!(customHeaders == null || customHeaders.isEmpty())) { - for (Map.Entry entry : customHeaders.entrySet()) { - httpPost.setHeader(entry.getKey(), entry.getValue()); + try (CloseableHttpClient httpClient = getSecureClient(tokenApiUrl, messageContext, connectionTimeout, + connectionRequestTimeout, socketTimeout)) { + HttpPost httpPost = new HttpPost(tokenApiUrl); + httpPost.setHeader(AuthConstants.CONTENT_TYPE_HEADER, AuthConstants.APPLICATION_X_WWW_FORM_URLENCODED); + if (!(customHeaders == null || customHeaders.isEmpty())) { + for (Map.Entry entry : customHeaders.entrySet()) { + httpPost.setHeader(entry.getKey(), entry.getValue()); + } } - } - if (credentials != null) { - httpPost.setHeader(AuthConstants.AUTHORIZATION_HEADER, AuthConstants.BASIC + credentials); - } - httpPost.setEntity(new StringEntity(payload)); + if (credentials != null) { + httpPost.setHeader(AuthConstants.AUTHORIZATION_HEADER, AuthConstants.BASIC + credentials); + } + httpPost.setEntity(new StringEntity(payload)); - try (CloseableHttpResponse response = httpClient.execute(httpPost)) { - return extractToken(response); - } finally { - httpPost.releaseConnection(); + try (CloseableHttpResponse response = httpClient.execute(httpPost)) { + return extractToken(response); + } finally { + httpPost.releaseConnection(); + } } } diff --git a/modules/core/src/main/java/org/apache/synapse/endpoints/auth/oauth/OAuthHandler.java b/modules/core/src/main/java/org/apache/synapse/endpoints/auth/oauth/OAuthHandler.java index b9d8f957da..a520b43000 100644 --- a/modules/core/src/main/java/org/apache/synapse/endpoints/auth/oauth/OAuthHandler.java +++ b/modules/core/src/main/java/org/apache/synapse/endpoints/auth/oauth/OAuthHandler.java @@ -34,8 +34,6 @@ import java.util.HashMap; import java.util.Map; import java.util.TreeMap; -import java.util.concurrent.Callable; -import java.util.concurrent.ExecutionException; /** * This abstract class is to be used by OAuth handlers @@ -55,9 +53,11 @@ public abstract class OAuthHandler implements AuthHandler { protected final int connectionTimeout; protected final int connectionRequestTimeout; protected final int socketTimeout; + private final TokenCacheProvider tokenCacheProvider; protected OAuthHandler(String tokenApiUrl, String clientId, String clientSecret, String authMode, - int connectionTimeout, int connectionRequestTimeout, int socketTimeout) { + int connectionTimeout, int connectionRequestTimeout, int socketTimeout, + TokenCacheProvider tokenCacheProvider) { this.id = OAuthUtils.getRandomOAuthHandlerID(); this.tokenApiUrl = tokenApiUrl; @@ -67,6 +67,7 @@ protected OAuthHandler(String tokenApiUrl, String clientId, String clientSecret, this.connectionTimeout = connectionTimeout; this.connectionRequestTimeout = connectionRequestTimeout; this.socketTimeout = socketTimeout; + this.tokenCacheProvider = tokenCacheProvider; } @Override @@ -87,18 +88,25 @@ public void setAuthHeader(MessageContext messageContext) throws AuthException { */ private String getToken(final MessageContext messageContext) throws AuthException { - try { - return TokenCache.getInstance().getToken(getId(messageContext), new Callable() { - @Override - public String call() throws AuthException, IOException { - return OAuthClient.generateToken(OAuthUtils.resolveExpression(tokenApiUrl, messageContext), + // Check if the token is already cached + String token = tokenCacheProvider.getToken(getId(messageContext)); + + synchronized (getId(messageContext).intern()) { + if (StringUtils.isEmpty(token)) { + // If no token found, generate a new one + try { + token = OAuthClient.generateToken(OAuthUtils.resolveExpression(tokenApiUrl, messageContext), buildTokenRequestPayload(messageContext), getEncodedCredentials(messageContext), - messageContext, getResolvedCustomHeadersMap(customHeadersMap, messageContext), connectionTimeout, - connectionRequestTimeout, socketTimeout); + messageContext, getResolvedCustomHeadersMap(customHeadersMap, messageContext), + connectionTimeout, connectionRequestTimeout, socketTimeout); + + // Cache the newly generated token + tokenCacheProvider.putToken(getId(messageContext), token); + } catch (IOException e) { + throw new AuthException("Error generating token", e); } - }); - } catch (ExecutionException e) { - throw new AuthException(e.getCause()); + } + return token; } } @@ -133,7 +141,7 @@ public int compare(String o1, String o2) { */ public void removeTokenFromCache(MessageContext messageContext) throws AuthException { - TokenCache.getInstance().removeToken(getId(messageContext)); + tokenCacheProvider.removeToken(getId(messageContext)); } /** @@ -141,7 +149,7 @@ public void removeTokenFromCache(MessageContext messageContext) throws AuthExcep */ public void removeTokensFromCache() { - TokenCache.getInstance().removeTokens(id.concat("_")); + tokenCacheProvider.removeTokens(id.concat("_")); } /** diff --git a/modules/core/src/main/java/org/apache/synapse/endpoints/auth/oauth/OAuthUtils.java b/modules/core/src/main/java/org/apache/synapse/endpoints/auth/oauth/OAuthUtils.java index 5c03f43095..48bb5b4809 100644 --- a/modules/core/src/main/java/org/apache/synapse/endpoints/auth/oauth/OAuthUtils.java +++ b/modules/core/src/main/java/org/apache/synapse/endpoints/auth/oauth/OAuthUtils.java @@ -129,7 +129,8 @@ private static AuthorizationCodeHandler getAuthorizationCodeHandler(OMElement au return null; } AuthorizationCodeHandler handler = new AuthorizationCodeHandler(tokenApiUrl, clientId, clientSecret, - refreshToken, authMode, connectionTimeout, connectionRequestTimeout, socketTimeout); + refreshToken, authMode, connectionTimeout, connectionRequestTimeout, socketTimeout, + TokenCacheFactory.getTokenCache()); if (hasRequestParameters(authCodeElement)) { Map requestParameters = getRequestParameters(authCodeElement); if (requestParameters == null) { @@ -170,7 +171,7 @@ private static ClientCredentialsHandler getClientCredentialsHandler( return null; } ClientCredentialsHandler handler = new ClientCredentialsHandler(tokenApiUrl, clientId, clientSecret, authMode, - connectionTimeout, connectionRequestTimeout, socketTimeout); + connectionTimeout, connectionRequestTimeout, socketTimeout, TokenCacheFactory.getTokenCache()); if (hasRequestParameters(clientCredentialsElement)) { Map requestParameters = getRequestParameters(clientCredentialsElement); if (requestParameters == null) { @@ -213,7 +214,8 @@ private static PasswordCredentialsHandler getPasswordCredentialsHandler( return null; } PasswordCredentialsHandler handler = new PasswordCredentialsHandler(tokenApiUrl, clientId, clientSecret, - username, password, authMode, connectionTimeout, connectionRequestTimeout, socketTimeout); + username, password, authMode, connectionTimeout, connectionRequestTimeout, socketTimeout, + TokenCacheFactory.getTokenCache()); if (hasRequestParameters(passwordCredentialsElement)) { Map requestParameters = getRequestParameters(passwordCredentialsElement); if (requestParameters == null) { diff --git a/modules/core/src/main/java/org/apache/synapse/endpoints/auth/oauth/PasswordCredentialsHandler.java b/modules/core/src/main/java/org/apache/synapse/endpoints/auth/oauth/PasswordCredentialsHandler.java index 07f48e2916..6d868037ad 100644 --- a/modules/core/src/main/java/org/apache/synapse/endpoints/auth/oauth/PasswordCredentialsHandler.java +++ b/modules/core/src/main/java/org/apache/synapse/endpoints/auth/oauth/PasswordCredentialsHandler.java @@ -38,9 +38,11 @@ public class PasswordCredentialsHandler extends OAuthHandler { protected PasswordCredentialsHandler(String tokenApiUrl, String clientId, String clientSecret, String username, String password, String authMode, int connectionTimeout, - int connectionRequestTimeout, int socketTimeout) { + int connectionRequestTimeout, int socketTimeout, + TokenCacheProvider tokenCacheProvider) { - super(tokenApiUrl, clientId, clientSecret, authMode, connectionTimeout, connectionRequestTimeout, socketTimeout); + super(tokenApiUrl, clientId, clientSecret, authMode, connectionTimeout, connectionRequestTimeout, socketTimeout, + tokenCacheProvider); this.username = username; this.password = password; } diff --git a/modules/core/src/main/java/org/apache/synapse/endpoints/auth/oauth/TokenCache.java b/modules/core/src/main/java/org/apache/synapse/endpoints/auth/oauth/TokenCache.java index 6537c11db3..cc2c16e5e7 100644 --- a/modules/core/src/main/java/org/apache/synapse/endpoints/auth/oauth/TokenCache.java +++ b/modules/core/src/main/java/org/apache/synapse/endpoints/auth/oauth/TokenCache.java @@ -25,8 +25,6 @@ import org.apache.synapse.config.SynapsePropertiesLoader; import org.apache.synapse.endpoints.auth.AuthConstants; -import java.util.concurrent.Callable; -import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; import static org.apache.synapse.endpoints.auth.AuthConstants.TOKEN_CACHE_TIMEOUT_PROPERTY; @@ -35,7 +33,7 @@ * Token Cache Implementation * Tokens will be invalidate after a interval of TOKEN_CACHE_TIMEOUT minutes */ -public class TokenCache { +public class TokenCache implements TokenCacheProvider { private static final Log log = LogFactory.getLog(TokenCache.class); @@ -70,15 +68,27 @@ public static TokenCache getInstance() { } /** - * This method returns the value in the cache, or computes it from the specified Callable + * Stores a token in the cache with the specified ID. * - * @param id id of the oauth handler - * @param callable to generate a new token by calling oauth server - * @return Token object + * @param id the unique identifier for the token + * @param token the token to be cached */ - public String getToken(String id, Callable callable) throws ExecutionException { + @Override + public void putToken(String id, String token) { - return tokenMap.get(id, callable); + tokenMap.put(id, token); + } + + /** + * Retrieves a token from the cache using the specified ID. + * + * @param id the unique identifier for the token + * @return the cached token, or {@code null} if not found + */ + @Override + public String getToken(String id) { + + return tokenMap.getIfPresent(id); } /** @@ -86,6 +96,7 @@ public String getToken(String id, Callable callable) throws ExecutionExc * * @param id id of the endpoint */ + @Override public void removeToken(String id) { tokenMap.invalidate(id); @@ -96,6 +107,7 @@ public void removeToken(String id) { * * @param oauthHandlerId id of the OAuth handler bounded to the endpoint */ + @Override public void removeTokens(String oauthHandlerId) { tokenMap.asMap().entrySet().removeIf(entry -> entry.getKey().startsWith(oauthHandlerId)); } diff --git a/modules/core/src/main/java/org/apache/synapse/endpoints/auth/oauth/TokenCacheFactory.java b/modules/core/src/main/java/org/apache/synapse/endpoints/auth/oauth/TokenCacheFactory.java new file mode 100644 index 0000000000..7f5d836568 --- /dev/null +++ b/modules/core/src/main/java/org/apache/synapse/endpoints/auth/oauth/TokenCacheFactory.java @@ -0,0 +1,82 @@ +/* + * Copyright (c) 2024, WSO2 LLC. (https://www.wso2.com/). + * + * WSO2 LLC. licenses this file to you under the Apache License, + * Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.synapse.endpoints.auth.oauth; + +import org.apache.synapse.SynapseException; +import org.apache.synapse.config.SynapsePropertiesLoader; + +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; + +/** + * Factory class responsible for providing the appropriate implementation of the TokenCacheProvider interface. + * This class manages the singleton instance of TokenCacheProvider, ensuring that it is only loaded once and reused + * across the application. + */ +public class TokenCacheFactory { + + /** + * Singleton instance of TokenCacheProvider. This will be initialized the first time, and the same instance will be + * returned on subsequent calls. + */ + private static TokenCacheProvider tokenCacheProvider; + + /** + * Retrieves the singleton instance of TokenCacheProvider. If the instance is not already initialized, + * it attempts to load the provider class specified in the `token.cache.class` property. If the property + * is not set or the class cannot be loaded, it defaults to the TokenCache implementation. + * + * @return the singleton instance of TokenCacheProvider + * @throws SynapseException if there is an error loading the specified class + */ + public static TokenCacheProvider getTokenCache() { + if (tokenCacheProvider != null) { + return tokenCacheProvider; + } + + String classPath = SynapsePropertiesLoader.loadSynapseProperties().getProperty("token.cache.class"); + if (classPath != null) { + tokenCacheProvider = loadTokenCacheProvider(classPath); + } else { + tokenCacheProvider = TokenCache.getInstance(); + } + return tokenCacheProvider; + } + + /** + * Loads the TokenCacheProvider implementation specified by the given class path. + * + * @param classPath the fully qualified class path of the TokenCacheProvider implementation + * @return an instance of the specified TokenCacheProvider implementation + * @throws SynapseException if there is an error loading the class or invoking the `getInstance` method + */ + private static TokenCacheProvider loadTokenCacheProvider(String classPath) { + try { + Class clazz = Class.forName(classPath); + Method getInstanceMethod = clazz.getMethod("getInstance"); + return (TokenCacheProvider) getInstanceMethod.invoke(null); + } catch (ClassNotFoundException e) { + throw new SynapseException("Error loading class: " + classPath, e); + } catch (NoSuchMethodException e) { + throw new SynapseException("getInstance method not found for class: " + classPath, e); + } catch (InvocationTargetException | IllegalAccessException e) { + throw new SynapseException("Error invoking getInstance method for class: " + classPath, e); + } + } +} diff --git a/modules/core/src/main/java/org/apache/synapse/endpoints/auth/oauth/TokenCacheProvider.java b/modules/core/src/main/java/org/apache/synapse/endpoints/auth/oauth/TokenCacheProvider.java new file mode 100644 index 0000000000..aa30810cf2 --- /dev/null +++ b/modules/core/src/main/java/org/apache/synapse/endpoints/auth/oauth/TokenCacheProvider.java @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2024, WSO2 LLC. (https://www.wso2.com/). + * + * WSO2 LLC. licenses this file to you under the Apache License, + * Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.synapse.endpoints.auth.oauth; + +/** + * Interface for managing token caching operations. + */ +public interface TokenCacheProvider { + + /** + * Stores a token in the cache with the specified ID. + * + * @param id the unique identifier for the token + * @param token the token to be cached + */ + void putToken(String id, String token); + + /** + * Retrieves a token from the cache using the specified ID. + * + * @param id the unique identifier for the token + * @return the cached token, or {@code null} if not found + */ + String getToken(String id); + + /** + * Removes a token from the cache using the specified ID. + * + * @param id the unique identifier for the token to be removed + */ + void removeToken(String id); + + /** + * Removes all tokens associated with the specified OAuth handler from the cache. + * + * @param id the identifier of the OAuth handler whose tokens are to be removed + */ + void removeTokens(String id); +} diff --git a/modules/core/src/test/java/org/apache/synapse/endpoints/auth/oauth/OAuthUtilsTest.java b/modules/core/src/test/java/org/apache/synapse/endpoints/auth/oauth/OAuthUtilsTest.java index 6a1405c3d9..07ce210399 100644 --- a/modules/core/src/test/java/org/apache/synapse/endpoints/auth/oauth/OAuthUtilsTest.java +++ b/modules/core/src/test/java/org/apache/synapse/endpoints/auth/oauth/OAuthUtilsTest.java @@ -303,7 +303,7 @@ public static Collection provideDataForRetryOnOauthFailureTests() throws AxisFau OAuthHandler oAuthHandler = new AuthorizationCodeHandler("oauth_server_url", "client_id", "client_secret", - "refresh_token", "header", -1, -1, -1); + "refresh_token", "header", -1, -1, -1, TokenCache.getInstance()); OAuthConfiguredHTTPEndpoint httpEndpoint = new OAuthConfiguredHTTPEndpoint(oAuthHandler);