diff --git a/src/main/java/software/aws/neptune/opencypher/OpenCypherIAMRequestGenerator.java b/src/main/java/software/aws/neptune/opencypher/OpenCypherIAMRequestGenerator.java index ed7b4532..21260045 100644 --- a/src/main/java/software/aws/neptune/opencypher/OpenCypherIAMRequestGenerator.java +++ b/src/main/java/software/aws/neptune/opencypher/OpenCypherIAMRequestGenerator.java @@ -24,9 +24,11 @@ import com.amazonaws.http.HttpMethodName; import com.google.gson.Gson; import org.neo4j.driver.AuthToken; -import org.neo4j.driver.AuthTokens; +import org.neo4j.driver.Values; +import org.neo4j.driver.internal.security.InternalAuthToken; import java.net.URI; +import java.util.Collections; import java.util.HashMap; import java.util.Map; @@ -34,6 +36,11 @@ import static com.amazonaws.auth.internal.SignerConstants.HOST; import static com.amazonaws.auth.internal.SignerConstants.X_AMZ_DATE; import static com.amazonaws.auth.internal.SignerConstants.X_AMZ_SECURITY_TOKEN; +import lombok.Builder; +import lombok.Getter; +import lombok.NonNull; +import org.neo4j.driver.Value; +import org.neo4j.driver.internal.value.StringValue; /** * Class to help with IAM authentication. @@ -53,17 +60,7 @@ public class OpenCypherIAMRequestGenerator { * @return AuthToken for IAM authentication. */ public static AuthToken createAuthToken(final String url, final String region) { - final Request request = new DefaultRequest<>(SERVICE_NAME); - request.setHttpMethod(HttpMethodName.GET); - request.setEndpoint(URI.create(url)); - request.setResourcePath("/opencypher"); - - final AWS4Signer signer = new AWS4Signer(); - signer.setRegionName(region); - signer.setServiceName(request.getServiceName()); - signer.sign(request, AWS_CREDENTIALS_PROVIDER.getCredentials()); - - return AuthTokens.basic(DUMMY_USERNAME, getAuthInfoJson(request)); + return new NeptuneAuthToken(region, url, AWS_CREDENTIALS_PROVIDER); } private static String getAuthInfoJson(final Request request) { @@ -76,4 +73,63 @@ private static String getAuthInfoJson(final Request request) { return GSON.toJson(obj); } + + /** + * This class is derived from the AWS documentation on accessing Amazon Neptune with IAM authentication. + * For more details and information, please refer to: + * Using the Bolt protocol to make openCypher queries to Neptune. + */ + + public static class NeptuneAuthToken extends InternalAuthToken { + private static final String SCHEME = "basic"; + private static final String REALM = "realm"; + private static final String SERVICE_NAME = "neptune-db"; + private static final String DUMMY_USERNAME = "username"; + @NonNull + private final String region; + @NonNull + @Getter + private final String url; + @NonNull + private final AWSCredentialsProvider credentialsProvider; + + @Builder + private NeptuneAuthToken( + @NonNull final String region, + @NonNull final String url, + @NonNull final AWSCredentialsProvider credentialsProvider + ) { + // The superclass caches the result of toMap(), which we don't want + super(Collections.emptyMap()); + this.region = region; + this.url = url; + this.credentialsProvider = credentialsProvider; + } + + @Override + public Map toMap() { + final Map map = new HashMap<>(); + map.put(SCHEME_KEY, Values.value(SCHEME)); + map.put(PRINCIPAL_KEY, Values.value(DUMMY_USERNAME)); + map.put(CREDENTIALS_KEY, new StringValue(getSignedHeader())); + map.put(REALM_KEY, Values.value(REALM)); + + return map; + } + + private String getSignedHeader() { + final Request request = new DefaultRequest<>(SERVICE_NAME); + request.setHttpMethod(HttpMethodName.GET); + request.setEndpoint(URI.create(url)); + // Comment out the following line if you're using an engine version older than 1.2.0.0 + request.setResourcePath("/opencypher"); + + final AWS4Signer signer = new AWS4Signer(); + signer.setRegionName(region); + signer.setServiceName(request.getServiceName()); + signer.sign(request, credentialsProvider.getCredentials()); + + return getAuthInfoJson(request); + } + } } diff --git a/src/test/java/software/aws/neptune/opencypher/OpenCypherIAMRequestGeneratorTest.java b/src/test/java/software/aws/neptune/opencypher/OpenCypherIAMRequestGeneratorTest.java index a8db5531..1a659bfa 100644 --- a/src/test/java/software/aws/neptune/opencypher/OpenCypherIAMRequestGeneratorTest.java +++ b/src/test/java/software/aws/neptune/opencypher/OpenCypherIAMRequestGeneratorTest.java @@ -99,8 +99,7 @@ private void verifyAuthToken(final boolean useTempCreds) { assertEquals(value("basic"), internalAuthToken.get(SCHEME_KEY)); assertEquals(value(DUMMY_USERNAME), internalAuthToken.get(PRINCIPAL_KEY)); - - assertFalse(internalAuthToken.containsKey(REALM_KEY)); + assertTrue(internalAuthToken.containsKey(REALM_KEY)); assertTrue(internalAuthToken.containsKey(CREDENTIALS_KEY)); final Value credentialsValue = internalAuthToken.get(CREDENTIALS_KEY);