Fix boolean 4D attention-mask handling in joint-QKV bridge attention reconstruction#1198
Merged
jlarson4 merged 11 commits intoTransformerLensOrg:dev-3.xfrom Mar 13, 2026
Merged
Conversation
…avior with `TL_UV_ACTIVE=1` instead of requiring manual edits for out-of-place virtualenv workflows.
Contributor
There was a problem hiding this comment.
Pull request overview
This PR fixes attention-mask handling in the joint-QKV attention reconstruction paths used by TransformerBridge components, specifically correcting behavior when HuggingFace provides a 4D boolean attention mask so masked positions do not leak into attention.
Changes:
- Convert 4D boolean attention masks into additive (float) masks using
torch.finfo(dtype).minbefore adding to attention scores. - Treat HuggingFace-provided 4D masks as authoritative for causal+padding semantics (and only fall back to a local
trilcausal mask when no HF mask is provided). - Add/update regressions and compatibility assertions around masked attention-score sentinel values; update Makefile to optionally use
uv --active.
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 1 comment.
Show a summary per file
| File | Description |
|---|---|
| transformer_lens/model_bridge/generalized_components/joint_qkv_attention.py | Fix boolean 4D mask handling and adjust causal masking behavior in manual attention reconstruction. |
| transformer_lens/model_bridge/generalized_components/joint_qkv_position_embeddings_attention.py | Apply analogous boolean-mask handling to the rotary (position-embeddings) reconstruction path. |
| tests/unit/model_bridge/generalized_components/test_joint_qkv_attention.py | Add regressions asserting boolean 4D masks match equivalent additive masks for both bridge variants. |
| tests/integration/model_bridge/compatibility/test_legacy_hooks.py | Update legacy cache equality expectations for hook_attn_scores masked sentinel values. |
| makefile | Add TL_UV_ACTIVE toggle to use uv run/sync --active; use $(MAKE) for sub-targets. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
You can also share your feedback on Copilot code review. Take the survey.
transformer_lens/model_bridge/generalized_components/joint_qkv_position_embeddings_attention.py
Outdated
Show resolved
Hide resolved
8 tasks
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
This change fixes a masking bug in the joint-QKV bridge attention reconstruction path used by both
JointQKVAttentionBridge._reconstruct_attention()andJointQKVPositionEmbeddingsAttentionBridge._reconstruct_attention(). The issue surfaced during downstream backend-parity work in Interpretune and, more importantly, caused TransformerBridge to diverge from native HuggingFace behavior on padded inputs.The change is small in code size but high-impact in behavior: when HuggingFace supplied a 4D boolean causal mask, the bridge reconstruction path treated that boolean mask as additive numeric values and effectively let masked positions leak into attention.
The Issue
The joint-QKV
_reconstruct_attention()implementations are the manual attention-reconstruction paths used by bridge attention modules built from fused QKV projections, including the rotary-enabled variant. They are not fallback branches inside a larger conditional; they are the bridge's explicit path for rebuilding attention after the fused QKV projections have been split into individually hookable components. Before this fix, whenattention_maskwas present the methods assumed the mask could be added directly to float attention scores.That assumption is valid for HuggingFace's additive float masks, but not for boolean 4D masks.
In the failing case:
True-> "attend",False-> "mask")._reconstruct_attention()added that boolean tensor directly toattn_scores, which silently castTrue/Falseto1.0/0.0.+1bias, while masked positions were not actually suppressed.The result was that TransformerBridge no longer matched the native HuggingFace forward pass on padded inputs. In downstream Interpretune (a pre-MVP latent space analysis framework) analysis runs on left-padded GPT-2 batches, this showed up as materially different answer-position logits and parity failures between TransformerBridge and NNsight, even though both are expected to approximate the HuggingFace-native execution.
Why This Matters
The bug was discovered while debugging the previously failing downstream parity test:
tests/core/test_sae_backend_parity.py::TestLogitDiffsBaseBackendParityThat test compares TransformerBridge and NNsight on the same HuggingFace model forward pass. It is intentionally strict because both backends are expected to preserve HuggingFace semantics and differ only in interception mechanism.
The broader investigation established the following backend model:
This fix reconciles TransformerBridge to that source-of-truth by making the bridge attention-reconstruction path preserve HuggingFace mask semantics.
Root Cause
HuggingFace causal mask generation can produce either:
0.0for attend,min_dtypefor masked positionsTruefor attend,Falsefor masked positionsThe previous implementation only handled the additive case correctly.
For boolean masks it performed the equivalent of:
which becomes:
scores + 1.0at valid positionsscores + 0.0at masked positionsThat is the opposite of the intended behavior for masked entries.
The Fix
The joint-QKV attention-reconstruction implementations now:
torch.where(mask, 0.0, min_dtype)trilfallback only when no HuggingFace mask is providedKey Design Decisions
1. Only HuggingFace-style 4D masks are authoritative for causal semantics
When HuggingFace has already prepared a 4D attention mask, that mask should be treated as the canonical expression of both causal and padding semantics. Lower-rank masks such as 2D or 3D padding masks do not encode causality on their own, so the bridge still applies its local causal
trilmask before layering the caller-provided padding mask on top.2.
torch.finfo(dtype).minused consistently for all masked positionstorch.finfo(dtype).minis used for every masking path in the bridge: both when converting boolean 4D masks and in the fallback causaltrilpath that activates when no HuggingFace mask is supplied. This approach:min_dtype, not-inf).NaNgradients or unstable softmax sums) that can occur when a row is fully masked and all inputs to softmax are-inf.HookedTransformer continues to use
-infin its masking path because it is not subject to this constraint. The difference between TransformerBridge'smin_dtypeand HookedTransformer's-infat masked positions is reflected explicitly in the updated legacy hook compatibility test.3. Shared mask application logic keeps the rotary and non-rotary paths aligned
The rotary attention bridge differs from the base bridge only in how it prepares Q and K before score computation. The causal-vs-padding mask rules should be identical. To avoid further drift between the two implementations, the mask application logic now lives in a shared helper on
JointQKVAttentionBridgeand is reused byJointQKVPositionEmbeddingsAttentionBridge.Downstream Impact
This bug fix was motivated by Interpretune's latent space operation backend parity work.
Relevant downstream references:
tests/core/test_sae_backend_parity.py::TestLogitDiffsBaseBackendParitydocs/ht_bridge_parity_behavior.mddebug_4way_canonical_comparison.pyThat investigation showed:
Test Changes
This PR adds and updates tests in three places.
1. New unit regression for boolean-mask conversion
Added focused regressions in:
tests/unit/model_bridge/generalized_components/test_joint_qkv_attention.pyThe new tests construct a boolean 4D attention mask and the equivalent additive float mask, then verify that both the base joint-QKV bridge and the rotary joint-QKV bridge produce the same attention pattern and output for both representations.
This test fails without the fix because the old implementation adds boolean masks as
0/1values rather than converting them to additive masked scores.The same test module now also verifies that non-4D masks still preserve causal semantics in both bridges: a 2D padding mask must not re-enable attention to future tokens, and the boolean and additive representations must still match.
2. Updated legacy hook compatibility expectation
Updated:
tests/integration/model_bridge/compatibility/test_legacy_hooks.py::TestLegacyHookCompatibility::test_cache_hook_equality_with_hooked_transformerBefore this change, the test required raw equality for all cached hooks, including
blocks.0.attn.hook_attn_scores.That assertion is no longer correct for the masked entries of
hook_attn_scores:-infat masked causal positions.torch.finfo(dtype).min, matching HuggingFace additive-mask semantics.The updated test keeps the original compatibility guarantee for every other expected hook and refines the
hook_attn_scorescheck to assert:-inftorch.finfo(dtype).minThis is a stricter and more correct assertion than the previous
nanmean(abs(diff)) < 0.5check because it encodes the exact semantic reason the raw tensor representation differs.Testing
Validated locally with current uv.lock including
transformers==5.0.0,huggingface-hub==1.3.4, anddatasets==4.0.0.tests/unit/model_bridge/generalized_components/test_joint_qkv_attention.pypass with the fix and fail without it.tests/unit/model_bridge/generalized_components/test_joint_qkv_attention.pynow use class-level helpers for the additive-mask conversion and shared Q/K/V tensors so the base and rotary tests stay aligned.tests/integration/model_bridge/compatibility/test_legacy_hooks.py::TestLegacyHookCompatibility::test_cache_hook_equality_with_hooked_transformernow passes with the updated masked-score assertion.makefileso it can now toggleuv run --active/uv sync --activebehavior withTL_UV_ACTIVE=1instead of requiring manual edits.Files Changed
joint_qkv_attention.py_reconstruct_attention()and clarify the manual attention-reconstruction pathjoint_qkv_position_embeddings_attention.pytest_joint_qkv_attention.pytest_legacy_hooks.pyhook_attn_scoressemanticsmakefileTL_UV_ACTIVE=1to opt intouv run --activeanduv sync --activewithout editing the fileWhy The Legacy Test Update Is Justified
The legacy test change is not loosening correctness. It is correcting the reference model used by the assertion.
The old assertion implicitly assumed that HookedTransformer and TransformerBridge should serialize masked raw attention scores identically. That assumption is no longer valid once TransformerBridge correctly mirrors HuggingFace additive masking.
The updated test keeps a strong compatibility contract while reflecting the architectural reality:
-inf.Type of change
Checklist: