Sunday, November 30, 2008

IR Math in Java : Citation based Ranking

If you are a regular reader, you know that I have been working my way through Dr Manu Konchady's TMAP book in an effort to teach myself some Information Retrieval theory. This week, I talk about my experience implementing Google's PageRank algorithm in Java, as described in Chapter 6 of this book and the PageRank Wikipedia page. In the process, I also ended up developing a Sparse Matrix implementation in order to compute PageRank for real data collections, which I contributed back to the commons-math project.

The PageRank algorithm was originally proposed by Google's founders, and while it does form part of the core of what SEO types refer to as The Google Algorithm, the Algorithm is significantly more comprehensive and complex. My intent is not to reverse engineer this stuff, nor to hack it. I think the algorithm is interesting, and thought it would be worth figuring out how to code this up in Java.

The PageRank algorithm is based on the citation model (hence the title of this post), ie, if a scholarly paper is considered to be of interest, other scholarly papers cite it as a reference. Similarly, a page with good information is linked to by other pages on the web. The PageRank of a page is the sum of normalized PageRanks of pages that point to it. If a page links out to more than one page, its contribution to the target page's PageRank is its PageRank divided by the number of pages it links out to. Obviously, this is kind of a chicken and egg problem, so it needs to be solved in a recursive way.

In addition, there is a damping factor d to simulate a random surfer, who clicks on links but eventually gets bored and does a new search and starts over. To compensate for the damping factor, a constant factor c is added to the PageRank formula. The formula is thus:

  rj = c + (d * Σ ri / ni)
  where:
    rj = PageRank for page j
    d = damping factor, usually 0.85
    c = (1 - d) / N
    ri = PageRank for page i which points to page j
    ni = Number of outbound links from page i
    N = number of documents in the collection

This would translate to a set of linear equations, and could thus be re-written as a recursive matrix equation. As much as I would like to say that I arrived at this epiphany all by myself, I really just worked backwards from the formula on the Wikipedia page.

  R = C + d * A * R0
  where:
    R  = a column vector of size N, containing the ranks of pages in the collection.
    C  = a constant column vector containing [ci]
    d  = scalar damping factor
    A  = a NxN square matrix containing the initial probabilities 1/N for each (i,j)
         where page(i) links to page(j), and 0 for all other (i,j).
    R0 = the initial guess for the page ranks, all set to 1/N.

We populate the matrices on the right hand side, then compute R. At each stage we check for convergence (if it is close enough to the previous value of R). If not, we set R0 from R and recompute. Here is the code to do this:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
// Source: src/main/java/com/mycompany/myapp/ranking/PageRanker.java
package com.mycompany.myapp.ranking;

import java.util.List;
import java.util.Map;

import org.apache.commons.lang.StringUtils;
import org.apache.commons.math.linear.RealMatrix;
import org.apache.commons.math.linear.SparseRealMatrixImpl;

public class PageRanker {

  private Map<String,Boolean> linkMap;
  private double d;
  private double threshold;
  private List<String> docIds;
  private int numDocs;
  
  public void setLinkMap(Map<String,Boolean> linkMap) {
    this.linkMap = linkMap;
  }
  
  public void setDocIds(List<String> docIds) {
    this.docIds = docIds;
    this.numDocs = docIds.size();
  }
  
  public void setDampingFactor(double dampingFactor) {
    this.d = dampingFactor;
  }
  
  public void setConvergenceThreshold(double threshold) {
    this.threshold = threshold;
  }
  
  public RealMatrix rank() throws Exception {
    // create and initialize the probability matrix, start with all
    // equal probability p(i,j) of 0 or 1/n depending on if there is 
    // a link or not from page i to j.
    RealMatrix a = new SparseRealMatrixImpl(numDocs, numDocs);
    for (int i = 0; i < numDocs; i++) {
      for (int j = 0; j < numDocs; j++) {
        String key = StringUtils.join(new String[] {
          docIds.get(i), docIds.get(j)
        }, ",");
        if (linkMap.containsKey(key)) {
          a.setEntry(i, j, 1.0D / numDocs);
        }
      }
    }
    // create and initialize the constant matrix
    RealMatrix c = new SparseRealMatrixImpl(numDocs, 1);
    for (int i = 0; i < numDocs; i++) {
      c.setEntry(i, 0, ((1.0D - d) / numDocs));
    }
    // create and initialize the rank matrix
    RealMatrix r0 = new SparseRealMatrixImpl(numDocs, 1);
    for (int i = 0; i < numDocs; i++) {
      r0.setEntry(i, 0, (1.0D / numDocs));
    }
    // solve for the pagerank matrix r
    RealMatrix r;
    int i = 0;
    for(;;) {
      r = c.add(a.scalarMultiply(d).multiply(r0));
      // check for convergence
      if (r.subtract(r0).getNorm() < threshold) {
        break;
      }
      r0 = r.copy();
      i++;
    }
    return r;
  }
}

Here is the JUnit code to call the class. We set up the damping factor and the convergence threshold. We use the picture of the graph on the Wikipedia PageRank article as our initial dataset. The dataset is represented as a comma-delimited pairs of linked page ids.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
  @Test
  public void testRankWithToyData() throws Exception {
    Map<String,Boolean> linkMap = getLinkMapFromDatafile(
      "src/test/resources/pagerank_links.txt");
    PageRanker ranker = new PageRanker();
    ranker.setLinkMap(linkMap);
    ranker.setDocIds(Arrays.asList(new String[] {
      "1", "2", "3", "4", "5", "6", "7"
    }));
    ranker.setDampingFactor(0.85D);
    ranker.setConvergenceThreshold(0.001D);
    RealMatrix pageranks = ranker.rank();
    log.debug("pageRank=" + pageranks.toString());
  }

  private Map<String,Boolean> getLinkMapFromDatafile(String filename) 
      throws Exception {
    Map<String,Boolean> linkMap = new HashMap<String,Boolean>();
    BufferedReader reader = new BufferedReader(new FileReader(filename));
    String line;
    while ((line = reader.readLine()) != null) {
      if (StringUtils.isEmpty(line) || line.startsWith("#")) {
        continue;
      }
      String[] pairs = StringUtils.split(line, "\t");
      linkMap.put(pairs[0], Boolean.TRUE);
    }
    return linkMap;
  }

You may have noticed that I am using calls to SparseRealMatrixImpl, which does not exist in the commons-math codebase at the time of this writing. The reason I implemented the SparseRealMatrixImpl was because when I try to run the algorithm against a real interlinked data collection of about 6000+ documents, I would consistently get an Out Of Memory Exception with the code that used a RealMatrixImpl (which uses a two dimensional double array as its backing store).

The SparseRealMatrixImpl subclasses RealMatrixImpl, but uses a Map<Point,Double> as its backing store. The Point class is a simple struct type data holder private class that encapsulates the row and column number for the data element. Only non-zero matrix elements are actually stored in the Map. This works out because the largest matrix (A) contains mostly zeros, ie comparatively few pages are actually linked. The patch is available here.

Update 2009-04-26: In recent posts, I have been building on code written and described in previous posts, so there were (and rightly so) quite a few requests for the code. So I've created a project on Sourceforge to host the code. You will find the complete source code built so far in the project's SVN repository.

Saturday, November 22, 2008

Jab - Inflict pain on your Java Application

When load testing web applications, I usually take a few URLs and run them through tools such as Apache Bench (ab) or more recently, Siege. Recently, however, I needed to compare performance under load for code querying data from a MySQL database table versus a Lucene index. I could have built a simple web-based interface around this code and used the tools mentioned above, but it seemed like too much work, so I looked around to see if there was anything in library form that I could use to load test Java components.

The first result on my Google search came up with information about Mike Clark's JUnitPerf project, which consists of a set of Test decorators designed to work with JUnit 3.x. Since I use JUnit 4.x, I would have to write JUnit 3.x style code and run it under JUnit 4.x, which is something I'd rather not do unless really, really have to. That was not the biggest problem, however. Since both my components depended on external resources, they would have to be pre-instantiated for the test times to be realistic. Since JUnitPerf wraps an existing Test, which then runs within the JUnitRunner, the instantiation would have to be done either within the @Test or @Before equivalent methods, or I would have to write another @BeforeClass style JUnit 3.8 decorator. In the first two cases, tests run with JUnitPerf's LoadTest would include the resource setup times. So I decided to write my own little framework which was JUnit-agnostic and yet runnable from within Junit 4.x, and which allowed me to setup resources outside the code being tested.

Overview

My framework borrows the idea of using the Decorator pattern from JUnitPerf. It consists of 2 interfaces and 5 different Test Decorator implementations, and couple of utility classes. The only dependencies are Java 1.5+, commons-lang, commons-math, commons-logging and log4j. I call it jab (JAva Bench), drawing inspiration for the name from Apache Bench. It can also be thought of as something that inflicts pain on your Java application by putting it under load (hence the title of this post).

Component Descriptions

ITestable

The ITestable interface provides the template which a peice of code that wishes to be tested with jab needs to implement. The resources argument passes in all the pre-instantiated resources that are needed by the ITestable to execute. Further down, I show you a real-life example, which incidentally was also the code that drove the building of this framework - there are two example implementations of ITestable in there.

1
2
3
4
5
6
7
8
// Source: src/main/java/com/mycompany/jab/ITestable.java
package com.mycompany.jab;

import java.util.Map;

public interface ITestable {
  public void execute(Map<String,Object> resources) throws Exception;
}

ITest

ITest is the interface that all our Test instances implement. This is really something internal to the framework, providing a template for people writing new Test implementations.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
// Source: src/main/java/mycompany/jab/ITest.java
package com.mycompany.jab;

import java.util.List;

public interface ITest extends Cloneable {
  public void runTest() throws Exception;
  public Double getAggregatedObservation();
  public List<Double> getObservations();
  public Object clone();
}

SingleTest

This is the most basic (and central) implementation of ITest. All it does is wrap the ITestable.execute() call within two calls to System.currentTimeMillis() to grab the wallclock times, and calculate and update the elapsed times into the appropriate counters.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
// Source: src/main/java/com/mycompany/jab/SingleTest.java
package com.mycompany.jab;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

/**
 * Models a single test. All it does is attach timers around the 
 * ITestable.execute() call.
 */
public class SingleTest implements ITest {

  private final Log log = LogFactory.getLog(getClass());
  
  private Class<? extends ITestable> testableClass;
  private Map<String,Object> resources;
  private ITestable testable;
  
  private List<Double> observations = new ArrayList<Double>();
  
  public SingleTest(Class<? extends ITestable> testableClass, 
      Map<String,Object> resources) throws Exception {
    this.testableClass = testableClass;
    this.resources = resources;
    this.testable = testableClass.newInstance();
  }

  public Double getAggregatedObservation() {
    return getObservations().get(0);
  }
  
  public List<Double> getObservations() {
    return observations;
  }

  public void runTest() throws Exception {
    try {
      observations.clear();
      long start = System.currentTimeMillis();
      testable.execute(resources);
      long stop = System.currentTimeMillis();
      observations.add(new Double(stop - start));
    } catch (Exception e) {
      observations.add(-1.0D); // negative number indicate that it failed
      e.printStackTrace();
    }
  }
  
  @Override
  public Object clone() {
    try {
      return new SingleTest(this.testableClass, this.resources);
    } catch (Exception e) {
      log.error("Cloning object of class: " + this.getClass() + 
        " failed", e);
      return null;
    }
  }
}

This is the only ITest implementation that has access to the ITestable. More complex ITest implementations wrap a SingleTest. Instantiating a SingleTest is simple. The example below shows it being instantiated with an ITestable implementation called MockTestable.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
    // instantiate and populate a Map<String,Object> in
    // your @Before annotated method
    resources.put("text", "Some random text");
    ...
    // instantiate a SingleTest in your @Test annotated method
    // and run it
    SingleTest test = new SingleTest(MockTestable.class, resources);
    test.runTest();
    // return the aggregated observation
    double elapsed = test.getAggregatedObservation();

AggregationPolicy

The next two implementations are really decorators for the SingleTest, which can be used to run the underlying test in serial or in parallel. Now that we will have multiple elapsed time observations, we need to be able to control what we will do with these multiple observations. The default is to expose the average of these observations using the getAggregatedObservation() method. However, this is tunable, using the AggregationPolicy argument in the constructor. The AggregationPolicy is a simple enum as shown below:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
// Source: src/main/java/com/mycompany/jab/AggregationPolicy.java
package com.mycompany.jab;

/**
 * Enumerates the possible aggregation policies for observations returned
 * from RepeatedTest and ConcurrentTest (and other combo tests in the 
 * future).
 */
public enum AggregationPolicy {

  SUM, AVERAGE, MAX, MIN, VARIANCE, STDDEV, COUNT, FAILED, SUCCEEDED;
  
}

Most of the values are self explanatory, corresponding to various common statistical measures. The FAILED and SUCCEEDED signals that the number of the failures and successful runs should be counted and aggregated.

Aggregator

The Aggregator provides utility methods to actually do the aggregation that is requested using the AggregationPolicy. We rely on the StatUtils class in commons-math to do the heavy lifting.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
// Source: src/main/java/com/mycompany/jab/Aggregator.java
package com.mycompany.jab;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;

import org.apache.commons.lang.ArrayUtils;
import org.apache.commons.math.stat.StatUtils;

/**
 * Aggregates a List of Double observations into a single Double value
 * based on the specified aggregation policy.
 */
public class Aggregator {

  private double[] failures;
  private double[] successes;

  public Aggregator(List<Double> observations) {
    List<Double> sobs = new ArrayList<Double>();
    List<Double> fobs = new ArrayList<Double>();
    for (Iterator<Double> sit = observations.iterator(); 
        sit.hasNext();) {
      Double obs = sit.next();
      if (obs < 0.0D) {
        fobs.add(obs);
      } else {
        sobs.add(obs);
      }
    }
    this.successes = ArrayUtils.toPrimitive(sobs.toArray(new Double[0]));
    this.failures = ArrayUtils.toPrimitive(fobs.toArray(new Double[0]));
  }

  public Double aggregate(AggregationPolicy policy) {
    switch(policy) {
    case SUM:
      return StatUtils.sum(successes);
    case MAX:
      return StatUtils.max(successes);
    case MIN:
      return StatUtils.min(successes);
    case VARIANCE:
      return StatUtils.variance(successes);
    case STDDEV:
      return Math.sqrt(StatUtils.variance(successes));
    case COUNT:
      return ((double) (successes.length + failures.length));
    case FAILED:
      return ((double) failures.length);
    case SUCCEEDED:
      return ((double) successes.length);
    case AVERAGE:
    default:
      return StatUtils.mean(successes);
    }
  }
}

RepeatedTest

A RepeatedTest decorates an ITest, usually a SingleTest. All it does is run the decorated ITest a specified number of times, collecting and aggregating the elapsed time observations. The type of aggregation is specified with an AggregationPolicy. The default AggregationPolicy is AVERAGE, meaning that the aggregated observation is the average of the individual aggregated observations from the ITests.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
// Source: src/main/java/com/mycompany/jab/RepeatedTest.java
package com.mycompany.jab;

import java.util.ArrayList;
import java.util.List;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

/**
 * Models a test that consists of running a test a fixed number of times
 * in series.
 */
public class RepeatedTest implements ITest {

  private final Log log = LogFactory.getLog(getClass());
  
  private ITest test;
  private int numIterations;
  private AggregationPolicy policy;
  private long delayMillis;
  
  private List<Double> observations = new ArrayList<Double>();
  
  public RepeatedTest(ITest test, int numIterations) {
    this(test, numIterations, AggregationPolicy.AVERAGE, 0L);
  }
  
  public RepeatedTest(ITest test, int numIterations, 
      AggregationPolicy policy) {
    this(test, numIterations, policy, 0L);
  }
  
  public RepeatedTest(ITest test, int numIterations, 
      AggregationPolicy policy, long delayMillis) {
    this.test = test;
    this.numIterations = numIterations;
    this.policy = policy;
    this.delayMillis = delayMillis;
  }

  public Double getAggregatedObservation() {
    Aggregator aggregator = new Aggregator(getObservations());
    return aggregator.aggregate(policy);
  }

  public List<Double> getObservations() {
    return observations;
  }

  public void runTest() throws Exception {
    ITest clone = (ITest) test.clone();
    for (int i = 0; i < numIterations; i++) {
      clone.runTest();
      observations.add(clone.getAggregatedObservation());
      if (delayMillis > 0L) {
        try { Thread.sleep(delayMillis); }
        catch (InterruptedException e) {;}
      }
    }
  }
  
  @Override
  public Object clone() {
    return new RepeatedTest(this.test, this.numIterations, this.policy, 
      this.delayMillis);
  }
}

As you can see, there three constructors that you can use. The simplest one specifies the ITest and the number of repetitions, the second one overrides the default AggregationPolicy to be used, and the third one specifies that the test should wait a specified number of milliseconds between ITest invocations. Here are some usage examples.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
    // run the SingleTest 10 times, no delay, default aggregation
    RepeatedTest test1 = new RepeatedTest(
      new SingleTest(MockTestable.class, resources), 10);

    // run the SingleTest 10 times, no delay, override aggregation
    // policy to return the sum of the 10 observations
    RepeatedTest test2 = new RepeatedTest(
      new SingleTest(MockTestable.class, resources), 10,
      AggregationPolicy.SUM);

    // run the SingleTest 10 times, with default aggregation,
    // and a 10ms delay between each invocation
    RepeatedTest test3 = new RepeatedTest(
      new SingleTest(MockTestable.class, resources), 10,
      AggregationPolicy.AVERAGE, 10L);

ConcurrentTest

A ConcurrentTest decorates an ITest and runs a specific number of these ITests concurrently. Like RepeatedTest, its default AggregationPolicy is AVERAGE, which can be overriden. Also like RepeatedTest, it allows you to specify a delay between spawning successive parallel ITest instances.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
// Source: src/main/java/com/mycompany/jab/ConcurrentTest.java
package com.mycompany.jab;

import java.util.List;
import java.util.ArrayList;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

/**
 * Models multiple running concurrent jobs.
 */
public class ConcurrentTest implements ITest {

  private final Log log = LogFactory.getLog(getClass());
  
  private ITest test;
  private int numConcurrent;
  private AggregationPolicy policy;
  private long delayMillis;
  
  private List<Callable<ITest>> callables = null;
  
  private List<Double> observations = new ArrayList<Double>();

  public ConcurrentTest(ITest test, int numConcurrent) throws Exception {
    this(test, numConcurrent, AggregationPolicy.AVERAGE, 0L);
  }
  
  public ConcurrentTest(ITest test, int numConcurrent, 
      AggregationPolicy policy) throws Exception {
    this(test, numConcurrent, policy, 0L);
  }
  
  public ConcurrentTest(ITest test, int numConcurrent, 
      AggregationPolicy policy, long delayMillis) throws Exception {
    this.test = test;
    this.numConcurrent = numConcurrent;
    this.delayMillis = delayMillis;
    this.policy = policy;
    this.callables = 
      new ArrayList<Callable<ITest>>(numConcurrent);
    for (int i = 0; i < numConcurrent; i++) {
      final ITest clone = (ITest) this.test.clone();
      callables.add(new Callable<ITest>() {
        public ITest call() throws Exception {
          clone.runTest();
          return clone;
      }});
    }
  }

  public Double getAggregatedObservation() {
    Aggregator aggregator = new Aggregator(getObservations());
    return aggregator.aggregate(policy);
  }

  public List<Double> getObservations() {
    return observations;
  }

  public void runTest() throws Exception {
    ExecutorService executor = Executors.newFixedThreadPool(numConcurrent);
    List<Future<ITest>> tests = 
      new ArrayList<Future<ITest>>();
    for (int i = 0; i < numConcurrent; i++) {
      Future<ITest> test = executor.submit(callables.get(i));
      tests.add(test);
      if (delayMillis > 0L) {
        try { Thread.sleep(delayMillis); }
        catch (InterruptedException e) {;}
      }
    }
    for (Future<ITest> future : tests) {
      future.get();
    }
    executor.shutdown();
    for (int i = 0; i < numConcurrent; i++) {
      ITest test = tests.get(i).get();
      observations.add(test.getAggregatedObservation());
    }
  }
  
  @Override
  public Object clone() {
    try {
      return new ConcurrentTest(this.test, this.numConcurrent, this.policy);
    } catch (Exception e) {
      log.error("Cloning object of class: " + this.getClass() + 
        " failed", e);
      return null;
    }
  }
}

I picked up some pointers on the new Java 1.5 threading style from this blog post on recursor. As you can see, the constructors are similar to those for RepeatedTest. Here are some usage examples:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
    // run SingleTest in parallel with 10 threads, no delay
    // between thread spawning
    ConcurrentTest test1 = new ConcurrentTest(
      new SingleTest(MockTestable.class, resources), 10);

    // run SingleTest in parallel with 10 threads, override the
    // AggregationPolicy to SUM, no delay between thread spawning
    ConcurrentTest test2 = new ConcurrentTest(
      new SingleTest(MockTestable.class, resources), 10,
      AggregationPolicy.SUM);

    // run SingleTest in parallel with 10 threads, default 
    // AggregationPolicy, with delay of 10ms between thread spawning
    ConcurrentTest test3 = new ConcurrentTest(
      new SingleTest(MockTestable.class, resources), 10,
      AggregationPolicy.AVERAGE, 10L);

TimedTest

A TimedTest is passed a ITest and a maximum allowed time. The underlying ITest is allowed to run to completion, and if the aggregated observation exceeds the maximum allowed time, it is recorded as a failure.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
// Source: src/main/java/com/mycompany/jab/TimedTest.java
package com.mycompany.jab;

import java.util.ArrayList;
import java.util.List;

/**
 * Models a test which has an upper time limit. If the test runs beyond
 * that period, it is counted as a failure.
 */
public class TimedTest implements ITest {

  private ITest test;
  private long maxElapsedMillis;
  private AggregationPolicy policy;
  private List<Double> observations = new ArrayList<Double>();
  
  public TimedTest(ITest test, long maxElapsedMillis) {
    this(test, maxElapsedMillis, AggregationPolicy.AVERAGE);
  }

  public TimedTest(ITest test, long maxElapsedMillis, 
      AggregationPolicy policy) {
    this.test = test;
    this.maxElapsedMillis = maxElapsedMillis;
    this.policy = policy;
  }
  
  public Double getAggregatedObservation() {
    Aggregator aggregator = new Aggregator(observations);
    return aggregator.aggregate(policy);
  }

  public List<Double> getObservations() {
    return observations;
  }

  public void runTest() throws Exception {
    test.runTest();
    List<Double> observations = test.getObservations();
    if (getAggregatedObservation() > maxElapsedMillis) {
      observations.add(-1.0D);
    } else {
      observations.add(test.getAggregatedObservation());
    }
  }
  
  @Override
  public Object clone() {
    return new TimedTest(this.test, this.maxElapsedMillis, this.policy);
  }
}

Calling patterns are similar to the RepeatedTest and ConcurrentTest decorators. Here are some examples:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
    // Construct a timed test, setting the maximum allowed time to
    // 10ms, and count the number of failures.
    TimedTest test1 = new TimedTest(
      new SingleTest(MockTestable.class, resources), 10L,
      AggregationPolicy.FAILED);

    // Construct a timed test, setting the maximum allowed time to
    // 2000ms (2s).
    TimedTest test2 = new TimedTest(
      new SingleTest(MockTestable.class, resources), 2000L);

ThroughputTest

This test measures the througput, i.e. the number of times the test ran within the maximum allowed time period. This is useful when you want to stress test a component for a given time period, say 10mins, and see how many times it ran.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
// Source: src/main/java/com/mycompany/jab/ThroughputTest.java
package com.mycompany.jab;

import java.util.ArrayList;
import java.util.List;

/**
 * Given a test and a maximum time to run, returns the number of times
 * the test was run in the time provided.
 */
public class ThroughputTest implements ITest {

  private ITest test;
  private long maxElapsedMillis;
  private List<Double> observations = new ArrayList<Double>();
  private AggregationPolicy policy;

  public ThroughputTest(ITest test, long maxElapsedMillis) {
    this(test, maxElapsedMillis, AggregationPolicy.AVERAGE);
  }

  public ThroughputTest(ITest test, long maxElapsedMillis, 
      AggregationPolicy policy) {
    this.test = test;
    this.maxElapsedMillis = maxElapsedMillis;
    this.policy = policy;
  }

  public Double getAggregatedObservation() {
    Aggregator aggregator = new Aggregator(this.observations);
    return aggregator.aggregate(policy);
  }

  public List<Double> getObservations() {
    return observations;
  }

  public void runTest() throws Exception {
    long totalElapsed = 0L;
    for (;;) {
      long start = System.currentTimeMillis();
      this.test.runTest();
      long end = System.currentTimeMillis();
      long elapsed = end - start;
      observations.add((double) elapsed);
      totalElapsed += elapsed;
      if (totalElapsed > maxElapsedMillis) {
        break;
      }
    }
  }

  @Override
  public Object clone() {
    return new ThroughputTest(this.test, this.maxElapsedMillis, this.policy);
  }
}

And here is an example of how to call this. As you can see, you can nest decorators fairly deep, although it is left to you to determine what kind of nesting make sense.

1
2
3
4
5
6
    // Declare a test that runs for 15s, which consists of 5 parallel
    // invocations of a set of 5 serial invocations of the SingleTest
    ThroughputTest test = new ThroughputTest(
      new ConcurrentTest(new RepeatedTest(
      new SingleTest(MockTestable.class, resources), 5),
      5), 15000L);

A real-life example

I tested the code above with a MockTestable that slept for 10s to simulate some kind of load. But the whole reason I built this was so I could do this sort of thing on real-life components. Here is a JUnit test that runs searches against 2 components and compares their performance under load. The searchers are modeled as ITest implementations.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
// Source: src/test/java/com/mycompany/jab/example/MySQLSearchTestable.java
package com.mycompany.jab.example;;

import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Queue;

import javax.sql.DataSource;

import com.mycompany.jab.ITestable;

public class MySQLSearchTestable implements ITestable {

  public void execute(Map<String,Object> resources) throws Exception {
    // get references to various resources
    Queue<String> mysqlQueue = 
      (Queue<String>) resources.get("mysqlQueue");
    DataSource dataSource = (DataSource) resources.get("dataSource");
    String imuidQuery = (String) resources.get("sqlQuery");
    Integer preparedStmtFetchSize = 
      (Integer) resources.get("preparedStmtFetchSize");
    String randomImuid = mysqlQueue.poll();
    // do the work
    List<Result> results = new ArrayList<Result>();
    Connection conn = dataSource.getConnection();
    PreparedStatement ps = conn.prepareStatement(imuidQuery);
    ps.setFetchSize(preparedStmtFetchSize);
    ps.setString(1, randomImuid);
    ResultSet rs = null;
    try {
      rs = ps.executeQuery();
      while (rs.next()) {
        // populate a Result object
        Result result = new Result();
        // result.field = rs.getString(n) type calls 
        // deliberately removed
        ...
        results.add(result);
      }
    } finally {
      if (rs != null) {
        try { rs.close(); } catch (Exception e) {;}
      }
      if (ps != null) {
        try { ps.close(); } catch (Exception e) {;}
      }
      if (conn != null) {
        try { conn.close(); } catch (Exception e) {;}
      }
    }
  }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
// Source: src/test/java/com/mycompany/jab/example/LuceneSearchTestable.java
package com.mycompany.jab.example;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Queue;

import org.apache.lucene.document.Document;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.Hits;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Sort;
import org.apache.lucene.search.TermQuery;

import com.mycompany.jab.ITestable;

public class LuceneSearchTestable implements ITestable {

  public void execute(Map<String,Object> resources) throws Exception {
    // get references to various resources
    Queue<String> luceneQueue = 
      (Queue<String>) resources.get("luceneQueue");
    IndexSearcher searcher = (IndexSearcher) resources.get("searcher");
    // start the test
    List<Result> results = new ArrayList<Result>();
    String id = luceneQueue.poll();
    Hits hits = searcher.search(new TermQuery(new Term("myId", id)));
    int numHits = hits.length();
    for (int i = 0; i < numHits; i++) {
      Result result = new Result();
      // result.field = doc.get("fieldName") type calls 
      // deliberately removed
      ...
      results.add(result);
    }
  }
}

As you can see, these two testables are just some simple code to run an SQL query against a database table and a TermQuery against a Lucene index. All the expensive resources (and some inexpensive ones) are passed to the ITestable via the resources map. The resources are created in the calling JUnit test, which also uses the jab mini-framework to build a pair of progressively larger ConcurrentTest by varying the number of users. Each ConcurrentTest is composed of 10 RepeatedTest, which invoke one of the two ITestables shown above. The observations from each run are aggregated and written out into a flat file in tab-delimited format. Here is the code for the JUnit test.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
// Source: src/test/java/com/mycompany/jab/example/JabExampleTest.java
package com.mycompany.jab.example;

import java.io.FileOutputStream;
import java.io.OutputStreamWriter;
import java.io.PrintWriter;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.Queue;
import java.util.Random;
import java.util.Set;
import java.util.concurrent.ConcurrentLinkedQueue;

import javax.sql.DataSource;

import org.apache.commons.dbcp.ConnectionFactory;
import org.apache.commons.dbcp.DriverManagerConnectionFactory;
import org.apache.commons.dbcp.PoolableConnectionFactory;
import org.apache.commons.dbcp.PoolingDataSource;
import org.apache.commons.lang.StringUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.commons.pool.ObjectPool;
import org.apache.commons.pool.impl.GenericObjectPool;
import org.apache.lucene.document.Document;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.search.IndexSearcher;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Test;

import com.mycompany.jab.AggregationPolicy;
import com.mycompany.jab.Aggregator;
import com.mycompany.jab.ConcurrentTest;
import com.mycompany.jab.RepeatedTest;
import com.mycompany.jab.SingleTest;

/**
 * Harness to compare Lucene and MySQL cp index performance.
 */
public class JabExampleTest {

  // ======== Configuration Parameters =========
  
  private static final String INDEX_PATH = "/path/to/index";
  private static final String DATA_FILE = "/tmp/output.dat";
  private static final String DB_URL = "jdbc:mysql://localhost:3306/test";
  private static final String DB_USER = "root";
  private static final String DB_PASS = "secret";
  private static final int DB_POOL_INITIAL_SIZE = 1;
  private static final int DB_POOL_MAX_ACTIVE = 50;
  private static final int DB_POOL_MAX_WAIT = 5000;

  private static final int NUM_SEARCHES_PER_USER = 10;
  private static final int[] NUM_CONCURRENT_USERS = 
    new int[] {5,10,15,20,25,30,35,40,45,50,55,60,65,70,75,80,85,90,95,100};
  
  // ========= global vars for internal use ============
  
  private final Log log = LogFactory.getLog(getClass());
  
  private static final DecimalFormat DF = new DecimalFormat("####");
  
  private static IndexSearcher searcher;
  private static DataSource dataSource;
  private static List<String> uniqueIds;
  private static Random randomizer;
  private static PrintWriter outputWriter;

  private static final String MYSQL_QUERY = "select * from foo where ...";
  private static final int PREPARED_STATEMENT_FETCH_SIZE = 200;

  @BeforeClass
  public static void setUpBeforeTest() throws Exception {
    uniqueIds = getUniqueIds(INDEX_PATH);
    randomizer = new Random();
    searcher = new IndexSearcher(INDEX_PATH);
    dataSource = getPoolingDataSource();
    outputWriter = new PrintWriter(new OutputStreamWriter(
      new FileOutputStream(DATA_FILE)));
  }

  @AfterClass
  public static void tearDownAfterClass() throws Exception {
    searcher.close();
    outputWriter.flush();
    outputWriter.close();
  }

  @Test
  public void testCompareSearches() throws Exception {
    // set up reporting
    outputWriter.println(StringUtils.join(new String[] {
      "NUM-USERS",
      "LUCENE-AVG",
      "LUCENE-MAX",
      "LUCENE-MIN",
      "LUCENE-FAIL",
      "MYSQL-AVG",
      "MYSQL-MAX",
      "MYSQL-MIN",
      "MYSQL-FAIL"
    }, "\t"));
    // set up resources
    Map<String,Object> resources = new HashMap<String,Object>();
    resources.put("searcher", searcher);
    resources.put("dataSource", dataSource);
    resources.put("sqlQuery", MYSQL_QUERY);
    resources.put("preparedStmtFetchSize", PREPARED_STATEMENT_FETCH_SIZE);
    for (int numConcurrent : NUM_CONCURRENT_USERS) {
      // compute the random ids
      List<String> randomIds = 
        getRandomIds(numConcurrent * NUM_SEARCHES_PER_USER);
      Queue<String> luceneQueue = 
        new ConcurrentLinkedQueue<String>();
      luceneQueue.addAll(randomIds);
      Queue<String> mysqlQueue = 
        new ConcurrentLinkedQueue<String>();
      mysqlQueue.addAll(randomIds);
      resources.put("luceneQueue", luceneQueue);
      resources.put("mysqlQueue", mysqlQueue);
      // set up the tests
      log.debug("Running test with " + numConcurrent + " users...");
      ConcurrentTest luceneTest = new ConcurrentTest(
        new RepeatedTest(new SingleTest(
        LuceneSearchTestable.class, resources), 
        NUM_SEARCHES_PER_USER), numConcurrent);
      ConcurrentTest mysqlTest = new ConcurrentTest(
        new RepeatedTest(new SingleTest(
        MySQLSearchTestable.class, resources),
        NUM_SEARCHES_PER_USER), numConcurrent);
      // run them
      luceneTest.runTest();
      mysqlTest.runTest();
      // collect information and output to report
      Aggregator luceneAggregator = 
        new Aggregator(luceneTest.getObservations());
      Aggregator mysqlAggregator = 
        new Aggregator(mysqlTest.getObservations());
      outputWriter.println(StringUtils.join(new String[] {
        String.valueOf(numConcurrent),
        DF.format(luceneAggregator.aggregate(AggregationPolicy.AVERAGE)),
        DF.format(luceneAggregator.aggregate(AggregationPolicy.MAX)),
        DF.format(luceneAggregator.aggregate(AggregationPolicy.MIN)),
        DF.format(luceneAggregator.aggregate(AggregationPolicy.FAILED)),
        DF.format(mysqlAggregator.aggregate(AggregationPolicy.AVERAGE)),
        DF.format(mysqlAggregator.aggregate(AggregationPolicy.MAX)),
        DF.format(mysqlAggregator.aggregate(AggregationPolicy.MIN)),
        DF.format(mysqlAggregator.aggregate(AggregationPolicy.FAILED))
      }, "\t"));
    }
  }
  
  // ========= Methods to build and populate resources as applicable ========
  
  private static DataSource getPoolingDataSource() throws Exception {
    ObjectPool connectionPool = new GenericObjectPool(null);
    Properties connProps = new Properties();
    connProps.put("user", DB_USER);
    connProps.put("password", DB_PASS);
    connProps.put("initialSize", String.valueOf(DB_POOL_INITIAL_SIZE));
    connProps.put("maxActive", String.valueOf(DB_POOL_MAX_ACTIVE));
    connProps.put("maxWait", String.valueOf(DB_POOL_MAX_WAIT));
    Class.forName("com.mysql.jdbc.Driver");
    ConnectionFactory connectionFactory = 
      new DriverManagerConnectionFactory(DB_URL, connProps);
    PoolableConnectionFactory pcf = new PoolableConnectionFactory(
      connectionFactory, connectionPool, null, null, false, false);
    return new PoolingDataSource(connectionPool);
  }

  private static List<String> getUniqueIds(String cpIndexPath) 
      throws Exception {
    Set<String> uniqueImuidSet = new HashSet<String>();
    IndexReader reader = IndexReader.open(cpIndexPath);
    int numDocs = reader.maxDoc();
    for (int i = 0; i < numDocs; i++) {
      Document doc = reader.document(i);
      uniqueImuidSet.add(doc.get("myId"));
    }
    List<String> idlist = new ArrayList<String>();
    idlist.addAll(uniqueImuidSet);
    reader.close();
    return idlist;
  }

  private List<String> getRandomIds(int numRandom) {
    List<String> randomImuids = new ArrayList<String>();
    for (int i = 0; i < numRandom; i++) {
      int random = randomizer.nextInt(uniqueIds.size());
      randomImuids.add(uniqueIds.get(random));
    }
    return randomImuids;
  }
}

Test Results

Although the results of this exercise is not relevant for this post (since I am just describing the framework and how to use it), I thought it would be interesting, so I am including it here.

NUM-USERSLUCENE-AVGLUCENE-MAXLUCENE-MINLUCENE-FAILMYSQL-AVGMYSQL-MAXMYSQL-MINMYSQL-FAIL
539423604954430
10111940142390
15121930142160
20153410202860
251935602338120
30254440263790
35254450274180
404581160446250
454369904564100
503373004266110
55286720335960
60377230416920
65438430508140
70336370387030
755395170538480
80105232006411350
8552102606210040
90711464071110130
9556991206510620
100641281107312160

To visualize the data, I used the following gnuplot script to transform the average time observations into a graph.

1
2
3
4
5
6
7
set multiplot
set key off
set xlabel '#-users'
set ylabel 'response(ms)'
set yrange [0:150]
plot 'perfcomp.dat' using 1:2 with lines lt 1
plot 'perfcomp.dat' using 1:6 with lines lt 2

The graph is shown below. Not too many surprises here, there are quite a few people who've reached the same conclusion, that it is as performant, and often more convenient, to serve results of exact queries from a MySQL database than from a Lucene index.

Conclusion

Prior to this, I would either resort to wrapping a component in a web interface and used ab or siege, or written JUnit tests that did the multithreading inline with the code being tested. I think this approach is cleaner and perhaps more scalable, since it separates out the component being tested from the actual test parameters, allowing you to model more complex scenarios.

I am curious as to what other people do in similar situations. If you have had similar needs, I would appreciate knowing how you approached it. I am also curious if other people think this is complete enough to release as a project - I don't really want the headache of maintaining and improving the project, I just figure that it may be useful to have it somewhere where people can download it, use and maybe improve it and check the fixes/features back in.

Also, I don't normally write multi-threaded code, just because its not needed that often for the stuff I work on, so there may be obvious bugs that a reader who does multi-threaded stuff for a living (and some that do not) may spot immediately. If so, please let me know and I will make the necessary corrections.

Friday, November 14, 2008

IR Math in Java : Rule based POS Tagger

In my previous post, I described an HMM based Part of Speech tagger. This post describes a rule based POS tagger loosely based on ideas underlying the design of the Brill Tagger, as described in Chapter 4 of the TMAP book.

The rule based tagger provides a single method tagSentence(), which takes a sentence as input, and returns the same sentence tagged with the appropriate part of speech. I was too lazy to add a convenience method to return the POS for a given word in a sentence, since its really simple, and I have already done this in the code in my previous post.

Wordnet is used to find the part of speech for each word in the sentence. We use MIT Java Wordnet Interface (JWI) to access the Wordnet database. Wordnet can only recognize the following four parts of speech - NOUN, VERB, ADJECTIVE and ADVERB. Therefore, our POS tagging is restricted to these four and OTHER. We enhance our Pos enum class from our previous post with methods to convert Wordnet POS to and from our Pos enum, as shown below:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
// Source: src/main/java/com/mycompany/myapp/postaggers/Pos.java
package com.mycompany.myapp.postaggers;

import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.InputStreamReader;
import java.util.HashMap;
import java.util.Map;

import org.apache.commons.collections15.BidiMap;
import org.apache.commons.collections15.bidimap.DualHashBidiMap;
import org.apache.commons.lang.StringUtils;

import edu.mit.jwi.item.POS;

public enum Pos {

  NOUN, VERB, ADJECTIVE, ADVERB, OTHER;

  private static Map<String,Pos> bmap = null;
  private static BidiMap<Pos,POS> wmap = null;
  private static final String translationFile = 
    "src/main/resources/brown_tags.txt";
  
  public static Pos fromBrownTag(String btag) throws Exception {
    // .. omitted for brevity, see previous post for body
  }

  public static Pos fromWordnetPOS(POS pos) {
    if (wmap == null) {
      wmap = buildPosBidiMap();
    }
    return wmap.getKey(pos);
  }
  
  public static POS toWordnetPos(Pos pos) {
    if (wmap == null) {
      wmap = buildPosBidiMap();
    }
    return wmap.get(pos);
  }

  private static BidiMap<Pos,POS> buildPosBidiMap() {
    wmap = new DualHashBidiMap<Pos,POS>();
    wmap.put(Pos.NOUN, POS.NOUN);
    wmap.put(Pos.VERB, POS.VERB);
    wmap.put(Pos.ADJECTIVE, POS.ADJECTIVE);
    wmap.put(Pos.ADVERB, POS.ADVERB);
    wmap.put(Pos.OTHER, null);
    return wmap;
  }
}

When Wordnet is asked what the POS of a particular word is, it can return one of the following three results, each of which are handled differently as described below:

  1. No POS found for word
  2. Single Unique POS found for word
  3. Multiple POS found for word

No POS found for Word

In this case, Wordnet may not know about the word being checked, or it could be a proper noun. We use a combination of word pattern rules to try and guess the POS for this word. If none of the patterns match, then the POS is considered to be OTHER.

First, we check to see if the first letter is uppercase, in that case we assume that it is a proper noun, and therefore tag it with Pos.NOUN.

If not, we check if the word ends with one of the known suffixes that exist in our suffix to POS mappings, longest suffix first. The POS corresponding to the first matched suffix is used to tag the word.

If not, the word is tagged as Pos.OTHER.

Single Unique POS found for Word

In this case, there is no confusion -- Wordnet tells us that there is a single POS found for the word, so we tag the word with this POS, and continue on with our life...err, I mean, the next word.

Multiple POS found for Word

When Wordnet reports multiple POS possibilities for a word, it means that the word can be used as different POS depending on where it is used in the sentence - in other words, the context determines the POS.

The tagger considers the context as the word trigram surrounding the word (i.e. before/current/after). Two rules are fired, the word-backward rule and the word-forward rule, unless the word happens to occur at the beginning or end of the sentence, in which case only one rule is fired. The objective of these rules is to find the most likely POS for the word based on the POS of its anterior and posterior neighbors.

The probabilities used to compute the likelihood comes from the transition probabilities (A values) that were computed from the Brown Corpus in my previous post.

Each rule finds the highest probability of a particular POS pairs (before/current and current/after) occurring and associates it with the word's POS. The two probabilities are then added and the word POS corresponding to the highest probability is used to tag the word.

In addition, we could have used the emission probabilities (Π) from from our last post to similarly resolve ambiguous POS for the first word in the sentence, but we did not do this.

Tagger code

The code for the tagger is shown below:

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
// Source: src/main/java/com/mycompany/myapp/postaggers/RuleBasedTagger.java
package com.mycompany.myapp.postaggers;

import java.io.BufferedReader;
import java.io.FileReader;
import java.net.URL;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;

import org.apache.commons.lang.StringUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

import com.mycompany.myapp.clustering.ByValueComparator;
import com.mycompany.myapp.tokenizers.Token;
import com.mycompany.myapp.tokenizers.TokenType;
import com.mycompany.myapp.tokenizers.WordTokenizer;

import edu.mit.jwi.Dictionary;
import edu.mit.jwi.IDictionary;
import edu.mit.jwi.item.IIndexWord;

public class RuleBasedTagger {

  private final Log log = LogFactory.getLog(getClass());
  
  private class Context {
    public String prev;
    public List<Pos> prevPos;
    public String curr;
    public List<Pos> nextPos;
    public String next;
    public String toString() {
      return StringUtils.join(new String[] {prev,curr,next}, "/");
    }
  };
  
  private IDictionary wordnetDictionary;
  private Map<String,Pos> suffixPosMap;
  private double[][] tp;
  
  public void setWordnetDictLocation(String wordnetDictLocation) 
      throws Exception {
    this.wordnetDictionary = new Dictionary(
      new URL("file", null, wordnetDictLocation));
    this.wordnetDictionary.open();
  }

  public void setSuffixMappingLocation(String suffixMappingLocation) 
      throws Exception {
    String line;
    this.suffixPosMap = new TreeMap<String,Pos>(
      new Comparator<String>() {
        public int compare(String s1, String s2) {
          int l1 = s1.length();
          int l2 = s2.length();
          if (l1 == l2) {
            return s1.compareTo(s2);
          } else {
            return (l2 > l1 ? 1 : -1);
          }
        }
      }
    );
    BufferedReader reader = new BufferedReader(
      new FileReader(suffixMappingLocation));
    while ((line = reader.readLine()) != null) {
      if (StringUtils.isEmpty(line) || line.startsWith("#")) {
        continue;
      }
      String[] suffixPosPair = StringUtils.split(line, "\t");
      suffixPosMap.put(suffixPosPair[0], Pos.valueOf(suffixPosPair[1]));
    }
    reader.close();
  }

  public void setTransitionProbabilityDatafile(
      String transitionProbabilityDatafile) throws Exception {
    int numPos = Pos.values().length;
    tp = new double[numPos][numPos];
    BufferedReader reader = new BufferedReader(
      new FileReader(transitionProbabilityDatafile));
    int i = 0; // row
    String line;
    while ((line = reader.readLine()) != null) {
      if (StringUtils.isEmpty(line) || line.startsWith("#")) {
        continue;
      }
      String[] parts = StringUtils.split(line, "\t");
      for (int j = 0; j < parts.length; j++) {
        tp[i][j] = Double.valueOf(parts[j]);
      }
      i++;
    }
    reader.close();
  }

  public String tagSentence(String sentence) throws Exception {
    StringBuilder taggedSentenceBuilder = new StringBuilder();
    WordTokenizer wordTokenizer = new WordTokenizer();
    wordTokenizer.setText(sentence);
    List<Token> tokens = new ArrayList<Token>();
    Token token = null;
    while ((token = wordTokenizer.nextToken()) != null) {
      tokens.add(token);
    }
    // extract the words from the tokens
    List<String> words = new ArrayList<String>();
    for (Token tok : tokens) {
      if (tok.getType() == TokenType.WORD) {
        words.add(tok.getValue());
      }
    }
    // for each word, find the pos
    int position = 0;
    for (String word : words) {
      Pos partOfSpeech = getPartOfSpeech(words, word, position);
      taggedSentenceBuilder.append(word).
        append("/").
        append(partOfSpeech.name()).
        append(" ");
      position++;
    }
    return taggedSentenceBuilder.toString();
  }

  private Pos getPartOfSpeech(List<String> wordList, String word, 
      int position) {
    List<Pos> partsOfSpeech = getPosFromWordnet(word);
    int numPos = partsOfSpeech.size();
    if (numPos == 0) {
      // unknown Pos, apply word rules to figure out Pos
      if (startsWithUppercase(word)) {
        return Pos.NOUN;
      }
      Pos pos = getPosBasedOnSuffixRules(word);
      if (pos != null) {
        return pos;
      } else {
        return Pos.OTHER;
      }
    } else if (numPos == 1) {
      // unique Pos, return
      return partsOfSpeech.get(0);
    } else {
      // ambiguous Pos, apply disambiguation rules
      Context context = getContext(wordList, position);
      Map<Pos,Double> posProbs = new HashMap<Pos,Double>();
      if (context.prev != null) {
        // backward neighbor rule
        accumulatePosProbabilities(posProbs, word, partsOfSpeech, 
          context.prev, context.prevPos, false);
      }
      if (context.next != null) {
        // forward neighbor rule
        accumulatePosProbabilities(posProbs, word, partsOfSpeech, 
          context.next, context.nextPos, true);
      }
      if (posProbs.size() == 0) {
        return Pos.OTHER;
      } else {
        ByValueComparator<Pos,Double> bvc = 
          new ByValueComparator<Pos,Double>(posProbs);
        List<Pos> poslist = new ArrayList<Pos>();
        poslist.addAll(posProbs.keySet());
        Collections.sort(poslist, Collections.reverseOrder(bvc));
        return poslist.get(0);
      }
    }
  }

  private List<Pos> getPosFromWordnet(String word) {
    List<Pos> poslist = new ArrayList<Pos>();
    for (Pos pos : Pos.values()) {
      try {
        IIndexWord indexWord = 
          wordnetDictionary.getIndexWord(word, Pos.toWordnetPos(pos));
        if (indexWord != null) {
          poslist.add(pos);
        }
      } catch (NullPointerException e) {
        // JWI throws this if it cannot find the word in its dictionary
        // so we just dont add anything to the poslist.
        continue;
      }
    }
    return poslist;
  }

  private boolean startsWithUppercase(String word) {
    return word.charAt(0) == Character.UPPERCASE_LETTER;
  }

  private Pos getPosBasedOnSuffixRules(String word) {
    for (String suffix : suffixPosMap.keySet()) {
      if (StringUtils.lowerCase(word).endsWith(suffix)) {
        return suffixPosMap.get(suffix);
      }
    }
    return null;
  }

  private Context getContext(List<String> words, int wordPosition) {
    Context context = new Context();
    if ((wordPosition - 1) >= 0) {
      context.prev = words.get(wordPosition - 1);
      context.prevPos = getPosFromWordnet(context.prev);
    }
    context.curr = words.get(wordPosition);
    if ((wordPosition + 1) < words.size()) {
      context.next = words.get(wordPosition + 1);
      context.nextPos = getPosFromWordnet(context.next);
    }
    return context;
  }
  
  private void accumulatePosProbabilities(
      Map<Pos,Double> posProbabilities,
      String word, List<Pos> wordPosList, String neighbor, 
      List<Pos> neighborPosList, boolean isForwardRule) {
    if (isForwardRule) {
      for (Pos wordPos : wordPosList) {
        for (Pos neighborPos : neighborPosList) {
          double prob = 
            tp[wordPos.ordinal()][neighborPos.ordinal()];
          updatePosProbabilities(posProbabilities, wordPos, prob);
        }
      }
    } else {
      for (Pos neighborPos : neighborPosList) {
        for (Pos wordPos : wordPosList) {
          double prob = 
            tp[neighborPos.ordinal()][wordPos.ordinal()];
          updatePosProbabilities(posProbabilities, wordPos, prob);
        }
      }
    }
  }

  private void updatePosProbabilities(
      Map<Pos,Double> posProbabilities,
      Pos wordPos, double prob) {
    Double origProb = posProbabilities.get(wordPos);
    if (origProb == null) {
      posProbabilities.put(wordPos, prob);
    } else {
      posProbabilities.put(wordPos, prob + origProb);
    }
  }
}

Test Code and Data Files

The test code for this is really simple. All we do is to instantiate the RuleBasedTagger and set into it the location of the Wordnet dictionary, the suffix mapping data file and the transition probabilities (A values) from the HMM data file that we built in our previous post. The files are described in more detail below. Once instantiated and set up, we feed it a set of sentences, and get back POS-tagged sentences. Here is the code for the JUnit test.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
// Source: src/test/java/com/mycompany/myapp/postaggers/RuleBasedTaggerTest.java
package com.mycompany.myapp.postaggers;

import org.junit.Test;

public class RuleBasedTaggerTest {

  private String[] INPUT_TEXTS = {
    "The growing popularity of Linux in Asia, Europe, and the US is " +
    "a major concern for Microsoft.",
    "Jaguar will sell its new XJ-6 model in the US for a small fortune.",
    "The union is in a sad state.",
    "Please do not state the obvious.",
    "I am looking forward to the state of the union address.",
    "I have a bad cold today.",
    "The cold war was over long ago."
  };

  @Test
  public void testTagSentence() throws Exception {
    for (String sentence : INPUT_TEXTS) {
      RuleBasedTagger tagger = new RuleBasedTagger();
      tagger.setWordnetDictLocation("/opt/wordnet-3.0/dict");
      tagger.setSuffixMappingLocation("src/main/resources/pos_suffixes.txt");
      tagger.setTransitionProbabilityDatafile(
        "src/main/resources/pos_trans_prob.txt");
      String taggedSentence = tagger.tagSentence(sentence);
      System.out.println("Original: " + sentence);
      System.out.println("Tagged:   " + taggedSentence);
    }
  }
}

I am using Wordnet-3.0 data files which can be downloaded from here.

The suffix to POS mapping file was created manually from data available here and here. The suffix and POS are tab separated. A partial listing follows:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
# Source: src/main/resources/pos_suffixes.txt
# POS Suffixes
#SUFFIX POS
dom     NOUN
ity     NOUN
ment    NOUN
sion    NOUN
tion    NOUN
ness    NOUN
ance    NOUN
ence    NOUN
er      NOUN
...

The transition probabilities (A values) are taken from the HMM text file, which was generated from the Brown Corpus as described in my previous post. The file is a tab separated file of observed probabilities for transitioning from one POS to another. The numbers should add up to 1 across each line. Here is what the file looks like:

1
2
3
4
5
6
7
# Source: src/main/resources/pos_trans_prob.txt
# NOUN  VERB    ADJECTIVE  ADVERB  OTHER
0.155   0.156   0.019      0.025   0.645
0.095   0.195   0.168      0.094   0.449
0.639   0.024   0.148      0.005   0.183
0.052   0.228   0.111      0.041   0.569
0.206   0.199   0.205      0.039   0.351

Results

The results of running the test are shown below (edited for readability by adding newlines to break up the original and tagged sentences so words don't break across lines). As you can see, the tagging is fairly accurate.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
Original: The growing popularity of Linux in Asia, Europe, and the US is 
          a major concern for Microsoft.
Tagged:   The/OTHER growing/ADJECTIVE popularity/NOUN of/OTHER Linux/NOUN 
          in/ADJECTIVE Asia/NOUN Europe/NOUN and/OTHER the/OTHER US/NOUN 
          is/OTHER a/NOUN major/ADJECTIVE concern/NOUN for/NOUN 
          Microsoft/OTHER 

Original: Jaguar will sell its new XJ-6 model in the US for a small fortune.
Tagged:   Jaguar/NOUN will/NOUN sell/VERB its/OTHER new/OTHER XJ-6/OTHER
          model/ADJECTIVE in/NOUN the/OTHER US/NOUN for/NOUN a/NOUN 
          small/ADJECTIVE fortune/NOUN 

Original: The union is in a sad state.
Tagged:   The/OTHER union/OTHER is/OTHER in/ADJECTIVE a/NOUN sad/ADJECTIVE
          state/NOUN 

Original: Please do not state the obvious.
Tagged:   Please/VERB do/VERB not/ADVERB state/VERB the/OTHER 
          obvious/ADJECTIVE 

Original: I am looking forward to the state of the union address.
Tagged:   I/ADJECTIVE am/NOUN looking/ADJECTIVE forward/NOUN to/OTHER 
          the/OTHER state/OTHER of/OTHER the/OTHER union/ADJECTIVE 
          address/NOUN 

Original: I have a bad cold today.
Tagged:   I/ADJECTIVE have/NOUN a/NOUN bad/ADJECTIVE cold/NOUN today/NOUN 

Original: The cold war was over long ago.
Tagged:   The/OTHER cold/ADJECTIVE war/NOUN was/OTHER over/ADVERB 
          long/VERB ago/ADJECTIVE 

Conclusion

Unlike the HMM approach, where all the information was built into the model at the outset, the Rule based approach takes advantage of the fact that we can use a tagged dictionary (Wordnet) to tell the POS for a word as we encounter it, and that most words resolve to a single POS. For those that are not recognized by Wordnet, we use simple word pattern matching to deduce the POS. For those that resolve to multiple POS, we use collocation probability rules to disambiguate it.

And now, for something completely different...

On a completely different note, regular (i.e. more than first-time) readers may have noticed some new widgets on my blog. This week, I added a tag cloud widget, spiffed up the "Blogs I read" widget with a new one from Blogger which displays the favicon of the blog and reports on the last time it was updated, added social bookmarking links at the bottom of the post, and added URLs for Atom and RSS Syndication feeds for my blog. Hope you like them, and no, this is not some sinister black-hat SEO bid to pollute/enrich the World Wide Web with my blog contents. For my reasons, read on...

I've wanted to get a tag cloud ever since I wrote about how to build one with Python. I think it is an awesome way to provide a snapshot of the entire site contents at a glance. However, Blogger does not give you a Tag Cloud Widget, although it does give you a Label widget on which this one is based, and I never had the time to mess around with the HTML Template until last week.

For the social bookmarking links, I noticed that one of you had submitted one of my posts to Digg. Whoever it was, thank you, especially since you had to do this manually. For those of who who like my posts enough to submit to your favorite social bookmarking site, there is now a bank of links under my byline and tags which you can click and which will pre-populate the title and URL so you don't have to cut and paste.

The syndication URLs just seemed to be a good idea because I've written quite a few times about ROME and RSS, and its funny that I didn't offer a syndication feed URL myself before this. So now I offer two, one for RSS and another for Atom - take your pick :-).

Update 2009-04-26: In recent posts, I have been building on code written and described in previous posts, so there were (and rightly so) quite a few requests for the code. So I've created a project on Sourceforge to host the code. You will find the complete source code built so far in the project's SVN repository.

Saturday, November 08, 2008

IR Math in Java : HMM Based POS Tagger/Recognizer

As you know, I have been slowly working my way through Dr Konchady's TMAP book, and coding up the algorithms in Java. By doing so, I hope to understand the techniques and mathematical models presented, so I can apply them to real-life problems in the future. In this post I describe an implementation of a Hidden Markov Model based Part of Speech recognizer/tagger, based on the material presented in Chapter 4 of the TMAP book.

Background

What follows is my take on what an HMM is and how it can be used for Part of Speech (POS) tagging. For a more detailed, math-heavy, and possibly more accurate description of HMM and their internals, read the Wikipedia article or Dr Rabiner's tutorial or the TMAP book if you happen to own it. A Hidden Markov Model can be thought of as a probabilistic finite state machine. Its states can be represented by the set S = {S1, S2, ..., Sn} which are not directly visible. What is visible is a set of Observations O = {O1, O2, ..., Om} which are the result of the machine moving from one state to the other. The probabilities of the machine starting in one of the states Si is specified by the one-dimensional matrix Π of size n. The probabilities of the machine moving from one state to another is specified by a two dimensional matrix A of size n*n. Finally, the probability of an observation being observed when the machine is in a certain state is given by the two dimensional matrix B of size n*m. More succintly,

  H = f(Π, A, B)
  where:
    H = the HMM
    Π = start probabilities. The element Πi represents 
        the probability that the HMM starts a sequence in State
        State Si, where i in (0..n-1).
    A = transition probabilities. The element Ai,j represents
        the probability of a transition from State Si to
        State Sj, where i and j in (0..n-1).
    B = emission probabilities. The element Bi,j represents
        the probability of an Observation Oj occurring while
        the machine is in State Si, where i in (0..n-1)
        and j in (0..m-1).
    n = number of states.
    m = number of unique observations.

The objective of POS tagging is to tag each word of a sentence with its part-of-speech tag. While some words can be unambiguously tagged, ie their is only one POS that the word is ever used for, there are quite a few which can represent different POS depending on its usage in the sentence. For example, cold can be both a noun and an adjective, and catch can be both a noun and a verb. The fact that the word exists in the sentence is known, while the POS for the word is unknown. Therefore the HMM built for POS tagging would model the words as visible observations and the set of possible POS as the hidden states.

As far as POS tagging is concerned, the main problems that can be solved by HMMs are as follows. Given an HMM,

  1. Finding the most likely state sequence for a given observation sequence. In this case, we pass in a sentence, and tag each word with its most likely POS.
  2. Finding the most likely state for a given observation in a sequence. This is useful for word sense disambiguation (WSD), so we can tell the most likely POS that a particular word in a sentence belongs to.

The problems above are identical from the point of view of a HMM, and are solved using the Viterbi algorithm.

The second problem is to find the probability of a certain sequence of observations. This can be used to answer questions such as whether a sentence such as "I am going to the market" is more common than one such as "To the market I go". The Forward-Backward Algorithm is used to solve this kind of problem. This can be useful for applications that predict the most likely outcome given a set of input observations, but probably is not important from the perspective of POS tagging. Both the above problems need to have a HMM built from (manually) tagged data.

A third problem of HMMs is how to build one given a corpus of untagged text. Such an HMM would allow us to solve the second type of problem. However, since the HMM has not been fed with tagged words, it must depend on a clustering algorithm such as K-Means to cluster the words into undefined hidden states, which are of no use when attempting to solve the second type of problem. The two learning algorithms used here are the K-Means Algorithm to build the initial HMM and the Baum-Welch Algorithm to refine it. As with the second type of problem, this does not have much applications where POS taggers are concerned.

I used the Java HMM library Jahmm to do all of the heavy computational lifting. It has implementations of the algorithms mentioned above, as well as several utility methods and classes to model various kinds of Observation.

Building the HMM from Tagged Data

For my tagged corpus, I used the Brown Corpus, downloading the data from the Natural Language Toolkit Project (NTLP). The corpus is a set of about 500 files containing one sentence per line, each manually tagged with a very comprehensive set of POS tags described here. Since I plan to use Wordnet at some point with this data, and Wordnet only categorizes words as one of 4 categories, I set up my own Part of Speech Enum called Pos which has 5 categories, the 4 from Wordnet and OTHER. As a result, I had to the dumb the Brown tags down using the translation table shown below:

BTAG POS BTAG POS BTAG POS BTAG POS
( OTHER FW-CS OTHER MD VERB RBR+CS ADVERB
) OTHER FW-DT OTHER MD* VERB RBT ADVERB
* OTHER FW-DT+BEZ OTHER MD+HV VERB RN ADVERB
, OTHER FW-DTS OTHER MD+PPSS VERB RP ADVERB
-- OTHER FW-HV VERB MD+TO VERB RP+IN ADVERB
. OTHER FW-IN OTHER NN NOUN TO OTHER
: OTHER FW-IN+AT OTHER NN$ NOUN TO+VB VERB
ABL OTHER FW-IN+NN OTHER NN+BEZ NOUN UH OTHER
ABN OTHER FW-IN+NP OTHER NN+HVD NOUN VB VERB
ABX OTHER FW-JJ ADJECTIVE NN+HVZ NOUN VB+AT VERB
AP OTHER FW-JJR ADJECTIVE NN+IN NOUN VB+IN VERB
AP$ OTHER FW-JJT ADJECTIVE NN+MD NOUN VB+JJ VERB
AP+AP OTHER FW-NN NOUN NN+NN NOUN VB+PPO VERB
AT ADJECTIVE FW-NN$ NOUN NNS NOUN VB+RP VERB
BE VERB FW-NNS NOUN NNS$ NOUN VB+TO VERB
BED VERB FW-NP NOUN NNS+MD NOUN VB+VB VERB
BED* VERB FW-NPS NOUN NP NOUN VBD VERB
BEDZ VERB FW-NR NOUN NP$ NOUN VBG VERB
BEDZ* VERB FW-OD NOUN NP+BEZ NOUN VBG+TO VERB
BEG VERB FW-PN OTHER NP+HVZ NOUN VBN VERB
BEM VERB FW-PP$ OTHER NP+MD NOUN VBN+TO VERB
BEM* VERB FW-PPL OTHER NPS NOUN VBZ VERB
BEN VERB FW-PPL+VBZ OTHER NPS$ NOUN WDT OTHER
BER VERB FW-PPO OTHER NR NOUN WDT+BER OTHER
BER* VERB FW-PPO+IN OTHER NR$ NOUN WDT+BER+PP OTHER
BEZ VERB FW-PPS OTHER NR+MD NOUN WDT+BEZ OTHER
BEZ* VERB FW-PPSS OTHER NRS NOUN WDT+DO+PPS OTHER
CC OTHER FW-PPSS+HV OTHER OD NOUN WDT+DOD OTHER
CD NOUN FW-QL OTHER PN OTHER WDT+HVZ OTHER
CD$ NOUN FW-RB ADVERB PN$ OTHER WP$ OTHER
CS OTHER FW-RB+CC ADVERB PN+BEZ OTHER WPO OTHER
DO VERB FW-TO+VB VERB PN+HVD OTHER WPS OTHER
DO* VERB FW-UH OTHER PN+HVZ OTHER WPS+BEZ OTHER
DO+PPSS VERB FW-VB VERB PN+MD OTHER WPS+HVD OTHER
DOD VERB FW-VBD VERB PP$ OTHER WPS+HVZ OTHER
DOD* VERB FW-VBG VERB PP$$ OTHER WPS+MD OTHER
DOZ VERB FW-VBN VERB PPL OTHER WQL OTHER
DOZ* VERB FW-VBZ VERB PPLS OTHER WRB ADVERB
DT OTHER FW-WDT OTHER PPO OTHER WRB+BER ADVERB
DT$ OTHER FW-WPO OTHER PPS OTHER WRB+BEZ ADVERB
DT+BEZ OTHER FW-WPS OTHER PPS+BEZ OTHER WRB+DO ADVERB
DT+MD OTHER HV VERB PPS+HVD OTHER WRB+DOD ADVERB
DTI OTHER HV* VERB PPS+HVZ OTHER WRB+DOD* ADVERB
DTS OTHER HV+TO VERB PPS+MD OTHER WRB+DOZ ADVERB
DTS+BEZ OTHER HVD VERB PPSS OTHER WRB+IN ADVERB
DTX OTHER HVD* VERB PPSS+BEM OTHER WRB+MD ADVERB
EX VERB HVG VERB PPSS+BER OTHER - -
EX+BEZ VERB HVN VERB PPSS+BEZ OTHER - -
EX+HVD VERB HVZ VERB PPSS+BEZ* OTHER - -
EX+HVZ VERB HVZ* VERB PPSS+HV OTHER - -
EX+MD VERB IN OTHER PPSS+HVD OTHER - -
FW-* OTHER IN+IN OTHER PPSS+MD OTHER - -
FW-AT ADJECTIVE IN+PPO OTHER PPSS+VB OTHER - -
FW-AT+NN ADJECTIVE JJ ADJECTIVE QL OTHER - -
FW-AT+NP ADJECTIVE JJ$ ADJECTIVE QLP OTHER - -
FW-BE VERB JJ+JJ ADJECTIVE RB ADVERB - -
FW-BER VERB JJR ADJECTIVE RB$ ADVERB - -
FW-BEZ VERB JJR+CS ADJECTIVE RB+BEZ ADVERB - -
FW-CC OTHER JJS ADJECTIVE RB+CS ADVERB - -
FW-CD NOUN JJT ADJECTIVE RBR ADVERB - -

You may notice that some of the mappings are not correct. Unfortunately, my knowledge of formal English grammar is not as good as I would like it to be, owing to having been educated in an environment that posited that a person can learn to recognize incorrect grammar better by reading and writing enough gramatically correct sentences rather than through a study of language rules. My grandfather, obviously regarding all this as crazy talk, briefly attempted to rectify that, armed with a Wren and Martin and a 18" ruler, but as you can probably see, it did not work out all that well :-).

The code for the Pos Enum is shown below. As mentioned earlier, it exposes a set of 5 POS values, and has a convenience method to convert the Brown tag into corresponding Pos.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
// Source: src/main/java/com/mycompany/myapp/postaggers/Pos.java
package com.mycompany.myapp.postaggers;

import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.InputStreamReader;
import java.util.HashMap;
import java.util.Map;

import org.apache.commons.lang.StringUtils;

/**
 * Enumeration of Parts of Speech being considered. Conversions from
 * Brown Tags and Wordnet tags are handled by convenience methods.
 */
public enum Pos {

  NOUN, VERB, ADJECTIVE, ADVERB, OTHER;

  private static Map<String,Pos> bmap = null;
  private static final String translationFile = 
    "src/main/resources/brown_tags.txt";
  
  public static Pos fromBrownTag(String btag) throws Exception {
    if (bmap == null) {
      bmap = new HashMap<String,Pos>();
      BufferedReader reader = new BufferedReader(new InputStreamReader(
          new FileInputStream(translationFile)));
      String line;
      while ((line = reader.readLine()) != null) {
        if (line.startsWith("#")) {
          continue;
        }
        String[] cols = StringUtils.split(line, "\t");
        bmap.put(StringUtils.lowerCase(cols[0]), Pos.valueOf(cols[1])); 
      }
      reader.close();
    }
    Pos pos = bmap.get(btag);
    if (pos == null) {
      return Pos.OTHER;
    }
    return pos;
  }
}

The BrownCorpusReader reads through each tagged file in the Brown Corpus directory, extracts the word and the tag out of each tagged word, converts the Brown tag to its equivalent Pos value, and accumulates the occurrences into internal counters. Once all files are processed, the counters are normalized into the three probability matrices Π, A and B that we spoke about earlier.

Since the number of words tagged in any corpus is potentially quite large, we represent the words (or observations) in the HMM as an integer. That is why the BrownCorpusReader also dumps out a list of unique words it found in the corpus into a flat file which can be pulled back into memory later to do the mapping between the word and the integer observation Id.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
// Source: src/main/java/com/mycompany/myapp/postaggers/BrownCorpusReader.java
package com.mycompany.myapp.postaggers;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileWriter;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.StringTokenizer;

import org.apache.commons.collections15.Bag;
import org.apache.commons.collections15.bag.HashBag;
import org.apache.commons.lang.StringUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

import Jama.Matrix;

/**
 * Reads a file or directory of tagged text from Brown Corpus and
 * computes the various probability matrices for the HMM.
 */
public class BrownCorpusReader {

  private final Log log = LogFactory.getLog(getClass());
  
  private String dataFilesLocation;
  private String wordDictionaryLocation;
  private boolean debug;

  private Bag<String> piCounts = new HashBag<String>();
  private Bag<String> aCounts = new HashBag<String>();
  private Map<String,Double[]> wordPosMap = 
    new HashMap<String,Double[]>();
  
  private Matrix pi;
  private Matrix a;
  private Matrix b;
  private List<String> words;
  
  public void setDataFilesLocation(String dataFilesLocation) {
    this.dataFilesLocation = dataFilesLocation;
  }
  
  public void setWordDictionaryLocation(String wordDictionaryLocation) {
    this.wordDictionaryLocation = wordDictionaryLocation;
  }
  
  public void setDebug(boolean debug) {
    this.debug = debug;
  }

  public void read() throws Exception {
    File location = new File(dataFilesLocation);
    File[] inputs;
    if (location.isDirectory()) {
      inputs = location.listFiles();
    } else {
      inputs = new File[] {location};
    }
    int currfile = 0;
    int totfiles = inputs.length;
    for (File input : inputs) {
      currfile++;
      log.info("Processing file (" + currfile + "/" + totfiles + "): " + 
        input.getName());
      BufferedReader reader = new BufferedReader(new InputStreamReader(
        new FileInputStream(input)));
      String line;
      while ((line = reader.readLine()) != null) {
        if (StringUtils.isEmpty(line)) {
          continue;
        }
        StringTokenizer tok = new StringTokenizer(line, " ");
        int wordIndex = 0;
        Pos prevPos = null;
        while (tok.hasMoreTokens()) {
          String taggedWord = tok.nextToken();
          String[] wordTagPair = StringUtils.split(
            StringUtils.lowerCase(StringUtils.trim(taggedWord)), "/");
          if (wordTagPair.length != 2) {
            continue;
          }
          Pos pos = Pos.fromBrownTag(wordTagPair[1]);
          if (! wordPosMap.containsKey(wordTagPair[0])) {
            // create an entry
            Double[] posProbs = new Double[Pos.values().length];
            for (int i = 0; i < posProbs.length; i++) {
              posProbs[i] = new Double(0.0D);
            }
            wordPosMap.put(wordTagPair[0], posProbs);
          }
          Double[] posProbs = wordPosMap.get(wordTagPair[0]);
          posProbs[pos.ordinal()] += 1.0D;
          wordPosMap.put(wordTagPair[0], posProbs);
          if (wordIndex == 0) {
            // first word, update piCounts
            piCounts.add(pos.name());
          } else {
            aCounts.add(StringUtils.join(new String[] {
              prevPos.name(), pos.name()}, ":"));
          }
          prevPos = pos;
          wordIndex++;
        }
      }
      reader.close();
    }
    // normalize counts to probabilities
    int numPos = Pos.values().length;
    // compute pi
    pi = new Matrix(numPos, 1);
    for (int i = 0; i < numPos; i++) {
      pi.set(i, 0, piCounts.getCount((Pos.values()[i]).name()));
    }
    pi = pi.times(1 / pi.norm1());
    // compute a
    a = new Matrix(numPos, numPos);
    for (int i = 0; i < numPos; i++) {
      for (int j = 0; j < numPos; j++) {
        a.set(i, j, aCounts.getCount(StringUtils.join(new String[] {
          (Pos.values()[i]).name(), (Pos.values()[j]).name()
        }, ":")));
      }
    }
    // compute b
    int numWords = wordPosMap.size();
    words = new ArrayList<String>();
    words.addAll(wordPosMap.keySet());
    b = new Matrix(numPos, numWords);
    for (int i = 0; i < numPos; i++) {
      for (int j = 0; j < numWords; j++) {
        String word = words.get(j);
        b.set(i, j, wordPosMap.get(word)[i]);
      }
    }
    // normalize across rows for a and b (sum of cols in each row == 1.0)
    for (int i = 0; i < numPos; i++) {
      double rowSumA = 0.0D;
      for (int j = 0; j < numPos; j++) {
        rowSumA += a.get(i, j);
      }
      for (int j = 0; j < numPos; j++) {
        a.set(i, j, (a.get(i, j) / rowSumA));
      }
      double rowSumB = 0.0D;
      for (int j = 0; j < numWords; j++) {
        rowSumB += b.get(i, j);
      }
      for (int j = 0; j < numWords; j++) {
        b.set(i, j, (b.get(i, j) / rowSumB));
      }
    }
    // write out brown word dictionary for later use
    writeDictionary();
    // debug
    if (debug) {
      pi.print(8, 4);
      a.print(8, 4);
      b.print(8, 4);
      System.out.println(words.toString());
    }
  }
  
  public List<String> getWords() {
    return words;
  }
  
  public double[] getPi() {
    double[] pia = new double[pi.getRowDimension()];
    for (int i = 0; i < pia.length; i++) {
      pia[i] = pi.get(i, 0);
    }
    return pia;
  }
  
  public double[][] getA() {
    double[][] aa = new double[a.getRowDimension()][a.getColumnDimension()];
    for (int i = 0; i < a.getRowDimension(); i++) {
      for (int j = 0; j < a.getColumnDimension(); j++) {
        aa[i][j] = a.get(i, j);
      }
    }
    return aa;
  }
  
  public double[][] getB() {
    double[][] ba = new double[b.getRowDimension()][b.getColumnDimension()];
    for (int i = 0; i < b.getRowDimension(); i++) {
      for (int j = 0; j < b.getColumnDimension(); j++) {
        ba[i][j] = b.get(i, j);
      }
    }
    return ba;
  }

  private void writeDictionary() throws Exception {
    FileWriter dictWriter = new FileWriter(wordDictionaryLocation);
    for (String word : words) {
      dictWriter.write(word + "\n");
    }
    dictWriter.flush();
    dictWriter.close();
  }
}

We generate the HMM and serialize it to disk as a flat file. That decouples the building of the HMM from the actual usage, and saves a few CPU cycles and makes the tests run a bit faster. In addition, if this solution was to be used in a real-life situation, it would be much faster to load the HMM from a flat file than to build it from a tagged corpus. Our serialized HMM file looks like this (edited to truncate the number of observations for readability).

On a quick side note, the Jahmm example uses the ObservationDiscrete class based on an Enum to model a small finite set of observations. This works well if the number of observations in your set are small and well known. In our case, we consider a unique word as an observation, and we have approximately 3900 of them, so we used the ObservationInteger class to model the observation, and our flat file serves as a mapping between the integer id for the Observation to the actual word.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
Hmm v1.0

NbStates 5

State
Pi 0.127
A 0.155 0.156 0.019 0.025 0.645 
IntegerOPDF [0 0 0.00002 0.00003 0 0 0.00001 0 0.00001 ...]

State
Pi 0.057
A 0.095 0.195 0.168 0.094 0.449 
IntegerOPDF [0 0 0 0 0 0.00001 0 0 0 0 0 0 0.00001 0.00005 ...]

State
Pi 0.164
A 0.639 0.024 0.148 0.005 0.183 
IntegerOPDF [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 ...]

State
Pi 0.083
A 0.052 0.228 0.111 0.041 0.569 
IntegerOPDF [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 ...]

State
Pi 0.569
A 0.206 0.199 0.205 0.039 0.351 
IntegerOPDF [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0.0032 0.00033 0.00064 ...]

The following JUnit test snippet shows how we use the HmmTagger class (described below) to call the BrownCorpusReader and build and persist the HMM.

1
2
3
4
5
6
7
8
9
  @Test
  public void testBuildFromBrownAndWrite() throws Exception {
    HmmTagger hmmTagger = new HmmTagger();
    hmmTagger.setDataDir("/opt/brown-2.0");
    hmmTagger.setDictionaryLocation("src/test/resources/brown_dict.txt");
    hmmTagger.setHmmFileName("src/test/resources/hmm_tagger.dat");
    Hmm<ObservationInteger> hmm = hmmTagger.buildFromBrownCorpus();
    hmmTagger.saveToFile(hmm);
  }

HMM Tagger class

I then create a HmmTagger class that can build an HMM from the BrownCorpusReader as well as from a serialized HMM file shown above. The HmmTagger contains all the methods that are needed to solve the common HMM problems listed above. The code for the HmmTagger is as follows:

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
// Source: src/main/java/com/mycompany/myapp/postaggers/HmmTagger.java
package com.mycompany.myapp.postaggers;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.io.Writer;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import org.apache.commons.lang.StringUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

import be.ac.ulg.montefiore.run.jahmm.ForwardBackwardCalculator;
import be.ac.ulg.montefiore.run.jahmm.Hmm;
import be.ac.ulg.montefiore.run.jahmm.ObservationInteger;
import be.ac.ulg.montefiore.run.jahmm.OpdfInteger;
import be.ac.ulg.montefiore.run.jahmm.OpdfIntegerFactory;
import be.ac.ulg.montefiore.run.jahmm.ViterbiCalculator;
import be.ac.ulg.montefiore.run.jahmm.io.HmmReader;
import be.ac.ulg.montefiore.run.jahmm.io.HmmWriter;
import be.ac.ulg.montefiore.run.jahmm.io.OpdfIntegerReader;
import be.ac.ulg.montefiore.run.jahmm.io.OpdfReader;
import be.ac.ulg.montefiore.run.jahmm.io.OpdfWriter;
import be.ac.ulg.montefiore.run.jahmm.learn.BaumWelchLearner;
import be.ac.ulg.montefiore.run.jahmm.learn.KMeansLearner;
import be.ac.ulg.montefiore.run.jahmm.toolbox.KullbackLeiblerDistanceCalculator;

/**
 * HMM based POS Tagger.
 */
public class HmmTagger {

  private static final DecimalFormat OBS_FORMAT = 
    new DecimalFormat("##.#####");
  
  private final Log log = LogFactory.getLog(getClass());

  private String dataDir;
  private String dictionaryLocation;
  private String hmmFileName;
  
  private Map<String,Integer> words = 
    new HashMap<String,Integer>();
  
  public void setDataDir(String brownDataDir) {
    this.dataDir = brownDataDir;
  }

  public void setDictionaryLocation(String dictionaryLocation) {
    this.dictionaryLocation = dictionaryLocation;
  }

  public void setHmmFileName(String hmmFileName) {
    this.hmmFileName = hmmFileName;
  }

  /**
   * Builds up an HMM where states are parts of speech given by the Pos
   * Enum, and the observations are actual words in the tagged Brown
   * corpus. Each integer observation corresponds to the position of 
   * a word found in the Brown corpus.
   * @return an HMM.
   * @throws Exception if one is thrown.
   */
  public Hmm<ObservationInteger> buildFromBrownCorpus() 
      throws Exception {
    BrownCorpusReader brownReader = new BrownCorpusReader();
    brownReader.setDataFilesLocation(dataDir);
    brownReader.setWordDictionaryLocation(dictionaryLocation);
    brownReader.read();
    int nbStates = Pos.values().length;
    OpdfIntegerFactory factory = new OpdfIntegerFactory(nbStates);
    Hmm<ObservationInteger> hmm = 
      new Hmm<ObservationInteger>(nbStates, factory); 
    double[] pi = brownReader.getPi();
    for (int i = 0; i < nbStates; i++) {
      hmm.setPi(i, pi[i]);
    }
    double[][] a = brownReader.getA();
    for (int i = 0; i < nbStates; i++) {
      for (int j = 0; j < nbStates; j++) {
        hmm.setAij(i, j, a[i][j]);
      }
    }
    double[][] b = brownReader.getB();
    for (int i = 0; i < nbStates; i++) {
      for (int j = 0; j < nbStates; j++) {
        hmm.setOpdf(i, new OpdfInteger(b[i]));
      }
    }
    int seq = 0;
    for (String word : brownReader.getWords()) {
      words.put(word, seq);
      seq++;
    }
    return hmm;
  }
  
  /**
   * Builds an HMM from a formatted file describing the HMM. The format is
   * specified by the Jahmm project, and it has utility methods to read and
   * write HMMs from and to text files. We use this because the builder that
   * builds an HMM from the Brown corpus is computationally intensive and
   * this strategy provides us a way to partition the process.
   * @return a HMM
   * @throws Exception if one is thrown.
   */
  public Hmm<ObservationInteger> buildFromHmmFile() throws Exception {
    File hmmFile = new File(hmmFileName);
    if (! hmmFile.exists()) {
      throw new Exception("HMM File: " + hmmFile.getName() + 
        " does not exist");
    }
    FileReader fileReader = new FileReader(hmmFile);
    OpdfReader<OpdfInteger> opdfReader = new OpdfIntegerReader();
    Hmm<ObservationInteger> hmm = 
      HmmReader.read(fileReader, opdfReader);
    return hmm;
  }
  
  /**
   * Utility method to save an HMM into a formatted text file describing the
   * HMM. The format is specified by the Jahmm project, which also provides
   * utility methods to write a HMM to the text file.
   * @param hmm the HMM to write.
   * @throws Exception if one is thrown.
   */
  public void saveToFile(Hmm<ObservationInteger> hmm) 
      throws Exception {
    FileWriter fileWriter = new FileWriter(hmmFileName);
    // we create our own impl of the OpdfIntegerWriter because we want
    // to control the formatting of the opdf probabilities. With the 
    // default OpdfIntegerWriter, small probabilities get written in 
    // the exponential format, ie 1.234..E-4, which the HmmReader does
    // not recognize.
    OpdfWriter<OpdfInteger> opdfWriter = 
      new OpdfWriter<OpdfInteger>() {
        @Override
        public void write(Writer writer, OpdfInteger opdf) 
            throws IOException {
          String s = "IntegerOPDF [";
          for (int i = 0; i < opdf.nbEntries(); i++)
            s += OBS_FORMAT.format(opdf.probability(
              new ObservationInteger(i))) + " ";
            writer.write(s + "]\n");
          }
    };
    HmmWriter.write(fileWriter, opdfWriter, hmm);
    fileWriter.flush();
    fileWriter.close();
  }

  /**
   * Given the HMM, returns the probability of observing the sequence 
   * of words specified in the sentence. Uses the Forward-Backward 
   * algorithm to compute the probability.
   * @param sentence the sentence to check.
   * @param hmm a reference to a prebuilt HMM.
   * @return the probability of observing this sequence.
   * @throws Exception if one is thrown.
   */
  public double getObservationProbability(String sentence, 
      Hmm<ObservationInteger> hmm) throws Exception {
    String[] tokens = tokenizeSentence(sentence);
    List<ObservationInteger> observations = getObservations(tokens);
    ForwardBackwardCalculator fbc = 
      new ForwardBackwardCalculator(observations, hmm);
    return fbc.probability();
  }

  /**
   * Given an HMM and an untagged sentence, tags each word with the part of
   * speech it is most likely to belong in. Uses the Viterbi algorithm.
   * @param sentence the sentence to tag.
   * @param hmm the HMM to use.
   * @return a tagged sentence.
   * @throws Exception if one is thrown.
   */
  public String tagSentence(String sentence, 
      Hmm<ObservationInteger> hmm) throws Exception {
    String[] tokens = tokenizeSentence(sentence);
    List<ObservationInteger> observations = getObservations(tokens);
    ViterbiCalculator vc = new ViterbiCalculator(observations, hmm);
    int[] ids = vc.stateSequence();
    StringBuilder tagBuilder = new StringBuilder();
    for (int i = 0; i < ids.length; i++) {
      tagBuilder.append(tokens[i]).
        append("/").
        append((Pos.values()[ids[i]]).name()).
        append(" ");
    }
    return tagBuilder.toString();
  }
  
  /**
   * Given an HMM, a sentence and a word within the sentence which needs to 
   * be disambiguated, returns the most likely Pos for the specified word.
   * @param word the word to find the Pos for.
   * @param sentence the sentence.
   * @param hmm the HMM.
   * @return the most likely POS.
   * @throws Exception if one is thrown.
   */
  public Pos getMostLikelyPos(String word, String sentence, 
      Hmm<ObservationInteger> hmm) throws Exception {
    if (words == null || words.size() == 0) {
      loadWordsFromDictionary();
    }
    String[] tokens = tokenizeSentence(sentence);
    List<ObservationInteger> observations = getObservations(tokens);
    int wordPos = -1;
    for (int i = 0; i < tokens.length; i++) {
      if (tokens[i].equalsIgnoreCase(word)) {
        wordPos = i;
        break;
      }
    }
    if (wordPos == -1) {
      throw new IllegalArgumentException("Word [" + word + 
        "] does not exist in sentence [" + sentence + "]");
    }
    ViterbiCalculator vc = new ViterbiCalculator(observations, hmm);
    int[] ids = vc.stateSequence();
    return Pos.values()[ids[wordPos]];
  }

  /**
   * Given an existing HMM, this method will send in a List of sentences from
   * a possibly different untagged source, to refine the HMM.
   * @param sentences the List of sentences to teach.
   * @return a HMM that has been taught using the observation sequences.
   * @throws Exception if one is thrown.
   */
  public Hmm<ObservationInteger> teach(List<String> sentences)
      throws Exception {
    if (words == null || words.size() == 0) {
      loadWordsFromDictionary();
    }
    OpdfIntegerFactory factory = new OpdfIntegerFactory(words.size());
    List<List<ObservationInteger>> sequences = 
      new ArrayList<List<ObservationInteger>>();
    for (String sentence : sentences) {
      List<ObservationInteger> sequence = 
        getObservations(tokenizeSentence(sentence));
      sequences.add(sequence);
    }
    KMeansLearner<ObservationInteger> kml = 
      new KMeansLearner<ObservationInteger>(
      Pos.values().length, factory, sequences);
    Hmm<ObservationInteger> hmm = kml.iterate();
    // refine it with Baum-Welch Learner
    BaumWelchLearner bwl = new BaumWelchLearner();
    Hmm<ObservationInteger> refinedHmm = bwl.iterate(hmm, sequences);
    return refinedHmm;
  }
  
  /**
   * Convenience method to compute the distance between two HMMs. This can 
   * be used to stop the teaching process once more teaching is not
   * producing any appreciable improvement in the HMM, ie, the HMM
   * converges. The caller will need to match the result of this method 
   * with a number based on experience.
   * @param hmm1 the original HMM.
   * @param hmm2 the HMM that was most recently taught.
   * @return the difference measure between the two HMMs.
   * @throws Exception if one is thrown.
   */
  public double difference(Hmm<ObservationInteger> hmm1,
      Hmm<ObservationInteger> hmm2) throws Exception {
    KullbackLeiblerDistanceCalculator kdc = 
      new KullbackLeiblerDistanceCalculator();
    return kdc.distance(hmm1, hmm2);
  }
  
  private String[] tokenizeSentence(String sentence) {
    String[] tokens = StringUtils.split(
      StringUtils.lowerCase(StringUtils.trim(sentence)), " ");
    return tokens;
  }
  
  private List<ObservationInteger> getObservations(String[] tokens)
      throws Exception {
    if (words == null || words.size() == 0) {
      loadWordsFromDictionary();
    }
    List<ObservationInteger> observations = 
      new ArrayList<ObservationInteger>();
    for (String token : tokens) {
      observations.add(new ObservationInteger(words.get(token)));
    }
    return observations;
  }
  
  private void loadWordsFromDictionary() throws Exception {
    BufferedReader reader = new BufferedReader(
      new FileReader(dictionaryLocation));
    String word;
    int seq = 0;
    while ((word = reader.readLine()) != null) {
      words.put(word, seq);
      seq++;
    }
    reader.close();
  }
}

Word Sense Disambiguation

Given a sentence, a human user can figure the correct POS for each word almost immediately, but with an HMM, we can only tell which is the most likely POS for the word given the sequence of words in the sentence. Obviously, this depends on how large and accurate the HMM's training data set was. Here is how the HmmTagger is called to determine the most likely POS for a word in the sentence.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
  @Test
  public void testWordSenseDisambiguation() throws Exception {
    HmmTagger hmmTagger = new HmmTagger();
    hmmTagger.setDataDir("/opt/brown-2.0");
    hmmTagger.setDictionaryLocation("src/test/resources/brown_dict.txt");
    hmmTagger.setHmmFileName("src/test/resources/hmm_tagger.dat");
    Hmm<ObservationInteger> hmm = 
      hmmTagger.buildFromHmmFile();
    String[] testSentences = new String[] {
      "The dog ran after the cat .",
      "Failure dogs his path .",
      "The cold steel cuts through the flesh .",
      "He had a bad cold .",
      "He will catch the ball .",
      "Salmon is the catch of the day ."
    };
    String[] testWords = new String[] {
      "dog",
      "dogs",
      "cold",
      "cold",
      "catch",
      "catch"
    };
    for (int i = 0; i < testSentences.length; i++) {
      System.out.println("Original sentence: " + testSentences[i]);
      Pos wordPos = hmmTagger.getMostLikelyPos(testWords[i], 
        testSentences[i], hmm); 
      System.out.println("Pos(" + testWords[i] + ")=" + wordPos);
    }
  }

And here are the results. As you can see, the HMM did well on all but the second sentence.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
Original sentence: The dog ran after the cat .
Pos(dog)=NOUN

Original sentence: Failure dogs his path .
Pos(dogs)=NOUN

Original sentence: The cold steel cuts through the flesh .
Pos(cold)=ADJECTIVE

Original sentence: He had a bad cold .
Pos(cold)=NOUN

Original sentence: He will catch the ball .
Pos(catch)=VERB

Original sentence: Salmon is the catch of the day .
Pos(catch)=NOUN

POS Tagging

POS Tagging uses the same algorithm as Word Sense Disambiguation. Given a HMM trained with a sufficiently large and accurate corpus of tagged words, we can now use it to automatically tag sentences from a similar corpus. Here is the JUnit code snippet to do tag the sentences we used in our previous test.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
  @Test
  public void testPosTagging() throws Exception {
    HmmTagger hmmTagger = new HmmTagger();
    hmmTagger.setDataDir("/opt/brown-2.0");
    hmmTagger.setDictionaryLocation("src/test/resources/brown_dict.txt");
    hmmTagger.setHmmFileName("src/test/resources/hmm_tagger.dat");
    Hmm<ObservationInteger> hmm = hmmTagger.buildFromHmmFile();
    // POS tagging
    String[] testSentences = new String[] {...};
    for (int i = 0; i < testSentences.length; i++) {
      System.out.println("Original sentence: " + testSentences[i]);
      System.out.println("Tagged sentence: " + 
        hmmTagger.tagSentence(testSentences[i], hmm));
    }
  }

And here are the results.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
Original sentence: The dog ran after the cat .
Tagged sentence: the/ADJECTIVE dog/NOUN ran/VERB after/OTHER 
  the/ADJECTIVE cat/NOUN ./OTHER 

Original sentence: Failure dogs his path .
Tagged sentence: failure/NOUN dogs/NOUN his/OTHER path/NOUN ./OTHER 

Original sentence: The cold steel cuts through the flesh .
Tagged sentence: the/ADJECTIVE cold/ADJECTIVE steel/NOUN cuts/NOUN 
  through/OTHER the/ADJECTIVE flesh/NOUN ./OTHER 

Original sentence: He had a bad cold .
Tagged sentence: he/OTHER had/VERB a/ADJECTIVE bad/ADJECTIVE cold/NOUN 
  ./OTHER 

Original sentence: He will catch the ball .
Tagged sentence: he/OTHER will/VERB catch/VERB the/ADJECTIVE ball/NOUN 
  ./OTHER 

Original sentence: Salmon is the catch of the day .
Tagged sentence: salmon/NOUN is/VERB the/ADJECTIVE catch/NOUN of/OTHER 
  the/ADJECTIVE day/NOUN ./OTHER 

Sentence Likelihood

HMMs can be used to predict if one sentence is more likely to occur than another one, by comparing the observation probability of a certain sequence of words with another sequence. So for example, we find that the HMM believes that sentences spoken by Master Yoda of Star Wars fame are less likely to occur in "normal" English than sentences expressing similar meaning that you or I would speak.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
  @Test
  public void testObservationProbability() throws Exception {
    HmmTagger hmmTagger = new HmmTagger();
    hmmTagger.setDataDir("/opt/brown-2.0");
    hmmTagger.setDictionaryLocation("src/test/resources/brown_dict.txt");
    hmmTagger.setHmmFileName("src/test/resources/hmm_tagger.dat");
    Hmm<ObservationInteger> hmm = hmmTagger.buildFromHmmFile();
    System.out.println("P(I am worried)=" + 
      hmmTagger.getObservationProbability("I am worried", hmm));
    System.out.println("P(Worried I am)=" +  
      hmmTagger.getObservationProbability("Worried I am", hmm));
  }

As expected, our results indicate that the HMM understands us better than it understands Master Yoda.

1
2
P(I am worried)=5.446081633660202E-11
P(Worried I am)=1.2623833954125002E-11

Unsupervised Learning

The final problem we can solve with a HMM is to build one from a set of untagged data. This HMM can then be used for solving the Sentence Likelihood problem, but not the POS Tagging or the WSD problems. To set this up, I picked up a bunch of of Yoda quotes from this page and fed it into a newly instantiated HMM. I then took the same two sentences and asked the HMM which was more probable. Here is the test code snippet to do that:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
  @Test
  public void testTeachYodaAndObserveProbabilities() throws Exception {
    List<String> sentences = Arrays.asList(new String[] {
      "Powerful you have become .",
      "The dark side I sense in you .",
      "Grave danger you are in .",
      "Impatient you are .",
      "Try not .",
      "Do or do not there is no try .",
      "Consume you the dark path will .",
      "Always in motion is the future .",
      "Around the survivors a perimeter create .",
      "Size matters not .",
      "Blind we are if creation of this army we could not see .",
      "Help you I can yes .",
      "Strong am I with the force .",
      "Agree with you the council does .",
      "Your apprentice he will be .",
      "Worried I am .",
      "Always two there are .",
      "When 900 years you reach look as good you will not ."
    });
    HmmTagger hmmTagger = new HmmTagger();
    hmmTagger.setDataDir("/opt/brown-2.0");
    hmmTagger.setDictionaryLocation("src/test/resources/brown_dict.txt");
    Hmm<ObservationInteger> learnedHmm = hmmTagger.teach(sentences);
    System.out.println("P(Worried I am)=" +  
      hmmTagger.getObservationProbability("Worried I am", learnedHmm));
    System.out.println("P(I am worried)=" + 
      hmmTagger.getObservationProbability("I am worried", learnedHmm));
  }

Now, as you can see, this new HMM understands Yoda better than it understands us :-).

1
2
P(Worried I am)=4.455273233553778E-6
P(I am worried)=2.4569521508568634E-6

Conclusions

Personally, this learning curve was quite a steep one for me. The theory was fairly easy to grasp from an intuitive standpoint, but then understanding how to model the POS tagging problem as a HMM took me a while. Once I crossed that hurdle, it took me a fair bit of effort to figure out how to use Jahmm to build and solve a HMM.

I think it was worth it, though. HMMs are a very powerful modeling tool for text mining, and can be used to model a variety of real life situations. Using a library such as Jahmm means that you just have to figure out how to model your problem and to solve it using the tools provided.

Hopefully, if you've been reading this far, and you started out not knowing or with a vague idea of what an HMM was and how it could be used for POS tagging (as was my situation couple of months ago), this post has provided some information as well as an example of using the Jahmm API to build and solve an HMM.

Update 2009-04-26: In recent posts, I have been building on code written and described in previous posts, so there were (and rightly so) quite a few requests for the code. So I've created a project on Sourceforge to host the code. You will find the complete source code built so far in the project's SVN repository.

Saturday, November 01, 2008

My First MapReduce with Hadoop

Last week, I described a Phrase Spelling corrector that depended on Word Collocation Probabilities. The probabilities (from the occur_a and occur_ab tables) were based on data that were in the ballpark of what they should be, but the fact remains that I cooked them up. So this week, I describe a Hadoop Map-Reduce class that pulls this information out of real user-entered search terms from Apache access logs.

Background

I learned about the existence of the Map-Reduce programming style about a year ago, when a colleague told us about a free Hadoop seminar hosted by Yahoo!. Unfortunately, by the time I got around to enroll, all the seats were taken. I did read the original Map-Reduce paper by Jeffrey Dean and Sanjay Ghemawat of Google at the time, and while it seemed a neat idea, I wasn't sure where or how I would be able to use it in my work. So not being able to go to the seminar didn't seem like a huge loss.

More recently, however, as I spend more time working in and around text mining, I find myself spending more time waiting for programs to finish executing than I spend writing them. I had made a few half-hearted attempts to try to pick up Hadoop on my own, but it is a fairly steep learning curve, and I was never able to find the time to get to a point where I could model my existing algorithm as an equivalent Map-Reduce program, before having to move on to other things.

On a somewhat unrelated note, I recently also joined East Bay IT Group (EBIG) in an attempt to meet other like-minded tech professionals and to enhance my chances of landing a job closer to home. Just kidding about that last one, since nobody at EBIG seems to be working anywhere east of Oakland. So in any case, the first talk (since my joining) on the Java SIG was on Hadoop, by Owen O'Malley of Yahoo, so I figured that attending it would be a good way to quickly ramp up on Hadoop. I am happy to say that the talk was very informative (thanks Owen!) and I did get enough out of it to be able to write my own code soon after.

Specifications

The structure of our search URL is as follows:

1
2
  http://.../search/q1=term[&name=value...]
  where term: a single or multi-word query term

The idea is to run through all the access logs and count unique occurrences of single words and unique word pairs from the q1 values. These counts will later be fed into the occur_a and occur_ab tables described in my previous post.

For development, I just use the access_log files that are in my /var/log/httpd directory (just 4 of them), but in the final run on a cluster (not described in this post) will use a years worth of log files.

The MapReduce Class

Here's the code for the MapReduce class. As you can see, the Map and Reduce classes are written as inner classes of the main class. This seems to be the preferred style so I went with it, but it may be more unit-testable if you put each job into its own package and put the Map and Reduce classes inside that package. I did test the supporting classes, and did a quick test run, and things came out ok, so...

The Map class is called MapClass and the Reduce class is called ReduceClass. In addition, there is a PartitionerClass that attempts to send the pair output from the Map to one Reducer and the singleton output to another, so they are "sorted" in the final output, but apparently you cannot have more than one reducer in a non-clustered environment, so you cannot have a Partioner partition Map output to a second Reducer (because it does not exist). That is why the PartitionerClass is defined but commented out in the JobConf settings.

Once the Map and Reduce classes are defined, the main method sets up a JobConf and sets the Map and Reduce classes into it. The framework takes care of the rest. Basically the input is read line by line and passed into a bank of available Map classes. The outputs are accumulated into temporary file(s). Once all Map classes are done, the pairs written by the Map classes are sent to a bank of Reducers, which accumulates them. Once all the Reducers are done, the program ends.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
// Source: src/main/java/com/mycompany/accessloganalyzer/AccessLogAnalyzer.java
package com.mycompany.accessloganalyzer;

import java.io.IOException;
import java.util.Arrays;
import java.util.EnumMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;

import org.apache.commons.lang.StringUtils;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.io.Writable;
import org.apache.hadoop.io.WritableComparable;
import org.apache.hadoop.mapred.FileInputFormat;
import org.apache.hadoop.mapred.FileOutputFormat;
import org.apache.hadoop.mapred.JobClient;
import org.apache.hadoop.mapred.JobConf;
import org.apache.hadoop.mapred.MapReduceBase;
import org.apache.hadoop.mapred.Mapper;
import org.apache.hadoop.mapred.OutputCollector;
import org.apache.hadoop.mapred.Partitioner;
import org.apache.hadoop.mapred.Reducer;
import org.apache.hadoop.mapred.Reporter;

import com.mycompany.accessloganalyzer.NcsaLogParser.NcsaLog;

public class AccessLogAnalyzer {

  private static class MapClass extends MapReduceBase 
    implements Mapper<WritableComparable<Text>,Writable,
                      WritableComparable<Text>,Writable> {

    public void map(WritableComparable<Text> key,Writable value,
        OutputCollector<WritableComparable<Text>,Writable> output,
        Reporter reporter) throws IOException {
      String line = ((Text) value).toString();
      EnumMap<NcsaLog,String> values = NcsaLogParser.parse(line);
      String url = values.get(NcsaLog.REQUEST_URL);
      if (url.startsWith("/search")) {
        Map<String,String> parameters = 
          NcsaLogParser.getUrlParameters(url);
        String searchTerms = parameters.get("q1");
        String[] terms = StringUtils.split(searchTerms, " ");
        for (String term : terms) {
          output.collect(new Text(term), new LongWritable(1));
        }
        if (terms.length > 1) {
          // need to have at least 2 words to generate pair-wise combinations
          CombinationGenerator combinationGenerator = 
            new CombinationGenerator(terms.length, 2);
          Set<Pair> combinations = new HashSet<Pair>();
          while (combinationGenerator.hasMore()) {
            int[] indices = combinationGenerator.getNext();
            combinations.add(new Pair(terms[indices[0]], terms[indices[1]]));
          }
          for (Pair combination : combinations) {
            output.collect(new Text(combination.toString()), 
              new LongWritable(1));
          }
        }
      }
    }
  }
  
  private static class ReduceClass extends MapReduceBase 
    implements Reducer<WritableComparable<Text>,Writable,
                       WritableComparable<Text>,Writable> {

    public void reduce(WritableComparable<Text> key, 
        Iterator<Writable> values,
        OutputCollector<WritableComparable<Text>,Writable> output,
        Reporter reporter) throws IOException {
      long occurs = 0;
      while (values.hasNext()) {
        occurs += ((LongWritable) values.next()).get();
      }
      output.collect(key, new LongWritable(occurs));
    }
  }

  private static class PartitionerClass  
    implements Partitioner<WritableComparable<Text>,Writable> {

    public void configure(JobConf conf) { /* NOOP */ }

    public int getPartition(WritableComparable<Text> key, Writable value, 
        int numReduceTasks) {
      if (numReduceTasks > 1) {
        String k = ((Text) key).toString();
        return (k.contains(",") ? 1 : 0);
      }
      return 0;
    }
  }
  
  static class Pair {
    public String first;
    public String second;

    public Pair(String first, String second) {
      String[] pair = new String[] {first, second};
      Arrays.sort(pair);
      this.first = pair[0];
      this.second = pair[1];
    }
    
    @Override
    public int hashCode() {
      return toString().hashCode();
    }
    
    @Override
    public boolean equals(Object obj) {
      if (!(obj instanceof Pair)) {
        return false;
      }
      Pair that = (Pair) obj;
      return (this.first.equals(that.first) &&
        this.second.equals(that.second));
    }
    
    @Override
    public String toString() {
      return StringUtils.join(new String[] {first, second}, ",");
    }
  }

  public static void main(String[] argv) throws IOException {
    if (argv.length != 2) {
      System.err.println("Usage: calc input_path output_path");
      System.exit(-1);
    }
    
    JobConf jobConf = new JobConf(AccessLogAnalyzer.class);
    
    FileInputFormat.addInputPath(jobConf, new Path(argv[0]));
    FileOutputFormat.setOutputPath(jobConf, new Path(argv[1]));
    
    jobConf.setOutputKeyClass(Text.class);
    jobConf.setOutputValueClass(LongWritable.class);
    
    jobConf.setMapperClass(MapClass.class);
    jobConf.setCombinerClass(ReduceClass.class);
    jobConf.setReducerClass(ReduceClass.class);
//    jobConf.setPartitionerClass(PartitionerClass.class);
    
    jobConf.setNumReduceTasks(2);
    
    JobClient.runJob(jobConf);
  }
}

Supporting Classes

NCSA Log Parser

I went looking for a NCSA Log Parser but couldn't find one, so I wrote my own. I tried to make it generic, since I will probably be re-using this parser to pull other stuff out of the logs in the future. The parser described below parses the NCSA Common Log file format, which is smallest of the three NCSA Log formats. Here is the code:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
// Source: src/main/java/com/mycompany/accessloganalyzer/NcsaLogParser.java
package com.mycompany.accessloganalyzer;

import java.io.IOException;
import java.io.UnsupportedEncodingException;
import java.net.URLDecoder;
import java.util.EnumMap;
import java.util.HashMap;
import java.util.Map;
import java.util.StringTokenizer;

import org.apache.commons.lang.StringUtils;

public class NcsaLogParser {
  
  public enum NcsaLog {
    HOST,
    PROTOCOL,
    USERNAME,
    DATE,
    TIME,
    TIMEZONE,
    REQUEST_METHOD,
    REQUEST_URL,
    REQUEST_PROTOCOL,
    STATUS_CODE,
    BYTE_COUNT
  };
  
  public static EnumMap<NcsaLog,String> parse(String logline) {
    EnumMap<NcsaLog,String> values = 
      new EnumMap<NcsaLog, String>(NcsaLog.class);
    StringTokenizer tok = new StringTokenizer(logline, " ");
    if (tok.hasMoreTokens()) {
      values.put(NcsaLog.HOST, tok.nextToken());
      values.put(NcsaLog.PROTOCOL, tok.nextToken());
      values.put(NcsaLog.USERNAME, tok.nextToken());
      String dttm = tok.nextToken();
      values.put(NcsaLog.DATE, dttm.substring(1, dttm.indexOf(':')));
      values.put(NcsaLog.TIME, dttm.substring(dttm.indexOf(':') + 1));
      String tz = tok.nextToken();
      values.put(NcsaLog.TIMEZONE, tz.substring(0, tz.length() - 1));
      String requestMethod = tok.nextToken();
      values.put(NcsaLog.REQUEST_METHOD, requestMethod.substring(1));
      values.put(NcsaLog.REQUEST_URL, tok.nextToken());
      String requestProtocol = tok.nextToken();
      values.put(NcsaLog.REQUEST_PROTOCOL, 
        requestProtocol.substring(0, requestProtocol.length() - 1));
      values.put(NcsaLog.STATUS_CODE, tok.nextToken());
      values.put(NcsaLog.BYTE_COUNT, tok.nextToken());
    }
    return values;
  }
  
  public static Map<String,String> getUrlParameters(String url) throws IOException {
    Map<String,String> parameters = new HashMap<String,String>();
    int pos = url.indexOf('?');
    if (pos == -1) {
      return parameters;
    }
    String queryString = url.substring(pos + 1);
    String[] nvps = queryString.split("&");
    for (String nvp : nvps) {
      String[] pair = nvp.split("=");
      if (pair.length != 2) {
        continue;
      }
      String key = pair[0];
      String value = pair[1];
      // URL decode the value, replacing + and %20 etc chars with their
      // non-encoded equivalents.
      try {
        value = URLDecoder.decode(value, "UTF-8");
      } catch (UnsupportedEncodingException e) {
        throw new IOException("Unsupported encoding", e);
      }
      // replace all punctuation by space
      value = value.replaceAll("\\p{Punct}", " ");
      // lowercase it
      value = StringUtils.lowerCase(value);
      parameters.put(key, value); 
    }
    return parameters;
  }
}

Combination Generator

I needed a way to enumerate all pairs of words I find in a multi-word phrase. Michael Gilleland has already written one that works great, so all I did was to just copy this into my own package structure and use it. You can read/snag the code from Michael's site.

Packaging

Hadoop needs the classes packaged a certain way. Along with the compiled classes, you also want to add in any runtime dependency JAR files in a lib/ directory. You can optionally supply a MANIFEST.MF file specifying the Main-Class if you want to use the java -jar calling style. Since I use Maven, but don't really know of an easy way to write ad-hoc scripts to build new Maven goals, I decided to generate an Ant build.xml using Maven, then writing a new target.

  1. To generate the Ant build.xml, run mvn ant:ant.
  2. Add hadoop.jar to build.classpath
  3. Add definitions for input and output directories for the job
  4. Add the hadoop build and run target (shown below).

My target to package and run the jar are shown below. In the future, when I run this on a remote cluster, I will decouple it.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
  <target name="hadoop-log-analyzer" 
      description="Launch AccessLogAnalyzer job on Hadoop" depends="compile">
    <!-- create new directory target/lib and copy required runtime
         dependencies for the hadoop job into it -->
    <delete dir="${maven.build.directory}/jars/lib"/>
    <mkdir dir="${maven.build.directory}/jars/lib"/>
    <copy todir="${maven.build.directory}/jars/lib" flatten="true">
      <fileset dir="${maven.repo.local}">
        <include name="commons-lang/commons-lang/2.1/commons-lang-2.1.jar"/>
      </fileset>
    </copy>
    <!-- create jar file with classes and libraries -->
    <jar jarfile="${maven.build.directory}/log-analyzer.jar">
      <fileset dir="${maven.build.directory}/classes"/>
      <fileset dir="${maven.build.directory}/jars"/>
      <manifest>
        <attribute name="Main-Class"
          value="com/healthline/accessloganalyzer/AccessLogAnalyzer"/>
      </manifest>
    </jar>
    <!-- clean up output directory -->
    <delete dir="${basedir}/src/main/resources/access_log_outputs"/>
    <!-- run jar in hadoop -->
    <exec executable="bin/hadoop" dir="/opt/hadoop-0.18.1">
      <arg value="jar"/>
      <arg value="${basedir}/target/log-analyzer.jar"/>
      <arg value="${basedir}/src/main/resources/access_logs"/>
      <arg value="${basedir}/src/main/resources/access_log_outputs"/>
    </exec>
  </target>

My only runtime dependency was commons-lang-2.3.jar which I provide to the package in the target above.

(Local) Dev Testing

Unlike Owen, my other computer is not a data center. In fact, unless you count my work desktop, my laptop is the only computer I have. So I need to be able to test the job on my laptop first. Here is what I had to do.

  1. Explode the Hadoop tarball into a local directory. My hadoop directory (HADOOP_HOME) is /opt/hadoop-0.18.1.
  2. In the $HADOOP_HOME/conf/hadoop-env.sh file, update JAVA_HOME to point to whatever it is on your machine. The export is commented out, uncomment and update.
  3. Run the package using the Ant target

Here is the output of running ant hadoop-log-analyzer from the command line (edited for readability by removing dates from the logs).

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
hadoop-log-analyzer:
     [exec] jvm.JvmMetrics: Initializing JVM Metrics with \
       processName=JobTracker, sessionId=
     [exec] mapred.FileInputFormat: Total input paths to process : 5
     [exec] mapred.JobClient: Running job: job_local_1
     [exec] mapred.MapTask: numReduceTasks: 1
     [exec] mapred.LocalJobRunner: \
       file:~/myapp/src/main/resources/access_logs/access_log.3:0+118794
     [exec] mapred.TaskRunner: Task 'map_0000' done.
     [exec] mapred.MapTask: numReduceTasks: 1
     [exec] mapred.LocalJobRunner: \
       file:~/myapp/src/main/resources/access_logs/access_log.4:0+3811
     [exec] mapred.TaskRunner: Task 'map_0001' done.
     [exec] mapred.MapTask: numReduceTasks: 1
     [exec] mapred.LocalJobRunner: \
       file:~/myapp/src/main/resources/access_logs/access_log:0+446816
     [exec] mapred.TaskRunner: Task 'map_0002' done.
     [exec] mapred.MapTask: numReduceTasks: 1
     [exec] mapred.LocalJobRunner: \
       file:~/myapp/src/main/resources/access_logs/access_log.2:0+99752
     [exec] mapred.TaskRunner: Task 'map_0003' done.
     [exec] mapred.MapTask: numReduceTasks: 1
     [exec] mapred.LocalJobRunner: \
       file:~/myapp/src/main/resources/access_logs/access_log.1:0+36810
     [exec] mapred.TaskRunner: Task 'map_0004' done.
     [exec] mapred.LocalJobRunner: reduce > reduce
     [exec] mapred.TaskRunner: Task 'reduce_mglk8q' done.
     [exec] mapred.TaskRunner: Saved output of task \
       'reduce_mglk8q' to file:~/myapp/src/main/resources/access_log_outputs
     [exec] mapred.JobClient: Job complete: job_local_1
     [exec] mapred.JobClient: Counters: 9
     [exec] mapred.JobClient:   Map-Reduce Framework
     [exec] mapred.JobClient:     Map input records=2661
     [exec] mapred.JobClient:     Map output records=186
     [exec] mapred.JobClient:     Map input bytes=705983
     [exec] mapred.JobClient:     Map output bytes=3381
     [exec] mapred.JobClient:     Combine input records=186
     [exec] mapred.JobClient:     Combine output records=34
     [exec] mapred.JobClient:     Reduce input groups=31
     [exec] mapred.JobClient:     Reduce input records=34
     [exec] mapred.JobClient:     Reduce output records=31

Output is in the output directory in a file called part-0000, and here are a few lines from it to show what it looks like. In a full clustered system, we would be able to use the Partitioner to partition the pairs and the singletons into two Reducers so they will be distinct chunks.

1
2
3
4
5
6
7
8
9
asthma  6
breast  7
breast,cancer   7
cancer  18
diabetes        3
diaries 1
diaries,headache        1
disease 6
...

So thats pretty much it. The MapReduce style is a very powerful mechanism that allows average developers with domain expertise to write code that can be run within a framework, such as Hadoop, on large clusters to parallelize the computation. So its worth knowing, especially if you need to write batch programs that run on large data sets. I believe that I now understand enough about Hadoop to be able to use it effectively, and have reached a stage where I can pick up what I don't know. I hope this article has helped you in a similar way as well.