Skip to content

shaheennabi/Group-Query-Attention

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

6 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Group Query Attention

This repository contains an implementation of Group Query Attention (GQA), an efficient variant of multi-head attention used in modern transformer models like LLaMA.

What is Group Query Attention?

In standard multi-head attention, each attention head has its own query, key, and value projections. Group Query Attention optimizes this by grouping multiple query heads to share the same key and value heads, reducing computational cost and memory usage while maintaining performance.

  • Efficiency: Fewer key-value computations compared to standard multi-head attention.
  • Scalability: Particularly beneficial for large models with many attention heads.
  • Compatibility: Can be used as a drop-in replacement in transformer architectures.

Installation

Clone the repository and install the required dependencies:

git clone https://github.com/shaheennabi/Group-Query-Attention.git
cd Group-Query-Attention
pip install -r requirements.txt

Usage

The main component is the GroupQueryAttention class in gqa.py. Here's a basic example:

import torch
from gqa import GroupQueryAttention

# Model parameters
d_in = 512          # Input dimension
num_heads = 8       # Number of attention heads
num_kv_groups = 2   # Number of key-value groups (heads per group = 4)
head_dim = 64       # Dimension per head

# Initialize the attention layer
gqa = GroupQueryAttention(
    d_in=d_in,
    num_heads=num_heads,
    num_kv_groups=num_kv_groups,
    head_dim=head_dim
)

# Example input
batch_size = 2
seq_len = 10
x = torch.randn(batch_size, seq_len, d_in)

# Forward pass (simplified - actual usage requires mask, cos, sin for RoPE)
# Note: This is a basic example; see code for full parameters
output, cache = gqa(x, mask=None, cos=None, sin=None)
print(output.shape)  # (batch_size, seq_len, d_in)

Key Features

  • Grouped KV Heads: Reduces KV cache size and computation.
  • RoPE Support: Includes Rotary Position Embedding for better positional encoding.
  • KV Caching: Efficient for autoregressive generation.
  • RMSNorm: Uses Root Mean Square Layer Normalization.

Architecture Details

  • Query Heads: num_heads (e.g., 8)
  • KV Groups: num_kv_groups (e.g., 2)
  • Group Size: num_heads // num_kv_groups (e.g., 4 queries per KV pair)

This means each KV group serves 4 query heads, sharing computations.

License

This project is licensed under the MIT License - see the LICENSE file for details.

Contributing

Feel free to open issues or submit pull requests for improvements and bug fixes.

About

This repository contains an implementation of Group Query Attention (GQA), an efficient variant of multi-head attention used in modern transformer models like LLaMA.

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages