K-means clustering is an important algorithm for clustering analysis in unsupervised learning. The goal of k-means is to partition observation into k clusters, where each observation is classified to the cluster which it has the shortest distance to the center.
K-means algorithm to get k clusters works as follows:
Randomly pick k observations from a total of n observations, they will function as initial centroids.
For every observation (that is not initial centroids), decide which centroid is closest to it using a measurement like Euclidean distance. The observation will be assigned to the cluster corresponding to that centroid.
After all observations are assigned to those k clusters, find the mean in each axis/dimension for all observations from each of the k clusters. This will generate the new centroids of all k clusters (new centroid(s) can be existing observation(s) or arbitrary point(s) ).
Repeat step 2 and 3 until convergence is reached (i.e no change in the assignment of observation to clusters) or the maximum number of iterations is reached.
The following is a nice illustration of k-means algorithm.
You can download the iris dataset used in this tutorial in csv format and in stata data format
The dataset contains 3 classes of 50 instances each, where each class refers to a type of iris plant. The three types of plants are Setosa, Versicolour and Virginica. There are no missing values in the dataset.
There are 5 attributes in the dataset, sepal length in cm, sepal width in cm, petal length in cm, petal width in cm and classes(name of species).
We provide examples of k-means clustering using 3 languages: R, Python and stata. Besides k-means clustering, we also give examples of how to do multidimensional scaling to expand your skills set beyond what we learned in R.
rm(list=ls())
# load data
iris=read.csv('https://raw.githubusercontent.com/pydata/pandas/master/pandas/tests/data/iris.csv')
# Add label
iris$Species_label=1
for (i in (1:nrow(iris))){
if (iris[i,5]=="Iris-versicolor") iris[i,6]=2
else if (iris[i,5]=="Iris-virginica") iris[i,6]=3
}
# Show data
head(iris)
## SepalLength SepalWidth PetalLength PetalWidth Name Species_label
## 1 5.1 3.5 1.4 0.2 Iris-setosa 1
## 2 4.9 3.0 1.4 0.2 Iris-setosa 1
## 3 4.7 3.2 1.3 0.2 Iris-setosa 1
## 4 4.6 3.1 1.5 0.2 Iris-setosa 1
## 5 5.0 3.6 1.4 0.2 Iris-setosa 1
## 6 5.4 3.9 1.7 0.4 Iris-setosa 1
First of all, we load the iris data and add a species label column. We set Setosa to 1, Versicolor to 2, and Virginica to 3. The first six rows of the data are shown as above.
# Kmeans cluster
fit1 = kmeans(iris[,1:4],3)
# Add new label to original data
iris$Kmeans_label = fit1$cluster
head(iris)
## SepalLength SepalWidth PetalLength PetalWidth Name Species_label
## 1 5.1 3.5 1.4 0.2 Iris-setosa 1
## 2 4.9 3.0 1.4 0.2 Iris-setosa 1
## 3 4.7 3.2 1.3 0.2 Iris-setosa 1
## 4 4.6 3.1 1.5 0.2 Iris-setosa 1
## 5 5.0 3.6 1.4 0.2 Iris-setosa 1
## 6 5.4 3.9 1.7 0.4 Iris-setosa 1
## Kmeans_label
## 1 1
## 2 1
## 3 1
## 4 1
## 5 1
## 6 1
W use kmeans function to do the cluster. We set the number of clusters to be 3, and then a random set of (distinct) rows in the data are chosen as the initial centers. We also added a Kmeans label column. I have attached the help page of k-means clustering: kmeans help page
# confusion matrix
table(iris$Name, iris$Kmeans_label)
##
## 1 2 3
## Iris-setosa 50 0 0
## Iris-versicolor 0 2 48
## Iris-virginica 0 36 14
First of all, we create the confusion matrix. As the table shows, all setosa flowers are classified correctly, 2 of versicolor flowers are classified as virginica, and 14 virginica flowers are classified as versicolor.
plot(iris[,1:4], col = fit1$cluster)
Secondly, we want to see the distribution of each group. As the plot shows, while the setosa flowers are separated from the other 2 species, k-means cannot separate the other two species perfectly.
#mds plot
dist = dist(iris[,1:4])
mds = cmdscale(dist)
plot(mds[,1], mds[,2], type="n", xlab = "", ylab = "", axes = FALSE,main = "Iris")
points(mds[,1], mds[,2], pch = iris$Species_label, col = iris$Kmeans_label)
legend("bottomright",c("Kmeans ID=1","Kmeans ID=2","Kmeans ID=3","setosa","versicolor","virginica"), lty=c(1,1,1,NA,NA,NA), pch=c(NA,NA,NA,1:3),lwd=c(2.5,2.5,2.5,1,1,1),col=c(1:3,"black","black","black"),cex=0.7)
Then, we use multidimensional scaling to produce a two-dimensional map. In R, we use cmdscale function. We use different shapes to represent different species and different colors to distinguish different clusters.
We can see that the setosa flowers are all classified correctly by k-means. However, some versicolor flowers are classified as virginica, and vice versa.
##Load packages
# packages for data manipulation
import pandas as pd
import numpy as np
# packages to run Kmeans and Multidimensional Scaling
from sklearn.manifold import MDS
from sklearn.cluster import KMeans
# packages to generate plots
import matplotlib.pyplot as plt
import matplotlib.lines as mlines #create legend
import matplotlib.patches as mpatches #create legend
import seaborn as sns
Pandas and Numpy are two core packages that are useful in data management and analysis.
Sklearn is a simple and efficient tool for machine learning. The K-means clustering function is included in this module.
Matplotlib and Seaborn are most frequently used modules when visualization is needed.
##Load data
iris=pd.read_csv('https://raw.githubusercontent.com/pydata/pandas/master/pandas/tests/data/iris.csv')
##Add label
iris['Species_label']=0
for i in range(len(iris)):
if iris.iloc[i,4]=="Iris-versicolor":
iris.iloc[i,5]=1
elif iris.iloc[i,4]=="Iris-virginica":
iris.iloc[i,5]=2
##Show data
iris.head()
SepalLength SepalWidth PetalLength PetalWidth Name Species_label
0 5.1 3.5 1.4 0.2 Iris-setosa 0
1 4.9 3.0 1.4 0.2 Iris-setosa 0
2 4.7 3.2 1.3 0.2 Iris-setosa 0
3 4.6 3.1 1.5 0.2 Iris-setosa 0
4 5.0 3.6 1.4 0.2 Iris-setosa 0
Firstly, load the iris data as pandas data frame.
Then add a species label column for easier classification. Setosa to 0, Versicolor to 1, Virginica to 2.
The first 5 rows of the data frame are shown as above.
##Kmeans cluster
kmeans_3=KMeans(n_clusters=3,init='random')
labels_3=kmeans_3.fit_predict(iris.iloc[:,0:4])
#Second method to obtain clusters
dist_3=kmeans_3.fit_transform(iris.iloc[:,0:4])
labels_3_2=kmeans_3.labels_
##Add new label to original data
iris['Kmeans_label']=labels_3
iris.head()
SepalLength SepalWidth PetalLength PetalWidth Name Species_label Kmeans_label
0 5.1 3.5 1.4 0.2 Iris-setosa 0 0
1 4.9 3.0 1.4 0.2 Iris-setosa 0 0
2 4.7 3.2 1.3 0.2 Iris-setosa 0 0
3 4.6 3.1 1.5 0.2 Iris-setosa 0 0
4 5.0 3.6 1.4 0.2 Iris-setosa 0 0
We set 3 as the number of clusters, and choose the initial centroids randomly. We can also set the maximum number of iterations, the seed used by the random number generator etc. by using other parameters in the function. In this tutorial, we use the default values for other parameters.
After setting up the parameters, use fit_predict to obtain the clustering results. We can also use fit_transform to acquire the cluster-distance space, then use labels_ to acquire the same results as fit_predict. The code is shown above, we use the results of fit_predict to continue the analysis.
Use crosstab function in pandas, sum up the clustering results and compare with the original species. The results are presented in percentage of each row.
pd.crosstab(index=iris['Name'],columns=iris['Kmeans_label'],rownames=['Species'],colnames=['Kmeans']). \
apply(lambda r: r/r.sum()*100, axis=1)
Kmeans 0 1 2
Species
Iris-setosa 100.0 0.0 0.0
Iris-versicolor 0.0 4.0 96.0
Iris-virginica 0.0 72.0 28.0
The parameter ‘index’ contains flower species, while ‘columns’ refers to clusters generated by k-means. We can use ‘rownames’ and ‘colnames’ to rename the title of rows and columns.
The pairplot in seaborn package is a good way to show the distribution of each group. It plots pairwise relationships in a dataset, each variable in the data is shared on the y-axis and x-axis. Meanwhile, the diagonal axes show the univariate distributions of all variables included.
The parameter ‘hue’ is the variable to map plot aspects to different colors, ‘palette’ is a set of colors for mapping.
##Pairplot
sns.pairplot(iris.iloc[:,[0,1,2,3,6]],hue='Kmeans_label',palette="husl")
plt.show()
As we can see from the pairplot above, the group represented by pink points is clearly separated from the other two groups, while there are some overlaps between the green and the blue points.
##Multidimensional scaling
mds=MDS(n_components=2,dissimilarity='euclidean',n_jobs=1)
pos=mds.fit(iris.iloc[:,0:4]).embedding_
##Transform list to numpy array, easier to use
pos=np.array(pos)
We use the function MDS in sklearn.manifold to do multidimensional scaling, utilizing Euclidean distances as the measurement of dissimilarities. After setting up the parameters, use fit and embedding_ to calculate and obtain the results.
##Define different groups
kmeans_id=np.array(iris['Kmeans_label'])
species_id=np.array(iris['Species_label'])
group00=pos[np.intersect1d(np.where(kmeans_id==0),np.where(species_id==0))]
group01=pos[np.intersect1d(np.where(kmeans_id==0),np.where(species_id==1))]
group02=pos[np.intersect1d(np.where(kmeans_id==0),np.where(species_id==2))]
group10=pos[np.intersect1d(np.where(kmeans_id==1),np.where(species_id==0))]
group11=pos[np.intersect1d(np.where(kmeans_id==1),np.where(species_id==1))]
group12=pos[np.intersect1d(np.where(kmeans_id==1),np.where(species_id==2))]
group20=pos[np.intersect1d(np.where(kmeans_id==2),np.where(species_id==0))]
group21=pos[np.intersect1d(np.where(kmeans_id==2),np.where(species_id==1))]
group22=pos[np.intersect1d(np.where(kmeans_id==2),np.where(species_id==2))]
##Plot to show different groups
plt.figure(figsize=(10,8))
plt.xlim([-3,3])
plt.ylim([-5,5])
plt.scatter(group00[:,0],group00[:,1],c='r',marker='^',)
plt.scatter(group01[:,0],group01[:,1],c='r',marker='o',)
plt.scatter(group02[:,0],group02[:,1],c='r',marker='*',)
plt.scatter(group10[:,0],group10[:,1],c='g',marker='^',)
plt.scatter(group11[:,0],group11[:,1],c='g',marker='o',)
plt.scatter(group12[:,0],group12[:,1],c='g',marker='*',)
plt.scatter(group20[:,0],group20[:,1],c='b',marker='^',)
plt.scatter(group21[:,0],group21[:,1],c='b',marker='o',)
plt.scatter(group22[:,0],group22[:,1],c='b',marker='*',)
#Plot legend
red_patch = mpatches.Patch(color='red', label='Kmeans Id=0')
green_patch = mpatches.Patch(color='green', label='Kmeans Id=1')
blue_patch = mpatches.Patch(color='blue', label='Kmeans Id=2')
triangle = mlines.Line2D([],[],ls='None',color='black',marker='^', label='Setosa')
circle = mlines.Line2D([],[],ls='None',color='black',marker='o', label='Versicolor')
star = mlines.Line2D([],[],ls='None',color='black',marker='*',markersize=10, label='Virginica')
plt.legend(handles=[red_patch,green_patch,blue_patch,triangle,circle,star])
plt.show()
We use Matplotlib pyplot to draw the scatter plot. Although all Setosa flowers are classified as one cluster, there are some misclassifications for Versicolor and Virginica flowers.
Version information: Code from this page was tested in Stata 14
You can download the iris data set in Stata data format here.
use "~\iris.dta", clear
cluster kmeans is the STATA command to do k-means clustering. The basic syntax looks like this:
cluster kmeans [varlist], k(#) [options]
cluster kmeans Sepal_Length Sepal_Width Petal_Length Petal_Width, k(3) measure(L2) start(krandom)
Let’s see how clusters from k-means resemble the actual species of iris flowers using tabulate function in stata.
We can see clusters from k-means look decent for species of setosa and versicolor, but not so well for virginica.
Unlike ggplot, stata does not allow directly changing the shape of points in scatter plot matrix, hence shapes need to be defined manually.
// generate label for different clusters
gen label = "o" if _clus_1 == 1
replace label = "+" if _clus_1 == 2
replace label = "*" if _clus_1 == 3
// scatter matrix with prespecified shape
graph matrix Sepal_Length Sepal_Width Petal_Length Petal_Width, half ms(none) mla(label)
Note: adding “half” in the option will create only the lower triangular portion of scatter matrix.
From the scatter matrix, we can see one cluster (mark with *) is clearly separate from other two clusters, while the other two clusters have some overlaps, indicating k-means may have difficulty distinguishing between different species of flowers.
We can use function matrix dissimilarity in stata to create a dissimilarity matrix between observation (each row) in the dataset. You may guess that we can use function matrix similarity to create a similarity measure.
The syntax of matrix dissimilarity looks like this:
matrix dissimilarity matname = [varlist]
matrix dissimilarity Distance = Sepal* Petal*
The code above will create a dissimilarity matrix called Distance, using variables whose names start with “Sepal” or “Petal”, so both length and width of both sepal and petal will be used.
To see the dissimilarity matrix we just created, the following code can be used.
mat list Distance
Command to do multidimensional scaling using dissimilarity matrix in Stata is mdsmat
The code below using matrix Distance for multidimensional scaling, the config option will let stata to print out MDS coordinates for each observation.
mdsmat Distance, config
In fact, we do not need to go through the above process of computing dissimilarity matrix to do multidimensional scaling, because stata already generate an MDS coordinate behind the scene during k-means clustering and you can refer it to using e(Y). The dissimilarity matrix and MDS above is for your own edification :)
mat Coor = e(Y)
The code above extracts the MDS matrix from k-means and assigns it to a new name Coor.
svmat Coor, name(MDS)
The code above appends matrix Coor to the existing dataset, basically it adds 2 columns (MDS1 and MDS2) which represents the MDS coordinates for each observation.
Stata does not allow having a scatter diagram where there are different grouping variables for color and shape, hence I generate two separate scatter diagrams where one is by species and another by clusters. Two methods to generate scatter plots by groups will be demonstrated here.
graph twoway (scatter MDS2 MDS1 if _clus_1 == 1) (scatter MDS2 MDS1 if _clus_1 == 2) (scatter MDS2 MDS1 if _clus_1 == 3), legend(label(1 Cluster1) label(2 Cluster2) label(3 Cluster3))
Scatter plot with clusters
// Seperate variable by grouping variable
separate MDS2, by(Species)
The code below creates scatter plot using MDS1 and 3 groups of MDS2, the option msymbol allows customization of shapes for different groups.
graph twoway scatter MDS21 MDS22 MDS23 MDS1, msymbol(h S D),, legend(label(1 setosa) label(2 versicolor) label(3 virginica))
Scatter plot with actual species
Comparing scatter plot for actually species and k-means clustering result, it is clear that k-means will produce non-overlapping clusters, this can be problematic when there is a region where observations from different groups are mixed together.
The major feature of k-means is linearly separable and non-overlapping clusters.
Let’s try to draw a line to separate 3 species.
We can see that we need a curve to separate versicolor and virginica groups. However, in k-means the line separating clusters should be a straight line, which will lead to the following result. This is the result why k means cannot distinguish between these two species. This is one common problem for k-means when several groups are similar in some attributes or variables.
K-means classification is an optimization problem using numerical methods. If it converges, it will converge to a local solution. However, there may be several non-unique local solutions. Therefore, depending on the initial selection of k centroids (which are drawn randomly), the k-means algorithm may not converge to a unique solution.
We run the k-mean algorithm on the same dataset many times. Although most of the time it converges to the result on the left, in some cases it converges to the result on the right.