Skip to content

Commit

Permalink
Add Java JNI interface for multiple contains
Browse files Browse the repository at this point in the history
Signed-off-by: Chong Gao <[email protected]>
  • Loading branch information
Chong Gao committed Nov 8, 2024
1 parent 150d8d8 commit 3ae0ecd
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 0 deletions.
45 changes: 45 additions & 0 deletions java/src/main/java/ai/rapids/cudf/ColumnView.java
Original file line number Diff line number Diff line change
Expand Up @@ -3332,6 +3332,44 @@ public final ColumnVector stringContains(Scalar compString) {
return new ColumnVector(stringContains(getNativeView(), compString.getScalarHandle()));
}

private static long[] toPrimitive(Long[] longs) {
long[] ret = new long[longs.length];
for (int i = 0; i < longs.length; ++i) {
ret[i] = longs[i];
}
return ret;
}

/**
* @brief Searches for the given target strings within each string in the provided column
*
* Each column in the result table corresponds to the result for the target string at the same
* ordinal. i.e. 0th column is the BOOL8 column result for the 0th target string, 1th for 1th,
* etc.
*
* If the target is not found for a string, false is returned for that entry in the output column.
* If the target is an empty string, true is returned for all non-null entries in the output column.
*
* Any null input strings return corresponding null entries in the output columns.
*
* input = ["a", "b", "c"]
* targets = ["a", "c"]
* output is a table with two boolean columns:
* column 0: [true, false, false]
* column 1: [false, false, true]
*
* @param targets UTF-8 encoded strings to search for in each string in `input`
* @return BOOL8 columns
*/
public final ColumnVector[] stringContains(ColumnView targets) {
assert type.equals(DType.STRING) : "column type must be a String";
assert targets.getType().equals(DType.STRING) : "targets type must be a string";
assert targets.getNullCount() == 0 : "targets must not be null";
assert targets.getRowCount() > 0 : "targets must not be empty";
long[] resultPointers = stringContainsMulti(getNativeView(), targets.getNativeView());
return Arrays.stream(resultPointers).mapToObj(ColumnVector::new).toArray(ColumnVector[]::new);
}

/**
* Replaces values less than `lo` in `input` with `lo`,
* and values greater than `hi` with `hi`.
Expand Down Expand Up @@ -4437,6 +4475,13 @@ private static native long stringReplaceWithBackrefs(long columnView, String pat
*/
private static native long stringContains(long cudfViewHandle, long compString) throws CudfException;

/**
* Native method for searching for the given target strings within each string in the provided column.
* @param cudfViewHandle native handle of the cudf::column_view being operated on.
* @param targets handle of the column containing the string being searched for.
*/
private static native long[] stringContainsMulti(long cudfViewHandle, long targets) throws CudfException;

/**
* Native method for extracting results from a regex program pattern. Returns a table handle.
*
Expand Down
22 changes: 22 additions & 0 deletions java/src/main/native/src/ColumnViewJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
#include <cudf/strings/convert/convert_urls.hpp>
#include <cudf/strings/extract.hpp>
#include <cudf/strings/find.hpp>
#include <cudf/strings/find_multiple.hpp>
#include <cudf/strings/findall.hpp>
#include <cudf/strings/padding.hpp>
#include <cudf/strings/regex/regex_program.hpp>
Expand Down Expand Up @@ -2827,4 +2828,25 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_toHex(JNIEnv* env, jclass
}
CATCH_STD(env, 0);
}

JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_ColumnView_stringContainsMulti(JNIEnv* env,
jobject j_object,
jlong j_view_handle,
jlong comp_strings)
{
JNI_NULL_CHECK(env, j_view_handle, "column is null", 0);
JNI_NULL_CHECK(env, comp_strings, "targets is null", 0);

try {
cudf::jni::auto_set_device(env);
auto* column_view = reinterpret_cast<cudf::column_view*>(j_view_handle);
auto* targets_view = reinterpret_cast<cudf::column_view*>(comp_strings);
auto const strings_column = cudf::strings_column_view(*column_view);
auto const targets_column = cudf::strings_column_view(*targets_view);
auto contains_results = cudf::strings::contains_multiple(strings_column, targets_column);
return cudf::jni::convert_table_for_return(env, std::move(contains_results));
}
CATCH_STD(env, 0);
}

} // extern "C"
24 changes: 24 additions & 0 deletions java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -3828,6 +3828,30 @@ void testStringOpsEmpty() {
}
}

@Test
void testStringContainsMulti() {
ColumnVector[] results = null;
try (ColumnVector haystack = ColumnVector.fromStrings("All the leaves are brown",
"And the sky is grey",
"I've been for a walk",
"On a winter's day",
null,
"");
ColumnVector targets = ColumnVector.fromStrings("the", "a");
ColumnVector expected0 = ColumnVector.fromBoxedBooleans(true, true, false, false, null, false);
ColumnVector expected1 = ColumnVector.fromBoxedBooleans(true, false, true, true, null, false)) {
results = haystack.stringContains(targets);
assertColumnsAreEqual(results[0], expected0);
assertColumnsAreEqual(results[1], expected1);
} finally {
if (results != null) {
for (ColumnVector c : results) {
c.close();
}
}
}
}

@Test
void testStringFindOperations() {
try (ColumnVector testStrings = ColumnVector.fromStrings("", null, "abCD", "1a\"\u0100B1", "a\"\u0100B1", "1a\"\u0100B",
Expand Down

0 comments on commit 3ae0ecd

Please sign in to comment.