Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CPU/GPU interop with RandomForest #6175

Open
wants to merge 8 commits into
base: branch-25.02
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion conda/environments/all_cuda-118_arch-x86_64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ dependencies:
- sphinx-markdown-tables
- statsmodels
- sysroot_linux-64==2.28
- treelite==4.3.0
- treelite==4.4.1
- umap-learn==0.5.6
- xgboost>=2.1.0
name: all_cuda-118_arch-x86_64
2 changes: 1 addition & 1 deletion conda/environments/all_cuda-125_arch-x86_64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ dependencies:
- sphinx-markdown-tables
- statsmodels
- sysroot_linux-64==2.28
- treelite==4.3.0
- treelite==4.4.1
- umap-learn==0.5.6
- xgboost>=2.1.0
name: all_cuda-125_arch-x86_64
6 changes: 3 additions & 3 deletions cpp/cmake/thirdparty/get_treelite.cmake
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#=============================================================================
# Copyright (c) 2021-2024, NVIDIA CORPORATION.
# Copyright (c) 2021-2025, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -78,7 +78,7 @@ function(find_and_configure_treelite)
rapids_export_find_package_root(BUILD Treelite [=[${CMAKE_CURRENT_LIST_DIR}]=] EXPORT_SET cuml-exports)
endfunction()

find_and_configure_treelite(VERSION 4.3.0
PINNED_TAG 575e4208f2b18e40d818c338ecb95d7a26e69aab
find_and_configure_treelite(VERSION 4.4.1
PINNED_TAG 386bd0de99f5a66584c7e58221ee38ce606ad1ae
EXCLUDE_FROM_ALL ${CUML_EXCLUDE_TREELITE_FROM_ALL}
BUILD_STATIC_LIBS ${CUML_USE_TREELITE_STATIC})
2 changes: 1 addition & 1 deletion dependencies.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ dependencies:
- output_types: [conda, requirements, pyproject]
packages:
- &cython cython>=3.0.0
- &treelite treelite==4.3.0
- &treelite treelite==4.4.1

py_run_cuml:
common:
Expand Down
28 changes: 25 additions & 3 deletions python/cuml/cuml/ensemble/randomforest_common.pyx
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2020-2024, NVIDIA CORPORATION.
# Copyright (c) 2020-2025, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -13,7 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import treelite.sklearn
from cuml.internals.safe_imports import gpu_only_import
from cuml.internals.api_decorators import device_interop_preparation

cp = gpu_only_import('cupy')
import math
import warnings
Expand All @@ -24,7 +27,7 @@ np = cpu_only_import('numpy')
from cuml import ForestInference
from cuml.fil.fil import TreeliteModel
from pylibraft.common.handle import Handle
from cuml.internals.base import Base
from cuml.internals.base import UniversalBase
from cuml.internals.array import CumlArray
from cuml.common.exceptions import NotFittedError
import cuml.internals
Expand All @@ -39,7 +42,7 @@ from cuml.common.array_descriptor import CumlArrayDescriptor
from cuml.prims.label.classlabels import make_monotonic, check_labels


class BaseRandomForestModel(Base):
class BaseRandomForestModel(UniversalBase):
_param_names = ['n_estimators', 'max_depth', 'handle',
'max_features', 'n_bins',
'split_criterion', 'min_samples_leaf',
Expand Down Expand Up @@ -67,6 +70,7 @@ class BaseRandomForestModel(Base):

classes_ = CumlArrayDescriptor()

@device_interop_preparation
def __init__(self, *, split_criterion, n_streams=4, n_estimators=100,
max_depth=16, handle=None, max_features='sqrt', n_bins=128,
bootstrap=True,
Expand Down Expand Up @@ -268,6 +272,24 @@ class BaseRandomForestModel(Base):
self.treelite_handle = <uintptr_t> tl_handle
return self.treelite_handle

def cpu_to_gpu(self):
tl_model = treelite.sklearn.import_model(self._cpu_model)
self._temp = TreeliteModel.from_treelite_bytes(tl_model.serialize_bytes())
self.treelite_serialized_model = treelite_serialize(self._temp.handle)
self._obtain_treelite_handle()
self.dtype = np.float64
self.update_labels = False
super().cpu_to_gpu()

def gpu_to_cpu(self):
self._obtain_treelite_handle()
tl_model = TreeliteModel.from_treelite_model_handle(
self.treelite_handle,
take_handle_ownership=False)
tl_bytes = tl_model.to_treelite_bytes()
tl_model2 = treelite.Model.deserialize_bytes(tl_bytes)
self._cpu_model = treelite.sklearn.export_model(tl_model2)

@cuml.internals.api_base_return_generic(set_output_type=True,
set_n_features_in=True,
get_output_type=False)
Expand Down
10 changes: 10 additions & 0 deletions python/cuml/cuml/ensemble/randomforestclassifier.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
#

# distutils: language = c++
from cuml.internals.api_decorators import device_interop_preparation
from cuml.internals.api_decorators import enable_device_interop
from cuml.internals.safe_imports import (
cpu_only_import,
gpu_only_import,
Expand Down Expand Up @@ -247,6 +249,9 @@ class RandomForestClassifier(BaseRandomForestModel,
<https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.RandomForestClassifier.html>`_.
"""

_cpu_estimator_import_path = 'sklearn.ensemble.RandomForestClassifier'

@device_interop_preparation
def __init__(self, *, split_criterion=0, handle=None, verbose=False,
output_type=None,
**kwargs):
Expand Down Expand Up @@ -337,6 +342,9 @@ class RandomForestClassifier(BaseRandomForestModel,
self.treelite_serialized_model = None
self.n_cols = None

def get_attr_names(self):
return []

def convert_to_treelite_model(self):
"""
Converts the cuML RF model to a Treelite model
Expand Down Expand Up @@ -417,6 +425,7 @@ class RandomForestClassifier(BaseRandomForestModel,
@cuml.internals.api_base_return_any(set_output_type=False,
set_output_dtype=True,
set_n_features_in=False)
@enable_device_interop
def fit(self, X, y, convert_dtype=True):
"""
Perform Random Forest Classification on the input data
Expand Down Expand Up @@ -555,6 +564,7 @@ class RandomForestClassifier(BaseRandomForestModel,
@insert_into_docstring(parameters=[('dense', '(n_samples, n_features)')],
return_values=[('dense', '(n_samples, 1)')])
@cuml.internals.api_base_return_array(get_output_dtype=True)
@enable_device_interop
def predict(self, X, predict_model="GPU", threshold=0.5,
algo='auto', convert_dtype=True,
fil_sparse_format='auto') -> CumlArray:
Expand Down
11 changes: 10 additions & 1 deletion python/cuml/cuml/ensemble/randomforestregressor.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#

# distutils: language = c++

from cuml.internals.api_decorators import device_interop_preparation
from cuml.internals.api_decorators import enable_device_interop
from cuml.internals.safe_imports import (
cpu_only_import,
gpu_only_import,
Expand Down Expand Up @@ -250,6 +251,9 @@ class RandomForestRegressor(BaseRandomForestModel,
<https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.RandomForestRegressor.html>`_.
"""

_cpu_estimator_import_path = 'sklearn.ensemble.RandomForestRegressor'

@device_interop_preparation
def __init__(self, *,
split_criterion=2,
accuracy_metric='r2',
Expand Down Expand Up @@ -341,6 +345,9 @@ class RandomForestRegressor(BaseRandomForestModel,
self.treelite_serialized_model = None
self.n_cols = None

def get_attr_names(self):
return []

def convert_to_treelite_model(self):
"""
Converts the cuML RF model to a Treelite model
Expand Down Expand Up @@ -412,6 +419,7 @@ class RandomForestRegressor(BaseRandomForestModel,
domain="cuml_python")
@generate_docstring()
@cuml.internals.api_base_return_any_skipall
@enable_device_interop
def fit(self, X, y, convert_dtype=True):
"""
Perform Random Forest Regression on the input data
Expand Down Expand Up @@ -534,6 +542,7 @@ class RandomForestRegressor(BaseRandomForestModel,
domain="cuml_python")
@insert_into_docstring(parameters=[('dense', '(n_samples, n_features)')],
return_values=[('dense', '(n_samples, 1)')])
@enable_device_interop
def predict(self, X, predict_model="GPU",
algo='auto', convert_dtype=True,
fil_sparse_format='auto') -> CumlArray:
Expand Down
14 changes: 13 additions & 1 deletion python/cuml/cuml/fil/fil.pyx
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2019-2024, NVIDIA CORPORATION.
# Copyright (c) 2019-2025, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -72,6 +72,7 @@ cdef extern from "treelite/c_api.h":
const char* filename) except +
cdef int TreeliteDeserializeModelFromBytes(const char* bytes_seq, size_t len,
TreeliteModelHandle* out) except +
cdef int TreeliteSerializeModelToBytes(TreeliteModelHandle handle, const char** out_bytes, size_t* out_bytes_len)
cdef int TreeliteGetHeaderField(
TreeliteModelHandle model, const char * name, TreelitePyBufferFrame* out_frame) except +
cdef const char* TreeliteGetLastError()
Expand Down Expand Up @@ -192,6 +193,17 @@ cdef class TreeliteModel():
model.set_handle(handle)
return model

def to_treelite_bytes(self) -> bytes:
assert self.handle != NULL
cdef const char* out_bytes
cdef size_t out_bytes_len
cdef int res = TreeliteSerializeModelToBytes(self.handle, &out_bytes, &out_bytes_len)
cdef str err_msg
if res < 0:
err_msg = TreeliteGetLastError().decode("UTF-8")
raise RuntimeError(f"Failed to serialize Treelite model ({err_msg})")
return out_bytes[:out_bytes_len]

@classmethod
def from_filename(cls, filename, model_type="xgboost_ubj"):
"""
Expand Down
23 changes: 22 additions & 1 deletion python/cuml/cuml/tests/test_device_selection.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2022-2024, NVIDIA CORPORATION.
# Copyright (c) 2022-2025, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -37,6 +37,7 @@
from cuml.decomposition import PCA, TruncatedSVD
from cuml.cluster import KMeans
from cuml.cluster import DBSCAN
from cuml.ensemble import RandomForestClassifier, RandomForestRegressor
from cuml.common.device_selection import DeviceType, using_device_type
from cuml.testing.utils import assert_dbscan_equal
from hdbscan import HDBSCAN as refHDBSCAN
Expand Down Expand Up @@ -1011,3 +1012,23 @@ def test_dbscan_methods(train_device, infer_device):
assert_dbscan_equal(
ref_output, output, X_train_blob, model.core_sample_indices_, eps
)


@pytest.mark.parametrize("train_device", ["cpu", "gpu"])
@pytest.mark.parametrize("infer_device", ["cpu", "gpu"])
def test_random_forest_regressor(train_device, infer_device):
model = RandomForestRegressor()
with using_device_type(train_device):
model.fit(X_train_reg, y_train_reg)
with using_device_type(infer_device):
_ = model.predict(X_test_reg)


@pytest.mark.parametrize("train_device", ["cpu", "gpu"])
@pytest.mark.parametrize("infer_device", ["cpu", "gpu"])
def test_random_forest_classifier(train_device, infer_device):
model = RandomForestClassifier()
with using_device_type(train_device):
model.fit(X_train_blob, y_train_blob)
with using_device_type(infer_device):
_ = model.predict(X_test_blob)
4 changes: 2 additions & 2 deletions python/cuml/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ dependencies = [
"rapids-dask-dependency==25.2.*,>=0.0.0a0",
"rmm==25.2.*,>=0.0.0a0",
"scipy>=1.8.0",
"treelite==4.3.0",
"treelite==4.4.1",
] # This list was generated by `rapids-dependency-file-generator`. To make changes, edit ../../dependencies.yaml and run `rapids-dependency-file-generator`.
classifiers = [
"Intended Audience :: Developers",
Expand Down Expand Up @@ -184,7 +184,7 @@ requires = [
"ninja",
"pylibraft==25.2.*,>=0.0.0a0",
"rmm==25.2.*,>=0.0.0a0",
"treelite==4.3.0",
"treelite==4.4.1",
] # This list was generated by `rapids-dependency-file-generator`. To make changes, edit ../../dependencies.yaml and run `rapids-dependency-file-generator`.

[tool.scikit-build]
Expand Down
Loading