Friday, February 21, 2014

Fuzzy String Matching with SolrTextTagger


As you know, I've been experimenting with approaches for associating entities from a dictionary to text. David Smiley (Solr/Lucene committer best known for his work with Solr Spatial) pointed me to SolrTextTagger, part of his OpenSextant project, an application that very efficiently recognizes place names in large blocks of plain text by using Lucene as a Finite State Transducer (FST). Apart from the OpenSextant project, SolrTextTagger is also used in Apache Stanbol, which provides a set of reusable components for semantic content management. This post describes a variant of the Fuzzy String Matching Algorithm I described in my previous post, but using SolrTextTagger and Solr instead of Lucene.

The idea is similar to how I used Redis and Aho-Corasick, but probably far, far more efficient in terms of speed and memory usage. What it meant for me (in terms of my previous algorithm) was that I could eliminate building and querying with phrase n-grams in client code for partial searches, and replace it with one pass through the query phrase against the entire dictionary of 6.8 million terms. In addition, in my current algorithm (described below), I replaced the cascading search across different variants of the phrase into a single search, further reducing the round trips to Lucene per query.

The use of Solr seemed a bit heavyweight at first, but one advantage of Solr is that the lookup resource can be shared within a network. Further, the SolrTextTagger service supports an fq parameter so you could potentially have many dictionaries (from a single FST) being served from the same Solr instance (although I didn't use it in this case).

The current algorithm shares many ideas from the previous one. Incoming text is partitioned into (noun) phrases, and each phrase is sent into the algorithm. It first tries to match the exact phrase, and some standard transformations on the exact phrase (punctuation and case normalized, alpha sorted and stemmed), against pre-transformed String (non-tokenized) fields in the index. If a match happens, then it is reported and the algorithm exits. If no match happens, the phrase is sent to the SolrTextTagger (tag) service, which matches the punctuation and case normalized phrase against the pre-built Lucene FST, and gets back phrase substrings that match entities in the dictionary. Results are deduped by the CUI field and the first (highest ranking) record is retained.

Scoring is manual. This is deliberate. I tried using raw Lucene scores, but it felt a bit confusing. Here is how the scoring works. In the case of the full phrase, a match implies a match either against the original or one of the three transformed versions. We calculate a measure of overlap based on Levenshtein's distance and find the closest one. Depending on which one it matched best, we discount the overlap by the level score (100, 75, 50, 25). We do the same thing for the partial matches, except that we calculate the overlap against the substring and its transforms, and discount the overlap further by the number of words in the substring versus the original phrase.

Here are the steps to replicate my setup.

Download and Build SolrTextTagger


The code for SolrTextTagger resides on GitHub, so to download and build the custom Solr JAR, execute the following sequence of commands. This will create a solr-text-tagger-1.3-SNAPSHOT.jar file in your target subdirectory in the SolrTextTagger project.

1
2
3
4
5
sujit@tsunami:~/Downloads$ git clone \
    https://github.com/OpenSextant/SolrTextTagger.git
sujit@tsunami:~/Downloads$ cd SolrTextTagger
sujit@tsunami:~/Downloads/SolrTextTagger$ mvn test
sujit@tsunami:~/Downloads/SolrTextTagger$ mvn package

Download and Customize Solr


Solr is available for download here. After downloading you will need to expand it locally, then update the schema.xml and solrconfig.xml in the conf subdirectory as shown below:

1
2
sujit@tsunami:~/Downloads$ tar xvzf solr-4.6.1.tgz
sujit@tsunami:~/Downloads$ cd solr-4.6.1/example/solr/collection1/conf

Update the schema.xml to replace the field definitions with our own. Our fields list and the definition of the field type "tag" (copied from the documentation of SolrTextTagger) is shown. The "id" field is just a integer sequence (unique key for Solr), the "cui" and "descr" comes from the CUI and STR fields from the UMLS database, and the descr_norm, descr_sorted, descr_stemmed are case/punctuation normalized, alpha sorted and stemmed versions of STR. The descr_tagged field is identical to descr_norm but is analyzed differently as specified below.

1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
<fields>
    <field name="id" type="string" indexed="true" stored="true" 
      required="true"/>
    <field name="cui" type="string" indexed="true" stored="true"/>
    <field name="descr" type="string" indexed="true" stored="true"/>
    <field name="descr_norm" type="string" indexed="true" stored="true"/>
    <field name="descr_sorted" type="string" indexed="true" stored="true"/>
    <field name="descr_stemmed" type="string" indexed="true" stored="true"/>
    <field name="descr_tagged" type="tag" indexed="true" stored="false" 
         omitTermFreqAndPositions="true" omitNorms="true"/>
    <copyField source="descr_norm" dest="descr_tagged"/>
    <dynamicField name="*" type="string" indexed="true" stored="true"/>
  </fields>
  <uniqueKey>id</uniqueKey>
  <types>
    <fieldType name="tag" class="solr.TextField" positionIncrementGap="100">
      <analyzer>
        <tokenizer class="solr.StandardTokenizerFactory"/>
        <filter class="solr.EnglishPossessiveFilterFactory"/>
        <filter class="solr.ASCIIFoldingFilterFactory"/>
        <filter class="solr.LowerCaseFilterFactory"/>
      </analyzer>
    </fieldType>
    ...
  </types>

We then add in the requestHandler definition for SolrTextTagger's tag service into the solrconfig.xml file (also in conf). The definition is shown below:

1
2
3
4
5
6
7
8
<requestHandler name="/tag" 
      class="org.opensextant.solrtexttagger.TaggerRequestHandler">
    <str name="indexedField">descr_tagged</str>
    <str name="storedField">descr_norm</str>
    <bool name="partialMatches">false</bool>
    <int name="valueMaxLen">5000</int>
    <str name="cacheFile">taggerCache.dat</str>
  </requestHandler>

Finally, we create a lib directory and copy over the solr-text-tagger-1.3-SNAPSHOT.jar into it. Then go up to the example directory and start Solr. Solr is now listening on port 8983 on localhost.

1
2
3
4
5
sujit@tsunami:~/Downloads/solr-4.6.1/example/solr/collection1$ mkdir lib
sujit@tsunami:~/Downloads/solr-4.6.1/example/solr/collection1$ cp \
    ~/Downloads/SolrTextTagger/target/*jar lib/
sujit@tsunami:~/Downloads/solr-4.6.1/example/solr/collection1$ cd ../..
sujit@tsunami:~/Downloads/solr-4.6.1/example$ java -jar start.jar

Load Data and Build FST


We use the same cuistr1.csv file that we downloaded from our MySQL UMLS database. I guess I could have written custom code to load the data into the index, but I had started experimenting with SolrTextTagger using curl, so I just wrote some code that converted the (CUI,STR) CSV format into JSON, with additional fields created by our case/punctuation normalization, alpha sort and stemming. I used the same Scala code since I already had the transformations coded up from last week. Once I generated the JSON file (cuistr1.json), I uploaded it into Solr and built the FST using the following curl commands.

1
2
3
4
sujit@tsunami:~/Downloads$ curl \
    "http://localhost:8983/solr/update/json?commit=true" \
    --data-binary @cuistr1.json -H 'Content-type:application/json'
sujit@tsunami:~/Downloads$ curl "http://localhost:8983/solr/tag?build=true"

The data is now ready to use, the code for the algorithm is shown below. The buildIndex() method was used to create the cuistr1.json file. The annotateConcepts() method is used to match a phrase against the dictionary.

1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
// Source: src/main/scala/com/mycompany/scalcium/umls/UmlsTagger2.scala
package com.mycompany.scalcium.umls

import java.io.File
import java.io.FileWriter
import java.io.PrintWriter
import java.io.StringReader
import java.util.regex.Pattern

import scala.collection.JavaConversions.asScalaIterator
import scala.collection.mutable.ArrayBuffer
import scala.io.Source

import org.apache.commons.lang3.StringUtils
import org.apache.lucene.analysis.Analyzer
import org.apache.lucene.analysis.standard.StandardAnalyzer
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute
import org.apache.lucene.util.Version
import org.apache.solr.client.solrj.SolrRequest
import org.apache.solr.client.solrj.impl.HttpSolrServer
import org.apache.solr.client.solrj.request.ContentStreamUpdateRequest
import org.apache.solr.common.SolrDocumentList
import org.apache.solr.common.params.CommonParams
import org.apache.solr.common.params.ModifiableSolrParams
import org.apache.solr.common.util.ContentStreamBase

class UmlsTagger2(val solrServerUrl: String) {

  val punctPattern = Pattern.compile("\\p{Punct}")
  val spacePattern = Pattern.compile("\\s+")
  
  case class Suggestion(val score: Float, 
    val descr: String, val cui: String)
      
  val solrServer = new HttpSolrServer(solrServerUrl)
  
  def buildIndexJson(inputFile: File, 
      outputFile: File): Unit = {
    val writer = new PrintWriter(new FileWriter(outputFile))
    writer.println("[")
    var i = 0
    Source.fromFile(inputFile)
      .getLines()
      .foreach(line => {
        val Array(cui, str) = line
          .replace("\",\"", "\t")
          .replaceAll("\"", "")
          .replaceAll("\\\\", "")
          .split("\t")
        val strNorm = normalizeCasePunct(str)
        val strSorted = sortWords(strNorm)
        val strStemmed = stemWords(strNorm)
        val obuf = new StringBuilder()
        if (i > 0) obuf.append(",")
        obuf.append("{")
          .append("\"id\":").append(i).append(",")
          .append("\"cui\":\"").append(cui).append("\",")
          .append("\"descr\":\"").append(str).append("\",")
          .append("\"descr_norm\":\"").append(strNorm).append("\",")
          .append("\"descr_sorted\":\"").append(strSorted).append("\",")
          .append("\"descr_stemmed\":\"").append(strStemmed).append("\"")
          .append("}")
        writer.println(obuf.toString)
        i += 1
      })
    writer.println("]")
    writer.flush()
    writer.close()
  }

  def annotateConcepts(phrase: String): 
      List[Suggestion] = {
    // check for full match
    val suggestions = ArrayBuffer[Suggestion]()
    select(phrase) match {
      case Some(suggestion) => suggestions += suggestion
      case None => tag(phrase) match {
        case Some(subSuggs) => suggestions ++= subSuggs
        case None => {}
      }
    }
    suggestions.toList
  }

  ///////////// phrase munging methods //////////////
  
  def normalizeCasePunct(str: String): String = {
    val str_lps = punctPattern
      .matcher(str.toLowerCase())
      .replaceAll(" ")
    spacePattern.matcher(str_lps).replaceAll(" ")
  }

  def sortWords(str: String): String = {
    val words = str.split(" ")
    words.sortWith(_ < _).mkString(" ")
  }
  
  def stemWords(str: String): String = {
    val stemmedWords = ArrayBuffer[String]()
    val tokenStream = getAnalyzer().tokenStream(
      "str_stemmed", new StringReader(str))
    val ctattr = tokenStream.addAttribute(
      classOf[CharTermAttribute])    
    tokenStream.reset()
    while (tokenStream.incrementToken()) {
      stemmedWords += ctattr.toString()
    }
    stemmedWords.mkString(" ")
  }
  
  def getAnalyzer(): Analyzer = {
    new StandardAnalyzer(Version.LUCENE_46)
  }
  
  ///////////////// solr search methods //////////////
  
  def select(phrase: String): Option[Suggestion] = {
    val phraseNorm = normalizeCasePunct(phrase)
    val phraseSorted = sortWords(phraseNorm)
    val phraseStemmed = stemWords(phraseNorm)
    // construct query
    val query = """descr:"%s" descr_norm:"%s" descr_sorted:"%s" descr_stemmed:"%s""""
      .format(phrase, phraseNorm, phraseSorted, phraseStemmed)
    val params = new ModifiableSolrParams()
    params.add(CommonParams.Q, query)
    params.add(CommonParams.ROWS, String.valueOf(1))
    params.add(CommonParams.FL, "*,score")
    val rsp = solrServer.query(params)
    val results = rsp.getResults()
    if (results.getNumFound() > 0L) {
      val sdoc = results.get(0)
      val descr = sdoc.getFieldValue("descr").asInstanceOf[String]
      val cui = sdoc.getFieldValue("cui").asInstanceOf[String]
      val score = computeScore(descr, 
        List(phrase, phraseNorm, phraseSorted, phraseStemmed))
      Some(Suggestion(score, descr, cui))
    } else None
  }

  def tag(phrase: String): Option[List[Suggestion]] = {
    val phraseNorm = normalizeCasePunct(phrase)
    val params = new ModifiableSolrParams()
    params.add("overlaps", "LONGEST_DOMINANT_RIGHT")
    val req = new ContentStreamUpdateRequest("")
    req.addContentStream(new ContentStreamBase.StringStream(phrase))
    req.setMethod(SolrRequest.METHOD.POST)
    req.setPath("/tag")
    req.setParams(params)
    val rsp = req.process(solrServer)
    val results = rsp.getResponse()
      .get("matchingDocs")
      .asInstanceOf[SolrDocumentList]
    val nwordsInPhrase = phraseNorm.split(" ").length.toFloat
    val suggestions = results.iterator().map(sdoc => {
        val descr = sdoc.getFieldValue("descr").asInstanceOf[String]
        val cui = sdoc.getFieldValue("cui").asInstanceOf[String]
        val nWordsInDescr = descr.split(" ").length.toFloat
        val descrNorm = normalizeCasePunct(descr)
        val descrSorted = sortWords(descrNorm)
        val descrStemmed = stemWords(descrNorm)
        val nwords = descrNorm.split(" ").length.toFloat
        val score = (nwords / nwordsInPhrase) * 
          computeScore(descr, 
          List(descr, descrNorm, descrSorted, descrStemmed))
        Suggestion(score, descr, cui)      
      })
      .toList
      .groupBy(_.cui) // dedup by cui
      .map(_._2.toList.head)
      .toList
      .sortWith((a,b) => a.score > b.score) // sort by score
    Some(suggestions)
  }

  def computeScore(s: String, 
      candidates: List[String]): Float = {
    val levels = List(100.0F, 75.0F, 50.0F, 25.0F)
    val candLevels = candidates.zip(levels).toMap
    val topscore = candidates.map(candidate => {
        val maxlen = Math.max(candidate.length(), s.length()).toFloat
        val dist = StringUtils.getLevenshteinDistance(candidate, s).toFloat
        (candidate, 1.0F - (dist / maxlen))
      })    
      .sortWith((a, b) => a._2 > b._2)
      .head
    val level = candLevels.getOrElse(topscore._1, 0.0F)
    level * topscore._2
  }

  //////////////// misc methods ////////////////
  
  def formatSuggestion(sugg: Suggestion): String = {
    "[%6.2f%%] (%s) %s"
      .format(sugg.score, sugg.cui, sugg.descr)
  }
}

To run the code, we use the following JUnit test (commenting and uncommenting tests as needed).

1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
// Source: src/test/scala/com/mycompany/scalcium/umls/UmlsTagger2Test.scala 
package com.mycompany.scalcium.umls

import java.io.File

import org.junit.Assert
import org.junit.Test

class UmlsTagger2Test {

  @Test
  def testBuildIndex(): Unit = {
    val tagger = new UmlsTagger2("")
    tagger.buildIndexJson(
      new File("/home/sujit/Projects/med_data/cuistr1.csv"), 
      new File("/home/sujit/Projects/med_data/cuistr1.json"))
  }
  
  @Test
  def testGetFull(): Unit = {
    val tagger = new UmlsTagger2("http://localhost:8983/solr")
    val phrases = List("Lung Cancer", "Heart Attack", "Diabetes")
    phrases.foreach(phrase => {
      Console.println()
      Console.println("Query: %s".format(phrase))
      val suggestions = tagger.select(phrase)
      suggestions match {
        case Some(suggestion) => {
          Console.println(tagger.formatSuggestion(suggestion))
          Assert.assertNotNull(suggestion.cui)
        }
        case None =>
          Assert.fail("No results for [%s]".format(phrase))
      }
    })
  }
  
  @Test
  def testGetPartial(): Unit = {
    val tagger = new UmlsTagger2("http://localhost:8983/solr")
    val phrases = List(
        "Heart Attack and diabetes",
        "carcinoma (small-cell) of lung",
        "asthma side effects")
    phrases.foreach(phrase => {
      Console.println()
      Console.println("Query: %s".format(phrase))
      val suggestions = tagger.tag(phrase)
      suggestions match {
        case Some(psuggs) => {
          psuggs.foreach(psugg => {
            Console.println(tagger.formatSuggestion(psugg))    
          })
          Assert.assertNotNull(psuggs)
        }
        case None =>
          Assert.fail("No results for [%s]".format(phrase))
      }
    })
  }
  
  @Test
  def testAnnotateConcepts(): Unit = {
    val tagger = new UmlsTagger2("http://localhost:8983/solr")
    val phrases = List("Lung Cancer", 
        "Heart Attack", 
        "Diabetes",
        "Heart Attack and diabetes",
        "carcinoma (small-cell) of lung",
        "asthma side effects"
    )
    phrases.foreach(phrase => {
      Console.println()
      Console.println("Query: %s".format(phrase))
      val suggestions = tagger.annotateConcepts(phrase)
      suggestions.foreach(suggestion => {
        Console.println(tagger.formatSuggestion(suggestion))
      })
    })
  }
}

The results of the testAnnotateConcepts() are shown below. I've used the same query terms as the previous algorithm, and the results are also similarly consistent.

1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
Query: Lung Cancer
[100.00%] (C0242379) Lung Cancer

Query: Heart Attack
[100.00%] (C0027051) Heart Attack

Query: Diabetes
[100.00%] (C0011847) Diabetes

Query: Heart Attack and diabetes
[ 25.00%] (C0002838) AND
[ 25.00%] (C1515981) And
[ 25.00%] (C0011847) Diabetes
[ 25.00%] (C1706368) And
[ 12.50%] (C1550557) and
[ 12.50%] (C0027051) heart attack
[  6.25%] (C0011849) diabetes
[  6.25%] (C0011860) diabetes

Query: carcinoma (small-cell) of lung
[ 26.79%] (C0149925) small cell carcinoma of lung

Query: asthma side effects
[ 33.33%] (C2984299) Asthma
[ 16.67%] (C0001688) side effects
[  8.33%] (C0004096) asthma

Thats all I have for today. Next week hopefully I will talk about something else :-).

Sunday, February 16, 2014

Fuzzy String Matching against UMLS Data


I started looking at the Unified Medical Language System (UMLS) Database because it is an important data source for the Apache cTakes pipeline. UMLS integrates multiple medical terminologies into a single dictionary, which can be used for recognizing medical entities in text. Our proprietary concept mapping algorithm uses a superset of UMLS (combined with other terminologies, and manually curated by our team of in-house medical folks) to extract meaning from medical text. Both cTakes and Hitex (another NLP application focused on medical text) use Metamap (via its Java API), a tool provided by the National Library of Medicine (NLM) to recognize UMLS concepts in text.

I was curious about how Metamap did it - turns out it uses a pretty complex algorithm as described in this paper by Alan Aronson, the creator of Metamap and is written in Prolog and C. Parallelly, I also happened to come across the Python fuzzywuzzy library for fuzzy string matching, and the Java OpenReconcile project (used in the OpenRefine project) for matching entities in text. This gave me an idea about an algorithm that could be used to annotate text using the UMLS database as a dictionary, which I describe here. It is very different from what we use at work, and the advantage of that one is that it has been vetted and tweaked extensively over the years to yield good results over our entire taxonomy.

My matching algorithm matches the incoming phrase against the original strings from UMLS, then the case and punctuation normalized (ie lowercased and punctuations replaced with spaces) version, then the version with all the words in the normalized string alpha sorted, and finally a version with all the words in the normalized string stemmed using Lucene's StandardAnalyzer. Each match in the sequence above assigned a score in the sequence [100, 75, 50, 25]. If a match is found at any level, the other layers are not executed. If no match is found, we construct n-grams of the phrase, where n varies from N-1 to 1. Like the incoming phrase, each n-gram is transformed and matched against the dictionary. Like the full phrase, a match short-circuits the pipeline for that n-gram. Words in n-grams that matched a dictionary entry are ignored in subsequent matches. Scores for n-grams are discounted by the number of words in the n-gram versus the number of words in the entire phrase.

I store the dictionary as a Lucene index, although I am not using any of Lucene's search capabilities, and I could just as easily have stored the dictionary in a relational database or key-value store. All matches are exact matches, although I use Lucene's StandardAnalyzer to stem the dictionary entries during indexing and the incoming phrases during searching.

To get the data out of the UMLS database into a flat file, I run the following SQL on the MRCONSO table.

1
2
3
4
5
6
7
mysql> select CUI, STR from MRCONSO 
...    where LAT = 'ENG' 
...    into outfile '/tmp/cuistr.csv' 
...    fields terminated by ',' 
...    enclosed by '"' lines 
...    terminated by '\n';
Query OK, 7871075 rows affected (23.30 sec)

I then removed duplicates (there are some records with the same CUI and STR values from different sources) using this Unix call. This resulted in 6,818,318 rows of input.

1
sujit@tsunami:~$ cat cuistr.csv | sort | uniq > cuistr1.csv

Here is the code. The buildIndex() method creates the index and the annotateConcepts() method matches the incoming phrase (as described above) to the entries in the index.

1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
// Source: src/main/scala/com/mycompany/scalcium/umls/UmlsTagger.scala
package com.mycompany.scalcium.umls

import java.io.File
import java.io.StringReader
import java.util.concurrent.atomic.AtomicLong
import java.util.regex.Pattern

import scala.Array.canBuildFrom
import scala.Array.fallbackCanBuildFrom
import scala.collection.mutable.ArrayBuffer
import scala.io.Source

import org.apache.lucene.analysis.Analyzer
import org.apache.lucene.analysis.standard.StandardAnalyzer
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute
import org.apache.lucene.document.Document
import org.apache.lucene.document.Field
import org.apache.lucene.document.Field.Index
import org.apache.lucene.document.Field.Store
import org.apache.lucene.index.DirectoryReader
import org.apache.lucene.index.IndexReader
import org.apache.lucene.index.IndexWriter
import org.apache.lucene.index.IndexWriterConfig
import org.apache.lucene.index.Term
import org.apache.lucene.search.BooleanClause.Occur
import org.apache.lucene.search.BooleanQuery
import org.apache.lucene.search.IndexSearcher
import org.apache.lucene.search.ScoreDoc
import org.apache.lucene.search.TermQuery
import org.apache.lucene.store.SimpleFSDirectory
import org.apache.lucene.util.Version

class UmlsTagger {

  val punctPattern = Pattern.compile("\\p{Punct}")
  val spacePattern = Pattern.compile("\\s+")
  
  def buildIndex(inputFile: File, 
      luceneDir: File): Unit = {
    // set up the index writer
    val analyzer = getAnalyzer()
    val iwconf = new IndexWriterConfig(Version.LUCENE_46, analyzer)
    iwconf.setOpenMode(IndexWriterConfig.OpenMode.CREATE)
    val writer = new IndexWriter(new SimpleFSDirectory(luceneDir), iwconf)
    // read through input file and write out to lucene
    val counter = new AtomicLong(0L)
    val linesReadCounter = new AtomicLong(0L)
    Source.fromFile(inputFile)
        .getLines()
        .foreach(line => {
      val linesRead = linesReadCounter.incrementAndGet()
      if (linesRead % 1000 == 0) Console.println("%d lines read".format(linesRead))
      val Array(cui, str) = line
        .replace("\",\"", "\t")
        .replaceAll("\"", "")
        .split("\t")
      val strNorm = normalizeCasePunct(str)
      val strSorted = sortWords(strNorm)
      val strStemmed = stemWords(strNorm)
      // write full str record 
      // str = exact string
      // str_norm = case and punct normalized, exact
      // str_sorted = str_norm sorted
      // str_stemmed = str_sorted stemmed
      val fdoc = new Document()
      val fid = counter.incrementAndGet()
      fdoc.add(new Field("id", fid.toString, Store.YES, Index.NOT_ANALYZED))
      fdoc.add(new Field("cui", cui, Store.YES, Index.NOT_ANALYZED))
      fdoc.add(new Field("str", str, Store.YES, Index.NOT_ANALYZED))
      fdoc.add(new Field("str_norm", strNorm, Store.YES, Index.NOT_ANALYZED))
      fdoc.add(new Field("str_sorted", strSorted, Store.YES, Index.NOT_ANALYZED))
      fdoc.add(new Field("str_stemmed", strStemmed, Store.YES, Index.NOT_ANALYZED))
      writer.addDocument(fdoc)
      if (fid % 1000 == 0) writer.commit()
    })
    writer.commit()
    writer.close()
  }

  def annotateConcepts(phrase: String, 
      luceneDir: File): 
      List[(Double,String,String)] = {
    val suggestions = ArrayBuffer[(Double,String,String)]()
    val reader = DirectoryReader.open(
      new SimpleFSDirectory(luceneDir)) 
    val searcher = new IndexSearcher(reader)
    // try to match full string
    suggestions ++= cascadeSearch(searcher, reader, 
      phrase, 1.0)
    if (suggestions.size == 0) {
      // no exact match found, fall back to inexact matches
      val words = normalizeCasePunct(phrase)
        .split(" ")
      val foundWords = scala.collection.mutable.Set[String]()
      for (nword <- words.size - 1 until 0 by -1) {
        words.sliding(nword)
          .map(ngram => ngram.mkString(" "))
          .foreach(ngram => {
            if (alreadySeen(foundWords, ngram)) {
              val ngramWords = ngram.split(" ")
              val ratio = ngramWords.size.toDouble / words.size
              val inexactSuggestions = cascadeSearch(
                searcher, reader, ngram, ratio)
              if (inexactSuggestions.size > 0) {
                suggestions ++= inexactSuggestions
                foundWords ++= ngramWords
              }
            }    
          })       
      }
    }
    if (suggestions.size > 0) {
      // dedup by cui, keeping the first matched
      val deduped = suggestions.groupBy(_._3)
        .map(kv => kv._2.head)
        .toList  
        .sortWith((a,b) => a._1 > b._1)
      suggestions.clear
      suggestions ++= deduped
    }
    // clean up
    reader.close()
    // return results
    suggestions.toList
  }
  
  def printConcepts(
      concepts: List[(Double,String,String)]): 
      Unit = {
    concepts.foreach(concept => 
      Console.println("[%6.2f%%] (%s) %s"
      .format(concept._1, concept._3, concept._2)))
  }
  
  def normalizeCasePunct(str: String): String = {
    val str_lps = punctPattern
      .matcher(str.toLowerCase())
      .replaceAll(" ")
    spacePattern.matcher(str_lps).replaceAll(" ")
  }

  def sortWords(str: String): String = {
    val words = str.split(" ")
    words.sortWith(_ < _).mkString(" ")
  }
  
  def stemWords(str: String): String = {
    val stemmedWords = ArrayBuffer[String]()
    val tokenStream = getAnalyzer().tokenStream(
      "str_stemmed", new StringReader(str))
    val ctattr = tokenStream.addAttribute(
      classOf[CharTermAttribute])    
    tokenStream.reset()
    while (tokenStream.incrementToken()) {
      stemmedWords += ctattr.toString()
    }
    stemmedWords.mkString(" ")
  }
  
  def getAnalyzer(): Analyzer = {
    new StandardAnalyzer(Version.LUCENE_46)
  }
  
  case class StrDocument(id: Int, 
      cui: String, str: String, 
      strNorm: String, strSorted: String, 
      strStemmed: String)
      
  def getDocument(reader: IndexReader,
      hit: ScoreDoc): StrDocument = {
    val doc = reader.document(hit.doc)
    StrDocument(doc.get("id").toInt,
      doc.get("cui"), doc.get("str"), 
      doc.get("str_norm"), doc.get("str_sorted"), 
      doc.get("str_stemmed"))
  }
  
  def cascadeSearch(searcher: IndexSearcher,
      reader: IndexReader, phrase: String,
      ratio: Double): 
      List[(Double,String,String)] = {
    val results = ArrayBuffer[(Double,String,String)]()
    // exact match (100.0%)
    val query1 = new TermQuery(new Term("str", phrase))
    val hits1 = searcher.search(query1, 1).scoreDocs
    if (hits1.size > 0) {
      results += hits1.map(hit => {
        val doc = getDocument(reader, hit)
        (100.0 * ratio, doc.str, doc.cui)
      })
      .toList
      .head
    }
    // match normalized phrase (75%)
    val normPhrase = normalizeCasePunct(phrase)
    if (results.size == 0) {
      val query2 = new TermQuery(new Term("str_norm", normPhrase))
      val hits2 = searcher.search(query2, 1).scoreDocs
      if (hits2.size > 0) {
        results += hits2.map(hit => {
          val doc = getDocument(reader, hit)
          (90.0 * ratio, doc.str, doc.cui)
        })
        .toList
        .head
      }
    }
    // match sorted phrase (50%)
    val sortedPhrase = sortWords(normPhrase)
    if (results.size == 0) {
      val query3 = new TermQuery(new Term("str_sorted", sortedPhrase))
      val hits3 = searcher.search(query3, 1).scoreDocs
      if (hits3.size > 0) {
        results += hits3.map(hit => {
          val doc = getDocument(reader, hit)
          (80.0 * ratio, doc.str, doc.cui)
        })
        .toList
        .head
      }
    }
    // match stemmed phrase (25%)
    val stemmedPhrase = stemWords(normPhrase)
    if (results.size == 0) {
      val query4 = new TermQuery(new Term("str_stemmed", stemmedPhrase))
      val hits4 = searcher.search(query4, 1).scoreDocs
      if (hits4.size > 0) {
        results += hits4.map(hit => {
          val doc = getDocument(reader, hit)
          (70.0 * ratio, doc.str, doc.cui)
        })
        .toList
        .head
      }
    }
    results.toList
  }
  
  def alreadySeen(
      refset: scala.collection.mutable.Set[String], 
      ngram: String): Boolean = {
    val words = ngram.split(" ")
    val nseen = words.filter(word => 
      !refset.contains(word))
      .size
    if (refset.size == 0) true
    else if (nseen > 0) true 
    else false
  }
}

I just used JUnit tests to run through the various functionality. Building the index takes about 40 minutes on my laptop. Here is the JUnit.

1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
// Source: src/test/scala/com/mycompany/scalcium/umls/UmlsTaggerTest.scala
package com.mycompany.scalcium.umls

import org.junit.Test
import java.io.File
import org.junit.Assert

class UmlsTaggerTest {

  @Test
  def testSortWords(): Unit = {
    val s = "heart attack and diabetes"
    val tagger = new UmlsTagger()
    Assert.assertEquals("and attack diabetes heart", tagger.sortWords(s))
  }
  
  @Test
  def testStemWords(): Unit = {
    val s = "and attack diabetes heart"
    val tagger = new UmlsTagger()
    Assert.assertEquals("attack diabetes heart", tagger.stemWords(s))
  }

  @Test
  def testBuild(): Unit = {
    val input = new File("/home/sujit/Projects/med_data/cuistr1.csv")
    val output = new File("/home/sujit/Projects/med_data/umlsindex")
    val tagger = new UmlsTagger()
    tagger.buildIndex(input, output)
  }
  
  @Test
  def testMapSingleConcept(): Unit = {
    val luceneDir = new File("/home/sujit/Projects/med_data/umlsindex")
    val tagger = new UmlsTagger()
    val strs = List("Lung Cancer", "Heart Attack", "Diabetes")
    strs.foreach(str => {
      val concepts = tagger.annotateConcepts(str, luceneDir)
      Console.println("Query: " + str)
      tagger.printConcepts(concepts)
      Assert.assertEquals(1, concepts.size)
      Assert.assertEquals(100.0D, concepts.head._1, 0.1D)
    })
  }

  @Test
  def testMapMultipleConcepts(): Unit = {
    val luceneDir = new File("/home/sujit/Projects/med_data/umlsindex")
    val tagger = new UmlsTagger()
    val strs = List(
        "Heart Attack and diabetes",
        "carcinoma (small-cell) of lung",
        "asthma side effects")
    strs.foreach(str => {
      val concepts = tagger.annotateConcepts(str, luceneDir)
      Console.println("Query: " + str)
      tagger.printConcepts(concepts)
    })
  }
}

The last two tests in the JUnit test above return exact and inexact match results for some queries. The first 3 are exact queries and the next 3 are in-exact queries. Output from these tests is shown below:

1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
Query: Lung Cancer
[100.00%] (C0242379) Lung Cancer

Query: Heart Attack
[100.00%] (C0027051) Heart Attack

Query: Diabetes
[100.00%] (C0011847) Diabetes

Query: Heart Attack and diabetes
[ 50.00%] (C0027051) heart attack
[ 35.00%] (C0011847) Diabetes
[ 35.00%] (C0699795) Attack

Query: carcinoma (small-cell) of lung
[ 80.00%] (C0149925) small cell carcinoma of lung

Query: asthma side effects
[ 66.67%] (C0001688) side effects
[ 33.33%] (C0004096) asthma

As you can see, the results don't look too bad. The algorithm is also not very complex, it borrows a lot from the ideas used in the fuzzywuzzy and OpenReconcile projects. Of course, I have tried this on a very small subset of queries, it remains to be seen if it produces good results as consistently as the one we use at work.