Skip to content

Commit 6c7123b

Browse files
committed
address Curt's and Pranathi's comments
1 parent e8b947d commit 6c7123b

File tree

3 files changed

+8
-12
lines changed

3 files changed

+8
-12
lines changed

applications/llama_3.2_1b/src/block/feed_forward.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,6 @@ def __init__(
2525
super().__init__()
2626
self.cfg = cfg.copy()
2727

28-
assert cfg["use_aie_ffn_swiglu"] != (
29-
cfg["use_aie_ffn_silu"] or cfg["use_aie_ffn_gemm"] or cfg["use_aie_ffn_mul"]
30-
), "Cannot mix fused SwiGLU with individual AIE operators."
31-
3228
self.emb_dim = cfg["emb_dim"]
3329
self.hidden_dim = cfg["hidden_dim"]
3430

@@ -106,8 +102,8 @@ def forward(self, x):
106102
is_prefill = not is_vector or not self.cfg["use_kv_cache"]
107103

108104
if is_vector and self.cfg["use_kv_cache"] and self.cfg["use_aie_gemv"]:
109-
x_fc1 = self.aie_fc1_gemv(None, x)
110-
x_fc2 = self.aie_fc2_gemv(None, x)
105+
x_fc1 = self.aie_fc1_gemv(x)
106+
x_fc2 = self.aie_fc2_gemv(x)
111107
else:
112108
x_fc1 = self.fc1(x)
113109
x_fc2 = self.fc2(x)
@@ -120,7 +116,7 @@ def forward(self, x):
120116
x = x_fc1_silu * x_fc2
121117

122118
if is_vector and self.cfg["use_kv_cache"] and self.cfg["use_aie_gemv"]:
123-
result = self.aie_fc3_gemv(None, x)
119+
result = self.aie_fc3_gemv(x)
124120
return result.view(original_shape)
125121
else:
126122
return self.fc3(x).view(original_shape)

applications/llama_3.2_1b/src/block/gqa.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -166,15 +166,15 @@ def forward(self, x, mask, angles, input_pos=None):
166166
x_flat = x.reshape(1, -1) # Shape: (1, d_in)
167167
input_dtype = x.dtype
168168

169-
queries_flat = self.aie_query_gemv(None, x_flat)
169+
queries_flat = self.aie_query_gemv(x_flat)
170170
queries = queries_flat.reshape(b, num_tokens, self.d_out).to(input_dtype)
171171

172-
keys_flat = self.aie_key_gemv(None, x_flat)
172+
keys_flat = self.aie_key_gemv(x_flat)
173173
keys = keys_flat.reshape(
174174
b, num_tokens, self.num_kv_groups * self.head_dim
175175
).to(input_dtype)
176176

177-
values_flat = self.aie_value_gemv(None, x_flat)
177+
values_flat = self.aie_value_gemv(x_flat)
178178
values = values_flat.reshape(
179179
b, num_tokens, self.num_kv_groups * self.head_dim
180180
).to(input_dtype)
@@ -384,7 +384,7 @@ def my_mha(queries, keys, values):
384384
# Choose output projection based on phase
385385
if self.cfg["use_kv_cache"] and is_decode and self.cfg["use_aie_gemv"]:
386386
context_vec_flat = context_vec.reshape(1, -1)
387-
output_flat = self.aie_out_proj_gemv(None, context_vec_flat)
387+
output_flat = self.aie_out_proj_gemv(context_vec_flat)
388388
context_vec = output_flat.reshape(b, num_tokens, self.d_out).to(input_dtype)
389389
elif self.cfg["use_aie_attn_projection_gemm"]:
390390
context_vec_flat = context_vec.reshape(-1, self.d_out)

applications/llama_3.2_1b/src/operator/aie_gemv.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def set_up(self):
107107
self.add_buffer("output", self.M)
108108
self.add_to_runlist("gemv", "matrix", "vector", "output")
109109

110-
def forward(self, matrix, vector):
110+
def forward(self, vector, matrix=None):
111111
"""Forward pass through GEMV operation
112112
113113
Args:

0 commit comments

Comments
 (0)