Name: Towards AI Legal Name: Towards AI, Inc. Description: Towards AI is the world's leading artificial intelligence (AI) and technology publication. Read by thought-leaders and decision-makers around the world. Phone Number: +1-650-246-9381 Email: [email protected]
228 Park Avenue South New York, NY 10003 United States
Website: Publisher: Diversity Policy: Ethics Policy: Masthead:
Name: Towards AI Legal Name: Towards AI, Inc. Description: Towards AI is the world's leading artificial intelligence (AI) and technology publication. Founders: Roberto Iriondo, , Job Title: Co-founder and Advisor Works for: Towards AI, Inc. Follow Roberto: X, LinkedIn, GitHub, Google Scholar, Towards AI Profile, Medium, ML@CMU, FreeCodeCamp, Crunchbase, Bloomberg, Roberto Iriondo, Generative AI Lab, Generative AI Lab Denis Piffaretti, Job Title: Co-founder Works for: Towards AI, Inc. Louie Peters, Job Title: Co-founder Works for: Towards AI, Inc. Louis-François Bouchard, Job Title: Co-founder Works for: Towards AI, Inc. Cover:
Towards AI Cover
Towards AI Logo
Areas Served: Worldwide Alternate Name: Towards AI, Inc. Alternate Name: Towards AI Co. Alternate Name: towards ai Alternate Name: towardsai Alternate Name: Alternate Name: tai Alternate Name: toward ai Alternate Name: Alternate Name: Towards AI, Inc. Alternate Name: Alternate Name:
5 stars – based on 497 reviews

Frequently Used, Contextual References

TODO: Remember to copy unique IDs whenever it needs used. i.e., URL: 304b2e42315e


Take our 85+ lesson From Beginner to Advanced LLM Developer Certification: From choosing a project to deploying a working product this is the most comprehensive and practical LLM course out there!


Generative AI Foundations: Training a Vanilla GAN for Fashion
Artificial Intelligence   Latest   Machine Learning

Generative AI Foundations: Training a Vanilla GAN for Fashion

Last Updated on July 22, 2024 by Editorial Team

Author(s): Amit Kharel

Originally published on Towards AI.

Photo by Mateusz WacΕ‚awek on Unsplash
GAN learning to generate Images [By Author]

(Not a member? Read the article for free.)

Let’s step back and take a break from the over-hype of LLMs/Transformers and get to know one of the foremost Gen AI revolutions: Generative Adversarial Networks (GANs).

What is a GAN?

A GAN is a deep learning neural network architecture where two networks compete with each other to generate new data learned from the training dataset. There are two different models/networks : the Generator Model and the Discriminator Model. The Generator Model learns to generate new data by taking random input noise, while the Discriminator Model learns to discriminate whether the data is real (from the training set) or fake (from the generator).

And that’s where the magic happens.

As the Discriminator Model learns to distinguish between real and fake data, the Generator Model improves its ability to generate data that is indistinguishable from real data. The main goal is to ensure both models are equally powerful, so the loss doesn’t favor either model. This is important for two reasons:

  1. If the Discriminator Model becomes too powerful, it will confidently identify the Generator’s data as fake, and the Generator won’t be able to fool it as it consistently receives strong signals that its outputs are incorrect..
  2. If the Generator Model becomes too powerful, it will generate data that doesn’t resemble the desired output but can still fool the Discriminator into thinking it’s real by exploiting it’s weaknesses.

Well, it does sound interesting. Let’s dive into the code to see how it works behind the scenes.

Table of Contents

Β· Setting Up
Β· Loading Fashion MNIST Dataset
Β· Building a Vanilla GAN
∘ Generator Model
∘ Discriminator Model
Β· Training
Β· Final Code
Β· Test and Evaluation

Setting Up

Firstly, installing all the required libraries.

pip install torch torchvision matplotlib

Import all the required libraries.

import os
import sys
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from import DataLoader
from torchvision import datasets
from torchvision import transforms
from torchvision.utils import save_image
import numpy as np
import datetime
from matplotlib.pyplot import imshow, imsave

%matplotlib inline

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Loading Fashion MNIST Dataset

We’ll be using the Fasion MNIST dataset which contains various clothing images of size (28,28).

image_dim = (28,28)
batch_size = 64
n_noise = 100
max_epoch = 100
n_critic = 2 # the number of iterations of the critic per generator iteration

# image transformer
transform = transforms.Compose([
transforms.Normalize((0.5,), (0.5,))

dataset = datasets.FashionMNIST(root='fashion_mnist', train=True, transform=transform, download=True)

# data loader for training
data_loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, drop_last=True)
Dataset Visualization [Image by Author]

Building a Vanilla GAN

The interesting part is here. For this project, we’ll build our models with simple Deep Neural Network architecture which still does a good job while working with images of smaller scales.

Generator Model

This model will take in random noise of size n_noise and return us a fake generated image.

class Generator(nn.Module):
Simple Generator w/ MLP
def __init__(self, input_size=n_noise, output_size=784):
super(Generator, self).__init__()
self.layer = nn.Sequential(
nn.Linear(input_size, 256),
nn.Linear(256, 512),
nn.Linear(512, 1024),
nn.Linear(1024, output_size),

def forward(self, x):
x = self.layer(x)
return x.view(x.size(0), 1, *image_dim)

# define the model
G = Generator(input_size=n_noise, output_size=image_dim[0] * image_dim[1]).to(DEVICE)

Let’s visualize what our Generator model comes up with before training:

def get_sample_image(G, n_samples=100):
get sample images from generator
z = torch.randn(n_samples, n_noise).to(DEVICE)
y_hat = G(z).view(n_samples, *image_dim) # (100, 28, 28)
result = y_hat.cpu().data.numpy()

n_rows = int(np.sqrt(n_samples))
n_cols = int(np.sqrt(n_samples))

assert n_rows * n_cols == n_samples

img = np.zeros([image_dim[0] * n_rows, image_dim[1] * n_cols])
for j in range(n_rows):
img[j*image_dim[0]:(j+1)*image_dim[1]] = np.concatenate([x for x in result[j*n_cols:(j+1)*n_cols]], axis=-1)
return img
Initial Generator Image Visualization [Image by Author]

Well, it’s a noisy image but it can only learn when there’s a Discriminator Model teaching it what’s real and what’s not.

Discriminator Model

This model takes in images from both the training dataset and the generator, and returns a prediction between 0 and 1, indicating how real the data is.

class Discriminator(nn.Module):
Simple Discriminator w/ MLP
def __init__(self, input_size=784, output_size=1):
super(Discriminator, self).__init__()
self.layer = nn.Sequential(
nn.Linear(input_size, 1024),
nn.Linear(1024, 512),
nn.Linear(512, 256),
nn.Linear(256, output_size),

def forward(self, x):
x = x.view(x.size(0), -1)
x = self.layer(x)
return x

# define the model
D = Discriminator(input_size=image_dim[0] * image_dim[1], output_size=1).to(DEVICE)


To train the model, we first initialize two sets of labels: true and fake. The true labels will be used with the images from the training dataset and fed to the Discriminator, where it learns to assign these images a true label (1). Similarly, the fake labels will be assigned to the images from the Generator Model.

D_true_labels = torch.ones(batch_size, 1).to(DEVICE) # True Label for real images
D_fake_labels = torch.zeros(batch_size, 1).to(DEVICE) # False Label for fake images

loss = nn.BCELoss() # Binary Cross Entropy Loss
D_opt = torch.optim.Adam(D.parameters(), lr=0.0002, betas=(0.5, 0.999))
G_opt = torch.optim.Adam(G.parameters(), lr=0.0002, betas=(0.5, 0.999))

if not os.path.exists('results'):

Now, we loop over each epoch, training the Discriminator to distinguish between real and fake data. Every n_critic steps, the Generator Model will use the Discriminator's feedback to improve its ability to generate convincing fake images.

for epoch in range(max_epoch):
for idx, (images, _) in enumerate(data_loader):
x =
x_outputs = D(x)
D_x_loss = loss(x_outputs, D_true_labels)

z = torch.randn(batch_size, n_noise).to(DEVICE)
z_outputs = D(G(z))
D_z_loss = loss(z_outputs, D_fake_labels)

D_loss = D_x_loss + D_z_loss


if step % n_critic == 0:
z = torch.randn(batch_size, n_noise).to(DEVICE)
z_outputs = D(G(z))
G_loss = loss(z_outputs, D_true_labels)


if step % 1000 == 0:
print('Epoch: {}/{}, Step: {}, D Loss: {}, G Loss: {}'.format(epoch, max_epoch, step, D_loss.item(), G_loss.item()))

samples = get_sample_image(G, n_samples=64)
imsave('results/{}_step{}.jpg'.format(MODEL_NAME, str(step).zfill(3)), samples, cmap='gray')
step += 1

Final Code

You can copy and paste below code to a python file and run it to train the model and evaluate generated images in results folder.

import os
import sys
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from import DataLoader
from torchvision import datasets
from torchvision import transforms
from torchvision.utils import save_image
import numpy as np
from matplotlib.pyplot import imshow, imsave

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

image_dim = (28,28)
batch_size = 64
n_noise = 100
max_epoch = 100
n_critic = 5 # the number of iterations of the critic per generator iteration
step = 0 # the number of iterations

transform = transforms.Compose([
transforms.Normalize((0.5,), (0.5,))
dataset = datasets.FashionMNIST(root='fashion_mnist', train=True, transform=transform, download=True)
data_loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, drop_last=True)

def get_sample_image(G, n_samples=100):
get sample images from generator
z = torch.randn(n_samples, n_noise).to(DEVICE)
y_hat = G(z).view(n_samples, *image_dim) # (100, 28, 28)
result = y_hat.cpu().data.numpy()

n_rows = int(np.sqrt(n_samples))
n_cols = int(np.sqrt(n_samples))

assert n_rows * n_cols == n_samples

img = np.zeros([image_dim[0] * n_rows, image_dim[1] * n_cols])
for j in range(n_rows):
img[j*image_dim[0]:(j+1)*image_dim[1]] = np.concatenate([x for x in result[j*n_cols:(j+1)*n_cols]], axis=-1)
return img

class Generator(nn.Module):
Simple Generator w/ MLP
def __init__(self, input_size=n_noise, output_size=784):
super(Generator, self).__init__()
self.layer = nn.Sequential(
nn.Linear(input_size, 256),
nn.Linear(256, 512),
nn.Linear(512, 1024),
nn.Linear(1024, output_size),

def forward(self, x):
x = self.layer(x)
return x.view(x.size(0), 1, *image_dim)

class Discriminator(nn.Module):
Simple Discriminator w/ MLP
def __init__(self, input_size=784, output_size=1):
super(Discriminator, self).__init__()
self.layer = nn.Sequential(
nn.Linear(input_size, 1024),
nn.Linear(1024, 512),
nn.Linear(512, 256),
nn.Linear(256, output_size),

def forward(self, x):
x = x.view(x.size(0), -1)
x = self.layer(x)
return x

G = Generator(input_size=n_noise, output_size=image_dim[0] * image_dim[1]).to(DEVICE)
G = torch.compile(G)

D = Discriminator(input_size=image_dim[0] * image_dim[1], output_size=1).to(DEVICE)
D = torch.compile(D)

D_true_labels = torch.ones(batch_size, 1).to(DEVICE) # True Label for real images
D_fake_labels = torch.zeros(batch_size, 1).to(DEVICE) # False Label for fake images

loss = nn.BCELoss() # Binary Cross Entropy Loss
D_opt = torch.optim.Adam(D.parameters(), lr=0.0002, betas=(0.5, 0.999))
G_opt = torch.optim.Adam(G.parameters(), lr=0.0002, betas=(0.5, 0.999))

if not os.path.exists('results'):

for epoch in range(max_epoch):
for idx, (images, _) in enumerate(data_loader):
x =
x_outputs = D(x)
D_x_loss = loss(x_outputs, D_true_labels)

z = torch.randn(batch_size, n_noise).to(DEVICE)
z_outputs = D(G(z))
D_z_loss = loss(z_outputs, D_fake_labels)

D_loss = D_x_loss + D_z_loss


if step % n_critic == 0:
z = torch.randn(batch_size, n_noise).to(DEVICE)
z_outputs = D(G(z))
G_loss = loss(z_outputs, D_true_labels)


if step % 2000 == 0:
print('Epoch: {}/{}, Step: {}, D Loss: {}, G Loss: {}'.format(epoch, max_epoch, step, D_loss.item(), G_loss.item()))

samples = get_sample_image(G, n_samples=64)
imsave('results/{}_step{}.jpg'.format(MODEL_NAME, str(step).zfill(3)), samples, cmap='gray')
step += 1

There will be gradual change of loss in the initial steps, but once the models reach equilibrium, the loss should remain relatively stable (with very minor changes) for both models until the end.


Let’s see what our model learned over the training:

Image By Author

Pretty good results. You can try training for more steps to see if it improves the generated images’ clarity. But there it is β€” all four images you see above are fake and generated by our models.

Thanks for reading! If you’re interested in the current trends of Generative AI and want to learn more about LLMs, check out the article below on building your own GPT-2 model from scratch.

Building GPT-2 with PyTorch (Part 1)

Ready to build your own GPT?

Building GPT-2 with PyTorch (Part 2)

Build and Train a 29M GPT-2 Model from scratch

Join thousands of data leaders on the AI newsletter. Join over 80,000 subscribers and keep up to date with the latest developments in AI. From research to projects and ideas. If you are building an AI startup, an AI-related product, or a service, we invite you to consider becoming aΒ sponsor.

Published via Towards AI

Feedback ↓