Generative AI Foundations: Training a Vanilla GAN for Fashion
Last Updated on July 22, 2024 by Editorial Team
Author(s): Amit Kharel
Originally published on Towards AI.
(Not a member? Read the article for free.)
Letβs step back and take a break from the over-hype of LLMs/Transformers and get to know one of the foremost Gen AI revolutions: Generative Adversarial Networks (GANs).
What is a GAN?
A GAN is a deep learning neural network architecture where two networks compete with each other to generate new data learned from the training dataset. There are two different models/networks : the Generator Model and the Discriminator Model. The Generator Model learns to generate new data by taking random input noise, while the Discriminator Model learns to discriminate whether the data is real (from the training set) or fake (from the generator).
And thatβs where the magic happens.
As the Discriminator Model learns to distinguish between real and fake data, the Generator Model improves its ability to generate data that is indistinguishable from real data. The main goal is to ensure both models are equally powerful, so the loss doesnβt favor either model. This is important for two reasons:
- If the Discriminator Model becomes too powerful, it will confidently identify the Generatorβs data as fake, and the Generator wonβt be able to fool it as it consistently receives strong signals that its outputs are incorrect..
- If the Generator Model becomes too powerful, it will generate data that doesnβt resemble the desired output but can still fool the Discriminator into thinking itβs real by exploiting itβs weaknesses.
Well, it does sound interesting. Letβs dive into the code to see how it works behind the scenes.
Table of Contents
Β· Setting Up
Β· Loading Fashion MNIST Dataset
Β· Building a Vanilla GAN
β Generator Model
β Discriminator Model
Β· Training
Β· Final Code
Β· Test and Evaluation
Setting Up
Firstly, installing all the required libraries.
pip install torch torchvision matplotlib
Import all the required libraries.
import os
import sys
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
from torchvision.utils import save_image
import numpy as np
import datetime
from matplotlib.pyplot import imshow, imsave
%matplotlib inline
MODEL_NAME = "VanillaGAN"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Loading Fashion MNIST Dataset
Weβll be using the Fasion MNIST dataset which contains various clothing images of size (28,28).
image_dim = (28,28)
batch_size = 64
n_noise = 100
max_epoch = 100
n_critic = 2 # the number of iterations of the critic per generator iteration
image_dim
# image transformer
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
dataset = datasets.FashionMNIST(root='fashion_mnist', train=True, transform=transform, download=True)
# data loader for training
data_loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, drop_last=True)
Building a Vanilla GAN
The interesting part is here. For this project, weβll build our models with simple Deep Neural Network architecture which still does a good job while working with images of smaller scales.
Generator Model
This model will take in random noise of size n_noise
and return us a fake generated image.
class Generator(nn.Module):
"""
Simple Generator w/ MLP
"""
def __init__(self, input_size=n_noise, output_size=784):
super(Generator, self).__init__()
self.layer = nn.Sequential(
nn.Linear(input_size, 256),
nn.BatchNorm1d(256),
nn.LeakyReLU(0.2),
nn.Linear(256, 512),
nn.BatchNorm1d(512),
nn.LeakyReLU(0.2),
nn.Linear(512, 1024),
nn.BatchNorm1d(1024),
nn.LeakyReLU(0.2),
nn.Linear(1024, output_size),
nn.Tanh()
)
def forward(self, x):
x = self.layer(x)
return x.view(x.size(0), 1, *image_dim)
# define the model
G = Generator(input_size=n_noise, output_size=image_dim[0] * image_dim[1]).to(DEVICE)
Letβs visualize what our Generator model comes up with before training:
def get_sample_image(G, n_samples=100):
"""
get sample images from generator
"""
z = torch.randn(n_samples, n_noise).to(DEVICE)
y_hat = G(z).view(n_samples, *image_dim) # (100, 28, 28)
result = y_hat.cpu().data.numpy()
n_rows = int(np.sqrt(n_samples))
n_cols = int(np.sqrt(n_samples))
assert n_rows * n_cols == n_samples
img = np.zeros([image_dim[0] * n_rows, image_dim[1] * n_cols])
for j in range(n_rows):
img[j*image_dim[0]:(j+1)*image_dim[1]] = np.concatenate([x for x in result[j*n_cols:(j+1)*n_cols]], axis=-1)
return img
Well, itβs a noisy image but it can only learn when thereβs a Discriminator Model teaching it whatβs real and whatβs not.
Discriminator Model
This model takes in images from both the training dataset and the generator, and returns a prediction between 0 and 1, indicating how real the data is.
class Discriminator(nn.Module):
"""
Simple Discriminator w/ MLP
"""
def __init__(self, input_size=784, output_size=1):
super(Discriminator, self).__init__()
self.layer = nn.Sequential(
nn.Linear(input_size, 1024),
nn.LeakyReLU(0.2),
nn.Dropout(0.3),
nn.Linear(1024, 512),
nn.LeakyReLU(0.2),
nn.Dropout(0.3),
nn.Linear(512, 256),
nn.LeakyReLU(0.2),
nn.Dropout(0.3),
nn.Linear(256, output_size),
nn.Sigmoid()
)
def forward(self, x):
x = x.view(x.size(0), -1)
x = self.layer(x)
return x
# define the model
D = Discriminator(input_size=image_dim[0] * image_dim[1], output_size=1).to(DEVICE)
Training
To train the model, we first initialize two sets of labels: true and fake. The true labels will be used with the images from the training dataset and fed to the Discriminator, where it learns to assign these images a true label (1). Similarly, the fake labels will be assigned to the images from the Generator Model.
D_true_labels = torch.ones(batch_size, 1).to(DEVICE) # True Label for real images
D_fake_labels = torch.zeros(batch_size, 1).to(DEVICE) # False Label for fake images
loss = nn.BCELoss() # Binary Cross Entropy Loss
D_opt = torch.optim.Adam(D.parameters(), lr=0.0002, betas=(0.5, 0.999))
G_opt = torch.optim.Adam(G.parameters(), lr=0.0002, betas=(0.5, 0.999))
if not os.path.exists('results'):
os.makedirs('results')
Now, we loop over each epoch, training the Discriminator to distinguish between real and fake data. Every n_critic
steps, the Generator Model will use the Discriminator's feedback to improve its ability to generate convincing fake images.
for epoch in range(max_epoch):
for idx, (images, _) in enumerate(data_loader):
x = images.to(DEVICE)
x_outputs = D(x)
D_x_loss = loss(x_outputs, D_true_labels)
z = torch.randn(batch_size, n_noise).to(DEVICE)
z_outputs = D(G(z))
D_z_loss = loss(z_outputs, D_fake_labels)
D_loss = D_x_loss + D_z_loss
D.zero_grad()
D_loss.backward()
D_opt.step()
if step % n_critic == 0:
D.eval()
z = torch.randn(batch_size, n_noise).to(DEVICE)
z_outputs = D(G(z))
G_loss = loss(z_outputs, D_true_labels)
G.zero_grad()
G_loss.backward()
G_opt.step()
D.train()
if step % 1000 == 0:
print('Epoch: {}/{}, Step: {}, D Loss: {}, G Loss: {}'.format(epoch, max_epoch, step, D_loss.item(), G_loss.item()))
samples = get_sample_image(G, n_samples=64)
imsave('results/{}_step{}.jpg'.format(MODEL_NAME, str(step).zfill(3)), samples, cmap='gray')
step += 1
Final Code
You can copy and paste below code to a python file and run it to train the model and evaluate generated images in results
folder.
import os
import sys
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
from torchvision.utils import save_image
import numpy as np
from matplotlib.pyplot import imshow, imsave
MODEL_NAME = "VanillaGAN"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
image_dim = (28,28)
batch_size = 64
n_noise = 100
max_epoch = 100
n_critic = 5 # the number of iterations of the critic per generator iteration
step = 0 # the number of iterations
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
dataset = datasets.FashionMNIST(root='fashion_mnist', train=True, transform=transform, download=True)
data_loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, drop_last=True)
def get_sample_image(G, n_samples=100):
"""
get sample images from generator
"""
z = torch.randn(n_samples, n_noise).to(DEVICE)
y_hat = G(z).view(n_samples, *image_dim) # (100, 28, 28)
result = y_hat.cpu().data.numpy()
n_rows = int(np.sqrt(n_samples))
n_cols = int(np.sqrt(n_samples))
assert n_rows * n_cols == n_samples
img = np.zeros([image_dim[0] * n_rows, image_dim[1] * n_cols])
for j in range(n_rows):
img[j*image_dim[0]:(j+1)*image_dim[1]] = np.concatenate([x for x in result[j*n_cols:(j+1)*n_cols]], axis=-1)
return img
class Generator(nn.Module):
"""
Simple Generator w/ MLP
"""
def __init__(self, input_size=n_noise, output_size=784):
super(Generator, self).__init__()
self.layer = nn.Sequential(
nn.Linear(input_size, 256),
nn.BatchNorm1d(256),
nn.LeakyReLU(0.2),
nn.Linear(256, 512),
nn.BatchNorm1d(512),
nn.LeakyReLU(0.2),
nn.Linear(512, 1024),
nn.BatchNorm1d(1024),
nn.LeakyReLU(0.2),
nn.Linear(1024, output_size),
nn.Tanh()
)
def forward(self, x):
x = self.layer(x)
return x.view(x.size(0), 1, *image_dim)
class Discriminator(nn.Module):
"""
Simple Discriminator w/ MLP
"""
def __init__(self, input_size=784, output_size=1):
super(Discriminator, self).__init__()
self.layer = nn.Sequential(
nn.Linear(input_size, 1024),
nn.LeakyReLU(0.2),
nn.Dropout(0.3),
nn.Linear(1024, 512),
nn.LeakyReLU(0.2),
nn.Dropout(0.3),
nn.Linear(512, 256),
nn.LeakyReLU(0.2),
nn.Dropout(0.3),
nn.Linear(256, output_size),
nn.Sigmoid()
)
def forward(self, x):
x = x.view(x.size(0), -1)
x = self.layer(x)
return x
G = Generator(input_size=n_noise, output_size=image_dim[0] * image_dim[1]).to(DEVICE)
G = torch.compile(G)
D = Discriminator(input_size=image_dim[0] * image_dim[1], output_size=1).to(DEVICE)
D = torch.compile(D)
D_true_labels = torch.ones(batch_size, 1).to(DEVICE) # True Label for real images
D_fake_labels = torch.zeros(batch_size, 1).to(DEVICE) # False Label for fake images
loss = nn.BCELoss() # Binary Cross Entropy Loss
D_opt = torch.optim.Adam(D.parameters(), lr=0.0002, betas=(0.5, 0.999))
G_opt = torch.optim.Adam(G.parameters(), lr=0.0002, betas=(0.5, 0.999))
if not os.path.exists('results'):
os.makedirs('results')
for epoch in range(max_epoch):
for idx, (images, _) in enumerate(data_loader):
x = images.to(DEVICE)
x_outputs = D(x)
D_x_loss = loss(x_outputs, D_true_labels)
z = torch.randn(batch_size, n_noise).to(DEVICE)
z_outputs = D(G(z))
D_z_loss = loss(z_outputs, D_fake_labels)
D_loss = D_x_loss + D_z_loss
D.zero_grad()
D_loss.backward()
D_opt.step()
if step % n_critic == 0:
D.eval()
z = torch.randn(batch_size, n_noise).to(DEVICE)
z_outputs = D(G(z))
G_loss = loss(z_outputs, D_true_labels)
G.zero_grad()
G_loss.backward()
G_opt.step()
D.train()
if step % 2000 == 0:
print('Epoch: {}/{}, Step: {}, D Loss: {}, G Loss: {}'.format(epoch, max_epoch, step, D_loss.item(), G_loss.item()))
samples = get_sample_image(G, n_samples=64)
imsave('results/{}_step{}.jpg'.format(MODEL_NAME, str(step).zfill(3)), samples, cmap='gray')
step += 1
There will be gradual change of loss in the initial steps, but once the models reach equilibrium, the loss should remain relatively stable (with very minor changes) for both models until the end.
Results
Letβs see what our model learned over the training:
Pretty good results. You can try training for more steps to see if it improves the generated imagesβ clarity. But there it is β all four images you see above are fake and generated by our models.
Thanks for reading! If youβre interested in the current trends of Generative AI and want to learn more about LLMs, check out the article below on building your own GPT-2 model from scratch.
Building GPT-2 with PyTorch (Part 1)
Ready to build your own GPT?
pub.towardsai.net
Building GPT-2 with PyTorch (Part 2)
Build and Train a 29M GPT-2 Model from scratch
pub.towardsai.net
Join thousands of data leaders on the AI newsletter. Join over 80,000 subscribers and keep up to date with the latest developments in AI. From research to projects and ideas. If you are building an AI startup, an AI-related product, or a service, we invite you to consider becoming aΒ sponsor.
Published via Towards AI