Skip to content

[KSM] support keep sampling mask#7146

Open
zeroRains wants to merge 4 commits intoPaddlePaddle:release/2.5from
zeroRains:kms_2.5
Open

[KSM] support keep sampling mask#7146
zeroRains wants to merge 4 commits intoPaddlePaddle:release/2.5from
zeroRains:kms_2.5

Conversation

@zeroRains
Copy link
Copy Markdown
Contributor

@zeroRains zeroRains commented Apr 2, 2026

Motivation

添加keep_sampling_mask功能,详细见PR:#6725

💡 If this PR is a Cherry Pick, the PR title needs to follow the format by adding the [Cherry-Pick] label at the very beginning and appending the original PR ID at the end. For example, [Cherry-Pick][CI] Add check trigger and logic(#5191)

💡 如若此PR是Cherry Pick,PR标题需遵循格式,在最开始加上[Cherry-Pick]标签,以及最后面加上原PR ID,例如[Cherry-Pick][CI] Add check trigger and logic(#5191)

Modifications

sampler.py 下新增_compute_sampling_mask方法
添加启动参数--enable-keep-sampling-mask

Usage or Command

服务启动指令:

export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
MODEL_PATH="/root/paddlejob/tmpspace/GLM-4.5-Air/"
python -m fastdeploy.entrypoints.openai.api_server \
    --port 9293 \
    --host $(hostname -i) \
    --model "$MODEL_PATH" \
    --disable-custom-all-reduce \
    --tensor-parallel-size 8 \
    --max-model-len 131072 \
    --max-num-seqs 32 \
    --gpu-memory-utilization 0.9 \
    --graph-optimization-config '{"use_cudagraph":true}' \
    --enable-logprob \
    --enable-keep-sampling-mask \
    --speculative-config '{"method":"mtp","num_speculative_tokens":1,"num_model_steps":1,"model":"'$MODEL_PATH'"}'

Accuracy Tests

yes

Checklist

  • Add at least a tag in the PR title.
    • Tag list: [[FDConfig],[APIServer],[Engine], [Scheduler], [PD Disaggregation], [Executor], [Graph Optimization], [Speculative Decoding], [RL], [Models], [Quantization], [Loader], [OP], [KVCache], [DataProcessor], [BugFix], [Docs], [CI], [Optimization], [Feature], [Benchmark], [Others], [XPU], [HPU], [GCU], [DCU], [Iluvatar], [Metax]]
    • You can add new tags based on the PR content, but the semantics must be clear.
  • Format your code, run pre-commit before commit.
  • Add unit tests. Please write the reason in this PR if no unit tests.
  • Provide accuracy results.
  • If the current PR is submitting to the release branch, make sure the PR has been submitted to the develop branch, then cherry-pick it to the release branch with the [Cherry-Pick] PR tag.

Copilot AI review requested due to automatic review settings April 2, 2026 05:36
@paddle-bot
Copy link
Copy Markdown

paddle-bot bot commented Apr 2, 2026

Thanks for your contribution!

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

该 PR 旨在为采样流程新增 keep_sampling_mask 能力:在 top_p/top_k 采样后输出每个 token 步的“保留词表索引集合”(稀疏形式),并将其一路透传到引擎输出与 OpenAI 协议响应中,便于客户端获取采样约束/候选集合信息。

Changes:

  • 新增启动参数 --enable-keep-sampling-mask / --enable_keep_sampling_mask,并在 engine→worker→sampler 链路中开启 sampling_mask 产出。
  • 在 sampler 中计算 top_k+top_p 的稀疏 sampling_mask(以及 logZ),并写入 SamplerOutput,通过 ZMQ side-channel(FD_USE_GET_SAVE_OUTPUT_V1=0)回传到 TokenProcessor
  • OpenAI serving/protocol、Engine 输出结构中新增 sampling_mask 字段并在 stream/full 响应中返回。

Reviewed changes

Copilot reviewed 16 out of 16 changed files in this pull request and generated 9 comments.

Show a summary per file
File Description
fastdeploy/worker/worker_process.py worker 侧新增 enable_keep_sampling_mask CLI 参数
fastdeploy/config.py ModelConfig 增加 enable_keep_sampling_mask 配置项
fastdeploy/engine/args_utils.py engine CLI/EngineArgs 增加 enable_keep_sampling_mask 并透传到 ModelConfig
fastdeploy/engine/engine.py 启动 worker 时透传 enable_keep_sampling_mask store_true flag
fastdeploy/engine/common_engine.py 同上(common_engine 启动 worker)
fastdeploy/model_executor/layers/sample/meta_data.py SamplingMetadata 增加 keep_sampling_mask 开关字段
fastdeploy/model_executor/layers/sample/sampler.py 新增 _compute_sampling_mask 并在采样前计算稀疏 mask + logZ
fastdeploy/model_executor/layers/sample/ops/top_k_top_p_sampling.py base 分支对 tensor top_p 的采样补上 top_k renorm 逻辑
fastdeploy/worker/output.py SamplerOutput 增加 sampling_mask / logz_per_batch 字段
fastdeploy/model_executor/pre_and_post_process.py save_output_normal / post_process_specualate 增加 sampling_mask 透传与 ZMQ side-channel 发送
fastdeploy/worker/gpu_model_runner.py GPU runner 初始化 sampling_mask ZMQ client,并将开关传入 SamplingMetadata 与 post_process
fastdeploy/output/token_processor.py TokenProcessor 增加 sampling_mask ZMQ PULL server 并在 batch 输出中填充 sampling_mask
fastdeploy/output/stream_transfer_data.py StreamTransferData 增加 sampling_mask 字段(V1 输出链路用)
fastdeploy/engine/request.py CompletionOutput 增加 sampling_mask 字段并纳入 to_dict()
fastdeploy/entrypoints/openai/protocol.py OpenAI ChatCompletion stream/full choice 新增 sampling_mask 字段
fastdeploy/entrypoints/openai/serving_chat.py OpenAI chat stream/full 透传 sampling_mask,并做形状统一与扁平化

Comment on lines +85 to +92
self.use_sampling_mask = getattr(self.cfg.model_config, "enable_keep_sampling_mask", False)
if not envs.FD_USE_GET_SAVE_OUTPUT_V1 and self.use_sampling_mask:
rank_id = self.cfg.parallel_config.local_data_parallel_id
port = self.cfg.parallel_config.engine_worker_queue_port[rank_id]
self.sampling_mask_zmq_server = ZmqIpcServer(
name=f"sampling_mask_output_rank_{rank_id}_{port}", mode=zmq.PULL
)
llm_logger.info(f"create zmq sampling_mask_output_rank_{rank_id}_{port}")
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

enable_keep_sampling_mask 打开时这里会创建 sampling_mask 的 ZMQ PULL server,并在后续 _process_batch_output()receive_pyobj_once(block=True) 阻塞等待。当前只有 GPU worker 路径会发送 sampling_mask;在 XPU/GCU/HPU/Metax 等平台启用该开关会导致这里永久阻塞、服务 hang 住。建议:1)仅在支持的平台启用(例如 current_platform.is_cuda()),或 2)改成非阻塞/带超时 receive 并在收不到时跳过本步 sampling_mask。

Copilot uses AI. Check for mistakes.
Comment on lines +169 to +176
cum_probs = paddle.cumsum(renorm_sorted_probs, axis=-1) # [B, V]
topp_mask = (cum_probs - renorm_sorted_probs) <= top_p # [B, V]
# When top_p[i] >= 1.0, keep the entire row.
topp_mask = paddle.where(
(top_p >= 1.0).expand_as(topp_mask),
paddle.ones_like(topp_mask),
topp_mask,
)
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

top_p >= 1.0 且未启用 top_k 时,final_mask 会变成整行全 True,max_k 接近 vocab_size,随后会把 sorted_indices[:, :max_k] 整块 D2H 拷贝并生成超大的 Python list(每 token 返回全词表 index)。这会带来非常显著的 CPU/内存/带宽开销,甚至把服务拖死。建议在这种“未截断”场景直接不返回 sampling_mask(设为 None),或增加上限(例如最多返回 top_k_max 个 index)并在文档/参数校验中限制。

Copilot uses AI. Check for mistakes.
Comment on lines 432 to +439
delta=delta_message,
logprobs=logprobs_res,
draft_logprobs=draft_logprobs_res,
sampling_mask=(
self._make_sampling_mask_list(output["sampling_mask"])
if output.get("sampling_mask") is not None
else None
),
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里新增了 sampling_mask 字段并透传到 OpenAI 响应(stream/full 两条路径),但现有 tests/entrypoints/openai/ 下对响应结构有较多覆盖,当前没有看到对 sampling_mask 的断言用例。建议补充单测:1)stream 场景每个 delta 的 sampling_mask 形状(Non-MTP: [[...]];MTP/Spec: [[...], ...]);2)non-stream 场景最终 choice.sampling_mask 的扁平化结果与 token 数对齐。

Copilot generated this review using guidance from repository custom instructions.
"--enable-keep-sampling-mask",
action="store_true",
help=(
"Enable output of keep_sampling_mask as sparse vocab index list per token step "
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

参数 help 文案里使用了 keep_sampling_mask 这个字段名,但对外协议/返回字段实际是 sampling_mask(见 OpenAI protocol / CompletionOutput)。建议把 help 文案与对外字段名对齐(例如统一叫 sampling_mask),避免用户以为返回字段也叫 keep_sampling_mask。

Suggested change
"Enable output of keep_sampling_mask as sparse vocab index list per token step "
"Enable output of sampling_mask as sparse vocab index list per token step "

Copilot uses AI. Check for mistakes.
@zeroRains zeroRains changed the title [KSM] support keep samping mask [KSM] support keep sampling mask Apr 2, 2026
Copilot AI review requested due to automatic review settings April 2, 2026 07:15
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 23 out of 23 changed files in this pull request and generated 5 comments.

Comment on lines +197 to +201
k_per_row = final_mask.astype("int32").sum(axis=-1) # [B]
max_k = int(k_per_row.max().item())

# ------------------------------------------------------------------
# Stage 5: compute logZ_K for renormalization
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_compute_sampling_mask 在 real_bsz==0(例如 speculative 场景 total_accepted==0 或极端情况下空 batch)时会调用 k_per_row.max()/sorted_indices[:, :max_k] 等操作,容易直接报错或产生无效切片。建议在函数开头对 real_bsz==0 做快速返回(空 list + 空数组),并避免后续 max()/argsort 逻辑。

Copilot uses AI. Check for mistakes.
Comment on lines +997 to 1002
# Extract target logits/probs at accepted positions (shared by logprobs and sampling_mask).
# When both are enabled, reuse target_logits to derive target_probs (avoid a second kernel call).
total_accepted = int(accept_nums.sum().item())
target_logits = paddle.empty([total_accepted, logits.shape[1]], dtype=logits.dtype)
speculate_get_target_logits(
target_logits,
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Speculative 路径下 total_accepted 可能为 0(accept_num 全为 0)。当前仍会创建 shape=[0, vocab] 的 target_logits 并继续计算 softmax/采样 mask,最终会触发 _compute_sampling_mask 的空 batch 问题或其它算子在空张量上报错。建议在 total_accepted==0 时直接跳过 logprobs/sampling_mask 计算并返回 sampling_mask=[]、logz_per_batch=空数组(同时仍需向 side-channel 发送空 dict,避免主进程阻塞)。

Copilot uses AI. Check for mistakes.
# where the value is a list[int] or list[list[int]] of allowed token ids
sampling_masks_per_request = {}
if self.use_sampling_mask and not envs.FD_USE_GET_SAVE_OUTPUT_V1 and hasattr(self, "sampling_mask_zmq_server"):
_, mask_data = self.sampling_mask_zmq_server.receive_pyobj_once(block=True)
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里对 sampling_mask side-channel 使用了阻塞式 receive_pyobj_once(block=True)。如果 worker 端未发送(例如非 CUDA runner 未实现发送、或某些 rank/路径未触发 send),主线程会永久阻塞导致服务卡死。建议:1)只在确认 worker 端启用且已建立 side-channel 时再阻塞接收;或 2)改为非阻塞/带超时的 poll,并在超时后回退为不填充 sampling_mask。

Suggested change
_, mask_data = self.sampling_mask_zmq_server.receive_pyobj_once(block=True)
# Use a bounded non-blocking polling loop to avoid deadlock if worker does not send.
mask_data = None
max_wait_ms = 50.0
start_ts = time.monotonic()
try:
while (time.monotonic() - start_ts) * 1000.0 < max_wait_ms:
_, mask_data = self.sampling_mask_zmq_server.receive_pyobj_once(block=False)
if mask_data is not None:
break
# Sleep briefly before next poll to avoid busy-waiting.
time.sleep(0.001)
except Exception:
# If side-channel fails, fall back to decoding without sampling masks.
mask_data = None

Copilot uses AI. Check for mistakes.
Comment on lines +416 to +421
# Send sampling_mask via ZMQ side-channel when enabled.
if sampler_output.sampling_mask is not None and model_output.mp_rank == 0:
# sampling_mask is List[np.ndarray] of sparse int indices, one array per request.
mask_dict = {i: arr.tolist() for i, arr in enumerate(sampler_output.sampling_mask)}

sampling_mask_zmq_client.send_pyobj(mask_dict)
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

side-channel 发送 sampling_mask 时未对 sampler_output 做 batch re-order(enable_pd_reorder 场景下 sampled_token_ids 会被 recover,但 sampling_mask 仍按原顺序枚举),会导致主进程按 batch_id 对应到错误请求。建议在发送前对 sampler_output(至少 sampling_mask)执行与 sampled_token_ids 相同的 recover/reorder,或按 index_to_batch_id 显式构造 mask_dict 的键顺序。

Copilot uses AI. Check for mistakes.
Comment on lines +88 to +92
if top_k_list and any(x > 0 for x in top_k_list):
from fastdeploy.model_executor.ops.gpu import top_k_renorm_probs

x = top_k_renorm_probs(x, top_k)

Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里在默认(base)top_p_sampling 分支中新增了 top_k_renorm_probs 逻辑,意味着 top_k 不再仅限于 rejection class。当前函数顶部 docstring 仍写“top_k Only used when FD_SAMPLING_CLASS is rejection”,容易误导使用者/后续维护。建议同步更新该说明,并明确 base 分支下 top_k 的生效条件与行为(renorm + 再做 top_p)。

Copilot uses AI. Check for mistakes.
@codecov-commenter
Copy link
Copy Markdown

codecov-commenter commented Apr 2, 2026

Codecov Report

❌ Patch coverage is 73.01587% with 34 lines in your changes missing coverage. Please review.
⚠️ Please upload report for BASE (release/2.5@bd48640). Learn more about missing BASE report.

Files with missing lines Patch % Lines
fastdeploy/model_executor/layers/sample/sampler.py 78.57% 9 Missing and 3 partials ⚠️
fastdeploy/entrypoints/openai/serving_chat.py 38.46% 6 Missing and 2 partials ⚠️
fastdeploy/output/token_processor.py 60.00% 4 Missing and 2 partials ⚠️
fastdeploy/model_executor/pre_and_post_process.py 72.22% 3 Missing and 2 partials ⚠️
...executor/layers/sample/ops/top_k_top_p_sampling.py 0.00% 2 Missing and 1 partial ⚠️
Additional details and impacted files
@@              Coverage Diff               @@
##             release/2.5    #7146   +/-   ##
==============================================
  Coverage               ?   69.10%           
==============================================
  Files                  ?      390           
  Lines                  ?    54356           
  Branches               ?     8576           
==============================================
  Hits                   ?    37564           
  Misses                 ?    14064           
  Partials               ?     2728           
Flag Coverage Δ
GPU 69.10% <73.01%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Copy Markdown

@fastdeploy-bot fastdeploy-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤖 AI Code Review | 2026-04-02 20:12 CST

📋 Review 摘要

PR 概述:添加 keep_sampling_mask 功能,用于返回 top_p/top_k 采样后保留的词汇表位置稀疏索引
变更范围:sampler、pre_and_post_process、serving_chat、token_processor、gpu_model_runner
影响面 Tag[OP] [APIServer] [Engine]

📝 PR 规范检查

PR 标题中的 [KSM] 不在官方 Tag 列表中。

标题建议(可直接复制):

  • [Feature] support keep sampling mask

问题

级别 文件 概述
🔴 Bug sampler.py:103 函数返回类型声明与实际返回值不匹配
🔴 Bug pre_and_post_process.py:632 非投机解码路径未传递 sampling_mask_zmq_client 参数
🟡 建议 serving_chat.py:1018 _make_sampling_mask_list 未处理空列表边界条件

总体评价

功能实现思路清晰,MTP 和非 MTP 场景均有考虑。但存在两个阻塞性问题:类型声明错误和非投机解码路径参数传递缺失,后者会导致非 MTP 模式下功能无法正常工作。建议修复后合并。

sampling_mask_zmq_client,
)
else:
post_process_normal(
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Bug 非投机解码路径未传递 sampling_mask_zmq_client 参数

post_process 函数接收了 sampling_mask_zmq_client 参数(第601行),但在非投机解码分支调用 post_process_normal 时未传递该参数。这会导致非 MTP 模式下 enable_keep_sampling_mask 功能无法通过 ZMQ 通道传输 sampling_mask 数据。

建议修复:

  1. post_process_normal 函数添加 sampling_mask_zmq_client 参数
  2. post_process_normal 内部(或调用 save_output_normal 时)处理 sampling_mask 的 ZMQ 发送逻辑
  3. 在此处调用时传递该参数

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已经在save_output_normal中调用ZMQ发送逻辑了,而post_process_normal中不需要zmq发送信息

Copilot AI review requested due to automatic review settings April 2, 2026 13:01
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 23 out of 23 changed files in this pull request and generated 4 comments.

Comments suppressed due to low confidence (1)

fastdeploy/model_executor/layers/sample/ops/top_k_top_p_sampling.py:102

  • 此处在非 rejection/base 分支也对 top_k 做了 renorm,从而让 top_k 在 paddle.tensor.top_p_sampling 路径生效。当前函数 docstring 仍写着 top_k “Only used when FD_SAMPLING_CLASS is rejection”,容易误导调用方/使用者。建议同步更新 docstring(或确认该行为确实期望)。

            _, ids = native_top_p_sampling(x, top_p)
        else:
            if top_k_list and any(x > 0 for x in top_k_list):
                from fastdeploy.model_executor.ops.gpu import top_k_renorm_probs

                x = top_k_renorm_probs(x, top_k)

            _, ids = paddle.tensor.top_p_sampling(
                x,
                top_p,
                threshold=threshold,
                topp_seed=topp_seed,
                seed=seed,
                k=k,
                mode="truncated",
            )
    return _, ids

Comment on lines +416 to +422
# Send sampling_mask via ZMQ side-channel when enabled.
if sampler_output.sampling_mask is not None and model_output.mp_rank == 0:
# sampling_mask is List[np.ndarray] of sparse int indices, one array per request.
mask_dict = {i: arr.tolist() for i, arr in enumerate(sampler_output.sampling_mask)}

sampling_mask_zmq_client.send_pyobj(mask_dict)

Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里在发送 sampling_mask 时未检查 sampling_mask_zmq_client 是否为 None;如果调用方未创建/未传入 client(例如某些 runner 路径或异常初始化),会直接触发 AttributeError 并中断推理。建议在 send_pyobj 前判断 sampling_mask_zmq_client 非空,并在缺失时记录告警或跳过发送。

Copilot uses AI. Check for mistakes.
Comment on lines +557 to 572
# Send sampling_mask via ZMQ side-channel when enabled.
if sampler_output.sampling_mask is not None and model_output.mp_rank == 0:
# sampling_mask is List[np.ndarray] of sparse int indices, length = total_accepted_tokens.
# Group by request using accept_num so each entry is List[np.ndarray] (n arrays per req).
real_bsz = model_output.accept_num.shape[0]
accept_nums = model_output.accept_num[:real_bsz].flatten().tolist()
mask_dict = {}
offset = 0
for i, n in enumerate(accept_nums):
n = int(n)
if n > 0:
# List of n sparse index arrays, one per accepted token
mask_dict[i] = [arr.tolist() for arr in sampler_output.sampling_mask[offset : offset + n]]
offset += n
sampling_mask_zmq_client.send_pyobj(mask_dict)

Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

speculative 路径下同样在 send_pyobj 前未校验 sampling_mask_zmq_client 是否为 None;一旦 client 未初始化会导致 AttributeError。建议在发送前增加非空判断,并在缺失时降级处理(例如跳过发送并记录日志)。

Copilot uses AI. Check for mistakes.
Comment on lines +751 to +752
_, mask_data = self.sampling_mask_zmq_server.receive_pyobj_once(block=True)
if mask_data is not None and isinstance(mask_data, dict):
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里丢弃了 receive_pyobj_once 返回的 err(第一个返回值)。当 ZMQ socket 关闭/异常时会返回错误字符串并主动 close;当前逻辑会静默忽略,导致 sampling_mask 丢失且排障困难。建议检查 err 并打印 warning/触发降级(必要时避免一直 block 等待)。

Suggested change
_, mask_data = self.sampling_mask_zmq_server.receive_pyobj_once(block=True)
if mask_data is not None and isinstance(mask_data, dict):
err, mask_data = self.sampling_mask_zmq_server.receive_pyobj_once(block=True)
if err:
# Log warning and degrade gracefully by skipping sampling masks for this step
llm_logger.warning(
"Failed to receive sampling mask from ZMQ side-channel: %s", err
)
elif mask_data is not None and isinstance(mask_data, dict):

Copilot uses AI. Check for mistakes.
Comment on lines +647 to +656
# Compute sampling mask BEFORE top_k_top_p_sampling modifies probs.
# Binary mask [num_reqs, vocab_size]: 1 = retained by top_k/top_p, 0 = truncated.
sampling_mask = None
if sampling_metadata.keep_sampling_mask:
sampling_mask = _compute_sampling_mask(
probs,
sampling_metadata.top_p,
top_k=sampling_metadata.top_k,
top_k_list=sampling_metadata.top_k_list,
)
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的注释写的是“Binary mask [num_reqs, vocab_size]”,但 sampling_mask 实际返回的是稀疏 vocab 索引的 List[np.ndarray](非 dense bool mask)。建议把注释改为与返回格式一致,避免后续维护/调用方误解。

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants