Skip to content

Commit

Permalink
Enable CPU execution of IncrementalPCA
Browse files Browse the repository at this point in the history
  • Loading branch information
wphicks committed Jan 23, 2025
1 parent 87727c1 commit 0d9d09d
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 93 deletions.
46 changes: 31 additions & 15 deletions python/cuml/cuml/decomposition/incremental_pca.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2020-2024, NVIDIA CORPORATION.
# Copyright (c) 2020-2025, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -20,6 +20,11 @@
from cuml.internals.input_utils import input_to_cupy_array
from cuml.common import input_to_cuml_array
from cuml import Base
from cuml.internals.api_decorators import (
device_interop_preparation,
enable_device_interop,
)
from cuml.internals.global_settings import GlobalSettings
from cuml.internals.safe_imports import cpu_only_import
import numbers

Expand Down Expand Up @@ -195,6 +200,9 @@ class IncrementalPCA(PCA):
0.0037122774558343763
"""

_cpu_estimator_import_path = "sklearn.decomposition.IncrementalPCA"

@device_interop_preparation
def __init__(
self,
*,
Expand All @@ -218,6 +226,7 @@ def __init__(
self.batch_size = batch_size
self._sparse_model = True

@enable_device_interop
def fit(self, X, y=None, convert_dtype=True) -> "IncrementalPCA":
"""
Fit the model with X, using minibatches of size batch_size.
Expand Down Expand Up @@ -255,10 +264,10 @@ def fit(self, X, y=None, convert_dtype=True) -> "IncrementalPCA":
check_dtype=[cp.float32, cp.float64],
)

n_samples, n_features = X.shape
n_samples, self.n_features_in_ = X.shape

if self.batch_size is None:
self.batch_size_ = 5 * n_features
self.batch_size_ = 5 * self.n_features_in_
else:
self.batch_size_ = self.batch_size

Expand Down Expand Up @@ -305,25 +314,30 @@ def partial_fit(self, X, y=None, check_input=True) -> "IncrementalPCA":

self._set_output_type(X)

X, n_samples, n_features, self.dtype = input_to_cupy_array(
(
X,
n_samples,
self.n_features_in_,
self.dtype,
) = input_to_cupy_array(
X, order="K", check_dtype=[cp.float32, cp.float64]
)
else:
n_samples, n_features = X.shape
n_samples, self.n_features_in_ = X.shape

if not hasattr(self, "components_"):
self.components_ = None

if self.n_components is None:
if self.components_ is None:
self.n_components_ = min(n_samples, n_features)
self.n_components_ = min(n_samples, self.n_features_in_)
else:
self.n_components_ = self.components_.shape[0]
elif not 1 <= self.n_components <= n_features:
elif not 1 <= self.n_components <= self.n_features_in_:
raise ValueError(
"n_components=%r invalid for n_features=%d, need "
"more rows than columns for IncrementalPCA "
"processing" % (self.n_components, n_features)
"processing" % (self.n_components, self.n_features_in_)
)
elif not self.n_components <= n_samples:
raise ValueError(
Expand Down Expand Up @@ -394,7 +408,7 @@ def partial_fit(self, X, y=None, check_input=True) -> "IncrementalPCA":
self.explained_variance_ratio_ = explained_variance_ratio[
: self.n_components_
]
if self.n_components_ < n_features:
if self.n_components_ < self.n_features_in_:
self.noise_variance_ = explained_variance[
self.n_components_ :
].mean()
Expand All @@ -403,6 +417,7 @@ def partial_fit(self, X, y=None, check_input=True) -> "IncrementalPCA":

return self

@enable_device_interop
def transform(self, X, convert_dtype=False) -> CumlArray:
"""
Apply dimensionality reduction to X.
Expand Down Expand Up @@ -678,16 +693,17 @@ def _svd_flip(u, v, u_based_decision=True):
u_adjusted, v_adjusted : arrays with the same dimensions as the input.
"""
xpy = GlobalSettings().xpy
if u_based_decision:
# columns of u, rows of v
max_abs_cols = cp.argmax(cp.abs(u), axis=0)
signs = cp.sign(u[max_abs_cols, list(range(u.shape[1]))])
max_abs_cols = xpy.argmax(xpy.abs(u), axis=0)
signs = xpy.sign(u[max_abs_cols, list(range(u.shape[1]))])
u *= signs
v *= signs[:, cp.newaxis]
v *= signs[:, xpy.newaxis]
else:
# rows of v, columns of u
max_abs_rows = cp.argmax(cp.abs(v), axis=1)
signs = cp.sign(v[list(range(v.shape[0])), max_abs_rows])
max_abs_rows = xpy.argmax(xpy.abs(v), axis=1)
signs = xpy.sign(v[list(range(v.shape[0])), max_abs_rows])
u *= signs
v *= signs[:, cp.newaxis]
v *= signs[:, xpy.newaxis]
return u, v
6 changes: 5 additions & 1 deletion python/cuml/cuml/internals/global_settings.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2021-2024, NVIDIA CORPORATION.
# Copyright (c) 2021-2025, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -134,3 +134,7 @@ def output_type(self, value):
@property
def xpy(self):
return self.memory_type.xpy

@property
def xsparse(self):
return self.memory_type.xsparse
169 changes: 92 additions & 77 deletions python/cuml/cuml/tests/test_incremental_pca.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2020-2024, NVIDIA CORPORATION.
# Copyright (c) 2020-2025, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -14,7 +14,9 @@
# limitations under the License.
#

from cuml import global_settings
from cuml.common.exceptions import NotFittedError
from cuml.common.device_selection import using_device_type
from cuml.testing.utils import array_equal
from cuml.decomposition.incremental_pca import _svd_flip
from cuml.decomposition import IncrementalPCA as cuIPCA
Expand All @@ -40,6 +42,7 @@
(500, 250, 14, True, 0.07, "csr", 1, True),
],
)
@pytest.mark.parametrize("device", ["gpu", "cpu"])
@pytest.mark.no_bad_cuml_array_check
def test_fit(
nrows,
Expand All @@ -50,48 +53,53 @@ def test_fit(
sparse_format,
batch_size_divider,
whiten,
device,
):

if sparse_format == "csc":
pytest.skip(
"cupyx.scipy.sparse.csc.csc_matrix does not support"
" indexing as of cupy 7.6.0"
with using_device_type(device):

if sparse_format == "csc":
pytest.skip(
"cupyx.scipy.sparse.csc.csc_matrix does not support"
" indexing as of cupy 7.6.0"
)

if sparse_input:
X = global_settings.xsparse.random(
nrows,
ncols,
density=density,
random_state=10,
format=sparse_format,
)
else:
X, _ = make_blobs(
n_samples=nrows, n_features=ncols, random_state=10
)

cu_ipca = cuIPCA(
n_components=n_components,
whiten=whiten,
batch_size=int(nrows / batch_size_divider),
)

if sparse_input:
X = cupyx.scipy.sparse.random(
nrows,
ncols,
density=density,
random_state=10,
format=sparse_format,
cu_ipca.fit(X)
cu_t = cu_ipca.transform(X)
cu_inv = cu_ipca.inverse_transform(cu_t)

sk_ipca = skIPCA(
n_components=n_components,
whiten=whiten,
batch_size=int(nrows / batch_size_divider),
)
else:
X, _ = make_blobs(n_samples=nrows, n_features=ncols, random_state=10)
if device == "gpu":
if sparse_input:
X = X.get()
else:
X = cp.asnumpy(X)
sk_ipca.fit(X)
sk_t = sk_ipca.transform(X)
sk_inv = sk_ipca.inverse_transform(sk_t)

cu_ipca = cuIPCA(
n_components=n_components,
whiten=whiten,
batch_size=int(nrows / batch_size_divider),
)
cu_ipca.fit(X)
cu_t = cu_ipca.transform(X)
cu_inv = cu_ipca.inverse_transform(cu_t)

sk_ipca = skIPCA(
n_components=n_components,
whiten=whiten,
batch_size=int(nrows / batch_size_divider),
)
if sparse_input:
X = X.get()
else:
X = cp.asnumpy(X)
sk_ipca.fit(X)
sk_t = sk_ipca.transform(X)
sk_inv = sk_ipca.inverse_transform(sk_t)

assert array_equal(cu_inv, sk_inv, 5e-5, with_sign=True)
assert array_equal(cu_inv, sk_inv, 5e-5, with_sign=True)


@pytest.mark.parametrize(
Expand All @@ -105,62 +113,69 @@ def test_fit(
(5000, 4, 2, 0.1, 100, False),
],
)
@pytest.mark.parametrize("device", ["gpu", "cpu"])
@pytest.mark.no_bad_cuml_array_check
def test_partial_fit(
nrows, ncols, n_components, density, batch_size_divider, whiten
nrows, ncols, n_components, density, batch_size_divider, whiten, device
):

X, _ = make_blobs(n_samples=nrows, n_features=ncols, random_state=10)
with using_device_type(device):
X, _ = make_blobs(n_samples=nrows, n_features=ncols, random_state=10)

cu_ipca = cuIPCA(n_components=n_components, whiten=whiten)
cu_ipca = cuIPCA(n_components=n_components, whiten=whiten)

sample_size = int(nrows / batch_size_divider)
for i in range(0, nrows, sample_size):
cu_ipca.partial_fit(X[i : i + sample_size].copy())
sample_size = int(nrows / batch_size_divider)
for i in range(0, nrows, sample_size):
cu_ipca.partial_fit(X[i : i + sample_size].copy())

cu_t = cu_ipca.transform(X)
cu_inv = cu_ipca.inverse_transform(cu_t)
cu_t = cu_ipca.transform(X)
cu_inv = cu_ipca.inverse_transform(cu_t)

sk_ipca = skIPCA(n_components=n_components, whiten=whiten)
sk_ipca = skIPCA(n_components=n_components, whiten=whiten)

X = cp.asnumpy(X)
if device == "gpu":
X = cp.asnumpy(X)

for i in range(0, nrows, sample_size):
sk_ipca.partial_fit(X[i : i + sample_size].copy())
for i in range(0, nrows, sample_size):
sk_ipca.partial_fit(X[i : i + sample_size].copy())

sk_t = sk_ipca.transform(X)
sk_inv = sk_ipca.inverse_transform(sk_t)
sk_t = sk_ipca.transform(X)
sk_inv = sk_ipca.inverse_transform(sk_t)

assert array_equal(cu_inv, sk_inv, 6e-5, with_sign=True)
assert array_equal(cu_inv, sk_inv, 6e-5, with_sign=True)


def test_exceptions():
X = cupyx.scipy.sparse.eye(10)
ipca = cuIPCA()
with pytest.raises(TypeError):
ipca.partial_fit(X)
@pytest.mark.parametrize("device", ["gpu", "cpu"])
def test_exceptions(device):
with using_device_type(device):
X = global_settings.xsparse.eye(10)
ipca = cuIPCA()
with pytest.raises(TypeError):
ipca.partial_fit(X)

X = X.toarray()
with pytest.raises(NotFittedError):
ipca.transform(X)
X = X.toarray()
with pytest.raises(ValueError):
ipca.transform(X)

with pytest.raises(NotFittedError):
ipca.inverse_transform(X)
with pytest.raises(ValueError):
ipca.inverse_transform(X)

with pytest.raises(ValueError):
cuIPCA(n_components=8).fit(X[:5])
with pytest.raises(ValueError):
cuIPCA(n_components=8).fit(X[:5])

with pytest.raises(ValueError):
cuIPCA(n_components=8).fit(X[:, :5])
with pytest.raises(ValueError):
cuIPCA(n_components=8).fit(X[:, :5])


def test_svd_flip():
x = cp.array(range(-10, 80)).reshape((9, 10))
u, s, v = cp.linalg.svd(x, full_matrices=False)
u_true, v_true = _svd_flip(u, v, u_based_decision=True)
reco_true = cp.dot(u_true * s, v_true)
u_false, v_false = _svd_flip(u, v, u_based_decision=False)
reco_false = cp.dot(u_false * s, v_false)
@pytest.mark.parametrize("device", ["gpu", "cpu"])
def test_svd_flip(device):
with using_device_type(device):
x = global_settings.xpy.array(range(-10, 80)).reshape((9, 10))
u, s, v = global_settings.xpy.linalg.svd(x, full_matrices=False)
u_true, v_true = _svd_flip(u, v, u_based_decision=True)
reco_true = global_settings.xpy.dot(u_true * s, v_true)
u_false, v_false = _svd_flip(u, v, u_based_decision=False)
reco_false = global_settings.xpy.dot(u_false * s, v_false)

assert array_equal(reco_true, x)
assert array_equal(reco_false, x)
assert array_equal(reco_true, x)
assert array_equal(reco_false, x)

0 comments on commit 0d9d09d

Please sign in to comment.