Skip to content

Commit

Permalink
Consider v0 as optional in c++
Browse files Browse the repository at this point in the history
  • Loading branch information
lowener committed Nov 14, 2024
1 parent 80dafb9 commit f8078e1
Show file tree
Hide file tree
Showing 6 changed files with 110 additions and 68 deletions.
42 changes: 30 additions & 12 deletions cpp/include/raft/sparse/solver/detail/lanczos.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -2071,22 +2071,40 @@ auto lanczos_smallest(
template <typename IndexTypeT, typename ValueTypeT>
auto lanczos_compute_smallest_eigenvectors(
raft::resources const& handle,
raft::device_csr_matrix_view<ValueTypeT, IndexTypeT, IndexTypeT, IndexTypeT> A,
lanczos_solver_config<ValueTypeT> const& config,
raft::device_vector_view<ValueTypeT, uint32_t> v0,
raft::device_csr_matrix_view<ValueTypeT, IndexTypeT, IndexTypeT, IndexTypeT> A,
std::optional<raft::device_vector_view<ValueTypeT, uint32_t>> v0,
raft::device_vector_view<ValueTypeT, uint32_t> eigenvalues,
raft::device_matrix_view<ValueTypeT, uint32_t, raft::col_major> eigenvectors) -> int
{
return lanczos_smallest(handle,
A,
config.n_components,
config.max_iterations,
config.ncv,
config.tolerance,
eigenvalues.data_handle(),
eigenvectors.data_handle(),
v0.data_handle(),
config.seed);
if (v0.has_value()) {
return lanczos_smallest(handle,
A,
config.n_components,
config.max_iterations,
config.ncv,
config.tolerance,
eigenvalues.data_handle(),
eigenvectors.data_handle(),
v0->data_handle(),
config.seed);
} else {
// Handle the optional v0 initial Lanczos vector if nullopt is used
auto n = A.structure_view().get_n_rows();
auto temp_v0 = raft::make_device_vector<ValueTypeT, uint32_t>(handle, n);
raft::random::RngState rng_state(config.seed);
raft::random::uniform(handle, rng_state, temp_v0.view(), ValueTypeT{0.0}, ValueTypeT{1.0});
return lanczos_smallest(handle,
A,
config.n_components,
config.max_iterations,
config.ncv,
config.tolerance,
eigenvalues.data_handle(),
eigenvectors.data_handle(),
temp_v0.data_handle(),
config.seed);
}
}

} // namespace raft::sparse::solver::detail
20 changes: 10 additions & 10 deletions cpp/include/raft/sparse/solver/lanczos.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -33,48 +33,48 @@ namespace raft::sparse::solver {
* @tparam index_type_t the type of data used for indexing.
* @tparam value_type_t the type of data used for weights, distances.
* @param handle the raft handle.
* @param A Sparse matrix in CSR format.
* @param config lanczos config used to set hyperparameters
* @param v0 Initial lanczos vector
* @param A Sparse matrix in CSR format.
* @param v0 Optional Initial lanczos vector
* @param eigenvalues output eigenvalues
* @param eigenvectors output eigenvectors
* @return Zero if successful. Otherwise non-zero.
*/
template <typename IndexTypeT, typename ValueTypeT>
auto lanczos_compute_smallest_eigenvectors(
raft::resources const& handle,
raft::device_csr_matrix_view<ValueTypeT, IndexTypeT, IndexTypeT, IndexTypeT> A,
lanczos_solver_config<ValueTypeT> const& config,
raft::device_vector_view<ValueTypeT, uint32_t, raft::row_major> v0,
raft::device_csr_matrix_view<ValueTypeT, IndexTypeT, IndexTypeT, IndexTypeT> A,
std::optional<raft::device_vector_view<ValueTypeT, uint32_t, raft::row_major>> v0,
raft::device_vector_view<ValueTypeT, uint32_t, raft::col_major> eigenvalues,
raft::device_matrix_view<ValueTypeT, uint32_t, raft::col_major> eigenvectors) -> int
{
return detail::lanczos_compute_smallest_eigenvectors<IndexTypeT, ValueTypeT>(
handle, A, config, v0, eigenvalues, eigenvectors);
handle, config, A, v0, eigenvalues, eigenvectors);
}

/**
* @brief Find the smallest eigenpairs using lanczos solver
* @tparam index_type_t the type of data used for indexing.
* @tparam value_type_t the type of data used for weights, distances.
* @param handle the raft handle.
* @param config lanczos config used to set hyperparameters
* @param rows Vector view of the rows of the sparse matrix.
* @param cols Vector view of the cols of the sparse matrix.
* @param vals Vector view of the vals of the sparse matrix.
* @param config lanczos config used to set hyperparameters
* @param v0 Initial lanczos vector
* @param v0 Optional Initial lanczos vector
* @param eigenvalues output eigenvalues
* @param eigenvectors output eigenvectors
* @return Zero if successful. Otherwise non-zero.
*/
template <typename IndexTypeT, typename ValueTypeT>
auto lanczos_compute_smallest_eigenvectors(
raft::resources const& handle,
lanczos_solver_config<ValueTypeT> const& config,
raft::device_vector_view<IndexTypeT, uint32_t, raft::row_major> rows,
raft::device_vector_view<IndexTypeT, uint32_t, raft::row_major> cols,
raft::device_vector_view<ValueTypeT, uint32_t, raft::row_major> vals,
lanczos_solver_config<ValueTypeT> const& config,
raft::device_vector_view<ValueTypeT, uint32_t, raft::row_major> v0,
std::optional<raft::device_vector_view<ValueTypeT, uint32_t, raft::row_major>> v0,
raft::device_vector_view<ValueTypeT, uint32_t, raft::col_major> eigenvalues,
raft::device_matrix_view<ValueTypeT, uint32_t, raft::col_major> eigenvectors) -> int
{
Expand All @@ -95,7 +95,7 @@ auto lanczos_compute_smallest_eigenvectors(
const_cast<ValueTypeT*>(vals.data_handle()), csr_structure);

return lanczos_compute_smallest_eigenvectors<IndexTypeT, ValueTypeT>(
handle, csr_matrix, config, v0, eigenvalues, eigenvectors);
handle, config, csr_matrix, v0, eigenvalues, eigenvectors);
}

/**
Expand Down
19 changes: 10 additions & 9 deletions cpp/include/raft_runtime/solver/lanczos.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,16 @@ namespace raft::runtime::solver {
* @{
*/

#define FUNC_DECL(IndexType, ValueType) \
void lanczos_solver(const raft::resources& handle, \
raft::device_vector_view<IndexType, uint32_t, raft::row_major> rows, \
raft::device_vector_view<IndexType, uint32_t, raft::row_major> cols, \
raft::device_vector_view<ValueType, uint32_t, raft::row_major> vals, \
raft::sparse::solver::lanczos_solver_config<ValueType> config, \
raft::device_vector_view<ValueType, uint32_t, raft::row_major> v0, \
raft::device_vector_view<ValueType, uint32_t, raft::col_major> eigenvalues, \
raft::device_matrix_view<ValueType, uint32_t, raft::col_major> eigenvectors)
#define FUNC_DECL(IndexType, ValueType) \
void lanczos_solver( \
const raft::resources& handle, \
raft::sparse::solver::lanczos_solver_config<ValueType> config, \
raft::device_vector_view<IndexType, uint32_t, raft::row_major> rows, \
raft::device_vector_view<IndexType, uint32_t, raft::row_major> cols, \
raft::device_vector_view<ValueType, uint32_t, raft::row_major> vals, \
std::optional<raft::device_vector_view<ValueType, uint32_t, raft::row_major>> v0, \
raft::device_vector_view<ValueType, uint32_t, raft::col_major> eigenvalues, \
raft::device_matrix_view<ValueType, uint32_t, raft::col_major> eigenvectors)

FUNC_DECL(int, float);
FUNC_DECL(int64_t, float);
Expand Down
25 changes: 13 additions & 12 deletions cpp/src/raft_runtime/solver/lanczos_solver.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,17 @@

#include <raft/sparse/solver/lanczos.cuh>

#define FUNC_DEF(IndexType, ValueType) \
void lanczos_solver(const raft::resources& handle, \
raft::device_vector_view<IndexType, uint32_t, raft::row_major> rows, \
raft::device_vector_view<IndexType, uint32_t, raft::row_major> cols, \
raft::device_vector_view<ValueType, uint32_t, raft::row_major> vals, \
raft::sparse::solver::lanczos_solver_config<ValueType> config, \
raft::device_vector_view<ValueType, uint32_t, raft::row_major> v0, \
raft::device_vector_view<ValueType, uint32_t, raft::col_major> eigenvalues, \
raft::device_matrix_view<ValueType, uint32_t, raft::col_major> eigenvectors) \
{ \
raft::sparse::solver::lanczos_compute_smallest_eigenvectors<IndexType, ValueType>( \
handle, rows, cols, vals, config, v0, eigenvalues, eigenvectors); \
#define FUNC_DEF(IndexType, ValueType) \
void lanczos_solver( \
const raft::resources& handle, \
raft::sparse::solver::lanczos_solver_config<ValueType> config, \
raft::device_vector_view<IndexType, uint32_t, raft::row_major> rows, \
raft::device_vector_view<IndexType, uint32_t, raft::row_major> cols, \
raft::device_vector_view<ValueType, uint32_t, raft::row_major> vals, \
std::optional<raft::device_vector_view<ValueType, uint32_t, raft::row_major>> v0, \
raft::device_vector_view<ValueType, uint32_t, raft::col_major> eigenvalues, \
raft::device_matrix_view<ValueType, uint32_t, raft::col_major> eigenvectors) \
{ \
raft::sparse::solver::lanczos_compute_smallest_eigenvectors<IndexType, ValueType>( \
handle, config, rows, cols, vals, v0, eigenvalues, eigenvectors); \
}
14 changes: 12 additions & 2 deletions cpp/test/sparse/solver/lanczos.cu
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,12 @@ class rmat_lanczos_tests

std::get<0>(stats) =
raft::sparse::solver::lanczos_compute_smallest_eigenvectors<IndexType, ValueType>(
handle, csr_matrix, config, v0.view(), eigenvalues.view(), eigenvectors.view());
handle,
config,
csr_matrix,
std::make_optional(v0.view()),
eigenvalues.view(),
eigenvectors.view());

ASSERT_TRUE(raft::devArrMatch<ValueType>(eigenvalues.data_handle(),
expected_eigenvalues.data_handle(),
Expand Down Expand Up @@ -278,7 +283,12 @@ class lanczos_tests : public ::testing::TestWithParam<lanczos_inputs<IndexType,

std::get<0>(stats) =
raft::sparse::solver::lanczos_compute_smallest_eigenvectors<IndexType, ValueType>(
handle, csr_matrix, config, v0.view(), eigenvalues.view(), eigenvectors.view());
handle,
config,
csr_matrix,
std::make_optional(v0.view()),
eigenvalues.view(),
eigenvectors.view());

ASSERT_TRUE(raft::devArrMatch<ValueType>(eigenvalues.data_handle(),
expected_eigenvalues.data_handle(),
Expand Down
58 changes: 35 additions & 23 deletions python/pylibraft/pylibraft/sparse/linalg/lanczos.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ from pylibraft.common.cpp.mdspan cimport (
make_device_vector_view,
row_major,
)
from pylibraft.common.cpp.optional cimport optional
from pylibraft.common.handle cimport device_resources
from pylibraft.random.cpp.rng_state cimport RngState

Expand All @@ -59,41 +60,41 @@ cdef extern from "raft_runtime/solver/lanczos.hpp" \

cdef void lanczos_solver(
const device_resources &handle,
lanczos_solver_config[double] config,
device_vector_view[int64_t, uint32_t] rows,
device_vector_view[int64_t, uint32_t] cols,
device_vector_view[double, uint32_t] vals,
lanczos_solver_config[double] config,
device_vector_view[double, uint32_t] v0,
optional[device_vector_view[double, uint32_t]] v0,
device_vector_view[double, uint32_t] eigenvalues,
device_matrix_view[double, uint32_t, col_major] eigenvectors) except +

cdef void lanczos_solver(
const device_resources &handle,
lanczos_solver_config[float] config,
device_vector_view[int64_t, uint32_t] rows,
device_vector_view[int64_t, uint32_t] cols,
device_vector_view[float, uint32_t] vals,
lanczos_solver_config[float] config,
device_vector_view[float, uint32_t] v0,
optional[device_vector_view[float, uint32_t]] v0,
device_vector_view[float, uint32_t] eigenvalues,
device_matrix_view[float, uint32_t, col_major] eigenvectors) except +

cdef void lanczos_solver(
const device_resources &handle,
lanczos_solver_config[double] config,
device_vector_view[int, uint32_t] rows,
device_vector_view[int, uint32_t] cols,
device_vector_view[double, uint32_t] vals,
lanczos_solver_config[double] config,
device_vector_view[double, uint32_t] v0,
optional[device_vector_view[double, uint32_t]] v0,
device_vector_view[double, uint32_t] eigenvalues,
device_matrix_view[double, uint32_t, col_major] eigenvectors) except +

cdef void lanczos_solver(
const device_resources &handle,
lanczos_solver_config[float] config,
device_vector_view[int, uint32_t] rows,
device_vector_view[int, uint32_t] cols,
device_vector_view[float, uint32_t] vals,
lanczos_solver_config[float] config,
device_vector_view[float, uint32_t] v0,
optional[device_vector_view[float, uint32_t]] v0,
device_vector_view[float, uint32_t] eigenvalues,
device_matrix_view[float, uint32_t, col_major] eigenvectors) except +

Expand Down Expand Up @@ -159,6 +160,8 @@ def eigsh(A, k=6, v0=None, ncv=None, maxiter=None,
rows_ptr = <uintptr_t>rows.data
cols_ptr = <uintptr_t>cols.data
vals_ptr = <uintptr_t>vals.data
cdef optional[device_vector_view[double, uint32_t]] d_v0
cdef optional[device_vector_view[float, uint32_t]] f_v0

if ncv is None:
ncv = min(n, max(2*k + 1, 20))
Expand All @@ -171,13 +174,6 @@ def eigsh(A, k=6, v0=None, ncv=None, maxiter=None,
if tol == 0:
tol = np.finfo(ValueType).eps

if v0 is None:
rng = cp.random.default_rng(seed)
v0 = rng.random((N,)).astype(vals.dtype)

v0 = cai_wrapper(v0)
v0_ptr = <uintptr_t>v0.data

eigenvectors = device_ndarray.empty((N, k), dtype=ValueType, order='F')
eigenvalues = device_ndarray.empty((k), dtype=ValueType, order='F')

Expand All @@ -196,13 +192,17 @@ def eigsh(A, k=6, v0=None, ncv=None, maxiter=None,
config_float.ncv = ncv
config_float.tolerance = tol
config_float.seed = seed
if v0 is not None:
v0 = cai_wrapper(v0)
v0_ptr = <uintptr_t>v0.data
f_v0 = make_device_vector_view(<float *>v0_ptr, <uint32_t> N)
lanczos_solver(
deref(h),
<lanczos_solver_config[float]> config_float,
make_device_vector_view(<int *>rows_ptr, <uint32_t> (N + 1)),
make_device_vector_view(<int *>cols_ptr, <uint32_t> nnz),
make_device_vector_view(<float *>vals_ptr, <uint32_t> nnz),
<lanczos_solver_config[float]> config_float,
make_device_vector_view(<float *>v0_ptr, <uint32_t> N),
f_v0,
make_device_vector_view(<float *>eigenvalues_ptr, <uint32_t> k),
make_device_matrix_view[float, uint32_t, col_major](
<float *>eigenvectors_ptr, <uint32_t> N, <uint32_t> k),
Expand All @@ -213,13 +213,17 @@ def eigsh(A, k=6, v0=None, ncv=None, maxiter=None,
config_float.ncv = ncv
config_float.tolerance = tol
config_float.seed = seed
if v0 is not None:
v0 = cai_wrapper(v0)
v0_ptr = <uintptr_t>v0.data
f_v0 = make_device_vector_view(<float *>v0_ptr, <uint32_t> N)
lanczos_solver(
deref(h),
<lanczos_solver_config[float]> config_float,
make_device_vector_view(<int64_t *>rows_ptr, <uint32_t> (N + 1)),
make_device_vector_view(<int64_t *>cols_ptr, <uint32_t> nnz),
make_device_vector_view(<float *>vals_ptr, <uint32_t> nnz),
<lanczos_solver_config[float]> config_float,
make_device_vector_view(<float *>v0_ptr, <uint32_t> N),
f_v0,
make_device_vector_view(<float *>eigenvalues_ptr, <uint32_t> k),
make_device_matrix_view[float, uint32_t, col_major](
<float *>eigenvectors_ptr, <uint32_t> N, <uint32_t> k),
Expand All @@ -230,13 +234,17 @@ def eigsh(A, k=6, v0=None, ncv=None, maxiter=None,
config_double.ncv = ncv
config_double.tolerance = tol
config_double.seed = seed
if v0 is not None:
v0 = cai_wrapper(v0)
v0_ptr = <uintptr_t>v0.data
d_v0 = make_device_vector_view(<double *>v0_ptr, <uint32_t> N)
lanczos_solver(
deref(h),
<lanczos_solver_config[double]> config_double,
make_device_vector_view(<int *>rows_ptr, <uint32_t> (N + 1)),
make_device_vector_view(<int *>cols_ptr, <uint32_t> nnz),
make_device_vector_view(<double *>vals_ptr, <uint32_t> nnz),
<lanczos_solver_config[double]> config_double,
make_device_vector_view(<double *>v0_ptr, <uint32_t> N),
d_v0,
make_device_vector_view(<double *>eigenvalues_ptr, <uint32_t> k),
make_device_matrix_view[double, uint32_t, col_major](
<double *>eigenvectors_ptr, <uint32_t> N, <uint32_t> k),
Expand All @@ -247,13 +255,17 @@ def eigsh(A, k=6, v0=None, ncv=None, maxiter=None,
config_double.ncv = ncv
config_double.tolerance = tol
config_double.seed = seed
if v0 is not None:
v0 = cai_wrapper(v0)
v0_ptr = <uintptr_t>v0.data
d_v0 = make_device_vector_view(<double *>v0_ptr, <uint32_t> N)
lanczos_solver(
deref(h),
<lanczos_solver_config[double]> config_double,
make_device_vector_view(<int64_t *>rows_ptr, <uint32_t> (N + 1)),
make_device_vector_view(<int64_t *>cols_ptr, <uint32_t> nnz),
make_device_vector_view(<double *>vals_ptr, <uint32_t> nnz),
<lanczos_solver_config[double]> config_double,
make_device_vector_view(<double *>v0_ptr, <uint32_t> N),
d_v0,
make_device_vector_view(<double *>eigenvalues_ptr, <uint32_t> k),
make_device_matrix_view[double, uint32_t, col_major](
<double *>eigenvectors_ptr, <uint32_t> N, <uint32_t> k),
Expand Down

0 comments on commit f8078e1

Please sign in to comment.