Skip to content

Commit acf168b

Browse files
committed
more fixes
1 parent 33e3bf7 commit acf168b

File tree

1 file changed

+3
-7
lines changed

1 file changed

+3
-7
lines changed

mlx_engine/cache.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,16 +33,12 @@ def do_reuse(self) -> None:
3333
key_segments.append(self.keys[..., current_pos:write_start_idx, :])
3434
value_segments.append(self.values[..., current_pos:write_start_idx, :])
3535

36-
# add the reused segment with RoPE shift
37-
shift_by = write_start_idx - reuse_start_idx # intentionally negative!!!
3836
reuse_end_idx = reuse_start_idx + reuse_length
37+
current_pos = write_start_idx + reuse_length
3938

4039
key_segments.append(self.keys[..., reuse_start_idx:reuse_end_idx, :])
4140
value_segments.append(self.values[..., reuse_start_idx:reuse_end_idx, :])
4241

43-
current_pos = write_start_idx + reuse_length
44-
self.offset += shift_by
45-
4642
self.keys = mx.concatenate(key_segments, axis=2)
4743
self.values = mx.concatenate(value_segments, axis=2)
4844

@@ -52,12 +48,12 @@ def do_reuse(self) -> None:
5248
self.offset = self.keys.shape[2]
5349

5450
def trim(self, n) -> int:
55-
# trim does not respect keep, which must be the case
51+
# trim must not respect keep
5652
n = min(self.offset, n)
5753
if n <= 0:
5854
return 0
5955

60-
# do trim: put us back into the state before the circular buffer is full
56+
# put us back into the state before the circular buffer is full
6157
self.keys = self._temporal_order(self.keys)
6258
self.values = self._temporal_order(self.values)
6359

0 commit comments

Comments
 (0)