-
Notifications
You must be signed in to change notification settings - Fork 649
support ngram logits processor #4288
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -1,8 +1,10 @@ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Copyright (c) OpenMMLab. All rights reserved. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import asyncio | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from dataclasses import dataclass, fields | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from typing import Any, Dict, List, Optional, Tuple | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from functools import lru_cache | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from typing import Any, Dict, List, Tuple | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import numpy as np | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import torch | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from lmdeploy.messages import LogitsProcessor | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -29,7 +31,7 @@ def _process_bad_words_(scores: torch.Tensor, | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return scores | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def _process_repetition_penalty_(scores: torch.Tensor, input_ids: torch.LongTensor, penalty: torch.Tensor): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def _process_repetition_penalty_(scores: torch.Tensor, input_ids: torch.Tensor, penalty: torch.Tensor): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Process repetition penalty.""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| score = torch.gather(scores, 1, input_ids) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| penalty = penalty.to(score.dtype) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -68,6 +70,116 @@ def _filter_minp_sorted_(scores: torch.Tensor, minp: torch.Tensor, filter_value: | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return scores | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| @lru_cache(maxsize=1) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def _ngram_one(dtype: torch.dtype, device: torch.device): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return torch.ones(1, dtype=dtype, device=device) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def ngram(token_ids: torch.Tensor, n: torch.Tensor, threshold: torch.Tensor, max_n: int, same_n: bool = False): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Compute n-gram matches between sliding windows and a target sequence. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| For each batch, performs cosine similarity checking between: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| - All sliding windows of length `max_n` from the full sequence | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| - The last `max_n` tokens of the sequence (target window) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| A match is counted when both: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| 1. Cosine similarity ≈ 1 (normalized vectors match) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| 2. Vector lengths match (preventing zero/normalization artifacts) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Parameters | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ---------- | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| token_ids : torch.Tensor | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Input token IDs of shape (batch_size, seq_len). | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Values are typically ≥0 (0 may represent padding/special tokens). | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| n : torch.Tensor | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Effective n-gram length for each batch element, shape (batch_size,). | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| When `same_n=False`, positions beyond `n` in the last `max_n` tokens are masked. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| threshold : torch.Tensor | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Minimum number of matching windows required for validity, shape (batch_size,). | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| max_n : int | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Maximum n-gram length (window size for matching). | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| same_n : bool, default False | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| If True, use full `max_n`-length windows regardless of `n`. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| If False, mask positions where index < (max_n - n) in the target window. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Returns | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ------- | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| matched_mask : torch.Tensor | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Boolean mask of shape (batch_size, seq_len - max_n + 1) indicating | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| which sliding windows match the target n-gram. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| found : torch.Tensor | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Boolean tensor of shape (batch_size,) indicating whether each batch | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| element has at least `threshold` matches. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| batch_size, seq_len = token_ids.size() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if seq_len < max_n: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Not enough tokens to form a single n-gram | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| matched_mask = torch.zeros((batch_size, 0), dtype=torch.bool, device=token_ids.device) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| found = torch.zeros((batch_size, ), dtype=torch.bool, device=token_ids.device) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return matched_mask, found | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # token_ids could be 0, so we add 1 to avoid div 0 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| token_ids = token_ids.to(torch.float32) + 1 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # normalize ids | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| norm = token_ids[:, -max_n:] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if not same_n: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # fill 0 for n < max_n | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| mask = torch.arange(max_n, device=token_ids.device).unsqueeze(0) >= (max_n - n.unsqueeze(1)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| norm = norm * mask.to(torch.float32) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| norm = norm.norm(2, dim=-1, keepdim=True) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| normed_ids = token_ids / norm | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # concate p1 and p2 so we can check distance and vector in one conv1d | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| normed_n_ids = normed_ids[:, -max_n:] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| normed_ids_p2 = normed_ids * normed_ids | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ones_ids = torch.ones_like(normed_n_ids) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if not same_n: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # fill 0 for n < max_n | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| normed_n_ids = normed_n_ids * mask.to(torch.float32) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ones_ids = ones_ids * mask.to(torch.float32) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| normed_ids = torch.cat([normed_ids, normed_ids_p2], dim=0) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| normed_n_ids = torch.cat([normed_n_ids, ones_ids], dim=0) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # check cos distance & check vector length | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| match_norm = torch.conv1d(normed_ids.unsqueeze(0), normed_n_ids.unsqueeze(1), groups=batch_size * 2)[0] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| match_norm, match_ones = match_norm.chunk(2, dim=0) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # both match result should be close to 1 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| one_tensor = _ngram_one(dtype=match_norm.dtype, device=match_norm.device) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| matched_mask = match_norm.isclose(one_tensor) & match_ones.isclose(one_tensor) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # threshold | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| count = matched_mask.sum(-1) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| found = (count >= threshold) & (threshold > 0) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return matched_mask, found | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def _filter_ngram_( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| scores: torch.Tensor, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| stop_words: torch.Tensor, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| generated_ids: torch.Tensor, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| n: torch.Tensor, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| threshold: torch.Tensor, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| max_n: int, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| same_n: bool = False, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Filter ngram.""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if stop_words is None or stop_words.numel() == 0: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return scores | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # use first stop words | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| _, found = ngram(generated_ids, n, threshold, max_n, same_n) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| stop_words = stop_words[:, 0] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # fill all scores -inf | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| scores.masked_fill_(found[:, None], -float('inf')) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # set stop words to 0 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| stop_scores = scores.gather(1, stop_words[:, None]) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| stop_scores.masked_fill_(found[:, None], 0) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| scores.scatter_(1, stop_words[:, None], stop_scores) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+167
to
+179
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ): | |
| """Filter ngram.""" | |
| if stop_words is None or stop_words.numel() == 0: | |
| return scores | |
| # use first stop words | |
| _, found = ngram(generated_ids, n, threshold, max_n, same_n) | |
| stop_words = stop_words[:, 0] | |
| # fill all scores -inf | |
| scores.masked_fill_(found[:, None], -float('inf')) | |
| # set stop words to 0 | |
| stop_scores = scores.gather(1, stop_words[:, None]) | |
| stop_scores.masked_fill_(found[:, None], 0) | |
| scores.scatter_(1, stop_words[:, None], stop_scores) | |
| stop_mask: torch.Tensor | None = None, | |
| ): | |
| """Filter ngram.""" | |
| if stop_words is None or stop_words.numel() == 0: | |
| return scores | |
| # determine which sequences have ngram matches | |
| _, found = ngram(generated_ids, n, threshold, max_n, same_n) | |
| batch_size = scores.size(0) | |
| if stop_mask is not None: | |
| # has_stop indicates which batch elements have at least one valid stop word | |
| if stop_mask.numel() == 0: | |
| return scores | |
| has_stop = stop_mask.any(dim=1) | |
| if not has_stop.any(): | |
| # no sequences have valid stop words, nothing to do | |
| return scores | |
| # compute index of first valid stop word per batch element | |
| num_stops = stop_mask.size(1) | |
| idxs = torch.arange( | |
| num_stops, device=stop_mask.device, dtype=torch.long | |
| ).unsqueeze(0).expand_as(stop_mask) | |
| idxs = idxs.masked_fill(~stop_mask, num_stops) | |
| first_idxs = idxs.argmin(dim=1) | |
| batch_indices = torch.arange(batch_size, device=stop_words.device, dtype=torch.long) | |
| stop_tokens = stop_words[batch_indices, first_idxs] | |
| # only apply forcing where both an ngram is found and a valid stop word exists | |
| valid_found = found & has_stop | |
| if not valid_found.any(): | |
| return scores | |
| scores.masked_fill_(valid_found[:, None], -float('inf')) | |
| stop_tokens_exp = stop_tokens[:, None] | |
| stop_scores = scores.gather(1, stop_tokens_exp) | |
| stop_scores.masked_fill_(valid_found[:, None], 0) | |
| scores.scatter_(1, stop_tokens_exp, stop_scores) | |
| else: | |
| # fallback: use the first stop word in each row, as originally implemented | |
| stop_tokens = stop_words[:, 0] | |
| # fill all scores -inf where an ngram is found | |
| scores.masked_fill_(found[:, None], -float('inf')) | |
| # set stop word scores to 0 | |
| stop_tokens_exp = stop_tokens[:, None] | |
| stop_scores = scores.gather(1, stop_tokens_exp) | |
| stop_scores.masked_fill_(found[:, None], 0) | |
| scores.scatter_(1, stop_tokens_exp, stop_scores) |
Copilot
AI
Jan 23, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SamplingInputs.to_device() mutates self by setting self.generated_ids before building the output instance, and it forces a full .copy() of generated_ids_cpu. This introduces side effects and an extra memory copy on every device transfer. Prefer keeping to_device pure (compute a local tensor from generated_ids_cpu and put it into out_dict) and avoid the unconditional copy unless it’s required for contiguity.
| if self.generated_ids_cpu is not None: | |
| self.generated_ids = torch.from_numpy(self.generated_ids_cpu.copy()) | |
| for f in fields(self): | |
| k = f.name | |
| v = getattr(self, k) | |
| # Build a local `generated_ids` tensor from `generated_ids_cpu` without | |
| # mutating `self` and without forcing an unnecessary copy. | |
| if self.generated_ids_cpu is not None: | |
| gen_np = self.generated_ids_cpu | |
| # Ensure the array is suitable for torch.from_numpy without an | |
| # extra copy unless required for contiguity or writability. | |
| if (not gen_np.flags['C_CONTIGUOUS']) or (not gen_np.flags['WRITEABLE']): | |
| gen_np = np.ascontiguousarray(gen_np) | |
| generated_ids = torch.from_numpy(gen_np) | |
| else: | |
| generated_ids = self.generated_ids | |
| for f in fields(self): | |
| k = f.name | |
| if k == 'generated_ids': | |
| v = generated_ids | |
| else: | |
| v = getattr(self, k) |
Copilot
AI
Jan 23, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Repetition penalty is now applied to sampling_inputs.generated_ids (new tokens only) instead of the full history (all_ids/valid_ids). This changes behavior compared to standard repetition penalty (which usually includes the prompt/history) and compared to the previous implementation in this file. If the intent is to preserve previous semantics, keep using the full token history; otherwise, this behavior change should be explicitly validated and covered by an integration test.
| generated_ids = sampling_inputs.generated_ids | |
| scores = _process_repetition_penalty_(scores, generated_ids, repetition_penalty) | |
| scores = _process_repetition_penalty_(scores, all_ids, repetition_penalty) |
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -1,6 +1,7 @@ | ||||||||||||||||||||||||||||||||||||||||||
| # Copyright (c) OpenMMLab. All rights reserved. | ||||||||||||||||||||||||||||||||||||||||||
| from typing import List | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| import numpy as np | ||||||||||||||||||||||||||||||||||||||||||
| import torch | ||||||||||||||||||||||||||||||||||||||||||
| from torch.profiler import record_function | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -15,7 +16,7 @@ | |||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| def _gather_all_ids(pad_id: int, seqs: SeqList, sampling_inputs: SamplingInputs): | ||||||||||||||||||||||||||||||||||||||||||
| """Gather history.""" | ||||||||||||||||||||||||||||||||||||||||||
| if sampling_inputs.repetition_penalty is None and not any(sampling_inputs.logits_processors): | ||||||||||||||||||||||||||||||||||||||||||
| if not any(sampling_inputs.logits_processors): | ||||||||||||||||||||||||||||||||||||||||||
| return None | ||||||||||||||||||||||||||||||||||||||||||
| batch = len(seqs) | ||||||||||||||||||||||||||||||||||||||||||
| max_len = max(seq.num_valid_ids for seq in seqs) | ||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -29,6 +30,22 @@ def _gather_all_ids(pad_id: int, seqs: SeqList, sampling_inputs: SamplingInputs) | |||||||||||||||||||||||||||||||||||||||||
| return output | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| def _gather_generated_ids(pad_id: int, seqs: SeqList, sampling_inputs: SamplingInputs) -> np.ndarray | None: | ||||||||||||||||||||||||||||||||||||||||||
| """Gather history.""" | ||||||||||||||||||||||||||||||||||||||||||
| if sampling_inputs.repetition_penalty is None and sampling_inputs.max_ngram_size == 0: | ||||||||||||||||||||||||||||||||||||||||||
| return None | ||||||||||||||||||||||||||||||||||||||||||
| batch = len(seqs) | ||||||||||||||||||||||||||||||||||||||||||
| max_len = max(seq.num_new_tokens for seq in seqs) | ||||||||||||||||||||||||||||||||||||||||||
| output = np.full((batch, max_len), pad_id, dtype=np.int64) | ||||||||||||||||||||||||||||||||||||||||||
| for idx, seq in enumerate(seqs): | ||||||||||||||||||||||||||||||||||||||||||
| h_len = seq.num_new_tokens | ||||||||||||||||||||||||||||||||||||||||||
| if h_len == 0: | ||||||||||||||||||||||||||||||||||||||||||
| continue | ||||||||||||||||||||||||||||||||||||||||||
| h_ids = seq.generated_ids | ||||||||||||||||||||||||||||||||||||||||||
| output[idx, -h_len:] = h_ids | ||||||||||||||||||||||||||||||||||||||||||
| return output | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| def _get_num_ignore_eos(seqs: SeqList): | ||||||||||||||||||||||||||||||||||||||||||
| """Get num ignore eos.""" | ||||||||||||||||||||||||||||||||||||||||||
| ret = [seq.sampling_param.min_new_tokens - seq.num_new_tokens for seq in seqs] | ||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -61,6 +78,8 @@ def make_sampling_inputs(self, seqs: SeqList) -> SamplingInputs: | |||||||||||||||||||||||||||||||||||||||||
| num_logprobs = [None] * batch_size | ||||||||||||||||||||||||||||||||||||||||||
| session_to_cleanup = self.session_to_cleanup | ||||||||||||||||||||||||||||||||||||||||||
| self.session_to_cleanup = [] | ||||||||||||||||||||||||||||||||||||||||||
| ngram_sizes = [None] * batch_size | ||||||||||||||||||||||||||||||||||||||||||
| ngram_thresholds = [None] * batch_size | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| def __gather_params(): | ||||||||||||||||||||||||||||||||||||||||||
| """Gather params.""" | ||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -84,6 +103,8 @@ def __gather_params(): | |||||||||||||||||||||||||||||||||||||||||
| stop_words[idx] = sw | ||||||||||||||||||||||||||||||||||||||||||
| logits_processors[idx] = param.logits_processors | ||||||||||||||||||||||||||||||||||||||||||
| num_logprobs[idx] = param.num_logprobs | ||||||||||||||||||||||||||||||||||||||||||
| ngram_sizes[idx] = param.ngram_size | ||||||||||||||||||||||||||||||||||||||||||
| ngram_thresholds[idx] = param.ngram_threshold | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| def __get_topp(top_p): | ||||||||||||||||||||||||||||||||||||||||||
| """Get topp.""" | ||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -165,6 +186,19 @@ def __get_bad_words(bad_words): | |||||||||||||||||||||||||||||||||||||||||
| 'seq_id': seq.seq_id, | ||||||||||||||||||||||||||||||||||||||||||
| } for seq in seqs] | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| # ngram | ||||||||||||||||||||||||||||||||||||||||||
| max_ngram_size = max(ngram_sizes) | ||||||||||||||||||||||||||||||||||||||||||
| if max_ngram_size == 0: | ||||||||||||||||||||||||||||||||||||||||||
| ngram_sizes = None | ||||||||||||||||||||||||||||||||||||||||||
| ngram_thresholds = None | ||||||||||||||||||||||||||||||||||||||||||
| ngram_same_n = True | ||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||
| ngram_sizes = torch.tensor(ngram_sizes) | ||||||||||||||||||||||||||||||||||||||||||
| ngram_thresholds = torch.tensor(ngram_thresholds) | ||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+196
to
+197
|
||||||||||||||||||||||||||||||||||||||||||
| ngram_sizes = torch.tensor(ngram_sizes) | |
| ngram_thresholds = torch.tensor(ngram_thresholds) | |
| # Normalize per-request ngram settings so that disabled requests | |
| # do not have n == 0 in the tensors. For any row where ngram_size | |
| # is non-positive or ngram_threshold is non-positive, we treat the | |
| # request as "disabled" for n-gram matching by assigning | |
| # ngram_size = max_ngram_size and ngram_threshold = 0. This keeps | |
| # shapes and masks valid while ensuring those rows do not | |
| # participate in blocking logic. | |
| normalized_ngram_sizes: list[int] = [] | |
| normalized_ngram_thresholds: list[float] = [] | |
| for size, threshold in zip(ngram_sizes, ngram_thresholds): | |
| if size <= 0 or threshold <= 0: | |
| normalized_ngram_sizes.append(max_ngram_size) | |
| normalized_ngram_thresholds.append(0.0) | |
| else: | |
| normalized_ngram_sizes.append(size) | |
| normalized_ngram_thresholds.append(threshold) | |
| ngram_sizes = torch.tensor(normalized_ngram_sizes) | |
| ngram_thresholds = torch.tensor(normalized_ngram_thresholds) |
Uh oh!
There was an error while loading. Please reload this page.