Skip to content

Commit

Permalink
Support CPU execution for make_blobs
Browse files Browse the repository at this point in the history
  • Loading branch information
wphicks committed Jan 23, 2025
1 parent f29293f commit 87727c1
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 52 deletions.
31 changes: 15 additions & 16 deletions python/cuml/cuml/datasets/blobs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2020-2024, NVIDIA CORPORATION.
# Copyright (c) 2020-2025, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -18,6 +18,7 @@
import cuml.internals

from collections.abc import Iterable
from cuml.internals.global_settings import GlobalSettings
from cuml.internals.safe_imports import gpu_only_import
from cuml.datasets.utils import _create_rs_generator
from cuml.internals.safe_imports import (
Expand All @@ -44,8 +45,7 @@ def _get_centers(rs, centers, center_box, n_samples, n_features, dtype):
center_box[0],
center_box[1],
size=(n_centers, n_features),
dtype=dtype,
)
).astype(dtype)

else:
if n_features != centers.shape[1]:
Expand All @@ -63,8 +63,7 @@ def _get_centers(rs, centers, center_box, n_samples, n_features, dtype):
center_box[0],
center_box[1],
size=(n_centers, n_features),
dtype=dtype,
)
).astype(dtype)
try:
assert len(centers) == n_centers
except TypeError:
Expand Down Expand Up @@ -173,7 +172,8 @@ def make_blobs(
# Set the default output type to "cupy". This will be ignored if the user
# has set `cuml.global_settings.output_type`. Only necessary for array
# generation methods that do not take an array as input
cuml.internals.set_api_output_type("cupy")
cuml.internals.set_api_output_type("array")
xpy = GlobalSettings().xpy

generator = _create_rs_generator(random_state=random_state)

Expand All @@ -191,7 +191,7 @@ def make_blobs(
)

if isinstance(cluster_std, numbers.Real):
cluster_std = cp.full(len(centers), cluster_std)
cluster_std = xpy.full(len(centers), cluster_std)

if isinstance(n_samples, Iterable):
n_samples_per_center = n_samples
Expand All @@ -201,9 +201,9 @@ def make_blobs(
for i in range(n_samples % n_centers):
n_samples_per_center[i] += 1

X = cp.zeros(n_samples * n_features, dtype=dtype)
X = xpy.zeros(n_samples * n_features, dtype=dtype)
X = X.reshape((n_samples, n_features), order=order)
y = cp.zeros(n_samples, dtype=dtype)
y = xpy.zeros(n_samples, dtype=dtype)

if shuffle:
proba_samples_per_center = np.array(n_samples_per_center) / np.sum(
Expand All @@ -213,21 +213,20 @@ def make_blobs(
n_centers, n_samples, replace=True, p=proba_samples_per_center
)
for i, (n, std) in enumerate(zip(n_samples_per_center, cluster_std)):
center_indices = cp.where(shuffled_sample_indices == i)
center_indices = xpy.where(shuffled_sample_indices == i)

y[center_indices[0]] = i

X_k = generator.normal(
scale=std,
size=(len(center_indices[0]), n_features),
dtype=dtype,
)
).astype(dtype)

# NOTE: Adding the loc explicitly as cupy has a bug
# when calling generator.normal with an array for loc.
# cupy.random.normal, however, works with the same
# arguments
cp.add(X_k, centers[i], out=X_k)
xpy.add(X_k, centers[i], out=X_k)
X[center_indices[0], :] = X_k
else:
stop = 0
Expand All @@ -236,11 +235,11 @@ def make_blobs(

y[start:stop] = i

X_k = generator.normal(
scale=std, size=(n, n_features), dtype=dtype
X_k = generator.normal(scale=std, size=(n, n_features)).astype(
dtype
)

cp.add(X_k, centers[i], out=X_k)
xpy.add(X_k, centers[i], out=X_k)
X[start:stop, :] = X_k

if return_centers:
Expand Down
16 changes: 9 additions & 7 deletions python/cuml/cuml/datasets/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2020-2023, NVIDIA CORPORATION.
# Copyright (c) 2020-2025, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -13,23 +13,25 @@
# limitations under the License.
#

from cuml.internals.global_settings import GlobalSettings
from cuml.internals.safe_imports import gpu_only_import

cp = gpu_only_import("cupy")


def _create_rs_generator(random_state):
"""
This is a utility function that returns an instance of CuPy RandomState
This is a utility function that returns an instance of CuPy/numpy
RandomState depending on the current globally-selected device type
Parameters
----------
random_state : None, int, or CuPy RandomState
The random_state from which the CuPy random state is generated
random_state : None, int, or RandomState
The random_state from which the random state is generated
"""

if isinstance(random_state, (type(None), int)):
return cp.random.RandomState(seed=random_state)
elif isinstance(random_state, cp.random.RandomState):
return GlobalSettings().xpy.random.RandomState(seed=random_state)
elif isinstance(random_state, GlobalSettings().xpy.random.RandomState):
return random_state
else:
raise ValueError("random_state type must be int or CuPy RandomState")
raise ValueError("random_state type must be int or RandomState")
64 changes: 35 additions & 29 deletions python/cuml/cuml/tests/test_make_blobs.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2019-2023, NVIDIA CORPORATION.
# Copyright (c) 2019-2025, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -15,10 +15,10 @@

import cuml
import pytest
from cuml import global_settings
from cuml.common.device_selection import using_device_type
from cuml.internals.safe_imports import gpu_only_import

cp = gpu_only_import("cupy")

# Testing parameters for scalar parameter tests

dtype = ["single", "double"]
Expand Down Expand Up @@ -55,6 +55,7 @@
@pytest.mark.parametrize("shuffle", shuffle)
@pytest.mark.parametrize("random_state", random_state)
@pytest.mark.parametrize("order", ["F", "C"])
@pytest.mark.parametrize("device", ["gpu", "cpu"])
def test_make_blobs_scalar_parameters(
dtype,
n_samples,
Expand All @@ -65,31 +66,36 @@ def test_make_blobs_scalar_parameters(
shuffle,
random_state,
order,
device,
):

out, labels = cuml.make_blobs(
dtype=dtype,
n_samples=n_samples,
n_features=n_features,
centers=centers,
cluster_std=0.001,
center_box=center_box,
shuffle=shuffle,
random_state=random_state,
order=order,
)

assert out.shape == (n_samples, n_features), "out shape mismatch"
assert labels.shape == (n_samples,), "labels shape mismatch"

if order == "F":
assert out.flags["F_CONTIGUOUS"]
elif order == "C":
assert out.flags["C_CONTIGUOUS"]

if centers is None:
assert cp.unique(labels).shape == (3,), "unexpected number of clusters"
elif centers <= n_samples:
assert cp.unique(labels).shape == (
centers,
), "unexpected number of clusters"
with using_device_type(device):

out, labels = cuml.make_blobs(
dtype=dtype,
n_samples=n_samples,
n_features=n_features,
centers=centers,
cluster_std=0.001,
center_box=center_box,
shuffle=shuffle,
random_state=random_state,
order=order,
)

assert out.shape == (n_samples, n_features), "out shape mismatch"
assert labels.shape == (n_samples,), "labels shape mismatch"

if order == "F":
assert out.flags["F_CONTIGUOUS"]
elif order == "C":
assert out.flags["C_CONTIGUOUS"]

if centers is None:
assert global_settings.xpy.unique(labels).shape == (
3,
), "unexpected number of clusters"
elif centers <= n_samples:
assert global_settings.xpy.unique(labels).shape == (
centers,
), "unexpected number of clusters"

0 comments on commit 87727c1

Please sign in to comment.