[TRTLLM-12432][perf] ltx2: drop redundant pe all-gather in AV cross-attention#13687
[TRTLLM-12432][perf] ltx2: drop redundant pe all-gather in AV cross-attention#13687luyiyun1021 wants to merge 1 commit intoNVIDIA:mainfrom
Conversation
📝 WalkthroughWalkthroughAdded a new optional Changes
Estimated code review effort🎯 2 (Simple) | ⏱️ ~10 minutes 🚥 Pre-merge checks | ✅ 5✅ Passed checks (5 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tensorrt_llm/_torch/visual_gen/models/ltx2/transformer_ltx2.py (1)
1174-1184:⚠️ Potential issue | 🟡 Minor | ⚡ Quick winPreserve existing full cross-PE when sharding args (Line 1183).
Line 1183 always copies from
args.cross_positional_embeddings. If this helper is ever invoked on already-sharded args, the “full” PE can be overwritten by a shard.Suggested fix
return replace( args, x=args.x[:, s:e], timesteps=_shard(args.timesteps), embedded_timestep=_shard(args.embedded_timestep), positional_embeddings=_shard_pe(args.positional_embeddings), cross_positional_embeddings=_shard_pe(args.cross_positional_embeddings), cross_scale_shift_timestep=_shard(args.cross_scale_shift_timestep), cross_gate_timestep=_shard(args.cross_gate_timestep), - cross_positional_embeddings_full=args.cross_positional_embeddings, + cross_positional_embeddings_full=( + args.cross_positional_embeddings_full + if args.cross_positional_embeddings_full is not None + else args.cross_positional_embeddings + ), )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/visual_gen/models/ltx2/transformer_ltx2.py` around lines 1174 - 1184, The helper currently always sets cross_positional_embeddings_full to args.cross_positional_embeddings which can overwrite an existing full (unsharded) value when called on already-sharded args; change the assignment in the replace call to prefer the existing full value if present (e.g. use getattr(args, "cross_positional_embeddings_full", args.cross_positional_embeddings) or equivalent) so that cross_positional_embeddings_full is preserved when already set, otherwise fall back to args.cross_positional_embeddings.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Outside diff comments:
In `@tensorrt_llm/_torch/visual_gen/models/ltx2/transformer_ltx2.py`:
- Around line 1174-1184: The helper currently always sets
cross_positional_embeddings_full to args.cross_positional_embeddings which can
overwrite an existing full (unsharded) value when called on already-sharded
args; change the assignment in the replace call to prefer the existing full
value if present (e.g. use getattr(args, "cross_positional_embeddings_full",
args.cross_positional_embeddings) or equivalent) so that
cross_positional_embeddings_full is preserved when already set, otherwise fall
back to args.cross_positional_embeddings.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 4d7fb4d1-9137-4f19-b2d3-f0a9e644dbf3
📒 Files selected for processing (2)
tensorrt_llm/_torch/visual_gen/models/ltx2/ltx2_core/transformer_args.pytensorrt_llm/_torch/visual_gen/models/ltx2/transformer_ltx2.py
e29b559 to
925a056
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #46500 [ run ] triggered by Bot. Commit: |
925a056 to
17c061c
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #46503 [ run ] triggered by Bot. Commit: |
17c061c to
2976160
Compare
|
/bot run --disable-fail-fast |
…ttention Replaces TransformerArgs.cross_positional_embeddings with two precomputed fields: cross_positional_embeddings_local : sharded along seq (Q-side rope) cross_positional_embeddings_full : un-sharded (K-side rope on gathered K) Both fields are populated once at construction time in MultiModalTransformerArgsPreprocessor.prepare from the precomputed full pe. LTXModel._shard_transformer_args slices *_local* from *_full* for Ulysses scatter; *_full* is passed through. AV cross-attn K-side now reads *_full* directly, so the cos/sin all-gather scheduled by _sp_gather_pe is unconditionally absent from the captured cuda graph. Removes the now-unused _sp_gather_pe method. Lossless: cos/sin are deterministic, precomputed once per generate(); the previous shard->gather->concat round-trip was a memcpy-only no-op that this patch elides. Verified by bit-identical raw-video sha256 across baseline and patched runs (40-step, 2-GPU Ulysses, seed=42). Signed-off-by: Yiyun Lu <55233584+luyiyun1021@users.noreply.github.com>
2976160 to
19e3ccb
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #46505 [ run ] triggered by Bot. Commit: |
|
PR_Github #46506 [ run ] triggered by Bot. Commit: |
|
PR_Github #46506 [ run ] completed with state
|
@coderabbitai summary
Description
Under Ulysses sequence parallelism, LTX-2's AV cross-attention K-side does an explicit
dist.all_gatheron the (cos, sin) RoPE tensors every block × every step (BasicAVTransformerBlock._sp_gather_pe). The pe is deterministic — precomputed once pergenerate()and held un-sharded insideMultiModalTransformerArgsPreprocessor.prepare, then sliced for Ulysses scatter. The all-gather rebuilds data the rank already had a copy of pre-shard.This PR removes the cos/sin all-gathers by carrying both layouts on
TransformerArgs.Design
cross_positional_embeddings_localcross_positional_embeddings_fullBoth fields are populated once at construction (
MultiModalTransformerArgsPreprocessor.prepare) from the samestatic_cross_pe.LTXModel._shard_transformer_argsslices_local = _shard_pe(args.cross_positional_embeddings_full);_fullpasses through. AV cross-attn readspe=*._localandk_pe=*._fulldirectly — no_sp_gather_pecall.Why precomputed fields rather than a Python conditional or env-var gate? cuda_graph captures the kernel sequence, so a Python branch has no effect on the captured graph (we verified empirically that an env-var-gated approach left the AllGather kernel count unchanged). A field whose value is fixed at trace time is captured directly. Pre-computing in
_shard_transformer_argsinstead of slicing inline keeps the sharding policy in one place and avoids creating fresh view objects per block × step.Changes
transformer_args.py: replacecross_positional_embeddingswith_local+_full; populate both at construction.transformer_ltx2.py::LTXModel._shard_transformer_args: slice_localfrom_full; pass_fullthrough.transformer_ltx2.py::BasicAVTransformerBlock.forward: AV cross-attn reads_local(Q) /_full(K)._sp_gather_pemethod.Net diff:
+15 / −24across two files.Lossless
cos/sin are deterministic and computed once. The previous shard→gather→concat round-trip is memcpy-only (no fp ops). Reusing the un-sharded source produces a bit-identical tensor for K-side rope. Q-side path is unchanged (just renamed
_local).Validation
Setup: 2-GPU pure Ulysses (
dit_cfg=1, dit_ulysses=2), single-stage 768×1280×121 fr, gs=3.0, seed=42,cuda_graph + torch_compileboth ON. A/B by toggling git commits on the same worktree.Bit-identical — raw video tensor sha256 (pre-encode), 40-step end-to-end:
NCCL kernels (10-step profiled iter, nsys with
--cuda-graph-trace=node, sum over both GPUs):AllGather count is exactly halved — every cos/sin all-gather scheduled by
_sp_gather_peis removed from the captured graph. SendRecv (project-before-gather K/V collective) is untouched, as designed.E2E (40-step, single timed run, post-warmup):
Per-GPU NCCL saving (~85 ms / 10-step iter) translates to ~1 % E2E at 2 GPU; observed delta direction is consistent. Larger Ulysses sizes are expected to see similar ~9 % NCCL kernel-time reduction.
Test Coverage
No new tests — change is a data-routing rewrite of an existing collective, covered by existing LTX-2 generate-path and
tests/unittest/_torch/visual_gen/...pipeline tests. No-op for single-GPU / non-Ulysses configurations:_shard_transformer_argsis not called and both fields hold the same un-sharded tensor.PR Checklist
PR description clearly explains what and why.
PR Follows TRT-LLM CODING GUIDELINES.
Test cases are provided for new code paths.
Any new dependencies have been scanned for license and vulnerabilities.
CODEOWNERS updated if ownership changes.
Documentation updated as needed.
Update tava architecture diagram if significant design change.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.