diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 8a8d5fb1c..8556e2941 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -216,7 +216,6 @@ add_library( src/distance/detail/pairwise_matrix/dispatch_russel_rao_double_double_double_int.cu src/distance/detail/pairwise_matrix/dispatch_russel_rao_float_float_float_int.cu src/distance/distance.cu - src/distance/fused_nn.cu src/distance/pairwise_distance.cu src/neighbors/brute_force_index.cu src/neighbors/brute_force.cu diff --git a/cpp/include/cuvs/distance/one_nn.hpp b/cpp/include/cuvs/distance/one_nn.hpp deleted file mode 100644 index d2bea2672..000000000 --- a/cpp/include/cuvs/distance/one_nn.hpp +++ /dev/null @@ -1,62 +0,0 @@ -/* - * Copyright (c) 2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -namespace cuvs::distance { - -/** - * @defgroup fused_distance_nn_min_arg_runtime Fused Distance 1NN Runtime API - * @{ - */ - -/** - * @brief Wrapper around fusedDistanceNN with minimum reduction operators. - * - * fusedDistanceNN cannot be compiled in the distance library due to the lambda - * operators, so this wrapper covers the most common case (minimum). - * - * @param[in] handle raft handle - * @param[out] min will contain the reduced output (Length = `m`) - * (on device) - * @param[in] x first matrix. Row major. Dim = `m x k`. - * (on device). - * @param[in] y second matrix. Row major. Dim = `n x k`. - * (on device). - * @param[in] m gemm m - * @param[in] n gemm n - * @param[in] k gemm k - * @param[in] sqrt Whether the output `minDist` should contain L2-sqrt - * @param[in] metric Distance metric to be used (supports L2, cosine) - * @param[in] isRowMajor whether the input/output is row or column major. - * @param[in] metric_arg power argument for distances like Minkowski (not supported for now) - */ -void one_nn_argmin(raft::resources const& handle, - int* min, - const float* x, - const float* y, - int m, - int n, - int k, - bool sqrt, - cuvs::distance::DistanceType metric, - bool isRowMajor, - float metric_arg); - -/** @} */ // end group fused_distance_nn_min_arg_runtime - -} // end namespace cuvs::distance diff --git a/cpp/src/distance/fused_nn-ext.cuh b/cpp/src/distance/fused_nn-ext.cuh deleted file mode 100644 index 049f96448..000000000 --- a/cpp/src/distance/fused_nn-ext.cuh +++ /dev/null @@ -1,84 +0,0 @@ -/* - * Copyright (c) 2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "fused_nn_helpers.cuh" // include initialize and reduce operations -#include // raft::KeyValuePair -#include // raft::resources -#include // RAFT_EXPLICIT - -#include // int64_t - -#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY - -namespace raft { -namespace distance { - -template -void fusedDistanceNNMinReduce(OutT* min, - const DataT* x, - const DataT* y, - const DataT* xn, - const DataT* yn, - IdxT m, - IdxT n, - IdxT k, - void* workspace, - bool sqrt, - bool initOutBuffer, - bool isRowMajor, - cuvs::distance::DistanceType metric, - float metric_arg, - cudaStream_t stream) RAFT_EXPLICIT; - -} // namespace distance -} // namespace raft - -#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY - -#define instantiate_raft_distance_fusedDistanceNNMinReduce(DataT, OutT, IdxT) \ - extern template void cuvs::distance::fusedDistanceNNMinReduce( \ - OutT * min, \ - const DataT* x, \ - const DataT* y, \ - const DataT* xn, \ - const DataT* yn, \ - IdxT m, \ - IdxT n, \ - IdxT k, \ - void* workspace, \ - bool sqrt, \ - bool initOutBuffer, \ - bool isRowMajor, \ - cuvs::distance::DistanceType metric, \ - float metric_arg, \ - cudaStream_t stream) - -instantiate_raft_distance_fusedDistanceNNMinReduce(float, float, int); -instantiate_raft_distance_fusedDistanceNNMinReduce(float, float, int64_t); - -// We can't have comma's in the macro expansion, so we use the COMMA macro: -#define COMMA , - -instantiate_raft_distance_fusedDistanceNNMinReduce(float, raft::KeyValuePair, int); -instantiate_raft_distance_fusedDistanceNNMinReduce(float, - raft::KeyValuePair, - int64_t); - -#undef COMMA - -#undef instantiate_raft_distance_fusedDistanceNNMinReduce diff --git a/cpp/src/distance/fused_nn-inl.cuh b/cpp/src/distance/fused_nn-inl.cuh deleted file mode 100644 index 6cad0bd30..000000000 --- a/cpp/src/distance/fused_nn-inl.cuh +++ /dev/null @@ -1,330 +0,0 @@ -/* - * Copyright (c) 2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef __FUSED_DISTANCE_NN_H -#define __FUSED_DISTANCE_NN_H - -#pragma once - -#include "detail/fused_distance_nn.cuh" -#include "fused_nn_helpers.cuh" -#include -#include -#include - -#include - -#include - -#include -#include - -namespace cuvs { -namespace distance { - -/** - * \ingroup fused_l2_nn - * @{ - */ -/** - * @brief Fused L2 distance and 1-nearest-neighbor computation in a single call. - * - * The benefits of such a call are 2-fold: 1) eliminate the need for an - * intermediate buffer to store the output of gemm 2) reduce the memory read - * traffic on this intermediate buffer, otherwise needed during the reduction - * phase for 1-NN. - * - * @tparam DataT data type - * @tparam OutT output type to either store 1-NN indices and their minimum - * distances or store only the min distances. Accordingly, one - * has to pass an appropriate `ReduceOpT` - * @tparam IdxT indexing arithmetic type - * @tparam ReduceOpT A struct to perform the final needed reduction operation - * and also to initialize the output array elements with the - * appropriate initial value needed for reduction. - * @tparam KVPReduceOpT A struct providing functions for key-value pair comparison. - * - * @param[out] min will contain the reduced output (Length = `m`) - * (on device) - * @param[in] x first matrix. Row major. Dim = `m x k`. - * (on device). - * @param[in] y second matrix. Row major. Dim = `n x k`. - * (on device). - * @param[in] xn L2 squared norm of `x`. Length = `m`. (on device). - * @param[in] yn L2 squared norm of `y`. Length = `n`. (on device) - * @param[in] m gemm m - * @param[in] n gemm n - * @param[in] k gemm k - * @param[in] workspace temp workspace. Size = sizeof(int)*m. (on device) - * @param[in] redOp reduction operator in the epilogue - * @param[in] pairRedOp reduction operation on key value pairs - * @param[in] sqrt Whether the output `minDist` should contain L2-sqrt - * @param[in] initOutBuffer whether to initialize the output buffer before the - * main kernel launch - * @param[in] isRowMajor whether the input/output is row or column major. - * @param[in] metric Distance metric to be used (supports L2, cosine) - * @param[in] metric_arg power argument for distances like Minkowski (not supported for now) - * @param[in] stream cuda stream - */ -template -void fusedDistanceNN(OutT* min, - const DataT* x, - const DataT* y, - const DataT* xn, - const DataT* yn, - IdxT m, - IdxT n, - IdxT k, - void* workspace, - ReduceOpT redOp, - KVPReduceOpT pairRedOp, - bool sqrt, - bool initOutBuffer, - bool isRowMajor, - cuvs::distance::DistanceType metric, - float metric_arg, - cudaStream_t stream) -{ - ASSERT(isRowMajor, "fusedDistanceNN only supports row major inputs"); - // When k is smaller than 32, the Policy4x4 results in redundant calculations - // as it uses tiles that have k=32. Therefore, use a "skinny" policy instead - // that uses tiles with a smaller value of k. - bool is_skinny = k < 32; - - size_t bytes = sizeof(DataT) * k; - auto px = reinterpret_cast(x); - auto py = reinterpret_cast(y); - if (16 % sizeof(DataT) == 0 && bytes % 16 == 0 && px % 16 == 0 && py % 16 == 0) { - if (is_skinny) { - detail::fusedDistanceNNImpl< - DataT, - OutT, - IdxT, - typename raft::linalg::Policy4x4Skinny::Policy, - ReduceOpT>(min, - x, - y, - xn, - yn, - m, - n, - k, - (int*)workspace, - redOp, - pairRedOp, - sqrt, - initOutBuffer, - isRowMajor, - metric, - metric_arg, - stream); - } else { - detail::fusedDistanceNNImpl< - DataT, - OutT, - IdxT, - typename raft::linalg::Policy4x4::Policy, - ReduceOpT>(min, - x, - y, - xn, - yn, - m, - n, - k, - (int*)workspace, - redOp, - pairRedOp, - sqrt, - initOutBuffer, - isRowMajor, - metric, - metric_arg, - stream); - } - } else if (8 % sizeof(DataT) == 0 && bytes % 8 == 0 && px % 8 == 0 && py % 8 == 0) { - if (is_skinny) { - detail::fusedDistanceNNImpl< - DataT, - OutT, - IdxT, - typename raft::linalg::Policy4x4Skinny::Policy, - ReduceOpT>(min, - x, - y, - xn, - yn, - m, - n, - k, - (int*)workspace, - redOp, - pairRedOp, - sqrt, - initOutBuffer, - isRowMajor, - metric, - metric_arg, - stream); - } else { - detail::fusedDistanceNNImpl< - DataT, - OutT, - IdxT, - typename raft::linalg::Policy4x4::Policy, - ReduceOpT>(min, - x, - y, - xn, - yn, - m, - n, - k, - (int*)workspace, - redOp, - pairRedOp, - sqrt, - initOutBuffer, - isRowMajor, - metric, - metric_arg, - stream); - } - } else { - if (is_skinny) { - detail::fusedDistanceNNImpl::Policy, - ReduceOpT>(min, - x, - y, - xn, - yn, - m, - n, - k, - (int*)workspace, - redOp, - pairRedOp, - sqrt, - initOutBuffer, - isRowMajor, - metric, - metric_arg, - stream); - } else { - detail::fusedDistanceNNImpl::Policy, - ReduceOpT>(min, - x, - y, - xn, - yn, - m, - n, - k, - (int*)workspace, - redOp, - pairRedOp, - sqrt, - initOutBuffer, - isRowMajor, - metric, - metric_arg, - stream); - } - } -} - -/** - * @brief Wrapper around fusedDistanceNN with minimum reduction operators. - * - * fusedDistanceNN cannot be compiled in the distance library due to the lambda - * operators, so this wrapper covers the most common case (minimum). - * - * @tparam DataT data type - * @tparam OutT output type to either store 1-NN indices and their minimum - * distances (e.g. raft::KeyValuePair) or store only the min - * distances. - * @tparam IdxT indexing arithmetic type - * @param[out] min will contain the reduced output (Length = `m`) - * (on device) - * @param[in] x first matrix. Row major. Dim = `m x k`. - * (on device). - * @param[in] y second matrix. Row major. Dim = `n x k`. - * (on device). - * @param[in] xn L2 squared norm of `x`. Length = `m`. (on device). - * @param[in] yn L2 squared norm of `y`. Length = `n`. (on device) - * @param[in] m gemm m - * @param[in] n gemm n - * @param[in] k gemm k - * @param[in] workspace temp workspace. Size = sizeof(int)*m. (on device) - * @param[in] sqrt Whether the output `minDist` should contain L2-sqrt - * @param[in] initOutBuffer whether to initialize the output buffer before the - * main kernel launch - * @param[in] isRowMajor whether the input/output is row or column major. - * @param[in] metric Distance metric to be used (supports L2, cosine) - * @param[in] metric_arg power argument for distances like Minkowski (not supported for now) - * @param[in] stream cuda stream - */ -template -void fusedDistanceNNMinReduce(OutT* min, - const DataT* x, - const DataT* y, - const DataT* xn, - const DataT* yn, - IdxT m, - IdxT n, - IdxT k, - void* workspace, - bool sqrt, - bool initOutBuffer, - bool isRowMajor, - cuvs::distance::DistanceType metric, - float metric_arg, - cudaStream_t stream) -{ - MinAndDistanceReduceOp redOp; - KVPMinReduce pairRedOp; - - fusedDistanceNN(min, - x, - y, - xn, - yn, - m, - n, - k, - workspace, - redOp, - pairRedOp, - sqrt, - initOutBuffer, - isRowMajor, - metric, - metric_arg, - stream); -} - -/** @} */ - -} // namespace distance -} // namespace cuvs - -#endif diff --git a/cpp/src/distance/fused_nn.cu b/cpp/src/distance/fused_nn.cu deleted file mode 100644 index 0b191ac6f..000000000 --- a/cpp/src/distance/fused_nn.cu +++ /dev/null @@ -1,53 +0,0 @@ -/* - * Copyright (c) 2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "fused_nn-inl.cuh" -#include // raft::KeyValuePair - -#include // int64_t - -#define instantiate_raft_distance_fusedDistanceNNMinReduce(DataT, OutT, IdxT) \ - template void cuvs::distance::fusedDistanceNNMinReduce( \ - OutT * min, \ - const DataT* x, \ - const DataT* y, \ - const DataT* xn, \ - const DataT* yn, \ - IdxT m, \ - IdxT n, \ - IdxT k, \ - void* workspace, \ - bool sqrt, \ - bool initOutBuffer, \ - bool isRowMajor, \ - cuvs::distance::DistanceType metric, \ - float metric_arg, \ - cudaStream_t stream) - -instantiate_raft_distance_fusedDistanceNNMinReduce(float, float, int); -instantiate_raft_distance_fusedDistanceNNMinReduce(float, float, int64_t); - -// We can't have comma's in the macro expansion, so we use the COMMA macro: -#define COMMA , - -instantiate_raft_distance_fusedDistanceNNMinReduce(float, raft::KeyValuePair, int); -instantiate_raft_distance_fusedDistanceNNMinReduce(float, - raft::KeyValuePair, - int64_t); - -#undef COMMA - -#undef instantiate_raft_distance_fusedDistanceNNMinReduce diff --git a/cpp/src/distance/fused_nn.cuh b/cpp/src/distance/fused_nn.cuh deleted file mode 100755 index d9e442c54..000000000 --- a/cpp/src/distance/fused_nn.cuh +++ /dev/null @@ -1,24 +0,0 @@ -/* - * Copyright (c) 2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#ifndef RAFT_EXPLICIT_INSTANTIATE_ONLY -#include "fused_nn-inl.cuh" -#endif - -#ifdef RAFT_COMPILED -#include "fused_distance_nn-ext.cuh" -#endif diff --git a/cpp/src/distance/fused_nn_helpers.cuh b/cpp/src/distance/fused_nn_helpers.cuh deleted file mode 100644 index 7dd370fa0..000000000 --- a/cpp/src/distance/fused_nn_helpers.cuh +++ /dev/null @@ -1,50 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "detail/fused_distance_nn/helper_structs.cuh" -#include - -namespace cuvs::distance { - -/** - * \defgroup fused_l2_nn Fused 1-nearest neighbors - * @{ - */ - -template -using KVPMinReduce = detail::KVPMinReduceImpl; - -template -using MinAndDistanceReduceOp = detail::MinAndDistanceReduceOpImpl; - -template -using MinReduceOp = detail::MinReduceOpImpl; - -/** @} */ - -/** - * Initialize array using init value from reduction op - */ -template -void initialize(raft::resources const& handle, OutT* min, IdxT m, DataT maxVal, ReduceOpT redOp) -{ - detail::initialize( - min, m, maxVal, redOp, raft::resource::get_cuda_stream(handle)); -} - -} // namespace cuvs::distance diff --git a/cpp/src/distance/kernels.cuh b/cpp/src/distance/kernels.cuh deleted file mode 100644 index 8fb92d46c..000000000 --- a/cpp/src/distance/kernels.cuh +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Copyright (c) 2019-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include -#include - -namespace cuvs::distance::kernels { - -// TODO: Need to expose formal APIs for this that are more consistent w/ other APIs in RAFT -using cuvs::distance::kernels::detail::GramMatrixBase; -using cuvs::distance::kernels::detail::KernelFactory; - -}; // namespace cuvs::distance::kernels diff --git a/cpp/src/distance/masked_nn.cuh b/cpp/src/distance/masked_nn.cuh deleted file mode 100644 index 866a657d0..000000000 --- a/cpp/src/distance/masked_nn.cuh +++ /dev/null @@ -1,201 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef __MASKED_L2_NN_H -#define __MASKED_L2_NN_H - -#pragma once - -#include -#include -#include -#include - -#include - -#include - -namespace cuvs { -namespace distance { -/** - * \defgroup masked_nn Masked 1-nearest neighbors - * @{ - */ - -/** - * @brief Parameter struct for masked_l2_nn function - * - * @tparam ReduceOpT Type of reduction operator in the epilogue. - * @tparam KVPReduceOpT Type of Reduction operation on key value pairs. - * - * Usage example: - * @code{.cpp} - * #include - * - * using IdxT = int; - * using DataT = float; - * using RedOpT = cuvs::distance::MinAndDistanceReduceOp; - * using PairRedOpT = cuvs::distance::KVPMinReduce; - * using ParamT = cuvs::distance::masked_l2_nn_params; - * - * bool init_out = true; - * bool sqrt = false; - * - * ParamT masked_l2_params{RedOpT{}, PairRedOpT{}, sqrt, init_out}; - * @endcode - * - * Prescribes how to reduce a distance to an intermediate type (`redOp`), and - * how to reduce two intermediate types (`pairRedOp`). Typically, a distance is - * mapped to an (index, value) pair and (index, value) pair with the lowest - * value (distance) is selected. - * - * In addition, prescribes whether to compute the square root of the distance - * (`sqrt`) and whether to initialize the output buffer (`initOutBuffer`). - */ -template -struct masked_l2_nn_params { - /** Reduction operator in the epilogue */ - ReduceOpT redOp; - /** Reduction operation on key value pairs */ - KVPReduceOpT pairRedOp; - /** Whether the output `minDist` should contain L2-sqrt */ - bool sqrt; - /** Whether to initialize the output buffer before the main kernel launch */ - bool initOutBuffer; -}; - -/** - * @brief Masked L2 distance and 1-nearest-neighbor computation in a single call. - * - * This function enables faster computation of nearest neighbors if the - * computation of distances between certain point pairs can be skipped. - * - * We use an adjacency matrix that describes which distances to calculate. The - * points in `y` are divided into groups, and the adjacency matrix indicates - * whether to compute distances between points in `x` and groups in `y`. In other - * words, if `adj[i,k]` is true then distance between point `x_i`, and points in - * `group_k` will be calculated. - * - * **Performance considerations** - * - * The points in `x` are processed in tiles of `M` points (`M` is currently 64, - * but may change in the future). As a result, the largest compute time - * reduction occurs if all `M` points can skip a group. If only part of the `M` - * points can skip a group, then at most a minor compute time reduction and a - * modest energy use reduction can be expected. - * - * The points in `y` are also grouped into tiles of `N` points (`N` is currently - * 64, but may change in the future). As a result, group sizes should be larger - * than `N` to avoid wasting computational resources. If the group sizes are - * evenly divisible by `N`, then the computation is most efficient, although for - * larger group sizes this effect is minor. - * - * - * **Comparison to SDDM** - * - * [SDDMM](https://ieeexplore.ieee.org/document/8638042) (sampled dense-dense - * matrix multiplication) is a matrix-matrix multiplication where only part of - * the output is computed. Compared to masked_l2_nn, there are a few differences: - * - * - The output of masked_l2_nn is a single vector (of nearest neighbors) and not - * a sparse matrix. - * - * - The sampling in masked_l2_nn is expressed through intermediate "groups" - rather than a CSR format. - * - * @tparam DataT data type - * @tparam OutT output type to either store 1-NN indices and their minimum - * distances or store only the min distances. Accordingly, one - * has to pass an appropriate `ReduceOpT` - * @tparam IdxT indexing arithmetic type - * @tparam ReduceOpT A struct to perform the final needed reduction operation - * and also to initialize the output array elements with the - * appropriate initial value needed for reduction. - * - * @param handle RAFT handle for managing expensive resources - * @param params Parameter struct specifying the reduction operations. - * @param[in] x First matrix. Row major. Dim = `m x k`. - * (on device). - * @param[in] y Second matrix. Row major. Dim = `n x k`. - * (on device). - * @param[in] x_norm L2 squared norm of `x`. Length = `m`. (on device). - * @param[in] y_norm L2 squared norm of `y`. Length = `n`. (on device) - * @param[in] adj A boolean adjacency matrix indicating for each - * row of `x` and each group in `y` whether to compute the - * distance. Dim = `m x num_groups`. - * @param[in] group_idxs An array containing the *end* indices of each group - * in `y`. The value of group_idxs[j] indicates the - * start of group j + 1, i.e., it is the inclusive - * scan of the group lengths. The first group is - * always assumed to start at index 0 and the last - * group typically ends at index `n`. Length = - * `num_groups`. - * @param[out] out will contain the reduced output (Length = `m`) - * (on device) - */ -template -void masked_l2_nn(raft::resources const& handle, - cuvs::distance::masked_l2_nn_params params, - raft::device_matrix_view x, - raft::device_matrix_view y, - raft::device_vector_view x_norm, - raft::device_vector_view y_norm, - raft::device_matrix_view adj, - raft::device_vector_view group_idxs, - raft::device_vector_view out) -{ - IdxT m = x.extent(0); - IdxT n = y.extent(0); - IdxT k = x.extent(1); - IdxT num_groups = group_idxs.extent(0); - - // Match k dimension of x, y - RAFT_EXPECTS(x.extent(1) == y.extent(1), "Dimension of vectors in x and y must be equal."); - // Match x, x_norm and y, y_norm - RAFT_EXPECTS(m == x_norm.extent(0), "Length of `x_norm` must match input `x`."); - RAFT_EXPECTS(n == y_norm.extent(0), "Length of `y_norm` must match input `y` "); - // Match adj to x and group_idxs - RAFT_EXPECTS(m == adj.extent(0), "#rows in `adj` must match input `x`."); - RAFT_EXPECTS(num_groups == adj.extent(1), "#cols in `adj` must match length of `group_idxs`."); - // NOTE: We do not check if all indices in group_idxs actually points *inside* y. - - // If there is no work to be done, return immediately. - if (m == 0 || n == 0 || k == 0 || num_groups == 0) { return; } - - detail::masked_l2_nn_impl(handle, - out.data_handle(), - x.data_handle(), - y.data_handle(), - x_norm.data_handle(), - y_norm.data_handle(), - adj.data_handle(), - group_idxs.data_handle(), - num_groups, - m, - n, - k, - params.redOp, - params.pairRedOp, - params.sqrt, - params.initOutBuffer); -} - -/** @} */ - -} // namespace distance -} // namespace cuvs - -#endif