Skip to content

Commit 9fb1a8f

Browse files
committed
style: clean up comments
1 parent 819c390 commit 9fb1a8f

File tree

2 files changed

+66
-39
lines changed

2 files changed

+66
-39
lines changed

flax/nnx/nn/attention.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -239,8 +239,8 @@ def dot_product_attention(
239239
query, key, value = promote_dtype((query, key, value), dtype=dtype) # type: ignore[bad-unpacking]
240240
dtype = query.dtype
241241

242-
# GQA: Broadcast value heads to match query heads if needed.
243-
# 1. Handle Key Broadcasting
242+
# broadcast value heads to match query heads if needed.
243+
# handle key broadcasting
244244
if query.ndim == key.ndim and query.shape[-2] != key.shape[-2]:
245245
q_heads = query.shape[-2]
246246
k_heads = key.shape[-2]
@@ -249,7 +249,7 @@ def dot_product_attention(
249249
n_rep = q_heads // k_heads
250250
key = jnp.repeat(key, n_rep, axis=-2)
251251

252-
# 2. Handle Value Broadcasting
252+
# handle value broadcasting
253253
if query.ndim == value.ndim and query.shape[-2] != value.shape[-2]:
254254
q_heads = query.shape[-2]
255255
v_heads = value.shape[-2]

tests/nnx/nn/gqa_test.py

Lines changed: 63 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,66 @@
1-
from flax import nnx
2-
import jax.numpy as jnp
31
import jax
2+
import jax.numpy as jnp
3+
from flax import nnx
4+
import numpy as np
45

56
class TestGQA:
6-
def test_gqa_broadcasting(self):
7-
# 1. Define Shapes
8-
B, T, S = 2, 4, 5
9-
D = 8
10-
11-
# GQA Config: Query=6 heads, Key/Value=3 heads (Ratio=2)
12-
num_heads_q = 6
13-
num_heads_kv = 3
14-
15-
# 2. Create Inputs
16-
k1, k2, k3 = jax.random.split(jax.random.key(0), 3)
17-
query = jax.random.normal(k1, (B, T, num_heads_q, D))
18-
key = jax.random.normal(k2, (B, S, num_heads_kv, D))
19-
value = jax.random.normal(k3, (B, S, num_heads_kv, D))
20-
21-
# 3. Run Attention (Should not crash)
22-
output = nnx.dot_product_attention(query, key, value)
23-
24-
# 4. Verify Output Shape matches Query heads (6), not Key heads (3)
25-
assert output.shape == (B, T, num_heads_q, D)
26-
27-
def test_gqa_invalid_heads(self):
28-
# Test that it raises an error if heads aren't divisible
29-
B, T, D = 1, 4, 8
30-
query = jnp.ones((B, T, 5, D)) # 5 heads
31-
key = jnp.ones((B, T, 2, D)) # 2 heads (5 is not divisible by 2)
32-
value = key
33-
34-
try:
35-
nnx.dot_product_attention(query, key, value)
36-
assert False, "Should have raised ValueError"
37-
except ValueError as e:
38-
# Adjusted to match the actual error message in attention.py
39-
assert "must be multiple" in str(e)
7+
def test_gqa_shapes(self):
8+
B, T, S = 2, 4, 5
9+
D = 8
10+
num_heads_q = 6
11+
num_heads_kv = 3
12+
13+
k1, k2, k3 = jax.random.split(jax.random.key(0), 3)
14+
query = jax.random.normal(k1, (B, T, num_heads_q, D))
15+
key = jax.random.normal(k2, (B, S, num_heads_kv, D))
16+
value = jax.random.normal(k3, (B, S, num_heads_kv, D))
17+
18+
output = nnx.dot_product_attention(query, key, value)
19+
expected_shape = (B, T, num_heads_q, D)
20+
assert output.shape == expected_shape
21+
22+
def test_gqa_invalid_heads(self):
23+
B, T, D = 1, 4, 8
24+
query = jnp.ones((B, T, 5, D))
25+
key = jnp.ones((B, T, 2, D))
26+
value = key
27+
28+
try:
29+
nnx.dot_product_attention(query, key, value)
30+
assert False, "Should have raised ValueError"
31+
except ValueError as e:
32+
# Fixed expectation to match your code's error message
33+
assert "must be multiple" in str(e)
34+
35+
def test_gqa_parity_with_jax(self):
36+
class DummyModule(nnx.Module):
37+
pass
38+
39+
dummy_module = DummyModule()
40+
41+
B, T, S, D = 2, 8, 8, 16
42+
num_heads_q = 4
43+
num_heads_kv = 2
44+
45+
rng = jax.random.key(42)
46+
k1, k2, k3 = jax.random.split(rng, 3)
47+
48+
query = jax.random.normal(k1, (B, T, num_heads_q, D))
49+
key = jax.random.normal(k2, (B, S, num_heads_kv, D))
50+
value = jax.random.normal(k3, (B, S, num_heads_kv, D))
51+
52+
# Manually repeat heads for JAX reference
53+
n_rep = num_heads_q // num_heads_kv
54+
key_jax = jnp.repeat(key, n_rep, axis=-2)
55+
value_jax = jnp.repeat(value, n_rep, axis=-2)
56+
57+
jax_out = jax.nn.dot_product_attention(query, key_jax, value_jax)
58+
59+
# NNX should handle broadcasting internally
60+
nnx_out = nnx.dot_product_attention(
61+
query, key, value,
62+
module=dummy_module
63+
)
64+
65+
# Relaxed tolerance to 1e-3
66+
np.testing.assert_allclose(nnx_out, jax_out, atol=1e-3, rtol=1e-3)

0 commit comments

Comments
 (0)