Skip to content

[TRTLLM-12432][perf] ltx2: drop redundant pe all-gather in AV cross-attention#13687

Open
luyiyun1021 wants to merge 1 commit intoNVIDIA:mainfrom
luyiyun1021:dev/ltx2-cross-pe-full
Open

[TRTLLM-12432][perf] ltx2: drop redundant pe all-gather in AV cross-attention#13687
luyiyun1021 wants to merge 1 commit intoNVIDIA:mainfrom
luyiyun1021:dev/ltx2-cross-pe-full

Conversation

@luyiyun1021
Copy link
Copy Markdown
Collaborator

@luyiyun1021 luyiyun1021 commented May 1, 2026

@coderabbitai summary

Description

Under Ulysses sequence parallelism, LTX-2's AV cross-attention K-side does an explicit dist.all_gather on the (cos, sin) RoPE tensors every block × every step (BasicAVTransformerBlock._sp_gather_pe). The pe is deterministic — precomputed once per generate() and held un-sharded inside MultiModalTransformerArgsPreprocessor.prepare, then sliced for Ulysses scatter. The all-gather rebuilds data the rank already had a copy of pre-shard.

This PR removes the cos/sin all-gathers by carrying both layouts on TransformerArgs.

Design

Field Layout Used by
cross_positional_embeddings_local sharded along seq Q-side rope on sharded query
cross_positional_embeddings_full un-sharded K-side rope on all-gathered K

Both fields are populated once at construction (MultiModalTransformerArgsPreprocessor.prepare) from the same static_cross_pe. LTXModel._shard_transformer_args slices _local = _shard_pe(args.cross_positional_embeddings_full); _full passes through. AV cross-attn reads pe=*._local and k_pe=*._full directly — no _sp_gather_pe call.

Why precomputed fields rather than a Python conditional or env-var gate? cuda_graph captures the kernel sequence, so a Python branch has no effect on the captured graph (we verified empirically that an env-var-gated approach left the AllGather kernel count unchanged). A field whose value is fixed at trace time is captured directly. Pre-computing in _shard_transformer_args instead of slicing inline keeps the sharding policy in one place and avoids creating fresh view objects per block × step.

Changes

  • transformer_args.py: replace cross_positional_embeddings with _local + _full; populate both at construction.
  • transformer_ltx2.py::LTXModel._shard_transformer_args: slice _local from _full; pass _full through.
  • transformer_ltx2.py::BasicAVTransformerBlock.forward: AV cross-attn reads _local (Q) / _full (K).
  • Delete the now-unused _sp_gather_pe method.

Net diff: +15 / −24 across two files.

Lossless

cos/sin are deterministic and computed once. The previous shard→gather→concat round-trip is memcpy-only (no fp ops). Reusing the un-sharded source produces a bit-identical tensor for K-side rope. Q-side path is unchanged (just renamed _local).

Validation

Setup: 2-GPU pure Ulysses (dit_cfg=1, dit_ulysses=2), single-stage 768×1280×121 fr, gs=3.0, seed=42, cuda_graph + torch_compile both ON. A/B by toggling git commits on the same worktree.

Bit-identical — raw video tensor sha256 (pre-encode), 40-step end-to-end:

baseline (HEAD~1):   371cee010e51cfc4ebb0af74a783a2f95b069ad49c70f703b07d406edc5e1c98
this patch (HEAD):   371cee010e51cfc4ebb0af74a783a2f95b069ad49c70f703b07d406edc5e1c98

NCCL kernels (10-step profiled iter, nsys with --cuda-graph-trace=node, sum over both GPUs):

baseline this patch Δ
AllGather count 7720 3880 −3840 (−49.7 %)
AllGather total 605.45 ms 434.61 ms −170.84 ms (−28.2 %)
SendRecv count 7680 7680 0 (unchanged ✓)
SendRecv total 1275.09 ms 1277.20 ms +2 ms (noise)
Total NCCL 1880.54 ms 1711.83 ms −168.71 ms (−9.0 %)

AllGather count is exactly halved — every cos/sin all-gather scheduled by _sp_gather_pe is removed from the captured graph. SendRecv (project-before-gather K/V collective) is untouched, as designed.

E2E (40-step, single timed run, post-warmup):

baseline this patch
E2E 26.104 s 25.894 s (−210 ms, −0.8 %)

Per-GPU NCCL saving (~85 ms / 10-step iter) translates to ~1 % E2E at 2 GPU; observed delta direction is consistent. Larger Ulysses sizes are expected to see similar ~9 % NCCL kernel-time reduction.

Test Coverage

No new tests — change is a data-routing rewrite of an existing collective, covered by existing LTX-2 generate-path and tests/unittest/_torch/visual_gen/... pipeline tests. No-op for single-GPU / non-Ulysses configurations: _shard_transformer_args is not called and both fields hold the same un-sharded tensor.

PR Checklist

  • PR description clearly explains what and why.

  • PR Follows TRT-LLM CODING GUIDELINES.

  • Test cases are provided for new code paths.

  • Any new dependencies have been scanned for license and vulnerabilities.

  • CODEOWNERS updated if ownership changes.

  • Documentation updated as needed.

  • Update tava architecture diagram if significant design change.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

To see a list of available CI bot commands, please comment /bot help.

@luyiyun1021 luyiyun1021 requested a review from a team as a code owner May 1, 2026 14:48
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 1, 2026

📝 Walkthrough

Walkthrough

Added a new optional cross_positional_embeddings_full field to the TransformerArgs dataclass to support precomputed positional embeddings. Updated the LTX2 transformer to use this field for AV cross-attention RoPE handling in sharded configurations, replacing the previous gather-based approach.

Changes

Cohort / File(s) Summary
Cross-positional embeddings enhancement
tensorrt_llm/_torch/visual_gen/models/ltx2/ltx2_core/transformer_args.py, tensorrt_llm/_torch/visual_gen/models/ltx2/transformer_ltx2.py
Added optional cross_positional_embeddings_full attribute to TransformerArgs dataclass. Updated _shard_transformer_args to thread this field through transformer args, and modified AV cross-attention RoPE computation to use precomputed embeddings instead of gathering from sharded embeddings.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~10 minutes

🚥 Pre-merge checks | ✅ 5
✅ Passed checks (5 passed)
Check name Status Explanation
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Description check ✅ Passed The description includes all required sections: clear explanation of the issue and solution with design rationale, validation data with performance metrics, test coverage justification, and completed PR checklist with all items verified.
Title check ✅ Passed The title accurately describes the main optimization: dropping redundant all-gather operations for positional embeddings in audio-video cross-attention within the LTX2 model under sequence parallelism.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tensorrt_llm/_torch/visual_gen/models/ltx2/transformer_ltx2.py (1)

1174-1184: ⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Preserve existing full cross-PE when sharding args (Line 1183).

Line 1183 always copies from args.cross_positional_embeddings. If this helper is ever invoked on already-sharded args, the “full” PE can be overwritten by a shard.

Suggested fix
         return replace(
             args,
             x=args.x[:, s:e],
             timesteps=_shard(args.timesteps),
             embedded_timestep=_shard(args.embedded_timestep),
             positional_embeddings=_shard_pe(args.positional_embeddings),
             cross_positional_embeddings=_shard_pe(args.cross_positional_embeddings),
             cross_scale_shift_timestep=_shard(args.cross_scale_shift_timestep),
             cross_gate_timestep=_shard(args.cross_gate_timestep),
-            cross_positional_embeddings_full=args.cross_positional_embeddings,
+            cross_positional_embeddings_full=(
+                args.cross_positional_embeddings_full
+                if args.cross_positional_embeddings_full is not None
+                else args.cross_positional_embeddings
+            ),
         )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/visual_gen/models/ltx2/transformer_ltx2.py` around lines
1174 - 1184, The helper currently always sets cross_positional_embeddings_full
to args.cross_positional_embeddings which can overwrite an existing full
(unsharded) value when called on already-sharded args; change the assignment in
the replace call to prefer the existing full value if present (e.g. use
getattr(args, "cross_positional_embeddings_full",
args.cross_positional_embeddings) or equivalent) so that
cross_positional_embeddings_full is preserved when already set, otherwise fall
back to args.cross_positional_embeddings.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Outside diff comments:
In `@tensorrt_llm/_torch/visual_gen/models/ltx2/transformer_ltx2.py`:
- Around line 1174-1184: The helper currently always sets
cross_positional_embeddings_full to args.cross_positional_embeddings which can
overwrite an existing full (unsharded) value when called on already-sharded
args; change the assignment in the replace call to prefer the existing full
value if present (e.g. use getattr(args, "cross_positional_embeddings_full",
args.cross_positional_embeddings) or equivalent) so that
cross_positional_embeddings_full is preserved when already set, otherwise fall
back to args.cross_positional_embeddings.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 4d7fb4d1-9137-4f19-b2d3-f0a9e644dbf3

📥 Commits

Reviewing files that changed from the base of the PR and between 0b44a58 and e29b559.

📒 Files selected for processing (2)
  • tensorrt_llm/_torch/visual_gen/models/ltx2/ltx2_core/transformer_args.py
  • tensorrt_llm/_torch/visual_gen/models/ltx2/transformer_ltx2.py

@luyiyun1021 luyiyun1021 force-pushed the dev/ltx2-cross-pe-full branch from e29b559 to 925a056 Compare May 1, 2026 15:33
@luyiyun1021 luyiyun1021 changed the title [None][perf] ltx2: pipe full cross_pe via TransformerArgs to skip AV cross-attn pe gather [None][perf] ltx2: split cross_pe into _local/_full to skip AV cross-attn pe gather May 1, 2026
@luyiyun1021
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #46500 [ run ] triggered by Bot. Commit: 925a056 Link to invocation

@luyiyun1021 luyiyun1021 force-pushed the dev/ltx2-cross-pe-full branch from 925a056 to 17c061c Compare May 1, 2026 15:48
@luyiyun1021
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #46503 [ run ] triggered by Bot. Commit: 17c061c Link to invocation

@luyiyun1021 luyiyun1021 force-pushed the dev/ltx2-cross-pe-full branch from 17c061c to 2976160 Compare May 1, 2026 16:10
@luyiyun1021 luyiyun1021 changed the title [None][perf] ltx2: split cross_pe into _local/_full to skip AV cross-attn pe gather [TRTLLM-12432][perf] ltx2: split cross_pe into _local/_full to skip AV cross-attn pe gather May 1, 2026
@luyiyun1021
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

…ttention

Replaces TransformerArgs.cross_positional_embeddings with two precomputed
fields:

  cross_positional_embeddings_local : sharded along seq (Q-side rope)
  cross_positional_embeddings_full  : un-sharded (K-side rope on gathered K)

Both fields are populated once at construction time in
MultiModalTransformerArgsPreprocessor.prepare from the precomputed full pe.
LTXModel._shard_transformer_args slices *_local* from *_full* for Ulysses
scatter; *_full* is passed through. AV cross-attn K-side now reads *_full*
directly, so the cos/sin all-gather scheduled by _sp_gather_pe is
unconditionally absent from the captured cuda graph.

Removes the now-unused _sp_gather_pe method.

Lossless: cos/sin are deterministic, precomputed once per generate(); the
previous shard->gather->concat round-trip was a memcpy-only no-op that this
patch elides. Verified by bit-identical raw-video sha256 across baseline
and patched runs (40-step, 2-GPU Ulysses, seed=42).

Signed-off-by: Yiyun Lu <55233584+luyiyun1021@users.noreply.github.com>
@luyiyun1021 luyiyun1021 force-pushed the dev/ltx2-cross-pe-full branch from 2976160 to 19e3ccb Compare May 1, 2026 16:12
@luyiyun1021 luyiyun1021 changed the title [TRTLLM-12432][perf] ltx2: split cross_pe into _local/_full to skip AV cross-attn pe gather [TRTLLM-12432][perf] ltx2: drop redundant pe all-gather in AV cross-attention May 1, 2026
@luyiyun1021
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #46505 [ run ] triggered by Bot. Commit: 19e3ccb Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #46506 [ run ] triggered by Bot. Commit: 19e3ccb Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #46506 [ run ] completed with state SUCCESS. Commit: 19e3ccb
/LLM/main/L0_MergeRequest_PR pipeline #36567 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

CI Agent Failure Analysis

Link to invocation

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants