Fine-tuning Embeddings for RAG applications
Author(s): Anuar Sharafudinov
Originally published on Towards AI.
The rise of the Retrieval-Augmented Generation (RAG) has revolutionized how we build intelligent applications. At its core, RAG is all about efficiently turning large chunks of text into actionable embeddings and then letting an AI model piece together contextually relevant answers.
However, what works in theory can stumble in real-world scenarios. Why? One of the biggest culprits is poor or unclear embedding representations. Often, these representations donβt align well with the demands of production-level applications β particularly for tasks like question-answering.
The solution is fine-tuning embeddings β an impactful way to enhance your RAG implementation.
RAG 101: How It Works
Letβs break it down. Hereβs the typical RAG workflow:
- User Input: A user submits a question or query.
- Query Embedding: The system generates an embedding for the query.
- Chunk Matching: It searches for the chunk embeddings most similar to the query using cosine similarity.
- Answer Generation: The contents of the retrieved top chunks are sent as context to a language model, which generates the final response.
This setup works well in theory. However, when embeddings lack precision, the results can feel off-target, especially when dealing with large datasets.
The Fine-Tuning Solution
What if you could pre-train your embeddings to anticipate the kinds of questions your users might ask?
Hereβs the idea:
- Generate Question-Chunk Pairs: For each chunk of text in your dataset, generate multiple potential questions it could answer.
- Fine-Tune the Embedding Model: Train the model to pull embeddings of related questions and chunks closer together in multidimensional space while pushing unrelated ones further apart.
While this approach might seem like overfitting, it actually focuses on optimizing for generalization. It turns out, fine-tuning embeddings in this way equips the system to handle unseen queries with improved accuracy.
The Results Speak for Themselves
Fine-tuning embeddings yielded remarkable improvements across several models. For training, we used one of our internal experimental datasets. It consists of 52 chunks, each approximately 800 tokens long. For each chunk, we used Anthropicβs Claude-3-Sonnet to generate 3β5 corresponding questions.
To evaluate performance, we measured how often the correct chunk appeared within the top 3, top 5, and top 10 retrieved results. To provide a broader context, we also included results for OpenAI/text-embedding-large-3. However, since it is a closed-source model, we could not apply fine-tuning to it.
Hereβs a snapshot of the results:
Open-Sourcing the Code
If youβre inspired to experiment with fine-tuning, weβve got you covered. Check out our code repository with training and testing scripts for Alibaba-NLP/gte-Qwen2β1.5B-instruct and jinaai/jina-embeddings-v3 models. The repo also includes support for two training methods: TripletMarginLoss and CosineEmbeddingLoss.
Model requirements
- Alibaba-NLP/gte-Qwen2β1.5B-instruct Requires about 30GB of VRAM. A GPU with 40GB of memory and higher (e.g., A100) is recommended. Its forward pass logic is standard and can be applied to many similar embedding models.
- jinaai/jina-embeddings-v3 is a very lightweight model requiring only 8GB of GPU memory for fine-tuning. Its forward-pass logic is slightly specific, but the core concept is clear.
Training methods
- TripletMarginLoss. This method uses an anchor (βaβ), a positive sample (βpβ), and a negative sample (βnβ):
- Anchor (a): Chunk content embedding
- Positive sample (p): A corresponding question embedding
- Negative sample (n): An unrelated question embedding
To build a training set, create (chunk, questions) pairs and randomly select unrelated questions as negative samples.
2. CosineEmbeddingLoss. This method uses positive and negative samples from different parts of the training set:
- x1: The chunk embedding
- x2: Either a positive or negative sample embedding
- y: Label indicating if x2 is positive (y=1) or negative (y=-1).
Adapting the Code
To use your own dataset, modify the prepare_data
function in train.py.
Ensure it returns chunks and their corresponding questions as pairs.
Note: The repository does not include question generation logic, but various approaches are available. Below, weβve included a sample code that we used for reference.
#1. split the document into chunks (simple way)
def split_into_chunks(content, chunk_size):
import tiktoken
enc = tiktoken.get_encoding("o200k_base")
a = enc.encode(content)
left, chunks = 0, []
while left < len(a):
arr = a[left : left+chunk_size]
chunks.append(enc.decode(arr))
left+=chunk_size
return chunks
chunks = split_into_chunks(document_content, 400)
#2. generate questions
def anthropic_run(system_prompt, user_message):
import anthropic
client = anthropic.Anthropic(
api_key=ANTHROPIC_API_KEY,
)
message = client.messages.create(
model="claude-3-sonnet-20240229", #"claude-3-opus-20240229",
max_tokens=4096,
system=system_prompt,
messages=[
{"role": "user", "content": user_message}
]
)
return message.content[0].text
system_prompt = '''
Given a chunk from document. Generate 3-5 questions related to the chunk. Each question must be full and not require additional context.
Example output:
1. How to open new account?
2. How much BMW X5 costs?
'''
for chunk in chunks:
text = "#"+chunk["keywords"]+"\n"+chunk["content"]
out = anthropic_run(system_prompt, text)
question_pattern = re.compile(r'^\s*\d+\.\s+(.*)', re.MULTILINE)
questions = question_pattern.findall(out)
print(text, questions)
#now you have (chunk, questions) pairs
Continuous Improvement
Another advantage of this approach is its potential for continuous improvement. Over time, as new cases emerge, you can retrain the model. For instance, if a corresponding chunk wasnβt found in the top 10 (avoid large numbers to avoid the βlost in the middleβ issue) and the LLM failed to generate an answer, simply add this question and its correct chunk to the training set. This ensures the system evolves to handle similar issues in the future more effectively.
Join thousands of data leaders on the AI newsletter. Join over 80,000 subscribers and keep up to date with the latest developments in AI. From research to projects and ideas. If you are building an AI startup, an AI-related product, or a service, we invite you to consider becoming aΒ sponsor.
Published via Towards AI