[feat] ULTRA-HSTU Mixture of Transducers (MoT)#492
[feat] ULTRA-HSTU Mixture of Transducers (MoT)#492tiankongdeguiji merged 26 commits intoalibaba:masterfrom
Conversation
Add an optional `name` parameter to ContextualInterleavePreprocessor and
thread it through HSTUTransducer.__init__ -> create_input_preprocessor.
When `name == ""` (default) the four UIH-side keys
("uih.sequence", "uih.sequence_length", "uih_action.sequence",
"uih_watchtime.sequence", "uih_timestamp.sequence") are unchanged, so
all existing DlrmHSTU configs and tests pass without modification.
When `name != ""` (e.g. "consumption") the channel name *replaces* the
default `uih` prefix, so the preprocessor reads
"consumption.sequence", "consumption_action.sequence",
"consumption_watchtime.sequence", "consumption_timestamp.sequence".
This is the substrate for ULTRA-HSTU's Mixture of Transducers, where
N parallel HSTUTransducer stacks each consume a disjoint UIH channel
while sharing the candidate stream and contextual features.
Candidate-side keys ("candidate.sequence", "candidate.sequence_length",
"candidate_timestamp.sequence") and the contextual key (already
configurable via contextual_group_name) are intentionally unchanged.
Allowlist "MoT" in codespell to avoid the false positive on the MoT
acronym in docstrings.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…ooks Factor the transducer construction and the STU embedding dim accessor out of __init__ into two overridable hooks, and split the post-assert init body into _init_after_assert(). Subclasses (UltraHSTU) reuse the same scaffolding while assigning their own model-type assertion and constructing alternative transducer wrappers (e.g. a stack of MoT channels with concatenated outputs). No behavior change for DlrmHSTU; existing test_dlrm_hstu and test_dlrm_hstu_task_weight pass unchanged. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Add an optional `name` field to the HSTU proto message. When set, the HSTUTransducer reads UIH-side keys with the channel name *replacing* the default `uih` prefix (e.g. name="consumption" causes the preprocessor to read consumption.sequence, consumption_action.sequence, consumption_watchtime.sequence, consumption_timestamp.sequence). Empty (default) preserves the existing uih.sequence, uih_action.sequence, etc. lookups, so all current DlrmHSTU configs remain valid. This field will be set per channel in the upcoming UltraHSTU model, which holds `repeated HSTU hstu` (one entry per Mixture-of-Transducers channel) and concatenates per-candidate outputs. The wiring through HSTUTransducer.__init__ -> create_input_preprocessor -> ContextualInterleavePreprocessor was added in the previous refactor commit, so DlrmHSTU configs can opt-in to a name without any further code change. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Add the UltraHSTU model proto with `repeated HSTU hstu` (one entry per Mixture-of-Transducers channel) and the same auxiliary fields as DlrmHSTU (fusion_mtl_tower, max_seq_len, item_embedding_hidden_dim, enable_global_average_loss, sequence_timestamp_is_ascending, concat_contextual_features). When `hstu` has a single entry the model behaves like DlrmHSTU (modulo the optional `name` routing on the channel). When `hstu` has >= 2 entries every channel must set a unique non-empty `name`, and the corresponding feature groups (uih-side replaced by `<name>`, `<name>_action`, `<name>_watchtime`, `<name>_timestamp`) must be defined. The candidate-side and contextual feature groups are shared across every channel. Wire the new model type into the `oneof model` in model.proto with field number 207 (206 is taken by pepnet). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Add the UltraHSTU model class as a sibling to DlrmHSTU. UltraHSTU holds N parallel HSTUTransducer stacks (one per channel listed in model_config.ultra_hstu.hstu) and concatenates their per-candidate outputs along the embedding dim before they reach the multi-task tower. Each channel's UIH-side feature group is named after the channel (e.g. "consumption", "consumption_action", ...); the candidate-side group and the contextual group are shared. Implementation reuses the DlrmHSTU scaffolding via the _build_transducer / _stu_embedding_dim hooks added in the previous refactor commit. Single-channel UltraHSTU configs degrade cleanly to a bare HSTUTransducer (no _HSTUTransducerStack wrapper) so AOTI export and JIT trace shapes are unchanged. The auto-import in tzrec/__init__.py picks the model up via the metaclass-based registry on BaseModel; no explicit __init__.py wiring is required. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Add tzrec/models/ultra_hstu_test.py mirroring dlrm_hstu_test: a
hypothesis-driven property test that builds a 2-channel UltraHSTU
("a", "b"), runs prediction across NORMAL / FX_TRACE / JIT_SCRIPT /
AOT_INDUCTOR graph types and PYTORCH / TRITON kernels, and asserts the
candidate-count shape on each task's logits/probs. Each channel has
its own per-channel UIH-side feature groups (`a` / `a_action` /
`a_watchtime` / `a_timestamp` and the same for `b`) while `candidate`,
`candidate_timestamp`, and `contextual` are shared, exercising the
name-based key routing through the preprocessor.
Add tzrec/tests/configs/ultra_hstu_kuairand_1k.config: a kuairand-1k
demo cloned from dlrm_ultra_hstu_cutlass_kuairand_1k.config with the
dlrm_hstu block replaced by ultra_hstu and two HSTU sub-configs
(name="a", name="b"), each with the same SLA + truncation settings as
the source. In this minimal demo both channels reference the same
underlying uih_seq features; in production each channel would consume
distinct sources (e.g. consumption vs engagement).
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- Rename _init_after_assert -> _init in DlrmHSTU and the override site in UltraHSTU. Same intent, less noise. - Use uih_click as the docstring/proto-comment channel-name example everywhere (preprocessors.py, hstu_transducer wiring, module.proto, ultra_hstu.py); the previous "consumption" wasn't aligned with the uih_<event> convention. - Move UltraHSTU into the natural slot UltraHSTU ultra_hstu = 206 and bump PEPNet to pepnet = 207 so the ULTRA model sits next to its DlrmHSTU dlrm_hstu = 205 sibling. - Drop the separate tzrec/tests/configs/ultra_hstu_kuairand_1k.config and convert the existing dlrm_ultra_hstu_cutlass_kuairand_1k.config to use the ultra_hstu model with two channels (uih_click, uih_view). The two channels reference the same underlying uih_seq features in this minimal demo; in production each channel would consume distinct sources. Added test_rank_ultra_hstu_cutlass_train_eval_export in rank_integration_test.py to exercise the converted config end-to-end (train -> eval -> export -> predict), mirroring test_rank_dlrm_hstu_cutlass_train_eval_export. - Simplify _HSTUTransducerStack.forward: every sub-transducer is built with return_full_embeddings=False, so the second tuple element is always None. Drop the dead "if all(f is not None for f in full_list)" branch. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
|
|
||
| RocketLaunching rocket_launching = 500; | ||
| PEPNet pepnet = 206; | ||
| PEPNet pepnet = 207; |
There was a problem hiding this comment.
Renumbering PEPNet from 206 → 207 and reusing tag 206 for UltraHSTU is wire-format-breaking. Any binary-serialized ModelConfig that previously stored pepnet at tag 206 will now silently parse as ultra_hstu (or be discarded as a type mismatch). In-tree configs are loaded via text_format/json_format (name-keyed) so the in-repo path is unaffected, but this is bad proto hygiene and dangerous for any downstream tooling, exported pipeline.config bundles, or external consumers.
Suggested fix: keep PEPNet pepnet = 206 at its original tag and assign UltraHSTU a fresh, never-published tag (e.g. 208). Tag numbers should be treated as immutable once published.
| channels = self._model_config.hstu | ||
| assert len(channels) >= 1, "UltraHSTU requires at least 1 hstu channel" | ||
| if len(channels) >= 2: | ||
| names = [c.name for c in channels] | ||
| assert all(names) and len(set(names)) == len(names), ( | ||
| "When UltraHSTU has >= 2 channels every channel must set a " | ||
| f"unique non-empty `name`, got {names!r}" | ||
| ) |
There was a problem hiding this comment.
Two related validation gaps in this block:
-
Single-channel mode skips all name validation. A user with one channel can set
name="candidate"(or any reserved feature-group name);_build_transducerwill then callembedding_group.group_total_dim("candidate")— which exists, sincecandidateis the shared candidate-side group — and the UIH path will silently consume candidate-side embeddings without raising. Worth checking that single-channelname(when non-empty) is also non-reserved. -
No charset/whitespace check on
name. The proto field is an unconstrainedoptional string. Anamecontaining.would produce keys like"foo.bar.sequence";EmbeddingGroupImpl.group_total_dimsplits on.and would resolve via"foo", mapping the user's intent to an unrelated group. Whitespace or look-alike unicode would mismatch the dict key from the feature-group config, producing a confusingKeyErrordeep in forward.
Suggested: validate name once (non-reserved, no ., matches ^[A-Za-z_][A-Za-z0-9_]*$) for every channel regardless of count.
| def _build_transducer( | ||
| self, contextual_feature_dim: int, max_contextual_seq_len: int | ||
| ) -> torch.nn.Module: | ||
| transducers: List[HSTUTransducer] = [] | ||
| for ch in self._model_config.hstu: | ||
| uih_group = ch.name if ch.name else "uih" | ||
| transducers.append( | ||
| HSTUTransducer( | ||
| uih_embedding_dim=self.embedding_group.group_total_dim(uih_group), | ||
| target_embedding_dim=self.embedding_group.group_total_dim( | ||
| "candidate" | ||
| ), | ||
| contextual_feature_dim=contextual_feature_dim, | ||
| max_contextual_seq_len=max_contextual_seq_len, | ||
| contextual_group_name=self._contextual_group_name, | ||
| scaling_seqlen=self._model_config.max_seq_len, | ||
| **config_to_kwargs(ch), | ||
| return_full_embeddings=False, | ||
| ) | ||
| ) | ||
| if len(transducers) == 1: | ||
| return transducers[0] | ||
| return _HSTUTransducerStack(transducers) |
There was a problem hiding this comment.
Two small concerns on the stack construction path:
-
Silent friendliness gap on missing groups. If a channel sets
name="uih_click"but the user forgets to define one of the four required groups (uih_click,uih_click_action,uih_click_watchtime,uih_click_timestamp), the missing UIH-group surfaces here as a bareKeyErrorfromgroup_total_dim, while a missing_action/_watchtime/_timestampgroup only surfaces at runtime insideContextualInterleavePreprocessor.forward. Consider asserting the four groups exist up-front (skipping_action/_watchtimewhen those encoders aren't configured) so misconfigs fail with a clear message at construction time. -
_HSTUTransducerStackalways discards the second tuple element. The stack hardcodesreturn_full_embeddings=Falseper channel and returnsNoneas the second slot, but if a future change ever flipsreturn_full_embeddings=Trueon a sub-transducer (e.g. for an aux-loss head),cand[1]is silently dropped. Worth a one-lineassert not transducer._return_full_embeddingsin the stack constructor and a note in its docstring.
| # Channel names. The MoT routing in ContextualInterleavePreprocessor | ||
| # substitutes the channel name for the default "uih" prefix on each | ||
| # UIH-side key (e.g. "a.sequence", "a_action.sequence", ...). | ||
| _CHANNELS = ("a", "b") |
There was a problem hiding this comment.
Two coverage gaps from hardcoding _CHANNELS = ("a", "b") and embedding_dim=512 for every channel:
-
Single-channel path is never exercised. The branch at
ultra_hstu.py:118-119(if len(transducers) == 1: return transducers[0]) — the documented "behaves like DlrmHSTU" mode that bypasses_HSTUTransducerStackentirely — is the easiest path to silently break in a future refactor (e.g. shape regression: single tensor vs. concatenated). Recommend a separate test case (or parametrize over(("a",), ("a","b"))) and assertisinstance(model._hstu_transducer, HSTUTransducer)for the single-channel case. -
Heterogeneous channel dims are not exercised. With both channels at
embedding_dim=512,sum(c.stu.embedding_dim for c in channels)is mathematically indistinguishable fromN * dimormax(...)— a buggy_stu_embedding_dimwould still pass. Use distinct dims (e.g. 256 and 512) in at least one case so the sum is the only correct answer.
Also: there are no negative tests for the __init__ validation assertions (len(channels) >= 1; >= 2 channels need unique non-empty names) — easy to add and they protect user-facing error paths.
| embedding_name=f"{channel}_video_id_emb", | ||
| num_buckets=1000, | ||
| ) | ||
| ), | ||
| feature_pb2.SeqFeatureConfig( | ||
| id_feature=feature_pb2.IdFeature( | ||
| feature_name="video_cat", | ||
| embedding_dim=16, | ||
| embedding_name=f"{channel}_video_cat_emb", |
There was a problem hiding this comment.
Per-channel embedding_name=f"{channel}_..._emb" gives every channel its own embedding table for the same physical feature (video_id, video_cat). Users will copy this pattern, and in production it multiplies sparse-parameter count, TBE forward/backward work, and all-to-all communication volume by N — typically the dominant cost in a sharded recsys deployment. The integration config (dlrm_ultra_hstu_cutlass_kuairand_1k.config) implicitly shares table names because both channels point at the same uih_seq__video_id, so the test config and the integration config disagree on the recommended pattern.
Suggest making the test share embedding_name across channels by default to set the right example, and add a class/proto docstring note recommending shared embedding_name unless per-channel tables are explicitly desired.
| message UltraHSTU { | ||
| // ULTRA-HSTU model with Mixture of Transducers. Holds one HSTU | ||
| // sub-config per channel; per-candidate outputs are concatenated | ||
| // along the embedding dim. When `hstu` has a single entry the | ||
| // model behaves like DlrmHSTU (modulo the optional name routing on | ||
| // the channel). When `hstu` has >= 2 entries every channel must | ||
| // set a unique non-empty `name`, and the corresponding feature | ||
| // groups (uih-side replaced by `<name>`, `<name>_action`, | ||
| // `<name>_watchtime`, `<name>_timestamp`) must be defined. The | ||
| // candidate-side and contextual feature groups are shared across | ||
| // every channel. | ||
| repeated HSTU hstu = 1; | ||
| // multi task tower config | ||
| required FusionMTLTower fusion_mtl_tower = 2; | ||
| // max sequence length | ||
| required uint32 max_seq_len = 3; | ||
| // item embedding mlp hidden dimension | ||
| optional uint32 item_embedding_hidden_dim = 4 [default = 512]; | ||
| // enables loss averaging computation globally across all ranks (total rank) | ||
| // instead of locally (local rank). | ||
| optional bool enable_global_average_loss = 5 [default = true]; | ||
| // timestamp of sequence is ascending or descending | ||
| optional bool sequence_timestamp_is_ascending = 6 [default = true]; | ||
| // concat all contextual features on channel dim as one token |
There was a problem hiding this comment.
Fields 2-7 of UltraHSTU are an exact copy of DlrmHSTU fields 2-7; only field 1 (hstu repeated vs single) differs. The Python _init() already references all of these via self._model_config.X and works for both messages purely by field-name duck typing — that contract silently breaks the day someone adds a field to one and forgets the other.
Lower-effort options: add a // keep in sync with DlrmHSTU 2-7 TODO comment to both messages. Higher-effort: extract a shared sub-message embedded in both, or just make DlrmHSTU.hstu repeated and retire UltraHSTU entirely.
Review SummarySolid, well-scoped addition. The Highest-impact items (see inline comments for details):
Performance note (informational, not blocking): Docs: the deferred |
…eddings, slot swap - ultra_hstu_test: parametrize over (single, multi_uniform, multi_hetero) channel topologies via @parameterized.expand on top of the existing hypothesis sweep. Asserts on the single-channel branch returning a bare HSTUTransducer (not _HSTUTransducerStack), and on heterogeneous STU dims (256 + 512) so that _stu_embedding_dim's `sum` is the only correct answer (a buggy `N * dim` or `max(...)` would no longer pass). Drop per-channel embedding_name in the test helper so the pattern matches the integration config and won't push users toward per-channel embedding tables. - ultra_hstu / module.proto: docstring + proto-comment notes recommending shared embedding_name across channels. EmbeddingGroup dedupes by embedding_name, so per-channel tables are an opt-in for the unusual case where you specifically want disjoint tables (it multiplies sparse-parameter count, TBE forward/backward work, and all-to-all volume by the channel count). - model.proto: pepnet = 206 (restored to its original deployed slot, placed adjacent to dlrm_hstu = 205); ultra_hstu = 207. The kuairand-mot-1k integration test data is generated separately by experiments/preprocess_kuairand_mot.py (gitignored, run locally) and wired into ci_data.sh + the cutlass config in a follow-up commit once the parquet is uploaded to OSS. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…lerate)
- ci_data.sh: switch every kuairand-1k* URL from oss-cn-beijing to
oss-accelerate, and add wget lines for the new kuairand-mot-1k
{train,eval} parquet so CI gets distinct per-channel UIH sequences.
oss-accelerate is the global-acceleration endpoint and works from
the GitHub-hosted CPU runners.
- .gitignore: cover data/test/kuairand-mot-1k* (the existing
kuairand-1k* glob does not match the mot-prefixed files).
- dlrm_ultra_hstu_cutlass_kuairand_1k.config: point train/eval paths
at the new kuairand-mot-1k parquet; replace the single uih_seq
feature_config with click_seq + view_seq feature_configs (both
declare the four sub-features and share video_id_emb so
EmbeddingGroup dedupes the table); rewire the per-channel
feature_groups (uih_click* -> click_seq__*, uih_view* ->
view_seq__*) so the integration test now exercises truly distinct
per-channel UIH streams (action_weight & 1 for clicks,
action_weight & 64 for long-views) sharing the candidate side and
contextual features.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…ation export path
ultra_hstu_test on CPU-only CI was hitting "Torch not compiled with
CUDA enabled" because @unittest.skipIf(*gpu_unavailable) sat *outside*
@parameterized.expand and so didn't propagate to the generated
test_ultra_hstu_{0_single,1_multi_uniform,2_multi_hetero} methods.
Move skipIf below expand so each parametrized case carries the skip.
The cutlass-kuairand integration test exposed four pre-existing fx /
AOT-export bugs in shared HSTU SLA + truncation infra that no test
had previously exercised end-to-end (PR alibaba#486 added the config but did
not wire it into rank_integration_test). Each fix is minimal:
1. tzrec/ops/hstu_attention_utils.py:build_sla_func_tensor —
`if seq_offsets.dtype != torch.int32: ...` evaluates Proxy.dtype
under fx. Drop the conditional; `tensor.to(dtype)` is a no-op when
already-matching, so unconditional cast costs nothing on the fast
path.
2. tzrec/modules/gr/stu.py:STULayer.forward — the SLA cache
hit-check `my_sig == prev_attn_func_sig` compared a tuple
containing `x.size(0)` (a Proxy at trace time) against another
Proxy-containing tuple, falling into Proxy.__eq__ → bool →
TraceError. Skip the cache under is_fx_tracing(); it's a
runtime-only optimization.
3. tzrec/ops/hstu_attention_utils.py:compute_stu_truncation_plan —
`int(new_lengths.max().item())` raises `int(Proxy)` under fx.
Replace with `fx_int_item(new_lengths.max())` (already
`@torch.fx.wrap`-ed in tzrec.utils.fx_util). Same trick for the
new offset totals (total_dropped/kept/prefix/rest) so
apply_stu_truncation_plan can pass them as static `total_len_*`
into split_2D_jagged and skip the `.item()` fallback inside the
triton fake impl. STUTruncationPlan grows four int fields — see
the docstring. hstu_transducer._replay_truncation_state now
derives `post_truncation_total_uih_len` from
`plan.total_kept - total_targets` instead of returning None, for
the same reason. hstu_transducer_test updated for the new
keyword and the no-longer-None return.
4. tzrec/ops/hstu_attention_utils.py:build_sla_func_tensor —
`searchsorted(seq_offsets_i32[1:], pos_global, ...)` failed
Inductor lowering with `NotImplementedError: SliceView` because a
slice produced a SliceView IR that searchsorted's
`_boundaries_helper` can't take stride from. Replace
searchsorted with `torch.diff(seq_offsets_i32)` + `repeat_interleave`,
which avoids both the slice and the searchsorted entirely while
producing the same per-position batch_ids.
Also fix predict_input_path in test_rank_ultra_hstu_cutlass_train_eval_export
(was pointing at the old kuairand-1k parquet; now correctly uses the
kuairand-mot-1k file the rest of the test trains+evals against).
All five tests pass locally:
python -m tzrec.models.dlrm_hstu_test (back-compat)
python -m tzrec.models.ultra_hstu_test (parametrized)
python -m tzrec.modules.gr.preprocessors_test (back-compat)
python -m tzrec.modules.gr.hstu_transducer_test (back-compat)
python -m tzrec.modules.gr.stu_test (back-compat)
python -m unittest tzrec.tests.rank_integration_test.RankIntegrationTest.test_rank_ultra_hstu_cutlass_train_eval_export
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…ev-sigs in STUStack Replaces the is_fx_tracing() short-circuit in 32b6ad9 with the proper fix. The runtime sig comparison was failing under fx symbolic tracing because the 6-tuple cache key embedded `x.size(0)` (a Proxy at trace time). The dynamic dim only ever served to invalidate the cache across the one statically-known event -- mid-stack truncation -- so encode that signal statically: `STUStack.__init__` precomputes a per-layer `prev_attn_func_sig` that is `None` at layer 0 and at the truncation-split layer, and the previous layer's static sig otherwise. The cache key is now a comma-separated string encoding `(sla_k1, sla_k2, contextual_seq_len, num_heads, target_aware)` exposed as `STULayer.attn_func_static_sig` (read-only @Property). Strings are pure Python primitives and trivially fx-trace-safe. STULayer.forward's hit-check is `prev_attn_func is not None and prev_attn_func_sig == self.attn_func_static_sig`. forward returns a 2-tuple now (no sig in the return); the precomputed list is the single source of truth for "which sig describes prev_attn_func". STUStack.forward's loop body is now: optional `truncate_input` -> call layer with precomputed sig -> carry the new attn_func forward. No sig bookkeeping at all in the hot path. At the truncation index the precomputed sig is already None, so the layer's comparison fails and it rebuilds with the post-truncation total_q. Drops the `is_fx_tracing` import added in 32b6ad9; the band-aid is gone. Cache invariant preserved: at most two `build_sla_func_tensor` calls per forward (layer 0 + the truncation-split layer). stu_test cache tests adapted to the 2-tuple return + property. Adds parametrized fx-trace coverage in hstu_attention_utils_test: BuildSlaFuncTensorTraceTest and StuTruncationTraceTest each exercise NORMAL + FX_TRACE TestGraphTypes on tiny nn.Module wrappers around build_sla_func_tensor and the compute+apply_stu_truncation_plan pair. Eager and traced outputs must match. Targeted regression coverage for the hstu_attention_utils.py Proxy fixes from 32b6ad9. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Replaces the dedicated BuildSlaFuncTensorTraceTest / StuTruncationTraceTest classes from 355cd29 with @parameterized.expand((NORMAL, FX_TRACE)) on the existing tests. Each numerical-correctness assertion now pins both the eager output *and* the fx-traceability of the helper in a single test method. BuildSlaFuncTensorTest._build accepts a graph_type and routes through a tiny nn.Module wrapper + create_test_module so NORMAL runs eagerly and FX_TRACE runs the symbolic-traced module. All 7 existing tests gain the parametrize decorator (14 cases total). test_int32_offsets_skip_cast still verifies dtype-cast equivalence between int32 and int64 offsets, both via the wrapper now. StuTruncationTest.test_matches_reference cross-products the existing 5 input cases with (NORMAL, FX_TRACE) -> 10 cases. test_replay_on_parallel_jagged gains the parametrize too; test_validation_raises_on_negative_params stays eager-only since validation lives in compute_stu_truncation_plan() construction before any forward. The wrapper modules carry a static `target_aware` flag so under fx trace the `if num_targets is not None` branch is selected at construction time -- a Proxy num_targets would otherwise always trace the not-None branch and break callers that pass None at run time. Total goes from 13 tests -> 27 tests; same assertion strength, +fx-trace regression coverage. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The three private nn.Module wrappers in hstu_attention_utils_test (_BuildSlaFuncTensorWrapper, _StuTruncationWrapper, _ReplayTruncationWrapper) are scaffolding for the parametrized fx-trace tests; their constructor + forward arg lists already say what they do. Drop the class docstrings and inline notes. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
… separator Tighten the multi-line WHY-comments added across the MoT/SLA/truncation paths. The intent (fx-trace safety, Inductor SliceView avoidance, .item() vs FakeTensor) stays in one or two lines per touchpoint; removed restated-from-code rationale. Switch attn_func_static_sig from comma- to colon-separated so the encoded sig (e.g. "256:32:0:4:1") reads as obviously categorical rather than a numeric tuple repr. No behavior change; stu_test (13/OK) and hstu_attention_utils_test (27/OK) pass. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Cuts a release that includes UltraHSTU MoT (PR alibaba#492). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- docs/source/models/ultra_hstu.md (NEW): per-section walk through the four ULTRA-HSTU optimizations (SLA, mid-stack truncation, MoT, selective rematerialization) following the architecture in Ding et al., 2026 (arXiv:2602.16986); includes a two-channel kuairand config snippet and a field reference for the UltraHSTU / HSTU.name / sla_k1+k2 / attn_truncation_* additions. - README.md: add ULTRA-HSTU row under "Generative Recommendation" with description "HSTU with Semi-Local Attention, Attention Truncation, and Mixture of Transducers". - docs/source/models/generative.rst: add ultra_hstu to the toctree. - Rename tzrec/tests/configs/dlrm_ultra_hstu_cutlass_kuairand_1k.config -> ultra_hstu_cutlass_kuairand_1k.config (drop the "dlrm_" prefix -- this config now uses the ultra_hstu model proto, not dlrm_hstu). rank_integration_test reference updated. - Bump copyright year on the two new files (ultra_hstu.py, ultra_hstu_test.py) from 2025 to 2026. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
ultra_hstu.md gains the conventional 示例 / 参考论文 sections (data + config + export + paper link) following the pattern in dlrm_hstu.md and other model docs. Config link points at the user-uploaded https://tzrec.oss-cn-beijing.aliyuncs.com/config/models/ultra_hstu_kuairand.config (URL verified, HTTP 200); reuses the kuairand-27k dataset. Drops the now-stale # total_uih_len comment in hstu_transducer_test: the assertion + variable name already say it. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Match the simpler pattern used in deepfm/ple/dlrm/dssm docs: just a config link, no 数据 / 模型导出 sub-sections. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Drop the 4-line note about attn_func_static_sig in the return description; the property docstring already covers it. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
|
|
||
| - **Mixture of Transducers (MoT)** —— 模型并行运行 N 个 `HSTUTransducer`,每个对应一路 UIH 通道(如点击、长播等);每个通道有自己的 STU 栈、自己的 SLA / truncation 配置。各通道针对 candidate 的输出 embedding 在 channel 维上拼接后喂给统一的 `FusionMTLTower`。**Candidate 与 contextual 共享**,UIH 按通道拆分;channel 数由 `repeated HSTU hstu` 决定。 | ||
|
|
||
| - **Selective Rematerialization** —— STU layer 的反向传播按需重算两类中间张量以省显存:`recompute_normed_x_in_backward` 控制 LayerNorm 输出是否存储,`recompute_uvqk_in_backward` 控制 (U, V, Q, K) 投影是否存储。两者均默认 `true`,对 attention activation memory 占主导的大模型尤其重要。 |
There was a problem hiding this comment.
The proto field names cited here don't exist. The actual fields in tzrec/protos/module.proto:249,251 are recompute_normed_x and recompute_uvqk (no _in_backward suffix — that suffix only appears as the internal kwarg in tzrec/ops/hstu_compute.py). A user copy-pasting recompute_normed_x_in_backward: true into a stu { ... } config will get a proto-parse error. Please rename to the actual proto field names.
| assert len(channels) >= 1, "UltraHSTU requires at least 1 hstu channel" | ||
| if len(channels) >= 2: | ||
| names = [c.name for c in channels] | ||
| assert all(names) and len(set(names)) == len(names), ( | ||
| "When UltraHSTU has >= 2 channels every channel must set a " | ||
| f"unique non-empty `name`, got {names!r}" | ||
| ) |
There was a problem hiding this comment.
Two related validation gaps worth closing up-front:
-
Reserved-name collisions are silently accepted.
name="uih",name="candidate", orname="contextual"on a multi-channel config will pass theall(names)and uniqueness check, but the derived UIH-side keys (preprocessors.py:151-154) will then alias the candidate / contextual / default-uih groups. Suggest rejecting reserved names and*_action/*_watchtime/*_timestampsuffixes that would collide with the per-channel naming convention. -
Missing feature_group is reported deep, not at construction. A typo in
namesurfaces as aKeyErrorinpreprocessors.py:373(grouped_features[f"{self._uih_key}.sequence"]) at forward time, or as whateverembedding_group.group_total_dim(uih_group)raises mid-_build_transducer. Validating up-front that for every channel the four groups (<name>,<name>_action,<name>_watchtime,<name>_timestamp) exist onself.embedding_groupand listing the missing keys would make misconfigurations obvious.
| @unittest.skipIf(*gpu_unavailable) | ||
| @given( | ||
| graph_type=st.sampled_from( | ||
| [ | ||
| TestGraphType.NORMAL, | ||
| TestGraphType.FX_TRACE, | ||
| TestGraphType.JIT_SCRIPT, | ||
| TestGraphType.AOT_INDUCTOR, | ||
| ] | ||
| ), | ||
| kernel=st.sampled_from([Kernel.PYTORCH, Kernel.TRITON]), | ||
| contextual_group_type=st.sampled_from( | ||
| [model_pb2.FeatureGroupType.DEEP, model_pb2.FeatureGroupType.SEQUENCE] | ||
| ), | ||
| sequence_timestamp_is_ascending=st.sampled_from([True, False]), | ||
| enable_global_average_loss=st.sampled_from([True, False]), | ||
| concat_contextual_features=st.sampled_from([True, False]), | ||
| ) | ||
| @settings( | ||
| verbosity=Verbosity.verbose, | ||
| max_examples=6, | ||
| deadline=None, | ||
| ) | ||
| def test_ultra_hstu( | ||
| self, | ||
| case_name, | ||
| channel_specs, | ||
| graph_type, | ||
| kernel, | ||
| contextual_group_type, | ||
| sequence_timestamp_is_ascending, | ||
| enable_global_average_loss, | ||
| concat_contextual_features, | ||
| ) -> None: | ||
| # JIT_SCRIPT only supports the PyTorch kernel today. | ||
| assume( | ||
| (graph_type == TestGraphType.JIT_SCRIPT and kernel == Kernel.PYTORCH) | ||
| or graph_type != TestGraphType.JIT_SCRIPT | ||
| ) | ||
|
|
||
| device = torch.device("cuda") | ||
| ultra_hstu = _build_model( | ||
| device=device, | ||
| channel_specs=channel_specs, | ||
| contextual_group_type=contextual_group_type, | ||
| enable_global_average_loss=enable_global_average_loss, | ||
| sequence_timestamp_is_ascending=sequence_timestamp_is_ascending, | ||
| concat_contextual_features=concat_contextual_features, | ||
| ) | ||
|
|
||
| # Single-channel UltraHSTU returns a bare HSTUTransducer (no | ||
| # _HSTUTransducerStack wrapper) so the predict() path matches | ||
| # DlrmHSTU's exactly. Multi-channel returns the stack. | ||
| if len(channel_specs) == 1: | ||
| self.assertIsInstance(ultra_hstu._hstu_transducer, HSTUTransducer) | ||
| self.assertNotIsInstance(ultra_hstu._hstu_transducer, _HSTUTransducerStack) | ||
| else: | ||
| self.assertIsInstance(ultra_hstu._hstu_transducer, _HSTUTransducerStack) | ||
| self.assertEqual( | ||
| ultra_hstu._stu_embedding_dim(), sum(d for _, d in channel_specs) | ||
| ) |
There was a problem hiding this comment.
Coverage gaps worth closing:
-
Three new
__init__error branches are untested (ultra_hstu.py:84-90): zero channels,len>=2with empty name,len>=2with duplicate names. All are CPU-only assertion checks — trivial to add and would prevent a regression silently dropping the validations. -
GPU-skip hides CPU-only structural assertions. Lines 396-403 (
isinstance(..., _HSTUTransducerStack)and_stu_embedding_dim() == sum(d for _, d in channel_specs)) are pure-Python checks that gate the central MoT structural invariant (the only-correct-answer formulti_heteroissum, notmaxorN*dim). Hiding them behind@unittest.skipIf(*gpu_unavailable)means CPU CI never validates them. Recommend splitting into a CPU-only structural test plus the GPU forward test. -
No single-channel-with-empty-name case — the
("single", [("a", 256)])parametrize case uses a non-empty name, so the_build_transducerfallback to the"uih"group atultra_hstu.py:98is never exercised. This is the documented "无痛迁移 from DlrmHSTU" path. -
Output is shape-checked only, not value-checked. A two-channel parity test (identical seeds for both channels, then assert the two halves of
torch.cat(..., dim=-1)along the embedding dim are equal) would catch a broken concat order or wrong index in_HSTUTransducerStack.forwardfor negligible cost.
| def forward( | ||
| self, grouped_features: Dict[str, torch.Tensor] | ||
| ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: | ||
| cand_list: List[torch.Tensor] = [] | ||
| for transducer in self._transducers: | ||
| cand, _ = transducer(grouped_features) | ||
| cand_list.append(cand) | ||
| return torch.cat(cand_list, dim=-1), None |
There was a problem hiding this comment.
N transducers run serially in a Python for loop on a single CUDA stream — each transducer launches its own preprocessor + STU stack + postprocessor + many small jagged kernels. With per-channel embedding_dim typically 32–128, kernel-launch latency and SM under-occupancy will dominate at small batch sizes, so MoT scaling won't be linear in N.
The user-facing doc at docs/source/models/ultra_hstu.md:11 calls this "模型并行运行 N 个 HSTUTransducer", which oversells the GPU behavior — eager + AOT-export both serialize today. At minimum, please soften the doc claim. Optionally, on the eager-train path you could record each iteration on a dedicated torch.cuda.Stream and join before torch.cat (skip on AOT-export).
Also, no shape/dim sanity check before torch.cat(..., dim=-1) — divergent per-candidate row counts across transducers (e.g. a future per-channel truncation behavior change) would silently misalign rows. A cheap precondition assert would prevent a hard-to-debug regression.
| def attn_func_static_sig(self) -> str: | ||
| """SLA NFUNC cache key (colon-separated, fx-trace-safe).""" | ||
| return ( | ||
| f"{self._sla_k1}:{self._sla_k2}:{self._contextual_seq_len}:" | ||
| f"{self._num_heads}:{int(self._target_aware)}" | ||
| ) |
There was a problem hiding this comment.
The cache key omits _max_attn_len and _causal, both of which affect NFUNC mask semantics. It happens to be safe today because STUStack constructs every layer in a stack with identical kwargs (hstu_transducer.py:111), so all layers in a stack share the same (max_attn_len, causal). But two layers that differ only in max_attn_len (a plausible future heterogeneous-layer config) would silently reuse a stale mask via this sig.
Defensive fix — include both in the sig:
return (
f"{self._sla_k1}:{self._sla_k2}:{self._contextual_seq_len}:"
f"{self._max_attn_len}:{int(self._causal)}:"
f"{self._num_heads}:{int(self._target_aware)}"
)
Code Review SummaryFive reviewers covered code quality, performance, tests, docs, and security. Posted 5 inline comments on the highest-signal findings; collecting the rest here. Most actionable (see inline)
Other notes worth scanning
Praise
|
Both are init-time STU constants that affect attention semantics. Excluding them was safe today because every layer in a STUStack is constructed from the same kwargs (hstu_transducer.py:107), but a future heterogeneous stack with mismatching max_attn_len or causal would silently reuse a stale `attn_func` from the cache. Cost is negligible (one int + one bool added to the colon-separated key). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Each integration-test pipeline stage runs as a torchrun subprocess with stdout/stderr piped to ./tmp/<dir>/log_*.txt; the unittest framework only sees AssertionError(False is not true). CI workflow has no upload-artifact step, so the per-stage logs are wiped with the runner workspace and the real stack trace from any flaky failure is unrecoverable from gh run view --log. When run_cmd's subprocess fails terminally, dump the last 200 lines of the captured log to stdout with a RUNCMD FAILED header so the trace lands in CI logs directly. Tail-only to keep CI output bounded. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The cutlass_hstu_train_eval_export integration test was failing deterministically (4 CI runs in a row) with an unrecoverable AssertionError(False is not true) -- the actual stack trace lived in torchrun's per-stage log_export.txt that the runner deletes on cleanup. Three changes: 1. tzrec/utils/misc_util.py:run_cmd -- bump the diagnostic tail-dump from 200 to 500 lines so the full Python traceback (Inductor stack + sympy frame + final exception) lands in CI stdout. 200 lines only captured the inner Inductor stack and was cut off mid-frame above the actual exception text. 2. tzrec/ops/hstu_attention_utils.py:build_sla_func_tensor -- add `_emit_check_total_q_positive(total_q)` (an @torch.fx.wrap'd helper that emits `torch._check_is_size + torch._check(>0)`). Inductor's combine_contiguous_dims for the (nheads, 3, total_q) output emits ModularIndexing(idx, total_q, 3) whose sympy simplifier raises ZeroDivisionError under AOT compile when total_q is dynamic and not provably > 0 (the `Max(1, total_q)` in computed strides is the smoking gun). The check is fx-trace safe because the wrap turns it into a call_function leaf, and no-op outside torch.compiler.is_compiling() so eager runs pay nothing. 3. tzrec/ops/hstu_attention_utils.py:build_sla_func_tensor -- swap `repeat_interleave(arange(B), seq_lengths)` for `searchsorted( cumsum(seq_lengths), pos_global, right=True).clamp_max(B - 1)`. repeat_interleave's Inductor lowering generates additional ModularIndexing patterns that compound the same sympy-divisor problem; the searchsorted variant uses fewer dynamic-shape ops. `clamp_max(B - 1)` is an eager no-op (pos_global < boundaries[-1] always) but gives Inductor a provable upper bound for the downstream `seq_lengths[batch_ids]` indirect index, so it doesn't insert a defensive device-side assert that would fire under AOT compile (separately observed locally before the torch._check). As a bonus, fixes a subtle off-by-one in the original (pre-PR) `searchsorted(seq_offsets[1:], pos)` call: without right=True, tokens at exactly seq_offsets[i] were assigned to batch i-1 instead of batch i. Verification: ran python -m unittest tzrec.tests.rank_integration_test.RankIntegrationTest.test_rank_ultra_hstu_cutlass_train_eval_export locally with a wiped /tmp/torchinductor_tianyi cache; passes in 162s. Pre-fix the same command failed in 125s with `InductorError: ZeroDivisionError`. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The previous fix (eee0f99) added a torch._check(total_q > 0) hint and swapped repeat_interleave for searchsorted+clamp. The hint passed once locally with a fresh inductor cache but failed on rerun and on CI -- fundamentally because Inductor's combine_contiguous_dims for the (nheads, 3, total_q) output of `build_sla_func_tensor` emits `ModularIndexing(idx, total_q, 3)` and sympy's simplifier hits a ZeroDivisionError on the symbolic divisor regardless of any torch._check constraint we add (the constraint apparently doesn't propagate to `combine_contiguous_dims`'s shape-simplification path). The robust fix is to make Inductor never see the broadcast: wrap `func_2d.unsqueeze(0).expand(nheads, 3, total_q).contiguous()` in a custom op (`tzrec::_sla_broadcast_func_to_heads`) that is a black box to Inductor's fuser. Inductor passes the (3, total_q) input through unchanged and the eager-time impl materializes the head dim outside any compiled region. Uses the low-level torch.library.define/impl/register_fake API matching cutlass_hstu_attention.py's convention; the `@torch.library.custom_op` decorator triggers an AOTI multi-thread predict deadlock per the comment in that file. Drop the unused `_emit_check_total_q_positive` helper and update the docstring -- the returned tensor is now contiguous (head dim materialized inside the custom op), no longer a stride-0 broadcast view. Verification: ran python -m unittest tzrec.tests.rank_integration_test.RankIntegrationTest.test_rank_ultra_hstu_cutlass_train_eval_export 3 times in a row locally (fresh cache, then warm cache twice); all three passed (163s, 146s, 151s -- second run faster due to cache reuse on the now-correct compiled kernel). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…ensor Commit eee0f99 swapped repeat_interleave for searchsorted+clamp under the (later disproven) hypothesis that repeat_interleave's Inductor lowering was the source of the ZeroDivisionError. The actual root cause was the (nheads, 3, total_q) head broadcast (fixed in cf7c5e5 by wrapping that op in a black-box custom op). With the broadcast now opaque to Inductor, repeat_interleave's lowering is no longer in the failure path and the searchsorted variant just adds two extra ops (cumsum + clamp_max) to the SLA cache build for no benefit. Restoring the simpler diff + repeat_interleave form from commit 32b6ad9. Verified: ran the integration test 3x locally (fresh inductor cache + warm cache twice); all three passed (166s, 149s, 147s) -- same time profile as the searchsorted version. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Summary
Adds
UltraHSTUas a new top-level model type implementing the four ULTRA-HSTU optimizations from Ding et al., 2026 (arXiv:2602.16986) on top ofDlrmHSTU:HSTU.stu.sla_k1(local window) +sla_k2(global prefix); requiresKernel.CUTLASSorKernel.PYTORCH.HSTU.attn_truncation_split_layer+attn_truncation_tail_len; drops UIH prefix after the split layer, preserving contextual + targets.repeated HSTU hstuinUltraHSTUruns N parallel transducers, one per UIH channel; per-candidate outputs concatenate on the embedding dim before the multi-task tower. Candidate and contextual groups stay shared. Channel selection is via a single newoptional string HSTU.namefield that replaces the defaultuihprefix on UIH-side feature_group keys (e.g.name="uih_click"readsuih_click.sequence/uih_click_action.sequence/ ...). Emptynamepreserves all currentDlrmHSTUconfigs.STUviarecompute_normed_x_in_backward/recompute_uvqk_in_backward; surfaced + documented for ULTRA-HSTU.Single-channel
UltraHSTU(onehstuentry, emptyname) is behaviorally equivalent toDlrmHSTU; SLA and Attention Truncation are independently usable in that mode too.Implementation notes
optional string nameonHSTU; newUltraHSTUmodel message inmulti_task_rank.protomirroringDlrmHSTU's top-level fields withrepeated HSTU hstu;UltraHSTU ultra_hstu = 207inmodel.proto's oneof.UltraHSTUsubclassesDlrmHSTUvia two small extracted hooks (_build_transducer,_stu_embedding_dim); the wrapper_HSTUTransducerStackis bypassed for single-channel configs so AOTI / JIT trace shapes are unchanged.STULayerexposesattn_func_static_sig(a colon-separated:-string of init-time SLA constants) as a@property;STUStackprecomputes a per-layerprev_attn_func_siglist at construction so the forward loop has zero sig bookkeeping and is fx-trace-safe by construction.build_sla_func_tensor: unconditionalint32cast (no Proxy.dtype check);torch.diff+repeat_interleaveinstead ofsearchsortedon a slice (avoids InductorSliceView.get_stride()NotImplementedError).compute_stu_truncation_planreturns precomputed total-int fields onSTUTruncationPlan(total_dropped/total_kept/total_prefix/total_rest) soapply_stu_truncation_plancan pass them astotal_len_*intosplit_2D_jaggedand skip its.item()fallback under FakeTensor.tzrec/tests/configs/dlrm_ultra_hstu_cutlass_kuairand_1k.config→ultra_hstu_cutlass_kuairand_1k.config; converted toultra_hstuwith two channels (uih_click,uih_view); addedtest_rank_ultra_hstu_cutlass_train_eval_exportinrank_integration_test.py.action_weightbits at preprocess time) wired intoscripts/ci/ci_data.sh(oss-accelerate); existing kuairand-1k URLs also moved to oss-accelerate.tzrec/models/ultra_hstu_test.pyparametrizes over(single, multi_uniform, multi_hetero)channel topologies × NORMAL / FX_TRACE / JIT_SCRIPT / AOT_INDUCTOR × PYTORCH / TRITON.tzrec/ops/hstu_attention_utils_test.pyparametrizes every numerical-correctness test over NORMAL + FX_TRACE for direct fx-trace regression coverage.docs/source/models/ultra_hstu.md(NEW) + README + generative.rst toctree.tzrecversion bump 1.1.16 → 1.1.17.Test plan
python -m tzrec.modules.gr.preprocessors_test— back-compat (emptynamedefaults to existing keys).python -m tzrec.modules.gr.hstu_transducer_test— back-compat.python -m tzrec.modules.gr.stu_test— SLA cache + truncation back-compat.python -m tzrec.modules.gr.preprocessors_test— back-compat.python -m tzrec.ops.hstu_attention_utils_test— 27 parametrized cases (eager + fx-trace).python -m tzrec.models.dlrm_hstu_test— DlrmHSTU refactor is behavior-preserving (full hypothesis sweep).python -m tzrec.models.ultra_hstu_test— 3 channel topologies × hypothesis sweep across all graph types and kernels.python -m unittest tzrec.tests.rank_integration_test.RankIntegrationTest.test_rank_ultra_hstu_cutlass_train_eval_export— end-to-end train → eval → export → predict on kuairand-mot-1k with two MoT channels under CUTLASS + AOT export.🤖 Generated with Claude Code