This repository contains an implementation of Group Query Attention (GQA), an efficient variant of multi-head attention used in modern transformer models like LLaMA.
In standard multi-head attention, each attention head has its own query, key, and value projections. Group Query Attention optimizes this by grouping multiple query heads to share the same key and value heads, reducing computational cost and memory usage while maintaining performance.
- Efficiency: Fewer key-value computations compared to standard multi-head attention.
- Scalability: Particularly beneficial for large models with many attention heads.
- Compatibility: Can be used as a drop-in replacement in transformer architectures.
Clone the repository and install the required dependencies:
git clone https://github.com/shaheennabi/Group-Query-Attention.git
cd Group-Query-Attention
pip install -r requirements.txtThe main component is the GroupQueryAttention class in gqa.py. Here's a basic example:
import torch
from gqa import GroupQueryAttention
# Model parameters
d_in = 512 # Input dimension
num_heads = 8 # Number of attention heads
num_kv_groups = 2 # Number of key-value groups (heads per group = 4)
head_dim = 64 # Dimension per head
# Initialize the attention layer
gqa = GroupQueryAttention(
d_in=d_in,
num_heads=num_heads,
num_kv_groups=num_kv_groups,
head_dim=head_dim
)
# Example input
batch_size = 2
seq_len = 10
x = torch.randn(batch_size, seq_len, d_in)
# Forward pass (simplified - actual usage requires mask, cos, sin for RoPE)
# Note: This is a basic example; see code for full parameters
output, cache = gqa(x, mask=None, cos=None, sin=None)
print(output.shape) # (batch_size, seq_len, d_in)- Grouped KV Heads: Reduces KV cache size and computation.
- RoPE Support: Includes Rotary Position Embedding for better positional encoding.
- KV Caching: Efficient for autoregressive generation.
- RMSNorm: Uses Root Mean Square Layer Normalization.
- Query Heads:
num_heads(e.g., 8) - KV Groups:
num_kv_groups(e.g., 2) - Group Size:
num_heads // num_kv_groups(e.g., 4 queries per KV pair)
This means each KV group serves 4 query heads, sharing computations.
This project is licensed under the MIT License - see the LICENSE file for details.
Feel free to open issues or submit pull requests for improvements and bug fixes.