Showing posts with label hidden-markov-model. Show all posts
Showing posts with label hidden-markov-model. Show all posts

Saturday, August 08, 2020

Disambiguating SciSpacy + UMLS entities using the Viterbi algorithm

The SciSpacy project from AllenAI provides a language model trained on biomedical text, which can be used for Named Entity Recognition (NER) of biomedical entities using the standard SpaCy API. Unlike the entities found using SpaCy's language models (at least the English one), where entities have types such as PER, GEO, ORG, etc., SciSpacy entities have the single type ENTITY. In order to further classify them, SciSpacy provides Entity Linking (NEL) functionality through its integration with various ontology providers, such as the Unified Medical Language System (UMLS), Medical Subject Headings (MeSH), RxNorm, Gene Ontology (GO), and Human Phenotype Ontology (HPO)


The NER and NEL processes are decoupled. The NER process finds candidate entity spans, and these spans are matched against the respective ontologies, which may result in the span matching zero or more ontology entries. All candidate span is then matched to all the matched entities. 

I tried annotating the COVID-19 Open Research Dataset (CORD-19) against UMLS using the SciSpacy integration described above, and I noticed significant ambiguity in the linking results. Specifically, annotating approximately 22 million sentences in the CORD-19 dataset results in 113 million candidate entity spans, which get linked to 166 million UMLS concepts, i.e., on average, each candidate span resolves to 1.5 UMLS concepts. However, the distribution is Zipfian, with approximately 46.87% entity spans resolving to a single concept, with a long tail of entity spans being linked to up to 67 UMLS concepts. 

In this post, I will describe a strategy to disambiguate the linked entities. Based on limited testing, this chooses the correct concept about 73% of the time. 

The strategy is based on the intuition that an ambiguously linked entity span is more likely to resolve to a concept that is closely related to concepts for the other non-ambiguously linked entity spans in the sentence. In other words, the best target label to choose for an ambiguous entity is the one that is semantically closest to the labels of other entities in the sentence. Or even more succintly, and with apologies to John Firth, an entity is known by the company it keeps. 

The NER and NEL processes provided by the SciSpacy library allows us to reduce a sentence to a collection of entity spans, each of which map to zero or more UMLS concepts. Each UMLS concept maps to one or more Semantic Types, which represent high level subject categories. So essentially, a sentence can be reduced to a graph of semantic type using the following steps. 

Consider the sentence below, the NER step identifies candidate spans that are indicated by highlights.
The fact that viral antigens could not be demonstrated with the used staining is not the result of antibodies present in the cat that already bound to these antigens and hinder binding of other antibodies.
The NEL step will attempt to match these spans against the UMLS ontology. Results for the matching are shown below. As noted earlier, each UMLS concept maps to one or more sematic types, and these are shown here as well.
   
Entity-ID Entity Span Concept-ID Concept Primary Name Semantic Type Code Semantic Type Name
1 staining C0487602 Staining method T059 Laboratory Procedure
2 antibodies C0003241 Antibodies T116 Amino Acid, Peptide, or Protein
T129 Immunologic Factor
3 cat C0007450 Felis catus T015 Mammal
C0008169 Chloramphenicol O-Acetyltransferase T116 Amino Acid, Peptide, or Protein
T126 Enzyme
C0325089 Family Felidae T015 Mammal
C1366498 Chloramphenicol Acetyl Transferase Gene T028 Gene or Genome
4 antigens C0003320 Antigens T129 Immunologic Factor
5 binding C1145667 Binding action T052 Activity
C1167622 Binding (Molecular Function) T044 Molecular Function
6 antibodies C0003241 Antibodies T116 Amino Acid, Peptide, or Protein
T129 Immunologic Factor

The sequence of entity spans, each mapped to one or more semantic type codes can be represented by a graph of semantic type nodes as shown below. Here, each vertical grouping corresponds to an entity position. The BOS node is a special node representing the beginning of the sequence. Based on our intuition above, entity disambiguation is now just a matter of finding the most likely path through the graph.



Of course, "most likely" implies that we need to know the probabilities for transitioning between semantic types. We can think of the graph as a Markov Chain, and consider the probability of each node in the graph as being determined only by its previous node. Fortunately, this information is already available as a result of the NER + NEL process for the entire CORD-19 dataset, where approximately half of the entity spans mapped unambiguously to a single UMLS concept. Most concepts map to a single semantic type, but in cases where they map to multiple, we consider them as separate records. We compute pairwise transition probabilities across semantic types for these unambiguously linked pairs across the CORD-19 dataset and create our transition matrix. In addition, we also create a matrix of emission probabilities that identify the probabilities of resolving to a concept given a semantic type. 

Using the transition probabilities, we can traverse each path in the graph from starting to ending position, computing the path probability as the product of transition probabilities (or for computational reasons, the sum of log-probabilities) of the edges. However, better methods exist, such as the Viterbi algorithm, which allows us to save on repeated computation of common edge sequences across multiple paths. This is what we used to compute the most likely path through our semantic type graph. 

The Viterbi algorithm consists of two phases -- forward and backward. In the forward phase, we move left to right, computing the log-probability of each transition at each step, as shown by the vectors below each position in the figure. When computing the transition from multiple nodes to a single node (such as the one from [T129, T116] to [T126], we compute for both paths and choose the maximum value. 

In the backward phase, we move from right to left, choosing the maximum probability node at each step. This is shown in the figure as boxed entries. We can then lookup the appropriate semantic type and return the most likely sequence of semantic types (shown in bold in the bottom of the figure). 

However, our objective is to return disambiguated concept linkages for entities. Given a disambiguated semantic type and multiple possibilities indicated by SciSpacy's linking process, we use the emission probabilities to choose the most likely concept to apply at the position. The result for our example is shown in the table below.

Entity-ID Entity Span Concept-ID Concept Primary Name Semantic Type Code Semantic Type Name Correct?
1 staining C0487602 Staining method T059 Laboratory Procedure N/A*
2 antibodies C0003241 Antibodies T116 Amino Acid, Peptide, or Protein Yes
3 cat C0008169 Chloramphenicol O-Acetyltransferase T116 Amino Acid, Peptide, or Protein No
4 antigens C0003320 Antigens T129 Immunologic Factor N/A*
5 binding C1145667 Binding action T052 Activity Yes
6 antibodies C0003241 Antibodies T116 Amino Acid, Peptide, or Protein Yes
(N/A: non-ambiguous mappings) 

I thought this might be an interesting technique to share, hence writing about it. In addition, in the spirit of reproducibility, I have also provided the following artifacts for your convenience.
  1. Code: This github gist contains code that illustrates NER + NEL on an input sentence using SciSpacy and its UMLS integration, and then applies my adaptation of the Viterbi method (as described in this post) to disambiguate ambiguous entity linkages.
  2. Data: I have also provided the transition and emission matrices, and their associated lookup tables, for convenience, as these can be time consuming to generate from scratch from the CORD-19 dataset.
As always, I appreciate your feedback. Please let me know if you find flaws with my approach, and/or you know of a better approach for entity disambiguation

Saturday, March 30, 2013

A HMM based Gene Tagger using NLTK


In Prof. Michael Collin's Natural Language Processing class on Coursera, the first programming assignment consisted of building a Gene Named Entity Recognizer using a Hidden Markov Model. The training set consists of 13,795 sentences with entities representing genes marked with an I tag and all other words are marked with a O tag. Also provided are a tagged validation set of 509 sentences to test the algorithm against and a test set of 510 sentences to tag.

While the Coursera Honor code prevents students from sharing solutions (including programming assignments), this post does not qualify as one. First, the assignment expressly forbids the use of NLTK, since the objective of the assignment is to implement the Viterbi algorithm from scratch, and experimenting with different tweaks to make it better. Second, I wasn't able to implement the transition probabilities for trigram backoff properly, although the problem is very likely my imperfect understanding of the NLTK API. So you are better off doing the assignment from scratch as required.

However, I think that a lot of stuff in the NLP/ML domain is applied NLP/ML, so knowing how to solve something with an existing library/toolkit is as important as knowing (and being able to implement) the algorithm itself. I thought it would be interesting to use NLTK to build the tagger, hence this post.

Interestingly, the Masters in Language Technology program at the University of Gothenberg, Sweden, uses NLTK for its lab work, one of which provided me with insights on how to use NLTK's HMM package. Its good to know that there are others with the same opinion as me :-).

An HMM is a function of three probability distributions - the prior probabilities, which describes the probabilities of seeing the different tags in the data; the transition probabilities, which defines the probability of seeing a tag conditioned on the previous tag, and the emission probabilities, which defines the probability of seeing a word conditioned on a tag. An example may make this clearer.

The assignment starts off by asking to construct a simple HMM model for the training data, measuring its performance using the validation data, and finally applying the model to predict the IO tag for the test data. We then identify rare words (those occurring less than 5 times in the training data), and replace them with the string "_RARE_" and measure the performance. The next step is to identify types of the rare words, such as numerics, all caps, etc, and replacing them with "_RARE_xx_" strings, and measure the performance once again. Finally, the model is modified to work with transition probabilities that are conditioned against two previous tags instead of one, ie the trigram model.

The last part is where I had problems. My initial approach was to convert everything to bigrams, so that priors were the probabilities of seeing the different tag pairs in the data; the transition probabilities was the probability of seeing a tag bigram (t2,t3) conditioned on the previous tag bigram (t1,t2); and the emission probabilities was the probability of seeing a word bigram (w1,w2) conditioned on the corresponding tag pair (t1,t2). Mikhail Korobov, one of my classmates at Prof Collin's course, was kind enough to point me to this paper (PDF) which uses a similar approach.

Because the transition probabilities are now sparser, the precision and recall numbers went down pretty drastically. I then attempted to build a conditional probability distribution for the transition probabilities that used a backoff model, ie, if the MLE for the trigram probability (t3|t1,t2) was not available, fall back to the MLE for the bigram probability (t3|t2), and finally to the MLE for the unigram probability N(t3)/N. Unfortunately, this made all the predicted tag bigrams look like (I,O), so the model is now as good as a (binary) broken clock.

Here is the code for the tagger (also available on my GitHub). I wrote the code where I started with the baseline and kept making changes to support the functionalities required by the assignment by adding a new key-value pair as a command line parameter. Each successive key value turns on all previous key values (see the main method to see what I mean). Also the key values are used to identify the results.

1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
from __future__ import division

import itertools
import sys

import collections
import nltk
import nltk.probability
import numpy as np

######################## utility methods ###########################

def findRareWords(train_file):
  """
  Extra pass through the training file to identify rare
  words and deal with them in the accumulation process.
  """
  wordsFD = nltk.FreqDist()
  reader = nltk.corpus.reader.TaggedCorpusReader(".", train_file)
  for word in reader.words():
    wordsFD.inc(word.lower())
  return set(filter(lambda x: wordsFD[x] < 5, wordsFD.keys()))

def normalizeRareWord(word, rareWords, replaceRare):
  """
  Introduce a bit of variety. Even though they are all rare,
  we classify rare words into "kinds" of rare words.
  """
  if word in rareWords:
    if replaceRare:
      if word.isalnum():
        return "_RARE_NUMERIC_"
      elif word.upper() == word:
        return "_RARE_ALLCAPS_"
      elif word[-1:].isupper():
        return "_RARE_LASTCAP"
      else:
        return "_RARE_"
    else:
      return "_RARE_"
  else:
    return word

def pad(sent, tags=True):
  """
  Pad sentences with the start and stop tags and return padded
  sentence.
  """
  if tags:
    padded = [("<*>", "<*>"), ("<*>", "<*>")]
  else:
    padded = ["<*>", "<*>"]
  padded.extend(sent)
  if tags:
    padded.append(("<$>", "<$>"))
  else:
    padded.append("<$>")
  return padded

def calculateMetrics(actual, predicted):
  """
  Returns the number of cases where prediction and actual NER
  tags are the same, divided by the number of tags for the
  sentence.
  """
  pred_p = map(lambda x: "I" if x == "I" else "O", predicted)
  cm = nltk.metrics.confusionmatrix.ConfusionMatrix(actual, pred_p)
  keys = ["I", "O"]
  metrics = np.matrix(np.zeros((2, 2)))
  for x, y in [(x, y) for x in range(0, 2)
                      for y in range(0, 2)]:
    try:
      metrics[x, y] = cm[keys[x], keys[y]]
    except KeyError:
      pass
  tp = metrics[0, 0]
  tn = metrics[1, 1]
  fp = metrics[0, 1]
  fn = metrics[1, 0]
  precision = 0 if (tp + fp) == 0 else tp / (tp + fp)
  recall = 0 if (tp + fn) == 0 else tp / (tp + fn)
  fmeasure = (0 if (precision + recall) == 0
    else (2 * precision * recall) / (precision + recall))
  accuracy = (tp + tn) / (tp + tn + fp + fn)
  return (precision, recall, fmeasure, accuracy)

def writeResult(fout, hmm, words):
  """
  Writes out the result in the required format.
  """
  tags = hmm.best_path(words)
  for word, tag in zip(words, tags)[2:-1]:
    fout.write("%s %s\n" % (word, tag))
  fout.write("\n")

def bigramToUnigram(bigrams):
  """
  Convert a list of bigrams to the equivalent unigram list.
  """
  unigrams = [bigrams[0][0]]
  unigrams.extend([x[1] for x in bigrams])
  return unigrams

def calculateBackoffTransCPD(tagsFD, transCFD, trans2CFD):
  """
  Uses a backoff model to calculate a smoothed conditional
  probability distribution on the training data.
  """
  probDistDict = collections.defaultdict(nltk.DictionaryProbDist)
  tags = tagsFD.keys()
  conds = [x for x in itertools.permutations(tags, 2)]
  for tag in tags:
    conds.append((tag, tag))
  for (t1, t2) in conds:
    probDict = collections.defaultdict(float)
    prob = 0
    for t3 in tags:
      trigramsFD = trans2CFD[(t1, t2)]
      if trigramsFD.N() > 0 and trigramsFD.freq(t3) > 0:
        prob = trigramsFD.freq(t3) / trigramsFD.N()
      else:
        bigramsFD = transCFD[t2]
        if bigramsFD.N() > 0 and bigramsFD.freq(t3) > 0:
          prob = bigramsFD.freq(t3) / bigramsFD.N()
        else:
          prob = tagsFD[t3] / tagsFD.N()
      probDict[t3] = prob
    probDistDict[(t1, t2)] = nltk.DictionaryProbDist(probDict)
  return nltk.DictionaryConditionalProbDist(probDistDict)

class Accumulator:
  """
  Convenience class to accumulate all the frequencies
  into a set of data structures.
  """
  def __init__(self, rareWords, replaceRare, useTrigrams):
    self.rareWords = rareWords
    self.replaceRare = replaceRare
    self.useTrigrams = useTrigrams
    self.words = set()
    self.tags = set()
    self.priorsFD = nltk.FreqDist()
    self.transitionsCFD = nltk.ConditionalFreqDist()
    self.outputsCFD = nltk.ConditionalFreqDist()
    # additional data structures for trigram
    self.transitions2CFD = nltk.ConditionalFreqDist()
    self.tagsFD = nltk.FreqDist()

  def addSentence(self, sent, norm_func):
    # preprocess
    unigrams = [(norm_func(word, self.rareWords, self.replaceRare), tag)
      for (word, tag) in sent]
    prevTag = None
    prev2Tag = None
    if self.useTrigrams:
      # each state is represented by a tag bigram
      bigrams = nltk.bigrams(unigrams)
      for ((w1, t1), (w2, t2)) in bigrams:
        self.words.add((w1, w2))
        self.tags.add((t1, t2))
        self.priorsFD.inc((t1, t2))
        self.outputsCFD[(t1, t2)].inc((w1, w2))
        if prevTag is not None:
          self.transitionsCFD[prevTag].inc(t2)
        if prev2Tag is not None:
          self.transitions2CFD[prev2Tag].inc((t1, t2))
        prevTag = t2
        prev2Tag = (t1, t2)
        self.tagsFD.inc(prevTag)
    else:
      # each state is represented by an tag unigram
      for word, tag in unigrams:
        self.words.add(word)
        self.tags.add(tag)
        self.priorsFD.inc(tag)
        self.outputsCFD[tag].inc(word)
        if prevTag is not None:
          self.transitionsCFD[prevTag].inc(tag)
        prevTag = tag

####################### train, validate, test ##################

def train(train_file, 
    rareWords, replaceRare, useTrigrams, trigramBackoff):
  """
  Read the file and populate the various frequency and
  conditional frequency distributions and build the HMM
  off these data structures.
  """
  acc = Accumulator(rareWords, replaceRare, useTrigrams)
  reader = nltk.corpus.reader.TaggedCorpusReader(".", train_file)
  for sent in reader.tagged_sents():
    unigrams = pad(sent)
    acc.addSentence(unigrams, normalizeRareWord)
  if useTrigrams:
    if trigramBackoff:
      backoffCPD = calculateBackoffTransCPD(acc.tagsFD, acc.transitionsCFD,
        acc.transitions2CFD)
      return nltk.HiddenMarkovModelTagger(list(acc.words), list(acc.tags),
        backoffCPD,
        nltk.ConditionalProbDist(acc.outputsCFD, nltk.ELEProbDist),
        nltk.ELEProbDist(acc.priorsFD))
    else:
      return nltk.HiddenMarkovModelTagger(list(acc.words), list(acc.tags),
        nltk.ConditionalProbDist(acc.transitions2CFD, nltk.ELEProbDist,
        len(acc.transitions2CFD.conditions())),
        nltk.ConditionalProbDist(acc.outputsCFD, nltk.ELEProbDist),
        nltk.ELEProbDist(acc.priorsFD))
  else:
    return nltk.HiddenMarkovModelTagger(list(acc.words), list(acc.tags),
      nltk.ConditionalProbDist(acc.transitionsCFD, nltk.ELEProbDist,
      len(acc.transitionsCFD.conditions())),
      nltk.ConditionalProbDist(acc.outputsCFD, nltk.ELEProbDist),
      nltk.ELEProbDist(acc.priorsFD))

def validate(hmm, validation_file, rareWords, replaceRare, useTrigrams):
  """
  Tests the HMM against the validation file.
  """
  precision = 0
  recall = 0
  fmeasure = 0
  accuracy = 0
  nSents = 0
  reader = nltk.corpus.reader.TaggedCorpusReader(".", validation_file)
  for sent in reader.tagged_sents():
    sent = pad(sent)
    words = [word for (word, tag) in sent]
    tags = [tag for (word, tag) in sent]
    normWords = map(lambda x: normalizeRareWord(
      x, rareWords, replaceRare), words)
    if useTrigrams:
      # convert words to word bigrams
      normWords = nltk.bigrams(normWords)
    predictedTags = hmm.best_path(normWords)
    if useTrigrams:
      # convert tag bigrams back to unigrams
      predictedTags = bigramToUnigram(predictedTags)
    (p, r, f, a) = calculateMetrics(tags[2:-1], predictedTags[2:-1])
    precision += p
    recall += r
    fmeasure += f
    accuracy += a
    nSents += 1
  print("Accuracy=%f, Precision=%f, Recall=%f, F1-Measure=%f\n" %
    (accuracy/nSents, precision/nSents, recall/nSents,
    fmeasure/nSents))

def test(hmm, test_file, result_file, rareWords, replaceRare, useTrigrams):
  """
  Tests the HMM against the test file (without tags) and writes
  out the results to the result file.
  """
  fin = open(test_file, 'rb')
  fout = open(result_file, 'wb')
  for line in fin:
    line = line.strip()
    words = pad([word for word in line.split(" ")], tags=False)
    normWords = map(lambda x: normalizeRareWord(
      x, rareWords, replaceRare), words)
    if useTrigrams:
      # convert words to word bigrams
      normWords = nltk.bigrams(normWords)
    tags = hmm.best_path(normWords)
    if useTrigrams:
      # convert tag bigrams back to unigrams
      tags = bigramToUnigram(tags)
    fout.write(" ".join(["/".join([word, tag])
      for (word, tag) in (zip(words, tags))[2:-1]]) + "\n")
  fin.close()
  fout.close()

def main():
  normalizeRare = False
  replaceRare = False
  useTrigrams = False
  trigramBackoff = False
  if len(sys.argv) > 1:
    args = sys.argv[1:]
    for arg in args:
      k, v = arg.split("=")
      if k == "normalize-rare":
        normalizeRare = True if v.lower() == "true" else False
      elif k == "replace-rare":
        normalizeRare = True
        replaceRare = True if v.lower() == "true" else False
      elif k == "use-trigrams":
        normalizeRare = True
        replaceRare = True
        useTrigrams = True if v.lower() == "true" else False
      elif k == "trigram-backoff":
        normalizeRare = True
        replaceRare = True
        useTrigrams = True
        trigramBackoff = True if v.lower() == "true" else False
      else:
        continue
  rareWords = set()
  if normalizeRare:
    rareWords = findRareWords("gene.train")
  hmm = train("gene.train",
    rareWords, replaceRare, useTrigrams, trigramBackoff)
  validate(hmm, "gene.validate", rareWords, replaceRare, useTrigrams)
  test(hmm, "gene.test", "gene.test.out",
    rareWords, replaceRare, useTrigrams)
  
if __name__ == "__main__":
  main()

And here are the results.

Model Accuracy Precision Recall F1-Measure
(baseline) 0.941894 0.337643 0.346306 0.324261
normalize-rare 0.940066 0.319570 0.333753 0.309895
replace-rare 0.940960 0.319570 0.335538 0.311532
use-trigrams 0.901701 0.001742 0.011788 0.002912
trigram-backoff 0.902182 0.000000 0.000000 0.000000

As you can see, the best results in this case seem to come from the baseline bigram model. Prior to this, I had built an NER by modeling it as a classification problem. However, the use of HMMs for this problem domain looks fairly common, as we can see from the papers here (PDF) and here (PDF). Another resource I came across that I want to explore further is NLTK's TnT (Trigrams'n'Tags) tagger, which seems to be a better fit for the trigram-backoff model.

Saturday, March 16, 2013

The Wikipedia Bob Alice HMM example using scikit-learn


Recently I needed to build a Hidden Markov Model (HMM). I have played with HMMs previously, but it was a while ago, so I needed to brush up on the underlying concepts. For that, the Wikipedia article is actually quite effective. My objective was to take an off the shelf HMM implementation, train it and use it to predict (ie, the HMM algorithm itself is a black box).

Scikit-Learn is an open-source Python machine-learning library has several HMM implementations. The documentation is somewhat light, though, so I wanted to see if I could implement the Bob-Alice example from the Wikipedia article (there is a similar example on the Wikipedia article on the Viterbi algorithm), and if the resulting HMM returned believable results.

The Bob-Alice example is described here. Here is the corresponding implementation using Python and scikit-learn.

1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from __future__ import division
import numpy as np
from sklearn import hmm

states = ["Rainy", "Sunny"]
n_states = len(states)

observations = ["walk", "shop", "clean"]
n_observations = len(observations)

start_probability = np.array([0.6, 0.4])

transition_probability = np.array([
  [0.7, 0.3],
  [0.4, 0.6]
])

emission_probability = np.array([
  [0.1, 0.4, 0.5],
  [0.6, 0.3, 0.1]
])

model = hmm.MultinomialHMM(n_components=n_states)
model._set_startprob(start_probability)
model._set_transmat(transition_probability)
model._set_emissionprob(emission_probability)

# predict a sequence of hidden states based on visible states
bob_says = [0, 2, 1, 1, 2, 0]
logprob, alice_hears = model.decode(bob_says, algorithm="viterbi")
print "Bob says:", ", ".join(map(lambda x: observations[x], bob_says))
print "Alice hears:", ", ".join(map(lambda x: states[x], alice_hears))

The output of this code is shown below. As you can see, it looks quite reasonable given the constraints in the example.

1
2
Bob says: walk, clean, shop, shop, clean, walk
Alice hears: Sunny, Rainy, Rainy, Rainy, Rainy, Sunny

Even though its a silly little example, it helped me understand how to model a Named Entity Recognizer as a HMM for a Coursera class I am taking. Hopefully it helps you for something (or at least you found it interesting :-)).

Saturday, November 08, 2008

IR Math in Java : HMM Based POS Tagger/Recognizer

As you know, I have been slowly working my way through Dr Konchady's TMAP book, and coding up the algorithms in Java. By doing so, I hope to understand the techniques and mathematical models presented, so I can apply them to real-life problems in the future. In this post I describe an implementation of a Hidden Markov Model based Part of Speech recognizer/tagger, based on the material presented in Chapter 4 of the TMAP book.

Background

What follows is my take on what an HMM is and how it can be used for Part of Speech (POS) tagging. For a more detailed, math-heavy, and possibly more accurate description of HMM and their internals, read the Wikipedia article or Dr Rabiner's tutorial or the TMAP book if you happen to own it. A Hidden Markov Model can be thought of as a probabilistic finite state machine. Its states can be represented by the set S = {S1, S2, ..., Sn} which are not directly visible. What is visible is a set of Observations O = {O1, O2, ..., Om} which are the result of the machine moving from one state to the other. The probabilities of the machine starting in one of the states Si is specified by the one-dimensional matrix Π of size n. The probabilities of the machine moving from one state to another is specified by a two dimensional matrix A of size n*n. Finally, the probability of an observation being observed when the machine is in a certain state is given by the two dimensional matrix B of size n*m. More succintly,

  H = f(Π, A, B)
  where:
    H = the HMM
    Π = start probabilities. The element Πi represents 
        the probability that the HMM starts a sequence in State
        State Si, where i in (0..n-1).
    A = transition probabilities. The element Ai,j represents
        the probability of a transition from State Si to
        State Sj, where i and j in (0..n-1).
    B = emission probabilities. The element Bi,j represents
        the probability of an Observation Oj occurring while
        the machine is in State Si, where i in (0..n-1)
        and j in (0..m-1).
    n = number of states.
    m = number of unique observations.

The objective of POS tagging is to tag each word of a sentence with its part-of-speech tag. While some words can be unambiguously tagged, ie their is only one POS that the word is ever used for, there are quite a few which can represent different POS depending on its usage in the sentence. For example, cold can be both a noun and an adjective, and catch can be both a noun and a verb. The fact that the word exists in the sentence is known, while the POS for the word is unknown. Therefore the HMM built for POS tagging would model the words as visible observations and the set of possible POS as the hidden states.

As far as POS tagging is concerned, the main problems that can be solved by HMMs are as follows. Given an HMM,

  1. Finding the most likely state sequence for a given observation sequence. In this case, we pass in a sentence, and tag each word with its most likely POS.
  2. Finding the most likely state for a given observation in a sequence. This is useful for word sense disambiguation (WSD), so we can tell the most likely POS that a particular word in a sentence belongs to.

The problems above are identical from the point of view of a HMM, and are solved using the Viterbi algorithm.

The second problem is to find the probability of a certain sequence of observations. This can be used to answer questions such as whether a sentence such as "I am going to the market" is more common than one such as "To the market I go". The Forward-Backward Algorithm is used to solve this kind of problem. This can be useful for applications that predict the most likely outcome given a set of input observations, but probably is not important from the perspective of POS tagging. Both the above problems need to have a HMM built from (manually) tagged data.

A third problem of HMMs is how to build one given a corpus of untagged text. Such an HMM would allow us to solve the second type of problem. However, since the HMM has not been fed with tagged words, it must depend on a clustering algorithm such as K-Means to cluster the words into undefined hidden states, which are of no use when attempting to solve the second type of problem. The two learning algorithms used here are the K-Means Algorithm to build the initial HMM and the Baum-Welch Algorithm to refine it. As with the second type of problem, this does not have much applications where POS taggers are concerned.

I used the Java HMM library Jahmm to do all of the heavy computational lifting. It has implementations of the algorithms mentioned above, as well as several utility methods and classes to model various kinds of Observation.

Building the HMM from Tagged Data

For my tagged corpus, I used the Brown Corpus, downloading the data from the Natural Language Toolkit Project (NTLP). The corpus is a set of about 500 files containing one sentence per line, each manually tagged with a very comprehensive set of POS tags described here. Since I plan to use Wordnet at some point with this data, and Wordnet only categorizes words as one of 4 categories, I set up my own Part of Speech Enum called Pos which has 5 categories, the 4 from Wordnet and OTHER. As a result, I had to the dumb the Brown tags down using the translation table shown below:

BTAG POS BTAG POS BTAG POS BTAG POS
( OTHER FW-CS OTHER MD VERB RBR+CS ADVERB
) OTHER FW-DT OTHER MD* VERB RBT ADVERB
* OTHER FW-DT+BEZ OTHER MD+HV VERB RN ADVERB
, OTHER FW-DTS OTHER MD+PPSS VERB RP ADVERB
-- OTHER FW-HV VERB MD+TO VERB RP+IN ADVERB
. OTHER FW-IN OTHER NN NOUN TO OTHER
: OTHER FW-IN+AT OTHER NN$ NOUN TO+VB VERB
ABL OTHER FW-IN+NN OTHER NN+BEZ NOUN UH OTHER
ABN OTHER FW-IN+NP OTHER NN+HVD NOUN VB VERB
ABX OTHER FW-JJ ADJECTIVE NN+HVZ NOUN VB+AT VERB
AP OTHER FW-JJR ADJECTIVE NN+IN NOUN VB+IN VERB
AP$ OTHER FW-JJT ADJECTIVE NN+MD NOUN VB+JJ VERB
AP+AP OTHER FW-NN NOUN NN+NN NOUN VB+PPO VERB
AT ADJECTIVE FW-NN$ NOUN NNS NOUN VB+RP VERB
BE VERB FW-NNS NOUN NNS$ NOUN VB+TO VERB
BED VERB FW-NP NOUN NNS+MD NOUN VB+VB VERB
BED* VERB FW-NPS NOUN NP NOUN VBD VERB
BEDZ VERB FW-NR NOUN NP$ NOUN VBG VERB
BEDZ* VERB FW-OD NOUN NP+BEZ NOUN VBG+TO VERB
BEG VERB FW-PN OTHER NP+HVZ NOUN VBN VERB
BEM VERB FW-PP$ OTHER NP+MD NOUN VBN+TO VERB
BEM* VERB FW-PPL OTHER NPS NOUN VBZ VERB
BEN VERB FW-PPL+VBZ OTHER NPS$ NOUN WDT OTHER
BER VERB FW-PPO OTHER NR NOUN WDT+BER OTHER
BER* VERB FW-PPO+IN OTHER NR$ NOUN WDT+BER+PP OTHER
BEZ VERB FW-PPS OTHER NR+MD NOUN WDT+BEZ OTHER
BEZ* VERB FW-PPSS OTHER NRS NOUN WDT+DO+PPS OTHER
CC OTHER FW-PPSS+HV OTHER OD NOUN WDT+DOD OTHER
CD NOUN FW-QL OTHER PN OTHER WDT+HVZ OTHER
CD$ NOUN FW-RB ADVERB PN$ OTHER WP$ OTHER
CS OTHER FW-RB+CC ADVERB PN+BEZ OTHER WPO OTHER
DO VERB FW-TO+VB VERB PN+HVD OTHER WPS OTHER
DO* VERB FW-UH OTHER PN+HVZ OTHER WPS+BEZ OTHER
DO+PPSS VERB FW-VB VERB PN+MD OTHER WPS+HVD OTHER
DOD VERB FW-VBD VERB PP$ OTHER WPS+HVZ OTHER
DOD* VERB FW-VBG VERB PP$$ OTHER WPS+MD OTHER
DOZ VERB FW-VBN VERB PPL OTHER WQL OTHER
DOZ* VERB FW-VBZ VERB PPLS OTHER WRB ADVERB
DT OTHER FW-WDT OTHER PPO OTHER WRB+BER ADVERB
DT$ OTHER FW-WPO OTHER PPS OTHER WRB+BEZ ADVERB
DT+BEZ OTHER FW-WPS OTHER PPS+BEZ OTHER WRB+DO ADVERB
DT+MD OTHER HV VERB PPS+HVD OTHER WRB+DOD ADVERB
DTI OTHER HV* VERB PPS+HVZ OTHER WRB+DOD* ADVERB
DTS OTHER HV+TO VERB PPS+MD OTHER WRB+DOZ ADVERB
DTS+BEZ OTHER HVD VERB PPSS OTHER WRB+IN ADVERB
DTX OTHER HVD* VERB PPSS+BEM OTHER WRB+MD ADVERB
EX VERB HVG VERB PPSS+BER OTHER - -
EX+BEZ VERB HVN VERB PPSS+BEZ OTHER - -
EX+HVD VERB HVZ VERB PPSS+BEZ* OTHER - -
EX+HVZ VERB HVZ* VERB PPSS+HV OTHER - -
EX+MD VERB IN OTHER PPSS+HVD OTHER - -
FW-* OTHER IN+IN OTHER PPSS+MD OTHER - -
FW-AT ADJECTIVE IN+PPO OTHER PPSS+VB OTHER - -
FW-AT+NN ADJECTIVE JJ ADJECTIVE QL OTHER - -
FW-AT+NP ADJECTIVE JJ$ ADJECTIVE QLP OTHER - -
FW-BE VERB JJ+JJ ADJECTIVE RB ADVERB - -
FW-BER VERB JJR ADJECTIVE RB$ ADVERB - -
FW-BEZ VERB JJR+CS ADJECTIVE RB+BEZ ADVERB - -
FW-CC OTHER JJS ADJECTIVE RB+CS ADVERB - -
FW-CD NOUN JJT ADJECTIVE RBR ADVERB - -

You may notice that some of the mappings are not correct. Unfortunately, my knowledge of formal English grammar is not as good as I would like it to be, owing to having been educated in an environment that posited that a person can learn to recognize incorrect grammar better by reading and writing enough gramatically correct sentences rather than through a study of language rules. My grandfather, obviously regarding all this as crazy talk, briefly attempted to rectify that, armed with a Wren and Martin and a 18" ruler, but as you can probably see, it did not work out all that well :-).

The code for the Pos Enum is shown below. As mentioned earlier, it exposes a set of 5 POS values, and has a convenience method to convert the Brown tag into corresponding Pos.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
// Source: src/main/java/com/mycompany/myapp/postaggers/Pos.java
package com.mycompany.myapp.postaggers;

import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.InputStreamReader;
import java.util.HashMap;
import java.util.Map;

import org.apache.commons.lang.StringUtils;

/**
 * Enumeration of Parts of Speech being considered. Conversions from
 * Brown Tags and Wordnet tags are handled by convenience methods.
 */
public enum Pos {

  NOUN, VERB, ADJECTIVE, ADVERB, OTHER;

  private static Map<String,Pos> bmap = null;
  private static final String translationFile = 
    "src/main/resources/brown_tags.txt";
  
  public static Pos fromBrownTag(String btag) throws Exception {
    if (bmap == null) {
      bmap = new HashMap<String,Pos>();
      BufferedReader reader = new BufferedReader(new InputStreamReader(
          new FileInputStream(translationFile)));
      String line;
      while ((line = reader.readLine()) != null) {
        if (line.startsWith("#")) {
          continue;
        }
        String[] cols = StringUtils.split(line, "\t");
        bmap.put(StringUtils.lowerCase(cols[0]), Pos.valueOf(cols[1])); 
      }
      reader.close();
    }
    Pos pos = bmap.get(btag);
    if (pos == null) {
      return Pos.OTHER;
    }
    return pos;
  }
}

The BrownCorpusReader reads through each tagged file in the Brown Corpus directory, extracts the word and the tag out of each tagged word, converts the Brown tag to its equivalent Pos value, and accumulates the occurrences into internal counters. Once all files are processed, the counters are normalized into the three probability matrices Π, A and B that we spoke about earlier.

Since the number of words tagged in any corpus is potentially quite large, we represent the words (or observations) in the HMM as an integer. That is why the BrownCorpusReader also dumps out a list of unique words it found in the corpus into a flat file which can be pulled back into memory later to do the mapping between the word and the integer observation Id.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
// Source: src/main/java/com/mycompany/myapp/postaggers/BrownCorpusReader.java
package com.mycompany.myapp.postaggers;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileWriter;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.StringTokenizer;

import org.apache.commons.collections15.Bag;
import org.apache.commons.collections15.bag.HashBag;
import org.apache.commons.lang.StringUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

import Jama.Matrix;

/**
 * Reads a file or directory of tagged text from Brown Corpus and
 * computes the various probability matrices for the HMM.
 */
public class BrownCorpusReader {

  private final Log log = LogFactory.getLog(getClass());
  
  private String dataFilesLocation;
  private String wordDictionaryLocation;
  private boolean debug;

  private Bag<String> piCounts = new HashBag<String>();
  private Bag<String> aCounts = new HashBag<String>();
  private Map<String,Double[]> wordPosMap = 
    new HashMap<String,Double[]>();
  
  private Matrix pi;
  private Matrix a;
  private Matrix b;
  private List<String> words;
  
  public void setDataFilesLocation(String dataFilesLocation) {
    this.dataFilesLocation = dataFilesLocation;
  }
  
  public void setWordDictionaryLocation(String wordDictionaryLocation) {
    this.wordDictionaryLocation = wordDictionaryLocation;
  }
  
  public void setDebug(boolean debug) {
    this.debug = debug;
  }

  public void read() throws Exception {
    File location = new File(dataFilesLocation);
    File[] inputs;
    if (location.isDirectory()) {
      inputs = location.listFiles();
    } else {
      inputs = new File[] {location};
    }
    int currfile = 0;
    int totfiles = inputs.length;
    for (File input : inputs) {
      currfile++;
      log.info("Processing file (" + currfile + "/" + totfiles + "): " + 
        input.getName());
      BufferedReader reader = new BufferedReader(new InputStreamReader(
        new FileInputStream(input)));
      String line;
      while ((line = reader.readLine()) != null) {
        if (StringUtils.isEmpty(line)) {
          continue;
        }
        StringTokenizer tok = new StringTokenizer(line, " ");
        int wordIndex = 0;
        Pos prevPos = null;
        while (tok.hasMoreTokens()) {
          String taggedWord = tok.nextToken();
          String[] wordTagPair = StringUtils.split(
            StringUtils.lowerCase(StringUtils.trim(taggedWord)), "/");
          if (wordTagPair.length != 2) {
            continue;
          }
          Pos pos = Pos.fromBrownTag(wordTagPair[1]);
          if (! wordPosMap.containsKey(wordTagPair[0])) {
            // create an entry
            Double[] posProbs = new Double[Pos.values().length];
            for (int i = 0; i < posProbs.length; i++) {
              posProbs[i] = new Double(0.0D);
            }
            wordPosMap.put(wordTagPair[0], posProbs);
          }
          Double[] posProbs = wordPosMap.get(wordTagPair[0]);
          posProbs[pos.ordinal()] += 1.0D;
          wordPosMap.put(wordTagPair[0], posProbs);
          if (wordIndex == 0) {
            // first word, update piCounts
            piCounts.add(pos.name());
          } else {
            aCounts.add(StringUtils.join(new String[] {
              prevPos.name(), pos.name()}, ":"));
          }
          prevPos = pos;
          wordIndex++;
        }
      }
      reader.close();
    }
    // normalize counts to probabilities
    int numPos = Pos.values().length;
    // compute pi
    pi = new Matrix(numPos, 1);
    for (int i = 0; i < numPos; i++) {
      pi.set(i, 0, piCounts.getCount((Pos.values()[i]).name()));
    }
    pi = pi.times(1 / pi.norm1());
    // compute a
    a = new Matrix(numPos, numPos);
    for (int i = 0; i < numPos; i++) {
      for (int j = 0; j < numPos; j++) {
        a.set(i, j, aCounts.getCount(StringUtils.join(new String[] {
          (Pos.values()[i]).name(), (Pos.values()[j]).name()
        }, ":")));
      }
    }
    // compute b
    int numWords = wordPosMap.size();
    words = new ArrayList<String>();
    words.addAll(wordPosMap.keySet());
    b = new Matrix(numPos, numWords);
    for (int i = 0; i < numPos; i++) {
      for (int j = 0; j < numWords; j++) {
        String word = words.get(j);
        b.set(i, j, wordPosMap.get(word)[i]);
      }
    }
    // normalize across rows for a and b (sum of cols in each row == 1.0)
    for (int i = 0; i < numPos; i++) {
      double rowSumA = 0.0D;
      for (int j = 0; j < numPos; j++) {
        rowSumA += a.get(i, j);
      }
      for (int j = 0; j < numPos; j++) {
        a.set(i, j, (a.get(i, j) / rowSumA));
      }
      double rowSumB = 0.0D;
      for (int j = 0; j < numWords; j++) {
        rowSumB += b.get(i, j);
      }
      for (int j = 0; j < numWords; j++) {
        b.set(i, j, (b.get(i, j) / rowSumB));
      }
    }
    // write out brown word dictionary for later use
    writeDictionary();
    // debug
    if (debug) {
      pi.print(8, 4);
      a.print(8, 4);
      b.print(8, 4);
      System.out.println(words.toString());
    }
  }
  
  public List<String> getWords() {
    return words;
  }
  
  public double[] getPi() {
    double[] pia = new double[pi.getRowDimension()];
    for (int i = 0; i < pia.length; i++) {
      pia[i] = pi.get(i, 0);
    }
    return pia;
  }
  
  public double[][] getA() {
    double[][] aa = new double[a.getRowDimension()][a.getColumnDimension()];
    for (int i = 0; i < a.getRowDimension(); i++) {
      for (int j = 0; j < a.getColumnDimension(); j++) {
        aa[i][j] = a.get(i, j);
      }
    }
    return aa;
  }
  
  public double[][] getB() {
    double[][] ba = new double[b.getRowDimension()][b.getColumnDimension()];
    for (int i = 0; i < b.getRowDimension(); i++) {
      for (int j = 0; j < b.getColumnDimension(); j++) {
        ba[i][j] = b.get(i, j);
      }
    }
    return ba;
  }

  private void writeDictionary() throws Exception {
    FileWriter dictWriter = new FileWriter(wordDictionaryLocation);
    for (String word : words) {
      dictWriter.write(word + "\n");
    }
    dictWriter.flush();
    dictWriter.close();
  }
}

We generate the HMM and serialize it to disk as a flat file. That decouples the building of the HMM from the actual usage, and saves a few CPU cycles and makes the tests run a bit faster. In addition, if this solution was to be used in a real-life situation, it would be much faster to load the HMM from a flat file than to build it from a tagged corpus. Our serialized HMM file looks like this (edited to truncate the number of observations for readability).

On a quick side note, the Jahmm example uses the ObservationDiscrete class based on an Enum to model a small finite set of observations. This works well if the number of observations in your set are small and well known. In our case, we consider a unique word as an observation, and we have approximately 3900 of them, so we used the ObservationInteger class to model the observation, and our flat file serves as a mapping between the integer id for the Observation to the actual word.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
Hmm v1.0

NbStates 5

State
Pi 0.127
A 0.155 0.156 0.019 0.025 0.645 
IntegerOPDF [0 0 0.00002 0.00003 0 0 0.00001 0 0.00001 ...]

State
Pi 0.057
A 0.095 0.195 0.168 0.094 0.449 
IntegerOPDF [0 0 0 0 0 0.00001 0 0 0 0 0 0 0.00001 0.00005 ...]

State
Pi 0.164
A 0.639 0.024 0.148 0.005 0.183 
IntegerOPDF [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 ...]

State
Pi 0.083
A 0.052 0.228 0.111 0.041 0.569 
IntegerOPDF [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 ...]

State
Pi 0.569
A 0.206 0.199 0.205 0.039 0.351 
IntegerOPDF [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0.0032 0.00033 0.00064 ...]

The following JUnit test snippet shows how we use the HmmTagger class (described below) to call the BrownCorpusReader and build and persist the HMM.

1
2
3
4
5
6
7
8
9
  @Test
  public void testBuildFromBrownAndWrite() throws Exception {
    HmmTagger hmmTagger = new HmmTagger();
    hmmTagger.setDataDir("/opt/brown-2.0");
    hmmTagger.setDictionaryLocation("src/test/resources/brown_dict.txt");
    hmmTagger.setHmmFileName("src/test/resources/hmm_tagger.dat");
    Hmm<ObservationInteger> hmm = hmmTagger.buildFromBrownCorpus();
    hmmTagger.saveToFile(hmm);
  }

HMM Tagger class

I then create a HmmTagger class that can build an HMM from the BrownCorpusReader as well as from a serialized HMM file shown above. The HmmTagger contains all the methods that are needed to solve the common HMM problems listed above. The code for the HmmTagger is as follows:

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
// Source: src/main/java/com/mycompany/myapp/postaggers/HmmTagger.java
package com.mycompany.myapp.postaggers;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.io.Writer;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import org.apache.commons.lang.StringUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

import be.ac.ulg.montefiore.run.jahmm.ForwardBackwardCalculator;
import be.ac.ulg.montefiore.run.jahmm.Hmm;
import be.ac.ulg.montefiore.run.jahmm.ObservationInteger;
import be.ac.ulg.montefiore.run.jahmm.OpdfInteger;
import be.ac.ulg.montefiore.run.jahmm.OpdfIntegerFactory;
import be.ac.ulg.montefiore.run.jahmm.ViterbiCalculator;
import be.ac.ulg.montefiore.run.jahmm.io.HmmReader;
import be.ac.ulg.montefiore.run.jahmm.io.HmmWriter;
import be.ac.ulg.montefiore.run.jahmm.io.OpdfIntegerReader;
import be.ac.ulg.montefiore.run.jahmm.io.OpdfReader;
import be.ac.ulg.montefiore.run.jahmm.io.OpdfWriter;
import be.ac.ulg.montefiore.run.jahmm.learn.BaumWelchLearner;
import be.ac.ulg.montefiore.run.jahmm.learn.KMeansLearner;
import be.ac.ulg.montefiore.run.jahmm.toolbox.KullbackLeiblerDistanceCalculator;

/**
 * HMM based POS Tagger.
 */
public class HmmTagger {

  private static final DecimalFormat OBS_FORMAT = 
    new DecimalFormat("##.#####");
  
  private final Log log = LogFactory.getLog(getClass());

  private String dataDir;
  private String dictionaryLocation;
  private String hmmFileName;
  
  private Map<String,Integer> words = 
    new HashMap<String,Integer>();
  
  public void setDataDir(String brownDataDir) {
    this.dataDir = brownDataDir;
  }

  public void setDictionaryLocation(String dictionaryLocation) {
    this.dictionaryLocation = dictionaryLocation;
  }

  public void setHmmFileName(String hmmFileName) {
    this.hmmFileName = hmmFileName;
  }

  /**
   * Builds up an HMM where states are parts of speech given by the Pos
   * Enum, and the observations are actual words in the tagged Brown
   * corpus. Each integer observation corresponds to the position of 
   * a word found in the Brown corpus.
   * @return an HMM.
   * @throws Exception if one is thrown.
   */
  public Hmm<ObservationInteger> buildFromBrownCorpus() 
      throws Exception {
    BrownCorpusReader brownReader = new BrownCorpusReader();
    brownReader.setDataFilesLocation(dataDir);
    brownReader.setWordDictionaryLocation(dictionaryLocation);
    brownReader.read();
    int nbStates = Pos.values().length;
    OpdfIntegerFactory factory = new OpdfIntegerFactory(nbStates);
    Hmm<ObservationInteger> hmm = 
      new Hmm<ObservationInteger>(nbStates, factory); 
    double[] pi = brownReader.getPi();
    for (int i = 0; i < nbStates; i++) {
      hmm.setPi(i, pi[i]);
    }
    double[][] a = brownReader.getA();
    for (int i = 0; i < nbStates; i++) {
      for (int j = 0; j < nbStates; j++) {
        hmm.setAij(i, j, a[i][j]);
      }
    }
    double[][] b = brownReader.getB();
    for (int i = 0; i < nbStates; i++) {
      for (int j = 0; j < nbStates; j++) {
        hmm.setOpdf(i, new OpdfInteger(b[i]));
      }
    }
    int seq = 0;
    for (String word : brownReader.getWords()) {
      words.put(word, seq);
      seq++;
    }
    return hmm;
  }
  
  /**
   * Builds an HMM from a formatted file describing the HMM. The format is
   * specified by the Jahmm project, and it has utility methods to read and
   * write HMMs from and to text files. We use this because the builder that
   * builds an HMM from the Brown corpus is computationally intensive and
   * this strategy provides us a way to partition the process.
   * @return a HMM
   * @throws Exception if one is thrown.
   */
  public Hmm<ObservationInteger> buildFromHmmFile() throws Exception {
    File hmmFile = new File(hmmFileName);
    if (! hmmFile.exists()) {
      throw new Exception("HMM File: " + hmmFile.getName() + 
        " does not exist");
    }
    FileReader fileReader = new FileReader(hmmFile);
    OpdfReader<OpdfInteger> opdfReader = new OpdfIntegerReader();
    Hmm<ObservationInteger> hmm = 
      HmmReader.read(fileReader, opdfReader);
    return hmm;
  }
  
  /**
   * Utility method to save an HMM into a formatted text file describing the
   * HMM. The format is specified by the Jahmm project, which also provides
   * utility methods to write a HMM to the text file.
   * @param hmm the HMM to write.
   * @throws Exception if one is thrown.
   */
  public void saveToFile(Hmm<ObservationInteger> hmm) 
      throws Exception {
    FileWriter fileWriter = new FileWriter(hmmFileName);
    // we create our own impl of the OpdfIntegerWriter because we want
    // to control the formatting of the opdf probabilities. With the 
    // default OpdfIntegerWriter, small probabilities get written in 
    // the exponential format, ie 1.234..E-4, which the HmmReader does
    // not recognize.
    OpdfWriter<OpdfInteger> opdfWriter = 
      new OpdfWriter<OpdfInteger>() {
        @Override
        public void write(Writer writer, OpdfInteger opdf) 
            throws IOException {
          String s = "IntegerOPDF [";
          for (int i = 0; i < opdf.nbEntries(); i++)
            s += OBS_FORMAT.format(opdf.probability(
              new ObservationInteger(i))) + " ";
            writer.write(s + "]\n");
          }
    };
    HmmWriter.write(fileWriter, opdfWriter, hmm);
    fileWriter.flush();
    fileWriter.close();
  }

  /**
   * Given the HMM, returns the probability of observing the sequence 
   * of words specified in the sentence. Uses the Forward-Backward 
   * algorithm to compute the probability.
   * @param sentence the sentence to check.
   * @param hmm a reference to a prebuilt HMM.
   * @return the probability of observing this sequence.
   * @throws Exception if one is thrown.
   */
  public double getObservationProbability(String sentence, 
      Hmm<ObservationInteger> hmm) throws Exception {
    String[] tokens = tokenizeSentence(sentence);
    List<ObservationInteger> observations = getObservations(tokens);
    ForwardBackwardCalculator fbc = 
      new ForwardBackwardCalculator(observations, hmm);
    return fbc.probability();
  }

  /**
   * Given an HMM and an untagged sentence, tags each word with the part of
   * speech it is most likely to belong in. Uses the Viterbi algorithm.
   * @param sentence the sentence to tag.
   * @param hmm the HMM to use.
   * @return a tagged sentence.
   * @throws Exception if one is thrown.
   */
  public String tagSentence(String sentence, 
      Hmm<ObservationInteger> hmm) throws Exception {
    String[] tokens = tokenizeSentence(sentence);
    List<ObservationInteger> observations = getObservations(tokens);
    ViterbiCalculator vc = new ViterbiCalculator(observations, hmm);
    int[] ids = vc.stateSequence();
    StringBuilder tagBuilder = new StringBuilder();
    for (int i = 0; i < ids.length; i++) {
      tagBuilder.append(tokens[i]).
        append("/").
        append((Pos.values()[ids[i]]).name()).
        append(" ");
    }
    return tagBuilder.toString();
  }
  
  /**
   * Given an HMM, a sentence and a word within the sentence which needs to 
   * be disambiguated, returns the most likely Pos for the specified word.
   * @param word the word to find the Pos for.
   * @param sentence the sentence.
   * @param hmm the HMM.
   * @return the most likely POS.
   * @throws Exception if one is thrown.
   */
  public Pos getMostLikelyPos(String word, String sentence, 
      Hmm<ObservationInteger> hmm) throws Exception {
    if (words == null || words.size() == 0) {
      loadWordsFromDictionary();
    }
    String[] tokens = tokenizeSentence(sentence);
    List<ObservationInteger> observations = getObservations(tokens);
    int wordPos = -1;
    for (int i = 0; i < tokens.length; i++) {
      if (tokens[i].equalsIgnoreCase(word)) {
        wordPos = i;
        break;
      }
    }
    if (wordPos == -1) {
      throw new IllegalArgumentException("Word [" + word + 
        "] does not exist in sentence [" + sentence + "]");
    }
    ViterbiCalculator vc = new ViterbiCalculator(observations, hmm);
    int[] ids = vc.stateSequence();
    return Pos.values()[ids[wordPos]];
  }

  /**
   * Given an existing HMM, this method will send in a List of sentences from
   * a possibly different untagged source, to refine the HMM.
   * @param sentences the List of sentences to teach.
   * @return a HMM that has been taught using the observation sequences.
   * @throws Exception if one is thrown.
   */
  public Hmm<ObservationInteger> teach(List<String> sentences)
      throws Exception {
    if (words == null || words.size() == 0) {
      loadWordsFromDictionary();
    }
    OpdfIntegerFactory factory = new OpdfIntegerFactory(words.size());
    List<List<ObservationInteger>> sequences = 
      new ArrayList<List<ObservationInteger>>();
    for (String sentence : sentences) {
      List<ObservationInteger> sequence = 
        getObservations(tokenizeSentence(sentence));
      sequences.add(sequence);
    }
    KMeansLearner<ObservationInteger> kml = 
      new KMeansLearner<ObservationInteger>(
      Pos.values().length, factory, sequences);
    Hmm<ObservationInteger> hmm = kml.iterate();
    // refine it with Baum-Welch Learner
    BaumWelchLearner bwl = new BaumWelchLearner();
    Hmm<ObservationInteger> refinedHmm = bwl.iterate(hmm, sequences);
    return refinedHmm;
  }
  
  /**
   * Convenience method to compute the distance between two HMMs. This can 
   * be used to stop the teaching process once more teaching is not
   * producing any appreciable improvement in the HMM, ie, the HMM
   * converges. The caller will need to match the result of this method 
   * with a number based on experience.
   * @param hmm1 the original HMM.
   * @param hmm2 the HMM that was most recently taught.
   * @return the difference measure between the two HMMs.
   * @throws Exception if one is thrown.
   */
  public double difference(Hmm<ObservationInteger> hmm1,
      Hmm<ObservationInteger> hmm2) throws Exception {
    KullbackLeiblerDistanceCalculator kdc = 
      new KullbackLeiblerDistanceCalculator();
    return kdc.distance(hmm1, hmm2);
  }
  
  private String[] tokenizeSentence(String sentence) {
    String[] tokens = StringUtils.split(
      StringUtils.lowerCase(StringUtils.trim(sentence)), " ");
    return tokens;
  }
  
  private List<ObservationInteger> getObservations(String[] tokens)
      throws Exception {
    if (words == null || words.size() == 0) {
      loadWordsFromDictionary();
    }
    List<ObservationInteger> observations = 
      new ArrayList<ObservationInteger>();
    for (String token : tokens) {
      observations.add(new ObservationInteger(words.get(token)));
    }
    return observations;
  }
  
  private void loadWordsFromDictionary() throws Exception {
    BufferedReader reader = new BufferedReader(
      new FileReader(dictionaryLocation));
    String word;
    int seq = 0;
    while ((word = reader.readLine()) != null) {
      words.put(word, seq);
      seq++;
    }
    reader.close();
  }
}

Word Sense Disambiguation

Given a sentence, a human user can figure the correct POS for each word almost immediately, but with an HMM, we can only tell which is the most likely POS for the word given the sequence of words in the sentence. Obviously, this depends on how large and accurate the HMM's training data set was. Here is how the HmmTagger is called to determine the most likely POS for a word in the sentence.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
  @Test
  public void testWordSenseDisambiguation() throws Exception {
    HmmTagger hmmTagger = new HmmTagger();
    hmmTagger.setDataDir("/opt/brown-2.0");
    hmmTagger.setDictionaryLocation("src/test/resources/brown_dict.txt");
    hmmTagger.setHmmFileName("src/test/resources/hmm_tagger.dat");
    Hmm<ObservationInteger> hmm = 
      hmmTagger.buildFromHmmFile();
    String[] testSentences = new String[] {
      "The dog ran after the cat .",
      "Failure dogs his path .",
      "The cold steel cuts through the flesh .",
      "He had a bad cold .",
      "He will catch the ball .",
      "Salmon is the catch of the day ."
    };
    String[] testWords = new String[] {
      "dog",
      "dogs",
      "cold",
      "cold",
      "catch",
      "catch"
    };
    for (int i = 0; i < testSentences.length; i++) {
      System.out.println("Original sentence: " + testSentences[i]);
      Pos wordPos = hmmTagger.getMostLikelyPos(testWords[i], 
        testSentences[i], hmm); 
      System.out.println("Pos(" + testWords[i] + ")=" + wordPos);
    }
  }

And here are the results. As you can see, the HMM did well on all but the second sentence.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
Original sentence: The dog ran after the cat .
Pos(dog)=NOUN

Original sentence: Failure dogs his path .
Pos(dogs)=NOUN

Original sentence: The cold steel cuts through the flesh .
Pos(cold)=ADJECTIVE

Original sentence: He had a bad cold .
Pos(cold)=NOUN

Original sentence: He will catch the ball .
Pos(catch)=VERB

Original sentence: Salmon is the catch of the day .
Pos(catch)=NOUN

POS Tagging

POS Tagging uses the same algorithm as Word Sense Disambiguation. Given a HMM trained with a sufficiently large and accurate corpus of tagged words, we can now use it to automatically tag sentences from a similar corpus. Here is the JUnit code snippet to do tag the sentences we used in our previous test.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
  @Test
  public void testPosTagging() throws Exception {
    HmmTagger hmmTagger = new HmmTagger();
    hmmTagger.setDataDir("/opt/brown-2.0");
    hmmTagger.setDictionaryLocation("src/test/resources/brown_dict.txt");
    hmmTagger.setHmmFileName("src/test/resources/hmm_tagger.dat");
    Hmm<ObservationInteger> hmm = hmmTagger.buildFromHmmFile();
    // POS tagging
    String[] testSentences = new String[] {...};
    for (int i = 0; i < testSentences.length; i++) {
      System.out.println("Original sentence: " + testSentences[i]);
      System.out.println("Tagged sentence: " + 
        hmmTagger.tagSentence(testSentences[i], hmm));
    }
  }

And here are the results.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
Original sentence: The dog ran after the cat .
Tagged sentence: the/ADJECTIVE dog/NOUN ran/VERB after/OTHER 
  the/ADJECTIVE cat/NOUN ./OTHER 

Original sentence: Failure dogs his path .
Tagged sentence: failure/NOUN dogs/NOUN his/OTHER path/NOUN ./OTHER 

Original sentence: The cold steel cuts through the flesh .
Tagged sentence: the/ADJECTIVE cold/ADJECTIVE steel/NOUN cuts/NOUN 
  through/OTHER the/ADJECTIVE flesh/NOUN ./OTHER 

Original sentence: He had a bad cold .
Tagged sentence: he/OTHER had/VERB a/ADJECTIVE bad/ADJECTIVE cold/NOUN 
  ./OTHER 

Original sentence: He will catch the ball .
Tagged sentence: he/OTHER will/VERB catch/VERB the/ADJECTIVE ball/NOUN 
  ./OTHER 

Original sentence: Salmon is the catch of the day .
Tagged sentence: salmon/NOUN is/VERB the/ADJECTIVE catch/NOUN of/OTHER 
  the/ADJECTIVE day/NOUN ./OTHER 

Sentence Likelihood

HMMs can be used to predict if one sentence is more likely to occur than another one, by comparing the observation probability of a certain sequence of words with another sequence. So for example, we find that the HMM believes that sentences spoken by Master Yoda of Star Wars fame are less likely to occur in "normal" English than sentences expressing similar meaning that you or I would speak.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
  @Test
  public void testObservationProbability() throws Exception {
    HmmTagger hmmTagger = new HmmTagger();
    hmmTagger.setDataDir("/opt/brown-2.0");
    hmmTagger.setDictionaryLocation("src/test/resources/brown_dict.txt");
    hmmTagger.setHmmFileName("src/test/resources/hmm_tagger.dat");
    Hmm<ObservationInteger> hmm = hmmTagger.buildFromHmmFile();
    System.out.println("P(I am worried)=" + 
      hmmTagger.getObservationProbability("I am worried", hmm));
    System.out.println("P(Worried I am)=" +  
      hmmTagger.getObservationProbability("Worried I am", hmm));
  }

As expected, our results indicate that the HMM understands us better than it understands Master Yoda.

1
2
P(I am worried)=5.446081633660202E-11
P(Worried I am)=1.2623833954125002E-11

Unsupervised Learning

The final problem we can solve with a HMM is to build one from a set of untagged data. This HMM can then be used for solving the Sentence Likelihood problem, but not the POS Tagging or the WSD problems. To set this up, I picked up a bunch of of Yoda quotes from this page and fed it into a newly instantiated HMM. I then took the same two sentences and asked the HMM which was more probable. Here is the test code snippet to do that:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
  @Test
  public void testTeachYodaAndObserveProbabilities() throws Exception {
    List<String> sentences = Arrays.asList(new String[] {
      "Powerful you have become .",
      "The dark side I sense in you .",
      "Grave danger you are in .",
      "Impatient you are .",
      "Try not .",
      "Do or do not there is no try .",
      "Consume you the dark path will .",
      "Always in motion is the future .",
      "Around the survivors a perimeter create .",
      "Size matters not .",
      "Blind we are if creation of this army we could not see .",
      "Help you I can yes .",
      "Strong am I with the force .",
      "Agree with you the council does .",
      "Your apprentice he will be .",
      "Worried I am .",
      "Always two there are .",
      "When 900 years you reach look as good you will not ."
    });
    HmmTagger hmmTagger = new HmmTagger();
    hmmTagger.setDataDir("/opt/brown-2.0");
    hmmTagger.setDictionaryLocation("src/test/resources/brown_dict.txt");
    Hmm<ObservationInteger> learnedHmm = hmmTagger.teach(sentences);
    System.out.println("P(Worried I am)=" +  
      hmmTagger.getObservationProbability("Worried I am", learnedHmm));
    System.out.println("P(I am worried)=" + 
      hmmTagger.getObservationProbability("I am worried", learnedHmm));
  }

Now, as you can see, this new HMM understands Yoda better than it understands us :-).

1
2
P(Worried I am)=4.455273233553778E-6
P(I am worried)=2.4569521508568634E-6

Conclusions

Personally, this learning curve was quite a steep one for me. The theory was fairly easy to grasp from an intuitive standpoint, but then understanding how to model the POS tagging problem as a HMM took me a while. Once I crossed that hurdle, it took me a fair bit of effort to figure out how to use Jahmm to build and solve a HMM.

I think it was worth it, though. HMMs are a very powerful modeling tool for text mining, and can be used to model a variety of real life situations. Using a library such as Jahmm means that you just have to figure out how to model your problem and to solve it using the tools provided.

Hopefully, if you've been reading this far, and you started out not knowing or with a vague idea of what an HMM was and how it could be used for POS tagging (as was my situation couple of months ago), this post has provided some information as well as an example of using the Jahmm API to build and solve an HMM.

Update 2009-04-26: In recent posts, I have been building on code written and described in previous posts, so there were (and rightly so) quite a few requests for the code. So I've created a project on Sourceforge to host the code. You will find the complete source code built so far in the project's SVN repository.