[AI/ML] Spatial Transformer Networks (STN) — Overview, Challenges And Proposed Improvements
Last Updated on November 17, 2024 by Editorial Team
Author(s): Shashwat Gupta
Originally published on Towards AI.
The modification of dynamic spatial information through spatial transformer networks (STNs) allows models to handle transformations such as scaling and rotation for subsequent tasks. They enhance recognition accuracy by enabling models to focus on essential visual regions with minimal dependence on pooling layers. This blog delves into the functional advantages and disadvantages of STNs, despite the extensive coverage in multiple studies. We also examine P-STN, a potential upgrade from 2020 including enhanced transformations and increased efficiency. The construction of more adaptable and precise machine learning models relies on an understanding of STNs and their advancements.
Disclaimer: Much of this section is inspired by the original paper on Spatial Transformer Networks [1,2,3]
Spatial Transformer Networks (STN):
STNs (Spatial Transformer Networks), by Max Jaderberg et al., are modules that can learn to adjust the spatial information in a model, making it more resistant to changes like warping. Before STNs, achieving this required many layers of Max-Pooling. Unlike pooling layers, which have fixed and small areas they examine, spatial transformers can dynamically change an image or feature map by applying different transformations for each input. These transformations affect the entire feature map and can include scaling, cropping, rotating, and bending.
This capability allows networks to focus on important parts of an image (a process called attention) and adjust these parts to a standard position, making it easier to recognize them in later layers. STNs expand on the idea of attention modules by handling spatial transformations. They can be trained using regular back-propagation, which means the entire model can be trained all at once. STNs are useful for various tasks, including image classification, locating objects, and managing spatial attention.
The STN Consists of the following 3 parts:
- Localisation Net
- Grid Generator
- Sampler
1. Localisation Network:
It takes the input feature map U ∈ R H∗W∗C , and outputs the parameters of transformation (θ = floc(U)). It can take any form but should include a final regressor layer to produce the transformation parameters θ
2. Parametrised Grid Sampling :
The output pixels are computed by applying a sampling kernel centred at each location of the input feature map. The only constraint is that the transformation should be different wrt the parameters to allow for back-propagation. A good heuristic is to predict the transformation parametrised in a low dimensional way so that the complexity of the task assigned to the localisation network is reduced, and it can also learn about the target grid representation. e.g. if τ_θ = M_θB, where B is the target representation. Thus, it is also possible to learn θ and B.
In our case, we analyze 2D transformations, which the following equation can overall summarise:
3. Differentiable Image Sampling:
Differentiable Image Sampling To perform a spatial transformation of the input feature map, a sampler must take the set of sampling points Tθ(G), along with the input feature map U, and produce the sampled output feature map V . Each (x s i , ys i ) coordinate in τθ(G) defines the spatial location in the input where a sampling kernel is applied to get the value at a particular pixel in the output V . This can be written as:
where Φₓ and Φᵧ are the parameters of a generic sampling kernel k() which defines the image interpolation (e.g. bilinear), U^c_{nm} is the value at location (n, m) in channel c of the input, and V^c_i is the output value for pixel i at location (x^t_i , y^t_i ) in channel c. Note that the sampling is done identically for each channel of the input, so every channel is transformed identically (this preserves spatial consistency between channels).
In theory, any sampling kernel can be used, as long as (sub-)gradients can be defined with respect to x^s_i and y^s_i . For example, using the integer sampling kernel reduces the above equation to:
where [x+ 0.5] rounds x to the nearest integer and δ() is the Kronecker delta function. This sampling kernel equates to just copying the value at the nearest pixel to (x s i , ys i ) to the output location (x t i , yt i ). Alternatively, a bilinear sampling kernel can be used, giving
To allow backpropagation of the loss through this sampling mechanism, we can define the gradients with respect to U and G. For bilinear sampling above equation, the partial derivatives are
This gives us a (sub-)differentiable sampling mechanism, allowing loss gradients to flow back not only to the input feature map but also to the sampling grid coordinates and, therefore, back to the transformation parameters θ and localization network since ∂x^{s}_i / ∂θ and ∂y^{s}_{i}/ ∂θ can be easily derived. Due to discontinuities in the sampling functions, sub-gradients must be used. This sampling mechanism can be implemented very efficiently on GPU by ignoring the sum over all input locations and instead just looking at the kernel support region for each output pixel.
For better warping, the STNs can be cascaded by passing the output of one STN to the next (as in [2]) and with additional input to condition (as in [1])
Pros and cons of STNs :
The overall pros of STNs are :
- STNs are very fast, and the application does not require making many modifications to the downstream model
- They can also be used to downsample or oversample a feature map (downsampling with fixed, small support might lead to an aliasing effect)
- Multiple STNs can be used. The combination can be in Series (for more complex feature learning, with the input of one STN going into another, with or without an unwarped conditional input.
- Parallel combinations are effective when there are more than one parts to focus on in images (It was shown that of 2 STNs used on the CUB-200–2011 bird classification dataset, one became head-detector and the other became body-detector)
However, STNs are notoriously known to suffer from the following 2 defects :
1. Boundary effect arises as the image is propagated and not the geometric information (e.g. if an image is rotated, STNs can fix the rotation, but they do not fix the degraded boundary effects like cut corners etc.). This could be solved by boundary aware sampling:
2. Single STN application is insufficient to learn complex transformations This could be solved by hierarcial cascaded STNs (i.e. STNs in series) with multi-scale transformations.
3. Training Difficulty: Hard to train due to sensitivity to small mis-predictions in transformation parameters — solved in P-STN (below)
4. Sensitivity to Errors: Mis-predicted transformations can lead to poor localization, adversely affecting downstream tasks — solved in P-STN (below)
P-STN : an improvement over STN
Probabilistic Spatial Transformer Networks (P-STN) by Schwöbel et al. [7], address the limitations 3 and 4 by introducing a probabilistic framework to the transformation process. Instead of predicting a single deterministic transformation, P-STN estimates a distribution over possible transformations (probabilistic Transformation).
This probabilistic approach offers several key improvements:
- Robustness Through Marginalization:
- Multiple Transformations: By sampling multiple transformations from the estimated distribution, P-STN effectively “looks” at the input from various perspectives. This marginalization over transformations mitigates the impact of any single mis-predicted transformation.
- Smoother Loss Landscape: The integration over multiple transformations results in a more stable and smoother loss landscape, facilitating easier and more reliable training.
2. Enhanced Data Augmentation:
- Learned Augmentations: The stochastic transformations serve as a form of learned data augmentation, automatically generating diverse training samples that improve the model’s generalization capabilities.
- Improved Downstream Performance: This augmentation leads to better classification accuracy, increased robustness, and improved model calibration.
3. Applicability to Diverse Domains:
- While initially designed for image data, P-STN’s probabilistic nature allows it to generalize effectively to non-visual domains, such as time-series data, further demonstrating its versatility.
The mathematical equations for the changes are as follows:
Illustrative Benefits:
- Reduced Sensitivity to Transformation Errors:
STN Loss∝Negative Log-Likelihood of a Single Transformation
P-STN Loss∝Average Negative Log-Likelihood Over Multiple Transformations
By averaging over multiple transformations, P-STN reduces the impact of any single erroneous transformation, leading to a more stable and reliable training process.
- Improved Calibration:
Calibration Error_STN > Calibration Error_P-STN
P-STN’s approach of considering multiple transformations results in better-calibrated probabilities, as evidenced by lower calibration errors compared to STN.
Probabilistic Spatial Transformer Networks enhance the original STN framework by introducing a distribution over possible spatial transformations. This probabilistic approach leads to more robust training, effective data augmentation, improved classification performance, and better-calibrated models. The integration of variational inference and Monte Carlo sampling in P-STN provides a principled way to handle transformation uncertainties, making it a significant advancement over traditional STNs.
I write about technology, investing and books I read. Here is an index to my other blogs (sorted by topic): https://medium.com/@shashwat.gpt/index-welcome-to-my-reflections-on-code-and-capital-2ac34c7213d9
References :
- Paper: IC-STN: https://arxiv.org/pdf/1612.03897.pdf
- STN: https://paperswithcode.com/method/stn
- Video: https://www.youtube.com/watch?v=6NOQC_fl1hQ&t=162s (with slides, CV reading group resources)
- Paper: Lenc and A. Vedaldi. Understanding image representations by measuring their equivariance and equivalence. CVPR, 2015 (defines affine invariance, equivariance, and equivalence criterion)
- STN — PyTorch Implementation: https://pytorch.org/tutorials/ intermediate/spatial_transformer_tutorial.html
- Scatter Nets: https://paperswithcode.com/paper/ invariant-scattering-convolution-networks#code
- P-STN: https://backend.orbit.dtu.dk/ws/portalfiles/portal/280953750/2004.03637.pdf
Join thousands of data leaders on the AI newsletter. Join over 80,000 subscribers and keep up to date with the latest developments in AI. From research to projects and ideas. If you are building an AI startup, an AI-related product, or a service, we invite you to consider becoming a sponsor.
Published via Towards AI