Skip to content

[Speculative Decoding]【Hackathon 10th Spring No.49】Adapt ngram_match and hybrid_mtp_ngram gpu kernels#7103

Open
NKNaN wants to merge 2 commits intoPaddlePaddle:developfrom
NKNaN:ngram
Open

[Speculative Decoding]【Hackathon 10th Spring No.49】Adapt ngram_match and hybrid_mtp_ngram gpu kernels#7103
NKNaN wants to merge 2 commits intoPaddlePaddle:developfrom
NKNaN:ngram

Conversation

@NKNaN
Copy link
Copy Markdown

@NKNaN NKNaN commented Mar 31, 2026

Motivation

rfc: PaddlePaddle/community#1213

Modifications

  • 实现方式:两个 kernel。
    • 第一阶段:count_and_find_candidate_kernel,网格为 <<<max_batch_size+1, 1024>>>。
      • block 0 用 BlockReduce 统计全局 unprocessed_batch_size。
      • block 1..N 各自负责一个 batch 并行执行候选查找(input_ids / pre_ids)。
    • 第二阶段:truncate_candidate,<<<1, 1024>>>,统一按 threshold 做截断和写回。
      • 该阶段使用 CUB BlockScan 做前缀和(processed_batch_size / sum_token_num),用于计算每个 batch 的可分配 token 上限并完成截断。

Usage or Command

None

Accuracy Tests

https://github.com/NKNaN/FastDeploy_ngram_match_kernel

Checklist

  • Add at least a tag in the PR title.
    • Tag list: [[FDConfig],[APIServer],[Engine], [Scheduler], [PD Disaggregation], [Executor], [Graph Optimization], [Speculative Decoding], [RL], [Models], [Quantization], [Loader], [OP], [KVCache], [DataProcessor], [BugFix], [Docs], [CI], [Optimization], [Feature], [Benchmark], [Others], [XPU], [HPU], [GCU], [DCU], [Iluvatar], [Metax]]
    • You can add new tags based on the PR content, but the semantics must be clear.
  • Format your code, run pre-commit before commit.
  • Add unit tests. Please write the reason in this PR if no unit tests.
  • Provide accuracy results.
  • If the current PR is submitting to the release branch, make sure the PR has been submitted to the develop branch, then cherry-pick it to the release branch with the [Cherry-Pick] PR tag.

@paddle-bot
Copy link
Copy Markdown

paddle-bot bot commented Mar 31, 2026

Thanks for your contribution!

@NKNaN NKNaN closed this Apr 2, 2026
@NKNaN NKNaN reopened this Apr 2, 2026
@NKNaN NKNaN closed this Apr 2, 2026
@NKNaN NKNaN reopened this Apr 2, 2026
@freeliuzc freeliuzc requested a review from Copilot April 2, 2026 11:37
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

该 PR 将 speculative decoding 里的 ngram_matchhybrid_mtp_ngram 从原先偏 CPU/Host 逻辑适配为 GPU kernel 两阶段实现,以降低延迟并减少 Host<->Device 拷贝,属于 spec_decode 路径上的算子性能优化与接口适配。

Changes:

  • 新增/替换 ngram_match CUDA 实现:拆分为 “统计+候选查找” 与 “阈值截断写回” 两阶段 kernel。
  • 更新 NgramProposerMTPProposer 调用方式:改为直接使用 GPU 输入,并新增/复用 GPU copy buffer 参数以匹配新算子签名。
  • 更新相关单测以适配新签名,并将测试设备切换为 GPU。

Reviewed changes

Copilot reviewed 9 out of 9 changed files in this pull request and generated 5 comments.

Show a summary per file
File Description
tests/operators/test_ngram_match.py 适配 ngram_match 新签名并改为在 GPU 上运行/取回结果检查
tests/operators/test_hybrid_mtp_ngram.py 适配 hybrid_mtp_ngram 新签名并改为 GPU 上构造输入与断言
fastdeploy/spec_decode/ngram.py NgramProposer 改为走 GPU 输入与新增 copy buffer(避免 .cpu()/.cuda() 往返)
fastdeploy/spec_decode/mtp.py MTPProposer 调用 hybrid_mtp_ngram 适配新签名并缓存 copy buffer
custom_ops/gpu_ops/speculate_decoding/ngram_match.cu 新增 ngram_match CUDA 两阶段 kernel 实现并注册静态算子
custom_ops/gpu_ops/speculate_decoding/ngram_match.cc 删除旧的 host/CPU 风格实现
custom_ops/gpu_ops/speculate_decoding/ngram_match_core.cuh 抽取滑窗 ngram search 的 device 内联函数供两算子复用
custom_ops/gpu_ops/speculate_decoding/draft_model/ngram_match_mixed.cu hybrid_mtp_ngram 适配两阶段 GPU kernel + 新签名
custom_ops/gpu_ops/cpp_extensions.cc 更新 C++ 扩展侧函数声明签名以匹配新增参数

Comment on lines 24 to +25
def setUp(self):
paddle.set_device("cpu")
paddle.set_device("gpu")
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里在 setUp() 里无条件切到 GPU,会导致在 CPU-only / 未编译 CUDA 的环境下直接报错,CI 也可能无法运行。建议在 setUp() 先判断 paddle.is_compiled_with_cuda(),不满足则 skipTest,并在通过检查后再 paddle.set_device("gpu")。

Copilot generated this review using guidance from repository custom instructions.
Comment on lines 26 to 28
def setUp(self):
paddle.set_device("gpu")
self.max_bsz = 2
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上:setUp() 里无条件 set_device("gpu") 会让未编译 CUDA/无 GPU 的环境无法运行该测试。建议增加 paddle.is_compiled_with_cuda() 判断并 skipTest;必要时也可以根据设备 capability 做更细粒度跳过。

Copilot generated this review using guidance from repository custom instructions.
Comment on lines +284 to +325
const int NTHREADS = 1024;

int *d_unprocessed_ptr;
cudaGetSymbolAddress(reinterpret_cast<void **>(&d_unprocessed_ptr),
d_ngram_unprocessed_batch_size);

ngram_count_and_find_candidate_kernel<NTHREADS>
<<<max_batch_size + 1, NTHREADS>>>(
input_ids.data<int64_t>(),
input_ids_len.data<int64_t>(),
token_ids_all.data<int64_t>(),
prompt_lens.data<int64_t>(),
step_idx.data<int64_t>(),
draft_token_num.data<int>(),
const_cast<int64_t *>(draft_tokens.data<int64_t>()),
const_cast<int64_t *>(draft_tokens_copy.data<int64_t>()),
const_cast<int32_t *>(seq_lens_this_time.data<int32_t>()),
const_cast<int32_t *>(seq_lens_this_time_copy.data<int32_t>()),
const_cast<int32_t *>(seq_lens_encoder.data<int32_t>()),
const_cast<int32_t *>(seq_lens_decoder.data<int32_t>()),
const_cast<int64_t *>(max_dec_len.data<int64_t>()),
input_ids_stride,
max_model_len,
draft_tokens_stride,
max_ngram_size,
max_draft_tokens,
d_unprocessed_ptr,
max_batch_size);

ngram_truncate_candidate<NTHREADS><<<1, NTHREADS>>>(
step_idx.data<int64_t>(),
draft_token_num.data<int>(),
const_cast<int64_t *>(max_dec_len.data<int64_t>()),
const_cast<int32_t *>(seq_lens_this_time.data<int32_t>()),
const_cast<int32_t *>(seq_lens_this_time_copy.data<int32_t>()),
const_cast<int64_t *>(draft_tokens.data<int64_t>()),
const_cast<int64_t *>(draft_tokens_copy.data<int64_t>()),
draft_tokens_stride,
max_batch_size,
max_draft_tokens,
tokennum_threshold,
d_unprocessed_ptr);
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里 launch CUDA kernel 没有使用 Paddle 提供的 tensor.stream()(当前代码用默认 stream 0)。在 Paddle/phi 的执行流不是默认 stream 的情况下会破坏算子间的顺序依赖,导致竞态或错误结果。建议获取一个输入 tensor 的 stream(例如 seq_lens_this_time.stream()/input_ids.stream()),并在两次 kernel launch 的执行配置里显式传入该 stream。

Copilot uses AI. Check for mistakes.
Comment on lines +289 to +329
const int NTHREADS = 1024;

int *d_unprocessed_ptr;
cudaGetSymbolAddress(reinterpret_cast<void **>(&d_unprocessed_ptr),
d_mixed_unprocessed_batch_size);

mixed_count_and_find_candidate_kernel<NTHREADS>
<<<max_batch_size + 1, NTHREADS>>>(
input_ids.data<int64_t>(),
input_ids_len.data<int64_t>(),
pre_ids.data<int64_t>(),
step_idx.data<int64_t>(),
draft_token_num.data<int>(),
const_cast<int64_t *>(draft_tokens.data<int64_t>()),
const_cast<int64_t *>(draft_tokens_copy.data<int64_t>()),
const_cast<int32_t *>(seq_lens_this_time.data<int32_t>()),
const_cast<int32_t *>(seq_lens_this_time_copy.data<int32_t>()),
const_cast<int32_t *>(seq_lens_decoder.data<int32_t>()),
const_cast<int64_t *>(max_dec_len.data<int64_t>()),
input_ids_stride,
pre_ids_stride,
draft_tokens_stride,
max_ngram_size,
min_ngram_size,
max_draft_tokens,
d_unprocessed_ptr,
max_batch_size);

mixed_truncate_candidate<NTHREADS><<<1, NTHREADS>>>(
step_idx.data<int64_t>(),
draft_token_num.data<int>(),
const_cast<int64_t *>(draft_tokens.data<int64_t>()),
const_cast<int32_t *>(seq_lens_this_time.data<int32_t>()),
const_cast<int32_t *>(seq_lens_decoder.data<int32_t>()),
const_cast<int64_t *>(max_dec_len.data<int64_t>()),
input_ids_stride,
pre_ids_stride,
const_cast<int32_t *>(seq_lens_this_time.data<int32_t>()),
const_cast<int32_t *>(seq_lens_this_time_copy.data<int32_t>()),
const_cast<int64_t *>(draft_tokens.data<int64_t>()),
const_cast<int64_t *>(draft_tokens_copy.data<int64_t>()),
draft_tokens_stride,
max_batch_size,
max_ngram_size,
min_ngram_size,
max_draft_tokens);
max_draft_tokens,
tokennum_threshold,
d_unprocessed_ptr);
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同样的问题:HybridMtpNgram 这里的 kernel launch 也没有绑定到 Paddle 的执行 stream(未使用 input_ids/seq_lens_this_time 的 .stream()),可能造成与前后算子的异步竞态。建议改为在 <<<... , 0, cu_stream>>> 上显式使用 tensor.stream()。

Copilot uses AI. Check for mistakes.
Comment on lines 1225 to +1228
hybrid_mtp_ngram(
self.model_inputs["input_ids_cpu"],
self.model_inputs["input_ids_len"],
self.model_inputs["pre_ids"]._copy_to(device, True),
self.model_inputs["step_idx"].cpu(),
self.target_model_inputs["actual_draft_token_num"].cpu(),
draft_tokens,
seq_lens_this_time,
seq_lens_decoder,
self.model_inputs["max_dec_len"].cpu(),
self.model_inputs["input_ids"],
self.model_inputs["input_ids_len"].cuda(),
self.model_inputs["pre_ids"],
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里每次调用都对 self.model_inputs["input_ids_len"] 做 .cuda(),会产生一次额外的 H2D 拷贝/新 Tensor 分配(input_ids_len 在 ProposerInputBatch 里初始化为 device="cpu"),对每 step 的延迟不友好。建议像 draft_tokens_copy 一样做一次性的 GPU buffer 缓存,并在更新 input_ids_len 时同步更新该 GPU buffer(或直接把 input_ids_len 维护在 GPU 上)。

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants