Name: Towards AI Legal Name: Towards AI, Inc. Description: Towards AI is the world's leading artificial intelligence (AI) and technology publication. Read by thought-leaders and decision-makers around the world. Phone Number: +1-650-246-9381 Email: [email protected]
228 Park Avenue South New York, NY 10003 United States
Website: Publisher: https://towardsai.net/#publisher Diversity Policy: https://towardsai.net/about Ethics Policy: https://towardsai.net/about Masthead: https://towardsai.net/about
Name: Towards AI Legal Name: Towards AI, Inc. Description: Towards AI is the world's leading artificial intelligence (AI) and technology publication. Founders: Roberto Iriondo, , Job Title: Co-founder and Advisor Works for: Towards AI, Inc. Follow Roberto: X, LinkedIn, GitHub, Google Scholar, Towards AI Profile, Medium, ML@CMU, FreeCodeCamp, Crunchbase, Bloomberg, Roberto Iriondo, Generative AI Lab, Generative AI Lab Denis Piffaretti, Job Title: Co-founder Works for: Towards AI, Inc. Louie Peters, Job Title: Co-founder Works for: Towards AI, Inc. Louis-François Bouchard, Job Title: Co-founder Works for: Towards AI, Inc. Cover:
Towards AI Cover
Logo:
Towards AI Logo
Areas Served: Worldwide Alternate Name: Towards AI, Inc. Alternate Name: Towards AI Co. Alternate Name: towards ai Alternate Name: towardsai Alternate Name: towards.ai Alternate Name: tai Alternate Name: toward ai Alternate Name: toward.ai Alternate Name: Towards AI, Inc. Alternate Name: towardsai.net Alternate Name: pub.towardsai.net
5 stars – based on 497 reviews

Frequently Used, Contextual References

TODO: Remember to copy unique IDs whenever it needs used. i.e., URL: 304b2e42315e

Resources

Take our 85+ lesson From Beginner to Advanced LLM Developer Certification: From choosing a project to deploying a working product this is the most comprehensive and practical LLM course out there!

Publication

Advanced Attention Mechanisms — II
Latest   Machine Learning

Advanced Attention Mechanisms — II

Last Updated on November 13, 2024 by Editorial Team

Author(s): Arion Das

Originally published on Towards AI.

flash attention (from source)

Flash Attention.
You can refer to it’s predecessors here: KV cache, sliding window attention, MHA, MQA, uptraining, & GQA. These methods were employed to bring down memory and compute requirements, but at the cost of quality degradation. We’ll look into some ways of fixing that here.

WHY?

We have to understand the current workings of the GPU to get the intuition behind flash attention.

Memory segments division & their operating speeds (image by author)

Observe how the GPU SRAM is way faster than HBM. Now, here’s a quick look at what we’ve been doing so far :

1) Matrix multiplication (Q,K)
a) read Q, K to SRAM (1)
b) compute A = QxK
c) write A to HBM (2)

2) Mask
a) read A to SRAM (3)
b) mask A into A’
c) write A’ to HBM (4)

3) SoftMax
a) read A’ to SRAM (5)
b) SoftMax A’ to A’’
c) write A’’ to HBM (6)

Observe how many read-write operations are being performed for such simple computations. This is not ideal. Available hardware is not being used optimally.

FLASH ATTENTION

|| paper ||

The paper mentions a crucial missing component—making the attention mechanism I/O aware.
Their main aim was to avoid reading and writing the attention matrix to and from HBM. This required:

1) computing the SoftMax reduction without access to the whole input.
2) not storing the large intermediate attention matrix for the backward pass.

Now, our operation looks like :

1) Read Q,K to SRAM
2) Compute A = QxK
3) Mask A into A’
4) SoftMax A’ to A’’
5) Write A’’ to HBM

putting flash attention into perspective (from source)

To put it into perspective, Flash attention on GPT-2 results in a 7.6x speedup on the attention computation.

TILING

Notice how SRAM memory is 1/2000 times that of HBM memory. Hence, the authors propose a workaround to make the operation computable in the small yet fast memory space.

The aim is to break down the entire weight matrix into chunks that can fit into SRAM memory, allowing multiple GPUs to process them in parallel. Effectively, instead of calculating the attention for the entire sequence in one go, the algorithm processes each tile sequentially or preferably in parallel (because this is where the real optimization works).
This can be done in the following steps:

Step 1: Chunking Query, Key & Value Tensors

Given a sequence of length n & each tile being of size t, the sequence is broken into chunks such that :
Q : tiles of size t x d_k
K : tiles of size t x d_k
V : tiles of size t x d_v

Step 2: Computing Partial Attention within Each Tile

For each pair of Q and K tiles, FlashAttention computes the scaled dot-product attention:
Compute Q x K^T for the tile: gives an attention score matrix of size t x t.
Apply SoftMax Normalization: to ensure the tile's partial attention scores are normalized within the tile context.
Multiply with V Tile: softmaxed scores for the tile are then used to weigh the corresponding V tile.

Step 3: Accumulate the results
Attention scores from each tile are accumulated tile-by-tile, ensuring the final computation of attention equivalent to QK^T but with much lower memory requirements.

Tiling in flash attention (image by author)

This helps in reduced memory overhead, optimized bandwidth usage, & enables parallel processing. Refer to Christian Mill’s blog for the CUDA C++ implementation details.

SOFTMAX DENOMINATOR OPTIMIZATION

Stabilizing SoftMax

Looking into the softmax formula we find,
p_i = exp(l_i) / sum_j(exp(l_j)),
where p’s are probabilities and l’s are logits.

Exponential values might explode in either direction, so the authors stabilize the softmax operation by subtracting the maximum logit (m) from all the logits, ensuring the numerator lies within [0, 1].
p_i = exp(l_i — m) / sum_j(exp(l_j – m)).

Stabilized SoftMax (image by author)

Incremental / Online SoftMax

I’ll explain this using an analogy.

Imagine a new startup. It promises high revenue and projects profits. Unfortunately, in the first few years it faces losses and can’t generate the desired revenue. No profit. However, in the long term, it performs incredibly well and generates high profits just as it promised. (Yes, like Amazon.)

Similarly, here we don’t take the struggle to compute the entire denominator for every block / tile. Instead, we compute the softmax for that particular block. As an example, for the block with vectors x1 & x2,

incremental softmax at timestamp = 2 (image by author)

Is this correct?

No, if we consider the softmax for the entire attention computation.
Yes, if we consider just the block/tile at i = 2;

Now, have a look at the calculation evolving to the last timestep.

incremental softmax at timestamp = n (image by author)

The softmax value at every timestamp other than the last one is incomplete, but we eventually get to the required value at the last timestamp. I think this is a great workaround proposed by the researchers; it helps reduce the memory requirements.

annotated flash attention algorithm (image by author)

I tried annotating the various phases of this incredible algorithm. I would highly recommend, though, to go through the implementation in the paper for the intuition.

Good luck understanding it in one go! If you’re willing to discuss the intricate details, shoot me a mail here : [email protected]

This concludes the advanced attention mechanisms (as of November 12, 2024). I’ll add further articles to this series should more attention mechanisms come up.

Github : 🐈‍⬛
LinkedIn : in
Twitter : X

The Rotation of the Earth really makes my day

Join thousands of data leaders on the AI newsletter. Join over 80,000 subscribers and keep up to date with the latest developments in AI. From research to projects and ideas. If you are building an AI startup, an AI-related product, or a service, we invite you to consider becoming a sponsor.

Published via Towards AI

Feedback ↓