Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parallel initialization for k-means #1754

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 107 additions & 44 deletions heat/cluster/_kcluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
"""

import heat as ht
import torch
from heat.cluster.batchparallelclustering import _kmex
from typing import Optional, Union, Callable
from heat.core.dndarray import DNDarray

Expand Down Expand Up @@ -94,14 +96,22 @@ def functional_value_(self) -> DNDarray:
"""
return self._functional_value

def _initialize_cluster_centers(self, x: DNDarray):
def _initialize_cluster_centers(
self, x: DNDarray, oversampling: float = 100, iter_multiplier: float = 20
):
"""
Initializes the K-Means centroids.

Parameters
----------
x : DNDarray
The data to initialize the clusters for. Shape = (n_samples, n_features)

oversampling : float
oversampling factor used in the k-means|| initializiation of centroids

iter_multiplier : float
factor that increases the number of iterations used in the initialization of centroids
"""
# always initialize the random state
if self.random_state is not None:
Expand All @@ -123,53 +133,106 @@ def _initialize_cluster_centers(self, x: DNDarray):
raise ValueError("passed centroids do not match cluster count or data shape")
self._cluster_centers = self.init.resplit(None)

# Smart centroid guessing, random sampling with probability weight proportional to distance to existing centroids
# Parallelized centroid guessing using the k-means|| algorithm
elif self.init == "probability_based":
# First, check along which axis the data is sliced
if x.split is None or x.split == 0:
centroids = ht.zeros(
(self.n_clusters, x.shape[1]), split=None, device=x.device, comm=x.comm
)
sample = ht.random.randint(0, x.shape[0] - 1).item()
_, displ, _ = x.comm.counts_displs_shape(shape=x.shape, axis=0)
proc = 0
for p in range(x.comm.size):
if displ[p] > sample:
break
proc = p
x0 = ht.zeros(x.shape[1], dtype=x.dtype, device=x.device, comm=x.comm)
if x.comm.rank == proc:
idx = sample - displ[proc]
x0 = ht.array(x.lloc[idx, :], device=x.device, comm=x.comm)
x0.comm.Bcast(x0, root=proc)
centroids[0, :] = x0
for i in range(1, self.n_clusters):
distances = ht.spatial.distance.cdist(x, centroids, quadratic_expansion=True)
D2 = distances.min(axis=1)
D2.resplit_(axis=None)
prob = D2 / D2.sum()
random_position = ht.random.rand()
sample = 0
sum = 0
for j in range(len(prob)):
if sum > random_position:
break
sum += prob[j].item()
sample = j
proc = 0
for p in range(x.comm.size):
if displ[p] > sample:
break
proc = p
xi = ht.zeros(x.shape[1], dtype=x.dtype)
if x.comm.rank == proc:
idx = sample - displ[proc]
xi = ht.array(x.lloc[idx, :], device=x.device, comm=x.comm)
xi.comm.Bcast(xi, root=proc)
centroids[i, :] = xi

# Define a list of random, uniformly distributed probabilities,
# which is later used to sample the centroids
sample = ht.random.rand(x.shape[0], split=x.split)
# Define a random integer serving as a label to pick the first centroid randomly
init_idx = ht.random.randint(0, x.shape[0] - 1).item()
# Randomly select first centroid and organize it as a tensor, in order to use the function cdist later.
# This tensor will be filled continously in the proceeding of this function
# We assume that the centroids fit into the memory of a single GPU
centroids = ht.expand_dims(x[init_idx, :].resplit_(None), axis=0)
# Calculate the initial cost of the clustering after the first centroid selection
# and use it as an indicator for the order of magnitude for the number of necessary iterations
init_distance = ht.spatial.distance.cdist(x, centroids, quadratic_expansion=True)
# --> init_distance calculates the Euclidean distance between data points x and initial centroids
# output format: tensor
init_min_distance = init_distance.min(axis=1)
# --> Pick the minimal distance of the data points to each centroid
# output format: vector
init_cost = init_min_distance.sum()
# --> Now calculate the cost
# output format: scalar
#
# Iteratively fill the tensor storing the centroids
for _ in ht.arange(0, iter_multiplier * ht.log(init_cost)):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for loop counters, the standard python range should be more effective

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

# Calculate the distance between data points and the current set of centroids
distance = ht.spatial.distance.cdist(x, centroids, quadratic_expansion=True)
min_distance = distance.min(axis=1)
# Sample each point in the data to a new set of centroids
prob = oversampling * min_distance / min_distance.sum()
# --> probability distribution with oversampling factor
# output format: vector
idx = ht.where(sample <= prob)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one may think about moving the creation of sample here to have some new sampling every step of the loop

# --> choose indices to sample the data according to prob
# output format: vector
local_data = x[idx].resplit_(centroids.split)
# --> pick the data points that are identified as possible centroids and make sure
# that data points and centroids are split in the same way
# output format: vector
centroids = ht.row_stack((centroids, local_data))
# --> stack the data points with these indices to the DNDarray of centroids
# output format: tensor
# Evaluate distance between final centroids and data points
if centroids.shape[0] <= self.n_clusters:
raise ValueError(
"The oversampling factor and/or the number of iterations are chosen"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why two strings?
It may also be helpful to use sth like f"The oversampling factor (={oversampling}) ... " to give the user the values in the error message.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the sake of readability, I would like to avoid single lines of code that are too long and thus decided to split the string in two lines. However, in Python you need a separate string for each line of code. The values of the parameters have been added tp the error message.

"too small for the initialization of cluster centers."
)
# Evaluate the distance between data and the final set of centroids for the initialization
final_distance = ht.spatial.distance.cdist(x, centroids, quadratic_expansion=True)
# For each data point in x, find the index of the centroid that is closest
final_idx = ht.argmin(final_distance, axis=1)
# Introduce weights, i.e., the number of data points closest to each centroid
# (count how often the same index in final_idx occurs)
weights = ht.zeros(centroids.shape[0], split=centroids.split)
for i in range(centroids.shape[0]):
weights[i] = ht.sum(final_idx == i)
# Recluster the oversampled centroids using standard k-means ++ (here we use the
# already implemented version in torch)
centroids = centroids.resplit_(None)
centroids = centroids.larray
weights = weights.resplit_(None)
weights = weights.larray
# --> first transform relevant arrays into torch tensors
if ht.MPI_WORLD.rank == 0:
batch_kmeans = _kmex(
centroids,
p=2,
n_clusters=self.n_clusters,
init="++",
max_iter=self.max_iter,
tol=self.tol,
random_state=None,
weights=weights,
)
# --> apply standard k-means ++
# Note: as we only recluster the centroids for initialization with standard k-means ++,
# this list of centroids can also be used to initialize k-medians and k-medoids
reclustered_centroids = batch_kmeans[0]
# --> access the reclustered centroids
else:
# ensure that all processes have the same data
reclustered_centroids = torch.zeros(
(self.n_clusters, centroids.shape[1]),
dtype=x.dtype.torch_type(),
device=centroids.device,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

centroids.device.torch_device()?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please specify your comment?

)
# --> tensor with zeros that has the same size as reclustered centroids, in order to to
# allocate memory with the correct type in all processes(necessary for broadcast)
ht.MPI_WORLD.Bcast(
reclustered_centroids, root=0
) # by default it is broadcasted from process 0
reclustered_centroids = ht.array(reclustered_centroids, split=x.split)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont know whether we want to split the centroids as they are probably only a small number. So this might produce overhead compared to having split=None here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

split was set to None

# --> transform back to DNDarray
self._cluster_centers = reclustered_centroids
# --> final result for initialized cluster centers
else:
raise NotImplementedError("Not implemented for other splitting-axes")
self._cluster_centers = centroids

elif self.init == "batchparallel":
if x.split == 0:
Expand Down
21 changes: 15 additions & 6 deletions heat/cluster/batchparallelclustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

import heat as ht
import torch
from heat.cluster._kcluster import _KCluster

# from heat.cluster._kcluster import _KCluster
from heat.core.dndarray import DNDarray
from warnings import warn
from math import log
Expand All @@ -19,10 +20,14 @@
"""


def _initialize_plus_plus(X, n_clusters, p, random_state=None, max_samples=2**24 - 1):
def _initialize_plus_plus(
X, n_clusters, p, random_state=None, weights: torch.tensor = 1, max_samples=2**24 - 1
):
"""
Auxiliary function: single-process k-means++/k-medians++ initialization in pytorch
p is the norm used for computing distances
weights allows to add weights to the distribution function, so that the data points with higher weights are preferred;
note that weights must have the same dimension as X[0]
The value max_samples=2**24 - 1 is necessary as PyTorchs multinomial currently only
supports this number of different categories.
"""
Expand All @@ -37,11 +42,11 @@ def _initialize_plus_plus(X, n_clusters, p, random_state=None, max_samples=2**24
for i in range(1, n_clusters):
dist = torch.cdist(X, X[idxs[:i]], p=p)
dist = torch.min(dist, dim=1)[0]
idxs[i] = torch.multinomial(dist, 1)
idxs[i] = torch.multinomial(weights * dist, 1)
return X[idxs]


def _kmex(X, p, n_clusters, init, max_iter, tol, random_state=None):
def _kmex(X, p, n_clusters, init, max_iter, tol, random_state=None, weights: torch.tensor = 1.0):
"""
Auxiliary function: single-process k-means and k-medians in pytorch
p is the norm used for computing distances: p=2 implies k-means, p=1 implies k-medians.
Expand All @@ -55,7 +60,7 @@ def _kmex(X, p, n_clusters, init, max_iter, tol, random_state=None):
raise ValueError("if a torch tensor, init must have shape (n_clusters, n_features).")
centers = init
elif init == "++":
centers = _initialize_plus_plus(X, n_clusters, p, random_state)
centers = _initialize_plus_plus(X, n_clusters, p, random_state, weights)
elif init == "random":
idxs = torch.randint(0, X.shape[0], (n_clusters,))
centers = X[idxs]
Expand Down Expand Up @@ -169,7 +174,7 @@ def functional_value_(self) -> float:
"""
return self._functional_value

def fit(self, x: DNDarray):
def fit(self, x: DNDarray, weights: torch.tensor = 1):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the fit of batchparallel clustering. If we allow for specification of weights here, it must be a DNDarray, and below the corresponding local arrays would have to be used for the local clusterings.

"""
Computes the centroid of the clustering algorithm to fit the data ``x``.

Expand All @@ -178,6 +183,8 @@ def fit(self, x: DNDarray):
x : DNDarray
Training instances to cluster. Shape = (n_samples, n_features). It must hold x.split=0.

weights: torch.tensor
Add weights to the distribution function used in the clustering algorithm in kmex
"""
if not isinstance(x, DNDarray):
raise TypeError(f"input needs to be a ht.DNDarray, but was {type(x)}")
Expand All @@ -198,6 +205,7 @@ def fit(self, x: DNDarray):
self.max_iter,
self.tol,
local_random_state,
weights,
)

# hierarchical approach to obtail "global" cluster centers from the "local" centers
Expand Down Expand Up @@ -233,6 +241,7 @@ def fit(self, x: DNDarray):
self.max_iter,
self.tol,
local_random_state,
weights,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see above

)
del gathered_centers_local
n_iters_local += n_iters_local_new
Expand Down
9 changes: 7 additions & 2 deletions heat/cluster/kmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def _update_centroids(self, x: DNDarray, matching_centroids: DNDarray):

return new_cluster_centers

def fit(self, x: DNDarray) -> self:
def fit(self, x: DNDarray, oversampling: float = 100, iter_multiplier: float = 20) -> self:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there references for these default values? I had a look into dask and they use oversampling=2 as default.

"""
Computes the centroid of a k-means clustering.

Expand All @@ -111,13 +111,18 @@ def fit(self, x: DNDarray) -> self:
x : DNDarray
Training instances to cluster. Shape = (n_samples, n_features)

oversampling : float
oversampling factor used for the k-means|| initializiation of centroids

iter_multiplier : float
factor that increases the number of iterations used in the initialization of centroids
"""
# input sanitation
if not isinstance(x, DNDarray):
raise ValueError(f"input needs to be a ht.DNDarray, but was {type(x)}")

# initialize the clustering
self._initialize_cluster_centers(x)
self._initialize_cluster_centers(x, oversampling, iter_multiplier)
self._n_iter = 0

# iteratively fit the points to the centroids
Expand Down
11 changes: 9 additions & 2 deletions heat/cluster/kmedians.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def _update_centroids(self, x: DNDarray, matching_centroids: DNDarray):
----------
x : DNDarray
Input data

matching_centroids : DNDarray
Array filled with indeces ``i`` indicating to which cluster ``ci`` each sample point in x is assigned

Expand Down Expand Up @@ -103,21 +104,27 @@ def _update_centroids(self, x: DNDarray, matching_centroids: DNDarray):

return new_cluster_centers

def fit(self, x: DNDarray):
def fit(self, x: DNDarray, oversampling: float = 100, iter_multiplier: float = 20):
"""
Computes the centroid of a k-medians clustering.

Parameters
----------
x : DNDarray
Training instances to cluster. Shape = (n_samples, n_features)

oversampling : float
oversampling factor used in the k-means|| initializiation of centroids

iter_multiplier : float
factor that increases the number of iterations used in the initialization of centroids
"""
# input sanitation
if not isinstance(x, ht.DNDarray):
raise ValueError(f"input needs to be a ht.DNDarray, but was {type(x)}")

# initialize the clustering
self._initialize_cluster_centers(x)
self._initialize_cluster_centers(x, oversampling, iter_multiplier)
self._n_iter = 0

# iteratively fit the points to the centroids
Expand Down
9 changes: 7 additions & 2 deletions heat/cluster/kmedoids.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,21 +114,26 @@ def _update_centroids(self, x: DNDarray, matching_centroids: DNDarray):

return new_cluster_centers

def fit(self, x: DNDarray):
def fit(self, x: DNDarray, oversampling: float = 100, iter_multiplier: float = 20):
"""
Computes the centroid of a k-medoids clustering.

Parameters
----------
x : DNDarray
Training instances to cluster. Shape = (n_samples, n_features)
oversampling : float
oversampling factor used in the k-means|| initializiation of centroids

iter_multiplier : float
factor that increases the number of iterations used in the initialization of centroids
"""
# input sanitation
if not isinstance(x, DNDarray):
raise ValueError(f"input needs to be a ht.DNDarray, but was {type(x)}")

# initialize the clustering
self._initialize_cluster_centers(x)
self._initialize_cluster_centers(x, oversampling, iter_multiplier)
self._n_iter = 0

# iteratively fit the points to the centroids
Expand Down
5 changes: 1 addition & 4 deletions heat/cluster/tests/test_kmedoids.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,7 @@ def test_spherical_clusters(self):
self.assertEqual(kmedoid.cluster_centers_.shape, (4, 3))
for i in range(kmedoid.cluster_centers_.shape[0]):
self.assertTrue(
ht.any(
ht.sum(ht.abs(kmedoid.cluster_centers_[i, :] - data.astype(ht.float32)), axis=1)
== 0
)
ht.any(ht.sum(ht.abs(kmedoid.cluster_centers_[i, :] - data), axis=1) == 0)
)

# on Ints (different radius, offset and datatype
Expand Down
3 changes: 2 additions & 1 deletion heat/core/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,9 @@ def nonzero(x: DNDarray) -> DNDarray:

if x.ndim == 1:
lcl_nonzero = lcl_nonzero.squeeze(dim=1)

for g in range(len(gout) - 1, -1, -1):
if gout[g] == 1:
if gout[g] == 1 and len(gout) > 1:
del gout[g]

return DNDarray(
Expand Down
Loading