Grouped-Query Attention(GQA) Explained
Last Updated on December 30, 2023 by Editorial Team
Author(s): Florian
Originally published on Towards AI.
From Principles to Llama2 Code Explanation
The standard practice for autoregressive decoding is to cache the keys and values of the previous tokens in the sequence to speed up attention computation. However, as the context window or batch size increases, the memory cost associated with the size of the key-value cache(kv cache) in the multi-head attention(MHA) model significantly increases.
Multi-Query attention(MQA) is a mechanism that uses only a single key-value head for multiple queries, which can save memory and greatly speed up decoder inference.
However, MQA may lead to a decrease in quality. In fact, we not only want fast inference, but also want the quality to be on par with MHA, so Grouped-query attention(GQA)[1] comes into play.
Grouped-query attention(GQA) is an interpolation of multi-query and multi-head attention. It achieves a quality similar to multi-head attention while maintaining a comparable speed to multi-query attention.
Since GQA is a newcomer, many famous large language models have not adopted it before. However, since its proposal, it has gained popularity among popular models such as Llama2[2] and Mistral 7B[3].
GQA can be seen as an intermediate or generalized form of MQA and MHA:
When there is only one group in GQA, it is referred to as MQA.When the number of groups in GQA is equal… Read the full blog for free on Medium.
Join thousands of data leaders on the AI newsletter. Join over 80,000 subscribers and keep up to date with the latest developments in AI. From research to projects and ideas. If you are building an AI startup, an AI-related product, or a service, we invite you to consider becoming aΒ sponsor.
Published via Towards AI