[KSM] support keep sampling mask#7146
[KSM] support keep sampling mask#7146zeroRains wants to merge 4 commits intoPaddlePaddle:release/2.5from
Conversation
|
Thanks for your contribution! |
There was a problem hiding this comment.
Pull request overview
该 PR 旨在为采样流程新增 keep_sampling_mask 能力:在 top_p/top_k 采样后输出每个 token 步的“保留词表索引集合”(稀疏形式),并将其一路透传到引擎输出与 OpenAI 协议响应中,便于客户端获取采样约束/候选集合信息。
Changes:
- 新增启动参数
--enable-keep-sampling-mask/--enable_keep_sampling_mask,并在 engine→worker→sampler 链路中开启 sampling_mask 产出。 - 在 sampler 中计算 top_k+top_p 的稀疏 sampling_mask(以及 logZ),并写入
SamplerOutput,通过 ZMQ side-channel(FD_USE_GET_SAVE_OUTPUT_V1=0)回传到TokenProcessor。 - OpenAI serving/protocol、Engine 输出结构中新增
sampling_mask字段并在 stream/full 响应中返回。
Reviewed changes
Copilot reviewed 16 out of 16 changed files in this pull request and generated 9 comments.
Show a summary per file
| File | Description |
|---|---|
| fastdeploy/worker/worker_process.py | worker 侧新增 enable_keep_sampling_mask CLI 参数 |
| fastdeploy/config.py | ModelConfig 增加 enable_keep_sampling_mask 配置项 |
| fastdeploy/engine/args_utils.py | engine CLI/EngineArgs 增加 enable_keep_sampling_mask 并透传到 ModelConfig |
| fastdeploy/engine/engine.py | 启动 worker 时透传 enable_keep_sampling_mask store_true flag |
| fastdeploy/engine/common_engine.py | 同上(common_engine 启动 worker) |
| fastdeploy/model_executor/layers/sample/meta_data.py | SamplingMetadata 增加 keep_sampling_mask 开关字段 |
| fastdeploy/model_executor/layers/sample/sampler.py | 新增 _compute_sampling_mask 并在采样前计算稀疏 mask + logZ |
| fastdeploy/model_executor/layers/sample/ops/top_k_top_p_sampling.py | base 分支对 tensor top_p 的采样补上 top_k renorm 逻辑 |
| fastdeploy/worker/output.py | SamplerOutput 增加 sampling_mask / logz_per_batch 字段 |
| fastdeploy/model_executor/pre_and_post_process.py | save_output_normal / post_process_specualate 增加 sampling_mask 透传与 ZMQ side-channel 发送 |
| fastdeploy/worker/gpu_model_runner.py | GPU runner 初始化 sampling_mask ZMQ client,并将开关传入 SamplingMetadata 与 post_process |
| fastdeploy/output/token_processor.py | TokenProcessor 增加 sampling_mask ZMQ PULL server 并在 batch 输出中填充 sampling_mask |
| fastdeploy/output/stream_transfer_data.py | StreamTransferData 增加 sampling_mask 字段(V1 输出链路用) |
| fastdeploy/engine/request.py | CompletionOutput 增加 sampling_mask 字段并纳入 to_dict() |
| fastdeploy/entrypoints/openai/protocol.py | OpenAI ChatCompletion stream/full choice 新增 sampling_mask 字段 |
| fastdeploy/entrypoints/openai/serving_chat.py | OpenAI chat stream/full 透传 sampling_mask,并做形状统一与扁平化 |
| self.use_sampling_mask = getattr(self.cfg.model_config, "enable_keep_sampling_mask", False) | ||
| if not envs.FD_USE_GET_SAVE_OUTPUT_V1 and self.use_sampling_mask: | ||
| rank_id = self.cfg.parallel_config.local_data_parallel_id | ||
| port = self.cfg.parallel_config.engine_worker_queue_port[rank_id] | ||
| self.sampling_mask_zmq_server = ZmqIpcServer( | ||
| name=f"sampling_mask_output_rank_{rank_id}_{port}", mode=zmq.PULL | ||
| ) | ||
| llm_logger.info(f"create zmq sampling_mask_output_rank_{rank_id}_{port}") |
There was a problem hiding this comment.
enable_keep_sampling_mask 打开时这里会创建 sampling_mask 的 ZMQ PULL server,并在后续 _process_batch_output() 中 receive_pyobj_once(block=True) 阻塞等待。当前只有 GPU worker 路径会发送 sampling_mask;在 XPU/GCU/HPU/Metax 等平台启用该开关会导致这里永久阻塞、服务 hang 住。建议:1)仅在支持的平台启用(例如 current_platform.is_cuda()),或 2)改成非阻塞/带超时 receive 并在收不到时跳过本步 sampling_mask。
| cum_probs = paddle.cumsum(renorm_sorted_probs, axis=-1) # [B, V] | ||
| topp_mask = (cum_probs - renorm_sorted_probs) <= top_p # [B, V] | ||
| # When top_p[i] >= 1.0, keep the entire row. | ||
| topp_mask = paddle.where( | ||
| (top_p >= 1.0).expand_as(topp_mask), | ||
| paddle.ones_like(topp_mask), | ||
| topp_mask, | ||
| ) |
There was a problem hiding this comment.
当 top_p >= 1.0 且未启用 top_k 时,final_mask 会变成整行全 True,max_k 接近 vocab_size,随后会把 sorted_indices[:, :max_k] 整块 D2H 拷贝并生成超大的 Python list(每 token 返回全词表 index)。这会带来非常显著的 CPU/内存/带宽开销,甚至把服务拖死。建议在这种“未截断”场景直接不返回 sampling_mask(设为 None),或增加上限(例如最多返回 top_k_max 个 index)并在文档/参数校验中限制。
| delta=delta_message, | ||
| logprobs=logprobs_res, | ||
| draft_logprobs=draft_logprobs_res, | ||
| sampling_mask=( | ||
| self._make_sampling_mask_list(output["sampling_mask"]) | ||
| if output.get("sampling_mask") is not None | ||
| else None | ||
| ), |
There was a problem hiding this comment.
这里新增了 sampling_mask 字段并透传到 OpenAI 响应(stream/full 两条路径),但现有 tests/entrypoints/openai/ 下对响应结构有较多覆盖,当前没有看到对 sampling_mask 的断言用例。建议补充单测:1)stream 场景每个 delta 的 sampling_mask 形状(Non-MTP: [[...]];MTP/Spec: [[...], ...]);2)non-stream 场景最终 choice.sampling_mask 的扁平化结果与 token 数对齐。
| "--enable-keep-sampling-mask", | ||
| action="store_true", | ||
| help=( | ||
| "Enable output of keep_sampling_mask as sparse vocab index list per token step " |
There was a problem hiding this comment.
参数 help 文案里使用了 keep_sampling_mask 这个字段名,但对外协议/返回字段实际是 sampling_mask(见 OpenAI protocol / CompletionOutput)。建议把 help 文案与对外字段名对齐(例如统一叫 sampling_mask),避免用户以为返回字段也叫 keep_sampling_mask。
| "Enable output of keep_sampling_mask as sparse vocab index list per token step " | |
| "Enable output of sampling_mask as sparse vocab index list per token step " |
| k_per_row = final_mask.astype("int32").sum(axis=-1) # [B] | ||
| max_k = int(k_per_row.max().item()) | ||
|
|
||
| # ------------------------------------------------------------------ | ||
| # Stage 5: compute logZ_K for renormalization |
There was a problem hiding this comment.
_compute_sampling_mask 在 real_bsz==0(例如 speculative 场景 total_accepted==0 或极端情况下空 batch)时会调用 k_per_row.max()/sorted_indices[:, :max_k] 等操作,容易直接报错或产生无效切片。建议在函数开头对 real_bsz==0 做快速返回(空 list + 空数组),并避免后续 max()/argsort 逻辑。
| # Extract target logits/probs at accepted positions (shared by logprobs and sampling_mask). | ||
| # When both are enabled, reuse target_logits to derive target_probs (avoid a second kernel call). | ||
| total_accepted = int(accept_nums.sum().item()) | ||
| target_logits = paddle.empty([total_accepted, logits.shape[1]], dtype=logits.dtype) | ||
| speculate_get_target_logits( | ||
| target_logits, |
There was a problem hiding this comment.
Speculative 路径下 total_accepted 可能为 0(accept_num 全为 0)。当前仍会创建 shape=[0, vocab] 的 target_logits 并继续计算 softmax/采样 mask,最终会触发 _compute_sampling_mask 的空 batch 问题或其它算子在空张量上报错。建议在 total_accepted==0 时直接跳过 logprobs/sampling_mask 计算并返回 sampling_mask=[]、logz_per_batch=空数组(同时仍需向 side-channel 发送空 dict,避免主进程阻塞)。
| # where the value is a list[int] or list[list[int]] of allowed token ids | ||
| sampling_masks_per_request = {} | ||
| if self.use_sampling_mask and not envs.FD_USE_GET_SAVE_OUTPUT_V1 and hasattr(self, "sampling_mask_zmq_server"): | ||
| _, mask_data = self.sampling_mask_zmq_server.receive_pyobj_once(block=True) |
There was a problem hiding this comment.
这里对 sampling_mask side-channel 使用了阻塞式 receive_pyobj_once(block=True)。如果 worker 端未发送(例如非 CUDA runner 未实现发送、或某些 rank/路径未触发 send),主线程会永久阻塞导致服务卡死。建议:1)只在确认 worker 端启用且已建立 side-channel 时再阻塞接收;或 2)改为非阻塞/带超时的 poll,并在超时后回退为不填充 sampling_mask。
| _, mask_data = self.sampling_mask_zmq_server.receive_pyobj_once(block=True) | |
| # Use a bounded non-blocking polling loop to avoid deadlock if worker does not send. | |
| mask_data = None | |
| max_wait_ms = 50.0 | |
| start_ts = time.monotonic() | |
| try: | |
| while (time.monotonic() - start_ts) * 1000.0 < max_wait_ms: | |
| _, mask_data = self.sampling_mask_zmq_server.receive_pyobj_once(block=False) | |
| if mask_data is not None: | |
| break | |
| # Sleep briefly before next poll to avoid busy-waiting. | |
| time.sleep(0.001) | |
| except Exception: | |
| # If side-channel fails, fall back to decoding without sampling masks. | |
| mask_data = None |
| # Send sampling_mask via ZMQ side-channel when enabled. | ||
| if sampler_output.sampling_mask is not None and model_output.mp_rank == 0: | ||
| # sampling_mask is List[np.ndarray] of sparse int indices, one array per request. | ||
| mask_dict = {i: arr.tolist() for i, arr in enumerate(sampler_output.sampling_mask)} | ||
|
|
||
| sampling_mask_zmq_client.send_pyobj(mask_dict) |
There was a problem hiding this comment.
side-channel 发送 sampling_mask 时未对 sampler_output 做 batch re-order(enable_pd_reorder 场景下 sampled_token_ids 会被 recover,但 sampling_mask 仍按原顺序枚举),会导致主进程按 batch_id 对应到错误请求。建议在发送前对 sampler_output(至少 sampling_mask)执行与 sampled_token_ids 相同的 recover/reorder,或按 index_to_batch_id 显式构造 mask_dict 的键顺序。
| if top_k_list and any(x > 0 for x in top_k_list): | ||
| from fastdeploy.model_executor.ops.gpu import top_k_renorm_probs | ||
|
|
||
| x = top_k_renorm_probs(x, top_k) | ||
|
|
There was a problem hiding this comment.
这里在默认(base)top_p_sampling 分支中新增了 top_k_renorm_probs 逻辑,意味着 top_k 不再仅限于 rejection class。当前函数顶部 docstring 仍写“top_k Only used when FD_SAMPLING_CLASS is rejection”,容易误导使用者/后续维护。建议同步更新该说明,并明确 base 分支下 top_k 的生效条件与行为(renorm + 再做 top_p)。
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## release/2.5 #7146 +/- ##
==============================================
Coverage ? 69.10%
==============================================
Files ? 390
Lines ? 54356
Branches ? 8576
==============================================
Hits ? 37564
Misses ? 14064
Partials ? 2728
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
fastdeploy-bot
left a comment
There was a problem hiding this comment.
🤖 AI Code Review |
2026-04-02 20:12 CST
📋 Review 摘要
PR 概述:添加 keep_sampling_mask 功能,用于返回 top_p/top_k 采样后保留的词汇表位置稀疏索引
变更范围:sampler、pre_and_post_process、serving_chat、token_processor、gpu_model_runner
影响面 Tag:[OP] [APIServer] [Engine]
📝 PR 规范检查
PR 标题中的 [KSM] 不在官方 Tag 列表中。
标题建议(可直接复制):
[Feature] support keep sampling mask
问题
| 级别 | 文件 | 概述 |
|---|---|---|
| 🔴 Bug | sampler.py:103 |
函数返回类型声明与实际返回值不匹配 |
| 🔴 Bug | pre_and_post_process.py:632 |
非投机解码路径未传递 sampling_mask_zmq_client 参数 |
| 🟡 建议 | serving_chat.py:1018 |
_make_sampling_mask_list 未处理空列表边界条件 |
总体评价
功能实现思路清晰,MTP 和非 MTP 场景均有考虑。但存在两个阻塞性问题:类型声明错误和非投机解码路径参数传递缺失,后者会导致非 MTP 模式下功能无法正常工作。建议修复后合并。
| sampling_mask_zmq_client, | ||
| ) | ||
| else: | ||
| post_process_normal( |
There was a problem hiding this comment.
🔴 Bug 非投机解码路径未传递 sampling_mask_zmq_client 参数
post_process 函数接收了 sampling_mask_zmq_client 参数(第601行),但在非投机解码分支调用 post_process_normal 时未传递该参数。这会导致非 MTP 模式下 enable_keep_sampling_mask 功能无法通过 ZMQ 通道传输 sampling_mask 数据。
建议修复:
- 为
post_process_normal函数添加sampling_mask_zmq_client参数 - 在
post_process_normal内部(或调用save_output_normal时)处理 sampling_mask 的 ZMQ 发送逻辑 - 在此处调用时传递该参数
There was a problem hiding this comment.
已经在save_output_normal中调用ZMQ发送逻辑了,而post_process_normal中不需要zmq发送信息
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 23 out of 23 changed files in this pull request and generated 4 comments.
Comments suppressed due to low confidence (1)
fastdeploy/model_executor/layers/sample/ops/top_k_top_p_sampling.py:102
- 此处在非 rejection/base 分支也对 top_k 做了 renorm,从而让 top_k 在 paddle.tensor.top_p_sampling 路径生效。当前函数 docstring 仍写着 top_k “Only used when FD_SAMPLING_CLASS is
rejection”,容易误导调用方/使用者。建议同步更新 docstring(或确认该行为确实期望)。
_, ids = native_top_p_sampling(x, top_p)
else:
if top_k_list and any(x > 0 for x in top_k_list):
from fastdeploy.model_executor.ops.gpu import top_k_renorm_probs
x = top_k_renorm_probs(x, top_k)
_, ids = paddle.tensor.top_p_sampling(
x,
top_p,
threshold=threshold,
topp_seed=topp_seed,
seed=seed,
k=k,
mode="truncated",
)
return _, ids
| # Send sampling_mask via ZMQ side-channel when enabled. | ||
| if sampler_output.sampling_mask is not None and model_output.mp_rank == 0: | ||
| # sampling_mask is List[np.ndarray] of sparse int indices, one array per request. | ||
| mask_dict = {i: arr.tolist() for i, arr in enumerate(sampler_output.sampling_mask)} | ||
|
|
||
| sampling_mask_zmq_client.send_pyobj(mask_dict) | ||
|
|
There was a problem hiding this comment.
这里在发送 sampling_mask 时未检查 sampling_mask_zmq_client 是否为 None;如果调用方未创建/未传入 client(例如某些 runner 路径或异常初始化),会直接触发 AttributeError 并中断推理。建议在 send_pyobj 前判断 sampling_mask_zmq_client 非空,并在缺失时记录告警或跳过发送。
| # Send sampling_mask via ZMQ side-channel when enabled. | ||
| if sampler_output.sampling_mask is not None and model_output.mp_rank == 0: | ||
| # sampling_mask is List[np.ndarray] of sparse int indices, length = total_accepted_tokens. | ||
| # Group by request using accept_num so each entry is List[np.ndarray] (n arrays per req). | ||
| real_bsz = model_output.accept_num.shape[0] | ||
| accept_nums = model_output.accept_num[:real_bsz].flatten().tolist() | ||
| mask_dict = {} | ||
| offset = 0 | ||
| for i, n in enumerate(accept_nums): | ||
| n = int(n) | ||
| if n > 0: | ||
| # List of n sparse index arrays, one per accepted token | ||
| mask_dict[i] = [arr.tolist() for arr in sampler_output.sampling_mask[offset : offset + n]] | ||
| offset += n | ||
| sampling_mask_zmq_client.send_pyobj(mask_dict) | ||
|
|
There was a problem hiding this comment.
speculative 路径下同样在 send_pyobj 前未校验 sampling_mask_zmq_client 是否为 None;一旦 client 未初始化会导致 AttributeError。建议在发送前增加非空判断,并在缺失时降级处理(例如跳过发送并记录日志)。
| _, mask_data = self.sampling_mask_zmq_server.receive_pyobj_once(block=True) | ||
| if mask_data is not None and isinstance(mask_data, dict): |
There was a problem hiding this comment.
这里丢弃了 receive_pyobj_once 返回的 err(第一个返回值)。当 ZMQ socket 关闭/异常时会返回错误字符串并主动 close;当前逻辑会静默忽略,导致 sampling_mask 丢失且排障困难。建议检查 err 并打印 warning/触发降级(必要时避免一直 block 等待)。
| _, mask_data = self.sampling_mask_zmq_server.receive_pyobj_once(block=True) | |
| if mask_data is not None and isinstance(mask_data, dict): | |
| err, mask_data = self.sampling_mask_zmq_server.receive_pyobj_once(block=True) | |
| if err: | |
| # Log warning and degrade gracefully by skipping sampling masks for this step | |
| llm_logger.warning( | |
| "Failed to receive sampling mask from ZMQ side-channel: %s", err | |
| ) | |
| elif mask_data is not None and isinstance(mask_data, dict): |
| # Compute sampling mask BEFORE top_k_top_p_sampling modifies probs. | ||
| # Binary mask [num_reqs, vocab_size]: 1 = retained by top_k/top_p, 0 = truncated. | ||
| sampling_mask = None | ||
| if sampling_metadata.keep_sampling_mask: | ||
| sampling_mask = _compute_sampling_mask( | ||
| probs, | ||
| sampling_metadata.top_p, | ||
| top_k=sampling_metadata.top_k, | ||
| top_k_list=sampling_metadata.top_k_list, | ||
| ) |
There was a problem hiding this comment.
这里的注释写的是“Binary mask [num_reqs, vocab_size]”,但 sampling_mask 实际返回的是稀疏 vocab 索引的 List[np.ndarray](非 dense bool mask)。建议把注释改为与返回格式一致,避免后续维护/调用方误解。
Motivation
添加keep_sampling_mask功能,详细见PR:#6725
Modifications
sampler.py 下新增_compute_sampling_mask方法
添加启动参数--enable-keep-sampling-mask
Usage or Command
服务启动指令:
Accuracy Tests
yes
Checklist
[FDConfig],[APIServer],[Engine],[Scheduler],[PD Disaggregation],[Executor],[Graph Optimization],[Speculative Decoding],[RL],[Models],[Quantization],[Loader],[OP],[KVCache],[DataProcessor],[BugFix],[Docs],[CI],[Optimization],[Feature],[Benchmark],[Others],[XPU],[HPU],[GCU],[DCU],[Iluvatar],[Metax]]pre-commitbefore commit.releasebranch, make sure the PR has been submitted to thedevelopbranch, then cherry-pick it to thereleasebranch with the[Cherry-Pick]PR tag.