Skip to content

Commit d48746f

Browse files
authored
Support larger matrices in GEMV and offload last-layer GEMV in llama (#64)
* allow larger matrices for GEMV and sparate out input/output tile size * fix tests * address comments
1 parent 331dcca commit d48746f

File tree

12 files changed

+227
-111
lines changed

12 files changed

+227
-111
lines changed

aie_kernels/generic/mv.cc

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,27 +15,39 @@
1515

1616
#include <aie_api/aie.hpp>
1717

18-
void matvec_scalar(uint32_t m, uint32_t k, uint32_t row_offset, bfloat16 *a, bfloat16 *b, bfloat16 *c)
18+
void matvec_scalar(uint32_t m,
19+
uint32_t k,
20+
const bfloat16 *__restrict a,
21+
const bfloat16 *__restrict b,
22+
bfloat16 *__restrict c)
1923
{
2024
for (uint32_t row = 0; row < m; row++) {
2125
float acc = 0;
2226
for (uint32_t i = 0; i < k; i++) {
2327
acc += a[row * k + i] * b[i];
2428
}
25-
c[row + row_offset * m] = static_cast<bfloat16>(acc);
29+
c[row] = static_cast<bfloat16>(acc);
2630
}
2731
}
2832

33+
/*
34+
Matrix-vector multiplication kernel
35+
36+
- m: Number of output rows == number of rows in the input matrix
37+
- k: Number of columns in the input matrix == length of the input vector
38+
- a: Pointer to the input matrix, stored in row-major order
39+
- b: Pointer to the input vector
40+
- c: Pointer to the output vector
41+
- r: Vector size; data from the matrix and vector will be loaded in and processed in chunks of this size
42+
*/
2943
template <uint32_t r>
3044
void matvec_vectorized(uint32_t m,
3145
uint32_t k,
32-
uint32_t row_offset,
3346
const bfloat16 *__restrict a,
3447
const bfloat16 *__restrict b,
3548
bfloat16 *__restrict c)
3649
{
3750
::aie::set_rounding(aie::rounding_mode::conv_even);
38-
c += row_offset * m;
3951
bfloat16 *c_end = c + m;
4052
const bfloat16 *b_end = b + k;
4153
for (; c < c_end; c++) {
@@ -55,24 +67,30 @@ void matvec_vectorized(uint32_t m,
5567

5668
extern "C" {
5769

70+
/* The row offset parameter in the functions below is a workaround. The output will be written to c + row_offset * m.
71+
* This is simpler than to do pointer arithmetic in the calling MLIR code, but that's all this is for -- an offset into
72+
* `c`. */
73+
5874
void matvec_scalar_bf16_bf16(uint32_t m,
5975
uint32_t k,
6076
uint32_t row_offset,
61-
bfloat16 *a_in,
62-
bfloat16 *b_in,
63-
bfloat16 *c_out)
77+
const bfloat16 *__restrict a_in,
78+
const bfloat16 *__restrict b_in,
79+
bfloat16 *__restrict c_out)
6480
{
65-
matvec_scalar(m, k, row_offset, a_in, b_in, c_out);
81+
c_out += row_offset;
82+
matvec_scalar(m, k, a_in, b_in, c_out);
6683
}
6784

6885
void matvec_vectorized_bf16_bf16(uint32_t m,
6986
uint32_t k,
7087
uint32_t row_offset,
71-
bfloat16 *a_in,
72-
bfloat16 *b_in,
73-
bfloat16 *c_out)
88+
const bfloat16 *__restrict a_in,
89+
const bfloat16 *__restrict b_in,
90+
bfloat16 *__restrict c_out)
7491
{
75-
matvec_vectorized<64>(m, k, row_offset, a_in, b_in, c_out);
92+
c_out += row_offset;
93+
matvec_vectorized<64>(m, k, a_in, b_in, c_out);
7694
}
7795

7896
} // extern "C"

applications/llama_3.2_1b/configs/llama32_1b.json

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,22 +8,24 @@
88
"hidden_dim": 8192,
99
"n_kv_groups": 8,
1010
"use_kv_cache": true,
11-
"use_aie_gemv": true,
1211
"rope_base": 500000.0,
1312
"dtype": "bfloat16",
1413
"use_aie_final_norm": true,
1514
"use_aie_ffn_gemm": false,
1615
"use_aie_ffn_silu": false,
1716
"use_aie_ffn_mul": false,
1817
"use_aie_ffn_swiglu": true,
18+
"use_aie_ffn_gemv": true,
1919
"use_aie_attn_projection_gemm": true,
20+
"use_aie_gqa_gemv": true,
2021
"use_aie_rope": true,
2122
"use_aie_norm1": true,
2223
"use_aie_norm2": true,
2324
"use_aie_residual": true,
2425
"use_aie_regular_mha": false,
2526
"use_aie_fused_mha": true,
2627
"use_aie_final_gemm": true,
28+
"use_aie_final_gemv": true,
2729
"rope_freq": {
2830
"factor": 32.0,
2931
"low_freq_factor": 1.0,

applications/llama_3.2_1b/src/block/feed_forward.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -115,18 +115,30 @@ def __init__(
115115
cfg["hidden_dim"], cfg["emb_dim"], dtype=cfg["dtype"], bias=False
116116
)
117117

118-
if self.cfg["use_kv_cache"] and self.cfg["use_aie_gemv"]:
119-
aie_gemv_config = {"num_aie_columns": 1, "is_mv": False}
118+
if self.cfg["use_kv_cache"] and self.cfg["use_aie_ffn_gemv"]:
119+
aie_gemv_config = {"num_aie_columns": 8, "is_mv": False}
120120
# FC1 and FC2: emb_dim -> hidden_dim
121121
self.aie_fc1_gemv = AIEGEMV(
122-
M=self.hidden_dim, K=self.emb_dim, **aie_gemv_config
122+
M=self.hidden_dim,
123+
K=self.emb_dim,
124+
tile_size_input=1,
125+
tile_size_output=self.hidden_dim // 16,
126+
**aie_gemv_config,
123127
)
124128
self.aie_fc2_gemv = AIEGEMV(
125-
M=self.hidden_dim, K=self.emb_dim, **aie_gemv_config
129+
M=self.hidden_dim,
130+
K=self.emb_dim,
131+
tile_size_input=1,
132+
tile_size_output=self.hidden_dim // 16,
133+
**aie_gemv_config,
126134
)
127135
# FC3: hidden_dim -> emb_dim
128136
self.aie_fc3_gemv = AIEGEMV(
129-
M=self.emb_dim, K=self.hidden_dim, **aie_gemv_config
137+
M=self.emb_dim,
138+
K=self.hidden_dim,
139+
tile_size_input=1,
140+
tile_size_output=self.emb_dim // 16,
141+
**aie_gemv_config,
130142
)
131143

132144
# Initialize AIE elementwise multiply
@@ -176,7 +188,7 @@ def forward(self, x):
176188
else:
177189
return self.aie_swiglu_decode(x)
178190

179-
if is_decode_with_kv and self.cfg["use_aie_gemv"]:
191+
if is_decode_with_kv and self.cfg["use_aie_ffn_gemv"]:
180192
x_fc1 = self.aie_fc1_gemv(x)
181193
x_fc2 = self.aie_fc2_gemv(x)
182194
else:
@@ -199,14 +211,14 @@ def forward(self, x):
199211
else:
200212
x = x_fc1_silu * x_fc2
201213

202-
if is_decode_with_kv and self.cfg["use_aie_gemv"]:
214+
if is_decode_with_kv and self.cfg["use_aie_ffn_gemv"]:
203215
result = self.aie_fc3_gemv(x)
204216
return result.view(original_shape)
205217
else:
206218
return self.fc3(x).view(original_shape)
207219

208220
def assign_weights(self, l, fc1, fc2, fc3):
209-
if self.cfg["use_kv_cache"] and self.cfg["use_aie_gemv"]:
221+
if self.cfg["use_kv_cache"] and self.cfg["use_aie_ffn_gemv"]:
210222
self.aie_fc1_gemv.weight = fc1
211223
self.aie_fc2_gemv.weight = fc2
212224
self.aie_fc3_gemv.weight = fc3

applications/llama_3.2_1b/src/block/gqa.py

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -115,18 +115,42 @@ def __init__(
115115
)
116116

117117
# Initialize AIE GEMV operators for decode phase (when using KV cache)
118-
if self.cfg["use_kv_cache"] and self.cfg["use_aie_gemv"]:
118+
if self.cfg["use_kv_cache"] and self.cfg["use_aie_gqa_gemv"]:
119119

120120
aie_gemv_config = {
121-
"num_aie_columns": 1,
121+
"num_aie_columns": 8,
122122
"is_mv": False,
123123
"use_static_weight": True,
124124
}
125-
self.aie_query_gemv = AIEGEMV(M=d_out, K=d_in, **aie_gemv_config)
125+
self.aie_query_gemv = AIEGEMV(
126+
M=d_out,
127+
K=d_in,
128+
tile_size_input=1,
129+
tile_size_output=d_out // 16,
130+
**aie_gemv_config,
131+
)
126132
kv_out_dim = num_kv_groups * self.head_dim
127-
self.aie_key_gemv = AIEGEMV(M=kv_out_dim, K=d_in, **aie_gemv_config)
128-
self.aie_value_gemv = AIEGEMV(M=kv_out_dim, K=d_in, **aie_gemv_config)
129-
self.aie_out_proj_gemv = AIEGEMV(M=d_out, K=d_out, **aie_gemv_config)
133+
self.aie_key_gemv = AIEGEMV(
134+
M=kv_out_dim,
135+
K=d_in,
136+
tile_size_input=1,
137+
tile_size_output=kv_out_dim // 16,
138+
**aie_gemv_config,
139+
)
140+
self.aie_value_gemv = AIEGEMV(
141+
M=kv_out_dim,
142+
K=d_in,
143+
tile_size_input=1,
144+
tile_size_output=kv_out_dim // 16,
145+
**aie_gemv_config,
146+
)
147+
self.aie_out_proj_gemv = AIEGEMV(
148+
M=d_out,
149+
K=d_out,
150+
tile_size_input=1,
151+
tile_size_output=d_out // 16,
152+
**aie_gemv_config,
153+
)
130154

131155
# Initialize AIE GEMM operators
132156
if self.cfg["use_aie_attn_projection_gemm"]:
@@ -159,7 +183,7 @@ def forward(self, x, mask, angles, input_pos=None):
159183
is_decode = input_pos is not None
160184

161185
# Choose between GEMM (prefill) and GEMV (decode) based on KV cache usage
162-
if self.cfg["use_kv_cache"] and is_decode and self.cfg["use_aie_gemv"]:
186+
if self.cfg["use_kv_cache"] and is_decode and self.cfg["use_aie_gqa_gemv"]:
163187
# Decode phase with KV cache - use GEMV for single token
164188
# weight.T @ input, which is vector-matrix multiplication (So, is_mv=False)
165189
x_flat = x.reshape(1, -1) # Shape: (1, d_in)
@@ -376,7 +400,7 @@ def my_mha(queries, keys, values):
376400
context_vec = context_vec.reshape(b, num_tokens, self.d_out)
377401

378402
# Choose output projection based on phase
379-
if self.cfg["use_kv_cache"] and is_decode and self.cfg["use_aie_gemv"]:
403+
if self.cfg["use_kv_cache"] and is_decode and self.cfg["use_aie_gqa_gemv"]:
380404
context_vec_flat = context_vec.reshape(1, -1)
381405
output_flat = self.aie_out_proj_gemv(context_vec_flat)
382406
context_vec = output_flat.reshape(b, num_tokens, self.d_out)
@@ -390,7 +414,7 @@ def my_mha(queries, keys, values):
390414
return context_vec
391415

392416
def assign_weights(self, l, w_query, w_key, w_value, w_out_proj):
393-
if self.cfg["use_kv_cache"] and self.cfg["use_aie_gemv"]:
417+
if self.cfg["use_kv_cache"] and self.cfg["use_aie_gqa_gemv"]:
394418
self.aie_query_gemv.weight = w_query
395419
self.aie_key_gemv.weight = w_key
396420
self.aie_value_gemv.weight = w_value

applications/llama_3.2_1b/src/model_with_json.py

Lines changed: 24 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,7 @@
1212
from pathlib import Path
1313
from src.block.transformer import TransformerBlock
1414
from operators.rope.rope_utils import compute_rope_params
15-
from operators import (
16-
AIERMSNorm,
17-
AIEGEMM,
18-
)
15+
from operators import AIERMSNorm, AIEGEMM, AIEGEMV
1916
from rich.console import Console
2017
from rich.text import Text
2118

@@ -35,20 +32,22 @@ def dtype_from_string(inp):
3532
config_options = {
3633
"dtype": (dtype_from_string, torch.float32, "Data type"),
3734
"use_kv_cache": (bool, False, "[Model] KV Cache"),
38-
"use_aie_gemv": (bool, False, "[Decode] GEMV"),
3935
"use_aie_rope": (bool, False, "[Attention] Rope"),
4036
"use_aie_attn_projection_gemm": (bool, False, "[Attention] QKV GEMM"),
4137
"use_aie_regular_mha": (bool, False, "[Attention] Regular MHA"),
4238
"use_aie_fused_mha": (bool, False, "[Attention] Fused MHA"),
39+
"use_aie_gqa_gemv": (bool, False, "[Attention] GEMV (Decode)"),
4340
"use_aie_ffn_gemm": (bool, False, "[FFN] GEMM"),
4441
"use_aie_ffn_mul": (bool, False, "[FFN] Elementwise Mul"),
4542
"use_aie_ffn_silu": (bool, False, "[FFN] SiLU"),
4643
"use_aie_ffn_swiglu": (bool, False, "[FFN] Runlist-based SwiGLU"),
44+
"use_aie_ffn_gemv": (bool, False, "[FFN] GEMV (Decode)"),
4745
"use_aie_residual": (bool, False, "[Transformer] Residual Addition"),
4846
"use_aie_norm1": (bool, False, "[Transformer] Pre Norm"),
4947
"use_aie_norm2": (bool, False, "[Transformer] Post Norm"),
5048
"use_aie_final_norm": (bool, False, "[Transformer] Final Norm"),
5149
"use_aie_final_gemm": (bool, False, "[Transformer] Final GEMM"),
50+
"use_aie_final_gemv": (bool, False, "[Transformer] Final GEMV"),
5251
}
5352
# fmt: on
5453

@@ -190,25 +189,32 @@ def __init__(
190189
M_for_gemm = self.prompt_length
191190
else:
192191
M_for_gemm = self.prompt_length + self.num_tokens
193-
self.out_head_aie = AIEGEMM(
192+
self.out_head_prefill = AIEGEMM(
194193
M=M_for_gemm,
195194
K=self.cfg["emb_dim"],
196195
N=self.cfg["vocab_size"],
197196
**aie_config_prefill,
198197
)
198+
aie_gemv_config = {
199+
"num_aie_columns": 8,
200+
"is_mv": True,
201+
"use_static_weight": True,
202+
"num_aie_columns": 8,
203+
"tile_size_input": 4,
204+
"tile_size_output": 32,
205+
}
206+
# FC1 and FC2: emb_dim -> hidden_dim
207+
if self.cfg["use_aie_final_gemv"]:
208+
self.out_head_decode = AIEGEMV(
209+
M=self.cfg["vocab_size"], K=self.cfg["emb_dim"], **aie_gemv_config
210+
)
199211
else:
200212
self.out_head = nn.Linear(
201213
self.cfg["emb_dim"],
202214
self.cfg["vocab_size"],
203215
bias=False,
204216
dtype=self.cfg["dtype"],
205217
)
206-
self.out_head = nn.Linear(
207-
self.cfg["emb_dim"],
208-
self.cfg["vocab_size"],
209-
bias=False,
210-
dtype=self.cfg["dtype"],
211-
)
212218

213219
# Reusable utilities
214220
cos, sin = compute_rope_params(
@@ -269,12 +275,10 @@ def forward(self, in_idx, input_pos=None, use_kv_cache=False):
269275
x = self.final_norm(x)
270276

271277
if self.cfg["use_aie_final_gemm"]:
272-
if is_decode_with_kv and self.cfg["use_aie_gemv"]:
273-
# TODO: Create GEMV operator
274-
# logits = self.aie_out_head_gemv(x)
275-
logits = self.out_head(x) # Running on CPU
278+
if is_decode_with_kv and self.cfg["use_aie_final_gemv"]:
279+
logits = self.out_head_decode(x)
276280
else:
277-
logits = self.out_head_aie(x)
281+
logits = self.out_head_prefill(x)
278282
else:
279283
logits = self.out_head(x)
280284

@@ -292,20 +296,11 @@ def assign_weights(self, final_norm, out_head, out_head_name):
292296
f"model.norm.weight",
293297
)
294298

295-
self.out_head.weight = assign(
296-
self.out_head.weight,
297-
out_head,
298-
out_head_name,
299-
)
300-
# TODO: Offload GEMV to NPU
301-
# if self.cfg["use_kv_cache"] and self.cfg["use_aie_gemv"]:
302-
# self.aie_out_head_gemv.weight = out_head
303299
if self.cfg["use_aie_final_gemm"]:
304300
# Want column-major for B
305-
self.out_head_aie.weight = out_head.T
306-
# TODO: Create separate linear layers for prefill and decode (with gemm/gemv)
307-
# if self.cfg["use_kv_cache"]:
308-
# self.out_head.weight = out_head.T
301+
self.out_head_prefill.weight = out_head.T
302+
if self.cfg["use_aie_final_gemv"]:
303+
self.out_head_decode.weight = out_head.T
309304
else:
310305
self.out_head.weight = assign(
311306
self.out_head.weight,

operators/common/aie_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def add_buffer(self, name, count, dtype=bfloat16, static_data=None):
6868
if static_data is not None:
6969
assert (
7070
static_data.nbytes <= self.buffers[name]
71-
), f"Static data for buffer {name} exceeds allocated size."
71+
), f"Static data for buffer {name} exceeds allocated size: expected {self.buffers[name]} bytes, got {static_data.nbytes} bytes."
7272
static_data_bytes = static_data.flatten().view(np.uint8).tobytes()
7373
if static_data_bytes not in self.context.static_data_pool:
7474
self.context.static_data_pool[static_data_bytes] = None

operators/common/compilation.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,7 @@ def compile(self, artifacts):
361361
"--no-xbridge",
362362
"--peano",
363363
str(self.peano_dir),
364+
"--dynamic-objFifos",
364365
]
365366
do_compile_xclbin = mlir_source in mlir_sources_to_xclbins
366367
do_compile_insts_bin = mlir_source in mlir_sources_to_insts_bins

0 commit comments

Comments
 (0)