diff --git a/heat/cluster/tests/test_spectral.py b/heat/cluster/tests/test_spectral.py index eb59d92ce5..9e24dddfc5 100644 --- a/heat/cluster/tests/test_spectral.py +++ b/heat/cluster/tests/test_spectral.py @@ -35,51 +35,49 @@ def test_get_and_set_params(self): self.assertEqual(10, spectral.n_clusters) def test_fit_iris(self): - if ht.MPI_WORLD.size <= 4: - # todo: fix tests with >7 processes, NaNs appearing in spectral._spectral_embedding - # get some test data - iris = ht.load("heat/datasets/iris.csv", sep=";", split=0) - m = 10 - # fit the clusters - spectral = ht.cluster.Spectral( - n_clusters=3, gamma=1.0, metric="rbf", laplacian="fully_connected", n_lanczos=m - ) - spectral.fit(iris) - self.assertIsInstance(spectral.labels_, ht.DNDarray) + # get some test data + iris = ht.load("heat/datasets/iris.csv", sep=";", split=0) + m = 10 + # fit the clusters + spectral = ht.cluster.Spectral( + n_clusters=3, gamma=1.0, metric="rbf", laplacian="fully_connected", n_lanczos=m + ) + spectral.fit(iris) + self.assertIsInstance(spectral.labels_, ht.DNDarray) - spectral = ht.cluster.Spectral( - metric="euclidean", - laplacian="eNeighbour", - threshold=0.5, - boundary="upper", - n_lanczos=m, - ) - labels = spectral.fit_predict(iris) - self.assertIsInstance(labels, ht.DNDarray) + spectral = ht.cluster.Spectral( + metric="euclidean", + laplacian="eNeighbour", + threshold=0.5, + boundary="upper", + n_lanczos=m, + ) + labels = spectral.fit_predict(iris) + self.assertIsInstance(labels, ht.DNDarray) - spectral = ht.cluster.Spectral( - gamma=0.1, - metric="rbf", - laplacian="eNeighbour", - threshold=0.5, - boundary="upper", - n_lanczos=m, - ) - labels = spectral.fit_predict(iris) - self.assertIsInstance(labels, ht.DNDarray) + spectral = ht.cluster.Spectral( + gamma=0.1, + metric="rbf", + laplacian="eNeighbour", + threshold=0.5, + boundary="upper", + n_lanczos=m, + ) + labels = spectral.fit_predict(iris) + self.assertIsInstance(labels, ht.DNDarray) - kmeans = {"kmeans++": "kmeans++", "max_iter": 30, "tol": -1} - spectral = ht.cluster.Spectral( - n_clusters=3, gamma=1.0, normalize=True, n_lanczos=m, params=kmeans - ) - labels = spectral.fit_predict(iris) - self.assertIsInstance(labels, ht.DNDarray) + kmeans = {"kmeans++": "kmeans++", "max_iter": 30, "tol": -1} + spectral = ht.cluster.Spectral( + n_clusters=3, gamma=1.0, normalize=True, n_lanczos=m, params=kmeans + ) + labels = spectral.fit_predict(iris) + self.assertIsInstance(labels, ht.DNDarray) - # Errors - with self.assertRaises(NotImplementedError): - spectral = ht.cluster.Spectral(metric="ahalanobis", n_lanczos=m) + # Errors + with self.assertRaises(NotImplementedError): + spectral = ht.cluster.Spectral(metric="ahalanobis", n_lanczos=m) - iris_split = ht.load("heat/datasets/iris.csv", sep=";", split=1) - spectral = ht.cluster.Spectral(n_lanczos=20) - with self.assertRaises(NotImplementedError): - spectral.fit(iris_split) + iris_split = ht.load("heat/datasets/iris.csv", sep=";", split=1) + spectral = ht.cluster.Spectral(n_lanczos=20) + with self.assertRaises(NotImplementedError): + spectral.fit(iris_split)