File tree Expand file tree Collapse file tree 1 file changed +2
-3
lines changed
applications/llama_3.2_1b/src/block Expand file tree Collapse file tree 1 file changed +2
-3
lines changed Original file line number Diff line number Diff line change @@ -109,9 +109,8 @@ def forward(self, x):
109109 )
110110
111111 is_decode_with_kv = is_vector and self .cfg ["use_kv_cache" ]
112- is_prefill = not is_vector or not self .cfg ["use_kv_cache" ]
113112
114- if is_vector and self . cfg [ "use_kv_cache" ] and self .cfg ["use_aie_gemv" ]:
113+ if is_decode_with_kv and self .cfg ["use_aie_gemv" ]:
115114 x_fc1 = self .aie_fc1_gemv (x )
116115 x_fc2 = self .aie_fc2_gemv (x )
117116 else :
@@ -125,7 +124,7 @@ def forward(self, x):
125124 else :
126125 x = x_fc1_silu * x_fc2
127126
128- if is_vector and self . cfg [ "use_kv_cache" ] and self .cfg ["use_aie_gemv" ]:
127+ if is_decode_with_kv and self .cfg ["use_aie_gemv" ]:
129128 result = self .aie_fc3_gemv (x )
130129 return result .view (original_shape )
131130 else :
You can’t perform that action at this time.
0 commit comments