[KSM] support keep sampling mask#7146
[KSM] support keep sampling mask#7146zeroRains wants to merge 2 commits intoPaddlePaddle:release/2.5from
Conversation
|
Thanks for your contribution! |
There was a problem hiding this comment.
Pull request overview
该 PR 旨在为采样流程新增 keep_sampling_mask 能力:在 top_p/top_k 采样后输出每个 token 步的“保留词表索引集合”(稀疏形式),并将其一路透传到引擎输出与 OpenAI 协议响应中,便于客户端获取采样约束/候选集合信息。
Changes:
- 新增启动参数
--enable-keep-sampling-mask/--enable_keep_sampling_mask,并在 engine→worker→sampler 链路中开启 sampling_mask 产出。 - 在 sampler 中计算 top_k+top_p 的稀疏 sampling_mask(以及 logZ),并写入
SamplerOutput,通过 ZMQ side-channel(FD_USE_GET_SAVE_OUTPUT_V1=0)回传到TokenProcessor。 - OpenAI serving/protocol、Engine 输出结构中新增
sampling_mask字段并在 stream/full 响应中返回。
Reviewed changes
Copilot reviewed 16 out of 16 changed files in this pull request and generated 9 comments.
Show a summary per file
| File | Description |
|---|---|
| fastdeploy/worker/worker_process.py | worker 侧新增 enable_keep_sampling_mask CLI 参数 |
| fastdeploy/config.py | ModelConfig 增加 enable_keep_sampling_mask 配置项 |
| fastdeploy/engine/args_utils.py | engine CLI/EngineArgs 增加 enable_keep_sampling_mask 并透传到 ModelConfig |
| fastdeploy/engine/engine.py | 启动 worker 时透传 enable_keep_sampling_mask store_true flag |
| fastdeploy/engine/common_engine.py | 同上(common_engine 启动 worker) |
| fastdeploy/model_executor/layers/sample/meta_data.py | SamplingMetadata 增加 keep_sampling_mask 开关字段 |
| fastdeploy/model_executor/layers/sample/sampler.py | 新增 _compute_sampling_mask 并在采样前计算稀疏 mask + logZ |
| fastdeploy/model_executor/layers/sample/ops/top_k_top_p_sampling.py | base 分支对 tensor top_p 的采样补上 top_k renorm 逻辑 |
| fastdeploy/worker/output.py | SamplerOutput 增加 sampling_mask / logz_per_batch 字段 |
| fastdeploy/model_executor/pre_and_post_process.py | save_output_normal / post_process_specualate 增加 sampling_mask 透传与 ZMQ side-channel 发送 |
| fastdeploy/worker/gpu_model_runner.py | GPU runner 初始化 sampling_mask ZMQ client,并将开关传入 SamplingMetadata 与 post_process |
| fastdeploy/output/token_processor.py | TokenProcessor 增加 sampling_mask ZMQ PULL server 并在 batch 输出中填充 sampling_mask |
| fastdeploy/output/stream_transfer_data.py | StreamTransferData 增加 sampling_mask 字段(V1 输出链路用) |
| fastdeploy/engine/request.py | CompletionOutput 增加 sampling_mask 字段并纳入 to_dict() |
| fastdeploy/entrypoints/openai/protocol.py | OpenAI ChatCompletion stream/full choice 新增 sampling_mask 字段 |
| fastdeploy/entrypoints/openai/serving_chat.py | OpenAI chat stream/full 透传 sampling_mask,并做形状统一与扁平化 |
| self.use_sampling_mask = getattr(self.cfg.model_config, "enable_keep_sampling_mask", False) | ||
| if not envs.FD_USE_GET_SAVE_OUTPUT_V1 and self.use_sampling_mask: | ||
| rank_id = self.cfg.parallel_config.local_data_parallel_id | ||
| port = self.cfg.parallel_config.engine_worker_queue_port[rank_id] | ||
| self.sampling_mask_zmq_server = ZmqIpcServer( | ||
| name=f"sampling_mask_output_rank_{rank_id}_{port}", mode=zmq.PULL | ||
| ) | ||
| llm_logger.info(f"create zmq sampling_mask_output_rank_{rank_id}_{port}") |
There was a problem hiding this comment.
enable_keep_sampling_mask 打开时这里会创建 sampling_mask 的 ZMQ PULL server,并在后续 _process_batch_output() 中 receive_pyobj_once(block=True) 阻塞等待。当前只有 GPU worker 路径会发送 sampling_mask;在 XPU/GCU/HPU/Metax 等平台启用该开关会导致这里永久阻塞、服务 hang 住。建议:1)仅在支持的平台启用(例如 current_platform.is_cuda()),或 2)改成非阻塞/带超时 receive 并在收不到时跳过本步 sampling_mask。
| cum_probs = paddle.cumsum(renorm_sorted_probs, axis=-1) # [B, V] | ||
| topp_mask = (cum_probs - renorm_sorted_probs) <= top_p # [B, V] | ||
| # When top_p[i] >= 1.0, keep the entire row. | ||
| topp_mask = paddle.where( | ||
| (top_p >= 1.0).expand_as(topp_mask), | ||
| paddle.ones_like(topp_mask), | ||
| topp_mask, | ||
| ) |
There was a problem hiding this comment.
当 top_p >= 1.0 且未启用 top_k 时,final_mask 会变成整行全 True,max_k 接近 vocab_size,随后会把 sorted_indices[:, :max_k] 整块 D2H 拷贝并生成超大的 Python list(每 token 返回全词表 index)。这会带来非常显著的 CPU/内存/带宽开销,甚至把服务拖死。建议在这种“未截断”场景直接不返回 sampling_mask(设为 None),或增加上限(例如最多返回 top_k_max 个 index)并在文档/参数校验中限制。
| delta=delta_message, | ||
| logprobs=logprobs_res, | ||
| draft_logprobs=draft_logprobs_res, | ||
| sampling_mask=( | ||
| self._make_sampling_mask_list(output["sampling_mask"]) | ||
| if output.get("sampling_mask") is not None | ||
| else None | ||
| ), |
There was a problem hiding this comment.
这里新增了 sampling_mask 字段并透传到 OpenAI 响应(stream/full 两条路径),但现有 tests/entrypoints/openai/ 下对响应结构有较多覆盖,当前没有看到对 sampling_mask 的断言用例。建议补充单测:1)stream 场景每个 delta 的 sampling_mask 形状(Non-MTP: [[...]];MTP/Spec: [[...], ...]);2)non-stream 场景最终 choice.sampling_mask 的扁平化结果与 token 数对齐。
| "--enable-keep-sampling-mask", | ||
| action="store_true", | ||
| help=( | ||
| "Enable output of keep_sampling_mask as sparse vocab index list per token step " |
There was a problem hiding this comment.
参数 help 文案里使用了 keep_sampling_mask 这个字段名,但对外协议/返回字段实际是 sampling_mask(见 OpenAI protocol / CompletionOutput)。建议把 help 文案与对外字段名对齐(例如统一叫 sampling_mask),避免用户以为返回字段也叫 keep_sampling_mask。
| "Enable output of keep_sampling_mask as sparse vocab index list per token step " | |
| "Enable output of sampling_mask as sparse vocab index list per token step " |
| k_per_row = final_mask.astype("int32").sum(axis=-1) # [B] | ||
| max_k = int(k_per_row.max().item()) | ||
|
|
||
| # ------------------------------------------------------------------ | ||
| # Stage 5: compute logZ_K for renormalization |
There was a problem hiding this comment.
_compute_sampling_mask 在 real_bsz==0(例如 speculative 场景 total_accepted==0 或极端情况下空 batch)时会调用 k_per_row.max()/sorted_indices[:, :max_k] 等操作,容易直接报错或产生无效切片。建议在函数开头对 real_bsz==0 做快速返回(空 list + 空数组),并避免后续 max()/argsort 逻辑。
| # Extract target logits/probs at accepted positions (shared by logprobs and sampling_mask). | ||
| # When both are enabled, reuse target_logits to derive target_probs (avoid a second kernel call). | ||
| total_accepted = int(accept_nums.sum().item()) | ||
| target_logits = paddle.empty([total_accepted, logits.shape[1]], dtype=logits.dtype) | ||
| speculate_get_target_logits( | ||
| target_logits, |
There was a problem hiding this comment.
Speculative 路径下 total_accepted 可能为 0(accept_num 全为 0)。当前仍会创建 shape=[0, vocab] 的 target_logits 并继续计算 softmax/采样 mask,最终会触发 _compute_sampling_mask 的空 batch 问题或其它算子在空张量上报错。建议在 total_accepted==0 时直接跳过 logprobs/sampling_mask 计算并返回 sampling_mask=[]、logz_per_batch=空数组(同时仍需向 side-channel 发送空 dict,避免主进程阻塞)。
| # where the value is a list[int] or list[list[int]] of allowed token ids | ||
| sampling_masks_per_request = {} | ||
| if self.use_sampling_mask and not envs.FD_USE_GET_SAVE_OUTPUT_V1 and hasattr(self, "sampling_mask_zmq_server"): | ||
| _, mask_data = self.sampling_mask_zmq_server.receive_pyobj_once(block=True) |
There was a problem hiding this comment.
这里对 sampling_mask side-channel 使用了阻塞式 receive_pyobj_once(block=True)。如果 worker 端未发送(例如非 CUDA runner 未实现发送、或某些 rank/路径未触发 send),主线程会永久阻塞导致服务卡死。建议:1)只在确认 worker 端启用且已建立 side-channel 时再阻塞接收;或 2)改为非阻塞/带超时的 poll,并在超时后回退为不填充 sampling_mask。
| _, mask_data = self.sampling_mask_zmq_server.receive_pyobj_once(block=True) | |
| # Use a bounded non-blocking polling loop to avoid deadlock if worker does not send. | |
| mask_data = None | |
| max_wait_ms = 50.0 | |
| start_ts = time.monotonic() | |
| try: | |
| while (time.monotonic() - start_ts) * 1000.0 < max_wait_ms: | |
| _, mask_data = self.sampling_mask_zmq_server.receive_pyobj_once(block=False) | |
| if mask_data is not None: | |
| break | |
| # Sleep briefly before next poll to avoid busy-waiting. | |
| time.sleep(0.001) | |
| except Exception: | |
| # If side-channel fails, fall back to decoding without sampling masks. | |
| mask_data = None |
| # Send sampling_mask via ZMQ side-channel when enabled. | ||
| if sampler_output.sampling_mask is not None and model_output.mp_rank == 0: | ||
| # sampling_mask is List[np.ndarray] of sparse int indices, one array per request. | ||
| mask_dict = {i: arr.tolist() for i, arr in enumerate(sampler_output.sampling_mask)} | ||
|
|
||
| sampling_mask_zmq_client.send_pyobj(mask_dict) |
There was a problem hiding this comment.
side-channel 发送 sampling_mask 时未对 sampler_output 做 batch re-order(enable_pd_reorder 场景下 sampled_token_ids 会被 recover,但 sampling_mask 仍按原顺序枚举),会导致主进程按 batch_id 对应到错误请求。建议在发送前对 sampler_output(至少 sampling_mask)执行与 sampled_token_ids 相同的 recover/reorder,或按 index_to_batch_id 显式构造 mask_dict 的键顺序。
| if top_k_list and any(x > 0 for x in top_k_list): | ||
| from fastdeploy.model_executor.ops.gpu import top_k_renorm_probs | ||
|
|
||
| x = top_k_renorm_probs(x, top_k) | ||
|
|
There was a problem hiding this comment.
这里在默认(base)top_p_sampling 分支中新增了 top_k_renorm_probs 逻辑,意味着 top_k 不再仅限于 rejection class。当前函数顶部 docstring 仍写“top_k Only used when FD_SAMPLING_CLASS is rejection”,容易误导使用者/后续维护。建议同步更新该说明,并明确 base 分支下 top_k 的生效条件与行为(renorm + 再做 top_p)。
Motivation
添加keep_sampling_mask功能,详细见PR:#6725
Modifications
sampler.py 下新增_compute_sampling_mask方法
添加启动参数--enable-keep-sampling-mask
Usage or Command
服务启动指令:
Accuracy Tests
yes
Checklist
[FDConfig],[APIServer],[Engine],[Scheduler],[PD Disaggregation],[Executor],[Graph Optimization],[Speculative Decoding],[RL],[Models],[Quantization],[Loader],[OP],[KVCache],[DataProcessor],[BugFix],[Docs],[CI],[Optimization],[Feature],[Benchmark],[Others],[XPU],[HPU],[GCU],[DCU],[Iluvatar],[Metax]]pre-commitbefore commit.releasebranch, make sure the PR has been submitted to thedevelopbranch, then cherry-pick it to thereleasebranch with the[Cherry-Pick]PR tag.