Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions transformer_lens/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
utilities,
)
from . import loading_from_pretrained as loading
from . import supported_models
from .ActivationCache import ActivationCache
from .BertNextSentencePrediction import BertNextSentencePrediction
from .cache.key_value_cache import TransformerLensKeyValueCache
Expand Down
40 changes: 18 additions & 22 deletions transformer_lens/hook_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,31 +243,27 @@ def full_hook(
pt_handle = self.register_forward_hook(full_hook, prepend=prepend)
visible_hooks = self.fwd_hooks
elif dir == "bwd":
# Use tensor-level grad hooks instead of register_full_backward_hook
# to avoid BackwardHookFunctionBackward views that break downstream
# in-place ops (e.g. OLMo's query_states.clamp_()).
def _bwd_via_tensor_hook(
_module: torch.nn.Module,
_input: Any,
output: Any,
) -> None:
if isinstance(output, Tensor) and output.requires_grad:

def _grad_hook(grad: Tensor) -> Any:
result = full_hook(_module, _input, (grad,))
# full_hook may return a tuple (register_full_backward_hook
# convention) but tensor hooks expect Tensor or None.
if isinstance(result, tuple):
return result[0]
return result

output.register_hook(_grad_hook)
# register_full_backward_hook signature:
# hook(module, grad_input, grad_output) -> tuple(Tensor) | None
# The return value replaces grad_input. full_hook returns a bare
# Tensor (or None), so we wrap it in a tuple for PyTorch.
def _bwd_hook_wrapper(
module: torch.nn.Module,
grad_input: Any,
grad_output: Any,
):
result = full_hook(module, grad_input, grad_output)
if result is None:
return None
if isinstance(result, tuple):
return result
return (result,)

if isinstance(hook, partial):
_bwd_via_tensor_hook.__name__ = f"partial({hook.func.__repr__()},...)"
_bwd_hook_wrapper.__name__ = f"partial({hook.func.__repr__()},...)"
else:
_bwd_via_tensor_hook.__name__ = hook.__repr__()
pt_handle = self.register_forward_hook(_bwd_via_tensor_hook, prepend=prepend)
_bwd_hook_wrapper.__name__ = hook.__repr__()
pt_handle = self.register_full_backward_hook(_bwd_hook_wrapper, prepend=prepend)
visible_hooks = self.bwd_hooks
else:
raise ValueError(f"Invalid direction {dir}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,12 @@
from transformer_lens.model_bridge.supported_architectures.gpt2 import (
GPT2ArchitectureAdapter,
)
from transformer_lens.model_bridge.supported_architectures.gpt_oss import GPTOSSArchitectureAdapter
from transformer_lens.model_bridge.supported_architectures.gpt2_lm_head_custom import (
Gpt2LmHeadCustomArchitectureAdapter,
)
from transformer_lens.model_bridge.supported_architectures.gpt_oss import (
GPTOSSArchitectureAdapter,
)
from transformer_lens.model_bridge.supported_architectures.gptj import (
GptjArchitectureAdapter,
)
Expand Down Expand Up @@ -102,6 +104,7 @@
"Gemma2ArchitectureAdapter",
"Gemma3ArchitectureAdapter",
"GPT2ArchitectureAdapter",
"GPTOSSArchitectureAdapter",
"Gpt2LmHeadCustomArchitectureAdapter",
"GptjArchitectureAdapter",
"LlamaArchitectureAdapter",
Expand Down
70 changes: 70 additions & 0 deletions transformer_lens/model_bridge/supported_architectures/olmo.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""OLMo architecture adapter."""

import logging
from typing import Any

from transformer_lens.conversion_utils.conversion_steps import RearrangeTensorConversion
Expand Down Expand Up @@ -141,6 +142,20 @@ def __init__(self, cfg: Any) -> None:
"unembed": UnembeddingBridge(name="lm_head", config=self.cfg),
}

def prepare_model(self, hf_model: Any) -> None:
"""Patch OLMo's in-place clamp_ to avoid backward hook conflicts.

OLMo v1 uses query_states.clamp_() when config.clip_qkv is set.
In-place ops on tensors that pass through register_full_backward_hook
trigger PyTorch's "view modified inplace" error. This patch disables
the in-place clamp branch during attention forward passes.

Note: clip_qkv clamping is skipped in the patched forward. In practice
clip_qkv values (typically 100+) rarely activate. If exact clamping is
needed, add out-of-place clamp hooks on hook_q/hook_k/hook_v.
"""
_patch_olmo_inplace_clamp(hf_model)

def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None:
"""Set up rotary embedding references for OLMo component testing.

Expand Down Expand Up @@ -172,3 +187,58 @@ def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> No
# Also set on the template for get_generalized_component() calls
attn_bridge = self.get_generalized_component("blocks.0.attn")
attn_bridge.set_rotary_emb(rotary_emb)


def _patch_olmo_inplace_clamp(hf_model: Any) -> None:
"""Patch OLMo attention to avoid in-place clamp_ that conflicts with backward hooks.

PyTorch's register_full_backward_hook wraps module outputs in
BackwardHookFunctionBackward views. OLMo's attention does
query_states.clamp_() on tensors derived from those views, which
PyTorch forbids.

Fix: wrap each attention layer's forward to temporarily clear
config.clip_qkv (preventing the in-place branch) and apply
out-of-place clamping via a forward hook instead.
"""
if not hasattr(hf_model, "model") or not hasattr(hf_model.model, "layers"):
return

clip_qkv = getattr(hf_model.config, "clip_qkv", None)
if clip_qkv is None:
return

import functools

patched = 0
for layer in hf_model.model.layers:
attn = getattr(layer, "self_attn", None)
if attn is None:
continue

original_forward = attn.forward

def _make_patched_forward(orig_fwd, clip_val=clip_qkv):
@functools.wraps(orig_fwd)
def patched_forward(*args, **kwargs):
# Temporarily disable clip_qkv so HF's in-place clamp_ is skipped
cfg = hf_model.config
saved = cfg.clip_qkv
cfg.clip_qkv = None
try:
return orig_fwd(*args, **kwargs)
finally:
cfg.clip_qkv = saved

return patched_forward

attn.forward = _make_patched_forward(original_forward)
patched += 1

if patched > 0:
logging.info(
"Patched %d OLMo attention layer(s): disabled in-place clamp_ "
"(clip_qkv=%.1f) for backward hook compatibility.",
patched,
clip_qkv,
)
11 changes: 11 additions & 0 deletions transformer_lens/model_bridge/supported_architectures/olmoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,17 @@ def __init__(self, cfg: Any) -> None:
"unembed": UnembeddingBridge(name="lm_head", config=self.cfg),
}

def prepare_model(self, hf_model: Any) -> None:
"""Patch OLMoE's in-place clamp_ to avoid backward hook conflicts.

Same issue as OLMo v1 — see OlmoArchitectureAdapter.prepare_model.
"""
from transformer_lens.model_bridge.supported_architectures.olmo import (
_patch_olmo_inplace_clamp,
)

_patch_olmo_inplace_clamp(hf_model)

def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None:
"""Set up rotary embedding references for OLMoE component testing.

Expand Down
Loading
Loading