Skip to content

Commit

Permalink
add return tensor
Browse files Browse the repository at this point in the history
  • Loading branch information
ajit283 committed Aug 22, 2024
1 parent ccb5d86 commit c78e86f
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 16 deletions.
1 change: 1 addition & 0 deletions cpp/include/cuvs/neighbors/cagra.h
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,7 @@ cuvsError_t cuvsCagraBuild(cuvsResources_t res,
cuvsError_t cuvsCagraExtend(cuvsResources_t res,
cuvsCagraExtendParams_t params,
DLManagedTensor* additional_dataset,
DLManagedTensor* return_dataset,
cuvsCagraIndex_t index);

/**
Expand Down
39 changes: 24 additions & 15 deletions cpp/src/neighbors/cagra_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,23 +87,31 @@ template <typename T>
void _extend(cuvsResources_t res,
cuvsCagraExtendParams params,
cuvsCagraIndex index,
DLManagedTensor* additional_dataset_tensor)
DLManagedTensor* additional_dataset_tensor,
DLManagedTensor* return_tensor)
{
auto dataset = additional_dataset_tensor->dl_tensor;
auto index_ptr = reinterpret_cast<cuvs::neighbors::cagra::index<T, uint32_t>*>(index.addr);
auto res_ptr = reinterpret_cast<raft::resources*>(res);
auto dataset = additional_dataset_tensor->dl_tensor;
auto return_dl_tensor = return_tensor->dl_tensor;
auto index_ptr = reinterpret_cast<cuvs::neighbors::cagra::index<T, uint32_t>*>(index.addr);
auto res_ptr = reinterpret_cast<raft::resources*>(res);

auto extend_params = cuvs::neighbors::cagra::extend_params();
extend_params.max_chunk_size = params.max_chunk_size;

if (cuvs::core::is_dlpack_device_compatible(dataset)) {
using mdspan_type = raft::device_matrix_view<T const, int64_t, raft::row_major>;
auto mds = cuvs::core::from_dlpack<mdspan_type>(additional_dataset_tensor);
cuvs::neighbors::cagra::extend(*res_ptr, extend_params, mds, *index_ptr);
} else if (cuvs::core::is_dlpack_host_compatible(dataset)) {
using mdspan_type = raft::host_matrix_view<T const, int64_t, raft::row_major>;
auto mds = cuvs::core::from_dlpack<mdspan_type>(additional_dataset_tensor);
cuvs::neighbors::cagra::extend(*res_ptr, extend_params, mds, *index_ptr);
if (cuvs::core::is_dlpack_device_compatible(dataset) &&
cuvs::core::is_dlpack_device_compatible(return_dl_tensor)) {
using mdspan_type = raft::device_matrix_view<T const, int64_t, raft::row_major>;
using mdspan_return_type = raft::device_matrix_view<T, int64_t, raft::row_major>;
auto mds = cuvs::core::from_dlpack<mdspan_type>(additional_dataset_tensor);
auto return_mds = cuvs::core::from_dlpack<mdspan_return_type>(return_tensor);
cuvs::neighbors::cagra::extend(*res_ptr, extend_params, mds, *index_ptr, return_mds);
} else if (cuvs::core::is_dlpack_host_compatible(dataset) &&
cuvs::core::is_dlpack_host_compatible(return_dl_tensor)) {
using mdspan_type = raft::host_matrix_view<T const, int64_t, raft::row_major>;
using mdspan_return_type = raft::device_matrix_view<T, int64_t, raft::row_major>;
auto mds = cuvs::core::from_dlpack<mdspan_type>(additional_dataset_tensor);
auto return_mds = cuvs::core::from_dlpack<mdspan_return_type>(return_tensor);
cuvs::neighbors::cagra::extend(*res_ptr, extend_params, mds, *index_ptr, return_mds);
} else {
RAFT_FAIL("Unsupported dataset DLtensor dtype: %d and bits: %d",
dataset.dtype.code,
Expand Down Expand Up @@ -221,18 +229,19 @@ extern "C" cuvsError_t cuvsCagraBuild(cuvsResources_t res,
extern "C" cuvsError_t cuvsCagraExtend(cuvsResources_t res,
cuvsCagraExtendParams_t params,
DLManagedTensor* additional_dataset_tensor,
DLManagedTensor* return_dataset_tensor,
cuvsCagraIndex_t index_c_ptr)
{
return cuvs::core::translate_exceptions([=] {
auto dataset = additional_dataset_tensor->dl_tensor;
auto index = *index_c_ptr;

if ((dataset.dtype.code == kDLFloat) && (dataset.dtype.bits == 32)) {
_extend<float>(res, *params, index, additional_dataset_tensor);
_extend<float>(res, *params, index, additional_dataset_tensor, return_dataset_tensor);
} else if (dataset.dtype.code == kDLInt && dataset.dtype.bits == 8) {
_extend<int8_t>(res, *params, index, additional_dataset_tensor);
_extend<int8_t>(res, *params, index, additional_dataset_tensor, return_dataset_tensor);
} else if (dataset.dtype.code == kDLUInt && dataset.dtype.bits == 8) {
_extend<uint8_t>(res, *params, index, additional_dataset_tensor);
_extend<uint8_t>(res, *params, index, additional_dataset_tensor, return_dataset_tensor);
} else {
RAFT_FAIL("Unsupported dataset DLtensor dtype: %d and bits: %d",
dataset.dtype.code,
Expand Down
17 changes: 16 additions & 1 deletion cpp/test/neighbors/ann_cagra_c.cu
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,18 @@ TEST(CagraC, BuildSearch)
additional_dataset_tensor.dl_tensor.shape = additional_dataset_shape;
additional_dataset_tensor.dl_tensor.strides = nullptr;

rmm::device_uvector<float> extend_return_d((additional_num_rows + main_num_rows) * 2, stream);
DLManagedTensor additional_dataset_return_tensor;
additional_dataset_return_tensor.dl_tensor.data = extend_return_d.data();
additional_dataset_return_tensor.dl_tensor.device.device_type = kDLCUDA;
additional_dataset_return_tensor.dl_tensor.ndim = 2;
additional_dataset_return_tensor.dl_tensor.dtype.code = kDLFloat;
additional_dataset_return_tensor.dl_tensor.dtype.bits = 32;
additional_dataset_return_tensor.dl_tensor.dtype.lanes = 1;
int64_t additional_return_dataset_shape[2] = {additional_num_rows + main_num_rows, 2};
additional_dataset_return_tensor.dl_tensor.shape = additional_return_dataset_shape;
additional_dataset_return_tensor.dl_tensor.strides = nullptr;

// create index
cuvsCagraIndex_t index;
cuvsCagraIndexCreate(&index);
Expand All @@ -112,7 +124,10 @@ TEST(CagraC, BuildSearch)
cuvsCagraExtendParams_t extend_params;
cuvsCagraExtendParamsCreate(&extend_params);
extend_params->max_chunk_size = 100;
cuvsCagraExtend(res, extend_params, &additional_dataset_tensor, index);
cuvsCagraExtend(
res, extend_params, &additional_dataset_tensor, &additional_dataset_return_tensor, index);

extend_return_d.resize(main_num_rows * 2, stream);

// create queries DLTensor
rmm::device_uvector<float> queries_d(4 * 2, stream);
Expand Down

0 comments on commit c78e86f

Please sign in to comment.