Some time back, I found myself thinking of different data augmentation strategies for unbalanced datasets, i.e. datasets in which one or more classes are over-represented compared to the others, and wondering how these strategies stack up to one another. So I decided to set up a simple experiment to compare them. This post describes the experiment and its results.
The dataset I chose for this experiment was the SMS Spam Collection Dataset from Kaggle, a collection of almost 5600 text messages, consisting of 4825 (87%) ham and 747 (13%) spam messages. The network is a simple 3 layer fully connected network (FCN), whose input is a 512 element vector generated using the Google Universal Sentence Encoder (GUSE) against the text message, and outputs the argmax of a 2 element vector (representing "ham" or "spam"). The text augmentation strategies I considered in my experiment are as follows:
- Baseline -- this is a baseline for result comparison. Since the task is binary classification, the metric we chose is Accuracy. We train the network for 10 epochs using Cross Entropy and the AdamW Optimizer with a learning rate of 1e-3.
- Class Weights -- Class Weights attempt to address data imbalance by giving more weight to the minority class. Here we assign class weights to our optimizer proportional to the inverse of their counts in the training data.
- Undersampling Majority Class -- in this scenario, we sample from the majority class the number of records in the minority class, and only use the sampled subset of the majority class plus the minority class for our training.
- Oversampling Minority Class -- this is the opposite scenario, where we sample (with replacement) from the minority class a number of records that are equal to the number in the majority class. The sampled set will contain repetitions. We then use the sampled set plus the majority class for training.
- SMOTE -- this is a variant on the previous strategy of oversampling the minority class. SMOTE (Synthetic Minority Oversampling TEchnique) ensures more heterogeneity in the oversampled minority class by creating synthetic records by interpolating between real records. SMOTE needs the input data to be vectorized.
- Text Augmentation -- like the two previous approaches, this is another oversampling strategy. Heuristics and ontologies are used to make changes to the input text preserving its meaning as far as possible. I used the TextAttack, a Python library for text augmentation (and generating examples for adversarial attacks).
A few points to note here.
First, all the sampling methods, i.e., all the strategies listed above except for the Baseline and Class Weights, requires you to split your training data into training, validation, and test splits, before they are applied. Also, the sampling should be done only on the training split. Otherwise, you risk data leakage, where the augmented data leaks into the validation and test splits, giving you very optimistic results during model development which will invariably not hold as you move your model into production.
Second, augmenting your data using SMOTE can only be done on vectorized data, since the idea is to find and use points in feature hyperspace that are "in-between" your existing data. Because of this, I decided to pre-vectorize my text inputs using GUSE. Other augmentation approaches considered here don't need the input to be pre-vectorized.
The code for this experiment is divided into two notebooks.
- blog_text_augment_01.ipynb -- In this notebook, I split the dataset into a train/validation/test split of 70/10/20, and generate vector representations for each text message using GUSE. I also oversample the minority class (spam) by generating approximately 5 augmentations for each record, and generate their vector representations as well.
- blog_text_augment_02.ipynb -- I define a common network, which I retrain using Pytorch for each of the 6 augmentation scenarios listed above, and compare their accuracies.
Results are shown below, and seem to indicate that oversampling strategies tend to work the best, both the naive one and the one based on SMOTE. The next best choice seems to be class weights. This seems understandable because oversampling gives the network the most data to train with. That is probably also why undersampling doesn't work well. I was a bit surprised also that text augmentation strategies did not perform as well as the other oversampling strategies.
However, the differences here are quite small and possibly not really significant (note the y-axis in the bar chart is exagerrated (0.95 to 1.0) to highlight this difference). I also found that the results varied across multiple runs, probably resulting from different initialization scenarios. But overall the pattern shown above was the most common.
Edit 2021-02-13: @Yorko suggested using confidence intervals in order to address my above concern (see comments below), so I collected the results from 10 runs and computed the mean and standard deviation for each approach across all the runs. The updated bar chart above shows the mean value and has error bars of +/- 2 standard deviations off the mean result. Thanks to the error bars, we can now draw a few additional conclusions. First, we observe that SMOTE oversampling can indeed give better results than naive oversampling. It also shows that undersampling results can be very highly variable.
> However, the differences here are quite small and possibly not really significant (note the y-axis in the bar chart is exagerrated (0.95 to 1.0) to highlight this difference). I also found that the results varied across multiple runs, probably resulting from different initialization scenarios. But overall the pattern shown above was the most common.
ReplyDeleteHow about confidence intervals then?
It is much more aviated technique
ReplyDeletehttps://socialprachar.com/fun-ai-tools-available-online/
Thanks Yorko, that is a good idea, I will update the post with a bar chart with confidence intervals.
ReplyDeleteUnknown, not sure what you meant, but I thought the URL you shared has some good links for AI folks, so thanks for sharing it.
ReplyDelete