From Concept to Creation: U-Net for Flawless Inpainting
Author(s): Dawid KopeΔ
Originally published on Towards AI.
From Concept to Creation: U-Net for Flawless Inpainting
Introduction
Image inpainting is a powerful computer vision technique for restoring missing or damaged parts of images. This article goes deeper into building and implementing a U-Net architecture specifically for this task.
I will assume that you have a basic understanding of computer vision and deep learning. However, I will provide clear explanations of both image inpainting and U-Net operation for those who might be new to these concepts. Even for seasoned deep learning practitioners, my aim is to offer valuable insights through detailed explanations and potentially surprising practical considerations.
Although the U-Net approach itself is not novel, its application to image inpainting may be less widely described. This article aims to bridge that gap, offering a comprehensive guide for anyone interested in using U-Net for this exciting application.
All the code and more are in my project on Github.
From Flaws to Flawless: Understanding Inpainting
Image inpainting is a machine learning technique that is used to reconstruct missing parts of an image. It is widely used in fields such as historical preservation and photo retouching. Missing parts can be caused by damage, censorship, or other factors that affect the integrity of the image.
There are many different techniques for image inpainting, but they are all based on the same basic concept. The method finds and identifies the damaged area and then analyses the surrounding pixels based on that. By doing so, it is able to understand the context and structure of the image. With this knowledge, it is able to recreate the (hopefully) original appearance by generating the missing pixels.
But what exactly is the model supposed to generate? The whole image or just the missing partβ¦ There are different approaches, but the best answer is kind of both. The model learns to generate the whole new image, but since in most cases we know where the damaged part is, we just take that part of the new image and overlay it on top of the original image. This is because, by design, the modelβs result will be worse than the look of the original image.
Currently, there are many great models to perform this task. Undoubtedly, one of the best are diffusion models. The models that create new data by gradually removing noise from a corrupted version of real data. However, they have one big drawback, computational complexityβ¦. It takes ages to train this model, but worse, the predictions take no less. Therefore, I want to introduce a slightly simpler and less complex architecture that can handle this task.
Beyond Segmentation: U-Netβs Role in Flawless Inpainting
U-Net is a convolutional neural network architecture known for its U-shaped structure. It was originally introduced for biomedical image segmentation. Since its inception, U-Net has demonstrated significant potential and has been widely adopted for various other segmentation tasks. It is now one of the most common and influential models in the field of image segmentation. Beyond its primary use in image segmentation, U-Net has also been effectively applied to several other tasks, including image denoising, object detection, and even natural language processing (NLP).
What Makes U-Net Special for Image Inpainting?
U-Netβs power lies in its unique U-shaped architecture, which resembles an encoder-decoder structure. Imagine the encoder as an analyst examining an image. It uses convolutional layers to identify patterns and features, while pooling layers summarise this information, reducing image size for a more holistic view.
The decoder, on the other hand, acts like a builder. Using upsampling layers to increase the resolution of the analysed features and convolutional layers to refine them. This process allows for the gradual restoration of the image, making U-Net particularly well suited for inpainting tasks where missing elements need to be filled in.
One key advantage of U-Net over simpler autoencoders is the use of skip connections between the encoder and decoder layers. These connections act as information bridges, allowing the decoder to access the detailed features captured by the encoder. This not only helps maintain colour consistency and image properties, but also enables faster and more accurate image restoration, even after a relatively small number of training iterations.
Inside the Code: Implementing U-Net for Perfect Inpainting
In this section, I am going to introduce my U-Net implementation for image inpainting, which was implemented using the PyTorch and Pytroch lightning libraries. I will focus on the implementation of: U-Net blocks, skip connections, loss function, and training process.
For training and evaluation, I used the Nature image inpainting dataset from Kaggle. This dataset offers a diverse collection of over 100,000 natural scene images (City, Mountain, Fire, and Lake) with a resolution of 64×64, which makes it computationally efficient. The size and diversity of this dataset provide ideal conditions for the model to achieve generalisation and reconstruction quality during inpainting tasks. Worth mentioning, the data were carefully divided into training, validation, and test sets to ensure solid model evaluation.
Full details of image preprocessing steps can be found in the Github project repository.
Building Blocks: Encoder and Decoder
When comes to implementation, letβs take another look at the U-Net architecture. We can see the encoder on the left, the decoder on the right, and the so-called bottleneck in the middle. To simplify this, we can first focus on the encoder and decoder separately, as two classes. However, remember that the input of the decoder blocks must have the same resolution as the output of the encoder blocks to form a skip connection. While the decoder may have a different number of blocks, a symmetric architecture is commonly used for simplicity, and such an implementation will be described.
The U-Net encoder operates through a series of reusable blocks. Each encoder block consists of a few (usually two) pairs of a convolution layer and an activation function (e.g. ReLU), followed by a pooling layer. This block can therefore be implemented as a separate class, letβs call it EncoderStep. What is more, these blocks are stacked one after the other to form an encoder. In this way, the number of blocks used in the U-Net model can become a hyperparameter, which can then be adapted to the task of painting an image.
class EncoderStep(nn.Module):
"""
Encoder step in U-Net.
"""
def __init__(self, in_channels: int, out_channels: int) -> None:
"""
Initialize the encoder step.
Parameters
----------
in_channels : int
Number of input channels.
out_channels : int
Number of output channels.
"""
super().__init__()
self.block = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(),
)
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
Like the encoder, the decoder also consists of blocks. These blocks mirror the structure of the encoder, with a few (usually two) pairs of a convolution layer followed by an activation function. However, instead of pooling, we use transposed convolution layer (upsampling) to increase resolution and gradually recover image details. Similarly to the encoder, the blocks stack on top of each other to form a decoder. Since we want the decoder and encoder to be symmetrical (have the same number of blocks), the same hyperparameter of the number of blocks can also be reused here. In this way, we create a second class which we will call DecoderStep.
class DecoderStep(nn.Module):
"""
Decoder step in U-Net.
"""
def __init__(self, in_channels: int, out_channels: int) -> None:
"""
Initialize the decoder step.
Parameters
----------
in_channels : int
Number of input channels.
out_channels : int
Number of output channels.
"""
super().__init__()
self.upconv = nn.ConvTranspose2d(
in_channels, out_channels, kernel_size=2, stride=2
)
self.block = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(),
)
The Secret Weapon: Skip Connections
There is still one little thing we have forgotten, the skip connections. We can modify the EncoderStep class to return not just the output, but also the feature map right before pooling. This becomes our βskipβ connection. In the decoderβs forward pass (inside the DecoderStep class), we can then modify it to accept not only the upsampled feature map, but also the corresponding βskipβ connection from the encoder. These are then concatenated before feeding them into the convolutional layers of the decoder block.
class EncoderStep(nn.Module):
"""
Encoder step in U-Net.
"""
def __init__(self, in_channels: int, out_channels: int) -> None:
"""
Initialize the encoder step.
Parameters
----------
in_channels : int
Number of input channels.
out_channels : int
Number of output channels.
"""
super().__init__()
self.block = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(),
)
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass of the encoder step.
Parameters
----------
x : torch.Tensor
Input tensor.
Returns
-------
torch.Tensor
Output tensor.
"""
x = self.block(x)
x_polled = self.pool(x)
return x_polled, x
class DecoderStep(nn.Module):
"""
Decoder step in U-Net.
"""
def __init__(self, in_channels: int, out_channels: int) -> None:
"""
Initialize the decoder step.
Parameters
----------
in_channels : int
Number of input channels.
out_channels : int
Number of output channels.
"""
super().__init__()
self.upconv = nn.ConvTranspose2d(
in_channels, out_channels, kernel_size=2, stride=2
)
self.block = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(),
)
def forward(self, x: torch.Tensor, skip: torch.Tensor) -> torch.Tensor:
"""
Forward pass of the decoder step.
Parameters
----------
x : torch.Tensor
Input tensor.
skip : torch.Tensor
Skip connection tensor.
Returns
-------
torch.Tensor
Output tensor.
"""
x = self.upconv(x)
x = torch.cat([x, skip], dim=1)
x = self.block(x)
return x
Putting it All Together: The U-Net Model
Finally, we can create the complete U-Net model by combining the encoder, decoder, a bottleneck (encoder without pooling or decoder without transposed convolution) and a so-called output layer at the end (a simple convolution layer that makes sure the output has the right dimensions). Both the encoder and decoder blocks can be used repeatedly, and the number of blocks and initial channels can be adjusted based on the complexity of your inpainting task.
class UNet(nn.Module):
"""
U-Net model implementation.
"""
def __init__(
self, input_channels: int = 3, num_blocks: int = 3, start_channels: int = 8
) -> None:
"""
Initialize the U-Net model.
Parameters
----------
input_channels : int, optional
Number of input channels, by default 3
num_blocks : int, optional
Number of encoder-decoder blocks, by default 3
start_channels : int, optional
Number of channels in the first encoder block, by default 8
"""
super().__init__()
self.encoders = nn.ModuleList()
self.decoders = nn.ModuleList()
self.encoders.append(EncoderStep(input_channels, start_channels))
channels = start_channels
for _ in range(1, num_blocks):
self.encoders.append(EncoderStep(channels, channels * 2))
channels *= 2
self.bottleneck = nn.Sequential(
nn.Conv2d(channels, channels * 2, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(channels * 2, channels * 2, kernel_size=3, padding=1),
nn.ReLU(),
)
channels *= 2
for _ in range(num_blocks):
self.decoders.append(DecoderStep(channels, channels // 2))
channels //= 2
self.output = nn.Conv2d(channels, input_channels, kernel_size=1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass of the U-Net.
Parameters
----------
x : torch.Tensor
Input tensor.
Returns
-------
torch.Tensor
Output tensor.
"""
skips = []
for encoder in self.encoders:
x, skip = encoder(x)
skips.append(skip)
x = self.bottleneck(x)
for decoder, skip in zip(self.decoders, reversed(skips)):
x = decoder(x, skip)
x = self.output(x)
return x
Training the Inpainting Expert: Loss Function and the Learning Journey
Choosing the Right Weapon: Loss Functions for Image Inpainting
The success of any machine learning model is based on a well defined loss function. There are many appropriate loss functions that we can use, but the one I used in my project is a Mean Square Error (MSE) for its simplicity and efficiency. It calculates the square of pixel difference between the predicted image and the original image. While I used the entire image to calculate the loss, it can also be restricted to the corrupted region only.
Note that MSE is not always the best option, it can be sensitive to outliers, which is why it is good practice to consider the nature of your data. Alternatives such as L1 loss, which is less sensitive to outliers, or perceptual loss, which takes into account the high-level features of the images, might be better choices in some cases.
Training: Guiding the Model Toward Perfection
During the training process, we iteratively feed batches of corrupted images (x) through the U-Net model. The model generates an inpainted image based on the input, which is then evaluated by the loss function. The loss function calculates the difference between the predicted image and the original image (y), guiding the optimisation process.
I implemented the training process by creating a custom U-Net Trainer class using PyTorch Lightning. This custom class manages the training workflow, including both the training step and the validation step. If you have not used PyTorch Lightning before, I highly recommend exploring it, as it optimises the learning process and makes it more efficient. Unfortunately, in this article, I will not discuss PyTorch Lightning in detail.
class UnetTrainer(pl.LightningModule):
"""
A PyTorch Lightning Module for training a U-Net model.
This class handles the training, validation, and optimization of a U-Net model.
"""
...
def training_step(
self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int
) -> dict:
"""
Perform a training step.
Parameters
----------
batch : tuple[torch.Tensor, torch.Tensor]
The input and target tensors for the batch.
batch_idx : int
The index of the batch.
Returns
-------
dict
A dictionary with the loss for the step.
"""
x, y = batch
x, y = x.to(self.device), y.to(self.device)
y_pred = self(x)
loss = self.loss(y_pred, y)
self.log(
"train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True
)
return loss
Validation: Ensuring Generalisation Ability
While the loss function provides valuable feedback during training, its raw value does not always provide a clear picture of the modelβs generalisation ability. That is why I used a validation step to plot the predicted image against the original image, providing a visual reference to evaluate the model performance during the learning process. Including the corrupted image in the plot can offer more complete information, though I reserved this step for the evaluation stage.
class UnetTrainer(pl.LightningModule):
"""
A PyTorch Lightning Module for training a U-Net model.
This class handles the training, validation, and optimization of a U-Net model.
"""
...
def validation_step(
self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int
) -> dict:
"""
Perform a validation step.
Parameters
----------
batch : tuple[torch.Tensor, torch.Tensor]
The input and target tensors for the batch.
batch_idx : int
The index of the batch.
Returns
-------
dict
A dictionary with the loss for the step.
"""
x, y = batch
x, y = x.to(self.device), y.to(self.device)
y_pred = self(x)
loss = self.loss(y_pred, y)
self.log("val_loss", loss)
print(f"Validation loss: {loss}")
y_pred = y_pred[0].detach().cpu().numpy().transpose(1, 2, 0)
y_pred = (y_pred + 1) / 2 # Normalize to [0, 1]
y = y[0].detach().cpu().numpy().transpose(1, 2, 0)
y = (y + 1) / 2 # Normalize to [0, 1]
plt.style.use("default")
fig, axs = plt.subplots(1, 2, figsize=(8, 4))
axs[0].imshow(y_pred)
axs[0].set_title("Predicted")
axs[1].imshow(y)
axs[1].set_title("Ground Truth")
plt.suptitle(f"Epoch {self.current_epoch}")
plt.show()
return loss
The Devil is in the Details
Now that we have a solid understanding of U-Netβs core architecture, letβs go into some of the implementation details that were previously omitted to avoid complicating the basic concept.
Understanding Feature Maps and Starting Channels
One crucial aspect to consider is the starting channels parameter, but please do not confuse them with the input channels, which is the number of channels of the image (in this case we need 3 channels because the image is RGB). Starting channels represent the number of feature maps produced by the first convolutional layer in an encoder or decoder block.
A common practice is to maintain the same number of feature maps throughout all layers within a single block and to double the number of feature maps in the encoder between blocks, while halving them in the decoder symmetrically. This approach allows the network to capture increasingly complex features while maintaining a good balance between depth and width. Since the number of blocks can be a hyperparameter in your implementation, you only need to define the starting channels, the rest will be calculated according to this approach.
While larger models can achieve better results, they also come with increased time and computational complexity. In my case the images were small, so you may need a larger network, however, I personally encourage you to test smaller architectures. I found that 3β4 blocks and about 16 starting channels were sufficient for my 64×64 images. Sometimes, it is better to learn a smaller model for more epochs than a larger model in the same amount of time. In the end, I motivate you to experiment and maybe even use optimisers such as Optuna, which I recommend and also used in this project.
Kernel Size, Padding, and Stride: Balancing Efficiency and Feature Extraction
In terms of βhow to setβ kernel size in convolutional and max pooling layer, I have always heard that it is intuitive and, with the passage of time and the implemented models, a person gets this feeling. I have to agree with this, and it is hard for me to explicitly say why such a value is the most appropriate because there is no arbitrarily most appropriate value. It is all part of the experiments. Smaller kernels (e.g. 3×3) are efficient at capturing local features, but might miss larger patterns. And vice versa, larger kernels can capture a wider context, but may require more computational resources. Max pooling layers, meanwhile, often use 2×2 kernels, effectively reducing the feature mapβs spatial dimensions while retaining the most significant features, however, this does not mean that other values cannot be better.
Padding is easier to explain, setting to 1 ensures that the dimensions of the feature map remain the same after convolution. A stride of 2 in max pooling layers effectively downsamples the feature map by half. Eventually, depending on the specifics of the target task, each of these parameters can be adjusted to get the best results, just remember that everything done in the encoder must be reproduced in the same way in the decoder.
Training, Evaluation and Results
Now that the U-Net model has been built, it is time to train it using train and validation data. Using PyTorch Lightningβs built-in Trainer class, I trained the model for 30 epochs. The training process took approximately 20 to 30 minutes using Google Colab, making it a great option for those with limited resources. The instructions on how to move your project and use this platform are described in my repository; be sure to check out Github.
# Example on how to run code:
model = UNet(start_channels=16).to(device)
UNet_trainer = UnetTrainer(model)
trainer = pl.Trainer(
accelerator=device.type,
max_epochs=30,
check_val_every_n_epoch=5,
limit_val_batches=1,
)
trainer.fit(UNet_trainer, train_loader, val_loader)
After that, we need to evaluate the model on test data to verify its performance. To do that, we will use evaluation function which will show five randomly selected images in corrupted, generated and predicted versions, as well as the four metrics which we can use in image inpainting task, and those are:
- MSE (Mean Squared Error) β calculates the average squared difference between pixels in the original and inpainted images. The closer 0 is, the better the result.
- NRMSE (Normalised Root Mean Squared Error) β an improved version of MSE that normalises the error values to a range of 0 to 1, making it easier to interpret and compare results. The closer 0 is, the better the result.
- PSNR (Peak Signal to Noise Ratio) β measures the ratio between the original imageβs signal (desired information) and the noise (errors) introduced during inpainting. The higher the better, above 30 is generally considered good, and above 40 is very good.
- SSIM (Structural Similarity Index Measure) β measures the structural similarity between the original and inpainted image, considering not only the pixel brightness, but also the local structure and texture. The closer to 1 the better; typically above 0.9 is very good.
As can be seen in the metrics (which on the record are looking good) there are flawless generations, but I am not going to show only the best ones, there are also some challenging cases where the inpainting might not be perfect. These βhopeless casesβ can occur for various reasons, such as very complex image regions or limited training data for certain types of scenario.
There is still room for progressβ¦
Although the model is complete and its performance is satisfactory, there is still plenty of room for improvement. Here are a few ideas that could enhance the results even further.
Activation Function
While I have discussed the networkβs structure, number of blocks and channels, there are additional aspects to consider within the blocks themselves. An area of potential improvement there is the activation function. The model currently uses ReLU, but consider exploring functions like LeakyReLU which might be beneficial. LeakyReLU can address the βdying ReLUβ problem, where activations can become zero and never recover. This function allows a small positive gradient for negative inputs, in order to prevent this issue.
Batch Normalization
Another idea is to incorporate batch normalization, which is currently absent. Batch normalization layers can be added within the blocks or in the bottleneck, either multiple times or just once. Their goal is to stabilise and potentially accelerate the training process.
More Convolutional Layers
Adding more convolutional layers is another option. While this might be excessive for my problem, it could be beneficial for more complex tasks. More layers can enable the model to learn more intricate patterns and details in the data. (Be careful not to overdo it; too large a network can be worse than a small one)
Using Known Corruption for Improved Inpainting
Knowing the coordinates of the corrupted areas can be a significant advantage. This information can be used in the loss function, allowing the model to focus more wisely on those regions. Additionally, using this information as a patch on the original photo can lead to better results.
Experimentation is Key!
It is important to remember that there is no one-size-fits-all approach. Each technique has its advantages and drawbacks, and some may be better suited to particular problems than others. Therefore, I strongly recommend experimenting with different techniques and approaches to achieve the best results.
Takeaways
- Image inpainting is a machine-learning technique that is used to reconstruct missing parts of an image.
- U-Net is a convolutional neural network architecture known for its U-shaped structure with an encoder-decoder architecture and skip connections.
- U-Net, originally made for segmentation, is great for other problems such as image inpainting.
- The encoder uses convolutional and pooling layers to identify patterns and features in the image.
- The detector uses convolutional and upsampling layers to increase the resolution of the analyzed features and to refine them.
- Both encoder and decoder blocks in U-Net must have matching resolutions for effective skip connections.
- Although larger architecture can identify more complex patterns, it does not always mean better.
- Experimentation is the key to success.
References
[1] My personal project, https://github.com/Dawir7/Nature-inpainting,
[2] Kenneth Leung, draw.io U-Net Architecture diagram, https://github.com/kennethleungty/Neural-Network-Architecture-Diagrams/blob/main/U-Net.drawio
Join thousands of data leaders on the AI newsletter. Join over 80,000 subscribers and keep up to date with the latest developments in AI. From research to projects and ideas. If you are building an AI startup, an AI-related product, or a service, we invite you to consider becoming aΒ sponsor.
Published via Towards AI