[None][feat] Skip-softmax-stat: stat-collecting FMHA cubin variants + JSON logger#13665
Draft
bobboli wants to merge 3 commits intoNVIDIA:mainfrom
Draft
[None][feat] Skip-softmax-stat: stat-collecting FMHA cubin variants + JSON logger#13665bobboli wants to merge 3 commits intoNVIDIA:mainfrom
bobboli wants to merge 3 commits intoNVIDIA:mainfrom
Conversation
Replace the project-wide -DSKIP_SOFTMAX_STAT=ON build flag (for the
FMHA prefill path) with a runtime-selected kernel variant. Stat-collecting
FMHA cubins now ship by default alongside the production _skipSoftmax
cubins; the dispatcher picks between them via a new
Launch_params.enableSkipSoftmaxStat flag, hashed into the
kernel-selection bitmask at bit 15.
Kernel side:
- src/fmha/warpspec/kernel_traits.h: new ENABLE_SKIP_SOFTMAX_STAT_
template parameter (and constexpr alias) on Kernel_traits and the
Hopper_qgmma_e4m3_fp32 specialization.
- src/fmha/warpspec/{compute,epilogue}.h: replace #ifdef SKIP_SOFTMAX_STAT
with `if constexpr (Kernel_traits::ENABLE_SKIP_SOFTMAX_STAT)`. The
total_blocks/skipped_blocks Softmax_base members and the params-struct
pointer fields become unconditional.
- src/fused_multihead_attention.{h,cpp} + demo_bert_params.h: drop the
#ifdef around stat pointer fields; gate alloc/print/free in fmha.exe
on a new -skip-softmax-stat CLI flag plumbed via
Fused_multihead_attention_launch_params::enable_skip_softmax_stat.
setup.py kernel emitter:
- New 'enable_skip_softmax_stat' axis in kernel_spec; cubin name suffix
'_skipSoftmaxStat' on top of '_skipSoftmax'; reverse-parser handles the
substring relationship correctly. Enumeration extended to (skip,stat)
pairs (False,False)/(True,False)/(True,True). DISABLE_SKIP_SOFTMAX
env-var short-circuit removed.
- Ktraits instantiation strings (5 mask variants + the API code path)
thread the new bool through.
- Metadata struct emitter writes mEnableSkipSoftmaxStat into all four
variants of FusedMultiHeadAttentionKernelMetaInfoV2.
Dispatcher / runtime:
- contextFusedMultiHeadAttention/cubin/fmha_cubin.h: new
mEnableSkipSoftmaxStat field (cubin/fmha_cubin.cpp regenerated by
setup.py at build time).
- fused_multihead_attention_v2.{h,cpp}: kEnableSkipSoftmaxStatShift=15
hash bit; both hashID overloads, hashFromParams, and the diagnostic
dump consume the new flag.
- fmhaRunner.cpp: surface MHARunnerParams::enableSkipSoftmaxStat into
Launch_params.enableSkipSoftmaxStat.
- common/attentionOp.{cpp,h}: prefill FMHA path drops its
SKIP_SOFTMAX_STAT-guarded fields and the throw-on-env path.
mSkipSoftmaxTotalBlocks/Skipped + new mEnableSkipSoftmaxStat are
unconditional and default-nullptr/false.
- thop/attentionOp.cpp: skip_softmax_stat std::optional<Tensor>
presence drives mEnableSkipSoftmaxStat (Python passes None when stats
are off so .has_value() is the natural gate).
Build wiring:
- cpp/CMakeLists.txt: keep the SKIP_SOFTMAX_STAT option but downgrade it
to "legacy XQA-decoder Skip-Softmax stat path". The FMHA prefill path
no longer depends on it.
- cpp/kernels/fmha_v2/Makefile: replace the SKIP_SOFTMAX_STAT comment
with a pointer to the new -skip-softmax-stat fmha.exe CLI flag.
- common/envUtils.{h,cpp}: remove getEnvPrintSkipSoftmaxStat
(unused after the prefill path drops the env-var fallback).
XQA decoder kernels are intentionally untouched — their existing
SKIP_SOFTMAX_STAT-gated stat path remains opt-in via the build flag.
Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com>
…Torch backend
Wire the runtime-selected stat-collecting FMHA kernel variants into a
single config dial: SkipSoftmaxAttentionConfig.stat_log_path. When set,
the per-step / per-layer (skipped, total) block counters land in a
JSON dump at executor shutdown. Replaces the legacy
TRTLLM_PRINT_SKIP_SOFTMAX_STAT env var + per-layer print() path.
llmapi/llm_args.py:
- New stat_log_path: Optional[str] field on SkipSoftmaxAttentionConfig
(co-located with threshold_scale_factor and target_sparsity).
- Cross-config validator on TorchLlmArgs rejects the combination of
stat_log_path + cuda_graph_config.batch_sizes — per-layer .zero_()
and host-side stat readback both break graph capture. Emits a debug
warning when stat_log_path is set so users do not accidentally
leave it on for benchmark runs.
_torch/attention_backend/trtllm.py:
- New SkipSoftmaxStatLogger class (inlined): step → layer → {total,
skipped} dict, auto-detecting step boundaries from layer-index
ordering. Per-step metadata (num_contexts / num_ctx_tokens /
num_gen_tokens) is recorded once when layer 0 fires for the step.
dump() writes JSON via Path.write_text(); an atexit fallback covers
crashes that bypass PyExecutor.shutdown().
- Module-level singleton routes records from per-layer wrappers, so
step counters stay coherent.
- TrtllmAttentionWrapper:
* skip_softmax_stat tensor allocated lazily — None on the production
path so we do not reserve 8 bytes of VRAM per layer for an unused
buffer.
* skip_softmax_stat_enabled resolved from
sparse_attention_config.stat_log_path during plan(); zeros and
reads back the GPU counters only when enabled. The tensor is now
passed to torch.ops.trtllm.attention as None when stats are off.
* Replace the legacy print(SKIP_SOFTMAX_STAT: layer{i}: ...) block
with a logger.record() call.
- TrtLLMAttention.forward stamps per-step ctx/gen meta on the
singleton logger before each wrapper.run().
_torch/pyexecutor/py_executor.py:
- PyExecutor.shutdown() flushes the stat logger before tearing down
the worker thread / resource managers, so users get a deterministic
dump under trtllm-serve. atexit remains as a fallback.
Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com>
…bench
Two-pass driver under tests/microbenchmarks/skip_softmax/:
- llm_configs.py: Qwen3-30B-A3B GQA proxy (q_heads=64, kv_heads=4,
head_dim=128). Prefill bs=1 x seq{16k, 64k} x dtype{bf16, e4m3} x
causal mask. Decode bs=64 x kv{16k, 64k} x q=1 x dtype{bf16, e4m3}
x causal mask. threshold_scale_factor sweep lifted directly from
blog 16's "Performance Benchmark" table.
- diffusion_configs.py: Wan2.2-T2V-A14B attn1 proxy (24 heads,
head_dim=128, bf16, bidirectional mask) over seq{16k, 32k, 41k,
65k}. Per the user's instruction: ignore target_sparsity (it is a
per-model calibration knob and does not generalise to random
data); sweep threshold_scale_factor over a log-spaced range wide
enough to bracket 0% to 99% achieved sparsity.
- bench_skip_softmax.py: shells out to cpp/kernels/fmha_v2/bin/fmha.exe.
Pass 1 enables -skip-softmax-stat (the new CLI flag) so the kernel
picks the _skipSoftmaxStat cubin variant and parses
"Skip-Softmax .: skipped / total" from stdout. Pass 2 omits the
flag so the production _skipSoftmax cubin is selected — no
atomic-counter overhead — and parses "Elapsed ......: <us>" for
median latency. Speedup is computed against the threshold=0 baseline
per config.
- generate_report.py: joins the two CSVs by (config, threshold),
emits report.md with one table per config and a matplotlib
speedup-vs-achieved-sparsity scatter (PNG).
The microbench drives the kernel directly via fmha.exe (extended in
the previous commit with -skip-softmax-stat) rather than reaching
through the LLM pipeline.
Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com>
280ef8d to
425cf1f
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
-DSKIP_SOFTMAX_STAT=ONbuild flag with runtime kernel selection. New_skipSoftmaxStatcubin variants ship next to the production_skipSoftmaxcubins; the dispatcher picks between them via aLaunch_params.enableSkipSoftmaxStatbit. Default perf path is unchanged.SkipSoftmaxAttentionConfig.stat_log_path(PyTorch backend). When set, an inlineSkipSoftmaxStatLoggercollects per-step / per-layer(skipped, total)block counts and dumps JSON onPyExecutor.shutdown(). Validator rejects the combination with CUDA graphs.TRTLLM_PRINT_SKIP_SOFTMAX_STATenv var + per-layerprint()path.tests/microbenchmarks/skip_softmax/(LLM blog-16 shapes + Wan2.2 diffusion shapes) driven by a newbin/fmha.exe -skip-softmax-statflag.Three commits: kernel-side variant emission / dispatcher; PyTorch config + JSON logger; microbench.