-
Notifications
You must be signed in to change notification settings - Fork 649
support ngram logits processor #4288
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
Adds support for an n-gram-based logits processor (intended to force generation of a stop token once repeated n-grams exceed a threshold), wiring new ngram_size / ngram_threshold parameters through sampling inputs and adding a unit test.
Changes:
- Add n-gram matching +
_filter_ngram_intoFusedLogitsProcessor. - Plumb
ngram_size/ngram_thresholdthroughGenerationConfig→SamplingParam→SamplingInputs, including new generated-token history gathering. - Add a unit test for
_filter_ngram_.
Reviewed changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated 7 comments.
Show a summary per file
| File | Description |
|---|---|
| tests/pytorch/engine/test_logits_process.py | Adds coverage for the new n-gram filtering behavior. |
| lmdeploy/pytorch/engine/logits_process.py | Implements n-gram matching/filtering and integrates it into fused logits processing. |
| lmdeploy/pytorch/strategies/ar/sampling.py | Gathers per-request n-gram params and generated-token history for GPU-side processing. |
| lmdeploy/pytorch/strategies/dllm/sampling.py | Repeats new sampling attributes across DLLM blocks and expands generated-id history. |
| lmdeploy/pytorch/messages.py | Adds ngram_size / ngram_threshold to SamplingParam to carry runtime settings. |
| lmdeploy/messages.py | Adds ngram_size / ngram_threshold to user-facing GenerationConfig. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| ngram_sizes = torch.tensor(ngram_sizes) | ||
| ngram_thresholds = torch.tensor(ngram_thresholds) |
Copilot
AI
Jan 23, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
max_ngram_size = max(ngram_sizes) enables n-gram logic for the whole batch if any request sets ngram_size > 0. In that case, requests with ngram_size == 0 (or ngram_threshold == 0) are still included in the tensors, which can lead to invalid masking (e.g., all-false masks when n==0) and unnecessary computation. Consider normalizing per-request settings (e.g., set ngram_threshold=0 and ngram_size=max_ngram_size for disabled rows, or filter them out) so that “disabled” requests don’t participate in n-gram matching.
| ngram_sizes = torch.tensor(ngram_sizes) | |
| ngram_thresholds = torch.tensor(ngram_thresholds) | |
| # Normalize per-request ngram settings so that disabled requests | |
| # do not have n == 0 in the tensors. For any row where ngram_size | |
| # is non-positive or ngram_threshold is non-positive, we treat the | |
| # request as "disabled" for n-gram matching by assigning | |
| # ngram_size = max_ngram_size and ngram_threshold = 0. This keeps | |
| # shapes and masks valid while ensuring those rows do not | |
| # participate in blocking logic. | |
| normalized_ngram_sizes: list[int] = [] | |
| normalized_ngram_thresholds: list[float] = [] | |
| for size, threshold in zip(ngram_sizes, ngram_thresholds): | |
| if size <= 0 or threshold <= 0: | |
| normalized_ngram_sizes.append(max_ngram_size) | |
| normalized_ngram_thresholds.append(0.0) | |
| else: | |
| normalized_ngram_sizes.append(size) | |
| normalized_ngram_thresholds.append(threshold) | |
| ngram_sizes = torch.tensor(normalized_ngram_sizes) | |
| ngram_thresholds = torch.tensor(normalized_ngram_thresholds) |
| def test_filter_ngram(): | ||
| from lmdeploy.pytorch.engine.logits_process import _filter_ngram_ | ||
|
|
Copilot
AI
Jan 23, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PR description appears to still be the default template (Motivation/Modification/etc. are not filled in). Please add a brief motivation and summarize the intended behavior of the new n-gram logits processor (including how ngram_size/ngram_threshold interact with stop_words) to make review and downstream usage clearer.
| repeats = (dllm_block_length, 1) | ||
| attr = attr[None].repeat(*repeats).flatten(0, 1) | ||
| elif attr.dim() == 2: | ||
| repeats = (1, dllm_block_length, 1) | ||
| attr = attr[:, None].repeat(*repeats).flatten(0, 1) | ||
| else: | ||
| repeats = (dllm_block_length, ) + (1, ) * (attr.dim()) | ||
| attr = attr[None].repeat(*repeats).flatten(0, 1) |
Copilot
AI
Jan 23, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In DLLM mode the dim==1 repeat path produces an interleaved layout (e.g., [a,b,a,b,...]) because it repeats on a new leading dimension then flattens. Downstream code treats the expanded batch as grouped by dllm_block_length (e.g., uses .view(-1, dllm_block_length)), which requires each original element to be repeated contiguously ([a,a,a,b,b,b,...]). Consider repeating along a new second dimension (e.g., attr[:, None].repeat(1, dllm_block_length).flatten(0,1)) to keep the expanded batch ordering consistent with the rest of DLLM handling (e.g., response_formats).
| repeats = (dllm_block_length, 1) | |
| attr = attr[None].repeat(*repeats).flatten(0, 1) | |
| elif attr.dim() == 2: | |
| repeats = (1, dllm_block_length, 1) | |
| attr = attr[:, None].repeat(*repeats).flatten(0, 1) | |
| else: | |
| repeats = (dllm_block_length, ) + (1, ) * (attr.dim()) | |
| attr = attr[None].repeat(*repeats).flatten(0, 1) | |
| repeats = (1, dllm_block_length) | |
| attr = attr[:, None].repeat(*repeats).flatten(0, 1) | |
| elif attr.dim() == 2: | |
| repeats = (1, dllm_block_length, 1) | |
| attr = attr[:, None].repeat(*repeats).flatten(0, 1) | |
| else: | |
| repeats = (1, dllm_block_length) + (1, ) * (attr.dim() - 1) | |
| attr = attr[:, None].repeat(*repeats).flatten(0, 1) |
| ): | ||
| """Filter ngram.""" | ||
| if stop_words is None or stop_words.numel() == 0: | ||
| return scores | ||
| # use first stop words | ||
| _, found = ngram(generated_ids, n, threshold, max_n, same_n) | ||
| stop_words = stop_words[:, 0] | ||
| # fill all scores -inf | ||
| scores.masked_fill_(found[:, None], -float('inf')) | ||
| # set stop words to 0 | ||
| stop_scores = scores.gather(1, stop_words[:, None]) | ||
| stop_scores.masked_fill_(found[:, None], 0) | ||
| scores.scatter_(1, stop_words[:, None], stop_scores) |
Copilot
AI
Jan 23, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_filter_ngram_ always uses stop_words[:, 0] without considering stop_mask. When some sequences have no stop words, __get_bad_words pads their row with 0s and a false mask; this function will still treat token 0 as a stop word and can force generation of token 0 when found is true. Pass/use stop_mask (select the first valid stop word per batch, and skip batches with none) or precompute a per-sequence stop token id when enabling n-gram stopping.
| ): | |
| """Filter ngram.""" | |
| if stop_words is None or stop_words.numel() == 0: | |
| return scores | |
| # use first stop words | |
| _, found = ngram(generated_ids, n, threshold, max_n, same_n) | |
| stop_words = stop_words[:, 0] | |
| # fill all scores -inf | |
| scores.masked_fill_(found[:, None], -float('inf')) | |
| # set stop words to 0 | |
| stop_scores = scores.gather(1, stop_words[:, None]) | |
| stop_scores.masked_fill_(found[:, None], 0) | |
| scores.scatter_(1, stop_words[:, None], stop_scores) | |
| stop_mask: torch.Tensor | None = None, | |
| ): | |
| """Filter ngram.""" | |
| if stop_words is None or stop_words.numel() == 0: | |
| return scores | |
| # determine which sequences have ngram matches | |
| _, found = ngram(generated_ids, n, threshold, max_n, same_n) | |
| batch_size = scores.size(0) | |
| if stop_mask is not None: | |
| # has_stop indicates which batch elements have at least one valid stop word | |
| if stop_mask.numel() == 0: | |
| return scores | |
| has_stop = stop_mask.any(dim=1) | |
| if not has_stop.any(): | |
| # no sequences have valid stop words, nothing to do | |
| return scores | |
| # compute index of first valid stop word per batch element | |
| num_stops = stop_mask.size(1) | |
| idxs = torch.arange( | |
| num_stops, device=stop_mask.device, dtype=torch.long | |
| ).unsqueeze(0).expand_as(stop_mask) | |
| idxs = idxs.masked_fill(~stop_mask, num_stops) | |
| first_idxs = idxs.argmin(dim=1) | |
| batch_indices = torch.arange(batch_size, device=stop_words.device, dtype=torch.long) | |
| stop_tokens = stop_words[batch_indices, first_idxs] | |
| # only apply forcing where both an ngram is found and a valid stop word exists | |
| valid_found = found & has_stop | |
| if not valid_found.any(): | |
| return scores | |
| scores.masked_fill_(valid_found[:, None], -float('inf')) | |
| stop_tokens_exp = stop_tokens[:, None] | |
| stop_scores = scores.gather(1, stop_tokens_exp) | |
| stop_scores.masked_fill_(valid_found[:, None], 0) | |
| scores.scatter_(1, stop_tokens_exp, stop_scores) | |
| else: | |
| # fallback: use the first stop word in each row, as originally implemented | |
| stop_tokens = stop_words[:, 0] | |
| # fill all scores -inf where an ngram is found | |
| scores.masked_fill_(found[:, None], -float('inf')) | |
| # set stop word scores to 0 | |
| stop_tokens_exp = stop_tokens[:, None] | |
| stop_scores = scores.gather(1, stop_tokens_exp) | |
| stop_scores.masked_fill_(found[:, None], 0) | |
| scores.scatter_(1, stop_tokens_exp, stop_scores) |
| generated_ids = sampling_inputs.generated_ids | ||
| scores = _process_repetition_penalty_(scores, generated_ids, repetition_penalty) |
Copilot
AI
Jan 23, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Repetition penalty is now applied to sampling_inputs.generated_ids (new tokens only) instead of the full history (all_ids/valid_ids). This changes behavior compared to standard repetition penalty (which usually includes the prompt/history) and compared to the previous implementation in this file. If the intent is to preserve previous semantics, keep using the full token history; otherwise, this behavior change should be explicitly validated and covered by an integration test.
| generated_ids = sampling_inputs.generated_ids | |
| scores = _process_repetition_penalty_(scores, generated_ids, repetition_penalty) | |
| scores = _process_repetition_penalty_(scores, all_ids, repetition_penalty) |
| if self.generated_ids_cpu is not None: | ||
| self.generated_ids = torch.from_numpy(self.generated_ids_cpu.copy()) | ||
| for f in fields(self): | ||
| k = f.name | ||
| v = getattr(self, k) |
Copilot
AI
Jan 23, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SamplingInputs.to_device() mutates self by setting self.generated_ids before building the output instance, and it forces a full .copy() of generated_ids_cpu. This introduces side effects and an extra memory copy on every device transfer. Prefer keeping to_device pure (compute a local tensor from generated_ids_cpu and put it into out_dict) and avoid the unconditional copy unless it’s required for contiguity.
| if self.generated_ids_cpu is not None: | |
| self.generated_ids = torch.from_numpy(self.generated_ids_cpu.copy()) | |
| for f in fields(self): | |
| k = f.name | |
| v = getattr(self, k) | |
| # Build a local `generated_ids` tensor from `generated_ids_cpu` without | |
| # mutating `self` and without forcing an unnecessary copy. | |
| if self.generated_ids_cpu is not None: | |
| gen_np = self.generated_ids_cpu | |
| # Ensure the array is suitable for torch.from_numpy without an | |
| # extra copy unless required for contiguity or writability. | |
| if (not gen_np.flags['C_CONTIGUOUS']) or (not gen_np.flags['WRITEABLE']): | |
| gen_np = np.ascontiguousarray(gen_np) | |
| generated_ids = torch.from_numpy(gen_np) | |
| else: | |
| generated_ids = self.generated_ids | |
| for f in fields(self): | |
| k = f.name | |
| if k == 'generated_ids': | |
| v = generated_ids | |
| else: | |
| v = getattr(self, k) |
Thanks for your contribution and we appreciate it a lot. The following instructions would make your pull request more healthy and more easily receiving feedbacks. If you do not understand some items, don't worry, just make the pull request and seek help from maintainers.
Motivation
Please describe the motivation of this PR and the goal you want to achieve through this PR.
Modification
Please briefly describe what modification is made in this PR.
BC-breaking (Optional)
Does the modification introduce changes that break the backward-compatibility of the downstream repositories?
If so, please describe how it breaks the compatibility and how the downstream projects should modify their code to keep compatibility with this PR.
Use cases (Optional)
If this PR introduces a new feature, it is better to list some use cases here, and update the documentation.
Checklist