Skip to content

Fix boolean 4D attention-mask handling in joint-QKV bridge attention reconstruction#1198

Merged
jlarson4 merged 11 commits intoTransformerLensOrg:dev-3.xfrom
speediedan:joint-qkv-attention-fix
Mar 13, 2026
Merged

Fix boolean 4D attention-mask handling in joint-QKV bridge attention reconstruction#1198
jlarson4 merged 11 commits intoTransformerLensOrg:dev-3.xfrom
speediedan:joint-qkv-attention-fix

Conversation

@speediedan
Copy link

@speediedan speediedan commented Mar 12, 2026

This change fixes a masking bug in the joint-QKV bridge attention reconstruction path used by both JointQKVAttentionBridge._reconstruct_attention() and JointQKVPositionEmbeddingsAttentionBridge._reconstruct_attention(). The issue surfaced during downstream backend-parity work in Interpretune and, more importantly, caused TransformerBridge to diverge from native HuggingFace behavior on padded inputs.

The change is small in code size but high-impact in behavior: when HuggingFace supplied a 4D boolean causal mask, the bridge reconstruction path treated that boolean mask as additive numeric values and effectively let masked positions leak into attention.

The Issue

The joint-QKV _reconstruct_attention() implementations are the manual attention-reconstruction paths used by bridge attention modules built from fused QKV projections, including the rotary-enabled variant. They are not fallback branches inside a larger conditional; they are the bridge's explicit path for rebuilding attention after the fused QKV projections have been split into individually hookable components. Before this fix, when attention_mask was present the methods assumed the mask could be added directly to float attention scores.

That assumption is valid for HuggingFace's additive float masks, but not for boolean 4D masks.

In the failing case:

  1. HuggingFace produced a 4D boolean mask (True-> "attend", False-> "mask").
  2. _reconstruct_attention() added that boolean tensor directly to attn_scores, which silently cast True/False to 1.0/0.0.
  3. Attend positions therefore received a +1 bias, while masked positions were not actually suppressed.
  4. The fallback path also relied on separate causal masking logic instead of treating HuggingFace's 4D mask as already authoritative for both causal and padding semantics.

The result was that TransformerBridge no longer matched the native HuggingFace forward pass on padded inputs. In downstream Interpretune (a pre-MVP latent space analysis framework) analysis runs on left-padded GPT-2 batches, this showed up as materially different answer-position logits and parity failures between TransformerBridge and NNsight, even though both are expected to approximate the HuggingFace-native execution.

Why This Matters

The bug was discovered while debugging the previously failing downstream parity test:

tests/core/test_sae_backend_parity.py::TestLogitDiffsBaseBackendParity

That test compares TransformerBridge and NNsight on the same HuggingFace model forward pass. It is intentionally strict because both backends are expected to preserve HuggingFace semantics and differ only in interception mechanism.

The broader investigation established the following backend model:

  • TransformerBridge and NNsight are the HuggingFace-native execution paths.
  • HookedTransformer is a separate TransformerLens weight processed forward path with known architectural divergence on padded workloads.
  • For parity questions involving padding and masking, HuggingFace-native behavior was deemed the sensible source of truth.

This fix reconciles TransformerBridge to that source-of-truth by making the bridge attention-reconstruction path preserve HuggingFace mask semantics.

Root Cause

HuggingFace causal mask generation can produce either:

  • an additive float mask: 0.0 for attend, min_dtype for masked positions
  • a boolean 4D mask: True for attend, False for masked positions

The previous implementation only handled the additive case correctly.

For boolean masks it performed the equivalent of:

attn_scores = attn_scores + attention_mask

which becomes:

  • scores + 1.0 at valid positions
  • scores + 0.0 at masked positions

That is the opposite of the intended behavior for masked entries.

The Fix

The joint-QKV attention-reconstruction implementations now:

  1. trim HuggingFace-provided masks to the active sequence length
  2. detect boolean masks explicitly
  3. convert boolean masks to additive float masks using torch.where(mask, 0.0, min_dtype)
  4. use the HuggingFace 4D mask directly when present, instead of relying on a separate secondary masking path
  5. retain the simple causal tril fallback only when no HuggingFace mask is provided

Key Design Decisions

1. Only HuggingFace-style 4D masks are authoritative for causal semantics

When HuggingFace has already prepared a 4D attention mask, that mask should be treated as the canonical expression of both causal and padding semantics. Lower-rank masks such as 2D or 3D padding masks do not encode causality on their own, so the bridge still applies its local causal tril mask before layering the caller-provided padding mask on top.

2. torch.finfo(dtype).min used consistently for all masked positions

torch.finfo(dtype).min is used for every masking path in the bridge: both when converting boolean 4D masks and in the fallback causal tril path that activates when no HuggingFace mask is supplied. This approach:

  • Matches HuggingFace's additive-mask convention (HF itself fills its causal mask with min_dtype, not -inf).
  • Avoids numerical problems (e.g., NaN gradients or unstable softmax sums) that can occur when a row is fully masked and all inputs to softmax are -inf.
  • Preserves parity with other HuggingFace-native backends.

HookedTransformer continues to use -inf in its masking path because it is not subject to this constraint. The difference between TransformerBridge's min_dtype and HookedTransformer's -inf at masked positions is reflected explicitly in the updated legacy hook compatibility test.

3. Shared mask application logic keeps the rotary and non-rotary paths aligned

The rotary attention bridge differs from the base bridge only in how it prepares Q and K before score computation. The causal-vs-padding mask rules should be identical. To avoid further drift between the two implementations, the mask application logic now lives in a shared helper on JointQKVAttentionBridge and is reused by JointQKVPositionEmbeddingsAttentionBridge.

Downstream Impact

This bug fix was motivated by Interpretune's latent space operation backend parity work.

Relevant downstream references:

  • tests/core/test_sae_backend_parity.py::TestLogitDiffsBaseBackendParity
  • docs/ht_bridge_parity_behavior.md
  • debug_4way_canonical_comparison.py

That investigation showed:

  • TransformerBridge and NNsight should track HuggingFace-native outputs closely.
  • HookedTransformer follows a different masking implementation and should not be used as the raw parity reference for padded attention behavior.
  • After the mask fix, TransformerBridge matched native HuggingFace outputs to float precision in the canonical comparison workflow.

Test Changes

This PR adds and updates tests in three places.

1. New unit regression for boolean-mask conversion

Added focused regressions in:

tests/unit/model_bridge/generalized_components/test_joint_qkv_attention.py

The new tests construct a boolean 4D attention mask and the equivalent additive float mask, then verify that both the base joint-QKV bridge and the rotary joint-QKV bridge produce the same attention pattern and output for both representations.

This test fails without the fix because the old implementation adds boolean masks as 0/1 values rather than converting them to additive masked scores.

The same test module now also verifies that non-4D masks still preserve causal semantics in both bridges: a 2D padding mask must not re-enable attention to future tokens, and the boolean and additive representations must still match.

2. Updated legacy hook compatibility expectation

Updated:

tests/integration/model_bridge/compatibility/test_legacy_hooks.py::TestLegacyHookCompatibility::test_cache_hook_equality_with_hooked_transformer

Before this change, the test required raw equality for all cached hooks, including blocks.0.attn.hook_attn_scores.

That assertion is no longer correct for the masked entries of hook_attn_scores:

  • HookedTransformer stores -inf at masked causal positions.
  • TransformerBridge now stores torch.finfo(dtype).min, matching HuggingFace additive-mask semantics.

The updated test keeps the original compatibility guarantee for every other expected hook and refines the hook_attn_scores check to assert:

  • unmasked attention scores match within float32 numerical precision
  • HookedTransformer masked positions are -inf
  • TransformerBridge masked positions are finite and equal to torch.finfo(dtype).min

This is a stricter and more correct assertion than the previous nanmean(abs(diff)) < 0.5 check because it encodes the exact semantic reason the raw tensor representation differs.

Testing

Validated locally with current uv.lock including transformers==5.0.0, huggingface-hub==1.3.4, and datasets==4.0.0.

  • New unit regressions in tests/unit/model_bridge/generalized_components/test_joint_qkv_attention.py pass with the fix and fail without it.
  • The shared regression fixtures in tests/unit/model_bridge/generalized_components/test_joint_qkv_attention.py now use class-level helpers for the additive-mask conversion and shared Q/K/V tensors so the base and rotary tests stay aligned.
  • tests/integration/model_bridge/compatibility/test_legacy_hooks.py::TestLegacyHookCompatibility::test_cache_hook_equality_with_hooked_transformer now passes with the updated masked-score assertion.
  • To accommodate out-of-place virtualenv workflows, we also add a few lines to the local makefile so it can now toggle uv run --active / uv sync --active behavior with TL_UV_ACTIVE=1 instead of requiring manual edits.

Files Changed

File Change
joint_qkv_attention.py Fix boolean 4D attention-mask handling in _reconstruct_attention() and clarify the manual attention-reconstruction path
joint_qkv_position_embeddings_attention.py Apply the same boolean/additive-mask fix to the rotary joint-QKV attention-reconstruction path
test_joint_qkv_attention.py Add regressions covering boolean-mask to additive-mask equivalence for both joint-QKV bridge variants
test_legacy_hooks.py Update legacy hook compatibility test for masked hook_attn_scores semantics
makefile Allow TL_UV_ACTIVE=1 to opt into uv run --active and uv sync --active without editing the file

Why The Legacy Test Update Is Justified

The legacy test change is not loosening correctness. It is correcting the reference model used by the assertion.

The old assertion implicitly assumed that HookedTransformer and TransformerBridge should serialize masked raw attention scores identically. That assumption is no longer valid once TransformerBridge correctly mirrors HuggingFace additive masking.

The updated test keeps a strong compatibility contract while reflecting the architectural reality:

  • HookedTransformer and TransformerBridge still agree exactly on unmasked attention scores for this path.
  • HookedTransformer still exposes masked positions clearly via -inf.
  • TransformerBridge now correctly exposes HuggingFace-style finite masked sentinels.
  • The downstream user-visible quantity that matters, the attention pattern, remains compatible.

Type of change

  • Bug fix (non-breaking change which fixes an issue)

Checklist:

  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes (with noted appropriate modification)
  • I have not rewritten tests relating to key interfaces which would af

@speediedan speediedan marked this pull request as ready for review March 12, 2026 18:20
Copilot AI review requested due to automatic review settings March 12, 2026 18:20
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR fixes attention-mask handling in the joint-QKV attention reconstruction paths used by TransformerBridge components, specifically correcting behavior when HuggingFace provides a 4D boolean attention mask so masked positions do not leak into attention.

Changes:

  • Convert 4D boolean attention masks into additive (float) masks using torch.finfo(dtype).min before adding to attention scores.
  • Treat HuggingFace-provided 4D masks as authoritative for causal+padding semantics (and only fall back to a local tril causal mask when no HF mask is provided).
  • Add/update regressions and compatibility assertions around masked attention-score sentinel values; update Makefile to optionally use uv --active.

Reviewed changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated 1 comment.

Show a summary per file
File Description
transformer_lens/model_bridge/generalized_components/joint_qkv_attention.py Fix boolean 4D mask handling and adjust causal masking behavior in manual attention reconstruction.
transformer_lens/model_bridge/generalized_components/joint_qkv_position_embeddings_attention.py Apply analogous boolean-mask handling to the rotary (position-embeddings) reconstruction path.
tests/unit/model_bridge/generalized_components/test_joint_qkv_attention.py Add regressions asserting boolean 4D masks match equivalent additive masks for both bridge variants.
tests/integration/model_bridge/compatibility/test_legacy_hooks.py Update legacy cache equality expectations for hook_attn_scores masked sentinel values.
makefile Add TL_UV_ACTIVE toggle to use uv run/sync --active; use $(MAKE) for sub-targets.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

You can also share your feedback on Copilot code review. Take the survey.

@speediedan speediedan marked this pull request as draft March 12, 2026 19:31
@speediedan speediedan marked this pull request as ready for review March 12, 2026 21:31
@jlarson4 jlarson4 merged commit 42a2d52 into TransformerLensOrg:dev-3.x Mar 13, 2026
18 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants