Empowering Computer Vision with Zero-Shot Learning in the Data-Centric Small Data Age
Last Updated on July 17, 2023 by Editorial Team
Author(s): Luhui Hu
Originally published on Towards AI.
Inside zero-shot learning from OpenAI CLIP, Microsoft RegionCLIP, and Meta SAM
The era of big data has provided Machine Learning (ML) and Computer Vision (CV) with ample fuel for training increasingly complex models. However, in many practical scenarios, weβre faced with the challenge of data scarcity. Situations, where we have only a few or even no labeled examples for a particular class, are not uncommon. Fortunately, advancements in Zero-Shot and Few-Shot Learning are bringing new solutions to the table. Inspired by models like OpenAIβs CLIP, Microsoftβs RegionCLIP, and Metaβs Segment Anything Model (SAM), we explore how zero-shot learning is revolutionizing the field of CV.
A Leap into Zero-shot Learning
Zero-Shot Learning aims to classify unseen categories, i.e., those not present during training. The key is to leverage auxiliary information that describes the relationships between different classes. For instance, a learned joint embedding space of images and text allows the model to leverage descriptions or attributes of unseen classes.
OpenAIβs CLIP model illustrates this principle perfectly. By learning to associate images and text in a shared embedding space, CLIP can understand and generate descriptions for images of unseen objects, showcasing strong zero-shot performance.
The first letter of CLIP stands for Contrastive. Contrastive learning is a self-supervised learning method in ML allowing models to recognize patterns in data by identifying similarities and differences, without needing labeled data. Itβs like teaching the model to understand the underlying structure of the data, without explicit instructions.
The relationship with zero-shot learning, a technique for making predictions for unseen classes, is complementary. Contrastive learning can first extract useful features from data, then zero-shot learning can use these features to make predictions for new, unseen classes. Both methods together make the model more efficient and less dependent on large amounts of labeled data.
RegionCLIP: Zero-Shot Object Detection
Microsoftβs RegionCLIP, a hybrid of Faster R-CNN (a popular object detection framework) and CLIP, offers an interesting case of zero-shot learning applied to object detection. In RegionCLIP, region proposals from Faster R-CNN are fed into CLIP, and the output probabilities are then used for the detection of unseen classes.
The Python implementation of such a system involves using the outputs from the Faster R-CNN model as inputs to the CLIP model:
import torch
from torchvision.models.detection import fasterrcnn_resnet50_fpn
from CLIP import CLIP, tokenize
# Initialize models
detection_model = fasterrcnn_resnet50_fpn(pretrained=True)
clip_model = CLIP()
# Forward an image through the detection model
image = 'get image' # the input image here
detections = detection_model([image])
# Use the region proposals as input to the CLIP model
region_proposals = detections[0]['boxes']
clip_inputs = [image[:, y1:y2, x1:x2] for (x1, y1, x2, y2) in region_proposals]
clip_embeddings = clip_model.encode_image(clip_inputs)
Segment Anything (SAM): Zero-Shot Learning for Image Segmentation
Metaβs Segment Anything Model (SAM) leverages zero-shot learning for the image segmentation task. SAM, designed to be promptable, can generate valid segmentation masks given a segmentation prompt in real-time, making it a fantastic tool for interactive usage scenarios.
The model has three main components: an image encoder for computing an image embedding, a prompt encoder for embedding prompts, and a mask decoder for predicting segmentation masks.
import torch
from torch import nn
from torch.nn import functional as F
from typing import Any, Dict, List, Tuple
from .image_encoder import ImageEncoderViT
from .mask_decoder import MaskDecoder
from .prompt_encoder import PromptEncoder
class Sam(nn.Module):
mask_threshold: float = 0.0
image_format: str = "RGB"
def __init__(
self,
image_encoder: ImageEncoderViT,
prompt_encoder: PromptEncoder,
mask_decoder: MaskDecoder,
pixel_mean: List[float] = [123.675, 116.28, 103.53],
pixel_std: List[float] = [58.395, 57.12, 57.375],
) -> None:
...
def postprocess_masks(
self,
masks: torch.Tensor,
input_size: Tuple[int, ...],
original_size: Tuple[int, ...],
) -> torch.Tensor:
...
def preprocess(self, x: torch.Tensor) -> torch.Tensor:
...
With the βdata engineβ approach, SAM trained on the SA-1B dataset collected over one billion masks, setting a new record in the domain of image segmentation.
In a Nutshell
Zero-Shot Learning in the data-centric and small-data world brings about new possibilities for Computer Vision, alleviating the traditional dependency on large labeled datasets. Through the introduction of models like CLIP, RegionCLIP, and SAM, engineers and scientists can now leverage these techniques for various CV tasks, including image classification, object detection, and image segmentation, even when labeled data is scarce. As the field continues to advance, the adaptability and performance of these models are only set to improve.
Resources
- OpenAI CLIP GitHub: https://github.com/openai/CLIP
- Meta SAM GitHub: https://github.com/facebookresearch/segment-anything
- Microsoft RegionCLIP: https://github.com/microsoft/RegionCLIP
- Federated Zero-Shot Learning for Visual Recognition: https://arxiv.org/abs/2209.01994
- Zero-Shot Learning β A Comprehensive Evaluation of the Good, the Bad and the Ugly: https://arxiv.org/abs/1707.00600
Join thousands of data leaders on the AI newsletter. Join over 80,000 subscribers and keep up to date with the latest developments in AI. From research to projects and ideas. If you are building an AI startup, an AI-related product, or a service, we invite you to consider becoming aΒ sponsor.
Published via Towards AI