Skip to content

Commit

Permalink
ensure spark returns are consistent with cuvs when handling less than…
Browse files Browse the repository at this point in the history
… k items probed

listening for future updates to consolidate behaviors of ivfflat, ivfpq and refine
  • Loading branch information
lijinf2 committed Nov 6, 2024
1 parent f5aed5b commit 8f290f2
Showing 1 changed file with 18 additions and 2 deletions.
20 changes: 18 additions & 2 deletions python/tests/test_approximate_nearest_neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,13 +275,13 @@ def compare_with_cuml_or_cuvs_sg(
avg_recall = self.cal_avg_recall(given_indices)
assert (avg_recall > avg_recall_cumlann) or abs(
avg_recall - avg_recall_cumlann
) < tolerance
) <= tolerance

avg_dist_gap_cumlann = self.cal_avg_dist_gap(cuvssg_distances)
avg_dist_gap = self.cal_avg_dist_gap(given_distances)
assert (avg_dist_gap <= avg_dist_gap_cumlann) or abs(
avg_dist_gap - avg_dist_gap_cumlann
) < tolerance
) <= tolerance

def get_cuml_sg_results(
self,
Expand Down Expand Up @@ -966,6 +966,7 @@ def test_return_fewer_k(
This tests the corner case where there are less than k neighbors found due to nprobe too small.
More details can be found at the docstring of class ApproximateNearestNeighbors.
"""
assert algorithm in {"ivfpq", "ivfflat"}
metric = "euclidean"
gpu_number = 1
k = 4
Expand Down Expand Up @@ -1020,6 +1021,21 @@ def test_return_fewer_k(
int64_max = np.iinfo("int64").max
float_inf = float("inf")

# ensure consistency with cuvs for ivfflat, and ivfpq > 24.10
import cuvs
from packaging import version

if algorithm == "ivfflat" or version.parse(cuvs.__version__) > version.parse(
"24.10.00"
):
ann_evaluator = ANNEvaluator(X, k, metric)
spark_indices = np.array([row["indices"] for row in knn_df_collect])
spark_distances = np.array([row["distances"] for row in knn_df_collect])
ann_evaluator.compare_with_cuml_or_cuvs_sg(
algorithm, algo_params, spark_indices, spark_distances, tolerance=0.0
)

# check result details
indices_none_probed = [int64_max, int64_max, int64_max, int64_max]
distances_none_probed = [float_inf, float_inf, float_inf, float_inf]

Expand Down

0 comments on commit 8f290f2

Please sign in to comment.