Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions fastdeploy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ def __init__(
self.enable_logprob = False
self.max_logprobs = 20
self.logprobs_mode = "raw_logprobs"
self.enable_keep_sampling_mask = False
self.redundant_experts_num = 0
self.seed = 0
self.quantization = None
Expand Down
20 changes: 20 additions & 0 deletions fastdeploy/engine/args_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,14 @@ class EngineArgs:
Must be explicitly enabled via the `--enable-logprob` startup parameter to output logprob values.
"""

enable_keep_sampling_mask: bool = False
"""
When enabled, the server returns a sparse index list for each generated token, indicating
which vocabulary positions were retained after top_p/top_k sampling, and streams it to
the client. In MTP (multi-token prediction) scenarios this field is a List[List[int]],
where each inner list contains the retained vocabulary indices for a predicted token.
"""

max_logprobs: int = 20
"""
Maximum number of log probabilities to return when `enable_logprob` is True. The default value comes the default for the
Expand Down Expand Up @@ -872,6 +880,18 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
default=EngineArgs.enable_logprob,
help="Enable output of token-level log probabilities.",
)
model_group.add_argument(
"--enable-keep-sampling-mask",
action="store_true",
default=EngineArgs.enable_keep_sampling_mask,
help=(
"Enable output of sampling mask as a sparse index list over the vocabulary. "
"For non-MTP decoding, this is a list[int] per token step indicating which "
"vocabulary indices were kept after top_p/top_k sampling. "
"For MTP decoding, this is a list[list[int]] per token step, where each inner "
"list corresponds to one MTP group."
),
)
model_group.add_argument(
"--max-logprobs",
type=int,
Expand Down
1 change: 1 addition & 0 deletions fastdeploy/engine/common_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2371,6 +2371,7 @@ def _start_worker_service(self):
"lm_head_fp32": self.cfg.model_config.lm_head_fp32,
"enable_entropy": self.cfg.model_config.enable_entropy,
"enable_overlap_schedule": self.cfg.scheduler_config.enable_overlap_schedule,
"enable_keep_sampling_mask": self.cfg.model_config.enable_keep_sampling_mask,
}
for worker_flag, value in worker_store_true_flag.items():
if value:
Expand Down
1 change: 1 addition & 0 deletions fastdeploy/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,6 +621,7 @@ def _start_worker_service(self):
"shutdown_comm_group_if_worker_idle": self.cfg.parallel_config.shutdown_comm_group_if_worker_idle,
"enable_entropy": self.cfg.model_config.enable_entropy,
"enable_overlap_schedule": self.cfg.scheduler_config.enable_overlap_schedule,
"enable_keep_sampling_mask": self.cfg.model_config.enable_keep_sampling_mask,
}
for worker_flag, value in worker_store_true_flag.items():
if value:
Expand Down
5 changes: 5 additions & 0 deletions fastdeploy/engine/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,6 +727,10 @@ class CompletionOutput:
multipart: Optional[list[Any]] = None
num_image_tokens: Optional[int] = None
enable_parser: bool = False
# Sparse indices of retained vocab ids:
# - Non-MTP: list[int]
# - MTP: list[list[int]]
sampling_mask: Optional[Any] = None

def to_dict(self):
"""
Expand All @@ -745,6 +749,7 @@ def to_dict(self):
"text": self.text,
"reasoning_content": self.reasoning_content,
"reasoning_token_num": self.reasoning_token_num,
"sampling_mask": self.sampling_mask,
}

@classmethod
Expand Down
5 changes: 5 additions & 0 deletions fastdeploy/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,8 @@ class ChatCompletionResponseChoice(BaseModel):
logprobs: Optional[LogProbs] = None
draft_logprobs: Optional[LogProbs] = None
prompt_logprobs: Optional[PromptLogprobs] = None
# Per-token retained vocab indices from top_p/top_k sampling: List[List[int]], one list of vocab indices per token
sampling_mask: Optional[List[List[int]]] = None
finish_reason: Optional[Literal["stop", "length", "tool_calls", "recover_stop", "abort"]]
speculate_metrics: Optional[SpeculateMetrics] = None

Expand Down Expand Up @@ -333,6 +335,9 @@ class ChatCompletionResponseStreamChoice(BaseModel):
logprobs: Optional[LogProbs] = None
draft_logprobs: Optional[LogProbs] = None
prompt_logprobs: Optional[PromptLogprobs] = None
# Per-token index list of retained positions after top_p sampling.
# Non-MTP: [[idx, ...]] (1 token/step). MTP: [[idx, ...], ...] (N accepted tokens/step).
sampling_mask: Optional[List[List[int]]] = None
finish_reason: Optional[Literal["stop", "length", "tool_calls", "recover_stop", "abort"]] = None
arrival_time: Optional[float] = None
speculate_metrics: Optional[SpeculateMetrics] = None
Expand Down
31 changes: 31 additions & 0 deletions fastdeploy/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,11 @@ async def chat_completion_stream_generator(
delta=delta_message,
logprobs=logprobs_res,
draft_logprobs=draft_logprobs_res,
sampling_mask=(
self._make_sampling_mask_list(output["sampling_mask"])
if output.get("sampling_mask") is not None
else None
),
Comment on lines 432 to +439
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里新增了 sampling_mask 字段并透传到 OpenAI 响应(stream/full 两条路径),但现有 tests/entrypoints/openai/ 下对响应结构有较多覆盖,当前没有看到对 sampling_mask 的断言用例。建议补充单测:1)stream 场景每个 delta 的 sampling_mask 形状(Non-MTP: [[...]];MTP/Spec: [[...], ...]);2)non-stream 场景最终 choice.sampling_mask 的扁平化结果与 token 数对齐。

Copilot generated this review using guidance from repository custom instructions.
arrival_time=arrival_time,
speculate_metrics=output_speculate_metrics,
)
Expand Down Expand Up @@ -577,6 +582,7 @@ async def chat_completion_full_generator(
decoder_base_url=self.tokenizer_base_url,
)
prompt_logprobs_res_list = [[] for _ in range(num_choices)]
sampling_mask_list = [[] for _ in range(num_choices)]
speculate_metrics = [None for _ in range(num_choices)]
choices = []
while num_choices > 0:
Expand Down Expand Up @@ -656,6 +662,9 @@ async def chat_completion_full_generator(
)
if prompt_logprobs_res:
prompt_logprobs_res_list[idx].extend(clamp_prompt_logprobs(prompt_logprobs_res))
output_sampling_mask = output.get("sampling_mask", None)
if output_sampling_mask is not None:
sampling_mask_list[idx].append(self._make_sampling_mask_list(output_sampling_mask))
speculate_metrics[idx] = data["metrics"].get("speculate_metrics", None)
if data["finished"]:
trace_carrier = data.get("trace_carrier")
Expand Down Expand Up @@ -691,6 +700,7 @@ async def chat_completion_full_generator(
draft_logprob_contents=draft_logprob_contents,
response_processor=response_processor,
prompt_logprobs_res_list=prompt_logprobs_res_list,
sampling_mask_list=sampling_mask_list,
max_tokens=max_tokens,
speculate_metrics=speculate_metrics[idx],
)
Expand Down Expand Up @@ -745,6 +755,7 @@ async def _create_chat_completion_choice(
logprob_contents: list,
draft_logprob_contents: list,
prompt_logprobs_res_list: list,
sampling_mask_list: list,
response_processor: ChatResponseProcessor,
max_tokens: int,
speculate_metrics: SpeculateMetrics | None,
Expand Down Expand Up @@ -782,6 +793,10 @@ async def _create_chat_completion_choice(
draft_logprobs_full_res = LogProbs(content=draft_logprob_contents[idx])
if prompt_logprobs_res_list[idx]:
prompt_logprobs_full_res = prompt_logprobs_res_list[idx]
# Flatten per-step List[List[int]] into a single List[List[int]] over all tokens.
sampling_mask_full_res = None
if sampling_mask_list and sampling_mask_list[idx]:
sampling_mask_full_res = [mask for step in sampling_mask_list[idx] for mask in step]

num_cached_tokens[idx] = data.get("num_cached_tokens", 0)
num_input_image_tokens[idx] = data.get("num_input_image_tokens", 0)
Expand All @@ -806,6 +821,7 @@ async def _create_chat_completion_choice(
logprobs=logprobs_full_res,
draft_logprobs=draft_logprobs_full_res,
prompt_logprobs=prompt_logprobs_full_res,
sampling_mask=sampling_mask_full_res,
finish_reason=finish_reason,
speculate_metrics=speculate_metrics,
)
Expand Down Expand Up @@ -989,3 +1005,18 @@ def _make_logprob_dict(
)
for token_id, logprob, rank, token in zip(logprob_token_ids, logprobs, ranks, decoded_tokens)
}

@staticmethod
def _make_sampling_mask_list(sampling_mask) -> List[List[int]]:
"""Wrap sampling_mask into a uniform List[List[int]] format.

sampling_mask is already in sparse-index form (no bool-to-index conversion needed):
Non-MTP: List[int] (indices for 1 token/step) → [[idx, ...]]
MTP: List[List[int]] (indices for N tokens/step) → [[idx, ...], ...]
"""
assert sampling_mask is not None
if sampling_mask and isinstance(sampling_mask[0], list):
# MTP: already List[List[int]], return as-is
return sampling_mask
# Non-MTP: already List[int], wrap in outer list for uniform format
return [sampling_mask]
2 changes: 2 additions & 0 deletions fastdeploy/model_executor/layers/sample/meta_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,5 @@ class SamplingMetadata:
# Add for HPU post-processing
seq_lens_encoder: Optional[paddle.Tensor] = None
seq_lens_decoder: Optional[paddle.Tensor] = None
# Add for keep sampling mask
keep_sampling_mask: Optional[bool] = None
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,11 @@ def top_k_top_p_sampling(

_, ids = native_top_p_sampling(x, top_p)
else:
if top_k_list and any(x > 0 for x in top_k_list):
from fastdeploy.model_executor.ops.gpu import top_k_renorm_probs

x = top_k_renorm_probs(x, top_k)

Comment on lines +88 to +92
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里在默认(base)top_p_sampling 分支中新增了 top_k_renorm_probs 逻辑,意味着 top_k 不再仅限于 rejection class。当前函数顶部 docstring 仍写“top_k Only used when FD_SAMPLING_CLASS is rejection”,容易误导使用者/后续维护。建议同步更新该说明,并明确 base 分支下 top_k 的生效条件与行为(renorm + 再做 top_p)。

Copilot uses AI. Check for mistakes.
_, ids = paddle.tensor.top_p_sampling(
x,
top_p,
Expand Down
Loading
Loading