Skip to content

Commit

Permalink
Updating APIs to account for const types
Browse files Browse the repository at this point in the history
  • Loading branch information
cjnolet committed Apr 23, 2024
1 parent 59fe3fd commit 28b115c
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 36 deletions.
8 changes: 4 additions & 4 deletions cpp/src/distance/distance-ext.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -760,8 +760,8 @@ instantiate_raft_distance_pairwise_distance(double, int);
#define instantiate_raft_distance_distance(DistT, DataT, AccT, OutT, layout, IdxT) \
extern template void cuvs::distance::distance<DistT, DataT, AccT, OutT, layout, IdxT>( \
raft::resources const& handle, \
raft::device_matrix_view<DataT, IdxT, layout> const x, \
raft::device_matrix_view<DataT, IdxT, layout> const y, \
raft::device_matrix_view<const DataT, IdxT, layout> const x, \
raft::device_matrix_view<const DataT, IdxT, layout> const y, \
raft::device_matrix_view<OutT, IdxT, layout> dist, \
DataT metric_arg)

Expand Down Expand Up @@ -1052,8 +1052,8 @@ instantiate_raft_distance_distance(cuvs::distance::DistanceType::RusselRaoExpand
#define instantiate_raft_distance_pairwise_distance(DataT, layout, IdxT) \
extern template void cuvs::distance::pairwise_distance( \
raft::resources const& handle, \
raft::device_matrix_view<DataT, IdxT, layout> const x, \
raft::device_matrix_view<DataT, IdxT, layout> const y, \
raft::device_matrix_view<const DataT, IdxT, layout> const x, \
raft::device_matrix_view<const DataT, IdxT, layout> const y, \
raft::device_matrix_view<DataT, IdxT, layout> dist, \
cuvs::distance::DistanceType metric, \
DataT metric_arg)
Expand Down
8 changes: 4 additions & 4 deletions cpp/src/distance/distance-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -394,8 +394,8 @@ template <cuvs::distance::DistanceType DistT,
typename layout = raft::layout_c_contiguous,
typename IdxT = int>
void distance(raft::resources const& handle,
raft::device_matrix_view<DataT, IdxT, layout> const x,
raft::device_matrix_view<DataT, IdxT, layout> const y,
raft::device_matrix_view<const DataT, IdxT, layout> const x,
raft::device_matrix_view<const DataT, IdxT, layout> const y,
raft::device_matrix_view<OutT, IdxT, layout> dist,
DataT metric_arg = 2.0f)
{
Expand Down Expand Up @@ -437,8 +437,8 @@ void distance(raft::resources const& handle,
*/
template <typename Type, typename layout = raft::layout_c_contiguous, typename IdxT = int>
void pairwise_distance(raft::resources const& handle,
raft::device_matrix_view<Type, IdxT, layout> const x,
raft::device_matrix_view<Type, IdxT, layout> const y,
raft::device_matrix_view<const Type, IdxT, layout> const x,
raft::device_matrix_view<const Type, IdxT, layout> const y,
raft::device_matrix_view<Type, IdxT, layout> dist,
cuvs::distance::DistanceType metric,
Type metric_arg = 2.0f)
Expand Down
8 changes: 4 additions & 4 deletions cpp/src/distance/distance.cu
Original file line number Diff line number Diff line change
Expand Up @@ -628,8 +628,8 @@ instantiate_raft_distance_pairwise_distance(double, int);
#define instantiate_raft_distance_distance(DistT, DataT, AccT, OutT, layout, IdxT) \
template void cuvs::distance::distance<DistT, DataT, AccT, OutT, layout, IdxT>( \
raft::resources const& handle, \
raft::device_matrix_view<DataT, IdxT, layout> const x, \
raft::device_matrix_view<DataT, IdxT, layout> const y, \
raft::device_matrix_view<const DataT, IdxT, layout> const x, \
raft::device_matrix_view<const DataT, IdxT, layout> const y, \
raft::device_matrix_view<OutT, IdxT, layout> dist, \
DataT metric_arg)

Expand Down Expand Up @@ -920,8 +920,8 @@ instantiate_raft_distance_distance(cuvs::distance::DistanceType::RusselRaoExpand
#define instantiate_raft_distance_pairwise_distance(DataT, layout, IdxT) \
template void cuvs::distance::pairwise_distance( \
raft::resources const& handle, \
raft::device_matrix_view<DataT, IdxT, layout> const x, \
raft::device_matrix_view<DataT, IdxT, layout> const y, \
raft::device_matrix_view<const DataT, IdxT, layout> const x, \
raft::device_matrix_view<const DataT, IdxT, layout> const y, \
raft::device_matrix_view<DataT, IdxT, layout> dist, \
cuvs::distance::DistanceType metric, \
DataT metric_arg)
Expand Down
52 changes: 28 additions & 24 deletions cpp/src/distance/pairwise_distance.cu
Original file line number Diff line number Diff line change
Expand Up @@ -25,42 +25,46 @@ namespace cuvs::distance {
* @defgroup pairwise_distance_runtime Pairwise Distances Runtime API
* @{
*/
void pairwise_distance(raft::resources const& handle,
raft::device_matrix_view<float, int, raft::layout_c_contiguous> const x,
raft::device_matrix_view<float, int, raft::layout_c_contiguous> const y,
raft::device_matrix_view<float, int, raft::layout_c_contiguous> dist,
cuvs::distance::DistanceType metric,
float metric_arg)
void pairwise_distance(
raft::resources const& handle,
raft::device_matrix_view<const float, int, raft::layout_c_contiguous> const x,
raft::device_matrix_view<const float, int, raft::layout_c_contiguous> const y,
raft::device_matrix_view<float, int, raft::layout_c_contiguous> dist,
cuvs::distance::DistanceType metric,
float metric_arg)
{
pairwise_distance<float, raft::layout_c_contiguous, int>(handle, x, y, dist, metric, metric_arg);
}

void pairwise_distance(raft::resources const& handle,
raft::device_matrix_view<double, int, raft::layout_c_contiguous> const x,
raft::device_matrix_view<double, int, raft::layout_c_contiguous> const y,
raft::device_matrix_view<double, int, raft::layout_c_contiguous> dist,
cuvs::distance::DistanceType metric,
double metric_arg)
void pairwise_distance(
raft::resources const& handle,
raft::device_matrix_view<const double, int, raft::layout_c_contiguous> const x,
raft::device_matrix_view<const double, int, raft::layout_c_contiguous> const y,
raft::device_matrix_view<double, int, raft::layout_c_contiguous> dist,
cuvs::distance::DistanceType metric,
double metric_arg)
{
pairwise_distance<double, raft::layout_c_contiguous, int>(handle, x, y, dist, metric, metric_arg);
}

void pairwise_distance(raft::resources const& handle,
raft::device_matrix_view<float, int, raft::layout_f_contiguous> const x,
raft::device_matrix_view<float, int, raft::layout_f_contiguous> const y,
raft::device_matrix_view<float, int, raft::layout_f_contiguous> dist,
cuvs::distance::DistanceType metric,
float metric_arg)
void pairwise_distance(
raft::resources const& handle,
raft::device_matrix_view<const float, int, raft::layout_f_contiguous> const x,
raft::device_matrix_view<const float, int, raft::layout_f_contiguous> const y,
raft::device_matrix_view<float, int, raft::layout_f_contiguous> dist,
cuvs::distance::DistanceType metric,
float metric_arg)
{
pairwise_distance<float, raft::layout_f_contiguous, int>(handle, x, y, dist, metric, metric_arg);
}

void pairwise_distance(raft::resources const& handle,
raft::device_matrix_view<double, int, raft::layout_f_contiguous> const x,
raft::device_matrix_view<double, int, raft::layout_f_contiguous> const y,
raft::device_matrix_view<double, int, raft::layout_f_contiguous> dist,
cuvs::distance::DistanceType metric,
double metric_arg)
void pairwise_distance(
raft::resources const& handle,
raft::device_matrix_view<const double, int, raft::layout_f_contiguous> const x,
raft::device_matrix_view<const double, int, raft::layout_f_contiguous> const y,
raft::device_matrix_view<double, int, raft::layout_f_contiguous> dist,
cuvs::distance::DistanceType metric,
double metric_arg)
{
pairwise_distance<double, raft::layout_f_contiguous, int>(handle, x, y, dist, metric, metric_arg);
}
Expand Down

0 comments on commit 28b115c

Please sign in to comment.