.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "_gallery/machine_learning/kmeans.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr__gallery_machine_learning_kmeans.py: K-Means Clustering of MNIST Dataset =================================== Example showing how you can perform K-Means clustering on the MNIST dataset. .. GENERATED FROM PYTHON SOURCE LINES 7-118 .. image-sg:: /_gallery/machine_learning/images/sphx_glr_kmeans_001.webp :alt: kmeans :srcset: /_gallery/machine_learning/images/sphx_glr_kmeans_001.webp :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none /opt/hostedtoolcache/Python/3.12.9/x64/lib/python3.12/site-packages/pygfx/objects/_ruler.py:267: RuntimeWarning: divide by zero encountered in divide screen_full = (ndc_full[:, :2] / ndc_full[:, 3:4]) * half_canvas_size /opt/hostedtoolcache/Python/3.12.9/x64/lib/python3.12/site-packages/pygfx/objects/_ruler.py:267: RuntimeWarning: invalid value encountered in divide screen_full = (ndc_full[:, :2] / ndc_full[:, 3:4]) * half_canvas_size /opt/hostedtoolcache/Python/3.12.9/x64/lib/python3.12/site-packages/pygfx/objects/_ruler.py:279: RuntimeWarning: invalid value encountered in divide screen_sel = (ndc_sel[:, :2] / ndc_sel[:, 3:4]) * half_canvas_size /home/runner/work/fastplotlib/fastplotlib/fastplotlib/graphics/_features/_base.py:18: UserWarning: casting float64 array to float32 warn(f"casting {array.dtype} array to float32") | .. code-block:: Python # test_example = false import fastplotlib as fpl import numpy as np from sklearn.datasets import load_digits from sklearn.cluster import KMeans from sklearn.decomposition import PCA # load the data mnist = load_digits() # get the data and labels data = mnist['data'] # (1797, 64) labels = mnist['target'] # (1797,) # visualize the first 5 digits # NOTE: this is just to give a sense of the dataset if you are unfamiliar, # the more interesting visualization is below :D fig_data = fpl.Figure(shape=(1, 5), size=(900, 300)) # iterate through each subplot for i, subplot in enumerate(fig_data): # reshape each image to (8, 8) subplot.add_image(data[i].reshape(8,8), cmap="gray", interpolation="linear") # add the label as a title subplot.set_title(f"Label: {labels[i]}") # turn off the axes and toolbar subplot.axes.visible = False subplot.toolbar = False fig_data.show() # project the data from 64 dimensions down to the number of unique digits n_digits = len(np.unique(labels)) # 10 reduced_data = PCA(n_components=n_digits).fit_transform(data) # (1797, 10) # performs K-Means clustering, take the best of 4 runs kmeans = KMeans(n_clusters=n_digits, n_init=4) # fit the lower-dimension data kmeans.fit(reduced_data) # get the centroids (center of the clusters) centroids = kmeans.cluster_centers_ # plot the kmeans result and corresponding original image figure = fpl.Figure( shape=(1,2), size=(700, 400), cameras=["3d", "2d"], controller_types=[["fly", "panzoom"]] ) # set the axes to False figure[0, 0].axes.visible = False figure[0, 1].axes.visible = False figure[0, 0].set_title(f"K-means clustering of PCA-reduced data") # plot the centroids figure[0, 0].add_scatter( data=np.vstack([centroids[:, 0], centroids[:, 1], centroids[:, 2]]).T, colors="white", sizes=15 ) # plot the down-projected data digit_scatter = figure[0,0].add_scatter( data=np.vstack([reduced_data[:, 0], reduced_data[:, 1], reduced_data[:, 2]]).T, sizes=5, cmap="tab10", # use a qualitative cmap cmap_transform=kmeans.labels_, # color by the predicted cluster ) # initial index ix = 0 # plot the initial image digit_img = figure[0, 1].add_image( data=data[ix].reshape(8,8), cmap="gray", name="digit", interpolation="linear" ) # change the color and size of the initial selected data point digit_scatter.colors[ix] = "magenta" digit_scatter.sizes[ix] = 10 # define event handler to update the selected data point @digit_scatter.add_event_handler("pointer_enter") def update(ev): # reset colors and sizes digit_scatter.cmap = "tab10" digit_scatter.sizes = 5 # update with new seleciton ix = ev.pick_info["vertex_index"] digit_scatter.colors[ix] = "magenta" digit_scatter.sizes[ix] = 10 # update digit fig figure[0, 1]["digit"].data = data[ix].reshape(8, 8) figure.show() # NOTE: `if __name__ == "__main__"` is NOT how to use fastplotlib interactively # please see our docs for using fastplotlib interactively in ipython and jupyter if __name__ == "__main__": print(__doc__) fpl.loop.run() .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 0.828 seconds) .. _sphx_glr_download__gallery_machine_learning_kmeans.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: kmeans.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: kmeans.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: kmeans.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_