Towards AI Can Help your Team Adopt AI: Corporate Training, Consulting, and Talent Solutions.

Publication

How to Fine-Tune Meta SAM
Latest   Machine Learning

How to Fine-Tune Meta SAM

Last Updated on August 1, 2023 by Editorial Team

Author(s): Luhui Hu

Originally published on Towards AI.

Photo by Wolfgang Hasselmann on Unsplash

Checkpoint-based fine-tuning of Meta pre-trained SAM for domain images

Transformers and foundation models have reinvented AI and ushered in a new AI era of unifying NLP and CV, two main AI domains. Meta’s two large-scale models are leading the way in the current AI open-source world: LLaMA for LLM and SAM for CV.

SAM stands for Segment Anything Model, a state-of-the-art transformer-based (ViT backbone) vision segment foundation model. It can segment any images automatically with zero-shot generalization. Segmenting is a critical category in CV and essential for image classification and object detection.

Leveraging and fine-tuning LLaMA and SAM become a trend. Here will introduce how to fine-tune Meta pre-trained SAM based on its checkpoint for new images. It’s crucial to improve accuracy and performance with domain images as SAM is fantastic, but it’s still a research or demo project.

Setup

First, we need to install all the necessary dependencies as specified in the Meta SAM readme as follows.

# Run these in Python environment
!pip install git+https://github.com/facebookresearch/segment-anything.git
!pip install opencv-python pycocotools matplotlib onnxruntime onnx

To understand better, here will provides three implementations of how to fine-tune SAM:

  1. Fine-tuning SAM for new images based on the foundation model checkpoint
  2. Fine-tuning SAM for new images with masks to improve domain-specific accuracy
  3. Fine-tuning SAM according to SAM preferred optimizer and dataset formats

First Fine-Tuning Example Based on the Checkpoint

Training SAM requires 256 A100 GPUs for 3~5 days. It is very expensive to train the entire model from scratch each time. Training or fine-tuning from a pre-trained foundation model checkpoint is strongly recommended, which is also the reference baseline for transfer learning.

SAM provides sam_model_registry method to load its model checkpoint. Below illustrates how to fine-tune a foundation model like SAM based on the checkpoint.

from segment_anything import SamPredictor, sam_model_registry
from torch.utils.data import DataLoader
from torchvision import transforms
from PIL import Image
import torch
import os

# Loading the model based on checkpoint
sam = sam_model_registry["<model_type>"](checkpoint="<path/to/checkpoint>")
predictor = SamPredictor(sam)

# Define dataset
class CustomDataset(torch.utils.data.Dataset):
def __init__(self, root_dir, transform=None):
self.root_dir = root_dir
self.transform = transform
self.images = os.listdir(self.root_dir)

def __len__(self):
return len(self.images)

def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()

img_name = os.path.join(self.root_dir, self.images[idx])
image = Image.open(img_name)

if self.transform:
image = self.transform(image)

return image

# Define transforms
transform = transforms.Compose([
transforms.Resize((224, 224)), # Resize to the size the model expects
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Normalization values for pre-trained PyTorch models
])

# Load custom dataset
dataset = CustomDataset(root_dir='<path_to_images>', transform=transform)

# Create a DataLoader
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

# Fine-tuning the model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
predictor.sam.to(device)
predictor.sam.train()

# Define loss function and optimizer
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(predictor.sam.parameters(), lr=0.001, momentum=0.9)

for epoch in range(2): # loop over the dataset multiple times
running_loss = 0.0
for i, data in enumerate(dataloader, 0):
inputs = data.to(device)

optimizer.zero_grad()

outputs = predictor.sam(inputs)
loss = criterion(outputs, inputs)
loss.backward()
optimizer.step()

running_loss += loss.item()
if i % 2000 == 1999: # print every 2000 mini-batches
print('[%d, %5d] loss: %.3f' %
(epoch + 1, i + 1, running_loss / 2000))
running_loss = 0.0

print('Finished Training')

In the above, "<model_type>" and "<path/to/checkpoint>" need to be replaced with the appropriate model type and path to the checkpoint file. Also, '<path_to_images>' is the path to the image dataset.

Here, we are fine-tuning the entire model. In some cases, we may want to freeze the weights of some layers (typically the earlier layers) and only fine-tune the later layers. This can be done by setting the requires_grad the attribute of the parameters you want to freeze to False.

Also, it is intended to illustrate and simply use CrossEntropyLoss, which may not be the best choice for a task. Depending on the specific use case, we may want to use a different loss function. For instance, for a segmentation task, we might want to use a loss function that is more suitable for comparing images, such as the Dice loss or the Jaccard/Intersection over Union loss.

In the beginning, we fine-tune the model with unlabeled data (i.e., in an unsupervised manner). Next, we have labels for images by modifying the CustomDataset class to load the labels and adjust the loss computation accordingly.

Domain-Specific Fine-Tuning with Masks

For a segmentation task with labeled data, we need to make a few adjustments:

  1. The dataset should return both images and their corresponding segmentation masks.
  2. The loss function should be suitable for segmentation tasks, such as the Dice loss or the Jaccard/Intersection over Union loss. For simplicity, we’ll use the BCEWithLogitsLoss, a combination of a Sigmoid layer and the BCELoss (Binary Cross Entropy Loss) in one single class.

Below is the adjusted Python code for this requirement:

from segment_anything import SamPredictor, sam_model_registry
from torch.utils.data import DataLoader
from torchvision import transforms
from PIL import Image
import torch
import os

# Loading the model
sam = sam_model_registry["<model_type>"](checkpoint="<path/to/checkpoint>")
predictor = SamPredictor(sam)

# Define dataset with masks
class CustomDataset(torch.utils.data.Dataset):
def __init__(self, root_dir, mask_dir, transform=None):
self.root_dir = root_dir
self.mask_dir = mask_dir
self.transform = transform
self.images = os.listdir(self.root_dir)

def __len__(self):
return len(self.images)

def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()

img_name = os.path.join(self.root_dir, self.images[idx])
mask_name = os.path.join(self.mask_dir, self.images[idx])

image = Image.open(img_name)
mask = Image.open(mask_name)

if self.transform:
image = self.transform(image)
mask = self.transform(mask)

return image, mask

# Define transforms
transform = transforms.Compose([
transforms.Resize((224, 224)), # Resize to the size a model expects
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Normalization values for pre-trained PyTorch models
])

# Load custom dataset
dataset = CustomDataset(root_dir='<path_to_images>', mask_dir='<path_to_masks>', transform=transform)

# Create a DataLoader
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

# Fine-tuning the model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
predictor.sam.to(device)
predictor.sam.train()

# Define loss function and optimizer
criterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.SGD(predictor.sam.parameters(), lr=0.001, momentum=0.9)

for epoch in range(2): # loop over the dataset multiple times
running_loss = 0.0
for i, data in enumerate(dataloader, 0):
inputs, labels = data
inputs = inputs.to(device)
labels = labels.to(device)

optimizer.zero_grad()

outputs = predictor.sam(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()

running_loss += loss.item()
if i % 2000 == 1999: # print every 2000 mini-batches
print('[%d, %5d] loss: %.3f' %
(epoch + 1, i + 1, running_loss / 2000))
running_loss = 0.0

print('Finished Training')

Again, "<model_type>" and "<path/to/checkpoint>" need to be replaced with the appropriate model type and path to the checkpoint file. And '<path_to_images>' and '<path_to_marks>' are the paths to the image dataset and mask dataset, respectively.

Here, assume the masks are in the same format as the images and are named identically to their corresponding images. If the masks are in a different format or are named differently, we need to adjust the __getitem__ method of the CustomDataset class accordingly.

It also assumes the masks are single-channel (grayscale) images where the pixel value indicates the class of that pixel. If the masks are in a different format, we need to adjust the code accordingly.

The BCEWithLogitsLoss is suitable for binary segmentation tasks (i.e., tasks where each pixel can belong to one of two classes). If we have a multi-class segmentation task (i.e., each pixel can belong to one of more than two classes), we need to use a different loss function, such as the CrossEntropyLoss. In this case, we also need to ensure the masks are in the correct format (a single-channel image where the pixel value indicates the class of that pixel).

SAM Specific Fine-Tuning

With the above practices, we can try to fine-tune Meta SAM with prompt encoding for domain-specific images. This is closer to what SAM supports.

In the above, input_size and original_image_size are the dimensions of the input images and original images, respectively. 'custom_dataset' is the actual dataset loader that yields images, bounding boxes (or other prompts), and ground truth masks. It is a JSON object like:

custom_dataset = [
{
"image": {
"image_id": 0,
"width": 1024,
"height": 768,
"file_name": "path_to_images/image0.jpg",
},
"annotations": [
{
"id": 0,
"segmentation": rle0, # Replace with the actual RLE for the mask
"bbox": [100, 100, 200, 200],
"area": 40000,
"predicted_iou": 0.85,
"stability_score": 0.95,
"crop_box": [50, 50, 150, 150],
"point_coords": [[125, 125]],
},
],
},
{
"image": {
"image_id": 1,
"width": 1024,
"height": 768,
"file_name": "path_to_images/image1.jpg",
},
"annotations": [
{
"id": 1,
"segmentation": rle1, # Replace with the actual RLE for the mask
"bbox": [200, 200, 300, 300],
"area": 40000,
"predicted_iou": 0.85,
"stability_score": 0.95,
"crop_box": [150, 150, 250, 250],
"point_coords": [[225, 225]],
},
],
},
# More image entries...
]

After the training loop, the code saves the fine-tuned model weights to a file named ‘fine_tuned_sam.pth’. We can load these weights later for inference on similar data or further fine-tuning.

What’s Next

Here describes how to fine-tune SAM and other foundation models. But this is just a starting point. We need to consider parallel training in distributed GPUs for efficiency. Also, it is important to fine-tune the parameters and choose the right loss function.

If you are interested in fine-tuning foundation models, please feel free to leave comments, insights, or/and questions here. Or reach out to Luhui on LinkedIn.

References

  1. Meta SAM GitHub: https://github.com/facebookresearch/segment-anything
  2. Meta SAM Blog with links to demo and paper: https://ai.facebook.com/blog/segment-anything-foundation-model-image-segmentation/
  3. HuggingFace discussions: https://github.com/huggingface/transformers/issues/22592
  4. Exploring Plain Vision Transformer Backbones for Object Detection: https://arxiv.org/abs/2203.16527

Join thousands of data leaders on the AI newsletter. Join over 80,000 subscribers and keep up to date with the latest developments in AI. From research to projects and ideas. If you are building an AI startup, an AI-related product, or a service, we invite you to consider becoming a sponsor.

Published via Towards AI

Feedback ↓