Advanced Attention Mechanisms — II
Last Updated on November 13, 2024 by Editorial Team
Author(s): Arion Das
Originally published on Towards AI.
Flash Attention.
You can refer to it’s predecessors here: KV cache, sliding window attention, MHA, MQA, uptraining, & GQA. These methods were employed to bring down memory and compute requirements, but at the cost of quality degradation. We’ll look into some ways of fixing that here.
WHY?
We have to understand the current workings of the GPU to get the intuition behind flash attention.
Observe how the GPU SRAM is way faster than HBM. Now, here’s a quick look at what we’ve been doing so far :
1) Matrix multiplication (Q,K)
a) read Q, K to SRAM (1)
b) compute A = QxK
c) write A to HBM (2)2) Mask
a) read A to SRAM (3)
b) mask A into A’
c) write A’ to HBM (4)3) SoftMax
a) read A’ to SRAM (5)
b) SoftMax A’ to A’’
c) write A’’ to HBM (6)
Observe how many read-write operations are being performed for such simple computations. This is not ideal. Available hardware is not being used optimally.
FLASH ATTENTION
|| paper ||
The paper mentions a crucial missing component—making the attention mechanism I/O aware.
Their main aim was to avoid reading and writing the attention matrix to and from HBM. This required:
1) computing the SoftMax reduction without access to the whole input.
2) not storing the large intermediate attention matrix for the backward pass.
Now, our operation looks like :
1) Read Q,K to SRAM
2) Compute A = QxK
3) Mask A into A’
4) SoftMax A’ to A’’
5) Write A’’ to HBM
To put it into perspective, Flash attention on GPT-2 results in a 7.6x speedup on the attention computation.
TILING
Notice how SRAM memory is 1/2000 times that of HBM memory. Hence, the authors propose a workaround to make the operation computable in the small yet fast memory space.
The aim is to break down the entire weight matrix into chunks that can fit into SRAM memory, allowing multiple GPUs to process them in parallel. Effectively, instead of calculating the attention for the entire sequence in one go, the algorithm processes each tile sequentially or preferably in parallel (because this is where the real optimization works).
This can be done in the following steps:
Step 1: Chunking Query, Key & Value Tensors
Given a sequence of length n & each tile being of size t, the sequence is broken into chunks such that :
Q : tiles of size t x d_k
K : tiles of size t x d_k
V : tiles of size t x d_v
Step 2: Computing Partial Attention within Each Tile
For each pair of Q and K tiles, FlashAttention computes the scaled dot-product attention:
Compute Q x K^T for the tile: gives an attention score matrix of size t x t.
Apply SoftMax Normalization: to ensure the tile's partial attention scores are normalized within the tile context.
Multiply with V Tile: softmaxed scores for the tile are then used to weigh the corresponding V tile.
Step 3: Accumulate the results
Attention scores from each tile are accumulated tile-by-tile, ensuring the final computation of attention equivalent to QK^T but with much lower memory requirements.
This helps in reduced memory overhead, optimized bandwidth usage, & enables parallel processing. Refer to Christian Mill’s blog for the CUDA C++ implementation details.
SOFTMAX DENOMINATOR OPTIMIZATION
Stabilizing SoftMax
Looking into the softmax formula we find,
p_i = exp(l_i) / sum_j(exp(l_j)),
where p’s are probabilities and l’s are logits.
Exponential values might explode in either direction, so the authors stabilize the softmax operation by subtracting the maximum logit (m) from all the logits, ensuring the numerator lies within [0, 1].
p_i = exp(l_i — m) / sum_j(exp(l_j – m)).
Incremental / Online SoftMax
I’ll explain this using an analogy.
Imagine a new startup. It promises high revenue and projects profits. Unfortunately, in the first few years it faces losses and can’t generate the desired revenue. No profit. However, in the long term, it performs incredibly well and generates high profits just as it promised. (Yes, like Amazon.)
Similarly, here we don’t take the struggle to compute the entire denominator for every block / tile. Instead, we compute the softmax for that particular block. As an example, for the block with vectors x1 & x2,
Is this correct?
No, if we consider the softmax for the entire attention computation.
Yes, if we consider just the block/tile at i = 2;
Now, have a look at the calculation evolving to the last timestep.
The softmax value at every timestamp other than the last one is incomplete, but we eventually get to the required value at the last timestamp. I think this is a great workaround proposed by the researchers; it helps reduce the memory requirements.
I tried annotating the various phases of this incredible algorithm. I would highly recommend, though, to go through the implementation in the paper for the intuition.
Good luck understanding it in one go! If you’re willing to discuss the intricate details, shoot me a mail here : [email protected]
This concludes the advanced attention mechanisms (as of November 12, 2024). I’ll add further articles to this series should more attention mechanisms come up.
Github : 🐈⬛
LinkedIn : in
Twitter : X
Join thousands of data leaders on the AI newsletter. Join over 80,000 subscribers and keep up to date with the latest developments in AI. From research to projects and ideas. If you are building an AI startup, an AI-related product, or a service, we invite you to consider becoming a sponsor.
Published via Towards AI