Skip to content

Commit 03f8928

Browse files
committed
Refactor GQA: Replace jnp.repeat with einsum broadcasting for memory efficiency
1 parent 89c08b4 commit 03f8928

File tree

1 file changed

+50
-36
lines changed

1 file changed

+50
-36
lines changed

flax/nnx/nn/attention.py

Lines changed: 50 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -108,31 +108,44 @@ def dot_product_attention_weights(
108108
dtype = query.dtype
109109

110110
assert query.ndim == key.ndim, 'q, k must have same rank.'
111+
assert query.shape[:-3] == key.shape[:-3], 'q, k batch dims must match.'
112+
assert query.shape[-1] == key.shape[-1], 'q, k depths must match.'
111113

112114
# check if we need to broadcast Key heads to match Query heads
115+
is_gqa = False
113116
if query.shape[-2] != key.shape[-2]:
114117
q_heads = query.shape[-2]
115118
k_heads = key.shape[-2]
116119

117120
if q_heads % k_heads != 0:
118121
raise ValueError(
119-
f"Query heads ({q_heads}) must be a multiple of "
122+
f"Query heads ({q_heads}) must be multiple of "
120123
f"Key heads ({k_heads}) for Grouped Query Attention."
121124
)
122125

123126
n_rep = q_heads // k_heads
124-
key = jnp.repeat(key, n_rep, axis=-2)
127+
is_gqa = True
128+
# Reshape Query: [..., Q, H_k * n_rep, D] -> [..., Q, H_k, n_rep, D]
129+
query = query.reshape(query.shape[:-2] + (k_heads, n_rep, query.shape[-1]))
130+
# Expand Key: [..., K, H_k, D] -> [..., K, H_k, 1, D]
131+
key = jnp.expand_dims(key, axis=-2)
132+
133+
# Contract: q(h)gd, k(h)1d -> hgqk (h=H_k, g=n_rep)
134+
einsum_str = '...qhgd,...kh1d->...hgqk'
135+
else:
136+
q_heads = query.shape[-2]
137+
einsum_str = '...qhd,...khd->...hqk'
138+
assert query.shape[-2] == key.shape[-2], 'q, k num_heads must match.'
125139

126-
assert query.shape[:-3] == key.shape[:-3], 'q, k batch dims must match.'
127-
assert query.shape[-2] == key.shape[-2], 'q, k num_heads must match.'
128-
assert query.shape[-1] == key.shape[-1], 'q, k depths must match.'
129140
# calculate attention matrix
130141
depth = query.shape[-1]
131142
query = query / jnp.sqrt(depth).astype(dtype)
143+
132144
# attn weight shape is (batch..., num_heads, q_length, kv_length)
133-
attn_weights = jnp.einsum(
134-
'...qhd,...khd->...hqk', query, key, precision=precision
135-
)
145+
attn_weights = jnp.einsum(einsum_str, query, key, precision=precision)
146+
147+
if is_gqa:
148+
attn_weights = attn_weights.reshape(attn_weights.shape[:-4] + (q_heads, attn_weights.shape[-2], attn_weights.shape[-1]))
136149

137150
# apply attention bias: masking, dropout, proximity bias, etc.
138151
if bias is not None:
@@ -159,9 +172,12 @@ def dot_product_attention_weights(
159172
# apply attention dropout
160173
if not deterministic and dropout_rate > 0.0:
161174
keep_prob = 1.0 - dropout_rate
175+
# Note: We use original key.ndim because we might have expanded key dim
176+
ndim_base = key.ndim - 1 if is_gqa else key.ndim
177+
162178
if broadcast_dropout:
163179
# dropout is broadcast across the batch + head dimensions
164-
dropout_shape = tuple([1] * (key.ndim - 2)) + attn_weights.shape[-2:]
180+
dropout_shape = tuple([1] * (ndim_base - 2)) + attn_weights.shape[-2:]
165181
keep = random.bernoulli(dropout_rng, keep_prob, dropout_shape) # type: ignore
166182
else:
167183
keep = random.bernoulli(dropout_rng, keep_prob, attn_weights.shape) # type: ignore
@@ -240,36 +256,15 @@ def dot_product_attention(
240256
dtype = query.dtype
241257

242258
assert key.ndim == query.ndim == value.ndim, 'q, k, v must have same rank.'
243-
244-
# broadcast value heads to match query heads if needed.
245-
# handle key broadcasting
246-
if query.shape[-2] != key.shape[-2]:
247-
q_heads = query.shape[-2]
248-
k_heads = key.shape[-2]
249-
if q_heads % k_heads != 0:
250-
raise ValueError(f"Query heads ({q_heads}) must be multiple of Key heads ({k_heads})")
251-
n_rep = q_heads // k_heads
252-
key = jnp.repeat(key, n_rep, axis=-2)
253-
254-
# handle value broadcasting
255-
if query.shape[-2] != value.shape[-2]:
256-
q_heads = query.shape[-2]
257-
v_heads = value.shape[-2]
258-
if q_heads % v_heads != 0:
259-
raise ValueError(f"Query heads ({q_heads}) must be multiple of Value heads ({v_heads})")
260-
n_rep = q_heads // v_heads
261-
value = jnp.repeat(value, n_rep, axis=-2)
262-
263259
assert (
264260
query.shape[:-3] == key.shape[:-3] == value.shape[:-3]
265261
), 'q, k, v batch dims must match.'
266-
assert (
267-
query.shape[-2] == key.shape[-2] == value.shape[-2]
268-
), 'q, k, v num_heads must match.'
269262
assert key.shape[-3] == value.shape[-3], 'k, v lengths must match.'
270263

271264
# Criteria that invoke the more optimized dot product attention
272-
if dropout_rate == 0.0 and module == None:
265+
# We skip this optimization for GQA (mismatched heads) to use manual broadcasting
266+
if (dropout_rate == 0.0 and module == None and
267+
query.shape[-2] == key.shape[-2] == value.shape[-2]):
273268
# make sure qkv batch are compressed to one dim
274269
query_shape = query.shape
275270
if len(query_shape) > 4:
@@ -302,9 +297,28 @@ def reshape_4d(x):
302297
)
303298

304299
# return weighted sum over values for each query position
305-
return jnp.einsum(
306-
'...hqk,...khd->...qhd', attn_weights, value, precision=precision
307-
)
300+
# Check if we need to broadcast Value heads to match Query heads (GQA)
301+
if attn_weights.shape[-3] != value.shape[-2]:
302+
q_heads = attn_weights.shape[-3]
303+
v_heads = value.shape[-2]
304+
if q_heads % v_heads != 0:
305+
raise ValueError(f"Query heads ({q_heads}) must be multiple of Value heads ({v_heads})")
306+
307+
n_rep = q_heads // v_heads
308+
# Reshape weights: [..., H_v, n_rep, Q, K]
309+
attn_weights = attn_weights.reshape(attn_weights.shape[:-3] + (v_heads, n_rep) + attn_weights.shape[-2:])
310+
# Expand Value: [..., K, H_v, 1, D]
311+
value = jnp.expand_dims(value, axis=-2)
312+
# Contract: hgqk, kh1d -> qhgd (h=H_v, g=n_rep)
313+
out = jnp.einsum('...hgqk,...kh1d->...qhgd', attn_weights, value, precision=precision)
314+
# Flatten: [..., Q, H_q, D]
315+
out = out.reshape(out.shape[:-3] + (q_heads, out.shape[-1]))
316+
else:
317+
out = jnp.einsum(
318+
'...hqk,...khd->...qhd', attn_weights, value, precision=precision
319+
)
320+
321+
return out
308322

309323

310324
class MultiHeadAttention(Module):

0 commit comments

Comments
 (0)