Sunday, August 30, 2015

Categorizing Medical Transcripts using DMOZ and Word2Vec


Sometime back, a colleague mentioned during a conversation that his PhD dissertation involved using DMOZ for clustering query terms. The technique seemed interesting in relation to a problem we were trying to solve at the time, but also got me thinking that perhaps this idea could be useful for categorizing documents as well. Since DMOZ provides a comprehensive hierarchy of web categories, and links to representative documents in each of these categories, the category (or categories) of an unseen document can be determined by computing the similarity of this document against the representative documents. Word2Vec vectors can be used to reduce both reference and test documents to the same latent space for comparison. To test this idea out, I used a subset of DMOZ categories under medical specialties to categorize a small collection of Medical Transcription documents I had crawled some time back. This post describes the effort.

DMOZ data is available for download as a pair of large RDF files - structure.rdf and content.rdf. The structure.rdf file contains the categories defined as path like strings. Categories are defined using the Topic tag and nested using narrow and symbolic tags. The content.rdf file also contains these path-like categories defined using the Topic tag and provides links to representative web pages for these categories using the link tag. Since we are only interested in Medical Specialties, we used the following code to parse the two RDFs in a streaming manner (SAX), and extract the links from them. Each link is then sent to Boilerpipe's Article Extractor, which strips out the HTML markup and removes irrelevant text from the page (using a combination of rules and machine learning).

1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
// Source: src/main/scala/dmozcat/preprocessing/ReferencePageExtractor.scala
package dmozcat.preprocessing

import java.io.File
import java.io.FileWriter
import java.io.PrintWriter
import java.net.URL
import java.util.concurrent.FutureTask
import java.util.concurrent.Callable
import java.util.concurrent.TimeUnit
import java.util.concurrent.TimeoutException

import javax.xml.parsers.SAXParserFactory

import org.xml.sax.Attributes
import org.xml.sax.helpers.DefaultHandler

import de.l3s.boilerpipe.extractors.ArticleExtractor

import scala.collection.mutable.ArrayBuffer

object ReferencePageExtractor extends App {
    val dataDir = "data"
    val structXml = new File(dataDir, "structure.rdf.u8")
    val contentXml = new File(dataDir, "content.rdf.u8")
    val outputCsv = new File(dataDir, "ref-pages.csv")
    val extractor = new ReferencePageExtractor(structXml, contentXml, outputCsv)
    extractor.extract()
}

class ReferencePageExtractor(structXml: File, contentXml: File, outputCsv: File) {
    
    def extract(): Unit = {
        val factory = SAXParserFactory.newInstance()
        // parse structure RDF to get list of topics
        val structParser = factory.newSAXParser()
        val structHandler = new DmozStructHandler()
        structParser.parse(structXml, structHandler)
        val topicSet = structHandler.topicSet()
        // parse content RDF to get list of URLs for each topic
        val contentParser = factory.newSAXParser()
        val contentHandler = new DmozContentHandler(topicSet)
        contentParser.parse(contentXml, contentHandler)
        val contentUrls = contentHandler.contentUrls()
        // download pages and write to file
        val writer = new PrintWriter(new FileWriter(outputCsv), true)
        contentUrls.foreach(topicUrl => {
            val topic = topicUrl._1
            val url = topicUrl._2
            val downloadTask = new FutureTask(new Callable[String]() {
                def call(): String = {
                    try {
                        val text = ArticleExtractor.INSTANCE
                            .getText(new URL(url))
                            .replaceAll("\n", " ")
                        Console.println("Downloading %s for topic: %s"
                            .format(url, topic))
                        text
                    } catch {
                        case e: Exception => "__ERROR__"
                    }
                }
            })
            new Thread(downloadTask).start()
            try {
                val text = downloadTask.get(60, TimeUnit.SECONDS)
                if (!text.equals("__ERROR__")) {
                    writer.println("%s\t%s\t%s".format(topic, url, text))
                } else {
                    Console.println("Download Error, skipping")
                }
            } catch {
                case e: TimeoutException => Console.println("Timed out, skipping")
            }
        })
        writer.flush()
        writer.close()
    }
}

class DmozStructHandler extends DefaultHandler {
    
    val contentTopics = Set("narrow", "symbolic")
    
    var isRelevant = false
    val topics = ArrayBuffer[String]()
    
    def topicSet(): Set[String] = topics.toSet
    
    override def startElement(uri: String, localName: String, 
            qName: String, attrs: Attributes): Unit = {
        if (!isRelevant && qName.equals("Topic")) {
            val numAttrs = attrs.getLength()
            val topicName = (0 until numAttrs)
                .filter(i => attrs.getQName(i).equals("r:id"))
                .map(i => attrs.getValue(i))
                .head
            if (topicName.equals("Top/Health/Medicine/Medical_Specialties"))
                isRelevant = true
        }
        if (isRelevant && contentTopics.contains(qName)) {
            val numAttrs = attrs.getLength()
            val contentTopicName = (0 until numAttrs)
                .filter(i => attrs.getQName(i).equals("r:resource"))
                .map(i => attrs.getValue(i))
                .map(v => if (v.indexOf(':') > -1)
                    v.substring(v.indexOf(':') + 1) else v)
                .head
            topics += contentTopicName
        }
    }
    
    override def endElement(uri: String, localName: String, 
            qName: String): Unit = {
        if (isRelevant && qName.equals("Topic")) isRelevant = false
    }
}

class DmozContentHandler(topics: Set[String]) extends DefaultHandler {
    
    var isRelevant = false
    var currentTopicName: String = null
    val contents = ArrayBuffer[(String, String)]()
    
    def contentUrls(): List[(String, String)] = contents.toList
    
    override def startElement(uri: String, localName: String, 
            qName: String, attrs: Attributes): Unit = {
        if (!isRelevant && qName.equals("Topic")) {
            val numAttrs = attrs.getLength()
            val topicName = (0 until numAttrs)
                .filter(i => attrs.getQName(i).equals("r:id"))
                .map(i => attrs.getValue(i))
                .head
            if (topics.contains(topicName)) {
                isRelevant = true
                currentTopicName = topicName
            }
        }
        if (isRelevant && qName.equals("link")) {
            val numAttrs = attrs.getLength()
            val link = (0 until numAttrs)
                .filter(i => attrs.getQName(i).equals("r:resource"))
                .map(i => attrs.getValue(i))
                .head
            contents += ((currentTopicName, link))
        }
    }
    
    override def endElement(uri: String, localName: String, 
            qName: String): Unit = {
        if (isRelevant && qName.equals("Topic")) {
            isRelevant = false
            currentTopicName = null
        }
    }
}

Output of this code is a tab separated file containing the category name, file URL, and the text. There are 464 records in all and 44 unique categories. The data looks something like this.

1
2
3
Top/Health/Medicine/Medical_Specialties/Aerospace_Medicine  http://www.civilavmed.com/  Home How are you as an AME adapting to the new ...
Top/Health/Medicine/Medical_Specialties/Aerospace_Medicine  http://www.asma.org/        Sunday, September 20, 2015 8:00 AM ICASM 2015 The Congress ...
...

The next step is to vectorize the reference text (the third column in the file shown above). We use OpenNLP's models to segment the text into sentences and extract noun phrases, using the excellent Scala interface provided by the NLP Tools Project from the University of Washington.

One of the artifacts published by the Word2Vec project is the set of vectors for around 3 million words and phrases trained on the Google News dataset. This information is released as a binary file, and we use Gensim's Word2Vec module to convert to TSV format that our code can use. Using the noun phrase retrieved using OpenNLP, we construct unigrams, bigrams and trigrams and look them up against this word vector dictionary. The vectors for the ngrams so found are added and normalized to build a single vector for each category.

1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
// Source: src/main/scala/dmozcat/vectorize/TextVectorizer.scala
package dmozcat.vectorize

import scala.Array.canBuildFrom

import org.apache.log4j.Level
import org.apache.log4j.Logger
import org.apache.spark.SparkConf
import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD.rddToPairRDDFunctions

import breeze.linalg.DenseVector
import breeze.linalg.InjectNumericOps
import breeze.linalg.norm

object TextVectorizer {

    def main(args: Array[String]): Unit = {
        Logger.getLogger("org.apache.spark").setLevel(Level.ERROR)
        Logger.getLogger("org.apache.spark.storage.BlockManager").setLevel(Level.ERROR)
        
        // arguments
        val awsAccessKey = args(0)
        val awsSecretKey = args(1)
        val inputFile = args(2)
        val word2vecFile = args(3)
        val outputDir = args(4)
        
        val conf = new SparkConf()
        conf.setAppName("TextVectorizer")
        
        val sc = new SparkContext(conf)
        
        sc.hadoopConfiguration.set("fs.s3n.awsAccessKeyId", awsAccessKey)
        sc.hadoopConfiguration.set("fs.s3n.awsSecretAccessKey", awsSecretKey)
        
        val input = sc.textFile(inputFile)
            .map(line => parseLine(line))
            .mapPartitions(p => extractNGrams(p)) // (key, ngram)
            .map(kv => (kv._2, kv._1))            // (ngram, key)
            
        val wordVectors = sc.textFile(word2vecFile)
            .map(line => {
                val Array(word, vecstr) = line.split("\t")
                val vector = new DenseVector(vecstr.split(",").map(_.toDouble))
                (word.toLowerCase, vector)        // (ngram, vector)
            })
        
        // join input to wordVectors by word
        val inputVectors = input.join(wordVectors)  // (ngram, (key, vector))
            .map(nkv => (nkv._2._1, nkv._2._2))     // (key, vector)
            .aggregateByKey((0, DenseVector.zeros[Double](300)))(
                (acc, value) => (acc._1 + 1, acc._2 + value),
                (acc1, acc2) => (acc1._1 + acc2._1, acc1._2 + acc2._2))
            .mapValues(countVec => 
                (1.0D / countVec._1) * countVec._2) // (key, mean(vector))
            
        // save document (id, vector) pair as flat file
        inputVectors.map(kvec => {
            val key = kvec._1
            val value = (kvec._2 / norm(kvec._2, 2))
                .toArray
                .map("%.5f".format(_))
                .mkString(",")
            "%s\t%s".format(key, value)
        }).saveAsTextFile(outputDir)
    }
        
    def parseLine(line: String): (String, String) = {
        val cols = line.split("\t")
        val key = cols.head
        val text = cols.last
        (key, text)
    }
    
    def extractNGrams(p: Iterator[(String, String)]): 
            Iterator[(String, String)] = {
        val t2n = new NGramExtractor()
        p.flatMap(keyText => t2n.ngrams(keyText))
    }
}

The actual vectorization logic is encapsulated inside the NGramExtractor, whose code is shown below. If you have used the OpenNLP API directly, you will appreciate the convenience of using the NLPTools API. Plus, since the models are captured as dependencies, you no longer have to explicitly deal with the OpenNLP model files. The only downside of the NLPTools API is the number of dependencies you must specify in your build file - this is because NLPTools unify access to a bunch of underlying tools, and each of them have different licenses, so having different JAR files for each is how NLPTools deals with this.

1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
// Source: src/main/scala/dmozcat/vectorize/NGramExtractor.scala
package dmozcat.vectorize

import edu.knowitall.tool.sentence.OpenNlpSentencer
import edu.knowitall.tool.postag.OpenNlpPostagger
import edu.knowitall.tool.tokenize.OpenNlpTokenizer
import edu.knowitall.tool.chunk.OpenNlpChunker

import scala.collection.mutable.ArrayBuffer

class NGramExtractor {

    val sentencer = new OpenNlpSentencer
    val postagger = new OpenNlpPostagger
    val tokenizer = new OpenNlpTokenizer
    val chunker = new OpenNlpChunker
    
    def ngrams(keyText: (String, String)): List[(String, String)] = {
        val key = keyText._1
        val text = keyText._2
        // segment text into sentences
        val sentences = sentencer.segment(text)
        // extract noun phrases from sentences
        val nounPhrases = sentences.flatMap(segment => {
            val    sentence = segment.text
            val chunks = chunker.chunk(sentence)
            chunks.filter(chunk => chunk.chunk.endsWith("-NP"))
                .map(chunk => (chunk.string, chunk.chunk))
                .foldLeft(List.empty[String])((acc, x) => x match {
                    case (s, "B-NP") => s :: acc
                    case (s, "I-NP") => acc.head + " " + s :: acc.tail
                }).reverse
        })
        // extract ngrams (n=1,2,3) from noun phrases
        val ngrams = nounPhrases.flatMap(nounPhrase => {
            val words = nounPhrase.toLowerCase.split(" ")
            words.size match {
                case 0 => List()
                case 1 => words
                case 2 => words ++ words.sliding(2).map(_.mkString("_"))
                case _ => words ++ 
                    words.sliding(2).map(_.mkString("_")) ++
                    words.sliding(3).map(_.mkString("_"))
            }
        })
        ngrams.map(ngram => (key, ngram)).toList
    }
}

The text from the documents to be categorized are similarly vectorized by extracting the noun phrases and looking up vectors for its unigrams, bigrams and trigrams against the Word2Vec dictionary. In order to use the same code to do this, we preprocess our documents from a set of files in a directory to a single file containing tab-separated filename and corresponding text.

1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
// Source: src/main/scala/dmozcat/preprocessing/TestDataReformatter.scala
package dmozcat.preprocessing

import java.io.File
import scala.io.Source
import java.io.PrintWriter
import java.io.FileWriter

object TestDataReformatter extends App {
    
    val inputDir = new File("/Users/palsujit/Projects/med_data/mtcrawler/texts")
    val outputFile = new PrintWriter(new FileWriter(new File("/tmp/testdata.csv")), true)
    inputDir.listFiles().foreach(inputFile => {
        val filename = inputFile.getName()
        val text = Source.fromFile(inputFile)
            .getLines
            .map(line => if (line.endsWith(".")) line else line + ".")
            .mkString(" ")
            .replaceAll("\\s+", " ")
        outputFile.println("%s\t%s".format(filename, text))
    })
    outputFile.flush()
    outputFile.close()
}

We vectorize the reference and test datasets in two separate Amazon EMR jobs. Amazon officially announced support for Spark on EMR couple of months ago (at the Spark Summit in San Francisco in June this year), and I've been meaning to try it out. Their instructions for running Spark on EMR is very detailed and worked perfectly for me, the only thing you need to know is to switch to the "advanced option" when defining your cluster.

Essentially Spark on EMR works in client mode. The first step invokes an Amazon custom job to execute a Hadoop dfs command to copy your JAR from S3 to the driver box. The next step actually runs your JAR against the Spark cluster. The JAR file needs to have (in this case) the OpenNLP models baked in, which can be done by invoking "sbt assembly".

The output of the job run against the file of reference text is a file of (category name, category vector) pairs, and a file of (filename, vector) pairs when run against the test dataset.

Next, we compute the cosine similarity of each document vector (in word2vec space) against each of the category vectors (also in word2vec space). Since there are only 44 category vectors, we convert this to a 44x300 matrix (the word2vec vectors provided are 300 dimensional) and broadcast it to the workers. Figuring out the best category is simply a matter of multiplying this matrix by the document vector. Both the matrix entries and the vector are normalized by their L2 norms, so the result is a vector of cosine similarities. We then find the top N (3) similarities, and their associated category names.

1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
// Source: src/main/scala/dmozcat/vectorize/KNearestCategories.scala
package dmozcat.vectorize

import scala.Array.canBuildFrom
import org.apache.log4j.Level
import org.apache.log4j.Logger
import org.apache.spark.SparkConf
import org.apache.spark.SparkContext
import breeze.linalg.DenseVector
import breeze.linalg.norm
import breeze.linalg.DenseMatrix

object KNearestCategories {
    
    def main(args: Array[String]): Unit = {
        
        Logger.getLogger("org.apache.spark").setLevel(Level.ERROR)
        Logger.getLogger("org.apache.spark.storage.BlockManager").setLevel(Level.ERROR)
        
        // arguments
        val awsAccessKey = args(0)
        val awsSecretKey = args(1)
        val refVectorsDir = args(2)
        val testVectorsDir = args(3)
        val bestcatsFile = args(4)
        
        val conf = new SparkConf()
        conf.setAppName("KNearestCategories")
        
        val sc = new SparkContext(conf)
        
        sc.hadoopConfiguration.set("fs.s3n.awsAccessKeyId", awsAccessKey)
        sc.hadoopConfiguration.set("fs.s3n.awsSecretAccessKey", awsSecretKey)
    
        // read reference vectors and broadcast them to workers for
        // replicated join
        val refVectors = sc.textFile(refVectorsDir)
            .map(line => {
                val Array(catStr, vecStr) = line.split("\t")
                val cat = catStr.split("/").last
                val vec = new DenseVector(vecStr.split(",").map(_.toDouble))
                val l2 = norm(vec, 2.0)
                (cat, vec / l2)
            }).collect
        val categories = refVectors.map(_._1)
        // we want to take the Array[DenseVector[Double]] and convert
        // it to DenseMatrix[Double] so we can do matrix-vector multiplication
        // for computing similarities later
        val nrows = categories.size
        val ncols = refVectors.map(_._2).head.length
        val catVectors = refVectors.map(_._2)
            .reduce((a, b) => DenseVector.vertcat(a, b))
            .toArray
        val catMatrix = new DenseMatrix[Double](ncols, nrows, catVectors).t
        // broadcast it
        val bCategories = sc.broadcast(categories)
        val bCatMatrix = sc.broadcast(catMatrix) 
        
        // read test vectors representing each test document
        val testVectors = sc.textFile(testVectorsDir)
            .map(line => {
                val Array(filename, vecStr) = line.split("\t")
                val vec = DenseVector(vecStr.split(",").map(_.toDouble))
                val l2 = norm(vec, 2.0)
                (filename, vec / l2)
            })
            .map(fileVec => 
                (fileVec._1, topCategories(fileVec._2, 3,
                    bCatMatrix.value, bCategories.value)))
            
        testVectors.map(fileCats => "%s\t%s".format(
                fileCats._1, fileCats._2.mkString(", ")))
            .coalesce(1)
            .saveAsTextFile(bestcatsFile)
    }

    def topCategories(vec: DenseVector[Double], n: Int,
            catMatrix: DenseMatrix[Double],
            categories: Array[String]): List[String] = {
        val cosims = catMatrix * vec
        cosims.toArray.zipWithIndex
            .sortWith((a, b) => a._1 > b._1)
            .take(n)                           // argmax(n)
            .map(simIdx => (categories(simIdx._2), simIdx._1)) // (cat,sim)
            .map(catSim => "%s (%.3f)".format(catSim._1, catSim._2))
            .toList
    }
}

This is also a Spark job executed on EMR. One thing I learned was that you cannot use broadcast variables with Scala objects that extend App because of scoping issues, as discussed in this StackOverflow page and this Spark JIRA. The final output looks like this:

1
 2
 3
 4
 5
 6
 7
 8
 9
10
2022.txt  Pain_Management (0.875), Podiatry (0.831), Sports_Medicine (0.826)
1928.txt  Pain_Management (0.858), Behavioral_Medicine (0.843), Podiatry (0.840)
1296.txt  Surgery (0.864), Podiatry (0.849), Radiotherapy (0.832)
0996.txt  Cardiology (0.864), Pain_Management (0.847), Radiotherapy (0.846)
2000.txt  Pain_Management (0.751), Cardiology (0.736), Radiotherapy (0.735)
0853.txt  Radiotherapy (0.773), Cardiology (0.767), Podiatry (0.734)
2228.txt  Pain_Management (0.918), Podiatry (0.904), Surgery (0.889)
0361.txt  Surgery (0.859), Podiatry (0.848), Pain_Management (0.843)
0916.txt  Cardiology (0.823), Radiotherapy (0.820), Pain_Management (0.801)
3078.txt  Palliative_Care (0.903), Surgery (0.887), Critical_Care (0.885)

Thats all I have for today. It was a lot of fun playing with Word2Vec and the KnowItAll APIs for OpenNLP, as well as using Spark on EMR. Nowadays people consider DMOZ slightly outdated compared to Wikipedia, but it provides a nice hierarchical topic taxonomy that can be used effectively for this kind of work. I hope you enjoyed reading it, hopefully it gave you some ideas to categorize your own data. All the code for this post is available on GitHub here.

No comments:

Post a Comment

Comments are moderated to prevent spam.