Skip to content

Commit

Permalink
Use estimator tags to improve sparse error handling (#6151)
Browse files Browse the repository at this point in the history
Authors:
  - Dante Gama Dessavre (https://github.com/dantegd)

Approvers:
  - William Hicks (https://github.com/wphicks)

URL: #6151
  • Loading branch information
dantegd authored Dec 13, 2024
1 parent 811e18b commit 7211507
Show file tree
Hide file tree
Showing 14 changed files with 219 additions and 30 deletions.
30 changes: 26 additions & 4 deletions python/cuml/cuml/internals/base.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ except ImportError:

import cuml
import cuml.common
from cuml.common.sparse_utils import is_sparse
import cuml.internals.logger as logger
import cuml.internals
import cuml.internals.input_utils
Expand All @@ -47,6 +48,7 @@ from cuml.internals.input_utils import (
determine_array_type,
input_to_cuml_array,
input_to_host_array,
input_to_host_array_with_sparse_support,
is_array_like
)
from cuml.internals.memory_utils import determine_array_memtype
Expand Down Expand Up @@ -679,19 +681,23 @@ class UniversalBase(Base):

def args_to_cpu(self, *args, **kwargs):
# put all the args on host
new_args = tuple(input_to_host_array(arg)[0] for arg in args)
new_args = tuple(
input_to_host_array_with_sparse_support(arg) for arg in args
)

# put all the kwargs on host
new_kwargs = dict()
for kw, arg in kwargs.items():
# if array-like, ensure array-like is on the host
if is_array_like(arg):
new_kwargs[kw] = input_to_host_array(arg)[0]
new_kwargs[kw] = input_to_host_array_with_sparse_support(arg)
# if Real or string, pass as is
elif isinstance(arg, (numbers.Real, str)):
new_kwargs[kw] = arg
else:
raise ValueError(f"Unable to process argument {kw}")

new_kwargs.pop("convert_dtype", None)
return new_args, new_kwargs

def dispatch_func(self, func_name, gpu_func, *args, **kwargs):
Expand Down Expand Up @@ -739,9 +745,9 @@ class UniversalBase(Base):
# ensure args and kwargs are on the CPU
args, kwargs = self.args_to_cpu(*args, **kwargs)

# get the function from the GPU estimator
# get the function from the CPU estimator
cpu_func = getattr(self._cpu_model, func_name)
# call the function from the GPU estimator
# call the function from the CPU estimator
logger.info(f"cuML: Performing {func_name} in CPU")
res = cpu_func(*args, **kwargs)

Expand All @@ -764,6 +770,22 @@ class UniversalBase(Base):
def _dispatch_selector(self, func_name, *args, **kwargs):
"""
"""
# check for sparse inputs and whether estimator supports them
sparse_support = "sparse" in self._get_tags()["X_types_gpu"]

if args and is_sparse(args[0]):
if sparse_support:
return DeviceType.device
elif GlobalSettings().accelerator_active and not sparse_support:
logger.info(
f"cuML: Estimator {self} does not support sparse inputs in GPU."
)
return DeviceType.host
else:
raise NotImplementedError(
"Estimator does not support sparse inputs currently"
)

# if not using accelerator, then return global device
if not hasattr(self, "_gpuaccel"):
return cuml.global_settings.device_type
Expand Down
5 changes: 5 additions & 0 deletions python/cuml/cuml/internals/input_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,11 @@ def input_to_host_array(


def input_to_host_array_with_sparse_support(X):
try:
if scipy_sparse.isspmatrix(X):
return X
except UnavailableError:
pass
_array_type, is_sparse = determine_array_type_full(X)
if is_sparse:
if _array_type == "cupy":
Expand Down
12 changes: 1 addition & 11 deletions python/cuml/cuml/linear_model/linear_regression.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ class LinearRegression(LinearPredictMixin,

@device_interop_preparation
def __init__(self, *, algorithm='eig', fit_intercept=True,
copy_X=None, normalize=False,
copy_X=True, normalize=False,
handle=None, verbose=False, output_type=None):
IF GPUBUILD == 1:
if handle is None and algorithm == 'eig':
Expand All @@ -301,16 +301,6 @@ class LinearRegression(LinearPredictMixin,
raise TypeError(msg.format(algorithm))

self.intercept_value = 0.0
if copy_X is None:
warnings.warn(
"Starting from version 23.08, the new 'copy_X' parameter defaults "
"to 'True', ensuring a copy of X is created after passing it to "
"fit(), preventing any changes to the input, but with increased "
"memory usage. This represents a change in behavior from previous "
"versions. With `copy_X=False` a copy might still be created if "
"necessary. Explicitly set 'copy_X' to either True or False to "
"suppress this warning.", UserWarning)
copy_X = True
self.copy_X = copy_X

def _get_algorithm_int(self, algorithm):
Expand Down
5 changes: 3 additions & 2 deletions python/cuml/cuml/linear_model/logistic_regression.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import pprint
import cuml.internals
from cuml.solvers import QN
from cuml.internals.base import UniversalBase
from cuml.internals.mixins import ClassifierMixin, FMajorInputTagMixin
from cuml.internals.mixins import ClassifierMixin, FMajorInputTagMixin, SparseInputTagMixin
from cuml.common.array_descriptor import CumlArrayDescriptor
from cuml.internals.array import CumlArray
from cuml.common.doc_utils import generate_docstring
Expand All @@ -45,7 +45,8 @@ supported_solvers = ["qn"]

class LogisticRegression(UniversalBase,
ClassifierMixin,
FMajorInputTagMixin):
FMajorInputTagMixin,
SparseInputTagMixin):
"""
LogisticRegression is a linear model that is used to model probability of
occurrence of certain events, for example probability of success or fail of
Expand Down
5 changes: 3 additions & 2 deletions python/cuml/cuml/manifold/t_sne.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ from cuml.internals.array_sparse import SparseCumlArray
from cuml.common.sparse_utils import is_sparse
from cuml.common.doc_utils import generate_docstring
from cuml.common import input_to_cuml_array
from cuml.internals.mixins import CMajorInputTagMixin
from cuml.internals.mixins import CMajorInputTagMixin, SparseInputTagMixin
from cuml.common.sparsefuncs import extract_knn_infos
from cuml.metrics.distance_type cimport DistanceType
rmm = gpu_only_import('rmm')
Expand Down Expand Up @@ -119,7 +119,8 @@ cdef extern from "cuml/manifold/tsne.h" namespace "ML":


class TSNE(UniversalBase,
CMajorInputTagMixin):
CMajorInputTagMixin,
SparseInputTagMixin):
"""
t-SNE (T-Distributed Stochastic Neighbor Embedding) is an extremely
powerful dimensionality reduction technique that aims to maintain
Expand Down
5 changes: 3 additions & 2 deletions python/cuml/cuml/manifold/umap.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ from cuml.internals.input_utils import input_to_cuml_array
from cuml.internals.array import CumlArray
from cuml.internals.array_sparse import SparseCumlArray
from cuml.internals.mem_type import MemoryType
from cuml.internals.mixins import CMajorInputTagMixin
from cuml.internals.mixins import CMajorInputTagMixin, SparseInputTagMixin
from cuml.common.sparse_utils import is_sparse

from cuml.common.array_descriptor import CumlArrayDescriptor
Expand Down Expand Up @@ -136,7 +136,8 @@ IF GPUBUILD == 1:


class UMAP(UniversalBase,
CMajorInputTagMixin):
CMajorInputTagMixin,
SparseInputTagMixin):
"""
Uniform Manifold Approximation and Projection
Expand Down
2 changes: 2 additions & 0 deletions python/cuml/cuml/neighbors/kneighbors_classifier.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ from cuml.common.array_descriptor import CumlArrayDescriptor
from cuml.internals.mixins import ClassifierMixin
from cuml.common.doc_utils import generate_docstring
from cuml.internals.mixins import FMajorInputTagMixin
from cuml.internals.api_decorators import enable_device_interop

from cuml.internals.safe_imports import cpu_only_import
np = cpu_only_import('numpy')
Expand Down Expand Up @@ -246,6 +247,7 @@ class KNeighborsClassifier(ClassifierMixin,
'description': 'Labels probabilities',
'shape': '(n_samples, 1)'})
@cuml.internals.api_base_return_generic()
@enable_device_interop
def predict_proba(
self,
X,
Expand Down
2 changes: 2 additions & 0 deletions python/cuml/cuml/neighbors/kneighbors_regressor.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ from cuml.common.array_descriptor import CumlArrayDescriptor
from cuml.internals.mixins import RegressorMixin
from cuml.common.doc_utils import generate_docstring
from cuml.internals.mixins import FMajorInputTagMixin
from cuml.internals.api_decorators import enable_device_interop

from cuml.internals.safe_imports import cpu_only_import
np = cpu_only_import('numpy')
Expand Down Expand Up @@ -195,6 +196,7 @@ class KNeighborsRegressor(RegressorMixin,
'type': 'dense',
'description': 'Predicted values',
'shape': '(n_samples, n_features)'})
@enable_device_interop
def predict(self, X, convert_dtype=True) -> CumlArray:
"""
Use the trained k-nearest neighbors regression model to
Expand Down
5 changes: 3 additions & 2 deletions python/cuml/cuml/neighbors/nearest_neighbors.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ from cuml.internals.array import CumlArray
from cuml.internals.array_sparse import SparseCumlArray
from cuml.common.doc_utils import generate_docstring
from cuml.common.doc_utils import insert_into_docstring
from cuml.internals.mixins import CMajorInputTagMixin
from cuml.internals.mixins import CMajorInputTagMixin, SparseInputTagMixin
from cuml.internals.input_utils import input_to_cupy_array
from cuml.common import input_to_cuml_array
from cuml.common.sparse_utils import is_sparse
Expand Down Expand Up @@ -144,7 +144,8 @@ IF GPUBUILD == 1:


class NearestNeighbors(UniversalBase,
CMajorInputTagMixin):
CMajorInputTagMixin,
SparseInputTagMixin):
"""
NearestNeighbors is an queries neighborhoods from a given set of
datapoints. Currently, cuML supports k-NN queries, which define
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,13 +122,19 @@ def custom_weights(distances):
assert acc > 0.7, "Accuracy should be reasonable with custom weights"


@pytest.mark.xfail(
reason="cuML and sklearn don't have matching exceptions yet"
)
def test_knn_classifier_invalid_algorithm(classification_data):
X, y = classification_data
with pytest.raises((ValueError, KeyError)):
model = KNeighborsClassifier(algorithm="invalid_algorithm")
model.fit(X, y)


@pytest.mark.xfail(
reason="cuML and sklearn don't have matching exceptions yet"
)
def test_knn_classifier_invalid_metric(classification_data):
X, y = classification_data
with pytest.raises(ValueError):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,13 +116,19 @@ def custom_weights(distances):
assert r2 > 0.7, "R^2 score should be reasonable with custom weights"


@pytest.mark.xfail(
reason="cuML and sklearn don't have matching exceptions yet"
)
def test_knn_regressor_invalid_algorithm(regression_data):
X, y = regression_data
with pytest.raises((ValueError, KeyError)):
model = KNeighborsRegressor(algorithm="invalid_algorithm")
model.fit(X, y)


@pytest.mark.xfail(
reason="cuML and sklearn don't have matching exceptions yet"
)
def test_knn_regressor_invalid_metric(regression_data):
X, y = regression_data
with pytest.raises(ValueError):
Expand Down
77 changes: 77 additions & 0 deletions python/cuml/cuml/tests/experimental/accel/test_sparse_dispatch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
#
# Copyright (c) 2024, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import pytest
import numpy as np

from cuml.internals.global_settings import GlobalSettings
from scipy.sparse import csr_matrix
from sklearn.cluster import KMeans, DBSCAN
from sklearn.decomposition import TruncatedSVD
from sklearn.linear_model import (
LinearRegression,
LogisticRegression,
ElasticNet,
Ridge,
Lasso,
)
from sklearn.neighbors import (
NearestNeighbors,
KNeighborsClassifier,
KNeighborsRegressor,
)
from sklearn.base import is_classifier, is_regressor
from hdbscan import HDBSCAN
from umap import UMAP

estimators = {
"KMeans": lambda: KMeans(n_clusters=2, random_state=0),
"DBSCAN": lambda: DBSCAN(eps=1.0),
"TruncatedSVD": lambda: TruncatedSVD(n_components=1, random_state=0),
"LinearRegression": lambda: LinearRegression(),
"LogisticRegression": lambda: LogisticRegression(),
"ElasticNet": lambda: ElasticNet(),
"Ridge": lambda: Ridge(),
"Lasso": lambda: Lasso(),
"NearestNeighbors": lambda: NearestNeighbors(n_neighbors=1),
"UMAP": lambda: UMAP(n_components=1),
"HDBSCAN": lambda: HDBSCAN(),
}


@pytest.mark.parametrize("estimator_name", list(estimators.keys()))
def test_sparse_support(estimator_name):
if not GlobalSettings().accelerator_active and estimator_name == "UMAP":
pytest.skip(reason="UMAP CPU library fails on this small dataset")
X_sparse = csr_matrix([[0, 1], [1, 0]])
y_class = np.array([0, 1])
y_reg = np.array([0.0, 1.0])
estimator = estimators[estimator_name]()
# Fit or fit_transform depending on the estimator type
if isinstance(estimator, (KMeans, DBSCAN, TruncatedSVD, NearestNeighbors)):
if hasattr(estimator, "fit_transform"):
estimator.fit_transform(X_sparse)
else:
estimator.fit(X_sparse)
else:
# For classifiers and regressors, decide which y to provide
if is_classifier(estimator):
estimator.fit(X_sparse, y_class)
elif is_regressor(estimator):
estimator.fit(X_sparse, y_reg)
else:
# Just in case there's an unexpected type
estimator.fit(X_sparse)
Loading

0 comments on commit 7211507

Please sign in to comment.