@@ -91,9 +91,9 @@ def __init__(
9191 logits_all : bool = False ,
9292 embedding : bool = False ,
9393 offload_kqv : bool = True ,
94- flash_attn : bool = False ,
9594 op_offload : Optional [bool ] = None ,
9695 swa_full : Optional [bool ] = None ,
96+ flash_attn : Optional [bool ] = None ,
9797 # Sampling Params
9898 no_perf : bool = False ,
9999 last_n_tokens_size : int = 64 ,
@@ -173,7 +173,7 @@ def __init__(
173173 logits_all: Return logits for all tokens, not just the last token. Must be True for completion to return logprobs.
174174 embedding: Embedding mode only.
175175 offload_kqv: Offload K, Q, V to GPU.
176- flash_attn: Use flash attention.
176+ flash_attn: Use flash attention. None = auto, True = enabled, False = disabled.
177177 op_offload: offload host tensor operations to device
178178 swa_full: use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
179179 no_perf: Measure performance timings.
@@ -341,7 +341,16 @@ def __init__(
341341 self ._logits_all = logits_all if draft_model is None else True
342342 self .context_params .embeddings = embedding # TODO: Rename to embeddings
343343 self .context_params .offload_kqv = offload_kqv
344- self .context_params .flash_attn = flash_attn
344+ if flash_attn is None :
345+ self .context_params .flash_attn_type = llama_cpp .LLAMA_FLASH_ATTN_TYPE_AUTO
346+ elif flash_attn :
347+ self .context_params .flash_attn_type = (
348+ llama_cpp .LLAMA_FLASH_ATTN_TYPE_ENABLED
349+ )
350+ else :
351+ self .context_params .flash_attn_type = (
352+ llama_cpp .LLAMA_FLASH_ATTN_TYPE_DISABLED
353+ )
345354
346355 if op_offload is not None :
347356 self .context_params .op_offload = op_offload
@@ -934,7 +943,8 @@ def generate(
934943
935944 sample_idx += 1
936945 if stopping_criteria is not None and stopping_criteria (
937- self ._input_ids [: sample_idx ], self ._scores [sample_idx - self .n_tokens , :]
946+ self ._input_ids [:sample_idx ],
947+ self ._scores [sample_idx - self .n_tokens , :],
938948 ):
939949 return
940950 tokens_or_none = yield token
@@ -1041,7 +1051,9 @@ def embed(
10411051 data : Union [List [List [float ]], List [List [List [float ]]]] = []
10421052
10431053 def decode_batch (seq_sizes : List [int ]):
1044- llama_cpp .llama_kv_self_clear (self ._ctx .ctx )
1054+ mem = llama_cpp .llama_get_memory (self ._ctx .ctx )
1055+ if mem is not None :
1056+ llama_cpp .llama_memory_clear (mem , True )
10451057 self ._ctx .decode (self ._batch )
10461058 self ._batch .reset ()
10471059
@@ -1112,7 +1124,9 @@ def decode_batch(seq_sizes: List[int]):
11121124
11131125 output = data [0 ] if isinstance (input , str ) else data
11141126
1115- llama_cpp .llama_kv_self_clear (self ._ctx .ctx )
1127+ mem = llama_cpp .llama_get_memory (self ._ctx .ctx )
1128+ if mem is not None :
1129+ llama_cpp .llama_memory_clear (mem , True )
11161130 self .reset ()
11171131
11181132 if return_count :
@@ -1157,9 +1171,9 @@ def _create_completion(
11571171 bos_token_id : int = self .token_bos ()
11581172 cls_token_id : int = self ._model .token_cls ()
11591173 sep_token_id : int = self ._model .token_sep ()
1160- prefix_token_id : int = 0 # self._model.token_prefix() # TODO: Fix
1161- middle_token_id : int = 0 # self._model.token_middle() # TODO: Fix
1162- suffix_token_id : int = 0 # self._model.token_suffix() # TODO: Fix
1174+ prefix_token_id : int = self ._model .token_prefix ()
1175+ middle_token_id : int = self ._model .token_middle ()
1176+ suffix_token_id : int = self ._model .token_suffix ()
11631177 add_space_prefix : bool = (
11641178 self .metadata .get ("tokenizer.ggml.add_space_prefix" , "true" ) == "true"
11651179 )
@@ -1315,7 +1329,7 @@ def logit_bias_processor(
13151329 if seed is not None :
13161330 self .set_seed (seed )
13171331 else :
1318- self .set_seed (random .Random (self ._seed ).randint (0 , 2 ** 32 ))
1332+ self .set_seed (random .Random (self ._seed ).randint (0 , 2 ** 32 ))
13191333
13201334 finish_reason = "length"
13211335 multibyte_fix = 0
@@ -2056,7 +2070,10 @@ def create_chat_completion_openai_v1(
20562070 stream = kwargs .get ("stream" , False ) # type: ignore
20572071 assert isinstance (stream , bool )
20582072 if stream :
2059- return (ChatCompletionChunk (** chunk ) for chunk in self .create_chat_completion (* args , ** kwargs )) # type: ignore
2073+ return (
2074+ ChatCompletionChunk (** chunk )
2075+ for chunk in self .create_chat_completion (* args , ** kwargs )
2076+ ) # type: ignore
20602077 else :
20612078 return ChatCompletion (** self .create_chat_completion (* args , ** kwargs )) # type: ignore
20622079 except ImportError :
@@ -2096,7 +2113,7 @@ def __getstate__(self):
20962113 logits_all = self ._logits_all ,
20972114 embedding = self .context_params .embeddings ,
20982115 offload_kqv = self .context_params .offload_kqv ,
2099- flash_attn = self .context_params .flash_attn ,
2116+ flash_attn = self .context_params .flash_attn_type ,
21002117 op_offload = self .context_params .op_offload ,
21012118 swa_full = self .context_params .swa_full ,
21022119 # Sampling Params
@@ -2316,19 +2333,23 @@ def from_pretrained(
23162333 )
23172334
23182335 if additional_files :
2319- for additonal_file_name in additional_files :
2336+ for additional_file_name in additional_files :
23202337 # find the additional shard file:
2321- matching_additional_files = [file for file in file_list if fnmatch .fnmatch (file , additonal_file_name )]
2338+ matching_additional_files = [
2339+ file
2340+ for file in file_list
2341+ if fnmatch .fnmatch (file , additional_file_name )
2342+ ]
23222343
23232344 if len (matching_additional_files ) == 0 :
23242345 raise ValueError (
2325- f"No file found in { repo_id } that match { additonal_file_name } \n \n "
2346+ f"No file found in { repo_id } that match { additional_file_name } \n \n "
23262347 f"Available Files:\n { json .dumps (file_list )} "
23272348 )
23282349
23292350 if len (matching_additional_files ) > 1 :
23302351 raise ValueError (
2331- f"Multiple files found in { repo_id } matching { additonal_file_name } \n \n "
2352+ f"Multiple files found in { repo_id } matching { additional_file_name } \n \n "
23322353 f"Available Files:\n { json .dumps (files )} "
23332354 )
23342355
0 commit comments