Skip to content

Commit

Permalink
Merge pull request #368 from agebhar1/feature/gh-366-custom-credennti…
Browse files Browse the repository at this point in the history
…al-provider

support custom AWS credential provider
  • Loading branch information
embano1 authored Aug 24, 2024
2 parents 0ad6c51 + 9e4981e commit 9d66194
Show file tree
Hide file tree
Showing 9 changed files with 257 additions and 23 deletions.
52 changes: 36 additions & 16 deletions README.md

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ public class EventBridgeSinkConfig extends AbstractConfig {
static final String AWS_RETRIES_CONFIG = "aws.eventbridge.retries.max";
static final String AWS_RETRIES_DELAY_CONFIG = "aws.eventbridge.retries.delay";
static final String AWS_PROFILE_NAME_CONFIG = "aws.eventbridge.iam.profile.name";
static final String AWS_CREDENTIAL_PROVIDER_CLASS =
"aws.eventbridge.auth.credentials_provider.class";
static final String AWS_ROLE_ARN_CONFIG = "aws.eventbridge.iam.role.arn";
static final String AWS_ROLE_EXTERNAL_ID_CONFIG = "aws.eventbridge.iam.external.id";
static final String AWS_DETAIL_TYPES_CONFIG = "aws.eventbridge.detail.types";
Expand All @@ -55,6 +57,8 @@ public class EventBridgeSinkConfig extends AbstractConfig {
private static final int AWS_RETRIES_DELAY_DEFAULT = 200; // 200ms
private static final String AWS_RETRIES_DELAY_DOC =
"The retry delay in milliseconds between each retry attempt.";
private static final String AWS_CREDENTIAL_PROVIDER_DOC =
"An optional class name of the credentials provider to use. It must implement 'software.amazon.awssdk.auth.credentials.AwsCredentialsProvider' with a no-arg constructor and optionally 'org.apache.kafka.common.Configurable' to configure the provider after instantiation.";
private static final String AWS_ROLE_ARN_DOC =
"An optional IAM role to authenticate and send events to EventBridge. "
+ "If not specified, AWS default credentials provider is used";
Expand Down Expand Up @@ -95,6 +99,7 @@ public class EventBridgeSinkConfig extends AbstractConfig {
public final String eventBusArn;
public final String endpointID;
public final String endpointURI;
public final String awsCredentialsProviderClass;
public final String roleArn;
public final String externalId;
public final String profileName;
Expand All @@ -115,6 +120,7 @@ public EventBridgeSinkConfig(final Map<?, ?> originalProps) {
this.eventBusArn = getString(AWS_EVENTBUS_ARN_CONFIG);
this.endpointID = getString(AWS_EVENTBUS_GLOBAL_ENDPOINT_ID_CONFIG);
this.endpointURI = getString(AWS_ENDPOINT_URI_CONFIG);
this.awsCredentialsProviderClass = getString(AWS_CREDENTIAL_PROVIDER_CLASS);
this.roleArn = getString(AWS_ROLE_ARN_CONFIG);
this.externalId = getString(AWS_ROLE_EXTERNAL_ID_CONFIG);
this.profileName = getString(AWS_PROFILE_NAME_CONFIG);
Expand Down Expand Up @@ -175,6 +181,12 @@ private static void addParams(final ConfigDef configDef) {
"",
Importance.MEDIUM,
AWS_EVENTBUS_ENDPOINT_ID_DOC);
configDef.define(
AWS_CREDENTIAL_PROVIDER_CLASS,
Type.STRING,
"",
Importance.MEDIUM,
AWS_CREDENTIAL_PROVIDER_DOC);
configDef.define(AWS_ROLE_ARN_CONFIG, Type.STRING, "", Importance.MEDIUM, AWS_ROLE_ARN_DOC);
configDef.define(
AWS_ROLE_EXTERNAL_ID_CONFIG,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
import org.apache.kafka.common.config.Config;
import org.apache.kafka.common.config.ConfigException;
import org.apache.kafka.common.config.ConfigValue;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.regions.RegionMetadata;
import software.amazon.awssdk.utils.StringUtils;
import software.amazon.event.kafkaconnector.offloading.S3EventBridgeEventDetailValueOffloading;

public class EventBridgeSinkConfigValidator {
Expand Down Expand Up @@ -61,6 +63,11 @@ public static void validate(ConfigValue configValue, EnvVarGetter getenv) {
validateEventBusRetries(configValue);
break;
}
case AWS_CREDENTIAL_PROVIDER_CLASS:
{
validateAwsCredentialProviderClass(configValue);
break;
}
case AWS_ROLE_ARN_CONFIG:
{
validateRoleArn(configValue);
Expand Down Expand Up @@ -135,6 +142,32 @@ private static void validateURI(ConfigValue configValue) {
// TODO: validate optional URI here or when constructing client in task?
}

private static void validateAwsCredentialProviderClass(ConfigValue configValue) {
var requiredInterface = AwsCredentialsProvider.class;
var className = (String) configValue.value();
if (StringUtils.isNotBlank(className)) {
try {
var clazz = Class.forName((String) configValue.value());
if (!requiredInterface.isAssignableFrom(clazz)) {
throw new ConfigException(
"Class '"
+ className
+ "' does not implement '"
+ requiredInterface.getCanonicalName()
+ "'.");
}
clazz.getDeclaredConstructor();
} catch (ClassNotFoundException e) {
throw new ConfigException(
"Class '"
+ className
+ "' can't be loaded. Ensure the class path you have specified is correct.");
} catch (NoSuchMethodException e) {
throw new ConfigException("Class '" + className + "' requires a no-arg constructor.");
}
}
}

private static void validateBusArn(ConfigValue configValue) {
var awsEventBusArn = (String) configValue.value();
// example: arn:aws[-partition]:events:region:account:event-bus/bus-name
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
import software.amazon.awssdk.services.eventbridge.model.PutEventsResponse;
import software.amazon.awssdk.services.s3.S3AsyncClient;
import software.amazon.awssdk.utils.StringUtils;
import software.amazon.event.kafkaconnector.auth.EventBridgeCredentialsProvider;
import software.amazon.event.kafkaconnector.auth.EventBridgeAwsCredentialsProviderFactory;
import software.amazon.event.kafkaconnector.batch.DefaultEventBridgeBatching;
import software.amazon.event.kafkaconnector.batch.EventBridgeBatchingStrategy;
import software.amazon.event.kafkaconnector.logging.ContextAwareLoggerFactory;
Expand Down Expand Up @@ -90,7 +90,8 @@ public EventBridgeWriter(EventBridgeSinkConfig config) {
.putAdvancedOption(USER_AGENT_PREFIX, userAgentPrefix)
.build();

var credentialsProvider = EventBridgeCredentialsProvider.getCredentials(config);
var credentialsProvider =
EventBridgeAwsCredentialsProviderFactory.getAwsCredentialsProvider(config);

var client =
EventBridgeAsyncClient.builder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@
*/
package software.amazon.event.kafkaconnector.auth;

import static software.amazon.awssdk.utils.StringUtils.isNotBlank;

import java.lang.reflect.InvocationTargetException;
import java.util.Map;
import org.apache.kafka.common.Configurable;
import org.slf4j.Logger;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider;
Expand All @@ -16,24 +21,52 @@
import software.amazon.event.kafkaconnector.logging.ContextAwareLoggerFactory;

/** IAMUtility offers convenience functions for creating AWS IAM credential providers. */
public class EventBridgeCredentialsProvider {
public abstract class EventBridgeAwsCredentialsProviderFactory {

private static final int stsRefreshDuration = 900; // min allowed value
private static final Logger log =
ContextAwareLoggerFactory.getLogger(EventBridgeCredentialsProvider.class);
ContextAwareLoggerFactory.getLogger(EventBridgeAwsCredentialsProviderFactory.class);

private EventBridgeAwsCredentialsProviderFactory() {
// prevent instantiation
}

/**
* Create an IAM credentials provider.
* Create an AWS credentials provider.
*
* <p>If a {@link AwsCredentialsProvider} implementing class name is provided, then an instance is
* created by no-arg constructor. If the class also implements {@link Configurable}, then {@link
* Configurable#configure(Map)} is called after instantiation.
*
* <p>If a role ARN is provided in the config, then an STS assume-role credentials provider is
* created. The provider will automatically renew the assume-role session as needed.
*
* <p>If the role ARN is empty or null, then the default AWS credentials provider is returned.
*
* @param config Configuration containing optional IAM role, session, etc.
* @param config Configuration containing optional {@link AwsCredentialsProvider} implementing
* class name, IAM role, session, etc.
* @return AWS credentials provider
*/
public static AwsCredentialsProvider getCredentials(EventBridgeSinkConfig config) {
public static AwsCredentialsProvider getAwsCredentialsProvider(EventBridgeSinkConfig config) {
if (isNotBlank(config.awsCredentialsProviderClass)) {
try {
// checks are already executed by EventBridgeSinkConnector#validate(Map)
var clazz = Class.forName(config.awsCredentialsProviderClass);
var ctor = clazz.getDeclaredConstructor();
var obj = ctor.newInstance();
if (Configurable.class.isAssignableFrom(clazz)) {
((Configurable) obj).configure(config.originals());
}
return (AwsCredentialsProvider) obj;
} catch (final ClassNotFoundException
| NoSuchMethodException
| InvocationTargetException
| InstantiationException
| IllegalArgumentException
| IllegalAccessException e) {
throw new RuntimeException(e);
}
}
if (config.roleArn.trim().isBlank()) {
log.info("Using aws default credentials provider");
return getDefaultCredentialsProvider(config);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/
package software.amazon.event.kafkaconnector;

import software.amazon.awssdk.auth.credentials.AwsCredentials;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;

public class AwsCredentialProviderImpl implements AwsCredentialsProvider {

@Override
public AwsCredentials resolveCredentials() {
return null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
*/
package software.amazon.event.kafkaconnector;

import static org.assertj.core.api.AssertionsForClassTypes.assertThat;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static software.amazon.awssdk.core.SdkSystemSetting.AWS_ACCESS_KEY_ID;
import static software.amazon.awssdk.core.SdkSystemSetting.AWS_SECRET_ACCESS_KEY;
Expand All @@ -16,6 +17,7 @@
import org.apache.kafka.common.config.ConfigValue;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.CsvSource;
import org.junit.jupiter.params.provider.ValueSource;

public class EventBridgeSinkConfigValidatorTest {
Expand Down Expand Up @@ -71,6 +73,33 @@ public void invalidRegionValidation() {
});
}

@ParameterizedTest
@ValueSource(strings = {"", "software.amazon.event.kafkaconnector.AwsCredentialProviderImpl"})
public void validAwsCredentialProviderClass(String className) {
var configValue = new ConfigValue(AWS_CREDENTIAL_PROVIDER_CLASS);
configValue.value(className);

EventBridgeSinkConfigValidator.validate(configValue);
}

@ParameterizedTest
@CsvSource(
value = {
"xyz:Class 'xyz' can't be loaded. Ensure the class path you have specified is correct.",
"software.amazon.event.kafkaconnector.TestUtils$NonAwsCredentialProvider:Class 'software.amazon.event.kafkaconnector.TestUtils$NonAwsCredentialProvider' does not implement 'software.amazon.awssdk.auth.credentials.AwsCredentialsProvider'.",
"software.amazon.event.kafkaconnector.TestUtils$NoNoArgAwsCredentialProvider:Class 'software.amazon.event.kafkaconnector.TestUtils$NoNoArgAwsCredentialProvider' requires a no-arg constructor."
},
delimiter = ':')
public void invalidAwsCredentialProviderClass(String className, String expectedExceptionMessage) {
var configValue = new ConfigValue(AWS_CREDENTIAL_PROVIDER_CLASS);
configValue.value(className);

var exception =
assertThrows(
ConfigException.class, () -> EventBridgeSinkConfigValidator.validate(configValue));
assertThat(exception).hasMessage(expectedExceptionMessage);
}

@Test
public void validBusArn() {
var configValue = new ConfigValue(AWS_EVENTBUS_ARN_CONFIG);
Expand Down
23 changes: 23 additions & 0 deletions src/test/java/software/amazon/event/kafkaconnector/TestUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import org.apache.kafka.connect.data.Struct;
import org.apache.kafka.connect.sink.SinkRecord;
import org.slf4j.LoggerFactory;
import software.amazon.awssdk.auth.credentials.AwsCredentials;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.services.eventbridge.model.PutEventsResultEntry;

public abstract class TestUtils {
Expand Down Expand Up @@ -110,4 +112,25 @@ public static ListAppender of(Class<?> clazz, Level level) {
return appender;
}
}

public static class NonAwsCredentialProvider {}

public static class NoNoArgAwsCredentialProvider implements AwsCredentialsProvider {

private final String sentinel;

public NoNoArgAwsCredentialProvider(String sentinel) {
this.sentinel = sentinel;
}

@Override
public AwsCredentials resolveCredentials() {
return null;
}

@Override
public String toString() {
return "NoNoArgAwsCredentialProvider{" + "sentinel='" + sentinel + '\'' + '}';
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/
package software.amazon.event.kafkaconnector.auth;

import static org.assertj.core.api.Assertions.assertThat;

import java.util.HashMap;
import java.util.Map;
import org.junit.jupiter.api.Test;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider;
import software.amazon.awssdk.services.sts.auth.StsAssumeRoleCredentialsProvider;
import software.amazon.event.kafkaconnector.AwsCredentialProviderImpl;
import software.amazon.event.kafkaconnector.EventBridgeSinkConfig;

public class EventBridgeCredentialsProviderFactoryTest {

private static final Map<String, String> commonProps =
Map.of(
"aws.eventbridge.connector.id", "testConnectorId",
"aws.eventbridge.region", "us-east-1",
"aws.eventbridge.eventbus.arn", "arn:aws:events:us-east-1:000000000000:event-bus/e2e");

@Test
public void shouldUseDefaultAwsCredentialsProvider() {

var provider =
EventBridgeAwsCredentialsProviderFactory.getAwsCredentialsProvider(
new EventBridgeSinkConfig(commonProps));

assertThat(provider).isInstanceOf(AwsCredentialsProvider.class);
assertThat(provider).isExactlyInstanceOf(DefaultCredentialsProvider.class);
}

@Test
public void shouldUseStsAssumeRoleCredentialsProviderIfArnIsPresent() {

var props = new HashMap<>(commonProps);
props.put(
"aws.eventbridge.iam.role.arn",
"arn:aws:iam::123456789012:oidc-provider/server.example.org");
var provider =
EventBridgeAwsCredentialsProviderFactory.getAwsCredentialsProvider(
new EventBridgeSinkConfig(props));

assertThat(provider).isInstanceOf(AwsCredentialsProvider.class);
assertThat(provider).isExactlyInstanceOf(StsAssumeRoleCredentialsProvider.class);
}

@Test
public void shouldUseAwsCredentialsProviderByProvidedClass() {

var props = new HashMap<>(commonProps);
props.put(
"aws.eventbridge.auth.credentials_provider.class",
AwsCredentialProviderImpl.class.getCanonicalName());

var provider =
EventBridgeAwsCredentialsProviderFactory.getAwsCredentialsProvider(
new EventBridgeSinkConfig(props));

assertThat(provider).isInstanceOf(AwsCredentialsProvider.class);
assertThat(provider).isExactlyInstanceOf(AwsCredentialProviderImpl.class);
}
}

0 comments on commit 9d66194

Please sign in to comment.