Skip to content

Commit dc5c511

Browse files
committed
Address Curts Comments
1 parent c7923d6 commit dc5c511

File tree

1 file changed

+2
-3
lines changed

1 file changed

+2
-3
lines changed

applications/llama_3.2_1b/src/block/feed_forward.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,9 +109,8 @@ def forward(self, x):
109109
)
110110

111111
is_decode_with_kv = is_vector and self.cfg["use_kv_cache"]
112-
is_prefill = not is_vector or not self.cfg["use_kv_cache"]
113112

114-
if is_vector and self.cfg["use_kv_cache"] and self.cfg["use_aie_gemv"]:
113+
if is_decode_with_kv and self.cfg["use_aie_gemv"]:
115114
x_fc1 = self.aie_fc1_gemv(x)
116115
x_fc2 = self.aie_fc2_gemv(x)
117116
else:
@@ -125,7 +124,7 @@ def forward(self, x):
125124
else:
126125
x = x_fc1_silu * x_fc2
127126

128-
if is_vector and self.cfg["use_kv_cache"] and self.cfg["use_aie_gemv"]:
127+
if is_decode_with_kv and self.cfg["use_aie_gemv"]:
129128
result = self.aie_fc3_gemv(x)
130129
return result.view(original_shape)
131130
else:

0 commit comments

Comments
 (0)