-
Notifications
You must be signed in to change notification settings - Fork 56
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Parallel initialization for k-means #1754
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,6 +3,8 @@ | |
""" | ||
|
||
import heat as ht | ||
import torch | ||
from heat.cluster.batchparallelclustering import _kmex | ||
from typing import Optional, Union, Callable | ||
from heat.core.dndarray import DNDarray | ||
|
||
|
@@ -94,14 +96,22 @@ def functional_value_(self) -> DNDarray: | |
""" | ||
return self._functional_value | ||
|
||
def _initialize_cluster_centers(self, x: DNDarray): | ||
def _initialize_cluster_centers( | ||
self, x: DNDarray, oversampling: float = 100, iter_multiplier: float = 20 | ||
): | ||
""" | ||
Initializes the K-Means centroids. | ||
|
||
Parameters | ||
---------- | ||
x : DNDarray | ||
The data to initialize the clusters for. Shape = (n_samples, n_features) | ||
|
||
oversampling : float | ||
oversampling factor used in the k-means|| initializiation of centroids | ||
|
||
iter_multiplier : float | ||
factor that increases the number of iterations used in the initialization of centroids | ||
""" | ||
# always initialize the random state | ||
if self.random_state is not None: | ||
|
@@ -123,53 +133,106 @@ def _initialize_cluster_centers(self, x: DNDarray): | |
raise ValueError("passed centroids do not match cluster count or data shape") | ||
self._cluster_centers = self.init.resplit(None) | ||
|
||
# Smart centroid guessing, random sampling with probability weight proportional to distance to existing centroids | ||
# Parallelized centroid guessing using the k-means|| algorithm | ||
elif self.init == "probability_based": | ||
# First, check along which axis the data is sliced | ||
if x.split is None or x.split == 0: | ||
centroids = ht.zeros( | ||
(self.n_clusters, x.shape[1]), split=None, device=x.device, comm=x.comm | ||
) | ||
sample = ht.random.randint(0, x.shape[0] - 1).item() | ||
_, displ, _ = x.comm.counts_displs_shape(shape=x.shape, axis=0) | ||
proc = 0 | ||
for p in range(x.comm.size): | ||
if displ[p] > sample: | ||
break | ||
proc = p | ||
x0 = ht.zeros(x.shape[1], dtype=x.dtype, device=x.device, comm=x.comm) | ||
if x.comm.rank == proc: | ||
idx = sample - displ[proc] | ||
x0 = ht.array(x.lloc[idx, :], device=x.device, comm=x.comm) | ||
x0.comm.Bcast(x0, root=proc) | ||
centroids[0, :] = x0 | ||
for i in range(1, self.n_clusters): | ||
distances = ht.spatial.distance.cdist(x, centroids, quadratic_expansion=True) | ||
D2 = distances.min(axis=1) | ||
D2.resplit_(axis=None) | ||
prob = D2 / D2.sum() | ||
random_position = ht.random.rand() | ||
sample = 0 | ||
sum = 0 | ||
for j in range(len(prob)): | ||
if sum > random_position: | ||
break | ||
sum += prob[j].item() | ||
sample = j | ||
proc = 0 | ||
for p in range(x.comm.size): | ||
if displ[p] > sample: | ||
break | ||
proc = p | ||
xi = ht.zeros(x.shape[1], dtype=x.dtype) | ||
if x.comm.rank == proc: | ||
idx = sample - displ[proc] | ||
xi = ht.array(x.lloc[idx, :], device=x.device, comm=x.comm) | ||
xi.comm.Bcast(xi, root=proc) | ||
centroids[i, :] = xi | ||
|
||
# Define a list of random, uniformly distributed probabilities, | ||
# which is later used to sample the centroids | ||
sample = ht.random.rand(x.shape[0], split=x.split) | ||
# Define a random integer serving as a label to pick the first centroid randomly | ||
init_idx = ht.random.randint(0, x.shape[0] - 1).item() | ||
# Randomly select first centroid and organize it as a tensor, in order to use the function cdist later. | ||
# This tensor will be filled continously in the proceeding of this function | ||
# We assume that the centroids fit into the memory of a single GPU | ||
centroids = ht.expand_dims(x[init_idx, :].resplit_(None), axis=0) | ||
# Calculate the initial cost of the clustering after the first centroid selection | ||
# and use it as an indicator for the order of magnitude for the number of necessary iterations | ||
init_distance = ht.spatial.distance.cdist(x, centroids, quadratic_expansion=True) | ||
# --> init_distance calculates the Euclidean distance between data points x and initial centroids | ||
# output format: tensor | ||
init_min_distance = init_distance.min(axis=1) | ||
# --> Pick the minimal distance of the data points to each centroid | ||
# output format: vector | ||
init_cost = init_min_distance.sum() | ||
# --> Now calculate the cost | ||
# output format: scalar | ||
# | ||
# Iteratively fill the tensor storing the centroids | ||
for _ in ht.arange(0, iter_multiplier * ht.log(init_cost)): | ||
# Calculate the distance between data points and the current set of centroids | ||
distance = ht.spatial.distance.cdist(x, centroids, quadratic_expansion=True) | ||
min_distance = distance.min(axis=1) | ||
# Sample each point in the data to a new set of centroids | ||
prob = oversampling * min_distance / min_distance.sum() | ||
# --> probability distribution with oversampling factor | ||
# output format: vector | ||
idx = ht.where(sample <= prob) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. one may think about moving the creation of sample here to have some new sampling every step of the loop |
||
# --> choose indices to sample the data according to prob | ||
# output format: vector | ||
local_data = x[idx].resplit_(centroids.split) | ||
# --> pick the data points that are identified as possible centroids and make sure | ||
# that data points and centroids are split in the same way | ||
# output format: vector | ||
centroids = ht.row_stack((centroids, local_data)) | ||
# --> stack the data points with these indices to the DNDarray of centroids | ||
# output format: tensor | ||
# Evaluate distance between final centroids and data points | ||
if centroids.shape[0] <= self.n_clusters: | ||
raise ValueError( | ||
"The oversampling factor and/or the number of iterations are chosen" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why two strings? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For the sake of readability, I would like to avoid single lines of code that are too long and thus decided to split the string in two lines. However, in Python you need a separate string for each line of code. The values of the parameters have been added tp the error message. |
||
"too small for the initialization of cluster centers." | ||
) | ||
# Evaluate the distance between data and the final set of centroids for the initialization | ||
final_distance = ht.spatial.distance.cdist(x, centroids, quadratic_expansion=True) | ||
# For each data point in x, find the index of the centroid that is closest | ||
final_idx = ht.argmin(final_distance, axis=1) | ||
# Introduce weights, i.e., the number of data points closest to each centroid | ||
# (count how often the same index in final_idx occurs) | ||
weights = ht.zeros(centroids.shape[0], split=centroids.split) | ||
for i in range(centroids.shape[0]): | ||
weights[i] = ht.sum(final_idx == i) | ||
# Recluster the oversampled centroids using standard k-means ++ (here we use the | ||
# already implemented version in torch) | ||
centroids = centroids.resplit_(None) | ||
centroids = centroids.larray | ||
weights = weights.resplit_(None) | ||
weights = weights.larray | ||
# --> first transform relevant arrays into torch tensors | ||
if ht.MPI_WORLD.rank == 0: | ||
batch_kmeans = _kmex( | ||
centroids, | ||
p=2, | ||
n_clusters=self.n_clusters, | ||
init="++", | ||
max_iter=self.max_iter, | ||
tol=self.tol, | ||
random_state=None, | ||
weights=weights, | ||
) | ||
# --> apply standard k-means ++ | ||
# Note: as we only recluster the centroids for initialization with standard k-means ++, | ||
# this list of centroids can also be used to initialize k-medians and k-medoids | ||
reclustered_centroids = batch_kmeans[0] | ||
# --> access the reclustered centroids | ||
else: | ||
# ensure that all processes have the same data | ||
reclustered_centroids = torch.zeros( | ||
(self.n_clusters, centroids.shape[1]), | ||
dtype=x.dtype.torch_type(), | ||
device=centroids.device, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. centroids.device.torch_device()? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you please specify your comment? |
||
) | ||
# --> tensor with zeros that has the same size as reclustered centroids, in order to to | ||
# allocate memory with the correct type in all processes(necessary for broadcast) | ||
ht.MPI_WORLD.Bcast( | ||
reclustered_centroids, root=0 | ||
) # by default it is broadcasted from process 0 | ||
reclustered_centroids = ht.array(reclustered_centroids, split=x.split) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I dont know whether we want to split the centroids as they are probably only a small number. So this might produce overhead compared to having split=None here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. split was set to None |
||
# --> transform back to DNDarray | ||
self._cluster_centers = reclustered_centroids | ||
# --> final result for initialized cluster centers | ||
else: | ||
raise NotImplementedError("Not implemented for other splitting-axes") | ||
self._cluster_centers = centroids | ||
|
||
elif self.init == "batchparallel": | ||
if x.split == 0: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,7 +4,8 @@ | |
|
||
import heat as ht | ||
import torch | ||
from heat.cluster._kcluster import _KCluster | ||
|
||
# from heat.cluster._kcluster import _KCluster | ||
from heat.core.dndarray import DNDarray | ||
from warnings import warn | ||
from math import log | ||
|
@@ -19,10 +20,14 @@ | |
""" | ||
|
||
|
||
def _initialize_plus_plus(X, n_clusters, p, random_state=None, max_samples=2**24 - 1): | ||
def _initialize_plus_plus( | ||
X, n_clusters, p, random_state=None, weights: torch.tensor = 1, max_samples=2**24 - 1 | ||
): | ||
""" | ||
Auxiliary function: single-process k-means++/k-medians++ initialization in pytorch | ||
p is the norm used for computing distances | ||
weights allows to add weights to the distribution function, so that the data points with higher weights are preferred; | ||
note that weights must have the same dimension as X[0] | ||
The value max_samples=2**24 - 1 is necessary as PyTorchs multinomial currently only | ||
supports this number of different categories. | ||
""" | ||
|
@@ -37,11 +42,11 @@ def _initialize_plus_plus(X, n_clusters, p, random_state=None, max_samples=2**24 | |
for i in range(1, n_clusters): | ||
dist = torch.cdist(X, X[idxs[:i]], p=p) | ||
dist = torch.min(dist, dim=1)[0] | ||
idxs[i] = torch.multinomial(dist, 1) | ||
idxs[i] = torch.multinomial(weights * dist, 1) | ||
return X[idxs] | ||
|
||
|
||
def _kmex(X, p, n_clusters, init, max_iter, tol, random_state=None): | ||
def _kmex(X, p, n_clusters, init, max_iter, tol, random_state=None, weights: torch.tensor = 1.0): | ||
""" | ||
Auxiliary function: single-process k-means and k-medians in pytorch | ||
p is the norm used for computing distances: p=2 implies k-means, p=1 implies k-medians. | ||
|
@@ -55,7 +60,7 @@ def _kmex(X, p, n_clusters, init, max_iter, tol, random_state=None): | |
raise ValueError("if a torch tensor, init must have shape (n_clusters, n_features).") | ||
centers = init | ||
elif init == "++": | ||
centers = _initialize_plus_plus(X, n_clusters, p, random_state) | ||
centers = _initialize_plus_plus(X, n_clusters, p, random_state, weights) | ||
elif init == "random": | ||
idxs = torch.randint(0, X.shape[0], (n_clusters,)) | ||
centers = X[idxs] | ||
|
@@ -169,7 +174,7 @@ def functional_value_(self) -> float: | |
""" | ||
return self._functional_value | ||
|
||
def fit(self, x: DNDarray): | ||
def fit(self, x: DNDarray, weights: torch.tensor = 1): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is the fit of batchparallel clustering. If we allow for specification of weights here, it must be a DNDarray, and below the corresponding local arrays would have to be used for the local clusterings. |
||
""" | ||
Computes the centroid of the clustering algorithm to fit the data ``x``. | ||
|
||
|
@@ -178,6 +183,8 @@ def fit(self, x: DNDarray): | |
x : DNDarray | ||
Training instances to cluster. Shape = (n_samples, n_features). It must hold x.split=0. | ||
|
||
weights: torch.tensor | ||
Add weights to the distribution function used in the clustering algorithm in kmex | ||
""" | ||
if not isinstance(x, DNDarray): | ||
raise TypeError(f"input needs to be a ht.DNDarray, but was {type(x)}") | ||
|
@@ -198,6 +205,7 @@ def fit(self, x: DNDarray): | |
self.max_iter, | ||
self.tol, | ||
local_random_state, | ||
weights, | ||
) | ||
|
||
# hierarchical approach to obtail "global" cluster centers from the "local" centers | ||
|
@@ -233,6 +241,7 @@ def fit(self, x: DNDarray): | |
self.max_iter, | ||
self.tol, | ||
local_random_state, | ||
weights, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. see above |
||
) | ||
del gathered_centers_local | ||
n_iters_local += n_iters_local_new | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -102,7 +102,7 @@ def _update_centroids(self, x: DNDarray, matching_centroids: DNDarray): | |
|
||
return new_cluster_centers | ||
|
||
def fit(self, x: DNDarray) -> self: | ||
def fit(self, x: DNDarray, oversampling: float = 100, iter_multiplier: float = 20) -> self: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are there references for these default values? I had a look into dask and they use oversampling=2 as default. |
||
""" | ||
Computes the centroid of a k-means clustering. | ||
|
||
|
@@ -111,13 +111,18 @@ def fit(self, x: DNDarray) -> self: | |
x : DNDarray | ||
Training instances to cluster. Shape = (n_samples, n_features) | ||
|
||
oversampling : float | ||
oversampling factor used for the k-means|| initializiation of centroids | ||
|
||
iter_multiplier : float | ||
factor that increases the number of iterations used in the initialization of centroids | ||
""" | ||
# input sanitation | ||
if not isinstance(x, DNDarray): | ||
raise ValueError(f"input needs to be a ht.DNDarray, but was {type(x)}") | ||
|
||
# initialize the clustering | ||
self._initialize_cluster_centers(x) | ||
self._initialize_cluster_centers(x, oversampling, iter_multiplier) | ||
self._n_iter = 0 | ||
|
||
# iteratively fit the points to the centroids | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for loop counters, the standard python range should be more effective
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done