@@ -97,11 +97,17 @@ def __init__(
9797
9898 # Initialize AIE RoPE operator
9999 if self .cfg ["use_aie_rope" ]:
100- self .aie_rope = AIERope (
101- num_aie_columns = 1 ,
102- num_channels = 1 ,
100+ self .aie_rope_prefill = AIERope (
103101 size = self .prompt_length * self .head_dim ,
104102 last_dim = self .head_dim ,
103+ num_aie_columns = 1 ,
104+ method_type = 0 ,
105+ )
106+ self .aie_rope_decode = AIERope (
107+ size = self .head_dim ,
108+ last_dim = self .head_dim ,
109+ num_aie_columns = 1 ,
110+ method_type = 0 ,
105111 )
106112
107113 # Initialize fused AIE MHA operator
@@ -158,6 +164,10 @@ def forward(self, x, mask, angles, input_pos=None):
158164 is_prefill = input_pos is None
159165 is_decode = input_pos is not None
160166
167+ # Step 1.
168+ # ---
169+ # Linear projections -- calculate quries, keys and values by multiplying embedding vector (in decode) or matrix (in prefill) with weight matrices
170+
161171 # Choose between GEMM (prefill) and GEMV (decode) based on KV cache usage
162172 if self .cfg ["use_kv_cache" ] and is_decode and self .cfg ["use_aie_gemv" ]:
163173 # Decode phase with KV cache - use GEMV for single token
@@ -195,10 +205,21 @@ def forward(self, x, mask, angles, input_pos=None):
195205 keys = self .W_key (x )
196206 values = self .W_value (x )
197207
208+ # Each attention head gets its own slice of the embedding dimension.
209+ # For each head, we have query, key and value.
210+ # In grouped-query attention, the keys and values are shared across groups of heads.
211+ # Therefore, we have self.num_heads queries, and self.num_kv_groups (== self.num_heads in case of regular attention) keys and values.
212+ # Each head can be applied independently to its subslice of the embedding dimension.
198213 keys = keys .view (b , num_tokens , self .num_kv_groups , self .head_dim )
199214 values = values .view (b , num_tokens , self .num_kv_groups , self .head_dim )
200215 queries = queries .view (b , num_tokens , self .num_heads , self .head_dim )
201216
217+ # Step 2.
218+ # ---
219+ # Apply positional encoding to keys and queries.
220+ # The positional embedding is applied independently to each head.
221+ # It modifies the embedding vectors to encode where in the sequence each token is located.
222+
202223 # Determine angle slice based on KV cache usage and phase
203224 if self .cfg ["use_kv_cache" ] and is_decode :
204225 # Decode phase with KV cache: use single position
@@ -208,27 +229,28 @@ def forward(self, x, mask, angles, input_pos=None):
208229 # Prefill phase or no KV cache: use all tokens
209230 angle_slice = angles [:num_tokens , :]
210231
211- # Apply RoPE with AIE or CPU fallback
232+ # Apply RoPE with AIE
212233 def apply_rope_and_transpose (tensor , num_heads_dim , angle_slice ):
213- expected_seq_len = (
214- 1 if (self .cfg ["use_kv_cache" ] and is_decode ) else self .prompt_length
215- )
216- can_use_aie = (
217- self .cfg ["use_aie_rope" ]
218- and tensor .shape [- 1 ] == self .head_dim
219- and tensor .shape [- 2 ] == expected_seq_len
234+ transposed = (
235+ tensor .view (num_tokens , num_heads_dim , self .head_dim )
236+ .transpose (0 , 1 )
237+ .contiguous ()
220238 )
221-
222- if can_use_aie :
223- # AIE RoPE path: flatten -> apply -> reshape -> transpose
224- tensor = self .aie_rope ( tensor . view ( b , num_tokens , - 1 ) , angle_slice )
225- return tensor . view (
226- b , num_tokens , num_heads_dim , self .head_dim
227- ). transpose ( 1 , 2 )
239+ angle_slice = angle_slice . to ( dtype = tensor . dtype )
240+ if self . cfg [ "use_aie_rope" ] :
241+ if is_prefill :
242+ result = self .aie_rope_prefill ( transposed , angle_slice )
243+ else :
244+ result = self .aie_rope_decode ( transposed , angle_slice )
245+ result = result . view ( b , num_heads_dim , num_tokens , self . head_dim )
228246 else :
229- # CPU RoPE path: transpose -> apply
230- tensor = tensor .transpose (1 , 2 )
231- return apply_rope (tensor , angle_slice )
247+ result = apply_rope (
248+ transposed .view (1 , num_heads_dim , num_tokens , self .head_dim ),
249+ angle_slice ,
250+ )
251+ # ref = apply_rope(transposed.view(1, num_heads_dim, num_tokens, self.head_dim), angle_slice)
252+ # assert torch.allclose(ref, result, atol=0.7, rtol=0.07), "AIE RoPE result does not match reference"
253+ return result
232254
233255 keys = apply_rope_and_transpose (keys , self .num_kv_groups , angle_slice )
234256 queries = apply_rope_and_transpose (queries , self .num_heads , angle_slice )
@@ -248,10 +270,18 @@ def apply_rope_and_transpose(tensor, num_heads_dim, angle_slice):
248270 keys = cached_keys
249271 values = cached_values
250272
251- # Expand keys and values to match query heads for all cases (grouped query attention)
273+ # Step 3.
274+ # ---
275+ # Since the keys and values are shared across groups of heads in grouped-query attention,
276+ # we now expand (repeat) the same keys and values so that each head has its own keys and values.
252277 keys = keys .repeat_interleave (self .group_size , dim = 1 )
253278 values = values .repeat_interleave (self .group_size , dim = 1 )
254279
280+ # Step 4.
281+ # ---
282+ # Compute attention scores (indepdentently for each head), apply softmax to get attention weights, then apply those weights to the attention values to get output.
283+ # Attention scores are the dot-product of queries and keys.
284+
255285 # Use fused AIE MHA if enabled and conditions are met
256286 if is_prefill or not self .cfg ["use_kv_cache" ]:
257287 if (
0 commit comments