Saturday, March 23, 2024

Book Report: Machine Learning for Drug Discovery

Drug Discovery is a field where biochemists (and more recently computer scientists) turn ideas into potential medications. I first came across a few applications in this area when checking out how to build Graph Neural Networks (GNN) as part of auditing the CS224W: Machine Learning with Graphs course from Stanford, some learnings of which I recycled into my Deep Learning with Graphs tutorial at ODSC 2021. Of course, drug discovery is much more than just GNNs, I mention this only because this happened to be my entry point into this fascinating world. However, I will hasten to add that despite having made an entrance, I am still parked pretty solidly close to the entrance (or exit, depending on your point of view).

But I am always looking to learn more about stuff I find interesting, so when I was offered a chance to review Dr Noah Flynn's Machine Learning for Drug Discovery published by Manning, I jumped on it. The book is currently in MEAP (Manning Early Access Program) so currently there are only 5 chapters available, but once the book is completed, there are going to be 15 chapters in all. The intended audience of the book, as the title suggests, are computational biochemists, i.e. the ones who attempt to solve Drug Discovery problems using Machine Learning. Thus, to become a computational biochemist, there are two main ways -- either you are a biochemist and you learn the ML, or you are a ML person and you learn the biochemistry. The book is aimed at both categories of readers.

As someone in the latter category, I had to spend much more time on the biochemistry aspects. I suspect that most readers of this review would also fall into this category. For them, I would say that while the ML part is sophisticated enough to solve the problem at hand, they are methods and practices that should be familiar to most ML people already. The most useful things that I think you would get out of this book are as follows:

  • Framing the Drug Discovery problem as a ML problem
  • Preprocessing and Encoding inputs
  • Getting data to train your ML model

For the first one, you either need to have a biochemistry background yourself, or you need to pair with someone who does. I suppose you could get by with a life sciences or chemistry background as well, or acquire enough biochemistry knowledge over time in this field, and this book may even take you part of the way there, but be aware that the learning curve is steep.

For the second and the third items, I thought the book was super useful. Most chapters are built as case studies around a Drug Discovery problem, so as you go through the chapters, you will learn about the sites to acquire your datasets from, and the techniques to preprocess the data from these sites into a form suitable for consumption by your ML model. At least the first 5 chapters deal with fairly simple ML models, but which may or may not be familiar to you depending on your industry, so you might also learn a few things about evaluating or tuning these models that you didn't know before (I did).

The first chapter introduces the reader to the domain and talks about the need for computational approaches to Drug Discovery. It introduces the terminology and the RDKit software library, an open-source cheminformatics toolkit the provides implementations of many common operations needed for computational Drug Discovery (sort of like a specialized supplement to Scikit-Learn for general ML). It also covers high level rules of thumb for detecting drug compounds, such as Lipinski's rule of 5. It then covers some common use cases common in Drug Discovery, ranging from Virtual Screening to Generative and Synthetic Chemistry. It also covers some popular (and public) repositories for Chemistry data, such as ChEMBL, PubChem, Protein Data Bank (PDB), etc.

The second chapter demonstrates Ligand based Screening, where you already have a reference molecule with some of the desired properties, and you want to search the chemical space for molecules similar to that one, with the objective of finding more drugs like the one you started with. The case study here is to identify potential anti-malarial compounds. The dataset for this comes packaged with RDKit itself as Structure Definition Files (SDF) which describes each molecule using a SMILES (Simplified Molecular Input Link Entry System) string. The chapter walks us through converting the SMILES to MOL format, then using RDKit to extract specialized chemical features from the MOL and SMILES, preprocessing to filter out uninteresting molecules based on rule based thresholds such as bio-availability, molecular weight, etc, structure based thresholds such as toxicity, and specific substructural patterns (similar to subgraph motifs). It then uses RDKit to generate Morgan fingerprints out of the remaining molecules (MOL). Morgan (and other) fingerprints are similar to embeddings in NLP, except that they encode structural information through a more deterministic process, and are hence more explainable than embeddings. Finally, these fingerprints are compared with the reference molecule using Tanimoto similarity and the nearest neighbors found.

Chapter 3 continues with the problem of Ligand based screening, but tries to predict cardiotoxicity of the anti-malarial compounds found in the previous chapter using a linear model. This is done indirectly by predicting if the compound blocks the hERG (or gene potassiuam) channel, then it is cardiotoxic, and vice versa. A linear model (Scikit-Learn SGD CLassifier) is trained using the hERG dataset from the Therapeutic Data Commons (TDC). The chapter shows some Exploratory Data Analysis (EDA) on the data, using standard preprocessing as described in the previous chapter. An additional step here is to standardize (regularize) the data for classification. The author provides the biochemistry reasoning for behind this step, but uses the implementation already provided by RDKit. Finally Morgan fingerprints are used to train the SGD Classifier. Because the elements of Morgan fingerprints have meaning, the weights of the resulting SGD model can be used to determine feature importances. There is also some discussion here of cross validation, L1/L2 regularization, removing collinearity, adding interaction terms and hyperparameter sweeps.

Chapter 4 explores building a linear regression model to predict solubility, i.e. how much of the drug would be absorbed by the system. The dataset used to train the regressor is the AqSolDB, also from TDC. This chapter introduces the idea of scaffold splitting, a technique common with biochemical datasets that preserves the structural / chemical similarity within each split. It also briefly describes outlier removal at the extremes, which requires chemistry knowledge. The RDKit library is used to extract features from the dataset, and the model trained to minimize the Mean Squared Error loss. The RANSAC (RANdom SAmple Consensus) technique is introduced that makes models more robust to outliers. On the ML side, there is some discussion on the bias-variance tradeoff and Learning / Validation curves.

The fifth and last chapter of the MEAP (at the time of writing this review) deals with predicting how well the body will metabolize the drug. Typically, drugs are broken down into enzymes in the liver, a large proportion of which are collectively known as the Cytochrome P450 superfamily. As before, metabolism is predicted indirectly by whether the drug inhibits Cytochrome P450 -- if it does, then it will not get metabolized easily, and vice versa. The dataset used to train the model is the CYP3A4 dataset, also from TDC. Data is prepared using the same set of (by now) standard pipeline and the classifier trained a binary predictions of whether the input inhibits Cytochrome P450 or not. The chapter discusses the utility of Reliability Plots in Performance Evaluation and Platt scaling for calibrating probabilities. It also talks about how to deal with imbalanced datasets, Data Augmentation, Class Weights and other approaches to deal with class imbalance. Various models are trained and evaluated, and their important features identified and visualized with RDKit Similarity Map. The chapter ends with a short discussion on Multi-label classification.

The pandemic and the rapid discovery of the COVID vaccine gave a lot of us (at least those of us that were watching) a ringside view into the fascinating world of drug discovery. This book provides yet another peek into this world, with its carefully crafted case studies and examples. Overall, I think you will learn a lot about drug discovery if you go through this book, both on the biochemistry side and the ML side. There are exercises at the end of each chapter, doing these would help you get more familiar with RDKit and hopefully more effective at computational drug discovery.

Sunday, March 17, 2024

Hierarchical (and other) Indexes using LlamaIndex for RAG Content Enrichment

At our weekly This Week in Machine Learning (TWIML) meetings, (our leader and facilitataor) Darin Plutchok pointed out a LinkedIn blog post on Semantic Chunking that has been recently implemented in the LangChain framework. Unlike more traditional chunking approaches that use number of tokens or separator tokens as a guide, this one chunks groups of sentences into semantic units by breaking them when the (semantic) similarity between consecutive sentences (or sentence-grams) fall below some predefined threshold. I had tried it earlier (pre-LangChain) and while results were reasonable, it would need a lot of processing, so I went back to what I was using before.

I was also recently exploring LlamaIndex as part of the effort to familiarize myself with the GenAI ecosystem. LlamaIndex supports hierarchical indexes natively, meaning it provides the data structures that make building them easier and more natural. Unlike the typical RAG index, which are just a sequence of chunks (and their vectors), hierarchical indexes would cluster chunks into parent chunks, and parent chunks into grandparent chunks, and so on. A parent chunk would generally inherit or merge most of the metadata from its children, and its text would be a summary of its children's text contents. To illustrate my point about LlamaIndex data structures having natural support for this kind of setup, here are the definitions of the LlamaIndex TextNode (the LlamaIndex Document object is just a child of TextNode with an additional doc_id: str field) and the LangChain Document. Of particular interest is the relationships field, which allows pointers to other chunks using named relationships PARENT, CHILD, NEXT, PREVIOUS, SOURCE, etc. Arguably, the LlamaIndex TextNode can be represented more generally and succintly by the LangChain Document, but the hooks do help to support hierarchical indexing more naturally.

# this is a LlamaIndex TextNode
class TextNode:
  id_: str = None
  embedding: Optional[List[float]] = None
  extra_info: Dict[str, Any]
  excluded_embed_metadata_keys: List[str] = None
  excluded_llm_metadata_keys: List[str] = None
  relationships: Dict[NodeRelationship, Union[RelatedNodeInfo, List[RelatedNodeInfo]] = None
  text: str
  start_char_idx: Optional[int] = None
  end_char_idx: Optional[int] = None
  text_template: str = "{metadata_str}\n\n{content}"
  metadata_template: str = "{key}: {value}",
  metadata_separator = str = "\n"

# and this is a LangChain Document
class Document:
  page_content: str
  metadata: Dict[str, Any]

In any case, having discovered the hammer that is LlamaIndex, I began to see a lot of potential hierarchical indexes nails. One such nail that occurred to me was to use Semantic Chunking to cluster consecutive chunks rather than sentences (or sentence-grams), and then create parents nodes from these chunk clusters. Instead of computing cosine similarity between consecutive sentence vectors to build up chunks, we compute cosine similarity across consecutive chunk vectors and split them up into clusters based on some similarity threshold, i.e. if the similarity drops below the threshold, we terminate the cluster and start a new one.

Both LangChain and LlamaIndex have implementations of Semantic Chunking (for sentence clustering into chunks, not chunk clustering into parent chunks). LangChain's Semantic Chunking allows you to set the threshold using percentiles, standard deviation and inter-quartile range, while the LlamaIndex implementation supports only the percentile threshold. But intuitively, here's how you could get an idea of the percentile threshold to use -- thresholds for the other methods can be computed similarly. Assume your content has N chunks and K clusters (based on your understanding of the data or from other estimates), then assuming a uniform distribution, there would be N/K chunks in each cluster. If N/K is approximately 20%, then your percentile threshold would be approximately 80.

LlamaIndex provides an IngestionPipeline which takes a list of TransformComponent objects. My pipeline looks something like below. The last component is a custom subclass of TransformComponent, all you need to do is to override it's __call__ method, which takes a List[TextNode] and returns a List[TextNode].

transformations = [
    text_splitter: SentenceSplitter,
    embedding_generator: HuggingFaceEmbedding,
    summary_node_builder: SemanticChunkingSummaryNodeBuilder
ingestion_pipeline = IngestionPipeline(transformations=transformations)
docs = SimpleDirectoryReader("/path/to/input/docs")
nodes =

My custom component takes the desired cluster size K during construction. It uses the vectors computed by the (LlamaIndex provided) HuggingFaceEmbedding component to compute similarities between consecutive vectors and uses K to compute a threshold to use. It then uses the threshold to cluster the chunks, resulting in a list of list of chunks List[List[TextNode]]. For each cluster, we create a summary TextNode and set its CHILD relationships to the cluster nodes, and the PARENT relationship of each child in the cluster to this new summary node. The text of the child nodes are first condensed using extractive summarization, then these condensed summaries are further summarized into one final summary using abstractive summarization. I used bert-extractive-summarizer with bert-base-uncased for the first and a HuggingFace summarization pipeline with facebook/bert-large-cnn for the second. I suppose I could have used an LLM for the second step, but it would have taken more time to build the index, and I have been experimenting with ideas described in the DeepLearning.AI course Open Source Models with HuggingFace.

Finally, I recalculate the embeddings for the summary nodes -- I ran the summary node texts through the HuggingFaceEmbedding, but I guess I could have done some aggregation (mean-pool / max-pool) on the child vectors as well.

Darin also pointed out another instance of Hierarchical Index proposed via the RAPTOR: Recursive Abstractive Processing for Tree-Organized Retrieval and described in detail by the authors in this LlamaIndex webinar. This is a bit more radical than my idea of using semantic chunking to cluster consecutive chunks, in that it allows clustering of chunks across the entire corpus. One other significant difference is that it allows for soft-clustering, meaning a chunk can be a member of more than one chunk. They first reduce the dimensionality of the vector space using UMAP (Uniform Manifold Approximation and Projection) and then apply Gaussian Mixture Model (GMM) to do the soft clustering. To find the optimum number of clusters K for the GMM, one can use a combination of AIC (Aikake Information Criterion) and BIC (Bayesian Information Criterion).

In my case, when training the GMM, the AIC kept decreasing as the number of clusters increased, and the BIC had its minimum value for K=10, which corresponds roughly to the 12 chapters in my Snowflake book (my test corpus). But there was a lot of overlap, which would force me to implement some sort of logic to take advantage of the soft clustering, which I didn't want to do, since I wanted to reuse code from my earlier Semantic Chunking node builder component. Ultimately, I settled on 90 clusters by using my original intuition to compute K, and the resulting clusters seem reasonably well separated as seen below.

Using the results of the clustering, I built this also as another custom LlamaIndex TransformComponent for hierarchical indexing. This implementation differs from the previous one only in the way it assigns nodes to clusters, all other details with respect to text summarization and metadata merging are identical.

For both these indexes, we have a choice to maintain the index as hierarchical, and decide which layer(s) to query based on the question, or add the summary nodes into the same level as the other chunks, and let vector similarity surface them when queries deal with cross-cutting issues that may be found together in these nodes. The RAPTOR paper reports that they don't see a significant gain using the first approach over the second. Because my query functionality is LangChain based, my approach has been to generate the nodes and then reformat them into LangChain Document objects and use LCEL to query the index and generate answers, so I haven't looked into querying from a hierarchical index at all.

Looking back on this work, I am reminded of similar choices when designing traditional search pipelines. Often there is a choice between building functionality into the index to support a cheaper query implementation, or building the logic into the query pipeline that may be more expensive but also more flexible. I think LlamaIndex started with the first approach (as evidenced by their blog posts Chunking Strategies for Large Language Models Part I and Evaluating Ideal Chunk Sizes for RAG Systems using LlamaIndex) while LangChain started with the second, even though nowadays there is a lot of convergence between the two frameworks.