1212from pathlib import Path
1313from src .block .transformer import TransformerBlock
1414from operators .rope .rope_utils import compute_rope_params
15- from operators import (
16- AIERMSNorm ,
17- AIEGEMM ,
18- )
15+ from operators import AIERMSNorm , AIEGEMM , AIEGEMV
1916from rich .console import Console
2017from rich .text import Text
2118
@@ -35,20 +32,22 @@ def dtype_from_string(inp):
3532config_options = {
3633 "dtype" : (dtype_from_string , torch .float32 , "Data type" ),
3734 "use_kv_cache" : (bool , False , "[Model] KV Cache" ),
38- "use_aie_gemv" : (bool , False , "[Decode] GEMV" ),
3935 "use_aie_rope" : (bool , False , "[Attention] Rope" ),
4036 "use_aie_attn_projection_gemm" : (bool , False , "[Attention] QKV GEMM" ),
4137 "use_aie_regular_mha" : (bool , False , "[Attention] Regular MHA" ),
4238 "use_aie_fused_mha" : (bool , False , "[Attention] Fused MHA" ),
39+ "use_aie_gqa_gemv" : (bool , False , "[Attention] GEMV (Decode)" ),
4340 "use_aie_ffn_gemm" : (bool , False , "[FFN] GEMM" ),
4441 "use_aie_ffn_mul" : (bool , False , "[FFN] Elementwise Mul" ),
4542 "use_aie_ffn_silu" : (bool , False , "[FFN] SiLU" ),
4643 "use_aie_ffn_swiglu" : (bool , False , "[FFN] Runlist-based SwiGLU" ),
44+ "use_aie_ffn_gemv" : (bool , False , "[FFN] GEMV (Decode)" ),
4745 "use_aie_residual" : (bool , False , "[Transformer] Residual Addition" ),
4846 "use_aie_norm1" : (bool , False , "[Transformer] Pre Norm" ),
4947 "use_aie_norm2" : (bool , False , "[Transformer] Post Norm" ),
5048 "use_aie_final_norm" : (bool , False , "[Transformer] Final Norm" ),
5149 "use_aie_final_gemm" : (bool , False , "[Transformer] Final GEMM" ),
50+ "use_aie_final_gemv" : (bool , False , "[Transformer] Final GEMV" ),
5251}
5352# fmt: on
5453
@@ -190,25 +189,32 @@ def __init__(
190189 M_for_gemm = self .prompt_length
191190 else :
192191 M_for_gemm = self .prompt_length + self .num_tokens
193- self .out_head_aie = AIEGEMM (
192+ self .out_head_prefill = AIEGEMM (
194193 M = M_for_gemm ,
195194 K = self .cfg ["emb_dim" ],
196195 N = self .cfg ["vocab_size" ],
197196 ** aie_config_prefill ,
198197 )
198+ aie_gemv_config = {
199+ "num_aie_columns" : 8 ,
200+ "is_mv" : True ,
201+ "use_static_weight" : True ,
202+ "num_aie_columns" : 8 ,
203+ "tile_size_input" : 4 ,
204+ "tile_size_output" : 32 ,
205+ }
206+ # FC1 and FC2: emb_dim -> hidden_dim
207+ if self .cfg ["use_aie_final_gemv" ]:
208+ self .out_head_decode = AIEGEMV (
209+ M = self .cfg ["vocab_size" ], K = self .cfg ["emb_dim" ], ** aie_gemv_config
210+ )
199211 else :
200212 self .out_head = nn .Linear (
201213 self .cfg ["emb_dim" ],
202214 self .cfg ["vocab_size" ],
203215 bias = False ,
204216 dtype = self .cfg ["dtype" ],
205217 )
206- self .out_head = nn .Linear (
207- self .cfg ["emb_dim" ],
208- self .cfg ["vocab_size" ],
209- bias = False ,
210- dtype = self .cfg ["dtype" ],
211- )
212218
213219 # Reusable utilities
214220 cos , sin = compute_rope_params (
@@ -269,12 +275,10 @@ def forward(self, in_idx, input_pos=None, use_kv_cache=False):
269275 x = self .final_norm (x )
270276
271277 if self .cfg ["use_aie_final_gemm" ]:
272- if is_decode_with_kv and self .cfg ["use_aie_gemv" ]:
273- # TODO: Create GEMV operator
274- # logits = self.aie_out_head_gemv(x)
275- logits = self .out_head (x ) # Running on CPU
278+ if is_decode_with_kv and self .cfg ["use_aie_final_gemv" ]:
279+ logits = self .out_head_decode (x )
276280 else :
277- logits = self .out_head_aie (x )
281+ logits = self .out_head_prefill (x )
278282 else :
279283 logits = self .out_head (x )
280284
@@ -292,20 +296,11 @@ def assign_weights(self, final_norm, out_head, out_head_name):
292296 f"model.norm.weight" ,
293297 )
294298
295- self .out_head .weight = assign (
296- self .out_head .weight ,
297- out_head ,
298- out_head_name ,
299- )
300- # TODO: Offload GEMV to NPU
301- # if self.cfg["use_kv_cache"] and self.cfg["use_aie_gemv"]:
302- # self.aie_out_head_gemv.weight = out_head
303299 if self .cfg ["use_aie_final_gemm" ]:
304300 # Want column-major for B
305- self .out_head_aie .weight = out_head .T
306- # TODO: Create separate linear layers for prefill and decode (with gemm/gemv)
307- # if self.cfg["use_kv_cache"]:
308- # self.out_head.weight = out_head.T
301+ self .out_head_prefill .weight = out_head .T
302+ if self .cfg ["use_aie_final_gemv" ]:
303+ self .out_head_decode .weight = out_head .T
309304 else :
310305 self .out_head .weight = assign (
311306 self .out_head .weight ,
0 commit comments