Skip to content

Commit eea7502

Browse files
committed
Refactor GQA: Use einsum broadcasting & enable JAX fast-path
1 parent 03f8928 commit eea7502

File tree

1 file changed

+6
-7
lines changed

1 file changed

+6
-7
lines changed

flax/nnx/nn/attention.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def dot_product_attention_weights(
129129
query = query.reshape(query.shape[:-2] + (k_heads, n_rep, query.shape[-1]))
130130
# Expand Key: [..., K, H_k, D] -> [..., K, H_k, 1, D]
131131
key = jnp.expand_dims(key, axis=-2)
132-
132+
133133
# Contract: q(h)gd, k(h)1d -> hgqk (h=H_k, g=n_rep)
134134
einsum_str = '...qhgd,...kh1d->...hgqk'
135135
else:
@@ -140,7 +140,7 @@ def dot_product_attention_weights(
140140
# calculate attention matrix
141141
depth = query.shape[-1]
142142
query = query / jnp.sqrt(depth).astype(dtype)
143-
143+
144144
# attn weight shape is (batch..., num_heads, q_length, kv_length)
145145
attn_weights = jnp.einsum(einsum_str, query, key, precision=precision)
146146

@@ -174,7 +174,7 @@ def dot_product_attention_weights(
174174
keep_prob = 1.0 - dropout_rate
175175
# Note: We use original key.ndim because we might have expanded key dim
176176
ndim_base = key.ndim - 1 if is_gqa else key.ndim
177-
177+
178178
if broadcast_dropout:
179179
# dropout is broadcast across the batch + head dimensions
180180
dropout_shape = tuple([1] * (ndim_base - 2)) + attn_weights.shape[-2:]
@@ -261,10 +261,9 @@ def dot_product_attention(
261261
), 'q, k, v batch dims must match.'
262262
assert key.shape[-3] == value.shape[-3], 'k, v lengths must match.'
263263

264-
# Criteria that invoke the more optimized dot product attention
265264
# We skip this optimization for GQA (mismatched heads) to use manual broadcasting
266-
if (dropout_rate == 0.0 and module == None and
267-
query.shape[-2] == key.shape[-2] == value.shape[-2]):
265+
# Criteria that invoke the more optimized dot product attention
266+
if dropout_rate == 0.0 and module is None:
268267
# make sure qkv batch are compressed to one dim
269268
query_shape = query.shape
270269
if len(query_shape) > 4:
@@ -303,7 +302,7 @@ def reshape_4d(x):
303302
v_heads = value.shape[-2]
304303
if q_heads % v_heads != 0:
305304
raise ValueError(f"Query heads ({q_heads}) must be multiple of Value heads ({v_heads})")
306-
305+
307306
n_rep = q_heads // v_heads
308307
# Reshape weights: [..., H_v, n_rep, Q, K]
309308
attn_weights = attn_weights.reshape(attn_weights.shape[:-3] + (v_heads, n_rep) + attn_weights.shape[-2:])

0 commit comments

Comments
 (0)