[None][chore] Refactor attention forward context#13662
[None][chore] Refactor attention forward context#13662yuxianq wants to merge 7 commits intoNVIDIA:mainfrom
Conversation
|
/bot run --disable-fail-fast |
|
PR_Github #46362 [ run ] triggered by Bot. Commit: |
|
PR_Github #46362 [ run ] completed with state
|
65fbeef to
46cd18f
Compare
|
/bot run --disable-fail-fast |
|
/bot help |
GitHub Bot Help
Provide a user friendly way for developers to interact with a Jenkins server. Run See details below for each supported subcommand. Details
Launch build/test pipelines. All previously running jobs will be killed.
kill
Kill all running builds associated with pull request. skip
Skip testing for latest commit on pull request. reuse-pipeline
Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break. |
|
PR_Github #46376 [ run ] triggered by Bot. Commit: |
|
PR_Github #46376 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #46479 [ run ] triggered by Bot. Commit: |
|
PR_Github #46479 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #46562 [ run ] triggered by Bot. Commit: |
|
PR_Github #46562 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #46573 [ run ] triggered by Bot. Commit: |
|
PR_Github #46573 [ run ] completed with state |
|
/bot run --disable-fail-fast --add-multi-gpu-test |
|
PR_Github #46593 [ run ] triggered by Bot. Commit: |
|
PR_Github #46593 [ run ] completed with state
|
Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com>
Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com>
e2d1938 to
81d33fe
Compare
|
/bot run --disable-fail-fast --add-multi-gpu-test |
|
PR_Github #46718 [ run ] triggered by Bot. Commit: |
📝 WalkthroughWalkthroughThis PR introduces ChangesAttention Forward Context Refactoring
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes 🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Tip 💬 Introducing Slack Agent: The best way for teams to turn conversations into code.Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.
Built for teams:
One agent for your entire SDLC. Right inside Slack. Comment |
There was a problem hiding this comment.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (3)
tensorrt_llm/_torch/attention_backend/sparse/dsa.py (1)
1942-1993:⚠️ Potential issue | 🟠 Major | ⚡ Quick winGrow the RoPE table before invoking the DSA MLA append kernel.
This override now reads
self.rotary_cos_sindirectly, but unlikeTrtllmAttention.mla_rope_append_paged_kv_assign_q()it never calls_ensure_rope_table_size(). Oncemax_seq_lengrows past the constructor-time table size, this path can launch the kernel with an undersized cos/sin buffer.Suggested fix
def mla_rope_append_paged_kv_assign_q( self, q: torch.Tensor, latent_cache: torch.Tensor, metadata: DSAtrtllmAttentionMetadata, is_generation: bool = False, **kwargs, ) -> None: """Apply RoPE, append latent cache to paged KV, and assign query for MLA.""" if is_generation: cached_token_indptr = metadata.gen_cached_token_indptr kv_indptr = metadata.gen_kv_indptr num_seqs = metadata.num_generations max_seq_len = metadata.max_gen_seq_len block_offsets = metadata.kv_cache_block_offsets[:, metadata. num_contexts:] else: cached_token_indptr = metadata.ctx_cached_token_indptr kv_indptr = metadata.ctx_kv_indptr num_seqs = metadata.num_contexts max_seq_len = metadata.max_ctx_seq_len block_offsets = metadata.kv_cache_block_offsets assert self.is_mla_enable and self.mla_params is not None assert metadata.kv_cache_manager is not None + self._ensure_rope_table_size(metadata.kv_cache_manager.max_seq_len) sink_token_length = 0 beam_width = 1🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tensorrt_llm/_torch/attention_backend/sparse/dsa.py` around lines 1942 - 1993, The MLA RoPE path in mla_rope_append_paged_kv_assign_q reads self.rotary_cos_sin directly and can launch the DSA kernel with an undersized cos/sin buffer when max_seq_len has grown; before calling torch.ops.trtllm.mla_rope_append_paged_kv_assign_q, call the existing rope table grow helper (e.g. self._ensure_rope_table_size(max_seq_len) or the equivalent method used by TrtllmAttention.mla_rope_append_paged_kv_assign_q) using the computed max_seq_len so self.rotary_cos_sin is large enough, keeping the rest of the call and parameters unchanged.tensorrt_llm/_torch/attention_backend/trtllm.py (1)
1586-1599:⚠️ Potential issue | 🟠 Major | ⚡ Quick winTreat
out_scale_sfas a quantized-output request too.When
outputis absent, auto-allocation only enters the quantized path ifctx.out_scaleis set. The NVFP4 path consumesctx.out_scale_sfinstead, so callers that rely on auto-allocation can silently fall back to a dense output buffer and skip the NVFP4 output path.Suggested fix
if output is None: # Output is not provided. is_gen_only = ctx.attention_input_type == AttentionInputType.generation_only outputs = self.create_output( q, - is_quantize_output=ctx.out_scale is not None, + is_quantize_output=(ctx.out_scale is not None + or ctx.out_scale_sf is not None), metadata=metadata, attention_mask=ctx.attention_mask, use_paged_context_fmha=use_paged_context_fmha, is_mla_enable=self.is_mla_enable, is_gen_only=is_gen_only,🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tensorrt_llm/_torch/attention_backend/trtllm.py` around lines 1586 - 1599, The current check for auto-allocation only treats the presence of ctx.out_scale as an indicator for quantized output, but it should also consider ctx.out_scale_sf to correctly handle NVFP4 paths. In the code section around the output allocation logic using self.create_output, update the condition for is_quantize_output to check if either ctx.out_scale or ctx.out_scale_sf is not None. This ensures the NVFP4 quantized output path is properly triggered during auto-allocation.tensorrt_llm/_torch/attention_backend/vanilla.py (1)
308-318:⚠️ Potential issue | 🟠 Major | ⚡ Quick winPass
ctx.attention_window_sizethrough the no-cache path.The cached branch honors the context field via
_single_request_forward()(line 429), butno_kv_cache_forward()(line 308) does not accept it. Line 397 only passesattention_mask, droppingctx.attention_window_size. Consequently,flash_attn_varlen_func()runs with the commented-out infinite window (window_size=(-1, -1)) instead of the configured window size, causing sliding-window no-cache requests to produce full-causal outputs instead of local attention.To fix, add
attention_window_size: Optional[int] = Nonetono_kv_cache_forward()signature, pass it fromforward()(line 397), and use it in theflash_attn_varlen_func()call (line 359) withwindow_size=(attention_window_size - 1, 0)when set.Also applies to: 383–397
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tensorrt_llm/_torch/attention_backend/vanilla.py` around lines 308 - 318, The no-cache path must accept and propagate the context window size: add an optional parameter attention_window_size: Optional[int] = None to no_kv_cache_forward and update forward to pass ctx.attention_window_size into that call (matching how _single_request_forward uses it); then, in the flash_attn_varlen_func invocation inside no_kv_cache_forward, replace the current default/infinite window with window_size=(attention_window_size - 1, 0) when attention_window_size is set (otherwise keep the existing fallback), ensuring flash_attn_varlen_func receives the correct sliding-window size for local attention.
🧹 Nitpick comments (1)
tests/unittest/_torch/attention/sparse/test_rocketkv.py (1)
263-266: ⚡ Quick winAdd one assertion for a non-default context path.
Both new calls pass
AttentionForwardContext()with all defaults, so they only verify signature plumbing. They would still pass if a non-default field were dropped during propagation, or ifmerge_attention_forward_context(...)regressed on its error paths. A small backend-level test that goes throughTrtllmAttention.forward(or the merge helper directly) with one non-default field and one mixedctx+ legacy-kwargs rejection case would cover the risky part of this refactor much better.As per coding guidelines, "Coverage expectations: Assess whether new/changed tests cover happy path, important edge cases, and failure modes relevant to the feature or fix."
Also applies to: 529-531
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/unittest/_torch/attention/sparse/test_rocketkv.py` around lines 263 - 266, Add a test that exercises a non-default AttentionForwardContext and the mixed-ctx/legacy kwargs rejection: construct an AttentionForwardContext with at least one non-default field set (e.g., a non-empty torch device/flag or an explicit timestamp/seed field used by merge_attention_forward_context), call trtllm_attn.sparse_kv_predict (or TrtllmAttention.forward) with that ctx and assert the propagated/merged values are preserved, then add a second case that passes both a ctx and a legacy kwarg (to trigger merge_attention_forward_context rejection) and assert it raises the expected error; reference AttentionForwardContext, trtllm_attn.sparse_kv_predict, TrtllmAttention.forward, and merge_attention_forward_context to locate where to modify tests.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@tensorrt_llm/_torch/attention_backend/flashinfer.py`:
- Around line 756-759: Replace the assert with an explicit runtime check that
raises a ValueError when ctx.attention_mask == CustomAttentionMask.CUSTOM and
attention_mask_data is None: in the block around attention_mask_data,
ctx.attention_mask, CustomAttentionMask.CUSTOM, and attention_mask_type (set to
int(AttentionMaskType.custom_mask)), change the assert to an if-statement that
raises ValueError("attention_mask_data is required for custom attention mask.")
so validation is preserved even under -O optimization.
In `@tensorrt_llm/_torch/attention_backend/trtllm.py`:
- Around line 1132-1137: Ensure the cached RoPE max-position metadata is updated
when growing the table: in _ensure_rope_table_size (after updating
self.rope_params.max_positions and recreating
self.rotary_inv_freq/self.rotary_cos_sin via
self.rope_params.create_rope_const_params()) also assign the new max value to
self.rotary_embedding_max_positions so _run() will forward the correct, in-sync
max-position metadata; update any related code paths that rely on
rotary_embedding_max_positions to use this refreshed value.
---
Outside diff comments:
In `@tensorrt_llm/_torch/attention_backend/sparse/dsa.py`:
- Around line 1942-1993: The MLA RoPE path in mla_rope_append_paged_kv_assign_q
reads self.rotary_cos_sin directly and can launch the DSA kernel with an
undersized cos/sin buffer when max_seq_len has grown; before calling
torch.ops.trtllm.mla_rope_append_paged_kv_assign_q, call the existing rope table
grow helper (e.g. self._ensure_rope_table_size(max_seq_len) or the equivalent
method used by TrtllmAttention.mla_rope_append_paged_kv_assign_q) using the
computed max_seq_len so self.rotary_cos_sin is large enough, keeping the rest of
the call and parameters unchanged.
In `@tensorrt_llm/_torch/attention_backend/trtllm.py`:
- Around line 1586-1599: The current check for auto-allocation only treats the
presence of ctx.out_scale as an indicator for quantized output, but it should
also consider ctx.out_scale_sf to correctly handle NVFP4 paths. In the code
section around the output allocation logic using self.create_output, update the
condition for is_quantize_output to check if either ctx.out_scale or
ctx.out_scale_sf is not None. This ensures the NVFP4 quantized output path is
properly triggered during auto-allocation.
In `@tensorrt_llm/_torch/attention_backend/vanilla.py`:
- Around line 308-318: The no-cache path must accept and propagate the context
window size: add an optional parameter attention_window_size: Optional[int] =
None to no_kv_cache_forward and update forward to pass ctx.attention_window_size
into that call (matching how _single_request_forward uses it); then, in the
flash_attn_varlen_func invocation inside no_kv_cache_forward, replace the
current default/infinite window with window_size=(attention_window_size - 1, 0)
when attention_window_size is set (otherwise keep the existing fallback),
ensuring flash_attn_varlen_func receives the correct sliding-window size for
local attention.
---
Nitpick comments:
In `@tests/unittest/_torch/attention/sparse/test_rocketkv.py`:
- Around line 263-266: Add a test that exercises a non-default
AttentionForwardContext and the mixed-ctx/legacy kwargs rejection: construct an
AttentionForwardContext with at least one non-default field set (e.g., a
non-empty torch device/flag or an explicit timestamp/seed field used by
merge_attention_forward_context), call trtllm_attn.sparse_kv_predict (or
TrtllmAttention.forward) with that ctx and assert the propagated/merged values
are preserved, then add a second case that passes both a ctx and a legacy kwarg
(to trigger merge_attention_forward_context rejection) and assert it raises the
expected error; reference AttentionForwardContext,
trtllm_attn.sparse_kv_predict, TrtllmAttention.forward, and
merge_attention_forward_context to locate where to modify tests.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 9148f771-7dc2-4722-9d12-4598321aa4e0
📒 Files selected for processing (11)
tensorrt_llm/_torch/attention_backend/__init__.pytensorrt_llm/_torch/attention_backend/flashinfer.pytensorrt_llm/_torch/attention_backend/interface.pytensorrt_llm/_torch/attention_backend/sparse/dsa.pytensorrt_llm/_torch/attention_backend/sparse/rocket.pytensorrt_llm/_torch/attention_backend/star_flashinfer.pytensorrt_llm/_torch/attention_backend/trtllm.pytensorrt_llm/_torch/attention_backend/vanilla.pytensorrt_llm/_torch/modules/attention.pytests/unittest/_torch/attention/sparse/test_rocketkv.pytests/unittest/_torch/attention/sparse/test_sparse_attention.py
Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com>
|
/bot run --disable-fail-fast --add-multi-gpu-test |
|
PR_Github #46776 [ run ] triggered by Bot. Commit: |
Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com>
|
/bot run --disable-fail-fast --add-multi-gpu-test |
Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com>
Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com>
|
/bot run --disable-fail-fast --add-multi-gpu-test |
|
PR_Github #46786 [ run ] triggered by Bot. Commit: |
|
/bot run --disable-fail-fast --add-multi-gpu-test |
|
PR_Github #46952 [ run ] triggered by Bot. Commit: |
Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com>
32d6532 to
71c0dd4
Compare
|
/bot run --disable-fail-fast --add-multi-gpu-test |
|
PR_Github #46956 [ run ] triggered by Bot. Commit: |
|
PR_Github #46952 [ run ] completed with state |
|
PR_Github #46956 [ run ] completed with state
|
|
/bot run --disable-fail-fast --add-multi-gpu-test |
1 similar comment
|
/bot run --disable-fail-fast --add-multi-gpu-test |
|
PR_Github #47037 [ run ] triggered by Bot. Commit: |
Summary by CodeRabbit
Refactor
Description
Refactor PyTorch attention backend forwarding around an explicit
AttentionForwardContext. This removes the TRTLLM attention wrapper, moves the former plan/run data flow intoTrtllmAttention.forwardand_run, and makes vanilla, FlashInfer, StarAttention, TRTLLM, DSA, and Rocket use the same context merge path.The merge helper now returns only
AttentionForwardContextand rejects unknown forward kwargs. The sparse TRTLLM hooks also accept the context object directly.Design Documents
Test Coverage
pre-commit run isort --files ...pre-commit run yapf --files ...pre-commit run ruff-legacy --files ...python -m py_compile tensorrt_llm/_torch/attention_backend/interface.py tensorrt_llm/_torch/attention_backend/trtllm.py tensorrt_llm/_torch/attention_backend/vanilla.py tensorrt_llm/_torch/attention_backend/flashinfer.py tensorrt_llm/_torch/attention_backend/star_flashinfer.py tensorrt_llm/_torch/attention_backend/sparse/dsa.py tensorrt_llm/_torch/attention_backend/sparse/rocket.pygit diff --checkaccuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=2-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=True-enable_chunked_prefill=True]passed, andaccuracy/test_llm_api_pytorch.py::TestLlama3_1_8B::test_auto_dtypepassed.PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.