
From Pixels to Predictions: Building a Transformer for Images
Author(s): Vicki Y Mu
Originally published on Towards AI.
Convolutional neural networks have been the driving force behind almost every major breakthrough in computer vision — but what if they’ve been holding us back all along?
In 2020, a team of researchers at Google asked the bold question: Can we throw out convolutions entirely and still build world-class image models? Their answer — the Vision Transformer (ViT) — sparked a new era in deep learning.

I’m an undergraduate student at MIT with an interest in computer vision and generative models, and I recently implemented a Vision Transformer from scratch to better understand its architecture. This post is a distilled guide to that process, blending theory, visuals and code.
We’ll walk through how ViT’s turn images into sequences, how attention works in this context, and how it compares to the CNNs you’re used to. By the end, you’ll have a working ViT in PyTorch and a much deeper understanding of how it all fits together.
1. Background and Intuition
1.1 From Recurrent Models to the Rise of Transformers (in NLP)
Before 2017, NLP was ruled by RNNs and LSTMs: models that powered everything from machine translation to language modeling. But despite their early success, they had fundamental limitations. Because they processed sequences one token at a time, training couldn’t be parallelized. And as sequences got longer, they struggled to retain information from earlier tokens. These bottlenecks made it difficult to scale up, especially for tasks that required a deep, global understanding of language.
In 2017, researchers at Google made a breakthrough in their paper Attention Is All You Need. It proposed a new architecture — the Transformer — built around a simple but powerful idea: self-attention. Instead of processing tokens one at a time, self-attention allows each token to directly consider every other token in the sequence.
Put simply, each word learns to ask questions (queries), decide what to listen to (keys), and gather relevant information(values).
This mechanism eliminated the need for recurrence and fixed step order, sidestepping the main weaknesses of RNNs.

Within just two years, Transformer architecture completely took over NLP. It proved more efficient, easier to scale, and better at modeling long-range dependencies than any of its predecessors. Transformers quickly became the backbone of major breakthrough models: BERT (for bi-directional context), GPT (for generative, causal language modeling), and T5 (for sequence to sequence tasks).
RNNs were replaced by attention in NLP — but what about computer vision? At the time, CNNs dominated the field, but they came with their own set of limitations. Convolutions are inherently local, making it difficult for CNNs to capture long-range dependencies. They also rely heavily on spatial prior and careful feature engineering.
So the natural next question emerged: if attention could replace recurrence…could it replace convolution too?
1.2 Can Attention Replace Convolution? The Shift to Vision
In 2020, Dosovitskiy et al. introduced the Vision Transformer (ViT) in their paper An Image Is Worth 16×16 Words. They proposed a bold idea: what if we treated an image like a sentence?
Instead of relying on convolutional filters, they divided images into patches and fed them into a standard transformer. While early ViTs needed massive datasets to compete with CNNs, the approach proved that attention-based models could work for vision — not just language.
Since its release, the Vision Transformer has sparked a wave of improvements:
- DeiT introduced smarter training strategies to reduce ViT’s reliance on huge datasets
- Swin Transformer added hierarchical structure to better handle local spatial patterns
- DINO and DINOv2 showed that ViTs could learn rich visual representations without any labels at all — unlocking powerful self-supervised features for downstream tasks.
What began as a bold experiment has now become a core building block in modern computer vision.
2. How Vision Transformers Work
2.1 Patch Embedding
Transformers were originally designed to process sequences, like sentences made out of word tokens. But images are 2D grids of pixels, not 1D sequences. So how do we feed an image into a transformer?
The Vision Transformer solves this by dividing the image into non-overlapping square patches (e.g. 16×16 pixels). Each patch is then flattened into a 1D vector and linearly projected into a fixed-size embedding — just like token embeddings in NLP.
For example:
- A 224×224 image with 16×16 patches produces (224/16)² = 196 patches.
- Each patch is of shape 3x16x16 (RGB).
- Each flattened patch becomes a 768-dim vector (common for ViT-Base).
Instead of a sentence of words, ViT sees a sequence of image patch embeddings.

Analogy: Just like a tokenizer turns a sentence into a sequence of word embeddings, the ViT turns an image into a sequence of patch embeddings.
2.2 Class Token and Positional Embeddings
Transformers need two extra ingredients to work properly with image sequences:
- a [CLS] token to aggregate global information, and
- positional embeddings to encode spatial structure.
In ViT, a special learnable token is prepended to the input sequence. During self-attention, this token attends to every patch — and becomes the representation used for final classification.
Transformers are permutation-invariant — they don’t inherently understand order. To give the model spatial awareness, we add a unique positional embedding to each token in the sequence.
Both the [CLS]
token and positional embeddings are learned parameters, updated during training.
2.3 Multi-Head Self-Attention (MHSA)
At the heart of the Vision Transformer is the multi-head self-attention mechanism — the part that allows the model to understand how image patches relate to each other, regardless of spatial distance.
Instead of using one attention function, MHSA splits the input into multiple “heads”. Each head learns to focus on different aspects of the input — some might focus on edges, others on texture, others on spatial layout. Their outputs are then concatenated and projected back into the original embedding space.
How it works, step by step:
- The input sequence of tokens (shape
[B, N, D]
) is linearly projected into: QueriesQ
, KeysK
, and ValuesV
. - Each attention head computes:

- Multiple heads run in parallel, and their outputs are concatenated and linearly projected back.
Why “multi-head”?
Each head attends to different parts of the sequence. This allows the model to understand complex relationships in parallel — not just spatial proximity, but also semantic structure.

2.4 Transformer Encoder
Once we have self-attention, we wrap it inside a larger unit: the Transformer block. This block is the fundamental building unit of ViTs (and NLP Transformers too). It combines:
- LayerNorm → Multi-Head Attention → Residual Connection
- LayerNorm → MLP (Feedforward Network) → Residual Connection
Each block enables the model to attend globally and transform features across layers while maintaining stability with normalization and residuals.
What’s inside a ViT Transformer block:
- LayerNorm before attention (called pre-norm).
- Multi-head self-attention applied to the normalized input.
- A residual connection adds the attention output back.
- Another LayerNorm, followed by a small MLP.
- Another residual connection adds the MLP output.

This structure repeats across all transformer layers (e.g., 12 layers in ViT-Base).
2.5 Classification Head
After processing the input through multiple Transformer blocks, the model needs a way to produce a final prediction. In Vision Transformers, this is handled by the classification head.
During the embedding step, we added a special [CLS] token at the beginning of the sequence. Just like in BERT, this token is intended to aggregate information from all the image patches through self-attention. After passing through all Transformer layers, the final embedding of the [CLS] token is used as a summary representation of the entire image.
This vector is then passed through a simple linear layer to output class logits.
3. Implementation Walkthrough
All core modules — patch embedding, MHSA, encoder blocks — are implemented from scratch. No timm
shortcuts.
3.1 Patch Embedding
To convert image patches into a sequence of embeddings, we use a clever trick. Instead of writing a manual for-loop to extract and flatten patches, we can use a Conv2d
layer with:
kernel_size = patch_size
stride = patch_size
This extracts non-overlapping patches and applies a learned linear projection — all in a single operation. It’s clean, efficient, and easy to backpropagate through.
class PatchEmbed(nn.Module):
def __init__(self, img_size = 224, patch_size = 16, in_chans = 3, embed_dim = 768):
super().__init__()
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = (img_size // patch_size) ** 2
self.proj = nn.Conv2d(
in_chans,
embed_dim,
kernel_size = patch_size,
stride = patch_size)
def forward(self, x):
# x shape: [B, 3, 224, 224]
x = self.proj(x) # [B, emdbed_dim, H/patch, W/patch]
x = x.flatten(2) # [B, emdbed_dim, num_patches]
x = x.transpose(1, 2) # [B, num_patches, embed_dim]
return x
3.2 Class Token and Positional Embeddings
Here we define a ViTEmbed
module that:
- Prepends a learnable
[CLS]
token to the sequence - Adds a learnable positional embedding to each token (including
[CLS]
)
This produces a sequence shaped [B, num_patches + 1, embed_dim]
— ready for the transformer encoder.
class ViTEmbed(nn.Module):
def __init__(self, num_patches, embed_dim):
super().__init__()
# Learnable [CLS] token (1 per model)
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) # [1, 1, D]
# Learnable positional embeddings (1 per token, including CLS)
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) # [1, N+1, D]
def forward(self, x):
batch_size = x.shape[0]
# Expand [CLS] token to match batch size
cls_tokens = self.cls_token.expand(batch_size, -1, -1) # [B, 1, D]
# Prepend CLS token to patch embeddings
x = torch.cat((cls_tokens, x), dim=1) # [B, N+1, D]
# Add positional embeddings
x = x + self.pos_embed # [B, N+1, D]
return x
3.3 Multi-Head Self-Attention
Let’s implement one of the most important parts of the Vision Transformer: multi-head self-attention.
Each input token is linearly projected into a query (Q), key (K), and value (V) vector. Attention is computed in parallel across multiple heads, then concatenated and projected back to the original embedding dimension.
class MyMultiheadAttention(nn.Module):
def __init__(self, embed_dim, num_heads):
super().__init__()
assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
# Learnable projections for Q, K, V
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
# Final output projection
self.out_proj = nn.Linear(embed_dim, embed_dim)
def forward(self, x):
B, T, C = x.shape # [batch, seq_len, embed_dim]
# Project input into Q, K, V
Q = self.q_proj(x)
K = self.k_proj(x)
V = self.v_proj(x)
# Reshape into heads: [B, num_heads, T, head_dim]
def split_heads(tensor):
return tensor.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
Q = split_heads(Q)
K = split_heads(K)
V = split_heads(V)
# Scaled dot-product attention
scores = torch.matmul(Q, K.transpose(-2, -1)) # [B, heads, T, T]
scores /= self.head_dim ** 0.5
attn = torch.softmax(scores, dim=-1)
# Apply attention to values
out = torch.matmul(attn, V) # [B, heads, T, head_dim]
# Recombine heads
out = out.transpose(1, 2).contiguous().view(B, T, C)
# Final linear projection
return self.out_proj(out)
3.4 Transformer Encoder
We now wrap everything together in a Transformer block — a modular unit that stacks self-attention and MLP layers with residual connections. This design lets the model reason globally (through self-attention) and then transform those representations (through the MLP), all while preserving stability via skip connections.
In this implementation:
- We use our own
MyMultiheadAttention
class from earlier to demystify how attention works under the hood. - In practice, you can use PyTorch’s built-in
nn.MultiheadAttention
for convenience and efficiency. - We apply LayerNorm before both the attention and MLP layers (a “pre-norm” design).
- The
mlp_ratio
controls the width of the MLP’s hidden layer (usually 3–4× wider than the embedding dimension).
Let’s build the full Transformer block:
class TransformerBlock(nn.Module):
def __init__(self, embed_dim, num_heads, mlp_ratio=4.0):
super().__init__()
self.norm1 = nn.LayerNorm(embed_dim)
self.attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
self.norm2 = nn.LayerNorm(embed_dim)
self.mlp = nn.Sequential(
nn.Linear(embed_dim, int(embed_dim * mlp_ratio)),
nn.GELU(),
nn.Linear(int(embed_dim * mlp_ratio), embed_dim)
)
def forward(self, x):
# Self-attention with residual connection
x = x + self.attn(self.norm1(x), self.norm1(x), self.norm1(x))[0]
# MLP with residual connection
x = x + self.mlp(self.norm2(x))
return x
3.5 Putting It All Together
Now that we’ve built all the key components of a Vision Transformer — patch embedding, positional encoding, multi-head self-attention, Transformer blocks, and the [CLS] token — it’s time to assemble everything into a full model.
In the code below:
- We use our
PatchEmbed
,ViTEmbed
, andTransformerBlock
classes from earlier. - The
[CLS]
token is passed through all transformer layers and then normalized. - We add a classification head: a single
nn.Linear
layer that maps the[CLS]
token embedding to class logits. - This architecture mirrors the original ViT-Base (12 layers, 12 heads, 768-dim embeddings), but it’s easy to scale.
class SimpleViT(nn.Module):
def __init__(
self, img_size=224, patch_size=16, in_chans=3,
embed_dim=768, depth=12, num_heads=12, num_classes=1000
):
super().__init__()
self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
num_patches = (img_size // patch_size) ** 2
self.vit_embed = ViTEmbed(num_patches, embed_dim)
# Stack transformer blocks
self.blocks = nn.Sequential(*[
TransformerBlock(embed_dim, num_heads) for _ in range(depth)
])
# Final normalization before classification
self.norm = nn.LayerNorm(embed_dim)
# Linear classification head (using the CLS token)
self.head = nn.Linear(embed_dim, num_classes)
def forward(self, x): # [batch_size, channels, height, width]
x = self.patch_embed(x) # -> [B, N, D]
x = self.vit_embed(x) # add CLS token + pos embed
x = self.blocks(x) # transformer layers
x = self.norm(x) # normalize CLS token
return self.head(x[:, 0]) # classification using CLS token
4. Training the ViT
4.1 Dataset: CIFAR-10
We trained our Vision Transformer (ViT) on CIFAR-10, a well-known benchmark dataset with 60,000 images across 10 classes (e.g., airplanes, cats, ships). Each image is only 32×32 pixels, which makes CIFAR-10:
- Lightweight and quick to train on
- Challenging enough to evaluate small models
- Easy to visualize, which helps interpret what the model is learning
4.2 Model Setup: Adapting ViT for CIFAR-10
ViTs were originally designed for large-scale datasets like ImageNet, so we made several adjustments to make training feasible on CIFAR-10 with limited compute:
- Input size: Resized to 32×32 to match CIFAR-10
- Patch size: 4×4 → yields 64 tokens per image
- Embedding dimension: 192 (smaller than ImageNet-scale ViTs)
- Depth: 6 transformer blocks
- Heads: 3 attention heads (192 ÷ 3 = 64 dim per head)
- Positional embeddings: Recomputed for 64+1 tokens
- Batch size: 80 — balances speed and memory on Colab
# Refactored SimpleViT config for CIFAR-10
model = SimpleViT(
img_size=32, # CIFAR-10 images are 32x32
patch_size=4, # 4x4 patches → 64 tokens
in_chans=3,
embed_dim=192, # Smaller embedding size
depth=6, # Fewer transformer blocks
num_heads=3, # Divides evenly into 192
num_classes=10 # For CIFAR-10
).to(device)
4.3 Training Setup
The model was trained using:
- Optimizer: Adam (
lr = 3e-4
) - Loss: CrossEntropyLoss
- Hardware: Google Colab T4 GPU
Training was efficient — about 30 seconds per epoch, thanks to:
- Fewer transformer blocks and tokens
- Larger batch size (80)
- Explicit use of CUDA (
to(device)
)
4.3 Results
We trained our Vision Transformer for 30 epochs, totaling ~15 minutes on a GPU. By the end of training, the model achieved approximately 60% accuracy on the CIFAR-10 test set — a solid baseline given the model’s simplicity and the relatively small dataset size.Learning Progress
As shown in the training plots below:
- Training loss steadily decreased, indicating that the model was effectively minimizing prediction error on the training set.
- Test accuracy improved rapidly within the first 10 epochs, plateauing around 60% thereafter. This suggests the model learned quickly but struggled to generalize further without additional techniques like data augmentation or regularization.

Here are a few example outputs from the model. While it correctly identified many samples (like cats and frogs), it struggled with visually similar classes (e.g., misclassifying a ship as an airplane).

The bar chart below shows how well the model performed across all 10 classes. Notably:
- The model performed best on ship, automobile, and frog classes — likely due to more distinctive visual features.
- Performance lagged on cat and bird, which may be harder to distinguish due to higher intra-class variation and similar textures or shapes shared with other animals.

5. Limitations and Extensions
Despite their success, Vision Transformers (ViTs) come with trade-offs. Here’s a summary of what to keep in mind:
5.1 Limitations
- Data-Hungry by Design
ViTs lack the strong inductive biases of CNNs (like locality and translation invariance), which means they typically require large datasets to perform well.
→ This is why the original ViT was pretrained on massive private datasets. - Quadratic Time Complexity
The self-attention mechanism scales with the square of the number of input tokens — making ViTs computationally expensive for high-resolution images. For an image split intoN
patches, attention scales as O(N²).
5.2 Extensions and Improvements
Researchers have developed several workarounds and improvements to address these issues:
- DeiT (Data-efficient Image Transformer)
A version of ViT trained without large private datasets, using knowledge distillation from a CNN teacher to improve performance on smaller datasets like ImageNet. - Pretrained Backbones + Fine-Tuning
Instead of training ViTs from scratch, most modern pipelines use pretrained ViTs and then fine-tune them on downstream tasks with fewer samples. - Swin Transformer
Introduces a hierarchical structure similar to CNNs by using local window-based attention that shifts across layers — making it efficient and scalable for high-resolution inputs. - 🛠 Fine-tuning on Small Datasets
Techniques like freezing early layers, adding task-specific heads, or leveraging self-supervised pretraining (e.g., DINO, MAE) can help ViTs adapt well to limited data.
In short, while ViTs opened the door to attention-based vision modeling, their full potential is best realized when paired with large-scale pretraining, architectural tweaks, or smart training tricks.
6. GitHub + Colab
View the GitHub Repo
Includes a clean folder structure with:
vit_cifar10.ipynb
notebookimages/
folder for visualizationsrequirements.txt
for easy installation
Open in Colab
Readers can fork and run the notebook directly in the browser.
Installation
git clone https://github.com/vickiiimu/vit-cifar10-tutorial.git
cd vit-cifar10-tutorial
pip install -r requirements.txt
7. Conclusion
Congratulations — you’ve just built a Vision Transformer from scratch!
Along the way, we covered the intuitions behind attention, walked through ViT’s architecture block-by-block, and reimplemented core components like patch embedding, positional encoding, and multi-head self-attention. If you followed the walkthrough, you now have a functioning ViT model you fully understand.
Whether you’re here to learn the internals, prototype your own vision ideas, or just scratch the itch of curiosity — this is your sandbox.
Feedback welcome
If you spot bugs, have suggestions, or build something cool on top of this, feel free to open an issue or pull request on GitHub.
Further reading
- An Image is Worth 16×16 Words (ViT paper)
- Attention is All You Need (Transformer paper)
- DeiT: Data-efficient training of ViTs
Thanks for reading! 👋
See you in the next post.
About Me
I’m a sophomore at MIT studying physics and artificial intelligence. This post is part of a series — Recreated from Scratch — where I reimplement foundational AI papers to learn them inside and out. I’ll be posting new walkthroughs every few of weeks, diving into a different model or paper each time.
If you enjoyed this walkthrough, feel free to follow me on GitHub, follow me on Medium, or even reach out. I love chatting about research, open-source projects, and all things deep learning.
Join thousands of data leaders on the AI newsletter. Join over 80,000 subscribers and keep up to date with the latest developments in AI. From research to projects and ideas. If you are building an AI startup, an AI-related product, or a service, we invite you to consider becoming a sponsor.
Published via Towards AI
Take our 90+ lesson From Beginner to Advanced LLM Developer Certification: From choosing a project to deploying a working product this is the most comprehensive and practical LLM course out there!
Towards AI has published Building LLMs for Production—our 470+ page guide to mastering LLMs with practical projects and expert insights!

Discover Your Dream AI Career at Towards AI Jobs
Towards AI has built a jobs board tailored specifically to Machine Learning and Data Science Jobs and Skills. Our software searches for live AI jobs each hour, labels and categorises them and makes them easily searchable. Explore over 40,000 live jobs today with Towards AI Jobs!
Note: Content contains the views of the contributing authors and not Towards AI.