[#12784][feat] AutoDeploy: Optimize DeepSeek-R1 model performance#12946
[#12784][feat] AutoDeploy: Optimize DeepSeek-R1 model performance#12946taylor-yb-lee wants to merge 60 commits intoNVIDIA:mainfrom
Conversation
45a1967 to
0a1a52c
Compare
fdb717b to
58aacc8
Compare
📝 WalkthroughWalkthroughThe changes introduce specialized support for Blackwell (SM100f) GPU architecture with early detection paths in FP8 quantization, implement TMA-aligned tensor layout utilities for kernel fusion, refactor output-splitting logic to use Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes 🚥 Pre-merge checks | ✅ 1 | ❌ 2❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (1 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 4
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py (2)
581-584:⚠️ Potential issue | 🟠 MajorRestore TMA layout after concatenating UE8M0 fine-grained scales.
On SM100f,
FineGrainedFP8LinearQuantization.post_load_hook()convertsweight_scale_invto TMA-alignedtorch.int, andtrtllm_finegrained_fp8_linear()switches tofp8_swap_ab_gemmbased only on that dtype.torch.cat(...)here drops the column-major/TMA layout, so the fused path can feed row-major UE8M0 scales straight into DeepGEMM. This needs the same layout re-materialization you added infuse_swiglu.py, or the fused path should opt out of the DeepGEMM fast-path.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py` around lines 581 - 584, The concatenation of weight_scale_inv via torch.cat drops the TMA-aligned column-major layout produced by FineGrainedFP8LinearQuantization.post_load_hook, so update the fused path after fused_weight_scale_inv = torch.cat(weight_scale_inv, dim=0) to re-materialize the original TMA/column-major layout and torch.int dtype (same approach used in fuse_swiglu.py) so trtllm_finegrained_fp8_linear still sees the TMA-aligned torch.int and chooses fp8_swap_ab_gemm; alternatively, if re-materializing is undesirable, mark the fused path to opt out of the DeepGEMM fast-path by ensuring the layout/dtype check used by trtllm_finegrained_fp8_linear fails.
442-446:⚠️ Potential issue | 🟠 MajorDon't unconditionally claim shape metadata is valid.
_insert_fused_gemm()and_insert_fused_quant_gemm()only populatemeta["val"]when the original linear node already had it. If a mixed-children fusion runs without that metadata, the newtorch.narrow/contiguousnodes are still shape-less, buthas_valid_shapes=Truesuppresses the fallback shape-prop pass and leaves downstream shape-based transforms with incomplete metadata.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py` around lines 442 - 446, The TransformInfo currently sets has_valid_shapes=True unconditionally which hides missing shape metadata created by _insert_fused_gemm() and _insert_fused_quant_gemm(); change the code that builds the TransformInfo so has_valid_shapes is computed (not hard-coded) by inspecting the fused outputs' metadata (e.g., check meta.get("val") presence on the newly created nodes such as the torch.narrow/contiguous results or verify that the original linear nodes provided shape meta) and set has_valid_shapes=False if any fused node lacks meta["val"] so the fallback shape-prop pass can run; use the same symbol names (_insert_fused_gemm, _insert_fused_quant_gemm, TransformInfo, torch.narrow, contiguous) to locate where to compute this flag.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/auto_deploy/model_registry/configs/deepseek-r1.yaml`:
- Around line 35-37: The example config re-enables the piecewise CUDA-graph
splitting; change the compile_model setting so piecewise_enabled is false to
match the rollback and keep the regressing path disabled, leaving
piecewise_num_tokens intact for future use (update the compile_model:
piecewise_enabled value only).
In `@tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py`:
- Around line 46-53: The imported FP8 helpers resmooth_to_fp8_e8m0 and
transform_sf_into_required_layout may be None; update the SM100f post-load hook
(the function that currently calls these helpers unconditionally) to check for
their presence before invoking them and either skip the FP8 conversion path or
raise a clear, descriptive error; apply the same None-guarding pattern to the
other occurrence around lines 934-946 so both call sites test "if
resmooth_to_fp8_e8m0 is not None and transform_sf_into_required_layout is not
None" (or handle each helper individually) and log or fallback appropriately
instead of calling a NoneType.
- Around line 948-957: The code currently replaces the quantization scale slot
with an nn.Parameter (using setattr on scale_attr), but weight_scale_inv must
remain a buffer so later fusion code (gm.get_buffer(...) in fusion.py) can find
it; modify the transform in the block that sets scale_attr so that when
scale_attr corresponds to the weight_scale_inv buffer you call
target_module.register_buffer(scale_attr, transformed_scale) (ensuring the
tensor is detached/converted to the correct device/dtype) instead of setting an
nn.Parameter, while leaving the weight replacement (attr_name) as an
nn.Parameter as before.
In
`@tests/unittest/auto_deploy/singlegpu/transformations/library/test_gemm_fusion.py`:
- Around line 643-645: In the forward method(s) in test_gemm_fusion.py where the
shape is unpacked (e.g., the forward function that sets batch_size, seq_len, _ =
x.shape), rename the intentionally unused bindings to start with an underscore
(e.g., _batch_size, _seq_len) so Ruff stops flagging them as unused; apply the
same change to the other occurrence(s) of the same unpacking later in the file
(the forward that currently uses batch_size and seq_len) by prefixing those
names with an underscore.
---
Outside diff comments:
In `@tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py`:
- Around line 581-584: The concatenation of weight_scale_inv via torch.cat drops
the TMA-aligned column-major layout produced by
FineGrainedFP8LinearQuantization.post_load_hook, so update the fused path after
fused_weight_scale_inv = torch.cat(weight_scale_inv, dim=0) to re-materialize
the original TMA/column-major layout and torch.int dtype (same approach used in
fuse_swiglu.py) so trtllm_finegrained_fp8_linear still sees the TMA-aligned
torch.int and chooses fp8_swap_ab_gemm; alternatively, if re-materializing is
undesirable, mark the fused path to opt out of the DeepGEMM fast-path by
ensuring the layout/dtype check used by trtllm_finegrained_fp8_linear fails.
- Around line 442-446: The TransformInfo currently sets has_valid_shapes=True
unconditionally which hides missing shape metadata created by
_insert_fused_gemm() and _insert_fused_quant_gemm(); change the code that builds
the TransformInfo so has_valid_shapes is computed (not hard-coded) by inspecting
the fused outputs' metadata (e.g., check meta.get("val") presence on the newly
created nodes such as the torch.narrow/contiguous results or verify that the
original linear nodes provided shape meta) and set has_valid_shapes=False if any
fused node lacks meta["val"] so the fallback shape-prop pass can run; use the
same symbol names (_insert_fused_gemm, _insert_fused_quant_gemm, TransformInfo,
torch.narrow, contiguous) to locate where to compute this flag.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro Plus
Run ID: 5438b43f-9552-4332-930c-c6f38a64bef2
📒 Files selected for processing (6)
examples/auto_deploy/model_registry/configs/deepseek-r1.yamltensorrt_llm/_torch/auto_deploy/custom_ops/quantization/torch_quant.pytensorrt_llm/_torch/auto_deploy/transform/library/fuse_swiglu.pytensorrt_llm/_torch/auto_deploy/transform/library/fusion.pytensorrt_llm/_torch/auto_deploy/transform/library/quantization.pytests/unittest/auto_deploy/singlegpu/transformations/library/test_gemm_fusion.py
46a7475 to
84b0b9e
Compare
b80aadf to
50dd694
Compare
50dd694 to
65d38ef
Compare
|
/bot run --extra-stage "DGX_B200-4_GPUs-AutoDeploy-1, DGX_H100-4_GPUs-AutoDeploy-1" |
|
PR_Github #44763 [ run ] triggered by Bot. Commit: |
|
PR_Github #44763 [ run ] completed with state
|
b7d1a93 to
4550f61
Compare
|
/bot run --disable-fail-fast --extra-stage "DGX_B200-4_GPUs-AutoDeploy-1, DGX_H100-4_GPUs-AutoDeploy-1" |
|
PR_Github #45908 [ run ] triggered by Bot. Commit: |
…comments
Bring back the per-branch comments ("Force TRT-LLM ops", "Force PyTorch
distributed ops", "Automatically select based on availability", "Use
TRT-LLM optimized ops in MPI mode", "Use PyTorch distributed ops in
demollm mode") that were dropped when SYMM_MEM AllGather was added.
They explain at a glance why each branch picks the ops it does.
Signed-off-by: Taylor Yeonbok Lee <249374542+taylor-yb-lee@users.noreply.github.com>
…lper _requantize_to_128x128_ue8m0 was added in 75833d1 to push TP-misaligned FP8 projections (q_a N=192, kv_a N=72 at TP=8) through DeepGEMM. ce9bc00 then flipped that policy: misaligned projections now skip DeepGEMM and fall back to cuBLAS with float32 scales because the requantize introduced power-of-2 rounding that hurt accuracy (DSR1 MMLU 82.31 -> 84.16). After that flip the function has no callers — aligned projections (q_b, kv_b, MoE, etc.) go through resmooth_to_fp8_e8m0, and misaligned ones are skipped before any requantize would happen. Drop the dead helper and rephrase the post_load_hook comment so it no longer references it. Signed-off-by: Taylor Yeonbok Lee <249374542+taylor-yb-lee@users.noreply.github.com>
The "Phase 0/1" naming implied sequential phases of one operation, but they are mutually exclusive matches selected by graph shape (unfused vs fused GEMMs). Rename to "Case 0/1" — the unmatched case is just skipped rather than feeding into the next. Also extend the module docstring with ASCII flow diagrams for both cases showing which ops live on the main vs aux stream and where the streams synchronise back, and clarify that case 0's AllGather uses symm_mem (not NCCL) after the d967d8c rework. Updates the MLA reference in multi_stream_utils.py to match. Signed-off-by: Taylor Yeonbok Lee <249374542+taylor-yb-lee@users.noreply.github.com>
The previous diagram showed symm_mem_all_gather_aux directly, which read as if the transform only matches that specific op. The matcher actually accepts any AllGather variant in _ALL_GATHER_OPS; only the rewrite is fixed to symm_mem_all_gather_aux. Split the docstring so the Match step uses an abstract <any AllGather op> placeholder and a dedicated Rewrite section explains the substitution. Signed-off-by: Taylor Yeonbok Lee <249374542+taylor-yb-lee@users.noreply.github.com>
…tized note "Pattern" reads more accurately than "case" for graph-shape-driven match alternatives. Also drop the "(fallback for non-quantized graphs)" qualifier on pattern 1: the current _is_linear matches quantized linears too (fp8, fake-quant), so the qualifier is stale from when pattern 1 only handled torch_linear_simple/aten.linear. Signed-off-by: Taylor Yeonbok Lee <249374542+taylor-yb-lee@users.noreply.github.com>
bc4999c split the piecewise multi-stream bypass per-path via an aux_has_collective kwarg, leaving non-collective aux paths (e.g. multi_stream_moe shared-expert overlap) actually switching streams during piecewise capture/replay. The piecewise CUDA graphs share one memory pool, so concurrent main/aux replays on that pool race and corrupt buffers, manifesting as cudaErrorIllegalAddress on the first decode forward (DeepSeek-R1, world_size=8). Restore the pre-bc4999c7e3 semantics: when disable_aux_stream_switch is set, every begin/end/wait_aux_stream_passthrough becomes a no-op regardless of aux_has_collective. The kwarg is preserved as a documentation hint (it still flags collective-bearing aux paths) but does not influence the bypass. Decode (monolithic) capture is unaffected because disable_aux_stream_switch stays False there, so full multi-stream overlap is captured. Signed-off-by: Taylor Yeonbok Lee <249374542+taylor-yb-lee@users.noreply.github.com>
… bypass
Drop the in-branch piecewise multi-stream bypass that was added across
several commits (introduces flag, gate by aux_has_collective, all-or-
nothing rebypass). The same problem is solved upstream via the
disable_multi_stream context manager landing next from
suyogg/fix-ms-pcg-att2 (cherry-pick to follow), so this branch's bypass
is removed before that commit is applied.
Reverted pieces:
- multi_stream_utils.py: restored to main's state (drops
disable_aux_stream_switch flag, the aux_has_collective kwarg on
begin/end/wait_aux_stream_passthrough, the caller_stream.synchronize()
block, and re-adds the x.record_stream(aux_stream) hint).
- torch_cudagraph.py: drop the import of multi_stream_utils and the
disable_aux_stream_switch=True line in warmup_and_capture.
- multi_stream_attn.py (MLA Phase 0): drop kv_kwargs={"aux_has_collective":
True} and the kwargs= passing on begin/end/wait_aux_stream_passthrough
call sites.
No other in-branch work (Phase 0/1 MLA, Phase 2 cleanup in 43270db,
piecewise compat changes, doc/refactor commits) is touched.
Signed-off-by: Taylor Yeonbok Lee <249374542+taylor-yb-lee@users.noreply.github.com>
…se partitions After cherry-picking d8f5de4 (disable_multi_stream context manager that no-ops stream-switch passthroughs and aux-stream impls inside the piecewise CG warmup/capture/forward), the line-316 first-pass partition rule that split stream-switch nodes into their own dynamic partitions is redundant and contradicts the cherry-pick's design. Drop _is_stream_switch_node and the special-case branch in split_graph_at_dynamic_ops so stream-switch passthroughs ride along inside the surrounding static partition. At piecewise capture/replay time disable_multi_stream() makes them no-ops, so they capture safely; outside piecewise (monolithic decode CG) they execute normally for full multi-stream overlap. _STREAM_SWITCH_FUNCTION_NAMES is left as a documentation sentinel (matching the cherry-pick). Signed-off-by: Taylor Yeonbok Lee <249374542+taylor-yb-lee@users.noreply.github.com>
…sforms is_dist_op, get_lm_head_node, and multi_stream_attn previously only knew about trtllm_dist_all_gather / torch_dist_all_gather, so the symm-mem allgather ops were invisible. With allgather_strategy: SYMM_MEM (set in deepseek-r1.yaml) this caused silent regressions: gather_logits_before_lm_head no-oped on lm-head AG and the multi-stream MLA aux rewrite was hardcoded to symm_mem_all_gather_aux regardless of the matched main-stream op. Centralize the AllGather op set as ALL_GATHER_OPS in node_utils, switch the three sites to use it, and derive the aux op from the matched main op (symm_mem_all_gather -> symm_mem_all_gather_aux for separate workspace, NCCL variants reused as-is). Signed-off-by: Taylor Yeonbok Lee <249374542+taylor-yb-lee@users.noreply.github.com>
… + workspace_id
The previous shape encoded AllGather strategy in op identity:
symm_mem_all_gather, symm_mem_all_gather_torch, and symm_mem_all_gather_aux
were distinct from trtllm_dist_all_gather / torch_dist_all_gather, and
sharding.py:_get_dist_ops branched the op handle by AllGatherStrategy.
This is asymmetric with AllReduce, where the op is backend-only and the
strategy flows through as an op argument; every downstream pattern-match
site that keyed on the AG op handle (is_dist_op, get_lm_head_node,
multi_stream_attn) had to enumerate every strategy variant.
Collapse to one op per backend, mirroring AllReduce:
trtllm_dist_all_gather(tensor, dim, sizes, strategy, workspace_id)
torch_dist_all_gather (tensor, dim, sizes, strategy, workspace_id)
strategy="AUTO" -> NCCL all-gather (TRT-LLM optimized / torch.distributed)
strategy="SYMM_MEM" -> SymmetricMemoryAllGather, falls back to NCCL
workspace_id -> picks the symm_mem ProcessGroup/workspace; aux
streams use a non-zero id to avoid workspace
conflicts with the main-stream allgather
Ops symm_mem_all_gather, symm_mem_all_gather_torch, and
symm_mem_all_gather_aux are removed. _get_dist_ops returns backend-only
ops; the call site passes strategy through dist_lookup. multi-stream MLA
re-emits the matched AllGather on the aux stream with workspace_id=1.
all_gather_ops() returns the two unified ops only.
Signed-off-by: Taylor Yeonbok Lee <249374542+taylor-yb-lee@users.noreply.github.com>
Signed-off-by: Taylor Yeonbok Lee <249374542+taylor-yb-lee@users.noreply.github.com>
Signed-off-by: Taylor Yeonbok Lee <249374542+taylor-yb-lee@users.noreply.github.com>
Signed-off-by: Taylor Yeonbok Lee <249374542+taylor-yb-lee@users.noreply.github.com>
Signed-off-by: Taylor Yeonbok Lee <249374542+taylor-yb-lee@users.noreply.github.com>
Signed-off-by: Taylor Yeonbok Lee <249374542+taylor-yb-lee@users.noreply.github.com>
Signed-off-by: Taylor Yeonbok Lee <249374542+taylor-yb-lee@users.noreply.github.com>
- Add tests/unittest/_torch/multi_gpu/test_allgather.py mirroring test_allreduce.py: parametrize over strategy (NCCL, SYMM_MEM) and world_size (4, 8); verify correctness vs dist.all_gather and CUDA Graph capture/replay for SYMM_MEM. SYMM_MEM cases auto-skip when device capability or world_size is outside the MULTIMEM support matrix. - Rename test_ad_allreduce_strategies.py to test_ad_dist_strategies.py to reflect its broader scope, and add test_allgather_strategy_propagation (AUTO, SYMM_MEM) mirroring test_allreduce_strategy_propagation: column shard with dist_op="all_gather" emits an allgather node carrying the configured strategy at args[3]. Signed-off-by: Taylor Yeonbok Lee <249374542+taylor-yb-lee@users.noreply.github.com>
Signed-off-by: Taylor Yeonbok Lee <249374542+taylor-yb-lee@users.noreply.github.com> (cherry picked from commit a803e2a)
DSR1 (and other TP-sharded models) emit AllGather along the last hidden dim. SymmetricMemoryAllGather previously rejected dim != 0 in can_use_symm_mem(), so every SYMM_MEM AllGather silently fell back to NCCL ring-LL at runtime even though the FX graph carries the SYMM_MEM strategy. Restore the transpose-to-dim-0 path in forward() so the multimem kernel is actually used. For prefill-sized tensors on the piecewise path (where multi-stream overlap is disabled by the [NVIDIA#13321] accuracy fix), the transpose+contiguous copies plus the multimem all-at-once kernel end up fully exposed on the critical path, and lose to NCCL ring-LL's pipelined bandwidth. Add a per-(device-capability, world-size) _TRANSPOSE_PERF_THRESHOLD (1 MiB on B200/H100 ws=8) that gates the dim != 0 path: small tensors stay on multimem (decode wins), large tensors fall back to NCCL (prefill bounded). The dim == 0 path is unaffected and continues to be gated only by the workspace-overflow check in _MAX_SIZES. DSR1 (8xB200, conc=256, isl=osl=1000) net effect: - Decode AllGather kernel returns to multimem_all_gather_kernel (was ncclDevKernel_AllGather_RING_LL); ITL p50 36.88 -> 35.10 ms. - Prefill AllGather (>1 MiB output) stays on NCCL ring-LL; TTFT tails do not blow up the way they did with multimem-on-main-stream. - OSL Mismatch Count remains 0 (correctness preserved). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Taylor Yeonbok Lee <249374542+taylor-yb-lee@users.noreply.github.com> (cherry picked from commit 1e1c08cab4b8d561a032934b560c5c23880e70a0)
91e0b38 to
9992506
Compare
Disable pw cudagraph until pwc + multistream error fixed in main Signed-off-by: Taylor Yeonbok Lee <249374542+taylor-yb-lee@users.noreply.github.com>
9992506 to
c1943cd
Compare
|
/bot run --disable-fail-fast --extra-stage "DGX_B200-4_GPUs-AutoDeploy-1, DGX_H100-4_GPUs-AutoDeploy-1" |
|
PR_Github #46542 [ run ] triggered by Bot. Commit: |
d16355d to
21b218e
Compare
Revise comments in symm_mem_allgather Validate GPU pinning in SymmetricMemoryAllGather init Signed-off-by: Taylor Yeonbok Lee <249374542+taylor-yb-lee@users.noreply.github.com>
21b218e to
7e0e791
Compare
|
/bot run --disable-fail-fast --extra-stage "DGX_B200-4_GPUs-AutoDeploy-1, DGX_H100-4_GPUs-AutoDeploy-1" |
|
PR_Github #46546 [ run ] triggered by Bot. Commit: |
|
PR_Github #46546 [ run ] completed with state
|
Summary by CodeRabbit
Description
Test Coverage
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.