Understanding Mamba and Selective State Space Models (SSMs)
Last Updated on June 24, 2024 by Editorial Team
Author(s): Matthew Gunton
Originally published on Towards AI.
The Transformer architecture has been the foundation of most major
large language models (LLMs) on the market today, delivering impressive
performance and revolutionizing the field. However, this success comes
with limitations. One major challenge is that Transformers with
self-attention mechanisms inherently attend to the entire context window,
leading to quadratic scaling costs as input sizes increase. This has a
direct impact on training and inference times, making it increasingly
expensive to work with larger inputs. The quest for an architecture that
balances performance with scalability is crucial, as it could unlock new
use cases for LLMs overnight. In this blog, weβll explore a novel block
architecture that aims to achieve just that: harnessing the power of large
language models without the scalability limitations of traditional
Transformers.
The authors of βMamba: Linear-Time Sequence Modeling with Selective State Spacesβ have found a way to apply a type of machine learning called State Space Models (SSMs) so that they are competitive with Transformers. Letβs dive into this new architecture!
State Space Models
The first question you may have is: what exactly are State Space Models (SSMs)? The basic idea here is to model a system that changes over time. To accomplish this, we can choose values corresponding to parts of our system (A,B,C), which stay the same through each iteration. We then have 3 vectors that represent how the system changes β h(the state vector), x (the input vector), and y(the output vector), where hβ is the next iteration of the state vector. The key idea here is that we have new values for t and h through each round, but A B, and C stay the same. The basic equations are shown below:
Moreover, as A, B, and C themselves do not vary with time, SSMs typically use convolution under the hood β which applies the same kernel to every part of the input sequence β achieving high-performance computing during inference and training.
These models have historically been used for signal processing, economics, and control systems, however, for tasks involving discrete data, such as text, they have been less useful.
Selective State Space Model
To address the problem with discrete data, the authors introduce a new version of SSMs called a Selective State Space Model. There are two big changes here to the typical SSM. First, they introduce a selection mechanism, which helps us filter out or focus on certain data. Second, as a consequence of the selection mechanism, we can no longer use convolution β thus, the authors introduced a selective scan.
Selection Mechanism
Starting off with the selection mechanism, the authors give the model the ability to select data by changing B, C, and Ξ to be time-variant (meaning they now vary based on t). Below is both a typical implementation of SSMs (S4) and of the Selective State Space Model:
Starting off by explaining the variables, x and y are tensors with dimensions B (noting batch size), L (noting sequence length), and D (noting dimension size). N is chosen as an arbitrary value that determines the size of the tensors that follow. Ξ is a tensor that we use for the transition from the current state h to the next state hβ. Sb, Sc, and SΞ are activation functions (specifically, Sb(x)=LinearN(x)
Sc(x)=LinearN(x)
Sd(x)=BroadcastD(Linear1(x))
Td = Softplus
, where LinearN is a linear projection to the N dimension
) and discretize is changing either the matrix or tensor from a continuous time into a discrete one.
Stepping back, the major change then is the input element. In Algorithm 1, we are dealing with matrices, whereas in Algorithm 2 we have tensors for B, C, and Ξ. The additional dimension comes from running our activation functions on the input tensor x and putting that into the corresponding tensor. Whereas before, B and C carried through all of their information, now the model can determine which information is pertinent and only keep those. As a consequence of the new dependency on x, Algorithm 2 is now time (or input) varying.
This is perhaps the biggest change from a typical Transformer β rather than having self-attention, we instead have the selection mechanism to determine what the model should focus on.
Selective Scan
As a consequence of making B and C input variant, we can no longer use convolution. To get around this, the authors created a βselective scanβ algorithm, with the goal being to be hardware aware for better performance.
From the point of view of a Graphical Processing Unit (GPU), the balancing act is between data and speed. High Bandwidth Memory (HBM) has lots of space to hold the data but is slow. Static Random Access Memory (SRAM) is fast, but cannot hold a lot of data.
To understand selective scan, letβs first understand a standard scan operation applied to a Selective SSM. To do so, we would need to take the entire data of shape (B, L, D, N) into HBM as SRAM cannot handle data of that size. We then apply the calculations of our scan operation to every part of the tensor in HBM, losing a lot of time moving data between addresses in memory.
By comparison, the selective scan is much more memory-efficient. It does not take the entire data of shape (B,L,D,N), instead only operating on the updates (A, B, C, Ξ), which have shapes of (D, N) and (B, L, N) respectively. Because we are operating on significantly smaller data, we are able to hold a lot more of these in SRAM, dramatically decreasing our calculation times. Once the calculations are done, the output of size (B, L, D) is output to y. The figure above shows where each variable is stored inside of the GPUβs memory.
The great trade-off with selective scan comes when you want to do backward passes. Because we do not have the intermediate calculations anywhere in memory, we need to recompute these, thus in essence, trading compute for memory.
Mamba Block Structure
Now that we understand SSMs, we can see how the authors used them to create the Mamba architecture.
The authors relied on two block designs to create Mamba: Hungry, Hungry Hippos (H3) and Gated Multi-Layer Perceptrons (Gated MLPs). As Mamba combines both, letβs explain how each of these block structures work.
Multi-Layer Perceptrons (MLP) are extremely common in neural network architectures. They are feed-forward neural networks where each neuron in a layer is connected to every neuron in the previous layer. The gate part of Gated MLPs just adds further controls to the information flow via a reset gate and an update gate. The reset gate determines how much information is dropped, while the update gate determines how much from the input and the hidden layer should now be passed along.
H3 is about using the SSM to remember previous tokens and then multiplying this result with the other vectors to enable comparisons. To break this down, we first project the input into the familiar Key, Value, and Query vectors. The Key value goes through a βshiftβ SSM that is meant to give some memory of previous tokens. The output is then multiplied with the Value tokens so that our first comparison of the tokens occurs. We then run the βdiagonalβ SSM to propagate these token interactions across the entire sequence. We end with the query multiplication so that we can compare our stored interactions in the query with the current elements.
Mamba takes the gating feature from the Gated MLPs and then combines this with convolution and selective SSM transformation. Note that while selective SSMs can no longer be computed using convolution under the hood, there is no reason a Mamba block cannot have a convolution inside of it. From a high level, we now have a new block structure that can pass through certain parts of the input via the gating mechanism and then also focus on certain parts of that input via the selective SSM. Some consequences of the above are discussed below.
Mamba Inference Time
With the new block architecture, we are able to get significant improvements for both training and inference time.
First, because we are not doing attention, we do not have to worry about the quadratic scaling that comes from larger input sizes. If you go back to figure 1 from the paper, you can see that the state being passed is the same size regardless of the input length. Consequently, while larger input lengths will require more calculations, it is only increasing at a linear rate. This is in contrast to attention, where as the input increases the size of the attention pattern will grow quadratically. This linear vs quadratic scaling means that, all else equal, the costs and latency for training and inferencing will be dramatically better for SSMs than for Transformers.
From the graph below, we can see that the inference throughput from a constant prompt length is significantly higher for Mamba than for similarly complex Transformer models. Indeed, as the batch size goes up, the difference between the architectures greatly increases.
Moreover, when comparing the time to complete the transformation, we see that the selective scan (labeled βoursβ) performs considerably better than all other implementations, especially distinguishing itself at greater sequence lengths. As the world looks to LLMs to process increasingly more verbose data, better performance at greater lengths may become a key differentiator.
Mamba-2
Roughly 5 months after submitting the first Mamba paper, the authors released a second paper called βTransformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Dualityβ. This paper expanded upon the Mamba block structure in a number of ways to further improve computation efficiency and scalability. If there is sufficient interest, I will go deeper into this paper at some later point.
Conclusion
The authors of Mamba have ambitions to make their architecture the new bedrock for complex ML systems β from chat interactions like ChatGPT to DNA sequencing and analysis. The results they display are promising, most especially how they get around the quadratic scaling that self-attention suffers from.
Similar to the conclusion for YOCO, it remains to be seen if the industry will move to adopt the new architecture. With so many people learning about Transformers, itβs possible that there is some amount of inertia preventing people from wanting to study a new architecture. Nevertheless, the results speak for themselves. If SSM-driven architectures like Mamba can consistently perform as well as or better than Transformers β at a fraction of the training and inference cost β then they will quickly become the norm. We could also see a merging of architectures. As we saw in the Mamba architecture, combining different block structures can lead to interesting results. We may start to see more unique architectures arise.
Even if the Mamba architecture in particular, doesnβt take off, the drive towards more efficient models is a positive force β driving down costs for business and helping us protect our planet.
It is an exciting time to be building.
[1] Gu, A., et al., βMamba: Linear-Time Sequence Modeling with Selective State Spacesβ (2024), arXiv
[2] Bourdois, L., et al., βIntroduction to State Space Models (SSM)β (2024), HuggingFace
[3] Fu, D., et al., βHungry Hungry Hippos: Towards Language Modeling with State Space Modelsβ (2023), arXiv
[4] Dao, T., et al., βTransformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Dualityβ (2024), arXiv
[5] Dumoulin, V., et al., βConvolution arithmetic β Padding strides odd transposed.gifβ (2019), WikiMedia
Join thousands of data leaders on the AI newsletter. Join over 80,000 subscribers and keep up to date with the latest developments in AI. From research to projects and ideas. If you are building an AI startup, an AI-related product, or a service, we invite you to consider becoming aΒ sponsor.
Published via Towards AI