Thursday, March 27, 2014

Parsing Drug Dosages in text using Finite State Machines


Someone recently pointed out an issue with the Drug Dosage FSM in Apache cTakes on the cTakes mailing list. Looking at the code for it revealed a fairly complex implementation based on a hierarchy of Finite State Machines (FSM). The intuition behind the implementation is that Drug Dosage text in doctor's notes tend to follow a standard-ish format, and FSMs can be used to exploit this structure and pull out relevant entities out of this text. The paper Extracting Structured Medication Event Information from Discharge Summaries has more information about this problem. The authors provide their own solution, called the Merki Medication Parser. Here is a link to their Online Demo and source code (Perl).

I've never used FSMs myself, although I have seen it used to model (more structured) systems. So the idea of using FSMs for parsing semi-structured text such as this seemed interesting and I decided to try it out myself. The implementation I describe here is nowhere nearly as complex as the one in cTakes, but on the flip side, is neither as accurate, nor broad nor bulletproof either.

My solution uses drug dosage phrase data provided in this Pattern Matching article by Erin Rhode (which also comes with a Perl based solution), as well as its dictionaries (with additions by me), to model the phrases with the state diagram below. I built the diagram by eyeballing the outputs from Erin Rhode's program. I then implement the state diagram with a home-grown FSM implementation based on ideas from Electric Monk's post on FSMs in Python and the documentation for the Java library Tungsten FSM. I initially tried to use Tungsten-FSM, but ended up with extremely verbose Scala code because of Scala's stricter generics system.


The FSM implementation is described below. It is built up by the client adding all the states (the nodes in the diagram above) and transitions (the edges connecting the nodes). Each transition specifies a Guard, essentially a predicate that dictates if the machine can transition to the target state, and an Action that defines the actions to be taken when the transition is attempted. For each input text to be parsed, the client constructs a new FSM, then passes in a list of tokens to the run() method, which applies the transition to each token in the list.

1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
// Source: src/main/scala/com/mycompany/scalcium/utils/FSM.scala
package com.mycompany.scalcium.utils

import scala.collection.mutable.ArrayBuffer

trait Guard[T] {
  def accept(token: T): Boolean
}

trait Action[T] {
  def perform(currState: String, token: T): Unit
}

class FSM[T](val debug: Boolean = false) {

  val states = ArrayBuffer[String]()
  val transitions = scala.collection.mutable.Map[
    String,ArrayBuffer[(String,Guard[T],Action[T])]]()
  var currState: String = "START"
  
  def addState(state: String): Unit = {
    states += state
  }
  
  def addTransition(from: String, to: String,
      guard: Guard[T], action: Action[T]): Unit = {
    val translist = transitions.getOrElse(from, ArrayBuffer()) 
    translist += ((to, guard, action))
    transitions.put(from, translist)
  } 
  
  def transition(token: T): Unit = {
    val targetStates = transitions.getOrElse(currState, List())
      .filter(tga => tga._2.accept(token))
    if (targetStates.size == 1) {
      val targetState = targetStates.head
      if (debug)
        Console.println("%s -> %s".format(currState, targetState._1))
      currState = targetState._1
      targetState._3.perform(currState, token)
    } else {
      transitions.get(currState) match {
        case Some(x) => x.head._3.perform(currState, token)
        case None => {}
      }
    }
  }
  
  def run(tokens: List[T]): Unit = tokens.foreach(transition(_))
}

The client instantiates the FSM, then populates its states and transitions by calls to addState() and addTransition() respectively. The client also defines custom Guard objects. In general, a Guard for a transition from A to B checks to see if the current token is acceptable for B. If so, it makes the transition, else it remains in state A. Since there are multiple states transitioning to the END state, we want to make it the least attractive target state, hence the noGuard - a state transition from some state to the END state will only happen if the token stream is exhausted and there are no better alternatives.

The other Guards are the DictGuard, the RegexGuard and the CombinedGuard. The DictGuard is a Gazetteer, accepting a token that matches an entry in its input file. The RegexGuard matches a pattern that is specified in its input file, and a CombinedGuard is a combination of a DictGuard and RegexGuard.

The CollectAction captures (state, token) pairs as the FSM processes the tokens. The pairs are returned as a Map[state,List[token]] using the getSymbolTable() method of CollectAction. The action is associated with the FSM rather than a state, but I retained the Action per State idea just in case I ever need it.

1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
// Source: src/main/scala/com/mycompany/scalcium/drugdosage/DrugDosageFSM.scala
package com.mycompany.scalcium.drugdosage

import java.io.File
import java.util.regex.Pattern

import scala.collection.TraversableOnce.flattenTraversableOnce
import scala.collection.mutable.ArrayBuffer
import scala.io.Source

import com.mycompany.scalcium.utils.Action
import com.mycompany.scalcium.utils.FSM
import com.mycompany.scalcium.utils.Guard

class BoolGuard(val refValue: Boolean) extends Guard[String] {
  override def accept(token: String): Boolean = refValue
}

class DictGuard(val file: File) extends Guard[String] {
  val words = Source.fromFile(file).getLines()
    .map(line => line.toLowerCase().split(" ").toList)
    .flatten
    .toSet
  
  override def accept(token: String): Boolean = {
    words.contains(token.toLowerCase())
  }
}

class RegexGuard(val file: File) extends Guard[String] {
  val patterns = Source.fromFile(file).getLines()
    .map(line => Pattern.compile(line))
    .toList
  
  override def accept(token: String): Boolean = {
    patterns.map(pattern => {
      val matcher = pattern.matcher(token)
      if (matcher.matches()) return true
    })
    false
  }
}

class CombinedGuard(val file: File, val pfile: File) 
      extends Guard[String] {
  val words = Source.fromFile(file).getLines()
    .map(line => line.toLowerCase().split(" ").toList)
    .flatten
    .toSet
  val patterns = Source.fromFile(pfile).getLines()
    .map(line => Pattern.compile(line))
    .toList
    
  override def accept(token: String): Boolean = {
    acceptWord(token) || acceptPattern(token)
  }
  
  def acceptWord(token: String): Boolean = 
    words.contains(token.toLowerCase())
    
  def acceptPattern(token: String): Boolean = {
    patterns.map(pattern => {
      val matcher = pattern.matcher(token)
      if (matcher.matches()) return true
    })
    false
  }
}

class CollectAction(val debug: Boolean) extends Action[String] {
  val stab = new ArrayBuffer[(String,String)]()
  
  override def perform(currState: String, token: String): Unit = {
    if (debug)
      Console.println("setting: %s to %s".format(token, currState))
    stab += ((currState, token))
  }
  
  def getSymbolTable(): Map[String,List[String]] = {
    stab.groupBy(kv => kv._1)
      .map(kv => (kv._1, kv._2.map(_._2).toList))
  }
}

class DrugDosageFSM(val drugFile: File, 
    val freqFile: File, 
    val routeFile: File,
    val unitsFile: File,
    val numPatternsFile: File,
    val debug: Boolean = false) {

  
  def parse(s: String): Map[String,List[String]] = {
    val collector = new CollectAction(debug)
    val fsm = buildFSM(collector, debug)
    val x = fsm.run(s.toLowerCase()
        .replaceAll("[,;]", " ")
        .replaceAll("\\s+", " ")
        .split(" ")
        .toList)
    collector.getSymbolTable()
  }
  
  def buildFSM(collector: CollectAction, debug: Boolean): FSM[String] = {
    val fsm = new FSM[String](debug)
    // states
    fsm.addState("START")
    fsm.addState("DRUG")
    fsm.addState("DOSAGE")
    fsm.addState("ROUTE")
    fsm.addState("FREQ")
    fsm.addState("QTY")
    fsm.addState("REFILL")
    fsm.addState("END")
  
    val noGuard = new BoolGuard(false)
    val drugGuard = new DictGuard(drugFile)
    val dosageGuard = new CombinedGuard(unitsFile, numPatternsFile)
    val freqGuard = new DictGuard(freqFile)
    val qtyGuard = new RegexGuard(numPatternsFile)
    val routeGuard = new DictGuard(routeFile)
    val refillGuard = new CombinedGuard(unitsFile, numPatternsFile)
    
    // transitions
    fsm.addTransition("START", "DRUG", drugGuard, collector)
    fsm.addTransition("DRUG", "DOSAGE", dosageGuard, collector)
    fsm.addTransition("DRUG", "FREQ", freqGuard, collector)
    fsm.addTransition("DRUG", "ROUTE", routeGuard, collector)
    fsm.addTransition("DOSAGE", "ROUTE", routeGuard, collector)
    fsm.addTransition("DOSAGE", "FREQ", freqGuard, collector)
    fsm.addTransition("ROUTE", "FREQ", freqGuard, collector)
    fsm.addTransition("FREQ", "QTY", qtyGuard, collector)
    fsm.addTransition("FREQ", "END", noGuard, collector)
    fsm.addTransition("QTY", "REFILL", refillGuard, collector)
    fsm.addTransition("QTY", "END", noGuard, collector)
    fsm.addTransition("REFILL", "END", noGuard, collector)
    
    fsm
  }
}  

We call this client using the JUnit test code below:

1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
// Source: src/test/scala/com/mycompany/scalcium/drugdosage/DrugDosageFSMTest.scala
package com.mycompany.scalcium.drugdosage

import org.junit.Test
import java.io.File
import scala.io.Source
import java.io.PrintWriter
import java.io.FileWriter

class DrugDosageFSMTest {

  val datadir = "/path/to/data/directory"
    
  @Test
  def testParse(): Unit = {
    val ddFSM = new DrugDosageFSM(
        new File(datadir, "drugs.dict"),
        new File(datadir, "frequencies.dict"),
        new File(datadir, "routes.dict"),
        new File(datadir, "units.dict"),
        new File(datadir, "num_patterns.dict"),
        false)
    val writer = new PrintWriter(new FileWriter(
      new File(datadir, "fsm_output.txt")))
    Source.fromFile(new File(datadir, "input.txt"))
      .getLines()
      .foreach(line => {
         val stab = ddFSM.parse(line)
         writer.println(line)
         stab.toList
           .sortBy(kv => kv._1)
           .foreach(kv => 
             writer.println(kv._1 + ": " + kv._2.mkString(" ")))
         writer.println()
    })
    writer.flush()
    writer.close()
  }
}

As mentioned already, the input data and the entity dictionaries come from Erin Rhode's Pattern Matching article. I have made a few additions to these files, and added a new file for specifying numeric patterns called num_patterns.dict, so my version of the data files can be found on Github here.

Although the output is not perfect, it doesn't look too bad. Here are the first 5 results. The full output can be viewed in the fsm_output.txt file on Github in the link above.

1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
hydrocortizone cream, apply to rash bid
DRUG: hydrocortizone cream
FREQ: bid
ROUTE: apply to rash

albuterol inhaler one to two puffs bid
DOSAGE: one to two puffs
DRUG: albuterol inhaler
FREQ: bid

Enteric coated aspirin 81 mg tablets one qd
DOSAGE: 81 mg tablets one
DRUG: enteric coated aspirin
FREQ: qd

Vitamin B12 1000 mcg IM
DOSAGE: 1000 mcg
DRUG: vitamin b12
ROUTE: im

atenolol 50 mg tabs one qd, #100, one year
DOSAGE: 50 mg tabs one
DRUG: atenolol
FREQ: qd #100
QTY: one year
...

There is obviously lot of scope for improvement here. For one thing, I suspect that my state diagram is much too simple to model the phrases. But its a start, and I plan on exploring some other solutions as well. More when I do.

Wednesday, March 19, 2014

Using Weka to identify Smokers from Patient Discharge Data


The ability to identify smokers from the general population is important to several health studies. One such initiative was organized by the Informatics for Integrating Biology to the Bedside (i2b2) project, as described in this article. The objective is to build a model that will predict the smoking status for a previously unseen patient discharge note. Seemed to be an interesting problem, so I thought I'll try my hand at it. This post describes that effort.

Input


The dataset is provided as 3 XML files, each corresponding to the training, evaluation (ground-truth) and testing (submission) datasets. Since I am not participating in the challenge, I am only interested in the first two. The training set has 398 patient records and the evaluation set has 104. In both cases, the record is annotated with one of 5 smoking statuses - CURRENT SMOKER, PAST SMOKER, NON-SMOKER, SMOKER and UNKNOWN. You can get these datasets here after registering and signing an end-user NDA.

The files have the following structure.

1
 2
 3
 4
 5
 6
 7
 8
 9
10
<ROOT>
  <RECORD ID="123">
    <SMOKING STATUS="CURRENT_SMOKER"></SMOKING>
    <TEXT>
Plain text with sentences and non-sentences. Sentences are terminated
with period (.).
    </TEXT>
  </RECORD>
... more records ...
</ROOT>

Prior Work


Since the experiment is as much an attempt to educate as entertain myself, I decided to take a look at Apache cTakes and Hitex, two popular open-source projects that I know of that analyze and extract information from data in Clinical pipelines, and see how they are doing it. From a limited scan of the code, both of them seem to use each sentence in the (multi-sentence) record for training, setting the target variable for each sentence as the SMOKING@STATUS value for the entire record. At test/prediction time, individual sentences are passed in and the class predicted. Presumably the smoking status attributed to the full discharge note is the maximally predicted class across the sentences (other than UNKNOWN). Both systems use Support Vector Machine (SVM) classifiers.

My Approach - Overview


Although the approach seemed to address the problem of extremely wide data (many more features than examples), I felt that it diluted the value of the smoking annotation, since only a few sentences in the entire record would specifically indicate the status.

So my first attempt was to build a multi-class classifier that treated each record as a training row. I experimented with various classification algorithms and various sizes of feature sets using the Weka GUI. My best model had an accuracy of 91.21% on the training data (indicating that it wasn't underfitting) and an accuracy of 71.15% on the evaluation data.

Hoping to do better, my second attempt exploits the structure of the target variable to first decide smoking vs non-smoking, then sends the data down two other classifiers which decide the type of smoker or non-smoker. Overall results was slightly better, with an accuracy of 95.48% against training data and 73.08% accuracy against evaluation data. The top level classifier in the stack to detect smoker vs non-smoker did considerably better, achieving an accuracy of 96.23% against training data and 84.62% against evaluation data.

In both cases, the input XML files was read by a parser to extract the smoking status and the sentences from the text. The records were written out into ARFF files (Weka's standard input format). Here is the code for parsing the XML files and writing out ARFF files. The buildMultiClassArff() method builds the ARFF file for the first approach and the buildSmokerNonSmokerArff() and buildSubClassifierArffs() methods are used for the second approach.

1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
// Source: src/main/scala/com/mycompany/scalcium/smoke/ArffBuilder.scala
package com.mycompany.scalcium.smoke

import java.io.File
import java.io.PrintWriter
import java.io.FileWriter

class ArffBuilder {

  def buildMulticlassArff(xmlin: File, 
      arffout: File): Unit = {
    val preprocessor = new Preprocessor()
    val arffWriter = openWriter(arffout)
    writeHeader(arffWriter, preprocessor.multiClassTargets)
    val rootElem = scala.xml.XML.loadFile(xmlin)
    (rootElem \ "RECORD").map(record => {
      val smoking = (record \ "SMOKING" \ "@STATUS").text
      val text = (record \ "TEXT").text
        .split("\n")
        .filter(line => line.endsWith("."))
        .map(line => preprocessor.preprocess(line))
        .mkString(" ")
      writeData(arffWriter, smoking, text,
        preprocessor.multiClassTargets)
    })
    closeWriter(arffWriter)
  }
  
  def buildSmokerNonSmokerArff(xmlin: File, 
      arffout: File): Unit = {
    val preprocessor = new Preprocessor()
    val arffWriter = openWriter(arffout)
    writeHeader(arffWriter, 
      preprocessor.smokerNonSmokerTargets)
    val rootElem = scala.xml.XML.loadFile(xmlin)
    (rootElem \ "RECORD").map(record => {
      val smoking = (record \ "SMOKING" \ "@STATUS").text
      val text = (record \ "TEXT").text
        .split("\n")
        .filter(line => line.endsWith("."))
        .map(line => preprocessor.preprocess(line))
        .mkString(" ")
      writeData(arffWriter, smoking, text, 
        preprocessor.smokerNonSmokerTargets)
    })
    closeWriter(arffWriter)
  }
  
  def buildSubClassifierArrfs(xmlin: File,
      smokSubArff: File, nonSmokArff: File): Unit = {
    val preprocessor = new Preprocessor()
    val smokArffWriter = openWriter(smokSubArff)
    val nonSmokArffWriter = openWriter(nonSmokArff)
    writeHeader(smokArffWriter, 
      preprocessor.smokerSubTargets)
    writeHeader(nonSmokArffWriter, 
      preprocessor.nonSmokerSubTargets)
    val rootElem = scala.xml.XML.loadFile(xmlin)
    (rootElem \ "RECORD").map(record => {
      val smoking = (record \ "SMOKING" \ "@STATUS").text
      val text = (record \ "TEXT").text
        .split("\n")
        .filter(line => line.endsWith("."))
        .map(line => preprocessor.preprocess(line))
        .mkString(" ")
      if (preprocessor.smokerNonSmokerTargets(smoking) == 1)
        writeData(smokArffWriter, smoking, text, 
          preprocessor.smokerSubTargets)
      else 
        writeData(nonSmokArffWriter, smoking, text, 
          preprocessor.nonSmokerSubTargets)
    })
    closeWriter(smokArffWriter)
    closeWriter(nonSmokArffWriter)
  }
  
  def openWriter(f: File): PrintWriter = 
    new PrintWriter(new FileWriter(f))
  
  def writeHeader(w: PrintWriter, 
      targets: Map[String,Int]): Unit = {
    val classes = targets.map(kv => kv._2)
                         .toSet
                         .toList
                         .sortWith(_ < _)
                         .mkString(",")
    w.println("@relation smoke")
    w.println()
    w.println("@attribute class {%s}".format(classes)) 
    w.println("@attribute text string")
    w.println()
    w.println("@data")
  }
  
  def writeData(w: PrintWriter, 
      smoking: String, body: String,
      targets: Map[String,Int]): Unit = {
    w.println("%d,\"%s\"".format(
      targets(smoking), body))
  }
  
  def closeWriter(w: PrintWriter): Unit = {
    w.flush()
    w.close()
  }
}

Each sentence in the text was normalized by lowercasing and removing punctuations and stopwords. The preprocessing logic is kept in a separate class because the test data needs to be preprocessed in the same way during prediction/evaluation time. Here is the code for the preprocessor.

1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
// Source: src/main/scala/com/mycompany/scalcium/smoke/Preprocessor.scala
package com.mycompany.scalcium.smoke

import java.util.regex.Pattern
import org.apache.lucene.analysis.core.StopAnalyzer
import com.mycompany.scalcium.utils.Tokenizer

class Preprocessor {

  val punctPattern = Pattern.compile("\\p{Punct}")
  val spacePattern = Pattern.compile("\\s+")
  val classPattern = Pattern.compile("class")
  
  val stopwords = StopAnalyzer.ENGLISH_STOP_WORDS_SET
  val tokenizer = Tokenizer.getTokenizer("opennlp")

  val multiClassTargets = Map(
    ("CURRENT SMOKER", 0),
    ("NON-SMOKER", 1),
    ("PAST SMOKER", 2),
    ("SMOKER", 3),
    ("UNKNOWN", 4))
  val smokerNonSmokerTargets = Map(
    ("CURRENT SMOKER", 1),
    ("PAST SMOKER", 1),
    ("SMOKER", 1),
    ("NON-SMOKER", 0),
    ("UNKNOWN", 0))
  val smokerSubTargets = Map(
    ("CURRENT SMOKER", 0),
    ("PAST SMOKER", 1),
    ("SMOKER", 2))
  val nonSmokerSubTargets = Map(
    ("NON-SMOKER", 0),
    ("UNKNOWN", 1))

  def preprocess(sin: String): String = {
    val sinClean = replaceAll(classPattern, "clazz",
      replaceAll(spacePattern, " ",
      replaceAll(punctPattern, " ", sin.toLowerCase())))
    // stopword removal
    val sinStp = tokenizer.wordTokenize(sinClean)
      .filter(word => !stopwords.contains(word))
      .mkString(" ")
    sinStp
  }
  
  def replaceAll(pattern: Pattern, 
      replacement: String, 
      input: String): String = {
    pattern.matcher(input)
      .replaceAll(replacement)
  }
}

The ARFF file so created were vectorized and binarized from the Weka GUI using Weka's built in filters - the StringToWordVector for vectorizing the text and NumericToBinary for binarizing the resulting word vector. The resulting preprocessed model was saved in a vectorized ARFF which was tested using 3-fold cross validation against various built-in Weka classifiers.

The winning model(s) were finally evaluated against the training data (to verify that we don't have an underfitting model) and the evaluation data (to verify the "true" quality of the solution). Since we have to vectorize and binarize the evaluation data, we need access to the original features, so our evaluation process first uses the training data to train the classifier and uses the trained classifier to predict the target variable for the evaluation set. The code to do this is shown below - the evaluate() method does this for the first approach and the evaluateStacked() does this for the second approach.

1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
// Source: src/main/scala/com/mycompany/scalcium/smoke/SmokingClassifier.scala
package com.mycompany.scalcium.smoke

import java.io.BufferedReader
import java.io.File
import java.io.FileReader
import scala.Array.canBuildFrom
import weka.classifiers.trees.J48
import weka.core.Attribute
import weka.core.Instances
import weka.core.SparseInstance
import weka.classifiers.Classifier
import weka.classifiers.functions.SMO
import weka.classifiers.rules.PART
import weka.core.Instance

case class TrainedModel(classifier: Classifier, 
    dataset: Instances, 
    vocab: Map[String,Int])
    
class SmokingClassifier(arffIn: File, 
    smokingArffIn: File = null, 
    nonSmokingArffIn: File = null) {

  val preprocessor = new Preprocessor()
  
  val trainedModel = trainJ48(arffIn)
  val trainedSmokSubModel = trainSMO(smokingArffIn)
  val trainedNonSmokSubModel = trainPART(nonSmokingArffIn)

  def evaluate(testfile: File, 
      targets: Map[String,Int], 
      model: TrainedModel): Double = {
    var numTested = 0D
    var numCorrect = 0D
    val rootElem = scala.xml.XML.loadFile(testfile)
    (rootElem \ "RECORD").map(record => {
      val smoking = (record \ "SMOKING" \ "@STATUS").text
      val text = (record \ "TEXT").text
        .split("\n")
        .filter(line => line.endsWith("."))
        .map(line => preprocessor.preprocess(line))
        .mkString(" ")
      val ypred = predict(text, model)
      numTested += 1D
      if (ypred == targets(smoking)) 
        numCorrect += 1D
    })
    100.0D * numCorrect / numTested
  }
  
  def evaluateStacked(testfile: File,
      topModel: TrainedModel,
      nonSmokingSubModel: TrainedModel,
      smokingSubModel: TrainedModel): Double = {
    var numTested = 0D
    var numCorrect = 0D
    val rootElem = scala.xml.XML.loadFile(testfile)
    (rootElem \ "RECORD").map(record => {
      val smoking = (record \ "SMOKING" \ "@STATUS").text
      val text = (record \ "TEXT").text
        .split("\n")
        .filter(line => line.endsWith("."))
        .map(line => preprocessor.preprocess(line))
        .mkString(" ")
      val topPred = predict(text, topModel)
      if (topPred == 1) { // smoking
        val subPred = predict(text, smokingSubModel)
        if (preprocessor.smokerSubTargets.contains(smoking) &&
            subPred == preprocessor.smokerSubTargets(smoking))
          numCorrect += 1
      } else { // non-smoking
        val subPred = predict(text, nonSmokingSubModel)
        if (preprocessor.nonSmokerSubTargets.contains(smoking) &&
            subPred == preprocessor.nonSmokerSubTargets(smoking))
          numCorrect += 1
      }
      numTested += 1
    })    
    100.0D * numCorrect / numTested
  }

  /////////////////// predict ////////////////////
  
  def predict(input: String, 
      model: TrainedModel): Int = {
    val inst = buildInstance(input, model)
    val pdist = model.classifier.distributionForInstance(
      inst)
    pdist.zipWithIndex.maxBy(_._1)._2
  }
  
  def buildInstance(input: String, 
      model: TrainedModel): Instance = {
    val inst = new SparseInstance(model.vocab.size)
    input.split(" ")
      .foreach(word => {
        if (model.vocab.contains(word)) {
          inst.setValue(model.vocab(word), 1)
        }
    })
    inst.setDataset(model.dataset)
    inst
  }

  /////////////////// train models //////////////////
  
  def trainJ48(arff: File): TrainedModel = {
    trainModel(arff, new J48())  
  }
  
  def trainSMO(arff: File): TrainedModel = {
    if (arff == null) null 
    else {
      val smo = new SMO()
      smo.setC(0.1D)
      smo.setToleranceParameter(0.1D)
      trainModel(arff, smo)
    }
  }
  
  def trainPART(arff: File): TrainedModel = {
    if (arff == null) null
    else trainModel(arff, new PART())
  }
  
  def trainModel(arff: File, 
      classifier: Classifier): TrainedModel = {
    val reader = new BufferedReader(new FileReader(arff))
    val _instances = new Instances(reader)
    reader.close()
    _instances.setClassIndex(0)
    val _vocab = scala.collection.mutable.Map[String,Int]()
    val e = _instances.enumerateAttributes()
    while (e.hasMoreElements()) {
      val attrib = e.nextElement().asInstanceOf[Attribute]
      if (! "class".equals(attrib.name())) {
        // replace the _binarized suffix
        val stripname = attrib.name().replace("_binarized", "")
        _vocab += ((stripname, attrib.index()))
      }
    }
    classifier.buildClassifier(_instances)
    TrainedModel(classifier, _instances, _vocab.toMap)
  }
}

Finally, I used this JUnit test class to selectively call methods as I progressed through the experiment.

1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
// Source: src/test/scala/com/mycompany/scalcium/smoke/SmokingClassifierTest.scala
package com.mycompany.scalcium.smoke

import org.junit.Test
import java.io.File
import org.apache.commons.io.FileUtils
import java.util.regex.Pattern

class SmokingClassifierTest {

  val datadir = new File("/path/to/my/data")
  val trainfile = new File(datadir, "smokers_surrogate_train_all_version2.xml")
  val testfile = new File(datadir, "smokers_surrogate_test_all_groundtruth_version2.xml")
  val preprocessor = new Preprocessor()
  
  @Test
  def testBuildMultiClassArff(): Unit = {
    val ab = new ArffBuilder()
    ab.buildMulticlassArff(trainfile,
      new File(datadir, "smoke_mc.arff"))
  }

  @Test
  def testBuildSmokerNonSmokerArff(): Unit = {
    val ab = new ArffBuilder()
    ab.buildSmokerNonSmokerArff(trainfile,
      new File(datadir, "smoke_sns.arff"))
  }
  
  @Test
  def testBuildSubclassifierArffs(): Unit = {
    val ab = new ArffBuilder()
    ab.buildSubClassifierArrfs(trainfile, 
      new File(datadir, "smoke_subs.arff"), 
      new File(datadir, "smoke_subn.arff"))
  }

  @Test
  def testEvaluateMultiClass(): Unit = {
    val sc = new SmokingClassifier(
      new File(datadir, "smoke_mc_vectorized.arff"))
    val trainResult = sc.evaluate(trainfile, 
      preprocessor.multiClassTargets, 
      sc.trainedModel)
    Console.println("training accuracy=%f".format(trainResult))
    val testResult = sc.evaluate(testfile,
      preprocessor.multiClassTargets,
      sc.trainedModel)
    Console.println("test accuracy=%f".format(testResult))
  }
  
  @Test
  def testEvaluateSmokerNonSmoker(): Unit = {
    val sc = new SmokingClassifier(
      new File(datadir, "smoke_sns_vectorized.arff"))
    val trainResult = sc.evaluate(trainfile, 
      preprocessor.smokerNonSmokerTargets,
      sc.trainedModel)
    Console.println("training accuracy=%f".format(trainResult))
    val testResult = sc.evaluate(testfile, 
      preprocessor.smokerNonSmokerTargets,
      sc.trainedModel)
    Console.println("test accuracy=%f".format(testResult))
  }
  
  @Test
  def testEvaluateStackedClassifier(): Unit = {
    val sc = new SmokingClassifier(
      new File(datadir, "smoke_sns_vectorized.arff"),
      new File(datadir, "smoke_subs_vectorized.arff"),
      new File(datadir, "smoke_subn_vectorized.arff"))
    val trainResult = sc.evaluateStacked(trainfile, 
      sc.trainedModel, sc.trainedNonSmokSubModel, 
      sc.trainedSmokSubModel)
    Console.println("training accuracy=%f".format(trainResult))
    val testResult = sc.evaluateStacked(testfile, 
      sc.trainedModel, sc.trainedNonSmokSubModel, 
      sc.trainedSmokSubModel)
    Console.println("test accuracy=%f".format(testResult))
  }

}

Analysis and Results - Approach #1


I tested 5 different (Weka built-in) classifiers against different sizes of the word vector. By default, StringToWordVector creates a vector from the top 1000 word features. I varied this from 1000 to 5000 for each algorithm and plotted the results. The data and the corresponding plot is shown below:

#-wordsNaiveBayesSMOJ48RandomForestPART
1000w57.2967.5970.8561.8165.08
2000w59.5567.8479.962.8176.63
3000w59.5566.3379.1563.5676.38
4000w61.8166.5878.6460.375.62
5000w61.3167.3476.1362.8173.37


As can be seen, accuracy peaks for vector size of 2000 words for all the classifiers, and the best results come from the J48 (and to a lesser extent, the PART) classifier, so we continue with these going forward.

My next improvement was to remove stop words. Accuracy for J48 went up to 79.40% for an input vector size of 2000 words. I then tried parsing the sentences with OpenNLP to extract noun phrases and use them as features instead of words - the sweet spot moves up to 3000 words but the accuracy is 61.56%, lower than our previous experiment, so the extra effort of doing noun-phrase chunking doesn't seem to be worth it.

The "best" classifier so far was the J48 classifier with 2000 words and stopwords removed. The confusion matrix for the classifier is shown below. Here non-smokers, past-smokers and unknowns are being classified the best, current-smoker is being misclassified quite frequently as non-smoker and past-smoker.

1
2
3
4
5
6
7
8
9
Confusion Matrix
-------------------

   a   b   c   d   e   <-- classified as
  16   5  10   1   3 |   a = 0 (CURRENT_SMOKER)
  10  47   4   0   5 |   b = 1 (NON-SMOKER)
  10   5  19   1   1 |   c = 2 (PAST SMOKER)
   3   1   1   0   4 |   d = 3 (SMOKER)
   8   7   1   0 236 |   e = 4 (UNKNOWN)

The pruned tree for the J48 classifier looks like this. I edited out the _binarized suffixes from the tree labels to make it more readable. But what is interesting is that the classifier uses words such as tobacco, smoker, smokes, etc to figure out the target variable, much the same way a human reader would.

1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
J48 pruned tree
------------------

quit = 0
|   pack = 0
|   |   smoke = 0
|   |   |   tobacco = 0
|   |   |   |   smoking = 0
|   |   |   |   |   smoker = 0
|   |   |   |   |   |   drinks = 0
|   |   |   |   |   |   |   married = 0: 5 (253.0/8.0)
|   |   |   |   |   |   |   married = 1
|   |   |   |   |   |   |   |   bid = 0: 5 (5.0/1.0)
|   |   |   |   |   |   |   |   bid = 1: 2 (3.0)
|   |   |   |   |   |   drinks = 1: 1 (2.0/1.0)
|   |   |   |   |   smoker = 1
|   |   |   |   |   |   100 = 0
|   |   |   |   |   |   |   0 = 0: 1 (3.0)
|   |   |   |   |   |   |   0 = 1: 3 (3.0)
|   |   |   |   |   |   100 = 1: 2 (3.0)
|   |   |   |   smoking = 1
|   |   |   |   |   68 = 0
|   |   |   |   |   |   aspirin = 0
|   |   |   |   |   |   |   feeling = 0
|   |   |   |   |   |   |   |   4 = 0
|   |   |   |   |   |   |   |   |   i = 0: 1 (7.0)
|   |   |   |   |   |   |   |   |   i = 1: 2 (2.0)
|   |   |   |   |   |   |   |   4 = 1: 2 (7.0)
|   |   |   |   |   |   |   feeling = 1: 3 (2.0)
|   |   |   |   |   |   aspirin = 1: 3 (3.0)
|   |   |   |   |   68 = 1: 4 (2.0)
|   |   |   tobacco = 1
|   |   |   |   dilatation = 0
|   |   |   |   |   possibility = 0
|   |   |   |   |   |   former = 0
|   |   |   |   |   |   |   passed = 0: 2 (29.0/1.0)
|   |   |   |   |   |   |   passed = 1: 3 (2.0/1.0)
|   |   |   |   |   |   former = 1: 3 (2.0)
|   |   |   |   |   possibility = 1: 3 (2.0/1.0)
|   |   |   |   dilatation = 1: 1 (2.0)
|   |   smoke = 1: 2 (21.0/1.0)
|   pack = 1
|   |   central = 0
|   |   |   10 = 0: 3 (4.0/1.0)
|   |   |   10 = 1: 1 (16.0/1.0)
|   |   central = 1: 5 (2.0)
quit = 1
|   added = 0
|   |   day = 0
|   |   |   0 = 0: 1 (2.0)
|   |   |   0 = 1: 3 (2.0)
|   |   day = 1: 3 (17.0)
|   added = 1: 1 (2.0)

As mentioned earlier, the performance of this classifier had an accuracy of 91.21% against the training set and 71.15% against the evaluation set.

Analysis and Results - Approach #2


Something I noticed about the annotations is that they are somewhat ad-hoc, ie, SMOKER could mean CURRENT or PAST SMOKER, but is a category by itself. UNKNOWN is a bit tricky, we could assume that the smoking status was not considered because the patient was obviously a non-smoker, or the clinician neglected to record it. So we could think of the annotations as being structured - the top split would be SMOKER and NON-SMOKER. Within SMOKER would be CURRENT SMOKER, PAST SMOKER and SMOKER, and within NON-SMOKER would be UNKNOWN and NON-SMOKER.

So I decided to use a classifier to first determine SMOKER vs NON-SMOKER, then based on the results of that send it to a second classifier to determine the "type" of SMOKER or NON-SMOKER.

Once again, I rebuilt the ARFF file for the top level classifier, then used the Weka GUI to vectorize and binarize the text, and tried the J48 and PART classifiers against different word vector sizes. The data and plot are shown below. Corresponding accuracies were 96.23% against the training set and 84.62% against the evaluation test - so we can be pretty accurate in predicting smoker vs non-smoker from the discharge notes.

#-wordsJ48PART
1000w84.6779.65
2000w87.6983.64
3000w84.9284.17
4000w85.1684.17
5000w86.6885.43


Once more, the winner appears to be J48 with a 2000 word vector. We then build two sub-ARFF files one for the SMOKER sub-classes, and one for the NON-SMOKER subclasses. For each ARFF file, I tested various classifiers with word vector size 2000.

For the smoking category, the best performer was SMO with c=0.1 and toleranceParameter=0.1. 3-fold cross validation gave 57.5% accuracy (a slight bump from 55% with the default parameters, although this did not affect the accuracy of the final stacked classifier).

For the non-smoking category, the best performer was PART with an accuracy of 87.73% with 3-fold cross validation.

The overall accuracy against the training set was 95.48% and 73.08% against the evaluation set. Overall accuracy was calculated by sending each record through the classifier stack and comparing the final prediction with the original.

Acknowledgements


  • Datasets for this experiment was provided by i2b2, for which I am very thankful. i2b2 also requires users of the datasets to acknowledge them with the following boilerplate below:
    "Deidentified clinical records used in this research were provided by the i2b2 National Center for Biomedical Computing funded by U54LM008748 and were originally prepared for the Shared Tasks for Challenges in NLP for Clinical Data organized by Dr. Ozlem Uzuner, i2b2 and SUNY."
  • For generating all the models, I used the Weka data mining toolkit, without which I would not have been able to try out and compare so many models in the time I took. Weka is required to be cited as follows:
    Mark Hall, Eibe Frank, Geoffrey Holmes, Bernhard Pfahringer, Peter Reutemann, Ian H. Witten (2009); The WEKA Data Mining Software: An Update; SIGKDD Explorations, Volume 11, Issue 1.

Additionally, as a result of the NDA with i2b2, I am forbidden from sharing the contents of the datasets I used (even for showing snippets for illustrative purposes). If you want the data, you will need to register and sign the NDA yourself.

Also, as a condition for using Weka (since it is GPL), my model and source code must be made freely available for other people to use - the ARFF files with (vectorized and binarized with top 2000 features) data corresponding to the original multiclass classifier and the three classifiers comprising the stacked classifier can be found on GitHub here. All the source code used in this experiment is available on this page.

In addition, here are some links for documents that I found invaluable when building this solution.