@@ -75,11 +75,7 @@ def free_model():
7575 self ._exit_stack .callback (free_model )
7676
7777 def close (self ):
78- if self .sampler is not None :
79- # NOTE: Must remove custom samplers before free or llama.cpp will try to free them
80- for i , _ in reversed (self .custom_samplers ):
81- llama_cpp .llama_sampler_chain_remove (self .sampler , i )
82- self .custom_samplers .clear ()
78+ # NOTE: LlamaModel doesn't manage samplers (that's LlamaSampler's job)
8379 self ._exit_stack .close ()
8480
8581 def __del__ (self ):
@@ -292,19 +288,26 @@ def kv_cache_clear(self):
292288
293289 def kv_cache_seq_rm (self , seq_id : int , p0 : int , p1 : int ):
294290 assert self .memory is not None , "Memory is not initialized"
295- seq_id = seq_id if seq_id >= 0 else 0
291+ # seq_id < 0 means "all sequences" - this is valid per llama.cpp docs
296292 llama_cpp .llama_memory_seq_rm (self .memory , seq_id , p0 , p1 )
297293
298294 def kv_cache_seq_cp (self , seq_id_src : int , seq_id_dst : int , p0 : int , p1 : int ):
299295 assert self .memory is not None , "Memory is not initialized"
296+ # Negative seq_id not documented for cp - require non-negative IDs
297+ assert seq_id_src >= 0 , f"seq_id_src must be >= 0, got { seq_id_src } "
298+ assert seq_id_dst >= 0 , f"seq_id_dst must be >= 0, got { seq_id_dst } "
300299 llama_cpp .llama_memory_seq_cp (self .memory , seq_id_src , seq_id_dst , p0 , p1 )
301300
302301 def kv_cache_seq_keep (self , seq_id : int ):
303302 assert self .memory is not None , "Memory is not initialized"
303+ # Negative seq_id not documented for keep - require non-negative ID
304+ assert seq_id >= 0 , f"seq_id must be >= 0, got { seq_id } "
304305 llama_cpp .llama_memory_seq_keep (self .memory , seq_id )
305306
306307 def kv_cache_seq_shift (self , seq_id : int , p0 : int , p1 : int , shift : int ):
307308 assert self .memory is not None , "Memory is not initialized"
309+ # Negative seq_id not documented for shift - require non-negative ID
310+ assert seq_id >= 0 , f"seq_id must be >= 0, got { seq_id } "
308311 llama_cpp .llama_memory_seq_add (self .memory , seq_id , p0 , p1 , shift )
309312
310313 def get_state_size (self ) -> int :
@@ -355,7 +358,9 @@ def get_embeddings_seq(self, seq_id: int):
355358 # Sampling functions - deprecated, use LlamaSampler instead
356359
357360 def set_rng_seed (self , seed : int ):
358- raise NotImplementedError ("set_rng_seed is deprecated, use LlamaSampler instead" )
361+ raise NotImplementedError (
362+ "set_rng_seed is deprecated, use LlamaSampler instead"
363+ )
359364
360365 def sample_repetition_penalties (
361366 self ,
@@ -366,30 +371,44 @@ def sample_repetition_penalties(
366371 penalty_freq : float ,
367372 penalty_present : float ,
368373 ):
369- raise NotImplementedError ("sample_repetition_penalties is deprecated, use LlamaSampler instead" )
374+ raise NotImplementedError (
375+ "sample_repetition_penalties is deprecated, use LlamaSampler instead"
376+ )
370377
371378 def sample_softmax (self , candidates : "_LlamaTokenDataArray" ):
372- raise NotImplementedError ("sample_softmax is deprecated, use LlamaSampler instead" )
379+ raise NotImplementedError (
380+ "sample_softmax is deprecated, use LlamaSampler instead"
381+ )
373382
374383 def sample_top_k (self , candidates : "_LlamaTokenDataArray" , k : int , min_keep : int ):
375- raise NotImplementedError ("sample_top_k is deprecated, use LlamaSampler instead" )
384+ raise NotImplementedError (
385+ "sample_top_k is deprecated, use LlamaSampler instead"
386+ )
376387
377388 def sample_top_p (self , candidates : "_LlamaTokenDataArray" , p : float , min_keep : int ):
378- raise NotImplementedError ("sample_top_p is deprecated, use LlamaSampler instead" )
389+ raise NotImplementedError (
390+ "sample_top_p is deprecated, use LlamaSampler instead"
391+ )
379392
380393 def sample_min_p (self , candidates : "_LlamaTokenDataArray" , p : float , min_keep : int ):
381- raise NotImplementedError ("sample_min_p is deprecated, use LlamaSampler instead" )
394+ raise NotImplementedError (
395+ "sample_min_p is deprecated, use LlamaSampler instead"
396+ )
382397
383398 def sample_typical (
384399 self , candidates : "_LlamaTokenDataArray" , p : float , min_keep : int
385400 ):
386- raise NotImplementedError ("sample_typical is deprecated, use LlamaSampler instead" )
401+ raise NotImplementedError (
402+ "sample_typical is deprecated, use LlamaSampler instead"
403+ )
387404
388405 def sample_temp (self , candidates : "_LlamaTokenDataArray" , temp : float ):
389406 raise NotImplementedError ("sample_temp is deprecated, use LlamaSampler instead" )
390407
391408 def sample_grammar (self , candidates : "_LlamaTokenDataArray" , grammar : LlamaGrammar ):
392- raise NotImplementedError ("sample_grammar is deprecated, use LlamaSampler instead" )
409+ raise NotImplementedError (
410+ "sample_grammar is deprecated, use LlamaSampler instead"
411+ )
393412
394413 def sample_token_mirostat (
395414 self ,
@@ -399,7 +418,9 @@ def sample_token_mirostat(
399418 m : int ,
400419 mu : llama_cpp .CtypesPointerOrRef [ctypes .c_float ],
401420 ) -> int :
402- raise NotImplementedError ("sample_token_mirostat is deprecated, use LlamaSampler instead" )
421+ raise NotImplementedError (
422+ "sample_token_mirostat is deprecated, use LlamaSampler instead"
423+ )
403424
404425 def sample_token_mirostat_v2 (
405426 self ,
@@ -408,17 +429,25 @@ def sample_token_mirostat_v2(
408429 eta : float ,
409430 mu : llama_cpp .CtypesPointerOrRef [ctypes .c_float ],
410431 ) -> int :
411- raise NotImplementedError ("sample_token_mirostat_v2 is deprecated, use LlamaSampler instead" )
432+ raise NotImplementedError (
433+ "sample_token_mirostat_v2 is deprecated, use LlamaSampler instead"
434+ )
412435
413436 def sample_token_greedy (self , candidates : "_LlamaTokenDataArray" ) -> int :
414- raise NotImplementedError ("sample_token_greedy is deprecated, use LlamaSampler instead" )
437+ raise NotImplementedError (
438+ "sample_token_greedy is deprecated, use LlamaSampler instead"
439+ )
415440
416441 def sample_token (self , candidates : "_LlamaTokenDataArray" ) -> int :
417- raise NotImplementedError ("sample_token is deprecated, use LlamaSampler instead" )
442+ raise NotImplementedError (
443+ "sample_token is deprecated, use LlamaSampler instead"
444+ )
418445
419446 # Grammar
420447 def grammar_accept_token (self , grammar : LlamaGrammar , token : int ):
421- raise NotImplementedError ("grammar_accept_token is deprecated, use LlamaSampler instead" )
448+ raise NotImplementedError (
449+ "grammar_accept_token is deprecated, use LlamaSampler instead"
450+ )
422451
423452 def reset_timings (self ):
424453 llama_cpp .llama_perf_context_reset (self .ctx )
@@ -493,7 +522,7 @@ def add_sequence(self, batch: Sequence[int], seq_id: int, logits_all: bool):
493522 self .batch .seq_id [j ][0 ] = seq_id
494523 self .batch .n_seq_id [j ] = 1
495524 self .batch .logits [j ] = logits_all
496- self .batch .logits [n_tokens - 1 ] = True
525+ self .batch .logits [n_tokens0 + n_tokens - 1 ] = True
497526
498527
499528class LlamaTokenDataArray :
@@ -602,16 +631,16 @@ def sample(
602631 logits_array : Optional [npt .NDArray [np .single ]] = None ,
603632 ):
604633 # This method is deprecated in favor of using LlamaSampler directly
605- raise NotImplementedError ("LlamaSamplingContext.sample is deprecated, use LlamaSampler instead" )
634+ raise NotImplementedError (
635+ "LlamaSamplingContext.sample is deprecated, use LlamaSampler instead"
636+ )
606637
607638 def accept (self , ctx_main : LlamaContext , id : int , apply_grammar : bool ):
608639 self .prev .append (id )
609640
610641
611642class CustomSampler :
612- def __init__ (
613- self , apply_func : Callable [[llama_cpp .llama_token_data_array ], None ]
614- ):
643+ def __init__ (self , apply_func : Callable [[llama_cpp .llama_token_data_array ], None ]):
615644 self .apply_func = apply_func
616645
617646 def apply_wrapper (
@@ -646,6 +675,7 @@ def __init__(self):
646675 params = llama_cpp .llama_sampler_chain_default_params ()
647676 self .sampler = llama_cpp .llama_sampler_chain_init (params )
648677 self .custom_samplers : List [Tuple [int , CustomSampler ]] = []
678+ self ._pinned_buffers : List [ctypes .Array ] = [] # Pin C arrays to prevent GC
649679 self ._exit_stack = ExitStack ()
650680
651681 def free_sampler ():
@@ -723,28 +753,32 @@ def add_grammar(self, model: LlamaModel, grammar: LlamaGrammar):
723753 llama_cpp .llama_sampler_chain_add (self .sampler , sampler )
724754
725755 def add_grammar_lazy_patterns (
726- self ,
727- model : LlamaModel ,
756+ self ,
757+ model : LlamaModel ,
728758 grammar : LlamaGrammar ,
729759 trigger_patterns : List [str ],
730- trigger_tokens : List [int ]
760+ trigger_tokens : List [int ],
731761 ):
732762 # Convert patterns to C array
733763 pattern_ptrs = (ctypes .c_char_p * len (trigger_patterns ))()
734764 for i , pattern in enumerate (trigger_patterns ):
735765 pattern_ptrs [i ] = pattern .encode ("utf-8" )
736-
766+
737767 # Convert tokens to C array
738768 token_array = (llama_cpp .llama_token * len (trigger_tokens ))(* trigger_tokens )
739-
769+
770+ # Pin buffers to prevent garbage collection while C code may reference them
771+ self ._pinned_buffers .append (pattern_ptrs )
772+ self ._pinned_buffers .append (token_array )
773+
740774 sampler = llama_cpp .llama_sampler_init_grammar_lazy_patterns (
741775 model .vocab ,
742776 grammar ._grammar .encode ("utf-8" ),
743777 grammar ._root .encode ("utf-8" ),
744778 pattern_ptrs ,
745779 len (trigger_patterns ),
746780 token_array ,
747- len (trigger_tokens )
781+ len (trigger_tokens ),
748782 )
749783 llama_cpp .llama_sampler_chain_add (self .sampler , sampler )
750784
@@ -771,13 +805,16 @@ def add_dry(
771805 dry_base : float ,
772806 dry_allowed_length : int ,
773807 dry_penalty_last_n : int ,
774- seq_breakers : List [str ]
808+ seq_breakers : List [str ],
775809 ):
776810 # Convert seq_breakers to C array
777811 breaker_ptrs = (ctypes .c_char_p * len (seq_breakers ))()
778812 for i , breaker in enumerate (seq_breakers ):
779813 breaker_ptrs [i ] = breaker .encode ("utf-8" )
780-
814+
815+ # Pin buffer to prevent garbage collection
816+ self ._pinned_buffers .append (breaker_ptrs )
817+
781818 sampler = llama_cpp .llama_sampler_init_dry (
782819 model .vocab ,
783820 n_ctx_train ,
@@ -786,25 +823,22 @@ def add_dry(
786823 dry_allowed_length ,
787824 dry_penalty_last_n ,
788825 breaker_ptrs ,
789- len (seq_breakers )
826+ len (seq_breakers ),
790827 )
791828 llama_cpp .llama_sampler_chain_add (self .sampler , sampler )
792829
793- def add_logit_bias (
794- self ,
795- n_vocab : int ,
796- logit_bias : Dict [int , float ]
797- ):
830+ def add_logit_bias (self , n_vocab : int , logit_bias : Dict [int , float ]):
798831 # Convert logit_bias dict to C array
799832 bias_array = (llama_cpp .llama_logit_bias * len (logit_bias ))()
800833 for i , (token , bias ) in enumerate (logit_bias .items ()):
801834 bias_array [i ].token = token
802835 bias_array [i ].bias = bias
803-
836+
837+ # Pin buffer to prevent garbage collection
838+ self ._pinned_buffers .append (bias_array )
839+
804840 sampler = llama_cpp .llama_sampler_init_logit_bias (
805- n_vocab ,
806- len (logit_bias ),
807- bias_array
841+ n_vocab , len (logit_bias ), bias_array
808842 )
809843 llama_cpp .llama_sampler_chain_add (self .sampler , sampler )
810844
@@ -838,15 +872,17 @@ def reset(self):
838872 def clone (self ):
839873 # NOTE: Custom samplers cannot be cloned due to Python callback limitations
840874 if self .custom_samplers :
841- raise NotImplementedError ("Cannot clone LlamaSampler that contains custom samplers" )
842-
875+ raise NotImplementedError (
876+ "Cannot clone LlamaSampler that contains custom samplers"
877+ )
878+
843879 cloned_sampler = llama_cpp .llama_sampler_clone (self .sampler )
844880 # Create a new wrapper around the cloned sampler
845881 new_sampler = LlamaSampler .__new__ (LlamaSampler )
846882 new_sampler .sampler = cloned_sampler
847883 new_sampler .custom_samplers = []
848884 new_sampler ._exit_stack = ExitStack ()
849-
885+
850886 def free_sampler ():
851887 if new_sampler .sampler is not None :
852888 llama_cpp .llama_sampler_free (new_sampler .sampler )
0 commit comments