Friday, October 24, 2008

Phrase Spelling Corrector using Word Collocation Probabilities

Spelling correction is one of those things that people don't notice when it works well. Indeed, for web-based search applications, its manifestation is usually a little "Did you mean: xxx?" component that appears when the application is not able to recognize the term being queried for. In spite of its relative non-ubiquity, however, users do notice when the suggestion is incorrect.

There are various approaches to spelling corrections. One popular approach is to use a Lucene index with character n-grams for terms in the index - I have written previously about my implementation of this approach.

Another popular approach is to compute edit costs and return from a dictionary the words that are within a predefined edit cost from the mispelt word. This is the approach used by GNU Aspell and its Java cousin Jazzy, which we use here.

Both these approaches work very well for single words, so they are very usable for applications such as word processors, where you need to be able to flag and suggest alternatives for mispelt words. In a typical search page, however, a user can type in a multi-word phrase, with one or more words mispelt. The job of the spelling corrector, in this case, is to tie the best suggestions together so that the corrected phrase makes sense within the context of the original phrase. A much harder problem, as you will no doubt agree.

Various approaches to solve this have been suggested and tried - I noticed one such suggestion almost by accident here on the Aspell TODO list, which set me thinking about this whole thing again.

Thinking about this suggestion a bit, I realized that a much simpler way would be to compute conditional probabilities between consecutive words in the phrase, and then consider the "best" suggestion to be the one which connects the words via the most probable path, i.e. the path with the highest sum of conditional probabilities. This effectively boils down a graph theory problem of computing the shortest path in a weighted directed graph. This post describes an implementation of this idea.

Consider Knud Sorensen's example from the Aspell TODO list. Two mispelt phrases and their corrected forms are shown below. As you can see, the correct form of the mispelt word 'fone' differs based on other words in the term.

1
2
    a fone number => a phone number
    a fone dress  => a fine dress

The list below shows the suggestions returned by Jazzy for the mispelt word 'fone', ordered by cost, i.e. the first suggestion is the one with the least edit cost to convert from the mispelt word. Notice that neither 'phone' nor 'fine' is the first result. The Java code for the CLI that I built for quickly looking up suggestions is available later in this post.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
sujit@sirocco:~$ mvn -o exec:java \
  -Dexec.mainClass=com.mycompany.myapp.Shell \
  -Dexec.args=true
jazzy> fone
foe, one, fine, bone, zone, fore, lone, fond, font, hone, cone, gone, 
none, done, tone, fen, foes, on, fan, fin, fun, money, phone, son, fee, 
for, fog, fox, fined, found, fount, fence, fines, finer, honey, non, 
don, ton, ion, yon, vane, vine, June, gene, mane, mine, mono, bane, 
bony, pane, pine, pony, sane, sine, fade, fate, food, foot, vote, face, 
foci, fuse, fare, fire, four, free, fame, foam, fume, file, flee, foil, 
fool, foul, fowl, fake, lane, line, fife, five, fogs, find, fund, fans, 
fins, fang, cane, nine, dine, dune, tune, wane, wine
jazzy> \q

The approach I propose is to construct a graph of our input phrase ('a fone book'), adding vertices corresponding to the original word and each of its spelling suggestions, as shown below. The edge weights represent the conditional probability of the edge target B being followed by the edge source A (or P(B|A)). The numbers are all cooked up for this example, but I describe a way to compute them further down. I still need to populate my database tables with real data, I will describe this in a subsequent post.

What you will immediately notice is that we cannot prune the graph as we encounter each word in the phrase, i.e. we cannot select the most likely suggestion as we parse each word, since the "best path" is the most probable path through the graph from the start vertex to the finish vertex.

Since we are going to use Dijkstra's shortest path algorithm to find the shortest path (aka Graph Geodesic) through the graph, we need to convert the edge probabilities to a weight function given by wA,B, like so:

  wA,B = 1 - P(B|A)
  where:
    wA,B = cost to get from vertex A to B
    P(B|A) = probability of the occurrence of B given A

The probability P(B|A) can be computed as the number of times A and B co-occur in our dataset divided by the number of times word A occurs in the dataset, as shown below:

  If the occurrence of A and B are dependent:
    P(B ∩ A) = P(B|A) * P(A)
  so:
    P(B|A) = P(B ∩ A) / P(A)
           = N(B ∩ A) / N(A)

To get experimental values for N(A) and N(B ∩ A), we will need to extract data from actual search terms used by users from our Apache access logs and populate the following tables:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
mysql> desc occur_a;
+---------+---------------+------+-----+---------+-------+
| Field   | Type          | Null | Key | Default | Extra |
+---------+---------------+------+-----+---------+-------+
| word    | varchar(32)   | NO   | PRI | NULL    |       | 
| n_words | mediumint(11) | NO   |     | NULL    |       | 
+---------+---------------+------+-----+---------+-------+

mysql> desc occur_ab;
+---------+---------------+------+-----+---------+-------+
| Field   | Type          | Null | Key | Default | Extra |
+---------+---------------+------+-----+---------+-------+
| word_a  | varchar(32)   | NO   | PRI | NULL    |       | 
| word_b  | varchar(32)   | NO   | PRI | NULL    |       | 
| n_words | mediumint(11) | NO   |     | NULL    |       | 
+---------+---------------+------+-----+---------+-------+

Without any data in the database tables, the code degrades very gracefully. It just returns what we typed in, as you can see below. This happens because we always insert the original word in the first position of the suggestion list returned by Jazzy, so the "best" among equals is the one that comes first. As before, the Java code for this CLI is provided later in the article.

1
2
3
4
5
6
sujit@sirocco:~$ mvn -o exec:java \
  -Dexec.mainClass=com.mycompany.myapp.Shell
spell-check> a fone book
a fone book
spell-check> a fone dress
a fone dress

Once some (still cooked up) occurrence data is entered manually into these tables...

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
mysql> select * from occur_a;
+-------+---------+
| word  | n_words |
+-------+---------+
| a     |     100 | 
| book  |      43 | 
| dress |      10 | 
| fine  |      12 | 
| phone |      18 | 
+-------+---------+
5 rows in set (0.00 sec)

mysql> select * from occur_ab;
+--------+--------+---------+
| word_a | word_b | n_words |
+--------+--------+---------+
| a      | fine   |       8 | 
| a      | phone  |      13 | 
| book   | phone  |      12 | 
| dress  | fine   |       7 | 
+--------+--------+---------+
4 rows in set (0.00 sec)

...our spelling corrector behaves more intelligently. The beauty of this approach is that its intelligence can be localized to your industry. So for example, if you were in the clothing business, your search terms are more likely to include fine dresses than phone books, and therefore the probability of P(dress|fine) would be higher than P(dress|phone).

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
sujit@sirocco:~$ mvn -o exec:java \
  -Dexec.mainClass=com.mycompany.myapp.Shell
spell-check> a fone book
a phone book
spell-check> a fone dress
a fine dress
spell-check> fone book
phone book
spell-check> fone dress
fine dress

Here is the code for the actual Spelling corrector. It uses Jazzy for its word suggestions, and JGraphT to construct a graph and run Dijkstra's shortest path algorithm (included in the JGraphT library) to find the most likely path based on word co-occurrence probabilities.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
// Source: src/main/java/com/mycompany/myapp/SpellingCorrector.java
package com.mycompany.myapp;

import java.io.File;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

import javax.sql.DataSource;

import org.apache.commons.lang.StringUtils;
import org.jgrapht.alg.DijkstraShortestPath;
import org.jgrapht.graph.ClassBasedEdgeFactory;
import org.jgrapht.graph.DefaultWeightedEdge;
import org.jgrapht.graph.SimpleDirectedWeightedGraph;
import org.springframework.dao.IncorrectResultSizeDataAccessException;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.jdbc.datasource.DriverManagerDataSource;

import com.swabunga.spell.engine.SpellDictionary;
import com.swabunga.spell.engine.SpellDictionaryHashMap;
import com.swabunga.spell.engine.Word;

/**
 * Uses probability of word-collocations to determine best phrases to be
 * returned from a SpellingCorrector for multi-word mispelt queries.
 */
public class SpellingCorrector {

  private static final int SCORE_THRESHOLD = 200;
  private static final String DICTIONARY_FILENAME = 
    "src/main/resources/english.0";
  
  private long occurASumWords = 1L;
  private JdbcTemplate jdbcTemplate;
  
  @SuppressWarnings("unchecked")
  public String getSuggestion(String input) throws Exception {
    // initialize Jazzy spelling dictionary
    SpellDictionary dictionary = new SpellDictionaryHashMap(
      new File(DICTIONARY_FILENAME));
    // initialize database connection
    DataSource dataSource = new DriverManagerDataSource(
      "com.mysql.jdbc.Driver", "jdbc:mysql://localhost:3306/spelldb", 
      "foo", "secret");
    jdbcTemplate = new JdbcTemplate(dataSource);
    occurASumWords = jdbcTemplate.queryForLong(
      "select sum(n_words) from occur_a");
    if (occurASumWords == 0L) {
      // just a hack to prevent divide by zero for empty db
      occurASumWords = 1L;
    }
    // set up graph and create root vertex
    final SimpleDirectedWeightedGraph<SuggestedWord,DefaultWeightedEdge> g = 
      new SimpleDirectedWeightedGraph<SuggestedWord,DefaultWeightedEdge>(
      new ClassBasedEdgeFactory<SuggestedWord,DefaultWeightedEdge>(
      DefaultWeightedEdge.class));
    SuggestedWord startVertex = new SuggestedWord("START", 0);
    g.addVertex(startVertex);
    // set up variables to hold results of previous iteration
    List<SuggestedWord> prevVertices = 
      new ArrayList<SuggestedWord>();
    List<SuggestedWord> currentVertices = 
      new ArrayList<SuggestedWord>();
    int tokenId = 1;
    prevVertices.add(startVertex);
    // parse the string
    String[] tokens = input.toLowerCase().split("[ -]");
    for (String token : tokens) {
      // build up spelling suggestions for individual word
      List<String> possibleTokens = new ArrayList<String>();
      if (token.trim().length() <= 2) {
        // people usually don't make mistakes for words 2 words or less,
        // just pass it back unchanged
        possibleTokens.add(token);
      } else if (dictionary.isCorrect(token)) {
        // no need to find suggestions, token is recognized as valid spelling
        possibleTokens.add(token);
      } else {
        possibleTokens.add(token);
        List<Word> words = 
          dictionary.getSuggestions(token, SCORE_THRESHOLD);
        for (Word word : words) {
          possibleTokens.add(word.getWord());
        }
      }
      // populate the graph with these values
      for (String possibleToken : possibleTokens) {
        SuggestedWord currentVertex = 
          new SuggestedWord(possibleToken, tokenId); 
        g.addVertex(currentVertex);
        currentVertices.add(currentVertex);
        for (SuggestedWord prevVertex : prevVertices) {
          DefaultWeightedEdge edge = new DefaultWeightedEdge();
          double weight = computeEdgeWeight(
            prevVertex.token, currentVertex.token);
          g.setEdgeWeight(edge, weight);
          g.addEdge(prevVertex, currentVertex, edge);
        }
      }
      prevVertices.clear();
      prevVertices.addAll(currentVertices);
      currentVertices.clear();
      tokenId++;
    } // for token : tokens
    // finally set the end vertex
    SuggestedWord endVertex = new SuggestedWord("END", tokenId);
    g.addVertex(endVertex);
    for (SuggestedWord prevVertex : prevVertices) {
      DefaultWeightedEdge edge = new DefaultWeightedEdge();
      g.setEdgeWeight(edge, 1.0D);
      g.addEdge(prevVertex, endVertex, edge);
    }
    // find shortest path between START and END
    DijkstraShortestPath<SuggestedWord,DefaultWeightedEdge> dijkstra =
      new DijkstraShortestPath<SuggestedWord, DefaultWeightedEdge>(
      g, startVertex, endVertex);
    List<DefaultWeightedEdge> edges = dijkstra.getPathEdgeList();
    List<String> bestMatch = new ArrayList<String>();
    for (DefaultWeightedEdge edge : edges) {
      if (startVertex.equals(g.getEdgeSource(edge))) {
        // skip the START vertex
        continue;
      }
      bestMatch.add(g.getEdgeSource(edge).token);
    }
    return StringUtils.join(bestMatch.iterator(), " ");
  }

  private Double computeEdgeWeight(String prevToken, String currentToken) {
    if (prevToken.equals("START")) {
      // this is the first word, return 1-P(B)
      try {
        double nb = (Double) jdbcTemplate.queryForObject(
          "select n_words/? from occur_a where word = ?", 
          new Object[] {occurASumWords, currentToken}, Double.class);
        return 1.0D - nb;
      } catch (IncorrectResultSizeDataAccessException e) {
        // in case there is no match, then we should return weight of 1
        return 1.0D;
      }
    }
    double na = 0.0D;
    try {
      na = (Double) jdbcTemplate.queryForObject(
        "select n_words from occur_a where word = ?", 
        new String[] {prevToken}, Double.class);
    } catch (IncorrectResultSizeDataAccessException e) {
      // no match, should be 0
      na = 0.0D;
    }
    if (na == 0.0D) {
      // if N(A) == 0, A does not exist, and hence N(A ^ B) == 0 too,
      // so we guard against a DivideByZero and an additional useless
      // computation.
      return 1.0D;
    }
    // for the A^B lookup, alphabetize so A is lexically ahead of B
    // since that is the way we store it in the database
    String[] tokens = new String[] {prevToken, currentToken};
    Arrays.sort(tokens); // alphabetize before lookup
    double nba = 0.0D;
    try {
      nba = (Double) jdbcTemplate.queryForObject(
        "select n_words from occur_ab where word_a = ? and word_b = ?",
        tokens, Double.class);
    } catch (IncorrectResultSizeDataAccessException e) {
      // no result found so N(B^A) = 0
      nba = 0.0D;
    }
    return 1.0D - (nba / na);
  }

  /**
   * Holder for the graph vertex information.
   */
  private class SuggestedWord {
    public String token;
    public int id;
    
    public SuggestedWord(String token, int id) {
      this.token = token;
      this.id = id;
    }
    
    @Override
    public int hashCode() {
      return toString().hashCode();
    }
    
    @Override
    public boolean equals(Object obj) {
      if (!(obj instanceof SuggestedWord)) {
        return false;
      }
      SuggestedWord that = (SuggestedWord) obj;
      return (this.id == that.id && 
        this.token.equals(that.token));
    }
    
    @Override
    public String toString() {
      return id + ":" + token;
    }
  };
}

The CLI proved to be very useful for checking out assumptions quickly when I was developing the algorithm. Its quite simple, it just wraps the functionality within a JLine ConsoleReader. I included it here for completeness and to illustrate how easy it is to build. Depending on the presence of a command line argument, it can function either as an interface over the Jazzy dictionary or to the Phrase Spelling Corrector described in this post.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
// Source: src/main/java/com/mycompany/myapp/Shell.java
package com.mycompany.myapp;

import java.io.File;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import jline.ConsoleReader;

import org.apache.commons.lang.StringUtils;

import com.swabunga.spell.engine.SpellDictionary;
import com.swabunga.spell.engine.SpellDictionaryHashMap;
import com.swabunga.spell.engine.Word;

public class Shell {

  private final int SPELL_CHECK_THRESHOLD = 250;

  public Shell() throws Exception {
    ConsoleReader reader = new ConsoleReader(
      System.in, new PrintWriter(System.out));
    SpellingCorrector spellingCorrector = new SpellingCorrector();
    for (;;) {
      String line = reader.readLine("spell-check> ");
      if ("\\q".equals(line)) {
        break;
      }
      System.out.println(spellingCorrector.getSuggestion(line));
    }
  }

  // === this is really for exploratory testing purposes ===
  
  /**
   * Wrapper over Jazzy native spell checking functionality.
   * @param b always true (to differentiate from the new ctor).
   * @throws Exception if one is thrown.
   */
  public Shell(boolean b) throws Exception {
    ConsoleReader reader = new ConsoleReader(
      System.in, new PrintWriter(System.out));
    SpellDictionary dictionary = new SpellDictionaryHashMap(
      new File("src/main/resources/english.0"));
    for (;;) {
      String line = reader.readLine("jazzy> ");
      if ("\\q".equals(line)) {
        break;
      }
      String suggestions = suggest(dictionary, line);
      System.out.println(suggestions);
    }
  }
  
  /**
   * Looks up single words from Jazzy's English dictionary.
   * @param dictionary the dictionary object to look up.
   * @param incorrect the suspected mispelt word.
   * @return if the incorrect word is correct according to 
   * Jazzy's dictionary, then it is returned, else a set of possible
   * corrections is returned. If no possible corrections were found, 
   * this method returns (no suggestions).
   */
  @SuppressWarnings("unchecked")
  private String suggest(SpellDictionary dictionary, String incorrect) {
    if (dictionary.isCorrect(incorrect)) {
      // return the entered word
      return incorrect;
    }
    List<Word> words = dictionary.getSuggestions(
      incorrect, SPELL_CHECK_THRESHOLD);
    List<String> suggestions = new ArrayList<String>();
    final Map<String,Integer> costs = 
      new HashMap<String,Integer>();
    for (Word word : words) {
      costs.put(word.getWord(), word.getCost());
      suggestions.add(word.getWord());
    }
    if (suggestions.size() == 0) {
      return "(no suggestions)";
    }
    Collections.sort(suggestions, new Comparator<String>() {
      public int compare(String s1, String s2) {
        Integer cost1 = costs.get(s1);
        Integer cost2 = costs.get(s2);
        return cost1.compareTo(cost2);
      }
    });
    return StringUtils.join(suggestions.iterator(), ", ");
  }
  
  public static void main(String[] args) throws Exception {
    if (args.length == 1) {
      // word mode
      new Shell(true);
    } else {
      new Shell();
    }
  }
}

Note that I still don't know whether this works well for a large set of mispelt phrases. I need to put this through a lot more real data to say that with any degree of certainty. It is also fairly slow in my development testing. I have a few ideas as to how that can be improved, although I will attempt them after I have some real data to play with. As always, any suggestions/corrections much appreciated.

Update 2009-04-26: In recent posts, I have been building on code written and described in previous posts, so there were (and rightly so) quite a few requests for the code. So I've created a project on Sourceforge to host the code. You will find the complete source code built so far in the project's SVN repository.

Saturday, October 18, 2008

IR Math in Java : Experiments in Clustering

As I mentioned last week, I have been trying to teach myself clustering algorithms. Having used the Carrot API to do some clustering work about 6 months ago, I have been curious about the clustering algorithms themselves. Carrot offers you a choice of several built-in clustering algorithms, so you just use one depending on your needs. Obviously, this presupposes that you know enough about the algorithms themselves to make the decision (which wasn't the case for me, unfortunately). So what better way to learn than to implement the algorithms in code? So this post covers some popular clustering algorithms implemented in Java.

The code in this post is based on algorithms from various sources. Sources are mentioned in the individual sections as well as listed in the references section below. I describe my implementations and test results for the following algorithms:

  1. K-Means Clustering
  2. Quality Threshold (QT) Clustering
  3. Simulated Annealing Clustering
  4. Nearest Neighbor Clustering
  5. Genetic Algorithm Clustering

K-Means Algorithm

For K-Means clustering, one seeds a random number of clusters with a few random seed documents from the collection. An estimate k of the number of clusters to use for a document collection of N documents is given by the heuristic below:

  k = floor(sqrt(N))
  where:
    N = number of documents in collection.
    k = the estimate of the number of clusters.

The algorithm starts by seeding the clusters with one random document each from the collection. It computes the centroid μ for each cluster using the following formula:

  μ = sqrt(sum(xi2)) / N
  where:
    μ = centroid of a cluster.
    N = number of documents in the collection.
    xi = the i-th document vector.

For each document, we compute the similarity between the centroids of the clusters, and assign the document to the cluster whose centroid is most similar. The similarity measure used is Cosine Similarity - I use code from my article on similarity metrics.

At the end of this step, we have a list of clusters fully populated with the documents from the collection. We then recompute the centroid based on the documents in these clusters, and repeat the above step until the new cluster is no better than the previous cluster.

Although common sense would suggest that we should use the same measure, ie Eucledian distance, for both similarity and centroid calculations, this does not work in practice - there will be no improvement in the cluster after the initial population. So it is important to use a different measure to calculate similarity.

Here is the code for my K-Means clusterer. The algorithm used is the one in the TMAP book[1].

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
// Source: src/main/java/com/mycompany/myapp/clustering/KMeansClusterer.java
package com.mycompany.myapp.clustering;

import java.util.ArrayList;
import java.util.List;

import org.apache.commons.collections15.CollectionUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

import Jama.Matrix;

public class KMeansClusterer {

  private final Log log = LogFactory.getLog(getClass());
  
  private String[] initialClusterAssignments = null;
  
  public void setInitialClusterAssignments(String[] documentNames) {
    this.initialClusterAssignments = documentNames;
  }
  
  public List<Cluster> cluster(DocumentCollection collection) {
    int numDocs = collection.size();
    int numClusters = 0;
    if (initialClusterAssignments == null) {
      // compute initial cluster assignments
      IdGenerator idGenerator = new IdGenerator(numDocs);
      numClusters = (int) Math.floor(Math.sqrt(numDocs));
      initialClusterAssignments = new String[numClusters];
      for (int i = 0; i < numClusters; i++) {
        int docId = idGenerator.getNextId();
        initialClusterAssignments[i] = collection.getDocumentNameAt(docId);
      }
    } else {
      numClusters = initialClusterAssignments.length;
    }

    // build initial clusters
    List<Cluster> clusters = new ArrayList<Cluster>();
    for (int i = 0; i < numClusters; i++) {
      Cluster cluster = new Cluster("C" + i);
      cluster.addDocument(initialClusterAssignments[i], 
        collection.getDocument(initialClusterAssignments[i]));
      clusters.add(cluster);
    }
    log.debug("..Initial clusters:" + clusters.toString());

    List<Cluster> prevClusters = new ArrayList<Cluster>();

    // Repeat until termination conditions are satisfied
    for (;;) {
      // For every cluster i, (re-)compute the centroid based on the
      // current member documents. (We have moved 2.2 above 2.1 because
      // this needs to be done before every iteration).
      Matrix[] centroids = new Matrix[numClusters];
      for (int i = 0; i < numClusters; i++) {
        Matrix centroid = clusters.get(i).getCentroid();
        centroids[i] = centroid;
      }
      // For every document d, find the cluster i whose centroid is 
      // most similar, assign d to cluster i. (If a document is 
      // equally similar from all centroids, then just dump it into 
      // cluster 0).
      for (int i = 0; i < numDocs; i++) {
        int bestCluster = 0;
        double maxSimilarity = Double.MIN_VALUE;
        Matrix document = collection.getDocumentAt(i);
        String docName = collection.getDocumentNameAt(i);
        for (int j = 0; j < numClusters; j++) {
          double similarity = clusters.get(j).getSimilarity(document);
          if (similarity > maxSimilarity) {
            bestCluster = j;
            maxSimilarity = similarity;
          }
        }
        for (Cluster cluster : clusters) {
          if (cluster.getDocument(docName) != null) {
            cluster.removeDocument(docName);
          }
        }
        clusters.get(bestCluster).addDocument(docName, document);
      }
      log.debug("..Intermediate clusters: " + clusters.toString());

      // Check for termination -- minimal or no change to the assignment
      // of documents to clusters.
      if (CollectionUtils.isEqualCollection(clusters, prevClusters)) {
        break;
      }
      prevClusters.clear();
      prevClusters.addAll(clusters);
    }
    // Return list of clusters
    log.debug("..Final clusters: " + clusters.toString());
    return clusters;
  }
}

The K-Means algorithm seems to be reasonably fast. However, the problem with it is that the solution is very sensitive to initial cluster seeding. Here are some results I got from using a random number generator to seed the clusters.

1
2
3
4
5
6
7
==== Results from K-Means clustering ==== (seeds: [D5,D2])
C0:[D1, D3, D4, D5, D6]
C1:[D2, D7]

==== Results from K-Means clustering ==== (seeds: [D3,D2])
C0:[D1, D3, D4, D5, D6]
C1:[D2, D7]

However, if I seed the clusters manually, using the results from my cluster visualization article last week, then I get results which look reasonable:

1
2
3
==== Results from K-Means clustering ==== (seeds: [D1,D3])
C0:[D1, D2, D5, D6, D7]
C1:[D3, D4]

Quality Threshold (QT) Algorithm

The Quality Threshold (QT) algorithm uses a maximum diameter settable by the user to cluster documents. The first cluster is built with the first document in the collection. As long as other documents are close enough to be within the diameter specified, they are added to the cluster. Once all documents are read, the documents that have been added to the cluster are set aside and the algorithm repeated recursively on the rest of the document collection. The program stops when there are no more documents. The number of levels that the program recurses down to corresponds to the number of clusters formed as a result.

The distance between a document and a cluster is computed using Complete Linkage Distance, ie the distance from the document and the furthest document in the cluster.

Here is the code for my QT Clustering program. The algorithm used was from this Wikipedia article[2].

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
// Source: src/main/java/com/mycompany/myapp/clustering/QtClusterer.java
package com.mycompany.myapp.clustering;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.commons.math.stat.descriptive.moment.Mean;

import Jama.Matrix;

import com.mycompany.myapp.similarity.CosineSimilarity;

public class QtClusterer {

  private final Log log = LogFactory.getLog(getClass());
  
  private double maxDiameter;
  private boolean randomizeDocuments;
  
  public void setMaxRadius(double maxRadius) {
    this.maxDiameter = maxRadius * 2.0D;
  }
  
  public void setRandomizeDocuments(boolean randomizeDocuments) {
    this.randomizeDocuments = randomizeDocuments;
  }
  
  public List<Cluster> cluster(DocumentCollection collection) {
    if (randomizeDocuments) {
      collection.shuffle();
    }
    List<Cluster> clusters = new ArrayList<Cluster>();
    Set<String> clusteredDocNames = new HashSet<String>();
    cluster_r(collection, clusters, clusteredDocNames, 0);
    return clusters;
  }

  private void cluster_r(DocumentCollection collection, 
      List<Cluster> clusters, 
      Set<String> clusteredDocNames, int level) {
    int numDocs = collection.size();
    int numClustered = clusteredDocNames.size();
    if (numDocs == numClustered) {
      return;
    }
    Cluster cluster = new Cluster("C" + level);
    for (int i = 0; i < numDocs; i++) {
      Matrix document = collection.getDocumentAt(i);
      String docName = collection.getDocumentNameAt(i);
      if (clusteredDocNames.contains(docName)) {
        continue;
      }
      log.debug("max dist=" + cluster.getCompleteLinkageDistance(document));
      if (cluster.getCompleteLinkageDistance(document) < maxDiameter) {
        cluster.addDocument(docName, document);
        clusteredDocNames.add(docName);
      }
    }
    if (cluster.size() == 0) {
      log.warn("No clusters added at level " + level + ", check diameter");
    }
    clusters.add(cluster);
    cluster_r(collection, clusters, clusteredDocNames, level + 1);
  }
}

The algorithm is easy to understand, and always returns the exact same clusters, regardless of the input. Using a diameter threshold of 0.4, I was able to get two clusters which is shown below:

1
2
3
==== Results from Qt Clustering ==== (diameter: 0.4D)
C0:[D6, D7]
C1:[D2, D1, D4, D5, D3]

Simulated Annealing Algorithm

The Simulated Annealing clustering algorithm is based on the Annealing process in metallurgy, where it is used to harden metals by cooling the molten metal in steps.

The algorithm starts by setting an initial "temperature", and builds an initial set of clusters using some population process. Mod based partitioning to used populate the initial clusters, although a degree of randomness can be added by shuffling the collection.

At each temperature setting, we exchange two random documents between two random clusters for a specified number of times. We then check to see if the solution improved or degraded based on the average radius of the cluster. Depending on the current temperature setting, we compute a probability that we should accept the degraded solution (aka downhill move). The probability is given by:

  P = exp((Si-1 - Si) / T)
  where:
    Si-1 = value of solution at loop (i-1)
    Si = value of solution at loop (i)
    T = current temperature setting

If the probability is higher than a specified threshold, we accept a downhill move. We then decrease the temperature by a specified step value. We then go back to exchanging random documents between random clusters in a loop. We keep doing this until the temperature is below a certain cutoff point.

From the probability equation above, the algorithm will allow more downhill moves towards its end, when the temperature gets lower. This allows more exploration of the solution space than K-Means or QT clustering methods.

The code for my Simulated Annealing Clusterer is shown below. The algorithm comes from the TMAP book[1].

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
// Source: src/main/java/com/mycompany/myapp/clustering/SimulatedaAnnealingClusterer.java
package com.mycompany.myapp.clustering;

import java.util.ArrayList;
import java.util.List;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

public class SimulatedAnnealingClusterer {

  private final Log log = LogFactory.getLog(getClass());

  private boolean randomizeDocs;
  private double initialTemperature;
  private double finalTemperature;
  private double downhillProbabilityCutoff;
  private int numberOfLoops;
  
  public void setRandomizeDocs(boolean randomizeDocs) {
    this.randomizeDocs = randomizeDocs;
  }
  
  public void setInitialTemperature(double initialTemperature) {
    this.initialTemperature = initialTemperature;
  }
  
  public void setFinalTemperature(double finalTemperature) {
    this.finalTemperature = finalTemperature;
  }
  
  public void setDownhillProbabilityCutoff(
      double downhillProbabilityCutoff) {
    this.downhillProbabilityCutoff = downhillProbabilityCutoff;
  }
  
  public void setNumberOfLoops(int numberOfLoops) {
    this.numberOfLoops = numberOfLoops;
  }
  
  public List<Cluster> cluster(DocumentCollection collection) {
    // Get initial set of clusters... 
    int numDocs = collection.size();
    int numClusters = (int) Math.floor(Math.sqrt(numDocs));
    List<Cluster> clusters = new ArrayList<Cluster>();
    for (int i = 0; i < numClusters; i++) {
      clusters.add(new Cluster("C" + i));
    }
    // ...and set initial temperature parameter T.
    double temperature = initialTemperature;
    // Randomly assign documents to the k clusters.
    if (randomizeDocs) {
      collection.shuffle();
    }
    for (int i = 0; i < numDocs; i++) {
      int targetCluster = i % numClusters;
      clusters.get(targetCluster).addDocument(
        collection.getDocumentNameAt(i),
        collection.getDocument(collection.getDocumentNameAt(i)));
    }
    log.debug("..Initial clusters: " + clusters.toString());
    // Repeat until temperature is reduced to the minimum.
    while (temperature > finalTemperature) {
      double previousAverageRadius = 0.0D;
      List<Cluster> prevClusters = new ArrayList<Cluster>();
      // Run loop NUM_LOOP times.
      for (int loop = 0; loop < numberOfLoops; loop++) {
        // Find a new set of clusters by altering the membership of some
        // documents. Start by picking two clusters at random
        List<Integer> randomClusterIds = getRandomClusterIds(clusters);
        // pick two documents out of the clusters at random
        List<String> randomDocumentNames = 
          getRandomDocumentNames(collection, randomClusterIds, clusters);
        // exchange the two random documents among the random clusters.
        clusters.get(randomClusterIds.get(0)).removeDocument(
          randomDocumentNames.get(0));
        clusters.get(randomClusterIds.get(0)).addDocument(
          randomDocumentNames.get(1), 
          collection.getDocument(randomDocumentNames.get(1)));
        clusters.get(randomClusterIds.get(1)).removeDocument(
          randomDocumentNames.get(1));
        clusters.get(randomClusterIds.get(1)).addDocument(
          randomDocumentNames.get(0), 
          collection.getDocument(randomDocumentNames.get(0)));
        // Compare the difference between the values of the new and old
        // set of clusters. If there is an improvement, accept the new 
        // set of clusters, otherwise accept the new set of clusters with
        // probability p.
        log.debug("..Intermediate clusters: " + clusters.toString());
        double averageRadius = getAverageRadius(clusters);
        if (averageRadius > previousAverageRadius) {
          // possible downhill move, calculate the probability of it being 
          // accepted
          double probability = 
            Math.exp((previousAverageRadius - averageRadius)/temperature);
          if (probability < downhillProbabilityCutoff) {
            // go back to the cluster before the changes
            clusters.clear();
            clusters.addAll(prevClusters);
            continue;
          }
        }
        prevClusters.clear();
        prevClusters.addAll(clusters);
        previousAverageRadius = averageRadius;
      }
      // Reduce the temperature based on the cooling schedule.
      temperature = temperature / 10;
    }
    // Return the final set of clusters.
    return clusters;
  }

  private List<Integer> getRandomClusterIds(
      List<Cluster> clusters) {
    IdGenerator clusterIdGenerator = new IdGenerator(clusters.size());
    List<Integer> randomClusterIds = new ArrayList<Integer>();
    for (int i = 0; i < 2; i++) {
      randomClusterIds.add(clusterIdGenerator.getNextId());
    }
    return randomClusterIds;
  }

  private List<String> getRandomDocumentNames(
      DocumentCollection collection, 
      List<Integer> randomClusterIds, 
      List<Cluster> clusters) {
    List<String> randomDocumentNames = new ArrayList<String>();
    for (Integer randomClusterId : randomClusterIds) {
      Cluster randomCluster = clusters.get(randomClusterId);
      IdGenerator documentIdGenerator = 
        new IdGenerator(randomCluster.size());
      randomDocumentNames.add(
        randomCluster.getDocumentName(documentIdGenerator.getNextId()));
    }
    return randomDocumentNames;
  }

  private double getAverageRadius(List<Cluster> clusters) {
    double score = 0.0D;
    for (Cluster cluster : clusters) {
      score += cluster.getRadius();
    }
    return (score / clusters.size());
  }
}

Results for Simulated Annealing runs vary across runs, which is expected, since this is essentially a Monte Carlo simulation. Some results from multiple runs, with the initial and final temperatures set to 100 and 1, and the downhill probability threshold set to 0.7, are shown below. One way to come by a good set of final results may be to consider aggregating results from multiple runs into a single one.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
==== Results from Simulated Annealing Clustering ====
C0:[D1, D3, D2, D4]
C1:[D5, D6, D7]

==== Results from Simulated Annealing Clustering ====
C0:[D3, D4, D2, D5]
C1:[D7, D1, D6]

==== Results from Simulated Annealing Clustering ====
C0:[D1, D7, D5, D6]
C1:[D4, D2, D3]

Nearest Neighbor Algorithm

This algorithm is classified as a Genetic algorithm in the TMAP book, but a subsequent section in the book describes a genetic clustering algorithm that involves mutations and crossovers. I guess the latter type is commonly associated with genetic algorithms in general. However, the Nearest Neighbor algorithm is is popular for clustering genes as well, so I guess calling it a genetic algorithm is not incorrect.

The algorithm first sorts the documents according to the sum of similarities with its 2r neighbors. It then loops through the documents in descending order of the sum of similarities. If a document is already assigned to a cluster, it is skipped, otherwise a new cluster is created with the document as its seed. Then all neighboring documents to the right and left of the current document that are not already assigned and have a simularity greater than a specified threshold are added to the cluster.

My code for the Nearest Neighbor algorithm is shown below. The algorithm comes from the TMAP book[1].

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
// Source: src/main/java/com/mycompany/myapp/clustering/NearestNeighborClusterer.java
package com.mycompany.myapp.clustering;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import org.apache.commons.lang.StringUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

public class NearestNeighborClusterer {

  private final Log log = LogFactory.getLog(getClass());
  
  private int numNeighbors;
  private double similarityThreshold;
  
  public void setNumNeighbors(int numNeighbors) {
    this.numNeighbors = numNeighbors;
  }

  public void setSimilarityThreshold(double similarityThreshold) {
    this.similarityThreshold = similarityThreshold;
  }

  public List<Cluster> cluster(DocumentCollection collection) {
    // get neighbors for every document
    Map<String,Double> similarityMap = collection.getSimilarityMap();
    Map<String,List<String>> neighborMap = 
      new HashMap<String,List<String>>();
    for (String documentName : collection.getDocumentNames()) {
      neighborMap.put(documentName, 
        collection.getNeighbors(documentName, similarityMap, numNeighbors));
    }
    // compute sum of similarities of every document with its numNeighbors
    Map<String,Double> fitnesses = 
      getFitnesses(collection, similarityMap, neighborMap);
    List<String> sortedDocNames = new ArrayList<String>();
    // sort by sum of similarities descending
    sortedDocNames.addAll(collection.getDocumentNames());
    Collections.sort(sortedDocNames, Collections.reverseOrder(
      new ByValueComparator<String,Double>(fitnesses)));
    List<Cluster> clusters = new ArrayList<Cluster>();
    int clusterId = 0;
    // Loop through the list of documents in descending order of the sum 
    // of the similarities.
    Map<String,String> documentClusterMap = 
      new HashMap<String,String>();
    for (String docName : sortedDocNames) {
      // skip if document already assigned to cluster
      if (documentClusterMap.containsKey(docName)) {
        continue;
      }
      // create cluster with current document
      Cluster cluster = new Cluster("C" + clusterId);
      cluster.addDocument(docName, collection.getDocument(docName));
      documentClusterMap.put(docName, cluster.getId());
      // find all neighboring documents to the left and right of the current
      // document that are not assigned to a cluster, and have a similarity
      // greater than our threshold. Add these documents to the new cluster
      List<String> neighbors = neighborMap.get(docName);
      for (String neighbor : neighbors) {
        if (documentClusterMap.containsKey(neighbor)) {
          continue;
        }
        double similarity = similarityMap.get(
          StringUtils.join(new String[] {docName, neighbor}, ":"));
        if (similarity < similarityThreshold) {
          continue;
        }
        cluster.addDocument(neighbor, collection.getDocument(neighbor));
        documentClusterMap.put(neighbor, cluster.getId());
      }
      clusters.add(cluster);
      clusterId++;
    }
    return clusters;
  }

  private Map<String,Double> getFitnesses(
      DocumentCollection collection, 
      Map<String,Double> similarityMap,
      Map<String,List<String>> neighbors) {
    Map<String,Double> fitnesses = new HashMap<String,Double>();
    for (String docName : collection.getDocumentNames()) {
      double fitness = 0.0D;
      for (String neighborDoc : neighbors.get(docName)) {
        String key = StringUtils.join(
          new String[] {docName, neighborDoc}, ":");
        fitness += similarityMap.get(key);
      }
      fitnesses.put(docName, fitness);
    }
    return fitnesses;
  }
}

Although there are extra pre-sorting work that needs to be done, the algorithm is relatively simple to understand. With a similarity threshold set to 0.25, I get the following results:

1
2
3
4
5
6
=== Clusters from Nearest Neighbor Algorithm === (sim threshold = 0.25)
C0:[D4, D3]
C1:[D7, D6]
C2:[D2]
C3:[D1]
C4:[D5]

Genetic Algorithm

The code for this algorithm was written using an algorithm described in the paper by Maulik and Bandopadhayaya[5]. In the language of genetics, a document is a gene and a cluster is a chromosome. Clusters get fitter by a reproducing and passing on their best traits, in a process similar to Darwinian evolution.

The algorithm starts by estimating the number of clusters, partitioning the document collection by mod value into the clusters. It then computes the fitness across all clusters in the current generation. After that, it will execute a few (configurable) crossover operations followed by a mutation operation. Crossover involves selecting two cut points, and exchanging the documents for the portion between the cut points, and thus creating a new cluster. Mutation selects two random documents from two random clusters and exchanges them. At the end of each generation, the fitness of the clusters are recomputed. The algorithm terminates when the fitness does not increase any more across generations.

My code for the genetic clustering algorithm is shown below:

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
// Source: src/main/java/com/mycompany/myapp/GeneticClusterer.java
package com.mycompany.myapp.clustering;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

public class GeneticClusterer {

  private final Log log = LogFactory.getLog(getClass());
  
  private boolean randomizeData;
  private int numCrossoversPerMutation;
  private int maxGenerations;
  
  public void setRandomizeData(boolean randomizeData) {
    this.randomizeData = randomizeData;
  }
  
  public void setNumberOfCrossoversPerMutation(int ncpm) {
    this.numCrossoversPerMutation = ncpm;
  }

  public void setMaxGenerations(int maxGenerations) {
    this.maxGenerations = maxGenerations;
  }
  
  public List<Cluster> cluster(DocumentCollection collection) {
    // get initial clusters
    int k = (int) Math.floor(Math.sqrt(collection.size()));
    List<Cluster> clusters = new ArrayList<Cluster>();
    for (int i = 0; i < k; i++) {
      Cluster cluster = new Cluster("C" + i);
      clusters.add(cluster);
    }
    if (randomizeData) {
      collection.shuffle();
    }
    // load it up using mod partitioning, this is P(0)
    int docId = 0;
    for (String documentName : collection.getDocumentNames()) {
      int clusterId = docId % k;
      clusters.get(clusterId).addDocument(
        documentName, collection.getDocument(documentName));
      docId++;
    }
    log.debug("Initial clusters = " + clusters.toString());
    // holds previous cluster in the compute loop
    List<Cluster> prevClusters = new ArrayList<Cluster>();
    double prevFitness = 0.0D;
    int generations = 0;
    for (;;) {
      // compute fitness for P(t)
      double fitness = computeFitness(clusters);
      // if termination condition achieved, break and return clusters
      if (prevFitness > fitness) {
        clusters.clear();
        clusters.addAll(prevClusters);
        break;
      }
      // even if termination condition not met, terminate after the
      // maximum number of generations
      if (generations > maxGenerations) {
        break;
      }
      // do specified number of crossover operations for this generation
      for (int i = 0; i < numCrossoversPerMutation; i++) {
        crossover(clusters, collection, i);
        generations++;
      }
      // followed by a single mutation per generation
      mutate(clusters, collection);
      generations++;
      log.debug("..Intermediate clusters (" + generations + "): " +
        clusters.toString());
      // hold on to previous solution
      prevClusters.clear();
      prevClusters.addAll(clusters);
      prevFitness = computeFitness(prevClusters);
    }
    return clusters;
  }
  
  /**
   * Come up with something arbitary. Just compute the sum of the radii of
   * the clusters.
   * @param clusters
   * @return
   */
  private double computeFitness(List<Cluster> clusters) {
    double radius = 0.0D;
    for (Cluster cluster : clusters) {
      cluster.getCentroid();
      radius += cluster.getRadius();
    }
    return radius;
  }
  
  /**
   * Selects two random clusters from the list, then selects two cut-points
   * based on the minimum cluster size of the two clusters. Exchanges the
   * documents between the cut points.
   * @param clusters the clusters to operate on.
   * @param sequence the sequence number of the cross over operation.
   */
  public void crossover(List<Cluster> clusters, 
      DocumentCollection collection, int sequence) {
    IdGenerator clusterIdGenerator = new IdGenerator(clusters.size());
    int[] clusterIds = new int[2];
    clusterIds[0] = clusterIdGenerator.getNextId();
    clusterIds[1] = clusterIdGenerator.getNextId();
    int minSize = Math.min(
      clusters.get(clusterIds[0]).size(), 
      clusters.get(clusterIds[1]).size());
    IdGenerator docIdGenerator = new IdGenerator(minSize);
    int[] cutPoints = new int[2];
    cutPoints[0] = docIdGenerator.getNextId();
    cutPoints[1] = docIdGenerator.getNextId();
    Arrays.sort(cutPoints);
    Cluster cluster1 = clusters.get(clusterIds[0]);
    Cluster cluster2 = clusters.get(clusterIds[1]);
    for (int i = 0; i < cutPoints[0]; i++) {
      String docName1 = cluster1.getDocumentName(i);
      String docName2 = cluster2.getDocumentName(i);
      cluster1.removeDocument(docName1);
      cluster2.addDocument(docName1, collection.getDocument(docName1));
      cluster2.removeDocument(docName2);
      cluster1.addDocument(docName2, collection.getDocument(docName2));
    }
    // leave the documents between the cut points alone
    for (int i = cutPoints[1]; i < minSize; i++) {
      String docName1 = cluster1.getDocumentName(i);
      String docName2 = cluster2.getDocumentName(i);
      cluster1.removeDocument(docName1);
      cluster2.addDocument(docName1, collection.getDocument(docName1));
      cluster2.removeDocument(docName2);
      cluster1.addDocument(docName2, collection.getDocument(docName2));
    }
    // rebuild the Cluster list, replacing the changed clusters.
    List<Cluster> crossoverClusters = new ArrayList<Cluster>();
    int clusterId = 0;
    for (Cluster cluster : clusters) {
      if (clusterId == clusterIds[0]) {
        crossoverClusters.add(cluster1);
      } else if (clusterId == clusterIds[1]) {
        crossoverClusters.add(cluster2);
      } else {
        crossoverClusters.add(cluster);
      }
      clusterId++;
    }
    clusters.clear();
    clusters.addAll(crossoverClusters);
  }
  
  /**
   * Exchanges a random document between two random clusters in the list.
   * @param clusters the clusters to operate on.
   */
  private void mutate(List<Cluster> clusters, 
      DocumentCollection collection) {
    // choose two random clusters
    IdGenerator clusterIdGenerator = new IdGenerator(clusters.size());
    int[] clusterIds = new int[2];
    clusterIds[0] = clusterIdGenerator.getNextId();
    clusterIds[1] = clusterIdGenerator.getNextId();
    Cluster cluster1 = clusters.get(clusterIds[0]);
    Cluster cluster2 = clusters.get(clusterIds[1]);
    // choose two random documents in the clusters
    int minSize = Math.min(
      clusters.get(clusterIds[0]).size(), 
      clusters.get(clusterIds[1]).size());
    IdGenerator docIdGenerator = new IdGenerator(minSize);
    String docName1 = cluster1.getDocumentName(docIdGenerator.getNextId());
    String docName2 = cluster2.getDocumentName(docIdGenerator.getNextId());
    // exchange the documents
    cluster1.removeDocument(docName1);
    cluster1.addDocument(docName2, collection.getDocument(docName2));
    cluster2.removeDocument(docName2);
    cluster2.addDocument(docName1, collection.getDocument(docName1));
    // rebuild the cluster list, replacing changed clusters
    List<Cluster> mutatedClusters = new ArrayList<Cluster>();
    int clusterId = 0;
    for (Cluster cluster : clusters) {
      if (clusterId == clusterIds[0]) {
        mutatedClusters.add(cluster1);
      } else if (clusterId == clusterIds[1]) {
        mutatedClusters.add(cluster2);
      } else {
        mutatedClusters.add(cluster);
      }
      clusterId++;
    }
    clusters.clear();
    clusters.addAll(mutatedClusters);
  }
}

To measure the fitness of a generation (ie the list of clusters), I decided on the sum of the radii of the clusters. I guess I could have used a fancier function such as the sum of similarities in the Nearest Neighbor algorithm. In any case, I terminated the algorithm after 500 generations, and the result it came up with is shown below:

1
2
3
=== Clusters from Genetic Algorithm ===
C0:[D2, D1, D3, D4]
C1:[D5, D7, D6]

Supporting classes

In order to make the code for the clusterers clean and readable, a lot of code is factored out into supporting classes. They are shown below:

Cluster.java

This class models a cluster as a list of named document objects. It provides various methods to compute properties of a cluster given its members, such as centroid, Eucledian Distance or Cosine Similarity of a document from the cluster centroid, etc. The code is shown below:

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69

 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
// Source: src/main/java/com/mycompany/myapp/clustering/Cluster.java
package com.mycompany.myapp.clustering;

import java.util.Arrays;
import java.util.LinkedHashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;

import org.apache.commons.lang.ArrayUtils;
import org.apache.commons.lang.StringUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.commons.math.stat.descriptive.rank.Max;

import Jama.Matrix;

public class Cluster {
  
  private final Log log = LogFactory.getLog(getClass());
  
  private String id;
  private Map<String,Matrix> docs = 
    new LinkedHashMap<String,Matrix>();
  private List<String> docNames = new LinkedList<String>();
  
  private Matrix centroid = null;
  
  public Cluster(String id) {
    super();
    this.id = id;
  }
  
  public String getId() {
    return id;
  }
  
  public Set<String> getDocumentNames() {
    return docs.keySet();
  }

  public String getDocumentName(int pos) {
    return docNames.get(pos);
  }
  
  public Matrix getDocument(String documentName) {
    return docs.get(documentName);
  }

  public Matrix getDocument(int pos) {
    return docs.get(docNames.get(pos));
  }
  
  public void addDocument(String docName, Matrix docMatrix) {
    docs.put(docName, docMatrix);
    docNames.add(docName);
    log.debug("...." + id + " += " + docName);
  }

  public void removeDocument(String docName) {
    docs.remove(docName);
    docNames.remove(docName);
    log.debug("...." + id + " -= " + docName);
  }

  public int size() {
    return docs.size();
  }
  
  public boolean contains(String docName) {
    return docs.containsKey(docName);
  }
  
  /**
   * Returns a document (term vector) consisting of the average of the 
   * coordinates of the documents in the cluster. Returns a null Matrix
   * if there are no documents in the cluster. 
   * @return the centroid of the cluster, or null if no documents have 
   * been added to the cluster.
   */
  public Matrix getCentroid() {
    if (docs.size() == 0) {
      return null;
    }
    Matrix d = docs.get(docNames.get(0));
    centroid = new Matrix(d.getRowDimension(), d.getColumnDimension()); 
    for (String docName : docs.keySet()) {
      Matrix docMatrix = docs.get(docName);
      centroid = centroid.plus(docMatrix);
    }
    centroid = centroid.times(1.0D / docs.size());
    return centroid;
  }

  /**
   * Returns the radius of the cluster. The radius is the average of the
   * square root of the sum of squares of its constituent document term
   * vector coordinates with that of the centroid.
   * @return the radius of the cluster.
   */
  public double getRadius() {
    double radius = 0.0D;
    if (centroid != null) {
      for (String docName : docNames) {
        Matrix doc = getDocument(docName);
        radius += doc.minus(centroid).normF();
      }
    }
    return radius / docNames.size();
  }
  
  /**
   * Returns the Eucledian distance between the centroid of this cluster
   * and the new document.
   * @param doc the document to be measured for distance.
   * @return the eucledian distance between the cluster centroid and the 
   * document.
   */
  public double getEucledianDistance(Matrix doc) {
    if (centroid != null) {
      return (doc.minus(centroid)).normF();
    }
    return 0.0D;
  }
  
  /**
   * Returns the maximum distance from the specified document to any of
   * the documents in the cluster.
   * @param doc the document to be measured for distance.
   * @return the complete linkage distance from the cluster.
   */
  public double getCompleteLinkageDistance(Matrix doc) {
    Max max = new Max();
    if (docs.size() ==0) {
      return 0.0D;
    }
    double[] distances = new double[docs.size()];
    for (int i = 0; i < distances.length; i++) {
      Matrix clusterDoc = docs.get(docNames.get(i));
      distances[i] = clusterDoc.minus(doc).normF();
    }
    return max.evaluate(distances);
  }
  
  /**
   * Returns the cosine similarity between the centroid of this cluster
   * and the new document.
   * @param doc the document to be measured for similarity.
   * @return the similarity of the centroid of the cluster to the document.
   */
  public double getSimilarity(Matrix doc) {
    if (centroid != null) {
      double dotProduct = centroid.arrayTimes(doc).norm1();
      double normProduct = centroid.normF() * doc.normF();
      return dotProduct / normProduct;
    }
    return 0.0D;
  }

  @Override
  public boolean equals(Object obj) {
    if (!(obj instanceof Cluster)) {
      return false;
    }
    Cluster that = (Cluster) obj;
    String[] thisDocNames = this.getDocumentNames().toArray(new String[0]);
    String[] thatDocNames = that.getDocumentNames().toArray(new String[0]);
    if (thisDocNames.length != thatDocNames.length) {
      return false;
    }
    Arrays.sort(thisDocNames);
    Arrays.sort(thatDocNames);
    return ArrayUtils.isEquals(thisDocNames, thatDocNames);
  }
  
  @Override
  public int hashCode() {
    String[] docNames = getDocumentNames().toArray(new String[0]);
    Arrays.sort(docNames);
    return StringUtils.join(docNames, ",").hashCode();
  }
  
  @Override
  public String toString() {
    return id + ":" + docs.keySet().toString();
  }
}

DocumentCollection.java

The DocumentCollection represents the collection of documents previously represented by the term-document matrix. It provides convenience accessor methods, and other methods to compute similarity of documents to its collection and get neighboring documents by similarity. These were used by the nearest neighbor algorithm.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
// Source: src/main/java/com/mycompany/myapp/DocumentCollection.java
package com.mycompany.myapp.clustering;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import org.apache.commons.lang.StringUtils;

import Jama.Matrix;

import com.mycompany.myapp.similarity.CosineSimilarity;

public class DocumentCollection {

  private Matrix tdMatrix;
  private Map<String,Matrix> documentMap;
  private List<String> documentNames;
  
  public DocumentCollection(Matrix tdMatrix, String[] docNames) {
    int position = 0;
    this.tdMatrix = tdMatrix;
    this.documentMap = new HashMap<String,Matrix>();
    this.documentNames = new ArrayList<String>();
    for (String documentName : docNames) {
      documentMap.put(documentName, tdMatrix.getMatrix(
        0, tdMatrix.getRowDimension() - 1, position, position));
      documentNames.add(documentName);
      position++;
    }
  }

  public int size() {
    return documentMap.keySet().size();
  }
  
  public List<String> getDocumentNames() {
    return documentNames;
  }
  
  public String getDocumentNameAt(int position) {
    return documentNames.get(position);
  }
  
  public Matrix getDocumentAt(int position) {
    return documentMap.get(documentNames.get(position));
  }
  
  public Matrix getDocument(String documentName) {
    return documentMap.get(documentName);
  }
  
  public void shuffle() {
    Collections.shuffle(documentNames);
  }
  
  public Map<String,Double> getSimilarityMap() {
    Map<String,Double> similarityMap = 
      new HashMap<String,Double>();
    CosineSimilarity similarity = new CosineSimilarity();
    Matrix similarityMatrix = similarity.transform(tdMatrix);
    for (int i = 0; i < similarityMatrix.getRowDimension(); i++) {
      for (int j = 0; j < similarityMatrix.getColumnDimension(); j++) {
        String sourceDoc = getDocumentNameAt(i);
        String targetDoc = getDocumentNameAt(j);
        similarityMap.put(StringUtils.join(
          new String[] {sourceDoc, targetDoc}, ":"),
          similarityMatrix.get(i, j));
      }
    }
    return similarityMap;
  }
  
  public List<String> getNeighbors(String docName,
      Map<String,Double> similarityMap, int numNeighbors) {
    if (numNeighbors > size()) {
      throw new IllegalArgumentException(
        "numNeighbors too large, max: " + size());
    }
    final Map<String,Double> differenceMap = 
      new HashMap<String,Double>();
    List<String> neighbors = new ArrayList<String>();
    neighbors.addAll(documentNames);
    for (String documentName : documentNames) {
      String key = StringUtils.join(
        new String[] {docName, documentName}, ":");
      double difference = Math.abs(similarityMap.get(key) - 1.0D);
      differenceMap.put(documentName, difference);
    }
    Collections.sort(neighbors, 
      new ByValueComparator<String,Double>(differenceMap));
    return neighbors.subList(0, numNeighbors + 1);
  }
}

IdGenerator.java

IdGenerator is a "safe" random number generator that will always return unique different numbers till its numbers are exhausted. It is seeded with a maximum number, so it will return unique numbers from 0 to the (maximum - 1) as long as can, then it starts repeating the numbers.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
// Source: src/main/java/com/mycompany/myapp/IdGenerator.java
package com.mycompany.myapp.clustering;

import java.util.HashSet;
import java.util.Random;
import java.util.Set;

public class IdGenerator {

  private int upperBound;
  
  private Random randomizer;
  private Set<Integer> ids = new HashSet<Integer>();
  
  public IdGenerator(int upperBound) {
    this.upperBound = upperBound;
    randomizer = new Random();
  }
  
  public int getNextId() {
    if (ids.size() == upperBound) {
      ids.clear();
    }
    for (;;) {
      int id = randomizer.nextInt(upperBound);
      if (ids.contains(id)) {
        continue;
      }
      ids.add(id);
      return id;
    }
  }
}

ByValueComparator.java

The ByValueComparator is a generic comparator that allows you to sort a List based on a supporting map. The idea for this came from Jeffrey Bigham's blog post Sorting Java Map by Value, although I have used Java Generics to allow it to sort any kind of Map.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
// Source: src/main/java/com/mycompany/myapp/clustering/ByValueComparator.java
package com.mycompany.myapp.clustering;

import java.util.Comparator;
import java.util.HashMap;
import java.util.Map;

public class 
    ByValueComparator<K,V extends Comparable<? super V>> 
    implements Comparator<K> {

  private Map<K,V> map = new HashMap<K,V>();
  
  public ByValueComparator(Map<K,V> map) {
    this.map = map;
  }

  public int compare(K k1, K k2) {
    return map.get(k1).compareTo(map.get(k2));
  }
}

Test case

The test case is a simple JUnit test case that runs through all the clustering code using my little collection of seven document titles to build the term document matrix off of. Here is the code:

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
// Source: src/test/java/com/mycompany/myapp/clustering/ClusteringTest.java
package com.mycompany.myapp.clustering;

import java.io.BufferedReader;
import java.io.FileReader;
import java.io.Reader;
import java.io.StringReader;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;

import org.apache.commons.lang.StringUtils;
import org.junit.Before;
import org.junit.Test;
import org.springframework.jdbc.datasource.DriverManagerDataSource;

import Jama.Matrix;

import com.mycompany.myapp.indexers.IdfIndexer;
import com.mycompany.myapp.indexers.VectorGenerator;

public class ClusteringTest {

  private Matrix tdMatrix;
  private String[] documentNames;
  
  private DocumentCollection documentCollection;
  
  @Before
  public void setUp() throws Exception {
    VectorGenerator vectorGenerator = new VectorGenerator();
    vectorGenerator.setDataSource(new DriverManagerDataSource(
      "com.mysql.jdbc.Driver", 
      "jdbc:mysql://localhost:3306/tmdb", 
      "irstuff", "irstuff"));
    Map<String,Reader> documents = 
      new LinkedHashMap<String,Reader>();
    BufferedReader reader = new BufferedReader(
      new FileReader("src/test/resources/data/indexing_sample_data.txt"));
    String line = null;
    while ((line = reader.readLine()) != null) {
      String[] docTitleParts = StringUtils.split(line, ";");
      documents.put(docTitleParts[0], new StringReader(docTitleParts[1]));
    }
    vectorGenerator.generateVector(documents);
    IdfIndexer indexer = new IdfIndexer();
    tdMatrix = indexer.transform(vectorGenerator.getMatrix());
    documentNames = vectorGenerator.getDocumentNames();
    documentCollection = new DocumentCollection(tdMatrix, documentNames);
  }

  @Test
  public void testKMeansClustering() throws Exception {
    KMeansClusterer clusterer = new KMeansClusterer();
    clusterer.setInitialClusterAssignments(new String[] {"D1", "D3"});
    List<Cluster> clusters = clusterer.cluster(documentCollection);
    System.out.println("=== Clusters from K-Means algorithm ===");
    for (Cluster cluster : clusters) {
      System.out.println(cluster.toString());
    }
  }

  @Test
  public void testQtClustering() throws Exception {
    QtClusterer clusterer = new QtClusterer();
    clusterer.setMaxRadius(0.40D);
    clusterer.setRandomizeDocuments(true);
    List<Cluster> clusters = clusterer.cluster(documentCollection);
    System.out.println("=== Clusters from QT Algorithm ===");
    for (Cluster cluster : clusters) {
      System.out.println(cluster.toString());
    }
  }

  @Test
  public void testSimulatedAnnealingClustering() throws Exception {
    SimulatedAnnealingClusterer clusterer = 
      new SimulatedAnnealingClusterer();
    clusterer.setRandomizeDocs(false);
    clusterer.setNumberOfLoops(5);
    clusterer.setInitialTemperature(100.0D);
    clusterer.setFinalTemperature(1.0D);
    clusterer.setDownhillProbabilityCutoff(0.7D);
    List<Cluster> clusters = clusterer.cluster(documentCollection);
    System.out.println(
      "=== Clusters from Simulated Annealing Algorithm ===");
    for (Cluster cluster : clusters) {
      System.out.println(cluster.toString());
    }
  }
  
  @Test
  public void testNearestNeighborClustering() throws Exception {
    NearestNeighborClusterer clusterer = new NearestNeighborClusterer();
    clusterer.setNumNeighbors(2);
    clusterer.setSimilarityThreshold(0.25);
    List<Cluster> clusters = clusterer.cluster(documentCollection);
    System.out.println("=== Clusters from Nearest Neighbor Algorithm ===");
    for (Cluster cluster : clusters) {
      System.out.println(cluster.toString());
    }
  }
  
  @Test
  public void testGeneticAlgorithmClustering() throws Exception {
    GeneticClusterer clusterer = new GeneticClusterer();
    clusterer.setNumberOfCrossoversPerMutation(5);
    clusterer.setMaxGenerations(500);
    clusterer.setRandomizeData(false);
    List<Cluster> clusters = clusterer.cluster(documentCollection);
    System.out.println("=== Clusters from Genetic Algorithm ===");
    for (Cluster cluster : clusters) {
      System.out.println(cluster.toString());
    }
  }
}

References

References to books and Internet articles that the above code is based on, in no particular order:

  1. Text Mining Application Programming, by Dr. Manu Konchady.
  2. Wikipedia article on QT (Quality Threshold) clustering.
  3. Wikipedia article on Nearest-Neighbor algorithm.
  4. Research article on Hierarchic document clustering using a genetic algorithm by Robertson, Santimetrvirul and Willet, of the University of Sheffield, UK.
  5. Research article on Genetic algorithm-based clustering technique (requires PDF download) by Maulik and Bandopadhyay, of Government Engineering College, Kalyani, India and Indian Statistical Institute, Calcutta, India.

Update 2009-04-26: In recent posts, I have been building on code written and described in previous posts, so there were (and rightly so) quite a few requests for the code. So I've created a project on Sourceforge to host the code. You will find the complete source code built so far in the project's SVN repository.

Saturday, October 11, 2008

IR Math in Java : Cluster Visualization

I've been trying to learn clustering algorithms lately. I was planning to write about them this week, but some last minute refactoring to remove redundancies and make the code more readable resulted in everything going to hell. So I guess I will have to write about them next week.

Almost all clustering algorithms (at least the ones I have seen) seem to be non-deterministic, mainly because they select documents randomly from the collection to build the initial clusters. As a result, they can come up with wildly different clusters depending on how the initial clusters were formed. In my previous (un-refactored) code, for example, the K-Means algorithm converged to the same set of clusters most of the time, but with the changes, they no longer do.

Working through this for some time, I decided I needed to see for myself what the "correct" clusters were. So if I could visualize the documents as points in the n-dimensional term space, clumps of points would correspond to clusters. The problem was that I had only 2 (or maximum 3) dimensions of visualization to work with.

Luckily for me, smarter people than I have faced and solved the same problem, and they have been kind enough to write about it on the web. The solution is to do Dimensionality Reduction, extracting from the term-document matrix the first 2 or 3 most interesting components (or Principal Components) and use them as the values for a 2-dimensional or 3-dimensional scatter chart.

The mathematical background for Principal Component Analysis (PCA) is explained very nicely in this tutorial, which I quote verbatim below.

The mathematical technique used in PCA is called eigen analysis: we solve for the eigenvalues and eigenvectors of a square symmetric matrix with sums of squares and cross products. The eigenvector associated with the largest eigenvalue has the same direction as the first principal component. The eigenvector associated with the second largest eigenvalue determines the direction of the second principal component. The sum of the eigenvalues equals the trace of the square matrix and the maximum number of eigenvectors equals the number of rows (or columns) of this matrix.

It then goes on to explain the algorithm that should be used for reducing and extracting the most interesting dimensions (See Section 6, Algorithms) for a non-square matrix such as our term-document matrix. Essentially, it decomposes the term-document matrix A using Singular Value Decomposition (SVD) into 3 matrices, U, S and V, where the following equation holds.

  A = U * S * VT

Here S is a square diagonal matrix, where values are in descending order down the diagonal. So for a 2 dimensional reduction, the principal components correspond to the first 2 columns of V, and for a 3 dimensional reduction, the principal components correspond to the first 3 columns of V.

The Java code to generate data for drawing the charts is trivial, mainly because we use the Jama matrix library, which does all the heavy lifting for SVD calculations.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
// Source: src/main/java/com/mycompany/myapp/clustering/PcaClusterVisualizer.java
package com.mycompany.myapp.clustering;

import java.io.FileWriter;
import java.io.IOException;
import java.io.PrintWriter;

import Jama.Matrix;
import Jama.SingularValueDecomposition;

public class PcaClusterVisualizer {

  private final String PLOT_2D_OUTPUT = "plot2d.dat";
  private final String PLOT_3D_OUTPUT = "plot3d.dat";
  
  public void reduce(Matrix tdMatrix, String[] docNames) throws IOException {
    PrintWriter plot2dWriter = 
      new PrintWriter(new FileWriter(PLOT_2D_OUTPUT));
    PrintWriter plot3dWriter = 
      new PrintWriter(new FileWriter(PLOT_3D_OUTPUT));
    SingularValueDecomposition svd = 
      new SingularValueDecomposition(tdMatrix);
    Matrix v = svd.getV();
    // we know that the diagonal of S is ordered, so we can take the
    // first 3 cols from V, for use in plot2d and plot3d
    Matrix vRed = v.getMatrix(0, v.getRowDimension() - 1, 0, 2);
    for (int i = 0; i < v.getRowDimension(); i++) { // rows
      plot2dWriter.printf("%6.4f %6.4f %s%n", 
        Math.abs(vRed.get(i, 0)), Math.abs(vRed.get(i, 1)), docNames[i]);
      plot3dWriter.printf("%6.4f %6.4f %6.4f %s%n", 
        Math.abs(vRed.get(i, 0)), Math.abs(vRed.get(i, 1)), 
        Math.abs(vRed.get(i, 2)), docNames[i]);
    }
    plot2dWriter.flush();
    plot3dWriter.flush();
    plot2dWriter.close();
    plot3dWriter.close();
  }
}

The term-document matrix is generated from my 7 document title collection that I have been using for my experiments, using the following snippet of JUnit code. See one of my earlier posts titled IR Math in Java : TF, IDF and LSI for the actual data and details on the classes being used.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
  @Test
  public void testPcaClusterVisualization() throws Exception {
    // for brevity, this block is in a @Before method in the actual
    // code, it has been globbed together here for readability
    VectorGenerator vectorGenerator = new VectorGenerator();
    vectorGenerator.setDataSource(new DriverManagerDataSource(
      "com.mysql.jdbc.Driver", "jdbc:mysql://localhost:3306/tmdb", 
      "irstuff", "irstuff"));
    Map<String,Reader> documents = 
      new LinkedHashMap<String,Reader>();
    BufferedReader reader = new BufferedReader(
      new FileReader("src/test/resources/data/indexing_sample_data.txt"));
    String line = null;
    while ((line = reader.readLine()) != null) {
      String[] docTitleParts = StringUtils.split(line, ";");
      documents.put(docTitleParts[0], new StringReader(docTitleParts[1]));
    }
    vectorGenerator.generateVector(documents);
    IdfIndexer indexer = new IdfIndexer();
    tdMatrix = indexer.transform(vectorGenerator.getMatrix());
    documentNames = vectorGenerator.getDocumentNames();
    documentCollection = new DocumentCollection(tdMatrix, documentNames);
    // this is my actual @Test block
    PCAClusterVisualizer visualizer = new PCAClusterVisualizer();
    visualizer.reduce(tdMatrix, documentNames);
  }

This generates 2 data files which are used as inputs to gnuplot to generate 2D and 3D scatter charts. The data, chart, and the gnuplot code to generate the chart is shown in the table below:

1
2
3
4
5
6
7
8
# plot2d.dat
0.0000 0.2261 D1
0.0468 0.0000 D2
0.0000 0.7363 D3
0.0000 0.6378 D4
0.0000 0.0000 D5
0.8751 0.0000 D6
0.4817 0.0000 D7
1
2
3
4
5
6
7
8
# plot3d.dat
0.0000 0.2261 0.0000 D1
0.0468 0.0000 0.2997 D2
0.0000 0.7363 0.0000 D3
0.0000 0.6378 0.0000 D4
0.0000 0.0000 0.0000 D5
0.8751 0.0000 0.4723 D6
0.4817 0.0000 0.8289 D7
1
2
3
4
5
6
# plot2d.gp
set style data labels
unset key
plot 'plot2d.dat' using 1:2:3 \
  with labels font "arial,11" \
  textcolor lt 1
1
2
3
4
5
6
# plot3d.gp
set style data labels
unset key
splot 'plot3d.dat' using 1:2:3:4 \
  with labels font "arial,11" \
  textcolor lt 1

From the charts above, it appears that the following clusters may be valid for our test document set. Notice that although D3 and D7 appear really close (overlapped) on the 3D chart, they don't seem to be close going by the 2D chart or the data. In any case, the results look believable, although not perfect, but that could be due to dimensionality reduction and/or the small data set.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
C0: [D1, D2, D5]
    D1  Human machine interface for <b>computer</b> applications
    D2  A survey of user opinion of <b>computer</b> system response time
    D5  The generation of random, binary and ordered trees
C1: [D3, D4]
    D3  The <b>EPS</b> user interface management <b>system</b>
    D4  <b>System</b> and human system engineering testing of <b>EPS</b>
C2: [D7]
    D7  Graph minors: A survey
C3: [D6]
    D6  The intersection graph of paths in trees

I think this post may be helpful to programmers like me who are just getting into IR (most people who are heavily into IR would probably know this stuff already). Text mining algorithms, by their very nature, need to deal with n-dimensional data, and the ability to visualize the data in 2D or 3D can be quite enlightening, so this is a useful tool to have in one's text mining toolbox.

Update 2009-04-26: In recent posts, I have been building on code written and described in previous posts, so there were (and rightly so) quite a few requests for the code. So I've created a project on Sourceforge to host the code. You will find the complete source code built so far in the project's SVN repository.

Saturday, October 04, 2008

Measuring and Graphing Search Quality

A colleague recently started me off on this whole thing. We have been working on improving our indexing algorithms, and his (rhetorical) question was how anybody could assert (as we were hoping to assert) that the changes being made were improving (or going to improve) the quality of search. His point was that if you cannot measure it, you cannot manage it. As usual (at least for him), he had part of the solution worked out already - his ideas form the basis of the user-based scoring for precision calculations described below.

The E-Measure

Looking for some unrelated information on the web, I came across this paper by Jones, Robertson, Santimetvirul and Willet which contains a description of the E-Measure (or effectiveness measure) that can be used to quantify search quality, which looked about perfect for my purposes. The description of the E-Measure from the article is paraphrased below:

                       (1 + β2) * P * R
  E(P,R) = 100 * (1 -  ----------------)
                         (β2 * P) + R
  where:
    P = precision
    R = recall
    β = a coefficient indicating the relative importance of 
        precision vs recall. If set to 1.0, precision and 
        recall are equally important. If set to 2.0, precision
        is twice as important as recall, etc.

For my own understanding, I drew some gnuplot charts of E against P and R, which I also include below - they may be helpful to you as well. As you can see, the quality of search is inversely related to the value of E, ie, if E goes down, search quality goes up, and vice versa. The best value for E appears to be when P and R are about equal (depending on the value of β, of course).

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
# Plot of E(P,R) holding R=1 and varying P=x with various beta
# beta=1 (red) - equal importance of P and R
# beta=0.5 (green) - R twice as important as P
# beta=2.0 (blue) - P twice as important as R
set multiplot
set xlabel 'x'
set ylabel 'E(x,1)'
set key off
set xrange [0:1]
set yrange [0:100]
beta=1
plot 100*(1-((1+beta**2)*x*1/((beta**2*x)+1))) linetype 1
beta=0.5
plot 100*(1-((1+beta**2)*x*1/((beta**2*x)+1))) linetype 2
beta=2.0
plot 100*(1-((1+beta**2)*x*1/((beta**2*x)+1))) linetype 3
unset multiplot
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
# Plot of E(P,R) holding P=1 and varying R=x with various beta
# beta=1 (red) - equal importance of P and R
# beta=0.5 (green) - R twice as important as P
# beta=2.0 (blue) - P twice as important as R
set multiplot
set xlabel 'x'
set ylabel 'E(1,x)'
set key off
set xrange [0:1]
set yrange [0:100]
beta=1
plot 100*(1-((1+beta**2)*1*x/((beta**2*1)+x))) linetype 1
beta=0.5
plot 100*(1-((1+beta**2)*1*x/((beta**2*1)+x))) linetype 2
beta=2.0
plot 100*(1-((1+beta**2)*1*x/((beta**2*1)+x))) linetype 3
unset multiplot
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
# Plot of E(P,R) where P=x and R=1-x with various beta
# beta=1 (red) - equal importance of P and R
# beta=0.5 (green) - R twice as important as P
# beta=2.0 (blue) - P twice as important as R
set multiplot
set xlabel 'x'
set ylabel 'E(x,1-x)'
set key off
set xrange [0:1]
set yrange [0:100]
beta=1
plot 100*(1-((1+beta**2)*x*(1-x)/((beta**2*x)+(1-x)))) linetype 1
beta=0.5
plot 100*(1-((1+beta**2)*x*(1-x)/((beta**2*x)+(1-x)))) linetype 2
beta=2.0
plot 100*(1-((1+beta**2)*x*(1-x)/((beta**2*x)+(1-x)))) linetype 3
unset multiplot

In this article, I describe how I calculate E for a given set of indexes built off the same corpus, each index corresponding to a block of iterative algorithmic changes in the index building code.

Calculating Recall

The formula for recall is r/T, where r is the number of relevant documents returned out of a total of T documents available for the given topic. It is not reasonable to compute T for every benchmark query, and in any case, my objective is to compare the increase or decrease in recall based on the original index. So, assuming a query Q on two different indexes, let r1 and r2 be the number of relevant documents returned:

  R1 = r1 / T
  R2 = r2 / T
  Therefore:
  R2 / R1 = r2 / r1

Based on the above, I define R for an index as the normalized count of the average of the number of relevant results returned from all my benchmark queries against that given index.

Calculating Precision

The formula for precision is r/n, where r is the number of relevant documents returned from a total of n documents returned from a search. This is easy enough to calculate, but does not capture position information, ie, the fact that a good result at the top of the results is more valuable than one at the bottom.

For that, I use the index created before our code changes as the baseline index to measure precision. For each query in our set of benchmark queries, a human user scores the top 30 search results using a 5-point scale, -2 being the worst and +2 being the best. I consider only 30 because according to studies such as these, users rarely go beyond the 3rd page of search results. The scores are captured in a database table such as the one shown below:

1
2
3
4
5
6
7
8
9
+-------------+--------------+------+-----+---------+-------+
| Field       | Type         | Null | Key | Default | Extra |
+-------------+--------------+------+-----+---------+-------+
| query_term  | varchar(128) | NO   | PRI |         |       | 
| result_url  | varchar(128) | NO   | PRI |         |       | 
| search_type | varchar(32)  | NO   | PRI |         |       | 
| position    | int(11)      | NO   |     |         |       | 
| score       | int(11)      | NO   |     |         |       | 
+-------------+--------------+------+-----+---------+-------+

The overall precision for the index is calculated as the average of the sum of weighted scores for each query result, across all queries against that index. The weight reflects the importance of the score based on its position.

  P = Σ (si * wi) / Nscored
  where:
    P = the precision of a given query
    si = the score for result at position i
    wi = atan(30 - i) / atan(30)
    Nscored = number of results which were scored

The plot of the w(i) function for i=[0..29] is shown below. As you can see, the scores for the top results are going to be given a weight of 1, and the scores at the bottom will be deboosted

1
2
3
4
5
6
# Plot of w(x) = atan(30-x)/atan(30) for x=[x..29]
set xlabel 'x'
set ylabel 'w(x)'
set xrange [0:29]
set yrange [0:1]
plot atan(30-x)/atan(30)

In addition, when issuing the same query against a new index created with an improved algorithm, we may find new results coming in to replace the existing results. These new results represent our uncertainity factor when calculating E. For search results for which we cannot find scores in the hon_scores table, we update an uncertainity metric using the max score possible, ie:

  U = Σ (M * wi) / Nunscored
  where:
    U = the uncertainity for a given query
    M = maximum score possible, in this case +2
    wi = atan(30 - i) / atan(30)
    i = position of the result about which we are uncertain
    Nunscored = number of results which were unscored, ie new.

The value of U is used to calculate upper and lower bounds for the E-measure by calculating E(P+U, R) and E(P-U, R).

Calculating Effectiveness

Once the baseline scores are set up, a backend process runs all the benchmark queries through the other indexes in the collection (if not run already), and populates a table such as this one:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
+---------------+-------------+------+-----+---------+-------+
| Field         | Type        | Null | Key | Default | Extra |
+---------------+-------------+------+-----+---------+-------+
| index_name    | varchar(32) | NO   | PRI |         |       | 
| search_type   | varchar(32) | NO   | PRI |         |       | 
| prec          | float(8,4)  | NO   |     |         |       | 
| uncertainity  | float(8,4)  | NO   |     |         |       | 
| recall        | float(8,4)  | NO   |     |         |       | 
| effectiveness | float(8,4)  | NO   |     |         |       | 
| effective_lb  | float(8,4)  | NO   |     |         |       | 
| effective_ub  | float(8,4)  | NO   |     |         |       | 
+---------------+-------------+------+-----+---------+-------+

Because the back-end code is part of a Spring web application, it is injected with quite a few specialized data access beans and if I had to show them all, this post would get very long. So I just provide pseudo-code for this job here. Essentially, all it is doing is executing a fixed set of queries against a fixed set of indexes, and looping through the results, looking for matches against the baseline, and calculating recall and precision appropriately..

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
for (indexName in indexNames):
  searcher = buildSearcher(indexName)
  precision, recall, uncertainity = 0
  n_queryterms = 0

  for (queryterm in queryterms):
    hits = searcher.search(queryterm)
    recall += hits.length
    hits = hits[0,30]
    position = n_scored = n_unscored = 0
    query_precision = query_uncertainity = 0

    for (hit in hits):
      url = hit.url
      weight = atan(30 - position) / atan(30)
      if (url is scored):
        query_precision += score * weight
        n_scored++
      else:
        query_uncertainity += 2 * weight
        n_unscored++
      position++

    # average for query
    query_precision = query_precision / n_scored
    query_uncertainity = query_uncertainity / n_unscored
    n_queryterms++
    precision += query_precision
    uncertainity += query_uncertainity

  # average precision and uncertainity for all query terms for a single index
  precision = precision / n_queryterms
  uncertainity = uncertainity / n_queryterms

  # compute and save effectiveness (first pass)
  save(recall, precision, uncertainity) for indexName

# After results for all indexes is populated, normalize recall so the max value
# across all indexes is 1
normalize_recall()
# Calculate e(p,r), e(p+u,r) and e(p-u,r) and save (second pass)
effectiveness = compute_e(p, r)
effectiveness_lowerbound = compute_e(p - u, r)
effectiveness_upperbound = compute_e(p + u, r)
# Save updated values (second pass)
save(recall, effectiveness, 
  effectiveness_lowerbound, effectiveness_upperbound) 
  for indexName

Graphing the Effectiveness measures

The chart(s) are generated dynamically off the data populated into the database by the backend process described above. I could have just used a table to display the results, but a graph makes things easier to visualize, and besides, I have been meaning to try out jfreechart for a while, and this seemed a good place to use it.

The code to allow the user to score individual search results and calculate the effectiveness scores are all part of a Spring web application, so I needed a way to show the graph on a web page. The controller just reads information off the table and builds a chart, converts it to a PNG bytestream and writes it into the response. The application allows scoring for different kinds of search, so multiple charts can be generated and shown on the same page.

Here is the code for the controller that generates the chart. The JFreeChart project has a pay-for-documentation business model, but there are any number of examples available on the web, which is where I got most of my information. I provide some comments in the code, but if you need more explanation, I would suggest looking at the many available JFreeChart examples.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
// Source: src/main/java/com/mycompany/myapp/controllers/GraphController.java
package com.mycompany.myapp.controllers;

import java.awt.BasicStroke;
import java.awt.Color;
import java.awt.Paint;
import java.io.OutputStream;
import java.text.DecimalFormat;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

import org.apache.commons.math.stat.descriptive.rank.Max;
import org.apache.commons.math.stat.descriptive.rank.Min;
import org.jfree.chart.ChartFactory;
import org.jfree.chart.ChartUtilities;
import org.jfree.chart.JFreeChart;
import org.jfree.chart.annotations.CategoryLineAnnotation;
import org.jfree.chart.axis.CategoryAxis;
import org.jfree.chart.axis.CategoryLabelPositions;
import org.jfree.chart.axis.NumberAxis;
import org.jfree.chart.labels.StandardCategoryItemLabelGenerator;
import org.jfree.chart.plot.CategoryPlot;
import org.jfree.chart.plot.PlotOrientation;
import org.jfree.chart.renderer.category.LineAndShapeRenderer;
import org.jfree.data.category.DefaultCategoryDataset;
import org.springframework.beans.factory.annotation.Required;
import org.springframework.web.bind.ServletRequestUtils;
import org.springframework.web.servlet.ModelAndView;
import org.springframework.web.servlet.mvc.Controller;

import com.mycompany.myapp.daos.EMeasureDao;
import com.healthline.util.Pair;

public class GraphController implements Controller {

  private EMeasureDao emeasureDao;

  @Required
  public void setEmeasureDao(EMeasureDao emeasureDao) {
    this.emeasureDao = emeasureDao;
  }

  public ModelAndView handleRequest(HttpServletRequest request,
     HttpServletResponse response) throws Exception {

    String searchType = 
      ServletRequestUtils.getRequiredStringParameter(request, "st");
    List<Map<String,Object>> scoresForSearchType = 
      emeasureDao.getScoresForSearchType(searchType);

    double minYe = Double.MAX_VALUE;
    double maxYe = Double.MIN_VALUE;
    double minYpr = Double.MAX_VALUE;
    double maxYpr = Double.MIN_VALUE;
    DefaultCategoryDataset prDataset = new DefaultCategoryDataset();
    DefaultCategoryDataset eDataset = new DefaultCategoryDataset();
    Map<String,Pair<Double,Double>> candlesticks = 
      new LinkedHashMap<String,Pair<Double,Double>>();
    for (Map<String,Object> scoreForSearchType : scoresForSearchType) {
      String indexName = (String) scoreForSearchType.get("INDEX_NAME");
      Double precision = new Double((Float) scoreForSearchType.get("PREC"));
      Double recall = new Double((Float) scoreForSearchType.get("RECALL"));
      Double effectiveness = 
        new Double((Float) scoreForSearchType.get("EFFECTIVENESS"));
      Double effectiveLb = 
        new Double((Float) scoreForSearchType.get("EFFECTIVE_LB"));
      Double effectiveUb = 
        new Double((Float) scoreForSearchType.get("EFFECTIVE_UB"));
      eDataset.addValue(effectiveness, "E-Measure", indexName);
      prDataset.addValue(precision, "Precision", indexName);
      prDataset.addValue(recall, "Recall", indexName);
      candlesticks.put(indexName, 
        new Pair<Double,Double>(effectiveLb, effectiveUb));
      minYe = min(new double[] {
        minYe, effectiveness, effectiveLb, effectiveUb});
      maxYe = max(new double[] {
        maxYe, effectiveness, effectiveLb, effectiveUb});
      minYpr = min(new double[] {minYpr, precision, recall});
      maxYpr = max(new double[] {maxYpr, precision, recall});
    }
    
    JFreeChart chart = ChartFactory.createLineChart(
      "", "Indexes", "E-Measure (%)", eDataset, 
      PlotOrientation.VERTICAL, true, true, false);
    CategoryPlot plot = (CategoryPlot) chart.getPlot();
    
    // show vertical gridlines
    plot.setDomainGridlinePaint(Color.white);
    plot.setDomainGridlineStroke(CategoryPlot.DEFAULT_GRIDLINE_STROKE);
    plot.setDomainGridlinesVisible(true);

    // customize domain (x-axis)
    CategoryAxis domainAxis = plot.getDomainAxis();
    domainAxis.setCategoryLabelPositions(CategoryLabelPositions.DOWN_45);
    domainAxis.setTickLabelsVisible(true);

    // customize range (y-axis).
    NumberAxis rangeAxis = (NumberAxis) plot.getRangeAxis();
    rangeAxis.setLowerBound(minYe == Double.MAX_VALUE ? 0.0D : minYe * 0.9D);
    rangeAxis.setUpperBound(maxYe == Double.MIN_VALUE ? 200.0D : 
      maxYe * 1.1D);
    rangeAxis.setLabelPaint(Color.red);

    // display data values for e-measure
    LineAndShapeRenderer renderer = 
      (LineAndShapeRenderer) plot.getRenderer();
    DecimalFormat decimalFormat = new DecimalFormat("###.##");
    renderer.setSeriesItemLabelGenerator(0, 
      new StandardCategoryItemLabelGenerator(
      StandardCategoryItemLabelGenerator.DEFAULT_LABEL_FORMAT_STRING, 
      decimalFormat));
    renderer.setSeriesPaint(0, Color.red);
    renderer.setSeriesStroke(0, new BasicStroke(2.0F, 
        BasicStroke.CAP_ROUND, BasicStroke.JOIN_ROUND));
    renderer.setSeriesItemLabelsVisible(0, true);
    renderer.setBaseItemLabelsVisible(true);
    plot.setRenderer(renderer);
    
    // set candlestick annotations on e-measure for uncertainity
    for (String indexName : candlesticks.keySet()) {
      Pair<Double,Double> hilo = candlesticks.get(indexName);
      plot.addAnnotation(new CategoryLineAnnotation(
        indexName, hilo.getFirst(), 
        indexName, hilo.getSecond(), 
        Color.red,
        new BasicStroke(2.0F, BasicStroke.CAP_ROUND,
        BasicStroke.JOIN_ROUND)));
    }
    
    // add precision and recall with right hand side y-axis (0..2)
    
    NumberAxis prRangeAxis = new NumberAxis("Precision/Recall");
    prRangeAxis.setLowerBound(minYpr == Double.MAX_VALUE ? 0.0D : 
      minYpr * 0.9D);
    prRangeAxis.setUpperBound(maxYpr == Double.MIN_VALUE ? 2.0D : 
      maxYpr * 1.1D);
    plot.setRangeAxis(1, prRangeAxis);
    plot.setDataset(1, prDataset);
    plot.mapDatasetToRangeAxis(1, 1);
    // display data values
    Paint[] colors = new Paint[] {Color.green, Color.blue};
    LineAndShapeRenderer prRenderer = new LineAndShapeRenderer();
    DecimalFormat prDecimalFormat = new DecimalFormat("#.##");
    for (int i = 0; i < 2; i++) {
      prRenderer.setSeriesItemLabelGenerator(i, 
        new StandardCategoryItemLabelGenerator(
        StandardCategoryItemLabelGenerator.DEFAULT_LABEL_FORMAT_STRING, 
        prDecimalFormat));
      prRenderer.setSeriesPaint(i, colors[i]);
      prRenderer.setSeriesStroke(i, new BasicStroke(2.0F, 
        BasicStroke.CAP_ROUND, BasicStroke.JOIN_ROUND));
      prRenderer.setSeriesItemLabelsVisible(i, true);
    }
    prRenderer.setBaseItemLabelsVisible(true);
    plot.setRenderer(1, prRenderer);
    
    // output to response
    OutputStream responseOutputStream = response.getOutputStream();
    ChartUtilities.writeChartAsPNG(responseOutputStream, chart, 750, 400);
    responseOutputStream.flush();
    responseOutputStream.close();
    return null;
  }

  private double max(double[] values) {
    Max max = new Max();
    return max.evaluate(values);
  }

  private double min(double[] values) {
    Min min = new Min();
    return min.evaluate(values);
  }
}

The Controller is called from an image tag from the JSP page like this. That way we can have multiple image tags and they are all started off in parallel while the page is loaded.

1
<img src="/honscorer/_graph.do"/>
Here is what a generated chart looks like:

As we can see from the chart above, both recall and precision increased initially from the baseline index, and the E-Measure came down from 21.79 to 5, but then some algorithm change between 2008-07-09 and 2008-07-24x caused a slight decrease in the precision and a slight uptick in the E-Measure. I think automated search quality metrics such as these can be quite useful as an early warning system for unexpected side effects caused by some algorithm change, as well as a way to measure how a change or set of changes affect the overall search quality.