From ebc21a7ad9b8e7fca03d5d631491670891e55a8f Mon Sep 17 00:00:00 2001 From: Yang Long Date: Fri, 28 Jun 2024 10:41:25 +0100 Subject: [PATCH 1/2] update --- examples/interactive_Kmeans.py | 18 +++ .../Kmeans interactive _ more clusters.ipynb | 130 ++++++++++++++++++ .../yang/Kmeans interactive _ more.ipynb | 124 +++++++++++++++++ notebooks/yang/Kmeans interactive.ipynb | 43 +----- vidar/__init__.py | 1 + vidar/interactive/interactive_layout_yang.py | 86 ++++++++++++ .../tests/test_interactive_layout_yang.py | 29 ++++ 7 files changed, 392 insertions(+), 39 deletions(-) create mode 100644 examples/interactive_Kmeans.py create mode 100644 notebooks/yang/Kmeans interactive _ more clusters.ipynb create mode 100644 notebooks/yang/Kmeans interactive _ more.ipynb create mode 100644 vidar/__init__.py create mode 100644 vidar/interactive/tests/test_interactive_layout_yang.py diff --git a/examples/interactive_Kmeans.py b/examples/interactive_Kmeans.py new file mode 100644 index 0000000..37ed93c --- /dev/null +++ b/examples/interactive_Kmeans.py @@ -0,0 +1,18 @@ +import sys +sys.path.append('..') + +import numpy as np +from vidar import InteractiveKMeans, interactive_kmeans + +rng = np.random.RandomState(0) +n_samples = 1000 +cov = [[0.4, 0], [0, 0.4]] +X = np.concatenate([ + rng.multivariate_normal(mean=[-2, 0], cov=cov, size=n_samples), + rng.multivariate_normal(mean=[2, 0], cov=cov, size=n_samples), + rng.multivariate_normal(mean=[0.3, 1], cov=cov, size=n_samples) + ]) + +n_clusters = 3 +app = interactive_kmeans(X, n_clusters) +app.show() \ No newline at end of file diff --git a/notebooks/yang/Kmeans interactive _ more clusters.ipynb b/notebooks/yang/Kmeans interactive _ more clusters.ipynb new file mode 100644 index 0000000..c023ace --- /dev/null +++ b/notebooks/yang/Kmeans interactive _ more clusters.ipynb @@ -0,0 +1,130 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# K-means interactive\n", + "\n", + "> Yang" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "%matplotlib qt\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "from matplotlib.widgets import Button\n", + "from matplotlib.widgets import PolygonSelector\n", + "from sklearn.cluster import KMeans\n", + "\n", + "def colors_from_lbs(lbs, colors=None):\n", + " mpl_20 = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd',\n", + " '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf',\n", + " '#3397dc', '#ff993e', '#3fca3f', '#df5152', '#a985ca',\n", + " '#ad7165', '#e992ce', '#999999', '#dbdc3c', '#35d8e9']\n", + " \n", + " if colors is None:\n", + " colors = np.array(mpl_20)\n", + " else:\n", + " colors = np.array(colors)\n", + " lbs = np.array(lbs) % len(colors)\n", + " return colors[lbs]\n", + "\n", + "rng = np.random.RandomState(0)\n", + "n_samples = 1000\n", + "cov = [[0.4, 0], [0, 0.4]]\n", + "X = np.concatenate([\n", + " rng.multivariate_normal(mean=[-2, 0], cov=cov, size=n_samples), \n", + " rng.multivariate_normal(mean=[2, 0], cov=cov, size=n_samples),\n", + " rng.multivariate_normal(mean=[0.5, 1], cov=cov, size=n_samples)\n", + " ])\n", + "\n", + "n_clusters=3\n", + "kmeans = KMeans(n_clusters=n_clusters, random_state=0, n_init=\"auto\")\n", + "labels = kmeans.fit_predict(X)\n", + "\n", + "centers = kmeans.cluster_centers_\n", + "\n", + "fig, (ax_orig, ax_redim) = plt.subplots(1, 2, figsize=(12, 6))\n", + "\n", + "def plot_figure(axe_list, X, centers):\n", + " ax_orig, ax_redim = axe_list\n", + "\n", + " kmeans.cluster_centers_ = np.array(centers, dtype=np.float64)\n", + " labels = kmeans.predict(X) \n", + "\n", + " ax_orig.clear()\n", + " ax_orig.scatter(X[:, 0], X[:, 1], alpha=0.3, label=\"samples\", c=colors_from_lbs(labels))\n", + " ax_orig.scatter(centers[:,0], centers[:,1], s=50, c='black', edgecolors='r')\n", + " ax_orig.set(\n", + " aspect=\"auto\", \n", + " title=\"Interactive K-means\",\n", + " xlabel=\"first feature\",\n", + " ylabel=\"second feature\",\n", + " )\n", + "\n", + " ax_redim.clear()\n", + " class_name = ['class {0}'.format(i+1) for i in range(len(centers))]\n", + "\n", + " # update labels\n", + " counts = [np.sum(labels==i) for i in range(len(centers))]\n", + " \n", + "\n", + " ax_redim.bar(class_name, counts, label=class_name, color=colors_from_lbs(range(len(centers))))\n", + " ax_redim.set(\n", + " aspect=\"auto\",\n", + " title=\"Clustering results\",\n", + " xlabel=\"Main feature\",\n", + " ylabel=\"Number of samples\",\n", + " )\n", + " fig.canvas.draw_idle()\n", + "\n", + "plot_figure((ax_orig, ax_redim), X, centers)\n", + "\n", + "def onselect(verts):\n", + " centers = np.array(verts)\n", + " plot_figure((ax_orig, ax_redim), X, centers)\n", + "\n", + "selector = PolygonSelector(ax_orig, onselect=onselect, \n", + " props=dict(color='r', linestyle='', linewidth=3, alpha=0.6, label=f\"Component\"))\n", + "selector.verts = centers\n", + "\n", + "plt.tight_layout()\n", + "plt.show()\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "base", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.19" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/notebooks/yang/Kmeans interactive _ more.ipynb b/notebooks/yang/Kmeans interactive _ more.ipynb new file mode 100644 index 0000000..72f8c63 --- /dev/null +++ b/notebooks/yang/Kmeans interactive _ more.ipynb @@ -0,0 +1,124 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# K-means interactive\n", + "\n", + "> Yang" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "%matplotlib qt\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "from matplotlib.widgets import Button\n", + "from matplotlib.widgets import PolygonSelector\n", + "from sklearn.cluster import KMeans\n", + "\n", + "def colors_from_lbs(lbs, colors=None):\n", + " mpl_20 = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd',\n", + " '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf',\n", + " '#3397dc', '#ff993e', '#3fca3f', '#df5152', '#a985ca',\n", + " '#ad7165', '#e992ce', '#999999', '#dbdc3c', '#35d8e9']\n", + " \n", + " if colors is None:\n", + " colors = np.array(mpl_20)\n", + " else:\n", + " colors = np.array(colors)\n", + " lbs = np.array(lbs) % len(colors)\n", + " return colors[lbs]\n", + "\n", + "rng = np.random.RandomState(0)\n", + "n_samples = 1000\n", + "cov = [[0.4, 0], [0, 0.4]]\n", + "X = np.concatenate([\n", + " rng.multivariate_normal(mean=[-2, 0], cov=cov, size=n_samples), \n", + " rng.multivariate_normal(mean=[2, 0], cov=cov, size=n_samples),\n", + " rng.multivariate_normal(mean=[0.3, 1], cov=cov, size=n_samples)\n", + " ])\n", + "\n", + "kmeans = KMeans(n_clusters=2, random_state=0, n_init=\"auto\")\n", + "labels = kmeans.fit_predict(X)\n", + "\n", + "centers = kmeans.cluster_centers_\n", + "\n", + "fig, (ax_orig, ax_redim) = plt.subplots(1, 2, figsize=(12, 6))\n", + "\n", + "def plot_figure(axe_list, X, centers):\n", + " ax_orig, ax_redim = axe_list\n", + "\n", + " kmeans.cluster_centers_ = np.array(centers, dtype=np.float64)\n", + " labels = kmeans.predict(X) \n", + "\n", + " ax_orig.clear()\n", + " ax_orig.scatter(X[:, 0], X[:, 1], alpha=0.3, label=\"samples\", c=colors_from_lbs(labels))\n", + " ax_orig.scatter(centers[:,0], centers[:,1], s=50, c='black', edgecolors='r')\n", + " ax_orig.set(\n", + " aspect=\"auto\", \n", + " title=\"Interactive K-means\",\n", + " xlabel=\"first feature\",\n", + " ylabel=\"second feature\",\n", + " )\n", + "\n", + " ax_redim.clear()\n", + " class_name = ['class {0}'.format(i+1) for i in range(len(centers))]\n", + "\n", + " # update labels\n", + " counts = [np.sum(labels==i) for i in range(len(centers))]\n", + " \n", + "\n", + " ax_redim.bar(class_name, counts, \n", + " label=class_name,\n", + " color=colors_from_lbs(range(len(centers))))\n", + " ax_redim.set(\n", + " aspect=\"auto\",\n", + " title=\"Clustering results\",\n", + " xlabel=\"Main feature\",\n", + " ylabel=\"Number of samples\",\n", + " )\n", + " fig.canvas.draw_idle()\n", + "\n", + "plot_figure((ax_orig, ax_redim), X, centers)\n", + "\n", + "def onselect(verts):\n", + " centers = np.array(verts)\n", + " plot_figure((ax_orig, ax_redim), X, centers)\n", + "\n", + "selector = PolygonSelector(ax_orig, onselect=onselect, \n", + " props=dict(color='r', linestyle='', linewidth=3, alpha=0.6, label=f\"Component\"))\n", + "selector.verts = centers\n", + "\n", + "plt.tight_layout()\n", + "plt.show()\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "base", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.19" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/notebooks/yang/Kmeans interactive.ipynb b/notebooks/yang/Kmeans interactive.ipynb index 9bd717c..e3c2c46 100644 --- a/notebooks/yang/Kmeans interactive.ipynb +++ b/notebooks/yang/Kmeans interactive.ipynb @@ -11,7 +11,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -63,13 +63,13 @@ " ax_orig.scatter(centers[:,0], centers[:,1], s=50, c='black', edgecolors='r')\n", " ax_orig.set(\n", " aspect=\"auto\", \n", - " title=\"K-means\",\n", + " title=\"Interactive K-means\",\n", " xlabel=\"first feature\",\n", " ylabel=\"second feature\",\n", " )\n", "\n", " ax_redim.clear()\n", - " class_name = ['class {0}'.format(i) for i in range(len(centers))]\n", + " class_name = ['class {0}'.format(i+1) for i in range(len(centers))]\n", "\n", " # update labels\n", " counts = [np.sum(labels==i) for i in range(len(centers))]\n", @@ -95,17 +95,6 @@ " props=dict(color='r', linestyle='', linewidth=3, alpha=0.6, label=f\"Component\"))\n", "selector.verts = centers\n", "\n", - "\n", - "# ax_redim.hist((X @ component.T - x_center @ component.T),50)\n", - "# ax_redim.set(\n", - "# aspect=\"auto\",\n", - "# title=\"1-dimensional dataset after dimension reduction\",\n", - "# xlabel=\"Main feature\",\n", - "# ylabel=\"Number of samples\",\n", - "# )\n", - "#_asp = np.diff(ax_orig.get_ylim())[0] / np.diff(ax_orig.get_xlim())[0]\n", - "#ax_redim.set_aspect(_asp)\n", - "\n", "plt.tight_layout()\n", "plt.show()\n" ] @@ -115,31 +104,7 @@ "execution_count": null, "metadata": {}, "outputs": [], - "source": [ - "\n", - "\n", - "# pca = PCA(n_components=1).fit(X)\n", - "# component = pca.components_.reshape(-1)\n", - "\n", - "# # print(pca.components_)\n", - "# # print(pca.explained_variance_)\n", - "# # print(list(zip(pca.components_, pca.explained_variance_)))\n", - "\n", - "# # fig, (ax_orig, ax_redim) = plt.subplots(1, 2, figsize=(12, 6))\n", - "# # ax_orig.scatter(X[:, 0], X[:, 1], alpha=0.3, label=\"samples\")\n", - "# # x_center = np.mean(X, axis=0)\n", - "\n", - "# comp_vector = [component, x_center]\n", - "\n", - "# ax_orig.set(\n", - "# aspect=\"auto\", \n", - "# title=\"2-dimensional dataset with principal components\",\n", - "# xlabel=\"first feature\",\n", - "# ylabel=\"second feature\",\n", - "# )\n", - "\n", - "\n" - ] + "source": [] } ], "metadata": { diff --git a/vidar/__init__.py b/vidar/__init__.py new file mode 100644 index 0000000..a4dd1d6 --- /dev/null +++ b/vidar/__init__.py @@ -0,0 +1 @@ +from .interactive.interactive_layout_yang import InteractiveKMeans, interactive_kmeans \ No newline at end of file diff --git a/vidar/interactive/interactive_layout_yang.py b/vidar/interactive/interactive_layout_yang.py index e69de29..19ce49e 100644 --- a/vidar/interactive/interactive_layout_yang.py +++ b/vidar/interactive/interactive_layout_yang.py @@ -0,0 +1,86 @@ +import numpy as np +import matplotlib.pyplot as plt +from matplotlib.widgets import Button +from matplotlib.widgets import PolygonSelector +from sklearn.cluster import KMeans + + +def colors_from_lbs(lbs, colors=None): + mpl_20 = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', + '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf', + '#3397dc', '#ff993e', '#3fca3f', '#df5152', '#a985ca', + '#ad7165', '#e992ce', '#999999', '#dbdc3c', '#35d8e9'] + + if colors is None: + colors = np.array(mpl_20) + else: + colors = np.array(colors) + lbs = np.array(lbs) % len(colors) + return colors[lbs] + +class InteractiveKMeans: + + def __init__(self, fig, axes, X, n_clusters, **kwargs): + self.fig = fig + self.X = X + self.ax_orig, self.ax_redim = axes + self.n_clusters = n_clusters + + self.kmeans = KMeans(n_clusters=n_clusters, random_state=40, n_init="auto") + self.labels = self.kmeans.fit_predict(X) + self.centers = self.kmeans.cluster_centers_ + + self.plot_figure() # plot the initial figure + + self.selector = PolygonSelector(self.ax_orig, + onselect=self.onselect, + props=dict(color='r', linestyle='', linewidth=3, alpha=0.6, label=f"Component") + ) + self.selector.verts = self.centers + + def onselect(self, verts): + self.centers = np.array(verts) + self.update_cluster() + self.plot_figure() + + def update_cluster(self): + self.kmeans.cluster_centers_ = np.array(self.centers, dtype=np.float64) + self.labels = self.kmeans.predict(self.X) + + def plot_figure(self): + self.ax_orig.clear() + self.ax_orig.scatter(self.X[:, 0], self.X[:, 1], + alpha=0.3, label="samples", + c=colors_from_lbs(self.labels)) + self.ax_orig.scatter(self.centers[:,0], self.centers[:,1], + s=50, c='black', edgecolors='r') + self.ax_orig.set( + aspect="auto", + title="Interactive K-means", + xlabel="first feature", + ylabel="second feature", + ) + + self.ax_redim.clear() + class_name = ['class {0}'.format(i+1) for i in range(len(self.centers))] + + # update labels + counts = [np.sum(self.labels==i) for i in range(len(self.centers))] + + self.ax_redim.bar(class_name, counts, label=class_name, color=colors_from_lbs(range(len(self.centers)))) + self.ax_redim.set( + aspect="auto", + title="Clustering results", + xlabel="Main feature", + ylabel="Number of samples", + ) + self.fig.canvas.draw_idle() + + def show(self): + plt.tight_layout() + plt.show() + +def interactive_kmeans(X, n_clusters=2, **kwargs): + fig, axes = plt.subplots(1, 2, figsize=(12, 6)) + app = InteractiveKMeans(fig, axes, X, n_clusters=n_clusters, **kwargs) + return app \ No newline at end of file diff --git a/vidar/interactive/tests/test_interactive_layout_yang.py b/vidar/interactive/tests/test_interactive_layout_yang.py new file mode 100644 index 0000000..c29e935 --- /dev/null +++ b/vidar/interactive/tests/test_interactive_layout_yang.py @@ -0,0 +1,29 @@ +import pytest +import numpy as np + +class TestInteractiveKMeans: + + def test_module(self): + try: + from vidar import InteractiveKMeans, interactive_kmeans + except ModuleNotFoundError: + pytest.fail("Not found: InteractiveKMeans, interactive_kmeans") + + def test_function_call(self): + from vidar import interactive_kmeans + + X = np.random.random((10,2)) + with pytest.raises(AttributeError): + app = interactive_kmeans(X) + app.xx() + + try: + app = interactive_kmeans(X) + app.update_cluster() + app.plot_figure() + except AttributeError: + pytest.fail( + "Function missing: the function missing in the class InteractiveKMeans" + ) + + \ No newline at end of file From 02249890819c53d0901b2c0457be8e5d5b8337a5 Mon Sep 17 00:00:00 2001 From: Yang Long Date: Fri, 28 Jun 2024 11:23:58 +0100 Subject: [PATCH 2/2] update interactive PCA --- examples/interactive_Kmeans.py | 4 +- examples/interactive_PCA.py | 17 +++++ notebooks/yang/PCA interactive.ipynb | 21 +++++- vidar/__init__.py | 3 +- vidar/interactive/interactive_layout_yang.py | 71 +++++++++++++++++++ .../tests/test_interactive_layout_yang.py | 27 ++++++- 6 files changed, 137 insertions(+), 6 deletions(-) create mode 100644 examples/interactive_PCA.py diff --git a/examples/interactive_Kmeans.py b/examples/interactive_Kmeans.py index 37ed93c..54901bc 100644 --- a/examples/interactive_Kmeans.py +++ b/examples/interactive_Kmeans.py @@ -2,7 +2,7 @@ sys.path.append('..') import numpy as np -from vidar import InteractiveKMeans, interactive_kmeans +from vidar import interactive_kmeans rng = np.random.RandomState(0) n_samples = 1000 @@ -13,6 +13,6 @@ rng.multivariate_normal(mean=[0.3, 1], cov=cov, size=n_samples) ]) -n_clusters = 3 +n_clusters = 2 app = interactive_kmeans(X, n_clusters) app.show() \ No newline at end of file diff --git a/examples/interactive_PCA.py b/examples/interactive_PCA.py new file mode 100644 index 0000000..8e9986a --- /dev/null +++ b/examples/interactive_PCA.py @@ -0,0 +1,17 @@ +import sys +sys.path.append('..') + +import numpy as np +from vidar import interactive_PCA + +rng = np.random.RandomState(0) +n_samples = 1000 +cov = [[1, 0], [0, 1]] +X = np.concatenate([ + rng.multivariate_normal(mean=[-2, 0], cov=cov, size=n_samples), + rng.multivariate_normal(mean=[2, 0], cov=cov, size=n_samples)]) + + +n_components = 1 +app = interactive_PCA(X, n_components) +app.show() \ No newline at end of file diff --git a/notebooks/yang/PCA interactive.ipynb b/notebooks/yang/PCA interactive.ipynb index aaf0b9b..73a5cbd 100644 --- a/notebooks/yang/PCA interactive.ipynb +++ b/notebooks/yang/PCA interactive.ipynb @@ -10,7 +10,7 @@ }, { "cell_type": "code", - "execution_count": 115, + "execution_count": 11, "metadata": {}, "outputs": [], "source": [ @@ -50,6 +50,14 @@ " ylabel=\"second feature\",\n", ")\n", "\n", + "ax_orig.plot(\n", + " [x_center[0], component[0]],\n", + " [x_center[1], component[1]],\n", + " label=f\"PCA\",\n", + " linewidth=5,\n", + " color=f\"orange\",\n", + " )\n", + "\n", "\n", "def onselect(verts):\n", " _x_center, _total_vector = verts\n", @@ -64,10 +72,19 @@ " xlabel=\"Main feature\",\n", " ylabel=\"Number of samples\",\n", " )\n", + "\n", + " #ax_orig.clear()\n", + " #ax_orig.scatter(X[:, 0], X[:, 1], alpha=0.3, label=\"samples\")\n", + " # ax_orig.plot(\n", + " # [x_center[0], _component[0]],\n", + " # [x_center[1], _component[1]],\n", + " # linewidth=5,\n", + " # color=f\"C{3}\",\n", + " # )\n", " fig.canvas.draw()\n", "\n", "selector = PolygonSelector(ax_orig, onselect=onselect, \n", - " props=dict(color='r', linestyle='-', linewidth=3, alpha=0.6, label=f\"Component\"))\n", + " props=dict(color='r', linestyle='-', linewidth=3, alpha=0.6, label=f\"Interactive\"))\n", "component, x_center = comp_vector\n", "selector.verts = [x_center, x_center + component]\n", "ax_orig.legend()\n", diff --git a/vidar/__init__.py b/vidar/__init__.py index a4dd1d6..de9a4cc 100644 --- a/vidar/__init__.py +++ b/vidar/__init__.py @@ -1 +1,2 @@ -from .interactive.interactive_layout_yang import InteractiveKMeans, interactive_kmeans \ No newline at end of file +from .interactive.interactive_layout_yang import InteractiveKMeans, interactive_kmeans +from .interactive.interactive_layout_yang import InteractivePCA, interactive_PCA \ No newline at end of file diff --git a/vidar/interactive/interactive_layout_yang.py b/vidar/interactive/interactive_layout_yang.py index 19ce49e..40e3478 100644 --- a/vidar/interactive/interactive_layout_yang.py +++ b/vidar/interactive/interactive_layout_yang.py @@ -3,6 +3,7 @@ from matplotlib.widgets import Button from matplotlib.widgets import PolygonSelector from sklearn.cluster import KMeans +from sklearn.decomposition import PCA def colors_from_lbs(lbs, colors=None): @@ -80,7 +81,77 @@ def show(self): plt.tight_layout() plt.show() +class InteractivePCA: + + def __init__(self, fig, axes, X, n_components, **kwargs): + self.fig = fig + self.X = X + self.ax_orig, self.ax_redim = axes + self.n_components = n_components + + self.pca = PCA(n_components=n_components).fit(X) + self.component = self.pca.components_.reshape(-1) + self.x_center = np.mean(X, axis=0) + + self.ax_orig.scatter(X[:, 0], X[:, 1], alpha=0.3, label="samples") + self.ax_orig.set( + aspect="auto", + title="2-dimensional dataset with principal components", + xlabel="first feature", + ylabel="second feature", + ) + + self.ax_orig.plot( + [self.x_center[0], self.component[0]], + [self.x_center[1], self.component[1]], + label=f"PCA", + linewidth=5, + color=f"purple", + ) + + #self.comp_vector = [self.component, selfx_center] + self.selector = PolygonSelector(self.ax_orig, onselect=self.onselect, + props=dict(color='r', linestyle='-', linewidth=3, alpha=0.6, label=f"Interactive")) + + self.selector.verts = [self.x_center, self.x_center + self.component] + + self.ax_orig.legend() + self.ax_redim.hist((self.X @ self.component.T - self.x_center @ self.component.T),50) + self.ax_redim.set( + aspect="auto", + title="1-dimensional dataset after dimension reduction", + xlabel="Main feature", + ylabel="Number of samples", + ) + + def onselect(self, verts): + _x_center, _total_vector = verts + _component = np.array(_total_vector) - np.array(_x_center) + self.ax_redim.clear() + self.ax_redim.hist((self.X @ _component.T - _x_center @ _component.T),50) + self.ax_redim.set( + aspect="auto", + title="1-dimensional dataset after dimension reduction", + xlabel="Main feature", + ylabel="Number of samples", + ) + + self.fig.canvas.draw() + + def show(self): + plt.tight_layout() + plt.show() + + def get_components(self): + return self.selector.verts + def interactive_kmeans(X, n_clusters=2, **kwargs): fig, axes = plt.subplots(1, 2, figsize=(12, 6)) app = InteractiveKMeans(fig, axes, X, n_clusters=n_clusters, **kwargs) + return app + + +def interactive_PCA(X, n_components=1, **kwargs): + fig, axes = plt.subplots(1, 2, figsize=(12, 6)) + app = InteractivePCA(fig, axes, X, n_components=n_components, **kwargs) return app \ No newline at end of file diff --git a/vidar/interactive/tests/test_interactive_layout_yang.py b/vidar/interactive/tests/test_interactive_layout_yang.py index c29e935..a65cd64 100644 --- a/vidar/interactive/tests/test_interactive_layout_yang.py +++ b/vidar/interactive/tests/test_interactive_layout_yang.py @@ -26,4 +26,29 @@ def test_function_call(self): "Function missing: the function missing in the class InteractiveKMeans" ) - \ No newline at end of file +class TestInteractivePCA: + + def test_module(self): + try: + from vidar import InteractivePCA, interactive_PCA + except ModuleNotFoundError: + pytest.fail("Not found: InteractivePCA, interactive_PCA") + + def test_function_call(self): + from vidar import interactive_PCA + + X = np.random.random((10,2)) + with pytest.raises(AttributeError): + app = interactive_PCA(X) + app.xx() + + try: + app = interactive_PCA(X) + app.get_components() + #app.show() + except AttributeError: + pytest.fail( + "Function missing: the function missing in the class InteractivePCA" + ) + + \ No newline at end of file