TransUNet — No More CNNs for Image Segmentation
Last Updated on September 27, 2022 by Editorial Team
Author(s): Leo Wang
Originally published on Towards AI the World’s Leading AI and Technology News and Media Company. If you are building an AI-related product or service, we invite you to consider becoming an AI sponsor. At Towards AI, we help scale AI and technology startups. Let us help you unleash your technology to the masses.
TransUNet — Revolutionize Traditional Image Segmentation
Renovating U-Net by combining CNNs and Transformer to achieve SOTA results on image segmentation tasks.
Table of Contents
Nowadays, U-Net has dominated image segmentation tasks, especially in the field of medical imaging. Among most of the hitherto proposed U-Nets, Convolution Neural Networks (CNN) are widely used as their underlying structures.
However, CNN's can only exploit short-range (or local) information effectively due to their small convolutional kernel sizes, failing to explore long-range information sufficiently in tasks that have features with long-range relations.
To use the power of CNNs can compensate for the shortcomings of Transformers and vice versa in image segmentation tasks, Chen et al. proposed TransUNet, which is also the first image segmentation model built from Transformer. It’s also worth mentioning that the authors verified the promising outcome of combining CNNs and Transformers by firstly attempting to use pure Transformer architecture for image segmentation. However, it did not work as well as introducing CNNs in their architecture because transformers are not as good as CNNs in exploiting local features.
We know that in U-Net (Fig. 1) there is an encoder (the down-sampling path) and a decoder (the up-sampling path). The down-sampling path will encode the image’s features to a high-level map, and its details will be used by the up-sampling path to produce the final mask of the same dimension as the input.
Similarly, TransUNet also comprises an encoder and a decoder for encoding and decoding image information to produce a segmentation. Different from traditional U-Nets, TransUNet instead uses a hybrid CNN-Transformer architecture as an encoder to learn both the high-resolution spatial information from CNNs and global context information from Transformers.
To break down the steps:
- Firstly, CNN is used as a feature extractor to generate a feature map for the input, as shown in the pink box in Fig. 2.
- For each level of the feature extractor, the output feature map (that encodes for intermediate high-level feature maps) is then concatenated to the decoder path of the same level, as shown by the dotted arrows in Fig. 2.
- Then, the feature map is tokenized (vectorized) into a 2D embedding of shape (n_patch, D) by linear projection, and D is the total length of the embedding. The embeddings were pre-trained and will retain the positional information of the feature map (if you don’t understand how don’t worry about it for now, as it won’t hinder your comprehension of TransUNet too much).
- After obtaining the embeddings, they are fed into 12 Transformer layers to encode less short-range and more long-range information from the image. Each layer is illustrated in Fig. 2 (a), which uses multi-head self-attention (MSA) and multi-layer perceptron (MLP) modules. MSA is the basic building block for transformers and is explained here, and MLP is just comprised of several fully connected layers.
- Lastly, to prepare for the up-sampling path, the output is reshaped to (D, H/16, W/16). H/16 and W/16 mean that the heights and widths by this time have been shrunk by 16 times because of the previous operations.
The up-sampling process is pretty straightforward (without any fancy techniques).
- Firstly, the input from the CNN-Transformer encoder is run by a 3×3 convolution layer with ReLU activation, upsampled, and then concatenated with the output of the third-level CNN feature extractor.
- The resultant feature maps are then run through the 3×3 convolution with the ReLU activation layer again. The output is then concatenated with the output from the second-level CNN feature extractor.
- The step is repeated again. Now, the output is a mask in shape (C, H, W), with C=number of objective classes, H=image height, and W=image width.
The authors also noted that more intensive incorporation of low-level features generally leads to better segmentation accuracy.
The model is run on the Synapse Multi-Organ Segmentation dataset. The final scores are evaluated by the Dice Similarity Coefficient and Hausdorff distance.
As you can see, TransUNet performed better than most of the existing SOTA architectures, such as V-Net, ResNet U-Net, ResNet Attention U-Net, and Vision Transformers, which suggests that the Transformer-based architecture is better at leveraging self-attention than other self-attention based CNN U-Nets.
Thank you! ❤️
May we plead with you to consider giving us some applause! ❤️
Join thousands of data leaders on the AI newsletter. It’s free, we don’t spam, and we never share your email address. Keep up to date with the latest work in AI. From research to projects and ideas. If you are building an AI startup, an AI-related product, or a service, we invite you to consider becoming a sponsor.
Published via Towards AI