Attention Sinks and Where to Cache Them: A Visual Walkthrough for the Streaming LLM Implementation
Last Updated on November 5, 2023 by Editorial Team
Author(s): David R. Winer
Originally published on Towards AI.
One of the latest AI papers drawing headlines is a technique for Generative Pre-training Transformer (GPT) model architectures that enables efficient, unlimited-sized context windows for text generation. This is made possible by leveraging a discovery about βattention sinksβ, i.e., that the earliest tokens in the next token prediction (autoregressive) are doing most of the work for self-attention to build up a representation of the text. It is very practical because it doesnβt require fine-tuning and only requires minimal modifications to the GPT architecture. This post is focused on what those modifications are at a detailed level so that you can feel comfortable knowing how to put it into practice.
To remind you why itβs important, a vanilla LLM requires exponentially more memory and processing time as the context length gets longer to generate the next token. Plus, many models are not actually trained on very long inputs, so they suffer as inputs get longer. Every time the model generates the next token, the window gets longer. Imagine GPT writing the end of a book. For the model to understand everything it has written it needs to keep a very long context window, otherwise, the ending of the book would not wrap up all of the plot details.
The Paper:
The rest of this post is focused on the actual technique, not on justifying it or examining the results. The actual text in the paper about the technique is relatively small. Essentially, you pick some number of βattention sinksβ, and then you have a queue of token embeddings after the sink that is a fixed size. At every iteration, when you generate a next token embedding, you keep the sink embeddings and just discard the embeddings for the tokens at the end of the queue.
So hereβs a running example for the text starting with
Hmm, okay so this is some input
And letβs say we have 3 attention sinks and a maximum token length of just 7. Initially, we run all the tokens through the layers to produce 7 token embeddings, and we go to generate the 8th token. Letβs say the next token it produces is βtextβ, and in bold is the token weβll evict next.
[Hmm, okay, so, this, is, some, input] β βtextβ
Then, on the next iteration, we scroll the queue back and evict the token occurring as soon as possible after the sink.
[Hmm, okay, so, is, some, input, text] β βandβ
And we would keep doing this on and end.
[Hmm, okay, so, some, input, text, and] β βthisβ
The other thing to keep in mind is that the position embeddings donβt get rolled forward, they just stay the same. That means the position embedding associated with the token changes on each iteration.
Visualization Details
The compute steps will be shown visually using a node graph visualization tool. Each block is an operation that takes the inputs on the left side and produces the data for the output variables on the right side. Links denote the passing of data from outputs to inputs, and circles on inputs mean the data is specified in place and is static.
Operations are either composite containing an βunboxβ icon, which then decomposes into a sub-graph whose inputs are the parentβs inputs and whose outputs are the parentβs outputs, or they are primitive, meaning they cannot be decomposed further and correspond to low-level tensor operations like from NumPy or TensorFlow. Colors indicate data type and patterns indicate the data shape. Blue means the data type is an integer, whereas purple/pink means itβs a decimal data type, and green means itβs text. Solid links indicate that the data shape is scalar, whereas dots in the link indicate the number of dimensions of the array (the number of dots between the dashes). At the bottom of each graph is a table that characterizes the shape, type, and operation name of each variable that is carrying data in the model.
Iβve covered and used the visualization in previous posts, such as creating a reference map for GPT Fully Visualized and BERT Fully Visualized, and for visual walkthroughs about Graph Attention Networks, the LoRA fine-tuning method, and BERTScore.
Visualized Implementation
Letβs dive into the implementation. What weβre seeing below is the beginning of a loop. At each iteration, we have the next token and the concatenated tokens so far. We only pass in the next new token into GPT because weβll be reading and writing the cached embeddings on each layer. This is an implementation trick that makes it more efficient so that at each new iteration, weβre only needing to fetch embeddings for the latest token.
Before we get further, letβs examine some hyper parameters (global variables) of the model. The global constants are values that are static in your graph.
We will read from the public database containing GPT-2 weights, and cache them in the specified directory path βgpt_2_rolling_cacheβ. These cache paths are used to store the parameters of every weight and function like model parameters that are in memory.
You can see that we set the number of attention sinks to be 3 tokens, and the Max Tokens to be 7. That means that we will limit the model from processing more than 7 tokens at a time, which is pretty short, but itβs just an example. Typically, that number would match the original context lengths used in training, which for this small GPT-2 model is 32. Every time we process the next token, weβll drop the earliest token we have cached after the attention sinks, and in total, weβll only look at the 3 attention sinks + the last 4 tokens at each iteration.
But when we say, βlook at tokensβ, what does that really mean? Letβs dive into the layers. Looking at just layer 0, you can follow the breadcrumbs to see where we are inside the architecture. Here, we are fetching the dense projection layers for the Key and Value weights.
Inside the Key Rolling Cache, we are reading weights from cache. Note that we are in a conditional block, so on the first iteration, weβll just write to cache without reading. The cache includes the token embeddings on the previous iteration. The embeddings have shape [1, 12, 7, 64].
- Dimension 0 is for the size of the batch (1),
- Dimension 1 is for the number of attention heads (12),
- Dimension 2 is for the number of tokens (7),
- Dimension 3 is for the hidden size (768) divided by the number of attention heads (64).
The incoming link coming in around the read-from file is for the incoming token(s) only. In the first iteration of the loop of our example, it is [1, 12, 7, 64] and then every subsequent iteration it is only running for the next token, which is [1, 12, 1, 64]. The first thing weβll do is split out the attention sinks (on dimension 2) and then concatenate the new embedding along the dimension 2 axis. The attention sink weights skip forward to get concatenated. Inside the evict block, weβll evict 1 or more tokens from the end of the queue.
Inside the evict block, you can see that we calculate how many token embeddings to slice (i.e., evict, yeah evict sounds better) off the beginning of dimension 2. In general, each new token causes 1 token to get evicted.
Finally, we take the result and concatenate it with the attention sink embeddings preceding and pass forward. We do the same for the key and value weights for each layer when we fetch the Query Key Value layers within the Self Attention operation.
Finally, all thatβs left is the position encoding. Within the βCreate Position IDsβ block, we can update the position embeddings logic. The logic is relatively simple. Either we increment the position embedding for the next token because we havenβt reached the token length yet, or otherwise we keep them the same and fetch the same position embeddings.
For an example, I am comparing side-by-side GPT-2 without the rolling cache and with the rolling cache for generating 20 tokens, starting with the example I gave earlier βHmm okay this is some input textβ. This is still short and hardly requires the rolling cache, but shows that it is working.
GPT2 without rolling cache:
GPT2 with rolling caching (max 7 tokens and 3 attention sinks):
They are different, which is expected, but is one better than the other?
Thanks for reading! The full graph is a JSON available on Github. What did you think? Did I get anything wrong? Anything you want to see next? Let me know in the comments!
Join thousands of data leaders on the AI newsletter. Join over 80,000 subscribers and keep up to date with the latest developments in AI. From research to projects and ideas. If you are building an AI startup, an AI-related product, or a service, we invite you to consider becoming aΒ sponsor.
Published via Towards AI