Saturday, October 11, 2008

IR Math in Java : Cluster Visualization

I've been trying to learn clustering algorithms lately. I was planning to write about them this week, but some last minute refactoring to remove redundancies and make the code more readable resulted in everything going to hell. So I guess I will have to write about them next week.

Almost all clustering algorithms (at least the ones I have seen) seem to be non-deterministic, mainly because they select documents randomly from the collection to build the initial clusters. As a result, they can come up with wildly different clusters depending on how the initial clusters were formed. In my previous (un-refactored) code, for example, the K-Means algorithm converged to the same set of clusters most of the time, but with the changes, they no longer do.

Working through this for some time, I decided I needed to see for myself what the "correct" clusters were. So if I could visualize the documents as points in the n-dimensional term space, clumps of points would correspond to clusters. The problem was that I had only 2 (or maximum 3) dimensions of visualization to work with.

Luckily for me, smarter people than I have faced and solved the same problem, and they have been kind enough to write about it on the web. The solution is to do Dimensionality Reduction, extracting from the term-document matrix the first 2 or 3 most interesting components (or Principal Components) and use them as the values for a 2-dimensional or 3-dimensional scatter chart.

The mathematical background for Principal Component Analysis (PCA) is explained very nicely in this tutorial, which I quote verbatim below.

The mathematical technique used in PCA is called eigen analysis: we solve for the eigenvalues and eigenvectors of a square symmetric matrix with sums of squares and cross products. The eigenvector associated with the largest eigenvalue has the same direction as the first principal component. The eigenvector associated with the second largest eigenvalue determines the direction of the second principal component. The sum of the eigenvalues equals the trace of the square matrix and the maximum number of eigenvectors equals the number of rows (or columns) of this matrix.

It then goes on to explain the algorithm that should be used for reducing and extracting the most interesting dimensions (See Section 6, Algorithms) for a non-square matrix such as our term-document matrix. Essentially, it decomposes the term-document matrix A using Singular Value Decomposition (SVD) into 3 matrices, U, S and V, where the following equation holds.

  A = U * S * VT

Here S is a square diagonal matrix, where values are in descending order down the diagonal. So for a 2 dimensional reduction, the principal components correspond to the first 2 columns of V, and for a 3 dimensional reduction, the principal components correspond to the first 3 columns of V.

The Java code to generate data for drawing the charts is trivial, mainly because we use the Jama matrix library, which does all the heavy lifting for SVD calculations.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
// Source: src/main/java/com/mycompany/myapp/clustering/PcaClusterVisualizer.java
package com.mycompany.myapp.clustering;

import java.io.FileWriter;
import java.io.IOException;
import java.io.PrintWriter;

import Jama.Matrix;
import Jama.SingularValueDecomposition;

public class PcaClusterVisualizer {

  private final String PLOT_2D_OUTPUT = "plot2d.dat";
  private final String PLOT_3D_OUTPUT = "plot3d.dat";
  
  public void reduce(Matrix tdMatrix, String[] docNames) throws IOException {
    PrintWriter plot2dWriter = 
      new PrintWriter(new FileWriter(PLOT_2D_OUTPUT));
    PrintWriter plot3dWriter = 
      new PrintWriter(new FileWriter(PLOT_3D_OUTPUT));
    SingularValueDecomposition svd = 
      new SingularValueDecomposition(tdMatrix);
    Matrix v = svd.getV();
    // we know that the diagonal of S is ordered, so we can take the
    // first 3 cols from V, for use in plot2d and plot3d
    Matrix vRed = v.getMatrix(0, v.getRowDimension() - 1, 0, 2);
    for (int i = 0; i < v.getRowDimension(); i++) { // rows
      plot2dWriter.printf("%6.4f %6.4f %s%n", 
        Math.abs(vRed.get(i, 0)), Math.abs(vRed.get(i, 1)), docNames[i]);
      plot3dWriter.printf("%6.4f %6.4f %6.4f %s%n", 
        Math.abs(vRed.get(i, 0)), Math.abs(vRed.get(i, 1)), 
        Math.abs(vRed.get(i, 2)), docNames[i]);
    }
    plot2dWriter.flush();
    plot3dWriter.flush();
    plot2dWriter.close();
    plot3dWriter.close();
  }
}

The term-document matrix is generated from my 7 document title collection that I have been using for my experiments, using the following snippet of JUnit code. See one of my earlier posts titled IR Math in Java : TF, IDF and LSI for the actual data and details on the classes being used.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
  @Test
  public void testPcaClusterVisualization() throws Exception {
    // for brevity, this block is in a @Before method in the actual
    // code, it has been globbed together here for readability
    VectorGenerator vectorGenerator = new VectorGenerator();
    vectorGenerator.setDataSource(new DriverManagerDataSource(
      "com.mysql.jdbc.Driver", "jdbc:mysql://localhost:3306/tmdb", 
      "irstuff", "irstuff"));
    Map<String,Reader> documents = 
      new LinkedHashMap<String,Reader>();
    BufferedReader reader = new BufferedReader(
      new FileReader("src/test/resources/data/indexing_sample_data.txt"));
    String line = null;
    while ((line = reader.readLine()) != null) {
      String[] docTitleParts = StringUtils.split(line, ";");
      documents.put(docTitleParts[0], new StringReader(docTitleParts[1]));
    }
    vectorGenerator.generateVector(documents);
    IdfIndexer indexer = new IdfIndexer();
    tdMatrix = indexer.transform(vectorGenerator.getMatrix());
    documentNames = vectorGenerator.getDocumentNames();
    documentCollection = new DocumentCollection(tdMatrix, documentNames);
    // this is my actual @Test block
    PCAClusterVisualizer visualizer = new PCAClusterVisualizer();
    visualizer.reduce(tdMatrix, documentNames);
  }

This generates 2 data files which are used as inputs to gnuplot to generate 2D and 3D scatter charts. The data, chart, and the gnuplot code to generate the chart is shown in the table below:

1
2
3
4
5
6
7
8
# plot2d.dat
0.0000 0.2261 D1
0.0468 0.0000 D2
0.0000 0.7363 D3
0.0000 0.6378 D4
0.0000 0.0000 D5
0.8751 0.0000 D6
0.4817 0.0000 D7
1
2
3
4
5
6
7
8
# plot3d.dat
0.0000 0.2261 0.0000 D1
0.0468 0.0000 0.2997 D2
0.0000 0.7363 0.0000 D3
0.0000 0.6378 0.0000 D4
0.0000 0.0000 0.0000 D5
0.8751 0.0000 0.4723 D6
0.4817 0.0000 0.8289 D7
1
2
3
4
5
6
# plot2d.gp
set style data labels
unset key
plot 'plot2d.dat' using 1:2:3 \
  with labels font "arial,11" \
  textcolor lt 1
1
2
3
4
5
6
# plot3d.gp
set style data labels
unset key
splot 'plot3d.dat' using 1:2:3:4 \
  with labels font "arial,11" \
  textcolor lt 1

From the charts above, it appears that the following clusters may be valid for our test document set. Notice that although D3 and D7 appear really close (overlapped) on the 3D chart, they don't seem to be close going by the 2D chart or the data. In any case, the results look believable, although not perfect, but that could be due to dimensionality reduction and/or the small data set.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
C0: [D1, D2, D5]
    D1  Human machine interface for <b>computer</b> applications
    D2  A survey of user opinion of <b>computer</b> system response time
    D5  The generation of random, binary and ordered trees
C1: [D3, D4]
    D3  The <b>EPS</b> user interface management <b>system</b>
    D4  <b>System</b> and human system engineering testing of <b>EPS</b>
C2: [D7]
    D7  Graph minors: A survey
C3: [D6]
    D6  The intersection graph of paths in trees

I think this post may be helpful to programmers like me who are just getting into IR (most people who are heavily into IR would probably know this stuff already). Text mining algorithms, by their very nature, need to deal with n-dimensional data, and the ability to visualize the data in 2D or 3D can be quite enlightening, so this is a useful tool to have in one's text mining toolbox.

Update 2009-04-26: In recent posts, I have been building on code written and described in previous posts, so there were (and rightly so) quite a few requests for the code. So I've created a project on Sourceforge to host the code. You will find the complete source code built so far in the project's SVN repository.

Be the first to comment. Comments are moderated to prevent spam.