Build and Train Vision Transformer from Scratch
Last Updated on August 1, 2023 by Editorial Team
Author(s): Mikhail Kravets
Originally published on Towards AI.
A few years ago, it was hard to imagine what a transformer is; today, it is hard to imagine a modern neural network that doesnβt use transformers.
In this tutorial, weβll be building Vision Transformer using PyTorch and PyTorch Lightning. Along with the ViT model, you will also see how to organize your code in a well-structured and efficient manner.
All the code of the tutorial can be found in the vision_transformer repository.
Overview
Letβs have a quick theory overview before we proceed to the practical part of the tutorial.
Transformer & self-attention
The history of transformers began with the Attention Is All You Need work. Initially, they were used for machine translation but later expanded to solve various tasks. Jay Alammar explains transformers in his pretty detailed article, Illustrated Transformer.
In the diagram below, you may see the architecture of the transformer network for the machine translation task.
Transformer has Encoder and Decoder blocks. We only need encoders for the vision transformer model.
Encoder (so also a Decoder) is based on a mechanism called self-attention.
Multi-Head Attention
block calculates the element importance (or attention) score for each element in a sequence. For example, letβs take a sentence
The animal didnβt cross the street because it was too tired. [5].
One attention score vector for the element `it` may look like
Encoder and Decoder blocks are identical except for a tiny difference. Encoder can attend to all elements in a sequence to calculate attention scores. You may see encoder attention in figure 3. BERT model is an example of encoders-only architecture. Decoders, on the contrary, can attend only to the previous elements in a sequence during the calculation of attention scores. For instance, GPT model is a decoder-only model.
Vision Transformer
As already mentioned above, we can use transformers for image classification tasks. The main difference between Vision Transformer and an NLP transformer is that we should apply a special embedding operation to the images.
The image embedding begins with image preprocessing. The image should be split into 2D patches as shown in figure 4. The resulting number of patches N
can be calculated as
where H
is height, W
is width and P
is patch size.
After weβve got image patches, we flatten each patch to a 1D vector. The size of the flattened patch vector can be calculated as
where C
is the amount of color channels (C = 3
).
After patch transformation is done, we have an image represented as a matrix of size N x M
. This matrix is the input tensor that is fed to the model. The input tensor then goes through a linear projection which is then concatenated with [class]
token parameter and summed with learnable position embedding. The authors of the original paper discuss position embedding in Appendix D.4 of the original paper.
After processing the input tensor with embeddings, there is a standard set of encoders with a classification head at the end.
Letβs move to the code U+1F3B8.
Installation
As was said above, vision_transformer repository contains the full working project, but the installation of dependencies may be tricky.
CPU or MPS
If you plan to train the network on CPU or youβre using Apple computers, you may proceed with the standard installation flow, i.e., install packages from requirements.txt
file:
pip install -r requirements.txt
CUDA
If youβre going to train the network on GPU, you should install PyTorch with the command from the official website. Then you have to install the remaining packages from requirements-cuda.txt
. The installation instructions may look like this:
pip install torch torchvision --index-url https://download.pytorch.org/whl/cu118
pip install -r requirements-cuda.txt
Dataset
We use CIFAR10 dataset, which consists of 60,000 images of 10 classes (50k for training and 10k for validation/testing). The size of a single image is 32 x 32
with 3 RGB color channels.
Full code that manages datasets can be found at src/dataset.py.
The first thing we should do is import the required objects and define constants:
from pathlib import Path
import torch
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision.transforms import transforms, AutoAugment, AutoAugmentPolicy
BASE_DIR = Path(__file__).parent.parent
PyTorch
already has CIFAR10
dataset implemented in its child package torchvision
. So, we use the one directly from torchvision.datasets
.
Lightning Data Module
We create CIFAR10 dataset in pytorch_lightning.LightningDataModule
. Data Module simplifies the usage of datasets and, especially, data loaders during the training phase.
Letβs see the full code of the data module. Then we skim through it step by step.
class CIFAR10DataModule(pl.LightningDataModule):
def __init__(self, batch_size: int, patch_size: int = 4, val_batch_size: int = 16):
super().__init__()
self.batch_size = batch_size
self.val_batch_size = val_batch_size
self.train_transform = transforms.Compose(
[
transforms.RandomHorizontalFlip(),
transforms.RandomResizedCrop(size=(im_size, im_size)),
transforms.RandomRotation(degrees=rotation_degrees),
AutoAugment(AutoAugmentPolicy.CIFAR10),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)),
PatchifyTransform(patch_size)
]
)
self.val_transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)),
PatchifyTransform(patch_size)
]
)
self.patch_size = patch_size
self.ds_train = None
self.ds_val = None
def prepare_data(self) -> None:
CIFAR10(BASE_DIR.joinpath('data/cifar'), train=True, transform=self.train_transform, download=True)
CIFAR10(BASE_DIR.joinpath('data/cifar'), train=False, transform=self.val_transform, download=True)
def setup(self, stage: str) -> None:
self.ds_train = CIFAR10(BASE_DIR.joinpath('data/cifar'), train=True, transform=self.train_transform)
self.ds_val = CIFAR10(BASE_DIR.joinpath('data/cifar'), train=False, transform=self.val_transform)
def train_dataloader(self):
# Due to small dataset we don't need to use multiprocessing
return DataLoader(self.ds_train, batch_size=self.batch_size, shuffle=True)
def val_dataloader(self):
return DataLoader(self.ds_val, batch_size=self.val_batch_size)
@property
def classes(self):
return 10 # CIFAR10 has 10 possible classes
In the code above, we create a classCIFAR10DataModule
that inherits LightningDataModule
. The main intent of the data module is to create data loaders, not datasets. This is why we pass batch sizes in the constructor.
LightningDataModule
has several methods to override:
prepare_data
method is called within a single CPU process, meaning that your data will not be corrupt. It is called before the training, so we use it to downloadCIFAR10
data into the local directory;setup
method is called afterprepare_data
. Here we instantiate train and validation datasets;train_dataloader
returns data loader for training;val_dataloader
returns data loader for validation;classes
is just a property that returns a number of classes.
Transforms
Transform is an operation to apply to an image before we pass it to the model. We define two sets of transforms: for training and validation. Training transform looks like this:
self.train_transform = transforms.Compose(
[
transforms.RandomHorizontalFlip(),
transforms.RandomResizedCrop(size=(im_size, im_size)),
transforms.RandomRotation(degrees=rotation_degrees),
AutoAugment(AutoAugmentPolicy.CIFAR10),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)),
PatchifyTransform(patch_size)
]
)
CIFAR10
has only 50,000 training images which is a relatively small dataset for neural networks, so we use data augmentation to expand it. The first four transforms do it:
RandomHorizontalFlip
reflects image horizontally with 50% probability (by default);RandomResizedCrop
crops a random portion of an image and resize it to the given size (in our case32 x 32
);RandomRotation
rotates an image by a given degree from a range (in our case a range from -30 to 30 degrees);AutoAugment
is a pre-trained set of auto-augmentation policies. It is described in AutoAugment: Learning Augmentation Strategies from Data. We use pre-trained auto-augmentation forCIFAR10
.
Then we convert our image to a tensor with size 3 x 32 x 32
, normalize it with CIFAR10
mean and std values, and split the image tensor on patches.
Pytchify Image
Image patching is, in my opinion, the hardest part of the project. In order to have well-structured code, I move image patching to its own transform class:
class PatchifyTransform:
def __init__(self, patch_size):
self.patch_size = patch_size
def __call__(self, img: torch.Tensor):
res = img.unfold(1, self.patch_size, self.patch_size) # 3 x 8 x 32 x 4
res = res.unfold(2, self.patch_size, self.patch_size) # 3 x 8 x 8 x 4 x 4
return res.reshape(-1, self.patch_size * self.patch_size * 3) # -1 x 48 == 64 x 48
PyTorch Tensor has a method called unfold(dimension, size, step) that does exactly what we need. It creates a sliding window along the given dimension and unfolds it to a new dimension. Letβs take apart __call__
method. Its first row:
res = img.unfold(1, self.patch_size, self.patch_size)
The tensor img
has size 3 x 32 x 32
and self.patch_size
equals to 4
. unfold
method goes through all patches of size 4
with a step 4
along the dimension 1
that has 32
elements and put found patches into a new dimension.
So, now we have a new tensor res
with a size 3 x 8 x 32 x 4
. For easier understanding, I follow this logic:
- Discard color dimension for a moment. The tensor is
8 x 32 x 4
; - This tensor can be seen as an
8 x 32
matrix of four-element vectors, for example,[0.3, 0.01, 0.4, 0.7]
; - Now add color dimension. There is
8 x 32
four-element vectors for each RGB color channel.
In the second row, we unfold the second dimension of the image:
res = res.unfold(2, self.patch_size, self.patch_size)
Now, the size of res
is 3 x 8 x 8 x 4 x 4
which may seem insane but it isn't. You may understand it as follows:
- Discard color dimension, for now, having
8 x 8 x 4 x 4
tensor; - Each element in
8 x 8
matrix contains4 x 4
patch; - Bring back color dimension and you have three
8 x 8
matrices of4 x 4
patches.
What is remained to do is to reshape tensor res
back to 2D matrix. It is achieved with the reshape
method:
res.reshape(-1, self.patch_size * self.patch_size * 3)
After reshape
operation is done, res
is a matrix of size64 x 48
. If you worked with NLP tasks you may notice how similar it is to a sentence embedding: the first dimension corresponds to a word and the second dimension corresponds to a context vector of a given word.
Model
We build a model as a set of standard PyTorch modules, except the main ViT
module. ViT
inherits LightningModule
. The model diagram is displayed in the figure below.
The full code of the model can be found at src/basic.py.
Letβs take a detailed look at each module.
InputEmbedding
Input Embedding accepts a batch of patchified images and returns a full embedding of patches with the [class]
token prepended.
class ImageEmbedding(nn.Module):
def __init__(self, size: int, hidden_size: int, num_patches: int, dropout: float = 0.2):
super().__init__()
self.projection = nn.Linear(size, hidden_size)
self.class_token = nn.Parameter(torch.rand(1, hidden_size))
self.position = nn.Parameter(torch.rand(1, num_patches + 1, hidden_size))
self.dropout = nn.Dropout(dropout)
def forward(self, inp: torch.Tensor):
res = self.projection(inp)
class_token = self.class_token.repeat(res.size(0), 1, 1) # batch_size x 1 x output_size
res = torch.concat([class_token, res], dim=1)
position = self.position.repeat(res.size(0), 1, 1)
return self.dropout(res + position)
An interesting thing to look at is how we create class_token
and position
embedding parameters:
self.class_token = nn.Parameter(torch.rand(1, hidden_size))
self.position = nn.Parameter(torch.rand(1, num_patches + 1, hidden_size))
Tensors that are created via nn.Parameter
are added to the graph and trained during the fit process.
class_token
is a tensor of size 1 x hidden_size
which we later repeat for each batch. position
tensor has size 1 x num_patches + 1 x hidden_size
and is also repeated for each batch. The first dimension of position
is num_patches + 1
because [class]
token is taken into consideration by position
as well.
Letβs take a closer look at the forward
method:
def forward(self, inp: torch.Tensor):
res = self.projection(inp)
First of all, we accept inp
tensor that has size batch_size x 64 x 48
. Then we pass it through a linear projection layer.
Context size 48
of the input tensor is too small. Our model will barely be able to catch dependencies between the input and the target. So, we should expand it to a hidden_size
. Projectionres
has size batch_size x 64 x hidden_size
.
In the next operation, we repeat class_token
parameter for each element in a batch and concatenate it with res
class_token = self.class_token.repeat(res.size(0), 1, 1)
res = torch.concat([class_token, res], dim=1)
The size of class_token
tensor is batch_size x 1 x hidden_size
.
position = self.position.repeat(res.size(0), 1, 1)
return self.dropout(res + position)
Above operations repeat position
tensor for each image in batch, sum res
and position
tensor, and pass the result to the dropout
layer.
AttentionHead
After weβve got input embeddings, they are sent to each AttentionHead
. AttentionHead
module looks like this:
class AttentionHead(nn.Module):
def __init__(self, size: int): # size is hidden size
super(AttentionHead, self).__init__()
self.query = nn.Linear(size, size)
self.key = nn.Linear(size, size)
self.value = nn.Linear(size, size)
def forward(self, input_tensor: torch.Tensor):
q, k, v = self.query(input_tensor), self.key(input_tensor), self.value(input_tensor)
scale = q.size(1) ** 0.5
scores = torch.bmm(q, k.transpose(1, 2)) / scale
scores = F.softmax(scores, dim=-1)
# 8 x 64 x 64 @ 8 x 64 x 48 = 8 x 64 x 48
output = torch.bmm(scores, v)
return output
size
argument here is our hidden size.
Letβs set values of our parameters to which weβll refer later as an example:
batch_size = 64
sequence_size = 64
hidden_size = 512
num_heads = 8
Now letβs take a look at the code of forward
method:
def forward(self, input_tensor):
q, k, v = self.query(input_tensor), self.key(input_tensor), self.value(input_tensor)
At first, we create query
, key
, value
projections for input_tensor
. input_tensor
has size batch_size x sequence_size x hidden_size
or 64 x 64 x 512
.
In the next set of operations, we calculate attention scores using the famous formula:
scale = q.size(1) ** 0.5
scores = torch.bmm(q, k.transpose(1, 2)) / scale
scores = F.softmax(scores, dim=-1)
The size of scores
is sequence_size x sequence_size
U+007C 64 x 64
meaning that each element of the sequence has an attention score to each other element in a sequence.
Note, that we do not apply any masking. After scores
are calculated, we multiply them with value
tensor.
output = torch.bmm(scores, v)
The size of output
is the same as the size of the input_tensor
β batch_size x sequence_size x hidden_size
U+007C 64 x 64 x 512
.
MultiHeadAttention
The intent of MultiHeadAttention
module is to unite attention heads.
class MultiHeadAttention(nn.Module):
def __init__(self, size: int, num_heads: int):
super().__init__()
self.heads = nn.ModuleList([AttentionHead(size) for _ in range(num_heads)])
self.linear = nn.Linear(size * num_heads, size)
def forward(self, input_tensor: torch.Tensor):
s = [head(input_tensor) for head in self.heads]
s = torch.cat(s, dim=-1)
output = self.linear(s)
return output
size
argument in this module is the hidden size.
We calculate the output of each attention head and concatenate them into dimension 2
.
def forward(self, input_tensor: torch.Tensor):
s = [head(input_tensor) for head in self.heads]
s = torch.cat(s, dim=-1)
The resulting size of s
is batch_size x sequence_size x num_heads * hidden_size
. Regarding the example above, the size is 64 x 64 x 4096
.
output = self.linear(s)
return output
Then we pass s
through the linear layer. The size of output
is batch_size x sequence_size x hidden_size
U+007C 64 x 64 x 512
, the same as the input size.
Encoder
The encoder module contains multi-head attention and a feed-forward network. Also, it provides normalization to the data.
class Encoder(nn.Module):
def __init__(self, size: int, num_heads: int, dropout: float = 0.1):
super().__init__()
self.attention = MultiHeadAttention(size, num_heads)
self.feed_forward = nn.Sequential(
nn.Linear(size, 4 * size),
nn.Dropout(dropout),
nn.GELU(),
nn.Linear(4 * size, size),
nn.Dropout(dropout)
)
self.norm_attention = nn.LayerNorm(size)
self.norm_feed_forward = nn.LayerNorm(size)
def forward(self, input_tensor):
attn = input_tensor + self.attention(self.norm_attention(input_tensor))
output = attn + self.feed_forward(self.norm_feed_forward(attn))
return output
Feed-forward network is created as a Sequential module:
self.feed_forward = nn.Sequential(
nn.Linear(size, 4 * size),
nn.Dropout(dropout),
nn.GELU(),
nn.Linear(4 * size, size),
nn.Dropout(dropout)
)
We make the feed-forward network four times bigger than attention to make it more expressive and to capture more complex dependencies between input and the target. It helps mitigate the vanishing gradients problem as well.
We use Gaussian Error Linear Units (GELU) activation function.
forward
method is more than expressive:
def forward(self, input_tensor):
attn = input_tensor + self.attention(self.norm_attention(input_tensor))
output = attn + self.feed_forward(self.norm_feed_forward(attn))
return output
Note, that we apply the normalization layer before we pass tensors further to the network in contrast to Attention Is All You Need. This process is called pre-normalization. The paper Understanding the Difficulty of Training Transformers analyzes both approaches.
ViT
Finally, we are ready to proceed with the main module of the Vision Transformer ViT
. This class is kind of big to include the full code at once. I split it into several parts and cover them all individually.
ViT
not only embeds all parts of the model but also provides training functionality. Note, that ViT
inherits pl.LightningModule
, not nn.Module
.
class ViT(pl.LightningModule):
def __init__(self, size: int, hidden_size: int, num_patches: int, num_classes: int, num_heads: int,
num_encoders: int, emb_dropout: float = 0.1, dropout: float = 0.1,
lr: float = 1e-4, min_lr: float = 4e-5,
weight_decay: float = 0.1, epochs: int = 200):
super().__init__()
self.save_hyperparameters()
self.lr = lr
self.min_lr = min_lr
self.weight_decay = weight_decay
self.epochs = epochs
self.embedding = ImageEmbedding(size, hidden_size, num_patches, dropout=emb_dropout)
self.encoders = nn.Sequential(
*[Encoder(hidden_size, num_heads, dropout=dropout) for _ in range(num_encoders)],
)
self.mlp_head = nn.Linear(hidden_size, num_classes)
We create the model modules in the constructor:
- Input Embedding layer;
- Set of encoders;
- MLP head that does the final classification.
The forward step is seen in forward
method:
def forward(self, input_tensor: torch.Tensor):
emb = self.embedding(input_tensor)
attn = self.encoders(emb)
return self.mlp_head(attn[:, 0, :])
At first, we convert the input tensor into the model's inner state with added [class]
token and positional encoding.
emb
tensor has size batch_size x sequence_size x hidden_size
. Regarding the values defined above, its size is 64 x 64 x 512
. Then we pass emb
into a sequential set of encoders.
attn
tensor has the same size as emb
: batch_size x sequence_size x hidden_size
or 64 x 64 x 512
. The first element in the sequence attn
corresponds to the [class]
token. So, we pass only this element to the mlp_head
and return the value from the function.
attn[:, 0, :]
has size batch_size x hidden_size
U+007C 64 x 512
. The function outputs a tensor of logits with a size batch_size x num_classes
U+007C 64 x 10
.
Now we can use the output of the model to organize the training process.
Training
The training process starts in ViT
class. There are a few more methods that participate in the training:
configure_optimizers
configure_parameters
training_step
validation_step
Letβs run through each of them.
def configure_optimizers(self):
optimizer = AdamW(self.configure_parameters(), lr=self.lr)
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, self.epochs, eta_min=self.min_lr)
return {"optimizer": optimizer, "lr_scheduler": scheduler}
We should create our optimizers and schedulers in configure_optimizers
method. Read more at Lightning Optimization docs. We use AdamW with the CosineAnnealingLR scheduler. There is also a nice article that visualizes various learning rate schedulers A Visual Guide to Learning Rate Schedulers in PyTorch.
Note that we donβt pass all parameters of the model to AdamW
. We configure them in our custom method configure_parameters
:
def configure_parameters(self):
no_decay_modules = (nn.LayerNorm,)
decay_modules = (nn.Linear,)
decay = set()
no_decay = set()
for module_name, module in self.named_modules():
if module is self:
continue
for param_name, value in module.named_parameters():
full_name = f"{module_name}.{param_name}" if module_name else param_name
if param_name.endswith('bias'):
no_decay.add(full_name)
elif param_name.endswith('weight') and isinstance(module, no_decay_modules):
no_decay.add(full_name)
elif param_name.endswith('weight') and isinstance(module, decay_modules):
decay.add(full_name)
optim_groups = [
{"params": [v for name, v in self.named_parameters() if name in decay],
"weight_decay": self.weight_decay},
{"params": [v for name, v in self.named_parameters() if name in no_decay],
"weight_decay": 0}
]
return optim_groups
The code of the above method is taken from Andrej Karpathyβs minGPT. LayerNorm
module has its own regularization so, we should disable weight decay for this module. configure_parameters
prepares two groups of parameters: the ones with weight decay enabled and the ones with weight decay disabled.
def training_step(self, batch, batch_idx):
input_batch, target = batch
logits = self(input_batch)
loss = F.cross_entropy(logits, target)
if batch_idx % 5 == 0:
self.log("train_acc", logit_accuracy(logits, target), prog_bar=True)
self.log("train_loss", loss)
return loss
training_step
should return the loss of the particular training step. We use cross_entropy loss function. While logits
tensor has size batch_size x hidden_size
, the target
tensor size is (batch_size,)
. These are exact arguments that cross_entropy
function expects from us.
Also, for each period of time, we log loss and model accuracy to tensorboard
.
The code of validation_step
is the same as of training_step
:
def validation_step(self, batch, batch_idx):
input_batch, target = batch
output = self(input_batch)
loss = F.cross_entropy(output, target)
self.log("val_loss", loss, prog_bar=True)
self.log("val_accuracy", logit_accuracy(output, target), prog_bar=True)
return loss
A function that calculates accuracy looks like this:
def logit_accuracy(logits: torch.Tensor, target: torch.Tensor) -> float:
idx = logits.max(1).indices
acc = (idx == target).int()
return acc.sum() / torch.numel(acc)
logit_accuracy
function takes two tensors as arguments:
logits
tensor is the output of the model. It has a size ofbatch_size x 10
(10 is because we have 10 possible classes);target
is the target tensor for the batch. It has a size ofbatch_size,
.
At the first row of the logit_accuracy
, we take the class index with the maximum logit value, i.e. the element with the highest value along the 1 dimension:
idx = logits.max(1).indices
idx
tensor now has the same size as target
β batch_size,
.
Then we create a tensor acc
whose elements are 0 or 1, where 1 shows that the model output and target are equal for the specific element in a batch. 0, obviously, means that the prediction isnβt correct.
At the return statement, we calculate the ratio of correctly predicted classes to the total amount of elements.
Lightning Trainer
The script that runs model training is located in train.py.
First of all, there we import all required objects:
import torch
import pytorch_lightning as pl
from pathlib import Path
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor
from src.dataset import CIFAR10DataModule
from src.models.basic import ViT
Then we set constants and hyperparameters of the model:
BASE_DIR = Path(__file__).parent
LIGHTNING_DIR = BASE_DIR.joinpath("data/lightning")
MODELS_DIR = LIGHTNING_DIR.joinpath("models")
LOG_EVERY_N_STEPS = 50
MAX_EPOCHS = 200
BATCH_SIZE = 512
VAL_BATCH_SIZE = 512
PATCH_SIZE = 4
SIZE = PATCH_SIZE * PATCH_SIZE * 3
HIDDEN_SIZE = 512
NUM_PATCHES = int(32 * 32 / PATCH_SIZE ** 2) # 32 x 32 is the size of image in CIFAR10
NUM_HEADS = 8
NUM_ENCODERS = 6
DROPOUT = 0.1
EMB_DROPOUT = 0.1
LEARNING_RATE = 1e-4
MIN_LEARNING_RATE = 2.5e-5
WEIGHT_DECAY = 1e-6
where
BASE_DIR
is the base directory of the project;LIGHTNING_DIR
is a directory where lightning stores models and logs;MODELS_DIR
is a directory where lightning stores models;LOG_EVERY_N_STEPS
defines how often to log statistics intotensorboard
;MAX_EPOCHS
is the maximum amount of epochs to run;SIZE
is the size of the context vector of each element in the input sequence;HIDDEN_SIZE
is the size of the context after embedding is applied;NUM_PATCHES
is the total amount of patches of an image;NUM_HEADS
is the number of attention heads;NUM_ENCODERS
is the amount of sequential encoders;DROPOUT
is dropout percentage to apply in encoders;EMB_DROPOUT
is dropout percentage to apply in embedding.
Such as weβre training the model with a GPU device, we can speed up the training process by using mixed precision floating points. We do it by the command:
torch.set_float32_matmul_precision('medium')
Under if __name__ == '__main__'
section, we create a data module and instantiate the model:
data = CIFAR10DataModule(batch_size=BATCH_SIZE, patch_size=PATCH_SIZE)
model = ViT(
size=SIZE,
hidden_size=HIDDEN_SIZE,
num_patches=NUM_PATCHES,
num_classes=data.classes,
num_heads=NUM_HEADS,
num_encoders=NUM_ENCODERS,
emb_dropout=EMB_DROPOUT,
dropout=DROPOUT,
lr=LEARNING_RATE,
weight_decay=WEIGHT_DECAY,
epochs=MAX_EPOCHS
)
Then we create several useful callbacks:
checkpoint_callback = ModelCheckpoint(
dirpath=MODELS_DIR,
monitor="val_loss",
save_last=True,
verbose=True
)
es = EarlyStopping(monitor="val_loss", mode="min", patience=10)
lr_monitor = LearningRateMonitor(logging_interval='epoch')
ModelCheckpoint
callback saves a model after each epoch intoMODELS_DIR
directory. So, if the training process breaks we'll not lose the progress;EarlyStopping
monitors validation loss and stops training if it wasn't improved in the last 10 epochs;LearningRateMonitor
adds the visualization of the learning rate totensorboard
.
Eventually, we see the trainer:
trainer = pl.Trainer(
accelerator="cuda",
precision="bf16",
default_root_dir=LIGHTNING_DIR,
log_every_n_steps=LOG_EVERY_N_STEPS,
max_epochs=MAX_EPOCHS,
callbacks=[checkpoint_callback, es, lr_monitor],
resume_from_checkpoint=MODELS_DIR.joinpath("last.ckpt")
)
trainer.fit(model, data)
We create an instance of pytorch_lightning.Trainer
and run it for model
and data
. For GPU training we should set cuda
as an accelerator. Also, we want to set the precision to bf16
. Read more about PyTorch Lightning precision management in N-bit Precision.
One of the biggest advantages of pytorch_lightning
(I suppose) is that you don't need to pass .to(device)
to every tensor you have, you just pass accelerator
to the trainer.
We run the model training with the command:
python train.py
In the terminal, you may see the training progress. Note, that if you donβt have last.ckpt
model saved, you should remove resume_from_checkpoint
argument from Trainer
creation.
Run this command to visualize training in Tensorboard:
tensorboard --logdir=data/lightning/lightning_logs
CPU Training
If there is only a CPU or MPS device available, you can train a smaller model.
Set the following hyperparameters in train.py
:
BATCH_SIZE = 256
VAL_BATCH_SIZE = 256
PATCH_SIZE = 4
SIZE = PATCH_SIZE * PATCH_SIZE * 3 # 4 * 4 * 3 (RGB colors)
HIDDEN_SIZE = 512
NUM_PATCHES = int(32 * 32 / PATCH_SIZE ** 2) # 32 x 32 is the size of image in CIFAR10
NUM_HEADS = 8
NUM_ENCODERS = 4
DROPOUT = 0.1
EMB_DROPOUT = 0.16
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 1e-6
update the accelerator
value in the Trainer:
trainer = pl.Trainer(
accelerator="cpu",
default_root_dir=LIGHTNING_DIR,
log_every_n_steps=LOG_EVERY_N_STEPS,
max_epochs=MAX_EPOCHS,
callbacks=[checkpoint_callback, es, lr_monitor],
resume_from_checkpoint=MODELS_DIR.joinpath("last.ckpt")
)
trainer.fit(model, data)
and remove AutoAugment
transform in src/dataset.py
, having:
self.train_transform = transforms.Compose(
[
transforms.RandomHorizontalFlip(),
transforms.RandomResizedCrop(size=(im_size, im_size)),
transforms.RandomRotation(degrees=rotation_degrees),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)),
PatchifyTransform(patch_size)
]
)
The model with the above parameters can be trained to ~80% accuracy for around 6β8 hours on an MPS device.
Training Results
I trained the model on RTX 4090. 200 epochs of training took almost 3 hours. At the end of the 200th epoch, the model is ~83% accurate on validation data.
Below you may see charts of the training progress.
The model is neither on plateau nor overfitting. We can continue training and get even higher accuracy.
Evaluate Model
Now, itβs time to evaluate the model by running classify.py. Follow the code on GitHub to see the full implementation of the script.
First of all, we create CIFAR10
validation dataset and load the trained model by command:
model = ViT.load_from_checkpoint(MODELS_DIR.joinpath('last.ckpt'))
model.eval()
Then we can use the model
to classify the images.
If you run the script, you should see a nice 32 x 32
frog.
And the output in the terminal:
Predicted class: 6 - frog
Target class: 6 - frog
Summary
As you see there is nothing complicated in the vision transformer. It utilizes the same self-attention mechanism as any other transformer model. Despite the production models should be pre-trained on a huge set of data, it was shown that the vision transformer can be trained even with the CIFAR10 dataset. However powerful computation device is required: 200 epochs of training of the above model were running for almost 3 hours on RTX 4090.
Also, we fulfilled our secondary goal of the tutorial and showed how well the code is organized when we use a bundle of pytorch
and pytorch_lightning
. This approach can be used to, actually, prepare any production-ready model.
References
[1] Mikhail Kravets, vision_transformer (2023). GitHub repository containing the full code of the tutorial;
[2] Ashish Vaswani, Llion Jones, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Aidan N. Gomez, Εukasz Kaiser, Illia Polosukhin, Attention Is All You Need (2017);
[3] Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby, An Image Is Worth 16×16 Words: Transformers For Image Recognition At Scale (2021);
[4] Liyuan Liu, Xiaodong Liu, Jianfeng Gao, Weizhu Chen, Jiawei Han, Understanding the Difficulty of Training Transformers (2020);
[5] Jay Alammar, The Illustrated Transformer (2018);
[6] Ekin D. Cubuk, Barret Zoph, Dandelion Mane, Vijay Vasudevan, Quoc V. Le, AutoAugment: Learning Augmentation Strategies from Data (2019);
[7] Leonie Monigatti, A Visual Guide to Learning Rate Schedulers in PyTorch (2022).
Join thousands of data leaders on the AI newsletter. Join over 80,000 subscribers and keep up to date with the latest developments in AI. From research to projects and ideas. If you are building an AI startup, an AI-related product, or a service, we invite you to consider becoming aΒ sponsor.
Published via Towards AI