Skip to content

Commit

Permalink
[agents] Add support for AstraDB Document API
Browse files Browse the repository at this point in the history
  • Loading branch information
eolivelli committed Nov 20, 2023
1 parent 5695641 commit 3014195
Show file tree
Hide file tree
Showing 6 changed files with 216 additions and 31 deletions.
8 changes: 6 additions & 2 deletions langstream-agents/langstream-ai-agents/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,15 @@
<groupId>com.datastax.oss</groupId>
<artifactId>java-driver-core-shaded</artifactId>
</dependency>

<dependency>
<groupId>com.datastax.astra</groupId>
<artifactId>astra-sdk-devops</artifactId>
<version>0.6.9</version>
<version>1.0</version>
</dependency>
<dependency>
<groupId>com.datastax.astra</groupId>
<artifactId>astra-sdk</artifactId>
<version>1.0</version>
</dependency>
<dependency>
<groupId>io.netty.incubator</groupId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
package com.datastax.oss.streaming.ai.datasource;

import ai.langstream.api.util.ConfigurationUtils;
import com.datastax.astra.sdk.AstraClient;
import com.datastax.astra.sdk.config.AstraClientConfig;
import com.datastax.oss.driver.api.core.CqlSession;
import com.datastax.oss.driver.api.core.CqlSessionBuilder;
import com.datastax.oss.driver.api.core.cql.BoundStatement;
Expand All @@ -31,9 +33,9 @@
import com.datastax.oss.driver.api.core.type.reflect.GenericType;
import com.datastax.oss.driver.internal.core.type.codec.CqlVectorCodec;
import com.datastax.oss.driver.internal.core.type.codec.registry.DefaultCodecRegistry;
import com.dtsx.astra.sdk.db.AstraDbClient;
import com.dtsx.astra.sdk.db.DatabaseClient;
import com.dtsx.astra.sdk.utils.ApiLocator;
import com.dtsx.astra.sdk.db.AstraDBOpsClient;
import com.dtsx.astra.sdk.db.DbOpsClient;
import com.dtsx.astra.sdk.utils.AstraEnvironment;
import edu.umd.cs.findbugs.annotations.Nullable;
import java.io.ByteArrayInputStream;
import java.net.InetSocketAddress;
Expand Down Expand Up @@ -211,21 +213,11 @@ private CqlSession buildCqlSession(Map<String, Object> dataSourceConfig) {
if (password == null) {
password = ConfigurationUtils.getString("secret", null, dataSourceConfig);
}

// in AstraDB you can use "token" as clientId and the AstraCS token as password
if (username == null && astraToken != null && !astraToken.isEmpty()) {
username = "token";
}
if (password == null) {
password = astraToken;
}

String secureBundle = ConfigurationUtils.getString("secureBundle", "", dataSourceConfig);
List<String> contactPoints = ConfigurationUtils.getList("contact-points", dataSourceConfig);
String loadBalancingLocalDc =
ConfigurationUtils.getString("loadBalancing-localDc", "", dataSourceConfig);
int port = ConfigurationUtils.getInteger("port", 9042, dataSourceConfig);
log.info("Username/ClientId: {}", username);
log.info("Contact points: {}", contactPoints);
log.info("Secure Bundle: {}", secureBundle);

Expand All @@ -241,13 +233,13 @@ private CqlSession buildCqlSession(Map<String, Object> dataSourceConfig) {
log.info(
"Automatically downloading the secure bundle for database name {} from AstraDB",
astraDatabase);
DatabaseClient databaseClient = this.buildAstraClient();
DbOpsClient databaseClient = this.buildAstraClient();
secureBundleDecoded = downloadSecureBundle(databaseClient);
} else if (!astraDatabaseId.isEmpty() && !astraToken.isEmpty()) {
log.info(
"Automatically downloading the secure bundle for database id {} from AstraDB",
astraDatabaseId);
DatabaseClient databaseClient = this.buildAstraClient();
DbOpsClient databaseClient = this.buildAstraClient();
secureBundleDecoded = downloadSecureBundle(databaseClient);
} else {
log.info("No secure bundle provided, using the default CQL driver for Cassandra");
Expand Down Expand Up @@ -281,21 +273,20 @@ public CqlSession getSession() {
return session;
}

public DatabaseClient buildAstraClient() {
public DbOpsClient buildAstraClient() {
return buildAstraClient(astraToken, astraDatabase, astraDatabaseId, astraEnvironment);
}

public static DatabaseClient buildAstraClient(
public static DbOpsClient buildAstraClient(
String astraToken,
String astraDatabase,
String astraDatabaseId,
String astraEnvironment) {
if (astraToken.isEmpty()) {
throw new IllegalArgumentException("You must configure the AstraDB token");
}
AstraDbClient astraDbClient =
new AstraDbClient(
astraToken, ApiLocator.AstraEnvironment.valueOf(astraEnvironment));
AstraDBOpsClient astraDbClient =
new AstraDBOpsClient(astraToken, AstraEnvironment.valueOf(astraEnvironment));
if (!astraDatabase.isEmpty()) {
return astraDbClient.databaseByName(astraDatabase);
} else if (!astraDatabaseId.isEmpty()) {
Expand All @@ -306,7 +297,19 @@ public static DatabaseClient buildAstraClient(
}
}

public static byte[] downloadSecureBundle(DatabaseClient databaseClient) {
public AstraClient buildAstraAPIClient() {
return buildAstraAPIClient(astraToken);
}

public static AstraClient buildAstraAPIClient(String astraToken) {
if (astraToken.isEmpty()) {
throw new IllegalArgumentException("You must configure the AstraDB token");
}
AstraClientConfig astraClientConfig = new AstraClientConfig().withToken(astraToken);
return new AstraClient(astraClientConfig);
}

public static byte[] downloadSecureBundle(DbOpsClient databaseClient) {
long start = System.currentTimeMillis();
byte[] secureBundleDecoded = databaseClient.downloadDefaultSecureConnectBundle();
long delta = System.currentTimeMillis() - start;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,17 @@
import ai.langstream.api.runner.assets.AssetManager;
import ai.langstream.api.runner.assets.AssetManagerProvider;
import ai.langstream.api.util.ConfigurationUtils;
import com.datastax.astra.sdk.AstraClient;
import com.datastax.oss.driver.api.core.CqlSession;
import com.datastax.oss.driver.api.core.cql.SimpleStatement;
import com.datastax.oss.driver.api.core.metadata.schema.KeyspaceMetadata;
import com.datastax.oss.driver.api.core.metadata.schema.TableMetadata;
import com.datastax.oss.driver.api.core.servererrors.AlreadyExistsException;
import com.datastax.oss.streaming.ai.datasource.CassandraDataSource;
import com.dtsx.astra.sdk.db.DatabaseClient;
import com.dtsx.astra.sdk.db.DbOpsClient;
import io.stargate.sdk.doc.exception.CollectionNotFoundException;
import io.stargate.sdk.json.ApiClient;
import io.stargate.sdk.json.exception.NamespaceNotFoundException;
import java.util.List;
import java.util.Map;
import java.util.Optional;
Expand All @@ -38,7 +42,9 @@ public class CassandraAssetsManagerProvider implements AssetManagerProvider {
public boolean supports(String assetType) {
return "cassandra-table".equals(assetType)
|| "cassandra-keyspace".equals(assetType)
|| "astra-keyspace".equals(assetType);
|| "astra-keyspace".equals(assetType)
|| "astra-namespace".equals(assetType)
|| "astra-collection".equals(assetType);
}

@Override
Expand All @@ -51,6 +57,10 @@ public AssetManager createInstance(String assetType) {
return new CassandraKeyspaceAssetManager();
case "astra-keyspace":
return new AstraDBKeyspaceAssetManager();
case "astra-namespace":
return new AstraDBNamespaceAssetManager();
case "astra-collection":
return new AstraDBCollectionAssetManager();
default:
throw new IllegalArgumentException();
}
Expand Down Expand Up @@ -211,7 +221,7 @@ private static class AstraDBKeyspaceAssetManager extends BaseCassandraAssetManag
public boolean assetExists() throws Exception {
String keySpace = getKeyspace();
log.info("Checking if keyspace {} exists", keySpace);
DatabaseClient astraDbClient = datasource.buildAstraClient();
DbOpsClient astraDbClient = datasource.buildAstraClient();
boolean exist = astraDbClient.keyspaces().exist(keySpace);
log.info("Result: {}", exist);
return exist;
Expand All @@ -220,7 +230,7 @@ public boolean assetExists() throws Exception {
@Override
public void deployAsset() throws Exception {
String keySpace = getKeyspace();
DatabaseClient astraDbClient = datasource.buildAstraClient();
DbOpsClient astraDbClient = datasource.buildAstraClient();
try {
astraDbClient.keyspaces().create(keySpace);
} catch (com.dtsx.astra.sdk.db.exception.KeyspaceAlreadyExistException e) {
Expand Down Expand Up @@ -250,7 +260,7 @@ public boolean deleteAssetIfExists() throws Exception {
String keySpace = getKeyspace();

log.info("Deleting keyspace {}", keySpace);
DatabaseClient astraDbClient = datasource.buildAstraClient();
DbOpsClient astraDbClient = datasource.buildAstraClient();
try {
astraDbClient.keyspaces().delete(keySpace);
return true;
Expand All @@ -263,6 +273,114 @@ public boolean deleteAssetIfExists() throws Exception {
}
}

private static class AstraDBCollectionAssetManager extends BaseCassandraAssetManager {

@Override
public boolean assetExists() throws Exception {
String namespace = getNamespace();
String collection = getCollection();
log.info("Checking if collection {} exists in namespace {}", collection, namespace);
AstraClient astraDbClient = datasource.buildAstraAPIClient();
ApiClient apiClient = astraDbClient.apiStargateJson();
if (!apiClient.isNamespaceExists(namespace)) {
log.info("Namespace {} does not exist", namespace);
return false;
}
boolean exist = apiClient.namespace(namespace).isCollectionExists(collection);
log.info("Result: {}", exist);
return exist;
}

@Override
public void deployAsset() throws Exception {
String namespace = getNamespace();
String collection = getCollection();
AstraClient astraDbClient = datasource.buildAstraAPIClient();
ApiClient apiClient = astraDbClient.apiStargateJson();

if (!apiClient.isNamespaceExists(namespace)) {
apiClient.createNamespace(namespace);
}
apiClient.namespace(namespace).createCollection(collection);
}

private String getNamespace() {
return ConfigurationUtils.getString("namespace", null, assetDefinition.getConfig());
}

private String getCollection() {
return ConfigurationUtils.getString("collection", null, assetDefinition.getConfig());
}

@Override
public boolean deleteAssetIfExists() throws Exception {
String namespace = getNamespace();
String collection = getCollection();

log.info("Deleting collection {} in namespace {}", collection, namespace);
AstraClient astraDbClient = datasource.buildAstraAPIClient();
ApiClient apiClient = astraDbClient.apiStargateJson();
try {
apiClient.namespace(namespace).deleteCollection(collection);
return true;
} catch (CollectionNotFoundException e) {
log.info(
"Collection does not exist, maybe it was deleted by another agent ({})",
e.toString());
return false;
}
}
}

private static class AstraDBNamespaceAssetManager extends BaseCassandraAssetManager {

@Override
public boolean assetExists() throws Exception {
String namespace = getNamespace();
log.info("Checking if namespace {} exists", namespace);
AstraClient astraDbClient = datasource.buildAstraAPIClient();
ApiClient apiClient = astraDbClient.apiStargateJson();
if (!apiClient.isNamespaceExists(namespace)) {
log.info("Namespace {} does not exist", namespace);
return false;
}
return true;
}

@Override
public void deployAsset() throws Exception {
String namespace = getNamespace();
AstraClient astraDbClient = datasource.buildAstraAPIClient();
ApiClient apiClient = astraDbClient.apiStargateJson();

if (!apiClient.isNamespaceExists(namespace)) {
apiClient.createNamespace(namespace);
}
}

private String getNamespace() {
return ConfigurationUtils.getString("namespace", null, assetDefinition.getConfig());
}

@Override
public boolean deleteAssetIfExists() throws Exception {
String namespace = getNamespace();

log.info("Deleting namespace {} ", namespace);
AstraClient astraDbClient = datasource.buildAstraAPIClient();
ApiClient apiClient = astraDbClient.apiStargateJson();
try {
apiClient.dropNamespace(namespace);
return true;
} catch (NamespaceNotFoundException e) {
log.info(
"Namespace does not exist, maybe it was deleted by another agent ({})",
e.toString());
return false;
}
}
}

private static CassandraDataSource buildDataSource(AssetDefinition assetDefinition) {
CassandraDataSource dataSource = new CassandraDataSource();
Map<String, Object> datasourceDefinition =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import com.datastax.oss.common.sink.config.CassandraSinkConfig;
import com.datastax.oss.common.sink.util.SinkUtil;
import com.datastax.oss.streaming.ai.datasource.CassandraDataSource;
import com.dtsx.astra.sdk.db.DatabaseClient;
import com.dtsx.astra.sdk.db.DbOpsClient;
import java.nio.charset.StandardCharsets;
import java.util.Base64;
import java.util.HashMap;
Expand Down Expand Up @@ -121,7 +121,7 @@ public void initialise(Map<String, Object> agentConfiguration) {
"environment", "PROD", datasource);
if (!token.isEmpty()
&& (!database.isEmpty() || !databaseId.isEmpty())) {
DatabaseClient databaseClient =
DbOpsClient databaseClient =
CassandraDataSource.buildAstraClient(
token,
database,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
cassandra-table
cassandra-keyspace
astra-keyspace
astra-namespace
astra-collection
milvus-collection
jdbc-table
solr-collection
Expand Down
Loading

0 comments on commit 3014195

Please sign in to comment.