Skip to content

Commit a13dcaa

Browse files
committed
Address Curts comments
1 parent d3e412f commit a13dcaa

File tree

9 files changed

+320
-321
lines changed

9 files changed

+320
-321
lines changed

applications/llama_3.2_1b/configs/llama32_1b.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
"use_aie_norm2": true,
2323
"use_aie_residual": true,
2424
"use_aie_regular_mha": false,
25-
"use_aie_fused_mha": false,
25+
"use_aie_fused_mha": true,
2626
"use_aie_final_gemm": false,
2727
"rope_freq": {
2828
"factor": 32.0,

applications/llama_3.2_1b/src/block/feed_forward.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
# SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved.
77
# SPDX-License-Identifier: Apache-2.0
88

9-
import logging
109
import torch
1110
import torch.nn as nn
1211
from ..utils import torch_to_numpy, assign
@@ -129,6 +128,7 @@ def forward(self, x):
129128
or (len(x.shape) == 3 and x.shape[0] == 1 and x.shape[1] == 1)
130129
)
131130

131+
is_prefill = not is_vector or not self.cfg["use_kv_cache"]
132132
is_decode_with_kv = is_vector and self.cfg["use_kv_cache"]
133133

134134
if self.cfg["use_aie_ffn_swiglu"]:

applications/llama_3.2_1b/src/operator/aie_base.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@ def prepare_runtime(cls):
6060
cls.static_data_pool[buffer_data] = bo
6161

6262
for op in cls.registered_operators:
63+
if len(op.kernels) == 0:
64+
# Operator likely is used as a sub-operator in another operator and does need any setup.
65+
continue
6366
logging.info(f"Preparing runtime for AIE operator: {op.__class__.__name__}")
6467

6568
# Set up for each kernel

applications/llama_3.2_1b/src/operator/aie_elementwise_mul.py

Lines changed: 52 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -18,52 +18,18 @@
1818
from pathlib import Path
1919

2020

21-
def get_elementwise_mul_artifacts(
22-
base_dir,
23-
device_type,
24-
size,
25-
tile_size=2048,
26-
num_columns=4,
27-
num_channels=2,
28-
prefix="eltwise_mul_",
29-
):
30-
file_name_base = f"{prefix}{num_columns}c_{num_channels}ch_{size}_{tile_size}t"
31-
32-
mlir_artifact = PythonGeneratedMLIRArtifact.new(
33-
f"{file_name_base}.mlir",
34-
import_path=base_dir / "example" / "elementwise_mul" / "eltwise_mul.py",
35-
callback_fn="my_eltwise_mul",
36-
callback_args=[
37-
device_type,
38-
size,
39-
num_columns,
40-
num_channels,
41-
tile_size,
42-
0,
43-
],
44-
)
45-
46-
xclbin_artifact = XclbinArtifact.new(
47-
f"{file_name_base}.xclbin",
48-
depends=[
49-
mlir_artifact,
50-
KernelObjectArtifact.new(
51-
f"mul.o", depends=[SourceArtifact.new("aie_kernels/generic/mul.cc")]
52-
),
53-
],
54-
)
55-
56-
insts_artifact = InstsBinArtifact.new(
57-
f"{file_name_base}.bin", depends=[mlir_artifact]
58-
)
59-
60-
return xclbin_artifact, insts_artifact
61-
62-
6321
class AIEElementwiseMul(AIEOperatorBase):
6422
"""AIE-accelerated element-wise multiplication"""
6523

66-
def __init__(self, size, num_columns=None, num_channels=None, tile_size=None):
24+
def __init__(
25+
self,
26+
size,
27+
num_columns=None,
28+
num_channels=None,
29+
tile_size=None,
30+
trace_size=0,
31+
do_set_up=True,
32+
):
6733
self.size = size
6834

6935
# Enforce ShimDMA limits for elementwise_mul (uses 2 inputs per core)
@@ -80,20 +46,54 @@ def __init__(self, size, num_columns=None, num_channels=None, tile_size=None):
8046
self.num_columns = num_columns
8147
self.num_channels = num_channels
8248
self.tile_size = tile_size
49+
self.trace_size = trace_size
50+
self.do_set_up = do_set_up
8351

8452
AIEOperatorBase.__init__(self)
8553

54+
def get_artifacts(self, prefix="eltwise_mul_"):
55+
file_name_base = f"{prefix}{self.num_columns}c_{self.num_channels}ch_{self.size}_{self.tile_size}t"
56+
57+
mlir_artifact = PythonGeneratedMLIRArtifact.new(
58+
f"{file_name_base}.mlir",
59+
import_path=self.base_dir
60+
/ "example"
61+
/ "elementwise_mul"
62+
/ "eltwise_mul.py",
63+
callback_fn="my_eltwise_mul",
64+
callback_args=[
65+
self.device_manager.device_type,
66+
self.size,
67+
self.num_columns,
68+
self.num_channels,
69+
self.tile_size,
70+
self.trace_size,
71+
],
72+
)
73+
74+
xclbin_artifact = XclbinArtifact.new(
75+
f"{file_name_base}.xclbin",
76+
depends=[
77+
mlir_artifact,
78+
KernelObjectArtifact.new(
79+
f"mul.o", depends=[SourceArtifact.new("aie_kernels/generic/mul.cc")]
80+
),
81+
],
82+
)
83+
84+
insts_artifact = InstsBinArtifact.new(
85+
f"{file_name_base}.bin", depends=[mlir_artifact]
86+
)
87+
88+
return xclbin_artifact, insts_artifact
89+
8690
def set_up(self):
91+
# If this operator is only used as a sub-operator in another operator that sets it up, we should skip the setup here as those artifacts and buffers may not be needed.
92+
if not self.do_set_up:
93+
return
94+
8795
# Compilation artifacts
88-
xclbin_artifact, insts_artifact = get_elementwise_mul_artifacts(
89-
self.base_dir,
90-
self.device_manager.device_type,
91-
self.size,
92-
self.tile_size,
93-
self.num_columns,
94-
self.num_channels,
95-
prefix="",
96-
)
96+
xclbin_artifact, insts_artifact = self.get_artifacts()
9797

9898
# Override device_type in the mlir_artifact's callback_args if needed
9999
mlir_artifact = xclbin_artifact.depends[0]

0 commit comments

Comments
 (0)