Skip to content

[feat] ULTRA-HSTU Mixture of Transducers (MoT)#492

Merged
tiankongdeguiji merged 26 commits intoalibaba:masterfrom
tiankongdeguiji:feat/ultra-hstu-mot
May 1, 2026
Merged

[feat] ULTRA-HSTU Mixture of Transducers (MoT)#492
tiankongdeguiji merged 26 commits intoalibaba:masterfrom
tiankongdeguiji:feat/ultra-hstu-mot

Conversation

@tiankongdeguiji
Copy link
Copy Markdown
Collaborator

@tiankongdeguiji tiankongdeguiji commented Apr 29, 2026

Summary

Adds UltraHSTU as a new top-level model type implementing the four ULTRA-HSTU optimizations from Ding et al., 2026 (arXiv:2602.16986) on top of DlrmHSTU:

  • Semi-Local Attention (SLA)HSTU.stu.sla_k1 (local window) + sla_k2 (global prefix); requires Kernel.CUTLASS or Kernel.PYTORCH.
  • Mid-stack Attention TruncationHSTU.attn_truncation_split_layer + attn_truncation_tail_len; drops UIH prefix after the split layer, preserving contextual + targets.
  • Mixture of Transducers (MoT)repeated HSTU hstu in UltraHSTU runs N parallel transducers, one per UIH channel; per-candidate outputs concatenate on the embedding dim before the multi-task tower. Candidate and contextual groups stay shared. Channel selection is via a single new optional string HSTU.name field that replaces the default uih prefix on UIH-side feature_group keys (e.g. name="uih_click" reads uih_click.sequence / uih_click_action.sequence / ...). Empty name preserves all current DlrmHSTU configs.
  • Selective rematerialization — already on STU via recompute_normed_x_in_backward / recompute_uvqk_in_backward; surfaced + documented for ULTRA-HSTU.

Single-channel UltraHSTU (one hstu entry, empty name) is behaviorally equivalent to DlrmHSTU; SLA and Attention Truncation are independently usable in that mode too.

Implementation notes

  • Proto: one optional string name on HSTU; new UltraHSTU model message in multi_task_rank.proto mirroring DlrmHSTU's top-level fields with repeated HSTU hstu; UltraHSTU ultra_hstu = 207 in model.proto's oneof.
  • Runtime: UltraHSTU subclasses DlrmHSTU via two small extracted hooks (_build_transducer, _stu_embedding_dim); the wrapper _HSTUTransducerStack is bypassed for single-channel configs so AOTI / JIT trace shapes are unchanged. STULayer exposes attn_func_static_sig (a colon-separated :-string of init-time SLA constants) as a @property; STUStack precomputes a per-layer prev_attn_func_sig list at construction so the forward loop has zero sig bookkeeping and is fx-trace-safe by construction.
  • SLA + truncation infra hardened for fx symbolic-trace + AOT-export end-to-end (the cutlass kuairand integration test was the first to exercise them under export). Fixes:
    • build_sla_func_tensor: unconditional int32 cast (no Proxy.dtype check); torch.diff + repeat_interleave instead of searchsorted on a slice (avoids Inductor SliceView.get_stride() NotImplementedError).
    • compute_stu_truncation_plan returns precomputed total-int fields on STUTruncationPlan (total_dropped / total_kept / total_prefix / total_rest) so apply_stu_truncation_plan can pass them as total_len_* into split_2D_jagged and skip its .item() fallback under FakeTensor.
  • Tests + config + docs:
    • Renamed tzrec/tests/configs/dlrm_ultra_hstu_cutlass_kuairand_1k.configultra_hstu_cutlass_kuairand_1k.config; converted to ultra_hstu with two channels (uih_click, uih_view); added test_rank_ultra_hstu_cutlass_train_eval_export in rank_integration_test.py.
    • New kuairand-mot-1k parquet (split UIH by action_weight bits at preprocess time) wired into scripts/ci/ci_data.sh (oss-accelerate); existing kuairand-1k URLs also moved to oss-accelerate.
    • tzrec/models/ultra_hstu_test.py parametrizes over (single, multi_uniform, multi_hetero) channel topologies × NORMAL / FX_TRACE / JIT_SCRIPT / AOT_INDUCTOR × PYTORCH / TRITON.
    • tzrec/ops/hstu_attention_utils_test.py parametrizes every numerical-correctness test over NORMAL + FX_TRACE for direct fx-trace regression coverage.
    • docs/source/models/ultra_hstu.md (NEW) + README + generative.rst toctree.
    • tzrec version bump 1.1.16 → 1.1.17.

Test plan

  • python -m tzrec.modules.gr.preprocessors_test — back-compat (empty name defaults to existing keys).
  • python -m tzrec.modules.gr.hstu_transducer_test — back-compat.
  • python -m tzrec.modules.gr.stu_test — SLA cache + truncation back-compat.
  • python -m tzrec.modules.gr.preprocessors_test — back-compat.
  • python -m tzrec.ops.hstu_attention_utils_test — 27 parametrized cases (eager + fx-trace).
  • python -m tzrec.models.dlrm_hstu_test — DlrmHSTU refactor is behavior-preserving (full hypothesis sweep).
  • python -m tzrec.models.ultra_hstu_test — 3 channel topologies × hypothesis sweep across all graph types and kernels.
  • python -m unittest tzrec.tests.rank_integration_test.RankIntegrationTest.test_rank_ultra_hstu_cutlass_train_eval_export — end-to-end train → eval → export → predict on kuairand-mot-1k with two MoT channels under CUTLASS + AOT export.

🤖 Generated with Claude Code

tiankongdeguiji and others added 7 commits April 29, 2026 16:30
Add an optional `name` parameter to ContextualInterleavePreprocessor and
thread it through HSTUTransducer.__init__ -> create_input_preprocessor.
When `name == ""` (default) the four UIH-side keys
("uih.sequence", "uih.sequence_length", "uih_action.sequence",
"uih_watchtime.sequence", "uih_timestamp.sequence") are unchanged, so
all existing DlrmHSTU configs and tests pass without modification.
When `name != ""` (e.g. "consumption") the channel name *replaces* the
default `uih` prefix, so the preprocessor reads
"consumption.sequence", "consumption_action.sequence",
"consumption_watchtime.sequence", "consumption_timestamp.sequence".

This is the substrate for ULTRA-HSTU's Mixture of Transducers, where
N parallel HSTUTransducer stacks each consume a disjoint UIH channel
while sharing the candidate stream and contextual features.

Candidate-side keys ("candidate.sequence", "candidate.sequence_length",
"candidate_timestamp.sequence") and the contextual key (already
configurable via contextual_group_name) are intentionally unchanged.

Allowlist "MoT" in codespell to avoid the false positive on the MoT
acronym in docstrings.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…ooks

Factor the transducer construction and the STU embedding dim accessor
out of __init__ into two overridable hooks, and split the post-assert
init body into _init_after_assert(). Subclasses (UltraHSTU) reuse the
same scaffolding while assigning their own model-type assertion and
constructing alternative transducer wrappers (e.g. a stack of MoT
channels with concatenated outputs).

No behavior change for DlrmHSTU; existing test_dlrm_hstu and
test_dlrm_hstu_task_weight pass unchanged.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Add an optional `name` field to the HSTU proto message. When set, the
HSTUTransducer reads UIH-side keys with the channel name *replacing*
the default `uih` prefix (e.g. name="consumption" causes the
preprocessor to read consumption.sequence, consumption_action.sequence,
consumption_watchtime.sequence, consumption_timestamp.sequence). Empty
(default) preserves the existing uih.sequence, uih_action.sequence,
etc. lookups, so all current DlrmHSTU configs remain valid.

This field will be set per channel in the upcoming UltraHSTU model,
which holds `repeated HSTU hstu` (one entry per Mixture-of-Transducers
channel) and concatenates per-candidate outputs.

The wiring through HSTUTransducer.__init__ -> create_input_preprocessor
-> ContextualInterleavePreprocessor was added in the previous refactor
commit, so DlrmHSTU configs can opt-in to a name without any further
code change.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Add the UltraHSTU model proto with `repeated HSTU hstu` (one entry per
Mixture-of-Transducers channel) and the same auxiliary fields as
DlrmHSTU (fusion_mtl_tower, max_seq_len, item_embedding_hidden_dim,
enable_global_average_loss, sequence_timestamp_is_ascending,
concat_contextual_features).

When `hstu` has a single entry the model behaves like DlrmHSTU
(modulo the optional `name` routing on the channel).  When `hstu`
has >= 2 entries every channel must set a unique non-empty `name`,
and the corresponding feature groups (uih-side replaced by `<name>`,
`<name>_action`, `<name>_watchtime`, `<name>_timestamp`) must be
defined.  The candidate-side and contextual feature groups are
shared across every channel.

Wire the new model type into the `oneof model` in model.proto with
field number 207 (206 is taken by pepnet).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Add the UltraHSTU model class as a sibling to DlrmHSTU. UltraHSTU
holds N parallel HSTUTransducer stacks (one per channel listed in
model_config.ultra_hstu.hstu) and concatenates their per-candidate
outputs along the embedding dim before they reach the multi-task
tower.  Each channel's UIH-side feature group is named after the
channel (e.g. "consumption", "consumption_action", ...); the
candidate-side group and the contextual group are shared.

Implementation reuses the DlrmHSTU scaffolding via the
_build_transducer / _stu_embedding_dim hooks added in the previous
refactor commit.  Single-channel UltraHSTU configs degrade cleanly to
a bare HSTUTransducer (no _HSTUTransducerStack wrapper) so AOTI export
and JIT trace shapes are unchanged.

The auto-import in tzrec/__init__.py picks the model up via the
metaclass-based registry on BaseModel; no explicit __init__.py wiring
is required.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Add tzrec/models/ultra_hstu_test.py mirroring dlrm_hstu_test: a
hypothesis-driven property test that builds a 2-channel UltraHSTU
("a", "b"), runs prediction across NORMAL / FX_TRACE / JIT_SCRIPT /
AOT_INDUCTOR graph types and PYTORCH / TRITON kernels, and asserts the
candidate-count shape on each task's logits/probs.  Each channel has
its own per-channel UIH-side feature groups (`a` / `a_action` /
`a_watchtime` / `a_timestamp` and the same for `b`) while `candidate`,
`candidate_timestamp`, and `contextual` are shared, exercising the
name-based key routing through the preprocessor.

Add tzrec/tests/configs/ultra_hstu_kuairand_1k.config: a kuairand-1k
demo cloned from dlrm_ultra_hstu_cutlass_kuairand_1k.config with the
dlrm_hstu block replaced by ultra_hstu and two HSTU sub-configs
(name="a", name="b"), each with the same SLA + truncation settings as
the source.  In this minimal demo both channels reference the same
underlying uih_seq features; in production each channel would consume
distinct sources (e.g. consumption vs engagement).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- Rename _init_after_assert -> _init in DlrmHSTU and the override site
  in UltraHSTU.  Same intent, less noise.
- Use uih_click as the docstring/proto-comment channel-name example
  everywhere (preprocessors.py, hstu_transducer wiring, module.proto,
  ultra_hstu.py); the previous "consumption" wasn't aligned with the
  uih_<event> convention.
- Move UltraHSTU into the natural slot UltraHSTU ultra_hstu = 206 and
  bump PEPNet to pepnet = 207 so the ULTRA model sits next to its
  DlrmHSTU dlrm_hstu = 205 sibling.
- Drop the separate tzrec/tests/configs/ultra_hstu_kuairand_1k.config
  and convert the existing dlrm_ultra_hstu_cutlass_kuairand_1k.config
  to use the ultra_hstu model with two channels (uih_click, uih_view).
  The two channels reference the same underlying uih_seq features in
  this minimal demo; in production each channel would consume distinct
  sources.  Added test_rank_ultra_hstu_cutlass_train_eval_export in
  rank_integration_test.py to exercise the converted config end-to-end
  (train -> eval -> export -> predict), mirroring
  test_rank_dlrm_hstu_cutlass_train_eval_export.
- Simplify _HSTUTransducerStack.forward: every sub-transducer is built
  with return_full_embeddings=False, so the second tuple element is
  always None.  Drop the dead "if all(f is not None for f in
  full_list)" branch.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@tiankongdeguiji tiankongdeguiji added the claude-review Let Claude Review label Apr 29, 2026
@github-actions github-actions Bot removed the claude-review Let Claude Review label Apr 29, 2026
Comment thread tzrec/protos/model.proto Outdated

RocketLaunching rocket_launching = 500;
PEPNet pepnet = 206;
PEPNet pepnet = 207;
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renumbering PEPNet from 206 → 207 and reusing tag 206 for UltraHSTU is wire-format-breaking. Any binary-serialized ModelConfig that previously stored pepnet at tag 206 will now silently parse as ultra_hstu (or be discarded as a type mismatch). In-tree configs are loaded via text_format/json_format (name-keyed) so the in-repo path is unaffected, but this is bad proto hygiene and dangerous for any downstream tooling, exported pipeline.config bundles, or external consumers.

Suggested fix: keep PEPNet pepnet = 206 at its original tag and assign UltraHSTU a fresh, never-published tag (e.g. 208). Tag numbers should be treated as immutable once published.

Comment on lines +88 to +95
channels = self._model_config.hstu
assert len(channels) >= 1, "UltraHSTU requires at least 1 hstu channel"
if len(channels) >= 2:
names = [c.name for c in channels]
assert all(names) and len(set(names)) == len(names), (
"When UltraHSTU has >= 2 channels every channel must set a "
f"unique non-empty `name`, got {names!r}"
)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two related validation gaps in this block:

  1. Single-channel mode skips all name validation. A user with one channel can set name="candidate" (or any reserved feature-group name); _build_transducer will then call embedding_group.group_total_dim("candidate") — which exists, since candidate is the shared candidate-side group — and the UIH path will silently consume candidate-side embeddings without raising. Worth checking that single-channel name (when non-empty) is also non-reserved.

  2. No charset/whitespace check on name. The proto field is an unconstrained optional string. A name containing . would produce keys like "foo.bar.sequence"; EmbeddingGroupImpl.group_total_dim splits on . and would resolve via "foo", mapping the user's intent to an unrelated group. Whitespace or look-alike unicode would mismatch the dict key from the feature-group config, producing a confusing KeyError deep in forward.

Suggested: validate name once (non-reserved, no ., matches ^[A-Za-z_][A-Za-z0-9_]*$) for every channel regardless of count.

Comment on lines +98 to +120
def _build_transducer(
self, contextual_feature_dim: int, max_contextual_seq_len: int
) -> torch.nn.Module:
transducers: List[HSTUTransducer] = []
for ch in self._model_config.hstu:
uih_group = ch.name if ch.name else "uih"
transducers.append(
HSTUTransducer(
uih_embedding_dim=self.embedding_group.group_total_dim(uih_group),
target_embedding_dim=self.embedding_group.group_total_dim(
"candidate"
),
contextual_feature_dim=contextual_feature_dim,
max_contextual_seq_len=max_contextual_seq_len,
contextual_group_name=self._contextual_group_name,
scaling_seqlen=self._model_config.max_seq_len,
**config_to_kwargs(ch),
return_full_embeddings=False,
)
)
if len(transducers) == 1:
return transducers[0]
return _HSTUTransducerStack(transducers)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two small concerns on the stack construction path:

  1. Silent friendliness gap on missing groups. If a channel sets name="uih_click" but the user forgets to define one of the four required groups (uih_click, uih_click_action, uih_click_watchtime, uih_click_timestamp), the missing UIH-group surfaces here as a bare KeyError from group_total_dim, while a missing _action/_watchtime/_timestamp group only surfaces at runtime inside ContextualInterleavePreprocessor.forward. Consider asserting the four groups exist up-front (skipping _action/_watchtime when those encoders aren't configured) so misconfigs fail with a clear message at construction time.

  2. _HSTUTransducerStack always discards the second tuple element. The stack hardcodes return_full_embeddings=False per channel and returns None as the second slot, but if a future change ever flips return_full_embeddings=True on a sub-transducer (e.g. for an aux-loss head), cand[1] is silently dropped. Worth a one-line assert not transducer._return_full_embeddings in the stack constructor and a note in its docstring.

Comment thread tzrec/models/ultra_hstu_test.py Outdated
# Channel names. The MoT routing in ContextualInterleavePreprocessor
# substitutes the channel name for the default "uih" prefix on each
# UIH-side key (e.g. "a.sequence", "a_action.sequence", ...).
_CHANNELS = ("a", "b")
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two coverage gaps from hardcoding _CHANNELS = ("a", "b") and embedding_dim=512 for every channel:

  1. Single-channel path is never exercised. The branch at ultra_hstu.py:118-119 (if len(transducers) == 1: return transducers[0]) — the documented "behaves like DlrmHSTU" mode that bypasses _HSTUTransducerStack entirely — is the easiest path to silently break in a future refactor (e.g. shape regression: single tensor vs. concatenated). Recommend a separate test case (or parametrize over (("a",), ("a","b"))) and assert isinstance(model._hstu_transducer, HSTUTransducer) for the single-channel case.

  2. Heterogeneous channel dims are not exercised. With both channels at embedding_dim=512, sum(c.stu.embedding_dim for c in channels) is mathematically indistinguishable from N * dim or max(...) — a buggy _stu_embedding_dim would still pass. Use distinct dims (e.g. 256 and 512) in at least one case so the sum is the only correct answer.

Also: there are no negative tests for the __init__ validation assertions (len(channels) >= 1; >= 2 channels need unique non-empty names) — easy to add and they protect user-facing error paths.

Comment thread tzrec/models/ultra_hstu_test.py Outdated
Comment on lines +106 to +114
embedding_name=f"{channel}_video_id_emb",
num_buckets=1000,
)
),
feature_pb2.SeqFeatureConfig(
id_feature=feature_pb2.IdFeature(
feature_name="video_cat",
embedding_dim=16,
embedding_name=f"{channel}_video_cat_emb",
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Per-channel embedding_name=f"{channel}_..._emb" gives every channel its own embedding table for the same physical feature (video_id, video_cat). Users will copy this pattern, and in production it multiplies sparse-parameter count, TBE forward/backward work, and all-to-all communication volume by N — typically the dominant cost in a sharded recsys deployment. The integration config (dlrm_ultra_hstu_cutlass_kuairand_1k.config) implicitly shares table names because both channels point at the same uih_seq__video_id, so the test config and the integration config disagree on the recommended pattern.

Suggest making the test share embedding_name across channels by default to set the right example, and add a class/proto docstring note recommending shared embedding_name unless per-channel tables are explicitly desired.

Comment on lines +78 to +101
message UltraHSTU {
// ULTRA-HSTU model with Mixture of Transducers. Holds one HSTU
// sub-config per channel; per-candidate outputs are concatenated
// along the embedding dim. When `hstu` has a single entry the
// model behaves like DlrmHSTU (modulo the optional name routing on
// the channel). When `hstu` has >= 2 entries every channel must
// set a unique non-empty `name`, and the corresponding feature
// groups (uih-side replaced by `<name>`, `<name>_action`,
// `<name>_watchtime`, `<name>_timestamp`) must be defined. The
// candidate-side and contextual feature groups are shared across
// every channel.
repeated HSTU hstu = 1;
// multi task tower config
required FusionMTLTower fusion_mtl_tower = 2;
// max sequence length
required uint32 max_seq_len = 3;
// item embedding mlp hidden dimension
optional uint32 item_embedding_hidden_dim = 4 [default = 512];
// enables loss averaging computation globally across all ranks (total rank)
// instead of locally (local rank).
optional bool enable_global_average_loss = 5 [default = true];
// timestamp of sequence is ascending or descending
optional bool sequence_timestamp_is_ascending = 6 [default = true];
// concat all contextual features on channel dim as one token
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fields 2-7 of UltraHSTU are an exact copy of DlrmHSTU fields 2-7; only field 1 (hstu repeated vs single) differs. The Python _init() already references all of these via self._model_config.X and works for both messages purely by field-name duck typing — that contract silently breaks the day someone adds a field to one and forgets the other.

Lower-effort options: add a // keep in sync with DlrmHSTU 2-7 TODO comment to both messages. Higher-effort: extract a shared sub-message embedded in both, or just make DlrmHSTU.hstu repeated and retire UltraHSTU entirely.

@github-actions
Copy link
Copy Markdown

Review Summary

Solid, well-scoped addition. The _init / _build_transducer / _stu_embedding_dim extraction in DlrmHSTU is a clean Template Method seam, the MoT key-routing in ContextualInterleavePreprocessor is small and self-contained, and the test matrix covers NORMAL / FX_TRACE / JIT_SCRIPT / AOT_INDUCTOR × PYTORCH / TRITON. Most findings are robustness / contract hygiene, not correctness.

Highest-impact items (see inline comments for details):

  • tzrec/protos/model.proto:73 — renumbering PEPNet 206→207 to give UltraHSTU tag 206 is wire-format-breaking. Reusing a previously-published tag for a different message is dangerous for any downstream binary-serialized ModelConfig. Prefer keeping PEPNet=206 and giving UltraHSTU a fresh tag (e.g. 208).
  • tzrec/models/ultra_hstu.py channel-name validation — single-channel mode skips name validation entirely; name has no charset/whitespace/reserved-token check, so values like "candidate" or "foo.bar" silently misroute via the embedding-group .split('.') lookup.
  • _HSTUTransducerStack (ultra_hstu.py:41-48) — silently discards the second tuple element from each sub-transducer; one-line invariant assertion would prevent a confusing future regression. Also worth asserting up-front that all four required UIH groups exist for each channel, instead of letting the failure surface deep in forward.
  • Test coverage gaps (ultra_hstu_test.py) — single-channel degradation path, heterogeneous per-channel embedding_dim, and __init__ validation negative tests are not exercised. The hard-coded per-channel embedding_name in the test also models a pattern that would multiply embedding tables N-fold in production; consider sharing embedding_name across channels to match the integration config.
  • multi_task_rank.protoUltraHSTU fields 2-7 mirror DlrmHSTU exactly; consider a sync-comment or refactor to prevent silent drift.

Performance note (informational, not blocking): _HSTUTransducerStack.forward runs the N transducers in a serial Python for and re-executes the candidate-side preprocessor preamble (including N fx_int_item D2H syncs) per channel. Fine for 2 channels, but the design clearly aims at higher N where this becomes the dominant cost. Worth profiling and potentially hoisting the candidate-side state out of the inner loop or using CUDA streams.

Docs: the deferred docs/source/models/ultra_hstu.md is acknowledged in the PR body. Worth also touching dlrm_hstu.md group-name section and generative.rst / README.md model table when that doc lands.

tiankongdeguiji and others added 14 commits April 30, 2026 11:18
…eddings, slot swap

- ultra_hstu_test: parametrize over (single, multi_uniform, multi_hetero)
  channel topologies via @parameterized.expand on top of the existing
  hypothesis sweep.  Asserts on the single-channel branch returning a
  bare HSTUTransducer (not _HSTUTransducerStack), and on heterogeneous
  STU dims (256 + 512) so that _stu_embedding_dim's `sum` is the only
  correct answer (a buggy `N * dim` or `max(...)` would no longer pass).
  Drop per-channel embedding_name in the test helper so the pattern
  matches the integration config and won't push users toward
  per-channel embedding tables.

- ultra_hstu / module.proto: docstring + proto-comment notes
  recommending shared embedding_name across channels.  EmbeddingGroup
  dedupes by embedding_name, so per-channel tables are an opt-in for
  the unusual case where you specifically want disjoint tables (it
  multiplies sparse-parameter count, TBE forward/backward work, and
  all-to-all volume by the channel count).

- model.proto: pepnet = 206 (restored to its original deployed slot,
  placed adjacent to dlrm_hstu = 205); ultra_hstu = 207.

The kuairand-mot-1k integration test data is generated separately by
experiments/preprocess_kuairand_mot.py (gitignored, run locally) and
wired into ci_data.sh + the cutlass config in a follow-up commit once
the parquet is uploaded to OSS.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…lerate)

- ci_data.sh: switch every kuairand-1k* URL from oss-cn-beijing to
  oss-accelerate, and add wget lines for the new kuairand-mot-1k
  {train,eval} parquet so CI gets distinct per-channel UIH sequences.
  oss-accelerate is the global-acceleration endpoint and works from
  the GitHub-hosted CPU runners.

- .gitignore: cover data/test/kuairand-mot-1k* (the existing
  kuairand-1k* glob does not match the mot-prefixed files).

- dlrm_ultra_hstu_cutlass_kuairand_1k.config: point train/eval paths
  at the new kuairand-mot-1k parquet; replace the single uih_seq
  feature_config with click_seq + view_seq feature_configs (both
  declare the four sub-features and share video_id_emb so
  EmbeddingGroup dedupes the table); rewire the per-channel
  feature_groups (uih_click* -> click_seq__*, uih_view* ->
  view_seq__*) so the integration test now exercises truly distinct
  per-channel UIH streams (action_weight & 1 for clicks,
  action_weight & 64 for long-views) sharing the candidate side and
  contextual features.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…ation export path

ultra_hstu_test on CPU-only CI was hitting "Torch not compiled with
CUDA enabled" because @unittest.skipIf(*gpu_unavailable) sat *outside*
@parameterized.expand and so didn't propagate to the generated
test_ultra_hstu_{0_single,1_multi_uniform,2_multi_hetero} methods.
Move skipIf below expand so each parametrized case carries the skip.

The cutlass-kuairand integration test exposed four pre-existing fx /
AOT-export bugs in shared HSTU SLA + truncation infra that no test
had previously exercised end-to-end (PR alibaba#486 added the config but did
not wire it into rank_integration_test).  Each fix is minimal:

1. tzrec/ops/hstu_attention_utils.py:build_sla_func_tensor —
   `if seq_offsets.dtype != torch.int32: ...` evaluates Proxy.dtype
   under fx.  Drop the conditional; `tensor.to(dtype)` is a no-op when
   already-matching, so unconditional cast costs nothing on the fast
   path.

2. tzrec/modules/gr/stu.py:STULayer.forward — the SLA cache
   hit-check `my_sig == prev_attn_func_sig` compared a tuple
   containing `x.size(0)` (a Proxy at trace time) against another
   Proxy-containing tuple, falling into Proxy.__eq__ → bool →
   TraceError.  Skip the cache under is_fx_tracing(); it's a
   runtime-only optimization.

3. tzrec/ops/hstu_attention_utils.py:compute_stu_truncation_plan —
   `int(new_lengths.max().item())` raises `int(Proxy)` under fx.
   Replace with `fx_int_item(new_lengths.max())` (already
   `@torch.fx.wrap`-ed in tzrec.utils.fx_util).  Same trick for the
   new offset totals (total_dropped/kept/prefix/rest) so
   apply_stu_truncation_plan can pass them as static `total_len_*`
   into split_2D_jagged and skip the `.item()` fallback inside the
   triton fake impl.  STUTruncationPlan grows four int fields — see
   the docstring.  hstu_transducer._replay_truncation_state now
   derives `post_truncation_total_uih_len` from
   `plan.total_kept - total_targets` instead of returning None, for
   the same reason.  hstu_transducer_test updated for the new
   keyword and the no-longer-None return.

4. tzrec/ops/hstu_attention_utils.py:build_sla_func_tensor —
   `searchsorted(seq_offsets_i32[1:], pos_global, ...)` failed
   Inductor lowering with `NotImplementedError: SliceView` because a
   slice produced a SliceView IR that searchsorted's
   `_boundaries_helper` can't take stride from.  Replace
   searchsorted with `torch.diff(seq_offsets_i32)` + `repeat_interleave`,
   which avoids both the slice and the searchsorted entirely while
   producing the same per-position batch_ids.

Also fix predict_input_path in test_rank_ultra_hstu_cutlass_train_eval_export
(was pointing at the old kuairand-1k parquet; now correctly uses the
kuairand-mot-1k file the rest of the test trains+evals against).

All five tests pass locally:
  python -m tzrec.models.dlrm_hstu_test                    (back-compat)
  python -m tzrec.models.ultra_hstu_test                   (parametrized)
  python -m tzrec.modules.gr.preprocessors_test            (back-compat)
  python -m tzrec.modules.gr.hstu_transducer_test          (back-compat)
  python -m tzrec.modules.gr.stu_test                      (back-compat)
  python -m unittest tzrec.tests.rank_integration_test.RankIntegrationTest.test_rank_ultra_hstu_cutlass_train_eval_export

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…ev-sigs in STUStack

Replaces the is_fx_tracing() short-circuit in 32b6ad9 with the proper
fix.  The runtime sig comparison was failing under fx symbolic tracing
because the 6-tuple cache key embedded `x.size(0)` (a Proxy at trace
time).  The dynamic dim only ever served to invalidate the cache
across the one statically-known event -- mid-stack truncation -- so
encode that signal statically: `STUStack.__init__` precomputes a
per-layer `prev_attn_func_sig` that is `None` at layer 0 and at the
truncation-split layer, and the previous layer's static sig
otherwise.

The cache key is now a comma-separated string encoding
`(sla_k1, sla_k2, contextual_seq_len, num_heads, target_aware)`
exposed as `STULayer.attn_func_static_sig` (read-only @Property).
Strings are pure Python primitives and trivially fx-trace-safe.
STULayer.forward's hit-check is `prev_attn_func is not None and
prev_attn_func_sig == self.attn_func_static_sig`.  forward returns a
2-tuple now (no sig in the return); the precomputed list is the
single source of truth for "which sig describes prev_attn_func".

STUStack.forward's loop body is now: optional `truncate_input` ->
call layer with precomputed sig -> carry the new attn_func forward.
No sig bookkeeping at all in the hot path.  At the truncation index
the precomputed sig is already None, so the layer's comparison fails
and it rebuilds with the post-truncation total_q.

Drops the `is_fx_tracing` import added in 32b6ad9; the band-aid is
gone.  Cache invariant preserved: at most two `build_sla_func_tensor`
calls per forward (layer 0 + the truncation-split layer).
stu_test cache tests adapted to the 2-tuple return + property.

Adds parametrized fx-trace coverage in hstu_attention_utils_test:
BuildSlaFuncTensorTraceTest and StuTruncationTraceTest each exercise
NORMAL + FX_TRACE TestGraphTypes on tiny nn.Module wrappers around
build_sla_func_tensor and the compute+apply_stu_truncation_plan pair.
Eager and traced outputs must match.  Targeted regression coverage
for the hstu_attention_utils.py Proxy fixes from 32b6ad9.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Replaces the dedicated BuildSlaFuncTensorTraceTest /
StuTruncationTraceTest classes from 355cd29 with
@parameterized.expand((NORMAL, FX_TRACE)) on the existing tests.
Each numerical-correctness assertion now pins both the eager output
*and* the fx-traceability of the helper in a single test method.

BuildSlaFuncTensorTest._build accepts a graph_type and routes
through a tiny nn.Module wrapper + create_test_module so NORMAL
runs eagerly and FX_TRACE runs the symbolic-traced module.  All 7
existing tests gain the parametrize decorator (14 cases total).
test_int32_offsets_skip_cast still verifies dtype-cast equivalence
between int32 and int64 offsets, both via the wrapper now.

StuTruncationTest.test_matches_reference cross-products the existing
5 input cases with (NORMAL, FX_TRACE) -> 10 cases.
test_replay_on_parallel_jagged gains the parametrize too;
test_validation_raises_on_negative_params stays eager-only since
validation lives in compute_stu_truncation_plan() construction
before any forward.

The wrapper modules carry a static `target_aware` flag so under fx
trace the `if num_targets is not None` branch is selected at
construction time -- a Proxy num_targets would otherwise always
trace the not-None branch and break callers that pass None at run
time.

Total goes from 13 tests -> 27 tests; same assertion strength,
+fx-trace regression coverage.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The three private nn.Module wrappers in hstu_attention_utils_test
(_BuildSlaFuncTensorWrapper, _StuTruncationWrapper,
_ReplayTruncationWrapper) are scaffolding for the parametrized
fx-trace tests; their constructor + forward arg lists already say
what they do.  Drop the class docstrings and inline notes.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
… separator

Tighten the multi-line WHY-comments added across the MoT/SLA/truncation
paths.  The intent (fx-trace safety, Inductor SliceView avoidance,
.item() vs FakeTensor) stays in one or two lines per touchpoint;
removed restated-from-code rationale.

Switch attn_func_static_sig from comma- to colon-separated so the
encoded sig (e.g. "256:32:0:4:1") reads as obviously categorical
rather than a numeric tuple repr.

No behavior change; stu_test (13/OK) and hstu_attention_utils_test
(27/OK) pass.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Cuts a release that includes UltraHSTU MoT (PR alibaba#492).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- docs/source/models/ultra_hstu.md (NEW): per-section walk through
  the four ULTRA-HSTU optimizations (SLA, mid-stack truncation, MoT,
  selective rematerialization) following the architecture in Ding et
  al., 2026 (arXiv:2602.16986); includes a two-channel kuairand
  config snippet and a field reference for the UltraHSTU /
  HSTU.name / sla_k1+k2 / attn_truncation_* additions.
- README.md: add ULTRA-HSTU row under "Generative Recommendation"
  with description "HSTU with Semi-Local Attention, Attention
  Truncation, and Mixture of Transducers".
- docs/source/models/generative.rst: add ultra_hstu to the toctree.
- Rename tzrec/tests/configs/dlrm_ultra_hstu_cutlass_kuairand_1k.config
  -> ultra_hstu_cutlass_kuairand_1k.config (drop the "dlrm_" prefix
  -- this config now uses the ultra_hstu model proto, not dlrm_hstu).
  rank_integration_test reference updated.
- Bump copyright year on the two new files (ultra_hstu.py,
  ultra_hstu_test.py) from 2025 to 2026.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
ultra_hstu.md gains the conventional 示例 / 参考论文 sections (data
+ config + export + paper link) following the pattern in dlrm_hstu.md
and other model docs.  Config link points at the user-uploaded
https://tzrec.oss-cn-beijing.aliyuncs.com/config/models/ultra_hstu_kuairand.config
(URL verified, HTTP 200); reuses the kuairand-27k dataset.

Drops the now-stale # total_uih_len comment in hstu_transducer_test:
the assertion + variable name already say it.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Match the simpler pattern used in deepfm/ple/dlrm/dssm docs: just a
config link, no 数据 / 模型导出 sub-sections.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Drop the 4-line note about attn_func_static_sig in the return
description; the property docstring already covers it.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@tiankongdeguiji tiankongdeguiji added the claude-review Let Claude Review label Apr 30, 2026
@github-actions github-actions Bot removed the claude-review Let Claude Review label Apr 30, 2026

- **Mixture of Transducers (MoT)** —— 模型并行运行 N 个 `HSTUTransducer`,每个对应一路 UIH 通道(如点击、长播等);每个通道有自己的 STU 栈、自己的 SLA / truncation 配置。各通道针对 candidate 的输出 embedding 在 channel 维上拼接后喂给统一的 `FusionMTLTower`。**Candidate 与 contextual 共享**,UIH 按通道拆分;channel 数由 `repeated HSTU hstu` 决定。

- **Selective Rematerialization** —— STU layer 的反向传播按需重算两类中间张量以省显存:`recompute_normed_x_in_backward` 控制 LayerNorm 输出是否存储,`recompute_uvqk_in_backward` 控制 (U, V, Q, K) 投影是否存储。两者均默认 `true`,对 attention activation memory 占主导的大模型尤其重要。
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The proto field names cited here don't exist. The actual fields in tzrec/protos/module.proto:249,251 are recompute_normed_x and recompute_uvqk (no _in_backward suffix — that suffix only appears as the internal kwarg in tzrec/ops/hstu_compute.py). A user copy-pasting recompute_normed_x_in_backward: true into a stu { ... } config will get a proto-parse error. Please rename to the actual proto field names.

Comment on lines +84 to +90
assert len(channels) >= 1, "UltraHSTU requires at least 1 hstu channel"
if len(channels) >= 2:
names = [c.name for c in channels]
assert all(names) and len(set(names)) == len(names), (
"When UltraHSTU has >= 2 channels every channel must set a "
f"unique non-empty `name`, got {names!r}"
)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two related validation gaps worth closing up-front:

  1. Reserved-name collisions are silently accepted. name="uih", name="candidate", or name="contextual" on a multi-channel config will pass the all(names) and uniqueness check, but the derived UIH-side keys (preprocessors.py:151-154) will then alias the candidate / contextual / default-uih groups. Suggest rejecting reserved names and *_action / *_watchtime / *_timestamp suffixes that would collide with the per-channel naming convention.

  2. Missing feature_group is reported deep, not at construction. A typo in name surfaces as a KeyError in preprocessors.py:373 (grouped_features[f"{self._uih_key}.sequence"]) at forward time, or as whatever embedding_group.group_total_dim(uih_group) raises mid-_build_transducer. Validating up-front that for every channel the four groups (<name>, <name>_action, <name>_watchtime, <name>_timestamp) exist on self.embedding_group and listing the missing keys would make misconfigurations obvious.

Comment on lines +343 to +403
@unittest.skipIf(*gpu_unavailable)
@given(
graph_type=st.sampled_from(
[
TestGraphType.NORMAL,
TestGraphType.FX_TRACE,
TestGraphType.JIT_SCRIPT,
TestGraphType.AOT_INDUCTOR,
]
),
kernel=st.sampled_from([Kernel.PYTORCH, Kernel.TRITON]),
contextual_group_type=st.sampled_from(
[model_pb2.FeatureGroupType.DEEP, model_pb2.FeatureGroupType.SEQUENCE]
),
sequence_timestamp_is_ascending=st.sampled_from([True, False]),
enable_global_average_loss=st.sampled_from([True, False]),
concat_contextual_features=st.sampled_from([True, False]),
)
@settings(
verbosity=Verbosity.verbose,
max_examples=6,
deadline=None,
)
def test_ultra_hstu(
self,
case_name,
channel_specs,
graph_type,
kernel,
contextual_group_type,
sequence_timestamp_is_ascending,
enable_global_average_loss,
concat_contextual_features,
) -> None:
# JIT_SCRIPT only supports the PyTorch kernel today.
assume(
(graph_type == TestGraphType.JIT_SCRIPT and kernel == Kernel.PYTORCH)
or graph_type != TestGraphType.JIT_SCRIPT
)

device = torch.device("cuda")
ultra_hstu = _build_model(
device=device,
channel_specs=channel_specs,
contextual_group_type=contextual_group_type,
enable_global_average_loss=enable_global_average_loss,
sequence_timestamp_is_ascending=sequence_timestamp_is_ascending,
concat_contextual_features=concat_contextual_features,
)

# Single-channel UltraHSTU returns a bare HSTUTransducer (no
# _HSTUTransducerStack wrapper) so the predict() path matches
# DlrmHSTU's exactly. Multi-channel returns the stack.
if len(channel_specs) == 1:
self.assertIsInstance(ultra_hstu._hstu_transducer, HSTUTransducer)
self.assertNotIsInstance(ultra_hstu._hstu_transducer, _HSTUTransducerStack)
else:
self.assertIsInstance(ultra_hstu._hstu_transducer, _HSTUTransducerStack)
self.assertEqual(
ultra_hstu._stu_embedding_dim(), sum(d for _, d in channel_specs)
)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Coverage gaps worth closing:

  1. Three new __init__ error branches are untested (ultra_hstu.py:84-90): zero channels, len>=2 with empty name, len>=2 with duplicate names. All are CPU-only assertion checks — trivial to add and would prevent a regression silently dropping the validations.

  2. GPU-skip hides CPU-only structural assertions. Lines 396-403 (isinstance(..., _HSTUTransducerStack) and _stu_embedding_dim() == sum(d for _, d in channel_specs)) are pure-Python checks that gate the central MoT structural invariant (the only-correct-answer for multi_hetero is sum, not max or N*dim). Hiding them behind @unittest.skipIf(*gpu_unavailable) means CPU CI never validates them. Recommend splitting into a CPU-only structural test plus the GPU forward test.

  3. No single-channel-with-empty-name case — the ("single", [("a", 256)]) parametrize case uses a non-empty name, so the _build_transducer fallback to the "uih" group at ultra_hstu.py:98 is never exercised. This is the documented "无痛迁移 from DlrmHSTU" path.

  4. Output is shape-checked only, not value-checked. A two-channel parity test (identical seeds for both channels, then assert the two halves of torch.cat(..., dim=-1) along the embedding dim are equal) would catch a broken concat order or wrong index in _HSTUTransducerStack.forward for negligible cost.

Comment on lines +37 to +44
def forward(
self, grouped_features: Dict[str, torch.Tensor]
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
cand_list: List[torch.Tensor] = []
for transducer in self._transducers:
cand, _ = transducer(grouped_features)
cand_list.append(cand)
return torch.cat(cand_list, dim=-1), None
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

N transducers run serially in a Python for loop on a single CUDA stream — each transducer launches its own preprocessor + STU stack + postprocessor + many small jagged kernels. With per-channel embedding_dim typically 32–128, kernel-launch latency and SM under-occupancy will dominate at small batch sizes, so MoT scaling won't be linear in N.

The user-facing doc at docs/source/models/ultra_hstu.md:11 calls this "模型并行运行 N 个 HSTUTransducer", which oversells the GPU behavior — eager + AOT-export both serialize today. At minimum, please soften the doc claim. Optionally, on the eager-train path you could record each iteration on a dedicated torch.cuda.Stream and join before torch.cat (skip on AOT-export).

Also, no shape/dim sanity check before torch.cat(..., dim=-1) — divergent per-candidate row counts across transducers (e.g. a future per-channel truncation behavior change) would silently misalign rows. A cheap precondition assert would prevent a hard-to-debug regression.

Comment thread tzrec/modules/gr/stu.py
Comment on lines +437 to +442
def attn_func_static_sig(self) -> str:
"""SLA NFUNC cache key (colon-separated, fx-trace-safe)."""
return (
f"{self._sla_k1}:{self._sla_k2}:{self._contextual_seq_len}:"
f"{self._num_heads}:{int(self._target_aware)}"
)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The cache key omits _max_attn_len and _causal, both of which affect NFUNC mask semantics. It happens to be safe today because STUStack constructs every layer in a stack with identical kwargs (hstu_transducer.py:111), so all layers in a stack share the same (max_attn_len, causal). But two layers that differ only in max_attn_len (a plausible future heterogeneous-layer config) would silently reuse a stale mask via this sig.

Defensive fix — include both in the sig:

return (
    f"{self._sla_k1}:{self._sla_k2}:{self._contextual_seq_len}:"
    f"{self._max_attn_len}:{int(self._causal)}:"
    f"{self._num_heads}:{int(self._target_aware)}"
)

@github-actions
Copy link
Copy Markdown

Code Review Summary

Five reviewers covered code quality, performance, tests, docs, and security. Posted 5 inline comments on the highest-signal findings; collecting the rest here.

Most actionable (see inline)

  • Doc bug: docs/source/models/ultra_hstu.md:13 cites proto fields recompute_normed_x_in_backward / recompute_uvqk_in_backward — these don't exist; actual proto fields are recompute_normed_x / recompute_uvqk. A copy-paste from the docs into a config will fail to parse.
  • Validation: UltraHSTU.__init__ accepts reserved channel names (uih, candidate, contextual) and defers feature-group existence to a deep KeyError.
  • Tests: Three new error branches in UltraHSTU.__init__ are untested; structural MoT assertions in ultra_hstu_test.py:396-403 are CPU-only checks gated behind gpu_unavailable.
  • Docs vs reality: "模型并行运行 N 个 HSTUTransducer" oversells GPU scaling — _HSTUTransducerStack runs serially on a single CUDA stream.
  • SLA cache key: STULayer.attn_func_static_sig omits max_attn_len and causal (safe today since all layers in a stack share kwargs, but a foot-gun for future heterogeneous stacks).

Other notes worth scanning

  • arXiv ID 2602.16986 — plausible but the high .16986 suffix for a Feb 2026 submission is on the edge of monthly volume. Worth double-checking the link before publishing.
  • scripts/ci/ci_data.sh — no set -euo pipefail and no checksum verification on the new kuairand-mot-1k parquet. The 32-hex hash in the URL is purely advisory; a partial download will silently produce CI noise.
  • MoT redundant work — under MoT the _combine_embeddings interleave mask (preprocessors.py:255-277) and contextual-prefix concat_2D_jagged (preprocessors.py:301-324) re-compute identical results across channels. Could share across channels if a profile shows them hot.
  • No MoT correctness parity test — current end-to-end test asserts only output shapes. Two channels with identical seeds vs. single-channel double-concat would catch a broken torch.cat(dim=-1) order or wrong index in _HSTUTransducerStack.forward.
  • STU.truncate_input declared @abc.abstractmethod with a docstring asking subclasses to "override and raise NotImplementedError" — that contradicts the abstract decorator. Either drop the decorator (provide a default raise body on the base) or drop the docstring guidance.
  • Selective rematerialization defaults — the new docs claim defaults are true, which matches module.proto:249,251. Just calling out that this is a behavior-affecting default users will likely not notice.

Praise

  • The fx-trace + AOT-export hardening is well-thought-out: STUTruncationPlan precomputes total_dropped / total_kept so split_2D_jagged skips its .item() fallback under FakeTensor; STUStack precomputes _prev_attn_func_sig_per_layer so the forward loop has zero sig bookkeeping. The _HSTUTransducerStack bypass when len == 1 is a nice touch — keeps single-channel UltraHSTU traceable identically to DlrmHSTU.
  • Good defensive choice in build_sla_func_tensor: the torch.diff + repeat_interleave rewrite (avoiding a slice-SliceView Inductor crash) and the clamp(L - T, min=0) to dodge silent NaN attention.
  • Test parametrization across (single, multi_uniform, multi_hetero) × all four TestGraphTypes is thorough on the structural side; the suggestions above are about what's not yet covered, not the breadth.

Both are init-time STU constants that affect attention semantics.
Excluding them was safe today because every layer in a STUStack is
constructed from the same kwargs (hstu_transducer.py:107), but a
future heterogeneous stack with mismatching max_attn_len or causal
would silently reuse a stale `attn_func` from the cache.  Cost is
negligible (one int + one bool added to the colon-separated key).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
eric-gecheng
eric-gecheng previously approved these changes Apr 30, 2026
Each integration-test pipeline stage runs as a torchrun subprocess
with stdout/stderr piped to ./tmp/<dir>/log_*.txt; the unittest
framework only sees AssertionError(False is not true).  CI workflow
has no upload-artifact step, so the per-stage logs are wiped with
the runner workspace and the real stack trace from any flaky
failure is unrecoverable from gh run view --log.

When run_cmd's subprocess fails terminally, dump the last 200 lines
of the captured log to stdout with a RUNCMD FAILED header so the
trace lands in CI logs directly.  Tail-only to keep CI output
bounded.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
tiankongdeguiji and others added 3 commits May 1, 2026 10:45
The cutlass_hstu_train_eval_export integration test was failing
deterministically (4 CI runs in a row) with an unrecoverable
AssertionError(False is not true) -- the actual stack trace lived in
torchrun's per-stage log_export.txt that the runner deletes on
cleanup.  Three changes:

1. tzrec/utils/misc_util.py:run_cmd -- bump the diagnostic tail-dump
   from 200 to 500 lines so the full Python traceback (Inductor stack
   + sympy frame + final exception) lands in CI stdout.  200 lines
   only captured the inner Inductor stack and was cut off mid-frame
   above the actual exception text.

2. tzrec/ops/hstu_attention_utils.py:build_sla_func_tensor -- add
   `_emit_check_total_q_positive(total_q)` (an @torch.fx.wrap'd
   helper that emits `torch._check_is_size + torch._check(>0)`).
   Inductor's combine_contiguous_dims for the (nheads, 3, total_q)
   output emits ModularIndexing(idx, total_q, 3) whose sympy
   simplifier raises ZeroDivisionError under AOT compile when
   total_q is dynamic and not provably > 0 (the `Max(1, total_q)`
   in computed strides is the smoking gun).  The check is fx-trace
   safe because the wrap turns it into a call_function leaf, and
   no-op outside torch.compiler.is_compiling() so eager runs pay
   nothing.

3. tzrec/ops/hstu_attention_utils.py:build_sla_func_tensor -- swap
   `repeat_interleave(arange(B), seq_lengths)` for `searchsorted(
   cumsum(seq_lengths), pos_global, right=True).clamp_max(B - 1)`.
   repeat_interleave's Inductor lowering generates additional
   ModularIndexing patterns that compound the same sympy-divisor
   problem; the searchsorted variant uses fewer dynamic-shape ops.
   `clamp_max(B - 1)` is an eager no-op (pos_global < boundaries[-1]
   always) but gives Inductor a provable upper bound for the
   downstream `seq_lengths[batch_ids]` indirect index, so it doesn't
   insert a defensive device-side assert that would fire under AOT
   compile (separately observed locally before the torch._check).
   As a bonus, fixes a subtle off-by-one in the original (pre-PR)
   `searchsorted(seq_offsets[1:], pos)` call: without right=True,
   tokens at exactly seq_offsets[i] were assigned to batch i-1
   instead of batch i.

Verification: ran
  python -m unittest tzrec.tests.rank_integration_test.RankIntegrationTest.test_rank_ultra_hstu_cutlass_train_eval_export
locally with a wiped /tmp/torchinductor_tianyi cache; passes
in 162s.  Pre-fix the same command failed in 125s with
`InductorError: ZeroDivisionError`.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The previous fix (eee0f99) added a torch._check(total_q > 0) hint and
swapped repeat_interleave for searchsorted+clamp.  The hint passed
once locally with a fresh inductor cache but failed on rerun and on
CI -- fundamentally because Inductor's combine_contiguous_dims for
the (nheads, 3, total_q) output of `build_sla_func_tensor` emits
`ModularIndexing(idx, total_q, 3)` and sympy's simplifier hits a
ZeroDivisionError on the symbolic divisor regardless of any
torch._check constraint we add (the constraint apparently doesn't
propagate to `combine_contiguous_dims`'s shape-simplification path).

The robust fix is to make Inductor never see the broadcast: wrap
`func_2d.unsqueeze(0).expand(nheads, 3, total_q).contiguous()` in a
custom op (`tzrec::_sla_broadcast_func_to_heads`) that is a black
box to Inductor's fuser.  Inductor passes the (3, total_q) input
through unchanged and the eager-time impl materializes the head
dim outside any compiled region.

Uses the low-level torch.library.define/impl/register_fake API
matching cutlass_hstu_attention.py's convention; the
`@torch.library.custom_op` decorator triggers an AOTI multi-thread
predict deadlock per the comment in that file.

Drop the unused `_emit_check_total_q_positive` helper and update
the docstring -- the returned tensor is now contiguous (head dim
materialized inside the custom op), no longer a stride-0 broadcast
view.

Verification: ran
  python -m unittest tzrec.tests.rank_integration_test.RankIntegrationTest.test_rank_ultra_hstu_cutlass_train_eval_export
3 times in a row locally (fresh cache, then warm cache twice); all
three passed (163s, 146s, 151s -- second run faster due to cache
reuse on the now-correct compiled kernel).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…ensor

Commit eee0f99 swapped repeat_interleave for searchsorted+clamp under
the (later disproven) hypothesis that repeat_interleave's Inductor
lowering was the source of the ZeroDivisionError.  The actual root
cause was the (nheads, 3, total_q) head broadcast (fixed in cf7c5e5
by wrapping that op in a black-box custom op).  With the broadcast
now opaque to Inductor, repeat_interleave's lowering is no longer
in the failure path and the searchsorted variant just adds two extra
ops (cumsum + clamp_max) to the SLA cache build for no benefit.

Restoring the simpler diff + repeat_interleave form from commit
32b6ad9.  Verified: ran the integration test 3x locally (fresh
inductor cache + warm cache twice); all three passed (166s, 149s,
147s) -- same time profile as the searchsorted version.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@tiankongdeguiji tiankongdeguiji merged commit f2d0116 into alibaba:master May 1, 2026
6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants