Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions lmdeploy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,10 @@ class GenerationConfig:
# router replay
return_routed_experts: bool = False

# ngram
ngram_size: int = 0
ngram_threshold: int = 0

def convert_stop_bad_words_to_ids(self, tokenizer: Tokenizer):
"""Convert stop_words/bad_sords to ids and append the ids to
stop_token_ids/bad_token_ids."""
Expand Down
156 changes: 146 additions & 10 deletions lmdeploy/pytorch/engine/logits_process.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
import asyncio
from dataclasses import dataclass, fields
from typing import Any, Dict, List, Optional, Tuple
from functools import lru_cache
from typing import Any, Dict, List, Tuple

import numpy as np
import torch

from lmdeploy.messages import LogitsProcessor
Expand All @@ -29,7 +31,7 @@ def _process_bad_words_(scores: torch.Tensor,
return scores


def _process_repetition_penalty_(scores: torch.Tensor, input_ids: torch.LongTensor, penalty: torch.Tensor):
def _process_repetition_penalty_(scores: torch.Tensor, input_ids: torch.Tensor, penalty: torch.Tensor):
"""Process repetition penalty."""
score = torch.gather(scores, 1, input_ids)
penalty = penalty.to(score.dtype)
Expand Down Expand Up @@ -68,6 +70,116 @@ def _filter_minp_sorted_(scores: torch.Tensor, minp: torch.Tensor, filter_value:
return scores


@lru_cache(maxsize=1)
def _ngram_one(dtype: torch.dtype, device: torch.device):
return torch.ones(1, dtype=dtype, device=device)


def ngram(token_ids: torch.Tensor, n: torch.Tensor, threshold: torch.Tensor, max_n: int, same_n: bool = False):
"""Compute n-gram matches between sliding windows and a target sequence.

For each batch, performs cosine similarity checking between:
- All sliding windows of length `max_n` from the full sequence
- The last `max_n` tokens of the sequence (target window)

A match is counted when both:
1. Cosine similarity ≈ 1 (normalized vectors match)
2. Vector lengths match (preventing zero/normalization artifacts)

Parameters
----------
token_ids : torch.Tensor
Input token IDs of shape (batch_size, seq_len).
Values are typically ≥0 (0 may represent padding/special tokens).
n : torch.Tensor
Effective n-gram length for each batch element, shape (batch_size,).
When `same_n=False`, positions beyond `n` in the last `max_n` tokens are masked.
threshold : torch.Tensor
Minimum number of matching windows required for validity, shape (batch_size,).
max_n : int
Maximum n-gram length (window size for matching).
same_n : bool, default False
If True, use full `max_n`-length windows regardless of `n`.
If False, mask positions where index < (max_n - n) in the target window.

Returns
-------
matched_mask : torch.Tensor
Boolean mask of shape (batch_size, seq_len - max_n + 1) indicating
which sliding windows match the target n-gram.
found : torch.Tensor
Boolean tensor of shape (batch_size,) indicating whether each batch
element has at least `threshold` matches.
"""

batch_size, seq_len = token_ids.size()
if seq_len < max_n:
# Not enough tokens to form a single n-gram
matched_mask = torch.zeros((batch_size, 0), dtype=torch.bool, device=token_ids.device)
found = torch.zeros((batch_size, ), dtype=torch.bool, device=token_ids.device)
return matched_mask, found
# token_ids could be 0, so we add 1 to avoid div 0
token_ids = token_ids.to(torch.float32) + 1

# normalize ids
norm = token_ids[:, -max_n:]
if not same_n:
# fill 0 for n < max_n
mask = torch.arange(max_n, device=token_ids.device).unsqueeze(0) >= (max_n - n.unsqueeze(1))
norm = norm * mask.to(torch.float32)
norm = norm.norm(2, dim=-1, keepdim=True)
normed_ids = token_ids / norm

# concate p1 and p2 so we can check distance and vector in one conv1d
normed_n_ids = normed_ids[:, -max_n:]
normed_ids_p2 = normed_ids * normed_ids
ones_ids = torch.ones_like(normed_n_ids)
if not same_n:
# fill 0 for n < max_n
normed_n_ids = normed_n_ids * mask.to(torch.float32)
ones_ids = ones_ids * mask.to(torch.float32)
normed_ids = torch.cat([normed_ids, normed_ids_p2], dim=0)
normed_n_ids = torch.cat([normed_n_ids, ones_ids], dim=0)

# check cos distance & check vector length
match_norm = torch.conv1d(normed_ids.unsqueeze(0), normed_n_ids.unsqueeze(1), groups=batch_size * 2)[0]
match_norm, match_ones = match_norm.chunk(2, dim=0)

# both match result should be close to 1
one_tensor = _ngram_one(dtype=match_norm.dtype, device=match_norm.device)
matched_mask = match_norm.isclose(one_tensor) & match_ones.isclose(one_tensor)

# threshold
count = matched_mask.sum(-1)
found = (count >= threshold) & (threshold > 0)

return matched_mask, found


def _filter_ngram_(
scores: torch.Tensor,
stop_words: torch.Tensor,
generated_ids: torch.Tensor,
n: torch.Tensor,
threshold: torch.Tensor,
max_n: int,
same_n: bool = False,
):
"""Filter ngram."""
if stop_words is None or stop_words.numel() == 0:
return scores
# use first stop words
_, found = ngram(generated_ids, n, threshold, max_n, same_n)
stop_words = stop_words[:, 0]
# fill all scores -inf
scores.masked_fill_(found[:, None], -float('inf'))
# set stop words to 0
stop_scores = scores.gather(1, stop_words[:, None])
stop_scores.masked_fill_(found[:, None], 0)
scores.scatter_(1, stop_words[:, None], stop_scores)
Comment on lines +167 to +179
Copy link

Copilot AI Jan 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_filter_ngram_ always uses stop_words[:, 0] without considering stop_mask. When some sequences have no stop words, __get_bad_words pads their row with 0s and a false mask; this function will still treat token 0 as a stop word and can force generation of token 0 when found is true. Pass/use stop_mask (select the first valid stop word per batch, and skip batches with none) or precompute a per-sequence stop token id when enabling n-gram stopping.

Suggested change
):
"""Filter ngram."""
if stop_words is None or stop_words.numel() == 0:
return scores
# use first stop words
_, found = ngram(generated_ids, n, threshold, max_n, same_n)
stop_words = stop_words[:, 0]
# fill all scores -inf
scores.masked_fill_(found[:, None], -float('inf'))
# set stop words to 0
stop_scores = scores.gather(1, stop_words[:, None])
stop_scores.masked_fill_(found[:, None], 0)
scores.scatter_(1, stop_words[:, None], stop_scores)
stop_mask: torch.Tensor | None = None,
):
"""Filter ngram."""
if stop_words is None or stop_words.numel() == 0:
return scores
# determine which sequences have ngram matches
_, found = ngram(generated_ids, n, threshold, max_n, same_n)
batch_size = scores.size(0)
if stop_mask is not None:
# has_stop indicates which batch elements have at least one valid stop word
if stop_mask.numel() == 0:
return scores
has_stop = stop_mask.any(dim=1)
if not has_stop.any():
# no sequences have valid stop words, nothing to do
return scores
# compute index of first valid stop word per batch element
num_stops = stop_mask.size(1)
idxs = torch.arange(
num_stops, device=stop_mask.device, dtype=torch.long
).unsqueeze(0).expand_as(stop_mask)
idxs = idxs.masked_fill(~stop_mask, num_stops)
first_idxs = idxs.argmin(dim=1)
batch_indices = torch.arange(batch_size, device=stop_words.device, dtype=torch.long)
stop_tokens = stop_words[batch_indices, first_idxs]
# only apply forcing where both an ngram is found and a valid stop word exists
valid_found = found & has_stop
if not valid_found.any():
return scores
scores.masked_fill_(valid_found[:, None], -float('inf'))
stop_tokens_exp = stop_tokens[:, None]
stop_scores = scores.gather(1, stop_tokens_exp)
stop_scores.masked_fill_(valid_found[:, None], 0)
scores.scatter_(1, stop_tokens_exp, stop_scores)
else:
# fallback: use the first stop word in each row, as originally implemented
stop_tokens = stop_words[:, 0]
# fill all scores -inf where an ngram is found
scores.masked_fill_(found[:, None], -float('inf'))
# set stop word scores to 0
stop_tokens_exp = stop_tokens[:, None]
stop_scores = scores.gather(1, stop_tokens_exp)
stop_scores.masked_fill_(found[:, None], 0)
scores.scatter_(1, stop_tokens_exp, stop_scores)

Copilot uses AI. Check for mistakes.
return scores


def _multinomial_sampling(scores: torch.Tensor,
seeds: torch.LongTensor,
offsets: torch.LongTensor,
Expand All @@ -84,7 +196,7 @@ def _multinomial_sampling(scores: torch.Tensor,
class SamplingInputsDelta:
num_ignore_eos: torch.Tensor = None
random_offsets: torch.Tensor = None
all_ids: Optional[torch.Tensor] = None
all_ids: None | torch.Tensor = None


@dataclass
Expand All @@ -104,16 +216,27 @@ class SamplingInputs:
min_top_p: float = 1.0
response_formats: Tuple[str] = ()
logits_processors: List[List[LogitsProcessor]] = None
max_num_logprobs: Optional[int] = None
all_ids: Optional[torch.Tensor] = None
max_num_logprobs: None | int = None
all_ids: None | torch.Tensor = None
num_ignore_eos: torch.Tensor = None
batch_size: int = 0
session_ctx: Optional[List[Dict[str, Any]]] = None
session_to_cleanup: Optional[List[int]] = None
session_ctx: None | List[Dict[str, Any]] = None
session_to_cleanup: None | List[int] = None
# for repetition_penalty and ngram
generated_ids: torch.Tensor | None = None
generated_ids_cpu: np.ndarray | None = None

# n gram
ngram_size: torch.Tensor = None
ngram_threshold: torch.Tensor = None
max_ngram_size: int = 0
ngram_same_n: bool = False

def to_device(self, device: str, non_blocking: bool = False):
"""To device."""
out_dict = dict()
if self.generated_ids_cpu is not None:
self.generated_ids = torch.from_numpy(self.generated_ids_cpu.copy())
for f in fields(self):
k = f.name
v = getattr(self, k)
Comment on lines +238 to 242
Copy link

Copilot AI Jan 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SamplingInputs.to_device() mutates self by setting self.generated_ids before building the output instance, and it forces a full .copy() of generated_ids_cpu. This introduces side effects and an extra memory copy on every device transfer. Prefer keeping to_device pure (compute a local tensor from generated_ids_cpu and put it into out_dict) and avoid the unconditional copy unless it’s required for contiguity.

Suggested change
if self.generated_ids_cpu is not None:
self.generated_ids = torch.from_numpy(self.generated_ids_cpu.copy())
for f in fields(self):
k = f.name
v = getattr(self, k)
# Build a local `generated_ids` tensor from `generated_ids_cpu` without
# mutating `self` and without forcing an unnecessary copy.
if self.generated_ids_cpu is not None:
gen_np = self.generated_ids_cpu
# Ensure the array is suitable for torch.from_numpy without an
# extra copy unless required for contiguity or writability.
if (not gen_np.flags['C_CONTIGUOUS']) or (not gen_np.flags['WRITEABLE']):
gen_np = np.ascontiguousarray(gen_np)
generated_ids = torch.from_numpy(gen_np)
else:
generated_ids = self.generated_ids
for f in fields(self):
k = f.name
if k == 'generated_ids':
v = generated_ids
else:
v = getattr(self, k)

Copilot uses AI. Check for mistakes.
Expand Down Expand Up @@ -168,8 +291,8 @@ class FusedLogitsProcessor:
def __init__(
self,
sampling_inputs: SamplingInputs,
logprobs_mode: Optional[str] = None,
guided_decoding_manager: Optional[GuidedDecodingManager] = None,
logprobs_mode: None | str = None,
guided_decoding_manager: None | GuidedDecodingManager = None,
):
self.sampling_inputs: SamplingInputs = sampling_inputs
self.logprobs_mode = logprobs_mode
Expand Down Expand Up @@ -238,7 +361,20 @@ async def __call__(self, scores: torch.FloatTensor) -> torch.FloatTensor:

repetition_penalty = sampling_inputs.repetition_penalty
if repetition_penalty is not None:
scores = _process_repetition_penalty_(scores, all_ids, repetition_penalty)
generated_ids = sampling_inputs.generated_ids
scores = _process_repetition_penalty_(scores, generated_ids, repetition_penalty)
Comment on lines +364 to +365
Copy link

Copilot AI Jan 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Repetition penalty is now applied to sampling_inputs.generated_ids (new tokens only) instead of the full history (all_ids/valid_ids). This changes behavior compared to standard repetition penalty (which usually includes the prompt/history) and compared to the previous implementation in this file. If the intent is to preserve previous semantics, keep using the full token history; otherwise, this behavior change should be explicitly validated and covered by an integration test.

Suggested change
generated_ids = sampling_inputs.generated_ids
scores = _process_repetition_penalty_(scores, generated_ids, repetition_penalty)
scores = _process_repetition_penalty_(scores, all_ids, repetition_penalty)

Copilot uses AI. Check for mistakes.

if sampling_inputs.max_ngram_size > 0:
generated_ids = sampling_inputs.generated_ids
scores = _filter_ngram_(
scores,
sampling_inputs.stop_words,
generated_ids,
sampling_inputs.ngram_size,
sampling_inputs.ngram_threshold,
sampling_inputs.max_ngram_size,
sampling_inputs.ngram_same_n,
)

temperature = sampling_inputs.temperature
if temperature is not None:
Expand Down
22 changes: 14 additions & 8 deletions lmdeploy/pytorch/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import enum
from collections import defaultdict
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, List

import numpy as np
import torch
Expand Down Expand Up @@ -56,13 +56,17 @@ class SamplingParam:
bad_words: List[int] = field(default_factory=list)
max_new_tokens: int = 512
min_new_tokens: int = 0
response_format: Optional[str] = None
logits_processors: Optional[List[LogitsProcessor]] = None
response_format: None | str = None
logits_processors: None | List[LogitsProcessor] = None
out_logits: bool = False
out_last_hidden_states: bool = False
num_logprobs: int = -1
return_routed_experts: bool = False

# ngram
ngram_size: int = 0
ngram_threshold: int = 0

@classmethod
def from_gen_config(cls, gen_config: GenerationConfig):
"""From gen config."""
Expand Down Expand Up @@ -144,6 +148,8 @@ def from_gen_config(cls, gen_config: GenerationConfig):
out_logits=(output_logits is not None),
num_logprobs=logprobs,
return_routed_experts=gen_config.return_routed_experts,
ngram_size=gen_config.ngram_size,
ngram_threshold=gen_config.ngram_threshold,
)


Expand Down Expand Up @@ -262,7 +268,7 @@ def add_sequence(self,
adapter_name: str = None,
multimodals: MultiModalInputs = None,
input_embeddings: List[InputEmbeddings] = None,
migration_request: Optional[MigrationRequest] = None,
migration_request: None | MigrationRequest = None,
resp_cache: bool = False,
preserve_cache: bool = False) -> 'SchedulerSequence':
"""Add a new message."""
Expand Down Expand Up @@ -604,7 +610,7 @@ class SchedulerSequence:
model_meta: Dict[str, Any] = None

# For Disaggregation
migration_request: Optional[MigrationRequest] = None
migration_request: None | MigrationRequest = None
resp_cache: bool = False
preserve_cache: bool = False

Expand Down Expand Up @@ -698,7 +704,7 @@ def routed_experts(self) -> np.ndarray:
else:
return None

def append_routed_experts(self, routed_experts: Union[Tensor, np.ndarray]):
def append_routed_experts(self, routed_experts: Tensor | np.ndarray):
"""Append routed experts."""
if not self.return_routed_experts:
return
Expand Down Expand Up @@ -756,7 +762,7 @@ def logits(self):
"""Get logits."""
return self.all_logits.get_logits()

def append_logits(self, logits: Union[Tensor, np.ndarray]):
def append_logits(self, logits: Tensor | np.ndarray):
"""Append logits."""
if not self.return_logits:
return
Expand All @@ -776,7 +782,7 @@ def get_input_multimodals(self):
def record_event(
self,
event_type: EventType,
timestamp: Optional[float] = None,
timestamp: None | float = None,
) -> None:
self.engine_events.append(EngineEvent.new_event(event_type, timestamp))

Expand Down
41 changes: 40 additions & 1 deletion lmdeploy/pytorch/strategies/ar/sampling.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List

import numpy as np
import torch
from torch.profiler import record_function

Expand All @@ -15,7 +16,7 @@

def _gather_all_ids(pad_id: int, seqs: SeqList, sampling_inputs: SamplingInputs):
"""Gather history."""
if sampling_inputs.repetition_penalty is None and not any(sampling_inputs.logits_processors):
if not any(sampling_inputs.logits_processors):
return None
batch = len(seqs)
max_len = max(seq.num_valid_ids for seq in seqs)
Expand All @@ -29,6 +30,22 @@ def _gather_all_ids(pad_id: int, seqs: SeqList, sampling_inputs: SamplingInputs)
return output


def _gather_generated_ids(pad_id: int, seqs: SeqList, sampling_inputs: SamplingInputs) -> np.ndarray | None:
"""Gather history."""
if sampling_inputs.repetition_penalty is None and sampling_inputs.max_ngram_size == 0:
return None
batch = len(seqs)
max_len = max(seq.num_new_tokens for seq in seqs)
output = np.full((batch, max_len), pad_id, dtype=np.int64)
for idx, seq in enumerate(seqs):
h_len = seq.num_new_tokens
if h_len == 0:
continue
h_ids = seq.generated_ids
output[idx, -h_len:] = h_ids
return output


def _get_num_ignore_eos(seqs: SeqList):
"""Get num ignore eos."""
ret = [seq.sampling_param.min_new_tokens - seq.num_new_tokens for seq in seqs]
Expand Down Expand Up @@ -61,6 +78,8 @@ def make_sampling_inputs(self, seqs: SeqList) -> SamplingInputs:
num_logprobs = [None] * batch_size
session_to_cleanup = self.session_to_cleanup
self.session_to_cleanup = []
ngram_sizes = [None] * batch_size
ngram_thresholds = [None] * batch_size

def __gather_params():
"""Gather params."""
Expand All @@ -84,6 +103,8 @@ def __gather_params():
stop_words[idx] = sw
logits_processors[idx] = param.logits_processors
num_logprobs[idx] = param.num_logprobs
ngram_sizes[idx] = param.ngram_size
ngram_thresholds[idx] = param.ngram_threshold

def __get_topp(top_p):
"""Get topp."""
Expand Down Expand Up @@ -165,6 +186,19 @@ def __get_bad_words(bad_words):
'seq_id': seq.seq_id,
} for seq in seqs]

# ngram
max_ngram_size = max(ngram_sizes)
if max_ngram_size == 0:
ngram_sizes = None
ngram_thresholds = None
ngram_same_n = True
else:
ngram_sizes = torch.tensor(ngram_sizes)
ngram_thresholds = torch.tensor(ngram_thresholds)
Comment on lines +196 to +197
Copy link

Copilot AI Jan 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

max_ngram_size = max(ngram_sizes) enables n-gram logic for the whole batch if any request sets ngram_size > 0. In that case, requests with ngram_size == 0 (or ngram_threshold == 0) are still included in the tensors, which can lead to invalid masking (e.g., all-false masks when n==0) and unnecessary computation. Consider normalizing per-request settings (e.g., set ngram_threshold=0 and ngram_size=max_ngram_size for disabled rows, or filter them out) so that “disabled” requests don’t participate in n-gram matching.

Suggested change
ngram_sizes = torch.tensor(ngram_sizes)
ngram_thresholds = torch.tensor(ngram_thresholds)
# Normalize per-request ngram settings so that disabled requests
# do not have n == 0 in the tensors. For any row where ngram_size
# is non-positive or ngram_threshold is non-positive, we treat the
# request as "disabled" for n-gram matching by assigning
# ngram_size = max_ngram_size and ngram_threshold = 0. This keeps
# shapes and masks valid while ensuring those rows do not
# participate in blocking logic.
normalized_ngram_sizes: list[int] = []
normalized_ngram_thresholds: list[float] = []
for size, threshold in zip(ngram_sizes, ngram_thresholds):
if size <= 0 or threshold <= 0:
normalized_ngram_sizes.append(max_ngram_size)
normalized_ngram_thresholds.append(0.0)
else:
normalized_ngram_sizes.append(size)
normalized_ngram_thresholds.append(threshold)
ngram_sizes = torch.tensor(normalized_ngram_sizes)
ngram_thresholds = torch.tensor(normalized_ngram_thresholds)

Copilot uses AI. Check for mistakes.
ngram_same_n = (ngram_sizes == max_ngram_size).all().item()
if ngram_same_n:
ngram_sizes = None

sampling_input = SamplingInputs(
temperature=temperature,
bad_words=bad_words,
Expand All @@ -185,10 +219,15 @@ def __get_bad_words(bad_words):
batch_size=batch_size,
session_ctx=session_ctx,
session_to_cleanup=session_to_cleanup,
ngram_size=ngram_sizes,
ngram_threshold=ngram_thresholds,
max_ngram_size=max_ngram_size,
ngram_same_n=ngram_same_n,
)

pad_token_id = self.pad_token_id
sampling_input.all_ids = _gather_all_ids(pad_token_id, seqs, sampling_input)
sampling_input.generated_ids_cpu = _gather_generated_ids(pad_token_id, seqs, sampling_input)
sampling_input.num_ignore_eos = _get_num_ignore_eos(seqs)
return sampling_input

Expand Down
Loading
Loading