Master LLMs with our FREE course in collaboration with Activeloop & Intel Disruptor Initiative. Join now!


Attention Sinks and Where to Cache Them: A Visual Walkthrough for the Streaming LLM Implementation
Artificial Intelligence   Latest   Machine Learning

Attention Sinks and Where to Cache Them: A Visual Walkthrough for the Streaming LLM Implementation

Last Updated on November 5, 2023 by Editorial Team

Author(s): David R. Winer

Originally published on Towards AI.

GPT-2 Rolling Cache Text Generator Block

One of the latest AI papers drawing headlines is a technique for Generative Pre-training Transformer (GPT) model architectures that enables efficient, unlimited-sized context windows for text generation. This is made possible by leveraging a discovery about “attention sinks”, i.e., that the earliest tokens in the next token prediction (autoregressive) are doing most of the work for self-attention to build up a representation of the text. It is very practical because it doesn’t require fine-tuning and only requires minimal modifications to the GPT architecture. This post is focused on what those modifications are at a detailed level so that you can feel comfortable knowing how to put it into practice.

To remind you why it’s important, a vanilla LLM requires exponentially more memory and processing time as the context length gets longer to generate the next token. Plus, many models are not actually trained on very long inputs, so they suffer as inputs get longer. Every time the model generates the next token, the window gets longer. Imagine GPT writing the end of a book. For the model to understand everything it has written it needs to keep a very long context window, otherwise, the ending of the book would not wrap up all of the plot details.

The Paper:

The rest of this post is focused on the actual technique, not on justifying it or examining the results. The actual text in the paper about the technique is relatively small. Essentially, you pick some number of “attention sinks”, and then you have a queue of token embeddings after the sink that is a fixed size. At every iteration, when you generate a next token embedding, you keep the sink embeddings and just discard the embeddings for the tokens at the end of the queue.

So here’s a running example for the text starting with

Hmm, okay so this is some input

And let’s say we have 3 attention sinks and a maximum token length of just 7. Initially, we run all the tokens through the layers to produce 7 token embeddings, and we go to generate the 8th token. Let’s say the next token it produces is “text”, and in bold is the token we’ll evict next.

[Hmm, okay, so, this, is, some, input] → “text”

Then, on the next iteration, we scroll the queue back and evict the token occurring as soon as possible after the sink.

[Hmm, okay, so, is, some, input, text] → “and”

And we would keep doing this on and end.

[Hmm, okay, so, some, input, text, and] → “this”

The other thing to keep in mind is that the position embeddings don’t get rolled forward, they just stay the same. That means the position embedding associated with the token changes on each iteration.

Visualization Details

The compute steps will be shown visually using a node graph visualization tool. Each block is an operation that takes the inputs on the left side and produces the data for the output variables on the right side. Links denote the passing of data from outputs to inputs, and circles on inputs mean the data is specified in place and is static.

Operations are either composite containing an “unbox” icon, which then decomposes into a sub-graph whose inputs are the parent’s inputs and whose outputs are the parent’s outputs, or they are primitive, meaning they cannot be decomposed further and correspond to low-level tensor operations like from NumPy or TensorFlow. Colors indicate data type and patterns indicate the data shape. Blue means the data type is an integer, whereas purple/pink means it’s a decimal data type, and green means it’s text. Solid links indicate that the data shape is scalar, whereas dots in the link indicate the number of dimensions of the array (the number of dots between the dashes). At the bottom of each graph is a table that characterizes the shape, type, and operation name of each variable that is carrying data in the model.

I’ve covered and used the visualization in previous posts, such as creating a reference map for GPT Fully Visualized and BERT Fully Visualized, and for visual walkthroughs about Graph Attention Networks, the LoRA fine-tuning method, and BERTScore.

Visualized Implementation

Let’s dive into the implementation. What we’re seeing below is the beginning of a loop. At each iteration, we have the next token and the concatenated tokens so far. We only pass in the next new token into GPT because we’ll be reading and writing the cached embeddings on each layer. This is an implementation trick that makes it more efficient so that at each new iteration, we’re only needing to fetch embeddings for the latest token.

Before we get further, let’s examine some hyper parameters (global variables) of the model. The global constants are values that are static in your graph.

We will read from the public database containing GPT-2 weights, and cache them in the specified directory path “gpt_2_rolling_cache”. These cache paths are used to store the parameters of every weight and function like model parameters that are in memory.

You can see that we set the number of attention sinks to be 3 tokens, and the Max Tokens to be 7. That means that we will limit the model from processing more than 7 tokens at a time, which is pretty short, but it’s just an example. Typically, that number would match the original context lengths used in training, which for this small GPT-2 model is 32. Every time we process the next token, we’ll drop the earliest token we have cached after the attention sinks, and in total, we’ll only look at the 3 attention sinks + the last 4 tokens at each iteration.

But when we say, “look at tokens”, what does that really mean? Let’s dive into the layers. Looking at just layer 0, you can follow the breadcrumbs to see where we are inside the architecture. Here, we are fetching the dense projection layers for the Key and Value weights.

Inside the Key Rolling Cache, we are reading weights from cache. Note that we are in a conditional block, so on the first iteration, we’ll just write to cache without reading. The cache includes the token embeddings on the previous iteration. The embeddings have shape [1, 12, 7, 64].

  • Dimension 0 is for the size of the batch (1),
  • Dimension 1 is for the number of attention heads (12),
  • Dimension 2 is for the number of tokens (7),
  • Dimension 3 is for the hidden size (768) divided by the number of attention heads (64).
Reading from cache at each layer

The incoming link coming in around the read-from file is for the incoming token(s) only. In the first iteration of the loop of our example, it is [1, 12, 7, 64] and then every subsequent iteration it is only running for the next token, which is [1, 12, 1, 64]. The first thing we’ll do is split out the attention sinks (on dimension 2) and then concatenate the new embedding along the dimension 2 axis. The attention sink weights skip forward to get concatenated. Inside the evict block, we’ll evict 1 or more tokens from the end of the queue.

Rolling cache logic

Inside the evict block, you can see that we calculate how many token embeddings to slice (i.e., evict, yeah evict sounds better) off the beginning of dimension 2. In general, each new token causes 1 token to get evicted.


Finally, we take the result and concatenate it with the attention sink embeddings preceding and pass forward. We do the same for the key and value weights for each layer when we fetch the Query Key Value layers within the Self Attention operation.

Query Key Value fetch embeddings

Finally, all that’s left is the position encoding. Within the “Create Position IDs” block, we can update the position embeddings logic. The logic is relatively simple. Either we increment the position embedding for the next token because we haven’t reached the token length yet, or otherwise we keep them the same and fetch the same position embeddings.

Create Position IDs

For an example, I am comparing side-by-side GPT-2 without the rolling cache and with the rolling cache for generating 20 tokens, starting with the example I gave earlier “Hmm okay this is some input text”. This is still short and hardly requires the rolling cache, but shows that it is working.

GPT2 without rolling cache:

GPT2 with rolling caching (max 7 tokens and 3 attention sinks):

They are different, which is expected, but is one better than the other?

Thanks for reading! The full graph is a JSON available on Github. What did you think? Did I get anything wrong? Anything you want to see next? Let me know in the comments!

Join thousands of data leaders on the AI newsletter. Join over 80,000 subscribers and keep up to date with the latest developments in AI. From research to projects and ideas. If you are building an AI startup, an AI-related product, or a service, we invite you to consider becoming a sponsor.

Published via Towards AI

Feedback ↓