Master LLMs with our FREE course in collaboration with Activeloop & Intel Disruptor Initiative. Join now!

Publication

Self-Attention in Transformers: Computation Logic and Implementation
Latest   Machine Learning

Self-Attention in Transformers: Computation Logic and Implementation

Last Updated on May 9, 2024 by Editorial Team

Author(s): Anthony Demeusy

Originally published on Towards AI.

Self-attention untangles the relationships between tokens in deep learning

Attention serves as a fundamental concept for transformer architecture and for Large Language Models, playing a pivotal role in capturing dependencies between different words in a sequence. It intervenes in several building blocks of the Transformer architecture, more specifically, the multi-head self-attention, cross-attention, and masked attention stages.

Attention-based stages in the Transformer architecture, based on Attention is All You Need, Wasnani et al.
arXiv:1706.03762

While numerous resources delve into the interpretation of attention, self-attention and its other variations, the intricacies of their calculation often remain opaque, assuming a substantial prior knowledge of related works and publications.

This article will first endeavor to demystify these aspects by providing a comprehensive explanation of the calculation logic applied in a self-attention head from the ground up. The focus will then shift on detailing how the computation is optimized in practical scenarios, all this assuming only a foundational understanding of matrix multiplication and vector dot product (which — by the way — is essentially a particular case of matrix multiplication).

By untangling the intricacies of self-attention calculation and optimization, this article aims to make the topic accessible to a wide audience and to empower readers with a clear and practical understanding of this essential component of the transformer models.

Calculation logic

In the transformer architecture, like in numerous tasks related to Natural Language Processing, the input text is transformed in a series a vectors by the embedding stage. These vectors are called tokens. Each token roughlyrepresentst one wor,d and, in the case of transformers, they also include some information regarding the position in the input gained at the positional encoding stage.

Once the input text is embedded, calculating the self-attention can be envisioned as :

  • Computing attention weights reflecting how each input token relates to the other ones.
  • For each given input token, compute the sum of vectors derived from this input token weighted by its attention weights

In other terms, we use attention weights to determine how much each input vector should contributes to the output vector associated to each input token, effectively generating linear combinations of the input tokens as output.

Obtaining attention weights

The first step to obtain attention weight for one input token i relatively to another input token j is to, multiply the input token i by a matrix Wq and the input token j by a matrix Wk. Both matrix have been optimized during the process training of the transformer. These 2 multiplications generate 2 vector of the same length, respectively named query vector for the transformed token i – and key vector for the transformed token j.

Obtaining one query vector and one key vector

Conceptually, the role of these operations is to transform the input tokens into 2 new vectors, respectively the query vector and the key vector, in a way that will then allow to capture semantic information and better measure the similarity between the 2 inputs, and appreciate how they relate to each other, how strongly they are associated to each other.

From here, a raw attention weight for the input i relatively to the input j can be obtained by computing a similarity metric between the query vector for i and the key vector for j. In the case of the transformer architecture described in the original paper,r Attention isAll Youu Need, this similarity metric isthea dot-product of the query vector and the key vector . The result of this dot products is a number that depends on the inputs token i and j, and the 2 matrices Wq and Wk.

Computing a single raw weight

Attention weights for the input token i relatively to all input tokens, including itself can be obtained by applying the same principle.

All raw weights for a single query vector

All the raw weights for a input token i are then scaled by dividing them by the square root of the query and key vectors length, and by applying a Softmax function so that all the weights add up to 1. The division is a nuance introduced in Attention is all you need as the team suspected that “for large values of dk, the dot products grow large in magnitude, pushing the softmax function into regions where it has extremely small gradients

Scaling the attention weights

The same process can then be used to compute the series of attention weights for all input tokens.

Iterating to get all attention weights

Obtaining value vectors and context vectors

Attention weights determine how much each input token should contribute to the output vectors. However, before performing this weighted sum, the input tokens are transformer by multiplying them a matrix Wv, wich is — again — optimized during the training process. This operation can be viewed as a fine-tuning step aiming at extracting features from the input vectors. Resulting vectors are called ‘value vectors’.

Value vectors

Once we have these value vectors, we compute the i-th output vector, also referred to as context vector, as the sum of all value vectors weighted by the attention weights for this i-th input token

Computing a context vector

Another way to understand the construction context vectors is by expressing them as such:

Finally, the same process can be used for each input to obtain the different context vectors.

Iterative calculation of context vectors

At this point, we can now display the complete calculation logic by unwrapping the weight calculation stage, showing weight calculation and context vector calculation in parrallel.

Complete calculation logic of a self-attention module

Practical implementation

While remaining aligned with the principles explained above, in practice, there are ways to improve the computational efficiency. In particular, matrices are used to achieve faster calculations. Understanding how this is achieves can also some light on notations and formulas you may find in the literature.

So far, in order to obtain the attention weights, each token had been represented as a vector, and multiplied each of these vectors by a matrix Wq to generate the queries, by a matrix Wk to generate key vectors, and by a matrix Wv to generate value vectors. Practically, an equivalent way of obtaining the same results is to stack all input vectors in a one single matrix and multiply this matrix respectively by Wq, Wk and Wv to obtain matrices containing all query, key and and value vectors, all at once. These matrices are called Query matrix, Key matrix and value matrix, respectively noted Q, K and V.

2 equivalent workflows to obtain the Query, Key and Value Matrices

Then, you can observe that multiplying 2 matrices can be envisioned as calculating the dot product of the rows of the first matrix and the columns of the second matrix.

Matrix multiplication as dot-products

Therefore, instead of calculating the dot products of all query vectors and key vectors, the raw/unscaled weights are computed by multiplying the matrices Q – whose rows are the query vectors – and the transposed of K – whose columns are the key values.

At this point, all values in the matrix are scaled by dividing them by the length of the query and key vector and by applying the Softmax function row-wise to obtain the scaled attention, leading to the formula :

Now looking at the computation of the context vectors, computing a weighted sum of all tokens can also be represented by the product of the vector containing the sequence of attention weights and the value matrix.

This calculation has to be carried-out for each sequence of scaled weights. Fortunately, we can keep using a matrix multiplication here as well, simply by stacking both the weight sequence and context vectors in matrices. Below is a illustration if this principle.

By injecting the expression for the self attention matrix, we obtain another formula you may find in the literature :

Conclusion

Since it was first introduced by Bahdanau, Cho and Bengio² in 2014, the concept in attention in neural network has been critical in advancements of deep-learning and the development of generative AI.

By detailing the calculation logic pertaining to the self-attention mechanism, this article facilitates a deep comprehension of its governing principles. It also bridges the gap between the calculation logic and the practical computation, and serves as a stepping stone to variations such the masked/causal attention and cross-attention, which underpin the development of cutting-edge models such as Large Language Models or vision transformers.

If you found this article helpful, please show your support by clapping for this article and considering subscribing for more articles on machine learning and data analysis. Your engagement and feedback are highly valued as they play a crucial role in the continued delivery of high-quality content.

You can also support my work by buying me a coffee. Your support helps me continue to create and share informative content. It’s a simple and appreciated gesture that keeps the momentum going : Buy Me a Coffee.

References & resources

[1] Ashish Vaswani, et al. “Attention is all you need.” NIPS 2017.

[2] Dzmitry Bahdanau, Kyunghyun Cho, and Yoshua Bengio. “Neural machine translation by jointly learning to align and translate.” ICLR 2015.

Join thousands of data leaders on the AI newsletter. Join over 80,000 subscribers and keep up to date with the latest developments in AI. From research to projects and ideas. If you are building an AI startup, an AI-related product, or a service, we invite you to consider becoming a sponsor.

Published via Towards AI

Feedback ↓