Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions makefile
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
RUN := uv run
TL_UV_ACTIVE ?= 0
ACTIVE_FLAG := $(if $(filter 1 true TRUE yes YES on ON,$(TL_UV_ACTIVE)), --active,)
RUN := uv run$(ACTIVE_FLAG)
UV_SYNC := uv sync$(ACTIVE_FLAG)

# Rerun args for flaky tests (httpx timeouts during HF Hub downloads)
# Remove this line when no longer needed
RERUN_ARGS := --reruns 2 --reruns-delay 5

dep:
uv sync
$(UV_SYNC)

format:
$(RUN) pycln --all . --exclude "__init__.py"
Expand Down Expand Up @@ -59,12 +62,12 @@ notebook-test:
$(RUN) pytest --nbval-sanitize-with demos/doc_sanitize.cfg demos/Grokking_Demo.ipynb $(RERUN_ARGS)

test:
make unit-test
make integration-test
make acceptance-test
make benchmark-test
make docstring-test
make notebook-test
$(MAKE) unit-test
$(MAKE) integration-test
$(MAKE) acceptance-test
$(MAKE) benchmark-test
$(MAKE) docstring-test
$(MAKE) notebook-test

docs-hot-reload:
$(RUN) docs-hot-reload
Expand Down
44 changes: 38 additions & 6 deletions tests/integration/model_bridge/compatibility/test_legacy_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,15 @@ def test_cache_hook_names_present(self, transformer_bridge, prompt, expected_hoo
def test_cache_hook_equality_with_hooked_transformer(
self, transformer_bridge, hooked_transformer, prompt, expected_hooks
):
"""Test that TransformerBridge cache values match HookedTransformer cache values."""
"""Test that TransformerBridge cache values match HookedTransformer cache values.

Raw attention-score caches intentionally use different masked sentinels:
HookedTransformer stores ``-inf`` for masked causal positions, while
TransformerBridge preserves HuggingFace's finite additive mask
representation using ``torch.finfo(dtype).min``. The unmasked scores and
resulting attention pattern should still match within floating-point
precision.
"""
_, bridge_cache = transformer_bridge.run_with_cache(prompt)
_, hooked_transformer_cache = hooked_transformer.run_with_cache(prompt)

Expand All @@ -157,11 +165,35 @@ def test_cache_hook_equality_with_hooked_transformer(
f"TransformerBridge shape {bridge_activation.shape}"
)

# Allow for some numerical differences due to different implementations
# Use nanmean to handle -inf values in attention scores (which produce nan when subtracted)
mean_abs_diff = torch.nanmean(
torch.abs(hooked_transformer_activation - bridge_activation)
)
if hook == "blocks.0.attn.hook_attn_scores":
masked_positions = torch.isinf(hooked_transformer_activation)
unmasked_positions = ~masked_positions

assert torch.allclose(
hooked_transformer_activation[unmasked_positions],
bridge_activation[unmasked_positions],
atol=1e-6,
rtol=1e-6,
), (
"Unmasked attention scores should match within float32 " "numerical precision"
)

masked_bridge_values = bridge_activation[masked_positions]
min_dtype = torch.finfo(bridge_activation.dtype).min

assert masked_positions.any(), "Expected causal masking in attention scores"
assert torch.isfinite(masked_bridge_values).all(), (
"TransformerBridge should keep masked attention scores finite "
"to mirror HuggingFace additive masking semantics"
)
assert torch.all(masked_bridge_values == min_dtype), (
"Masked TransformerBridge attention scores should use dtype min "
"instead of HookedTransformer's -inf sentinel"
)
continue

# Remaining legacy-compatible hooks are finite on this prompt, mean abs diff suffices
mean_abs_diff = torch.abs(hooked_transformer_activation - bridge_activation).mean()
assert mean_abs_diff < 0.5, (
f"Hook {hook} does not match between HookedTransformer and TransformerBridge. "
f"Mean absolute difference: {mean_abs_diff}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,83 @@
class TestJointQKVAttention:
"""Test JointQKVAttentionBridge functionality."""

@classmethod
def _make_additive_mask(cls, boolean_mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
min_dtype = torch.finfo(dtype).min
return torch.where(
boolean_mask,
torch.zeros((), dtype=dtype, device=boolean_mask.device),
torch.full((), min_dtype, dtype=dtype, device=boolean_mask.device),
)

@classmethod
def _make_reconstruct_attention_qkv(cls) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
q = torch.tensor(
[
[
[[1.0, 0.0], [0.5, -0.5]],
[[0.3, 0.7], [0.2, -0.1]],
[[-0.4, 0.6], [0.1, 0.9]],
]
],
dtype=torch.float32,
)
k = torch.tensor(
[
[
[[0.9, 0.1], [0.2, -0.3]],
[[0.5, 0.4], [0.3, 0.2]],
[[-0.2, 0.8], [0.7, 0.1]],
]
],
dtype=torch.float32,
)
v = torch.tensor(
[
[
[[0.2, 1.0], [0.1, 0.6]],
[[0.4, 0.3], [0.8, 0.2]],
[[0.7, 0.5], [0.9, 0.4]],
]
],
dtype=torch.float32,
)
return q, k, v

def _assert_non_4d_mask_preserves_causality(
self,
bridge,
*,
position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
) -> None:
q, k, v = self._make_reconstruct_attention_qkv()
boolean_mask = torch.tensor([[True, True, False]])
additive_mask = self._make_additive_mask(boolean_mask, q.dtype)
reconstruct_kwargs = {}
if position_embeddings is not None:
reconstruct_kwargs["position_embeddings"] = position_embeddings

bool_output, bool_pattern = bridge._reconstruct_attention(
q.clone(),
k.clone(),
v.clone(),
attention_mask=boolean_mask,
**reconstruct_kwargs,
)
additive_output, additive_pattern = bridge._reconstruct_attention(
q.clone(),
k.clone(),
v.clone(),
attention_mask=additive_mask,
**reconstruct_kwargs,
)

assert torch.allclose(bool_output, additive_output)
assert torch.allclose(bool_pattern, additive_pattern)
assert torch.all(bool_pattern[:, :, 0, 1:] == 0)
assert torch.all(bool_pattern[:, :, 1, 2] == 0)
assert torch.all(bool_pattern[..., 2] == 0)

def test_q_hook_out_mutation_applied_in_forward_pass(self):
"""Test that mutations made to q.hook_out are applied in the forward pass result."""

Expand All @@ -33,7 +110,8 @@ def forward(self, input):
k_transformation = MockLinear(in_features=128, out_features=384)
v_transformation = MockLinear(in_features=128, out_features=384)

split_qkv_matrix = lambda x: (q_transformation, k_transformation, v_transformation)
def split_qkv_matrix(_component):
return q_transformation, k_transformation, v_transformation

# Create a mock attention layer for testing, doesn't do anything because we're only interested in the QKV components
class MockAttention(torch.nn.Module):
Expand Down Expand Up @@ -119,7 +197,8 @@ def forward(self, input):
k_transformation = MockLinear(in_features=128, out_features=384)
v_transformation = MockLinear(in_features=128, out_features=384)

split_qkv_matrix = lambda x: (q_transformation, k_transformation, v_transformation)
def split_qkv_matrix(_component):
return q_transformation, k_transformation, v_transformation

# Create a mock attention layer for testing, doesn't do anything because we're only interested in the QKV components
class MockAttention(torch.nn.Module):
Expand Down Expand Up @@ -205,7 +284,8 @@ def forward(self, input):
k_transformation = MockLinear(in_features=128, out_features=384)
v_transformation = MockLinear(in_features=128, out_features=384)

split_qkv_matrix = lambda x: (q_transformation, k_transformation, v_transformation)
def split_qkv_matrix(_component):
return q_transformation, k_transformation, v_transformation

# Create a mock attention layer for testing, doesn't do anything because we're only interested in the QKV components
class MockAttention(torch.nn.Module):
Expand Down Expand Up @@ -267,3 +347,141 @@ def v_hook_fn(v_output, hook):
assert not torch.allclose(
baseline_output, hooked_output
), "Output with v_hook_out mutation should be different from baseline"

def test_reconstruct_attention_boolean_mask_matches_additive_mask(self):
"""Boolean 4D masks should be equivalent to additive masks.

This regression test covers the HuggingFace causal-mask path used by
TransformerBridge. Without the boolean-mask conversion in
``_reconstruct_attention()``, boolean masks are added as ``0``/``1``
and produce substantively different scores and patterns than the equivalent additive
float mask.
"""

class TestConfig:
n_heads = 2
d_model = 4

class MockOriginalAttention(torch.nn.Module):
def __init__(self):
super().__init__()
self.attn_dropout = torch.nn.Identity()

bridge = JointQKVAttentionBridge(name="qkv", config=TestConfig())
bridge.add_module("_original_component", MockOriginalAttention())
q, k, v = self._make_reconstruct_attention_qkv()
boolean_mask = torch.tensor(
[[[[False, False, False], [False, True, False], [False, True, True]]]]
)
additive_mask = self._make_additive_mask(boolean_mask, q.dtype)

bool_output, bool_pattern = bridge._reconstruct_attention(
q.clone(),
k.clone(),
v.clone(),
attention_mask=boolean_mask,
)
additive_output, additive_pattern = bridge._reconstruct_attention(
q.clone(),
k.clone(),
v.clone(),
attention_mask=additive_mask,
)

assert torch.isfinite(bool_output).all()
assert torch.isfinite(bool_pattern).all()
assert torch.allclose(bool_output, additive_output)
assert torch.allclose(bool_pattern, additive_pattern)

def test_rotary_reconstruct_attention_boolean_mask_matches_additive_mask(self):
"""Rotary joint-QKV attention should treat boolean and additive masks identically."""

from transformer_lens.model_bridge.generalized_components.joint_qkv_position_embeddings_attention import (
JointQKVPositionEmbeddingsAttentionBridge,
)

class TestConfig:
n_heads = 2
d_model = 4

class MockOriginalAttention(torch.nn.Module):
def __init__(self):
super().__init__()
self.attn_dropout = torch.nn.Identity()

bridge = JointQKVPositionEmbeddingsAttentionBridge(name="qkv", config=TestConfig())
bridge.add_module("_original_component", MockOriginalAttention())
q, k, v = self._make_reconstruct_attention_qkv()
boolean_mask = torch.tensor(
[[[[False, False, False], [False, True, False], [False, True, True]]]]
)
additive_mask = self._make_additive_mask(boolean_mask, q.dtype)
position_embeddings = (
torch.ones(1, 3, 2, dtype=torch.float32),
torch.zeros(1, 3, 2, dtype=torch.float32),
)

bool_output, bool_pattern = bridge._reconstruct_attention(
q.clone(),
k.clone(),
v.clone(),
attention_mask=boolean_mask,
position_embeddings=position_embeddings,
)
additive_output, additive_pattern = bridge._reconstruct_attention(
q.clone(),
k.clone(),
v.clone(),
attention_mask=additive_mask,
position_embeddings=position_embeddings,
)

assert torch.isfinite(bool_output).all()
assert torch.isfinite(bool_pattern).all()
assert torch.allclose(bool_output, additive_output)
assert torch.allclose(bool_pattern, additive_pattern)

def test_reconstruct_attention_non_4d_mask_preserves_causality(self):
"""Non-4D masks should still receive the local causal mask in the base bridge."""

class TestConfig:
n_heads = 2
d_model = 4

class MockOriginalAttention(torch.nn.Module):
def __init__(self):
super().__init__()
self.attn_dropout = torch.nn.Identity()

bridge = JointQKVAttentionBridge(name="qkv", config=TestConfig())
bridge.add_module("_original_component", MockOriginalAttention())

self._assert_non_4d_mask_preserves_causality(bridge)

def test_rotary_reconstruct_attention_non_4d_mask_preserves_causality(self):
"""Rotary joint-QKV attention should match base masking semantics for non-4D masks."""

from transformer_lens.model_bridge.generalized_components.joint_qkv_position_embeddings_attention import (
JointQKVPositionEmbeddingsAttentionBridge,
)

class TestConfig:
n_heads = 2
d_model = 4

class MockOriginalAttention(torch.nn.Module):
def __init__(self):
super().__init__()
self.attn_dropout = torch.nn.Identity()

bridge = JointQKVPositionEmbeddingsAttentionBridge(name="qkv", config=TestConfig())
bridge.add_module("_original_component", MockOriginalAttention())
position_embeddings = (
torch.ones(1, 3, 2, dtype=torch.float32),
torch.zeros(1, 3, 2, dtype=torch.float32),
)

self._assert_non_4d_mask_preserves_causality(
bridge,
position_embeddings=position_embeddings,
)
Loading
Loading