Understanding Attention In Transformers
Last Updated on November 15, 2024 by Editorial Team
Author(s): Shashank Bhushan
Originally published on Towards AI.
Introduction
Transformers are everywhere in machine learning nowadays. What started as a novel architecture for sequence-to-sequence language tasks such as translation, question answering, etc. can now be found in virtually all ML domains, from Computer Vision to Audio to Recommendation. While a transformer has multiple components, the core piece is undoubtedly its use of the attention mechanism. In this post we will start by going over what attention is and how transformers use it, next, we will go over some theoretical reasoning for why it works/scales so well and finally, we will look at some shortcomings of the original Transformer proposal and potential improvements.
What Is Attention
Suppose you are given the sentence βTodayβs date went wellβ and need to figure out what the word date means. As the word itself has many meanings, fruit, calendar date, etc, there isnβt a universal meaning that can be used. Instead, we will have to rely on the context, or in other words attend to the context, βwent wellβ to understand that it probably means a romantic date. Now letβs see how we would mathematically a) find these context words and b) then use them to arrive at the correct word meaning.
First, we will break the sentence down into words or tokens (in practical applications the tokens are generally sub-word) and then replace each word with its corresponding embedding representation from a pre-learned system. If you are not sure what embeddings are, just think of them as an n-dimensional vector that semantically represents a word. These representations also maintain relational properties between words. E.g. the distance between the representation of King and Queen would be the same as the distance between Man and Woman.
Now letβs get back to attention. Given that we have these embeddings that capture semantic information, one way to figure out the context words would be to compute the similarity between the word embeddings. Words with high similarity are most likely to be found together in text making them the right candidates to provide the contextual information. The similarity can be computed using functions such as cosine similarity or dot product. Once we have computed the similarity of the target word with all the words in the sentence (including the target word) we can do a weighted sum of word embeddings using the similarity as the weights to get an updated embedding for the target word.
If it is not clear why doing a weighted sum would work. Think of the initial embedding of the target word βdateβ as an average representation that captures all different meanings for the word and we want to move this representation in a direction that is more aligned to its current context. The embedding similarity tells us how different words should affect the final direction of βdateββs embedding. The weighted sum thus allows us to move the embedding in the appropriate direction.
Note: The similarity weights should be normalized so that they sum up to 1, before doing the weighted sum.
How is Attention Used in Transformers
Now, we are ready to examine how attention is defined in the original Transformer paper, Attention is all you need.
Here Q, K, and V are all an NxD matrix called the query, key, and value matrix respectively. N is the number of tokens/words each of which is represented by a D size vector. There is also a scaling factor and optional masking in the equation which I am ignoring for simplicity. While this may look very different from what we just looked at, it does the same thing. If we assume Q, K, and V are all the same NxD matrix. Then the Q*Kα΅ is a matrix operation that does the similarity computation for all pairs of tokens at once. The output of this operation will be an NxN matrix where the ith row represents the similarity of the ith with the remaining tokens in the sentence. Using the earlier example of βTodayβs date went wellβ the 1st row (assuming 0 indexed matrix) would store the similarity of the word βdateβ with all the other words in the sentence.
Softmax would then normalize these similarity values. Finally, the matrix multiplication of the softmax output with V computes the weighted sum of the embeddings for all the tokens/words at once.
So why are there 3 different matrices and whatβs their significance?
Query, Key, and Value matrices are transformed versions of the input embedding matrix. Before being passed to the attention mechanism, the input embedding goes through 3 linear projections, the reason for doing this is to add more learning power to the attention mechanism. The terminology used is borrowed from database systems. Q/query represents the values for which we want to compute the attention, K/keys represents values over which attention can be computed (keys in the database). Finally, V/values are the output values.
Multi-Head Attention
Each transformer block has multiple attention mechanisms or heads running in parallel. The outputs of the individual heads are concatenated together and then run through a linear layer to generate the final output. The figure below shows the overall setup.
There are two important things to call out about the multi-head setup:
- Each attention head has its own matrices that create Q, K, and V matrices. This allows each attention head to focus on different properties
- The per token embedding dimension in the Q, K, and V matrix, K is smaller than the original embedding dimension E such that K = E/n where n is the number of attention heads. This is the reason why the output of attention heads gets concatenated, thereby making the output dimension size E again. The reason for doing this is to ensure that multi-head attention is computationally equivalent to a single attention whose embedding size is E.
What Makes Transformers Powerful
Unlike a Convolutional or Recurrent Neural Network, Transformers or rather the attention setup does not make any assumption on the problem setup. Rather the relationship is learned based on the data. A transformer can mimic both a convolutional neural network and an RNN. This lack of assumption or lack of inductive bias is what makes transformers so powerful.
This however also means that Transformers need a lot of data to generalize well. In the first Vision Transformer paper βAn Image is Worth 16×16 Wordsβ, the authors noted that transformers only started to outperform CNN based architecture when they were trained on large training sets such as JFT 300M which has 300M labeled images.
Another way to understand the statement βtransformers have a lack of inductiveβ is to think of the attention mechanism as a fully connected Neural Network whose input is V and the weights are the output of the Softmax(Q*Kα΅) operation. The dynamic and input-dependent nature of the weights gives transformers the descriptive power it has.
Making Transformers More Efficient
One of the key drawbacks of the transformer architecture is its time complexity. In the attention mechanism, each token/input needs to attend to each input in the given sequence, making the operation quadratic w.r.t. to the input sequence length. While there are many different ways researchers are approaching this problem, a few that stand out to me are:
Pattern Based Methods
These simplify the self-attention matrix by limiting the field of view to either fixed/ predefined or learnable patterns. Some examples of such pattern-based methods are:
- Band or Block Attention: This pattern relies on the fact that most data come with a strong property of locality. Thus it is natural to restrict each query to attend to its neighbor tokens. Models such as Blockwise (Qiu et al) use this approach.
- Strided Attention: This pattern is similar to Band/Block attention, except the receptive field i.e. the attention field is larger with gaps in between. Models such as Sparse Transformer (Child et al) employ this method.
- Learnable Patterns: Instead of relying on fixed predetermined patterns the patterns are learned from data. An example of this is Reformer (Kitaev et al., 2020) which uses a hash-based similarity measure to efficiently cluster tokens in Q and K matrices and then limit attention for a given token to only tokens in its cluster.
These pattern-based approaches add inductive bias to the model architecture to make it more efficient.
Low-Rank Methods
Low Rank methods are another way to improve efficiency. The key idea is to assume that there exists a low rank approximation of the self-attention matrix. Linformer (Wang et al., 2020c) is an early example of this technique. It first emperically shows that self attention is low rank and then to exploit this it projects the keys and values matrix to a lower KxD (N Β» K) matrix.
They show that the choice of K can be constant and independent of sequence length, thereby making the runtime complexity linear w.r.t. to the input sequence length N.
Thatβs all for this post, thanks for reading!
References
- Vaswani et al. Attention is all you need
- Dosovitskiy et. al. An Image is Worth 16×16 Words: Transformers for Image Recognition at Scale
- Tay et al. Efficient Transformers: A Survey
- Lin et al. A Survey of Transformers
- Child et al. Generating Long Sequences with Sparse Transformers
- Qiu et al. Blockwise Self-Attention for Long Document Understanding
- Kitaev, Kaiser et al. Reformer: The Efficient Transformer
- Wang et al. Linformer: Self-Attention with Linear Complexity
Join thousands of data leaders on the AI newsletter. Join over 80,000 subscribers and keep up to date with the latest developments in AI. From research to projects and ideas. If you are building an AI startup, an AI-related product, or a service, we invite you to consider becoming aΒ sponsor.
Published via Towards AI