From 11a740256258a5c3a07bf9a219c9b10a9572ff65 Mon Sep 17 00:00:00 2001 From: divyegala Date: Tue, 12 Nov 2024 15:37:27 -0800 Subject: [PATCH] reword conditional --- cpp/src/neighbors/detail/nn_descent.cuh | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/cpp/src/neighbors/detail/nn_descent.cuh b/cpp/src/neighbors/detail/nn_descent.cuh index 7c8d4c055..c62a52540 100644 --- a/cpp/src/neighbors/detail/nn_descent.cuh +++ b/cpp/src/neighbors/detail/nn_descent.cuh @@ -473,7 +473,8 @@ RAFT_KERNEL preprocess_data_kernel( if (threadIdx.x == 0) { l2_norm = 0; } __syncthreads(); - if (metric != cuvs::distance::DistanceType::InnerProduct) { + if (metric == cuvs::distance::DistanceType::L2Expanded || + metric == cuvs::distance::DistanceType::CosineExpanded) { int lane_id = threadIdx.x % raft::warp_size(); for (int step = 0; step < raft::ceildiv(dim, raft::warp_size()); step++) { int idx = step * raft::warp_size() + lane_id;