Saturday, March 03, 2018

An implementation of the Silhouette Score metric on Spark


While clustering some data on Spark recently, I needed a quantitative metric to evaluate the quality of the clustering. Couldn't find anything built-in, so (predictably) went looking on Google, where I found this Stack Overflow page discussing this very thing. However, as you can see from the accepted answer, the Silhouette score by definition is not very scalable, since it requires measuring the average distance from each point to every other point in its cluster, as well as the average distance from each point to every other point in all the other clusters. This is clearly an O(N2) operation and not likely to scale for large N. What caught my eye, however, was Geoffrey Anderson's answer, who suggested using the Simplified Silhouette score, which requires only the distance from every point to its own centroid and is thus an O(N) operation and therefore quite scalable.

However, while this is indeed the algorithm suggested in the cited paper (Multiple K Means++ Clustering of Satellite Image Using Hadoop MapReduce and Spark by Sharma, Shokeen and Mathur), this simplified metric only reports on how tight each cluster is. The intent of the original Silhouette score is to report both on tightness of individual clusters as well as separation between clusters. Turns out that such a simplified Silhouette metric does exist, and is defined in detail in this paper titled An Analysis of the Application of Simplified Silhouette to the Evaluation of k-means Clustering Validity (PDF) by Wang, et al. Interestingly, this is also the definition used in the implementation of Silhouette score in Scikit-Learn.

The latter formulation of the Simplified Silhouette Index (SSI) is shown below. For each point i, call the distance to its own cluster centroid ai, and call the distance to the nearest neighboring centroid bi. The Silhouette score for the i-th point is given by SSIi as shown below. Here ai is the indicator of cluster tightness and bi is the indicator of cluster separation. One thing to notice is that in order to compute bi you will need to compute the distance from point i to all the other centroids except its own, so the complexity of the algorithm is O(Nk) where k is the number of clusters — so still linear, but can be high for large k.


Values for SSI can vary in the range [-1, 1]. Values near 0 indicate overlapping clusters, and negative values generally indicate that the point has been wrongly clustered, since it is closer to a different cluster than it's own. Best values of SSI are close to 1. The average SSI across all the points in the corpus gives us an indication of how good the cluster is.

In addition, the distribution of SSI values in each cluster can be histogrammed as shown in this KMeans clustering and Silhouette analysis example on Scikit-Learn. A Clustering where the points are distributd approximately equally across clusters tend to be better, given similar values of average SSI.

A flow diagram for my implementation is shown below. There are 3 inputs needed, an RDD of point vectors keyed by an sequential record ID (rid), an RDD of predictions consisting of a record ID and cluster ID pair, and an RDD of centroids consisting of the clusterID and centroid vector. The RDD of points is just the input vectors with an additional sequential record ID, which you can easily provide with a zipWithIndex call on the original input RDD of vectors to cluster. The prediction RDD is the output of model.predict on the clustering model. The centroids RDD is the output of model.clusterCenters, with an additional zipWithIndex to get the cluster IDs. The point and prediction RDD are joined on the record ID, and the centroids RDD is converted to a lookup dictionary of cluster ID to cluster vector and broadcasted to the workers. The joined RDD and the broadcasted lookup table are used to compute SSI for each point. We retain the cluster ID in case we want to compute the histograms by cluster. I didn't do this because my data is too large for this information to be visually meaningful, I ended up plotting the distribution of number of points in each cluster instead.


All the input RDDs are available to the algorithm as structured text files from the clustering process. Here is the Scala code I used to implement this Simplified Silhouette index. We use Databricks Notebooks for most of our Spark analytics work, so the code below doesn't contain the boilerplate that you need to start Spark jobs directly on EMR or a standard Spark cluster. But it should be relatively simple to add that stuff in if needed.

1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import breeze.linalg._

val OUTPUT_POINTS = "/path/to/points"
val OUTPUT_PREDICTIONS = "/path/to/predictions"
val OUTPUT_CENTROIDS = "/path/to/centroids"

// read centroids file and convert to lookup table, then broadcast to workers
val centroidsRDD = sc.textFile(OUTPUT_CENTROIDS)
  .map(line => {
    val Array(cid, cstr) = line.split('\t')
    val cvec = DenseVector(cstr.split(',').map(x => x.toDouble))
    (cid.toLong, cvec)
  })

val cid2centroid = centroidsRDD.collect.toMap
val b_cid2centroid = sc.broadcast(cid2centroid)

// read points file
val pointsRDD = sc.textFile(OUTPUT_POINTS)
  .map(line => {
    val Array(rid, pstr) = line.split('\t')
    val pvec = DenseVector(pstr.split(',').map(x => x.toDouble))
    (rid.toLong, pvec)
  })
pointsRDD.persist()

// read predictions file
val predictionsRDD = sc.textFile(OUTPUT_PREDICTIONS)
  .map(line => {
    val Array(rid, cid) = line.split('\t')
    (rid.toLong, cid.toLong)
  })
predictionsRDD.persist()

// join pointsRDD and predictionsRDD, look up centroid vectors from broadcast
// and compute a, b and SSI for all points, group by cluster ID
def euclideanDist(v1: DenseVector[Double], v2: DenseVector[Double]): Double = norm(v1 - v2, 2)

val predictedPointsRDD = pointsRDD.join(predictionsRDD)    // (rid, (pvec, cid))
  .map(rec => {
    val rid = rec._1
    val pvec = rec._2._1
    val cid = rec._2._2
    val cvec = b_cid2centroid.value(cid)
    val aDist = euclideanDist(pvec, cvec)
    val bDists = b_cid2centroid.value.toList
      .filter(cc => cc._1 != cid)
      .map(cc => {
        val otherCid = cc._1
        val otherCvec = cc._2
        val otherDist = euclideanDist(pvec, otherCvec)
        (otherCid, otherDist)
      })
    val bDist = bDists.sortWith((a, b) => a._2 < b._2).head._2
    val ssi = if (aDist == 0.0 && bDist == 0.0) 0.0D 
              else (bDist - aDist) / max(List(aDist, bDist))
    (cid, ssi)
  })
predictedPointsRDD.persist()

// compute mean SSI
val acc = sc.doubleAccumulator("acc_ssi")
val sumSSI = predictedPointsRDD.foreach(rec => acc.add(rec._2))
val meanSSI = acc.value.toDouble / predictedPointsRDD.count
print("mean SSI: %.5f\n".format(meanSSI))

predictedPointsRDD.unpersist()
pointsRDD.unpersist()
predictionsRDD.unpersist()

I have used this code to evaluate two clustering algorithms (KMeans and Bisecting KMeans) on my data, using two different approaches to vectorizing the data, and with various values of K (number of clusters). The mean SSI metric provided me the ability to reduce an entire operation to a single number that I could compare across runs. I thought this was very helpful and helped me decide which outputs to keep and which to discard without having to physically scan each output. I hope this code is useful to others who might need to evaluate their clustering algorithms on Spark.


4 comments (moderated to prevent spam):

Anonymous said...

Very shortly this website will be famous amid all
blogging visitors, due to it's nice articles or reviews

Sujit Pal said...

Thanks for the kind words, Anonymous.

Anonymous said...

Thanks so much for this! Do you have a Python version of this?

Sujit Pal said...

Thanks, and no, I don't have a Python version, sorry. Although should be fairly easy to write (maybe easier because numpy has a cleaner API vs breeze IMO). I used Scala here mainly because you can inline code inside one of Spark's higher order functions, with Python you would need to write a separate function to compute the block that begins on line 40 and call it within the map call.