Skip to content

Commit

Permalink
[vector-database] Add support for Astra Collections
Browse files Browse the repository at this point in the history
  • Loading branch information
eolivelli committed Nov 21, 2023
1 parent 5695641 commit f2d57e3
Show file tree
Hide file tree
Showing 12 changed files with 320 additions and 22 deletions.
12 changes: 11 additions & 1 deletion langstream-agents/langstream-ai-agents/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,17 @@
<dependency>
<groupId>com.datastax.astra</groupId>
<artifactId>astra-sdk-devops</artifactId>
<version>0.6.9</version>
<version>1.0</version>
</dependency>
<dependency>
<groupId>com.datastax.astra</groupId>
<artifactId>astra-db-client</artifactId>
<version>1.0</version>
</dependency>
<dependency>
<groupId>com.datastax.astra</groupId>
<artifactId>astra-sdk</artifactId>
<version>1.0</version>
</dependency>
<dependency>
<groupId>io.netty.incubator</groupId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import com.datastax.oss.streaming.ai.datasource.QueryStepDataSource;
import java.util.Map;

public class AstraDataSource implements DataSourceProvider {
public class CassandraDataSourceProvider implements DataSourceProvider {

@Override
public boolean supports(Map<String, Object> dataSourceConfig) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@
import com.datastax.oss.driver.api.core.type.reflect.GenericType;
import com.datastax.oss.driver.internal.core.type.codec.CqlVectorCodec;
import com.datastax.oss.driver.internal.core.type.codec.registry.DefaultCodecRegistry;
import com.dtsx.astra.sdk.db.AstraDbClient;
import com.dtsx.astra.sdk.db.DatabaseClient;
import com.dtsx.astra.sdk.utils.ApiLocator;
import com.dtsx.astra.sdk.db.AstraDBOpsClient;
import com.dtsx.astra.sdk.db.DbOpsClient;
import com.dtsx.astra.sdk.utils.AstraEnvironment;
import edu.umd.cs.findbugs.annotations.Nullable;
import java.io.ByteArrayInputStream;
import java.net.InetSocketAddress;
Expand Down Expand Up @@ -241,13 +241,13 @@ private CqlSession buildCqlSession(Map<String, Object> dataSourceConfig) {
log.info(
"Automatically downloading the secure bundle for database name {} from AstraDB",
astraDatabase);
DatabaseClient databaseClient = this.buildAstraClient();
DbOpsClient databaseClient = this.buildAstraClient();
secureBundleDecoded = downloadSecureBundle(databaseClient);
} else if (!astraDatabaseId.isEmpty() && !astraToken.isEmpty()) {
log.info(
"Automatically downloading the secure bundle for database id {} from AstraDB",
astraDatabaseId);
DatabaseClient databaseClient = this.buildAstraClient();
DbOpsClient databaseClient = this.buildAstraClient();
secureBundleDecoded = downloadSecureBundle(databaseClient);
} else {
log.info("No secure bundle provided, using the default CQL driver for Cassandra");
Expand Down Expand Up @@ -281,21 +281,20 @@ public CqlSession getSession() {
return session;
}

public DatabaseClient buildAstraClient() {
public DbOpsClient buildAstraClient() {
return buildAstraClient(astraToken, astraDatabase, astraDatabaseId, astraEnvironment);
}

public static DatabaseClient buildAstraClient(
public static DbOpsClient buildAstraClient(
String astraToken,
String astraDatabase,
String astraDatabaseId,
String astraEnvironment) {
if (astraToken.isEmpty()) {
throw new IllegalArgumentException("You must configure the AstraDB token");
}
AstraDbClient astraDbClient =
new AstraDbClient(
astraToken, ApiLocator.AstraEnvironment.valueOf(astraEnvironment));
AstraDBOpsClient astraDbClient =
new AstraDBOpsClient(astraToken, AstraEnvironment.valueOf(astraEnvironment));
if (!astraDatabase.isEmpty()) {
return astraDbClient.databaseByName(astraDatabase);
} else if (!astraDatabaseId.isEmpty()) {
Expand All @@ -306,7 +305,7 @@ public static DatabaseClient buildAstraClient(
}
}

public static byte[] downloadSecureBundle(DatabaseClient databaseClient) {
public static byte[] downloadSecureBundle(DbOpsClient databaseClient) {
long start = System.currentTimeMillis();
byte[] secureBundleDecoded = databaseClient.downloadDefaultSecureConnectBundle();
long delta = System.currentTimeMillis() - start;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
ai.langstream.ai.agents.datasource.impl.AstraDataSource
ai.langstream.ai.agents.datasource.impl.CassandraDataSourceProvider
ai.langstream.ai.agents.datasource.impl.JdbcDataSourceProvider
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
/*
* Copyright DataStax, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ai.langstream.agents.vector.astra;

import ai.langstream.api.model.AssetDefinition;
import ai.langstream.api.runner.assets.AssetManager;
import ai.langstream.api.runner.assets.AssetManagerProvider;
import ai.langstream.api.util.ConfigurationUtils;
import io.stargate.sdk.doc.exception.CollectionNotFoundException;
import java.util.Map;
import lombok.extern.slf4j.Slf4j;

@Slf4j
public class AstraCollectionsAssetsManagerProvider implements AssetManagerProvider {

@Override
public boolean supports(String assetType) {
return "astra-collection".equals(assetType);
}

@Override
public AssetManager createInstance(String assetType) {

switch (assetType) {
case "astra-collection":
return new AstraDBCollectionAssetManager();
default:
throw new IllegalArgumentException();
}
}

private abstract static class BaseAstraAssetManager implements AssetManager {

AstraCollectionsDataSource datasource;
AssetDefinition assetDefinition;

@Override
public void initialize(AssetDefinition assetDefinition) {
this.datasource = buildDataSource(assetDefinition);
this.assetDefinition = assetDefinition;
}

@Override
public void close() throws Exception {
if (datasource != null) {
datasource.close();
}
}
}

private static class AstraDBCollectionAssetManager extends BaseAstraAssetManager {

@Override
public boolean assetExists() throws Exception {
String collection = getCollection();
log.info("Checking if collection {} exists", collection);
return datasource.getAstraDB().isCollectionExists(collection);
}

@Override
public void deployAsset() throws Exception {
int vectorDimension = getVectorDimension();

String collection = getCollection();
log.info("Create collection {} with vector dimension {}", collection, vectorDimension);
datasource.getAstraDB().createCollection(collection, vectorDimension);
}

private String getCollection() {
return ConfigurationUtils.getString("collection", null, assetDefinition.getConfig());
}

private int getVectorDimension() {
return ConfigurationUtils.getInt("vector-dimension", 1536, assetDefinition.getConfig());
}

@Override
public boolean deleteAssetIfExists() throws Exception {
String collection = getCollection();

log.info("Deleting collection {}", collection);

try {
datasource.getAstraDB().deleteCollection(collection);
return true;
} catch (CollectionNotFoundException e) {
log.info(
"collection does not exist, maybe it was deleted by another agent ({})",
e.toString());
return false;
}
}
}

private static AstraCollectionsDataSource buildDataSource(AssetDefinition assetDefinition) {
AstraCollectionsDataSource dataSource = new AstraCollectionsDataSource();
Map<String, Object> datasourceDefinition =
ConfigurationUtils.getMap("datasource", Map.of(), assetDefinition.getConfig());
Map<String, Object> configuration =
ConfigurationUtils.getMap("configuration", Map.of(), datasourceDefinition);
dataSource.initialize(configuration);
return dataSource;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/*
* Copyright DataStax, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ai.langstream.agents.vector.astra;

import ai.langstream.api.util.ConfigurationUtils;
import com.datastax.oss.streaming.ai.datasource.QueryStepDataSource;
import com.dtsx.astra.sdk.AstraDB;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;

@Slf4j
public class AstraCollectionsDataSource implements QueryStepDataSource {

AstraDB astraDB;

@Override
public void initialize(Map<String, Object> dataSourceConfig) {
log.info(
"Initializing CassandraDataSource with config {}",
ConfigurationUtils.redactSecrets(dataSourceConfig));
String astraToken = ConfigurationUtils.getString("token", "", dataSourceConfig);
String astraEndpoint = ConfigurationUtils.getString("endpoint", "", dataSourceConfig);
this.astraDB = new AstraDB(astraToken, astraEndpoint);
}

@Override
public void close() {}

@Override
public List<Map<String, Object>> fetchData(String query, List<Object> params) {
if (log.isDebugEnabled()) {
log.debug(
"Executing query {} with params {} ({})",
query,
params,
params.stream()
.map(v -> v == null ? "null" : v.getClass().toString())
.collect(Collectors.joining(",")));
}
throw new UnsupportedOperationException();
}

@Override
public Map<String, Object> executeStatement(
String query, List<String> generatedKeys, List<Object> params) {
if (log.isDebugEnabled()) {
log.debug(
"Executing statement {} with params {} ({})",
query,
params,
params.stream()
.map(v -> v == null ? "null" : v.getClass().toString())
.collect(Collectors.joining(",")));
}
throw new UnsupportedOperationException();
}

public AstraDB getAstraDB() {
return astraDB;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*
* Copyright DataStax, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ai.langstream.agents.vector.astra;

import ai.langstream.ai.agents.datasource.DataSourceProvider;
import com.datastax.oss.streaming.ai.datasource.QueryStepDataSource;
import java.util.Map;
import lombok.extern.slf4j.Slf4j;

@Slf4j
public class AstraCollectionsDataSourceProvider implements DataSourceProvider {

@Override
public boolean supports(Map<String, Object> dataSourceConfig) {
String service = (String) dataSourceConfig.get("service");
return "astra-collections".equals(service);
}

@Override
public QueryStepDataSource createDataSourceImplementation(
Map<String, Object> dataSourceConfig) {
return new AstraCollectionsDataSource();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Copyright DataStax, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ai.langstream.agents.vector.astra;

import ai.langstream.api.database.VectorDatabaseWriter;
import ai.langstream.api.database.VectorDatabaseWriterProvider;
import ai.langstream.api.runner.code.Record;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import lombok.extern.slf4j.Slf4j;

@Slf4j
public class AstraCollectionsWriter implements VectorDatabaseWriterProvider {

@Override
public boolean supports(Map<String, Object> dataSourceConfig) {
return "astra-collections".equals(dataSourceConfig.get("service"));
}

@Override
public VectorDatabaseWriter createImplementation(Map<String, Object> datasourceConfig) {
return new AstraCollectionsDatabaseWriter(datasourceConfig);
}

private static class AstraCollectionsDatabaseWriter implements VectorDatabaseWriter {

private final Map<String, Object> datasourceConfig;

public AstraCollectionsDatabaseWriter(Map<String, Object> datasourceConfig) {
this.datasourceConfig = datasourceConfig;
}

@Override
public void initialise(Map<String, Object> agentConfiguration) {}

@Override
public CompletableFuture<?> upsert(Record record, Map<String, Object> context) {
return CompletableFuture.failedFuture(new UnsupportedOperationException());
}

@Override
public void close() {}
}
}
Loading

0 comments on commit f2d57e3

Please sign in to comment.