Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Initial cut for a cuVS Java API #450

Open
wants to merge 8 commits into
base: branch-24.12
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
## common
__pycache__
.gitignore
*.pyc
*~
\#*
Expand Down Expand Up @@ -82,3 +83,8 @@ ivf_pq_index
# cuvs_bench
datasets/
/*.json

# java
.classpath


12 changes: 10 additions & 2 deletions build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,15 @@ ARGS=$*
# scripts, and that this script resides in the repo dir!
REPODIR=$(cd $(dirname $0); pwd)

VALIDARGS="clean libcuvs python rust docs tests bench-ann examples --uninstall -v -g -n --compile-static-lib --allgpuarch --no-mg --no-cpu --cpu-only --no-shared-libs --no-nvtx --show_depr_warn --incl-cache-stats --time -h"
VALIDARGS="clean libcuvs python rust java docs tests bench-ann examples --uninstall -v -g -n --compile-static-lib --allgpuarch --no-mg --no-cpu --cpu-only --no-shared-libs --no-nvtx --show_depr_warn --incl-cache-stats --time -h"
HELP="$0 [<target> ...] [<flag> ...] [--cmake-args=\"<args>\"] [--cache-tool=<tool>] [--limit-tests=<targets>] [--limit-bench-ann=<targets>] [--build-metrics=<filename>]
where <target> is:
clean - remove all existing build artifacts and configuration (start over)
libcuvs - build the cuvs C++ code only. Also builds the C-wrapper library
around the C++ code.
python - build the cuvs Python package
rust - build the cuvs Rust bindings
java - build the cuvs Java bindings
docs - build the documentation
tests - build the tests
bench-ann - build end-to-end ann benchmarks
Expand Down Expand Up @@ -61,7 +62,8 @@ SPHINX_BUILD_DIR=${REPODIR}/docs
DOXYGEN_BUILD_DIR=${REPODIR}/cpp/doxygen
PYTHON_BUILD_DIR=${REPODIR}/python/cuvs/_skbuild
RUST_BUILD_DIR=${REPODIR}/rust/target
BUILD_DIRS="${LIBCUVS_BUILD_DIR} ${PYTHON_BUILD_DIR} ${RUST_BUILD_DIR}"
JAVA_BUILD_DIR=${REPODIR}/java/cuvs-java/target
BUILD_DIRS="${LIBCUVS_BUILD_DIR} ${PYTHON_BUILD_DIR} ${RUST_BUILD_DIR} ${JAVA_BUILD_DIR}"

# Set defaults for vars modified by flags to this script
CMAKE_LOG_LEVEL=""
Expand Down Expand Up @@ -459,6 +461,12 @@ if (( ${NUMARGS} == 0 )) || hasArg rust; then
cargo test
fi

# Build the cuvs Java bindings
if (( ${NUMARGS} == 0 )) || hasArg java; then
cd ${REPODIR}/java
./build.sh
fi

export RAPIDS_VERSION="$(sed -E -e 's/^([0-9]{2})\.([0-9]{2})\.([0-9]{2}).*$/\1.\2.\3/' "${REPODIR}/VERSION")"
export RAPIDS_VERSION_MAJOR_MINOR="$(sed -E -e 's/^([0-9]{2})\.([0-9]{2})\.([0-9]{2}).*$/\1.\2/' "${REPODIR}/VERSION")"

Expand Down
12 changes: 12 additions & 0 deletions java/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
Prerequisites
-------------

* JDK 24
* Maven 3.9.6 or later

Please build libcuvs (`./build.sh libcuvs` from top level directory) before building the Java API.

Building
--------

./build.sh will generate the libcuvs_java.so file in internal/ directory, and then build the final jar file for the cuVS Java API in cuvs-java/ directory.
9 changes: 9 additions & 0 deletions java/build.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
export CMAKE_PREFIX_PATH=`pwd`/../cpp/build
cd internal
cmake .
cmake --build .
cd ..
mvn install:install-file -DgroupId=com.nvidia.cuvs -DartifactId=cuvs-java-internal -Dversion=0.1 -Dpackaging=so -Dfile=./internal/libcuvs_java.so

cd cuvs-java
mvn package
135 changes: 135 additions & 0 deletions java/cuvs-java/pom.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 https://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>com.nvidia.cuvs</groupId>
<artifactId>cuvs-java</artifactId>
<version>0.1</version>
<name>cuvs-java</name>
<packaging>jar</packaging>

<properties>
<maven.compiler.target>22</maven.compiler.target>
<maven.compiler.source>22</maven.compiler.source>
</properties>

<dependencies>
<dependency>
<groupId>commons-io</groupId>
<artifactId>commons-io</artifactId>
<version>2.15.1</version>
</dependency>

<dependency>
<groupId>com.github.fommil</groupId>
<artifactId>jniloader</artifactId>
<version>1.1</version>
</dependency>

<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-api</artifactId>
<version>2.0.13</version>
</dependency>

<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-simple</artifactId>
<version>2.0.13</version>
<scope>runtime</scope>
</dependency>

<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-api</artifactId>
<version>5.10.0</version>
</dependency>

</dependencies>

<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-surefire-plugin</artifactId>
<version>2.7</version>
<configuration>
<systemPropertyVariables>
<java.library.path>${project.build.directory}/classes</java.library.path>
</systemPropertyVariables>
</configuration>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-dependency-plugin</artifactId>
<version>2.10</version>
<executions>
<execution>
<id>copy</id>
<phase>compile</phase>
<goals>
<goal>copy</goal>
</goals>
<configuration>
<artifactItems>
<artifactItem>
<groupId>com.nvidia.cuvs</groupId>
<artifactId>cuvs-java-internal</artifactId>
<version>0.1</version>
<type>so</type>
<overWrite>false</overWrite>
<outputDirectory>
${project.build.directory}/classes</outputDirectory>
<destFileName>libcuvs_java.so</destFileName>
</artifactItem>
</artifactItems>
</configuration>
</execution>
</executions>
</plugin>

<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-assembly-plugin</artifactId>
<version>3.4.2</version>
<configuration>
<descriptorRefs>
<descriptorRef>jar-with-dependencies</descriptorRef>
</descriptorRefs>
<archiverConfig>
<duplicateBehavior>add</duplicateBehavior>
</archiverConfig>
<archive>
<manifest>
<mainClass>
com.nvidia.cuvs.ExampleApp</mainClass>
</manifest>
</archive>
</configuration>
<executions>
<execution>
<id>assemble-all</id>
<phase>package</phase>
<goals>
<goal>single</goal>
</goals>
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-jar-plugin</artifactId>
<version>2.2</version>
<configuration>
<archive>
<manifest>
<addClasspath>true</addClasspath>
<mainClass>
com.nvidia.cuvs.ExampleApp</mainClass>
</manifest>
</archive>
</configuration>
</plugin>
</plugins>
</build>
</project>
75 changes: 75 additions & 0 deletions java/cuvs-java/src/main/java/com/nvidia/cuvs/ExampleApp.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package com.nvidia.cuvs;

import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.InputStream;
import java.util.Map;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.nvidia.cuvs.cagra.CagraIndex;
import com.nvidia.cuvs.cagra.CagraIndexParams;
import com.nvidia.cuvs.cagra.CagraIndexParams.CuvsCagraGraphBuildAlgo;
import com.nvidia.cuvs.cagra.CagraSearchParams;
import com.nvidia.cuvs.cagra.CuVSQuery;
import com.nvidia.cuvs.cagra.CuVSResources;
import com.nvidia.cuvs.cagra.SearchResult;

public class ExampleApp {

private static Logger LOGGER = LoggerFactory.getLogger(ExampleApp.class);

public static void main(String[] args) throws Throwable {

// Sample data and query
float[][] dataset = { { 0.74021935f, 0.9209938f }, { 0.03902049f, 0.9689629f }, { 0.92514056f, 0.4463501f },
{ 0.6673192f, 0.10993068f } };
Map<Integer, Integer> map = Map.of(0, 0, 1, 1, 2, 2, 3, 3);
float[][] queries = { { 0.48216683f, 0.0428398f }, { 0.5084142f, 0.6545497f }, { 0.51260436f, 0.2643005f },
{ 0.05198065f, 0.5789965f } };

CuVSResources res = new CuVSResources();

CagraIndexParams cagraIndexParams = new CagraIndexParams.Builder()
.withIntermediateGraphDegree(10)
.withBuildAlgo(CuvsCagraGraphBuildAlgo.NN_DESCENT)
.build();

CagraSearchParams cagraSearchParams = new CagraSearchParams
.Builder()
.build();

// Creating a new CAGRA index
CagraIndex index = new CagraIndex.Builder(res)
.withDataset(dataset)
.withIndexParams(cagraIndexParams)
.build();

// Saving the index on to the disk.
index.serialize(new FileOutputStream("abc.cag"));

// Loading a CAGRA index from disk.
InputStream fin = new FileInputStream(new File("abc.cag"));
CagraIndex index2 = new CagraIndex.Builder(res)
.from(fin)
.build();

// Query
CuVSQuery query = new CuVSQuery.Builder()
.withTopK(3)
.withSearchParams(cagraSearchParams)
.withQueryVectors(queries)
.withMapping(map)
.build();

// Search
SearchResult rslt = index.search(query);
LOGGER.info(rslt.getResults().toString());

// Search from de-serialized index
SearchResult rslt2 = index2.search(query);
LOGGER.info(rslt2.getResults().toString());
}
}
Loading
Loading