diff --git a/notebooks/yang/Kmeans interactive.ipynb b/notebooks/yang/Kmeans interactive.ipynb new file mode 100644 index 0000000..9bd717c --- /dev/null +++ b/notebooks/yang/Kmeans interactive.ipynb @@ -0,0 +1,166 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# K-means interactive\n", + "\n", + "> Yang" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [], + "source": [ + "%matplotlib qt\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "from matplotlib.widgets import Button\n", + "from matplotlib.widgets import PolygonSelector\n", + "from sklearn.cluster import KMeans\n", + "\n", + "def colors_from_lbs(lbs, colors=None):\n", + " mpl_20 = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd',\n", + " '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf',\n", + " '#3397dc', '#ff993e', '#3fca3f', '#df5152', '#a985ca',\n", + " '#ad7165', '#e992ce', '#999999', '#dbdc3c', '#35d8e9']\n", + " \n", + " if colors is None:\n", + " colors = np.array(mpl_20)\n", + " else:\n", + " colors = np.array(colors)\n", + " lbs = np.array(lbs) % len(colors)\n", + " return colors[lbs]\n", + "\n", + "rng = np.random.RandomState(0)\n", + "n_samples = 1000\n", + "cov = [[1, 0], [0, 1]]\n", + "X = np.concatenate([\n", + " rng.multivariate_normal(mean=[-2, 0], cov=cov, size=n_samples), \n", + " rng.multivariate_normal(mean=[2, 0], cov=cov, size=n_samples)])\n", + "\n", + "kmeans = KMeans(n_clusters=2, random_state=0, n_init=\"auto\")\n", + "labels = kmeans.fit_predict(X)\n", + "\n", + "#kmeans.cluster_centers_ = np.array([[1,2],[2,1]],dtype=kmeans.cluster_centers_.dtype)\n", + "#centers = kmeans.predict(X)\n", + "centers = kmeans.cluster_centers_\n", + "\n", + "fig, (ax_orig, ax_redim) = plt.subplots(1, 2, figsize=(12, 6))\n", + "\n", + "def plot_figure(axe_list, X, centers):\n", + " ax_orig, ax_redim = axe_list\n", + "\n", + " #labels = [np.argmax([np.linalg.norm(x-center) for center in centers]) for x in X]\n", + " kmeans.cluster_centers_ = np.array(centers, dtype=np.float64)\n", + " labels = kmeans.predict(X) \n", + "\n", + " ax_orig.clear()\n", + " ax_orig.scatter(X[:, 0], X[:, 1], alpha=0.3, label=\"samples\", c=colors_from_lbs(labels))\n", + " ax_orig.scatter(centers[:,0], centers[:,1], s=50, c='black', edgecolors='r')\n", + " ax_orig.set(\n", + " aspect=\"auto\", \n", + " title=\"K-means\",\n", + " xlabel=\"first feature\",\n", + " ylabel=\"second feature\",\n", + " )\n", + "\n", + " ax_redim.clear()\n", + " class_name = ['class {0}'.format(i) for i in range(len(centers))]\n", + "\n", + " # update labels\n", + " counts = [np.sum(labels==i) for i in range(len(centers))]\n", + " \n", + "\n", + " ax_redim.bar(class_name, counts, label=class_name)\n", + " #ax_redim.bar((X @ _component.T - _x_center @ _component.T),50)\n", + " ax_redim.set(\n", + " aspect=\"auto\",\n", + " title=\"Clustering results\",\n", + " xlabel=\"Main feature\",\n", + " ylabel=\"Number of samples\",\n", + " )\n", + " fig.canvas.draw_idle()\n", + "\n", + "plot_figure((ax_orig, ax_redim), X, centers)\n", + "\n", + "def onselect(verts):\n", + " centers = np.array(verts)\n", + " plot_figure((ax_orig, ax_redim), X, centers)\n", + "\n", + "selector = PolygonSelector(ax_orig, onselect=onselect, \n", + " props=dict(color='r', linestyle='', linewidth=3, alpha=0.6, label=f\"Component\"))\n", + "selector.verts = centers\n", + "\n", + "\n", + "# ax_redim.hist((X @ component.T - x_center @ component.T),50)\n", + "# ax_redim.set(\n", + "# aspect=\"auto\",\n", + "# title=\"1-dimensional dataset after dimension reduction\",\n", + "# xlabel=\"Main feature\",\n", + "# ylabel=\"Number of samples\",\n", + "# )\n", + "#_asp = np.diff(ax_orig.get_ylim())[0] / np.diff(ax_orig.get_xlim())[0]\n", + "#ax_redim.set_aspect(_asp)\n", + "\n", + "plt.tight_layout()\n", + "plt.show()\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "\n", + "# pca = PCA(n_components=1).fit(X)\n", + "# component = pca.components_.reshape(-1)\n", + "\n", + "# # print(pca.components_)\n", + "# # print(pca.explained_variance_)\n", + "# # print(list(zip(pca.components_, pca.explained_variance_)))\n", + "\n", + "# # fig, (ax_orig, ax_redim) = plt.subplots(1, 2, figsize=(12, 6))\n", + "# # ax_orig.scatter(X[:, 0], X[:, 1], alpha=0.3, label=\"samples\")\n", + "# # x_center = np.mean(X, axis=0)\n", + "\n", + "# comp_vector = [component, x_center]\n", + "\n", + "# ax_orig.set(\n", + "# aspect=\"auto\", \n", + "# title=\"2-dimensional dataset with principal components\",\n", + "# xlabel=\"first feature\",\n", + "# ylabel=\"second feature\",\n", + "# )\n", + "\n", + "\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "base", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.19" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}