From 787653f550626227db081e455ca4aebef9f0f8a1 Mon Sep 17 00:00:00 2001 From: Jinfeng Date: Fri, 18 Oct 2024 22:04:06 +0000 Subject: [PATCH] add refine to the knn.py for ivfpq in progress for checkout add debug info get ivf_pq cosine passed by increasing dataset std to make it separable get ivf_pq working after using refine remove unnecessary test for refine get refine work for less than k itmes probed replace df.withColumn with df.select to fix slowdown for df that was initialized with wide pd.DataFrame --- python/src/spark_rapids_ml/knn.py | 26 +++++++++++++- .../test_approximate_nearest_neighbors.py | 34 +++++++++++++++---- python/tests/utils.py | 12 ++++++- 3 files changed, 63 insertions(+), 9 deletions(-) diff --git a/python/src/spark_rapids_ml/knn.py b/python/src/spark_rapids_ml/knn.py index 08500c21..3dba8886 100644 --- a/python/src/spark_rapids_ml/knn.py +++ b/python/src/spark_rapids_ml/knn.py @@ -1528,7 +1528,7 @@ def _transform_internal( nn_object ): # derived class (e.g. benchmark.bench_nearest_neighbors.CPUNearestNeighborsModel) distances, indices = nn_object.kneighbors(bcast_qfeatures.value) - else: # cuvs ivf_flat cagra + else: # cuvs ivf_flat cagra ivf_pq gpu_qfeatures = cp.array( bcast_qfeatures.value, order="C", dtype="float32" ) @@ -1543,9 +1543,33 @@ def _transform_internal( gpu_qfeatures, cuml_alg_params["n_neighbors"], ) + + if cuml_alg_params["algorithm"] in {"ivf_pq", "ivfpq"}: + from cuvs.neighbors import refine + + distances, indices = refine( + dataset=item, + queries=gpu_qfeatures, + candidates=indices, + k=cuml_alg_params["n_neighbors"], + metric=cuml_alg_params["metric"], + ) + distances = cp.asarray(distances) indices = cp.asarray(indices) + # in case refine API reset inf distances to 0. + if cuml_alg_params["algorithm"] in {"ivf_pq", "ivfpq"}: + distances[indices >= len(item)] = float("inf") + + # for the case top-1 nn got filled into indices + top1_ind = indices[:, 0] + rest_indices = indices[:, 1:] + rest_distances = distances[:, 1:] + rest_distances[rest_indices == top1_ind[:, cp.newaxis]] = float( + "inf" + ) + if isinstance(distances, cp.ndarray): distances = distances.get() diff --git a/python/tests/test_approximate_nearest_neighbors.py b/python/tests/test_approximate_nearest_neighbors.py index 385cc500..af2210a7 100644 --- a/python/tests/test_approximate_nearest_neighbors.py +++ b/python/tests/test_approximate_nearest_neighbors.py @@ -265,7 +265,7 @@ def compare_with_cuml_or_cuvs_sg( avg_dist_gap_cumlann = self.cal_avg_dist_gap(cuvssg_distances) avg_dist_gap = self.cal_avg_dist_gap(given_distances) - assert (avg_dist_gap < avg_dist_gap_cumlann) or abs( + assert (avg_dist_gap <= avg_dist_gap_cumlann) or abs( avg_dist_gap - avg_dist_gap_cumlann ) < tolerance @@ -335,6 +335,13 @@ def get_cuvs_sg_results( cuvs_algo.SearchParams(**search_params), index, gpu_X, self.n_neighbors ) + if algorithm in {"ivf_pq", "ivfpq"}: + from cuvs.neighbors import refine + + sg_distances, sg_indices = refine( + gpu_X, gpu_X, sg_indices, self.n_neighbors, metric=self.metric + ) + # convert results to cp array then to np array sg_distances = cp.array(sg_distances).get() sg_indices = cp.array(sg_indices).get() @@ -351,6 +358,7 @@ def ann_algorithm_test_func( distances_are_exact: bool = True, tolerance: float = 1e-4, n_neighbors: int = 50, + cluster_std: float = 1.0, ) -> None: algorithm = combo[0] @@ -382,6 +390,7 @@ def ann_algorithm_test_func( n_features=data_shape[1], centers=n_clusters, random_state=0, + cluster_std=cluster_std, ) # make_blobs creates a random dataset of isotropic gaussian blobs. # set average norm sq to be 1 to allow comparisons with default error thresholds @@ -435,7 +444,7 @@ def ann_algorithm_test_func( # test kneighbors: compare top-1 nn indices(self) and distances(self) - if metric != "inner_product" and distances_are_exact: + if metric != "inner_product": self_index = [knn[0] for knn in indices] assert np.all(self_index == y) @@ -660,21 +669,28 @@ def test_ivfpq( (2) ivfpq has become unstable in 24.10. It does not get passed with algoParam {"nlist" : 10, "nprobe" : 2, "M": 2, "n_bits": 4} in ci where test_ivfflat is run beforehand. avg_recall shows large variance, depending on the quantization accuracy. This can be fixed by increasing nlist, nprobe, M, and n_bits. Note ivf_pq is non-deterministic, and it seems due to kmeans initialization leveraging runtime values of GPU memory. - (3) If M is is too small (e.g. 2), the returned distances can be very different from the ground distances. - Spark rapids ml may give lower recall than cuvs sg because it aggregates local topk candidates by the returned distances. + (3) In ivfpq, when the dataset itself is used as queries, it is observed sometimes that the top-1 indice may not be self, and top-1 distance may not be zero. + This is because ivfpq internally uses approximated distance, i.e. the distance of the query vector to the center of quantized item. """ combo = (algorithm, feature_type, max_records_per_batch, algo_params, metric) - expected_avg_recall = 0.4 if metric != "cosine" else 0.1 - distances_are_exact = False + expected_avg_recall = 0.4 + distances_are_exact = True + expected_avg_dist_gap = 0.05 tolerance = 0.05 # tolerance increased to be more stable due to quantization and randomness in ivfpq, especially when expected_recall is low. + cluster_std = ( + 1.0 if metric != "cosine" else 10.0 + ) # Increasing cluster_std for cosine to make dataset more randomized and separable. + ann_algorithm_test_func( combo=combo, data_shape=data_shape, data_type=data_type, expected_avg_recall=expected_avg_recall, + expected_avg_dist_gap=expected_avg_dist_gap, distances_are_exact=distances_are_exact, tolerance=tolerance, + cluster_std=cluster_std, ) @@ -823,12 +839,16 @@ def test_ivfflat_wide_matrix( data_shape: Tuple[int, int], data_type: np.dtype, ) -> None: + """ + It seems adding a column with df.withColumn can be very slow, if df already has many columns (e.g. 3000). + One strategy is to avoid df.withColumn on wide df and use df.select instead. + """ import time start = time.time() ann_algorithm_test_func(combo=combo, data_shape=data_shape, data_type=data_type) duration_sec = time.time() - start - assert duration_sec < 10 * 60 + assert duration_sec < 3 * 60 @pytest.mark.parametrize( diff --git a/python/tests/utils.py b/python/tests/utils.py index 3b9ffc28..25ec41de 100644 --- a/python/tests/utils.py +++ b/python/tests/utils.py @@ -117,7 +117,14 @@ def create_pyspark_dataframe( df = spark.createDataFrame(data.tolist(), ",".join(schema)) if feature_type == feature_types.array: - df = df.withColumn("features", array(*feature_cols)).drop(*feature_cols) + # avoid calling df.withColumn here because runtime slowdown is observed when df has many columns (e.g. 3000). + from pyspark.sql.functions import col + + selected_col = [array(*feature_cols).alias("features")] + if label_col: + selected_col.append(col(label_col).alias(label_col)) + df = df.select(selected_col) + feature_cols = "features" elif feature_type == feature_types.vector: df = ( @@ -128,6 +135,9 @@ def create_pyspark_dataframe( .drop(*feature_cols) ) feature_cols = "features" + else: + # When df has many columns (e.g. 3000), the select here breaks the runtime slowdown observed at calling df.withColumn. + df = df.select("*") return df, feature_cols, label_col