diff --git a/makefile b/makefile index 874cd9e13..dc7c96c98 100644 --- a/makefile +++ b/makefile @@ -1,11 +1,14 @@ -RUN := uv run +TL_UV_ACTIVE ?= 0 +ACTIVE_FLAG := $(if $(filter 1 true TRUE yes YES on ON,$(TL_UV_ACTIVE)), --active,) +RUN := uv run$(ACTIVE_FLAG) +UV_SYNC := uv sync$(ACTIVE_FLAG) # Rerun args for flaky tests (httpx timeouts during HF Hub downloads) # Remove this line when no longer needed RERUN_ARGS := --reruns 2 --reruns-delay 5 dep: - uv sync + $(UV_SYNC) format: $(RUN) pycln --all . --exclude "__init__.py" @@ -59,12 +62,12 @@ notebook-test: $(RUN) pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Grokking_Demo.ipynb $(RERUN_ARGS) test: - make unit-test - make integration-test - make acceptance-test - make benchmark-test - make docstring-test - make notebook-test + $(MAKE) unit-test + $(MAKE) integration-test + $(MAKE) acceptance-test + $(MAKE) benchmark-test + $(MAKE) docstring-test + $(MAKE) notebook-test docs-hot-reload: $(RUN) docs-hot-reload diff --git a/tests/integration/model_bridge/compatibility/test_legacy_hooks.py b/tests/integration/model_bridge/compatibility/test_legacy_hooks.py index 66f3f7948..0ecf33c18 100644 --- a/tests/integration/model_bridge/compatibility/test_legacy_hooks.py +++ b/tests/integration/model_bridge/compatibility/test_legacy_hooks.py @@ -139,7 +139,15 @@ def test_cache_hook_names_present(self, transformer_bridge, prompt, expected_hoo def test_cache_hook_equality_with_hooked_transformer( self, transformer_bridge, hooked_transformer, prompt, expected_hooks ): - """Test that TransformerBridge cache values match HookedTransformer cache values.""" + """Test that TransformerBridge cache values match HookedTransformer cache values. + + Raw attention-score caches intentionally use different masked sentinels: + HookedTransformer stores ``-inf`` for masked causal positions, while + TransformerBridge preserves HuggingFace's finite additive mask + representation using ``torch.finfo(dtype).min``. The unmasked scores and + resulting attention pattern should still match within floating-point + precision. + """ _, bridge_cache = transformer_bridge.run_with_cache(prompt) _, hooked_transformer_cache = hooked_transformer.run_with_cache(prompt) @@ -157,11 +165,35 @@ def test_cache_hook_equality_with_hooked_transformer( f"TransformerBridge shape {bridge_activation.shape}" ) - # Allow for some numerical differences due to different implementations - # Use nanmean to handle -inf values in attention scores (which produce nan when subtracted) - mean_abs_diff = torch.nanmean( - torch.abs(hooked_transformer_activation - bridge_activation) - ) + if hook == "blocks.0.attn.hook_attn_scores": + masked_positions = torch.isinf(hooked_transformer_activation) + unmasked_positions = ~masked_positions + + assert torch.allclose( + hooked_transformer_activation[unmasked_positions], + bridge_activation[unmasked_positions], + atol=1e-6, + rtol=1e-6, + ), ( + "Unmasked attention scores should match within float32 " "numerical precision" + ) + + masked_bridge_values = bridge_activation[masked_positions] + min_dtype = torch.finfo(bridge_activation.dtype).min + + assert masked_positions.any(), "Expected causal masking in attention scores" + assert torch.isfinite(masked_bridge_values).all(), ( + "TransformerBridge should keep masked attention scores finite " + "to mirror HuggingFace additive masking semantics" + ) + assert torch.all(masked_bridge_values == min_dtype), ( + "Masked TransformerBridge attention scores should use dtype min " + "instead of HookedTransformer's -inf sentinel" + ) + continue + + # Remaining legacy-compatible hooks are finite on this prompt, mean abs diff suffices + mean_abs_diff = torch.abs(hooked_transformer_activation - bridge_activation).mean() assert mean_abs_diff < 0.5, ( f"Hook {hook} does not match between HookedTransformer and TransformerBridge. " f"Mean absolute difference: {mean_abs_diff}" diff --git a/tests/unit/model_bridge/generalized_components/test_joint_qkv_attention.py b/tests/unit/model_bridge/generalized_components/test_joint_qkv_attention.py index 6be812214..7c2fddc00 100644 --- a/tests/unit/model_bridge/generalized_components/test_joint_qkv_attention.py +++ b/tests/unit/model_bridge/generalized_components/test_joint_qkv_attention.py @@ -10,6 +10,83 @@ class TestJointQKVAttention: """Test JointQKVAttentionBridge functionality.""" + @classmethod + def _make_additive_mask(cls, boolean_mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: + min_dtype = torch.finfo(dtype).min + return torch.where( + boolean_mask, + torch.zeros((), dtype=dtype, device=boolean_mask.device), + torch.full((), min_dtype, dtype=dtype, device=boolean_mask.device), + ) + + @classmethod + def _make_reconstruct_attention_qkv(cls) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + q = torch.tensor( + [ + [ + [[1.0, 0.0], [0.5, -0.5]], + [[0.3, 0.7], [0.2, -0.1]], + [[-0.4, 0.6], [0.1, 0.9]], + ] + ], + dtype=torch.float32, + ) + k = torch.tensor( + [ + [ + [[0.9, 0.1], [0.2, -0.3]], + [[0.5, 0.4], [0.3, 0.2]], + [[-0.2, 0.8], [0.7, 0.1]], + ] + ], + dtype=torch.float32, + ) + v = torch.tensor( + [ + [ + [[0.2, 1.0], [0.1, 0.6]], + [[0.4, 0.3], [0.8, 0.2]], + [[0.7, 0.5], [0.9, 0.4]], + ] + ], + dtype=torch.float32, + ) + return q, k, v + + def _assert_non_4d_mask_preserves_causality( + self, + bridge, + *, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + ) -> None: + q, k, v = self._make_reconstruct_attention_qkv() + boolean_mask = torch.tensor([[True, True, False]]) + additive_mask = self._make_additive_mask(boolean_mask, q.dtype) + reconstruct_kwargs = {} + if position_embeddings is not None: + reconstruct_kwargs["position_embeddings"] = position_embeddings + + bool_output, bool_pattern = bridge._reconstruct_attention( + q.clone(), + k.clone(), + v.clone(), + attention_mask=boolean_mask, + **reconstruct_kwargs, + ) + additive_output, additive_pattern = bridge._reconstruct_attention( + q.clone(), + k.clone(), + v.clone(), + attention_mask=additive_mask, + **reconstruct_kwargs, + ) + + assert torch.allclose(bool_output, additive_output) + assert torch.allclose(bool_pattern, additive_pattern) + assert torch.all(bool_pattern[:, :, 0, 1:] == 0) + assert torch.all(bool_pattern[:, :, 1, 2] == 0) + assert torch.all(bool_pattern[..., 2] == 0) + def test_q_hook_out_mutation_applied_in_forward_pass(self): """Test that mutations made to q.hook_out are applied in the forward pass result.""" @@ -33,7 +110,8 @@ def forward(self, input): k_transformation = MockLinear(in_features=128, out_features=384) v_transformation = MockLinear(in_features=128, out_features=384) - split_qkv_matrix = lambda x: (q_transformation, k_transformation, v_transformation) + def split_qkv_matrix(_component): + return q_transformation, k_transformation, v_transformation # Create a mock attention layer for testing, doesn't do anything because we're only interested in the QKV components class MockAttention(torch.nn.Module): @@ -119,7 +197,8 @@ def forward(self, input): k_transformation = MockLinear(in_features=128, out_features=384) v_transformation = MockLinear(in_features=128, out_features=384) - split_qkv_matrix = lambda x: (q_transformation, k_transformation, v_transformation) + def split_qkv_matrix(_component): + return q_transformation, k_transformation, v_transformation # Create a mock attention layer for testing, doesn't do anything because we're only interested in the QKV components class MockAttention(torch.nn.Module): @@ -205,7 +284,8 @@ def forward(self, input): k_transformation = MockLinear(in_features=128, out_features=384) v_transformation = MockLinear(in_features=128, out_features=384) - split_qkv_matrix = lambda x: (q_transformation, k_transformation, v_transformation) + def split_qkv_matrix(_component): + return q_transformation, k_transformation, v_transformation # Create a mock attention layer for testing, doesn't do anything because we're only interested in the QKV components class MockAttention(torch.nn.Module): @@ -267,3 +347,141 @@ def v_hook_fn(v_output, hook): assert not torch.allclose( baseline_output, hooked_output ), "Output with v_hook_out mutation should be different from baseline" + + def test_reconstruct_attention_boolean_mask_matches_additive_mask(self): + """Boolean 4D masks should be equivalent to additive masks. + + This regression test covers the HuggingFace causal-mask path used by + TransformerBridge. Without the boolean-mask conversion in + ``_reconstruct_attention()``, boolean masks are added as ``0``/``1`` + and produce substantively different scores and patterns than the equivalent additive + float mask. + """ + + class TestConfig: + n_heads = 2 + d_model = 4 + + class MockOriginalAttention(torch.nn.Module): + def __init__(self): + super().__init__() + self.attn_dropout = torch.nn.Identity() + + bridge = JointQKVAttentionBridge(name="qkv", config=TestConfig()) + bridge.add_module("_original_component", MockOriginalAttention()) + q, k, v = self._make_reconstruct_attention_qkv() + boolean_mask = torch.tensor( + [[[[False, False, False], [False, True, False], [False, True, True]]]] + ) + additive_mask = self._make_additive_mask(boolean_mask, q.dtype) + + bool_output, bool_pattern = bridge._reconstruct_attention( + q.clone(), + k.clone(), + v.clone(), + attention_mask=boolean_mask, + ) + additive_output, additive_pattern = bridge._reconstruct_attention( + q.clone(), + k.clone(), + v.clone(), + attention_mask=additive_mask, + ) + + assert torch.isfinite(bool_output).all() + assert torch.isfinite(bool_pattern).all() + assert torch.allclose(bool_output, additive_output) + assert torch.allclose(bool_pattern, additive_pattern) + + def test_rotary_reconstruct_attention_boolean_mask_matches_additive_mask(self): + """Rotary joint-QKV attention should treat boolean and additive masks identically.""" + + from transformer_lens.model_bridge.generalized_components.joint_qkv_position_embeddings_attention import ( + JointQKVPositionEmbeddingsAttentionBridge, + ) + + class TestConfig: + n_heads = 2 + d_model = 4 + + class MockOriginalAttention(torch.nn.Module): + def __init__(self): + super().__init__() + self.attn_dropout = torch.nn.Identity() + + bridge = JointQKVPositionEmbeddingsAttentionBridge(name="qkv", config=TestConfig()) + bridge.add_module("_original_component", MockOriginalAttention()) + q, k, v = self._make_reconstruct_attention_qkv() + boolean_mask = torch.tensor( + [[[[False, False, False], [False, True, False], [False, True, True]]]] + ) + additive_mask = self._make_additive_mask(boolean_mask, q.dtype) + position_embeddings = ( + torch.ones(1, 3, 2, dtype=torch.float32), + torch.zeros(1, 3, 2, dtype=torch.float32), + ) + + bool_output, bool_pattern = bridge._reconstruct_attention( + q.clone(), + k.clone(), + v.clone(), + attention_mask=boolean_mask, + position_embeddings=position_embeddings, + ) + additive_output, additive_pattern = bridge._reconstruct_attention( + q.clone(), + k.clone(), + v.clone(), + attention_mask=additive_mask, + position_embeddings=position_embeddings, + ) + + assert torch.isfinite(bool_output).all() + assert torch.isfinite(bool_pattern).all() + assert torch.allclose(bool_output, additive_output) + assert torch.allclose(bool_pattern, additive_pattern) + + def test_reconstruct_attention_non_4d_mask_preserves_causality(self): + """Non-4D masks should still receive the local causal mask in the base bridge.""" + + class TestConfig: + n_heads = 2 + d_model = 4 + + class MockOriginalAttention(torch.nn.Module): + def __init__(self): + super().__init__() + self.attn_dropout = torch.nn.Identity() + + bridge = JointQKVAttentionBridge(name="qkv", config=TestConfig()) + bridge.add_module("_original_component", MockOriginalAttention()) + + self._assert_non_4d_mask_preserves_causality(bridge) + + def test_rotary_reconstruct_attention_non_4d_mask_preserves_causality(self): + """Rotary joint-QKV attention should match base masking semantics for non-4D masks.""" + + from transformer_lens.model_bridge.generalized_components.joint_qkv_position_embeddings_attention import ( + JointQKVPositionEmbeddingsAttentionBridge, + ) + + class TestConfig: + n_heads = 2 + d_model = 4 + + class MockOriginalAttention(torch.nn.Module): + def __init__(self): + super().__init__() + self.attn_dropout = torch.nn.Identity() + + bridge = JointQKVPositionEmbeddingsAttentionBridge(name="qkv", config=TestConfig()) + bridge.add_module("_original_component", MockOriginalAttention()) + position_embeddings = ( + torch.ones(1, 3, 2, dtype=torch.float32), + torch.zeros(1, 3, 2, dtype=torch.float32), + ) + + self._assert_non_4d_mask_preserves_causality( + bridge, + position_embeddings=position_embeddings, + ) diff --git a/transformer_lens/model_bridge/generalized_components/joint_qkv_attention.py b/transformer_lens/model_bridge/generalized_components/joint_qkv_attention.py index 5e26431f7..cf1314120 100644 --- a/transformer_lens/model_bridge/generalized_components/joint_qkv_attention.py +++ b/transformer_lens/model_bridge/generalized_components/joint_qkv_attention.py @@ -185,8 +185,8 @@ def _default_split_qkv_matrix( # Get the combined QKV component using the 'qkv' submodule name if "qkv" not in self.submodules: raise ValueError( - f"No 'qkv' submodule found in JointQKVAttentionBridge. " - f"Please define a 'qkv' submodule or provide a custom split_qkv_matrix function." + "No 'qkv' submodule found in JointQKVAttentionBridge. " + "Please define a 'qkv' submodule or provide a custom split_qkv_matrix function." ) # Get the actual qkv component name from the bridge @@ -355,10 +355,49 @@ def _apply_attention_input_hook(self, *args: Any, **kwargs: Any) -> torch.Tensor raise ValueError("No input tensor found in args or kwargs") return self.hook_in(input_tensor) + def _apply_reconstruct_attention_mask( + self, + attn_scores: torch.Tensor, + attention_mask: torch.Tensor | None, + seq_len: int, + ) -> torch.Tensor: + """Apply causal and optional attention masking to reconstructed scores. + + HuggingFace-style 4D masks already encode causal semantics, so they are + treated as authoritative. Lower-rank masks do not, so the local causal + mask is still applied before adding the caller-provided padding mask. + """ + min_dtype = torch.finfo(attn_scores.dtype).min + use_direct_hf_mask = attention_mask is not None and attention_mask.ndim >= 4 + if not use_direct_hf_mask: + causal_mask = torch.tril( + torch.ones(seq_len, seq_len, device=attn_scores.device, dtype=torch.bool) + ) + attn_scores = attn_scores.masked_fill(~causal_mask, min_dtype) + + if attention_mask is None: + return attn_scores + + if attention_mask.shape[-1] != seq_len: + attention_mask = attention_mask[..., :seq_len] + if attention_mask.ndim >= 3 and attention_mask.shape[-2] != seq_len: + attention_mask = attention_mask[..., :seq_len, :] + + if attention_mask.dtype == torch.bool: + attention_mask = torch.where( + attention_mask, + torch.zeros((), dtype=attn_scores.dtype, device=attn_scores.device), + torch.full((), min_dtype, dtype=attn_scores.dtype, device=attn_scores.device), + ) + else: + attention_mask = attention_mask.to(dtype=attn_scores.dtype) + + return attn_scores + attention_mask + def _reconstruct_attention( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, **kwargs ) -> tuple: - """Manual attention computation as fallback using TransformerLens computation logic.""" + """Manual attention reconstruction used by the bridge after splitting fused QKV projections.""" original_component = self.original_component assert original_component is not None assert self.config is not None @@ -395,22 +434,19 @@ def _reconstruct_attention( # using torch.baddbmm for numerical stability. Mirror that behavior here. reorder_and_upcast = getattr(self, "_reorder_and_upcast_attn", False) if reorder_and_upcast: - # Upcast Q/K to float32 for matmul, then apply combined scale - q_f32 = q.to(torch.float32) - k_f32 = k.to(torch.float32) - attn_scores = torch.matmul(q_f32, k_f32.transpose(-2, -1)) * scale + q_scores = q.to(torch.float32) + k_scores = k.to(torch.float32) else: - attn_scores = torch.matmul(q, k.transpose(-2, -1)) * scale - - causal_mask = torch.tril(torch.ones(seq_len, seq_len, device=q.device)) - attn_scores = attn_scores.masked_fill(causal_mask == 0, float("-inf")) + q_scores = q + k_scores = k + attn_scores = torch.matmul(q_scores, k_scores.transpose(-2, -1)) * scale attention_mask = kwargs.get("attention_mask", None) - if attention_mask is not None: - if attention_mask.shape[-1] != seq_len: - attention_mask = attention_mask[..., :seq_len] - if attention_mask.shape[-2] != seq_len: - attention_mask = attention_mask[..., :seq_len, :] - attn_scores = attn_scores + attention_mask + attn_scores = self._apply_reconstruct_attention_mask( + attn_scores=attn_scores, + attention_mask=attention_mask, + seq_len=seq_len, + ) + attn_scores = self.hook_attn_scores(attn_scores) # Softmax in float32 when upcast mode is active, then cast back diff --git a/transformer_lens/model_bridge/generalized_components/joint_qkv_position_embeddings_attention.py b/transformer_lens/model_bridge/generalized_components/joint_qkv_position_embeddings_attention.py index 6fbcd0f7f..926ce8bb3 100644 --- a/transformer_lens/model_bridge/generalized_components/joint_qkv_position_embeddings_attention.py +++ b/transformer_lens/model_bridge/generalized_components/joint_qkv_position_embeddings_attention.py @@ -195,7 +195,7 @@ def rotate_half(x): def _reconstruct_attention( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, **kwargs ) -> tuple: - """Manual attention computation with rotary position embeddings. + """Manual attention reconstruction with rotary position embeddings. This overrides the parent class to apply rotary embeddings to Q and K before computing attention scores. @@ -238,18 +238,12 @@ def _reconstruct_attention( scale = head_dim ** (-0.5) attn_scores = torch.matmul(q, k.transpose(-2, -1)) * scale - # Apply causal mask - causal_mask = torch.tril(torch.ones(seq_len, seq_len, device=q.device)) - attn_scores = attn_scores.masked_fill(causal_mask == 0, float("-inf")) - - # Apply attention mask if provided attention_mask = kwargs.get("attention_mask", None) - if attention_mask is not None: - if attention_mask.shape[-1] != seq_len: - attention_mask = attention_mask[..., :seq_len] - if attention_mask.shape[-2] != seq_len: - attention_mask = attention_mask[..., :seq_len, :] - attn_scores = attn_scores + attention_mask + attn_scores = self._apply_reconstruct_attention_mask( + attn_scores=attn_scores, + attention_mask=attention_mask, + seq_len=seq_len, + ) # Apply hook to attention scores attn_scores = self.hook_attn_scores(attn_scores)