diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 02ee73807..6c48352e2 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -437,6 +437,53 @@ def test_init_xavier_normal(self, d_model, d_mlp): assert torch.allclose(x_new, x, rtol=1e-2) +class TestTokenizeAndConcatenate: + """Tests for tokenize_and_concatenate utility function.""" + + def test_no_split_tokens_across_chunks(self): + """ + Regression test for issue #1133. + tokenize_and_concatenate previously split text into chunks by character + count, which could cut words in half and produce token pairs that would + never occur in naturally tokenized text. This test verifies that all + tokens in the output also appear consecutively in a clean tokenization + of the same text, confirming no artificial token pairs were introduced. + """ + from datasets import Dataset + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained("gpt2") + + # Construct text where character-based splitting would cut mid-word. + # Repeating a long word many times ensures a chunk boundary falls inside it. + text = "Military " * 500 + dataset = Dataset.from_dict({"text": [text]}) + + result = utils.tokenize_and_concatenate( + dataset, + tokenizer, + streaming=False, + max_length=64, + add_bos_token=False, + ) + + # Tokenize the same text cleanly in one shot (no chunking) + clean_tokens = tokenizer(text, return_tensors="np")["input_ids"].flatten() + + # Build a set of all consecutive pairs from the clean tokenization + clean_pairs = set(zip(clean_tokens[:-1], clean_tokens[1:])) + + # Every consecutive pair in our output must exist in the clean pairs + output_tokens = result["tokens"].numpy().flatten() + for i in range(len(output_tokens) - 1): + pair = (output_tokens[i], output_tokens[i + 1]) + assert pair in clean_pairs, ( + f"Token pair {pair} found in tokenize_and_concatenate output " + f"but never occurs in natural tokenization. " + f"This indicates a word was split across chunk boundaries." + ) + + def test_tokenize_and_concatenate_no_spurious_sequence_length_warning(): """Test that tokenize_and_concatenate does not emit the HF 'sequence length longer than maximum' warning.""" from datasets import Dataset diff --git a/transformer_lens/utils.py b/transformer_lens/utils.py index 8bf112a20..92fa17368 100644 --- a/transformer_lens/utils.py +++ b/transformer_lens/utils.py @@ -382,12 +382,25 @@ def tokenize_function(examples: Any) -> dict[str, np.ndarray]: if not full_text.strip(): return {"tokens": np.array([], dtype=np.int64)} - # Divide into 20 chunks of ~ equal length + # Divide into 20 chunks of ~ equal length, splitting at whitespace + # boundaries to avoid cutting words in half (which creates token pairs + # that would never occur in naturally tokenized text - see issue #1133) num_chunks = 20 chunk_length = (len(full_text) - 1) // num_chunks + 1 - chunks = [ - full_text[i * chunk_length : (i + 1) * chunk_length] for i in range(num_chunks) - ] + chunks = [] + start = 0 + lookahead = chunk_length // 10 + for i in range(num_chunks): + end = min(start + chunk_length, len(full_text)) + # Advance end to the next whitespace boundary to avoid splitting mid-token. + # Lookahead is bounded so pathological inputs (e.g. no whitespace) degrade + # gracefully to character-based splitting rather than consuming the rest of + # the string. + boundary = min(end + lookahead, len(full_text)) + while end < boundary and not full_text[end].isspace(): + end += 1 + chunks.append(full_text[start:end]) + start = end # Tokenize the chunks in parallel. Uses NumPy because HuggingFace map doesn't want tensors returned tokens = tokenizer(chunks, return_tensors="np", padding=True)["input_ids"].flatten() # Drop padding tokens