Techniques in Self-Attention Generative Adversarial Networks
Last Updated on July 25, 2023 by Editorial Team
Author(s): Sherwin Chen
Originally published on Towards AI.
Self Attention GAN (SAGAN) U+007C Towards AI
Discussion about different approaches of SAGAN like spectral normalization, conditional batch normalization, etc.
Introduction
Self-Attention Generative Adversarial Networks(SAGAN), a structure proposed by Han Zhang et al. in PMLR 2019, has experimentally shown to significantly outperform prior works in image synthesis. In this article, we discuss several techniques involved in SAGAN, including self-attention, spectral normalization, conditional batch normalization, projection discriminator, etc.
Bonus: we will give a simple example code for each key component, but you should be aware that the code provided here is simplified only for illustrative purposes. For the whole implementation, you may refer to my repo for SAGAN on GitHub, or the official implementation from Google Brain.
Self-Attention
Motivation
GANs have shown their success in modeling structural texture, but they often fail to capture geometric or structural patterns that occur consistently in some classes. For example, synthesized dogs are often drawn with realistic fur texture but without clearly defined separate feet. One explanation for this is that convolutional layers are good at capturing local structures, but they have trouble discovering long-range dependencies: 1). Although deep ConvNets are theoretically capable of capturing long-range dependencies, it is hard for optimization algorithms to find parameter values that carefully coordinate multiple layers to capture these dependencies, and these parameters may be statistically brittle and prone to failure when applied to previously unseen data. 2). Large convolutional kernels increase the representational capacity, but are more computationally inefficient.
Self-attention, on the other hand, exhibits a better balance between the ability to model long-range dependencies and computational and statistical efficiency. Based on these ideas, Han Zhang et al. proposed SAGANs to introduce a self-attention mechanism into convolutional GANs.
Self-Attention with Images
We have discussed the self-attention mechanism in the previous article, which is applied to 3D sequential data to capture temporal dependencies. To apply self-attention to images, Han Zhang et al. suggest to make three major modifications:
- Replace fully connected layers with 1-by-1 convolutional layers.
- Reshape 4D tensors into 3D tensors(merging height and width) before computing attention and reshape them back afterward.
- Multiply the output of the attention layer by a scale parameter and add back the input feature map:
- where o is the output of the attention layer and Ξ³ is a learnable scalar and it is initialized to 0. Introducing the learnable Ξ³ allows the network to first rely on the cues in the local neighborhood β since this is easier β and then gradually learn to assign more weight to the non-local evidence. The intuition for why we do this is straightforward: we want to learn the easy task ο¬rst and then progressively increase the complexity of the task. [1]
Python Code
Spectral Normalization
Motivation
Before getting into the details of spectral normalization, we briefly introduce some basic ideas to ensure we are on the same page.
- A flat local minimum of a function is less sensitive to the input perturbation.
- A Hessian matrix describes the local curvature of a multi-variate function at a local minimum; it measures the sensitivity of a function to its input at a local minimum.
- The spectral norm of a real matrix is equal to its largest singular value. Specifically, for a symmetric real matrix(e.g., a Hessian matrix), its spectral norm is its largest eigenvalue. For more detailed discussions on the spectral norm and a corresponding concept of K-Lipschitz continuous functions, please refer to [6].
In [3], Yuichi Yoshida et al. stress that a flat local minimum of a loss function generalizes better than a sharp one(according to (1)), and they formulate the flatness as the eigenvalues of the Hessian matrix of the loss function(according to (2)). Following this thought, they prove that, under some constraints, to achieve a flat local minimum, it is sufficient to bound the spectral norm of the weight matrix at each layer(partially according to (3)). Therefore, they propose to regularize the spectral norm of each weight matrix in the loss function just like L2 regularization.
Based on Y. Yoshidaβs work, Takeru Miyato et al. in [2] develope spectral normalization, which explicitly normalizes the spectral norm of the weight matrix in each layer so that it satisfies the Lipschitz constraints βΟ(W)=1:
where βΟ(W)=1 is the spectral norm of βW, a constant. We can verify its spectral norm by showing
Takeru Miyato et al. further prove that spectral normalization regularizes the gradient of W, preventing the column space of W from concentrating into one particular direction. This precludes the transformation of each layer from becoming sensitive in one direction.
How to Compute The Spectral Norm?
Assume W is of shape (N, M)β and we have a randomly initialized vector u. the power iteration method computes the spectral norm of W as follows
where u and v approximate the first left and right singular vector of W. In practice, T=1 is sufficient since we gradually update W as well.
Python Code
Conditional Batch Normalization
The conditional batch normalization was first proposed by Harm de Vries, Florian Strub et al. [4]. The central idea is to condition the Ξ³ and Ξ² of the batch normalization on some x(e.g., language embedding), which is done by adding f(x) and h(x) to Ξ³ and Ξ², respectively. Here, f and h could be any function(e.g. a one-hidden-layer MLP). In this way, they can incorporate some additional information into a pre-trained network with minimal overhead.
SAGAN could be implemented as a form of conditional GANs(cGANs) by integrating class labels into both the generator and discriminator. In the generator, this is achieved through conditional batch normalization layers, where we give each label a specific gamma and beta. In the discriminator, this is accomplished by projection, a method we will see soon in the next section. Here we provide the code for conditional batch normalization from [7] with some annotations.
Python Code
Projection Discriminator
In [5], Takeru Miyato proposes to incorporate class labels into the discriminator. To see how it works, we denote the conditional discriminator as D(x,y)=Ο(f(x,y)), where the f(x,y) is a function of x and y. We first derive the optimal discriminator by setting the derivative of Dβ to zero
Solving this equation, we get the optimal discriminator
By replacing the discriminator with Ο(f(x,y))β, we have
This gives us the logits
Now we take a closer look at βp(yU+007Cx)ββ, a categorical distribution usually expressed as a softmax function. Its log-linear model is
where Z(Ο(x))β is the partition function. The log-likelihood ratio, therefore, would take the following form:
Now, if we put βv=v_p-v_q, and put the normalization constant together with r(x) intoβ one expression Ο(Ο(x)), we can rewrite f(x,y)β as
If we use yβ to denote a one-hot vector of the labelβ and use Vβ to denote the embedding matrix consisting of the row vectors βv, we can rewrite the above model by
This formulation introduces the label information via an inner product as shown in the following figure.
Python Code
Miscellanea
In this section, we briefly mention several other techniques adopted by SAGANs
- SAGANs use the hinge loss as the adversarial loss, which is defined as
- SAGANs use different learning rate for the generator and discriminator, which is so-called Two-Timescale Update Rule (TTUR). For ImageNet, they use 0.0004 for the discriminator and 0.0001 for the generator. In my implementation, I use 0.0001 for the discriminator and 0.00005 for the generator for the celebA dataset.
END
Hope this article helps you build some sense on SAGAN. Welcome to leave a comment if you bump into something wrong or unclear.
References
- Han Zhang et al. Self-Attention Generative Adversarial Networks. In ICML 2019.
- Takeru Miyato et al. Spectral Normalization for Generative Adversarial Networks. In ICLR 2018
- Yuichi Yoshida et al. Spectral Norm Regularization for Improving the Generalizability of Deep Learning
- Harm de Vries, Florian Strub et al. Modulating early visual processing by language
- Takeru Miyato, Masanori Koyama. cGANs with Projection Discriminator
- Official Code for SAGAN
- A detailed discussion on spectral norm by Christian Cosgrove
Join thousands of data leaders on the AI newsletter. Join over 80,000 subscribers and keep up to date with the latest developments in AI. From research to projects and ideas. If you are building an AI startup, an AI-related product, or a service, we invite you to consider becoming aΒ sponsor.
Published via Towards AI