In this tutorial, we'll show how you to fine-tune two different transformer models, BERT and DistilBERT, for two different NLP problems: Sentiment Analysis, and Duplicate Question Detection.
Since being first developed and released in the Attention Is All You Need paper Transformers have completely redefined the field of Natural Language Processing (NLP) setting the state-of-the-art on numerous tasks such as question answering, language generation, and named-entity recognition. Here we won't go into too much detail about what a Transformer is, but rather how to apply and train them to help achieve some task at hand. The main things to keep in mind conceptually about Transformers are that they are really good at dealing with sequential data (text, speech, etc.), they act as an encoder-decoder framework where data is mapped to some representational space by the encoder before then being mapped to the output by way of the decoder, and they scale incredibly well to parallel processing hardware (GPUs).
Transformers in the field of Natural Language Processing have been trained on massive amounts of text data which allow them to understand both the syntax and semantics of a language very well. For example, the original GPT model published in Improving Language Understanding by Generative Pre-Training was trained on BooksCorpus, over 7,000 unique unpublished books. Likewise, the famous BERT model released in the paper BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding was trained on both BooksCorpus and English Wikipedia. For readers interested in diving into the neural network architecture of a Transformer, the original paper and The Illustrated Transformer are two great resources.
The main benefit behind Transformers, and what we will take a look at throughout the rest of this blog, is that once pre-trained Transformers can be quickly fine-tuned for numerous downstream tasks and often perform really well out of the box. This is primarily due to the fact that the Transformer already understands language which allows training to focus on learning how to do question answering, language generation, named-entity recognition, or whatever other goal someone has in mind for their model.
Stanford Sentiment Treebank v2 (SST2)
The first task models will be trained for is sentiment analysis. Sentiment analysis is a long-standing benchmark in the field of NLP with the goal in mind to be able to detect whether some text is positive, negative, or somewhere in between. This has many use cases such as detecting if a product is viewed in a positive or negative manner based on customer reviews or if a candidate has a high or low approval rating based on tweets. The dataset we will use to train a sentiment analysis model is the Stanford Sentiment Treebank v2 (SST2) dataset which contains 11,855 movie review sentences. This task and dataset is part of the General Language Understanding Evaluation (GLUE) Benchmark which is a collection of resources for training, evaluating, and analyzing natural language understanding systems.
Here are some examples from this dataset where numbers closer to 0 represent negative sentiment and numbers closer to 1 represent positive:
Quora Question Pairs (QQP)
The second task models will be trained for is duplicate question detection. Likewise, this task also has various use cases such as removing similar questions from the Quora platform to limit confusion amongst users. The dataset we will use to train a duplicate question detection model is the Quora Question Pairs dataset. This task/dataset is also part of the GLUE Benchmark.
A number of examples from this dataset where 0 represents non-duplicates and 1 represents duplicates are:
Two different Transformer based architectures will be trained for the tasks/datasets above. Pre-trained models will be loaded from the HuggingFace Transformers Repo which contains over 60 different network types. The HuggingFace Model Hub is also a great resource which contains over 10,000 different pre-trained Transformers on a wide variety of tasks.
The first architecture we will train is DistilBERT which was open sourced and released in DistilBERT, a distilled version of BERT: smaller, faster, cheaper, and lighter. This Transformer is 40% smaller than BERT while retaining 97% of the language understanding capabilities and also being 60% faster. We will train this architecture for both the SST2 and QQP datasets.
The second architecture we will train is BERT published in BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding. This was the first Transformer that really showed the power of this model type in the NLP domain by setting a new state-of-the-art on eleven different NLP tasks at the time of its release.
We will train this architecture for the SST2 dataset only.
With this background in mind let's now take a look at the code and train/fine-tune these models! Here we use the PyTorch deep learning framework and only include code for the SST2 dataset. To run this code yourself feel free to check out our Colab Notebook which can be easily edited to accomodate the QQP dataset as well.
Creating the Dataset
First let's create our PyTorch Dataset class for SST2. This class defines three important functions with the following purposes:
- __init__: initializes the class and loads in the dataset
- __len__: gets the length of the dataset
- __getitem__: selects a random item from the dataset
Next let's create a couple helper functions to do things like get the GPU, transfer data to it, etc. Neural networks, especially Transformer based ones, nearly always train faster on accelerator hardware such as GPUs so it is critical to send both the model and data there for processing if it's available. This allows for a significant training speedup as parallel processing capabilities can be utilized.
Defining the Loss Function
Now we will define the loss function... Since we are training a classifier to predict whether a sentence has positive or negative sentiment, or if two questions are duplicates, we will use the binary cross entropy loss function. The math behind this loss is:
Here y is the true label (0 or 1) whereas p(y) is our model prediction. Through the minimization of this value our network learns to make more accurate predictions.
Model Training / Evaluation
Next lets write the core training/evaluation logic to fine-tune and test our model which consists of 3 primary functions:
The train_model function works by first evaluating the pre-trained model on the validation set and calculating the performance before any training has taken place. This function then loops over three epochs while training the model on the training set and evaluating its performance on the validation set. An epoch is essentially a loop over all the data in some dataset.
The train function operates by training the model for an epoch. Note that before any training takes place our model is put into training mode indicating to PyTorch that gradients need to be stored for parameter updates. All batches in an epoch are then looped over by iterating over the PyTorch Dataloader. Each batch is then passed through the tokenizer allowing these tokens to then be sent to the model for sentiment score predictions. Following the de facto PyTorch training loop setup, a loss value is computed, the optimizer is zeroed out, gradients are derived on the loss, and the model is updated by taking an optimizer step.
The evaluate function has a similar setup to train except the final optimizer zeroing out, gradient derivation, and optimizer step are removed since the model should not be trained on a validation set. Other differences between these two functions is that here our model is set to evaluation mode which allows for faster inference since gradients don't need to be stored.
Built into both the train and evaluate function is a call to count_correct which computes the number of correct sentiment score predictions per batch allowing a final accuracy score to be derived across the entire dataset. Also note that softmax is called over the model's output mapping scores to probabilities.
Putting it all Together
Now that we have defined all the functions needed to train our model, we can finally fine-tune it and see what happens! Note that SST2 is one of many dataset stored in HuggingFace Datasets making it incredibly easy to load and use.
To train a BERT model instead of DistilBERT use the following:
After fine-tuning both DistilBERT and BERT on the SST2 dataset for 3 epochs their performance was evaluated on the validation and test sets. Numbers below are accuracy scores averaged across 8 separate model training runs:
Although code for training on QQP is not shown in this blog, our Colab Notebook can easily be modified to accomodate this data. The primary changes to make are editing the PyTorch dataset to handle two text inputs to a model, question 1 and question 2, as well as adjusting the input to the tokenizer. Results from fine-tuning DistilBERT on QQP for 3 epochs with performance being evaluated on the validation set can be seen below. Note that the accuracy score is averaged across 8 separate model training runs:
In this blog we learned how to fine-tune Transformers on downstream tasks, specifically sentiment analysis and duplicate question detection. By fine-tuning pre-trained Transformers significant time can be saved with performance often immediately high out of the box. This is in comparison to training from scratch which takes longer and uses orders of magnitude more compute and energy to reach the same performance metrics.
Feel free to check out the Colab Notebook that comes with this blog to run the experiments yourself! Also, if you would like to download and use the models we have developed they can be found on the HuggingFace Model Hub at the following locations: