Credit: xkcd (https://xkcd.com/388/)

Credit: xkcd (https://xkcd.com/388/)

Introduction to k-means clustering

K-means clustering is an important algorithm for clustering analysis in unsupervised learning. The goal of k-means is to partition observation into k clusters, where each observation is classified to the cluster which it has the shortest distance to the center.

K-means algorithm to get k clusters works as follows:

  1. Randomly pick k observations from a total of n observations, they will function as initial centroids.

  2. For every observation (that is not initial centroids), decide which centroid is closest to it using a measurement like Euclidean distance. The observation will be assigned to the cluster corresponding to that centroid.

  3. After all observations are assigned to those k clusters, find the mean in each axis/dimension for all observations from each of the k clusters. This will generate the new centroids of all k clusters (new centroid(s) can be existing observation(s) or arbitrary point(s) ).

  4. Repeat step 2 and 3 until convergence is reached (i.e no change in the assignment of observation to clusters) or the maximum number of iterations is reached.

The following is a nice illustration of k-means algorithm.

Picture credit: Learn by Marketing (http://www.learnbymarketing.com)

Picture credit: Learn by Marketing (http://www.learnbymarketing.com)

Data description

You can download the iris dataset used in this tutorial in csv format and in stata data format

The dataset contains 3 classes of 50 instances each, where each class refers to a type of iris plant. The three types of plants are Setosa, Versicolour and Virginica. There are no missing values in the dataset.

There are 5 attributes in the dataset, sepal length in cm, sepal width in cm, petal length in cm, petal width in cm and classes(name of species).

Tutorial

We provide examples of k-means clustering using 3 languages: R, Python and stata. Besides k-means clustering, we also give examples of how to do multidimensional scaling to expand your skills set beyond what we learned in R.

R

Preliminary preparation

rm(list=ls())
# load data
iris=read.csv('https://raw.githubusercontent.com/pydata/pandas/master/pandas/tests/data/iris.csv')
# Add label
iris$Species_label=1
for (i in (1:nrow(iris))){
  if (iris[i,5]=="Iris-versicolor") iris[i,6]=2
  else if (iris[i,5]=="Iris-virginica") iris[i,6]=3
}
# Show data
head(iris)
##   SepalLength SepalWidth PetalLength PetalWidth        Name Species_label
## 1         5.1        3.5         1.4        0.2 Iris-setosa             1
## 2         4.9        3.0         1.4        0.2 Iris-setosa             1
## 3         4.7        3.2         1.3        0.2 Iris-setosa             1
## 4         4.6        3.1         1.5        0.2 Iris-setosa             1
## 5         5.0        3.6         1.4        0.2 Iris-setosa             1
## 6         5.4        3.9         1.7        0.4 Iris-setosa             1

First of all, we load the iris data and add a species label column. We set Setosa to 1, Versicolor to 2, and Virginica to 3. The first six rows of the data are shown as above.

K-means clustering

# Kmeans cluster
fit1 = kmeans(iris[,1:4],3)

# Add new label to original data
iris$Kmeans_label = fit1$cluster
head(iris)
##   SepalLength SepalWidth PetalLength PetalWidth        Name Species_label
## 1         5.1        3.5         1.4        0.2 Iris-setosa             1
## 2         4.9        3.0         1.4        0.2 Iris-setosa             1
## 3         4.7        3.2         1.3        0.2 Iris-setosa             1
## 4         4.6        3.1         1.5        0.2 Iris-setosa             1
## 5         5.0        3.6         1.4        0.2 Iris-setosa             1
## 6         5.4        3.9         1.7        0.4 Iris-setosa             1
##   Kmeans_label
## 1            1
## 2            1
## 3            1
## 4            1
## 5            1
## 6            1

W use kmeans function to do the cluster. We set the number of clusters to be 3, and then a random set of (distinct) rows in the data are chosen as the initial centers. We also added a Kmeans label column. I have attached the help page of k-means clustering: kmeans help page

Evaluate performance

Confusion matrix

# confusion matrix
table(iris$Name, iris$Kmeans_label)
##                  
##                    1  2  3
##   Iris-setosa     50  0  0
##   Iris-versicolor  0  2 48
##   Iris-virginica   0 36 14

First of all, we create the confusion matrix. As the table shows, all setosa flowers are classified correctly, 2 of versicolor flowers are classified as virginica, and 14 virginica flowers are classified as versicolor.

Visualization with scatter plot

plot(iris[,1:4], col = fit1$cluster)

Secondly, we want to see the distribution of each group. As the plot shows, while the setosa flowers are separated from the other 2 species, k-means cannot separate the other two species perfectly.

Visualization with multidimensional scaling

#mds plot
dist = dist(iris[,1:4])
mds = cmdscale(dist)
plot(mds[,1], mds[,2],  type="n", xlab = "", ylab = "", axes = FALSE,main = "Iris")
points(mds[,1], mds[,2], pch = iris$Species_label, col = iris$Kmeans_label)
legend("bottomright",c("Kmeans ID=1","Kmeans ID=2","Kmeans ID=3","setosa","versicolor","virginica"), lty=c(1,1,1,NA,NA,NA), pch=c(NA,NA,NA,1:3),lwd=c(2.5,2.5,2.5,1,1,1),col=c(1:3,"black","black","black"),cex=0.7)

Then, we use multidimensional scaling to produce a two-dimensional map. In R, we use cmdscale function. We use different shapes to represent different species and different colors to distinguish different clusters.

We can see that the setosa flowers are all classified correctly by k-means. However, some versicolor flowers are classified as virginica, and vice versa.

Python

Preliminary preparation

##Load packages
 # packages for data manipulation
import pandas as pd
import numpy as np
 # packages to run Kmeans and Multidimensional Scaling
from sklearn.manifold import MDS
from sklearn.cluster import KMeans
 # packages to generate plots
import matplotlib.pyplot as plt
import matplotlib.lines as mlines #create legend
import matplotlib.patches as mpatches  #create legend
import seaborn as sns
  1. Pandas and Numpy are two core packages that are useful in data management and analysis.

  2. Sklearn is a simple and efficient tool for machine learning. The K-means clustering function is included in this module.

  3. Matplotlib and Seaborn are most frequently used modules when visualization is needed.

##Load data
iris=pd.read_csv('https://raw.githubusercontent.com/pydata/pandas/master/pandas/tests/data/iris.csv')
##Add label
iris['Species_label']=0
for i in range(len(iris)):
    if iris.iloc[i,4]=="Iris-versicolor":
        iris.iloc[i,5]=1
    elif iris.iloc[i,4]=="Iris-virginica":
        iris.iloc[i,5]=2
##Show data
iris.head()

  SepalLength   SepalWidth  PetalLength PetalWidth  Name    Species_label
0   5.1         3.5         1.4         0.2      Iris-setosa     0
1   4.9         3.0         1.4         0.2      Iris-setosa     0
2   4.7         3.2         1.3         0.2      Iris-setosa     0
3   4.6         3.1         1.5         0.2      Iris-setosa     0
4   5.0         3.6         1.4         0.2      Iris-setosa     0

Firstly, load the iris data as pandas data frame.

Then add a species label column for easier classification. Setosa to 0, Versicolor to 1, Virginica to 2.

The first 5 rows of the data frame are shown as above.

k-means clustering

##Kmeans cluster
kmeans_3=KMeans(n_clusters=3,init='random')
labels_3=kmeans_3.fit_predict(iris.iloc[:,0:4])
 #Second method to obtain clusters
dist_3=kmeans_3.fit_transform(iris.iloc[:,0:4])
labels_3_2=kmeans_3.labels_

##Add new label to original data
iris['Kmeans_label']=labels_3
iris.head()

 SepalLength    SepalWidth  PetalLength PetalWidth  Name    Species_label Kmeans_label
0   5.1         3.5         1.4         0.2      Iris-setosa     0       0
1   4.9         3.0         1.4         0.2      Iris-setosa     0       0
2   4.7         3.2         1.3         0.2      Iris-setosa     0       0
3   4.6         3.1         1.5         0.2      Iris-setosa     0       0
4   5.0         3.6         1.4         0.2      Iris-setosa     0       0

We set 3 as the number of clusters, and choose the initial centroids randomly. We can also set the maximum number of iterations, the seed used by the random number generator etc. by using other parameters in the function. In this tutorial, we use the default values for other parameters.

After setting up the parameters, use fit_predict to obtain the clustering results. We can also use fit_transform to acquire the cluster-distance space, then use labels_ to acquire the same results as fit_predict. The code is shown above, we use the results of fit_predict to continue the analysis.

Evaluate performance

confusion matrix

Use crosstab function in pandas, sum up the clustering results and compare with the original species. The results are presented in percentage of each row.

pd.crosstab(index=iris['Name'],columns=iris['Kmeans_label'],rownames=['Species'],colnames=['Kmeans']). \
     apply(lambda r: r/r.sum()*100, axis=1)
     
Kmeans             0      1      2
Species         
Iris-setosa     100.0   0.0    0.0
Iris-versicolor   0.0   4.0   96.0
Iris-virginica    0.0  72.0   28.0

The parameter ‘index’ contains flower species, while ‘columns’ refers to clusters generated by k-means. We can use ‘rownames’ and ‘colnames’ to rename the title of rows and columns.

visualization with scatter matrix

The pairplot in seaborn package is a good way to show the distribution of each group. It plots pairwise relationships in a dataset, each variable in the data is shared on the y-axis and x-axis. Meanwhile, the diagonal axes show the univariate distributions of all variables included.

The parameter ‘hue’ is the variable to map plot aspects to different colors, ‘palette’ is a set of colors for mapping.

##Pairplot
sns.pairplot(iris.iloc[:,[0,1,2,3,6]],hue='Kmeans_label',palette="husl")
plt.show()

As we can see from the pairplot above, the group represented by pink points is clearly separated from the other two groups, while there are some overlaps between the green and the blue points.

visualization with multidimensional scaling

##Multidimensional scaling
mds=MDS(n_components=2,dissimilarity='euclidean',n_jobs=1)
pos=mds.fit(iris.iloc[:,0:4]).embedding_

##Transform list to numpy array, easier to use
pos=np.array(pos)

We use the function MDS in sklearn.manifold to do multidimensional scaling, utilizing Euclidean distances as the measurement of dissimilarities. After setting up the parameters, use fit and embedding_ to calculate and obtain the results.

##Define different groups
kmeans_id=np.array(iris['Kmeans_label'])
species_id=np.array(iris['Species_label'])
group00=pos[np.intersect1d(np.where(kmeans_id==0),np.where(species_id==0))]
group01=pos[np.intersect1d(np.where(kmeans_id==0),np.where(species_id==1))]
group02=pos[np.intersect1d(np.where(kmeans_id==0),np.where(species_id==2))]
group10=pos[np.intersect1d(np.where(kmeans_id==1),np.where(species_id==0))]
group11=pos[np.intersect1d(np.where(kmeans_id==1),np.where(species_id==1))]
group12=pos[np.intersect1d(np.where(kmeans_id==1),np.where(species_id==2))]
group20=pos[np.intersect1d(np.where(kmeans_id==2),np.where(species_id==0))]
group21=pos[np.intersect1d(np.where(kmeans_id==2),np.where(species_id==1))]
group22=pos[np.intersect1d(np.where(kmeans_id==2),np.where(species_id==2))]

##Plot to show different groups
plt.figure(figsize=(10,8))
plt.xlim([-3,3])
plt.ylim([-5,5])
plt.scatter(group00[:,0],group00[:,1],c='r',marker='^',)
plt.scatter(group01[:,0],group01[:,1],c='r',marker='o',)
plt.scatter(group02[:,0],group02[:,1],c='r',marker='*',)
plt.scatter(group10[:,0],group10[:,1],c='g',marker='^',)
plt.scatter(group11[:,0],group11[:,1],c='g',marker='o',)
plt.scatter(group12[:,0],group12[:,1],c='g',marker='*',)
plt.scatter(group20[:,0],group20[:,1],c='b',marker='^',)
plt.scatter(group21[:,0],group21[:,1],c='b',marker='o',)
plt.scatter(group22[:,0],group22[:,1],c='b',marker='*',)

 #Plot legend
red_patch = mpatches.Patch(color='red', label='Kmeans Id=0')
green_patch = mpatches.Patch(color='green', label='Kmeans Id=1')
blue_patch = mpatches.Patch(color='blue', label='Kmeans Id=2')
triangle = mlines.Line2D([],[],ls='None',color='black',marker='^', label='Setosa')
circle = mlines.Line2D([],[],ls='None',color='black',marker='o', label='Versicolor')
star = mlines.Line2D([],[],ls='None',color='black',marker='*',markersize=10, label='Virginica')
plt.legend(handles=[red_patch,green_patch,blue_patch,triangle,circle,star])

plt.show()

We use Matplotlib pyplot to draw the scatter plot. Although all Setosa flowers are classified as one cluster, there are some misclassifications for Versicolor and Virginica flowers.

stata

Version information: Code from this page was tested in Stata 14

Preliminary preparation

You can download the iris data set in Stata data format here.

use "~\iris.dta", clear 

K-means clustering

cluster kmeans is the STATA command to do k-means clustering. The basic syntax looks like this:
cluster kmeans [varlist], k(#) [options]

cluster kmeans Sepal_Length Sepal_Width Petal_Length Petal_Width, k(3) measure(L2) start(krandom)
  • varlist: variables for clustering, in the example code we use variables Sepal_Length Sepal_Width Petal_Length Petal_Width.
  • k(#): this specifies how many clusters we need, in the example code we want 3 clusters.
  • measure(L2): this specifies similarity or dissimilarity measure for k-means, the default is L2 which is Euclidean distance, you can change it to Canberra or L1(absolute-value distance). See measure_option for more details.
  • start(): this option specifies how initial k group centers are chosen. The default option is krandom, that is k observations will be randomly drawn as centers. One of the other options is random, that means k centers are chosen from a uniform distribution from the range of data. See options section in stata manual of kmeans for more details.
    After running the command, STATA will generate a new column in the dataset indicating the grouping of each observation in each cluster.

Evaluate performance

confusion matrix

Let’s see how clusters from k-means resemble the actual species of iris flowers using tabulate function in stata.

We can see clusters from k-means look decent for species of setosa and versicolor, but not so well for virginica.

visualization with scatter matrix

Unlike ggplot, stata does not allow directly changing the shape of points in scatter plot matrix, hence shapes need to be defined manually.

// generate label for different clusters
gen label = "o" if _clus_1 == 1
replace label = "+" if _clus_1 == 2
replace label = "*" if _clus_1 == 3

// scatter matrix with prespecified shape
graph matrix Sepal_Length Sepal_Width Petal_Length Petal_Width, half ms(none) mla(label)

Note: adding “half” in the option will create only the lower triangular portion of scatter matrix.

From the scatter matrix, we can see one cluster (mark with *) is clearly separate from other two clusters, while the other two clusters have some overlaps, indicating k-means may have difficulty distinguishing between different species of flowers.

visualization with multidimensional scaling

multidimensional scaling with distace matrix

We can use function matrix dissimilarity in stata to create a dissimilarity matrix between observation (each row) in the dataset. You may guess that we can use function matrix similarity to create a similarity measure.

The syntax of matrix dissimilarity looks like this:

matrix dissimilarity matname = [varlist]

  • matname is the name of dissimilarity matrix you want to assign.
  • [varlist] specifies variables used to generate dissimilarity matrix, enter varialbes without square brachets and comma
matrix dissimilarity Distance = Sepal* Petal*

The code above will create a dissimilarity matrix called Distance, using variables whose names start with “Sepal” or “Petal”, so both length and width of both sepal and petal will be used.

To see the dissimilarity matrix we just created, the following code can be used.

mat list Distance

Command to do multidimensional scaling using dissimilarity matrix in Stata is mdsmat
The code below using matrix Distance for multidimensional scaling, the config option will let stata to print out MDS coordinates for each observation.

mdsmat Distance, config

short cut for multidimensional scaling using k-means

In fact, we do not need to go through the above process of computing dissimilarity matrix to do multidimensional scaling, because stata already generate an MDS coordinate behind the scene during k-means clustering and you can refer it to using e(Y). The dissimilarity matrix and MDS above is for your own edification :)

mat Coor = e(Y)

The code above extracts the MDS matrix from k-means and assigns it to a new name Coor.

svmat Coor, name(MDS)

The code above appends matrix Coor to the existing dataset, basically it adds 2 columns (MDS1 and MDS2) which represents the MDS coordinates for each observation.

Stata does not allow having a scatter diagram where there are different grouping variables for color and shape, hence I generate two separate scatter diagrams where one is by species and another by clusters. Two methods to generate scatter plots by groups will be demonstrated here.

  1. Combine 3 different scatter plots for each cluster into one.
    The following code plots MDS1 and MDS2 for all the observations from cluster 1 first, then adds observations from cluster 2 and subsequently cluster 3, STATA by default will differentiate points from different groups using different colors. You can customize labels for each group as shown below.
graph twoway (scatter MDS2 MDS1 if _clus_1 == 1) (scatter MDS2 MDS1 if _clus_1 == 2) (scatter MDS2 MDS1 if _clus_1 == 3), legend(label(1 Cluster1) label(2 Cluster2) label(3 Cluster3))

Scatter plot with clusters

  1. Separate one variable by groups, then plot every variable in one scatter plot.
    The code below separates MDS2 by Species by creating 3 columns (MDS21, MDS22 MDS23), each column contains values MDS2 for each species. See screenshot of data set.
// Seperate variable by grouping variable 
separate MDS2, by(Species)

The code below creates scatter plot using MDS1 and 3 groups of MDS2, the option msymbol allows customization of shapes for different groups.

graph twoway scatter MDS21 MDS22 MDS23 MDS1, msymbol(h S D),, legend(label(1 setosa) label(2 versicolor) label(3 virginica))

Scatter plot with actual species

Comparing scatter plot for actually species and k-means clustering result, it is clear that k-means will produce non-overlapping clusters, this can be problematic when there is a region where observations from different groups are mixed together.

Final note

Problem with non linearly seperable observations

The major feature of k-means is linearly separable and non-overlapping clusters.

Let’s try to draw a line to separate 3 species.

We can see that we need a curve to separate versicolor and virginica groups. However, in k-means the line separating clusters should be a straight line, which will lead to the following result. This is the result why k means cannot distinguish between these two species. This is one common problem for k-means when several groups are similar in some attributes or variables.

Problem with non-unique solutions

K-means classification is an optimization problem using numerical methods. If it converges, it will converge to a local solution. However, there may be several non-unique local solutions. Therefore, depending on the initial selection of k centroids (which are drawn randomly), the k-means algorithm may not converge to a unique solution.

We run the k-mean algorithm on the same dataset many times. Although most of the time it converges to the result on the left, in some cases it converges to the result on the right.

Reference and resources

  1. K-means clustering on Wikipedia
  2. K-means lecture by Robert Tibshirani and Trevor Hastie