Fine-Tune LLMs with Unsloth
Last Updated on October 31, 2024 by Editorial Team
Author(s): Barhoumi Mosbeh
Originally published on Towards AI.
Why Fine-Tune When We Have RAG?
Itβs a question I see a lot β with RAG (Retrieval-Augmented Generation) becoming increasingly popular, why bother with fine-tuning at all? While RAG is fantastic for many use cases, fine-tuning still has its place in the ML toolkit.
Hereβs why: Fine-tuning allows you to fundamentally alter how your model βthinksβ about specific domains. While RAG provides context at inference time, fine-tuning builds domain expertise directly into the modelβs weights. This is particularly powerful when you need:
- Consistent domain-specific behavior
- Faster inference (no need to search through external documents)
- Specialized knowledge thatβs difficult to capture in reference documents
Plus, thereβs a compelling cost argument: You can fine-tune smaller models for specific tasks and achieve performance comparable to much larger models at a fraction of the hosting cost.
I found this discussion on Reddit great. Have a look!
Enter Unsloth: Making Fine-Tuning Accessible
Training times have always been one of the biggest barriers to fine-tuning. Thatβs where Unsloth comes in β a new optimization framework that claims to make LLM training up to 30x faster.
The secret to Unslothβs efficiency lies in deep optimization. While PyTorch and Transformers are built for flexibility across different architectures, Unsloth takes a more focused approach. It combines techniques like QLoRA and Triton with architecture-specific optimizations to squeeze maximum performance out of the training process.
Hands-on: Fine-Tuning for SQL Generation
Letβs put this into practice by fine-tuning a model to generate SQL queries. Weβll use Llama-3.2β3B, a 3-billion parameter model that strikes a good balance between capability and resource requirements.
First, itβs important to find a good dataset to fine-tune the model, and the reason why finding the right dataset is so crucial is that when you train a small language model with data relevant to the task at hand, it can actually outperform larger models. What we aim to do is create a small, fast LLM that generates SQL queries based on table data.
One of the most significant datasets for this purpose is called Synthetic Text to SQL, which contains over 105,000 records divided into columns of prompt SQL content, complexity, and more.
Here is the link to the dataset:
Synthetic Text to SQL
Setting Up the Environment
First, letβs install the necessary packages. Weβll need to manage our PyTorch installation carefully:
%%capture
!pip install pip3-autoremove
!pip-autoremove torch torchvision torchaudio -y
!pip install "torch==2.4.0" "xformers==0.0.27.post2" triton torchvision torchaudio
!pip install "unsloth[kaggle-new] @ git+https://github.com/unslothai/unsloth.git"
!pip install datasets
Loading the Model
Now weβll load our base model using Unslothβs optimized loader:
from unsloth import FastLanguageModel
import torch
max_seq_length = 2048 # Choose any! We auto support RoPE Scaling internally!
dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.
model, tokenizer = FastLanguageModel.from_pretrained(
model_name = "unsloth/Llama-3.2-3B-bnb-4bit",
max_seq_length = max_seq_length,
dtype = dtype,
load_in_4bit = load_in_4bit,
)
Setting Up PEFT
Weβll load the PEFT (Parameter-Efficient Fine-Tuning) model, which uses LoRA (Low-Rank Adaptation) adapters. If youβre not familiar with these terms, donβt worry. LoRA adapters allow us to update only 1β10% of the modelβs parameters during fine-tuning. Without them, weβd need to retrain the entire model, which would be significantly more time-consuming, computationally intensive, and expensive. Unsloth provides these recommended settings for optimal performance. While weβll use their default configuration for this tutorial, feel free to explore and adjust these parameters based on your needs.
model = FastLanguageModel.get_peft_model(
model,
r = 16, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",],
lora_alpha = 16,
lora_dropout = 0, # Supports any, but = 0 is optimized
bias = "none", # Supports any, but = "none" is optimized
use_gradient_checkpointing = "unsloth", # 4x longer contexts auto supported!
random_state = 3407,
use_rslora = False, # We support rank stabilized LoRA
loftq_config = None, # And LoftQ
)
Data
Now, this is where things can get a little bit tricky depending on what data set youβre using. The each data set comes different from each other, but theyβre each formatted in the same way such that the large language model can understand it. Llama3.2 uses alpaca prompts, which look like this.
Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
### Instruction:
[The task or question you want the model to perform/answer]
### Input:
[Additional context or information needed to complete the task. This can be empty if the instruction is self-contained]
### Response:
[The expected output or answer you want the model to learn]
For our SQL database project, weβre specifically interested in three components:
- The SQL query prompt
- The generated SQL code
- The explanation of the code
from datasets import Dataset, load_dataset
# Define the prompt template with variables matching the loop content
alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
### Instruction:
{instruction}
### Input:
{input}
### Response:
{response}"""
# Set the EOS token (assuming the tokenizer is already defined)
EOS_TOKEN = tokenizer.eos_token
# Formatting function to apply the prompt template to the dataset
def formatting_prompts_func(examples):
company_databases = examples["sql_context"]
prompts = examples["sql_prompt"]
sqls = examples["sql"]
explanations = examples["sql_explanation"]
texts = []
for company_database, prompt, sql, explanation in zip(company_databases, prompts, sqls, explanations):
# Substitute the correct placeholders
text = alpaca_prompt.format(
instruction=prompt,
input=company_database,
response=sql + " " + explanation
) + EOS_TOKEN
texts.append(text)
return {"text": texts} # Ensure the formatted text is returned as a "text" field
# Load dataset and map formatting function to add prompts
ds = load_dataset("gretelai/synthetic_text_to_sql")
formatted_ds = ds.map(formatting_prompts_func, batched=True) # Apply formatting
# Select the 'train' split from the formatted dataset
train_dataset = formatted_ds['train']
Training Configuration
There are a lot of parameters to use, and all that can be described. For example, have the maximum number of steps, which tells us how many training steps to perform. Seed is a random number generator. We used to be able to reproduce results, and warmup steps gradually increased the learning rate over time.
from transformers import TrainingArguments
from unsloth import is_bfloat16_supported
from trl import SFTTrainer
# Trainer setup
trainer = SFTTrainer(
model=model, # Ensure model is defined
tokenizer=tokenizer, # Ensure tokenizer is defined
train_dataset=train_dataset, # Use the 'train' split from formatted_ds
dataset_text_field="text", # This is the field we created with formatted prompts
max_seq_length=max_seq_length, # Ensure max_seq_length is defined
dataset_num_proc=2,
packing=False, # Can make training 5x faster for short sequences.
args=TrainingArguments(
per_device_train_batch_size=2,
gradient_accumulation_steps=4,
warmup_steps=5,
max_steps=60,
learning_rate=2e-4,
fp16=not is_bfloat16_supported(),
bf16=is_bfloat16_supported(),
logging_steps=1,
optim="adamw_8bit",
weight_decay=0.01,
lr_scheduler_type="linear",
seed=3407,
output_dir="outputs",
report_to="none", # Disable WANDB logging
)
)
So now that we have everything set up, letβs run it. And thatβs it.
trainer_stats = trainer.train()
Resources
Code on colab: https://colab.research.google.com/drive/1BHuj-8mA8lvxJQyQsCyp-7vlGkXuRxKY?usp=sharing
unsloth (Unsloth AI)
Hey! We're focusing on making AI more accessible to everyone!
huggingface.co
https://www.reddit.com/r/LocalLLaMA/comments/1ar7e4m/unsloth_whats_the_catch_seems_too_good_to_be_true/
Join thousands of data leaders on the AI newsletter. Join over 80,000 subscribers and keep up to date with the latest developments in AI. From research to projects and ideas. If you are building an AI startup, an AI-related product, or a service, we invite you to consider becoming aΒ sponsor.
Published via Towards AI