A Beginner’s Guide to Building a Conditional GAN.
Last Updated on July 17, 2023 by Editorial Team
Author(s): Pere Martra
Originally published on Towards AI.
A comprehensive guide to creating conditional GANs with TensorFlow, Python and Keras for imaging generation.
A GAN can be used to make images similar to those of the Dataset it has been trained on. A conditional GAN also allows us to choose the kind of images we want to generate.
To achieve this, we must modify the structure of the models that comprise the GAN so that they accept as input the label that indicate the type of image to be generated.
In this article, I am going to use the MNIST dataset, which is composed of 10 categories, corresponding to the numbers from 0 to 9. It is a Dataset that I have used in a previous article. So, you can easily compare the changes made to the GAN to transform it into a conditional GAN.
The series on GANs is formed, by now, for four articles. In the first of which, we can find a GAN that works with the MNIST Dataset.
GANs From Zero to Hero
View list4 stories
The code of this article can be found on:
Kaggle: https://www.kaggle.com/code/peremartramanonellas/gan-tutorial-4-how-to-create-a-conditional-gan
Github: https://github.com/oopere/GANs/blob/main/C4_COND_GAN_MNIST.ipynb
You have all the Notebooks in the GAN Series in a repository in gitHub under a MIT License. Feel free to clone it, and don’t forget to watch or star it if you want to receive the updates and new notebooks.
GitHub — peremartra/GANs: GAN tutorials using TensorFlow, Keras & Python
GAN tutorials using TensorFlow, Keras & Python. Contribute to peremartra/GANs development by creating an account on…
github.com
General Structure of a Conditional GAN.
The structure is mostly the same as for a normal GAN. We have a Generator that makes the images and a Discriminator who decides if the images are real and belong to the Dataset or if they are from the Generator.
It is important to note that both the Generator and the Discriminator must now receive the class to which the images belong.
As you may have already guessed, both the Generator and the Discriminator will be Multipath models, so we will no longer be able to use the Sequential API. Instead, we will use the Functional API.
If you need an explanation of how the Functional API works or the non-sequential models, you can refer to this article:
How To Predict Multiple Variables With One Model? And Why!
Do you want to save time and cost? As simple as creating a model capable of predicting multiple variables with the same…
pub.towardsai.net
The Generator of our conditional GAN.
Any generator of a GAN receives random data in Gaussian form, also called noise, as an input, which is then transformed into an image of the desired size.
The size of the noise, also known as the latent space, that I have chosen for this Generator is 50. This noise should be converted into a 28×28×1 image. To achieve this, we will start by transforming it into a smaller 7×7×1 image that will reach its target size through upsampling.
In the case of a conditional GAN, the noise will be accompanied by an indicator of the class to which the generated image must belong, that is, the conditional information.
In the branch on the right, we can see how the conditional indicator changes to a 7×7×1 shape to adapt to the size of the source image.
On the left branch, the noise is transformed to 7×7×128, which represents our 7×7 source image with 128 nodes.
These two layers are joined using the Concatenate layer. With that, the conditional indicator is transformed into one more channel, and we have a 7×7×129 image where one of the channels has the information on the class to which it belongs.
Let’s see the definition of the block that receives the class indicator data:
# label input
in_label = keras.layers.Input(shape=(1,))
# embedding for categorical input
li = keras.layers.Embedding(n_classes, 50)(in_label)
# linear multiplication
n_nodes = 7 * 7
li = keras.layers.Dense(n_nodes)(li)
# reshape to additional channel
li = keras.layers.Reshape((7, 7, 1))(li)
First, an Input layer is defined to receive the indicator with a single variable Input(shape=(1,)).
This value is passed through the Embedding layer, where it is transformed into a vector of real numbers. The vector has a size of 50. n_classes indicates the number of existing classes in the dataset.
The next Dense layer linearly multiplies the vector obtained from the Embedding layer.
Finally, we apply a Reshape to the data returned by the Dense layer. Thus, we have the desired 7×7×1 format.
Now let’s see the branch that transforms the noise to a 7×7 image.
# image generator input
in_lat = keras.layers.Input(shape=(latent_dim,))
# foundation for 7x7 image
n_nodes = 128 * 7 * 7
gen = keras.layers.Dense(n_nodes)(in_lat)
gen = keras.layers.LeakyReLU(alpha=0.2)(gen)
gen = keras.layers.Reshape((7, 7, 128))(gen)
In the first line, we define the Input layer for the noise.
Then a Dense layer is defined, which processes the noise, to generate an output with the indicated nodes. We want it to be 7 × 7 with a depth of 128. That is, it will be 6272 nodes.
We use a LeakyReLU activation layer to allow negative values, although at most a value of -0.2. Using a LeakyReLU activation instead of ReLU is one of the GAN Hacks recommendations.
We finish with a Reshape that gives the desired format to the data.
Now let’s look at the fusion and the common part of the model:
#merge image gen and label input
merge = keras.layers.Concatenate()([gen, li])
# upsample to 14x14
gen = keras.layers.Conv2DTranspose(128, (4,4), strides=(2,2), padding='same',
activation=keras.layers.LeakyReLU(alpha=0.2))(merge)
gen = keras.layers.BatchNormalization()(gen)
# upsample to 28x28
gen = keras.layers.Conv2DTranspose(128, (4,4), strides=(2,2), padding='same',
activation=keras.layers.LeakyReLU(alpha=0.2))(gen)
gen = keras.layers.BatchNormalization()(gen)
# output
out_layer = keras.layers.Conv2D(1, (7,7), activation='tanh', padding='same')(gen)
# define model
model = keras.Model([in_lat, in_label], out_layer)
The code block starts with a Concatenate layer that combines the two branches, returning the merged data in 7×7×129, where one of the nodes contains the conditional information.
Subsequently, we begin with the normal upsampling process in any GAN. In this case, two Conv2DTranspose layers are used to first go from 7×7 to 14×14 and finally to 28×28. To highlight the use of the LeakyReLU activator and the BatchNormalization layer after each upsampling.
You can find more information about the upsampling process in the first article of the GAN tutorial: https://medium.com/towards-artificial-intelligence/creating-our-first-optimized-dcgan-12edde5e34c6
Finally, the model is defined using the two input layers and the output layer. The output layer contains all the layers that we have saved in gen, which really covers all the layers of the model since it has incorporated merge, which is the merging of the two previously defined branches.
Let’s see now all the Generator codes together:
# define the standalone generator model
def define_generator(latent_dim, n_classes=10):
# label input
in_label = keras.layers.Input(shape=(1,))
# embedding for categorical input
li = keras.layers.Embedding(n_classes, 50)(in_label)
# linear multiplication
n_nodes = 7 * 7
li = keras.layers.Dense(n_nodes)(li)
# reshape to additional channel
li = keras.layers.Reshape((7, 7, 1))(li)
# image generator input
in_lat = keras.layers.Input(shape=(latent_dim,))
# foundation for 7x7 image
n_nodes = 128 * 7 * 7
gen = keras.layers.Dense(n_nodes)(in_lat)
gen = keras.layers.LeakyReLU(alpha=0.2)(gen)
gen = keras.layers.Reshape((7, 7, 128))(gen)
# merge image gen and label input
merge = keras.layers.Concatenate()([gen, li])
# upsample to 14x14
gen = keras.layers.Conv2DTranspose(128, (4,4), strides=(2,2), padding='same',
activation=keras.layers.LeakyReLU(alpha=0.2))(merge)
gen = keras.layers.BatchNormalization()(gen)
# upsample to 28x28
gen = keras.layers.Conv2DTranspose(128, (4,4), strides=(2,2), padding='same',
activation=keras.layers.LeakyReLU(alpha=0.2))(gen)
gen = keras.layers.BatchNormalization()(gen)
# output
out_layer = keras.layers.Conv2D(1, (7,7), activation='tanh', padding='same')(gen)
# define model
model = keras.Model([in_lat, in_label], out_layer)
return model
noise_size = 50
generator = define_generator(noise_size)
generator.summary()
Model: "model"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_2 (InputLayer) [(None, 50)] 0 []
input_1 (InputLayer) [(None, 1)] 0 []
dense_1 (Dense) (None, 6272) 319872 ['input_2[0][0]']
embedding (Embedding) (None, 1, 50) 500 ['input_1[0][0]']
leaky_re_lu (LeakyReLU) (None, 6272) 0 ['dense_1[0][0]']
dense (Dense) (None, 1, 49) 2499 ['embedding[0][0]']
reshape_1 (Reshape) (None, 7, 7, 128) 0 ['leaky_re_lu[0][0]']
reshape (Reshape) (None, 7, 7, 1) 0 ['dense[0][0]']
concatenate (Concatenate) (None, 7, 7, 129) 0 ['reshape_1[0][0]',
'reshape[0][0]']
conv2d_transpose (Conv2DTransp (None, 14, 14, 128) 264320 ['concatenate[0][0]']
ose)
batch_normalization (BatchNorm (None, 14, 14, 128) 512 ['conv2d_transpose[0][0]']
alization)
conv2d_transpose_1 (Conv2DTran (None, 28, 28, 128) 262272 ['batch_normalization[0][0]']
spose)
batch_normalization_1 (BatchNo (None, 28, 28, 128) 512 ['conv2d_transpose_1[0][0]']
rmalization)
conv2d (Conv2D) (None, 28, 28, 1) 6273 ['batch_normalization_1[0][0]']
==================================================================================================
Total params: 856,760
Trainable params: 856,248
Non-trainable params: 512
Now that we have seen how the generator of our GAN has been created, let’s see the discriminator.
The discriminator of our conditional GAN.
The Discriminator will be responsible for deciding if an image is from the original Dataset, or if the Generator has created it. Apart from the image in this case, as it is a conditional GAN, conditional information is also received that indicates to which class the image belongs.
As with the Generator, we have two separate branches. We receive the conditional information on the left and the image on the right.
Let’s see the code for creating the two branches and their merge:
in_label = keras.layers.Input(shape=(1,))
li = keras.layers.Embedding(n_classes, noise_size)(in_label)
n_nodes = in_shape[0] * in_shape[1]
li=keras.layers.Dense(n_nodes)(li)
li=keras.layers.Reshape((in_shape[0], in_shape[1], 1))(li)
in_image = keras.layers.Input(shape=in_shape)
merge = keras.layers.Concatenate()([in_image, li])
The first two lines define the Input and pass it through an Embedding, just like with the generator.
Then we process the output in a Dense layer that returns a vector of the width and height of the image, with the processed information.
With the Reshape layer we give to the vector the same size and shape as the input image.
Finally, we combine the two branches with the Concatenate layer, which returns the image information plus a channel for the conditional information.
To finalize the Discriminator, it would be necessary to carry out the downsampling process. As in any other GAN, to get from the 28×28 image to a binary conclusion indicating whether or not it is an image created by the Generator.
Here is the complete code of the discriminator:
def define_discriminator(in_shape=(28, 28, 1), n_classes=10):
in_label = keras.layers.Input(shape=(1,))
li = keras.layers.Embedding(n_classes, 50)(in_label)
n_nodes = in_shape[0] * in_shape[1]
li=keras.layers.Dense(n_nodes)(li)
li=keras.layers.Reshape((in_shape[0], in_shape[1], 1))(li)
in_image = keras.layers.Input(shape=in_shape)
merge = keras.layers.Concatenate()([in_image, li])
#downsample
fe=keras.layers.Conv2D(128, (3, 3), strides=(2, 2), padding='same',
activation=keras.layers.LeakyReLU(alpha=0.2))(merge)
fe=keras.layers.Dropout(0.4)(fe)
#downsample
fe=keras.layers.Conv2D(128, (3, 3), strides=(2, 2), padding='same',
activation=keras.layers.LeakyReLU(alpha=0.2))(fe)
fe=keras.layers.Dropout(0.4)(fe)
fe = keras.layers.Flatten()(fe)
out_layer = keras.layers.Dense(1, activation='sigmoid')(fe)
model = keras.Model([in_image, in_label], out_layer)
opt = keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5)
model.compile(loss='binary_crossentropy', optimizer=opt, metrics=['accuracy'])
return model
discriminator = define_discriminator()
discriminator.summary()
Model: "model_1"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_3 (InputLayer) [(None, 1)] 0 []
embedding_1 (Embedding) (None, 1, 50) 500 ['input_3[0][0]']
dense_2 (Dense) (None, 1, 784) 39984 ['embedding_1[0][0]']
input_4 (InputLayer) [(None, 28, 28, 1)] 0 []
reshape_2 (Reshape) (None, 28, 28, 1) 0 ['dense_2[0][0]']
concatenate_1 (Concatenate) (None, 28, 28, 2) 0 ['input_4[0][0]',
'reshape_2[0][0]']
conv2d_1 (Conv2D) (None, 14, 14, 128) 2432 ['concatenate_1[0][0]']
dropout (Dropout) (None, 14, 14, 128) 0 ['conv2d_1[0][0]']
conv2d_2 (Conv2D) (None, 7, 7, 128) 147584 ['dropout[0][0]']
dropout_1 (Dropout) (None, 7, 7, 128) 0 ['conv2d_2[0][0]']
flatten (Flatten) (None, 6272) 0 ['dropout_1[0][0]']
dense_3 (Dense) (None, 1) 6273 ['flatten[0][0]']
==================================================================================================
Total params: 196,773
Trainable params: 196,773
Non-trainable params: 0
__________________________________________________________________________________________________
With this, we would already have the two models that constitute our conditional GAN, now we should join them.
The structure of the GAN.
In order to combine the models, I will use a function that receives the Generator and the Discriminator and returns the complete model, already assembled and compiled.
#define the Conditional GAN
def define_gan(generator, discriminator):
#make discriminator non trainable
discriminator.trainable = False
#get noise and label from generator
gen_noise, gen_label = generator.input
#get output from generator
gen_output = generator.output
#connect image and label input from generator as inputs to discriminator
gan_output = discriminator([gen_output, gen_label])
#define the GAN model.
model= keras.Model([gen_noise, gen_label], gan_output)
#compile model
opt = keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5)
model.compile(loss='binary_crossentropy', optimizer=opt)
return model
The most important thing about this code is that we are using the image and conditional information created by the generator as inputs to the discriminator.
The code, it’s pretty self-explanatory. The discriminator layers are marked as non-trainable. We retrieve the noise input and the Generator label, as well as its output. Finally, we create gan_output, calling the discriminator with the output of the generator, and the conditional label.
We create the model using the Keras Model class. As the first parameter, it receives the inputs, and as the second, the output.
We finished compiling the model.
Even though the structure seems very complex, it is actually a sum of the two models.
GAN training.
With the models already created, now it’s time to train the GAN. The process is simple. We must decide how many epochs we want to train, and each epoch must use all the data from the dataset.
The Dataset is usually divided into batches, to perform the training with several images at the same time. Within each epoch, we will have as many steps as elements in the Dataset divided by the size of the batch. Thus, at the end of the steps, we will have used all the available data.
Each step must contain the following actions:
- Train the Discriminator. It is done in two blocks.
— We pass real images to the discriminator, with their corresponding labels.
— We pass fake images, created with the Generator, to the discriminator, with their corresponding labels. - Train the Generator.
— We pass noise marked with true image labels to the GAN. Within the GAN, this noise and labels are used to generate images that are passed to the discriminator, and the generator modifies their weights to reduce the loss. Within the GAN, the discriminator has its layers marked as non-trainable, so only the generator weights can be modified.
To perform these actions, I am going to create several auxiliary functions, so that the main training function will be easy to read and maintain.
Support functions.
def load_dataset():
# download the training images
(X_train, y_train), (_, _) = keras.datasets.mnist.load_data()
# normalize pixel values
X_train = X_train.astype(np.float32) / 255
# reshape and rescale
X_train = X_train.reshape(-1, 28, 28, 1) * 2. - 1.
return [X_train, y_train]
This function loads the Dataset from Keras, and normalizes the value of the pixels of the images so that they have values between -1 and 1.
The function returns both the image and its class.
def get_dataset_samples(dataset, n_samples):
images, labels = dataset
ix = np.random.randint(0, images.shape[0], n_samples)
X, labels = images[ix], labels[ix]
y = np.ones((n_samples, 1))
return [X, labels], y
The function receives a Dataset and returns the number of elements indicated in n_samples. The first element is decided by a random number, which gives a certain randomness to the data.
Apart from returning the images and the category to which they belong, it also returns a label indicating that the images are true.
def generate_noise(noise_size, n_samples, n_classes=10):
#generate noise
x_input = np.random.randn(noise_size * n_samples)
#shape to adjust to batch size
z_input = x_input.reshape(n_samples, noise_size)
#generate labels
labels = np.random.randint(0, n_classes, n_samples)
return [z_input, labels]
This function is responsible for creating the noise that the Generator receives as input. It returns both the noise and some random class indicators that show which class the image will belong to.
def generate_fake_samples(generator, latent_dim, n_samples):
#get the noise calling the function
z_input, labels_input = generate_noise(latent_dim, n_samples)
images = generator.predict([z_input, labels_input])
#create class labes
y = np.zeros((n_samples, 1))
return [images, labels_input], y
The above function creates images using the Generator. To achieve this, it first calls the function generate_noise to get the input for the generator, and then calls it to get the images.
It returns some images accompanied by their class indicator and a list of labels indicating that they are generated images.
The training function.
We are going to use the helper functions to build the training function. This has two main parts: first the Discriminator is trained and then the Generator.
The discriminator is trained in two batches. In the first one, we pass real images of the Dataset, retrieved with the get_dataset_samples function. In the second block, we pass fake images made with the Generator.
The reason for using two training batches instead of one is that it has been shown to be more efficient. It is one of the recommendations of the GAN Hacks.
#TRAIN THE DISCRIMINATOR
# get randomly selected 'real' samples
[X_real, labels_real], y_real = get_dataset_samples(dataset, half_batch)
# update discriminator model weights
d_loss1, _ = discriminator.train_on_batch([X_real, labels_real], y_real)
# generate 'fake' examples
[X_fake, labels], y_fake = generate_fake_samples(generator, noise_size, half_batch)
# update discriminator model weights
d_loss2, _ = discriminator.train_on_batch([X_fake, labels], y_fake)
The block that trains the Generator uses the generate_noise function, to obtain the necessary input data. The label says that it is true data, but it is actually generated data, and it is passed to the train_on_batch function of the complete GAN model.
#TRAIN THE GENERATOR
# prepare points in latent space as input for the generator
[z_input, labels_input] = generate_noise(noise_size, n_batch)
# create inverted labels for the fake samples
y_gan = np.ones((n_batch, 1))
# update the generator via the discriminator's error
g_loss = GAN.train_on_batch([z_input, labels_input], y_gan)
All the code together creates the following function:
def train_gan(generator, discriminator, GAN, dataset, noise_size=100, n_epochs=30, n_batch=512):
steps = int(dataset[0].shape[0] / n_batch)
half_batch = int(n_batch / 2)
# manually enumerate epochs
for e in range(n_epochs):
# enumerate batches over the training set
for s in range(steps):
#TRAIN THE DISCRIMINATOR
# get randomly selected 'real' samples
[X_real, labels_real], y_real = get_dataset_samples(dataset, half_batch)
# update discriminator model weights
d_loss1, _ = discriminator.train_on_batch([X_real, labels_real], y_real)
# generate 'fake' examples
[X_fake, labels], y_fake = generate_fake_samples(generator, noise_size, half_batch)
# update discriminator model weights
d_loss2, _ = discriminator.train_on_batch([X_fake, labels], y_fake)
#TRAIN THE GENERATOR
# prepare points in latent space as input for the generator
[z_input, labels_input] = generate_noise(noise_size, n_batch)
# create inverted labels for the fake samples
y_gan = np.ones((n_batch, 1))
# update the generator via the discriminator's error
g_loss = GAN.train_on_batch([z_input, labels_input], y_gan)
# summarize loss on this batch
print('>%d, %d/%d, d1=%.3f, d2=%.3f g=%.3f' %
(e+1, s+1, steps, d_loss1, d_loss2, g_loss))
plot_results(X_fake, 8)
# save the generator model
generator.save('cgan_generator.h5')
Apart from training the Discriminator and the Generator, some generated images are shown in each epoch, and the model is saved at the end of the training. Only the Generator is saved, which is the one we are interested in using at inference time.
You just have to make the following call to perform the training:
train_gan(generator, discriminator, GAN, dataset, noise_size, n_epochs=30, n_batch=128)
Using the model to generate a specific class of numbers.
Now that we have a conditional GAN trained, we can generate images of any class we want. In our case, they are numbers that belong to the MNIST Dataset, that is, numbers from 0 to 9.
model = keras.models.load_model('cgan_generator.h5')
latent_points, labels = generate_noise(noise_size, 20)
labels = np.ones(20) * 5
X = model.predict([latent_points, labels])
plot_results(X, 10)
labels = np.ones(20) * 8
X = model.predict([latent_points, labels])
plot_results(X, 10)
As you can see, it’s as easy as loading the model you’ve saved and calling its predict function, passing it noise and the class label we want to make.
In this case, we are generating two different numbers: five and eight.
What uses can a Conditional GAN have?
If you remember in the previous article, where I explained how to use TPUs to increase the performance of our GAN, we generated faces of famous people. In the case of using a conditional GAN, we could have indicated specific characteristics, such as hair color, eyes, sex… A conditional GAN can be based on several class indicators.
It can be used to balance a dataset, given the possibility of generating data of the required type.
To be used in image creation, where it is the user who can indicate some attributes of the image.
Generate custom images, from a text input, or from the selection of attributes.
And many more…
What have we learned?
We have taken a giant step in the creation of GANs. Although we have only used one class indicator in our conditional GAN, we have established the foundation for creating much more complex conditional GANs.
We have used multiple inputs to train a model.
This is a long article, but there has been a lot of progress since the first article about GANs.
I hope you liked it!
I write about TensorFlow and machine learning regularly. Consider following me on Medium to get updates about new articles. And, of course, You are welcome to connect with me on LinkedIn.
More articles in the GAN series:
GANs From Zero to Hero
View list4 stories
If you like TensorFlow and want to know some interesting techniques, check my series: TensorFlow Beyond The Basics.
TensorFlow beyond the basics
View list3 stories
Join thousands of data leaders on the AI newsletter. Join over 80,000 subscribers and keep up to date with the latest developments in AI. From research to projects and ideas. If you are building an AI startup, an AI-related product, or a service, we invite you to consider becoming a sponsor.
Published via Towards AI