Skip to content

Commit

Permalink
Update test_spectral.py - removed the restriction to comm_world size …
Browse files Browse the repository at this point in the history
…< 7 (#1200)

Co-authored-by: Claudia Comito <[email protected]>
  • Loading branch information
mrfh92 and ClaudiaComito authored Aug 23, 2023
1 parent 086457a commit 089ab2f
Showing 1 changed file with 41 additions and 43 deletions.
84 changes: 41 additions & 43 deletions heat/cluster/tests/test_spectral.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,51 +35,49 @@ def test_get_and_set_params(self):
self.assertEqual(10, spectral.n_clusters)

def test_fit_iris(self):
if ht.MPI_WORLD.size <= 4:
# todo: fix tests with >7 processes, NaNs appearing in spectral._spectral_embedding
# get some test data
iris = ht.load("heat/datasets/iris.csv", sep=";", split=0)
m = 10
# fit the clusters
spectral = ht.cluster.Spectral(
n_clusters=3, gamma=1.0, metric="rbf", laplacian="fully_connected", n_lanczos=m
)
spectral.fit(iris)
self.assertIsInstance(spectral.labels_, ht.DNDarray)
# get some test data
iris = ht.load("heat/datasets/iris.csv", sep=";", split=0)
m = 10
# fit the clusters
spectral = ht.cluster.Spectral(
n_clusters=3, gamma=1.0, metric="rbf", laplacian="fully_connected", n_lanczos=m
)
spectral.fit(iris)
self.assertIsInstance(spectral.labels_, ht.DNDarray)

spectral = ht.cluster.Spectral(
metric="euclidean",
laplacian="eNeighbour",
threshold=0.5,
boundary="upper",
n_lanczos=m,
)
labels = spectral.fit_predict(iris)
self.assertIsInstance(labels, ht.DNDarray)
spectral = ht.cluster.Spectral(
metric="euclidean",
laplacian="eNeighbour",
threshold=0.5,
boundary="upper",
n_lanczos=m,
)
labels = spectral.fit_predict(iris)
self.assertIsInstance(labels, ht.DNDarray)

spectral = ht.cluster.Spectral(
gamma=0.1,
metric="rbf",
laplacian="eNeighbour",
threshold=0.5,
boundary="upper",
n_lanczos=m,
)
labels = spectral.fit_predict(iris)
self.assertIsInstance(labels, ht.DNDarray)
spectral = ht.cluster.Spectral(
gamma=0.1,
metric="rbf",
laplacian="eNeighbour",
threshold=0.5,
boundary="upper",
n_lanczos=m,
)
labels = spectral.fit_predict(iris)
self.assertIsInstance(labels, ht.DNDarray)

kmeans = {"kmeans++": "kmeans++", "max_iter": 30, "tol": -1}
spectral = ht.cluster.Spectral(
n_clusters=3, gamma=1.0, normalize=True, n_lanczos=m, params=kmeans
)
labels = spectral.fit_predict(iris)
self.assertIsInstance(labels, ht.DNDarray)
kmeans = {"kmeans++": "kmeans++", "max_iter": 30, "tol": -1}
spectral = ht.cluster.Spectral(
n_clusters=3, gamma=1.0, normalize=True, n_lanczos=m, params=kmeans
)
labels = spectral.fit_predict(iris)
self.assertIsInstance(labels, ht.DNDarray)

# Errors
with self.assertRaises(NotImplementedError):
spectral = ht.cluster.Spectral(metric="ahalanobis", n_lanczos=m)
# Errors
with self.assertRaises(NotImplementedError):
spectral = ht.cluster.Spectral(metric="ahalanobis", n_lanczos=m)

iris_split = ht.load("heat/datasets/iris.csv", sep=";", split=1)
spectral = ht.cluster.Spectral(n_lanczos=20)
with self.assertRaises(NotImplementedError):
spectral.fit(iris_split)
iris_split = ht.load("heat/datasets/iris.csv", sep=";", split=1)
spectral = ht.cluster.Spectral(n_lanczos=20)
with self.assertRaises(NotImplementedError):
spectral.fit(iris_split)

0 comments on commit 089ab2f

Please sign in to comment.