-
Notifications
You must be signed in to change notification settings - Fork 787
feat(nnx): add Grouped Query Attention (GQA) support #5180
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Hi @cgarciae . Could you please take a look at this PR when you have a moment? Thanks! |
|
Hi @ayulockedin - thanks for the PR! Would you be able to add some tests explicitly comparing the results of |
|
@samanklesaria on it thx |
35eaf6f to
9fb1a8f
Compare
|
@samanklesaria Thanks for the review! I've added test_gqa_parity_with_jax to tests/nnx/nn/gqa_test.py. It forces the internal NNX python implementation (by passing a dummy module) and compares the output against jax.nn.dot_product_attention (where I manually broadcast the GQA inputs) to ensure numerical equivalence. All tests are passing. |
9fb1a8f to
7ac9f52
Compare
|
@samanklesaria Good catch. I've refactored this so the rank assertions happen before broadcasting, which allowed me to simplify the logic. Looks good now?? |
|
@ayulockedin Looks like pre-commit hooks are failing. Make sure you've set up pre-commit hooks as in https://flax.readthedocs.io/en/stable/contributing.html |
|
Agreed. I removed the extra shape checks. The code now falls back to jax.nn.dot_product_attention whenever dropout_rate == 0 and module is None. @samanklesaria thx alot for reviewing and helping me :) |
|
@ayulockedin once all the tests pass I'll give it another look, but so far I don't see any major issues. |
|
@samanklesaria there was another small typo error which made the checks fail but i have fixed that with the recent commit should be good to run the checks again thx |
7233b20 to
2ba1a78
Compare
2ba1a78 to
c3a4286
Compare
samanklesaria
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me!
|
@samanklesaria Just a heads up: I've opened Issue #5198 to track the follow-up work for the MultiHeadAttention module updates. Once this functional PR lands, I plan to tackle that issue to bring full GQA parity to the class API (adding num_key_value_heads to init). Just wanted to link the two contextually so we have a roadmap! |
|
@ayulockedin @samanklesaria lets move the tests to |
e2294b6 to
4daedfe
Compare
|
Done! I've moved the GQA tests to attention_test.py (added as the TestGQADotProductAttention class) and removed the separate test file. I also ran the pre-commit hooks and squashed everything into a single clean commit. Ready for review! @cgarcia |
What does this PR do?
This PR adds support for Grouped Query Attention (GQA) to nnx.dot_product_attention.
Previously, nnx.dot_product_attention required the number of heads in Query, Key, and Value to be identical. This caused a shape mismatch error when trying to use GQA configurations (where multiple Query heads share a single Key/Value head).
Changes Implemented:
This change brings nnx into parity with jax.nn.dot_product_attention, enabling modern architectures (like Llama 3) to be implemented in NNX.
Fixes #5177
Checklist