Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions tests/unit/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,53 @@ def test_init_xavier_normal(self, d_model, d_mlp):
assert torch.allclose(x_new, x, rtol=1e-2)


class TestTokenizeAndConcatenate:
"""Tests for tokenize_and_concatenate utility function."""

def test_no_split_tokens_across_chunks(self):
"""
Regression test for issue #1133.
tokenize_and_concatenate previously split text into chunks by character
count, which could cut words in half and produce token pairs that would
never occur in naturally tokenized text. This test verifies that all
tokens in the output also appear consecutively in a clean tokenization
of the same text, confirming no artificial token pairs were introduced.
"""
from datasets import Dataset
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("gpt2")

# Construct text where character-based splitting would cut mid-word.
# Repeating a long word many times ensures a chunk boundary falls inside it.
text = "Military " * 500
dataset = Dataset.from_dict({"text": [text]})

result = utils.tokenize_and_concatenate(
dataset,
tokenizer,
streaming=False,
max_length=64,
add_bos_token=False,
)

# Tokenize the same text cleanly in one shot (no chunking)
clean_tokens = tokenizer(text, return_tensors="np")["input_ids"].flatten()

# Build a set of all consecutive pairs from the clean tokenization
clean_pairs = set(zip(clean_tokens[:-1], clean_tokens[1:]))

# Every consecutive pair in our output must exist in the clean pairs
output_tokens = result["tokens"].numpy().flatten()
for i in range(len(output_tokens) - 1):
pair = (output_tokens[i], output_tokens[i + 1])
assert pair in clean_pairs, (
f"Token pair {pair} found in tokenize_and_concatenate output "
f"but never occurs in natural tokenization. "
f"This indicates a word was split across chunk boundaries."
)


def test_tokenize_and_concatenate_no_spurious_sequence_length_warning():
"""Test that tokenize_and_concatenate does not emit the HF 'sequence length longer than maximum' warning."""
from datasets import Dataset
Expand Down
21 changes: 17 additions & 4 deletions transformer_lens/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,12 +382,25 @@ def tokenize_function(examples: Any) -> dict[str, np.ndarray]:
if not full_text.strip():
return {"tokens": np.array([], dtype=np.int64)}

# Divide into 20 chunks of ~ equal length
# Divide into 20 chunks of ~ equal length, splitting at whitespace
# boundaries to avoid cutting words in half (which creates token pairs
# that would never occur in naturally tokenized text - see issue #1133)
num_chunks = 20
chunk_length = (len(full_text) - 1) // num_chunks + 1
chunks = [
full_text[i * chunk_length : (i + 1) * chunk_length] for i in range(num_chunks)
]
chunks = []
start = 0
lookahead = chunk_length // 10
for i in range(num_chunks):
end = min(start + chunk_length, len(full_text))
# Advance end to the next whitespace boundary to avoid splitting mid-token.
# Lookahead is bounded so pathological inputs (e.g. no whitespace) degrade
# gracefully to character-based splitting rather than consuming the rest of
# the string.
boundary = min(end + lookahead, len(full_text))
while end < boundary and not full_text[end].isspace():
end += 1
chunks.append(full_text[start:end])
start = end
# Tokenize the chunks in parallel. Uses NumPy because HuggingFace map doesn't want tensors returned
tokens = tokenizer(chunks, return_tensors="np", padding=True)["input_ids"].flatten()
# Drop padding tokens
Expand Down
Loading