Saturday, March 23, 2024

Book Report: Machine Learning for Drug Discovery

Drug Discovery is a field where biochemists (and more recently computer scientists) turn ideas into potential medications. I first came across a few applications in this area when checking out how to build Graph Neural Networks (GNN) as part of auditing the CS224W: Machine Learning with Graphs course from Stanford, some learnings of which I recycled into my Deep Learning with Graphs tutorial at ODSC 2021. Of course, drug discovery is much more than just GNNs, I mention this only because this happened to be my entry point into this fascinating world. However, I will hasten to add that despite having made an entrance, I am still parked pretty solidly close to the entrance (or exit, depending on your point of view).

But I am always looking to learn more about stuff I find interesting, so when I was offered a chance to review Dr Noah Flynn's Machine Learning for Drug Discovery published by Manning, I jumped on it. The book is currently in MEAP (Manning Early Access Program) so currently there are only 5 chapters available, but once the book is completed, there are going to be 15 chapters in all. The intended audience of the book, as the title suggests, are computational biochemists, i.e. the ones who attempt to solve Drug Discovery problems using Machine Learning. Thus, to become a computational biochemist, there are two main ways -- either you are a biochemist and you learn the ML, or you are a ML person and you learn the biochemistry. The book is aimed at both categories of readers.

As someone in the latter category, I had to spend much more time on the biochemistry aspects. I suspect that most readers of this review would also fall into this category. For them, I would say that while the ML part is sophisticated enough to solve the problem at hand, they are methods and practices that should be familiar to most ML people already. The most useful things that I think you would get out of this book are as follows:

  • Framing the Drug Discovery problem as a ML problem
  • Preprocessing and Encoding inputs
  • Getting data to train your ML model

For the first one, you either need to have a biochemistry background yourself, or you need to pair with someone who does. I suppose you could get by with a life sciences or chemistry background as well, or acquire enough biochemistry knowledge over time in this field, and this book may even take you part of the way there, but be aware that the learning curve is steep.

For the second and the third items, I thought the book was super useful. Most chapters are built as case studies around a Drug Discovery problem, so as you go through the chapters, you will learn about the sites to acquire your datasets from, and the techniques to preprocess the data from these sites into a form suitable for consumption by your ML model. At least the first 5 chapters deal with fairly simple ML models, but which may or may not be familiar to you depending on your industry, so you might also learn a few things about evaluating or tuning these models that you didn't know before (I did).

The first chapter introduces the reader to the domain and talks about the need for computational approaches to Drug Discovery. It introduces the terminology and the RDKit software library, an open-source cheminformatics toolkit the provides implementations of many common operations needed for computational Drug Discovery (sort of like a specialized supplement to Scikit-Learn for general ML). It also covers high level rules of thumb for detecting drug compounds, such as Lipinski's rule of 5. It then covers some common use cases common in Drug Discovery, ranging from Virtual Screening to Generative and Synthetic Chemistry. It also covers some popular (and public) repositories for Chemistry data, such as ChEMBL, PubChem, Protein Data Bank (PDB), etc.

The second chapter demonstrates Ligand based Screening, where you already have a reference molecule with some of the desired properties, and you want to search the chemical space for molecules similar to that one, with the objective of finding more drugs like the one you started with. The case study here is to identify potential anti-malarial compounds. The dataset for this comes packaged with RDKit itself as Structure Definition Files (SDF) which describes each molecule using a SMILES (Simplified Molecular Input Link Entry System) string. The chapter walks us through converting the SMILES to MOL format, then using RDKit to extract specialized chemical features from the MOL and SMILES, preprocessing to filter out uninteresting molecules based on rule based thresholds such as bio-availability, molecular weight, etc, structure based thresholds such as toxicity, and specific substructural patterns (similar to subgraph motifs). It then uses RDKit to generate Morgan fingerprints out of the remaining molecules (MOL). Morgan (and other) fingerprints are similar to embeddings in NLP, except that they encode structural information through a more deterministic process, and are hence more explainable than embeddings. Finally, these fingerprints are compared with the reference molecule using Tanimoto similarity and the nearest neighbors found.

Chapter 3 continues with the problem of Ligand based screening, but tries to predict cardiotoxicity of the anti-malarial compounds found in the previous chapter using a linear model. This is done indirectly by predicting if the compound blocks the hERG (or gene potassiuam) channel, then it is cardiotoxic, and vice versa. A linear model (Scikit-Learn SGD CLassifier) is trained using the hERG dataset from the Therapeutic Data Commons (TDC). The chapter shows some Exploratory Data Analysis (EDA) on the data, using standard preprocessing as described in the previous chapter. An additional step here is to standardize (regularize) the data for classification. The author provides the biochemistry reasoning for behind this step, but uses the implementation already provided by RDKit. Finally Morgan fingerprints are used to train the SGD Classifier. Because the elements of Morgan fingerprints have meaning, the weights of the resulting SGD model can be used to determine feature importances. There is also some discussion here of cross validation, L1/L2 regularization, removing collinearity, adding interaction terms and hyperparameter sweeps.

Chapter 4 explores building a linear regression model to predict solubility, i.e. how much of the drug would be absorbed by the system. The dataset used to train the regressor is the AqSolDB, also from TDC. This chapter introduces the idea of scaffold splitting, a technique common with biochemical datasets that preserves the structural / chemical similarity within each split. It also briefly describes outlier removal at the extremes, which requires chemistry knowledge. The RDKit library is used to extract features from the dataset, and the model trained to minimize the Mean Squared Error loss. The RANSAC (RANdom SAmple Consensus) technique is introduced that makes models more robust to outliers. On the ML side, there is some discussion on the bias-variance tradeoff and Learning / Validation curves.

The fifth and last chapter of the MEAP (at the time of writing this review) deals with predicting how well the body will metabolize the drug. Typically, drugs are broken down into enzymes in the liver, a large proportion of which are collectively known as the Cytochrome P450 superfamily. As before, metabolism is predicted indirectly by whether the drug inhibits Cytochrome P450 -- if it does, then it will not get metabolized easily, and vice versa. The dataset used to train the model is the CYP3A4 dataset, also from TDC. Data is prepared using the same set of (by now) standard pipeline and the classifier trained a binary predictions of whether the input inhibits Cytochrome P450 or not. The chapter discusses the utility of Reliability Plots in Performance Evaluation and Platt scaling for calibrating probabilities. It also talks about how to deal with imbalanced datasets, Data Augmentation, Class Weights and other approaches to deal with class imbalance. Various models are trained and evaluated, and their important features identified and visualized with RDKit Similarity Map. The chapter ends with a short discussion on Multi-label classification.

The pandemic and the rapid discovery of the COVID vaccine gave a lot of us (at least those of us that were watching) a ringside view into the fascinating world of drug discovery. This book provides yet another peek into this world, with its carefully crafted case studies and examples. Overall, I think you will learn a lot about drug discovery if you go through this book, both on the biochemistry side and the ML side. There are exercises at the end of each chapter, doing these would help you get more familiar with RDKit and hopefully more effective at computational drug discovery.

Sunday, March 17, 2024

Hierarchical (and other) Indexes using LlamaIndex for RAG Content Enrichment

At our weekly This Week in Machine Learning (TWIML) meetings, (our leader and facilitataor) Darin Plutchok pointed out a LinkedIn blog post on Semantic Chunking that has been recently implemented in the LangChain framework. Unlike more traditional chunking approaches that use number of tokens or separator tokens as a guide, this one chunks groups of sentences into semantic units by breaking them when the (semantic) similarity between consecutive sentences (or sentence-grams) fall below some predefined threshold. I had tried it earlier (pre-LangChain) and while results were reasonable, it would need a lot of processing, so I went back to what I was using before.

I was also recently exploring LlamaIndex as part of the effort to familiarize myself with the GenAI ecosystem. LlamaIndex supports hierarchical indexes natively, meaning it provides the data structures that make building them easier and more natural. Unlike the typical RAG index, which are just a sequence of chunks (and their vectors), hierarchical indexes would cluster chunks into parent chunks, and parent chunks into grandparent chunks, and so on. A parent chunk would generally inherit or merge most of the metadata from its children, and its text would be a summary of its children's text contents. To illustrate my point about LlamaIndex data structures having natural support for this kind of setup, here are the definitions of the LlamaIndex TextNode (the LlamaIndex Document object is just a child of TextNode with an additional doc_id: str field) and the LangChain Document. Of particular interest is the relationships field, which allows pointers to other chunks using named relationships PARENT, CHILD, NEXT, PREVIOUS, SOURCE, etc. Arguably, the LlamaIndex TextNode can be represented more generally and succintly by the LangChain Document, but the hooks do help to support hierarchical indexing more naturally.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
# this is a LlamaIndex TextNode
class TextNode:
  id_: str = None
  embedding: Optional[List[float]] = None
  extra_info: Dict[str, Any]
  excluded_embed_metadata_keys: List[str] = None
  excluded_llm_metadata_keys: List[str] = None
  relationships: Dict[NodeRelationship, Union[RelatedNodeInfo, List[RelatedNodeInfo]] = None
  text: str
  start_char_idx: Optional[int] = None
  end_char_idx: Optional[int] = None
  text_template: str = "{metadata_str}\n\n{content}"
  metadata_template: str = "{key}: {value}",
  metadata_separator = str = "\n"

# and this is a LangChain Document
class Document:
  page_content: str
  metadata: Dict[str, Any]

In any case, having discovered the hammer that is LlamaIndex, I began to see a lot of potential hierarchical indexes nails. One such nail that occurred to me was to use Semantic Chunking to cluster consecutive chunks rather than sentences (or sentence-grams), and then create parents nodes from these chunk clusters. Instead of computing cosine similarity between consecutive sentence vectors to build up chunks, we compute cosine similarity across consecutive chunk vectors and split them up into clusters based on some similarity threshold, i.e. if the similarity drops below the threshold, we terminate the cluster and start a new one.

Both LangChain and LlamaIndex have implementations of Semantic Chunking (for sentence clustering into chunks, not chunk clustering into parent chunks). LangChain's Semantic Chunking allows you to set the threshold using percentiles, standard deviation and inter-quartile range, while the LlamaIndex implementation supports only the percentile threshold. But intuitively, here's how you could get an idea of the percentile threshold to use -- thresholds for the other methods can be computed similarly. Assume your content has N chunks and K clusters (based on your understanding of the data or from other estimates), then assuming a uniform distribution, there would be N/K chunks in each cluster. If N/K is approximately 20%, then your percentile threshold would be approximately 80.

LlamaIndex provides an IngestionPipeline which takes a list of TransformComponent objects. My pipeline looks something like below. The last component is a custom subclass of TransformComponent, all you need to do is to override it's __call__ method, which takes a List[TextNode] and returns a List[TextNode].

1
2
3
4
5
6
7
8
transformations = [
    text_splitter: SentenceSplitter,
    embedding_generator: HuggingFaceEmbedding,
    summary_node_builder: SemanticChunkingSummaryNodeBuilder
]
ingestion_pipeline = IngestionPipeline(transformations=transformations)
docs = SimpleDirectoryReader("/path/to/input/docs")
nodes = ingestion_pipeline.run(documents=docs)

My custom component takes the desired cluster size K during construction. It uses the vectors computed by the (LlamaIndex provided) HuggingFaceEmbedding component to compute similarities between consecutive vectors and uses K to compute a threshold to use. It then uses the threshold to cluster the chunks, resulting in a list of list of chunks List[List[TextNode]]. For each cluster, we create a summary TextNode and set its CHILD relationships to the cluster nodes, and the PARENT relationship of each child in the cluster to this new summary node. The text of the child nodes are first condensed using extractive summarization, then these condensed summaries are further summarized into one final summary using abstractive summarization. I used bert-extractive-summarizer with bert-base-uncased for the first and a HuggingFace summarization pipeline with facebook/bert-large-cnn for the second. I suppose I could have used an LLM for the second step, but it would have taken more time to build the index, and I have been experimenting with ideas described in the DeepLearning.AI course Open Source Models with HuggingFace.

Finally, I recalculate the embeddings for the summary nodes -- I ran the summary node texts through the HuggingFaceEmbedding, but I guess I could have done some aggregation (mean-pool / max-pool) on the child vectors as well.

Darin also pointed out another instance of Hierarchical Index proposed via the RAPTOR: Recursive Abstractive Processing for Tree-Organized Retrieval and described in detail by the authors in this LlamaIndex webinar. This is a bit more radical than my idea of using semantic chunking to cluster consecutive chunks, in that it allows clustering of chunks across the entire corpus. One other significant difference is that it allows for soft-clustering, meaning a chunk can be a member of more than one chunk. They first reduce the dimensionality of the vector space using UMAP (Uniform Manifold Approximation and Projection) and then apply Gaussian Mixture Model (GMM) to do the soft clustering. To find the optimum number of clusters K for the GMM, one can use a combination of AIC (Aikake Information Criterion) and BIC (Bayesian Information Criterion).

In my case, when training the GMM, the AIC kept decreasing as the number of clusters increased, and the BIC had its minimum value for K=10, which corresponds roughly to the 12 chapters in my Snowflake book (my test corpus). But there was a lot of overlap, which would force me to implement some sort of logic to take advantage of the soft clustering, which I didn't want to do, since I wanted to reuse code from my earlier Semantic Chunking node builder component. Ultimately, I settled on 90 clusters by using my original intuition to compute K, and the resulting clusters seem reasonably well separated as seen below.

Using the results of the clustering, I built this also as another custom LlamaIndex TransformComponent for hierarchical indexing. This implementation differs from the previous one only in the way it assigns nodes to clusters, all other details with respect to text summarization and metadata merging are identical.

For both these indexes, we have a choice to maintain the index as hierarchical, and decide which layer(s) to query based on the question, or add the summary nodes into the same level as the other chunks, and let vector similarity surface them when queries deal with cross-cutting issues that may be found together in these nodes. The RAPTOR paper reports that they don't see a significant gain using the first approach over the second. Because my query functionality is LangChain based, my approach has been to generate the nodes and then reformat them into LangChain Document objects and use LCEL to query the index and generate answers, so I haven't looked into querying from a hierarchical index at all.

Looking back on this work, I am reminded of similar choices when designing traditional search pipelines. Often there is a choice between building functionality into the index to support a cheaper query implementation, or building the logic into the query pipeline that may be more expensive but also more flexible. I think LlamaIndex started with the first approach (as evidenced by their blog posts Chunking Strategies for Large Language Models Part I and Evaluating Ideal Chunk Sizes for RAG Systems using LlamaIndex) while LangChain started with the second, even though nowadays there is a lot of convergence between the two frameworks.

Saturday, February 24, 2024

Thoughts on using LangChain LCEL with Claude

I got into Natural Language Processing (NLP) and Machine Learning (ML) through Search. And this led me into Generative AI (GenAI), which led me back to Search via Retrieval Augmented Generation (RAG). RAG started out relatively simple -- take a query, generate search results, use search results as context for a Large Language Model (LLM) to generate an abstractive summary of the results. Back when I started on my first "official" GenAI project middle of last year, there were not too many frameworks to support building GenAI components (at least not the prompt based ones), except maybe LangChain, which was just starting out. But prompting as a concept is not too difficult to understand and implement, so thats what we did at the time.

I did have plans to use LangChain in my project once it became more stable, so I started out building my components to be "langchain compliant". But that turned out to be a bad idea as LangChain continued its exponential (and from the outside at least, somewhat haphazard) growth and showed no signs of stabilizing. At one point, LangChain users were advised to make pip install -U langchain part of their daily morning routine! So anyway, we ended up building up our GenAI application by hooking up third party components with our own (non-framework) code, using Anthropic's Claude-v2 as our LLM, ElasticSearch as our lexical / vector document store and PostgreSQL as our conversational buffer.

While I continue to believe that the decision to go with our own code made more sense than trying to jump on the LangChain (or Semantic Kernel, or Haystack, or some other) train, I do regret it in some ways. A collateral benefit for people who adopted and stuck with LangChain were the ready-to-use implementations of cutting-edge RAG and GenAI techniques that the community implemented at almost the same pace as they were being proposed in academic papers. For the subset of these people that were even slightly curious about how these implementations worked, this offered a ringside view into the latest advances in the field and a chance to stay current with it, with minimal effort.

So anyway, in an attempt to replicate this benefit for myself (going forward at least), I decided to learn LangChain by doing a small side project. Earlier I needed to learn to use Snowflake for something else and had their free O'Reilly book on disk, so I converted it to text, chunked it, and put it into a Chroma vector store. I then tried to implement examples from the DeepLearning.AI courses LangChain: Chat with your Data and LangChain for LLM Application Development. The big difference is that the course examples use OpenAI's GPT-3 as their LLM whereas I use Claude-2 on AWS Bedrock in mine. In this post, I share the issues I faced and my solutions, hopefully this can help guide others in similar situations.

Couple of observations here. First, the granularity of GenAI components is necessarily larger than traditional software components, and this means application details that the developer of the component was working on can leak into the component itself (mostly through the prompt). To a user of the component, this can manifest as subtle bugs. Fortunately, LangChain developers seem to have also noticed this and have come up with the LangChain Expression Language (LCEL), a small set of reusable components that can be composed to create chains from the ground up. They have also marked a large number of Chains as Legacy Chains (to be converted to LCEL chains in the future).

Second, most of the components (or chains, since that is LangChain's central abstraction) are developed against OpenAI GPT-3 (or its chat version GPT-3.5 Turbo) whose strengths and weaknesses may be different from those of your LLM. For example, OpenAI is very good at generating JSON output, whereas Claude is better at generating XML. I have also seen that Claude can terminate XML / JSON output mid-output unless forced to complete using stop_sequences. Yhis doesn't seem to be a problem GPT-3 users have observed -- when I mentioned this problem and the fix, I drew a blank on both counts.

To address the first issue, my general approach in trying to re-implement these examples has been to use LCEL to build my chains from scratch. I attempt to leverage the expertise available in LangChain by looking in the code or running the existing LangChain chain with langchain.debug set to True. Doing this helps me see the prompt being used and the flow, which I can use to adapt the prompt and flow for my LCEL chain. To address the second issue, I play to Claude's strengths by specifying XML output format in my prompts and parsing them as Pydantic objects for data transfer across chains.

The example application I will use to illustrate these techniques here is derived from the Evaluation lesson from the LangChain for LLM Application Development course, and is illustrated in the diagram below. The application takes a chunk of text as input, and uses the Question Generation chain to generate multiple question-answer pairs from it. The questions and the original content are fed into the Question Answering chain, which uses the question to generate additional context from a vector retriever, and uses all three to generate an answer. The answer generated from the Question Generation chain and the answer generated from the Question Answering chain are fed into a Question Generation Evaluation chain, where the LLM grades one against the other, and generates an aggregate score for the questions generated from the chunk.

Each chain in this pipeline is actually quite simple, they take one or more inputs and generates a block of XML. All the chains are structured as follows:

1
2
3
from langchain_core.output_parsers import StrOutputParser

chain = prompt | model | StrOutputParser()

And all our prompts follow the same general format. Here is the prompt for the Evaluation chain (the third one) which I adapted from the QAEvalChain used in the lesson notebook. Developing from scratch using LCEL gives me the chance to use Claude's Human / Assistant format (see LangChain Guidelines for Anthropic) rather than depend on the generic prompt that happens to work well for GPT-3.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
Human: You are a teacher grading a quiz.

You are given a question, the context the question is about, and the student's 
answer.

QUESTION: {question}
CONTEXT: {context}
STUDENT ANSWER: {predicted_answer}
TRUE ANSWER: {generated_answer}

You are to score the student's answer as either CORRECT or INCORRECT, based on the 
context.

Write out in a step by step manner your reasoning to be sure that your conclusion 
is correct. Avoid simply stating the correct answer at the outset.

Please provide your response in the following format:

<result>
    <qa_eval>
        <question>the question here</question>
        <student_answer>the student's answer here</student_answer>
        <true_answer>the true answer here</true_answer>
        <explanation>step by step reasoning here</explanation>
        <grade>CORRECT or INCORRECT here</grade>
    </qa_eval>
</result>

Grade the student answers based ONLY on their factual accuracy. Ignore differences in 
punctuation and phrasing between the student answer and true answer. It is OK if the 
student answer contains more information than the true answer, as long as it does not 
contain any conflicting statements.

Assistant:

In addition, I specify the formatting instructions explicitly in the prompt instead of using the canned ones from XMLOutputParser or PydanticOutputParser via get_formatting_instructions(), which are comparatively quite generic and sub-optimal. By convention, the outermost tag in my format is always <result>...</result>. The qa_eval tag inside result has a corresponding Pydantic class analog declared in the code as follows:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
from pydantic import BaseModel, Field

class QAEval(BaseModel):
    question: str = Field(alias="question", description="question text")
    student_answer: str = Field(alias="student_answer",
                                description="answer predicted by QA chain")
    true_answer: str = Field(alias="true_answer",
                             description="answer generated by QG chain")
    explanation: str = Field(alias="explanation",
                             description="chain of thought for grading")
    grade: str = Field(alias="grade",
                       description="LLM grade CORRECT or INCORRECT")

After the StrOutputParser extracts the LLM output into a string, it is first passed through a regular expression to remove any content outside the <result>...</result>, then convert it into the QAEval Pydantic object using the following code. This allows us to keep object manipulation between chains independent of the output format, as well as negate any need for format specific parsing.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import re
import xmltodict

from pydantic import Field
from pydantic.generics import GenericModel
from typing import Generic, List, Tuple, TypeVar

T = TypeVar("T")

class Result(GenericModel, Generic[T]):
    value: T = Field(alias="result")

def parse_response(response):
    response = response.strip()
    start_tag, end_tag = "<result>", "</result>"
    is_valid = response.startswith(start_tag) and response.endswith(end_tag)
    if not is_valid:
        pattern = f"(?:{start_tag})(.*)(?:{end_tag})"
        p = re.compile(pattern, re.DOTALL)
        m = p.search(response)
        if m is not None:
            response = start_tag + m.group(1) + end_tag
    resp_dict = xmltodict.parse(response)
    result = Result(**resp_dict)
    return result

# example call
response = chain.invoke(
    "question": "the question",
    "context": "the context",
    "predicted_answer": "the predicted answer",
    "generated_answer": "the generated answer"
})
result = parse_response(response)
qa_eval = result.value["qa_eval"]

One downside to this approach is that it uses the current version of the Pydantic toolkit (v2) whereas LangChain still uses Pydantic V1 internally, as descibed in LangChain's Pydantic compatibility page. This is why this conversion needs to be outside LangChain and in the application code. Ideally, I would like this to be part of a subclass of PydanticOutputParser where the formatting_instructions could be generated from the class definition as a nice side effect, but that would mean more work than I am prepared to do at this point :-). Meanwhile, this seems like a decent compromise.

Thats all I had for today. Thank you for staying with me so far, and hope you found this useful!

Saturday, February 03, 2024

Book Report: Allen B Downey's Probably Overthinking It

I have read Allen Downey's books on statistics in the past, when trying to turn myself from a Software Engineer into what Josh Wills says a Data Scientist is -- someone who is better at statistics than a Software Engineer and better at software than a statistician (with somewhat limited success in the first area, I will hasten to add). Last year, I had the good fortune to present at PyData Global 2023 (the video is out finally!) so had a free ticket to attend, and one of the talks I really enjoyed there was Allen Downey's talk Extremes, Outliers and GOATs: on life in a lognormal world. In it, he mentions that this is essentially the material from Chapter 4 of his book Probably Overthinking It. I liked his talk enough to buy the book, and I wanted to share my understanding of this book with you all, hence this post.

The book is not as dense as a "real" book on stats like say The Elements of Statistical Learning but is definitely not light reading. I tried reading it on a flight from San Francisco to Philadelphia (and back) and found it pretty heavy going. While the writing is lucid and illustrated with tons of well-explained and easy to understand examples, most of these were new concepts to me, and I wished I took notes after each chapter so I could relate all these concepts together enough to reason about them rather than just learn about them. So I did another pass through the book, this time with pen and paper, and I now feel more confident about talking to other people about it. Hopefully, this is also helpful for folks who have done (or planning to do) the first pass on the book but not the second.

Most people who are new to statistics (me included) lay great store in the Gaussian (Normal) distribution to explain or model various datasets. Chapter 1 challenges this idea and demonstrate that while individual traits may follow a Gaussian distribution, a combination of such traits can be a very restrictive filter. In other words, almost all of us are weird (i.e. not normal). For me, it also introduces the Cumulative Distribution Function (CDF) as a modeling tool.

The second chapter introduces the Inspection Paradox, which explains why it always seems like our wait time for the next train is longer then the average wait time between trains, among other things. The explanation lies in the sampling strategy -- if we sample our data from the population, we may get a skew from oversampling from over-represented populations. It also describes a practical use case of this paradox to detect COVID superspreaders.

The third chapter describes what the author calls Preston's paradox, based on a 1976 paper by Samuel Preston. The paradox is that even if every woman has fewer children than her mother, the average family size can increase over time. The paradox is explained by an idea similar to the Inspection Paradox, i.e. because there are more women in existence from large families than small ones, a larger proportion of women would end up having large families than small ones, and overall that contributes to an increase in family size. The opposite can hold true as well, as demonstrateed by the loosening of reproductive restrictions in China in the aftermath of China's one-child policy not having the desired effect in boosting family sizes.

Chapter 4 is the one the author talked about in the PyData Global talk. In it, he demonstrates that certain attributes are better explained by a log-normal distribution, i.e. taking the log of the values in the distribution, rather than our familiar Gaussian distribution. This is especially true for outlier type distributions, such as performance numbers of GOAT (Greatest Of All Time) athletes compared to the general population. The explanation for this is that GOAT performance is almost always a multiplicative combination of innate human prowess (nature) and these skills being effectively harnessed and trained (nurture) plus a whole lot of other factors that all have to line up just so for the event to happen, and whose contributions to the target are therefore multiplicative rather than additive, hence the effectiveness of the log-normal distribution over the normal one.

Chapter 5 explores different survival characterstics of different populations and classifies them as either NBUE (New Better than Used in Expectation) and NWUE (New Worse than Used in Expectation). The former would apply for predicting the remaining life of lightbulbs with use, and the latter would apply for predicting cancer survivability and child mortality over time. Using child mortality statistics, the author shows that as healthcare improves and becomes more predictable across age categories, the NWUE distribution changes to resemble more closely a NBUE distribution.

Chapter 6 explores Berkson's Paradox, where a sub-sample selected from a population using some selection criteria can create correlations that did not exist in the population, or correlations that are opposite to that observed in the population. Berkson originally pointed out the paradox as a warning about using hospital data (sub-sample) to make conclusions about the general population. The selection criteria restrict the general population in specific ways, leading to a change in composition of the traits in the sub-sample, thus leading to the paradox.

Chapter 7 warns about the dangers of interpreting correlation as causation, something most of us have probably read or heard about many many times in the popular Data Science literature. The main case study here are moms who smoke (or don't smoke) and their low birth weight (LBW) babies. A study concluded that while smoker's were more likely to give birth to LBW babies, and LBW babies had a higher mortality rate, the mortality rate of LBW babies whose mothers smoked was 48% lower than those whose mothers didn't smoke. Further LBW babies of non-smokers also had higher rate of birth defects. Interpreting this correlation as causation, i.e. not heeding the warning, it seems like maternal smoking is beneficial for LBW babies, protecting them from mortality and birth defects. The explanation is that maternal smoking is not the only cause of LBW babies, and birth defects may be congenital and not linked to smoking. These two factors mean that there are biological explanations for LBW other than maternal smoking. This and a few other examples segue naturally into a brief high-level introduction to Causal Reasoning, which I also found useful.

Following on from GOAT events being better represented by log-normal rather than normal distributions, Chapter 8 describes applying this to model extremely rare events (such as earthquakes and stock market crashes), and concludes that while the log-normal distribution is more "long-tailed" than a Gaussian, rare events have an even longer tail that is better modeled by log-Student-t (or Log-t) distibution (Student-t is a Gaussian with longer / fatter tails). It also introduces the idea of a Tail distribution (the inverse of a CDF, a survival chart is a tail distribution chart). The author also makes a brief reference to Nassim Taleb's Black Swan events, saying that the ability to model and predict them make them more of Gray Swans.

Chapter 9 talks about the challenges in ensuring algorithmic fairness to all recipients of its predictions, which is very relevant given the many paradoxes the book has already covered. In this chapter, the author describes Bayes rule without mentioning it by name, calling it the "base rate" and the difference between the prior and posterior probabilities the "base rate fallacy". He also covers other aspects of fairness, citing differences across groups that an algorithm often does not see. This last part seemed to me to be related to the Inspection Paradox described earlier in the book.

Chapter 10 describes Simpson's Paradox, where sub-populations can exhibit similar correlations across the sub-populations but where the same traits are anti-correlated in the conbined population. To some extent, this seems related to Berkson's law. Among the examples cited, there is one about penguins, where within each species, the beak size and body size are correlated, but across species, they are anti-correlated. The explanation here is that there is a biological reason for the correlation within the species, but the anti-correlation is just a statistical artifact (correlation != causation in action I guess?).

Chapter 11 is about how certain instances of Simpson's Paradox can be explained as a combination of other underlying factors. It is a trusim that people get more conservative as they get older (i.e. if you are not a liberal when you are young, you have no heart, and if you are not a conservative when old, you have no brain). However, within each age group, it is observed that people actually get more liberal over time. This is explained as a combination of the age effect, the period effect, and the cohort effect. The age effect shows a positive correlation between adherence to traditional beliefs (conservativeness) and age. However, within each age group, it is observed that people get more liberal over time, i.e. the cohort effect. Finally the period effect deals with specific events during the time period under consideration, and this covers older people dying out and being replaced with younger (and more liberal) people.

Chaoter 12 continues the discussion from the previous chapter and brings in the idea of the Overton Window, which dictates what views are considered acceptable at any particular point in time, and which changes over time as well. So what was thought to be liberal in decades past is now considered more conservative. So while an individual may get more liberal with time, the Overtom Window has shifted faster towards liberalism. This can explain why an individual may find themselves getting more conservative as they age, relative to the world around them.

Overall, I enjoyed this book. I think the most impressive thing about this book was its use of generally available datasets to model physical and social environments, and using simulations to control for certain aspects of these data experiments. Also, I think I learned a few things about corner cases in Statistics which I think may be useful when reasoning about them in future. I hope I have sparked your curiosity about this book as well.

Monday, January 01, 2024

Knowledge Graph Aligned Entity Linker using SentenceTransformers

Most of us are familiar with Named Entity Recognizers (NERs) that can recognize spans in text as belonging to a small number of classes, such as Person (PER), Organization (ORG), Location (LOC), etc. These are usually multi-class classifier models, trained on input sequences to return BIO (Begin-Input-Output) tags for each token. However, recognizing entities in a Knowledge Graph (KG) using this approach is usually a much harder proposition, since a KG can contain thousands, even millions, of distinct entities, and it is just not practical to create a multi-class classifier for so many target classes. A common approach to building a NER for such a large number of entities is to use dictionary based matching. However, the approach suffers from the inability to do "fuzzy" or inexact matching, beyond standard normalization streategies such as lowercasing and stemming / lemmatizing, and requires you to specify up-front all possible synonyms that may be used to refer to a given entity.

An alternative approach may be to train another model, called a Named Entity Linker (NEL) that would take the spans recognized as candidate entities or phrases by the NER model, and then attempt to link the phrase to an entity in the KG. In this situation, the NER just learns to predict candidate phrases that may be entities of interest, which puts it on par with simpler PER/ORG/LOC style NERs in terms of complexity. The NER and NEL are pipelined together in a setup that is usually known as Named Entity Recognition and Linking (NERL).

In this post, I will describe a NEL model that I built for my 2023 Dev10 project. Our Dev10 program allows employees to use up to 10 working days per year to pursue a side-project, similar to Google's 20% program. The objective is to learn an embedding model where encodings of synonyms of a given entity are close together, and where encodings of synonyms of different entities are pushed far apart. We can then encode each entity in this space as the encoding of the centroid of the encodings of its individual synonyms. Each candidate phrase output from the NER model can then be encoded using this embedding model, and its nearest neighbors in the embedding space would correspond to the most likely entities to link to.

The idea is inspired by Self-Alignment Pretraining for Biomedical Entity Representations (Liu et al, 2021) which produced the SapBERT model (SAP == Self Aligned Pretraining). It uses Contrastive Learning to fine-tune the BiomedBERT model. In this scenario, positive pairs are pairs of synonyms for the same entity in the KG and negative pairs are synonyms from different entities. It uses the Unified Medical Language System (UMLS) as its KG, to source synonym pairs.

I follow a largely similar approach in my project, except that I use the SentenceTransformers library to fine tune the BiomedBERT model. For my initial experiments, I also used the UMLS as my source of synonym pairs, mainly for reproducibility purposes since it is a free resource available for download to anyone. I tried fine-tuning a bert-base-uncased model and the BiomedBERT models, with MultipleNegativesRanking (MNR) as well as Triplet loss, the latter with Hard Negative Mining. My findings are in line with the SapBERT paper, i.e. that BiomedBERT performs better than BERT base, and that MNR performs better than Triplet loss. The last bit was something of a dissapointment, since I had expected Triplet loss to perform better. It is possible that the Hard Negative Mining was not hard enough, or maybe I needed a higher number than 5 negatives for each positive.

You can learn more about the project in my GitHub repository sujitpal/kg-aligned-entity-linker, as well as find the code in there, in case you want to replicate it.

Here are some visualizations from my best model. The chart on the left shows the distribution of cosine similarities between known negative synonym pairs (orange curve) and known positive synonym pairs (blue curve). As you can see, there is almost no overlap. The heatmap on the right shows the cosine similarity of a set of 10 synonym pairs, where the diagonal corresponds to positive pairs and the non-diagonal elements correspond to negative pairs. As you can see, the distribution seems quite good.

I also built a small demo that shows what in my opinion is the main use case for this model. It is a NERL pipeline, where the NER component is the UMLS entity finder (en_core_sci_sm) from the SciSpacy project, and the NEL component is my best performing model (kgnel-bmbert-mnr). In order to look up nearest neighbors for a given phrase encoding, the NEL component also needs a vector store to store the centroids of the encodings of entity synonyms, I used QDrant for this purpose. The QDrant vector store needs to be populated with the centroid embeddings in advance, and in order to cut down on the index and vectorization time, I only computed embeddings for centroids for entities of type "Disease or Syndrome" and "Clinical Drug". The visualizations below show the outputs (from displacy) of the outputs of the NER component:

and that of the NEL component in my demo NERL pipeline. Note that only spans that were identified as a Disease or Drug with a confidence above a threshold were selected in this phase.

Such a NERL pipeline could be used to mine new literature for new synonyms of existing entities. Once discovered, they could be added to the synonym list for the dictionary based NER to increase its recall.

Anyway, that was all I had for this post. Today is also January 1 2024, so I wanted to wish you all a very Happy New Year and a productive 2024 filled with many Machine Learning adventures!