@@ -129,7 +129,7 @@ def dot_product_attention_weights(
129129 query = query .reshape (query .shape [:- 2 ] + (k_heads , n_rep , query .shape [- 1 ]))
130130 # Expand Key: [..., K, H_k, D] -> [..., K, H_k, 1, D]
131131 key = jnp .expand_dims (key , axis = - 2 )
132-
132+
133133 # Contract: q(h)gd, k(h)1d -> hgqk (h=H_k, g=n_rep)
134134 einsum_str = '...qhgd,...kh1d->...hgqk'
135135 else :
@@ -140,7 +140,7 @@ def dot_product_attention_weights(
140140 # calculate attention matrix
141141 depth = query .shape [- 1 ]
142142 query = query / jnp .sqrt (depth ).astype (dtype )
143-
143+
144144 # attn weight shape is (batch..., num_heads, q_length, kv_length)
145145 attn_weights = jnp .einsum (einsum_str , query , key , precision = precision )
146146
@@ -174,7 +174,7 @@ def dot_product_attention_weights(
174174 keep_prob = 1.0 - dropout_rate
175175 # Note: We use original key.ndim because we might have expanded key dim
176176 ndim_base = key .ndim - 1 if is_gqa else key .ndim
177-
177+
178178 if broadcast_dropout :
179179 # dropout is broadcast across the batch + head dimensions
180180 dropout_shape = tuple ([1 ] * (ndim_base - 2 )) + attn_weights .shape [- 2 :]
@@ -261,10 +261,9 @@ def dot_product_attention(
261261 ), 'q, k, v batch dims must match.'
262262 assert key .shape [- 3 ] == value .shape [- 3 ], 'k, v lengths must match.'
263263
264- # Criteria that invoke the more optimized dot product attention
265264 # We skip this optimization for GQA (mismatched heads) to use manual broadcasting
266- if ( dropout_rate == 0.0 and module == None and
267- query . shape [ - 2 ] == key . shape [ - 2 ] == value . shape [ - 2 ]) :
265+ # Criteria that invoke the more optimized dot product attention
266+ if dropout_rate == 0.0 and module is None :
268267 # make sure qkv batch are compressed to one dim
269268 query_shape = query .shape
270269 if len (query_shape ) > 4 :
@@ -303,7 +302,7 @@ def reshape_4d(x):
303302 v_heads = value .shape [- 2 ]
304303 if q_heads % v_heads != 0 :
305304 raise ValueError (f"Query heads ({ q_heads } ) must be multiple of Value heads ({ v_heads } )" )
306-
305+
307306 n_rep = q_heads // v_heads
308307 # Reshape weights: [..., H_v, n_rep, Q, K]
309308 attn_weights = attn_weights .reshape (attn_weights .shape [:- 3 ] + (v_heads , n_rep ) + attn_weights .shape [- 2 :])
0 commit comments