Sunday, October 17, 2021

Fine-tuning OpenAI CLIP for different domains

In July this year, a group of us on the TWIML Slack Channel came together and participated in the Flax/JAX Community Week organized by Hugging Face and Google Cloud. Our project was about fine-tuning the CLIP Model from OpenAI with the RSICD (Remote Sensing Image Captioning Dataset), and ended up placing third.

The code for the project is available on github at arampacha/CLIP-rsicd if you are curious about how we went about doing this, or if you want to replicate our efforts. Our fine-tuned model is available on the Hugging Face model repository at flax-community/clip-rsicd-v2, you can find instructions on how to use it for inference on your own remote-sensing / satellite data. We also have a Streamlit based demo that shows its application in image search and finding features in images using text descriptions. Finally, we also have a blog post on the Hugging Face blog titled Fine tuning CLIP with Remote Sensing (Satellite) images and captions. Hope you fine these useful, do check them out.

Even before this project, I had been considering learning a joint embedding for medical images and their captions as described in the Contrastive Learning of Medical Visual Representations from Paired Images and Text (CONVIRT) paper by Zhang et al (2010), and using it to power a text-to-image image search application. Based on the RSICD project, however, CLIP looked like a better and more modern alternative.

Elsevier has a Dev-10 program for their engineers, by which they are given 10 working days (2 weeks) to build something that does not necessarily have to align with company objectives, but which is somewhat work-related. When my Dev-10 days came up in early September, I used it to fine-tune the same OpenAI CLIP baseline as we did for the Flax/JAX community week, but with the ImageCLEF 2017 Image Captioning dataset. Happily, the results were just as encouraging as fine-tuning it with RSICD, if anything, the improvement was even more dtamatic.

During the RSICD fine-tuning exercise, the fine-tuning work was done by other members of the team. My contribution to that project was the evaluation framework, the image augmentation piece, the demo, and later the blog post. On the ImageCLEF exercise, I was the only developer, so while a lot of the code in the second case was borrowed or adapted from the first, there were some important differences as well, apart from the dataset.

First, in the RSICD fine-tuning case we used JAX/Flax with a TPU enabled instance on Google Cloud, and in the second I used Pytorch on a single-GPU EC2 instance on AWS (with the Deep Learning AMI). I found that the Hugging Face wrapper for CLIP provides a lot of the support that was being done explicitly, so I tried to leverage the provided functionality as much as possible, resulting in slightly cleaner and more readable code (even if I say so myself :-)).

Second, I didn't do any image or text augmentation like we did with the RSICD fine-tuning effort. RSICD had a total of 10k images with approximately 5 captions per image, of which we were using about 7k for training. On the other hand, ImageCLEF was about 160k images and captions, of which we were using 140k for training. In addition, RSICD was training on a TPU with 4 parallel devices, and ImageCLEF was training an a single GPU. Because of this, I ended up using subsampling from the training set as a form of regularization instead, and using early stopping to terminate the training process once no improvements in validation accuracy were detected.

Third, with the benefit of hindsight, I settled on a more industry-standard metric for evaluation, the Mean Reciprocal Rank (MRR@k) compared to the less strict and somewhat ad-hoc Hits@k metric I had used for the first exercise.

And fourth, because the data volume for my second Image Search demo was much larger (200k images instad of 10k), I switched from using NMSLib to using Vespa, the open source hybrid vector + text search engine from Yahoo!. Using it, I was able to provide image search results based on lexical matches between query and caption text, vector space matches between CLIP query vector and CLIP image vectors, and hybrid search results ranked by combining the relevance of the two approaches.

Unfortunately I am not able to share the code. Since the work was done on company time with company resources, the code rightfully belongs to the company. I am also hopeful that the work could be used to power image search (or related) functionlity in some production application. For these reasons I am unable to share the code, but in general, it is similar (with the differences enumerated above) to the RSICD version.

However, just to give some idea of the kind of results you can expect from a fine-tuned CLIP model, here are couple of screenshots. The results are for the queries "computed tomography" and "computed tomography deep vein thrombosis". Both results are from doing vector matching, i.e. ranked by cosine similarity between the CLIP encoding of the query text and the CLIP encoding of each image.

As you can see, CLIP returns relevant images for both high level and detailed queries, indicating how rich the embedding is. My main takeaway from this series of exercises are twofold -- first, CLIP's joint image-text encoding is a seriously powerful idea and is super-effective, and second, transformer models trained on general data (natural images and text in this case) can be fine-tuned effectively for specialized domains using relatively small amounts of data.