Skip to content

Commit

Permalink
flush raft results ids
Browse files Browse the repository at this point in the history
Signed-off-by: Yusheng.Ma <[email protected]>
  • Loading branch information
Presburger committed Nov 30, 2023
1 parent 0cf256c commit 7c4f58f
Showing 1 changed file with 25 additions and 4 deletions.
29 changes: 25 additions & 4 deletions src/common/raft/integration/raft_knowhere_index.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -457,9 +457,15 @@ struct raft_knowhere_index<IndexKind>::impl {
device_distances, config.refine_ratio, input_indexing_type{}, dataset_view,
raft::neighbors::filtering::bitset_filter<knowhere_bitset_data_type, knowhere_bitset_indexing_type>{
device_bitset->view()});
thrust::replace(res.get_thrust_policy(), thrust::device_ptr<indexing_type>(device_ids.data_handle()),
thrust::device_ptr<indexing_type>(device_ids.data_handle() + output_size),
std::numeric_limits<indexing_type>::max(), indexing_type{-1});
// For raft_ivf_flat, setting invalid IDs to 0 here will only
// slightly affect the recall(ground truth contain 0th vector). In order to
// maintain consistency in Knowhere's results and better serve the merging of results in the upper layers of
// Milvus, it is temporarily replaced with -1
if constexpr (index_kind == raft_proto::raft_index_kind::ivf_flat) {
thrust::replace(res.get_thrust_policy(), thrust::device_ptr<indexing_type>(device_ids.data_handle()),
thrust::device_ptr<indexing_type>(device_ids.data_handle() + output_size),
indexing_type{0}, indexing_type{-1});
}

} else {
raft_index_type::search(res, *index_, search_params, raft::make_const_mdspan(device_data_storage.view()),
Expand All @@ -471,7 +477,22 @@ struct raft_knowhere_index<IndexKind>::impl {
thrust::device_ptr<indexing_type>(device_ids.data_handle() + output_size),
std::numeric_limits<indexing_type>::max(), indexing_type{-1});
}
raft::copy(res, host_ids, device_ids);
if constexpr (index_kind == raft_proto::raft_index_kind::cagra) {
if (device_bitset) {
auto tmp = raft::make_device_matrix<knowhere_indexing_type, input_indexing_type>(res, row_count, k);
raft::copy(res, tmp.view(), device_ids);
thrust::replace(res.get_thrust_policy(), thrust::device_ptr<knowhere_indexing_type>(tmp.data_handle()),
thrust::device_ptr<knowhere_indexing_type>(tmp.data_handle() + output_size),
knowhere_indexing_type{0x7fffffffu}, knowhere_indexing_type{-1});
raft::copy(res, host_ids, tmp.view());

} else {
raft::copy(res, host_ids, device_ids);
}

} else {
raft::copy(res, host_ids, device_ids);
}
raft::copy(res, host_distances, device_distances);
return std::make_tuple(ids.release(), distances.release());
}
Expand Down

0 comments on commit 7c4f58f

Please sign in to comment.