Advanced Recipes for Contrastive Learning
Last Updated on September 18, 2024 by Editorial Team
Author(s): Raj Sangani
Originally published on Towards AI.
Note : This blog assumes familiarity with basic knowledge about dense retrievers, RAG systems and contrastive learning. A very solid introduction to this can be found in the original DPR paper from FAIR.
Table of Contents
- Possible issues with traditional training methods for retrievers
- Massively increasing batch size using gradient caching (by a factor of 1000 on modern GPUs!)
- Further increasing batch size in multi-GPU settings
- FURTHER increasing the number of in-batch negatives through an alternative loss function
- Are in-batch negatives REALLY negatives? Filtering out false in-batch negatives
Terminology
For consistency, here are some terms that will appear in later sections
Query or Anchor β The question for which we are retrieving documents
Positive β The document which correctly answers the query
Negative β Any unrelated document that does not answer the query
Hard Negative β A document which is very similar to the query but does not contain the correct answer.
In-batch negatives β For every query, positive pair in a batch, all positive documents corresponding to the other queries can be used as negatives since they are unrelated.
1. Issues with Dense Contrastive Learning
The objective we optimize in dense retrieval looks like this
The summation term in the denominator contains documents which are positive documents for other queries in the batch but are used as negatives for the given query, positive pair in the numerator.
Several studies show that having a large batch B allows the use of more in-batch negatives and hence learns a better metric space.
Unfortunately, when using transformer models to embed long documents, we run out of GPU memory quickly since long sequences take up huge activation memory. Here is a great blog that details memory usage in decoder models.
As a result of this, we have to resort to smaller batches which is not helpful. Even turning to techniques such as gradient accumulation is of no use since we need all the negatives in memory at once as our loss is inseparable (due to the denominator).
2. Enter Gradient Caching
The authors of the paper Scaling Deep Contrastive Learning Batch Size under Memory Limited Setup identify this problem and solve it.
They observe that we can separate the backpropagation process of contrastive loss into two independent parts with respect to batch examples
- from loss to representation,
- from representation to model parameters
I would strongly encourage the mathematically curious readers to refer to section 3.2 in the paper for a detailed analysis on the independence.
Based on these observations, they list the following steps to compute the gradients
1. Graph-less Forward
Before gradient computation, run an extra forward pass for each batch instance to get its representation. Importantly, this forward pass runs without constructing the computation graph. Collect and store all representations computed.
2. Representation Gradient Computation and Caching
Compute the contrastive loss for the batch based on the representation from Step1 and have a corresponding computation graph constructed. A backward pass is then run to populate gradients for each representation. Note that the retriever model is not included in this gradient computation.
3. Sub-batch Gradient Accumulation
Construct sub-batches of the batch (could be as small as a batch of 1 instance). Run model forward one sub-batch at a time to compute representations and build the corresponding computation graph. We take the sub-batchβs representation gradients from the cache and run back propagation through the encoder. Gradients are accumulated for model parameters across all sub-batches.
4. Optimization
When all sub-batches are processed, step the optimizer to update model parameters as if the full batch is processed in a single pass
At the cost of some training time due to the extra forward passes, we can now fit extremely large batches onto a GPU. I have personally used this method to increase batch size by a factor of 1024 with a sub-batch size of 1!
3. Further increasing batch size through multiple GPUs
With the availability of multiple GPUs, say A GPUs, we can further increase batch size by a factor of A.
For each query, positive pair in a batch we have B-1 in-batch negatives. As we can see in the figure, if we compute representations on each GPU and then share them across GPUs, we are utilizing the same amount of memory but increasing the number of negatives by a factor equivalent to the number of available GPUs!
4. Bidirectional Contrastive Loss
In-batch sampling is a very smart way of saving on memory but can we do better? ABSOLUTELY!
The authors of the paper Towards General Text Embeddings with Multi-stage Contrastive Learning cleverly utilize the whole batch to effectively increase the number of negatives.
They propose the following loss
where Z is :
For a given query, positive pair, they not only use traditional in-batch negatives, but also
- similarities between the given query and all other queries in the batch
- similarities between the given positive and all other positives in the batch
- similarities between current positive and all other queries in the batch
to enhance the effective number of negatives.
This modified objective yields improvements in scores on the MTEB benchmark.
5. Can in-batch negatives be considered as REAL negatives?
So far, we have focused on increasing batch size and finding more negatives in a batch but there are often cases when the documents that we sample for a query, positive pair within the batch are actually not unrelated to the query.
This makes them false negatives!
Thankfully the author of GISTEmbed: Guided In-sample Selection of Training Negatives for Text Embedding Fine-tuning has a very simple fix for this.
The goal is to discard the similarity scores for those items in equation 5 (in the above section) that have a higher similarity than the given query, document pair.
One can use a stronger guide model (a pre-trained encoder model used as a cross-encoder) to calculate the similarity scores for all the candidates in our partition function (equation 6 in the above section) and mask out all those scores that are more similar than the query,positive pair because these candidate pairs canβt be interpreted as negatives.
Closing Thoughts
Note that the ideas mentioned above are intended to squeeze out every bit of performance from your retrieval system, but the core power of such a system always lies in the question: document pairs itself.
Training with a dataset that has a diverse set of questions and documents that are clean and unambiguous is the bread and butter for any retrieval system!
On a personal note, I really enjoyed writing this blog since this is my first blog after almost two years! I hope you enjoyed reading it!
Check out my GitHub for some other projects. You can contact me here. Thank you for your time!
If you liked this here are some more!
A Comprehensive Guide on Model Calibration: What, When, and How
Part 1: Learn about calibrating machine learning models to obtain sensible and interpretable probabilities as outputs
towardsdatascience.com
Overfitting is not the only problem Regularisation can help with
Understanding mathematically how Ridge Regression helps in cases where the number of features exceeds data points
towardsdatascience.com
Dealing with features that have high cardinality
A simple utility I use to address categorical features with many unique values
towardsdatascience.com
References
- Dense Passage Retrieval for Open-Domain Question Answering https://arxiv.org/abs/2004.04906
- RocketQA: An Optimized Training Approach to Dense Passage Retrieval for Open-Domain Question Answering https://arxiv.org/abs/2010.08191
- Scaling Deep Contrastive Learning Batch Size under Memory Limited Setup https://arxiv.org/abs/2101.06983
- GISTEmbed: Guided In-sample Selection of Training Negatives for Text Embedding Fine-tuning https://arxiv.org/pdf/2402.16829
Join thousands of data leaders on the AI newsletter. Join over 80,000 subscribers and keep up to date with the latest developments in AI. From research to projects and ideas. If you are building an AI startup, an AI-related product, or a service, we invite you to consider becoming aΒ sponsor.
Published via Towards AI