Skip to content

[None][feat] Skip-softmax-stat: stat-collecting FMHA cubin variants + JSON logger#13665

Draft
bobboli wants to merge 3 commits intoNVIDIA:mainfrom
bobboli:feat/skip-softmax-stat
Draft

[None][feat] Skip-softmax-stat: stat-collecting FMHA cubin variants + JSON logger#13665
bobboli wants to merge 3 commits intoNVIDIA:mainfrom
bobboli:feat/skip-softmax-stat

Conversation

@bobboli
Copy link
Copy Markdown
Collaborator

@bobboli bobboli commented Apr 30, 2026

Summary

  • Replace the -DSKIP_SOFTMAX_STAT=ON build flag with runtime kernel selection. New _skipSoftmaxStat cubin variants ship next to the production _skipSoftmax cubins; the dispatcher picks between them via a Launch_params.enableSkipSoftmaxStat bit. Default perf path is unchanged.
  • New SkipSoftmaxAttentionConfig.stat_log_path (PyTorch backend). When set, an inline SkipSoftmaxStatLogger collects per-step / per-layer (skipped, total) block counts and dumps JSON on PyExecutor.shutdown(). Validator rejects the combination with CUDA graphs.
  • Replaces the legacy TRTLLM_PRINT_SKIP_SOFTMAX_STAT env var + per-layer print() path.
  • Adds a kernel microbench under tests/microbenchmarks/skip_softmax/ (LLM blog-16 shapes + Wan2.2 diffusion shapes) driven by a new bin/fmha.exe -skip-softmax-stat flag.

Three commits: kernel-side variant emission / dispatcher; PyTorch config + JSON logger; microbench.

bobboli added 3 commits April 30, 2026 07:10
Replace the project-wide -DSKIP_SOFTMAX_STAT=ON build flag (for the
FMHA prefill path) with a runtime-selected kernel variant. Stat-collecting
FMHA cubins now ship by default alongside the production _skipSoftmax
cubins; the dispatcher picks between them via a new
Launch_params.enableSkipSoftmaxStat flag, hashed into the
kernel-selection bitmask at bit 15.

Kernel side:
- src/fmha/warpspec/kernel_traits.h: new ENABLE_SKIP_SOFTMAX_STAT_
  template parameter (and constexpr alias) on Kernel_traits and the
  Hopper_qgmma_e4m3_fp32 specialization.
- src/fmha/warpspec/{compute,epilogue}.h: replace #ifdef SKIP_SOFTMAX_STAT
  with `if constexpr (Kernel_traits::ENABLE_SKIP_SOFTMAX_STAT)`. The
  total_blocks/skipped_blocks Softmax_base members and the params-struct
  pointer fields become unconditional.
- src/fused_multihead_attention.{h,cpp} + demo_bert_params.h: drop the
  #ifdef around stat pointer fields; gate alloc/print/free in fmha.exe
  on a new -skip-softmax-stat CLI flag plumbed via
  Fused_multihead_attention_launch_params::enable_skip_softmax_stat.

setup.py kernel emitter:
- New 'enable_skip_softmax_stat' axis in kernel_spec; cubin name suffix
  '_skipSoftmaxStat' on top of '_skipSoftmax'; reverse-parser handles the
  substring relationship correctly. Enumeration extended to (skip,stat)
  pairs (False,False)/(True,False)/(True,True). DISABLE_SKIP_SOFTMAX
  env-var short-circuit removed.
- Ktraits instantiation strings (5 mask variants + the API code path)
  thread the new bool through.
- Metadata struct emitter writes mEnableSkipSoftmaxStat into all four
  variants of FusedMultiHeadAttentionKernelMetaInfoV2.

Dispatcher / runtime:
- contextFusedMultiHeadAttention/cubin/fmha_cubin.h: new
  mEnableSkipSoftmaxStat field (cubin/fmha_cubin.cpp regenerated by
  setup.py at build time).
- fused_multihead_attention_v2.{h,cpp}: kEnableSkipSoftmaxStatShift=15
  hash bit; both hashID overloads, hashFromParams, and the diagnostic
  dump consume the new flag.
- fmhaRunner.cpp: surface MHARunnerParams::enableSkipSoftmaxStat into
  Launch_params.enableSkipSoftmaxStat.
- common/attentionOp.{cpp,h}: prefill FMHA path drops its
  SKIP_SOFTMAX_STAT-guarded fields and the throw-on-env path.
  mSkipSoftmaxTotalBlocks/Skipped + new mEnableSkipSoftmaxStat are
  unconditional and default-nullptr/false.
- thop/attentionOp.cpp: skip_softmax_stat std::optional<Tensor>
  presence drives mEnableSkipSoftmaxStat (Python passes None when stats
  are off so .has_value() is the natural gate).

Build wiring:
- cpp/CMakeLists.txt: keep the SKIP_SOFTMAX_STAT option but downgrade it
  to "legacy XQA-decoder Skip-Softmax stat path". The FMHA prefill path
  no longer depends on it.
- cpp/kernels/fmha_v2/Makefile: replace the SKIP_SOFTMAX_STAT comment
  with a pointer to the new -skip-softmax-stat fmha.exe CLI flag.
- common/envUtils.{h,cpp}: remove getEnvPrintSkipSoftmaxStat
  (unused after the prefill path drops the env-var fallback).

XQA decoder kernels are intentionally untouched — their existing
SKIP_SOFTMAX_STAT-gated stat path remains opt-in via the build flag.

Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com>
…Torch backend

Wire the runtime-selected stat-collecting FMHA kernel variants into a
single config dial: SkipSoftmaxAttentionConfig.stat_log_path. When set,
the per-step / per-layer (skipped, total) block counters land in a
JSON dump at executor shutdown. Replaces the legacy
TRTLLM_PRINT_SKIP_SOFTMAX_STAT env var + per-layer print() path.

llmapi/llm_args.py:
- New stat_log_path: Optional[str] field on SkipSoftmaxAttentionConfig
  (co-located with threshold_scale_factor and target_sparsity).
- Cross-config validator on TorchLlmArgs rejects the combination of
  stat_log_path + cuda_graph_config.batch_sizes — per-layer .zero_()
  and host-side stat readback both break graph capture. Emits a debug
  warning when stat_log_path is set so users do not accidentally
  leave it on for benchmark runs.

_torch/attention_backend/trtllm.py:
- New SkipSoftmaxStatLogger class (inlined): step → layer → {total,
  skipped} dict, auto-detecting step boundaries from layer-index
  ordering. Per-step metadata (num_contexts / num_ctx_tokens /
  num_gen_tokens) is recorded once when layer 0 fires for the step.
  dump() writes JSON via Path.write_text(); an atexit fallback covers
  crashes that bypass PyExecutor.shutdown().
- Module-level singleton routes records from per-layer wrappers, so
  step counters stay coherent.
- TrtllmAttentionWrapper:
  * skip_softmax_stat tensor allocated lazily — None on the production
    path so we do not reserve 8 bytes of VRAM per layer for an unused
    buffer.
  * skip_softmax_stat_enabled resolved from
    sparse_attention_config.stat_log_path during plan(); zeros and
    reads back the GPU counters only when enabled. The tensor is now
    passed to torch.ops.trtllm.attention as None when stats are off.
  * Replace the legacy print(SKIP_SOFTMAX_STAT: layer{i}: ...) block
    with a logger.record() call.
- TrtLLMAttention.forward stamps per-step ctx/gen meta on the
  singleton logger before each wrapper.run().

_torch/pyexecutor/py_executor.py:
- PyExecutor.shutdown() flushes the stat logger before tearing down
  the worker thread / resource managers, so users get a deterministic
  dump under trtllm-serve. atexit remains as a fallback.

Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com>
…bench

Two-pass driver under tests/microbenchmarks/skip_softmax/:

- llm_configs.py: Qwen3-30B-A3B GQA proxy (q_heads=64, kv_heads=4,
  head_dim=128). Prefill bs=1 x seq{16k, 64k} x dtype{bf16, e4m3} x
  causal mask. Decode bs=64 x kv{16k, 64k} x q=1 x dtype{bf16, e4m3}
  x causal mask. threshold_scale_factor sweep lifted directly from
  blog 16's "Performance Benchmark" table.
- diffusion_configs.py: Wan2.2-T2V-A14B attn1 proxy (24 heads,
  head_dim=128, bf16, bidirectional mask) over seq{16k, 32k, 41k,
  65k}. Per the user's instruction: ignore target_sparsity (it is a
  per-model calibration knob and does not generalise to random
  data); sweep threshold_scale_factor over a log-spaced range wide
  enough to bracket 0% to 99% achieved sparsity.
- bench_skip_softmax.py: shells out to cpp/kernels/fmha_v2/bin/fmha.exe.
  Pass 1 enables -skip-softmax-stat (the new CLI flag) so the kernel
  picks the _skipSoftmaxStat cubin variant and parses
  "Skip-Softmax .: skipped / total" from stdout. Pass 2 omits the
  flag so the production _skipSoftmax cubin is selected — no
  atomic-counter overhead — and parses "Elapsed ......: <us>" for
  median latency. Speedup is computed against the threshold=0 baseline
  per config.
- generate_report.py: joins the two CSVs by (config, threshold),
  emits report.md with one table per config and a matplotlib
  speedup-vs-achieved-sparsity scatter (PNG).

The microbench drives the kernel directly via fmha.exe (extended in
the previous commit with -skip-softmax-stat) rather than reaching
through the LLM pipeline.

Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com>
@bobboli bobboli force-pushed the feat/skip-softmax-stat branch from 280ef8d to 425cf1f Compare April 30, 2026 14:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant