@@ -33,16 +33,12 @@ def do_reuse(self) -> None:
3333 key_segments .append (self .keys [..., current_pos :write_start_idx , :])
3434 value_segments .append (self .values [..., current_pos :write_start_idx , :])
3535
36- # add the reused segment with RoPE shift
37- shift_by = write_start_idx - reuse_start_idx # intentionally negative!!!
3836 reuse_end_idx = reuse_start_idx + reuse_length
37+ current_pos = write_start_idx + reuse_length
3938
4039 key_segments .append (self .keys [..., reuse_start_idx :reuse_end_idx , :])
4140 value_segments .append (self .values [..., reuse_start_idx :reuse_end_idx , :])
4241
43- current_pos = write_start_idx + reuse_length
44- self .offset += shift_by
45-
4642 self .keys = mx .concatenate (key_segments , axis = 2 )
4743 self .values = mx .concatenate (value_segments , axis = 2 )
4844
@@ -52,12 +48,12 @@ def do_reuse(self) -> None:
5248 self .offset = self .keys .shape [2 ]
5349
5450 def trim (self , n ) -> int :
55- # trim does not respect keep, which must be the case
51+ # trim must not respect keep
5652 n = min (self .offset , n )
5753 if n <= 0 :
5854 return 0
5955
60- # do trim: put us back into the state before the circular buffer is full
56+ # put us back into the state before the circular buffer is full
6157 self .keys = self ._temporal_order (self .keys )
6258 self .values = self ._temporal_order (self .values )
6359
0 commit comments