From 695480c500c6e1c24c46cf9630b5f7b95b24e69f Mon Sep 17 00:00:00 2001 From: Enrico Olivelli Date: Tue, 21 Nov 2023 17:47:11 +0100 Subject: [PATCH] Add first working test --- .../ai/agents/commons/jstl/JstlFunctions.java | 22 +++ .../astra/AstraCollectionsDataSource.java | 76 -------- .../vector/astra/AstraCollectionsWriter.java | 57 ------ ...> AstraVectorDBAssetsManagerProvider.java} | 11 +- .../vector/astra/AstraVectorDBDataSource.java | 137 ++++++++++++++ ...a => AstraVectorDBDataSourceProvider.java} | 6 +- .../vector/astra/AstraVectorDBWriter.java | 178 ++++++++++++++++++ ...am.ai.agents.datasource.DataSourceProvider | 2 +- ...eam.api.runner.assets.AssetManagerProvider | 2 +- .../datasource/impl/AstraVectorDBTest.java | 167 ++++++++++++++++ .../api/database/VectorDatabaseWriter.java | 2 +- 11 files changed, 516 insertions(+), 144 deletions(-) delete mode 100644 langstream-agents/langstream-vector-agents/src/main/java/ai/langstream/agents/vector/astra/AstraCollectionsDataSource.java delete mode 100644 langstream-agents/langstream-vector-agents/src/main/java/ai/langstream/agents/vector/astra/AstraCollectionsWriter.java rename langstream-agents/langstream-vector-agents/src/main/java/ai/langstream/agents/vector/astra/{AstraCollectionsAssetsManagerProvider.java => AstraVectorDBAssetsManagerProvider.java} (90%) create mode 100644 langstream-agents/langstream-vector-agents/src/main/java/ai/langstream/agents/vector/astra/AstraVectorDBDataSource.java rename langstream-agents/langstream-vector-agents/src/main/java/ai/langstream/agents/vector/astra/{AstraCollectionsDataSourceProvider.java => AstraVectorDBDataSourceProvider.java} (86%) create mode 100644 langstream-agents/langstream-vector-agents/src/main/java/ai/langstream/agents/vector/astra/AstraVectorDBWriter.java create mode 100644 langstream-agents/langstream-vector-agents/src/test/java/ai/langstream/agents/vector/datasource/impl/AstraVectorDBTest.java diff --git a/langstream-agents/langstream-agents-commons/src/main/java/ai/langstream/ai/agents/commons/jstl/JstlFunctions.java b/langstream-agents/langstream-agents-commons/src/main/java/ai/langstream/ai/agents/commons/jstl/JstlFunctions.java index cedf455d7..409abaf37 100644 --- a/langstream-agents/langstream-agents-commons/src/main/java/ai/langstream/ai/agents/commons/jstl/JstlFunctions.java +++ b/langstream-agents/langstream-agents-commons/src/main/java/ai/langstream/ai/agents/commons/jstl/JstlFunctions.java @@ -100,6 +100,28 @@ public static List toListOfFloat(Object input) { result.add(JstlTypeConverter.INSTANCE.coerceToFloat(o)); } return result; + } else if (input instanceof float[] a) { + List result = new ArrayList<>(a.length); + for (Object o : a) { + result.add(JstlTypeConverter.INSTANCE.coerceToFloat(o)); + } + return result; + } else { + throw new IllegalArgumentException("Cannot convert " + input + " to list of float"); + } + } + + public static float[] toArrayOfFloat(Object input) { + if (input == null) { + return null; + } + if (input instanceof Collection collection) { + float[] result = new float[collection.size()]; + int i = 0; + for (Object o : collection) { + result[i++] = JstlTypeConverter.INSTANCE.coerceToFloat(o); + } + return result; } else { throw new IllegalArgumentException("Cannot convert " + input + " to list of float"); } diff --git a/langstream-agents/langstream-vector-agents/src/main/java/ai/langstream/agents/vector/astra/AstraCollectionsDataSource.java b/langstream-agents/langstream-vector-agents/src/main/java/ai/langstream/agents/vector/astra/AstraCollectionsDataSource.java deleted file mode 100644 index 1ed0caee2..000000000 --- a/langstream-agents/langstream-vector-agents/src/main/java/ai/langstream/agents/vector/astra/AstraCollectionsDataSource.java +++ /dev/null @@ -1,76 +0,0 @@ -/* - * Copyright DataStax, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package ai.langstream.agents.vector.astra; - -import ai.langstream.api.util.ConfigurationUtils; -import com.datastax.oss.streaming.ai.datasource.QueryStepDataSource; -import com.dtsx.astra.sdk.AstraDB; -import java.util.List; -import java.util.Map; -import java.util.stream.Collectors; -import lombok.extern.slf4j.Slf4j; - -@Slf4j -public class AstraCollectionsDataSource implements QueryStepDataSource { - - AstraDB astraDB; - - @Override - public void initialize(Map dataSourceConfig) { - log.info( - "Initializing CassandraDataSource with config {}", - ConfigurationUtils.redactSecrets(dataSourceConfig)); - String astraToken = ConfigurationUtils.getString("token", "", dataSourceConfig); - String astraEndpoint = ConfigurationUtils.getString("endpoint", "", dataSourceConfig); - this.astraDB = new AstraDB(astraToken, astraEndpoint); - } - - @Override - public void close() {} - - @Override - public List> fetchData(String query, List params) { - if (log.isDebugEnabled()) { - log.debug( - "Executing query {} with params {} ({})", - query, - params, - params.stream() - .map(v -> v == null ? "null" : v.getClass().toString()) - .collect(Collectors.joining(","))); - } - throw new UnsupportedOperationException(); - } - - @Override - public Map executeStatement( - String query, List generatedKeys, List params) { - if (log.isDebugEnabled()) { - log.debug( - "Executing statement {} with params {} ({})", - query, - params, - params.stream() - .map(v -> v == null ? "null" : v.getClass().toString()) - .collect(Collectors.joining(","))); - } - throw new UnsupportedOperationException(); - } - - public AstraDB getAstraDB() { - return astraDB; - } -} diff --git a/langstream-agents/langstream-vector-agents/src/main/java/ai/langstream/agents/vector/astra/AstraCollectionsWriter.java b/langstream-agents/langstream-vector-agents/src/main/java/ai/langstream/agents/vector/astra/AstraCollectionsWriter.java deleted file mode 100644 index 366dbadf3..000000000 --- a/langstream-agents/langstream-vector-agents/src/main/java/ai/langstream/agents/vector/astra/AstraCollectionsWriter.java +++ /dev/null @@ -1,57 +0,0 @@ -/* - * Copyright DataStax, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package ai.langstream.agents.vector.astra; - -import ai.langstream.api.database.VectorDatabaseWriter; -import ai.langstream.api.database.VectorDatabaseWriterProvider; -import ai.langstream.api.runner.code.Record; -import java.util.Map; -import java.util.concurrent.CompletableFuture; -import lombok.extern.slf4j.Slf4j; - -@Slf4j -public class AstraCollectionsWriter implements VectorDatabaseWriterProvider { - - @Override - public boolean supports(Map dataSourceConfig) { - return "astra-collections".equals(dataSourceConfig.get("service")); - } - - @Override - public VectorDatabaseWriter createImplementation(Map datasourceConfig) { - return new AstraCollectionsDatabaseWriter(datasourceConfig); - } - - private static class AstraCollectionsDatabaseWriter implements VectorDatabaseWriter { - - private final Map datasourceConfig; - - public AstraCollectionsDatabaseWriter(Map datasourceConfig) { - this.datasourceConfig = datasourceConfig; - } - - @Override - public void initialise(Map agentConfiguration) {} - - @Override - public CompletableFuture upsert(Record record, Map context) { - return CompletableFuture.failedFuture(new UnsupportedOperationException()); - } - - @Override - public void close() {} - } -} diff --git a/langstream-agents/langstream-vector-agents/src/main/java/ai/langstream/agents/vector/astra/AstraCollectionsAssetsManagerProvider.java b/langstream-agents/langstream-vector-agents/src/main/java/ai/langstream/agents/vector/astra/AstraVectorDBAssetsManagerProvider.java similarity index 90% rename from langstream-agents/langstream-vector-agents/src/main/java/ai/langstream/agents/vector/astra/AstraCollectionsAssetsManagerProvider.java rename to langstream-agents/langstream-vector-agents/src/main/java/ai/langstream/agents/vector/astra/AstraVectorDBAssetsManagerProvider.java index c915b4ab3..1b1a69870 100644 --- a/langstream-agents/langstream-vector-agents/src/main/java/ai/langstream/agents/vector/astra/AstraCollectionsAssetsManagerProvider.java +++ b/langstream-agents/langstream-vector-agents/src/main/java/ai/langstream/agents/vector/astra/AstraVectorDBAssetsManagerProvider.java @@ -24,7 +24,7 @@ import lombok.extern.slf4j.Slf4j; @Slf4j -public class AstraCollectionsAssetsManagerProvider implements AssetManagerProvider { +public class AstraVectorDBAssetsManagerProvider implements AssetManagerProvider { @Override public boolean supports(String assetType) { @@ -44,7 +44,7 @@ public AssetManager createInstance(String assetType) { private abstract static class BaseAstraAssetManager implements AssetManager { - AstraCollectionsDataSource datasource; + AstraVectorDBDataSource datasource; AssetDefinition assetDefinition; @Override @@ -80,7 +80,8 @@ public void deployAsset() throws Exception { } private String getCollection() { - return ConfigurationUtils.getString("collection", null, assetDefinition.getConfig()); + return ConfigurationUtils.getString( + "collection-name", null, assetDefinition.getConfig()); } private int getVectorDimension() { @@ -105,8 +106,8 @@ public boolean deleteAssetIfExists() throws Exception { } } - private static AstraCollectionsDataSource buildDataSource(AssetDefinition assetDefinition) { - AstraCollectionsDataSource dataSource = new AstraCollectionsDataSource(); + private static AstraVectorDBDataSource buildDataSource(AssetDefinition assetDefinition) { + AstraVectorDBDataSource dataSource = new AstraVectorDBDataSource(); Map datasourceDefinition = ConfigurationUtils.getMap("datasource", Map.of(), assetDefinition.getConfig()); Map configuration = diff --git a/langstream-agents/langstream-vector-agents/src/main/java/ai/langstream/agents/vector/astra/AstraVectorDBDataSource.java b/langstream-agents/langstream-vector-agents/src/main/java/ai/langstream/agents/vector/astra/AstraVectorDBDataSource.java new file mode 100644 index 000000000..22bb1542d --- /dev/null +++ b/langstream-agents/langstream-vector-agents/src/main/java/ai/langstream/agents/vector/astra/AstraVectorDBDataSource.java @@ -0,0 +1,137 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ai.langstream.agents.vector.astra; + +import ai.langstream.agents.vector.InterpolationUtils; +import ai.langstream.ai.agents.commons.jstl.JstlFunctions; +import ai.langstream.api.util.ConfigurationUtils; +import com.datastax.oss.streaming.ai.datasource.QueryStepDataSource; +import com.dtsx.astra.sdk.AstraDB; +import io.stargate.sdk.json.CollectionClient; +import io.stargate.sdk.json.domain.Filter; +import io.stargate.sdk.json.domain.JsonResult; +import io.stargate.sdk.json.domain.SelectQuery; +import io.stargate.sdk.json.domain.SelectQueryBuilder; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import lombok.extern.slf4j.Slf4j; + +@Slf4j +public class AstraVectorDBDataSource implements QueryStepDataSource { + + AstraDB astraDB; + + @Override + public void initialize(Map dataSourceConfig) { + log.info( + "Initializing CassandraDataSource with config {}", + ConfigurationUtils.redactSecrets(dataSourceConfig)); + String astraToken = ConfigurationUtils.getString("token", "", dataSourceConfig); + String astraEndpoint = ConfigurationUtils.getString("endpoint", "", dataSourceConfig); + this.astraDB = new AstraDB(astraToken, astraEndpoint); + } + + @Override + public void close() {} + + @Override + public List> fetchData(String query, List params) { + if (log.isDebugEnabled()) { + log.debug( + "Executing query {} with params {} ({})", + query, + params, + params.stream() + .map(v -> v == null ? "null" : v.getClass().toString()) + .collect(Collectors.joining(","))); + } + Map queryMap = + InterpolationUtils.buildObjectFromJson(query, Map.class, params); + if (queryMap.isEmpty()) { + throw new UnsupportedOperationException("Query is empty"); + } + String collectionName = (String) queryMap.get("collection-name"); + if (collectionName == null) { + throw new UnsupportedOperationException("collection-name is not defined"); + } + CollectionClient collection = this.getAstraDB().collection(collectionName); + List result; + + float[] vector = JstlFunctions.toArrayOfFloat(queryMap.remove("vector")); + Integer max = (Integer) queryMap.remove("max"); + + if (max == null) { + max = Integer.MAX_VALUE; + } + if (vector != null) { + Filter filter = new Filter(); + queryMap.forEach((k, v) -> filter.where(k).isEqualsTo(v)); + log.info( + "doing similarity search with filter {} max {} and vector {}", + filter, + max, + vector); + result = collection.similaritySearch(vector, filter, max); + } else { + SelectQueryBuilder selectQueryBuilder = + SelectQuery.builder().includeSimilarity().select("*"); + queryMap.forEach((k, v) -> selectQueryBuilder.where(k).isEqualsTo(v)); + + SelectQuery selectQuery = selectQueryBuilder.build(); + log.info("doing query {}", selectQuery); + + result = collection.query(selectQuery).toList(); + } + + return result.stream() + .map( + m -> { + Map r = new HashMap<>(); + if (m.getData() != null) { + r.putAll(m.getData()); + } + if (m.getSimilarity() != null) { + r.put("similarity", m.getSimilarity()); + } + if (m.getVector() != null) { + r.put("vector", JstlFunctions.toListOfFloat(m.getVector())); + } + return r; + }) + .collect(Collectors.toList()); + } + + @Override + public Map executeStatement( + String query, List generatedKeys, List params) { + if (log.isDebugEnabled()) { + log.debug( + "Executing statement {} with params {} ({})", + query, + params, + params.stream() + .map(v -> v == null ? "null" : v.getClass().toString()) + .collect(Collectors.joining(","))); + } + throw new UnsupportedOperationException(); + } + + public AstraDB getAstraDB() { + return astraDB; + } +} diff --git a/langstream-agents/langstream-vector-agents/src/main/java/ai/langstream/agents/vector/astra/AstraCollectionsDataSourceProvider.java b/langstream-agents/langstream-vector-agents/src/main/java/ai/langstream/agents/vector/astra/AstraVectorDBDataSourceProvider.java similarity index 86% rename from langstream-agents/langstream-vector-agents/src/main/java/ai/langstream/agents/vector/astra/AstraCollectionsDataSourceProvider.java rename to langstream-agents/langstream-vector-agents/src/main/java/ai/langstream/agents/vector/astra/AstraVectorDBDataSourceProvider.java index 8971b8162..3360ce61b 100644 --- a/langstream-agents/langstream-vector-agents/src/main/java/ai/langstream/agents/vector/astra/AstraCollectionsDataSourceProvider.java +++ b/langstream-agents/langstream-vector-agents/src/main/java/ai/langstream/agents/vector/astra/AstraVectorDBDataSourceProvider.java @@ -21,17 +21,17 @@ import lombok.extern.slf4j.Slf4j; @Slf4j -public class AstraCollectionsDataSourceProvider implements DataSourceProvider { +public class AstraVectorDBDataSourceProvider implements DataSourceProvider { @Override public boolean supports(Map dataSourceConfig) { String service = (String) dataSourceConfig.get("service"); - return "astra-collections".equals(service); + return "astra-vector-db".equals(service); } @Override public QueryStepDataSource createDataSourceImplementation( Map dataSourceConfig) { - return new AstraCollectionsDataSource(); + return new AstraVectorDBDataSource(); } } diff --git a/langstream-agents/langstream-vector-agents/src/main/java/ai/langstream/agents/vector/astra/AstraVectorDBWriter.java b/langstream-agents/langstream-vector-agents/src/main/java/ai/langstream/agents/vector/astra/AstraVectorDBWriter.java new file mode 100644 index 000000000..ad0aa5516 --- /dev/null +++ b/langstream-agents/langstream-vector-agents/src/main/java/ai/langstream/agents/vector/astra/AstraVectorDBWriter.java @@ -0,0 +1,178 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ai.langstream.agents.vector.astra; + +import static ai.langstream.ai.agents.commons.MutableRecord.recordToMutableRecord; + +import ai.langstream.ai.agents.commons.MutableRecord; +import ai.langstream.ai.agents.commons.jstl.JstlEvaluator; +import ai.langstream.ai.agents.commons.jstl.JstlFunctions; +import ai.langstream.api.database.VectorDatabaseWriter; +import ai.langstream.api.database.VectorDatabaseWriterProvider; +import ai.langstream.api.runner.code.Record; +import ai.langstream.api.util.ConfigurationUtils; +import io.stargate.sdk.json.CollectionClient; +import io.stargate.sdk.json.domain.JsonDocument; +import io.stargate.sdk.json.domain.UpdateQuery; +import io.stargate.sdk.json.exception.ApiException; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.function.BiConsumer; +import lombok.extern.slf4j.Slf4j; + +@Slf4j +public class AstraVectorDBWriter implements VectorDatabaseWriterProvider { + + @Override + public boolean supports(Map dataSourceConfig) { + return "astra-vector-db".equals(dataSourceConfig.get("service")); + } + + @Override + public VectorDatabaseWriter createImplementation(Map datasourceConfig) { + return new AstraCollectionsDatabaseWriter(datasourceConfig); + } + + private static class AstraCollectionsDatabaseWriter implements VectorDatabaseWriter { + + AstraVectorDBDataSource dataSource; + private final Map datasourceConfig; + private String collectionName; + private CollectionClient collection; + + private final LinkedHashMap fields = new LinkedHashMap<>(); + + public AstraCollectionsDatabaseWriter(Map datasourceConfig) { + this.datasourceConfig = datasourceConfig; + this.dataSource = new AstraVectorDBDataSource(); + } + + @Override + public void initialise(Map agentConfiguration) { + collectionName = + ConfigurationUtils.getString("collection-name", "", agentConfiguration); + dataSource.initialize(datasourceConfig); + collection = dataSource.getAstraDB().collection(collectionName); + + List> fields = + (List>) + agentConfiguration.getOrDefault("fields", List.of()); + fields.forEach( + field -> { + this.fields.put( + field.get("name").toString(), + buildEvaluator(field, "expression", Object.class)); + }); + } + + @Override + public CompletableFuture upsert(Record record, Map context) { + MutableRecord mutableRecord = recordToMutableRecord(record, true); + JsonDocument document = new JsonDocument(); + try { + computeFields( + mutableRecord, + fields, + (name, value) -> { + if (value != null) { + log.info("Field {} value {}", name, value); + switch (name) { + case "vector": + document.vector(JstlFunctions.toArrayOfFloat(value)); + break; + case "id": + document.id(value.toString()); + break; + case "data": + document.data(value); + break; + default: + document.put(name, value); + break; + } + } + }); + if (record.value() == null) { + int count = collection.deleteById(document.getId()); + if (count > 0) { + log.info("Deleted document with id {}", document.getId()); + } else { + log.info("No document with id {} to delete", document.getId()); + } + return CompletableFuture.completedFuture(document.getId()); + } else { + + try { + String id = collection.insertOne(document); + log.info("Inserted document with id {}", id); + return CompletableFuture.completedFuture(id); + } catch (ApiException e) { + String message = e.getMessage() + ""; + // TODO: have a way to get the error code + if (message.contains("Document already exists")) { + collection. // Already Exist + findOneAndReplace( + UpdateQuery.builder() + .where("_id") + .isEqualsTo(document.getId()) + .replaceBy(document) + .build()); + return CompletableFuture.completedFuture(document.getId()); + } else { + return CompletableFuture.failedFuture(e); + } + } + } + + } catch (Throwable e) { + log.error("Error while inserting document", e); + return CompletableFuture.failedFuture(e); + } + } + + @Override + public void close() {} + + private void computeFields( + MutableRecord mutableRecord, + Map fields, + BiConsumer acceptor) { + fields.forEach( + (name, evaluator) -> { + Object value = evaluator.evaluate(mutableRecord); + if (log.isDebugEnabled()) { + log.debug( + "setting value {} ({}) for field {}", + value, + value.getClass(), + name); + } + acceptor.accept(name, value); + }); + } + + private static JstlEvaluator buildEvaluator( + Map agentConfiguration, String param, Class type) { + String expression = agentConfiguration.getOrDefault(param, "").toString(); + if (expression == null || expression.isEmpty()) { + return null; + } + return new JstlEvaluator("${" + expression + "}", type); + } + } +} diff --git a/langstream-agents/langstream-vector-agents/src/main/resources/META-INF/services/ai.langstream.ai.agents.datasource.DataSourceProvider b/langstream-agents/langstream-vector-agents/src/main/resources/META-INF/services/ai.langstream.ai.agents.datasource.DataSourceProvider index e227398ea..f269c5d74 100644 --- a/langstream-agents/langstream-vector-agents/src/main/resources/META-INF/services/ai.langstream.ai.agents.datasource.DataSourceProvider +++ b/langstream-agents/langstream-vector-agents/src/main/resources/META-INF/services/ai.langstream.ai.agents.datasource.DataSourceProvider @@ -2,4 +2,4 @@ ai.langstream.agents.vector.pinecone.PineconeDataSource ai.langstream.agents.vector.milvus.MilvusDataSource ai.langstream.agents.vector.solr.SolrDataSource ai.langstream.agents.vector.opensearch.OpenSearchDataSource -ai.langstream.agents.vector.astra.AstraCollectionsDataSourceProvider \ No newline at end of file +ai.langstream.agents.vector.astra.AstraVectorDBDataSourceProvider \ No newline at end of file diff --git a/langstream-agents/langstream-vector-agents/src/main/resources/META-INF/services/ai.langstream.api.runner.assets.AssetManagerProvider b/langstream-agents/langstream-vector-agents/src/main/resources/META-INF/services/ai.langstream.api.runner.assets.AssetManagerProvider index fd6eb76be..98573dc73 100644 --- a/langstream-agents/langstream-vector-agents/src/main/resources/META-INF/services/ai.langstream.api.runner.assets.AssetManagerProvider +++ b/langstream-agents/langstream-vector-agents/src/main/resources/META-INF/services/ai.langstream.api.runner.assets.AssetManagerProvider @@ -3,4 +3,4 @@ ai.langstream.agents.vector.milvus.MilvusAssetsManagerProvider ai.langstream.agents.vector.jdbc.JdbcAssetsManagerProvider ai.langstream.agents.vector.solr.SolrAssetsManagerProvider ai.langstream.agents.vector.opensearch.OpenSearchAssetsManagerProvider -ai.langstream.agents.vector.astra.AstraCollectionsAssetsManagerProvider \ No newline at end of file +ai.langstream.agents.vector.astra.AstraVectorDBAssetsManagerProvider \ No newline at end of file diff --git a/langstream-agents/langstream-vector-agents/src/test/java/ai/langstream/agents/vector/datasource/impl/AstraVectorDBTest.java b/langstream-agents/langstream-vector-agents/src/test/java/ai/langstream/agents/vector/datasource/impl/AstraVectorDBTest.java new file mode 100644 index 000000000..b5987d166 --- /dev/null +++ b/langstream-agents/langstream-vector-agents/src/test/java/ai/langstream/agents/vector/datasource/impl/AstraVectorDBTest.java @@ -0,0 +1,167 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ai.langstream.agents.vector.datasource.impl; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import ai.langstream.agents.vector.astra.AstraVectorDBAssetsManagerProvider; +import ai.langstream.agents.vector.astra.AstraVectorDBDataSourceProvider; +import ai.langstream.agents.vector.astra.AstraVectorDBWriter; +import ai.langstream.api.database.VectorDatabaseWriter; +import ai.langstream.api.model.AssetDefinition; +import ai.langstream.api.runner.assets.AssetManager; +import ai.langstream.api.runner.assets.AssetManagerProvider; +import ai.langstream.api.runner.code.SimpleRecord; +import com.datastax.oss.streaming.ai.datasource.QueryStepDataSource; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import lombok.extern.slf4j.Slf4j; +import org.junit.jupiter.api.Test; + +@Slf4j +public class AstraVectorDBTest { + + private static final String TOKEN = + "AstraCS:HQKZyFwTNcNQFPhsLHPHlyYq:0fd08e29b7e7c590e947ac8fa9a4d6d785a4661a8eb1b3c011e2a0d19c2ecd7c"; + private static final String ENDPOINT = + "https://18bdf302-901f-4245-af09-061ebdb480d2-us-east1.apps.astra.datastax.com"; + + @Test + void testWrite() throws Exception { + AstraVectorDBDataSourceProvider dataSourceProvider = new AstraVectorDBDataSourceProvider(); + Map config = Map.of("token", TOKEN, "endpoint", ENDPOINT); + + String collectionName = "documents"; + int dimension = 32; + List vector = new ArrayList<>(); + List vector2 = new ArrayList<>(); + for (int i = 0; i < dimension; i++) { + vector.add(i * 1f / dimension); + vector2.add((i + 1) * 1f / dimension); + } + String vectorAsString = vector.toString(); + String vector2AsString = vector2.toString(); + + try (QueryStepDataSource datasource = + dataSourceProvider.createDataSourceImplementation(config); + VectorDatabaseWriter writer = + new AstraVectorDBWriter().createImplementation(config)) { + datasource.initialize(config); + + AssetManagerProvider assetsManagerProvider = new AstraVectorDBAssetsManagerProvider(); + try (AssetManager tableManager = + assetsManagerProvider.createInstance("astra-collection"); ) { + AssetDefinition assetDefinition = new AssetDefinition(); + assetDefinition.setAssetType("astra-collection"); + assetDefinition.setConfig( + Map.of( + "collection-name", + collectionName, + "datasource", + Map.of("configuration", config), + "vector-dimension", + vector.size())); + tableManager.initialize(assetDefinition); + tableManager.deleteAssetIfExists(); + + assertFalse(tableManager.assetExists()); + tableManager.deployAsset(); + + List> fields = + List.of( + Map.of( + "name", + "id", + "expression", + "fn:concat(key.name,'-',key.chunk_id)"), + Map.of("name", "name", "expression", "key.name"), + Map.of("name", "chunk_id", "expression", "key.chunk_id"), + Map.of( + "name", + "vector", + "expression", + "fn:toListOfFloat(value.vector)"), + Map.of("name", "text", "expression", "value.text")); + + writer.initialise(Map.of("collection-name", collectionName, "fields", fields)); + + // the PK contains a single quote in order to test escaping values in deletion + SimpleRecord record = + SimpleRecord.of( + "{\"name\": \"do'c1\", \"chunk_id\": 1}", + """ + { + "vector": %s, + "text": "Lorem ipsum..." + } + """ + .formatted(vectorAsString)); + writer.upsert(record, Map.of()).get(); + + String query = + """ + { + "collection-name": "%s", + "vector": ?, + "max": 10 + } + """ + .formatted(collectionName); + List params = List.of(vector); + List> results = datasource.fetchData(query, params); + log.info("Results: {}", results); + + assertEquals(1, results.size()); + assertEquals("do'c1", results.get(0).get("name")); + assertEquals("Lorem ipsum...", results.get(0).get("text")); + + SimpleRecord recordUpdated = + SimpleRecord.of( + "{\"name\": \"do'c1\", \"chunk_id\": 1}", + """ + { + "vector": %s, + "text": "Lorem ipsum changed..." + } + """ + .formatted(vector2AsString)); + writer.upsert(recordUpdated, Map.of()).get(); + + List params2 = List.of(vector2); + List> results2 = datasource.fetchData(query, params2); + log.info("Results: {}", results2); + + assertEquals(1, results2.size()); + assertEquals("do'c1", results2.get(0).get("name")); + assertEquals("Lorem ipsum changed...", results2.get(0).get("text")); + + SimpleRecord recordDelete = + SimpleRecord.of("{\"name\": \"do'c1\", \"chunk_id\": 1}", null); + writer.upsert(recordDelete, Map.of()).get(); + + List> results3 = datasource.fetchData(query, params2); + log.info("Results: {}", results3); + assertEquals(0, results3.size()); + + assertTrue(tableManager.assetExists()); + tableManager.deleteAssetIfExists(); + } + } + } +} diff --git a/langstream-api/src/main/java/ai/langstream/api/database/VectorDatabaseWriter.java b/langstream-api/src/main/java/ai/langstream/api/database/VectorDatabaseWriter.java index f890fa5a4..fb9715af2 100644 --- a/langstream-api/src/main/java/ai/langstream/api/database/VectorDatabaseWriter.java +++ b/langstream-api/src/main/java/ai/langstream/api/database/VectorDatabaseWriter.java @@ -23,7 +23,7 @@ * This is the interface for writing to a vector database. this interface is really simple by * intention. For advanced usages users should use Kafka Connect connectors. */ -public interface VectorDatabaseWriter { +public interface VectorDatabaseWriter extends AutoCloseable { default void initialise(Map agentConfiguration) throws Exception {}