Skip to content

Commit

Permalink
Merge pull request #14 from longyangking/main
Browse files Browse the repository at this point in the history
add K-means interactive
  • Loading branch information
longyangking authored Jun 27, 2024
2 parents 6f2d6d8 + bbee437 commit 8873d5b
Showing 1 changed file with 166 additions and 0 deletions.
166 changes: 166 additions & 0 deletions notebooks/yang/Kmeans interactive.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# K-means interactive\n",
"\n",
"> Yang"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"%matplotlib qt\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"from matplotlib.widgets import Button\n",
"from matplotlib.widgets import PolygonSelector\n",
"from sklearn.cluster import KMeans\n",
"\n",
"def colors_from_lbs(lbs, colors=None):\n",
" mpl_20 = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd',\n",
" '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf',\n",
" '#3397dc', '#ff993e', '#3fca3f', '#df5152', '#a985ca',\n",
" '#ad7165', '#e992ce', '#999999', '#dbdc3c', '#35d8e9']\n",
" \n",
" if colors is None:\n",
" colors = np.array(mpl_20)\n",
" else:\n",
" colors = np.array(colors)\n",
" lbs = np.array(lbs) % len(colors)\n",
" return colors[lbs]\n",
"\n",
"rng = np.random.RandomState(0)\n",
"n_samples = 1000\n",
"cov = [[1, 0], [0, 1]]\n",
"X = np.concatenate([\n",
" rng.multivariate_normal(mean=[-2, 0], cov=cov, size=n_samples), \n",
" rng.multivariate_normal(mean=[2, 0], cov=cov, size=n_samples)])\n",
"\n",
"kmeans = KMeans(n_clusters=2, random_state=0, n_init=\"auto\")\n",
"labels = kmeans.fit_predict(X)\n",
"\n",
"#kmeans.cluster_centers_ = np.array([[1,2],[2,1]],dtype=kmeans.cluster_centers_.dtype)\n",
"#centers = kmeans.predict(X)\n",
"centers = kmeans.cluster_centers_\n",
"\n",
"fig, (ax_orig, ax_redim) = plt.subplots(1, 2, figsize=(12, 6))\n",
"\n",
"def plot_figure(axe_list, X, centers):\n",
" ax_orig, ax_redim = axe_list\n",
"\n",
" #labels = [np.argmax([np.linalg.norm(x-center) for center in centers]) for x in X]\n",
" kmeans.cluster_centers_ = np.array(centers, dtype=np.float64)\n",
" labels = kmeans.predict(X) \n",
"\n",
" ax_orig.clear()\n",
" ax_orig.scatter(X[:, 0], X[:, 1], alpha=0.3, label=\"samples\", c=colors_from_lbs(labels))\n",
" ax_orig.scatter(centers[:,0], centers[:,1], s=50, c='black', edgecolors='r')\n",
" ax_orig.set(\n",
" aspect=\"auto\", \n",
" title=\"K-means\",\n",
" xlabel=\"first feature\",\n",
" ylabel=\"second feature\",\n",
" )\n",
"\n",
" ax_redim.clear()\n",
" class_name = ['class {0}'.format(i) for i in range(len(centers))]\n",
"\n",
" # update labels\n",
" counts = [np.sum(labels==i) for i in range(len(centers))]\n",
" \n",
"\n",
" ax_redim.bar(class_name, counts, label=class_name)\n",
" #ax_redim.bar((X @ _component.T - _x_center @ _component.T),50)\n",
" ax_redim.set(\n",
" aspect=\"auto\",\n",
" title=\"Clustering results\",\n",
" xlabel=\"Main feature\",\n",
" ylabel=\"Number of samples\",\n",
" )\n",
" fig.canvas.draw_idle()\n",
"\n",
"plot_figure((ax_orig, ax_redim), X, centers)\n",
"\n",
"def onselect(verts):\n",
" centers = np.array(verts)\n",
" plot_figure((ax_orig, ax_redim), X, centers)\n",
"\n",
"selector = PolygonSelector(ax_orig, onselect=onselect, \n",
" props=dict(color='r', linestyle='', linewidth=3, alpha=0.6, label=f\"Component\"))\n",
"selector.verts = centers\n",
"\n",
"\n",
"# ax_redim.hist((X @ component.T - x_center @ component.T),50)\n",
"# ax_redim.set(\n",
"# aspect=\"auto\",\n",
"# title=\"1-dimensional dataset after dimension reduction\",\n",
"# xlabel=\"Main feature\",\n",
"# ylabel=\"Number of samples\",\n",
"# )\n",
"#_asp = np.diff(ax_orig.get_ylim())[0] / np.diff(ax_orig.get_xlim())[0]\n",
"#ax_redim.set_aspect(_asp)\n",
"\n",
"plt.tight_layout()\n",
"plt.show()\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"\n",
"# pca = PCA(n_components=1).fit(X)\n",
"# component = pca.components_.reshape(-1)\n",
"\n",
"# # print(pca.components_)\n",
"# # print(pca.explained_variance_)\n",
"# # print(list(zip(pca.components_, pca.explained_variance_)))\n",
"\n",
"# # fig, (ax_orig, ax_redim) = plt.subplots(1, 2, figsize=(12, 6))\n",
"# # ax_orig.scatter(X[:, 0], X[:, 1], alpha=0.3, label=\"samples\")\n",
"# # x_center = np.mean(X, axis=0)\n",
"\n",
"# comp_vector = [component, x_center]\n",
"\n",
"# ax_orig.set(\n",
"# aspect=\"auto\", \n",
"# title=\"2-dimensional dataset with principal components\",\n",
"# xlabel=\"first feature\",\n",
"# ylabel=\"second feature\",\n",
"# )\n",
"\n",
"\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "base",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.19"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

0 comments on commit 8873d5b

Please sign in to comment.