From 7c4f58f5304e42c941f1a2c07bfe8665c73b75f3 Mon Sep 17 00:00:00 2001 From: "Yusheng.Ma" Date: Thu, 30 Nov 2023 10:16:12 +0000 Subject: [PATCH] flush raft results ids Signed-off-by: Yusheng.Ma --- .../raft/integration/raft_knowhere_index.cuh | 29 ++++++++++++++++--- 1 file changed, 25 insertions(+), 4 deletions(-) diff --git a/src/common/raft/integration/raft_knowhere_index.cuh b/src/common/raft/integration/raft_knowhere_index.cuh index 505d1e171..822b124ae 100644 --- a/src/common/raft/integration/raft_knowhere_index.cuh +++ b/src/common/raft/integration/raft_knowhere_index.cuh @@ -457,9 +457,15 @@ struct raft_knowhere_index::impl { device_distances, config.refine_ratio, input_indexing_type{}, dataset_view, raft::neighbors::filtering::bitset_filter{ device_bitset->view()}); - thrust::replace(res.get_thrust_policy(), thrust::device_ptr(device_ids.data_handle()), - thrust::device_ptr(device_ids.data_handle() + output_size), - std::numeric_limits::max(), indexing_type{-1}); + // For raft_ivf_flat, setting invalid IDs to 0 here will only + // slightly affect the recall(ground truth contain 0th vector). In order to + // maintain consistency in Knowhere's results and better serve the merging of results in the upper layers of + // Milvus, it is temporarily replaced with -1 + if constexpr (index_kind == raft_proto::raft_index_kind::ivf_flat) { + thrust::replace(res.get_thrust_policy(), thrust::device_ptr(device_ids.data_handle()), + thrust::device_ptr(device_ids.data_handle() + output_size), + indexing_type{0}, indexing_type{-1}); + } } else { raft_index_type::search(res, *index_, search_params, raft::make_const_mdspan(device_data_storage.view()), @@ -471,7 +477,22 @@ struct raft_knowhere_index::impl { thrust::device_ptr(device_ids.data_handle() + output_size), std::numeric_limits::max(), indexing_type{-1}); } - raft::copy(res, host_ids, device_ids); + if constexpr (index_kind == raft_proto::raft_index_kind::cagra) { + if (device_bitset) { + auto tmp = raft::make_device_matrix(res, row_count, k); + raft::copy(res, tmp.view(), device_ids); + thrust::replace(res.get_thrust_policy(), thrust::device_ptr(tmp.data_handle()), + thrust::device_ptr(tmp.data_handle() + output_size), + knowhere_indexing_type{0x7fffffffu}, knowhere_indexing_type{-1}); + raft::copy(res, host_ids, tmp.view()); + + } else { + raft::copy(res, host_ids, device_ids); + } + + } else { + raft::copy(res, host_ids, device_ids); + } raft::copy(res, host_distances, device_distances); return std::make_tuple(ids.release(), distances.release()); }