Skip to content

Commit

Permalink
Add first working test
Browse files Browse the repository at this point in the history
  • Loading branch information
eolivelli committed Nov 21, 2023
1 parent f2d57e3 commit 695480c
Show file tree
Hide file tree
Showing 11 changed files with 516 additions and 144 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,28 @@ public static List<Float> toListOfFloat(Object input) {
result.add(JstlTypeConverter.INSTANCE.coerceToFloat(o));
}
return result;
} else if (input instanceof float[] a) {
List<Float> result = new ArrayList<>(a.length);
for (Object o : a) {
result.add(JstlTypeConverter.INSTANCE.coerceToFloat(o));
}
return result;
} else {
throw new IllegalArgumentException("Cannot convert " + input + " to list of float");
}
}

public static float[] toArrayOfFloat(Object input) {
if (input == null) {
return null;
}
if (input instanceof Collection<?> collection) {
float[] result = new float[collection.size()];
int i = 0;
for (Object o : collection) {
result[i++] = JstlTypeConverter.INSTANCE.coerceToFloat(o);
}
return result;
} else {
throw new IllegalArgumentException("Cannot convert " + input + " to list of float");
}
Expand Down

This file was deleted.

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import lombok.extern.slf4j.Slf4j;

@Slf4j
public class AstraCollectionsAssetsManagerProvider implements AssetManagerProvider {
public class AstraVectorDBAssetsManagerProvider implements AssetManagerProvider {

@Override
public boolean supports(String assetType) {
Expand All @@ -44,7 +44,7 @@ public AssetManager createInstance(String assetType) {

private abstract static class BaseAstraAssetManager implements AssetManager {

AstraCollectionsDataSource datasource;
AstraVectorDBDataSource datasource;
AssetDefinition assetDefinition;

@Override
Expand Down Expand Up @@ -80,7 +80,8 @@ public void deployAsset() throws Exception {
}

private String getCollection() {
return ConfigurationUtils.getString("collection", null, assetDefinition.getConfig());
return ConfigurationUtils.getString(
"collection-name", null, assetDefinition.getConfig());
}

private int getVectorDimension() {
Expand All @@ -105,8 +106,8 @@ public boolean deleteAssetIfExists() throws Exception {
}
}

private static AstraCollectionsDataSource buildDataSource(AssetDefinition assetDefinition) {
AstraCollectionsDataSource dataSource = new AstraCollectionsDataSource();
private static AstraVectorDBDataSource buildDataSource(AssetDefinition assetDefinition) {
AstraVectorDBDataSource dataSource = new AstraVectorDBDataSource();
Map<String, Object> datasourceDefinition =
ConfigurationUtils.getMap("datasource", Map.of(), assetDefinition.getConfig());
Map<String, Object> configuration =
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
/*
* Copyright DataStax, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ai.langstream.agents.vector.astra;

import ai.langstream.agents.vector.InterpolationUtils;
import ai.langstream.ai.agents.commons.jstl.JstlFunctions;
import ai.langstream.api.util.ConfigurationUtils;
import com.datastax.oss.streaming.ai.datasource.QueryStepDataSource;
import com.dtsx.astra.sdk.AstraDB;
import io.stargate.sdk.json.CollectionClient;
import io.stargate.sdk.json.domain.Filter;
import io.stargate.sdk.json.domain.JsonResult;
import io.stargate.sdk.json.domain.SelectQuery;
import io.stargate.sdk.json.domain.SelectQueryBuilder;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;

@Slf4j
public class AstraVectorDBDataSource implements QueryStepDataSource {

AstraDB astraDB;

@Override
public void initialize(Map<String, Object> dataSourceConfig) {
log.info(
"Initializing CassandraDataSource with config {}",
ConfigurationUtils.redactSecrets(dataSourceConfig));
String astraToken = ConfigurationUtils.getString("token", "", dataSourceConfig);
String astraEndpoint = ConfigurationUtils.getString("endpoint", "", dataSourceConfig);
this.astraDB = new AstraDB(astraToken, astraEndpoint);
}

@Override
public void close() {}

@Override
public List<Map<String, Object>> fetchData(String query, List<Object> params) {
if (log.isDebugEnabled()) {
log.debug(
"Executing query {} with params {} ({})",
query,
params,
params.stream()
.map(v -> v == null ? "null" : v.getClass().toString())
.collect(Collectors.joining(",")));
}
Map<String, Object> queryMap =
InterpolationUtils.buildObjectFromJson(query, Map.class, params);
if (queryMap.isEmpty()) {
throw new UnsupportedOperationException("Query is empty");
}
String collectionName = (String) queryMap.get("collection-name");
if (collectionName == null) {
throw new UnsupportedOperationException("collection-name is not defined");
}
CollectionClient collection = this.getAstraDB().collection(collectionName);
List<JsonResult> result;

float[] vector = JstlFunctions.toArrayOfFloat(queryMap.remove("vector"));
Integer max = (Integer) queryMap.remove("max");

if (max == null) {
max = Integer.MAX_VALUE;
}
if (vector != null) {
Filter filter = new Filter();
queryMap.forEach((k, v) -> filter.where(k).isEqualsTo(v));
log.info(
"doing similarity search with filter {} max {} and vector {}",
filter,
max,
vector);
result = collection.similaritySearch(vector, filter, max);
} else {
SelectQueryBuilder selectQueryBuilder =
SelectQuery.builder().includeSimilarity().select("*");
queryMap.forEach((k, v) -> selectQueryBuilder.where(k).isEqualsTo(v));

SelectQuery selectQuery = selectQueryBuilder.build();
log.info("doing query {}", selectQuery);

result = collection.query(selectQuery).toList();
}

return result.stream()
.map(
m -> {
Map<String, Object> r = new HashMap<>();
if (m.getData() != null) {
r.putAll(m.getData());
}
if (m.getSimilarity() != null) {
r.put("similarity", m.getSimilarity());
}
if (m.getVector() != null) {
r.put("vector", JstlFunctions.toListOfFloat(m.getVector()));
}
return r;
})
.collect(Collectors.toList());
}

@Override
public Map<String, Object> executeStatement(
String query, List<String> generatedKeys, List<Object> params) {
if (log.isDebugEnabled()) {
log.debug(
"Executing statement {} with params {} ({})",
query,
params,
params.stream()
.map(v -> v == null ? "null" : v.getClass().toString())
.collect(Collectors.joining(",")));
}
throw new UnsupportedOperationException();
}

public AstraDB getAstraDB() {
return astraDB;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,17 @@
import lombok.extern.slf4j.Slf4j;

@Slf4j
public class AstraCollectionsDataSourceProvider implements DataSourceProvider {
public class AstraVectorDBDataSourceProvider implements DataSourceProvider {

@Override
public boolean supports(Map<String, Object> dataSourceConfig) {
String service = (String) dataSourceConfig.get("service");
return "astra-collections".equals(service);
return "astra-vector-db".equals(service);
}

@Override
public QueryStepDataSource createDataSourceImplementation(
Map<String, Object> dataSourceConfig) {
return new AstraCollectionsDataSource();
return new AstraVectorDBDataSource();
}
}
Loading

0 comments on commit 695480c

Please sign in to comment.