From c78e86f26cf8fa813c693e8a872c0e89a4adf667 Mon Sep 17 00:00:00 2001 From: Ajit Mistry Date: Thu, 22 Aug 2024 10:31:48 +0000 Subject: [PATCH] add return tensor --- cpp/include/cuvs/neighbors/cagra.h | 1 + cpp/src/neighbors/cagra_c.cpp | 39 ++++++++++++++++++------------ cpp/test/neighbors/ann_cagra_c.cu | 17 ++++++++++++- 3 files changed, 41 insertions(+), 16 deletions(-) diff --git a/cpp/include/cuvs/neighbors/cagra.h b/cpp/include/cuvs/neighbors/cagra.h index 3596f9b8c..d76d08c06 100644 --- a/cpp/include/cuvs/neighbors/cagra.h +++ b/cpp/include/cuvs/neighbors/cagra.h @@ -385,6 +385,7 @@ cuvsError_t cuvsCagraBuild(cuvsResources_t res, cuvsError_t cuvsCagraExtend(cuvsResources_t res, cuvsCagraExtendParams_t params, DLManagedTensor* additional_dataset, + DLManagedTensor* return_dataset, cuvsCagraIndex_t index); /** diff --git a/cpp/src/neighbors/cagra_c.cpp b/cpp/src/neighbors/cagra_c.cpp index c278a75a3..8c94b4f05 100644 --- a/cpp/src/neighbors/cagra_c.cpp +++ b/cpp/src/neighbors/cagra_c.cpp @@ -87,23 +87,31 @@ template void _extend(cuvsResources_t res, cuvsCagraExtendParams params, cuvsCagraIndex index, - DLManagedTensor* additional_dataset_tensor) + DLManagedTensor* additional_dataset_tensor, + DLManagedTensor* return_tensor) { - auto dataset = additional_dataset_tensor->dl_tensor; - auto index_ptr = reinterpret_cast*>(index.addr); - auto res_ptr = reinterpret_cast(res); + auto dataset = additional_dataset_tensor->dl_tensor; + auto return_dl_tensor = return_tensor->dl_tensor; + auto index_ptr = reinterpret_cast*>(index.addr); + auto res_ptr = reinterpret_cast(res); auto extend_params = cuvs::neighbors::cagra::extend_params(); extend_params.max_chunk_size = params.max_chunk_size; - if (cuvs::core::is_dlpack_device_compatible(dataset)) { - using mdspan_type = raft::device_matrix_view; - auto mds = cuvs::core::from_dlpack(additional_dataset_tensor); - cuvs::neighbors::cagra::extend(*res_ptr, extend_params, mds, *index_ptr); - } else if (cuvs::core::is_dlpack_host_compatible(dataset)) { - using mdspan_type = raft::host_matrix_view; - auto mds = cuvs::core::from_dlpack(additional_dataset_tensor); - cuvs::neighbors::cagra::extend(*res_ptr, extend_params, mds, *index_ptr); + if (cuvs::core::is_dlpack_device_compatible(dataset) && + cuvs::core::is_dlpack_device_compatible(return_dl_tensor)) { + using mdspan_type = raft::device_matrix_view; + using mdspan_return_type = raft::device_matrix_view; + auto mds = cuvs::core::from_dlpack(additional_dataset_tensor); + auto return_mds = cuvs::core::from_dlpack(return_tensor); + cuvs::neighbors::cagra::extend(*res_ptr, extend_params, mds, *index_ptr, return_mds); + } else if (cuvs::core::is_dlpack_host_compatible(dataset) && + cuvs::core::is_dlpack_host_compatible(return_dl_tensor)) { + using mdspan_type = raft::host_matrix_view; + using mdspan_return_type = raft::device_matrix_view; + auto mds = cuvs::core::from_dlpack(additional_dataset_tensor); + auto return_mds = cuvs::core::from_dlpack(return_tensor); + cuvs::neighbors::cagra::extend(*res_ptr, extend_params, mds, *index_ptr, return_mds); } else { RAFT_FAIL("Unsupported dataset DLtensor dtype: %d and bits: %d", dataset.dtype.code, @@ -221,6 +229,7 @@ extern "C" cuvsError_t cuvsCagraBuild(cuvsResources_t res, extern "C" cuvsError_t cuvsCagraExtend(cuvsResources_t res, cuvsCagraExtendParams_t params, DLManagedTensor* additional_dataset_tensor, + DLManagedTensor* return_dataset_tensor, cuvsCagraIndex_t index_c_ptr) { return cuvs::core::translate_exceptions([=] { @@ -228,11 +237,11 @@ extern "C" cuvsError_t cuvsCagraExtend(cuvsResources_t res, auto index = *index_c_ptr; if ((dataset.dtype.code == kDLFloat) && (dataset.dtype.bits == 32)) { - _extend(res, *params, index, additional_dataset_tensor); + _extend(res, *params, index, additional_dataset_tensor, return_dataset_tensor); } else if (dataset.dtype.code == kDLInt && dataset.dtype.bits == 8) { - _extend(res, *params, index, additional_dataset_tensor); + _extend(res, *params, index, additional_dataset_tensor, return_dataset_tensor); } else if (dataset.dtype.code == kDLUInt && dataset.dtype.bits == 8) { - _extend(res, *params, index, additional_dataset_tensor); + _extend(res, *params, index, additional_dataset_tensor, return_dataset_tensor); } else { RAFT_FAIL("Unsupported dataset DLtensor dtype: %d and bits: %d", dataset.dtype.code, diff --git a/cpp/test/neighbors/ann_cagra_c.cu b/cpp/test/neighbors/ann_cagra_c.cu index 407757b93..2322ea185 100644 --- a/cpp/test/neighbors/ann_cagra_c.cu +++ b/cpp/test/neighbors/ann_cagra_c.cu @@ -100,6 +100,18 @@ TEST(CagraC, BuildSearch) additional_dataset_tensor.dl_tensor.shape = additional_dataset_shape; additional_dataset_tensor.dl_tensor.strides = nullptr; + rmm::device_uvector extend_return_d((additional_num_rows + main_num_rows) * 2, stream); + DLManagedTensor additional_dataset_return_tensor; + additional_dataset_return_tensor.dl_tensor.data = extend_return_d.data(); + additional_dataset_return_tensor.dl_tensor.device.device_type = kDLCUDA; + additional_dataset_return_tensor.dl_tensor.ndim = 2; + additional_dataset_return_tensor.dl_tensor.dtype.code = kDLFloat; + additional_dataset_return_tensor.dl_tensor.dtype.bits = 32; + additional_dataset_return_tensor.dl_tensor.dtype.lanes = 1; + int64_t additional_return_dataset_shape[2] = {additional_num_rows + main_num_rows, 2}; + additional_dataset_return_tensor.dl_tensor.shape = additional_return_dataset_shape; + additional_dataset_return_tensor.dl_tensor.strides = nullptr; + // create index cuvsCagraIndex_t index; cuvsCagraIndexCreate(&index); @@ -112,7 +124,10 @@ TEST(CagraC, BuildSearch) cuvsCagraExtendParams_t extend_params; cuvsCagraExtendParamsCreate(&extend_params); extend_params->max_chunk_size = 100; - cuvsCagraExtend(res, extend_params, &additional_dataset_tensor, index); + cuvsCagraExtend( + res, extend_params, &additional_dataset_tensor, &additional_dataset_return_tensor, index); + + extend_return_d.resize(main_num_rows * 2, stream); // create queries DLTensor rmm::device_uvector queries_d(4 * 2, stream);