diff --git a/lmdeploy/messages.py b/lmdeploy/messages.py index 8b54705e31..00da23baf6 100644 --- a/lmdeploy/messages.py +++ b/lmdeploy/messages.py @@ -129,6 +129,10 @@ class GenerationConfig: # router replay return_routed_experts: bool = False + # ngram + ngram_size: int = 0 + ngram_threshold: int = 0 + def convert_stop_bad_words_to_ids(self, tokenizer: Tokenizer): """Convert stop_words/bad_sords to ids and append the ids to stop_token_ids/bad_token_ids.""" diff --git a/lmdeploy/pytorch/engine/logits_process.py b/lmdeploy/pytorch/engine/logits_process.py index d6b5542581..9190e14925 100644 --- a/lmdeploy/pytorch/engine/logits_process.py +++ b/lmdeploy/pytorch/engine/logits_process.py @@ -1,8 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. import asyncio from dataclasses import dataclass, fields -from typing import Any, Dict, List, Optional, Tuple +from functools import lru_cache +from typing import Any, Dict, List, Tuple +import numpy as np import torch from lmdeploy.messages import LogitsProcessor @@ -29,7 +31,7 @@ def _process_bad_words_(scores: torch.Tensor, return scores -def _process_repetition_penalty_(scores: torch.Tensor, input_ids: torch.LongTensor, penalty: torch.Tensor): +def _process_repetition_penalty_(scores: torch.Tensor, input_ids: torch.Tensor, penalty: torch.Tensor): """Process repetition penalty.""" score = torch.gather(scores, 1, input_ids) penalty = penalty.to(score.dtype) @@ -68,6 +70,116 @@ def _filter_minp_sorted_(scores: torch.Tensor, minp: torch.Tensor, filter_value: return scores +@lru_cache(maxsize=1) +def _ngram_one(dtype: torch.dtype, device: torch.device): + return torch.ones(1, dtype=dtype, device=device) + + +def ngram(token_ids: torch.Tensor, n: torch.Tensor, threshold: torch.Tensor, max_n: int, same_n: bool = False): + """Compute n-gram matches between sliding windows and a target sequence. + + For each batch, performs cosine similarity checking between: + - All sliding windows of length `max_n` from the full sequence + - The last `max_n` tokens of the sequence (target window) + + A match is counted when both: + 1. Cosine similarity ≈ 1 (normalized vectors match) + 2. Vector lengths match (preventing zero/normalization artifacts) + + Parameters + ---------- + token_ids : torch.Tensor + Input token IDs of shape (batch_size, seq_len). + Values are typically ≥0 (0 may represent padding/special tokens). + n : torch.Tensor + Effective n-gram length for each batch element, shape (batch_size,). + When `same_n=False`, positions beyond `n` in the last `max_n` tokens are masked. + threshold : torch.Tensor + Minimum number of matching windows required for validity, shape (batch_size,). + max_n : int + Maximum n-gram length (window size for matching). + same_n : bool, default False + If True, use full `max_n`-length windows regardless of `n`. + If False, mask positions where index < (max_n - n) in the target window. + + Returns + ------- + matched_mask : torch.Tensor + Boolean mask of shape (batch_size, seq_len - max_n + 1) indicating + which sliding windows match the target n-gram. + found : torch.Tensor + Boolean tensor of shape (batch_size,) indicating whether each batch + element has at least `threshold` matches. + """ + + batch_size, seq_len = token_ids.size() + if seq_len < max_n: + # Not enough tokens to form a single n-gram + matched_mask = torch.zeros((batch_size, 0), dtype=torch.bool, device=token_ids.device) + found = torch.zeros((batch_size, ), dtype=torch.bool, device=token_ids.device) + return matched_mask, found + # token_ids could be 0, so we add 1 to avoid div 0 + token_ids = token_ids.to(torch.float32) + 1 + + # normalize ids + norm = token_ids[:, -max_n:] + if not same_n: + # fill 0 for n < max_n + mask = torch.arange(max_n, device=token_ids.device).unsqueeze(0) >= (max_n - n.unsqueeze(1)) + norm = norm * mask.to(torch.float32) + norm = norm.norm(2, dim=-1, keepdim=True) + normed_ids = token_ids / norm + + # concate p1 and p2 so we can check distance and vector in one conv1d + normed_n_ids = normed_ids[:, -max_n:] + normed_ids_p2 = normed_ids * normed_ids + ones_ids = torch.ones_like(normed_n_ids) + if not same_n: + # fill 0 for n < max_n + normed_n_ids = normed_n_ids * mask.to(torch.float32) + ones_ids = ones_ids * mask.to(torch.float32) + normed_ids = torch.cat([normed_ids, normed_ids_p2], dim=0) + normed_n_ids = torch.cat([normed_n_ids, ones_ids], dim=0) + + # check cos distance & check vector length + match_norm = torch.conv1d(normed_ids.unsqueeze(0), normed_n_ids.unsqueeze(1), groups=batch_size * 2)[0] + match_norm, match_ones = match_norm.chunk(2, dim=0) + + # both match result should be close to 1 + one_tensor = _ngram_one(dtype=match_norm.dtype, device=match_norm.device) + matched_mask = match_norm.isclose(one_tensor) & match_ones.isclose(one_tensor) + + # threshold + count = matched_mask.sum(-1) + found = (count >= threshold) & (threshold > 0) + + return matched_mask, found + + +def _filter_ngram_( + scores: torch.Tensor, + stop_words: torch.Tensor, + generated_ids: torch.Tensor, + n: torch.Tensor, + threshold: torch.Tensor, + max_n: int, + same_n: bool = False, +): + """Filter ngram.""" + if stop_words is None or stop_words.numel() == 0: + return scores + # use first stop words + _, found = ngram(generated_ids, n, threshold, max_n, same_n) + stop_words = stop_words[:, 0] + # fill all scores -inf + scores.masked_fill_(found[:, None], -float('inf')) + # set stop words to 0 + stop_scores = scores.gather(1, stop_words[:, None]) + stop_scores.masked_fill_(found[:, None], 0) + scores.scatter_(1, stop_words[:, None], stop_scores) + return scores + + def _multinomial_sampling(scores: torch.Tensor, seeds: torch.LongTensor, offsets: torch.LongTensor, @@ -84,7 +196,7 @@ def _multinomial_sampling(scores: torch.Tensor, class SamplingInputsDelta: num_ignore_eos: torch.Tensor = None random_offsets: torch.Tensor = None - all_ids: Optional[torch.Tensor] = None + all_ids: None | torch.Tensor = None @dataclass @@ -104,16 +216,27 @@ class SamplingInputs: min_top_p: float = 1.0 response_formats: Tuple[str] = () logits_processors: List[List[LogitsProcessor]] = None - max_num_logprobs: Optional[int] = None - all_ids: Optional[torch.Tensor] = None + max_num_logprobs: None | int = None + all_ids: None | torch.Tensor = None num_ignore_eos: torch.Tensor = None batch_size: int = 0 - session_ctx: Optional[List[Dict[str, Any]]] = None - session_to_cleanup: Optional[List[int]] = None + session_ctx: None | List[Dict[str, Any]] = None + session_to_cleanup: None | List[int] = None + # for repetition_penalty and ngram + generated_ids: torch.Tensor | None = None + generated_ids_cpu: np.ndarray | None = None + + # n gram + ngram_size: torch.Tensor = None + ngram_threshold: torch.Tensor = None + max_ngram_size: int = 0 + ngram_same_n: bool = False def to_device(self, device: str, non_blocking: bool = False): """To device.""" out_dict = dict() + if self.generated_ids_cpu is not None: + self.generated_ids = torch.from_numpy(self.generated_ids_cpu.copy()) for f in fields(self): k = f.name v = getattr(self, k) @@ -168,8 +291,8 @@ class FusedLogitsProcessor: def __init__( self, sampling_inputs: SamplingInputs, - logprobs_mode: Optional[str] = None, - guided_decoding_manager: Optional[GuidedDecodingManager] = None, + logprobs_mode: None | str = None, + guided_decoding_manager: None | GuidedDecodingManager = None, ): self.sampling_inputs: SamplingInputs = sampling_inputs self.logprobs_mode = logprobs_mode @@ -238,7 +361,20 @@ async def __call__(self, scores: torch.FloatTensor) -> torch.FloatTensor: repetition_penalty = sampling_inputs.repetition_penalty if repetition_penalty is not None: - scores = _process_repetition_penalty_(scores, all_ids, repetition_penalty) + generated_ids = sampling_inputs.generated_ids + scores = _process_repetition_penalty_(scores, generated_ids, repetition_penalty) + + if sampling_inputs.max_ngram_size > 0: + generated_ids = sampling_inputs.generated_ids + scores = _filter_ngram_( + scores, + sampling_inputs.stop_words, + generated_ids, + sampling_inputs.ngram_size, + sampling_inputs.ngram_threshold, + sampling_inputs.max_ngram_size, + sampling_inputs.ngram_same_n, + ) temperature = sampling_inputs.temperature if temperature is not None: diff --git a/lmdeploy/pytorch/messages.py b/lmdeploy/pytorch/messages.py index c020403fa8..fb71465d1e 100644 --- a/lmdeploy/pytorch/messages.py +++ b/lmdeploy/pytorch/messages.py @@ -2,7 +2,7 @@ import enum from collections import defaultdict from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List import numpy as np import torch @@ -56,13 +56,17 @@ class SamplingParam: bad_words: List[int] = field(default_factory=list) max_new_tokens: int = 512 min_new_tokens: int = 0 - response_format: Optional[str] = None - logits_processors: Optional[List[LogitsProcessor]] = None + response_format: None | str = None + logits_processors: None | List[LogitsProcessor] = None out_logits: bool = False out_last_hidden_states: bool = False num_logprobs: int = -1 return_routed_experts: bool = False + # ngram + ngram_size: int = 0 + ngram_threshold: int = 0 + @classmethod def from_gen_config(cls, gen_config: GenerationConfig): """From gen config.""" @@ -144,6 +148,8 @@ def from_gen_config(cls, gen_config: GenerationConfig): out_logits=(output_logits is not None), num_logprobs=logprobs, return_routed_experts=gen_config.return_routed_experts, + ngram_size=gen_config.ngram_size, + ngram_threshold=gen_config.ngram_threshold, ) @@ -262,7 +268,7 @@ def add_sequence(self, adapter_name: str = None, multimodals: MultiModalInputs = None, input_embeddings: List[InputEmbeddings] = None, - migration_request: Optional[MigrationRequest] = None, + migration_request: None | MigrationRequest = None, resp_cache: bool = False, preserve_cache: bool = False) -> 'SchedulerSequence': """Add a new message.""" @@ -604,7 +610,7 @@ class SchedulerSequence: model_meta: Dict[str, Any] = None # For Disaggregation - migration_request: Optional[MigrationRequest] = None + migration_request: None | MigrationRequest = None resp_cache: bool = False preserve_cache: bool = False @@ -698,7 +704,7 @@ def routed_experts(self) -> np.ndarray: else: return None - def append_routed_experts(self, routed_experts: Union[Tensor, np.ndarray]): + def append_routed_experts(self, routed_experts: Tensor | np.ndarray): """Append routed experts.""" if not self.return_routed_experts: return @@ -756,7 +762,7 @@ def logits(self): """Get logits.""" return self.all_logits.get_logits() - def append_logits(self, logits: Union[Tensor, np.ndarray]): + def append_logits(self, logits: Tensor | np.ndarray): """Append logits.""" if not self.return_logits: return @@ -776,7 +782,7 @@ def get_input_multimodals(self): def record_event( self, event_type: EventType, - timestamp: Optional[float] = None, + timestamp: None | float = None, ) -> None: self.engine_events.append(EngineEvent.new_event(event_type, timestamp)) diff --git a/lmdeploy/pytorch/strategies/ar/sampling.py b/lmdeploy/pytorch/strategies/ar/sampling.py index 4db051ddcf..cb95308af8 100644 --- a/lmdeploy/pytorch/strategies/ar/sampling.py +++ b/lmdeploy/pytorch/strategies/ar/sampling.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from typing import List +import numpy as np import torch from torch.profiler import record_function @@ -15,7 +16,7 @@ def _gather_all_ids(pad_id: int, seqs: SeqList, sampling_inputs: SamplingInputs): """Gather history.""" - if sampling_inputs.repetition_penalty is None and not any(sampling_inputs.logits_processors): + if not any(sampling_inputs.logits_processors): return None batch = len(seqs) max_len = max(seq.num_valid_ids for seq in seqs) @@ -29,6 +30,22 @@ def _gather_all_ids(pad_id: int, seqs: SeqList, sampling_inputs: SamplingInputs) return output +def _gather_generated_ids(pad_id: int, seqs: SeqList, sampling_inputs: SamplingInputs) -> np.ndarray | None: + """Gather history.""" + if sampling_inputs.repetition_penalty is None and sampling_inputs.max_ngram_size == 0: + return None + batch = len(seqs) + max_len = max(seq.num_new_tokens for seq in seqs) + output = np.full((batch, max_len), pad_id, dtype=np.int64) + for idx, seq in enumerate(seqs): + h_len = seq.num_new_tokens + if h_len == 0: + continue + h_ids = seq.generated_ids + output[idx, -h_len:] = h_ids + return output + + def _get_num_ignore_eos(seqs: SeqList): """Get num ignore eos.""" ret = [seq.sampling_param.min_new_tokens - seq.num_new_tokens for seq in seqs] @@ -61,6 +78,8 @@ def make_sampling_inputs(self, seqs: SeqList) -> SamplingInputs: num_logprobs = [None] * batch_size session_to_cleanup = self.session_to_cleanup self.session_to_cleanup = [] + ngram_sizes = [None] * batch_size + ngram_thresholds = [None] * batch_size def __gather_params(): """Gather params.""" @@ -84,6 +103,8 @@ def __gather_params(): stop_words[idx] = sw logits_processors[idx] = param.logits_processors num_logprobs[idx] = param.num_logprobs + ngram_sizes[idx] = param.ngram_size + ngram_thresholds[idx] = param.ngram_threshold def __get_topp(top_p): """Get topp.""" @@ -165,6 +186,19 @@ def __get_bad_words(bad_words): 'seq_id': seq.seq_id, } for seq in seqs] + # ngram + max_ngram_size = max(ngram_sizes) + if max_ngram_size == 0: + ngram_sizes = None + ngram_thresholds = None + ngram_same_n = True + else: + ngram_sizes = torch.tensor(ngram_sizes) + ngram_thresholds = torch.tensor(ngram_thresholds) + ngram_same_n = (ngram_sizes == max_ngram_size).all().item() + if ngram_same_n: + ngram_sizes = None + sampling_input = SamplingInputs( temperature=temperature, bad_words=bad_words, @@ -185,10 +219,15 @@ def __get_bad_words(bad_words): batch_size=batch_size, session_ctx=session_ctx, session_to_cleanup=session_to_cleanup, + ngram_size=ngram_sizes, + ngram_threshold=ngram_thresholds, + max_ngram_size=max_ngram_size, + ngram_same_n=ngram_same_n, ) pad_token_id = self.pad_token_id sampling_input.all_ids = _gather_all_ids(pad_token_id, seqs, sampling_input) + sampling_input.generated_ids_cpu = _gather_generated_ids(pad_token_id, seqs, sampling_input) sampling_input.num_ignore_eos = _get_num_ignore_eos(seqs) return sampling_input diff --git a/lmdeploy/pytorch/strategies/dllm/sampling.py b/lmdeploy/pytorch/strategies/dllm/sampling.py index 5a027e922d..d7c8bc4716 100644 --- a/lmdeploy/pytorch/strategies/dllm/sampling.py +++ b/lmdeploy/pytorch/strategies/dllm/sampling.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from typing import List +import numpy as np import torch from torch.profiler import record_function @@ -42,15 +43,34 @@ def make_sampling_inputs(self, seqs: SeqList) -> SamplingInputs: 'random_offsets', 'all_ids', 'num_ignore_eos', + 'ngram_size', + 'ngram_threshold', ] for name in update_attr_names: attr = getattr(out, name) if attr is None: continue - repeats = (dllm_block_length, ) + (1, ) * (attr.dim()) - attr = attr[None].repeat(*repeats).flatten(0, 1) + if attr.dim() == 1: + repeats = (dllm_block_length, 1) + attr = attr[None].repeat(*repeats).flatten(0, 1) + elif attr.dim() == 2: + repeats = (1, dllm_block_length, 1) + attr = attr[:, None].repeat(*repeats).flatten(0, 1) + else: + repeats = (dllm_block_length, ) + (1, ) * (attr.dim()) + attr = attr[None].repeat(*repeats).flatten(0, 1) setattr(out, name, attr) + # update generated_ids_cpu + if out.generated_ids_cpu is not None: + generated_ids_cpu = out.generated_ids_cpu + if generated_ids_cpu.shape[1] == 0: + out.generated_ids_cpu = np.repeat(generated_ids_cpu, dllm_block_length, axis=0) + else: + generated_ids_cpu = np.repeat(generated_ids_cpu[:, None], dllm_block_length, axis=1) + generated_ids_cpu = np.reshape(generated_ids_cpu, (-1, generated_ids_cpu.shape[-1])) + out.generated_ids_cpu = generated_ids_cpu + if len(out.response_formats) > 0: new_resp_formats = [] for resp in out.response_formats: diff --git a/tests/pytorch/engine/test_logits_process.py b/tests/pytorch/engine/test_logits_process.py index b901879be4..7a85a62d4f 100644 --- a/tests/pytorch/engine/test_logits_process.py +++ b/tests/pytorch/engine/test_logits_process.py @@ -124,3 +124,28 @@ def test_filter_minp_sorted(): out = _filter_minp_sorted_(scores, min_p) torch.testing.assert_close(out, gt) + + +def test_filter_ngram(): + from lmdeploy.pytorch.engine.logits_process import _filter_ngram_ + + generated_ids = torch.tensor( + [[2, 3, 4, 1, 2, 3, 4, 2, 3, 4], [9, 8, 7, 3, 8, 7, 5, 9, 8, 7], [9, 8, 7, 3, 8, 7, 5, 9, 8, 7]], + dtype=torch.int64) + n = torch.tensor([3, 3, 2], dtype=torch.int64) + threshold = torch.tensor([3, 3, 3], dtype=torch.int64) + + batch_size = generated_ids.size(0) + max_n = n.max().item() + same_n = n.eq(max_n).all().item() + vocab_size = 100 + + scores = torch.rand(batch_size, vocab_size) + stop_words = torch.randint(0, vocab_size, (batch_size, 3), dtype=torch.int64) + _filter_ngram_(scores, stop_words, generated_ids, n, threshold, max_n, same_n) + + assert not scores[1].isinf().any().item() + assert scores[0].isinf().sum().item() == vocab_size - 1 + assert scores[2].isinf().sum().item() == vocab_size - 1 + assert scores[0, stop_words[0, 0]] == 0 + assert scores[2, stop_words[2, 0]] == 0