Sunday, September 27, 2015

Sentence Similarity using Word2Vec and Word Movers Distance


Sometime back, I read about the Word Mover's Distance (WMD) in the paper From Word Embeddings to Document Distances by Kusner, Sun, Kolkin and Weinberger. The WMD is a distance function that measures the distance between two texts as the cumulative sum of minimum distance each word in one text must move in vector space to the closest word in the other text. In the paper, the authors provide some examples where WMD is calculated against a Word2Vec vector space. Since Word2Vec word embeddings preserve aspects of the word's context, its a good way to capture semantic meaning (or difference in meaning) when calculating WMD.

The paper reminded me of a similar (in intent) algorithm that I had implemented earlier and written about in my post Computing Semantic Similarity for Short Sentences. There, we captured the semantic meaning using an external semantic network (Wordnet).

Since the problems were so similar, I figured that it might be interesting to compute the WMD for the sentence pairs in this paper and see how they match up with intuition. I already had lying around a dump of the GoogleNews vectors (pretrained vectors over about 100B words of Google News) from a previous project. The paper described results over a dataset of just 16 short sentence pairs, so I decided to do this interactively on Spark using a Databricks notebook. We use Databricks at work and its ideal for this kind of quick and dirty ad-hoc work.

First we load up our 16 sentence pairs. The input is 3 columns - sentence#1, sentence#2 and the original score, tab separated. Since we don't care about the original score, we discard it and convert the input to a pair.

Since we want to compare words across sentences in the same pair, it makes sense to have these words in the same worker when they are compared, so we add an index key to each sentence pair. The output of this cell is an RDD that looks like ((sentence1: String, sentence2: String), index: Long).

1
 2
 3
 4
 5
 6
 7
 8
 9
10
import org.apache.spark.storage.StorageLevel

val sentencePairs = sc.textFile("sentence_pairs.txt")
    .map(line => {
        val Array(s1, s2, _) = line.split('\t')
        (s1, s2)
    })
    .zipWithIndex
    .persist(StorageLevel.MEMORY_AND_DISK)
sentencePairs.count()

WMD between two sentences (or between any two blobs of text) is computed as the sum of the distances between closest pairs of words in the texts. The words are pre-processed to remove stop words, so the next cell pulls in a list of English stopwords which I convert to a Set and broadcast to the Worker boxes.

1
2
val stopwords = sc.textFile("stopwords.txt").collect.toSet
val bStopwords = sc.broadcast(stopwords)

We now split up both sentences into words (removing punctuation and splitting on whitespace), removing stopwords from each, then flatMap-ing them to the format (index: Long, (word1: String, word2: String)). This gives us a list of 71 word pairs.

1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
def getWordPairs(id: Long, s1: String, s2: String, stopwords: Set[String]): 
        List[(Long, (String, String))] = {
    val w1s = s1.toLowerCase
          .replaceAll("\\p{Punct}", "")
          .split(" ")
          .filter(w => !stopwords.contains(w))
    val w2s = s2.toLowerCase
          .replaceAll("\\p{Punct}", "")
          .split(" ")
          .filter(w => !stopwords.contains(w))
    val wpairs = for (w1 <- w1s; w2 <- w2s) yield (id, (w1, w2))
    wpairs.toList
}

val wordPairs = sentencePairs.flatMap(ssi => 
    getWordPairs(ssi._2, ssi._1._1, ssi._1._2, bStopwords.value))
wordPairs.count()

Next we ingest the Word2Vec vectors. I've used Gensim's Word2Vec module to convert the the original Word2Vec binary format to TSV. The format of this dataset is (word: String, comma-separated list of 300 vector elements).

1
2
3
4
5
val w2vs = sc.textFile("GoogleNews-vectors-negative300.tsv")
    .map(line => {
        val Array(word, vector) = line.split('\t')
        (word, vector)
    })

Next, we join the wordPairs against the w2vs RDD on the RHS and the LHS words to get the 300 dimensional word2vec vector for the RHS and LHS word respectively. We do a lot of moving things around so I have used case matching instead of the less intuitive underscore syntax to represent tuple elements and subelements. Note that we need to hang on to the left word because we want to find the word that is closest to each left word.

1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
import breeze.linalg._

def dist(lvec: String, rvec: String): Double = {
    val lv = DenseVector(lvec.split(',').map(_.toDouble))
    val rv = DenseVector(rvec.split(',').map(_.toDouble))
    math.sqrt(sum((lv - rv) :* (lv - rv)))
}

val wordVectors = wordPairs.map({case (idx, (lword, rword)) => 
        (rword, (idx, lword))})
    .join(w2vs)    // (rword, ((idx, lword), rvec))
    .map({case (rword, ((idx, lword), rvec)) => (lword, (idx, rvec))})
    .join(w2vs)    // (lword, ((idx, rvec), lvec))
    .map({case (lword, ((idx, rvec), lvec)) => ((idx, lword), (lvec, rvec))})
    .map({case ((idx, lword), (lvec, rvec)) => 
        ((idx, lword), List(dist(lvec, rvec)))}) 
    .persist(StorageLevel.MEMORY_AND_DISK)

I used Euclidean Distance in Word2Vec space for distance between words. I also tried using Cosine Distance (1 - Cosine Similarity) with similar results. We then sum all the shortest distances across all LHS words to get the WMD for the sentence pair.

1
2
3
4
val bestWMDs = wordVectors.reduceByKey((a, b) => a ++ b)
    .mapValues(dists => dists.sortWith(_ < _).head)  // dist to closest word
    .map({case ((idx, lword), wmd) => (idx, wmd)})
    .reduceByKey((a, b) => a + b)                    // sum all wmds for sent

Finally, we join these WMD scores back into the original dataset using the pair index that we originally generated using zipWithIndex.

1
 2
 3
 4
 5
 6
 7
 8
 9
10
val sqlContext = new org.apache.spark.sql.SQLContext(sc)
import sqlContext.implicits._

case class SentencePair(s1: String, s2: String, wmd: Double)
val results = sentencePairs.map(_.swap)
    .join(bestWMDs)
    .map({case (id, ((s1, s2), wmd)) => SentencePair(s1, s2, wmd)})
val resultsDF = sqlContext.createDataFrame(results)
    .orderBy($"s1".asc, $"wmd".asc)
display(resultsDF)

The results are shown below. The sentences are sorted by the LHS sentence first, then by WMD (lowest WMD first so we can easily see the closest sentence pairs first and compare them to pairs that are not as close).

LHS SentenceRHS SentenceWMD
A glass of cider.A full cup of apple juice.2.2169259719396095
Canis familiaris are animals.Dogs are common pets.1.859694788966317
Dogs are animals.They are common pets.1.4537090848972198
I have a hammer.Take some nails.1.1578027104196844
I have a hammer.Take some apples.1.3028564676146912
I have a pen.Where is ink?1.020277185488236
I have a pen.Where do you live?1.3924941078355293
I like that bachelor.I like that unmarried man.1.176742725809037
It is a dog.That must be your dog.0
It is a dog.It is a pig.1.04864558369858
It is a dog.It is a log.1.3798001799052624
John is very nice.Is John very nice?0
Red alcoholic drink.Fresh orange juice.3.1161814560971166
Red alcoholic drink.A bottle of wine.3.386809492524872
Red alcoholic drink.Fresh apple juice.3.505168296314785
Red alcoholic drink.An English dictionary.4.106139922327307

As you can see, the scoring seems correct. For example, it finds that a "glass of cider" and a "cup of apple juice" are quite similar, even though there are no shared words (except for stopwords). Similarly "I have a hammer" is more similar to "Take some nails" than "Take some apples". The only intuitively incorrect result in this set is that "Red alcoholic drink" is more similar to "Fresh orange juice" than a "A bottle of wine". However, "A bottle of wine" is more similar to "Red Alcoholic drink" than "Fresh apple juice" and "An English dictionary" respectively. So overall, it seems to work on my limited dataset.

In my case, I already have two sentences and I just have to find the distance between them. In cases where you have to find the closest sentence, the complexity of the algorithm is O(p3 log p). One suggestion is to prune the number of possible RHS sentences by thresholding on the centroid distance (WCD) or relaxed WMD (see the paper for details) between the two sentences, and only running the full WMD on the pruned set of sentence pairs.

11 comments (moderated to prevent spam):

Anonymous said...

how did you convert the GoogleNews-vectors-negative300.bin to GoogleNews-vectors-negative300.tsv

Unknown said...

sounds great, it is very in now to add more info to word2vec http://ai.stanford.edu/~amaas/papers/wvSent_acl2011.pdf Learning Word Vectors for Sentiment Analysis or http://anthology.aclweb.org/P/P14/P14-1146.pdf Learning Sentiment-Specific Word Embedding or Coooolll A Deep Learning System for Twitter Sentiment Classification http://www.aclweb.org/anthology/S14-2033, but may you share some code to study this Word Movers Distance, sander.stepanov@gmail.com

Sujit Pal said...

Hi Sander, there is a link to the original WMD paper at the top of the post. Also I calculate WMD in my post using the code snippet that starts with "val bestWMDs".

Unknown said...

I see, I am working with Python, R, and Matlab, it is pity I can not understand this code, do you know Python, R, and Matlab examples or some simple description . By the reference in
http://sujitpal.blogspot.ca/2014/12/semantic-similarity-for-short-sentences.html
to A reader recently recommended a paper for me to read - Sentence Similarity Based on Semantic Nets and Corpus Statistics. not working do you know fresh link?

Sujit Pal said...

@Anonymous: apologies for the delay in replying, just saw your comment sitting in my queue, must have missed it earlier. I used Gensim's Word2Vec module to do the conversion from BIN to TSV.

@Sander: I don't have code for WMD in either of these languages, but here is the definition: The WMD is a distance function that measures the distance between two texts as the cumulative sum of minimum distance each word in one text must move in vector space to the closest word in the other text. So basically do an all-pairs between words in the two sentences to find the closest word pairs in word2vec space, then sum these distances together.

For the "Sentence Similarity Based on Semantic Nets and Corpus Statistics" paper, I found a couple of references behind the IEEE and ACM paywalls, but the original link to sinica.edu.tw seems to be dead. I guess you are out of luck unless you or your organization have access (I don't unfortunately).

Anonymous said...

This is so great publishing open source for learning, any way I'm also very new on the scala. Could you issue the full files of the code on github or (touy_say@hotmail.com), that will make deep understanding what you did in your implementations.

Thanks for your will issue.

Sujit Pal said...

Thank you. I did this using a Scala Databricks Notebook, and I have already provided snippets in the post. I have also downloaded the notebook as HTML and put it into a Github Gist here. Unfortunately gist does not render the page, so you will have to download it locally and view it through the browser using file:///path/to/downloaded.html. If you have access to a Databricks Notebook setup (they have a community edition also), you can import it into that also.

Anonymous said...

Hi, this is great post and would like to implement WMD using R. Do you know any existing R packages or references to implement in R.
Great to have reply and thanks for your time.

Sujit Pal said...

Thank you, and sorry, but I don't know of R packages that implement WMD. But it should be possible to implement it yourself, this blog post has a nice explanation and a Python implementation.

Mr Elusive said...

Nice work! How does your implementation account for the "flow" of one LHS word to multiple RHS words?

Sujit Pal said...

Thank you. To answer your question, it doesn't, it just uses shortest distance between single words on the LHS to RHS. I guess we could do a limited form of that by adding 2-grams and 3-grams to our list of tokens on either end, and removing the subsumed tokens after an n-gram has matched.