Master LLMs with our FREE course in collaboration with Activeloop & Intel Disruptor Initiative. Join now!

Publication

Techniques in Self-Attention Generative Adversarial Networks
Latest   Machine Learning

Techniques in Self-Attention Generative Adversarial Networks

Last Updated on July 25, 2023 by Editorial Team

Author(s): Sherwin Chen

Originally published on Towards AI.

Self Attention GAN (SAGAN) U+007C Towards AI

Discussion about different approaches of SAGAN like spectral normalization, conditional batch normalization, etc.

Image generated by my implementation of SAGAN on celebA dataset after 120k iterations

Introduction

Self-Attention Generative Adversarial Networks(SAGAN), a structure proposed by Han Zhang et al. in PMLR 2019, has experimentally shown to significantly outperform prior works in image synthesis. In this article, we discuss several techniques involved in SAGAN, including self-attention, spectral normalization, conditional batch normalization, projection discriminator, etc.

Bonus: we will give a simple example code for each key component, but you should be aware that the code provided here is simplified only for illustrative purposes. For the whole implementation, you may refer to my repo for SAGAN on GitHub, or the official implementation from Google Brain.

Self-Attention

Motivation

Source: Self-Attention Generative Adversarial Networks.

GANs have shown their success in modeling structural texture, but they often fail to capture geometric or structural patterns that occur consistently in some classes. For example, synthesized dogs are often drawn with realistic fur texture but without clearly defined separate feet. One explanation for this is that convolutional layers are good at capturing local structures, but they have trouble discovering long-range dependencies: 1). Although deep ConvNets are theoretically capable of capturing long-range dependencies, it is hard for optimization algorithms to find parameter values that carefully coordinate multiple layers to capture these dependencies, and these parameters may be statistically brittle and prone to failure when applied to previously unseen data. 2). Large convolutional kernels increase the representational capacity, but are more computationally inefficient.

Self-attention, on the other hand, exhibits a better balance between the ability to model long-range dependencies and computational and statistical efficiency. Based on these ideas, Han Zhang et al. proposed SAGANs to introduce a self-attention mechanism into convolutional GANs.

Self-Attention with Images

We have discussed the self-attention mechanism in the previous article, which is applied to 3D sequential data to capture temporal dependencies. To apply self-attention to images, Han Zhang et al. suggest to make three major modifications:

  1. Replace fully connected layers with 1-by-1 convolutional layers.
  2. Reshape 4D tensors into 3D tensors(merging height and width) before computing attention and reshape them back afterward.
  3. Multiply the output of the attention layer by a scale parameter and add back the input feature map:
  • where o is the output of the attention layer and γ is a learnable scalar and it is initialized to 0. Introducing the learnable γ allows the network to first rely on the cues in the local neighborhood — since this is easier — and then gradually learn to assign more weight to the non-local evidence. The intuition for why we do this is straightforward: we want to learn the easy task first and then progressively increase the complexity of the task. [1]

Python Code

Python code for self-attention with 4D tensor

Spectral Normalization

Motivation

Before getting into the details of spectral normalization, we briefly introduce some basic ideas to ensure we are on the same page.

  1. A flat local minimum of a function is less sensitive to the input perturbation.
  2. A Hessian matrix describes the local curvature of a multi-variate function at a local minimum; it measures the sensitivity of a function to its input at a local minimum.
  3. The spectral norm of a real matrix is equal to its largest singular value. Specifically, for a symmetric real matrix(e.g., a Hessian matrix), its spectral norm is its largest eigenvalue. For more detailed discussions on the spectral norm and a corresponding concept of K-Lipschitz continuous functions, please refer to [6].
The flat local minimum on the black curve is projected on somewhere(the blue diamond) near the local minimum of the test function(the red dotted curve), while the sharp local-minimum projection deviates from the local minimum of the test function. Source: On Large-Batch Training For Deep Learning: Generalization Gap And Sharp Minima

In [3], Yuichi Yoshida et al. stress that a flat local minimum of a loss function generalizes better than a sharp one(according to (1)), and they formulate the flatness as the eigenvalues of the Hessian matrix of the loss function(according to (2)). Following this thought, they prove that, under some constraints, to achieve a flat local minimum, it is sufficient to bound the spectral norm of the weight matrix at each layer(partially according to (3)). Therefore, they propose to regularize the spectral norm of each weight matrix in the loss function just like L2 regularization.

Based on Y. Yoshida’s work, Takeru Miyato et al. in [2] develope spectral normalization, which explicitly normalizes the spectral norm of the weight matrix in each layer so that it satisfies the Lipschitz constraints ​σ(W)=1:

where ​σ(W)=1 is the spectral norm of ​W, a constant. We can verify its spectral norm by showing

Takeru Miyato et al. further prove that spectral normalization regularizes the gradient of W, preventing the column space of W from concentrating into one particular direction. This precludes the transformation of each layer from becoming sensitive in one direction.

How to Compute The Spectral Norm?

Assume W is of shape (N, M)​ and we have a randomly initialized vector u. the power iteration method computes the spectral norm of W as follows

where u and v approximate the first left and right singular vector of W. In practice, T=1 is sufficient since we gradually update W as well.

Python Code

Conditional Batch Normalization

The conditional batch normalization was first proposed by Harm de Vries, Florian Strub et al. [4]. The central idea is to condition the γ and β of the batch normalization on some x(e.g., language embedding), which is done by adding f(x) and h(x) to γ and β, respectively. Here, f and h could be any function(e.g. a one-hidden-layer MLP). In this way, they can incorporate some additional information into a pre-trained network with minimal overhead.

SAGAN could be implemented as a form of conditional GANs(cGANs) by integrating class labels into both the generator and discriminator. In the generator, this is achieved through conditional batch normalization layers, where we give each label a specific gamma and beta. In the discriminator, this is accomplished by projection, a method we will see soon in the next section. Here we provide the code for conditional batch normalization from [7] with some annotations.

Python Code

Projection Discriminator

In [5], Takeru Miyato proposes to incorporate class labels into the discriminator. To see how it works, we denote the conditional discriminator as D(x,y)=σ(f(x,y)), where the f(x,y) is a function of x and y. We first derive the optimal discriminator by setting the derivative of D​ to zero

Solving this equation, we get the optimal discriminator

By replacing the discriminator with σ(f(x,y))​, we have

This gives us the logits

Now we take a closer look at ​p(yU+007Cx)​​, a categorical distribution usually expressed as a softmax function. Its log-linear model is

where Z(ϕ(x))​ is the partition function. The log-likelihood ratio, therefore, would take the following form:

Now, if we put ​v=v_p-v_q, and put the normalization constant together with r(x) into​ one expression ψ(ϕ(x)), we can rewrite f(x,y)​ as

If we use y​ to denote a one-hot vector of the label​ and use V​ to denote the embedding matrix consisting of the row vectors ​v, we can rewrite the above model by

This formulation introduces the label information via an inner product as shown in the following figure.

Projection discriminator. Source: cGANs with Projection Discriminator

Python Code

Miscellanea

In this section, we briefly mention several other techniques adopted by SAGANs

  • SAGANs use the hinge loss as the adversarial loss, which is defined as
  • SAGANs use different learning rate for the generator and discriminator, which is so-called Two-Timescale Update Rule (TTUR). For ImageNet, they use 0.0004 for the discriminator and 0.0001 for the generator. In my implementation, I use 0.0001 for the discriminator and 0.00005 for the generator for the celebA dataset.

END

Hope this article helps you build some sense on SAGAN. Welcome to leave a comment if you bump into something wrong or unclear.

References

  1. Han Zhang et al. Self-Attention Generative Adversarial Networks. In ICML 2019.
  2. Takeru Miyato et al. Spectral Normalization for Generative Adversarial Networks. In ICLR 2018
  3. Yuichi Yoshida et al. Spectral Norm Regularization for Improving the Generalizability of Deep Learning
  4. Harm de Vries, Florian Strub et al. Modulating early visual processing by language
  5. Takeru Miyato, Masanori Koyama. cGANs with Projection Discriminator
  6. Official Code for SAGAN
  7. A detailed discussion on spectral norm by Christian Cosgrove

Join thousands of data leaders on the AI newsletter. Join over 80,000 subscribers and keep up to date with the latest developments in AI. From research to projects and ideas. If you are building an AI startup, an AI-related product, or a service, we invite you to consider becoming a sponsor.

Published via Towards AI

Feedback ↓