Skip to content

Conversation

@ayulockedin
Copy link
Contributor

@ayulockedin ayulockedin commented Jan 7, 2026

What does this PR do?

This PR adds support for Grouped Query Attention (GQA) to nnx.dot_product_attention.

Previously, nnx.dot_product_attention required the number of heads in Query, Key, and Value to be identical. This caused a shape mismatch error when trying to use GQA configurations (where multiple Query heads share a single Key/Value head).

Changes Implemented:

  • Added broadcasting logic in dot_product_attention_weights to repeat Key heads to match Query heads.
  • Added broadcasting logic in dot_product_attention to repeat Value heads to match the Attention Weights.
  • Added a validation check to ensure query_heads is divisible by key_heads (raising a clear ValueError if not).
  • Added a new test file tests/nnx/nn/gqa_test.py covering valid GQA shapes and invalid configuration handling.

This change brings nnx into parity with jax.nn.dot_product_attention, enabling modern architectures (like Llama 3) to be implemented in NNX.

Fixes #5177

Checklist

  • This PR fixes a minor issue (e.g.: typo or small bug) or improves the docs (you can dismiss the other checks if that's the case).
  • This change is discussed in a Github issue/discussion (please add a link).
  • The documentation and docstrings adhere to the documentation guidelines.
  • This change includes necessary high-coverage tests. (No quality testing = no merge!)

@ayulockedin
Copy link
Contributor Author

Hi @cgarciae . Could you please take a look at this PR when you have a moment? Thanks!

@samanklesaria
Copy link
Collaborator

Hi @ayulockedin - thanks for the PR! Would you be able to add some tests explicitly comparing the results of nnx.dot_product_attention to jax.nn.dot_product_attention to make sure they produce the same results for some random inputs? You'll need to give the nnx one a module argument so it doesn't just call the jax one.

@ayulockedin
Copy link
Contributor Author

@samanklesaria on it thx

@ayulockedin ayulockedin force-pushed the feat/nnx-gqa-support branch from 35eaf6f to 9fb1a8f Compare January 8, 2026 15:45
@ayulockedin
Copy link
Contributor Author

@samanklesaria Thanks for the review! I've added test_gqa_parity_with_jax to tests/nnx/nn/gqa_test.py.

It forces the internal NNX python implementation (by passing a dummy module) and compares the output against jax.nn.dot_product_attention (where I manually broadcast the GQA inputs) to ensure numerical equivalence. All tests are passing.

@ayulockedin
Copy link
Contributor Author

@samanklesaria Good catch. I've refactored this so the rank assertions happen before broadcasting, which allowed me to simplify the logic. Looks good now??

@samanklesaria
Copy link
Collaborator

@ayulockedin Looks like pre-commit hooks are failing. Make sure you've set up pre-commit hooks as in https://flax.readthedocs.io/en/stable/contributing.html

@ayulockedin
Copy link
Contributor Author

Agreed. I removed the extra shape checks. The code now falls back to jax.nn.dot_product_attention whenever dropout_rate == 0 and module is None. @samanklesaria thx alot for reviewing and helping me :)

@samanklesaria
Copy link
Collaborator

@ayulockedin once all the tests pass I'll give it another look, but so far I don't see any major issues.

@ayulockedin
Copy link
Contributor Author

@samanklesaria there was another small typo error which made the checks fail but i have fixed that with the recent commit should be good to run the checks again thx

Copy link
Collaborator

@samanklesaria samanklesaria left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me!

@ayulockedin
Copy link
Contributor Author

@samanklesaria Just a heads up: I've opened Issue #5198 to track the follow-up work for the MultiHeadAttention module updates.

Once this functional PR lands, I plan to tackle that issue to bring full GQA parity to the class API (adding num_key_value_heads to init). Just wanted to link the two contextually so we have a roadmap!

@cgarciae
Copy link
Collaborator

@ayulockedin @samanklesaria lets move the tests to attention_test.py.

@ayulockedin
Copy link
Contributor Author

Done! I've moved the GQA tests to attention_test.py (added as the TestGQADotProductAttention class) and removed the separate test file.

I also ran the pre-commit hooks and squashed everything into a single clean commit. Ready for review! @cgarcia

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Support QGA in nnx dot_product_attention

3 participants