Skip to content

Commit f34ee29

Browse files
committed
use AIE RoPE
1 parent 331dcca commit f34ee29

File tree

2 files changed

+60
-32
lines changed
  • applications/llama_3.2_1b/src/block
  • operators/rope

2 files changed

+60
-32
lines changed

applications/llama_3.2_1b/src/block/gqa.py

Lines changed: 52 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -97,11 +97,17 @@ def __init__(
9797

9898
# Initialize AIE RoPE operator
9999
if self.cfg["use_aie_rope"]:
100-
self.aie_rope = AIERope(
101-
num_aie_columns=1,
102-
num_channels=1,
100+
self.aie_rope_prefill = AIERope(
103101
size=self.prompt_length * self.head_dim,
104102
last_dim=self.head_dim,
103+
num_aie_columns=1,
104+
method_type=0,
105+
)
106+
self.aie_rope_decode = AIERope(
107+
size=self.head_dim,
108+
last_dim=self.head_dim,
109+
num_aie_columns=1,
110+
method_type=0,
105111
)
106112

107113
# Initialize fused AIE MHA operator
@@ -158,6 +164,10 @@ def forward(self, x, mask, angles, input_pos=None):
158164
is_prefill = input_pos is None
159165
is_decode = input_pos is not None
160166

167+
# Step 1.
168+
# ---
169+
# Linear projections -- calculate quries, keys and values by multiplying embedding vector (in decode) or matrix (in prefill) with weight matrices
170+
161171
# Choose between GEMM (prefill) and GEMV (decode) based on KV cache usage
162172
if self.cfg["use_kv_cache"] and is_decode and self.cfg["use_aie_gemv"]:
163173
# Decode phase with KV cache - use GEMV for single token
@@ -195,10 +205,21 @@ def forward(self, x, mask, angles, input_pos=None):
195205
keys = self.W_key(x)
196206
values = self.W_value(x)
197207

208+
# Each attention head gets its own slice of the embedding dimension.
209+
# For each head, we have query, key and value.
210+
# In grouped-query attention, the keys and values are shared across groups of heads.
211+
# Therefore, we have self.num_heads queries, and self.num_kv_groups (== self.num_heads in case of regular attention) keys and values.
212+
# Each head can be applied independently to its subslice of the embedding dimension.
198213
keys = keys.view(b, num_tokens, self.num_kv_groups, self.head_dim)
199214
values = values.view(b, num_tokens, self.num_kv_groups, self.head_dim)
200215
queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
201216

217+
# Step 2.
218+
# ---
219+
# Apply positional encoding to keys and queries.
220+
# The positional embedding is applied independently to each head.
221+
# It modifies the embedding vectors to encode where in the sequence each token is located.
222+
202223
# Determine angle slice based on KV cache usage and phase
203224
if self.cfg["use_kv_cache"] and is_decode:
204225
# Decode phase with KV cache: use single position
@@ -208,27 +229,28 @@ def forward(self, x, mask, angles, input_pos=None):
208229
# Prefill phase or no KV cache: use all tokens
209230
angle_slice = angles[:num_tokens, :]
210231

211-
# Apply RoPE with AIE or CPU fallback
232+
# Apply RoPE with AIE
212233
def apply_rope_and_transpose(tensor, num_heads_dim, angle_slice):
213-
expected_seq_len = (
214-
1 if (self.cfg["use_kv_cache"] and is_decode) else self.prompt_length
215-
)
216-
can_use_aie = (
217-
self.cfg["use_aie_rope"]
218-
and tensor.shape[-1] == self.head_dim
219-
and tensor.shape[-2] == expected_seq_len
234+
transposed = (
235+
tensor.view(num_tokens, num_heads_dim, self.head_dim)
236+
.transpose(0, 1)
237+
.contiguous()
220238
)
221-
222-
if can_use_aie:
223-
# AIE RoPE path: flatten -> apply -> reshape -> transpose
224-
tensor = self.aie_rope(tensor.view(b, num_tokens, -1), angle_slice)
225-
return tensor.view(
226-
b, num_tokens, num_heads_dim, self.head_dim
227-
).transpose(1, 2)
239+
angle_slice = angle_slice.to(dtype=tensor.dtype)
240+
if self.cfg["use_aie_rope"]:
241+
if is_prefill:
242+
result = self.aie_rope_prefill(transposed, angle_slice)
243+
else:
244+
result = self.aie_rope_decode(transposed, angle_slice)
245+
result = result.view(b, num_heads_dim, num_tokens, self.head_dim)
228246
else:
229-
# CPU RoPE path: transpose -> apply
230-
tensor = tensor.transpose(1, 2)
231-
return apply_rope(tensor, angle_slice)
247+
result = apply_rope(
248+
transposed.view(1, num_heads_dim, num_tokens, self.head_dim),
249+
angle_slice,
250+
)
251+
# ref = apply_rope(transposed.view(1, num_heads_dim, num_tokens, self.head_dim), angle_slice)
252+
# assert torch.allclose(ref, result, atol=0.7, rtol=0.07), "AIE RoPE result does not match reference"
253+
return result
232254

233255
keys = apply_rope_and_transpose(keys, self.num_kv_groups, angle_slice)
234256
queries = apply_rope_and_transpose(queries, self.num_heads, angle_slice)
@@ -248,10 +270,18 @@ def apply_rope_and_transpose(tensor, num_heads_dim, angle_slice):
248270
keys = cached_keys
249271
values = cached_values
250272

251-
# Expand keys and values to match query heads for all cases (grouped query attention)
273+
# Step 3.
274+
# ---
275+
# Since the keys and values are shared across groups of heads in grouped-query attention,
276+
# we now expand (repeat) the same keys and values so that each head has its own keys and values.
252277
keys = keys.repeat_interleave(self.group_size, dim=1)
253278
values = values.repeat_interleave(self.group_size, dim=1)
254279

280+
# Step 4.
281+
# ---
282+
# Compute attention scores (indepdentently for each head), apply softmax to get attention weights, then apply those weights to the attention values to get output.
283+
# Attention scores are the dot-product of queries and keys.
284+
255285
# Use fused AIE MHA if enabled and conditions are met
256286
if is_prefill or not self.cfg["use_kv_cache"]:
257287
if (

operators/rope/op.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def __init__(
5151
def set_up_artifacts(self):
5252
# Compilation artifacts
5353
operator_dir = Path(__file__).parent
54-
file_name_base = f"rope_{self.num_aie_columns}c_{self.num_channels}ch_{self.size}_{self.tile_size}t_{self.method_type}m"
54+
file_name_base = f"rope_{self.num_aie_columns}c_{self.size}_{self.tile_size}t_{self.method_type}m"
5555

5656
mlir_artifact = PythonGeneratedMLIRArtifact.new(
5757
f"{file_name_base}.mlir",
@@ -119,7 +119,7 @@ def forward(self, x, y):
119119
and x.shape[-2:] == y.shape
120120
)
121121
if not applicable:
122-
raise AIEOPeratorConstraintError("AIERope: incompatible tensor shape(s)")
122+
raise AIEOperatorConstraintError("AIERope: incompatible tensor shape(s)")
123123

124124
original_shape = x.shape
125125
if len(x.shape) > 2:
@@ -137,6 +137,7 @@ def forward(self, x, y):
137137
batch_data = x[i:end_idx, :]
138138

139139
# Pad if necessary to match expected rows_per_batch
140+
angle_offset = i % y.shape[0]
140141
if batch_data.shape[0] < rows_per_batch:
141142
padding = torch.zeros(
142143
rows_per_batch - batch_data.shape[0],
@@ -146,12 +147,13 @@ def forward(self, x, y):
146147
)
147148
batch_data_padded = torch.cat([batch_data, padding], dim=0)
148149
result = self._process_batch(
149-
batch_data_padded, y[i % y.shape[0] : batch_size]
150+
batch_data_padded, y[angle_offset : angle_offset + rows_per_batch]
150151
)
151152
result = result[: batch_data.shape[0], :]
152153
else:
153-
result = self._process_batch(batch_data, y[i % y.shape[0] : batch_size])
154-
154+
result = self._process_batch(
155+
batch_data, y[angle_offset : angle_offset + rows_per_batch]
156+
)
155157
results.append(result)
156158

157159
# Concatenate all batch results
@@ -165,13 +167,9 @@ def forward(self, x, y):
165167

166168
def _process_batch(self, batch_data, angle_data):
167169
"""Process a batch of sequences through the AIE kernel"""
168-
batch_flat = batch_data.view(-1)
169-
170-
# Calculate buffer sizes for the batch
171-
input_size = batch_data.nbytes
172170

173171
# Write data to buffers
174-
self.write_buffer("input", batch_data)
172+
self.write_buffer("in", batch_data)
175173
self.write_buffer("angles", angle_data)
176174
test_pattern = np.zeros(len(batch_data), dtype=bfloat16)
177175
self.write_buffer("output", test_pattern)

0 commit comments

Comments
 (0)