Skip to content

Commit

Permalink
Exposing kernel gramm APIs
Browse files Browse the repository at this point in the history
  • Loading branch information
cjnolet committed Nov 13, 2024
1 parent 9458ae9 commit 557c2aa
Show file tree
Hide file tree
Showing 10 changed files with 1,695 additions and 1,053 deletions.
3 changes: 3 additions & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,9 @@ if(BUILD_SHARED_LIBS)
src/cluster/kmeans_transform_float.cu
src/cluster/single_linkage_float.cu
src/core/bitset.cu
src/distance/detail/kernels/gram_matrix.cu
src/distance/detail/kernels/kernel_factory.cu
src/distance/detail/kernels/kernel_matrices.cu
src/distance/detail/pairwise_matrix/dispatch_canberra_float_float_float_int.cu
src/distance/detail/pairwise_matrix/dispatch_canberra_half_float_float_int.cu
src/distance/detail/pairwise_matrix/dispatch_canberra_double_double_double_int.cu
Expand Down
478 changes: 478 additions & 0 deletions cpp/src/distance/detail/kernels/gram_matrix.cu

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,19 @@

#pragma once

#include "../../distance.cuh"
#include "cublas.h"
#include <cuvs/distance/distance.hpp>
#include <raft/core/device_csr_matrix.hpp>
#include <raft/core/mdspan.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resources.hpp>
// #include <raft/sparse/detail/cusparse_wrappers.h>
#include <raft/linalg/detail/cublas_wrappers.hpp>
#include <raft/linalg/gemm.cuh>
#include <raft/sparse/distance/distance.cuh>
#include <raft/sparse/linalg/spmm.hpp>

namespace cuvs::distance::kernels::detail {

template <typename math_t>
using dense_input_matrix_view_t = raft::device_matrix_view<const math_t, int, layout_stride>;
using dense_input_matrix_view_t = raft::device_matrix_view<const math_t, int, raft::layout_stride>;
template <typename math_t>
using dense_output_matrix_view_t = raft::device_matrix_view<math_t, int, layout_stride>;
using dense_output_matrix_view_t = raft::device_matrix_view<math_t, int, raft::layout_stride>;
template <typename math_t>
using csr_input_matrix_view_t = raft::device_csr_matrix_view<const math_t, int, int, int>;

Expand Down Expand Up @@ -76,10 +72,7 @@ class GramMatrixBase {
dense_input_matrix_view_t<math_t> x2,
dense_output_matrix_view_t<math_t> out,
math_t* norm_x1 = nullptr,
math_t* norm_x2 = nullptr)
{
evaluate(handle, x1, x2, out, norm_x1, norm_x2);
}
math_t* norm_x2 = nullptr);

/** Convenience function to evaluate the Gram matrix for two vector sets.
* Vector sets are provided in Matrix format
Expand All @@ -96,10 +89,7 @@ class GramMatrixBase {
dense_input_matrix_view_t<math_t> x2,
dense_output_matrix_view_t<math_t> out,
math_t* norm_x1 = nullptr,
math_t* norm_x2 = nullptr)
{
evaluate(handle, x1, x2, out, norm_x1, norm_x2);
}
math_t* norm_x2 = nullptr);

/** Convenience function to evaluate the Gram matrix for two vector sets.
* Vector sets are provided in Matrix format
Expand All @@ -116,10 +106,7 @@ class GramMatrixBase {
csr_input_matrix_view_t<math_t> x2,
dense_output_matrix_view_t<math_t> out,
math_t* norm_x1 = nullptr,
math_t* norm_x2 = nullptr)
{
evaluate(handle, x1, x2, out, norm_x1, norm_x2);
}
math_t* norm_x2 = nullptr);

// unfortunately, 'evaluate' cannot be templatized as it needs to be virtual

Expand All @@ -137,10 +124,8 @@ class GramMatrixBase {
dense_input_matrix_view_t<math_t> x2,
dense_output_matrix_view_t<math_t> out,
math_t* norm_x1,
math_t* norm_x2)
{
linear(handle, x1, x2, out);
}
math_t* norm_x2);

/** Evaluate the Gram matrix for two vector sets using simple dot product.
*
* @param [in] handle raft handle
Expand All @@ -155,10 +140,8 @@ class GramMatrixBase {
dense_input_matrix_view_t<math_t> x2,
dense_output_matrix_view_t<math_t> out,
math_t* norm_x1,
math_t* norm_x2)
{
linear(handle, x1, x2, out);
}
math_t* norm_x2);

/** Evaluate the Gram matrix for two vector sets using simple dot product.
*
* @param [in] handle raft handle
Expand All @@ -173,10 +156,7 @@ class GramMatrixBase {
csr_input_matrix_view_t<math_t> x2,
dense_output_matrix_view_t<math_t> out,
math_t* norm_x1,
math_t* norm_x2)
{
linear(handle, x1, x2, out);
}
math_t* norm_x2);

/** Evaluate the Gram matrix for two vector sets using simple dot product.
*
Expand All @@ -203,10 +183,7 @@ class GramMatrixBase {
cudaStream_t stream,
int ld1,
int ld2,
int ld_out)
{
linear(x1, n1, n_cols, x2, n2, out, is_row_major, stream, ld1, ld2, ld_out);
}
int ld_out);

/** Convenience function to evaluate the Gram matrix for two vector sets.
*
Expand All @@ -233,14 +210,7 @@ class GramMatrixBase {
cudaStream_t stream,
int ld1 = 0,
int ld2 = 0,
int ld_out = 0)
{
ASSERT(legacy_interface, "Legacy interface can only be used with legacy ctor.");
if (ld1 <= 0) { ld1 = is_row_major ? n_cols : n1; }
if (ld2 <= 0) { ld2 = is_row_major ? n_cols : n2; }
if (ld_out <= 0) { ld_out = is_row_major ? n2 : n1; }
evaluate(x1, n1, n_cols, x2, n2, out, is_row_major, stream, ld1, ld2, ld_out);
}
int ld_out = 0);

protected:
/** Calculates the Gram matrix using simple dot product between vector sets.
Expand Down Expand Up @@ -272,67 +242,13 @@ class GramMatrixBase {
cudaStream_t stream,
int ld1,
int ld2,
int ld_out)
{
math_t alpha = 1.0;
math_t beta = 0.0;
if (is_row_major) {
// #TODO: Call from public API when ready
RAFT_CUBLAS_TRY(raft::linalg::detail::cublasgemm(cublas_handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
n2,
n1,
n_cols,
&alpha,
x2,
ld2,
x1,
ld1,
&beta,
out,
ld_out,
stream));
} else {
// #TODO: Call from public API when ready
RAFT_CUBLAS_TRY(raft::linalg::detail::cublasgemm(cublas_handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
n1,
n2,
n_cols,
&alpha,
x1,
ld1,
x2,
ld2,
&beta,
out,
ld_out,
stream));
}
}
int ld_out);

protected:
bool get_is_row_major(dense_output_matrix_view_t<math_t> matrix)
{
return (matrix.stride(1) == 1);
}

bool get_is_row_major(dense_input_matrix_view_t<math_t> matrix)
{
return (matrix.stride(1) == 1);
}

bool get_is_col_major(dense_output_matrix_view_t<math_t> matrix)
{
return (matrix.stride(0) == 1);
}

bool get_is_col_major(dense_input_matrix_view_t<math_t> matrix)
{
return (matrix.stride(0) == 1);
}
bool get_is_row_major(dense_output_matrix_view_t<math_t> matrix);
bool get_is_row_major(dense_input_matrix_view_t<math_t> matrix);
bool get_is_col_major(dense_output_matrix_view_t<math_t> matrix);
bool get_is_col_major(dense_input_matrix_view_t<math_t> matrix);

/** Calculates the Gram matrix using simple dot product between vector sets.
*
Expand All @@ -348,67 +264,7 @@ class GramMatrixBase {
void linear(raft::resources const& handle,
dense_input_matrix_view_t<math_t> x1,
dense_input_matrix_view_t<math_t> x2,
dense_output_matrix_view_t<math_t> out)
{
// check is_row_major consistency
bool is_row_major = get_is_row_major(x1) && get_is_row_major(x2) && get_is_row_major(out);
bool is_col_major = get_is_col_major(x1) && get_is_col_major(x2) && get_is_col_major(out);
ASSERT(is_row_major || is_col_major,
"GramMatrix leading dimensions for x1, x2 and out do not match");

// check dimensions
int n1 = out.extent(0);
int n2 = out.extent(1);
int n_cols = x1.extent(1);
ASSERT(x1.extent(0) == n1, "GramMatrix input matrix dimensions for x1 and out do not match");
ASSERT(x2.extent(0) == n2, "GramMatrix input matrix dimensions for x2 and out do not match");
ASSERT(x2.extent(1) == n_cols, "GramMatrix input matrix dimensions for x1 and x2 do not match");

// extract major stride
int ld1 = is_row_major ? x1.stride(0) : x1.stride(1);
int ld2 = is_row_major ? x2.stride(0) : x2.stride(1);
int ld_out = is_row_major ? out.stride(0) : out.stride(1);

math_t alpha = 1.0;
math_t beta = 0.0;
if (is_row_major) {
// #TODO: Use mdspan-based API when stride-capable
// https://github.com/rapidsai/raft/issues/875
raft::linalg::gemm(handle,
true,
false,
n2,
n1,
n_cols,
&alpha,
x2.data_handle(),
ld2,
x1.data_handle(),
ld1,
&beta,
out.data_handle(),
ld_out,
resource::get_cuda_stream(handle));
} else {
// #TODO: Use mdspan-based API when stride-capable
// https://github.com/rapidsai/raft/issues/875
raft::linalg::gemm(handle,
false,
true,
n1,
n2,
n_cols,
&alpha,
x1.data_handle(),
ld1,
x2.data_handle(),
ld2,
&beta,
out.data_handle(),
ld_out,
resource::get_cuda_stream(handle));
}
}
dense_output_matrix_view_t<math_t> out);

/** Calculates the Gram matrix using simple dot product between vector sets.
*
Expand All @@ -424,28 +280,7 @@ class GramMatrixBase {
void linear(raft::resources const& handle,
csr_input_matrix_view_t<math_t> x1,
dense_input_matrix_view_t<math_t> x2,
dense_output_matrix_view_t<math_t> out)
{
// check is_row_major consistency
bool is_row_major = get_is_row_major(x2) && get_is_row_major(out);
bool is_col_major = get_is_col_major(x2) && get_is_col_major(out);
ASSERT(is_row_major || is_col_major,
"GramMatrix leading dimensions for x2 and out do not match");

// check dimensions
auto x1_structure = x1.structure_view();
ASSERT(x1_structure.get_n_rows() == out.extent(0),
"GramMatrix input matrix dimensions for x1 and out do not match");
ASSERT(x2.extent(0) == out.extent(1),
"GramMatrix input matrix dimensions for x2 and out do not match");
ASSERT(x2.extent(1) == x1_structure.get_n_cols(),
"GramMatrix input matrix dimensions for x1 and x2 do not match");

math_t alpha = 1.0;
math_t beta = 0.0;

raft::sparse::linalg::spmm(handle, false, true, &alpha, x1, x2, &beta, out);
}
dense_output_matrix_view_t<math_t> out);

/** Calculates the Gram matrix using simple dot product between vector sets.
*
Expand All @@ -461,28 +296,6 @@ class GramMatrixBase {
void linear(raft::resources const& handle,
csr_input_matrix_view_t<math_t> x1,
csr_input_matrix_view_t<math_t> x2,
dense_output_matrix_view_t<math_t> out)
{
// check layout consistency (w.r.t. strides a matrix might be both row & col major)
bool is_row_major_nopad = get_is_row_major(out) && out.stride(0) == out.extent(1);
bool is_col_major_nopad = get_is_col_major(out) && out.stride(1) == out.extent(0);

ASSERT(is_row_major_nopad || is_col_major_nopad,
"Sparse linear Kernel distance does not support ld_out parameter");

// switch a,b based on is_row_major
if (is_col_major_nopad) {
auto out_row_major = raft::make_device_matrix_view<math_t, int, raft::row_major>(
out.data_handle(), out.extent(1), out.extent(0));
raft::sparse::distance::pairwise_distance(
handle, x2, x1, out_row_major, cuvs::distance::DistanceType::InnerProduct, 0.0);
} else {
auto out_row_major = raft::make_device_matrix_view<math_t, int, raft::row_major>(
out.data_handle(), out.extent(0), out.extent(1));
raft::sparse::distance::pairwise_distance(
handle, x1, x2, out_row_major, cuvs::distance::DistanceType::InnerProduct, 0.0);
}
}
dense_output_matrix_view_t<math_t> out);
};

}; // end namespace cuvs::distance::kernels::detail
Loading

0 comments on commit 557c2aa

Please sign in to comment.