Saturday, March 30, 2013

A HMM based Gene Tagger using NLTK


In Prof. Michael Collin's Natural Language Processing class on Coursera, the first programming assignment consisted of building a Gene Named Entity Recognizer using a Hidden Markov Model. The training set consists of 13,795 sentences with entities representing genes marked with an I tag and all other words are marked with a O tag. Also provided are a tagged validation set of 509 sentences to test the algorithm against and a test set of 510 sentences to tag.

While the Coursera Honor code prevents students from sharing solutions (including programming assignments), this post does not qualify as one. First, the assignment expressly forbids the use of NLTK, since the objective of the assignment is to implement the Viterbi algorithm from scratch, and experimenting with different tweaks to make it better. Second, I wasn't able to implement the transition probabilities for trigram backoff properly, although the problem is very likely my imperfect understanding of the NLTK API. So you are better off doing the assignment from scratch as required.

However, I think that a lot of stuff in the NLP/ML domain is applied NLP/ML, so knowing how to solve something with an existing library/toolkit is as important as knowing (and being able to implement) the algorithm itself. I thought it would be interesting to use NLTK to build the tagger, hence this post.

Interestingly, the Masters in Language Technology program at the University of Gothenberg, Sweden, uses NLTK for its lab work, one of which provided me with insights on how to use NLTK's HMM package. Its good to know that there are others with the same opinion as me :-).

An HMM is a function of three probability distributions - the prior probabilities, which describes the probabilities of seeing the different tags in the data; the transition probabilities, which defines the probability of seeing a tag conditioned on the previous tag, and the emission probabilities, which defines the probability of seeing a word conditioned on a tag. An example may make this clearer.

The assignment starts off by asking to construct a simple HMM model for the training data, measuring its performance using the validation data, and finally applying the model to predict the IO tag for the test data. We then identify rare words (those occurring less than 5 times in the training data), and replace them with the string "_RARE_" and measure the performance. The next step is to identify types of the rare words, such as numerics, all caps, etc, and replacing them with "_RARE_xx_" strings, and measure the performance once again. Finally, the model is modified to work with transition probabilities that are conditioned against two previous tags instead of one, ie the trigram model.

The last part is where I had problems. My initial approach was to convert everything to bigrams, so that priors were the probabilities of seeing the different tag pairs in the data; the transition probabilities was the probability of seeing a tag bigram (t2,t3) conditioned on the previous tag bigram (t1,t2); and the emission probabilities was the probability of seeing a word bigram (w1,w2) conditioned on the corresponding tag pair (t1,t2). Mikhail Korobov, one of my classmates at Prof Collin's course, was kind enough to point me to this paper (PDF) which uses a similar approach.

Because the transition probabilities are now sparser, the precision and recall numbers went down pretty drastically. I then attempted to build a conditional probability distribution for the transition probabilities that used a backoff model, ie, if the MLE for the trigram probability (t3|t1,t2) was not available, fall back to the MLE for the bigram probability (t3|t2), and finally to the MLE for the unigram probability N(t3)/N. Unfortunately, this made all the predicted tag bigrams look like (I,O), so the model is now as good as a (binary) broken clock.

Here is the code for the tagger (also available on my GitHub). I wrote the code where I started with the baseline and kept making changes to support the functionalities required by the assignment by adding a new key-value pair as a command line parameter. Each successive key value turns on all previous key values (see the main method to see what I mean). Also the key values are used to identify the results.

1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
from __future__ import division

import itertools
import sys

import collections
import nltk
import nltk.probability
import numpy as np

######################## utility methods ###########################

def findRareWords(train_file):
  """
  Extra pass through the training file to identify rare
  words and deal with them in the accumulation process.
  """
  wordsFD = nltk.FreqDist()
  reader = nltk.corpus.reader.TaggedCorpusReader(".", train_file)
  for word in reader.words():
    wordsFD.inc(word.lower())
  return set(filter(lambda x: wordsFD[x] < 5, wordsFD.keys()))

def normalizeRareWord(word, rareWords, replaceRare):
  """
  Introduce a bit of variety. Even though they are all rare,
  we classify rare words into "kinds" of rare words.
  """
  if word in rareWords:
    if replaceRare:
      if word.isalnum():
        return "_RARE_NUMERIC_"
      elif word.upper() == word:
        return "_RARE_ALLCAPS_"
      elif word[-1:].isupper():
        return "_RARE_LASTCAP"
      else:
        return "_RARE_"
    else:
      return "_RARE_"
  else:
    return word

def pad(sent, tags=True):
  """
  Pad sentences with the start and stop tags and return padded
  sentence.
  """
  if tags:
    padded = [("<*>", "<*>"), ("<*>", "<*>")]
  else:
    padded = ["<*>", "<*>"]
  padded.extend(sent)
  if tags:
    padded.append(("<$>", "<$>"))
  else:
    padded.append("<$>")
  return padded

def calculateMetrics(actual, predicted):
  """
  Returns the number of cases where prediction and actual NER
  tags are the same, divided by the number of tags for the
  sentence.
  """
  pred_p = map(lambda x: "I" if x == "I" else "O", predicted)
  cm = nltk.metrics.confusionmatrix.ConfusionMatrix(actual, pred_p)
  keys = ["I", "O"]
  metrics = np.matrix(np.zeros((2, 2)))
  for x, y in [(x, y) for x in range(0, 2)
                      for y in range(0, 2)]:
    try:
      metrics[x, y] = cm[keys[x], keys[y]]
    except KeyError:
      pass
  tp = metrics[0, 0]
  tn = metrics[1, 1]
  fp = metrics[0, 1]
  fn = metrics[1, 0]
  precision = 0 if (tp + fp) == 0 else tp / (tp + fp)
  recall = 0 if (tp + fn) == 0 else tp / (tp + fn)
  fmeasure = (0 if (precision + recall) == 0
    else (2 * precision * recall) / (precision + recall))
  accuracy = (tp + tn) / (tp + tn + fp + fn)
  return (precision, recall, fmeasure, accuracy)

def writeResult(fout, hmm, words):
  """
  Writes out the result in the required format.
  """
  tags = hmm.best_path(words)
  for word, tag in zip(words, tags)[2:-1]:
    fout.write("%s %s\n" % (word, tag))
  fout.write("\n")

def bigramToUnigram(bigrams):
  """
  Convert a list of bigrams to the equivalent unigram list.
  """
  unigrams = [bigrams[0][0]]
  unigrams.extend([x[1] for x in bigrams])
  return unigrams

def calculateBackoffTransCPD(tagsFD, transCFD, trans2CFD):
  """
  Uses a backoff model to calculate a smoothed conditional
  probability distribution on the training data.
  """
  probDistDict = collections.defaultdict(nltk.DictionaryProbDist)
  tags = tagsFD.keys()
  conds = [x for x in itertools.permutations(tags, 2)]
  for tag in tags:
    conds.append((tag, tag))
  for (t1, t2) in conds:
    probDict = collections.defaultdict(float)
    prob = 0
    for t3 in tags:
      trigramsFD = trans2CFD[(t1, t2)]
      if trigramsFD.N() > 0 and trigramsFD.freq(t3) > 0:
        prob = trigramsFD.freq(t3) / trigramsFD.N()
      else:
        bigramsFD = transCFD[t2]
        if bigramsFD.N() > 0 and bigramsFD.freq(t3) > 0:
          prob = bigramsFD.freq(t3) / bigramsFD.N()
        else:
          prob = tagsFD[t3] / tagsFD.N()
      probDict[t3] = prob
    probDistDict[(t1, t2)] = nltk.DictionaryProbDist(probDict)
  return nltk.DictionaryConditionalProbDist(probDistDict)

class Accumulator:
  """
  Convenience class to accumulate all the frequencies
  into a set of data structures.
  """
  def __init__(self, rareWords, replaceRare, useTrigrams):
    self.rareWords = rareWords
    self.replaceRare = replaceRare
    self.useTrigrams = useTrigrams
    self.words = set()
    self.tags = set()
    self.priorsFD = nltk.FreqDist()
    self.transitionsCFD = nltk.ConditionalFreqDist()
    self.outputsCFD = nltk.ConditionalFreqDist()
    # additional data structures for trigram
    self.transitions2CFD = nltk.ConditionalFreqDist()
    self.tagsFD = nltk.FreqDist()

  def addSentence(self, sent, norm_func):
    # preprocess
    unigrams = [(norm_func(word, self.rareWords, self.replaceRare), tag)
      for (word, tag) in sent]
    prevTag = None
    prev2Tag = None
    if self.useTrigrams:
      # each state is represented by a tag bigram
      bigrams = nltk.bigrams(unigrams)
      for ((w1, t1), (w2, t2)) in bigrams:
        self.words.add((w1, w2))
        self.tags.add((t1, t2))
        self.priorsFD.inc((t1, t2))
        self.outputsCFD[(t1, t2)].inc((w1, w2))
        if prevTag is not None:
          self.transitionsCFD[prevTag].inc(t2)
        if prev2Tag is not None:
          self.transitions2CFD[prev2Tag].inc((t1, t2))
        prevTag = t2
        prev2Tag = (t1, t2)
        self.tagsFD.inc(prevTag)
    else:
      # each state is represented by an tag unigram
      for word, tag in unigrams:
        self.words.add(word)
        self.tags.add(tag)
        self.priorsFD.inc(tag)
        self.outputsCFD[tag].inc(word)
        if prevTag is not None:
          self.transitionsCFD[prevTag].inc(tag)
        prevTag = tag

####################### train, validate, test ##################

def train(train_file, 
    rareWords, replaceRare, useTrigrams, trigramBackoff):
  """
  Read the file and populate the various frequency and
  conditional frequency distributions and build the HMM
  off these data structures.
  """
  acc = Accumulator(rareWords, replaceRare, useTrigrams)
  reader = nltk.corpus.reader.TaggedCorpusReader(".", train_file)
  for sent in reader.tagged_sents():
    unigrams = pad(sent)
    acc.addSentence(unigrams, normalizeRareWord)
  if useTrigrams:
    if trigramBackoff:
      backoffCPD = calculateBackoffTransCPD(acc.tagsFD, acc.transitionsCFD,
        acc.transitions2CFD)
      return nltk.HiddenMarkovModelTagger(list(acc.words), list(acc.tags),
        backoffCPD,
        nltk.ConditionalProbDist(acc.outputsCFD, nltk.ELEProbDist),
        nltk.ELEProbDist(acc.priorsFD))
    else:
      return nltk.HiddenMarkovModelTagger(list(acc.words), list(acc.tags),
        nltk.ConditionalProbDist(acc.transitions2CFD, nltk.ELEProbDist,
        len(acc.transitions2CFD.conditions())),
        nltk.ConditionalProbDist(acc.outputsCFD, nltk.ELEProbDist),
        nltk.ELEProbDist(acc.priorsFD))
  else:
    return nltk.HiddenMarkovModelTagger(list(acc.words), list(acc.tags),
      nltk.ConditionalProbDist(acc.transitionsCFD, nltk.ELEProbDist,
      len(acc.transitionsCFD.conditions())),
      nltk.ConditionalProbDist(acc.outputsCFD, nltk.ELEProbDist),
      nltk.ELEProbDist(acc.priorsFD))

def validate(hmm, validation_file, rareWords, replaceRare, useTrigrams):
  """
  Tests the HMM against the validation file.
  """
  precision = 0
  recall = 0
  fmeasure = 0
  accuracy = 0
  nSents = 0
  reader = nltk.corpus.reader.TaggedCorpusReader(".", validation_file)
  for sent in reader.tagged_sents():
    sent = pad(sent)
    words = [word for (word, tag) in sent]
    tags = [tag for (word, tag) in sent]
    normWords = map(lambda x: normalizeRareWord(
      x, rareWords, replaceRare), words)
    if useTrigrams:
      # convert words to word bigrams
      normWords = nltk.bigrams(normWords)
    predictedTags = hmm.best_path(normWords)
    if useTrigrams:
      # convert tag bigrams back to unigrams
      predictedTags = bigramToUnigram(predictedTags)
    (p, r, f, a) = calculateMetrics(tags[2:-1], predictedTags[2:-1])
    precision += p
    recall += r
    fmeasure += f
    accuracy += a
    nSents += 1
  print("Accuracy=%f, Precision=%f, Recall=%f, F1-Measure=%f\n" %
    (accuracy/nSents, precision/nSents, recall/nSents,
    fmeasure/nSents))

def test(hmm, test_file, result_file, rareWords, replaceRare, useTrigrams):
  """
  Tests the HMM against the test file (without tags) and writes
  out the results to the result file.
  """
  fin = open(test_file, 'rb')
  fout = open(result_file, 'wb')
  for line in fin:
    line = line.strip()
    words = pad([word for word in line.split(" ")], tags=False)
    normWords = map(lambda x: normalizeRareWord(
      x, rareWords, replaceRare), words)
    if useTrigrams:
      # convert words to word bigrams
      normWords = nltk.bigrams(normWords)
    tags = hmm.best_path(normWords)
    if useTrigrams:
      # convert tag bigrams back to unigrams
      tags = bigramToUnigram(tags)
    fout.write(" ".join(["/".join([word, tag])
      for (word, tag) in (zip(words, tags))[2:-1]]) + "\n")
  fin.close()
  fout.close()

def main():
  normalizeRare = False
  replaceRare = False
  useTrigrams = False
  trigramBackoff = False
  if len(sys.argv) > 1:
    args = sys.argv[1:]
    for arg in args:
      k, v = arg.split("=")
      if k == "normalize-rare":
        normalizeRare = True if v.lower() == "true" else False
      elif k == "replace-rare":
        normalizeRare = True
        replaceRare = True if v.lower() == "true" else False
      elif k == "use-trigrams":
        normalizeRare = True
        replaceRare = True
        useTrigrams = True if v.lower() == "true" else False
      elif k == "trigram-backoff":
        normalizeRare = True
        replaceRare = True
        useTrigrams = True
        trigramBackoff = True if v.lower() == "true" else False
      else:
        continue
  rareWords = set()
  if normalizeRare:
    rareWords = findRareWords("gene.train")
  hmm = train("gene.train",
    rareWords, replaceRare, useTrigrams, trigramBackoff)
  validate(hmm, "gene.validate", rareWords, replaceRare, useTrigrams)
  test(hmm, "gene.test", "gene.test.out",
    rareWords, replaceRare, useTrigrams)
  
if __name__ == "__main__":
  main()

And here are the results.

Model Accuracy Precision Recall F1-Measure
(baseline) 0.941894 0.337643 0.346306 0.324261
normalize-rare 0.940066 0.319570 0.333753 0.309895
replace-rare 0.940960 0.319570 0.335538 0.311532
use-trigrams 0.901701 0.001742 0.011788 0.002912
trigram-backoff 0.902182 0.000000 0.000000 0.000000

As you can see, the best results in this case seem to come from the baseline bigram model. Prior to this, I had built an NER by modeling it as a classification problem. However, the use of HMMs for this problem domain looks fairly common, as we can see from the papers here (PDF) and here (PDF). Another resource I came across that I want to explore further is NLTK's TnT (Trigrams'n'Tags) tagger, which seems to be a better fit for the trigram-backoff model.

No comments:

Post a Comment

Comments are moderated to prevent spam.