![Parameter-Efficient Fine-Tuning (PEFT): A Hands-On Guide with LoRA Parameter-Efficient Fine-Tuning (PEFT): A Hands-On Guide with LoRA](https://i2.wp.com/miro.medium.com/v2/resize:fit:700/0*nIWCHJBSdXRbAynw.png?w=1920&resize=1920,921&ssl=1)
Parameter-Efficient Fine-Tuning (PEFT): A Hands-On Guide with LoRA
Author(s): BeastBoyJay
Originally published on Towards AI.
Imagine building a powerful AI model without needing massive computational resources β PEFT makes that possible, and Iβll show you how with LoRA from scratch.
Introduction
Traditional fine-tuning challenges :
Fine-tuning large models sounds cool β until reality hits. Imagine trying to sculpt a masterpiece but needing a giant crane just to lift your tools. Thatβs what traditional fine-tuning feels like. Youβre working with millions (sometimes billions) of parameters, and the computational cost can skyrocket faster than your coffee bill during finals week.
Hardware Struggles:
- Got a spare supercomputer lying around? Probably not.
- GPUs heat up like your phone during a marathon PUBG session.
- RAM gets maxed out faster than your Netflix binge in 4K.
Data Dilemma:
- You need a ton of data, or your model behaves like a forgetful student on exam day.
- Gathering and cleaning that much data? A nightmare in itself.
Snail-Speed Training:
- Hit βrunβ and waitβ¦ and waitβ¦ and maybe even take a nap while your model chugs along.
Maintenance Mayhem:
- Tiny tweaks mean re-training the whole colossal beast.
- Waste of time, energy, and your already-thin patience.
Need a solution :
PEFT, solution for this traditional bulky fine-tuning method. Think of PEFT (Parameter-Efficient Fine-Tuning) as upgrading a car by just changing the tires instead of rebuilding the whole engine. Instead of retraining every parameter in a massive model, PEFT tweaks just the essential parts β saving time, resources, and sanity.
Why it rocks:
- Resource-Smart: No supercomputer required.
- Time-Saving: Faster results with minimal effort.
- Scalable: Handles large models like a pro.
What is PEFT ?
PEFT (Parameter-Efficient Fine-Tuning) is like giving your AI model a performance boost by only adjusting the most important parameters, rather than retraining the entire thing. Think of it as overclocking your model without needing to upgrade the whole motherboard.
Why Is PEFT Necessary?
Reduced Training Costs:
Instead of burning through a fortune in GPU time to retrain the whole model, PEFT lets you fine-tune with minimal resources, saving both cash and computing power.
Faster Adaptation to Tasks:
PEFT allows you to quickly adapt large models to new tasks by only tuning the necessary components β speeding up the training process without sacrificing accuracy.
Minimal Memory Requirements:
Rather than loading the entire model into memory, PEFT uses fewer resources, letting you work on large-scale models without draining your system.
How PEFT works ?
The core idea of the PEFT is to focuses on updating only a small subset of parameters that are crucial for task-specific performance, Instead of updating the entire set of model parameters during fine-tuning. This is done by introducing task-specific βadaptersβ or by manipulating a few selected layers (like attention or feed-forward layers) that control the output for a given task.
Types of PEFT techniques :
LoRA (Low-Rank Adaptation) :
Letβs talk about one of the coolest tricks in PEFT (Parameter-Efficient Fine-Tuning) β LoRA. Imagine youβve got this massive pre-trained model, like a Transformer, thatβs already packed with all sorts of knowledge. Now, instead of modifying everything in the model, LoRA lets you tweak just the essentials β specifically, a few sneaky little low-rank matrices that help the model adapt to new tasks. The rest of the model stays frozen in time, like an immovable fortress, while LoRA does its magic.
So, how does LoRA work its sorcery?
Hereβs the gist of it: Letβs say thereβs a weight matrix W in the model (maybe in the attention mechanism, where the model decides whatβs important in the input). LoRA comes in and says, βWhy not approximate W as the product of two much smaller matrices, A and B?β Mathematically, itβs like:
WβAΓB
These matrices, A and B, are low-rank β which, in nerd terms, means they have way fewer parameters to deal with compared to the original weight matrix. The magic? Because A and B are so much smaller, weβve got fewer parameters to tune during fine-tuning.
But thatβs not all β hereβs the real kicker:
When it comes to fine-tuning, LoRA focuses only on training the parameters of A and B. The rest of the massive model stays locked, untouched. Itβs like having the keys to just one door in a huge mansion β youβre making minimal changes, but theyβre all targeted and impactful.
By doing this, you reduce the number of parameters you need to update during fine-tuning, which makes the whole process way more efficient. Youβre getting the same task-specific performance without the heavy lifting of retraining everything. Itβs like finding the shortcut in a maze β you still reach the goal, but with way less effort!
Adapters :
Letβs talk about Adapters β not the kind you plug into your phone charger, but these nifty little modules that slot into the transformer architecture like a perfect puzzle piece!
Imagine youβve got a powerful pre-trained model, and you need to adapt it to a new task. Instead of retraining the entire thing, you introduce an adapter β a lightweight, task-specific module that fits neatly after each transformer block. The best part? You donβt have to touch the core model at all. Itβs like adding a few extra gears to a well-oiled machine without dismantling the whole thing.
Hereβs the lowdown on how adapters work:
- Insertion into Layers: Think of an adapter as a mini-module that slides in after key layers in the transformer, like right after the attention or feed-forward layers. It usually consists of a couple of fully connected layers, where the input size is the same as the original layer (because, letβs face it, we donβt want to mess with the modelβs flow), but the output dimension is smaller. Itβs like a sleek, efficient middleman.
- Task-Specific Tuning: Hereβs where the fun happens: When you fine-tune the model, only the adapter parameters are updated. That means the core model β packed with all its pre-trained knowledge β stays frozen, like a wise professor whoβs teaching you everything they know, but youβre just adding some extra knowledge with the adapter. The adapter absorbs the task-specific tweaks without messing up the original wisdom of the model.
The Big Win?
The core model retains its massive, generalized knowledge while the adapter learns just enough to tackle the new task. Itβs like teaching a world-class musician a new song without changing their entire repertoire. Efficient, fast, and keeps things clean.
Prefix Tuning :
Letβs get into the groove of Prefix Tuning β a clever, minimalist trick that adds just the right amount of guidance to steer a model without overhauling its entire structure. Itβs like giving your car a gentle nudge to take a different route without touching the engine. Cool, right?
Hereβs how Prefix Tuning works its magic:
- Learnable Prefix: Picture this: before the model gets to process the input text, you prep a small, task-specific set of tokens β this is your prefix. Itβs like a little note that says, βHey, focus on this when youβre working!β These tokens are learnable, meaning you can train them to carry the relevant task information. Importantly, the rest of the modelβs weights stay locked down, untouched.
- Controlling Attention: The prefix isnβt just a random add-on. These tokens guide the modelβs attention mechanisms, telling it which parts of the input to focus on. Itβs like placing a signpost at the start of the road, directing the model on where to head next. So, when the model generates an output, itβs subtly influenced by the prefix tokens, helping it stay on track for the specific task at hand.
The Beauty of Prefix Tuning?
The brilliance of prefix tuning lies in its simplicity. Youβre not retraining the entire model or altering its inner workings. Instead, youβre enhancing its attention β just enough to guide it in the right direction for the task you need it to perform.
BitFit :
Letβs dive into BitFit, a deceptively simple yet highly effective PEFT technique thatβs like tweaking just the small dials on a well-tuned machine to get the perfect result. Instead of overhauling the entire system, BitFit focuses on the tiniest components to make a big impact.
How BitFit Works:
- Bias Tuning: Imagine your model is a giant network of gears and levers (aka weights) that are already trained and doing their thing. Now, instead of retraining every gear, BitFit zooms in on the bias terms β the extra parameters that get added to the final output of each layer. These bias terms are like small adjustments that help shift the modelβs output in the right direction, but they donβt have the complexity or weight of the entire modelβs weights.
- Minimalist Fine-Tuning: The trick is that only the bias terms are tuned, while the rest of the modelβs weights remain frozen. Bias terms are much smaller in number compared to the full set of weights, so youβre making very targeted changes. Itβs like fine-tuning the volume on a speaker without touching the entire sound system. Youβre still getting the desired sound (or task performance), but without the hassle of fiddling with everything.
Why BitFit Rocks:
The real charm of BitFit is its efficiency. By focusing on just a few parameters, youβre able to fine-tune a model for a specific task while keeping the computational load light. Itβs a great way to make tweaks without the heavy lifting of full model fine-tuning, making it fast and resource-friendly.
Implementing LORA from scratch in Pytorch:
Now i will explain you how you can Implement the LORA from scratch so that you have more deep understanding about it.
importing necessary libraries :
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.nn as nn
from tqdm import tqdm
Making torch model deterministic :
_ = torch.manual_seed(0)
Training a small model :
Letβs have some fun with LoRA! Weβll start by building a small, simple model to classify those classic MNIST digits β you know, the ones everyone loves to work with when learning machine learning. But hereβs the twist: instead of stopping at basic digit classification, weβre going to take it up a notch.
Weβll identify one digit our network struggles with (maybe it just doesnβt vibe with the number 7?), and fine-tune the whole thing using LoRA to make it smarter and better at recognizing that tricky number. Itβs going to be a cool mix of training, tweaking, and improving β perfect for seeing LoRA in action!
Loading the Dataset:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
# Load the MNIST dataset
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
# Create a dataloader for the training
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)
# Load the MNIST test set
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(mnist_testset, batch_size=10, shuffle=True)
# Define the device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
Model Architecture:
class SimpleNN(nn.Module):
def __init__(self, hidden_size_1=1000, hidden_size_2=2000):
super(SimpleNN,self).__init__()
self.linear1 = nn.Linear(28*28, hidden_size_1)
self.linear2 = nn.Linear(hidden_size_1, hidden_size_2)
self.linear3 = nn.Linear(hidden_size_2, 10)
self.relu = nn.ReLU()
def forward(self, img):
x = img.view(-1, 28*28)
x = self.relu(self.linear1(x))
x = self.relu(self.linear2(x))
x = self.linear3(x)
return x
model = SimpleNN().to(device)
Training Loop:
def train(train_loader, model, epochs=5, total_iterations_limit=None):
cross_el = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
total_iterations = 0
for epoch in range(epochs):
model.train()
loss_sum = 0
num_iterations = 0
data_iterator = tqdm(train_loader, desc=f'Epoch {epoch+1}')
if total_iterations_limit is not None:
data_iterator.total = total_iterations_limit
for data in data_iterator:
num_iterations += 1
total_iterations += 1
x, y = data
x = x.to(device)
y = y.to(device)
optimizer.zero_grad()
output = model(x.view(-1, 28*28))
loss = cross_el(output, y)
loss_sum += loss.item()
avg_loss = loss_sum / num_iterations
data_iterator.set_postfix(loss=avg_loss)
loss.backward()
optimizer.step()
if total_iterations_limit is not None and total_iterations >= total_iterations_limit:
return
train(train_loader, model, epochs=1)
After executing the above code your small model will get trained and ready to inference,but before that let me keep a copy of the original weights (cloning them) so later we can prove that a fine-tuning with LoRA doesnβt alter the original weights.
original_weights = {}
for name, param in model.named_parameters():
original_weights[name] = param.clone().detach()
Now, Testing the performance of the Trained Mode :
def test():
correct = 0
total = 0
wrong_counts = [0 for i in range(10)]
with torch.no_grad():
for data in tqdm(test_loader, desc='Testing'):
x, y = data
x = x.to(device)
y = y.to(device)
output = model(x.view(-1, 784))
for idx, i in enumerate(output):
if torch.argmax(i) == y[idx]:
correct +=1
else:
wrong_counts[y[idx]] +=1
total +=1
print(f'Accuracy: {round(correct/total, 3)}')
for i in range(len(wrong_counts)):
print(f'wrong counts for the digit {i}: {wrong_counts[i]}')
test()
Output:
Accuracy: 0.954
wrong counts for the digit 0: 31
wrong counts for the digit 1: 17
wrong counts for the digit 2: 46
wrong counts for the digit 3: 74
wrong counts for the digit 4: 29
wrong counts for the digit 5: 7
wrong counts for the digit 6: 36
wrong counts for the digit 7: 80
wrong counts for the digit 8: 25
wrong counts for the digit 9: 116
As you can see the worst performing digit is β9β.
LoRA Implementation :
Define the LoRA parameterization as described in the paper. The full detail on how PyTorch parameterizations work is here: click
class LoRAParametrization(nn.Module):
def __init__(self, features_in, features_out, rank=1, alpha=1, device='cpu'):
super().__init__()
# Section 4.1 of the paper:
# We use a random Gaussian initialization for A and zero for B, so βW = BA is zero at the beginning of training
self.lora_A = nn.Parameter(torch.zeros((rank,features_out)).to(device))
self.lora_B = nn.Parameter(torch.zeros((features_in, rank)).to(device))
nn.init.normal_(self.lora_A, mean=0, std=1)
# Section 4.1 of the paper:
# We then scale βWx by Ξ±/r , where Ξ± is a constant in r.
# When optimizing with Adam, tuning Ξ± is roughly the same as tuning the learning rate if we scale the initialization appropriately.
# As a result, we simply set Ξ± to the first r we try and do not tune it.
# This scaling helps to reduce the need to retune hyperparameters when we vary r.
self.scale = alpha / rank
self.enabled = True
def forward(self, original_weights):
if self.enabled:
# Return W + (B*A)*scale
return original_weights + torch.matmul(self.lora_B, self.lora_A).view(original_weights.shape) * self.scale
else:
return original_weights
import torch.nn.utils.parametrize as parametrize
def linear_layer_parameterization(layer, device, rank=1, lora_alpha=1):
# Only add the parameterization to the weight matrix, ignore the Bias
# From section 4.2 of the paper:
# We limit our study to only adapting the attention weights for downstream tasks and freeze the MLP modules (so they are not trained in downstream tasks) both for simplicity and parameter-efficiency.
# [...]
# We leave the empirical investigation of [...], and biases to a future work.
features_in, features_out = layer.weight.shape
return LoRAParametrization(
features_in, features_out, rank=rank, alpha=lora_alpha, device=device
)
parametrize.register_parametrization(
model.linear1, "weight", linear_layer_parameterization(model.linear1, device)
)
parametrize.register_parametrization(
model.linear2, "weight", linear_layer_parameterization(model.linear2, device)
)
parametrize.register_parametrization(
model.linear3, "weight", linear_layer_parameterization(model.linear3, device)
)
def enable_disable_lora(enabled=True):
for layer in [model.linear1, model.linear2, model.linear3]:
layer.parametrizations["weight"][0].enabled = enabled
Display the number of parameters added by LoRA.
total_parameters_lora = 0
total_parameters_non_lora = 0
for index, layer in enumerate([model.linear1, model.linear2, model.linear3]):
total_parameters_lora += layer.parametrizations["weight"][0].lora_A.nelement() + layer.parametrizations["weight"][0].lora_B.nelement()
total_parameters_non_lora += layer.weight.nelement() + layer.bias.nelement()
print(
f'Layer {index+1}: W: {layer.weight.shape} + B: {layer.bias.shape} + Lora_A: {layer.parametrizations["weight"][0].lora_A.shape} + Lora_B: {layer.parametrizations["weight"][0].lora_B.shape}'
)
# The non-LoRA parameters count must match the original network
assert total_parameters_non_lora == total_parameters_original
print(f'Total number of parameters (original): {total_parameters_non_lora:,}')
print(f'Total number of parameters (original + LoRA): {total_parameters_lora + total_parameters_non_lora:,}')
print(f'Parameters introduced by LoRA: {total_parameters_lora:,}')
parameters_incremment = (total_parameters_lora / total_parameters_non_lora) * 100
print(f'Parameters incremment: {parameters_incremment:.3f}%')
Output:
Layer 1: W: torch.Size([1000, 784]) + B: torch.Size([1000]) + Lora_A: torch.Size([1, 784]) + Lora_B: torch.Size([1000, 1])
Layer 2: W: torch.Size([2000, 1000]) + B: torch.Size([2000]) + Lora_A: torch.Size([1, 1000]) + Lora_B: torch.Size([2000, 1])
Layer 3: W: torch.Size([10, 2000]) + B: torch.Size([10]) + Lora_A: torch.Size([1, 2000]) + Lora_B: torch.Size([10, 1])
Total number of parameters (original): 2,807,010
Total number of parameters (original + LoRA): 2,813,804
Parameters introduced by LoRA: 6,794
Parameters incremment: 0.242%
Freezing all the parameters of the original network and only fine tuning the ones introduced by LoRA. Then fine-tune the model on the digit 9 and only for 100 batches.
# Freeze the non-Lora parameters
for name, param in model.named_parameters():
if 'lora' not in name:
print(f'Freezing non-LoRA parameter {name}')
param.requires_grad = False
# Load the MNIST dataset again, by keeping only the digit 9
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
exclude_indices = mnist_trainset.targets == 9
mnist_trainset.data = mnist_trainset.data[exclude_indices]
mnist_trainset.targets = mnist_trainset.targets[exclude_indices]
# Create a dataloader for the training
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)
# Train the network with LoRA only on the digit 9 and only for 100 batches (hoping that it would improve the performance on the digit 9)
train(train_loader, model, epochs=1, total_iterations_limit=100)
After Training the above new LoRA introduced weights model
Verifying that the fine-tuning didnβt alter the original weights, but only the ones introduced by LoRA.
# Check that the frozen parameters are still unchanged by the finetuning
assert torch.all(model.linear1.parametrizations.weight.original == original_weights['linear1.weight'])
assert torch.all(model.linear2.parametrizations.weight.original == original_weights['linear2.weight'])
assert torch.all(model.linear3.parametrizations.weight.original == original_weights['linear3.weight'])
enable_disable_lora(enabled=True)
# The new linear1.weight is obtained by the "forward" function of our LoRA parametrization
# The original weights have been moved to net.linear1.parametrizations.weight.original
# More info here: https://pytorch.org/tutorials/intermediate/parametrizations.html#inspecting-a-parametrized-module
assert torch.equal(model.linear1.weight, model.linear1.parametrizations.weight.original + (model.linear1.parametrizations.weight[0].lora_B @ model.linear1.parametrizations.weight[0].lora_A) * model.linear1.parametrizations.weight[0].scale)
enable_disable_lora(enabled=False)
# If we disable LoRA, the linear1.weight is the original one
assert torch.equal(model.linear1.weight, original_weights['linear1.weight'])
Testing the network with LoRA enabled (the digit 9 should be classified better)
# Test with LoRA enabled
enable_disable_lora(enabled=True)
test()
Output:
Accuracy: 0.924
wrong counts for the digit 0: 47
wrong counts for the digit 1: 27
wrong counts for the digit 2: 65
wrong counts for the digit 3: 240
wrong counts for the digit 4: 89
wrong counts for the digit 5: 32
wrong counts for the digit 6: 54
wrong counts for the digit 7: 137
wrong counts for the digit 8: 61
wrong counts for the digit 9: 9
Testing the network with LoRA disabled (the accuracy and errors counts must be the same as the original network)
enable_disable_lora(enabled=False)
test()
Output:
wrong counts for the digit 0: 31
wrong counts for the digit 1: 17
wrong counts for the digit 2: 46
wrong counts for the digit 3: 74
wrong counts for the digit 4: 29
wrong counts for the digit 5: 7
wrong counts for the digit 6: 36
wrong counts for the digit 7: 80
wrong counts for the digit 8: 25
wrong counts for the digit 9: 116
Conclusion :
The implementation weβve walked through demonstrates the power and efficiency of LoRA in practice. Through our MNIST example, weβve seen how LoRA can significantly improve model performance on specific tasks (like digit β9β recognition) while adding only 0.242% more parameters to the original model. This perfectly illustrates why PEFT techniques, particularly LoRA, are becoming increasingly important in the AI landscape.
Key takeaways from our exploration:
- PEFT techniques like LoRA make fine-tuning accessible even with limited computational resources
- By focusing on crucial parameters, we can achieve significant improvements in task-specific performance
- The original model weights remain unchanged, allowing for multiple task-specific adaptations
- The implementation requires minimal code changes to existing architectures
The future of AI model adaptation lies in such efficient techniques that balance performance with resource utilization. As models continue to grow in size and complexity, PEFT approaches will become even more crucial for practical applications.
GitHub Repository :
I have created an project in which you can fine tune resnet on your custom dataset by using the technique that we have just learned.
For the complete code and implementation details, visit: github.com/yourusername/peft-lora-guide
Join thousands of data leaders on the AI newsletter. Join over 80,000 subscribers and keep up to date with the latest developments in AI. From research to projects and ideas. If you are building an AI startup, an AI-related product, or a service, we invite you to consider becoming aΒ sponsor.
Published via Towards AI