Skip to content

Commit 331dcca

Browse files
kurtis-b-1andrej
andauthored
Offload Final Linear Layer (#48)
* Use output directory instead of /tmp directory for building * Can run GEMM with a wide B matrix * Moved creating buffers for buffer lists B/C inside golden reference so that the major order logic is there instead * Simplifying some logic in the operator * Further simplification of the gemm operator * Formatting * Adjust comments for C tile streaming through shim DMA based on the separate_c_tiles parameter * Remove using the separated C tile runtime streams in test * Can offload the last linear layer, but TTFT goes way up--reason seems to be that the actual forward operation for last linear layer is ~4.5s, likely due to reading the output buffers and converting the output from np to torch since the GEMM kernel itself should take ~200ms with bfp16 emulation enabled * Modified the torch_to_numpy/numpy_to_torch conversions to use zero-copy reinterprets, and removed unnecessary .to() calls which could result in extram unnecessary passes over memory * Make read_buffer for BOs zero-copy * Fix functionality when copy is True with read_buffer() * Run decode stage on CPU for final linear layer, which fixes toks per sec but the outptut tokens still inconsistent with CPU-only inference * Fix CPU final linear layer run with KV cache enabled and formatting * Use map view for writing buffers like with reading buffers * Make separate_c_tiles parameter based on partition_N value * Formatting * Use corect shapes for forward pass (padded N vs actual N) * Formatting * Clean up code and comments in new versions of read_buffer()/write_buffer() methods * Fix comments in numpy/torch conversion utils * fixes after merge * format * fixes --------- Co-authored-by: andrej <[email protected]>
1 parent a4b6ffe commit 331dcca

File tree

12 files changed

+485
-162
lines changed

12 files changed

+485
-162
lines changed

applications/llama_3.2_1b/configs/llama32_1b.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
"use_aie_residual": true,
2424
"use_aie_regular_mha": false,
2525
"use_aie_fused_mha": true,
26-
"use_aie_final_gemm": false,
26+
"use_aie_final_gemm": true,
2727
"rope_freq": {
2828
"factor": 32.0,
2929
"low_freq_factor": 1.0,

applications/llama_3.2_1b/inference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -400,7 +400,7 @@ def set_prefill_time():
400400
parser.add_argument(
401401
"--prompt_len",
402402
type=int,
403-
default=64,
403+
default=2048,
404404
help="Truncate prompt to this many tokens.",
405405
)
406406
parser.add_argument(

applications/llama_3.2_1b/src/block/gqa.py

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -163,38 +163,33 @@ def forward(self, x, mask, angles, input_pos=None):
163163
# Decode phase with KV cache - use GEMV for single token
164164
# weight.T @ input, which is vector-matrix multiplication (So, is_mv=False)
165165
x_flat = x.reshape(1, -1) # Shape: (1, d_in)
166-
input_dtype = x.dtype
167166

168167
queries_flat = self.aie_query_gemv(x_flat)
169-
queries = queries_flat.reshape(b, num_tokens, self.d_out).to(input_dtype)
168+
queries = queries_flat.reshape(b, num_tokens, self.d_out)
170169

171170
keys_flat = self.aie_key_gemv(x_flat)
172-
keys = keys_flat.reshape(
173-
b, num_tokens, self.num_kv_groups * self.head_dim
174-
).to(input_dtype)
171+
keys = keys_flat.reshape(b, num_tokens, self.num_kv_groups * self.head_dim)
175172

176173
values_flat = self.aie_value_gemv(x_flat)
177174
values = values_flat.reshape(
178175
b, num_tokens, self.num_kv_groups * self.head_dim
179-
).to(input_dtype)
176+
)
180177

181178
elif self.cfg["use_aie_attn_projection_gemm"]:
182179
# Prefill phase - use GEMM for multiple tokens
183180
x_flat = x.reshape(-1, d_in)
184181
input_dtype = x.dtype
185182

186183
queries_flat = self.aie_query(x_flat)
187-
queries = queries_flat.reshape(b, num_tokens, self.d_out).to(input_dtype)
184+
queries = queries_flat.reshape(b, num_tokens, self.d_out)
188185

189186
keys_flat = self.aie_key(x_flat)
190-
keys = keys_flat.reshape(
191-
b, num_tokens, self.num_kv_groups * self.head_dim
192-
).to(input_dtype)
187+
keys = keys_flat.reshape(b, num_tokens, self.num_kv_groups * self.head_dim)
193188

194189
values_flat = self.aie_value(x_flat)
195190
values = values_flat.reshape(
196191
b, num_tokens, self.num_kv_groups * self.head_dim
197-
).to(input_dtype)
192+
)
198193
else:
199194
queries = self.W_query(x)
200195
keys = self.W_key(x)
@@ -348,9 +343,9 @@ def apply_rope_and_transpose(tensor, num_heads_dim, angle_slice):
348343
def my_mha(queries, keys, values):
349344
inv_scale = 1 / np.sqrt(values.shape[-1])
350345
context_vec = torch.nn.functional.scaled_dot_product_attention(
351-
queries.to(torch.bfloat16).to("cpu"),
352-
keys.to(torch.bfloat16).to("cpu"),
353-
values.to(torch.bfloat16).to("cpu"),
346+
queries,
347+
keys,
348+
values,
354349
dropout_p=0.0,
355350
is_causal=True,
356351
scale=inv_scale,
@@ -384,11 +379,11 @@ def my_mha(queries, keys, values):
384379
if self.cfg["use_kv_cache"] and is_decode and self.cfg["use_aie_gemv"]:
385380
context_vec_flat = context_vec.reshape(1, -1)
386381
output_flat = self.aie_out_proj_gemv(context_vec_flat)
387-
context_vec = output_flat.reshape(b, num_tokens, self.d_out).to(input_dtype)
382+
context_vec = output_flat.reshape(b, num_tokens, self.d_out)
388383
elif self.cfg["use_aie_attn_projection_gemm"]:
389384
context_vec_flat = context_vec.reshape(-1, self.d_out)
390385
output_flat = self.aie_out_proj(context_vec_flat)
391-
context_vec = output_flat.reshape(b, num_tokens, self.d_out).to(input_dtype)
386+
context_vec = output_flat.reshape(b, num_tokens, self.d_out)
392387
else:
393388
context_vec = self.out_proj(context_vec)
394389

applications/llama_3.2_1b/src/model_with_json.py

Lines changed: 86 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@
1212
from pathlib import Path
1313
from src.block.transformer import TransformerBlock
1414
from operators.rope.rope_utils import compute_rope_params
15-
from operators import AIERMSNorm
15+
from operators import (
16+
AIERMSNorm,
17+
AIEGEMM,
18+
)
1619
from rich.console import Console
1720
from rich.text import Text
1821

@@ -169,7 +172,37 @@ def __init__(
169172
self.cfg["emb_dim"], eps=1e-5, dtype=self.cfg["dtype"]
170173
)
171174

172-
# Depedns on use_aie_final_gemm
175+
# Offload final linear layer if enabled
176+
if self.cfg.get("use_aie_final_gemm", False):
177+
# Since this GEMM has such a large N dimension, partition the N dimension by 4,
178+
# and GEMM will execute for a workload of that smaller N dimension across different buffers of B and C
179+
aie_config_prefill = {
180+
"num_aie_columns": 8,
181+
"tile_m": 64,
182+
"tile_k": 64,
183+
"tile_n": 64,
184+
"b_col_maj": True,
185+
"use_static_weight": True,
186+
"separate_c_tiles": True,
187+
"partition_N": 4,
188+
}
189+
if self.cfg["use_kv_cache"]:
190+
M_for_gemm = self.prompt_length
191+
else:
192+
M_for_gemm = self.prompt_length + self.num_tokens
193+
self.out_head_aie = AIEGEMM(
194+
M=M_for_gemm,
195+
K=self.cfg["emb_dim"],
196+
N=self.cfg["vocab_size"],
197+
**aie_config_prefill,
198+
)
199+
else:
200+
self.out_head = nn.Linear(
201+
self.cfg["emb_dim"],
202+
self.cfg["vocab_size"],
203+
bias=False,
204+
dtype=self.cfg["dtype"],
205+
)
173206
self.out_head = nn.Linear(
174207
self.cfg["emb_dim"],
175208
self.cfg["vocab_size"],
@@ -194,6 +227,22 @@ def forward(self, in_idx, input_pos=None, use_kv_cache=False):
194227
tok_embeds = self.tok_emb(in_idx)
195228
x = tok_embeds
196229

230+
# Check if input is a vector (decode phase) or matrix (prefill phase)
231+
# Handle 1D: (emb_dim,), 2D: (1, emb_dim), or 3D: (1, 1, emb_dim)
232+
is_vector = (
233+
len(x.shape) == 1
234+
or (len(x.shape) == 2 and x.shape[0] == 1)
235+
or (len(x.shape) == 3 and x.shape[0] == 1 and x.shape[1] == 1)
236+
)
237+
238+
# (batch, sequence, embedding) where sequence=1 indicates decode
239+
if len(x.shape) == 3:
240+
is_decode_with_kv = (x.shape[1] == 1) and self.cfg["use_kv_cache"]
241+
elif len(x.shape) == 2:
242+
is_decode_with_kv = (x.shape[0] == 1) and self.cfg["use_kv_cache"]
243+
else:
244+
is_decode_with_kv = False
245+
197246
num_tokens = x.shape[1]
198247

199248
# During generation phase with KV cache, don't create a mask
@@ -219,19 +268,47 @@ def forward(self, in_idx, input_pos=None, use_kv_cache=False):
219268
else:
220269
x = self.final_norm(x)
221270

222-
logits = self.out_head(x.to(self.cfg["dtype"]))
271+
if self.cfg["use_aie_final_gemm"]:
272+
if is_decode_with_kv and self.cfg["use_aie_gemv"]:
273+
# TODO: Create GEMV operator
274+
# logits = self.aie_out_head_gemv(x)
275+
logits = self.out_head(x) # Running on CPU
276+
else:
277+
logits = self.out_head_aie(x)
278+
else:
279+
logits = self.out_head(x)
223280

224281
return logits
225282

226-
def assign_weights(self, final_norm):
283+
def assign_weights(self, final_norm, out_head, out_head_name):
227284
if self.cfg.get("use_aie_final_norm", False):
228285
self.aie_final_norm_prefill.weight = final_norm
229286
if self.cfg["use_kv_cache"]:
230287
self.aie_final_norm_decode.weight = final_norm
231-
return
288+
else:
289+
self.final_norm.weight = assign(
290+
self.final_norm.weight,
291+
final_norm,
292+
f"model.norm.weight",
293+
)
232294

233-
self.final_norm.weight = assign(
234-
self.final_norm.weight,
235-
final_norm,
236-
f"model.norm.weight",
295+
self.out_head.weight = assign(
296+
self.out_head.weight,
297+
out_head,
298+
out_head_name,
237299
)
300+
# TODO: Offload GEMV to NPU
301+
# if self.cfg["use_kv_cache"] and self.cfg["use_aie_gemv"]:
302+
# self.aie_out_head_gemv.weight = out_head
303+
if self.cfg["use_aie_final_gemm"]:
304+
# Want column-major for B
305+
self.out_head_aie.weight = out_head.T
306+
# TODO: Create separate linear layers for prefill and decode (with gemm/gemv)
307+
# if self.cfg["use_kv_cache"]:
308+
# self.out_head.weight = out_head.T
309+
else:
310+
self.out_head.weight = assign(
311+
self.out_head.weight,
312+
out_head,
313+
out_head_name,
314+
)

applications/llama_3.2_1b/src/utils.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -126,15 +126,13 @@ def load_weights_into_llama(model, param_config, params):
126126
)
127127

128128
# Load output layer weights
129-
model.assign_weights(params["model.norm.weight"])
130-
131129
if "lm_head.weight" in params.keys():
132-
model.out_head.weight = assign(
133-
model.out_head.weight, params["lm_head.weight"], "lm_head.weight"
130+
model.assign_weights(
131+
params["model.norm.weight"], params["lm_head.weight"], "lm_head.weight"
134132
)
135133
else:
136-
model.out_head.weight = assign(
137-
model.out_head.weight,
134+
model.assign_weights(
135+
params["model.norm.weight"],
138136
params["model.embed_tokens.weight"],
139137
"model.embed_tokens.weight",
140138
)

operators/common/aie_base.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -89,25 +89,41 @@ def add_to_runlist(self, kernel_name, *args):
8989
def get_bo(self, buffer_name):
9090
return self.buffer_bos[buffer_name]
9191

92-
def read_buffer(self, buffer_name, shape, dtype=bfloat16):
92+
def read_buffer(self, buffer_name, shape, copy=False, dtype=bfloat16):
9393
"""Read buffer and return values as a numpy array"""
94-
size = np.prod(shape) * np.dtype(dtype).itemsize
95-
output_bytes = self.get_bo(buffer_name).read(size, 0)
96-
output_data_flat = np.frombuffer(output_bytes, dtype=dtype)
97-
return output_data_flat.reshape(*shape)
94+
# Create a byte accessible memory view of the buffer object
95+
mv = self.get_bo(buffer_name).map()
96+
97+
# Interpret the buffer as a 1-dimensional array then change its view to the expected shape
98+
arr = np.frombuffer(mv, dtype=dtype, count=np.prod(shape)).reshape(shape)
99+
100+
# Return an independent copy of the array if needed
101+
return arr.copy() if copy else arr
98102

99103
def read_buffer_as_torch(self, buffer_name, shape, dtype=bfloat16):
100104
return numpy_to_torch(self.read_buffer(buffer_name, shape, dtype))
101105

102106
def write_buffer(self, buffer_name, array):
103107
"""Write buffer from a numpy array into a XRT buffer object"""
104-
if isinstance(array, torch.Tensor):
105-
numpy_array = torch_to_numpy(array)
106-
else:
107-
numpy_array = array
108108
if buffer_name in self.buffer_static_data:
109109
raise RuntimeError(f"Cannot write to static buffer: {buffer_name}")
110-
self.get_bo(buffer_name).write(numpy_array.flatten().view(np.uint8), 0)
110+
111+
# Normalize the source
112+
if isinstance(array, torch.Tensor):
113+
src = torch_to_numpy(array)
114+
else:
115+
src = np.asarray(array)
116+
117+
# Create a flattened 1D byte view of the source
118+
src_bytes = src.ravel().view(np.uint8)
119+
120+
bo = self.get_bo(buffer_name)
121+
mv = bo.map() # byte accessible memory view
122+
# Interpret the buffer as a 1-dimensional array
123+
dst_bytes = np.frombuffer(mv, dtype=np.uint8, count=bo.size())
124+
125+
# The BO is an existing array, so copyto() can be called, which doesn't create a new array
126+
np.copyto(dst_bytes[: src_bytes.size], src_bytes, casting="no")
111127

112128
@abstractmethod
113129
def set_up_artifacts(self):

operators/common/aie_context.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,16 +98,48 @@ def prepare_runtime(self):
9898
)
9999

100100
# If multiple buffers (of the same binned size) are used in the
101-
# same kernel invocation, they require separate allocations.
101+
# same kernel invocation OR across different invocations with shared
102+
# buffers, they require separate allocations.
102103
conflicting_buffers = {} # map buffer -> {set of conflicting buffers}
103-
for kernel, *args in op.runlist:
104+
buffer_to_runlist_entries = {} # map buffer -> set of runlist entry indices
105+
106+
# First pass: track which buffers appear in which runlist entries
107+
for idx, (kernel, *args) in enumerate(op.runlist):
108+
for arg in args:
109+
buffer_to_runlist_entries.setdefault(arg, set()).add(idx)
110+
111+
# Second pass: determine conflicts
112+
for idx, (kernel, *args) in enumerate(op.runlist):
104113
for arg in args:
105114
if arg in op.buffer_static_data:
115+
# Static buffers never conflict
106116
continue
107117
pool_sz = get_pool_sz(op.buffers[arg])
118+
119+
# Buffers conflict if they're in the same runlist entry
108120
conflicting_args = {
109121
a for a in args if get_pool_sz(op.buffers[a]) == pool_sz
110122
} - {arg}
123+
124+
# Also conflict with buffers in other runlist entries that share
125+
# a buffer with this entry
126+
for other_arg in args:
127+
if other_arg == arg:
128+
continue
129+
for other_idx in buffer_to_runlist_entries.get(
130+
other_arg, set()
131+
):
132+
if other_idx != idx:
133+
_, *other_args = op.runlist[other_idx]
134+
conflicting_args.update(
135+
{
136+
a
137+
for a in other_args
138+
if get_pool_sz(op.buffers[a]) == pool_sz
139+
and a != arg
140+
}
141+
)
142+
111143
conflicting_buffers[arg] = conflicting_buffers.get(
112144
arg, set()
113145
).union(conflicting_args)

operators/common/utils.py

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,31 @@
2323

2424

2525
def torch_to_numpy(tensor: torch.Tensor) -> np.ndarray:
26-
if tensor.dtype == torch.bfloat16:
27-
float_arr = tensor.float().detach().cpu().numpy()
28-
return float_arr.astype(bfloat16)
29-
return tensor.detach().cpu().numpy()
26+
# Detach (to drop grad) and ensure on CPU
27+
t = tensor.detach()
28+
if t.device.type != "cpu":
29+
t = t.cpu()
30+
# Ensure contiguous for safe view operations
31+
if not t.is_contiguous():
32+
t = t.contiguous()
33+
34+
if t.dtype == torch.bfloat16:
35+
# View the same memory as uint16, then as NumPy bfloat16
36+
# This avoids numeric conversion and extra passes over memory.
37+
u16_np = t.view(torch.uint16).numpy() # shares memory
38+
return u16_np.view(np.dtype("bfloat16")) # reinterpret
39+
40+
return t.numpy()
3041

3142

3243
def numpy_to_torch(array: np.ndarray) -> torch.Tensor:
33-
device = torch.device("cpu")
34-
if array.dtype == bfloat16:
35-
return torch.from_numpy(array.astype(np.float32)).to(torch.bfloat16).to(device)
36-
return torch.from_numpy(array).to(device)
44+
# Ensure contiguous to let from_numpy create a view
45+
if not array.flags["C_CONTIGUOUS"]:
46+
array = np.ascontiguousarray(array)
47+
48+
if array.dtype == np.dtype("bfloat16"):
49+
# reinterpret the same memory as uint16, then view as torch.bfloat16
50+
t_u16 = torch.from_numpy(array.view(np.uint16))
51+
return t_u16.view(torch.bfloat16) # view
52+
53+
return torch.from_numpy(array)

0 commit comments

Comments
 (0)