@@ -108,31 +108,44 @@ def dot_product_attention_weights(
108108 dtype = query .dtype
109109
110110 assert query .ndim == key .ndim , 'q, k must have same rank.'
111+ assert query .shape [:- 3 ] == key .shape [:- 3 ], 'q, k batch dims must match.'
112+ assert query .shape [- 1 ] == key .shape [- 1 ], 'q, k depths must match.'
111113
112114 # check if we need to broadcast Key heads to match Query heads
115+ is_gqa = False
113116 if query .shape [- 2 ] != key .shape [- 2 ]:
114117 q_heads = query .shape [- 2 ]
115118 k_heads = key .shape [- 2 ]
116119
117120 if q_heads % k_heads != 0 :
118121 raise ValueError (
119- f"Query heads ({ q_heads } ) must be a multiple of "
122+ f"Query heads ({ q_heads } ) must be multiple of "
120123 f"Key heads ({ k_heads } ) for Grouped Query Attention."
121124 )
122125
123126 n_rep = q_heads // k_heads
124- key = jnp .repeat (key , n_rep , axis = - 2 )
127+ is_gqa = True
128+ # Reshape Query: [..., Q, H_k * n_rep, D] -> [..., Q, H_k, n_rep, D]
129+ query = query .reshape (query .shape [:- 2 ] + (k_heads , n_rep , query .shape [- 1 ]))
130+ # Expand Key: [..., K, H_k, D] -> [..., K, H_k, 1, D]
131+ key = jnp .expand_dims (key , axis = - 2 )
132+
133+ # Contract: q(h)gd, k(h)1d -> hgqk (h=H_k, g=n_rep)
134+ einsum_str = '...qhgd,...kh1d->...hgqk'
135+ else :
136+ q_heads = query .shape [- 2 ]
137+ einsum_str = '...qhd,...khd->...hqk'
138+ assert query .shape [- 2 ] == key .shape [- 2 ], 'q, k num_heads must match.'
125139
126- assert query .shape [:- 3 ] == key .shape [:- 3 ], 'q, k batch dims must match.'
127- assert query .shape [- 2 ] == key .shape [- 2 ], 'q, k num_heads must match.'
128- assert query .shape [- 1 ] == key .shape [- 1 ], 'q, k depths must match.'
129140 # calculate attention matrix
130141 depth = query .shape [- 1 ]
131142 query = query / jnp .sqrt (depth ).astype (dtype )
143+
132144 # attn weight shape is (batch..., num_heads, q_length, kv_length)
133- attn_weights = jnp .einsum (
134- '...qhd,...khd->...hqk' , query , key , precision = precision
135- )
145+ attn_weights = jnp .einsum (einsum_str , query , key , precision = precision )
146+
147+ if is_gqa :
148+ attn_weights = attn_weights .reshape (attn_weights .shape [:- 4 ] + (q_heads , attn_weights .shape [- 2 ], attn_weights .shape [- 1 ]))
136149
137150 # apply attention bias: masking, dropout, proximity bias, etc.
138151 if bias is not None :
@@ -159,9 +172,12 @@ def dot_product_attention_weights(
159172 # apply attention dropout
160173 if not deterministic and dropout_rate > 0.0 :
161174 keep_prob = 1.0 - dropout_rate
175+ # Note: We use original key.ndim because we might have expanded key dim
176+ ndim_base = key .ndim - 1 if is_gqa else key .ndim
177+
162178 if broadcast_dropout :
163179 # dropout is broadcast across the batch + head dimensions
164- dropout_shape = tuple ([1 ] * (key . ndim - 2 )) + attn_weights .shape [- 2 :]
180+ dropout_shape = tuple ([1 ] * (ndim_base - 2 )) + attn_weights .shape [- 2 :]
165181 keep = random .bernoulli (dropout_rng , keep_prob , dropout_shape ) # type: ignore
166182 else :
167183 keep = random .bernoulli (dropout_rng , keep_prob , attn_weights .shape ) # type: ignore
@@ -240,36 +256,15 @@ def dot_product_attention(
240256 dtype = query .dtype
241257
242258 assert key .ndim == query .ndim == value .ndim , 'q, k, v must have same rank.'
243-
244- # broadcast value heads to match query heads if needed.
245- # handle key broadcasting
246- if query .shape [- 2 ] != key .shape [- 2 ]:
247- q_heads = query .shape [- 2 ]
248- k_heads = key .shape [- 2 ]
249- if q_heads % k_heads != 0 :
250- raise ValueError (f"Query heads ({ q_heads } ) must be multiple of Key heads ({ k_heads } )" )
251- n_rep = q_heads // k_heads
252- key = jnp .repeat (key , n_rep , axis = - 2 )
253-
254- # handle value broadcasting
255- if query .shape [- 2 ] != value .shape [- 2 ]:
256- q_heads = query .shape [- 2 ]
257- v_heads = value .shape [- 2 ]
258- if q_heads % v_heads != 0 :
259- raise ValueError (f"Query heads ({ q_heads } ) must be multiple of Value heads ({ v_heads } )" )
260- n_rep = q_heads // v_heads
261- value = jnp .repeat (value , n_rep , axis = - 2 )
262-
263259 assert (
264260 query .shape [:- 3 ] == key .shape [:- 3 ] == value .shape [:- 3 ]
265261 ), 'q, k, v batch dims must match.'
266- assert (
267- query .shape [- 2 ] == key .shape [- 2 ] == value .shape [- 2 ]
268- ), 'q, k, v num_heads must match.'
269262 assert key .shape [- 3 ] == value .shape [- 3 ], 'k, v lengths must match.'
270263
271264 # Criteria that invoke the more optimized dot product attention
272- if dropout_rate == 0.0 and module == None :
265+ # We skip this optimization for GQA (mismatched heads) to use manual broadcasting
266+ if (dropout_rate == 0.0 and module == None and
267+ query .shape [- 2 ] == key .shape [- 2 ] == value .shape [- 2 ]):
273268 # make sure qkv batch are compressed to one dim
274269 query_shape = query .shape
275270 if len (query_shape ) > 4 :
@@ -302,9 +297,28 @@ def reshape_4d(x):
302297 )
303298
304299 # return weighted sum over values for each query position
305- return jnp .einsum (
306- '...hqk,...khd->...qhd' , attn_weights , value , precision = precision
307- )
300+ # Check if we need to broadcast Value heads to match Query heads (GQA)
301+ if attn_weights .shape [- 3 ] != value .shape [- 2 ]:
302+ q_heads = attn_weights .shape [- 3 ]
303+ v_heads = value .shape [- 2 ]
304+ if q_heads % v_heads != 0 :
305+ raise ValueError (f"Query heads ({ q_heads } ) must be multiple of Value heads ({ v_heads } )" )
306+
307+ n_rep = q_heads // v_heads
308+ # Reshape weights: [..., H_v, n_rep, Q, K]
309+ attn_weights = attn_weights .reshape (attn_weights .shape [:- 3 ] + (v_heads , n_rep ) + attn_weights .shape [- 2 :])
310+ # Expand Value: [..., K, H_v, 1, D]
311+ value = jnp .expand_dims (value , axis = - 2 )
312+ # Contract: hgqk, kh1d -> qhgd (h=H_v, g=n_rep)
313+ out = jnp .einsum ('...hgqk,...kh1d->...qhgd' , attn_weights , value , precision = precision )
314+ # Flatten: [..., Q, H_q, D]
315+ out = out .reshape (out .shape [:- 3 ] + (q_heads , out .shape [- 1 ]))
316+ else :
317+ out = jnp .einsum (
318+ '...hqk,...khd->...qhd' , attn_weights , value , precision = precision
319+ )
320+
321+ return out
308322
309323
310324class MultiHeadAttention (Module ):
0 commit comments