RAG Research Paper Explained: Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks
Last Updated on January 3, 2025 by Editorial Team
Author(s): Aman Agrawal
Originally published on Towards AI.
RAG Research Paper Explained: Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks
Step Inside the World of RAG for a Detailed Breakdown of Its Core Components and Advanced Fine-Tuning Strategies
RAG is one of the most hyped terms in the AI and LLM domain, and many articles tell us about advanced techniques in Retrieval-Augmented Generation (RAG), but few explore what happens behind the scenes. If youβve never encountered an article explaining the mechanism of Retrieval and Generation, this oneβs for you. This article aims at everything from scratch, so itβs going to be a long article, we will understand the whole mechanism in depth. I am not gonna write down another variant of RAG like adaptive, corrective, contextual, order-preserve etc , there is plenty of deep information available on it. I have just written down the mechanism of pre-training that RAG used as in the original research paper which is Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks
I have written down this article with the motivation of explaining the paper in simple language, with mathematical intuition(loss functions, back-prop) and using simple examples and also with the motivation that maybe someday someone will make some tweak in the mechanism of pre-training in RAG instead of making a wrapper on vanilla RAG or craft another variant of RAG because I believe more on engineering and coding mechanism of retriever or generator, rather than how to feed better English to generator model for better accuracy, as RAG is not that much scalable when it comes to production and usage of apps among millions of users as it comes with the cost of high token usage, some extra engineering and coding is required in the whole pipeline of RAG to bring down the cost, that can only happen when one have a basic understanding of how pre-training and fine-tuning of vanilla RAG has been done , around which variants are built.
This will be a detailed breakdown of the RAG research paper of the two mechanisms discussed in the paper –
- RAG sequence
- RAG token-based
In this article, the discussion is about the two main parts of RAG separately: Retriever and the Generator. We explain how each part works and how they come together to make RAG function. Weβll also cover how these systems are trained and improved through fine-tuning and pre-training and how errors are managed through loss functions and backpropagation. The goal is to cover everything in depth, using simple examples (I have taken the help of ChatGPT to make examples at various instances, which will mention in the article itself whether thereβs use of an AI tool to generate the text and related examples) to help understand each part of the process clearly. More or less, the concepts related to fine-tuning various RAG techniques remain the same, so understanding the mechanism of vanilla RAG as discussed in the original RAG paper β Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks will develop the core concepts.
So, letβs go ahead with this. This is a long read, so have patience and try to read it through till the end, it will give you a lot of perspective on how RAG works β IN and OUT.
Letβs go step by step and try to develop the intuition behind the Retrieval Augmented Generation and RAG fine-tuning and then go into details of joint pre-training and fine-tuning of RAG.
The traditional vanilla RAG feeds the generative model with context attached (retrieval from the external database) along with the input so that the generative model produces a better answer compared to the response generated by the generative model alone without any retrieval of context from an external database.
So surely, using RAG, our generative response would become better in comparison to a solely generative model without retrieval of context as we are trying to give the model the right context, the right kind of information for generating the answer against a query.
The key difference between the two models is the length of input being fed to the generative model β
- In purely generative model input β βQueryβ (only modelβs internal knowledge β stored in the form of weights)
- In RAG the input is β βContext + Queryβ (context + modelβs internal knowledge)
Architecture-wise, the difference would be an addition of a retriever in the RAG that gives the generative model the context from non-parametric memory (external database).
Thinking of a step further? What else can we do to enhance the performance of this generative model?
Tuning the RAG upon the external memory (non-parametric memory) upon our task/dataset. We have often heard of fine-tuning the LLM upon our task and also heard of using RAG on our external dataset. But have you ever heard of fine-tuning + RAG together on our task? Why not if our task is very complex or we need better accuracy compared to just fine-tuning or just using RAG? We can combine the powers of fine-tuning and RAG and achieve wonders
We can do RAG fine-tuning( It comes at the cost of complex engineering for RAG setup + computing required for fine-tuning).
But itβs essential to understand when to use what, if the generated answers of a purely generative model suffice our needs, that means for our task, the modelβs internal knowledge is sufficient, then we donβt go for RAG β It completely depends on our task. If our generation task is very general, then the knowledge of the generative model alone would be enough, but if our task is knowledge-intensive, that means our generated answers are from some dedicated external database (external knowledge base), then our modelβs knowledge lacks, and we have to pave our way towards RAG and maybe RAG fine-tuning (if the task is very complex).
A short example to describe, addressing why and how to do RAG fine-tuningβ¦.
Suppose we have a special type of task that requires generating answers not just from what a model knows from its knowledge of stored weights in it but from a specific, detailed external source, like a medical database. This task is complex because the answers need to come from detailed, specific information not found in general knowledge, which means the task is knowledge-intensive.
Initially, we tried using a standard generative model, which was trained on a broad range of topics. As expected, it didnβt perform well because it didnβt have access to the specialised knowledge from our external database, which is crucial for answering these specific queries.
To improve upon our results, we turned to a Retrieval-Augmented Generation (RAG) approach. By integrating RAG, we can enhance our modelβs answers by pulling in relevant information from the external database right before generating a response. This added context can significantly boost our modelβs ability to produce answers.
But we saw a chance to do even better by fine-tuning the RAG model directly on our external database. Since all our queries would come from this specialised source, it made sense to optimise the RAG. For this fine-tuning, we didnβt need to label the documents for retrieval β just the questions and their correct answers. This focus made the fine-tuning process much more straightforward, and less labour in labelling was required, only labels for (queries corrected answers) on the database.
For fine-tuning RAG, we decided to create a labelled dataset specifically for this fine-tuning. We prepared pairs of queries and their correct answers. This set was all we needed to fine-tune our RAG model β no need to tag which documents the answers came from unless we wanted to, which would have been a much bigger task with only a slight performance improvement. We just need a labelled set of (queries with correct answers) and could fine-tune RAG on this.
With this labelled dataset, weβre good to go for RAG fine-tuning our data
Fine Tuning in RAG
Fine-tuning in Retrieval-Augmented Generation (RAG) is essential to optimise the systems for specific tasks and improve both retrieval and generation components.
Letβs break down the complex concepts involved step by step with the help of examples and cover all aspects involved in this , just to get started itβs important to have a overview around certain components and technical jargons involved. Letβs just get overview of β fine-tuning , retriever , generator , dense passage retriver and various other components involved.
1. What is Fine-tuning in RAG?
Fine-tuning in RAG involves updating the parameters of its two main components:
a. Retriever: Improves retrieval quality, ensuring the most relevant passages are selected from the external database (generally labels for this part are not made for RAG fine-tuning)
b. Generator: Ensures the generative model (seq2seq as in research paper) uses the retrieved passages effectively to produce accurate, context-aware answers. We are using the seq2seq (BART) generative model in this (in the research paper). Otherwise, a decoder-only model like GPT could also have been used.
We have mentioned the original RAG set-up before, so in the original RAG paper, a Dense Passage Retriever (DPR) is used. The usage of DPR is the reason we donβt require extensive labelling of the correct document from the external database for answer generation. The DPR is a pre-trained model specifically trained on the task of retrieving correct documents, and then this pre-trained retriever is used in the joint rag fine-tuning pipeline for retrieval purposes. The DPR used was a BERT base in the original RAG setup.
So, letβs see how the pre-training of the DPR happensβ¦
The pre-training of the DPR involves training two separate BERT-based encoders β one for processing queries and another for passages β on a dataset containing known βpositiveβ and βnegativeβ passages (BERT-based passage retriever). Positive passages are those that directly answer the query, while negative passages do not. This training enables the model to effectively rank passages by their relevance to a given query.
By training the DPR encoders to differentiate between these positives and negatives, the model learns to identify the most relevant passages when faced with real queries from users. Once this pre-training is complete, the DPR can be integrated into the retrieval component of the RAG system, enhancing its ability to retrieve accurate and contextually appropriate information before passing it to the generative model for answers.
This focused approach ensures that the RAG system not only generates responses but does so based on the most relevant information available from the external database.
If we completely skip the initial retrieval pre-training step and go straight into RAG fine-tuning without a pre-trained retriever, the retrieval part may start poorly or very poorly. The generator would not receive good, relevant information at all, making the entire system less effective and taking a lot of time for training. The original idea behind RAG is to use a strong retrieval model (like a DPR already trained on a suitable dataset) as the starting point.
Throughout this article, weβve used the term βJoint RAG fine-tuningβ frequently. Itβs important to note the term might sound like an extensive simultaneous update of the weights of all the components of the RAG pipeline. However, in the original study, while the concept of fine-tuning the RAG components jointly exists, only the weights of the query encoder and the generator were fine-tuned. The weights of the document encoder (passage retriever) remained fixed because they used a pre-trained retriever.
The idea of updating all components simultaneously β document encoder, query encoder, and generator β does exist but would be computationally intensive. This update was not practised in the original RAG paper. The use of βjoint fine-tuningβ in discussions may imply this broader application, but in practice, the fine-tuning was more selective to manage computational resources effectively.
This distinction is crucial for understanding the practical implementations of RAG and aligning expectations with the original research outcomes. While theoretically possible, full joint fine-tuning across all RAG components is more complex and resource-demanding than the methods employed. A BERT base or a normal sentence transformer could also be used as a retriever, but using DPR would perform better as it is pre-trained on the task of retrieval.
Letβs get into a bit more detail about the two components of RAG, summing up the flow with the help of examples
a. Retriever
- Model Used: Often Dense Passage Retrieval (DPR), which embeds queries and documents into the same dense vector space for similarity-based retrieval. For this, a BERT-type model is used. Sentence transformers or any encoder-only models could be used (it would be better if they were pre-trained for the retrieval part before placing them in the rag pipeline)
How Retriever is Fine-Tuned (Before RAG set-up):
Provide training examples where a query is paired with the following:
- Positive passages: Relevant passages that contain the correct answer.
- Negative passages: Irrelevant passages to penalise incorrect retrieval.
The retriever learns to maximise the similarity for positives and minimise it for negatives.
Example:
Query: βWho discovered penicillin?β
- Positive passage: βAlexander Fleming discovered penicillin in 1928.β
- Negative passage: βPenicillin is an antibiotic used to treat infections.β
Result: Retriever improves its embeddings for better passage ranking.
(The above example has been generated with the help of the AI tool β ChatGPT)
b. Generator (This happens during RAG fine-tuning):
Making this thing clear that the weights of both the retriever and generator are free to get adjusted during fine-tuning, but the retriever is already pre-trained on the task of retrieval as far as the original RAG set-up is concerned, and the generator is being fine-tuned for the first time specific to our correct labels y against our input query x.
Model Used: A sequence-to-sequence model, like BART, T5 or GPT, depends upon our task. (BART β in the original paper)
How Itβs Fine-Tuned:
- Provide training data where the input consists of:
- Input Query (x) + Retrieved Passages (DPR)
- The target output y is the expected answer.
The generator learns to combine the retrieved knowledge and generate fluent text.
Example:
- Input: (Query: βWho discovered penicillin?β + Retrieved Passage: βAlexander Fleming discovered penicillin in 1928.β)
- Target Output: βAlexander Fleming discovered penicillin.β
(The above example has been generated with the help of the AI tool β ChatGPT)
Joint Fine-Tuning in RAG
RAG supports joint fine-tuning, where both the retriever and generator are optimised simultaneously:
Process:
- Retrieve passages for a query.
- Generate an answer using the passages.
- Compute the loss:
Retrieval Loss: Based on whether the correct document was retrieved.
Here, the correct doc doesnβt mean any correctly labelled document because we donβt do the labels for correct docs as far as the original RAG set-up is concerned. The retriever model weights are trained (in this also we consider the original case, the passage retrieverβs weights are fixed) via the gradient signals from the generatorβs success or failure in producing the correct answer from an input query. This is made clear in detail in the next few sections.
Generation Loss: Based on how well the generated answer matches the ground truth (y).
Example:
Query: βWhen did the Wright brothers invent the aeroplane?β
- Retrieved Passages: [βThe Wright brothers flew the first aeroplane in 1903.β]
- Target Output: βThe Wright brothers invented the aeroplane in 1903.β
Fine-tuning optimises:
- Retriever: Improves focus on passages mentioning βWright brothersβ and β1903.β
- Generator: Learns to combine retrieved content to try to generate the correct answer with whatever the model has learned from training.
(The above example was made by the Author)
Aspects of Fine-Tuning in RAG
Data Requirements
Fine-tuning requires labelled data with the following:
- Queries.
- Ground-truth answers.
- Pre-trained retriever preferred
Advantages of Fine-Tuning RAG
Allows dynamic knowledge updating by retraining the retriever on new data without retraining the generator. This would not require heavy computation, as we already use pre-trained models, so they already have a lot of knowledge inside their weights, and the external database just becomes a plus so that they can perform well on our task.
Example Use Case of RAG Fine Tuning
Task: Open-domain QA for medical questions.
Step 1: Fine-tune the retriever:
- Provide medical questions and passages from a medical knowledge base.
- Optimise retrieval for relevant medical information.
Step 2: Fine-tune the generator:
- Use retrieved medical text to train the generator to produce accurate, patient-friendly answers.
Input:
Query:
- βWhat are the symptoms of diabetes?β
Retrieved:
- βDiabetes symptoms include frequent urination, thirst, and fatigue.β
Generated Output:
- βThe symptoms of diabetes include frequent urination, excessive thirst, and feeling tired.β
(The above example of Query, retrieval and generation has been generated with the help of the AI tool β ChatGPT)
Summary
Fine-tuning in RAG is about:
- Training the retriever to fetch relevant information effectively (before joint rag fine-tuning )
- Training the generator to use retrieved content to generate answers.
This process adapts RAG to specific domains or tasks, making it highly effective for knowledge-intensive applications.
Letβs discuss another very important aspect of Teacher Forcing in the setup.
Now, as we donβt make labels for the retrieval of documents for pairs of query and ground truth answers and we only have (query, answers) in our training data, one should question what happens if the generative model gets the wrong document during pre-training or fine-tuning, that is when an irrelevant document which does not have the information is fed to the generative model to generate the correct answer. There comes the concept of teacher forcing.
Letβs clarify how teacher forcing works in the context of RAG and what happens when incorrect passages are retrieved.
1. How RAG Handles Incorrect Passages
During joint fine-tuning in RAG:
- The retriever retrieves passages, and these are passed to the generator as they are, even if they are incorrect which means does not contain information.
- The generator tries to produce the correct output based on the retrieved passages. If the passages are incorrect, it still tries to generate an answer but will likely perform poorly.
2. Role of Teacher Forcing
Teacher forcing is typically used in the generatorβs training phase:
- It involves providing the ground-truth answer (correct output) at each training step, even if the retrieved passages are wrong or incomplete.
- This helps the generator learn to generate the correct answer pattern regardless of the retrieval quality; it learns to deal with noise.
3. Why Use Incorrect Passages?
Allowing the generator to work with the retrieved passages, even if they are wrong, is important because:
- It reflects the real-world setting, where retrieval is not always perfect.
- The generator learns to handle noisy or incomplete information better, which can be considered a good thing.
If correct passages were directly substituted during training (bypassing the retrieverβs mistakes), the generator wouldnβt learn how to handle real-world retrieval errors.
The generator side :
- For instance, letβs say the passage retriever was not able to give relevant documents/context to the generative model.
- The generative model attempts to generate the answer with whatever context given to it.
- The correct answer is provided via teacher forcing (from ground truth labels at the generatorβs side) to compute loss and update weights, if the generator isnβt able to produce the correct answer through the retrieved passage, it will impose a heavy penalty that will also tell the RAG mechanism that the retriever is also not able to retrieve properly, so it weighs will also get updated (if the weights of the passage retriever are not fixed, which is generally not the case but the weights of query encoder are free as far as original RAG paper is concerned, so in the later case the weights of query encoder will get updated if heavy loss is there because of inefficient retrieval.)
Two methods of RAG fine-tuning that are mentioned in the RAG for knowledge-intensive NLP Tasks (Original RAG research paper) :
- RAG sequence
- RAG token method
RAG Sequence Model
RAG-Sequence: This method treats the entire output generation process as influenced by a single set of retrieved documents for each input query for the generation of the whole sequence (the set of top-k docs remains the same till the response has been generated against the query fed into the application). The probabilities of the generated responses from each document are combined (marginalised) to produce the final output. More on this in detail and with examples is below.
In the context of the RAG-Sequence method, marginalisation refers to the process of averaging the outputs generated from each of the top K documents to produce a single, final output (refer to the diagram of the RAG sequence model above). This involves computing the overall probability of the sequence by considering the individual contributions from each document-based generation and then taking the average of top k docs to generate the final answer y (from the model), which would be compared against y* (ground truth labels) for loss function.
How Marginalisation Works in RAG-Sequence:
- Generate Individual Outputs: The model first generates separate outputs for each of the top K (K is set by us, that out of total documents how many documents we want to feed the generator model usually k = 5 to 10) documents retrieved by the retriever. Each doc has an associated probability p(z/x) of being the correct response to the input query based on the content of the corresponding document. The intuition for p(z/x) is the cosine similarity between query embedding and document embedding.
- Compute Combined Probability: Marginalisation then takes these probabilities and combines them to determine the final output probability, just see the diagram above that after retrieval of z1, z2, z3 from the external database, corresponding answers y1, y2, and y3 are being produced from the docs with score of p(z1/x), p(z2/x), p(z3/x) and then the response y1, y2, y3 are marginalised to get final output y.
If this p(z/x) has confused you, you will get it clear in the next section of this article, one just has to understand what happens in the case of the RAG sequence and RAG token mechanism.
Purpose:
The purpose of marginalisation is to ensure that the final output is not just dependent on a single document but reflects the collective evidence and relevance of all top documents considered. This method helps in making the final decision more versatile and less biased towards anything particular, or what we say helps reduce overfitting.
Equation of RAG Sequence Probability
The equation describes the probability computation for generating a sequence y given an input x using the RAG-Sequence method. It involves integrating out (or marginalising over) all the possible documents z retrieved based on the input.
Letβs break down the equation into more understandable parts:
Equation Breakdown
- This equation states that the probability of generating the sequence y given x in the RAG-Sequence model is approximated by summing over the top k documents z that are most relevant to x.
- Pm(zβ£x): Probability that a document z is relevant to x, as determined by the retriever model. (cosine product between query embedding and embeddings of the document stored in the vector database)
- PΞΈ(yβ£x,z): Probability of generating the sequence y, given input x and document z, as determined by the generator model.
- The equation is the same as what is explained above; the only difference is attention is laid on the generative part. PΞΈ term is expanded and shown how the next word is predicted given all previous tokens (before the current timestamp) , input and retrieved context.
- βiPΞΈ(yiβ£x,z,y1:iβ1) shows that the probability of generating each token yi in the sequence
The modelβs output response (y) is dependent on the following:
- x β input
- z β retrieved document.
- y1:iβ1 β all the previous tokens generated in the sequence.
This is typical of sequence generation models where each next word depends on all the previous words (auto-regressive property, in inference. Remember, training in the generative model is not auto-regressive, itβs done through teacher forcing).
Overall, this method allows the RAG-Sequence model to generate responses that are well-informed by relevant external information, making it powerful for tasks where context and detail are important.
Summing up the flow, loss functions and example in the RAG sequence model
It has been a lot to digest. Letβs sum up the flow once, and we will further discuss the loss functions involved We would also understand the whole flow with the help of an example.
1. Retrieval of Top k Documents:
- For a given input x, the model retrieves the top k most relevant documents from the non-parametric memory (in the example and diagram above, k = 3 with documents z1, z2, z3).
- The relevance of each document z to the input x is determined by a semantic search or matching score, P(z | x), which could be computed using embedding similarities β cosine product between the query vector and embedding vector.
2. Generation of the Sequence:
For each document z, the model computes the probability of generating the entire sequence y token by token, where each token generation is conditioned on:
- The input x,
- The retrieved document z,
- All previously generated tokens in the sequence.
The sequence generation for a single document z1 would look like this , expanding the term would look like β P(y0 | x, z1, start token) x P(y1 | x, z1, y0) x P(yN | x, z1, y0, y1β¦β¦., yN-1) β This is for generative part.
By the principle of counting rule β To calculate the probability of whole sequence that is for simultaneous occurrence of all tokens, thatβs why we are multiplying the terms
3. Marginalisation Over Documents:
The final probability of the sequence y given x is calculated by marginalising over these generated sequences for each document that is summing over for each expanded term (shown above for p(z1/x)).
For all top k = 3 docs , here for particular example shown in above imageβ z1, z2, z3) , the full equation would look like β
p (rag-sequence)(y/x) = P(z1/x) x P(y0 | x, z1, start token) x P(y1 | x, z1, y0) x P(yN | x, z1, y0, y1β¦β¦., yN-1) + P(z2/x) x P(y0 | x, z2, start token) x P(y1 | x, z2, y0) x P(yN | x, z2, y0, y1β¦β¦., yN-1) +P(z3/x) x P(y0 | x, z3, start token) x P(y1 | x, z3, y0) x P(yN | x, z3, y0, y1β¦β¦., yN-1)
4. Generatorβs Role and Softmax:
During sequence generation, at each token prediction step, the modelβs generator (a sequence-to-sequence model like BART as in the original paper) uses its last layerβs softmax to determine the most probable next token based on the current context (input, selected document, and previously generated tokens). This section is covered under loss functions in detail.
5. Overall Probability Representation:
The equation outlined represents how the probability of the generated sequence y given the input x is computed in the RAG-Sequence model. It captures the entire mechanism of how each component (retriever and generator) contributes to the modelβs output in a probabilistic framework.
More on intuition via counting principle that we study in permutations and combinations
The intuition behind why there is a summation sign (β) at the start of the equation and Multiplication (β) after retriever probability should be clear. Itβs analogous to counting principle in permutations and combinations, Iβll explain how in the RAG sequence model, we have K paths (top k docs from non-parametric memory (Z)), so weβll use βSUMβ (logical β OR) and because we have to calculate the probability of the whole sequence (of every word that generator is generating) simultaneously (considering all the words to complete the sentence), so weβll use βMULTIPLICATIONβ (logical β βANDβ), this will give the marginalised probability of RAG sequence model using top k docs from Z and generating sentence of sequence length N.
We have understood the equation breakdown of the RAG sequence model, now letβs understand the significance of the number it gives the P(y/x) in RAG sequence model, and itβs very important to understand the intuition behind this number, as this represents the likelihood which we need to maximise to get the best weights (or to minimise the negative log of this), understanding the intuition behind this will help to understand the loss function used while RAG training/fine-tuning. (Modelβs response is y, correct ground truth is y*)
The probability p(RAG-Sequence) in the RAG-Sequence model represents the likelihood that the sequence y is the appropriate response to the given input query x.
Hereβs what this probability number signifies:
- In training β p(RAG-Sequence), its negative log-value (NLL) is directly used as the training loss, enabling effective backpropagation during the modelβs learning phase.
- While in inference probability p(RAG-Sequence) β Higher probability = more relevance, better fit, and stronger model confidence.
Now, letβs understand the training, back-prop and loss function part of the RAG sequence model, but for that, one should know about how training works in only generator, it would be easy to understand the intuition behind maximising p(RAG-Sequence probability) that is maximising the likelihood, how teacher force is involved in solely generator models and how the comparison is made with ground truth labels.
I will attach my handwritten notes, which, in detail, along with examples, show the workings of how loss functions work in generative only, rag sequence-based, and rag token-based at the end of this article before the header CONCLUSION. (I will share a drive link that contains the notes of the same in a PDF)
Training Process and Loss Function
In a RAG model, training involves fine-tuning both the retriever (query encoder weights trainable and passageβs retriever fixed in original paper) and generator components to optimise the sequence generation process based on retrieved documents. The typical loss function used in this context is the Negative Log Likelihood (NLL), which is expected because the whole process is associated with the generation of the output sequence, the loss function we use would be indeed cross entropy log loss when the comparison between y* and y tokens will happen, I have attached my notes below to get in-depth understanding in this.
Loss Function:
The loss function for a training instance when using the RAG model is the negative log-likelihood (log loss) of the correct sequence given the input. Mathematically, this is expressed as:
Loss = -log P(y* | x)
Where:
- y* is the correct (ground truth) sequence for the given input x.
- P(y* | x) is the probability of generating the correct sequence calculated by the model.
It can be interpreted as whatβs the modelβs conditional probability of predicting the correct token , given input x.
Model will predict a certain probability (0β1) for all tokens in the vocab of generator model , but for loss function we have to penalise the modelβs prediction against correct label ->basic principle of cross categorical entropy loss.
This probability is detailed as follows:
Where:
- P(z | x) is the probability of retrieving document z given x, determined by the retriever.
- P(yi* | x,z,y (1:i-1)*) is the probability of generating the i-th token of the sequence given the input, the retrieved document, and all previously generated correct tokens or in the mathematical terms we can say that this term P(yi* | x,z,y (1:i-1)* ) signifies is that whatβs the conditional probability of model generating the yi* (correct token at the ith instance of the sequence) given input, retrieved documents and query. If one thinks properly, this is indeed what cross-entropy log loss suggests.
Back-propagation and Parameter Updates:
During training, back-propagation is used to adjust the parameters of both the retriever (generally, weights are fixed of passage retriever and query encoderβs weights are trainable as far as the original RAG paper is concerned) and the generator by minimising the loss function. The gradient of the loss concerning the model parameters is computed as follows:
βLoss/β(parameters) = β (1/P(y*|x)) * βP(y*|x)/β(parameters)
Gradient Calculation Using Chain Rule:
To minimise the loss, we need to calculate the gradients of the loss function concerning the parameters of the generator (ΞΈ) and the retriever (m). Using the chain rule, these gradients are given by:
This detailed breakdown explains how the training process aligns with the modelβs outputs with the correct labels by fine-tuning both the retrieval and generation processes through the loss function.
In the above equation, the updation of weights of the retriever is also shown, which is a general case if we allow all the parameters involved in the pipeline to get trained, which is not generally feasible and optimal with respect to the computing it takes, so generally we keep weights of passage retriever fixed and no updation happens in their weights, the above equation is just concerned to show the most general case possible.
Letβs understand all these things with the help of an easy example in this k=2; the top 2 docs are retrieved to generate the whole sequence by giving input x
Example Setup:
- Input Query: βWhere does he go every morning?β
- Ground Truth Sentence: βHe goes to school.β
- Retrieved Documents:
- Doc 1: βHe often goes to school.β
- Doc 2: βHe usually goes to the house.β
Correct Sequence Generation for Each Document:
Doc 1 (Correct Document) Calculation:
- Probabilities for βHe goes to schoolβ from Doc 1:
- βHeβ β Prob = 0.9
- βgoesβ β Prob = 0.8
- βtoβ β Prob = 0.9
- βschoolβ β Prob = 0.8
Doc 2 (Incorrect Document) Calculation:
- Probabilities for βHe goes to schoolβ from Doc 2:
- βHeβ β Prob = 0.9
- βgoesβ β Prob = 0.8
- βtoβ β Prob = 0.9
- βschoolβ β Prob = 0.2 (the correct ground truth token probability when βhouseβ was predicted incorrectly), βHouseβ β Prob = 0.6 (as we have to calculate the probability of the token given by the generative model for the correct label, so the correct label was a school for which the model gives the prob of 0.2).
Calculating the Total Probability P(yββ£x)
Contribution from Doc 1:
P(correct sequence from Doc 1β£x,Doc 1) = 0.9 Γ 0.8 Γ 0.9 Γ 0.8 = 0.5184
Weighted by retrieval probability:
P(Doc 1β£x) Γ P(correct sequence from Doc 1β£x,Doc 1) = 0.6 Γ 0.5184 = 0.31104
Contribution from Doc 2:
P(correct sequence from Doc 2β£x,Doc 2) = 0.9 Γ 0.8 Γ 0.9 Γ 0.2 = 0.1296
Weighted by retrieval probability:
P(Doc 2β£x) Γ P(correct sequence from Doc 2β£x,Doc 2) = 0.4 Γ 0.1296 = 0.05184
Total Probability of Correct Sentence P(yββ£x):
Combining contributions from both documents:
P(yββ£x) = 0.31104 + 0.05184 = 0.36288
Loss Calculation and Backpropagation:
The negative log-likelihood (NLL) of the ground truth sequence:
Loss = βlog(0.36288)
(The above example is created using ChatGPT by giving a suitable prompt to create an example showing the loss involved mathematically)
This setup ensures that each tokenβs probability from the correct ground truth sequence is used for calculating the total probability of generating the correct sequence. The loss function effectively penalises the model more heavily when incorrect predictions (like predicting βhouseβ instead of βschoolβ) are made, even if the incorrect prediction was made with high confidence. This method ensures that the model is trained to increase accuracy over time, aligning better with the correct labels.
This is how things work in the RAG sequence mechanism, we have discussed it in very detail, so now it would be easy to roll it from here on. The other mechanism, the RAG token mechanism, hasnβt been written out that extensively as the majority of the concepts are similar; just the marginalisation part would differ, and that is written in detail.
RAG Token Model
As we have gone through so many detailed descriptions of the RAG sequence model, it would be easy to understand the RAG token-based model, Iβll summarise that, in short, the only difference here is that marginalisation happens at every token generated, at every token top k docs are retrieved then their output is marginalised and then next token is predicted, I hope the intuition behind the maximum likelihood function in this aspect is clear, the summation would come inside the likelihood function, and the Pie multiplication term will go outside, as marginalisation is happening at every step and to get the probability of whole sequence, we would need to multiply the marginalised output at generation and retrieval at every token.
So, the counting principle analogy to this would be the top K paths(summation/marginalisation) during retrieval from non-parametric memory, and then to calculate the probability of generating the whole sequence (every token), we need to multiply the marginalised probability at every token.
Loss Function for RAG-Token Model
The loss function focuses on maximising the likelihood of the correct sequence y*, given the input x. This is done by calculating the probability of each token in the sequence conditioned on the input and all preceding ground truth tokens using retrieved documents relevant to that context.
Loss Function
Β· N is the length of the sequence y*.
Β· y*i represents the i-th token in the ground truth sequence.
Β· y*1:i-1 represents all the ground truth tokens from the start-up to i β 1.
Β· P(z | x, y*1:i-1) is the probability of retrieving document z based on the input x and the ground truth sequence up to i β 1, computed by the retriever.
Β· P(ΞΈ(yi | x, z, y1:i-1)) is the probability of generating the ground truth token y*i given x, the retrieved document z, and all previous ground truth tokens computed by the generator.
The intuition of the loss function remains the same in that the loss would be computed for every token generated by the model yi against the ground truth yi*, and itβs pretty much the same as what the conditional probability of the model predicting the correct token yi* given input, all previous correct tokens and retrieved-context thatβs what suggests the cross entropy category loss for the tokens at the generatorβs side.
Example Setup
- Input Query: βWhere does he go every morning?β
- Ground Truth Sequence: βHe goes to school.β
- Documents (7 total):
- Doc 1: βHe often goes to school.β
- Doc 2: βHe usually goes to the house.β
- Doc 3: βHe frequently visits a park.β
- Doc 4: βHe sometimes travels to a market.β
- Doc 5: βHe wakes up early.β
- Doc 6: βHe loves to read books.β
- Doc 7: βHe attends classes at school.β
We assume that each token generation step involves retrieving the top 2 documents from this pool based on the current partial sequence and the original query.
Token-by-Token Retrieval & Generation Example
Token 1: βHeβ
- Context so far: input x = βWhere does he go every morning?β (no tokens generated yet)
- Top 2 docs retrieved: Suppose the model finds Doc 1 and Doc 5 most relevant (Doc 1 mentions βHe often goesβ, and Doc 5 relates to βHeβ but not specifically about going somewhere, just βwakes up earlyβ).
P(Doc 1|x) = 0.5
P(Doc 5|x) = 0.5 - Probabilities of generating βHeβ:
P(βHeβ|x,Doc 1) = 0.9
P(βHeβ|x,Doc 5) = 0.7 (less directly relevant, but still plausible) - Marginalised probability for βHeβ:
P(βHeβ|x) = P(Doc 1|x) * P(βHeβ|x,Doc 1) + P(Doc 5|x) * P(βHeβ|x,Doc 5)
= (0.5*0.9) + (0.5*0.7)
= 0.45 + 0.35
= 0.8
Token 2: βgoesβ
- Now the context is x, and the generated token is βHeβ.
- With βHeβ established, the model might now find documents mentioning going somewhere more relevant. At this step, the top 2 docs might be Doc 1 and Doc 2 (both talk about going to a location, either school or house).
P(Doc 1|x,βHeβ) = 0.6
P(Doc 2|x,βHeβ) = 0.4 - Probabilities of βgoesβ:
P(βgoesβ|x,βHeβ,Doc 1) = 0.85 (Doc 1 strongly correlates with going)
P(βgoesβ|x,βHeβ,Doc 2) = 0.8 (Doc 2 also mentions βgoesβ) - Marginalised probability for βgoesβ:
P(βgoesβ|x,βHeβ) = (0.6*0.85) + (0.4*0.8)
= 0.51 + 0.32
= 0.83
Token 3: βtoβ
- Context: βHe goesβ
- Given we have βHe goesβ now, documents that mention going to specific places become more relevant. Letβs say the top 2 at this step are Doc 1 and Doc 7 since Doc 7 explicitly says, βHe attends classes at school,β which aligns with going to a place:
P(Doc 1|x,βHe goesβ) = 0.55
P(Doc 7|x,βHe goesβ) = 0.45 - Probabilities of βtoβ:
P(βtoβ|x,β He goesβ, Doc 1) = 0.9 (Doc 1 closely related to going to school)
P(βtoβ|x,β He goesβ, Doc 7) = 0.85 (Doc 7: βattends classes at schoolβ implies going to school as well) - Marginalised probability for βtoβ:
P(βtoβ|x,βHe goesβ) = (0.55*0.9) + (0.45*0.85)
= 0.495 + 0.3825
= 0.8775 (β0.88)
Token 4: βschoolβ
- Context: βHe goes toβ
- At this final token, documents that mention βschoolβ become highly relevant. Letβs assume the top docs are Doc 1 and Doc 7 again. Doc 1 explicitly says, βHe often goes to school,β and Doc 7 mentions βattends classes at school.β P(Doc 1|x,βHe goes toβ) = 0.6
P(Doc 7|x,βHe goes toβ) = 0.4 - Probabilities of βschoolβ:
P(βschoolβ|x,βHe goes toβ,Doc 1) = 0.8 (direct match)
P(βschoolβ|x,βHe goes toβ,Doc 7) = 0.7 (also strongly implies school context) - Marginalised probability for βschoolβ:
P(βschoolβ|x,βHe goes toβ) = (0.6*0.8) + (0.4*0.7)
= 0.48 + 0.28
= 0.76
Combining the Token Probabilities
For the RAG-Token model, we multiply the marginalised probabilities at each token step:
P(y*|x) = P(βHeβ|x) Γ P(βgoesβ|x,βHeβ) Γ P(βtoβ|x,βHe goesβ) Γ P(βschoolβ|x,βHe goes toβ)
= 0.8 Γ 0.83 Γ 0.88 Γ 0.76
Letβs calculate step-by-step:
0.8 Γ 0.83 = 0.664
0.664 Γ 0.88 β 0.58432
0.58432 Γ 0.76 β 0.44408
So, P(y*|x) β 0.444 (This is just an example scenario with made-up probabilities.)
Loss Calculation
Loss = -log(P(y*|x)) = -log(0.444)
(The above example is created using ChatGPT by giving a suitably detailed prompt to create an example showing the loss involved mathematically in the RAG token-based model)
The intuition here is β at each token, the model retrieves the top 2 documents given all the input query, retrieved document and ground truth tokens till the current token is generated. The final probability of generating the entire correct sequence is the product of these token-wise marginalised probabilities.
Retrieval probabilities can shift token by token. For example, at token one, we chose Doc 1 and Doc 5, while at token two, we shifted to Doc 1 and Doc 2, and so onβ¦. This simulates the RAG-Token approach, where retrieval happens at every token step.
So thatβs how things work out in the RAG token-based model. Clearly, itβs visible that because the retrieval is involved at every token in the RAG token-based model as compared to retrieval in the RAG sequence-based model where the retrieval for the generation of the entire sequence happens once, so the time and resources are more in RAG token-based model, and it should be choice for complex downstream tasks as it is resource heavy.
Drive Link for notes on loss functions.
Conclusion
The base of the whole article was the original RAG research paper, and we discussed the two mechanisms involved, that is, RAG sequence-based and RAG token-based models, in detail. Fine tuning the RAG would need some good engineering skills to get good results on the downstream task (depending upon the task), choosing the right kind of retriever, generator model, and vector database, and writing the fine-tuning training script for the data are all essential skills involved. Also, choosing the most optimal way to perform well on the custom dataset/downstream task is required for production-ready solutions. Understanding the core principles of RAG fine-tuning would be a plus in those engineering skills.
If possible we can think of bringing some changes in the RAG technique from pre-training and fine-tuning stage, as we cannot input a lot of tokens as context (like k=10), that would add on a lot of cost as token usage is high, thatβs why RAG applications donβt scale up, hope someday it happens, this was my main motivation of writing out my knowledge about pre-training and joint RAG fine-tuning.
My contact details
Email β [email protected]
Twitter β https://x.com/r4plh
GitHub β https://github.com/r4plh
Linkedin β https://www.linkedin.com/in/aman-agrawal-bbb3641b8/
References
- Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks
- Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks
- BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding
- BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension
- I have also used Grammarly for editing my article β https://app.grammarly.com/
Join thousands of data leaders on the AI newsletter. Join over 80,000 subscribers and keep up to date with the latest developments in AI. From research to projects and ideas. If you are building an AI startup, an AI-related product, or a service, we invite you to consider becoming aΒ sponsor.
Published via Towards AI