The BERT (Bidirectional Encoder Representation from Transformers) model was proposed in the paper BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding (Devlin, et al, 2019). BERT is the encoder part of an encoder-decoder architecture called Transformers, that was proposed in Attention is all you need (Vaswani, et al., 2017). The BERT model is pre-trained on two tasks against a large corpus of text in a self-supervised manner -- first, to predict masked words in a sentence, and second, to predict a sentence given the previous one, and are called Masked Language Modeling and Next Sentence Prediction tasks respectively. These pre-trained models can be further fine-tuned for tasks as diverse as classification, sequence prediction, and question answering.
Since the release of BERT, the research community has done a lot of work around Transformers and BERT-like architectures, so much so, that HuggingFace has its enormously popular transformers library dedicated to helping people work efficiently and easily with popular Transformer architectures. Among other things, the HuggingFace transformers library provides a unified interface to working with different kinds of Transformer architectures (with slightly different details), as well as provide weights for many pre-trained Transformer architectures.
Most of my work with transformers so far has been around fine-tuning them for Question Answering and Sequence Prediction. I recently came across a blog post Examining BERT's raw embeddings by Ajit Rajasekharan, where he describes how one can use a plain BERT model (pre-trained only, no fine-tuning required) and later a BERT Masked Language Model (MLM), as a Language Model, to help with Word Sense Disambiguation (WSD).
The idea is rooted in the model's ability to produce contextual embeddings for a words in a sentence. A pre-trained model has learned enough about the language it is trained on, to produce different embeddings for a homonym based on different sentence contexts it appears in. For example, a pre-trained model would produce a different vector representation for the word "bank" if it is used in the context of a bank robbery versus a river bank. This is different from how the older word embeddings such as word2vec work, in that case a word has a single embedding, regardless of the sentence context in which it appears.
An important point here is that there is no fine-tuning, we will leverage the knowledge inherent in the pre-trained models for our WSD experiments, and use these models in inference mode.
In this post, I will summarize these ideas from Ajit Rajasekharan's blog post, and provide Jupyter notebooks with implementations of these ideas using the HuggingFace transformers library.
WSD using raw BERT embeddings
Our first experiment uses a pre-trained BERT model initialized with the weights of a bert-base-cased model. We extract a matrix of "base" embeddings for each word in the model's vocabulary. We then pass in sentences containing our ambiguous word into the pre-trained BERT model, and capture the input embedding and output embedding for our ambiguous word. Our first sentence uses the word "bank" in the context of banking, and our second sentence uses it in the context of a river bank.
We then compute the cosine similarity between the embedding (input and output) for our ambiguous word against all the words in the vocabulary, and plot the histogram of cosine similarities. We notice that in both cases, the histogram shows a long tail, but the histogram for the output embedding seems to have a shorter tail, perhaps because there is less uncertainty once the context is known.
We then identify the words in the vocabulary whose embeddings are most similar (cosine similarity) to the embedding for our ambiguous word. As expected, the similar words for both input embeddings relate to banking (presumably because this may be the dominant usage of the word in the language). For the output embeddings, also as expected, similar words for our ambiguous word relate to banking in the first sentence, and rivers in the second.
The notebook WSD Using BERT Raw Embeddings contains the implementation described above.
WSD using BERT MLM
In our second experiment, we mask out the word "bank" in our two sentences and replace it with the [MASK] token. We then pass these sentences through a BERT Masked Language Model (MLM) initialized with weights from a bert-base-cased model. The output of the MLM is a 3-dimensional tensor of logits, where the first dimension is the number of sentences in the batch (1), the second dimension is the number of tokens in the input sentence, and the third domension is the number of words in the vocabulary. Effectively, the output provides log probabilities for predictions across the entire vocabulary for each token position in the input.
As before, we identify the logits corresponding to our masked position in the input (and output) sequence, then compute the softmax of the logits to convert them to probabilities. We then extract the top k (k=20) terms with the highest probabilities.
Again, as expected, predictions for the masked word are predominantly around banking for the first sentence, and predominantly around rivers for the second sentence.
The notebook WSD Using BERT Masked Language Model contains the implementation described above.
So thats all I had for today. Even though I understood the idea in Ajit Rajasekharan's blog post at a high level, and had even attempted something similar for WSD using non-contextual word embeddings (using the average of word embeddings across a span of text around the ambiguous word), it was interesting to actually go into the transformer model and figure out how to make things work. I hope you found it interesting as well.